Add support for simultaneous runs of Script helper (#31937)

* Add tests for legacy Script helper behavior

* Add Script helper if_running and run_mode options

- if_running controls what happens if Script run while previous run
  has not completed. Can be:
  - error: Raise an exception
  - ignore: Return without doing anything (previous run continues as-is)
  - parallel: Start run in new task
  - restart: Stop previous run before starting new run
- run_mode controls when call to async_run will return. Can be:
  - background: Returns immediately
  - legacy: Implements previous behavior, which is to return when done,
            or when suspended by delay or wait_template
  - blocking: Returns when run has completed
- If neither is specified, default is run_mode=legacy (and if_running
  is not used.) Otherwise, defaults are if_running=parallel and
  run_mode=background. If run_mode is set to legacy then if_running must
  be None.
- Caller may supply a logger which will be used throughout instead of
  default module logger.
- Move Script running state into new helper classes, comprised of an
  abstract base class and two concrete clases, one for legacy behavior
  and one for new behavior.
- Remove some non-async methods, as well as call_from_config which has
  only been used in tests.
- Adjust tests accordingly.

* Change per review

- Change run_mode default from background to blocking.
- Make sure change listener is called, even when there's an unexpected
  exception.
- Make _ScriptRun.async_stop more graceful by using an asyncio.Event for
  signaling instead of simply cancelling Task.
- Subclass _ScriptRun for background & blocking behavior.

Also:

- Fix timeouts in _ScriptRun by converting timedeltas to float seconds.
- General cleanup.

* Change per review 2

- Don't propagate exceptions if call from user has already returned
  (i.e., for background runs or legacy runs that have suspended.)
- Allow user to specify if exceptions should be logged. They will still
  be logged regardless if exception is not propagated.
- Rename _start_script_delay and _start_wait_template_delay for
  clarity.
- Remove return value from Script.async_run.
- Fix missing await.
- Change call to self.is_running in Script.async_run to direct test of
  self._runs.

* Change per review 3 and add tests

- Remove Script.set_logger().
- Enhance existing tests to check all run modes.
- Add tests for new features.
- Fix a few minor bugs found by tests.
This commit is contained in:
Phil Bruckner 2020-02-24 16:56:00 -06:00 committed by GitHub
parent 309989be89
commit b2d7bc40dc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 1774 additions and 881 deletions

View file

@ -393,10 +393,8 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
try:
await self.action_script.async_run(variables, trigger_context)
except Exception as err: # pylint: disable=broad-except
self.action_script.async_log_exception(
_LOGGER, f"Error while executing automation {self.entity_id}", err
)
except Exception: # pylint: disable=broad-except
pass
self._last_triggered = utcnow()
await self.async_update_ha_state()
@ -504,7 +502,9 @@ async def _async_process_config(hass, config, component):
hidden = config_block[CONF_HIDE_ENTITY]
initial_state = config_block.get(CONF_INITIAL_STATE)
action_script = script.Script(hass, config_block.get(CONF_ACTION, {}), name)
action_script = script.Script(
hass, config_block.get(CONF_ACTION, {}), name, logger=_LOGGER
)
if CONF_CONDITION in config_block:
cond_func = await _async_process_if(hass, config, config_block)

View file

@ -242,7 +242,9 @@ class ScriptEntity(ToggleEntity):
self.object_id = object_id
self.icon = icon
self.entity_id = ENTITY_ID_FORMAT.format(object_id)
self.script = Script(hass, sequence, name, self.async_update_ha_state)
self.script = Script(
hass, sequence, name, self.async_update_ha_state, logger=_LOGGER
)
@property
def should_poll(self):
@ -279,22 +281,15 @@ class ScriptEntity(ToggleEntity):
{ATTR_NAME: self.script.name, ATTR_ENTITY_ID: self.entity_id},
context=context,
)
try:
await self.script.async_run(kwargs.get(ATTR_VARIABLES), context)
except Exception as err:
self.script.async_log_exception(
_LOGGER, f"Error executing script {self.entity_id}", err
)
raise err
await self.script.async_run(kwargs.get(ATTR_VARIABLES), context)
async def async_turn_off(self, **kwargs):
"""Turn script off."""
self.script.async_stop()
await self.script.async_stop()
async def async_will_remove_from_hass(self):
"""Stop script and remove service when it will be removed from Home Assistant."""
if self.script.is_running:
self.script.async_stop()
await self.script.async_stop()
# remove service
self.hass.services.async_remove(DOMAIN, self.object_id)

