Convert statistics to use history api for database access (#68411)

This commit is contained in:
J. Nick Koston 2022-04-01 05:49:21 -10:00 committed by GitHub
parent 4f4f7e40e3
commit 8cf6ac281e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -12,9 +12,7 @@ from typing import Any, Literal, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN
from homeassistant.components.recorder import get_instance from homeassistant.components.recorder import get_instance, history
from homeassistant.components.recorder.models import StateAttributes, States
from homeassistant.components.recorder.util import execute, session_scope
from homeassistant.components.sensor import ( from homeassistant.components.sensor import (
PLATFORM_SCHEMA, PLATFORM_SCHEMA,
SensorDeviceClass, SensorDeviceClass,
@ -474,37 +472,29 @@ class StatisticsSensor(SensorEntity):
def _fetch_states_from_database(self) -> list[State]: def _fetch_states_from_database(self) -> list[State]:
"""Fetch the states from the database.""" """Fetch the states from the database."""
_LOGGER.debug("%s: initializing values from the database", self.entity_id) _LOGGER.debug("%s: initializing values from the database", self.entity_id)
states = [] lower_entity_id = self._source_entity_id.lower()
if self._samples_max_age is not None:
with session_scope(hass=self.hass) as session: start_date = (
query = session.query(States, StateAttributes).filter( dt_util.utcnow() - self._samples_max_age - timedelta(microseconds=1)
States.entity_id == self._source_entity_id.lower()
) )
_LOGGER.debug(
if self._samples_max_age is not None: "%s: retrieve records not older then %s",
records_older_then = dt_util.utcnow() - self._samples_max_age self.entity_id,
_LOGGER.debug( start_date,
"%s: retrieve records not older then %s",
self.entity_id,
records_older_then,
)
query = query.filter(States.last_updated >= records_older_then)
else:
_LOGGER.debug("%s: retrieving all records", self.entity_id)
query = query.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
) )
query = query.order_by(States.last_updated.desc()).limit( else:
self._samples_max_buffer_size start_date = datetime.fromtimestamp(0, tz=dt_util.UTC)
) _LOGGER.debug("%s: retrieving all records", self.entity_id)
if results := execute(query, to_native=False, validate_entity_ids=False): entity_states = history.state_changes_during_period(
for state, attributes in results: self.hass,
native = state.to_native() start_date,
if not native.attributes: entity_id=lower_entity_id,
native.attributes = attributes.to_native() descending=True,
states.append(native) limit=self._samples_max_buffer_size,
return states include_start_time_state=False,
)
# Need to cast since minimal responses is not passed in
return cast(list[State], entity_states.get(lower_entity_id, []))
async def _initialize_from_database(self) -> None: async def _initialize_from_database(self) -> None:
"""Initialize the list of states from the database. """Initialize the list of states from the database.