Migrate restore_state helper to use registry loading pattern (#93773)

* Migrate restore_state helper to use registry loading pattern

As more entities have started using restore_state over time, it
has become a startup bottleneck as each entity being added is
creating a task to load restore state data that is already loaded
since it is a singleton

We now use the same pattern as the registry helpers

* fix refactoring error -- guess I am tired

* fixes

* fix tests

* fix more

* fix more

* fix zha tests

* fix zha tests

* comments

* fix error

* add missing coverage

* s/DATA_RESTORE_STATE_TASK/DATA_RESTORE_STATE/g
This commit is contained in:
J. Nick Koston 2023-05-30 20:48:17 -05:00 committed by GitHub
parent b91c6911d9
commit fba826ae9e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 147 additions and 86 deletions

View file

@ -32,6 +32,7 @@ from .helpers import (
entity_registry, entity_registry,
issue_registry, issue_registry,
recorder, recorder,
restore_state,
template, template,
) )
from .helpers.dispatcher import async_dispatcher_send from .helpers.dispatcher import async_dispatcher_send
@ -248,6 +249,7 @@ async def load_registries(hass: core.HomeAssistant) -> None:
issue_registry.async_load(hass), issue_registry.async_load(hass),
hass.async_add_executor_job(_cache_uname_processor), hass.async_add_executor_job(_cache_uname_processor),
template.async_load_custom_templates(hass), template.async_load_custom_templates(hass),
restore_state.async_load(hass),
) )

View file

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
import logging import logging
from typing import Any, cast from typing import Any, cast
@ -18,10 +17,9 @@ from . import start
from .entity import Entity from .entity import Entity
from .event import async_track_time_interval from .event import async_track_time_interval
from .json import JSONEncoder from .json import JSONEncoder
from .singleton import singleton
from .storage import Store from .storage import Store
DATA_RESTORE_STATE_TASK = "restore_state_task" DATA_RESTORE_STATE = "restore_state"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -96,31 +94,25 @@ class StoredState:
) )
async def async_load(hass: HomeAssistant) -> None:
"""Load the restore state task."""
hass.data[DATA_RESTORE_STATE] = await RestoreStateData.async_get_instance(hass)
@callback
def async_get(hass: HomeAssistant) -> RestoreStateData:
"""Get the restore state data helper."""
return cast(RestoreStateData, hass.data[DATA_RESTORE_STATE])
class RestoreStateData: class RestoreStateData:
"""Helper class for managing the helper saved data.""" """Helper class for managing the helper saved data."""
@staticmethod @staticmethod
@singleton(DATA_RESTORE_STATE_TASK)
async def async_get_instance(hass: HomeAssistant) -> RestoreStateData: async def async_get_instance(hass: HomeAssistant) -> RestoreStateData:
"""Get the singleton instance of this data helper.""" """Get the instance of this data helper."""
data = RestoreStateData(hass) data = RestoreStateData(hass)
await data.async_load()
try:
stored_states = await data.store.async_load()
except HomeAssistantError as exc:
_LOGGER.error("Error loading last states", exc_info=exc)
stored_states = None
if stored_states is None:
_LOGGER.debug("Not creating cache - no saved states found")
data.last_states = {}
else:
data.last_states = {
item["state"]["entity_id"]: StoredState.from_dict(item)
for item in stored_states
if valid_entity_id(item["state"]["entity_id"])
}
_LOGGER.debug("Created cache with %s", list(data.last_states))
async def hass_start(hass: HomeAssistant) -> None: async def hass_start(hass: HomeAssistant) -> None:
"""Start the restore state task.""" """Start the restore state task."""
@ -133,8 +125,7 @@ class RestoreStateData:
@classmethod @classmethod
async def async_save_persistent_states(cls, hass: HomeAssistant) -> None: async def async_save_persistent_states(cls, hass: HomeAssistant) -> None:
"""Dump states now.""" """Dump states now."""
data = await cls.async_get_instance(hass) await async_get(hass).async_dump_states()
await data.async_dump_states()
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the restore state data class.""" """Initialize the restore state data class."""
@ -145,6 +136,25 @@ class RestoreStateData:
self.last_states: dict[str, StoredState] = {} self.last_states: dict[str, StoredState] = {}
self.entities: dict[str, RestoreEntity] = {} self.entities: dict[str, RestoreEntity] = {}
async def async_load(self) -> None:
"""Load the instance of this data helper."""
try:
stored_states = await self.store.async_load()
except HomeAssistantError as exc:
_LOGGER.error("Error loading last states", exc_info=exc)
stored_states = None
if stored_states is None:
_LOGGER.debug("Not creating cache - no saved states found")
self.last_states = {}
else:
self.last_states = {
item["state"]["entity_id"]: StoredState.from_dict(item)
for item in stored_states
if valid_entity_id(item["state"]["entity_id"])
}
_LOGGER.debug("Created cache with %s", list(self.last_states))
@callback @callback
def async_get_stored_states(self) -> list[StoredState]: def async_get_stored_states(self) -> list[StoredState]:
"""Get the set of states which should be stored. """Get the set of states which should be stored.
@ -288,21 +298,18 @@ class RestoreEntity(Entity):
async def async_internal_added_to_hass(self) -> None: async def async_internal_added_to_hass(self) -> None:
"""Register this entity as a restorable entity.""" """Register this entity as a restorable entity."""
_, data = await asyncio.gather( await super().async_internal_added_to_hass()
super().async_internal_added_to_hass(), async_get(self.hass).async_restore_entity_added(self)
RestoreStateData.async_get_instance(self.hass),
)
data.async_restore_entity_added(self)
async def async_internal_will_remove_from_hass(self) -> None: async def async_internal_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass.""" """Run when entity will be removed from hass."""
_, data = await asyncio.gather( async_get(self.hass).async_restore_entity_removed(
super().async_internal_will_remove_from_hass(), self.entity_id, self.extra_restore_state_data
RestoreStateData.async_get_instance(self.hass),
) )
data.async_restore_entity_removed(self.entity_id, self.extra_restore_state_data) await super().async_internal_will_remove_from_hass()
async def _async_get_restored_data(self) -> StoredState | None: @callback
def _async_get_restored_data(self) -> StoredState | None:
"""Get data stored for an entity, if any.""" """Get data stored for an entity, if any."""
if self.hass is None or self.entity_id is None: if self.hass is None or self.entity_id is None:
# Return None if this entity isn't added to hass yet # Return None if this entity isn't added to hass yet
@ -310,20 +317,17 @@ class RestoreEntity(Entity):
"Cannot get last state. Entity not added to hass" "Cannot get last state. Entity not added to hass"
) )
return None return None
data = await RestoreStateData.async_get_instance(self.hass) return async_get(self.hass).last_states.get(self.entity_id)
if self.entity_id not in data.last_states:
return None
return data.last_states[self.entity_id]
async def async_get_last_state(self) -> State | None: async def async_get_last_state(self) -> State | None:
"""Get the entity state from the previous run.""" """Get the entity state from the previous run."""
if (stored_state := await self._async_get_restored_data()) is None: if (stored_state := self._async_get_restored_data()) is None:
return None return None
return stored_state.state return stored_state.state
async def async_get_last_extra_data(self) -> ExtraStoredData | None: async def async_get_last_extra_data(self) -> ExtraStoredData | None:
"""Get the entity specific state data from the previous run.""" """Get the entity specific state data from the previous run."""
if (stored_state := await self._async_get_restored_data()) is None: if (stored_state := self._async_get_restored_data()) is None:
return None return None
return stored_state.extra_data return stored_state.extra_data

View file

