diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index c0d79c446bb..1af553165bd 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -18,11 +18,13 @@ from homeassistant.const import ( ATTR_MODE, ATTR_NAME, CONF_ALIAS, + CONF_DEFAULT, CONF_DESCRIPTION, CONF_ICON, CONF_MODE, CONF_NAME, CONF_PATH, + CONF_SELECTOR, CONF_SEQUENCE, CONF_VARIABLES, SERVICE_RELOAD, @@ -58,6 +60,7 @@ from homeassistant.helpers.script import ( ScriptRunResult, script_stack_cv, ) +from homeassistant.helpers.selector import selector from homeassistant.helpers.service import async_set_service_schema from homeassistant.helpers.trace import trace_get, trace_path from homeassistant.helpers.typing import ConfigType @@ -71,6 +74,7 @@ from .const import ( ATTR_LAST_TRIGGERED, ATTR_VARIABLES, CONF_FIELDS, + CONF_REQUIRED, CONF_TRACE, DOMAIN, ENTITY_ID_FORMAT, @@ -730,11 +734,40 @@ class ScriptEntity(BaseScriptEntity, RestoreEntity): unique_id = self.unique_id hass = self.hass + + service_schema = {} + for field_name, field_info in self.fields.items(): + key_cls = vol.Required if field_info[CONF_REQUIRED] else vol.Optional + key_kwargs = {} + if CONF_DEFAULT in field_info: + key_kwargs["default"] = field_info[CONF_DEFAULT] + + if CONF_SELECTOR in field_info: + validator: Any = selector(field_info[CONF_SELECTOR]) + + # Default values need to match the validator. + # When they don't match, we will not enforce validation + if CONF_DEFAULT in field_info: + try: + validator(field_info[CONF_DEFAULT]) + except vol.Invalid: + logging.getLogger(f"{__name__}.{self._attr_unique_id}").warning( + "Field %s has invalid default value %s", + field_name, + field_info[CONF_DEFAULT], + ) + validator = cv.match_all + + else: + validator = cv.match_all + + service_schema[key_cls(field_name, **key_kwargs)] = validator + hass.services.async_register( DOMAIN, unique_id, self._service_handler, - schema=SCRIPT_SERVICE_SCHEMA, + schema=vol.Schema(service_schema, extra=vol.ALLOW_EXTRA), supports_response=SupportsResponse.OPTIONAL, ) diff --git a/tests/components/script/test_init.py b/tests/components/script/test_init.py index a5eda3757a9..96ac73438ea 100644 --- a/tests/components/script/test_init.py +++ b/tests/components/script/test_init.py @@ -6,6 +6,7 @@ from typing import Any from unittest.mock import ANY, Mock, patch import pytest +import voluptuous as vol from homeassistant.components import script from homeassistant.components.script import DOMAIN, EVENT_SCRIPT_STARTED, ScriptEntity @@ -48,6 +49,7 @@ import homeassistant.util.dt as dt_util from tests.common import ( MockConfigEntry, MockUser, + async_capture_events, async_fire_time_changed, async_mock_service, mock_restore_cache, @@ -557,6 +559,101 @@ async def test_reload_unchanged_script( assert len(calls) == 2 +async def test_service_schema( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test that service schema are defined correctly.""" + events = async_capture_events(hass, "test_event") + + assert await async_setup_component( + hass, + "script", + { + "script": { + "test": { + "fields": { + "param_with_default": { + "default": "default_value", + }, + "required_param": { + "required": True, + }, + "selector_param": { + "selector": { + "select": { + "options": [ + "one", + "two", + ] + } + } + }, + "invalid_default": { + "default": "invalid-value", + "selector": {"number": {"min": 0, "max": 2}}, + }, + }, + "sequence": [ + { + "event": "test_event", + "event_data": { + "param_with_default": "{{ param_with_default }}", + "required_param": "{{ required_param }}", + "selector_param": "{{ selector_param | default('not_set') }}", + "invalid_default": "{{ invalid_default }}", + }, + } + ], + } + } + }, + ) + + assert ( + "Field invalid_default has invalid default value invalid-value" in caplog.text + ) + + await hass.services.async_call( + DOMAIN, + "test", + {"required_param": "required_value"}, + blocking=True, + ) + assert len(events) == 1 + assert events[0].data["param_with_default"] == "default_value" + assert events[0].data["required_param"] == "required_value" + assert events[0].data["selector_param"] == "not_set" + assert events[0].data["invalid_default"] == "invalid-value" + + with pytest.raises(vol.Invalid): + await hass.services.async_call( + DOMAIN, + "test", + { + "required_param": "required_value", + "selector_param": "invalid_value", + }, + blocking=True, + ) + + await hass.services.async_call( + DOMAIN, + "test", + { + "param_with_default": "service_set_value", + "required_param": "required_value", + "selector_param": "one", + "invalid_default": "another-value", + }, + blocking=True, + ) + assert len(events) == 2 + assert events[1].data["param_with_default"] == "service_set_value" + assert events[1].data["required_param"] == "required_value" + assert events[1].data["selector_param"] == "one" + assert events[1].data["invalid_default"] == "another-value" + + async def test_service_descriptions(hass: HomeAssistant) -> None: """Test that service descriptions are loaded and reloaded correctly.""" # Test 1: has "description" but no "fields"