Add additional data source to dhcp (#48430)

This commit is contained in:
J. Nick Koston 2021-03-28 09:47:28 -10:00 committed by GitHub
parent 23c7c4c977
commit 2ff94c8ed9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 235 additions and 60 deletions

View file

@ -1,12 +1,19 @@
"""The dhcp integration."""
from abc import abstractmethod
from datetime import timedelta
import fnmatch
from ipaddress import ip_address as make_ip_address
import logging
import os
import threading
from aiodiscover import DiscoverHosts
from aiodiscover.discovery import (
HOSTNAME as DISCOVERY_HOSTNAME,
IP_ADDRESS as DISCOVERY_IP_ADDRESS,
MAC_ADDRESS as DISCOVERY_MAC_ADDRESS,
)
from scapy.arch.common import compile_filter
from scapy.config import conf
from scapy.error import Scapy_Exception
@ -29,7 +36,10 @@ from homeassistant.const import (
)
from homeassistant.core import Event, HomeAssistant, State, callback
from homeassistant.helpers.device_registry import format_mac
from homeassistant.helpers.event import async_track_state_added_domain
from homeassistant.helpers.event import (
async_track_state_added_domain,
async_track_time_interval,
)
from homeassistant.loader import async_get_dhcp
from homeassistant.util.network import is_link_local
@ -42,6 +52,7 @@ HOSTNAME = "hostname"
MAC_ADDRESS = "macaddress"
IP_ADDRESS = "ip"
DHCP_REQUEST = 3
SCAN_INTERVAL = timedelta(minutes=60)
_LOGGER = logging.getLogger(__name__)
@ -54,7 +65,7 @@ async def async_setup(hass: HomeAssistant, config: dict) -> bool:
integration_matchers = await async_get_dhcp(hass)
watchers = []
for cls in (DHCPWatcher, DeviceTrackerWatcher):
for cls in (DHCPWatcher, DeviceTrackerWatcher, NetworkWatcher):
watcher = cls(hass, address_data, integration_matchers)
await watcher.async_start()
watchers.append(watcher)
@ -88,7 +99,11 @@ class WatcherBase:
data = self._address_data.get(ip_address)
if data and data[MAC_ADDRESS] == mac_address and data[HOSTNAME] == hostname:
if (
data
and data[MAC_ADDRESS] == mac_address
and data[HOSTNAME].startswith(hostname)
):
# If the address data is the same no need
# to process it
return
@ -139,6 +154,54 @@ class WatcherBase:
"""Pass a task to async_add_task based on which context we are in."""
class NetworkWatcher(WatcherBase):
"""Class to query ptr records routers."""
def __init__(self, hass, address_data, integration_matchers):
"""Initialize class."""
super().__init__(hass, address_data, integration_matchers)
self._unsub = None
self._discover_hosts = None
self._discover_task = None
async def async_stop(self):
"""Stop scanning for new devices on the network."""
if self._unsub:
self._unsub()
self._unsub = None
if self._discover_task:
self._discover_task.cancel()
self._discover_task = None
async def async_start(self):
"""Start scanning for new devices on the network."""
self._discover_hosts = DiscoverHosts()
self._unsub = async_track_time_interval(
self.hass, self.async_start_discover, SCAN_INTERVAL
)
self.async_start_discover()
@callback
def async_start_discover(self, *_):
"""Start a new discovery task if one is not running."""
if self._discover_task and not self._discover_task.done():
return
self._discover_task = self.create_task(self.async_discover())
async def async_discover(self):
"""Process discovery."""
for host in await self._discover_hosts.async_discover():
self.process_client(
host[DISCOVERY_IP_ADDRESS],
host[DISCOVERY_HOSTNAME],
_format_mac(host[DISCOVERY_MAC_ADDRESS]),
)
def create_task(self, task):
"""Pass a task to async_create_task since we are in async context."""
return self.hass.async_create_task(task)
class DeviceTrackerWatcher(WatcherBase):
"""Class to watch dhcp data from routers."""
@ -188,7 +251,7 @@ class DeviceTrackerWatcher(WatcherBase):
def create_task(self, task):
"""Pass a task to async_create_task since we are in async context."""
self.hass.async_create_task(task)
return self.hass.async_create_task(task)
class DHCPWatcher(WatcherBase):
@ -266,7 +329,7 @@ class DHCPWatcher(WatcherBase):
def create_task(self, task):
"""Pass a task to hass.add_job since we are in a thread."""
self.hass.add_job(task)
return self.hass.add_job(task)
def _decode_dhcp_option(dhcp_options, key):