Allow async_setup changes to config entry data be taken into a… (#34166)

* Allow async_setup changes to config entry data be taken into account

* Fix tests

* Limit scope try…finally

* Update tests/test_config_entries.py

Co-Authored-By: Martin Hjelmare <marhje52@gmail.com>

* Fix import

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Paulus Schoutsen 2020-04-14 18:46:41 -07:00 committed by GitHub
parent 0db1fcca0f
commit 0b90ebf91e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 176 additions and 33 deletions

View file

@ -22,6 +22,7 @@ from homeassistant.helpers.event import async_track_time_interval
from .const import ( from .const import (
CONF_CITY, CONF_CITY,
CONF_COUNTRY, CONF_COUNTRY,
CONF_GEOGRAPHIES,
DATA_CLIENT, DATA_CLIENT,
DEFAULT_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL,
DOMAIN, DOMAIN,
@ -34,8 +35,6 @@ DATA_LISTENER = "listener"
DEFAULT_OPTIONS = {CONF_SHOW_ON_MAP: True} DEFAULT_OPTIONS = {CONF_SHOW_ON_MAP: True}
CONF_GEOGRAPHIES = "geographies"
GEOGRAPHY_COORDINATES_SCHEMA = vol.Schema( GEOGRAPHY_COORDINATES_SCHEMA = vol.Schema(
{ {
vol.Required(CONF_LATITUDE): cv.latitude, vol.Required(CONF_LATITUDE): cv.latitude,
@ -158,8 +157,7 @@ async def async_migrate_entry(hass, config_entry):
# Update the config entry to only include the first geography (there is always # Update the config entry to only include the first geography (there is always
# guaranteed to be at least one): # guaranteed to be at least one):
data = {**config_entry.data} geographies = list(config_entry.data[CONF_GEOGRAPHIES])
geographies = data.pop(CONF_GEOGRAPHIES)
first_geography = geographies.pop(0) first_geography = geographies.pop(0)
first_id = async_get_geography_id(first_geography) first_id = async_get_geography_id(first_geography)

View file

@ -16,7 +16,7 @@ from homeassistant.core import callback
from homeassistant.helpers import aiohttp_client, config_validation as cv from homeassistant.helpers import aiohttp_client, config_validation as cv
from . import async_get_geography_id from . import async_get_geography_id
from .const import DOMAIN # pylint: disable=unused-import from .const import CONF_GEOGRAPHIES, DOMAIN # pylint: disable=unused-import
class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
@ -69,6 +69,18 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
geo_id = async_get_geography_id(user_input) geo_id = async_get_geography_id(user_input)
await self._async_set_unique_id(geo_id) await self._async_set_unique_id(geo_id)
self._abort_if_unique_id_configured()
# Find older config entries without unique ID
for entry in self._async_current_entries():
if entry.version != 1:
continue
if any(
geo_id == async_get_geography_id(geography)
for geography in entry.data[CONF_GEOGRAPHIES]
):
return self.async_abort(reason="already_configured")
websession = aiohttp_client.async_get_clientsession(self.hass) websession = aiohttp_client.async_get_clientsession(self.hass)
client = Client(websession, api_key=user_input[CONF_API_KEY]) client = Client(websession, api_key=user_input[CONF_API_KEY])
@ -90,6 +102,7 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
) )
checked_keys.add(user_input[CONF_API_KEY]) checked_keys.add(user_input[CONF_API_KEY])
return self.async_create_entry( return self.async_create_entry(
title=f"Cloud API ({geo_id})", data=user_input title=f"Cloud API ({geo_id})", data=user_input
) )

View file

@ -5,6 +5,7 @@ DOMAIN = "airvisual"
CONF_CITY = "city" CONF_CITY = "city"
CONF_COUNTRY = "country" CONF_COUNTRY = "country"
CONF_GEOGRAPHIES = "geographies"
DATA_CLIENT = "client" DATA_CLIENT = "client"

View file

@ -1,5 +1,6 @@
"""Classes to help gather user submissions.""" """Classes to help gather user submissions."""
import abc import abc
import asyncio
import logging import logging
from typing import Any, Dict, List, Optional, cast from typing import Any, Dict, List, Optional, cast
import uuid import uuid
@ -53,8 +54,18 @@ class FlowManager(abc.ABC):
def __init__(self, hass: HomeAssistant,) -> None: def __init__(self, hass: HomeAssistant,) -> None:
"""Initialize the flow manager.""" """Initialize the flow manager."""
self.hass = hass self.hass = hass
self._initializing: Dict[str, List[asyncio.Future]] = {}
self._progress: Dict[str, Any] = {} self._progress: Dict[str, Any] = {}
async def async_wait_init_flow_finish(self, handler: str) -> None:
"""Wait till all flows in progress are initialized."""
current = self._initializing.get(handler)
if not current:
return
await asyncio.wait(current)
@abc.abstractmethod @abc.abstractmethod
async def async_create_flow( async def async_create_flow(
self, self,
@ -94,8 +105,13 @@ class FlowManager(abc.ABC):
"""Start a configuration flow.""" """Start a configuration flow."""
if context is None: if context is None:
context = {} context = {}
init_done: asyncio.Future = asyncio.Future()
self._initializing.setdefault(handler, []).append(init_done)
flow = await self.async_create_flow(handler, context=context, data=data) flow = await self.async_create_flow(handler, context=context, data=data)
if not flow: if not flow:
self._initializing[handler].remove(init_done)
raise UnknownFlow("Flow was not created") raise UnknownFlow("Flow was not created")
flow.hass = self.hass flow.hass = self.hass
flow.handler = handler flow.handler = handler
@ -103,7 +119,12 @@ class FlowManager(abc.ABC):
flow.context = context flow.context = context
self._progress[flow.flow_id] = flow self._progress[flow.flow_id] = flow
result = await self._async_handle_step(flow, flow.init_step, data) try:
result = await self._async_handle_step(
flow, flow.init_step, data, init_done
)
finally:
self._initializing[handler].remove(init_done)
if result["type"] != RESULT_TYPE_ABORT: if result["type"] != RESULT_TYPE_ABORT:
await self.async_post_init(flow, result) await self.async_post_init(flow, result)
@ -154,13 +175,19 @@ class FlowManager(abc.ABC):
raise UnknownFlow raise UnknownFlow
async def _async_handle_step( async def _async_handle_step(
self, flow: Any, step_id: str, user_input: Optional[Dict] self,
flow: Any,
step_id: str,
user_input: Optional[Dict],
step_done: Optional[asyncio.Future] = None,
) -> Dict: ) -> Dict:
"""Handle a step of a flow.""" """Handle a step of a flow."""
method = f"async_step_{step_id}" method = f"async_step_{step_id}"
if not hasattr(flow, method): if not hasattr(flow, method):
self._progress.pop(flow.flow_id) self._progress.pop(flow.flow_id)
if step_done:
step_done.set_result(None)
raise UnknownStep( raise UnknownStep(
f"Handler {flow.__class__.__name__} doesn't support step {step_id}" f"Handler {flow.__class__.__name__} doesn't support step {step_id}"
) )
@ -172,6 +199,13 @@ class FlowManager(abc.ABC):
flow.flow_id, flow.handler, err.reason, err.description_placeholders flow.flow_id, flow.handler, err.reason, err.description_placeholders
) )
# Mark the step as done.
# We do this before calling async_finish_flow because config entries will hit a
# circular dependency where async_finish_flow sets up new entry, which needs the
# integration to be set up, which is waiting for init to be done.
if step_done:
step_done.set_result(None)
if result["type"] not in ( if result["type"] not in (
RESULT_TYPE_FORM, RESULT_TYPE_FORM,
RESULT_TYPE_EXTERNAL_STEP, RESULT_TYPE_EXTERNAL_STEP,

View file

@ -197,7 +197,10 @@ async def _async_setup_component(
) )
return False return False
if hass.config_entries: # Flush out async_setup calling create_task. Fragile but covered by test.
await asyncio.sleep(0)
await hass.config_entries.flow.async_wait_init_flow_finish(domain)
for entry in hass.config_entries.async_entries(domain): for entry in hass.config_entries.async_entries(domain):
await entry.async_setup(hass, integration=integration) await entry.async_setup(hass, integration=integration)

View file

@ -9,19 +9,6 @@ from .test_device import MAC, setup_axis_integration
from tests.common import MockConfigEntry, mock_coro from tests.common import MockConfigEntry, mock_coro
async def test_setup_device_already_configured(hass):
"""Test already configured device does not configure a second."""
with patch.object(hass, "config_entries") as mock_config_entries:
assert await async_setup_component(
hass,
axis.DOMAIN,
{axis.DOMAIN: {"device_name": {axis.CONF_HOST: "1.2.3.4"}}},
)
assert not mock_config_entries.flow.mock_calls
async def test_setup_no_config(hass): async def test_setup_no_config(hass):
"""Test setup without configuration.""" """Test setup without configuration."""
assert await async_setup_component(hass, axis.DOMAIN, {}) assert await async_setup_component(hass, axis.DOMAIN, {})

View file

@ -230,7 +230,7 @@ async def test_setup_with_no_config(hass):
assert konnected.YAML_CONFIGS not in hass.data[konnected.DOMAIN] assert konnected.YAML_CONFIGS not in hass.data[konnected.DOMAIN]
async def test_setup_defined_hosts_known_auth(hass): async def test_setup_defined_hosts_known_auth(hass, mock_panel):
"""Test we don't initiate a config entry if configured panel is known.""" """Test we don't initiate a config entry if configured panel is known."""
MockConfigEntry( MockConfigEntry(
domain="konnected", domain="konnected",

View file

@ -15,7 +15,9 @@ async def test_config_with_sensor_passed_to_config_entry(hass):
CONF_SCAN_INTERVAL: 600, CONF_SCAN_INTERVAL: 600,
} }
with patch.object(hass, "config_entries") as mock_config_entries, patch.object( with patch.object(
hass.config_entries.flow, "async_init"
) as mock_config_entries, patch.object(
luftdaten, "configured_sensors", return_value=[] luftdaten, "configured_sensors", return_value=[]
): ):
assert await async_setup_component(hass, DOMAIN, conf) is True assert await async_setup_component(hass, DOMAIN, conf) is True
@ -27,7 +29,9 @@ async def test_config_already_registered_not_passed_to_config_entry(hass):
"""Test that an already registered sensor does not initiate an import.""" """Test that an already registered sensor does not initiate an import."""
conf = {CONF_SENSOR_ID: "12345abcde"} conf = {CONF_SENSOR_ID: "12345abcde"}
with patch.object(hass, "config_entries") as mock_config_entries, patch.object( with patch.object(
hass.config_entries.flow, "async_init"
) as mock_config_entries, patch.object(
luftdaten, "configured_sensors", return_value=["12345abcde"] luftdaten, "configured_sensors", return_value=["12345abcde"]
): ):
assert await async_setup_component(hass, DOMAIN, conf) is True assert await async_setup_component(hass, DOMAIN, conf) is True

View file

@ -55,7 +55,9 @@ def get_homekit_info_mock(model):
async def test_setup(hass, mock_zeroconf): async def test_setup(hass, mock_zeroconf):
"""Test configured options for a device are loaded via config entry.""" """Test configured options for a device are loaded via config entry."""
with patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object( with patch.object(
hass.config_entries.flow, "async_init"
) as mock_config_flow, patch.object(
zeroconf, "ServiceBrowser", side_effect=service_update_mock zeroconf, "ServiceBrowser", side_effect=service_update_mock
) as mock_service_browser: ) as mock_service_browser:
mock_zeroconf.get_service_info.side_effect = get_service_info_mock mock_zeroconf.get_service_info.side_effect = get_service_info_mock
@ -72,7 +74,9 @@ async def test_homekit_match_partial_space(hass, mock_zeroconf):
"""Test configured options for a device are loaded via config entry.""" """Test configured options for a device are loaded via config entry."""
with patch.dict( with patch.dict(
zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True
), patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object( ), patch.object(
hass.config_entries.flow, "async_init"
) as mock_config_flow, patch.object(
zeroconf, "ServiceBrowser", side_effect=service_update_mock zeroconf, "ServiceBrowser", side_effect=service_update_mock
) as mock_service_browser: ) as mock_service_browser:
mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock("LIFX bulb") mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock("LIFX bulb")
@ -87,7 +91,9 @@ async def test_homekit_match_partial_dash(hass, mock_zeroconf):
"""Test configured options for a device are loaded via config entry.""" """Test configured options for a device are loaded via config entry."""
with patch.dict( with patch.dict(
zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True
), patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object( ), patch.object(
hass.config_entries.flow, "async_init"
) as mock_config_flow, patch.object(
zeroconf, "ServiceBrowser", side_effect=service_update_mock zeroconf, "ServiceBrowser", side_effect=service_update_mock
) as mock_service_browser: ) as mock_service_browser:
mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock( mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock(
@ -104,7 +110,9 @@ async def test_homekit_match_full(hass, mock_zeroconf):
"""Test configured options for a device are loaded via config entry.""" """Test configured options for a device are loaded via config entry."""
with patch.dict( with patch.dict(
zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True zc_gen.ZEROCONF, {zeroconf.HOMEKIT_TYPE: ["homekit_controller"]}, clear=True
), patch.object(hass.config_entries, "flow") as mock_config_flow, patch.object( ), patch.object(
hass.config_entries.flow, "async_init"
) as mock_config_flow, patch.object(
zeroconf, "ServiceBrowser", side_effect=service_update_mock zeroconf, "ServiceBrowser", side_effect=service_update_mock
) as mock_service_browser: ) as mock_service_browser:
mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock("BSB002") mock_zeroconf.get_service_info.side_effect = get_homekit_info_mock("BSB002")

View file

@ -3,6 +3,7 @@ import asyncio
from datetime import timedelta from datetime import timedelta
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from asynctest import CoroutineMock
import pytest import pytest
from homeassistant import config_entries, data_entry_flow, loader from homeassistant import config_entries, data_entry_flow, loader
@ -1463,3 +1464,97 @@ async def test_partial_flows_hidden(hass, manager):
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("persistent_notification.config_entry_discovery") state = hass.states.get("persistent_notification.config_entry_discovery")
assert state is not None assert state is not None
async def test_async_setup_init_entry(hass):
"""Test a config entry being initialized during integration setup."""
async def mock_async_setup(hass, config):
"""Mock setup."""
hass.async_create_task(
hass.config_entries.flow.async_init(
"comp", context={"source": config_entries.SOURCE_IMPORT}, data={},
)
)
return True
async_setup_entry = CoroutineMock(return_value=True)
mock_integration(
hass,
MockModule(
"comp", async_setup=mock_async_setup, async_setup_entry=async_setup_entry
),
)
mock_entity_platform(hass, "config_flow.comp", None)
await async_setup_component(hass, "persistent_notification", {})
class TestFlow(config_entries.ConfigFlow):
"""Test flow."""
VERSION = 1
async def async_step_import(self, user_input):
"""Test import step creating entry."""
return self.async_create_entry(title="title", data={})
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
assert await async_setup_component(hass, "comp", {})
await hass.async_block_till_done()
assert len(async_setup_entry.mock_calls) == 1
entries = hass.config_entries.async_entries("comp")
assert len(entries) == 1
assert entries[0].state == config_entries.ENTRY_STATE_LOADED
async def test_async_setup_update_entry(hass):
"""Test a config entry being updated during integration setup."""
entry = MockConfigEntry(domain="comp", data={"value": "initial"})
entry.add_to_hass(hass)
async def mock_async_setup(hass, config):
"""Mock setup."""
hass.async_create_task(
hass.config_entries.flow.async_init(
"comp", context={"source": config_entries.SOURCE_IMPORT}, data={},
)
)
return True
async def mock_async_setup_entry(hass, entry):
"""Mock setting up an entry."""
assert entry.data["value"] == "updated"
return True
mock_integration(
hass,
MockModule(
"comp",
async_setup=mock_async_setup,
async_setup_entry=mock_async_setup_entry,
),
)
mock_entity_platform(hass, "config_flow.comp", None)
await async_setup_component(hass, "persistent_notification", {})
class TestFlow(config_entries.ConfigFlow):
"""Test flow."""
VERSION = 1
async def async_step_import(self, user_input):
"""Test import step updating existing entry."""
self.hass.config_entries.async_update_entry(
entry, data={"value": "updated"}
)
return self.async_abort(reason="yo")
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
assert await async_setup_component(hass, "comp", {})
entries = hass.config_entries.async_entries("comp")
assert len(entries) == 1
assert entries[0].state == config_entries.ENTRY_STATE_LOADED
assert entries[0].data == {"value": "updated"}