Start script runs eagerly (#113190)

This commit is contained in:
J. Nick Koston 2024-03-14 16:53:26 -10:00 committed by GitHub
parent 92e73312ea
commit bdede0e0da
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 33 additions and 8 deletions

View file

@ -77,6 +77,7 @@ from homeassistant.core import (
callback,
)
from homeassistant.util import slugify
from homeassistant.util.async_ import create_eager_task
from homeassistant.util.dt import utcnow
from . import condition, config_validation as cv, service, template
@ -1611,7 +1612,7 @@ class Script:
self._changed()
try:
return await asyncio.shield(run.async_run())
return await asyncio.shield(create_eager_task(run.async_run()))
except asyncio.CancelledError:
await run.async_stop()
self._changed()

View file

@ -44,6 +44,7 @@ async def async_turn_on(
}
await hass.services.async_call(DOMAIN, SERVICE_TURN_ON, data, blocking=True)
await hass.async_block_till_done()
async def async_turn_off(hass, entity_id=ENTITY_MATCH_ALL) -> None:
@ -51,6 +52,7 @@ async def async_turn_off(hass, entity_id=ENTITY_MATCH_ALL) -> None:
data = {ATTR_ENTITY_ID: entity_id} if entity_id else {}
await hass.services.async_call(DOMAIN, SERVICE_TURN_OFF, data, blocking=True)
await hass.async_block_till_done()
async def async_oscillate(
@ -67,6 +69,7 @@ async def async_oscillate(
}
await hass.services.async_call(DOMAIN, SERVICE_OSCILLATE, data, blocking=True)
await hass.async_block_till_done()
async def async_set_preset_mode(
@ -80,6 +83,7 @@ async def async_set_preset_mode(
}
await hass.services.async_call(DOMAIN, SERVICE_SET_PRESET_MODE, data, blocking=True)
await hass.async_block_till_done()
async def async_set_percentage(
@ -93,6 +97,7 @@ async def async_set_percentage(
}
await hass.services.async_call(DOMAIN, SERVICE_SET_PERCENTAGE, data, blocking=True)
await hass.async_block_till_done()
async def async_increase_speed(
@ -109,6 +114,7 @@ async def async_increase_speed(
}
await hass.services.async_call(DOMAIN, SERVICE_INCREASE_SPEED, data, blocking=True)
await hass.async_block_till_done()
async def async_decrease_speed(
@ -125,6 +131,7 @@ async def async_decrease_speed(
}
await hass.services.async_call(DOMAIN, SERVICE_DECREASE_SPEED, data, blocking=True)
await hass.async_block_till_done()
async def async_set_direction(
@ -138,3 +145,4 @@ async def async_set_direction(
}
await hass.services.async_call(DOMAIN, SERVICE_SET_DIRECTION, data, blocking=True)
await hass.async_block_till_done()

View file

@ -1207,7 +1207,10 @@ async def test_if_not_fires_on_entities_change_with_for_after_stop(
"below": below,
"for": {"seconds": 5},
},
"action": {"service": "test.automation"},
"action": [
{"delay": "0.0001"},
{"service": "test.automation"},
],
}
},
)
@ -1833,7 +1836,10 @@ async def test_attribute_if_not_fires_on_entities_change_with_for_after_stop(
"attribute": "test-measurement",
"for": 5,
},
"action": {"service": "test.automation"},
"action": [
{"delay": "0.0001"},
{"service": "test.automation"},
],
}
},
)

View file

@ -666,7 +666,10 @@ async def test_if_not_fires_on_entities_change_with_for_after_stop(
"to": "world",
"for": {"seconds": 5},
},
"action": {"service": "test.automation"},
"action": [
{"delay": "0.0001"},
{"service": "test.automation"},
],
}
},
)
@ -1624,7 +1627,10 @@ async def test_attribute_if_not_fires_on_entities_change_with_for_after_stop(
"attribute": "name",
"for": 5,
},
"action": {"service": "test.automation"},
"action": [
{"delay": "0.0001"},
{"service": "test.automation"},
],
}
},
)

View file

@ -428,6 +428,7 @@ async def test_set_invalid_direction_from_initial_stage(
await common.async_turn_on(hass, _TEST_FAN)
await common.async_set_direction(hass, _TEST_FAN, "invalid")
assert hass.states.get(_DIRECTION_INPUT_SELECT).state == ""
_verify(hass, STATE_ON, 0, None, None, None)

View file

@ -3441,7 +3441,8 @@ async def test_parallel_loop(
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")
hass.async_create_task(
script_obj.async_run(MappingProxyType({"what": "world"}), Context())
script_obj.async_run(MappingProxyType({"what": "world"}), Context()),
eager_start=True,
)
await hass.async_block_till_done()
@ -3456,7 +3457,6 @@ async def test_parallel_loop(
expected_trace = {
"0": [{"variables": {"what": "world"}}],
"0/parallel/0/sequence/0": [{}],
"0/parallel/1/sequence/0": [{}],
"0/parallel/0/sequence/0/repeat/sequence/0": [
{
"variables": {
@ -3492,6 +3492,7 @@ async def test_parallel_loop(
"result": {"event": "loop1", "event_data": {"hello1": "loop1_c"}},
},
],
"0/parallel/1/sequence/0": [{}],
"0/parallel/1/sequence/0/repeat/sequence/0": [
{
"variables": {
@ -4118,7 +4119,9 @@ async def test_max_exceeded(
)
hass.states.async_set("switch.test", "on")
for _ in range(max_runs + 1):
hass.async_create_task(script_obj.async_run(context=Context()))
hass.async_create_task(
script_obj.async_run(context=Context()), eager_start=True
)
hass.states.async_set("switch.test", "off")
await hass.async_block_till_done()
if max_exceeded is None: