diff --git a/homeassistant/components/esphome/entity.py b/homeassistant/components/esphome/entity.py index 1abf60be18a..14602077a94 100644 --- a/homeassistant/components/esphome/entity.py +++ b/homeassistant/components/esphome/entity.py @@ -37,6 +37,51 @@ _EntityT = TypeVar("_EntityT", bound="EsphomeEntity[Any,Any]") _StateT = TypeVar("_StateT", bound=EntityState) +@callback +def async_static_info_updated( + hass: HomeAssistant, + entry_data: RuntimeEntryData, + platform: entity_platform.EntityPlatform, + async_add_entities: AddEntitiesCallback, + info_type: type[_InfoT], + entity_type: type[_EntityT], + state_type: type[_StateT], + infos: list[EntityInfo], +) -> None: + """Update entities of this platform when entities are listed.""" + current_infos = entry_data.info[info_type] + new_infos: dict[int, EntityInfo] = {} + add_entities: list[_EntityT] = [] + + for info in infos: + if not current_infos.pop(info.key, None): + # Create new entity + entity = entity_type(entry_data, platform.domain, info, state_type) + add_entities.append(entity) + new_infos[info.key] = info + + # Anything still in current_infos is now gone + if current_infos: + device_info = entry_data.device_info + if TYPE_CHECKING: + assert device_info is not None + hass.async_create_task( + entry_data.async_remove_entities( + hass, current_infos.values(), device_info.mac_address + ) + ) + + # Then update the actual info + entry_data.info[info_type] = new_infos + + if new_infos: + entry_data.async_update_entity_infos(new_infos.values()) + + if add_entities: + # Add entities to Home Assistant + async_add_entities(add_entities) + + async def platform_async_setup_entry( hass: HomeAssistant, entry: ConfigEntry, @@ -55,39 +100,21 @@ async def platform_async_setup_entry( entry_data.info[info_type] = {} entry_data.state.setdefault(state_type, {}) platform = entity_platform.async_get_current_platform() - - @callback - def async_list_entities(infos: list[EntityInfo]) -> None: - """Update entities of this platform when entities are listed.""" - current_infos = entry_data.info[info_type] - new_infos: dict[int, EntityInfo] = {} - add_entities: list[_EntityT] = [] - - for info in infos: - if not current_infos.pop(info.key, None): - # Create new entity - entity = entity_type(entry_data, platform.domain, info, state_type) - add_entities.append(entity) - new_infos[info.key] = info - - # Anything still in current_infos is now gone - if current_infos: - hass.async_create_task( - entry_data.async_remove_entities(current_infos.values()) - ) - - # Then update the actual info - entry_data.info[info_type] = new_infos - - if new_infos: - entry_data.async_update_entity_infos(new_infos.values()) - - if add_entities: - # Add entities to Home Assistant - async_add_entities(add_entities) - + on_static_info_update = functools.partial( + async_static_info_updated, + hass, + entry_data, + platform, + async_add_entities, + info_type, + entity_type, + state_type, + ) entry_data.cleanup_callbacks.append( - entry_data.async_register_static_info_callback(info_type, async_list_entities) + entry_data.async_register_static_info_callback( + info_type, + on_static_info_update, + ) ) diff --git a/homeassistant/components/esphome/entry_data.py b/homeassistant/components/esphome/entry_data.py index 723141a94a2..940b1560ba4 100644 --- a/homeassistant/components/esphome/entry_data.py +++ b/homeassistant/components/esphome/entry_data.py @@ -243,8 +243,18 @@ class RuntimeEntryData: """Unsubscribe to assist pipeline updates.""" self.assist_pipeline_update_callbacks.remove(update_callback) - async def async_remove_entities(self, static_infos: Iterable[EntityInfo]) -> None: + async def async_remove_entities( + self, hass: HomeAssistant, static_infos: Iterable[EntityInfo], mac: str + ) -> None: """Schedule the removal of an entity.""" + # Remove from entity registry first so the entity is fully removed + ent_reg = er.async_get(hass) + for info in static_infos: + if entry := ent_reg.async_get_entity_id( + INFO_TYPE_TO_PLATFORM[type(info)], DOMAIN, build_unique_id(mac, info) + ): + ent_reg.async_remove(entry) + callbacks: list[Coroutine[Any, Any, None]] = [] for static_info in static_infos: callback_key = (type(static_info), static_info.key) diff --git a/tests/components/esphome/conftest.py b/tests/components/esphome/conftest.py index 8c46fac08d4..ac9d9235917 100644 --- a/tests/components/esphome/conftest.py +++ b/tests/components/esphome/conftest.py @@ -177,9 +177,10 @@ async def mock_dashboard(hass): class MockESPHomeDevice: """Mock an esphome device.""" - def __init__(self, entry: MockConfigEntry) -> None: + def __init__(self, entry: MockConfigEntry, client: APIClient) -> None: """Init the mock.""" self.entry = entry + self.client = client self.state_callback: Callable[[EntityState], None] self.service_call_callback: Callable[[HomeassistantServiceCall], None] self.on_disconnect: Callable[[bool], None] @@ -258,7 +259,7 @@ async def _mock_generic_device_entry( ) entry.add_to_hass(hass) - mock_device = MockESPHomeDevice(entry) + mock_device = MockESPHomeDevice(entry, mock_client) default_device_info = { "name": "test", diff --git a/tests/components/esphome/test_entity.py b/tests/components/esphome/test_entity.py index 9a5cb441f28..03fd21c32f8 100644 --- a/tests/components/esphome/test_entity.py +++ b/tests/components/esphome/test_entity.py @@ -1,6 +1,7 @@ """Test ESPHome binary sensors.""" from collections.abc import Awaitable, Callable from typing import Any +from unittest.mock import AsyncMock from aioesphomeapi import ( APIClient, @@ -21,6 +22,7 @@ from homeassistant.const import ( STATE_UNAVAILABLE, ) from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er from .conftest import MockESPHomeDevice @@ -34,7 +36,8 @@ async def test_entities_removed( Awaitable[MockESPHomeDevice], ], ) -> None: - """Test a generic binary_sensor where has_state is false.""" + """Test entities are removed when static info changes.""" + ent_reg = er.async_get(hass) entity_info = [ BinarySensorInfo( object_id="mybinary_sensor", @@ -80,6 +83,8 @@ async def test_entities_removed( assert state.attributes[ATTR_RESTORED] is True state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed") assert state is not None + reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed") + assert reg_entry is not None assert state.attributes[ATTR_RESTORED] is True entity_info = [ @@ -106,11 +111,128 @@ async def test_entities_removed( assert state.state == STATE_ON state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed") assert state is None + reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed") + assert reg_entry is None await hass.config_entries.async_unload(entry.entry_id) await hass.async_block_till_done() assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 1 +async def test_entities_removed_after_reload( + hass: HomeAssistant, + mock_client: APIClient, + hass_storage: dict[str, Any], + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], +) -> None: + """Test entities and their registry entry are removed when static info changes after a reload.""" + ent_reg = er.async_get(hass) + entity_info = [ + BinarySensorInfo( + object_id="mybinary_sensor", + key=1, + name="my binary_sensor", + unique_id="my_binary_sensor", + ), + BinarySensorInfo( + object_id="mybinary_sensor_to_be_removed", + key=2, + name="my binary_sensor to be removed", + unique_id="mybinary_sensor_to_be_removed", + ), + ] + states = [ + BinarySensorState(key=1, state=True, missing_state=False), + BinarySensorState(key=2, state=True, missing_state=False), + ] + user_service = [] + mock_device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=entity_info, + user_service=user_service, + states=states, + ) + entry = mock_device.entry + entry_id = entry.entry_id + storage_key = f"esphome.{entry_id}" + state = hass.states.get("binary_sensor.test_mybinary_sensor") + assert state is not None + assert state.state == STATE_ON + state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed") + assert state is not None + assert state.state == STATE_ON + + reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed") + assert reg_entry is not None + + assert await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + + assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 2 + + state = hass.states.get("binary_sensor.test_mybinary_sensor") + assert state is not None + assert state.attributes[ATTR_RESTORED] is True + state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed") + assert state is not None + assert state.attributes[ATTR_RESTORED] is True + + reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed") + assert reg_entry is not None + + assert await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 2 + + state = hass.states.get("binary_sensor.test_mybinary_sensor") + assert state is not None + assert ATTR_RESTORED not in state.attributes + state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed") + assert state is not None + assert ATTR_RESTORED not in state.attributes + reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed") + assert reg_entry is not None + + assert await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + + entity_info = [ + BinarySensorInfo( + object_id="mybinary_sensor", + key=1, + name="my binary_sensor", + unique_id="my_binary_sensor", + ), + ] + states = [ + BinarySensorState(key=1, state=True, missing_state=False), + ] + mock_device.client.list_entities_services = AsyncMock( + return_value=(entity_info, user_service) + ) + + assert await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + assert mock_device.entry.entry_id == entry_id + state = hass.states.get("binary_sensor.test_mybinary_sensor") + assert state is not None + assert state.state == STATE_ON + state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed") + assert state is None + + await hass.async_block_till_done() + + reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed") + assert reg_entry is None + assert await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 1 + + async def test_entity_info_object_ids( hass: HomeAssistant, mock_client: APIClient,