Restore states through a JSON store instead of recorder (#17270)
* Restore states through a JSON store * Accept entity_id directly in restore state helper * Keep states stored between runs for a limited time * Remove warning
This commit is contained in:
parent
a039c3209b
commit
5c3a4e3d10
46 changed files with 493 additions and 422 deletions
|
@ -21,7 +21,7 @@ from homeassistant.const import (
|
|||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.event import track_point_in_time
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -116,7 +116,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
|
|||
)])
|
||||
|
||||
|
||||
class ManualAlarm(alarm.AlarmControlPanel):
|
||||
class ManualAlarm(alarm.AlarmControlPanel, RestoreEntity):
|
||||
"""
|
||||
Representation of an alarm status.
|
||||
|
||||
|
@ -310,7 +310,7 @@ class ManualAlarm(alarm.AlarmControlPanel):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Run when entity about to be added to hass."""
|
||||
state = await async_get_last_state(self.hass, self.entity_id)
|
||||
state = await self.async_get_last_state()
|
||||
if state:
|
||||
self._state = state.state
|
||||
self._state_ts = state.last_updated
|
||||
|
|
|
@ -108,8 +108,7 @@ class MqttAlarm(MqttAvailability, MqttDiscoveryUpdate,
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Subscribe mqtt events."""
|
||||
await MqttAvailability.async_added_to_hass(self)
|
||||
await MqttDiscoveryUpdate.async_added_to_hass(self)
|
||||
await super().async_added_to_hass()
|
||||
await self._subscribe_topics()
|
||||
|
||||
async def discovery_update(self, discovery_payload):
|
||||
|
|
|
@ -21,7 +21,7 @@ from homeassistant.exceptions import HomeAssistantError
|
|||
from homeassistant.helpers import extract_domain_configs, script, condition
|
||||
from homeassistant.helpers.entity import ToggleEntity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.util.dt import utcnow
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
||||
|
@ -182,7 +182,7 @@ async def async_setup(hass, config):
|
|||
return True
|
||||
|
||||
|
||||
class AutomationEntity(ToggleEntity):
|
||||
class AutomationEntity(ToggleEntity, RestoreEntity):
|
||||
"""Entity to show status of entity."""
|
||||
|
||||
def __init__(self, automation_id, name, async_attach_triggers, cond_func,
|
||||
|
@ -227,12 +227,13 @@ class AutomationEntity(ToggleEntity):
|
|||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Startup with initial state or previous state."""
|
||||
await super().async_added_to_hass()
|
||||
if self._initial_state is not None:
|
||||
enable_automation = self._initial_state
|
||||
_LOGGER.debug("Automation %s initial state %s from config "
|
||||
"initial_state", self.entity_id, enable_automation)
|
||||
else:
|
||||
state = await async_get_last_state(self.hass, self.entity_id)
|
||||
state = await self.async_get_last_state()
|
||||
if state:
|
||||
enable_automation = state.state == STATE_ON
|
||||
self._last_triggered = state.attributes.get('last_triggered')
|
||||
|
@ -291,6 +292,7 @@ class AutomationEntity(ToggleEntity):
|
|||
|
||||
async def async_will_remove_from_hass(self):
|
||||
"""Remove listeners when removing automation from HASS."""
|
||||
await super().async_will_remove_from_hass()
|
||||
await self.async_turn_off()
|
||||
|
||||
async def async_enable(self):
|
||||
|
|
|
@ -102,8 +102,7 @@ class MqttBinarySensor(MqttAvailability, MqttDiscoveryUpdate,
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Subscribe mqtt events."""
|
||||
await MqttAvailability.async_added_to_hass(self)
|
||||
await MqttDiscoveryUpdate.async_added_to_hass(self)
|
||||
await super().async_added_to_hass()
|
||||
await self._subscribe_topics()
|
||||
|
||||
async def discovery_update(self, discovery_payload):
|
||||
|
|
|
@ -23,7 +23,7 @@ from homeassistant.helpers import condition
|
|||
from homeassistant.helpers.event import (
|
||||
async_track_state_change, async_track_time_interval)
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -96,7 +96,7 @@ async def async_setup_platform(hass, config, async_add_entities,
|
|||
precision)])
|
||||
|
||||
|
||||
class GenericThermostat(ClimateDevice):
|
||||
class GenericThermostat(ClimateDevice, RestoreEntity):
|
||||
"""Representation of a Generic Thermostat device."""
|
||||
|
||||
def __init__(self, hass, name, heater_entity_id, sensor_entity_id,
|
||||
|
@ -155,8 +155,9 @@ class GenericThermostat(ClimateDevice):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Run when entity about to be added."""
|
||||
await super().async_added_to_hass()
|
||||
# Check If we have an old state
|
||||
old_state = await async_get_last_state(self.hass, self.entity_id)
|
||||
old_state = await self.async_get_last_state()
|
||||
if old_state is not None:
|
||||
# If we have no initial temperature, restore
|
||||
if self._target_temp is None:
|
||||
|
|
|
@ -221,8 +221,7 @@ class MqttClimate(MqttAvailability, MqttDiscoveryUpdate, ClimateDevice):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Handle being added to home assistant."""
|
||||
await MqttAvailability.async_added_to_hass(self)
|
||||
await MqttDiscoveryUpdate.async_added_to_hass(self)
|
||||
await super().async_added_to_hass()
|
||||
await self._subscribe_topics()
|
||||
|
||||
async def discovery_update(self, discovery_payload):
|
||||
|
|
|
@ -10,9 +10,8 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_ICON, CONF_NAME
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -86,7 +85,7 @@ async def async_setup(hass, config):
|
|||
return True
|
||||
|
||||
|
||||
class Counter(Entity):
|
||||
class Counter(RestoreEntity):
|
||||
"""Representation of a counter."""
|
||||
|
||||
def __init__(self, object_id, name, initial, restore, step, icon):
|
||||
|
@ -128,10 +127,11 @@ class Counter(Entity):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Call when entity about to be added to Home Assistant."""
|
||||
await super().async_added_to_hass()
|
||||
# __init__ will set self._state to self._initial, only override
|
||||
# if needed.
|
||||
if self._restore:
|
||||
state = await async_get_last_state(self.hass, self.entity_id)
|
||||
state = await self.async_get_last_state()
|
||||
if state is not None:
|
||||
self._state = int(state.state)
|
||||
|
||||
|
|
|
@ -205,8 +205,7 @@ class MqttCover(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo,
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Subscribe MQTT events."""
|
||||
await MqttAvailability.async_added_to_hass(self)
|
||||
await MqttDiscoveryUpdate.async_added_to_hass(self)
|
||||
await super().async_added_to_hass()
|
||||
await self._subscribe_topics()
|
||||
|
||||
async def discovery_update(self, discovery_payload):
|
||||
|
|
|
@ -22,9 +22,8 @@ from homeassistant.components.zone.zone import async_active_zone
|
|||
from homeassistant.config import load_yaml_config_file, async_log_exception
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_per_platform, discovery
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.event import async_track_time_interval
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.helpers.typing import GPSType, ConfigType, HomeAssistantType
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant import util
|
||||
|
@ -396,7 +395,7 @@ class DeviceTracker:
|
|||
await asyncio.wait(tasks, loop=self.hass.loop)
|
||||
|
||||
|
||||
class Device(Entity):
|
||||
class Device(RestoreEntity):
|
||||
"""Represent a tracked device."""
|
||||
|
||||
host_name = None # type: str
|
||||
|
@ -564,7 +563,8 @@ class Device(Entity):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Add an entity."""
|
||||
state = await async_get_last_state(self.hass, self.entity_id)
|
||||
await super().async_added_to_hass()
|
||||
state = await self.async_get_last_state()
|
||||
if not state:
|
||||
return
|
||||
self._state = state.state
|
||||
|
|
|
@ -151,8 +151,7 @@ class MqttFan(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo,
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Subscribe to MQTT events."""
|
||||
await MqttAvailability.async_added_to_hass(self)
|
||||
await MqttDiscoveryUpdate.async_added_to_hass(self)
|
||||
await super().async_added_to_hass()
|
||||
await self._subscribe_topics()
|
||||
|
||||
async def discovery_update(self, discovery_payload):
|
||||
|
|
|
@ -38,20 +38,6 @@ SIGNIFICANT_DOMAINS = ('thermostat', 'climate')
|
|||
IGNORE_DOMAINS = ('zone', 'scene',)
|
||||
|
||||
|
||||
def last_recorder_run(hass):
|
||||
"""Retrieve the last closed recorder run from the database."""
|
||||
from homeassistant.components.recorder.models import RecorderRuns
|
||||
|
||||
with session_scope(hass=hass) as session:
|
||||
res = (session.query(RecorderRuns)
|
||||
.filter(RecorderRuns.end.isnot(None))
|
||||
.order_by(RecorderRuns.end.desc()).first())
|
||||
if res is None:
|
||||
return None
|
||||
session.expunge(res)
|
||||
return res
|
||||
|
||||
|
||||
def get_significant_states(hass, start_time, end_time=None, entity_ids=None,
|
||||
filters=None, include_start_time_state=True):
|
||||
"""
|
||||
|
|
|
@ -15,7 +15,7 @@ from homeassistant.loader import bind_hass
|
|||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity import ToggleEntity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
|
||||
DOMAIN = 'input_boolean'
|
||||
|
||||
|
@ -84,7 +84,7 @@ async def async_setup(hass, config):
|
|||
return True
|
||||
|
||||
|
||||
class InputBoolean(ToggleEntity):
|
||||
class InputBoolean(ToggleEntity, RestoreEntity):
|
||||
"""Representation of a boolean input."""
|
||||
|
||||
def __init__(self, object_id, name, initial, icon):
|
||||
|
@ -117,10 +117,11 @@ class InputBoolean(ToggleEntity):
|
|||
async def async_added_to_hass(self):
|
||||
"""Call when entity about to be added to hass."""
|
||||
# If not None, we got an initial value.
|
||||
await super().async_added_to_hass()
|
||||
if self._state is not None:
|
||||
return
|
||||
|
||||
state = await async_get_last_state(self.hass, self.entity_id)
|
||||
state = await self.async_get_last_state()
|
||||
self._state = state and state.state == STATE_ON
|
||||
|
||||
async def async_turn_on(self, **kwargs):
|
||||
|
|
|
@ -11,9 +11,8 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_ICON, CONF_NAME
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
|
||||
|
@ -97,7 +96,7 @@ async def async_setup(hass, config):
|
|||
return True
|
||||
|
||||
|
||||
class InputDatetime(Entity):
|
||||
class InputDatetime(RestoreEntity):
|
||||
"""Representation of a datetime input."""
|
||||
|
||||
def __init__(self, object_id, name, has_date, has_time, icon, initial):
|
||||
|
@ -112,6 +111,7 @@ class InputDatetime(Entity):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Run when entity about to be added."""
|
||||
await super().async_added_to_hass()
|
||||
restore_val = None
|
||||
|
||||
# Priority 1: Initial State
|
||||
|
@ -120,7 +120,7 @@ class InputDatetime(Entity):
|
|||
|
||||
# Priority 2: Old state
|
||||
if restore_val is None:
|
||||
old_state = await async_get_last_state(self.hass, self.entity_id)
|
||||
old_state = await self.async_get_last_state()
|
||||
if old_state is not None:
|
||||
restore_val = old_state.state
|
||||
|
||||
|
|
|
@ -11,9 +11,8 @@ import voluptuous as vol
|
|||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID, ATTR_UNIT_OF_MEASUREMENT, CONF_ICON, CONF_NAME, CONF_MODE)
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -123,7 +122,7 @@ async def async_setup(hass, config):
|
|||
return True
|
||||
|
||||
|
||||
class InputNumber(Entity):
|
||||
class InputNumber(RestoreEntity):
|
||||
"""Representation of a slider."""
|
||||
|
||||
def __init__(self, object_id, name, initial, minimum, maximum, step, icon,
|
||||
|
@ -178,10 +177,11 @@ class InputNumber(Entity):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Run when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
if self._current_value is not None:
|
||||
return
|
||||
|
||||
state = await async_get_last_state(self.hass, self.entity_id)
|
||||
state = await self.async_get_last_state()
|
||||
value = state and float(state.state)
|
||||
|
||||
# Check against None because value can be 0
|
||||
|
|
|
@ -10,9 +10,8 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_ICON, CONF_NAME
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -116,7 +115,7 @@ async def async_setup(hass, config):
|
|||
return True
|
||||
|
||||
|
||||
class InputSelect(Entity):
|
||||
class InputSelect(RestoreEntity):
|
||||
"""Representation of a select input."""
|
||||
|
||||
def __init__(self, object_id, name, initial, options, icon):
|
||||
|
@ -129,10 +128,11 @@ class InputSelect(Entity):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Run when entity about to be added."""
|
||||
await super().async_added_to_hass()
|
||||
if self._current_option is not None:
|
||||
return
|
||||
|
||||
state = await async_get_last_state(self.hass, self.entity_id)
|
||||
state = await self.async_get_last_state()
|
||||
if not state or state.state not in self._options:
|
||||
self._current_option = self._options[0]
|
||||
else:
|
||||
|
|
|
@ -11,9 +11,8 @@ import voluptuous as vol
|
|||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID, ATTR_UNIT_OF_MEASUREMENT, CONF_ICON, CONF_NAME, CONF_MODE)
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -104,7 +103,7 @@ async def async_setup(hass, config):
|
|||
return True
|
||||
|
||||
|
||||
class InputText(Entity):
|
||||
class InputText(RestoreEntity):
|
||||
"""Represent a text box."""
|
||||
|
||||
def __init__(self, object_id, name, initial, minimum, maximum, icon,
|
||||
|
@ -157,10 +156,11 @@ class InputText(Entity):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Run when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
if self._current_value is not None:
|
||||
return
|
||||
|
||||
state = await async_get_last_state(self.hass, self.entity_id)
|
||||
state = await self.async_get_last_state()
|
||||
value = state and state.state
|
||||
|
||||
# Check against None because value can be 0
|
||||
|
|
|
@ -18,7 +18,7 @@ from homeassistant.components.light import (
|
|||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.util.color import (
|
||||
color_temperature_mired_to_kelvin, color_hs_to_RGB)
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
|
||||
REQUIREMENTS = ['limitlessled==1.1.3']
|
||||
|
||||
|
@ -157,7 +157,7 @@ def state(new_state):
|
|||
return decorator
|
||||
|
||||
|
||||
class LimitlessLEDGroup(Light):
|
||||
class LimitlessLEDGroup(Light, RestoreEntity):
|
||||
"""Representation of a LimitessLED group."""
|
||||
|
||||
def __init__(self, group, config):
|
||||
|
@ -189,7 +189,8 @@ class LimitlessLEDGroup(Light):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Handle entity about to be added to hass event."""
|
||||
last_state = await async_get_last_state(self.hass, self.entity_id)
|
||||
await super().async_added_to_hass()
|
||||
last_state = await self.async_get_last_state()
|
||||
if last_state:
|
||||
self._is_on = (last_state.state == STATE_ON)
|
||||
self._brightness = last_state.attributes.get('brightness')
|
||||
|
|
|
@ -22,7 +22,7 @@ from homeassistant.components.mqtt import (
|
|||
CONF_AVAILABILITY_TOPIC, CONF_COMMAND_TOPIC, CONF_PAYLOAD_AVAILABLE,
|
||||
CONF_PAYLOAD_NOT_AVAILABLE, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC,
|
||||
MqttAvailability, MqttDiscoveryUpdate)
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
import homeassistant.util.color as color_util
|
||||
|
||||
|
@ -166,7 +166,7 @@ async def async_setup_entity_basic(hass, config, async_add_entities,
|
|||
)])
|
||||
|
||||
|
||||
class MqttLight(MqttAvailability, MqttDiscoveryUpdate, Light):
|
||||
class MqttLight(MqttAvailability, MqttDiscoveryUpdate, Light, RestoreEntity):
|
||||
"""Representation of a MQTT light."""
|
||||
|
||||
def __init__(self, name, unique_id, effect_list, topic, templates,
|
||||
|
@ -237,8 +237,7 @@ class MqttLight(MqttAvailability, MqttDiscoveryUpdate, Light):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Subscribe to MQTT events."""
|
||||
await MqttAvailability.async_added_to_hass(self)
|
||||
await MqttDiscoveryUpdate.async_added_to_hass(self)
|
||||
await super().async_added_to_hass()
|
||||
|
||||
templates = {}
|
||||
for key, tpl in list(self._templates.items()):
|
||||
|
@ -248,7 +247,7 @@ class MqttLight(MqttAvailability, MqttDiscoveryUpdate, Light):
|
|||
tpl.hass = self.hass
|
||||
templates[key] = tpl.async_render_with_possible_json_value
|
||||
|
||||
last_state = await async_get_last_state(self.hass, self.entity_id)
|
||||
last_state = await self.async_get_last_state()
|
||||
|
||||
@callback
|
||||
def state_received(topic, payload, qos):
|
||||
|
|
|
@ -25,7 +25,7 @@ from homeassistant.const import (
|
|||
CONF_RGB, CONF_WHITE_VALUE, CONF_XY, STATE_ON)
|
||||
from homeassistant.core import callback
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
|
||||
import homeassistant.util.color as color_util
|
||||
|
||||
|
@ -121,7 +121,8 @@ async def async_setup_entity_json(hass: HomeAssistantType, config: ConfigType,
|
|||
)])
|
||||
|
||||
|
||||
class MqttLightJson(MqttAvailability, MqttDiscoveryUpdate, Light):
|
||||
class MqttLightJson(MqttAvailability, MqttDiscoveryUpdate, Light,
|
||||
RestoreEntity):
|
||||
"""Representation of a MQTT JSON light."""
|
||||
|
||||
def __init__(self, name, unique_id, effect_list, topic, qos, retain,
|
||||
|
@ -183,10 +184,9 @@ class MqttLightJson(MqttAvailability, MqttDiscoveryUpdate, Light):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Subscribe to MQTT events."""
|
||||
await MqttAvailability.async_added_to_hass(self)
|
||||
await MqttDiscoveryUpdate.async_added_to_hass(self)
|
||||
await super().async_added_to_hass()
|
||||
|
||||
last_state = await async_get_last_state(self.hass, self.entity_id)
|
||||
last_state = await self.async_get_last_state()
|
||||
|
||||
@callback
|
||||
def state_received(topic, payload, qos):
|
||||
|
|
|
@ -21,7 +21,7 @@ from homeassistant.components.mqtt import (
|
|||
MqttAvailability)
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
import homeassistant.util.color as color_util
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -102,7 +102,7 @@ async def async_setup_entity_template(hass, config, async_add_entities,
|
|||
)])
|
||||
|
||||
|
||||
class MqttTemplate(MqttAvailability, Light):
|
||||
class MqttTemplate(MqttAvailability, Light, RestoreEntity):
|
||||
"""Representation of a MQTT Template light."""
|
||||
|
||||
def __init__(self, hass, name, effect_list, topics, templates, optimistic,
|
||||
|
@ -153,7 +153,7 @@ class MqttTemplate(MqttAvailability, Light):
|
|||
"""Subscribe to MQTT events."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
last_state = await async_get_last_state(self.hass, self.entity_id)
|
||||
last_state = await self.async_get_last_state()
|
||||
|
||||
@callback
|
||||
def state_received(topic, payload, qos):
|
||||
|
|
|
@ -111,8 +111,7 @@ class MqttLock(MqttAvailability, MqttDiscoveryUpdate, LockDevice):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Subscribe to MQTT events."""
|
||||
await MqttAvailability.async_added_to_hass(self)
|
||||
await MqttDiscoveryUpdate.async_added_to_hass(self)
|
||||
await super().async_added_to_hass()
|
||||
|
||||
@callback
|
||||
def message_received(topic, payload, qos):
|
||||
|
|
|
@ -840,6 +840,7 @@ class MqttAvailability(Entity):
|
|||
|
||||
This method must be run in the event loop and returns a coroutine.
|
||||
"""
|
||||
await super().async_added_to_hass()
|
||||
await self._availability_subscribe_topics()
|
||||
|
||||
async def availability_discovery_update(self, config: dict):
|
||||
|
@ -900,6 +901,8 @@ class MqttDiscoveryUpdate(Entity):
|
|||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Subscribe to discovery updates."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.components.mqtt.discovery import (
|
||||
ALREADY_DISCOVERED, MQTT_DISCOVERY_UPDATED)
|
||||
|
|
|
@ -28,7 +28,6 @@ import homeassistant.helpers.config_validation as cv
|
|||
from homeassistant.helpers.entityfilter import generate_filter
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.loader import bind_hass
|
||||
|
||||
from . import migration, purge
|
||||
from .const import DATA_INSTANCE
|
||||
|
@ -83,12 +82,6 @@ CONFIG_SCHEMA = vol.Schema({
|
|||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def wait_connection_ready(hass):
|
||||
"""Wait till the connection is ready."""
|
||||
return await hass.data[DATA_INSTANCE].async_db_ready
|
||||
|
||||
|
||||
def run_information(hass, point_in_time: Optional[datetime] = None):
|
||||
"""Return information about current run.
|
||||
|
||||
|
|
|
@ -10,9 +10,8 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.components.sensor import DOMAIN, PLATFORM_SCHEMA
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.event import track_time_change
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
REQUIREMENTS = ['fastdotcom==0.0.3']
|
||||
|
@ -51,7 +50,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
|
|||
hass.services.register(DOMAIN, 'update_fastdotcom', update)
|
||||
|
||||
|
||||
class SpeedtestSensor(Entity):
|
||||
class SpeedtestSensor(RestoreEntity):
|
||||
"""Implementation of a FAst.com sensor."""
|
||||
|
||||
def __init__(self, speedtest_data):
|
||||
|
@ -86,7 +85,8 @@ class SpeedtestSensor(Entity):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Handle entity which will be added."""
|
||||
state = await async_get_last_state(self.hass, self.entity_id)
|
||||
await super().async_added_to_hass()
|
||||
state = await self.async_get_last_state()
|
||||
if not state:
|
||||
return
|
||||
self._state = state.state
|
||||
|
|
|
@ -119,8 +119,7 @@ class MqttSensor(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo,
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Subscribe to MQTT events."""
|
||||
await MqttAvailability.async_added_to_hass(self)
|
||||
await MqttDiscoveryUpdate.async_added_to_hass(self)
|
||||
await super().async_added_to_hass()
|
||||
await self._subscribe_topics()
|
||||
|
||||
async def discovery_update(self, discovery_payload):
|
||||
|
|
|
@ -11,9 +11,8 @@ import voluptuous as vol
|
|||
from homeassistant.components.sensor import DOMAIN, PLATFORM_SCHEMA
|
||||
from homeassistant.const import ATTR_ATTRIBUTION, CONF_MONITORED_CONDITIONS
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.event import track_time_change
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
REQUIREMENTS = ['speedtest-cli==2.0.2']
|
||||
|
@ -76,7 +75,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
|
|||
hass.services.register(DOMAIN, 'update_speedtest', update)
|
||||
|
||||
|
||||
class SpeedtestSensor(Entity):
|
||||
class SpeedtestSensor(RestoreEntity):
|
||||
"""Implementation of a speedtest.net sensor."""
|
||||
|
||||
def __init__(self, speedtest_data, sensor_type):
|
||||
|
@ -137,7 +136,8 @@ class SpeedtestSensor(Entity):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Handle all entity which are about to be added."""
|
||||
state = await async_get_last_state(self.hass, self.entity_id)
|
||||
await super().async_added_to_hass()
|
||||
state = await self.async_get_last_state()
|
||||
if not state:
|
||||
return
|
||||
self._state = state.state
|
||||
|
|
|
@ -24,7 +24,7 @@ from homeassistant.components import mqtt, switch
|
|||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.typing import HomeAssistantType, ConfigType
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -102,8 +102,9 @@ async def _async_setup_entity(hass, config, async_add_entities,
|
|||
async_add_entities([newswitch])
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
class MqttSwitch(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo,
|
||||
SwitchDevice):
|
||||
SwitchDevice, RestoreEntity):
|
||||
"""Representation of a switch that can be toggled using MQTT."""
|
||||
|
||||
def __init__(self, name, icon,
|
||||
|
@ -136,8 +137,7 @@ class MqttSwitch(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo,
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Subscribe to MQTT events."""
|
||||
await MqttAvailability.async_added_to_hass(self)
|
||||
await MqttDiscoveryUpdate.async_added_to_hass(self)
|
||||
await super().async_added_to_hass()
|
||||
|
||||
@callback
|
||||
def state_message_received(topic, payload, qos):
|
||||
|
@ -161,8 +161,7 @@ class MqttSwitch(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo,
|
|||
self._qos)
|
||||
|
||||
if self._optimistic:
|
||||
last_state = await async_get_last_state(self.hass,
|
||||
self.entity_id)
|
||||
last_state = await self.async_get_last_state()
|
||||
if last_state:
|
||||
self._state = last_state.state == STATE_ON
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from homeassistant.components import pilight
|
|||
from homeassistant.components.switch import (SwitchDevice, PLATFORM_SCHEMA)
|
||||
from homeassistant.const import (CONF_NAME, CONF_ID, CONF_SWITCHES, CONF_STATE,
|
||||
CONF_PROTOCOL, STATE_ON)
|
||||
from homeassistant.helpers.restore_state import async_get_last_state
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -97,7 +97,7 @@ class _ReceiveHandle:
|
|||
switch.set_state(turn_on=turn_on, send_code=self.echo)
|
||||
|
||||
|
||||
class PilightSwitch(SwitchDevice):
|
||||
class PilightSwitch(SwitchDevice, RestoreEntity):
|
||||
"""Representation of a Pilight switch."""
|
||||
|
||||
def __init__(self, hass, name, code_on, code_off, code_on_receive,
|
||||
|
@ -123,7 +123,8 @@ class PilightSwitch(SwitchDevice):
|
|||
|
||||
async def async_added_to_hass(self):
|
||||
"""Call when entity about to be added to hass."""
|
||||
state = await async_get_last_state(self._hass, self.entity_id)
|
||||
await super().async_added_to_hass()
|
||||
state = await self.async_get_last_state()
|
||||
if state:
|
||||
self._state = state.state == STATE_ON
|
||||
|
||||
|
|
|
@ -12,9 +12,9 @@ import voluptuous as vol
|
|||
import homeassistant.util.dt as dt_util
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.const import (ATTR_ENTITY_ID, CONF_ICON, CONF_NAME)
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.event import async_track_point_in_utc_time
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -97,7 +97,7 @@ async def async_setup(hass, config):
|
|||
return True
|
||||
|
||||
|
||||
class Timer(Entity):
|
||||
class Timer(RestoreEntity):
|
||||
"""Representation of a timer."""
|
||||
|
||||
def __init__(self, hass, object_id, name, icon, duration):
|
||||
|
@ -146,8 +146,7 @@ class Timer(Entity):
|
|||
if self._state is not None:
|
||||
return
|
||||
|
||||
restore_state = self._hass.helpers.restore_state
|
||||
state = await restore_state.async_get_last_state(self.entity_id)
|
||||
state = await self.async_get_last_state()
|
||||
self._state = state and state.state == state
|
||||
|
||||
async def async_start(self, duration):
|
||||
|
|
|
@ -363,10 +363,7 @@ class Entity:
|
|||
|
||||
async def async_remove(self):
|
||||
"""Remove entity from Home Assistant."""
|
||||
will_remove = getattr(self, 'async_will_remove_from_hass', None)
|
||||
|
||||
if will_remove:
|
||||
await will_remove() # pylint: disable=not-callable
|
||||
await self.async_will_remove_from_hass()
|
||||
|
||||
if self._on_remove is not None:
|
||||
while self._on_remove:
|
||||
|
@ -390,6 +387,12 @@ class Entity:
|
|||
|
||||
self.hass.async_create_task(readd())
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added to hass."""
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Return the comparison."""
|
||||
if not isinstance(other, self.__class__):
|
||||
|
|
|
@ -346,8 +346,7 @@ class EntityPlatform:
|
|||
self.entities[entity_id] = entity
|
||||
entity.async_on_remove(lambda: self.entities.pop(entity_id))
|
||||
|
||||
if hasattr(entity, 'async_added_to_hass'):
|
||||
await entity.async_added_to_hass()
|
||||
await entity.async_added_to_hass()
|
||||
|
||||
await entity.async_update_ha_state()
|
||||
|
||||
|
|
|
@ -2,97 +2,174 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, List, Set, Optional # noqa pylint_disable=unused-import
|
||||
|
||||
import async_timeout
|
||||
|
||||
from homeassistant.core import HomeAssistant, CoreState, callback
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_START
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.components.history import get_states, last_recorder_run
|
||||
from homeassistant.components.recorder import (
|
||||
wait_connection_ready, DOMAIN as _RECORDER)
|
||||
from homeassistant.core import HomeAssistant, callback, State, CoreState
|
||||
from homeassistant.const import (
|
||||
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP)
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.event import async_track_time_interval
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.json import JSONEncoder
|
||||
from homeassistant.helpers.storage import Store # noqa pylint_disable=unused-import
|
||||
|
||||
DATA_RESTORE_STATE_TASK = 'restore_state_task'
|
||||
|
||||
RECORDER_TIMEOUT = 10
|
||||
DATA_RESTORE_CACHE = 'restore_state_cache'
|
||||
_LOCK = 'restore_lock'
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
STORAGE_KEY = 'core.restore_state'
|
||||
STORAGE_VERSION = 1
|
||||
|
||||
# How long between periodically saving the current states to disk
|
||||
STATE_DUMP_INTERVAL = timedelta(minutes=15)
|
||||
|
||||
# How long should a saved state be preserved if the entity no longer exists
|
||||
STATE_EXPIRATION = timedelta(days=7)
|
||||
|
||||
|
||||
class RestoreStateData():
|
||||
"""Helper class for managing the helper saved data."""
|
||||
|
||||
@classmethod
|
||||
async def async_get_instance(
|
||||
cls, hass: HomeAssistant) -> 'RestoreStateData':
|
||||
"""Get the singleton instance of this data helper."""
|
||||
task = hass.data.get(DATA_RESTORE_STATE_TASK)
|
||||
|
||||
if task is None:
|
||||
async def load_instance(hass: HomeAssistant) -> 'RestoreStateData':
|
||||
"""Set up the restore state helper."""
|
||||
data = cls(hass)
|
||||
|
||||
try:
|
||||
states = await data.store.async_load()
|
||||
except HomeAssistantError as exc:
|
||||
_LOGGER.error("Error loading last states", exc_info=exc)
|
||||
states = None
|
||||
|
||||
if states is None:
|
||||
_LOGGER.debug('Not creating cache - no saved states found')
|
||||
data.last_states = {}
|
||||
else:
|
||||
data.last_states = {
|
||||
state['entity_id']: State.from_dict(state)
|
||||
for state in states}
|
||||
_LOGGER.debug(
|
||||
'Created cache with %s', list(data.last_states))
|
||||
|
||||
if hass.state == CoreState.running:
|
||||
data.async_setup_dump()
|
||||
else:
|
||||
hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_START, data.async_setup_dump)
|
||||
|
||||
return data
|
||||
|
||||
task = hass.data[DATA_RESTORE_STATE_TASK] = hass.async_create_task(
|
||||
load_instance(hass))
|
||||
|
||||
return await task
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the restore state data class."""
|
||||
self.hass = hass # type: HomeAssistant
|
||||
self.store = Store(hass, STORAGE_VERSION, STORAGE_KEY,
|
||||
encoder=JSONEncoder) # type: Store
|
||||
self.last_states = {} # type: Dict[str, State]
|
||||
self.entity_ids = set() # type: Set[str]
|
||||
|
||||
def async_get_states(self) -> List[State]:
|
||||
"""Get the set of states which should be stored.
|
||||
|
||||
This includes the states of all registered entities, as well as the
|
||||
stored states from the previous run, which have not been created as
|
||||
entities on this run, and have not expired.
|
||||
"""
|
||||
all_states = self.hass.states.async_all()
|
||||
current_entity_ids = set(state.entity_id for state in all_states)
|
||||
|
||||
# Start with the currently registered states
|
||||
states = [state for state in all_states
|
||||
if state.entity_id in self.entity_ids]
|
||||
|
||||
expiration_time = dt_util.utcnow() - STATE_EXPIRATION
|
||||
|
||||
for entity_id, state in self.last_states.items():
|
||||
# Don't save old states that have entities in the current run
|
||||
if entity_id in current_entity_ids:
|
||||
continue
|
||||
|
||||
# Don't save old states that have expired
|
||||
if state.last_updated < expiration_time:
|
||||
continue
|
||||
|
||||
states.append(state)
|
||||
|
||||
return states
|
||||
|
||||
async def async_dump_states(self) -> None:
|
||||
"""Save the current state machine to storage."""
|
||||
_LOGGER.debug("Dumping states")
|
||||
try:
|
||||
await self.store.async_save([
|
||||
state.as_dict() for state in self.async_get_states()])
|
||||
except HomeAssistantError as exc:
|
||||
_LOGGER.error("Error saving current states", exc_info=exc)
|
||||
|
||||
def _load_restore_cache(hass: HomeAssistant):
|
||||
"""Load the restore cache to be used by other components."""
|
||||
@callback
|
||||
def remove_cache(event):
|
||||
"""Remove the states cache."""
|
||||
hass.data.pop(DATA_RESTORE_CACHE, None)
|
||||
def async_setup_dump(self, *args: Any) -> None:
|
||||
"""Set up the restore state listeners."""
|
||||
# Dump the initial states now. This helps minimize the risk of having
|
||||
# old states loaded by overwritting the last states once home assistant
|
||||
# has started and the old states have been read.
|
||||
self.hass.async_create_task(self.async_dump_states())
|
||||
|
||||
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, remove_cache)
|
||||
# Dump states periodically
|
||||
async_track_time_interval(
|
||||
self.hass, lambda *_: self.hass.async_create_task(
|
||||
self.async_dump_states()), STATE_DUMP_INTERVAL)
|
||||
|
||||
last_run = last_recorder_run(hass)
|
||||
# Dump states when stopping hass
|
||||
self.hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_STOP, lambda *_: self.hass.async_create_task(
|
||||
self.async_dump_states()))
|
||||
|
||||
if last_run is None or last_run.end is None:
|
||||
_LOGGER.debug('Not creating cache - no suitable last run found: %s',
|
||||
last_run)
|
||||
hass.data[DATA_RESTORE_CACHE] = {}
|
||||
return
|
||||
@callback
|
||||
def async_register_entity(self, entity_id: str) -> None:
|
||||
"""Store this entity's state when hass is shutdown."""
|
||||
self.entity_ids.add(entity_id)
|
||||
|
||||
last_end_time = last_run.end - timedelta(seconds=1)
|
||||
# Unfortunately the recorder_run model do not return offset-aware time
|
||||
last_end_time = last_end_time.replace(tzinfo=dt_util.UTC)
|
||||
_LOGGER.debug("Last run: %s - %s", last_run.start, last_end_time)
|
||||
|
||||
states = get_states(hass, last_end_time, run=last_run)
|
||||
|
||||
# Cache the states
|
||||
hass.data[DATA_RESTORE_CACHE] = {
|
||||
state.entity_id: state for state in states}
|
||||
_LOGGER.debug('Created cache with %s', list(hass.data[DATA_RESTORE_CACHE]))
|
||||
@callback
|
||||
def async_unregister_entity(self, entity_id: str) -> None:
|
||||
"""Unregister this entity from saving state."""
|
||||
self.entity_ids.remove(entity_id)
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_get_last_state(hass, entity_id: str):
|
||||
"""Restore state."""
|
||||
if DATA_RESTORE_CACHE in hass.data:
|
||||
return hass.data[DATA_RESTORE_CACHE].get(entity_id)
|
||||
class RestoreEntity(Entity):
|
||||
"""Mixin class for restoring previous entity state."""
|
||||
|
||||
if _RECORDER not in hass.config.components:
|
||||
return None
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Register this entity as a restorable entity."""
|
||||
_, data = await asyncio.gather(
|
||||
super().async_added_to_hass(),
|
||||
RestoreStateData.async_get_instance(self.hass),
|
||||
)
|
||||
data.async_register_entity(self.entity_id)
|
||||
|
||||
if hass.state not in (CoreState.starting, CoreState.not_running):
|
||||
_LOGGER.debug("Cache for %s can only be loaded during startup, not %s",
|
||||
entity_id, hass.state)
|
||||
return None
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
_, data = await asyncio.gather(
|
||||
super().async_will_remove_from_hass(),
|
||||
RestoreStateData.async_get_instance(self.hass),
|
||||
)
|
||||
data.async_unregister_entity(self.entity_id)
|
||||
|
||||
try:
|
||||
with async_timeout.timeout(RECORDER_TIMEOUT, loop=hass.loop):
|
||||
connected = await wait_connection_ready(hass)
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
|
||||
if not connected:
|
||||
return None
|
||||
|
||||
if _LOCK not in hass.data:
|
||||
hass.data[_LOCK] = asyncio.Lock(loop=hass.loop)
|
||||
|
||||
async with hass.data[_LOCK]:
|
||||
if DATA_RESTORE_CACHE not in hass.data:
|
||||
await hass.async_add_job(
|
||||
_load_restore_cache, hass)
|
||||
|
||||
return hass.data.get(DATA_RESTORE_CACHE, {}).get(entity_id)
|
||||
|
||||
|
||||
async def async_restore_state(entity, extract_info):
|
||||
"""Call entity.async_restore_state with cached info."""
|
||||
if entity.hass.state not in (CoreState.starting, CoreState.not_running):
|
||||
_LOGGER.debug("Not restoring state for %s: Hass is not starting: %s",
|
||||
entity.entity_id, entity.hass.state)
|
||||
return
|
||||
|
||||
state = await async_get_last_state(entity.hass, entity.entity_id)
|
||||
|
||||
if not state:
|
||||
return
|
||||
|
||||
await entity.async_restore_state(**extract_info(state))
|
||||
async def async_get_last_state(self) -> Optional[State]:
|
||||
"""Get the entity state from the previous run."""
|
||||
if self.hass is None or self.entity_id is None:
|
||||
# Return None if this entity isn't added to hass yet
|
||||
_LOGGER.warning("Cannot get last state. Entity not added to hass")
|
||||
return None
|
||||
data = await RestoreStateData.async_get_instance(self.hass)
|
||||
return data.last_states.get(self.entity_id)
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
"""Helper to help store data."""
|
||||
import asyncio
|
||||
from json import JSONEncoder
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Optional, Callable, Any
|
||||
from typing import Dict, List, Optional, Callable, Union
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import json
|
||||
from homeassistant.util import json as json_util
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
|
||||
STORAGE_DIR = '.storage'
|
||||
|
@ -16,7 +17,7 @@ _LOGGER = logging.getLogger(__name__)
|
|||
|
||||
@bind_hass
|
||||
async def async_migrator(hass, old_path, store, *,
|
||||
old_conf_load_func=json.load_json,
|
||||
old_conf_load_func=json_util.load_json,
|
||||
old_conf_migrate_func=None):
|
||||
"""Migrate old data to a store and then load data.
|
||||
|
||||
|
@ -46,7 +47,8 @@ async def async_migrator(hass, old_path, store, *,
|
|||
class Store:
|
||||
"""Class to help storing data."""
|
||||
|
||||
def __init__(self, hass, version: int, key: str, private: bool = False):
|
||||
def __init__(self, hass, version: int, key: str, private: bool = False, *,
|
||||
encoder: JSONEncoder = None):
|
||||
"""Initialize storage class."""
|
||||
self.version = version
|
||||
self.key = key
|
||||
|
@ -57,13 +59,14 @@ class Store:
|
|||
self._unsub_stop_listener = None
|
||||
self._write_lock = asyncio.Lock(loop=hass.loop)
|
||||
self._load_task = None
|
||||
self._encoder = encoder
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
"""Return the config path."""
|
||||
return self.hass.config.path(STORAGE_DIR, self.key)
|
||||
|
||||
async def async_load(self) -> Optional[Dict[str, Any]]:
|
||||
async def async_load(self) -> Optional[Union[Dict, List]]:
|
||||
"""Load data.
|
||||
|
||||
If the expected version does not match the given version, the migrate
|
||||
|
@ -88,7 +91,7 @@ class Store:
|
|||
data['data'] = data.pop('data_func')()
|
||||
else:
|
||||
data = await self.hass.async_add_executor_job(
|
||||
json.load_json, self.path)
|
||||
json_util.load_json, self.path)
|
||||
|
||||
if data == {}:
|
||||
return None
|
||||
|
@ -103,7 +106,7 @@ class Store:
|
|||
self._load_task = None
|
||||
return stored
|
||||
|
||||
async def async_save(self, data):
|
||||
async def async_save(self, data: Union[Dict, List]) -> None:
|
||||
"""Save data."""
|
||||
self._data = {
|
||||
'version': self.version,
|
||||
|
@ -178,7 +181,7 @@ class Store:
|
|||
try:
|
||||
await self.hass.async_add_executor_job(
|
||||
self._write_data, self.path, data)
|
||||
except (json.SerializationError, json.WriteError) as err:
|
||||
except (json_util.SerializationError, json_util.WriteError) as err:
|
||||
_LOGGER.error('Error writing config for %s: %s', self.key, err)
|
||||
|
||||
def _write_data(self, path: str, data: Dict):
|
||||
|
@ -187,7 +190,7 @@ class Store:
|
|||
os.makedirs(os.path.dirname(path))
|
||||
|
||||
_LOGGER.debug('Writing data for %s', self.key)
|
||||
json.save_json(path, data, self._private)
|
||||
json_util.save_json(path, data, self._private, encoder=self._encoder)
|
||||
|
||||
async def _async_migrate_func(self, old_version, old_data):
|
||||
"""Migrate to the new version."""
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""JSON utility functions."""
|
||||
import logging
|
||||
from typing import Union, List, Dict
|
||||
from typing import Union, List, Dict, Optional
|
||||
|
||||
import json
|
||||
import os
|
||||
|
@ -41,7 +41,8 @@ def load_json(filename: str, default: Union[List, Dict, None] = None) \
|
|||
|
||||
|
||||
def save_json(filename: str, data: Union[List, Dict],
|
||||
private: bool = False) -> None:
|
||||
private: bool = False, *,
|
||||
encoder: Optional[json.JSONEncoder] = None) -> None:
|
||||
"""Save JSON data to a file.
|
||||
|
||||
Returns True on success.
|
||||
|
@ -49,7 +50,7 @@ def save_json(filename: str, data: Union[List, Dict],
|
|||
tmp_filename = ""
|
||||
tmp_path = os.path.split(filename)[0]
|
||||
try:
|
||||
json_data = json.dumps(data, sort_keys=True, indent=4)
|
||||
json_data = json.dumps(data, sort_keys=True, indent=4, cls=encoder)
|
||||
# Modern versions of Python tempfile create this file with mode 0o600
|
||||
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8',
|
||||
dir=tmp_path, delete=False) as fdesc:
|
||||
|
|
|
@ -114,8 +114,7 @@ def get_test_home_assistant():
|
|||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
@asyncio.coroutine
|
||||
def async_test_home_assistant(loop):
|
||||
async def async_test_home_assistant(loop):
|
||||
"""Return a Home Assistant object pointing at test config dir."""
|
||||
hass = ha.HomeAssistant(loop)
|
||||
hass.config.async_load = Mock()
|
||||
|
@ -168,13 +167,12 @@ def async_test_home_assistant(loop):
|
|||
# Mock async_start
|
||||
orig_start = hass.async_start
|
||||
|
||||
@asyncio.coroutine
|
||||
def mock_async_start():
|
||||
async def mock_async_start():
|
||||
"""Start the mocking."""
|
||||
# We only mock time during tests and we want to track tasks
|
||||
with patch('homeassistant.core._async_create_timer'), \
|
||||
patch.object(hass, 'async_stop_track_tasks'):
|
||||
yield from orig_start()
|
||||
await orig_start()
|
||||
|
||||
hass.async_start = mock_async_start
|
||||
|
||||
|
@ -715,14 +713,20 @@ def init_recorder_component(hass, add_config=None):
|
|||
|
||||
def mock_restore_cache(hass, states):
|
||||
"""Mock the DATA_RESTORE_CACHE."""
|
||||
key = restore_state.DATA_RESTORE_CACHE
|
||||
hass.data[key] = {
|
||||
key = restore_state.DATA_RESTORE_STATE_TASK
|
||||
data = restore_state.RestoreStateData(hass)
|
||||
|
||||
data.last_states = {
|
||||
state.entity_id: state for state in states}
|
||||
_LOGGER.debug('Restore cache: %s', hass.data[key])
|
||||
assert len(hass.data[key]) == len(states), \
|
||||
_LOGGER.debug('Restore cache: %s', data.last_states)
|
||||
assert len(data.last_states) == len(states), \
|
||||
"Duplicate entity_id? {}".format(states)
|
||||
hass.state = ha.CoreState.starting
|
||||
mock_component(hass, recorder.DOMAIN)
|
||||
|
||||
async def get_restore_state_data() -> restore_state.RestoreStateData:
|
||||
return data
|
||||
|
||||
# Patch the singleton task in hass.data to return our new RestoreStateData
|
||||
hass.data[key] = hass.async_create_task(get_restore_state_data())
|
||||
|
||||
|
||||
class MockDependency:
|
||||
|
@ -846,9 +850,10 @@ def mock_storage(data=None):
|
|||
|
||||
def mock_write_data(store, path, data_to_write):
|
||||
"""Mock version of write data."""
|
||||
# To ensure that the data can be serialized
|
||||
_LOGGER.info('Writing data to %s: %s', store.key, data_to_write)
|
||||
data[store.key] = json.loads(json.dumps(data_to_write))
|
||||
# To ensure that the data can be serialized
|
||||
data[store.key] = json.loads(json.dumps(
|
||||
data_to_write, cls=store._encoder))
|
||||
|
||||
with patch('homeassistant.helpers.storage.Store._async_load',
|
||||
side_effect=mock_async_load, autospec=True), \
|
||||
|
|
|
@ -6,10 +6,8 @@ from unittest.mock import patch
|
|||
import requests
|
||||
from aiohttp.hdrs import CONTENT_TYPE
|
||||
|
||||
from homeassistant import setup, const, core
|
||||
import homeassistant.components as core_components
|
||||
from homeassistant import setup, const
|
||||
from homeassistant.components import emulated_hue, http
|
||||
from homeassistant.util.async_ import run_coroutine_threadsafe
|
||||
|
||||
from tests.common import get_test_instance_port, get_test_home_assistant
|
||||
|
||||
|
@ -20,29 +18,6 @@ BRIDGE_URL_BASE = 'http://127.0.0.1:{}'.format(BRIDGE_SERVER_PORT) + '{}'
|
|||
JSON_HEADERS = {CONTENT_TYPE: const.CONTENT_TYPE_JSON}
|
||||
|
||||
|
||||
def setup_hass_instance(emulated_hue_config):
|
||||
"""Set up the Home Assistant instance to test."""
|
||||
hass = get_test_home_assistant()
|
||||
|
||||
# We need to do this to get access to homeassistant/turn_(on,off)
|
||||
run_coroutine_threadsafe(
|
||||
core_components.async_setup(hass, {core.DOMAIN: {}}), hass.loop
|
||||
).result()
|
||||
|
||||
setup.setup_component(
|
||||
hass, http.DOMAIN,
|
||||
{http.DOMAIN: {http.CONF_SERVER_PORT: HTTP_SERVER_PORT}})
|
||||
|
||||
setup.setup_component(hass, emulated_hue.DOMAIN, emulated_hue_config)
|
||||
|
||||
return hass
|
||||
|
||||
|
||||
def start_hass_instance(hass):
|
||||
"""Start the Home Assistant instance to test."""
|
||||
hass.start()
|
||||
|
||||
|
||||
class TestEmulatedHue(unittest.TestCase):
|
||||
"""Test the emulated Hue component."""
|
||||
|
||||
|
@ -53,11 +28,6 @@ class TestEmulatedHue(unittest.TestCase):
|
|||
"""Set up the class."""
|
||||
cls.hass = hass = get_test_home_assistant()
|
||||
|
||||
# We need to do this to get access to homeassistant/turn_(on,off)
|
||||
run_coroutine_threadsafe(
|
||||
core_components.async_setup(hass, {core.DOMAIN: {}}), hass.loop
|
||||
).result()
|
||||
|
||||
setup.setup_component(
|
||||
hass, http.DOMAIN,
|
||||
{http.DOMAIN: {http.CONF_SERVER_PORT: HTTP_SERVER_PORT}})
|
||||
|
|
|
@ -585,7 +585,7 @@ async def test_sending_mqtt_commands_and_optimistic(hass, mqtt_mock):
|
|||
'effect': 'random',
|
||||
'color_temp': 100,
|
||||
'white_value': 50})
|
||||
with patch('homeassistant.components.light.mqtt.schema_basic'
|
||||
with patch('homeassistant.helpers.restore_state.RestoreEntity'
|
||||
'.async_get_last_state',
|
||||
return_value=mock_coro(fake_state)):
|
||||
with assert_setup_component(1, light.DOMAIN):
|
||||
|
|
|
@ -279,7 +279,7 @@ async def test_sending_mqtt_commands_and_optimistic(hass, mqtt_mock):
|
|||
'color_temp': 100,
|
||||
'white_value': 50})
|
||||
|
||||
with patch('homeassistant.components.light.mqtt.schema_json'
|
||||
with patch('homeassistant.helpers.restore_state.RestoreEntity'
|
||||
'.async_get_last_state',
|
||||
return_value=mock_coro(fake_state)):
|
||||
assert await async_setup_component(hass, light.DOMAIN, {
|
||||
|
|
|
@ -245,7 +245,7 @@ async def test_optimistic(hass, mqtt_mock):
|
|||
'color_temp': 100,
|
||||
'white_value': 50})
|
||||
|
||||
with patch('homeassistant.components.light.mqtt.schema_template'
|
||||
with patch('homeassistant.helpers.restore_state.RestoreEntity'
|
||||
'.async_get_last_state',
|
||||
return_value=mock_coro(fake_state)):
|
||||
with assert_setup_component(1, light.DOMAIN):
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
"""The tests for the Recorder component."""
|
||||
# pylint: disable=protected-access
|
||||
import asyncio
|
||||
from unittest.mock import patch, call
|
||||
|
||||
import pytest
|
||||
|
@ -9,7 +8,7 @@ from sqlalchemy.pool import StaticPool
|
|||
|
||||
from homeassistant.bootstrap import async_setup_component
|
||||
from homeassistant.components.recorder import (
|
||||
wait_connection_ready, migration, const, models)
|
||||
migration, const, models)
|
||||
from tests.components.recorder import models_original
|
||||
|
||||
|
||||
|
@ -23,26 +22,24 @@ def create_engine_test(*args, **kwargs):
|
|||
return engine
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_schema_update_calls(hass):
|
||||
async def test_schema_update_calls(hass):
|
||||
"""Test that schema migrations occur in correct order."""
|
||||
with patch('sqlalchemy.create_engine', new=create_engine_test), \
|
||||
patch('homeassistant.components.recorder.migration._apply_update') as \
|
||||
update:
|
||||
yield from async_setup_component(hass, 'recorder', {
|
||||
await async_setup_component(hass, 'recorder', {
|
||||
'recorder': {
|
||||
'db_url': 'sqlite://'
|
||||
}
|
||||
})
|
||||
yield from wait_connection_ready(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
update.assert_has_calls([
|
||||
call(hass.data[const.DATA_INSTANCE].engine, version+1, 0) for version
|
||||
in range(0, models.SCHEMA_VERSION)])
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_schema_migrate(hass):
|
||||
async def test_schema_migrate(hass):
|
||||
"""Test the full schema migration logic.
|
||||
|
||||
We're just testing that the logic can execute successfully here without
|
||||
|
@ -52,12 +49,12 @@ def test_schema_migrate(hass):
|
|||
with patch('sqlalchemy.create_engine', new=create_engine_test), \
|
||||
patch('homeassistant.components.recorder.Recorder._setup_run') as \
|
||||
setup_run:
|
||||
yield from async_setup_component(hass, 'recorder', {
|
||||
await async_setup_component(hass, 'recorder', {
|
||||
'recorder': {
|
||||
'db_url': 'sqlite://'
|
||||
}
|
||||
})
|
||||
yield from wait_connection_ready(hass)
|
||||
await hass.async_block_till_done()
|
||||
assert setup_run.called
|
||||
|
||||
|
||||
|
|
|
@ -57,7 +57,8 @@ async def test_sending_mqtt_commands_and_optimistic(hass, mock_publish):
|
|||
"""Test the sending MQTT commands in optimistic mode."""
|
||||
fake_state = ha.State('switch.test', 'on')
|
||||
|
||||
with patch('homeassistant.components.switch.mqtt.async_get_last_state',
|
||||
with patch('homeassistant.helpers.restore_state.RestoreEntity'
|
||||
'.async_get_last_state',
|
||||
return_value=mock_coro(fake_state)):
|
||||
assert await async_setup_component(hass, switch.DOMAIN, {
|
||||
switch.DOMAIN: {
|
||||
|
|
|
@ -519,7 +519,6 @@ async def test_fetch_period_api(hass, hass_client):
|
|||
"""Test the fetch period view for history."""
|
||||
await hass.async_add_job(init_recorder_component, hass)
|
||||
await async_setup_component(hass, 'history', {})
|
||||
await hass.components.recorder.wait_connection_ready()
|
||||
await hass.async_add_job(hass.data[recorder.DATA_INSTANCE].block_till_done)
|
||||
client = await hass_client()
|
||||
response = await client.get(
|
||||
|
|
|
@ -575,7 +575,6 @@ async def test_logbook_view(hass, aiohttp_client):
|
|||
"""Test the logbook view."""
|
||||
await hass.async_add_job(init_recorder_component, hass)
|
||||
await async_setup_component(hass, 'logbook', {})
|
||||
await hass.components.recorder.wait_connection_ready()
|
||||
await hass.async_add_job(hass.data[recorder.DATA_INSTANCE].block_till_done)
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
response = await client.get(
|
||||
|
@ -587,7 +586,6 @@ async def test_logbook_view_period_entity(hass, aiohttp_client):
|
|||
"""Test the logbook view with period and entity."""
|
||||
await hass.async_add_job(init_recorder_component, hass)
|
||||
await async_setup_component(hass, 'logbook', {})
|
||||
await hass.components.recorder.wait_connection_ready()
|
||||
await hass.async_add_job(hass.data[recorder.DATA_INSTANCE].block_till_done)
|
||||
|
||||
entity_id_test = 'switch.test'
|
||||
|
|
|
@ -1,60 +1,52 @@
|
|||
"""The tests for the Restore component."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch, MagicMock
|
||||
from datetime import datetime
|
||||
|
||||
from homeassistant.setup import setup_component
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_START
|
||||
from homeassistant.core import CoreState, split_entity_id, State
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.components import input_boolean, recorder
|
||||
from homeassistant.core import CoreState, State
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.restore_state import (
|
||||
async_get_last_state, DATA_RESTORE_CACHE)
|
||||
from homeassistant.components.recorder.models import RecorderRuns, States
|
||||
RestoreStateData, RestoreEntity, DATA_RESTORE_STATE_TASK)
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from tests.common import (
|
||||
get_test_home_assistant, mock_coro, init_recorder_component,
|
||||
mock_component)
|
||||
from asynctest import patch
|
||||
|
||||
from tests.common import mock_coro
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_caching_data(hass):
|
||||
async def test_caching_data(hass):
|
||||
"""Test that we cache data."""
|
||||
mock_component(hass, 'recorder')
|
||||
hass.state = CoreState.starting
|
||||
|
||||
states = [
|
||||
State('input_boolean.b0', 'on'),
|
||||
State('input_boolean.b1', 'on'),
|
||||
State('input_boolean.b2', 'on'),
|
||||
]
|
||||
|
||||
with patch('homeassistant.helpers.restore_state.last_recorder_run',
|
||||
return_value=MagicMock(end=dt_util.utcnow())), \
|
||||
patch('homeassistant.helpers.restore_state.get_states',
|
||||
return_value=states), \
|
||||
patch('homeassistant.helpers.restore_state.wait_connection_ready',
|
||||
return_value=mock_coro(True)):
|
||||
state = yield from async_get_last_state(hass, 'input_boolean.b1')
|
||||
data = await RestoreStateData.async_get_instance(hass)
|
||||
await data.store.async_save([state.as_dict() for state in states])
|
||||
|
||||
assert DATA_RESTORE_CACHE in hass.data
|
||||
assert hass.data[DATA_RESTORE_CACHE] == {st.entity_id: st for st in states}
|
||||
# Emulate a fresh load
|
||||
hass.data[DATA_RESTORE_STATE_TASK] = None
|
||||
|
||||
entity = RestoreEntity()
|
||||
entity.hass = hass
|
||||
entity.entity_id = 'input_boolean.b1'
|
||||
|
||||
# Mock that only b1 is present this run
|
||||
with patch('homeassistant.helpers.restore_state.Store.async_save'
|
||||
) as mock_write_data:
|
||||
state = await entity.async_get_last_state()
|
||||
|
||||
assert state is not None
|
||||
assert state.entity_id == 'input_boolean.b1'
|
||||
assert state.state == 'on'
|
||||
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_START)
|
||||
|
||||
yield from hass.async_block_till_done()
|
||||
|
||||
assert DATA_RESTORE_CACHE not in hass.data
|
||||
assert mock_write_data.called
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_hass_running(hass):
|
||||
"""Test that cache cannot be accessed while hass is running."""
|
||||
mock_component(hass, 'recorder')
|
||||
async def test_hass_starting(hass):
|
||||
"""Test that we cache data."""
|
||||
hass.state = CoreState.starting
|
||||
|
||||
states = [
|
||||
State('input_boolean.b0', 'on'),
|
||||
|
@ -62,129 +54,144 @@ def test_hass_running(hass):
|
|||
State('input_boolean.b2', 'on'),
|
||||
]
|
||||
|
||||
with patch('homeassistant.helpers.restore_state.last_recorder_run',
|
||||
return_value=MagicMock(end=dt_util.utcnow())), \
|
||||
patch('homeassistant.helpers.restore_state.get_states',
|
||||
return_value=states), \
|
||||
patch('homeassistant.helpers.restore_state.wait_connection_ready',
|
||||
return_value=mock_coro(True)):
|
||||
state = yield from async_get_last_state(hass, 'input_boolean.b1')
|
||||
assert state is None
|
||||
data = await RestoreStateData.async_get_instance(hass)
|
||||
await data.store.async_save([state.as_dict() for state in states])
|
||||
|
||||
# Emulate a fresh load
|
||||
hass.data[DATA_RESTORE_STATE_TASK] = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_not_connected(hass):
|
||||
"""Test that cache cannot be accessed if db connection times out."""
|
||||
mock_component(hass, 'recorder')
|
||||
hass.state = CoreState.starting
|
||||
entity = RestoreEntity()
|
||||
entity.hass = hass
|
||||
entity.entity_id = 'input_boolean.b1'
|
||||
|
||||
states = [State('input_boolean.b1', 'on')]
|
||||
# Mock that only b1 is present this run
|
||||
states = [
|
||||
State('input_boolean.b1', 'on'),
|
||||
]
|
||||
with patch('homeassistant.helpers.restore_state.Store.async_save'
|
||||
) as mock_write_data, patch.object(
|
||||
hass.states, 'async_all', return_value=states):
|
||||
state = await entity.async_get_last_state()
|
||||
|
||||
with patch('homeassistant.helpers.restore_state.last_recorder_run',
|
||||
return_value=MagicMock(end=dt_util.utcnow())), \
|
||||
patch('homeassistant.helpers.restore_state.get_states',
|
||||
return_value=states), \
|
||||
patch('homeassistant.helpers.restore_state.wait_connection_ready',
|
||||
return_value=mock_coro(False)):
|
||||
state = yield from async_get_last_state(hass, 'input_boolean.b1')
|
||||
assert state is None
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_no_last_run_found(hass):
|
||||
"""Test that cache cannot be accessed if no last run found."""
|
||||
mock_component(hass, 'recorder')
|
||||
hass.state = CoreState.starting
|
||||
|
||||
states = [State('input_boolean.b1', 'on')]
|
||||
|
||||
with patch('homeassistant.helpers.restore_state.last_recorder_run',
|
||||
return_value=None), \
|
||||
patch('homeassistant.helpers.restore_state.get_states',
|
||||
return_value=states), \
|
||||
patch('homeassistant.helpers.restore_state.wait_connection_ready',
|
||||
return_value=mock_coro(True)):
|
||||
state = yield from async_get_last_state(hass, 'input_boolean.b1')
|
||||
assert state is None
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_cache_timeout(hass):
|
||||
"""Test that cache timeout returns none."""
|
||||
mock_component(hass, 'recorder')
|
||||
hass.state = CoreState.starting
|
||||
|
||||
states = [State('input_boolean.b1', 'on')]
|
||||
|
||||
@asyncio.coroutine
|
||||
def timeout_coro():
|
||||
raise asyncio.TimeoutError()
|
||||
|
||||
with patch('homeassistant.helpers.restore_state.last_recorder_run',
|
||||
return_value=MagicMock(end=dt_util.utcnow())), \
|
||||
patch('homeassistant.helpers.restore_state.get_states',
|
||||
return_value=states), \
|
||||
patch('homeassistant.helpers.restore_state.wait_connection_ready',
|
||||
return_value=timeout_coro()):
|
||||
state = yield from async_get_last_state(hass, 'input_boolean.b1')
|
||||
assert state is None
|
||||
|
||||
|
||||
def _add_data_in_last_run(hass, entities):
|
||||
"""Add test data in the last recorder_run."""
|
||||
# pylint: disable=protected-access
|
||||
t_now = dt_util.utcnow() - timedelta(minutes=10)
|
||||
t_min_1 = t_now - timedelta(minutes=20)
|
||||
t_min_2 = t_now - timedelta(minutes=30)
|
||||
|
||||
with recorder.session_scope(hass=hass) as session:
|
||||
session.add(RecorderRuns(
|
||||
start=t_min_2,
|
||||
end=t_now,
|
||||
created=t_min_2
|
||||
))
|
||||
|
||||
for entity_id, state in entities.items():
|
||||
session.add(States(
|
||||
entity_id=entity_id,
|
||||
domain=split_entity_id(entity_id)[0],
|
||||
state=state,
|
||||
attributes='{}',
|
||||
last_changed=t_min_1,
|
||||
last_updated=t_min_1,
|
||||
created=t_min_1))
|
||||
|
||||
|
||||
def test_filling_the_cache():
|
||||
"""Test filling the cache from the DB."""
|
||||
test_entity_id1 = 'input_boolean.b1'
|
||||
test_entity_id2 = 'input_boolean.b2'
|
||||
|
||||
hass = get_test_home_assistant()
|
||||
hass.state = CoreState.starting
|
||||
|
||||
init_recorder_component(hass)
|
||||
|
||||
_add_data_in_last_run(hass, {
|
||||
test_entity_id1: 'on',
|
||||
test_entity_id2: 'off',
|
||||
})
|
||||
|
||||
hass.block_till_done()
|
||||
setup_component(hass, input_boolean.DOMAIN, {
|
||||
input_boolean.DOMAIN: {
|
||||
'b1': None,
|
||||
'b2': None,
|
||||
}})
|
||||
|
||||
hass.start()
|
||||
|
||||
state = hass.states.get('input_boolean.b1')
|
||||
assert state
|
||||
assert state is not None
|
||||
assert state.entity_id == 'input_boolean.b1'
|
||||
assert state.state == 'on'
|
||||
|
||||
state = hass.states.get('input_boolean.b2')
|
||||
assert state
|
||||
assert state.state == 'off'
|
||||
# Assert that no data was written yet, since hass is still starting.
|
||||
assert not mock_write_data.called
|
||||
|
||||
hass.stop()
|
||||
# Finish hass startup
|
||||
with patch('homeassistant.helpers.restore_state.Store.async_save'
|
||||
) as mock_write_data:
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_START)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Assert that this session states were written
|
||||
assert mock_write_data.called
|
||||
|
||||
|
||||
async def test_dump_data(hass):
|
||||
"""Test that we cache data."""
|
||||
states = [
|
||||
State('input_boolean.b0', 'on'),
|
||||
State('input_boolean.b1', 'on'),
|
||||
State('input_boolean.b2', 'on'),
|
||||
]
|
||||
|
||||
entity = Entity()
|
||||
entity.hass = hass
|
||||
entity.entity_id = 'input_boolean.b0'
|
||||
await entity.async_added_to_hass()
|
||||
|
||||
entity = RestoreEntity()
|
||||
entity.hass = hass
|
||||
entity.entity_id = 'input_boolean.b1'
|
||||
await entity.async_added_to_hass()
|
||||
|
||||
data = await RestoreStateData.async_get_instance(hass)
|
||||
data.last_states = {
|
||||
'input_boolean.b0': State('input_boolean.b0', 'off'),
|
||||
'input_boolean.b1': State('input_boolean.b1', 'off'),
|
||||
'input_boolean.b2': State('input_boolean.b2', 'off'),
|
||||
'input_boolean.b3': State('input_boolean.b3', 'off'),
|
||||
'input_boolean.b4': State(
|
||||
'input_boolean.b4', 'off', last_updated=datetime(
|
||||
1985, 10, 26, 1, 22, tzinfo=dt_util.UTC)),
|
||||
}
|
||||
|
||||
with patch('homeassistant.helpers.restore_state.Store.async_save'
|
||||
) as mock_write_data, patch.object(
|
||||
hass.states, 'async_all', return_value=states):
|
||||
await data.async_dump_states()
|
||||
|
||||
assert mock_write_data.called
|
||||
args = mock_write_data.mock_calls[0][1]
|
||||
written_states = args[0]
|
||||
|
||||
# b0 should not be written, since it didn't extend RestoreEntity
|
||||
# b1 should be written, since it is present in the current run
|
||||
# b2 should not be written, since it is not registered with the helper
|
||||
# b3 should be written, since it is still not expired
|
||||
# b4 should not be written, since it is now expired
|
||||
assert len(written_states) == 2
|
||||
assert written_states[0]['entity_id'] == 'input_boolean.b1'
|
||||
assert written_states[0]['state'] == 'on'
|
||||
assert written_states[1]['entity_id'] == 'input_boolean.b3'
|
||||
assert written_states[1]['state'] == 'off'
|
||||
|
||||
# Test that removed entities are not persisted
|
||||
await entity.async_will_remove_from_hass()
|
||||
|
||||
with patch('homeassistant.helpers.restore_state.Store.async_save'
|
||||
) as mock_write_data, patch.object(
|
||||
hass.states, 'async_all', return_value=states):
|
||||
await data.async_dump_states()
|
||||
|
||||
assert mock_write_data.called
|
||||
args = mock_write_data.mock_calls[0][1]
|
||||
written_states = args[0]
|
||||
assert len(written_states) == 1
|
||||
assert written_states[0]['entity_id'] == 'input_boolean.b3'
|
||||
assert written_states[0]['state'] == 'off'
|
||||
|
||||
|
||||
async def test_dump_error(hass):
|
||||
"""Test that we cache data."""
|
||||
states = [
|
||||
State('input_boolean.b0', 'on'),
|
||||
State('input_boolean.b1', 'on'),
|
||||
State('input_boolean.b2', 'on'),
|
||||
]
|
||||
|
||||
entity = Entity()
|
||||
entity.hass = hass
|
||||
entity.entity_id = 'input_boolean.b0'
|
||||
await entity.async_added_to_hass()
|
||||
|
||||
entity = RestoreEntity()
|
||||
entity.hass = hass
|
||||
entity.entity_id = 'input_boolean.b1'
|
||||
await entity.async_added_to_hass()
|
||||
|
||||
data = await RestoreStateData.async_get_instance(hass)
|
||||
|
||||
with patch('homeassistant.helpers.restore_state.Store.async_save',
|
||||
return_value=mock_coro(exception=HomeAssistantError)
|
||||
) as mock_write_data, patch.object(
|
||||
hass.states, 'async_all', return_value=states):
|
||||
await data.async_dump_states()
|
||||
|
||||
assert mock_write_data.called
|
||||
|
||||
|
||||
async def test_load_error(hass):
|
||||
"""Test that we cache data."""
|
||||
entity = RestoreEntity()
|
||||
entity.hass = hass
|
||||
entity.entity_id = 'input_boolean.b1'
|
||||
|
||||
with patch('homeassistant.helpers.storage.Store.async_load',
|
||||
return_value=mock_coro(exception=HomeAssistantError)):
|
||||
state = await entity.async_get_last_state()
|
||||
|
||||
assert state is None
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
"""Tests for the storage helper."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
import json
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -31,6 +32,21 @@ async def test_loading(hass, store):
|
|||
assert data == MOCK_DATA
|
||||
|
||||
|
||||
async def test_custom_encoder(hass):
|
||||
"""Test we can save and load data."""
|
||||
class JSONEncoder(json.JSONEncoder):
|
||||
"""Mock JSON encoder."""
|
||||
|
||||
def default(self, o):
|
||||
"""Mock JSON encode method."""
|
||||
return "9"
|
||||
|
||||
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, encoder=JSONEncoder)
|
||||
await store.async_save(Mock())
|
||||
data = await store.async_load()
|
||||
assert data == "9"
|
||||
|
||||
|
||||
async def test_loading_non_existing(hass, store):
|
||||
"""Test we can save and load data."""
|
||||
with patch('homeassistant.util.json.open', side_effect=FileNotFoundError):
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
"""Test Home Assistant json utility functions."""
|
||||
from json import JSONEncoder
|
||||
import os
|
||||
import unittest
|
||||
import sys
|
||||
from tempfile import mkdtemp
|
||||
|
||||
from homeassistant.util.json import (SerializationError,
|
||||
load_json, save_json)
|
||||
from homeassistant.util.json import (
|
||||
SerializationError, load_json, save_json)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import pytest
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
# Test data that can be saved as JSON
|
||||
TEST_JSON_A = {"a": 1, "B": "two"}
|
||||
TEST_JSON_B = {"a": "one", "B": 2}
|
||||
|
@ -74,3 +77,17 @@ class TestJSON(unittest.TestCase):
|
|||
fh.write(TEST_BAD_SERIALIED)
|
||||
with pytest.raises(HomeAssistantError):
|
||||
load_json(fname)
|
||||
|
||||
def test_custom_encoder(self):
|
||||
"""Test serializing with a custom encoder."""
|
||||
class MockJSONEncoder(JSONEncoder):
|
||||
"""Mock JSON encoder."""
|
||||
|
||||
def default(self, o):
|
||||
"""Mock JSON encode method."""
|
||||
return "9"
|
||||
|
||||
fname = self._path_for("test6")
|
||||
save_json(fname, Mock(), encoder=MockJSONEncoder)
|
||||
data = load_json(fname)
|
||||
self.assertEqual(data, "9")
|
||||
|
|
Loading…
Add table
Reference in a new issue