From 2ba0f42accde760adf7d99ef9dd6aebd38cb517f Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 14 Sep 2022 16:47:08 +0200 Subject: [PATCH] Prevent deleting blueprints which are in use (#78444) --- .../components/automation/__init__.py | 24 ++++ .../components/automation/helpers.py | 9 +- .../components/blueprint/__init__.py | 2 +- homeassistant/components/blueprint/errors.py | 8 ++ homeassistant/components/blueprint/models.py | 6 + homeassistant/components/script/__init__.py | 25 ++++- homeassistant/components/script/helpers.py | 9 +- tests/components/blueprint/test_models.py | 4 +- .../blueprint/test_websocket_api.py | 105 +++++++++++++++++- .../blueprints/script/test_service.yaml | 8 ++ 10 files changed, 192 insertions(+), 8 deletions(-) create mode 100644 tests/testing_config/blueprints/script/test_service.yaml diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index a5ea30f59d2..454edce5cac 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -9,6 +9,7 @@ import voluptuous as vol from voluptuous.humanize import humanize_error from homeassistant.components import blueprint +from homeassistant.components.blueprint import CONF_USE_BLUEPRINT from homeassistant.const import ( ATTR_ENTITY_ID, ATTR_MODE, @@ -20,6 +21,7 @@ from homeassistant.const import ( CONF_EVENT_DATA, CONF_ID, CONF_MODE, + CONF_PATH, CONF_PLATFORM, CONF_VARIABLES, CONF_ZONE, @@ -233,6 +235,21 @@ def areas_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]: return list(cast(AutomationEntity, automation_entity).referenced_areas) +@callback +def automations_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list[str]: + """Return all automations that reference the blueprint.""" + if DOMAIN not in hass.data: + return [] + + component = hass.data[DOMAIN] + + return [ + automation_entity.entity_id + for automation_entity in component.entities + if automation_entity.referenced_blueprint == blueprint_path + ] + + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up all automations.""" hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass) @@ -356,6 +373,13 @@ class AutomationEntity(ToggleEntity, RestoreEntity): """Return a set of referenced areas.""" return self.action_script.referenced_areas + @property + def referenced_blueprint(self) -> str | None: + """Return referenced blueprint or None.""" + if self._blueprint_inputs is None: + return None + return cast(str, self._blueprint_inputs[CONF_USE_BLUEPRINT][CONF_PATH]) + @property def referenced_devices(self) -> set[str]: """Return a set of referenced devices.""" diff --git a/homeassistant/components/automation/helpers.py b/homeassistant/components/automation/helpers.py index 3be11afe18b..7c2efc17bf4 100644 --- a/homeassistant/components/automation/helpers.py +++ b/homeassistant/components/automation/helpers.py @@ -8,8 +8,15 @@ from .const import DOMAIN, LOGGER DATA_BLUEPRINTS = "automation_blueprints" +def _blueprint_in_use(hass: HomeAssistant, blueprint_path: str) -> bool: + """Return True if any automation references the blueprint.""" + from . import automations_with_blueprint # pylint: disable=import-outside-toplevel + + return len(automations_with_blueprint(hass, blueprint_path)) > 0 + + @singleton(DATA_BLUEPRINTS) @callback def async_get_blueprints(hass: HomeAssistant) -> blueprint.DomainBlueprints: """Get automation blueprints.""" - return blueprint.DomainBlueprints(hass, DOMAIN, LOGGER) + return blueprint.DomainBlueprints(hass, DOMAIN, LOGGER, _blueprint_in_use) diff --git a/homeassistant/components/blueprint/__init__.py b/homeassistant/components/blueprint/__init__.py index 23ab6398333..3087309f36a 100644 --- a/homeassistant/components/blueprint/__init__.py +++ b/homeassistant/components/blueprint/__init__.py @@ -3,7 +3,7 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers.typing import ConfigType from . import websocket_api -from .const import DOMAIN # noqa: F401 +from .const import CONF_USE_BLUEPRINT, DOMAIN # noqa: F401 from .errors import ( # noqa: F401 BlueprintException, BlueprintWithNameException, diff --git a/homeassistant/components/blueprint/errors.py b/homeassistant/components/blueprint/errors.py index aceca533d23..fe714542e0f 100644 --- a/homeassistant/components/blueprint/errors.py +++ b/homeassistant/components/blueprint/errors.py @@ -91,3 +91,11 @@ class FileAlreadyExists(BlueprintWithNameException): def __init__(self, domain: str, blueprint_name: str) -> None: """Initialize blueprint exception.""" super().__init__(domain, blueprint_name, "Blueprint already exists") + + +class BlueprintInUse(BlueprintWithNameException): + """Error when a blueprint is in use.""" + + def __init__(self, domain: str, blueprint_name: str) -> None: + """Initialize blueprint exception.""" + super().__init__(domain, blueprint_name, "Blueprint in use") diff --git a/homeassistant/components/blueprint/models.py b/homeassistant/components/blueprint/models.py index 0d90c663b4f..f77a2bed9a4 100644 --- a/homeassistant/components/blueprint/models.py +++ b/homeassistant/components/blueprint/models.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +from collections.abc import Callable import logging import pathlib import shutil @@ -35,6 +36,7 @@ from .const import ( ) from .errors import ( BlueprintException, + BlueprintInUse, FailedToLoad, FileAlreadyExists, InvalidBlueprint, @@ -183,11 +185,13 @@ class DomainBlueprints: hass: HomeAssistant, domain: str, logger: logging.Logger, + blueprint_in_use: Callable[[HomeAssistant, str], bool], ) -> None: """Initialize a domain blueprints instance.""" self.hass = hass self.domain = domain self.logger = logger + self._blueprint_in_use = blueprint_in_use self._blueprints: dict[str, Blueprint | None] = {} self._load_lock = asyncio.Lock() @@ -302,6 +306,8 @@ class DomainBlueprints: async def async_remove_blueprint(self, blueprint_path: str) -> None: """Remove a blueprint file.""" + if self._blueprint_in_use(self.hass, blueprint_path): + raise BlueprintInUse(self.domain, blueprint_path) path = self.blueprint_folder / blueprint_path await self.hass.async_add_executor_job(path.unlink) self._blueprints[blueprint_path] = None diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index efad242fbd0..53bd256c624 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -8,7 +8,7 @@ from typing import Any, cast import voluptuous as vol from voluptuous.humanize import humanize_error -from homeassistant.components.blueprint import BlueprintInputs +from homeassistant.components.blueprint import CONF_USE_BLUEPRINT, BlueprintInputs from homeassistant.const import ( ATTR_ENTITY_ID, ATTR_MODE, @@ -18,6 +18,7 @@ from homeassistant.const import ( CONF_ICON, CONF_MODE, CONF_NAME, + CONF_PATH, CONF_SEQUENCE, CONF_VARIABLES, SERVICE_RELOAD, @@ -165,6 +166,21 @@ def areas_in_script(hass: HomeAssistant, entity_id: str) -> list[str]: return list(script_entity.script.referenced_areas) +@callback +def scripts_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list[str]: + """Return all scripts that reference the blueprint.""" + if DOMAIN not in hass.data: + return [] + + component = hass.data[DOMAIN] + + return [ + script_entity.entity_id + for script_entity in component.entities + if script_entity.referenced_blueprint == blueprint_path + ] + + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Load the scripts from the configuration.""" hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass) @@ -372,6 +388,13 @@ class ScriptEntity(ToggleEntity, RestoreEntity): """Return true if script is on.""" return self.script.is_running + @property + def referenced_blueprint(self): + """Return referenced blueprint or None.""" + if self._blueprint_inputs is None: + return None + return self._blueprint_inputs[CONF_USE_BLUEPRINT][CONF_PATH] + @callback def async_change_listener(self): """Update state.""" diff --git a/homeassistant/components/script/helpers.py b/homeassistant/components/script/helpers.py index 3c78138a4ec..9f0d4399d3d 100644 --- a/homeassistant/components/script/helpers.py +++ b/homeassistant/components/script/helpers.py @@ -8,8 +8,15 @@ from .const import DOMAIN, LOGGER DATA_BLUEPRINTS = "script_blueprints" +def _blueprint_in_use(hass: HomeAssistant, blueprint_path: str) -> bool: + """Return True if any script references the blueprint.""" + from . import scripts_with_blueprint # pylint: disable=import-outside-toplevel + + return len(scripts_with_blueprint(hass, blueprint_path)) > 0 + + @singleton(DATA_BLUEPRINTS) @callback def async_get_blueprints(hass: HomeAssistant) -> DomainBlueprints: """Get script blueprints.""" - return DomainBlueprints(hass, DOMAIN, LOGGER) + return DomainBlueprints(hass, DOMAIN, LOGGER, _blueprint_in_use) diff --git a/tests/components/blueprint/test_models.py b/tests/components/blueprint/test_models.py index 497e8b36e99..02ed94709db 100644 --- a/tests/components/blueprint/test_models.py +++ b/tests/components/blueprint/test_models.py @@ -47,7 +47,9 @@ def blueprint_2(): @pytest.fixture def domain_bps(hass): """Domain blueprints fixture.""" - return models.DomainBlueprints(hass, "automation", logging.getLogger(__name__)) + return models.DomainBlueprints( + hass, "automation", logging.getLogger(__name__), None + ) def test_blueprint_model_init(): diff --git a/tests/components/blueprint/test_websocket_api.py b/tests/components/blueprint/test_websocket_api.py index eb2d12f5081..05c0e4adc4c 100644 --- a/tests/components/blueprint/test_websocket_api.py +++ b/tests/components/blueprint/test_websocket_api.py @@ -8,13 +8,26 @@ from homeassistant.setup import async_setup_component from homeassistant.util.yaml import parse_yaml +@pytest.fixture +def automation_config(): + """Automation config.""" + return {} + + +@pytest.fixture +def script_config(): + """Script config.""" + return {} + + @pytest.fixture(autouse=True) -async def setup_bp(hass): +async def setup_bp(hass, automation_config, script_config): """Fixture to set up the blueprint component.""" assert await async_setup_component(hass, "blueprint", {}) - # Trigger registration of automation blueprints - await async_setup_component(hass, "automation", {}) + # Trigger registration of automation and script blueprints + await async_setup_component(hass, "automation", automation_config) + await async_setup_component(hass, "script", script_config) async def test_list_blueprints(hass, hass_ws_client): @@ -251,3 +264,89 @@ async def test_delete_non_exist_file_blueprint(hass, aioclient_mock, hass_ws_cli assert msg["id"] == 9 assert not msg["success"] + + +@pytest.mark.parametrize( + "automation_config", + ( + { + "automation": { + "use_blueprint": { + "path": "test_event_service.yaml", + "input": { + "trigger_event": "blueprint_event", + "service_to_call": "test.automation", + "a_number": 5, + }, + } + } + }, + ), +) +async def test_delete_blueprint_in_use_by_automation( + hass, aioclient_mock, hass_ws_client +): + """Test deleting a blueprint which is in use.""" + + with patch("pathlib.Path.unlink", return_value=Mock()) as unlink_mock: + client = await hass_ws_client(hass) + await client.send_json( + { + "id": 9, + "type": "blueprint/delete", + "path": "test_event_service.yaml", + "domain": "automation", + } + ) + + msg = await client.receive_json() + + assert not unlink_mock.mock_calls + assert msg["id"] == 9 + assert not msg["success"] + assert msg["error"] == { + "code": "unknown_error", + "message": "Blueprint in use", + } + + +@pytest.mark.parametrize( + "script_config", + ( + { + "script": { + "test_script": { + "use_blueprint": { + "path": "test_service.yaml", + "input": { + "service_to_call": "test.automation", + }, + } + } + } + }, + ), +) +async def test_delete_blueprint_in_use_by_script(hass, aioclient_mock, hass_ws_client): + """Test deleting a blueprint which is in use.""" + + with patch("pathlib.Path.unlink", return_value=Mock()) as unlink_mock: + client = await hass_ws_client(hass) + await client.send_json( + { + "id": 9, + "type": "blueprint/delete", + "path": "test_service.yaml", + "domain": "script", + } + ) + + msg = await client.receive_json() + + assert not unlink_mock.mock_calls + assert msg["id"] == 9 + assert not msg["success"] + assert msg["error"] == { + "code": "unknown_error", + "message": "Blueprint in use", + } diff --git a/tests/testing_config/blueprints/script/test_service.yaml b/tests/testing_config/blueprints/script/test_service.yaml new file mode 100644 index 00000000000..4de991e90dc --- /dev/null +++ b/tests/testing_config/blueprints/script/test_service.yaml @@ -0,0 +1,8 @@ +blueprint: + name: "Call service" + domain: script + input: + service_to_call: +sequence: + service: !input service_to_call + entity_id: light.kitchen