View file

@ -1,10 +1,11 @@
"""Helpers to execute scripts."""
from abc import ABC, abstractmethod
import asyncio
from contextlib import suppress
from datetime import datetime
from itertools import islice
import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
import voluptuous as vol
@ -31,13 +32,10 @@ from homeassistant.helpers.event import (
async_track_template,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.util.async_ import run_callback_threadsafe
import homeassistant.util.dt as date_util
from homeassistant.util.dt import utcnow
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__)
CONF_ALIAS = "alias"
CONF_SERVICE = "service"
CONF_SERVICE_DATA = "data"
@ -50,7 +48,6 @@ CONF_WAIT_TEMPLATE = "wait_template"
CONF_CONTINUE = "continue_on_timeout"
CONF_SCENE = "scene"
ACTION_DELAY = "delay"
ACTION_WAIT_TEMPLATE = "wait_template"
ACTION_CHECK_CONDITION = "condition"
@ -59,6 +56,31 @@ ACTION_CALL_SERVICE = "call_service"
ACTION_DEVICE_AUTOMATION = "device"
ACTION_ACTIVATE_SCENE = "scene"
IF_RUNNING_ERROR = "error"
IF_RUNNING_IGNORE = "ignore"
IF_RUNNING_PARALLEL = "parallel"
IF_RUNNING_RESTART = "restart"
# First choice is default
IF_RUNNING_CHOICES = [
IF_RUNNING_PARALLEL,
IF_RUNNING_ERROR,
IF_RUNNING_IGNORE,
IF_RUNNING_RESTART,
]
RUN_MODE_BACKGROUND = "background"
RUN_MODE_BLOCKING = "blocking"
RUN_MODE_LEGACY = "legacy"
# First choice is default
RUN_MODE_CHOICES = [
RUN_MODE_BLOCKING,
RUN_MODE_BACKGROUND,
RUN_MODE_LEGACY,
]
_LOG_EXCEPTION = logging.ERROR + 1
_TIMEOUT_MSG = "Timeout reached, abort script."
def _determine_action(action):
"""Determine action type."""
@ -83,16 +105,6 @@ def _determine_action(action):
return ACTION_CALL_SERVICE
def call_from_config(
hass: HomeAssistant,
config: ConfigType,
variables: Optional[Sequence] = None,
context: Optional[Context] = None,
) -> None:
"""Call a script based on a config entry."""
Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables, context)
async def async_validate_action_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
@ -121,6 +133,446 @@ class _SuspendScript(Exception):
"""Throw if script needs to suspend."""
class _ScriptRunBase(ABC):
"""Common data & methods for managing Script sequence run."""
def __init__(
self,
hass: HomeAssistant,
script: "Script",
variables: Optional[Sequence],
context: Optional[Context],
log_exceptions: bool,
) -> None:
self._hass = hass
self._script = script
self._variables = variables
self._context = context
self._log_exceptions = log_exceptions
self._step = -1
self._action: Optional[Dict[str, Any]] = None
def _changed(self):
self._script._changed() # pylint: disable=protected-access
@property
def _config_cache(self):
return self._script._config_cache # pylint: disable=protected-access
@abstractmethod
async def async_run(self) -> None:
"""Run script."""
async def _async_step(self, log_exceptions):
try:
await getattr(self, f"_async_{_determine_action(self._action)}_step")()
except Exception as err:
if not isinstance(err, (_SuspendScript, _StopScript)) and (
self._log_exceptions or log_exceptions
):
self._log_exception(err)
raise
@abstractmethod
async def async_stop(self) -> None:
"""Stop script run."""
def _log_exception(self, exception):
action_type = _determine_action(self._action)
error = str(exception)
level = logging.ERROR
if isinstance(exception, vol.Invalid):
error_desc = "Invalid data"
elif isinstance(exception, exceptions.TemplateError):
error_desc = "Error rendering template"
elif isinstance(exception, exceptions.Unauthorized):
error_desc = "Unauthorized"
elif isinstance(exception, exceptions.ServiceNotFound):
error_desc = "Service not found"
else:
error_desc = "Unexpected error"
level = _LOG_EXCEPTION
self._log(
"Error executing script. %s for %s at pos %s: %s",
error_desc,
action_type,
self._step + 1,
error,
level=level,
)
@abstractmethod
async def _async_delay_step(self):
"""Handle delay."""
def _prep_delay_step(self):
try:
delay = vol.All(cv.time_period, cv.positive_timedelta)(
template.render_complex(self._action[CONF_DELAY], self._variables)
)
except (exceptions.TemplateError, vol.Invalid) as ex:
self._raise(
"Error rendering %s delay template: %s",
self._script.name,
ex,
exception=_StopScript,
)
self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}")
self._log("Executing step %s", self._script.last_action)
return delay
@abstractmethod
async def _async_wait_template_step(self):
"""Handle a wait template."""
def _prep_wait_template_step(self, async_script_wait):
wait_template = self._action[CONF_WAIT_TEMPLATE]
wait_template.hass = self._hass
self._script.last_action = self._action.get(CONF_ALIAS, "wait template")
self._log("Executing step %s", self._script.last_action)
# check if condition already okay
if condition.async_template(self._hass, wait_template, self._variables):
return None
return async_track_template(
self._hass, wait_template, async_script_wait, self._variables
)
async def _async_call_service_step(self):
"""Call the service specified in the action."""
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
self._log("Executing step %s", self._script.last_action)
await service.async_call_from_config(
self._hass,
self._action,
blocking=True,
variables=self._variables,
validate_config=False,
context=self._context,
)
async def _async_device_step(self):
"""Perform the device automation specified in the action."""
self._script.last_action = self._action.get(CONF_ALIAS, "device automation")
self._log("Executing step %s", self._script.last_action)
platform = await device_automation.async_get_device_automation_platform(
self._hass, self._action[CONF_DOMAIN], "action"
)
await platform.async_call_action_from_config(
self._hass, self._action, self._variables, self._context
)
async def _async_scene_step(self):
"""Activate the scene specified in the action."""
self._script.last_action = self._action.get(CONF_ALIAS, "activate scene")
self._log("Executing step %s", self._script.last_action)
await self._hass.services.async_call(
scene.DOMAIN,
SERVICE_TURN_ON,
{ATTR_ENTITY_ID: self._action[CONF_SCENE]},
blocking=True,
context=self._context,
)
async def _async_event_step(self):
"""Fire an event."""
self._script.last_action = self._action.get(
CONF_ALIAS, self._action[CONF_EVENT]
)
self._log("Executing step %s", self._script.last_action)
event_data = dict(self._action.get(CONF_EVENT_DATA, {}))
if CONF_EVENT_DATA_TEMPLATE in self._action:
try:
event_data.update(
template.render_complex(
self._action[CONF_EVENT_DATA_TEMPLATE], self._variables
)
)
except exceptions.TemplateError as ex:
self._log(
"Error rendering event data template: %s", ex, level=logging.ERROR
)
self._hass.bus.async_fire(
self._action[CONF_EVENT], event_data, context=self._context
)
async def _async_condition_step(self):
"""Test if condition is matching."""
config_cache_key = frozenset((k, str(v)) for k, v in self._action.items())
config = self._config_cache.get(config_cache_key)
if not config:
config = await condition.async_from_config(self._hass, self._action, False)
self._config_cache[config_cache_key] = config
self._script.last_action = self._action.get(
CONF_ALIAS, self._action[CONF_CONDITION]
)
check = config(self._hass, self._variables)
self._log("Test condition %s: %s", self._script.last_action, check)
if not check:
raise _StopScript
def _log(self, msg, *args, level=logging.INFO):
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
def _raise(self, msg, *args, exception=None):
# pylint: disable=protected-access
self._script._raise(msg, *args, exception=exception)
class _ScriptRun(_ScriptRunBase):
"""Manage Script sequence run."""
def __init__(
self,
hass: HomeAssistant,
script: "Script",
variables: Optional[Sequence],
context: Optional[Context],
log_exceptions: bool,
) -> None:
super().__init__(hass, script, variables, context, log_exceptions)
self._stop = asyncio.Event()
self._stopped = asyncio.Event()
async def _async_run(self, propagate_exceptions=True):
self._log("Running script")
try:
for self._step, self._action in enumerate(self._script.sequence):
if self._stop.is_set():
break
await self._async_step(not propagate_exceptions)
except _StopScript:
pass
except Exception: # pylint: disable=broad-except
if propagate_exceptions:
raise
finally:
if not self._stop.is_set():
self._changed()
self._script.last_action = None
self._script._runs.remove(self) # pylint: disable=protected-access
self._stopped.set()
async def async_stop(self) -> None:
"""Stop script run."""
self._stop.set()
await self._stopped.wait()
async def _async_delay_step(self):
"""Handle delay."""
timeout = self._prep_delay_step().total_seconds()
if not self._stop.is_set():
self._changed()
await asyncio.wait({self._stop.wait()}, timeout=timeout)
async def _async_wait_template_step(self):
"""Handle a wait template."""
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
done.set()
unsub = self._prep_wait_template_step(async_script_wait)
if not unsub:
return
if not self._stop.is_set():
self._changed()
try:
timeout = self._action[CONF_TIMEOUT].total_seconds()
except KeyError:
timeout = None
done = asyncio.Event()
try:
await asyncio.wait_for(
asyncio.wait(
{self._stop.wait(), done.wait()},
return_when=asyncio.FIRST_COMPLETED,
),
timeout,
)
except asyncio.TimeoutError:
if not self._action.get(CONF_CONTINUE, True):
self._log(_TIMEOUT_MSG)
raise _StopScript
finally:
unsub()
class _BackgroundScriptRun(_ScriptRun):
"""Manage background Script sequence run."""
async def async_run(self) -> None:
"""Run script."""
self._hass.async_create_task(self._async_run(False))
class _BlockingScriptRun(_ScriptRun):
"""Manage blocking Script sequence run."""
async def async_run(self) -> None:
"""Run script."""
try:
await asyncio.shield(self._async_run())
except asyncio.CancelledError:
await self.async_stop()
raise
class _LegacyScriptRun(_ScriptRunBase):
"""Manage legacy Script sequence run."""
def __init__(
self,
hass: HomeAssistant,
script: "Script",
variables: Optional[Sequence],
context: Optional[Context],
log_exceptions: bool,
shared: Optional["_LegacyScriptRun"],
) -> None:
super().__init__(hass, script, variables, context, log_exceptions)
if shared:
self._shared = shared
else:
# To implement legacy behavior we need to share the following "run state"
# amongst all runs, so it will only exist in the first instantiation of
# concurrent runs, and the rest will use it, too.
self._current = -1
self._async_listeners: List[CALLBACK_TYPE] = []
self._shared = self
@property
def _cur(self):
return self._shared._current # pylint: disable=protected-access
@_cur.setter
def _cur(self, value):
self._shared._current = value # pylint: disable=protected-access
@property
def _async_listener(self):
return self._shared._async_listeners # pylint: disable=protected-access
async def async_run(self) -> None:
"""Run script."""
await self._async_run()
async def _async_run(self, propagate_exceptions=True):
if self._cur == -1:
self._log("Running script")
self._cur = 0
# Unregister callback if we were in a delay or wait but turn on is
# called again. In that case we just continue execution.
self._async_remove_listener()
suspended = False
try:
for self._step, self._action in islice(
enumerate(self._script.sequence), self._cur, None
):
await self._async_step(not propagate_exceptions)
except _StopScript:
pass
except _SuspendScript:
# Store next step to take and notify change listeners
self._cur = self._step + 1
suspended = True
return
except Exception: # pylint: disable=broad-except
if propagate_exceptions:
raise
finally:
if self._cur != -1:
self._changed()
if not suspended:
self._script.last_action = None
await self.async_stop()
async def async_stop(self) -> None:
"""Stop script run."""
if self._cur == -1:
return
self._cur = -1
self._async_remove_listener()
self._script._runs.clear() # pylint: disable=protected-access
async def _async_delay_step(self):
"""Handle delay."""
delay = self._prep_delay_step()
@callback
def async_script_delay(now):
"""Handle delay."""
with suppress(ValueError):
self._async_listener.remove(unsub)
self._hass.async_create_task(self._async_run(False))
unsub = async_track_point_in_utc_time(
self._hass, async_script_delay, utcnow() + delay
)
self._async_listener.append(unsub)
raise _SuspendScript
async def _async_wait_template_step(self):
"""Handle a wait template."""
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
self._async_remove_listener()
self._hass.async_create_task(self._async_run(False))
@callback
def async_script_timeout(now):
"""Call after timeout is retrieve."""
with suppress(ValueError):
self._async_listener.remove(unsub)
# Check if we want to continue to execute
# the script after the timeout
if self._action.get(CONF_CONTINUE, True):
self._hass.async_create_task(self._async_run(False))
else:
self._log(_TIMEOUT_MSG)
self._hass.async_create_task(self.async_stop())
unsub_wait = self._prep_wait_template_step(async_script_wait)
if not unsub_wait:
return
self._async_listener.append(unsub_wait)
if CONF_TIMEOUT in self._action:
unsub = async_track_point_in_utc_time(
self._hass, async_script_timeout, utcnow() + self._action[CONF_TIMEOUT]
)
self._async_listener.append(unsub)
raise _SuspendScript
def _async_remove_listener(self):
"""Remove listeners, if any."""
for unsub in self._async_listener:
unsub()
self._async_listener.clear()
class Script:
"""Representation of a script."""
@ -130,39 +582,46 @@ class Script:
sequence: Sequence[Dict[str, Any]],
name: Optional[str] = None,
change_listener: Optional[Callable[..., Any]] = None,
if_running: Optional[str] = None,
run_mode: Optional[str] = None,
logger: Optional[logging.Logger] = None,
log_exceptions: bool = True,
) -> None:
"""Initialize the script."""
self.hass = hass
self._logger = logger or logging.getLogger(__name__)
self._hass = hass
self.sequence = sequence
template.attach(hass, self.sequence)
self.name = name
self._change_listener = change_listener
self._cur = -1
self._exception_step: Optional[int] = None
self.last_action = None
self.last_triggered: Optional[datetime] = None
self.can_cancel = any(
CONF_DELAY in action or CONF_WAIT_TEMPLATE in action
for action in self.sequence
)
self._async_listener: List[CALLBACK_TYPE] = []
if not if_running and not run_mode:
self._if_running = IF_RUNNING_PARALLEL
self._run_mode = RUN_MODE_LEGACY
elif if_running and run_mode == RUN_MODE_LEGACY:
self._raise('Cannot use if_running if run_mode is "legacy"')
else:
self._if_running = if_running or IF_RUNNING_CHOICES[0]
self._run_mode = run_mode or RUN_MODE_CHOICES[0]
self._runs: List[_ScriptRunBase] = []
self._log_exceptions = log_exceptions
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {}
self._actions = {
ACTION_DELAY: self._async_delay,
ACTION_WAIT_TEMPLATE: self._async_wait_template,
ACTION_CHECK_CONDITION: self._async_check_condition,
ACTION_FIRE_EVENT: self._async_fire_event,
ACTION_CALL_SERVICE: self._async_call_service,
ACTION_DEVICE_AUTOMATION: self._async_device_automation,
ACTION_ACTIVATE_SCENE: self._async_activate_scene,
}
self._referenced_entities: Optional[Set[str]] = None
self._referenced_devices: Optional[Set[str]] = None
def _changed(self):
if self._change_listener:
self._hass.async_add_job(self._change_listener)
@property
def is_running(self) -> bool:
"""Return true if script is on."""
return self._cur != -1
return len(self._runs) > 0
@property
def referenced_devices(self):
@ -223,288 +682,62 @@ class Script:
def run(self, variables=None, context=None):
"""Run script."""
asyncio.run_coroutine_threadsafe(
self.async_run(variables, context), self.hass.loop
self.async_run(variables, context), self._hass.loop
).result()
async def async_run(
self, variables: Optional[Sequence] = None, context: Optional[Context] = None
) -> None:
"""Run script.
This method is a coroutine.
"""
self.last_triggered = date_util.utcnow()
if self._cur == -1:
self._log("Running script")
self._cur = 0
# Unregister callback if we were in a delay or wait but turn on is
# called again. In that case we just continue execution.
self._async_remove_listener()
for cur, action in islice(enumerate(self.sequence), self._cur, None):
try:
await self._handle_action(action, variables, context)
except _SuspendScript:
# Store next step to take and notify change listeners
self._cur = cur + 1
if self._change_listener:
self.hass.async_add_job(self._change_listener)
"""Run script."""
if self.is_running:
if self._if_running == IF_RUNNING_IGNORE:
self._log("Skipping script")
return
except _StopScript:
break
except Exception:
# Store the step that had an exception
self._exception_step = cur
# Set script to not running
self._cur = -1
self.last_action = None
# Pass exception on.
raise
# Set script to not-running.
self._cur = -1
self.last_action = None
if self._change_listener:
self.hass.async_add_job(self._change_listener)
if self._if_running == IF_RUNNING_ERROR:
self._raise("Already running")
if self._if_running == IF_RUNNING_RESTART:
self._log("Restarting script")
await self.async_stop()
def stop(self) -> None:
"""Stop running script."""
run_callback_threadsafe(self.hass.loop, self.async_stop).result()
@callback
def async_stop(self) -> None:
"""Stop running script."""
if self._cur == -1:
return
self._cur = -1
self._async_remove_listener()
if self._change_listener:
self.hass.async_add_job(self._change_listener)
@callback
def async_log_exception(self, logger, message_base, exception):
"""Log an exception for this script.
Should only be called on exceptions raised by this scripts async_run.
"""
step = self._exception_step
action = self.sequence[step]
action_type = _determine_action(action)
error = None
meth = logger.error
if isinstance(exception, vol.Invalid):
error_desc = "Invalid data"
elif isinstance(exception, exceptions.TemplateError):
error_desc = "Error rendering template"
elif isinstance(exception, exceptions.Unauthorized):
error_desc = "Unauthorized"
elif isinstance(exception, exceptions.ServiceNotFound):
error_desc = "Service not found"
else:
# Print the full stack trace, unknown error
error_desc = "Unknown error"
meth = logger.exception
error = ""
if error is None:
error = str(exception)
meth(
"%s. %s for %s at pos %s: %s",
message_base,
error_desc,
action_type,
step + 1,
error,
)
async def _handle_action(self, action, variables, context):
"""Handle an action."""
await self._actions[_determine_action(action)](action, variables, context)
async def _async_delay(self, action, variables, context):
"""Handle delay."""
# Call ourselves in the future to continue work
unsub = None
@callback
def async_script_delay(now):
"""Handle delay."""
with suppress(ValueError):
self._async_listener.remove(unsub)
self.hass.async_create_task(self.async_run(variables, context))
delay = action[CONF_DELAY]
try:
if isinstance(delay, template.Template):
delay = vol.All(cv.time_period, cv.positive_timedelta)(
delay.async_render(variables)
)
elif isinstance(delay, dict):
delay_data = {}
delay_data.update(template.render_complex(delay, variables))
delay = cv.time_period(delay_data)
except (exceptions.TemplateError, vol.Invalid) as ex:
_LOGGER.error("Error rendering '%s' delay template: %s", self.name, ex)
raise _StopScript
self.last_action = action.get(CONF_ALIAS, f"delay {delay}")
self._log("Executing step %s" % self.last_action)
unsub = async_track_point_in_utc_time(
self.hass, async_script_delay, date_util.utcnow() + delay
)
self._async_listener.append(unsub)
raise _SuspendScript
async def _async_wait_template(self, action, variables, context):
"""Handle a wait template."""
# Call ourselves in the future to continue work
wait_template = action[CONF_WAIT_TEMPLATE]
wait_template.hass = self.hass
self.last_action = action.get(CONF_ALIAS, "wait template")
self._log("Executing step %s" % self.last_action)
# check if condition already okay
if condition.async_template(self.hass, wait_template, variables):
return
@callback
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
self._async_remove_listener()
self.hass.async_create_task(self.async_run(variables, context))
self._async_listener.append(
async_track_template(self.hass, wait_template, async_script_wait, variables)
)
if CONF_TIMEOUT in action:
self._async_set_timeout(
action, variables, context, action.get(CONF_CONTINUE, True)
)
raise _SuspendScript
async def _async_call_service(self, action, variables, context):
"""Call the service specified in the action.
This method is a coroutine.
"""
self.last_action = action.get(CONF_ALIAS, "call service")
self._log("Executing step %s" % self.last_action)
await service.async_call_from_config(
self.hass,
action,
blocking=True,
variables=variables,
validate_config=False,
context=context,
)
async def _async_device_automation(self, action, variables, context):
"""Perform the device automation specified in the action.
This method is a coroutine.
"""
self.last_action = action.get(CONF_ALIAS, "device automation")
self._log("Executing step %s" % self.last_action)
platform = await device_automation.async_get_device_automation_platform(
self.hass, action[CONF_DOMAIN], "action"
)
await platform.async_call_action_from_config(
self.hass, action, variables, context
)
async def _async_activate_scene(self, action, variables, context):
"""Activate the scene specified in the action.
This method is a coroutine.
"""
self.last_action = action.get(CONF_ALIAS, "activate scene")
self._log("Executing step %s" % self.last_action)
await self.hass.services.async_call(
scene.DOMAIN,
SERVICE_TURN_ON,
{ATTR_ENTITY_ID: action[CONF_SCENE]},
blocking=True,
context=context,
)
async def _async_fire_event(self, action, variables, context):
"""Fire an event."""
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
self._log("Executing step %s" % self.last_action)
event_data = dict(action.get(CONF_EVENT_DATA, {}))
if CONF_EVENT_DATA_TEMPLATE in action:
try:
event_data.update(
template.render_complex(action[CONF_EVENT_DATA_TEMPLATE], variables)
)
except exceptions.TemplateError as ex:
_LOGGER.error("Error rendering event data template: %s", ex)
self.hass.bus.async_fire(action[CONF_EVENT], event_data, context=context)
async def _async_check_condition(self, action, variables, context):
"""Test if condition is matching."""
config_cache_key = frozenset((k, str(v)) for k, v in action.items())
config = self._config_cache.get(config_cache_key)
if not config:
config = await condition.async_from_config(self.hass, action, False)
self._config_cache[config_cache_key] = config
self.last_action = action.get(CONF_ALIAS, action[CONF_CONDITION])
check = config(self.hass, variables)
self._log(f"Test condition {self.last_action}: {check}")
if not check:
raise _StopScript
def _async_set_timeout(self, action, variables, context, continue_on_timeout):
"""Schedule a timeout to abort or continue script."""
timeout = action[CONF_TIMEOUT]
unsub = None
@callback
def async_script_timeout(now):
"""Call after timeout is retrieve."""
with suppress(ValueError):
self._async_listener.remove(unsub)
# Check if we want to continue to execute
# the script after the timeout
if continue_on_timeout:
self.hass.async_create_task(self.async_run(variables, context))
self.last_triggered = utcnow()
if self._run_mode == RUN_MODE_LEGACY:
if self._runs:
shared = cast(Optional[_LegacyScriptRun], self._runs[0])
else:
self._log("Timeout reached, abort script.")
self.async_stop()
shared = None
run: _ScriptRunBase = _LegacyScriptRun(
self._hass, self, variables, context, self._log_exceptions, shared
)
else:
if self._run_mode == RUN_MODE_BACKGROUND:
run = _BackgroundScriptRun(
self._hass, self, variables, context, self._log_exceptions
)
else:
run = _BlockingScriptRun(
self._hass, self, variables, context, self._log_exceptions
)
self._runs.append(run)
await run.async_run()
unsub = async_track_point_in_utc_time(
self.hass, async_script_timeout, date_util.utcnow() + timeout
)
self._async_listener.append(unsub)
async def async_stop(self) -> None:
"""Stop running script."""
if not self.is_running:
return
await asyncio.shield(asyncio.gather(*(run.async_stop() for run in self._runs)))
self._changed()
def _async_remove_listener(self):
"""Remove point in time listener, if any."""
for unsub in self._async_listener:
unsub()
self._async_listener.clear()
def _log(self, msg, *args, level=logging.INFO):
if self.name:
msg = f"{self.name}: {msg}"
if level == _LOG_EXCEPTION:
self._logger.exception(msg, *args)
else:
self._logger.log(level, msg, *args)
def _log(self, msg):
"""Logger helper."""
if self.name is not None:
msg = f"Script {self.name}: {msg}"
_LOGGER.info(msg)
def _raise(self, msg, *args, exception=None):
if not exception:
exception = exceptions.HomeAssistantError
self._log(msg, *args, level=logging.ERROR)
raise exception(msg % args)

