diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index d925bf215ab..d739fbfef98 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -1692,7 +1692,7 @@ class Script: script_stack = script_stack_cv.get() if ( self.script_mode in (SCRIPT_MODE_RESTART, SCRIPT_MODE_QUEUED) - and (script_stack := script_stack_cv.get()) is not None + and script_stack is not None and id(self) in script_stack ): script_execution_set("disallowed_recursion_detected") @@ -1706,15 +1706,19 @@ class Script: run = cls( self._hass, self, cast(dict, variables), context, self._log_exceptions ) + has_existing_runs = bool(self._runs) self._runs.append(run) - if self.script_mode == SCRIPT_MODE_RESTART: + if self.script_mode == SCRIPT_MODE_RESTART and has_existing_runs: # When script mode is SCRIPT_MODE_RESTART, first add the new run and then # stop any other runs. If we stop other runs first, self.is_running will # return false after the other script runs were stopped until our task - # resumes running. + # resumes running. Its important that we check if there are existing + # runs before sleeping as otherwise if two runs are started at the exact + # same time they will cancel each other out. self._log("Restarting") # Important: yield to the event loop to allow the script to start in case - # the script is restarting itself. + # the script is restarting itself so it ends up in the script stack and + # the recursion check above will prevent the script from running. await asyncio.sleep(0) await self.async_stop(update_state=False, spare=run) @@ -1730,9 +1734,7 @@ class Script: self._changed() raise - async def _async_stop( - self, aws: list[asyncio.Task], update_state: bool, spare: _ScriptRun | None - ) -> None: + async def _async_stop(self, aws: list[asyncio.Task], update_state: bool) -> None: await asyncio.wait(aws) if update_state: self._changed() @@ -1749,9 +1751,7 @@ class Script: ] if not aws: return - await asyncio.shield( - create_eager_task(self._async_stop(aws, update_state, spare)) - ) + await asyncio.shield(create_eager_task(self._async_stop(aws, update_state))) async def _async_get_condition(self, config): if isinstance(config, template.Template): diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 61e6d0e4660..edf0eff878b 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -8,7 +8,7 @@ from unittest.mock import Mock, patch import pytest -from homeassistant.components import automation +from homeassistant.components import automation, input_boolean, script from homeassistant.components.automation import ( ATTR_SOURCE, DOMAIN, @@ -41,6 +41,7 @@ from homeassistant.core import ( ) from homeassistant.exceptions import HomeAssistantError, Unauthorized from homeassistant.helpers import device_registry as dr +from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.script import ( SCRIPT_MODE_CHOICES, SCRIPT_MODE_PARALLEL, @@ -2980,3 +2981,82 @@ async def test_automation_turns_off_other_automation( async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=5)) await hass.async_block_till_done() assert len(calls) == 0 + + +async def test_two_automations_call_restart_script_same_time( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test two automations that call a restart mode script at the same.""" + hass.states.async_set("binary_sensor.presence", "off") + await hass.async_block_till_done() + events = [] + + @callback + def _save_event(event): + events.append(event) + + assert await async_setup_component( + hass, + input_boolean.DOMAIN, + { + input_boolean.DOMAIN: { + "test_1": None, + } + }, + ) + cancel = async_track_state_change_event(hass, "input_boolean.test_1", _save_event) + + assert await async_setup_component( + hass, + script.DOMAIN, + { + script.DOMAIN: { + "fire_toggle": { + "sequence": [ + { + "service": "input_boolean.toggle", + "target": {"entity_id": "input_boolean.test_1"}, + } + ] + }, + } + }, + ) + + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "trigger": { + "platform": "state", + "entity_id": "binary_sensor.presence", + "to": "on", + }, + "action": { + "service": "script.fire_toggle", + }, + "id": "automation_0", + "mode": "single", + }, + { + "trigger": { + "platform": "state", + "entity_id": "binary_sensor.presence", + "to": "on", + }, + "action": { + "service": "script.fire_toggle", + }, + "id": "automation_1", + "mode": "single", + }, + ] + }, + ) + + hass.states.async_set("binary_sensor.presence", "on") + await hass.async_block_till_done() + assert len(events) == 2 + cancel()