Index entity_registry_updated listeners (#37940)
This commit is contained in:
parent
9ae08585dc
commit
910b6c9c2c
4 changed files with 189 additions and 19 deletions
|
@ -27,11 +27,8 @@ from homeassistant.const import (
|
|||
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
|
||||
from homeassistant.exceptions import NoEntitySpecifiedError
|
||||
from homeassistant.helpers.entity_platform import EntityPlatform
|
||||
from homeassistant.helpers.entity_registry import (
|
||||
EVENT_ENTITY_REGISTRY_UPDATED,
|
||||
RegistryEntry,
|
||||
)
|
||||
from homeassistant.helpers.event import Event
|
||||
from homeassistant.helpers.entity_registry import RegistryEntry
|
||||
from homeassistant.helpers.event import Event, async_track_entity_registry_updated_event
|
||||
from homeassistant.util import dt as dt_util, ensure_unique_string, slugify
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
|
||||
|
@ -518,8 +515,8 @@ class Entity(ABC):
|
|||
if self.registry_entry is not None:
|
||||
assert self.hass is not None
|
||||
self.async_on_remove(
|
||||
self.hass.bus.async_listen(
|
||||
EVENT_ENTITY_REGISTRY_UPDATED, self._async_registry_updated
|
||||
async_track_entity_registry_updated_event(
|
||||
self.hass, self.entity_id, self._async_registry_updated
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -532,14 +529,11 @@ class Entity(ABC):
|
|||
async def _async_registry_updated(self, event: Event) -> None:
|
||||
"""Handle entity registry update."""
|
||||
data = event.data
|
||||
if data["action"] == "remove" and data["entity_id"] == self.entity_id:
|
||||
if data["action"] == "remove":
|
||||
await self.async_removed_from_registry()
|
||||
await self.async_remove()
|
||||
|
||||
if (
|
||||
data["action"] != "update"
|
||||
or data.get("old_entity_id", data["entity_id"]) != self.entity_id
|
||||
):
|
||||
if data["action"] != "update":
|
||||
return
|
||||
|
||||
assert self.hass is not None
|
||||
|
|
|
@ -17,6 +17,7 @@ from homeassistant.const import (
|
|||
SUN_EVENT_SUNSET,
|
||||
)
|
||||
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, State, callback
|
||||
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
|
||||
from homeassistant.helpers.sun import get_astral_event_next
|
||||
from homeassistant.helpers.template import Template
|
||||
from homeassistant.loader import bind_hass
|
||||
|
@ -26,6 +27,9 @@ from homeassistant.util.async_ import run_callback_threadsafe
|
|||
TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks"
|
||||
TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener"
|
||||
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks"
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# PyLint does not like the use of threaded_listener_factory
|
||||
|
@ -137,7 +141,7 @@ track_state_change = threaded_listener_factory(async_track_state_change)
|
|||
def async_track_state_change_event(
|
||||
hass: HomeAssistant,
|
||||
entity_ids: Union[str, Iterable[str]],
|
||||
action: Callable[[Event], None],
|
||||
action: Callable[[Event], Any],
|
||||
) -> Callable[[], None]:
|
||||
"""Track specific state change events indexed by entity_id.
|
||||
|
||||
|
@ -186,17 +190,28 @@ def async_track_state_change_event(
|
|||
@callback
|
||||
def remove_listener() -> None:
|
||||
"""Remove state change listener."""
|
||||
_async_remove_state_change_listeners(hass, entity_ids, action)
|
||||
_async_remove_entity_listeners(
|
||||
hass,
|
||||
TRACK_STATE_CHANGE_CALLBACKS,
|
||||
TRACK_STATE_CHANGE_LISTENER,
|
||||
entity_ids,
|
||||
action,
|
||||
)
|
||||
|
||||
return remove_listener
|
||||
|
||||
|
||||
@callback
|
||||
def _async_remove_state_change_listeners(
|
||||
hass: HomeAssistant, entity_ids: Iterable[str], action: Callable[[Event], None]
|
||||
def _async_remove_entity_listeners(
|
||||
hass: HomeAssistant,
|
||||
storage_key: str,
|
||||
listener_key: str,
|
||||
entity_ids: Iterable[str],
|
||||
action: Callable[[Event], Any],
|
||||
) -> None:
|
||||
"""Remove a listener."""
|
||||
entity_callbacks = hass.data[TRACK_STATE_CHANGE_CALLBACKS]
|
||||
|
||||
entity_callbacks = hass.data[storage_key]
|
||||
|
||||
for entity_id in entity_ids:
|
||||
entity_callbacks[entity_id].remove(action)
|
||||
|
@ -204,8 +219,66 @@ def _async_remove_state_change_listeners(
|
|||
del entity_callbacks[entity_id]
|
||||
|
||||
if not entity_callbacks:
|
||||
hass.data[TRACK_STATE_CHANGE_LISTENER]()
|
||||
del hass.data[TRACK_STATE_CHANGE_LISTENER]
|
||||
hass.data[listener_key]()
|
||||
del hass.data[listener_key]
|
||||
|
||||
|
||||
@bind_hass
|
||||
def async_track_entity_registry_updated_event(
|
||||
hass: HomeAssistant,
|
||||
entity_ids: Union[str, Iterable[str]],
|
||||
action: Callable[[Event], Any],
|
||||
) -> Callable[[], None]:
|
||||
"""Track specific entity registry updated events indexed by entity_id.
|
||||
|
||||
Similar to async_track_state_change_event.
|
||||
"""
|
||||
|
||||
entity_callbacks = hass.data.setdefault(TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, {})
|
||||
|
||||
if TRACK_ENTITY_REGISTRY_UPDATED_LISTENER not in hass.data:
|
||||
|
||||
@callback
|
||||
def _async_entity_registry_updated_dispatcher(event: Event) -> None:
|
||||
"""Dispatch entity registry updates by entity_id."""
|
||||
entity_id = event.data.get("old_entity_id", event.data["entity_id"])
|
||||
|
||||
if entity_id not in entity_callbacks:
|
||||
return
|
||||
|
||||
for action in entity_callbacks[entity_id][:]:
|
||||
try:
|
||||
hass.async_run_job(action, event)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception(
|
||||
"Error while processing entity registry update for %s",
|
||||
entity_id,
|
||||
)
|
||||
|
||||
hass.data[TRACK_ENTITY_REGISTRY_UPDATED_LISTENER] = hass.bus.async_listen(
|
||||
EVENT_ENTITY_REGISTRY_UPDATED, _async_entity_registry_updated_dispatcher
|
||||
)
|
||||
|
||||
if isinstance(entity_ids, str):
|
||||
entity_ids = [entity_ids]
|
||||
|
||||
entity_ids = [entity_id.lower() for entity_id in entity_ids]
|
||||
|
||||
for entity_id in entity_ids:
|
||||
entity_callbacks.setdefault(entity_id, []).append(action)
|
||||
|
||||
@callback
|
||||
def remove_listener() -> None:
|
||||
"""Remove state change listener."""
|
||||
_async_remove_entity_listeners(
|
||||
hass,
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS,
|
||||
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER,
|
||||
entity_ids,
|
||||
action,
|
||||
)
|
||||
|
||||
return remove_listener
|
||||
|
||||
|
||||
@callback
|
||||
|
|
|
@ -338,6 +338,7 @@ async def test_cleanup_device(hass, device_reg, entity_reg, mqtt_mock):
|
|||
# Verify state is removed
|
||||
state = hass.states.get("sensor.mqtt_sensor")
|
||||
assert state is None
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Verify retained discovery topic has been cleared
|
||||
mqtt_mock.async_publish.assert_called_once_with(
|
||||
|
|
|
@ -10,6 +10,7 @@ from homeassistant.components import sun
|
|||
from homeassistant.const import MATCH_ALL
|
||||
import homeassistant.core as ha
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
|
||||
from homeassistant.helpers.event import (
|
||||
async_call_later,
|
||||
async_track_point_in_time,
|
||||
|
@ -1180,3 +1181,104 @@ async def test_async_track_point_in_time_cancel(hass):
|
|||
|
||||
assert len(times) == 1
|
||||
assert times[0].tzinfo.zone == "US/Hawaii"
|
||||
|
||||
|
||||
async def test_async_track_entity_registry_updated_event(hass):
|
||||
"""Test tracking entity registry updates for an entity_id."""
|
||||
|
||||
entity_id = "switch.puppy_feeder"
|
||||
new_entity_id = "switch.dog_feeder"
|
||||
untracked_entity_id = "switch.kitty_feeder"
|
||||
|
||||
hass.states.async_set(entity_id, "on")
|
||||
await hass.async_block_till_done()
|
||||
event_data = []
|
||||
|
||||
@ha.callback
|
||||
def run_callback(event):
|
||||
event_data.append(event.data)
|
||||
|
||||
unsub1 = hass.helpers.event.async_track_entity_registry_updated_event(
|
||||
entity_id, run_callback
|
||||
)
|
||||
unsub2 = hass.helpers.event.async_track_entity_registry_updated_event(
|
||||
new_entity_id, run_callback
|
||||
)
|
||||
hass.bus.async_fire(
|
||||
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id}
|
||||
)
|
||||
hass.bus.async_fire(
|
||||
EVENT_ENTITY_REGISTRY_UPDATED,
|
||||
{"action": "create", "entity_id": untracked_entity_id},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
hass.bus.async_fire(
|
||||
EVENT_ENTITY_REGISTRY_UPDATED,
|
||||
{
|
||||
"action": "update",
|
||||
"entity_id": new_entity_id,
|
||||
"old_entity_id": entity_id,
|
||||
"changes": {},
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
hass.bus.async_fire(
|
||||
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": new_entity_id}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
unsub1()
|
||||
unsub2()
|
||||
hass.bus.async_fire(
|
||||
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id}
|
||||
)
|
||||
hass.bus.async_fire(
|
||||
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": new_entity_id}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert event_data[0] == {"action": "create", "entity_id": "switch.puppy_feeder"}
|
||||
assert event_data[1] == {
|
||||
"action": "update",
|
||||
"changes": {},
|
||||
"entity_id": "switch.dog_feeder",
|
||||
"old_entity_id": "switch.puppy_feeder",
|
||||
}
|
||||
assert event_data[2] == {"action": "remove", "entity_id": "switch.dog_feeder"}
|
||||
|
||||
|
||||
async def test_async_track_entity_registry_updated_event_with_a_callback_that_throws(
|
||||
hass,
|
||||
):
|
||||
"""Test tracking entity registry updates for an entity_id when one callback throws."""
|
||||
|
||||
entity_id = "switch.puppy_feeder"
|
||||
|
||||
hass.states.async_set(entity_id, "on")
|
||||
await hass.async_block_till_done()
|
||||
event_data = []
|
||||
|
||||
@ha.callback
|
||||
def run_callback(event):
|
||||
event_data.append(event.data)
|
||||
|
||||
@ha.callback
|
||||
def failing_callback(event):
|
||||
raise ValueError
|
||||
|
||||
unsub1 = hass.helpers.event.async_track_entity_registry_updated_event(
|
||||
entity_id, failing_callback
|
||||
)
|
||||
unsub2 = hass.helpers.event.async_track_entity_registry_updated_event(
|
||||
entity_id, run_callback
|
||||
)
|
||||
hass.bus.async_fire(
|
||||
EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
unsub1()
|
||||
unsub2()
|
||||
|
||||
assert event_data[0] == {"action": "create", "entity_id": "switch.puppy_feeder"}
|
||||
|
|
Loading…
Add table
Reference in a new issue