Allow stopping a script with a response value (#95284)
This commit is contained in:
parent
51aa2ba835
commit
5f14cdf69d
10 changed files with 140 additions and 28 deletions
|
@ -28,7 +28,14 @@ from homeassistant.const import (
|
|||
SERVICE_TURN_ON,
|
||||
STATE_ON,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, ServiceCall, callback
|
||||
from homeassistant.core import (
|
||||
Context,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
SupportsResponse,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.config_validation import make_entity_service_schema
|
||||
|
@ -436,6 +443,12 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
|
|||
variables = kwargs.get("variables")
|
||||
context = kwargs.get("context")
|
||||
wait = kwargs.get("wait", True)
|
||||
await self._async_start_run(variables, context, wait)
|
||||
|
||||
async def _async_start_run(
|
||||
self, variables: dict, context: Context, wait: bool
|
||||
) -> ServiceResponse:
|
||||
"""Start the run of a script."""
|
||||
self.async_set_context(context)
|
||||
self.hass.bus.async_fire(
|
||||
EVENT_SCRIPT_STARTED,
|
||||
|
@ -444,8 +457,7 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
|
|||
)
|
||||
coro = self._async_run(variables, context)
|
||||
if wait:
|
||||
await coro
|
||||
return
|
||||
return await coro
|
||||
|
||||
# Caller does not want to wait for called script to finish so let script run in
|
||||
# separate Task. Make a new empty script stack; scripts are allowed to
|
||||
|
@ -457,6 +469,7 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
|
|||
# Wait for first state change so we can guarantee that
|
||||
# it is written to the State Machine before we return.
|
||||
await self._changed.wait()
|
||||
return None
|
||||
|
||||
async def _async_run(self, variables, context):
|
||||
with trace_script(
|
||||
|
@ -483,16 +496,25 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
|
|||
"""
|
||||
await self.script.async_stop()
|
||||
|
||||
async def _service_handler(self, service: ServiceCall) -> None:
|
||||
async def _service_handler(self, service: ServiceCall) -> ServiceResponse:
|
||||
"""Execute a service call to script.<script name>."""
|
||||
await self.async_turn_on(variables=service.data, context=service.context)
|
||||
response = await self._async_start_run(
|
||||
variables=service.data, context=service.context, wait=True
|
||||
)
|
||||
if service.return_response:
|
||||
return response
|
||||
return None
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Restore last triggered on startup and register service."""
|
||||
|
||||
unique_id = cast(str, self.unique_id)
|
||||
self.hass.services.async_register(
|
||||
DOMAIN, unique_id, self._service_handler, schema=SCRIPT_SERVICE_SCHEMA
|
||||
DOMAIN,
|
||||
unique_id,
|
||||
self._service_handler,
|
||||
schema=SCRIPT_SERVICE_SCHEMA,
|
||||
supports_response=SupportsResponse.OPTIONAL,
|
||||
)
|
||||
|
||||
# Register the service description
|
||||
|
|
|
@ -675,8 +675,14 @@ async def handle_execute_script(
|
|||
|
||||
context = connection.context(msg)
|
||||
script_obj = Script(hass, msg["sequence"], f"{const.DOMAIN} script", const.DOMAIN)
|
||||
await script_obj.async_run(msg.get("variables"), context=context)
|
||||
connection.send_result(msg["id"], {"context": context})
|
||||
response = await script_obj.async_run(msg.get("variables"), context=context)
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{
|
||||
"context": context,
|
||||
"response": response,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
|
|
|
@ -221,9 +221,10 @@ CONF_RECIPIENT: Final = "recipient"
|
|||
CONF_REGION: Final = "region"
|
||||
CONF_REPEAT: Final = "repeat"
|
||||
CONF_RESOURCE: Final = "resource"
|
||||
CONF_RESOURCES: Final = "resources"
|
||||
CONF_RESOURCE_TEMPLATE: Final = "resource_template"
|
||||
CONF_RESOURCES: Final = "resources"
|
||||
CONF_RESPONSE_VARIABLE: Final = "response_variable"
|
||||
CONF_RESPONSE: Final = "response"
|
||||
CONF_RGB: Final = "rgb"
|
||||
CONF_ROOM: Final = "room"
|
||||
CONF_SCAN_INTERVAL: Final = "scan_interval"
|
||||
|
|
|
@ -59,6 +59,7 @@ from homeassistant.const import (
|
|||
CONF_PARALLEL,
|
||||
CONF_PLATFORM,
|
||||
CONF_REPEAT,
|
||||
CONF_RESPONSE,
|
||||
CONF_RESPONSE_VARIABLE,
|
||||
CONF_SCAN_INTERVAL,
|
||||
CONF_SCENE,
|
||||
|
@ -1689,7 +1690,11 @@ _SCRIPT_STOP_SCHEMA = vol.Schema(
|
|||
{
|
||||
**SCRIPT_ACTION_BASE_SCHEMA,
|
||||
vol.Required(CONF_STOP): vol.Any(None, string),
|
||||
vol.Optional(CONF_ERROR, default=False): boolean,
|
||||
vol.Exclusive(CONF_ERROR, "error_or_response"): boolean,
|
||||
vol.Exclusive(CONF_RESPONSE, "error_or_response"): vol.Any(
|
||||
vol.All(dict, template_complex),
|
||||
vol.All(str, template),
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ from homeassistant.const import (
|
|||
CONF_MODE,
|
||||
CONF_PARALLEL,
|
||||
CONF_REPEAT,
|
||||
CONF_RESPONSE,
|
||||
CONF_RESPONSE_VARIABLE,
|
||||
CONF_SCENE,
|
||||
CONF_SEQUENCE,
|
||||
|
@ -69,6 +70,7 @@ from homeassistant.core import (
|
|||
Event,
|
||||
HassJob,
|
||||
HomeAssistant,
|
||||
ServiceResponse,
|
||||
SupportsResponse,
|
||||
callback,
|
||||
)
|
||||
|
@ -352,6 +354,11 @@ class _ConditionFail(_HaltScript):
|
|||
class _StopScript(_HaltScript):
|
||||
"""Throw if script needs to stop."""
|
||||
|
||||
def __init__(self, message: str, response: Any) -> None:
|
||||
"""Initialize a halt exception."""
|
||||
super().__init__(message)
|
||||
self.response = response
|
||||
|
||||
|
||||
class _ScriptRun:
|
||||
"""Manage Script sequence run."""
|
||||
|
@ -396,13 +403,14 @@ class _ScriptRun:
|
|||
)
|
||||
self._log("Executing step %s%s", self._script.last_action, _timeout)
|
||||
|
||||
async def async_run(self) -> None:
|
||||
async def async_run(self) -> ServiceResponse:
|
||||
"""Run script."""
|
||||
# Push the script to the script execution stack
|
||||
if (script_stack := script_stack_cv.get()) is None:
|
||||
script_stack = []
|
||||
script_stack_cv.set(script_stack)
|
||||
script_stack.append(id(self._script))
|
||||
response = None
|
||||
|
||||
try:
|
||||
self._log("Running %s", self._script.running_description)
|
||||
|
@ -420,11 +428,15 @@ class _ScriptRun:
|
|||
raise
|
||||
except _ConditionFail:
|
||||
script_execution_set("aborted")
|
||||
except _StopScript:
|
||||
script_execution_set("finished")
|
||||
except _StopScript as err:
|
||||
script_execution_set("finished", err.response)
|
||||
response = err.response
|
||||
|
||||
# Let the _StopScript bubble up if this is a sub-script
|
||||
if not self._script.top_level:
|
||||
raise
|
||||
# We already consumed the response, do not pass it on
|
||||
err.response = None
|
||||
raise err
|
||||
except Exception:
|
||||
script_execution_set("error")
|
||||
raise
|
||||
|
@ -433,6 +445,8 @@ class _ScriptRun:
|
|||
script_stack.pop()
|
||||
self._finish()
|
||||
|
||||
return response
|
||||
|
||||
async def _async_step(self, log_exceptions):
|
||||
continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False)
|
||||
|
||||
|
@ -1010,13 +1024,20 @@ class _ScriptRun:
|
|||
async def _async_stop_step(self):
|
||||
"""Stop script execution."""
|
||||
stop = self._action[CONF_STOP]
|
||||
error = self._action[CONF_ERROR]
|
||||
error = self._action.get(CONF_ERROR, False)
|
||||
trace_set_result(stop=stop, error=error)
|
||||
if error:
|
||||
self._log("Error script sequence: %s", stop)
|
||||
raise _AbortScript(stop)
|
||||
|
||||
self._log("Stop script sequence: %s", stop)
|
||||
raise _StopScript(stop)
|
||||
if CONF_RESPONSE in self._action:
|
||||
response = template.render_complex(
|
||||
self._action[CONF_RESPONSE], self._variables
|
||||
)
|
||||
else:
|
||||
response = None
|
||||
raise _StopScript(stop, response)
|
||||
|
||||
@async_trace_path("parallel")
|
||||
async def _async_parallel_step(self) -> None:
|
||||
|
@ -1455,7 +1476,7 @@ class Script:
|
|||
run_variables: _VarsType | None = None,
|
||||
context: Context | None = None,
|
||||
started_action: Callable[..., Any] | None = None,
|
||||
) -> None:
|
||||
) -> ServiceResponse:
|
||||
"""Run script."""
|
||||
if context is None:
|
||||
self._log(
|
||||
|
@ -1466,7 +1487,7 @@ class Script:
|
|||
# Prevent spawning new script runs when Home Assistant is shutting down
|
||||
if DATA_NEW_SCRIPT_RUNS_NOT_ALLOWED in self._hass.data:
|
||||
self._log("Home Assistant is shutting down, starting script blocked")
|
||||
return
|
||||
return None
|
||||
|
||||
# Prevent spawning new script runs if not allowed by script mode
|
||||
if self.is_running:
|
||||
|
@ -1474,7 +1495,7 @@ class Script:
|
|||
if self._max_exceeded != "SILENT":
|
||||
self._log("Already running", level=LOGSEVERITY[self._max_exceeded])
|
||||
script_execution_set("failed_single")
|
||||
return
|
||||
return None
|
||||
if self.script_mode != SCRIPT_MODE_RESTART and self.runs == self.max_runs:
|
||||
if self._max_exceeded != "SILENT":
|
||||
self._log(
|
||||
|
@ -1482,7 +1503,7 @@ class Script:
|
|||
level=LOGSEVERITY[self._max_exceeded],
|
||||
)
|
||||
script_execution_set("failed_max_runs")
|
||||
return
|
||||
return None
|
||||
|
||||
# If this is a top level Script then make a copy of the variables in case they
|
||||
# are read-only, but more importantly, so as not to leak any variables created
|
||||
|
@ -1519,7 +1540,7 @@ class Script:
|
|||
):
|
||||
script_execution_set("disallowed_recursion_detected")
|
||||
self._log("Disallowed recursion detected", level=logging.WARNING)
|
||||
return
|
||||
return None
|
||||
|
||||
if self.script_mode != SCRIPT_MODE_QUEUED:
|
||||
cls = _ScriptRun
|
||||
|
@ -1543,7 +1564,7 @@ class Script:
|
|||
self._changed()
|
||||
|
||||
try:
|
||||
await asyncio.shield(run.async_run())
|
||||
return await asyncio.shield(run.async_run())
|
||||
except asyncio.CancelledError:
|
||||
await run.async_stop()
|
||||
self._changed()
|
||||
|
|
|
@ -441,7 +441,7 @@ class TemplateEntity(Entity):
|
|||
"""Run an action script."""
|
||||
if run_variables is None:
|
||||
run_variables = {}
|
||||
return await script.async_run(
|
||||
await script.async_run(
|
||||
run_variables={
|
||||
"this": TemplateStateFromEntityId(self.hass, self.entity_id),
|
||||
**run_variables,
|
||||
|
|
|
@ -8,6 +8,7 @@ from contextvars import ContextVar
|
|||
from functools import wraps
|
||||
from typing import Any, cast
|
||||
|
||||
from homeassistant.core import ServiceResponse
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from .typing import TemplateVarsType
|
||||
|
@ -207,13 +208,15 @@ class StopReason:
|
|||
"""Mutable container class for script_execution."""
|
||||
|
||||
script_execution: str | None = None
|
||||
response: ServiceResponse = None
|
||||
|
||||
|
||||
def script_execution_set(reason: str) -> None:
|
||||
def script_execution_set(reason: str, response: ServiceResponse = None) -> None:
|
||||
"""Set stop reason."""
|
||||
if (data := script_execution_cv.get()) is None:
|
||||
return
|
||||
data.script_execution = reason
|
||||
data.response = response
|
||||
|
||||
|
||||
def script_execution_get() -> str | None:
|
||||
|
|
|
@ -48,7 +48,9 @@ from homeassistant.core import (
|
|||
Event,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
State,
|
||||
SupportsResponse,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.helpers import (
|
||||
|
@ -285,7 +287,12 @@ async def async_test_home_assistant(event_loop, load_registries=True):
|
|||
|
||||
|
||||
def async_mock_service(
|
||||
hass: HomeAssistant, domain: str, service: str, schema: vol.Schema | None = None
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
service: str,
|
||||
schema: vol.Schema | None = None,
|
||||
response: ServiceResponse = None,
|
||||
supports_response: SupportsResponse | None = None,
|
||||
) -> list[ServiceCall]:
|
||||
"""Set up a fake service & return a calls log list to this service."""
|
||||
calls = []
|
||||
|
@ -294,8 +301,18 @@ def async_mock_service(
|
|||
def mock_service_log(call): # pylint: disable=unnecessary-lambda
|
||||
"""Mock service call."""
|
||||
calls.append(call)
|
||||
return response
|
||||
|
||||
hass.services.async_register(domain, service, mock_service_log, schema=schema)
|
||||
if supports_response is None and response is not None:
|
||||
supports_response = SupportsResponse.OPTIONAL
|
||||
|
||||
hass.services.async_register(
|
||||
domain,
|
||||
service,
|
||||
mock_service_log,
|
||||
schema=schema,
|
||||
supports_response=supports_response,
|
||||
)
|
||||
|
||||
return calls
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""The tests for the Script component."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
@ -1502,3 +1503,34 @@ async def test_blueprint_script_fails_substitution(
|
|||
"{'service_to_call': 'test.automation'}: No substitution found for input blah"
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response", ({"value": 5}, '{"value": 5}'))
|
||||
async def test_responses(hass: HomeAssistant, response: Any) -> None:
|
||||
"""Test we can get responses."""
|
||||
mock_restore_cache(hass, ())
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"script",
|
||||
{
|
||||
"script": {
|
||||
"test": {
|
||||
"sequence": {
|
||||
"stop": "done",
|
||||
"response": response,
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert await hass.services.async_call(
|
||||
DOMAIN, "test", {"greeting": "world"}, blocking=True, return_response=True
|
||||
) == {"value": 5}
|
||||
# Validate we can also call it without return_response
|
||||
assert (
|
||||
await hass.services.async_call(
|
||||
DOMAIN, "test", {"greeting": "world"}, blocking=True, return_response=False
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
|
|
@ -1672,7 +1672,9 @@ async def test_test_condition(hass: HomeAssistant, websocket_client) -> None:
|
|||
|
||||
async def test_execute_script(hass: HomeAssistant, websocket_client) -> None:
|
||||
"""Test testing a condition."""
|
||||
calls = async_mock_service(hass, "domain_test", "test_service")
|
||||
calls = async_mock_service(
|
||||
hass, "domain_test", "test_service", response={"hello": "world"}
|
||||
)
|
||||
|
||||
await websocket_client.send_json(
|
||||
{
|
||||
|
@ -1682,7 +1684,9 @@ async def test_execute_script(hass: HomeAssistant, websocket_client) -> None:
|
|||
{
|
||||
"service": "domain_test.test_service",
|
||||
"data": {"hello": "world"},
|
||||
}
|
||||
"response_variable": "service_result",
|
||||
},
|
||||
{"stop": "done", "response": "{{ service_result }}"},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
@ -1691,6 +1695,7 @@ async def test_execute_script(hass: HomeAssistant, websocket_client) -> None:
|
|||
assert msg_no_var["id"] == 5
|
||||
assert msg_no_var["type"] == const.TYPE_RESULT
|
||||
assert msg_no_var["success"]
|
||||
assert msg_no_var["result"]["response"] == {"hello": "world"}
|
||||
|
||||
await websocket_client.send_json(
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue