diff --git a/homeassistant/components/binary_sensor/template.py b/homeassistant/components/binary_sensor/template.py index fbdfa2eb4de..396f591923b 100644 --- a/homeassistant/components/binary_sensor/template.py +++ b/homeassistant/components/binary_sensor/template.py @@ -16,7 +16,7 @@ from homeassistant.components.binary_sensor import ( from homeassistant.const import ( ATTR_FRIENDLY_NAME, ATTR_ENTITY_ID, CONF_VALUE_TEMPLATE, CONF_SENSOR_CLASS, CONF_SENSORS, CONF_DEVICE_CLASS, - EVENT_HOMEASSISTANT_START) + EVENT_HOMEASSISTANT_START, STATE_ON) from homeassistant.exceptions import TemplateError import homeassistant.helpers.config_validation as cv from homeassistant.helpers.deprecation import get_deprecated @@ -92,7 +92,7 @@ class BinarySensorTemplate(BinarySensorDevice): """Register callbacks.""" state = yield from async_get_last_state(self.hass, self.entity_id) if state: - self._state = state.state + self._state = state.state == STATE_ON @callback def template_bsensor_state_listener(entity, old_state, new_state): diff --git a/homeassistant/components/switch/template.py b/homeassistant/components/switch/template.py index 91ac16fe06c..4ea2d82388d 100644 --- a/homeassistant/components/switch/template.py +++ b/homeassistant/components/switch/template.py @@ -14,12 +14,13 @@ from homeassistant.components.switch import ( ENTITY_ID_FORMAT, SwitchDevice, PLATFORM_SCHEMA) from homeassistant.const import ( ATTR_FRIENDLY_NAME, CONF_VALUE_TEMPLATE, STATE_OFF, STATE_ON, - ATTR_ENTITY_ID, CONF_SWITCHES) + ATTR_ENTITY_ID, CONF_SWITCHES, EVENT_HOMEASSISTANT_START) from homeassistant.exceptions import TemplateError +import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity import async_generate_entity_id from homeassistant.helpers.event import async_track_state_change +from homeassistant.helpers.restore_state import async_get_last_state from homeassistant.helpers.script import Script -import homeassistant.helpers.config_validation as cv _LOGGER = logging.getLogger(__name__) _VALID_STATES = [STATE_ON, STATE_OFF, 'true', 'false'] @@ -88,14 +89,30 @@ class SwitchTemplate(SwitchDevice): self._on_script = Script(hass, on_action) self._off_script = Script(hass, off_action) self._state = False + self._entities = entity_ids + + @asyncio.coroutine + def async_added_to_hass(self): + """Register callbacks.""" + state = yield from async_get_last_state(self.hass, self.entity_id) + if state: + self._state = state.state == STATE_ON @callback def template_switch_state_listener(entity, old_state, new_state): """Called when the target device changes state.""" - hass.async_add_job(self.async_update_ha_state(True)) + self.hass.async_add_job(self.async_update_ha_state(True)) - async_track_state_change( - hass, entity_ids, template_switch_state_listener) + @callback + def template_switch_startup(event): + """Update template on startup.""" + async_track_state_change( + self.hass, self._entities, template_switch_state_listener) + + self.hass.async_add_job(self.async_update_ha_state(True)) + + self.hass.bus.async_listen_once( + EVENT_HOMEASSISTANT_START, template_switch_startup) @property def name(self): diff --git a/tests/components/sensor/test_template.py b/tests/components/sensor/test_template.py index 7ba4ca136e0..adfdc08d510 100644 --- a/tests/components/sensor/test_template.py +++ b/tests/components/sensor/test_template.py @@ -39,6 +39,7 @@ class TestTemplateSensor: }) self.hass.start() + self.hass.block_till_done() state = self.hass.states.get('sensor.test_template_sensor') assert state.state == 'It .' @@ -68,6 +69,7 @@ class TestTemplateSensor: }) self.hass.start() + self.hass.block_till_done() state = self.hass.states.get('sensor.test_template_sensor') assert 'icon' not in state.attributes @@ -93,6 +95,7 @@ class TestTemplateSensor: }) self.hass.start() + self.hass.block_till_done() assert self.hass.states.all() == [] def test_template_attribute_missing(self): @@ -111,6 +114,7 @@ class TestTemplateSensor: }) self.hass.start() + self.hass.block_till_done() state = self.hass.states.get('sensor.test_template_sensor') assert state.state == 'unknown' @@ -131,6 +135,8 @@ class TestTemplateSensor: }) self.hass.start() + self.hass.block_till_done() + assert self.hass.states.all() == [] def test_invalid_sensor_does_not_create(self): @@ -146,6 +152,7 @@ class TestTemplateSensor: }) self.hass.start() + assert self.hass.states.all() == [] def test_no_sensors_does_not_create(self): @@ -158,6 +165,8 @@ class TestTemplateSensor: }) self.hass.start() + self.hass.block_till_done() + assert self.hass.states.all() == [] def test_missing_template_does_not_create(self): @@ -176,6 +185,8 @@ class TestTemplateSensor: }) self.hass.start() + self.hass.block_till_done() + assert self.hass.states.all() == [] diff --git a/tests/components/switch/test_template.py b/tests/components/switch/test_template.py index 2f67564e6e8..dabdaa2b4d7 100644 --- a/tests/components/switch/test_template.py +++ b/tests/components/switch/test_template.py @@ -1,12 +1,14 @@ """The tests for the Template switch platform.""" -from homeassistant.core import callback +import asyncio + +from homeassistant.core import callback, State, CoreState import homeassistant.bootstrap as bootstrap import homeassistant.components as core -from homeassistant.const import ( - STATE_ON, - STATE_OFF) +from homeassistant.const import STATE_ON, STATE_OFF +from homeassistant.helpers.restore_state import DATA_RESTORE_CACHE -from tests.common import get_test_home_assistant, assert_setup_component +from tests.common import ( + get_test_home_assistant, assert_setup_component, mock_component) class TestTemplateSwitch: @@ -55,6 +57,9 @@ class TestTemplateSwitch: } }) + self.hass.start() + self.hass.block_till_done() + state = self.hass.states.set('switch.test_state', STATE_ON) self.hass.block_till_done() @@ -90,6 +95,9 @@ class TestTemplateSwitch: } }) + self.hass.start() + self.hass.block_till_done() + state = self.hass.states.get('switch.test_template_switch') assert state.state == STATE_ON @@ -116,6 +124,9 @@ class TestTemplateSwitch: } }) + self.hass.start() + self.hass.block_till_done() + state = self.hass.states.get('switch.test_template_switch') assert state.state == STATE_OFF @@ -141,6 +152,10 @@ class TestTemplateSwitch: } } }) + + self.hass.start() + self.hass.block_till_done() + assert self.hass.states.all() == [] def test_invalid_name_does_not_create(self): @@ -165,6 +180,10 @@ class TestTemplateSwitch: } } }) + + self.hass.start() + self.hass.block_till_done() + assert self.hass.states.all() == [] def test_invalid_switch_does_not_create(self): @@ -178,6 +197,10 @@ class TestTemplateSwitch: } } }) + + self.hass.start() + self.hass.block_till_done() + assert self.hass.states.all() == [] def test_no_switches_does_not_create(self): @@ -188,6 +211,10 @@ class TestTemplateSwitch: 'platform': 'template' } }) + + self.hass.start() + self.hass.block_till_done() + assert self.hass.states.all() == [] def test_missing_template_does_not_create(self): @@ -212,6 +239,10 @@ class TestTemplateSwitch: } } }) + + self.hass.start() + self.hass.block_till_done() + assert self.hass.states.all() == [] def test_missing_on_does_not_create(self): @@ -236,6 +267,10 @@ class TestTemplateSwitch: } } }) + + self.hass.start() + self.hass.block_till_done() + assert self.hass.states.all() == [] def test_missing_off_does_not_create(self): @@ -260,6 +295,10 @@ class TestTemplateSwitch: } } }) + + self.hass.start() + self.hass.block_till_done() + assert self.hass.states.all() == [] def test_on_action(self): @@ -282,6 +321,10 @@ class TestTemplateSwitch: } } }) + + self.hass.start() + self.hass.block_till_done() + self.hass.states.set('switch.test_state', STATE_OFF) self.hass.block_till_done() @@ -314,6 +357,10 @@ class TestTemplateSwitch: } } }) + + self.hass.start() + self.hass.block_till_done() + self.hass.states.set('switch.test_state', STATE_ON) self.hass.block_till_done() @@ -324,3 +371,44 @@ class TestTemplateSwitch: self.hass.block_till_done() assert len(self.calls) == 1 + + +@asyncio.coroutine +def test_restore_state(hass): + """Ensure states are restored on startup.""" + hass.data[DATA_RESTORE_CACHE] = { + 'switch.test_template_switch': + State('switch.test_template_switch', 'on'), + } + + hass.state = CoreState.starting + mock_component(hass, 'recorder') + + yield from bootstrap.async_setup_component(hass, 'switch', { + 'switch': { + 'platform': 'template', + 'switches': { + 'test_template_switch': { + 'value_template': + "{{ states.switch.test_state.state }}", + 'turn_on': { + 'service': 'switch.turn_on', + 'entity_id': 'switch.test_state' + }, + 'turn_off': { + 'service': 'switch.turn_off', + 'entity_id': 'switch.test_state' + }, + } + } + } + }) + + state = hass.states.get('switch.test_template_switch') + assert state.state == 'on' + + yield from hass.async_start() + yield from hass.async_block_till_done() + + state = hass.states.get('switch.test_template_switch') + assert state.state == 'unavailable'