Allow complex schemas for validating WS commands (#91655)
This commit is contained in:
parent
90e92aa9d8
commit
4e0b8a7363
2 changed files with 55 additions and 41 deletions
|
@ -33,47 +33,45 @@ _LOGGER = logging.getLogger(__name__)
|
||||||
@callback
|
@callback
|
||||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
"""Register the websocket API."""
|
"""Register the websocket API."""
|
||||||
websocket_api.async_register_command(
|
websocket_api.async_register_command(hass, websocket_run)
|
||||||
hass,
|
|
||||||
"assist_pipeline/run",
|
|
||||||
websocket_run,
|
|
||||||
vol.All(
|
|
||||||
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
|
||||||
{
|
|
||||||
vol.Required("type"): "assist_pipeline/run",
|
|
||||||
# pylint: disable-next=unnecessary-lambda
|
|
||||||
vol.Required("start_stage"): lambda val: PipelineStage(val),
|
|
||||||
# pylint: disable-next=unnecessary-lambda
|
|
||||||
vol.Required("end_stage"): lambda val: PipelineStage(val),
|
|
||||||
vol.Optional("input"): dict,
|
|
||||||
vol.Optional("pipeline"): str,
|
|
||||||
vol.Optional("conversation_id"): vol.Any(str, None),
|
|
||||||
vol.Optional("timeout"): vol.Any(float, int),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
cv.key_value_schemas(
|
|
||||||
"start_stage",
|
|
||||||
{
|
|
||||||
PipelineStage.STT: vol.Schema(
|
|
||||||
{vol.Required("input"): {vol.Required("sample_rate"): int}},
|
|
||||||
extra=vol.ALLOW_EXTRA,
|
|
||||||
),
|
|
||||||
PipelineStage.INTENT: vol.Schema(
|
|
||||||
{vol.Required("input"): {"text": str}},
|
|
||||||
extra=vol.ALLOW_EXTRA,
|
|
||||||
),
|
|
||||||
PipelineStage.TTS: vol.Schema(
|
|
||||||
{vol.Required("input"): {"text": str}},
|
|
||||||
extra=vol.ALLOW_EXTRA,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
websocket_api.async_register_command(hass, websocket_list_runs)
|
websocket_api.async_register_command(hass, websocket_list_runs)
|
||||||
websocket_api.async_register_command(hass, websocket_get_run)
|
websocket_api.async_register_command(hass, websocket_get_run)
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
vol.All(
|
||||||
|
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "assist_pipeline/run",
|
||||||
|
# pylint: disable-next=unnecessary-lambda
|
||||||
|
vol.Required("start_stage"): lambda val: PipelineStage(val),
|
||||||
|
# pylint: disable-next=unnecessary-lambda
|
||||||
|
vol.Required("end_stage"): lambda val: PipelineStage(val),
|
||||||
|
vol.Optional("input"): dict,
|
||||||
|
vol.Optional("pipeline"): str,
|
||||||
|
vol.Optional("conversation_id"): vol.Any(str, None),
|
||||||
|
vol.Optional("timeout"): vol.Any(float, int),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
cv.key_value_schemas(
|
||||||
|
"start_stage",
|
||||||
|
{
|
||||||
|
PipelineStage.STT: vol.Schema(
|
||||||
|
{vol.Required("input"): {vol.Required("sample_rate"): int}},
|
||||||
|
extra=vol.ALLOW_EXTRA,
|
||||||
|
),
|
||||||
|
PipelineStage.INTENT: vol.Schema(
|
||||||
|
{vol.Required("input"): {"text": str}},
|
||||||
|
extra=vol.ALLOW_EXTRA,
|
||||||
|
),
|
||||||
|
PipelineStage.TTS: vol.Schema(
|
||||||
|
{vol.Required("input"): {"text": str}},
|
||||||
|
extra=vol.ALLOW_EXTRA,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
@websocket_api.async_response
|
@websocket_api.async_response
|
||||||
async def websocket_run(
|
async def websocket_run(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
|
|
@ -128,15 +128,31 @@ def ws_require_user(
|
||||||
|
|
||||||
|
|
||||||
def websocket_command(
|
def websocket_command(
|
||||||
schema: dict[vol.Marker, Any],
|
schema: dict[vol.Marker, Any] | vol.All,
|
||||||
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
|
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
|
||||||
"""Tag a function as a websocket command."""
|
"""Tag a function as a websocket command.
|
||||||
command = schema["type"]
|
|
||||||
|
The schema must be either a dictionary where the keys are voluptuous markers, or
|
||||||
|
a voluptuous.All schema where the first item is a voluptuous Mapping schema.
|
||||||
|
"""
|
||||||
|
if isinstance(schema, dict):
|
||||||
|
command = schema["type"]
|
||||||
|
else:
|
||||||
|
command = schema.validators[0].schema["type"]
|
||||||
|
|
||||||
def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
|
def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
|
||||||
"""Decorate ws command function."""
|
"""Decorate ws command function."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined]
|
if isinstance(schema, dict):
|
||||||
|
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined]
|
||||||
|
else:
|
||||||
|
extended_schema = vol.All(
|
||||||
|
schema.validators[0].extend(
|
||||||
|
messages.BASE_COMMAND_MESSAGE_SCHEMA.schema
|
||||||
|
),
|
||||||
|
*schema.validators[1:],
|
||||||
|
)
|
||||||
|
func._ws_schema = extended_schema # type: ignore[attr-defined]
|
||||||
func._ws_command = command # type: ignore[attr-defined]
|
func._ws_command = command # type: ignore[attr-defined]
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue