Use state_reported events in Riemann sum sensor (#113869)

This commit is contained in:
Erik Montnemery 2024-06-26 13:35:01 +02:00 committed by GitHub
parent f0590f08b1
commit a36c40a434
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 113 additions and 96 deletions

View file

@ -8,7 +8,7 @@ from datetime import UTC, datetime, timedelta
from decimal import Decimal, InvalidOperation
from enum import Enum
import logging
from typing import Any, Final, Self
from typing import TYPE_CHECKING, Any, Final, Self
import voluptuous as vol
@ -27,6 +27,8 @@ from homeassistant.const import (
CONF_METHOD,
CONF_NAME,
CONF_UNIQUE_ID,
EVENT_STATE_CHANGED,
EVENT_STATE_REPORTED,
STATE_UNAVAILABLE,
UnitOfTime,
)
@ -34,6 +36,7 @@ from homeassistant.core import (
CALLBACK_TYPE,
Event,
EventStateChangedData,
EventStateReportedData,
HomeAssistant,
State,
callback,
@ -42,7 +45,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.device import async_device_info_to_link_from_entity
from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_call_later, async_track_state_change_event
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .const import (
@ -107,9 +110,7 @@ class _IntegrationMethod(ABC):
return _NAME_TO_INTEGRATION_METHOD[method_name]()
@abstractmethod
def validate_states(
self, left: State, right: State
) -> tuple[Decimal, Decimal] | None:
def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None:
"""Check state requirements for integration."""
@abstractmethod
@ -130,11 +131,9 @@ class _Trapezoidal(_IntegrationMethod):
) -> Decimal:
return elapsed_time * (left + right) / 2
def validate_states(
self, left: State, right: State
) -> tuple[Decimal, Decimal] | None:
if (left_dec := _decimal_state(left.state)) is None or (
right_dec := _decimal_state(right.state)
def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None:
if (left_dec := _decimal_state(left)) is None or (
right_dec := _decimal_state(right)
) is None:
return None
return (left_dec, right_dec)
@ -146,10 +145,8 @@ class _Left(_IntegrationMethod):
) -> Decimal:
return self.calculate_area_with_one_state(elapsed_time, left)
def validate_states(
self, left: State, right: State
) -> tuple[Decimal, Decimal] | None:
if (left_dec := _decimal_state(left.state)) is None:
def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None:
if (left_dec := _decimal_state(left)) is None:
return None
return (left_dec, left_dec)
@ -160,10 +157,8 @@ class _Right(_IntegrationMethod):
) -> Decimal:
return self.calculate_area_with_one_state(elapsed_time, right)
def validate_states(
self, left: State, right: State
) -> tuple[Decimal, Decimal] | None:
if (right_dec := _decimal_state(right.state)) is None:
def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None:
if (right_dec := _decimal_state(right)) is None:
return None
return (right_dec, right_dec)
@ -183,7 +178,7 @@ _NAME_TO_INTEGRATION_METHOD: dict[str, type[_IntegrationMethod]] = {
class _IntegrationTrigger(Enum):
StateChange = "state_change"
StateEvent = "state_event"
TimeElapsed = "time_elapsed"
@ -343,7 +338,7 @@ class IntegrationSensor(RestoreSensor):
)
self._max_sub_interval_exceeded_callback: CALLBACK_TYPE = lambda *args: None
self._last_integration_time: datetime = datetime.now(tz=UTC)
self._last_integration_trigger = _IntegrationTrigger.StateChange
self._last_integration_trigger = _IntegrationTrigger.StateEvent
self._attr_suggested_display_precision = round_digits or 2
def _calculate_unit(self, source_unit: str) -> str:
@ -433,9 +428,11 @@ class IntegrationSensor(RestoreSensor):
source_state = self.hass.states.get(self._sensor_source_id)
self._schedule_max_sub_interval_exceeded_if_state_is_numeric(source_state)
self.async_on_remove(self._cancel_max_sub_interval_exceeded_callback)
handle_state_change = self._integrate_on_state_change_and_max_sub_interval
handle_state_change = self._integrate_on_state_change_with_max_sub_interval
handle_state_report = self._integrate_on_state_report_with_max_sub_interval
else:
handle_state_change = self._integrate_on_state_change_callback
handle_state_report = self._integrate_on_state_report_callback
if (
state := self.hass.states.get(self._source_entity)
@ -443,16 +440,50 @@ class IntegrationSensor(RestoreSensor):
self._derive_and_set_attributes_from_state(state)
self.async_on_remove(
async_track_state_change_event(
self.hass,
[self._sensor_source_id],
self.hass.bus.async_listen(
EVENT_STATE_CHANGED,
handle_state_change,
event_filter=callback(
lambda event_data: event_data["entity_id"] == self._sensor_source_id
),
run_immediately=True,
)
)
self.async_on_remove(
self.hass.bus.async_listen(
EVENT_STATE_REPORTED,
handle_state_report,
event_filter=callback(
lambda event_data: event_data["entity_id"] == self._sensor_source_id
),
run_immediately=True,
)
)
@callback
def _integrate_on_state_change_and_max_sub_interval(
def _integrate_on_state_change_with_max_sub_interval(
self, event: Event[EventStateChangedData]
) -> None:
"""Handle sensor state update when sub interval is configured."""
self._integrate_on_state_update_with_max_sub_interval(
None, event.data["old_state"], event.data["new_state"]
)
@callback
def _integrate_on_state_report_with_max_sub_interval(
self, event: Event[EventStateReportedData]
) -> None:
"""Handle sensor state report when sub interval is configured."""
self._integrate_on_state_update_with_max_sub_interval(
event.data["old_last_reported"], None, event.data["new_state"]
)
@callback
def _integrate_on_state_update_with_max_sub_interval(
self,
old_last_reported: datetime | None,
old_state: State | None,
new_state: State | None,
) -> None:
"""Integrate based on state change and time.
@ -460,11 +491,9 @@ class IntegrationSensor(RestoreSensor):
reschedules time based integration.
"""
self._cancel_max_sub_interval_exceeded_callback()
old_state = event.data["old_state"]
new_state = event.data["new_state"]
try:
self._integrate_on_state_change(old_state, new_state)
self._last_integration_trigger = _IntegrationTrigger.StateChange
self._integrate_on_state_change(old_last_reported, old_state, new_state)
self._last_integration_trigger = _IntegrationTrigger.StateEvent
self._last_integration_time = datetime.now(tz=UTC)
finally:
# When max_sub_interval exceeds without state change the source is assumed
@ -475,13 +504,25 @@ class IntegrationSensor(RestoreSensor):
def _integrate_on_state_change_callback(
self, event: Event[EventStateChangedData]
) -> None:
"""Handle the sensor state changes."""
old_state = event.data["old_state"]
new_state = event.data["new_state"]
return self._integrate_on_state_change(old_state, new_state)
"""Handle sensor state change."""
return self._integrate_on_state_change(
None, event.data["old_state"], event.data["new_state"]
)
@callback
def _integrate_on_state_report_callback(
self, event: Event[EventStateReportedData]
) -> None:
"""Handle sensor state report."""
return self._integrate_on_state_change(
event.data["old_last_reported"], None, event.data["new_state"]
)
def _integrate_on_state_change(
self, old_state: State | None, new_state: State | None
self,
old_last_reported: datetime | None,
old_state: State | None,
new_state: State | None,
) -> None:
if new_state is None:
return
@ -491,21 +532,33 @@ class IntegrationSensor(RestoreSensor):
self.async_write_ha_state()
return
if old_state:
# state has changed, we recover old_state from the event
old_state_state = old_state.state
old_last_reported = old_state.last_reported
else:
# event state reported without any state change
old_state_state = new_state.state
self._attr_available = True
self._derive_and_set_attributes_from_state(new_state)
if old_state is None:
if old_last_reported is None and old_state is None:
self.async_write_ha_state()
return
if not (states := self._method.validate_states(old_state, new_state)):
if not (
states := self._method.validate_states(old_state_state, new_state.state)
):
self.async_write_ha_state()
return
if TYPE_CHECKING:
assert old_last_reported is not None
elapsed_seconds = Decimal(
(new_state.last_updated - old_state.last_updated).total_seconds()
if self._last_integration_trigger == _IntegrationTrigger.StateChange
else (new_state.last_updated - self._last_integration_time).total_seconds()
(new_state.last_reported - old_last_reported).total_seconds()
if self._last_integration_trigger == _IntegrationTrigger.StateEvent
else (new_state.last_reported - self._last_integration_time).total_seconds()
)
area = self._method.calculate_area_with_two_states(elapsed_seconds, *states)