Add context to scripts and automations (#16415)
* Add context to script helper * Update script component * Add context to automations * Lint
This commit is contained in:
parent
e1501c83f8
commit
746f4ac158
17 changed files with 164 additions and 144 deletions
|
@ -158,27 +158,26 @@ def async_reload(hass):
|
|||
return hass.services.async_call(DOMAIN, SERVICE_RELOAD)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_setup(hass, config):
|
||||
async def async_setup(hass, config):
|
||||
"""Set up the automation."""
|
||||
component = EntityComponent(_LOGGER, DOMAIN, hass,
|
||||
group_name=GROUP_NAME_ALL_AUTOMATIONS)
|
||||
|
||||
yield from _async_process_config(hass, config, component)
|
||||
await _async_process_config(hass, config, component)
|
||||
|
||||
@asyncio.coroutine
|
||||
def trigger_service_handler(service_call):
|
||||
async def trigger_service_handler(service_call):
|
||||
"""Handle automation triggers."""
|
||||
tasks = []
|
||||
for entity in component.async_extract_from_service(service_call):
|
||||
tasks.append(entity.async_trigger(
|
||||
service_call.data.get(ATTR_VARIABLES), True))
|
||||
service_call.data.get(ATTR_VARIABLES),
|
||||
skip_condition=True,
|
||||
context=service_call.context))
|
||||
|
||||
if tasks:
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
await asyncio.wait(tasks, loop=hass.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def turn_onoff_service_handler(service_call):
|
||||
async def turn_onoff_service_handler(service_call):
|
||||
"""Handle automation turn on/off service calls."""
|
||||
tasks = []
|
||||
method = 'async_{}'.format(service_call.service)
|
||||
|
@ -186,10 +185,9 @@ def async_setup(hass, config):
|
|||
tasks.append(getattr(entity, method)())
|
||||
|
||||
if tasks:
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
await asyncio.wait(tasks, loop=hass.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def toggle_service_handler(service_call):
|
||||
async def toggle_service_handler(service_call):
|
||||
"""Handle automation toggle service calls."""
|
||||
tasks = []
|
||||
for entity in component.async_extract_from_service(service_call):
|
||||
|
@ -199,15 +197,14 @@ def async_setup(hass, config):
|
|||
tasks.append(entity.async_turn_on())
|
||||
|
||||
if tasks:
|
||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
||||
await asyncio.wait(tasks, loop=hass.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def reload_service_handler(service_call):
|
||||
async def reload_service_handler(service_call):
|
||||
"""Remove all automations and load new ones from config."""
|
||||
conf = yield from component.async_prepare_reload()
|
||||
conf = await component.async_prepare_reload()
|
||||
if conf is None:
|
||||
return
|
||||
yield from _async_process_config(hass, conf, component)
|
||||
await _async_process_config(hass, conf, component)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_TRIGGER, trigger_service_handler,
|
||||
|
@ -272,15 +269,14 @@ class AutomationEntity(ToggleEntity):
|
|||
"""Return True if entity is on."""
|
||||
return self._async_detach_triggers is not None
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_added_to_hass(self) -> None:
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Startup with initial state or previous state."""
|
||||
if self._initial_state is not None:
|
||||
enable_automation = self._initial_state
|
||||
_LOGGER.debug("Automation %s initial state %s from config "
|
||||
"initial_state", self.entity_id, enable_automation)
|
||||
else:
|
||||
state = yield from async_get_last_state(self.hass, self.entity_id)
|
||||
state = await async_get_last_state(self.hass, self.entity_id)
|
||||
if state:
|
||||
enable_automation = state.state == STATE_ON
|
||||
self._last_triggered = state.attributes.get('last_triggered')
|
||||
|
@ -298,54 +294,50 @@ class AutomationEntity(ToggleEntity):
|
|||
|
||||
# HomeAssistant is starting up
|
||||
if self.hass.state == CoreState.not_running:
|
||||
@asyncio.coroutine
|
||||
def async_enable_automation(event):
|
||||
async def async_enable_automation(event):
|
||||
"""Start automation on startup."""
|
||||
yield from self.async_enable()
|
||||
await self.async_enable()
|
||||
|
||||
self.hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_START, async_enable_automation)
|
||||
|
||||
# HomeAssistant is running
|
||||
else:
|
||||
yield from self.async_enable()
|
||||
await self.async_enable()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_turn_on(self, **kwargs) -> None:
|
||||
async def async_turn_on(self, **kwargs) -> None:
|
||||
"""Turn the entity on and update the state."""
|
||||
if self.is_on:
|
||||
return
|
||||
|
||||
yield from self.async_enable()
|
||||
await self.async_enable()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_turn_off(self, **kwargs) -> None:
|
||||
async def async_turn_off(self, **kwargs) -> None:
|
||||
"""Turn the entity off."""
|
||||
if not self.is_on:
|
||||
return
|
||||
|
||||
self._async_detach_triggers()
|
||||
self._async_detach_triggers = None
|
||||
yield from self.async_update_ha_state()
|
||||
await self.async_update_ha_state()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_trigger(self, variables, skip_condition=False):
|
||||
async def async_trigger(self, variables, skip_condition=False,
|
||||
context=None):
|
||||
"""Trigger automation.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
if skip_condition or self._cond_func(variables):
|
||||
yield from self._async_action(self.entity_id, variables)
|
||||
self.async_set_context(context)
|
||||
await self._async_action(self.entity_id, variables, context)
|
||||
self._last_triggered = utcnow()
|
||||
yield from self.async_update_ha_state()
|
||||
await self.async_update_ha_state()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_will_remove_from_hass(self):
|
||||
async def async_will_remove_from_hass(self):
|
||||
"""Remove listeners when removing automation from HASS."""
|
||||
yield from self.async_turn_off()
|
||||
await self.async_turn_off()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_enable(self):
|
||||
async def async_enable(self):
|
||||
"""Enable this automation entity.
|
||||
|
||||
This method is a coroutine.
|
||||
|
@ -353,9 +345,9 @@ class AutomationEntity(ToggleEntity):
|
|||
if self.is_on:
|
||||
return
|
||||
|
||||
self._async_detach_triggers = yield from self._async_attach_triggers(
|
||||
self._async_detach_triggers = await self._async_attach_triggers(
|
||||
self.async_trigger)
|
||||
yield from self.async_update_ha_state()
|
||||
await self.async_update_ha_state()
|
||||
|
||||
@property
|
||||
def device_state_attributes(self):
|
||||
|
@ -368,8 +360,7 @@ class AutomationEntity(ToggleEntity):
|
|||
}
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def _async_process_config(hass, config, component):
|
||||
async def _async_process_config(hass, config, component):
|
||||
"""Process config and add automations.
|
||||
|
||||
This method is a coroutine.
|
||||
|
@ -411,20 +402,19 @@ def _async_process_config(hass, config, component):
|
|||
entities.append(entity)
|
||||
|
||||
if entities:
|
||||
yield from component.async_add_entities(entities)
|
||||
await component.async_add_entities(entities)
|
||||
|
||||
|
||||
def _async_get_action(hass, config, name):
|
||||
"""Return an action based on a configuration."""
|
||||
script_obj = script.Script(hass, config, name)
|
||||
|
||||
@asyncio.coroutine
|
||||
def action(entity_id, variables):
|
||||
async def action(entity_id, variables, context):
|
||||
"""Execute an action."""
|
||||
_LOGGER.info('Executing %s', name)
|
||||
logbook.async_log_entry(
|
||||
hass, name, 'has been triggered', DOMAIN, entity_id)
|
||||
yield from script_obj.async_run(variables)
|
||||
await script_obj.async_run(variables, context)
|
||||
|
||||
return action
|
||||
|
||||
|
@ -448,8 +438,7 @@ def _async_process_if(hass, config, p_config):
|
|||
return if_action
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def _async_process_trigger(hass, config, trigger_configs, name, action):
|
||||
async def _async_process_trigger(hass, config, trigger_configs, name, action):
|
||||
"""Set up the triggers.
|
||||
|
||||
This method is a coroutine.
|
||||
|
@ -457,13 +446,13 @@ def _async_process_trigger(hass, config, trigger_configs, name, action):
|
|||
removes = []
|
||||
|
||||
for conf in trigger_configs:
|
||||
platform = yield from async_prepare_setup_platform(
|
||||
platform = await async_prepare_setup_platform(
|
||||
hass, config, DOMAIN, conf.get(CONF_PLATFORM))
|
||||
|
||||
if platform is None:
|
||||
return None
|
||||
|
||||
remove = yield from platform.async_trigger(hass, conf, action)
|
||||
remove = await platform.async_trigger(hass, conf, action)
|
||||
|
||||
if not remove:
|
||||
_LOGGER.error("Error setting up trigger %s", name)
|
||||
|
|
|
@ -45,11 +45,11 @@ def async_trigger(hass, config, action):
|
|||
# If event data doesn't match requested schema, skip event
|
||||
return
|
||||
|
||||
hass.async_run_job(action, {
|
||||
hass.async_run_job(action({
|
||||
'trigger': {
|
||||
'platform': 'event',
|
||||
'event': event,
|
||||
},
|
||||
})
|
||||
}, context=event.context))
|
||||
|
||||
return hass.bus.async_listen(event_type, handle_event)
|
||||
|
|
|
@ -32,12 +32,12 @@ def async_trigger(hass, config, action):
|
|||
@callback
|
||||
def hass_shutdown(event):
|
||||
"""Execute when Home Assistant is shutting down."""
|
||||
hass.async_run_job(action, {
|
||||
hass.async_run_job(action({
|
||||
'trigger': {
|
||||
'platform': 'homeassistant',
|
||||
'event': event,
|
||||
},
|
||||
})
|
||||
}, context=event.context))
|
||||
|
||||
return hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP,
|
||||
hass_shutdown)
|
||||
|
@ -45,11 +45,11 @@ def async_trigger(hass, config, action):
|
|||
# Automation are enabled while hass is starting up, fire right away
|
||||
# Check state because a config reload shouldn't trigger it.
|
||||
if hass.state == CoreState.starting:
|
||||
hass.async_run_job(action, {
|
||||
hass.async_run_job(action({
|
||||
'trigger': {
|
||||
'platform': 'homeassistant',
|
||||
'event': event,
|
||||
},
|
||||
})
|
||||
}))
|
||||
|
||||
return lambda: None
|
||||
|
|
|
@ -66,7 +66,7 @@ def async_trigger(hass, config, action):
|
|||
@callback
|
||||
def call_action():
|
||||
"""Call action with right context."""
|
||||
hass.async_run_job(action, {
|
||||
hass.async_run_job(action({
|
||||
'trigger': {
|
||||
'platform': 'numeric_state',
|
||||
'entity_id': entity,
|
||||
|
@ -75,7 +75,7 @@ def async_trigger(hass, config, action):
|
|||
'from_state': from_s,
|
||||
'to_state': to_s,
|
||||
}
|
||||
})
|
||||
}, context=to_s.context))
|
||||
|
||||
matching = check_numeric_state(entity, from_s, to_s)
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ def async_trigger(hass, config, action):
|
|||
@callback
|
||||
def call_action():
|
||||
"""Call action with right context."""
|
||||
hass.async_run_job(action, {
|
||||
hass.async_run_job(action({
|
||||
'trigger': {
|
||||
'platform': 'state',
|
||||
'entity_id': entity,
|
||||
|
@ -51,7 +51,7 @@ def async_trigger(hass, config, action):
|
|||
'to_state': to_s,
|
||||
'for': time_delta,
|
||||
}
|
||||
})
|
||||
}, context=to_s.context))
|
||||
|
||||
# Ignore changes to state attributes if from/to is in use
|
||||
if (not match_all and from_s is not None and to_s is not None and
|
||||
|
|
|
@ -32,13 +32,13 @@ def async_trigger(hass, config, action):
|
|||
@callback
|
||||
def template_listener(entity_id, from_s, to_s):
|
||||
"""Listen for state changes and calls action."""
|
||||
hass.async_run_job(action, {
|
||||
hass.async_run_job(action({
|
||||
'trigger': {
|
||||
'platform': 'template',
|
||||
'entity_id': entity_id,
|
||||
'from_state': from_s,
|
||||
'to_state': to_s,
|
||||
},
|
||||
})
|
||||
}, context=to_s.context))
|
||||
|
||||
return async_track_template(hass, value_template, template_listener)
|
||||
|
|
|
@ -51,7 +51,7 @@ def async_trigger(hass, config, action):
|
|||
# pylint: disable=too-many-boolean-expressions
|
||||
if event == EVENT_ENTER and not from_match and to_match or \
|
||||
event == EVENT_LEAVE and from_match and not to_match:
|
||||
hass.async_run_job(action, {
|
||||
hass.async_run_job(action({
|
||||
'trigger': {
|
||||
'platform': 'zone',
|
||||
'entity_id': entity,
|
||||
|
@ -60,7 +60,7 @@ def async_trigger(hass, config, action):
|
|||
'zone': zone_state,
|
||||
'event': event,
|
||||
},
|
||||
})
|
||||
}, context=to_s.context))
|
||||
|
||||
return async_track_state_change(hass, entity_id, zone_automation_listener,
|
||||
MATCH_ALL, MATCH_ALL)
|
||||
|
|
|
@ -63,11 +63,11 @@ def is_on(hass, entity_id):
|
|||
|
||||
|
||||
@bind_hass
|
||||
def turn_on(hass, entity_id, variables=None):
|
||||
def turn_on(hass, entity_id, variables=None, context=None):
|
||||
"""Turn script on."""
|
||||
_, object_id = split_entity_id(entity_id)
|
||||
|
||||
hass.services.call(DOMAIN, object_id, variables)
|
||||
hass.services.call(DOMAIN, object_id, variables, context=context)
|
||||
|
||||
|
||||
@bind_hass
|
||||
|
@ -97,45 +97,41 @@ def async_reload(hass):
|
|||
return hass.services.async_call(DOMAIN, SERVICE_RELOAD)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_setup(hass, config):
|
||||
async def async_setup(hass, config):
|
||||
"""Load the scripts from the configuration."""
|
||||
component = EntityComponent(
|
||||
_LOGGER, DOMAIN, hass, group_name=GROUP_NAME_ALL_SCRIPTS)
|
||||
|
||||
yield from _async_process_config(hass, config, component)
|
||||
await _async_process_config(hass, config, component)
|
||||
|
||||
@asyncio.coroutine
|
||||
def reload_service(service):
|
||||
async def reload_service(service):
|
||||
"""Call a service to reload scripts."""
|
||||
conf = yield from component.async_prepare_reload()
|
||||
conf = await component.async_prepare_reload()
|
||||
if conf is None:
|
||||
return
|
||||
|
||||
yield from _async_process_config(hass, conf, component)
|
||||
await _async_process_config(hass, conf, component)
|
||||
|
||||
@asyncio.coroutine
|
||||
def turn_on_service(service):
|
||||
async def turn_on_service(service):
|
||||
"""Call a service to turn script on."""
|
||||
# We could turn on script directly here, but we only want to offer
|
||||
# one way to do it. Otherwise no easy way to detect invocations.
|
||||
var = service.data.get(ATTR_VARIABLES)
|
||||
for script in component.async_extract_from_service(service):
|
||||
yield from hass.services.async_call(DOMAIN, script.object_id, var)
|
||||
await hass.services.async_call(DOMAIN, script.object_id, var,
|
||||
context=service.context)
|
||||
|
||||
@asyncio.coroutine
|
||||
def turn_off_service(service):
|
||||
async def turn_off_service(service):
|
||||
"""Cancel a script."""
|
||||
# Stopping a script is ok to be done in parallel
|
||||
yield from asyncio.wait(
|
||||
await asyncio.wait(
|
||||
[script.async_turn_off() for script
|
||||
in component.async_extract_from_service(service)], loop=hass.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def toggle_service(service):
|
||||
async def toggle_service(service):
|
||||
"""Toggle a script."""
|
||||
for script in component.async_extract_from_service(service):
|
||||
yield from script.async_toggle()
|
||||
await script.async_toggle(context=service.context)
|
||||
|
||||
hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_service,
|
||||
schema=RELOAD_SERVICE_SCHEMA)
|
||||
|
@ -149,18 +145,17 @@ def async_setup(hass, config):
|
|||
return True
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def _async_process_config(hass, config, component):
|
||||
"""Process group configuration."""
|
||||
@asyncio.coroutine
|
||||
def service_handler(service):
|
||||
async def _async_process_config(hass, config, component):
|
||||
"""Process script configuration."""
|
||||
async def service_handler(service):
|
||||
"""Execute a service call to script.<script name>."""
|
||||
entity_id = ENTITY_ID_FORMAT.format(service.service)
|
||||
script = component.get_entity(entity_id)
|
||||
if script.is_on:
|
||||
_LOGGER.warning("Script %s already running.", entity_id)
|
||||
return
|
||||
yield from script.async_turn_on(variables=service.data)
|
||||
await script.async_turn_on(variables=service.data,
|
||||
context=service.context)
|
||||
|
||||
scripts = []
|
||||
|
||||
|
@ -171,7 +166,7 @@ def _async_process_config(hass, config, component):
|
|||
hass.services.async_register(
|
||||
DOMAIN, object_id, service_handler, schema=SCRIPT_SERVICE_SCHEMA)
|
||||
|
||||
yield from component.async_add_entities(scripts)
|
||||
await component.async_add_entities(scripts)
|
||||
|
||||
|
||||
class ScriptEntity(ToggleEntity):
|
||||
|
@ -209,18 +204,16 @@ class ScriptEntity(ToggleEntity):
|
|||
"""Return true if script is on."""
|
||||
return self.script.is_running
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_turn_on(self, **kwargs):
|
||||
async def async_turn_on(self, **kwargs):
|
||||
"""Turn the script on."""
|
||||
yield from self.script.async_run(kwargs.get(ATTR_VARIABLES))
|
||||
await self.script.async_run(
|
||||
kwargs.get(ATTR_VARIABLES), kwargs.get('context'))
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_turn_off(self, **kwargs):
|
||||
async def async_turn_off(self, **kwargs):
|
||||
"""Turn script off."""
|
||||
self.script.async_stop()
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_will_remove_from_hass(self):
|
||||
async def async_will_remove_from_hass(self):
|
||||
"""Stop script and remove service when it will be removed from HASS."""
|
||||
if self.script.is_running:
|
||||
self.script.async_stop()
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import Optional, Sequence
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.core import HomeAssistant, Context, callback
|
||||
from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.helpers import (
|
||||
|
@ -34,9 +34,10 @@ CONF_CONTINUE = 'continue_on_timeout'
|
|||
|
||||
|
||||
def call_from_config(hass: HomeAssistant, config: ConfigType,
|
||||
variables: Optional[Sequence] = None) -> None:
|
||||
variables: Optional[Sequence] = None,
|
||||
context: Optional[Context] = None) -> None:
|
||||
"""Call a script based on a config entry."""
|
||||
Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables)
|
||||
Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables, context)
|
||||
|
||||
|
||||
class Script():
|
||||
|
@ -64,12 +65,13 @@ class Script():
|
|||
"""Return true if script is on."""
|
||||
return self._cur != -1
|
||||
|
||||
def run(self, variables=None):
|
||||
def run(self, variables=None, context=None):
|
||||
"""Run script."""
|
||||
run_coroutine_threadsafe(
|
||||
self.async_run(variables), self.hass.loop).result()
|
||||
self.async_run(variables, context), self.hass.loop).result()
|
||||
|
||||
async def async_run(self, variables: Optional[Sequence] = None) -> None:
|
||||
async def async_run(self, variables: Optional[Sequence] = None,
|
||||
context: Optional[Context] = None) -> None:
|
||||
"""Run script.
|
||||
|
||||
This method is a coroutine.
|
||||
|
@ -94,7 +96,8 @@ class Script():
|
|||
"""Handle delay."""
|
||||
# pylint: disable=cell-var-from-loop
|
||||
self._async_listener.remove(unsub)
|
||||
self.hass.async_add_job(self.async_run(variables))
|
||||
self.hass.async_create_task(
|
||||
self.async_run(variables, context))
|
||||
|
||||
delay = action[CONF_DELAY]
|
||||
|
||||
|
@ -134,7 +137,8 @@ class Script():
|
|||
def async_script_wait(entity_id, from_s, to_s):
|
||||
"""Handle script after template condition is true."""
|
||||
self._async_remove_listener()
|
||||
self.hass.async_add_job(self.async_run(variables))
|
||||
self.hass.async_create_task(
|
||||
self.async_run(variables, context))
|
||||
|
||||
self._async_listener.append(async_track_template(
|
||||
self.hass, wait_template, async_script_wait, variables))
|
||||
|
@ -145,7 +149,8 @@ class Script():
|
|||
|
||||
if CONF_TIMEOUT in action:
|
||||
self._async_set_timeout(
|
||||
action, variables, action.get(CONF_CONTINUE, True))
|
||||
action, variables, context,
|
||||
action.get(CONF_CONTINUE, True))
|
||||
|
||||
return
|
||||
|
||||
|
@ -154,10 +159,10 @@ class Script():
|
|||
break
|
||||
|
||||
elif CONF_EVENT in action:
|
||||
self._async_fire_event(action, variables)
|
||||
self._async_fire_event(action, variables, context)
|
||||
|
||||
else:
|
||||
await self._async_call_service(action, variables)
|
||||
await self._async_call_service(action, variables, context)
|
||||
|
||||
self._cur = -1
|
||||
self.last_action = None
|
||||
|
@ -178,7 +183,7 @@ class Script():
|
|||
if self._change_listener:
|
||||
self.hass.async_add_job(self._change_listener)
|
||||
|
||||
async def _async_call_service(self, action, variables):
|
||||
async def _async_call_service(self, action, variables, context):
|
||||
"""Call the service specified in the action.
|
||||
|
||||
This method is a coroutine.
|
||||
|
@ -186,9 +191,14 @@ class Script():
|
|||
self.last_action = action.get(CONF_ALIAS, 'call service')
|
||||
self._log("Executing step %s" % self.last_action)
|
||||
await service.async_call_from_config(
|
||||
self.hass, action, True, variables, validate_config=False)
|
||||
self.hass, action,
|
||||
blocking=True,
|
||||
variables=variables,
|
||||
validate_config=False,
|
||||
context=context
|
||||
)
|
||||
|
||||
def _async_fire_event(self, action, variables):
|
||||
def _async_fire_event(self, action, variables, context):
|
||||
"""Fire an event."""
|
||||
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
|
||||
self._log("Executing step %s" % self.last_action)
|
||||
|
@ -201,7 +211,7 @@ class Script():
|
|||
_LOGGER.error('Error rendering event data template: %s', ex)
|
||||
|
||||
self.hass.bus.async_fire(action[CONF_EVENT],
|
||||
event_data)
|
||||
event_data, context=context)
|
||||
|
||||
def _async_check_condition(self, action, variables):
|
||||
"""Test if condition is matching."""
|
||||
|
@ -216,7 +226,8 @@ class Script():
|
|||
self._log("Test condition {}: {}".format(self.last_action, check))
|
||||
return check
|
||||
|
||||
def _async_set_timeout(self, action, variables, continue_on_timeout=True):
|
||||
def _async_set_timeout(self, action, variables, context,
|
||||
continue_on_timeout):
|
||||
"""Schedule a timeout to abort or continue script."""
|
||||
timeout = action[CONF_TIMEOUT]
|
||||
unsub = None
|
||||
|
@ -229,7 +240,8 @@ class Script():
|
|||
# Check if we want to continue to execute
|
||||
# the script after the timeout
|
||||
if continue_on_timeout:
|
||||
self.hass.async_add_job(self.async_run(variables))
|
||||
self.hass.async_create_task(
|
||||
self.async_run(variables, context))
|
||||
else:
|
||||
self._log("Timeout reached, abort script.")
|
||||
self.async_stop()
|
||||
|
|
|
@ -36,7 +36,7 @@ def call_from_config(hass, config, blocking=False, variables=None,
|
|||
|
||||
@bind_hass
|
||||
async def async_call_from_config(hass, config, blocking=False, variables=None,
|
||||
validate_config=True):
|
||||
validate_config=True, context=None):
|
||||
"""Call a service based on a config hash."""
|
||||
if validate_config:
|
||||
try:
|
||||
|
@ -77,7 +77,7 @@ async def async_call_from_config(hass, config, blocking=False, variables=None,
|
|||
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
|
||||
|
||||
await hass.services.async_call(
|
||||
domain, service_name, service_data, blocking)
|
||||
domain, service_name, service_data, blocking=blocking, context=context)
|
||||
|
||||
|
||||
@bind_hass
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""The tests for the Event automation."""
|
||||
import unittest
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.setup import setup_component
|
||||
import homeassistant.components.automation as automation
|
||||
|
||||
|
@ -31,6 +31,8 @@ class TestAutomationEvent(unittest.TestCase):
|
|||
|
||||
def test_if_fires_on_event(self):
|
||||
"""Test the firing of events."""
|
||||
context = Context()
|
||||
|
||||
assert setup_component(self.hass, automation.DOMAIN, {
|
||||
automation.DOMAIN: {
|
||||
'trigger': {
|
||||
|
@ -43,9 +45,10 @@ class TestAutomationEvent(unittest.TestCase):
|
|||
}
|
||||
})
|
||||
|
||||
self.hass.bus.fire('test_event')
|
||||
self.hass.bus.fire('test_event', context=context)
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(1, len(self.calls))
|
||||
assert self.calls[0].context is context
|
||||
|
||||
automation.turn_off(self.hass)
|
||||
self.hass.block_till_done()
|
||||
|
|
|
@ -4,7 +4,7 @@ import unittest
|
|||
from unittest.mock import patch
|
||||
|
||||
import homeassistant.components.automation as automation
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.setup import setup_component
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
|
@ -36,6 +36,7 @@ class TestAutomationNumericState(unittest.TestCase):
|
|||
|
||||
def test_if_fires_on_entity_change_below(self):
|
||||
"""Test the firing with changed entity."""
|
||||
context = Context()
|
||||
assert setup_component(self.hass, automation.DOMAIN, {
|
||||
automation.DOMAIN: {
|
||||
'trigger': {
|
||||
|
@ -49,9 +50,10 @@ class TestAutomationNumericState(unittest.TestCase):
|
|||
}
|
||||
})
|
||||
# 9 is below 10
|
||||
self.hass.states.set('test.entity', 9)
|
||||
self.hass.states.set('test.entity', 9, context=context)
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(1, len(self.calls))
|
||||
assert self.calls[0].context is context
|
||||
|
||||
# Set above 12 so the automation will fire again
|
||||
self.hass.states.set('test.entity', 12)
|
||||
|
@ -116,6 +118,7 @@ class TestAutomationNumericState(unittest.TestCase):
|
|||
|
||||
def test_if_not_fires_on_entity_change_below_to_below(self):
|
||||
"""Test the firing with changed entity."""
|
||||
context = Context()
|
||||
self.hass.states.set('test.entity', 11)
|
||||
self.hass.block_till_done()
|
||||
|
||||
|
@ -133,9 +136,10 @@ class TestAutomationNumericState(unittest.TestCase):
|
|||
})
|
||||
|
||||
# 9 is below 10 so this should fire
|
||||
self.hass.states.set('test.entity', 9)
|
||||
self.hass.states.set('test.entity', 9, context=context)
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(1, len(self.calls))
|
||||
assert self.calls[0].context is context
|
||||
|
||||
# already below so should not fire again
|
||||
self.hass.states.set('test.entity', 5)
|
||||
|
|
|
@ -4,7 +4,7 @@ from datetime import timedelta
|
|||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.setup import setup_component
|
||||
import homeassistant.util.dt as dt_util
|
||||
import homeassistant.components.automation as automation
|
||||
|
@ -38,6 +38,7 @@ class TestAutomationState(unittest.TestCase):
|
|||
|
||||
def test_if_fires_on_entity_change(self):
|
||||
"""Test for firing on entity change."""
|
||||
context = Context()
|
||||
self.hass.states.set('test.entity', 'hello')
|
||||
self.hass.block_till_done()
|
||||
|
||||
|
@ -59,9 +60,10 @@ class TestAutomationState(unittest.TestCase):
|
|||
}
|
||||
})
|
||||
|
||||
self.hass.states.set('test.entity', 'world')
|
||||
self.hass.states.set('test.entity', 'world', context=context)
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(1, len(self.calls))
|
||||
assert self.calls[0].context is context
|
||||
self.assertEqual(
|
||||
'state - test.entity - hello - world - None',
|
||||
self.calls[0].data['some'])
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""The tests for the Template automation."""
|
||||
import unittest
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.setup import setup_component
|
||||
import homeassistant.components.automation as automation
|
||||
|
||||
|
@ -232,15 +232,12 @@ class TestAutomationTemplate(unittest.TestCase):
|
|||
|
||||
def test_if_fires_on_change_with_template_advanced(self):
|
||||
"""Test for firing on change with template advanced."""
|
||||
context = Context()
|
||||
assert setup_component(self.hass, automation.DOMAIN, {
|
||||
automation.DOMAIN: {
|
||||
'trigger': {
|
||||
'platform': 'template',
|
||||
'value_template': '''{%- if is_state("test.entity", "world") -%}
|
||||
true
|
||||
{%- else -%}
|
||||
false
|
||||
{%- endif -%}''',
|
||||
'value_template': '{{ is_state("test.entity", "world") }}'
|
||||
},
|
||||
'action': {
|
||||
'service': 'test.automation',
|
||||
|
@ -257,9 +254,10 @@ class TestAutomationTemplate(unittest.TestCase):
|
|||
self.hass.block_till_done()
|
||||
self.calls = []
|
||||
|
||||
self.hass.states.set('test.entity', 'world')
|
||||
self.hass.states.set('test.entity', 'world', context=context)
|
||||
self.hass.block_till_done()
|
||||
self.assertEqual(1, len(self.calls))
|
||||
assert self.calls[0].context is context
|
||||
self.assertEqual(
|
||||
'template - test.entity - hello - world',
|
||||
self.calls[0].data['some'])
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""The tests for the location automation."""
|
||||
import unittest
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.setup import setup_component
|
||||
from homeassistant.components import automation, zone
|
||||
|
||||
|
@ -40,6 +40,7 @@ class TestAutomationZone(unittest.TestCase):
|
|||
|
||||
def test_if_fires_on_zone_enter(self):
|
||||
"""Test for firing on zone enter."""
|
||||
context = Context()
|
||||
self.hass.states.set('test.entity', 'hello', {
|
||||
'latitude': 32.881011,
|
||||
'longitude': -117.234758
|
||||
|
@ -70,10 +71,11 @@ class TestAutomationZone(unittest.TestCase):
|
|||
self.hass.states.set('test.entity', 'hello', {
|
||||
'latitude': 32.880586,
|
||||
'longitude': -117.237564
|
||||
})
|
||||
}, context=context)
|
||||
self.hass.block_till_done()
|
||||
|
||||
self.assertEqual(1, len(self.calls))
|
||||
assert self.calls[0].context is context
|
||||
self.assertEqual(
|
||||
'zone - test.entity - hello - hello - test',
|
||||
self.calls[0].data['some'])
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.setup import setup_component
|
||||
from homeassistant.components import script
|
||||
|
||||
|
@ -134,6 +134,7 @@ class TestScriptComponent(unittest.TestCase):
|
|||
def test_passing_variables(self):
|
||||
"""Test different ways of passing in variables."""
|
||||
calls = []
|
||||
context = Context()
|
||||
|
||||
@callback
|
||||
def record_call(service):
|
||||
|
@ -157,21 +158,23 @@ class TestScriptComponent(unittest.TestCase):
|
|||
|
||||
script.turn_on(self.hass, ENTITY_ID, {
|
||||
'greeting': 'world'
|
||||
})
|
||||
}, context=context)
|
||||
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[-1].data['hello'] == 'world'
|
||||
assert calls[0].context is context
|
||||
assert calls[0].data['hello'] == 'world'
|
||||
|
||||
self.hass.services.call('script', 'test', {
|
||||
'greeting': 'universe',
|
||||
})
|
||||
}, context=context)
|
||||
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 2
|
||||
assert calls[-1].data['hello'] == 'universe'
|
||||
assert calls[1].context is context
|
||||
assert calls[1].data['hello'] == 'universe'
|
||||
|
||||
def test_reload_service(self):
|
||||
"""Verify that the turn_on service."""
|
||||
|
|
|
@ -4,7 +4,7 @@ from datetime import timedelta
|
|||
from unittest import mock
|
||||
import unittest
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import Context, callback
|
||||
# Otherwise can't test just this file (import order issue)
|
||||
import homeassistant.components # noqa
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
@ -32,6 +32,7 @@ class TestScriptHelper(unittest.TestCase):
|
|||
def test_firing_event(self):
|
||||
"""Test the firing of events."""
|
||||
event = 'test_event'
|
||||
context = Context()
|
||||
calls = []
|
||||
|
||||
@callback
|
||||
|
@ -48,17 +49,19 @@ class TestScriptHelper(unittest.TestCase):
|
|||
}
|
||||
}))
|
||||
|
||||
script_obj.run()
|
||||
script_obj.run(context=context)
|
||||
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0].context is context
|
||||
assert calls[0].data.get('hello') == 'world'
|
||||
assert not script_obj.can_cancel
|
||||
|
||||
def test_firing_event_template(self):
|
||||
"""Test the firing of events."""
|
||||
event = 'test_event'
|
||||
context = Context()
|
||||
calls = []
|
||||
|
||||
@callback
|
||||
|
@ -82,11 +85,12 @@ class TestScriptHelper(unittest.TestCase):
|
|||
}
|
||||
}))
|
||||
|
||||
script_obj.run({'is_world': 'yes'})
|
||||
script_obj.run({'is_world': 'yes'}, context=context)
|
||||
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0].context is context
|
||||
assert calls[0].data == {
|
||||
'dict': {
|
||||
1: 'yes',
|
||||
|
@ -100,6 +104,7 @@ class TestScriptHelper(unittest.TestCase):
|
|||
def test_calling_service(self):
|
||||
"""Test the calling of a service."""
|
||||
calls = []
|
||||
context = Context()
|
||||
|
||||
@callback
|
||||
def record_call(service):
|
||||
|
@ -113,16 +118,18 @@ class TestScriptHelper(unittest.TestCase):
|
|||
'data': {
|
||||
'hello': 'world'
|
||||
}
|
||||
})
|
||||
}, context=context)
|
||||
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0].context is context
|
||||
assert calls[0].data.get('hello') == 'world'
|
||||
|
||||
def test_calling_service_template(self):
|
||||
"""Test the calling of a service."""
|
||||
calls = []
|
||||
context = Context()
|
||||
|
||||
@callback
|
||||
def record_call(service):
|
||||
|
@ -147,17 +154,19 @@ class TestScriptHelper(unittest.TestCase):
|
|||
{% endif %}
|
||||
"""
|
||||
}
|
||||
}, {'is_world': 'yes'})
|
||||
}, {'is_world': 'yes'}, context=context)
|
||||
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0].context is context
|
||||
assert calls[0].data.get('hello') == 'world'
|
||||
|
||||
def test_delay(self):
|
||||
"""Test the delay."""
|
||||
event = 'test_event'
|
||||
events = []
|
||||
context = Context()
|
||||
|
||||
@callback
|
||||
def record_event(event):
|
||||
|
@ -171,7 +180,7 @@ class TestScriptHelper(unittest.TestCase):
|
|||
{'delay': {'seconds': 5}},
|
||||
{'event': event}]))
|
||||
|
||||
script_obj.run()
|
||||
script_obj.run(context=context)
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert script_obj.is_running
|
||||
|
@ -185,6 +194,8 @@ class TestScriptHelper(unittest.TestCase):
|
|||
|
||||
assert not script_obj.is_running
|
||||
assert len(events) == 2
|
||||
assert events[0].context is context
|
||||
assert events[1].context is context
|
||||
|
||||
def test_delay_template(self):
|
||||
"""Test the delay as a template."""
|
||||
|
@ -282,6 +293,7 @@ class TestScriptHelper(unittest.TestCase):
|
|||
"""Test the wait template."""
|
||||
event = 'test_event'
|
||||
events = []
|
||||
context = Context()
|
||||
|
||||
@callback
|
||||
def record_event(event):
|
||||
|
@ -297,7 +309,7 @@ class TestScriptHelper(unittest.TestCase):
|
|||
{'wait_template': "{{states.switch.test.state == 'off'}}"},
|
||||
{'event': event}]))
|
||||
|
||||
script_obj.run()
|
||||
script_obj.run(context=context)
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert script_obj.is_running
|
||||
|
@ -310,6 +322,8 @@ class TestScriptHelper(unittest.TestCase):
|
|||
|
||||
assert not script_obj.is_running
|
||||
assert len(events) == 2
|
||||
assert events[0].context is context
|
||||
assert events[1].context is context
|
||||
|
||||
def test_wait_template_cancel(self):
|
||||
"""Test the wait template cancel action."""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue