Add context to scripts and automations (#16415)

* Add context to script helper

* Update script component

* Add context to automations

* Lint
This commit is contained in:
Paulus Schoutsen 2018-09-04 21:16:24 +02:00 committed by Pascal Vizeli
parent e1501c83f8
commit 746f4ac158
17 changed files with 164 additions and 144 deletions

View file

@ -158,27 +158,26 @@ def async_reload(hass):
return hass.services.async_call(DOMAIN, SERVICE_RELOAD) return hass.services.async_call(DOMAIN, SERVICE_RELOAD)
@asyncio.coroutine async def async_setup(hass, config):
def async_setup(hass, config):
"""Set up the automation.""" """Set up the automation."""
component = EntityComponent(_LOGGER, DOMAIN, hass, component = EntityComponent(_LOGGER, DOMAIN, hass,
group_name=GROUP_NAME_ALL_AUTOMATIONS) group_name=GROUP_NAME_ALL_AUTOMATIONS)
yield from _async_process_config(hass, config, component) await _async_process_config(hass, config, component)
@asyncio.coroutine async def trigger_service_handler(service_call):
def trigger_service_handler(service_call):
"""Handle automation triggers.""" """Handle automation triggers."""
tasks = [] tasks = []
for entity in component.async_extract_from_service(service_call): for entity in component.async_extract_from_service(service_call):
tasks.append(entity.async_trigger( tasks.append(entity.async_trigger(
service_call.data.get(ATTR_VARIABLES), True)) service_call.data.get(ATTR_VARIABLES),
skip_condition=True,
context=service_call.context))
if tasks: if tasks:
yield from asyncio.wait(tasks, loop=hass.loop) await asyncio.wait(tasks, loop=hass.loop)
@asyncio.coroutine async def turn_onoff_service_handler(service_call):
def turn_onoff_service_handler(service_call):
"""Handle automation turn on/off service calls.""" """Handle automation turn on/off service calls."""
tasks = [] tasks = []
method = 'async_{}'.format(service_call.service) method = 'async_{}'.format(service_call.service)
@ -186,10 +185,9 @@ def async_setup(hass, config):
tasks.append(getattr(entity, method)()) tasks.append(getattr(entity, method)())
if tasks: if tasks:
yield from asyncio.wait(tasks, loop=hass.loop) await asyncio.wait(tasks, loop=hass.loop)
@asyncio.coroutine async def toggle_service_handler(service_call):
def toggle_service_handler(service_call):
"""Handle automation toggle service calls.""" """Handle automation toggle service calls."""
tasks = [] tasks = []
for entity in component.async_extract_from_service(service_call): for entity in component.async_extract_from_service(service_call):
@ -199,15 +197,14 @@ def async_setup(hass, config):
tasks.append(entity.async_turn_on()) tasks.append(entity.async_turn_on())
if tasks: if tasks:
yield from asyncio.wait(tasks, loop=hass.loop) await asyncio.wait(tasks, loop=hass.loop)
@asyncio.coroutine async def reload_service_handler(service_call):
def reload_service_handler(service_call):
"""Remove all automations and load new ones from config.""" """Remove all automations and load new ones from config."""
conf = yield from component.async_prepare_reload() conf = await component.async_prepare_reload()
if conf is None: if conf is None:
return return
yield from _async_process_config(hass, conf, component) await _async_process_config(hass, conf, component)
hass.services.async_register( hass.services.async_register(
DOMAIN, SERVICE_TRIGGER, trigger_service_handler, DOMAIN, SERVICE_TRIGGER, trigger_service_handler,
@ -272,15 +269,14 @@ class AutomationEntity(ToggleEntity):
"""Return True if entity is on.""" """Return True if entity is on."""
return self._async_detach_triggers is not None return self._async_detach_triggers is not None
@asyncio.coroutine async def async_added_to_hass(self) -> None:
def async_added_to_hass(self) -> None:
"""Startup with initial state or previous state.""" """Startup with initial state or previous state."""
if self._initial_state is not None: if self._initial_state is not None:
enable_automation = self._initial_state enable_automation = self._initial_state
_LOGGER.debug("Automation %s initial state %s from config " _LOGGER.debug("Automation %s initial state %s from config "
"initial_state", self.entity_id, enable_automation) "initial_state", self.entity_id, enable_automation)
else: else:
state = yield from async_get_last_state(self.hass, self.entity_id) state = await async_get_last_state(self.hass, self.entity_id)
if state: if state:
enable_automation = state.state == STATE_ON enable_automation = state.state == STATE_ON
self._last_triggered = state.attributes.get('last_triggered') self._last_triggered = state.attributes.get('last_triggered')
@ -298,54 +294,50 @@ class AutomationEntity(ToggleEntity):
# HomeAssistant is starting up # HomeAssistant is starting up
if self.hass.state == CoreState.not_running: if self.hass.state == CoreState.not_running:
@asyncio.coroutine async def async_enable_automation(event):
def async_enable_automation(event):
"""Start automation on startup.""" """Start automation on startup."""
yield from self.async_enable() await self.async_enable()
self.hass.bus.async_listen_once( self.hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_START, async_enable_automation) EVENT_HOMEASSISTANT_START, async_enable_automation)
# HomeAssistant is running # HomeAssistant is running
else: else:
yield from self.async_enable() await self.async_enable()
@asyncio.coroutine async def async_turn_on(self, **kwargs) -> None:
def async_turn_on(self, **kwargs) -> None:
"""Turn the entity on and update the state.""" """Turn the entity on and update the state."""
if self.is_on: if self.is_on:
return return
yield from self.async_enable() await self.async_enable()
@asyncio.coroutine async def async_turn_off(self, **kwargs) -> None:
def async_turn_off(self, **kwargs) -> None:
"""Turn the entity off.""" """Turn the entity off."""
if not self.is_on: if not self.is_on:
return return
self._async_detach_triggers() self._async_detach_triggers()
self._async_detach_triggers = None self._async_detach_triggers = None
yield from self.async_update_ha_state() await self.async_update_ha_state()
@asyncio.coroutine async def async_trigger(self, variables, skip_condition=False,
def async_trigger(self, variables, skip_condition=False): context=None):
"""Trigger automation. """Trigger automation.
This method is a coroutine. This method is a coroutine.
""" """
if skip_condition or self._cond_func(variables): if skip_condition or self._cond_func(variables):
yield from self._async_action(self.entity_id, variables) self.async_set_context(context)
await self._async_action(self.entity_id, variables, context)
self._last_triggered = utcnow() self._last_triggered = utcnow()
yield from self.async_update_ha_state() await self.async_update_ha_state()
@asyncio.coroutine async def async_will_remove_from_hass(self):
def async_will_remove_from_hass(self):
"""Remove listeners when removing automation from HASS.""" """Remove listeners when removing automation from HASS."""
yield from self.async_turn_off() await self.async_turn_off()
@asyncio.coroutine async def async_enable(self):
def async_enable(self):
"""Enable this automation entity. """Enable this automation entity.
This method is a coroutine. This method is a coroutine.
@ -353,9 +345,9 @@ class AutomationEntity(ToggleEntity):
if self.is_on: if self.is_on:
return return
self._async_detach_triggers = yield from self._async_attach_triggers( self._async_detach_triggers = await self._async_attach_triggers(
self.async_trigger) self.async_trigger)
yield from self.async_update_ha_state() await self.async_update_ha_state()
@property @property
def device_state_attributes(self): def device_state_attributes(self):
@ -368,8 +360,7 @@ class AutomationEntity(ToggleEntity):
} }
@asyncio.coroutine async def _async_process_config(hass, config, component):
def _async_process_config(hass, config, component):
"""Process config and add automations. """Process config and add automations.
This method is a coroutine. This method is a coroutine.
@ -411,20 +402,19 @@ def _async_process_config(hass, config, component):
entities.append(entity) entities.append(entity)
if entities: if entities:
yield from component.async_add_entities(entities) await component.async_add_entities(entities)
def _async_get_action(hass, config, name): def _async_get_action(hass, config, name):
"""Return an action based on a configuration.""" """Return an action based on a configuration."""
script_obj = script.Script(hass, config, name) script_obj = script.Script(hass, config, name)
@asyncio.coroutine async def action(entity_id, variables, context):
def action(entity_id, variables):
"""Execute an action.""" """Execute an action."""
_LOGGER.info('Executing %s', name) _LOGGER.info('Executing %s', name)
logbook.async_log_entry( logbook.async_log_entry(
hass, name, 'has been triggered', DOMAIN, entity_id) hass, name, 'has been triggered', DOMAIN, entity_id)
yield from script_obj.async_run(variables) await script_obj.async_run(variables, context)
return action return action
@ -448,8 +438,7 @@ def _async_process_if(hass, config, p_config):
return if_action return if_action
@asyncio.coroutine async def _async_process_trigger(hass, config, trigger_configs, name, action):
def _async_process_trigger(hass, config, trigger_configs, name, action):
"""Set up the triggers. """Set up the triggers.
This method is a coroutine. This method is a coroutine.
@ -457,13 +446,13 @@ def _async_process_trigger(hass, config, trigger_configs, name, action):
removes = [] removes = []
for conf in trigger_configs: for conf in trigger_configs:
platform = yield from async_prepare_setup_platform( platform = await async_prepare_setup_platform(
hass, config, DOMAIN, conf.get(CONF_PLATFORM)) hass, config, DOMAIN, conf.get(CONF_PLATFORM))
if platform is None: if platform is None:
return None return None
remove = yield from platform.async_trigger(hass, conf, action) remove = await platform.async_trigger(hass, conf, action)
if not remove: if not remove:
_LOGGER.error("Error setting up trigger %s", name) _LOGGER.error("Error setting up trigger %s", name)

