From f772eab7b77c70919ef2cfb335fcf1bf17a7a53c Mon Sep 17 00:00:00 2001 From: Otto Winter Date: Wed, 30 Jun 2021 00:06:24 +0200 Subject: [PATCH] ESPHome delete store data when unloading entry (#52296) --- homeassistant/components/esphome/__init__.py | 92 +++++++++++++++---- homeassistant/components/esphome/camera.py | 4 +- .../components/esphome/config_flow.py | 7 +- tests/components/esphome/test_config_flow.py | 5 +- 4 files changed, 83 insertions(+), 25 deletions(-) diff --git a/homeassistant/components/esphome/__init__.py b/homeassistant/components/esphome/__init__.py index b91197b019c..aa7da100505 100644 --- a/homeassistant/components/esphome/__init__.py +++ b/homeassistant/components/esphome/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +from dataclasses import dataclass, field import functools import logging import math @@ -54,10 +55,55 @@ _T = TypeVar("_T") STORAGE_VERSION = 1 +@dataclass +class DomainData: + """Define a class that stores global esphome data in hass.data[DOMAIN].""" + + _entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict) + _stores: dict[str, Store] = field(default_factory=dict) + + def get_entry_data(self, entry: ConfigEntry) -> RuntimeEntryData: + """Return the runtime entry data associated with this config entry. + + Raises KeyError if the entry isn't loaded yet. + """ + return self._entry_datas[entry.entry_id] + + def set_entry_data(self, entry: ConfigEntry, entry_data: RuntimeEntryData) -> None: + """Set the runtime entry data associated with this config entry.""" + if entry.entry_id in self._entry_datas: + raise ValueError("Entry data for this entry is already set") + self._entry_datas[entry.entry_id] = entry_data + + def pop_entry_data(self, entry: ConfigEntry) -> RuntimeEntryData: + """Pop the runtime entry data instance associated with this config entry.""" + return self._entry_datas.pop(entry.entry_id) + + def is_entry_loaded(self, entry: ConfigEntry) -> bool: + """Check whether the given entry is loaded.""" + return entry.entry_id in self._entry_datas + + def get_or_create_store(self, hass: HomeAssistant, entry: ConfigEntry) -> Store: + """Get or create a Store instance for the given config entry.""" + return self._stores.setdefault( + entry.entry_id, + Store( + hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder + ), + ) + + @classmethod + def get(cls: type[_T], hass: HomeAssistant) -> _T: + """Get the global DomainData instance stored in hass.data.""" + # Don't use setdefault - this is a hot code path + if DOMAIN in hass.data: + return hass.data[DOMAIN] + ret = hass.data[DOMAIN] = cls() + return ret + + async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up the esphome component.""" - hass.data.setdefault(DOMAIN, {}) - host = entry.data[CONF_HOST] port = entry.data[CONF_PORT] password = entry.data[CONF_PASSWORD] @@ -74,13 +120,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: zeroconf_instance=zeroconf_instance, ) - # Store client in per-config-entry hass.data - store = Store( - hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder - ) - entry_data = hass.data[DOMAIN][entry.entry_id] = RuntimeEntryData( - client=cli, entry_id=entry.entry_id, store=store + domain_data = DomainData.get(hass) + entry_data = RuntimeEntryData( + client=cli, + entry_id=entry.entry_id, + store=domain_data.get_or_create_store(hass, entry), ) + domain_data.set_entry_data(entry, entry_data) async def on_stop(event: Event) -> None: """Cleanup the socket client on HA stop.""" @@ -286,7 +332,11 @@ class ReconnectLogic(RecordUpdateListener): @property def _entry_data(self) -> RuntimeEntryData | None: - return self._hass.data[DOMAIN].get(self._entry.entry_id) + domain_data = DomainData.get(self._hass) + try: + return domain_data.get_entry_data(self._entry) + except KeyError: + return None async def _on_disconnect(self): """Log and issue callbacks when disconnecting.""" @@ -382,7 +432,7 @@ class ReconnectLogic(RecordUpdateListener): return False # Check if the entry got removed or disabled, in which case we shouldn't reconnect - if self._entry.entry_id not in self._hass.data[DOMAIN]: + if not DomainData.get(self._hass).is_entry_loaded(self._entry): # When removing/disconnecting manually return @@ -615,7 +665,8 @@ async def _cleanup_instance( hass: HomeAssistant, entry: ConfigEntry ) -> RuntimeEntryData: """Cleanup the esphome client if it exists.""" - data: RuntimeEntryData = hass.data[DOMAIN].pop(entry.entry_id) + domain_data = DomainData.get(hass) + data = domain_data.pop_entry_data(entry) for disconnect_cb in data.disconnect_callbacks: disconnect_cb() for cleanup_callback in data.cleanup_callbacks: @@ -632,6 +683,11 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ) +async def async_remove_entry(hass: HomeAssistant, entry: ConfigEntry) -> None: + """Remove an esphome config entry.""" + await DomainData.get(hass).get_or_create_store(hass, entry).async_remove() + + async def platform_async_setup_entry( hass: HomeAssistant, entry: ConfigEntry, @@ -647,7 +703,7 @@ async def platform_async_setup_entry( This method is in charge of receiving, distributing and storing info and state updates. """ - entry_data: RuntimeEntryData = hass.data[DOMAIN][entry.entry_id] + entry_data: RuntimeEntryData = DomainData.get(hass).get_entry_data(entry) entry_data.info[component_key] = {} entry_data.old_info[component_key] = {} entry_data.state[component_key] = {} @@ -668,7 +724,7 @@ async def platform_async_setup_entry( old_infos.pop(info.key) else: # Create new entity - entity = entity_type(entry.entry_id, component_key, info.key) + entity = entity_type(entry_data, component_key, info.key) add_entities.append(entity) new_infos[info.key] = info @@ -746,9 +802,11 @@ class EsphomeEnumMapper(Generic[_T]): class EsphomeBaseEntity(Entity): """Define a base esphome entity.""" - def __init__(self, entry_id: str, component_key: str, key: int) -> None: + def __init__( + self, entry_data: RuntimeEntryData, component_key: str, key: int + ) -> None: """Initialize.""" - self._entry_id = entry_id + self._entry_data = entry_data self._component_key = component_key self._key = key @@ -784,8 +842,8 @@ class EsphomeBaseEntity(Entity): self.async_write_ha_state() @property - def _entry_data(self) -> RuntimeEntryData: - return self.hass.data[DOMAIN][self._entry_id] + def _entry_id(self) -> str: + return self._entry_data.entry_id @property def _api_version(self) -> APIVersion: diff --git a/homeassistant/components/esphome/camera.py b/homeassistant/components/esphome/camera.py index 6b553de1a13..7afd89bf9be 100644 --- a/homeassistant/components/esphome/camera.py +++ b/homeassistant/components/esphome/camera.py @@ -32,10 +32,10 @@ async def async_setup_entry( class EsphomeCamera(Camera, EsphomeBaseEntity): """A camera implementation for ESPHome.""" - def __init__(self, entry_id: str, component_key: str, key: int) -> None: + def __init__(self, *args, **kwargs) -> None: """Initialize.""" Camera.__init__(self) - EsphomeBaseEntity.__init__(self, entry_id, component_key, key) + EsphomeBaseEntity.__init__(self, *args, **kwargs) self._image_cond = asyncio.Condition() @property diff --git a/homeassistant/components/esphome/config_flow.py b/homeassistant/components/esphome/config_flow.py index e31fa202a39..38e44b12508 100644 --- a/homeassistant/components/esphome/config_flow.py +++ b/homeassistant/components/esphome/config_flow.py @@ -12,8 +12,7 @@ from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PASSWORD, CONF_PORT from homeassistant.core import callback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import DOMAIN -from .entry_data import RuntimeEntryData +from . import DOMAIN, DomainData class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN): @@ -104,9 +103,9 @@ class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN): ]: # Is this address or IP address already configured? already_configured = True - elif entry.entry_id in self.hass.data.get(DOMAIN, {}): + elif DomainData.get(self.hass).is_entry_loaded(entry): # Does a config entry with this name already exist? - data: RuntimeEntryData = self.hass.data[DOMAIN][entry.entry_id] + data = DomainData.get(self.hass).get_entry_data(entry) # Node names are unique in the network if data.device_info is not None: diff --git a/tests/components/esphome/test_config_flow.py b/tests/components/esphome/test_config_flow.py index d5968e7f731..735a02e960c 100644 --- a/tests/components/esphome/test_config_flow.py +++ b/tests/components/esphome/test_config_flow.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from homeassistant import config_entries -from homeassistant.components.esphome import DOMAIN +from homeassistant.components.esphome import DOMAIN, DomainData from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT from homeassistant.data_entry_flow import ( RESULT_TYPE_ABORT, @@ -265,7 +265,8 @@ async def test_discovery_already_configured_name(hass, mock_client): mock_entry_data = MagicMock() mock_entry_data.device_info.name = "test8266" - hass.data[DOMAIN] = {entry.entry_id: mock_entry_data} + domain_data = DomainData.get(hass) + domain_data.set_entry_data(entry, mock_entry_data) service_info = { "host": "192.168.43.184",