Reload config entry when entity enabled in entity registry, remove entity if disabled. (#26120)
* Reload config entry when disabled_by updated in entity registry * Add types * Remove entities that get disabled * Remove unnecessary domain checks. * Attach handler in async_setup * Remove unused var * Type * Fix test * Fix tests
This commit is contained in:
parent
05ed3c44ea
commit
65cf5a6ef5
7 changed files with 219 additions and 12 deletions
|
@ -3,13 +3,7 @@ import asyncio
|
|||
import logging
|
||||
import functools
|
||||
import uuid
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Set, # noqa pylint: disable=unused-import
|
||||
)
|
||||
from typing import Any, Callable, List, Optional, Set
|
||||
import weakref
|
||||
|
||||
import attr
|
||||
|
@ -19,6 +13,7 @@ from homeassistant.core import callback, HomeAssistant
|
|||
from homeassistant.exceptions import HomeAssistantError, ConfigEntryNotReady
|
||||
from homeassistant.setup import async_setup_component, async_process_deps_reqs
|
||||
from homeassistant.util.decorator import Registry
|
||||
from homeassistant.helpers import entity_registry
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
|
@ -161,8 +156,6 @@ class ConfigEntry:
|
|||
|
||||
try:
|
||||
component = integration.get_component()
|
||||
if self.domain == integration.domain:
|
||||
integration.get_platform("config_flow")
|
||||
except ImportError as err:
|
||||
_LOGGER.error(
|
||||
"Error importing integration %s to set up %s config entry: %s",
|
||||
|
@ -174,8 +167,20 @@ class ConfigEntry:
|
|||
self.state = ENTRY_STATE_SETUP_ERROR
|
||||
return
|
||||
|
||||
# Perform migration
|
||||
if integration.domain == self.domain:
|
||||
if self.domain == integration.domain:
|
||||
try:
|
||||
integration.get_platform("config_flow")
|
||||
except ImportError as err:
|
||||
_LOGGER.error(
|
||||
"Error importing platform config_flow from integration %s to set up %s config entry: %s",
|
||||
integration.domain,
|
||||
self.domain,
|
||||
err,
|
||||
)
|
||||
self.state = ENTRY_STATE_SETUP_ERROR
|
||||
return
|
||||
|
||||
# Perform migration
|
||||
if not await self.async_migrate(hass):
|
||||
self.state = ENTRY_STATE_MIGRATION_ERROR
|
||||
return
|
||||
|
@ -383,6 +388,7 @@ class ConfigEntries:
|
|||
self._hass_config = hass_config
|
||||
self._entries = [] # type: List[ConfigEntry]
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
EntityRegistryDisabledHandler(hass).async_setup()
|
||||
|
||||
@callback
|
||||
def async_domains(self) -> List[str]:
|
||||
|
@ -757,3 +763,91 @@ class SystemOptions:
|
|||
def as_dict(self):
|
||||
"""Return dictionary version of this config entrys system options."""
|
||||
return {"disable_new_entities": self.disable_new_entities}
|
||||
|
||||
|
||||
class EntityRegistryDisabledHandler:
|
||||
"""Handler to handle when entities related to config entries updating disabled_by."""
|
||||
|
||||
RELOAD_AFTER_UPDATE_DELAY = 30
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the handler."""
|
||||
self.hass = hass
|
||||
self.registry: Optional[entity_registry.EntityRegistry] = None
|
||||
self.changed: Set[str] = set()
|
||||
self._remove_call_later: Optional[Callable[[], None]] = None
|
||||
|
||||
@callback
|
||||
def async_setup(self) -> None:
|
||||
"""Set up the disable handler."""
|
||||
self.hass.bus.async_listen(
|
||||
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._handle_entry_updated
|
||||
)
|
||||
|
||||
async def _handle_entry_updated(self, event):
|
||||
"""Handle entity registry entry update."""
|
||||
if (
|
||||
event.data["action"] != "update"
|
||||
or "disabled_by" not in event.data["changes"]
|
||||
):
|
||||
return
|
||||
|
||||
if self.registry is None:
|
||||
self.registry = await entity_registry.async_get_registry(self.hass)
|
||||
|
||||
entity_entry = self.registry.async_get(event.data["entity_id"])
|
||||
|
||||
if (
|
||||
# Stop if no entry found
|
||||
entity_entry is None
|
||||
# Stop if entry not connected to config entry
|
||||
or entity_entry.config_entry_id is None
|
||||
# Stop if the entry got disabled. In that case the entity handles it
|
||||
# themselves.
|
||||
or entity_entry.disabled_by
|
||||
):
|
||||
return
|
||||
|
||||
config_entry = self.hass.config_entries.async_get_entry(
|
||||
entity_entry.config_entry_id
|
||||
)
|
||||
|
||||
if config_entry.entry_id not in self.changed and await support_entry_unload(
|
||||
self.hass, config_entry.domain
|
||||
):
|
||||
self.changed.add(config_entry.entry_id)
|
||||
|
||||
if not self.changed:
|
||||
return
|
||||
|
||||
# We are going to delay reloading on *every* entity registry change so that
|
||||
# if a user is happily clicking along, it will only reload at the end.
|
||||
|
||||
if self._remove_call_later:
|
||||
self._remove_call_later()
|
||||
|
||||
self._remove_call_later = self.hass.helpers.event.async_call_later(
|
||||
self.RELOAD_AFTER_UPDATE_DELAY, self._handle_reload
|
||||
)
|
||||
|
||||
async def _handle_reload(self, _now):
|
||||
"""Handle a reload."""
|
||||
self._remove_call_later = None
|
||||
to_reload = self.changed
|
||||
self.changed = set()
|
||||
|
||||
_LOGGER.info(
|
||||
"Reloading config entries because disabled_by changed in entity registry: %s",
|
||||
", ".join(self.changed),
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
*[self.hass.config_entries.async_reload(entry_id) for entry_id in to_reload]
|
||||
)
|
||||
|
||||
|
||||
async def support_entry_unload(hass: HomeAssistant, domain: str) -> bool:
|
||||
"""Test if a domain supports entry unloading."""
|
||||
integration = await loader.async_get_integration(hass, domain)
|
||||
component = integration.get_component()
|
||||
return hasattr(component, "async_unload_entry")
|
||||
|
|
|
@ -503,6 +503,10 @@ class Entity:
|
|||
old = self.registry_entry
|
||||
self.registry_entry = ent_reg.async_get(data["entity_id"])
|
||||
|
||||
if self.registry_entry.disabled_by is not None:
|
||||
await self.async_remove()
|
||||
return
|
||||
|
||||
if self.registry_entry.entity_id == old.entity_id:
|
||||
self.async_write_ha_state()
|
||||
return
|
||||
|
|
|
@ -302,7 +302,7 @@ class EntityRegistry:
|
|||
|
||||
self.async_schedule_save()
|
||||
|
||||
data = {"action": "update", "entity_id": entity_id}
|
||||
data = {"action": "update", "entity_id": entity_id, "changes": list(changes)}
|
||||
|
||||
if old.entity_id != entity_id:
|
||||
data["old_entity_id"] = old.entity_id
|
||||
|
|
|
@ -163,6 +163,7 @@ async def test_update_entity(hass, client):
|
|||
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert hass.states.get("test_domain.world") is None
|
||||
assert registry.entities["test_domain.world"].disabled_by == "user"
|
||||
|
||||
# UPDATE DISABLED_BY TO NONE
|
||||
|
|
|
@ -526,3 +526,34 @@ async def test_warn_disabled(hass, caplog):
|
|||
ent.async_write_ha_state()
|
||||
assert hass.states.get("hello.world") is None
|
||||
assert caplog.text == ""
|
||||
|
||||
|
||||
async def test_disabled_in_entity_registry(hass):
|
||||
"""Test entity is removed if we disable entity registry entry."""
|
||||
entry = entity_registry.RegistryEntry(
|
||||
entity_id="hello.world",
|
||||
unique_id="test-unique-id",
|
||||
platform="test-platform",
|
||||
disabled_by="user",
|
||||
)
|
||||
registry = mock_registry(hass, {"hello.world": entry})
|
||||
|
||||
ent = entity.Entity()
|
||||
ent.hass = hass
|
||||
ent.entity_id = "hello.world"
|
||||
ent.registry_entry = entry
|
||||
ent.platform = MagicMock(platform_name="test-platform")
|
||||
|
||||
await ent.async_internal_added_to_hass()
|
||||
ent.async_write_ha_state()
|
||||
assert hass.states.get("hello.world") is None
|
||||
|
||||
entry2 = registry.async_update_entity("hello.world", disabled_by=None)
|
||||
await hass.async_block_till_done()
|
||||
assert entry2 != entry
|
||||
assert ent.registry_entry == entry2
|
||||
|
||||
entry3 = registry.async_update_entity("hello.world", disabled_by="user")
|
||||
await hass.async_block_till_done()
|
||||
assert entry3 != entry2
|
||||
assert ent.registry_entry == entry3
|
||||
|
|
|
@ -219,6 +219,7 @@ async def test_updating_config_entry_id(hass, registry, update_events):
|
|||
assert update_events[0]["entity_id"] == entry.entity_id
|
||||
assert update_events[1]["action"] == "update"
|
||||
assert update_events[1]["entity_id"] == entry.entity_id
|
||||
assert update_events[1]["changes"] == ["config_entry_id"]
|
||||
|
||||
|
||||
async def test_removing_config_entry_id(hass, registry, update_events):
|
||||
|
|
|
@ -20,6 +20,7 @@ from tests.common import (
|
|||
MockEntity,
|
||||
mock_integration,
|
||||
mock_entity_platform,
|
||||
mock_registry,
|
||||
)
|
||||
|
||||
|
||||
|
@ -925,3 +926,78 @@ async def test_init_custom_integration(hass):
|
|||
return_value=mock_coro(integration),
|
||||
):
|
||||
await hass.config_entries.flow.async_init("bla")
|
||||
|
||||
|
||||
async def test_support_entry_unload(hass):
|
||||
"""Test unloading entry."""
|
||||
assert await config_entries.support_entry_unload(hass, "light")
|
||||
assert not await config_entries.support_entry_unload(hass, "auth")
|
||||
|
||||
|
||||
async def test_reload_entry_entity_registry_ignores_no_entry(hass):
|
||||
"""Test reloading entry in entity registry skips if no config entry linked."""
|
||||
handler = config_entries.EntityRegistryDisabledHandler(hass)
|
||||
registry = mock_registry(hass)
|
||||
|
||||
# Test we ignore entities without config entry
|
||||
entry = registry.async_get_or_create("light", "hue", "123")
|
||||
registry.async_update_entity(entry.entity_id, disabled_by="user")
|
||||
await hass.async_block_till_done()
|
||||
assert not handler.changed
|
||||
assert handler._remove_call_later is None
|
||||
|
||||
|
||||
async def test_reload_entry_entity_registry_works(hass):
|
||||
"""Test we schedule an entry to be reloaded if disabled_by is updated."""
|
||||
handler = config_entries.EntityRegistryDisabledHandler(hass)
|
||||
handler.async_setup()
|
||||
registry = mock_registry(hass)
|
||||
|
||||
config_entry = MockConfigEntry(
|
||||
domain="comp", state=config_entries.ENTRY_STATE_LOADED
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
mock_setup_entry = MagicMock(return_value=mock_coro(True))
|
||||
mock_unload_entry = MagicMock(return_value=mock_coro(True))
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
"comp",
|
||||
async_setup_entry=mock_setup_entry,
|
||||
async_unload_entry=mock_unload_entry,
|
||||
),
|
||||
)
|
||||
mock_entity_platform(hass, "config_flow.comp", None)
|
||||
|
||||
# Only changing disabled_by should update trigger
|
||||
entity_entry = registry.async_get_or_create(
|
||||
"light", "hue", "123", config_entry=config_entry
|
||||
)
|
||||
registry.async_update_entity(entity_entry.entity_id, name="yo")
|
||||
await hass.async_block_till_done()
|
||||
assert not handler.changed
|
||||
assert handler._remove_call_later is None
|
||||
|
||||
# Disable entity, we should not do anything, only act when enabled.
|
||||
registry.async_update_entity(entity_entry.entity_id, disabled_by="user")
|
||||
await hass.async_block_till_done()
|
||||
assert not handler.changed
|
||||
assert handler._remove_call_later is None
|
||||
|
||||
# Enable entity, check we are reloading config entry.
|
||||
registry.async_update_entity(entity_entry.entity_id, disabled_by=None)
|
||||
await hass.async_block_till_done()
|
||||
assert handler.changed == {config_entry.entry_id}
|
||||
assert handler._remove_call_later is not None
|
||||
|
||||
async_fire_time_changed(
|
||||
hass,
|
||||
dt.utcnow()
|
||||
+ timedelta(
|
||||
seconds=config_entries.EntityRegistryDisabledHandler.RELOAD_AFTER_UPDATE_DELAY
|
||||
+ 1
|
||||
),
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(mock_unload_entry.mock_calls) == 1
|
||||
|
|
Loading…
Add table
Reference in a new issue