Fix script in restart mode that is fired from the same trigger (#116247)

This commit is contained in:
J. Nick Koston 2024-04-27 07:08:29 -05:00 committed by GitHub
parent a37d274b37
commit 7715bee6b0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 91 additions and 11 deletions

View file

@ -1692,7 +1692,7 @@ class Script:
script_stack = script_stack_cv.get()
if (
self.script_mode in (SCRIPT_MODE_RESTART, SCRIPT_MODE_QUEUED)
and (script_stack := script_stack_cv.get()) is not None
and script_stack is not None
and id(self) in script_stack
):
script_execution_set("disallowed_recursion_detected")
@ -1706,15 +1706,19 @@ class Script:
run = cls(
self._hass, self, cast(dict, variables), context, self._log_exceptions
)
has_existing_runs = bool(self._runs)
self._runs.append(run)
if self.script_mode == SCRIPT_MODE_RESTART:
if self.script_mode == SCRIPT_MODE_RESTART and has_existing_runs:
# When script mode is SCRIPT_MODE_RESTART, first add the new run and then
# stop any other runs. If we stop other runs first, self.is_running will
# return false after the other script runs were stopped until our task
# resumes running.
# resumes running. Its important that we check if there are existing
# runs before sleeping as otherwise if two runs are started at the exact
# same time they will cancel each other out.
self._log("Restarting")
# Important: yield to the event loop to allow the script to start in case
# the script is restarting itself.
# the script is restarting itself so it ends up in the script stack and
# the recursion check above will prevent the script from running.
await asyncio.sleep(0)
await self.async_stop(update_state=False, spare=run)
@ -1730,9 +1734,7 @@ class Script:
self._changed()
raise
async def _async_stop(
self, aws: list[asyncio.Task], update_state: bool, spare: _ScriptRun | None
) -> None:
async def _async_stop(self, aws: list[asyncio.Task], update_state: bool) -> None:
await asyncio.wait(aws)
if update_state:
self._changed()
@ -1749,9 +1751,7 @@ class Script:
]
if not aws:
return
await asyncio.shield(
create_eager_task(self._async_stop(aws, update_state, spare))
)
await asyncio.shield(create_eager_task(self._async_stop(aws, update_state)))
async def _async_get_condition(self, config):
if isinstance(config, template.Template):