Add speech detection and conversation_id to esphome voice assistant (#93578)
* Add speech detection to esphome voice assistant * Timeout after silence Ensure events are sent before finish is called * use va_version 3 instead of ESPHome version * Convert repeated fixtures to factory * Add some v3 tests * Add conversation_id * Bump aioesphomeapi to 13.8.0 * Fix missed buffering of detected chunk * Alter log message * Updates * Spelling * Fix return type
This commit is contained in:
parent
49f10eecaa
commit
d7d9143a44
7 changed files with 352 additions and 101 deletions
|
@ -302,7 +302,7 @@ async def async_setup_entry( # noqa: C901
|
|||
voice_assistant_udp_server.close()
|
||||
voice_assistant_udp_server = None
|
||||
|
||||
async def _handle_pipeline_start() -> int | None:
|
||||
async def _handle_pipeline_start(conversation_id: str, use_vad: bool) -> int | None:
|
||||
"""Start a voice assistant pipeline."""
|
||||
nonlocal voice_assistant_udp_server
|
||||
|
||||
|
@ -315,7 +315,10 @@ async def async_setup_entry( # noqa: C901
|
|||
port = await voice_assistant_udp_server.start_server()
|
||||
|
||||
hass.async_create_background_task(
|
||||
voice_assistant_udp_server.run_pipeline(),
|
||||
voice_assistant_udp_server.run_pipeline(
|
||||
conversation_id=conversation_id or None,
|
||||
use_vad=use_vad,
|
||||
),
|
||||
"esphome.voice_assistant_udp_server.run_pipeline",
|
||||
)
|
||||
entry_data.async_set_assist_pipeline_state(True)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"iot_class": "local_push",
|
||||
"loggers": ["aioesphomeapi", "noiseprotocol"],
|
||||
"requirements": [
|
||||
"aioesphomeapi==13.7.5",
|
||||
"aioesphomeapi==13.9.0",
|
||||
"bluetooth-data-tools==0.4.0",
|
||||
"esphome-dashboard-api==1.2.3"
|
||||
],
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable, Callable
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterable, Callable, MutableSequence, Sequence
|
||||
import logging
|
||||
import socket
|
||||
from typing import cast
|
||||
|
@ -17,6 +18,7 @@ from homeassistant.components.assist_pipeline import (
|
|||
async_pipeline_from_audio_stream,
|
||||
select as pipeline_select,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter
|
||||
from homeassistant.components.media_player import async_process_play_media_url
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
|
||||
|
@ -50,7 +52,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
"""Receive UDP packets and forward them to the voice assistant."""
|
||||
|
||||
started = False
|
||||
queue: asyncio.Queue[bytes] | None = None
|
||||
stopped = False
|
||||
transport: asyncio.DatagramTransport | None = None
|
||||
remote_addr: tuple[str, int] | None = None
|
||||
|
||||
|
@ -60,6 +62,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
entry_data: RuntimeEntryData,
|
||||
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
||||
handle_finished: Callable[[], None],
|
||||
audio_timeout: float = 2.0,
|
||||
) -> None:
|
||||
"""Initialize UDP receiver."""
|
||||
self.context = Context()
|
||||
|
@ -68,10 +71,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
assert entry_data.device_info is not None
|
||||
self.device_info = entry_data.device_info
|
||||
|
||||
self.queue = asyncio.Queue()
|
||||
self.queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self.handle_event = handle_event
|
||||
self.handle_finished = handle_finished
|
||||
self._tts_done = asyncio.Event()
|
||||
self.audio_timeout = audio_timeout
|
||||
|
||||
async def start_server(self) -> int:
|
||||
"""Start accepting connections."""
|
||||
|
@ -80,7 +84,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
"""Accept connection."""
|
||||
if self.started:
|
||||
raise RuntimeError("Can only start once")
|
||||
if self.queue is None:
|
||||
if self.stopped:
|
||||
raise RuntimeError("No longer accepting connections")
|
||||
|
||||
self.started = True
|
||||
|
@ -105,12 +109,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
@callback
|
||||
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||
"""Handle incoming UDP packet."""
|
||||
if not self.started:
|
||||
if not self.started or self.stopped:
|
||||
return
|
||||
if self.remote_addr is None:
|
||||
self.remote_addr = addr
|
||||
if self.queue is not None:
|
||||
self.queue.put_nowait(data)
|
||||
self.queue.put_nowait(data)
|
||||
|
||||
def error_received(self, exc: Exception) -> None:
|
||||
"""Handle when a send or receive operation raises an OSError.
|
||||
|
@ -123,21 +126,21 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
@callback
|
||||
def stop(self) -> None:
|
||||
"""Stop the receiver."""
|
||||
if self.queue is not None:
|
||||
self.queue.put_nowait(b"")
|
||||
self.queue.put_nowait(b"")
|
||||
self.started = False
|
||||
self.stopped = True
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the receiver."""
|
||||
if self.queue is not None:
|
||||
self.queue = None
|
||||
self.started = False
|
||||
self.stopped = True
|
||||
if self.transport is not None:
|
||||
self.transport.close()
|
||||
|
||||
async def _iterate_packets(self) -> AsyncIterable[bytes]:
|
||||
"""Iterate over incoming packets."""
|
||||
if self.queue is None:
|
||||
raise RuntimeError("Already stopped")
|
||||
if not self.started or self.stopped:
|
||||
raise RuntimeError("Not running")
|
||||
|
||||
while data := await self.queue.get():
|
||||
yield data
|
||||
|
@ -152,9 +155,15 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
return
|
||||
|
||||
data_to_send = None
|
||||
error = False
|
||||
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
|
||||
assert event.data is not None
|
||||
data_to_send = {"text": event.data["stt_output"]["text"]}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
|
||||
assert event.data is not None
|
||||
data_to_send = {
|
||||
"conversation_id": event.data["intent_output"]["conversation_id"] or "",
|
||||
}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
|
||||
assert event.data is not None
|
||||
data_to_send = {"text": event.data["tts_input"]}
|
||||
|
@ -177,19 +186,132 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
"code": event.data["code"],
|
||||
"message": event.data["message"],
|
||||
}
|
||||
self.handle_finished()
|
||||
self._tts_done.set()
|
||||
error = True
|
||||
|
||||
self.handle_event(event_type, data_to_send)
|
||||
if error:
|
||||
self.handle_finished()
|
||||
|
||||
async def _wait_for_speech(
|
||||
self,
|
||||
segmenter: VoiceCommandSegmenter,
|
||||
chunk_buffer: MutableSequence[bytes],
|
||||
) -> bool:
|
||||
"""Buffer audio chunks until speech is detected.
|
||||
|
||||
Raises asyncio.TimeoutError if no audio data is retrievable from the queue (device stops sending packets / networking issue).
|
||||
|
||||
Returns True if speech was detected
|
||||
Returns False if the connection was stopped gracefully (b"" put onto the queue).
|
||||
"""
|
||||
# Timeout if no audio comes in for a while.
|
||||
async with async_timeout.timeout(self.audio_timeout):
|
||||
chunk = await self.queue.get()
|
||||
|
||||
while chunk:
|
||||
segmenter.process(chunk)
|
||||
# Buffer the data we have taken from the queue
|
||||
chunk_buffer.append(chunk)
|
||||
if segmenter.in_command:
|
||||
return True
|
||||
|
||||
async with async_timeout.timeout(self.audio_timeout):
|
||||
chunk = await self.queue.get()
|
||||
|
||||
# If chunk is falsey, `stop()` was called
|
||||
return False
|
||||
|
||||
async def _segment_audio(
|
||||
self,
|
||||
segmenter: VoiceCommandSegmenter,
|
||||
chunk_buffer: Sequence[bytes],
|
||||
) -> AsyncIterable[bytes]:
|
||||
"""Yield audio chunks until voice command has finished.
|
||||
|
||||
Raises asyncio.TimeoutError if no audio data is retrievable from the queue.
|
||||
"""
|
||||
# Buffered chunks first
|
||||
for buffered_chunk in chunk_buffer:
|
||||
yield buffered_chunk
|
||||
|
||||
# Timeout if no audio comes in for a while.
|
||||
async with async_timeout.timeout(self.audio_timeout):
|
||||
chunk = await self.queue.get()
|
||||
|
||||
while chunk:
|
||||
if not segmenter.process(chunk):
|
||||
# Voice command is finished
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
async with async_timeout.timeout(self.audio_timeout):
|
||||
chunk = await self.queue.get()
|
||||
|
||||
async def _iterate_packets_with_vad(
|
||||
self, pipeline_timeout: float
|
||||
) -> Callable[[], AsyncIterable[bytes]] | None:
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
chunk_buffer: deque[bytes] = deque(maxlen=100)
|
||||
try:
|
||||
async with async_timeout.timeout(pipeline_timeout):
|
||||
speech_detected = await self._wait_for_speech(segmenter, chunk_buffer)
|
||||
if not speech_detected:
|
||||
_LOGGER.debug(
|
||||
"Device stopped sending audio before speech was detected"
|
||||
)
|
||||
self.handle_finished()
|
||||
return None
|
||||
except asyncio.TimeoutError:
|
||||
self.handle_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||
{
|
||||
"code": "speech-timeout",
|
||||
"message": "Timed out waiting for speech",
|
||||
},
|
||||
)
|
||||
self.handle_finished()
|
||||
return None
|
||||
|
||||
async def _stream_packets() -> AsyncIterable[bytes]:
|
||||
try:
|
||||
async for chunk in self._segment_audio(segmenter, chunk_buffer):
|
||||
yield chunk
|
||||
except asyncio.TimeoutError:
|
||||
self.handle_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||
{
|
||||
"code": "speech-timeout",
|
||||
"message": "No speech detected",
|
||||
},
|
||||
)
|
||||
self.handle_finished()
|
||||
|
||||
return _stream_packets
|
||||
|
||||
async def run_pipeline(
|
||||
self,
|
||||
conversation_id: str | None,
|
||||
use_vad: bool = False,
|
||||
pipeline_timeout: float = 30.0,
|
||||
) -> None:
|
||||
"""Run the Voice Assistant pipeline."""
|
||||
|
||||
tts_audio_output = (
|
||||
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
|
||||
)
|
||||
|
||||
if use_vad:
|
||||
stt_stream = await self._iterate_packets_with_vad(pipeline_timeout)
|
||||
# Error or timeout occurred and was handled already
|
||||
if stt_stream is None:
|
||||
return
|
||||
else:
|
||||
stt_stream = self._iterate_packets
|
||||
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
try:
|
||||
tts_audio_output = (
|
||||
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
|
||||
)
|
||||
async with async_timeout.timeout(pipeline_timeout):
|
||||
await async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
|
@ -203,10 +325,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=self._iterate_packets(),
|
||||
stt_stream=stt_stream(),
|
||||
pipeline_id=pipeline_select.get_chosen_pipeline(
|
||||
self.hass, DOMAIN, self.device_info.mac_address
|
||||
),
|
||||
conversation_id=conversation_id,
|
||||
tts_audio_output=tts_audio_output,
|
||||
)
|
||||
|
||||
|
@ -215,6 +338,13 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
except asyncio.TimeoutError:
|
||||
self.handle_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||
{
|
||||
"code": "pipeline-timeout",
|
||||
"message": "Pipeline timeout",
|
||||
},
|
||||
)
|
||||
_LOGGER.warning("Pipeline timeout")
|
||||
finally:
|
||||
self.handle_finished()
|
||||
|
|
|
@ -159,7 +159,7 @@ aioecowitt==2023.5.0
|
|||
aioemonitor==1.0.5
|
||||
|
||||
# homeassistant.components.esphome
|
||||
aioesphomeapi==13.7.5
|
||||
aioesphomeapi==13.9.0
|
||||
|
||||
# homeassistant.components.flo
|
||||
aioflo==2021.11.0
|
||||
|
|
|
@ -149,7 +149,7 @@ aioecowitt==2023.5.0
|
|||
aioemonitor==1.0.5
|
||||
|
||||
# homeassistant.components.esphome
|
||||
aioesphomeapi==13.7.5
|
||||
aioesphomeapi==13.9.0
|
||||
|
||||
# homeassistant.components.flo
|
||||
aioflo==2021.11.0
|
||||
|
|
|
@ -132,70 +132,51 @@ async def mock_dashboard(hass):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_voice_assistant_v1_entry(
|
||||
async def mock_voice_assistant_entry(
|
||||
hass: HomeAssistant,
|
||||
mock_client,
|
||||
) -> MockConfigEntry:
|
||||
"""Set up an ESPHome entry with voice assistant."""
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_HOST: "test.local",
|
||||
CONF_PORT: 6053,
|
||||
CONF_PASSWORD: "",
|
||||
},
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
device_info = DeviceInfo(
|
||||
name="test",
|
||||
friendly_name="Test",
|
||||
voice_assistant_version=1,
|
||||
mac_address="11:22:33:44:55:aa",
|
||||
esphome_version="1.0.0",
|
||||
)
|
||||
async def _mock_voice_assistant_entry(version: int):
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_HOST: "test.local",
|
||||
CONF_PORT: 6053,
|
||||
CONF_PASSWORD: "",
|
||||
},
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
mock_client.device_info = AsyncMock(return_value=device_info)
|
||||
mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock())
|
||||
device_info = DeviceInfo(
|
||||
name="test",
|
||||
friendly_name="Test",
|
||||
voice_assistant_version=version,
|
||||
mac_address="11:22:33:44:55:aa",
|
||||
esphome_version="1.0.0",
|
||||
)
|
||||
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
mock_client.device_info = AsyncMock(return_value=device_info)
|
||||
mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock())
|
||||
|
||||
return entry
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return entry
|
||||
|
||||
return _mock_voice_assistant_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_voice_assistant_v2_entry(
|
||||
hass: HomeAssistant,
|
||||
mock_client,
|
||||
) -> MockConfigEntry:
|
||||
async def mock_voice_assistant_v1_entry(mock_voice_assistant_entry) -> MockConfigEntry:
|
||||
"""Set up an ESPHome entry with voice assistant."""
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_HOST: "test.local",
|
||||
CONF_PORT: 6053,
|
||||
CONF_PASSWORD: "",
|
||||
},
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
return await mock_voice_assistant_entry(version=1)
|
||||
|
||||
device_info = DeviceInfo(
|
||||
name="test",
|
||||
friendly_name="Test",
|
||||
voice_assistant_version=2,
|
||||
mac_address="11:22:33:44:55:aa",
|
||||
esphome_version="1.0.0",
|
||||
)
|
||||
|
||||
mock_client.device_info = AsyncMock(return_value=device_info)
|
||||
mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock())
|
||||
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return entry
|
||||
@pytest.fixture
|
||||
async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry:
|
||||
"""Set up an ESPHome entry with voice assistant."""
|
||||
return await mock_voice_assistant_entry(version=2)
|
||||
|
|
|
@ -19,43 +19,47 @@ _TEST_OUTPUT_TEXT = "This is an output test"
|
|||
_TEST_OUTPUT_URL = "output.mp3"
|
||||
_TEST_MEDIA_ID = "12345"
|
||||
|
||||
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def voice_assistant_udp_server(
|
||||
hass: HomeAssistant,
|
||||
) -> VoiceAssistantUDPServer:
|
||||
"""Return the UDP server factory."""
|
||||
|
||||
def _voice_assistant_udp_server(entry):
|
||||
entry_data = DomainData.get(hass).get_entry_data(entry)
|
||||
|
||||
server: VoiceAssistantUDPServer = None
|
||||
|
||||
def handle_finished():
|
||||
nonlocal server
|
||||
assert server is not None
|
||||
server.close()
|
||||
|
||||
server = VoiceAssistantUDPServer(hass, entry_data, Mock(), handle_finished)
|
||||
return server
|
||||
|
||||
return _voice_assistant_udp_server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def voice_assistant_udp_server_v1(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server,
|
||||
mock_voice_assistant_v1_entry,
|
||||
) -> VoiceAssistantUDPServer:
|
||||
"""Return the UDP server."""
|
||||
entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_v1_entry)
|
||||
|
||||
server: VoiceAssistantUDPServer = None
|
||||
|
||||
def handle_finished():
|
||||
nonlocal server
|
||||
assert server is not None
|
||||
server.close()
|
||||
|
||||
server = VoiceAssistantUDPServer(hass, entry_data, Mock(), handle_finished)
|
||||
return server
|
||||
return voice_assistant_udp_server(entry=mock_voice_assistant_v1_entry)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def voice_assistant_udp_server_v2(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server,
|
||||
mock_voice_assistant_v2_entry,
|
||||
) -> VoiceAssistantUDPServer:
|
||||
"""Return the UDP server."""
|
||||
entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_v2_entry)
|
||||
|
||||
server: VoiceAssistantUDPServer = None
|
||||
|
||||
def handle_finished():
|
||||
nonlocal server
|
||||
assert server is not None
|
||||
server.close()
|
||||
|
||||
server = VoiceAssistantUDPServer(hass, entry_data, Mock(), handle_finished)
|
||||
return server
|
||||
return voice_assistant_udp_server(entry=mock_voice_assistant_v2_entry)
|
||||
|
||||
|
||||
async def test_pipeline_events(
|
||||
|
@ -117,7 +121,7 @@ async def test_pipeline_events(
|
|||
):
|
||||
voice_assistant_udp_server_v1.transport = Mock()
|
||||
|
||||
await voice_assistant_udp_server_v1.run_pipeline()
|
||||
await voice_assistant_udp_server_v1.run_pipeline(conversation_id=None)
|
||||
|
||||
|
||||
async def test_udp_server(
|
||||
|
@ -335,3 +339,136 @@ async def test_send_tts(
|
|||
await voice_assistant_udp_server_v2._tts_done.wait()
|
||||
|
||||
voice_assistant_udp_server_v2.transport.sendto.assert_called()
|
||||
|
||||
|
||||
async def test_speech_detection(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||
) -> None:
|
||||
"""Test the UDP server queues incoming data."""
|
||||
|
||||
def is_speech(self, chunk, sample_rate):
|
||||
"""Anything non-zero is speech."""
|
||||
return sum(chunk) > 0
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
pass
|
||||
|
||||
# Test empty data
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": _TEST_INPUT_TEXT}},
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"webrtcvad.Vad.is_speech",
|
||||
new=is_speech,
|
||||
), patch(
|
||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
voice_assistant_udp_server_v2.started = True
|
||||
|
||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
|
||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * _ONE_SECOND * 2))
|
||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * _ONE_SECOND * 2))
|
||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
|
||||
|
||||
await voice_assistant_udp_server_v2.run_pipeline(
|
||||
conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
||||
)
|
||||
|
||||
|
||||
async def test_no_speech(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||
) -> None:
|
||||
"""Test there is no speech."""
|
||||
|
||||
def is_speech(self, chunk, sample_rate):
|
||||
"""Anything non-zero is speech."""
|
||||
return sum(chunk) > 0
|
||||
|
||||
def handle_event(
|
||||
event_type: esphome.VoiceAssistantEventType, data: dict[str, str] | None
|
||||
) -> None:
|
||||
assert event_type == esphome.VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
|
||||
assert data is not None
|
||||
assert data["code"] == "speech-timeout"
|
||||
|
||||
voice_assistant_udp_server_v2.handle_event = handle_event
|
||||
|
||||
with patch(
|
||||
"webrtcvad.Vad.is_speech",
|
||||
new=is_speech,
|
||||
):
|
||||
voice_assistant_udp_server_v2.started = True
|
||||
|
||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
|
||||
|
||||
await voice_assistant_udp_server_v2.run_pipeline(
|
||||
conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
||||
)
|
||||
|
||||
|
||||
async def test_speech_timeout(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||
) -> None:
|
||||
"""Test when speech was detected, but the pipeline times out."""
|
||||
|
||||
def is_speech(self, chunk, sample_rate):
|
||||
"""Anything non-zero is speech."""
|
||||
return sum(chunk) > 255
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
pass
|
||||
|
||||
async def segment_audio(*args, **kwargs):
|
||||
raise asyncio.TimeoutError()
|
||||
async for chunk in []:
|
||||
yield chunk
|
||||
|
||||
with patch(
|
||||
"webrtcvad.Vad.is_speech",
|
||||
new=is_speech,
|
||||
), patch(
|
||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
), patch(
|
||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._segment_audio",
|
||||
new=segment_audio,
|
||||
):
|
||||
voice_assistant_udp_server_v2.started = True
|
||||
|
||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * (_ONE_SECOND * 2)))
|
||||
|
||||
await voice_assistant_udp_server_v2.run_pipeline(
|
||||
conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
||||
)
|
||||
|
||||
|
||||
async def test_cancelled(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||
) -> None:
|
||||
"""Test when the server is stopped while waiting for speech."""
|
||||
|
||||
voice_assistant_udp_server_v2.started = True
|
||||
|
||||
voice_assistant_udp_server_v2.queue.put_nowait(b"")
|
||||
|
||||
await voice_assistant_udp_server_v2.run_pipeline(
|
||||
conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
||||
)
|
||||
|
||||
# No events should be sent if cancelled while waiting for speech
|
||||
voice_assistant_udp_server_v2.handle_event.assert_not_called()
|
||||
|
|
Loading…
Add table
Reference in a new issue