Fix ESPHome not fully removing entities when entity info changes (#108823)
This commit is contained in:
parent
7f56330e3b
commit
d588ec8202
4 changed files with 196 additions and 36 deletions
|
@ -37,6 +37,51 @@ _EntityT = TypeVar("_EntityT", bound="EsphomeEntity[Any,Any]")
|
|||
_StateT = TypeVar("_StateT", bound=EntityState)
|
||||
|
||||
|
||||
@callback
|
||||
def async_static_info_updated(
|
||||
hass: HomeAssistant,
|
||||
entry_data: RuntimeEntryData,
|
||||
platform: entity_platform.EntityPlatform,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
info_type: type[_InfoT],
|
||||
entity_type: type[_EntityT],
|
||||
state_type: type[_StateT],
|
||||
infos: list[EntityInfo],
|
||||
) -> None:
|
||||
"""Update entities of this platform when entities are listed."""
|
||||
current_infos = entry_data.info[info_type]
|
||||
new_infos: dict[int, EntityInfo] = {}
|
||||
add_entities: list[_EntityT] = []
|
||||
|
||||
for info in infos:
|
||||
if not current_infos.pop(info.key, None):
|
||||
# Create new entity
|
||||
entity = entity_type(entry_data, platform.domain, info, state_type)
|
||||
add_entities.append(entity)
|
||||
new_infos[info.key] = info
|
||||
|
||||
# Anything still in current_infos is now gone
|
||||
if current_infos:
|
||||
device_info = entry_data.device_info
|
||||
if TYPE_CHECKING:
|
||||
assert device_info is not None
|
||||
hass.async_create_task(
|
||||
entry_data.async_remove_entities(
|
||||
hass, current_infos.values(), device_info.mac_address
|
||||
)
|
||||
)
|
||||
|
||||
# Then update the actual info
|
||||
entry_data.info[info_type] = new_infos
|
||||
|
||||
if new_infos:
|
||||
entry_data.async_update_entity_infos(new_infos.values())
|
||||
|
||||
if add_entities:
|
||||
# Add entities to Home Assistant
|
||||
async_add_entities(add_entities)
|
||||
|
||||
|
||||
async def platform_async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
entry: ConfigEntry,
|
||||
|
@ -55,39 +100,21 @@ async def platform_async_setup_entry(
|
|||
entry_data.info[info_type] = {}
|
||||
entry_data.state.setdefault(state_type, {})
|
||||
platform = entity_platform.async_get_current_platform()
|
||||
|
||||
@callback
|
||||
def async_list_entities(infos: list[EntityInfo]) -> None:
|
||||
"""Update entities of this platform when entities are listed."""
|
||||
current_infos = entry_data.info[info_type]
|
||||
new_infos: dict[int, EntityInfo] = {}
|
||||
add_entities: list[_EntityT] = []
|
||||
|
||||
for info in infos:
|
||||
if not current_infos.pop(info.key, None):
|
||||
# Create new entity
|
||||
entity = entity_type(entry_data, platform.domain, info, state_type)
|
||||
add_entities.append(entity)
|
||||
new_infos[info.key] = info
|
||||
|
||||
# Anything still in current_infos is now gone
|
||||
if current_infos:
|
||||
hass.async_create_task(
|
||||
entry_data.async_remove_entities(current_infos.values())
|
||||
)
|
||||
|
||||
# Then update the actual info
|
||||
entry_data.info[info_type] = new_infos
|
||||
|
||||
if new_infos:
|
||||
entry_data.async_update_entity_infos(new_infos.values())
|
||||
|
||||
if add_entities:
|
||||
# Add entities to Home Assistant
|
||||
async_add_entities(add_entities)
|
||||
|
||||
on_static_info_update = functools.partial(
|
||||
async_static_info_updated,
|
||||
hass,
|
||||
entry_data,
|
||||
platform,
|
||||
async_add_entities,
|
||||
info_type,
|
||||
entity_type,
|
||||
state_type,
|
||||
)
|
||||
entry_data.cleanup_callbacks.append(
|
||||
entry_data.async_register_static_info_callback(info_type, async_list_entities)
|
||||
entry_data.async_register_static_info_callback(
|
||||
info_type,
|
||||
on_static_info_update,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -243,8 +243,18 @@ class RuntimeEntryData:
|
|||
"""Unsubscribe to assist pipeline updates."""
|
||||
self.assist_pipeline_update_callbacks.remove(update_callback)
|
||||
|
||||
async def async_remove_entities(self, static_infos: Iterable[EntityInfo]) -> None:
|
||||
async def async_remove_entities(
|
||||
self, hass: HomeAssistant, static_infos: Iterable[EntityInfo], mac: str
|
||||
) -> None:
|
||||
"""Schedule the removal of an entity."""
|
||||
# Remove from entity registry first so the entity is fully removed
|
||||
ent_reg = er.async_get(hass)
|
||||
for info in static_infos:
|
||||
if entry := ent_reg.async_get_entity_id(
|
||||
INFO_TYPE_TO_PLATFORM[type(info)], DOMAIN, build_unique_id(mac, info)
|
||||
):
|
||||
ent_reg.async_remove(entry)
|
||||
|
||||
callbacks: list[Coroutine[Any, Any, None]] = []
|
||||
for static_info in static_infos:
|
||||
callback_key = (type(static_info), static_info.key)
|
||||
|
|
|
@ -177,9 +177,10 @@ async def mock_dashboard(hass):
|
|||
class MockESPHomeDevice:
|
||||
"""Mock an esphome device."""
|
||||
|
||||
def __init__(self, entry: MockConfigEntry) -> None:
|
||||
def __init__(self, entry: MockConfigEntry, client: APIClient) -> None:
|
||||
"""Init the mock."""
|
||||
self.entry = entry
|
||||
self.client = client
|
||||
self.state_callback: Callable[[EntityState], None]
|
||||
self.service_call_callback: Callable[[HomeassistantServiceCall], None]
|
||||
self.on_disconnect: Callable[[bool], None]
|
||||
|
@ -258,7 +259,7 @@ async def _mock_generic_device_entry(
|
|||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
mock_device = MockESPHomeDevice(entry)
|
||||
mock_device = MockESPHomeDevice(entry, mock_client)
|
||||
|
||||
default_device_info = {
|
||||
"name": "test",
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Test ESPHome binary sensors."""
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from aioesphomeapi import (
|
||||
APIClient,
|
||||
|
@ -21,6 +22,7 @@ from homeassistant.const import (
|
|||
STATE_UNAVAILABLE,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
|
||||
from .conftest import MockESPHomeDevice
|
||||
|
||||
|
@ -34,7 +36,8 @@ async def test_entities_removed(
|
|||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test a generic binary_sensor where has_state is false."""
|
||||
"""Test entities are removed when static info changes."""
|
||||
ent_reg = er.async_get(hass)
|
||||
entity_info = [
|
||||
BinarySensorInfo(
|
||||
object_id="mybinary_sensor",
|
||||
|
@ -80,6 +83,8 @@ async def test_entities_removed(
|
|||
assert state.attributes[ATTR_RESTORED] is True
|
||||
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
|
||||
assert state is not None
|
||||
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
|
||||
assert reg_entry is not None
|
||||
assert state.attributes[ATTR_RESTORED] is True
|
||||
|
||||
entity_info = [
|
||||
|
@ -106,11 +111,128 @@ async def test_entities_removed(
|
|||
assert state.state == STATE_ON
|
||||
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
|
||||
assert state is None
|
||||
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
|
||||
assert reg_entry is None
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 1
|
||||
|
||||
|
||||
async def test_entities_removed_after_reload(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
hass_storage: dict[str, Any],
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test entities and their registry entry are removed when static info changes after a reload."""
|
||||
ent_reg = er.async_get(hass)
|
||||
entity_info = [
|
||||
BinarySensorInfo(
|
||||
object_id="mybinary_sensor",
|
||||
key=1,
|
||||
name="my binary_sensor",
|
||||
unique_id="my_binary_sensor",
|
||||
),
|
||||
BinarySensorInfo(
|
||||
object_id="mybinary_sensor_to_be_removed",
|
||||
key=2,
|
||||
name="my binary_sensor to be removed",
|
||||
unique_id="mybinary_sensor_to_be_removed",
|
||||
),
|
||||
]
|
||||
states = [
|
||||
BinarySensorState(key=1, state=True, missing_state=False),
|
||||
BinarySensorState(key=2, state=True, missing_state=False),
|
||||
]
|
||||
user_service = []
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=entity_info,
|
||||
user_service=user_service,
|
||||
states=states,
|
||||
)
|
||||
entry = mock_device.entry
|
||||
entry_id = entry.entry_id
|
||||
storage_key = f"esphome.{entry_id}"
|
||||
state = hass.states.get("binary_sensor.test_mybinary_sensor")
|
||||
assert state is not None
|
||||
assert state.state == STATE_ON
|
||||
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
|
||||
assert state is not None
|
||||
assert state.state == STATE_ON
|
||||
|
||||
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
|
||||
assert reg_entry is not None
|
||||
|
||||
assert await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 2
|
||||
|
||||
state = hass.states.get("binary_sensor.test_mybinary_sensor")
|
||||
assert state is not None
|
||||
assert state.attributes[ATTR_RESTORED] is True
|
||||
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
|
||||
assert state is not None
|
||||
assert state.attributes[ATTR_RESTORED] is True
|
||||
|
||||
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
|
||||
assert reg_entry is not None
|
||||
|
||||
assert await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 2
|
||||
|
||||
state = hass.states.get("binary_sensor.test_mybinary_sensor")
|
||||
assert state is not None
|
||||
assert ATTR_RESTORED not in state.attributes
|
||||
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
|
||||
assert state is not None
|
||||
assert ATTR_RESTORED not in state.attributes
|
||||
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
|
||||
assert reg_entry is not None
|
||||
|
||||
assert await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
entity_info = [
|
||||
BinarySensorInfo(
|
||||
object_id="mybinary_sensor",
|
||||
key=1,
|
||||
name="my binary_sensor",
|
||||
unique_id="my_binary_sensor",
|
||||
),
|
||||
]
|
||||
states = [
|
||||
BinarySensorState(key=1, state=True, missing_state=False),
|
||||
]
|
||||
mock_device.client.list_entities_services = AsyncMock(
|
||||
return_value=(entity_info, user_service)
|
||||
)
|
||||
|
||||
assert await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_device.entry.entry_id == entry_id
|
||||
state = hass.states.get("binary_sensor.test_mybinary_sensor")
|
||||
assert state is not None
|
||||
assert state.state == STATE_ON
|
||||
state = hass.states.get("binary_sensor.test_mybinary_sensor_to_be_removed")
|
||||
assert state is None
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
reg_entry = ent_reg.async_get("binary_sensor.test_mybinary_sensor_to_be_removed")
|
||||
assert reg_entry is None
|
||||
assert await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 1
|
||||
|
||||
|
||||
async def test_entity_info_object_ids(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
|
|
Loading…
Add table
Reference in a new issue