Add audio_settings for pipeline from ESPHome device (#100894)

* Add audio_settings for pipeline from ESPHome device

* ruff fixes

* Bump aioesphomeapi 17.0.0

* Mypy

* Fix tests

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Jesse Hills 2023-09-27 10:27:26 +13:00 committed by GitHub
parent f899e5159b
commit 4c21aa18db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 50 additions and 80 deletions

View file

@ -16,6 +16,7 @@ from aioesphomeapi import (
RequiresEncryptionAPIError, RequiresEncryptionAPIError,
UserService, UserService,
UserServiceArgType, UserServiceArgType,
VoiceAssistantAudioSettings,
VoiceAssistantEventType, VoiceAssistantEventType,
) )
from awesomeversion import AwesomeVersion from awesomeversion import AwesomeVersion
@ -319,7 +320,10 @@ class ESPHomeManager:
self.voice_assistant_udp_server = None self.voice_assistant_udp_server = None
async def _handle_pipeline_start( async def _handle_pipeline_start(
self, conversation_id: str, flags: int self,
conversation_id: str,
flags: int,
audio_settings: VoiceAssistantAudioSettings,
) -> int | None: ) -> int | None:
"""Start a voice assistant pipeline.""" """Start a voice assistant pipeline."""
if self.voice_assistant_udp_server is not None: if self.voice_assistant_udp_server is not None:
@ -340,6 +344,7 @@ class ESPHomeManager:
device_id=self.device_id, device_id=self.device_id,
conversation_id=conversation_id or None, conversation_id=conversation_id or None,
flags=flags, flags=flags,
audio_settings=audio_settings,
), ),
"esphome.voice_assistant_udp_server.run_pipeline", "esphome.voice_assistant_udp_server.run_pipeline",
) )

View file

@ -16,7 +16,7 @@
"loggers": ["aioesphomeapi", "noiseprotocol"], "loggers": ["aioesphomeapi", "noiseprotocol"],
"requirements": [ "requirements": [
"async-interrupt==1.1.1", "async-interrupt==1.1.1",
"aioesphomeapi==16.0.6", "aioesphomeapi==17.0.0",
"bluetooth-data-tools==1.12.0", "bluetooth-data-tools==1.12.0",
"esphome-dashboard-api==1.2.3" "esphome-dashboard-api==1.2.3"
], ],

View file

@ -7,14 +7,20 @@ import logging
import socket import socket
from typing import cast from typing import cast
from aioesphomeapi import VoiceAssistantCommandFlag, VoiceAssistantEventType from aioesphomeapi import (
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
)
from homeassistant.components import stt, tts from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import ( from homeassistant.components.assist_pipeline import (
AudioSettings,
PipelineEvent, PipelineEvent,
PipelineEventType, PipelineEventType,
PipelineNotFound, PipelineNotFound,
PipelineStage, PipelineStage,
WakeWordSettings,
async_pipeline_from_audio_stream, async_pipeline_from_audio_stream,
select as pipeline_select, select as pipeline_select,
) )
@ -64,7 +70,6 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
entry_data: RuntimeEntryData, entry_data: RuntimeEntryData,
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
handle_finished: Callable[[], None], handle_finished: Callable[[], None],
audio_timeout: float = 2.0,
) -> None: ) -> None:
"""Initialize UDP receiver.""" """Initialize UDP receiver."""
self.context = Context() self.context = Context()
@ -78,7 +83,6 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self.handle_event = handle_event self.handle_event = handle_event
self.handle_finished = handle_finished self.handle_finished = handle_finished
self._tts_done = asyncio.Event() self._tts_done = asyncio.Event()
self.audio_timeout = audio_timeout
async def start_server(self) -> int: async def start_server(self) -> int:
"""Start accepting connections.""" """Start accepting connections."""
@ -212,9 +216,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
device_id: str, device_id: str,
conversation_id: str | None, conversation_id: str | None,
flags: int = 0, flags: int = 0,
pipeline_timeout: float = 30.0, audio_settings: VoiceAssistantAudioSettings | None = None,
) -> None: ) -> None:
"""Run the Voice Assistant pipeline.""" """Run the Voice Assistant pipeline."""
if audio_settings is None:
audio_settings = VoiceAssistantAudioSettings()
tts_audio_output = ( tts_audio_output = (
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3" "raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
@ -226,31 +232,36 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
else: else:
start_stage = PipelineStage.STT start_stage = PipelineStage.STT
try: try:
async with asyncio.timeout(pipeline_timeout): await async_pipeline_from_audio_stream(
await async_pipeline_from_audio_stream( self.hass,
self.hass, context=self.context,
context=self.context, event_callback=self._event_callback,
event_callback=self._event_callback, stt_metadata=stt.SpeechMetadata(
stt_metadata=stt.SpeechMetadata( language="", # set in async_pipeline_from_audio_stream
language="", # set in async_pipeline_from_audio_stream format=stt.AudioFormats.WAV,
format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM,
codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16,
bit_rate=stt.AudioBitRates.BITRATE_16, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO,
channel=stt.AudioChannels.CHANNEL_MONO, ),
), stt_stream=self._iterate_packets(),
stt_stream=self._iterate_packets(), pipeline_id=pipeline_select.get_chosen_pipeline(
pipeline_id=pipeline_select.get_chosen_pipeline( self.hass, DOMAIN, self.device_info.mac_address
self.hass, DOMAIN, self.device_info.mac_address ),
), conversation_id=conversation_id,
conversation_id=conversation_id, device_id=device_id,
device_id=device_id, tts_audio_output=tts_audio_output,
tts_audio_output=tts_audio_output, start_stage=start_stage,
start_stage=start_stage, wake_word_settings=WakeWordSettings(timeout=5),
) audio_settings=AudioSettings(
noise_suppression_level=audio_settings.noise_suppression_level,
auto_gain_dbfs=audio_settings.auto_gain,
volume_multiplier=audio_settings.volume_multiplier,
),
)
# Block until TTS is done sending # Block until TTS is done sending
await self._tts_done.wait() await self._tts_done.wait()
_LOGGER.debug("Pipeline finished") _LOGGER.debug("Pipeline finished")
except PipelineNotFound: except PipelineNotFound:
@ -271,18 +282,6 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
}, },
) )
_LOGGER.warning("No Wake word provider found") _LOGGER.warning("No Wake word provider found")
except asyncio.TimeoutError:
if self.stopped:
# The pipeline was stopped gracefully
return
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": "pipeline-timeout",
"message": "Pipeline timeout",
},
)
_LOGGER.warning("Pipeline timeout")
finally: finally:
self.handle_finished() self.handle_finished()

View file

@ -231,7 +231,7 @@ aioecowitt==2023.5.0
aioemonitor==1.0.5 aioemonitor==1.0.5
# homeassistant.components.esphome # homeassistant.components.esphome
aioesphomeapi==16.0.6 aioesphomeapi==17.0.0
# homeassistant.components.flo # homeassistant.components.flo
aioflo==2021.11.0 aioflo==2021.11.0

View file

@ -212,7 +212,7 @@ aioecowitt==2023.5.0
aioemonitor==1.0.5 aioemonitor==1.0.5
# homeassistant.components.esphome # homeassistant.components.esphome
aioesphomeapi==16.0.6 aioesphomeapi==17.0.0
# homeassistant.components.flo # homeassistant.components.flo
aioflo==2021.11.0 aioflo==2021.11.0

View file

@ -10,7 +10,6 @@ import pytest
from homeassistant.components.assist_pipeline import ( from homeassistant.components.assist_pipeline import (
PipelineEvent, PipelineEvent,
PipelineEventType, PipelineEventType,
PipelineNotFound,
PipelineStage, PipelineStage,
) )
from homeassistant.components.assist_pipeline.error import WakeWordDetectionError from homeassistant.components.assist_pipeline.error import WakeWordDetectionError
@ -370,6 +369,8 @@ async def test_wake_word(
with patch( with patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream, new=async_pipeline_from_audio_stream,
), patch(
"asyncio.Event.wait" # TTS wait event
): ):
voice_assistant_udp_server_v2.transport = Mock() voice_assistant_udp_server_v2.transport = Mock()
@ -377,7 +378,6 @@ async def test_wake_word(
device_id="mock-device-id", device_id="mock-device-id",
conversation_id=None, conversation_id=None,
flags=2, flags=2,
pipeline_timeout=1,
) )
@ -410,38 +410,4 @@ async def test_wake_word_exception(
device_id="mock-device-id", device_id="mock-device-id",
conversation_id=None, conversation_id=None,
flags=2, flags=2,
pipeline_timeout=1,
)
async def test_pipeline_timeout(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None:
"""Test that the pipeline is set to start with Wake word."""
async def async_pipeline_from_audio_stream(*args, **kwargs):
raise PipelineNotFound("not-found", "Pipeline not found")
with patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
voice_assistant_udp_server_v2.transport = Mock()
def handle_event(
event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None:
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
assert data is not None
assert data["code"] == "pipeline not found"
assert data["message"] == "Selected pipeline not found"
voice_assistant_udp_server_v2.handle_event = handle_event
await voice_assistant_udp_server_v2.run_pipeline(
device_id="mock-device-id",
conversation_id=None,
flags=2,
pipeline_timeout=1,
) )