Allow complex schemas for validating WS commands (#91655)
This commit is contained in:
parent
90e92aa9d8
commit
4e0b8a7363
2 changed files with 55 additions and 41 deletions
|
@ -33,47 +33,45 @@ _LOGGER = logging.getLogger(__name__)
|
|||
@callback
|
||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||
"""Register the websocket API."""
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
"assist_pipeline/run",
|
||||
websocket_run,
|
||||
vol.All(
|
||||
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{
|
||||
vol.Required("type"): "assist_pipeline/run",
|
||||
# pylint: disable-next=unnecessary-lambda
|
||||
vol.Required("start_stage"): lambda val: PipelineStage(val),
|
||||
# pylint: disable-next=unnecessary-lambda
|
||||
vol.Required("end_stage"): lambda val: PipelineStage(val),
|
||||
vol.Optional("input"): dict,
|
||||
vol.Optional("pipeline"): str,
|
||||
vol.Optional("conversation_id"): vol.Any(str, None),
|
||||
vol.Optional("timeout"): vol.Any(float, int),
|
||||
},
|
||||
),
|
||||
cv.key_value_schemas(
|
||||
"start_stage",
|
||||
{
|
||||
PipelineStage.STT: vol.Schema(
|
||||
{vol.Required("input"): {vol.Required("sample_rate"): int}},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
),
|
||||
PipelineStage.INTENT: vol.Schema(
|
||||
{vol.Required("input"): {"text": str}},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
),
|
||||
PipelineStage.TTS: vol.Schema(
|
||||
{vol.Required("input"): {"text": str}},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
)
|
||||
websocket_api.async_register_command(hass, websocket_run)
|
||||
websocket_api.async_register_command(hass, websocket_list_runs)
|
||||
websocket_api.async_register_command(hass, websocket_get_run)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
vol.All(
|
||||
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{
|
||||
vol.Required("type"): "assist_pipeline/run",
|
||||
# pylint: disable-next=unnecessary-lambda
|
||||
vol.Required("start_stage"): lambda val: PipelineStage(val),
|
||||
# pylint: disable-next=unnecessary-lambda
|
||||
vol.Required("end_stage"): lambda val: PipelineStage(val),
|
||||
vol.Optional("input"): dict,
|
||||
vol.Optional("pipeline"): str,
|
||||
vol.Optional("conversation_id"): vol.Any(str, None),
|
||||
vol.Optional("timeout"): vol.Any(float, int),
|
||||
},
|
||||
),
|
||||
cv.key_value_schemas(
|
||||
"start_stage",
|
||||
{
|
||||
PipelineStage.STT: vol.Schema(
|
||||
{vol.Required("input"): {vol.Required("sample_rate"): int}},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
),
|
||||
PipelineStage.INTENT: vol.Schema(
|
||||
{vol.Required("input"): {"text": str}},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
),
|
||||
PipelineStage.TTS: vol.Schema(
|
||||
{vol.Required("input"): {"text": str}},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_run(
|
||||
hass: HomeAssistant,
|
||||
|
|
|
@ -128,15 +128,31 @@ def ws_require_user(
|
|||
|
||||
|
||||
def websocket_command(
|
||||
schema: dict[vol.Marker, Any],
|
||||
schema: dict[vol.Marker, Any] | vol.All,
|
||||
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
|
||||
"""Tag a function as a websocket command."""
|
||||
command = schema["type"]
|
||||
"""Tag a function as a websocket command.
|
||||
|
||||
The schema must be either a dictionary where the keys are voluptuous markers, or
|
||||
a voluptuous.All schema where the first item is a voluptuous Mapping schema.
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
command = schema["type"]
|
||||
else:
|
||||
command = schema.validators[0].schema["type"]
|
||||
|
||||
def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
|
||||
"""Decorate ws command function."""
|
||||
# pylint: disable=protected-access
|
||||
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined]
|
||||
if isinstance(schema, dict):
|
||||
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined]
|
||||
else:
|
||||
extended_schema = vol.All(
|
||||
schema.validators[0].extend(
|
||||
messages.BASE_COMMAND_MESSAGE_SCHEMA.schema
|
||||
),
|
||||
*schema.validators[1:],
|
||||
)
|
||||
func._ws_schema = extended_schema # type: ignore[attr-defined]
|
||||
func._ws_command = command # type: ignore[attr-defined]
|
||||
return func
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue