Convert statistics to use history api for database access (#68411)

This commit is contained in:
J. Nick Koston 2022-04-01 05:49:21 -10:00 committed by GitHub
parent 4f4f7e40e3
commit 8cf6ac281e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -12,9 +12,7 @@ from typing import Any, Literal, cast
import voluptuous as vol
from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN
from homeassistant.components.recorder import get_instance
from homeassistant.components.recorder.models import StateAttributes, States
from homeassistant.components.recorder.util import execute, session_scope
from homeassistant.components.recorder import get_instance, history
from homeassistant.components.sensor import (
PLATFORM_SCHEMA,
SensorDeviceClass,
@ -474,37 +472,29 @@ class StatisticsSensor(SensorEntity):
def _fetch_states_from_database(self) -> list[State]:
"""Fetch the states from the database."""
_LOGGER.debug("%s: initializing values from the database", self.entity_id)
states = []
with session_scope(hass=self.hass) as session:
query = session.query(States, StateAttributes).filter(
States.entity_id == self._source_entity_id.lower()
lower_entity_id = self._source_entity_id.lower()
if self._samples_max_age is not None:
start_date = (
dt_util.utcnow() - self._samples_max_age - timedelta(microseconds=1)
)
if self._samples_max_age is not None:
records_older_then = dt_util.utcnow() - self._samples_max_age
_LOGGER.debug(
"%s: retrieve records not older then %s",
self.entity_id,
records_older_then,
)
query = query.filter(States.last_updated >= records_older_then)
else:
_LOGGER.debug("%s: retrieving all records", self.entity_id)
query = query.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
_LOGGER.debug(
"%s: retrieve records not older then %s",
self.entity_id,
start_date,
)
query = query.order_by(States.last_updated.desc()).limit(
self._samples_max_buffer_size
)
if results := execute(query, to_native=False, validate_entity_ids=False):
for state, attributes in results:
native = state.to_native()
if not native.attributes:
native.attributes = attributes.to_native()
states.append(native)
return states
else:
start_date = datetime.fromtimestamp(0, tz=dt_util.UTC)
_LOGGER.debug("%s: retrieving all records", self.entity_id)
entity_states = history.state_changes_during_period(
self.hass,
start_date,
entity_id=lower_entity_id,
descending=True,
limit=self._samples_max_buffer_size,
include_start_time_state=False,
)
# Need to cast since minimal responses is not passed in
return cast(list[State], entity_states.get(lower_entity_id, []))
async def _initialize_from_database(self) -> None:
"""Initialize the list of states from the database.