@ -61,6 +61,7 @@ from homeassistant.helpers import (
issue_registry as ir, issue_registry as ir,
recorder as recorder_helper, recorder as recorder_helper,
restore_state, restore_state,
restore_state as rs,
storage, storage,
) )
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
@ -251,12 +252,20 @@ async def async_test_home_assistant(event_loop, load_registries=True):
# Load the registries # Load the registries
entity.async_setup(hass) entity.async_setup(hass)
if load_registries: if load_registries:
with patch("homeassistant.helpers.storage.Store.async_load", return_value=None): with patch(
"homeassistant.helpers.storage.Store.async_load", return_value=None
), patch(
"homeassistant.helpers.restore_state.RestoreStateData.async_setup_dump",
return_value=None,
), patch(
"homeassistant.helpers.restore_state.start.async_at_start"
):
await asyncio.gather( await asyncio.gather(
ar.async_load(hass), ar.async_load(hass),
dr.async_load(hass), dr.async_load(hass),
er.async_load(hass), er.async_load(hass),
ir.async_load(hass), ir.async_load(hass),
rs.async_load(hass),
) )
hass.data[bootstrap.DATA_REGISTRIES_LOADED] = None hass.data[bootstrap.DATA_REGISTRIES_LOADED] = None
@ -1010,7 +1019,7 @@ def init_recorder_component(hass, add_config=None, db_url="sqlite://"):
def mock_restore_cache(hass: HomeAssistant, states: Sequence[State]) -> None: def mock_restore_cache(hass: HomeAssistant, states: Sequence[State]) -> None:
"""Mock the DATA_RESTORE_CACHE.""" """Mock the DATA_RESTORE_CACHE."""
key = restore_state.DATA_RESTORE_STATE_TASK key = restore_state.DATA_RESTORE_STATE
data = restore_state.RestoreStateData(hass) data = restore_state.RestoreStateData(hass)
now = dt_util.utcnow() now = dt_util.utcnow()
@ -1037,7 +1046,7 @@ def mock_restore_cache_with_extra_data(
hass: HomeAssistant, states: Sequence[tuple[State, Mapping[str, Any]]] hass: HomeAssistant, states: Sequence[tuple[State, Mapping[str, Any]]]
) -> None: ) -> None:
"""Mock the DATA_RESTORE_CACHE.""" """Mock the DATA_RESTORE_CACHE."""
key = restore_state.DATA_RESTORE_STATE_TASK key = restore_state.DATA_RESTORE_STATE
data = restore_state.RestoreStateData(hass) data = restore_state.RestoreStateData(hass)
now = dt_util.utcnow() now = dt_util.utcnow()
@ -1060,6 +1069,26 @@ def mock_restore_cache_with_extra_data(
hass.data[key] = data hass.data[key] = data
async def async_mock_restore_state_shutdown_restart(
hass: HomeAssistant,
) -> restore_state.RestoreStateData:
"""Mock shutting down and saving restore state and restoring."""
data = restore_state.async_get(hass)
await data.async_dump_states()
await async_mock_load_restore_state_from_storage(hass)
return data
async def async_mock_load_restore_state_from_storage(
hass: HomeAssistant,
) -> None:
"""Mock loading restore state from storage.
hass_storage must already be mocked.
"""
await restore_state.async_get(hass).async_load()
class MockEntity(entity.Entity): class MockEntity(entity.Entity):
"""Mock Entity class.""" """Mock Entity class."""

View file

@ -34,7 +34,10 @@ from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM
from tests.common import mock_restore_cache_with_extra_data from tests.common import (
async_mock_restore_state_shutdown_restart,
mock_restore_cache_with_extra_data,
)
class MockDefaultNumberEntity(NumberEntity): class MockDefaultNumberEntity(NumberEntity):
@ -635,7 +638,7 @@ async def test_restore_number_save_state(
await hass.async_block_till_done() await hass.async_block_till_done()
# Trigger saving state # Trigger saving state
await hass.async_stop() await async_mock_restore_state_shutdown_restart(hass)
assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1 assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1
state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"] state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"]

View file

@ -35,7 +35,10 @@ from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM
from tests.common import mock_restore_cache_with_extra_data from tests.common import (
async_mock_restore_state_shutdown_restart,
mock_restore_cache_with_extra_data,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -397,7 +400,7 @@ async def test_restore_sensor_save_state(
await hass.async_block_till_done() await hass.async_block_till_done()
# Trigger saving state # Trigger saving state
await hass.async_stop() await async_mock_restore_state_shutdown_restart(hass)
assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1 assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1
state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"] state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"]

View file

@ -20,7 +20,10 @@ from homeassistant.core import HomeAssistant, ServiceCall, State
from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import mock_restore_cache_with_extra_data from tests.common import (
async_mock_restore_state_shutdown_restart,
mock_restore_cache_with_extra_data,
)
class MockTextEntity(TextEntity): class MockTextEntity(TextEntity):
@ -141,7 +144,7 @@ async def test_restore_number_save_state(
await hass.async_block_till_done() await hass.async_block_till_done()
# Trigger saving state # Trigger saving state
await hass.async_stop() await async_mock_restore_state_shutdown_restart(hass)
assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1 assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1
state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"] state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"]

View file

@ -47,11 +47,7 @@ from homeassistant.const import (
from homeassistant.core import Context, CoreState, HomeAssistant, State from homeassistant.core import Context, CoreState, HomeAssistant, State
from homeassistant.exceptions import HomeAssistantError, Unauthorized from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers import config_validation as cv, entity_registry as er from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.restore_state import ( from homeassistant.helpers.restore_state import StoredState, async_get
DATA_RESTORE_STATE_TASK,
RestoreStateData,
StoredState,
)
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
@ -838,12 +834,9 @@ async def test_restore_idle(hass: HomeAssistant) -> None:
utc_now, utc_now,
) )
data = await RestoreStateData.async_get_instance(hass) data = async_get(hass)
await hass.async_block_till_done()
await data.store.async_save([stored_state.as_dict()]) await data.store.async_save([stored_state.as_dict()])
await data.async_load()
# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
entity = Timer.from_storage( entity = Timer.from_storage(
{ {
@ -878,12 +871,9 @@ async def test_restore_paused(hass: HomeAssistant) -> None:
utc_now, utc_now,
) )
data = await RestoreStateData.async_get_instance(hass) data = async_get(hass)
await hass.async_block_till_done()
await data.store.async_save([stored_state.as_dict()]) await data.store.async_save([stored_state.as_dict()])
await data.async_load()
# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
entity = Timer.from_storage( entity = Timer.from_storage(
{ {
@ -922,12 +912,9 @@ async def test_restore_active_resume(hass: HomeAssistant) -> None:
utc_now, utc_now,
) )
data = await RestoreStateData.async_get_instance(hass) data = async_get(hass)
await hass.async_block_till_done()
await data.store.async_save([stored_state.as_dict()]) await data.store.async_save([stored_state.as_dict()])
await data.async_load()
# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
entity = Timer.from_storage( entity = Timer.from_storage(
{ {
@ -973,12 +960,9 @@ async def test_restore_active_finished_outside_grace(hass: HomeAssistant) -> Non
utc_now, utc_now,
) )
data = await RestoreStateData.async_get_instance(hass) data = async_get(hass)
await hass.async_block_till_done()
await data.store.async_save([stored_state.as_dict()]) await data.store.async_save([stored_state.as_dict()])
await data.async_load()
# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
entity = Timer.from_storage( entity = Timer.from_storage(
{ {

View file

@ -21,6 +21,8 @@ from .common import (
) )
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE
from tests.common import async_mock_load_restore_state_from_storage
DEVICE_IAS = { DEVICE_IAS = {
1: { 1: {
SIG_EP_PROFILE: zigpy.profiles.zha.PROFILE_ID, SIG_EP_PROFILE: zigpy.profiles.zha.PROFILE_ID,
@ -186,6 +188,7 @@ async def test_binary_sensor_migration_not_migrated(
entity_id = "binary_sensor.fakemanufacturer_fakemodel_iaszone" entity_id = "binary_sensor.fakemanufacturer_fakemodel_iaszone"
core_rs(entity_id, state=restored_state, attributes={}) # migration sensor state core_rs(entity_id, state=restored_state, attributes={}) # migration sensor state
await async_mock_load_restore_state_from_storage(hass)
zigpy_device = zigpy_device_mock(DEVICE_IAS) zigpy_device = zigpy_device_mock(DEVICE_IAS)
zha_device = await zha_device_restored(zigpy_device) zha_device = await zha_device_restored(zigpy_device)
@ -208,6 +211,7 @@ async def test_binary_sensor_migration_already_migrated(
entity_id = "binary_sensor.fakemanufacturer_fakemodel_iaszone" entity_id = "binary_sensor.fakemanufacturer_fakemodel_iaszone"
core_rs(entity_id, state=STATE_OFF, attributes={"migrated_to_cache": True}) core_rs(entity_id, state=STATE_OFF, attributes={"migrated_to_cache": True})
await async_mock_load_restore_state_from_storage(hass)
zigpy_device = zigpy_device_mock(DEVICE_IAS) zigpy_device = zigpy_device_mock(DEVICE_IAS)
@ -243,6 +247,7 @@ async def test_onoff_binary_sensor_restore_state(
entity_id = "binary_sensor.fakemanufacturer_fakemodel_opening" entity_id = "binary_sensor.fakemanufacturer_fakemodel_opening"
core_rs(entity_id, state=restored_state, attributes={}) core_rs(entity_id, state=restored_state, attributes={})
await async_mock_load_restore_state_from_storage(hass)
zigpy_device = zigpy_device_mock(DEVICE_ONOFF) zigpy_device = zigpy_device_mock(DEVICE_ONOFF)
zha_device = await zha_device_restored(zigpy_device) zha_device = await zha_device_restored(zigpy_device)

View file

@ -26,6 +26,8 @@ from homeassistant.util import dt as dt_util
from .common import async_enable_traffic, find_entity_id, send_attributes_report from .common import async_enable_traffic, find_entity_id, send_attributes_report
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE
from tests.common import async_mock_load_restore_state_from_storage
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def select_select_only(): def select_select_only():
@ -176,6 +178,7 @@ async def test_select_restore_state(
entity_id = "select.fakemanufacturer_fakemodel_default_siren_tone" entity_id = "select.fakemanufacturer_fakemodel_default_siren_tone"
core_rs(entity_id, state="Burglar") core_rs(entity_id, state="Burglar")
await async_mock_load_restore_state_from_storage(hass)
zigpy_device = zigpy_device_mock( zigpy_device = zigpy_device_mock(
{ {

View file

@ -47,6 +47,8 @@ from .common import (
) )
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE
from tests.common import async_mock_load_restore_state_from_storage
ENTITY_ID_PREFIX = "sensor.fakemanufacturer_fakemodel_{}" ENTITY_ID_PREFIX = "sensor.fakemanufacturer_fakemodel_{}"
@ -530,6 +532,7 @@ def core_rs(hass_storage):
], ],
) )
async def test_temp_uom( async def test_temp_uom(
hass: HomeAssistant,
uom, uom,
raw_temp, raw_temp,
expected, expected,
@ -544,6 +547,7 @@ async def test_temp_uom(
entity_id = "sensor.fake1026_fakemodel1026_004f3202_temperature" entity_id = "sensor.fake1026_fakemodel1026_004f3202_temperature"
if restore: if restore:
core_rs(entity_id, uom, state=(expected - 2)) core_rs(entity_id, uom, state=(expected - 2))
await async_mock_load_restore_state_from_storage(hass)
hass = await hass_ms( hass = await hass_ms(
CONF_UNIT_SYSTEM_METRIC CONF_UNIT_SYSTEM_METRIC

View file

@ -8,11 +8,13 @@ from homeassistant.core import CoreState, HomeAssistant, State
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.restore_state import ( from homeassistant.helpers.restore_state import (
DATA_RESTORE_STATE_TASK, DATA_RESTORE_STATE,
STORAGE_KEY, STORAGE_KEY,
RestoreEntity, RestoreEntity,
RestoreStateData, RestoreStateData,
StoredState, StoredState,
async_get,
async_load,
) )
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -28,12 +30,25 @@ async def test_caching_data(hass: HomeAssistant) -> None:
StoredState(State("input_boolean.b2", "on"), None, now), StoredState(State("input_boolean.b2", "on"), None, now),
] ]
data = await RestoreStateData.async_get_instance(hass) data = async_get(hass)
await hass.async_block_till_done() await hass.async_block_till_done()
await data.store.async_save([state.as_dict() for state in stored_states]) await data.store.async_save([state.as_dict() for state in stored_states])
# Emulate a fresh load # Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK) hass.data.pop(DATA_RESTORE_STATE)
with patch(
"homeassistant.helpers.restore_state.Store.async_load",
side_effect=HomeAssistantError,
):
# Failure to load should not be treated as fatal
await async_load(hass)
data = async_get(hass)
assert data.last_states == {}
await async_load(hass)
data = async_get(hass)
entity = RestoreEntity() entity = RestoreEntity()
entity.hass = hass entity.hass = hass
@ -55,12 +70,14 @@ async def test_caching_data(hass: HomeAssistant) -> None:
async def test_periodic_write(hass: HomeAssistant) -> None: async def test_periodic_write(hass: HomeAssistant) -> None:
"""Test that we write periodiclly but not after stop.""" """Test that we write periodiclly but not after stop."""
data = await RestoreStateData.async_get_instance(hass) data = async_get(hass)
await hass.async_block_till_done() await hass.async_block_till_done()
await data.store.async_save([]) await data.store.async_save([])
# Emulate a fresh load # Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK) hass.data.pop(DATA_RESTORE_STATE)
await async_load(hass)
data = async_get(hass)
entity = RestoreEntity() entity = RestoreEntity()
entity.hass = hass entity.hass = hass
@ -101,12 +118,14 @@ async def test_periodic_write(hass: HomeAssistant) -> None:
async def test_save_persistent_states(hass: HomeAssistant) -> None: async def test_save_persistent_states(hass: HomeAssistant) -> None:
"""Test that we cancel the currently running job, save the data, and verify the perdiodic job continues.""" """Test that we cancel the currently running job, save the data, and verify the perdiodic job continues."""
data = await RestoreStateData.async_get_instance(hass) data = async_get(hass)
await hass.async_block_till_done() await hass.async_block_till_done()
await data.store.async_save([]) await data.store.async_save([])
# Emulate a fresh load # Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK) hass.data.pop(DATA_RESTORE_STATE)
await async_load(hass)
data = async_get(hass)
entity = RestoreEntity() entity = RestoreEntity()
entity.hass = hass entity.hass = hass
@ -166,13 +185,15 @@ async def test_hass_starting(hass: HomeAssistant) -> None:
StoredState(State("input_boolean.b2", "on"), None, now), StoredState(State("input_boolean.b2", "on"), None, now),
] ]
data = await RestoreStateData.async_get_instance(hass) data = async_get(hass)
await hass.async_block_till_done() await hass.async_block_till_done()
await data.store.async_save([state.as_dict() for state in stored_states]) await data.store.async_save([state.as_dict() for state in stored_states])
# Emulate a fresh load # Emulate a fresh load
hass.state = CoreState.not_running hass.state = CoreState.not_running
hass.data.pop(DATA_RESTORE_STATE_TASK) hass.data.pop(DATA_RESTORE_STATE)
await async_load(hass)
data = async_get(hass)
entity = RestoreEntity() entity = RestoreEntity()
entity.hass = hass entity.hass = hass
@ -223,7 +244,7 @@ async def test_dump_data(hass: HomeAssistant) -> None:
entity.entity_id = "input_boolean.b1" entity.entity_id = "input_boolean.b1"
await entity.async_internal_added_to_hass() await entity.async_internal_added_to_hass()
data = await RestoreStateData.async_get_instance(hass) data = async_get(hass)
now = dt_util.utcnow() now = dt_util.utcnow()
data.last_states = { data.last_states = {
"input_boolean.b0": StoredState(State("input_boolean.b0", "off"), None, now), "input_boolean.b0": StoredState(State("input_boolean.b0", "off"), None, now),
@ -297,7 +318,7 @@ async def test_dump_error(hass: HomeAssistant) -> None:
entity.entity_id = "input_boolean.b1" entity.entity_id = "input_boolean.b1"
await entity.async_internal_added_to_hass() await entity.async_internal_added_to_hass()
data = await RestoreStateData.async_get_instance(hass) data = async_get(hass)
with patch( with patch(
"homeassistant.helpers.restore_state.Store.async_save", "homeassistant.helpers.restore_state.Store.async_save",
@ -335,7 +356,7 @@ async def test_state_saved_on_remove(hass: HomeAssistant) -> None:
"input_boolean.b0", "on", {"complicated": {"value": {1, 2, now}}} "input_boolean.b0", "on", {"complicated": {"value": {1, 2, now}}}
) )
data = await RestoreStateData.async_get_instance(hass) data = async_get(hass)
# No last states should currently be saved # No last states should currently be saved
assert not data.last_states assert not data.last_states