Add audio_settings for pipeline from ESPHome device (#100894)

* Add audio_settings for pipeline from ESPHome device

* ruff fixes

* Bump aioesphomeapi 17.0.0

* Mypy

* Fix tests

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Jesse Hills 2023-09-27 10:27:26 +13:00 committed by GitHub
parent f899e5159b
commit 4c21aa18db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 50 additions and 80 deletions

View file

@ -16,6 +16,7 @@ from aioesphomeapi import (
RequiresEncryptionAPIError,
UserService,
UserServiceArgType,
VoiceAssistantAudioSettings,
VoiceAssistantEventType,
)
from awesomeversion import AwesomeVersion
@ -319,7 +320,10 @@ class ESPHomeManager:
self.voice_assistant_udp_server = None
async def _handle_pipeline_start(
self, conversation_id: str, flags: int
self,
conversation_id: str,
flags: int,
audio_settings: VoiceAssistantAudioSettings,
) -> int | None:
"""Start a voice assistant pipeline."""
if self.voice_assistant_udp_server is not None:
@ -340,6 +344,7 @@ class ESPHomeManager:
device_id=self.device_id,
conversation_id=conversation_id or None,
flags=flags,
audio_settings=audio_settings,
),
"esphome.voice_assistant_udp_server.run_pipeline",
)

View file

@ -16,7 +16,7 @@
"loggers": ["aioesphomeapi", "noiseprotocol"],
"requirements": [
"async-interrupt==1.1.1",
"aioesphomeapi==16.0.6",
"aioesphomeapi==17.0.0",
"bluetooth-data-tools==1.12.0",
"esphome-dashboard-api==1.2.3"
],

View file

@ -7,14 +7,20 @@ import logging
import socket
from typing import cast
from aioesphomeapi import VoiceAssistantCommandFlag, VoiceAssistantEventType
from aioesphomeapi import (
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
)
from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import (
AudioSettings,
PipelineEvent,
PipelineEventType,
PipelineNotFound,
PipelineStage,
WakeWordSettings,
async_pipeline_from_audio_stream,
select as pipeline_select,
)
@ -64,7 +70,6 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
entry_data: RuntimeEntryData,
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
handle_finished: Callable[[], None],
audio_timeout: float = 2.0,
) -> None:
"""Initialize UDP receiver."""
self.context = Context()
@ -78,7 +83,6 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self.handle_event = handle_event
self.handle_finished = handle_finished
self._tts_done = asyncio.Event()
self.audio_timeout = audio_timeout
async def start_server(self) -> int:
"""Start accepting connections."""
@ -212,9 +216,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
device_id: str,
conversation_id: str | None,
flags: int = 0,
pipeline_timeout: float = 30.0,
audio_settings: VoiceAssistantAudioSettings | None = None,
) -> None:
"""Run the Voice Assistant pipeline."""
if audio_settings is None:
audio_settings = VoiceAssistantAudioSettings()
tts_audio_output = (
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
@ -226,7 +232,6 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
else:
start_stage = PipelineStage.STT
try:
async with asyncio.timeout(pipeline_timeout):
await async_pipeline_from_audio_stream(
self.hass,
context=self.context,
@ -247,6 +252,12 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
device_id=device_id,
tts_audio_output=tts_audio_output,
start_stage=start_stage,
wake_word_settings=WakeWordSettings(timeout=5),
audio_settings=AudioSettings(
noise_suppression_level=audio_settings.noise_suppression_level,
auto_gain_dbfs=audio_settings.auto_gain,
volume_multiplier=audio_settings.volume_multiplier,
),
)
# Block until TTS is done sending
@ -271,18 +282,6 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
},
)
_LOGGER.warning("No Wake word provider found")
except asyncio.TimeoutError:
if self.stopped:
# The pipeline was stopped gracefully
return
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": "pipeline-timeout",
"message": "Pipeline timeout",
},
)
_LOGGER.warning("Pipeline timeout")
finally:
self.handle_finished()

View file

@ -231,7 +231,7 @@ aioecowitt==2023.5.0
aioemonitor==1.0.5
# homeassistant.components.esphome
aioesphomeapi==16.0.6
aioesphomeapi==17.0.0
# homeassistant.components.flo
aioflo==2021.11.0

View file

@ -212,7 +212,7 @@ aioecowitt==2023.5.0
aioemonitor==1.0.5
# homeassistant.components.esphome
aioesphomeapi==16.0.6
aioesphomeapi==17.0.0
# homeassistant.components.flo
aioflo==2021.11.0

View file

@ -10,7 +10,6 @@ import pytest
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineNotFound,
PipelineStage,
)
from homeassistant.components.assist_pipeline.error import WakeWordDetectionError
@ -370,6 +369,8 @@ async def test_wake_word(
with patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
), patch(
"asyncio.Event.wait" # TTS wait event
):
voice_assistant_udp_server_v2.transport = Mock()
@ -377,7 +378,6 @@ async def test_wake_word(
device_id="mock-device-id",
conversation_id=None,
flags=2,
pipeline_timeout=1,
)
@ -410,38 +410,4 @@ async def test_wake_word_exception(
device_id="mock-device-id",
conversation_id=None,
flags=2,
pipeline_timeout=1,
)
async def test_pipeline_timeout(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None:
"""Test that the pipeline is set to start with Wake word."""
async def async_pipeline_from_audio_stream(*args, **kwargs):
raise PipelineNotFound("not-found", "Pipeline not found")
with patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
voice_assistant_udp_server_v2.transport = Mock()
def handle_event(
event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None:
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
assert data is not None
assert data["code"] == "pipeline not found"
assert data["message"] == "Selected pipeline not found"
voice_assistant_udp_server_v2.handle_event = handle_event
await voice_assistant_udp_server_v2.run_pipeline(
device_id="mock-device-id",
conversation_id=None,
flags=2,
pipeline_timeout=1,
)