Add support for simultaneous runs of Script helper - Part 2 (#32442)

* Add limit parameter to service call methods

* Break out prep part of async_call_from_config for use elsewhere

* Minor cleanup

* Fix improper use of asyncio.wait

* Fix state update

Call change listener immediately if its a callback

* Fix exception handling and logging

* Merge Script helper if_running/run_mode parameters into script_mode

- Remove background/blocking _ScriptRun subclasses which are no longer needed.

* Add queued script mode

* Disable timeout when making fully blocking script call

* Don't call change listener when restarting script

This makes restart mode behavior consistent with parallel & queue modes.

* Changes per review

- Call all script services (except script.turn_off) with no time limit.
- Fix handling of lock in _QueuedScriptRun and add comments to make it
  clearer how this code works.

* Changes per review 2

- Move cancel shielding "up" from _ScriptRun.async_run to Script.async_run
  (and apply to new style scripts only.) This makes sure Script class also
  properly handles cancellation which it wasn't doing before.
- In _ScriptRun._async_call_service_step, instead of using script.turn_off
  service, just cancel service call and let it handle the cancellation
  accordingly.

* Fix bugs

- Add missing call to change listener in Script.async_run
  in cancelled path.
- Cancel service task if ServiceRegistry.async_call cancelled.

* Revert last changes to ServiceRegistry.async_call

* Minor Script helper fixes & test improvements

- Don't log asyncio.CancelledError exceptions.
- Make change_listener a public attribute.
- Test overhaul
  - Parametrize tests.
  - Use common test functions.
  - Mock timeout so tests don't need to wait for real time to elapse.
  - Add common function for waiting for script action step.
This commit is contained in:
Phil Bruckner 2020-03-11 18:34:50 -05:00 committed by GitHub
parent da761fdd39
commit 5f5cb8bea8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 948 additions and 1444 deletions

View file

@ -93,7 +93,7 @@ SOURCE_DISCOVERED = "discovered"
SOURCE_STORAGE = "storage" SOURCE_STORAGE = "storage"
SOURCE_YAML = "yaml" SOURCE_YAML = "yaml"
# How long to wait till things that run on startup have to finish. # How long to wait until things that run on startup have to finish.
TIMEOUT_EVENT_START = 15 TIMEOUT_EVENT_START = 15
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -249,7 +249,7 @@ class HomeAssistant:
try: try:
# Only block for EVENT_HOMEASSISTANT_START listener # Only block for EVENT_HOMEASSISTANT_START listener
self.async_stop_track_tasks() self.async_stop_track_tasks()
with timeout(TIMEOUT_EVENT_START): async with timeout(TIMEOUT_EVENT_START):
await self.async_block_till_done() await self.async_block_till_done()
except asyncio.TimeoutError: except asyncio.TimeoutError:
_LOGGER.warning( _LOGGER.warning(
@ -374,13 +374,13 @@ class HomeAssistant:
self.async_add_job(target, *args) self.async_add_job(target, *args)
def block_till_done(self) -> None: def block_till_done(self) -> None:
"""Block till all pending work is done.""" """Block until all pending work is done."""
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
self.async_block_till_done(), self.loop self.async_block_till_done(), self.loop
).result() ).result()
async def async_block_till_done(self) -> None: async def async_block_till_done(self) -> None:
"""Block till all pending work is done.""" """Block until all pending work is done."""
# To flush out any call_soon_threadsafe # To flush out any call_soon_threadsafe
await asyncio.sleep(0) await asyncio.sleep(0)
@ -1150,25 +1150,15 @@ class ServiceRegistry:
service_data: Optional[Dict] = None, service_data: Optional[Dict] = None,
blocking: bool = False, blocking: bool = False,
context: Optional[Context] = None, context: Optional[Context] = None,
limit: Optional[float] = SERVICE_CALL_LIMIT,
) -> Optional[bool]: ) -> Optional[bool]:
""" """
Call a service. Call a service.
Specify blocking=True to wait till service is executed. See description of async_call for details.
Waits a maximum of SERVICE_CALL_LIMIT.
If blocking = True, will return boolean if service executed
successfully within SERVICE_CALL_LIMIT.
This method will fire an event to call the service.
This event will be picked up by this ServiceRegistry and any
other ServiceRegistry that is listening on the EventBus.
Because the service is sent as an event you are not allowed to use
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
""" """
return asyncio.run_coroutine_threadsafe( return asyncio.run_coroutine_threadsafe(
self.async_call(domain, service, service_data, blocking, context), self.async_call(domain, service, service_data, blocking, context, limit),
self._hass.loop, self._hass.loop,
).result() ).result()
@ -1179,19 +1169,18 @@ class ServiceRegistry:
service_data: Optional[Dict] = None, service_data: Optional[Dict] = None,
blocking: bool = False, blocking: bool = False,
context: Optional[Context] = None, context: Optional[Context] = None,
limit: Optional[float] = SERVICE_CALL_LIMIT,
) -> Optional[bool]: ) -> Optional[bool]:
""" """
Call a service. Call a service.
Specify blocking=True to wait till service is executed. Specify blocking=True to wait until service is executed.
Waits a maximum of SERVICE_CALL_LIMIT. Waits a maximum of limit, which may be None for no timeout.
If blocking = True, will return boolean if service executed If blocking = True, will return boolean if service executed
successfully within SERVICE_CALL_LIMIT. successfully within limit.
This method will fire an event to call the service. This method will fire an event to indicate the service has been called.
This event will be picked up by this ServiceRegistry and any
other ServiceRegistry that is listening on the EventBus.
Because the service is sent as an event you are not allowed to use Because the service is sent as an event you are not allowed to use
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data. the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
@ -1230,7 +1219,7 @@ class ServiceRegistry:
return None return None
try: try:
with timeout(SERVICE_CALL_LIMIT): async with timeout(limit):
await asyncio.shield(self._execute_service(handler, service_call)) await asyncio.shield(self._execute_service(handler, service_call))
return True return True
except asyncio.TimeoutError: except asyncio.TimeoutError:

View file

@ -7,6 +7,7 @@ from itertools import islice
import logging import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, cast
from async_timeout import timeout
import voluptuous as vol import voluptuous as vol
from homeassistant import exceptions from homeassistant import exceptions
@ -14,6 +15,7 @@ import homeassistant.components.device_automation as device_automation
import homeassistant.components.scene as scene import homeassistant.components.scene as scene
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
CONF_ALIAS,
CONF_CONDITION, CONF_CONDITION,
CONF_CONTINUE_ON_TIMEOUT, CONF_CONTINUE_ON_TIMEOUT,
CONF_DELAY, CONF_DELAY,
@ -25,47 +27,53 @@ from homeassistant.const import (
CONF_SCENE, CONF_SCENE,
CONF_TIMEOUT, CONF_TIMEOUT,
CONF_WAIT_TEMPLATE, CONF_WAIT_TEMPLATE,
SERVICE_TURN_OFF,
SERVICE_TURN_ON, SERVICE_TURN_ON,
) )
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback from homeassistant.core import (
CALLBACK_TYPE,
SERVICE_CALL_LIMIT,
Context,
HomeAssistant,
callback,
is_callback,
)
from homeassistant.helpers import ( from homeassistant.helpers import (
condition, condition,
config_validation as cv, config_validation as cv,
service,
template as template, template as template,
) )
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
async_track_point_in_utc_time, async_track_point_in_utc_time,
async_track_template, async_track_template,
) )
from homeassistant.helpers.service import (
CONF_SERVICE_DATA,
async_prepare_call_from_config,
)
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.util import slugify
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
CONF_ALIAS = "alias" SCRIPT_MODE_ERROR = "error"
SCRIPT_MODE_IGNORE = "ignore"
IF_RUNNING_ERROR = "error" SCRIPT_MODE_LEGACY = "legacy"
IF_RUNNING_IGNORE = "ignore" SCRIPT_MODE_PARALLEL = "parallel"
IF_RUNNING_PARALLEL = "parallel" SCRIPT_MODE_QUEUE = "queue"
IF_RUNNING_RESTART = "restart" SCRIPT_MODE_RESTART = "restart"
# First choice is default SCRIPT_MODE_CHOICES = [
IF_RUNNING_CHOICES = [ SCRIPT_MODE_ERROR,
IF_RUNNING_PARALLEL, SCRIPT_MODE_IGNORE,
IF_RUNNING_ERROR, SCRIPT_MODE_LEGACY,
IF_RUNNING_IGNORE, SCRIPT_MODE_PARALLEL,
IF_RUNNING_RESTART, SCRIPT_MODE_QUEUE,
SCRIPT_MODE_RESTART,
] ]
DEFAULT_SCRIPT_MODE = SCRIPT_MODE_LEGACY
RUN_MODE_BACKGROUND = "background" DEFAULT_QUEUE_MAX = 10
RUN_MODE_BLOCKING = "blocking"
RUN_MODE_LEGACY = "legacy"
# First choice is default
RUN_MODE_CHOICES = [
RUN_MODE_BLOCKING,
RUN_MODE_BACKGROUND,
RUN_MODE_LEGACY,
]
_LOG_EXCEPTION = logging.ERROR + 1 _LOG_EXCEPTION = logging.ERROR + 1
_TIMEOUT_MSG = "Timeout reached, abort script." _TIMEOUT_MSG = "Timeout reached, abort script."
@ -102,6 +110,14 @@ class _SuspendScript(Exception):
"""Throw if script needs to suspend.""" """Throw if script needs to suspend."""
class AlreadyRunning(exceptions.HomeAssistantError):
"""Throw if script already running and user wants error."""
class QueueFull(exceptions.HomeAssistantError):
"""Throw if script already running, user wants new run queued, but queue is full."""
class _ScriptRunBase(ABC): class _ScriptRunBase(ABC):
"""Common data & methods for managing Script sequence run.""" """Common data & methods for managing Script sequence run."""
@ -137,11 +153,11 @@ class _ScriptRunBase(ABC):
await getattr( await getattr(
self, f"_async_{cv.determine_script_action(self._action)}_step" self, f"_async_{cv.determine_script_action(self._action)}_step"
)() )()
except Exception as err: except Exception as ex:
if not isinstance(err, (_SuspendScript, _StopScript)) and ( if not isinstance(
self._log_exceptions or log_exceptions ex, (_SuspendScript, _StopScript, asyncio.CancelledError)
): ) and (self._log_exceptions or log_exceptions):
self._log_exception(err) self._log_exception(ex)
raise raise
@abstractmethod @abstractmethod
@ -166,6 +182,12 @@ class _ScriptRunBase(ABC):
elif isinstance(exception, exceptions.ServiceNotFound): elif isinstance(exception, exceptions.ServiceNotFound):
error_desc = "Service not found" error_desc = "Service not found"
elif isinstance(exception, AlreadyRunning):
error_desc = "Already running"
elif isinstance(exception, QueueFull):
error_desc = "Run queue is full"
else: else:
error_desc = "Unexpected error" error_desc = "Unexpected error"
level = _LOG_EXCEPTION level = _LOG_EXCEPTION
@ -189,12 +211,13 @@ class _ScriptRunBase(ABC):
template.render_complex(self._action[CONF_DELAY], self._variables) template.render_complex(self._action[CONF_DELAY], self._variables)
) )
except (exceptions.TemplateError, vol.Invalid) as ex: except (exceptions.TemplateError, vol.Invalid) as ex:
self._raise( self._log(
"Error rendering %s delay template: %s", "Error rendering %s delay template: %s",
self._script.name, self._script.name,
ex, ex,
exception=_StopScript, level=logging.ERROR,
) )
raise _StopScript
self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}") self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}")
self._log("Executing step %s", self._script.last_action) self._log("Executing step %s", self._script.last_action)
@ -220,18 +243,14 @@ class _ScriptRunBase(ABC):
self._hass, wait_template, async_script_wait, self._variables self._hass, wait_template, async_script_wait, self._variables
) )
@abstractmethod
async def _async_call_service_step(self): async def _async_call_service_step(self):
"""Call the service specified in the action.""" """Call the service specified in the action."""
def _prep_call_service_step(self):
self._script.last_action = self._action.get(CONF_ALIAS, "call service") self._script.last_action = self._action.get(CONF_ALIAS, "call service")
self._log("Executing step %s", self._script.last_action) self._log("Executing step %s", self._script.last_action)
await service.async_call_from_config( return async_prepare_call_from_config(self._hass, self._action, self._variables)
self._hass,
self._action,
blocking=True,
variables=self._variables,
validate_config=False,
context=self._context,
)
async def _async_device_step(self): async def _async_device_step(self):
"""Perform the device automation specified in the action.""" """Perform the device automation specified in the action."""
@ -298,10 +317,6 @@ class _ScriptRunBase(ABC):
def _log(self, msg, *args, level=logging.INFO): def _log(self, msg, *args, level=logging.INFO):
self._script._log(msg, *args, level=level) # pylint: disable=protected-access self._script._log(msg, *args, level=level) # pylint: disable=protected-access
def _raise(self, msg, *args, exception=None):
# pylint: disable=protected-access
self._script._raise(msg, *args, exception=exception)
class _ScriptRun(_ScriptRunBase): class _ScriptRun(_ScriptRunBase):
"""Manage Script sequence run.""" """Manage Script sequence run."""
@ -318,24 +333,33 @@ class _ScriptRun(_ScriptRunBase):
self._stop = asyncio.Event() self._stop = asyncio.Event()
self._stopped = asyncio.Event() self._stopped = asyncio.Event()
async def _async_run(self, propagate_exceptions=True): def _changed(self):
self._log("Running script") if not self._stop.is_set():
super()._changed()
async def async_run(self) -> None:
"""Run script."""
try: try:
if self._stop.is_set():
return
self._script.last_triggered = utcnow()
self._changed()
self._log("Running script")
for self._step, self._action in enumerate(self._script.sequence): for self._step, self._action in enumerate(self._script.sequence):
if self._stop.is_set(): if self._stop.is_set():
break break
await self._async_step(not propagate_exceptions) await self._async_step(log_exceptions=False)
except _StopScript: except _StopScript:
pass pass
except Exception: # pylint: disable=broad-except
if propagate_exceptions:
raise
finally: finally:
if not self._stop.is_set(): self._finish()
self._changed()
def _finish(self):
self._script._runs.remove(self) # pylint: disable=protected-access
if not self._script.is_running:
self._script.last_action = None self._script.last_action = None
self._script._runs.remove(self) # pylint: disable=protected-access self._changed()
self._stopped.set() self._stopped.set()
async def async_stop(self) -> None: async def async_stop(self) -> None:
"""Stop script run.""" """Stop script run."""
@ -344,10 +368,13 @@ class _ScriptRun(_ScriptRunBase):
async def _async_delay_step(self): async def _async_delay_step(self):
"""Handle delay.""" """Handle delay."""
timeout = self._prep_delay_step().total_seconds() delay = self._prep_delay_step().total_seconds()
if not self._stop.is_set(): self._changed()
self._changed() try:
await asyncio.wait({self._stop.wait()}, timeout=timeout) async with timeout(delay):
await self._stop.wait()
except asyncio.TimeoutError:
pass
async def _async_wait_template_step(self): async def _async_wait_template_step(self):
"""Handle a wait template.""" """Handle a wait template."""
@ -361,21 +388,20 @@ class _ScriptRun(_ScriptRunBase):
if not unsub: if not unsub:
return return
if not self._stop.is_set(): self._changed()
self._changed()
try: try:
timeout = self._action[CONF_TIMEOUT].total_seconds() delay = self._action[CONF_TIMEOUT].total_seconds()
except KeyError: except KeyError:
timeout = None delay = None
done = asyncio.Event() done = asyncio.Event()
try: try:
await asyncio.wait_for( async with timeout(delay):
asyncio.wait( _, pending = await asyncio.wait(
{self._stop.wait(), done.wait()}, {self._stop.wait(), done.wait()},
return_when=asyncio.FIRST_COMPLETED, return_when=asyncio.FIRST_COMPLETED,
), )
timeout, for pending_task in pending:
) pending_task.cancel()
except asyncio.TimeoutError: except asyncio.TimeoutError:
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True): if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
self._log(_TIMEOUT_MSG) self._log(_TIMEOUT_MSG)
@ -383,25 +409,78 @@ class _ScriptRun(_ScriptRunBase):
finally: finally:
unsub() unsub()
async def _async_call_service_step(self):
"""Call the service specified in the action."""
domain, service, service_data = self._prep_call_service_step()
class _BackgroundScriptRun(_ScriptRun): # If this might start a script then disable the call timeout.
"""Manage background Script sequence run.""" # Otherwise use the normal service call limit.
if domain == "script" and service != SERVICE_TURN_OFF:
limit = None
else:
limit = SERVICE_CALL_LIMIT
coro = self._hass.services.async_call(
domain,
service,
service_data,
blocking=True,
context=self._context,
limit=limit,
)
if limit is not None:
# There is a call limit, so just wait for it to finish.
await coro
return
# No call limit (i.e., potentially starting one or more fully blocking scripts)
# so watch for a stop request.
done, pending = await asyncio.wait(
{self._stop.wait(), coro}, return_when=asyncio.FIRST_COMPLETED,
)
# Note that cancelling the service call, if it has not yet returned, will also
# stop any non-background script runs that it may have started.
for pending_task in pending:
pending_task.cancel()
# Propagate any exceptions that might have happened.
for done_task in done:
done_task.result()
class _QueuedScriptRun(_ScriptRun):
"""Manage queued Script sequence run."""
lock_acquired = False
async def async_run(self) -> None: async def async_run(self) -> None:
"""Run script.""" """Run script."""
self._hass.async_create_task(self._async_run(False)) # Wait for previous run, if any, to finish by attempting to acquire the script's
# shared lock. At the same time monitor if we've been told to stop.
lock_task = self._hass.async_create_task(
self._script._queue_lck.acquire() # pylint: disable=protected-access
)
done, pending = await asyncio.wait(
{self._stop.wait(), lock_task}, return_when=asyncio.FIRST_COMPLETED
)
for pending_task in pending:
pending_task.cancel()
self.lock_acquired = lock_task in done
# If we've been told to stop, then just finish up. Otherwise, we've acquired the
# lock so we can go ahead and start the run.
if self._stop.is_set():
self._finish()
else:
await super().async_run()
class _BlockingScriptRun(_ScriptRun): def _finish(self):
"""Manage blocking Script sequence run.""" # pylint: disable=protected-access
self._script._queue_len -= 1
async def async_run(self) -> None: if self.lock_acquired:
"""Run script.""" self._script._queue_lck.release()
try: self.lock_acquired = False
await asyncio.shield(self._async_run()) super()._finish()
except asyncio.CancelledError:
await self.async_stop()
raise
class _LegacyScriptRun(_ScriptRunBase): class _LegacyScriptRun(_ScriptRunBase):
@ -445,6 +524,7 @@ class _LegacyScriptRun(_ScriptRunBase):
async def _async_run(self, propagate_exceptions=True): async def _async_run(self, propagate_exceptions=True):
if self._cur == -1: if self._cur == -1:
self._script.last_triggered = utcnow()
self._log("Running script") self._log("Running script")
self._cur = 0 self._cur = 0
@ -457,7 +537,7 @@ class _LegacyScriptRun(_ScriptRunBase):
for self._step, self._action in islice( for self._step, self._action in islice(
enumerate(self._script.sequence), self._cur, None enumerate(self._script.sequence), self._cur, None
): ):
await self._async_step(not propagate_exceptions) await self._async_step(log_exceptions=not propagate_exceptions)
except _StopScript: except _StopScript:
pass pass
except _SuspendScript: except _SuspendScript:
@ -469,11 +549,12 @@ class _LegacyScriptRun(_ScriptRunBase):
if propagate_exceptions: if propagate_exceptions:
raise raise
finally: finally:
if self._cur != -1: _cur_was = self._cur
self._changed()
if not suspended: if not suspended:
self._script.last_action = None self._script.last_action = None
await self.async_stop() await self.async_stop()
if _cur_was != -1:
self._changed()
async def async_stop(self) -> None: async def async_stop(self) -> None:
"""Stop script run.""" """Stop script run."""
@ -512,9 +593,9 @@ class _LegacyScriptRun(_ScriptRunBase):
@callback @callback
def async_script_timeout(now): def async_script_timeout(now):
"""Call after timeout is retrieve.""" """Call after timeout has expired."""
with suppress(ValueError): with suppress(ValueError):
self._async_listener.remove(unsub) self._async_listener.remove(unsub_timeout)
# Check if we want to continue to execute # Check if we want to continue to execute
# the script after the timeout # the script after the timeout
@ -530,13 +611,19 @@ class _LegacyScriptRun(_ScriptRunBase):
self._async_listener.append(unsub_wait) self._async_listener.append(unsub_wait)
if CONF_TIMEOUT in self._action: if CONF_TIMEOUT in self._action:
unsub = async_track_point_in_utc_time( unsub_timeout = async_track_point_in_utc_time(
self._hass, async_script_timeout, utcnow() + self._action[CONF_TIMEOUT] self._hass, async_script_timeout, utcnow() + self._action[CONF_TIMEOUT]
) )
self._async_listener.append(unsub) self._async_listener.append(unsub_timeout)
raise _SuspendScript raise _SuspendScript
async def _async_call_service_step(self):
"""Call the service specified in the action."""
await self._hass.services.async_call(
*self._prep_call_service_step(), blocking=True, context=self._context
)
def _async_remove_listener(self): def _async_remove_listener(self):
"""Remove listeners, if any.""" """Remove listeners, if any."""
for unsub in self._async_listener: for unsub in self._async_listener:
@ -553,47 +640,60 @@ class Script:
sequence: Sequence[Dict[str, Any]], sequence: Sequence[Dict[str, Any]],
name: Optional[str] = None, name: Optional[str] = None,
change_listener: Optional[Callable[..., Any]] = None, change_listener: Optional[Callable[..., Any]] = None,
if_running: Optional[str] = None, script_mode: str = DEFAULT_SCRIPT_MODE,
run_mode: Optional[str] = None, queue_max: int = DEFAULT_QUEUE_MAX,
logger: Optional[logging.Logger] = None, logger: Optional[logging.Logger] = None,
log_exceptions: bool = True, log_exceptions: bool = True,
) -> None: ) -> None:
"""Initialize the script.""" """Initialize the script."""
self._logger = logger or logging.getLogger(__name__)
self._hass = hass self._hass = hass
self.sequence = sequence self.sequence = sequence
template.attach(hass, self.sequence) template.attach(hass, self.sequence)
self.name = name self.name = name
self._change_listener = change_listener self.change_listener = change_listener
self._script_mode = script_mode
if logger:
self._logger = logger
else:
logger_name = __name__
if name:
logger_name = ".".join([logger_name, slugify(name)])
self._logger = logging.getLogger(logger_name)
self._log_exceptions = log_exceptions
self.last_action = None self.last_action = None
self.last_triggered: Optional[datetime] = None self.last_triggered: Optional[datetime] = None
self.can_cancel = any( self.can_cancel = not self.is_legacy or any(
CONF_DELAY in action or CONF_WAIT_TEMPLATE in action CONF_DELAY in action or CONF_WAIT_TEMPLATE in action
for action in self.sequence for action in self.sequence
) )
if not if_running and not run_mode:
self._if_running = IF_RUNNING_PARALLEL
self._run_mode = RUN_MODE_LEGACY
elif if_running and run_mode == RUN_MODE_LEGACY:
self._raise('Cannot use if_running if run_mode is "legacy"')
else:
self._if_running = if_running or IF_RUNNING_CHOICES[0]
self._run_mode = run_mode or RUN_MODE_CHOICES[0]
self._runs: List[_ScriptRunBase] = [] self._runs: List[_ScriptRunBase] = []
self._log_exceptions = log_exceptions if script_mode == SCRIPT_MODE_QUEUE:
self._queue_max = queue_max
self._queue_len = 0
self._queue_lck = asyncio.Lock()
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {} self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {}
self._referenced_entities: Optional[Set[str]] = None self._referenced_entities: Optional[Set[str]] = None
self._referenced_devices: Optional[Set[str]] = None self._referenced_devices: Optional[Set[str]] = None
def _changed(self): def _changed(self):
if self._change_listener: if self.change_listener:
self._hass.async_add_job(self._change_listener) if is_callback(self.change_listener):
self.change_listener()
else:
self._hass.async_add_job(self.change_listener)
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
"""Return true if script is on.""" """Return true if script is on."""
return len(self._runs) > 0 return len(self._runs) > 0
@property
def is_legacy(self) -> bool:
"""Return if using legacy mode."""
return self._script_mode == SCRIPT_MODE_LEGACY
@property @property
def referenced_devices(self): def referenced_devices(self):
"""Return a set of referenced devices.""" """Return a set of referenced devices."""
@ -626,7 +726,7 @@ class Script:
action = cv.determine_script_action(step) action = cv.determine_script_action(step)
if action == cv.SCRIPT_ACTION_CALL_SERVICE: if action == cv.SCRIPT_ACTION_CALL_SERVICE:
data = step.get(service.CONF_SERVICE_DATA) data = step.get(CONF_SERVICE_DATA)
if not data: if not data:
continue continue
@ -661,18 +761,26 @@ class Script:
) -> None: ) -> None:
"""Run script.""" """Run script."""
if self.is_running: if self.is_running:
if self._if_running == IF_RUNNING_IGNORE: if self._script_mode == SCRIPT_MODE_IGNORE:
self._log("Skipping script") self._log("Skipping script")
return return
if self._if_running == IF_RUNNING_ERROR: if self._script_mode == SCRIPT_MODE_ERROR:
self._raise("Already running") raise AlreadyRunning
if self._if_running == IF_RUNNING_RESTART:
self._log("Restarting script")
await self.async_stop()
self.last_triggered = utcnow() if self._script_mode == SCRIPT_MODE_RESTART:
if self._run_mode == RUN_MODE_LEGACY: self._log("Restarting script")
await self.async_stop(update_state=False)
elif self._script_mode == SCRIPT_MODE_QUEUE:
self._log(
"Queueing script behind %i run%s",
self._queue_len,
"s" if self._queue_len > 1 else "",
)
if self._queue_len >= self._queue_max:
raise QueueFull
if self.is_legacy:
if self._runs: if self._runs:
shared = cast(Optional[_LegacyScriptRun], self._runs[0]) shared = cast(Optional[_LegacyScriptRun], self._runs[0])
else: else:
@ -681,23 +789,31 @@ class Script:
self._hass, self, variables, context, self._log_exceptions, shared self._hass, self, variables, context, self._log_exceptions, shared
) )
else: else:
if self._run_mode == RUN_MODE_BACKGROUND: if self._script_mode != SCRIPT_MODE_QUEUE:
run = _BackgroundScriptRun( cls = _ScriptRun
self._hass, self, variables, context, self._log_exceptions
)
else: else:
run = _BlockingScriptRun( cls = _QueuedScriptRun
self._hass, self, variables, context, self._log_exceptions self._queue_len += 1
) run = cls(self._hass, self, variables, context, self._log_exceptions)
self._runs.append(run) self._runs.append(run)
await run.async_run()
async def async_stop(self) -> None: try:
if self.is_legacy:
await run.async_run()
else:
await asyncio.shield(run.async_run())
except asyncio.CancelledError:
await run.async_stop()
self._changed()
raise
async def async_stop(self, update_state: bool = True) -> None:
"""Stop running script.""" """Stop running script."""
if not self.is_running: if not self.is_running:
return return
await asyncio.shield(asyncio.gather(*(run.async_stop() for run in self._runs))) await asyncio.shield(asyncio.gather(*(run.async_stop() for run in self._runs)))
self._changed() if update_state:
self._changed()
def _log(self, msg, *args, level=logging.INFO): def _log(self, msg, *args, level=logging.INFO):
if self.name: if self.name:
@ -706,9 +822,3 @@ class Script:
self._logger.exception(msg, *args) self._logger.exception(msg, *args)
else: else:
self._logger.log(level, msg, *args) self._logger.log(level, msg, *args)
def _raise(self, msg, *args, exception=None):
if not exception:
exception = exceptions.HomeAssistantError
self._log(msg, *args, level=logging.ERROR)
raise exception(msg % args)

View file

@ -56,12 +56,27 @@ async def async_call_from_config(
hass, config, blocking=False, variables=None, validate_config=True, context=None hass, config, blocking=False, variables=None, validate_config=True, context=None
): ):
"""Call a service based on a config hash.""" """Call a service based on a config hash."""
try:
parms = async_prepare_call_from_config(hass, config, variables, validate_config)
except HomeAssistantError as ex:
if blocking:
raise
_LOGGER.error(ex)
else:
await hass.services.async_call(*parms, blocking, context)
@ha.callback
@bind_hass
def async_prepare_call_from_config(hass, config, variables=None, validate_config=False):
"""Prepare to call a service based on a config hash."""
if validate_config: if validate_config:
try: try:
config = cv.SERVICE_SCHEMA(config) config = cv.SERVICE_SCHEMA(config)
except vol.Invalid as ex: except vol.Invalid as ex:
_LOGGER.error("Invalid config for calling service: %s", ex) raise HomeAssistantError(
return f"Invalid config for calling service: {ex}"
) from ex
if CONF_SERVICE in config: if CONF_SERVICE in config:
domain_service = config[CONF_SERVICE] domain_service = config[CONF_SERVICE]
@ -71,17 +86,15 @@ async def async_call_from_config(
domain_service = config[CONF_SERVICE_TEMPLATE].async_render(variables) domain_service = config[CONF_SERVICE_TEMPLATE].async_render(variables)
domain_service = cv.service(domain_service) domain_service = cv.service(domain_service)
except TemplateError as ex: except TemplateError as ex:
if blocking: raise HomeAssistantError(
raise f"Error rendering service name template: {ex}"
_LOGGER.error("Error rendering service name template: %s", ex) ) from ex
return except vol.Invalid as ex:
except vol.Invalid: raise HomeAssistantError(
if blocking: f"Template rendered invalid service: {domain_service}"
raise ) from ex
_LOGGER.error("Template rendered invalid service: %s", domain_service)
return
domain, service_name = domain_service.split(".", 1) domain, service = domain_service.split(".", 1)
service_data = dict(config.get(CONF_SERVICE_DATA, {})) service_data = dict(config.get(CONF_SERVICE_DATA, {}))
if CONF_SERVICE_DATA_TEMPLATE in config: if CONF_SERVICE_DATA_TEMPLATE in config:
@ -91,15 +104,12 @@ async def async_call_from_config(
template.render_complex(config[CONF_SERVICE_DATA_TEMPLATE], variables) template.render_complex(config[CONF_SERVICE_DATA_TEMPLATE], variables)
) )
except TemplateError as ex: except TemplateError as ex:
_LOGGER.error("Error rendering data template: %s", ex) raise HomeAssistantError(f"Error rendering data template: {ex}") from ex
return
if CONF_SERVICE_ENTITY_ID in config: if CONF_SERVICE_ENTITY_ID in config:
service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID] service_data[ATTR_ENTITY_ID] = config[CONF_SERVICE_ENTITY_ID]
await hass.services.async_call( return domain, service, service_data
domain, service_name, service_data, blocking=blocking, context=context
)
@bind_hass @bind_hass

File diff suppressed because it is too large Load diff