Allow complex schemas for validating WS commands (#91655)

This commit is contained in:
Erik Montnemery 2023-04-19 17:37:09 +02:00 committed by GitHub
parent 90e92aa9d8
commit 4e0b8a7363
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 41 deletions

View file

@ -33,47 +33,45 @@ _LOGGER = logging.getLogger(__name__)
@callback @callback
def async_register_websocket_api(hass: HomeAssistant) -> None: def async_register_websocket_api(hass: HomeAssistant) -> None:
"""Register the websocket API.""" """Register the websocket API."""
websocket_api.async_register_command( websocket_api.async_register_command(hass, websocket_run)
hass,
"assist_pipeline/run",
websocket_run,
vol.All(
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): "assist_pipeline/run",
# pylint: disable-next=unnecessary-lambda
vol.Required("start_stage"): lambda val: PipelineStage(val),
# pylint: disable-next=unnecessary-lambda
vol.Required("end_stage"): lambda val: PipelineStage(val),
vol.Optional("input"): dict,
vol.Optional("pipeline"): str,
vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("timeout"): vol.Any(float, int),
},
),
cv.key_value_schemas(
"start_stage",
{
PipelineStage.STT: vol.Schema(
{vol.Required("input"): {vol.Required("sample_rate"): int}},
extra=vol.ALLOW_EXTRA,
),
PipelineStage.INTENT: vol.Schema(
{vol.Required("input"): {"text": str}},
extra=vol.ALLOW_EXTRA,
),
PipelineStage.TTS: vol.Schema(
{vol.Required("input"): {"text": str}},
extra=vol.ALLOW_EXTRA,
),
},
),
),
)
websocket_api.async_register_command(hass, websocket_list_runs) websocket_api.async_register_command(hass, websocket_list_runs)
websocket_api.async_register_command(hass, websocket_get_run) websocket_api.async_register_command(hass, websocket_get_run)
@websocket_api.websocket_command(
vol.All(
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): "assist_pipeline/run",
# pylint: disable-next=unnecessary-lambda
vol.Required("start_stage"): lambda val: PipelineStage(val),
# pylint: disable-next=unnecessary-lambda
vol.Required("end_stage"): lambda val: PipelineStage(val),
vol.Optional("input"): dict,
vol.Optional("pipeline"): str,
vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("timeout"): vol.Any(float, int),
},
),
cv.key_value_schemas(
"start_stage",
{
PipelineStage.STT: vol.Schema(
{vol.Required("input"): {vol.Required("sample_rate"): int}},
extra=vol.ALLOW_EXTRA,
),
PipelineStage.INTENT: vol.Schema(
{vol.Required("input"): {"text": str}},
extra=vol.ALLOW_EXTRA,
),
PipelineStage.TTS: vol.Schema(
{vol.Required("input"): {"text": str}},
extra=vol.ALLOW_EXTRA,
),
},
),
),
)
@websocket_api.async_response @websocket_api.async_response
async def websocket_run( async def websocket_run(
hass: HomeAssistant, hass: HomeAssistant,

View file

@ -128,15 +128,31 @@ def ws_require_user(
def websocket_command( def websocket_command(
schema: dict[vol.Marker, Any], schema: dict[vol.Marker, Any] | vol.All,
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]: ) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
"""Tag a function as a websocket command.""" """Tag a function as a websocket command.
command = schema["type"]
The schema must be either a dictionary where the keys are voluptuous markers, or
a voluptuous.All schema where the first item is a voluptuous Mapping schema.
"""
if isinstance(schema, dict):
command = schema["type"]
else:
command = schema.validators[0].schema["type"]
def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler: def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
"""Decorate ws command function.""" """Decorate ws command function."""
# pylint: disable=protected-access # pylint: disable=protected-access
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined] if isinstance(schema, dict):
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined]
else:
extended_schema = vol.All(
schema.validators[0].extend(
messages.BASE_COMMAND_MESSAGE_SCHEMA.schema
),
*schema.validators[1:],
)
func._ws_schema = extended_schema # type: ignore[attr-defined]
func._ws_command = command # type: ignore[attr-defined] func._ws_command = command # type: ignore[attr-defined]
return func return func