Drop language parameter from async_get_pipeline (#91612)

This commit is contained in:
Erik Montnemery 2023-04-18 18:07:20 +02:00 committed by GitHub
parent 10606c4d1e
commit bdffb1f298
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 58 deletions

View file

@ -45,28 +45,16 @@ async def async_pipeline_from_audio_stream(
event_callback: PipelineEventCallback, event_callback: PipelineEventCallback,
stt_metadata: stt.SpeechMetadata, stt_metadata: stt.SpeechMetadata,
stt_stream: AsyncIterable[bytes], stt_stream: AsyncIterable[bytes],
language: str | None = None,
pipeline_id: str | None = None, pipeline_id: str | None = None,
conversation_id: str | None = None, conversation_id: str | None = None,
context: Context | None = None, context: Context | None = None,
tts_options: dict | None = None, tts_options: dict | None = None,
) -> None: ) -> None:
"""Create an audio pipeline from an audio stream.""" """Create an audio pipeline from an audio stream."""
if language is None and pipeline_id is None:
language = hass.config.language
# Temporary workaround for language codes
if language == "en":
language = "en-US"
if context is None: if context is None:
context = Context() context = Context()
pipeline = await async_get_pipeline( pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
hass,
pipeline_id=pipeline_id,
language=language,
)
if pipeline is None: if pipeline is None:
raise PipelineNotFound( raise PipelineNotFound(
"pipeline_not_found", f"Pipeline {pipeline_id} not found" "pipeline_not_found", f"Pipeline {pipeline_id} not found"

View file

@ -53,7 +53,7 @@ SAVE_DELAY = 10
async def async_get_pipeline( async def async_get_pipeline(
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None hass: HomeAssistant, pipeline_id: str | None = None
) -> Pipeline | None: ) -> Pipeline | None:
"""Get a pipeline by id or create one for a language.""" """Get a pipeline by id or create one for a language."""
pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_data: PipelineData = hass.data[DOMAIN]
@ -64,12 +64,11 @@ async def async_get_pipeline(
if pipeline_id is None: if pipeline_id is None:
# There's no preferred pipeline, construct a pipeline for the # There's no preferred pipeline, construct a pipeline for the
# required/configured language # configured language
language = language or hass.config.language
return await pipeline_data.pipeline_store.async_create_item( return await pipeline_data.pipeline_store.async_create_item(
{ {
"name": language, "name": hass.config.language,
"language": language, "language": hass.config.language,
"stt_engine": None, # first engine "stt_engine": None, # first engine
"conversation_engine": None, # first agent "conversation_engine": None, # first agent
"tts_engine": None, # first engine "tts_engine": None, # first engine

View file

@ -46,7 +46,6 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
# pylint: disable-next=unnecessary-lambda # pylint: disable-next=unnecessary-lambda
vol.Required("end_stage"): lambda val: PipelineStage(val), vol.Required("end_stage"): lambda val: PipelineStage(val),
vol.Optional("input"): dict, vol.Optional("input"): dict,
vol.Optional("language"): str,
vol.Optional("pipeline"): str, vol.Optional("pipeline"): str,
vol.Optional("conversation_id"): vol.Any(str, None), vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("timeout"): vol.Any(float, int), vol.Optional("timeout"): vol.Any(float, int),
@ -82,23 +81,13 @@ async def websocket_run(
msg: dict[str, Any], msg: dict[str, Any],
) -> None: ) -> None:
"""Run a pipeline.""" """Run a pipeline."""
language = msg.get("language", hass.config.language)
# Temporary workaround for language codes
if language == "en":
language = "en-US"
pipeline_id = msg.get("pipeline") pipeline_id = msg.get("pipeline")
pipeline = await async_get_pipeline( pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
hass,
pipeline_id=pipeline_id,
language=language,
)
if pipeline is None: if pipeline is None:
connection.send_error( connection.send_error(
msg["id"], msg["id"],
"pipeline-not-found", "pipeline-not-found",
f"Pipeline not found: id={pipeline_id}, language={language}", f"Pipeline not found: id={pipeline_id}",
) )
return return
@ -147,7 +136,7 @@ async def websocket_run(
# Audio input must be raw PCM at 16Khz with 16-bit mono samples # Audio input must be raw PCM at 16Khz with 16-bit mono samples
input_args["stt_metadata"] = stt.SpeechMetadata( input_args["stt_metadata"] = stt.SpeechMetadata(
language=language, language=pipeline.language,
format=stt.AudioFormats.WAV, format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM, codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16, bit_rate=stt.AudioBitRates.BITRATE_16,

View file

@ -3,8 +3,8 @@
list([ list([
dict({ dict({
'data': dict({ 'data': dict({
'language': 'en-US', 'language': 'en',
'pipeline': 'en-US', 'pipeline': 'en',
}), }),
'type': <PipelineEventType.RUN_START: 'run-start'>, 'type': <PipelineEventType.RUN_START: 'run-start'>,
}), }),
@ -47,7 +47,7 @@
'data': dict({ 'data': dict({
'code': 'no_intent_match', 'code': 'no_intent_match',
}), }),
'language': 'en-US', 'language': 'en',
'response_type': 'error', 'response_type': 'error',
'speech': dict({ 'speech': dict({
'plain': dict({ 'plain': dict({
@ -70,7 +70,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}), }),

View file

@ -1,8 +1,8 @@
# serializer version: 1 # serializer version: 1
# name: test_audio_pipeline # name: test_audio_pipeline
dict({ dict({
'language': 'en-US', 'language': 'en',
'pipeline': 'en-US', 'pipeline': 'en',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': 1, 'stt_binary_handler_id': 1,
'timeout': 30, 'timeout': 30,
@ -45,7 +45,7 @@
'data': dict({ 'data': dict({
'code': 'no_intent_match', 'code': 'no_intent_match',
}), }),
'language': 'en-US', 'language': 'en',
'response_type': 'error', 'response_type': 'error',
'speech': dict({ 'speech': dict({
'plain': dict({ 'plain': dict({
@ -66,7 +66,7 @@
# name: test_audio_pipeline.6 # name: test_audio_pipeline.6
dict({ dict({
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}), }),
@ -74,8 +74,8 @@
# --- # ---
# name: test_audio_pipeline_debug # name: test_audio_pipeline_debug
dict({ dict({
'language': 'en-US', 'language': 'en',
'pipeline': 'en-US', 'pipeline': 'en',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': 1, 'stt_binary_handler_id': 1,
'timeout': 30, 'timeout': 30,
@ -118,7 +118,7 @@
'data': dict({ 'data': dict({
'code': 'no_intent_match', 'code': 'no_intent_match',
}), }),
'language': 'en-US', 'language': 'en',
'response_type': 'error', 'response_type': 'error',
'speech': dict({ 'speech': dict({
'plain': dict({ 'plain': dict({
@ -139,7 +139,7 @@
# name: test_audio_pipeline_debug.6 # name: test_audio_pipeline_debug.6
dict({ dict({
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}), }),
@ -147,8 +147,8 @@
# --- # ---
# name: test_intent_failed # name: test_intent_failed
dict({ dict({
'language': 'en-US', 'language': 'en',
'pipeline': 'en-US', 'pipeline': 'en',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': None, 'stt_binary_handler_id': None,
'timeout': 30, 'timeout': 30,
@ -163,8 +163,8 @@
# --- # ---
# name: test_intent_timeout # name: test_intent_timeout
dict({ dict({
'language': 'en-US', 'language': 'en',
'pipeline': 'en-US', 'pipeline': 'en',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': None, 'stt_binary_handler_id': None,
'timeout': 0.1, 'timeout': 0.1,
@ -185,8 +185,8 @@
# --- # ---
# name: test_stt_provider_missing # name: test_stt_provider_missing
dict({ dict({
'language': 'en-US', 'language': 'en',
'pipeline': 'en-US', 'pipeline': 'en',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': 1, 'stt_binary_handler_id': 1,
'timeout': 30, 'timeout': 30,
@ -201,15 +201,15 @@
'channel': 1, 'channel': 1,
'codec': 'pcm', 'codec': 'pcm',
'format': 'wav', 'format': 'wav',
'language': 'en-US', 'language': 'en',
'sample_rate': 16000, 'sample_rate': 16000,
}), }),
}) })
# --- # ---
# name: test_stt_stream_failed # name: test_stt_stream_failed
dict({ dict({
'language': 'en-US', 'language': 'en',
'pipeline': 'en-US', 'pipeline': 'en',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': 1, 'stt_binary_handler_id': 1,
'timeout': 30, 'timeout': 30,
@ -231,8 +231,8 @@
# --- # ---
# name: test_text_only_pipeline # name: test_text_only_pipeline
dict({ dict({
'language': 'en-US', 'language': 'en',
'pipeline': 'en-US', 'pipeline': 'en',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': None, 'stt_binary_handler_id': None,
'timeout': 30, 'timeout': 30,
@ -255,7 +255,7 @@
'data': dict({ 'data': dict({
'code': 'no_intent_match', 'code': 'no_intent_match',
}), }),
'language': 'en-US', 'language': 'en',
'response_type': 'error', 'response_type': 'error',
'speech': dict({ 'speech': dict({
'plain': dict({ 'plain': dict({
@ -275,8 +275,8 @@
# --- # ---
# name: test_tts_failed # name: test_tts_failed
dict({ dict({
'language': 'en-US', 'language': 'en',
'pipeline': 'en-US', 'pipeline': 'en',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': None, 'stt_binary_handler_id': None,
'timeout': 30, 'timeout': 30,