Add context to scripts and automations (#16415)

* Add context to script helper

* Update script component

* Add context to automations

* Lint
This commit is contained in:
Paulus Schoutsen 2018-09-04 21:16:24 +02:00 committed by Pascal Vizeli
parent e1501c83f8
commit 746f4ac158
17 changed files with 164 additions and 144 deletions

View file

@ -158,27 +158,26 @@ def async_reload(hass):
return hass.services.async_call(DOMAIN, SERVICE_RELOAD)
@asyncio.coroutine
def async_setup(hass, config):
async def async_setup(hass, config):
"""Set up the automation."""
component = EntityComponent(_LOGGER, DOMAIN, hass,
group_name=GROUP_NAME_ALL_AUTOMATIONS)
yield from _async_process_config(hass, config, component)
await _async_process_config(hass, config, component)
@asyncio.coroutine
def trigger_service_handler(service_call):
async def trigger_service_handler(service_call):
"""Handle automation triggers."""
tasks = []
for entity in component.async_extract_from_service(service_call):
tasks.append(entity.async_trigger(
service_call.data.get(ATTR_VARIABLES), True))
service_call.data.get(ATTR_VARIABLES),
skip_condition=True,
context=service_call.context))
if tasks:
yield from asyncio.wait(tasks, loop=hass.loop)
await asyncio.wait(tasks, loop=hass.loop)
@asyncio.coroutine
def turn_onoff_service_handler(service_call):
async def turn_onoff_service_handler(service_call):
"""Handle automation turn on/off service calls."""
tasks = []
method = 'async_{}'.format(service_call.service)
@ -186,10 +185,9 @@ def async_setup(hass, config):
tasks.append(getattr(entity, method)())
if tasks:
yield from asyncio.wait(tasks, loop=hass.loop)
await asyncio.wait(tasks, loop=hass.loop)
@asyncio.coroutine
def toggle_service_handler(service_call):
async def toggle_service_handler(service_call):
"""Handle automation toggle service calls."""
tasks = []
for entity in component.async_extract_from_service(service_call):
@ -199,15 +197,14 @@ def async_setup(hass, config):
tasks.append(entity.async_turn_on())
if tasks:
yield from asyncio.wait(tasks, loop=hass.loop)
await asyncio.wait(tasks, loop=hass.loop)
@asyncio.coroutine
def reload_service_handler(service_call):
async def reload_service_handler(service_call):
"""Remove all automations and load new ones from config."""
conf = yield from component.async_prepare_reload()
conf = await component.async_prepare_reload()
if conf is None:
return
yield from _async_process_config(hass, conf, component)
await _async_process_config(hass, conf, component)
hass.services.async_register(
DOMAIN, SERVICE_TRIGGER, trigger_service_handler,
@ -272,15 +269,14 @@ class AutomationEntity(ToggleEntity):
"""Return True if entity is on."""
return self._async_detach_triggers is not None
@asyncio.coroutine
def async_added_to_hass(self) -> None:
async def async_added_to_hass(self) -> None:
"""Startup with initial state or previous state."""
if self._initial_state is not None:
enable_automation = self._initial_state
_LOGGER.debug("Automation %s initial state %s from config "
"initial_state", self.entity_id, enable_automation)
else:
state = yield from async_get_last_state(self.hass, self.entity_id)
state = await async_get_last_state(self.hass, self.entity_id)
if state:
enable_automation = state.state == STATE_ON
self._last_triggered = state.attributes.get('last_triggered')
@ -298,54 +294,50 @@ class AutomationEntity(ToggleEntity):
# HomeAssistant is starting up
if self.hass.state == CoreState.not_running:
@asyncio.coroutine
def async_enable_automation(event):
async def async_enable_automation(event):
"""Start automation on startup."""
yield from self.async_enable()
await self.async_enable()
self.hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_START, async_enable_automation)
# HomeAssistant is running
else:
yield from self.async_enable()
await self.async_enable()
@asyncio.coroutine
def async_turn_on(self, **kwargs) -> None:
async def async_turn_on(self, **kwargs) -> None:
"""Turn the entity on and update the state."""
if self.is_on:
return
yield from self.async_enable()
await self.async_enable()
@asyncio.coroutine
def async_turn_off(self, **kwargs) -> None:
async def async_turn_off(self, **kwargs) -> None:
"""Turn the entity off."""
if not self.is_on:
return
self._async_detach_triggers()
self._async_detach_triggers = None
yield from self.async_update_ha_state()
await self.async_update_ha_state()
@asyncio.coroutine
def async_trigger(self, variables, skip_condition=False):
async def async_trigger(self, variables, skip_condition=False,
context=None):
"""Trigger automation.
This method is a coroutine.
"""
if skip_condition or self._cond_func(variables):
yield from self._async_action(self.entity_id, variables)
self.async_set_context(context)
await self._async_action(self.entity_id, variables, context)
self._last_triggered = utcnow()
yield from self.async_update_ha_state()
await self.async_update_ha_state()
@asyncio.coroutine
def async_will_remove_from_hass(self):
async def async_will_remove_from_hass(self):
"""Remove listeners when removing automation from HASS."""
yield from self.async_turn_off()
await self.async_turn_off()
@asyncio.coroutine
def async_enable(self):
async def async_enable(self):
"""Enable this automation entity.
This method is a coroutine.
@ -353,9 +345,9 @@ class AutomationEntity(ToggleEntity):
if self.is_on:
return
self._async_detach_triggers = yield from self._async_attach_triggers(
self._async_detach_triggers = await self._async_attach_triggers(
self.async_trigger)
yield from self.async_update_ha_state()
await self.async_update_ha_state()
@property
def device_state_attributes(self):
@ -368,8 +360,7 @@ class AutomationEntity(ToggleEntity):
}
@asyncio.coroutine
def _async_process_config(hass, config, component):
async def _async_process_config(hass, config, component):
"""Process config and add automations.
This method is a coroutine.
@ -411,20 +402,19 @@ def _async_process_config(hass, config, component):
entities.append(entity)
if entities:
yield from component.async_add_entities(entities)
await component.async_add_entities(entities)
def _async_get_action(hass, config, name):
"""Return an action based on a configuration."""
script_obj = script.Script(hass, config, name)
@asyncio.coroutine
def action(entity_id, variables):
async def action(entity_id, variables, context):
"""Execute an action."""
_LOGGER.info('Executing %s', name)
logbook.async_log_entry(
hass, name, 'has been triggered', DOMAIN, entity_id)
yield from script_obj.async_run(variables)
await script_obj.async_run(variables, context)
return action
@ -448,8 +438,7 @@ def _async_process_if(hass, config, p_config):
return if_action
@asyncio.coroutine
def _async_process_trigger(hass, config, trigger_configs, name, action):
async def _async_process_trigger(hass, config, trigger_configs, name, action):
"""Set up the triggers.
This method is a coroutine.
@ -457,13 +446,13 @@ def _async_process_trigger(hass, config, trigger_configs, name, action):
removes = []
for conf in trigger_configs:
platform = yield from async_prepare_setup_platform(
platform = await async_prepare_setup_platform(
hass, config, DOMAIN, conf.get(CONF_PLATFORM))
if platform is None:
return None
remove = yield from platform.async_trigger(hass, conf, action)
remove = await platform.async_trigger(hass, conf, action)
if not remove:
_LOGGER.error("Error setting up trigger %s", name)

View file

@ -45,11 +45,11 @@ def async_trigger(hass, config, action):
# If event data doesn't match requested schema, skip event
return
hass.async_run_job(action, {
hass.async_run_job(action({
'trigger': {
'platform': 'event',
'event': event,
},
})
}, context=event.context))
return hass.bus.async_listen(event_type, handle_event)

View file

@ -32,12 +32,12 @@ def async_trigger(hass, config, action):
@callback
def hass_shutdown(event):
"""Execute when Home Assistant is shutting down."""
hass.async_run_job(action, {
hass.async_run_job(action({
'trigger': {
'platform': 'homeassistant',
'event': event,
},
})
}, context=event.context))
return hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP,
hass_shutdown)
@ -45,11 +45,11 @@ def async_trigger(hass, config, action):
# Automation are enabled while hass is starting up, fire right away
# Check state because a config reload shouldn't trigger it.
if hass.state == CoreState.starting:
hass.async_run_job(action, {
hass.async_run_job(action({
'trigger': {
'platform': 'homeassistant',
'event': event,
},
})
}))
return lambda: None

View file

@ -66,7 +66,7 @@ def async_trigger(hass, config, action):
@callback
def call_action():
"""Call action with right context."""
hass.async_run_job(action, {
hass.async_run_job(action({
'trigger': {
'platform': 'numeric_state',
'entity_id': entity,
@ -75,7 +75,7 @@ def async_trigger(hass, config, action):
'from_state': from_s,
'to_state': to_s,
}
})
}, context=to_s.context))
matching = check_numeric_state(entity, from_s, to_s)

View file

@ -43,7 +43,7 @@ def async_trigger(hass, config, action):
@callback
def call_action():
"""Call action with right context."""
hass.async_run_job(action, {
hass.async_run_job(action({
'trigger': {
'platform': 'state',
'entity_id': entity,
@ -51,7 +51,7 @@ def async_trigger(hass, config, action):
'to_state': to_s,
'for': time_delta,
}
})
}, context=to_s.context))
# Ignore changes to state attributes if from/to is in use
if (not match_all and from_s is not None and to_s is not None and

View file

@ -32,13 +32,13 @@ def async_trigger(hass, config, action):
@callback
def template_listener(entity_id, from_s, to_s):
"""Listen for state changes and calls action."""
hass.async_run_job(action, {
hass.async_run_job(action({
'trigger': {
'platform': 'template',
'entity_id': entity_id,
'from_state': from_s,
'to_state': to_s,
},
})
}, context=to_s.context))
return async_track_template(hass, value_template, template_listener)

View file

@ -51,7 +51,7 @@ def async_trigger(hass, config, action):
# pylint: disable=too-many-boolean-expressions
if event == EVENT_ENTER and not from_match and to_match or \
event == EVENT_LEAVE and from_match and not to_match:
hass.async_run_job(action, {
hass.async_run_job(action({
'trigger': {
'platform': 'zone',
'entity_id': entity,
@ -60,7 +60,7 @@ def async_trigger(hass, config, action):
'zone': zone_state,
'event': event,
},
})
}, context=to_s.context))
return async_track_state_change(hass, entity_id, zone_automation_listener,
MATCH_ALL, MATCH_ALL)

View file

@ -63,11 +63,11 @@ def is_on(hass, entity_id):
@bind_hass
def turn_on(hass, entity_id, variables=None):
def turn_on(hass, entity_id, variables=None, context=None):
"""Turn script on."""
_, object_id = split_entity_id(entity_id)
hass.services.call(DOMAIN, object_id, variables)
hass.services.call(DOMAIN, object_id, variables, context=context)
@bind_hass
@ -97,45 +97,41 @@ def async_reload(hass):
return hass.services.async_call(DOMAIN, SERVICE_RELOAD)
@asyncio.coroutine
def async_setup(hass, config):
async def async_setup(hass, config):
"""Load the scripts from the configuration."""
component = EntityComponent(
_LOGGER, DOMAIN, hass, group_name=GROUP_NAME_ALL_SCRIPTS)
yield from _async_process_config(hass, config, component)
await _async_process_config(hass, config, component)
@asyncio.coroutine
def reload_service(service):
async def reload_service(service):
"""Call a service to reload scripts."""
conf = yield from component.async_prepare_reload()
conf = await component.async_prepare_reload()
if conf is None:
return
yield from _async_process_config(hass, conf, component)
await _async_process_config(hass, conf, component)
@asyncio.coroutine
def turn_on_service(service):
async def turn_on_service(service):
"""Call a service to turn script on."""
# We could turn on script directly here, but we only want to offer
# one way to do it. Otherwise no easy way to detect invocations.
var = service.data.get(ATTR_VARIABLES)
for script in component.async_extract_from_service(service):
yield from hass.services.async_call(DOMAIN, script.object_id, var)
await hass.services.async_call(DOMAIN, script.object_id, var,
context=service.context)
@asyncio.coroutine
def turn_off_service(service):
async def turn_off_service(service):
"""Cancel a script."""
# Stopping a script is ok to be done in parallel
yield from asyncio.wait(
await asyncio.wait(
[script.async_turn_off() for script
in component.async_extract_from_service(service)], loop=hass.loop)
@asyncio.coroutine
def toggle_service(service):
async def toggle_service(service):
"""Toggle a script."""
for script in component.async_extract_from_service(service):
yield from script.async_toggle()
await script.async_toggle(context=service.context)
hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_service,
schema=RELOAD_SERVICE_SCHEMA)
@ -149,18 +145,17 @@ def async_setup(hass, config):
return True
@asyncio.coroutine
def _async_process_config(hass, config, component):
"""Process group configuration."""
@asyncio.coroutine
def service_handler(service):
async def _async_process_config(hass, config, component):
"""Process script configuration."""
async def service_handler(service):
"""Execute a service call to script.<script name>."""
entity_id = ENTITY_ID_FORMAT.format(service.service)
script = component.get_entity(entity_id)
if script.is_on:
_LOGGER.warning("Script %s already running.", entity_id)
return
yield from script.async_turn_on(variables=service.data)
await script.async_turn_on(variables=service.data,
context=service.context)
scripts = []
@ -171,7 +166,7 @@ def _async_process_config(hass, config, component):
hass.services.async_register(
DOMAIN, object_id, service_handler, schema=SCRIPT_SERVICE_SCHEMA)
yield from component.async_add_entities(scripts)
await component.async_add_entities(scripts)
class ScriptEntity(ToggleEntity):
@ -209,18 +204,16 @@ class ScriptEntity(ToggleEntity):
"""Return true if script is on."""
return self.script.is_running
@asyncio.coroutine
def async_turn_on(self, **kwargs):
async def async_turn_on(self, **kwargs):
"""Turn the script on."""
yield from self.script.async_run(kwargs.get(ATTR_VARIABLES))
await self.script.async_run(
kwargs.get(ATTR_VARIABLES), kwargs.get('context'))
@asyncio.coroutine
def async_turn_off(self, **kwargs):
async def async_turn_off(self, **kwargs):
"""Turn script off."""
self.script.async_stop()
@asyncio.coroutine
def async_will_remove_from_hass(self):
async def async_will_remove_from_hass(self):
"""Stop script and remove service when it will be removed from HASS."""
if self.script.is_running:
self.script.async_stop()

View file

@ -6,7 +6,7 @@ from typing import Optional, Sequence
import voluptuous as vol
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import HomeAssistant, Context, callback
from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import (
@ -34,9 +34,10 @@ CONF_CONTINUE = 'continue_on_timeout'
def call_from_config(hass: HomeAssistant, config: ConfigType,
variables: Optional[Sequence] = None) -> None:
variables: Optional[Sequence] = None,
context: Optional[Context] = None) -> None:
"""Call a script based on a config entry."""
Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables)
Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables, context)
class Script():
@ -64,12 +65,13 @@ class Script():
"""Return true if script is on."""
return self._cur != -1
def run(self, variables=None):
def run(self, variables=None, context=None):
"""Run script."""
run_coroutine_threadsafe(
self.async_run(variables), self.hass.loop).result()
self.async_run(variables, context), self.hass.loop).result()
async def async_run(self, variables: Optional[Sequence] = None) -> None:
async def async_run(self, variables: Optional[Sequence] = None,
context: Optional[Context] = None) -> None:
"""Run script.
This method is a coroutine.
@ -94,7 +96,8 @@ class Script():
"""Handle delay."""
# pylint: disable=cell-var-from-loop
self._async_listener.remove(unsub)
self.hass.async_add_job(self.async_run(variables))
self.hass.async_create_task(
self.async_run(variables, context))
delay = action[CONF_DELAY]
@ -134,7 +137,8 @@ class Script():
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
self._async_remove_listener()
self.hass.async_add_job(self.async_run(variables))
self.hass.async_create_task(
self.async_run(variables, context))
self._async_listener.append(async_track_template(
self.hass, wait_template, async_script_wait, variables))
@ -145,7 +149,8 @@ class Script():
if CONF_TIMEOUT in action:
self._async_set_timeout(
action, variables, action.get(CONF_CONTINUE, True))
action, variables, context,
action.get(CONF_CONTINUE, True))
return
@ -154,10 +159,10 @@ class Script():
break
elif CONF_EVENT in action:
self._async_fire_event(action, variables)
self._async_fire_event(action, variables, context)
else:
await self._async_call_service(action, variables)
await self._async_call_service(action, variables, context)
self._cur = -1
self.last_action = None
@ -178,7 +183,7 @@ class Script():
if self._change_listener:
self.hass.async_add_job(self._change_listener)
async def _async_call_service(self, action, variables):
async def _async_call_service(self, action, variables, context):
"""Call the service specified in the action.
This method is a coroutine.
@ -186,9 +191,14 @@ class Script():
self.last_action = action.get(CONF_ALIAS, 'call service')
self._log("Executing step %s" % self.last_action)
await service.async_call_from_config(
self.hass, action, True, variables, validate_config=False)
self.hass, action,
blocking=True,
variables=variables,
validate_config=False,
context=context
)
def _async_fire_event(self, action, variables):
def _async_fire_event(self, action, variables, context):
"""Fire an event."""
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
self._log("Executing step %s" % self.last_action)
@ -201,7 +211,7 @@ class Script():
_LOGGER.error('Error rendering event data template: %s', ex)
self.hass.bus.async_fire(action[CONF_EVENT],
event_data)
event_data, context=context)
def _async_check_condition(self, action, variables):
"""Test if condition is matching."""
@ -216,7 +226,8 @@ class Script():
self._log("Test condition {}: {}".format(self.last_action, check))
return check
def _async_set_timeout(self, action, variables, continue_on_timeout=True):
def _async_set_timeout(self, action, variables, context,
continue_on_timeout):
"""Schedule a timeout to abort or continue script."""
timeout = action[CONF_TIMEOUT]
unsub = None
@ -229,7 +240,8 @@ class Script():
# Check if we want to continue to execute
# the script after the timeout
if continue_on_timeout:
self.hass.async_add_job(self.async_run(variables))
self.hass.async_create_task(
self.async_run(variables, context))
else:
self._log("Timeout reached, abort script.")
self.async_stop()

View file

@ -36,7 +36,7 @@ def call_from_config(hass, config, blocking=False, variables=None,
@bind_hass
async def async_call_from_config(hass, config, blocking=False, variables=None,
validate_config=True):
validate_config=True, context=None):
"""Call a service based on a config hash."""
if validate_config:
try:
@ -77,7 +77,7 @@ async def async_call_from_config(hass, config, blocking=False, variables=None,
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
await hass.services.async_call(
domain, service_name, service_data, blocking)
domain, service_name, service_data, blocking=blocking, context=context)
@bind_hass

View file

@ -1,7 +1,7 @@
"""The tests for the Event automation."""
import unittest
from homeassistant.core import callback
from homeassistant.core import Context, callback
from homeassistant.setup import setup_component
import homeassistant.components.automation as automation
@ -31,6 +31,8 @@ class TestAutomationEvent(unittest.TestCase):
def test_if_fires_on_event(self):
"""Test the firing of events."""
context = Context()
assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: {
'trigger': {
@ -43,9 +45,10 @@ class TestAutomationEvent(unittest.TestCase):
}
})
self.hass.bus.fire('test_event')
self.hass.bus.fire('test_event', context=context)
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
assert self.calls[0].context is context
automation.turn_off(self.hass)
self.hass.block_till_done()

View file

@ -4,7 +4,7 @@ import unittest
from unittest.mock import patch
import homeassistant.components.automation as automation
from homeassistant.core import callback
from homeassistant.core import Context, callback
from homeassistant.setup import setup_component
import homeassistant.util.dt as dt_util
@ -36,6 +36,7 @@ class TestAutomationNumericState(unittest.TestCase):
def test_if_fires_on_entity_change_below(self):
"""Test the firing with changed entity."""
context = Context()
assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: {
'trigger': {
@ -49,9 +50,10 @@ class TestAutomationNumericState(unittest.TestCase):
}
})
# 9 is below 10
self.hass.states.set('test.entity', 9)
self.hass.states.set('test.entity', 9, context=context)
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
assert self.calls[0].context is context
# Set above 12 so the automation will fire again
self.hass.states.set('test.entity', 12)
@ -116,6 +118,7 @@ class TestAutomationNumericState(unittest.TestCase):
def test_if_not_fires_on_entity_change_below_to_below(self):
"""Test the firing with changed entity."""
context = Context()
self.hass.states.set('test.entity', 11)
self.hass.block_till_done()
@ -133,9 +136,10 @@ class TestAutomationNumericState(unittest.TestCase):
})
# 9 is below 10 so this should fire
self.hass.states.set('test.entity', 9)
self.hass.states.set('test.entity', 9, context=context)
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
assert self.calls[0].context is context
# already below so should not fire again
self.hass.states.set('test.entity', 5)

View file

@ -4,7 +4,7 @@ from datetime import timedelta
import unittest
from unittest.mock import patch
from homeassistant.core import callback
from homeassistant.core import Context, callback
from homeassistant.setup import setup_component
import homeassistant.util.dt as dt_util
import homeassistant.components.automation as automation
@ -38,6 +38,7 @@ class TestAutomationState(unittest.TestCase):
def test_if_fires_on_entity_change(self):
"""Test for firing on entity change."""
context = Context()
self.hass.states.set('test.entity', 'hello')
self.hass.block_till_done()
@ -59,9 +60,10 @@ class TestAutomationState(unittest.TestCase):
}
})
self.hass.states.set('test.entity', 'world')
self.hass.states.set('test.entity', 'world', context=context)
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
assert self.calls[0].context is context
self.assertEqual(
'state - test.entity - hello - world - None',
self.calls[0].data['some'])

View file

@ -1,7 +1,7 @@
"""The tests for the Template automation."""
import unittest
from homeassistant.core import callback
from homeassistant.core import Context, callback
from homeassistant.setup import setup_component
import homeassistant.components.automation as automation
@ -232,15 +232,12 @@ class TestAutomationTemplate(unittest.TestCase):
def test_if_fires_on_change_with_template_advanced(self):
"""Test for firing on change with template advanced."""
context = Context()
assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: {
'trigger': {
'platform': 'template',
'value_template': '''{%- if is_state("test.entity", "world") -%}
true
{%- else -%}
false
{%- endif -%}''',
'value_template': '{{ is_state("test.entity", "world") }}'
},
'action': {
'service': 'test.automation',
@ -257,9 +254,10 @@ class TestAutomationTemplate(unittest.TestCase):
self.hass.block_till_done()
self.calls = []
self.hass.states.set('test.entity', 'world')
self.hass.states.set('test.entity', 'world', context=context)
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
assert self.calls[0].context is context
self.assertEqual(
'template - test.entity - hello - world',
self.calls[0].data['some'])

View file

@ -1,7 +1,7 @@
"""The tests for the location automation."""
import unittest
from homeassistant.core import callback
from homeassistant.core import Context, callback
from homeassistant.setup import setup_component
from homeassistant.components import automation, zone
@ -40,6 +40,7 @@ class TestAutomationZone(unittest.TestCase):
def test_if_fires_on_zone_enter(self):
"""Test for firing on zone enter."""
context = Context()
self.hass.states.set('test.entity', 'hello', {
'latitude': 32.881011,
'longitude': -117.234758
@ -70,10 +71,11 @@ class TestAutomationZone(unittest.TestCase):
self.hass.states.set('test.entity', 'hello', {
'latitude': 32.880586,
'longitude': -117.237564
})
}, context=context)
self.hass.block_till_done()
self.assertEqual(1, len(self.calls))
assert self.calls[0].context is context
self.assertEqual(
'zone - test.entity - hello - hello - test',
self.calls[0].data['some'])

View file

@ -3,7 +3,7 @@
import unittest
from unittest.mock import patch
from homeassistant.core import callback
from homeassistant.core import Context, callback
from homeassistant.setup import setup_component
from homeassistant.components import script
@ -134,6 +134,7 @@ class TestScriptComponent(unittest.TestCase):
def test_passing_variables(self):
"""Test different ways of passing in variables."""
calls = []
context = Context()
@callback
def record_call(service):
@ -157,21 +158,23 @@ class TestScriptComponent(unittest.TestCase):
script.turn_on(self.hass, ENTITY_ID, {
'greeting': 'world'
})
}, context=context)
self.hass.block_till_done()
assert len(calls) == 1
assert calls[-1].data['hello'] == 'world'
assert calls[0].context is context
assert calls[0].data['hello'] == 'world'
self.hass.services.call('script', 'test', {
'greeting': 'universe',
})
}, context=context)
self.hass.block_till_done()
assert len(calls) == 2
assert calls[-1].data['hello'] == 'universe'
assert calls[1].context is context
assert calls[1].data['hello'] == 'universe'
def test_reload_service(self):
"""Verify that the turn_on service."""

View file

@ -4,7 +4,7 @@ from datetime import timedelta
from unittest import mock
import unittest
from homeassistant.core import callback
from homeassistant.core import Context, callback
# Otherwise can't test just this file (import order issue)
import homeassistant.components # noqa
import homeassistant.util.dt as dt_util
@ -32,6 +32,7 @@ class TestScriptHelper(unittest.TestCase):
def test_firing_event(self):
"""Test the firing of events."""
event = 'test_event'
context = Context()
calls = []
@callback
@ -48,17 +49,19 @@ class TestScriptHelper(unittest.TestCase):
}
}))
script_obj.run()
script_obj.run(context=context)
self.hass.block_till_done()
assert len(calls) == 1
assert calls[0].context is context
assert calls[0].data.get('hello') == 'world'
assert not script_obj.can_cancel
def test_firing_event_template(self):
"""Test the firing of events."""
event = 'test_event'
context = Context()
calls = []
@callback
@ -82,11 +85,12 @@ class TestScriptHelper(unittest.TestCase):
}
}))
script_obj.run({'is_world': 'yes'})
script_obj.run({'is_world': 'yes'}, context=context)
self.hass.block_till_done()
assert len(calls) == 1
assert calls[0].context is context
assert calls[0].data == {
'dict': {
1: 'yes',
@ -100,6 +104,7 @@ class TestScriptHelper(unittest.TestCase):
def test_calling_service(self):
"""Test the calling of a service."""
calls = []
context = Context()
@callback
def record_call(service):
@ -113,16 +118,18 @@ class TestScriptHelper(unittest.TestCase):
'data': {
'hello': 'world'
}
})
}, context=context)
self.hass.block_till_done()
assert len(calls) == 1
assert calls[0].context is context
assert calls[0].data.get('hello') == 'world'
def test_calling_service_template(self):
"""Test the calling of a service."""
calls = []
context = Context()
@callback
def record_call(service):
@ -147,17 +154,19 @@ class TestScriptHelper(unittest.TestCase):
{% endif %}
"""
}
}, {'is_world': 'yes'})
}, {'is_world': 'yes'}, context=context)
self.hass.block_till_done()
assert len(calls) == 1
assert calls[0].context is context
assert calls[0].data.get('hello') == 'world'
def test_delay(self):
"""Test the delay."""
event = 'test_event'
events = []
context = Context()
@callback
def record_event(event):
@ -171,7 +180,7 @@ class TestScriptHelper(unittest.TestCase):
{'delay': {'seconds': 5}},
{'event': event}]))
script_obj.run()
script_obj.run(context=context)
self.hass.block_till_done()
assert script_obj.is_running
@ -185,6 +194,8 @@ class TestScriptHelper(unittest.TestCase):
assert not script_obj.is_running
assert len(events) == 2
assert events[0].context is context
assert events[1].context is context
def test_delay_template(self):
"""Test the delay as a template."""
@ -282,6 +293,7 @@ class TestScriptHelper(unittest.TestCase):
"""Test the wait template."""
event = 'test_event'
events = []
context = Context()
@callback
def record_event(event):
@ -297,7 +309,7 @@ class TestScriptHelper(unittest.TestCase):
{'wait_template': "{{states.switch.test.state == 'off'}}"},
{'event': event}]))
script_obj.run()
script_obj.run(context=context)
self.hass.block_till_done()
assert script_obj.is_running
@ -310,6 +322,8 @@ class TestScriptHelper(unittest.TestCase):
assert not script_obj.is_running
assert len(events) == 2
assert events[0].context is context
assert events[1].context is context
def test_wait_template_cancel(self):
"""Test the wait template cancel action."""