From c0e22be7a875d7a70010a21d16714c2051d98b3e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 15 Feb 2023 20:36:00 -0600 Subject: [PATCH] Fix allowing identical flows to be created before startup (#88213) The check for identical flows only worked after the start event. We now check against pending flows as well If startup took a while we could end up with quite the thundering herd --- homeassistant/helpers/discovery_flow.py | 54 ++++++++++++++++++------- tests/helpers/test_discovery_flow.py | 30 +++++++++++++- 2 files changed, 68 insertions(+), 16 deletions(-) diff --git a/homeassistant/helpers/discovery_flow.py b/homeassistant/helpers/discovery_flow.py index 863fb58625c..2bfccf46960 100644 --- a/homeassistant/helpers/discovery_flow.py +++ b/homeassistant/helpers/discovery_flow.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Coroutine -from typing import Any +from typing import Any, NamedTuple from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.core import CoreState, Event, HomeAssistant, callback @@ -20,17 +20,18 @@ def async_create_flow( hass: HomeAssistant, domain: str, context: dict[str, Any], data: Any ) -> None: """Create a discovery flow.""" - if hass.state == CoreState.running: + dispatcher: FlowDispatcher | None = None + if DISCOVERY_FLOW_DISPATCHER in hass.data: + dispatcher = hass.data[DISCOVERY_FLOW_DISPATCHER] + elif hass.state != CoreState.running: + dispatcher = hass.data[DISCOVERY_FLOW_DISPATCHER] = FlowDispatcher(hass) + dispatcher.async_setup() + + if not dispatcher or dispatcher.started: if init_coro := _async_init_flow(hass, domain, context, data): hass.async_create_task(init_coro) return - if DISCOVERY_FLOW_DISPATCHER not in hass.data: - dispatcher = hass.data[DISCOVERY_FLOW_DISPATCHER] = FlowDispatcher(hass) - dispatcher.async_setup() - else: - dispatcher = hass.data[DISCOVERY_FLOW_DISPATCHER] - return dispatcher.async_create(domain, context, data) @@ -49,13 +50,28 @@ def _async_init_flow( return hass.config_entries.flow.async_init(domain, context=context, data=data) +class PendingFlowKey(NamedTuple): + """Key for pending flows.""" + + domain: str + source: str + + +class PendingFlowValue(NamedTuple): + """Value for pending flows.""" + + context: dict[str, Any] + data: Any + + class FlowDispatcher: """Dispatch discovery flows.""" def __init__(self, hass: HomeAssistant) -> None: """Init the discovery dispatcher.""" self.hass = hass - self.pending_flows: list[tuple[str, dict[str, Any], Any]] = [] + self.started = False + self.pending_flows: dict[PendingFlowKey, list[PendingFlowValue]] = {} @callback def async_setup(self) -> None: @@ -64,10 +80,16 @@ class FlowDispatcher: async def _async_start(self, event: Event) -> None: """Start processing pending flows.""" - self.hass.data.pop(DISCOVERY_FLOW_DISPATCHER) - - init_coros = [_async_init_flow(self.hass, *flow) for flow in self.pending_flows] - + pending_flows = self.pending_flows + self.pending_flows = {} + self.started = True + init_coros = [ + _async_init_flow( + self.hass, flow_key.domain, flow_values.context, flow_values.data + ) + for flow_key, flows in pending_flows.items() + for flow_values in flows + ] await gather_with_concurrency( FLOW_INIT_LIMIT, *[init_coro for init_coro in init_coros if init_coro is not None], @@ -76,4 +98,8 @@ class FlowDispatcher: @callback def async_create(self, domain: str, context: dict[str, Any], data: Any) -> None: """Create and add or queue a flow.""" - self.pending_flows.append((domain, context, data)) + key = PendingFlowKey(domain, context["source"]) + values = PendingFlowValue(context, data) + existing = self.pending_flows.setdefault(key, []) + if not any(existing_values.data == data for existing_values in existing): + existing.append(values) diff --git a/tests/helpers/test_discovery_flow.py b/tests/helpers/test_discovery_flow.py index 549848e5c7b..4019be80315 100644 --- a/tests/helpers/test_discovery_flow.py +++ b/tests/helpers/test_discovery_flow.py @@ -56,8 +56,11 @@ async def test_async_create_flow_deferred_until_started(hass, mock_flow_init): ] -async def test_async_create_flow_checks_existing_flows(hass, mock_flow_init): - """Test existing flows prevent an identical one from being creates.""" +async def test_async_create_flow_checks_existing_flows_after_startup( + hass, mock_flow_init +): + """Test existing flows prevent an identical ones from being after startup.""" + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) with patch( "homeassistant.data_entry_flow.FlowManager.async_has_matching_flow", return_value=True, @@ -69,3 +72,26 @@ async def test_async_create_flow_checks_existing_flows(hass, mock_flow_init): {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, ) assert not mock_flow_init.mock_calls + + +async def test_async_create_flow_checks_existing_flows_before_startup( + hass, mock_flow_init +): + """Test existing flows prevent an identical ones from being created before startup.""" + hass.state = CoreState.stopped + for _ in range(2): + discovery_flow.async_create_flow( + hass, + "hue", + {"source": config_entries.SOURCE_HOMEKIT}, + {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + assert mock_flow_init.mock_calls == [ + call( + "hue", + context={"source": "homekit"}, + data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + ]