ESPHome delete store data when unloading entry (#52296)
This commit is contained in:
parent
cca5964ac0
commit
f772eab7b7
4 changed files with 83 additions and 25 deletions
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
|
@ -54,10 +55,55 @@ _T = TypeVar("_T")
|
|||
STORAGE_VERSION = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class DomainData:
|
||||
"""Define a class that stores global esphome data in hass.data[DOMAIN]."""
|
||||
|
||||
_entry_datas: dict[str, RuntimeEntryData] = field(default_factory=dict)
|
||||
_stores: dict[str, Store] = field(default_factory=dict)
|
||||
|
||||
def get_entry_data(self, entry: ConfigEntry) -> RuntimeEntryData:
|
||||
"""Return the runtime entry data associated with this config entry.
|
||||
|
||||
Raises KeyError if the entry isn't loaded yet.
|
||||
"""
|
||||
return self._entry_datas[entry.entry_id]
|
||||
|
||||
def set_entry_data(self, entry: ConfigEntry, entry_data: RuntimeEntryData) -> None:
|
||||
"""Set the runtime entry data associated with this config entry."""
|
||||
if entry.entry_id in self._entry_datas:
|
||||
raise ValueError("Entry data for this entry is already set")
|
||||
self._entry_datas[entry.entry_id] = entry_data
|
||||
|
||||
def pop_entry_data(self, entry: ConfigEntry) -> RuntimeEntryData:
|
||||
"""Pop the runtime entry data instance associated with this config entry."""
|
||||
return self._entry_datas.pop(entry.entry_id)
|
||||
|
||||
def is_entry_loaded(self, entry: ConfigEntry) -> bool:
|
||||
"""Check whether the given entry is loaded."""
|
||||
return entry.entry_id in self._entry_datas
|
||||
|
||||
def get_or_create_store(self, hass: HomeAssistant, entry: ConfigEntry) -> Store:
|
||||
"""Get or create a Store instance for the given config entry."""
|
||||
return self._stores.setdefault(
|
||||
entry.entry_id,
|
||||
Store(
|
||||
hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get(cls: type[_T], hass: HomeAssistant) -> _T:
|
||||
"""Get the global DomainData instance stored in hass.data."""
|
||||
# Don't use setdefault - this is a hot code path
|
||||
if DOMAIN in hass.data:
|
||||
return hass.data[DOMAIN]
|
||||
ret = hass.data[DOMAIN] = cls()
|
||||
return ret
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up the esphome component."""
|
||||
hass.data.setdefault(DOMAIN, {})
|
||||
|
||||
host = entry.data[CONF_HOST]
|
||||
port = entry.data[CONF_PORT]
|
||||
password = entry.data[CONF_PASSWORD]
|
||||
|
@ -74,13 +120,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
zeroconf_instance=zeroconf_instance,
|
||||
)
|
||||
|
||||
# Store client in per-config-entry hass.data
|
||||
store = Store(
|
||||
hass, STORAGE_VERSION, f"esphome.{entry.entry_id}", encoder=JSONEncoder
|
||||
)
|
||||
entry_data = hass.data[DOMAIN][entry.entry_id] = RuntimeEntryData(
|
||||
client=cli, entry_id=entry.entry_id, store=store
|
||||
domain_data = DomainData.get(hass)
|
||||
entry_data = RuntimeEntryData(
|
||||
client=cli,
|
||||
entry_id=entry.entry_id,
|
||||
store=domain_data.get_or_create_store(hass, entry),
|
||||
)
|
||||
domain_data.set_entry_data(entry, entry_data)
|
||||
|
||||
async def on_stop(event: Event) -> None:
|
||||
"""Cleanup the socket client on HA stop."""
|
||||
|
@ -286,7 +332,11 @@ class ReconnectLogic(RecordUpdateListener):
|
|||
|
||||
@property
|
||||
def _entry_data(self) -> RuntimeEntryData | None:
|
||||
return self._hass.data[DOMAIN].get(self._entry.entry_id)
|
||||
domain_data = DomainData.get(self._hass)
|
||||
try:
|
||||
return domain_data.get_entry_data(self._entry)
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
async def _on_disconnect(self):
|
||||
"""Log and issue callbacks when disconnecting."""
|
||||
|
@ -382,7 +432,7 @@ class ReconnectLogic(RecordUpdateListener):
|
|||
return False
|
||||
|
||||
# Check if the entry got removed or disabled, in which case we shouldn't reconnect
|
||||
if self._entry.entry_id not in self._hass.data[DOMAIN]:
|
||||
if not DomainData.get(self._hass).is_entry_loaded(self._entry):
|
||||
# When removing/disconnecting manually
|
||||
return
|
||||
|
||||
|
@ -615,7 +665,8 @@ async def _cleanup_instance(
|
|||
hass: HomeAssistant, entry: ConfigEntry
|
||||
) -> RuntimeEntryData:
|
||||
"""Cleanup the esphome client if it exists."""
|
||||
data: RuntimeEntryData = hass.data[DOMAIN].pop(entry.entry_id)
|
||||
domain_data = DomainData.get(hass)
|
||||
data = domain_data.pop_entry_data(entry)
|
||||
for disconnect_cb in data.disconnect_callbacks:
|
||||
disconnect_cb()
|
||||
for cleanup_callback in data.cleanup_callbacks:
|
||||
|
@ -632,6 +683,11 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
)
|
||||
|
||||
|
||||
async def async_remove_entry(hass: HomeAssistant, entry: ConfigEntry) -> None:
|
||||
"""Remove an esphome config entry."""
|
||||
await DomainData.get(hass).get_or_create_store(hass, entry).async_remove()
|
||||
|
||||
|
||||
async def platform_async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
entry: ConfigEntry,
|
||||
|
@ -647,7 +703,7 @@ async def platform_async_setup_entry(
|
|||
This method is in charge of receiving, distributing and storing
|
||||
info and state updates.
|
||||
"""
|
||||
entry_data: RuntimeEntryData = hass.data[DOMAIN][entry.entry_id]
|
||||
entry_data: RuntimeEntryData = DomainData.get(hass).get_entry_data(entry)
|
||||
entry_data.info[component_key] = {}
|
||||
entry_data.old_info[component_key] = {}
|
||||
entry_data.state[component_key] = {}
|
||||
|
@ -668,7 +724,7 @@ async def platform_async_setup_entry(
|
|||
old_infos.pop(info.key)
|
||||
else:
|
||||
# Create new entity
|
||||
entity = entity_type(entry.entry_id, component_key, info.key)
|
||||
entity = entity_type(entry_data, component_key, info.key)
|
||||
add_entities.append(entity)
|
||||
new_infos[info.key] = info
|
||||
|
||||
|
@ -746,9 +802,11 @@ class EsphomeEnumMapper(Generic[_T]):
|
|||
class EsphomeBaseEntity(Entity):
|
||||
"""Define a base esphome entity."""
|
||||
|
||||
def __init__(self, entry_id: str, component_key: str, key: int) -> None:
|
||||
def __init__(
|
||||
self, entry_data: RuntimeEntryData, component_key: str, key: int
|
||||
) -> None:
|
||||
"""Initialize."""
|
||||
self._entry_id = entry_id
|
||||
self._entry_data = entry_data
|
||||
self._component_key = component_key
|
||||
self._key = key
|
||||
|
||||
|
@ -784,8 +842,8 @@ class EsphomeBaseEntity(Entity):
|
|||
self.async_write_ha_state()
|
||||
|
||||
@property
|
||||
def _entry_data(self) -> RuntimeEntryData:
|
||||
return self.hass.data[DOMAIN][self._entry_id]
|
||||
def _entry_id(self) -> str:
|
||||
return self._entry_data.entry_id
|
||||
|
||||
@property
|
||||
def _api_version(self) -> APIVersion:
|
||||
|
|
|
@ -32,10 +32,10 @@ async def async_setup_entry(
|
|||
class EsphomeCamera(Camera, EsphomeBaseEntity):
|
||||
"""A camera implementation for ESPHome."""
|
||||
|
||||
def __init__(self, entry_id: str, component_key: str, key: int) -> None:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
"""Initialize."""
|
||||
Camera.__init__(self)
|
||||
EsphomeBaseEntity.__init__(self, entry_id, component_key, key)
|
||||
EsphomeBaseEntity.__init__(self, *args, **kwargs)
|
||||
self._image_cond = asyncio.Condition()
|
||||
|
||||
@property
|
||||
|
|
|
@ -12,8 +12,7 @@ from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PASSWORD, CONF_PORT
|
|||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from . import DOMAIN
|
||||
from .entry_data import RuntimeEntryData
|
||||
from . import DOMAIN, DomainData
|
||||
|
||||
|
||||
class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN):
|
||||
|
@ -104,9 +103,9 @@ class EsphomeFlowHandler(ConfigFlow, domain=DOMAIN):
|
|||
]:
|
||||
# Is this address or IP address already configured?
|
||||
already_configured = True
|
||||
elif entry.entry_id in self.hass.data.get(DOMAIN, {}):
|
||||
elif DomainData.get(self.hass).is_entry_loaded(entry):
|
||||
# Does a config entry with this name already exist?
|
||||
data: RuntimeEntryData = self.hass.data[DOMAIN][entry.entry_id]
|
||||
data = DomainData.get(self.hass).get_entry_data(entry)
|
||||
|
||||
# Node names are unique in the network
|
||||
if data.device_info is not None:
|
||||
|
|
|
@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.esphome import DOMAIN
|
||||
from homeassistant.components.esphome import DOMAIN, DomainData
|
||||
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT
|
||||
from homeassistant.data_entry_flow import (
|
||||
RESULT_TYPE_ABORT,
|
||||
|
@ -265,7 +265,8 @@ async def test_discovery_already_configured_name(hass, mock_client):
|
|||
|
||||
mock_entry_data = MagicMock()
|
||||
mock_entry_data.device_info.name = "test8266"
|
||||
hass.data[DOMAIN] = {entry.entry_id: mock_entry_data}
|
||||
domain_data = DomainData.get(hass)
|
||||
domain_data.set_entry_data(entry, mock_entry_data)
|
||||
|
||||
service_info = {
|
||||
"host": "192.168.43.184",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue