Allow scripts to turn themselves on (#71289)
This commit is contained in:
parent
fdee8800a0
commit
1df99badcf
3 changed files with 95 additions and 9 deletions
|
@ -42,6 +42,7 @@ from homeassistant.helpers.script import (
|
||||||
CONF_MAX,
|
CONF_MAX,
|
||||||
CONF_MAX_EXCEEDED,
|
CONF_MAX_EXCEEDED,
|
||||||
Script,
|
Script,
|
||||||
|
script_stack_cv,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.service import async_set_service_schema
|
from homeassistant.helpers.service import async_set_service_schema
|
||||||
from homeassistant.helpers.trace import trace_get, trace_path
|
from homeassistant.helpers.trace import trace_get, trace_path
|
||||||
|
@ -398,10 +399,14 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
|
||||||
return
|
return
|
||||||
|
|
||||||
# Caller does not want to wait for called script to finish so let script run in
|
# Caller does not want to wait for called script to finish so let script run in
|
||||||
# separate Task. However, wait for first state change so we can guarantee that
|
# separate Task. Make a new empty script stack; scripts are allowed to
|
||||||
# it is written to the State Machine before we return.
|
# recursively turn themselves on when not waiting.
|
||||||
|
script_stack_cv.set([])
|
||||||
|
|
||||||
self._changed.clear()
|
self._changed.clear()
|
||||||
self.hass.async_create_task(coro)
|
self.hass.async_create_task(coro)
|
||||||
|
# Wait for first state change so we can guarantee that
|
||||||
|
# it is written to the State Machine before we return.
|
||||||
await self._changed.wait()
|
await self._changed.wait()
|
||||||
|
|
||||||
async def _async_run(self, variables, context):
|
async def _async_run(self, variables, context):
|
||||||
|
|
|
@ -1791,18 +1791,12 @@ async def test_recursive_automation(hass: HomeAssistant, automation_mode, caplog
|
||||||
)
|
)
|
||||||
|
|
||||||
service_called = asyncio.Event()
|
service_called = asyncio.Event()
|
||||||
service_called_late = []
|
|
||||||
|
|
||||||
async def async_service_handler(service):
|
async def async_service_handler(service):
|
||||||
if service.service == "automation_done":
|
if service.service == "automation_done":
|
||||||
service_called.set()
|
service_called.set()
|
||||||
if service.service == "automation_started_late":
|
|
||||||
service_called_late.append(service)
|
|
||||||
|
|
||||||
hass.services.async_register("test", "automation_done", async_service_handler)
|
hass.services.async_register("test", "automation_done", async_service_handler)
|
||||||
hass.services.async_register(
|
|
||||||
"test", "automation_started_late", async_service_handler
|
|
||||||
)
|
|
||||||
|
|
||||||
hass.bus.async_fire("trigger_automation")
|
hass.bus.async_fire("trigger_automation")
|
||||||
await asyncio.wait_for(service_called.wait(), 1)
|
await asyncio.wait_for(service_called.wait(), 1)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""The tests for the Script component."""
|
"""The tests for the Script component."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from datetime import timedelta
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -33,12 +34,13 @@ from homeassistant.helpers.script import (
|
||||||
SCRIPT_MODE_QUEUED,
|
SCRIPT_MODE_QUEUED,
|
||||||
SCRIPT_MODE_RESTART,
|
SCRIPT_MODE_RESTART,
|
||||||
SCRIPT_MODE_SINGLE,
|
SCRIPT_MODE_SINGLE,
|
||||||
|
_async_stop_scripts_at_shutdown,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.service import async_get_all_descriptions
|
from homeassistant.helpers.service import async_get_all_descriptions
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
|
||||||
from tests.common import async_mock_service, mock_restore_cache
|
from tests.common import async_fire_time_changed, async_mock_service, mock_restore_cache
|
||||||
from tests.components.logbook.test_init import MockLazyEventPartialState
|
from tests.components.logbook.test_init import MockLazyEventPartialState
|
||||||
|
|
||||||
ENTITY_ID = "script.test"
|
ENTITY_ID = "script.test"
|
||||||
|
@ -919,6 +921,91 @@ async def test_recursive_script_indirect(hass, script_mode, warning_msg, caplog)
|
||||||
assert warning_msg in caplog.text
|
assert warning_msg in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"script_mode", [SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED, SCRIPT_MODE_RESTART]
|
||||||
|
)
|
||||||
|
async def test_recursive_script_turn_on(hass: HomeAssistant, script_mode, caplog):
|
||||||
|
"""Test script turning itself on.
|
||||||
|
|
||||||
|
- Illegal recursion detection should not be triggered
|
||||||
|
- Home Assistant should not hang on shut down
|
||||||
|
- SCRIPT_MODE_SINGLE is not relevant because suca script can't turn itself on
|
||||||
|
"""
|
||||||
|
# Make sure we cover all script modes
|
||||||
|
assert SCRIPT_MODE_CHOICES == [
|
||||||
|
SCRIPT_MODE_PARALLEL,
|
||||||
|
SCRIPT_MODE_QUEUED,
|
||||||
|
SCRIPT_MODE_RESTART,
|
||||||
|
SCRIPT_MODE_SINGLE,
|
||||||
|
]
|
||||||
|
stop_scripts_at_shutdown_called = asyncio.Event()
|
||||||
|
real_stop_scripts_at_shutdown = _async_stop_scripts_at_shutdown
|
||||||
|
|
||||||
|
async def stop_scripts_at_shutdown(*args):
|
||||||
|
await real_stop_scripts_at_shutdown(*args)
|
||||||
|
stop_scripts_at_shutdown_called.set()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.helpers.script._async_stop_scripts_at_shutdown",
|
||||||
|
wraps=stop_scripts_at_shutdown,
|
||||||
|
):
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
script.DOMAIN,
|
||||||
|
{
|
||||||
|
script.DOMAIN: {
|
||||||
|
"script1": {
|
||||||
|
"mode": script_mode,
|
||||||
|
"sequence": [
|
||||||
|
{
|
||||||
|
"choose": {
|
||||||
|
"conditions": {
|
||||||
|
"condition": "template",
|
||||||
|
"value_template": "{{ request == 'step_2' }}",
|
||||||
|
},
|
||||||
|
"sequence": {"service": "test.script_done"},
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"service": "script.turn_on",
|
||||||
|
"data": {
|
||||||
|
"entity_id": "script.script1",
|
||||||
|
"variables": {"request": "step_2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"service": "script.turn_on",
|
||||||
|
"data": {"entity_id": "script.script1"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
service_called = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_service_handler(service):
|
||||||
|
if service.service == "script_done":
|
||||||
|
service_called.set()
|
||||||
|
|
||||||
|
hass.services.async_register("test", "script_done", async_service_handler)
|
||||||
|
|
||||||
|
await hass.services.async_call("script", "script1")
|
||||||
|
await asyncio.wait_for(service_called.wait(), 1)
|
||||||
|
|
||||||
|
# Trigger 1st stage script shutdown
|
||||||
|
hass.state = CoreState.stopping
|
||||||
|
hass.bus.async_fire("homeassistant_stop")
|
||||||
|
await asyncio.wait_for(stop_scripts_at_shutdown_called.wait(), 1)
|
||||||
|
|
||||||
|
# Trigger 2nd stage script shutdown
|
||||||
|
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=90))
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert "Disallowed recursion detected" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_with_duplicate_scripts(
|
async def test_setup_with_duplicate_scripts(
|
||||||
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue