Add support for simultaneous runs of Script helper - Part 2 (#32442)

* Add limit parameter to service call methods

* Break out prep part of async_call_from_config for use elsewhere

* Minor cleanup

* Fix improper use of asyncio.wait

* Fix state update

Call change listener immediately if its a callback

* Fix exception handling and logging

* Merge Script helper if_running/run_mode parameters into script_mode

- Remove background/blocking _ScriptRun subclasses which are no longer needed.

* Add queued script mode

* Disable timeout when making fully blocking script call

* Don't call change listener when restarting script

This makes restart mode behavior consistent with parallel & queue modes.

* Changes per review

- Call all script services (except script.turn_off) with no time limit.
- Fix handling of lock in _QueuedScriptRun and add comments to make it
  clearer how this code works.

* Changes per review 2

- Move cancel shielding "up" from _ScriptRun.async_run to Script.async_run
  (and apply to new style scripts only.) This makes sure Script class also
  properly handles cancellation which it wasn't doing before.
- In _ScriptRun._async_call_service_step, instead of using script.turn_off
  service, just cancel service call and let it handle the cancellation
  accordingly.

* Fix bugs

- Add missing call to change listener in Script.async_run
  in cancelled path.
- Cancel service task if ServiceRegistry.async_call cancelled.

* Revert last changes to ServiceRegistry.async_call

* Minor Script helper fixes & test improvements

- Don't log asyncio.CancelledError exceptions.
- Make change_listener a public attribute.
- Test overhaul
  - Parametrize tests.
  - Use common test functions.
  - Mock timeout so tests don't need to wait for real time to elapse.
  - Add common function for waiting for script action step.
This commit is contained in:
Phil Bruckner 2020-03-11 18:34:50 -05:00 committed by GitHub
parent da761fdd39
commit 5f5cb8bea8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 948 additions and 1444 deletions

View file

@ -7,6 +7,7 @@ from itertools import islice
import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
from async_timeout import timeout
import voluptuous as vol
from homeassistant import exceptions
@ -14,6 +15,7 @@ import homeassistant.components.device_automation as device_automation
import homeassistant.components.scene as scene
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_ALIAS,
CONF_CONDITION,
CONF_CONTINUE_ON_TIMEOUT,
CONF_DELAY,
@ -25,47 +27,53 @@ from homeassistant.const import (
CONF_SCENE,
CONF_TIMEOUT,
CONF_WAIT_TEMPLATE,
SERVICE_TURN_OFF,
SERVICE_TURN_ON,
)
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
from homeassistant.core import (
CALLBACK_TYPE,
SERVICE_CALL_LIMIT,
Context,
HomeAssistant,
callback,
is_callback,
)
from homeassistant.helpers import (
condition,
config_validation as cv,
service,
template as template,
)
from homeassistant.helpers.event import (
async_track_point_in_utc_time,
async_track_template,
)
from homeassistant.helpers.service import (
CONF_SERVICE_DATA,
async_prepare_call_from_config,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import slugify
from homeassistant.util.dt import utcnow
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
CONF_ALIAS = "alias"
IF_RUNNING_ERROR = "error"
IF_RUNNING_IGNORE = "ignore"
IF_RUNNING_PARALLEL = "parallel"
IF_RUNNING_RESTART = "restart"
# First choice is default
IF_RUNNING_CHOICES = [
IF_RUNNING_PARALLEL,
IF_RUNNING_ERROR,
IF_RUNNING_IGNORE,
IF_RUNNING_RESTART,
SCRIPT_MODE_ERROR = "error"
SCRIPT_MODE_IGNORE = "ignore"
SCRIPT_MODE_LEGACY = "legacy"
SCRIPT_MODE_PARALLEL = "parallel"
SCRIPT_MODE_QUEUE = "queue"
SCRIPT_MODE_RESTART = "restart"
SCRIPT_MODE_CHOICES = [
SCRIPT_MODE_ERROR,
SCRIPT_MODE_IGNORE,
SCRIPT_MODE_LEGACY,
SCRIPT_MODE_PARALLEL,
SCRIPT_MODE_QUEUE,
SCRIPT_MODE_RESTART,
]
DEFAULT_SCRIPT_MODE = SCRIPT_MODE_LEGACY
RUN_MODE_BACKGROUND = "background"
RUN_MODE_BLOCKING = "blocking"
RUN_MODE_LEGACY = "legacy"
# First choice is default
RUN_MODE_CHOICES = [
RUN_MODE_BLOCKING,
RUN_MODE_BACKGROUND,
RUN_MODE_LEGACY,
]
DEFAULT_QUEUE_MAX = 10
_LOG_EXCEPTION = logging.ERROR + 1
_TIMEOUT_MSG = "Timeout reached, abort script."
@ -102,6 +110,14 @@ class _SuspendScript(Exception):
"""Throw if script needs to suspend."""
class AlreadyRunning(exceptions.HomeAssistantError):
"""Throw if script already running and user wants error."""
class QueueFull(exceptions.HomeAssistantError):
"""Throw if script already running, user wants new run queued, but queue is full."""
class _ScriptRunBase(ABC):
"""Common data & methods for managing Script sequence run."""
@ -137,11 +153,11 @@ class _ScriptRunBase(ABC):
await getattr(
self, f"_async_{cv.determine_script_action(self._action)}_step"
)()
except Exception as err:
if not isinstance(err, (_SuspendScript, _StopScript)) and (
self._log_exceptions or log_exceptions
):
self._log_exception(err)
except Exception as ex:
if not isinstance(
ex, (_SuspendScript, _StopScript, asyncio.CancelledError)
) and (self._log_exceptions or log_exceptions):
self._log_exception(ex)
raise
@abstractmethod
@ -166,6 +182,12 @@ class _ScriptRunBase(ABC):
elif isinstance(exception, exceptions.ServiceNotFound):
error_desc = "Service not found"
elif isinstance(exception, AlreadyRunning):
error_desc = "Already running"
elif isinstance(exception, QueueFull):
error_desc = "Run queue is full"
else:
error_desc = "Unexpected error"
level = _LOG_EXCEPTION
@ -189,12 +211,13 @@ class _ScriptRunBase(ABC):
template.render_complex(self._action[CONF_DELAY], self._variables)
)
except (exceptions.TemplateError, vol.Invalid) as ex:
self._raise(
self._log(
"Error rendering %s delay template: %s",
self._script.name,
ex,
exception=_StopScript,
level=logging.ERROR,
)
raise _StopScript
self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}")
self._log("Executing step %s", self._script.last_action)
@ -220,18 +243,14 @@ class _ScriptRunBase(ABC):
self._hass, wait_template, async_script_wait, self._variables
)
@abstractmethod
async def _async_call_service_step(self):
"""Call the service specified in the action."""
def _prep_call_service_step(self):
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
self._log("Executing step %s", self._script.last_action)
await service.async_call_from_config(
self._hass,
self._action,
blocking=True,
variables=self._variables,
validate_config=False,
context=self._context,
)
return async_prepare_call_from_config(self._hass, self._action, self._variables)
async def _async_device_step(self):
"""Perform the device automation specified in the action."""
@ -298,10 +317,6 @@ class _ScriptRunBase(ABC):
def _log(self, msg, *args, level=logging.INFO):
self._script._log(msg, *args, level=level) # pylint: disable=protected-access
def _raise(self, msg, *args, exception=None):
# pylint: disable=protected-access
self._script._raise(msg, *args, exception=exception)
class _ScriptRun(_ScriptRunBase):
"""Manage Script sequence run."""
@ -318,24 +333,33 @@ class _ScriptRun(_ScriptRunBase):
self._stop = asyncio.Event()
self._stopped = asyncio.Event()
async def _async_run(self, propagate_exceptions=True):
self._log("Running script")
def _changed(self):
if not self._stop.is_set():
super()._changed()
async def async_run(self) -> None:
"""Run script."""
try:
if self._stop.is_set():
return
self._script.last_triggered = utcnow()
self._changed()
self._log("Running script")
for self._step, self._action in enumerate(self._script.sequence):
if self._stop.is_set():
break
await self._async_step(not propagate_exceptions)
await self._async_step(log_exceptions=False)
except _StopScript:
pass
except Exception: # pylint: disable=broad-except
if propagate_exceptions:
raise
finally:
if not self._stop.is_set():
self._changed()
self._finish()
def _finish(self):
self._script._runs.remove(self) # pylint: disable=protected-access
if not self._script.is_running:
self._script.last_action = None
self._script._runs.remove(self) # pylint: disable=protected-access
self._stopped.set()
self._changed()
self._stopped.set()
async def async_stop(self) -> None:
"""Stop script run."""
@ -344,10 +368,13 @@ class _ScriptRun(_ScriptRunBase):
async def _async_delay_step(self):
"""Handle delay."""
timeout = self._prep_delay_step().total_seconds()
if not self._stop.is_set():
self._changed()
await asyncio.wait({self._stop.wait()}, timeout=timeout)
delay = self._prep_delay_step().total_seconds()
self._changed()
try:
async with timeout(delay):
await self._stop.wait()
except asyncio.TimeoutError:
pass
async def _async_wait_template_step(self):
"""Handle a wait template."""
@ -361,21 +388,20 @@ class _ScriptRun(_ScriptRunBase):
if not unsub:
return
if not self._stop.is_set():
self._changed()
self._changed()
try:
timeout = self._action[CONF_TIMEOUT].total_seconds()
delay = self._action[CONF_TIMEOUT].total_seconds()
except KeyError:
timeout = None
delay = None
done = asyncio.Event()
try:
await asyncio.wait_for(
asyncio.wait(
async with timeout(delay):
_, pending = await asyncio.wait(
{self._stop.wait(), done.wait()},
return_when=asyncio.FIRST_COMPLETED,
),
timeout,
)
)
for pending_task in pending:
pending_task.cancel()
except asyncio.TimeoutError:
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._log(_TIMEOUT_MSG)
@ -383,25 +409,78 @@ class _ScriptRun(_ScriptRunBase):
finally:
unsub()
async def _async_call_service_step(self):
"""Call the service specified in the action."""
domain, service, service_data = self._prep_call_service_step()
class _BackgroundScriptRun(_ScriptRun):
"""Manage background Script sequence run."""
# If this might start a script then disable the call timeout.
# Otherwise use the normal service call limit.
if domain == "script" and service != SERVICE_TURN_OFF:
limit = None
else:
limit = SERVICE_CALL_LIMIT
coro = self._hass.services.async_call(
domain,
service,
service_data,
blocking=True,
context=self._context,
limit=limit,
)
if limit is not None:
# There is a call limit, so just wait for it to finish.
await coro
return
# No call limit (i.e., potentially starting one or more fully blocking scripts)
# so watch for a stop request.
done, pending = await asyncio.wait(
{self._stop.wait(), coro}, return_when=asyncio.FIRST_COMPLETED,
)
# Note that cancelling the service call, if it has not yet returned, will also
# stop any non-background script runs that it may have started.
for pending_task in pending:
pending_task.cancel()
# Propagate any exceptions that might have happened.
for done_task in done:
done_task.result()
class _QueuedScriptRun(_ScriptRun):
"""Manage queued Script sequence run."""
lock_acquired = False
async def async_run(self) -> None:
"""Run script."""
self._hass.async_create_task(self._async_run(False))
# Wait for previous run, if any, to finish by attempting to acquire the script's
# shared lock. At the same time monitor if we've been told to stop.
lock_task = self._hass.async_create_task(
self._script._queue_lck.acquire() # pylint: disable=protected-access
)
done, pending = await asyncio.wait(
{self._stop.wait(), lock_task}, return_when=asyncio.FIRST_COMPLETED
)
for pending_task in pending:
pending_task.cancel()
self.lock_acquired = lock_task in done
# If we've been told to stop, then just finish up. Otherwise, we've acquired the
# lock so we can go ahead and start the run.
if self._stop.is_set():
self._finish()
else:
await super().async_run()
class _BlockingScriptRun(_ScriptRun):
"""Manage blocking Script sequence run."""
async def async_run(self) -> None:
"""Run script."""
try:
await asyncio.shield(self._async_run())
except asyncio.CancelledError:
await self.async_stop()
raise
def _finish(self):
# pylint: disable=protected-access
self._script._queue_len -= 1
if self.lock_acquired:
self._script._queue_lck.release()
self.lock_acquired = False
super()._finish()
class _LegacyScriptRun(_ScriptRunBase):
@ -445,6 +524,7 @@ class _LegacyScriptRun(_ScriptRunBase):
async def _async_run(self, propagate_exceptions=True):
if self._cur == -1:
self._script.last_triggered = utcnow()
self._log("Running script")
self._cur = 0
@ -457,7 +537,7 @@ class _LegacyScriptRun(_ScriptRunBase):
for self._step, self._action in islice(
enumerate(self._script.sequence), self._cur, None
):
await self._async_step(not propagate_exceptions)
await self._async_step(log_exceptions=not propagate_exceptions)
except _StopScript:
pass
except _SuspendScript:
@ -469,11 +549,12 @@ class _LegacyScriptRun(_ScriptRunBase):
if propagate_exceptions:
raise
finally:
if self._cur != -1:
self._changed()
_cur_was = self._cur
if not suspended:
self._script.last_action = None
await self.async_stop()
if _cur_was != -1:
self._changed()
async def async_stop(self) -> None:
"""Stop script run."""
@ -512,9 +593,9 @@ class _LegacyScriptRun(_ScriptRunBase):
@callback
def async_script_timeout(now):
"""Call after timeout is retrieve."""
"""Call after timeout has expired."""
with suppress(ValueError):
self._async_listener.remove(unsub)
self._async_listener.remove(unsub_timeout)
# Check if we want to continue to execute
# the script after the timeout
@ -530,13 +611,19 @@ class _LegacyScriptRun(_ScriptRunBase):
self._async_listener.append(unsub_wait)
if CONF_TIMEOUT in self._action:
unsub = async_track_point_in_utc_time(
unsub_timeout = async_track_point_in_utc_time(
self._hass, async_script_timeout, utcnow() + self._action[CONF_TIMEOUT]
)
self._async_listener.append(unsub)
self._async_listener.append(unsub_timeout)
raise _SuspendScript
async def _async_call_service_step(self):
"""Call the service specified in the action."""
await self._hass.services.async_call(
*self._prep_call_service_step(), blocking=True, context=self._context
)
def _async_remove_listener(self):
"""Remove listeners, if any."""
for unsub in self._async_listener:
@ -553,47 +640,60 @@ class Script:
sequence: Sequence[Dict[str, Any]],
name: Optional[str] = None,
change_listener: Optional[Callable[..., Any]] = None,
if_running: Optional[str] = None,
run_mode: Optional[str] = None,
script_mode: str = DEFAULT_SCRIPT_MODE,
queue_max: int = DEFAULT_QUEUE_MAX,
logger: Optional[logging.Logger] = None,
log_exceptions: bool = True,
) -> None:
"""Initialize the script."""
self._logger = logger or logging.getLogger(__name__)
self._hass = hass
self.sequence = sequence
template.attach(hass, self.sequence)
self.name = name
self._change_listener = change_listener
self.change_listener = change_listener
self._script_mode = script_mode
if logger:
self._logger = logger
else:
logger_name = __name__
if name:
logger_name = ".".join([logger_name, slugify(name)])
self._logger = logging.getLogger(logger_name)
self._log_exceptions = log_exceptions
self.last_action = None
self.last_triggered: Optional[datetime] = None
self.can_cancel = any(
self.can_cancel = not self.is_legacy or any(
CONF_DELAY in action or CONF_WAIT_TEMPLATE in action
for action in self.sequence
)
if not if_running and not run_mode:
self._if_running = IF_RUNNING_PARALLEL
self._run_mode = RUN_MODE_LEGACY
elif if_running and run_mode == RUN_MODE_LEGACY:
self._raise('Cannot use if_running if run_mode is "legacy"')
else:
self._if_running = if_running or IF_RUNNING_CHOICES[0]
self._run_mode = run_mode or RUN_MODE_CHOICES[0]
self._runs: List[_ScriptRunBase] = []
self._log_exceptions = log_exceptions
if script_mode == SCRIPT_MODE_QUEUE:
self._queue_max = queue_max
self._queue_len = 0
self._queue_lck = asyncio.Lock()
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {}
self._referenced_entities: Optional[Set[str]] = None
self._referenced_devices: Optional[Set[str]] = None
def _changed(self):
if self._change_listener:
self._hass.async_add_job(self._change_listener)
if self.change_listener:
if is_callback(self.change_listener):
self.change_listener()
else:
self._hass.async_add_job(self.change_listener)
@property
def is_running(self) -> bool:
"""Return true if script is on."""
return len(self._runs) > 0
@property
def is_legacy(self) -> bool:
"""Return if using legacy mode."""
return self._script_mode == SCRIPT_MODE_LEGACY
@property
def referenced_devices(self):
"""Return a set of referenced devices."""
@ -626,7 +726,7 @@ class Script:
action = cv.determine_script_action(step)
if action == cv.SCRIPT_ACTION_CALL_SERVICE:
data = step.get(service.CONF_SERVICE_DATA)
data = step.get(CONF_SERVICE_DATA)
if not data:
continue
@ -661,18 +761,26 @@ class Script:
) -> None:
"""Run script."""
if self.is_running:
if self._if_running == IF_RUNNING_IGNORE:
if self._script_mode == SCRIPT_MODE_IGNORE:
self._log("Skipping script")
return
if self._if_running == IF_RUNNING_ERROR:
self._raise("Already running")
if self._if_running == IF_RUNNING_RESTART:
self._log("Restarting script")
await self.async_stop()
if self._script_mode == SCRIPT_MODE_ERROR:
raise AlreadyRunning
self.last_triggered = utcnow()
if self._run_mode == RUN_MODE_LEGACY:
if self._script_mode == SCRIPT_MODE_RESTART:
self._log("Restarting script")
await self.async_stop(update_state=False)
elif self._script_mode == SCRIPT_MODE_QUEUE:
self._log(
"Queueing script behind %i run%s",
self._queue_len,
"s" if self._queue_len > 1 else "",
)
if self._queue_len >= self._queue_max:
raise QueueFull
if self.is_legacy:
if self._runs:
shared = cast(Optional[_LegacyScriptRun], self._runs[0])
else:
@ -681,23 +789,31 @@ class Script:
self._hass, self, variables, context, self._log_exceptions, shared
)
else:
if self._run_mode == RUN_MODE_BACKGROUND:
run = _BackgroundScriptRun(
self._hass, self, variables, context, self._log_exceptions
)
if self._script_mode != SCRIPT_MODE_QUEUE:
cls = _ScriptRun
else:
run = _BlockingScriptRun(
self._hass, self, variables, context, self._log_exceptions
)
cls = _QueuedScriptRun
self._queue_len += 1
run = cls(self._hass, self, variables, context, self._log_exceptions)
self._runs.append(run)
await run.async_run()
async def async_stop(self) -> None:
try:
if self.is_legacy:
await run.async_run()
else:
await asyncio.shield(run.async_run())
except asyncio.CancelledError:
await run.async_stop()
self._changed()
raise
async def async_stop(self, update_state: bool = True) -> None:
"""Stop running script."""
if not self.is_running:
return
await asyncio.shield(asyncio.gather(*(run.async_stop() for run in self._runs)))
self._changed()
if update_state:
self._changed()
def _log(self, msg, *args, level=logging.INFO):
if self.name:
@ -706,9 +822,3 @@ class Script:
self._logger.exception(msg, *args)
else:
self._logger.log(level, msg, *args)
def _raise(self, msg, *args, exception=None):
if not exception:
exception = exceptions.HomeAssistantError
self._log(msg, *args, level=logging.ERROR)
raise exception(msg % args)