Allow async_setup changes to config entry data be taken into a… (#34166)

* Allow async_setup changes to config entry data be taken into account

* Fix tests

* Limit scope try…finally

* Update tests/test_config_entries.py

Co-Authored-By: Martin Hjelmare <marhje52@gmail.com>

* Fix import

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Paulus Schoutsen 2020-04-14 18:46:41 -07:00 committed by GitHub
parent 0db1fcca0f
commit 0b90ebf91e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 176 additions and 33 deletions

View file

@ -22,6 +22,7 @@ from homeassistant.helpers.event import async_track_time_interval
from .const import (
CONF_CITY,
CONF_COUNTRY,
CONF_GEOGRAPHIES,
DATA_CLIENT,
DEFAULT_SCAN_INTERVAL,
DOMAIN,
@ -34,8 +35,6 @@ DATA_LISTENER = "listener"
DEFAULT_OPTIONS = {CONF_SHOW_ON_MAP: True}
CONF_GEOGRAPHIES = "geographies"
GEOGRAPHY_COORDINATES_SCHEMA = vol.Schema(
{
vol.Required(CONF_LATITUDE): cv.latitude,
@ -158,8 +157,7 @@ async def async_migrate_entry(hass, config_entry):
# Update the config entry to only include the first geography (there is always
# guaranteed to be at least one):
data = {**config_entry.data}
geographies = data.pop(CONF_GEOGRAPHIES)
geographies = list(config_entry.data[CONF_GEOGRAPHIES])
first_geography = geographies.pop(0)
first_id = async_get_geography_id(first_geography)

View file

@ -16,7 +16,7 @@ from homeassistant.core import callback
from homeassistant.helpers import aiohttp_client, config_validation as cv
from . import async_get_geography_id
from .const import DOMAIN # pylint: disable=unused-import
from .const import CONF_GEOGRAPHIES, DOMAIN # pylint: disable=unused-import
class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
@ -69,6 +69,18 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
geo_id = async_get_geography_id(user_input)
await self._async_set_unique_id(geo_id)
self._abort_if_unique_id_configured()
# Find older config entries without unique ID
for entry in self._async_current_entries():
if entry.version != 1:
continue
if any(
geo_id == async_get_geography_id(geography)
for geography in entry.data[CONF_GEOGRAPHIES]
):
return self.async_abort(reason="already_configured")
websession = aiohttp_client.async_get_clientsession(self.hass)
client = Client(websession, api_key=user_input[CONF_API_KEY])
@ -90,9 +102,10 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
)
checked_keys.add(user_input[CONF_API_KEY])
return self.async_create_entry(
title=f"Cloud API ({geo_id})", data=user_input
)
return self.async_create_entry(
title=f"Cloud API ({geo_id})", data=user_input
)
class AirVisualOptionsFlowHandler(config_entries.OptionsFlow):

View file

@ -5,6 +5,7 @@ DOMAIN = "airvisual"
CONF_CITY = "city"
CONF_COUNTRY = "country"
CONF_GEOGRAPHIES = "geographies"
DATA_CLIENT = "client"

View file

@ -1,5 +1,6 @@
"""Classes to help gather user submissions."""
import abc
import asyncio
import logging
from typing import Any, Dict, List, Optional, cast
import uuid
@ -53,8 +54,18 @@ class FlowManager(abc.ABC):
def __init__(self, hass: HomeAssistant,) -> None:
"""Initialize the flow manager."""
self.hass = hass
self._initializing: Dict[str, List[asyncio.Future]] = {}
self._progress: Dict[str, Any] = {}
async def async_wait_init_flow_finish(self, handler: str) -> None:
"""Wait till all flows in progress are initialized."""
current = self._initializing.get(handler)
if not current:
return
await asyncio.wait(current)
@abc.abstractmethod
async def async_create_flow(
self,
@ -94,8 +105,13 @@ class FlowManager(abc.ABC):
"""Start a configuration flow."""
if context is None:
context = {}
init_done: asyncio.Future = asyncio.Future()
self._initializing.setdefault(handler, []).append(init_done)
flow = await self.async_create_flow(handler, context=context, data=data)
if not flow:
self._initializing[handler].remove(init_done)
raise UnknownFlow("Flow was not created")
flow.hass = self.hass
flow.handler = handler
@ -103,7 +119,12 @@ class FlowManager(abc.ABC):
flow.context = context
self._progress[flow.flow_id] = flow
result = await self._async_handle_step(flow, flow.init_step, data)
try:
result = await self._async_handle_step(
flow, flow.init_step, data, init_done
)
finally:
self._initializing[handler].remove(init_done)
if result["type"] != RESULT_TYPE_ABORT:
await self.async_post_init(flow, result)
@ -154,13 +175,19 @@ class FlowManager(abc.ABC):
raise UnknownFlow
async def _async_handle_step(
self, flow: Any, step_id: str, user_input: Optional[Dict]
self,
flow: Any,
step_id: str,
user_input: Optional[Dict],
step_done: Optional[asyncio.Future] = None,
) -> Dict:
"""Handle a step of a flow."""
method = f"async_step_{step_id}"
if not hasattr(flow, method):
self._progress.pop(flow.flow_id)
if step_done:
step_done.set_result(None)
raise UnknownStep(
f"Handler {flow.__class__.__name__} doesn't support step {step_id}"
)
@ -172,6 +199,13 @@ class FlowManager(abc.ABC):
flow.flow_id, flow.handler, err.reason, err.description_placeholders
)
# Mark the step as done.
# We do this before calling async_finish_flow because config entries will hit a
# circular dependency where async_finish_flow sets up new entry, which needs the
# integration to be set up, which is waiting for init to be done.
if step_done:
step_done.set_result(None)
if result["type"] not in (
RESULT_TYPE_FORM,
RESULT_TYPE_EXTERNAL_STEP,

View file

@ -197,9 +197,12 @@ async def _async_setup_component(
)
return False
if hass.config_entries:
for entry in hass.config_entries.async_entries(domain):
await entry.async_setup(hass, integration=integration)
# Flush out async_setup calling create_task. Fragile but covered by test.
await asyncio.sleep(0)
await hass.config_entries.flow.async_wait_init_flow_finish(domain)
for entry in hass.config_entries.async_entries(domain):
await entry.async_setup(hass, integration=integration)
hass.config.components.add(domain)

View file

@ -9,19 +9,6 @@ from .test_device import MAC, setup_axis_integration
from tests.common import MockConfigEntry, mock_coro
async def test_setup_device_already_configured(hass):
"""Test already configured device does not configure a second."""
with patch.object(hass, "config_entries") as mock_config_entries:
assert await async_setup_component(
hass,
axis.DOMAIN,
{axis.DOMAIN: {"device_name": {axis.CONF_HOST: "1.2.3.4"}}},
)
assert not mock_config_entries.flow.mock_calls
async def test_setup_no_config(hass):
"""Test setup without configuration."""
assert await async_setup_component(hass, axis.DOMAIN, {})

View file

@ -230,7 +230,7 @@ async def test_setup_with_no_config(hass):
assert konnected.YAML_CONFIGS not in hass.data[konnected.DOMAIN]
async def test_setup_defined_hosts_known_auth(hass):
async def test_setup_defined_hosts_known_auth(hass, mock_panel):
"""Test we don't initiate a config entry if configured panel is known."""
MockConfigEntry(
domain="konnected",

View file

@ -15,7 +15,9 @@ async def test_config_with_sensor_passed_to_config_entry(hass):
CONF_SCAN_INTERVAL: 600,
}
with patch.object(hass, "config_entries") as mock_config_entries, patch.object(
with patch.object(
hass.config_entries.flow, "async_init"
) as mock_config_entries, patch.object(
luftdaten, "configured_sensors", return_value=[]
):
assert await async_setup_component(hass, DOMAIN, conf) is True
@ -27,7 +29,9 @@ async def test_config_already_registered_not_passed_to_config_entry(hass):
"""Test that an already registered sensor does not initiate an import."""
conf = {CONF_SENSOR_ID: "12345abcde"}
with patch.object(hass, "config_entries") as mock_config_entries, patch.object(
with patch.object(
hass.config_entries.flow, "async_init"
) as mock_config_entries, patch.object(
luftdaten, "configured_sensors", return_value=["12345abcde"]
):
assert await async_setup_component(hass, DOMAIN, conf) is True

View file

@ -55,7 +55,9 @@ def get_homekit_info_mock(model):
async def test_setup(hass, mock_zeroconf):
"""Test configured options for a device are loaded via config entry."""
with patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object(
with patch.object(
hass.config_entries.flow, "async_init"
) as mock_config_flow, patch.object(
zeroconf, "ServiceBrowser", side_effect=service_update_mock
) as mock_service_browser:
mock_zeroconf.get_service_info.side_effect = get_service_info_mock
@ -72,7 +74,9 @@ async def test_homekit_match_partial_space(hass, mock_zeroconf):
"""Test configured options for a device are loaded via config entry."""
with patch.dict(
zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True
), patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object(
), patch.object(
hass.config_entries.flow, "async_init"
) as mock_config_flow, patch.object(
zeroconf, "ServiceBrowser", side_effect=service_update_mock
) as mock_service_browser:
mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock("LIFX bulb")
@ -87,7 +91,9 @@ async def test_homekit_match_partial_dash(hass, mock_zeroconf):
"""Test configured options for a device are loaded via config entry."""
with patch.dict(
zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True
), patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object(
), patch.object(
hass.config_entries.flow, "async_init"
) as mock_config_flow, patch.object(
zeroconf, "ServiceBrowser", side_effect=service_update_mock
) as mock_service_browser:
mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock(
@ -104,7 +110,9 @@ async def test_homekit_match_full(hass, mock_zeroconf):
"""Test configured options for a device are loaded via config entry."""
with patch.dict(
zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True
), patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object(
), patch.object(
hass.config_entries.flow, "async_init"
) as mock_config_flow, patch.object(
zeroconf, "ServiceBrowser", side_effect=service_update_mock
) as mock_service_browser:
mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock("BSB002")

View file

@ -3,6 +3,7 @@ import asyncio
from datetime import timedelta
from unittest.mock import MagicMock, patch
from asynctest import CoroutineMock
import pytest
from homeassistant import config_entries, data_entry_flow, loader
@ -1463,3 +1464,97 @@ async def test_partial_flows_hidden(hass, manager):
await hass.async_block_till_done()
state = hass.states.get("persistent_notification.config_entry_discovery")
assert state is not None
async def test_async_setup_init_entry(hass):
"""Test a config entry being initialized during integration setup."""
async def mock_async_setup(hass, config):
"""Mock setup."""
hass.async_create_task(
hass.config_entries.flow.async_init(
"comp", context={"source": config_entries.SOURCE_IMPORT}, data={},
)
)
return True
async_setup_entry = CoroutineMock(return_value=True)
mock_integration(
hass,
MockModule(
"comp", async_setup=mock_async_setup, async_setup_entry=async_setup_entry
),
)
mock_entity_platform(hass, "config_flow.comp", None)
await async_setup_component(hass, "persistent_notification", {})
class TestFlow(config_entries.ConfigFlow):
"""Test flow."""
VERSION = 1
async def async_step_import(self, user_input):
"""Test import step creating entry."""
return self.async_create_entry(title="title", data={})
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
assert await async_setup_component(hass, "comp", {})
await hass.async_block_till_done()
assert len(async_setup_entry.mock_calls) == 1
entries = hass.config_entries.async_entries("comp")
assert len(entries) == 1
assert entries[0].state == config_entries.ENTRY_STATE_LOADED
async def test_async_setup_update_entry(hass):
"""Test a config entry being updated during integration setup."""
entry = MockConfigEntry(domain="comp", data={"value": "initial"})
entry.add_to_hass(hass)
async def mock_async_setup(hass, config):
"""Mock setup."""
hass.async_create_task(
hass.config_entries.flow.async_init(
"comp", context={"source": config_entries.SOURCE_IMPORT}, data={},
)
)
return True
async def mock_async_setup_entry(hass, entry):
"""Mock setting up an entry."""
assert entry.data["value"] == "updated"
return True
mock_integration(
hass,
MockModule(
"comp",
async_setup=mock_async_setup,
async_setup_entry=mock_async_setup_entry,
),
)
mock_entity_platform(hass, "config_flow.comp", None)
await async_setup_component(hass, "persistent_notification", {})
class TestFlow(config_entries.ConfigFlow):
"""Test flow."""
VERSION = 1
async def async_step_import(self, user_input):
"""Test import step updating existing entry."""
self.hass.config_entries.async_update_entry(
entry, data={"value": "updated"}
)
return self.async_abort(reason="yo")
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
assert await async_setup_component(hass, "comp", {})
entries = hass.config_entries.async_entries("comp")
assert len(entries) == 1
assert entries[0].state == config_entries.ENTRY_STATE_LOADED
assert entries[0].data == {"value": "updated"}