Refactor config entry storage and index (#107590)
This commit is contained in:
parent
2c6aa80bc7
commit
3649cb96de
3 changed files with 153 additions and 57 deletions
|
@ -2,7 +2,15 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable, Coroutine, Generator, Iterable, Mapping
|
from collections import UserDict
|
||||||
|
from collections.abc import (
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Generator,
|
||||||
|
Iterable,
|
||||||
|
Mapping,
|
||||||
|
ValuesView,
|
||||||
|
)
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
|
@ -336,6 +344,13 @@ class ConfigEntry:
|
||||||
self._tries = 0
|
self._tries = 0
|
||||||
self._setup_again_job: HassJob | None = None
|
self._setup_again_job: HassJob | None = None
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Representation of ConfigEntry."""
|
||||||
|
return (
|
||||||
|
f"<ConfigEntry entry_id={self.entry_id} version={self.version} domain={self.domain} "
|
||||||
|
f"title={self.title} state={self.state} unique_id={self.unique_id}>"
|
||||||
|
)
|
||||||
|
|
||||||
async def async_setup(
|
async def async_setup(
|
||||||
self,
|
self,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
@ -1057,6 +1072,67 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigEntryItems(UserDict[str, ConfigEntry]):
|
||||||
|
"""Container for config items, maps config_entry_id -> entry.
|
||||||
|
|
||||||
|
Maintains two additional indexes:
|
||||||
|
- domain -> list[ConfigEntry]
|
||||||
|
- domain -> unique_id -> ConfigEntry
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the container."""
|
||||||
|
super().__init__()
|
||||||
|
self._domain_index: dict[str, list[ConfigEntry]] = {}
|
||||||
|
self._domain_unique_id_index: dict[str, dict[str, ConfigEntry]] = {}
|
||||||
|
|
||||||
|
def values(self) -> ValuesView[ConfigEntry]:
|
||||||
|
"""Return the underlying values to avoid __iter__ overhead."""
|
||||||
|
return self.data.values()
|
||||||
|
|
||||||
|
def __setitem__(self, entry_id: str, entry: ConfigEntry) -> None:
|
||||||
|
"""Add an item."""
|
||||||
|
data = self.data
|
||||||
|
if entry_id in data:
|
||||||
|
# This is likely a bug in a test that is adding the same entry twice.
|
||||||
|
# In the future, once we have fixed the tests, this will raise HomeAssistantError.
|
||||||
|
_LOGGER.error("An entry with the id %s already exists", entry_id)
|
||||||
|
self._unindex_entry(entry_id)
|
||||||
|
data[entry_id] = entry
|
||||||
|
self._domain_index.setdefault(entry.domain, []).append(entry)
|
||||||
|
if entry.unique_id is not None:
|
||||||
|
self._domain_unique_id_index.setdefault(entry.domain, {})[
|
||||||
|
entry.unique_id
|
||||||
|
] = entry
|
||||||
|
|
||||||
|
def _unindex_entry(self, entry_id: str) -> None:
|
||||||
|
"""Unindex an entry."""
|
||||||
|
entry = self.data[entry_id]
|
||||||
|
domain = entry.domain
|
||||||
|
self._domain_index[domain].remove(entry)
|
||||||
|
if not self._domain_index[domain]:
|
||||||
|
del self._domain_index[domain]
|
||||||
|
if (unique_id := entry.unique_id) is not None:
|
||||||
|
del self._domain_unique_id_index[domain][unique_id]
|
||||||
|
if not self._domain_unique_id_index[domain]:
|
||||||
|
del self._domain_unique_id_index[domain]
|
||||||
|
|
||||||
|
def __delitem__(self, entry_id: str) -> None:
|
||||||
|
"""Remove an item."""
|
||||||
|
self._unindex_entry(entry_id)
|
||||||
|
super().__delitem__(entry_id)
|
||||||
|
|
||||||
|
def get_entries_for_domain(self, domain: str) -> list[ConfigEntry]:
|
||||||
|
"""Get entries for a domain."""
|
||||||
|
return self._domain_index.get(domain, [])
|
||||||
|
|
||||||
|
def get_entry_by_domain_and_unique_id(
|
||||||
|
self, domain: str, unique_id: str
|
||||||
|
) -> ConfigEntry | None:
|
||||||
|
"""Get entry by domain and unique id."""
|
||||||
|
return self._domain_unique_id_index.get(domain, {}).get(unique_id)
|
||||||
|
|
||||||
|
|
||||||
class ConfigEntries:
|
class ConfigEntries:
|
||||||
"""Manage the configuration entries.
|
"""Manage the configuration entries.
|
||||||
|
|
||||||
|
@ -1069,8 +1145,7 @@ class ConfigEntries:
|
||||||
self.flow = ConfigEntriesFlowManager(hass, self, hass_config)
|
self.flow = ConfigEntriesFlowManager(hass, self, hass_config)
|
||||||
self.options = OptionsFlowManager(hass)
|
self.options = OptionsFlowManager(hass)
|
||||||
self._hass_config = hass_config
|
self._hass_config = hass_config
|
||||||
self._entries: dict[str, ConfigEntry] = {}
|
self._entries = ConfigEntryItems()
|
||||||
self._domain_index: dict[str, list[ConfigEntry]] = {}
|
|
||||||
self._store = storage.Store[dict[str, list[dict[str, Any]]]](
|
self._store = storage.Store[dict[str, list[dict[str, Any]]]](
|
||||||
hass, STORAGE_VERSION, STORAGE_KEY
|
hass, STORAGE_VERSION, STORAGE_KEY
|
||||||
)
|
)
|
||||||
|
@ -1093,23 +1168,29 @@ class ConfigEntries:
|
||||||
@callback
|
@callback
|
||||||
def async_get_entry(self, entry_id: str) -> ConfigEntry | None:
|
def async_get_entry(self, entry_id: str) -> ConfigEntry | None:
|
||||||
"""Return entry with matching entry_id."""
|
"""Return entry with matching entry_id."""
|
||||||
return self._entries.get(entry_id)
|
return self._entries.data.get(entry_id)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_entries(self, domain: str | None = None) -> list[ConfigEntry]:
|
def async_entries(self, domain: str | None = None) -> list[ConfigEntry]:
|
||||||
"""Return all entries or entries for a specific domain."""
|
"""Return all entries or entries for a specific domain."""
|
||||||
if domain is None:
|
if domain is None:
|
||||||
return list(self._entries.values())
|
return list(self._entries.values())
|
||||||
return list(self._domain_index.get(domain, []))
|
return list(self._entries.get_entries_for_domain(domain))
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_entry_for_domain_unique_id(
|
||||||
|
self, domain: str, unique_id: str
|
||||||
|
) -> ConfigEntry | None:
|
||||||
|
"""Return entry for a domain with a matching unique id."""
|
||||||
|
return self._entries.get_entry_by_domain_and_unique_id(domain, unique_id)
|
||||||
|
|
||||||
async def async_add(self, entry: ConfigEntry) -> None:
|
async def async_add(self, entry: ConfigEntry) -> None:
|
||||||
"""Add and setup an entry."""
|
"""Add and setup an entry."""
|
||||||
if entry.entry_id in self._entries:
|
if entry.entry_id in self._entries.data:
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
f"An entry with the id {entry.entry_id} already exists."
|
f"An entry with the id {entry.entry_id} already exists."
|
||||||
)
|
)
|
||||||
self._entries[entry.entry_id] = entry
|
self._entries[entry.entry_id] = entry
|
||||||
self._domain_index.setdefault(entry.domain, []).append(entry)
|
|
||||||
self._async_dispatch(ConfigEntryChange.ADDED, entry)
|
self._async_dispatch(ConfigEntryChange.ADDED, entry)
|
||||||
await self.async_setup(entry.entry_id)
|
await self.async_setup(entry.entry_id)
|
||||||
self._async_schedule_save()
|
self._async_schedule_save()
|
||||||
|
@ -1127,9 +1208,6 @@ class ConfigEntries:
|
||||||
await entry.async_remove(self.hass)
|
await entry.async_remove(self.hass)
|
||||||
|
|
||||||
del self._entries[entry.entry_id]
|
del self._entries[entry.entry_id]
|
||||||
self._domain_index[entry.domain].remove(entry)
|
|
||||||
if not self._domain_index[entry.domain]:
|
|
||||||
del self._domain_index[entry.domain]
|
|
||||||
self._async_schedule_save()
|
self._async_schedule_save()
|
||||||
|
|
||||||
dev_reg = device_registry.async_get(self.hass)
|
dev_reg = device_registry.async_get(self.hass)
|
||||||
|
@ -1189,13 +1267,10 @@ class ConfigEntries:
|
||||||
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._async_shutdown)
|
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._async_shutdown)
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
self._entries = {}
|
self._entries = ConfigEntryItems()
|
||||||
self._domain_index = {}
|
|
||||||
return
|
return
|
||||||
|
|
||||||
entries = {}
|
entries: ConfigEntryItems = ConfigEntryItems()
|
||||||
domain_index: dict[str, list[ConfigEntry]] = {}
|
|
||||||
|
|
||||||
for entry in config["entries"]:
|
for entry in config["entries"]:
|
||||||
pref_disable_new_entities = entry.get("pref_disable_new_entities")
|
pref_disable_new_entities = entry.get("pref_disable_new_entities")
|
||||||
|
|
||||||
|
@ -1230,9 +1305,7 @@ class ConfigEntries:
|
||||||
pref_disable_polling=entry.get("pref_disable_polling"),
|
pref_disable_polling=entry.get("pref_disable_polling"),
|
||||||
)
|
)
|
||||||
entries[entry_id] = config_entry
|
entries[entry_id] = config_entry
|
||||||
domain_index.setdefault(domain, []).append(config_entry)
|
|
||||||
|
|
||||||
self._domain_index = domain_index
|
|
||||||
self._entries = entries
|
self._entries = entries
|
||||||
|
|
||||||
async def async_setup(self, entry_id: str) -> bool:
|
async def async_setup(self, entry_id: str) -> bool:
|
||||||
|
@ -1365,8 +1438,15 @@ class ConfigEntries:
|
||||||
"""
|
"""
|
||||||
changed = False
|
changed = False
|
||||||
|
|
||||||
|
if unique_id is not UNDEFINED and entry.unique_id != unique_id:
|
||||||
|
# Reindex the entry if the unique_id has changed
|
||||||
|
entry_id = entry.entry_id
|
||||||
|
del self._entries[entry_id]
|
||||||
|
entry.unique_id = unique_id
|
||||||
|
self._entries[entry_id] = entry
|
||||||
|
changed = True
|
||||||
|
|
||||||
for attr, value in (
|
for attr, value in (
|
||||||
("unique_id", unique_id),
|
|
||||||
("title", title),
|
("title", title),
|
||||||
("pref_disable_new_entities", pref_disable_new_entities),
|
("pref_disable_new_entities", pref_disable_new_entities),
|
||||||
("pref_disable_polling", pref_disable_polling),
|
("pref_disable_polling", pref_disable_polling),
|
||||||
|
@ -1579,9 +1659,13 @@ class ConfigFlow(data_entry_flow.FlowHandler):
|
||||||
if self.unique_id is None:
|
if self.unique_id is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
for entry in self._async_current_entries(include_ignore=True):
|
if not (
|
||||||
if entry.unique_id != self.unique_id:
|
entry := self.hass.config_entries.async_entry_for_domain_unique_id(
|
||||||
continue
|
self.handler, self.unique_id
|
||||||
|
)
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
should_reload = False
|
should_reload = False
|
||||||
if (
|
if (
|
||||||
updates is not None
|
updates is not None
|
||||||
|
@ -1589,8 +1673,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
|
||||||
entry, data={**entry.data, **updates}
|
entry, data={**entry.data, **updates}
|
||||||
)
|
)
|
||||||
and reload_on_update
|
and reload_on_update
|
||||||
and entry.state
|
and entry.state in (ConfigEntryState.LOADED, ConfigEntryState.SETUP_RETRY)
|
||||||
in (ConfigEntryState.LOADED, ConfigEntryState.SETUP_RETRY)
|
|
||||||
):
|
):
|
||||||
# Existing config entry present, and the
|
# Existing config entry present, and the
|
||||||
# entry data just changed
|
# entry data just changed
|
||||||
|
@ -1604,7 +1687,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
|
||||||
should_reload = True
|
should_reload = True
|
||||||
# Allow ignored entries to be configured on manual user step
|
# Allow ignored entries to be configured on manual user step
|
||||||
if entry.source == SOURCE_IGNORE and self.source == SOURCE_USER:
|
if entry.source == SOURCE_IGNORE and self.source == SOURCE_USER:
|
||||||
continue
|
return
|
||||||
if should_reload:
|
if should_reload:
|
||||||
self.hass.async_create_task(
|
self.hass.async_create_task(
|
||||||
self.hass.config_entries.async_reload(entry.entry_id),
|
self.hass.config_entries.async_reload(entry.entry_id),
|
||||||
|
@ -1639,11 +1722,9 @@ class ConfigFlow(data_entry_flow.FlowHandler):
|
||||||
):
|
):
|
||||||
self.hass.config_entries.flow.async_abort(progress["flow_id"])
|
self.hass.config_entries.flow.async_abort(progress["flow_id"])
|
||||||
|
|
||||||
for entry in self._async_current_entries(include_ignore=True):
|
return self.hass.config_entries.async_entry_for_domain_unique_id(
|
||||||
if entry.unique_id == unique_id:
|
self.handler, unique_id
|
||||||
return entry
|
)
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _set_confirm_only(
|
def _set_confirm_only(
|
||||||
|
|
|
@ -939,12 +939,10 @@ class MockConfigEntry(config_entries.ConfigEntry):
|
||||||
def add_to_hass(self, hass: HomeAssistant) -> None:
|
def add_to_hass(self, hass: HomeAssistant) -> None:
|
||||||
"""Test helper to add entry to hass."""
|
"""Test helper to add entry to hass."""
|
||||||
hass.config_entries._entries[self.entry_id] = self
|
hass.config_entries._entries[self.entry_id] = self
|
||||||
hass.config_entries._domain_index.setdefault(self.domain, []).append(self)
|
|
||||||
|
|
||||||
def add_to_manager(self, manager: config_entries.ConfigEntries) -> None:
|
def add_to_manager(self, manager: config_entries.ConfigEntries) -> None:
|
||||||
"""Test helper to add entry to entry manager."""
|
"""Test helper to add entry to entry manager."""
|
||||||
manager._entries[self.entry_id] = self
|
manager._entries[self.entry_id] = self
|
||||||
manager._domain_index.setdefault(self.domain, []).append(self)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_yaml_files(files_dict, endswith=True):
|
def patch_yaml_files(files_dict, endswith=True):
|
||||||
|
|
|
@ -3123,6 +3123,9 @@ async def test_updating_entry_with_and_without_changes(
|
||||||
state=config_entries.ConfigEntryState.SETUP_ERROR,
|
state=config_entries.ConfigEntryState.SETUP_ERROR,
|
||||||
)
|
)
|
||||||
entry.add_to_manager(manager)
|
entry.add_to_manager(manager)
|
||||||
|
assert "abc123" in str(entry)
|
||||||
|
|
||||||
|
assert manager.async_entry_for_domain_unique_id("test", "abc123") is entry
|
||||||
|
|
||||||
assert manager.async_update_entry(entry) is False
|
assert manager.async_update_entry(entry) is False
|
||||||
|
|
||||||
|
@ -3138,6 +3141,10 @@ async def test_updating_entry_with_and_without_changes(
|
||||||
assert manager.async_update_entry(entry, **change) is True
|
assert manager.async_update_entry(entry, **change) is True
|
||||||
assert manager.async_update_entry(entry, **change) is False
|
assert manager.async_update_entry(entry, **change) is False
|
||||||
|
|
||||||
|
assert manager.async_entry_for_domain_unique_id("test", "abc123") is None
|
||||||
|
assert manager.async_entry_for_domain_unique_id("test", "abcd1234") is entry
|
||||||
|
assert "abcd1234" in str(entry)
|
||||||
|
|
||||||
|
|
||||||
async def test_entry_reload_calls_on_unload_listeners(
|
async def test_entry_reload_calls_on_unload_listeners(
|
||||||
hass: HomeAssistant, manager: config_entries.ConfigEntries
|
hass: HomeAssistant, manager: config_entries.ConfigEntries
|
||||||
|
@ -4127,3 +4134,13 @@ async def test_preview_not_supported(
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["preview"] is None
|
assert result["preview"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_raise_trying_to_add_same_config_entry_twice(
|
||||||
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||||
|
) -> None:
|
||||||
|
"""Test we log an error if trying to add same config entry twice."""
|
||||||
|
entry = MockConfigEntry(domain="test")
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
assert f"An entry with the id {entry.entry_id} already exists" in caplog.text
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue