Fix ESPHome not fully removing entities when entity info changes (#108823)

This commit is contained in:
J. Nick Koston 2024-01-24 17:29:11 -10:00 committed by GitHub
parent 7f56330e3b
commit d588ec8202
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 196 additions and 36 deletions

View file

@ -37,6 +37,51 @@ _EntityT = TypeVar("_EntityT", bound="EsphomeEntity[Any,Any]")
_StateT = TypeVar("_StateT", bound=EntityState)
@callback
def async_static_info_updated(
hass: HomeAssistant,
entry_data: RuntimeEntryData,
platform: entity_platform.EntityPlatform,
async_add_entities: AddEntitiesCallback,
info_type: type[_InfoT],
entity_type: type[_EntityT],
state_type: type[_StateT],
infos: list[EntityInfo],
) -> None:
"""Update entities of this platform when entities are listed."""
current_infos = entry_data.info[info_type]
new_infos: dict[int, EntityInfo] = {}
add_entities: list[_EntityT] = []
for info in infos:
if not current_infos.pop(info.key, None):
# Create new entity
entity = entity_type(entry_data, platform.domain, info, state_type)
add_entities.append(entity)
new_infos[info.key] = info
# Anything still in current_infos is now gone
if current_infos:
device_info = entry_data.device_info
if TYPE_CHECKING:
assert device_info is not None
hass.async_create_task(
entry_data.async_remove_entities(
hass, current_infos.values(), device_info.mac_address
)
)
# Then update the actual info
entry_data.info[info_type] = new_infos
if new_infos:
entry_data.async_update_entity_infos(new_infos.values())
if add_entities:
# Add entities to Home Assistant
async_add_entities(add_entities)
async def platform_async_setup_entry(
hass: HomeAssistant,
entry: ConfigEntry,
@ -55,39 +100,21 @@ async def platform_async_setup_entry(
entry_data.info[info_type] = {}
entry_data.state.setdefault(state_type, {})
platform = entity_platform.async_get_current_platform()
@callback
def async_list_entities(infos: list[EntityInfo]) -> None:
"""Update entities of this platform when entities are listed."""
current_infos = entry_data.info[info_type]
new_infos: dict[int, EntityInfo] = {}
add_entities: list[_EntityT] = []
for info in infos:
if not current_infos.pop(info.key, None):
# Create new entity
entity = entity_type(entry_data, platform.domain, info, state_type)
add_entities.append(entity)
new_infos[info.key] = info
# Anything still in current_infos is now gone
if current_infos:
hass.async_create_task(
entry_data.async_remove_entities(current_infos.values())
)
# Then update the actual info
entry_data.info[info_type] = new_infos
if new_infos:
entry_data.async_update_entity_infos(new_infos.values())
if add_entities:
# Add entities to Home Assistant
async_add_entities(add_entities)
on_static_info_update = functools.partial(
async_static_info_updated,
hass,
entry_data,
platform,
async_add_entities,
info_type,
entity_type,
state_type,
)
entry_data.cleanup_callbacks.append(
entry_data.async_register_static_info_callback(info_type, async_list_entities)
entry_data.async_register_static_info_callback(
info_type,
on_static_info_update,
)
)

View file

@ -243,8 +243,18 @@ class RuntimeEntryData:
"""Unsubscribe to assist pipeline updates."""
self.assist_pipeline_update_callbacks.remove(update_callback)
async def async_remove_entities(self, static_infos: Iterable[EntityInfo]) -> None:
async def async_remove_entities(
self, hass: HomeAssistant, static_infos: Iterable[EntityInfo], mac: str
) -> None:
"""Schedule the removal of an entity."""
# Remove from entity registry first so the entity is fully removed
ent_reg = er.async_get(hass)
for info in static_infos:
if entry := ent_reg.async_get_entity_id(
INFO_TYPE_TO_PLATFORM[type(info)], DOMAIN, build_unique_id(mac, info)
):
ent_reg.async_remove(entry)
callbacks: list[Coroutine[Any, Any, None]] = []
for static_info in static_infos:
callback_key = (type(static_info), static_info.key)

View file

@ -177,9 +177,10 @@ async def mock_dashboard(hass):
class MockESPHomeDevice:
"""Mock an esphome device."""
def __init__(self, entry: MockConfigEntry) -> None:
def __init__(self, entry: MockConfigEntry, client: APIClient) -> None:
"""Init the mock."""
self.entry = entry
self.client = client
self.state_callback: Callable[[EntityState], None]
self.service_call_callback: Callable[[HomeassistantServiceCall], None]
self.on_disconnect: Callable[[bool], None]
@ -258,7 +259,7 @@ async def _mock_generic_device_entry(
)
entry.add_to_hass(hass)
mock_device = MockESPHomeDevice(entry)
mock_device = MockESPHomeDevice(entry, mock_client)
default_device_info = {
"name": "test",

View file

@ -1,6 +1,7 @@
"""Test ESPHome binary sensors."""
from collections.abc import Awaitable, Callable
from typing import Any
from unittest.mock import AsyncMock
from aioesphomeapi import (
APIClient,
@ -21,6 +22,7 @@ from homeassistant.const import (
STATE_UNAVAILABLE,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from .conftest import MockESPHomeDevice
@ -34,7 +36,8 @@ async def test_entities_removed(
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test a generic binary_sensor where has_state is false."""
"""Test entities are removed when static info changes."""
ent_reg = er.async_get(hass)
entity_info = [
BinarySensorInfo(
object_id="mybinary_sensor",
@ -80,6 +83,8 @@ async def test_entities_removed(
assert state.attributes[ATTR_RESTORED] is True
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert state is not None
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert reg_entry is not None
assert state.attributes[ATTR_RESTORED] is True
entity_info = [
@ -106,11 +111,128 @@ async def test_entities_removed(
assert state.state == STATE_ON
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert state is None
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert reg_entry is None
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 1
async def test_entities_removed_after_reload(
hass: HomeAssistant,
mock_client: APIClient,
hass_storage: dict[str, Any],
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test entities and their registry entry are removed when static info changes after a reload."""
ent_reg = er.async_get(hass)
entity_info = [
BinarySensorInfo(
object_id="mybinary_sensor",
key=1,
name="my binary_sensor",
unique_id="my_binary_sensor",
),
BinarySensorInfo(
object_id="mybinary_sensor_to_be_removed",
key=2,
name="my binary_sensor to be removed",
unique_id="mybinary_sensor_to_be_removed",
),
]
states = [
BinarySensorState(key=1, state=True, missing_state=False),
BinarySensorState(key=2, state=True, missing_state=False),
]
user_service = []
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=entity_info,
user_service=user_service,
states=states,
)
entry = mock_device.entry
entry_id = entry.entry_id
storage_key = f"esphome.{entry_id}"
state = hass.states.get("binary_sensor.test_mybinary_sensor")
assert state is not None
assert state.state == STATE_ON
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert state is not None
assert state.state == STATE_ON
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert reg_entry is not None
assert await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 2
state = hass.states.get("binary_sensor.test_mybinary_sensor")
assert state is not None
assert state.attributes[ATTR_RESTORED] is True
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert state is not None
assert state.attributes[ATTR_RESTORED] is True
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert reg_entry is not None
assert await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 2
state = hass.states.get("binary_sensor.test_mybinary_sensor")
assert state is not None
assert ATTR_RESTORED not in state.attributes
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert state is not None
assert ATTR_RESTORED not in state.attributes
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert reg_entry is not None
assert await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
entity_info = [
BinarySensorInfo(
object_id="mybinary_sensor",
key=1,
name="my binary_sensor",
unique_id="my_binary_sensor",
),
]
states = [
BinarySensorState(key=1, state=True, missing_state=False),
]
mock_device.client.list_entities_services = AsyncMock(
return_value=(entity_info, user_service)
)
assert await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert mock_device.entry.entry_id == entry_id
state = hass.states.get("binary_sensor.test_mybinary_sensor")
assert state is not None
assert state.state == STATE_ON
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert state is None
await hass.async_block_till_done()
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
assert reg_entry is None
assert await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 1
async def test_entity_info_object_ids(
hass: HomeAssistant,
mock_client: APIClient,