Restore_state helper to restore entity states from the DB on startup (#4614)

* Restore states

* feedback

* Remove component move into recorder

* space

* helper

* Address my own comments

* Improve test coverage

* Add test for light restore state
This commit is contained in:
Johann Kellerman 2017-02-21 09:40:27 +02:00 committed by Paulus Schoutsen
parent 2b9fb73032
commit fdc373f27e
18 changed files with 425 additions and 184 deletions

View file

@ -15,7 +15,6 @@ import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
HTTP_BAD_REQUEST, CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE) HTTP_BAD_REQUEST, CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE)
import homeassistant.helpers.config_validation as cv
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.components import recorder, script from homeassistant.components import recorder, script
from homeassistant.components.frontend import register_built_in_panel from homeassistant.components.frontend import register_built_in_panel
@ -28,34 +27,22 @@ DOMAIN = 'history'
DEPENDENCIES = ['recorder', 'http'] DEPENDENCIES = ['recorder', 'http']
CONFIG_SCHEMA = vol.Schema({ CONFIG_SCHEMA = vol.Schema({
DOMAIN: vol.Schema({ DOMAIN: recorder.FILTER_SCHEMA,
CONF_EXCLUDE: vol.Schema({
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]):
vol.All(cv.ensure_list, [cv.string])
}),
CONF_INCLUDE: vol.Schema({
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]):
vol.All(cv.ensure_list, [cv.string])
})
}),
}, extra=vol.ALLOW_EXTRA) }, extra=vol.ALLOW_EXTRA)
SIGNIFICANT_DOMAINS = ('thermostat', 'climate') SIGNIFICANT_DOMAINS = ('thermostat', 'climate')
IGNORE_DOMAINS = ('zone', 'scene',) IGNORE_DOMAINS = ('zone', 'scene',)
def last_5_states(entity_id): def last_recorder_run():
"""Return the last 5 states for entity_id.""" """Retireve the last closed recorder run from the DB."""
entity_id = entity_id.lower() rec_runs = recorder.get_model('RecorderRuns')
with recorder.session_scope() as session:
states = recorder.get_model('States') res = recorder.query(rec_runs).order_by(rec_runs.end.desc()).first()
return recorder.execute( if res is None:
recorder.query('States').filter( return None
(states.entity_id == entity_id) & session.expunge(res)
(states.last_changed == states.last_updated) return res
).order_by(states.state_id.desc()).limit(5))
def get_significant_states(start_time, end_time=None, entity_id=None, def get_significant_states(start_time, end_time=None, entity_id=None,
@ -91,7 +78,7 @@ def get_significant_states(start_time, end_time=None, entity_id=None,
def state_changes_during_period(start_time, end_time=None, entity_id=None): def state_changes_during_period(start_time, end_time=None, entity_id=None):
"""Return states changes during UTC period start_time - end_time.""" """Return states changes during UTC period start_time - end_time."""
states = recorder.get_model('States') states = recorder.get_model('States')
query = recorder.query('States').filter( query = recorder.query(states).filter(
(states.last_changed == states.last_updated) & (states.last_changed == states.last_updated) &
(states.last_changed > start_time)) (states.last_changed > start_time))
@ -132,7 +119,7 @@ def get_states(utc_point_in_time, entity_ids=None, run=None, filters=None):
most_recent_state_ids = most_recent_state_ids.group_by( most_recent_state_ids = most_recent_state_ids.group_by(
states.entity_id).subquery() states.entity_id).subquery()
query = recorder.query('States').join(most_recent_state_ids, and_( query = recorder.query(states).join(most_recent_state_ids, and_(
states.state_id == most_recent_state_ids.c.max_state_id)) states.state_id == most_recent_state_ids.c.max_state_id))
for state in recorder.execute(query): for state in recorder.execute(query):
@ -185,27 +172,13 @@ def setup(hass, config):
filters.included_entities = include[CONF_ENTITIES] filters.included_entities = include[CONF_ENTITIES]
filters.included_domains = include[CONF_DOMAINS] filters.included_domains = include[CONF_DOMAINS]
hass.http.register_view(Last5StatesView) recorder.get_instance()
hass.http.register_view(HistoryPeriodView(filters)) hass.http.register_view(HistoryPeriodView(filters))
register_built_in_panel(hass, 'history', 'History', 'mdi:poll-box') register_built_in_panel(hass, 'history', 'History', 'mdi:poll-box')
return True return True
class Last5StatesView(HomeAssistantView):
"""Handle last 5 state view requests."""
url = '/api/history/entity/{entity_id}/recent_states'
name = 'api:history:entity-recent-states'
@asyncio.coroutine
def get(self, request, entity_id):
"""Retrieve last 5 states of entity."""
result = yield from request.app['hass'].loop.run_in_executor(
None, last_5_states, entity_id)
return self.json(result)
class HistoryPeriodView(HomeAssistantView): class HistoryPeriodView(HomeAssistantView):
"""Handle history period requests.""" """Handle history period requests."""

View file

@ -15,6 +15,7 @@ from homeassistant.const import (
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import async_get_last_state
DOMAIN = 'input_boolean' DOMAIN = 'input_boolean'
@ -139,6 +140,14 @@ class InputBoolean(ToggleEntity):
"""Return true if entity is on.""" """Return true if entity is on."""
return self._state return self._state
@asyncio.coroutine
def async_added_to_hass(self):
"""Called when entity about to be added to hass."""
state = yield from async_get_last_state(self.hass, self.entity_id)
if not state:
return
self._state = state.state == 'on'
@asyncio.coroutine @asyncio.coroutine
def async_turn_on(self, **kwargs): def async_turn_on(self, **kwargs):
"""Turn the entity on.""" """Turn the entity on."""

View file

@ -22,6 +22,7 @@ from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.restore_state import async_restore_state
import homeassistant.util.color as color_util import homeassistant.util.color as color_util
from homeassistant.util.async import run_callback_threadsafe from homeassistant.util.async import run_callback_threadsafe
@ -126,6 +127,14 @@ PROFILE_SCHEMA = vol.Schema(
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def extract_info(state):
"""Extract light parameters from a state object."""
params = {key: state.attributes[key] for key in PROP_TO_ATTR
if key in state.attributes}
params['is_on'] = state.state == STATE_ON
return params
def is_on(hass, entity_id=None): def is_on(hass, entity_id=None):
"""Return if the lights are on based on the statemachine.""" """Return if the lights are on based on the statemachine."""
entity_id = entity_id or ENTITY_ID_ALL_LIGHTS entity_id = entity_id or ENTITY_ID_ALL_LIGHTS
@ -369,3 +378,9 @@ class Light(ToggleEntity):
def supported_features(self): def supported_features(self):
"""Flag supported features.""" """Flag supported features."""
return 0 return 0
@asyncio.coroutine
def async_added_to_hass(self):
"""Component added, restore_state using platforms."""
if hasattr(self, 'async_restore_state'):
yield from async_restore_state(self, extract_info)

View file

@ -4,6 +4,7 @@ Demo light platform that implements lights.
For more details about this platform, please refer to the documentation For more details about this platform, please refer to the documentation
https://home-assistant.io/components/demo/ https://home-assistant.io/components/demo/
""" """
import asyncio
import random import random
from homeassistant.components.light import ( from homeassistant.components.light import (
@ -149,3 +150,26 @@ class DemoLight(Light):
# As we have disabled polling, we need to inform # As we have disabled polling, we need to inform
# Home Assistant about updates in our state ourselves. # Home Assistant about updates in our state ourselves.
self.schedule_update_ha_state() self.schedule_update_ha_state()
@asyncio.coroutine
def async_restore_state(self, is_on, **kwargs):
"""Restore the demo state."""
self._state = is_on
if 'brightness' in kwargs:
self._brightness = kwargs['brightness']
if 'color_temp' in kwargs:
self._ct = kwargs['color_temp']
if 'rgb_color' in kwargs:
self._rgb = kwargs['rgb_color']
if 'xy_color' in kwargs:
self._xy_color = kwargs['xy_color']
if 'white_value' in kwargs:
self._white = kwargs['white_value']
if 'effect' in kwargs:
self._effect = kwargs['effect']

View file

@ -22,6 +22,7 @@ from homeassistant.const import (
ATTR_ENTITY_ID, CONF_ENTITIES, CONF_EXCLUDE, CONF_DOMAINS, ATTR_ENTITY_ID, CONF_ENTITIES, CONF_EXCLUDE, CONF_DOMAINS,
CONF_INCLUDE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, CONF_INCLUDE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL) EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL)
from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import ConfigType, QueryType from homeassistant.helpers.typing import ConfigType, QueryType
@ -42,36 +43,35 @@ CONNECT_RETRY_WAIT = 10
QUERY_RETRY_WAIT = 0.1 QUERY_RETRY_WAIT = 0.1
ERROR_QUERY = "Error during query: %s" ERROR_QUERY = "Error during query: %s"
FILTER_SCHEMA = vol.Schema({
vol.Optional(CONF_EXCLUDE, default={}): vol.Schema({
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]):
vol.All(cv.ensure_list, [cv.string])
}),
vol.Optional(CONF_INCLUDE, default={}): vol.Schema({
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]):
vol.All(cv.ensure_list, [cv.string])
})
})
CONFIG_SCHEMA = vol.Schema({ CONFIG_SCHEMA = vol.Schema({
DOMAIN: vol.Schema({ DOMAIN: FILTER_SCHEMA.extend({
vol.Optional(CONF_PURGE_DAYS): vol.Optional(CONF_PURGE_DAYS):
vol.All(vol.Coerce(int), vol.Range(min=1)), vol.All(vol.Coerce(int), vol.Range(min=1)),
vol.Optional(CONF_DB_URL): cv.string, vol.Optional(CONF_DB_URL): cv.string,
vol.Optional(CONF_EXCLUDE, default={}): vol.Schema({
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]):
vol.All(cv.ensure_list, [cv.string])
}),
vol.Optional(CONF_INCLUDE, default={}): vol.Schema({
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]):
vol.All(cv.ensure_list, [cv.string])
})
}) })
}, extra=vol.ALLOW_EXTRA) }, extra=vol.ALLOW_EXTRA)
_INSTANCE = None # type: Any _INSTANCE = None # type: Any
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# These classes will be populated during setup()
# scoped_session, in the same thread session_scope() stays the same
_SESSION = None
@contextmanager @contextmanager
def session_scope(): def session_scope():
"""Provide a transactional scope around a series of operations.""" """Provide a transactional scope around a series of operations."""
session = _SESSION() session = _INSTANCE.get_session()
try: try:
yield session yield session
session.commit() session.commit()
@ -83,15 +83,28 @@ def session_scope():
session.close() session.close()
def get_instance() -> None:
"""Throw error if recorder not initialized."""
if _INSTANCE is None:
raise RuntimeError("Recorder not initialized.")
ident = _INSTANCE.hass.loop.__dict__.get("_thread_ident")
if ident is not None and ident == threading.get_ident():
raise RuntimeError('Cannot be called from within the event loop')
_wait(_INSTANCE.db_ready, "Database not ready")
return _INSTANCE
# pylint: disable=invalid-sequence-index # pylint: disable=invalid-sequence-index
def execute(qry: QueryType) -> List[Any]: def execute(qry: QueryType) -> List[Any]:
"""Query the database and convert the objects to HA native form. """Query the database and convert the objects to HA native form.
This method also retries a few times in the case of stale connections. This method also retries a few times in the case of stale connections.
""" """
_verify_instance() get_instance()
from sqlalchemy.exc import SQLAlchemyError
import sqlalchemy.exc
with session_scope() as session: with session_scope() as session:
for _ in range(0, RETRIES): for _ in range(0, RETRIES):
try: try:
@ -99,7 +112,7 @@ def execute(qry: QueryType) -> List[Any]:
row for row in row for row in
(row.to_native() for row in qry) (row.to_native() for row in qry)
if row is not None] if row is not None]
except sqlalchemy.exc.SQLAlchemyError as err: except SQLAlchemyError as err:
_LOGGER.error(ERROR_QUERY, err) _LOGGER.error(ERROR_QUERY, err)
session.rollback() session.rollback()
time.sleep(QUERY_RETRY_WAIT) time.sleep(QUERY_RETRY_WAIT)
@ -111,13 +124,13 @@ def run_information(point_in_time: Optional[datetime]=None):
There is also the run that covers point_in_time. There is also the run that covers point_in_time.
""" """
_verify_instance() ins = get_instance()
recorder_runs = get_model('RecorderRuns') recorder_runs = get_model('RecorderRuns')
if point_in_time is None or point_in_time > _INSTANCE.recording_start: if point_in_time is None or point_in_time > ins.recording_start:
return recorder_runs( return recorder_runs(
end=None, end=None,
start=_INSTANCE.recording_start, start=ins.recording_start,
closed_incorrect=False) closed_incorrect=False)
with session_scope() as session: with session_scope() as session:
@ -148,17 +161,19 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool:
exclude = config.get(DOMAIN, {}).get(CONF_EXCLUDE, {}) exclude = config.get(DOMAIN, {}).get(CONF_EXCLUDE, {})
_INSTANCE = Recorder(hass, purge_days=purge_days, uri=db_url, _INSTANCE = Recorder(hass, purge_days=purge_days, uri=db_url,
include=include, exclude=exclude) include=include, exclude=exclude)
_INSTANCE.start()
return True return True
def query(model_name: Union[str, Any], *args) -> QueryType: def query(model_name: Union[str, Any], session=None, *args) -> QueryType:
"""Helper to return a query handle.""" """Helper to return a query handle."""
_verify_instance() if session is None:
session = get_instance().get_session()
if isinstance(model_name, str): if isinstance(model_name, str):
return _SESSION().query(get_model(model_name), *args) return session.query(get_model(model_name), *args)
return _SESSION().query(model_name, *args) return session.query(model_name, *args)
def get_model(model_name: str) -> Any: def get_model(model_name: str) -> Any:
@ -185,6 +200,7 @@ class Recorder(threading.Thread):
self.recording_start = dt_util.utcnow() self.recording_start = dt_util.utcnow()
self.db_url = uri self.db_url = uri
self.db_ready = threading.Event() self.db_ready = threading.Event()
self.start_recording = threading.Event()
self.engine = None # type: Any self.engine = None # type: Any
self._run = None # type: Any self._run = None # type: Any
@ -195,23 +211,26 @@ class Recorder(threading.Thread):
def start_recording(event): def start_recording(event):
"""Start recording.""" """Start recording."""
self.start() self.start_recording.set()
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_recording) hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_recording)
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, self.shutdown) hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, self.shutdown)
hass.bus.listen(MATCH_ALL, self.event_listener) hass.bus.listen(MATCH_ALL, self.event_listener)
self.get_session = None
def run(self): def run(self):
"""Start processing events to save.""" """Start processing events to save."""
from homeassistant.components.recorder.models import Events, States from homeassistant.components.recorder.models import Events, States
import sqlalchemy.exc from sqlalchemy.exc import SQLAlchemyError
while True: while True:
try: try:
self._setup_connection() self._setup_connection()
self._setup_run() self._setup_run()
self.db_ready.set()
break break
except sqlalchemy.exc.SQLAlchemyError as err: except SQLAlchemyError as err:
_LOGGER.error("Error during connection setup: %s (retrying " _LOGGER.error("Error during connection setup: %s (retrying "
"in %s seconds)", err, CONNECT_RETRY_WAIT) "in %s seconds)", err, CONNECT_RETRY_WAIT)
time.sleep(CONNECT_RETRY_WAIT) time.sleep(CONNECT_RETRY_WAIT)
@ -220,6 +239,8 @@ class Recorder(threading.Thread):
async_track_time_interval( async_track_time_interval(
self.hass, self._purge_old_data, timedelta(days=2)) self.hass, self._purge_old_data, timedelta(days=2))
_wait(self.start_recording, "Waiting to start recording")
while True: while True:
event = self.queue.get() event = self.queue.get()
@ -275,10 +296,9 @@ class Recorder(threading.Thread):
def shutdown(self, event): def shutdown(self, event):
"""Tell the recorder to shut down.""" """Tell the recorder to shut down."""
global _INSTANCE # pylint: disable=global-statement global _INSTANCE # pylint: disable=global-statement
_INSTANCE = None
self.queue.put(None) self.queue.put(None)
self.join() self.join()
_INSTANCE = None
def block_till_done(self): def block_till_done(self):
"""Block till all events processed.""" """Block till all events processed."""
@ -286,15 +306,10 @@ class Recorder(threading.Thread):
def block_till_db_ready(self): def block_till_db_ready(self):
"""Block until the database session is ready.""" """Block until the database session is ready."""
self.db_ready.wait(10) _wait(self.db_ready, "Database not ready")
while not self.db_ready.is_set():
_LOGGER.warning('Database not ready, waiting another 10 seconds.')
self.db_ready.wait(10)
def _setup_connection(self): def _setup_connection(self):
"""Ensure database is ready to fly.""" """Ensure database is ready to fly."""
global _SESSION # pylint: disable=invalid-name,global-statement
import homeassistant.components.recorder.models as models import homeassistant.components.recorder.models as models
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session from sqlalchemy.orm import scoped_session
@ -312,9 +327,8 @@ class Recorder(threading.Thread):
models.Base.metadata.create_all(self.engine) models.Base.metadata.create_all(self.engine)
session_factory = sessionmaker(bind=self.engine) session_factory = sessionmaker(bind=self.engine)
_SESSION = scoped_session(session_factory) self.get_session = scoped_session(session_factory)
self._migrate_schema() self._migrate_schema()
self.db_ready.set()
def _migrate_schema(self): def _migrate_schema(self):
"""Check if the schema needs to be upgraded.""" """Check if the schema needs to be upgraded."""
@ -396,16 +410,16 @@ class Recorder(threading.Thread):
def _close_connection(self): def _close_connection(self):
"""Close the connection.""" """Close the connection."""
global _SESSION # pylint: disable=invalid-name,global-statement
self.engine.dispose() self.engine.dispose()
self.engine = None self.engine = None
_SESSION = None self.get_session = None
def _setup_run(self): def _setup_run(self):
"""Log the start of the current run.""" """Log the start of the current run."""
recorder_runs = get_model('RecorderRuns') recorder_runs = get_model('RecorderRuns')
with session_scope() as session: with session_scope() as session:
for run in query('RecorderRuns').filter_by(end=None): for run in query(
recorder_runs, session=session).filter_by(end=None):
run.closed_incorrect = True run.closed_incorrect = True
run.end = self.recording_start run.end = self.recording_start
_LOGGER.warning("Ended unfinished session (id=%s from %s)", _LOGGER.warning("Ended unfinished session (id=%s from %s)",
@ -482,13 +496,13 @@ class Recorder(threading.Thread):
return False return False
def _verify_instance() -> None: def _wait(event, message):
"""Throw error if recorder not initialized.""" """Event wait helper."""
if _INSTANCE is None: for retry in (10, 20, 30):
raise RuntimeError("Recorder not initialized.") event.wait(10)
if event.is_set():
ident = _INSTANCE.hass.loop.__dict__.get("_thread_ident") return
if ident is not None and ident == threading.get_ident(): msg = message + " ({} seconds)".format(retry)
raise RuntimeError('Cannot be called from within the event loop') _LOGGER.warning(msg)
if not event.is_set():
_INSTANCE.block_till_db_ready() raise HomeAssistantError(msg)

View file

@ -199,7 +199,7 @@ class HistoryStatsSensor(Entity):
if self._start is not None: if self._start is not None:
try: try:
start_rendered = self._start.render() start_rendered = self._start.render()
except TemplateError as ex: except (TemplateError, TypeError) as ex:
HistoryStatsHelper.handle_template_exception(ex, 'start') HistoryStatsHelper.handle_template_exception(ex, 'start')
return return
start = dt_util.parse_datetime(start_rendered) start = dt_util.parse_datetime(start_rendered)
@ -216,7 +216,7 @@ class HistoryStatsSensor(Entity):
if self._end is not None: if self._end is not None:
try: try:
end_rendered = self._end.render() end_rendered = self._end.render()
except TemplateError as ex: except (TemplateError, TypeError) as ex:
HistoryStatsHelper.handle_template_exception(ex, 'end') HistoryStatsHelper.handle_template_exception(ex, 'end')
return return
end = dt_util.parse_datetime(end_rendered) end = dt_util.parse_datetime(end_rendered)

View file

@ -288,7 +288,7 @@ class Entity(object):
self.hass.add_job(self.async_update_ha_state(force_refresh)) self.hass.add_job(self.async_update_ha_state(force_refresh))
def remove(self) -> None: def remove(self) -> None:
"""Remove entitiy from HASS.""" """Remove entity from HASS."""
run_coroutine_threadsafe( run_coroutine_threadsafe(
self.async_remove(), self.hass.loop self.async_remove(), self.hass.loop
).result() ).result()

View file

@ -202,6 +202,10 @@ class EntityComponent(object):
'Invalid entity id: {}'.format(entity.entity_id)) 'Invalid entity id: {}'.format(entity.entity_id))
self.entities[entity.entity_id] = entity self.entities[entity.entity_id] = entity
if hasattr(entity, 'async_added_to_hass'):
yield from entity.async_added_to_hass()
yield from entity.async_update_ha_state() yield from entity.async_update_ha_state()
return True return True