View file

@ -45,11 +45,11 @@ def async_trigger(hass, config, action):
# If event data doesn't match requested schema, skip event # If event data doesn't match requested schema, skip event
return return
hass.async_run_job(action, { hass.async_run_job(action({
'trigger': { 'trigger': {
'platform': 'event', 'platform': 'event',
'event': event, 'event': event,
}, },
}) }, context=event.context))
return hass.bus.async_listen(event_type, handle_event) return hass.bus.async_listen(event_type, handle_event)

View file

@ -32,12 +32,12 @@ def async_trigger(hass, config, action):
@callback @callback
def hass_shutdown(event): def hass_shutdown(event):
"""Execute when Home Assistant is shutting down.""" """Execute when Home Assistant is shutting down."""
hass.async_run_job(action, { hass.async_run_job(action({
'trigger': { 'trigger': {
'platform': 'homeassistant', 'platform': 'homeassistant',
'event': event, 'event': event,
}, },
}) }, context=event.context))
return hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, return hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP,
hass_shutdown) hass_shutdown)
@ -45,11 +45,11 @@ def async_trigger(hass, config, action):
# Automation are enabled while hass is starting up, fire right away # Automation are enabled while hass is starting up, fire right away
# Check state because a config reload shouldn't trigger it. # Check state because a config reload shouldn't trigger it.
if hass.state == CoreState.starting: if hass.state == CoreState.starting:
hass.async_run_job(action, { hass.async_run_job(action({
'trigger': { 'trigger': {
'platform': 'homeassistant', 'platform': 'homeassistant',
'event': event, 'event': event,
}, },
}) }))
return lambda: None return lambda: None

