Only wait for import flows to initialize at setup (#86106)
* Only wait for import flows to initialize at setup * Update hassio tests * Update hassio tests * Apply suggestions from code review Co-authored-by: Martin Hjelmare <marhje52@gmail.com> Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
767b43bb0e
commit
f17a829bd8
9 changed files with 69 additions and 27 deletions
|
@ -266,8 +266,12 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
hass.async_create_task(
|
||||
hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": source},
|
||||
data={CONF_API_KEY: entry.data[CONF_API_KEY], **geography},
|
||||
context={"source": SOURCE_IMPORT},
|
||||
data={
|
||||
"import_source": source,
|
||||
CONF_API_KEY: entry.data[CONF_API_KEY],
|
||||
**geography,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -171,6 +171,13 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
"""Define the config flow to handle options."""
|
||||
return SchemaOptionsFlowHandler(config_entry, OPTIONS_FLOW)
|
||||
|
||||
async def async_step_import(self, import_data: dict[str, str]) -> FlowResult:
|
||||
"""Handle import of config entry version 1 data."""
|
||||
import_source = import_data.pop("import_source")
|
||||
if import_source == "geography_by_coords":
|
||||
return await self.async_step_geography_by_coords(import_data)
|
||||
return await self.async_step_geography_by_name(import_data)
|
||||
|
||||
async def async_step_geography_by_coords(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
|
|
|
@ -761,12 +761,12 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
|
|||
super().__init__(hass)
|
||||
self.config_entries = config_entries
|
||||
self._hass_config = hass_config
|
||||
self._initializing: dict[str, dict[str, asyncio.Future]] = {}
|
||||
self._pending_import_flows: dict[str, dict[str, asyncio.Future[None]]] = {}
|
||||
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
|
||||
|
||||
async def async_wait_init_flow_finish(self, handler: str) -> None:
|
||||
"""Wait till all flows in progress are initialized."""
|
||||
if not (current := self._initializing.get(handler)):
|
||||
async def async_wait_import_flow_initialized(self, handler: str) -> None:
|
||||
"""Wait till all import flows in progress are initialized."""
|
||||
if not (current := self._pending_import_flows.get(handler)):
|
||||
return
|
||||
|
||||
await asyncio.wait(current.values())
|
||||
|
@ -783,12 +783,13 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
|
|||
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
|
||||
) -> FlowResult:
|
||||
"""Start a configuration flow."""
|
||||
if context is None:
|
||||
context = {}
|
||||
if not context or "source" not in context:
|
||||
raise KeyError("Context not set or doesn't have a source set")
|
||||
|
||||
flow_id = uuid_util.random_uuid_hex()
|
||||
init_done: asyncio.Future = asyncio.Future()
|
||||
self._initializing.setdefault(handler, {})[flow_id] = init_done
|
||||
if context["source"] == SOURCE_IMPORT:
|
||||
init_done: asyncio.Future[None] = asyncio.Future()
|
||||
self._pending_import_flows.setdefault(handler, {})[flow_id] = init_done
|
||||
|
||||
task = asyncio.create_task(self._async_init(flow_id, handler, context, data))
|
||||
self._initialize_tasks.setdefault(handler, []).append(task)
|
||||
|
@ -797,7 +798,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
|
|||
flow, result = await task
|
||||
finally:
|
||||
self._initialize_tasks[handler].remove(task)
|
||||
self._initializing[handler].pop(flow_id)
|
||||
self._pending_import_flows.get(handler, {}).pop(flow_id, None)
|
||||
|
||||
if result["type"] != data_entry_flow.FlowResultType.ABORT:
|
||||
await self.async_post_init(flow, result)
|
||||
|
@ -824,8 +825,8 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
|
|||
try:
|
||||
result = await self._async_handle_step(flow, flow.init_step, data)
|
||||
finally:
|
||||
init_done = self._initializing[handler][flow_id]
|
||||
if not init_done.done():
|
||||
init_done = self._pending_import_flows.get(handler, {}).get(flow_id)
|
||||
if init_done and not init_done.done():
|
||||
init_done.set_result(None)
|
||||
return flow, result
|
||||
|
||||
|
@ -845,7 +846,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
|
|||
# We do this to avoid a circular dependency where async_finish_flow sets up a
|
||||
# new entry, which needs the integration to be set up, which is waiting for
|
||||
# init to be done.
|
||||
init_done = self._initializing[flow.handler].get(flow.flow_id)
|
||||
init_done = self._pending_import_flows.get(flow.handler, {}).get(flow.flow_id)
|
||||
if init_done and not init_done.done():
|
||||
init_done.set_result(None)
|
||||
|
||||
|
|
|
@ -286,7 +286,7 @@ async def _async_setup_component(
|
|||
|
||||
# Flush out async_setup calling create_task. Fragile but covered by test.
|
||||
await asyncio.sleep(0)
|
||||
await hass.config_entries.flow.async_wait_init_flow_finish(domain)
|
||||
await hass.config_entries.flow.async_wait_import_flow_initialized(domain)
|
||||
|
||||
# Add to components before the entry.async_setup
|
||||
# call to avoid a deadlock when forwarding platforms
|
||||
|
|
|
@ -25,6 +25,8 @@ from tests.common import MockConfigEntry, load_fixture
|
|||
TEST_API_KEY = "abcde12345"
|
||||
TEST_LATITUDE = 51.528308
|
||||
TEST_LONGITUDE = -0.3817765
|
||||
TEST_LATITUDE2 = 37.514626
|
||||
TEST_LONGITUDE2 = 127.057414
|
||||
|
||||
COORDS_CONFIG = {
|
||||
CONF_API_KEY: TEST_API_KEY,
|
||||
|
@ -32,6 +34,12 @@ COORDS_CONFIG = {
|
|||
CONF_LONGITUDE: TEST_LONGITUDE,
|
||||
}
|
||||
|
||||
COORDS_CONFIG2 = {
|
||||
CONF_API_KEY: TEST_API_KEY,
|
||||
CONF_LATITUDE: TEST_LATITUDE2,
|
||||
CONF_LONGITUDE: TEST_LONGITUDE2,
|
||||
}
|
||||
|
||||
TEST_CITY = "Beijing"
|
||||
TEST_STATE = "Beijing"
|
||||
TEST_COUNTRY = "China"
|
||||
|
|
|
@ -24,12 +24,15 @@ from homeassistant.helpers import device_registry as dr, issue_registry as ir
|
|||
|
||||
from .conftest import (
|
||||
COORDS_CONFIG,
|
||||
COORDS_CONFIG2,
|
||||
NAME_CONFIG,
|
||||
TEST_API_KEY,
|
||||
TEST_CITY,
|
||||
TEST_COUNTRY,
|
||||
TEST_LATITUDE,
|
||||
TEST_LATITUDE2,
|
||||
TEST_LONGITUDE,
|
||||
TEST_LONGITUDE2,
|
||||
TEST_STATE,
|
||||
)
|
||||
|
||||
|
@ -53,6 +56,10 @@ async def test_migration_1_2(hass, mock_pyairvisual):
|
|||
CONF_STATE: TEST_STATE,
|
||||
CONF_COUNTRY: TEST_COUNTRY,
|
||||
},
|
||||
{
|
||||
CONF_LATITUDE: TEST_LATITUDE2,
|
||||
CONF_LONGITUDE: TEST_LONGITUDE2,
|
||||
},
|
||||
],
|
||||
},
|
||||
version=1,
|
||||
|
@ -63,7 +70,7 @@ async def test_migration_1_2(hass, mock_pyairvisual):
|
|||
await hass.async_block_till_done()
|
||||
|
||||
config_entries = hass.config_entries.async_entries(DOMAIN)
|
||||
assert len(config_entries) == 2
|
||||
assert len(config_entries) == 3
|
||||
|
||||
# Ensure that after migration, each configuration has its own config entry:
|
||||
identifier1 = f"{TEST_LATITUDE}, {TEST_LONGITUDE}"
|
||||
|
@ -82,6 +89,14 @@ async def test_migration_1_2(hass, mock_pyairvisual):
|
|||
CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_GEOGRAPHY_NAME,
|
||||
}
|
||||
|
||||
identifier3 = f"{TEST_LATITUDE2}, {TEST_LONGITUDE2}"
|
||||
assert config_entries[2].unique_id == identifier3
|
||||
assert config_entries[2].title == f"Cloud API ({identifier3})"
|
||||
assert config_entries[2].data == {
|
||||
**COORDS_CONFIG2,
|
||||
CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_GEOGRAPHY_COORDS,
|
||||
}
|
||||
|
||||
|
||||
async def test_migration_2_3(hass, mock_pyairvisual):
|
||||
"""Test migrating from version 2 to 3."""
|
||||
|
|
|
@ -202,8 +202,9 @@ async def test_setup_api_ping(hass, aioclient_mock):
|
|||
"""Test setup with API ping."""
|
||||
with patch.dict(os.environ, MOCK_ENVIRON):
|
||||
result = await async_setup_component(hass, "hassio", {})
|
||||
assert result
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result
|
||||
assert aioclient_mock.call_count == 16
|
||||
assert hass.components.hassio.get_core_info()["version_latest"] == "1.0.0"
|
||||
assert hass.components.hassio.is_hassio()
|
||||
|
@ -241,8 +242,9 @@ async def test_setup_api_push_api_data(hass, aioclient_mock):
|
|||
result = await async_setup_component(
|
||||
hass, "hassio", {"http": {"server_port": 9999}, "hassio": {}}
|
||||
)
|
||||
assert result
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result
|
||||
assert aioclient_mock.call_count == 16
|
||||
assert not aioclient_mock.mock_calls[1][2]["ssl"]
|
||||
assert aioclient_mock.mock_calls[1][2]["port"] == 9999
|
||||
|
@ -257,8 +259,9 @@ async def test_setup_api_push_api_data_server_host(hass, aioclient_mock):
|
|||
"hassio",
|
||||
{"http": {"server_port": 9999, "server_host": "127.0.0.1"}, "hassio": {}},
|
||||
)
|
||||
assert result
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result
|
||||
assert aioclient_mock.call_count == 16
|
||||
assert not aioclient_mock.mock_calls[1][2]["ssl"]
|
||||
assert aioclient_mock.mock_calls[1][2]["port"] == 9999
|
||||
|
@ -269,8 +272,9 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock, hass_storag
|
|||
"""Test setup with API push default data."""
|
||||
with patch.dict(os.environ, MOCK_ENVIRON):
|
||||
result = await async_setup_component(hass, "hassio", {"http": {}, "hassio": {}})
|
||||
assert result
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result
|
||||
assert aioclient_mock.call_count == 16
|
||||
assert not aioclient_mock.mock_calls[1][2]["ssl"]
|
||||
assert aioclient_mock.mock_calls[1][2]["port"] == 8123
|
||||
|
@ -336,8 +340,9 @@ async def test_setup_api_existing_hassio_user(hass, aioclient_mock, hass_storage
|
|||
hass_storage[STORAGE_KEY] = {"version": 1, "data": {"hassio_user": user.id}}
|
||||
with patch.dict(os.environ, MOCK_ENVIRON):
|
||||
result = await async_setup_component(hass, "hassio", {"http": {}, "hassio": {}})
|
||||
assert result
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result
|
||||
assert aioclient_mock.call_count == 16
|
||||
assert not aioclient_mock.mock_calls[1][2]["ssl"]
|
||||
assert aioclient_mock.mock_calls[1][2]["port"] == 8123
|
||||
|
@ -350,8 +355,9 @@ async def test_setup_core_push_timezone(hass, aioclient_mock):
|
|||
|
||||
with patch.dict(os.environ, MOCK_ENVIRON):
|
||||
result = await async_setup_component(hass, "hassio", {"hassio": {}})
|
||||
assert result
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result
|
||||
assert aioclient_mock.call_count == 16
|
||||
assert aioclient_mock.mock_calls[2][2]["timezone"] == "testzone"
|
||||
|
||||
|
@ -367,8 +373,9 @@ async def test_setup_hassio_no_additional_data(hass, aioclient_mock):
|
|||
os.environ, {"SUPERVISOR_TOKEN": "123456"}
|
||||
):
|
||||
result = await async_setup_component(hass, "hassio", {"hassio": {}})
|
||||
assert result
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result
|
||||
assert aioclient_mock.call_count == 16
|
||||
assert aioclient_mock.mock_calls[-1][3]["Authorization"] == "Bearer 123456"
|
||||
|
||||
|
@ -768,9 +775,9 @@ async def test_setup_hardware_integration(hass, aioclient_mock, integration):
|
|||
return_value=True,
|
||||
) as mock_setup_entry:
|
||||
result = await async_setup_component(hass, "hassio", {"hassio": {}})
|
||||
assert result
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result
|
||||
assert aioclient_mock.call_count == 16
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
|
|
@ -542,8 +542,8 @@ async def test_setting_up_core_update_when_addon_fails(hass, caplog):
|
|||
"hassio",
|
||||
{"http": {"server_port": 9999, "server_host": "127.0.0.1"}, "hassio": {}},
|
||||
)
|
||||
assert result
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
assert result
|
||||
|
||||
# Verify that the core update entity does exist
|
||||
state = hass.states.get("update.home_assistant_core_update")
|
||||
|
|
|
@ -1428,7 +1428,7 @@ async def test_init_custom_integration(hass):
|
|||
"homeassistant.loader.async_get_integration",
|
||||
return_value=integration,
|
||||
):
|
||||
await hass.config_entries.flow.async_init("bla")
|
||||
await hass.config_entries.flow.async_init("bla", context={"source": "user"})
|
||||
|
||||
|
||||
async def test_support_entry_unload(hass):
|
||||
|
|
Loading…
Add table
Reference in a new issue