Enable basic type checking for recorder (#52440)
* Enable basic type checking for recorder * Tweak
This commit is contained in:
parent
960684346f
commit
19d3aa71ad
7 changed files with 108 additions and 39 deletions
|
@ -5,18 +5,26 @@ from collections import defaultdict
|
|||
from datetime import datetime, timedelta
|
||||
from itertools import groupby
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from sqlalchemy import bindparam
|
||||
from sqlalchemy.ext import baked
|
||||
from sqlalchemy.orm.scoping import scoped_session
|
||||
|
||||
from homeassistant.const import PRESSURE_PA, TEMP_CELSIUS
|
||||
from homeassistant.core import HomeAssistant
|
||||
import homeassistant.util.dt as dt_util
|
||||
import homeassistant.util.pressure as pressure_util
|
||||
import homeassistant.util.temperature as temperature_util
|
||||
from homeassistant.util.unit_system import UnitSystem
|
||||
|
||||
from .const import DOMAIN
|
||||
from .models import Statistics, StatisticsMeta, process_timestamp_to_utc_isoformat
|
||||
from .models import (
|
||||
StatisticMetaData,
|
||||
Statistics,
|
||||
StatisticsMeta,
|
||||
process_timestamp_to_utc_isoformat,
|
||||
)
|
||||
from .util import execute, retryable_database_job, session_scope
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -60,20 +68,22 @@ UNIT_CONVERSIONS = {
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def async_setup(hass):
|
||||
def async_setup(hass: HomeAssistant) -> None:
|
||||
"""Set up the history hooks."""
|
||||
hass.data[STATISTICS_BAKERY] = baked.bakery()
|
||||
hass.data[STATISTICS_META_BAKERY] = baked.bakery()
|
||||
|
||||
|
||||
def get_start_time() -> datetime.datetime:
|
||||
def get_start_time() -> datetime:
|
||||
"""Return start time."""
|
||||
last_hour = dt_util.utcnow() - timedelta(hours=1)
|
||||
start = last_hour.replace(minute=0, second=0, microsecond=0)
|
||||
return start
|
||||
|
||||
|
||||
def _get_metadata_ids(hass, session, statistic_ids):
|
||||
def _get_metadata_ids(
|
||||
hass: HomeAssistant, session: scoped_session, statistic_ids: list[str]
|
||||
) -> list[str]:
|
||||
"""Resolve metadata_id for a list of statistic_ids."""
|
||||
baked_query = hass.data[STATISTICS_META_BAKERY](
|
||||
lambda session: session.query(*QUERY_STATISTIC_META)
|
||||
|
@ -83,10 +93,15 @@ def _get_metadata_ids(hass, session, statistic_ids):
|
|||
)
|
||||
result = execute(baked_query(session).params(statistic_ids=statistic_ids))
|
||||
|
||||
return [id for id, _, _ in result]
|
||||
return [id for id, _, _ in result] if result else []
|
||||
|
||||
|
||||
def _get_or_add_metadata_id(hass, session, statistic_id, metadata):
|
||||
def _get_or_add_metadata_id(
|
||||
hass: HomeAssistant,
|
||||
session: scoped_session,
|
||||
statistic_id: str,
|
||||
metadata: StatisticMetaData,
|
||||
) -> str:
|
||||
"""Get metadata_id for a statistic_id, add if it doesn't exist."""
|
||||
metadata_id = _get_metadata_ids(hass, session, [statistic_id])
|
||||
if not metadata_id:
|
||||
|
@ -101,7 +116,7 @@ def _get_or_add_metadata_id(hass, session, statistic_id, metadata):
|
|||
|
||||
|
||||
@retryable_database_job("statistics")
|
||||
def compile_statistics(instance: Recorder, start: datetime.datetime) -> bool:
|
||||
def compile_statistics(instance: Recorder, start: datetime) -> bool:
|
||||
"""Compile statistics."""
|
||||
start = dt_util.as_utc(start)
|
||||
end = start + timedelta(hours=1)
|
||||
|
@ -126,10 +141,15 @@ def compile_statistics(instance: Recorder, start: datetime.datetime) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def _get_metadata(hass, session, statistic_ids, statistic_type):
|
||||
def _get_metadata(
|
||||
hass: HomeAssistant,
|
||||
session: scoped_session,
|
||||
statistic_ids: list[str] | None,
|
||||
statistic_type: str | None,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Fetch meta data."""
|
||||
|
||||
def _meta(metas, wanted_metadata_id):
|
||||
def _meta(metas: list, wanted_metadata_id: str) -> dict[str, str] | None:
|
||||
meta = None
|
||||
for metadata_id, statistic_id, unit in metas:
|
||||
if metadata_id == wanted_metadata_id:
|
||||
|
@ -150,12 +170,19 @@ def _get_metadata(hass, session, statistic_ids, statistic_type):
|
|||
elif statistic_type is not None:
|
||||
return {}
|
||||
result = execute(baked_query(session).params(statistic_ids=statistic_ids))
|
||||
if not result:
|
||||
return {}
|
||||
|
||||
metadata_ids = [metadata[0] for metadata in result]
|
||||
return {id: _meta(result, id) for id in metadata_ids}
|
||||
metadata = {}
|
||||
for _id in metadata_ids:
|
||||
meta = _meta(result, _id)
|
||||
if meta:
|
||||
metadata[_id] = meta
|
||||
return metadata
|
||||
|
||||
|
||||
def _configured_unit(unit: str, units) -> str:
|
||||
def _configured_unit(unit: str, units: UnitSystem) -> str:
|
||||
"""Return the pressure and temperature units configured by the user."""
|
||||
if unit == PRESSURE_PA:
|
||||
return units.pressure_unit
|
||||
|
@ -164,7 +191,9 @@ def _configured_unit(unit: str, units) -> str:
|
|||
return unit
|
||||
|
||||
|
||||
def list_statistic_ids(hass, statistic_type=None):
|
||||
def list_statistic_ids(
|
||||
hass: HomeAssistant, statistic_type: str | None = None
|
||||
) -> list[dict[str, str] | None]:
|
||||
"""Return statistic_ids and meta data."""
|
||||
units = hass.config.units
|
||||
with session_scope(hass=hass) as session:
|
||||
|
@ -177,7 +206,12 @@ def list_statistic_ids(hass, statistic_type=None):
|
|||
return list(metadata.values())
|
||||
|
||||
|
||||
def statistics_during_period(hass, start_time, end_time=None, statistic_ids=None):
|
||||
def statistics_during_period(
|
||||
hass: HomeAssistant,
|
||||
start_time: datetime,
|
||||
end_time: datetime | None = None,
|
||||
statistic_ids: list[str] | None = None,
|
||||
) -> dict[str, list[dict[str, str]]]:
|
||||
"""Return states changes during UTC period start_time - end_time."""
|
||||
metadata = None
|
||||
with session_scope(hass=hass) as session:
|
||||
|
@ -208,10 +242,14 @@ def statistics_during_period(hass, start_time, end_time=None, statistic_ids=None
|
|||
start_time=start_time, end_time=end_time, metadata_ids=metadata_ids
|
||||
)
|
||||
)
|
||||
if not stats:
|
||||
return {}
|
||||
return _sorted_statistics_to_dict(hass, stats, statistic_ids, metadata)
|
||||
|
||||
|
||||
def get_last_statistics(hass, number_of_stats, statistic_id):
|
||||
def get_last_statistics(
|
||||
hass: HomeAssistant, number_of_stats: int, statistic_id: str
|
||||
) -> dict[str, list[dict]]:
|
||||
"""Return the last number_of_stats statistics for a statistic_id."""
|
||||
statistic_ids = [statistic_id]
|
||||
with session_scope(hass=hass) as session:
|
||||
|
@ -237,18 +275,20 @@ def get_last_statistics(hass, number_of_stats, statistic_id):
|
|||
number_of_stats=number_of_stats, metadata_id=metadata_id
|
||||
)
|
||||
)
|
||||
if not stats:
|
||||
return {}
|
||||
|
||||
return _sorted_statistics_to_dict(hass, stats, statistic_ids, metadata)
|
||||
|
||||
|
||||
def _sorted_statistics_to_dict(
|
||||
hass,
|
||||
stats,
|
||||
statistic_ids,
|
||||
metadata,
|
||||
):
|
||||
hass: HomeAssistant,
|
||||
stats: list,
|
||||
statistic_ids: list[str] | None,
|
||||
metadata: dict[str, dict[str, str]],
|
||||
) -> dict[str, list[dict]]:
|
||||
"""Convert SQL results into JSON friendly data structure."""
|
||||
result = defaultdict(list)
|
||||
result: dict = defaultdict(list)
|
||||
units = hass.config.units
|
||||
|
||||
# Set all statistic IDs to empty lists in result set to maintain the order
|
||||
|
@ -260,10 +300,12 @@ def _sorted_statistics_to_dict(
|
|||
_process_timestamp_to_utc_isoformat = process_timestamp_to_utc_isoformat
|
||||
|
||||
# Append all statistic entries, and do unit conversion
|
||||
for meta_id, group in groupby(stats, lambda state: state.metadata_id):
|
||||
for meta_id, group in groupby(stats, lambda stat: stat.metadata_id): # type: ignore
|
||||
unit = metadata[meta_id]["unit_of_measurement"]
|
||||
statistic_id = metadata[meta_id]["statistic_id"]
|
||||
convert = UNIT_CONVERSIONS.get(unit, lambda x, units: x)
|
||||
convert: Callable[[Any, Any], float | None] = UNIT_CONVERSIONS.get(
|
||||
unit, lambda x, units: x # type: ignore
|
||||
)
|
||||
ent_results = result[meta_id]
|
||||
ent_results.extend(
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue