Don't reload other automations when saving an automation (#80254)

* Only reload modified automation

* Correct check for existing automation

* Add tests

* Remove the new service, improve ReloadServiceHelper

* Revert unneeded changes

* Update tests

* Address review comments

* Improve test coverage

* Address review comments

* Tweak reloader code + add a targetted test

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Explain the tests + add more variations

* Fix copy-paste mistake in test

* Rephrase explanation of expected test outcome

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Erik Montnemery 2024-04-16 15:58:57 +02:00 committed by GitHub
parent 679752ceb8
commit 7cd0fe3c5f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 484 additions and 26 deletions

View file

@ -331,17 +331,25 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
await async_get_blueprints(hass).async_reset_cache()
if (conf := await component.async_prepare_reload(skip_reset=True)) is None:
return
await _async_process_config(hass, conf, component)
if automation_id := service_call.data.get(CONF_ID):
await _async_process_single_config(hass, conf, component, automation_id)
else:
await _async_process_config(hass, conf, component)
hass.bus.async_fire(EVENT_AUTOMATION_RELOADED, context=service_call.context)
reload_helper = ReloadServiceHelper(reload_service_handler)
def reload_targets(service_call: ServiceCall) -> set[str | None]:
if automation_id := service_call.data.get(CONF_ID):
return {automation_id}
return {automation.unique_id for automation in component.entities}
reload_helper = ReloadServiceHelper(reload_service_handler, reload_targets)
async_register_admin_service(
hass,
DOMAIN,
SERVICE_RELOAD,
reload_helper.execute_service,
schema=vol.Schema({}),
schema=vol.Schema({vol.Optional(CONF_ID): str}),
)
websocket_api.async_register_command(hass, websocket_config)
@ -859,6 +867,7 @@ class AutomationEntityConfig:
async def _prepare_automation_config(
hass: HomeAssistant,
config: ConfigType,
wanted_automation_id: str | None,
) -> list[AutomationEntityConfig]:
"""Parse configuration and prepare automation entity configuration."""
automation_configs: list[AutomationEntityConfig] = []
@ -866,6 +875,10 @@ async def _prepare_automation_config(
conf: list[ConfigType] = config[DOMAIN]
for list_no, config_block in enumerate(conf):
automation_id: str | None = config_block.get(CONF_ID)
if wanted_automation_id is not None and automation_id != wanted_automation_id:
continue
raw_config = cast(AutomationConfig, config_block).raw_config
raw_blueprint_inputs = cast(AutomationConfig, config_block).raw_blueprint_inputs
validation_failed = cast(AutomationConfig, config_block).validation_failed
@ -1025,7 +1038,7 @@ async def _async_process_config(
return automation_matches, config_matches
automation_configs = await _prepare_automation_config(hass, config)
automation_configs = await _prepare_automation_config(hass, config, None)
automations: list[BaseAutomationEntity] = list(component.entities)
# Find automations and configurations which have matches
@ -1049,6 +1062,41 @@ async def _async_process_config(
await component.async_add_entities(entities)
def _automation_matches_config(
automation: BaseAutomationEntity | None, config: AutomationEntityConfig | None
) -> bool:
"""Return False if an automation's config has been changed."""
if not automation:
return False
if not config:
return False
name = _automation_name(config)
return automation.name == name and automation.raw_config == config.raw_config
async def _async_process_single_config(
hass: HomeAssistant,
config: dict[str, Any],
component: EntityComponent[BaseAutomationEntity],
automation_id: str,
) -> None:
"""Process config and add a single automation."""
automation_configs = await _prepare_automation_config(hass, config, automation_id)
automation = next(
(x for x in component.entities if x.unique_id == automation_id), None
)
automation_config = automation_configs[0] if automation_configs else None
if _automation_matches_config(automation, automation_config):
return
if automation:
await automation.async_remove()
entities = await _create_automation_entities(hass, automation_configs)
await component.async_add_entities(entities)
async def _async_process_if(
hass: HomeAssistant, name: str, config: dict[str, Any]
) -> IfAction | None:

View file

@ -26,7 +26,9 @@ def async_setup(hass: HomeAssistant) -> bool:
async def hook(action: str, config_key: str) -> None:
"""post_write_hook for Config View that reloads automations."""
if action != ACTION_DELETE:
await hass.services.async_call(DOMAIN, SERVICE_RELOAD)
await hass.services.async_call(
DOMAIN, SERVICE_RELOAD, {CONF_ID: config_key}
)
return
ent_reg = er.async_get(hass)

View file

@ -77,6 +77,8 @@ _LOGGER = logging.getLogger(__name__)
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
ALL_SERVICE_DESCRIPTIONS_CACHE = "all_service_descriptions_cache"
_T = TypeVar("_T")
@cache
def _base_components() -> dict[str, ModuleType]:
@ -1154,40 +1156,67 @@ def verify_domain_control(
class ReloadServiceHelper:
"""Helper for reload services to minimize unnecessary reloads."""
"""Helper for reload services.
def __init__(self, service_func: Callable[[ServiceCall], Awaitable]) -> None:
The helper has the following purposes:
- Make sure reloads do not happen in parallel
- Avoid redundant reloads of the same target
"""
def __init__(
self,
service_func: Callable[[ServiceCall], Awaitable],
reload_targets_func: Callable[[ServiceCall], set[_T]],
) -> None:
"""Initialize ReloadServiceHelper."""
self._service_func = service_func
self._service_running = False
self._service_condition = asyncio.Condition()
self._pending_reload_targets: set[_T] = set()
self._reload_targets_func = reload_targets_func
async def execute_service(self, service_call: ServiceCall) -> None:
"""Execute the service.
If a previous reload task if currently in progress, wait for it to finish first.
If a previous reload task is currently in progress, wait for it to finish first.
Once the previous reload task has finished, one of the waiting tasks will be
assigned to execute the reload, the others will wait for the reload to finish.
assigned to execute the reload of the targets it is assigned to reload. The
other tasks will wait if they should reload the same target, otherwise they
will wait for the next round.
"""
do_reload = False
reload_targets = None
async with self._service_condition:
if self._service_running:
# A previous reload task is already in progress, wait for it to finish
# A previous reload task is already in progress, wait for it to finish,
# because that task may be reloading a stale version of the resource.
await self._service_condition.wait()
async with self._service_condition:
if not self._service_running:
# This task will do the reload
self._service_running = True
do_reload = True
else:
# Another task will perform the reload, wait for it to finish
while True:
async with self._service_condition:
# Once we've passed this point, we assume the version of the resource is
# the one our task was assigned to reload, or a newer one. Regardless of
# which, our task is happy as long as the target is reloaded at least
# once.
if reload_targets is None:
reload_targets = self._reload_targets_func(service_call)
self._pending_reload_targets |= reload_targets
if not self._service_running:
# This task will do a reload
self._service_running = True
do_reload = True
break
# Another task will perform a reload, wait for it to finish
await self._service_condition.wait()
# Check if the reload this task is waiting for has been completed
if reload_targets.isdisjoint(self._pending_reload_targets):
break
if do_reload:
# Reload, then notify other tasks
await self._service_func(service_call)
async with self._service_condition:
self._service_running = False
self._pending_reload_targets -= reload_targets
self._service_condition.notify_all()

View file

@ -21,6 +21,7 @@ from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import (
ATTR_ENTITY_ID,
ATTR_NAME,
CONF_ID,
EVENT_HOMEASSISTANT_STARTED,
SERVICE_RELOAD,
SERVICE_TOGGLE,
@ -692,7 +693,9 @@ async def test_reload_config_handles_load_fails(hass: HomeAssistant, calls) -> N
assert len(calls) == 2
@pytest.mark.parametrize("service", ["turn_off_stop", "turn_off_no_stop", "reload"])
@pytest.mark.parametrize(
"service", ["turn_off_stop", "turn_off_no_stop", "reload", "reload_single"]
)
async def test_automation_stops(hass: HomeAssistant, calls, service) -> None:
"""Test that turning off / reloading stops any running actions as appropriate."""
entity_id = "automation.hello"
@ -700,6 +703,7 @@ async def test_automation_stops(hass: HomeAssistant, calls, service) -> None:
config = {
automation.DOMAIN: {
"id": "sun",
"alias": "hello",
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [
@ -737,7 +741,7 @@ async def test_automation_stops(hass: HomeAssistant, calls, service) -> None:
{ATTR_ENTITY_ID: entity_id, automation.CONF_STOP_ACTIONS: False},
blocking=True,
)
else:
elif service == "reload":
config[automation.DOMAIN]["alias"] = "goodbye"
with patch(
"homeassistant.config.load_yaml_config_file",
@ -747,6 +751,19 @@ async def test_automation_stops(hass: HomeAssistant, calls, service) -> None:
await hass.services.async_call(
automation.DOMAIN, SERVICE_RELOAD, blocking=True
)
else: # service == "reload_single"
config[automation.DOMAIN]["alias"] = "goodbye"
with patch(
"homeassistant.config.load_yaml_config_file",
autospec=True,
return_value=config,
):
await hass.services.async_call(
automation.DOMAIN,
SERVICE_RELOAD,
{CONF_ID: "sun"},
blocking=True,
)
hass.states.async_set(test_entity, "goodbye")
await hass.async_block_till_done()
@ -801,6 +818,238 @@ async def test_reload_unchanged_does_not_stop(
assert len(calls) == 1
async def test_reload_single_unchanged_does_not_stop(
hass: HomeAssistant, calls
) -> None:
"""Test that reloading stops any running actions as appropriate."""
test_entity = "test.entity"
config = {
automation.DOMAIN: {
"id": "sun",
"alias": "hello",
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [
{"event": "running"},
{"wait_template": "{{ is_state('test.entity', 'goodbye') }}"},
{"service": "test.automation"},
],
}
}
assert await async_setup_component(hass, automation.DOMAIN, config)
running = asyncio.Event()
@callback
def running_cb(event):
running.set()
hass.bus.async_listen_once("running", running_cb)
hass.states.async_set(test_entity, "hello")
hass.bus.async_fire("test_event")
await running.wait()
assert len(calls) == 0
with patch(
"homeassistant.config.load_yaml_config_file",
autospec=True,
return_value=config,
):
await hass.services.async_call(
automation.DOMAIN,
SERVICE_RELOAD,
{CONF_ID: "sun"},
blocking=True,
)
hass.states.async_set(test_entity, "goodbye")
await hass.async_block_till_done()
assert len(calls) == 1
async def test_reload_single_add_automation(hass: HomeAssistant, calls) -> None:
"""Test that reloading a single automation."""
config1 = {automation.DOMAIN: {}}
config2 = {
automation.DOMAIN: {
"id": "sun",
"alias": "hello",
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [{"service": "test.automation"}],
}
}
assert await async_setup_component(hass, automation.DOMAIN, config1)
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert len(calls) == 0
with patch(
"homeassistant.config.load_yaml_config_file",
autospec=True,
return_value=config2,
):
await hass.services.async_call(
automation.DOMAIN,
SERVICE_RELOAD,
{CONF_ID: "sun"},
blocking=True,
)
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert len(calls) == 1
async def test_reload_single_parallel_calls(hass: HomeAssistant, calls) -> None:
"""Test reloading single automations in parallel."""
config1 = {automation.DOMAIN: {}}
config2 = {
automation.DOMAIN: [
{
"id": "sun",
"alias": "hello",
"trigger": {"platform": "event", "event_type": "test_event_sun"},
"action": [{"service": "test.automation"}],
},
{
"id": "moon",
"alias": "goodbye",
"trigger": {"platform": "event", "event_type": "test_event_moon"},
"action": [{"service": "test.automation"}],
},
{
"id": "mars",
"alias": "goodbye",
"trigger": {"platform": "event", "event_type": "test_event_mars"},
"action": [{"service": "test.automation"}],
},
{
"id": "venus",
"alias": "goodbye",
"trigger": {"platform": "event", "event_type": "test_event_venus"},
"action": [{"service": "test.automation"}],
},
]
}
assert await async_setup_component(hass, automation.DOMAIN, config1)
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert len(calls) == 0
# Trigger multiple reload service calls, each automation is reloaded twice.
# This tests the logic in the `ReloadServiceHelper` which avoids redundant
# reloads of the same target automation.
with patch(
"homeassistant.config.load_yaml_config_file",
autospec=True,
return_value=config2,
):
tasks = [
hass.services.async_call(
automation.DOMAIN,
SERVICE_RELOAD,
{CONF_ID: "sun"},
blocking=False,
),
hass.services.async_call(
automation.DOMAIN,
SERVICE_RELOAD,
{CONF_ID: "moon"},
blocking=False,
),
hass.services.async_call(
automation.DOMAIN,
SERVICE_RELOAD,
{CONF_ID: "mars"},
blocking=False,
),
hass.services.async_call(
automation.DOMAIN,
SERVICE_RELOAD,
{CONF_ID: "venus"},
blocking=False,
),
hass.services.async_call(
automation.DOMAIN,
SERVICE_RELOAD,
{CONF_ID: "sun"},
blocking=False,
),
hass.services.async_call(
automation.DOMAIN,
SERVICE_RELOAD,
{CONF_ID: "moon"},
blocking=False,
),
hass.services.async_call(
automation.DOMAIN,
SERVICE_RELOAD,
{CONF_ID: "mars"},
blocking=False,
),
hass.services.async_call(
automation.DOMAIN,
SERVICE_RELOAD,
{CONF_ID: "venus"},
blocking=False,
),
]
await asyncio.gather(*tasks)
await hass.async_block_till_done()
# Sanity check to ensure all automations are correctly setup
hass.bus.async_fire("test_event_sun")
await hass.async_block_till_done()
assert len(calls) == 1
hass.bus.async_fire("test_event_moon")
await hass.async_block_till_done()
assert len(calls) == 2
hass.bus.async_fire("test_event_mars")
await hass.async_block_till_done()
assert len(calls) == 3
hass.bus.async_fire("test_event_venus")
await hass.async_block_till_done()
assert len(calls) == 4
async def test_reload_single_remove_automation(hass: HomeAssistant, calls) -> None:
"""Test that reloading a single automation."""
config1 = {
automation.DOMAIN: {
"id": "sun",
"alias": "hello",
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [{"service": "test.automation"}],
}
}
config2 = {automation.DOMAIN: {}}
assert await async_setup_component(hass, automation.DOMAIN, config1)
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert len(calls) == 1
with patch(
"homeassistant.config.load_yaml_config_file",
autospec=True,
return_value=config2,
):
await hass.services.async_call(
automation.DOMAIN,
SERVICE_RELOAD,
{CONF_ID: "sun"},
blocking=True,
)
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert len(calls) == 1
async def test_reload_moved_automation_without_alias(
hass: HomeAssistant, calls
) -> None:

View file

@ -10,7 +10,7 @@ import pytest
from homeassistant.bootstrap import async_setup_component
from homeassistant.components import config
from homeassistant.components.config import automation
from homeassistant.const import STATE_ON, STATE_UNAVAILABLE
from homeassistant.const import STATE_ON
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.util import yaml
@ -82,10 +82,8 @@ async def test_update_automation_config(
)
await hass.async_block_till_done()
assert sorted(hass.states.async_entity_ids("automation")) == [
"automation.automation_0",
"automation.automation_1",
]
assert hass.states.get("automation.automation_0").state == STATE_UNAVAILABLE
assert hass.states.get("automation.automation_1").state == STATE_ON
assert resp.status == HTTPStatus.OK
@ -260,10 +258,8 @@ async def test_update_remove_key_automation_config(
)
await hass.async_block_till_done()
assert sorted(hass.states.async_entity_ids("automation")) == [
"automation.automation_0",
"automation.automation_1",
]
assert hass.states.get("automation.automation_0").state == STATE_UNAVAILABLE
assert hass.states.get("automation.automation_1").state == STATE_ON
assert resp.status == HTTPStatus.OK
@ -305,10 +301,8 @@ async def test_bad_formatted_automations(
)
await hass.async_block_till_done()
assert sorted(hass.states.async_entity_ids("automation")) == [
"automation.automation_0",
"automation.automation_1",
]
assert hass.states.get("automation.automation_0").state == STATE_UNAVAILABLE
assert hass.states.get("automation.automation_1").state == STATE_ON
assert resp.status == HTTPStatus.OK

View file

@ -1852,3 +1852,139 @@ async def test_async_extract_config_entry_ids(hass: HomeAssistant) -> None:
)
assert await service.async_extract_config_entry_ids(hass, call) == {"abc"}
async def test_reload_service_helper(hass: HomeAssistant) -> None:
"""Test the reload service helper."""
active_reload_calls = 0
reloaded = []
async def reload_service_handler(service_call: ServiceCall) -> None:
"""Remove all automations and load new ones from config."""
nonlocal active_reload_calls
# Assert the reload helper prevents parallel reloads
assert not active_reload_calls
active_reload_calls += 1
if not (target := service_call.data.get("target")):
reloaded.append("all")
else:
reloaded.append(target)
await asyncio.sleep(0.01)
active_reload_calls -= 1
def reload_targets(service_call: ServiceCall) -> set[str | None]:
if target_id := service_call.data.get("target"):
return {target_id}
return {"target1", "target2", "target3", "target4"}
# Test redundant reload of single targets
reloader = service.ReloadServiceHelper(reload_service_handler, reload_targets)
tasks = [
# This reload task will start executing first, (target1)
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
# These reload tasks will be deduplicated to (target2, target3, target4, target1)
# while the first task is reloaded, note that target1 can't be deduplicated
# because it's already being reloaded.
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
]
await asyncio.gather(*tasks)
assert reloaded == unordered(
["target1", "target2", "target3", "target4", "target1"]
)
# Test redundant reload of multiple targets + single target
reloaded.clear()
tasks = [
# This reload task will start executing first, (target1)
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
# These reload tasks will be deduplicated to (target2, target3, target4, all)
# while the first task is reloaded.
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
reloader.execute_service(ServiceCall("test", "test")),
]
await asyncio.gather(*tasks)
assert reloaded == unordered(["target1", "target2", "target3", "target4", "all"])
# Test redundant reload of multiple targets + single target
reloaded.clear()
tasks = [
# This reload task will start executing first, (all)
reloader.execute_service(ServiceCall("test", "test")),
# These reload tasks will be deduplicated to (target1, target2, target3, target4)
# while the first task is reloaded.
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
]
await asyncio.gather(*tasks)
assert reloaded == unordered(["all", "target1", "target2", "target3", "target4"])
# Test redundant reload of single targets
reloaded.clear()
tasks = [
# This reload task will start executing first, (target1)
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
# These reload tasks will be deduplicated to (target2, target3, target4, target1)
# while the first task is reloaded, note that target1 can't be deduplicated
# because it's already being reloaded.
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
]
await asyncio.gather(*tasks)
assert reloaded == unordered(
["target1", "target2", "target3", "target4", "target1"]
)
# Test redundant reload of multiple targets + single target
reloaded.clear()
tasks = [
# This reload task will start executing first, (target1)
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
# These reload tasks will be deduplicated to (target2, target3, target4, all)
# while the first task is reloaded.
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
reloader.execute_service(ServiceCall("test", "test")),
reloader.execute_service(ServiceCall("test", "test")),
]
await asyncio.gather(*tasks)
assert reloaded == unordered(["target1", "target2", "target3", "target4", "all"])
# Test redundant reload of multiple targets + single target
reloaded.clear()
tasks = [
# This reload task will start executing first, (all)
reloader.execute_service(ServiceCall("test", "test")),
# These reload tasks will be deduplicated to (target1, target2, target3, target4)
# while the first task is reloaded.
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
]
await asyncio.gather(*tasks)
assert reloaded == unordered(["all", "target1", "target2", "target3", "target4"])