Add additional data source to dhcp (#48430)
This commit is contained in:
parent
23c7c4c977
commit
2ff94c8ed9
6 changed files with 235 additions and 60 deletions
|
@ -1,12 +1,19 @@
|
|||
"""The dhcp integration."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from datetime import timedelta
|
||||
import fnmatch
|
||||
from ipaddress import ip_address as make_ip_address
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
from aiodiscover import DiscoverHosts
|
||||
from aiodiscover.discovery import (
|
||||
HOSTNAME as DISCOVERY_HOSTNAME,
|
||||
IP_ADDRESS as DISCOVERY_IP_ADDRESS,
|
||||
MAC_ADDRESS as DISCOVERY_MAC_ADDRESS,
|
||||
)
|
||||
from scapy.arch.common import compile_filter
|
||||
from scapy.config import conf
|
||||
from scapy.error import Scapy_Exception
|
||||
|
@ -29,7 +36,10 @@ from homeassistant.const import (
|
|||
)
|
||||
from homeassistant.core import Event, HomeAssistant, State, callback
|
||||
from homeassistant.helpers.device_registry import format_mac
|
||||
from homeassistant.helpers.event import async_track_state_added_domain
|
||||
from homeassistant.helpers.event import (
|
||||
async_track_state_added_domain,
|
||||
async_track_time_interval,
|
||||
)
|
||||
from homeassistant.loader import async_get_dhcp
|
||||
from homeassistant.util.network import is_link_local
|
||||
|
||||
|
@ -42,6 +52,7 @@ HOSTNAME = "hostname"
|
|||
MAC_ADDRESS = "macaddress"
|
||||
IP_ADDRESS = "ip"
|
||||
DHCP_REQUEST = 3
|
||||
SCAN_INTERVAL = timedelta(minutes=60)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -54,7 +65,7 @@ async def async_setup(hass: HomeAssistant, config: dict) -> bool:
|
|||
integration_matchers = await async_get_dhcp(hass)
|
||||
watchers = []
|
||||
|
||||
for cls in (DHCPWatcher, DeviceTrackerWatcher):
|
||||
for cls in (DHCPWatcher, DeviceTrackerWatcher, NetworkWatcher):
|
||||
watcher = cls(hass, address_data, integration_matchers)
|
||||
await watcher.async_start()
|
||||
watchers.append(watcher)
|
||||
|
@ -88,7 +99,11 @@ class WatcherBase:
|
|||
|
||||
data = self._address_data.get(ip_address)
|
||||
|
||||
if data and data[MAC_ADDRESS] == mac_address and data[HOSTNAME] == hostname:
|
||||
if (
|
||||
data
|
||||
and data[MAC_ADDRESS] == mac_address
|
||||
and data[HOSTNAME].startswith(hostname)
|
||||
):
|
||||
# If the address data is the same no need
|
||||
# to process it
|
||||
return
|
||||
|
@ -139,6 +154,54 @@ class WatcherBase:
|
|||
"""Pass a task to async_add_task based on which context we are in."""
|
||||
|
||||
|
||||
class NetworkWatcher(WatcherBase):
|
||||
"""Class to query ptr records routers."""
|
||||
|
||||
def __init__(self, hass, address_data, integration_matchers):
|
||||
"""Initialize class."""
|
||||
super().__init__(hass, address_data, integration_matchers)
|
||||
self._unsub = None
|
||||
self._discover_hosts = None
|
||||
self._discover_task = None
|
||||
|
||||
async def async_stop(self):
|
||||
"""Stop scanning for new devices on the network."""
|
||||
if self._unsub:
|
||||
self._unsub()
|
||||
self._unsub = None
|
||||
if self._discover_task:
|
||||
self._discover_task.cancel()
|
||||
self._discover_task = None
|
||||
|
||||
async def async_start(self):
|
||||
"""Start scanning for new devices on the network."""
|
||||
self._discover_hosts = DiscoverHosts()
|
||||
self._unsub = async_track_time_interval(
|
||||
self.hass, self.async_start_discover, SCAN_INTERVAL
|
||||
)
|
||||
self.async_start_discover()
|
||||
|
||||
@callback
|
||||
def async_start_discover(self, *_):
|
||||
"""Start a new discovery task if one is not running."""
|
||||
if self._discover_task and not self._discover_task.done():
|
||||
return
|
||||
self._discover_task = self.create_task(self.async_discover())
|
||||
|
||||
async def async_discover(self):
|
||||
"""Process discovery."""
|
||||
for host in await self._discover_hosts.async_discover():
|
||||
self.process_client(
|
||||
host[DISCOVERY_IP_ADDRESS],
|
||||
host[DISCOVERY_HOSTNAME],
|
||||
_format_mac(host[DISCOVERY_MAC_ADDRESS]),
|
||||
)
|
||||
|
||||
def create_task(self, task):
|
||||
"""Pass a task to async_create_task since we are in async context."""
|
||||
return self.hass.async_create_task(task)
|
||||
|
||||
|
||||
class DeviceTrackerWatcher(WatcherBase):
|
||||
"""Class to watch dhcp data from routers."""
|
||||
|
||||
|
@ -188,7 +251,7 @@ class DeviceTrackerWatcher(WatcherBase):
|
|||
|
||||
def create_task(self, task):
|
||||
"""Pass a task to async_create_task since we are in async context."""
|
||||
self.hass.async_create_task(task)
|
||||
return self.hass.async_create_task(task)
|
||||
|
||||
|
||||
class DHCPWatcher(WatcherBase):
|
||||
|
@ -266,7 +329,7 @@ class DHCPWatcher(WatcherBase):
|
|||
|
||||
def create_task(self, task):
|
||||
"""Pass a task to hass.add_job since we are in a thread."""
|
||||
self.hass.add_job(task)
|
||||
return self.hass.add_job(task)
|
||||
|
||||
|
||||
def _decode_dhcp_option(dhcp_options, key):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue