diff --git a/homeassistant/components/esphome/__init__.py b/homeassistant/components/esphome/__init__.py index af2c1d59505..76f218d3668 100644 --- a/homeassistant/components/esphome/__init__.py +++ b/homeassistant/components/esphome/__init__.py @@ -675,6 +675,7 @@ async def _cleanup_instance( data.disconnect_callbacks = [] for cleanup_callback in data.cleanup_callbacks: cleanup_callback() + await data.async_cleanup() await data.client.disconnect() return data diff --git a/homeassistant/components/esphome/alarm_control_panel.py b/homeassistant/components/esphome/alarm_control_panel.py index 6fadd7d4408..f69560945c3 100644 --- a/homeassistant/components/esphome/alarm_control_panel.py +++ b/homeassistant/components/esphome/alarm_control_panel.py @@ -73,7 +73,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="alarm_control_panel", info_type=AlarmControlPanelInfo, entity_type=EsphomeAlarmControlPanel, state_type=AlarmControlPanelEntityState, diff --git a/homeassistant/components/esphome/binary_sensor.py b/homeassistant/components/esphome/binary_sensor.py index a755bcf10ef..65a237de4f7 100644 --- a/homeassistant/components/esphome/binary_sensor.py +++ b/homeassistant/components/esphome/binary_sensor.py @@ -29,7 +29,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="binary_sensor", info_type=BinarySensorInfo, entity_type=EsphomeBinarySensor, state_type=BinarySensorState, diff --git a/homeassistant/components/esphome/button.py b/homeassistant/components/esphome/button.py index 71bb7017c55..7087cb034ae 100644 --- a/homeassistant/components/esphome/button.py +++ b/homeassistant/components/esphome/button.py @@ -23,7 +23,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="button", info_type=ButtonInfo, entity_type=EsphomeButton, state_type=EntityState, diff --git a/homeassistant/components/esphome/camera.py b/homeassistant/components/esphome/camera.py index 17f73f7c770..94a9b03b90c 100644 --- a/homeassistant/components/esphome/camera.py +++ b/homeassistant/components/esphome/camera.py @@ -27,7 +27,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="camera", info_type=CameraInfo, entity_type=EsphomeCamera, state_type=CameraState, diff --git a/homeassistant/components/esphome/climate.py b/homeassistant/components/esphome/climate.py index 5c252e888d9..2c1d005f9be 100644 --- a/homeassistant/components/esphome/climate.py +++ b/homeassistant/components/esphome/climate.py @@ -72,7 +72,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="climate", info_type=ClimateInfo, entity_type=EsphomeClimateEntity, state_type=ClimateState, diff --git a/homeassistant/components/esphome/cover.py b/homeassistant/components/esphome/cover.py index 347ff98e689..45ef8a132f9 100644 --- a/homeassistant/components/esphome/cover.py +++ b/homeassistant/components/esphome/cover.py @@ -32,7 +32,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="cover", info_type=CoverInfo, entity_type=EsphomeCover, state_type=CoverState, diff --git a/homeassistant/components/esphome/diagnostics.py b/homeassistant/components/esphome/diagnostics.py index 8de1501bc43..292d1921abf 100644 --- a/homeassistant/components/esphome/diagnostics.py +++ b/homeassistant/components/esphome/diagnostics.py @@ -1,7 +1,7 @@ """Diagnostics support for ESPHome.""" from __future__ import annotations -from typing import Any, cast +from typing import Any from homeassistant.components.bluetooth import async_scanner_by_source from homeassistant.components.diagnostics import async_redact_data @@ -28,7 +28,6 @@ async def async_get_config_entry_diagnostics( entry_data = DomainData.get(hass).get_entry_data(config_entry) if (storage_data := await entry_data.store.async_load()) is not None: - storage_data = cast("dict[str, Any]", storage_data) diag["storage_data"] = storage_data if config_entry.unique_id and ( diff --git a/homeassistant/components/esphome/domain_data.py b/homeassistant/components/esphome/domain_data.py index 1379b274122..2fc32129d1f 100644 --- a/homeassistant/components/esphome/domain_data.py +++ b/homeassistant/components/esphome/domain_data.py @@ -12,10 +12,9 @@ from typing_extensions import Self from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.json import JSONEncoder -from homeassistant.helpers.storage import Store from .const import DOMAIN -from .entry_data import RuntimeEntryData +from .entry_data import ESPHomeStorage, RuntimeEntryData STORAGE_VERSION = 1 MAX_CACHED_SERVICES = 128 @@ -26,7 +25,7 @@ class DomainData: """Define a class that stores global esphome data in hass.data[DOMAIN].""" _entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict) - _stores: dict[str, Store] = field(default_factory=dict) + _stores: dict[str, ESPHomeStorage] = field(default_factory=dict) _gatt_services_cache: MutableMapping[int, BleakGATTServiceCollection] = field( default_factory=lambda: LRU(MAX_CACHED_SERVICES) ) @@ -83,11 +82,13 @@ class DomainData: """Check whether the given entry is loaded.""" return entry.entry_id in self._entry_datas - def get_or_create_store(self, hass: HomeAssistant, entry: ConfigEntry) -> Store: + def get_or_create_store( + self, hass: HomeAssistant, entry: ConfigEntry + ) -> ESPHomeStorage: """Get or create a Store instance for the given config entry.""" return self._stores.setdefault( entry.entry_id, - Store( + ESPHomeStorage( hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder ), ) diff --git a/homeassistant/components/esphome/entity.py b/homeassistant/components/esphome/entity.py index 18bf15ce4ee..dbb16fe481d 100644 --- a/homeassistant/components/esphome/entity.py +++ b/homeassistant/components/esphome/entity.py @@ -27,7 +27,6 @@ import homeassistant.helpers.config_validation as cv import homeassistant.helpers.device_registry as dr from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, - async_dispatcher_send, ) from homeassistant.helpers.entity import DeviceInfo, Entity from homeassistant.helpers.entity_platform import AddEntitiesCallback @@ -49,7 +48,6 @@ async def platform_async_setup_entry( entry: ConfigEntry, async_add_entities: AddEntitiesCallback, *, - component_key: str, info_type: type[_InfoT], entity_type: type[_EntityT], state_type: type[_StateT], @@ -60,42 +58,35 @@ async def platform_async_setup_entry( info and state updates. """ entry_data: RuntimeEntryData = DomainData.get(hass).get_entry_data(entry) - entry_data.info[component_key] = {} - entry_data.old_info[component_key] = {} + entry_data.info[info_type] = {} entry_data.state.setdefault(state_type, {}) @callback def async_list_entities(infos: list[EntityInfo]) -> None: """Update entities of this platform when entities are listed.""" - old_infos = entry_data.info[component_key] + current_infos = entry_data.info[info_type] new_infos: dict[int, EntityInfo] = {} add_entities: list[_EntityT] = [] + for info in infos: - if info.key in old_infos: - # Update existing entity - old_infos.pop(info.key) - else: + if not current_infos.pop(info.key, None): # Create new entity - entity = entity_type(entry_data, component_key, info, state_type) + entity = entity_type(entry_data, info, state_type) add_entities.append(entity) new_infos[info.key] = info - # Remove old entities - for info in old_infos.values(): - entry_data.async_remove_entity(hass, component_key, info.key) - - # First copy the now-old info into the backup object - entry_data.old_info[component_key] = entry_data.info[component_key] - # Then update the actual info - entry_data.info[component_key] = new_infos - - for key, new_info in new_infos.items(): - async_dispatcher_send( - hass, - entry_data.signal_component_key_static_info_updated(component_key, key), - new_info, + # Anything still in current_infos is now gone + if current_infos: + hass.async_create_task( + entry_data.async_remove_entities(current_infos.values()) ) + # Then update the actual info + entry_data.info[info_type] = new_infos + + if new_infos: + entry_data.async_update_entity_infos(new_infos.values()) + if add_entities: # Add entities to Home Assistant async_add_entities(add_entities) @@ -154,14 +145,12 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]): def __init__( self, entry_data: RuntimeEntryData, - component_key: str, entity_info: EntityInfo, state_type: type[_StateT], ) -> None: """Initialize.""" self._entry_data = entry_data self._on_entry_data_changed() - self._component_key = component_key self._key = entity_info.key self._state_type = state_type self._on_static_info_update(entity_info) @@ -178,13 +167,11 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]): """Register callbacks.""" entry_data = self._entry_data hass = self.hass - component_key = self._component_key key = self._key self.async_on_remove( - async_dispatcher_connect( - hass, - f"esphome_{self._entry_id}_remove_{component_key}_{key}", + entry_data.async_register_key_static_info_remove_callback( + self._static_info, functools.partial(self.async_remove, force_remove=True), ) ) @@ -201,10 +188,8 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]): ) ) self.async_on_remove( - async_dispatcher_connect( - hass, - entry_data.signal_component_key_static_info_updated(component_key, key), - self._on_static_info_update, + entry_data.async_register_key_static_info_updated_callback( + self._static_info, self._on_static_info_update ) ) self._update_state_from_entry_data() diff --git a/homeassistant/components/esphome/entry_data.py b/homeassistant/components/esphome/entry_data.py index 4cde32e6a79..e0d989c4b8b 100644 --- a/homeassistant/components/esphome/entry_data.py +++ b/homeassistant/components/esphome/entry_data.py @@ -2,10 +2,10 @@ from __future__ import annotations import asyncio -from collections.abc import Callable +from collections.abc import Callable, Coroutine, Iterable from dataclasses import dataclass, field import logging -from typing import Any, cast +from typing import TYPE_CHECKING, Any, Final, TypedDict, cast from aioesphomeapi import ( COMPONENT_TYPE_TO_INFO, @@ -41,6 +41,8 @@ from homeassistant.helpers.storage import Store from .dashboard import async_get_dashboard +INFO_TO_COMPONENT_TYPE: Final = {v: k for k, v in COMPONENT_TYPE_TO_INFO.items()} + _SENTINEL = object() SAVE_DELAY = 120 _LOGGER = logging.getLogger(__name__) @@ -65,26 +67,31 @@ INFO_TYPE_TO_PLATFORM: dict[type[EntityInfo], Platform] = { } +class StoreData(TypedDict, total=False): + """ESPHome storage data.""" + + device_info: dict[str, Any] + services: list[dict[str, Any]] + api_version: dict[str, Any] + + +class ESPHomeStorage(Store[StoreData]): + """ESPHome Storage.""" + + @dataclass class RuntimeEntryData: """Store runtime data for esphome config entries.""" entry_id: str client: APIClient - store: Store + store: ESPHomeStorage state: dict[type[EntityState], dict[int, EntityState]] = field(default_factory=dict) # When the disconnect callback is called, we mark all states # as stale so we will always dispatch a state update when the # device reconnects. This is the same format as state_subscriptions. stale_state: set[tuple[type[EntityState], int]] = field(default_factory=set) - info: dict[str, dict[int, EntityInfo]] = field(default_factory=dict) - - # A second list of EntityInfo objects - # This is necessary for when an entity is being removed. HA requires - # some static info to be accessible during removal (unique_id, maybe others) - # If an entity can't find anything in the info array, it will look for info here. - old_info: dict[str, dict[int, EntityInfo]] = field(default_factory=dict) - + info: dict[type[EntityInfo], dict[int, EntityInfo]] = field(default_factory=dict) services: dict[int, UserService] = field(default_factory=dict) available: bool = False device_info: DeviceInfo | None = None @@ -96,7 +103,8 @@ class RuntimeEntryData: ] = field(default_factory=dict) loaded_platforms: set[Platform] = field(default_factory=set) platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock) - _storage_contents: dict[str, Any] | None = None + _storage_contents: StoreData | None = None + _pending_storage: Callable[[], StoreData] | None = None ble_connections_free: int = 0 ble_connections_limit: int = 0 _ble_connection_free_futures: list[asyncio.Future[int]] = field( @@ -109,6 +117,12 @@ class RuntimeEntryData: entity_info_callbacks: dict[ type[EntityInfo], list[Callable[[list[EntityInfo]], None]] ] = field(default_factory=dict) + entity_info_key_remove_callbacks: dict[ + tuple[type[EntityInfo], int], list[Callable[[], Coroutine[Any, Any, None]]] + ] = field(default_factory=dict) + entity_info_key_updated_callbacks: dict[ + tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]] + ] = field(default_factory=dict) original_options: dict[str, Any] = field(default_factory=dict) @property @@ -133,12 +147,6 @@ class RuntimeEntryData: """Return the signal to listen to for updates on static info.""" return f"esphome_{self.entry_id}_on_list" - def signal_component_key_static_info_updated( - self, component_key: str, key: int - ) -> str: - """Return the signal to listen to for updates on static info for a specific component_key and key.""" - return f"esphome_{self.entry_id}_static_info_updated_{component_key}_{key}" - @callback def async_register_static_info_callback( self, @@ -154,6 +162,38 @@ class RuntimeEntryData: return _unsub + @callback + def async_register_key_static_info_remove_callback( + self, + static_info: EntityInfo, + callback_: Callable[[], Coroutine[Any, Any, None]], + ) -> CALLBACK_TYPE: + """Register to receive callbacks when static info is removed for a specific key.""" + callback_key = (type(static_info), static_info.key) + callbacks = self.entity_info_key_remove_callbacks.setdefault(callback_key, []) + callbacks.append(callback_) + + def _unsub() -> None: + callbacks.remove(callback_) + + return _unsub + + @callback + def async_register_key_static_info_updated_callback( + self, + static_info: EntityInfo, + callback_: Callable[[EntityInfo], None], + ) -> CALLBACK_TYPE: + """Register to receive callbacks when static info is updated for a specific key.""" + callback_key = (type(static_info), static_info.key) + callbacks = self.entity_info_key_updated_callbacks.setdefault(callback_key, []) + callbacks.append(callback_) + + def _unsub() -> None: + callbacks.remove(callback_) + + return _unsub + @callback def async_update_ble_connection_limits(self, free: int, limit: int) -> None: """Update the BLE connection limits.""" @@ -203,13 +243,25 @@ class RuntimeEntryData: self.assist_pipeline_update_callbacks.append(update_callback) return _unsubscribe - @callback - def async_remove_entity( - self, hass: HomeAssistant, component_key: str, key: int - ) -> None: + async def async_remove_entities(self, static_infos: Iterable[EntityInfo]) -> None: """Schedule the removal of an entity.""" - signal = f"esphome_{self.entry_id}_remove_{component_key}_{key}" - async_dispatcher_send(hass, signal) + callbacks: list[Coroutine[Any, Any, None]] = [] + for static_info in static_infos: + callback_key = (type(static_info), static_info.key) + if key_callbacks := self.entity_info_key_remove_callbacks.get(callback_key): + callbacks.extend([callback_() for callback_ in key_callbacks]) + if callbacks: + await asyncio.gather(*callbacks) + + @callback + def async_update_entity_infos(self, static_infos: Iterable[EntityInfo]) -> None: + """Call static info updated callbacks.""" + for static_info in static_infos: + callback_key = (type(static_info), static_info.key) + for callback_ in self.entity_info_key_updated_callbacks.get( + callback_key, [] + ): + callback_(static_info) async def _ensure_platforms_loaded( self, hass: HomeAssistant, entry: ConfigEntry, platforms: set[Platform] @@ -288,7 +340,7 @@ class RuntimeEntryData: and subscription_key not in stale_state and not ( type(state) is SensorState # pylint: disable=unidiomatic-typecheck - and (platform_info := self.info.get(Platform.SENSOR)) + and (platform_info := self.info.get(SensorInfo)) and (entity_info := platform_info.get(state.key)) and (cast(SensorInfo, entity_info)).force_update ) @@ -326,47 +378,57 @@ class RuntimeEntryData: """Load the retained data from store and return de-serialized data.""" if (restored := await self.store.async_load()) is None: return [], [] - restored = cast("dict[str, Any]", restored) self._storage_contents = restored.copy() self.device_info = DeviceInfo.from_dict(restored.pop("device_info")) self.api_version = APIVersion.from_dict(restored.pop("api_version", {})) - infos = [] + infos: list[EntityInfo] = [] for comp_type, restored_infos in restored.items(): + if TYPE_CHECKING: + restored_infos = cast(list[dict[str, Any]], restored_infos) if comp_type not in COMPONENT_TYPE_TO_INFO: continue for info in restored_infos: cls = COMPONENT_TYPE_TO_INFO[comp_type] infos.append(cls.from_dict(info)) - services = [] - for service in restored.get("services", []): - services.append(UserService.from_dict(service)) + services = [ + UserService.from_dict(service) for service in restored.pop("services", []) + ] return infos, services async def async_save_to_store(self) -> None: """Generate dynamic data to store and save it to the filesystem.""" if self.device_info is None: raise ValueError("device_info is not set yet") - store_data: dict[str, Any] = { + store_data: StoreData = { "device_info": self.device_info.to_dict(), "services": [], "api_version": self.api_version.to_dict(), } - - for comp_type, infos in self.info.items(): - store_data[comp_type] = [info.to_dict() for info in infos.values()] + for info_type, infos in self.info.items(): + comp_type = INFO_TO_COMPONENT_TYPE[info_type] + store_data[comp_type] = [info.to_dict() for info in infos.values()] # type: ignore[literal-required] for service in self.services.values(): store_data["services"].append(service.to_dict()) if store_data == self._storage_contents: return - def _memorized_storage() -> dict[str, Any]: + def _memorized_storage() -> StoreData: + self._pending_storage = None self._storage_contents = store_data return store_data + self._pending_storage = _memorized_storage self.store.async_delay_save(_memorized_storage, SAVE_DELAY) + async def async_cleanup(self) -> None: + """Cleanup the entry data when disconnected or unloading.""" + if self._pending_storage: + # Ensure we save the data if we are unloading before the + # save delay has passed. + await self.store.async_save(self._pending_storage()) + async def async_update_listener( self, hass: HomeAssistant, entry: ConfigEntry ) -> None: diff --git a/homeassistant/components/esphome/fan.py b/homeassistant/components/esphome/fan.py index 388413f161f..c6be200e2b2 100644 --- a/homeassistant/components/esphome/fan.py +++ b/homeassistant/components/esphome/fan.py @@ -40,7 +40,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="fan", info_type=FanInfo, entity_type=EsphomeFan, state_type=FanState, diff --git a/homeassistant/components/esphome/light.py b/homeassistant/components/esphome/light.py index f4232e320b0..aa67a8124fc 100644 --- a/homeassistant/components/esphome/light.py +++ b/homeassistant/components/esphome/light.py @@ -48,7 +48,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="light", info_type=LightInfo, entity_type=EsphomeLight, state_type=LightState, diff --git a/homeassistant/components/esphome/lock.py b/homeassistant/components/esphome/lock.py index 0cfc25e3882..00b94cd15ff 100644 --- a/homeassistant/components/esphome/lock.py +++ b/homeassistant/components/esphome/lock.py @@ -26,7 +26,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="lock", info_type=LockInfo, entity_type=EsphomeLock, state_type=LockEntityState, diff --git a/homeassistant/components/esphome/media_player.py b/homeassistant/components/esphome/media_player.py index 9933f523c26..9d008300966 100644 --- a/homeassistant/components/esphome/media_player.py +++ b/homeassistant/components/esphome/media_player.py @@ -43,7 +43,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="media_player", info_type=MediaPlayerInfo, entity_type=EsphomeMediaPlayer, state_type=MediaPlayerEntityState, diff --git a/homeassistant/components/esphome/number.py b/homeassistant/components/esphome/number.py index ead3d5c4307..e876fe412f6 100644 --- a/homeassistant/components/esphome/number.py +++ b/homeassistant/components/esphome/number.py @@ -34,7 +34,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="number", info_type=NumberInfo, entity_type=EsphomeNumber, state_type=NumberState, diff --git a/homeassistant/components/esphome/select.py b/homeassistant/components/esphome/select.py index 2de6ddd7111..9849f7cded8 100644 --- a/homeassistant/components/esphome/select.py +++ b/homeassistant/components/esphome/select.py @@ -29,7 +29,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="select", info_type=SelectInfo, entity_type=EsphomeSelect, state_type=SelectState, diff --git a/homeassistant/components/esphome/sensor.py b/homeassistant/components/esphome/sensor.py index ac2fb9629a8..3185a5eb536 100644 --- a/homeassistant/components/esphome/sensor.py +++ b/homeassistant/components/esphome/sensor.py @@ -41,7 +41,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="sensor", info_type=SensorInfo, entity_type=EsphomeSensor, state_type=SensorState, @@ -50,7 +49,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="text_sensor", info_type=TextSensorInfo, entity_type=EsphomeTextSensor, state_type=TextSensorState, diff --git a/homeassistant/components/esphome/switch.py b/homeassistant/components/esphome/switch.py index 4ecee203fa0..99894b8501e 100644 --- a/homeassistant/components/esphome/switch.py +++ b/homeassistant/components/esphome/switch.py @@ -26,7 +26,6 @@ async def async_setup_entry( hass, entry, async_add_entities, - component_key="switch", info_type=SwitchInfo, entity_type=EsphomeSwitch, state_type=SwitchState, diff --git a/tests/components/esphome/conftest.py b/tests/components/esphome/conftest.py index 8bb41a92d80..d78af769a17 100644 --- a/tests/components/esphome/conftest.py +++ b/tests/components/esphome/conftest.py @@ -169,19 +169,22 @@ async def _mock_generic_device_entry( mock_device_info: dict[str, Any], mock_list_entities_services: tuple[list[EntityInfo], list[UserService]], states: list[EntityState], + entry: MockConfigEntry | None = None, ) -> MockESPHomeDevice: - entry = MockConfigEntry( - domain=DOMAIN, - data={ - CONF_HOST: "test.local", - CONF_PORT: 6053, - CONF_PASSWORD: "", - }, - options={ - CONF_ALLOW_SERVICE_CALLS: DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS - }, - ) - entry.add_to_hass(hass) + if not entry: + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "test.local", + CONF_PORT: 6053, + CONF_PASSWORD: "", + }, + options={ + CONF_ALLOW_SERVICE_CALLS: DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS + }, + ) + entry.add_to_hass(hass) + mock_device = MockESPHomeDevice(entry) device_info = DeviceInfo( @@ -290,9 +293,10 @@ async def mock_esphome_device( entity_info: list[EntityInfo], user_service: list[UserService], states: list[EntityState], + entry: MockConfigEntry | None = None, ) -> MockESPHomeDevice: return await _mock_generic_device_entry( - hass, mock_client, {}, (entity_info, user_service), states + hass, mock_client, {}, (entity_info, user_service), states, entry ) return _mock_device diff --git a/tests/components/esphome/test_entity.py b/tests/components/esphome/test_entity.py new file mode 100644 index 00000000000..39bfec852e7 --- /dev/null +++ b/tests/components/esphome/test_entity.py @@ -0,0 +1,103 @@ +"""Test ESPHome binary sensors.""" +from collections.abc import Awaitable, Callable +from typing import Any + +from aioesphomeapi import ( + APIClient, + BinarySensorInfo, + BinarySensorState, + EntityInfo, + EntityState, + UserService, +) + +from homeassistant.const import ATTR_RESTORED, STATE_ON +from homeassistant.core import HomeAssistant + +from .conftest import MockESPHomeDevice + + +async def test_entities_removed( + hass: HomeAssistant, + mock_client: APIClient, + hass_storage: dict[str, Any], + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], +) -> None: + """Test a generic binary_sensor where has_state is false.""" + entity_info = [ + BinarySensorInfo( + object_id="mybinary_sensor", + key=1, + name="my binary_sensor", + unique_id="my_binary_sensor", + ), + BinarySensorInfo( + object_id="mybinary_sensor_to_be_removed", + key=2, + name="my binary_sensor to be removed", + unique_id="mybinary_sensor_to_be_removed", + ), + ] + states = [ + BinarySensorState(key=1, state=True, missing_state=False), + BinarySensorState(key=2, state=True, missing_state=False), + ] + user_service = [] + mock_device = await mock_esphome_device( + mock_client=mock_client, + entity_info=entity_info, + user_service=user_service, + states=states, + ) + entry = mock_device.entry + entry_id = entry.entry_id + storage_key = f"esphome.{entry_id}" + state = hass.states.get("binary_sensor.test_my_binary_sensor") + assert state is not None + assert state.state == STATE_ON + state = hass.states.get("binary_sensor.test_my_binary_sensor_to_be_removed") + assert state is not None + assert state.state == STATE_ON + + await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + + assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 2 + + state = hass.states.get("binary_sensor.test_my_binary_sensor") + assert state is not None + assert state.attributes[ATTR_RESTORED] is True + state = hass.states.get("binary_sensor.test_my_binary_sensor_to_be_removed") + assert state is not None + assert state.attributes[ATTR_RESTORED] is True + + entity_info = [ + BinarySensorInfo( + object_id="mybinary_sensor", + key=1, + name="my binary_sensor", + unique_id="my_binary_sensor", + ), + ] + states = [ + BinarySensorState(key=1, state=True, missing_state=False), + ] + mock_device = await mock_esphome_device( + mock_client=mock_client, + entity_info=entity_info, + user_service=user_service, + states=states, + entry=entry, + ) + assert mock_device.entry.entry_id == entry_id + state = hass.states.get("binary_sensor.test_my_binary_sensor") + assert state is not None + assert state.state == STATE_ON + state = hass.states.get("binary_sensor.test_my_binary_sensor_to_be_removed") + assert state is None + await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done() + assert len(hass_storage[storage_key]["data"]["binary_sensor"]) == 1