ESPHome delete store data when unloading entry (#52296)

This commit is contained in:
Otto Winter 2021-06-30 00:06:24 +02:00 committed by GitHub
parent cca5964ac0
commit f772eab7b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 83 additions and 25 deletions

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
import functools
import logging
import math
@ -54,10 +55,55 @@ _T = TypeVar("_T")
STORAGE_VERSION = 1
@dataclass
class DomainData:
"""Define a class that stores global esphome data in hass.data[DOMAIN]."""
_entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict)
_stores: dict[str, Store] = field(default_factory=dict)
def get_entry_data(self, entry: ConfigEntry) -> RuntimeEntryData:
"""Return the runtime entry data associated with this config entry.
Raises KeyError if the entry isn't loaded yet.
"""
return self._entry_datas[entry.entry_id]
def set_entry_data(self, entry: ConfigEntry, entry_data: RuntimeEntryData) -> None:
"""Set the runtime entry data associated with this config entry."""
if entry.entry_id in self._entry_datas:
raise ValueError("Entry data for this entry is already set")
self._entry_datas[entry.entry_id] = entry_data
def pop_entry_data(self, entry: ConfigEntry) -> RuntimeEntryData:
"""Pop the runtime entry data instance associated with this config entry."""
return self._entry_datas.pop(entry.entry_id)
def is_entry_loaded(self, entry: ConfigEntry) -> bool:
"""Check whether the given entry is loaded."""
return entry.entry_id in self._entry_datas
def get_or_create_store(self, hass: HomeAssistant, entry: ConfigEntry) -> Store:
"""Get or create a Store instance for the given config entry."""
return self._stores.setdefault(
entry.entry_id,
Store(
hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder
),
)
@classmethod
def get(cls: type[_T], hass: HomeAssistant) -> _T:
"""Get the global DomainData instance stored in hass.data."""
# Don't use setdefault - this is a hot code path
if DOMAIN in hass.data:
return hass.data[DOMAIN]
ret = hass.data[DOMAIN] = cls()
return ret
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up the esphome component."""
hass.data.setdefault(DOMAIN, {})
host = entry.data[CONF_HOST]
port = entry.data[CONF_PORT]
password = entry.data[CONF_PASSWORD]
@ -74,13 +120,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
zeroconf_instance=zeroconf_instance,
)
# Store client in per-config-entry hass.data
store = Store(
hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder
)
entry_data = hass.data[DOMAIN][entry.entry_id] = RuntimeEntryData(
client=cli, entry_id=entry.entry_id, store=store
domain_data = DomainData.get(hass)
entry_data = RuntimeEntryData(
client=cli,
entry_id=entry.entry_id,
store=domain_data.get_or_create_store(hass, entry),
)
domain_data.set_entry_data(entry, entry_data)
async def on_stop(event: Event) -> None:
"""Cleanup the socket client on HA stop."""
@ -286,7 +332,11 @@ class ReconnectLogic(RecordUpdateListener):
@property
def _entry_data(self) -> RuntimeEntryData | None:
return self._hass.data[DOMAIN].get(self._entry.entry_id)
domain_data = DomainData.get(self._hass)
try:
return domain_data.get_entry_data(self._entry)
except KeyError:
return None
async def _on_disconnect(self):
"""Log and issue callbacks when disconnecting."""
@ -382,7 +432,7 @@ class ReconnectLogic(RecordUpdateListener):
return False
# Check if the entry got removed or disabled, in which case we shouldn't reconnect
if self._entry.entry_id not in self._hass.data[DOMAIN]:
if not DomainData.get(self._hass).is_entry_loaded(self._entry):
# When removing/disconnecting manually
return
@ -615,7 +665,8 @@ async def _cleanup_instance(
hass: HomeAssistant, entry: ConfigEntry
) -> RuntimeEntryData:
"""Cleanup the esphome client if it exists."""
data: RuntimeEntryData = hass.data[DOMAIN].pop(entry.entry_id)
domain_data = DomainData.get(hass)
data = domain_data.pop_entry_data(entry)
for disconnect_cb in data.disconnect_callbacks:
disconnect_cb()
for cleanup_callback in data.cleanup_callbacks:
@ -632,6 +683,11 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
)
async def async_remove_entry(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Remove an esphome config entry."""
await DomainData.get(hass).get_or_create_store(hass, entry).async_remove()
async def platform_async_setup_entry(
hass: HomeAssistant,
entry: ConfigEntry,
@ -647,7 +703,7 @@ async def platform_async_setup_entry(
This method is in charge of receiving, distributing and storing
info and state updates.
"""
entry_data: RuntimeEntryData = hass.data[DOMAIN][entry.entry_id]
entry_data: RuntimeEntryData = DomainData.get(hass).get_entry_data(entry)
entry_data.info[component_key] = {}
entry_data.old_info[component_key] = {}
entry_data.state[component_key] = {}
@ -668,7 +724,7 @@ async def platform_async_setup_entry(
old_infos.pop(info.key)
else:
# Create new entity
entity = entity_type(entry.entry_id, component_key, info.key)
entity = entity_type(entry_data, component_key, info.key)
add_entities.append(entity)
new_infos[info.key] = info
@ -746,9 +802,11 @@ class EsphomeEnumMapper(Generic[_T]):
class EsphomeBaseEntity(Entity):
"""Define a base esphome entity."""
def __init__(self, entry_id: str, component_key: str, key: int) -> None:
def __init__(
self, entry_data: RuntimeEntryData, component_key: str, key: int
) -> None:
"""Initialize."""
self._entry_id = entry_id
self._entry_data = entry_data
self._component_key = component_key
self._key = key
@ -784,8 +842,8 @@ class EsphomeBaseEntity(Entity):
self.async_write_ha_state()
@property
def _entry_data(self) -> RuntimeEntryData:
return self.hass.data[DOMAIN][self._entry_id]
def _entry_id(self) -> str:
return self._entry_data.entry_id
@property
def _api_version(self) -> APIVersion:

View file

@ -32,10 +32,10 @@ async def async_setup_entry(
class EsphomeCamera(Camera, EsphomeBaseEntity):
"""A camera implementation for ESPHome."""
def __init__(self, entry_id: str, component_key: str, key: int) -> None:
def __init__(self, *args, **kwargs) -> None:
"""Initialize."""
Camera.__init__(self)
EsphomeBaseEntity.__init__(self, entry_id, component_key, key)
EsphomeBaseEntity.__init__(self, *args, **kwargs)
self._image_cond = asyncio.Condition()
@property

View file

@ -12,8 +12,7 @@ from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PASSWORD, CONF_PORT
from homeassistant.core import callback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import DOMAIN
from .entry_data import RuntimeEntryData
from . import DOMAIN, DomainData
class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN):
@ -104,9 +103,9 @@ class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN):
]:
# Is this address or IP address already configured?
already_configured = True
elif entry.entry_id in self.hass.data.get(DOMAIN, {}):
elif DomainData.get(self.hass).is_entry_loaded(entry):
# Does a config entry with this name already exist?
data: RuntimeEntryData = self.hass.data[DOMAIN][entry.entry_id]
data = DomainData.get(self.hass).get_entry_data(entry)
# Node names are unique in the network
if data.device_info is not None:

View file

@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from homeassistant import config_entries
from homeassistant.components.esphome import DOMAIN
from homeassistant.components.esphome import DOMAIN, DomainData
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT
from homeassistant.data_entry_flow import (
RESULT_TYPE_ABORT,
@ -265,7 +265,8 @@ async def test_discovery_already_configured_name(hass, mock_client):
mock_entry_data = MagicMock()
mock_entry_data.device_info.name = "test8266"
hass.data[DOMAIN] = {entry.entry_id: mock_entry_data}
domain_data = DomainData.get(hass)
domain_data.set_entry_data(entry, mock_entry_data)
service_info = {
"host": "192.168.43.184",