Enable basic type checking for recorder (#52440)

* Enable basic type checking for recorder

* Tweak
This commit is contained in:
Erik Montnemery 2021-07-13 21:21:45 +02:00 committed by GitHub
parent 960684346f
commit 19d3aa71ad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 108 additions and 39 deletions

View file

@ -324,7 +324,7 @@ class PerodicCleanupTask:
class StatisticsTask(NamedTuple): class StatisticsTask(NamedTuple):
"""An object to insert into the recorder queue to run a statistics task.""" """An object to insert into the recorder queue to run a statistics task."""
start: datetime.datetime start: datetime
class WaitTask: class WaitTask:
@ -358,7 +358,7 @@ class Recorder(threading.Thread):
self.db_url = uri self.db_url = uri
self.db_max_retries = db_max_retries self.db_max_retries = db_max_retries
self.db_retry_wait = db_retry_wait self.db_retry_wait = db_retry_wait
self.async_db_ready = asyncio.Future() self.async_db_ready: asyncio.Future = asyncio.Future()
self.async_recorder_ready = asyncio.Event() self.async_recorder_ready = asyncio.Event()
self._queue_watch = threading.Event() self._queue_watch = threading.Event()
self.engine: Any = None self.engine: Any = None
@ -370,8 +370,8 @@ class Recorder(threading.Thread):
self._timechanges_seen = 0 self._timechanges_seen = 0
self._commits_without_expire = 0 self._commits_without_expire = 0
self._keepalive_count = 0 self._keepalive_count = 0
self._old_states = {} self._old_states: dict[str, States] = {}
self._pending_expunge = [] self._pending_expunge: list[States] = []
self.event_session = None self.event_session = None
self.get_session = None self.get_session = None
self._completed_first_database_setup = None self._completed_first_database_setup = None

View file

@ -1,6 +1,10 @@
"""Models for SQLAlchemy.""" """Models for SQLAlchemy."""
from __future__ import annotations
from datetime import datetime
import json import json
import logging import logging
from typing import TypedDict
from sqlalchemy import ( from sqlalchemy import (
Boolean, Boolean,
@ -206,6 +210,17 @@ class States(Base): # type: ignore
return None return None
class StatisticData(TypedDict, total=False):
"""Statistic data class."""
mean: float
min: float
max: float
last_reset: datetime | None
state: float
sum: float
class Statistics(Base): # type: ignore class Statistics(Base): # type: ignore
"""Statistics.""" """Statistics."""
@ -230,7 +245,7 @@ class Statistics(Base): # type: ignore
sum = Column(Float()) sum = Column(Float())
@staticmethod @staticmethod
def from_stats(metadata_id, start, stats): def from_stats(metadata_id: str, start: datetime, stats: StatisticData):
"""Create object from a statistics.""" """Create object from a statistics."""
return Statistics( return Statistics(
metadata_id=metadata_id, metadata_id=metadata_id,
@ -239,6 +254,14 @@ class Statistics(Base): # type: ignore
) )
class StatisticMetaData(TypedDict, total=False):
"""Statistic meta data class."""
unit_of_measurement: str | None
has_mean: bool
has_sum: bool
class StatisticsMeta(Base): # type: ignore class StatisticsMeta(Base): # type: ignore
"""Statistics meta data.""" """Statistics meta data."""
@ -251,7 +274,13 @@ class StatisticsMeta(Base): # type: ignore
has_sum = Column(Boolean) has_sum = Column(Boolean)
@staticmethod @staticmethod
def from_meta(source, statistic_id, unit_of_measurement, has_mean, has_sum): def from_meta(
source: str,
statistic_id: str,
unit_of_measurement: str | None,
has_mean: bool,
has_sum: bool,
) -> StatisticsMeta:
"""Create object from meta data.""" """Create object from meta data."""
return StatisticsMeta( return StatisticsMeta(
source=source, source=source,
@ -340,7 +369,7 @@ def process_timestamp(ts):
return dt_util.as_utc(ts) return dt_util.as_utc(ts)
def process_timestamp_to_utc_isoformat(ts): def process_timestamp_to_utc_isoformat(ts: datetime | None) -> str | None:
"""Process a timestamp into UTC isotime.""" """Process a timestamp into UTC isotime."""
if ts is None: if ts is None:
return None return None

View file

@ -5,18 +5,26 @@ from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from itertools import groupby from itertools import groupby
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, Callable
from sqlalchemy import bindparam from sqlalchemy import bindparam
from sqlalchemy.ext import baked from sqlalchemy.ext import baked
from sqlalchemy.orm.scoping import scoped_session
from homeassistant.const import PRESSURE_PA, TEMP_CELSIUS from homeassistant.const import PRESSURE_PA, TEMP_CELSIUS
from homeassistant.core import HomeAssistant
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
import homeassistant.util.pressure as pressure_util import homeassistant.util.pressure as pressure_util
import homeassistant.util.temperature as temperature_util import homeassistant.util.temperature as temperature_util
from homeassistant.util.unit_system import UnitSystem
from .const import DOMAIN from .const import DOMAIN
from .models import Statistics, StatisticsMeta, process_timestamp_to_utc_isoformat from .models import (
StatisticMetaData,
Statistics,
StatisticsMeta,
process_timestamp_to_utc_isoformat,
)
from .util import execute, retryable_database_job, session_scope from .util import execute, retryable_database_job, session_scope
if TYPE_CHECKING: if TYPE_CHECKING:
@ -60,20 +68,22 @@ UNIT_CONVERSIONS = {
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def async_setup(hass): def async_setup(hass: HomeAssistant) -> None:
"""Set up the history hooks.""" """Set up the history hooks."""
hass.data[STATISTICS_BAKERY] = baked.bakery() hass.data[STATISTICS_BAKERY] = baked.bakery()
hass.data[STATISTICS_META_BAKERY] = baked.bakery() hass.data[STATISTICS_META_BAKERY] = baked.bakery()
def get_start_time() -> datetime.datetime: def get_start_time() -> datetime:
"""Return start time.""" """Return start time."""
last_hour = dt_util.utcnow() - timedelta(hours=1) last_hour = dt_util.utcnow() - timedelta(hours=1)
start = last_hour.replace(minute=0, second=0, microsecond=0) start = last_hour.replace(minute=0, second=0, microsecond=0)
return start return start
def _get_metadata_ids(hass, session, statistic_ids): def _get_metadata_ids(
hass: HomeAssistant, session: scoped_session, statistic_ids: list[str]
) -> list[str]:
"""Resolve metadata_id for a list of statistic_ids.""" """Resolve metadata_id for a list of statistic_ids."""
baked_query = hass.data[STATISTICS_META_BAKERY]( baked_query = hass.data[STATISTICS_META_BAKERY](
lambda session: session.query(*QUERY_STATISTIC_META) lambda session: session.query(*QUERY_STATISTIC_META)
@ -83,10 +93,15 @@ def _get_metadata_ids(hass, session, statistic_ids):
) )
result = execute(baked_query(session).params(statistic_ids=statistic_ids)) result = execute(baked_query(session).params(statistic_ids=statistic_ids))
return [id for id, _, _ in result] return [id for id, _, _ in result] if result else []
def _get_or_add_metadata_id(hass, session, statistic_id, metadata): def _get_or_add_metadata_id(
hass: HomeAssistant,
session: scoped_session,
statistic_id: str,
metadata: StatisticMetaData,
) -> str:
"""Get metadata_id for a statistic_id, add if it doesn't exist.""" """Get metadata_id for a statistic_id, add if it doesn't exist."""
metadata_id = _get_metadata_ids(hass, session, [statistic_id]) metadata_id = _get_metadata_ids(hass, session, [statistic_id])
if not metadata_id: if not metadata_id:
@ -101,7 +116,7 @@ def _get_or_add_metadata_id(hass, session, statistic_id, metadata):
@retryable_database_job("statistics") @retryable_database_job("statistics")
def compile_statistics(instance: Recorder, start: datetime.datetime) -> bool: def compile_statistics(instance: Recorder, start: datetime) -> bool:
"""Compile statistics.""" """Compile statistics."""
start = dt_util.as_utc(start) start = dt_util.as_utc(start)
end = start + timedelta(hours=1) end = start + timedelta(hours=1)
@ -126,10 +141,15 @@ def compile_statistics(instance: Recorder, start: datetime.datetime) -> bool:
return True return True
def _get_metadata(hass, session, statistic_ids, statistic_type): def _get_metadata(
hass: HomeAssistant,
session: scoped_session,
statistic_ids: list[str] | None,
statistic_type: str | None,
) -> dict[str, dict[str, str]]:
"""Fetch meta data.""" """Fetch meta data."""
def _meta(metas, wanted_metadata_id): def _meta(metas: list, wanted_metadata_id: str) -> dict[str, str] | None:
meta = None meta = None
for metadata_id, statistic_id, unit in metas: for metadata_id, statistic_id, unit in metas:
if metadata_id == wanted_metadata_id: if metadata_id == wanted_metadata_id:
@ -150,12 +170,19 @@ def _get_metadata(hass, session, statistic_ids, statistic_type):
elif statistic_type is not None: elif statistic_type is not None:
return {} return {}
result = execute(baked_query(session).params(statistic_ids=statistic_ids)) result = execute(baked_query(session).params(statistic_ids=statistic_ids))
if not result:
return {}
metadata_ids = [metadata[0] for metadata in result] metadata_ids = [metadata[0] for metadata in result]
return {id: _meta(result, id) for id in metadata_ids} metadata = {}
for _id in metadata_ids:
meta = _meta(result, _id)
if meta:
metadata[_id] = meta
return metadata
def _configured_unit(unit: str, units) -> str: def _configured_unit(unit: str, units: UnitSystem) -> str:
"""Return the pressure and temperature units configured by the user.""" """Return the pressure and temperature units configured by the user."""
if unit == PRESSURE_PA: if unit == PRESSURE_PA:
return units.pressure_unit return units.pressure_unit
@ -164,7 +191,9 @@ def _configured_unit(unit: str, units) -> str:
return unit return unit
def list_statistic_ids(hass, statistic_type=None): def list_statistic_ids(
hass: HomeAssistant, statistic_type: str | None = None
) -> list[dict[str, str] | None]:
"""Return statistic_ids and meta data.""" """Return statistic_ids and meta data."""
units = hass.config.units units = hass.config.units
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
@ -177,7 +206,12 @@ def list_statistic_ids(hass, statistic_type=None):
return list(metadata.values()) return list(metadata.values())
def statistics_during_period(hass, start_time, end_time=None, statistic_ids=None): def statistics_during_period(
hass: HomeAssistant,
start_time: datetime,
end_time: datetime | None = None,
statistic_ids: list[str] | None = None,
) -> dict[str, list[dict[str, str]]]:
"""Return states changes during UTC period start_time - end_time.""" """Return states changes during UTC period start_time - end_time."""
metadata = None metadata = None
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
@ -208,10 +242,14 @@ def statistics_during_period(hass, start_time, end_time=None, statistic_ids=None
start_time=start_time, end_time=end_time, metadata_ids=metadata_ids start_time=start_time, end_time=end_time, metadata_ids=metadata_ids
) )
) )
if not stats:
return {}
return _sorted_statistics_to_dict(hass, stats, statistic_ids, metadata) return _sorted_statistics_to_dict(hass, stats, statistic_ids, metadata)
def get_last_statistics(hass, number_of_stats, statistic_id): def get_last_statistics(
hass: HomeAssistant, number_of_stats: int, statistic_id: str
) -> dict[str, list[dict]]:
"""Return the last number_of_stats statistics for a statistic_id.""" """Return the last number_of_stats statistics for a statistic_id."""
statistic_ids = [statistic_id] statistic_ids = [statistic_id]
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
@ -237,18 +275,20 @@ def get_last_statistics(hass, number_of_stats, statistic_id):
number_of_stats=number_of_stats, metadata_id=metadata_id number_of_stats=number_of_stats, metadata_id=metadata_id
) )
) )
if not stats:
return {}
return _sorted_statistics_to_dict(hass, stats, statistic_ids, metadata) return _sorted_statistics_to_dict(hass, stats, statistic_ids, metadata)
def _sorted_statistics_to_dict( def _sorted_statistics_to_dict(
hass, hass: HomeAssistant,
stats, stats: list,
statistic_ids, statistic_ids: list[str] | None,
metadata, metadata: dict[str, dict[str, str]],
): ) -> dict[str, list[dict]]:
"""Convert SQL results into JSON friendly data structure.""" """Convert SQL results into JSON friendly data structure."""
result = defaultdict(list) result: dict = defaultdict(list)
units = hass.config.units units = hass.config.units
# Set all statistic IDs to empty lists in result set to maintain the order # Set all statistic IDs to empty lists in result set to maintain the order
@ -260,10 +300,12 @@ def _sorted_statistics_to_dict(
_process_timestamp_to_utc_isoformat = process_timestamp_to_utc_isoformat _process_timestamp_to_utc_isoformat = process_timestamp_to_utc_isoformat
# Append all statistic entries, and do unit conversion # Append all statistic entries, and do unit conversion
for meta_id, group in groupby(stats, lambda state: state.metadata_id): for meta_id, group in groupby(stats, lambda stat: stat.metadata_id): # type: ignore
unit = metadata[meta_id]["unit_of_measurement"] unit = metadata[meta_id]["unit_of_measurement"]
statistic_id = metadata[meta_id]["statistic_id"] statistic_id = metadata[meta_id]["statistic_id"]
convert = UNIT_CONVERSIONS.get(unit, lambda x, units: x) convert: Callable[[Any, Any], float | None] = UNIT_CONVERSIONS.get(
unit, lambda x, units: x # type: ignore
)
ent_results = result[meta_id] ent_results = result[meta_id]
ent_results.extend( ent_results.extend(
{ {

View file

@ -8,7 +8,7 @@ import functools
import logging import logging
import os import os
import time import time
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Callable
from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.exc import OperationalError, SQLAlchemyError
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
@ -91,7 +91,7 @@ def commit(session, work):
return False return False
def execute(qry, to_native=False, validate_entity_ids=True): def execute(qry, to_native=False, validate_entity_ids=True) -> list | None:
"""Query the database and convert the objects to HA native form. """Query the database and convert the objects to HA native form.
This method also retries a few times in the case of stale connections. This method also retries a few times in the case of stale connections.
@ -135,6 +135,8 @@ def execute(qry, to_native=False, validate_entity_ids=True):
raise raise
time.sleep(QUERY_RETRY_WAIT) time.sleep(QUERY_RETRY_WAIT)
return None
def validate_or_move_away_sqlite_database(dburl: str) -> bool: def validate_or_move_away_sqlite_database(dburl: str) -> bool:
"""Ensure that the database is valid or move it away.""" """Ensure that the database is valid or move it away."""
@ -288,13 +290,13 @@ def end_incomplete_runs(session, start_time):
session.add(run) session.add(run)
def retryable_database_job(description: str): def retryable_database_job(description: str) -> Callable:
"""Try to execute a database job. """Try to execute a database job.
The job should return True if it finished, and False if it needs to be rescheduled. The job should return True if it finished, and False if it needs to be rescheduled.
""" """
def decorator(job: callable): def decorator(job: Callable) -> Callable:
@functools.wraps(job) @functools.wraps(job)
def wrapper(instance: Recorder, *args, **kwargs): def wrapper(instance: Recorder, *args, **kwargs):
try: try:

View file

@ -244,7 +244,7 @@ def compile_statistics(
last_reset = old_last_reset = None last_reset = old_last_reset = None
new_state = old_state = None new_state = old_state = None
_sum = 0 _sum = 0
last_stats = statistics.get_last_statistics(hass, 1, entity_id) # type: ignore last_stats = statistics.get_last_statistics(hass, 1, entity_id)
if entity_id in last_stats: if entity_id in last_stats:
# We have compiled history for this sensor before, use that as a starting point # We have compiled history for this sensor before, use that as a starting point
last_reset = old_last_reset = last_stats[entity_id][0]["last_reset"] last_reset = old_last_reset = last_stats[entity_id][0]["last_reset"]

View file

@ -1470,9 +1470,6 @@ ignore_errors = true
[mypy-homeassistant.components.recollect_waste.*] [mypy-homeassistant.components.recollect_waste.*]
ignore_errors = true ignore_errors = true
[mypy-homeassistant.components.recorder.*]
ignore_errors = true
[mypy-homeassistant.components.reddit.*] [mypy-homeassistant.components.reddit.*]
ignore_errors = true ignore_errors = true

View file

@ -152,7 +152,6 @@ IGNORED_MODULES: Final[list[str]] = [
"homeassistant.components.rachio.*", "homeassistant.components.rachio.*",
"homeassistant.components.rainmachine.*", "homeassistant.components.rainmachine.*",
"homeassistant.components.recollect_waste.*", "homeassistant.components.recollect_waste.*",
"homeassistant.components.recorder.*",
"homeassistant.components.reddit.*", "homeassistant.components.reddit.*",
"homeassistant.components.ring.*", "homeassistant.components.ring.*",
"homeassistant.components.rpi_power.*", "homeassistant.components.rpi_power.*",