Convert statistics to use history api for database access (#68411)
This commit is contained in:
parent
4f4f7e40e3
commit
8cf6ac281e
1 changed files with 22 additions and 32 deletions
|
@ -12,9 +12,7 @@ from typing import Any, Literal, cast
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN
|
||||
from homeassistant.components.recorder import get_instance
|
||||
from homeassistant.components.recorder.models import StateAttributes, States
|
||||
from homeassistant.components.recorder.util import execute, session_scope
|
||||
from homeassistant.components.recorder import get_instance, history
|
||||
from homeassistant.components.sensor import (
|
||||
PLATFORM_SCHEMA,
|
||||
SensorDeviceClass,
|
||||
|
@ -474,37 +472,29 @@ class StatisticsSensor(SensorEntity):
|
|||
def _fetch_states_from_database(self) -> list[State]:
|
||||
"""Fetch the states from the database."""
|
||||
_LOGGER.debug("%s: initializing values from the database", self.entity_id)
|
||||
states = []
|
||||
|
||||
with session_scope(hass=self.hass) as session:
|
||||
query = session.query(States, StateAttributes).filter(
|
||||
States.entity_id == self._source_entity_id.lower()
|
||||
lower_entity_id = self._source_entity_id.lower()
|
||||
if self._samples_max_age is not None:
|
||||
start_date = (
|
||||
dt_util.utcnow() - self._samples_max_age - timedelta(microseconds=1)
|
||||
)
|
||||
|
||||
if self._samples_max_age is not None:
|
||||
records_older_then = dt_util.utcnow() - self._samples_max_age
|
||||
_LOGGER.debug(
|
||||
"%s: retrieve records not older then %s",
|
||||
self.entity_id,
|
||||
records_older_then,
|
||||
)
|
||||
query = query.filter(States.last_updated >= records_older_then)
|
||||
else:
|
||||
_LOGGER.debug("%s: retrieving all records", self.entity_id)
|
||||
|
||||
query = query.outerjoin(
|
||||
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
||||
_LOGGER.debug(
|
||||
"%s: retrieve records not older then %s",
|
||||
self.entity_id,
|
||||
start_date,
|
||||
)
|
||||
query = query.order_by(States.last_updated.desc()).limit(
|
||||
self._samples_max_buffer_size
|
||||
)
|
||||
if results := execute(query, to_native=False, validate_entity_ids=False):
|
||||
for state, attributes in results:
|
||||
native = state.to_native()
|
||||
if not native.attributes:
|
||||
native.attributes = attributes.to_native()
|
||||
states.append(native)
|
||||
return states
|
||||
else:
|
||||
start_date = datetime.fromtimestamp(0, tz=dt_util.UTC)
|
||||
_LOGGER.debug("%s: retrieving all records", self.entity_id)
|
||||
entity_states = history.state_changes_during_period(
|
||||
self.hass,
|
||||
start_date,
|
||||
entity_id=lower_entity_id,
|
||||
descending=True,
|
||||
limit=self._samples_max_buffer_size,
|
||||
include_start_time_state=False,
|
||||
)
|
||||
# Need to cast since minimal responses is not passed in
|
||||
return cast(list[State], entity_states.get(lower_entity_id, []))
|
||||
|
||||
async def _initialize_from_database(self) -> None:
|
||||
"""Initialize the list of states from the database.
|
||||
|
|
Loading…
Add table
Reference in a new issue