diff --git a/homeassistant/components/config/core.py b/homeassistant/components/config/core.py index 43bce39082d..999e9433cbb 100644 --- a/homeassistant/components/config/core.py +++ b/homeassistant/components/config/core.py @@ -6,6 +6,7 @@ import voluptuous as vol from homeassistant.components import websocket_api from homeassistant.components.http import HomeAssistantView +from homeassistant.components.sensor import async_update_suggested_units from homeassistant.config import async_check_ha_config_file from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv @@ -40,17 +41,18 @@ class CheckConfigView(HomeAssistantView): @websocket_api.websocket_command( { "type": "config/core/update", - vol.Optional("latitude"): cv.latitude, - vol.Optional("longitude"): cv.longitude, + vol.Optional("country"): cv.country, + vol.Optional("currency"): cv.currency, vol.Optional("elevation"): int, - vol.Optional("unit_system"): unit_system.validate_unit_system, - vol.Optional("location_name"): str, - vol.Optional("time_zone"): cv.time_zone, vol.Optional("external_url"): vol.Any(cv.url_no_path, None), vol.Optional("internal_url"): vol.Any(cv.url_no_path, None), - vol.Optional("currency"): cv.currency, - vol.Optional("country"): cv.country, vol.Optional("language"): cv.language, + vol.Optional("latitude"): cv.latitude, + vol.Optional("location_name"): str, + vol.Optional("longitude"): cv.longitude, + vol.Optional("time_zone"): cv.time_zone, + vol.Optional("update_units"): bool, + vol.Optional("unit_system"): unit_system.validate_unit_system, } ) @websocket_api.async_response @@ -64,8 +66,12 @@ async def websocket_update_config( data.pop("id") data.pop("type") + update_units = data.pop("update_units", False) + try: await hass.config.async_update(**data) + if update_units: + async_update_suggested_units(hass) connection.send_result(msg["id"]) except ValueError as err: connection.send_error(msg["id"], "invalid_info", str(err)) diff --git a/homeassistant/components/sensor/__init__.py b/homeassistant/components/sensor/__init__.py index 42d36f59733..0222e6f00ed 100644 --- a/homeassistant/components/sensor/__init__.py +++ b/homeassistant/components/sensor/__init__.py @@ -730,6 +730,17 @@ class SensorEntity(Entity): def async_registry_entry_updated(self) -> None: """Run when the entity registry entry has been updated.""" self._sensor_option_precision = self._custom_precision_or_none() + assert self.registry_entry + if ( + sensor_options := self.registry_entry.options.get(f"{DOMAIN}.private") + ) and "refresh_initial_entity_options" in sensor_options: + registry = er.async_get(self.hass) + initial_options = self.get_initial_entity_options() or {} + registry.async_update_entity_options( + self.entity_id, + f"{DOMAIN}.private", + initial_options.get(f"{DOMAIN}.private"), + ) self._sensor_option_unit_of_measurement = self._custom_unit_or_undef( DOMAIN, CONF_UNIT_OF_MEASUREMENT ) @@ -808,3 +819,21 @@ class RestoreSensor(SensorEntity, RestoreEntity): if (restored_last_extra_data := await self.async_get_last_extra_data()) is None: return None return SensorExtraStoredData.from_dict(restored_last_extra_data.as_dict()) + + +@callback +def async_update_suggested_units(hass: HomeAssistant) -> None: + """Update the suggested_unit_of_measurement according to the unit system.""" + registry = er.async_get(hass) + + for entry in registry.entities.values(): + if entry.domain != DOMAIN: + continue + + sensor_private_options = dict(entry.options.get(f"{DOMAIN}.private", {})) + sensor_private_options["refresh_initial_entity_options"] = True + registry.async_update_entity_options( + entry.entity_id, + f"{DOMAIN}.private", + sensor_private_options, + ) diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 08130500f11..30bd52cfffc 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -859,11 +859,18 @@ class EntityRegistry: @callback def async_update_entity_options( - self, entity_id: str, domain: str, options: dict[str, Any] + self, entity_id: str, domain: str, options: Mapping[str, Any] | None ) -> RegistryEntry: - """Update entity options.""" + """Update entity options for a domain. + + If the domain options are set to None, they will be removed. + """ old = self.entities[entity_id] - new_options: EntityOptionsType = {**old.options, domain: options} + new_options = { + key: value for key, value in old.options.items() if key != domain + } + if options is not None: + new_options[domain] = options return self._async_update_entity(entity_id, options=new_options) async def async_load(self) -> None: diff --git a/tests/components/config/test_core.py b/tests/components/config/test_core.py index f3283d32972..cf8e76f6653 100644 --- a/tests/components/config/test_core.py +++ b/tests/components/config/test_core.py @@ -7,7 +7,11 @@ import pytest from homeassistant.bootstrap import async_setup_component from homeassistant.components import config from homeassistant.components.websocket_api.const import TYPE_RESULT -from homeassistant.const import CONF_UNIT_SYSTEM, CONF_UNIT_SYSTEM_IMPERIAL +from homeassistant.const import ( + CONF_UNIT_SYSTEM, + CONF_UNIT_SYSTEM_IMPERIAL, + CONF_UNIT_SYSTEM_METRIC, +) from homeassistant.util import dt as dt_util, location from homeassistant.util.unit_system import US_CUSTOMARY_SYSTEM @@ -64,7 +68,9 @@ async def test_websocket_core_update(hass, client): assert hass.config.country != "SE" assert hass.config.language != "sv" - with patch("homeassistant.util.dt.set_default_time_zone") as mock_set_tz: + with patch("homeassistant.util.dt.set_default_time_zone") as mock_set_tz, patch( + "homeassistant.components.config.core.async_update_suggested_units" + ) as mock_update_sensor_units: await client.send_json( { "id": 5, @@ -85,6 +91,8 @@ async def test_websocket_core_update(hass, client): msg = await client.receive_json() + mock_update_sensor_units.assert_not_called() + assert msg["id"] == 5 assert msg["type"] == TYPE_RESULT assert msg["success"] @@ -100,6 +108,22 @@ async def test_websocket_core_update(hass, client): assert len(mock_set_tz.mock_calls) == 1 assert mock_set_tz.mock_calls[0][1][0] == dt_util.get_time_zone("America/New_York") + with patch("homeassistant.util.dt.set_default_time_zone") as mock_set_tz, patch( + "homeassistant.components.config.core.async_update_suggested_units" + ) as mock_update_sensor_units: + await client.send_json( + { + "id": 6, + "type": "config/core/update", + CONF_UNIT_SYSTEM: CONF_UNIT_SYSTEM_METRIC, + "update_units": True, + } + ) + + msg = await client.receive_json() + + mock_update_sensor_units.assert_called_once() + async def test_websocket_core_update_not_admin(hass, hass_ws_client, hass_admin_user): """Test core config fails for non admin.""" diff --git a/tests/components/sensor/test_init.py b/tests/components/sensor/test_init.py index b4da7f19b5a..1901381ed1c 100644 --- a/tests/components/sensor/test_init.py +++ b/tests/components/sensor/test_init.py @@ -12,6 +12,7 @@ from homeassistant.components.sensor import ( DEVICE_CLASS_UNITS, SensorDeviceClass, SensorStateClass, + async_update_suggested_units, ) from homeassistant.const import ( ATTR_UNIT_OF_MEASUREMENT, @@ -1685,3 +1686,191 @@ async def test_numeric_state_expected_helper( assert state is not None assert entity0._numeric_state_expected == is_numeric + + +@pytest.mark.parametrize( + "unit_system_1, unit_system_2, native_unit, automatic_unit_1, automatic_unit_2, suggested_unit, custom_unit, native_value, automatic_state_1, automatic_state_2, suggested_state, custom_state, device_class", + [ + # Distance + ( + US_CUSTOMARY_SYSTEM, + METRIC_SYSTEM, + UnitOfLength.KILOMETERS, + UnitOfLength.MILES, + UnitOfLength.KILOMETERS, + UnitOfLength.METERS, + UnitOfLength.YARDS, + 1000, + "621", + "1000", + "1000000", + "1093613", + SensorDeviceClass.DISTANCE, + ), + ], +) +async def test_unit_conversion_update( + hass, + enable_custom_integrations, + unit_system_1, + unit_system_2, + native_unit, + automatic_unit_1, + automatic_unit_2, + suggested_unit, + custom_unit, + native_value, + automatic_state_1, + automatic_state_2, + suggested_state, + custom_state, + device_class, +): + """Test suggested unit can be updated.""" + + hass.config.units = unit_system_1 + + entity_registry = er.async_get(hass) + platform = getattr(hass.components, "test.sensor") + platform.init(empty=True) + + platform.ENTITIES["0"] = platform.MockSensor( + name="Test 0", + device_class=device_class, + native_unit_of_measurement=native_unit, + native_value=str(native_value), + unique_id="very_unique", + ) + entity0 = platform.ENTITIES["0"] + + platform.ENTITIES["1"] = platform.MockSensor( + name="Test 1", + device_class=device_class, + native_unit_of_measurement=native_unit, + native_value=str(native_value), + unique_id="very_unique_1", + ) + entity1 = platform.ENTITIES["1"] + + platform.ENTITIES["2"] = platform.MockSensor( + name="Test 2", + device_class=device_class, + native_unit_of_measurement=native_unit, + native_value=str(native_value), + suggested_unit_of_measurement=suggested_unit, + unique_id="very_unique_2", + ) + entity2 = platform.ENTITIES["2"] + + platform.ENTITIES["3"] = platform.MockSensor( + name="Test 3", + device_class=device_class, + native_unit_of_measurement=native_unit, + native_value=str(native_value), + suggested_unit_of_measurement=suggested_unit, + unique_id="very_unique_3", + ) + entity3 = platform.ENTITIES["3"] + + assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}}) + await hass.async_block_till_done() + + # Registered entity -> Follow automatic unit conversion + state = hass.states.get(entity0.entity_id) + assert state.state == automatic_state_1 + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == automatic_unit_1 + # Assert the automatic unit conversion is stored in the registry + entry = entity_registry.async_get(entity0.entity_id) + assert entry.options == { + "sensor.private": {"suggested_unit_of_measurement": automatic_unit_1} + } + + state = hass.states.get(entity1.entity_id) + assert state.state == automatic_state_1 + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == automatic_unit_1 + # Assert the automatic unit conversion is stored in the registry + entry = entity_registry.async_get(entity1.entity_id) + assert entry.options == { + "sensor.private": {"suggested_unit_of_measurement": automatic_unit_1} + } + + # Registered entity with suggested unit + state = hass.states.get(entity2.entity_id) + assert state.state == suggested_state + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == suggested_unit + # Assert the suggested unit is stored in the registry + entry = entity_registry.async_get(entity2.entity_id) + assert entry.options == { + "sensor.private": {"suggested_unit_of_measurement": suggested_unit} + } + + state = hass.states.get(entity3.entity_id) + assert state.state == suggested_state + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == suggested_unit + # Assert the suggested unit is stored in the registry + entry = entity_registry.async_get(entity3.entity_id) + assert entry.options == { + "sensor.private": {"suggested_unit_of_measurement": suggested_unit} + } + + # Set a custom unit, this should have priority over the automatic unit conversion + entity_registry.async_update_entity_options( + entity0.entity_id, "sensor", {"unit_of_measurement": custom_unit} + ) + await hass.async_block_till_done() + + state = hass.states.get(entity0.entity_id) + assert state.state == custom_state + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit + + entity_registry.async_update_entity_options( + entity2.entity_id, "sensor", {"unit_of_measurement": custom_unit} + ) + await hass.async_block_till_done() + + state = hass.states.get(entity2.entity_id) + assert state.state == custom_state + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit + + # Change unit system, states and units should be unchanged + hass.config.units = unit_system_2 + await hass.async_block_till_done() + + state = hass.states.get(entity0.entity_id) + assert state.state == custom_state + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit + + state = hass.states.get(entity1.entity_id) + assert state.state == automatic_state_1 + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == automatic_unit_1 + + state = hass.states.get(entity2.entity_id) + assert state.state == custom_state + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit + + state = hass.states.get(entity3.entity_id) + assert state.state == suggested_state + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == suggested_unit + + # Update suggested unit + async_update_suggested_units(hass) + await hass.async_block_till_done() + await hass.async_block_till_done() + await hass.async_block_till_done() + await hass.async_block_till_done() + + state = hass.states.get(entity0.entity_id) + assert state.state == custom_state + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit + + state = hass.states.get(entity1.entity_id) + assert state.state == automatic_state_2 + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == automatic_unit_2 + + state = hass.states.get(entity2.entity_id) + assert state.state == custom_state + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == custom_unit + + state = hass.states.get(entity3.entity_id) + assert state.state == suggested_state + assert state.attributes[ATTR_UNIT_OF_MEASUREMENT] == suggested_unit