View file

@ -0,0 +1,82 @@
"""Support for restoring entity states on startup."""
import asyncio
import logging
from datetime import timedelta
from homeassistant.core import HomeAssistant, CoreState, callback
from homeassistant.const import EVENT_HOMEASSISTANT_START
from homeassistant.components.history import get_states, last_recorder_run
from homeassistant.components.recorder import DOMAIN as _RECORDER
import homeassistant.util.dt as dt_util
_LOGGER = logging.getLogger(__name__)
DATA_RESTORE_CACHE = 'restore_state_cache'
_LOCK = 'restore_lock'
def _load_restore_cache(hass: HomeAssistant):
"""Load the restore cache to be used by other components."""
@callback
def remove_cache(event):
"""Remove the states cache."""
hass.data.pop(DATA_RESTORE_CACHE, None)
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, remove_cache)
last_run = last_recorder_run()
if last_run is None or last_run.end is None:
_LOGGER.debug('Not creating cache - no suitable last run found: %s',
last_run)
hass.data[DATA_RESTORE_CACHE] = {}
return
last_end_time = last_run.end - timedelta(seconds=1)
# Unfortunately the recorder_run model do not return offset-aware time
last_end_time = last_end_time.replace(tzinfo=dt_util.UTC)
_LOGGER.debug("Last run: %s - %s", last_run.start, last_end_time)
states = get_states(last_end_time, run=last_run)
# Cache the states
hass.data[DATA_RESTORE_CACHE] = {
state.entity_id: state for state in states}
_LOGGER.debug('Created cache with %s', list(hass.data[DATA_RESTORE_CACHE]))
@asyncio.coroutine
def async_get_last_state(hass, entity_id: str):
"""Helper to restore state."""
if (_RECORDER not in hass.config.components or
hass.state != CoreState.starting):
return None
if DATA_RESTORE_CACHE in hass.data:
return hass.data[DATA_RESTORE_CACHE].get(entity_id)
if _LOCK not in hass.data:
hass.data[_LOCK] = asyncio.Lock(loop=hass.loop)
with (yield from hass.data[_LOCK]):
if DATA_RESTORE_CACHE not in hass.data:
yield from hass.loop.run_in_executor(
None, _load_restore_cache, hass)
return hass.data[DATA_RESTORE_CACHE].get(entity_id)
@asyncio.coroutine
def async_restore_state(entity, extract_info):
"""Helper to call entity.async_restore_state with cached info."""
if entity.hass.state != CoreState.starting:
_LOGGER.debug("Not restoring state: State is not starting: %s",
entity.hass.state)
return
state = yield from async_get_last_state(entity.hass, entity.entity_id)
if not state:
return
yield from entity.async_restore_state(**extract_info(state))

