Add context to scripts and automations (#16415)
* Add context to script helper * Update script component * Add context to automations * Lint
This commit is contained in:
parent
e1501c83f8
commit
746f4ac158
17 changed files with 164 additions and 144 deletions
|
@ -158,27 +158,26 @@ def async_reload(hass):
|
||||||
return hass.services.async_call(DOMAIN, SERVICE_RELOAD)
|
return hass.services.async_call(DOMAIN, SERVICE_RELOAD)
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_setup(hass, config):
|
||||||
def async_setup(hass, config):
|
|
||||||
"""Set up the automation."""
|
"""Set up the automation."""
|
||||||
component = EntityComponent(_LOGGER, DOMAIN, hass,
|
component = EntityComponent(_LOGGER, DOMAIN, hass,
|
||||||
group_name=GROUP_NAME_ALL_AUTOMATIONS)
|
group_name=GROUP_NAME_ALL_AUTOMATIONS)
|
||||||
|
|
||||||
yield from _async_process_config(hass, config, component)
|
await _async_process_config(hass, config, component)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def trigger_service_handler(service_call):
|
||||||
def trigger_service_handler(service_call):
|
|
||||||
"""Handle automation triggers."""
|
"""Handle automation triggers."""
|
||||||
tasks = []
|
tasks = []
|
||||||
for entity in component.async_extract_from_service(service_call):
|
for entity in component.async_extract_from_service(service_call):
|
||||||
tasks.append(entity.async_trigger(
|
tasks.append(entity.async_trigger(
|
||||||
service_call.data.get(ATTR_VARIABLES), True))
|
service_call.data.get(ATTR_VARIABLES),
|
||||||
|
skip_condition=True,
|
||||||
|
context=service_call.context))
|
||||||
|
|
||||||
if tasks:
|
if tasks:
|
||||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
await asyncio.wait(tasks, loop=hass.loop)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def turn_onoff_service_handler(service_call):
|
||||||
def turn_onoff_service_handler(service_call):
|
|
||||||
"""Handle automation turn on/off service calls."""
|
"""Handle automation turn on/off service calls."""
|
||||||
tasks = []
|
tasks = []
|
||||||
method = 'async_{}'.format(service_call.service)
|
method = 'async_{}'.format(service_call.service)
|
||||||
|
@ -186,10 +185,9 @@ def async_setup(hass, config):
|
||||||
tasks.append(getattr(entity, method)())
|
tasks.append(getattr(entity, method)())
|
||||||
|
|
||||||
if tasks:
|
if tasks:
|
||||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
await asyncio.wait(tasks, loop=hass.loop)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def toggle_service_handler(service_call):
|
||||||
def toggle_service_handler(service_call):
|
|
||||||
"""Handle automation toggle service calls."""
|
"""Handle automation toggle service calls."""
|
||||||
tasks = []
|
tasks = []
|
||||||
for entity in component.async_extract_from_service(service_call):
|
for entity in component.async_extract_from_service(service_call):
|
||||||
|
@ -199,15 +197,14 @@ def async_setup(hass, config):
|
||||||
tasks.append(entity.async_turn_on())
|
tasks.append(entity.async_turn_on())
|
||||||
|
|
||||||
if tasks:
|
if tasks:
|
||||||
yield from asyncio.wait(tasks, loop=hass.loop)
|
await asyncio.wait(tasks, loop=hass.loop)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def reload_service_handler(service_call):
|
||||||
def reload_service_handler(service_call):
|
|
||||||
"""Remove all automations and load new ones from config."""
|
"""Remove all automations and load new ones from config."""
|
||||||
conf = yield from component.async_prepare_reload()
|
conf = await component.async_prepare_reload()
|
||||||
if conf is None:
|
if conf is None:
|
||||||
return
|
return
|
||||||
yield from _async_process_config(hass, conf, component)
|
await _async_process_config(hass, conf, component)
|
||||||
|
|
||||||
hass.services.async_register(
|
hass.services.async_register(
|
||||||
DOMAIN, SERVICE_TRIGGER, trigger_service_handler,
|
DOMAIN, SERVICE_TRIGGER, trigger_service_handler,
|
||||||
|
@ -272,15 +269,14 @@ class AutomationEntity(ToggleEntity):
|
||||||
"""Return True if entity is on."""
|
"""Return True if entity is on."""
|
||||||
return self._async_detach_triggers is not None
|
return self._async_detach_triggers is not None
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_added_to_hass(self) -> None:
|
||||||
def async_added_to_hass(self) -> None:
|
|
||||||
"""Startup with initial state or previous state."""
|
"""Startup with initial state or previous state."""
|
||||||
if self._initial_state is not None:
|
if self._initial_state is not None:
|
||||||
enable_automation = self._initial_state
|
enable_automation = self._initial_state
|
||||||
_LOGGER.debug("Automation %s initial state %s from config "
|
_LOGGER.debug("Automation %s initial state %s from config "
|
||||||
"initial_state", self.entity_id, enable_automation)
|
"initial_state", self.entity_id, enable_automation)
|
||||||
else:
|
else:
|
||||||
state = yield from async_get_last_state(self.hass, self.entity_id)
|
state = await async_get_last_state(self.hass, self.entity_id)
|
||||||
if state:
|
if state:
|
||||||
enable_automation = state.state == STATE_ON
|
enable_automation = state.state == STATE_ON
|
||||||
self._last_triggered = state.attributes.get('last_triggered')
|
self._last_triggered = state.attributes.get('last_triggered')
|
||||||
|
@ -298,54 +294,50 @@ class AutomationEntity(ToggleEntity):
|
||||||
|
|
||||||
# HomeAssistant is starting up
|
# HomeAssistant is starting up
|
||||||
if self.hass.state == CoreState.not_running:
|
if self.hass.state == CoreState.not_running:
|
||||||
@asyncio.coroutine
|
async def async_enable_automation(event):
|
||||||
def async_enable_automation(event):
|
|
||||||
"""Start automation on startup."""
|
"""Start automation on startup."""
|
||||||
yield from self.async_enable()
|
await self.async_enable()
|
||||||
|
|
||||||
self.hass.bus.async_listen_once(
|
self.hass.bus.async_listen_once(
|
||||||
EVENT_HOMEASSISTANT_START, async_enable_automation)
|
EVENT_HOMEASSISTANT_START, async_enable_automation)
|
||||||
|
|
||||||
# HomeAssistant is running
|
# HomeAssistant is running
|
||||||
else:
|
else:
|
||||||
yield from self.async_enable()
|
await self.async_enable()
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_turn_on(self, **kwargs) -> None:
|
||||||
def async_turn_on(self, **kwargs) -> None:
|
|
||||||
"""Turn the entity on and update the state."""
|
"""Turn the entity on and update the state."""
|
||||||
if self.is_on:
|
if self.is_on:
|
||||||
return
|
return
|
||||||
|
|
||||||
yield from self.async_enable()
|
await self.async_enable()
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_turn_off(self, **kwargs) -> None:
|
||||||
def async_turn_off(self, **kwargs) -> None:
|
|
||||||
"""Turn the entity off."""
|
"""Turn the entity off."""
|
||||||
if not self.is_on:
|
if not self.is_on:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._async_detach_triggers()
|
self._async_detach_triggers()
|
||||||
self._async_detach_triggers = None
|
self._async_detach_triggers = None
|
||||||
yield from self.async_update_ha_state()
|
await self.async_update_ha_state()
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_trigger(self, variables, skip_condition=False,
|
||||||
def async_trigger(self, variables, skip_condition=False):
|
context=None):
|
||||||
"""Trigger automation.
|
"""Trigger automation.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
if skip_condition or self._cond_func(variables):
|
if skip_condition or self._cond_func(variables):
|
||||||
yield from self._async_action(self.entity_id, variables)
|
self.async_set_context(context)
|
||||||
|
await self._async_action(self.entity_id, variables, context)
|
||||||
self._last_triggered = utcnow()
|
self._last_triggered = utcnow()
|
||||||
yield from self.async_update_ha_state()
|
await self.async_update_ha_state()
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_will_remove_from_hass(self):
|
||||||
def async_will_remove_from_hass(self):
|
|
||||||
"""Remove listeners when removing automation from HASS."""
|
"""Remove listeners when removing automation from HASS."""
|
||||||
yield from self.async_turn_off()
|
await self.async_turn_off()
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_enable(self):
|
||||||
def async_enable(self):
|
|
||||||
"""Enable this automation entity.
|
"""Enable this automation entity.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
|
@ -353,9 +345,9 @@ class AutomationEntity(ToggleEntity):
|
||||||
if self.is_on:
|
if self.is_on:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._async_detach_triggers = yield from self._async_attach_triggers(
|
self._async_detach_triggers = await self._async_attach_triggers(
|
||||||
self.async_trigger)
|
self.async_trigger)
|
||||||
yield from self.async_update_ha_state()
|
await self.async_update_ha_state()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device_state_attributes(self):
|
def device_state_attributes(self):
|
||||||
|
@ -368,8 +360,7 @@ class AutomationEntity(ToggleEntity):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def _async_process_config(hass, config, component):
|
||||||
def _async_process_config(hass, config, component):
|
|
||||||
"""Process config and add automations.
|
"""Process config and add automations.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
|
@ -411,20 +402,19 @@ def _async_process_config(hass, config, component):
|
||||||
entities.append(entity)
|
entities.append(entity)
|
||||||
|
|
||||||
if entities:
|
if entities:
|
||||||
yield from component.async_add_entities(entities)
|
await component.async_add_entities(entities)
|
||||||
|
|
||||||
|
|
||||||
def _async_get_action(hass, config, name):
|
def _async_get_action(hass, config, name):
|
||||||
"""Return an action based on a configuration."""
|
"""Return an action based on a configuration."""
|
||||||
script_obj = script.Script(hass, config, name)
|
script_obj = script.Script(hass, config, name)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def action(entity_id, variables, context):
|
||||||
def action(entity_id, variables):
|
|
||||||
"""Execute an action."""
|
"""Execute an action."""
|
||||||
_LOGGER.info('Executing %s', name)
|
_LOGGER.info('Executing %s', name)
|
||||||
logbook.async_log_entry(
|
logbook.async_log_entry(
|
||||||
hass, name, 'has been triggered', DOMAIN, entity_id)
|
hass, name, 'has been triggered', DOMAIN, entity_id)
|
||||||
yield from script_obj.async_run(variables)
|
await script_obj.async_run(variables, context)
|
||||||
|
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
@ -448,8 +438,7 @@ def _async_process_if(hass, config, p_config):
|
||||||
return if_action
|
return if_action
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def _async_process_trigger(hass, config, trigger_configs, name, action):
|
||||||
def _async_process_trigger(hass, config, trigger_configs, name, action):
|
|
||||||
"""Set up the triggers.
|
"""Set up the triggers.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
|
@ -457,13 +446,13 @@ def _async_process_trigger(hass, config, trigger_configs, name, action):
|
||||||
removes = []
|
removes = []
|
||||||
|
|
||||||
for conf in trigger_configs:
|
for conf in trigger_configs:
|
||||||
platform = yield from async_prepare_setup_platform(
|
platform = await async_prepare_setup_platform(
|
||||||
hass, config, DOMAIN, conf.get(CONF_PLATFORM))
|
hass, config, DOMAIN, conf.get(CONF_PLATFORM))
|
||||||
|
|
||||||
if platform is None:
|
if platform is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
remove = yield from platform.async_trigger(hass, conf, action)
|
remove = await platform.async_trigger(hass, conf, action)
|
||||||
|
|
||||||
if not remove:
|
if not remove:
|
||||||
_LOGGER.error("Error setting up trigger %s", name)
|
_LOGGER.error("Error setting up trigger %s", name)
|
||||||
|
|
|
@ -45,11 +45,11 @@ def async_trigger(hass, config, action):
|
||||||
# If event data doesn't match requested schema, skip event
|
# If event data doesn't match requested schema, skip event
|
||||||
return
|
return
|
||||||
|
|
||||||
hass.async_run_job(action, {
|
hass.async_run_job(action({
|
||||||
'trigger': {
|
'trigger': {
|
||||||
'platform': 'event',
|
'platform': 'event',
|
||||||
'event': event,
|
'event': event,
|
||||||
},
|
},
|
||||||
})
|
}, context=event.context))
|
||||||
|
|
||||||
return hass.bus.async_listen(event_type, handle_event)
|
return hass.bus.async_listen(event_type, handle_event)
|
||||||
|
|
|
@ -32,12 +32,12 @@ def async_trigger(hass, config, action):
|
||||||
@callback
|
@callback
|
||||||
def hass_shutdown(event):
|
def hass_shutdown(event):
|
||||||
"""Execute when Home Assistant is shutting down."""
|
"""Execute when Home Assistant is shutting down."""
|
||||||
hass.async_run_job(action, {
|
hass.async_run_job(action({
|
||||||
'trigger': {
|
'trigger': {
|
||||||
'platform': 'homeassistant',
|
'platform': 'homeassistant',
|
||||||
'event': event,
|
'event': event,
|
||||||
},
|
},
|
||||||
})
|
}, context=event.context))
|
||||||
|
|
||||||
return hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP,
|
return hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP,
|
||||||
hass_shutdown)
|
hass_shutdown)
|
||||||
|
@ -45,11 +45,11 @@ def async_trigger(hass, config, action):
|
||||||
# Automation are enabled while hass is starting up, fire right away
|
# Automation are enabled while hass is starting up, fire right away
|
||||||
# Check state because a config reload shouldn't trigger it.
|
# Check state because a config reload shouldn't trigger it.
|
||||||
if hass.state == CoreState.starting:
|
if hass.state == CoreState.starting:
|
||||||
hass.async_run_job(action, {
|
hass.async_run_job(action({
|
||||||
'trigger': {
|
'trigger': {
|
||||||
'platform': 'homeassistant',
|
'platform': 'homeassistant',
|
||||||
'event': event,
|
'event': event,
|
||||||
},
|
},
|
||||||
})
|
}))
|
||||||
|
|
||||||
return lambda: None
|
return lambda: None
|
||||||
|
|
|
@ -66,7 +66,7 @@ def async_trigger(hass, config, action):
|
||||||
@callback
|
@callback
|
||||||
def call_action():
|
def call_action():
|
||||||
"""Call action with right context."""
|
"""Call action with right context."""
|
||||||
hass.async_run_job(action, {
|
hass.async_run_job(action({
|
||||||
'trigger': {
|
'trigger': {
|
||||||
'platform': 'numeric_state',
|
'platform': 'numeric_state',
|
||||||
'entity_id': entity,
|
'entity_id': entity,
|
||||||
|
@ -75,7 +75,7 @@ def async_trigger(hass, config, action):
|
||||||
'from_state': from_s,
|
'from_state': from_s,
|
||||||
'to_state': to_s,
|
'to_state': to_s,
|
||||||
}
|
}
|
||||||
})
|
}, context=to_s.context))
|
||||||
|
|
||||||
matching = check_numeric_state(entity, from_s, to_s)
|
matching = check_numeric_state(entity, from_s, to_s)
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ def async_trigger(hass, config, action):
|
||||||
@callback
|
@callback
|
||||||
def call_action():
|
def call_action():
|
||||||
"""Call action with right context."""
|
"""Call action with right context."""
|
||||||
hass.async_run_job(action, {
|
hass.async_run_job(action({
|
||||||
'trigger': {
|
'trigger': {
|
||||||
'platform': 'state',
|
'platform': 'state',
|
||||||
'entity_id': entity,
|
'entity_id': entity,
|
||||||
|
@ -51,7 +51,7 @@ def async_trigger(hass, config, action):
|
||||||
'to_state': to_s,
|
'to_state': to_s,
|
||||||
'for': time_delta,
|
'for': time_delta,
|
||||||
}
|
}
|
||||||
})
|
}, context=to_s.context))
|
||||||
|
|
||||||
# Ignore changes to state attributes if from/to is in use
|
# Ignore changes to state attributes if from/to is in use
|
||||||
if (not match_all and from_s is not None and to_s is not None and
|
if (not match_all and from_s is not None and to_s is not None and
|
||||||
|
|
|
@ -32,13 +32,13 @@ def async_trigger(hass, config, action):
|
||||||
@callback
|
@callback
|
||||||
def template_listener(entity_id, from_s, to_s):
|
def template_listener(entity_id, from_s, to_s):
|
||||||
"""Listen for state changes and calls action."""
|
"""Listen for state changes and calls action."""
|
||||||
hass.async_run_job(action, {
|
hass.async_run_job(action({
|
||||||
'trigger': {
|
'trigger': {
|
||||||
'platform': 'template',
|
'platform': 'template',
|
||||||
'entity_id': entity_id,
|
'entity_id': entity_id,
|
||||||
'from_state': from_s,
|
'from_state': from_s,
|
||||||
'to_state': to_s,
|
'to_state': to_s,
|
||||||
},
|
},
|
||||||
})
|
}, context=to_s.context))
|
||||||
|
|
||||||
return async_track_template(hass, value_template, template_listener)
|
return async_track_template(hass, value_template, template_listener)
|
||||||
|
|
|
@ -51,7 +51,7 @@ def async_trigger(hass, config, action):
|
||||||
# pylint: disable=too-many-boolean-expressions
|
# pylint: disable=too-many-boolean-expressions
|
||||||
if event == EVENT_ENTER and not from_match and to_match or \
|
if event == EVENT_ENTER and not from_match and to_match or \
|
||||||
event == EVENT_LEAVE and from_match and not to_match:
|
event == EVENT_LEAVE and from_match and not to_match:
|
||||||
hass.async_run_job(action, {
|
hass.async_run_job(action({
|
||||||
'trigger': {
|
'trigger': {
|
||||||
'platform': 'zone',
|
'platform': 'zone',
|
||||||
'entity_id': entity,
|
'entity_id': entity,
|
||||||
|
@ -60,7 +60,7 @@ def async_trigger(hass, config, action):
|
||||||
'zone': zone_state,
|
'zone': zone_state,
|
||||||
'event': event,
|
'event': event,
|
||||||
},
|
},
|
||||||
})
|
}, context=to_s.context))
|
||||||
|
|
||||||
return async_track_state_change(hass, entity_id, zone_automation_listener,
|
return async_track_state_change(hass, entity_id, zone_automation_listener,
|
||||||
MATCH_ALL, MATCH_ALL)
|
MATCH_ALL, MATCH_ALL)
|
||||||
|
|
|
@ -63,11 +63,11 @@ def is_on(hass, entity_id):
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def turn_on(hass, entity_id, variables=None):
|
def turn_on(hass, entity_id, variables=None, context=None):
|
||||||
"""Turn script on."""
|
"""Turn script on."""
|
||||||
_, object_id = split_entity_id(entity_id)
|
_, object_id = split_entity_id(entity_id)
|
||||||
|
|
||||||
hass.services.call(DOMAIN, object_id, variables)
|
hass.services.call(DOMAIN, object_id, variables, context=context)
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
|
@ -97,45 +97,41 @@ def async_reload(hass):
|
||||||
return hass.services.async_call(DOMAIN, SERVICE_RELOAD)
|
return hass.services.async_call(DOMAIN, SERVICE_RELOAD)
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_setup(hass, config):
|
||||||
def async_setup(hass, config):
|
|
||||||
"""Load the scripts from the configuration."""
|
"""Load the scripts from the configuration."""
|
||||||
component = EntityComponent(
|
component = EntityComponent(
|
||||||
_LOGGER, DOMAIN, hass, group_name=GROUP_NAME_ALL_SCRIPTS)
|
_LOGGER, DOMAIN, hass, group_name=GROUP_NAME_ALL_SCRIPTS)
|
||||||
|
|
||||||
yield from _async_process_config(hass, config, component)
|
await _async_process_config(hass, config, component)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def reload_service(service):
|
||||||
def reload_service(service):
|
|
||||||
"""Call a service to reload scripts."""
|
"""Call a service to reload scripts."""
|
||||||
conf = yield from component.async_prepare_reload()
|
conf = await component.async_prepare_reload()
|
||||||
if conf is None:
|
if conf is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
yield from _async_process_config(hass, conf, component)
|
await _async_process_config(hass, conf, component)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def turn_on_service(service):
|
||||||
def turn_on_service(service):
|
|
||||||
"""Call a service to turn script on."""
|
"""Call a service to turn script on."""
|
||||||
# We could turn on script directly here, but we only want to offer
|
# We could turn on script directly here, but we only want to offer
|
||||||
# one way to do it. Otherwise no easy way to detect invocations.
|
# one way to do it. Otherwise no easy way to detect invocations.
|
||||||
var = service.data.get(ATTR_VARIABLES)
|
var = service.data.get(ATTR_VARIABLES)
|
||||||
for script in component.async_extract_from_service(service):
|
for script in component.async_extract_from_service(service):
|
||||||
yield from hass.services.async_call(DOMAIN, script.object_id, var)
|
await hass.services.async_call(DOMAIN, script.object_id, var,
|
||||||
|
context=service.context)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def turn_off_service(service):
|
||||||
def turn_off_service(service):
|
|
||||||
"""Cancel a script."""
|
"""Cancel a script."""
|
||||||
# Stopping a script is ok to be done in parallel
|
# Stopping a script is ok to be done in parallel
|
||||||
yield from asyncio.wait(
|
await asyncio.wait(
|
||||||
[script.async_turn_off() for script
|
[script.async_turn_off() for script
|
||||||
in component.async_extract_from_service(service)], loop=hass.loop)
|
in component.async_extract_from_service(service)], loop=hass.loop)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def toggle_service(service):
|
||||||
def toggle_service(service):
|
|
||||||
"""Toggle a script."""
|
"""Toggle a script."""
|
||||||
for script in component.async_extract_from_service(service):
|
for script in component.async_extract_from_service(service):
|
||||||
yield from script.async_toggle()
|
await script.async_toggle(context=service.context)
|
||||||
|
|
||||||
hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_service,
|
hass.services.async_register(DOMAIN, SERVICE_RELOAD, reload_service,
|
||||||
schema=RELOAD_SERVICE_SCHEMA)
|
schema=RELOAD_SERVICE_SCHEMA)
|
||||||
|
@ -149,18 +145,17 @@ def async_setup(hass, config):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def _async_process_config(hass, config, component):
|
||||||
def _async_process_config(hass, config, component):
|
"""Process script configuration."""
|
||||||
"""Process group configuration."""
|
async def service_handler(service):
|
||||||
@asyncio.coroutine
|
|
||||||
def service_handler(service):
|
|
||||||
"""Execute a service call to script.<script name>."""
|
"""Execute a service call to script.<script name>."""
|
||||||
entity_id = ENTITY_ID_FORMAT.format(service.service)
|
entity_id = ENTITY_ID_FORMAT.format(service.service)
|
||||||
script = component.get_entity(entity_id)
|
script = component.get_entity(entity_id)
|
||||||
if script.is_on:
|
if script.is_on:
|
||||||
_LOGGER.warning("Script %s already running.", entity_id)
|
_LOGGER.warning("Script %s already running.", entity_id)
|
||||||
return
|
return
|
||||||
yield from script.async_turn_on(variables=service.data)
|
await script.async_turn_on(variables=service.data,
|
||||||
|
context=service.context)
|
||||||
|
|
||||||
scripts = []
|
scripts = []
|
||||||
|
|
||||||
|
@ -171,7 +166,7 @@ def _async_process_config(hass, config, component):
|
||||||
hass.services.async_register(
|
hass.services.async_register(
|
||||||
DOMAIN, object_id, service_handler, schema=SCRIPT_SERVICE_SCHEMA)
|
DOMAIN, object_id, service_handler, schema=SCRIPT_SERVICE_SCHEMA)
|
||||||
|
|
||||||
yield from component.async_add_entities(scripts)
|
await component.async_add_entities(scripts)
|
||||||
|
|
||||||
|
|
||||||
class ScriptEntity(ToggleEntity):
|
class ScriptEntity(ToggleEntity):
|
||||||
|
@ -209,18 +204,16 @@ class ScriptEntity(ToggleEntity):
|
||||||
"""Return true if script is on."""
|
"""Return true if script is on."""
|
||||||
return self.script.is_running
|
return self.script.is_running
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_turn_on(self, **kwargs):
|
||||||
def async_turn_on(self, **kwargs):
|
|
||||||
"""Turn the script on."""
|
"""Turn the script on."""
|
||||||
yield from self.script.async_run(kwargs.get(ATTR_VARIABLES))
|
await self.script.async_run(
|
||||||
|
kwargs.get(ATTR_VARIABLES), kwargs.get('context'))
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_turn_off(self, **kwargs):
|
||||||
def async_turn_off(self, **kwargs):
|
|
||||||
"""Turn script off."""
|
"""Turn script off."""
|
||||||
self.script.async_stop()
|
self.script.async_stop()
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def async_will_remove_from_hass(self):
|
||||||
def async_will_remove_from_hass(self):
|
|
||||||
"""Stop script and remove service when it will be removed from HASS."""
|
"""Stop script and remove service when it will be removed from HASS."""
|
||||||
if self.script.is_running:
|
if self.script.is_running:
|
||||||
self.script.async_stop()
|
self.script.async_stop()
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import Optional, Sequence
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, Context, callback
|
||||||
from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT
|
from homeassistant.const import CONF_CONDITION, CONF_TIMEOUT
|
||||||
from homeassistant.exceptions import TemplateError
|
from homeassistant.exceptions import TemplateError
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
|
@ -34,9 +34,10 @@ CONF_CONTINUE = 'continue_on_timeout'
|
||||||
|
|
||||||
|
|
||||||
def call_from_config(hass: HomeAssistant, config: ConfigType,
|
def call_from_config(hass: HomeAssistant, config: ConfigType,
|
||||||
variables: Optional[Sequence] = None) -> None:
|
variables: Optional[Sequence] = None,
|
||||||
|
context: Optional[Context] = None) -> None:
|
||||||
"""Call a script based on a config entry."""
|
"""Call a script based on a config entry."""
|
||||||
Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables)
|
Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables, context)
|
||||||
|
|
||||||
|
|
||||||
class Script():
|
class Script():
|
||||||
|
@ -64,12 +65,13 @@ class Script():
|
||||||
"""Return true if script is on."""
|
"""Return true if script is on."""
|
||||||
return self._cur != -1
|
return self._cur != -1
|
||||||
|
|
||||||
def run(self, variables=None):
|
def run(self, variables=None, context=None):
|
||||||
"""Run script."""
|
"""Run script."""
|
||||||
run_coroutine_threadsafe(
|
run_coroutine_threadsafe(
|
||||||
self.async_run(variables), self.hass.loop).result()
|
self.async_run(variables, context), self.hass.loop).result()
|
||||||
|
|
||||||
async def async_run(self, variables: Optional[Sequence] = None) -> None:
|
async def async_run(self, variables: Optional[Sequence] = None,
|
||||||
|
context: Optional[Context] = None) -> None:
|
||||||
"""Run script.
|
"""Run script.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
|
@ -94,7 +96,8 @@ class Script():
|
||||||
"""Handle delay."""
|
"""Handle delay."""
|
||||||
# pylint: disable=cell-var-from-loop
|
# pylint: disable=cell-var-from-loop
|
||||||
self._async_listener.remove(unsub)
|
self._async_listener.remove(unsub)
|
||||||
self.hass.async_add_job(self.async_run(variables))
|
self.hass.async_create_task(
|
||||||
|
self.async_run(variables, context))
|
||||||
|
|
||||||
delay = action[CONF_DELAY]
|
delay = action[CONF_DELAY]
|
||||||
|
|
||||||
|
@ -134,7 +137,8 @@ class Script():
|
||||||
def async_script_wait(entity_id, from_s, to_s):
|
def async_script_wait(entity_id, from_s, to_s):
|
||||||
"""Handle script after template condition is true."""
|
"""Handle script after template condition is true."""
|
||||||
self._async_remove_listener()
|
self._async_remove_listener()
|
||||||
self.hass.async_add_job(self.async_run(variables))
|
self.hass.async_create_task(
|
||||||
|
self.async_run(variables, context))
|
||||||
|
|
||||||
self._async_listener.append(async_track_template(
|
self._async_listener.append(async_track_template(
|
||||||
self.hass, wait_template, async_script_wait, variables))
|
self.hass, wait_template, async_script_wait, variables))
|
||||||
|
@ -145,7 +149,8 @@ class Script():
|
||||||
|
|
||||||
if CONF_TIMEOUT in action:
|
if CONF_TIMEOUT in action:
|
||||||
self._async_set_timeout(
|
self._async_set_timeout(
|
||||||
action, variables, action.get(CONF_CONTINUE, True))
|
action, variables, context,
|
||||||
|
action.get(CONF_CONTINUE, True))
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -154,10 +159,10 @@ class Script():
|
||||||
break
|
break
|
||||||
|
|
||||||
elif CONF_EVENT in action:
|
elif CONF_EVENT in action:
|
||||||
self._async_fire_event(action, variables)
|
self._async_fire_event(action, variables, context)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await self._async_call_service(action, variables)
|
await self._async_call_service(action, variables, context)
|
||||||
|
|
||||||
self._cur = -1
|
self._cur = -1
|
||||||
self.last_action = None
|
self.last_action = None
|
||||||
|
@ -178,7 +183,7 @@ class Script():
|
||||||
if self._change_listener:
|
if self._change_listener:
|
||||||
self.hass.async_add_job(self._change_listener)
|
self.hass.async_add_job(self._change_listener)
|
||||||
|
|
||||||
async def _async_call_service(self, action, variables):
|
async def _async_call_service(self, action, variables, context):
|
||||||
"""Call the service specified in the action.
|
"""Call the service specified in the action.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
|
@ -186,9 +191,14 @@ class Script():
|
||||||
self.last_action = action.get(CONF_ALIAS, 'call service')
|
self.last_action = action.get(CONF_ALIAS, 'call service')
|
||||||
self._log("Executing step %s" % self.last_action)
|
self._log("Executing step %s" % self.last_action)
|
||||||
await service.async_call_from_config(
|
await service.async_call_from_config(
|
||||||
self.hass, action, True, variables, validate_config=False)
|
self.hass, action,
|
||||||
|
blocking=True,
|
||||||
|
variables=variables,
|
||||||
|
validate_config=False,
|
||||||
|
context=context
|
||||||
|
)
|
||||||
|
|
||||||
def _async_fire_event(self, action, variables):
|
def _async_fire_event(self, action, variables, context):
|
||||||
"""Fire an event."""
|
"""Fire an event."""
|
||||||
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
|
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
|
||||||
self._log("Executing step %s" % self.last_action)
|
self._log("Executing step %s" % self.last_action)
|
||||||
|
@ -201,7 +211,7 @@ class Script():
|
||||||
_LOGGER.error('Error rendering event data template: %s', ex)
|
_LOGGER.error('Error rendering event data template: %s', ex)
|
||||||
|
|
||||||
self.hass.bus.async_fire(action[CONF_EVENT],
|
self.hass.bus.async_fire(action[CONF_EVENT],
|
||||||
event_data)
|
event_data, context=context)
|
||||||
|
|
||||||
def _async_check_condition(self, action, variables):
|
def _async_check_condition(self, action, variables):
|
||||||
"""Test if condition is matching."""
|
"""Test if condition is matching."""
|
||||||
|
@ -216,7 +226,8 @@ class Script():
|
||||||
self._log("Test condition {}: {}".format(self.last_action, check))
|
self._log("Test condition {}: {}".format(self.last_action, check))
|
||||||
return check
|
return check
|
||||||
|
|
||||||
def _async_set_timeout(self, action, variables, continue_on_timeout=True):
|
def _async_set_timeout(self, action, variables, context,
|
||||||
|
continue_on_timeout):
|
||||||
"""Schedule a timeout to abort or continue script."""
|
"""Schedule a timeout to abort or continue script."""
|
||||||
timeout = action[CONF_TIMEOUT]
|
timeout = action[CONF_TIMEOUT]
|
||||||
unsub = None
|
unsub = None
|
||||||
|
@ -229,7 +240,8 @@ class Script():
|
||||||
# Check if we want to continue to execute
|
# Check if we want to continue to execute
|
||||||
# the script after the timeout
|
# the script after the timeout
|
||||||
if continue_on_timeout:
|
if continue_on_timeout:
|
||||||
self.hass.async_add_job(self.async_run(variables))
|
self.hass.async_create_task(
|
||||||
|
self.async_run(variables, context))
|
||||||
else:
|
else:
|
||||||
self._log("Timeout reached, abort script.")
|
self._log("Timeout reached, abort script.")
|
||||||
self.async_stop()
|
self.async_stop()
|
||||||
|
|
|
@ -36,7 +36,7 @@ def call_from_config(hass, config, blocking=False, variables=None,
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
async def async_call_from_config(hass, config, blocking=False, variables=None,
|
async def async_call_from_config(hass, config, blocking=False, variables=None,
|
||||||
validate_config=True):
|
validate_config=True, context=None):
|
||||||
"""Call a service based on a config hash."""
|
"""Call a service based on a config hash."""
|
||||||
if validate_config:
|
if validate_config:
|
||||||
try:
|
try:
|
||||||
|
@ -77,7 +77,7 @@ async def async_call_from_config(hass, config, blocking=False, variables=None,
|
||||||
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
|
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
|
||||||
|
|
||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
domain, service_name, service_data, blocking)
|
domain, service_name, service_data, blocking=blocking, context=context)
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""The tests for the Event automation."""
|
"""The tests for the Event automation."""
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import Context, callback
|
||||||
from homeassistant.setup import setup_component
|
from homeassistant.setup import setup_component
|
||||||
import homeassistant.components.automation as automation
|
import homeassistant.components.automation as automation
|
||||||
|
|
||||||
|
@ -31,6 +31,8 @@ class TestAutomationEvent(unittest.TestCase):
|
||||||
|
|
||||||
def test_if_fires_on_event(self):
|
def test_if_fires_on_event(self):
|
||||||
"""Test the firing of events."""
|
"""Test the firing of events."""
|
||||||
|
context = Context()
|
||||||
|
|
||||||
assert setup_component(self.hass, automation.DOMAIN, {
|
assert setup_component(self.hass, automation.DOMAIN, {
|
||||||
automation.DOMAIN: {
|
automation.DOMAIN: {
|
||||||
'trigger': {
|
'trigger': {
|
||||||
|
@ -43,9 +45,10 @@ class TestAutomationEvent(unittest.TestCase):
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
self.hass.bus.fire('test_event')
|
self.hass.bus.fire('test_event', context=context)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.assertEqual(1, len(self.calls))
|
self.assertEqual(1, len(self.calls))
|
||||||
|
assert self.calls[0].context is context
|
||||||
|
|
||||||
automation.turn_off(self.hass)
|
automation.turn_off(self.hass)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
|
@ -4,7 +4,7 @@ import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import homeassistant.components.automation as automation
|
import homeassistant.components.automation as automation
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import Context, callback
|
||||||
from homeassistant.setup import setup_component
|
from homeassistant.setup import setup_component
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
|
||||||
|
@ -36,6 +36,7 @@ class TestAutomationNumericState(unittest.TestCase):
|
||||||
|
|
||||||
def test_if_fires_on_entity_change_below(self):
|
def test_if_fires_on_entity_change_below(self):
|
||||||
"""Test the firing with changed entity."""
|
"""Test the firing with changed entity."""
|
||||||
|
context = Context()
|
||||||
assert setup_component(self.hass, automation.DOMAIN, {
|
assert setup_component(self.hass, automation.DOMAIN, {
|
||||||
automation.DOMAIN: {
|
automation.DOMAIN: {
|
||||||
'trigger': {
|
'trigger': {
|
||||||
|
@ -49,9 +50,10 @@ class TestAutomationNumericState(unittest.TestCase):
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
# 9 is below 10
|
# 9 is below 10
|
||||||
self.hass.states.set('test.entity', 9)
|
self.hass.states.set('test.entity', 9, context=context)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.assertEqual(1, len(self.calls))
|
self.assertEqual(1, len(self.calls))
|
||||||
|
assert self.calls[0].context is context
|
||||||
|
|
||||||
# Set above 12 so the automation will fire again
|
# Set above 12 so the automation will fire again
|
||||||
self.hass.states.set('test.entity', 12)
|
self.hass.states.set('test.entity', 12)
|
||||||
|
@ -116,6 +118,7 @@ class TestAutomationNumericState(unittest.TestCase):
|
||||||
|
|
||||||
def test_if_not_fires_on_entity_change_below_to_below(self):
|
def test_if_not_fires_on_entity_change_below_to_below(self):
|
||||||
"""Test the firing with changed entity."""
|
"""Test the firing with changed entity."""
|
||||||
|
context = Context()
|
||||||
self.hass.states.set('test.entity', 11)
|
self.hass.states.set('test.entity', 11)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
|
@ -133,9 +136,10 @@ class TestAutomationNumericState(unittest.TestCase):
|
||||||
})
|
})
|
||||||
|
|
||||||
# 9 is below 10 so this should fire
|
# 9 is below 10 so this should fire
|
||||||
self.hass.states.set('test.entity', 9)
|
self.hass.states.set('test.entity', 9, context=context)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.assertEqual(1, len(self.calls))
|
self.assertEqual(1, len(self.calls))
|
||||||
|
assert self.calls[0].context is context
|
||||||
|
|
||||||
# already below so should not fire again
|
# already below so should not fire again
|
||||||
self.hass.states.set('test.entity', 5)
|
self.hass.states.set('test.entity', 5)
|
||||||
|
|
|
@ -4,7 +4,7 @@ from datetime import timedelta
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import Context, callback
|
||||||
from homeassistant.setup import setup_component
|
from homeassistant.setup import setup_component
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
import homeassistant.components.automation as automation
|
import homeassistant.components.automation as automation
|
||||||
|
@ -38,6 +38,7 @@ class TestAutomationState(unittest.TestCase):
|
||||||
|
|
||||||
def test_if_fires_on_entity_change(self):
|
def test_if_fires_on_entity_change(self):
|
||||||
"""Test for firing on entity change."""
|
"""Test for firing on entity change."""
|
||||||
|
context = Context()
|
||||||
self.hass.states.set('test.entity', 'hello')
|
self.hass.states.set('test.entity', 'hello')
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
|
@ -59,9 +60,10 @@ class TestAutomationState(unittest.TestCase):
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
self.hass.states.set('test.entity', 'world')
|
self.hass.states.set('test.entity', 'world', context=context)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.assertEqual(1, len(self.calls))
|
self.assertEqual(1, len(self.calls))
|
||||||
|
assert self.calls[0].context is context
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
'state - test.entity - hello - world - None',
|
'state - test.entity - hello - world - None',
|
||||||
self.calls[0].data['some'])
|
self.calls[0].data['some'])
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""The tests for the Template automation."""
|
"""The tests for the Template automation."""
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import Context, callback
|
||||||
from homeassistant.setup import setup_component
|
from homeassistant.setup import setup_component
|
||||||
import homeassistant.components.automation as automation
|
import homeassistant.components.automation as automation
|
||||||
|
|
||||||
|
@ -232,15 +232,12 @@ class TestAutomationTemplate(unittest.TestCase):
|
||||||
|
|
||||||
def test_if_fires_on_change_with_template_advanced(self):
|
def test_if_fires_on_change_with_template_advanced(self):
|
||||||
"""Test for firing on change with template advanced."""
|
"""Test for firing on change with template advanced."""
|
||||||
|
context = Context()
|
||||||
assert setup_component(self.hass, automation.DOMAIN, {
|
assert setup_component(self.hass, automation.DOMAIN, {
|
||||||
automation.DOMAIN: {
|
automation.DOMAIN: {
|
||||||
'trigger': {
|
'trigger': {
|
||||||
'platform': 'template',
|
'platform': 'template',
|
||||||
'value_template': '''{%- if is_state("test.entity", "world") -%}
|
'value_template': '{{ is_state("test.entity", "world") }}'
|
||||||
true
|
|
||||||
{%- else -%}
|
|
||||||
false
|
|
||||||
{%- endif -%}''',
|
|
||||||
},
|
},
|
||||||
'action': {
|
'action': {
|
||||||
'service': 'test.automation',
|
'service': 'test.automation',
|
||||||
|
@ -257,9 +254,10 @@ class TestAutomationTemplate(unittest.TestCase):
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.calls = []
|
self.calls = []
|
||||||
|
|
||||||
self.hass.states.set('test.entity', 'world')
|
self.hass.states.set('test.entity', 'world', context=context)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
self.assertEqual(1, len(self.calls))
|
self.assertEqual(1, len(self.calls))
|
||||||
|
assert self.calls[0].context is context
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
'template - test.entity - hello - world',
|
'template - test.entity - hello - world',
|
||||||
self.calls[0].data['some'])
|
self.calls[0].data['some'])
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""The tests for the location automation."""
|
"""The tests for the location automation."""
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import Context, callback
|
||||||
from homeassistant.setup import setup_component
|
from homeassistant.setup import setup_component
|
||||||
from homeassistant.components import automation, zone
|
from homeassistant.components import automation, zone
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@ class TestAutomationZone(unittest.TestCase):
|
||||||
|
|
||||||
def test_if_fires_on_zone_enter(self):
|
def test_if_fires_on_zone_enter(self):
|
||||||
"""Test for firing on zone enter."""
|
"""Test for firing on zone enter."""
|
||||||
|
context = Context()
|
||||||
self.hass.states.set('test.entity', 'hello', {
|
self.hass.states.set('test.entity', 'hello', {
|
||||||
'latitude': 32.881011,
|
'latitude': 32.881011,
|
||||||
'longitude': -117.234758
|
'longitude': -117.234758
|
||||||
|
@ -70,10 +71,11 @@ class TestAutomationZone(unittest.TestCase):
|
||||||
self.hass.states.set('test.entity', 'hello', {
|
self.hass.states.set('test.entity', 'hello', {
|
||||||
'latitude': 32.880586,
|
'latitude': 32.880586,
|
||||||
'longitude': -117.237564
|
'longitude': -117.237564
|
||||||
})
|
}, context=context)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(1, len(self.calls))
|
self.assertEqual(1, len(self.calls))
|
||||||
|
assert self.calls[0].context is context
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
'zone - test.entity - hello - hello - test',
|
'zone - test.entity - hello - hello - test',
|
||||||
self.calls[0].data['some'])
|
self.calls[0].data['some'])
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import Context, callback
|
||||||
from homeassistant.setup import setup_component
|
from homeassistant.setup import setup_component
|
||||||
from homeassistant.components import script
|
from homeassistant.components import script
|
||||||
|
|
||||||
|
@ -134,6 +134,7 @@ class TestScriptComponent(unittest.TestCase):
|
||||||
def test_passing_variables(self):
|
def test_passing_variables(self):
|
||||||
"""Test different ways of passing in variables."""
|
"""Test different ways of passing in variables."""
|
||||||
calls = []
|
calls = []
|
||||||
|
context = Context()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def record_call(service):
|
def record_call(service):
|
||||||
|
@ -157,21 +158,23 @@ class TestScriptComponent(unittest.TestCase):
|
||||||
|
|
||||||
script.turn_on(self.hass, ENTITY_ID, {
|
script.turn_on(self.hass, ENTITY_ID, {
|
||||||
'greeting': 'world'
|
'greeting': 'world'
|
||||||
})
|
}, context=context)
|
||||||
|
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
assert calls[-1].data['hello'] == 'world'
|
assert calls[0].context is context
|
||||||
|
assert calls[0].data['hello'] == 'world'
|
||||||
|
|
||||||
self.hass.services.call('script', 'test', {
|
self.hass.services.call('script', 'test', {
|
||||||
'greeting': 'universe',
|
'greeting': 'universe',
|
||||||
})
|
}, context=context)
|
||||||
|
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
assert len(calls) == 2
|
assert len(calls) == 2
|
||||||
assert calls[-1].data['hello'] == 'universe'
|
assert calls[1].context is context
|
||||||
|
assert calls[1].data['hello'] == 'universe'
|
||||||
|
|
||||||
def test_reload_service(self):
|
def test_reload_service(self):
|
||||||
"""Verify that the turn_on service."""
|
"""Verify that the turn_on service."""
|
||||||
|
|
|
@ -4,7 +4,7 @@ from datetime import timedelta
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import Context, callback
|
||||||
# Otherwise can't test just this file (import order issue)
|
# Otherwise can't test just this file (import order issue)
|
||||||
import homeassistant.components # noqa
|
import homeassistant.components # noqa
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
@ -32,6 +32,7 @@ class TestScriptHelper(unittest.TestCase):
|
||||||
def test_firing_event(self):
|
def test_firing_event(self):
|
||||||
"""Test the firing of events."""
|
"""Test the firing of events."""
|
||||||
event = 'test_event'
|
event = 'test_event'
|
||||||
|
context = Context()
|
||||||
calls = []
|
calls = []
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -48,17 +49,19 @@ class TestScriptHelper(unittest.TestCase):
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
script_obj.run()
|
script_obj.run(context=context)
|
||||||
|
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
|
assert calls[0].context is context
|
||||||
assert calls[0].data.get('hello') == 'world'
|
assert calls[0].data.get('hello') == 'world'
|
||||||
assert not script_obj.can_cancel
|
assert not script_obj.can_cancel
|
||||||
|
|
||||||
def test_firing_event_template(self):
|
def test_firing_event_template(self):
|
||||||
"""Test the firing of events."""
|
"""Test the firing of events."""
|
||||||
event = 'test_event'
|
event = 'test_event'
|
||||||
|
context = Context()
|
||||||
calls = []
|
calls = []
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -82,11 +85,12 @@ class TestScriptHelper(unittest.TestCase):
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
script_obj.run({'is_world': 'yes'})
|
script_obj.run({'is_world': 'yes'}, context=context)
|
||||||
|
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
|
assert calls[0].context is context
|
||||||
assert calls[0].data == {
|
assert calls[0].data == {
|
||||||
'dict': {
|
'dict': {
|
||||||
1: 'yes',
|
1: 'yes',
|
||||||
|
@ -100,6 +104,7 @@ class TestScriptHelper(unittest.TestCase):
|
||||||
def test_calling_service(self):
|
def test_calling_service(self):
|
||||||
"""Test the calling of a service."""
|
"""Test the calling of a service."""
|
||||||
calls = []
|
calls = []
|
||||||
|
context = Context()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def record_call(service):
|
def record_call(service):
|
||||||
|
@ -113,16 +118,18 @@ class TestScriptHelper(unittest.TestCase):
|
||||||
'data': {
|
'data': {
|
||||||
'hello': 'world'
|
'hello': 'world'
|
||||||
}
|
}
|
||||||
})
|
}, context=context)
|
||||||
|
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
|
assert calls[0].context is context
|
||||||
assert calls[0].data.get('hello') == 'world'
|
assert calls[0].data.get('hello') == 'world'
|
||||||
|
|
||||||
def test_calling_service_template(self):
|
def test_calling_service_template(self):
|
||||||
"""Test the calling of a service."""
|
"""Test the calling of a service."""
|
||||||
calls = []
|
calls = []
|
||||||
|
context = Context()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def record_call(service):
|
def record_call(service):
|
||||||
|
@ -147,17 +154,19 @@ class TestScriptHelper(unittest.TestCase):
|
||||||
{% endif %}
|
{% endif %}
|
||||||
"""
|
"""
|
||||||
}
|
}
|
||||||
}, {'is_world': 'yes'})
|
}, {'is_world': 'yes'}, context=context)
|
||||||
|
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
assert len(calls) == 1
|
assert len(calls) == 1
|
||||||
|
assert calls[0].context is context
|
||||||
assert calls[0].data.get('hello') == 'world'
|
assert calls[0].data.get('hello') == 'world'
|
||||||
|
|
||||||
def test_delay(self):
|
def test_delay(self):
|
||||||
"""Test the delay."""
|
"""Test the delay."""
|
||||||
event = 'test_event'
|
event = 'test_event'
|
||||||
events = []
|
events = []
|
||||||
|
context = Context()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def record_event(event):
|
def record_event(event):
|
||||||
|
@ -171,7 +180,7 @@ class TestScriptHelper(unittest.TestCase):
|
||||||
{'delay': {'seconds': 5}},
|
{'delay': {'seconds': 5}},
|
||||||
{'event': event}]))
|
{'event': event}]))
|
||||||
|
|
||||||
script_obj.run()
|
script_obj.run(context=context)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
assert script_obj.is_running
|
assert script_obj.is_running
|
||||||
|
@ -185,6 +194,8 @@ class TestScriptHelper(unittest.TestCase):
|
||||||
|
|
||||||
assert not script_obj.is_running
|
assert not script_obj.is_running
|
||||||
assert len(events) == 2
|
assert len(events) == 2
|
||||||
|
assert events[0].context is context
|
||||||
|
assert events[1].context is context
|
||||||
|
|
||||||
def test_delay_template(self):
|
def test_delay_template(self):
|
||||||
"""Test the delay as a template."""
|
"""Test the delay as a template."""
|
||||||
|
@ -282,6 +293,7 @@ class TestScriptHelper(unittest.TestCase):
|
||||||
"""Test the wait template."""
|
"""Test the wait template."""
|
||||||
event = 'test_event'
|
event = 'test_event'
|
||||||
events = []
|
events = []
|
||||||
|
context = Context()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def record_event(event):
|
def record_event(event):
|
||||||
|
@ -297,7 +309,7 @@ class TestScriptHelper(unittest.TestCase):
|
||||||
{'wait_template': "{{states.switch.test.state == 'off'}}"},
|
{'wait_template': "{{states.switch.test.state == 'off'}}"},
|
||||||
{'event': event}]))
|
{'event': event}]))
|
||||||
|
|
||||||
script_obj.run()
|
script_obj.run(context=context)
|
||||||
self.hass.block_till_done()
|
self.hass.block_till_done()
|
||||||
|
|
||||||
assert script_obj.is_running
|
assert script_obj.is_running
|
||||||
|
@ -310,6 +322,8 @@ class TestScriptHelper(unittest.TestCase):
|
||||||
|
|
||||||
assert not script_obj.is_running
|
assert not script_obj.is_running
|
||||||
assert len(events) == 2
|
assert len(events) == 2
|
||||||
|
assert events[0].context is context
|
||||||
|
assert events[1].context is context
|
||||||
|
|
||||||
def test_wait_template_cancel(self):
|
def test_wait_template_cancel(self):
|
||||||
"""Test the wait template cancel action."""
|
"""Test the wait template cancel action."""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue