Fix esphome not removing entities when static info changes (#95202)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
J. Nick Koston 2023-06-25 21:31:31 -05:00 committed by GitHub
parent d700415045
commit 3b7095c63b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 244 additions and 104 deletions

View file

@ -675,6 +675,7 @@ async def _cleanup_instance(
data.disconnect_callbacks = [] data.disconnect_callbacks = []
for cleanup_callback in data.cleanup_callbacks: for cleanup_callback in data.cleanup_callbacks:
cleanup_callback() cleanup_callback()
await data.async_cleanup()
await data.client.disconnect() await data.client.disconnect()
return data return data

View file

@ -73,7 +73,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="alarm_control_panel",
info_type=AlarmControlPanelInfo, info_type=AlarmControlPanelInfo,
entity_type=EsphomeAlarmControlPanel, entity_type=EsphomeAlarmControlPanel,
state_type=AlarmControlPanelEntityState, state_type=AlarmControlPanelEntityState,

View file

@ -29,7 +29,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="binary_sensor",
info_type=BinarySensorInfo, info_type=BinarySensorInfo,
entity_type=EsphomeBinarySensor, entity_type=EsphomeBinarySensor,
state_type=BinarySensorState, state_type=BinarySensorState,

View file

@ -23,7 +23,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="button",
info_type=ButtonInfo, info_type=ButtonInfo,
entity_type=EsphomeButton, entity_type=EsphomeButton,
state_type=EntityState, state_type=EntityState,

View file

@ -27,7 +27,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="camera",
info_type=CameraInfo, info_type=CameraInfo,
entity_type=EsphomeCamera, entity_type=EsphomeCamera,
state_type=CameraState, state_type=CameraState,

View file

@ -72,7 +72,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="climate",
info_type=ClimateInfo, info_type=ClimateInfo,
entity_type=EsphomeClimateEntity, entity_type=EsphomeClimateEntity,
state_type=ClimateState, state_type=ClimateState,

View file

@ -32,7 +32,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="cover",
info_type=CoverInfo, info_type=CoverInfo,
entity_type=EsphomeCover, entity_type=EsphomeCover,
state_type=CoverState, state_type=CoverState,

View file

@ -1,7 +1,7 @@
"""Diagnostics support for ESPHome.""" """Diagnostics support for ESPHome."""
from __future__ import annotations from __future__ import annotations
from typing import Any, cast from typing import Any
from homeassistant.components.bluetooth import async_scanner_by_source from homeassistant.components.bluetooth import async_scanner_by_source
from homeassistant.components.diagnostics import async_redact_data from homeassistant.components.diagnostics import async_redact_data
@ -28,7 +28,6 @@ async def async_get_config_entry_diagnostics(
entry_data = DomainData.get(hass).get_entry_data(config_entry) entry_data = DomainData.get(hass).get_entry_data(config_entry)
if (storage_data := await entry_data.store.async_load()) is not None: if (storage_data := await entry_data.store.async_load()) is not None:
storage_data = cast("dict[str, Any]", storage_data)
diag["storage_data"] = storage_data diag["storage_data"] = storage_data
if config_entry.unique_id and ( if config_entry.unique_id and (

View file

@ -12,10 +12,9 @@ from typing_extensions import Self
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.storage import Store
from .const import DOMAIN from .const import DOMAIN
from .entry_data import RuntimeEntryData from .entry_data import ESPHomeStorage, RuntimeEntryData
STORAGE_VERSION = 1 STORAGE_VERSION = 1
MAX_CACHED_SERVICES = 128 MAX_CACHED_SERVICES = 128
@ -26,7 +25,7 @@ class DomainData:
"""Define a class that stores global esphome data in hass.data[DOMAIN].""" """Define a class that stores global esphome data in hass.data[DOMAIN]."""
_entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict) _entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict)
_stores: dict[str, Store] = field(default_factory=dict) _stores: dict[str, ESPHomeStorage] = field(default_factory=dict)
_gatt_services_cache: MutableMapping[int, BleakGATTServiceCollection] = field( _gatt_services_cache: MutableMapping[int, BleakGATTServiceCollection] = field(
default_factory=lambda: LRU(MAX_CACHED_SERVICES) default_factory=lambda: LRU(MAX_CACHED_SERVICES)
) )
@ -83,11 +82,13 @@ class DomainData:
"""Check whether the given entry is loaded.""" """Check whether the given entry is loaded."""
return entry.entry_id in self._entry_datas return entry.entry_id in self._entry_datas
def get_or_create_store(self, hass: HomeAssistant, entry: ConfigEntry) -> Store: def get_or_create_store(
self, hass: HomeAssistant, entry: ConfigEntry
) -> ESPHomeStorage:
"""Get or create a Store instance for the given config entry.""" """Get or create a Store instance for the given config entry."""
return self._stores.setdefault( return self._stores.setdefault(
entry.entry_id, entry.entry_id,
Store( ESPHomeStorage(
hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder
), ),
) )

View file

@ -27,7 +27,6 @@ import homeassistant.helpers.config_validation as cv
import homeassistant.helpers.device_registry as dr import homeassistant.helpers.device_registry as dr
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_connect,
async_dispatcher_send,
) )
from homeassistant.helpers.entity import DeviceInfo, Entity from homeassistant.helpers.entity import DeviceInfo, Entity
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
@ -49,7 +48,6 @@ async def platform_async_setup_entry(
entry: ConfigEntry, entry: ConfigEntry,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
*, *,
component_key: str,
info_type: type[_InfoT], info_type: type[_InfoT],
entity_type: type[_EntityT], entity_type: type[_EntityT],
state_type: type[_StateT], state_type: type[_StateT],
@ -60,42 +58,35 @@ async def platform_async_setup_entry(
info and state updates. info and state updates.
""" """
entry_data: RuntimeEntryData = DomainData.get(hass).get_entry_data(entry) entry_data: RuntimeEntryData = DomainData.get(hass).get_entry_data(entry)
entry_data.info[component_key] = {} entry_data.info[info_type] = {}
entry_data.old_info[component_key] = {}
entry_data.state.setdefault(state_type, {}) entry_data.state.setdefault(state_type, {})
@callback @callback
def async_list_entities(infos: list[EntityInfo]) -> None: def async_list_entities(infos: list[EntityInfo]) -> None:
"""Update entities of this platform when entities are listed.""" """Update entities of this platform when entities are listed."""
old_infos = entry_data.info[component_key] current_infos = entry_data.info[info_type]
new_infos: dict[int, EntityInfo] = {} new_infos: dict[int, EntityInfo] = {}
add_entities: list[_EntityT] = [] add_entities: list[_EntityT] = []
for info in infos: for info in infos:
if info.key in old_infos: if not current_infos.pop(info.key, None):
# Update existing entity
old_infos.pop(info.key)
else:
# Create new entity # Create new entity
entity = entity_type(entry_data, component_key, info, state_type) entity = entity_type(entry_data, info, state_type)
add_entities.append(entity) add_entities.append(entity)
new_infos[info.key] = info new_infos[info.key] = info
# Remove old entities # Anything still in current_infos is now gone
for info in old_infos.values(): if current_infos:
entry_data.async_remove_entity(hass, component_key, info.key) hass.async_create_task(
entry_data.async_remove_entities(current_infos.values())
# First copy the now-old info into the backup object
entry_data.old_info[component_key] = entry_data.info[component_key]
# Then update the actual info
entry_data.info[component_key] = new_infos
for key, new_info in new_infos.items():
async_dispatcher_send(
hass,
entry_data.signal_component_key_static_info_updated(component_key, key),
new_info,
) )
# Then update the actual info
entry_data.info[info_type] = new_infos
if new_infos:
entry_data.async_update_entity_infos(new_infos.values())
if add_entities: if add_entities:
# Add entities to Home Assistant # Add entities to Home Assistant
async_add_entities(add_entities) async_add_entities(add_entities)
@ -154,14 +145,12 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
def __init__( def __init__(
self, self,
entry_data: RuntimeEntryData, entry_data: RuntimeEntryData,
component_key: str,
entity_info: EntityInfo, entity_info: EntityInfo,
state_type: type[_StateT], state_type: type[_StateT],
) -> None: ) -> None:
"""Initialize.""" """Initialize."""
self._entry_data = entry_data self._entry_data = entry_data
self._on_entry_data_changed() self._on_entry_data_changed()
self._component_key = component_key
self._key = entity_info.key self._key = entity_info.key
self._state_type = state_type self._state_type = state_type
self._on_static_info_update(entity_info) self._on_static_info_update(entity_info)
@ -178,13 +167,11 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
"""Register callbacks.""" """Register callbacks."""
entry_data = self._entry_data entry_data = self._entry_data
hass = self.hass hass = self.hass
component_key = self._component_key
key = self._key key = self._key
self.async_on_remove( self.async_on_remove(
async_dispatcher_connect( entry_data.async_register_key_static_info_remove_callback(
hass, self._static_info,
f"esphome_{self._entry_id}_remove_{component_key}_{key}",
functools.partial(self.async_remove, force_remove=True), functools.partial(self.async_remove, force_remove=True),
) )
) )
@ -201,10 +188,8 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
) )
) )
self.async_on_remove( self.async_on_remove(
async_dispatcher_connect( entry_data.async_register_key_static_info_updated_callback(
hass, self._static_info, self._on_static_info_update
entry_data.signal_component_key_static_info_updated(component_key, key),
self._on_static_info_update,
) )
) )
self._update_state_from_entry_data() self._update_state_from_entry_data()

View file

@ -2,10 +2,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable from collections.abc import Callable, Coroutine, Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
import logging import logging
from typing import Any, cast from typing import TYPE_CHECKING, Any, Final, TypedDict, cast
from aioesphomeapi import ( from aioesphomeapi import (
COMPONENT_TYPE_TO_INFO, COMPONENT_TYPE_TO_INFO,
@ -41,6 +41,8 @@ from homeassistant.helpers.storage import Store
from .dashboard import async_get_dashboard from .dashboard import async_get_dashboard
INFO_TO_COMPONENT_TYPE: Final = {v: k for k, v in COMPONENT_TYPE_TO_INFO.items()}
_SENTINEL = object() _SENTINEL = object()
SAVE_DELAY = 120 SAVE_DELAY = 120
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -65,26 +67,31 @@ INFO_TYPE_TO_PLATFORM: dict[type[EntityInfo], Platform] = {
} }
class StoreData(TypedDict, total=False):
"""ESPHome storage data."""
device_info: dict[str, Any]
services: list[dict[str, Any]]
api_version: dict[str, Any]
class ESPHomeStorage(Store[StoreData]):
"""ESPHome Storage."""
@dataclass @dataclass
class RuntimeEntryData: class RuntimeEntryData:
"""Store runtime data for esphome config entries.""" """Store runtime data for esphome config entries."""
entry_id: str entry_id: str
client: APIClient client: APIClient
store: Store store: ESPHomeStorage
state: dict[type[EntityState], dict[int, EntityState]] = field(default_factory=dict) state: dict[type[EntityState], dict[int, EntityState]] = field(default_factory=dict)
# When the disconnect callback is called, we mark all states # When the disconnect callback is called, we mark all states
# as stale so we will always dispatch a state update when the # as stale so we will always dispatch a state update when the
# device reconnects. This is the same format as state_subscriptions. # device reconnects. This is the same format as state_subscriptions.
stale_state: set[tuple[type[EntityState], int]] = field(default_factory=set) stale_state: set[tuple[type[EntityState], int]] = field(default_factory=set)
info: dict[str, dict[int, EntityInfo]] = field(default_factory=dict) info: dict[type[EntityInfo], dict[int, EntityInfo]] = field(default_factory=dict)
# A second list of EntityInfo objects
# This is necessary for when an entity is being removed. HA requires
# some static info to be accessible during removal (unique_id, maybe others)
# If an entity can't find anything in the info array, it will look for info here.
old_info: dict[str, dict[int, EntityInfo]] = field(default_factory=dict)
services: dict[int, UserService] = field(default_factory=dict) services: dict[int, UserService] = field(default_factory=dict)
available: bool = False available: bool = False
device_info: DeviceInfo | None = None device_info: DeviceInfo | None = None
@ -96,7 +103,8 @@ class RuntimeEntryData:
] = field(default_factory=dict) ] = field(default_factory=dict)
loaded_platforms: set[Platform] = field(default_factory=set) loaded_platforms: set[Platform] = field(default_factory=set)
platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock) platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
_storage_contents: dict[str, Any] | None = None _storage_contents: StoreData | None = None
_pending_storage: Callable[[], StoreData] | None = None
ble_connections_free: int = 0 ble_connections_free: int = 0
ble_connections_limit: int = 0 ble_connections_limit: int = 0
_ble_connection_free_futures: list[asyncio.Future[int]] = field( _ble_connection_free_futures: list[asyncio.Future[int]] = field(
@ -109,6 +117,12 @@ class RuntimeEntryData:
entity_info_callbacks: dict[ entity_info_callbacks: dict[
type[EntityInfo], list[Callable[[list[EntityInfo]], None]] type[EntityInfo], list[Callable[[list[EntityInfo]], None]]
] = field(default_factory=dict) ] = field(default_factory=dict)
entity_info_key_remove_callbacks: dict[
tuple[type[EntityInfo], int], list[Callable[[], Coroutine[Any, Any, None]]]
] = field(default_factory=dict)
entity_info_key_updated_callbacks: dict[
tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]]
] = field(default_factory=dict)
original_options: dict[str, Any] = field(default_factory=dict) original_options: dict[str, Any] = field(default_factory=dict)
@property @property
@ -133,12 +147,6 @@ class RuntimeEntryData:
"""Return the signal to listen to for updates on static info.""" """Return the signal to listen to for updates on static info."""
return f"esphome_{self.entry_id}_on_list" return f"esphome_{self.entry_id}_on_list"
def signal_component_key_static_info_updated(
self, component_key: str, key: int
) -> str:
"""Return the signal to listen to for updates on static info for a specific component_key and key."""
return f"esphome_{self.entry_id}_static_info_updated_{component_key}_{key}"
@callback @callback
def async_register_static_info_callback( def async_register_static_info_callback(
self, self,
@ -154,6 +162,38 @@ class RuntimeEntryData:
return _unsub return _unsub
@callback
def async_register_key_static_info_remove_callback(
self,
static_info: EntityInfo,
callback_: Callable[[], Coroutine[Any, Any, None]],
) -> CALLBACK_TYPE:
"""Register to receive callbacks when static info is removed for a specific key."""
callback_key = (type(static_info), static_info.key)
callbacks = self.entity_info_key_remove_callbacks.setdefault(callback_key, [])
callbacks.append(callback_)
def _unsub() -> None:
callbacks.remove(callback_)
return _unsub
@callback
def async_register_key_static_info_updated_callback(
self,
static_info: EntityInfo,
callback_: Callable[[EntityInfo], None],
) -> CALLBACK_TYPE:
"""Register to receive callbacks when static info is updated for a specific key."""
callback_key = (type(static_info), static_info.key)
callbacks = self.entity_info_key_updated_callbacks.setdefault(callback_key, [])
callbacks.append(callback_)
def _unsub() -> None:
callbacks.remove(callback_)
return _unsub
@callback @callback
def async_update_ble_connection_limits(self, free: int, limit: int) -> None: def async_update_ble_connection_limits(self, free: int, limit: int) -> None:
"""Update the BLE connection limits.""" """Update the BLE connection limits."""
@ -203,13 +243,25 @@ class RuntimeEntryData:
self.assist_pipeline_update_callbacks.append(update_callback) self.assist_pipeline_update_callbacks.append(update_callback)
return _unsubscribe return _unsubscribe
@callback async def async_remove_entities(self, static_infos: Iterable[EntityInfo]) -> None:
def async_remove_entity(
self, hass: HomeAssistant, component_key: str, key: int
) -> None:
"""Schedule the removal of an entity.""" """Schedule the removal of an entity."""
signal = f"esphome_{self.entry_id}_remove_{component_key}_{key}" callbacks: list[Coroutine[Any, Any, None]] = []
async_dispatcher_send(hass, signal) for static_info in static_infos:
callback_key = (type(static_info), static_info.key)
if key_callbacks := self.entity_info_key_remove_callbacks.get(callback_key):
callbacks.extend([callback_() for callback_ in key_callbacks])
if callbacks:
await asyncio.gather(*callbacks)
@callback
def async_update_entity_infos(self, static_infos: Iterable[EntityInfo]) -> None:
"""Call static info updated callbacks."""
for static_info in static_infos:
callback_key = (type(static_info), static_info.key)
for callback_ in self.entity_info_key_updated_callbacks.get(
callback_key, []
):
callback_(static_info)
async def _ensure_platforms_loaded( async def _ensure_platforms_loaded(
self, hass: HomeAssistant, entry: ConfigEntry, platforms: set[Platform] self, hass: HomeAssistant, entry: ConfigEntry, platforms: set[Platform]
@ -288,7 +340,7 @@ class RuntimeEntryData:
and subscription_key not in stale_state and subscription_key not in stale_state
and not ( and not (
type(state) is SensorState # pylint: disable=unidiomatic-typecheck type(state) is SensorState # pylint: disable=unidiomatic-typecheck
and (platform_info := self.info.get(Platform.SENSOR)) and (platform_info := self.info.get(SensorInfo))
and (entity_info := platform_info.get(state.key)) and (entity_info := platform_info.get(state.key))
and (cast(SensorInfo, entity_info)).force_update and (cast(SensorInfo, entity_info)).force_update
) )
@ -326,47 +378,57 @@ class RuntimeEntryData:
"""Load the retained data from store and return de-serialized data.""" """Load the retained data from store and return de-serialized data."""
if (restored := await self.store.async_load()) is None: if (restored := await self.store.async_load()) is None:
return [], [] return [], []
restored = cast("dict[str, Any]", restored)
self._storage_contents = restored.copy() self._storage_contents = restored.copy()
self.device_info = DeviceInfo.from_dict(restored.pop("device_info")) self.device_info = DeviceInfo.from_dict(restored.pop("device_info"))
self.api_version = APIVersion.from_dict(restored.pop("api_version", {})) self.api_version = APIVersion.from_dict(restored.pop("api_version", {}))
infos = [] infos: list[EntityInfo] = []
for comp_type, restored_infos in restored.items(): for comp_type, restored_infos in restored.items():
if TYPE_CHECKING:
restored_infos = cast(list[dict[str, Any]], restored_infos)
if comp_type not in COMPONENT_TYPE_TO_INFO: if comp_type not in COMPONENT_TYPE_TO_INFO:
continue continue
for info in restored_infos: for info in restored_infos:
cls = COMPONENT_TYPE_TO_INFO[comp_type] cls = COMPONENT_TYPE_TO_INFO[comp_type]
infos.append(cls.from_dict(info)) infos.append(cls.from_dict(info))
services = [] services = [
for service in restored.get("services", []): UserService.from_dict(service) for service in restored.pop("services", [])
services.append(UserService.from_dict(service)) ]
return infos, services return infos, services
async def async_save_to_store(self) -> None: async def async_save_to_store(self) -> None:
"""Generate dynamic data to store and save it to the filesystem.""" """Generate dynamic data to store and save it to the filesystem."""
if self.device_info is None: if self.device_info is None:
raise ValueError("device_info is not set yet") raise ValueError("device_info is not set yet")
store_data: dict[str, Any] = { store_data: StoreData = {
"device_info": self.device_info.to_dict(), "device_info": self.device_info.to_dict(),
"services": [], "services": [],
"api_version": self.api_version.to_dict(), "api_version": self.api_version.to_dict(),
} }
for info_type, infos in self.info.items():
for comp_type, infos in self.info.items(): comp_type = INFO_TO_COMPONENT_TYPE[info_type]
store_data[comp_type] = [info.to_dict() for info in infos.values()] store_data[comp_type] = [info.to_dict() for info in infos.values()] # type: ignore[literal-required]
for service in self.services.values(): for service in self.services.values():
store_data["services"].append(service.to_dict()) store_data["services"].append(service.to_dict())
if store_data == self._storage_contents: if store_data == self._storage_contents:
return return
def _memorized_storage() -> dict[str, Any]: def _memorized_storage() -> StoreData:
self._pending_storage = None
self._storage_contents = store_data self._storage_contents = store_data
return store_data return store_data
self._pending_storage = _memorized_storage
self.store.async_delay_save(_memorized_storage, SAVE_DELAY) self.store.async_delay_save(_memorized_storage, SAVE_DELAY)
async def async_cleanup(self) -> None:
"""Cleanup the entry data when disconnected or unloading."""
if self._pending_storage:
# Ensure we save the data if we are unloading before the
# save delay has passed.
await self.store.async_save(self._pending_storage())
async def async_update_listener( async def async_update_listener(
self, hass: HomeAssistant, entry: ConfigEntry self, hass: HomeAssistant, entry: ConfigEntry
) -> None: ) -> None:

View file

@ -40,7 +40,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="fan",
info_type=FanInfo, info_type=FanInfo,
entity_type=EsphomeFan, entity_type=EsphomeFan,
state_type=FanState, state_type=FanState,

View file

@ -48,7 +48,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="light",
info_type=LightInfo, info_type=LightInfo,
entity_type=EsphomeLight, entity_type=EsphomeLight,
state_type=LightState, state_type=LightState,

View file

@ -26,7 +26,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="lock",
info_type=LockInfo, info_type=LockInfo,
entity_type=EsphomeLock, entity_type=EsphomeLock,
state_type=LockEntityState, state_type=LockEntityState,

View file

@ -43,7 +43,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="media_player",
info_type=MediaPlayerInfo, info_type=MediaPlayerInfo,
entity_type=EsphomeMediaPlayer, entity_type=EsphomeMediaPlayer,
state_type=MediaPlayerEntityState, state_type=MediaPlayerEntityState,

View file

@ -34,7 +34,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="number",
info_type=NumberInfo, info_type=NumberInfo,
entity_type=EsphomeNumber, entity_type=EsphomeNumber,
state_type=NumberState, state_type=NumberState,

View file

@ -29,7 +29,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="select",
info_type=SelectInfo, info_type=SelectInfo,
entity_type=EsphomeSelect, entity_type=EsphomeSelect,
state_type=SelectState, state_type=SelectState,

View file

@ -41,7 +41,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="sensor",
info_type=SensorInfo, info_type=SensorInfo,
entity_type=EsphomeSensor, entity_type=EsphomeSensor,
state_type=SensorState, state_type=SensorState,
@ -50,7 +49,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="text_sensor",
info_type=TextSensorInfo, info_type=TextSensorInfo,
entity_type=EsphomeTextSensor, entity_type=EsphomeTextSensor,
state_type=TextSensorState, state_type=TextSensorState,

View file

@ -26,7 +26,6 @@ async def async_setup_entry(
hass, hass,
entry, entry,
async_add_entities, async_add_entities,
component_key="switch",
info_type=SwitchInfo, info_type=SwitchInfo,
entity_type=EsphomeSwitch, entity_type=EsphomeSwitch,
state_type=SwitchState, state_type=SwitchState,

View file

@ -169,19 +169,22 @@ async def _mock_generic_device_entry(
mock_device_info: dict[str, Any], mock_device_info: dict[str, Any],
mock_list_entities_services: tuple[list[EntityInfo], list[UserService]], mock_list_entities_services: tuple[list[EntityInfo], list[UserService]],
states: list[EntityState], states: list[EntityState],
entry: MockConfigEntry | None = None,
) -> MockESPHomeDevice: ) -> MockESPHomeDevice:
entry = MockConfigEntry( if not entry:
domain=DOMAIN, entry = MockConfigEntry(
data={ domain=DOMAIN,
CONF_HOST: "test.local", data={
CONF_PORT: 6053, CONF_HOST: "test.local",
CONF_PASSWORD: "", CONF_PORT: 6053,
}, CONF_PASSWORD: "",
options={ },
CONF_ALLOW_SERVICE_CALLS: DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS options={
}, CONF_ALLOW_SERVICE_CALLS: DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS
) },
entry.add_to_hass(hass) )
entry.add_to_hass(hass)
mock_device = MockESPHomeDevice(entry) mock_device = MockESPHomeDevice(entry)
device_info = DeviceInfo( device_info = DeviceInfo(
@ -290,9 +293,10 @@ async def mock_esphome_device(
entity_info: list[EntityInfo], entity_info: list[EntityInfo],
user_service: list[UserService], user_service: list[UserService],
states: list[EntityState], states: list[EntityState],
entry: MockConfigEntry | None = None,
) -> MockESPHomeDevice: ) -> MockESPHomeDevice:
return await _mock_generic_device_entry( return await _mock_generic_device_entry(
hass, mock_client, {}, (entity_info, user_service), states hass, mock_client, {}, (entity_info, user_service), states, entry
) )
return _mock_device return _mock_device

View file

@ -0,0 +1,103 @@
"""Test ESPHome binary sensors."""
from collections.abc import Awaitable, Callable
from typing import Any
from aioesphomeapi import (
APIClient,
BinarySensorInfo,
BinarySensorState,
EntityInfo,
EntityState,
UserService,
)
from homeassistant.const import ATTR_RESTORED, STATE_ON
from homeassistant.core import HomeAssistant
from .conftest import MockESPHomeDevice
async def test_entities_removed(
hass: HomeAssistant,
mock_client: APIClient,
hass_storage: dict[str, Any],
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test a generic binary_sensor where has_state is false."""
entity_info = [
BinarySensorInfo(
object_id="mybinary_sensor",
key=1,
name="my binary_sensor",
unique_id="my_binary_sensor",
),
BinarySensorInfo(
object_id="mybinary_sensor_to_be_removed",
key=2,
name="my binary_sensor to be removed",
unique_id="mybinary_sensor_to_be_removed",
),
]
states = [
BinarySensorState(key=1, state=True, missing_state=False),
BinarySensorState(key=2, state=True, missing_state=False),
]
user_service = []
mock_device = await mock_esphome_device(
mock_client=mock_client,
entity_info=entity_info,
user_service=user_service,
states=states,
)
entry = mock_device.entry
entry_id = entry.entry_id
storage_key = f"esphome.{entry_id}"
state = hass.states.get("binary_sensor.test_my_binary_sensor")
assert state is not None
assert state.state == STATE_ON
state = hass.states.get("binary_sensor.test_my_binary_sensor_to_be_removed")
assert state is not None
assert state.state == STATE_ON
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 2
state = hass.states.get("binary_sensor.test_my_binary_sensor")
assert state is not None
assert state.attributes[ATTR_RESTORED] is True
state = hass.states.get("binary_sensor.test_my_binary_sensor_to_be_removed")
assert state is not None
assert state.attributes[ATTR_RESTORED] is True
entity_info = [
BinarySensorInfo(
object_id="mybinary_sensor",
key=1,
name="my binary_sensor",
unique_id="my_binary_sensor",
),
]
states = [
BinarySensorState(key=1, state=True, missing_state=False),
]
mock_device = await mock_esphome_device(
mock_client=mock_client,
entity_info=entity_info,
user_service=user_service,
states=states,
entry=entry,
)
assert mock_device.entry.entry_id == entry_id
state = hass.states.get("binary_sensor.test_my_binary_sensor")
assert state is not None
assert state.state == STATE_ON
state = hass.states.get("binary_sensor.test_my_binary_sensor_to_be_removed")
assert state is None
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 1