Skip TTS when text is empty (#104741)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
34c65749e2
commit
90bcad31b5
7 changed files with 225 additions and 43 deletions
|
@ -1024,39 +1024,38 @@ class PipelineRun:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
if tts_input := tts_input.strip():
|
||||||
# Synthesize audio and get URL
|
try:
|
||||||
tts_media_id = tts_generate_media_source_id(
|
# Synthesize audio and get URL
|
||||||
self.hass,
|
tts_media_id = tts_generate_media_source_id(
|
||||||
tts_input,
|
self.hass,
|
||||||
engine=self.tts_engine,
|
tts_input,
|
||||||
language=self.pipeline.tts_language,
|
engine=self.tts_engine,
|
||||||
options=self.tts_options,
|
language=self.pipeline.tts_language,
|
||||||
)
|
options=self.tts_options,
|
||||||
tts_media = await media_source.async_resolve_media(
|
)
|
||||||
self.hass,
|
tts_media = await media_source.async_resolve_media(
|
||||||
tts_media_id,
|
self.hass,
|
||||||
None,
|
tts_media_id,
|
||||||
)
|
None,
|
||||||
except Exception as src_error:
|
)
|
||||||
_LOGGER.exception("Unexpected error during text-to-speech")
|
except Exception as src_error:
|
||||||
raise TextToSpeechError(
|
_LOGGER.exception("Unexpected error during text-to-speech")
|
||||||
code="tts-failed",
|
raise TextToSpeechError(
|
||||||
message="Unexpected error during text-to-speech",
|
code="tts-failed",
|
||||||
) from src_error
|
message="Unexpected error during text-to-speech",
|
||||||
|
) from src_error
|
||||||
|
|
||||||
_LOGGER.debug("TTS result %s", tts_media)
|
_LOGGER.debug("TTS result %s", tts_media)
|
||||||
|
tts_output = {
|
||||||
|
"media_id": tts_media_id,
|
||||||
|
**asdict(tts_media),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
tts_output = {}
|
||||||
|
|
||||||
self.process_event(
|
self.process_event(
|
||||||
PipelineEvent(
|
PipelineEvent(PipelineEventType.TTS_END, {"tts_output": tts_output})
|
||||||
PipelineEventType.TTS_END,
|
|
||||||
{
|
|
||||||
"tts_output": {
|
|
||||||
"media_id": tts_media_id,
|
|
||||||
**asdict(tts_media),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return tts_media.url
|
return tts_media.url
|
||||||
|
|
|
@ -186,16 +186,22 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
data_to_send = {"text": event.data["tts_input"]}
|
data_to_send = {"text": event.data["tts_input"]}
|
||||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
||||||
assert event.data is not None
|
assert event.data is not None
|
||||||
path = event.data["tts_output"]["url"]
|
tts_output = event.data["tts_output"]
|
||||||
url = async_process_play_media_url(self.hass, path)
|
if tts_output:
|
||||||
data_to_send = {"url": url}
|
path = tts_output["url"]
|
||||||
|
url = async_process_play_media_url(self.hass, path)
|
||||||
|
data_to_send = {"url": url}
|
||||||
|
|
||||||
if self.device_info.voice_assistant_version >= 2:
|
if self.device_info.voice_assistant_version >= 2:
|
||||||
media_id = event.data["tts_output"]["media_id"]
|
media_id = tts_output["media_id"]
|
||||||
self._tts_task = self.hass.async_create_background_task(
|
self._tts_task = self.hass.async_create_background_task(
|
||||||
self._send_tts(media_id), "esphome_voice_assistant_tts"
|
self._send_tts(media_id), "esphome_voice_assistant_tts"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self._tts_done.set()
|
||||||
else:
|
else:
|
||||||
|
# Empty TTS response
|
||||||
|
data_to_send = {}
|
||||||
self._tts_done.set()
|
self._tts_done.set()
|
||||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
||||||
assert event.data is not None
|
assert event.data is not None
|
||||||
|
|
|
@ -389,11 +389,16 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||||
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
||||||
elif event.type == PipelineEventType.TTS_END:
|
elif event.type == PipelineEventType.TTS_END:
|
||||||
# Send TTS audio to caller over RTP
|
# Send TTS audio to caller over RTP
|
||||||
media_id = event.data["tts_output"]["media_id"]
|
tts_output = event.data["tts_output"]
|
||||||
self.hass.async_create_background_task(
|
if tts_output:
|
||||||
self._send_tts(media_id),
|
media_id = tts_output["media_id"]
|
||||||
"voip_pipeline_tts",
|
self.hass.async_create_background_task(
|
||||||
)
|
self._send_tts(media_id),
|
||||||
|
"voip_pipeline_tts",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Empty TTS response
|
||||||
|
self._tts_done.set()
|
||||||
elif event.type == PipelineEventType.ERROR:
|
elif event.type == PipelineEventType.ERROR:
|
||||||
# Play error tone instead of wait for TTS
|
# Play error tone instead of wait for TTS
|
||||||
self._pipeline_error = True
|
self._pipeline_error = True
|
||||||
|
|
|
@ -650,6 +650,33 @@
|
||||||
'message': 'Timeout running pipeline',
|
'message': 'Timeout running pipeline',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_pipeline_empty_tts_output
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': None,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_pipeline_empty_tts_output.1
|
||||||
|
dict({
|
||||||
|
'engine': 'test',
|
||||||
|
'language': 'en-US',
|
||||||
|
'tts_input': '',
|
||||||
|
'voice': 'james_earl_jones',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_pipeline_empty_tts_output.2
|
||||||
|
dict({
|
||||||
|
'tts_output': dict({
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_pipeline_empty_tts_output.3
|
||||||
|
None
|
||||||
|
# ---
|
||||||
# name: test_stt_provider_missing
|
# name: test_stt_provider_missing
|
||||||
dict({
|
dict({
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
|
|
|
@ -2452,3 +2452,54 @@ async def test_device_capture_queue_full(
|
||||||
assert msg["event"] == snapshot
|
assert msg["event"] == snapshot
|
||||||
assert msg["event"]["type"] == "end"
|
assert msg["event"]["type"] == "end"
|
||||||
assert msg["event"]["overflow"]
|
assert msg["event"]["overflow"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_empty_tts_output(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test events from a pipeline run with a empty text-to-speech text."""
|
||||||
|
events = []
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "tts",
|
||||||
|
"end_stage": "tts",
|
||||||
|
"input": {
|
||||||
|
"text": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
msg["event"]["data"]["pipeline"] = ANY
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
# text-to-speech
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "tts-start"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "tts-end"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
assert not msg["event"]["data"]["tts_output"]
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
||||||
|
# run end
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-end"
|
||||||
|
assert msg["event"]["data"] == snapshot
|
||||||
|
events.append(msg["event"])
|
||||||
|
|
|
@ -337,6 +337,28 @@ async def test_send_tts_called(
|
||||||
mock_send_tts.assert_called_with(_TEST_MEDIA_ID)
|
mock_send_tts.assert_called_with(_TEST_MEDIA_ID)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_send_tts_not_called_when_empty(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
||||||
|
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||||
|
) -> None:
|
||||||
|
"""Test the UDP server with a v1/v2 device doesn't call _send_tts when the output is empty."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._send_tts"
|
||||||
|
) as mock_send_tts:
|
||||||
|
voice_assistant_udp_server_v1._event_callback(
|
||||||
|
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_send_tts.assert_not_called()
|
||||||
|
|
||||||
|
voice_assistant_udp_server_v2._event_callback(
|
||||||
|
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_send_tts.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts(
|
async def test_send_tts(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||||
|
|
|
@ -528,3 +528,75 @@ async def test_tts_wrong_wav_format(
|
||||||
# Wait for mock pipeline to exhaust the audio stream
|
# Wait for mock pipeline to exhaust the audio stream
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
await done.wait()
|
await done.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_empty_tts_output(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voip_device: VoIPDevice,
|
||||||
|
) -> None:
|
||||||
|
"""Test that TTS will not stream when output is empty."""
|
||||||
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
|
def is_speech(self, chunk):
|
||||||
|
"""Anything non-zero is speech."""
|
||||||
|
return sum(chunk) > 0
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
|
stt_stream = kwargs["stt_stream"]
|
||||||
|
event_callback = kwargs["event_callback"]
|
||||||
|
async for _chunk in stt_stream:
|
||||||
|
# Stream will end when VAD detects end of "speech"
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fake intent result
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.INTENT_END,
|
||||||
|
data={
|
||||||
|
"intent_output": {
|
||||||
|
"conversation_id": "fake-conversation",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Empty TTS output
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||||
|
data={"tts_output": {}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
|
||||||
|
new=is_speech,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.voip.voip.PipelineRtpDatagramProtocol._send_tts",
|
||||||
|
) as mock_send_tts:
|
||||||
|
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||||
|
hass,
|
||||||
|
hass.config.language,
|
||||||
|
voip_device,
|
||||||
|
Context(),
|
||||||
|
opus_payload_type=123,
|
||||||
|
)
|
||||||
|
rtp_protocol.transport = Mock()
|
||||||
|
|
||||||
|
# silence
|
||||||
|
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
|
# "speech"
|
||||||
|
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||||
|
|
||||||
|
# silence (assumes relaxed VAD sensitivity)
|
||||||
|
rtp_protocol.on_chunk(bytes(_ONE_SECOND * 4))
|
||||||
|
|
||||||
|
# Wait for mock pipeline to finish
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await rtp_protocol._tts_done.wait()
|
||||||
|
|
||||||
|
mock_send_tts.assert_not_called()
|
||||||
|
|
Loading…
Add table
Reference in a new issue