Drop language parameter from async_get_pipeline (#91612)
This commit is contained in:
parent
10606c4d1e
commit
bdffb1f298
5 changed files with 34 additions and 58 deletions
|
@ -45,28 +45,16 @@ async def async_pipeline_from_audio_stream(
|
||||||
event_callback: PipelineEventCallback,
|
event_callback: PipelineEventCallback,
|
||||||
stt_metadata: stt.SpeechMetadata,
|
stt_metadata: stt.SpeechMetadata,
|
||||||
stt_stream: AsyncIterable[bytes],
|
stt_stream: AsyncIterable[bytes],
|
||||||
language: str | None = None,
|
|
||||||
pipeline_id: str | None = None,
|
pipeline_id: str | None = None,
|
||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
context: Context | None = None,
|
context: Context | None = None,
|
||||||
tts_options: dict | None = None,
|
tts_options: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create an audio pipeline from an audio stream."""
|
"""Create an audio pipeline from an audio stream."""
|
||||||
if language is None and pipeline_id is None:
|
|
||||||
language = hass.config.language
|
|
||||||
|
|
||||||
# Temporary workaround for language codes
|
|
||||||
if language == "en":
|
|
||||||
language = "en-US"
|
|
||||||
|
|
||||||
if context is None:
|
if context is None:
|
||||||
context = Context()
|
context = Context()
|
||||||
|
|
||||||
pipeline = await async_get_pipeline(
|
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
|
||||||
hass,
|
|
||||||
pipeline_id=pipeline_id,
|
|
||||||
language=language,
|
|
||||||
)
|
|
||||||
if pipeline is None:
|
if pipeline is None:
|
||||||
raise PipelineNotFound(
|
raise PipelineNotFound(
|
||||||
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
|
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
|
||||||
|
|
|
@ -53,7 +53,7 @@ SAVE_DELAY = 10
|
||||||
|
|
||||||
|
|
||||||
async def async_get_pipeline(
|
async def async_get_pipeline(
|
||||||
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
|
hass: HomeAssistant, pipeline_id: str | None = None
|
||||||
) -> Pipeline | None:
|
) -> Pipeline | None:
|
||||||
"""Get a pipeline by id or create one for a language."""
|
"""Get a pipeline by id or create one for a language."""
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
@ -64,12 +64,11 @@ async def async_get_pipeline(
|
||||||
|
|
||||||
if pipeline_id is None:
|
if pipeline_id is None:
|
||||||
# There's no preferred pipeline, construct a pipeline for the
|
# There's no preferred pipeline, construct a pipeline for the
|
||||||
# required/configured language
|
# configured language
|
||||||
language = language or hass.config.language
|
|
||||||
return await pipeline_data.pipeline_store.async_create_item(
|
return await pipeline_data.pipeline_store.async_create_item(
|
||||||
{
|
{
|
||||||
"name": language,
|
"name": hass.config.language,
|
||||||
"language": language,
|
"language": hass.config.language,
|
||||||
"stt_engine": None, # first engine
|
"stt_engine": None, # first engine
|
||||||
"conversation_engine": None, # first agent
|
"conversation_engine": None, # first agent
|
||||||
"tts_engine": None, # first engine
|
"tts_engine": None, # first engine
|
||||||
|
|
|
@ -46,7 +46,6 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
# pylint: disable-next=unnecessary-lambda
|
# pylint: disable-next=unnecessary-lambda
|
||||||
vol.Required("end_stage"): lambda val: PipelineStage(val),
|
vol.Required("end_stage"): lambda val: PipelineStage(val),
|
||||||
vol.Optional("input"): dict,
|
vol.Optional("input"): dict,
|
||||||
vol.Optional("language"): str,
|
|
||||||
vol.Optional("pipeline"): str,
|
vol.Optional("pipeline"): str,
|
||||||
vol.Optional("conversation_id"): vol.Any(str, None),
|
vol.Optional("conversation_id"): vol.Any(str, None),
|
||||||
vol.Optional("timeout"): vol.Any(float, int),
|
vol.Optional("timeout"): vol.Any(float, int),
|
||||||
|
@ -82,23 +81,13 @@ async def websocket_run(
|
||||||
msg: dict[str, Any],
|
msg: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run a pipeline."""
|
"""Run a pipeline."""
|
||||||
language = msg.get("language", hass.config.language)
|
|
||||||
|
|
||||||
# Temporary workaround for language codes
|
|
||||||
if language == "en":
|
|
||||||
language = "en-US"
|
|
||||||
|
|
||||||
pipeline_id = msg.get("pipeline")
|
pipeline_id = msg.get("pipeline")
|
||||||
pipeline = await async_get_pipeline(
|
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
|
||||||
hass,
|
|
||||||
pipeline_id=pipeline_id,
|
|
||||||
language=language,
|
|
||||||
)
|
|
||||||
if pipeline is None:
|
if pipeline is None:
|
||||||
connection.send_error(
|
connection.send_error(
|
||||||
msg["id"],
|
msg["id"],
|
||||||
"pipeline-not-found",
|
"pipeline-not-found",
|
||||||
f"Pipeline not found: id={pipeline_id}, language={language}",
|
f"Pipeline not found: id={pipeline_id}",
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -147,7 +136,7 @@ async def websocket_run(
|
||||||
|
|
||||||
# Audio input must be raw PCM at 16Khz with 16-bit mono samples
|
# Audio input must be raw PCM at 16Khz with 16-bit mono samples
|
||||||
input_args["stt_metadata"] = stt.SpeechMetadata(
|
input_args["stt_metadata"] = stt.SpeechMetadata(
|
||||||
language=language,
|
language=pipeline.language,
|
||||||
format=stt.AudioFormats.WAV,
|
format=stt.AudioFormats.WAV,
|
||||||
codec=stt.AudioCodecs.PCM,
|
codec=stt.AudioCodecs.PCM,
|
||||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
|
|
@ -3,8 +3,8 @@
|
||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'pipeline': 'en-US',
|
'pipeline': 'en',
|
||||||
}),
|
}),
|
||||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||||
}),
|
}),
|
||||||
|
@ -47,7 +47,7 @@
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'code': 'no_intent_match',
|
'code': 'no_intent_match',
|
||||||
}),
|
}),
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'response_type': 'error',
|
'response_type': 'error',
|
||||||
'speech': dict({
|
'speech': dict({
|
||||||
'plain': dict({
|
'plain': dict({
|
||||||
|
@ -70,7 +70,7 @@
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
|
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_audio_pipeline
|
# name: test_audio_pipeline
|
||||||
dict({
|
dict({
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'pipeline': 'en-US',
|
'pipeline': 'en',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': 1,
|
'stt_binary_handler_id': 1,
|
||||||
'timeout': 30,
|
'timeout': 30,
|
||||||
|
@ -45,7 +45,7 @@
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'code': 'no_intent_match',
|
'code': 'no_intent_match',
|
||||||
}),
|
}),
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'response_type': 'error',
|
'response_type': 'error',
|
||||||
'speech': dict({
|
'speech': dict({
|
||||||
'plain': dict({
|
'plain': dict({
|
||||||
|
@ -66,7 +66,7 @@
|
||||||
# name: test_audio_pipeline.6
|
# name: test_audio_pipeline.6
|
||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
|
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||||
}),
|
}),
|
||||||
|
@ -74,8 +74,8 @@
|
||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_debug
|
# name: test_audio_pipeline_debug
|
||||||
dict({
|
dict({
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'pipeline': 'en-US',
|
'pipeline': 'en',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': 1,
|
'stt_binary_handler_id': 1,
|
||||||
'timeout': 30,
|
'timeout': 30,
|
||||||
|
@ -118,7 +118,7 @@
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'code': 'no_intent_match',
|
'code': 'no_intent_match',
|
||||||
}),
|
}),
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'response_type': 'error',
|
'response_type': 'error',
|
||||||
'speech': dict({
|
'speech': dict({
|
||||||
'plain': dict({
|
'plain': dict({
|
||||||
|
@ -139,7 +139,7 @@
|
||||||
# name: test_audio_pipeline_debug.6
|
# name: test_audio_pipeline_debug.6
|
||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
|
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||||
}),
|
}),
|
||||||
|
@ -147,8 +147,8 @@
|
||||||
# ---
|
# ---
|
||||||
# name: test_intent_failed
|
# name: test_intent_failed
|
||||||
dict({
|
dict({
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'pipeline': 'en-US',
|
'pipeline': 'en',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': None,
|
'stt_binary_handler_id': None,
|
||||||
'timeout': 30,
|
'timeout': 30,
|
||||||
|
@ -163,8 +163,8 @@
|
||||||
# ---
|
# ---
|
||||||
# name: test_intent_timeout
|
# name: test_intent_timeout
|
||||||
dict({
|
dict({
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'pipeline': 'en-US',
|
'pipeline': 'en',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': None,
|
'stt_binary_handler_id': None,
|
||||||
'timeout': 0.1,
|
'timeout': 0.1,
|
||||||
|
@ -185,8 +185,8 @@
|
||||||
# ---
|
# ---
|
||||||
# name: test_stt_provider_missing
|
# name: test_stt_provider_missing
|
||||||
dict({
|
dict({
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'pipeline': 'en-US',
|
'pipeline': 'en',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': 1,
|
'stt_binary_handler_id': 1,
|
||||||
'timeout': 30,
|
'timeout': 30,
|
||||||
|
@ -201,15 +201,15 @@
|
||||||
'channel': 1,
|
'channel': 1,
|
||||||
'codec': 'pcm',
|
'codec': 'pcm',
|
||||||
'format': 'wav',
|
'format': 'wav',
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'sample_rate': 16000,
|
'sample_rate': 16000,
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_stt_stream_failed
|
# name: test_stt_stream_failed
|
||||||
dict({
|
dict({
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'pipeline': 'en-US',
|
'pipeline': 'en',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': 1,
|
'stt_binary_handler_id': 1,
|
||||||
'timeout': 30,
|
'timeout': 30,
|
||||||
|
@ -231,8 +231,8 @@
|
||||||
# ---
|
# ---
|
||||||
# name: test_text_only_pipeline
|
# name: test_text_only_pipeline
|
||||||
dict({
|
dict({
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'pipeline': 'en-US',
|
'pipeline': 'en',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': None,
|
'stt_binary_handler_id': None,
|
||||||
'timeout': 30,
|
'timeout': 30,
|
||||||
|
@ -255,7 +255,7 @@
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'code': 'no_intent_match',
|
'code': 'no_intent_match',
|
||||||
}),
|
}),
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'response_type': 'error',
|
'response_type': 'error',
|
||||||
'speech': dict({
|
'speech': dict({
|
||||||
'plain': dict({
|
'plain': dict({
|
||||||
|
@ -275,8 +275,8 @@
|
||||||
# ---
|
# ---
|
||||||
# name: test_tts_failed
|
# name: test_tts_failed
|
||||||
dict({
|
dict({
|
||||||
'language': 'en-US',
|
'language': 'en',
|
||||||
'pipeline': 'en-US',
|
'pipeline': 'en',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': None,
|
'stt_binary_handler_id': None,
|
||||||
'timeout': 30,
|
'timeout': 30,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue