Use configured voice in TTS output for assist pipeline (#91878)

This commit is contained in:
Paulus Schoutsen 2023-04-22 22:01:32 -04:00 committed by GitHub
parent 33808cd268
commit 1eef4af493
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 21 additions and 11 deletions

View file

@ -48,7 +48,7 @@ async def async_pipeline_from_audio_stream(
stt_stream: AsyncIterable[bytes],
pipeline_id: str | None = None,
conversation_id: str | None = None,
tts_options: dict | None = None,
tts_audio_output: str | None = None,
) -> None:
"""Create an audio pipeline from an audio stream."""
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
@ -71,7 +71,7 @@ async def async_pipeline_from_audio_stream(
start_stage=PipelineStage.STT,
end_stage=PipelineStage.TTS,
event_callback=event_callback,
tts_options=tts_options,
tts_audio_output=tts_audio_output,
),
)

View file

@ -206,9 +206,10 @@ class PipelineRun:
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
intent_agent: str | None = None
tts_engine: str | None = None
tts_options: dict | None = None
tts_audio_output: str | None = None
id: str = field(default_factory=ulid_util.ulid)
tts_options: dict | None = field(init=False, default=None)
def __post_init__(self) -> None:
"""Set language for pipeline."""
@ -428,21 +429,29 @@ class PipelineRun:
message=f"Text to speech engine '{engine}' not found",
)
tts_options = {}
if self.pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = self.pipeline.tts_voice
if self.tts_audio_output is not None:
tts_options[tts.ATTR_AUDIO_OUTPUT] = self.tts_audio_output
if not await tts.async_support_options(
self.hass,
engine,
self.language,
self.tts_options,
tts_options,
):
raise TextToSpeechError(
code="tts-not-supported",
message=(
f"Text to speech engine {engine} "
f"does not support language {self.language} or options {self.tts_options}"
f"does not support language {self.language} or options {tts_options}"
),
)
self.tts_engine = engine
self.tts_options = tts_options
async def text_to_speech(self, tts_input: str) -> str:
"""Run text to speech portion of pipeline. Returns URL of TTS audio."""

View file

@ -6,6 +6,7 @@ import voluptuous as vol
from homeassistant.components.tts import (
ATTR_AUDIO_OUTPUT,
ATTR_VOICE,
CONF_LANG,
PLATFORM_SCHEMA,
Provider,
@ -16,7 +17,6 @@ from homeassistant.core import callback
from .const import DOMAIN
ATTR_GENDER = "gender"
ATTR_VOICE = "voice"
SUPPORT_LANGUAGES = list(TTS_VOICES)

View file

@ -89,6 +89,7 @@ _LOGGER = logging.getLogger(__name__)
ATTR_PLATFORM = "platform"
ATTR_AUDIO_OUTPUT = "audio_output"
ATTR_MEDIA_PLAYER_ENTITY_ID = "media_player_entity_id"
ATTR_VOICE = "voice"
CONF_LANG = "language"

View file

@ -181,7 +181,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
self.hass, DOMAIN, self.voip_device.voip_id
),
conversation_id=self._conversation_id,
tts_options={tts.ATTR_AUDIO_OUTPUT: "raw"},
tts_audio_output="raw",
)
except asyncio.TimeoutError:

View file

@ -154,9 +154,9 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=Arnold+Schwarzenegger",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3',
}),
}),
'type': <PipelineEventType.TTS_END: 'tts-end'>,
@ -238,9 +238,9 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=Arnold+Schwarzenegger",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3',
}),
}),
'type': <PipelineEventType.TTS_END: 'tts-end'>,