Only reload modified scripts (#80470)
Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
parent
c5688072fd
commit
f7694c0550
3 changed files with 227 additions and 10 deletions
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
|
@ -180,7 +181,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
async def reload_service(service: ServiceCall) -> None:
|
||||
"""Call a service to reload scripts."""
|
||||
await async_get_blueprints(hass).async_reset_cache()
|
||||
if (conf := await component.async_prepare_reload()) is None:
|
||||
if (conf := await component.async_prepare_reload(skip_reset=True)) is None:
|
||||
return
|
||||
await _async_process_config(hass, conf, component)
|
||||
|
||||
|
@ -231,12 +232,22 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
async def _async_process_config(hass, config, component) -> None:
|
||||
"""Process script configuration.
|
||||
@dataclass
|
||||
class ScriptEntityConfig:
|
||||
"""Container for prepared script entity configuration."""
|
||||
|
||||
Return true, if Blueprints were used.
|
||||
"""
|
||||
entities = []
|
||||
config_block: ConfigType
|
||||
key: str
|
||||
raw_blueprint_inputs: ConfigType | None
|
||||
raw_config: ConfigType | None
|
||||
|
||||
|
||||
async def _prepare_script_config(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
) -> list[ScriptEntityConfig]:
|
||||
"""Parse configuration and prepare script entity configuration."""
|
||||
script_configs: list[ScriptEntityConfig] = []
|
||||
|
||||
conf: dict[str, dict[str, Any] | BlueprintInputs] = config[DOMAIN]
|
||||
|
||||
|
@ -265,10 +276,86 @@ async def _async_process_config(hass, config, component) -> None:
|
|||
else:
|
||||
raw_config = cast(ScriptConfig, config_block).raw_config
|
||||
|
||||
entities.append(
|
||||
ScriptEntity(hass, key, config_block, raw_config, raw_blueprint_inputs)
|
||||
script_configs.append(
|
||||
ScriptEntityConfig(config_block, key, raw_blueprint_inputs, raw_config)
|
||||
)
|
||||
|
||||
return script_configs
|
||||
|
||||
|
||||
async def _create_script_entities(
|
||||
hass: HomeAssistant, script_configs: list[ScriptEntityConfig]
|
||||
) -> list[ScriptEntity]:
|
||||
"""Create script entities from prepared configuration."""
|
||||
entities: list[ScriptEntity] = []
|
||||
|
||||
for script_config in script_configs:
|
||||
|
||||
entity = ScriptEntity(
|
||||
hass,
|
||||
script_config.key,
|
||||
script_config.config_block,
|
||||
script_config.raw_config,
|
||||
script_config.raw_blueprint_inputs,
|
||||
)
|
||||
entities.append(entity)
|
||||
|
||||
return entities
|
||||
|
||||
|
||||
async def _async_process_config(hass, config, component) -> None:
|
||||
"""Process script configuration."""
|
||||
entities = []
|
||||
|
||||
def script_matches_config(script: ScriptEntity, config: ScriptEntityConfig) -> bool:
|
||||
return script.unique_id == config.key and script.raw_config == config.raw_config
|
||||
|
||||
def find_matches(
|
||||
scripts: list[ScriptEntity],
|
||||
script_configs: list[ScriptEntityConfig],
|
||||
) -> tuple[set[int], set[int]]:
|
||||
"""Find matches between a list of script entities and a list of configurations.
|
||||
|
||||
A script or configuration is only allowed to match at most once to handle
|
||||
the case of multiple scripts with identical configuration.
|
||||
|
||||
Returns a tuple of sets of indices: ({script_matches}, {config_matches})
|
||||
"""
|
||||
script_matches: set[int] = set()
|
||||
config_matches: set[int] = set()
|
||||
|
||||
for script_idx, script in enumerate(scripts):
|
||||
for config_idx, config in enumerate(script_configs):
|
||||
if config_idx in config_matches:
|
||||
# Only allow a script config to match at most once
|
||||
continue
|
||||
if script_matches_config(script, config):
|
||||
script_matches.add(script_idx)
|
||||
config_matches.add(config_idx)
|
||||
# Only allow a script to match at most once
|
||||
break
|
||||
|
||||
return script_matches, config_matches
|
||||
|
||||
script_configs = await _prepare_script_config(hass, config)
|
||||
scripts: list[ScriptEntity] = list(component.entities)
|
||||
|
||||
# Find scripts and configurations which have matches
|
||||
script_matches, config_matches = find_matches(scripts, script_configs)
|
||||
|
||||
# Remove scripts which have changed config or no longer exist
|
||||
tasks = [
|
||||
script.async_remove()
|
||||
for idx, script in enumerate(scripts)
|
||||
if idx not in script_matches
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Create scripts which have changed config or have been added
|
||||
updated_script_configs = [
|
||||
config for idx, config in enumerate(script_configs) if idx not in config_matches
|
||||
]
|
||||
entities = await _create_script_entities(hass, updated_script_configs)
|
||||
await component.async_add_entities(entities)
|
||||
|
||||
|
||||
|
|
|
@ -739,7 +739,7 @@ async def test_automation_stops(hass, calls, service):
|
|||
|
||||
|
||||
async def test_reload_unchanged_does_not_stop(hass, calls):
|
||||
"""Test that turning off / reloading stops any running actions as appropriate."""
|
||||
"""Test that reloading stops any running actions as appropriate."""
|
||||
test_entity = "test.entity"
|
||||
|
||||
config = {
|
||||
|
@ -766,6 +766,7 @@ async def test_reload_unchanged_does_not_stop(hass, calls):
|
|||
|
||||
hass.bus.async_fire("test_event")
|
||||
await running.wait()
|
||||
assert len(calls) == 0
|
||||
|
||||
with patch(
|
||||
"homeassistant.config.load_yaml_config_file",
|
||||
|
|
|
@ -7,7 +7,7 @@ from unittest.mock import Mock, patch
|
|||
import pytest
|
||||
|
||||
from homeassistant.components import script
|
||||
from homeassistant.components.script import DOMAIN, EVENT_SCRIPT_STARTED
|
||||
from homeassistant.components.script import DOMAIN, EVENT_SCRIPT_STARTED, ScriptEntity
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
ATTR_NAME,
|
||||
|
@ -46,6 +46,12 @@ from tests.components.logbook.common import MockRow, mock_humanify
|
|||
ENTITY_ID = "script.test"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def calls(hass):
|
||||
"""Track calls to a mock service."""
|
||||
return async_mock_service(hass, "test", "script")
|
||||
|
||||
|
||||
async def test_passing_variables(hass):
|
||||
"""Test different ways of passing in variables."""
|
||||
mock_restore_cache(hass, ())
|
||||
|
@ -219,6 +225,129 @@ async def test_reload_service(hass, running):
|
|||
assert hass.services.has_service(script.DOMAIN, "test")
|
||||
|
||||
|
||||
async def test_reload_unchanged_does_not_stop(hass, calls):
|
||||
"""Test that reloading stops any running actions as appropriate."""
|
||||
test_entity = "test.entity"
|
||||
|
||||
config = {
|
||||
script.DOMAIN: {
|
||||
"test": {
|
||||
"sequence": [
|
||||
{"event": "running"},
|
||||
{"wait_template": "{{ is_state('test.entity', 'goodbye') }}"},
|
||||
{"service": "test.script"},
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
assert await async_setup_component(hass, script.DOMAIN, config)
|
||||
|
||||
assert hass.states.get(ENTITY_ID) is not None
|
||||
assert hass.services.has_service(script.DOMAIN, "test")
|
||||
|
||||
running = asyncio.Event()
|
||||
|
||||
@callback
|
||||
def running_cb(event):
|
||||
running.set()
|
||||
|
||||
hass.bus.async_listen_once("running", running_cb)
|
||||
hass.states.async_set(test_entity, "hello")
|
||||
|
||||
# Start the script and wait for it to start
|
||||
_, object_id = split_entity_id(ENTITY_ID)
|
||||
await hass.services.async_call(DOMAIN, object_id)
|
||||
await running.wait()
|
||||
assert len(calls) == 0
|
||||
|
||||
with patch(
|
||||
"homeassistant.config.load_yaml_config_file",
|
||||
autospec=True,
|
||||
return_value=config,
|
||||
):
|
||||
await hass.services.async_call(script.DOMAIN, SERVICE_RELOAD, blocking=True)
|
||||
|
||||
hass.states.async_set(test_entity, "goodbye")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"script_config",
|
||||
(
|
||||
{
|
||||
"test": {
|
||||
"sequence": [{"service": "test.script"}],
|
||||
}
|
||||
},
|
||||
# A script using templates
|
||||
{
|
||||
"test": {
|
||||
"sequence": [{"service": "{{ 'test.script' }}"}],
|
||||
}
|
||||
},
|
||||
# A script using blueprint
|
||||
{
|
||||
"test": {
|
||||
"use_blueprint": {
|
||||
"path": "test_service.yaml",
|
||||
"input": {
|
||||
"service_to_call": "test.script",
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
# A script using blueprint with templated input
|
||||
{
|
||||
"test": {
|
||||
"use_blueprint": {
|
||||
"path": "test_service.yaml",
|
||||
"input": {
|
||||
"service_to_call": "{{ 'test.script' }}",
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
),
|
||||
)
|
||||
async def test_reload_unchanged_script(hass, calls, script_config):
|
||||
"""Test an unmodified script is not reloaded."""
|
||||
with patch(
|
||||
"homeassistant.components.script.ScriptEntity", wraps=ScriptEntity
|
||||
) as script_entity_init:
|
||||
config = {script.DOMAIN: [script_config]}
|
||||
assert await async_setup_component(hass, script.DOMAIN, config)
|
||||
assert hass.states.get(ENTITY_ID) is not None
|
||||
assert hass.services.has_service(script.DOMAIN, "test")
|
||||
|
||||
assert script_entity_init.call_count == 1
|
||||
script_entity_init.reset_mock()
|
||||
|
||||
# Start the script and wait for it to finish
|
||||
_, object_id = split_entity_id(ENTITY_ID)
|
||||
await hass.services.async_call(DOMAIN, object_id)
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 1
|
||||
|
||||
# Reload the scripts without any change
|
||||
with patch(
|
||||
"homeassistant.config.load_yaml_config_file",
|
||||
autospec=True,
|
||||
return_value=config,
|
||||
):
|
||||
await hass.services.async_call(script.DOMAIN, SERVICE_RELOAD, blocking=True)
|
||||
|
||||
assert script_entity_init.call_count == 0
|
||||
script_entity_init.reset_mock()
|
||||
|
||||
# Start the script and wait for it to start
|
||||
_, object_id = split_entity_id(ENTITY_ID)
|
||||
await hass.services.async_call(DOMAIN, object_id)
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 2
|
||||
|
||||
|
||||
async def test_service_descriptions(hass):
|
||||
"""Test that service descriptions are loaded and reloaded correctly."""
|
||||
# Test 1: has "description" but no "fields"
|
||||
|
|
Loading…
Add table
Reference in a new issue