Fix dhcp discovery matching due to deferred imports (#56814)
This commit is contained in:
parent
a7f554e6da
commit
2ed35debdc
2 changed files with 83 additions and 81 deletions
|
@ -282,6 +282,9 @@ class DHCPWatcher(WatcherBase):
|
|||
from scapy import ( # pylint: disable=import-outside-toplevel,unused-import # noqa: F401
|
||||
arch,
|
||||
)
|
||||
from scapy.layers.dhcp import DHCP # pylint: disable=import-outside-toplevel
|
||||
from scapy.layers.inet import IP # pylint: disable=import-outside-toplevel
|
||||
from scapy.layers.l2 import Ether # pylint: disable=import-outside-toplevel
|
||||
|
||||
#
|
||||
# Importing scapy.sendrecv will cause a scapy resync which will
|
||||
|
@ -294,6 +297,24 @@ class DHCPWatcher(WatcherBase):
|
|||
AsyncSniffer,
|
||||
)
|
||||
|
||||
def _handle_dhcp_packet(packet):
|
||||
"""Process a dhcp packet."""
|
||||
if DHCP not in packet:
|
||||
return
|
||||
|
||||
options = packet[DHCP].options
|
||||
request_type = _decode_dhcp_option(options, MESSAGE_TYPE)
|
||||
if request_type != DHCP_REQUEST:
|
||||
# Not a DHCP request
|
||||
return
|
||||
|
||||
ip_address = _decode_dhcp_option(options, REQUESTED_ADDR) or packet[IP].src
|
||||
hostname = _decode_dhcp_option(options, HOSTNAME) or ""
|
||||
mac_address = _format_mac(packet[Ether].src)
|
||||
|
||||
if ip_address is not None and mac_address is not None:
|
||||
self.process_client(ip_address, hostname, mac_address)
|
||||
|
||||
# disable scapy promiscuous mode as we do not need it
|
||||
conf.sniff_promisc = 0
|
||||
|
||||
|
@ -320,7 +341,7 @@ class DHCPWatcher(WatcherBase):
|
|||
self._sniffer = AsyncSniffer(
|
||||
filter=FILTER,
|
||||
started_callback=self._started.set,
|
||||
prn=self.handle_dhcp_packet,
|
||||
prn=_handle_dhcp_packet,
|
||||
store=0,
|
||||
)
|
||||
|
||||
|
@ -328,33 +349,6 @@ class DHCPWatcher(WatcherBase):
|
|||
if self._sniffer.thread:
|
||||
self._sniffer.thread.name = self.__class__.__name__
|
||||
|
||||
def handle_dhcp_packet(self, packet):
|
||||
"""Process a dhcp packet."""
|
||||
# Local import because importing from scapy has side effects such as opening
|
||||
# sockets
|
||||
from scapy.layers.dhcp import DHCP # pylint: disable=import-outside-toplevel
|
||||
from scapy.layers.inet import IP # pylint: disable=import-outside-toplevel
|
||||
from scapy.layers.l2 import Ether # pylint: disable=import-outside-toplevel
|
||||
|
||||
if DHCP not in packet:
|
||||
return
|
||||
|
||||
options = packet[DHCP].options
|
||||
|
||||
request_type = _decode_dhcp_option(options, MESSAGE_TYPE)
|
||||
if request_type != DHCP_REQUEST:
|
||||
# DHCP request
|
||||
return
|
||||
|
||||
ip_address = _decode_dhcp_option(options, REQUESTED_ADDR) or packet[IP].src
|
||||
hostname = _decode_dhcp_option(options, HOSTNAME) or ""
|
||||
mac_address = _format_mac(packet[Ether].src)
|
||||
|
||||
if ip_address is None or mac_address is None:
|
||||
return
|
||||
|
||||
self.process_client(ip_address, hostname, mac_address)
|
||||
|
||||
def create_task(self, task):
|
||||
"""Pass a task to hass.add_job since we are in a thread."""
|
||||
return self.hass.add_job(task)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Test the DHCP discovery integration."""
|
||||
import datetime
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from scapy.error import Scapy_Exception
|
||||
from scapy.layers.dhcp import DHCP
|
||||
|
@ -123,20 +123,39 @@ RAW_DHCP_REQUEST_WITHOUT_HOSTNAME = (
|
|||
)
|
||||
|
||||
|
||||
async def test_dhcp_match_hostname_and_macaddress(hass):
|
||||
"""Test matching based on hostname and macaddress."""
|
||||
async def _async_get_handle_dhcp_packet(hass, integration_matchers):
|
||||
dhcp_watcher = dhcp.DHCPWatcher(
|
||||
hass,
|
||||
{},
|
||||
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
||||
integration_matchers,
|
||||
)
|
||||
handle_dhcp_packet = None
|
||||
|
||||
def _mock_sniffer(*args, **kwargs):
|
||||
nonlocal handle_dhcp_packet
|
||||
handle_dhcp_packet = kwargs["prn"]
|
||||
return MagicMock()
|
||||
|
||||
with patch("homeassistant.components.dhcp._verify_l2socket_setup",), patch(
|
||||
"scapy.arch.common.compile_filter"
|
||||
), patch("scapy.sendrecv.AsyncSniffer", _mock_sniffer):
|
||||
await dhcp_watcher.async_start()
|
||||
|
||||
return handle_dhcp_packet
|
||||
|
||||
|
||||
async def test_dhcp_match_hostname_and_macaddress(hass):
|
||||
"""Test matching based on hostname and macaddress."""
|
||||
integration_matchers = [
|
||||
{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}
|
||||
]
|
||||
packet = Ether(RAW_DHCP_REQUEST)
|
||||
|
||||
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
|
||||
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
|
||||
dhcp_watcher.handle_dhcp_packet(packet)
|
||||
handle_dhcp_packet(packet)
|
||||
# Ensure no change is ignored
|
||||
dhcp_watcher.handle_dhcp_packet(packet)
|
||||
handle_dhcp_packet(packet)
|
||||
|
||||
assert len(mock_init.mock_calls) == 1
|
||||
assert mock_init.mock_calls[0][1][0] == "mock-domain"
|
||||
|
@ -152,18 +171,17 @@ async def test_dhcp_match_hostname_and_macaddress(hass):
|
|||
|
||||
async def test_dhcp_renewal_match_hostname_and_macaddress(hass):
|
||||
"""Test renewal matching based on hostname and macaddress."""
|
||||
dhcp_watcher = dhcp.DHCPWatcher(
|
||||
hass,
|
||||
{},
|
||||
[{"domain": "mock-domain", "hostname": "irobot-*", "macaddress": "501479*"}],
|
||||
)
|
||||
integration_matchers = [
|
||||
{"domain": "mock-domain", "hostname": "irobot-*", "macaddress": "501479*"}
|
||||
]
|
||||
|
||||
packet = Ether(RAW_DHCP_RENEWAL)
|
||||
|
||||
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
|
||||
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
|
||||
dhcp_watcher.handle_dhcp_packet(packet)
|
||||
handle_dhcp_packet(packet)
|
||||
# Ensure no change is ignored
|
||||
dhcp_watcher.handle_dhcp_packet(packet)
|
||||
handle_dhcp_packet(packet)
|
||||
|
||||
assert len(mock_init.mock_calls) == 1
|
||||
assert mock_init.mock_calls[0][1][0] == "mock-domain"
|
||||
|
@ -179,14 +197,13 @@ async def test_dhcp_renewal_match_hostname_and_macaddress(hass):
|
|||
|
||||
async def test_dhcp_match_hostname(hass):
|
||||
"""Test matching based on hostname only."""
|
||||
dhcp_watcher = dhcp.DHCPWatcher(
|
||||
hass, {}, [{"domain": "mock-domain", "hostname": "connect"}]
|
||||
)
|
||||
integration_matchers = [{"domain": "mock-domain", "hostname": "connect"}]
|
||||
|
||||
packet = Ether(RAW_DHCP_REQUEST)
|
||||
|
||||
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
|
||||
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
|
||||
dhcp_watcher.handle_dhcp_packet(packet)
|
||||
handle_dhcp_packet(packet)
|
||||
|
||||
assert len(mock_init.mock_calls) == 1
|
||||
assert mock_init.mock_calls[0][1][0] == "mock-domain"
|
||||
|
@ -202,14 +219,13 @@ async def test_dhcp_match_hostname(hass):
|
|||
|
||||
async def test_dhcp_match_macaddress(hass):
|
||||
"""Test matching based on macaddress only."""
|
||||
dhcp_watcher = dhcp.DHCPWatcher(
|
||||
hass, {}, [{"domain": "mock-domain", "macaddress": "B8B7F1*"}]
|
||||
)
|
||||
integration_matchers = [{"domain": "mock-domain", "macaddress": "B8B7F1*"}]
|
||||
|
||||
packet = Ether(RAW_DHCP_REQUEST)
|
||||
|
||||
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
|
||||
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
|
||||
dhcp_watcher.handle_dhcp_packet(packet)
|
||||
handle_dhcp_packet(packet)
|
||||
|
||||
assert len(mock_init.mock_calls) == 1
|
||||
assert mock_init.mock_calls[0][1][0] == "mock-domain"
|
||||
|
@ -225,14 +241,13 @@ async def test_dhcp_match_macaddress(hass):
|
|||
|
||||
async def test_dhcp_match_macaddress_without_hostname(hass):
|
||||
"""Test matching based on macaddress only."""
|
||||
dhcp_watcher = dhcp.DHCPWatcher(
|
||||
hass, {}, [{"domain": "mock-domain", "macaddress": "606BBD*"}]
|
||||
)
|
||||
integration_matchers = [{"domain": "mock-domain", "macaddress": "606BBD*"}]
|
||||
|
||||
packet = Ether(RAW_DHCP_REQUEST_WITHOUT_HOSTNAME)
|
||||
|
||||
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
|
||||
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
|
||||
dhcp_watcher.handle_dhcp_packet(packet)
|
||||
handle_dhcp_packet(packet)
|
||||
|
||||
assert len(mock_init.mock_calls) == 1
|
||||
assert mock_init.mock_calls[0][1][0] == "mock-domain"
|
||||
|
@ -248,51 +263,46 @@ async def test_dhcp_match_macaddress_without_hostname(hass):
|
|||
|
||||
async def test_dhcp_nomatch(hass):
|
||||
"""Test not matching based on macaddress only."""
|
||||
dhcp_watcher = dhcp.DHCPWatcher(
|
||||
hass, {}, [{"domain": "mock-domain", "macaddress": "ABC123*"}]
|
||||
)
|
||||
integration_matchers = [{"domain": "mock-domain", "macaddress": "ABC123*"}]
|
||||
|
||||
packet = Ether(RAW_DHCP_REQUEST)
|
||||
|
||||
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
|
||||
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
|
||||
dhcp_watcher.handle_dhcp_packet(packet)
|
||||
handle_dhcp_packet(packet)
|
||||
|
||||
assert len(mock_init.mock_calls) == 0
|
||||
|
||||
|
||||
async def test_dhcp_nomatch_hostname(hass):
|
||||
"""Test not matching based on hostname only."""
|
||||
dhcp_watcher = dhcp.DHCPWatcher(
|
||||
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
|
||||
)
|
||||
integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
|
||||
|
||||
packet = Ether(RAW_DHCP_REQUEST)
|
||||
|
||||
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
|
||||
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
|
||||
dhcp_watcher.handle_dhcp_packet(packet)
|
||||
handle_dhcp_packet(packet)
|
||||
|
||||
assert len(mock_init.mock_calls) == 0
|
||||
|
||||
|
||||
async def test_dhcp_nomatch_non_dhcp_packet(hass):
|
||||
"""Test matching does not throw on a non-dhcp packet."""
|
||||
dhcp_watcher = dhcp.DHCPWatcher(
|
||||
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
|
||||
)
|
||||
integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
|
||||
|
||||
packet = Ether(b"")
|
||||
|
||||
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
|
||||
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
|
||||
dhcp_watcher.handle_dhcp_packet(packet)
|
||||
handle_dhcp_packet(packet)
|
||||
|
||||
assert len(mock_init.mock_calls) == 0
|
||||
|
||||
|
||||
async def test_dhcp_nomatch_non_dhcp_request_packet(hass):
|
||||
"""Test nothing happens with the wrong message-type."""
|
||||
dhcp_watcher = dhcp.DHCPWatcher(
|
||||
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
|
||||
)
|
||||
integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
|
||||
|
||||
packet = Ether(RAW_DHCP_REQUEST)
|
||||
|
||||
|
@ -305,17 +315,16 @@ async def test_dhcp_nomatch_non_dhcp_request_packet(hass):
|
|||
("hostname", b"connect"),
|
||||
]
|
||||
|
||||
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
|
||||
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
|
||||
dhcp_watcher.handle_dhcp_packet(packet)
|
||||
handle_dhcp_packet(packet)
|
||||
|
||||
assert len(mock_init.mock_calls) == 0
|
||||
|
||||
|
||||
async def test_dhcp_invalid_hostname(hass):
|
||||
"""Test we ignore invalid hostnames."""
|
||||
dhcp_watcher = dhcp.DHCPWatcher(
|
||||
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
|
||||
)
|
||||
integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
|
||||
|
||||
packet = Ether(RAW_DHCP_REQUEST)
|
||||
|
||||
|
@ -328,17 +337,16 @@ async def test_dhcp_invalid_hostname(hass):
|
|||
("hostname", "connect"),
|
||||
]
|
||||
|
||||
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
|
||||
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
|
||||
dhcp_watcher.handle_dhcp_packet(packet)
|
||||
handle_dhcp_packet(packet)
|
||||
|
||||
assert len(mock_init.mock_calls) == 0
|
||||
|
||||
|
||||
async def test_dhcp_missing_hostname(hass):
|
||||
"""Test we ignore missing hostnames."""
|
||||
dhcp_watcher = dhcp.DHCPWatcher(
|
||||
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
|
||||
)
|
||||
integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
|
||||
|
||||
packet = Ether(RAW_DHCP_REQUEST)
|
||||
|
||||
|
@ -351,17 +359,16 @@ async def test_dhcp_missing_hostname(hass):
|
|||
("hostname", None),
|
||||
]
|
||||
|
||||
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
|
||||
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
|
||||
dhcp_watcher.handle_dhcp_packet(packet)
|
||||
handle_dhcp_packet(packet)
|
||||
|
||||
assert len(mock_init.mock_calls) == 0
|
||||
|
||||
|
||||
async def test_dhcp_invalid_option(hass):
|
||||
"""Test we ignore invalid hostname option."""
|
||||
dhcp_watcher = dhcp.DHCPWatcher(
|
||||
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
|
||||
)
|
||||
integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
|
||||
|
||||
packet = Ether(RAW_DHCP_REQUEST)
|
||||
|
||||
|
@ -374,8 +381,9 @@ async def test_dhcp_invalid_option(hass):
|
|||
("hostname"),
|
||||
]
|
||||
|
||||
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
|
||||
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
|
||||
dhcp_watcher.handle_dhcp_packet(packet)
|
||||
handle_dhcp_packet(packet)
|
||||
|
||||
assert len(mock_init.mock_calls) == 0
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue