Compare commits

...
Sign in to create a new pull request.

3 commits

Author SHA1 Message Date
Erik
643003c47a Merge colliding devices when loading the device registry 2024-06-26 16:45:59 +02:00
Erik
d77c4cda0c Also merge config entries 2024-06-26 15:27:23 +02:00
Erik
8d06baf0a5 Merge devices on connection or identifier collision 2024-06-26 14:11:42 +02:00
2 changed files with 172 additions and 21 deletions

View file

@ -15,6 +15,7 @@ from yarl import URL
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import (
DOMAIN,
Event,
HomeAssistant,
ReleaseChannel,
@ -75,6 +76,7 @@ class DeviceEntryDisabler(StrEnum):
"""What disabled a device entry."""
CONFIG_ENTRY = "config_entry"
DUPLICATE = "duplicate"
INTEGRATION = "integration"
USER = "user"
@ -522,12 +524,36 @@ class DeviceRegistryItems[_EntryTypeT: (DeviceEntry, DeletedDeviceEntry)](
for identifier in old_entry.identifiers:
del self._identifiers[identifier]
def get_entries(
self,
identifiers: set[tuple[str, str]] | None,
connections: set[tuple[str, str]] | None,
) -> set[str]:
"""Get all matching entry ids from identifiers or connections."""
entries = set()
if identifiers:
entries = {
self._identifiers[identifier].id
for identifier in identifiers
if identifier in self._identifiers
}
if not connections:
return entries
return entries | {
self._connections[connection].id
for connection in _normalize_connections(connections)
if connection in self._connections
}
def get_entry(
self,
identifiers: set[tuple[str, str]] | None,
connections: set[tuple[str, str]] | None,
) -> _EntryTypeT | None:
"""Get entry from identifiers or connections."""
"""Get the first matching entry from identifiers or connections.
Identifiers are tried first, then connections.
"""
if identifiers:
for identifier in identifiers:
if identifier in self._identifiers:
@ -754,9 +780,11 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
else:
connections = _normalize_connections(connections)
device = self.async_get_device(identifiers=identifiers, connections=connections)
device_ids = self.devices.get_entries(identifiers, connections)
if device is None:
if len(device_ids) > 1:
device = self._merge_devices(device_ids)
elif not device_ids:
deleted_device = self._async_get_deleted_device(identifiers, connections)
if deleted_device is None:
device = DeviceEntry(is_new=True)
@ -769,6 +797,8 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
# If creating a new device, default to the config entry name
if device_info_type == "primary" and (not name or name is UNDEFINED):
name = config_entry.title
else:
device = self.devices[next(iter(device_ids))]
if default_manufacturer is not UNDEFINED and device.manufacturer is None:
manufacturer = default_manufacturer
@ -796,7 +826,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
)
entry_type = DeviceEntryType(entry_type)
device = self.async_update_device(
updated_device = self.async_update_device(
device.id,
allow_collisions=True,
add_config_entry_id=config_entry_id,
@ -819,8 +849,8 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
# This is safe because _async_update_device will always return a device
# in this use case.
assert device
return device
assert updated_device
return updated_device
@callback
def async_update_device( # noqa: C901
@ -1110,6 +1140,62 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
)
self.async_schedule_save()
@callback
def _merge_devices(self, device_ids: set[str]) -> DeviceEntry:
# Pick a device to be the main device. For now, we just pick the first
# device in the set
main_device = self.devices[next(iter(device_ids))]
merged_config_entries = set()
merged_connections = set()
merged_identifiers = set()
# Disable other devices, and clear their connections and identifiers
for device_id in device_ids:
device = self.devices[device_id]
merged_config_entries |= device.config_entries
merged_connections |= device.connections
merged_identifiers |= device.identifiers
if device.id == main_device.id:
continue
self.async_update_device(
device.id,
disabled_by=DeviceEntryDisabler.DUPLICATE,
new_connections=set(),
new_identifiers={(DOMAIN, device.id)},
)
self.async_update_device(
main_device.id,
new_connections=merged_connections,
new_identifiers=merged_identifiers,
)
for config_entry_id in merged_config_entries:
self.async_update_device(
main_device.id,
add_config_entry_id=config_entry_id,
)
return main_device
@callback
def _find_collision(self) -> set[str] | None:
for device in self.devices.values():
for identifier in device.identifiers:
if len(device_ids := self.devices.get_entries({identifier}, None)) > 1:
return device_ids
for connection in device.connections:
if len(device_ids := self.devices.get_entries({connection}, None)) > 1:
return device_ids
return None
@callback
def _merge_collisions(self) -> None:
while collision := self._find_collision():
self._merge_devices(collision)
async def async_load(self) -> None:
"""Load the device registry."""
async_setup_cleanup(self.hass, self)
@ -1169,6 +1255,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
self.devices = devices
self.deleted_devices = deleted_devices
self._device_data = devices.data
self._merge_collisions()
@callback
def _data_to_save(self) -> dict[str, Any]:

View file

@ -2501,7 +2501,14 @@ def test_all() -> None:
help_test_all(dr)
@pytest.mark.parametrize(("enum"), list(dr.DeviceEntryDisabler))
@pytest.mark.parametrize(
("enum"),
[
enum
for enum in dr.DeviceEntryDisabler
if enum != dr.DeviceEntryDisabler.DUPLICATE
],
)
def test_deprecated_constants(
caplog: pytest.LogCaptureFixture,
enum: dr.DeviceEntryDisabler,
@ -2822,17 +2829,19 @@ async def test_device_registry_connections_collision(
hass: HomeAssistant, device_registry: dr.DeviceRegistry
) -> None:
"""Test connection collisions in the device registry."""
config_entry = MockConfigEntry()
config_entry.add_to_hass(hass)
config_entry_1 = MockConfigEntry()
config_entry_1.add_to_hass(hass)
config_entry_2 = MockConfigEntry()
config_entry_2.add_to_hass(hass)
device1 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
config_entry_id=config_entry_1.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "none")},
manufacturer="manufacturer",
model="model",
)
device2 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
config_entry_id=config_entry_1.entry_id,
connections={(dr.CONNECTION_NETWORK_MAC, "none")},
manufacturer="manufacturer",
model="model",
@ -2841,7 +2850,7 @@ async def test_device_registry_connections_collision(
assert device1.id == device2.id
device3 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
config_entry_id=config_entry_2.entry_id,
identifiers={("bridgeid", "0123")},
manufacturer="manufacturer",
model="model",
@ -2891,7 +2900,7 @@ async def test_device_registry_connections_collision(
# Attempt to implicitly merge connection for device3 with the same
# connection that already exists in device1
device4 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
config_entry_id=config_entry_2.entry_id,
identifiers={("bridgeid", "0123")},
connections={
(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"),
@ -2903,24 +2912,54 @@ async def test_device_registry_connections_collision(
device3_refetched = device_registry.async_get(device3.id)
device1_refetched = device_registry.async_get(device1.id)
assert not device1_refetched.connections.isdisjoint(device3_refetched.connections)
# One of the devices should now:
# - Be disabled
# - Have all its connections removed
# - Have a single identifier
if device1_refetched.disabled_by is dr.DeviceEntryDisabler.DUPLICATE:
main_device = device3_refetched
duplicate_device = device1_refetched
else:
main_device = device1_refetched
duplicate_device = device3_refetched
assert duplicate_device.disabled_by is dr.DeviceEntryDisabler.DUPLICATE
assert main_device.disabled_by is None
assert duplicate_device.config_entries in (
{config_entry_1.entry_id},
{config_entry_2.entry_id},
)
assert duplicate_device.connections == set()
assert duplicate_device.identifiers == {("homeassistant", duplicate_device.id)}
assert main_device.config_entries == {
config_entry_1.entry_id,
config_entry_2.entry_id,
}
assert main_device.connections == {
(dr.CONNECTION_NETWORK_MAC, "ee:ee:ee:ee:ee:ee"),
(dr.CONNECTION_NETWORK_MAC, "none"),
}
assert main_device.identifiers == {("bridgeid", "0123")}
async def test_device_registry_identifiers_collision(
hass: HomeAssistant, device_registry: dr.DeviceRegistry
) -> None:
"""Test identifiers collisions in the device registry."""
config_entry = MockConfigEntry()
config_entry.add_to_hass(hass)
config_entry_1 = MockConfigEntry()
config_entry_1.add_to_hass(hass)
config_entry_2 = MockConfigEntry()
config_entry_2.add_to_hass(hass)
device1 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
config_entry_id=config_entry_1.entry_id,
identifiers={("bridgeid", "0123")},
manufacturer="manufacturer",
model="model",
)
device2 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
config_entry_id=config_entry_1.entry_id,
identifiers={("bridgeid", "0123")},
manufacturer="manufacturer",
model="model",
@ -2929,7 +2968,7 @@ async def test_device_registry_identifiers_collision(
assert device1.id == device2.id
device3 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
config_entry_id=config_entry_2.entry_id,
identifiers={("bridgeid", "4567")},
manufacturer="manufacturer",
model="model",
@ -2971,7 +3010,7 @@ async def test_device_registry_identifiers_collision(
# Attempt to implicitly merge identifiers for device3 with the same
# connection that already exists in device1
device4 = device_registry.async_get_or_create(
config_entry_id=config_entry.entry_id,
config_entry_id=config_entry_2.entry_id,
identifiers={("bridgeid", "4567"), ("bridgeid", "0123")},
)
assert len(device_registry.devices) == 2
@ -2979,7 +3018,32 @@ async def test_device_registry_identifiers_collision(
device3_refetched = device_registry.async_get(device3.id)
device1_refetched = device_registry.async_get(device1.id)
assert not device1_refetched.identifiers.isdisjoint(device3_refetched.identifiers)
# One of the devices should now:
# - Be disabled
# - Have all its connections removed
# - Have a single identifier
if device1_refetched.disabled_by is dr.DeviceEntryDisabler.DUPLICATE:
main_device = device3_refetched
duplicate_device = device1_refetched
else:
main_device = device1_refetched
duplicate_device = device3_refetched
assert duplicate_device.disabled_by is dr.DeviceEntryDisabler.DUPLICATE
assert main_device.disabled_by is None
assert duplicate_device.config_entries in (
{config_entry_1.entry_id},
{config_entry_2.entry_id},
)
assert duplicate_device.connections == set()
assert duplicate_device.identifiers == {("homeassistant", duplicate_device.id)}
assert main_device.config_entries == {
config_entry_1.entry_id,
config_entry_2.entry_id,
}
assert main_device.connections == set()
assert main_device.identifiers == {("bridgeid", "0123"), ("bridgeid", "4567")}
async def test_primary_config_entry(