Use state_reported events in Riemann sum sensor (#113869)

This commit is contained in:
Erik Montnemery 2024-06-26 13:35:01 +02:00 committed by GitHub
parent f0590f08b1
commit a36c40a434
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 113 additions and 96 deletions

View file

@ -8,7 +8,7 @@ from datetime import UTC, datetime, timedelta
from decimal import Decimal, InvalidOperation from decimal import Decimal, InvalidOperation
from enum import Enum from enum import Enum
import logging import logging
from typing import Any, Final, Self from typing import TYPE_CHECKING, Any, Final, Self
import voluptuous as vol import voluptuous as vol
@ -27,6 +27,8 @@ from homeassistant.const import (
CONF_METHOD, CONF_METHOD,
CONF_NAME, CONF_NAME,
CONF_UNIQUE_ID, CONF_UNIQUE_ID,
EVENT_STATE_CHANGED,
EVENT_STATE_REPORTED,
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
UnitOfTime, UnitOfTime,
) )
@ -34,6 +36,7 @@ from homeassistant.core import (
CALLBACK_TYPE, CALLBACK_TYPE,
Event, Event,
EventStateChangedData, EventStateChangedData,
EventStateReportedData,
HomeAssistant, HomeAssistant,
State, State,
callback, callback,
@ -42,7 +45,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.device import async_device_info_to_link_from_entity from homeassistant.helpers.device import async_device_info_to_link_from_entity
from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_call_later, async_track_state_change_event from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .const import ( from .const import (
@ -107,9 +110,7 @@ class _IntegrationMethod(ABC):
return _NAME_TO_INTEGRATION_METHOD[method_name]() return _NAME_TO_INTEGRATION_METHOD[method_name]()
@abstractmethod @abstractmethod
def validate_states( def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None:
self, left: State, right: State
) -> tuple[Decimal, Decimal] | None:
"""Check state requirements for integration.""" """Check state requirements for integration."""
@abstractmethod @abstractmethod
@ -130,11 +131,9 @@ class _Trapezoidal(_IntegrationMethod):
) -> Decimal: ) -> Decimal:
return elapsed_time * (left + right) / 2 return elapsed_time * (left + right) / 2
def validate_states( def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None:
self, left: State, right: State if (left_dec := _decimal_state(left)) is None or (
) -> tuple[Decimal, Decimal] | None: right_dec := _decimal_state(right)
if (left_dec := _decimal_state(left.state)) is None or (
right_dec := _decimal_state(right.state)
) is None: ) is None:
return None return None
return (left_dec, right_dec) return (left_dec, right_dec)
@ -146,10 +145,8 @@ class _Left(_IntegrationMethod):
) -> Decimal: ) -> Decimal:
return self.calculate_area_with_one_state(elapsed_time, left) return self.calculate_area_with_one_state(elapsed_time, left)
def validate_states( def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None:
self, left: State, right: State if (left_dec := _decimal_state(left)) is None:
) -> tuple[Decimal, Decimal] | None:
if (left_dec := _decimal_state(left.state)) is None:
return None return None
return (left_dec, left_dec) return (left_dec, left_dec)
@ -160,10 +157,8 @@ class _Right(_IntegrationMethod):
) -> Decimal: ) -> Decimal:
return self.calculate_area_with_one_state(elapsed_time, right) return self.calculate_area_with_one_state(elapsed_time, right)
def validate_states( def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None:
self, left: State, right: State if (right_dec := _decimal_state(right)) is None:
) -> tuple[Decimal, Decimal] | None:
if (right_dec := _decimal_state(right.state)) is None:
return None return None
return (right_dec, right_dec) return (right_dec, right_dec)
@ -183,7 +178,7 @@ _NAME_TO_INTEGRATION_METHOD: dict[str, type[_IntegrationMethod]] = {
class _IntegrationTrigger(Enum): class _IntegrationTrigger(Enum):
StateChange = "state_change" StateEvent = "state_event"
TimeElapsed = "time_elapsed" TimeElapsed = "time_elapsed"
@ -343,7 +338,7 @@ class IntegrationSensor(RestoreSensor):
) )
self._max_sub_interval_exceeded_callback: CALLBACK_TYPE = lambda *args: None self._max_sub_interval_exceeded_callback: CALLBACK_TYPE = lambda *args: None
self._last_integration_time: datetime = datetime.now(tz=UTC) self._last_integration_time: datetime = datetime.now(tz=UTC)
self._last_integration_trigger = _IntegrationTrigger.StateChange self._last_integration_trigger = _IntegrationTrigger.StateEvent
self._attr_suggested_display_precision = round_digits or 2 self._attr_suggested_display_precision = round_digits or 2
def _calculate_unit(self, source_unit: str) -> str: def _calculate_unit(self, source_unit: str) -> str:
@ -433,9 +428,11 @@ class IntegrationSensor(RestoreSensor):
source_state = self.hass.states.get(self._sensor_source_id) source_state = self.hass.states.get(self._sensor_source_id)
self._schedule_max_sub_interval_exceeded_if_state_is_numeric(source_state) self._schedule_max_sub_interval_exceeded_if_state_is_numeric(source_state)
self.async_on_remove(self._cancel_max_sub_interval_exceeded_callback) self.async_on_remove(self._cancel_max_sub_interval_exceeded_callback)
handle_state_change = self._integrate_on_state_change_and_max_sub_interval handle_state_change = self._integrate_on_state_change_with_max_sub_interval
handle_state_report = self._integrate_on_state_report_with_max_sub_interval
else: else:
handle_state_change = self._integrate_on_state_change_callback handle_state_change = self._integrate_on_state_change_callback
handle_state_report = self._integrate_on_state_report_callback
if ( if (
state := self.hass.states.get(self._source_entity) state := self.hass.states.get(self._source_entity)
@ -443,16 +440,50 @@ class IntegrationSensor(RestoreSensor):
self._derive_and_set_attributes_from_state(state) self._derive_and_set_attributes_from_state(state)
self.async_on_remove( self.async_on_remove(
async_track_state_change_event( self.hass.bus.async_listen(
self.hass, EVENT_STATE_CHANGED,
[self._sensor_source_id],
handle_state_change, handle_state_change,
event_filter=callback(
lambda event_data: event_data["entity_id"] == self._sensor_source_id
),
run_immediately=True,
)
)
self.async_on_remove(
self.hass.bus.async_listen(
EVENT_STATE_REPORTED,
handle_state_report,
event_filter=callback(
lambda event_data: event_data["entity_id"] == self._sensor_source_id
),
run_immediately=True,
) )
) )
@callback @callback
def _integrate_on_state_change_and_max_sub_interval( def _integrate_on_state_change_with_max_sub_interval(
self, event: Event[EventStateChangedData] self, event: Event[EventStateChangedData]
) -> None:
"""Handle sensor state update when sub interval is configured."""
self._integrate_on_state_update_with_max_sub_interval(
None, event.data["old_state"], event.data["new_state"]
)
@callback
def _integrate_on_state_report_with_max_sub_interval(
self, event: Event[EventStateReportedData]
) -> None:
"""Handle sensor state report when sub interval is configured."""
self._integrate_on_state_update_with_max_sub_interval(
event.data["old_last_reported"], None, event.data["new_state"]
)
@callback
def _integrate_on_state_update_with_max_sub_interval(
self,
old_last_reported: datetime | None,
old_state: State | None,
new_state: State | None,
) -> None: ) -> None:
"""Integrate based on state change and time. """Integrate based on state change and time.
@ -460,11 +491,9 @@ class IntegrationSensor(RestoreSensor):
reschedules time based integration. reschedules time based integration.
""" """
self._cancel_max_sub_interval_exceeded_callback() self._cancel_max_sub_interval_exceeded_callback()
old_state = event.data["old_state"]
new_state = event.data["new_state"]
try: try:
self._integrate_on_state_change(old_state, new_state) self._integrate_on_state_change(old_last_reported, old_state, new_state)
self._last_integration_trigger = _IntegrationTrigger.StateChange self._last_integration_trigger = _IntegrationTrigger.StateEvent
self._last_integration_time = datetime.now(tz=UTC) self._last_integration_time = datetime.now(tz=UTC)
finally: finally:
# When max_sub_interval exceeds without state change the source is assumed # When max_sub_interval exceeds without state change the source is assumed
@ -475,13 +504,25 @@ class IntegrationSensor(RestoreSensor):
def _integrate_on_state_change_callback( def _integrate_on_state_change_callback(
self, event: Event[EventStateChangedData] self, event: Event[EventStateChangedData]
) -> None: ) -> None:
"""Handle the sensor state changes.""" """Handle sensor state change."""
old_state = event.data["old_state"] return self._integrate_on_state_change(
new_state = event.data["new_state"] None, event.data["old_state"], event.data["new_state"]
return self._integrate_on_state_change(old_state, new_state) )
@callback
def _integrate_on_state_report_callback(
self, event: Event[EventStateReportedData]
) -> None:
"""Handle sensor state report."""
return self._integrate_on_state_change(
event.data["old_last_reported"], None, event.data["new_state"]
)
def _integrate_on_state_change( def _integrate_on_state_change(
self, old_state: State | None, new_state: State | None self,
old_last_reported: datetime | None,
old_state: State | None,
new_state: State | None,
) -> None: ) -> None:
if new_state is None: if new_state is None:
return return
@ -491,21 +532,33 @@ class IntegrationSensor(RestoreSensor):
self.async_write_ha_state() self.async_write_ha_state()
return return
if old_state:
# state has changed, we recover old_state from the event
old_state_state = old_state.state
old_last_reported = old_state.last_reported
else:
# event state reported without any state change
old_state_state = new_state.state
self._attr_available = True self._attr_available = True
self._derive_and_set_attributes_from_state(new_state) self._derive_and_set_attributes_from_state(new_state)
if old_state is None: if old_last_reported is None and old_state is None:
self.async_write_ha_state() self.async_write_ha_state()
return return
if not (states := self._method.validate_states(old_state, new_state)): if not (
states := self._method.validate_states(old_state_state, new_state.state)
):
self.async_write_ha_state() self.async_write_ha_state()
return return
if TYPE_CHECKING:
assert old_last_reported is not None
elapsed_seconds = Decimal( elapsed_seconds = Decimal(
(new_state.last_updated - old_state.last_updated).total_seconds() (new_state.last_reported - old_last_reported).total_seconds()
if self._last_integration_trigger == _IntegrationTrigger.StateChange if self._last_integration_trigger == _IntegrationTrigger.StateEvent
else (new_state.last_updated - self._last_integration_time).total_seconds() else (new_state.last_reported - self._last_integration_time).total_seconds()
) )
area = self._method.calculate_area_with_two_states(elapsed_seconds, *states) area = self._method.calculate_area_with_two_states(elapsed_seconds, *states)

View file

@ -294,28 +294,16 @@ async def test_restore_state_failed(hass: HomeAssistant, extra_attributes) -> No
assert state.state == STATE_UNKNOWN assert state.state == STATE_UNKNOWN
@pytest.mark.parametrize("force_update", [False, True])
@pytest.mark.parametrize( @pytest.mark.parametrize(
("force_update", "sequence"), "sequence",
[ [
( (
False, (20, 10, 1.67),
( (30, 30, 5.0),
(20, 10, 1.67), (40, 5, 7.92),
(30, 30, 5.0), (50, 5, 8.75),
(40, 5, 7.92), (60, 0, 9.17),
(50, 5, 7.92),
(60, 0, 8.75),
),
),
(
True,
(
(20, 10, 1.67),
(30, 30, 5.0),
(40, 5, 7.92),
(50, 5, 8.75),
(60, 0, 9.17),
),
), ),
], ],
) )
@ -358,28 +346,16 @@ async def test_trapezoidal(
assert state.attributes.get("unit_of_measurement") == UnitOfEnergy.KILO_WATT_HOUR assert state.attributes.get("unit_of_measurement") == UnitOfEnergy.KILO_WATT_HOUR
@pytest.mark.parametrize("force_update", [False, True])
@pytest.mark.parametrize( @pytest.mark.parametrize(
("force_update", "sequence"), "sequence",
[ [
( (
False, (20, 10, 0.0),
( (30, 30, 1.67),
(20, 10, 0.0), (40, 5, 6.67),
(30, 30, 1.67), (50, 5, 7.5),
(40, 5, 6.67), (60, 0, 8.33),
(50, 5, 6.67),
(60, 0, 8.33),
),
),
(
True,
(
(20, 10, 0.0),
(30, 30, 1.67),
(40, 5, 6.67),
(50, 5, 7.5),
(60, 0, 8.33),
),
), ),
], ],
) )
@ -425,28 +401,16 @@ async def test_left(
assert state.attributes.get("unit_of_measurement") == UnitOfEnergy.KILO_WATT_HOUR assert state.attributes.get("unit_of_measurement") == UnitOfEnergy.KILO_WATT_HOUR
@pytest.mark.parametrize("force_update", [False, True])
@pytest.mark.parametrize( @pytest.mark.parametrize(
("force_update", "sequence"), "sequence",
[ [
( (
False, (20, 10, 3.33),
( (30, 30, 8.33),
(20, 10, 3.33), (40, 5, 9.17),
(30, 30, 8.33), (50, 5, 10.0),
(40, 5, 9.17), (60, 0, 10.0),
(50, 5, 9.17),
(60, 0, 9.17),
),
),
(
True,
(
(20, 10, 3.33),
(30, 30, 8.33),
(40, 5, 9.17),
(50, 5, 10.0),
(60, 0, 10.0),
),
), ),
], ],
) )