View file

@ -197,8 +197,8 @@ def load_order_components(components: Sequence[str]) -> OrderedSet:
load_order.update(comp_load_order) load_order.update(comp_load_order)
# Push some to first place in load order # Push some to first place in load order
for comp in ('mqtt_eventstream', 'mqtt', 'logger', for comp in ('mqtt_eventstream', 'mqtt', 'recorder',
'recorder', 'introduction'): 'introduction', 'logger'):
if comp in load_order: if comp in load_order:
load_order.promote(comp) load_order.promote(comp)

View file

@ -22,7 +22,7 @@ from homeassistant.const import (
STATE_ON, STATE_OFF, DEVICE_DEFAULT_NAME, EVENT_TIME_CHANGED, STATE_ON, STATE_OFF, DEVICE_DEFAULT_NAME, EVENT_TIME_CHANGED,
EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE, EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE,
ATTR_DISCOVERED, SERVER_PORT) ATTR_DISCOVERED, SERVER_PORT)
from homeassistant.components import sun, mqtt from homeassistant.components import sun, mqtt, recorder
from homeassistant.components.http.auth import auth_middleware from homeassistant.components.http.auth import auth_middleware
from homeassistant.components.http.const import ( from homeassistant.components.http.const import (
KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED, KEY_TRUSTED_NETWORKS) KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED, KEY_TRUSTED_NETWORKS)
@ -452,3 +452,31 @@ def assert_setup_component(count, domain=None):
res_len = 0 if res is None else len(res) res_len = 0 if res is None else len(res)
assert res_len == count, 'setup_component failed, expected {} got {}: {}' \ assert res_len == count, 'setup_component failed, expected {} got {}: {}' \
.format(count, res_len, res) .format(count, res_len, res)
def init_recorder_component(hass, add_config=None, db_ready_callback=None):
"""Initialize the recorder."""
config = dict(add_config) if add_config else {}
config[recorder.CONF_DB_URL] = 'sqlite://' # In memory DB
saved_recorder = recorder.Recorder
class Recorder2(saved_recorder):
"""Recorder with a callback after db_ready."""
def _setup_connection(self):
"""Setup the connection and run the callback."""
super(Recorder2, self)._setup_connection()
if db_ready_callback:
_LOGGER.debug('db_ready_callback start (db_ready not set,'
'never use get_instance in the callback)')
db_ready_callback()
_LOGGER.debug('db_ready_callback completed')
with patch('homeassistant.components.recorder.Recorder',
side_effect=Recorder2):
assert setup_component(hass, recorder.DOMAIN,
{recorder.DOMAIN: config})
assert recorder.DOMAIN in hass.config.components
recorder.get_instance().block_till_db_ready()
_LOGGER.info("In-memory recorder successfully started")

View file

@ -1,17 +1,20 @@
"""The tests for the demo light component.""" """The tests for the demo light component."""
# pylint: disable=protected-access # pylint: disable=protected-access
import asyncio
import unittest import unittest
from homeassistant.bootstrap import setup_component from homeassistant.core import State, CoreState
from homeassistant.bootstrap import setup_component, async_setup_component
import homeassistant.components.light as light import homeassistant.components.light as light
from homeassistant.helpers.restore_state import DATA_RESTORE_CACHE
from tests.common import get_test_home_assistant from tests.common import get_test_home_assistant
ENTITY_LIGHT = 'light.bed_light' ENTITY_LIGHT = 'light.bed_light'
class TestDemoClimate(unittest.TestCase): class TestDemoLight(unittest.TestCase):
"""Test the demo climate hvac.""" """Test the demo light."""
# pylint: disable=invalid-name # pylint: disable=invalid-name
def setUp(self): def setUp(self):
@ -60,3 +63,36 @@ class TestDemoClimate(unittest.TestCase):
light.turn_off(self.hass, ENTITY_LIGHT) light.turn_off(self.hass, ENTITY_LIGHT)
self.hass.block_till_done() self.hass.block_till_done()
self.assertFalse(light.is_on(self.hass, ENTITY_LIGHT)) self.assertFalse(light.is_on(self.hass, ENTITY_LIGHT))
@asyncio.coroutine
def test_restore_state(hass):
"""Test state gets restored."""
hass.config.components.add('recorder')
hass.state = CoreState.starting
hass.data[DATA_RESTORE_CACHE] = {
'light.bed_light': State('light.bed_light', 'on', {
'brightness': 'value-brightness',
'color_temp': 'value-color_temp',
'rgb_color': 'value-rgb_color',
'xy_color': 'value-xy_color',
'white_value': 'value-white_value',
'effect': 'value-effect',
}),
}
yield from async_setup_component(hass, 'light', {
'light': {
'platform': 'demo',
}})
state = hass.states.get('light.bed_light')
assert state is not None
assert state.entity_id == 'light.bed_light'
assert state.state == 'on'
assert state.attributes.get('brightness') == 'value-brightness'
assert state.attributes.get('color_temp') == 'value-color_temp'
assert state.attributes.get('rgb_color') == 'value-rgb_color'
assert state.attributes.get('xy_color') == 'value-xy_color'
assert state.attributes.get('white_value') == 'value-white_value'
assert state.attributes.get('effect') == 'value-effect'

