Enable basic type checking for recorder (#52440)

* Enable basic type checking for recorder

* Tweak
This commit is contained in:
Erik Montnemery 2021-07-13 21:21:45 +02:00 committed by GitHub
parent 960684346f
commit 19d3aa71ad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 108 additions and 39 deletions

View file

@ -324,7 +324,7 @@ class PerodicCleanupTask:
class StatisticsTask(NamedTuple):
"""An object to insert into the recorder queue to run a statistics task."""
start: datetime.datetime
start: datetime
class WaitTask:
@ -358,7 +358,7 @@ class Recorder(threading.Thread):
self.db_url = uri
self.db_max_retries = db_max_retries
self.db_retry_wait = db_retry_wait
self.async_db_ready = asyncio.Future()
self.async_db_ready: asyncio.Future = asyncio.Future()
self.async_recorder_ready = asyncio.Event()
self._queue_watch = threading.Event()
self.engine: Any = None
@ -370,8 +370,8 @@ class Recorder(threading.Thread):
self._timechanges_seen = 0
self._commits_without_expire = 0
self._keepalive_count = 0
self._old_states = {}
self._pending_expunge = []
self._old_states: dict[str, States] = {}
self._pending_expunge: list[States] = []
self.event_session = None
self.get_session = None
self._completed_first_database_setup = None

View file

@ -1,6 +1,10 @@
"""Models for SQLAlchemy."""
from __future__ import annotations
from datetime import datetime
import json
import logging
from typing import TypedDict
from sqlalchemy import (
Boolean,
@ -206,6 +210,17 @@ class States(Base): # type: ignore
return None
class StatisticData(TypedDict, total=False):
"""Statistic data class."""
mean: float
min: float
max: float
last_reset: datetime | None
state: float
sum: float
class Statistics(Base): # type: ignore
"""Statistics."""
@ -230,7 +245,7 @@ class Statistics(Base): # type: ignore
sum = Column(Float())
@staticmethod
def from_stats(metadata_id, start, stats):
def from_stats(metadata_id: str, start: datetime, stats: StatisticData):
"""Create object from a statistics."""
return Statistics(
metadata_id=metadata_id,
@ -239,6 +254,14 @@ class Statistics(Base): # type: ignore
)
class StatisticMetaData(TypedDict, total=False):
"""Statistic meta data class."""
unit_of_measurement: str | None
has_mean: bool
has_sum: bool
class StatisticsMeta(Base): # type: ignore
"""Statistics meta data."""
@ -251,7 +274,13 @@ class StatisticsMeta(Base): # type: ignore
has_sum = Column(Boolean)
@staticmethod
def from_meta(source, statistic_id, unit_of_measurement, has_mean, has_sum):
def from_meta(
source: str,
statistic_id: str,
unit_of_measurement: str | None,
has_mean: bool,
has_sum: bool,
) -> StatisticsMeta:
"""Create object from meta data."""
return StatisticsMeta(
source=source,
@ -340,7 +369,7 @@ def process_timestamp(ts):
return dt_util.as_utc(ts)
def process_timestamp_to_utc_isoformat(ts):
def process_timestamp_to_utc_isoformat(ts: datetime | None) -> str | None:
"""Process a timestamp into UTC isotime."""
if ts is None:
return None

View file

@ -5,18 +5,26 @@ from collections import defaultdict
from datetime import datetime, timedelta
from itertools import groupby
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Callable
from sqlalchemy import bindparam
from sqlalchemy.ext import baked
from sqlalchemy.orm.scoping import scoped_session
from homeassistant.const import PRESSURE_PA, TEMP_CELSIUS
from homeassistant.core import HomeAssistant
import homeassistant.util.dt as dt_util
import homeassistant.util.pressure as pressure_util
import homeassistant.util.temperature as temperature_util
from homeassistant.util.unit_system import UnitSystem
from .const import DOMAIN
from .models import Statistics, StatisticsMeta, process_timestamp_to_utc_isoformat
from .models import (
StatisticMetaData,
Statistics,
StatisticsMeta,
process_timestamp_to_utc_isoformat,
)
from .util import execute, retryable_database_job, session_scope
if TYPE_CHECKING:
@ -60,20 +68,22 @@ UNIT_CONVERSIONS = {
_LOGGER = logging.getLogger(__name__)
def async_setup(hass):
def async_setup(hass: HomeAssistant) -> None:
"""Set up the history hooks."""
hass.data[STATISTICS_BAKERY] = baked.bakery()
hass.data[STATISTICS_META_BAKERY] = baked.bakery()
def get_start_time() -> datetime.datetime:
def get_start_time() -> datetime:
"""Return start time."""
last_hour = dt_util.utcnow() - timedelta(hours=1)
start = last_hour.replace(minute=0, second=0, microsecond=0)
return start
def _get_metadata_ids(hass, session, statistic_ids):
def _get_metadata_ids(
hass: HomeAssistant, session: scoped_session, statistic_ids: list[str]
) -> list[str]:
"""Resolve metadata_id for a list of statistic_ids."""
baked_query = hass.data[STATISTICS_META_BAKERY](
lambda session: session.query(*QUERY_STATISTIC_META)
@ -83,10 +93,15 @@ def _get_metadata_ids(hass, session, statistic_ids):
)
result = execute(baked_query(session).params(statistic_ids=statistic_ids))
return [id for id, _, _ in result]
return [id for id, _, _ in result] if result else []
def _get_or_add_metadata_id(hass, session, statistic_id, metadata):
def _get_or_add_metadata_id(
hass: HomeAssistant,
session: scoped_session,
statistic_id: str,
metadata: StatisticMetaData,
) -> str:
"""Get metadata_id for a statistic_id, add if it doesn't exist."""
metadata_id = _get_metadata_ids(hass, session, [statistic_id])
if not metadata_id:
@ -101,7 +116,7 @@ def _get_or_add_metadata_id(hass, session, statistic_id, metadata):
@retryable_database_job("statistics")
def compile_statistics(instance: Recorder, start: datetime.datetime) -> bool:
def compile_statistics(instance: Recorder, start: datetime) -> bool:
"""Compile statistics."""
start = dt_util.as_utc(start)
end = start + timedelta(hours=1)
@ -126,10 +141,15 @@ def compile_statistics(instance: Recorder, start: datetime.datetime) -> bool:
return True
def _get_metadata(hass, session, statistic_ids, statistic_type):
def _get_metadata(
hass: HomeAssistant,
session: scoped_session,
statistic_ids: list[str] | None,
statistic_type: str | None,
) -> dict[str, dict[str, str]]:
"""Fetch meta data."""
def _meta(metas, wanted_metadata_id):
def _meta(metas: list, wanted_metadata_id: str) -> dict[str, str] | None:
meta = None
for metadata_id, statistic_id, unit in metas:
if metadata_id == wanted_metadata_id:
@ -150,12 +170,19 @@ def _get_metadata(hass, session, statistic_ids, statistic_type):
elif statistic_type is not None:
return {}
result = execute(baked_query(session).params(statistic_ids=statistic_ids))
if not result:
return {}
metadata_ids = [metadata[0] for metadata in result]
return {id: _meta(result, id) for id in metadata_ids}
metadata = {}
for _id in metadata_ids:
meta = _meta(result, _id)
if meta:
metadata[_id] = meta
return metadata
def _configured_unit(unit: str, units) -> str:
def _configured_unit(unit: str, units: UnitSystem) -> str:
"""Return the pressure and temperature units configured by the user."""
if unit == PRESSURE_PA:
return units.pressure_unit
@ -164,7 +191,9 @@ def _configured_unit(unit: str, units) -> str:
return unit
def list_statistic_ids(hass, statistic_type=None):
def list_statistic_ids(
hass: HomeAssistant, statistic_type: str | None = None
) -> list[dict[str, str] | None]:
"""Return statistic_ids and meta data."""
units = hass.config.units
with session_scope(hass=hass) as session:
@ -177,7 +206,12 @@ def list_statistic_ids(hass, statistic_type=None):
return list(metadata.values())
def statistics_during_period(hass, start_time, end_time=None, statistic_ids=None):
def statistics_during_period(
hass: HomeAssistant,
start_time: datetime,
end_time: datetime | None = None,
statistic_ids: list[str] | None = None,
) -> dict[str, list[dict[str, str]]]:
"""Return states changes during UTC period start_time - end_time."""
metadata = None
with session_scope(hass=hass) as session:
@ -208,10 +242,14 @@ def statistics_during_period(hass, start_time, end_time=None, statistic_ids=None
start_time=start_time, end_time=end_time, metadata_ids=metadata_ids
)
)
if not stats:
return {}
return _sorted_statistics_to_dict(hass, stats, statistic_ids, metadata)
def get_last_statistics(hass, number_of_stats, statistic_id):
def get_last_statistics(
hass: HomeAssistant, number_of_stats: int, statistic_id: str
) -> dict[str, list[dict]]:
"""Return the last number_of_stats statistics for a statistic_id."""
statistic_ids = [statistic_id]
with session_scope(hass=hass) as session:
@ -237,18 +275,20 @@ def get_last_statistics(hass, number_of_stats, statistic_id):
number_of_stats=number_of_stats, metadata_id=metadata_id
)
)
if not stats:
return {}
return _sorted_statistics_to_dict(hass, stats, statistic_ids, metadata)
def _sorted_statistics_to_dict(
hass,
stats,
statistic_ids,
metadata,
):
hass: HomeAssistant,
stats: list,
statistic_ids: list[str] | None,
metadata: dict[str, dict[str, str]],
) -> dict[str, list[dict]]:
"""Convert SQL results into JSON friendly data structure."""
result = defaultdict(list)
result: dict = defaultdict(list)
units = hass.config.units
# Set all statistic IDs to empty lists in result set to maintain the order
@ -260,10 +300,12 @@ def _sorted_statistics_to_dict(
_process_timestamp_to_utc_isoformat = process_timestamp_to_utc_isoformat
# Append all statistic entries, and do unit conversion
for meta_id, group in groupby(stats, lambda state: state.metadata_id):
for meta_id, group in groupby(stats, lambda stat: stat.metadata_id): # type: ignore
unit = metadata[meta_id]["unit_of_measurement"]
statistic_id = metadata[meta_id]["statistic_id"]
convert = UNIT_CONVERSIONS.get(unit, lambda x, units: x)
convert: Callable[[Any, Any], float | None] = UNIT_CONVERSIONS.get(
unit, lambda x, units: x # type: ignore
)
ent_results = result[meta_id]
ent_results.extend(
{

View file

@ -8,7 +8,7 @@ import functools
import logging
import os
import time
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
from sqlalchemy.exc import OperationalError, SQLAlchemyError
from sqlalchemy.orm.session import Session
@ -91,7 +91,7 @@ def commit(session, work):
return False
def execute(qry, to_native=False, validate_entity_ids=True):
def execute(qry, to_native=False, validate_entity_ids=True) -> list | None:
"""Query the database and convert the objects to HA native form.
This method also retries a few times in the case of stale connections.
@ -135,6 +135,8 @@ def execute(qry, to_native=False, validate_entity_ids=True):
raise
time.sleep(QUERY_RETRY_WAIT)
return None
def validate_or_move_away_sqlite_database(dburl: str) -> bool:
"""Ensure that the database is valid or move it away."""
@ -288,13 +290,13 @@ def end_incomplete_runs(session, start_time):
session.add(run)
def retryable_database_job(description: str):
def retryable_database_job(description: str) -> Callable:
"""Try to execute a database job.
The job should return True if it finished, and False if it needs to be rescheduled.
"""
def decorator(job: callable):
def decorator(job: Callable) -> Callable:
@functools.wraps(job)
def wrapper(instance: Recorder, *args, **kwargs):
try:

View file

@ -244,7 +244,7 @@ def compile_statistics(
last_reset = old_last_reset = None
new_state = old_state = None
_sum = 0
last_stats = statistics.get_last_statistics(hass, 1, entity_id) # type: ignore
last_stats = statistics.get_last_statistics(hass, 1, entity_id)
if entity_id in last_stats:
# We have compiled history for this sensor before, use that as a starting point
last_reset = old_last_reset = last_stats[entity_id][0]["last_reset"]

View file

@ -1470,9 +1470,6 @@ ignore_errors = true
[mypy-homeassistant.components.recollect_waste.*]
ignore_errors = true
[mypy-homeassistant.components.recorder.*]
ignore_errors = true
[mypy-homeassistant.components.reddit.*]
ignore_errors = true

View file

@ -152,7 +152,6 @@ IGNORED_MODULES: Final[list[str]] = [
"homeassistant.components.rachio.*",
"homeassistant.components.rainmachine.*",
"homeassistant.components.recollect_waste.*",
"homeassistant.components.recorder.*",
"homeassistant.components.reddit.*",
"homeassistant.components.ring.*",
"homeassistant.components.rpi_power.*",