Allow async_setup changes to config entry data be taken into a… (#34166)
* Allow async_setup changes to config entry data be taken into account * Fix tests * Limit scope try…finally * Update tests/test_config_entries.py Co-Authored-By: Martin Hjelmare <marhje52@gmail.com> * Fix import Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
0db1fcca0f
commit
0b90ebf91e
10 changed files with 176 additions and 33 deletions
|
@ -22,6 +22,7 @@ from homeassistant.helpers.event import async_track_time_interval
|
|||
from .const import (
|
||||
CONF_CITY,
|
||||
CONF_COUNTRY,
|
||||
CONF_GEOGRAPHIES,
|
||||
DATA_CLIENT,
|
||||
DEFAULT_SCAN_INTERVAL,
|
||||
DOMAIN,
|
||||
|
@ -34,8 +35,6 @@ DATA_LISTENER = "listener"
|
|||
|
||||
DEFAULT_OPTIONS = {CONF_SHOW_ON_MAP: True}
|
||||
|
||||
CONF_GEOGRAPHIES = "geographies"
|
||||
|
||||
GEOGRAPHY_COORDINATES_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_LATITUDE): cv.latitude,
|
||||
|
@ -158,8 +157,7 @@ async def async_migrate_entry(hass, config_entry):
|
|||
|
||||
# Update the config entry to only include the first geography (there is always
|
||||
# guaranteed to be at least one):
|
||||
data = {**config_entry.data}
|
||||
geographies = data.pop(CONF_GEOGRAPHIES)
|
||||
geographies = list(config_entry.data[CONF_GEOGRAPHIES])
|
||||
first_geography = geographies.pop(0)
|
||||
first_id = async_get_geography_id(first_geography)
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from homeassistant.core import callback
|
|||
from homeassistant.helpers import aiohttp_client, config_validation as cv
|
||||
|
||||
from . import async_get_geography_id
|
||||
from .const import DOMAIN # pylint: disable=unused-import
|
||||
from .const import CONF_GEOGRAPHIES, DOMAIN # pylint: disable=unused-import
|
||||
|
||||
|
||||
class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
|
@ -69,6 +69,18 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
|
||||
geo_id = async_get_geography_id(user_input)
|
||||
await self._async_set_unique_id(geo_id)
|
||||
self._abort_if_unique_id_configured()
|
||||
|
||||
# Find older config entries without unique ID
|
||||
for entry in self._async_current_entries():
|
||||
if entry.version != 1:
|
||||
continue
|
||||
|
||||
if any(
|
||||
geo_id == async_get_geography_id(geography)
|
||||
for geography in entry.data[CONF_GEOGRAPHIES]
|
||||
):
|
||||
return self.async_abort(reason="already_configured")
|
||||
|
||||
websession = aiohttp_client.async_get_clientsession(self.hass)
|
||||
client = Client(websession, api_key=user_input[CONF_API_KEY])
|
||||
|
@ -90,9 +102,10 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
)
|
||||
|
||||
checked_keys.add(user_input[CONF_API_KEY])
|
||||
return self.async_create_entry(
|
||||
title=f"Cloud API ({geo_id})", data=user_input
|
||||
)
|
||||
|
||||
return self.async_create_entry(
|
||||
title=f"Cloud API ({geo_id})", data=user_input
|
||||
)
|
||||
|
||||
|
||||
class AirVisualOptionsFlowHandler(config_entries.OptionsFlow):
|
||||
|
|
|
@ -5,6 +5,7 @@ DOMAIN = "airvisual"
|
|||
|
||||
CONF_CITY = "city"
|
||||
CONF_COUNTRY = "country"
|
||||
CONF_GEOGRAPHIES = "geographies"
|
||||
|
||||
DATA_CLIENT = "client"
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Classes to help gather user submissions."""
|
||||
import abc
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
import uuid
|
||||
|
@ -53,8 +54,18 @@ class FlowManager(abc.ABC):
|
|||
def __init__(self, hass: HomeAssistant,) -> None:
|
||||
"""Initialize the flow manager."""
|
||||
self.hass = hass
|
||||
self._initializing: Dict[str, List[asyncio.Future]] = {}
|
||||
self._progress: Dict[str, Any] = {}
|
||||
|
||||
async def async_wait_init_flow_finish(self, handler: str) -> None:
|
||||
"""Wait till all flows in progress are initialized."""
|
||||
current = self._initializing.get(handler)
|
||||
|
||||
if not current:
|
||||
return
|
||||
|
||||
await asyncio.wait(current)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_create_flow(
|
||||
self,
|
||||
|
@ -94,8 +105,13 @@ class FlowManager(abc.ABC):
|
|||
"""Start a configuration flow."""
|
||||
if context is None:
|
||||
context = {}
|
||||
|
||||
init_done: asyncio.Future = asyncio.Future()
|
||||
self._initializing.setdefault(handler, []).append(init_done)
|
||||
|
||||
flow = await self.async_create_flow(handler, context=context, data=data)
|
||||
if not flow:
|
||||
self._initializing[handler].remove(init_done)
|
||||
raise UnknownFlow("Flow was not created")
|
||||
flow.hass = self.hass
|
||||
flow.handler = handler
|
||||
|
@ -103,7 +119,12 @@ class FlowManager(abc.ABC):
|
|||
flow.context = context
|
||||
self._progress[flow.flow_id] = flow
|
||||
|
||||
result = await self._async_handle_step(flow, flow.init_step, data)
|
||||
try:
|
||||
result = await self._async_handle_step(
|
||||
flow, flow.init_step, data, init_done
|
||||
)
|
||||
finally:
|
||||
self._initializing[handler].remove(init_done)
|
||||
|
||||
if result["type"] != RESULT_TYPE_ABORT:
|
||||
await self.async_post_init(flow, result)
|
||||
|
@ -154,13 +175,19 @@ class FlowManager(abc.ABC):
|
|||
raise UnknownFlow
|
||||
|
||||
async def _async_handle_step(
|
||||
self, flow: Any, step_id: str, user_input: Optional[Dict]
|
||||
self,
|
||||
flow: Any,
|
||||
step_id: str,
|
||||
user_input: Optional[Dict],
|
||||
step_done: Optional[asyncio.Future] = None,
|
||||
) -> Dict:
|
||||
"""Handle a step of a flow."""
|
||||
method = f"async_step_{step_id}"
|
||||
|
||||
if not hasattr(flow, method):
|
||||
self._progress.pop(flow.flow_id)
|
||||
if step_done:
|
||||
step_done.set_result(None)
|
||||
raise UnknownStep(
|
||||
f"Handler {flow.__class__.__name__} doesn't support step {step_id}"
|
||||
)
|
||||
|
@ -172,6 +199,13 @@ class FlowManager(abc.ABC):
|
|||
flow.flow_id, flow.handler, err.reason, err.description_placeholders
|
||||
)
|
||||
|
||||
# Mark the step as done.
|
||||
# We do this before calling async_finish_flow because config entries will hit a
|
||||
# circular dependency where async_finish_flow sets up new entry, which needs the
|
||||
# integration to be set up, which is waiting for init to be done.
|
||||
if step_done:
|
||||
step_done.set_result(None)
|
||||
|
||||
if result["type"] not in (
|
||||
RESULT_TYPE_FORM,
|
||||
RESULT_TYPE_EXTERNAL_STEP,
|
||||
|
|
|
@ -197,9 +197,12 @@ async def _async_setup_component(
|
|||
)
|
||||
return False
|
||||
|
||||
if hass.config_entries:
|
||||
for entry in hass.config_entries.async_entries(domain):
|
||||
await entry.async_setup(hass, integration=integration)
|
||||
# Flush out async_setup calling create_task. Fragile but covered by test.
|
||||
await asyncio.sleep(0)
|
||||
await hass.config_entries.flow.async_wait_init_flow_finish(domain)
|
||||
|
||||
for entry in hass.config_entries.async_entries(domain):
|
||||
await entry.async_setup(hass, integration=integration)
|
||||
|
||||
hass.config.components.add(domain)
|
||||
|
||||
|
|
|
@ -9,19 +9,6 @@ from .test_device import MAC, setup_axis_integration
|
|||
from tests.common import MockConfigEntry, mock_coro
|
||||
|
||||
|
||||
async def test_setup_device_already_configured(hass):
|
||||
"""Test already configured device does not configure a second."""
|
||||
with patch.object(hass, "config_entries") as mock_config_entries:
|
||||
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
axis.DOMAIN,
|
||||
{axis.DOMAIN: {"device_name": {axis.CONF_HOST: "1.2.3.4"}}},
|
||||
)
|
||||
|
||||
assert not mock_config_entries.flow.mock_calls
|
||||
|
||||
|
||||
async def test_setup_no_config(hass):
|
||||
"""Test setup without configuration."""
|
||||
assert await async_setup_component(hass, axis.DOMAIN, {})
|
||||
|
|
|
@ -230,7 +230,7 @@ async def test_setup_with_no_config(hass):
|
|||
assert konnected.YAML_CONFIGS not in hass.data[konnected.DOMAIN]
|
||||
|
||||
|
||||
async def test_setup_defined_hosts_known_auth(hass):
|
||||
async def test_setup_defined_hosts_known_auth(hass, mock_panel):
|
||||
"""Test we don't initiate a config entry if configured panel is known."""
|
||||
MockConfigEntry(
|
||||
domain="konnected",
|
||||
|
|
|
@ -15,7 +15,9 @@ async def test_config_with_sensor_passed_to_config_entry(hass):
|
|||
CONF_SCAN_INTERVAL: 600,
|
||||
}
|
||||
|
||||
with patch.object(hass, "config_entries") as mock_config_entries, patch.object(
|
||||
with patch.object(
|
||||
hass.config_entries.flow, "async_init"
|
||||
) as mock_config_entries, patch.object(
|
||||
luftdaten, "configured_sensors", return_value=[]
|
||||
):
|
||||
assert await async_setup_component(hass, DOMAIN, conf) is True
|
||||
|
@ -27,7 +29,9 @@ async def test_config_already_registered_not_passed_to_config_entry(hass):
|
|||
"""Test that an already registered sensor does not initiate an import."""
|
||||
conf = {CONF_SENSOR_ID: "12345abcde"}
|
||||
|
||||
with patch.object(hass, "config_entries") as mock_config_entries, patch.object(
|
||||
with patch.object(
|
||||
hass.config_entries.flow, "async_init"
|
||||
) as mock_config_entries, patch.object(
|
||||
luftdaten, "configured_sensors", return_value=["12345abcde"]
|
||||
):
|
||||
assert await async_setup_component(hass, DOMAIN, conf) is True
|
||||
|
|
|
@ -55,7 +55,9 @@ def get_homekit_info_mock(model):
|
|||
|
||||
async def test_setup(hass, mock_zeroconf):
|
||||
"""Test configured options for a device are loaded via config entry."""
|
||||
with patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object(
|
||||
with patch.object(
|
||||
hass.config_entries.flow, "async_init"
|
||||
) as mock_config_flow, patch.object(
|
||||
zeroconf, "ServiceBrowser", side_effect=service_update_mock
|
||||
) as mock_service_browser:
|
||||
mock_zeroconf.get_service_info.side_effect = get_service_info_mock
|
||||
|
@ -72,7 +74,9 @@ async def test_homekit_match_partial_space(hass, mock_zeroconf):
|
|||
"""Test configured options for a device are loaded via config entry."""
|
||||
with patch.dict(
|
||||
zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True
|
||||
), patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object(
|
||||
), patch.object(
|
||||
hass.config_entries.flow, "async_init"
|
||||
) as mock_config_flow, patch.object(
|
||||
zeroconf, "ServiceBrowser", side_effect=service_update_mock
|
||||
) as mock_service_browser:
|
||||
mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock("LIFX bulb")
|
||||
|
@ -87,7 +91,9 @@ async def test_homekit_match_partial_dash(hass, mock_zeroconf):
|
|||
"""Test configured options for a device are loaded via config entry."""
|
||||
with patch.dict(
|
||||
zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True
|
||||
), patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object(
|
||||
), patch.object(
|
||||
hass.config_entries.flow, "async_init"
|
||||
) as mock_config_flow, patch.object(
|
||||
zeroconf, "ServiceBrowser", side_effect=service_update_mock
|
||||
) as mock_service_browser:
|
||||
mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock(
|
||||
|
@ -104,7 +110,9 @@ async def test_homekit_match_full(hass, mock_zeroconf):
|
|||
"""Test configured options for a device are loaded via config entry."""
|
||||
with patch.dict(
|
||||
zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True
|
||||
), patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object(
|
||||
), patch.object(
|
||||
hass.config_entries.flow, "async_init"
|
||||
) as mock_config_flow, patch.object(
|
||||
zeroconf, "ServiceBrowser", side_effect=service_update_mock
|
||||
) as mock_service_browser:
|
||||
mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock("BSB002")
|
||||
|
|
|
@ -3,6 +3,7 @@ import asyncio
|
|||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from asynctest import CoroutineMock
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries, data_entry_flow, loader
|
||||
|
@ -1463,3 +1464,97 @@ async def test_partial_flows_hidden(hass, manager):
|
|||
await hass.async_block_till_done()
|
||||
state = hass.states.get("persistent_notification.config_entry_discovery")
|
||||
assert state is not None
|
||||
|
||||
|
||||
async def test_async_setup_init_entry(hass):
|
||||
"""Test a config entry being initialized during integration setup."""
|
||||
|
||||
async def mock_async_setup(hass, config):
|
||||
"""Mock setup."""
|
||||
hass.async_create_task(
|
||||
hass.config_entries.flow.async_init(
|
||||
"comp", context={"source": config_entries.SOURCE_IMPORT}, data={},
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
async_setup_entry = CoroutineMock(return_value=True)
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
"comp", async_setup=mock_async_setup, async_setup_entry=async_setup_entry
|
||||
),
|
||||
)
|
||||
mock_entity_platform(hass, "config_flow.comp", None)
|
||||
await async_setup_component(hass, "persistent_notification", {})
|
||||
|
||||
class TestFlow(config_entries.ConfigFlow):
|
||||
"""Test flow."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_import(self, user_input):
|
||||
"""Test import step creating entry."""
|
||||
return self.async_create_entry(title="title", data={})
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||
assert await async_setup_component(hass, "comp", {})
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(async_setup_entry.mock_calls) == 1
|
||||
|
||||
entries = hass.config_entries.async_entries("comp")
|
||||
assert len(entries) == 1
|
||||
assert entries[0].state == config_entries.ENTRY_STATE_LOADED
|
||||
|
||||
|
||||
async def test_async_setup_update_entry(hass):
|
||||
"""Test a config entry being updated during integration setup."""
|
||||
entry = MockConfigEntry(domain="comp", data={"value": "initial"})
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
async def mock_async_setup(hass, config):
|
||||
"""Mock setup."""
|
||||
hass.async_create_task(
|
||||
hass.config_entries.flow.async_init(
|
||||
"comp", context={"source": config_entries.SOURCE_IMPORT}, data={},
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
async def mock_async_setup_entry(hass, entry):
|
||||
"""Mock setting up an entry."""
|
||||
assert entry.data["value"] == "updated"
|
||||
return True
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
"comp",
|
||||
async_setup=mock_async_setup,
|
||||
async_setup_entry=mock_async_setup_entry,
|
||||
),
|
||||
)
|
||||
mock_entity_platform(hass, "config_flow.comp", None)
|
||||
await async_setup_component(hass, "persistent_notification", {})
|
||||
|
||||
class TestFlow(config_entries.ConfigFlow):
|
||||
"""Test flow."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_import(self, user_input):
|
||||
"""Test import step updating existing entry."""
|
||||
self.hass.config_entries.async_update_entry(
|
||||
entry, data={"value": "updated"}
|
||||
)
|
||||
return self.async_abort(reason="yo")
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||
assert await async_setup_component(hass, "comp", {})
|
||||
|
||||
entries = hass.config_entries.async_entries("comp")
|
||||
assert len(entries) == 1
|
||||
assert entries[0].state == config_entries.ENTRY_STATE_LOADED
|
||||
assert entries[0].data == {"value": "updated"}
|
||||
|
|
Loading…
Add table
Reference in a new issue