From e1036b3af0169113cb649b6aee4d8511dc7b2172 Mon Sep 17 00:00:00 2001 From: Ron Weikamp <15732230+ronweikamp@users.noreply.github.com> Date: Tue, 26 Mar 2024 19:09:48 +0100 Subject: [PATCH] Refactor Riemann sum integral sensor to prepare for time based trigger (#113932) * Refactor Integration sensor. * Use local simple function to verify the State is numeric. * Merge two methods to one. * Method renaming: _handle_state_change * Move async_write_ha_state to the caller. * Add comment on why attr_icon is set to None * Remove possible None type of State in validation methods. * Use a dict to map method name to method class. * Explain derived unit after integration. * Renaming to _multiply_unit_with_time and elaborate in docstring. * Set integral unit_of_measurement explicitly to None if source unit_of_measurement is None * One function for unit of measurement related steps. * Improve docstring of _multiply_unit_with_time Co-authored-by: Erik Montnemery * Apply f-string suggestions from code review Co-authored-by: Erik Montnemery * Be more clear in comment about removing the sensors icon default. * Apply suggestions from code review Co-authored-by: Diogo Gomes * Update homeassistant/components/integration/sensor.py * Update homeassistant/components/integration/sensor.py * Update homeassistant/components/integration/sensor.py --------- Co-authored-by: Erik Montnemery Co-authored-by: Diogo Gomes --- .../components/integration/sensor.py | 249 ++++++++++-------- tests/components/integration/test_sensor.py | 2 +- 2 files changed, 146 insertions(+), 105 deletions(-) diff --git a/homeassistant/components/integration/sensor.py b/homeassistant/components/integration/sensor.py index 956b868272f..62a0dbdec78 100644 --- a/homeassistant/components/integration/sensor.py +++ b/homeassistant/components/integration/sensor.py @@ -2,6 +2,7 @@ from __future__ import annotations +from abc import ABC, abstractmethod from dataclasses import dataclass from decimal import Decimal, DecimalException, InvalidOperation import logging @@ -27,8 +28,9 @@ from homeassistant.const import ( STATE_UNKNOWN, UnitOfTime, ) -from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.core import Event, HomeAssistant, State, callback from homeassistant.helpers import ( + condition, config_validation as cv, device_registry as dr, entity_registry as er, @@ -89,6 +91,72 @@ PLATFORM_SCHEMA = vol.All( ) +class _IntegrationMethod(ABC): + @staticmethod + def from_name(method_name: str) -> _IntegrationMethod: + return _NAME_TO_INTEGRATION_METHOD[method_name]() + + @abstractmethod + def validate_states(self, left: State, right: State) -> bool: + """Check state requirements for integration.""" + + @abstractmethod + def calculate_area_with_two_states( + self, elapsed_time: float, left: State, right: State + ) -> Decimal: + """Calculate area given two states.""" + + def calculate_area_with_one_state( + self, elapsed_time: float, constant_state: State + ) -> Decimal: + return Decimal(constant_state.state) * Decimal(elapsed_time) + + +class _Trapezoidal(_IntegrationMethod): + def calculate_area_with_two_states( + self, elapsed_time: float, left: State, right: State + ) -> Decimal: + return Decimal(elapsed_time) * (Decimal(left.state) + Decimal(right.state)) / 2 + + def validate_states(self, left: State, right: State) -> bool: + return _is_numeric_state(left) and _is_numeric_state(right) + + +class _Left(_IntegrationMethod): + def calculate_area_with_two_states( + self, elapsed_time: float, left: State, right: State + ) -> Decimal: + return self.calculate_area_with_one_state(elapsed_time, left) + + def validate_states(self, left: State, right: State) -> bool: + return _is_numeric_state(left) + + +class _Right(_IntegrationMethod): + def calculate_area_with_two_states( + self, elapsed_time: float, left: State, right: State + ) -> Decimal: + return self.calculate_area_with_one_state(elapsed_time, right) + + def validate_states(self, left: State, right: State) -> bool: + return _is_numeric_state(right) + + +def _is_numeric_state(state: State) -> bool: + try: + float(state.state) + return True + except (ValueError, TypeError): + return False + + +_NAME_TO_INTEGRATION_METHOD: dict[str, type[_IntegrationMethod]] = { + METHOD_LEFT: _Left, + METHOD_RIGHT: _Right, + METHOD_TRAPEZOIDAL: _Trapezoidal, +} + + @dataclass class IntegrationSensorExtraStoredData(SensorExtraStoredData): """Object to hold extra stored data.""" @@ -231,10 +299,10 @@ class IntegrationSensor(RestoreSensor): self._sensor_source_id = source_entity self._round_digits = round_digits self._state: Decimal | None = None - self._method = integration_method + self._method = _IntegrationMethod.from_name(integration_method) self._attr_name = name if name is not None else f"{source_entity} integral" - self._unit_template = f"{'' if unit_prefix is None else unit_prefix}{{}}" + self._unit_prefix_string = "" if unit_prefix is None else unit_prefix self._unit_of_measurement: str | None = None self._unit_prefix = UNIT_PREFIXES[unit_prefix] self._unit_time = UNIT_TIME[unit_time] @@ -244,15 +312,52 @@ class IntegrationSensor(RestoreSensor): self._last_valid_state: Decimal | None = None self._attr_device_info = device_info - def _unit(self, source_unit: str) -> str: - """Derive unit from the source sensor, SI prefix and time unit.""" + def _calculate_unit(self, source_unit: str) -> str: + """Multiply source_unit with time unit of the integral. + + Possibly cancelling out a time unit in the denominator of the source_unit. + Note that this is a heuristic string manipulation method and might not + transform all source units in a sensible way. + + Examples: + - Speed to distance: 'km/h' and 'h' will be transformed to 'km' + - Power to energy: 'W' and 'h' will be transformed to 'Wh' + + """ unit_time = self._unit_time_str if source_unit.endswith(f"/{unit_time}"): integral_unit = source_unit[0 : (-(1 + len(unit_time)))] else: integral_unit = f"{source_unit}{unit_time}" - return self._unit_template.format(integral_unit) + return f"{self._unit_prefix_string}{integral_unit}" + + def _derive_and_set_attributes_from_state(self, source_state: State) -> None: + source_unit = source_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) + if source_unit is not None: + self._unit_of_measurement = self._calculate_unit(source_unit) + else: + # If the source has no defined unit we cannot derive a unit for the integral + self._unit_of_measurement = None + + if ( + self.device_class is None + and source_state.attributes.get(ATTR_DEVICE_CLASS) + == SensorDeviceClass.POWER + ): + self._attr_device_class = SensorDeviceClass.ENERGY + self._attr_icon = None # Remove this sensors icon default and allow to fallback to the ENERGY default + + def _update_integral(self, area: Decimal) -> None: + area_scaled = area / (self._unit_prefix * self._unit_time) + if isinstance(self._state, Decimal): + self._state += area_scaled + else: + self._state = area_scaled + _LOGGER.debug( + "area = %s, area_scaled = %s new state = %s", area, area_scaled, self._state + ) + self._last_valid_state = self._state async def async_added_to_hass(self) -> None: """Handle entity which will be added.""" @@ -292,109 +397,45 @@ class IntegrationSensor(RestoreSensor): self._attr_device_class = state.attributes.get(ATTR_DEVICE_CLASS) self._unit_of_measurement = state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) - @callback - def calc_integration(event: Event[EventStateChangedData]) -> None: - """Handle the sensor state changes.""" - old_state = event.data["old_state"] - new_state = event.data["new_state"] - - if ( - source_state := self.hass.states.get(self._sensor_source_id) - ) is None or source_state.state == STATE_UNAVAILABLE: - self._attr_available = False - self.async_write_ha_state() - return - - self._attr_available = True - - if old_state is None or new_state is None: - # we can't calculate the elapsed time, so we can't calculate the integral - return - - unit = new_state.attributes.get(ATTR_UNIT_OF_MEASUREMENT) - if unit is not None: - self._unit_of_measurement = self._unit(unit) - - if ( - self.device_class is None - and new_state.attributes.get(ATTR_DEVICE_CLASS) - == SensorDeviceClass.POWER - ): - self._attr_device_class = SensorDeviceClass.ENERGY - self._attr_icon = None - - self.async_write_ha_state() - - try: - # integration as the Riemann integral of previous measures. - elapsed_time = ( - new_state.last_updated - old_state.last_updated - ).total_seconds() - - if ( - self._method == METHOD_TRAPEZOIDAL - and new_state.state - not in ( - STATE_UNKNOWN, - STATE_UNAVAILABLE, - ) - and old_state.state - not in ( - STATE_UNKNOWN, - STATE_UNAVAILABLE, - ) - ): - area = ( - (Decimal(new_state.state) + Decimal(old_state.state)) - * Decimal(elapsed_time) - / 2 - ) - elif self._method == METHOD_LEFT and old_state.state not in ( - STATE_UNKNOWN, - STATE_UNAVAILABLE, - ): - area = Decimal(old_state.state) * Decimal(elapsed_time) - elif self._method == METHOD_RIGHT and new_state.state not in ( - STATE_UNKNOWN, - STATE_UNAVAILABLE, - ): - area = Decimal(new_state.state) * Decimal(elapsed_time) - else: - _LOGGER.debug( - "Could not apply method %s to %s -> %s", - self._method, - old_state.state, - new_state.state, - ) - return - - integral = area / (self._unit_prefix * self._unit_time) - _LOGGER.debug( - "area = %s, integral = %s state = %s", area, integral, self._state - ) - assert isinstance(integral, Decimal) - except ValueError as err: - _LOGGER.warning("While calculating integration: %s", err) - except DecimalException as err: - _LOGGER.warning( - "Invalid state (%s > %s): %s", old_state.state, new_state.state, err - ) - except AssertionError as err: - _LOGGER.error("Could not calculate integral: %s", err) - else: - if isinstance(self._state, Decimal): - self._state += integral - else: - self._state = integral - self._last_valid_state = self._state - self.async_write_ha_state() - self.async_on_remove( async_track_state_change_event( - self.hass, [self._sensor_source_id], calc_integration + self.hass, + [self._sensor_source_id], + self._handle_state_change, ) ) + @callback + def _handle_state_change(self, event: Event[EventStateChangedData]) -> None: + old_state = event.data["old_state"] + new_state = event.data["new_state"] + + if old_state is None or new_state is None: + return + + if condition.state(self.hass, new_state, [STATE_UNAVAILABLE]): + self._attr_available = False + self.async_write_ha_state() + return + + self._attr_available = True + self._derive_and_set_attributes_from_state(new_state) + + if not self._method.validate_states(old_state, new_state): + self.async_write_ha_state() + return + + elapsed_seconds = ( + new_state.last_updated - old_state.last_updated + ).total_seconds() + + area = self._method.calculate_area_with_two_states( + elapsed_seconds, old_state, new_state + ) + + self._update_integral(area) + self.async_write_ha_state() + @property def native_value(self) -> Decimal | None: """Return the state of the sensor.""" diff --git a/tests/components/integration/test_sensor.py b/tests/components/integration/test_sensor.py index 904c31e9896..53763247bdf 100644 --- a/tests/components/integration/test_sensor.py +++ b/tests/components/integration/test_sensor.py @@ -563,7 +563,7 @@ async def test_units(hass: HomeAssistant) -> None: # When source state goes to None / Unknown, expect an early exit without # changes to the state or unit_of_measurement - hass.states.async_set(entity_id, None, None) + hass.states.async_set(entity_id, None, {"unit_of_measurement": UnitOfPower.WATT}) await hass.async_block_till_done() new_state = hass.states.get("sensor.integration")