Add FlowManager.async_has_matching_flow (#126804)
* Add FlowManager.async_flow_has_matching_flow * Revert changes from the future * Apply suggested changes to apple_tv config flow * Rename methods after discussion * Update homeassistant/data_entry_flow.py Co-authored-by: J. Nick Koston <nick@koston.org> * Move deduplication functions to config_entries, add tests * Adjust tests --------- Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
parent
26b5dab12b
commit
3c0be47d3c
7 changed files with 262 additions and 121 deletions
|
@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable, Mapping
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
import logging
|
import logging
|
||||||
from random import randrange
|
from random import randrange
|
||||||
from typing import Any
|
from typing import Any, Self
|
||||||
|
|
||||||
from pyatv import exceptions, pair, scan
|
from pyatv import exceptions, pair, scan
|
||||||
from pyatv.const import DeviceModel, PairingRequirement, Protocol
|
from pyatv.const import DeviceModel, PairingRequirement, Protocol
|
||||||
|
@ -98,8 +98,11 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
VERSION = 1
|
VERSION = 1
|
||||||
|
|
||||||
scan_filter: str | None = None
|
scan_filter: str | None = None
|
||||||
|
all_identifiers: set[str]
|
||||||
atv: BaseConfig | None = None
|
atv: BaseConfig | None = None
|
||||||
atv_identifiers: list[str] | None = None
|
atv_identifiers: list[str] | None = None
|
||||||
|
_host: str # host in zeroconf discovery info, should not be accessed by other flows
|
||||||
|
host: str | None = None # set by _async_aggregate_discoveries, for other flows
|
||||||
protocol: Protocol | None = None
|
protocol: Protocol | None = None
|
||||||
pairing: PairingHandler | None = None
|
pairing: PairingHandler | None = None
|
||||||
protocols_to_pair: deque[Protocol] | None = None
|
protocols_to_pair: deque[Protocol] | None = None
|
||||||
|
@ -157,7 +160,6 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
"type": "Apple TV",
|
"type": "Apple TV",
|
||||||
}
|
}
|
||||||
self.scan_filter = self.unique_id
|
self.scan_filter = self.unique_id
|
||||||
self.context["identifier"] = self.unique_id
|
|
||||||
return await self.async_step_restore_device()
|
return await self.async_step_restore_device()
|
||||||
|
|
||||||
async def async_step_restore_device(
|
async def async_step_restore_device(
|
||||||
|
@ -192,7 +194,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
self.device_identifier, raise_on_progress=False
|
self.device_identifier, raise_on_progress=False
|
||||||
)
|
)
|
||||||
assert self.atv
|
assert self.atv
|
||||||
self.context["all_identifiers"] = self.atv.all_identifiers
|
self.all_identifiers = set(self.atv.all_identifiers)
|
||||||
return await self.async_step_confirm()
|
return await self.async_step_confirm()
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
|
@ -207,7 +209,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
"""Handle device found via zeroconf."""
|
"""Handle device found via zeroconf."""
|
||||||
if discovery_info.ip_address.version == 6:
|
if discovery_info.ip_address.version == 6:
|
||||||
return self.async_abort(reason="ipv6_not_supported")
|
return self.async_abort(reason="ipv6_not_supported")
|
||||||
host = discovery_info.host
|
self._host = host = discovery_info.host
|
||||||
service_type = discovery_info.type[:-1] # Remove leading .
|
service_type = discovery_info.type[:-1] # Remove leading .
|
||||||
name = discovery_info.name.replace(f".{service_type}.", "")
|
name = discovery_info.name.replace(f".{service_type}.", "")
|
||||||
properties = discovery_info.properties
|
properties = discovery_info.properties
|
||||||
|
@ -255,7 +257,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
# as two separate flows.
|
# as two separate flows.
|
||||||
#
|
#
|
||||||
# To solve this, all identifiers are stored as
|
# To solve this, all identifiers are stored as
|
||||||
# "all_identifiers" in the flow context. When a new service is discovered, the
|
# "all_identifiers" in the flow. When a new service is discovered, the
|
||||||
# code below will check these identifiers for all active flows and abort if a
|
# code below will check these identifiers for all active flows and abort if a
|
||||||
# match is found. Before aborting, the original flow is updated with any
|
# match is found. Before aborting, the original flow is updated with any
|
||||||
# potentially new identifiers. In the example above, when service C is
|
# potentially new identifiers. In the example above, when service C is
|
||||||
|
@ -277,32 +279,32 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
self._async_check_and_update_in_progress(host, unique_id)
|
self._async_check_and_update_in_progress(host, unique_id)
|
||||||
# Host must only be set AFTER checking and updating in progress
|
# Host must only be set AFTER checking and updating in progress
|
||||||
# flows or we will have a race condition where no flows move forward.
|
# flows or we will have a race condition where no flows move forward.
|
||||||
self.context[CONF_ADDRESS] = host
|
self.host = host
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_check_and_update_in_progress(self, host: str, unique_id: str) -> None:
|
def _async_check_and_update_in_progress(self, host: str, unique_id: str) -> None:
|
||||||
"""Check for in-progress flows and update them with identifiers if needed."""
|
"""Check for in-progress flows and update them with identifiers if needed."""
|
||||||
for flow in self._async_in_progress(include_uninitialized=True):
|
if self.hass.config_entries.flow.async_has_matching_flow(self):
|
||||||
context = flow["context"]
|
|
||||||
if (
|
|
||||||
context.get("source") != SOURCE_ZEROCONF
|
|
||||||
or context.get(CONF_ADDRESS) != host
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
if (
|
|
||||||
"all_identifiers" in context
|
|
||||||
and unique_id not in context["all_identifiers"]
|
|
||||||
):
|
|
||||||
# Add potentially new identifiers from this device to the existing flow
|
|
||||||
context["all_identifiers"].append(unique_id)
|
|
||||||
raise AbortFlow("already_in_progress")
|
raise AbortFlow("already_in_progress")
|
||||||
|
|
||||||
|
def is_matching(self, other_flow: Self) -> bool:
|
||||||
|
"""Return True if other_flow is matching this flow."""
|
||||||
|
if (
|
||||||
|
other_flow.context.get("source") != SOURCE_ZEROCONF
|
||||||
|
or other_flow.host != self._host
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
if self.unique_id is not None:
|
||||||
|
# Add potentially new identifiers from this device to the existing flow
|
||||||
|
other_flow.all_identifiers.add(self.unique_id)
|
||||||
|
return True
|
||||||
|
|
||||||
async def async_found_zeroconf_device(
|
async def async_found_zeroconf_device(
|
||||||
self, user_input: dict[str, str] | None = None
|
self, user_input: dict[str, str] | None = None
|
||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Handle device found after Zeroconf discovery."""
|
"""Handle device found after Zeroconf discovery."""
|
||||||
assert self.atv
|
assert self.atv
|
||||||
self.context["all_identifiers"] = self.atv.all_identifiers
|
self.all_identifiers = set(self.atv.all_identifiers)
|
||||||
# Also abort if an integration with this identifier already exists
|
# Also abort if an integration with this identifier already exists
|
||||||
await self.async_set_unique_id(self.device_identifier)
|
await self.async_set_unique_id(self.device_identifier)
|
||||||
# but be sure to update the address if its changed so the scanner
|
# but be sure to update the address if its changed so the scanner
|
||||||
|
@ -310,7 +312,6 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
self._abort_if_unique_id_configured(
|
self._abort_if_unique_id_configured(
|
||||||
updates={CONF_ADDRESS: str(self.atv.address)}
|
updates={CONF_ADDRESS: str(self.atv.address)}
|
||||||
)
|
)
|
||||||
self.context["identifier"] = self.unique_id
|
|
||||||
return await self.async_step_confirm()
|
return await self.async_step_confirm()
|
||||||
|
|
||||||
async def async_find_device_wrapper(
|
async def async_find_device_wrapper(
|
||||||
|
@ -390,7 +391,7 @@ class AppleTVConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
"""Handle user-confirmation of discovered node."""
|
"""Handle user-confirmation of discovered node."""
|
||||||
assert self.atv
|
assert self.atv
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
expected_identifier_count = len(self.context["all_identifiers"])
|
expected_identifier_count = len(self.all_identifiers)
|
||||||
# If number of services found during device scan mismatch number of
|
# If number of services found during device scan mismatch number of
|
||||||
# identifiers collected during Zeroconf discovery, then trigger a new scan
|
# identifiers collected during Zeroconf discovery, then trigger a new scan
|
||||||
# with hopes of finding all services.
|
# with hopes of finding all services.
|
||||||
|
|
|
@ -1544,6 +1544,35 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||||
notification_id=DISCOVERY_NOTIFICATION_ID,
|
notification_id=DISCOVERY_NOTIFICATION_ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_has_matching_discovery_flow(
|
||||||
|
self, handler: str, match_context: dict[str, Any], data: Any
|
||||||
|
) -> bool:
|
||||||
|
"""Check if an existing matching discovery flow is in progress.
|
||||||
|
|
||||||
|
A flow with the same handler, context, and data.
|
||||||
|
|
||||||
|
If match_context is passed, only return flows with a context that is a
|
||||||
|
superset of match_context.
|
||||||
|
"""
|
||||||
|
if not (flows := self._handler_progress_index.get(handler)):
|
||||||
|
return False
|
||||||
|
match_items = match_context.items()
|
||||||
|
for progress in flows:
|
||||||
|
if match_items <= progress.context.items() and progress.init_data == data:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_has_matching_flow(self, flow: ConfigFlow) -> bool:
|
||||||
|
"""Check if an existing matching flow is in progress."""
|
||||||
|
if not (flows := self._handler_progress_index.get(flow.handler)):
|
||||||
|
return False
|
||||||
|
for other_flow in flows:
|
||||||
|
if other_flow is not flow and flow.is_matching(other_flow): # type: ignore[arg-type]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ConfigEntryItems(UserDict[str, ConfigEntry]):
|
class ConfigEntryItems(UserDict[str, ConfigEntry]):
|
||||||
"""Container for config items, maps config_entry_id -> entry.
|
"""Container for config items, maps config_entry_id -> entry.
|
||||||
|
@ -2693,6 +2722,10 @@ class ConfigFlow(ConfigEntryBaseFlow):
|
||||||
self.hass.config_entries.async_schedule_reload(entry.entry_id)
|
self.hass.config_entries.async_schedule_reload(entry.entry_id)
|
||||||
return self.async_abort(reason=reason)
|
return self.async_abort(reason=reason)
|
||||||
|
|
||||||
|
def is_matching(self, other_flow: Self) -> bool:
|
||||||
|
"""Return True if other_flow is matching this flow."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
|
||||||
"""Flow to set options for a configuration entry."""
|
"""Flow to set options for a configuration entry."""
|
||||||
|
|
|
@ -237,25 +237,6 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Entry has finished executing its first step asynchronously."""
|
"""Entry has finished executing its first step asynchronously."""
|
||||||
|
|
||||||
@callback
|
|
||||||
def async_has_matching_flow(
|
|
||||||
self, handler: _HandlerT, match_context: dict[str, Any], data: Any
|
|
||||||
) -> bool:
|
|
||||||
"""Check if an existing matching flow is in progress.
|
|
||||||
|
|
||||||
A flow with the same handler, context, and data.
|
|
||||||
|
|
||||||
If match_context is passed, only return flows with a context that is a
|
|
||||||
superset of match_context.
|
|
||||||
"""
|
|
||||||
if not (flows := self._handler_progress_index.get(handler)):
|
|
||||||
return False
|
|
||||||
match_items = match_context.items()
|
|
||||||
for progress in flows:
|
|
||||||
if match_items <= progress.context.items() and progress.init_data == data:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get(self, flow_id: str) -> _FlowResultT:
|
def async_get(self, flow_id: str) -> _FlowResultT:
|
||||||
"""Return a flow in progress as a partial FlowResult."""
|
"""Return a flow in progress as a partial FlowResult."""
|
||||||
|
|
|
@ -78,7 +78,9 @@ def _async_init_flow(
|
||||||
# which can overload devices since zeroconf/ssdp updates can happen
|
# which can overload devices since zeroconf/ssdp updates can happen
|
||||||
# multiple times in the same minute
|
# multiple times in the same minute
|
||||||
if (
|
if (
|
||||||
hass.config_entries.flow.async_has_matching_flow(domain, context, data)
|
hass.config_entries.flow.async_has_matching_discovery_flow(
|
||||||
|
domain, context, data
|
||||||
|
)
|
||||||
or hass.is_stopping
|
or hass.is_stopping
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -91,7 +91,7 @@ async def test_async_create_flow_checks_existing_flows_after_startup(
|
||||||
"""Test existing flows prevent an identical ones from being after startup."""
|
"""Test existing flows prevent an identical ones from being after startup."""
|
||||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.data_entry_flow.FlowManager.async_has_matching_flow",
|
"homeassistant.config_entries.ConfigEntriesFlowManager.async_has_matching_discovery_flow",
|
||||||
return_value=True,
|
return_value=True,
|
||||||
):
|
):
|
||||||
discovery_flow.async_create_flow(
|
discovery_flow.async_create_flow(
|
||||||
|
|
|
@ -7,7 +7,7 @@ from collections.abc import Generator
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, Self
|
||||||
from unittest.mock import ANY, AsyncMock, Mock, patch
|
from unittest.mock import ANY, AsyncMock, Mock, patch
|
||||||
|
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
@ -6180,3 +6180,204 @@ async def test_async_loaded_entries(
|
||||||
assert await hass.config_entries.async_unload(entry1.entry_id)
|
assert await hass.config_entries.async_unload(entry1.entry_id)
|
||||||
|
|
||||||
assert hass.config_entries.async_loaded_entries("comp") == []
|
assert hass.config_entries.async_loaded_entries("comp") == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_has_matching_discovery_flow(
|
||||||
|
hass: HomeAssistant, manager: config_entries.ConfigEntries
|
||||||
|
) -> None:
|
||||||
|
"""Test we can check for matching discovery flows."""
|
||||||
|
assert (
|
||||||
|
manager.flow.async_has_matching_discovery_flow(
|
||||||
|
"test",
|
||||||
|
{"source": config_entries.SOURCE_HOMEKIT},
|
||||||
|
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_integration(hass, MockModule("test"))
|
||||||
|
mock_platform(hass, "test.config_flow", None)
|
||||||
|
|
||||||
|
class TestFlow(config_entries.ConfigFlow):
|
||||||
|
VERSION = 5
|
||||||
|
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
return self.async_show_progress(
|
||||||
|
step_id="init",
|
||||||
|
progress_action="task_one",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_step_homekit(self, discovery_info=None):
|
||||||
|
return await self.async_step_init(discovery_info)
|
||||||
|
|
||||||
|
with mock_config_flow("test", TestFlow):
|
||||||
|
result = await manager.flow.async_init(
|
||||||
|
"test",
|
||||||
|
context={"source": config_entries.SOURCE_HOMEKIT},
|
||||||
|
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||||
|
)
|
||||||
|
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
|
||||||
|
assert result["progress_action"] == "task_one"
|
||||||
|
assert len(manager.flow.async_progress()) == 1
|
||||||
|
assert len(manager.flow.async_progress_by_handler("test")) == 1
|
||||||
|
assert (
|
||||||
|
len(
|
||||||
|
manager.flow.async_progress_by_handler(
|
||||||
|
"test", match_context={"source": config_entries.SOURCE_HOMEKIT}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
== 1
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
len(
|
||||||
|
manager.flow.async_progress_by_handler(
|
||||||
|
"test", match_context={"source": config_entries.SOURCE_BLUETOOTH}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
assert manager.flow.async_get(result["flow_id"])["handler"] == "test"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
manager.flow.async_has_matching_discovery_flow(
|
||||||
|
"test",
|
||||||
|
{"source": config_entries.SOURCE_HOMEKIT},
|
||||||
|
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
manager.flow.async_has_matching_discovery_flow(
|
||||||
|
"test",
|
||||||
|
{"source": config_entries.SOURCE_SSDP},
|
||||||
|
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
manager.flow.async_has_matching_discovery_flow(
|
||||||
|
"other",
|
||||||
|
{"source": config_entries.SOURCE_HOMEKIT},
|
||||||
|
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_has_matching_flow(
|
||||||
|
hass: HomeAssistant, manager: config_entries.ConfigEntries
|
||||||
|
) -> None:
|
||||||
|
"""Test check for matching flows when there is no active flow."""
|
||||||
|
mock_integration(hass, MockModule("test"))
|
||||||
|
mock_platform(hass, "test.config_flow", None)
|
||||||
|
|
||||||
|
class TestFlow(config_entries.ConfigFlow):
|
||||||
|
VERSION = 5
|
||||||
|
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
return self.async_show_progress(
|
||||||
|
step_id="init",
|
||||||
|
progress_action="task_one",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_step_homekit(self, discovery_info=None):
|
||||||
|
return await self.async_step_init(discovery_info)
|
||||||
|
|
||||||
|
def is_matching(self, other_flow: Self) -> bool:
|
||||||
|
"""Return True if other_flow is matching this flow."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Initiate a flow
|
||||||
|
with mock_config_flow("test", TestFlow):
|
||||||
|
await manager.flow.async_init(
|
||||||
|
"test",
|
||||||
|
context={"source": config_entries.SOURCE_HOMEKIT},
|
||||||
|
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||||
|
)
|
||||||
|
flow = list(manager.flow._handler_progress_index.get("test"))[0]
|
||||||
|
|
||||||
|
assert manager.flow.async_has_matching_flow(flow) is False
|
||||||
|
|
||||||
|
# Initiate another flow
|
||||||
|
with mock_config_flow("test", TestFlow):
|
||||||
|
await manager.flow.async_init(
|
||||||
|
"test",
|
||||||
|
context={"source": config_entries.SOURCE_HOMEKIT},
|
||||||
|
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert manager.flow.async_has_matching_flow(flow) is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_has_matching_flow_no_flows(
|
||||||
|
hass: HomeAssistant, manager: config_entries.ConfigEntries
|
||||||
|
) -> None:
|
||||||
|
"""Test check for matching flows when there is no active flow."""
|
||||||
|
mock_integration(hass, MockModule("test"))
|
||||||
|
mock_platform(hass, "test.config_flow", None)
|
||||||
|
|
||||||
|
class TestFlow(config_entries.ConfigFlow):
|
||||||
|
VERSION = 5
|
||||||
|
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
return self.async_show_progress(
|
||||||
|
step_id="init",
|
||||||
|
progress_action="task_one",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_step_homekit(self, discovery_info=None):
|
||||||
|
return await self.async_step_init(discovery_info)
|
||||||
|
|
||||||
|
with mock_config_flow("test", TestFlow):
|
||||||
|
result = await manager.flow.async_init(
|
||||||
|
"test",
|
||||||
|
context={"source": config_entries.SOURCE_HOMEKIT},
|
||||||
|
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||||
|
)
|
||||||
|
flow = list(manager.flow._handler_progress_index.get("test"))[0]
|
||||||
|
|
||||||
|
# Abort the flow before checking for matching flows
|
||||||
|
manager.flow.async_abort(result["flow_id"])
|
||||||
|
|
||||||
|
assert manager.flow.async_has_matching_flow(flow) is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_has_matching_flow_not_implemented(
|
||||||
|
hass: HomeAssistant, manager: config_entries.ConfigEntries
|
||||||
|
) -> None:
|
||||||
|
"""Test check for matching flows when there is no active flow."""
|
||||||
|
mock_integration(hass, MockModule("test"))
|
||||||
|
mock_platform(hass, "test.config_flow", None)
|
||||||
|
|
||||||
|
class TestFlow(config_entries.ConfigFlow):
|
||||||
|
VERSION = 5
|
||||||
|
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
return self.async_show_progress(
|
||||||
|
step_id="init",
|
||||||
|
progress_action="task_one",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_step_homekit(self, discovery_info=None):
|
||||||
|
return await self.async_step_init(discovery_info)
|
||||||
|
|
||||||
|
# Initiate a flow
|
||||||
|
with mock_config_flow("test", TestFlow):
|
||||||
|
await manager.flow.async_init(
|
||||||
|
"test",
|
||||||
|
context={"source": config_entries.SOURCE_HOMEKIT},
|
||||||
|
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||||
|
)
|
||||||
|
flow = list(manager.flow._handler_progress_index.get("test"))[0]
|
||||||
|
|
||||||
|
# Initiate another flow
|
||||||
|
with mock_config_flow("test", TestFlow):
|
||||||
|
await manager.flow.async_init(
|
||||||
|
"test",
|
||||||
|
context={"source": config_entries.SOURCE_HOMEKIT},
|
||||||
|
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# The flow does not implement is_matching
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
manager.flow.async_has_matching_flow(flow)
|
||||||
|
|
|
@ -781,83 +781,6 @@ async def test_async_get_unknown_flow(manager: MockFlowManager) -> None:
|
||||||
await manager.async_get("does_not_exist")
|
await manager.async_get("does_not_exist")
|
||||||
|
|
||||||
|
|
||||||
async def test_async_has_matching_flow(
|
|
||||||
hass: HomeAssistant, manager: MockFlowManager
|
|
||||||
) -> None:
|
|
||||||
"""Test we can check for matching flows."""
|
|
||||||
manager.hass = hass
|
|
||||||
assert (
|
|
||||||
manager.async_has_matching_flow(
|
|
||||||
"test",
|
|
||||||
{"source": config_entries.SOURCE_HOMEKIT},
|
|
||||||
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
|
||||||
)
|
|
||||||
is False
|
|
||||||
)
|
|
||||||
|
|
||||||
@manager.mock_reg_handler("test")
|
|
||||||
class TestFlow(data_entry_flow.FlowHandler):
|
|
||||||
VERSION = 5
|
|
||||||
|
|
||||||
async def async_step_init(self, user_input=None):
|
|
||||||
return self.async_show_progress(
|
|
||||||
step_id="init",
|
|
||||||
progress_action="task_one",
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await manager.async_init(
|
|
||||||
"test",
|
|
||||||
context={"source": config_entries.SOURCE_HOMEKIT},
|
|
||||||
data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
|
||||||
)
|
|
||||||
assert result["type"] == data_entry_flow.FlowResultType.SHOW_PROGRESS
|
|
||||||
assert result["progress_action"] == "task_one"
|
|
||||||
assert len(manager.async_progress()) == 1
|
|
||||||
assert len(manager.async_progress_by_handler("test")) == 1
|
|
||||||
assert (
|
|
||||||
len(
|
|
||||||
manager.async_progress_by_handler(
|
|
||||||
"test", match_context={"source": config_entries.SOURCE_HOMEKIT}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
== 1
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
len(
|
|
||||||
manager.async_progress_by_handler(
|
|
||||||
"test", match_context={"source": config_entries.SOURCE_BLUETOOTH}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
== 0
|
|
||||||
)
|
|
||||||
assert manager.async_get(result["flow_id"])["handler"] == "test"
|
|
||||||
|
|
||||||
assert (
|
|
||||||
manager.async_has_matching_flow(
|
|
||||||
"test",
|
|
||||||
{"source": config_entries.SOURCE_HOMEKIT},
|
|
||||||
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
|
||||||
)
|
|
||||||
is True
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
manager.async_has_matching_flow(
|
|
||||||
"test",
|
|
||||||
{"source": config_entries.SOURCE_SSDP},
|
|
||||||
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
|
||||||
)
|
|
||||||
is False
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
manager.async_has_matching_flow(
|
|
||||||
"other",
|
|
||||||
{"source": config_entries.SOURCE_HOMEKIT},
|
|
||||||
{"properties": {"id": "aa:bb:cc:dd:ee:ff"}},
|
|
||||||
)
|
|
||||||
is False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_move_to_unknown_step_raises_and_removes_from_in_progress(
|
async def test_move_to_unknown_step_raises_and_removes_from_in_progress(
|
||||||
manager: MockFlowManager,
|
manager: MockFlowManager,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
Loading…
Add table
Reference in a new issue