Enable basic type checking for recorder (#52440)

* Enable basic type checking for recorder

* Tweak
This commit is contained in:
Erik Montnemery 2021-07-13 21:21:45 +02:00 committed by GitHub
parent 960684346f
commit 19d3aa71ad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 108 additions and 39 deletions

View file

@ -5,18 +5,26 @@ from collections import defaultdict
from datetime import datetime, timedelta
from itertools import groupby
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Callable
from sqlalchemy import bindparam
from sqlalchemy.ext import baked
from sqlalchemy.orm.scoping import scoped_session
from homeassistant.const import PRESSURE_PA, TEMP_CELSIUS
from homeassistant.core import HomeAssistant
import homeassistant.util.dt as dt_util
import homeassistant.util.pressure as pressure_util
import homeassistant.util.temperature as temperature_util
from homeassistant.util.unit_system import UnitSystem
from .const import DOMAIN
from .models import Statistics, StatisticsMeta, process_timestamp_to_utc_isoformat
from .models import (
StatisticMetaData,
Statistics,
StatisticsMeta,
process_timestamp_to_utc_isoformat,
)
from .util import execute, retryable_database_job, session_scope
if TYPE_CHECKING:
@ -60,20 +68,22 @@ UNIT_CONVERSIONS = {
_LOGGER = logging.getLogger(__name__)
def async_setup(hass):
def async_setup(hass: HomeAssistant) -> None:
"""Set up the history hooks."""
hass.data[STATISTICS_BAKERY] = baked.bakery()
hass.data[STATISTICS_META_BAKERY] = baked.bakery()
def get_start_time() -> datetime.datetime:
def get_start_time() -> datetime:
"""Return start time."""
last_hour = dt_util.utcnow() - timedelta(hours=1)
start = last_hour.replace(minute=0, second=0, microsecond=0)
return start
def _get_metadata_ids(hass, session, statistic_ids):
def _get_metadata_ids(
hass: HomeAssistant, session: scoped_session, statistic_ids: list[str]
) -> list[str]:
"""Resolve metadata_id for a list of statistic_ids."""
baked_query = hass.data[STATISTICS_META_BAKERY](
lambda session: session.query(*QUERY_STATISTIC_META)
@ -83,10 +93,15 @@ def _get_metadata_ids(hass, session, statistic_ids):
)
result = execute(baked_query(session).params(statistic_ids=statistic_ids))
return [id for id, _, _ in result]
return [id for id, _, _ in result] if result else []
def _get_or_add_metadata_id(hass, session, statistic_id, metadata):
def _get_or_add_metadata_id(
hass: HomeAssistant,
session: scoped_session,
statistic_id: str,
metadata: StatisticMetaData,
) -> str:
"""Get metadata_id for a statistic_id, add if it doesn't exist."""
metadata_id = _get_metadata_ids(hass, session, [statistic_id])
if not metadata_id:
@ -101,7 +116,7 @@ def _get_or_add_metadata_id(hass, session, statistic_id, metadata):
@retryable_database_job("statistics")
def compile_statistics(instance: Recorder, start: datetime.datetime) -> bool:
def compile_statistics(instance: Recorder, start: datetime) -> bool:
"""Compile statistics."""
start = dt_util.as_utc(start)
end = start + timedelta(hours=1)
@ -126,10 +141,15 @@ def compile_statistics(instance: Recorder, start: datetime.datetime) -> bool:
return True
def _get_metadata(hass, session, statistic_ids, statistic_type):
def _get_metadata(
hass: HomeAssistant,
session: scoped_session,
statistic_ids: list[str] | None,
statistic_type: str | None,
) -> dict[str, dict[str, str]]:
"""Fetch meta data."""
def _meta(metas, wanted_metadata_id):
def _meta(metas: list, wanted_metadata_id: str) -> dict[str, str] | None:
meta = None
for metadata_id, statistic_id, unit in metas:
if metadata_id == wanted_metadata_id:
@ -150,12 +170,19 @@ def _get_metadata(hass, session, statistic_ids, statistic_type):
elif statistic_type is not None:
return {}
result = execute(baked_query(session).params(statistic_ids=statistic_ids))
if not result:
return {}
metadata_ids = [metadata[0] for metadata in result]
return {id: _meta(result, id) for id in metadata_ids}
metadata = {}
for _id in metadata_ids:
meta = _meta(result, _id)
if meta:
metadata[_id] = meta
return metadata
def _configured_unit(unit: str, units) -> str:
def _configured_unit(unit: str, units: UnitSystem) -> str:
"""Return the pressure and temperature units configured by the user."""
if unit == PRESSURE_PA:
return units.pressure_unit
@ -164,7 +191,9 @@ def _configured_unit(unit: str, units) -> str:
return unit
def list_statistic_ids(hass, statistic_type=None):
def list_statistic_ids(
hass: HomeAssistant, statistic_type: str | None = None
) -> list[dict[str, str] | None]:
"""Return statistic_ids and meta data."""
units = hass.config.units
with session_scope(hass=hass) as session:
@ -177,7 +206,12 @@ def list_statistic_ids(hass, statistic_type=None):
return list(metadata.values())
def statistics_during_period(hass, start_time, end_time=None, statistic_ids=None):
def statistics_during_period(
hass: HomeAssistant,
start_time: datetime,
end_time: datetime | None = None,
statistic_ids: list[str] | None = None,
) -> dict[str, list[dict[str, str]]]:
"""Return states changes during UTC period start_time - end_time."""
metadata = None
with session_scope(hass=hass) as session:
@ -208,10 +242,14 @@ def statistics_during_period(hass, start_time, end_time=None, statistic_ids=None
start_time=start_time, end_time=end_time, metadata_ids=metadata_ids
)
)
if not stats:
return {}
return _sorted_statistics_to_dict(hass, stats, statistic_ids, metadata)
def get_last_statistics(hass, number_of_stats, statistic_id):
def get_last_statistics(
hass: HomeAssistant, number_of_stats: int, statistic_id: str
) -> dict[str, list[dict]]:
"""Return the last number_of_stats statistics for a statistic_id."""
statistic_ids = [statistic_id]
with session_scope(hass=hass) as session:
@ -237,18 +275,20 @@ def get_last_statistics(hass, number_of_stats, statistic_id):
number_of_stats=number_of_stats, metadata_id=metadata_id
)
)
if not stats:
return {}
return _sorted_statistics_to_dict(hass, stats, statistic_ids, metadata)
def _sorted_statistics_to_dict(
hass,
stats,
statistic_ids,
metadata,
):
hass: HomeAssistant,
stats: list,
statistic_ids: list[str] | None,
metadata: dict[str, dict[str, str]],
) -> dict[str, list[dict]]:
"""Convert SQL results into JSON friendly data structure."""
result = defaultdict(list)
result: dict = defaultdict(list)
units = hass.config.units
# Set all statistic IDs to empty lists in result set to maintain the order
@ -260,10 +300,12 @@ def _sorted_statistics_to_dict(
_process_timestamp_to_utc_isoformat = process_timestamp_to_utc_isoformat
# Append all statistic entries, and do unit conversion
for meta_id, group in groupby(stats, lambda state: state.metadata_id):
for meta_id, group in groupby(stats, lambda stat: stat.metadata_id): # type: ignore
unit = metadata[meta_id]["unit_of_measurement"]
statistic_id = metadata[meta_id]["statistic_id"]
convert = UNIT_CONVERSIONS.get(unit, lambda x, units: x)
convert: Callable[[Any, Any], float | None] = UNIT_CONVERSIONS.get(
unit, lambda x, units: x # type: ignore
)
ent_results = result[meta_id]
ent_results.extend(
{