From f7694c055064029c396bb456377f3d4c0673f0ba Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Mon, 24 Oct 2022 20:47:06 +0200 Subject: [PATCH] Only reload modified scripts (#80470) Co-authored-by: Franck Nijhof --- homeassistant/components/script/__init__.py | 103 +++++++++++++-- tests/components/automation/test_init.py | 3 +- tests/components/script/test_init.py | 131 +++++++++++++++++++- 3 files changed, 227 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index a1c683faba7..359486d2687 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +from dataclasses import dataclass import logging from typing import Any, cast @@ -180,7 +181,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def reload_service(service: ServiceCall) -> None: """Call a service to reload scripts.""" await async_get_blueprints(hass).async_reset_cache() - if (conf := await component.async_prepare_reload()) is None: + if (conf := await component.async_prepare_reload(skip_reset=True)) is None: return await _async_process_config(hass, conf, component) @@ -231,12 +232,22 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -async def _async_process_config(hass, config, component) -> None: - """Process script configuration. +@dataclass +class ScriptEntityConfig: + """Container for prepared script entity configuration.""" - Return true, if Blueprints were used. - """ - entities = [] + config_block: ConfigType + key: str + raw_blueprint_inputs: ConfigType | None + raw_config: ConfigType | None + + +async def _prepare_script_config( + hass: HomeAssistant, + config: ConfigType, +) -> list[ScriptEntityConfig]: + """Parse configuration and prepare script entity configuration.""" + script_configs: list[ScriptEntityConfig] = [] conf: dict[str, dict[str, Any] | BlueprintInputs] = config[DOMAIN] @@ -265,10 +276,86 @@ async def _async_process_config(hass, config, component) -> None: else: raw_config = cast(ScriptConfig, config_block).raw_config - entities.append( - ScriptEntity(hass, key, config_block, raw_config, raw_blueprint_inputs) + script_configs.append( + ScriptEntityConfig(config_block, key, raw_blueprint_inputs, raw_config) ) + return script_configs + + +async def _create_script_entities( + hass: HomeAssistant, script_configs: list[ScriptEntityConfig] +) -> list[ScriptEntity]: + """Create script entities from prepared configuration.""" + entities: list[ScriptEntity] = [] + + for script_config in script_configs: + + entity = ScriptEntity( + hass, + script_config.key, + script_config.config_block, + script_config.raw_config, + script_config.raw_blueprint_inputs, + ) + entities.append(entity) + + return entities + + +async def _async_process_config(hass, config, component) -> None: + """Process script configuration.""" + entities = [] + + def script_matches_config(script: ScriptEntity, config: ScriptEntityConfig) -> bool: + return script.unique_id == config.key and script.raw_config == config.raw_config + + def find_matches( + scripts: list[ScriptEntity], + script_configs: list[ScriptEntityConfig], + ) -> tuple[set[int], set[int]]: + """Find matches between a list of script entities and a list of configurations. + + A script or configuration is only allowed to match at most once to handle + the case of multiple scripts with identical configuration. + + Returns a tuple of sets of indices: ({script_matches}, {config_matches}) + """ + script_matches: set[int] = set() + config_matches: set[int] = set() + + for script_idx, script in enumerate(scripts): + for config_idx, config in enumerate(script_configs): + if config_idx in config_matches: + # Only allow a script config to match at most once + continue + if script_matches_config(script, config): + script_matches.add(script_idx) + config_matches.add(config_idx) + # Only allow a script to match at most once + break + + return script_matches, config_matches + + script_configs = await _prepare_script_config(hass, config) + scripts: list[ScriptEntity] = list(component.entities) + + # Find scripts and configurations which have matches + script_matches, config_matches = find_matches(scripts, script_configs) + + # Remove scripts which have changed config or no longer exist + tasks = [ + script.async_remove() + for idx, script in enumerate(scripts) + if idx not in script_matches + ] + await asyncio.gather(*tasks) + + # Create scripts which have changed config or have been added + updated_script_configs = [ + config for idx, config in enumerate(script_configs) if idx not in config_matches + ] + entities = await _create_script_entities(hass, updated_script_configs) await component.async_add_entities(entities) diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 858f3de6549..f40309bf7f6 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -739,7 +739,7 @@ async def test_automation_stops(hass, calls, service): async def test_reload_unchanged_does_not_stop(hass, calls): - """Test that turning off / reloading stops any running actions as appropriate.""" + """Test that reloading stops any running actions as appropriate.""" test_entity = "test.entity" config = { @@ -766,6 +766,7 @@ async def test_reload_unchanged_does_not_stop(hass, calls): hass.bus.async_fire("test_event") await running.wait() + assert len(calls) == 0 with patch( "homeassistant.config.load_yaml_config_file", diff --git a/tests/components/script/test_init.py b/tests/components/script/test_init.py index b48a65275b7..09d2c3c70b1 100644 --- a/tests/components/script/test_init.py +++ b/tests/components/script/test_init.py @@ -7,7 +7,7 @@ from unittest.mock import Mock, patch import pytest from homeassistant.components import script -from homeassistant.components.script import DOMAIN, EVENT_SCRIPT_STARTED +from homeassistant.components.script import DOMAIN, EVENT_SCRIPT_STARTED, ScriptEntity from homeassistant.const import ( ATTR_ENTITY_ID, ATTR_NAME, @@ -46,6 +46,12 @@ from tests.components.logbook.common import MockRow, mock_humanify ENTITY_ID = "script.test" +@pytest.fixture +def calls(hass): + """Track calls to a mock service.""" + return async_mock_service(hass, "test", "script") + + async def test_passing_variables(hass): """Test different ways of passing in variables.""" mock_restore_cache(hass, ()) @@ -219,6 +225,129 @@ async def test_reload_service(hass, running): assert hass.services.has_service(script.DOMAIN, "test") +async def test_reload_unchanged_does_not_stop(hass, calls): + """Test that reloading stops any running actions as appropriate.""" + test_entity = "test.entity" + + config = { + script.DOMAIN: { + "test": { + "sequence": [ + {"event": "running"}, + {"wait_template": "{{ is_state('test.entity', 'goodbye') }}"}, + {"service": "test.script"}, + ], + } + } + } + assert await async_setup_component(hass, script.DOMAIN, config) + + assert hass.states.get(ENTITY_ID) is not None + assert hass.services.has_service(script.DOMAIN, "test") + + running = asyncio.Event() + + @callback + def running_cb(event): + running.set() + + hass.bus.async_listen_once("running", running_cb) + hass.states.async_set(test_entity, "hello") + + # Start the script and wait for it to start + _, object_id = split_entity_id(ENTITY_ID) + await hass.services.async_call(DOMAIN, object_id) + await running.wait() + assert len(calls) == 0 + + with patch( + "homeassistant.config.load_yaml_config_file", + autospec=True, + return_value=config, + ): + await hass.services.async_call(script.DOMAIN, SERVICE_RELOAD, blocking=True) + + hass.states.async_set(test_entity, "goodbye") + await hass.async_block_till_done() + + assert len(calls) == 1 + + +@pytest.mark.parametrize( + "script_config", + ( + { + "test": { + "sequence": [{"service": "test.script"}], + } + }, + # A script using templates + { + "test": { + "sequence": [{"service": "{{ 'test.script' }}"}], + } + }, + # A script using blueprint + { + "test": { + "use_blueprint": { + "path": "test_service.yaml", + "input": { + "service_to_call": "test.script", + }, + } + } + }, + # A script using blueprint with templated input + { + "test": { + "use_blueprint": { + "path": "test_service.yaml", + "input": { + "service_to_call": "{{ 'test.script' }}", + }, + } + } + }, + ), +) +async def test_reload_unchanged_script(hass, calls, script_config): + """Test an unmodified script is not reloaded.""" + with patch( + "homeassistant.components.script.ScriptEntity", wraps=ScriptEntity + ) as script_entity_init: + config = {script.DOMAIN: [script_config]} + assert await async_setup_component(hass, script.DOMAIN, config) + assert hass.states.get(ENTITY_ID) is not None + assert hass.services.has_service(script.DOMAIN, "test") + + assert script_entity_init.call_count == 1 + script_entity_init.reset_mock() + + # Start the script and wait for it to finish + _, object_id = split_entity_id(ENTITY_ID) + await hass.services.async_call(DOMAIN, object_id) + await hass.async_block_till_done() + assert len(calls) == 1 + + # Reload the scripts without any change + with patch( + "homeassistant.config.load_yaml_config_file", + autospec=True, + return_value=config, + ): + await hass.services.async_call(script.DOMAIN, SERVICE_RELOAD, blocking=True) + + assert script_entity_init.call_count == 0 + script_entity_init.reset_mock() + + # Start the script and wait for it to start + _, object_id = split_entity_id(ENTITY_ID) + await hass.services.async_call(DOMAIN, object_id) + await hass.async_block_till_done() + assert len(calls) == 2 + + async def test_service_descriptions(hass): """Test that service descriptions are loaded and reloaded correctly.""" # Test 1: has "description" but no "fields"