Add support for simultaneous runs of Script helper (#31937)
* Add tests for legacy Script helper behavior * Add Script helper if_running and run_mode options - if_running controls what happens if Script run while previous run has not completed. Can be: - error: Raise an exception - ignore: Return without doing anything (previous run continues as-is) - parallel: Start run in new task - restart: Stop previous run before starting new run - run_mode controls when call to async_run will return. Can be: - background: Returns immediately - legacy: Implements previous behavior, which is to return when done, or when suspended by delay or wait_template - blocking: Returns when run has completed - If neither is specified, default is run_mode=legacy (and if_running is not used.) Otherwise, defaults are if_running=parallel and run_mode=background. If run_mode is set to legacy then if_running must be None. - Caller may supply a logger which will be used throughout instead of default module logger. - Move Script running state into new helper classes, comprised of an abstract base class and two concrete clases, one for legacy behavior and one for new behavior. - Remove some non-async methods, as well as call_from_config which has only been used in tests. - Adjust tests accordingly. * Change per review - Change run_mode default from background to blocking. - Make sure change listener is called, even when there's an unexpected exception. - Make _ScriptRun.async_stop more graceful by using an asyncio.Event for signaling instead of simply cancelling Task. - Subclass _ScriptRun for background & blocking behavior. Also: - Fix timeouts in _ScriptRun by converting timedeltas to float seconds. - General cleanup. * Change per review 2 - Don't propagate exceptions if call from user has already returned (i.e., for background runs or legacy runs that have suspended.) - Allow user to specify if exceptions should be logged. They will still be logged regardless if exception is not propagated. - Rename _start_script_delay and _start_wait_template_delay for clarity. - Remove return value from Script.async_run. - Fix missing await. - Change call to self.is_running in Script.async_run to direct test of self._runs. * Change per review 3 and add tests - Remove Script.set_logger(). - Enhance existing tests to check all run modes. - Add tests for new features. - Fix a few minor bugs found by tests.
This commit is contained in:
parent
309989be89
commit
b2d7bc40dc
5 changed files with 1774 additions and 881 deletions
|
@ -393,10 +393,8 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
|
|||
|
||||
try:
|
||||
await self.action_script.async_run(variables, trigger_context)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
self.action_script.async_log_exception(
|
||||
_LOGGER, f"Error while executing automation {self.entity_id}", err
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
self._last_triggered = utcnow()
|
||||
await self.async_update_ha_state()
|
||||
|
@ -504,7 +502,9 @@ async def _async_process_config(hass, config, component):
|
|||
hidden = config_block[CONF_HIDE_ENTITY]
|
||||
initial_state = config_block.get(CONF_INITIAL_STATE)
|
||||
|
||||
action_script = script.Script(hass, config_block.get(CONF_ACTION, {}), name)
|
||||
action_script = script.Script(
|
||||
hass, config_block.get(CONF_ACTION, {}), name, logger=_LOGGER
|
||||
)
|
||||
|
||||
if CONF_CONDITION in config_block:
|
||||
cond_func = await _async_process_if(hass, config, config_block)
|
||||
|
|
|
@ -242,7 +242,9 @@ class ScriptEntity(ToggleEntity):
|
|||
self.object_id = object_id
|
||||
self.icon = icon
|
||||
self.entity_id = ENTITY_ID_FORMAT.format(object_id)
|
||||
self.script = Script(hass, sequence, name, self.async_update_ha_state)
|
||||
self.script = Script(
|
||||
hass, sequence, name, self.async_update_ha_state, logger=_LOGGER
|
||||
)
|
||||
|
||||
@property
|
||||
def should_poll(self):
|
||||
|
@ -279,22 +281,15 @@ class ScriptEntity(ToggleEntity):
|
|||
{ATTR_NAME: self.script.name, ATTR_ENTITY_ID: self.entity_id},
|
||||
context=context,
|
||||
)
|
||||
try:
|
||||
await self.script.async_run(kwargs.get(ATTR_VARIABLES), context)
|
||||
except Exception as err:
|
||||
self.script.async_log_exception(
|
||||
_LOGGER, f"Error executing script {self.entity_id}", err
|
||||
)
|
||||
raise err
|
||||
await self.script.async_run(kwargs.get(ATTR_VARIABLES), context)
|
||||
|
||||
async def async_turn_off(self, **kwargs):
|
||||
"""Turn script off."""
|
||||
self.script.async_stop()
|
||||
await self.script.async_stop()
|
||||
|
||||
async def async_will_remove_from_hass(self):
|
||||
"""Stop script and remove service when it will be removed from Home Assistant."""
|
||||
if self.script.is_running:
|
||||
self.script.async_stop()
|
||||
await self.script.async_stop()
|
||||
|
||||
# remove service
|
||||
self.hass.services.async_remove(DOMAIN, self.object_id)
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
"""Helpers to execute scripts."""
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
from itertools import islice
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -31,13 +32,10 @@ from homeassistant.helpers.event import (
|
|||
async_track_template,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
import homeassistant.util.dt as date_util
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONF_ALIAS = "alias"
|
||||
CONF_SERVICE = "service"
|
||||
CONF_SERVICE_DATA = "data"
|
||||
|
@ -50,7 +48,6 @@ CONF_WAIT_TEMPLATE = "wait_template"
|
|||
CONF_CONTINUE = "continue_on_timeout"
|
||||
CONF_SCENE = "scene"
|
||||
|
||||
|
||||
ACTION_DELAY = "delay"
|
||||
ACTION_WAIT_TEMPLATE = "wait_template"
|
||||
ACTION_CHECK_CONDITION = "condition"
|
||||
|
@ -59,6 +56,31 @@ ACTION_CALL_SERVICE = "call_service"
|
|||
ACTION_DEVICE_AUTOMATION = "device"
|
||||
ACTION_ACTIVATE_SCENE = "scene"
|
||||
|
||||
IF_RUNNING_ERROR = "error"
|
||||
IF_RUNNING_IGNORE = "ignore"
|
||||
IF_RUNNING_PARALLEL = "parallel"
|
||||
IF_RUNNING_RESTART = "restart"
|
||||
# First choice is default
|
||||
IF_RUNNING_CHOICES = [
|
||||
IF_RUNNING_PARALLEL,
|
||||
IF_RUNNING_ERROR,
|
||||
IF_RUNNING_IGNORE,
|
||||
IF_RUNNING_RESTART,
|
||||
]
|
||||
|
||||
RUN_MODE_BACKGROUND = "background"
|
||||
RUN_MODE_BLOCKING = "blocking"
|
||||
RUN_MODE_LEGACY = "legacy"
|
||||
# First choice is default
|
||||
RUN_MODE_CHOICES = [
|
||||
RUN_MODE_BLOCKING,
|
||||
RUN_MODE_BACKGROUND,
|
||||
RUN_MODE_LEGACY,
|
||||
]
|
||||
|
||||
_LOG_EXCEPTION = logging.ERROR + 1
|
||||
_TIMEOUT_MSG = "Timeout reached, abort script."
|
||||
|
||||
|
||||
def _determine_action(action):
|
||||
"""Determine action type."""
|
||||
|
@ -83,16 +105,6 @@ def _determine_action(action):
|
|||
return ACTION_CALL_SERVICE
|
||||
|
||||
|
||||
def call_from_config(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
variables: Optional[Sequence] = None,
|
||||
context: Optional[Context] = None,
|
||||
) -> None:
|
||||
"""Call a script based on a config entry."""
|
||||
Script(hass, cv.SCRIPT_SCHEMA(config)).run(variables, context)
|
||||
|
||||
|
||||
async def async_validate_action_config(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> ConfigType:
|
||||
|
@ -121,6 +133,446 @@ class _SuspendScript(Exception):
|
|||
"""Throw if script needs to suspend."""
|
||||
|
||||
|
||||
class _ScriptRunBase(ABC):
|
||||
"""Common data & methods for managing Script sequence run."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
script: "Script",
|
||||
variables: Optional[Sequence],
|
||||
context: Optional[Context],
|
||||
log_exceptions: bool,
|
||||
) -> None:
|
||||
self._hass = hass
|
||||
self._script = script
|
||||
self._variables = variables
|
||||
self._context = context
|
||||
self._log_exceptions = log_exceptions
|
||||
self._step = -1
|
||||
self._action: Optional[Dict[str, Any]] = None
|
||||
|
||||
def _changed(self):
|
||||
self._script._changed() # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def _config_cache(self):
|
||||
return self._script._config_cache # pylint: disable=protected-access
|
||||
|
||||
@abstractmethod
|
||||
async def async_run(self) -> None:
|
||||
"""Run script."""
|
||||
|
||||
async def _async_step(self, log_exceptions):
|
||||
try:
|
||||
await getattr(self, f"_async_{_determine_action(self._action)}_step")()
|
||||
except Exception as err:
|
||||
if not isinstance(err, (_SuspendScript, _StopScript)) and (
|
||||
self._log_exceptions or log_exceptions
|
||||
):
|
||||
self._log_exception(err)
|
||||
raise
|
||||
|
||||
@abstractmethod
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop script run."""
|
||||
|
||||
def _log_exception(self, exception):
|
||||
action_type = _determine_action(self._action)
|
||||
|
||||
error = str(exception)
|
||||
level = logging.ERROR
|
||||
|
||||
if isinstance(exception, vol.Invalid):
|
||||
error_desc = "Invalid data"
|
||||
|
||||
elif isinstance(exception, exceptions.TemplateError):
|
||||
error_desc = "Error rendering template"
|
||||
|
||||
elif isinstance(exception, exceptions.Unauthorized):
|
||||
error_desc = "Unauthorized"
|
||||
|
||||
elif isinstance(exception, exceptions.ServiceNotFound):
|
||||
error_desc = "Service not found"
|
||||
|
||||
else:
|
||||
error_desc = "Unexpected error"
|
||||
level = _LOG_EXCEPTION
|
||||
|
||||
self._log(
|
||||
"Error executing script. %s for %s at pos %s: %s",
|
||||
error_desc,
|
||||
action_type,
|
||||
self._step + 1,
|
||||
error,
|
||||
level=level,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def _async_delay_step(self):
|
||||
"""Handle delay."""
|
||||
|
||||
def _prep_delay_step(self):
|
||||
try:
|
||||
delay = vol.All(cv.time_period, cv.positive_timedelta)(
|
||||
template.render_complex(self._action[CONF_DELAY], self._variables)
|
||||
)
|
||||
except (exceptions.TemplateError, vol.Invalid) as ex:
|
||||
self._raise(
|
||||
"Error rendering %s delay template: %s",
|
||||
self._script.name,
|
||||
ex,
|
||||
exception=_StopScript,
|
||||
)
|
||||
|
||||
self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}")
|
||||
self._log("Executing step %s", self._script.last_action)
|
||||
|
||||
return delay
|
||||
|
||||
@abstractmethod
|
||||
async def _async_wait_template_step(self):
|
||||
"""Handle a wait template."""
|
||||
|
||||
def _prep_wait_template_step(self, async_script_wait):
|
||||
wait_template = self._action[CONF_WAIT_TEMPLATE]
|
||||
wait_template.hass = self._hass
|
||||
|
||||
self._script.last_action = self._action.get(CONF_ALIAS, "wait template")
|
||||
self._log("Executing step %s", self._script.last_action)
|
||||
|
||||
# check if condition already okay
|
||||
if condition.async_template(self._hass, wait_template, self._variables):
|
||||
return None
|
||||
|
||||
return async_track_template(
|
||||
self._hass, wait_template, async_script_wait, self._variables
|
||||
)
|
||||
|
||||
async def _async_call_service_step(self):
|
||||
"""Call the service specified in the action."""
|
||||
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
|
||||
self._log("Executing step %s", self._script.last_action)
|
||||
await service.async_call_from_config(
|
||||
self._hass,
|
||||
self._action,
|
||||
blocking=True,
|
||||
variables=self._variables,
|
||||
validate_config=False,
|
||||
context=self._context,
|
||||
)
|
||||
|
||||
async def _async_device_step(self):
|
||||
"""Perform the device automation specified in the action."""
|
||||
self._script.last_action = self._action.get(CONF_ALIAS, "device automation")
|
||||
self._log("Executing step %s", self._script.last_action)
|
||||
platform = await device_automation.async_get_device_automation_platform(
|
||||
self._hass, self._action[CONF_DOMAIN], "action"
|
||||
)
|
||||
await platform.async_call_action_from_config(
|
||||
self._hass, self._action, self._variables, self._context
|
||||
)
|
||||
|
||||
async def _async_scene_step(self):
|
||||
"""Activate the scene specified in the action."""
|
||||
self._script.last_action = self._action.get(CONF_ALIAS, "activate scene")
|
||||
self._log("Executing step %s", self._script.last_action)
|
||||
await self._hass.services.async_call(
|
||||
scene.DOMAIN,
|
||||
SERVICE_TURN_ON,
|
||||
{ATTR_ENTITY_ID: self._action[CONF_SCENE]},
|
||||
blocking=True,
|
||||
context=self._context,
|
||||
)
|
||||
|
||||
async def _async_event_step(self):
|
||||
"""Fire an event."""
|
||||
self._script.last_action = self._action.get(
|
||||
CONF_ALIAS, self._action[CONF_EVENT]
|
||||
)
|
||||
self._log("Executing step %s", self._script.last_action)
|
||||
event_data = dict(self._action.get(CONF_EVENT_DATA, {}))
|
||||
if CONF_EVENT_DATA_TEMPLATE in self._action:
|
||||
try:
|
||||
event_data.update(
|
||||
template.render_complex(
|
||||
self._action[CONF_EVENT_DATA_TEMPLATE], self._variables
|
||||
)
|
||||
)
|
||||
except exceptions.TemplateError as ex:
|
||||
self._log(
|
||||
"Error rendering event data template: %s", ex, level=logging.ERROR
|
||||
)
|
||||
|
||||
self._hass.bus.async_fire(
|
||||
self._action[CONF_EVENT], event_data, context=self._context
|
||||
)
|
||||
|
||||
async def _async_condition_step(self):
|
||||
"""Test if condition is matching."""
|
||||
config_cache_key = frozenset((k, str(v)) for k, v in self._action.items())
|
||||
config = self._config_cache.get(config_cache_key)
|
||||
if not config:
|
||||
config = await condition.async_from_config(self._hass, self._action, False)
|
||||
self._config_cache[config_cache_key] = config
|
||||
|
||||
self._script.last_action = self._action.get(
|
||||
CONF_ALIAS, self._action[CONF_CONDITION]
|
||||
)
|
||||
check = config(self._hass, self._variables)
|
||||
self._log("Test condition %s: %s", self._script.last_action, check)
|
||||
if not check:
|
||||
raise _StopScript
|
||||
|
||||
def _log(self, msg, *args, level=logging.INFO):
|
||||
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
|
||||
|
||||
def _raise(self, msg, *args, exception=None):
|
||||
# pylint: disable=protected-access
|
||||
self._script._raise(msg, *args, exception=exception)
|
||||
|
||||
|
||||
class _ScriptRun(_ScriptRunBase):
|
||||
"""Manage Script sequence run."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
script: "Script",
|
||||
variables: Optional[Sequence],
|
||||
context: Optional[Context],
|
||||
log_exceptions: bool,
|
||||
) -> None:
|
||||
super().__init__(hass, script, variables, context, log_exceptions)
|
||||
self._stop = asyncio.Event()
|
||||
self._stopped = asyncio.Event()
|
||||
|
||||
async def _async_run(self, propagate_exceptions=True):
|
||||
self._log("Running script")
|
||||
try:
|
||||
for self._step, self._action in enumerate(self._script.sequence):
|
||||
if self._stop.is_set():
|
||||
break
|
||||
await self._async_step(not propagate_exceptions)
|
||||
except _StopScript:
|
||||
pass
|
||||
except Exception: # pylint: disable=broad-except
|
||||
if propagate_exceptions:
|
||||
raise
|
||||
finally:
|
||||
if not self._stop.is_set():
|
||||
self._changed()
|
||||
self._script.last_action = None
|
||||
self._script._runs.remove(self) # pylint: disable=protected-access
|
||||
self._stopped.set()
|
||||
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop script run."""
|
||||
self._stop.set()
|
||||
await self._stopped.wait()
|
||||
|
||||
async def _async_delay_step(self):
|
||||
"""Handle delay."""
|
||||
timeout = self._prep_delay_step().total_seconds()
|
||||
if not self._stop.is_set():
|
||||
self._changed()
|
||||
await asyncio.wait({self._stop.wait()}, timeout=timeout)
|
||||
|
||||
async def _async_wait_template_step(self):
|
||||
"""Handle a wait template."""
|
||||
|
||||
@callback
|
||||
def async_script_wait(entity_id, from_s, to_s):
|
||||
"""Handle script after template condition is true."""
|
||||
done.set()
|
||||
|
||||
unsub = self._prep_wait_template_step(async_script_wait)
|
||||
if not unsub:
|
||||
return
|
||||
|
||||
if not self._stop.is_set():
|
||||
self._changed()
|
||||
try:
|
||||
timeout = self._action[CONF_TIMEOUT].total_seconds()
|
||||
except KeyError:
|
||||
timeout = None
|
||||
done = asyncio.Event()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.wait(
|
||||
{self._stop.wait(), done.wait()},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
if not self._action.get(CONF_CONTINUE, True):
|
||||
self._log(_TIMEOUT_MSG)
|
||||
raise _StopScript
|
||||
finally:
|
||||
unsub()
|
||||
|
||||
|
||||
class _BackgroundScriptRun(_ScriptRun):
|
||||
"""Manage background Script sequence run."""
|
||||
|
||||
async def async_run(self) -> None:
|
||||
"""Run script."""
|
||||
self._hass.async_create_task(self._async_run(False))
|
||||
|
||||
|
||||
class _BlockingScriptRun(_ScriptRun):
|
||||
"""Manage blocking Script sequence run."""
|
||||
|
||||
async def async_run(self) -> None:
|
||||
"""Run script."""
|
||||
try:
|
||||
await asyncio.shield(self._async_run())
|
||||
except asyncio.CancelledError:
|
||||
await self.async_stop()
|
||||
raise
|
||||
|
||||
|
||||
class _LegacyScriptRun(_ScriptRunBase):
|
||||
"""Manage legacy Script sequence run."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
script: "Script",
|
||||
variables: Optional[Sequence],
|
||||
context: Optional[Context],
|
||||
log_exceptions: bool,
|
||||
shared: Optional["_LegacyScriptRun"],
|
||||
) -> None:
|
||||
super().__init__(hass, script, variables, context, log_exceptions)
|
||||
if shared:
|
||||
self._shared = shared
|
||||
else:
|
||||
# To implement legacy behavior we need to share the following "run state"
|
||||
# amongst all runs, so it will only exist in the first instantiation of
|
||||
# concurrent runs, and the rest will use it, too.
|
||||
self._current = -1
|
||||
self._async_listeners: List[CALLBACK_TYPE] = []
|
||||
self._shared = self
|
||||
|
||||
@property
|
||||
def _cur(self):
|
||||
return self._shared._current # pylint: disable=protected-access
|
||||
|
||||
@_cur.setter
|
||||
def _cur(self, value):
|
||||
self._shared._current = value # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def _async_listener(self):
|
||||
return self._shared._async_listeners # pylint: disable=protected-access
|
||||
|
||||
async def async_run(self) -> None:
|
||||
"""Run script."""
|
||||
await self._async_run()
|
||||
|
||||
async def _async_run(self, propagate_exceptions=True):
|
||||
if self._cur == -1:
|
||||
self._log("Running script")
|
||||
self._cur = 0
|
||||
|
||||
# Unregister callback if we were in a delay or wait but turn on is
|
||||
# called again. In that case we just continue execution.
|
||||
self._async_remove_listener()
|
||||
|
||||
suspended = False
|
||||
try:
|
||||
for self._step, self._action in islice(
|
||||
enumerate(self._script.sequence), self._cur, None
|
||||
):
|
||||
await self._async_step(not propagate_exceptions)
|
||||
except _StopScript:
|
||||
pass
|
||||
except _SuspendScript:
|
||||
# Store next step to take and notify change listeners
|
||||
self._cur = self._step + 1
|
||||
suspended = True
|
||||
return
|
||||
except Exception: # pylint: disable=broad-except
|
||||
if propagate_exceptions:
|
||||
raise
|
||||
finally:
|
||||
if self._cur != -1:
|
||||
self._changed()
|
||||
if not suspended:
|
||||
self._script.last_action = None
|
||||
await self.async_stop()
|
||||
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop script run."""
|
||||
if self._cur == -1:
|
||||
return
|
||||
|
||||
self._cur = -1
|
||||
self._async_remove_listener()
|
||||
self._script._runs.clear() # pylint: disable=protected-access
|
||||
|
||||
async def _async_delay_step(self):
|
||||
"""Handle delay."""
|
||||
delay = self._prep_delay_step()
|
||||
|
||||
@callback
|
||||
def async_script_delay(now):
|
||||
"""Handle delay."""
|
||||
with suppress(ValueError):
|
||||
self._async_listener.remove(unsub)
|
||||
self._hass.async_create_task(self._async_run(False))
|
||||
|
||||
unsub = async_track_point_in_utc_time(
|
||||
self._hass, async_script_delay, utcnow() + delay
|
||||
)
|
||||
self._async_listener.append(unsub)
|
||||
raise _SuspendScript
|
||||
|
||||
async def _async_wait_template_step(self):
|
||||
"""Handle a wait template."""
|
||||
|
||||
@callback
|
||||
def async_script_wait(entity_id, from_s, to_s):
|
||||
"""Handle script after template condition is true."""
|
||||
self._async_remove_listener()
|
||||
self._hass.async_create_task(self._async_run(False))
|
||||
|
||||
@callback
|
||||
def async_script_timeout(now):
|
||||
"""Call after timeout is retrieve."""
|
||||
with suppress(ValueError):
|
||||
self._async_listener.remove(unsub)
|
||||
|
||||
# Check if we want to continue to execute
|
||||
# the script after the timeout
|
||||
if self._action.get(CONF_CONTINUE, True):
|
||||
self._hass.async_create_task(self._async_run(False))
|
||||
else:
|
||||
self._log(_TIMEOUT_MSG)
|
||||
self._hass.async_create_task(self.async_stop())
|
||||
|
||||
unsub_wait = self._prep_wait_template_step(async_script_wait)
|
||||
if not unsub_wait:
|
||||
return
|
||||
self._async_listener.append(unsub_wait)
|
||||
|
||||
if CONF_TIMEOUT in self._action:
|
||||
unsub = async_track_point_in_utc_time(
|
||||
self._hass, async_script_timeout, utcnow() + self._action[CONF_TIMEOUT]
|
||||
)
|
||||
self._async_listener.append(unsub)
|
||||
|
||||
raise _SuspendScript
|
||||
|
||||
def _async_remove_listener(self):
|
||||
"""Remove listeners, if any."""
|
||||
for unsub in self._async_listener:
|
||||
unsub()
|
||||
self._async_listener.clear()
|
||||
|
||||
|
||||
class Script:
|
||||
"""Representation of a script."""
|
||||
|
||||
|
@ -130,39 +582,46 @@ class Script:
|
|||
sequence: Sequence[Dict[str, Any]],
|
||||
name: Optional[str] = None,
|
||||
change_listener: Optional[Callable[..., Any]] = None,
|
||||
if_running: Optional[str] = None,
|
||||
run_mode: Optional[str] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
log_exceptions: bool = True,
|
||||
) -> None:
|
||||
"""Initialize the script."""
|
||||
self.hass = hass
|
||||
self._logger = logger or logging.getLogger(__name__)
|
||||
self._hass = hass
|
||||
self.sequence = sequence
|
||||
template.attach(hass, self.sequence)
|
||||
self.name = name
|
||||
self._change_listener = change_listener
|
||||
self._cur = -1
|
||||
self._exception_step: Optional[int] = None
|
||||
self.last_action = None
|
||||
self.last_triggered: Optional[datetime] = None
|
||||
self.can_cancel = any(
|
||||
CONF_DELAY in action or CONF_WAIT_TEMPLATE in action
|
||||
for action in self.sequence
|
||||
)
|
||||
self._async_listener: List[CALLBACK_TYPE] = []
|
||||
if not if_running and not run_mode:
|
||||
self._if_running = IF_RUNNING_PARALLEL
|
||||
self._run_mode = RUN_MODE_LEGACY
|
||||
elif if_running and run_mode == RUN_MODE_LEGACY:
|
||||
self._raise('Cannot use if_running if run_mode is "legacy"')
|
||||
else:
|
||||
self._if_running = if_running or IF_RUNNING_CHOICES[0]
|
||||
self._run_mode = run_mode or RUN_MODE_CHOICES[0]
|
||||
self._runs: List[_ScriptRunBase] = []
|
||||
self._log_exceptions = log_exceptions
|
||||
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {}
|
||||
self._actions = {
|
||||
ACTION_DELAY: self._async_delay,
|
||||
ACTION_WAIT_TEMPLATE: self._async_wait_template,
|
||||
ACTION_CHECK_CONDITION: self._async_check_condition,
|
||||
ACTION_FIRE_EVENT: self._async_fire_event,
|
||||
ACTION_CALL_SERVICE: self._async_call_service,
|
||||
ACTION_DEVICE_AUTOMATION: self._async_device_automation,
|
||||
ACTION_ACTIVATE_SCENE: self._async_activate_scene,
|
||||
}
|
||||
self._referenced_entities: Optional[Set[str]] = None
|
||||
self._referenced_devices: Optional[Set[str]] = None
|
||||
|
||||
def _changed(self):
|
||||
if self._change_listener:
|
||||
self._hass.async_add_job(self._change_listener)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Return true if script is on."""
|
||||
return self._cur != -1
|
||||
return len(self._runs) > 0
|
||||
|
||||
@property
|
||||
def referenced_devices(self):
|
||||
|
@ -223,288 +682,62 @@ class Script:
|
|||
def run(self, variables=None, context=None):
|
||||
"""Run script."""
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.async_run(variables, context), self.hass.loop
|
||||
self.async_run(variables, context), self._hass.loop
|
||||
).result()
|
||||
|
||||
async def async_run(
|
||||
self, variables: Optional[Sequence] = None, context: Optional[Context] = None
|
||||
) -> None:
|
||||
"""Run script.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
self.last_triggered = date_util.utcnow()
|
||||
if self._cur == -1:
|
||||
self._log("Running script")
|
||||
self._cur = 0
|
||||
|
||||
# Unregister callback if we were in a delay or wait but turn on is
|
||||
# called again. In that case we just continue execution.
|
||||
self._async_remove_listener()
|
||||
|
||||
for cur, action in islice(enumerate(self.sequence), self._cur, None):
|
||||
try:
|
||||
await self._handle_action(action, variables, context)
|
||||
except _SuspendScript:
|
||||
# Store next step to take and notify change listeners
|
||||
self._cur = cur + 1
|
||||
if self._change_listener:
|
||||
self.hass.async_add_job(self._change_listener)
|
||||
"""Run script."""
|
||||
if self.is_running:
|
||||
if self._if_running == IF_RUNNING_IGNORE:
|
||||
self._log("Skipping script")
|
||||
return
|
||||
except _StopScript:
|
||||
break
|
||||
except Exception:
|
||||
# Store the step that had an exception
|
||||
self._exception_step = cur
|
||||
# Set script to not running
|
||||
self._cur = -1
|
||||
self.last_action = None
|
||||
# Pass exception on.
|
||||
raise
|
||||
|
||||
# Set script to not-running.
|
||||
self._cur = -1
|
||||
self.last_action = None
|
||||
if self._change_listener:
|
||||
self.hass.async_add_job(self._change_listener)
|
||||
if self._if_running == IF_RUNNING_ERROR:
|
||||
self._raise("Already running")
|
||||
if self._if_running == IF_RUNNING_RESTART:
|
||||
self._log("Restarting script")
|
||||
await self.async_stop()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop running script."""
|
||||
run_callback_threadsafe(self.hass.loop, self.async_stop).result()
|
||||
|
||||
@callback
|
||||
def async_stop(self) -> None:
|
||||
"""Stop running script."""
|
||||
if self._cur == -1:
|
||||
return
|
||||
|
||||
self._cur = -1
|
||||
self._async_remove_listener()
|
||||
if self._change_listener:
|
||||
self.hass.async_add_job(self._change_listener)
|
||||
|
||||
@callback
|
||||
def async_log_exception(self, logger, message_base, exception):
|
||||
"""Log an exception for this script.
|
||||
|
||||
Should only be called on exceptions raised by this scripts async_run.
|
||||
"""
|
||||
step = self._exception_step
|
||||
action = self.sequence[step]
|
||||
action_type = _determine_action(action)
|
||||
|
||||
error = None
|
||||
meth = logger.error
|
||||
|
||||
if isinstance(exception, vol.Invalid):
|
||||
error_desc = "Invalid data"
|
||||
|
||||
elif isinstance(exception, exceptions.TemplateError):
|
||||
error_desc = "Error rendering template"
|
||||
|
||||
elif isinstance(exception, exceptions.Unauthorized):
|
||||
error_desc = "Unauthorized"
|
||||
|
||||
elif isinstance(exception, exceptions.ServiceNotFound):
|
||||
error_desc = "Service not found"
|
||||
|
||||
else:
|
||||
# Print the full stack trace, unknown error
|
||||
error_desc = "Unknown error"
|
||||
meth = logger.exception
|
||||
error = ""
|
||||
|
||||
if error is None:
|
||||
error = str(exception)
|
||||
|
||||
meth(
|
||||
"%s. %s for %s at pos %s: %s",
|
||||
message_base,
|
||||
error_desc,
|
||||
action_type,
|
||||
step + 1,
|
||||
error,
|
||||
)
|
||||
|
||||
async def _handle_action(self, action, variables, context):
|
||||
"""Handle an action."""
|
||||
await self._actions[_determine_action(action)](action, variables, context)
|
||||
|
||||
async def _async_delay(self, action, variables, context):
|
||||
"""Handle delay."""
|
||||
# Call ourselves in the future to continue work
|
||||
unsub = None
|
||||
|
||||
@callback
|
||||
def async_script_delay(now):
|
||||
"""Handle delay."""
|
||||
with suppress(ValueError):
|
||||
self._async_listener.remove(unsub)
|
||||
|
||||
self.hass.async_create_task(self.async_run(variables, context))
|
||||
|
||||
delay = action[CONF_DELAY]
|
||||
|
||||
try:
|
||||
if isinstance(delay, template.Template):
|
||||
delay = vol.All(cv.time_period, cv.positive_timedelta)(
|
||||
delay.async_render(variables)
|
||||
)
|
||||
elif isinstance(delay, dict):
|
||||
delay_data = {}
|
||||
delay_data.update(template.render_complex(delay, variables))
|
||||
delay = cv.time_period(delay_data)
|
||||
except (exceptions.TemplateError, vol.Invalid) as ex:
|
||||
_LOGGER.error("Error rendering '%s' delay template: %s", self.name, ex)
|
||||
raise _StopScript
|
||||
|
||||
self.last_action = action.get(CONF_ALIAS, f"delay {delay}")
|
||||
self._log("Executing step %s" % self.last_action)
|
||||
|
||||
unsub = async_track_point_in_utc_time(
|
||||
self.hass, async_script_delay, date_util.utcnow() + delay
|
||||
)
|
||||
self._async_listener.append(unsub)
|
||||
raise _SuspendScript
|
||||
|
||||
async def _async_wait_template(self, action, variables, context):
|
||||
"""Handle a wait template."""
|
||||
# Call ourselves in the future to continue work
|
||||
wait_template = action[CONF_WAIT_TEMPLATE]
|
||||
wait_template.hass = self.hass
|
||||
|
||||
self.last_action = action.get(CONF_ALIAS, "wait template")
|
||||
self._log("Executing step %s" % self.last_action)
|
||||
|
||||
# check if condition already okay
|
||||
if condition.async_template(self.hass, wait_template, variables):
|
||||
return
|
||||
|
||||
@callback
|
||||
def async_script_wait(entity_id, from_s, to_s):
|
||||
"""Handle script after template condition is true."""
|
||||
self._async_remove_listener()
|
||||
self.hass.async_create_task(self.async_run(variables, context))
|
||||
|
||||
self._async_listener.append(
|
||||
async_track_template(self.hass, wait_template, async_script_wait, variables)
|
||||
)
|
||||
|
||||
if CONF_TIMEOUT in action:
|
||||
self._async_set_timeout(
|
||||
action, variables, context, action.get(CONF_CONTINUE, True)
|
||||
)
|
||||
|
||||
raise _SuspendScript
|
||||
|
||||
async def _async_call_service(self, action, variables, context):
|
||||
"""Call the service specified in the action.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
self.last_action = action.get(CONF_ALIAS, "call service")
|
||||
self._log("Executing step %s" % self.last_action)
|
||||
await service.async_call_from_config(
|
||||
self.hass,
|
||||
action,
|
||||
blocking=True,
|
||||
variables=variables,
|
||||
validate_config=False,
|
||||
context=context,
|
||||
)
|
||||
|
||||
async def _async_device_automation(self, action, variables, context):
|
||||
"""Perform the device automation specified in the action.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
self.last_action = action.get(CONF_ALIAS, "device automation")
|
||||
self._log("Executing step %s" % self.last_action)
|
||||
platform = await device_automation.async_get_device_automation_platform(
|
||||
self.hass, action[CONF_DOMAIN], "action"
|
||||
)
|
||||
await platform.async_call_action_from_config(
|
||||
self.hass, action, variables, context
|
||||
)
|
||||
|
||||
async def _async_activate_scene(self, action, variables, context):
|
||||
"""Activate the scene specified in the action.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
self.last_action = action.get(CONF_ALIAS, "activate scene")
|
||||
self._log("Executing step %s" % self.last_action)
|
||||
await self.hass.services.async_call(
|
||||
scene.DOMAIN,
|
||||
SERVICE_TURN_ON,
|
||||
{ATTR_ENTITY_ID: action[CONF_SCENE]},
|
||||
blocking=True,
|
||||
context=context,
|
||||
)
|
||||
|
||||
async def _async_fire_event(self, action, variables, context):
|
||||
"""Fire an event."""
|
||||
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
|
||||
self._log("Executing step %s" % self.last_action)
|
||||
event_data = dict(action.get(CONF_EVENT_DATA, {}))
|
||||
if CONF_EVENT_DATA_TEMPLATE in action:
|
||||
try:
|
||||
event_data.update(
|
||||
template.render_complex(action[CONF_EVENT_DATA_TEMPLATE], variables)
|
||||
)
|
||||
except exceptions.TemplateError as ex:
|
||||
_LOGGER.error("Error rendering event data template: %s", ex)
|
||||
|
||||
self.hass.bus.async_fire(action[CONF_EVENT], event_data, context=context)
|
||||
|
||||
async def _async_check_condition(self, action, variables, context):
|
||||
"""Test if condition is matching."""
|
||||
config_cache_key = frozenset((k, str(v)) for k, v in action.items())
|
||||
config = self._config_cache.get(config_cache_key)
|
||||
if not config:
|
||||
config = await condition.async_from_config(self.hass, action, False)
|
||||
self._config_cache[config_cache_key] = config
|
||||
|
||||
self.last_action = action.get(CONF_ALIAS, action[CONF_CONDITION])
|
||||
check = config(self.hass, variables)
|
||||
self._log(f"Test condition {self.last_action}: {check}")
|
||||
|
||||
if not check:
|
||||
raise _StopScript
|
||||
|
||||
def _async_set_timeout(self, action, variables, context, continue_on_timeout):
|
||||
"""Schedule a timeout to abort or continue script."""
|
||||
timeout = action[CONF_TIMEOUT]
|
||||
unsub = None
|
||||
|
||||
@callback
|
||||
def async_script_timeout(now):
|
||||
"""Call after timeout is retrieve."""
|
||||
with suppress(ValueError):
|
||||
self._async_listener.remove(unsub)
|
||||
|
||||
# Check if we want to continue to execute
|
||||
# the script after the timeout
|
||||
if continue_on_timeout:
|
||||
self.hass.async_create_task(self.async_run(variables, context))
|
||||
self.last_triggered = utcnow()
|
||||
if self._run_mode == RUN_MODE_LEGACY:
|
||||
if self._runs:
|
||||
shared = cast(Optional[_LegacyScriptRun], self._runs[0])
|
||||
else:
|
||||
self._log("Timeout reached, abort script.")
|
||||
self.async_stop()
|
||||
shared = None
|
||||
run: _ScriptRunBase = _LegacyScriptRun(
|
||||
self._hass, self, variables, context, self._log_exceptions, shared
|
||||
)
|
||||
else:
|
||||
if self._run_mode == RUN_MODE_BACKGROUND:
|
||||
run = _BackgroundScriptRun(
|
||||
self._hass, self, variables, context, self._log_exceptions
|
||||
)
|
||||
else:
|
||||
run = _BlockingScriptRun(
|
||||
self._hass, self, variables, context, self._log_exceptions
|
||||
)
|
||||
self._runs.append(run)
|
||||
await run.async_run()
|
||||
|
||||
unsub = async_track_point_in_utc_time(
|
||||
self.hass, async_script_timeout, date_util.utcnow() + timeout
|
||||
)
|
||||
self._async_listener.append(unsub)
|
||||
async def async_stop(self) -> None:
|
||||
"""Stop running script."""
|
||||
if not self.is_running:
|
||||
return
|
||||
await asyncio.shield(asyncio.gather(*(run.async_stop() for run in self._runs)))
|
||||
self._changed()
|
||||
|
||||
def _async_remove_listener(self):
|
||||
"""Remove point in time listener, if any."""
|
||||
for unsub in self._async_listener:
|
||||
unsub()
|
||||
self._async_listener.clear()
|
||||
def _log(self, msg, *args, level=logging.INFO):
|
||||
if self.name:
|
||||
msg = f"{self.name}: {msg}"
|
||||
if level == _LOG_EXCEPTION:
|
||||
self._logger.exception(msg, *args)
|
||||
else:
|
||||
self._logger.log(level, msg, *args)
|
||||
|
||||
def _log(self, msg):
|
||||
"""Logger helper."""
|
||||
if self.name is not None:
|
||||
msg = f"Script {self.name}: {msg}"
|
||||
|
||||
_LOGGER.info(msg)
|
||||
def _raise(self, msg, *args, exception=None):
|
||||
if not exception:
|
||||
exception = exceptions.HomeAssistantError
|
||||
self._log(msg, *args, level=logging.ERROR)
|
||||
raise exception(msg % args)
|
||||
|
|
|
@ -8,7 +8,7 @@ import voluptuous as vol
|
|||
import homeassistant.components.demo.notify as demo
|
||||
import homeassistant.components.notify as notify
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers import discovery, script
|
||||
from homeassistant.helpers import discovery
|
||||
from homeassistant.setup import setup_component
|
||||
|
||||
from tests.common import assert_setup_component, get_test_home_assistant
|
||||
|
@ -121,7 +121,7 @@ class TestNotifyDemo(unittest.TestCase):
|
|||
def test_calling_notify_from_script_loaded_from_yaml_without_title(self):
|
||||
"""Test if we can call a notify from a script."""
|
||||
self._setup_notify()
|
||||
conf = {
|
||||
step = {
|
||||
"service": "notify.notify",
|
||||
"data": {
|
||||
"data": {
|
||||
|
@ -130,8 +130,8 @@ class TestNotifyDemo(unittest.TestCase):
|
|||
},
|
||||
"data_template": {"message": "Test 123 {{ 2 + 2 }}\n"},
|
||||
}
|
||||
|
||||
script.call_from_config(self.hass, conf)
|
||||
setup_component(self.hass, "script", {"script": {"test": {"sequence": step}}})
|
||||
self.hass.services.call("script", "test")
|
||||
self.hass.block_till_done()
|
||||
assert len(self.events) == 1
|
||||
assert {
|
||||
|
@ -144,7 +144,7 @@ class TestNotifyDemo(unittest.TestCase):
|
|||
def test_calling_notify_from_script_loaded_from_yaml_with_title(self):
|
||||
"""Test if we can call a notify from a script."""
|
||||
self._setup_notify()
|
||||
conf = {
|
||||
step = {
|
||||
"service": "notify.notify",
|
||||
"data": {
|
||||
"data": {
|
||||
|
@ -153,8 +153,8 @@ class TestNotifyDemo(unittest.TestCase):
|
|||
},
|
||||
"data_template": {"message": "Test 123 {{ 2 + 2 }}\n", "title": "Test"},
|
||||
}
|
||||
|
||||
script.call_from_config(self.hass, conf)
|
||||
setup_component(self.hass, "script", {"script": {"test": {"sequence": step}}})
|
||||
self.hass.services.call("script", "test")
|
||||
self.hass.block_till_done()
|
||||
assert len(self.events) == 1
|
||||
assert {
|
||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue