Make unit converter use a factory to avoid looking up the ratios each conversion (#93706)
This commit is contained in:
parent
7f3f2eea38
commit
2f1f32f0bb
4 changed files with 227 additions and 66 deletions
|
@ -219,50 +219,34 @@ def _get_statistic_to_display_unit_converter(
|
|||
if display_unit == statistic_unit:
|
||||
return None
|
||||
|
||||
convert = converter.convert
|
||||
|
||||
def _from_normalized_unit(val: float | None) -> float | None:
|
||||
"""Return val."""
|
||||
if val is None:
|
||||
return val
|
||||
return convert(val, statistic_unit, display_unit)
|
||||
|
||||
return _from_normalized_unit
|
||||
return converter.converter_factory_allow_none(
|
||||
from_unit=statistic_unit, to_unit=display_unit
|
||||
)
|
||||
|
||||
|
||||
def _get_display_to_statistic_unit_converter(
|
||||
display_unit: str | None,
|
||||
statistic_unit: str | None,
|
||||
) -> Callable[[float], float]:
|
||||
) -> Callable[[float], float] | None:
|
||||
"""Prepare a converter from the display unit to the statistics unit."""
|
||||
|
||||
def no_conversion(val: float) -> float:
|
||||
"""Return val."""
|
||||
return val
|
||||
|
||||
if (converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(statistic_unit)) is None:
|
||||
return no_conversion
|
||||
|
||||
return partial(converter.convert, from_unit=display_unit, to_unit=statistic_unit)
|
||||
if (
|
||||
display_unit == statistic_unit
|
||||
or (converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(statistic_unit)) is None
|
||||
):
|
||||
return None
|
||||
return converter.converter_factory(from_unit=display_unit, to_unit=statistic_unit)
|
||||
|
||||
|
||||
def _get_unit_converter(
|
||||
from_unit: str, to_unit: str
|
||||
) -> Callable[[float | None], float | None]:
|
||||
) -> Callable[[float | None], float | None] | None:
|
||||
"""Prepare a converter from a unit to another unit."""
|
||||
|
||||
def convert_units(
|
||||
val: float | None, conv: type[BaseUnitConverter], from_unit: str, to_unit: str
|
||||
) -> float | None:
|
||||
"""Return converted val."""
|
||||
if val is None:
|
||||
return val
|
||||
return conv.convert(val, from_unit=from_unit, to_unit=to_unit)
|
||||
|
||||
for conv in STATISTIC_UNIT_TO_UNIT_CONVERTER.values():
|
||||
if from_unit in conv.VALID_UNITS and to_unit in conv.VALID_UNITS:
|
||||
return partial(
|
||||
convert_units, conv=conv, from_unit=from_unit, to_unit=to_unit
|
||||
if from_unit == to_unit:
|
||||
return None
|
||||
return conv.converter_factory_allow_none(
|
||||
from_unit=from_unit, to_unit=to_unit
|
||||
)
|
||||
raise HomeAssistantError
|
||||
|
||||
|
@ -2290,10 +2274,10 @@ def adjust_statistics(
|
|||
return True
|
||||
|
||||
statistic_unit = metadata[statistic_id][1]["unit_of_measurement"]
|
||||
convert = _get_display_to_statistic_unit_converter(
|
||||
if convert := _get_display_to_statistic_unit_converter(
|
||||
adjustment_unit, statistic_unit
|
||||
)
|
||||
sum_adjustment = convert(sum_adjustment)
|
||||
):
|
||||
sum_adjustment = convert(sum_adjustment)
|
||||
|
||||
_adjust_sum_statistics(
|
||||
session,
|
||||
|
@ -2360,7 +2344,14 @@ def change_statistics_unit(
|
|||
|
||||
metadata_id = metadata[0]
|
||||
|
||||
convert = _get_unit_converter(old_unit, new_unit)
|
||||
if not (convert := _get_unit_converter(old_unit, new_unit)):
|
||||
_LOGGER.warning(
|
||||
"Statistics unit of measurement for %s is already %s",
|
||||
statistic_id,
|
||||
new_unit,
|
||||
)
|
||||
return
|
||||
|
||||
tables: tuple[type[StatisticsBase], ...] = (
|
||||
Statistics,
|
||||
StatisticsShortTerm,
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
"""Typing Helpers for Home Assistant."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
|
||||
from homeassistant.const import (
|
||||
CONCENTRATION_PARTS_PER_BILLION,
|
||||
CONCENTRATION_PARTS_PER_MILLION,
|
||||
|
@ -67,30 +70,49 @@ class BaseUnitConverter:
|
|||
@classmethod
|
||||
def convert(cls, value: float, from_unit: str | None, to_unit: str | None) -> float:
|
||||
"""Convert one unit of measurement to another."""
|
||||
if from_unit == to_unit:
|
||||
return value
|
||||
|
||||
try:
|
||||
from_ratio = cls._UNIT_CONVERSION[from_unit]
|
||||
except KeyError as err:
|
||||
raise HomeAssistantError(
|
||||
UNIT_NOT_RECOGNIZED_TEMPLATE.format(from_unit, cls.UNIT_CLASS)
|
||||
) from err
|
||||
|
||||
try:
|
||||
to_ratio = cls._UNIT_CONVERSION[to_unit]
|
||||
except KeyError as err:
|
||||
raise HomeAssistantError(
|
||||
UNIT_NOT_RECOGNIZED_TEMPLATE.format(to_unit, cls.UNIT_CLASS)
|
||||
) from err
|
||||
|
||||
new_value = value / from_ratio
|
||||
return new_value * to_ratio
|
||||
return cls.converter_factory(from_unit, to_unit)(value)
|
||||
|
||||
@classmethod
|
||||
@lru_cache
|
||||
def converter_factory(
|
||||
cls, from_unit: str | None, to_unit: str | None
|
||||
) -> Callable[[float], float]:
|
||||
"""Return a function to convert one unit of measurement to another."""
|
||||
if from_unit == to_unit:
|
||||
return lambda value: value
|
||||
from_ratio, to_ratio = cls._get_from_to_ratio(from_unit, to_unit)
|
||||
return lambda val: (val / from_ratio) * to_ratio
|
||||
|
||||
@classmethod
|
||||
def _get_from_to_ratio(
|
||||
cls, from_unit: str | None, to_unit: str | None
|
||||
) -> tuple[float, float]:
|
||||
"""Get unit ratio between units of measurement."""
|
||||
unit_conversion = cls._UNIT_CONVERSION
|
||||
try:
|
||||
return unit_conversion[from_unit], unit_conversion[to_unit]
|
||||
except KeyError as err:
|
||||
raise HomeAssistantError(
|
||||
UNIT_NOT_RECOGNIZED_TEMPLATE.format(err.args[0], cls.UNIT_CLASS)
|
||||
) from err
|
||||
|
||||
@classmethod
|
||||
@lru_cache
|
||||
def converter_factory_allow_none(
|
||||
cls, from_unit: str | None, to_unit: str | None
|
||||
) -> Callable[[float | None], float | None]:
|
||||
"""Return a function to convert one unit of measurement to another which allows None."""
|
||||
if from_unit == to_unit:
|
||||
return lambda value: value
|
||||
from_ratio, to_ratio = cls._get_from_to_ratio(from_unit, to_unit)
|
||||
return lambda val: None if val is None else (val / from_ratio) * to_ratio
|
||||
|
||||
@classmethod
|
||||
@lru_cache
|
||||
def get_unit_ratio(cls, from_unit: str | None, to_unit: str | None) -> float:
|
||||
"""Get unit ratio between units of measurement."""
|
||||
return cls._UNIT_CONVERSION[from_unit] / cls._UNIT_CONVERSION[to_unit]
|
||||
from_ratio, to_ratio = cls._get_from_to_ratio(from_unit, to_unit)
|
||||
return from_ratio / to_ratio
|
||||
|
||||
|
||||
class DataRateConverter(BaseUnitConverter):
|
||||
|
@ -339,7 +361,37 @@ class TemperatureConverter(BaseUnitConverter):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def convert(cls, value: float, from_unit: str | None, to_unit: str | None) -> float:
|
||||
@lru_cache(maxsize=8)
|
||||
def converter_factory(
|
||||
cls, from_unit: str | None, to_unit: str | None
|
||||
) -> Callable[[float], float]:
|
||||
"""Return a function to convert a temperature from one unit to another."""
|
||||
if from_unit == to_unit:
|
||||
# Return a function that does nothing. This is not
|
||||
# in _converter_factory because we do not want to wrap
|
||||
# it with the None check in converter_factory_allow_none.
|
||||
return lambda value: value
|
||||
|
||||
return cls._converter_factory(from_unit, to_unit)
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def converter_factory_allow_none(
|
||||
cls, from_unit: str | None, to_unit: str | None
|
||||
) -> Callable[[float | None], float | None]:
|
||||
"""Return a function to convert a temperature from one unit to another which allows None."""
|
||||
if from_unit == to_unit:
|
||||
# Return a function that does nothing. This is not
|
||||
# in _converter_factory because we do not want to wrap
|
||||
# it with the None check in this case.
|
||||
return lambda value: value
|
||||
convert = cls._converter_factory(from_unit, to_unit)
|
||||
return lambda value: None if value is None else convert(value)
|
||||
|
||||
@classmethod
|
||||
def _converter_factory(
|
||||
cls, from_unit: str | None, to_unit: str | None
|
||||
) -> Callable[[float], float]:
|
||||
"""Convert a temperature from one unit to another.
|
||||
|
||||
eg. 10°C will return 50°F
|
||||
|
@ -349,32 +401,29 @@ class TemperatureConverter(BaseUnitConverter):
|
|||
"""
|
||||
# We cannot use the implementation from BaseUnitConverter here because the
|
||||
# temperature units do not use the same floor: 0°C, 0°F and 0K do not align
|
||||
if from_unit == to_unit:
|
||||
return value
|
||||
|
||||
if from_unit == UnitOfTemperature.CELSIUS:
|
||||
if to_unit == UnitOfTemperature.FAHRENHEIT:
|
||||
return cls._celsius_to_fahrenheit(value)
|
||||
return cls._celsius_to_fahrenheit
|
||||
if to_unit == UnitOfTemperature.KELVIN:
|
||||
return cls._celsius_to_kelvin(value)
|
||||
return cls._celsius_to_kelvin
|
||||
raise HomeAssistantError(
|
||||
UNIT_NOT_RECOGNIZED_TEMPLATE.format(to_unit, cls.UNIT_CLASS)
|
||||
)
|
||||
|
||||
if from_unit == UnitOfTemperature.FAHRENHEIT:
|
||||
if to_unit == UnitOfTemperature.CELSIUS:
|
||||
return cls._fahrenheit_to_celsius(value)
|
||||
return cls._fahrenheit_to_celsius
|
||||
if to_unit == UnitOfTemperature.KELVIN:
|
||||
return cls._celsius_to_kelvin(cls._fahrenheit_to_celsius(value))
|
||||
return cls._fahrenheit_to_kelvin
|
||||
raise HomeAssistantError(
|
||||
UNIT_NOT_RECOGNIZED_TEMPLATE.format(to_unit, cls.UNIT_CLASS)
|
||||
)
|
||||
|
||||
if from_unit == UnitOfTemperature.KELVIN:
|
||||
if to_unit == UnitOfTemperature.CELSIUS:
|
||||
return cls._kelvin_to_celsius(value)
|
||||
return cls._kelvin_to_celsius
|
||||
if to_unit == UnitOfTemperature.FAHRENHEIT:
|
||||
return cls._celsius_to_fahrenheit(cls._kelvin_to_celsius(value))
|
||||
return cls._kelvin_to_fahrenheit
|
||||
raise HomeAssistantError(
|
||||
UNIT_NOT_RECOGNIZED_TEMPLATE.format(to_unit, cls.UNIT_CLASS)
|
||||
)
|
||||
|
@ -393,7 +442,17 @@ class TemperatureConverter(BaseUnitConverter):
|
|||
"""
|
||||
# We use BaseUnitConverter implementation here because we are only interested
|
||||
# in the ratio between the units.
|
||||
return super().convert(interval, from_unit, to_unit)
|
||||
return super().converter_factory(from_unit, to_unit)(interval)
|
||||
|
||||
@classmethod
|
||||
def _kelvin_to_fahrenheit(cls, kelvin: float) -> float:
|
||||
"""Convert a temperature in Kelvin to Fahrenheit."""
|
||||
return (kelvin - 273.15) * 1.8 + 32.0
|
||||
|
||||
@classmethod
|
||||
def _fahrenheit_to_kelvin(cls, fahrenheit: float) -> float:
|
||||
"""Convert a temperature in Fahrenheit to Kelvin."""
|
||||
return 273.15 + ((fahrenheit - 32.0) / 1.8)
|
||||
|
||||
@classmethod
|
||||
def _fahrenheit_to_celsius(cls, fahrenheit: float) -> float:
|
||||
|
|
|
@ -1973,6 +1973,36 @@ async def test_change_statistics_unit(
|
|||
],
|
||||
}
|
||||
|
||||
# Changing to the same unit is allowed but does nothing
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 6,
|
||||
"type": "recorder/change_statistics_unit",
|
||||
"statistic_id": "sensor.test",
|
||||
"new_unit_of_measurement": "W",
|
||||
"old_unit_of_measurement": "W",
|
||||
}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
await async_recorder_block_till_done(hass)
|
||||
|
||||
await client.send_json({"id": 7, "type": "recorder/list_statistic_ids"})
|
||||
response = await client.receive_json()
|
||||
assert response["success"]
|
||||
assert response["result"] == [
|
||||
{
|
||||
"statistic_id": "sensor.test",
|
||||
"display_unit_of_measurement": "kW",
|
||||
"has_mean": True,
|
||||
"has_sum": False,
|
||||
"name": None,
|
||||
"source": "recorder",
|
||||
"statistics_unit_of_measurement": "W",
|
||||
"unit_class": "power",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
async def test_change_statistics_unit_errors(
|
||||
recorder_mock: Recorder,
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from itertools import chain
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -534,6 +535,86 @@ def test_unit_conversion(
|
|||
assert converter.convert(value, from_unit, to_unit) == pytest.approx(expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("converter", "value", "from_unit", "expected", "to_unit"),
|
||||
[
|
||||
# Process all items in _CONVERTED_VALUE
|
||||
(converter, value, from_unit, expected, to_unit)
|
||||
for converter, item in _CONVERTED_VALUE.items()
|
||||
for value, from_unit, expected, to_unit in item
|
||||
],
|
||||
)
|
||||
def test_unit_conversion_factory(
|
||||
converter: type[BaseUnitConverter],
|
||||
value: float,
|
||||
from_unit: str,
|
||||
expected: float,
|
||||
to_unit: str,
|
||||
) -> None:
|
||||
"""Test conversion to other units."""
|
||||
assert converter.converter_factory(from_unit, to_unit)(value) == pytest.approx(
|
||||
expected
|
||||
)
|
||||
|
||||
|
||||
def test_unit_conversion_factory_allow_none_with_none() -> None:
|
||||
"""Test test_unit_conversion_factory_allow_none with None."""
|
||||
assert (
|
||||
SpeedConverter.converter_factory_allow_none(
|
||||
UnitOfSpeed.FEET_PER_SECOND, UnitOfSpeed.FEET_PER_SECOND
|
||||
)(1)
|
||||
== 1
|
||||
)
|
||||
assert (
|
||||
SpeedConverter.converter_factory_allow_none(
|
||||
UnitOfSpeed.FEET_PER_SECOND, UnitOfSpeed.FEET_PER_SECOND
|
||||
)(None)
|
||||
is None
|
||||
)
|
||||
assert (
|
||||
TemperatureConverter.converter_factory_allow_none(
|
||||
UnitOfTemperature.CELSIUS, UnitOfTemperature.CELSIUS
|
||||
)(1)
|
||||
== 1
|
||||
)
|
||||
assert (
|
||||
TemperatureConverter.converter_factory_allow_none(
|
||||
UnitOfTemperature.CELSIUS, UnitOfTemperature.CELSIUS
|
||||
)(None)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("converter", "value", "from_unit", "expected", "to_unit"),
|
||||
chain(
|
||||
[
|
||||
# Process all items in _CONVERTED_VALUE
|
||||
(converter, value, from_unit, expected, to_unit)
|
||||
for converter, item in _CONVERTED_VALUE.items()
|
||||
for value, from_unit, expected, to_unit in item
|
||||
],
|
||||
[
|
||||
# Process all items in _CONVERTED_VALUE and replace the value with None
|
||||
(converter, None, from_unit, None, to_unit)
|
||||
for converter, item in _CONVERTED_VALUE.items()
|
||||
for value, from_unit, expected, to_unit in item
|
||||
],
|
||||
),
|
||||
)
|
||||
def test_unit_conversion_factory_allow_none(
|
||||
converter: type[BaseUnitConverter],
|
||||
value: float,
|
||||
from_unit: str,
|
||||
expected: float,
|
||||
to_unit: str,
|
||||
) -> None:
|
||||
"""Test conversion to other units."""
|
||||
assert converter.converter_factory_allow_none(from_unit, to_unit)(
|
||||
value
|
||||
) == pytest.approx(expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("value", "from_unit", "expected", "to_unit"),
|
||||
[
|
||||
|
|
Loading…
Add table
Reference in a new issue