Fix dhcp discovery matching due to deferred imports (#56814)

This commit is contained in:
J. Nick Koston 2021-09-29 23:50:21 -05:00 committed by GitHub
parent a7f554e6da
commit 2ed35debdc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 81 deletions

View file

@ -282,6 +282,9 @@ class DHCPWatcher(WatcherBase):
from scapy import ( # pylint: disable=import-outside-toplevel,unused-import # noqa: F401 from scapy import ( # pylint: disable=import-outside-toplevel,unused-import # noqa: F401
arch, arch,
) )
from scapy.layers.dhcp import DHCP # pylint: disable=import-outside-toplevel
from scapy.layers.inet import IP # pylint: disable=import-outside-toplevel
from scapy.layers.l2 import Ether # pylint: disable=import-outside-toplevel
# #
# Importing scapy.sendrecv will cause a scapy resync which will # Importing scapy.sendrecv will cause a scapy resync which will
@ -294,6 +297,24 @@ class DHCPWatcher(WatcherBase):
AsyncSniffer, AsyncSniffer,
) )
def _handle_dhcp_packet(packet):
"""Process a dhcp packet."""
if DHCP not in packet:
return
options = packet[DHCP].options
request_type = _decode_dhcp_option(options, MESSAGE_TYPE)
if request_type != DHCP_REQUEST:
# Not a DHCP request
return
ip_address = _decode_dhcp_option(options, REQUESTED_ADDR) or packet[IP].src
hostname = _decode_dhcp_option(options, HOSTNAME) or ""
mac_address = _format_mac(packet[Ether].src)
if ip_address is not None and mac_address is not None:
self.process_client(ip_address, hostname, mac_address)
# disable scapy promiscuous mode as we do not need it # disable scapy promiscuous mode as we do not need it
conf.sniff_promisc = 0 conf.sniff_promisc = 0
@ -320,7 +341,7 @@ class DHCPWatcher(WatcherBase):
self._sniffer = AsyncSniffer( self._sniffer = AsyncSniffer(
filter=FILTER, filter=FILTER,
started_callback=self._started.set, started_callback=self._started.set,
prn=self.handle_dhcp_packet, prn=_handle_dhcp_packet,
store=0, store=0,
) )
@ -328,33 +349,6 @@ class DHCPWatcher(WatcherBase):
if self._sniffer.thread: if self._sniffer.thread:
self._sniffer.thread.name = self.__class__.__name__ self._sniffer.thread.name = self.__class__.__name__
def handle_dhcp_packet(self, packet):
"""Process a dhcp packet."""
# Local import because importing from scapy has side effects such as opening
# sockets
from scapy.layers.dhcp import DHCP # pylint: disable=import-outside-toplevel
from scapy.layers.inet import IP # pylint: disable=import-outside-toplevel
from scapy.layers.l2 import Ether # pylint: disable=import-outside-toplevel
if DHCP not in packet:
return
options = packet[DHCP].options
request_type = _decode_dhcp_option(options, MESSAGE_TYPE)
if request_type != DHCP_REQUEST:
# DHCP request
return
ip_address = _decode_dhcp_option(options, REQUESTED_ADDR) or packet[IP].src
hostname = _decode_dhcp_option(options, HOSTNAME) or ""
mac_address = _format_mac(packet[Ether].src)
if ip_address is None or mac_address is None:
return
self.process_client(ip_address, hostname, mac_address)
def create_task(self, task): def create_task(self, task):
"""Pass a task to hass.add_job since we are in a thread.""" """Pass a task to hass.add_job since we are in a thread."""
return self.hass.add_job(task) return self.hass.add_job(task)

View file

@ -1,7 +1,7 @@
"""Test the DHCP discovery integration.""" """Test the DHCP discovery integration."""
import datetime import datetime
import threading import threading
from unittest.mock import patch from unittest.mock import MagicMock, patch
from scapy.error import Scapy_Exception from scapy.error import Scapy_Exception
from scapy.layers.dhcp import DHCP from scapy.layers.dhcp import DHCP
@ -123,20 +123,39 @@ RAW_DHCP_REQUEST_WITHOUT_HOSTNAME = (
) )
async def test_dhcp_match_hostname_and_macaddress(hass): async def _async_get_handle_dhcp_packet(hass, integration_matchers):
"""Test matching based on hostname and macaddress."""
dhcp_watcher = dhcp.DHCPWatcher( dhcp_watcher = dhcp.DHCPWatcher(
hass, hass,
{}, {},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], integration_matchers,
) )
handle_dhcp_packet = None
def _mock_sniffer(*args, **kwargs):
nonlocal handle_dhcp_packet
handle_dhcp_packet = kwargs["prn"]
return MagicMock()
with patch("homeassistant.components.dhcp._verify_l2socket_setup",), patch(
"scapy.arch.common.compile_filter"
), patch("scapy.sendrecv.AsyncSniffer", _mock_sniffer):
await dhcp_watcher.async_start()
return handle_dhcp_packet
async def test_dhcp_match_hostname_and_macaddress(hass):
"""Test matching based on hostname and macaddress."""
integration_matchers = [
{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}
]
packet = Ether(RAW_DHCP_REQUEST) packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init: with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet) handle_dhcp_packet(packet)
# Ensure no change is ignored # Ensure no change is ignored
dhcp_watcher.handle_dhcp_packet(packet) handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1 assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain" assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -152,18 +171,17 @@ async def test_dhcp_match_hostname_and_macaddress(hass):
async def test_dhcp_renewal_match_hostname_and_macaddress(hass): async def test_dhcp_renewal_match_hostname_and_macaddress(hass):
"""Test renewal matching based on hostname and macaddress.""" """Test renewal matching based on hostname and macaddress."""
dhcp_watcher = dhcp.DHCPWatcher( integration_matchers = [
hass, {"domain": "mock-domain", "hostname": "irobot-*", "macaddress": "501479*"}
{}, ]
[{"domain": "mock-domain", "hostname": "irobot-*", "macaddress": "501479*"}],
)
packet = Ether(RAW_DHCP_RENEWAL) packet = Ether(RAW_DHCP_RENEWAL)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init: with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet) handle_dhcp_packet(packet)
# Ensure no change is ignored # Ensure no change is ignored
dhcp_watcher.handle_dhcp_packet(packet) handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1 assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain" assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -179,14 +197,13 @@ async def test_dhcp_renewal_match_hostname_and_macaddress(hass):
async def test_dhcp_match_hostname(hass): async def test_dhcp_match_hostname(hass):
"""Test matching based on hostname only.""" """Test matching based on hostname only."""
dhcp_watcher = dhcp.DHCPWatcher( integration_matchers = [{"domain": "mock-domain", "hostname": "connect"}]
hass, {}, [{"domain": "mock-domain", "hostname": "connect"}]
)
packet = Ether(RAW_DHCP_REQUEST) packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init: with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet) handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1 assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain" assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -202,14 +219,13 @@ async def test_dhcp_match_hostname(hass):
async def test_dhcp_match_macaddress(hass): async def test_dhcp_match_macaddress(hass):
"""Test matching based on macaddress only.""" """Test matching based on macaddress only."""
dhcp_watcher = dhcp.DHCPWatcher( integration_matchers = [{"domain": "mock-domain", "macaddress": "B8B7F1*"}]
hass, {}, [{"domain": "mock-domain", "macaddress": "B8B7F1*"}]
)
packet = Ether(RAW_DHCP_REQUEST) packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init: with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet) handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1 assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain" assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -225,14 +241,13 @@ async def test_dhcp_match_macaddress(hass):
async def test_dhcp_match_macaddress_without_hostname(hass): async def test_dhcp_match_macaddress_without_hostname(hass):
"""Test matching based on macaddress only.""" """Test matching based on macaddress only."""
dhcp_watcher = dhcp.DHCPWatcher( integration_matchers = [{"domain": "mock-domain", "macaddress": "606BBD*"}]
hass, {}, [{"domain": "mock-domain", "macaddress": "606BBD*"}]
)
packet = Ether(RAW_DHCP_REQUEST_WITHOUT_HOSTNAME) packet = Ether(RAW_DHCP_REQUEST_WITHOUT_HOSTNAME)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init: with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet) handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1 assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain" assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -248,51 +263,46 @@ async def test_dhcp_match_macaddress_without_hostname(hass):
async def test_dhcp_nomatch(hass): async def test_dhcp_nomatch(hass):
"""Test not matching based on macaddress only.""" """Test not matching based on macaddress only."""
dhcp_watcher = dhcp.DHCPWatcher( integration_matchers = [{"domain": "mock-domain", "macaddress": "ABC123*"}]
hass, {}, [{"domain": "mock-domain", "macaddress": "ABC123*"}]
)
packet = Ether(RAW_DHCP_REQUEST) packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init: with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet) handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0 assert len(mock_init.mock_calls) == 0
async def test_dhcp_nomatch_hostname(hass): async def test_dhcp_nomatch_hostname(hass):
"""Test not matching based on hostname only.""" """Test not matching based on hostname only."""
dhcp_watcher = dhcp.DHCPWatcher( integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
packet = Ether(RAW_DHCP_REQUEST) packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init: with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet) handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0 assert len(mock_init.mock_calls) == 0
async def test_dhcp_nomatch_non_dhcp_packet(hass): async def test_dhcp_nomatch_non_dhcp_packet(hass):
"""Test matching does not throw on a non-dhcp packet.""" """Test matching does not throw on a non-dhcp packet."""
dhcp_watcher = dhcp.DHCPWatcher( integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
packet = Ether(b"") packet = Ether(b"")
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init: with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet) handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0 assert len(mock_init.mock_calls) == 0
async def test_dhcp_nomatch_non_dhcp_request_packet(hass): async def test_dhcp_nomatch_non_dhcp_request_packet(hass):
"""Test nothing happens with the wrong message-type.""" """Test nothing happens with the wrong message-type."""
dhcp_watcher = dhcp.DHCPWatcher( integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
packet = Ether(RAW_DHCP_REQUEST) packet = Ether(RAW_DHCP_REQUEST)
@ -305,17 +315,16 @@ async def test_dhcp_nomatch_non_dhcp_request_packet(hass):
("hostname", b"connect"), ("hostname", b"connect"),
] ]
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init: with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet) handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0 assert len(mock_init.mock_calls) == 0
async def test_dhcp_invalid_hostname(hass): async def test_dhcp_invalid_hostname(hass):
"""Test we ignore invalid hostnames.""" """Test we ignore invalid hostnames."""
dhcp_watcher = dhcp.DHCPWatcher( integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
packet = Ether(RAW_DHCP_REQUEST) packet = Ether(RAW_DHCP_REQUEST)
@ -328,17 +337,16 @@ async def test_dhcp_invalid_hostname(hass):
("hostname", "connect"), ("hostname", "connect"),
] ]
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init: with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet) handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0 assert len(mock_init.mock_calls) == 0
async def test_dhcp_missing_hostname(hass): async def test_dhcp_missing_hostname(hass):
"""Test we ignore missing hostnames.""" """Test we ignore missing hostnames."""
dhcp_watcher = dhcp.DHCPWatcher( integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
packet = Ether(RAW_DHCP_REQUEST) packet = Ether(RAW_DHCP_REQUEST)
@ -351,17 +359,16 @@ async def test_dhcp_missing_hostname(hass):
("hostname", None), ("hostname", None),
] ]
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init: with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet) handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0 assert len(mock_init.mock_calls) == 0
async def test_dhcp_invalid_option(hass): async def test_dhcp_invalid_option(hass):
"""Test we ignore invalid hostname option.""" """Test we ignore invalid hostname option."""
dhcp_watcher = dhcp.DHCPWatcher( integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}]
hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}]
)
packet = Ether(RAW_DHCP_REQUEST) packet = Ether(RAW_DHCP_REQUEST)
@ -374,8 +381,9 @@ async def test_dhcp_invalid_option(hass):
("hostname"), ("hostname"),
] ]
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
with patch.object(hass.config_entries.flow, "async_init") as mock_init: with patch.object(hass.config_entries.flow, "async_init") as mock_init:
dhcp_watcher.handle_dhcp_packet(packet) handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0 assert len(mock_init.mock_calls) == 0