View file

@ -8,7 +8,7 @@ import voluptuous as vol
import homeassistant.components.demo.notify as demo
import homeassistant.components.notify as notify
from homeassistant.core import callback
from homeassistant.helpers import discovery, script
from homeassistant.helpers import discovery
from homeassistant.setup import setup_component
from tests.common import assert_setup_component, get_test_home_assistant
@ -121,7 +121,7 @@ class TestNotifyDemo(unittest.TestCase):
def test_calling_notify_from_script_loaded_from_yaml_without_title(self):
"""Test if we can call a notify from a script."""
self._setup_notify()
conf = {
step = {
"service": "notify.notify",
"data": {
"data": {
@ -130,8 +130,8 @@ class TestNotifyDemo(unittest.TestCase):
},
"data_template": {"message": "Test 123 {{ 2 + 2 }}\n"},
}
script.call_from_config(self.hass, conf)
setup_component(self.hass, "script", {"script": {"test": {"sequence": step}}})
self.hass.services.call("script", "test")
self.hass.block_till_done()
assert len(self.events) == 1
assert {
@ -144,7 +144,7 @@ class TestNotifyDemo(unittest.TestCase):
def test_calling_notify_from_script_loaded_from_yaml_with_title(self):
"""Test if we can call a notify from a script."""
self._setup_notify()
conf = {
step = {
"service": "notify.notify",
"data": {
"data": {
@ -153,8 +153,8 @@ class TestNotifyDemo(unittest.TestCase):
},
"data_template": {"message": "Test 123 {{ 2 + 2 }}\n", "title": "Test"},
}
script.call_from_config(self.hass, conf)
setup_component(self.hass, "script", {"script": {"test": {"sequence": step}}})
self.hass.services.call("script", "test")
self.hass.block_till_done()
assert len(self.events) == 1
assert {

File diff suppressed because it is too large Load diff