Cleanup recorder history typing (#69408)

This commit is contained in:
J. Nick Koston 2022-04-07 00:09:05 -10:00 committed by GitHub
parent 97aa65d9a4
commit 5c7c09726a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 160 additions and 45 deletions

View file

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, MutableMapping from collections.abc import Iterable, Iterator, MutableMapping
from datetime import datetime from datetime import datetime
from itertools import groupby from itertools import groupby
import logging import logging
@ -141,7 +141,7 @@ def get_significant_states(
significant_changes_only: bool = True, significant_changes_only: bool = True,
minimal_response: bool = False, minimal_response: bool = False,
no_attributes: bool = False, no_attributes: bool = False,
) -> MutableMapping[str, Iterable[LazyState | State | dict[str, Any]]]: ) -> MutableMapping[str, list[State | dict[str, Any]]]:
"""Wrap get_significant_states_with_session with an sql session.""" """Wrap get_significant_states_with_session with an sql session."""
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
return get_significant_states_with_session( return get_significant_states_with_session(
@ -158,31 +158,20 @@ def get_significant_states(
) )
def get_significant_states_with_session( def _query_significant_states_with_session(
hass: HomeAssistant, hass: HomeAssistant,
session: Session, session: Session,
start_time: datetime, start_time: datetime,
end_time: datetime | None = None, end_time: datetime | None = None,
entity_ids: list[str] | None = None, entity_ids: list[str] | None = None,
filters: Any = None, filters: Any = None,
include_start_time_state: bool = True,
significant_changes_only: bool = True, significant_changes_only: bool = True,
minimal_response: bool = False,
no_attributes: bool = False, no_attributes: bool = False,
) -> MutableMapping[str, Iterable[LazyState | State | dict[str, Any]]]: ) -> list[States]:
""" """Query the database for significant state changes."""
Return states changes during UTC period start_time - end_time. if _LOGGER.isEnabledFor(logging.DEBUG):
entity_ids is an optional iterable of entities to include in the results.
filters is an optional SQLAlchemy filter which will be applied to the database
queries unless entity_ids is given, in which case its ignored.
Significant states are all states where there is a state change,
as well as all states from certain domains (for instance
thermostat so that we get current temperature in our graphs).
"""
timer_start = time.perf_counter() timer_start = time.perf_counter()
baked_query, join_attributes = bake_query_and_join_attributes(hass, no_attributes) baked_query, join_attributes = bake_query_and_join_attributes(hass, no_attributes)
if entity_ids is not None and len(entity_ids) == 1: if entity_ids is not None and len(entity_ids) == 1:
@ -240,6 +229,43 @@ def get_significant_states_with_session(
elapsed = time.perf_counter() - timer_start elapsed = time.perf_counter() - timer_start
_LOGGER.debug("get_significant_states took %fs", elapsed) _LOGGER.debug("get_significant_states took %fs", elapsed)
return states
def get_significant_states_with_session(
hass: HomeAssistant,
session: Session,
start_time: datetime,
end_time: datetime | None = None,
entity_ids: list[str] | None = None,
filters: Any = None,
include_start_time_state: bool = True,
significant_changes_only: bool = True,
minimal_response: bool = False,
no_attributes: bool = False,
) -> MutableMapping[str, list[State | dict[str, Any]]]:
"""
Return states changes during UTC period start_time - end_time.
entity_ids is an optional iterable of entities to include in the results.
filters is an optional SQLAlchemy filter which will be applied to the database
queries unless entity_ids is given, in which case its ignored.
Significant states are all states where there is a state change,
as well as all states from certain domains (for instance
thermostat so that we get current temperature in our graphs).
"""
states = _query_significant_states_with_session(
hass,
session,
start_time,
end_time,
entity_ids,
filters,
significant_changes_only,
no_attributes,
)
return _sorted_states_to_dict( return _sorted_states_to_dict(
hass, hass,
session, session,
@ -253,6 +279,35 @@ def get_significant_states_with_session(
) )
def get_full_significant_states_with_session(
hass: HomeAssistant,
session: Session,
start_time: datetime,
end_time: datetime | None = None,
entity_ids: list[str] | None = None,
filters: Any = None,
include_start_time_state: bool = True,
significant_changes_only: bool = True,
no_attributes: bool = False,
) -> MutableMapping[str, list[State]]:
"""Variant of get_significant_states_with_session that does not return minimal responses."""
return cast(
MutableMapping[str, list[State]],
get_significant_states_with_session(
hass=hass,
session=session,
start_time=start_time,
end_time=end_time,
entity_ids=entity_ids,
filters=filters,
include_start_time_state=include_start_time_state,
significant_changes_only=significant_changes_only,
minimal_response=False,
no_attributes=no_attributes,
),
)
def state_changes_during_period( def state_changes_during_period(
hass: HomeAssistant, hass: HomeAssistant,
start_time: datetime, start_time: datetime,
@ -262,7 +317,7 @@ def state_changes_during_period(
descending: bool = False, descending: bool = False,
limit: int | None = None, limit: int | None = None,
include_start_time_state: bool = True, include_start_time_state: bool = True,
) -> MutableMapping[str, Iterable[LazyState]]: ) -> MutableMapping[str, list[State]]:
"""Return states changes during UTC period start_time - end_time.""" """Return states changes during UTC period start_time - end_time."""
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
baked_query, join_attributes = bake_query_and_join_attributes( baked_query, join_attributes = bake_query_and_join_attributes(
@ -303,7 +358,7 @@ def state_changes_during_period(
entity_ids = [entity_id] if entity_id is not None else None entity_ids = [entity_id] if entity_id is not None else None
return cast( return cast(
MutableMapping[str, Iterable[LazyState]], MutableMapping[str, list[State]],
_sorted_states_to_dict( _sorted_states_to_dict(
hass, hass,
session, session,
@ -317,7 +372,7 @@ def state_changes_during_period(
def get_last_state_changes( def get_last_state_changes(
hass: HomeAssistant, number_of_states: int, entity_id: str hass: HomeAssistant, number_of_states: int, entity_id: str
) -> MutableMapping[str, Iterable[LazyState]]: ) -> MutableMapping[str, list[State]]:
"""Return the last number_of_states.""" """Return the last number_of_states."""
start_time = dt_util.utcnow() start_time = dt_util.utcnow()
@ -349,7 +404,7 @@ def get_last_state_changes(
entity_ids = [entity_id] if entity_id is not None else None entity_ids = [entity_id] if entity_id is not None else None
return cast( return cast(
MutableMapping[str, Iterable[LazyState]], MutableMapping[str, list[State]],
_sorted_states_to_dict( _sorted_states_to_dict(
hass, hass,
session, session,
@ -368,7 +423,7 @@ def get_states(
run: RecorderRuns | None = None, run: RecorderRuns | None = None,
filters: Any = None, filters: Any = None,
no_attributes: bool = False, no_attributes: bool = False,
) -> list[LazyState]: ) -> list[State]:
"""Return the states at a specific point in time.""" """Return the states at a specific point in time."""
if ( if (
run is None run is None
@ -392,7 +447,7 @@ def _get_states_with_session(
run: RecorderRuns | None = None, run: RecorderRuns | None = None,
filters: Any | None = None, filters: Any | None = None,
no_attributes: bool = False, no_attributes: bool = False,
) -> list[LazyState]: ) -> list[State]:
"""Return the states at a specific point in time.""" """Return the states at a specific point in time."""
if entity_ids and len(entity_ids) == 1: if entity_ids and len(entity_ids) == 1:
return _get_single_entity_states_with_session( return _get_single_entity_states_with_session(
@ -488,7 +543,7 @@ def _get_single_entity_states_with_session(
utc_point_in_time: datetime, utc_point_in_time: datetime,
entity_id: str, entity_id: str,
no_attributes: bool = False, no_attributes: bool = False,
) -> list[LazyState]: ) -> list[State]:
# Use an entirely different (and extremely fast) query if we only # Use an entirely different (and extremely fast) query if we only
# have a single entity id # have a single entity id
baked_query, join_attributes = bake_query_and_join_attributes(hass, no_attributes) baked_query, join_attributes = bake_query_and_join_attributes(hass, no_attributes)
@ -520,7 +575,7 @@ def _sorted_states_to_dict(
include_start_time_state: bool = True, include_start_time_state: bool = True,
minimal_response: bool = False, minimal_response: bool = False,
no_attributes: bool = False, no_attributes: bool = False,
) -> MutableMapping[str, Iterable[LazyState | State | dict[str, Any]]]: ) -> MutableMapping[str, list[State | dict[str, Any]]]:
"""Convert SQL results into JSON friendly data structure. """Convert SQL results into JSON friendly data structure.
This takes our state list and turns it into a JSON friendly data This takes our state list and turns it into a JSON friendly data
@ -532,7 +587,7 @@ def _sorted_states_to_dict(
each list of states, otherwise our graphs won't start on the Y each list of states, otherwise our graphs won't start on the Y
axis correctly. axis correctly.
""" """
result: dict[str, list[LazyState | dict[str, Any]]] = defaultdict(list) result: dict[str, list[State | dict[str, Any]]] = defaultdict(list)
# Set all entity IDs to empty lists in result set to maintain the order # Set all entity IDs to empty lists in result set to maintain the order
if entity_ids is not None: if entity_ids is not None:
for ent_id in entity_ids: for ent_id in entity_ids:
@ -563,21 +618,30 @@ def _sorted_states_to_dict(
# here # here
_process_timestamp_to_utc_isoformat = process_timestamp_to_utc_isoformat _process_timestamp_to_utc_isoformat = process_timestamp_to_utc_isoformat
if entity_ids and len(entity_ids) == 1:
states_iter: Iterable[tuple[str | Column, Iterator[States]]] = (
(entity_ids[0], iter(states)),
)
else:
states_iter = groupby(states, lambda state: state.entity_id)
# Append all changes to it # Append all changes to it
for ent_id, group in groupby(states, lambda state: state.entity_id): # type: ignore[no-any-return] for ent_id, group in states_iter:
domain = split_entity_id(ent_id)[0]
ent_results = result[ent_id] ent_results = result[ent_id]
attr_cache: dict[str, dict[str, Any]] = {} attr_cache: dict[str, dict[str, Any]] = {}
if not minimal_response or domain in NEED_ATTRIBUTE_DOMAINS: if not minimal_response or split_entity_id(ent_id)[0] in NEED_ATTRIBUTE_DOMAINS:
ent_results.extend(LazyState(db_state, attr_cache) for db_state in group) ent_results.extend(LazyState(db_state, attr_cache) for db_state in group)
continue
# With minimal response we only provide a native # With minimal response we only provide a native
# State for the first and last response. All the states # State for the first and last response. All the states
# in-between only provide the "state" and the # in-between only provide the "state" and the
# "last_changed". # "last_changed".
if not ent_results: if not ent_results:
ent_results.append(LazyState(next(group), attr_cache)) if (first_state := next(group, None)) is None:
continue
ent_results.append(LazyState(first_state, attr_cache))
prev_state = ent_results[-1] prev_state = ent_results[-1]
assert isinstance(prev_state, LazyState) assert isinstance(prev_state, LazyState)
@ -615,7 +679,7 @@ def get_state(
entity_id: str, entity_id: str,
run: RecorderRuns | None = None, run: RecorderRuns | None = None,
no_attributes: bool = False, no_attributes: bool = False,
) -> LazyState | None: ) -> State | None:
"""Return a state at a specific point in time.""" """Return a state at a specific point in time."""
states = get_states(hass, utc_point_in_time, [entity_id], run, None, no_attributes) states = get_states(hass, utc_point_in_time, [entity_id], run, None, no_attributes)
return states[0] if states else None return states[0] if states else None

View file

@ -7,7 +7,7 @@ import datetime
import itertools import itertools
import logging import logging
import math import math
from typing import Any, cast from typing import Any
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
@ -19,7 +19,6 @@ from homeassistant.components.recorder import (
) )
from homeassistant.components.recorder.const import DOMAIN as RECORDER_DOMAIN from homeassistant.components.recorder.const import DOMAIN as RECORDER_DOMAIN
from homeassistant.components.recorder.models import ( from homeassistant.components.recorder.models import (
LazyState,
StatisticData, StatisticData,
StatisticMetaData, StatisticMetaData,
StatisticResult, StatisticResult,
@ -417,9 +416,9 @@ def _compile_statistics( # noqa: C901
entities_full_history = [ entities_full_history = [
i.entity_id for i in sensor_states if "sum" in wanted_statistics[i.entity_id] i.entity_id for i in sensor_states if "sum" in wanted_statistics[i.entity_id]
] ]
history_list: MutableMapping[str, Iterable[LazyState | State | dict[str, Any]]] = {} history_list: MutableMapping[str, list[State]] = {}
if entities_full_history: if entities_full_history:
history_list = history.get_significant_states_with_session( history_list = history.get_full_significant_states_with_session(
hass, hass,
session, session,
start - datetime.timedelta.resolution, start - datetime.timedelta.resolution,
@ -433,7 +432,7 @@ def _compile_statistics( # noqa: C901
if "sum" not in wanted_statistics[i.entity_id] if "sum" not in wanted_statistics[i.entity_id]
] ]
if entities_significant_history: if entities_significant_history:
_history_list = history.get_significant_states_with_session( _history_list = history.get_full_significant_states_with_session(
hass, hass,
session, session,
start - datetime.timedelta.resolution, start - datetime.timedelta.resolution,
@ -445,7 +444,7 @@ def _compile_statistics( # noqa: C901
# from the recorder. Get the state from the state machine instead. # from the recorder. Get the state from the state machine instead.
for _state in sensor_states: for _state in sensor_states:
if _state.entity_id not in history_list: if _state.entity_id not in history_list:
history_list[_state.entity_id] = (_state,) history_list[_state.entity_id] = [_state]
for _state in sensor_states: # pylint: disable=too-many-nested-blocks for _state in sensor_states: # pylint: disable=too-many-nested-blocks
entity_id = _state.entity_id entity_id = _state.entity_id
@ -459,9 +458,7 @@ def _compile_statistics( # noqa: C901
hass, hass,
session, session,
old_metadatas, old_metadatas,
# entity_history does not contain minimal responses entity_history,
# so we must cast here
cast(list[State], entity_history),
device_class, device_class,
entity_id, entity_id,
) )

View file

@ -485,16 +485,14 @@ class StatisticsSensor(SensorEntity):
else: else:
start_date = datetime.fromtimestamp(0, tz=dt_util.UTC) start_date = datetime.fromtimestamp(0, tz=dt_util.UTC)
_LOGGER.debug("%s: retrieving all records", self.entity_id) _LOGGER.debug("%s: retrieving all records", self.entity_id)
entity_states = history.state_changes_during_period( return history.state_changes_during_period(
self.hass, self.hass,
start_date, start_date,
entity_id=lower_entity_id, entity_id=lower_entity_id,
descending=True, descending=True,
limit=self._samples_max_buffer_size, limit=self._samples_max_buffer_size,
include_start_time_state=False, include_start_time_state=False,
) ).get(lower_entity_id, [])
# Need to cast since minimal responses is not passed in
return cast(list[State], entity_states.get(lower_entity_id, []))
async def _initialize_from_database(self) -> None: async def _initialize_from_database(self) -> None:
"""Initialize the list of states from the database. """Initialize the list of states from the database.

View file

@ -124,6 +124,62 @@ def test_get_states(hass_recorder):
assert history.get_state(hass, time_before_recorder_ran, "demo.id") is None assert history.get_state(hass, time_before_recorder_ran, "demo.id") is None
def test_get_full_significant_states_with_session_entity_no_matches(hass_recorder):
"""Test getting states at a specific point in time for entities that never have been recorded."""
hass = hass_recorder()
now = dt_util.utcnow()
time_before_recorder_ran = now - timedelta(days=1000)
with recorder.session_scope(hass=hass) as session:
assert (
history.get_full_significant_states_with_session(
hass, session, time_before_recorder_ran, now, entity_ids=["demo.id"]
)
== {}
)
assert (
history.get_full_significant_states_with_session(
hass,
session,
time_before_recorder_ran,
now,
entity_ids=["demo.id", "demo.id2"],
)
== {}
)
def test_significant_states_with_session_entity_minimal_response_no_matches(
hass_recorder,
):
"""Test getting states at a specific point in time for entities that never have been recorded."""
hass = hass_recorder()
now = dt_util.utcnow()
time_before_recorder_ran = now - timedelta(days=1000)
with recorder.session_scope(hass=hass) as session:
assert (
history.get_significant_states_with_session(
hass,
session,
time_before_recorder_ran,
now,
entity_ids=["demo.id"],
minimal_response=True,
)
== {}
)
assert (
history.get_significant_states_with_session(
hass,
session,
time_before_recorder_ran,
now,
entity_ids=["demo.id", "demo.id2"],
minimal_response=True,
)
== {}
)
def test_get_states_no_attributes(hass_recorder): def test_get_states_no_attributes(hass_recorder):
"""Test getting states without attributes at a specific point in time.""" """Test getting states without attributes at a specific point in time."""
hass = hass_recorder() hass = hass_recorder()