Allow stopping a script with a response value (#95284)

This commit is contained in:
Paulus Schoutsen 2023-06-27 02:24:22 -04:00 committed by GitHub
parent 51aa2ba835
commit 5f14cdf69d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 140 additions and 28 deletions

View file

@ -28,7 +28,14 @@ from homeassistant.const import (
SERVICE_TURN_ON, SERVICE_TURN_ON,
STATE_ON, STATE_ON,
) )
from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.core import (
Context,
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
callback,
)
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.config_validation import make_entity_service_schema from homeassistant.helpers.config_validation import make_entity_service_schema
@ -436,6 +443,12 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
variables = kwargs.get("variables") variables = kwargs.get("variables")
context = kwargs.get("context") context = kwargs.get("context")
wait = kwargs.get("wait", True) wait = kwargs.get("wait", True)
await self._async_start_run(variables, context, wait)
async def _async_start_run(
self, variables: dict, context: Context, wait: bool
) -> ServiceResponse:
"""Start the run of a script."""
self.async_set_context(context) self.async_set_context(context)
self.hass.bus.async_fire( self.hass.bus.async_fire(
EVENT_SCRIPT_STARTED, EVENT_SCRIPT_STARTED,
@ -444,8 +457,7 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
) )
coro = self._async_run(variables, context) coro = self._async_run(variables, context)
if wait: if wait:
await coro return await coro
return
# Caller does not want to wait for called script to finish so let script run in # Caller does not want to wait for called script to finish so let script run in
# separate Task. Make a new empty script stack; scripts are allowed to # separate Task. Make a new empty script stack; scripts are allowed to
@ -457,6 +469,7 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
# Wait for first state change so we can guarantee that # Wait for first state change so we can guarantee that
# it is written to the State Machine before we return. # it is written to the State Machine before we return.
await self._changed.wait() await self._changed.wait()
return None
async def _async_run(self, variables, context): async def _async_run(self, variables, context):
with trace_script( with trace_script(
@ -483,16 +496,25 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
""" """
await self.script.async_stop() await self.script.async_stop()
async def _service_handler(self, service: ServiceCall) -> None: async def _service_handler(self, service: ServiceCall) -> ServiceResponse:
"""Execute a service call to script.<script name>.""" """Execute a service call to script.<script name>."""
await self.async_turn_on(variables=service.data, context=service.context) response = await self._async_start_run(
variables=service.data, context=service.context, wait=True
)
if service.return_response:
return response
return None
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Restore last triggered on startup and register service.""" """Restore last triggered on startup and register service."""
unique_id = cast(str, self.unique_id) unique_id = cast(str, self.unique_id)
self.hass.services.async_register( self.hass.services.async_register(
DOMAIN, unique_id, self._service_handler, schema=SCRIPT_SERVICE_SCHEMA DOMAIN,
unique_id,
self._service_handler,
schema=SCRIPT_SERVICE_SCHEMA,
supports_response=SupportsResponse.OPTIONAL,
) )
# Register the service description # Register the service description

View file

@ -675,8 +675,14 @@ async def handle_execute_script(
context = connection.context(msg) context = connection.context(msg)
script_obj = Script(hass, msg["sequence"], f"{const.DOMAIN} script", const.DOMAIN) script_obj = Script(hass, msg["sequence"], f"{const.DOMAIN} script", const.DOMAIN)
await script_obj.async_run(msg.get("variables"), context=context) response = await script_obj.async_run(msg.get("variables"), context=context)
connection.send_result(msg["id"], {"context": context}) connection.send_result(
msg["id"],
{
"context": context,
"response": response,
},
)
@callback @callback

View file

@ -221,9 +221,10 @@ CONF_RECIPIENT: Final = "recipient"
CONF_REGION: Final = "region" CONF_REGION: Final = "region"
CONF_REPEAT: Final = "repeat" CONF_REPEAT: Final = "repeat"
CONF_RESOURCE: Final = "resource" CONF_RESOURCE: Final = "resource"
CONF_RESOURCES: Final = "resources"
CONF_RESOURCE_TEMPLATE: Final = "resource_template" CONF_RESOURCE_TEMPLATE: Final = "resource_template"
CONF_RESOURCES: Final = "resources"
CONF_RESPONSE_VARIABLE: Final = "response_variable" CONF_RESPONSE_VARIABLE: Final = "response_variable"
CONF_RESPONSE: Final = "response"
CONF_RGB: Final = "rgb" CONF_RGB: Final = "rgb"
CONF_ROOM: Final = "room" CONF_ROOM: Final = "room"
CONF_SCAN_INTERVAL: Final = "scan_interval" CONF_SCAN_INTERVAL: Final = "scan_interval"

View file

@ -59,6 +59,7 @@ from homeassistant.const import (
CONF_PARALLEL, CONF_PARALLEL,
CONF_PLATFORM, CONF_PLATFORM,
CONF_REPEAT, CONF_REPEAT,
CONF_RESPONSE,
CONF_RESPONSE_VARIABLE, CONF_RESPONSE_VARIABLE,
CONF_SCAN_INTERVAL, CONF_SCAN_INTERVAL,
CONF_SCENE, CONF_SCENE,
@ -1689,7 +1690,11 @@ _SCRIPT_STOP_SCHEMA = vol.Schema(
{ {
**SCRIPT_ACTION_BASE_SCHEMA, **SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(CONF_STOP): vol.Any(None, string), vol.Required(CONF_STOP): vol.Any(None, string),
vol.Optional(CONF_ERROR, default=False): boolean, vol.Exclusive(CONF_ERROR, "error_or_response"): boolean,
vol.Exclusive(CONF_RESPONSE, "error_or_response"): vol.Any(
vol.All(dict, template_complex),
vol.All(str, template),
),
} }
) )

View file

@ -46,6 +46,7 @@ from homeassistant.const import (
CONF_MODE, CONF_MODE,
CONF_PARALLEL, CONF_PARALLEL,
CONF_REPEAT, CONF_REPEAT,
CONF_RESPONSE,
CONF_RESPONSE_VARIABLE, CONF_RESPONSE_VARIABLE,
CONF_SCENE, CONF_SCENE,
CONF_SEQUENCE, CONF_SEQUENCE,
@ -69,6 +70,7 @@ from homeassistant.core import (
Event, Event,
HassJob, HassJob,
HomeAssistant, HomeAssistant,
ServiceResponse,
SupportsResponse, SupportsResponse,
callback, callback,
) )
@ -352,6 +354,11 @@ class _ConditionFail(_HaltScript):
class _StopScript(_HaltScript): class _StopScript(_HaltScript):
"""Throw if script needs to stop.""" """Throw if script needs to stop."""
def __init__(self, message: str, response: Any) -> None:
"""Initialize a halt exception."""
super().__init__(message)
self.response = response
class _ScriptRun: class _ScriptRun:
"""Manage Script sequence run.""" """Manage Script sequence run."""
@ -396,13 +403,14 @@ class _ScriptRun:
) )
self._log("Executing step %s%s", self._script.last_action, _timeout) self._log("Executing step %s%s", self._script.last_action, _timeout)
async def async_run(self) -> None: async def async_run(self) -> ServiceResponse:
"""Run script.""" """Run script."""
# Push the script to the script execution stack # Push the script to the script execution stack
if (script_stack := script_stack_cv.get()) is None: if (script_stack := script_stack_cv.get()) is None:
script_stack = [] script_stack = []
script_stack_cv.set(script_stack) script_stack_cv.set(script_stack)
script_stack.append(id(self._script)) script_stack.append(id(self._script))
response = None
try: try:
self._log("Running %s", self._script.running_description) self._log("Running %s", self._script.running_description)
@ -420,11 +428,15 @@ class _ScriptRun:
raise raise
except _ConditionFail: except _ConditionFail:
script_execution_set("aborted") script_execution_set("aborted")
except _StopScript: except _StopScript as err:
script_execution_set("finished") script_execution_set("finished", err.response)
response = err.response
# Let the _StopScript bubble up if this is a sub-script # Let the _StopScript bubble up if this is a sub-script
if not self._script.top_level: if not self._script.top_level:
raise # We already consumed the response, do not pass it on
err.response = None
raise err
except Exception: except Exception:
script_execution_set("error") script_execution_set("error")
raise raise
@ -433,6 +445,8 @@ class _ScriptRun:
script_stack.pop() script_stack.pop()
self._finish() self._finish()
return response
async def _async_step(self, log_exceptions): async def _async_step(self, log_exceptions):
continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False) continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False)
@ -1010,13 +1024,20 @@ class _ScriptRun:
async def _async_stop_step(self): async def _async_stop_step(self):
"""Stop script execution.""" """Stop script execution."""
stop = self._action[CONF_STOP] stop = self._action[CONF_STOP]
error = self._action[CONF_ERROR] error = self._action.get(CONF_ERROR, False)
trace_set_result(stop=stop, error=error) trace_set_result(stop=stop, error=error)
if error: if error:
self._log("Error script sequence: %s", stop) self._log("Error script sequence: %s", stop)
raise _AbortScript(stop) raise _AbortScript(stop)
self._log("Stop script sequence: %s", stop) self._log("Stop script sequence: %s", stop)
raise _StopScript(stop) if CONF_RESPONSE in self._action:
response = template.render_complex(
self._action[CONF_RESPONSE], self._variables
)
else:
response = None
raise _StopScript(stop, response)
@async_trace_path("parallel") @async_trace_path("parallel")
async def _async_parallel_step(self) -> None: async def _async_parallel_step(self) -> None:
@ -1455,7 +1476,7 @@ class Script:
run_variables: _VarsType | None = None, run_variables: _VarsType | None = None,
context: Context | None = None, context: Context | None = None,
started_action: Callable[..., Any] | None = None, started_action: Callable[..., Any] | None = None,
) -> None: ) -> ServiceResponse:
"""Run script.""" """Run script."""
if context is None: if context is None:
self._log( self._log(
@ -1466,7 +1487,7 @@ class Script:
# Prevent spawning new script runs when Home Assistant is shutting down # Prevent spawning new script runs when Home Assistant is shutting down
if DATA_NEW_SCRIPT_RUNS_NOT_ALLOWED in self._hass.data: if DATA_NEW_SCRIPT_RUNS_NOT_ALLOWED in self._hass.data:
self._log("Home Assistant is shutting down, starting script blocked") self._log("Home Assistant is shutting down, starting script blocked")
return return None
# Prevent spawning new script runs if not allowed by script mode # Prevent spawning new script runs if not allowed by script mode
if self.is_running: if self.is_running:
@ -1474,7 +1495,7 @@ class Script:
if self._max_exceeded != "SILENT": if self._max_exceeded != "SILENT":
self._log("Already running", level=LOGSEVERITY[self._max_exceeded]) self._log("Already running", level=LOGSEVERITY[self._max_exceeded])
script_execution_set("failed_single") script_execution_set("failed_single")
return return None
if self.script_mode != SCRIPT_MODE_RESTART and self.runs == self.max_runs: if self.script_mode != SCRIPT_MODE_RESTART and self.runs == self.max_runs:
if self._max_exceeded != "SILENT": if self._max_exceeded != "SILENT":
self._log( self._log(
@ -1482,7 +1503,7 @@ class Script:
level=LOGSEVERITY[self._max_exceeded], level=LOGSEVERITY[self._max_exceeded],
) )
script_execution_set("failed_max_runs") script_execution_set("failed_max_runs")
return return None
# If this is a top level Script then make a copy of the variables in case they # If this is a top level Script then make a copy of the variables in case they
# are read-only, but more importantly, so as not to leak any variables created # are read-only, but more importantly, so as not to leak any variables created
@ -1519,7 +1540,7 @@ class Script:
): ):
script_execution_set("disallowed_recursion_detected") script_execution_set("disallowed_recursion_detected")
self._log("Disallowed recursion detected", level=logging.WARNING) self._log("Disallowed recursion detected", level=logging.WARNING)
return return None
if self.script_mode != SCRIPT_MODE_QUEUED: if self.script_mode != SCRIPT_MODE_QUEUED:
cls = _ScriptRun cls = _ScriptRun
@ -1543,7 +1564,7 @@ class Script:
self._changed() self._changed()
try: try:
await asyncio.shield(run.async_run()) return await asyncio.shield(run.async_run())
except asyncio.CancelledError: except asyncio.CancelledError:
await run.async_stop() await run.async_stop()
self._changed() self._changed()

View file

@ -441,7 +441,7 @@ class TemplateEntity(Entity):
"""Run an action script.""" """Run an action script."""
if run_variables is None: if run_variables is None:
run_variables = {} run_variables = {}
return await script.async_run( await script.async_run(
run_variables={ run_variables={
"this": TemplateStateFromEntityId(self.hass, self.entity_id), "this": TemplateStateFromEntityId(self.hass, self.entity_id),
**run_variables, **run_variables,

View file

@ -8,6 +8,7 @@ from contextvars import ContextVar
from functools import wraps from functools import wraps
from typing import Any, cast from typing import Any, cast
from homeassistant.core import ServiceResponse
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .typing import TemplateVarsType from .typing import TemplateVarsType
@ -207,13 +208,15 @@ class StopReason:
"""Mutable container class for script_execution.""" """Mutable container class for script_execution."""
script_execution: str | None = None script_execution: str | None = None
response: ServiceResponse = None
def script_execution_set(reason: str) -> None: def script_execution_set(reason: str, response: ServiceResponse = None) -> None:
"""Set stop reason.""" """Set stop reason."""
if (data := script_execution_cv.get()) is None: if (data := script_execution_cv.get()) is None:
return return
data.script_execution = reason data.script_execution = reason
data.response = response
def script_execution_get() -> str | None: def script_execution_get() -> str | None:

View file

@ -48,7 +48,9 @@ from homeassistant.core import (
Event, Event,
HomeAssistant, HomeAssistant,
ServiceCall, ServiceCall,
ServiceResponse,
State, State,
SupportsResponse,
callback, callback,
) )
from homeassistant.helpers import ( from homeassistant.helpers import (
@ -285,7 +287,12 @@ async def async_test_home_assistant(event_loop, load_registries=True):
def async_mock_service( def async_mock_service(
hass: HomeAssistant, domain: str, service: str, schema: vol.Schema | None = None hass: HomeAssistant,
domain: str,
service: str,
schema: vol.Schema | None = None,
response: ServiceResponse = None,
supports_response: SupportsResponse | None = None,
) -> list[ServiceCall]: ) -> list[ServiceCall]:
"""Set up a fake service & return a calls log list to this service.""" """Set up a fake service & return a calls log list to this service."""
calls = [] calls = []
@ -294,8 +301,18 @@ def async_mock_service(
def mock_service_log(call): # pylint: disable=unnecessary-lambda def mock_service_log(call): # pylint: disable=unnecessary-lambda
"""Mock service call.""" """Mock service call."""
calls.append(call) calls.append(call)
return response
hass.services.async_register(domain, service, mock_service_log, schema=schema) if supports_response is None and response is not None:
supports_response = SupportsResponse.OPTIONAL
hass.services.async_register(
domain,
service,
mock_service_log,
schema=schema,
supports_response=supports_response,
)
return calls return calls

View file

@ -1,6 +1,7 @@
"""The tests for the Script component.""" """The tests for the Script component."""
import asyncio import asyncio
from datetime import timedelta from datetime import timedelta
from typing import Any
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
@ -1502,3 +1503,34 @@ async def test_blueprint_script_fails_substitution(
"{'service_to_call': 'test.automation'}: No substitution found for input blah" "{'service_to_call': 'test.automation'}: No substitution found for input blah"
in caplog.text in caplog.text
) )
@pytest.mark.parametrize("response", ({"value": 5}, '{"value": 5}'))
async def test_responses(hass: HomeAssistant, response: Any) -> None:
"""Test we can get responses."""
mock_restore_cache(hass, ())
assert await async_setup_component(
hass,
"script",
{
"script": {
"test": {
"sequence": {
"stop": "done",
"response": response,
}
}
}
},
)
assert await hass.services.async_call(
DOMAIN, "test", {"greeting": "world"}, blocking=True, return_response=True
) == {"value": 5}
# Validate we can also call it without return_response
assert (
await hass.services.async_call(
DOMAIN, "test", {"greeting": "world"}, blocking=True, return_response=False
)
is None
)

View file

@ -1672,7 +1672,9 @@ async def test_test_condition(hass: HomeAssistant, websocket_client) -> None:
async def test_execute_script(hass: HomeAssistant, websocket_client) -> None: async def test_execute_script(hass: HomeAssistant, websocket_client) -> None:
"""Test testing a condition.""" """Test testing a condition."""
calls = async_mock_service(hass, "domain_test", "test_service") calls = async_mock_service(
hass, "domain_test", "test_service", response={"hello": "world"}
)
await websocket_client.send_json( await websocket_client.send_json(
{ {
@ -1682,7 +1684,9 @@ async def test_execute_script(hass: HomeAssistant, websocket_client) -> None:
{ {
"service": "domain_test.test_service", "service": "domain_test.test_service",
"data": {"hello": "world"}, "data": {"hello": "world"},
} "response_variable": "service_result",
},
{"stop": "done", "response": "{{ service_result }}"},
], ],
} }
) )
@ -1691,6 +1695,7 @@ async def test_execute_script(hass: HomeAssistant, websocket_client) -> None:
assert msg_no_var["id"] == 5 assert msg_no_var["id"] == 5
assert msg_no_var["type"] == const.TYPE_RESULT assert msg_no_var["type"] == const.TYPE_RESULT
assert msg_no_var["success"] assert msg_no_var["success"]
assert msg_no_var["result"]["response"] == {"hello": "world"}
await websocket_client.send_json( await websocket_client.send_json(
{ {