View file

@ -66,7 +66,7 @@ def async_trigger(hass, config, action):
@callback @callback
def call_action(): def call_action():
"""Call action with right context.""" """Call action with right context."""
hass.async_run_job(action, { hass.async_run_job(action({
'trigger': { 'trigger': {
'platform': 'numeric_state', 'platform': 'numeric_state',
'entity_id': entity, 'entity_id': entity,
@ -75,7 +75,7 @@ def async_trigger(hass, config, action):
'from_state': from_s, 'from_state': from_s,
'to_state': to_s, 'to_state': to_s,
} }
}) }, context=to_s.context))
matching = check_numeric_state(entity, from_s, to_s) matching = check_numeric_state(entity, from_s, to_s)

View file

@ -43,7 +43,7 @@ def async_trigger(hass, config, action):
@callback @callback
def call_action(): def call_action():
"""Call action with right context.""" """Call action with right context."""
hass.async_run_job(action, { hass.async_run_job(action({
'trigger': { 'trigger': {
'platform': 'state', 'platform': 'state',
'entity_id': entity, 'entity_id': entity,
@ -51,7 +51,7 @@ def async_trigger(hass, config, action):
'to_state': to_s, 'to_state': to_s,
'for': time_delta, 'for': time_delta,
} }
}) }, context=to_s.context))
# Ignore changes to state attributes if from/to is in use # Ignore changes to state attributes if from/to is in use
if (not match_all and from_s is not None and to_s is not None and if (not match_all and from_s is not None and to_s is not None and

View file

@ -32,13 +32,13 @@ def async_trigger(hass, config, action):
@callback @callback
def template_listener(entity_id, from_s, to_s): def template_listener(entity_id, from_s, to_s):
"""Listen for state changes and calls action.""" """Listen for state changes and calls action."""
hass.async_run_job(action, { hass.async_run_job(action({
'trigger': { 'trigger': {
'platform': 'template', 'platform': 'template',
'entity_id': entity_id, 'entity_id': entity_id,
'from_state': from_s, 'from_state': from_s,
'to_state': to_s, 'to_state': to_s,
}, },
}) }, context=to_s.context))
return async_track_template(hass, value_template, template_listener) return async_track_template(hass, value_template, template_listener)

View file

@ -51,7 +51,7 @@ def async_trigger(hass, config, action):
# pylint: disable=too-many-boolean-expressions # pylint: disable=too-many-boolean-expressions
if event == EVENT_ENTER and not from_match and to_match or \ if event == EVENT_ENTER and not from_match and to_match or \
event == EVENT_LEAVE and from_match and not to_match: event == EVENT_LEAVE and from_match and not to_match:
hass.async_run_job(action, { hass.async_run_job(action({
'trigger': { 'trigger': {
'platform': 'zone', 'platform': 'zone',
'entity_id': entity, 'entity_id': entity,
@ -60,7 +60,7 @@ def async_trigger(hass, config, action):
'zone': zone_state, 'zone': zone_state,
'event': event, 'event': event,
}, },
}) }, context=to_s.context))
return async_track_state_change(hass, entity_id, zone_automation_listener, return async_track_state_change(hass, entity_id, zone_automation_listener,
MATCH_ALL, MATCH_ALL) MATCH_ALL, MATCH_ALL)

View file

@ -63,11 +63,11 @@ def is_on(hass, entity_id):
@bind_hass @bind_hass
def turn_on(hass, entity_id, variables=None): def turn_on(hass, entity_id, variables=None, context=None):
"""Turn script on.""" """Turn script on."""
_, object_id = split_entity_id(entity_id) _, object_id = split_entity_id(entity_id)
hass.services.call(DOMAIN, object_id, variables) hass.services.call(DOMAIN, object_id, variables, context=context)
@bind_hass @bind_hass
@ -97,45 +97,41 @@ def async_reload(hass):
return hass.services.async_call(DOMAIN, SERVICE_RELOAD) return hass.services.async_call(DOMAIN, SERVICE_RELOAD)
@asyncio.coroutine async def async_setup(hass, config):
def async_setup(hass, config):
"""Load the scripts from the configuration.""" """Load the scripts from the configuration."""
component = EntityComponent( component = EntityComponent(
_LOGGER, DOMAIN, hass, group_name=GROUP_NAME_ALL_SCRIPTS) _LOGGER, DOMAIN, hass, group_name=GROUP_NAME_ALL_SCRIPTS)
yield from _async_process_config(hass, config, component) await _async_process_config(hass, config, component)
@asyncio.coroutine async def reload_service(service):
def reload_service(service):
"""Call a service to reload scripts.""" """Call a service to reload scripts."""
conf = yield from component.async_prepare_reload() conf = await component.async_prepare_reload()
if conf is None: if conf is None:
return return
yield from _async_process_config(hass, conf, component) await _async_process_config(hass, conf, component)
@asyncio.coroutine async def turn_on_service(service):
def turn_on_service(service):
"""Call a service to turn script on.""" """Call a service to turn script on."""
# We could turn on script directly here, but we only want to offer # We could turn on script directly here, but we only want to offer
# one way to do it. Otherwise no easy way to detect invocations. # one way to do it. Otherwise no easy way to detect invocations.
var = service.data.get(ATTR_VARIABLES) var = service.data.get(ATTR_VARIABLES)
for script in component.async_extract_from_service(service): for script in component.async_extract_from_service(service):
yield from hass.services.async_call(DOMAIN, script.object_id, var) await hass.services.async_call(DOMAIN, script.object_id, var,
context=service.context)
@asyncio.coroutine async def turn_off_service(service):
def turn_off_service(service):
"""Cancel a script.""" """Cancel a script."""
# Stopping a script is ok to be done in parallel # Stopping a script is ok to be done in parallel
yield from asyncio.wait( await asyncio.wait(
[script.async_turn_off() for script [script.async_turn_off() for script
in component.async_extract_from_service(service)], loop=hass.loop) in component.async_extract_from_service(service)], loop=hass.loop)
@asyncio.coroutine async def toggle_service(service):
def toggle_service(service):
"""Toggle a script.""" """Toggle a script."""
for script in component.async_extract_from_service(service): for script in component.async_extract_from_service(service):
yield from script.async_toggle() await script.async_toggle(context=service.context)
hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_service, hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_service,
schema=RELOAD_SERVICE_SCHEMA) schema=RELOAD_SERVICE_SCHEMA)
@ -149,18 +145,17 @@ def async_setup(hass, config):
return True return True
@asyncio.coroutine async def _async_process_config(hass, config, component):
def _async_process_config(hass, config, component): """Process script configuration."""
"""Process group configuration.""" async def service_handler(service):
@asyncio.coroutine
def service_handler(service):
"""Execute a service call to script.<script name>.""" """Execute a service call to script.<script name>."""
entity_id = ENTITY_ID_FORMAT.format(service.service) entity_id = ENTITY_ID_FORMAT.format(service.service)
script = component.get_entity(entity_id) script = component.get_entity(entity_id)
if script.is_on: if script.is_on:
_LOGGER.warning("Script %s already running.", entity_id) _LOGGER.warning("Script %s already running.", entity_id)
return return
yield from script.async_turn_on(variables=service.data) await script.async_turn_on(variables=service.data,
context=service.context)
scripts = [] scripts = []
@ -171,7 +166,7 @@ def _async_process_config(hass, config, component):
hass.services.async_register( hass.services.async_register(
DOMAIN, object_id, service_handler, schema=SCRIPT_SERVICE_SCHEMA) DOMAIN, object_id, service_handler, schema=SCRIPT_SERVICE_SCHEMA)
yield from component.async_add_entities(scripts) await component.async_add_entities(scripts)
class ScriptEntity(ToggleEntity): class ScriptEntity(ToggleEntity):
@ -209,18 +204,16 @@ class ScriptEntity(ToggleEntity):
"""Return true if script is on.""" """Return true if script is on."""
return self.script.is_running return self.script.is_running
@asyncio.coroutine async def async_turn_on(self, **kwargs):
def async_turn_on(self, **kwargs):
"""Turn the script on.""" """Turn the script on."""
yield from self.script.async_run(kwargs.get(ATTR_VARIABLES)) await self.script.async_run(
kwargs.get(ATTR_VARIABLES), kwargs.get('context'))
@asyncio.coroutine async def async_turn_off(self, **kwargs):
def async_turn_off(self, **kwargs):
"""Turn script off.""" """Turn script off."""
self.script.async_stop() self.script.async_stop()
@asyncio.coroutine async def async_will_remove_from_hass(self):
def async_will_remove_from_hass(self):
"""Stop script and remove service when it will be removed from HASS.""" """Stop script and remove service when it will be removed from HASS."""
if self.script.is_running: if self.script.is_running:
self.script.async_stop() self.script.async_stop()

