Enable basic type checking for recorder (#52440)
* Enable basic type checking for recorder * Tweak
This commit is contained in:
parent
960684346f
commit
19d3aa71ad
7 changed files with 108 additions and 39 deletions
|
@ -324,7 +324,7 @@ class PerodicCleanupTask:
|
|||
class StatisticsTask(NamedTuple):
|
||||
"""An object to insert into the recorder queue to run a statistics task."""
|
||||
|
||||
start: datetime.datetime
|
||||
start: datetime
|
||||
|
||||
|
||||
class WaitTask:
|
||||
|
@ -358,7 +358,7 @@ class Recorder(threading.Thread):
|
|||
self.db_url = uri
|
||||
self.db_max_retries = db_max_retries
|
||||
self.db_retry_wait = db_retry_wait
|
||||
self.async_db_ready = asyncio.Future()
|
||||
self.async_db_ready: asyncio.Future = asyncio.Future()
|
||||
self.async_recorder_ready = asyncio.Event()
|
||||
self._queue_watch = threading.Event()
|
||||
self.engine: Any = None
|
||||
|
@ -370,8 +370,8 @@ class Recorder(threading.Thread):
|
|||
self._timechanges_seen = 0
|
||||
self._commits_without_expire = 0
|
||||
self._keepalive_count = 0
|
||||
self._old_states = {}
|
||||
self._pending_expunge = []
|
||||
self._old_states: dict[str, States] = {}
|
||||
self._pending_expunge: list[States] = []
|
||||
self.event_session = None
|
||||
self.get_session = None
|
||||
self._completed_first_database_setup = None
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
"""Models for SQLAlchemy."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
from typing import TypedDict
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
|
@ -206,6 +210,17 @@ class States(Base): # type: ignore
|
|||
return None
|
||||
|
||||
|
||||
class StatisticData(TypedDict, total=False):
|
||||
"""Statistic data class."""
|
||||
|
||||
mean: float
|
||||
min: float
|
||||
max: float
|
||||
last_reset: datetime | None
|
||||
state: float
|
||||
sum: float
|
||||
|
||||
|
||||
class Statistics(Base): # type: ignore
|
||||
"""Statistics."""
|
||||
|
||||
|
@ -230,7 +245,7 @@ class Statistics(Base): # type: ignore
|
|||
sum = Column(Float())
|
||||
|
||||
@staticmethod
|
||||
def from_stats(metadata_id, start, stats):
|
||||
def from_stats(metadata_id: str, start: datetime, stats: StatisticData):
|
||||
"""Create object from a statistics."""
|
||||
return Statistics(
|
||||
metadata_id=metadata_id,
|
||||
|
@ -239,6 +254,14 @@ class Statistics(Base): # type: ignore
|
|||
)
|
||||
|
||||
|
||||
class StatisticMetaData(TypedDict, total=False):
|
||||
"""Statistic meta data class."""
|
||||
|
||||
unit_of_measurement: str | None
|
||||
has_mean: bool
|
||||
has_sum: bool
|
||||
|
||||
|
||||
class StatisticsMeta(Base): # type: ignore
|
||||
"""Statistics meta data."""
|
||||
|
||||
|
@ -251,7 +274,13 @@ class StatisticsMeta(Base): # type: ignore
|
|||
has_sum = Column(Boolean)
|
||||
|
||||
@staticmethod
|
||||
def from_meta(source, statistic_id, unit_of_measurement, has_mean, has_sum):
|
||||
def from_meta(
|
||||
source: str,
|
||||
statistic_id: str,
|
||||
unit_of_measurement: str | None,
|
||||
has_mean: bool,
|
||||
has_sum: bool,
|
||||
) -> StatisticsMeta:
|
||||
"""Create object from meta data."""
|
||||
return StatisticsMeta(
|
||||
source=source,
|
||||
|
@ -340,7 +369,7 @@ def process_timestamp(ts):
|
|||
return dt_util.as_utc(ts)
|
||||
|
||||
|
||||
def process_timestamp_to_utc_isoformat(ts):
|
||||
def process_timestamp_to_utc_isoformat(ts: datetime | None) -> str | None:
|
||||
"""Process a timestamp into UTC isotime."""
|
||||
if ts is None:
|
||||
return None
|
||||
|
|
|
@ -5,18 +5,26 @@ from collections import defaultdict
|
|||
from datetime import datetime, timedelta
|
||||
from itertools import groupby
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from sqlalchemy import bindparam
|
||||
from sqlalchemy.ext import baked
|
||||
from sqlalchemy.orm.scoping import scoped_session
|
||||
|
||||
from homeassistant.const import PRESSURE_PA, TEMP_CELSIUS
|
||||
from homeassistant.core import HomeAssistant
|
||||
import homeassistant.util.dt as dt_util
|
||||
import homeassistant.util.pressure as pressure_util
|
||||
import homeassistant.util.temperature as temperature_util
|
||||
from homeassistant.util.unit_system import UnitSystem
|
||||
|
||||
from .const import DOMAIN
|
||||
from .models import Statistics, StatisticsMeta, process_timestamp_to_utc_isoformat
|
||||
from .models import (
|
||||
StatisticMetaData,
|
||||
Statistics,
|
||||
StatisticsMeta,
|
||||
process_timestamp_to_utc_isoformat,
|
||||
)
|
||||
from .util import execute, retryable_database_job, session_scope
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -60,20 +68,22 @@ UNIT_CONVERSIONS = {
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def async_setup(hass):
|
||||
def async_setup(hass: HomeAssistant) -> None:
|
||||
"""Set up the history hooks."""
|
||||
hass.data[STATISTICS_BAKERY] = baked.bakery()
|
||||
hass.data[STATISTICS_META_BAKERY] = baked.bakery()
|
||||
|
||||
|
||||
def get_start_time() -> datetime.datetime:
|
||||
def get_start_time() -> datetime:
|
||||
"""Return start time."""
|
||||
last_hour = dt_util.utcnow() - timedelta(hours=1)
|
||||
start = last_hour.replace(minute=0, second=0, microsecond=0)
|
||||
return start
|
||||
|
||||
|
||||
def _get_metadata_ids(hass, session, statistic_ids):
|
||||
def _get_metadata_ids(
|
||||
hass: HomeAssistant, session: scoped_session, statistic_ids: list[str]
|
||||
) -> list[str]:
|
||||
"""Resolve metadata_id for a list of statistic_ids."""
|
||||
baked_query = hass.data[STATISTICS_META_BAKERY](
|
||||
lambda session: session.query(*QUERY_STATISTIC_META)
|
||||
|
@ -83,10 +93,15 @@ def _get_metadata_ids(hass, session, statistic_ids):
|
|||
)
|
||||
result = execute(baked_query(session).params(statistic_ids=statistic_ids))
|
||||
|
||||
return [id for id, _, _ in result]
|
||||
return [id for id, _, _ in result] if result else []
|
||||
|
||||
|
||||
def _get_or_add_metadata_id(hass, session, statistic_id, metadata):
|
||||
def _get_or_add_metadata_id(
|
||||
hass: HomeAssistant,
|
||||
session: scoped_session,
|
||||
statistic_id: str,
|
||||
metadata: StatisticMetaData,
|
||||
) -> str:
|
||||
"""Get metadata_id for a statistic_id, add if it doesn't exist."""
|
||||
metadata_id = _get_metadata_ids(hass, session, [statistic_id])
|
||||
if not metadata_id:
|
||||
|
@ -101,7 +116,7 @@ def _get_or_add_metadata_id(hass, session, statistic_id, metadata):
|
|||
|
||||
|
||||
@retryable_database_job("statistics")
|
||||
def compile_statistics(instance: Recorder, start: datetime.datetime) -> bool:
|
||||
def compile_statistics(instance: Recorder, start: datetime) -> bool:
|
||||
"""Compile statistics."""
|
||||
start = dt_util.as_utc(start)
|
||||
end = start + timedelta(hours=1)
|
||||
|
@ -126,10 +141,15 @@ def compile_statistics(instance: Recorder, start: datetime.datetime) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def _get_metadata(hass, session, statistic_ids, statistic_type):
|
||||
def _get_metadata(
|
||||
hass: HomeAssistant,
|
||||
session: scoped_session,
|
||||
statistic_ids: list[str] | None,
|
||||
statistic_type: str | None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Fetch meta data."""
|
||||
|
||||
def _meta(metas, wanted_metadata_id):
|
||||
def _meta(metas: list, wanted_metadata_id: str) -> dict[str, str] | None:
|
||||
meta = None
|
||||
for metadata_id, statistic_id, unit in metas:
|
||||
if metadata_id == wanted_metadata_id:
|
||||
|
@ -150,12 +170,19 @@ def _get_metadata(hass, session, statistic_ids, statistic_type):
|
|||
elif statistic_type is not None:
|
||||
return {}
|
||||
result = execute(baked_query(session).params(statistic_ids=statistic_ids))
|
||||
if not result:
|
||||
return {}
|
||||
|
||||
metadata_ids = [metadata[0] for metadata in result]
|
||||
return {id: _meta(result, id) for id in metadata_ids}
|
||||
metadata = {}
|
||||
for _id in metadata_ids:
|
||||
meta = _meta(result, _id)
|
||||
if meta:
|
||||
metadata[_id] = meta
|
||||
return metadata
|
||||
|
||||
|
||||
def _configured_unit(unit: str, units) -> str:
|
||||
def _configured_unit(unit: str, units: UnitSystem) -> str:
|
||||
"""Return the pressure and temperature units configured by the user."""
|
||||
if unit == PRESSURE_PA:
|
||||
return units.pressure_unit
|
||||
|
@ -164,7 +191,9 @@ def _configured_unit(unit: str, units) -> str:
|
|||
return unit
|
||||
|
||||
|
||||
def list_statistic_ids(hass, statistic_type=None):
|
||||
def list_statistic_ids(
|
||||
hass: HomeAssistant, statistic_type: str | None = None
|
||||
) -> list[dict[str, str] | None]:
|
||||
"""Return statistic_ids and meta data."""
|
||||
units = hass.config.units
|
||||
with session_scope(hass=hass) as session:
|
||||
|
@ -177,7 +206,12 @@ def list_statistic_ids(hass, statistic_type=None):
|
|||
return list(metadata.values())
|
||||
|
||||
|
||||
def statistics_during_period(hass, start_time, end_time=None, statistic_ids=None):
|
||||
def statistics_during_period(
|
||||
hass: HomeAssistant,
|
||||
start_time: datetime,
|
||||
end_time: datetime | None = None,
|
||||
statistic_ids: list[str] | None = None,
|
||||
) -> dict[str, list[dict[str, str]]]:
|
||||
"""Return states changes during UTC period start_time - end_time."""
|
||||
metadata = None
|
||||
with session_scope(hass=hass) as session:
|
||||
|
@ -208,10 +242,14 @@ def statistics_during_period(hass, start_time, end_time=None, statistic_ids=None
|
|||
start_time=start_time, end_time=end_time, metadata_ids=metadata_ids
|
||||
)
|
||||
)
|
||||
if not stats:
|
||||
return {}
|
||||
return _sorted_statistics_to_dict(hass, stats, statistic_ids, metadata)
|
||||
|
||||
|
||||
def get_last_statistics(hass, number_of_stats, statistic_id):
|
||||
def get_last_statistics(
|
||||
hass: HomeAssistant, number_of_stats: int, statistic_id: str
|
||||
) -> dict[str, list[dict]]:
|
||||
"""Return the last number_of_stats statistics for a statistic_id."""
|
||||
statistic_ids = [statistic_id]
|
||||
with session_scope(hass=hass) as session:
|
||||
|
@ -237,18 +275,20 @@ def get_last_statistics(hass, number_of_stats, statistic_id):
|
|||
number_of_stats=number_of_stats, metadata_id=metadata_id
|
||||
)
|
||||
)
|
||||
if not stats:
|
||||
return {}
|
||||
|
||||
return _sorted_statistics_to_dict(hass, stats, statistic_ids, metadata)
|
||||
|
||||
|
||||
def _sorted_statistics_to_dict(
|
||||
hass,
|
||||
stats,
|
||||
statistic_ids,
|
||||
metadata,
|
||||
):
|
||||
hass: HomeAssistant,
|
||||
stats: list,
|
||||
statistic_ids: list[str] | None,
|
||||
metadata: dict[str, dict[str, str]],
|
||||
) -> dict[str, list[dict]]:
|
||||
"""Convert SQL results into JSON friendly data structure."""
|
||||
result = defaultdict(list)
|
||||
result: dict = defaultdict(list)
|
||||
units = hass.config.units
|
||||
|
||||
# Set all statistic IDs to empty lists in result set to maintain the order
|
||||
|
@ -260,10 +300,12 @@ def _sorted_statistics_to_dict(
|
|||
_process_timestamp_to_utc_isoformat = process_timestamp_to_utc_isoformat
|
||||
|
||||
# Append all statistic entries, and do unit conversion
|
||||
for meta_id, group in groupby(stats, lambda state: state.metadata_id):
|
||||
for meta_id, group in groupby(stats, lambda stat: stat.metadata_id): # type: ignore
|
||||
unit = metadata[meta_id]["unit_of_measurement"]
|
||||
statistic_id = metadata[meta_id]["statistic_id"]
|
||||
convert = UNIT_CONVERSIONS.get(unit, lambda x, units: x)
|
||||
convert: Callable[[Any, Any], float | None] = UNIT_CONVERSIONS.get(
|
||||
unit, lambda x, units: x # type: ignore
|
||||
)
|
||||
ent_results = result[meta_id]
|
||||
ent_results.extend(
|
||||
{
|
||||
|
|
|
@ -8,7 +8,7 @@ import functools
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
@ -91,7 +91,7 @@ def commit(session, work):
|
|||
return False
|
||||
|
||||
|
||||
def execute(qry, to_native=False, validate_entity_ids=True):
|
||||
def execute(qry, to_native=False, validate_entity_ids=True) -> list | None:
|
||||
"""Query the database and convert the objects to HA native form.
|
||||
|
||||
This method also retries a few times in the case of stale connections.
|
||||
|
@ -135,6 +135,8 @@ def execute(qry, to_native=False, validate_entity_ids=True):
|
|||
raise
|
||||
time.sleep(QUERY_RETRY_WAIT)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_or_move_away_sqlite_database(dburl: str) -> bool:
|
||||
"""Ensure that the database is valid or move it away."""
|
||||
|
@ -288,13 +290,13 @@ def end_incomplete_runs(session, start_time):
|
|||
session.add(run)
|
||||
|
||||
|
||||
def retryable_database_job(description: str):
|
||||
def retryable_database_job(description: str) -> Callable:
|
||||
"""Try to execute a database job.
|
||||
|
||||
The job should return True if it finished, and False if it needs to be rescheduled.
|
||||
"""
|
||||
|
||||
def decorator(job: callable):
|
||||
def decorator(job: Callable) -> Callable:
|
||||
@functools.wraps(job)
|
||||
def wrapper(instance: Recorder, *args, **kwargs):
|
||||
try:
|
||||
|
|
|
@ -244,7 +244,7 @@ def compile_statistics(
|
|||
last_reset = old_last_reset = None
|
||||
new_state = old_state = None
|
||||
_sum = 0
|
||||
last_stats = statistics.get_last_statistics(hass, 1, entity_id) # type: ignore
|
||||
last_stats = statistics.get_last_statistics(hass, 1, entity_id)
|
||||
if entity_id in last_stats:
|
||||
# We have compiled history for this sensor before, use that as a starting point
|
||||
last_reset = old_last_reset = last_stats[entity_id][0]["last_reset"]
|
||||
|
|
3
mypy.ini
3
mypy.ini
|
@ -1470,9 +1470,6 @@ ignore_errors = true
|
|||
[mypy-homeassistant.components.recollect_waste.*]
|
||||
ignore_errors = true
|
||||
|
||||
[mypy-homeassistant.components.recorder.*]
|
||||
ignore_errors = true
|
||||
|
||||
[mypy-homeassistant.components.reddit.*]
|
||||
ignore_errors = true
|
||||
|
||||
|
|
|
@ -152,7 +152,6 @@ IGNORED_MODULES: Final[list[str]] = [
|
|||
"homeassistant.components.rachio.*",
|
||||
"homeassistant.components.rainmachine.*",
|
||||
"homeassistant.components.recollect_waste.*",
|
||||
"homeassistant.components.recorder.*",
|
||||
"homeassistant.components.reddit.*",
|
||||
"homeassistant.components.ring.*",
|
||||
"homeassistant.components.rpi_power.*",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue