Compare commits

...
Sign in to create a new pull request.

5 commits

Author SHA1 Message Date
Erik
903054db14 Refactor event handlers 2024-06-26 13:37:11 +02:00
Erik
423f4c5bcf Refactor event handlers 2024-06-26 13:25:52 +02:00
Erik
9cb8b5be54 Subscribe also to EVENT_STATE_CHANGED 2024-06-26 11:48:15 +02:00
Erik
2058777db6 Update tests 2024-06-26 11:27:02 +02:00
Erik
0da50c3f14 Modify integration sensor 2024-06-26 11:26:59 +02:00
2 changed files with 106 additions and 105 deletions

View file

@ -3,12 +3,14 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from decimal import Decimal, InvalidOperation
from enum import Enum
from functools import partial
import logging
from typing import Any, Final, Self
from typing import TYPE_CHECKING, Any, Final, Self
import voluptuous as vol
@ -27,6 +29,8 @@ from homeassistant.const import (
CONF_METHOD,
CONF_NAME,
CONF_UNIQUE_ID,
EVENT_STATE_CHANGED,
EVENT_STATE_REPORTED,
STATE_UNAVAILABLE,
UnitOfTime,
)
@ -34,6 +38,7 @@ from homeassistant.core import (
CALLBACK_TYPE,
Event,
EventStateChangedData,
EventStateReportedData,
HomeAssistant,
State,
callback,
@ -42,7 +47,7 @@ from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.device import async_device_info_to_link_from_entity
from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_call_later, async_track_state_change_event
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .const import (
@ -100,6 +105,8 @@ PLATFORM_SCHEMA = vol.All(
),
)
type StateUpdateFunc = Callable[[datetime | None, State | None, State | None], None]
class _IntegrationMethod(ABC):
@staticmethod
@ -107,9 +114,7 @@ class _IntegrationMethod(ABC):
return _NAME_TO_INTEGRATION_METHOD[method_name]()
@abstractmethod
def validate_states(
self, left: State, right: State
) -> tuple[Decimal, Decimal] | None:
def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None:
"""Check state requirements for integration."""
@abstractmethod
@ -130,11 +135,9 @@ class _Trapezoidal(_IntegrationMethod):
) -> Decimal:
return elapsed_time * (left + right) / 2
def validate_states(
self, left: State, right: State
) -> tuple[Decimal, Decimal] | None:
if (left_dec := _decimal_state(left.state)) is None or (
right_dec := _decimal_state(right.state)
def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None:
if (left_dec := _decimal_state(left)) is None or (
right_dec := _decimal_state(right)
) is None:
return None
return (left_dec, right_dec)
@ -146,10 +149,8 @@ class _Left(_IntegrationMethod):
) -> Decimal:
return self.calculate_area_with_one_state(elapsed_time, left)
def validate_states(
self, left: State, right: State
) -> tuple[Decimal, Decimal] | None:
if (left_dec := _decimal_state(left.state)) is None:
def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None:
if (left_dec := _decimal_state(left)) is None:
return None
return (left_dec, left_dec)
@ -160,10 +161,8 @@ class _Right(_IntegrationMethod):
) -> Decimal:
return self.calculate_area_with_one_state(elapsed_time, right)
def validate_states(
self, left: State, right: State
) -> tuple[Decimal, Decimal] | None:
if (right_dec := _decimal_state(right.state)) is None:
def validate_states(self, left: str, right: str) -> tuple[Decimal, Decimal] | None:
if (right_dec := _decimal_state(right)) is None:
return None
return (right_dec, right_dec)
@ -183,7 +182,7 @@ _NAME_TO_INTEGRATION_METHOD: dict[str, type[_IntegrationMethod]] = {
class _IntegrationTrigger(Enum):
StateChange = "state_change"
StateEvent = "state_event"
TimeElapsed = "time_elapsed"
@ -343,7 +342,7 @@ class IntegrationSensor(RestoreSensor):
)
self._max_sub_interval_exceeded_callback: CALLBACK_TYPE = lambda *args: None
self._last_integration_time: datetime = datetime.now(tz=UTC)
self._last_integration_trigger = _IntegrationTrigger.StateChange
self._last_integration_trigger = _IntegrationTrigger.StateEvent
self._attr_suggested_display_precision = round_digits or 2
def _calculate_unit(self, source_unit: str) -> str:
@ -433,9 +432,9 @@ class IntegrationSensor(RestoreSensor):
source_state = self.hass.states.get(self._sensor_source_id)
self._schedule_max_sub_interval_exceeded_if_state_is_numeric(source_state)
self.async_on_remove(self._cancel_max_sub_interval_exceeded_callback)
handle_state_change = self._integrate_on_state_change_and_max_sub_interval
handle_state_update = self._integrate_on_state_update_with_max_sub_interval
else:
handle_state_change = self._integrate_on_state_change_callback
handle_state_update = self._integrate_on_state_update
if (
state := self.hass.states.get(self._source_entity)
@ -443,16 +442,50 @@ class IntegrationSensor(RestoreSensor):
self._derive_and_set_attributes_from_state(state)
self.async_on_remove(
async_track_state_change_event(
self.hass,
[self._sensor_source_id],
handle_state_change,
self.hass.bus.async_listen(
EVENT_STATE_CHANGED,
partial(self._integrate_on_state_change_callback, handle_state_update),
event_filter=callback(
lambda event_data: event_data["entity_id"] == self._sensor_source_id
),
run_immediately=True,
)
)
self.async_on_remove(
self.hass.bus.async_listen(
EVENT_STATE_REPORTED,
partial(self._integrate_on_state_report_callback, handle_state_update),
event_filter=callback(
lambda event_data: event_data["entity_id"] == self._sensor_source_id
),
run_immediately=True,
)
)
@callback
def _integrate_on_state_change_and_max_sub_interval(
self, event: Event[EventStateChangedData]
def _integrate_on_state_change_callback(
self, handle_state_update: StateUpdateFunc, event: Event[EventStateChangedData]
) -> None:
"""Handle sensor state change."""
return handle_state_update(
None, event.data["old_state"], event.data["new_state"]
)
@callback
def _integrate_on_state_report_callback(
self, handle_state_update: StateUpdateFunc, event: Event[EventStateReportedData]
) -> None:
"""Handle sensor state report."""
return handle_state_update(
event.data["old_last_reported"], None, event.data["new_state"]
)
@callback
def _integrate_on_state_update_with_max_sub_interval(
self,
old_last_reported: datetime | None,
old_state: State | None,
new_state: State | None,
) -> None:
"""Integrate based on state change and time.
@ -460,28 +493,20 @@ class IntegrationSensor(RestoreSensor):
reschedules time based integration.
"""
self._cancel_max_sub_interval_exceeded_callback()
old_state = event.data["old_state"]
new_state = event.data["new_state"]
try:
self._integrate_on_state_change(old_state, new_state)
self._last_integration_trigger = _IntegrationTrigger.StateChange
self._integrate_on_state_update(old_last_reported, old_state, new_state)
self._last_integration_trigger = _IntegrationTrigger.StateEvent
self._last_integration_time = datetime.now(tz=UTC)
finally:
# When max_sub_interval exceeds without state change the source is assumed
# constant with the last known state (new_state).
self._schedule_max_sub_interval_exceeded_if_state_is_numeric(new_state)
@callback
def _integrate_on_state_change_callback(
self, event: Event[EventStateChangedData]
) -> None:
"""Handle the sensor state changes."""
old_state = event.data["old_state"]
new_state = event.data["new_state"]
return self._integrate_on_state_change(old_state, new_state)
def _integrate_on_state_change(
self, old_state: State | None, new_state: State | None
def _integrate_on_state_update(
self,
old_last_reported: datetime | None,
old_state: State | None,
new_state: State | None,
) -> None:
if new_state is None:
return
@ -491,21 +516,33 @@ class IntegrationSensor(RestoreSensor):
self.async_write_ha_state()
return
if old_state:
# state has changed, we recover old_state from the event
old_state_state = old_state.state
old_last_reported = old_state.last_reported
else:
# event state reported without any state change
old_state_state = new_state.state
self._attr_available = True
self._derive_and_set_attributes_from_state(new_state)
if old_state is None:
if old_last_reported is None and old_state is None:
self.async_write_ha_state()
return
if not (states := self._method.validate_states(old_state, new_state)):
if not (
states := self._method.validate_states(old_state_state, new_state.state)
):
self.async_write_ha_state()
return
if TYPE_CHECKING:
assert old_last_reported is not None
elapsed_seconds = Decimal(
(new_state.last_updated - old_state.last_updated).total_seconds()
if self._last_integration_trigger == _IntegrationTrigger.StateChange
else (new_state.last_updated - self._last_integration_time).total_seconds()
(new_state.last_reported - old_last_reported).total_seconds()
if self._last_integration_trigger == _IntegrationTrigger.StateEvent
else (new_state.last_reported - self._last_integration_time).total_seconds()
)
area = self._method.calculate_area_with_two_states(elapsed_seconds, *states)

View file

@ -294,28 +294,16 @@ async def test_restore_state_failed(hass: HomeAssistant, extra_attributes) -> No
assert state.state == STATE_UNKNOWN
@pytest.mark.parametrize("force_update", [False, True])
@pytest.mark.parametrize(
("force_update", "sequence"),
"sequence",
[
(
False,
(
(20, 10, 1.67),
(30, 30, 5.0),
(40, 5, 7.92),
(50, 5, 7.92),
(60, 0, 8.75),
),
),
(
True,
(
(20, 10, 1.67),
(30, 30, 5.0),
(40, 5, 7.92),
(50, 5, 8.75),
(60, 0, 9.17),
),
(20, 10, 1.67),
(30, 30, 5.0),
(40, 5, 7.92),
(50, 5, 8.75),
(60, 0, 9.17),
),
],
)
@ -358,28 +346,16 @@ async def test_trapezoidal(
assert state.attributes.get("unit_of_measurement") == UnitOfEnergy.KILO_WATT_HOUR
@pytest.mark.parametrize("force_update", [False, True])
@pytest.mark.parametrize(
("force_update", "sequence"),
"sequence",
[
(
False,
(
(20, 10, 0.0),
(30, 30, 1.67),
(40, 5, 6.67),
(50, 5, 6.67),
(60, 0, 8.33),
),
),
(
True,
(
(20, 10, 0.0),
(30, 30, 1.67),
(40, 5, 6.67),
(50, 5, 7.5),
(60, 0, 8.33),
),
(20, 10, 0.0),
(30, 30, 1.67),
(40, 5, 6.67),
(50, 5, 7.5),
(60, 0, 8.33),
),
],
)
@ -425,28 +401,16 @@ async def test_left(
assert state.attributes.get("unit_of_measurement") == UnitOfEnergy.KILO_WATT_HOUR
@pytest.mark.parametrize("force_update", [False, True])
@pytest.mark.parametrize(
("force_update", "sequence"),
"sequence",
[
(
False,
(
(20, 10, 3.33),
(30, 30, 8.33),
(40, 5, 9.17),
(50, 5, 9.17),
(60, 0, 9.17),
),
),
(
True,
(
(20, 10, 3.33),
(30, 30, 8.33),
(40, 5, 9.17),
(50, 5, 10.0),
(60, 0, 10.0),
),
(20, 10, 3.33),
(30, 30, 8.33),
(40, 5, 9.17),
(50, 5, 10.0),
(60, 0, 10.0),
),
],
)