Allow scripts to turn themselves on (#71289)

This commit is contained in:
Erik Montnemery 2022-05-04 15:54:37 +02:00 committed by GitHub
parent fdee8800a0
commit 1df99badcf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 95 additions and 9 deletions

View file

@ -42,6 +42,7 @@ from homeassistant.helpers.script import (
CONF_MAX, CONF_MAX,
CONF_MAX_EXCEEDED, CONF_MAX_EXCEEDED,
Script, Script,
script_stack_cv,
) )
from homeassistant.helpers.service import async_set_service_schema from homeassistant.helpers.service import async_set_service_schema
from homeassistant.helpers.trace import trace_get, trace_path from homeassistant.helpers.trace import trace_get, trace_path
@ -398,10 +399,14 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
return return
# Caller does not want to wait for called script to finish so let script run in # Caller does not want to wait for called script to finish so let script run in
# separate Task. However, wait for first state change so we can guarantee that # separate Task. Make a new empty script stack; scripts are allowed to
# it is written to the State Machine before we return. # recursively turn themselves on when not waiting.
script_stack_cv.set([])
self._changed.clear() self._changed.clear()
self.hass.async_create_task(coro) self.hass.async_create_task(coro)
# Wait for first state change so we can guarantee that
# it is written to the State Machine before we return.
await self._changed.wait() await self._changed.wait()
async def _async_run(self, variables, context): async def _async_run(self, variables, context):

View file

@ -1791,18 +1791,12 @@ async def test_recursive_automation(hass: HomeAssistant, automation_mode, caplog
) )
service_called = asyncio.Event() service_called = asyncio.Event()
service_called_late = []
async def async_service_handler(service): async def async_service_handler(service):
if service.service == "automation_done": if service.service == "automation_done":
service_called.set() service_called.set()
if service.service == "automation_started_late":
service_called_late.append(service)
hass.services.async_register("test", "automation_done", async_service_handler) hass.services.async_register("test", "automation_done", async_service_handler)
hass.services.async_register(
"test", "automation_started_late", async_service_handler
)
hass.bus.async_fire("trigger_automation") hass.bus.async_fire("trigger_automation")
await asyncio.wait_for(service_called.wait(), 1) await asyncio.wait_for(service_called.wait(), 1)

View file

@ -1,6 +1,7 @@
"""The tests for the Script component.""" """The tests for the Script component."""
# pylint: disable=protected-access # pylint: disable=protected-access
import asyncio import asyncio
from datetime import timedelta
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
@ -33,12 +34,13 @@ from homeassistant.helpers.script import (
SCRIPT_MODE_QUEUED, SCRIPT_MODE_QUEUED,
SCRIPT_MODE_RESTART, SCRIPT_MODE_RESTART,
SCRIPT_MODE_SINGLE, SCRIPT_MODE_SINGLE,
_async_stop_scripts_at_shutdown,
) )
from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from tests.common import async_mock_service, mock_restore_cache from tests.common import async_fire_time_changed, async_mock_service, mock_restore_cache
from tests.components.logbook.test_init import MockLazyEventPartialState from tests.components.logbook.test_init import MockLazyEventPartialState
ENTITY_ID = "script.test" ENTITY_ID = "script.test"
@ -919,6 +921,91 @@ async def test_recursive_script_indirect(hass, script_mode, warning_msg, caplog)
assert warning_msg in caplog.text assert warning_msg in caplog.text
@pytest.mark.parametrize(
"script_mode", [SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED, SCRIPT_MODE_RESTART]
)
async def test_recursive_script_turn_on(hass: HomeAssistant, script_mode, caplog):
"""Test script turning itself on.
- Illegal recursion detection should not be triggered
- Home Assistant should not hang on shut down
- SCRIPT_MODE_SINGLE is not relevant because suca script can't turn itself on
"""
# Make sure we cover all script modes
assert SCRIPT_MODE_CHOICES == [
SCRIPT_MODE_PARALLEL,
SCRIPT_MODE_QUEUED,
SCRIPT_MODE_RESTART,
SCRIPT_MODE_SINGLE,
]
stop_scripts_at_shutdown_called = asyncio.Event()
real_stop_scripts_at_shutdown = _async_stop_scripts_at_shutdown
async def stop_scripts_at_shutdown(*args):
await real_stop_scripts_at_shutdown(*args)
stop_scripts_at_shutdown_called.set()
with patch(
"homeassistant.helpers.script._async_stop_scripts_at_shutdown",
wraps=stop_scripts_at_shutdown,
):
assert await async_setup_component(
hass,
script.DOMAIN,
{
script.DOMAIN: {
"script1": {
"mode": script_mode,
"sequence": [
{
"choose": {
"conditions": {
"condition": "template",
"value_template": "{{ request == 'step_2' }}",
},
"sequence": {"service": "test.script_done"},
},
"default": {
"service": "script.turn_on",
"data": {
"entity_id": "script.script1",
"variables": {"request": "step_2"},
},
},
},
{
"service": "script.turn_on",
"data": {"entity_id": "script.script1"},
},
],
}
}
},
)
service_called = asyncio.Event()
async def async_service_handler(service):
if service.service == "script_done":
service_called.set()
hass.services.async_register("test", "script_done", async_service_handler)
await hass.services.async_call("script", "script1")
await asyncio.wait_for(service_called.wait(), 1)
# Trigger 1st stage script shutdown
hass.state = CoreState.stopping
hass.bus.async_fire("homeassistant_stop")
await asyncio.wait_for(stop_scripts_at_shutdown_called.wait(), 1)
# Trigger 2nd stage script shutdown
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=90))
await hass.async_block_till_done()
assert "Disallowed recursion detected" not in caplog.text
async def test_setup_with_duplicate_scripts( async def test_setup_with_duplicate_scripts(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None: