diff --git a/homeassistant/components/asuswrt/device_tracker.py b/homeassistant/components/asuswrt/device_tracker.py index bb96cb184a1..c3afce88c18 100644 --- a/homeassistant/components/asuswrt/device_tracker.py +++ b/homeassistant/components/asuswrt/device_tracker.py @@ -4,11 +4,8 @@ from __future__ import annotations from homeassistant.components.device_tracker import SOURCE_TYPE_ROUTER from homeassistant.components.device_tracker.config_entry import ScannerEntity from homeassistant.config_entries import ConfigEntry -from homeassistant.const import ATTR_DEFAULT_NAME from homeassistant.core import HomeAssistant, callback -from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.entity import DeviceInfo from .const import DATA_ASUSWRT, DOMAIN from .router import AsusWrtRouter @@ -62,12 +59,6 @@ class AsusWrtDevice(ScannerEntity): self._device = device self._attr_unique_id = device.mac self._attr_name = device.name or DEFAULT_DEVICE_NAME - self._attr_device_info = DeviceInfo( - connections={(CONNECTION_NETWORK_MAC, device.mac)}, - default_model="ASUSWRT Tracked device", - ) - if device.name: - self._attr_device_info[ATTR_DEFAULT_NAME] = device.name @property def is_connected(self): diff --git a/homeassistant/components/device_tracker/__init__.py b/homeassistant/components/device_tracker/__init__.py index 035b1923c4c..24075ee1a7d 100644 --- a/homeassistant/components/device_tracker/__init__.py +++ b/homeassistant/components/device_tracker/__init__.py @@ -1,4 +1,6 @@ """Provide functionality to keep track of devices.""" +from __future__ import annotations + from homeassistant.const import ATTR_GPS_ACCURACY, STATE_HOME # noqa: F401 from homeassistant.core import HomeAssistant from homeassistant.helpers.typing import ConfigType diff --git a/homeassistant/components/device_tracker/config_entry.py b/homeassistant/components/device_tracker/config_entry.py index 97b79306e7b..096268c8fed 100644 --- a/homeassistant/components/device_tracker/config_entry.py +++ b/homeassistant/components/device_tracker/config_entry.py @@ -1,6 +1,7 @@ """Code to set up a device tracker platform using a config entry.""" from __future__ import annotations +import asyncio from typing import final from homeassistant.components import zone @@ -13,9 +14,11 @@ from homeassistant.const import ( STATE_HOME, STATE_NOT_HOME, ) -from homeassistant.core import HomeAssistant -from homeassistant.helpers.entity import Entity +from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.helpers.entity import DeviceInfo, Entity, EntityCategory from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.helpers.entity_platform import EntityPlatform from homeassistant.helpers.typing import StateType from .const import ATTR_HOST_NAME, ATTR_IP, ATTR_MAC, ATTR_SOURCE_TYPE, DOMAIN, LOGGER @@ -25,8 +28,32 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up an entry.""" component: EntityComponent | None = hass.data.get(DOMAIN) - if component is None: - component = hass.data[DOMAIN] = EntityComponent(LOGGER, DOMAIN, hass) + if component is not None: + return await component.async_setup_entry(entry) + + component = hass.data[DOMAIN] = EntityComponent(LOGGER, DOMAIN, hass) + + # Clean up old devices created by device tracker entities in the past. + # Can be removed after 2022.6 + ent_reg = er.async_get(hass) + dev_reg = dr.async_get(hass) + + devices_with_trackers = set() + devices_with_non_trackers = set() + + for entity in ent_reg.entities.values(): + if entity.device_id is None: + continue + + if entity.domain == DOMAIN: + devices_with_trackers.add(entity.device_id) + else: + devices_with_non_trackers.add(entity.device_id) + + for device_id in devices_with_trackers - devices_with_non_trackers: + for entity in er.async_entries_for_device(ent_reg, device_id, True): + ent_reg.async_update_entity(entity.entity_id, device_id=None) + dev_reg.async_remove_device(device_id) return await component.async_setup_entry(entry) @@ -37,9 +64,80 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return await component.async_unload_entry(entry) +@callback +def _async_register_mac( + hass: HomeAssistant, domain: str, mac: str, unique_id: str +) -> None: + """Register a mac address with a unique ID.""" + data_key = "device_tracker_mac" + mac = dr.format_mac(mac) + if data_key in hass.data: + hass.data[data_key][mac] = (domain, unique_id) + return + + # Setup listening. + + # dict mapping mac -> partial unique ID + data = hass.data[data_key] = {mac: (domain, unique_id)} + + @callback + def handle_device_event(ev: Event) -> None: + """Enable the online status entity for the mac of a newly created device.""" + # Only for new devices + if ev.data["action"] != "create": + return + + dev_reg = dr.async_get(hass) + device_entry = dev_reg.async_get(ev.data["device_id"]) + + if device_entry is None: + return + + # Check if device has a mac + mac = None + for conn in device_entry.connections: + if conn[0] == dr.CONNECTION_NETWORK_MAC: + mac = conn[1] + break + + if mac is None: + return + + # Check if we have an entity for this mac + if (unique_id := data.get(mac)) is None: + return + + ent_reg = er.async_get(hass) + entity_id = ent_reg.async_get_entity_id(DOMAIN, *unique_id) + + if entity_id is None: + return + + entity_entry = ent_reg.async_get(entity_id) + + if entity_entry is None: + return + + # Make sure entity has a config entry and was disabled by the + # default disable logic in the integration. + if ( + entity_entry.config_entry_id is None + or entity_entry.disabled_by != er.RegistryEntryDisabler.INTEGRATION + ): + return + + # Enable entity + ent_reg.async_update_entity(entity_id, disabled_by=None) + + hass.bus.async_listen(dr.EVENT_DEVICE_REGISTRY_UPDATED, handle_device_event) + + class BaseTrackerEntity(Entity): """Represent a tracked device.""" + _attr_device_info: None = None + _attr_entity_category = EntityCategory.DIAGNOSTIC + @property def battery_level(self) -> int | None: """Return the battery level of the device. @@ -164,6 +262,86 @@ class ScannerEntity(BaseTrackerEntity): """Return true if the device is connected to the network.""" raise NotImplementedError + @property + def unique_id(self) -> str | None: + """Return unique ID of the entity.""" + return self.mac_address + + @final + @property + def device_info(self) -> DeviceInfo | None: + """Device tracker entities should not create device registry entries.""" + return None + + @property + def entity_registry_enabled_default(self) -> bool: + """Return if entity is enabled by default.""" + # If mac_address is None, we can never find a device entry. + return ( + # Do not disable if we won't activate our attach to device logic + self.mac_address is None + or self.device_info is not None + # Disable if we automatically attach but there is no device + or self.find_device_entry() is not None + ) + + @callback + def add_to_platform_start( + self, + hass: HomeAssistant, + platform: EntityPlatform, + parallel_updates: asyncio.Semaphore | None, + ) -> None: + """Start adding an entity to a platform.""" + super().add_to_platform_start(hass, platform, parallel_updates) + if self.mac_address and self.unique_id: + _async_register_mac( + hass, platform.platform_name, self.mac_address, self.unique_id + ) + + @callback + def find_device_entry(self) -> dr.DeviceEntry | None: + """Return device entry.""" + assert self.mac_address is not None + + return dr.async_get(self.hass).async_get_device( + set(), {(dr.CONNECTION_NETWORK_MAC, self.mac_address)} + ) + + async def async_internal_added_to_hass(self) -> None: + """Handle added to Home Assistant.""" + # Entities without a unique ID don't have a device + if ( + not self.registry_entry + or not self.platform + or not self.platform.config_entry + or not self.mac_address + or (device_entry := self.find_device_entry()) is None + # Entities should not have a device info. We opt them out + # of this logic if they do. + or self.device_info + ): + if self.device_info: + LOGGER.debug("Entity %s unexpectedly has a device info", self.entity_id) + await super().async_internal_added_to_hass() + return + + # Attach entry to device + if self.registry_entry.device_id != device_entry.id: + self.registry_entry = er.async_get(self.hass).async_update_entity( + self.entity_id, device_id=device_entry.id + ) + + # Attach device to config entry + if self.platform.config_entry.entry_id not in device_entry.config_entries: + dr.async_get(self.hass).async_update_device( + device_entry.id, + add_config_entry_id=self.platform.config_entry.entry_id, + ) + + # Do this last or else the entity registry update listener has been installed + await super().async_internal_added_to_hass() + @final @property def state_attributes(self) -> dict[str, StateType]: diff --git a/homeassistant/components/freebox/device_tracker.py b/homeassistant/components/freebox/device_tracker.py index 38a781c8c12..f512eaf33aa 100644 --- a/homeassistant/components/freebox/device_tracker.py +++ b/homeassistant/components/freebox/device_tracker.py @@ -8,9 +8,7 @@ from homeassistant.components.device_tracker import SOURCE_TYPE_ROUTER from homeassistant.components.device_tracker.config_entry import ScannerEntity from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback -from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.entity import DeviceInfo from .const import DEFAULT_DEVICE_NAME, DEVICE_ICONS, DOMAIN from .router import FreeboxRouter @@ -82,7 +80,7 @@ class FreeboxDevice(ScannerEntity): self._attrs = device["attrs"] @property - def unique_id(self) -> str: + def mac_address(self) -> str: """Return a unique ID.""" return self._mac @@ -111,16 +109,6 @@ class FreeboxDevice(ScannerEntity): """Return the attributes.""" return self._attrs - @property - def device_info(self) -> DeviceInfo: - """Return the device information.""" - return DeviceInfo( - connections={(CONNECTION_NETWORK_MAC, self._mac)}, - identifiers={(DOMAIN, self.unique_id)}, - manufacturer=self._manufacturer, - name=self.name, - ) - @property def should_poll(self) -> bool: """No polling needed.""" diff --git a/homeassistant/components/fritz/common.py b/homeassistant/components/fritz/common.py index 68bea01ff3d..52b3d823651 100644 --- a/homeassistant/components/fritz/common.py +++ b/homeassistant/components/fritz/common.py @@ -519,21 +519,6 @@ class FritzDeviceBase(update_coordinator.CoordinatorEntity): return self._router.devices[self._mac].hostname return None - @property - def device_info(self) -> DeviceInfo: - """Return the device information.""" - return DeviceInfo( - connections={(CONNECTION_NETWORK_MAC, self._mac)}, - default_manufacturer="AVM", - default_model="FRITZ!Box Tracked device", - default_name=self.name, - identifiers={(DOMAIN, self._mac)}, - via_device=( - DOMAIN, - self._router.unique_id, - ), - ) - @property def should_poll(self) -> bool: """No polling needed.""" diff --git a/homeassistant/components/fritz/device_tracker.py b/homeassistant/components/fritz/device_tracker.py index 5a0cd71a728..1a5cff98904 100644 --- a/homeassistant/components/fritz/device_tracker.py +++ b/homeassistant/components/fritz/device_tracker.py @@ -130,6 +130,11 @@ class FritzBoxTracker(FritzDeviceBase, ScannerEntity): """Return device unique id.""" return f"{self._mac}_tracker" + @property + def mac_address(self) -> str: + """Return mac_address.""" + return self._mac + @property def icon(self) -> str: """Return device icon.""" diff --git a/homeassistant/components/fritz/switch.py b/homeassistant/components/fritz/switch.py index 2d0f9743d72..773b9c771b7 100644 --- a/homeassistant/components/fritz/switch.py +++ b/homeassistant/components/fritz/switch.py @@ -19,8 +19,9 @@ from homeassistant.components.network import async_get_source_ip from homeassistant.components.switch import SwitchEntity from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.entity import Entity, EntityCategory +from homeassistant.helpers.entity import DeviceInfo, Entity, EntityCategory from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util import slugify @@ -605,6 +606,21 @@ class FritzBoxProfileSwitch(FritzDeviceBase, SwitchEntity): """Switch status.""" return self._router.devices[self._mac].wan_access + @property + def device_info(self) -> DeviceInfo: + """Return the device information.""" + return DeviceInfo( + connections={(CONNECTION_NETWORK_MAC, self._mac)}, + default_manufacturer="AVM", + default_model="FRITZ!Box Tracked device", + default_name=self.name, + identifiers={(DOMAIN, self._mac)}, + via_device=( + DOMAIN, + self._router.unique_id, + ), + ) + async def async_turn_on(self, **kwargs: Any) -> None: """Turn on switch.""" await self._async_handle_turn_on_off(turn_on=True) diff --git a/homeassistant/components/huawei_lte/__init__.py b/homeassistant/components/huawei_lte/__init__.py index b9738b9136a..f7dc1ca9fe8 100644 --- a/homeassistant/components/huawei_lte/__init__.py +++ b/homeassistant/components/huawei_lte/__init__.py @@ -653,14 +653,6 @@ class HuaweiLteBaseEntity(Entity): """Huawei LTE entities report their state without polling.""" return False - @property - def device_info(self) -> DeviceInfo: - """Get info for matching with parent router.""" - return DeviceInfo( - connections=self.router.device_connections, - identifiers=self.router.device_identifiers, - ) - async def async_update(self) -> None: """Update state.""" raise NotImplementedError @@ -681,3 +673,15 @@ class HuaweiLteBaseEntity(Entity): for unsub in self._unsub_handlers: unsub() self._unsub_handlers.clear() + + +class HuaweiLteBaseEntityWithDevice(HuaweiLteBaseEntity): + """Base entity with device info.""" + + @property + def device_info(self) -> DeviceInfo: + """Get info for matching with parent router.""" + return DeviceInfo( + connections=self.router.device_connections, + identifiers=self.router.device_identifiers, + ) diff --git a/homeassistant/components/huawei_lte/binary_sensor.py b/homeassistant/components/huawei_lte/binary_sensor.py index bc4e7ce33a2..9bd455d9d59 100644 --- a/homeassistant/components/huawei_lte/binary_sensor.py +++ b/homeassistant/components/huawei_lte/binary_sensor.py @@ -16,7 +16,7 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_platform import AddEntitiesCallback -from . import HuaweiLteBaseEntity +from . import HuaweiLteBaseEntityWithDevice from .const import ( DOMAIN, KEY_MONITORING_CHECK_NOTIFICATIONS, @@ -49,7 +49,7 @@ async def async_setup_entry( @dataclass -class HuaweiLteBaseBinarySensor(HuaweiLteBaseEntity, BinarySensorEntity): +class HuaweiLteBaseBinarySensor(HuaweiLteBaseEntityWithDevice, BinarySensorEntity): """Huawei LTE binary sensor device base class.""" key: str = field(init=False) diff --git a/homeassistant/components/huawei_lte/sensor.py b/homeassistant/components/huawei_lte/sensor.py index ad75e6f84ac..85f63ed2bc6 100644 --- a/homeassistant/components/huawei_lte/sensor.py +++ b/homeassistant/components/huawei_lte/sensor.py @@ -28,7 +28,7 @@ from homeassistant.helpers.entity import Entity, EntityCategory from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import StateType -from . import HuaweiLteBaseEntity +from . import HuaweiLteBaseEntityWithDevice from .const import ( DOMAIN, KEY_DEVICE_INFORMATION, @@ -523,7 +523,7 @@ def format_default(value: StateType) -> tuple[StateType, str | None]: @dataclass -class HuaweiLteSensor(HuaweiLteBaseEntity, SensorEntity): +class HuaweiLteSensor(HuaweiLteBaseEntityWithDevice, SensorEntity): """Huawei LTE sensor entity.""" key: str diff --git a/homeassistant/components/huawei_lte/switch.py b/homeassistant/components/huawei_lte/switch.py index 112e6c820e6..cc5e8e446c5 100644 --- a/homeassistant/components/huawei_lte/switch.py +++ b/homeassistant/components/huawei_lte/switch.py @@ -15,7 +15,7 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_platform import AddEntitiesCallback -from . import HuaweiLteBaseEntity +from . import HuaweiLteBaseEntityWithDevice from .const import DOMAIN, KEY_DIALUP_MOBILE_DATASWITCH _LOGGER = logging.getLogger(__name__) @@ -37,7 +37,7 @@ async def async_setup_entry( @dataclass -class HuaweiLteBaseSwitch(HuaweiLteBaseEntity, SwitchEntity): +class HuaweiLteBaseSwitch(HuaweiLteBaseEntityWithDevice, SwitchEntity): """Huawei LTE switch device base class.""" key: str = field(init=False) diff --git a/homeassistant/components/keenetic_ndms2/device_tracker.py b/homeassistant/components/keenetic_ndms2/device_tracker.py index 4ec353045c7..4db32cae079 100644 --- a/homeassistant/components/keenetic_ndms2/device_tracker.py +++ b/homeassistant/components/keenetic_ndms2/device_tracker.py @@ -24,9 +24,7 @@ from homeassistant.const import ( from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import entity_registry import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.entity import DeviceInfo import homeassistant.util.dt as dt_util from .const import ( @@ -217,15 +215,6 @@ class KeeneticTracker(ScannerEntity): } return None - @property - def device_info(self) -> DeviceInfo: - """Return a client description for device registry.""" - return DeviceInfo( - connections={(CONNECTION_NETWORK_MAC, self._device.mac)}, - identifiers={(DOMAIN, self._device.mac)}, - name=self._device.name if self._device.name else None, - ) - async def async_added_to_hass(self): """Client entity created.""" _LOGGER.debug("New network device tracker %s (%s)", self.name, self.unique_id) diff --git a/homeassistant/components/mikrotik/device_tracker.py b/homeassistant/components/mikrotik/device_tracker.py index 31a279aeee7..bab4459fec3 100644 --- a/homeassistant/components/mikrotik/device_tracker.py +++ b/homeassistant/components/mikrotik/device_tracker.py @@ -8,9 +8,7 @@ from homeassistant.components.device_tracker.const import ( ) from homeassistant.core import callback from homeassistant.helpers import entity_registry -from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.entity import DeviceInfo import homeassistant.util.dt as dt_util from .const import DOMAIN @@ -130,17 +128,6 @@ class MikrotikHubTracker(ScannerEntity): return {k: v for k, v in self.device.attrs.items() if k not in FILTER_ATTRS} return None - @property - def device_info(self) -> DeviceInfo: - """Return a client description for device registry.""" - # We only get generic info from device discovery and so don't want - # to override API specific info that integrations can provide - return DeviceInfo( - connections={(CONNECTION_NETWORK_MAC, self.device.mac)}, - default_name=self.name, - identifiers={(DOMAIN, self.device.mac)}, - ) - async def async_added_to_hass(self): """Client entity created.""" _LOGGER.debug("New network device tracker %s (%s)", self.name, self.unique_id) diff --git a/homeassistant/components/nmap_tracker/device_tracker.py b/homeassistant/components/nmap_tracker/device_tracker.py index d1115fa8934..04ae07e8715 100644 --- a/homeassistant/components/nmap_tracker/device_tracker.py +++ b/homeassistant/components/nmap_tracker/device_tracker.py @@ -22,9 +22,7 @@ from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry from homeassistant.const import CONF_EXCLUDE, CONF_HOSTS from homeassistant.core import HomeAssistant, callback import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.typing import ConfigType from . import NmapDevice, NmapDeviceScanner, short_hostname, signal_device_update @@ -169,15 +167,6 @@ class NmapTrackerEntity(ScannerEntity): """Return tracker source type.""" return SOURCE_TYPE_ROUTER - @property - def device_info(self) -> DeviceInfo: - """Return the device information.""" - return DeviceInfo( - connections={(CONNECTION_NETWORK_MAC, self._mac_address)}, - default_manufacturer=self._device.manufacturer, - default_name=self.name, - ) - @property def should_poll(self) -> bool: """No polling needed.""" diff --git a/homeassistant/components/ruckus_unleashed/device_tracker.py b/homeassistant/components/ruckus_unleashed/device_tracker.py index 6a923e05641..8a27ea3aa78 100644 --- a/homeassistant/components/ruckus_unleashed/device_tracker.py +++ b/homeassistant/components/ruckus_unleashed/device_tracker.py @@ -6,12 +6,9 @@ from homeassistant.components.device_tracker.config_entry import ScannerEntity from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import entity_registry -from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC -from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.update_coordinator import CoordinatorEntity from .const import ( - API_ACCESS_POINT, API_CLIENTS, API_NAME, COORDINATOR, @@ -93,8 +90,8 @@ class RuckusUnleashedDevice(CoordinatorEntity, ScannerEntity): self._name = name @property - def unique_id(self) -> str: - """Return a unique ID.""" + def mac_address(self) -> str: + """Return a mac address.""" return self._mac @property @@ -116,17 +113,3 @@ class RuckusUnleashedDevice(CoordinatorEntity, ScannerEntity): def source_type(self) -> str: """Return the source type.""" return SOURCE_TYPE_ROUTER - - @property - def device_info(self) -> DeviceInfo | None: - """Return the device information.""" - if self.is_connected: - return DeviceInfo( - name=self.name, - connections={(CONNECTION_NETWORK_MAC, self._mac)}, - via_device=( - CONNECTION_NETWORK_MAC, - self.coordinator.data[API_CLIENTS][self._mac][API_ACCESS_POINT], - ), - ) - return None diff --git a/homeassistant/components/unifi/device_tracker.py b/homeassistant/components/unifi/device_tracker.py index a27c4de2244..0b345a13b17 100644 --- a/homeassistant/components/unifi/device_tracker.py +++ b/homeassistant/components/unifi/device_tracker.py @@ -19,16 +19,12 @@ from homeassistant.components.device_tracker import DOMAIN from homeassistant.components.device_tracker.config_entry import ScannerEntity from homeassistant.components.device_tracker.const import SOURCE_TYPE_ROUTER from homeassistant.config_entries import ConfigEntry -from homeassistant.const import ATTR_NAME from homeassistant.core import HomeAssistant, callback -from homeassistant.helpers import device_registry as dr -from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback import homeassistant.util.dt as dt_util -from .const import ATTR_MANUFACTURER, DOMAIN as UNIFI_DOMAIN +from .const import DOMAIN as UNIFI_DOMAIN from .unifi_client import UniFiClient from .unifi_entity_base import UniFiBase @@ -242,6 +238,11 @@ class UniFiClientTracker(UniFiClient, ScannerEntity): self._is_connected = False self.async_write_ha_state() + @property + def device_info(self) -> None: + """Return no device info.""" + return None + @property def is_connected(self): """Return true if the client is connected to the network.""" @@ -365,13 +366,6 @@ class UniFiDeviceTracker(UniFiBase, ScannerEntity): self._is_connected = True self.schedule_update = True - elif ( - self.device.last_updated == SOURCE_EVENT - and self.device.event.event in DEVICE_UPGRADED - ): - self.hass.async_create_task(self.async_update_device_registry()) - return - if self.schedule_update: self.schedule_update = False self.controller.async_heartbeat( @@ -412,28 +406,6 @@ class UniFiDeviceTracker(UniFiBase, ScannerEntity): """Return if controller is available.""" return not self.device.disabled and self.controller.available - @property - def device_info(self) -> DeviceInfo: - """Return a device description for device registry.""" - info = DeviceInfo( - connections={(CONNECTION_NETWORK_MAC, self.device.mac)}, - manufacturer=ATTR_MANUFACTURER, - model=self.device.model, - sw_version=self.device.version, - ) - - if self.device.name: - info[ATTR_NAME] = self.device.name - - return info - - async def async_update_device_registry(self) -> None: - """Update device registry.""" - device_registry = dr.async_get(self.hass) - device_registry.async_get_or_create( - config_entry_id=self.controller.config_entry.entry_id, **self.device_info - ) - @property def extra_state_attributes(self): """Return the device state attributes.""" diff --git a/homeassistant/components/unifi/unifi_entity_base.py b/homeassistant/components/unifi/unifi_entity_base.py index 25e10ab13ec..1c3251d213c 100644 --- a/homeassistant/components/unifi/unifi_entity_base.py +++ b/homeassistant/components/unifi/unifi_entity_base.py @@ -6,7 +6,6 @@ from homeassistant.core import callback from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity import Entity -from homeassistant.helpers.entity_registry import async_entries_for_device _LOGGER = logging.getLogger(__name__) @@ -102,35 +101,10 @@ class UniFiBase(Entity): entity_registry.async_remove(self.entity_id) return - if ( - len( - entries_for_device := async_entries_for_device( - entity_registry, - entity_entry.device_id, - include_disabled_entities=True, - ) - ) - ) == 1: - device_registry.async_remove_device(device_entry.id) - return - - if ( - len( - entries_for_device_from_this_config_entry := [ - entry_for_device - for entry_for_device in entries_for_device - if entry_for_device.config_entry_id - == self.controller.config_entry.entry_id - ] - ) - != len(entries_for_device) - and len(entries_for_device_from_this_config_entry) == 1 - ): - device_registry.async_update_device( - entity_entry.device_id, - remove_config_entry_id=self.controller.config_entry.entry_id, - ) - + device_registry.async_update_device( + entity_entry.device_id, + remove_config_entry_id=self.controller.config_entry.entry_id, + ) entity_registry.async_remove(self.entity_id) @property diff --git a/homeassistant/components/zha/device_tracker.py b/homeassistant/components/zha/device_tracker.py index 4219bfd6288..78d43952f8a 100644 --- a/homeassistant/components/zha/device_tracker.py +++ b/homeassistant/components/zha/device_tracker.py @@ -1,4 +1,6 @@ """Support for the ZHA platform.""" +from __future__ import annotations + import functools import time @@ -8,6 +10,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.dispatcher import async_dispatcher_connect +from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from .core import discovery @@ -103,3 +106,19 @@ class ZHADeviceScannerEntity(ScannerEntity, ZhaEntity): Percentage from 0-100. """ return self._battery_level + + @property + def device_info( # pylint: disable=overridden-final-method + self, + ) -> DeviceInfo | None: + """Return device info.""" + # We opt ZHA device tracker back into overriding this method because + # it doesn't track IP-based devices. + # Call Super because ScannerEntity overrode it. + return super(ZhaEntity, self).device_info + + @property + def unique_id(self) -> str | None: + """Return unique ID.""" + # Call Super because ScannerEntity overrode it. + return super(ZhaEntity, self).unique_id diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 8e5042db722..6bec4e592e6 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -372,7 +372,7 @@ class DeviceRegistry: ) entry_type = DeviceEntryType(entry_type) - device = self._async_update_device( + device = self.async_update_device( device.id, add_config_entry_id=config_entry_id, configuration_url=configuration_url, @@ -396,45 +396,6 @@ class DeviceRegistry: @callback def async_update_device( - self, - device_id: str, - *, - add_config_entry_id: str | UndefinedType = UNDEFINED, - area_id: str | None | UndefinedType = UNDEFINED, - configuration_url: str | None | UndefinedType = UNDEFINED, - disabled_by: DeviceEntryDisabler | None | UndefinedType = UNDEFINED, - manufacturer: str | None | UndefinedType = UNDEFINED, - model: str | None | UndefinedType = UNDEFINED, - name_by_user: str | None | UndefinedType = UNDEFINED, - name: str | None | UndefinedType = UNDEFINED, - new_identifiers: set[tuple[str, str]] | UndefinedType = UNDEFINED, - remove_config_entry_id: str | UndefinedType = UNDEFINED, - suggested_area: str | None | UndefinedType = UNDEFINED, - sw_version: str | None | UndefinedType = UNDEFINED, - hw_version: str | None | UndefinedType = UNDEFINED, - via_device_id: str | None | UndefinedType = UNDEFINED, - ) -> DeviceEntry | None: - """Update properties of a device.""" - return self._async_update_device( - device_id, - add_config_entry_id=add_config_entry_id, - area_id=area_id, - configuration_url=configuration_url, - disabled_by=disabled_by, - manufacturer=manufacturer, - model=model, - name_by_user=name_by_user, - name=name, - new_identifiers=new_identifiers, - remove_config_entry_id=remove_config_entry_id, - suggested_area=suggested_area, - sw_version=sw_version, - hw_version=hw_version, - via_device_id=via_device_id, - ) - - @callback - def _async_update_device( self, device_id: str, *, @@ -568,7 +529,7 @@ class DeviceRegistry: ) for other_device in list(self.devices.values()): if other_device.via_device_id == device_id: - self._async_update_device(other_device.id, via_device_id=None) + self.async_update_device(other_device.id, via_device_id=None) self.hass.bus.async_fire( EVENT_DEVICE_REGISTRY_UPDATED, {"action": "remove", "device_id": device_id} ) @@ -669,7 +630,7 @@ class DeviceRegistry: """Clear config entry from registry entries.""" now_time = time.time() for device in list(self.devices.values()): - self._async_update_device(device.id, remove_config_entry_id=config_entry_id) + self.async_update_device(device.id, remove_config_entry_id=config_entry_id) for deleted_device in list(self.deleted_devices.values()): config_entries = deleted_device.config_entries if config_entry_id not in config_entries: @@ -711,7 +672,7 @@ class DeviceRegistry: """Clear area id from registry entries.""" for dev_id, device in self.devices.items(): if area_id == device.area_id: - self._async_update_device(dev_id, area_id=None) + self.async_update_device(dev_id, area_id=None) @callback diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 7c49e884646..2a4e14ba050 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -335,7 +335,7 @@ class EntityRegistry: entity_id = self.async_get_entity_id(domain, platform, unique_id) if entity_id: - return self._async_update_entity( + return self.async_update_entity( entity_id, area_id=area_id or UNDEFINED, capabilities=capabilities or UNDEFINED, @@ -460,43 +460,6 @@ class EntityRegistry: @callback def async_update_entity( - self, - entity_id: str, - *, - area_id: str | None | UndefinedType = UNDEFINED, - config_entry_id: str | None | UndefinedType = UNDEFINED, - device_class: str | None | UndefinedType = UNDEFINED, - disabled_by: RegistryEntryDisabler | None | UndefinedType = UNDEFINED, - entity_category: str | None | UndefinedType = UNDEFINED, - icon: str | None | UndefinedType = UNDEFINED, - name: str | None | UndefinedType = UNDEFINED, - new_entity_id: str | UndefinedType = UNDEFINED, - new_unique_id: str | UndefinedType = UNDEFINED, - original_device_class: str | None | UndefinedType = UNDEFINED, - original_icon: str | None | UndefinedType = UNDEFINED, - original_name: str | None | UndefinedType = UNDEFINED, - unit_of_measurement: str | None | UndefinedType = UNDEFINED, - ) -> RegistryEntry: - """Update properties of an entity.""" - return self._async_update_entity( - entity_id, - area_id=area_id, - config_entry_id=config_entry_id, - device_class=device_class, - disabled_by=disabled_by, - entity_category=entity_category, - icon=icon, - name=name, - new_entity_id=new_entity_id, - new_unique_id=new_unique_id, - original_device_class=original_device_class, - original_icon=original_icon, - original_name=original_name, - unit_of_measurement=unit_of_measurement, - ) - - @callback - def _async_update_entity( self, entity_id: str, *, @@ -693,7 +656,7 @@ class EntityRegistry: """Clear area id from registry entries.""" for entity_id, entry in self.entities.items(): if area_id == entry.area_id: - self._async_update_entity(entity_id, area_id=None) + self.async_update_entity(entity_id, area_id=None) @callback diff --git a/tests/components/asuswrt/test_sensor.py b/tests/components/asuswrt/test_sensor.py index b8537a5e6a6..bfb62dae7e0 100644 --- a/tests/components/asuswrt/test_sensor.py +++ b/tests/components/asuswrt/test_sensor.py @@ -19,7 +19,7 @@ from homeassistant.const import ( STATE_HOME, STATE_NOT_HOME, ) -from homeassistant.helpers import entity_registry as er +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.util import slugify from homeassistant.util.dt import utcnow @@ -41,6 +41,9 @@ MOCK_BYTES_TOTAL = [60000000000, 50000000000] MOCK_CURRENT_TRANSFER_RATES = [20000000, 10000000] MOCK_LOAD_AVG = [1.1, 1.2, 1.3] MOCK_TEMPERATURES = {"2.4GHz": 40, "5.0GHz": 0, "CPU": 71.2} +MOCK_MAC_1 = "a1:b1:c1:d1:e1:f1" +MOCK_MAC_2 = "a2:b2:c2:d2:e2:f2" +MOCK_MAC_3 = "a3:b3:c3:d3:e3:f3" SENSOR_NAMES = [ "Devices Connected", @@ -61,8 +64,8 @@ SENSOR_NAMES = [ def mock_devices_fixture(): """Mock a list of devices.""" return { - "a1:b1:c1:d1:e1:f1": Device("a1:b1:c1:d1:e1:f1", "192.168.1.2", "Test"), - "a2:b2:c2:d2:e2:f2": Device("a2:b2:c2:d2:e2:f2", "192.168.1.3", "TestTwo"), + MOCK_MAC_1: Device(MOCK_MAC_1, "192.168.1.2", "Test"), + MOCK_MAC_2: Device(MOCK_MAC_2, "192.168.1.3", "TestTwo"), } @@ -74,6 +77,26 @@ def mock_available_temps_list(): return [True, False] +@pytest.fixture(name="create_device_registry_devices") +def create_device_registry_devices_fixture(hass): + """Create device registry devices so the device tracker entities are enabled.""" + dev_reg = dr.async_get(hass) + config_entry = MockConfigEntry(domain="something_else") + + for idx, device in enumerate( + ( + MOCK_MAC_1, + MOCK_MAC_2, + MOCK_MAC_3, + ) + ): + dev_reg.async_get_or_create( + name=f"Device {idx}", + config_entry_id=config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, device)}, + ) + + @pytest.fixture(name="connect") def mock_controller_connect(mock_devices, mock_available_temps): """Mock a successful connection.""" @@ -109,7 +132,13 @@ def mock_controller_connect(mock_devices, mock_available_temps): yield service_mock -async def test_sensors(hass, connect, mock_devices, mock_available_temps): +async def test_sensors( + hass, + connect, + mock_devices, + mock_available_temps, + create_device_registry_devices, +): """Test creating an AsusWRT sensor.""" entity_reg = er.async_get(hass) @@ -161,10 +190,8 @@ async def test_sensors(hass, connect, mock_devices, mock_available_temps): assert not hass.states.get(f"{sensor_prefix}_cpu_temperature") # add one device and remove another - mock_devices.pop("a1:b1:c1:d1:e1:f1") - mock_devices["a3:b3:c3:d3:e3:f3"] = Device( - "a3:b3:c3:d3:e3:f3", "192.168.1.4", "TestThree" - ) + mock_devices.pop(MOCK_MAC_1) + mock_devices[MOCK_MAC_3] = Device(MOCK_MAC_3, "192.168.1.4", "TestThree") async_fire_time_changed(hass, utcnow() + timedelta(seconds=30)) await hass.async_block_till_done() diff --git a/tests/components/device_tracker/test_config_entry.py b/tests/components/device_tracker/test_config_entry.py index 9b6a85cf8a0..3c8efad5b05 100644 --- a/tests/components/device_tracker/test_config_entry.py +++ b/tests/components/device_tracker/test_config_entry.py @@ -1,11 +1,14 @@ """Test Device Tracker config entry things.""" -from homeassistant.components.device_tracker import config_entry +from homeassistant.components.device_tracker import DOMAIN, config_entry as ce +from homeassistant.helpers import device_registry as dr, entity_registry as er + +from tests.common import MockConfigEntry def test_tracker_entity(): """Test tracker entity.""" - class TestEntry(config_entry.TrackerEntity): + class TestEntry(ce.TrackerEntity): """Mock tracker class.""" should_poll = False @@ -17,3 +20,111 @@ def test_tracker_entity(): instance.should_poll = True assert not instance.force_update + + +async def test_cleanup_legacy(hass, enable_custom_integrations): + """Test we clean up devices created by old device tracker.""" + dev_reg = dr.async_get(hass) + ent_reg = er.async_get(hass) + config_entry = MockConfigEntry(domain="test") + config_entry.add_to_hass(hass) + + device1 = dev_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, identifiers={(DOMAIN, "device1")} + ) + device2 = dev_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, identifiers={(DOMAIN, "device2")} + ) + device3 = dev_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, identifiers={(DOMAIN, "device3")} + ) + + # Device with light + device tracker entity + entity1a = ent_reg.async_get_or_create( + DOMAIN, + "test", + "entity1a-unique", + config_entry=config_entry, + device_id=device1.id, + ) + entity1b = ent_reg.async_get_or_create( + "light", + "test", + "entity1b-unique", + config_entry=config_entry, + device_id=device1.id, + ) + # Just device tracker entity + entity2a = ent_reg.async_get_or_create( + DOMAIN, + "test", + "entity2a-unique", + config_entry=config_entry, + device_id=device2.id, + ) + # Device with no device tracker entities + entity3a = ent_reg.async_get_or_create( + "light", + "test", + "entity3a-unique", + config_entry=config_entry, + device_id=device3.id, + ) + # Device tracker but no device + entity4a = ent_reg.async_get_or_create( + DOMAIN, + "test", + "entity4a-unique", + config_entry=config_entry, + ) + # Completely different entity + entity5a = ent_reg.async_get_or_create( + "light", + "test", + "entity4a-unique", + config_entry=config_entry, + ) + + await hass.config_entries.async_forward_entry_setup(config_entry, DOMAIN) + await hass.async_block_till_done() + + for entity in (entity1a, entity1b, entity3a, entity4a, entity5a): + assert ent_reg.async_get(entity.entity_id) is not None + + # We've removed device so device ID cleared + assert ent_reg.async_get(entity2a.entity_id).device_id is None + # Removed because only had device tracker entity + assert dev_reg.async_get(device2.id) is None + + +async def test_register_mac(hass): + """Test registering a mac.""" + dev_reg = dr.async_get(hass) + ent_reg = er.async_get(hass) + + config_entry = MockConfigEntry(domain="test") + config_entry.add_to_hass(hass) + + mac1 = "12:34:56:AB:CD:EF" + + entity_entry_1 = ent_reg.async_get_or_create( + "device_tracker", + "test", + mac1 + "yo1", + original_name="name 1", + config_entry=config_entry, + disabled_by=er.RegistryEntryDisabler.INTEGRATION, + ) + + ce._async_register_mac(hass, "test", mac1, mac1 + "yo1") + + dev_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, mac1)}, + ) + + await hass.async_block_till_done() + + entity_entry_1 = ent_reg.async_get(entity_entry_1.entity_id) + + assert entity_entry_1.disabled_by is None diff --git a/tests/components/device_tracker/test_entities.py b/tests/components/device_tracker/test_entities.py index 88e1dccdb34..12059cad601 100644 --- a/tests/components/device_tracker/test_entities.py +++ b/tests/components/device_tracker/test_entities.py @@ -14,25 +14,33 @@ from homeassistant.components.device_tracker.const import ( SOURCE_TYPE_ROUTER, ) from homeassistant.const import ATTR_BATTERY_LEVEL, STATE_HOME, STATE_NOT_HOME +from homeassistant.helpers import device_registry as dr from tests.common import MockConfigEntry async def test_scanner_entity_device_tracker(hass, enable_custom_integrations): """Test ScannerEntity based device tracker.""" + # Make device tied to other integration so device tracker entities get enabled + dr.async_get(hass).async_get_or_create( + name="Device from other integration", + config_entry_id=MockConfigEntry().entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, "ad:de:ef:be:ed:fe")}, + ) + config_entry = MockConfigEntry(domain="test") config_entry.add_to_hass(hass) await hass.config_entries.async_forward_entry_setup(config_entry, DOMAIN) await hass.async_block_till_done() - entity_id = "device_tracker.unnamed_device" + entity_id = "device_tracker.test_ad_de_ef_be_ed_fe" entity_state = hass.states.get(entity_id) assert entity_state.attributes == { ATTR_SOURCE_TYPE: SOURCE_TYPE_ROUTER, ATTR_BATTERY_LEVEL: 100, ATTR_IP: "0.0.0.0", - ATTR_MAC: "ad:de:ef:be:ed:fe:", + ATTR_MAC: "ad:de:ef:be:ed:fe", ATTR_HOST_NAME: "test.hostname.org", } assert entity_state.state == STATE_NOT_HOME diff --git a/tests/components/freebox/conftest.py b/tests/components/freebox/conftest.py index 3220552b6cf..2d9f0844115 100644 --- a/tests/components/freebox/conftest.py +++ b/tests/components/freebox/conftest.py @@ -3,6 +3,8 @@ from unittest.mock import AsyncMock, patch import pytest +from homeassistant.helpers import device_registry as dr + from .const import ( DATA_CALL_GET_CALLS_LOG, DATA_CONNECTION_GET_STATUS, @@ -12,6 +14,8 @@ from .const import ( WIFI_GET_GLOBAL_CONFIG, ) +from tests.common import MockConfigEntry + @pytest.fixture(autouse=True) def mock_path(): @@ -20,8 +24,30 @@ def mock_path(): yield +@pytest.fixture +def mock_device_registry_devices(hass): + """Create device registry devices so the device tracker entities are enabled.""" + dev_reg = dr.async_get(hass) + config_entry = MockConfigEntry(domain="something_else") + + for idx, device in enumerate( + ( + "68:A3:78:00:00:00", + "8C:97:EA:00:00:00", + "DE:00:B0:00:00:00", + "DC:00:B0:00:00:00", + "5E:65:55:00:00:00", + ) + ): + dev_reg.async_get_or_create( + name=f"Device {idx}", + config_entry_id=config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, device)}, + ) + + @pytest.fixture(name="router") -def mock_router(): +def mock_router(mock_device_registry_devices): """Mock a successful connection.""" with patch("homeassistant.components.freebox.router.Freepybox") as service_mock: instance = service_mock.return_value diff --git a/tests/components/mikrotik/test_device_tracker.py b/tests/components/mikrotik/test_device_tracker.py index fcd29c18682..f36129c223a 100644 --- a/tests/components/mikrotik/test_device_tracker.py +++ b/tests/components/mikrotik/test_device_tracker.py @@ -1,9 +1,11 @@ """The tests for the Mikrotik device tracker platform.""" from datetime import timedelta +import pytest + from homeassistant.components import mikrotik import homeassistant.components.device_tracker as device_tracker -from homeassistant.helpers import entity_registry as er +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -15,6 +17,25 @@ from tests.common import MockConfigEntry, patch DEFAULT_DETECTION_TIME = timedelta(seconds=300) +@pytest.fixture +def mock_device_registry_devices(hass): + """Create device registry devices so the device tracker entities are enabled.""" + dev_reg = dr.async_get(hass) + config_entry = MockConfigEntry(domain="something_else") + + for idx, device in enumerate( + ( + "00:00:00:00:00:01", + "00:00:00:00:00:02", + ) + ): + dev_reg.async_get_or_create( + name=f"Device {idx}", + config_entry_id=config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, device)}, + ) + + def mock_command(self, cmd, params=None): """Mock the Mikrotik command method.""" if cmd == mikrotik.const.MIKROTIK_SERVICES[mikrotik.const.IS_WIRELESS]: @@ -39,7 +60,9 @@ async def test_platform_manually_configured(hass): assert mikrotik.DOMAIN not in hass.data -async def test_device_trackers(hass, legacy_patchable_time): +async def test_device_trackers( + hass, legacy_patchable_time, mock_device_registry_devices +): """Test device_trackers created by mikrotik.""" # test devices are added from wireless list only diff --git a/tests/components/ruckus_unleashed/__init__.py b/tests/components/ruckus_unleashed/__init__.py index eff80a0387a..5c50f845064 100644 --- a/tests/components/ruckus_unleashed/__init__.py +++ b/tests/components/ruckus_unleashed/__init__.py @@ -16,6 +16,7 @@ from homeassistant.components.ruckus_unleashed.const import ( API_VERSION, ) from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME +from homeassistant.helpers import device_registry as dr from tests.common import MockConfigEntry @@ -68,6 +69,13 @@ def mock_config_entry() -> MockConfigEntry: async def init_integration(hass) -> MockConfigEntry: """Set up the Ruckus Unleashed integration in Home Assistant.""" entry = mock_config_entry() + entry.add_to_hass(hass) + # Make device tied to other integration so device tracker entities get enabled + dr.async_get(hass).async_get_or_create( + name="Device from other integration", + config_entry_id=MockConfigEntry().entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, TEST_CLIENT[API_MAC])}, + ) with patch( "homeassistant.components.ruckus_unleashed.Ruckus.connect", return_value=None, @@ -86,7 +94,6 @@ async def init_integration(hass) -> MockConfigEntry: TEST_CLIENT[API_MAC]: TEST_CLIENT, }, ): - entry.add_to_hass(hass) await hass.config_entries.async_setup(entry.entry_id) await hass.async_block_till_done() diff --git a/tests/components/ruckus_unleashed/test_device_tracker.py b/tests/components/ruckus_unleashed/test_device_tracker.py index 92382007273..2c64bd3d0a8 100644 --- a/tests/components/ruckus_unleashed/test_device_tracker.py +++ b/tests/components/ruckus_unleashed/test_device_tracker.py @@ -3,10 +3,8 @@ from datetime import timedelta from unittest.mock import patch from homeassistant.components.ruckus_unleashed import API_MAC, DOMAIN -from homeassistant.components.ruckus_unleashed.const import API_AP, API_ID, API_NAME from homeassistant.const import STATE_HOME, STATE_NOT_HOME, STATE_UNAVAILABLE -from homeassistant.helpers import device_registry as dr, entity_registry as er -from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC +from homeassistant.helpers import entity_registry as er from homeassistant.util import utcnow from tests.common import async_fire_time_changed @@ -112,24 +110,3 @@ async def test_restoring_clients(hass): device = hass.states.get(TEST_CLIENT_ENTITY_ID) assert device is not None assert device.state == STATE_NOT_HOME - - -async def test_client_device_setup(hass): - """Test a client device is created.""" - await init_integration(hass) - - router_info = DEFAULT_AP_INFO[API_AP][API_ID]["1"] - - device_registry = dr.async_get(hass) - client_device = device_registry.async_get_device( - identifiers={}, - connections={(CONNECTION_NETWORK_MAC, TEST_CLIENT[API_MAC])}, - ) - router_device = device_registry.async_get_device( - identifiers={(CONNECTION_NETWORK_MAC, router_info[API_MAC])}, - connections={(CONNECTION_NETWORK_MAC, router_info[API_MAC])}, - ) - - assert client_device - assert client_device.name == TEST_CLIENT[API_NAME] - assert client_device.via_device_id == router_device.id diff --git a/tests/components/unifi/conftest.py b/tests/components/unifi/conftest.py index 42e9db6b958..e21c458386f 100644 --- a/tests/components/unifi/conftest.py +++ b/tests/components/unifi/conftest.py @@ -6,6 +6,10 @@ from unittest.mock import patch from aiounifi.websocket import SIGNAL_CONNECTION_STATE, SIGNAL_DATA import pytest +from homeassistant.helpers import device_registry as dr + +from tests.common import MockConfigEntry + @pytest.fixture(autouse=True) def mock_unifi_websocket(): @@ -34,3 +38,27 @@ def mock_discovery(): return_value=None, ) as mock: yield mock + + +@pytest.fixture +def mock_device_registry(hass): + """Mock device registry.""" + dev_reg = dr.async_get(hass) + config_entry = MockConfigEntry(domain="something_else") + + for idx, device in enumerate( + ( + "00:00:00:00:00:01", + "00:00:00:00:00:02", + "00:00:00:00:00:03", + "00:00:00:00:00:04", + "00:00:00:00:00:05", + "00:00:00:00:01:01", + "00:00:00:00:02:02", + ) + ): + dev_reg.async_get_or_create( + name=f"Device {idx}", + config_entry_id=config_entry.entry_id, + connections={(dr.CONNECTION_NETWORK_MAC, device)}, + ) diff --git a/tests/components/unifi/test_controller.py b/tests/components/unifi/test_controller.py index 738cb28e1e3..81990c42231 100644 --- a/tests/components/unifi/test_controller.py +++ b/tests/components/unifi/test_controller.py @@ -346,7 +346,9 @@ async def test_reset_fails(hass, aioclient_mock): assert result is False -async def test_connection_state_signalling(hass, aioclient_mock, mock_unifi_websocket): +async def test_connection_state_signalling( + hass, aioclient_mock, mock_unifi_websocket, mock_device_registry +): """Verify connection statesignalling and connection state are working.""" client = { "hostname": "client", diff --git a/tests/components/unifi/test_device_tracker.py b/tests/components/unifi/test_device_tracker.py index 4014062ee27..d7d29db1be9 100644 --- a/tests/components/unifi/test_device_tracker.py +++ b/tests/components/unifi/test_device_tracker.py @@ -38,7 +38,9 @@ async def test_no_entities(hass, aioclient_mock): assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 0 -async def test_tracked_wireless_clients(hass, aioclient_mock, mock_unifi_websocket): +async def test_tracked_wireless_clients( + hass, aioclient_mock, mock_unifi_websocket, mock_device_registry +): """Verify tracking of wireless clients.""" client = { "ap_mac": "00:00:00:00:02:01", @@ -157,7 +159,9 @@ async def test_tracked_wireless_clients(hass, aioclient_mock, mock_unifi_websock assert hass.states.get("device_tracker.client").state == STATE_HOME -async def test_tracked_clients(hass, aioclient_mock, mock_unifi_websocket): +async def test_tracked_clients( + hass, aioclient_mock, mock_unifi_websocket, mock_device_registry +): """Test the update_items function with some clients.""" client_1 = { "ap_mac": "00:00:00:00:02:01", @@ -234,7 +238,9 @@ async def test_tracked_clients(hass, aioclient_mock, mock_unifi_websocket): assert hass.states.get("device_tracker.client_1").state == STATE_HOME -async def test_tracked_devices(hass, aioclient_mock, mock_unifi_websocket): +async def test_tracked_devices( + hass, aioclient_mock, mock_unifi_websocket, mock_device_registry +): """Test the update_items function with some devices.""" device_1 = { "board_rev": 3, @@ -321,45 +327,10 @@ async def test_tracked_devices(hass, aioclient_mock, mock_unifi_websocket): assert hass.states.get("device_tracker.device_1").state == STATE_UNAVAILABLE assert hass.states.get("device_tracker.device_2").state == STATE_HOME - # Update device registry when device is upgraded - event = { - "_id": "5eae7fe02ab79c00f9d38960", - "datetime": "2020-05-09T20:06:37Z", - "key": "EVT_SW_Upgraded", - "msg": f'Switch[{device_2["mac"]}] was upgraded from "{device_2["version"]}" to "4.3.13.11253"', - "subsystem": "lan", - "sw": device_2["mac"], - "sw_name": device_2["name"], - "time": 1589054797635, - "version_from": {device_2["version"]}, - "version_to": "4.3.13.11253", - } - - device_2["version"] = event["version_to"] - mock_unifi_websocket( - data={ - "meta": {"message": MESSAGE_DEVICE}, - "data": [device_2], - } - ) - mock_unifi_websocket( - data={ - "meta": {"message": MESSAGE_EVENT}, - "data": [event], - } - ) - await hass.async_block_till_done() - - # Verify device registry has been updated - entity_registry = er.async_get(hass) - entry = entity_registry.async_get("device_tracker.device_2") - device_registry = dr.async_get(hass) - device = device_registry.async_get(entry.device_id) - assert device.sw_version == event["version_to"] - - -async def test_remove_clients(hass, aioclient_mock, mock_unifi_websocket): +async def test_remove_clients( + hass, aioclient_mock, mock_unifi_websocket, mock_device_registry +): """Test the remove_items function with some clients.""" client_1 = { "essid": "ssid", @@ -399,7 +370,7 @@ async def test_remove_clients(hass, aioclient_mock, mock_unifi_websocket): async def test_remove_client_but_keep_device_entry( - hass, aioclient_mock, mock_unifi_websocket + hass, aioclient_mock, mock_unifi_websocket, mock_device_registry ): """Test that unifi entity base remove config entry id from a multi integration device registry entry.""" client_1 = { @@ -424,7 +395,7 @@ async def test_remove_client_but_keep_device_entry( "unique_id", device_id=device_entry.id, ) - assert len(device_entry.config_entries) == 2 + assert len(device_entry.config_entries) == 3 mock_unifi_websocket( data={ @@ -438,10 +409,12 @@ async def test_remove_client_but_keep_device_entry( assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 0 device_entry = device_registry.async_get(other_entity.device_id) - assert len(device_entry.config_entries) == 1 + assert len(device_entry.config_entries) == 2 -async def test_controller_state_change(hass, aioclient_mock, mock_unifi_websocket): +async def test_controller_state_change( + hass, aioclient_mock, mock_unifi_websocket, mock_device_registry +): """Verify entities state reflect on controller becoming unavailable.""" client = { "essid": "ssid", @@ -495,7 +468,7 @@ async def test_controller_state_change(hass, aioclient_mock, mock_unifi_websocke async def test_controller_state_change_client_to_listen_on_all_state_changes( - hass, aioclient_mock, mock_unifi_websocket + hass, aioclient_mock, mock_unifi_websocket, mock_device_registry ): """Verify entities state reflect on controller becoming unavailable.""" client = { @@ -579,7 +552,7 @@ async def test_controller_state_change_client_to_listen_on_all_state_changes( assert hass.states.get("device_tracker.client").state == STATE_HOME -async def test_option_track_clients(hass, aioclient_mock): +async def test_option_track_clients(hass, aioclient_mock, mock_device_registry): """Test the tracking of clients can be turned off.""" wireless_client = { "essid": "ssid", @@ -645,7 +618,7 @@ async def test_option_track_clients(hass, aioclient_mock): assert hass.states.get("device_tracker.device") -async def test_option_track_wired_clients(hass, aioclient_mock): +async def test_option_track_wired_clients(hass, aioclient_mock, mock_device_registry): """Test the tracking of wired clients can be turned off.""" wireless_client = { "essid": "ssid", @@ -711,7 +684,7 @@ async def test_option_track_wired_clients(hass, aioclient_mock): assert hass.states.get("device_tracker.device") -async def test_option_track_devices(hass, aioclient_mock): +async def test_option_track_devices(hass, aioclient_mock, mock_device_registry): """Test the tracking of devices can be turned off.""" client = { "hostname": "client", @@ -764,7 +737,9 @@ async def test_option_track_devices(hass, aioclient_mock): assert hass.states.get("device_tracker.device") -async def test_option_ssid_filter(hass, aioclient_mock, mock_unifi_websocket): +async def test_option_ssid_filter( + hass, aioclient_mock, mock_unifi_websocket, mock_device_registry +): """Test the SSID filter works. Client will travel from a supported SSID to an unsupported ssid. @@ -896,7 +871,7 @@ async def test_option_ssid_filter(hass, aioclient_mock, mock_unifi_websocket): async def test_wireless_client_go_wired_issue( - hass, aioclient_mock, mock_unifi_websocket + hass, aioclient_mock, mock_unifi_websocket, mock_device_registry ): """Test the solution to catch wireless device go wired UniFi issue. @@ -979,7 +954,9 @@ async def test_wireless_client_go_wired_issue( assert client_state.attributes["is_wired"] is False -async def test_option_ignore_wired_bug(hass, aioclient_mock, mock_unifi_websocket): +async def test_option_ignore_wired_bug( + hass, aioclient_mock, mock_unifi_websocket, mock_device_registry +): """Test option to ignore wired bug.""" client = { "ap_mac": "00:00:00:00:02:01", @@ -1061,7 +1038,7 @@ async def test_option_ignore_wired_bug(hass, aioclient_mock, mock_unifi_websocke assert client_state.attributes["is_wired"] is False -async def test_restoring_client(hass, aioclient_mock): +async def test_restoring_client(hass, aioclient_mock, mock_device_registry): """Verify clients are restored from clients_all if they ever was registered to entity registry.""" client = { "hostname": "client", @@ -1115,7 +1092,7 @@ async def test_restoring_client(hass, aioclient_mock): assert not hass.states.get("device_tracker.not_restored") -async def test_dont_track_clients(hass, aioclient_mock): +async def test_dont_track_clients(hass, aioclient_mock, mock_device_registry): """Test don't track clients config works.""" wireless_client = { "essid": "ssid", @@ -1175,7 +1152,7 @@ async def test_dont_track_clients(hass, aioclient_mock): assert hass.states.get("device_tracker.device") -async def test_dont_track_devices(hass, aioclient_mock): +async def test_dont_track_devices(hass, aioclient_mock, mock_device_registry): """Test don't track devices config works.""" client = { "hostname": "client", @@ -1224,7 +1201,7 @@ async def test_dont_track_devices(hass, aioclient_mock): assert hass.states.get("device_tracker.device") -async def test_dont_track_wired_clients(hass, aioclient_mock): +async def test_dont_track_wired_clients(hass, aioclient_mock, mock_device_registry): """Test don't track wired clients config works.""" wireless_client = { "essid": "ssid", diff --git a/tests/components/unifi/test_switch.py b/tests/components/unifi/test_switch.py index 38231ef9609..a5ca1c0ee6a 100644 --- a/tests/components/unifi/test_switch.py +++ b/tests/components/unifi/test_switch.py @@ -5,7 +5,6 @@ from unittest.mock import patch from aiounifi.controller import MESSAGE_CLIENT_REMOVED, MESSAGE_EVENT from homeassistant import config_entries, core -from homeassistant.components.device_tracker import DOMAIN as TRACKER_DOMAIN from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN from homeassistant.components.unifi.const import ( CONF_BLOCK_CLIENT, @@ -784,8 +783,6 @@ async def test_ignore_multiple_poe_clients_on_same_port(hass, aioclient_mock): devices_response=[DEVICE_1], ) - assert len(hass.states.async_entity_ids(TRACKER_DOMAIN)) == 3 - switch_1 = hass.states.get("switch.poe_client_1") switch_2 = hass.states.get("switch.poe_client_2") assert switch_1 is None diff --git a/tests/testing_config/custom_components/test/device_tracker.py b/tests/testing_config/custom_components/test/device_tracker.py index e4853d156ce..d5f34f48ec8 100644 --- a/tests/testing_config/custom_components/test/device_tracker.py +++ b/tests/testing_config/custom_components/test/device_tracker.py @@ -18,7 +18,7 @@ class MockScannerEntity(ScannerEntity): self.connected = False self._hostname = "test.hostname.org" self._ip_address = "0.0.0.0" - self._mac_address = "ad:de:ef:be:ed:fe:" + self._mac_address = "ad:de:ef:be:ed:fe" @property def source_type(self):