"""The dhcp integration.""" from abc import abstractmethod import fnmatch from ipaddress import ip_address as make_ip_address import logging import os import threading from scapy.config import conf from scapy.data import ETH_P_ALL from scapy.error import Scapy_Exception from scapy.layers.dhcp import DHCP from scapy.layers.l2 import Ether from scapy.sendrecv import AsyncSniffer from homeassistant.components.device_tracker.const import ( ATTR_HOST_NAME, ATTR_IP, ATTR_MAC, ATTR_SOURCE_TYPE, DOMAIN as DEVICE_TRACKER_DOMAIN, SOURCE_TYPE_ROUTER, ) from homeassistant.const import ( EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, STATE_HOME, ) from homeassistant.core import Event, HomeAssistant, State, callback from homeassistant.helpers.device_registry import format_mac from homeassistant.helpers.event import async_track_state_added_domain from homeassistant.loader import async_get_dhcp from homeassistant.util.network import is_link_local from .const import DOMAIN FILTER = "udp and (port 67 or 68)" REQUESTED_ADDR = "requested_addr" MESSAGE_TYPE = "message-type" HOSTNAME = "hostname" MAC_ADDRESS = "macaddress" IP_ADDRESS = "ip" DHCP_REQUEST = 3 _LOGGER = logging.getLogger(__name__) async def async_setup(hass: HomeAssistant, config: dict) -> bool: """Set up the dhcp component.""" async def _initialize(_): address_data = {} integration_matchers = await async_get_dhcp(hass) watchers = [] for cls in (DHCPWatcher, DeviceTrackerWatcher): watcher = cls(hass, address_data, integration_matchers) await watcher.async_start() watchers.append(watcher) async def _async_stop(*_): for watcher in watchers: await watcher.async_stop() hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, _initialize) return True class WatcherBase: """Base class for dhcp and device tracker watching.""" def __init__(self, hass, address_data, integration_matchers): """Initialize class.""" super().__init__() self.hass = hass self._integration_matchers = integration_matchers self._address_data = address_data def process_client(self, ip_address, hostname, mac_address): """Process a client.""" if is_link_local(make_ip_address(ip_address)): # Ignore self assigned addresses return data = self._address_data.get(ip_address) if data and data[MAC_ADDRESS] == mac_address and data[HOSTNAME] == hostname: # If the address data is the same no need # to process it return self._address_data[ip_address] = {MAC_ADDRESS: mac_address, HOSTNAME: hostname} self.process_updated_address_data(ip_address, self._address_data[ip_address]) def process_updated_address_data(self, ip_address, data): """Process the address data update.""" lowercase_hostname = data[HOSTNAME].lower() uppercase_mac = data[MAC_ADDRESS].upper() _LOGGER.debug( "Processing updated address data for %s: mac=%s hostname=%s", ip_address, uppercase_mac, lowercase_hostname, ) for entry in self._integration_matchers: if MAC_ADDRESS in entry and not fnmatch.fnmatch( uppercase_mac, entry[MAC_ADDRESS] ): continue if HOSTNAME in entry and not fnmatch.fnmatch( lowercase_hostname, entry[HOSTNAME] ): continue _LOGGER.debug("Matched %s against %s", data, entry) self.create_task( self.hass.config_entries.flow.async_init( entry["domain"], context={"source": DOMAIN}, data={IP_ADDRESS: ip_address, **data}, ) ) @abstractmethod def create_task(self, task): """Pass a task to async_add_task based on which context we are in.""" class DeviceTrackerWatcher(WatcherBase): """Class to watch dhcp data from routers.""" def __init__(self, hass, address_data, integration_matchers): """Initialize class.""" super().__init__(hass, address_data, integration_matchers) self._unsub = None async def async_stop(self): """Stop watching for new device trackers.""" if self._unsub: self._unsub() self._unsub = None async def async_start(self): """Stop watching for new device trackers.""" self._unsub = async_track_state_added_domain( self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event ) for state in self.hass.states.async_all(DEVICE_TRACKER_DOMAIN): self._async_process_device_state(state) @callback def _async_process_device_event(self, event: Event): """Process a device tracker state change event.""" self._async_process_device_state(event.data.get("new_state")) @callback def _async_process_device_state(self, state: State): """Process a device tracker state.""" if state.state != STATE_HOME: return attributes = state.attributes if attributes.get(ATTR_SOURCE_TYPE) != SOURCE_TYPE_ROUTER: return ip_address = attributes.get(ATTR_IP) hostname = attributes.get(ATTR_HOST_NAME) mac_address = attributes.get(ATTR_MAC) if ip_address is None or hostname is None or mac_address is None: return self.process_client(ip_address, hostname, _format_mac(mac_address)) def create_task(self, task): """Pass a task to async_create_task since we are in async context.""" self.hass.async_create_task(task) class DHCPWatcher(WatcherBase): """Class to watch dhcp requests.""" def __init__(self, hass, address_data, integration_matchers): """Initialize class.""" super().__init__(hass, address_data, integration_matchers) self._sniffer = None self._started = threading.Event() async def async_stop(self): """Stop watching for new device trackers.""" await self.hass.async_add_executor_job(self._stop) def _stop(self): """Stop the thread.""" if self._started.is_set(): self._sniffer.stop() async def async_start(self): """Start watching for dhcp packets.""" try: sniff_socket = conf.L2socket(type=ETH_P_ALL) self._sniffer = AsyncSniffer( filter=FILTER, opened_socket=[sniff_socket], started_callback=self._started.set, prn=self.handle_dhcp_packet, ) self._sniffer.start() except (Scapy_Exception, OSError) as ex: if os.geteuid() == 0: _LOGGER.error("Cannot watch for dhcp packets: %s", ex) else: _LOGGER.debug( "Cannot watch for dhcp packets without root or CAP_NET_RAW: %s", ex ) return def handle_dhcp_packet(self, packet): """Process a dhcp packet.""" if DHCP not in packet: return options = packet[DHCP].options request_type = _decode_dhcp_option(options, MESSAGE_TYPE) if request_type != DHCP_REQUEST: # DHCP request return ip_address = _decode_dhcp_option(options, REQUESTED_ADDR) hostname = _decode_dhcp_option(options, HOSTNAME) mac_address = _format_mac(packet[Ether].src) if ip_address is None or hostname is None or mac_address is None: return self.process_client(ip_address, hostname, mac_address) def create_task(self, task): """Pass a task to hass.add_job since we are in a thread.""" self.hass.add_job(task) def _decode_dhcp_option(dhcp_options, key): """Extract and decode data from a packet option.""" for option in dhcp_options: if len(option) < 2 or option[0] != key: continue value = option[1] if value is None or key != HOSTNAME: return value # hostname is unicode try: return value.decode() except (AttributeError, UnicodeDecodeError): return None def _format_mac(mac_address): """Format a mac address for matching.""" return format_mac(mac_address).replace(":", "")