View file

@ -6,7 +6,7 @@ from typing import Optional, Sequence
import voluptuous as vol import voluptuous as vol
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, Context, callback
from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.helpers import ( from homeassistant.helpers import (
@ -34,9 +34,10 @@ CONF_CONTINUE = 'continue_on_timeout'
def call_from_config(hass: HomeAssistant, config: ConfigType, def call_from_config(hass: HomeAssistant, config: ConfigType,
variables: Optional[Sequence] = None) -> None: variables: Optional[Sequence] = None,
context: Optional[Context] = None) -> None:
"""Call a script based on a config entry.""" """Call a script based on a config entry."""
Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables) Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables, context)
class Script(): class Script():
@ -64,12 +65,13 @@ class Script():
"""Return true if script is on.""" """Return true if script is on."""
return self._cur != -1 return self._cur != -1
def run(self, variables=None): def run(self, variables=None, context=None):
"""Run script.""" """Run script."""
run_coroutine_threadsafe( run_coroutine_threadsafe(
self.async_run(variables), self.hass.loop).result() self.async_run(variables, context), self.hass.loop).result()
async def async_run(self, variables: Optional[Sequence] = None) -> None: async def async_run(self, variables: Optional[Sequence] = None,
context: Optional[Context] = None) -> None:
"""Run script. """Run script.
This method is a coroutine. This method is a coroutine.
@ -94,7 +96,8 @@ class Script():
"""Handle delay.""" """Handle delay."""
# pylint: disable=cell-var-from-loop # pylint: disable=cell-var-from-loop
self._async_listener.remove(unsub) self._async_listener.remove(unsub)
self.hass.async_add_job(self.async_run(variables)) self.hass.async_create_task(
self.async_run(variables, context))
delay = action[CONF_DELAY] delay = action[CONF_DELAY]
@ -134,7 +137,8 @@ class Script():
def async_script_wait(entity_id, from_s, to_s): def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true.""" """Handle script after template condition is true."""
self._async_remove_listener() self._async_remove_listener()
self.hass.async_add_job(self.async_run(variables)) self.hass.async_create_task(
self.async_run(variables, context))
self._async_listener.append(async_track_template( self._async_listener.append(async_track_template(
self.hass, wait_template, async_script_wait, variables)) self.hass, wait_template, async_script_wait, variables))
@ -145,7 +149,8 @@ class Script():
if CONF_TIMEOUT in action: if CONF_TIMEOUT in action:
self._async_set_timeout( self._async_set_timeout(
action, variables, action.get(CONF_CONTINUE, True)) action, variables, context,
action.get(CONF_CONTINUE, True))
return return
@ -154,10 +159,10 @@ class Script():
break break
elif CONF_EVENT in action: elif CONF_EVENT in action:
self._async_fire_event(action, variables) self._async_fire_event(action, variables, context)
else: else:
await self._async_call_service(action, variables) await self._async_call_service(action, variables, context)
self._cur = -1 self._cur = -1
self.last_action = None self.last_action = None
@ -178,7 +183,7 @@ class Script():
if self._change_listener: if self._change_listener:
self.hass.async_add_job(self._change_listener) self.hass.async_add_job(self._change_listener)
async def _async_call_service(self, action, variables): async def _async_call_service(self, action, variables, context):
"""Call the service specified in the action. """Call the service specified in the action.
This method is a coroutine. This method is a coroutine.
@ -186,9 +191,14 @@ class Script():
self.last_action = action.get(CONF_ALIAS, 'call service') self.last_action = action.get(CONF_ALIAS, 'call service')
self._log("Executing step %s" % self.last_action) self._log("Executing step %s" % self.last_action)
await service.async_call_from_config( await service.async_call_from_config(
self.hass, action, True, variables, validate_config=False) self.hass, action,
blocking=True,
variables=variables,
validate_config=False,
context=context
)
def _async_fire_event(self, action, variables): def _async_fire_event(self, action, variables, context):
"""Fire an event.""" """Fire an event."""
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT]) self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
self._log("Executing step %s" % self.last_action) self._log("Executing step %s" % self.last_action)
@ -201,7 +211,7 @@ class Script():
_LOGGER.error('Error rendering event data template: %s', ex) _LOGGER.error('Error rendering event data template: %s', ex)
self.hass.bus.async_fire(action[CONF_EVENT], self.hass.bus.async_fire(action[CONF_EVENT],
event_data) event_data, context=context)
def _async_check_condition(self, action, variables): def _async_check_condition(self, action, variables):
"""Test if condition is matching.""" """Test if condition is matching."""
@ -216,7 +226,8 @@ class Script():
self._log("Test condition {}: {}".format(self.last_action, check)) self._log("Test condition {}: {}".format(self.last_action, check))
return check return check
def _async_set_timeout(self, action, variables, continue_on_timeout=True): def _async_set_timeout(self, action, variables, context,
continue_on_timeout):
"""Schedule a timeout to abort or continue script.""" """Schedule a timeout to abort or continue script."""
timeout = action[CONF_TIMEOUT] timeout = action[CONF_TIMEOUT]
unsub = None unsub = None
@ -229,7 +240,8 @@ class Script():
# Check if we want to continue to execute # Check if we want to continue to execute
# the script after the timeout # the script after the timeout
if continue_on_timeout: if continue_on_timeout:
self.hass.async_add_job(self.async_run(variables)) self.hass.async_create_task(
self.async_run(variables, context))
else: else:
self._log("Timeout reached, abort script.") self._log("Timeout reached, abort script.")
self.async_stop() self.async_stop()

View file

@ -36,7 +36,7 @@ def call_from_config(hass, config, blocking=False, variables=None,
@bind_hass @bind_hass
async def async_call_from_config(hass, config, blocking=False, variables=None, async def async_call_from_config(hass, config, blocking=False, variables=None,
validate_config=True): validate_config=True, context=None):
"""Call a service based on a config hash.""" """Call a service based on a config hash."""
if validate_config: if validate_config:
try: try:
@ -77,7 +77,7 @@ async def async_call_from_config(hass, config, blocking=False, variables=None,
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID] service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
await hass.services.async_call( await hass.services.async_call(
domain, service_name, service_data, blocking) domain, service_name, service_data, blocking=blocking, context=context)
@bind_hass @bind_hass

View file

@ -1,7 +1,7 @@
"""The tests for the Event automation.""" """The tests for the Event automation."""
import unittest import unittest
from homeassistant.core import callback from homeassistant.core import Context, callback
from homeassistant.setup import setup_component from homeassistant.setup import setup_component
import homeassistant.components.automation as automation import homeassistant.components.automation as automation
@ -31,6 +31,8 @@ class TestAutomationEvent(unittest.TestCase):
def test_if_fires_on_event(self): def test_if_fires_on_event(self):
"""Test the firing of events.""" """Test the firing of events."""
context = Context()
assert setup_component(self.hass, automation.DOMAIN, { assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: { automation.DOMAIN: {
'trigger': { 'trigger': {
@ -43,9 +45,10 @@ class TestAutomationEvent(unittest.TestCase):
} }
}) })
self.hass.bus.fire('test_event') self.hass.bus.fire('test_event', context=context)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(1, len(self.calls)) self.assertEqual(1, len(self.calls))
assert self.calls[0].context is context
automation.turn_off(self.hass) automation.turn_off(self.hass)
self.hass.block_till_done() self.hass.block_till_done()

View file

@ -4,7 +4,7 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
import homeassistant.components.automation as automation import homeassistant.components.automation as automation
from homeassistant.core import callback from homeassistant.core import Context, callback
from homeassistant.setup import setup_component from homeassistant.setup import setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -36,6 +36,7 @@ class TestAutomationNumericState(unittest.TestCase):
def test_if_fires_on_entity_change_below(self): def test_if_fires_on_entity_change_below(self):
"""Test the firing with changed entity.""" """Test the firing with changed entity."""
context = Context()
assert setup_component(self.hass, automation.DOMAIN, { assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: { automation.DOMAIN: {
'trigger': { 'trigger': {
@ -49,9 +50,10 @@ class TestAutomationNumericState(unittest.TestCase):
} }
}) })
# 9 is below 10 # 9 is below 10
self.hass.states.set('test.entity', 9) self.hass.states.set('test.entity', 9, context=context)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(1, len(self.calls)) self.assertEqual(1, len(self.calls))
assert self.calls[0].context is context
# Set above 12 so the automation will fire again # Set above 12 so the automation will fire again
self.hass.states.set('test.entity', 12) self.hass.states.set('test.entity', 12)
@ -116,6 +118,7 @@ class TestAutomationNumericState(unittest.TestCase):
def test_if_not_fires_on_entity_change_below_to_below(self): def test_if_not_fires_on_entity_change_below_to_below(self):
"""Test the firing with changed entity.""" """Test the firing with changed entity."""
context = Context()
self.hass.states.set('test.entity', 11) self.hass.states.set('test.entity', 11)
self.hass.block_till_done() self.hass.block_till_done()
@ -133,9 +136,10 @@ class TestAutomationNumericState(unittest.TestCase):
}) })
# 9 is below 10 so this should fire # 9 is below 10 so this should fire
self.hass.states.set('test.entity', 9) self.hass.states.set('test.entity', 9, context=context)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(1, len(self.calls)) self.assertEqual(1, len(self.calls))
assert self.calls[0].context is context
# already below so should not fire again # already below so should not fire again
self.hass.states.set('test.entity', 5) self.hass.states.set('test.entity', 5)

View file

@ -4,7 +4,7 @@ from datetime import timedelta
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
from homeassistant.core import callback from homeassistant.core import Context, callback
from homeassistant.setup import setup_component from homeassistant.setup import setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
import homeassistant.components.automation as automation import homeassistant.components.automation as automation
@ -38,6 +38,7 @@ class TestAutomationState(unittest.TestCase):
def test_if_fires_on_entity_change(self): def test_if_fires_on_entity_change(self):
"""Test for firing on entity change.""" """Test for firing on entity change."""
context = Context()
self.hass.states.set('test.entity', 'hello') self.hass.states.set('test.entity', 'hello')
self.hass.block_till_done() self.hass.block_till_done()
@ -59,9 +60,10 @@ class TestAutomationState(unittest.TestCase):
} }
}) })
self.hass.states.set('test.entity', 'world') self.hass.states.set('test.entity', 'world', context=context)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(1, len(self.calls)) self.assertEqual(1, len(self.calls))
assert self.calls[0].context is context
self.assertEqual( self.assertEqual(
'state - test.entity - hello - world - None', 'state - test.entity - hello - world - None',
self.calls[0].data['some']) self.calls[0].data['some'])

View file

@ -1,7 +1,7 @@
"""The tests for the Template automation.""" """The tests for the Template automation."""
import unittest import unittest
from homeassistant.core import callback from homeassistant.core import Context, callback
from homeassistant.setup import setup_component from homeassistant.setup import setup_component
import homeassistant.components.automation as automation import homeassistant.components.automation as automation
@ -232,15 +232,12 @@ class TestAutomationTemplate(unittest.TestCase):
def test_if_fires_on_change_with_template_advanced(self): def test_if_fires_on_change_with_template_advanced(self):
"""Test for firing on change with template advanced.""" """Test for firing on change with template advanced."""
context = Context()
assert setup_component(self.hass, automation.DOMAIN, { assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: { automation.DOMAIN: {
'trigger': { 'trigger': {
'platform': 'template', 'platform': 'template',
'value_template': '''{%- if is_state("test.entity", "world") -%} 'value_template': '{{ is_state("test.entity", "world") }}'
true
{%- else -%}
false
{%- endif -%}''',
}, },
'action': { 'action': {
'service': 'test.automation', 'service': 'test.automation',
@ -257,9 +254,10 @@ class TestAutomationTemplate(unittest.TestCase):
self.hass.block_till_done() self.hass.block_till_done()
self.calls = [] self.calls = []
self.hass.states.set('test.entity', 'world') self.hass.states.set('test.entity', 'world', context=context)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(1, len(self.calls)) self.assertEqual(1, len(self.calls))
assert self.calls[0].context is context
self.assertEqual( self.assertEqual(
'template - test.entity - hello - world', 'template - test.entity - hello - world',
self.calls[0].data['some']) self.calls[0].data['some'])

View file

@ -1,7 +1,7 @@
"""The tests for the location automation.""" """The tests for the location automation."""
import unittest import unittest
from homeassistant.core import callback from homeassistant.core import Context, callback
from homeassistant.setup import setup_component from homeassistant.setup import setup_component
from homeassistant.components import automation, zone from homeassistant.components import automation, zone
@ -40,6 +40,7 @@ class TestAutomationZone(unittest.TestCase):
def test_if_fires_on_zone_enter(self): def test_if_fires_on_zone_enter(self):
"""Test for firing on zone enter.""" """Test for firing on zone enter."""
context = Context()
self.hass.states.set('test.entity', 'hello', { self.hass.states.set('test.entity', 'hello', {
'latitude': 32.881011, 'latitude': 32.881011,
'longitude': -117.234758 'longitude': -117.234758
@ -70,10 +71,11 @@ class TestAutomationZone(unittest.TestCase):
self.hass.states.set('test.entity', 'hello', { self.hass.states.set('test.entity', 'hello', {
'latitude': 32.880586, 'latitude': 32.880586,
'longitude': -117.237564 'longitude': -117.237564
}) }, context=context)
self.hass.block_till_done() self.hass.block_till_done()
self.assertEqual(1, len(self.calls)) self.assertEqual(1, len(self.calls))
assert self.calls[0].context is context
self.assertEqual( self.assertEqual(
'zone - test.entity - hello - hello - test', 'zone - test.entity - hello - hello - test',
self.calls[0].data['some']) self.calls[0].data['some'])

View file

@ -3,7 +3,7 @@
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
from homeassistant.core import callback from homeassistant.core import Context, callback
from homeassistant.setup import setup_component from homeassistant.setup import setup_component
from homeassistant.components import script from homeassistant.components import script
@ -134,6 +134,7 @@ class TestScriptComponent(unittest.TestCase):
def test_passing_variables(self): def test_passing_variables(self):
"""Test different ways of passing in variables.""" """Test different ways of passing in variables."""
calls = [] calls = []
context = Context()
@callback @callback
def record_call(service): def record_call(service):
@ -157,21 +158,23 @@ class TestScriptComponent(unittest.TestCase):
script.turn_on(self.hass, ENTITY_ID, { script.turn_on(self.hass, ENTITY_ID, {
'greeting': 'world' 'greeting': 'world'
}) }, context=context)
self.hass.block_till_done() self.hass.block_till_done()
assert len(calls) == 1 assert len(calls) == 1
assert calls[-1].data['hello'] == 'world' assert calls[0].context is context
assert calls[0].data['hello'] == 'world'
self.hass.services.call('script', 'test', { self.hass.services.call('script', 'test', {
'greeting': 'universe', 'greeting': 'universe',
}) }, context=context)
self.hass.block_till_done() self.hass.block_till_done()
assert len(calls) == 2 assert len(calls) == 2
assert calls[-1].data['hello'] == 'universe' assert calls[1].context is context
assert calls[1].data['hello'] == 'universe'
def test_reload_service(self): def test_reload_service(self):
"""Verify that the turn_on service.""" """Verify that the turn_on service."""

View file

@ -4,7 +4,7 @@ from datetime import timedelta
from unittest import mock from unittest import mock
import unittest import unittest
from homeassistant.core import callback from homeassistant.core import Context, callback
# Otherwise can't test just this file (import order issue) # Otherwise can't test just this file (import order issue)
import homeassistant.components # noqa import homeassistant.components # noqa
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -32,6 +32,7 @@ class TestScriptHelper(unittest.TestCase):
def test_firing_event(self): def test_firing_event(self):
"""Test the firing of events.""" """Test the firing of events."""
event = 'test_event' event = 'test_event'
context = Context()
calls = [] calls = []
@callback @callback
@ -48,17 +49,19 @@ class TestScriptHelper(unittest.TestCase):
} }
})) }))
script_obj.run() script_obj.run(context=context)
self.hass.block_till_done() self.hass.block_till_done()
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].context is context
assert calls[0].data.get('hello') == 'world' assert calls[0].data.get('hello') == 'world'
assert not script_obj.can_cancel assert not script_obj.can_cancel
def test_firing_event_template(self): def test_firing_event_template(self):
"""Test the firing of events.""" """Test the firing of events."""
event = 'test_event' event = 'test_event'
context = Context()
calls = [] calls = []
@callback @callback
@ -82,11 +85,12 @@ class TestScriptHelper(unittest.TestCase):
} }
})) }))
script_obj.run({'is_world': 'yes'}) script_obj.run({'is_world': 'yes'}, context=context)
self.hass.block_till_done() self.hass.block_till_done()
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].context is context
assert calls[0].data == { assert calls[0].data == {
'dict': { 'dict': {
1: 'yes', 1: 'yes',
@ -100,6 +104,7 @@ class TestScriptHelper(unittest.TestCase):
def test_calling_service(self): def test_calling_service(self):
"""Test the calling of a service.""" """Test the calling of a service."""
calls = [] calls = []
context = Context()
@callback @callback
def record_call(service): def record_call(service):
@ -113,16 +118,18 @@ class TestScriptHelper(unittest.TestCase):
'data': { 'data': {
'hello': 'world' 'hello': 'world'
} }
}) }, context=context)
self.hass.block_till_done() self.hass.block_till_done()
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].context is context
assert calls[0].data.get('hello') == 'world' assert calls[0].data.get('hello') == 'world'
def test_calling_service_template(self): def test_calling_service_template(self):
"""Test the calling of a service.""" """Test the calling of a service."""
calls = [] calls = []
context = Context()
@callback @callback
def record_call(service): def record_call(service):
@ -147,17 +154,19 @@ class TestScriptHelper(unittest.TestCase):
{% endif %} {% endif %}
""" """
} }
}, {'is_world': 'yes'}) }, {'is_world': 'yes'}, context=context)
self.hass.block_till_done() self.hass.block_till_done()
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].context is context
assert calls[0].data.get('hello') == 'world' assert calls[0].data.get('hello') == 'world'
def test_delay(self): def test_delay(self):
"""Test the delay.""" """Test the delay."""
event = 'test_event' event = 'test_event'
events = [] events = []
context = Context()
@callback @callback
def record_event(event): def record_event(event):
@ -171,7 +180,7 @@ class TestScriptHelper(unittest.TestCase):
{'delay': {'seconds': 5}}, {'delay': {'seconds': 5}},
{'event': event}])) {'event': event}]))
script_obj.run() script_obj.run(context=context)
self.hass.block_till_done() self.hass.block_till_done()
assert script_obj.is_running assert script_obj.is_running
@ -185,6 +194,8 @@ class TestScriptHelper(unittest.TestCase):
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 2 assert len(events) == 2
assert events[0].context is context
assert events[1].context is context
def test_delay_template(self): def test_delay_template(self):
"""Test the delay as a template.""" """Test the delay as a template."""
@ -282,6 +293,7 @@ class TestScriptHelper(unittest.TestCase):
"""Test the wait template.""" """Test the wait template."""
event = 'test_event' event = 'test_event'
events = [] events = []
context = Context()
@callback @callback
def record_event(event): def record_event(event):
@ -297,7 +309,7 @@ class TestScriptHelper(unittest.TestCase):
{'wait_template': "{{states.switch.test.state == 'off'}}"}, {'wait_template': "{{states.switch.test.state == 'off'}}"},
{'event': event}])) {'event': event}]))
script_obj.run() script_obj.run(context=context)
self.hass.block_till_done() self.hass.block_till_done()
assert script_obj.is_running assert script_obj.is_running
@ -310,6 +322,8 @@ class TestScriptHelper(unittest.TestCase):
assert not script_obj.is_running assert not script_obj.is_running
assert len(events) == 2 assert len(events) == 2
assert events[0].context is context
assert events[1].context is context
def test_wait_template_cancel(self): def test_wait_template_cancel(self):
"""Test the wait template cancel action.""" """Test the wait template cancel action."""