diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 8414bb912c2..7c19d540704 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -27,11 +27,8 @@ from homeassistant.const import ( from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback from homeassistant.exceptions import NoEntitySpecifiedError from homeassistant.helpers.entity_platform import EntityPlatform -from homeassistant.helpers.entity_registry import ( - EVENT_ENTITY_REGISTRY_UPDATED, - RegistryEntry, -) -from homeassistant.helpers.event import Event +from homeassistant.helpers.entity_registry import RegistryEntry +from homeassistant.helpers.event import Event, async_track_entity_registry_updated_event from homeassistant.util import dt as dt_util, ensure_unique_string, slugify from homeassistant.util.async_ import run_callback_threadsafe @@ -518,8 +515,8 @@ class Entity(ABC): if self.registry_entry is not None: assert self.hass is not None self.async_on_remove( - self.hass.bus.async_listen( - EVENT_ENTITY_REGISTRY_UPDATED, self._async_registry_updated + async_track_entity_registry_updated_event( + self.hass, self.entity_id, self._async_registry_updated ) ) @@ -532,14 +529,11 @@ class Entity(ABC): async def _async_registry_updated(self, event: Event) -> None: """Handle entity registry update.""" data = event.data - if data["action"] == "remove" and data["entity_id"] == self.entity_id: + if data["action"] == "remove": await self.async_removed_from_registry() await self.async_remove() - if ( - data["action"] != "update" - or data.get("old_entity_id", data["entity_id"]) != self.entity_id - ): + if data["action"] != "update": return assert self.hass is not None diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 24110e8e63c..ecbf88d67a9 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -17,6 +17,7 @@ from homeassistant.const import ( SUN_EVENT_SUNSET, ) from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, State, callback +from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED from homeassistant.helpers.sun import get_astral_event_next from homeassistant.helpers.template import Template from homeassistant.loader import bind_hass @@ -26,6 +27,9 @@ from homeassistant.util.async_ import run_callback_threadsafe TRACK_STATE_CHANGE_CALLBACKS = "track_state_change_callbacks" TRACK_STATE_CHANGE_LISTENER = "track_state_change_listener" +TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS = "track_entity_registry_updated_callbacks" +TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener" + _LOGGER = logging.getLogger(__name__) # PyLint does not like the use of threaded_listener_factory @@ -137,7 +141,7 @@ track_state_change = threaded_listener_factory(async_track_state_change) def async_track_state_change_event( hass: HomeAssistant, entity_ids: Union[str, Iterable[str]], - action: Callable[[Event], None], + action: Callable[[Event], Any], ) -> Callable[[], None]: """Track specific state change events indexed by entity_id. @@ -186,17 +190,28 @@ def async_track_state_change_event( @callback def remove_listener() -> None: """Remove state change listener.""" - _async_remove_state_change_listeners(hass, entity_ids, action) + _async_remove_entity_listeners( + hass, + TRACK_STATE_CHANGE_CALLBACKS, + TRACK_STATE_CHANGE_LISTENER, + entity_ids, + action, + ) return remove_listener @callback -def _async_remove_state_change_listeners( - hass: HomeAssistant, entity_ids: Iterable[str], action: Callable[[Event], None] +def _async_remove_entity_listeners( + hass: HomeAssistant, + storage_key: str, + listener_key: str, + entity_ids: Iterable[str], + action: Callable[[Event], Any], ) -> None: """Remove a listener.""" - entity_callbacks = hass.data[TRACK_STATE_CHANGE_CALLBACKS] + + entity_callbacks = hass.data[storage_key] for entity_id in entity_ids: entity_callbacks[entity_id].remove(action) @@ -204,8 +219,66 @@ def _async_remove_state_change_listeners( del entity_callbacks[entity_id] if not entity_callbacks: - hass.data[TRACK_STATE_CHANGE_LISTENER]() - del hass.data[TRACK_STATE_CHANGE_LISTENER] + hass.data[listener_key]() + del hass.data[listener_key] + + +@bind_hass +def async_track_entity_registry_updated_event( + hass: HomeAssistant, + entity_ids: Union[str, Iterable[str]], + action: Callable[[Event], Any], +) -> Callable[[], None]: + """Track specific entity registry updated events indexed by entity_id. + + Similar to async_track_state_change_event. + """ + + entity_callbacks = hass.data.setdefault(TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, {}) + + if TRACK_ENTITY_REGISTRY_UPDATED_LISTENER not in hass.data: + + @callback + def _async_entity_registry_updated_dispatcher(event: Event) -> None: + """Dispatch entity registry updates by entity_id.""" + entity_id = event.data.get("old_entity_id", event.data["entity_id"]) + + if entity_id not in entity_callbacks: + return + + for action in entity_callbacks[entity_id][:]: + try: + hass.async_run_job(action, event) + except Exception: # pylint: disable=broad-except + _LOGGER.exception( + "Error while processing entity registry update for %s", + entity_id, + ) + + hass.data[TRACK_ENTITY_REGISTRY_UPDATED_LISTENER] = hass.bus.async_listen( + EVENT_ENTITY_REGISTRY_UPDATED, _async_entity_registry_updated_dispatcher + ) + + if isinstance(entity_ids, str): + entity_ids = [entity_ids] + + entity_ids = [entity_id.lower() for entity_id in entity_ids] + + for entity_id in entity_ids: + entity_callbacks.setdefault(entity_id, []).append(action) + + @callback + def remove_listener() -> None: + """Remove state change listener.""" + _async_remove_entity_listeners( + hass, + TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, + TRACK_ENTITY_REGISTRY_UPDATED_LISTENER, + entity_ids, + action, + ) + + return remove_listener @callback diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index 6c317e17989..c1388aeb1c1 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -338,6 +338,7 @@ async def test_cleanup_device(hass, device_reg, entity_reg, mqtt_mock): # Verify state is removed state = hass.states.get("sensor.mqtt_sensor") assert state is None + await hass.async_block_till_done() # Verify retained discovery topic has been cleared mqtt_mock.async_publish.assert_called_once_with( diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index 674dca474cd..99b4cad6eca 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -10,6 +10,7 @@ from homeassistant.components import sun from homeassistant.const import MATCH_ALL import homeassistant.core as ha from homeassistant.core import callback +from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED from homeassistant.helpers.event import ( async_call_later, async_track_point_in_time, @@ -1180,3 +1181,104 @@ async def test_async_track_point_in_time_cancel(hass): assert len(times) == 1 assert times[0].tzinfo.zone == "US/Hawaii" + + +async def test_async_track_entity_registry_updated_event(hass): + """Test tracking entity registry updates for an entity_id.""" + + entity_id = "switch.puppy_feeder" + new_entity_id = "switch.dog_feeder" + untracked_entity_id = "switch.kitty_feeder" + + hass.states.async_set(entity_id, "on") + await hass.async_block_till_done() + event_data = [] + + @ha.callback + def run_callback(event): + event_data.append(event.data) + + unsub1 = hass.helpers.event.async_track_entity_registry_updated_event( + entity_id, run_callback + ) + unsub2 = hass.helpers.event.async_track_entity_registry_updated_event( + new_entity_id, run_callback + ) + hass.bus.async_fire( + EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id} + ) + hass.bus.async_fire( + EVENT_ENTITY_REGISTRY_UPDATED, + {"action": "create", "entity_id": untracked_entity_id}, + ) + await hass.async_block_till_done() + + hass.bus.async_fire( + EVENT_ENTITY_REGISTRY_UPDATED, + { + "action": "update", + "entity_id": new_entity_id, + "old_entity_id": entity_id, + "changes": {}, + }, + ) + await hass.async_block_till_done() + + hass.bus.async_fire( + EVENT_ENTITY_REGISTRY_UPDATED, {"action": "remove", "entity_id": new_entity_id} + ) + await hass.async_block_till_done() + + unsub1() + unsub2() + hass.bus.async_fire( + EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id} + ) + hass.bus.async_fire( + EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": new_entity_id} + ) + await hass.async_block_till_done() + + assert event_data[0] == {"action": "create", "entity_id": "switch.puppy_feeder"} + assert event_data[1] == { + "action": "update", + "changes": {}, + "entity_id": "switch.dog_feeder", + "old_entity_id": "switch.puppy_feeder", + } + assert event_data[2] == {"action": "remove", "entity_id": "switch.dog_feeder"} + + +async def test_async_track_entity_registry_updated_event_with_a_callback_that_throws( + hass, +): + """Test tracking entity registry updates for an entity_id when one callback throws.""" + + entity_id = "switch.puppy_feeder" + + hass.states.async_set(entity_id, "on") + await hass.async_block_till_done() + event_data = [] + + @ha.callback + def run_callback(event): + event_data.append(event.data) + + @ha.callback + def failing_callback(event): + raise ValueError + + unsub1 = hass.helpers.event.async_track_entity_registry_updated_event( + entity_id, failing_callback + ) + unsub2 = hass.helpers.event.async_track_entity_registry_updated_event( + entity_id, run_callback + ) + hass.bus.async_fire( + EVENT_ENTITY_REGISTRY_UPDATED, {"action": "create", "entity_id": entity_id} + ) + await hass.async_block_till_done() + unsub1() + unsub2() + + assert event_data[0] == {"action": "create", "entity_id": "switch.puppy_feeder"}