Use pipeline ID in event (#92100)

* Use pipeline ID in event

* Fix tests
This commit is contained in:
Paulus Schoutsen 2023-04-26 22:40:17 -04:00 committed by GitHub
parent 32ed45084a
commit 7c696754ed
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 40 additions and 38 deletions

View file

@ -1,5 +1,6 @@
"""Test Voice Assistant init."""
from dataclasses import asdict
from unittest.mock import ANY
import pytest
from syrupy.assertion import SnapshotAssertion
@ -12,6 +13,19 @@ from .conftest import MockSttProvider, MockSttProviderEntity
from tests.typing import WebSocketGenerator
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
"""Process events to remove dynamic values."""
processed = []
for event in events:
as_dict = asdict(event)
as_dict.pop("timestamp")
if as_dict["type"] == assist_pipeline.PipelineEventType.RUN_START:
as_dict["data"]["pipeline"] = ANY
processed.append(as_dict)
return processed
async def test_pipeline_from_audio_stream_auto(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
@ -45,13 +59,7 @@ async def test_pipeline_from_audio_stream_auto(
audio_data(),
)
processed = []
for event in events:
as_dict = asdict(event)
as_dict.pop("timestamp")
processed.append(as_dict)
assert processed == snapshot
assert process_events(events) == snapshot
assert mock_stt_provider.received == [b"part1", b"part2"]
@ -111,13 +119,7 @@ async def test_pipeline_from_audio_stream_legacy(
pipeline_id=pipeline_id,
)
processed = []
for event in events:
as_dict = asdict(event)
as_dict.pop("timestamp")
processed.append(as_dict)
assert processed == snapshot
assert process_events(events) == snapshot
assert mock_stt_provider.received == [b"part1", b"part2"]
@ -177,13 +179,7 @@ async def test_pipeline_from_audio_stream_entity(
pipeline_id=pipeline_id,
)
processed = []
for event in events:
as_dict = asdict(event)
as_dict.pop("timestamp")
processed.append(as_dict)
assert processed == snapshot
assert process_events(events) == snapshot
assert mock_stt_provider_entity.received == [b"part1", b"part2"]