Add typing to statistics results (#89118)

This commit is contained in:
J. Nick Koston 2023-03-14 09:06:56 -10:00 committed by GitHub
parent 9d2c62095f
commit a6d6807dd0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 42 deletions

View file

@ -13,6 +13,7 @@ from typing import Any, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.components import recorder, websocket_api from homeassistant.components import recorder, websocket_api
from homeassistant.components.recorder.statistics import StatisticsRow
from homeassistant.const import UnitOfEnergy from homeassistant.const import UnitOfEnergy
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.integration_platform import ( from homeassistant.helpers.integration_platform import (
@ -277,7 +278,7 @@ async def ws_get_fossil_energy_consumption(
) )
def _combine_sum_statistics( def _combine_sum_statistics(
stats: dict[str, list[dict[str, Any]]], statistic_ids: list[str] stats: dict[str, list[StatisticsRow]], statistic_ids: list[str]
) -> dict[float, float]: ) -> dict[float, float]:
"""Combine multiple statistics, returns a dict indexed by start time.""" """Combine multiple statistics, returns a dict indexed by start time."""
result: defaultdict[float, float] = defaultdict(float) result: defaultdict[float, float] = defaultdict(float)
@ -313,11 +314,10 @@ async def ws_get_fossil_energy_consumption(
if not stat_list: if not stat_list:
return result return result
prev_stat: dict[str, Any] = stat_list[0] prev_stat: dict[str, Any] = stat_list[0]
fake_stat = {"start": stat_list[-1]["start"] + period.total_seconds()}
# Loop over the hourly deltas + a fake entry to end the period # Loop over the hourly deltas + a fake entry to end the period
for statistic in chain( for statistic in chain(stat_list, (fake_stat,)):
stat_list, ({"start": stat_list[-1]["start"] + period.total_seconds()},)
):
if not same_period(prev_stat["start"], statistic["start"]): if not same_period(prev_stat["start"], statistic["start"]):
start, _ = period_start_end(prev_stat["start"]) start, _ = period_start_end(prev_stat["start"])
# The previous statistic was the last entry of the period # The previous statistic was the last entry of the period
@ -338,10 +338,13 @@ async def ws_get_fossil_energy_consumption(
statistics, msg["energy_statistic_ids"] statistics, msg["energy_statistic_ids"]
) )
energy_deltas = _calculate_deltas(merged_energy_statistics) energy_deltas = _calculate_deltas(merged_energy_statistics)
indexed_co2_statistics = { indexed_co2_statistics = cast(
period["start"]: period["mean"] dict[float, float],
for period in statistics.get(msg["co2_statistic_id"], {}) {
} period["start"]: period["mean"]
for period in statistics.get(msg["co2_statistic_id"], {})
},
)
# Calculate amount of fossil based energy, assume 100% fossil if missing # Calculate amount of fossil based energy, assume 100% fossil if missing
fossil_energy = [ fossil_energy = [

View file

@ -14,7 +14,7 @@ from operator import itemgetter
import os import os
import re import re
from statistics import mean from statistics import mean
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
from sqlalchemy import Select, and_, bindparam, func, lambda_stmt, select, text from sqlalchemy import Select, and_, bindparam, func, lambda_stmt, select, text
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@ -166,6 +166,24 @@ STATISTIC_UNIT_TO_UNIT_CONVERTER: dict[str | None, type[BaseUnitConverter]] = {
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class BaseStatisticsRow(TypedDict, total=False):
"""A processed row of statistic data."""
start: float
class StatisticsRow(BaseStatisticsRow, total=False):
"""A processed row of statistic data."""
end: float
last_reset: float | None
state: float | None
sum: float | None
min: float | None
max: float | None
mean: float | None
def _get_unit_class(unit: str | None) -> str | None: def _get_unit_class(unit: str | None) -> str | None:
"""Get corresponding unit class from from the statistics unit.""" """Get corresponding unit class from from the statistics unit."""
if converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(unit): if converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(unit):
@ -1048,14 +1066,14 @@ def list_statistic_ids(
def _reduce_statistics( def _reduce_statistics(
stats: dict[str, list[dict[str, Any]]], stats: dict[str, list[StatisticsRow]],
same_period: Callable[[float, float], bool], same_period: Callable[[float, float], bool],
period_start_end: Callable[[float], tuple[float, float]], period_start_end: Callable[[float], tuple[float, float]],
period: timedelta, period: timedelta,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict[str, Any]]]: ) -> dict[str, list[StatisticsRow]]:
"""Reduce hourly statistics to daily or monthly statistics.""" """Reduce hourly statistics to daily or monthly statistics."""
result: dict[str, list[dict[str, Any]]] = defaultdict(list) result: dict[str, list[StatisticsRow]] = defaultdict(list)
period_seconds = period.total_seconds() period_seconds = period.total_seconds()
_want_mean = "mean" in types _want_mean = "mean" in types
_want_min = "min" in types _want_min = "min" in types
@ -1067,16 +1085,15 @@ def _reduce_statistics(
max_values: list[float] = [] max_values: list[float] = []
mean_values: list[float] = [] mean_values: list[float] = []
min_values: list[float] = [] min_values: list[float] = []
prev_stat: dict[str, Any] = stat_list[0] prev_stat: StatisticsRow = stat_list[0]
fake_entry: StatisticsRow = {"start": stat_list[-1]["start"] + period_seconds}
# Loop over the hourly statistics + a fake entry to end the period # Loop over the hourly statistics + a fake entry to end the period
for statistic in chain( for statistic in chain(stat_list, (fake_entry,)):
stat_list, ({"start": stat_list[-1]["start"] + period_seconds},)
):
if not same_period(prev_stat["start"], statistic["start"]): if not same_period(prev_stat["start"], statistic["start"]):
start, end = period_start_end(prev_stat["start"]) start, end = period_start_end(prev_stat["start"])
# The previous statistic was the last entry of the period # The previous statistic was the last entry of the period
row: dict[str, Any] = { row: StatisticsRow = {
"start": start, "start": start,
"end": end, "end": end,
} }
@ -1146,9 +1163,9 @@ def reduce_day_ts_factory() -> (
def _reduce_statistics_per_day( def _reduce_statistics_per_day(
stats: dict[str, list[dict[str, Any]]], stats: dict[str, list[StatisticsRow]],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict[str, Any]]]: ) -> dict[str, list[StatisticsRow]]:
"""Reduce hourly statistics to daily statistics.""" """Reduce hourly statistics to daily statistics."""
_same_day_ts, _day_start_end_ts = reduce_day_ts_factory() _same_day_ts, _day_start_end_ts = reduce_day_ts_factory()
return _reduce_statistics( return _reduce_statistics(
@ -1196,9 +1213,9 @@ def reduce_week_ts_factory() -> (
def _reduce_statistics_per_week( def _reduce_statistics_per_week(
stats: dict[str, list[dict[str, Any]]], stats: dict[str, list[StatisticsRow]],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict[str, Any]]]: ) -> dict[str, list[StatisticsRow]]:
"""Reduce hourly statistics to weekly statistics.""" """Reduce hourly statistics to weekly statistics."""
_same_week_ts, _week_start_end_ts = reduce_week_ts_factory() _same_week_ts, _week_start_end_ts = reduce_week_ts_factory()
return _reduce_statistics( return _reduce_statistics(
@ -1248,9 +1265,9 @@ def reduce_month_ts_factory() -> (
def _reduce_statistics_per_month( def _reduce_statistics_per_month(
stats: dict[str, list[dict[str, Any]]], stats: dict[str, list[StatisticsRow]],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict[str, Any]]]: ) -> dict[str, list[StatisticsRow]]:
"""Reduce hourly statistics to monthly statistics.""" """Reduce hourly statistics to monthly statistics."""
_same_month_ts, _month_start_end_ts = reduce_month_ts_factory() _same_month_ts, _month_start_end_ts = reduce_month_ts_factory()
return _reduce_statistics( return _reduce_statistics(
@ -1724,7 +1741,7 @@ def _statistics_during_period_with_session(
period: Literal["5minute", "day", "hour", "week", "month"], period: Literal["5minute", "day", "hour", "week", "month"],
units: dict[str, str] | None, units: dict[str, str] | None,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict[str, Any]]]: ) -> dict[str, list[StatisticsRow]]:
"""Return statistic data points during UTC period start_time - end_time. """Return statistic data points during UTC period start_time - end_time.
If end_time is omitted, returns statistics newer than or equal to start_time. If end_time is omitted, returns statistics newer than or equal to start_time.
@ -1808,7 +1825,7 @@ def statistics_during_period(
period: Literal["5minute", "day", "hour", "week", "month"], period: Literal["5minute", "day", "hour", "week", "month"],
units: dict[str, str] | None, units: dict[str, str] | None,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict[str, Any]]]: ) -> dict[str, list[StatisticsRow]]:
"""Return statistic data points during UTC period start_time - end_time. """Return statistic data points during UTC period start_time - end_time.
If end_time is omitted, returns statistics newer than or equal to start_time. If end_time is omitted, returns statistics newer than or equal to start_time.
@ -1863,7 +1880,7 @@ def _get_last_statistics(
convert_units: bool, convert_units: bool,
table: type[StatisticsBase], table: type[StatisticsBase],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict]]: ) -> dict[str, list[StatisticsRow]]:
"""Return the last number_of_stats statistics for a given statistic_id.""" """Return the last number_of_stats statistics for a given statistic_id."""
statistic_ids = [statistic_id] statistic_ids = [statistic_id]
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:
@ -1902,7 +1919,7 @@ def get_last_statistics(
statistic_id: str, statistic_id: str,
convert_units: bool, convert_units: bool,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict]]: ) -> dict[str, list[StatisticsRow]]:
"""Return the last number_of_stats statistics for a statistic_id.""" """Return the last number_of_stats statistics for a statistic_id."""
return _get_last_statistics( return _get_last_statistics(
hass, number_of_stats, statistic_id, convert_units, Statistics, types hass, number_of_stats, statistic_id, convert_units, Statistics, types
@ -1915,7 +1932,7 @@ def get_last_short_term_statistics(
statistic_id: str, statistic_id: str,
convert_units: bool, convert_units: bool,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict]]: ) -> dict[str, list[StatisticsRow]]:
"""Return the last number_of_stats short term statistics for a statistic_id.""" """Return the last number_of_stats short term statistics for a statistic_id."""
return _get_last_statistics( return _get_last_statistics(
hass, number_of_stats, statistic_id, convert_units, StatisticsShortTerm, types hass, number_of_stats, statistic_id, convert_units, StatisticsShortTerm, types
@ -1951,7 +1968,7 @@ def get_latest_short_term_statistics(
statistic_ids: list[str], statistic_ids: list[str],
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
metadata: dict[str, tuple[int, StatisticMetaData]] | None = None, metadata: dict[str, tuple[int, StatisticMetaData]] | None = None,
) -> dict[str, list[dict]]: ) -> dict[str, list[StatisticsRow]]:
"""Return the latest short term statistics for a list of statistic_ids.""" """Return the latest short term statistics for a list of statistic_ids."""
with session_scope(hass=hass, read_only=True) as session: with session_scope(hass=hass, read_only=True) as session:
# Fetch metadata for the given statistic_ids # Fetch metadata for the given statistic_ids
@ -2054,10 +2071,10 @@ def _sorted_statistics_to_dict(
start_time: datetime | None, start_time: datetime | None,
units: dict[str, str] | None, units: dict[str, str] | None,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
) -> dict[str, list[dict]]: ) -> dict[str, list[StatisticsRow]]:
"""Convert SQL results into JSON friendly data structure.""" """Convert SQL results into JSON friendly data structure."""
assert stats, "stats must not be empty" # Guard against implementation error assert stats, "stats must not be empty" # Guard against implementation error
result: dict = defaultdict(list) result: dict[str, list[StatisticsRow]] = defaultdict(list)
metadata = dict(_metadata.values()) metadata = dict(_metadata.values())
need_stat_at_start_time: set[int] = set() need_stat_at_start_time: set[int] = set()
start_time_ts = start_time.timestamp() if start_time else None start_time_ts = start_time.timestamp() if start_time else None
@ -2123,7 +2140,7 @@ def _sorted_statistics_to_dict(
# attribute lookups, and dict lookups as much as possible. # attribute lookups, and dict lookups as much as possible.
# #
for db_state in stats_list: for db_state in stats_list:
row: dict[str, Any] = { row: StatisticsRow = {
"start": (start_ts := db_state[start_ts_idx]), "start": (start_ts := db_state[start_ts_idx]),
"end": start_ts + table_duration_seconds, "end": start_ts + table_duration_seconds,
} }

View file

@ -529,11 +529,11 @@ def _compile_statistics( # noqa: C901
if entity_id in last_stats: if entity_id in last_stats:
# We have compiled history for this sensor before, # We have compiled history for this sensor before,
# use that as a starting point. # use that as a starting point.
last_reset = old_last_reset = _timestamp_to_isoformat_or_none( last_stat = last_stats[entity_id][0]
last_stats[entity_id][0]["last_reset"] last_reset = _timestamp_to_isoformat_or_none(last_stat["last_reset"])
) old_last_reset = last_reset
new_state = old_state = last_stats[entity_id][0]["state"] new_state = old_state = last_stat["state"]
_sum = last_stats[entity_id][0]["sum"] or 0.0 _sum = last_stat["sum"] or 0.0
for fstate, state in fstates: for fstate, state in fstates:
reset = False reset = False
@ -596,7 +596,7 @@ def _compile_statistics( # noqa: C901
if reset: if reset:
# The sensor has been reset, update the sum # The sensor has been reset, update the sum
if old_state is not None: if old_state is not None and new_state is not None:
_sum += new_state - old_state _sum += new_state - old_state
# ..and update the starting point # ..and update the starting point
new_state = fstate new_state = fstate

View file

@ -6,7 +6,7 @@ import datetime
from datetime import timedelta from datetime import timedelta
import logging import logging
from random import randrange from random import randrange
from typing import Any from typing import Any, cast
import aiohttp import aiohttp
import tibber import tibber
@ -614,7 +614,7 @@ class TibberDataCoordinator(DataUpdateCoordinator[None]):
5 * 365 * 24, production=is_production 5 * 365 * 24, production=is_production
) )
_sum = 0 _sum = 0.0
last_stats_time = None last_stats_time = None
else: else:
# hourly_consumption/production_data contains the last 30 days # hourly_consumption/production_data contains the last 30 days
@ -641,8 +641,9 @@ class TibberDataCoordinator(DataUpdateCoordinator[None]):
None, None,
{"sum"}, {"sum"},
) )
_sum = stat[statistic_id][0]["sum"] first_stat = stat[statistic_id][0]
last_stats_time = stat[statistic_id][0]["start"] _sum = cast(float, first_stat["sum"])
last_stats_time = first_stat["start"]
statistics = [] statistics = []