"""Assist pipeline Websocket API.""" import asyncio # Suppressing disable=deprecated-module is needed for Python 3.11 import audioop # pylint: disable=deprecated-module from collections.abc import AsyncGenerator, Callable import logging from typing import Any import async_timeout import voluptuous as vol from homeassistant.components import conversation, stt, tts, websocket_api from homeassistant.const import MATCH_ALL from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import config_validation as cv from homeassistant.util import language as language_util from .const import DOMAIN from .pipeline import ( PipelineData, PipelineError, PipelineEvent, PipelineEventType, PipelineInput, PipelineRun, PipelineStage, async_get_pipeline, ) from .vad import VoiceCommandSegmenter DEFAULT_TIMEOUT = 30 _LOGGER = logging.getLogger(__name__) @callback def async_register_websocket_api(hass: HomeAssistant) -> None: """Register the websocket API.""" websocket_api.async_register_command(hass, websocket_run) websocket_api.async_register_command(hass, websocket_list_languages) websocket_api.async_register_command(hass, websocket_list_runs) websocket_api.async_register_command(hass, websocket_get_run) @websocket_api.websocket_command( vol.All( websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( { vol.Required("type"): "assist_pipeline/run", # pylint: disable-next=unnecessary-lambda vol.Required("start_stage"): lambda val: PipelineStage(val), # pylint: disable-next=unnecessary-lambda vol.Required("end_stage"): lambda val: PipelineStage(val), vol.Optional("input"): dict, vol.Optional("pipeline"): str, vol.Optional("conversation_id"): vol.Any(str, None), vol.Optional("timeout"): vol.Any(float, int), }, ), cv.key_value_schemas( "start_stage", { PipelineStage.STT: vol.Schema( {vol.Required("input"): {vol.Required("sample_rate"): int}}, extra=vol.ALLOW_EXTRA, ), PipelineStage.INTENT: vol.Schema( {vol.Required("input"): {"text": str}}, extra=vol.ALLOW_EXTRA, ), PipelineStage.TTS: vol.Schema( {vol.Required("input"): {"text": str}}, extra=vol.ALLOW_EXTRA, ), }, ), ), ) @websocket_api.async_response async def websocket_run( hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any], ) -> None: """Run a pipeline.""" pipeline_id = msg.get("pipeline") pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id) if pipeline is None: connection.send_error( msg["id"], "pipeline-not-found", f"Pipeline not found: id={pipeline_id}", ) return timeout = msg.get("timeout", DEFAULT_TIMEOUT) start_stage = PipelineStage(msg["start_stage"]) end_stage = PipelineStage(msg["end_stage"]) handler_id: int | None = None unregister_handler: Callable[[], None] | None = None # Arguments to PipelineInput input_args: dict[str, Any] = { "conversation_id": msg.get("conversation_id"), } if start_stage == PipelineStage.STT: # Audio pipeline that will receive audio as binary websocket messages audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue() incoming_sample_rate = msg["input"]["sample_rate"] async def stt_stream() -> AsyncGenerator[bytes, None]: state = None segmenter = VoiceCommandSegmenter() # Yield until we receive an empty chunk while chunk := await audio_queue.get(): chunk, state = audioop.ratecv( chunk, 2, 1, incoming_sample_rate, 16000, state ) if not segmenter.process(chunk): # Voice command is finished break yield chunk def handle_binary( _hass: HomeAssistant, _connection: websocket_api.ActiveConnection, data: bytes, ) -> None: # Forward to STT audio stream audio_queue.put_nowait(data) handler_id, unregister_handler = connection.async_register_binary_handler( handle_binary ) # Audio input must be raw PCM at 16Khz with 16-bit mono samples input_args["stt_metadata"] = stt.SpeechMetadata( language=pipeline.stt_language or pipeline.language, format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ) input_args["stt_stream"] = stt_stream() elif start_stage == PipelineStage.INTENT: # Input to conversation agent input_args["intent_input"] = msg["input"]["text"] elif start_stage == PipelineStage.TTS: # Input to text to speech system input_args["tts_input"] = msg["input"]["text"] input_args["run"] = PipelineRun( hass, context=connection.context(msg), pipeline=pipeline, start_stage=start_stage, end_stage=end_stage, event_callback=lambda event: connection.send_event(msg["id"], event), runner_data={ "stt_binary_handler_id": handler_id, "timeout": timeout, }, ) pipeline_input = PipelineInput(**input_args) try: await pipeline_input.validate() except PipelineError as error: # Report more specific error when possible connection.send_error(msg["id"], error.code, error.message) return # Confirm subscription connection.send_result(msg["id"]) run_task = hass.async_create_task(pipeline_input.execute()) # Cancel pipeline if user unsubscribes connection.subscriptions[msg["id"]] = run_task.cancel try: # Task contains a timeout async with async_timeout.timeout(timeout): await run_task except asyncio.TimeoutError: pipeline_input.run.process_event( PipelineEvent( PipelineEventType.ERROR, {"code": "timeout", "message": "Timeout running pipeline"}, ) ) finally: if unregister_handler is not None: # Unregister binary handler unregister_handler() @callback @websocket_api.require_admin @websocket_api.websocket_command( { vol.Required("type"): "assist_pipeline/pipeline_debug/list", vol.Required("pipeline_id"): str, } ) def websocket_list_runs( hass: HomeAssistant, connection: websocket_api.connection.ActiveConnection, msg: dict[str, Any], ) -> None: """List pipeline runs for which debug data is available.""" pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_id = msg["pipeline_id"] if pipeline_id not in pipeline_data.pipeline_runs: connection.send_result(msg["id"], {"pipeline_runs": []}) return pipeline_runs = pipeline_data.pipeline_runs[pipeline_id] connection.send_result( msg["id"], { "pipeline_runs": [ {"pipeline_run_id": id, "timestamp": pipeline_run.timestamp} for id, pipeline_run in pipeline_runs.items() ] }, ) @callback @websocket_api.require_admin @websocket_api.websocket_command( { vol.Required("type"): "assist_pipeline/pipeline_debug/get", vol.Required("pipeline_id"): str, vol.Required("pipeline_run_id"): str, } ) def websocket_get_run( hass: HomeAssistant, connection: websocket_api.connection.ActiveConnection, msg: dict[str, Any], ) -> None: """Get debug data for a pipeline run.""" pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_id = msg["pipeline_id"] pipeline_run_id = msg["pipeline_run_id"] if pipeline_id not in pipeline_data.pipeline_runs: connection.send_error( msg["id"], websocket_api.const.ERR_NOT_FOUND, f"pipeline_id {pipeline_id} not found", ) return pipeline_runs = pipeline_data.pipeline_runs[pipeline_id] if pipeline_run_id not in pipeline_runs: connection.send_error( msg["id"], websocket_api.const.ERR_NOT_FOUND, f"pipeline_run_id {pipeline_run_id} not found", ) return connection.send_result( msg["id"], {"events": pipeline_runs[pipeline_run_id].events}, ) @callback @websocket_api.websocket_command( { vol.Required("type"): "assist_pipeline/language/list", } ) @websocket_api.async_response async def websocket_list_languages( hass: HomeAssistant, connection: websocket_api.connection.ActiveConnection, msg: dict[str, Any], ) -> None: """List languages which are supported by a complete pipeline. This will return a list of languages which are supported by at least one stt, tts and conversation engine respectively. """ conv_language_tags = await conversation.async_get_conversation_languages(hass) stt_language_tags = stt.async_get_speech_to_text_languages(hass) tts_language_tags = tts.async_get_text_to_speech_languages(hass) pipeline_languages: set[str] | None = None if conv_language_tags and conv_language_tags != MATCH_ALL: languages = set() for language_tag in conv_language_tags: dialect = language_util.Dialect.parse(language_tag) languages.add(dialect.language) pipeline_languages = languages if stt_language_tags: languages = set() for language_tag in stt_language_tags: dialect = language_util.Dialect.parse(language_tag) languages.add(dialect.language) if pipeline_languages is not None: pipeline_languages &= languages else: pipeline_languages = languages if tts_language_tags: languages = set() for language_tag in tts_language_tags: dialect = language_util.Dialect.parse(language_tag) languages.add(dialect.language) if pipeline_languages is not None: pipeline_languages &= languages else: pipeline_languages = languages connection.send_result( msg["id"], {"languages": pipeline_languages}, )