From 8cf6ac281ee6ad4122fa0836fad297d2da50cc9c Mon Sep 17 00:00:00 2001
From: "J. Nick Koston" <nick@koston.org>
Date: Fri, 1 Apr 2022 05:49:21 -1000
Subject: [PATCH] Convert statistics to use history api for database access
 (#68411)

---
 homeassistant/components/statistics/sensor.py | 54 ++++++++-----------
 1 file changed, 22 insertions(+), 32 deletions(-)

diff --git a/homeassistant/components/statistics/sensor.py b/homeassistant/components/statistics/sensor.py
index d632e9710cf..aaca8a98290 100644
--- a/homeassistant/components/statistics/sensor.py
+++ b/homeassistant/components/statistics/sensor.py
@@ -12,9 +12,7 @@ from typing import Any, Literal, cast
 import voluptuous as vol
 
 from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN
-from homeassistant.components.recorder import get_instance
-from homeassistant.components.recorder.models import StateAttributes, States
-from homeassistant.components.recorder.util import execute, session_scope
+from homeassistant.components.recorder import get_instance, history
 from homeassistant.components.sensor import (
     PLATFORM_SCHEMA,
     SensorDeviceClass,
@@ -474,37 +472,29 @@ class StatisticsSensor(SensorEntity):
     def _fetch_states_from_database(self) -> list[State]:
         """Fetch the states from the database."""
         _LOGGER.debug("%s: initializing values from the database", self.entity_id)
-        states = []
-
-        with session_scope(hass=self.hass) as session:
-            query = session.query(States, StateAttributes).filter(
-                States.entity_id == self._source_entity_id.lower()
+        lower_entity_id = self._source_entity_id.lower()
+        if self._samples_max_age is not None:
+            start_date = (
+                dt_util.utcnow() - self._samples_max_age - timedelta(microseconds=1)
             )
-
-            if self._samples_max_age is not None:
-                records_older_then = dt_util.utcnow() - self._samples_max_age
-                _LOGGER.debug(
-                    "%s: retrieve records not older then %s",
-                    self.entity_id,
-                    records_older_then,
-                )
-                query = query.filter(States.last_updated >= records_older_then)
-            else:
-                _LOGGER.debug("%s: retrieving all records", self.entity_id)
-
-            query = query.outerjoin(
-                StateAttributes, States.attributes_id == StateAttributes.attributes_id
+            _LOGGER.debug(
+                "%s: retrieve records not older then %s",
+                self.entity_id,
+                start_date,
             )
-            query = query.order_by(States.last_updated.desc()).limit(
-                self._samples_max_buffer_size
-            )
-            if results := execute(query, to_native=False, validate_entity_ids=False):
-                for state, attributes in results:
-                    native = state.to_native()
-                    if not native.attributes:
-                        native.attributes = attributes.to_native()
-                    states.append(native)
-        return states
+        else:
+            start_date = datetime.fromtimestamp(0, tz=dt_util.UTC)
+            _LOGGER.debug("%s: retrieving all records", self.entity_id)
+        entity_states = history.state_changes_during_period(
+            self.hass,
+            start_date,
+            entity_id=lower_entity_id,
+            descending=True,
+            limit=self._samples_max_buffer_size,
+            include_start_time_state=False,
+        )
+        # Need to cast since minimal responses is not passed in
+        return cast(list[State], entity_states.get(lower_entity_id, []))
 
     async def _initialize_from_database(self) -> None:
         """Initialize the list of states from the database.