diff --git a/homeassistant/components/device_tracker/config_entry.py b/homeassistant/components/device_tracker/config_entry.py index 1be47b9b981..8cdc843f680 100644 --- a/homeassistant/components/device_tracker/config_entry.py +++ b/homeassistant/components/device_tracker/config_entry.py @@ -13,7 +13,7 @@ from homeassistant.const import ( from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_component import EntityComponent -from .const import ATTR_SOURCE_TYPE, DOMAIN, LOGGER +from .const import ATTR_HOST_NAME, ATTR_IP, ATTR_MAC, ATTR_SOURCE_TYPE, DOMAIN, LOGGER async def async_setup_entry(hass, entry): @@ -47,6 +47,21 @@ class BaseTrackerEntity(Entity): """Return the source type, eg gps or router, of the device.""" raise NotImplementedError + @property + def ip_address(self) -> str: + """Return the primary ip address of the device.""" + return None + + @property + def mac_address(self) -> str: + """Return the mac address of the device.""" + return None + + @property + def hostname(self) -> str: + """Return hostname of the device.""" + return None + @property def state_attributes(self): """Return the device state attributes.""" @@ -54,6 +69,12 @@ class BaseTrackerEntity(Entity): if self.battery_level: attr[ATTR_BATTERY_LEVEL] = self.battery_level + if self.ip_address is not None: + attr[ATTR_IP] = self.ip_address + if self.ip_address is not None: + attr[ATTR_MAC] = self.mac_address + if self.hostname is not None: + attr[ATTR_HOST_NAME] = self.hostname return attr diff --git a/homeassistant/components/device_tracker/const.py b/homeassistant/components/device_tracker/const.py index c9ce9f2024a..aa1b349ef12 100644 --- a/homeassistant/components/device_tracker/const.py +++ b/homeassistant/components/device_tracker/const.py @@ -34,3 +34,4 @@ ATTR_LOCATION_NAME = "location_name" ATTR_MAC = "mac" ATTR_SOURCE_TYPE = "source_type" ATTR_CONSIDER_HOME = "consider_home" +ATTR_IP = "ip" diff --git a/homeassistant/components/dhcp/__init__.py b/homeassistant/components/dhcp/__init__.py index b45add24ee9..a65548d6654 100644 --- a/homeassistant/components/dhcp/__init__.py +++ b/homeassistant/components/dhcp/__init__.py @@ -1,18 +1,32 @@ """The dhcp integration.""" +from abc import abstractmethod import fnmatch import logging import os -from threading import Event, Thread +import threading from scapy.error import Scapy_Exception from scapy.layers.dhcp import DHCP from scapy.layers.l2 import Ether from scapy.sendrecv import sniff -from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP -from homeassistant.core import HomeAssistant +from homeassistant.components.device_tracker.const import ( + ATTR_HOST_NAME, + ATTR_IP, + ATTR_MAC, + ATTR_SOURCE_TYPE, + DOMAIN as DEVICE_TRACKER_DOMAIN, + SOURCE_TYPE_ROUTER, +) +from homeassistant.const import ( + EVENT_HOMEASSISTANT_STARTED, + EVENT_HOMEASSISTANT_STOP, + STATE_HOME, +) +from homeassistant.core import Event, HomeAssistant, State, callback from homeassistant.helpers.device_registry import format_mac +from homeassistant.helpers.event import async_track_state_added_domain from homeassistant.loader import async_get_dhcp from .const import DOMAIN @@ -32,35 +46,162 @@ async def async_setup(hass: HomeAssistant, config: dict) -> bool: """Set up the dhcp component.""" async def _initialize(_): - dhcp_watcher = DHCPWatcher(hass, await async_get_dhcp(hass)) - dhcp_watcher.start() + address_data = {} + integration_matchers = await async_get_dhcp(hass) + watchers = [] - def _stop(*_): - dhcp_watcher.stop() - dhcp_watcher.join() + for cls in (DHCPWatcher, DeviceTrackerWatcher): + watcher = cls(hass, address_data, integration_matchers) + watcher.async_start() + watchers.append(watcher) - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _stop) + async def _async_stop(*_): + for watcher in watchers: + if hasattr(watcher, "async_stop"): + watcher.async_stop() + else: + await hass.async_add_executor_job(watcher.stop) + + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, _initialize) return True -class DHCPWatcher(Thread): - """Class to watch dhcp requests.""" +class WatcherBase: + """Base class for dhcp and device tracker watching.""" - def __init__(self, hass, integration_matchers): + def __init__(self, hass, address_data, integration_matchers): """Initialize class.""" super().__init__() self.hass = hass - self.name = "dhcp-discovery" self._integration_matchers = integration_matchers - self._address_data = {} - self._stop_event = Event() + self._address_data = address_data + + def process_client(self, ip_address, hostname, mac_address): + """Process a client.""" + data = self._address_data.get(ip_address) + + if data and data[MAC_ADDRESS] == mac_address and data[HOSTNAME] == hostname: + # If the address data is the same no need + # to process it + return + + self._address_data[ip_address] = {MAC_ADDRESS: mac_address, HOSTNAME: hostname} + + self.process_updated_address_data(ip_address, self._address_data[ip_address]) + + def process_updated_address_data(self, ip_address, data): + """Process the address data update.""" + lowercase_hostname = data[HOSTNAME].lower() + uppercase_mac = data[MAC_ADDRESS].upper() + + _LOGGER.debug( + "Processing updated address data for %s: mac=%s hostname=%s", + ip_address, + uppercase_mac, + lowercase_hostname, + ) + + for entry in self._integration_matchers: + if MAC_ADDRESS in entry and not fnmatch.fnmatch( + uppercase_mac, entry[MAC_ADDRESS] + ): + continue + + if HOSTNAME in entry and not fnmatch.fnmatch( + lowercase_hostname, entry[HOSTNAME] + ): + continue + + _LOGGER.debug("Matched %s against %s", data, entry) + + self.create_task( + self.hass.config_entries.flow.async_init( + entry["domain"], + context={"source": DOMAIN}, + data={IP_ADDRESS: ip_address, **data}, + ) + ) + + @abstractmethod + def create_task(self, task): + """Pass a task to async_add_task based on which context we are in.""" + + +class DeviceTrackerWatcher(WatcherBase): + """Class to watch dhcp data from routers.""" + + def __init__(self, hass, address_data, integration_matchers): + """Initialize class.""" + super().__init__(hass, address_data, integration_matchers) + self._unsub = None + + @callback + def async_stop(self): + """Stop watching for new device trackers.""" + if self._unsub: + self._unsub() + self._unsub = None + + @callback + def async_start(self): + """Stop watching for new device trackers.""" + self._unsub = async_track_state_added_domain( + self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event + ) + for state in self.hass.states.async_all(DEVICE_TRACKER_DOMAIN): + self._async_process_device_state(state) + + @callback + def _async_process_device_event(self, event: Event): + """Process a device tracker state change event.""" + self._async_process_device_state(event.data.get("new_state")) + + @callback + def _async_process_device_state(self, state: State): + """Process a device tracker state.""" + if state.state != STATE_HOME: + return + + attributes = state.attributes + + if attributes.get(ATTR_SOURCE_TYPE) != SOURCE_TYPE_ROUTER: + return + + ip_address = attributes.get(ATTR_IP) + hostname = attributes.get(ATTR_HOST_NAME) + mac_address = attributes.get(ATTR_MAC) + + if ip_address is None or hostname is None or mac_address is None: + return + + self.process_client(ip_address, hostname, _format_mac(mac_address)) + + def create_task(self, task): + """Pass a task to async_create_task since we are in async context.""" + self.hass.async_create_task(task) + + +class DHCPWatcher(WatcherBase, threading.Thread): + """Class to watch dhcp requests.""" + + def __init__(self, hass, address_data, integration_matchers): + """Initialize class.""" + super().__init__(hass, address_data, integration_matchers) + self.name = "dhcp-discovery" + self._stop_event = threading.Event() def stop(self): """Stop the thread.""" self._stop_event.set() + self.join() + + @callback + def async_start(self): + """Start the thread.""" + self.start() def run(self): """Start watching for dhcp packets.""" @@ -98,49 +239,11 @@ class DHCPWatcher(Thread): if ip_address is None or hostname is None or mac_address is None: return - data = self._address_data.get(ip_address) + self.process_client(ip_address, hostname, mac_address) - if data and data[MAC_ADDRESS] == mac_address and data[HOSTNAME] == hostname: - # If the address data is the same no need - # to process it - return - - self._address_data[ip_address] = {MAC_ADDRESS: mac_address, HOSTNAME: hostname} - - self.process_updated_address_data(ip_address, self._address_data[ip_address]) - - def process_updated_address_data(self, ip_address, data): - """Process the address data update.""" - lowercase_hostname = data[HOSTNAME].lower() - uppercase_mac = data[MAC_ADDRESS].upper() - - _LOGGER.debug( - "Processing updated address data for %s: mac=%s hostname=%s", - ip_address, - uppercase_mac, - lowercase_hostname, - ) - - for entry in self._integration_matchers: - if MAC_ADDRESS in entry and not fnmatch.fnmatch( - uppercase_mac, entry[MAC_ADDRESS] - ): - continue - - if HOSTNAME in entry and not fnmatch.fnmatch( - lowercase_hostname, entry[HOSTNAME] - ): - continue - - _LOGGER.debug("Matched %s against %s", data, entry) - - self.hass.add_job( - self.hass.config_entries.flow.async_init( - entry["domain"], - context={"source": DOMAIN}, - data={IP_ADDRESS: ip_address, **data}, - ) - ) + def create_task(self, task): + """Pass a task to hass.add_job since we are in a thread.""" + self.hass.add_job(task) def _decode_dhcp_option(dhcp_options, key): diff --git a/homeassistant/components/unifi/device_tracker.py b/homeassistant/components/unifi/device_tracker.py index 9f7726e1ba1..22e4904ab8b 100644 --- a/homeassistant/components/unifi/device_tracker.py +++ b/homeassistant/components/unifi/device_tracker.py @@ -52,6 +52,7 @@ CLIENT_STATIC_ATTRIBUTES = [ "oui", ] + CLIENT_CONNECTED_ALL_ATTRIBUTES = CLIENT_CONNECTED_ATTRIBUTES + CLIENT_STATIC_ATTRIBUTES DEVICE_UPGRADED = (ACCESS_POINT_UPGRADED, GATEWAY_UPGRADED, SWITCH_UPGRADED) @@ -239,6 +240,21 @@ class UniFiClientTracker(UniFiClient, ScannerEntity): return attributes + @property + def ip_address(self) -> str: + """Return the primary ip address of the device.""" + return self.client.raw.get("ip") + + @property + def mac_address(self) -> str: + """Return the mac address of the device.""" + return self.client.raw.get("mac") + + @property + def hostname(self) -> str: + """Return hostname of the device.""" + return self.client.raw.get("hostname") + async def options_updated(self) -> None: """Config entry options are updated, remove entity if option is disabled.""" if not self.controller.option_track_clients: diff --git a/tests/components/device_tracker/test_entities.py b/tests/components/device_tracker/test_entities.py index a0b2553543d..6c8674f97b9 100644 --- a/tests/components/device_tracker/test_entities.py +++ b/tests/components/device_tracker/test_entities.py @@ -59,3 +59,6 @@ def test_base_tracker_entity(): assert entity.battery_level is None with pytest.raises(NotImplementedError): assert entity.state_attributes is None + assert entity.ip_address is None + assert entity.mac_address is None + assert entity.hostname is None diff --git a/tests/components/dhcp/test_init.py b/tests/components/dhcp/test_init.py index 04cbc401b08..b5724f4a303 100644 --- a/tests/components/dhcp/test_init.py +++ b/tests/components/dhcp/test_init.py @@ -7,7 +7,19 @@ from scapy.layers.dhcp import DHCP from scapy.layers.l2 import Ether from homeassistant.components import dhcp -from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP +from homeassistant.components.device_tracker.const import ( + ATTR_HOST_NAME, + ATTR_IP, + ATTR_MAC, + ATTR_SOURCE_TYPE, + SOURCE_TYPE_ROUTER, +) +from homeassistant.const import ( + EVENT_HOMEASSISTANT_STARTED, + EVENT_HOMEASSISTANT_STOP, + STATE_HOME, + STATE_NOT_HOME, +) from homeassistant.setup import async_setup_component from tests.common import mock_coro @@ -41,6 +53,7 @@ async def test_dhcp_match_hostname_and_macaddress(hass): """Test matching based on hostname and macaddress.""" dhcp_watcher = dhcp.DHCPWatcher( hass, + {}, [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], ) @@ -66,7 +79,7 @@ async def test_dhcp_match_hostname_and_macaddress(hass): async def test_dhcp_match_hostname(hass): """Test matching based on hostname only.""" dhcp_watcher = dhcp.DHCPWatcher( - hass, [{"domain": "mock-domain", "hostname": "connect"}] + hass, {}, [{"domain": "mock-domain", "hostname": "connect"}] ) packet = Ether(RAW_DHCP_REQUEST) @@ -89,7 +102,7 @@ async def test_dhcp_match_hostname(hass): async def test_dhcp_match_macaddress(hass): """Test matching based on macaddress only.""" dhcp_watcher = dhcp.DHCPWatcher( - hass, [{"domain": "mock-domain", "macaddress": "B8B7F1*"}] + hass, {}, [{"domain": "mock-domain", "macaddress": "B8B7F1*"}] ) packet = Ether(RAW_DHCP_REQUEST) @@ -112,7 +125,7 @@ async def test_dhcp_match_macaddress(hass): async def test_dhcp_nomatch(hass): """Test not matching based on macaddress only.""" dhcp_watcher = dhcp.DHCPWatcher( - hass, [{"domain": "mock-domain", "macaddress": "ABC123*"}] + hass, {}, [{"domain": "mock-domain", "macaddress": "ABC123*"}] ) packet = Ether(RAW_DHCP_REQUEST) @@ -128,7 +141,7 @@ async def test_dhcp_nomatch(hass): async def test_dhcp_nomatch_hostname(hass): """Test not matching based on hostname only.""" dhcp_watcher = dhcp.DHCPWatcher( - hass, [{"domain": "mock-domain", "hostname": "nomatch*"}] + hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}] ) packet = Ether(RAW_DHCP_REQUEST) @@ -144,7 +157,7 @@ async def test_dhcp_nomatch_hostname(hass): async def test_dhcp_nomatch_non_dhcp_packet(hass): """Test matching does not throw on a non-dhcp packet.""" dhcp_watcher = dhcp.DHCPWatcher( - hass, [{"domain": "mock-domain", "hostname": "nomatch*"}] + hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}] ) packet = Ether(b"") @@ -160,7 +173,7 @@ async def test_dhcp_nomatch_non_dhcp_packet(hass): async def test_dhcp_nomatch_non_dhcp_request_packet(hass): """Test nothing happens with the wrong message-type.""" dhcp_watcher = dhcp.DHCPWatcher( - hass, [{"domain": "mock-domain", "hostname": "nomatch*"}] + hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}] ) packet = Ether(RAW_DHCP_REQUEST) @@ -185,7 +198,7 @@ async def test_dhcp_nomatch_non_dhcp_request_packet(hass): async def test_dhcp_invalid_hostname(hass): """Test we ignore invalid hostnames.""" dhcp_watcher = dhcp.DHCPWatcher( - hass, [{"domain": "mock-domain", "hostname": "nomatch*"}] + hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}] ) packet = Ether(RAW_DHCP_REQUEST) @@ -210,7 +223,7 @@ async def test_dhcp_invalid_hostname(hass): async def test_dhcp_missing_hostname(hass): """Test we ignore missing hostnames.""" dhcp_watcher = dhcp.DHCPWatcher( - hass, [{"domain": "mock-domain", "hostname": "nomatch*"}] + hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}] ) packet = Ether(RAW_DHCP_REQUEST) @@ -235,7 +248,7 @@ async def test_dhcp_missing_hostname(hass): async def test_dhcp_invalid_option(hass): """Test we ignore invalid hostname option.""" dhcp_watcher = dhcp.DHCPWatcher( - hass, [{"domain": "mock-domain", "hostname": "nomatch*"}] + hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}] ) packet = Ether(RAW_DHCP_REQUEST) @@ -327,3 +340,167 @@ async def test_setup_fails_non_root(hass, caplog): await hass.async_block_till_done() wait_event.set() assert "Cannot watch for dhcp packets without root or CAP_NET_RAW" in caplog.text + + +async def test_device_tracker_hostname_and_macaddress_exists_before_start(hass): + """Test matching based on hostname and macaddress before start.""" + hass.states.async_set( + "device_tracker.august_connect", + STATE_HOME, + { + ATTR_HOST_NAME: "connect", + ATTR_IP: "192.168.210.56", + ATTR_SOURCE_TYPE: SOURCE_TYPE_ROUTER, + ATTR_MAC: "B8:B7:F1:6D:B5:33", + }, + ) + + with patch.object( + hass.config_entries.flow, "async_init", return_value=mock_coro() + ) as mock_init: + device_tracker_watcher = dhcp.DeviceTrackerWatcher( + hass, + {}, + [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], + ) + device_tracker_watcher.async_start() + await hass.async_block_till_done() + device_tracker_watcher.async_stop() + await hass.async_block_till_done() + + assert len(mock_init.mock_calls) == 1 + assert mock_init.mock_calls[0][1][0] == "mock-domain" + assert mock_init.mock_calls[0][2]["context"] == {"source": "dhcp"} + assert mock_init.mock_calls[0][2]["data"] == { + dhcp.IP_ADDRESS: "192.168.210.56", + dhcp.HOSTNAME: "connect", + dhcp.MAC_ADDRESS: "b8b7f16db533", + } + + +async def test_device_tracker_hostname_and_macaddress_after_start(hass): + """Test matching based on hostname and macaddress after start.""" + + with patch.object( + hass.config_entries.flow, "async_init", return_value=mock_coro() + ) as mock_init: + device_tracker_watcher = dhcp.DeviceTrackerWatcher( + hass, + {}, + [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], + ) + device_tracker_watcher.async_start() + await hass.async_block_till_done() + hass.states.async_set( + "device_tracker.august_connect", + STATE_HOME, + { + ATTR_HOST_NAME: "connect", + ATTR_IP: "192.168.210.56", + ATTR_SOURCE_TYPE: SOURCE_TYPE_ROUTER, + ATTR_MAC: "B8:B7:F1:6D:B5:33", + }, + ) + await hass.async_block_till_done() + device_tracker_watcher.async_stop() + await hass.async_block_till_done() + + assert len(mock_init.mock_calls) == 1 + assert mock_init.mock_calls[0][1][0] == "mock-domain" + assert mock_init.mock_calls[0][2]["context"] == {"source": "dhcp"} + assert mock_init.mock_calls[0][2]["data"] == { + dhcp.IP_ADDRESS: "192.168.210.56", + dhcp.HOSTNAME: "connect", + dhcp.MAC_ADDRESS: "b8b7f16db533", + } + + +async def test_device_tracker_hostname_and_macaddress_after_start_not_home(hass): + """Test matching based on hostname and macaddress after start but not home.""" + + with patch.object( + hass.config_entries.flow, "async_init", return_value=mock_coro() + ) as mock_init: + device_tracker_watcher = dhcp.DeviceTrackerWatcher( + hass, + {}, + [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], + ) + device_tracker_watcher.async_start() + await hass.async_block_till_done() + hass.states.async_set( + "device_tracker.august_connect", + STATE_NOT_HOME, + { + ATTR_HOST_NAME: "connect", + ATTR_IP: "192.168.210.56", + ATTR_SOURCE_TYPE: SOURCE_TYPE_ROUTER, + ATTR_MAC: "B8:B7:F1:6D:B5:33", + }, + ) + await hass.async_block_till_done() + device_tracker_watcher.async_stop() + await hass.async_block_till_done() + + assert len(mock_init.mock_calls) == 0 + + +async def test_device_tracker_hostname_and_macaddress_after_start_not_router(hass): + """Test matching based on hostname and macaddress after start but not router.""" + + with patch.object( + hass.config_entries.flow, "async_init", return_value=mock_coro() + ) as mock_init: + device_tracker_watcher = dhcp.DeviceTrackerWatcher( + hass, + {}, + [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], + ) + device_tracker_watcher.async_start() + await hass.async_block_till_done() + hass.states.async_set( + "device_tracker.august_connect", + STATE_HOME, + { + ATTR_HOST_NAME: "connect", + ATTR_IP: "192.168.210.56", + ATTR_SOURCE_TYPE: "something_else", + ATTR_MAC: "B8:B7:F1:6D:B5:33", + }, + ) + await hass.async_block_till_done() + device_tracker_watcher.async_stop() + await hass.async_block_till_done() + + assert len(mock_init.mock_calls) == 0 + + +async def test_device_tracker_hostname_and_macaddress_after_start_hostname_missing( + hass, +): + """Test matching based on hostname and macaddress after start but missing hostname.""" + + with patch.object( + hass.config_entries.flow, "async_init", return_value=mock_coro() + ) as mock_init: + device_tracker_watcher = dhcp.DeviceTrackerWatcher( + hass, + {}, + [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], + ) + device_tracker_watcher.async_start() + await hass.async_block_till_done() + hass.states.async_set( + "device_tracker.august_connect", + STATE_HOME, + { + ATTR_IP: "192.168.210.56", + ATTR_SOURCE_TYPE: SOURCE_TYPE_ROUTER, + ATTR_MAC: "B8:B7:F1:6D:B5:33", + }, + ) + await hass.async_block_till_done() + device_tracker_watcher.async_stop() + await hass.async_block_till_done() + + assert len(mock_init.mock_calls) == 0 diff --git a/tests/components/unifi/test_device_tracker.py b/tests/components/unifi/test_device_tracker.py index 6bfe8f44b5c..6936b2e0fb5 100644 --- a/tests/components/unifi/test_device_tracker.py +++ b/tests/components/unifi/test_device_tracker.py @@ -189,6 +189,10 @@ async def test_tracked_wireless_clients(hass): client_1 = hass.states.get("device_tracker.client_1") assert client_1.state == "home" + assert client_1.attributes["ip"] == "10.0.0.1" + assert client_1.attributes["mac"] == "00:00:00:00:00:01" + assert client_1.attributes["hostname"] == "client_1" + assert client_1.attributes["host_name"] == "client_1" # State change signalling works with events controller.api.websocket._data = {