Add unique ID to config entries (#29806)

* Add unique ID to config entries

* Unload existing entries with same unique ID if flow with unique ID is
finished

* Remove unused exception

* Fix typing

* silence pylint

* Fix tests

* Add unique ID to Hue

* Address typing comment

* Tweaks to comments

* lint
This commit is contained in:
Paulus Schoutsen 2019-12-16 12:27:43 +01:00 committed by GitHub
parent 87ca61ddd7
commit d851cb6f9e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 305 additions and 46 deletions

View file

@ -4,11 +4,11 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries, core
from homeassistant.const import CONF_FILENAME, CONF_HOST from homeassistant.const import CONF_FILENAME, CONF_HOST
from homeassistant.helpers import config_validation as cv, device_registry as dr from homeassistant.helpers import config_validation as cv, device_registry as dr
from .bridge import HueBridge from .bridge import HueBridge, normalize_bridge_id
from .config_flow import ( # Loading the config flow file will register the flow from .config_flow import ( # Loading the config flow file will register the flow
configured_hosts, configured_hosts,
) )
@ -102,7 +102,9 @@ async def async_setup(hass, config):
return True return True
async def async_setup_entry(hass, entry): async def async_setup_entry(
hass: core.HomeAssistant, entry: config_entries.ConfigEntry
):
"""Set up a bridge from a config entry.""" """Set up a bridge from a config entry."""
host = entry.data["host"] host = entry.data["host"]
config = hass.data[DATA_CONFIGS].get(host) config = hass.data[DATA_CONFIGS].get(host)
@ -121,6 +123,13 @@ async def async_setup_entry(hass, entry):
hass.data[DOMAIN][host] = bridge hass.data[DOMAIN][host] = bridge
config = bridge.api.config config = bridge.api.config
# For backwards compat
if entry.unique_id is None:
hass.config_entries.async_update_entry(
entry, unique_id=normalize_bridge_id(config.bridgeid)
)
device_registry = await dr.async_get_registry(hass) device_registry = await dr.async_get_registry(hass)
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,

View file

@ -201,3 +201,25 @@ async def get_bridge(hass, host, username=None):
except aiohue.AiohueException: except aiohue.AiohueException:
LOGGER.exception("Unknown Hue linking error occurred") LOGGER.exception("Unknown Hue linking error occurred")
raise AuthenticationRequired raise AuthenticationRequired
def normalize_bridge_id(bridge_id: str):
"""Normalize a bridge identifier.
There are three sources where we receive bridge ID from:
- ssdp/upnp: <host>/description.xml, field root/device/serialNumber
- nupnp: "id" field
- Hue Bridge API: config.bridgeid
The SSDP/UPNP source does not contain the middle 4 characters compared
to the other sources. In all our tests the middle 4 characters are "fffe".
"""
if len(bridge_id) == 16:
return bridge_id[0:6] + bridge_id[-6:]
if len(bridge_id) == 12:
return bridge_id
LOGGER.warning("Unexpected bridge id number found: %s", bridge_id)
return bridge_id

View file

@ -12,7 +12,7 @@ from homeassistant.components.ssdp import ATTR_MANUFACTURERURL, ATTR_NAME
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers import aiohttp_client from homeassistant.helpers import aiohttp_client
from .bridge import get_bridge from .bridge import get_bridge, normalize_bridge_id
from .const import DOMAIN, LOGGER from .const import DOMAIN, LOGGER
from .errors import AuthenticationRequired, CannotConnect from .errors import AuthenticationRequired, CannotConnect
@ -154,17 +154,15 @@ class HueFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
if host in configured_hosts(self.hass): if host in configured_hosts(self.hass):
return self.async_abort(reason="already_configured") return self.async_abort(reason="already_configured")
# This value is based off host/description.xml and is, weirdly, missing bridge_id = discovery_info.get("serial")
# 4 characters in the middle of the serial compared to results returned
# from the NUPNP API or when querying the bridge API for bridgeid. await self.async_set_unique_id(normalize_bridge_id(bridge_id))
# (on first gen Hue hub)
serial = discovery_info.get("serial")
return await self.async_step_import( return await self.async_step_import(
{ {
"host": host, "host": host,
# This format is the legacy format that Hue used for discovery # This format is the legacy format that Hue used for discovery
"path": f"phue-{serial}.conf", "path": f"phue-{bridge_id}.conf",
} }
) )
@ -180,6 +178,10 @@ class HueFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
if host in configured_hosts(self.hass): if host in configured_hosts(self.hass):
return self.async_abort(reason="already_configured") return self.async_abort(reason="already_configured")
await self.async_set_unique_id(
normalize_bridge_id(homekit_info["properties"]["id"].replace(":", ""))
)
return await self.async_step_import({"host": host}) return await self.async_step_import({"host": host})
async def async_step_import(self, import_info): async def async_step_import(self, import_info):
@ -234,18 +236,9 @@ class HueFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
host = bridge.host host = bridge.host
bridge_id = bridge.config.bridgeid bridge_id = bridge.config.bridgeid
same_hub_entries = [ if self.unique_id is None:
entry.entry_id await self.async_set_unique_id(
for entry in self.hass.config_entries.async_entries(DOMAIN) normalize_bridge_id(bridge_id), raise_on_progress=False
if entry.data["bridge_id"] == bridge_id or entry.data["host"] == host
]
if same_hub_entries:
await asyncio.wait(
[
self.hass.config_entries.async_remove(entry_id)
for entry_id in same_hub_entries
]
) )
return self.async_create_entry( return self.async_create_entry(

View file

@ -2,7 +2,7 @@
import asyncio import asyncio
import functools import functools
import logging import logging
from typing import Any, Callable, Dict, List, Optional, Set, cast from typing import Any, Callable, Dict, List, Optional, Set, Union, cast
import uuid import uuid
import weakref import weakref
@ -75,6 +75,10 @@ class OperationNotAllowed(ConfigError):
"""Raised when a config entry operation is not allowed.""" """Raised when a config entry operation is not allowed."""
class UniqueIdInProgress(data_entry_flow.AbortFlow):
"""Error to indicate that the unique Id is in progress."""
class ConfigEntry: class ConfigEntry:
"""Hold a configuration entry.""" """Hold a configuration entry."""
@ -85,6 +89,7 @@ class ConfigEntry:
"title", "title",
"data", "data",
"options", "options",
"unique_id",
"system_options", "system_options",
"source", "source",
"connection_class", "connection_class",
@ -104,6 +109,7 @@ class ConfigEntry:
connection_class: str, connection_class: str,
system_options: dict, system_options: dict,
options: Optional[dict] = None, options: Optional[dict] = None,
unique_id: Optional[str] = None,
entry_id: Optional[str] = None, entry_id: Optional[str] = None,
state: str = ENTRY_STATE_NOT_LOADED, state: str = ENTRY_STATE_NOT_LOADED,
) -> None: ) -> None:
@ -138,6 +144,9 @@ class ConfigEntry:
# State of the entry (LOADED, NOT_LOADED) # State of the entry (LOADED, NOT_LOADED)
self.state = state self.state = state
# Unique ID of this entry.
self.unique_id = unique_id
# Listeners to call on update # Listeners to call on update
self.update_listeners: List = [] self.update_listeners: List = []
@ -533,11 +542,15 @@ class ConfigEntries:
self, self,
entry: ConfigEntry, entry: ConfigEntry,
*, *,
unique_id: Union[str, dict, None] = _UNDEF,
data: dict = _UNDEF, data: dict = _UNDEF,
options: dict = _UNDEF, options: dict = _UNDEF,
system_options: dict = _UNDEF, system_options: dict = _UNDEF,
) -> None: ) -> None:
"""Update a config entry.""" """Update a config entry."""
if unique_id is not _UNDEF:
entry.unique_id = cast(Optional[str], unique_id)
if data is not _UNDEF: if data is not _UNDEF:
entry.data = data entry.data = data
@ -602,6 +615,25 @@ class ConfigEntries:
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY: if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return result return result
# Check if config entry exists with unique ID. Unload it.
existing_entry = None
unique_id = flow.context.get("unique_id")
if unique_id is not None:
for check_entry in self.async_entries(result["handler"]):
if check_entry.unique_id == unique_id:
existing_entry = check_entry
break
# Unload the entry before setting up the new one.
# We will remove it only after the other one is set up,
# so that device customizations are not getting lost.
if (
existing_entry is not None
and existing_entry.state not in UNRECOVERABLE_STATES
):
await self.async_unload(existing_entry.entry_id)
entry = ConfigEntry( entry = ConfigEntry(
version=result["version"], version=result["version"],
domain=result["handler"], domain=result["handler"],
@ -611,12 +643,16 @@ class ConfigEntries:
system_options={}, system_options={},
source=flow.context["source"], source=flow.context["source"],
connection_class=flow.CONNECTION_CLASS, connection_class=flow.CONNECTION_CLASS,
unique_id=unique_id,
) )
self._entries.append(entry) self._entries.append(entry)
self._async_schedule_save() self._async_schedule_save()
await self.async_setup(entry.entry_id) await self.async_setup(entry.entry_id)
if existing_entry is not None:
await self.async_remove(existing_entry.entry_id)
result["result"] = entry result["result"] = entry
return result return result
@ -687,6 +723,8 @@ async def _old_conf_migrator(old_config: Dict[str, Any]) -> Dict[str, Any]:
class ConfigFlow(data_entry_flow.FlowHandler): class ConfigFlow(data_entry_flow.FlowHandler):
"""Base class for config flows with some helpers.""" """Base class for config flows with some helpers."""
unique_id = None
def __init_subclass__(cls, domain: Optional[str] = None, **kwargs: Any) -> None: def __init_subclass__(cls, domain: Optional[str] = None, **kwargs: Any) -> None:
"""Initialize a subclass, register if possible.""" """Initialize a subclass, register if possible."""
super().__init_subclass__(**kwargs) # type: ignore super().__init_subclass__(**kwargs) # type: ignore
@ -701,6 +739,27 @@ class ConfigFlow(data_entry_flow.FlowHandler):
"""Get the options flow for this handler.""" """Get the options flow for this handler."""
raise data_entry_flow.UnknownHandler raise data_entry_flow.UnknownHandler
async def async_set_unique_id(
self, unique_id: str, *, raise_on_progress: bool = True
) -> Optional[ConfigEntry]:
"""Set a unique ID for the config flow.
Returns optionally existing config entry with same ID.
"""
if raise_on_progress:
for progress in self._async_in_progress():
if progress["context"].get("unique_id") == unique_id:
raise UniqueIdInProgress("already_in_progress")
# pylint: disable=no-member
self.context["unique_id"] = unique_id
for entry in self._async_current_entries():
if entry.unique_id == unique_id:
return entry
return None
@callback @callback
def _async_current_entries(self) -> List[ConfigEntry]: def _async_current_entries(self) -> List[ConfigEntry]:
"""Return current entries.""" """Return current entries."""

View file

@ -1,6 +1,6 @@
"""Classes to help gather user submissions.""" """Classes to help gather user submissions."""
import logging import logging
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, cast
import uuid import uuid
import voluptuous as vol import voluptuous as vol
@ -36,6 +36,16 @@ class UnknownStep(FlowError):
"""Unknown step specified.""" """Unknown step specified."""
class AbortFlow(FlowError):
"""Exception to indicate a flow needs to be aborted."""
def __init__(self, reason: str, description_placeholders: Optional[Dict] = None):
"""Initialize an abort flow exception."""
super().__init__(f"Flow aborted: {reason}")
self.reason = reason
self.description_placeholders = description_placeholders
class FlowManager: class FlowManager:
"""Manage all the flows that are in progress.""" """Manage all the flows that are in progress."""
@ -131,7 +141,12 @@ class FlowManager:
) )
) )
result: Dict = await getattr(flow, method)(user_input) try:
result: Dict = await getattr(flow, method)(user_input)
except AbortFlow as err:
result = _create_abort_data(
flow.flow_id, flow.handler, err.reason, err.description_placeholders
)
if result["type"] not in ( if result["type"] not in (
RESULT_TYPE_FORM, RESULT_TYPE_FORM,
@ -228,13 +243,9 @@ class FlowHandler:
self, *, reason: str, description_placeholders: Optional[Dict] = None self, *, reason: str, description_placeholders: Optional[Dict] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Abort the config flow.""" """Abort the config flow."""
return { return _create_abort_data(
"type": RESULT_TYPE_ABORT, self.flow_id, cast(str, self.handler), reason, description_placeholders
"flow_id": self.flow_id, )
"handler": self.handler,
"reason": reason,
"description_placeholders": description_placeholders,
}
@callback @callback
def async_external_step( def async_external_step(
@ -259,3 +270,20 @@ class FlowHandler:
"handler": self.handler, "handler": self.handler,
"step_id": next_step_id, "step_id": next_step_id,
} }
@callback
def _create_abort_data(
flow_id: str,
handler: str,
reason: str,
description_placeholders: Optional[Dict] = None,
) -> Dict[str, Any]:
"""Return the definition of an external step for the user to take."""
return {
"type": RESULT_TYPE_ABORT,
"flow_id": flow_id,
"handler": handler,
"reason": reason,
"description_placeholders": description_placeholders,
}

View file

@ -671,6 +671,7 @@ class MockConfigEntry(config_entries.ConfigEntry):
options={}, options={},
system_options={}, system_options={},
connection_class=config_entries.CONN_CLASS_UNKNOWN, connection_class=config_entries.CONN_CLASS_UNKNOWN,
unique_id=None,
): ):
"""Initialize a mock config entry.""" """Initialize a mock config entry."""
kwargs = { kwargs = {
@ -682,6 +683,7 @@ class MockConfigEntry(config_entries.ConfigEntry):
"version": version, "version": version,
"title": title, "title": title,
"connection_class": connection_class, "connection_class": connection_class,
"unique_id": unique_id,
} }
if source is not None: if source is not None:
kwargs["source"] = source kwargs["source"] = source

View file

@ -19,6 +19,7 @@ async def test_flow_works(hass, aioclient_mock):
flow = config_flow.HueFlowHandler() flow = config_flow.HueFlowHandler()
flow.hass = hass flow.hass = hass
flow.context = {}
await flow.async_step_init() await flow.async_step_init()
with patch("aiohue.Bridge") as mock_bridge: with patch("aiohue.Bridge") as mock_bridge:
@ -349,28 +350,33 @@ async def test_creating_entry_removes_entries_for_same_host_or_bridge(hass):
accessible via a single IP. So when we create a new entry, we'll remove accessible via a single IP. So when we create a new entry, we'll remove
all existing entries that either have same IP or same bridge_id. all existing entries that either have same IP or same bridge_id.
""" """
MockConfigEntry( orig_entry = MockConfigEntry(
domain="hue", data={"host": "0.0.0.0", "bridge_id": "id-1234"} domain="hue",
).add_to_hass(hass) data={"host": "0.0.0.0", "bridge_id": "id-1234"},
unique_id="id-1234",
)
orig_entry.add_to_hass(hass)
MockConfigEntry( MockConfigEntry(
domain="hue", data={"host": "1.2.3.4", "bridge_id": "id-1234"} domain="hue",
data={"host": "1.2.3.4", "bridge_id": "id-5678"},
unique_id="id-5678",
).add_to_hass(hass) ).add_to_hass(hass)
assert len(hass.config_entries.async_entries("hue")) == 2 assert len(hass.config_entries.async_entries("hue")) == 2
flow = config_flow.HueFlowHandler()
flow.hass = hass
flow.context = {}
bridge = Mock() bridge = Mock()
bridge.username = "username-abc" bridge.username = "username-abc"
bridge.config.bridgeid = "id-1234" bridge.config.bridgeid = "id-1234"
bridge.config.name = "Mock Bridge" bridge.config.name = "Mock Bridge"
bridge.host = "0.0.0.0" bridge.host = "0.0.0.0"
with patch.object(config_flow, "get_bridge", return_value=mock_coro(bridge)): with patch.object(
result = await flow.async_step_import({"host": "0.0.0.0"}) config_flow, "_find_username_from_config", return_value="mock-user"
), patch.object(config_flow, "get_bridge", return_value=mock_coro(bridge)):
result = await hass.config_entries.flow.async_init(
"hue", data={"host": "2.2.2.2"}, context={"source": "import"}
)
assert result["type"] == "create_entry" assert result["type"] == "create_entry"
assert result["title"] == "Mock Bridge" assert result["title"] == "Mock Bridge"
@ -379,9 +385,11 @@ async def test_creating_entry_removes_entries_for_same_host_or_bridge(hass):
"bridge_id": "id-1234", "bridge_id": "id-1234",
"username": "username-abc", "username": "username-abc",
} }
# We did not process the result of this entry but already removed the old entries = hass.config_entries.async_entries("hue")
# ones. So we should have 0 entries. assert len(entries) == 2
assert len(hass.config_entries.async_entries("hue")) == 0 new_entry = entries[-1]
assert orig_entry.entry_id != new_entry.entry_id
assert new_entry.unique_id == "id-1234"
async def test_bridge_homekit(hass): async def test_bridge_homekit(hass):
@ -398,6 +406,7 @@ async def test_bridge_homekit(hass):
"host": "0.0.0.0", "host": "0.0.0.0",
"serial": "1234", "serial": "1234",
"manufacturerURL": config_flow.HUE_MANUFACTURERURL, "manufacturerURL": config_flow.HUE_MANUFACTURERURL,
"properties": {"id": "aa:bb:cc:dd:ee:ff"},
} }
) )

View file

@ -175,3 +175,19 @@ async def test_unload_entry(hass):
assert await hue.async_unload_entry(hass, entry) assert await hue.async_unload_entry(hass, entry)
assert len(mock_bridge.return_value.async_reset.mock_calls) == 1 assert len(mock_bridge.return_value.async_reset.mock_calls) == 1
assert hass.data[hue.DOMAIN] == {} assert hass.data[hue.DOMAIN] == {}
async def test_setting_unique_id(hass):
"""Test we set unique ID if not set yet."""
entry = MockConfigEntry(domain=hue.DOMAIN, data={"host": "0.0.0.0"})
entry.add_to_hass(hass)
with patch.object(hue, "HueBridge") as mock_bridge, patch(
"homeassistant.helpers.device_registry.async_get_registry",
return_value=mock_coro(Mock()),
):
mock_bridge.return_value.async_setup.return_value = mock_coro(True)
mock_bridge.return_value.api.config = Mock(bridgeid="mock-id")
assert await async_setup_component(hass, hue.DOMAIN, {}) is True
assert entry.unique_id == "mock-id"

View file

@ -1001,3 +1001,110 @@ async def test_reload_entry_entity_registry_works(hass):
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_unload_entry.mock_calls) == 1 assert len(mock_unload_entry.mock_calls) == 1
async def test_unqiue_id_persisted(hass, manager):
"""Test that a unique ID is stored in the config entry."""
mock_setup_entry = MagicMock(return_value=mock_coro(True))
mock_integration(hass, MockModule("comp", async_setup_entry=mock_setup_entry))
mock_entity_platform(hass, "config_flow.comp", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 1
async def async_step_user(self, user_input=None):
await self.async_set_unique_id("mock-unique-id")
return self.async_create_entry(title="mock-title", data={})
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
await manager.flow.async_init(
"comp", context={"source": config_entries.SOURCE_USER}
)
assert len(mock_setup_entry.mock_calls) == 1
p_hass, p_entry = mock_setup_entry.mock_calls[0][1]
assert p_hass is hass
assert p_entry.unique_id == "mock-unique-id"
async def test_unique_id_existing_entry(hass, manager):
"""Test that we remove an entry if there already is an entry with unique ID."""
hass.config.components.add("comp")
MockConfigEntry(
domain="comp",
state=config_entries.ENTRY_STATE_LOADED,
unique_id="mock-unique-id",
).add_to_hass(hass)
async_setup_entry = MagicMock(side_effect=lambda _, _2: mock_coro(True))
async_unload_entry = MagicMock(side_effect=lambda _, _2: mock_coro(True))
async_remove_entry = MagicMock(side_effect=lambda _, _2: mock_coro(True))
mock_integration(
hass,
MockModule(
"comp",
async_setup_entry=async_setup_entry,
async_unload_entry=async_unload_entry,
async_remove_entry=async_remove_entry,
),
)
mock_entity_platform(hass, "config_flow.comp", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 1
async def async_step_user(self, user_input=None):
existing_entry = await self.async_set_unique_id("mock-unique-id")
assert existing_entry is not None
return self.async_create_entry(title="mock-title", data={"via": "flow"})
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
result = await manager.flow.async_init(
"comp", context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
entries = hass.config_entries.async_entries("comp")
assert len(entries) == 1
assert entries[0].data == {"via": "flow"}
assert len(async_setup_entry.mock_calls) == 1
assert len(async_unload_entry.mock_calls) == 1
assert len(async_remove_entry.mock_calls) == 1
async def test_unique_id_in_progress(hass, manager):
"""Test that we abort if there is already a flow in progress with same unique id."""
mock_integration(hass, MockModule("comp"))
mock_entity_platform(hass, "config_flow.comp", None)
class TestFlow(config_entries.ConfigFlow):
VERSION = 1
async def async_step_user(self, user_input=None):
await self.async_set_unique_id("mock-unique-id")
return self.async_show_form(step_id="discovery")
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
# Create one to be in progress
result = await manager.flow.async_init(
"comp", context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
# Will be canceled
result2 = await manager.flow.async_init(
"comp", context={"source": config_entries.SOURCE_USER}
)
assert result2["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result2["reason"] == "already_in_progress"

View file

@ -94,7 +94,7 @@ async def test_configure_two_steps(manager):
async def test_show_form(manager): async def test_show_form(manager):
"""Test that abort removes the flow from progress.""" """Test that we can show a form."""
schema = vol.Schema({vol.Required("username"): str, vol.Required("password"): str}) schema = vol.Schema({vol.Required("username"): str, vol.Required("password"): str})
@manager.mock_reg_handler("test") @manager.mock_reg_handler("test")
@ -271,3 +271,17 @@ async def test_external_step(hass, manager):
result = await manager.async_configure(result["flow_id"]) result = await manager.async_configure(result["flow_id"])
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["title"] == "Hello" assert result["title"] == "Hello"
async def test_abort_flow_exception(manager):
"""Test that the AbortFlow exception works."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
raise data_entry_flow.AbortFlow("mock-reason", {"placeholder": "yo"})
form = await manager.async_init("test")
assert form["type"] == "abort"
assert form["reason"] == "mock-reason"
assert form["description_placeholders"] == {"placeholder": "yo"}