diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index cfafa63ec3a..02dce7bf967 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -15,6 +15,7 @@ from yarl import URL from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.core import ( + DOMAIN, Event, HomeAssistant, ReleaseChannel, @@ -75,6 +76,7 @@ class DeviceEntryDisabler(StrEnum): """What disabled a device entry.""" CONFIG_ENTRY = "config_entry" + DUPLICATE = "duplicate" INTEGRATION = "integration" USER = "user" @@ -522,12 +524,36 @@ class DeviceRegistryItems[_EntryTypeT: (DeviceEntry, DeletedDeviceEntry)]( for identifier in old_entry.identifiers: del self._identifiers[identifier] + def get_entries( + self, + identifiers: set[tuple[str, str]] | None, + connections: set[tuple[str, str]] | None, + ) -> set[str]: + """Get all matching entry ids from identifiers or connections.""" + entries = set() + if identifiers: + entries = { + self._identifiers[identifier].id + for identifier in identifiers + if identifier in self._identifiers + } + if not connections: + return entries + return entries | { + self._connections[connection].id + for connection in _normalize_connections(connections) + if connection in self._connections + } + def get_entry( self, identifiers: set[tuple[str, str]] | None, connections: set[tuple[str, str]] | None, ) -> _EntryTypeT | None: - """Get entry from identifiers or connections.""" + """Get the first matching entry from identifiers or connections. + + Identifiers are tried first, then connections. + """ if identifiers: for identifier in identifiers: if identifier in self._identifiers: @@ -754,9 +780,11 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): else: connections = _normalize_connections(connections) - device = self.async_get_device(identifiers=identifiers, connections=connections) + device_ids = self.devices.get_entries(identifiers, connections) - if device is None: + if len(device_ids) > 1: + device = self._merge_devices(device_ids) + elif not device_ids: deleted_device = self._async_get_deleted_device(identifiers, connections) if deleted_device is None: device = DeviceEntry(is_new=True) @@ -769,6 +797,8 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): # If creating a new device, default to the config entry name if device_info_type == "primary" and (not name or name is UNDEFINED): name = config_entry.title + else: + device = self.devices[next(iter(device_ids))] if default_manufacturer is not UNDEFINED and device.manufacturer is None: manufacturer = default_manufacturer @@ -796,7 +826,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): ) entry_type = DeviceEntryType(entry_type) - device = self.async_update_device( + updated_device = self.async_update_device( device.id, allow_collisions=True, add_config_entry_id=config_entry_id, @@ -819,8 +849,8 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): # This is safe because _async_update_device will always return a device # in this use case. - assert device - return device + assert updated_device + return updated_device @callback def async_update_device( # noqa: C901 @@ -1110,6 +1140,62 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): ) self.async_schedule_save() + @callback + def _merge_devices(self, device_ids: set[str]) -> DeviceEntry: + # Pick a device to be the main device. For now, we just pick the first + # device in the set + main_device = self.devices[next(iter(device_ids))] + + merged_config_entries = set() + merged_connections = set() + merged_identifiers = set() + + # Disable other devices, and clear their connections and identifiers + for device_id in device_ids: + device = self.devices[device_id] + merged_config_entries |= device.config_entries + merged_connections |= device.connections + merged_identifiers |= device.identifiers + + if device.id == main_device.id: + continue + + self.async_update_device( + device.id, + disabled_by=DeviceEntryDisabler.DUPLICATE, + new_connections=set(), + new_identifiers={(DOMAIN, device.id)}, + ) + + self.async_update_device( + main_device.id, + new_connections=merged_connections, + new_identifiers=merged_identifiers, + ) + for config_entry_id in merged_config_entries: + self.async_update_device( + main_device.id, + add_config_entry_id=config_entry_id, + ) + + return main_device + + @callback + def _find_collision(self) -> set[str] | None: + for device in self.devices.values(): + for identifier in device.identifiers: + if len(device_ids := self.devices.get_entries({identifier}, None)) > 1: + return device_ids + for connection in device.connections: + if len(device_ids := self.devices.get_entries({connection}, None)) > 1: + return device_ids + return None + + @callback + def _merge_collisions(self) -> None: + while collision := self._find_collision(): + self._merge_devices(collision) + async def async_load(self) -> None: """Load the device registry.""" async_setup_cleanup(self.hass, self) @@ -1169,6 +1255,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]): self.devices = devices self.deleted_devices = deleted_devices self._device_data = devices.data + self._merge_collisions() @callback def _data_to_save(self) -> dict[str, Any]: diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index fa57cc7557e..5ec104aba32 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -2501,7 +2501,14 @@ def test_all() -> None: help_test_all(dr) -@pytest.mark.parametrize(("enum"), list(dr.DeviceEntryDisabler)) +@pytest.mark.parametrize( + ("enum"), + [ + enum + for enum in dr.DeviceEntryDisabler + if enum != dr.DeviceEntryDisabler.DUPLICATE + ], +) def test_deprecated_constants( caplog: pytest.LogCaptureFixture, enum: dr.DeviceEntryDisabler, @@ -2822,17 +2829,19 @@ async def test_device_registry_connections_collision( hass: HomeAssistant, device_registry: dr.DeviceRegistry ) -> None: """Test connection collisions in the device registry.""" - config_entry = MockConfigEntry() - config_entry.add_to_hass(hass) + config_entry_1 = MockConfigEntry() + config_entry_1.add_to_hass(hass) + config_entry_2 = MockConfigEntry() + config_entry_2.add_to_hass(hass) device1 = device_registry.async_get_or_create( - config_entry_id=config_entry.entry_id, + config_entry_id=config_entry_1.entry_id, connections={(dr.CONNECTION_NETWORK_MAC, "none")}, manufacturer="manufacturer", model="model", ) device2 = device_registry.async_get_or_create( - config_entry_id=config_entry.entry_id, + config_entry_id=config_entry_1.entry_id, connections={(dr.CONNECTION_NETWORK_MAC, "none")}, manufacturer="manufacturer", model="model", @@ -2841,7 +2850,7 @@ async def test_device_registry_connections_collision( assert device1.id == device2.id device3 = device_registry.async_get_or_create( - config_entry_id=config_entry.entry_id, + config_entry_id=config_entry_2.entry_id, identifiers={("bridgeid", "0123")}, manufacturer="manufacturer", model="model", @@ -2891,7 +2900,7 @@ async def test_device_registry_connections_collision( # Attempt to implicitly merge connection for device3 with the same # connection that already exists in device1 device4 = device_registry.async_get_or_create( - config_entry_id=config_entry.entry_id, + config_entry_id=config_entry_2.entry_id, identifiers={("bridgeid", "0123")}, connections={ (dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"), @@ -2903,24 +2912,54 @@ async def test_device_registry_connections_collision( device3_refetched = device_registry.async_get(device3.id) device1_refetched = device_registry.async_get(device1.id) - assert not device1_refetched.connections.isdisjoint(device3_refetched.connections) + + # One of the devices should now: + # - Be disabled + # - Have all its connections removed + # - Have a single identifier + if device1_refetched.disabled_by is dr.DeviceEntryDisabler.DUPLICATE: + main_device = device3_refetched + duplicate_device = device1_refetched + else: + main_device = device1_refetched + duplicate_device = device3_refetched + + assert duplicate_device.disabled_by is dr.DeviceEntryDisabler.DUPLICATE + assert main_device.disabled_by is None + assert duplicate_device.config_entries in ( + {config_entry_1.entry_id}, + {config_entry_2.entry_id}, + ) + assert duplicate_device.connections == set() + assert duplicate_device.identifiers == {("homeassistant", duplicate_device.id)} + assert main_device.config_entries == { + config_entry_1.entry_id, + config_entry_2.entry_id, + } + assert main_device.connections == { + (dr.CONNECTION_NETWORK_MAC, "ee:ee:ee:ee:ee:ee"), + (dr.CONNECTION_NETWORK_MAC, "none"), + } + assert main_device.identifiers == {("bridgeid", "0123")} async def test_device_registry_identifiers_collision( hass: HomeAssistant, device_registry: dr.DeviceRegistry ) -> None: """Test identifiers collisions in the device registry.""" - config_entry = MockConfigEntry() - config_entry.add_to_hass(hass) + config_entry_1 = MockConfigEntry() + config_entry_1.add_to_hass(hass) + config_entry_2 = MockConfigEntry() + config_entry_2.add_to_hass(hass) device1 = device_registry.async_get_or_create( - config_entry_id=config_entry.entry_id, + config_entry_id=config_entry_1.entry_id, identifiers={("bridgeid", "0123")}, manufacturer="manufacturer", model="model", ) device2 = device_registry.async_get_or_create( - config_entry_id=config_entry.entry_id, + config_entry_id=config_entry_1.entry_id, identifiers={("bridgeid", "0123")}, manufacturer="manufacturer", model="model", @@ -2929,7 +2968,7 @@ async def test_device_registry_identifiers_collision( assert device1.id == device2.id device3 = device_registry.async_get_or_create( - config_entry_id=config_entry.entry_id, + config_entry_id=config_entry_2.entry_id, identifiers={("bridgeid", "4567")}, manufacturer="manufacturer", model="model", @@ -2971,7 +3010,7 @@ async def test_device_registry_identifiers_collision( # Attempt to implicitly merge identifiers for device3 with the same # connection that already exists in device1 device4 = device_registry.async_get_or_create( - config_entry_id=config_entry.entry_id, + config_entry_id=config_entry_2.entry_id, identifiers={("bridgeid", "4567"), ("bridgeid", "0123")}, ) assert len(device_registry.devices) == 2 @@ -2979,7 +3018,32 @@ async def test_device_registry_identifiers_collision( device3_refetched = device_registry.async_get(device3.id) device1_refetched = device_registry.async_get(device1.id) - assert not device1_refetched.identifiers.isdisjoint(device3_refetched.identifiers) + + # One of the devices should now: + # - Be disabled + # - Have all its connections removed + # - Have a single identifier + if device1_refetched.disabled_by is dr.DeviceEntryDisabler.DUPLICATE: + main_device = device3_refetched + duplicate_device = device1_refetched + else: + main_device = device1_refetched + duplicate_device = device3_refetched + + assert duplicate_device.disabled_by is dr.DeviceEntryDisabler.DUPLICATE + assert main_device.disabled_by is None + assert duplicate_device.config_entries in ( + {config_entry_1.entry_id}, + {config_entry_2.entry_id}, + ) + assert duplicate_device.connections == set() + assert duplicate_device.identifiers == {("homeassistant", duplicate_device.id)} + assert main_device.config_entries == { + config_entry_1.entry_id, + config_entry_2.entry_id, + } + assert main_device.connections == set() + assert main_device.identifiers == {("bridgeid", "0123"), ("bridgeid", "4567")} async def test_primary_config_entry(