Add pipeline VAD events (#98603)

* Add stt-vad-start and stt-vad-end pipeline events

* Update tests
This commit is contained in:
Michael Hansen 2023-08-17 18:58:58 -05:00 committed by GitHub
parent c17f08a3f5
commit 49d2c60992
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 6 deletions

View file

@ -254,6 +254,8 @@ class PipelineEventType(StrEnum):
WAKE_WORD_START = "wake_word-start" WAKE_WORD_START = "wake_word-start"
WAKE_WORD_END = "wake_word-end" WAKE_WORD_END = "wake_word-end"
STT_START = "stt-start" STT_START = "stt-start"
STT_VAD_START = "stt-vad-start"
STT_VAD_END = "stt-vad-end"
STT_END = "stt-end" STT_END = "stt-end"
INTENT_START = "intent-start" INTENT_START = "intent-start"
INTENT_END = "intent-end" INTENT_END = "intent-end"
@ -612,11 +614,31 @@ class PipelineRun:
stream: AsyncIterable[bytes], stream: AsyncIterable[bytes],
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
"""Stop stream when voice command is finished.""" """Stop stream when voice command is finished."""
sent_vad_start = False
timestamp_ms = 0
async for chunk in stream: async for chunk in stream:
if not segmenter.process(chunk): if not segmenter.process(chunk):
# Silence detected at the end of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_END,
{"timestamp": timestamp_ms},
)
)
break break
if segmenter.in_command and (not sent_vad_start):
# Speech detected at start of voice command
self.process_event(
PipelineEvent(
PipelineEventType.STT_VAD_START,
{"timestamp": timestamp_ms},
)
)
sent_vad_start = True
yield chunk yield chunk
timestamp_ms += (len(chunk) // 2) // 16 # milliseconds @ 16Khz
# Transcribe audio stream # Transcribe audio stream
result = await self.stt_provider.async_process_audio_stream( result = await self.stt_provider.async_process_audio_stream(

View file

@ -311,6 +311,12 @@
}), }),
'type': <PipelineEventType.STT_START: 'stt-start'>, 'type': <PipelineEventType.STT_START: 'stt-start'>,
}), }),
dict({
'data': dict({
'timestamp': 0,
}),
'type': <PipelineEventType.STT_VAD_START: 'stt-vad-start'>,
}),
dict({ dict({
'data': dict({ 'data': dict({
'stt_output': dict({ 'stt_output': dict({

View file

@ -40,7 +40,7 @@ async def test_pipeline_from_audio_stream_auto(
In this test, no pipeline is specified. In this test, no pipeline is specified.
""" """
events = [] events: list[assist_pipeline.PipelineEvent] = []
async def audio_data(): async def audio_data():
yield b"part1" yield b"part1"
@ -79,7 +79,7 @@ async def test_pipeline_from_audio_stream_legacy(
""" """
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
events = [] events: list[assist_pipeline.PipelineEvent] = []
async def audio_data(): async def audio_data():
yield b"part1" yield b"part1"
@ -139,7 +139,7 @@ async def test_pipeline_from_audio_stream_entity(
""" """
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
events = [] events: list[assist_pipeline.PipelineEvent] = []
async def audio_data(): async def audio_data():
yield b"part1" yield b"part1"
@ -199,7 +199,7 @@ async def test_pipeline_from_audio_stream_no_stt(
""" """
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
events = [] events: list[assist_pipeline.PipelineEvent] = []
async def audio_data(): async def audio_data():
yield b"part1" yield b"part1"
@ -257,7 +257,7 @@ async def test_pipeline_from_audio_stream_unknown_pipeline(
In this test, the pipeline does not exist. In this test, the pipeline does not exist.
""" """
events = [] events: list[assist_pipeline.PipelineEvent] = []
async def audio_data(): async def audio_data():
yield b"part1" yield b"part1"
@ -294,7 +294,7 @@ async def test_pipeline_from_audio_stream_wake_word(
) -> None: ) -> None:
"""Test creating a pipeline from an audio stream with wake word.""" """Test creating a pipeline from an audio stream with wake word."""
events = [] events: list[assist_pipeline.PipelineEvent] = []
# [0, 1, ...] # [0, 1, ...]
wake_chunk_1 = bytes(it.islice(it.cycle(range(256)), BYTES_ONE_SECOND)) wake_chunk_1 = bytes(it.islice(it.cycle(range(256)), BYTES_ONE_SECOND))