Test specifying pipeline in calls to async_pipeline_from_audio_stream (#91739)
This commit is contained in:
parent
0429b321b8
commit
a98be9dc84
3 changed files with 398 additions and 10 deletions
|
@ -1,18 +1,27 @@
|
|||
"""Test fixtures for voice assistant."""
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import AsyncIterable, Generator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import stt, tts
|
||||
from homeassistant.components.assist_pipeline import DOMAIN
|
||||
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockModule, mock_integration, mock_platform
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
MockModule,
|
||||
MockPlatform,
|
||||
mock_config_flow,
|
||||
mock_integration,
|
||||
mock_platform,
|
||||
)
|
||||
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
|
||||
mock_get_cache_files,
|
||||
mock_init_cache_dir,
|
||||
|
@ -21,7 +30,7 @@ from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unuse
|
|||
_TRANSCRIPT = "test transcript"
|
||||
|
||||
|
||||
class MockSttProvider(stt.Provider):
|
||||
class BaseProvider:
|
||||
"""Mock STT provider."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, text: str) -> None:
|
||||
|
@ -71,6 +80,14 @@ class MockSttProvider(stt.Provider):
|
|||
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
|
||||
|
||||
|
||||
class MockSttProvider(BaseProvider, stt.Provider):
|
||||
"""Mock provider."""
|
||||
|
||||
|
||||
class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity):
|
||||
"""Mock provider entity."""
|
||||
|
||||
|
||||
class MockTTSProvider(tts.Provider):
|
||||
"""Mock TTS provider."""
|
||||
|
||||
|
@ -119,27 +136,96 @@ async def mock_stt_provider(hass) -> MockSttProvider:
|
|||
return MockSttProvider(hass, _TRANSCRIPT)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stt_provider_entity(hass) -> MockSttProviderEntity:
|
||||
"""Test provider entity fixture."""
|
||||
return MockSttProviderEntity(hass, _TRANSCRIPT)
|
||||
|
||||
|
||||
class MockSttPlatform(MockPlatform):
|
||||
"""Provide a fake STT platform."""
|
||||
|
||||
def __init__(self, *, async_get_engine, **kwargs):
|
||||
"""Initialize the stt platform."""
|
||||
super().__init__(**kwargs)
|
||||
self.async_get_engine = async_get_engine
|
||||
|
||||
|
||||
class MockFlow(ConfigFlow):
|
||||
"""Test flow."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
|
||||
"""Mock config flow."""
|
||||
mock_platform(hass, "test.config_flow")
|
||||
|
||||
with mock_config_flow("test", MockFlow):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_components(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_stt_provider_entity: MockSttProviderEntity,
|
||||
config_flow_fixture,
|
||||
mock_get_cache_files, # noqa: F811
|
||||
mock_init_cache_dir, # noqa: F811,
|
||||
):
|
||||
"""Initialize relevant components with empty configs."""
|
||||
mock_integration(hass, MockModule(domain="test"))
|
||||
|
||||
async def async_setup_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setup(config_entry, stt.DOMAIN)
|
||||
return True
|
||||
|
||||
async def async_unload_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Unload up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_unload(config_entry, stt.DOMAIN)
|
||||
return True
|
||||
|
||||
async def async_setup_entry_stt_platform(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up test stt platform via config entry."""
|
||||
async_add_entities([mock_stt_provider_entity])
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
"test",
|
||||
async_setup_entry=async_setup_entry_init,
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
)
|
||||
mock_platform(hass, "test.tts", MockTTS())
|
||||
mock_platform(
|
||||
hass,
|
||||
"test.stt",
|
||||
Mock(async_get_engine=AsyncMock(return_value=mock_stt_provider)),
|
||||
MockSttPlatform(
|
||||
async_get_engine=AsyncMock(return_value=mock_stt_provider),
|
||||
async_setup_entry=async_setup_entry_stt_platform,
|
||||
),
|
||||
)
|
||||
mock_platform(hass, "test.config_flow")
|
||||
|
||||
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
||||
assert await async_setup_component(hass, "media_source", {})
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
config_entry = MockConfigEntry(domain="test")
|
||||
config_entry.add_to_hass(hass)
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# serializer version: 1
|
||||
# name: test_pipeline_from_audio_stream
|
||||
# name: test_pipeline_from_audio_stream_auto
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
|
@ -83,3 +83,171 @@
|
|||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_pipeline_from_audio_stream_entity
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'language': 'en-US',
|
||||
'pipeline': 'test_name',
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'test',
|
||||
'metadata': dict({
|
||||
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
||||
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
||||
'codec': <AudioCodecs.PCM: 'pcm'>,
|
||||
'format': <AudioFormats.WAV: 'wav'>,
|
||||
'language': 'en-US',
|
||||
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.STT_START: 'stt-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'stt_output': dict({
|
||||
'text': 'test transcript',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.STT_END: 'stt-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'test transcript',
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'code': 'no_intent_match',
|
||||
}),
|
||||
'language': 'en-US',
|
||||
'response_type': 'error',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': "Sorry, I couldn't understand that",
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'test',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_END: 'tts-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_pipeline_from_audio_stream_legacy
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'language': 'en-US',
|
||||
'pipeline': 'test_name',
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'test',
|
||||
'metadata': dict({
|
||||
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
||||
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
||||
'codec': <AudioCodecs.PCM: 'pcm'>,
|
||||
'format': <AudioFormats.WAV: 'wav'>,
|
||||
'language': 'en-US',
|
||||
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.STT_START: 'stt-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'stt_output': dict({
|
||||
'text': 'test transcript',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.STT_END: 'stt-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'test transcript',
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'conversation_id': None,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'code': 'no_intent_match',
|
||||
}),
|
||||
'language': 'en-US',
|
||||
'response_type': 'error',
|
||||
'speech': dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': "Sorry, I couldn't understand that",
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'test',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_END: 'tts-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
|
|
|
@ -6,11 +6,21 @@ from syrupy.assertion import SnapshotAssertion
|
|||
from homeassistant.components import assist_pipeline, stt
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
|
||||
from .conftest import MockSttProvider, MockSttProviderEntity
|
||||
|
||||
async def test_pipeline_from_audio_stream(
|
||||
hass: HomeAssistant, mock_stt_provider, init_components, snapshot: SnapshotAssertion
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
async def test_pipeline_from_audio_stream_auto(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test creating a pipeline from an audio stream."""
|
||||
"""Test creating a pipeline from an audio stream.
|
||||
|
||||
In this test, no pipeline is specified.
|
||||
"""
|
||||
|
||||
events = []
|
||||
|
||||
|
@ -42,3 +52,127 @@ async def test_pipeline_from_audio_stream(
|
|||
|
||||
assert processed == snapshot
|
||||
assert mock_stt_provider.received == [b"part1", b"part2"]
|
||||
|
||||
|
||||
async def test_pipeline_from_audio_stream_legacy(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test creating a pipeline from an audio stream.
|
||||
|
||||
In this test, a pipeline using a legacy stt engine is used.
|
||||
"""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
events = []
|
||||
|
||||
async def audio_data():
|
||||
yield b"part1"
|
||||
yield b"part2"
|
||||
yield b""
|
||||
|
||||
# Create a pipeline using an stt entity
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "homeassistant",
|
||||
"language": "en-US",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test",
|
||||
"tts_engine": "test",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id = msg["result"]["id"]
|
||||
|
||||
# Use the created pipeline
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
|
||||
processed = []
|
||||
for event in events:
|
||||
as_dict = asdict(event)
|
||||
as_dict.pop("timestamp")
|
||||
processed.append(as_dict)
|
||||
|
||||
assert processed == snapshot
|
||||
assert mock_stt_provider.received == [b"part1", b"part2"]
|
||||
|
||||
|
||||
async def test_pipeline_from_audio_stream_entity(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
mock_stt_provider_entity: MockSttProviderEntity,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test creating a pipeline from an audio stream.
|
||||
|
||||
In this test, a pipeline using am stt entity is used.
|
||||
"""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
events = []
|
||||
|
||||
async def audio_data():
|
||||
yield b"part1"
|
||||
yield b"part2"
|
||||
yield b""
|
||||
|
||||
# Create a pipeline using an stt entity
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "homeassistant",
|
||||
"language": "en-US",
|
||||
"name": "test_name",
|
||||
"stt_engine": mock_stt_provider_entity.entity_id,
|
||||
"tts_engine": "test",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id = msg["result"]["id"]
|
||||
|
||||
# Use the created pipeline
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
|
||||
processed = []
|
||||
for event in events:
|
||||
as_dict = asdict(event)
|
||||
as_dict.pop("timestamp")
|
||||
processed.append(as_dict)
|
||||
|
||||
assert processed == snapshot
|
||||
assert mock_stt_provider_entity.received == [b"part1", b"part2"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue