Compare commits
3 commits
dev
...
device_reg
Author | SHA1 | Date | |
---|---|---|---|
|
643003c47a | ||
|
d77c4cda0c | ||
|
8d06baf0a5 |
2 changed files with 172 additions and 21 deletions
|
@ -15,6 +15,7 @@ from yarl import URL
|
|||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import (
|
||||
DOMAIN,
|
||||
Event,
|
||||
HomeAssistant,
|
||||
ReleaseChannel,
|
||||
|
@ -75,6 +76,7 @@ class DeviceEntryDisabler(StrEnum):
|
|||
"""What disabled a device entry."""
|
||||
|
||||
CONFIG_ENTRY = "config_entry"
|
||||
DUPLICATE = "duplicate"
|
||||
INTEGRATION = "integration"
|
||||
USER = "user"
|
||||
|
||||
|
@ -522,12 +524,36 @@ class DeviceRegistryItems[_EntryTypeT: (DeviceEntry, DeletedDeviceEntry)](
|
|||
for identifier in old_entry.identifiers:
|
||||
del self._identifiers[identifier]
|
||||
|
||||
def get_entries(
|
||||
self,
|
||||
identifiers: set[tuple[str, str]] | None,
|
||||
connections: set[tuple[str, str]] | None,
|
||||
) -> set[str]:
|
||||
"""Get all matching entry ids from identifiers or connections."""
|
||||
entries = set()
|
||||
if identifiers:
|
||||
entries = {
|
||||
self._identifiers[identifier].id
|
||||
for identifier in identifiers
|
||||
if identifier in self._identifiers
|
||||
}
|
||||
if not connections:
|
||||
return entries
|
||||
return entries | {
|
||||
self._connections[connection].id
|
||||
for connection in _normalize_connections(connections)
|
||||
if connection in self._connections
|
||||
}
|
||||
|
||||
def get_entry(
|
||||
self,
|
||||
identifiers: set[tuple[str, str]] | None,
|
||||
connections: set[tuple[str, str]] | None,
|
||||
) -> _EntryTypeT | None:
|
||||
"""Get entry from identifiers or connections."""
|
||||
"""Get the first matching entry from identifiers or connections.
|
||||
|
||||
Identifiers are tried first, then connections.
|
||||
"""
|
||||
if identifiers:
|
||||
for identifier in identifiers:
|
||||
if identifier in self._identifiers:
|
||||
|
@ -754,9 +780,11 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||
else:
|
||||
connections = _normalize_connections(connections)
|
||||
|
||||
device = self.async_get_device(identifiers=identifiers, connections=connections)
|
||||
device_ids = self.devices.get_entries(identifiers, connections)
|
||||
|
||||
if device is None:
|
||||
if len(device_ids) > 1:
|
||||
device = self._merge_devices(device_ids)
|
||||
elif not device_ids:
|
||||
deleted_device = self._async_get_deleted_device(identifiers, connections)
|
||||
if deleted_device is None:
|
||||
device = DeviceEntry(is_new=True)
|
||||
|
@ -769,6 +797,8 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||
# If creating a new device, default to the config entry name
|
||||
if device_info_type == "primary" and (not name or name is UNDEFINED):
|
||||
name = config_entry.title
|
||||
else:
|
||||
device = self.devices[next(iter(device_ids))]
|
||||
|
||||
if default_manufacturer is not UNDEFINED and device.manufacturer is None:
|
||||
manufacturer = default_manufacturer
|
||||
|
@ -796,7 +826,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||
)
|
||||
entry_type = DeviceEntryType(entry_type)
|
||||
|
||||
device = self.async_update_device(
|
||||
updated_device = self.async_update_device(
|
||||
device.id,
|
||||
allow_collisions=True,
|
||||
add_config_entry_id=config_entry_id,
|
||||
|
@ -819,8 +849,8 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||
|
||||
# This is safe because _async_update_device will always return a device
|
||||
# in this use case.
|
||||
assert device
|
||||
return device
|
||||
assert updated_device
|
||||
return updated_device
|
||||
|
||||
@callback
|
||||
def async_update_device( # noqa: C901
|
||||
|
@ -1110,6 +1140,62 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||
)
|
||||
self.async_schedule_save()
|
||||
|
||||
@callback
|
||||
def _merge_devices(self, device_ids: set[str]) -> DeviceEntry:
|
||||
# Pick a device to be the main device. For now, we just pick the first
|
||||
# device in the set
|
||||
main_device = self.devices[next(iter(device_ids))]
|
||||
|
||||
merged_config_entries = set()
|
||||
merged_connections = set()
|
||||
merged_identifiers = set()
|
||||
|
||||
# Disable other devices, and clear their connections and identifiers
|
||||
for device_id in device_ids:
|
||||
device = self.devices[device_id]
|
||||
merged_config_entries |= device.config_entries
|
||||
merged_connections |= device.connections
|
||||
merged_identifiers |= device.identifiers
|
||||
|
||||
if device.id == main_device.id:
|
||||
continue
|
||||
|
||||
self.async_update_device(
|
||||
device.id,
|
||||
disabled_by=DeviceEntryDisabler.DUPLICATE,
|
||||
new_connections=set(),
|
||||
new_identifiers={(DOMAIN, device.id)},
|
||||
)
|
||||
|
||||
self.async_update_device(
|
||||
main_device.id,
|
||||
new_connections=merged_connections,
|
||||
new_identifiers=merged_identifiers,
|
||||
)
|
||||
for config_entry_id in merged_config_entries:
|
||||
self.async_update_device(
|
||||
main_device.id,
|
||||
add_config_entry_id=config_entry_id,
|
||||
)
|
||||
|
||||
return main_device
|
||||
|
||||
@callback
|
||||
def _find_collision(self) -> set[str] | None:
|
||||
for device in self.devices.values():
|
||||
for identifier in device.identifiers:
|
||||
if len(device_ids := self.devices.get_entries({identifier}, None)) > 1:
|
||||
return device_ids
|
||||
for connection in device.connections:
|
||||
if len(device_ids := self.devices.get_entries({connection}, None)) > 1:
|
||||
return device_ids
|
||||
return None
|
||||
|
||||
@callback
|
||||
def _merge_collisions(self) -> None:
|
||||
while collision := self._find_collision():
|
||||
self._merge_devices(collision)
|
||||
|
||||
async def async_load(self) -> None:
|
||||
"""Load the device registry."""
|
||||
async_setup_cleanup(self.hass, self)
|
||||
|
@ -1169,6 +1255,7 @@ class DeviceRegistry(BaseRegistry[dict[str, list[dict[str, Any]]]]):
|
|||
self.devices = devices
|
||||
self.deleted_devices = deleted_devices
|
||||
self._device_data = devices.data
|
||||
self._merge_collisions()
|
||||
|
||||
@callback
|
||||
def _data_to_save(self) -> dict[str, Any]:
|
||||
|
|
|
@ -2501,7 +2501,14 @@ def test_all() -> None:
|
|||
help_test_all(dr)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("enum"), list(dr.DeviceEntryDisabler))
|
||||
@pytest.mark.parametrize(
|
||||
("enum"),
|
||||
[
|
||||
enum
|
||||
for enum in dr.DeviceEntryDisabler
|
||||
if enum != dr.DeviceEntryDisabler.DUPLICATE
|
||||
],
|
||||
)
|
||||
def test_deprecated_constants(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
enum: dr.DeviceEntryDisabler,
|
||||
|
@ -2822,17 +2829,19 @@ async def test_device_registry_connections_collision(
|
|||
hass: HomeAssistant, device_registry: dr.DeviceRegistry
|
||||
) -> None:
|
||||
"""Test connection collisions in the device registry."""
|
||||
config_entry = MockConfigEntry()
|
||||
config_entry.add_to_hass(hass)
|
||||
config_entry_1 = MockConfigEntry()
|
||||
config_entry_1.add_to_hass(hass)
|
||||
config_entry_2 = MockConfigEntry()
|
||||
config_entry_2.add_to_hass(hass)
|
||||
|
||||
device1 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
config_entry_id=config_entry_1.entry_id,
|
||||
connections={(dr.CONNECTION_NETWORK_MAC, "none")},
|
||||
manufacturer="manufacturer",
|
||||
model="model",
|
||||
)
|
||||
device2 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
config_entry_id=config_entry_1.entry_id,
|
||||
connections={(dr.CONNECTION_NETWORK_MAC, "none")},
|
||||
manufacturer="manufacturer",
|
||||
model="model",
|
||||
|
@ -2841,7 +2850,7 @@ async def test_device_registry_connections_collision(
|
|||
assert device1.id == device2.id
|
||||
|
||||
device3 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
config_entry_id=config_entry_2.entry_id,
|
||||
identifiers={("bridgeid", "0123")},
|
||||
manufacturer="manufacturer",
|
||||
model="model",
|
||||
|
@ -2891,7 +2900,7 @@ async def test_device_registry_connections_collision(
|
|||
# Attempt to implicitly merge connection for device3 with the same
|
||||
# connection that already exists in device1
|
||||
device4 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
config_entry_id=config_entry_2.entry_id,
|
||||
identifiers={("bridgeid", "0123")},
|
||||
connections={
|
||||
(dr.CONNECTION_NETWORK_MAC, "EE:EE:EE:EE:EE:EE"),
|
||||
|
@ -2903,24 +2912,54 @@ async def test_device_registry_connections_collision(
|
|||
|
||||
device3_refetched = device_registry.async_get(device3.id)
|
||||
device1_refetched = device_registry.async_get(device1.id)
|
||||
assert not device1_refetched.connections.isdisjoint(device3_refetched.connections)
|
||||
|
||||
# One of the devices should now:
|
||||
# - Be disabled
|
||||
# - Have all its connections removed
|
||||
# - Have a single identifier
|
||||
if device1_refetched.disabled_by is dr.DeviceEntryDisabler.DUPLICATE:
|
||||
main_device = device3_refetched
|
||||
duplicate_device = device1_refetched
|
||||
else:
|
||||
main_device = device1_refetched
|
||||
duplicate_device = device3_refetched
|
||||
|
||||
assert duplicate_device.disabled_by is dr.DeviceEntryDisabler.DUPLICATE
|
||||
assert main_device.disabled_by is None
|
||||
assert duplicate_device.config_entries in (
|
||||
{config_entry_1.entry_id},
|
||||
{config_entry_2.entry_id},
|
||||
)
|
||||
assert duplicate_device.connections == set()
|
||||
assert duplicate_device.identifiers == {("homeassistant", duplicate_device.id)}
|
||||
assert main_device.config_entries == {
|
||||
config_entry_1.entry_id,
|
||||
config_entry_2.entry_id,
|
||||
}
|
||||
assert main_device.connections == {
|
||||
(dr.CONNECTION_NETWORK_MAC, "ee:ee:ee:ee:ee:ee"),
|
||||
(dr.CONNECTION_NETWORK_MAC, "none"),
|
||||
}
|
||||
assert main_device.identifiers == {("bridgeid", "0123")}
|
||||
|
||||
|
||||
async def test_device_registry_identifiers_collision(
|
||||
hass: HomeAssistant, device_registry: dr.DeviceRegistry
|
||||
) -> None:
|
||||
"""Test identifiers collisions in the device registry."""
|
||||
config_entry = MockConfigEntry()
|
||||
config_entry.add_to_hass(hass)
|
||||
config_entry_1 = MockConfigEntry()
|
||||
config_entry_1.add_to_hass(hass)
|
||||
config_entry_2 = MockConfigEntry()
|
||||
config_entry_2.add_to_hass(hass)
|
||||
|
||||
device1 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
config_entry_id=config_entry_1.entry_id,
|
||||
identifiers={("bridgeid", "0123")},
|
||||
manufacturer="manufacturer",
|
||||
model="model",
|
||||
)
|
||||
device2 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
config_entry_id=config_entry_1.entry_id,
|
||||
identifiers={("bridgeid", "0123")},
|
||||
manufacturer="manufacturer",
|
||||
model="model",
|
||||
|
@ -2929,7 +2968,7 @@ async def test_device_registry_identifiers_collision(
|
|||
assert device1.id == device2.id
|
||||
|
||||
device3 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
config_entry_id=config_entry_2.entry_id,
|
||||
identifiers={("bridgeid", "4567")},
|
||||
manufacturer="manufacturer",
|
||||
model="model",
|
||||
|
@ -2971,7 +3010,7 @@ async def test_device_registry_identifiers_collision(
|
|||
# Attempt to implicitly merge identifiers for device3 with the same
|
||||
# connection that already exists in device1
|
||||
device4 = device_registry.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
config_entry_id=config_entry_2.entry_id,
|
||||
identifiers={("bridgeid", "4567"), ("bridgeid", "0123")},
|
||||
)
|
||||
assert len(device_registry.devices) == 2
|
||||
|
@ -2979,7 +3018,32 @@ async def test_device_registry_identifiers_collision(
|
|||
|
||||
device3_refetched = device_registry.async_get(device3.id)
|
||||
device1_refetched = device_registry.async_get(device1.id)
|
||||
assert not device1_refetched.identifiers.isdisjoint(device3_refetched.identifiers)
|
||||
|
||||
# One of the devices should now:
|
||||
# - Be disabled
|
||||
# - Have all its connections removed
|
||||
# - Have a single identifier
|
||||
if device1_refetched.disabled_by is dr.DeviceEntryDisabler.DUPLICATE:
|
||||
main_device = device3_refetched
|
||||
duplicate_device = device1_refetched
|
||||
else:
|
||||
main_device = device1_refetched
|
||||
duplicate_device = device3_refetched
|
||||
|
||||
assert duplicate_device.disabled_by is dr.DeviceEntryDisabler.DUPLICATE
|
||||
assert main_device.disabled_by is None
|
||||
assert duplicate_device.config_entries in (
|
||||
{config_entry_1.entry_id},
|
||||
{config_entry_2.entry_id},
|
||||
)
|
||||
assert duplicate_device.connections == set()
|
||||
assert duplicate_device.identifiers == {("homeassistant", duplicate_device.id)}
|
||||
assert main_device.config_entries == {
|
||||
config_entry_1.entry_id,
|
||||
config_entry_2.entry_id,
|
||||
}
|
||||
assert main_device.connections == set()
|
||||
assert main_device.identifiers == {("bridgeid", "0123"), ("bridgeid", "4567")}
|
||||
|
||||
|
||||
async def test_primary_config_entry(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue