Add audio_settings for pipeline from ESPHome device (#100894)

* Add audio_settings for pipeline from ESPHome device

* ruff fixes

* Bump aioesphomeapi 17.0.0

* Mypy

* Fix tests

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Jesse Hills 2023-09-27 10:27:26 +13:00 committed by GitHub
parent f899e5159b
commit 4c21aa18db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 50 additions and 80 deletions

View file

@ -7,14 +7,20 @@ import logging
import socket
from typing import cast
from aioesphomeapi import VoiceAssistantCommandFlag, VoiceAssistantEventType
from aioesphomeapi import (
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
)
from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import (
AudioSettings,
PipelineEvent,
PipelineEventType,
PipelineNotFound,
PipelineStage,
WakeWordSettings,
async_pipeline_from_audio_stream,
select as pipeline_select,
)
@ -64,7 +70,6 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
entry_data: RuntimeEntryData,
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
handle_finished: Callable[[], None],
audio_timeout: float = 2.0,
) -> None:
"""Initialize UDP receiver."""
self.context = Context()
@ -78,7 +83,6 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self.handle_event = handle_event
self.handle_finished = handle_finished
self._tts_done = asyncio.Event()
self.audio_timeout = audio_timeout
async def start_server(self) -> int:
"""Start accepting connections."""
@ -212,9 +216,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
device_id: str,
conversation_id: str | None,
flags: int = 0,
pipeline_timeout: float = 30.0,
audio_settings: VoiceAssistantAudioSettings | None = None,
) -> None:
"""Run the Voice Assistant pipeline."""
if audio_settings is None:
audio_settings = VoiceAssistantAudioSettings()
tts_audio_output = (
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
@ -226,31 +232,36 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
else:
start_stage = PipelineStage.STT
try:
async with asyncio.timeout(pipeline_timeout):
await async_pipeline_from_audio_stream(
self.hass,
context=self.context,
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=self._iterate_packets(),
pipeline_id=pipeline_select.get_chosen_pipeline(
self.hass, DOMAIN, self.device_info.mac_address
),
conversation_id=conversation_id,
device_id=device_id,
tts_audio_output=tts_audio_output,
start_stage=start_stage,
)
await async_pipeline_from_audio_stream(
self.hass,
context=self.context,
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=self._iterate_packets(),
pipeline_id=pipeline_select.get_chosen_pipeline(
self.hass, DOMAIN, self.device_info.mac_address
),
conversation_id=conversation_id,
device_id=device_id,
tts_audio_output=tts_audio_output,
start_stage=start_stage,
wake_word_settings=WakeWordSettings(timeout=5),
audio_settings=AudioSettings(
noise_suppression_level=audio_settings.noise_suppression_level,
auto_gain_dbfs=audio_settings.auto_gain,
volume_multiplier=audio_settings.volume_multiplier,
),
)
# Block until TTS is done sending
await self._tts_done.wait()
# Block until TTS is done sending
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
except PipelineNotFound:
@ -271,18 +282,6 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
},
)
_LOGGER.warning("No Wake word provider found")
except asyncio.TimeoutError:
if self.stopped:
# The pipeline was stopped gracefully
return
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": "pipeline-timeout",
"message": "Pipeline timeout",
},
)
_LOGGER.warning("Pipeline timeout")
finally:
self.handle_finished()