Add suggested_unit_of_measurement attribute to sensors (#80638)

* Add suggested_unit_of_measurement attribute to sensors

* Lazy calculation of initial entity options

* Add type alias for entity options

* Small tweak

* Add tests

* Store suggested_unit_of_measurement in its own option key

* Adapt to renaming of IMPERIAL_SYSTEM

* Fix rebase mistakes

* Apply suggestions from code review

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
Erik Montnemery 2022-10-24 16:08:02 +02:00 committed by GitHub
parent 0c79a9a33d
commit 6979cd95b0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 380 additions and 48 deletions

View file

@ -49,6 +49,7 @@ from homeassistant.const import ( # noqa: F401, pylint: disable=[hass-deprecate
TEMP_KELVIN,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.config_validation import ( # noqa: F401
PLATFORM_SCHEMA,
PLATFORM_SCHEMA_BASE,
@ -407,6 +408,7 @@ class SensorEntityDescription(EntityDescription):
"""A class that describes sensor entities."""
device_class: SensorDeviceClass | str | None = None
suggested_unit_of_measurement: str | None = None
last_reset: datetime | None = None
native_unit_of_measurement: str | None = None
state_class: SensorStateClass | str | None = None
@ -423,6 +425,7 @@ class SensorEntity(Entity):
_attr_native_value: StateType | date | datetime | Decimal = None
_attr_state_class: SensorStateClass | str | None
_attr_state: None = None # Subclasses of SensorEntity should not set this
_attr_suggested_unit_of_measurement: str | None
_attr_unit_of_measurement: None = (
None # Subclasses of SensorEntity should not set this
)
@ -471,6 +474,30 @@ class SensorEntity(Entity):
return None
def get_initial_entity_options(self) -> er.EntityOptionsType | None:
"""Return initial entity options.
These will be stored in the entity registry the first time the entity is seen,
and then never updated.
"""
# Unit suggested by the integration
suggested_unit_of_measurement = self.suggested_unit_of_measurement
if suggested_unit_of_measurement is None:
# Fallback to suggested by the unit conversion rules
suggested_unit_of_measurement = self.hass.config.units.get_converted_unit(
self.device_class, self.native_unit_of_measurement
)
if suggested_unit_of_measurement is None:
return None
return {
f"{DOMAIN}.private": {
"suggested_unit_of_measurement": suggested_unit_of_measurement
}
}
@final
@property
def state_attributes(self) -> dict[str, Any] | None:
@ -514,13 +541,45 @@ class SensorEntity(Entity):
return self.entity_description.native_unit_of_measurement
return None
@property
def suggested_unit_of_measurement(self) -> str | None:
"""Return the unit which should be used for the sensor's state.
This can be used by integrations to override automatic unit conversion rules,
for example to make a temperature sensor display in °C even if the configured
unit system prefers °F.
For sensors without a `unique_id`, this takes precedence over legacy
temperature conversion rules only.
For sensors with a `unique_id`, this is applied only if the unit is not set by the user,
and takes precedence over automatic device-class conversion rules.
Note:
suggested_unit_of_measurement is stored in the entity registry the first time
the entity is seen, and then never updated.
"""
if hasattr(self, "_attr_suggested_unit_of_measurement"):
return self._attr_suggested_unit_of_measurement
if hasattr(self, "entity_description"):
return self.entity_description.suggested_unit_of_measurement
return None
@final
@property
def unit_of_measurement(self) -> str | None:
"""Return the unit of measurement of the entity, after unit conversion."""
# Highest priority, for registered entities: unit set by user, with fallback to unit suggested
# by integration or secondary fallback to unit conversion rules
if self._sensor_option_unit_of_measurement:
return self._sensor_option_unit_of_measurement
# Second priority, for non registered entities: unit suggested by integration
if not self.registry_entry and self.suggested_unit_of_measurement:
return self.suggested_unit_of_measurement
# Third priority: Legacy temperature conversion, which applies
# to both registered and non registered entities
native_unit_of_measurement = self.native_unit_of_measurement
if (
@ -529,6 +588,7 @@ class SensorEntity(Entity):
):
return self.hass.config.units.temperature_unit
# Fourth priority: Native unit
return native_unit_of_measurement
@final
@ -624,22 +684,30 @@ class SensorEntity(Entity):
return super().__repr__()
@callback
def async_registry_entry_updated(self) -> None:
"""Run when the entity registry entry has been updated."""
def _custom_unit_or_none(self, primary_key: str, secondary_key: str) -> str | None:
"""Return a custom unit, or None if it's not compatible with the native unit."""
assert self.registry_entry
if (
(sensor_options := self.registry_entry.options.get(DOMAIN))
and (custom_unit := sensor_options.get(CONF_UNIT_OF_MEASUREMENT))
(sensor_options := self.registry_entry.options.get(primary_key))
and (custom_unit := sensor_options.get(secondary_key))
and (device_class := self.device_class) in UNIT_CONVERTERS
and self.native_unit_of_measurement
in UNIT_CONVERTERS[device_class].VALID_UNITS
and custom_unit in UNIT_CONVERTERS[device_class].VALID_UNITS
):
self._sensor_option_unit_of_measurement = custom_unit
return
return cast(str, custom_unit)
return None
self._sensor_option_unit_of_measurement = None
@callback
def async_registry_entry_updated(self) -> None:
"""Run when the entity registry entry has been updated."""
self._sensor_option_unit_of_measurement = self._custom_unit_or_none(
DOMAIN, CONF_UNIT_OF_MEASUREMENT
)
if not self._sensor_option_unit_of_measurement:
self._sensor_option_unit_of_measurement = self._custom_unit_or_none(
f"{DOMAIN}.private", "suggested_unit_of_measurement"
)
@dataclass

View file

@ -340,6 +340,18 @@ class Entity(ABC):
"""
return self._attr_capability_attributes
def get_initial_entity_options(self) -> er.EntityOptionsType | None:
"""Return initial entity options.
These will be stored in the entity registry the first time the entity is seen,
and then never updated.
Implemented by component base class, should not be extended by integrations.
Note: Not a property to avoid calculating unless needed.
"""
return None
@property
def state_attributes(self) -> dict[str, Any] | None:
"""Return the state attributes.

View file

@ -607,9 +607,10 @@ class EntityPlatform:
device_id=device_id,
disabled_by=disabled_by,
entity_category=entity.entity_category,
get_initial_options=entity.get_initial_entity_options,
has_entity_name=entity.has_entity_name,
hidden_by=hidden_by,
known_object_ids=self.entities.keys(),
has_entity_name=entity.has_entity_name,
original_device_class=entity.device_class,
original_icon=entity.icon,
original_name=entity.name,

View file

@ -94,6 +94,9 @@ class RegistryEntryHider(StrEnum):
USER = "user"
EntityOptionsType = Mapping[str, Mapping[str, Any]]
@attr.s(slots=True, frozen=True)
class RegistryEntry:
"""Entity Registry Entry."""
@ -114,7 +117,7 @@ class RegistryEntry:
id: str = attr.ib(factory=uuid_util.random_uuid_hex)
has_entity_name: bool = attr.ib(default=False)
name: str | None = attr.ib(default=None)
options: Mapping[str, Mapping[str, Any]] = attr.ib(
options: EntityOptionsType = attr.ib(
default=None, converter=attr.converters.default_if_none(factory=dict) # type: ignore[misc]
)
# As set by integration
@ -397,6 +400,8 @@ class EntityRegistry:
# To disable or hide an entity if it gets created
disabled_by: RegistryEntryDisabler | None = None,
hidden_by: RegistryEntryHider | None = None,
# Function to generate initial entity options if it gets created
get_initial_options: Callable[[], EntityOptionsType | None] | None = None,
# Data that we want entry to have
capabilities: Mapping[str, Any] | None | UndefinedType = UNDEFINED,
config_entry: ConfigEntry | None | UndefinedType = UNDEFINED,
@ -465,6 +470,8 @@ class EntityRegistry:
"""Return None if value is UNDEFINED, otherwise return value."""
return None if value is UNDEFINED else value
initial_options = get_initial_options() if get_initial_options else None
entry = RegistryEntry(
capabilities=none_if_undefined(capabilities),
config_entry_id=none_if_undefined(config_entry_id),
@ -474,6 +481,7 @@ class EntityRegistry:
entity_id=entity_id,
hidden_by=hidden_by,
has_entity_name=none_if_undefined(has_entity_name) or False,
options=initial_options,
original_device_class=none_if_undefined(original_device_class),
original_icon=none_if_undefined(original_icon),
original_name=none_if_undefined(original_name),
@ -590,7 +598,7 @@ class EntityRegistry:
supported_features: int | UndefinedType = UNDEFINED,
unit_of_measurement: str | None | UndefinedType = UNDEFINED,
platform: str | None | UndefinedType = UNDEFINED,
options: Mapping[str, Mapping[str, Any]] | UndefinedType = UNDEFINED,
options: EntityOptionsType | UndefinedType = UNDEFINED,
) -> RegistryEntry:
"""Private facing update properties method."""
old = self.entities[entity_id]
@ -779,7 +787,7 @@ class EntityRegistry:
) -> RegistryEntry:
"""Update entity options."""
old = self.entities[entity_id]
new_options: Mapping[str, Mapping[str, Any]] = {**old.options, domain: options}
new_options: EntityOptionsType = {**old.options, domain: options}
return self._async_update_entity(entity_id, options=new_options)
async def async_load(self) -> None:

View file

@ -2,7 +2,7 @@
from __future__ import annotations
from numbers import Number
from typing import Final
from typing import TYPE_CHECKING, Final
import voluptuous as vol
@ -42,6 +42,9 @@ from .unit_conversion import (
VolumeConverter,
)
if TYPE_CHECKING:
from homeassistant.components.sensor import SensorDeviceClass
_CONF_UNIT_SYSTEM_IMPERIAL: Final = "imperial"
_CONF_UNIT_SYSTEM_METRIC: Final = "metric"
_CONF_UNIT_SYSTEM_US_CUSTOMARY: Final = "us_customary"
@ -90,6 +93,7 @@ class UnitSystem:
*,
accumulated_precipitation: str,
length: str,
length_conversions: dict[str | None, str],
mass: str,
pressure: str,
temperature: str,
@ -122,6 +126,7 @@ class UnitSystem:
self.pressure_unit = pressure
self.volume_unit = volume
self.wind_speed_unit = wind_speed
self._length_conversions = length_conversions
@property
def name(self) -> str:
@ -215,6 +220,17 @@ class UnitSystem:
WIND_SPEED: self.wind_speed_unit,
}
def get_converted_unit(
self,
device_class: SensorDeviceClass | str | None,
original_unit: str | None,
) -> str | None:
"""Return converted unit given a device class or an original unit."""
if device_class == "distance":
return self._length_conversions.get(original_unit)
return None
def get_unit_system(key: str) -> UnitSystem:
"""Get unit system based on key."""
@ -244,6 +260,7 @@ METRIC_SYSTEM = UnitSystem(
_CONF_UNIT_SYSTEM_METRIC,
accumulated_precipitation=PRECIPITATION_MILLIMETERS,
length=LENGTH_KILOMETERS,
length_conversions={LENGTH_MILES: LENGTH_KILOMETERS},
mass=MASS_GRAMS,
pressure=PRESSURE_PA,
temperature=TEMP_CELSIUS,
@ -255,6 +272,7 @@ US_CUSTOMARY_SYSTEM = UnitSystem(
_CONF_UNIT_SYSTEM_US_CUSTOMARY,
accumulated_precipitation=PRECIPITATION_INCHES,
length=LENGTH_MILES,
length_conversions={LENGTH_KILOMETERS: LENGTH_MILES},
mass=MASS_POUNDS,
pressure=PRESSURE_PSI,
temperature=TEMP_FAHRENHEIT,

View file

@ -11,7 +11,9 @@ from homeassistant.const import (
LENGTH_CENTIMETERS,
LENGTH_INCHES,
LENGTH_KILOMETERS,
LENGTH_METERS,
LENGTH_MILES,
LENGTH_YARD,
MASS_GRAMS,
MASS_OUNCES,
PRESSURE_HPA,
@ -661,3 +663,213 @@ async def test_custom_unit_change(
state = hass.states.get(entity0.entity_id)
assert float(state.state) == approx(float(native_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == native_unit
@pytest.mark.parametrize(
"unit_system, native_unit, automatic_unit, suggested_unit, custom_unit, native_value, automatic_value, suggested_value, custom_value, device_class",
[
# Distance
(
US_CUSTOMARY_SYSTEM,
LENGTH_KILOMETERS,
LENGTH_MILES,
LENGTH_METERS,
LENGTH_YARD,
1000,
621,
1000000,
1093613,
SensorDeviceClass.DISTANCE,
),
],
)
async def test_unit_conversion_priority(
hass,
enable_custom_integrations,
unit_system,
native_unit,
automatic_unit,
suggested_unit,
custom_unit,
native_value,
automatic_value,
suggested_value,
custom_value,
device_class,
):
"""Test priority of unit conversion."""
hass.config.units = unit_system
entity_registry = er.async_get(hass)
platform = getattr(hass.components, "test.sensor")
platform.init(empty=True)
platform.ENTITIES["0"] = platform.MockSensor(
name="Test",
device_class=device_class,
native_unit_of_measurement=native_unit,
native_value=str(native_value),
unique_id="very_unique",
)
entity0 = platform.ENTITIES["0"]
platform.ENTITIES["1"] = platform.MockSensor(
name="Test",
device_class=device_class,
native_unit_of_measurement=native_unit,
native_value=str(native_value),
)
entity1 = platform.ENTITIES["1"]
platform.ENTITIES["2"] = platform.MockSensor(
name="Test",
device_class=device_class,
native_unit_of_measurement=native_unit,
native_value=str(native_value),
suggested_unit_of_measurement=suggested_unit,
unique_id="very_unique_2",
)
entity2 = platform.ENTITIES["2"]
platform.ENTITIES["3"] = platform.MockSensor(
name="Test",
device_class=device_class,
native_unit_of_measurement=native_unit,
native_value=str(native_value),
suggested_unit_of_measurement=suggested_unit,
)
entity3 = platform.ENTITIES["3"]
assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}})
await hass.async_block_till_done()
# Registered entity -> Follow automatic unit conversion
state = hass.states.get(entity0.entity_id)
assert float(state.state) == approx(float(automatic_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == automatic_unit
# Assert the automatic unit conversion is stored in the registry
entry = entity_registry.async_get(entity0.entity_id)
assert entry.options == {
"sensor.private": {"suggested_unit_of_measurement": automatic_unit}
}
# Unregistered entity -> Follow native unit
state = hass.states.get(entity1.entity_id)
assert float(state.state) == approx(float(native_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == native_unit
# Registered entity with suggested unit
state = hass.states.get(entity2.entity_id)
assert float(state.state) == approx(float(suggested_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == suggested_unit
# Assert the suggested unit is stored in the registry
entry = entity_registry.async_get(entity2.entity_id)
assert entry.options == {
"sensor.private": {"suggested_unit_of_measurement": suggested_unit}
}
# Unregistered entity with suggested unit
state = hass.states.get(entity3.entity_id)
assert float(state.state) == approx(float(suggested_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == suggested_unit
# Set a custom unit, this should have priority over the automatic unit conversion
entity_registry.async_update_entity_options(
entity0.entity_id, "sensor", {"unit_of_measurement": custom_unit}
)
await hass.async_block_till_done()
state = hass.states.get(entity0.entity_id)
assert float(state.state) == approx(float(custom_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit
entity_registry.async_update_entity_options(
entity2.entity_id, "sensor", {"unit_of_measurement": custom_unit}
)
await hass.async_block_till_done()
state = hass.states.get(entity2.entity_id)
assert float(state.state) == approx(float(custom_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit
@pytest.mark.parametrize(
"unit_system, native_unit, original_unit, suggested_unit, native_value, original_value, device_class",
[
# Distance
(
US_CUSTOMARY_SYSTEM,
LENGTH_KILOMETERS,
LENGTH_YARD,
LENGTH_METERS,
1000,
1093613,
SensorDeviceClass.DISTANCE,
),
],
)
async def test_unit_conversion_priority_suggested_unit_change(
hass,
enable_custom_integrations,
unit_system,
native_unit,
original_unit,
suggested_unit,
native_value,
original_value,
device_class,
):
"""Test priority of unit conversion."""
hass.config.units = unit_system
entity_registry = er.async_get(hass)
platform = getattr(hass.components, "test.sensor")
platform.init(empty=True)
# Pre-register entities
entry = entity_registry.async_get_or_create("sensor", "test", "very_unique")
entity_registry.async_update_entity_options(
entry.entity_id,
"sensor.private",
{"suggested_unit_of_measurement": original_unit},
)
entry = entity_registry.async_get_or_create("sensor", "test", "very_unique_2")
entity_registry.async_update_entity_options(
entry.entity_id,
"sensor.private",
{"suggested_unit_of_measurement": original_unit},
)
platform.ENTITIES["0"] = platform.MockSensor(
name="Test",
device_class=device_class,
native_unit_of_measurement=native_unit,
native_value=str(native_value),
unique_id="very_unique",
)
entity0 = platform.ENTITIES["0"]
platform.ENTITIES["1"] = platform.MockSensor(
name="Test",
device_class=device_class,
native_unit_of_measurement=native_unit,
native_value=str(native_value),
suggested_unit_of_measurement=suggested_unit,
unique_id="very_unique_2",
)
entity1 = platform.ENTITIES["1"]
assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}})
await hass.async_block_till_done()
# Registered entity -> Follow automatic unit conversion the first time the entity was seen
state = hass.states.get(entity0.entity_id)
assert float(state.state) == approx(float(original_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == original_unit
# Registered entity -> Follow suggested unit the first time the entity was seen
state = hass.states.get(entity1.entity_id)
assert float(state.state) == approx(float(original_value))
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == original_unit

View file

@ -43,13 +43,14 @@ def _set_up_units(hass):
"""Set up the tests."""
hass.config.units = UnitSystem(
"custom",
temperature=TEMP_CELSIUS,
accumulated_precipitation=LENGTH_MILLIMETERS,
length=LENGTH_METERS,
wind_speed=SPEED_KILOMETERS_PER_HOUR,
volume=VOLUME_LITERS,
length_conversions={},
mass=MASS_GRAMS,
pressure=PRESSURE_PA,
accumulated_precipitation=LENGTH_MILLIMETERS,
temperature=TEMP_CELSIUS,
volume=VOLUME_LITERS,
wind_speed=SPEED_KILOMETERS_PER_HOUR,
)

View file

@ -112,6 +112,11 @@ class MockSensor(MockEntity, SensorEntity):
"""Return the state class of this sensor."""
return self._handle("state_class")
@property
def suggested_unit_of_measurement(self):
"""Return the state class of this sensor."""
return self._handle("suggested_unit_of_measurement")
class MockRestoreSensor(MockSensor, RestoreSensor):
"""Mock RestoreSensor class."""

View file

@ -39,85 +39,92 @@ def test_invalid_units():
with pytest.raises(ValueError):
UnitSystem(
SYSTEM_NAME,
accumulated_precipitation=LENGTH_MILLIMETERS,
length=LENGTH_METERS,
length_conversions={},
mass=MASS_GRAMS,
pressure=PRESSURE_PA,
temperature=INVALID_UNIT,
length=LENGTH_METERS,
wind_speed=SPEED_METERS_PER_SECOND,
volume=VOLUME_LITERS,
mass=MASS_GRAMS,
pressure=PRESSURE_PA,
accumulated_precipitation=LENGTH_MILLIMETERS,
wind_speed=SPEED_METERS_PER_SECOND,
)
with pytest.raises(ValueError):
UnitSystem(
SYSTEM_NAME,
temperature=TEMP_CELSIUS,
accumulated_precipitation=LENGTH_MILLIMETERS,
length=INVALID_UNIT,
wind_speed=SPEED_METERS_PER_SECOND,
volume=VOLUME_LITERS,
length_conversions={},
mass=MASS_GRAMS,
pressure=PRESSURE_PA,
accumulated_precipitation=LENGTH_MILLIMETERS,
temperature=TEMP_CELSIUS,
volume=VOLUME_LITERS,
wind_speed=SPEED_METERS_PER_SECOND,
)
with pytest.raises(ValueError):
UnitSystem(
SYSTEM_NAME,
temperature=TEMP_CELSIUS,
accumulated_precipitation=LENGTH_MILLIMETERS,
length=LENGTH_METERS,
length_conversions={},
mass=MASS_GRAMS,
pressure=PRESSURE_PA,
temperature=TEMP_CELSIUS,
volume=VOLUME_LITERS,
wind_speed=INVALID_UNIT,
volume=VOLUME_LITERS,
mass=MASS_GRAMS,
pressure=PRESSURE_PA,
accumulated_precipitation=LENGTH_MILLIMETERS,
)
with pytest.raises(ValueError):
UnitSystem(
SYSTEM_NAME,
temperature=TEMP_CELSIUS,
accumulated_precipitation=LENGTH_MILLIMETERS,
length=LENGTH_METERS,
wind_speed=SPEED_METERS_PER_SECOND,
length_conversions={},
mass=MASS_GRAMS,
pressure=PRESSURE_PA,
temperature=TEMP_CELSIUS,
volume=INVALID_UNIT,
mass=MASS_GRAMS,
pressure=PRESSURE_PA,
accumulated_precipitation=LENGTH_MILLIMETERS,
wind_speed=SPEED_METERS_PER_SECOND,
)
with pytest.raises(ValueError):
UnitSystem(
SYSTEM_NAME,
temperature=TEMP_CELSIUS,
accumulated_precipitation=LENGTH_MILLIMETERS,
length=LENGTH_METERS,
wind_speed=SPEED_METERS_PER_SECOND,
volume=VOLUME_LITERS,
length_conversions={},
mass=INVALID_UNIT,
pressure=PRESSURE_PA,
accumulated_precipitation=LENGTH_MILLIMETERS,
temperature=TEMP_CELSIUS,
volume=VOLUME_LITERS,
wind_speed=SPEED_METERS_PER_SECOND,
)
with pytest.raises(ValueError):
UnitSystem(
SYSTEM_NAME,
temperature=TEMP_CELSIUS,
accumulated_precipitation=LENGTH_MILLIMETERS,
length=LENGTH_METERS,
wind_speed=SPEED_METERS_PER_SECOND,
volume=VOLUME_LITERS,
length_conversions={},
mass=MASS_GRAMS,
pressure=INVALID_UNIT,
accumulated_precipitation=LENGTH_MILLIMETERS,
temperature=TEMP_CELSIUS,
volume=VOLUME_LITERS,
wind_speed=SPEED_METERS_PER_SECOND,
)
with pytest.raises(ValueError):
UnitSystem(
SYSTEM_NAME,
temperature=TEMP_CELSIUS,
accumulated_precipitation=INVALID_UNIT,
length=LENGTH_METERS,
wind_speed=SPEED_METERS_PER_SECOND,
volume=VOLUME_LITERS,
length_conversions={},
mass=MASS_GRAMS,
pressure=PRESSURE_PA,
accumulated_precipitation=INVALID_UNIT,
temperature=TEMP_CELSIUS,
volume=VOLUME_LITERS,
wind_speed=SPEED_METERS_PER_SECOND,
)