Create a script service schema based on fields (#128622)

This commit is contained in:
Paulus Schoutsen 2024-10-25 16:05:00 -07:00 committed by GitHub
parent ababa639b3
commit 10300cc478
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 131 additions and 1 deletions

View file

@ -18,11 +18,13 @@ from homeassistant.const import (
ATTR_MODE, ATTR_MODE,
ATTR_NAME, ATTR_NAME,
CONF_ALIAS, CONF_ALIAS,
CONF_DEFAULT,
CONF_DESCRIPTION, CONF_DESCRIPTION,
CONF_ICON, CONF_ICON,
CONF_MODE, CONF_MODE,
CONF_NAME, CONF_NAME,
CONF_PATH, CONF_PATH,
CONF_SELECTOR,
CONF_SEQUENCE, CONF_SEQUENCE,
CONF_VARIABLES, CONF_VARIABLES,
SERVICE_RELOAD, SERVICE_RELOAD,
@ -58,6 +60,7 @@ from homeassistant.helpers.script import (
ScriptRunResult, ScriptRunResult,
script_stack_cv, script_stack_cv,
) )
from homeassistant.helpers.selector import selector
from homeassistant.helpers.service import async_set_service_schema from homeassistant.helpers.service import async_set_service_schema
from homeassistant.helpers.trace import trace_get, trace_path from homeassistant.helpers.trace import trace_get, trace_path
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@ -71,6 +74,7 @@ from .const import (
ATTR_LAST_TRIGGERED, ATTR_LAST_TRIGGERED,
ATTR_VARIABLES, ATTR_VARIABLES,
CONF_FIELDS, CONF_FIELDS,
CONF_REQUIRED,
CONF_TRACE, CONF_TRACE,
DOMAIN, DOMAIN,
ENTITY_ID_FORMAT, ENTITY_ID_FORMAT,
@ -730,11 +734,40 @@ class ScriptEntity(BaseScriptEntity, RestoreEntity):
unique_id = self.unique_id unique_id = self.unique_id
hass = self.hass hass = self.hass
service_schema = {}
for field_name, field_info in self.fields.items():
key_cls = vol.Required if field_info[CONF_REQUIRED] else vol.Optional
key_kwargs = {}
if CONF_DEFAULT in field_info:
key_kwargs["default"] = field_info[CONF_DEFAULT]
if CONF_SELECTOR in field_info:
validator: Any = selector(field_info[CONF_SELECTOR])
# Default values need to match the validator.
# When they don't match, we will not enforce validation
if CONF_DEFAULT in field_info:
try:
validator(field_info[CONF_DEFAULT])
except vol.Invalid:
logging.getLogger(f"{__name__}.{self._attr_unique_id}").warning(
"Field %s has invalid default value %s",
field_name,
field_info[CONF_DEFAULT],
)
validator = cv.match_all
else:
validator = cv.match_all
service_schema[key_cls(field_name, **key_kwargs)] = validator
hass.services.async_register( hass.services.async_register(
DOMAIN, DOMAIN,
unique_id, unique_id,
self._service_handler, self._service_handler,
schema=SCRIPT_SERVICE_SCHEMA, schema=vol.Schema(service_schema, extra=vol.ALLOW_EXTRA),
supports_response=SupportsResponse.OPTIONAL, supports_response=SupportsResponse.OPTIONAL,
) )

View file

@ -6,6 +6,7 @@ from typing import Any
from unittest.mock import ANY, Mock, patch from unittest.mock import ANY, Mock, patch
import pytest import pytest
import voluptuous as vol
from homeassistant.components import script from homeassistant.components import script
from homeassistant.components.script import DOMAIN, EVENT_SCRIPT_STARTED, ScriptEntity from homeassistant.components.script import DOMAIN, EVENT_SCRIPT_STARTED, ScriptEntity
@ -48,6 +49,7 @@ import homeassistant.util.dt as dt_util
from tests.common import ( from tests.common import (
MockConfigEntry, MockConfigEntry,
MockUser, MockUser,
async_capture_events,
async_fire_time_changed, async_fire_time_changed,
async_mock_service, async_mock_service,
mock_restore_cache, mock_restore_cache,
@ -557,6 +559,101 @@ async def test_reload_unchanged_script(
assert len(calls) == 2 assert len(calls) == 2
async def test_service_schema(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that service schema are defined correctly."""
events = async_capture_events(hass, "test_event")
assert await async_setup_component(
hass,
"script",
{
"script": {
"test": {
"fields": {
"param_with_default": {
"default": "default_value",
},
"required_param": {
"required": True,
},
"selector_param": {
"selector": {
"select": {
"options": [
"one",
"two",
]
}
}
},
"invalid_default": {
"default": "invalid-value",
"selector": {"number": {"min": 0, "max": 2}},
},
},
"sequence": [
{
"event": "test_event",
"event_data": {
"param_with_default": "{{ param_with_default }}",
"required_param": "{{ required_param }}",
"selector_param": "{{ selector_param | default('not_set') }}",
"invalid_default": "{{ invalid_default }}",
},
}
],
}
}
},
)
assert (
"Field invalid_default has invalid default value invalid-value" in caplog.text
)
await hass.services.async_call(
DOMAIN,
"test",
{"required_param": "required_value"},
blocking=True,
)
assert len(events) == 1
assert events[0].data["param_with_default"] == "default_value"
assert events[0].data["required_param"] == "required_value"
assert events[0].data["selector_param"] == "not_set"
assert events[0].data["invalid_default"] == "invalid-value"
with pytest.raises(vol.Invalid):
await hass.services.async_call(
DOMAIN,
"test",
{
"required_param": "required_value",
"selector_param": "invalid_value",
},
blocking=True,
)
await hass.services.async_call(
DOMAIN,
"test",
{
"param_with_default": "service_set_value",
"required_param": "required_value",
"selector_param": "one",
"invalid_default": "another-value",
},
blocking=True,
)
assert len(events) == 2
assert events[1].data["param_with_default"] == "service_set_value"
assert events[1].data["required_param"] == "required_value"
assert events[1].data["selector_param"] == "one"
assert events[1].data["invalid_default"] == "another-value"
async def test_service_descriptions(hass: HomeAssistant) -> None: async def test_service_descriptions(hass: HomeAssistant) -> None:
"""Test that service descriptions are loaded and reloaded correctly.""" """Test that service descriptions are loaded and reloaded correctly."""
# Test 1: has "description" but no "fields" # Test 1: has "description" but no "fields"