View file

@ -11,8 +11,7 @@ from sqlalchemy import create_engine
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.const import MATCH_ALL from homeassistant.const import MATCH_ALL
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.bootstrap import setup_component from tests.common import get_test_home_assistant, init_recorder_component
from tests.common import get_test_home_assistant
from tests.components.recorder import models_original from tests.components.recorder import models_original
@ -22,18 +21,15 @@ class BaseTestRecorder(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant() self.hass = get_test_home_assistant()
db_uri = 'sqlite://' # In memory DB init_recorder_component(self.hass)
setup_component(self.hass, recorder.DOMAIN, {
recorder.DOMAIN: {recorder.CONF_DB_URL: db_uri}})
self.hass.start() self.hass.start()
recorder._verify_instance() recorder.get_instance().block_till_done()
recorder._INSTANCE.block_till_done()
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
"""Stop everything that was started.""" """Stop everything that was started."""
recorder._INSTANCE.shutdown(None)
self.hass.stop() self.hass.stop()
assert recorder._INSTANCE is None with self.assertRaises(RuntimeError):
recorder.get_instance()
def _add_test_states(self): def _add_test_states(self):
"""Add multiple states to the db for testing.""" """Add multiple states to the db for testing."""
@ -228,7 +224,7 @@ class TestMigrateRecorder(BaseTestRecorder):
@patch('sqlalchemy.create_engine', new=create_engine_test) @patch('sqlalchemy.create_engine', new=create_engine_test)
@patch('homeassistant.components.recorder.Recorder._migrate_schema') @patch('homeassistant.components.recorder.Recorder._migrate_schema')
def setUp(self, migrate): # pylint: disable=invalid-name def setUp(self, migrate): # pylint: disable=invalid-name,arguments-differ
"""Setup things to be run when tests are started. """Setup things to be run when tests are started.
create_engine is patched to create a db that starts with the old create_engine is patched to create a db that starts with the old
@ -261,16 +257,12 @@ def hass_recorder():
"""HASS fixture with in-memory recorder.""" """HASS fixture with in-memory recorder."""
hass = get_test_home_assistant() hass = get_test_home_assistant()
def setup_recorder(config={}): def setup_recorder(config=None):
"""Setup with params.""" """Setup with params."""
db_uri = 'sqlite://' # In memory DB init_recorder_component(hass, config)
conf = {recorder.CONF_DB_URL: db_uri}
conf.update(config)
assert setup_component(hass, recorder.DOMAIN, {recorder.DOMAIN: conf})
hass.start() hass.start()
hass.block_till_done() hass.block_till_done()
recorder._verify_instance() recorder.get_instance().block_till_done()
recorder._INSTANCE.block_till_done()
return hass return hass
yield setup_recorder yield setup_recorder
@ -352,12 +344,12 @@ def test_recorder_errors_exceptions(hass_recorder): \
# Verify the instance fails before setup # Verify the instance fails before setup
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
recorder._verify_instance() recorder.get_instance()
# Setup the recorder # Setup the recorder
hass_recorder() hass_recorder()
recorder._verify_instance() recorder.get_instance()
# Verify session scope raises (and prints) an exception # Verify session scope raises (and prints) an exception
with patch('homeassistant.components.recorder._LOGGER.error') as e_mock, \ with patch('homeassistant.components.recorder._LOGGER.error') as e_mock, \

View file

@ -1,16 +1,17 @@
"""The test for the History Statistics sensor platform.""" """The test for the History Statistics sensor platform."""
# pylint: disable=protected-access # pylint: disable=protected-access
import unittest
from datetime import timedelta from datetime import timedelta
import unittest
from unittest.mock import patch from unittest.mock import patch
import homeassistant.components.recorder as recorder
import homeassistant.core as ha
import homeassistant.util.dt as dt_util
from homeassistant.bootstrap import setup_component from homeassistant.bootstrap import setup_component
import homeassistant.components.recorder as recorder
from homeassistant.components.sensor.history_stats import HistoryStatsSensor from homeassistant.components.sensor.history_stats import HistoryStatsSensor
import homeassistant.core as ha
from homeassistant.helpers.template import Template from homeassistant.helpers.template import Template
from tests.common import get_test_home_assistant import homeassistant.util.dt as dt_util
from tests.common import init_recorder_component, get_test_home_assistant
class TestHistoryStatsSensor(unittest.TestCase): class TestHistoryStatsSensor(unittest.TestCase):
@ -204,12 +205,8 @@ class TestHistoryStatsSensor(unittest.TestCase):
def init_recorder(self): def init_recorder(self):
"""Initialize the recorder.""" """Initialize the recorder."""
db_uri = 'sqlite://' init_recorder_component(self.hass)
with patch('homeassistant.core.Config.path', return_value=db_uri):
setup_component(self.hass, recorder.DOMAIN, {
"recorder": {
"db_url": db_uri}})
self.hass.start() self.hass.start()
recorder._INSTANCE.block_till_db_ready() recorder.get_instance().block_till_db_ready()
self.hass.block_till_done() self.hass.block_till_done()
recorder._INSTANCE.block_till_done() recorder.get_instance().block_till_done()

View file

@ -1,5 +1,5 @@
"""The tests the History component.""" """The tests the History component."""
# pylint: disable=protected-access # pylint: disable=protected-access,invalid-name
from datetime import timedelta from datetime import timedelta
import unittest import unittest
from unittest.mock import patch, sentinel from unittest.mock import patch, sentinel
@ -10,68 +10,47 @@ import homeassistant.util.dt as dt_util
from homeassistant.components import history, recorder from homeassistant.components import history, recorder
from tests.common import ( from tests.common import (
mock_http_component, mock_state_change_event, get_test_home_assistant) init_recorder_component, mock_http_component, mock_state_change_event,
get_test_home_assistant)
class TestComponentHistory(unittest.TestCase): class TestComponentHistory(unittest.TestCase):
"""Test History component.""" """Test History component."""
# pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
def setUp(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant() self.hass = get_test_home_assistant()
# pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
def tearDown(self):
"""Stop everything that was started.""" """Stop everything that was started."""
self.hass.stop() self.hass.stop()
def init_recorder(self): def init_recorder(self):
"""Initialize the recorder.""" """Initialize the recorder."""
db_uri = 'sqlite://' init_recorder_component(self.hass)
with patch('homeassistant.core.Config.path', return_value=db_uri):
setup_component(self.hass, recorder.DOMAIN, {
"recorder": {
"db_url": db_uri}})
self.hass.start() self.hass.start()
recorder._INSTANCE.block_till_db_ready() recorder.get_instance().block_till_db_ready()
self.wait_recording_done() self.wait_recording_done()
def wait_recording_done(self): def wait_recording_done(self):
"""Block till recording is done.""" """Block till recording is done."""
self.hass.block_till_done() self.hass.block_till_done()
recorder._INSTANCE.block_till_done() recorder.get_instance().block_till_done()
def test_setup(self): def test_setup(self):
"""Test setup method of history.""" """Test setup method of history."""
mock_http_component(self.hass) mock_http_component(self.hass)
config = history.CONFIG_SCHEMA({ config = history.CONFIG_SCHEMA({
ha.DOMAIN: {}, # ha.DOMAIN: {},
history.DOMAIN: {history.CONF_INCLUDE: { history.DOMAIN: {
history.CONF_INCLUDE: {
history.CONF_DOMAINS: ['media_player'], history.CONF_DOMAINS: ['media_player'],
history.CONF_ENTITIES: ['thermostat.test']}, history.CONF_ENTITIES: ['thermostat.test']},
history.CONF_EXCLUDE: { history.CONF_EXCLUDE: {
history.CONF_DOMAINS: ['thermostat'], history.CONF_DOMAINS: ['thermostat'],
history.CONF_ENTITIES: ['media_player.test']}}}) history.CONF_ENTITIES: ['media_player.test']}}})
self.assertTrue(setup_component(self.hass, history.DOMAIN, config))
def test_last_5_states(self):
"""Test retrieving the last 5 states."""
self.init_recorder() self.init_recorder()
states = [] self.assertTrue(setup_component(self.hass, history.DOMAIN, config))
entity_id = 'test.last_5_states'
for i in range(7):
self.hass.states.set(entity_id, "State {}".format(i))
self.wait_recording_done()
if i > 1:
states.append(self.hass.states.get(entity_id))
self.assertEqual(
list(reversed(states)), history.last_5_states(entity_id))
def test_get_states(self): def test_get_states(self):
"""Test getting states at a specific point in time.""" """Test getting states at a specific point in time."""
@ -121,6 +100,7 @@ class TestComponentHistory(unittest.TestCase):
entity_id = 'media_player.test' entity_id = 'media_player.test'
def set_state(state): def set_state(state):
"""Set the state."""
self.hass.states.set(entity_id, state) self.hass.states.set(entity_id, state)
self.wait_recording_done() self.wait_recording_done()
return self.hass.states.get(entity_id) return self.hass.states.get(entity_id)
@ -311,7 +291,8 @@ class TestComponentHistory(unittest.TestCase):
config = history.CONFIG_SCHEMA({ config = history.CONFIG_SCHEMA({
ha.DOMAIN: {}, ha.DOMAIN: {},
history.DOMAIN: {history.CONF_INCLUDE: { history.DOMAIN: {
history.CONF_INCLUDE: {
history.CONF_DOMAINS: ['media_player']}, history.CONF_DOMAINS: ['media_player']},
history.CONF_EXCLUDE: { history.CONF_EXCLUDE: {
history.CONF_DOMAINS: ['media_player']}}}) history.CONF_DOMAINS: ['media_player']}}})
@ -332,7 +313,8 @@ class TestComponentHistory(unittest.TestCase):
config = history.CONFIG_SCHEMA({ config = history.CONFIG_SCHEMA({
ha.DOMAIN: {}, ha.DOMAIN: {},
history.DOMAIN: {history.CONF_INCLUDE: { history.DOMAIN: {
history.CONF_INCLUDE: {
history.CONF_ENTITIES: ['media_player.test']}, history.CONF_ENTITIES: ['media_player.test']},
history.CONF_EXCLUDE: { history.CONF_EXCLUDE: {
history.CONF_ENTITIES: ['media_player.test']}}}) history.CONF_ENTITIES: ['media_player.test']}}})
@ -351,7 +333,8 @@ class TestComponentHistory(unittest.TestCase):
config = history.CONFIG_SCHEMA({ config = history.CONFIG_SCHEMA({
ha.DOMAIN: {}, ha.DOMAIN: {},
history.DOMAIN: {history.CONF_INCLUDE: { history.DOMAIN: {
history.CONF_INCLUDE: {
history.CONF_DOMAINS: ['media_player'], history.CONF_DOMAINS: ['media_player'],
history.CONF_ENTITIES: ['thermostat.test']}, history.CONF_ENTITIES: ['thermostat.test']},
history.CONF_EXCLUDE: { history.CONF_EXCLUDE: {
@ -359,7 +342,8 @@ class TestComponentHistory(unittest.TestCase):
history.CONF_ENTITIES: ['media_player.test']}}}) history.CONF_ENTITIES: ['media_player.test']}}})
self.check_significant_states(zero, four, states, config) self.check_significant_states(zero, four, states, config)
def check_significant_states(self, zero, four, states, config): def check_significant_states(self, zero, four, states, config): \
# pylint: disable=no-self-use
"""Check if significant states are retrieved.""" """Check if significant states are retrieved."""
filters = history.Filters() filters = history.Filters()
exclude = config[history.DOMAIN].get(history.CONF_EXCLUDE) exclude = config[history.DOMAIN].get(history.CONF_EXCLUDE)
@ -390,6 +374,7 @@ class TestComponentHistory(unittest.TestCase):
script_c = 'script.can_cancel_this_one' script_c = 'script.can_cancel_this_one'
def set_state(entity_id, state, **kwargs): def set_state(entity_id, state, **kwargs):
"""Set the state."""
self.hass.states.set(entity_id, state, **kwargs) self.hass.states.set(entity_id, state, **kwargs)
self.wait_recording_done() self.wait_recording_done()
return self.hass.states.get(entity_id) return self.hass.states.get(entity_id)

View file

@ -1,15 +1,18 @@
"""The tests for the input_boolean component.""" """The tests for the input_boolean component."""
# pylint: disable=protected-access # pylint: disable=protected-access
import asyncio
import unittest import unittest
import logging import logging
from tests.common import get_test_home_assistant from tests.common import get_test_home_assistant
from homeassistant.bootstrap import setup_component from homeassistant.core import CoreState, State
from homeassistant.bootstrap import setup_component, async_setup_component
from homeassistant.components.input_boolean import ( from homeassistant.components.input_boolean import (
DOMAIN, is_on, toggle, turn_off, turn_on) DOMAIN, is_on, toggle, turn_off, turn_on)
from homeassistant.const import ( from homeassistant.const import (
STATE_ON, STATE_OFF, ATTR_ICON, ATTR_FRIENDLY_NAME) STATE_ON, STATE_OFF, ATTR_ICON, ATTR_FRIENDLY_NAME)
from homeassistant.helpers.restore_state import DATA_RESTORE_CACHE
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -103,3 +106,30 @@ class TestInputBoolean(unittest.TestCase):
self.assertEqual('Hello World', self.assertEqual('Hello World',
state_2.attributes.get(ATTR_FRIENDLY_NAME)) state_2.attributes.get(ATTR_FRIENDLY_NAME))
self.assertEqual('mdi:work', state_2.attributes.get(ATTR_ICON)) self.assertEqual('mdi:work', state_2.attributes.get(ATTR_ICON))
@asyncio.coroutine
def test_restore_state(hass):
"""Ensure states are restored on startup."""
hass.data[DATA_RESTORE_CACHE] = {
'input_boolean.b1': State('input_boolean.b1', 'on'),
'input_boolean.b2': State('input_boolean.b2', 'off'),
'input_boolean.b3': State('input_boolean.b3', 'on'),
}
hass.state = CoreState.starting
hass.config.components.add('recorder')
yield from async_setup_component(hass, DOMAIN, {
DOMAIN: {
'b1': None,
'b2': None,
}})
state = hass.states.get('input_boolean.b1')
assert state
assert state.state == 'on'
state = hass.states.get('input_boolean.b2')
assert state
assert state.state == 'off'

View file

@ -1,5 +1,6 @@
"""The tests for the logbook component.""" """The tests for the logbook component."""
# pylint: disable=protected-access # pylint: disable=protected-access,invalid-name
import logging
from datetime import timedelta from datetime import timedelta
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
@ -13,7 +14,11 @@ import homeassistant.util.dt as dt_util
from homeassistant.components import logbook from homeassistant.components import logbook
from homeassistant.bootstrap import setup_component from homeassistant.bootstrap import setup_component
from tests.common import mock_http_component, get_test_home_assistant from tests.common import (
mock_http_component, init_recorder_component, get_test_home_assistant)
_LOGGER = logging.getLogger(__name__)
class TestComponentLogbook(unittest.TestCase): class TestComponentLogbook(unittest.TestCase):
@ -24,12 +29,14 @@ class TestComponentLogbook(unittest.TestCase):
def setUp(self): def setUp(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant() self.hass = get_test_home_assistant()
init_recorder_component(self.hass) # Force an in memory DB
mock_http_component(self.hass) mock_http_component(self.hass)
self.hass.config.components |= set(['frontend', 'recorder', 'api']) self.hass.config.components |= set(['frontend', 'recorder', 'api'])
with patch('homeassistant.components.logbook.' with patch('homeassistant.components.logbook.'
'register_built_in_panel'): 'register_built_in_panel'):
assert setup_component(self.hass, logbook.DOMAIN, assert setup_component(self.hass, logbook.DOMAIN,
self.EMPTY_CONFIG) self.EMPTY_CONFIG)
self.hass.start()
def tearDown(self): def tearDown(self):
"""Stop everything that was started.""" """Stop everything that was started."""
@ -41,6 +48,7 @@ class TestComponentLogbook(unittest.TestCase):
@ha.callback @ha.callback
def event_listener(event): def event_listener(event):
"""Append on event."""
calls.append(event) calls.append(event)
self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener) self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener)
@ -72,6 +80,7 @@ class TestComponentLogbook(unittest.TestCase):
@ha.callback @ha.callback
def event_listener(event): def event_listener(event):
"""Append on event."""
calls.append(event) calls.append(event)
self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener) self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener)
@ -242,17 +251,17 @@ class TestComponentLogbook(unittest.TestCase):
entity_id2 = 'sensor.blu' entity_id2 = 'sensor.blu'
eventA = ha.Event(logbook.EVENT_LOGBOOK_ENTRY, { eventA = ha.Event(logbook.EVENT_LOGBOOK_ENTRY, {
logbook.ATTR_NAME: name, logbook.ATTR_NAME: name,
logbook.ATTR_MESSAGE: message, logbook.ATTR_MESSAGE: message,
logbook.ATTR_DOMAIN: domain, logbook.ATTR_DOMAIN: domain,
logbook.ATTR_ENTITY_ID: entity_id, logbook.ATTR_ENTITY_ID: entity_id,
}) })
eventB = ha.Event(logbook.EVENT_LOGBOOK_ENTRY, { eventB = ha.Event(logbook.EVENT_LOGBOOK_ENTRY, {
logbook.ATTR_NAME: name, logbook.ATTR_NAME: name,
logbook.ATTR_MESSAGE: message, logbook.ATTR_MESSAGE: message,
logbook.ATTR_DOMAIN: domain, logbook.ATTR_DOMAIN: domain,
logbook.ATTR_ENTITY_ID: entity_id2, logbook.ATTR_ENTITY_ID: entity_id2,
}) })
config = logbook.CONFIG_SCHEMA({ config = logbook.CONFIG_SCHEMA({
ha.DOMAIN: {}, ha.DOMAIN: {},
@ -532,7 +541,8 @@ class TestComponentLogbook(unittest.TestCase):
def create_state_changed_event(self, event_time_fired, entity_id, state, def create_state_changed_event(self, event_time_fired, entity_id, state,
attributes=None, last_changed=None, attributes=None, last_changed=None,
last_updated=None): last_updated=None): \
# pylint: disable=no-self-use
"""Create state changed event.""" """Create state changed event."""
# Logbook only cares about state change events that # Logbook only cares about state change events that
# contain an old state but will not actually act on it. # contain an old state but will not actually act on it.

View file

@ -0,0 +1,42 @@
"""The tests for the Restore component."""
import asyncio
from unittest.mock import patch, MagicMock
from homeassistant.const import EVENT_HOMEASSISTANT_START
from homeassistant.core import CoreState, State
import homeassistant.util.dt as dt_util
from homeassistant.helpers.restore_state import (
async_get_last_state, DATA_RESTORE_CACHE)
@asyncio.coroutine
def test_caching_data(hass):
"""Test that we cache data."""
hass.config.components.add('recorder')
hass.state = CoreState.starting
states = [
State('input_boolean.b0', 'on'),
State('input_boolean.b1', 'on'),
State('input_boolean.b2', 'on'),
]
with patch('homeassistant.helpers.restore_state.last_recorder_run',
return_value=MagicMock(end=dt_util.utcnow())), \
patch('homeassistant.helpers.restore_state.get_states',
return_value=states):
state = yield from async_get_last_state(hass, 'input_boolean.b1')
assert DATA_RESTORE_CACHE in hass.data
assert hass.data[DATA_RESTORE_CACHE] == {st.entity_id: st for st in states}
assert state is not None
assert state.entity_id == 'input_boolean.b1'
assert state.state == 'on'
hass.bus.async_fire(EVENT_HOMEASSISTANT_START)
yield from hass.async_block_till_done()
assert DATA_RESTORE_CACHE not in hass.data