ESPHome delete store data when unloading entry (#52296)

This commit is contained in:
Otto Winter 2021-06-30 00:06:24 +02:00 committed by GitHub
parent cca5964ac0
commit f772eab7b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 83 additions and 25 deletions

View file

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from dataclasses import dataclass, field
import functools import functools
import logging import logging
import math import math
@ -54,10 +55,55 @@ _T = TypeVar("_T")
STORAGE_VERSION = 1 STORAGE_VERSION = 1
@dataclass
class DomainData:
"""Define a class that stores global esphome data in hass.data[DOMAIN]."""
_entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict)
_stores: dict[str, Store] = field(default_factory=dict)
def get_entry_data(self, entry: ConfigEntry) -> RuntimeEntryData:
"""Return the runtime entry data associated with this config entry.
Raises KeyError if the entry isn't loaded yet.
"""
return self._entry_datas[entry.entry_id]
def set_entry_data(self, entry: ConfigEntry, entry_data: RuntimeEntryData) -> None:
"""Set the runtime entry data associated with this config entry."""
if entry.entry_id in self._entry_datas:
raise ValueError("Entry data for this entry is already set")
self._entry_datas[entry.entry_id] = entry_data
def pop_entry_data(self, entry: ConfigEntry) -> RuntimeEntryData:
"""Pop the runtime entry data instance associated with this config entry."""
return self._entry_datas.pop(entry.entry_id)
def is_entry_loaded(self, entry: ConfigEntry) -> bool:
"""Check whether the given entry is loaded."""
return entry.entry_id in self._entry_datas
def get_or_create_store(self, hass: HomeAssistant, entry: ConfigEntry) -> Store:
"""Get or create a Store instance for the given config entry."""
return self._stores.setdefault(
entry.entry_id,
Store(
hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder
),
)
@classmethod
def get(cls: type[_T], hass: HomeAssistant) -> _T:
"""Get the global DomainData instance stored in hass.data."""
# Don't use setdefault - this is a hot code path
if DOMAIN in hass.data:
return hass.data[DOMAIN]
ret = hass.data[DOMAIN] = cls()
return ret
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up the esphome component.""" """Set up the esphome component."""
hass.data.setdefault(DOMAIN, {})
host = entry.data[CONF_HOST] host = entry.data[CONF_HOST]
port = entry.data[CONF_PORT] port = entry.data[CONF_PORT]
password = entry.data[CONF_PASSWORD] password = entry.data[CONF_PASSWORD]
@ -74,13 +120,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
zeroconf_instance=zeroconf_instance, zeroconf_instance=zeroconf_instance,
) )
# Store client in per-config-entry hass.data domain_data = DomainData.get(hass)
store = Store( entry_data = RuntimeEntryData(
hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder client=cli,
) entry_id=entry.entry_id,
entry_data = hass.data[DOMAIN][entry.entry_id] = RuntimeEntryData( store=domain_data.get_or_create_store(hass, entry),
client=cli, entry_id=entry.entry_id, store=store
) )
domain_data.set_entry_data(entry, entry_data)
async def on_stop(event: Event) -> None: async def on_stop(event: Event) -> None:
"""Cleanup the socket client on HA stop.""" """Cleanup the socket client on HA stop."""
@ -286,7 +332,11 @@ class ReconnectLogic(RecordUpdateListener):
@property @property
def _entry_data(self) -> RuntimeEntryData | None: def _entry_data(self) -> RuntimeEntryData | None:
return self._hass.data[DOMAIN].get(self._entry.entry_id) domain_data = DomainData.get(self._hass)
try:
return domain_data.get_entry_data(self._entry)
except KeyError:
return None
async def _on_disconnect(self): async def _on_disconnect(self):
"""Log and issue callbacks when disconnecting.""" """Log and issue callbacks when disconnecting."""
@ -382,7 +432,7 @@ class ReconnectLogic(RecordUpdateListener):
return False return False
# Check if the entry got removed or disabled, in which case we shouldn't reconnect # Check if the entry got removed or disabled, in which case we shouldn't reconnect
if self._entry.entry_id not in self._hass.data[DOMAIN]: if not DomainData.get(self._hass).is_entry_loaded(self._entry):
# When removing/disconnecting manually # When removing/disconnecting manually
return return
@ -615,7 +665,8 @@ async def _cleanup_instance(
hass: HomeAssistant, entry: ConfigEntry hass: HomeAssistant, entry: ConfigEntry
) -> RuntimeEntryData: ) -> RuntimeEntryData:
"""Cleanup the esphome client if it exists.""" """Cleanup the esphome client if it exists."""
data: RuntimeEntryData = hass.data[DOMAIN].pop(entry.entry_id) domain_data = DomainData.get(hass)
data = domain_data.pop_entry_data(entry)
for disconnect_cb in data.disconnect_callbacks: for disconnect_cb in data.disconnect_callbacks:
disconnect_cb() disconnect_cb()
for cleanup_callback in data.cleanup_callbacks: for cleanup_callback in data.cleanup_callbacks:
@ -632,6 +683,11 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
) )
async def async_remove_entry(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Remove an esphome config entry."""
await DomainData.get(hass).get_or_create_store(hass, entry).async_remove()
async def platform_async_setup_entry( async def platform_async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
entry: ConfigEntry, entry: ConfigEntry,
@ -647,7 +703,7 @@ async def platform_async_setup_entry(
This method is in charge of receiving, distributing and storing This method is in charge of receiving, distributing and storing
info and state updates. info and state updates.
""" """
entry_data: RuntimeEntryData = hass.data[DOMAIN][entry.entry_id] entry_data: RuntimeEntryData = DomainData.get(hass).get_entry_data(entry)
entry_data.info[component_key] = {} entry_data.info[component_key] = {}
entry_data.old_info[component_key] = {} entry_data.old_info[component_key] = {}
entry_data.state[component_key] = {} entry_data.state[component_key] = {}
@ -668,7 +724,7 @@ async def platform_async_setup_entry(
old_infos.pop(info.key) old_infos.pop(info.key)
else: else:
# Create new entity # Create new entity
entity = entity_type(entry.entry_id, component_key, info.key) entity = entity_type(entry_data, component_key, info.key)
add_entities.append(entity) add_entities.append(entity)
new_infos[info.key] = info new_infos[info.key] = info
@ -746,9 +802,11 @@ class EsphomeEnumMapper(Generic[_T]):
class EsphomeBaseEntity(Entity): class EsphomeBaseEntity(Entity):
"""Define a base esphome entity.""" """Define a base esphome entity."""
def __init__(self, entry_id: str, component_key: str, key: int) -> None: def __init__(
self, entry_data: RuntimeEntryData, component_key: str, key: int
) -> None:
"""Initialize.""" """Initialize."""
self._entry_id = entry_id self._entry_data = entry_data
self._component_key = component_key self._component_key = component_key
self._key = key self._key = key
@ -784,8 +842,8 @@ class EsphomeBaseEntity(Entity):
self.async_write_ha_state() self.async_write_ha_state()
@property @property
def _entry_data(self) -> RuntimeEntryData: def _entry_id(self) -> str:
return self.hass.data[DOMAIN][self._entry_id] return self._entry_data.entry_id
@property @property
def _api_version(self) -> APIVersion: def _api_version(self) -> APIVersion:

View file

@ -32,10 +32,10 @@ async def async_setup_entry(
class EsphomeCamera(Camera, EsphomeBaseEntity): class EsphomeCamera(Camera, EsphomeBaseEntity):
"""A camera implementation for ESPHome.""" """A camera implementation for ESPHome."""
def __init__(self, entry_id: str, component_key: str, key: int) -> None: def __init__(self, *args, **kwargs) -> None:
"""Initialize.""" """Initialize."""
Camera.__init__(self) Camera.__init__(self)
EsphomeBaseEntity.__init__(self, entry_id, component_key, key) EsphomeBaseEntity.__init__(self, *args, **kwargs)
self._image_cond = asyncio.Condition() self._image_cond = asyncio.Condition()
@property @property

View file

@ -12,8 +12,7 @@ from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PASSWORD, CONF_PORT
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import DOMAIN from . import DOMAIN, DomainData
from .entry_data import RuntimeEntryData
class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN): class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN):
@ -104,9 +103,9 @@ class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN):
]: ]:
# Is this address or IP address already configured? # Is this address or IP address already configured?
already_configured = True already_configured = True
elif entry.entry_id in self.hass.data.get(DOMAIN, {}): elif DomainData.get(self.hass).is_entry_loaded(entry):
# Does a config entry with this name already exist? # Does a config entry with this name already exist?
data: RuntimeEntryData = self.hass.data[DOMAIN][entry.entry_id] data = DomainData.get(self.hass).get_entry_data(entry)
# Node names are unique in the network # Node names are unique in the network
if data.device_info is not None: if data.device_info is not None:

View file

@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.esphome import DOMAIN from homeassistant.components.esphome import DOMAIN, DomainData
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT
from homeassistant.data_entry_flow import ( from homeassistant.data_entry_flow import (
RESULT_TYPE_ABORT, RESULT_TYPE_ABORT,
@ -265,7 +265,8 @@ async def test_discovery_already_configured_name(hass, mock_client):
mock_entry_data = MagicMock() mock_entry_data = MagicMock()
mock_entry_data.device_info.name = "test8266" mock_entry_data.device_info.name = "test8266"
hass.data[DOMAIN] = {entry.entry_id: mock_entry_data} domain_data = DomainData.get(hass)
domain_data.set_entry_data(entry, mock_entry_data)
service_info = { service_info = {
"host": "192.168.43.184", "host": "192.168.43.184",