diff --git a/homeassistant/components/utility_meter/__init__.py b/homeassistant/components/utility_meter/__init__.py index c579a684406..c6a8635f831 100644 --- a/homeassistant/components/utility_meter/__init__.py +++ b/homeassistant/components/utility_meter/__init__.py @@ -11,12 +11,11 @@ from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN from homeassistant.config_entries import ConfigEntry from homeassistant.const import ATTR_ENTITY_ID, CONF_NAME, CONF_UNIQUE_ID, Platform from homeassistant.core import HomeAssistant, split_entity_id -from homeassistant.helpers import ( - device_registry as dr, - discovery, - entity_registry as er, -) +from homeassistant.helpers import discovery, entity_registry as er import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.device import ( + async_remove_stale_devices_links_keep_entity_device, +) from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.typing import ConfigType @@ -192,7 +191,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up Utility Meter from a config entry.""" - await async_remove_stale_device_links( + async_remove_stale_devices_links_keep_entity_device( hass, entry.entry_id, entry.options[CONF_SOURCE_SENSOR] ) @@ -266,27 +265,3 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> _LOGGER.info("Migration to version %s successful", config_entry.version) return True - - -async def async_remove_stale_device_links( - hass: HomeAssistant, entry_id: str, entity_id: str -) -> None: - """Remove device link for entry, the source device may have changed.""" - - device_registry = dr.async_get(hass) - entity_registry = er.async_get(hass) - - # Resolve source entity device - current_device_id = None - if ((source_entity := entity_registry.async_get(entity_id)) is not None) and ( - source_entity.device_id is not None - ): - current_device_id = source_entity.device_id - - devices_in_entry = device_registry.devices.get_devices_for_config_entry_id(entry_id) - - # Removes all devices from the config entry that are not the same as the current device - for device in devices_in_entry: - if device.id == current_device_id: - continue - device_registry.async_update_device(device.id, remove_config_entry_id=entry_id) diff --git a/homeassistant/components/utility_meter/select.py b/homeassistant/components/utility_meter/select.py index 461fee3ba9f..d5b1206d046 100644 --- a/homeassistant/components/utility_meter/select.py +++ b/homeassistant/components/utility_meter/select.py @@ -8,7 +8,7 @@ from homeassistant.components.select import SelectEntity from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_UNIQUE_ID from homeassistant.core import HomeAssistant -from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.helpers.device import async_device_info_to_link_from_entity from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.restore_state import RestoreEntity @@ -30,28 +30,10 @@ async def async_setup_entry( unique_id = config_entry.entry_id - registry = er.async_get(hass) - source_entity = registry.async_get(config_entry.options[CONF_SOURCE_SENSOR]) - dev_reg = dr.async_get(hass) - # Resolve source entity device - if ( - (source_entity is not None) - and (source_entity.device_id is not None) - and ( - ( - device := dev_reg.async_get( - device_id=source_entity.device_id, - ) - ) - is not None - ) - ): - device_info = DeviceInfo( - identifiers=device.identifiers, - connections=device.connections, - ) - else: - device_info = None + device_info = async_device_info_to_link_from_entity( + hass, + config_entry.options[CONF_SOURCE_SENSOR], + ) tariff_select = TariffSelect( name, diff --git a/homeassistant/components/utility_meter/sensor.py b/homeassistant/components/utility_meter/sensor.py index 4a68248f067..6b8c07c7ef7 100644 --- a/homeassistant/components/utility_meter/sensor.py +++ b/homeassistant/components/utility_meter/sensor.py @@ -37,12 +37,8 @@ from homeassistant.core import ( State, callback, ) -from homeassistant.helpers import ( - device_registry as dr, - entity_platform, - entity_registry as er, -) -from homeassistant.helpers.device_registry import DeviceInfo +from homeassistant.helpers import entity_platform, entity_registry as er +from homeassistant.helpers.device import async_device_info_to_link_from_entity from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.event import ( @@ -130,27 +126,10 @@ async def async_setup_entry( registry, config_entry.options[CONF_SOURCE_SENSOR] ) - source_entity = registry.async_get(source_entity_id) - dev_reg = dr.async_get(hass) - # Resolve source entity device - if ( - (source_entity is not None) - and (source_entity.device_id is not None) - and ( - ( - device := dev_reg.async_get( - device_id=source_entity.device_id, - ) - ) - is not None - ) - ): - device_info = DeviceInfo( - identifiers=device.identifiers, - connections=device.connections, - ) - else: - device_info = None + device_info = async_device_info_to_link_from_entity( + hass, + source_entity_id, + ) cron_pattern = None delta_values = config_entry.options[CONF_METER_DELTA_VALUES] diff --git a/homeassistant/helpers/device.py b/homeassistant/helpers/device.py new file mode 100644 index 00000000000..b9df721ec6c --- /dev/null +++ b/homeassistant/helpers/device.py @@ -0,0 +1,75 @@ +"""Provides useful helpers for handling devices.""" + +from homeassistant.core import HomeAssistant, callback + +from . import device_registry as dr, entity_registry as er + + +@callback +def async_entity_id_to_device_id( + hass: HomeAssistant, + entity_id_or_uuid: str, +) -> str | None: + """Resolve the device id to the entity id or entity uuid.""" + + ent_reg = er.async_get(hass) + + entity_id = er.async_validate_entity_id(ent_reg, entity_id_or_uuid) + if (entity := ent_reg.async_get(entity_id)) is None: + return None + + return entity.device_id + + +@callback +def async_device_info_to_link_from_entity( + hass: HomeAssistant, + entity_id_or_uuid: str, +) -> dr.DeviceInfo | None: + """DeviceInfo with information to link a device to a configuration entry in the link category from a entity id or entity uuid.""" + + dev_reg = dr.async_get(hass) + + if (device_id := async_entity_id_to_device_id(hass, entity_id_or_uuid)) is None or ( + device := dev_reg.async_get(device_id=device_id) + ) is None: + return None + + return dr.DeviceInfo( + identifiers=device.identifiers, + connections=device.connections, + ) + + +@callback +def async_remove_stale_devices_links_keep_entity_device( + hass: HomeAssistant, + entry_id: str, + source_entity_id_or_uuid: str, +) -> None: + """Remove the link between stales devices and a configuration entry, keeping only the device that the informed entity is linked to.""" + + async_remove_stale_devices_links_keep_current_device( + hass=hass, + entry_id=entry_id, + current_device_id=async_entity_id_to_device_id(hass, source_entity_id_or_uuid), + ) + + +@callback +def async_remove_stale_devices_links_keep_current_device( + hass: HomeAssistant, + entry_id: str, + current_device_id: str | None, +) -> None: + """Remove the link between stales devices and a configuration entry, keeping only the device informed. + + Device passed in the current_device_id parameter will be kept linked to the configuration entry. + """ + + dev_reg = dr.async_get(hass) + # Removes all devices from the config entry that are not the same as the current device + for device in dev_reg.devices.get_devices_for_config_entry_id(entry_id): + if device.id == current_device_id: + continue + dev_reg.async_update_device(device.id, remove_config_entry_id=entry_id) diff --git a/tests/components/utility_meter/test_select.py b/tests/components/utility_meter/test_select.py new file mode 100644 index 00000000000..61f6cbe75b9 --- /dev/null +++ b/tests/components/utility_meter/test_select.py @@ -0,0 +1,56 @@ +"""The tests for the utility_meter select platform.""" + +from homeassistant.components.utility_meter.const import DOMAIN +from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr, entity_registry as er + +from tests.common import MockConfigEntry + + +async def test_device_id( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test for source entity device for Utility Meter.""" + source_config_entry = MockConfigEntry() + source_config_entry.add_to_hass(hass) + source_device_entry = device_registry.async_get_or_create( + config_entry_id=source_config_entry.entry_id, + identifiers={("sensor", "identifier_test")}, + connections={("mac", "30:31:32:33:34:35")}, + ) + source_entity = entity_registry.async_get_or_create( + "sensor", + "test", + "source", + config_entry=source_config_entry, + device_id=source_device_entry.id, + ) + await hass.async_block_till_done() + assert entity_registry.async_get("sensor.test_source") is not None + + utility_meter_config_entry = MockConfigEntry( + data={}, + domain=DOMAIN, + options={ + "cycle": "monthly", + "delta_values": False, + "name": "Energy", + "net_consumption": False, + "offset": 0, + "periodically_resetting": True, + "source": "sensor.test_source", + "tariffs": ["peak", "offpeak"], + }, + title="Energy", + ) + + utility_meter_config_entry.add_to_hass(hass) + + assert await hass.config_entries.async_setup(utility_meter_config_entry.entry_id) + await hass.async_block_till_done() + + utility_meter_entity_select = entity_registry.async_get("select.energy") + assert utility_meter_entity_select is not None + assert utility_meter_entity_select.device_id == source_entity.device_id diff --git a/tests/helpers/test_device.py b/tests/helpers/test_device.py new file mode 100644 index 00000000000..9e29288027c --- /dev/null +++ b/tests/helpers/test_device.py @@ -0,0 +1,211 @@ +"""Tests for the Device Utils.""" + +import pytest +import voluptuous as vol + +from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.helpers.device import ( + async_device_info_to_link_from_entity, + async_entity_id_to_device_id, + async_remove_stale_devices_links_keep_current_device, + async_remove_stale_devices_links_keep_entity_device, +) + +from tests.common import MockConfigEntry + + +async def test_entity_id_to_device_id( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test returning an entity's device ID.""" + config_entry = MockConfigEntry(domain="my") + config_entry.add_to_hass(hass) + + device = device_registry.async_get_or_create( + identifiers={("test", "current_device")}, + connections={("mac", "30:31:32:33:34:00")}, + config_entry_id=config_entry.entry_id, + ) + assert device is not None + + # Entity registry + entity = entity_registry.async_get_or_create( + "sensor", + "test", + "source", + config_entry=config_entry, + device_id=device.id, + ) + await hass.async_block_till_done() + assert entity_registry.async_get("sensor.test_source") is not None + + device_id = async_entity_id_to_device_id( + hass, + entity_id_or_uuid=entity.entity_id, + ) + assert device_id == device.id + + with pytest.raises(vol.Invalid): + async_entity_id_to_device_id( + hass, + entity_id_or_uuid="unknown_uuid", + ) + + +async def test_device_info_to_link( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test for returning device info with device link information.""" + config_entry = MockConfigEntry(domain="my") + config_entry.add_to_hass(hass) + + device = device_registry.async_get_or_create( + identifiers={("test", "my_device")}, + connections={("mac", "30:31:32:33:34:00")}, + config_entry_id=config_entry.entry_id, + ) + assert device is not None + + # Source entity registry + source_entity = entity_registry.async_get_or_create( + "sensor", + "test", + "source", + config_entry=config_entry, + device_id=device.id, + ) + await hass.async_block_till_done() + assert entity_registry.async_get("sensor.test_source") is not None + + result = async_device_info_to_link_from_entity( + hass, entity_id_or_uuid=source_entity.entity_id + ) + assert result == { + "identifiers": {("test", "my_device")}, + "connections": {("mac", "30:31:32:33:34:00")}, + } + + # With a non-existent entity id + result = async_device_info_to_link_from_entity( + hass, entity_id_or_uuid="sensor.invalid" + ) + assert result is None + + +async def test_remove_stale_device_links_keep_entity_device( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test cleaning works for entity.""" + config_entry = MockConfigEntry(domain="hue") + config_entry.add_to_hass(hass) + + current_device = device_registry.async_get_or_create( + identifiers={("test", "current_device")}, + connections={("mac", "30:31:32:33:34:00")}, + config_entry_id=config_entry.entry_id, + ) + assert current_device is not None + + device_registry.async_get_or_create( + identifiers={("test", "stale_device_1")}, + connections={("mac", "30:31:32:33:34:01")}, + config_entry_id=config_entry.entry_id, + ) + + device_registry.async_get_or_create( + identifiers={("test", "stale_device_2")}, + connections={("mac", "30:31:32:33:34:02")}, + config_entry_id=config_entry.entry_id, + ) + + # Source entity registry + source_entity = entity_registry.async_get_or_create( + "sensor", + "test", + "source", + config_entry=config_entry, + device_id=current_device.id, + ) + await hass.async_block_till_done() + assert entity_registry.async_get("sensor.test_source") is not None + + devices_config_entry = device_registry.devices.get_devices_for_config_entry_id( + config_entry.entry_id + ) + + # 3 devices linked to the config entry are expected (1 current device + 2 stales) + assert len(devices_config_entry) == 3 + + # Manual cleanup should unlink stales devices from the config entry + async_remove_stale_devices_links_keep_entity_device( + hass, + entry_id=config_entry.entry_id, + source_entity_id_or_uuid=source_entity.entity_id, + ) + + devices_config_entry = device_registry.devices.get_devices_for_config_entry_id( + config_entry.entry_id + ) + + # After cleanup, only one device is expected to be linked to the configuration entry if at least source_entity_id_or_uuid or device_id was given, else zero + assert len(devices_config_entry) == 1 + + assert current_device in devices_config_entry + + +async def test_remove_stale_devices_links_keep_current_device( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, +) -> None: + """Test cleanup works for device id.""" + config_entry = MockConfigEntry(domain="hue") + config_entry.add_to_hass(hass) + + current_device = device_registry.async_get_or_create( + identifiers={("test", "current_device")}, + connections={("mac", "30:31:32:33:34:00")}, + config_entry_id=config_entry.entry_id, + ) + assert current_device is not None + + device_registry.async_get_or_create( + identifiers={("test", "stale_device_1")}, + connections={("mac", "30:31:32:33:34:01")}, + config_entry_id=config_entry.entry_id, + ) + + device_registry.async_get_or_create( + identifiers={("test", "stale_device_2")}, + connections={("mac", "30:31:32:33:34:02")}, + config_entry_id=config_entry.entry_id, + ) + + devices_config_entry = device_registry.devices.get_devices_for_config_entry_id( + config_entry.entry_id + ) + + # 3 devices linked to the config entry are expected (1 current device + 2 stales) + assert len(devices_config_entry) == 3 + + # Manual cleanup should unlink stales devices from the config entry + async_remove_stale_devices_links_keep_current_device( + hass, + entry_id=config_entry.entry_id, + current_device_id=current_device.id, + ) + + devices_config_entry = device_registry.devices.get_devices_for_config_entry_id( + config_entry.entry_id + ) + + # After cleanup, only one device is expected to be linked to the configuration entry + assert len(devices_config_entry) == 1 + + assert current_device in devices_config_entry