Optionally update sensor units when unit system is changed (#83851)
This commit is contained in:
parent
4b27af6a8f
commit
4d4fb2477d
5 changed files with 267 additions and 12 deletions
|
@ -6,6 +6,7 @@ import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
|
from homeassistant.components.sensor import async_update_suggested_units
|
||||||
from homeassistant.config import async_check_ha_config_file
|
from homeassistant.config import async_check_ha_config_file
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
@ -40,17 +41,18 @@ class CheckConfigView(HomeAssistantView):
|
||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
{
|
{
|
||||||
"type": "config/core/update",
|
"type": "config/core/update",
|
||||||
vol.Optional("latitude"): cv.latitude,
|
vol.Optional("country"): cv.country,
|
||||||
vol.Optional("longitude"): cv.longitude,
|
vol.Optional("currency"): cv.currency,
|
||||||
vol.Optional("elevation"): int,
|
vol.Optional("elevation"): int,
|
||||||
vol.Optional("unit_system"): unit_system.validate_unit_system,
|
|
||||||
vol.Optional("location_name"): str,
|
|
||||||
vol.Optional("time_zone"): cv.time_zone,
|
|
||||||
vol.Optional("external_url"): vol.Any(cv.url_no_path, None),
|
vol.Optional("external_url"): vol.Any(cv.url_no_path, None),
|
||||||
vol.Optional("internal_url"): vol.Any(cv.url_no_path, None),
|
vol.Optional("internal_url"): vol.Any(cv.url_no_path, None),
|
||||||
vol.Optional("currency"): cv.currency,
|
|
||||||
vol.Optional("country"): cv.country,
|
|
||||||
vol.Optional("language"): cv.language,
|
vol.Optional("language"): cv.language,
|
||||||
|
vol.Optional("latitude"): cv.latitude,
|
||||||
|
vol.Optional("location_name"): str,
|
||||||
|
vol.Optional("longitude"): cv.longitude,
|
||||||
|
vol.Optional("time_zone"): cv.time_zone,
|
||||||
|
vol.Optional("update_units"): bool,
|
||||||
|
vol.Optional("unit_system"): unit_system.validate_unit_system,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@websocket_api.async_response
|
@websocket_api.async_response
|
||||||
|
@ -64,8 +66,12 @@ async def websocket_update_config(
|
||||||
data.pop("id")
|
data.pop("id")
|
||||||
data.pop("type")
|
data.pop("type")
|
||||||
|
|
||||||
|
update_units = data.pop("update_units", False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await hass.config.async_update(**data)
|
await hass.config.async_update(**data)
|
||||||
|
if update_units:
|
||||||
|
async_update_suggested_units(hass)
|
||||||
connection.send_result(msg["id"])
|
connection.send_result(msg["id"])
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
connection.send_error(msg["id"], "invalid_info", str(err))
|
connection.send_error(msg["id"], "invalid_info", str(err))
|
||||||
|
|
|
@ -730,6 +730,17 @@ class SensorEntity(Entity):
|
||||||
def async_registry_entry_updated(self) -> None:
|
def async_registry_entry_updated(self) -> None:
|
||||||
"""Run when the entity registry entry has been updated."""
|
"""Run when the entity registry entry has been updated."""
|
||||||
self._sensor_option_precision = self._custom_precision_or_none()
|
self._sensor_option_precision = self._custom_precision_or_none()
|
||||||
|
assert self.registry_entry
|
||||||
|
if (
|
||||||
|
sensor_options := self.registry_entry.options.get(f"{DOMAIN}.private")
|
||||||
|
) and "refresh_initial_entity_options" in sensor_options:
|
||||||
|
registry = er.async_get(self.hass)
|
||||||
|
initial_options = self.get_initial_entity_options() or {}
|
||||||
|
registry.async_update_entity_options(
|
||||||
|
self.entity_id,
|
||||||
|
f"{DOMAIN}.private",
|
||||||
|
initial_options.get(f"{DOMAIN}.private"),
|
||||||
|
)
|
||||||
self._sensor_option_unit_of_measurement = self._custom_unit_or_undef(
|
self._sensor_option_unit_of_measurement = self._custom_unit_or_undef(
|
||||||
DOMAIN, CONF_UNIT_OF_MEASUREMENT
|
DOMAIN, CONF_UNIT_OF_MEASUREMENT
|
||||||
)
|
)
|
||||||
|
@ -808,3 +819,21 @@ class RestoreSensor(SensorEntity, RestoreEntity):
|
||||||
if (restored_last_extra_data := await self.async_get_last_extra_data()) is None:
|
if (restored_last_extra_data := await self.async_get_last_extra_data()) is None:
|
||||||
return None
|
return None
|
||||||
return SensorExtraStoredData.from_dict(restored_last_extra_data.as_dict())
|
return SensorExtraStoredData.from_dict(restored_last_extra_data.as_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_update_suggested_units(hass: HomeAssistant) -> None:
|
||||||
|
"""Update the suggested_unit_of_measurement according to the unit system."""
|
||||||
|
registry = er.async_get(hass)
|
||||||
|
|
||||||
|
for entry in registry.entities.values():
|
||||||
|
if entry.domain != DOMAIN:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sensor_private_options = dict(entry.options.get(f"{DOMAIN}.private", {}))
|
||||||
|
sensor_private_options["refresh_initial_entity_options"] = True
|
||||||
|
registry.async_update_entity_options(
|
||||||
|
entry.entity_id,
|
||||||
|
f"{DOMAIN}.private",
|
||||||
|
sensor_private_options,
|
||||||
|
)
|
||||||
|
|
|
@ -859,11 +859,18 @@ class EntityRegistry:
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_update_entity_options(
|
def async_update_entity_options(
|
||||||
self, entity_id: str, domain: str, options: dict[str, Any]
|
self, entity_id: str, domain: str, options: Mapping[str, Any] | None
|
||||||
) -> RegistryEntry:
|
) -> RegistryEntry:
|
||||||
"""Update entity options."""
|
"""Update entity options for a domain.
|
||||||
|
|
||||||
|
If the domain options are set to None, they will be removed.
|
||||||
|
"""
|
||||||
old = self.entities[entity_id]
|
old = self.entities[entity_id]
|
||||||
new_options: EntityOptionsType = {**old.options, domain: options}
|
new_options = {
|
||||||
|
key: value for key, value in old.options.items() if key != domain
|
||||||
|
}
|
||||||
|
if options is not None:
|
||||||
|
new_options[domain] = options
|
||||||
return self._async_update_entity(entity_id, options=new_options)
|
return self._async_update_entity(entity_id, options=new_options)
|
||||||
|
|
||||||
async def async_load(self) -> None:
|
async def async_load(self) -> None:
|
||||||
|
|
|
@ -7,7 +7,11 @@ import pytest
|
||||||
from homeassistant.bootstrap import async_setup_component
|
from homeassistant.bootstrap import async_setup_component
|
||||||
from homeassistant.components import config
|
from homeassistant.components import config
|
||||||
from homeassistant.components.websocket_api.const import TYPE_RESULT
|
from homeassistant.components.websocket_api.const import TYPE_RESULT
|
||||||
from homeassistant.const import CONF_UNIT_SYSTEM, CONF_UNIT_SYSTEM_IMPERIAL
|
from homeassistant.const import (
|
||||||
|
CONF_UNIT_SYSTEM,
|
||||||
|
CONF_UNIT_SYSTEM_IMPERIAL,
|
||||||
|
CONF_UNIT_SYSTEM_METRIC,
|
||||||
|
)
|
||||||
from homeassistant.util import dt as dt_util, location
|
from homeassistant.util import dt as dt_util, location
|
||||||
from homeassistant.util.unit_system import US_CUSTOMARY_SYSTEM
|
from homeassistant.util.unit_system import US_CUSTOMARY_SYSTEM
|
||||||
|
|
||||||
|
@ -64,7 +68,9 @@ async def test_websocket_core_update(hass, client):
|
||||||
assert hass.config.country != "SE"
|
assert hass.config.country != "SE"
|
||||||
assert hass.config.language != "sv"
|
assert hass.config.language != "sv"
|
||||||
|
|
||||||
with patch("homeassistant.util.dt.set_default_time_zone") as mock_set_tz:
|
with patch("homeassistant.util.dt.set_default_time_zone") as mock_set_tz, patch(
|
||||||
|
"homeassistant.components.config.core.async_update_suggested_units"
|
||||||
|
) as mock_update_sensor_units:
|
||||||
await client.send_json(
|
await client.send_json(
|
||||||
{
|
{
|
||||||
"id": 5,
|
"id": 5,
|
||||||
|
@ -85,6 +91,8 @@ async def test_websocket_core_update(hass, client):
|
||||||
|
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
|
|
||||||
|
mock_update_sensor_units.assert_not_called()
|
||||||
|
|
||||||
assert msg["id"] == 5
|
assert msg["id"] == 5
|
||||||
assert msg["type"] == TYPE_RESULT
|
assert msg["type"] == TYPE_RESULT
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
|
@ -100,6 +108,22 @@ async def test_websocket_core_update(hass, client):
|
||||||
assert len(mock_set_tz.mock_calls) == 1
|
assert len(mock_set_tz.mock_calls) == 1
|
||||||
assert mock_set_tz.mock_calls[0][1][0] == dt_util.get_time_zone("America/New_York")
|
assert mock_set_tz.mock_calls[0][1][0] == dt_util.get_time_zone("America/New_York")
|
||||||
|
|
||||||
|
with patch("homeassistant.util.dt.set_default_time_zone") as mock_set_tz, patch(
|
||||||
|
"homeassistant.components.config.core.async_update_suggested_units"
|
||||||
|
) as mock_update_sensor_units:
|
||||||
|
await client.send_json(
|
||||||
|
{
|
||||||
|
"id": 6,
|
||||||
|
"type": "config/core/update",
|
||||||
|
CONF_UNIT_SYSTEM: CONF_UNIT_SYSTEM_METRIC,
|
||||||
|
"update_units": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
|
||||||
|
mock_update_sensor_units.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
async def test_websocket_core_update_not_admin(hass, hass_ws_client, hass_admin_user):
|
async def test_websocket_core_update_not_admin(hass, hass_ws_client, hass_admin_user):
|
||||||
"""Test core config fails for non admin."""
|
"""Test core config fails for non admin."""
|
||||||
|
|
|
@ -12,6 +12,7 @@ from homeassistant.components.sensor import (
|
||||||
DEVICE_CLASS_UNITS,
|
DEVICE_CLASS_UNITS,
|
||||||
SensorDeviceClass,
|
SensorDeviceClass,
|
||||||
SensorStateClass,
|
SensorStateClass,
|
||||||
|
async_update_suggested_units,
|
||||||
)
|
)
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_UNIT_OF_MEASUREMENT,
|
ATTR_UNIT_OF_MEASUREMENT,
|
||||||
|
@ -1685,3 +1686,191 @@ async def test_numeric_state_expected_helper(
|
||||||
assert state is not None
|
assert state is not None
|
||||||
|
|
||||||
assert entity0._numeric_state_expected == is_numeric
|
assert entity0._numeric_state_expected == is_numeric
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"unit_system_1, unit_system_2, native_unit, automatic_unit_1, automatic_unit_2, suggested_unit, custom_unit, native_value, automatic_state_1, automatic_state_2, suggested_state, custom_state, device_class",
|
||||||
|
[
|
||||||
|
# Distance
|
||||||
|
(
|
||||||
|
US_CUSTOMARY_SYSTEM,
|
||||||
|
METRIC_SYSTEM,
|
||||||
|
UnitOfLength.KILOMETERS,
|
||||||
|
UnitOfLength.MILES,
|
||||||
|
UnitOfLength.KILOMETERS,
|
||||||
|
UnitOfLength.METERS,
|
||||||
|
UnitOfLength.YARDS,
|
||||||
|
1000,
|
||||||
|
"621",
|
||||||
|
"1000",
|
||||||
|
"1000000",
|
||||||
|
"1093613",
|
||||||
|
SensorDeviceClass.DISTANCE,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_unit_conversion_update(
|
||||||
|
hass,
|
||||||
|
enable_custom_integrations,
|
||||||
|
unit_system_1,
|
||||||
|
unit_system_2,
|
||||||
|
native_unit,
|
||||||
|
automatic_unit_1,
|
||||||
|
automatic_unit_2,
|
||||||
|
suggested_unit,
|
||||||
|
custom_unit,
|
||||||
|
native_value,
|
||||||
|
automatic_state_1,
|
||||||
|
automatic_state_2,
|
||||||
|
suggested_state,
|
||||||
|
custom_state,
|
||||||
|
device_class,
|
||||||
|
):
|
||||||
|
"""Test suggested unit can be updated."""
|
||||||
|
|
||||||
|
hass.config.units = unit_system_1
|
||||||
|
|
||||||
|
entity_registry = er.async_get(hass)
|
||||||
|
platform = getattr(hass.components, "test.sensor")
|
||||||
|
platform.init(empty=True)
|
||||||
|
|
||||||
|
platform.ENTITIES["0"] = platform.MockSensor(
|
||||||
|
name="Test 0",
|
||||||
|
device_class=device_class,
|
||||||
|
native_unit_of_measurement=native_unit,
|
||||||
|
native_value=str(native_value),
|
||||||
|
unique_id="very_unique",
|
||||||
|
)
|
||||||
|
entity0 = platform.ENTITIES["0"]
|
||||||
|
|
||||||
|
platform.ENTITIES["1"] = platform.MockSensor(
|
||||||
|
name="Test 1",
|
||||||
|
device_class=device_class,
|
||||||
|
native_unit_of_measurement=native_unit,
|
||||||
|
native_value=str(native_value),
|
||||||
|
unique_id="very_unique_1",
|
||||||
|
)
|
||||||
|
entity1 = platform.ENTITIES["1"]
|
||||||
|
|
||||||
|
platform.ENTITIES["2"] = platform.MockSensor(
|
||||||
|
name="Test 2",
|
||||||
|
device_class=device_class,
|
||||||
|
native_unit_of_measurement=native_unit,
|
||||||
|
native_value=str(native_value),
|
||||||
|
suggested_unit_of_measurement=suggested_unit,
|
||||||
|
unique_id="very_unique_2",
|
||||||
|
)
|
||||||
|
entity2 = platform.ENTITIES["2"]
|
||||||
|
|
||||||
|
platform.ENTITIES["3"] = platform.MockSensor(
|
||||||
|
name="Test 3",
|
||||||
|
device_class=device_class,
|
||||||
|
native_unit_of_measurement=native_unit,
|
||||||
|
native_value=str(native_value),
|
||||||
|
suggested_unit_of_measurement=suggested_unit,
|
||||||
|
unique_id="very_unique_3",
|
||||||
|
)
|
||||||
|
entity3 = platform.ENTITIES["3"]
|
||||||
|
|
||||||
|
assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Registered entity -> Follow automatic unit conversion
|
||||||
|
state = hass.states.get(entity0.entity_id)
|
||||||
|
assert state.state == automatic_state_1
|
||||||
|
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == automatic_unit_1
|
||||||
|
# Assert the automatic unit conversion is stored in the registry
|
||||||
|
entry = entity_registry.async_get(entity0.entity_id)
|
||||||
|
assert entry.options == {
|
||||||
|
"sensor.private": {"suggested_unit_of_measurement": automatic_unit_1}
|
||||||
|
}
|
||||||
|
|
||||||
|
state = hass.states.get(entity1.entity_id)
|
||||||
|
assert state.state == automatic_state_1
|
||||||
|
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == automatic_unit_1
|
||||||
|
# Assert the automatic unit conversion is stored in the registry
|
||||||
|
entry = entity_registry.async_get(entity1.entity_id)
|
||||||
|
assert entry.options == {
|
||||||
|
"sensor.private": {"suggested_unit_of_measurement": automatic_unit_1}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Registered entity with suggested unit
|
||||||
|
state = hass.states.get(entity2.entity_id)
|
||||||
|
assert state.state == suggested_state
|
||||||
|
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == suggested_unit
|
||||||
|
# Assert the suggested unit is stored in the registry
|
||||||
|
entry = entity_registry.async_get(entity2.entity_id)
|
||||||
|
assert entry.options == {
|
||||||
|
"sensor.private": {"suggested_unit_of_measurement": suggested_unit}
|
||||||
|
}
|
||||||
|
|
||||||
|
state = hass.states.get(entity3.entity_id)
|
||||||
|
assert state.state == suggested_state
|
||||||
|
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == suggested_unit
|
||||||
|
# Assert the suggested unit is stored in the registry
|
||||||
|
entry = entity_registry.async_get(entity3.entity_id)
|
||||||
|
assert entry.options == {
|
||||||
|
"sensor.private": {"suggested_unit_of_measurement": suggested_unit}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set a custom unit, this should have priority over the automatic unit conversion
|
||||||
|
entity_registry.async_update_entity_options(
|
||||||
|
entity0.entity_id, "sensor", {"unit_of_measurement": custom_unit}
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
state = hass.states.get(entity0.entity_id)
|
||||||
|
assert state.state == custom_state
|
||||||
|
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit
|
||||||
|
|
||||||
|
entity_registry.async_update_entity_options(
|
||||||
|
entity2.entity_id, "sensor", {"unit_of_measurement": custom_unit}
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
state = hass.states.get(entity2.entity_id)
|
||||||
|
assert state.state == custom_state
|
||||||
|
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit
|
||||||
|
|
||||||
|
# Change unit system, states and units should be unchanged
|
||||||
|
hass.config.units = unit_system_2
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
state = hass.states.get(entity0.entity_id)
|
||||||
|
assert state.state == custom_state
|
||||||
|
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit
|
||||||
|
|
||||||
|
state = hass.states.get(entity1.entity_id)
|
||||||
|
assert state.state == automatic_state_1
|
||||||
|
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == automatic_unit_1
|
||||||
|
|
||||||
|
state = hass.states.get(entity2.entity_id)
|
||||||
|
assert state.state == custom_state
|
||||||
|
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit
|
||||||
|
|
||||||
|
state = hass.states.get(entity3.entity_id)
|
||||||
|
assert state.state == suggested_state
|
||||||
|
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == suggested_unit
|
||||||
|
|
||||||
|
# Update suggested unit
|
||||||
|
async_update_suggested_units(hass)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
state = hass.states.get(entity0.entity_id)
|
||||||
|
assert state.state == custom_state
|
||||||
|
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit
|
||||||
|
|
||||||
|
state = hass.states.get(entity1.entity_id)
|
||||||
|
assert state.state == automatic_state_2
|
||||||
|
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == automatic_unit_2
|
||||||
|
|
||||||
|
state = hass.states.get(entity2.entity_id)
|
||||||
|
assert state.state == custom_state
|
||||||
|
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit
|
||||||
|
|
||||||
|
state = hass.states.get(entity3.entity_id)
|
||||||
|
assert state.state == suggested_state
|
||||||
|
assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == suggested_unit
|
||||||
|
|
Loading…
Add table
Reference in a new issue