Restore for automation entities (#6254)
* Restore for automation entities * coroutine * no clue what i'm doing now * Still passes nicely in py 3.4
This commit is contained in:
parent
8232f1ef65
commit
1522e67351
3 changed files with 80 additions and 26 deletions
|
@ -21,6 +21,7 @@ from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import extract_domain_configs, script, condition
|
from homeassistant.helpers import extract_domain_configs, script, condition
|
||||||
from homeassistant.helpers.entity import ToggleEntity
|
from homeassistant.helpers.entity import ToggleEntity
|
||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
from homeassistant.helpers.restore_state import async_get_last_state
|
||||||
from homeassistant.loader import get_platform
|
from homeassistant.loader import get_platform
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.util.dt import utcnow
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
@ -265,9 +266,15 @@ class AutomationEntity(ToggleEntity):
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_added_to_hass(self) -> None:
|
def async_added_to_hass(self) -> None:
|
||||||
"""Startup if initial_state."""
|
"""Startup with initial state or previous state."""
|
||||||
|
state = yield from async_get_last_state(self.hass, self.entity_id)
|
||||||
|
if state is None:
|
||||||
if self._initial_state:
|
if self._initial_state:
|
||||||
yield from self.async_enable()
|
yield from self.async_enable()
|
||||||
|
else:
|
||||||
|
self._last_triggered = state.attributes.get('last_triggered')
|
||||||
|
if state.state == STATE_ON:
|
||||||
|
yield from self.async_enable()
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_turn_on(self, **kwargs) -> None:
|
def async_turn_on(self, **kwargs) -> None:
|
||||||
|
|
|
@ -131,6 +131,7 @@ def async_test_home_assistant(loop):
|
||||||
|
|
||||||
@ha.callback
|
@ha.callback
|
||||||
def clear_instance(event):
|
def clear_instance(event):
|
||||||
|
"""Clear global instance."""
|
||||||
global INST_COUNT
|
global INST_COUNT
|
||||||
INST_COUNT -= 1
|
INST_COUNT -= 1
|
||||||
|
|
||||||
|
@ -140,20 +141,18 @@ def async_test_home_assistant(loop):
|
||||||
|
|
||||||
|
|
||||||
def mock_service(hass, domain, service):
|
def mock_service(hass, domain, service):
|
||||||
"""Setup a fake service.
|
"""Setup a fake service & return a list that logs calls to this service."""
|
||||||
|
|
||||||
Return a list that logs all calls to fake service.
|
|
||||||
"""
|
|
||||||
calls = []
|
calls = []
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
@asyncio.coroutine
|
||||||
@ha.callback
|
def mock_service_log(call): # pylint: disable=unnecessary-lambda
|
||||||
def mock_service(call):
|
|
||||||
""""Mocked service call."""
|
""""Mocked service call."""
|
||||||
calls.append(call)
|
calls.append(call)
|
||||||
|
|
||||||
# pylint: disable=unnecessary-lambda
|
if hass.loop.__dict__.get("_thread_ident", 0) == threading.get_ident():
|
||||||
hass.services.register(domain, service, mock_service)
|
hass.services.async_register(domain, service, mock_service_log)
|
||||||
|
else:
|
||||||
|
hass.services.register(domain, service, mock_service_log)
|
||||||
|
|
||||||
return calls
|
return calls
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,19 @@
|
||||||
"""The tests for the automation component."""
|
"""The tests for the automation component."""
|
||||||
import unittest
|
import asyncio
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import State
|
||||||
from homeassistant.bootstrap import setup_component
|
from homeassistant.bootstrap import setup_component, async_setup_component
|
||||||
import homeassistant.components.automation as automation
|
import homeassistant.components.automation as automation
|
||||||
from homeassistant.const import ATTR_ENTITY_ID
|
from homeassistant.const import ATTR_ENTITY_ID, STATE_ON, STATE_OFF
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
|
||||||
from tests.common import get_test_home_assistant, assert_setup_component, \
|
from tests.common import (
|
||||||
fire_time_changed, mock_component
|
assert_setup_component, get_test_home_assistant, fire_time_changed,
|
||||||
|
mock_component, mock_service, mock_restore_cache)
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
@ -22,14 +24,7 @@ class TestAutomation(unittest.TestCase):
|
||||||
"""Setup things to be run when tests are started."""
|
"""Setup things to be run when tests are started."""
|
||||||
self.hass = get_test_home_assistant()
|
self.hass = get_test_home_assistant()
|
||||||
mock_component(self.hass, 'group')
|
mock_component(self.hass, 'group')
|
||||||
self.calls = []
|
self.calls = mock_service(self.hass, 'test', 'automation')
|
||||||
|
|
||||||
@callback
|
|
||||||
def record_call(service):
|
|
||||||
"""Helper to record calls."""
|
|
||||||
self.calls.append(service)
|
|
||||||
|
|
||||||
self.hass.services.register('test', 'automation', record_call)
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
"""Stop everything that was started."""
|
"""Stop everything that was started."""
|
||||||
|
@ -572,3 +567,56 @@ class TestAutomation(unittest.TestCase):
|
||||||
self.hass.bus.fire('test_event')
|
self.hass.bus.fire('test_event')
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
assert len(self.calls) == 2
|
assert len(self.calls) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_automation_restore_state(hass):
|
||||||
|
"""Ensure states are restored on startup."""
|
||||||
|
time = dt_util.utcnow()
|
||||||
|
|
||||||
|
mock_restore_cache(hass, (
|
||||||
|
State('automation.hello', STATE_ON),
|
||||||
|
State('automation.bye', STATE_OFF, {'last_triggered': time}),
|
||||||
|
))
|
||||||
|
|
||||||
|
config = {automation.DOMAIN: [{
|
||||||
|
'alias': 'hello',
|
||||||
|
'trigger': {
|
||||||
|
'platform': 'event',
|
||||||
|
'event_type': 'test_event_hello',
|
||||||
|
},
|
||||||
|
'action': {'service': 'test.automation'}
|
||||||
|
}, {
|
||||||
|
'alias': 'bye',
|
||||||
|
'trigger': {
|
||||||
|
'platform': 'event',
|
||||||
|
'event_type': 'test_event_bye',
|
||||||
|
},
|
||||||
|
'action': {'service': 'test.automation'}
|
||||||
|
}]}
|
||||||
|
|
||||||
|
assert (yield from async_setup_component(hass, automation.DOMAIN, config))
|
||||||
|
|
||||||
|
state = hass.states.get('automation.hello')
|
||||||
|
assert state
|
||||||
|
assert state.state == STATE_ON
|
||||||
|
|
||||||
|
state = hass.states.get('automation.bye')
|
||||||
|
assert state
|
||||||
|
assert state.state == STATE_OFF
|
||||||
|
assert state.attributes.get('last_triggered') == time
|
||||||
|
|
||||||
|
calls = mock_service(hass, 'test', 'automation')
|
||||||
|
|
||||||
|
assert automation.is_on(hass, 'automation.bye') is False
|
||||||
|
|
||||||
|
hass.bus.async_fire('test_event_bye')
|
||||||
|
yield from hass.async_block_till_done()
|
||||||
|
assert len(calls) == 0
|
||||||
|
|
||||||
|
assert automation.is_on(hass, 'automation.hello')
|
||||||
|
|
||||||
|
hass.bus.async_fire('test_event_hello')
|
||||||
|
yield from hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue