Only reload modified scripts (#80470)

Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Erik Montnemery 2022-10-24 20:47:06 +02:00 committed by GitHub
parent c5688072fd
commit f7694c0550
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 227 additions and 10 deletions

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
import logging
from typing import Any, cast
@ -180,7 +181,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def reload_service(service: ServiceCall) -> None:
"""Call a service to reload scripts."""
await async_get_blueprints(hass).async_reset_cache()
if (conf := await component.async_prepare_reload()) is None:
if (conf := await component.async_prepare_reload(skip_reset=True)) is None:
return
await _async_process_config(hass, conf, component)
@ -231,12 +232,22 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
async def _async_process_config(hass, config, component) -> None:
"""Process script configuration.
@dataclass
class ScriptEntityConfig:
"""Container for prepared script entity configuration."""
Return true, if Blueprints were used.
"""
entities = []
config_block: ConfigType
key: str
raw_blueprint_inputs: ConfigType | None
raw_config: ConfigType | None
async def _prepare_script_config(
hass: HomeAssistant,
config: ConfigType,
) -> list[ScriptEntityConfig]:
"""Parse configuration and prepare script entity configuration."""
script_configs: list[ScriptEntityConfig] = []
conf: dict[str, dict[str, Any] | BlueprintInputs] = config[DOMAIN]
@ -265,10 +276,86 @@ async def _async_process_config(hass, config, component) -> None:
else:
raw_config = cast(ScriptConfig, config_block).raw_config
entities.append(
ScriptEntity(hass, key, config_block, raw_config, raw_blueprint_inputs)
script_configs.append(
ScriptEntityConfig(config_block, key, raw_blueprint_inputs, raw_config)
)
return script_configs
async def _create_script_entities(
hass: HomeAssistant, script_configs: list[ScriptEntityConfig]
) -> list[ScriptEntity]:
"""Create script entities from prepared configuration."""
entities: list[ScriptEntity] = []
for script_config in script_configs:
entity = ScriptEntity(
hass,
script_config.key,
script_config.config_block,
script_config.raw_config,
script_config.raw_blueprint_inputs,
)
entities.append(entity)
return entities
async def _async_process_config(hass, config, component) -> None:
"""Process script configuration."""
entities = []
def script_matches_config(script: ScriptEntity, config: ScriptEntityConfig) -> bool:
return script.unique_id == config.key and script.raw_config == config.raw_config
def find_matches(
scripts: list[ScriptEntity],
script_configs: list[ScriptEntityConfig],
) -> tuple[set[int], set[int]]:
"""Find matches between a list of script entities and a list of configurations.
A script or configuration is only allowed to match at most once to handle
the case of multiple scripts with identical configuration.
Returns a tuple of sets of indices: ({script_matches}, {config_matches})
"""
script_matches: set[int] = set()
config_matches: set[int] = set()
for script_idx, script in enumerate(scripts):
for config_idx, config in enumerate(script_configs):
if config_idx in config_matches:
# Only allow a script config to match at most once
continue
if script_matches_config(script, config):
script_matches.add(script_idx)
config_matches.add(config_idx)
# Only allow a script to match at most once
break
return script_matches, config_matches
script_configs = await _prepare_script_config(hass, config)
scripts: list[ScriptEntity] = list(component.entities)
# Find scripts and configurations which have matches
script_matches, config_matches = find_matches(scripts, script_configs)
# Remove scripts which have changed config or no longer exist
tasks = [
script.async_remove()
for idx, script in enumerate(scripts)
if idx not in script_matches
]
await asyncio.gather(*tasks)
# Create scripts which have changed config or have been added
updated_script_configs = [
config for idx, config in enumerate(script_configs) if idx not in config_matches
]
entities = await _create_script_entities(hass, updated_script_configs)
await component.async_add_entities(entities)

View file

@ -739,7 +739,7 @@ async def test_automation_stops(hass, calls, service):
async def test_reload_unchanged_does_not_stop(hass, calls):
"""Test that turning off / reloading stops any running actions as appropriate."""
"""Test that reloading stops any running actions as appropriate."""
test_entity = "test.entity"
config = {
@ -766,6 +766,7 @@ async def test_reload_unchanged_does_not_stop(hass, calls):
hass.bus.async_fire("test_event")
await running.wait()
assert len(calls) == 0
with patch(
"homeassistant.config.load_yaml_config_file",

View file

@ -7,7 +7,7 @@ from unittest.mock import Mock, patch
import pytest
from homeassistant.components import script
from homeassistant.components.script import DOMAIN, EVENT_SCRIPT_STARTED
from homeassistant.components.script import DOMAIN, EVENT_SCRIPT_STARTED, ScriptEntity
from homeassistant.const import (
ATTR_ENTITY_ID,
ATTR_NAME,
@ -46,6 +46,12 @@ from tests.components.logbook.common import MockRow, mock_humanify
ENTITY_ID = "script.test"
@pytest.fixture
def calls(hass):
"""Track calls to a mock service."""
return async_mock_service(hass, "test", "script")
async def test_passing_variables(hass):
"""Test different ways of passing in variables."""
mock_restore_cache(hass, ())
@ -219,6 +225,129 @@ async def test_reload_service(hass, running):
assert hass.services.has_service(script.DOMAIN, "test")
async def test_reload_unchanged_does_not_stop(hass, calls):
"""Test that reloading stops any running actions as appropriate."""
test_entity = "test.entity"
config = {
script.DOMAIN: {
"test": {
"sequence": [
{"event": "running"},
{"wait_template": "{{ is_state('test.entity', 'goodbye') }}"},
{"service": "test.script"},
],
}
}
}
assert await async_setup_component(hass, script.DOMAIN, config)
assert hass.states.get(ENTITY_ID) is not None
assert hass.services.has_service(script.DOMAIN, "test")
running = asyncio.Event()
@callback
def running_cb(event):
running.set()
hass.bus.async_listen_once("running", running_cb)
hass.states.async_set(test_entity, "hello")
# Start the script and wait for it to start
_, object_id = split_entity_id(ENTITY_ID)
await hass.services.async_call(DOMAIN, object_id)
await running.wait()
assert len(calls) == 0
with patch(
"homeassistant.config.load_yaml_config_file",
autospec=True,
return_value=config,
):
await hass.services.async_call(script.DOMAIN, SERVICE_RELOAD, blocking=True)
hass.states.async_set(test_entity, "goodbye")
await hass.async_block_till_done()
assert len(calls) == 1
@pytest.mark.parametrize(
"script_config",
(
{
"test": {
"sequence": [{"service": "test.script"}],
}
},
# A script using templates
{
"test": {
"sequence": [{"service": "{{ 'test.script' }}"}],
}
},
# A script using blueprint
{
"test": {
"use_blueprint": {
"path": "test_service.yaml",
"input": {
"service_to_call": "test.script",
},
}
}
},
# A script using blueprint with templated input
{
"test": {
"use_blueprint": {
"path": "test_service.yaml",
"input": {
"service_to_call": "{{ 'test.script' }}",
},
}
}
},
),
)
async def test_reload_unchanged_script(hass, calls, script_config):
"""Test an unmodified script is not reloaded."""
with patch(
"homeassistant.components.script.ScriptEntity", wraps=ScriptEntity
) as script_entity_init:
config = {script.DOMAIN: [script_config]}
assert await async_setup_component(hass, script.DOMAIN, config)
assert hass.states.get(ENTITY_ID) is not None
assert hass.services.has_service(script.DOMAIN, "test")
assert script_entity_init.call_count == 1
script_entity_init.reset_mock()
# Start the script and wait for it to finish
_, object_id = split_entity_id(ENTITY_ID)
await hass.services.async_call(DOMAIN, object_id)
await hass.async_block_till_done()
assert len(calls) == 1
# Reload the scripts without any change
with patch(
"homeassistant.config.load_yaml_config_file",
autospec=True,
return_value=config,
):
await hass.services.async_call(script.DOMAIN, SERVICE_RELOAD, blocking=True)
assert script_entity_init.call_count == 0
script_entity_init.reset_mock()
# Start the script and wait for it to start
_, object_id = split_entity_id(ENTITY_ID)
await hass.services.async_call(DOMAIN, object_id)
await hass.async_block_till_done()
assert len(calls) == 2
async def test_service_descriptions(hass):
"""Test that service descriptions are loaded and reloaded correctly."""
# Test 1: has "description" but no "fields"