From 65fbcfa0ba225bd8af292b337484bebece758e88 Mon Sep 17 00:00:00 2001
From: Erik Montnemery <erik@montnemery.com>
Date: Thu, 10 Mar 2022 19:28:00 +0100
Subject: [PATCH] Prevent recursive script calls from deadlocking (#67861)

* Prevent recursive script calls from deadlocking

* Address review comments, improve tests

* Tweak comment
---
 homeassistant/helpers/script.py      |  23 +++++
 tests/components/script/test_init.py | 125 +++++++++++++++++++++++++++
 2 files changed, 148 insertions(+)

diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py
index 1eabc33b89d..07a89c8cddb 100644
--- a/homeassistant/helpers/script.py
+++ b/homeassistant/helpers/script.py
@@ -4,6 +4,7 @@ from __future__ import annotations
 import asyncio
 from collections.abc import Callable, Sequence
 from contextlib import asynccontextmanager, suppress
+from contextvars import ContextVar
 from datetime import datetime, timedelta
 from functools import partial
 import itertools
@@ -126,6 +127,8 @@ SCRIPT_BREAKPOINT_HIT = "script_breakpoint_hit"
 SCRIPT_DEBUG_CONTINUE_STOP = "script_debug_continue_stop_{}_{}"
 SCRIPT_DEBUG_CONTINUE_ALL = "script_debug_continue_all"
 
+script_stack_cv: ContextVar[list[int] | None] = ContextVar("script_stack", default=None)
+
 
 def action_trace_append(variables, path):
     """Append a TraceElement to trace[path]."""
@@ -340,6 +343,12 @@ class _ScriptRun:
 
     async def async_run(self) -> None:
         """Run script."""
+        # Push the script to the script execution stack
+        if (script_stack := script_stack_cv.get()) is None:
+            script_stack = []
+            script_stack_cv.set(script_stack)
+        script_stack.append(id(self._script))
+
         try:
             self._log("Running %s", self._script.running_description)
             for self._step, self._action in enumerate(self._script.sequence):
@@ -355,6 +364,8 @@ class _ScriptRun:
             script_execution_set("error")
             raise
         finally:
+            # Pop the script from the script execution stack
+            script_stack.pop()
             self._finish()
 
     async def _async_step(self, log_exceptions):
@@ -1218,6 +1229,18 @@ class Script:
         else:
             variables = cast(dict, run_variables)
 
+        # Prevent non-allowed recursive calls which will cause deadlocks when we try to
+        # stop (restart) or wait for (queued) our own script run.
+        script_stack = script_stack_cv.get()
+        if (
+            self.script_mode in (SCRIPT_MODE_RESTART, SCRIPT_MODE_QUEUED)
+            and (script_stack := script_stack_cv.get()) is not None
+            and id(self) in script_stack
+        ):
+            script_execution_set("disallowed_recursion_detected")
+            _LOGGER.warning("Disallowed recursion detected")
+            return
+
         if self.script_mode != SCRIPT_MODE_QUEUED:
             cls = _ScriptRun
         else:
diff --git a/tests/components/script/test_init.py b/tests/components/script/test_init.py
index a6923c88aa2..35875c6da12 100644
--- a/tests/components/script/test_init.py
+++ b/tests/components/script/test_init.py
@@ -27,6 +27,13 @@ from homeassistant.core import (
 from homeassistant.exceptions import ServiceNotFound
 from homeassistant.helpers import template
 from homeassistant.helpers.event import async_track_state_change
+from homeassistant.helpers.script import (
+    SCRIPT_MODE_CHOICES,
+    SCRIPT_MODE_PARALLEL,
+    SCRIPT_MODE_QUEUED,
+    SCRIPT_MODE_RESTART,
+    SCRIPT_MODE_SINGLE,
+)
 from homeassistant.helpers.service import async_get_all_descriptions
 from homeassistant.setup import async_setup_component
 import homeassistant.util.dt as dt_util
@@ -790,3 +797,121 @@ async def test_script_restore_last_triggered(hass: HomeAssistant) -> None:
     state = hass.states.get("script.last_triggered")
     assert state
     assert state.attributes["last_triggered"] == time
+
+
+@pytest.mark.parametrize(
+    "script_mode,warning_msg",
+    (
+        (SCRIPT_MODE_PARALLEL, "Maximum number of runs exceeded"),
+        (SCRIPT_MODE_QUEUED, "Disallowed recursion detected"),
+        (SCRIPT_MODE_RESTART, "Disallowed recursion detected"),
+        (SCRIPT_MODE_SINGLE, "Already running"),
+    ),
+)
+async def test_recursive_script(hass, script_mode, warning_msg, caplog):
+    """Test recursive script calls does not deadlock."""
+    # Make sure we cover all script modes
+    assert SCRIPT_MODE_CHOICES == [
+        SCRIPT_MODE_PARALLEL,
+        SCRIPT_MODE_QUEUED,
+        SCRIPT_MODE_RESTART,
+        SCRIPT_MODE_SINGLE,
+    ]
+
+    assert await async_setup_component(
+        hass,
+        "script",
+        {
+            "script": {
+                "script1": {
+                    "mode": script_mode,
+                    "sequence": [
+                        {"service": "script.script1"},
+                        {"service": "test.script"},
+                    ],
+                },
+            }
+        },
+    )
+
+    service_called = asyncio.Event()
+
+    async def async_service_handler(service):
+        service_called.set()
+
+    hass.services.async_register("test", "script", async_service_handler)
+    hass.states.async_set("input_boolean.test", "on")
+    hass.states.async_set("input_boolean.test2", "off")
+
+    await hass.services.async_call("script", "script1")
+    await asyncio.wait_for(service_called.wait(), 1)
+
+    assert warning_msg in caplog.text
+
+
+@pytest.mark.parametrize(
+    "script_mode,warning_msg",
+    (
+        (SCRIPT_MODE_PARALLEL, "Maximum number of runs exceeded"),
+        (SCRIPT_MODE_QUEUED, "Disallowed recursion detected"),
+        (SCRIPT_MODE_RESTART, "Disallowed recursion detected"),
+        (SCRIPT_MODE_SINGLE, "Already running"),
+    ),
+)
+async def test_recursive_script_indirect(hass, script_mode, warning_msg, caplog):
+    """Test recursive script calls does not deadlock."""
+    # Make sure we cover all script modes
+    assert SCRIPT_MODE_CHOICES == [
+        SCRIPT_MODE_PARALLEL,
+        SCRIPT_MODE_QUEUED,
+        SCRIPT_MODE_RESTART,
+        SCRIPT_MODE_SINGLE,
+    ]
+
+    assert await async_setup_component(
+        hass,
+        "script",
+        {
+            "script": {
+                "script1": {
+                    "mode": script_mode,
+                    "sequence": [
+                        {"service": "script.script2"},
+                    ],
+                },
+                "script2": {
+                    "mode": script_mode,
+                    "sequence": [
+                        {"service": "script.script3"},
+                    ],
+                },
+                "script3": {
+                    "mode": script_mode,
+                    "sequence": [
+                        {"service": "script.script4"},
+                    ],
+                },
+                "script4": {
+                    "mode": script_mode,
+                    "sequence": [
+                        {"service": "script.script1"},
+                        {"service": "test.script"},
+                    ],
+                },
+            }
+        },
+    )
+
+    service_called = asyncio.Event()
+
+    async def async_service_handler(service):
+        service_called.set()
+
+    hass.services.async_register("test", "script", async_service_handler)
+    hass.states.async_set("input_boolean.test", "on")
+    hass.states.async_set("input_boolean.test2", "off")
+
+    await hass.services.async_call("script", "script1")
+    await asyncio.wait_for(service_called.wait(), 1)
+
+    assert warning_msg in caplog.text