Convert statistics to use history api for database access ()

This commit is contained in:
J. Nick Koston 2022-04-01 05:49:21 -10:00 committed by GitHub
parent 4f4f7e40e3
commit 8cf6ac281e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -12,9 +12,7 @@ from typing import Any, Literal, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN
from homeassistant.components.recorder import get_instance from homeassistant.components.recorder import get_instance, history
from homeassistant.components.recorder.models import StateAttributes, States
from homeassistant.components.recorder.util import execute, session_scope
from homeassistant.components.sensor import ( from homeassistant.components.sensor import (
PLATFORM_SCHEMA, PLATFORM_SCHEMA,
SensorDeviceClass, SensorDeviceClass,
@ -474,37 +472,29 @@ class StatisticsSensor(SensorEntity):
def _fetch_states_from_database(self) -> list[State]: def _fetch_states_from_database(self) -> list[State]:
"""Fetch the states from the database.""" """Fetch the states from the database."""
_LOGGER.debug("%s: initializing values from the database", self.entity_id) _LOGGER.debug("%s: initializing values from the database", self.entity_id)
states = [] lower_entity_id = self._source_entity_id.lower()
with session_scope(hass=self.hass) as session:
query = session.query(States, StateAttributes).filter(
States.entity_id == self._source_entity_id.lower()
)
if self._samples_max_age is not None: if self._samples_max_age is not None:
records_older_then = dt_util.utcnow() - self._samples_max_age start_date = (
dt_util.utcnow() - self._samples_max_age - timedelta(microseconds=1)
)
_LOGGER.debug( _LOGGER.debug(
"%s: retrieve records not older then %s", "%s: retrieve records not older then %s",
self.entity_id, self.entity_id,
records_older_then, start_date,
) )
query = query.filter(States.last_updated >= records_older_then)
else: else:
start_date = datetime.fromtimestamp(0, tz=dt_util.UTC)
_LOGGER.debug("%s: retrieving all records", self.entity_id) _LOGGER.debug("%s: retrieving all records", self.entity_id)
entity_states = history.state_changes_during_period(
query = query.outerjoin( self.hass,
StateAttributes, States.attributes_id == StateAttributes.attributes_id start_date,
entity_id=lower_entity_id,
descending=True,
limit=self._samples_max_buffer_size,
include_start_time_state=False,
) )
query = query.order_by(States.last_updated.desc()).limit( # Need to cast since minimal responses is not passed in
self._samples_max_buffer_size return cast(list[State], entity_states.get(lower_entity_id, []))
)
if results := execute(query, to_native=False, validate_entity_ids=False):
for state, attributes in results:
native = state.to_native()
if not native.attributes:
native.attributes = attributes.to_native()
states.append(native)
return states
async def _initialize_from_database(self) -> None: async def _initialize_from_database(self) -> None:
"""Initialize the list of states from the database. """Initialize the list of states from the database.