diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index f9a7d6660da..b2e3bfd7a32 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -1,7 +1,7 @@ """Provide a way to connect entities belonging to one device.""" from collections import OrderedDict import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union import uuid import attr @@ -32,6 +32,11 @@ CONNECTION_NETWORK_MAC = "mac" CONNECTION_UPNP = "upnp" CONNECTION_ZIGBEE = "zigbee" +IDX_CONNECTIONS = "connections" +IDX_IDENTIFIERS = "identifiers" +REGISTERED_DEVICE = "registered" +DELETED_DEVICE = "deleted" + @attr.s(slots=True, frozen=True) class DeletedDeviceEntry: @@ -98,11 +103,13 @@ class DeviceRegistry: devices: Dict[str, DeviceEntry] deleted_devices: Dict[str, DeletedDeviceEntry] + _devices_index: Dict[str, Dict[str, Dict[str, str]]] def __init__(self, hass: HomeAssistantType) -> None: """Initialize the device registry.""" self.hass = hass self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) + self._clear_index() @callback def async_get(self, device_id: str) -> Optional[DeviceEntry]: @@ -114,25 +121,84 @@ class DeviceRegistry: self, identifiers: set, connections: set ) -> Optional[DeviceEntry]: """Check if device is registered.""" - for device in self.devices.values(): - if any(iden in device.identifiers for iden in identifiers) or any( - conn in device.connections for conn in connections - ): - return device - return None + device_id = self._async_get_device_id_from_index( + REGISTERED_DEVICE, identifiers, connections + ) + if device_id is None: + return None + return self.devices[device_id] - @callback def _async_get_deleted_device( self, identifiers: set, connections: set ) -> Optional[DeletedDeviceEntry]: + """Check if device is deleted.""" + device_id = self._async_get_device_id_from_index( + DELETED_DEVICE, identifiers, connections + ) + if device_id is None: + return None + return self.deleted_devices[device_id] + + def _async_get_device_id_from_index( + self, index: str, identifiers: set, connections: set + ) -> Optional[str]: """Check if device has previously been registered.""" - for device in self.deleted_devices.values(): - if any(iden in device.identifiers for iden in identifiers) or any( - conn in device.connections for conn in connections - ): - return device + devices_index = self._devices_index[index] + for identifier in identifiers: + if identifier in devices_index[IDX_IDENTIFIERS]: + return devices_index[IDX_IDENTIFIERS][identifier] + if not connections: + return None + for connection in _normalize_connections(connections): + if connection in devices_index[IDX_CONNECTIONS]: + return devices_index[IDX_CONNECTIONS][connection] return None + def _add_device(self, device: Union[DeviceEntry, DeletedDeviceEntry]) -> None: + """Add a device and index it.""" + if isinstance(device, DeletedDeviceEntry): + devices_index = self._devices_index[DELETED_DEVICE] + self.deleted_devices[device.id] = device + else: + devices_index = self._devices_index[REGISTERED_DEVICE] + self.devices[device.id] = device + + _add_device_to_index(devices_index, device) + + def _remove_device(self, device: Union[DeviceEntry, DeletedDeviceEntry]) -> None: + """Remove a device and remove it from the index.""" + if isinstance(device, DeletedDeviceEntry): + devices_index = self._devices_index[DELETED_DEVICE] + self.deleted_devices.pop(device.id) + else: + devices_index = self._devices_index[REGISTERED_DEVICE] + self.devices.pop(device.id) + + _remove_device_from_index(devices_index, device) + + def _update_device(self, old_device: DeviceEntry, new_device: DeviceEntry) -> None: + """Update a device and the index.""" + self.devices[new_device.id] = new_device + + devices_index = self._devices_index[REGISTERED_DEVICE] + _remove_device_from_index(devices_index, old_device) + _add_device_to_index(devices_index, new_device) + + def _clear_index(self): + """Clear the index.""" + self._devices_index = { + REGISTERED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}}, + DELETED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}}, + } + + def _rebuild_index(self): + """Create the index after loading devices.""" + self._clear_index() + for device in self.devices.values(): + _add_device_to_index(self._devices_index[REGISTERED_DEVICE], device) + for device in self.deleted_devices.values(): + _add_device_to_index(self._devices_index[DELETED_DEVICE], device) + @callback def async_get_or_create( self, @@ -156,11 +222,8 @@ class DeviceRegistry: if connections is None: connections = set() - - connections = { - (key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value) - for key, value in connections - } + else: + connections = _normalize_connections(connections) device = self.async_get_device(identifiers, connections) @@ -169,9 +232,9 @@ class DeviceRegistry: if deleted_device is None: device = DeviceEntry(is_new=True) else: - self.deleted_devices.pop(deleted_device.id) + self._remove_device(deleted_device) device = deleted_device.to_device_entry() - self.devices[device.id] = device + self._add_device(device) if via_device is not None: via = self.async_get_device({via_device}, set()) @@ -301,7 +364,8 @@ class DeviceRegistry: if not changes: return old - new = self.devices[device_id] = attr.evolve(old, **changes) + new = attr.evolve(old, **changes) + self._update_device(old, new) self.async_schedule_save() self.hass.bus.async_fire( @@ -317,12 +381,15 @@ class DeviceRegistry: @callback def async_remove_device(self, device_id: str) -> None: """Remove a device from the device registry.""" - device = self.devices.pop(device_id) - self.deleted_devices[device_id] = DeletedDeviceEntry( - config_entries=device.config_entries, - connections=device.connections, - identifiers=device.identifiers, - id=device.id, + device = self.devices[device_id] + self._remove_device(device) + self._add_device( + DeletedDeviceEntry( + config_entries=device.config_entries, + connections=device.connections, + identifiers=device.identifiers, + id=device.id, + ) ) self.hass.bus.async_fire( EVENT_DEVICE_REGISTRY_UPDATED, {"action": "remove", "device_id": device_id} @@ -371,6 +438,7 @@ class DeviceRegistry: self.devices = devices self.deleted_devices = deleted_devices + self._rebuild_index() @callback def async_schedule_save(self) -> None: @@ -422,9 +490,11 @@ class DeviceRegistry: continue if config_entries == {config_entry_id}: # Permanently remove the device from the device registry. - del self.deleted_devices[deleted_device.id] + self._remove_device(deleted_device) else: config_entries = config_entries - {config_entry_id} + # No need to reindex here since we currently + # do not have a lookup by config entry self.deleted_devices[deleted_device.id] = attr.evolve( deleted_device, config_entries=config_entries ) @@ -536,3 +606,33 @@ def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> Non await debounced_cleanup.async_call() hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, startup_clean) + + +def _normalize_connections(connections: set) -> set: + """Normalize connections to ensure we can match mac addresses.""" + return { + (key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value) + for key, value in connections + } + + +def _add_device_to_index( + devices_index: dict, device: Union[DeviceEntry, DeletedDeviceEntry] +) -> None: + """Add a device to the index.""" + for identifier in device.identifiers: + devices_index[IDX_IDENTIFIERS][identifier] = device.id + for connection in device.connections: + devices_index[IDX_CONNECTIONS][connection] = device.id + + +def _remove_device_from_index( + devices_index: dict, device: Union[DeviceEntry, DeletedDeviceEntry] +) -> None: + """Remove a device from the index.""" + for identifier in device.identifiers: + if identifier in devices_index[IDX_IDENTIFIERS]: + del devices_index[IDX_IDENTIFIERS][identifier] + for connection in device.connections: + if connection in devices_index[IDX_CONNECTIONS]: + del devices_index[IDX_CONNECTIONS][connection] diff --git a/tests/common.py b/tests/common.py index db060bc6b91..5fa2ba59ed1 100644 --- a/tests/common.py +++ b/tests/common.py @@ -371,6 +371,7 @@ def mock_device_registry(hass, mock_entries=None, mock_deleted_entries=None): registry = device_registry.DeviceRegistry(hass) registry.devices = mock_entries or OrderedDict() registry.deleted_devices = mock_deleted_entries or OrderedDict() + registry._rebuild_index() hass.data[device_registry.DATA_REGISTRY] = registry return registry diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index 82fadc35dd2..181a012807a 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -562,6 +562,21 @@ async def test_update(registry): assert updated_entry.identifiers == new_identifiers assert updated_entry.via_device_id == "98765B" + assert registry.async_get_device({("hue", "456")}, {}) is None + assert registry.async_get_device({("bla", "123")}, {}) is None + + assert registry.async_get_device({("hue", "654")}, {}) == updated_entry + assert registry.async_get_device({("bla", "321")}, {}) == updated_entry + + assert ( + registry.async_get_device( + {}, {(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")} + ) + == updated_entry + ) + + assert registry.async_get(updated_entry.id) is not None + async def test_update_remove_config_entries(hass, registry, update_events): """Make sure we do not get duplicate entries."""