Don't reload other automations when saving an automation (#80254)
* Only reload modified automation * Correct check for existing automation * Add tests * Remove the new service, improve ReloadServiceHelper * Revert unneeded changes * Update tests * Address review comments * Improve test coverage * Address review comments * Tweak reloader code + add a targetted test * Apply suggestions from code review Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * Explain the tests + add more variations * Fix copy-paste mistake in test * Rephrase explanation of expected test outcome --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
679752ceb8
commit
7cd0fe3c5f
6 changed files with 484 additions and 26 deletions
|
@ -331,17 +331,25 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
await async_get_blueprints(hass).async_reset_cache()
|
||||
if (conf := await component.async_prepare_reload(skip_reset=True)) is None:
|
||||
return
|
||||
await _async_process_config(hass, conf, component)
|
||||
if automation_id := service_call.data.get(CONF_ID):
|
||||
await _async_process_single_config(hass, conf, component, automation_id)
|
||||
else:
|
||||
await _async_process_config(hass, conf, component)
|
||||
hass.bus.async_fire(EVENT_AUTOMATION_RELOADED, context=service_call.context)
|
||||
|
||||
reload_helper = ReloadServiceHelper(reload_service_handler)
|
||||
def reload_targets(service_call: ServiceCall) -> set[str | None]:
|
||||
if automation_id := service_call.data.get(CONF_ID):
|
||||
return {automation_id}
|
||||
return {automation.unique_id for automation in component.entities}
|
||||
|
||||
reload_helper = ReloadServiceHelper(reload_service_handler, reload_targets)
|
||||
|
||||
async_register_admin_service(
|
||||
hass,
|
||||
DOMAIN,
|
||||
SERVICE_RELOAD,
|
||||
reload_helper.execute_service,
|
||||
schema=vol.Schema({}),
|
||||
schema=vol.Schema({vol.Optional(CONF_ID): str}),
|
||||
)
|
||||
|
||||
websocket_api.async_register_command(hass, websocket_config)
|
||||
|
@ -859,6 +867,7 @@ class AutomationEntityConfig:
|
|||
async def _prepare_automation_config(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
wanted_automation_id: str | None,
|
||||
) -> list[AutomationEntityConfig]:
|
||||
"""Parse configuration and prepare automation entity configuration."""
|
||||
automation_configs: list[AutomationEntityConfig] = []
|
||||
|
@ -866,6 +875,10 @@ async def _prepare_automation_config(
|
|||
conf: list[ConfigType] = config[DOMAIN]
|
||||
|
||||
for list_no, config_block in enumerate(conf):
|
||||
automation_id: str | None = config_block.get(CONF_ID)
|
||||
if wanted_automation_id is not None and automation_id != wanted_automation_id:
|
||||
continue
|
||||
|
||||
raw_config = cast(AutomationConfig, config_block).raw_config
|
||||
raw_blueprint_inputs = cast(AutomationConfig, config_block).raw_blueprint_inputs
|
||||
validation_failed = cast(AutomationConfig, config_block).validation_failed
|
||||
|
@ -1025,7 +1038,7 @@ async def _async_process_config(
|
|||
|
||||
return automation_matches, config_matches
|
||||
|
||||
automation_configs = await _prepare_automation_config(hass, config)
|
||||
automation_configs = await _prepare_automation_config(hass, config, None)
|
||||
automations: list[BaseAutomationEntity] = list(component.entities)
|
||||
|
||||
# Find automations and configurations which have matches
|
||||
|
@ -1049,6 +1062,41 @@ async def _async_process_config(
|
|||
await component.async_add_entities(entities)
|
||||
|
||||
|
||||
def _automation_matches_config(
|
||||
automation: BaseAutomationEntity | None, config: AutomationEntityConfig | None
|
||||
) -> bool:
|
||||
"""Return False if an automation's config has been changed."""
|
||||
if not automation:
|
||||
return False
|
||||
if not config:
|
||||
return False
|
||||
name = _automation_name(config)
|
||||
return automation.name == name and automation.raw_config == config.raw_config
|
||||
|
||||
|
||||
async def _async_process_single_config(
|
||||
hass: HomeAssistant,
|
||||
config: dict[str, Any],
|
||||
component: EntityComponent[BaseAutomationEntity],
|
||||
automation_id: str,
|
||||
) -> None:
|
||||
"""Process config and add a single automation."""
|
||||
|
||||
automation_configs = await _prepare_automation_config(hass, config, automation_id)
|
||||
automation = next(
|
||||
(x for x in component.entities if x.unique_id == automation_id), None
|
||||
)
|
||||
automation_config = automation_configs[0] if automation_configs else None
|
||||
|
||||
if _automation_matches_config(automation, automation_config):
|
||||
return
|
||||
|
||||
if automation:
|
||||
await automation.async_remove()
|
||||
entities = await _create_automation_entities(hass, automation_configs)
|
||||
await component.async_add_entities(entities)
|
||||
|
||||
|
||||
async def _async_process_if(
|
||||
hass: HomeAssistant, name: str, config: dict[str, Any]
|
||||
) -> IfAction | None:
|
||||
|
|
|
@ -26,7 +26,9 @@ def async_setup(hass: HomeAssistant) -> bool:
|
|||
async def hook(action: str, config_key: str) -> None:
|
||||
"""post_write_hook for Config View that reloads automations."""
|
||||
if action != ACTION_DELETE:
|
||||
await hass.services.async_call(DOMAIN, SERVICE_RELOAD)
|
||||
await hass.services.async_call(
|
||||
DOMAIN, SERVICE_RELOAD, {CONF_ID: config_key}
|
||||
)
|
||||
return
|
||||
|
||||
ent_reg = er.async_get(hass)
|
||||
|
|
|
@ -77,6 +77,8 @@ _LOGGER = logging.getLogger(__name__)
|
|||
SERVICE_DESCRIPTION_CACHE = "service_description_cache"
|
||||
ALL_SERVICE_DESCRIPTIONS_CACHE = "all_service_descriptions_cache"
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@cache
|
||||
def _base_components() -> dict[str, ModuleType]:
|
||||
|
@ -1154,40 +1156,67 @@ def verify_domain_control(
|
|||
|
||||
|
||||
class ReloadServiceHelper:
|
||||
"""Helper for reload services to minimize unnecessary reloads."""
|
||||
"""Helper for reload services.
|
||||
|
||||
def __init__(self, service_func: Callable[[ServiceCall], Awaitable]) -> None:
|
||||
The helper has the following purposes:
|
||||
- Make sure reloads do not happen in parallel
|
||||
- Avoid redundant reloads of the same target
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_func: Callable[[ServiceCall], Awaitable],
|
||||
reload_targets_func: Callable[[ServiceCall], set[_T]],
|
||||
) -> None:
|
||||
"""Initialize ReloadServiceHelper."""
|
||||
self._service_func = service_func
|
||||
self._service_running = False
|
||||
self._service_condition = asyncio.Condition()
|
||||
self._pending_reload_targets: set[_T] = set()
|
||||
self._reload_targets_func = reload_targets_func
|
||||
|
||||
async def execute_service(self, service_call: ServiceCall) -> None:
|
||||
"""Execute the service.
|
||||
|
||||
If a previous reload task if currently in progress, wait for it to finish first.
|
||||
If a previous reload task is currently in progress, wait for it to finish first.
|
||||
Once the previous reload task has finished, one of the waiting tasks will be
|
||||
assigned to execute the reload, the others will wait for the reload to finish.
|
||||
assigned to execute the reload of the targets it is assigned to reload. The
|
||||
other tasks will wait if they should reload the same target, otherwise they
|
||||
will wait for the next round.
|
||||
"""
|
||||
|
||||
do_reload = False
|
||||
reload_targets = None
|
||||
async with self._service_condition:
|
||||
if self._service_running:
|
||||
# A previous reload task is already in progress, wait for it to finish
|
||||
# A previous reload task is already in progress, wait for it to finish,
|
||||
# because that task may be reloading a stale version of the resource.
|
||||
await self._service_condition.wait()
|
||||
|
||||
async with self._service_condition:
|
||||
if not self._service_running:
|
||||
# This task will do the reload
|
||||
self._service_running = True
|
||||
do_reload = True
|
||||
else:
|
||||
# Another task will perform the reload, wait for it to finish
|
||||
while True:
|
||||
async with self._service_condition:
|
||||
# Once we've passed this point, we assume the version of the resource is
|
||||
# the one our task was assigned to reload, or a newer one. Regardless of
|
||||
# which, our task is happy as long as the target is reloaded at least
|
||||
# once.
|
||||
if reload_targets is None:
|
||||
reload_targets = self._reload_targets_func(service_call)
|
||||
self._pending_reload_targets |= reload_targets
|
||||
if not self._service_running:
|
||||
# This task will do a reload
|
||||
self._service_running = True
|
||||
do_reload = True
|
||||
break
|
||||
# Another task will perform a reload, wait for it to finish
|
||||
await self._service_condition.wait()
|
||||
# Check if the reload this task is waiting for has been completed
|
||||
if reload_targets.isdisjoint(self._pending_reload_targets):
|
||||
break
|
||||
|
||||
if do_reload:
|
||||
# Reload, then notify other tasks
|
||||
await self._service_func(service_call)
|
||||
async with self._service_condition:
|
||||
self._service_running = False
|
||||
self._pending_reload_targets -= reload_targets
|
||||
self._service_condition.notify_all()
|
||||
|
|
|
@ -21,6 +21,7 @@ from homeassistant.config_entries import ConfigEntryState
|
|||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
ATTR_NAME,
|
||||
CONF_ID,
|
||||
EVENT_HOMEASSISTANT_STARTED,
|
||||
SERVICE_RELOAD,
|
||||
SERVICE_TOGGLE,
|
||||
|
@ -692,7 +693,9 @@ async def test_reload_config_handles_load_fails(hass: HomeAssistant, calls) -> N
|
|||
assert len(calls) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("service", ["turn_off_stop", "turn_off_no_stop", "reload"])
|
||||
@pytest.mark.parametrize(
|
||||
"service", ["turn_off_stop", "turn_off_no_stop", "reload", "reload_single"]
|
||||
)
|
||||
async def test_automation_stops(hass: HomeAssistant, calls, service) -> None:
|
||||
"""Test that turning off / reloading stops any running actions as appropriate."""
|
||||
entity_id = "automation.hello"
|
||||
|
@ -700,6 +703,7 @@ async def test_automation_stops(hass: HomeAssistant, calls, service) -> None:
|
|||
|
||||
config = {
|
||||
automation.DOMAIN: {
|
||||
"id": "sun",
|
||||
"alias": "hello",
|
||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||
"action": [
|
||||
|
@ -737,7 +741,7 @@ async def test_automation_stops(hass: HomeAssistant, calls, service) -> None:
|
|||
{ATTR_ENTITY_ID: entity_id, automation.CONF_STOP_ACTIONS: False},
|
||||
blocking=True,
|
||||
)
|
||||
else:
|
||||
elif service == "reload":
|
||||
config[automation.DOMAIN]["alias"] = "goodbye"
|
||||
with patch(
|
||||
"homeassistant.config.load_yaml_config_file",
|
||||
|
@ -747,6 +751,19 @@ async def test_automation_stops(hass: HomeAssistant, calls, service) -> None:
|
|||
await hass.services.async_call(
|
||||
automation.DOMAIN, SERVICE_RELOAD, blocking=True
|
||||
)
|
||||
else: # service == "reload_single"
|
||||
config[automation.DOMAIN]["alias"] = "goodbye"
|
||||
with patch(
|
||||
"homeassistant.config.load_yaml_config_file",
|
||||
autospec=True,
|
||||
return_value=config,
|
||||
):
|
||||
await hass.services.async_call(
|
||||
automation.DOMAIN,
|
||||
SERVICE_RELOAD,
|
||||
{CONF_ID: "sun"},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
hass.states.async_set(test_entity, "goodbye")
|
||||
await hass.async_block_till_done()
|
||||
|
@ -801,6 +818,238 @@ async def test_reload_unchanged_does_not_stop(
|
|||
assert len(calls) == 1
|
||||
|
||||
|
||||
async def test_reload_single_unchanged_does_not_stop(
|
||||
hass: HomeAssistant, calls
|
||||
) -> None:
|
||||
"""Test that reloading stops any running actions as appropriate."""
|
||||
test_entity = "test.entity"
|
||||
|
||||
config = {
|
||||
automation.DOMAIN: {
|
||||
"id": "sun",
|
||||
"alias": "hello",
|
||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||
"action": [
|
||||
{"event": "running"},
|
||||
{"wait_template": "{{ is_state('test.entity', 'goodbye') }}"},
|
||||
{"service": "test.automation"},
|
||||
],
|
||||
}
|
||||
}
|
||||
assert await async_setup_component(hass, automation.DOMAIN, config)
|
||||
|
||||
running = asyncio.Event()
|
||||
|
||||
@callback
|
||||
def running_cb(event):
|
||||
running.set()
|
||||
|
||||
hass.bus.async_listen_once("running", running_cb)
|
||||
hass.states.async_set(test_entity, "hello")
|
||||
|
||||
hass.bus.async_fire("test_event")
|
||||
await running.wait()
|
||||
assert len(calls) == 0
|
||||
|
||||
with patch(
|
||||
"homeassistant.config.load_yaml_config_file",
|
||||
autospec=True,
|
||||
return_value=config,
|
||||
):
|
||||
await hass.services.async_call(
|
||||
automation.DOMAIN,
|
||||
SERVICE_RELOAD,
|
||||
{CONF_ID: "sun"},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
hass.states.async_set(test_entity, "goodbye")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
async def test_reload_single_add_automation(hass: HomeAssistant, calls) -> None:
|
||||
"""Test that reloading a single automation."""
|
||||
config1 = {automation.DOMAIN: {}}
|
||||
config2 = {
|
||||
automation.DOMAIN: {
|
||||
"id": "sun",
|
||||
"alias": "hello",
|
||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||
"action": [{"service": "test.automation"}],
|
||||
}
|
||||
}
|
||||
assert await async_setup_component(hass, automation.DOMAIN, config1)
|
||||
|
||||
hass.bus.async_fire("test_event")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 0
|
||||
|
||||
with patch(
|
||||
"homeassistant.config.load_yaml_config_file",
|
||||
autospec=True,
|
||||
return_value=config2,
|
||||
):
|
||||
await hass.services.async_call(
|
||||
automation.DOMAIN,
|
||||
SERVICE_RELOAD,
|
||||
{CONF_ID: "sun"},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
hass.bus.async_fire("test_event")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
async def test_reload_single_parallel_calls(hass: HomeAssistant, calls) -> None:
|
||||
"""Test reloading single automations in parallel."""
|
||||
config1 = {automation.DOMAIN: {}}
|
||||
config2 = {
|
||||
automation.DOMAIN: [
|
||||
{
|
||||
"id": "sun",
|
||||
"alias": "hello",
|
||||
"trigger": {"platform": "event", "event_type": "test_event_sun"},
|
||||
"action": [{"service": "test.automation"}],
|
||||
},
|
||||
{
|
||||
"id": "moon",
|
||||
"alias": "goodbye",
|
||||
"trigger": {"platform": "event", "event_type": "test_event_moon"},
|
||||
"action": [{"service": "test.automation"}],
|
||||
},
|
||||
{
|
||||
"id": "mars",
|
||||
"alias": "goodbye",
|
||||
"trigger": {"platform": "event", "event_type": "test_event_mars"},
|
||||
"action": [{"service": "test.automation"}],
|
||||
},
|
||||
{
|
||||
"id": "venus",
|
||||
"alias": "goodbye",
|
||||
"trigger": {"platform": "event", "event_type": "test_event_venus"},
|
||||
"action": [{"service": "test.automation"}],
|
||||
},
|
||||
]
|
||||
}
|
||||
assert await async_setup_component(hass, automation.DOMAIN, config1)
|
||||
|
||||
hass.bus.async_fire("test_event")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 0
|
||||
|
||||
# Trigger multiple reload service calls, each automation is reloaded twice.
|
||||
# This tests the logic in the `ReloadServiceHelper` which avoids redundant
|
||||
# reloads of the same target automation.
|
||||
with patch(
|
||||
"homeassistant.config.load_yaml_config_file",
|
||||
autospec=True,
|
||||
return_value=config2,
|
||||
):
|
||||
tasks = [
|
||||
hass.services.async_call(
|
||||
automation.DOMAIN,
|
||||
SERVICE_RELOAD,
|
||||
{CONF_ID: "sun"},
|
||||
blocking=False,
|
||||
),
|
||||
hass.services.async_call(
|
||||
automation.DOMAIN,
|
||||
SERVICE_RELOAD,
|
||||
{CONF_ID: "moon"},
|
||||
blocking=False,
|
||||
),
|
||||
hass.services.async_call(
|
||||
automation.DOMAIN,
|
||||
SERVICE_RELOAD,
|
||||
{CONF_ID: "mars"},
|
||||
blocking=False,
|
||||
),
|
||||
hass.services.async_call(
|
||||
automation.DOMAIN,
|
||||
SERVICE_RELOAD,
|
||||
{CONF_ID: "venus"},
|
||||
blocking=False,
|
||||
),
|
||||
hass.services.async_call(
|
||||
automation.DOMAIN,
|
||||
SERVICE_RELOAD,
|
||||
{CONF_ID: "sun"},
|
||||
blocking=False,
|
||||
),
|
||||
hass.services.async_call(
|
||||
automation.DOMAIN,
|
||||
SERVICE_RELOAD,
|
||||
{CONF_ID: "moon"},
|
||||
blocking=False,
|
||||
),
|
||||
hass.services.async_call(
|
||||
automation.DOMAIN,
|
||||
SERVICE_RELOAD,
|
||||
{CONF_ID: "mars"},
|
||||
blocking=False,
|
||||
),
|
||||
hass.services.async_call(
|
||||
automation.DOMAIN,
|
||||
SERVICE_RELOAD,
|
||||
{CONF_ID: "venus"},
|
||||
blocking=False,
|
||||
),
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Sanity check to ensure all automations are correctly setup
|
||||
hass.bus.async_fire("test_event_sun")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
hass.bus.async_fire("test_event_moon")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 2
|
||||
hass.bus.async_fire("test_event_mars")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 3
|
||||
hass.bus.async_fire("test_event_venus")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 4
|
||||
|
||||
|
||||
async def test_reload_single_remove_automation(hass: HomeAssistant, calls) -> None:
|
||||
"""Test that reloading a single automation."""
|
||||
config1 = {
|
||||
automation.DOMAIN: {
|
||||
"id": "sun",
|
||||
"alias": "hello",
|
||||
"trigger": {"platform": "event", "event_type": "test_event"},
|
||||
"action": [{"service": "test.automation"}],
|
||||
}
|
||||
}
|
||||
config2 = {automation.DOMAIN: {}}
|
||||
assert await async_setup_component(hass, automation.DOMAIN, config1)
|
||||
|
||||
hass.bus.async_fire("test_event")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
|
||||
with patch(
|
||||
"homeassistant.config.load_yaml_config_file",
|
||||
autospec=True,
|
||||
return_value=config2,
|
||||
):
|
||||
await hass.services.async_call(
|
||||
automation.DOMAIN,
|
||||
SERVICE_RELOAD,
|
||||
{CONF_ID: "sun"},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
hass.bus.async_fire("test_event")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
async def test_reload_moved_automation_without_alias(
|
||||
hass: HomeAssistant, calls
|
||||
) -> None:
|
||||
|
|
|
@ -10,7 +10,7 @@ import pytest
|
|||
from homeassistant.bootstrap import async_setup_component
|
||||
from homeassistant.components import config
|
||||
from homeassistant.components.config import automation
|
||||
from homeassistant.const import STATE_ON, STATE_UNAVAILABLE
|
||||
from homeassistant.const import STATE_ON
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.util import yaml
|
||||
|
@ -82,10 +82,8 @@ async def test_update_automation_config(
|
|||
)
|
||||
await hass.async_block_till_done()
|
||||
assert sorted(hass.states.async_entity_ids("automation")) == [
|
||||
"automation.automation_0",
|
||||
"automation.automation_1",
|
||||
]
|
||||
assert hass.states.get("automation.automation_0").state == STATE_UNAVAILABLE
|
||||
assert hass.states.get("automation.automation_1").state == STATE_ON
|
||||
|
||||
assert resp.status == HTTPStatus.OK
|
||||
|
@ -260,10 +258,8 @@ async def test_update_remove_key_automation_config(
|
|||
)
|
||||
await hass.async_block_till_done()
|
||||
assert sorted(hass.states.async_entity_ids("automation")) == [
|
||||
"automation.automation_0",
|
||||
"automation.automation_1",
|
||||
]
|
||||
assert hass.states.get("automation.automation_0").state == STATE_UNAVAILABLE
|
||||
assert hass.states.get("automation.automation_1").state == STATE_ON
|
||||
|
||||
assert resp.status == HTTPStatus.OK
|
||||
|
@ -305,10 +301,8 @@ async def test_bad_formatted_automations(
|
|||
)
|
||||
await hass.async_block_till_done()
|
||||
assert sorted(hass.states.async_entity_ids("automation")) == [
|
||||
"automation.automation_0",
|
||||
"automation.automation_1",
|
||||
]
|
||||
assert hass.states.get("automation.automation_0").state == STATE_UNAVAILABLE
|
||||
assert hass.states.get("automation.automation_1").state == STATE_ON
|
||||
|
||||
assert resp.status == HTTPStatus.OK
|
||||
|
|
|
@ -1852,3 +1852,139 @@ async def test_async_extract_config_entry_ids(hass: HomeAssistant) -> None:
|
|||
)
|
||||
|
||||
assert await service.async_extract_config_entry_ids(hass, call) == {"abc"}
|
||||
|
||||
|
||||
async def test_reload_service_helper(hass: HomeAssistant) -> None:
|
||||
"""Test the reload service helper."""
|
||||
|
||||
active_reload_calls = 0
|
||||
reloaded = []
|
||||
|
||||
async def reload_service_handler(service_call: ServiceCall) -> None:
|
||||
"""Remove all automations and load new ones from config."""
|
||||
nonlocal active_reload_calls
|
||||
# Assert the reload helper prevents parallel reloads
|
||||
assert not active_reload_calls
|
||||
active_reload_calls += 1
|
||||
if not (target := service_call.data.get("target")):
|
||||
reloaded.append("all")
|
||||
else:
|
||||
reloaded.append(target)
|
||||
await asyncio.sleep(0.01)
|
||||
active_reload_calls -= 1
|
||||
|
||||
def reload_targets(service_call: ServiceCall) -> set[str | None]:
|
||||
if target_id := service_call.data.get("target"):
|
||||
return {target_id}
|
||||
return {"target1", "target2", "target3", "target4"}
|
||||
|
||||
# Test redundant reload of single targets
|
||||
reloader = service.ReloadServiceHelper(reload_service_handler, reload_targets)
|
||||
tasks = [
|
||||
# This reload task will start executing first, (target1)
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||
# These reload tasks will be deduplicated to (target2, target3, target4, target1)
|
||||
# while the first task is reloaded, note that target1 can't be deduplicated
|
||||
# because it's already being reloaded.
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
assert reloaded == unordered(
|
||||
["target1", "target2", "target3", "target4", "target1"]
|
||||
)
|
||||
|
||||
# Test redundant reload of multiple targets + single target
|
||||
reloaded.clear()
|
||||
tasks = [
|
||||
# This reload task will start executing first, (target1)
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||
# These reload tasks will be deduplicated to (target2, target3, target4, all)
|
||||
# while the first task is reloaded.
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||
reloader.execute_service(ServiceCall("test", "test")),
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
assert reloaded == unordered(["target1", "target2", "target3", "target4", "all"])
|
||||
|
||||
# Test redundant reload of multiple targets + single target
|
||||
reloaded.clear()
|
||||
tasks = [
|
||||
# This reload task will start executing first, (all)
|
||||
reloader.execute_service(ServiceCall("test", "test")),
|
||||
# These reload tasks will be deduplicated to (target1, target2, target3, target4)
|
||||
# while the first task is reloaded.
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
assert reloaded == unordered(["all", "target1", "target2", "target3", "target4"])
|
||||
|
||||
# Test redundant reload of single targets
|
||||
reloaded.clear()
|
||||
tasks = [
|
||||
# This reload task will start executing first, (target1)
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||
# These reload tasks will be deduplicated to (target2, target3, target4, target1)
|
||||
# while the first task is reloaded, note that target1 can't be deduplicated
|
||||
# because it's already being reloaded.
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
assert reloaded == unordered(
|
||||
["target1", "target2", "target3", "target4", "target1"]
|
||||
)
|
||||
|
||||
# Test redundant reload of multiple targets + single target
|
||||
reloaded.clear()
|
||||
tasks = [
|
||||
# This reload task will start executing first, (target1)
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||
# These reload tasks will be deduplicated to (target2, target3, target4, all)
|
||||
# while the first task is reloaded.
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||
reloader.execute_service(ServiceCall("test", "test")),
|
||||
reloader.execute_service(ServiceCall("test", "test")),
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
assert reloaded == unordered(["target1", "target2", "target3", "target4", "all"])
|
||||
|
||||
# Test redundant reload of multiple targets + single target
|
||||
reloaded.clear()
|
||||
tasks = [
|
||||
# This reload task will start executing first, (all)
|
||||
reloader.execute_service(ServiceCall("test", "test")),
|
||||
# These reload tasks will be deduplicated to (target1, target2, target3, target4)
|
||||
# while the first task is reloaded.
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target1"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target2"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target3"})),
|
||||
reloader.execute_service(ServiceCall("test", "test", {"target": "target4"})),
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
assert reloaded == unordered(["all", "target1", "target2", "target3", "target4"])
|
||||
|
|
Loading…
Add table
Reference in a new issue