From 8cf6ac281ee6ad4122fa0836fad297d2da50cc9c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" <nick@koston.org> Date: Fri, 1 Apr 2022 05:49:21 -1000 Subject: [PATCH] Convert statistics to use history api for database access (#68411) --- homeassistant/components/statistics/sensor.py | 54 ++++++++----------- 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/homeassistant/components/statistics/sensor.py b/homeassistant/components/statistics/sensor.py index d632e9710cf..aaca8a98290 100644 --- a/homeassistant/components/statistics/sensor.py +++ b/homeassistant/components/statistics/sensor.py @@ -12,9 +12,7 @@ from typing import Any, Literal, cast import voluptuous as vol from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN -from homeassistant.components.recorder import get_instance -from homeassistant.components.recorder.models import StateAttributes, States -from homeassistant.components.recorder.util import execute, session_scope +from homeassistant.components.recorder import get_instance, history from homeassistant.components.sensor import ( PLATFORM_SCHEMA, SensorDeviceClass, @@ -474,37 +472,29 @@ class StatisticsSensor(SensorEntity): def _fetch_states_from_database(self) -> list[State]: """Fetch the states from the database.""" _LOGGER.debug("%s: initializing values from the database", self.entity_id) - states = [] - - with session_scope(hass=self.hass) as session: - query = session.query(States, StateAttributes).filter( - States.entity_id == self._source_entity_id.lower() + lower_entity_id = self._source_entity_id.lower() + if self._samples_max_age is not None: + start_date = ( + dt_util.utcnow() - self._samples_max_age - timedelta(microseconds=1) ) - - if self._samples_max_age is not None: - records_older_then = dt_util.utcnow() - self._samples_max_age - _LOGGER.debug( - "%s: retrieve records not older then %s", - self.entity_id, - records_older_then, - ) - query = query.filter(States.last_updated >= records_older_then) - else: - _LOGGER.debug("%s: retrieving all records", self.entity_id) - - query = query.outerjoin( - StateAttributes, States.attributes_id == StateAttributes.attributes_id + _LOGGER.debug( + "%s: retrieve records not older then %s", + self.entity_id, + start_date, ) - query = query.order_by(States.last_updated.desc()).limit( - self._samples_max_buffer_size - ) - if results := execute(query, to_native=False, validate_entity_ids=False): - for state, attributes in results: - native = state.to_native() - if not native.attributes: - native.attributes = attributes.to_native() - states.append(native) - return states + else: + start_date = datetime.fromtimestamp(0, tz=dt_util.UTC) + _LOGGER.debug("%s: retrieving all records", self.entity_id) + entity_states = history.state_changes_during_period( + self.hass, + start_date, + entity_id=lower_entity_id, + descending=True, + limit=self._samples_max_buffer_size, + include_start_time_state=False, + ) + # Need to cast since minimal responses is not passed in + return cast(list[State], entity_states.get(lower_entity_id, [])) async def _initialize_from_database(self) -> None: """Initialize the list of states from the database.