diff --git a/homeassistant/components/device_tracker/config_entry.py b/homeassistant/components/device_tracker/config_entry.py index 18d769df07f..c83ca669d6d 100644 --- a/homeassistant/components/device_tracker/config_entry.py +++ b/homeassistant/components/device_tracker/config_entry.py @@ -16,12 +16,21 @@ from homeassistant.const import ( ) from homeassistant.core import Event, HomeAssistant, callback from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.entity import DeviceInfo, Entity, EntityCategory from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_platform import EntityPlatform from homeassistant.helpers.typing import StateType -from .const import ATTR_HOST_NAME, ATTR_IP, ATTR_MAC, ATTR_SOURCE_TYPE, DOMAIN, LOGGER +from .const import ( + ATTR_HOST_NAME, + ATTR_IP, + ATTR_MAC, + ATTR_SOURCE_TYPE, + CONNECTED_DEVICE_REGISTERED, + DOMAIN, + LOGGER, +) async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: @@ -64,9 +73,33 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return await component.async_unload_entry(entry) +@callback +def _async_connected_device_registered( + hass: HomeAssistant, mac: str, ip_address: str | None, hostname: str | None +) -> None: + """Register a newly seen connected device. + + This is currently used by the dhcp integration + to listen for newly registered connected devices + for discovery. + """ + async_dispatcher_send( + hass, + CONNECTED_DEVICE_REGISTERED, + { + ATTR_IP: ip_address, + ATTR_MAC: mac, + ATTR_HOST_NAME: hostname, + }, + ) + + @callback def _async_register_mac( - hass: HomeAssistant, domain: str, mac: str, unique_id: str + hass: HomeAssistant, + domain: str, + mac: str, + unique_id: str, ) -> None: """Register a mac address with a unique ID.""" data_key = "device_tracker_mac" @@ -297,8 +330,18 @@ class ScannerEntity(BaseTrackerEntity): super().add_to_platform_start(hass, platform, parallel_updates) if self.mac_address and self.unique_id: _async_register_mac( - hass, platform.platform_name, self.mac_address, self.unique_id + hass, + platform.platform_name, + self.mac_address, + self.unique_id, ) + if self.is_connected: + _async_connected_device_registered( + hass, + self.mac_address, + self.ip_address, + self.hostname, + ) @callback def find_device_entry(self) -> dr.DeviceEntry | None: diff --git a/homeassistant/components/device_tracker/const.py b/homeassistant/components/device_tracker/const.py index 216255b9cb6..c52241ae51f 100644 --- a/homeassistant/components/device_tracker/const.py +++ b/homeassistant/components/device_tracker/const.py @@ -37,3 +37,5 @@ ATTR_MAC: Final = "mac" ATTR_SOURCE_TYPE: Final = "source_type" ATTR_CONSIDER_HOME: Final = "consider_home" ATTR_IP: Final = "ip" + +CONNECTED_DEVICE_REGISTERED: Final = "device_tracker_connected_device_registered" diff --git a/homeassistant/components/dhcp/__init__.py b/homeassistant/components/dhcp/__init__.py index ff67f77257b..a3de0e51708 100644 --- a/homeassistant/components/dhcp/__init__.py +++ b/homeassistant/components/dhcp/__init__.py @@ -1,6 +1,7 @@ """The dhcp integration.""" from __future__ import annotations +from abc import abstractmethod from dataclasses import dataclass from datetime import timedelta import fnmatch @@ -25,6 +26,7 @@ from homeassistant.components.device_tracker.const import ( ATTR_IP, ATTR_MAC, ATTR_SOURCE_TYPE, + CONNECTED_DEVICE_REGISTERED, DOMAIN as DEVICE_TRACKER_DOMAIN, SOURCE_TYPE_ROUTER, ) @@ -42,6 +44,7 @@ from homeassistant.helpers.device_registry import ( async_get, format_mac, ) +from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.event import ( async_track_state_added_domain, async_track_time_interval, @@ -109,16 +112,23 @@ class DhcpServiceInfo(BaseServiceInfo): async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the dhcp component.""" + watchers: list[WatcherBase] = [] + address_data: dict[str, dict[str, str]] = {} + integration_matchers = await async_get_dhcp(hass) + + # For the passive classes we need to start listening + # for state changes and connect the dispatchers before + # everything else starts up or we will miss events + for passive_cls in (DeviceTrackerRegisteredWatcher, DeviceTrackerWatcher): + passive_watcher = passive_cls(hass, address_data, integration_matchers) + await passive_watcher.async_start() + watchers.append(passive_watcher) async def _initialize(_): - address_data = {} - integration_matchers = await async_get_dhcp(hass) - watchers = [] - - for cls in (DHCPWatcher, DeviceTrackerWatcher, NetworkWatcher): - watcher = cls(hass, address_data, integration_matchers) - await watcher.async_start() - watchers.append(watcher) + for active_cls in (DHCPWatcher, NetworkWatcher): + active_watcher = active_cls(hass, address_data, integration_matchers) + await active_watcher.async_start() + watchers.append(active_watcher) async def _async_stop(*_): for watcher in watchers: @@ -141,6 +151,14 @@ class WatcherBase: self._integration_matchers = integration_matchers self._address_data = address_data + @abstractmethod + async def async_stop(self): + """Stop the watcher.""" + + @abstractmethod + async def async_start(self): + """Start the watcher.""" + def process_client(self, ip_address, hostname, mac_address): """Process a client.""" return run_callback_threadsafe( @@ -320,6 +338,39 @@ class DeviceTrackerWatcher(WatcherBase): self.async_process_client(ip_address, hostname, _format_mac(mac_address)) +class DeviceTrackerRegisteredWatcher(WatcherBase): + """Class to watch data from device tracker registrations.""" + + def __init__(self, hass, address_data, integration_matchers): + """Initialize class.""" + super().__init__(hass, address_data, integration_matchers) + self._unsub = None + + async def async_stop(self): + """Stop watching for device tracker registrations.""" + if self._unsub: + self._unsub() + self._unsub = None + + async def async_start(self): + """Stop watching for device tracker registrations.""" + self._unsub = async_dispatcher_connect( + self.hass, CONNECTED_DEVICE_REGISTERED, self._async_process_device_state + ) + + @callback + def _async_process_device_state(self, data: dict[str, Any]) -> None: + """Process a device tracker state.""" + ip_address = data.get(ATTR_IP) + hostname = data.get(ATTR_HOST_NAME, "") + mac_address = data.get(ATTR_MAC) + + if ip_address is None or mac_address is None: + return + + self.async_process_client(ip_address, hostname, _format_mac(mac_address)) + + class DHCPWatcher(WatcherBase): """Class to watch dhcp requests.""" diff --git a/tests/components/device_tracker/test_config_entry.py b/tests/components/device_tracker/test_config_entry.py index 3c8efad5b05..5134123074e 100644 --- a/tests/components/device_tracker/test_config_entry.py +++ b/tests/components/device_tracker/test_config_entry.py @@ -1,8 +1,15 @@ """Test Device Tracker config entry things.""" from homeassistant.components.device_tracker import DOMAIN, config_entry as ce +from homeassistant.core import callback from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.helpers.dispatcher import async_dispatcher_connect -from tests.common import MockConfigEntry +from tests.common import ( + MockConfigEntry, + MockEntityPlatform, + MockPlatform, + mock_registry, +) def test_tracker_entity(): @@ -128,3 +135,87 @@ async def test_register_mac(hass): entity_entry_1 = ent_reg.async_get(entity_entry_1.entity_id) assert entity_entry_1.disabled_by is None + + +async def test_connected_device_registered(hass): + """Test dispatch on connected device being registered.""" + + registry = mock_registry(hass) + dispatches = [] + + @callback + def _save_dispatch(msg): + dispatches.append(msg) + + unsub = async_dispatcher_connect( + hass, ce.CONNECTED_DEVICE_REGISTERED, _save_dispatch + ) + + class MockScannerEntity(ce.ScannerEntity): + """Mock a scanner entity.""" + + @property + def ip_address(self) -> str: + return "5.4.3.2" + + @property + def unique_id(self) -> str: + return self.mac_address + + class MockDisconnectedScannerEntity(MockScannerEntity): + """Mock a disconnected scanner entity.""" + + @property + def mac_address(self) -> str: + return "aa:bb:cc:dd:ee:ff" + + @property + def is_connected(self) -> bool: + return True + + @property + def hostname(self) -> str: + return "connected" + + class MockConnectedScannerEntity(MockScannerEntity): + """Mock a disconnected scanner entity.""" + + @property + def mac_address(self) -> str: + return "aa:bb:cc:dd:ee:00" + + @property + def is_connected(self) -> bool: + return False + + @property + def hostname(self) -> str: + return "disconnected" + + async def async_setup_entry(hass, config_entry, async_add_entities): + """Mock setup entry method.""" + async_add_entities( + [MockConnectedScannerEntity(), MockDisconnectedScannerEntity()] + ) + return True + + platform = MockPlatform(async_setup_entry=async_setup_entry) + config_entry = MockConfigEntry(entry_id="super-mock-id") + entity_platform = MockEntityPlatform( + hass, platform_name=config_entry.domain, platform=platform + ) + + assert await entity_platform.async_setup_entry(config_entry) + await hass.async_block_till_done() + full_name = f"{entity_platform.domain}.{config_entry.domain}" + assert full_name in hass.config.components + assert len(hass.states.async_entity_ids()) == 0 # should be disabled + assert len(registry.entities) == 2 + assert ( + registry.entities["test_domain.test_aa_bb_cc_dd_ee_ff"].config_entry_id + == "super-mock-id" + ) + unsub() + assert dispatches == [ + {"ip": "5.4.3.2", "mac": "aa:bb:cc:dd:ee:ff", "host_name": "connected"} + ] diff --git a/tests/components/dhcp/test_init.py b/tests/components/dhcp/test_init.py index 3650ed32987..d1b8d72be67 100644 --- a/tests/components/dhcp/test_init.py +++ b/tests/components/dhcp/test_init.py @@ -16,6 +16,7 @@ from homeassistant.components.device_tracker.const import ( ATTR_IP, ATTR_MAC, ATTR_SOURCE_TYPE, + CONNECTED_DEVICE_REGISTERED, SOURCE_TYPE_ROUTER, ) from homeassistant.components.dhcp.const import DOMAIN @@ -26,6 +27,7 @@ from homeassistant.const import ( STATE_NOT_HOME, ) import homeassistant.helpers.device_registry as dr +from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -630,6 +632,37 @@ async def test_device_tracker_hostname_and_macaddress_exists_before_start(hass): ) +async def test_device_tracker_registered(hass): + """Test matching based on hostname and macaddress when registered.""" + with patch.object(hass.config_entries.flow, "async_init") as mock_init: + device_tracker_watcher = dhcp.DeviceTrackerRegisteredWatcher( + hass, + {}, + [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], + ) + await device_tracker_watcher.async_start() + await hass.async_block_till_done() + async_dispatcher_send( + hass, + CONNECTED_DEVICE_REGISTERED, + {"ip": "192.168.210.56", "mac": "b8b7f16db533", "host_name": "connect"}, + ) + await hass.async_block_till_done() + + assert len(mock_init.mock_calls) == 1 + assert mock_init.mock_calls[0][1][0] == "mock-domain" + assert mock_init.mock_calls[0][2]["context"] == { + "source": config_entries.SOURCE_DHCP + } + assert mock_init.mock_calls[0][2]["data"] == dhcp.DhcpServiceInfo( + ip="192.168.210.56", + hostname="connect", + macaddress="b8b7f16db533", + ) + await device_tracker_watcher.async_stop() + await hass.async_block_till_done() + + async def test_device_tracker_hostname_and_macaddress_after_start(hass): """Test matching based on hostname and macaddress after start."""