From 2f1f32f0bb7c5d772f246a4fc28e60323ac8c8c8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 29 May 2023 13:50:40 -0500 Subject: [PATCH] Make unit converter use a factory to avoid looking up the ratios each conversion (#93706) --- .../components/recorder/statistics.py | 61 ++++----- homeassistant/util/unit_conversion.py | 121 +++++++++++++----- .../components/recorder/test_websocket_api.py | 30 +++++ tests/util/test_unit_conversion.py | 81 ++++++++++++ 4 files changed, 227 insertions(+), 66 deletions(-) diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 70f35d0349a..7dbe29cea7d 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -219,50 +219,34 @@ def _get_statistic_to_display_unit_converter( if display_unit == statistic_unit: return None - convert = converter.convert - - def _from_normalized_unit(val: float | None) -> float | None: - """Return val.""" - if val is None: - return val - return convert(val, statistic_unit, display_unit) - - return _from_normalized_unit + return converter.converter_factory_allow_none( + from_unit=statistic_unit, to_unit=display_unit + ) def _get_display_to_statistic_unit_converter( display_unit: str | None, statistic_unit: str | None, -) -> Callable[[float], float]: +) -> Callable[[float], float] | None: """Prepare a converter from the display unit to the statistics unit.""" - - def no_conversion(val: float) -> float: - """Return val.""" - return val - - if (converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(statistic_unit)) is None: - return no_conversion - - return partial(converter.convert, from_unit=display_unit, to_unit=statistic_unit) + if ( + display_unit == statistic_unit + or (converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(statistic_unit)) is None + ): + return None + return converter.converter_factory(from_unit=display_unit, to_unit=statistic_unit) def _get_unit_converter( from_unit: str, to_unit: str -) -> Callable[[float | None], float | None]: +) -> Callable[[float | None], float | None] | None: """Prepare a converter from a unit to another unit.""" - - def convert_units( - val: float | None, conv: type[BaseUnitConverter], from_unit: str, to_unit: str - ) -> float | None: - """Return converted val.""" - if val is None: - return val - return conv.convert(val, from_unit=from_unit, to_unit=to_unit) - for conv in STATISTIC_UNIT_TO_UNIT_CONVERTER.values(): if from_unit in conv.VALID_UNITS and to_unit in conv.VALID_UNITS: - return partial( - convert_units, conv=conv, from_unit=from_unit, to_unit=to_unit + if from_unit == to_unit: + return None + return conv.converter_factory_allow_none( + from_unit=from_unit, to_unit=to_unit ) raise HomeAssistantError @@ -2290,10 +2274,10 @@ def adjust_statistics( return True statistic_unit = metadata[statistic_id][1]["unit_of_measurement"] - convert = _get_display_to_statistic_unit_converter( + if convert := _get_display_to_statistic_unit_converter( adjustment_unit, statistic_unit - ) - sum_adjustment = convert(sum_adjustment) + ): + sum_adjustment = convert(sum_adjustment) _adjust_sum_statistics( session, @@ -2360,7 +2344,14 @@ def change_statistics_unit( metadata_id = metadata[0] - convert = _get_unit_converter(old_unit, new_unit) + if not (convert := _get_unit_converter(old_unit, new_unit)): + _LOGGER.warning( + "Statistics unit of measurement for %s is already %s", + statistic_id, + new_unit, + ) + return + tables: tuple[type[StatisticsBase], ...] = ( Statistics, StatisticsShortTerm, diff --git a/homeassistant/util/unit_conversion.py b/homeassistant/util/unit_conversion.py index a22a204e69a..5ce31b072cf 100644 --- a/homeassistant/util/unit_conversion.py +++ b/homeassistant/util/unit_conversion.py @@ -1,6 +1,9 @@ """Typing Helpers for Home Assistant.""" from __future__ import annotations +from collections.abc import Callable +from functools import lru_cache + from homeassistant.const import ( CONCENTRATION_PARTS_PER_BILLION, CONCENTRATION_PARTS_PER_MILLION, @@ -67,30 +70,49 @@ class BaseUnitConverter: @classmethod def convert(cls, value: float, from_unit: str | None, to_unit: str | None) -> float: """Convert one unit of measurement to another.""" - if from_unit == to_unit: - return value - - try: - from_ratio = cls._UNIT_CONVERSION[from_unit] - except KeyError as err: - raise HomeAssistantError( - UNIT_NOT_RECOGNIZED_TEMPLATE.format(from_unit, cls.UNIT_CLASS) - ) from err - - try: - to_ratio = cls._UNIT_CONVERSION[to_unit] - except KeyError as err: - raise HomeAssistantError( - UNIT_NOT_RECOGNIZED_TEMPLATE.format(to_unit, cls.UNIT_CLASS) - ) from err - - new_value = value / from_ratio - return new_value * to_ratio + return cls.converter_factory(from_unit, to_unit)(value) @classmethod + @lru_cache + def converter_factory( + cls, from_unit: str | None, to_unit: str | None + ) -> Callable[[float], float]: + """Return a function to convert one unit of measurement to another.""" + if from_unit == to_unit: + return lambda value: value + from_ratio, to_ratio = cls._get_from_to_ratio(from_unit, to_unit) + return lambda val: (val / from_ratio) * to_ratio + + @classmethod + def _get_from_to_ratio( + cls, from_unit: str | None, to_unit: str | None + ) -> tuple[float, float]: + """Get unit ratio between units of measurement.""" + unit_conversion = cls._UNIT_CONVERSION + try: + return unit_conversion[from_unit], unit_conversion[to_unit] + except KeyError as err: + raise HomeAssistantError( + UNIT_NOT_RECOGNIZED_TEMPLATE.format(err.args[0], cls.UNIT_CLASS) + ) from err + + @classmethod + @lru_cache + def converter_factory_allow_none( + cls, from_unit: str | None, to_unit: str | None + ) -> Callable[[float | None], float | None]: + """Return a function to convert one unit of measurement to another which allows None.""" + if from_unit == to_unit: + return lambda value: value + from_ratio, to_ratio = cls._get_from_to_ratio(from_unit, to_unit) + return lambda val: None if val is None else (val / from_ratio) * to_ratio + + @classmethod + @lru_cache def get_unit_ratio(cls, from_unit: str | None, to_unit: str | None) -> float: """Get unit ratio between units of measurement.""" - return cls._UNIT_CONVERSION[from_unit] / cls._UNIT_CONVERSION[to_unit] + from_ratio, to_ratio = cls._get_from_to_ratio(from_unit, to_unit) + return from_ratio / to_ratio class DataRateConverter(BaseUnitConverter): @@ -339,7 +361,37 @@ class TemperatureConverter(BaseUnitConverter): } @classmethod - def convert(cls, value: float, from_unit: str | None, to_unit: str | None) -> float: + @lru_cache(maxsize=8) + def converter_factory( + cls, from_unit: str | None, to_unit: str | None + ) -> Callable[[float], float]: + """Return a function to convert a temperature from one unit to another.""" + if from_unit == to_unit: + # Return a function that does nothing. This is not + # in _converter_factory because we do not want to wrap + # it with the None check in converter_factory_allow_none. + return lambda value: value + + return cls._converter_factory(from_unit, to_unit) + + @classmethod + @lru_cache(maxsize=8) + def converter_factory_allow_none( + cls, from_unit: str | None, to_unit: str | None + ) -> Callable[[float | None], float | None]: + """Return a function to convert a temperature from one unit to another which allows None.""" + if from_unit == to_unit: + # Return a function that does nothing. This is not + # in _converter_factory because we do not want to wrap + # it with the None check in this case. + return lambda value: value + convert = cls._converter_factory(from_unit, to_unit) + return lambda value: None if value is None else convert(value) + + @classmethod + def _converter_factory( + cls, from_unit: str | None, to_unit: str | None + ) -> Callable[[float], float]: """Convert a temperature from one unit to another. eg. 10°C will return 50°F @@ -349,32 +401,29 @@ class TemperatureConverter(BaseUnitConverter): """ # We cannot use the implementation from BaseUnitConverter here because the # temperature units do not use the same floor: 0°C, 0°F and 0K do not align - if from_unit == to_unit: - return value - if from_unit == UnitOfTemperature.CELSIUS: if to_unit == UnitOfTemperature.FAHRENHEIT: - return cls._celsius_to_fahrenheit(value) + return cls._celsius_to_fahrenheit if to_unit == UnitOfTemperature.KELVIN: - return cls._celsius_to_kelvin(value) + return cls._celsius_to_kelvin raise HomeAssistantError( UNIT_NOT_RECOGNIZED_TEMPLATE.format(to_unit, cls.UNIT_CLASS) ) if from_unit == UnitOfTemperature.FAHRENHEIT: if to_unit == UnitOfTemperature.CELSIUS: - return cls._fahrenheit_to_celsius(value) + return cls._fahrenheit_to_celsius if to_unit == UnitOfTemperature.KELVIN: - return cls._celsius_to_kelvin(cls._fahrenheit_to_celsius(value)) + return cls._fahrenheit_to_kelvin raise HomeAssistantError( UNIT_NOT_RECOGNIZED_TEMPLATE.format(to_unit, cls.UNIT_CLASS) ) if from_unit == UnitOfTemperature.KELVIN: if to_unit == UnitOfTemperature.CELSIUS: - return cls._kelvin_to_celsius(value) + return cls._kelvin_to_celsius if to_unit == UnitOfTemperature.FAHRENHEIT: - return cls._celsius_to_fahrenheit(cls._kelvin_to_celsius(value)) + return cls._kelvin_to_fahrenheit raise HomeAssistantError( UNIT_NOT_RECOGNIZED_TEMPLATE.format(to_unit, cls.UNIT_CLASS) ) @@ -393,7 +442,17 @@ class TemperatureConverter(BaseUnitConverter): """ # We use BaseUnitConverter implementation here because we are only interested # in the ratio between the units. - return super().convert(interval, from_unit, to_unit) + return super().converter_factory(from_unit, to_unit)(interval) + + @classmethod + def _kelvin_to_fahrenheit(cls, kelvin: float) -> float: + """Convert a temperature in Kelvin to Fahrenheit.""" + return (kelvin - 273.15) * 1.8 + 32.0 + + @classmethod + def _fahrenheit_to_kelvin(cls, fahrenheit: float) -> float: + """Convert a temperature in Fahrenheit to Kelvin.""" + return 273.15 + ((fahrenheit - 32.0) / 1.8) @classmethod def _fahrenheit_to_celsius(cls, fahrenheit: float) -> float: diff --git a/tests/components/recorder/test_websocket_api.py b/tests/components/recorder/test_websocket_api.py index 37f5dc77d00..2c76c947350 100644 --- a/tests/components/recorder/test_websocket_api.py +++ b/tests/components/recorder/test_websocket_api.py @@ -1973,6 +1973,36 @@ async def test_change_statistics_unit( ], } + # Changing to the same unit is allowed but does nothing + await client.send_json( + { + "id": 6, + "type": "recorder/change_statistics_unit", + "statistic_id": "sensor.test", + "new_unit_of_measurement": "W", + "old_unit_of_measurement": "W", + } + ) + response = await client.receive_json() + assert response["success"] + await async_recorder_block_till_done(hass) + + await client.send_json({"id": 7, "type": "recorder/list_statistic_ids"}) + response = await client.receive_json() + assert response["success"] + assert response["result"] == [ + { + "statistic_id": "sensor.test", + "display_unit_of_measurement": "kW", + "has_mean": True, + "has_sum": False, + "name": None, + "source": "recorder", + "statistics_unit_of_measurement": "W", + "unit_class": "power", + } + ] + async def test_change_statistics_unit_errors( recorder_mock: Recorder, diff --git a/tests/util/test_unit_conversion.py b/tests/util/test_unit_conversion.py index 1aea64201b5..18f0c9a12c1 100644 --- a/tests/util/test_unit_conversion.py +++ b/tests/util/test_unit_conversion.py @@ -2,6 +2,7 @@ from __future__ import annotations import inspect +from itertools import chain import pytest @@ -534,6 +535,86 @@ def test_unit_conversion( assert converter.convert(value, from_unit, to_unit) == pytest.approx(expected) +@pytest.mark.parametrize( + ("converter", "value", "from_unit", "expected", "to_unit"), + [ + # Process all items in _CONVERTED_VALUE + (converter, value, from_unit, expected, to_unit) + for converter, item in _CONVERTED_VALUE.items() + for value, from_unit, expected, to_unit in item + ], +) +def test_unit_conversion_factory( + converter: type[BaseUnitConverter], + value: float, + from_unit: str, + expected: float, + to_unit: str, +) -> None: + """Test conversion to other units.""" + assert converter.converter_factory(from_unit, to_unit)(value) == pytest.approx( + expected + ) + + +def test_unit_conversion_factory_allow_none_with_none() -> None: + """Test test_unit_conversion_factory_allow_none with None.""" + assert ( + SpeedConverter.converter_factory_allow_none( + UnitOfSpeed.FEET_PER_SECOND, UnitOfSpeed.FEET_PER_SECOND + )(1) + == 1 + ) + assert ( + SpeedConverter.converter_factory_allow_none( + UnitOfSpeed.FEET_PER_SECOND, UnitOfSpeed.FEET_PER_SECOND + )(None) + is None + ) + assert ( + TemperatureConverter.converter_factory_allow_none( + UnitOfTemperature.CELSIUS, UnitOfTemperature.CELSIUS + )(1) + == 1 + ) + assert ( + TemperatureConverter.converter_factory_allow_none( + UnitOfTemperature.CELSIUS, UnitOfTemperature.CELSIUS + )(None) + is None + ) + + +@pytest.mark.parametrize( + ("converter", "value", "from_unit", "expected", "to_unit"), + chain( + [ + # Process all items in _CONVERTED_VALUE + (converter, value, from_unit, expected, to_unit) + for converter, item in _CONVERTED_VALUE.items() + for value, from_unit, expected, to_unit in item + ], + [ + # Process all items in _CONVERTED_VALUE and replace the value with None + (converter, None, from_unit, None, to_unit) + for converter, item in _CONVERTED_VALUE.items() + for value, from_unit, expected, to_unit in item + ], + ), +) +def test_unit_conversion_factory_allow_none( + converter: type[BaseUnitConverter], + value: float, + from_unit: str, + expected: float, + to_unit: str, +) -> None: + """Test conversion to other units.""" + assert converter.converter_factory_allow_none(from_unit, to_unit)( + value + ) == pytest.approx(expected) + + @pytest.mark.parametrize( ("value", "from_unit", "expected", "to_unit"), [