Convert statistics to use history api for database access (#68411)
This commit is contained in:
parent
4f4f7e40e3
commit
8cf6ac281e
1 changed files with 22 additions and 32 deletions
|
@ -12,9 +12,7 @@ from typing import Any, Literal, cast
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN
|
from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN
|
||||||
from homeassistant.components.recorder import get_instance
|
from homeassistant.components.recorder import get_instance, history
|
||||||
from homeassistant.components.recorder.models import StateAttributes, States
|
|
||||||
from homeassistant.components.recorder.util import execute, session_scope
|
|
||||||
from homeassistant.components.sensor import (
|
from homeassistant.components.sensor import (
|
||||||
PLATFORM_SCHEMA,
|
PLATFORM_SCHEMA,
|
||||||
SensorDeviceClass,
|
SensorDeviceClass,
|
||||||
|
@ -474,37 +472,29 @@ class StatisticsSensor(SensorEntity):
|
||||||
def _fetch_states_from_database(self) -> list[State]:
|
def _fetch_states_from_database(self) -> list[State]:
|
||||||
"""Fetch the states from the database."""
|
"""Fetch the states from the database."""
|
||||||
_LOGGER.debug("%s: initializing values from the database", self.entity_id)
|
_LOGGER.debug("%s: initializing values from the database", self.entity_id)
|
||||||
states = []
|
lower_entity_id = self._source_entity_id.lower()
|
||||||
|
if self._samples_max_age is not None:
|
||||||
with session_scope(hass=self.hass) as session:
|
start_date = (
|
||||||
query = session.query(States, StateAttributes).filter(
|
dt_util.utcnow() - self._samples_max_age - timedelta(microseconds=1)
|
||||||
States.entity_id == self._source_entity_id.lower()
|
|
||||||
)
|
)
|
||||||
|
_LOGGER.debug(
|
||||||
if self._samples_max_age is not None:
|
"%s: retrieve records not older then %s",
|
||||||
records_older_then = dt_util.utcnow() - self._samples_max_age
|
self.entity_id,
|
||||||
_LOGGER.debug(
|
start_date,
|
||||||
"%s: retrieve records not older then %s",
|
|
||||||
self.entity_id,
|
|
||||||
records_older_then,
|
|
||||||
)
|
|
||||||
query = query.filter(States.last_updated >= records_older_then)
|
|
||||||
else:
|
|
||||||
_LOGGER.debug("%s: retrieving all records", self.entity_id)
|
|
||||||
|
|
||||||
query = query.outerjoin(
|
|
||||||
StateAttributes, States.attributes_id == StateAttributes.attributes_id
|
|
||||||
)
|
)
|
||||||
query = query.order_by(States.last_updated.desc()).limit(
|
else:
|
||||||
self._samples_max_buffer_size
|
start_date = datetime.fromtimestamp(0, tz=dt_util.UTC)
|
||||||
)
|
_LOGGER.debug("%s: retrieving all records", self.entity_id)
|
||||||
if results := execute(query, to_native=False, validate_entity_ids=False):
|
entity_states = history.state_changes_during_period(
|
||||||
for state, attributes in results:
|
self.hass,
|
||||||
native = state.to_native()
|
start_date,
|
||||||
if not native.attributes:
|
entity_id=lower_entity_id,
|
||||||
native.attributes = attributes.to_native()
|
descending=True,
|
||||||
states.append(native)
|
limit=self._samples_max_buffer_size,
|
||||||
return states
|
include_start_time_state=False,
|
||||||
|
)
|
||||||
|
# Need to cast since minimal responses is not passed in
|
||||||
|
return cast(list[State], entity_states.get(lower_entity_id, []))
|
||||||
|
|
||||||
async def _initialize_from_database(self) -> None:
|
async def _initialize_from_database(self) -> None:
|
||||||
"""Initialize the list of states from the database.
|
"""Initialize the list of states from the database.
|
||||||
|
|
Loading…
Add table
Reference in a new issue