From 65fbcfa0ba225bd8af292b337484bebece758e88 Mon Sep 17 00:00:00 2001 From: Erik Montnemery <erik@montnemery.com> Date: Thu, 10 Mar 2022 19:28:00 +0100 Subject: [PATCH] Prevent recursive script calls from deadlocking (#67861) * Prevent recursive script calls from deadlocking * Address review comments, improve tests * Tweak comment --- homeassistant/helpers/script.py | 23 +++++ tests/components/script/test_init.py | 125 +++++++++++++++++++++++++++ 2 files changed, 148 insertions(+) diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 1eabc33b89d..07a89c8cddb 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio from collections.abc import Callable, Sequence from contextlib import asynccontextmanager, suppress +from contextvars import ContextVar from datetime import datetime, timedelta from functools import partial import itertools @@ -126,6 +127,8 @@ SCRIPT_BREAKPOINT_HIT = "script_breakpoint_hit" SCRIPT_DEBUG_CONTINUE_STOP = "script_debug_continue_stop_{}_{}" SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all" +script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None) + def action_trace_append(variables, path): """Append a TraceElement to trace[path].""" @@ -340,6 +343,12 @@ class _ScriptRun: async def async_run(self) -> None: """Run script.""" + # Push the script to the script execution stack + if (script_stack := script_stack_cv.get()) is None: + script_stack = [] + script_stack_cv.set(script_stack) + script_stack.append(id(self._script)) + try: self._log("Running %s", self._script.running_description) for self._step, self._action in enumerate(self._script.sequence): @@ -355,6 +364,8 @@ class _ScriptRun: script_execution_set("error") raise finally: + # Pop the script from the script execution stack + script_stack.pop() self._finish() async def _async_step(self, log_exceptions): @@ -1218,6 +1229,18 @@ class Script: else: variables = cast(dict, run_variables) + # Prevent non-allowed recursive calls which will cause deadlocks when we try to + # stop (restart) or wait for (queued) our own script run. + script_stack = script_stack_cv.get() + if ( + self.script_mode in (SCRIPT_MODE_RESTART, SCRIPT_MODE_QUEUED) + and (script_stack := script_stack_cv.get()) is not None + and id(self) in script_stack + ): + script_execution_set("disallowed_recursion_detected") + _LOGGER.warning("Disallowed recursion detected") + return + if self.script_mode != SCRIPT_MODE_QUEUED: cls = _ScriptRun else: diff --git a/tests/components/script/test_init.py b/tests/components/script/test_init.py index a6923c88aa2..35875c6da12 100644 --- a/tests/components/script/test_init.py +++ b/tests/components/script/test_init.py @@ -27,6 +27,13 @@ from homeassistant.core import ( from homeassistant.exceptions import ServiceNotFound from homeassistant.helpers import template from homeassistant.helpers.event import async_track_state_change +from homeassistant.helpers.script import ( + SCRIPT_MODE_CHOICES, + SCRIPT_MODE_PARALLEL, + SCRIPT_MODE_QUEUED, + SCRIPT_MODE_RESTART, + SCRIPT_MODE_SINGLE, +) from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -790,3 +797,121 @@ async def test_script_restore_last_triggered(hass: HomeAssistant) -> None: state = hass.states.get("script.last_triggered") assert state assert state.attributes["last_triggered"] == time + + +@pytest.mark.parametrize( + "script_mode,warning_msg", + ( + (SCRIPT_MODE_PARALLEL, "Maximum number of runs exceeded"), + (SCRIPT_MODE_QUEUED, "Disallowed recursion detected"), + (SCRIPT_MODE_RESTART, "Disallowed recursion detected"), + (SCRIPT_MODE_SINGLE, "Already running"), + ), +) +async def test_recursive_script(hass, script_mode, warning_msg, caplog): + """Test recursive script calls does not deadlock.""" + # Make sure we cover all script modes + assert SCRIPT_MODE_CHOICES == [ + SCRIPT_MODE_PARALLEL, + SCRIPT_MODE_QUEUED, + SCRIPT_MODE_RESTART, + SCRIPT_MODE_SINGLE, + ] + + assert await async_setup_component( + hass, + "script", + { + "script": { + "script1": { + "mode": script_mode, + "sequence": [ + {"service": "script.script1"}, + {"service": "test.script"}, + ], + }, + } + }, + ) + + service_called = asyncio.Event() + + async def async_service_handler(service): + service_called.set() + + hass.services.async_register("test", "script", async_service_handler) + hass.states.async_set("input_boolean.test", "on") + hass.states.async_set("input_boolean.test2", "off") + + await hass.services.async_call("script", "script1") + await asyncio.wait_for(service_called.wait(), 1) + + assert warning_msg in caplog.text + + +@pytest.mark.parametrize( + "script_mode,warning_msg", + ( + (SCRIPT_MODE_PARALLEL, "Maximum number of runs exceeded"), + (SCRIPT_MODE_QUEUED, "Disallowed recursion detected"), + (SCRIPT_MODE_RESTART, "Disallowed recursion detected"), + (SCRIPT_MODE_SINGLE, "Already running"), + ), +) +async def test_recursive_script_indirect(hass, script_mode, warning_msg, caplog): + """Test recursive script calls does not deadlock.""" + # Make sure we cover all script modes + assert SCRIPT_MODE_CHOICES == [ + SCRIPT_MODE_PARALLEL, + SCRIPT_MODE_QUEUED, + SCRIPT_MODE_RESTART, + SCRIPT_MODE_SINGLE, + ] + + assert await async_setup_component( + hass, + "script", + { + "script": { + "script1": { + "mode": script_mode, + "sequence": [ + {"service": "script.script2"}, + ], + }, + "script2": { + "mode": script_mode, + "sequence": [ + {"service": "script.script3"}, + ], + }, + "script3": { + "mode": script_mode, + "sequence": [ + {"service": "script.script4"}, + ], + }, + "script4": { + "mode": script_mode, + "sequence": [ + {"service": "script.script1"}, + {"service": "test.script"}, + ], + }, + } + }, + ) + + service_called = asyncio.Event() + + async def async_service_handler(service): + service_called.set() + + hass.services.async_register("test", "script", async_service_handler) + hass.states.async_set("input_boolean.test", "on") + hass.states.async_set("input_boolean.test2", "off") + + await hass.services.async_call("script", "script1") + await asyncio.wait_for(service_called.wait(), 1) + + assert warning_msg in caplog.text