Fix dhcp discovery matching due to deferred imports (#56814)

This commit is contained in:
J. Nick Koston 2021-09-29 23:50:21 -05:00 committed by GitHub
parent a7f554e6da
commit 2ed35debdc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 81 deletions

View file

@ -282,6 +282,9 @@ class DHCPWatcher(WatcherBase):
from scapy import ( # pylint: disable=import-outside-toplevel,unused-import # noqa: F401
arch,
)
from scapy.layers.dhcp import DHCP # pylint: disable=import-outside-toplevel
from scapy.layers.inet import IP # pylint: disable=import-outside-toplevel
from scapy.layers.l2 import Ether # pylint: disable=import-outside-toplevel
#
# Importing scapy.sendrecv will cause a scapy resync which will
@ -294,6 +297,24 @@ class DHCPWatcher(WatcherBase):
AsyncSniffer,
)
def _handle_dhcp_packet(packet):
"""Process a dhcp packet."""
if DHCP not in packet:
return
options = packet[DHCP].options
request_type = _decode_dhcp_option(options, MESSAGE_TYPE)
if request_type != DHCP_REQUEST:
# Not a DHCP request
return
ip_address = _decode_dhcp_option(options, REQUESTED_ADDR) or packet[IP].src
hostname = _decode_dhcp_option(options, HOSTNAME) or ""
mac_address = _format_mac(packet[Ether].src)
if ip_address is not None and mac_address is not None:
self.process_client(ip_address, hostname, mac_address)
# disable scapy promiscuous mode as we do not need it
conf.sniff_promisc = 0
@ -320,7 +341,7 @@ class DHCPWatcher(WatcherBase):
self._sniffer = AsyncSniffer(
filter=FILTER,
started_callback=self._started.set,
prn=self.handle_dhcp_packet,
prn=_handle_dhcp_packet,
store=0,
)
@ -328,33 +349,6 @@ class DHCPWatcher(WatcherBase):
if self._sniffer.thread:
self._sniffer.thread.name = self.__class__.__name__
def handle_dhcp_packet(self, packet):
"""Process a dhcp packet."""
# Local import because importing from scapy has side effects such as opening
# sockets
from scapy.layers.dhcp import DHCP # pylint: disable=import-outside-toplevel
from scapy.layers.inet import IP # pylint: disable=import-outside-toplevel
from scapy.layers.l2 import Ether # pylint: disable=import-outside-toplevel
if DHCP not in packet:
return
options = packet[DHCP].options
request_type = _decode_dhcp_option(options, MESSAGE_TYPE)
if request_type != DHCP_REQUEST:
# DHCP request
return
ip_address = _decode_dhcp_option(options, REQUESTED_ADDR) or packet[IP].src
hostname = _decode_dhcp_option(options, HOSTNAME) or ""
mac_address = _format_mac(packet[Ether].src)
if ip_address is None or mac_address is None:
return
self.process_client(ip_address, hostname, mac_address)
def create_task(self, task):
"""Pass a task to hass.add_job since we are in a thread."""
return self.hass.add_job(task)

View file

@ -1,7 +1,7 @@
"""Test the DHCP discovery integration."""
import datetime
import threading
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from scapy.error import Scapy_Exception
from scapy.layers.dhcp import DHCP
@ -123,20 +123,39 @@ RAW_DHCP_REQUEST_WITHOUT_HOSTNAME = (
)
async def test_dhcp_match_hostname_and_macaddress(hass):
"""Test matching based on hostname and macaddress."""
async def _async_get_handle_dhcp_packet(hass, integration_matchers):
dhcp_watcher = dhcp.DHCPWatcher(
hass,
{},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
integration_matchers,
)
handle_dhcp_packet = None
def _mock_sniffer(*args, **kwargs):
nonlocal handle_dhcp_packet
handle_dhcp_packet = kwargs["prn"]
return MagicMock()
with patch("homeassistant.components.dhcp._verify_l2socket_setup",), patch(
"scapy.arch.common.compile_filter"
), patch("scapy.sendrecv.AsyncSniffer", _mock_sniffer):
await dhcp_watcher.async_start()
return handle_dhcp_packet
async def test_dhcp_match_hostname_and_macaddress(hass):
"""Test matching based on hostname and macaddress."""
integration_matchers = [
{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}
]
packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet)
handle_dhcp_packet(packet)
# Ensure no change is ignored
dhcp_watcher.handle_dhcp_packet(packet)
handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -152,18 +171,17 @@ async def test_dhcp_match_hostname_and_macaddress(hass):
async def test_dhcp_renewal_match_hostname_and_macaddress(hass):
"""Test renewal matching based on hostname and macaddress."""
dhcp_watcher = dhcp.DHCPWatcher(
hass,
{},
[{"domain": "mock-domain", "hostname": "irobot-*", "macaddress": "501479*"}],
)
integration_matchers = [
{"domain": "mock-domain", "hostname": "irobot-*", "macaddress": "501479*"}
]
packet = Ether(RAW_DHCP_RENEWAL)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet)
handle_dhcp_packet(packet)
# Ensure no change is ignored
dhcp_watcher.handle_dhcp_packet(packet)
handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -179,14 +197,13 @@ async def test_dhcp_renewal_match_hostname_and_macaddress(hass):
async def test_dhcp_match_hostname(hass):
"""Test matching based on hostname only."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, {}, [{"domain": "mock-domain", "hostname": "connect"}]
)
integration_matchers = [{"domain": "mock-domain", "hostname": "connect"}]
packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet)
handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -202,14 +219,13 @@ async def test_dhcp_match_hostname(hass):
async def test_dhcp_match_macaddress(hass):
"""Test matching based on macaddress only."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, {}, [{"domain": "mock-domain", "macaddress": "B8B7F1*"}]
)
integration_matchers = [{"domain": "mock-domain", "macaddress": "B8B7F1*"}]
packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet)
handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -225,14 +241,13 @@ async def test_dhcp_match_macaddress(hass):
async def test_dhcp_match_macaddress_without_hostname(hass):
"""Test matching based on macaddress only."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, {}, [{"domain": "mock-domain", "macaddress": "606BBD*"}]
)
integration_matchers = [{"domain": "mock-domain", "macaddress": "606BBD*"}]
packet = Ether(RAW_DHCP_REQUEST_WITHOUT_HOSTNAME)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet)
handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -248,51 +263,46 @@ async def test_dhcp_match_macaddress_without_hostname(hass):
async def test_dhcp_nomatch(hass):
"""Test not matching based on macaddress only."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, {}, [{"domain": "mock-domain", "macaddress": "ABC123*"}]
)
integration_matchers = [{"domain": "mock-domain", "macaddress": "ABC123*"}]
packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet)
handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0
async def test_dhcp_nomatch_hostname(hass):
"""Test not matching based on hostname only."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet)
handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0
async def test_dhcp_nomatch_non_dhcp_packet(hass):
"""Test matching does not throw on a non-dhcp packet."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
packet = Ether(b"")
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet)
handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0
async def test_dhcp_nomatch_non_dhcp_request_packet(hass):
"""Test nothing happens with the wrong message-type."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
packet = Ether(RAW_DHCP_REQUEST)
@ -305,17 +315,16 @@ async def test_dhcp_nomatch_non_dhcp_request_packet(hass):
("hostname", b"connect"),
]
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet)
handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0
async def test_dhcp_invalid_hostname(hass):
"""Test we ignore invalid hostnames."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
packet = Ether(RAW_DHCP_REQUEST)
@ -328,17 +337,16 @@ async def test_dhcp_invalid_hostname(hass):
("hostname", "connect"),
]
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet)
handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0
async def test_dhcp_missing_hostname(hass):
"""Test we ignore missing hostnames."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
packet = Ether(RAW_DHCP_REQUEST)
@ -351,17 +359,16 @@ async def test_dhcp_missing_hostname(hass):
("hostname", None),
]
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet)
handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0
async def test_dhcp_invalid_option(hass):
"""Test we ignore invalid hostname option."""
dhcp_watcher = dhcp.DHCPWatcher(
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
packet = Ether(RAW_DHCP_REQUEST)
@ -374,8 +381,9 @@ async def test_dhcp_invalid_option(hass):
("hostname"),
]
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet)
handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0