Only wait for import flows to initialize at setup (#86106)

* Only wait for import flows to initialize at setup

* Update hassio tests

* Update hassio tests

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Erik Montnemery 2023-01-18 10:44:18 +01:00 committed by GitHub
parent 767b43bb0e
commit f17a829bd8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 69 additions and 27 deletions

View file

@ -266,8 +266,12 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.async_create_task(
hass.config_entries.flow.async_init(
DOMAIN,
context={"source": source},
data={CONF_API_KEY: entry.data[CONF_API_KEY], **geography},
context={"source": SOURCE_IMPORT},
data={
"import_source": source,
CONF_API_KEY: entry.data[CONF_API_KEY],
**geography,
},
)
)

View file

@ -171,6 +171,13 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Define the config flow to handle options."""
return SchemaOptionsFlowHandler(config_entry, OPTIONS_FLOW)
async def async_step_import(self, import_data: dict[str, str]) -> FlowResult:
"""Handle import of config entry version 1 data."""
import_source = import_data.pop("import_source")
if import_source == "geography_by_coords":
return await self.async_step_geography_by_coords(import_data)
return await self.async_step_geography_by_name(import_data)
async def async_step_geography_by_coords(
self, user_input: dict[str, str] | None = None
) -> FlowResult:

View file

@ -761,12 +761,12 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
super().__init__(hass)
self.config_entries = config_entries
self._hass_config = hass_config
self._initializing: dict[str, dict[str, asyncio.Future]] = {}
self._pending_import_flows: dict[str, dict[str, asyncio.Future[None]]] = {}
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
async def async_wait_init_flow_finish(self, handler: str) -> None:
"""Wait till all flows in progress are initialized."""
if not (current := self._initializing.get(handler)):
async def async_wait_import_flow_initialized(self, handler: str) -> None:
"""Wait till all import flows in progress are initialized."""
if not (current := self._pending_import_flows.get(handler)):
return
await asyncio.wait(current.values())
@ -783,12 +783,13 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
) -> FlowResult:
"""Start a configuration flow."""
if context is None:
context = {}
if not context or "source" not in context:
raise KeyError("Context not set or doesn't have a source set")
flow_id = uuid_util.random_uuid_hex()
init_done: asyncio.Future = asyncio.Future()
self._initializing.setdefault(handler, {})[flow_id] = init_done
if context["source"] == SOURCE_IMPORT:
init_done: asyncio.Future[None] = asyncio.Future()
self._pending_import_flows.setdefault(handler, {})[flow_id] = init_done
task = asyncio.create_task(self._async_init(flow_id, handler, context, data))
self._initialize_tasks.setdefault(handler, []).append(task)
@ -797,7 +798,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
flow, result = await task
finally:
self._initialize_tasks[handler].remove(task)
self._initializing[handler].pop(flow_id)
self._pending_import_flows.get(handler, {}).pop(flow_id, None)
if result["type"] != data_entry_flow.FlowResultType.ABORT:
await self.async_post_init(flow, result)
@ -824,8 +825,8 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
try:
result = await self._async_handle_step(flow, flow.init_step, data)
finally:
init_done = self._initializing[handler][flow_id]
if not init_done.done():
init_done = self._pending_import_flows.get(handler, {}).get(flow_id)
if init_done and not init_done.done():
init_done.set_result(None)
return flow, result
@ -845,7 +846,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
# We do this to avoid a circular dependency where async_finish_flow sets up a
# new entry, which needs the integration to be set up, which is waiting for
# init to be done.
init_done = self._initializing[flow.handler].get(flow.flow_id)
init_done = self._pending_import_flows.get(flow.handler, {}).get(flow.flow_id)
if init_done and not init_done.done():
init_done.set_result(None)

View file

@ -286,7 +286,7 @@ async def _async_setup_component(
# Flush out async_setup calling create_task. Fragile but covered by test.
await asyncio.sleep(0)
await hass.config_entries.flow.async_wait_init_flow_finish(domain)
await hass.config_entries.flow.async_wait_import_flow_initialized(domain)
# Add to components before the entry.async_setup
# call to avoid a deadlock when forwarding platforms

View file

@ -25,6 +25,8 @@ from tests.common import MockConfigEntry, load_fixture
TEST_API_KEY = "abcde12345"
TEST_LATITUDE = 51.528308
TEST_LONGITUDE = -0.3817765
TEST_LATITUDE2 = 37.514626
TEST_LONGITUDE2 = 127.057414
COORDS_CONFIG = {
CONF_API_KEY: TEST_API_KEY,
@ -32,6 +34,12 @@ COORDS_CONFIG = {
CONF_LONGITUDE: TEST_LONGITUDE,
}
COORDS_CONFIG2 = {
CONF_API_KEY: TEST_API_KEY,
CONF_LATITUDE: TEST_LATITUDE2,
CONF_LONGITUDE: TEST_LONGITUDE2,
}
TEST_CITY = "Beijing"
TEST_STATE = "Beijing"
TEST_COUNTRY = "China"

View file

@ -24,12 +24,15 @@ from homeassistant.helpers import device_registry as dr, issue_registry as ir
from .conftest import (
COORDS_CONFIG,
COORDS_CONFIG2,
NAME_CONFIG,
TEST_API_KEY,
TEST_CITY,
TEST_COUNTRY,
TEST_LATITUDE,
TEST_LATITUDE2,
TEST_LONGITUDE,
TEST_LONGITUDE2,
TEST_STATE,
)
@ -53,6 +56,10 @@ async def test_migration_1_2(hass, mock_pyairvisual):
CONF_STATE: TEST_STATE,
CONF_COUNTRY: TEST_COUNTRY,
},
{
CONF_LATITUDE: TEST_LATITUDE2,
CONF_LONGITUDE: TEST_LONGITUDE2,
},
],
},
version=1,
@ -63,7 +70,7 @@ async def test_migration_1_2(hass, mock_pyairvisual):
await hass.async_block_till_done()
config_entries = hass.config_entries.async_entries(DOMAIN)
assert len(config_entries) == 2
assert len(config_entries) == 3
# Ensure that after migration, each configuration has its own config entry:
identifier1 = f"{TEST_LATITUDE}, {TEST_LONGITUDE}"
@ -82,6 +89,14 @@ async def test_migration_1_2(hass, mock_pyairvisual):
CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_GEOGRAPHY_NAME,
}
identifier3 = f"{TEST_LATITUDE2}, {TEST_LONGITUDE2}"
assert config_entries[2].unique_id == identifier3
assert config_entries[2].title == f"Cloud API ({identifier3})"
assert config_entries[2].data == {
**COORDS_CONFIG2,
CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_GEOGRAPHY_COORDS,
}
async def test_migration_2_3(hass, mock_pyairvisual):
"""Test migrating from version 2 to 3."""

View file

@ -202,8 +202,9 @@ async def test_setup_api_ping(hass, aioclient_mock):
"""Test setup with API ping."""
with patch.dict(os.environ, MOCK_ENVIRON):
result = await async_setup_component(hass, "hassio", {})
assert result
await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16
assert hass.components.hassio.get_core_info()["version_latest"] == "1.0.0"
assert hass.components.hassio.is_hassio()
@ -241,8 +242,9 @@ async def test_setup_api_push_api_data(hass, aioclient_mock):
result = await async_setup_component(
hass, "hassio", {"http": {"server_port": 9999}, "hassio": {}}
)
assert result
await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16
assert not aioclient_mock.mock_calls[1][2]["ssl"]
assert aioclient_mock.mock_calls[1][2]["port"] == 9999
@ -257,8 +259,9 @@ async def test_setup_api_push_api_data_server_host(hass, aioclient_mock):
"hassio",
{"http": {"server_port": 9999, "server_host": "127.0.0.1"}, "hassio": {}},
)
assert result
await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16
assert not aioclient_mock.mock_calls[1][2]["ssl"]
assert aioclient_mock.mock_calls[1][2]["port"] == 9999
@ -269,8 +272,9 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock, hass_storag
"""Test setup with API push default data."""
with patch.dict(os.environ, MOCK_ENVIRON):
result = await async_setup_component(hass, "hassio", {"http": {}, "hassio": {}})
assert result
await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16
assert not aioclient_mock.mock_calls[1][2]["ssl"]
assert aioclient_mock.mock_calls[1][2]["port"] == 8123
@ -336,8 +340,9 @@ async def test_setup_api_existing_hassio_user(hass, aioclient_mock, hass_storage
hass_storage[STORAGE_KEY] = {"version": 1, "data": {"hassio_user": user.id}}
with patch.dict(os.environ, MOCK_ENVIRON):
result = await async_setup_component(hass, "hassio", {"http": {}, "hassio": {}})
assert result
await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16
assert not aioclient_mock.mock_calls[1][2]["ssl"]
assert aioclient_mock.mock_calls[1][2]["port"] == 8123
@ -350,8 +355,9 @@ async def test_setup_core_push_timezone(hass, aioclient_mock):
with patch.dict(os.environ, MOCK_ENVIRON):
result = await async_setup_component(hass, "hassio", {"hassio": {}})
assert result
await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16
assert aioclient_mock.mock_calls[2][2]["timezone"] == "testzone"
@ -367,8 +373,9 @@ async def test_setup_hassio_no_additional_data(hass, aioclient_mock):
os.environ, {"SUPERVISOR_TOKEN": "123456"}
):
result = await async_setup_component(hass, "hassio", {"hassio": {}})
assert result
await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16
assert aioclient_mock.mock_calls[-1][3]["Authorization"] == "Bearer 123456"
@ -768,9 +775,9 @@ async def test_setup_hardware_integration(hass, aioclient_mock, integration):
return_value=True,
) as mock_setup_entry:
result = await async_setup_component(hass, "hassio", {"hassio": {}})
assert result
await hass.async_block_till_done()
assert result
assert aioclient_mock.call_count == 16
assert len(mock_setup_entry.mock_calls) == 1

View file

@ -542,8 +542,8 @@ async def test_setting_up_core_update_when_addon_fails(hass, caplog):
"hassio",
{"http": {"server_port": 9999, "server_host": "127.0.0.1"}, "hassio": {}},
)
assert result
await hass.async_block_till_done()
await hass.async_block_till_done()
assert result
# Verify that the core update entity does exist
state = hass.states.get("update.home_assistant_core_update")

View file

@ -1428,7 +1428,7 @@ async def test_init_custom_integration(hass):
"homeassistant.loader.async_get_integration",
return_value=integration,
):
await hass.config_entries.flow.async_init("bla")
await hass.config_entries.flow.async_init("bla", context={"source": "user"})
async def test_support_entry_unload(hass):