Add audio_settings for pipeline from ESPHome device (#100894)
* Add audio_settings for pipeline from ESPHome device * ruff fixes * Bump aioesphomeapi 17.0.0 * Mypy * Fix tests --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
f899e5159b
commit
4c21aa18db
6 changed files with 50 additions and 80 deletions
|
@ -16,6 +16,7 @@ from aioesphomeapi import (
|
||||||
RequiresEncryptionAPIError,
|
RequiresEncryptionAPIError,
|
||||||
UserService,
|
UserService,
|
||||||
UserServiceArgType,
|
UserServiceArgType,
|
||||||
|
VoiceAssistantAudioSettings,
|
||||||
VoiceAssistantEventType,
|
VoiceAssistantEventType,
|
||||||
)
|
)
|
||||||
from awesomeversion import AwesomeVersion
|
from awesomeversion import AwesomeVersion
|
||||||
|
@ -319,7 +320,10 @@ class ESPHomeManager:
|
||||||
self.voice_assistant_udp_server = None
|
self.voice_assistant_udp_server = None
|
||||||
|
|
||||||
async def _handle_pipeline_start(
|
async def _handle_pipeline_start(
|
||||||
self, conversation_id: str, flags: int
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
flags: int,
|
||||||
|
audio_settings: VoiceAssistantAudioSettings,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
"""Start a voice assistant pipeline."""
|
"""Start a voice assistant pipeline."""
|
||||||
if self.voice_assistant_udp_server is not None:
|
if self.voice_assistant_udp_server is not None:
|
||||||
|
@ -340,6 +344,7 @@ class ESPHomeManager:
|
||||||
device_id=self.device_id,
|
device_id=self.device_id,
|
||||||
conversation_id=conversation_id or None,
|
conversation_id=conversation_id or None,
|
||||||
flags=flags,
|
flags=flags,
|
||||||
|
audio_settings=audio_settings,
|
||||||
),
|
),
|
||||||
"esphome.voice_assistant_udp_server.run_pipeline",
|
"esphome.voice_assistant_udp_server.run_pipeline",
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
"loggers": ["aioesphomeapi", "noiseprotocol"],
|
"loggers": ["aioesphomeapi", "noiseprotocol"],
|
||||||
"requirements": [
|
"requirements": [
|
||||||
"async-interrupt==1.1.1",
|
"async-interrupt==1.1.1",
|
||||||
"aioesphomeapi==16.0.6",
|
"aioesphomeapi==17.0.0",
|
||||||
"bluetooth-data-tools==1.12.0",
|
"bluetooth-data-tools==1.12.0",
|
||||||
"esphome-dashboard-api==1.2.3"
|
"esphome-dashboard-api==1.2.3"
|
||||||
],
|
],
|
||||||
|
|
|
@ -7,14 +7,20 @@ import logging
|
||||||
import socket
|
import socket
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from aioesphomeapi import VoiceAssistantCommandFlag, VoiceAssistantEventType
|
from aioesphomeapi import (
|
||||||
|
VoiceAssistantAudioSettings,
|
||||||
|
VoiceAssistantCommandFlag,
|
||||||
|
VoiceAssistantEventType,
|
||||||
|
)
|
||||||
|
|
||||||
from homeassistant.components import stt, tts
|
from homeassistant.components import stt, tts
|
||||||
from homeassistant.components.assist_pipeline import (
|
from homeassistant.components.assist_pipeline import (
|
||||||
|
AudioSettings,
|
||||||
PipelineEvent,
|
PipelineEvent,
|
||||||
PipelineEventType,
|
PipelineEventType,
|
||||||
PipelineNotFound,
|
PipelineNotFound,
|
||||||
PipelineStage,
|
PipelineStage,
|
||||||
|
WakeWordSettings,
|
||||||
async_pipeline_from_audio_stream,
|
async_pipeline_from_audio_stream,
|
||||||
select as pipeline_select,
|
select as pipeline_select,
|
||||||
)
|
)
|
||||||
|
@ -64,7 +70,6 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
entry_data: RuntimeEntryData,
|
entry_data: RuntimeEntryData,
|
||||||
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
||||||
handle_finished: Callable[[], None],
|
handle_finished: Callable[[], None],
|
||||||
audio_timeout: float = 2.0,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize UDP receiver."""
|
"""Initialize UDP receiver."""
|
||||||
self.context = Context()
|
self.context = Context()
|
||||||
|
@ -78,7 +83,6 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
self.handle_event = handle_event
|
self.handle_event = handle_event
|
||||||
self.handle_finished = handle_finished
|
self.handle_finished = handle_finished
|
||||||
self._tts_done = asyncio.Event()
|
self._tts_done = asyncio.Event()
|
||||||
self.audio_timeout = audio_timeout
|
|
||||||
|
|
||||||
async def start_server(self) -> int:
|
async def start_server(self) -> int:
|
||||||
"""Start accepting connections."""
|
"""Start accepting connections."""
|
||||||
|
@ -212,9 +216,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
device_id: str,
|
device_id: str,
|
||||||
conversation_id: str | None,
|
conversation_id: str | None,
|
||||||
flags: int = 0,
|
flags: int = 0,
|
||||||
pipeline_timeout: float = 30.0,
|
audio_settings: VoiceAssistantAudioSettings | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run the Voice Assistant pipeline."""
|
"""Run the Voice Assistant pipeline."""
|
||||||
|
if audio_settings is None:
|
||||||
|
audio_settings = VoiceAssistantAudioSettings()
|
||||||
|
|
||||||
tts_audio_output = (
|
tts_audio_output = (
|
||||||
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
|
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
|
||||||
|
@ -226,31 +232,36 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
else:
|
else:
|
||||||
start_stage = PipelineStage.STT
|
start_stage = PipelineStage.STT
|
||||||
try:
|
try:
|
||||||
async with asyncio.timeout(pipeline_timeout):
|
await async_pipeline_from_audio_stream(
|
||||||
await async_pipeline_from_audio_stream(
|
self.hass,
|
||||||
self.hass,
|
context=self.context,
|
||||||
context=self.context,
|
event_callback=self._event_callback,
|
||||||
event_callback=self._event_callback,
|
stt_metadata=stt.SpeechMetadata(
|
||||||
stt_metadata=stt.SpeechMetadata(
|
language="", # set in async_pipeline_from_audio_stream
|
||||||
language="", # set in async_pipeline_from_audio_stream
|
format=stt.AudioFormats.WAV,
|
||||||
format=stt.AudioFormats.WAV,
|
codec=stt.AudioCodecs.PCM,
|
||||||
codec=stt.AudioCodecs.PCM,
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
),
|
||||||
),
|
stt_stream=self._iterate_packets(),
|
||||||
stt_stream=self._iterate_packets(),
|
pipeline_id=pipeline_select.get_chosen_pipeline(
|
||||||
pipeline_id=pipeline_select.get_chosen_pipeline(
|
self.hass, DOMAIN, self.device_info.mac_address
|
||||||
self.hass, DOMAIN, self.device_info.mac_address
|
),
|
||||||
),
|
conversation_id=conversation_id,
|
||||||
conversation_id=conversation_id,
|
device_id=device_id,
|
||||||
device_id=device_id,
|
tts_audio_output=tts_audio_output,
|
||||||
tts_audio_output=tts_audio_output,
|
start_stage=start_stage,
|
||||||
start_stage=start_stage,
|
wake_word_settings=WakeWordSettings(timeout=5),
|
||||||
)
|
audio_settings=AudioSettings(
|
||||||
|
noise_suppression_level=audio_settings.noise_suppression_level,
|
||||||
|
auto_gain_dbfs=audio_settings.auto_gain,
|
||||||
|
volume_multiplier=audio_settings.volume_multiplier,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Block until TTS is done sending
|
# Block until TTS is done sending
|
||||||
await self._tts_done.wait()
|
await self._tts_done.wait()
|
||||||
|
|
||||||
_LOGGER.debug("Pipeline finished")
|
_LOGGER.debug("Pipeline finished")
|
||||||
except PipelineNotFound:
|
except PipelineNotFound:
|
||||||
|
@ -271,18 +282,6 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
_LOGGER.warning("No Wake word provider found")
|
_LOGGER.warning("No Wake word provider found")
|
||||||
except asyncio.TimeoutError:
|
|
||||||
if self.stopped:
|
|
||||||
# The pipeline was stopped gracefully
|
|
||||||
return
|
|
||||||
self.handle_event(
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
|
||||||
{
|
|
||||||
"code": "pipeline-timeout",
|
|
||||||
"message": "Pipeline timeout",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
_LOGGER.warning("Pipeline timeout")
|
|
||||||
finally:
|
finally:
|
||||||
self.handle_finished()
|
self.handle_finished()
|
||||||
|
|
||||||
|
|
|
@ -231,7 +231,7 @@ aioecowitt==2023.5.0
|
||||||
aioemonitor==1.0.5
|
aioemonitor==1.0.5
|
||||||
|
|
||||||
# homeassistant.components.esphome
|
# homeassistant.components.esphome
|
||||||
aioesphomeapi==16.0.6
|
aioesphomeapi==17.0.0
|
||||||
|
|
||||||
# homeassistant.components.flo
|
# homeassistant.components.flo
|
||||||
aioflo==2021.11.0
|
aioflo==2021.11.0
|
||||||
|
|
|
@ -212,7 +212,7 @@ aioecowitt==2023.5.0
|
||||||
aioemonitor==1.0.5
|
aioemonitor==1.0.5
|
||||||
|
|
||||||
# homeassistant.components.esphome
|
# homeassistant.components.esphome
|
||||||
aioesphomeapi==16.0.6
|
aioesphomeapi==17.0.0
|
||||||
|
|
||||||
# homeassistant.components.flo
|
# homeassistant.components.flo
|
||||||
aioflo==2021.11.0
|
aioflo==2021.11.0
|
||||||
|
|
|
@ -10,7 +10,6 @@ import pytest
|
||||||
from homeassistant.components.assist_pipeline import (
|
from homeassistant.components.assist_pipeline import (
|
||||||
PipelineEvent,
|
PipelineEvent,
|
||||||
PipelineEventType,
|
PipelineEventType,
|
||||||
PipelineNotFound,
|
|
||||||
PipelineStage,
|
PipelineStage,
|
||||||
)
|
)
|
||||||
from homeassistant.components.assist_pipeline.error import WakeWordDetectionError
|
from homeassistant.components.assist_pipeline.error import WakeWordDetectionError
|
||||||
|
@ -370,6 +369,8 @@ async def test_wake_word(
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
|
), patch(
|
||||||
|
"asyncio.Event.wait" # TTS wait event
|
||||||
):
|
):
|
||||||
voice_assistant_udp_server_v2.transport = Mock()
|
voice_assistant_udp_server_v2.transport = Mock()
|
||||||
|
|
||||||
|
@ -377,7 +378,6 @@ async def test_wake_word(
|
||||||
device_id="mock-device-id",
|
device_id="mock-device-id",
|
||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
flags=2,
|
flags=2,
|
||||||
pipeline_timeout=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -410,38 +410,4 @@ async def test_wake_word_exception(
|
||||||
device_id="mock-device-id",
|
device_id="mock-device-id",
|
||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
flags=2,
|
flags=2,
|
||||||
pipeline_timeout=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline_timeout(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
|
||||||
) -> None:
|
|
||||||
"""Test that the pipeline is set to start with Wake word."""
|
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
|
||||||
raise PipelineNotFound("not-found", "Pipeline not found")
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
|
||||||
):
|
|
||||||
voice_assistant_udp_server_v2.transport = Mock()
|
|
||||||
|
|
||||||
def handle_event(
|
|
||||||
event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
|
||||||
) -> None:
|
|
||||||
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
|
||||||
assert data is not None
|
|
||||||
assert data["code"] == "pipeline not found"
|
|
||||||
assert data["message"] == "Selected pipeline not found"
|
|
||||||
|
|
||||||
voice_assistant_udp_server_v2.handle_event = handle_event
|
|
||||||
|
|
||||||
await voice_assistant_udp_server_v2.run_pipeline(
|
|
||||||
device_id="mock-device-id",
|
|
||||||
conversation_id=None,
|
|
||||||
flags=2,
|
|
||||||
pipeline_timeout=1,
|
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue