Fix singleton not working with falsey values (#56072)

This commit is contained in:
Paulus Schoutsen 2021-09-11 12:02:01 -07:00 committed by GitHub
parent 6e7ce89c64
commit 8a611eb640
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 60 deletions

View file

@ -6,16 +6,10 @@ from datetime import datetime, timedelta
import logging
from typing import Any, cast
from homeassistant.const import EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import (
CoreState,
HomeAssistant,
State,
callback,
valid_entity_id,
)
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant, State, callback, valid_entity_id
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry
from homeassistant.helpers import entity_registry, start
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.json import JSONEncoder
@ -63,14 +57,11 @@ class StoredState:
class RestoreStateData:
"""Helper class for managing the helper saved data."""
@classmethod
async def async_get_instance(cls, hass: HomeAssistant) -> RestoreStateData:
"""Get the singleton instance of this data helper."""
@staticmethod
@singleton(DATA_RESTORE_STATE_TASK)
async def load_instance(hass: HomeAssistant) -> RestoreStateData:
async def async_get_instance(hass: HomeAssistant) -> RestoreStateData:
"""Get the singleton instance of this data helper."""
data = cls(hass)
data = RestoreStateData(hass)
try:
stored_states = await data.store.async_load()
@ -89,17 +80,14 @@ class RestoreStateData:
}
_LOGGER.debug("Created cache with %s", list(data.last_states))
if hass.state == CoreState.running:
async def hass_start(hass: HomeAssistant) -> None:
"""Start the restore state task."""
data.async_setup_dump()
else:
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_START, data.async_setup_dump
)
start.async_at_start(hass, hass_start)
return data
return cast(RestoreStateData, await load_instance(hass))
@classmethod
async def async_save_persistent_states(cls, hass: HomeAssistant) -> None:
"""Dump states now."""
@ -269,7 +257,9 @@ class RestoreEntity(Entity):
# Return None if this entity isn't added to hass yet
_LOGGER.warning("Cannot get last state. Entity not added to hass") # type: ignore[unreachable]
return None
data = await RestoreStateData.async_get_instance(self.hass)
data = cast(
RestoreStateData, await RestoreStateData.async_get_instance(self.hass)
)
if self.entity_id not in data.last_states:
return None
return data.last_states[self.entity_id].state

View file

@ -26,31 +26,27 @@ def singleton(data_key: str) -> Callable[[FUNC], FUNC]:
@bind_hass
@functools.wraps(func)
def wrapped(hass: HomeAssistant) -> T:
obj: T | None = hass.data.get(data_key)
if obj is None:
obj = hass.data[data_key] = func(hass)
return obj
if data_key not in hass.data:
hass.data[data_key] = func(hass)
return cast(T, hass.data[data_key])
return wrapped
@bind_hass
@functools.wraps(func)
async def async_wrapped(hass: HomeAssistant) -> T:
obj_or_evt = hass.data.get(data_key)
if not obj_or_evt:
if data_key not in hass.data:
evt = hass.data[data_key] = asyncio.Event()
result = await func(hass)
hass.data[data_key] = result
evt.set()
return cast(T, result)
obj_or_evt = hass.data[data_key]
if isinstance(obj_or_evt, asyncio.Event):
evt = obj_or_evt
await evt.wait()
return cast(T, hass.data.get(data_key))
await obj_or_evt.wait()
return cast(T, hass.data[data_key])
return cast(T, obj_or_evt)

View file

@ -32,7 +32,7 @@ async def test_caching_data(hass):
await data.store.async_save([state.as_dict() for state in stored_states])
# Emulate a fresh load
hass.data[DATA_RESTORE_STATE_TASK] = None
hass.data.pop(DATA_RESTORE_STATE_TASK)
entity = RestoreEntity()
entity.hass = hass
@ -59,7 +59,7 @@ async def test_periodic_write(hass):
await data.store.async_save([])
# Emulate a fresh load
hass.data[DATA_RESTORE_STATE_TASK] = None
hass.data.pop(DATA_RESTORE_STATE_TASK)
entity = RestoreEntity()
entity.hass = hass
@ -105,7 +105,7 @@ async def test_save_persistent_states(hass):
await data.store.async_save([])
# Emulate a fresh load
hass.data[DATA_RESTORE_STATE_TASK] = None
hass.data.pop(DATA_RESTORE_STATE_TASK)
entity = RestoreEntity()
entity.hass = hass
@ -170,7 +170,8 @@ async def test_hass_starting(hass):
await data.store.async_save([state.as_dict() for state in stored_states])
# Emulate a fresh load
hass.data[DATA_RESTORE_STATE_TASK] = None
hass.state = CoreState.not_running
hass.data.pop(DATA_RESTORE_STATE_TASK)
entity = RestoreEntity()
entity.hass = hass

View file

@ -12,29 +12,33 @@ def mock_hass():
return Mock(data={})
async def test_singleton_async(mock_hass):
@pytest.mark.parametrize("result", (object(), {}, []))
async def test_singleton_async(mock_hass, result):
"""Test singleton with async function."""
@singleton.singleton("test_key")
async def something(hass):
return object()
return result
result1 = await something(mock_hass)
result2 = await something(mock_hass)
assert result1 is result
assert result1 is result2
assert "test_key" in mock_hass.data
assert mock_hass.data["test_key"] is result1
def test_singleton(mock_hass):
@pytest.mark.parametrize("result", (object(), {}, []))
def test_singleton(mock_hass, result):
"""Test singleton with function."""
@singleton.singleton("test_key")
def something(hass):
return object()
return result
result1 = something(mock_hass)
result2 = something(mock_hass)
assert result1 is result
assert result1 is result2
assert "test_key" in mock_hass.data
assert mock_hass.data["test_key"] is result1