Prevent event loop delay / instability from discovery (#57463)

This commit is contained in:
J. Nick Koston 2021-10-13 05:37:14 -10:00 committed by GitHub
parent ffbe4cffae
commit b86e19143d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 353 additions and 284 deletions

View file

@ -1,6 +1,5 @@
"""The dhcp integration."""
from abc import abstractmethod
from datetime import timedelta
import fnmatch
from ipaddress import ip_address as make_ip_address
@ -17,6 +16,7 @@ from aiodiscover.discovery import (
from scapy.config import conf
from scapy.error import Scapy_Exception
from homeassistant import config_entries
from homeassistant.components.device_tracker.const import (
ATTR_HOST_NAME,
ATTR_IP,
@ -31,6 +31,7 @@ from homeassistant.const import (
STATE_HOME,
)
from homeassistant.core import Event, HomeAssistant, State, callback
from homeassistant.helpers import discovery_flow
from homeassistant.helpers.device_registry import format_mac
from homeassistant.helpers.event import (
async_track_state_added_domain,
@ -38,10 +39,9 @@ from homeassistant.helpers.event import (
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import async_get_dhcp
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.network import is_invalid, is_link_local, is_loopback
from .const import DOMAIN
FILTER = "udp and (port 67 or 68)"
REQUESTED_ADDR = "requested_addr"
MESSAGE_TYPE = "message-type"
@ -89,6 +89,17 @@ class WatcherBase:
self._address_data = address_data
def process_client(self, ip_address, hostname, mac_address):
"""Process a client."""
return run_callback_threadsafe(
self.hass.loop,
self.async_process_client,
ip_address,
hostname,
mac_address,
).result()
@callback
def async_process_client(self, ip_address, hostname, mac_address):
"""Process a client."""
made_ip_address = make_ip_address(ip_address)
@ -101,7 +112,6 @@ class WatcherBase:
return
data = self._address_data.get(ip_address)
if (
data
and data[MAC_ADDRESS] == mac_address
@ -111,12 +121,9 @@ class WatcherBase:
# to process it
return
self._address_data[ip_address] = {MAC_ADDRESS: mac_address, HOSTNAME: hostname}
data = {MAC_ADDRESS: mac_address, HOSTNAME: hostname}
self._address_data[ip_address] = data
self.process_updated_address_data(ip_address, self._address_data[ip_address])
def process_updated_address_data(self, ip_address, data):
"""Process the address data update."""
lowercase_hostname = data[HOSTNAME].lower()
uppercase_mac = data[MAC_ADDRESS].upper()
@ -139,23 +146,17 @@ class WatcherBase:
continue
_LOGGER.debug("Matched %s against %s", data, entry)
self.create_task(
self.hass.config_entries.flow.async_init(
entry["domain"],
context={"source": DOMAIN},
data={
IP_ADDRESS: ip_address,
HOSTNAME: lowercase_hostname,
MAC_ADDRESS: data[MAC_ADDRESS],
},
)
discovery_flow.async_create_flow(
self.hass,
entry["domain"],
{"source": config_entries.SOURCE_DHCP},
{
IP_ADDRESS: ip_address,
HOSTNAME: lowercase_hostname,
MAC_ADDRESS: data[MAC_ADDRESS],
},
)
@abstractmethod
def create_task(self, task):
"""Pass a task to async_add_task based on which context we are in."""
class NetworkWatcher(WatcherBase):
"""Class to query ptr records routers."""
@ -189,21 +190,17 @@ class NetworkWatcher(WatcherBase):
"""Start a new discovery task if one is not running."""
if self._discover_task and not self._discover_task.done():
return
self._discover_task = self.create_task(self.async_discover())
self._discover_task = self.hass.async_create_task(self.async_discover())
async def async_discover(self):
"""Process discovery."""
for host in await self._discover_hosts.async_discover():
self.process_client(
self.async_process_client(
host[DISCOVERY_IP_ADDRESS],
host[DISCOVERY_HOSTNAME],
_format_mac(host[DISCOVERY_MAC_ADDRESS]),
)
def create_task(self, task):
"""Pass a task to async_create_task since we are in async context."""
return self.hass.async_create_task(task)
class DeviceTrackerWatcher(WatcherBase):
"""Class to watch dhcp data from routers."""
@ -250,11 +247,7 @@ class DeviceTrackerWatcher(WatcherBase):
if ip_address is None or mac_address is None:
return
self.process_client(ip_address, hostname, _format_mac(mac_address))
def create_task(self, task):
"""Pass a task to async_create_task since we are in async context."""
return self.hass.async_create_task(task)
self.async_process_client(ip_address, hostname, _format_mac(mac_address))
class DHCPWatcher(WatcherBase):
@ -353,10 +346,6 @@ class DHCPWatcher(WatcherBase):
if self._sniffer.thread:
self._sniffer.thread.name = self.__class__.__name__
def create_task(self, task):
"""Pass a task to hass.add_job since we are in a thread."""
return self.hass.add_job(task)
def _decode_dhcp_option(dhcp_options, key):
"""Extract and decode data from a packet option."""

View file

@ -18,19 +18,14 @@ from async_upnp_client.utils import CaseInsensitiveDict
from homeassistant import config_entries
from homeassistant.components import network
from homeassistant.const import (
EVENT_HOMEASSISTANT_STARTED,
EVENT_HOMEASSISTANT_STOP,
MATCH_ALL,
)
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, MATCH_ALL
from homeassistant.core import HomeAssistant, callback as core_callback
from homeassistant.helpers import discovery_flow
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import async_get_ssdp, bind_hass
from .flow import FlowDispatcher, SSDPFlow
DOMAIN = "ssdp"
SCAN_INTERVAL = timedelta(seconds=60)
@ -222,7 +217,6 @@ class Scanner:
self._cancel_scan: Callable[[], None] | None = None
self._ssdp_listeners: list[SsdpListener] = []
self._callbacks: list[tuple[SsdpCallback, dict[str, str]]] = []
self._flow_dispatcher: FlowDispatcher | None = None
self._description_cache: DescriptionCache | None = None
self.integration_matchers = integration_matchers
@ -327,14 +321,10 @@ class Scanner:
session = async_get_clientsession(self.hass)
requester = AiohttpSessionRequester(session, True, 10)
self._description_cache = DescriptionCache(requester)
self._flow_dispatcher = FlowDispatcher(self.hass)
await self._async_start_ssdp_listeners()
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self.async_stop)
self.hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STARTED, self._flow_dispatcher.async_start
)
self._cancel_scan = async_track_time_interval(
self.hass, self.async_scan, SCAN_INTERVAL
)
@ -417,13 +407,12 @@ class Scanner:
for domain in matching_domains:
_LOGGER.debug("Discovered %s at %s", domain, location)
flow: SSDPFlow = {
"domain": domain,
"context": {"source": config_entries.SOURCE_SSDP},
"data": discovery_info,
}
assert self._flow_dispatcher is not None
self._flow_dispatcher.create(flow)
discovery_flow.async_create_flow(
self.hass,
domain,
{"source": config_entries.SOURCE_SSDP},
discovery_info,
)
async def _async_get_description_dict(
self, location: str | None

View file

@ -1,50 +0,0 @@
"""The SSDP integration."""
from __future__ import annotations
from collections.abc import Coroutine
from typing import Any, TypedDict
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
class SSDPFlow(TypedDict):
"""A queued ssdp discovery flow."""
domain: str
context: dict[str, Any]
data: dict
class FlowDispatcher:
"""Dispatch discovery flows."""
def __init__(self, hass: HomeAssistant) -> None:
"""Init the discovery dispatcher."""
self.hass = hass
self.pending_flows: list[SSDPFlow] = []
self.started = False
@callback
def async_start(self, *_: Any) -> None:
"""Start processing pending flows."""
self.started = True
self.hass.loop.call_soon(self._async_process_pending_flows)
def _async_process_pending_flows(self) -> None:
for flow in self.pending_flows:
self.hass.async_create_task(self._init_flow(flow))
self.pending_flows = []
def create(self, flow: SSDPFlow) -> None:
"""Create and add or queue a flow."""
if self.started:
self.hass.async_create_task(self._init_flow(flow))
else:
self.pending_flows.append(flow)
def _init_flow(self, flow: SSDPFlow) -> Coroutine[None, None, FlowResult]:
"""Create a flow."""
return self.hass.config_entries.flow.async_init(
flow["domain"], context=flow["context"], data=flow["data"]
)

View file

@ -16,13 +16,12 @@ from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers import system_info
from homeassistant.helpers import discovery_flow, system_info
from homeassistant.helpers.debounce import Debouncer
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import async_get_usb
from .const import DOMAIN
from .flow import FlowDispatcher, USBFlow
from .models import USBDevice
from .utils import usb_device_from_port
@ -65,7 +64,7 @@ def get_serial_by_id(dev_path: str) -> str:
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the USB Discovery integration."""
usb = await async_get_usb(hass)
usb_discovery = USBDiscovery(hass, FlowDispatcher(hass), usb)
usb_discovery = USBDiscovery(hass, usb)
await usb_discovery.async_setup()
hass.data[DOMAIN] = usb_discovery
websocket_api.async_register_command(hass, websocket_usb_scan)
@ -86,12 +85,10 @@ class USBDiscovery:
def __init__(
self,
hass: HomeAssistant,
flow_dispatcher: FlowDispatcher,
usb: list[dict[str, str]],
) -> None:
"""Init USB Discovery."""
self.hass = hass
self.flow_dispatcher = flow_dispatcher
self.usb = usb
self.seen: set[tuple[str, ...]] = set()
self.observer_active = False
@ -104,7 +101,6 @@ class USBDiscovery:
async def async_start(self, event: Event) -> None:
"""Start USB Discovery and run a manual scan."""
self.flow_dispatcher.async_start()
await self._async_scan_serial()
async def _async_start_monitor(self) -> None:
@ -193,12 +189,12 @@ class USBDiscovery:
if len(matcher) < most_matched_fields:
break
flow: USBFlow = {
"domain": matcher["domain"],
"context": {"source": config_entries.SOURCE_USB},
"data": dataclasses.asdict(device),
}
self.flow_dispatcher.async_create(flow)
discovery_flow.async_create_flow(
self.hass,
matcher["domain"],
{"source": config_entries.SOURCE_USB},
dataclasses.asdict(device),
)
@callback
def _async_process_ports(self, ports: list[ListPortInfo]) -> None:

View file

@ -1,48 +0,0 @@
"""The USB Discovery integration."""
from __future__ import annotations
from collections.abc import Coroutine
from typing import Any, TypedDict
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
class USBFlow(TypedDict):
"""A queued usb discovery flow."""
domain: str
context: dict[str, Any]
data: dict
class FlowDispatcher:
"""Dispatch discovery flows."""
def __init__(self, hass: HomeAssistant) -> None:
"""Init the discovery dispatcher."""
self.hass = hass
self.pending_flows: list[USBFlow] = []
self.started = False
@callback
def async_start(self, *_: Any) -> None:
"""Start processing pending flows."""
self.started = True
for flow in self.pending_flows:
self.hass.async_create_task(self._init_flow(flow))
self.pending_flows = []
@callback
def async_create(self, flow: USBFlow) -> None:
"""Create and add or queue a flow."""
if self.started:
self.hass.async_create_task(self._init_flow(flow))
else:
self.pending_flows.append(flow)
def _init_flow(self, flow: USBFlow) -> Coroutine[None, None, FlowResult]:
"""Create a flow."""
return self.hass.config_entries.flow.async_init(
flow["domain"], context=flow["context"], data=flow["data"]
)

View file

@ -2,7 +2,6 @@
from __future__ import annotations
import asyncio
from collections.abc import Coroutine
from contextlib import suppress
import fnmatch
from ipaddress import IPv4Address, IPv6Address, ip_address
@ -21,12 +20,11 @@ from homeassistant.components.network import async_get_source_ip
from homeassistant.components.network.models import Adapter
from homeassistant.const import (
EVENT_HOMEASSISTANT_START,
EVENT_HOMEASSISTANT_STARTED,
EVENT_HOMEASSISTANT_STOP,
__version__,
)
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import discovery_flow
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.network import NoURLAvailableError, get_url
from homeassistant.helpers.typing import ConfigType
@ -91,14 +89,6 @@ class HaServiceInfo(TypedDict):
properties: dict[str, Any]
class ZeroconfFlow(TypedDict):
"""A queued zeroconf discovery flow."""
domain: str
context: dict[str, Any]
data: HaServiceInfo
@bind_hass
async def async_get_instance(hass: HomeAssistant) -> HaZeroconf:
"""Zeroconf instance to be shared with other integrations that use it."""
@ -192,17 +182,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
uuid = await hass.helpers.instance_id.async_get()
await _async_register_hass_zc_service(hass, aio_zc, uuid)
@callback
def _async_start_discovery(_event: Event) -> None:
"""Start processing flows."""
discovery.async_start()
async def _async_zeroconf_hass_stop(_event: Event) -> None:
await discovery.async_stop()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_zeroconf_hass_stop)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, _async_zeroconf_hass_start)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, _async_start_discovery)
return True
@ -288,40 +272,6 @@ async def _async_register_hass_zc_service(
await aio_zc.async_register_service(info, allow_name_change=True)
class FlowDispatcher:
"""Dispatch discovery flows."""
def __init__(self, hass: HomeAssistant) -> None:
"""Init the discovery dispatcher."""
self.hass = hass
self.pending_flows: list[ZeroconfFlow] = []
self.started = False
@callback
def async_start(self) -> None:
"""Start processing pending flows."""
self.started = True
self.hass.loop.call_soon(self._async_process_pending_flows)
def _async_process_pending_flows(self) -> None:
for flow in self.pending_flows:
self.hass.async_create_task(self._init_flow(flow))
self.pending_flows = []
def async_create(self, flow: ZeroconfFlow) -> None:
"""Create and add or queue a flow."""
if self.started:
self.hass.async_create_task(self._init_flow(flow))
else:
self.pending_flows.append(flow)
def _init_flow(self, flow: ZeroconfFlow) -> Coroutine[None, None, FlowResult]:
"""Create a flow."""
return self.hass.config_entries.flow.async_init(
flow["domain"], context=flow["context"], data=flow["data"]
)
class ZeroconfDiscovery:
"""Discovery via zeroconf."""
@ -340,12 +290,10 @@ class ZeroconfDiscovery:
self.homekit_models = homekit_models
self.ipv6 = ipv6
self.flow_dispatcher: FlowDispatcher | None = None
self.async_service_browser: HaAsyncServiceBrowser | None = None
async def async_setup(self) -> None:
"""Start discovery."""
self.flow_dispatcher = FlowDispatcher(self.hass)
types = list(self.zeroconf_types)
# We want to make sure we know about other HomeAssistant
# instances as soon as possible to avoid name conflicts
@ -363,12 +311,6 @@ class ZeroconfDiscovery:
if self.async_service_browser:
await self.async_service_browser.async_cancel()
@callback
def async_start(self) -> None:
"""Start processing discovery flows."""
assert self.flow_dispatcher is not None
self.flow_dispatcher.async_start()
@callback
def async_service_update(
self,
@ -404,12 +346,14 @@ class ZeroconfDiscovery:
return
_LOGGER.debug("Discovered new device %s %s", name, info)
assert self.flow_dispatcher is not None
# If we can handle it as a HomeKit discovery, we do that here.
if service_type in HOMEKIT_TYPES:
if pending_flow := handle_homekit(self.hass, self.homekit_models, info):
self.flow_dispatcher.async_create(pending_flow)
props = info["properties"]
if domain := async_get_homekit_discovery_domain(self.homekit_models, props):
discovery_flow.async_create_flow(
self.hass, domain, {"source": config_entries.SOURCE_HOMEKIT}, info
)
# Continue on here as homekit_controller
# still needs to get updates on devices
# so it can see when the 'c#' field is updated.
@ -417,10 +361,10 @@ class ZeroconfDiscovery:
# We only send updates to homekit_controller
# if the device is already paired in order to avoid
# offering a second discovery for the same device
if pending_flow and HOMEKIT_PAIRED_STATUS_FLAG in info["properties"]:
if domain and HOMEKIT_PAIRED_STATUS_FLAG in props:
try:
# 0 means paired and not discoverable by iOS clients)
if int(info["properties"][HOMEKIT_PAIRED_STATUS_FLAG]):
if int(props[HOMEKIT_PAIRED_STATUS_FLAG]):
return
except ValueError:
# HomeKit pairing status unknown
@ -466,24 +410,22 @@ class ZeroconfDiscovery:
):
continue
flow: ZeroconfFlow = {
"domain": matcher["domain"],
"context": {"source": config_entries.SOURCE_ZEROCONF},
"data": info,
}
self.flow_dispatcher.async_create(flow)
discovery_flow.async_create_flow(
self.hass,
matcher["domain"],
{"source": config_entries.SOURCE_ZEROCONF},
info,
)
def handle_homekit(
hass: HomeAssistant, homekit_models: dict[str, str], info: HaServiceInfo
) -> ZeroconfFlow | None:
def async_get_homekit_discovery_domain(
homekit_models: dict[str, str], props: dict[str, Any]
) -> str | None:
"""Handle a HomeKit discovery.
Return if discovery was forwarded.
Return the domain to forward the discovery data to
"""
model = None
props = info["properties"]
for key in props:
if key.lower() == HOMEKIT_MODEL:
model = props[key]
@ -500,11 +442,7 @@ def handle_homekit(
):
continue
return {
"domain": homekit_models[test_model],
"context": {"source": config_entries.SOURCE_HOMEKIT},
"data": info,
}
return homekit_models[test_model]
return None

View file

@ -120,6 +120,19 @@ class FlowManager(abc.ABC):
async def async_post_init(self, flow: FlowHandler, result: FlowResult) -> None:
"""Entry has finished executing its first step asynchronously."""
@callback
def async_has_matching_flow(
self, handler: str, context: dict[str, Any], data: Any
) -> bool:
"""Check if an existing matching flow is in progress with the same handler, context, and data."""
return any(
flow
for flow in self._progress.values()
if flow.handler == handler
and flow.context["source"] == context["source"]
and flow.init_data == data
)
@callback
def async_progress(self, include_uninitialized: bool = False) -> list[FlowResult]:
"""Return the flows in progress."""
@ -173,6 +186,7 @@ class FlowManager(abc.ABC):
flow.handler = handler
flow.flow_id = uuid.uuid4().hex
flow.context = context
flow.init_data = data
self._progress[flow.flow_id] = flow
result = await self._async_handle_step(flow, flow.init_step, data, init_done)
return flow, result
@ -318,6 +332,9 @@ class FlowHandler:
# Set by _async_create_flow callback
init_step = "init"
# The initial data that was used to start the flow
init_data: Any = None
# Set by developer
VERSION = 1

View file

@ -0,0 +1,82 @@
"""The discovery flow helper."""
from __future__ import annotations
from collections.abc import Coroutine
from typing import Any
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import CoreState, Event, HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.loader import bind_hass
from homeassistant.util.async_ import gather_with_concurrency
FLOW_INIT_LIMIT = 2
DISCOVERY_FLOW_DISPATCHER = "discovery_flow_disptacher"
@bind_hass
@callback
def async_create_flow(
hass: HomeAssistant, domain: str, context: dict[str, Any], data: Any
) -> None:
"""Create a discovery flow."""
if hass.state == CoreState.running:
if init_coro := _async_init_flow(hass, domain, context, data):
hass.async_create_task(init_coro)
return
if DISCOVERY_FLOW_DISPATCHER not in hass.data:
dispatcher = hass.data[DISCOVERY_FLOW_DISPATCHER] = FlowDispatcher(hass)
dispatcher.async_setup()
else:
dispatcher = hass.data[DISCOVERY_FLOW_DISPATCHER]
return dispatcher.async_create(domain, context, data)
@callback
def _async_init_flow(
hass: HomeAssistant, domain: str, context: dict[str, Any], data: Any
) -> Coroutine[None, None, FlowResult] | None:
"""Create a discovery flow."""
# Avoid spawning flows that have the same initial discovery data
# as ones in progress as it may cause additional device probing
# which can overload devices since zeroconf/ssdp updates can happen
# multiple times in the same minute
if hass.config_entries.flow.async_has_matching_flow(domain, context, data):
return None
return hass.config_entries.flow.async_init(domain, context=context, data=data)
class FlowDispatcher:
"""Dispatch discovery flows."""
def __init__(self, hass: HomeAssistant) -> None:
"""Init the discovery dispatcher."""
self.hass = hass
self.pending_flows: list[tuple[str, dict[str, Any], Any]] = []
@callback
def async_setup(self) -> None:
"""Set up the flow disptcher."""
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, self.async_start)
@callback
def async_start(self, event: Event) -> None:
"""Start processing pending flows."""
self.hass.data.pop(DISCOVERY_FLOW_DISPATCHER)
self.hass.async_create_task(self._async_process_pending_flows())
async def _async_process_pending_flows(self) -> None:
"""Process any pending discovery flows."""
init_coros = [_async_init_flow(self.hass, *flow) for flow in self.pending_flows]
await gather_with_concurrency(
FLOW_INIT_LIMIT,
*[init_coro for init_coro in init_coros if init_coro is not None],
)
@callback
def async_create(self, domain: str, context: dict[str, Any], data: Any) -> None:
"""Create and add or queue a flow."""
self.pending_flows.append((domain, context, data))

View file

@ -3,6 +3,7 @@ import datetime
import threading
from unittest.mock import MagicMock, patch
from scapy import arch # pylint: unused-import # noqa: F401
from scapy.error import Scapy_Exception
from scapy.layers.dhcp import DHCP
from scapy.layers.l2 import Ether
@ -16,6 +17,7 @@ from homeassistant.components.device_tracker.const import (
ATTR_SOURCE_TYPE,
SOURCE_TYPE_ROUTER,
)
from homeassistant.components.dhcp.const import DOMAIN
from homeassistant.const import (
EVENT_HOMEASSISTANT_STARTED,
EVENT_HOMEASSISTANT_STOP,
@ -129,11 +131,16 @@ async def _async_get_handle_dhcp_packet(hass, integration_matchers):
{},
integration_matchers,
)
handle_dhcp_packet = None
async_handle_dhcp_packet = None
def _mock_sniffer(*args, **kwargs):
nonlocal handle_dhcp_packet
handle_dhcp_packet = kwargs["prn"]
nonlocal async_handle_dhcp_packet
callback = kwargs["prn"]
async def _async_handle_dhcp_packet(packet):
await hass.async_add_executor_job(callback, packet)
async_handle_dhcp_packet = _async_handle_dhcp_packet
return MagicMock()
with patch("homeassistant.components.dhcp._verify_l2socket_setup",), patch(
@ -141,7 +148,7 @@ async def _async_get_handle_dhcp_packet(hass, integration_matchers):
), patch("scapy.sendrecv.AsyncSniffer", _mock_sniffer):
await dhcp_watcher.async_start()
return handle_dhcp_packet
return async_handle_dhcp_packet
async def test_dhcp_match_hostname_and_macaddress(hass):
@ -151,11 +158,13 @@ async def test_dhcp_match_hostname_and_macaddress(hass):
]
packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
async_handle_dhcp_packet = await _async_get_handle_dhcp_packet(
hass, integration_matchers
)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
handle_dhcp_packet(packet)
await async_handle_dhcp_packet(packet)
# Ensure no change is ignored
handle_dhcp_packet(packet)
await async_handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -177,11 +186,13 @@ async def test_dhcp_renewal_match_hostname_and_macaddress(hass):
packet = Ether(RAW_DHCP_RENEWAL)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
async_handle_dhcp_packet = await _async_get_handle_dhcp_packet(
hass, integration_matchers
)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
handle_dhcp_packet(packet)
await async_handle_dhcp_packet(packet)
# Ensure no change is ignored
handle_dhcp_packet(packet)
await async_handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -201,9 +212,11 @@ async def test_dhcp_match_hostname(hass):
packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
async_handle_dhcp_packet = await _async_get_handle_dhcp_packet(
hass, integration_matchers
)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
handle_dhcp_packet(packet)
await async_handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -223,9 +236,11 @@ async def test_dhcp_match_macaddress(hass):
packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
async_handle_dhcp_packet = await _async_get_handle_dhcp_packet(
hass, integration_matchers
)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
handle_dhcp_packet(packet)
await async_handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -245,9 +260,11 @@ async def test_dhcp_match_macaddress_without_hostname(hass):
packet = Ether(RAW_DHCP_REQUEST_WITHOUT_HOSTNAME)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
async_handle_dhcp_packet = await _async_get_handle_dhcp_packet(
hass, integration_matchers
)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
handle_dhcp_packet(packet)
await async_handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 1
assert mock_init.mock_calls[0][1][0] == "mock-domain"
@ -267,9 +284,11 @@ async def test_dhcp_nomatch(hass):
packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
async_handle_dhcp_packet = await _async_get_handle_dhcp_packet(
hass, integration_matchers
)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
handle_dhcp_packet(packet)
await async_handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0
@ -280,9 +299,11 @@ async def test_dhcp_nomatch_hostname(hass):
packet = Ether(RAW_DHCP_REQUEST)
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
async_handle_dhcp_packet = await _async_get_handle_dhcp_packet(
hass, integration_matchers
)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
handle_dhcp_packet(packet)
await async_handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0
@ -293,9 +314,11 @@ async def test_dhcp_nomatch_non_dhcp_packet(hass):
packet = Ether(b"")
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
async_handle_dhcp_packet = await _async_get_handle_dhcp_packet(
hass, integration_matchers
)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
handle_dhcp_packet(packet)
await async_handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0
@ -315,9 +338,11 @@ async def test_dhcp_nomatch_non_dhcp_request_packet(hass):
("hostname", b"connect"),
]
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
async_handle_dhcp_packet = await _async_get_handle_dhcp_packet(
hass, integration_matchers
)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
handle_dhcp_packet(packet)
await async_handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0
@ -337,9 +362,11 @@ async def test_dhcp_invalid_hostname(hass):
("hostname", "connect"),
]
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
async_handle_dhcp_packet = await _async_get_handle_dhcp_packet(
hass, integration_matchers
)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
handle_dhcp_packet(packet)
await async_handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0
@ -359,9 +386,11 @@ async def test_dhcp_missing_hostname(hass):
("hostname", None),
]
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
async_handle_dhcp_packet = await _async_get_handle_dhcp_packet(
hass, integration_matchers
)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
handle_dhcp_packet(packet)
await async_handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0
@ -381,9 +410,11 @@ async def test_dhcp_invalid_option(hass):
("hostname"),
]
handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers)
async_handle_dhcp_packet = await _async_get_handle_dhcp_packet(
hass, integration_matchers
)
with patch.object(hass.config_entries.flow, "async_init") as mock_init:
handle_dhcp_packet(packet)
await async_handle_dhcp_packet(packet)
assert len(mock_init.mock_calls) == 0
@ -393,7 +424,7 @@ async def test_setup_and_stop(hass):
assert await async_setup_component(
hass,
dhcp.DOMAIN,
DOMAIN,
{},
)
await hass.async_block_till_done()
@ -417,7 +448,7 @@ async def test_setup_fails_as_root(hass, caplog):
assert await async_setup_component(
hass,
dhcp.DOMAIN,
DOMAIN,
{},
)
await hass.async_block_till_done()
@ -442,7 +473,7 @@ async def test_setup_fails_non_root(hass, caplog):
assert await async_setup_component(
hass,
dhcp.DOMAIN,
DOMAIN,
{},
)
await hass.async_block_till_done()
@ -464,7 +495,7 @@ async def test_setup_fails_with_broken_libpcap(hass, caplog):
assert await async_setup_component(
hass,
dhcp.DOMAIN,
DOMAIN,
{},
)
await hass.async_block_till_done()

View file

@ -431,7 +431,9 @@ async def test_scan_with_registered_callback(
"homeassistant.components.ssdp.async_get_ssdp",
return_value={"mock-domain": [{"st": "mock-st"}]},
)
async def test_getting_existing_headers(mock_get_ssdp, hass, aioclient_mock):
async def test_getting_existing_headers(
mock_get_ssdp, hass, aioclient_mock, mock_flow_init
):
"""Test getting existing/previously scanned headers."""
aioclient_mock.get(
"http://1.1.1.1",

View file

@ -0,0 +1,71 @@
"""Test the discovery flow helper."""
from unittest.mock import AsyncMock, call, patch
import pytest
from homeassistant import config_entries
from homeassistant.core import EVENT_HOMEASSISTANT_STARTED, CoreState
from homeassistant.helpers import discovery_flow
@pytest.fixture
def mock_flow_init(hass):
"""Mock hass.config_entries.flow.async_init."""
with patch.object(
hass.config_entries.flow, "async_init", return_value=AsyncMock()
) as mock_init:
yield mock_init
async def test_async_create_flow(hass, mock_flow_init):
"""Test we can create a flow."""
discovery_flow.async_create_flow(
hass,
"hue",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
assert mock_flow_init.mock_calls == [
call(
"hue",
context={"source": "homekit"},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
]
async def test_async_create_flow_deferred_until_started(hass, mock_flow_init):
"""Test flows are deferred until started."""
hass.state = CoreState.stopped
discovery_flow.async_create_flow(
hass,
"hue",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
assert not mock_flow_init.mock_calls
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
assert mock_flow_init.mock_calls == [
call(
"hue",
context={"source": "homekit"},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
]
async def test_async_create_flow_checks_existing_flows(hass, mock_flow_init):
"""Test existing flows prevent an identical one from being creates."""
with patch(
"homeassistant.data_entry_flow.FlowManager.async_has_matching_flow",
return_value=True,
):
discovery_flow.async_create_flow(
hass,
"hue",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
assert not mock_flow_init.mock_calls

View file

@ -6,6 +6,7 @@ import pytest
import voluptuous as vol
from homeassistant import config_entries, data_entry_flow
from homeassistant.core import HomeAssistant
from homeassistant.util.decorator import Registry
from tests.common import async_capture_events
@ -397,3 +398,54 @@ async def test_init_unknown_flow(manager):
manager, "async_create_flow", return_value=None
):
await manager.async_init("test")
async def test_async_has_matching_flow(
hass: HomeAssistant, manager: data_entry_flow.FlowManager
):
"""Test we can check for matching flows."""
manager.hass = hass
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
async def async_step_init(self, user_input=None):
return self.async_show_progress(
step_id="init",
progress_action="task_one",
)
result = await manager.async_init(
"test",
context={"source": config_entries.SOURCE_HOMEKIT},
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
assert result["type"] == data_entry_flow.RESULT_TYPE_SHOW_PROGRESS
assert result["progress_action"] == "task_one"
assert len(manager.async_progress()) == 1
assert (
manager.async_has_matching_flow(
"test",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is True
)
assert (
manager.async_has_matching_flow(
"test",
{"source": config_entries.SOURCE_SSDP},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is False
)
assert (
manager.async_has_matching_flow(
"other",
{"source": config_entries.SOURCE_HOMEKIT},
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
)
is False
)