hass-core/tests/components/assist_pipeline/conftest.py
Paulus Schoutsen bd22e0bd43
Allow picking a pipeline for voip devices (#91524)
* Allow picking a pipeline for voip device

* Add tests

* Fix test

* Adjust on new pipeline data
2023-04-17 12:09:11 -05:00

147 lines
4.5 KiB
Python

"""Test fixtures for voice assistant."""
from collections.abc import AsyncIterable
from typing import Any
from unittest.mock import AsyncMock, Mock
import pytest
from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from tests.common import MockModule, mock_integration, mock_platform
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
mock_get_cache_files,
mock_init_cache_dir,
)
_TRANSCRIPT = "test transcript"
class MockSttProvider(stt.Provider):
"""Mock STT provider."""
def __init__(self, hass: HomeAssistant, text: str) -> None:
"""Init test provider."""
self.hass = hass
self.text = text
self.received = []
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return ["en-US"]
@property
def supported_formats(self) -> list[stt.AudioFormats]:
"""Return a list of supported formats."""
return [stt.AudioFormats.WAV]
@property
def supported_codecs(self) -> list[stt.AudioCodecs]:
"""Return a list of supported codecs."""
return [stt.AudioCodecs.PCM]
@property
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
"""Return a list of supported bitrates."""
return [stt.AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
"""Return a list of supported samplerates."""
return [stt.AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[stt.AudioChannels]:
"""Return a list of supported channels."""
return [stt.AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
) -> stt.SpeechResult:
"""Process an audio stream."""
async for data in stream:
if not data:
break
self.received.append(data)
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
class MockTTSProvider(tts.Provider):
"""Mock TTS provider."""
name = "Test"
@property
def default_language(self) -> str:
"""Return the default language."""
return "en"
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return ["en-US"]
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> tts.TtsAudioType:
"""Load TTS data."""
return ("mp3", b"")
class MockTTS:
"""A mock TTS platform."""
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
async def async_get_engine(
self,
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> tts.Provider:
"""Set up a mock speech component."""
return MockTTSProvider()
@pytest.fixture
async def mock_stt_provider(hass) -> MockSttProvider:
"""Mock STT provider."""
return MockSttProvider(hass, _TRANSCRIPT)
@pytest.fixture
async def init_components(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_get_cache_files, # noqa: F811
mock_init_cache_dir, # noqa: F811,
):
"""Initialize relevant components with empty configs."""
mock_integration(hass, MockModule(domain="test"))
mock_platform(hass, "test.tts", MockTTS())
mock_platform(
hass,
"test.stt",
Mock(async_get_engine=AsyncMock(return_value=mock_stt_provider)),
)
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component(hass, "assist_pipeline", {})
@pytest.fixture
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
"""Return pipeline storage collection."""
return hass.data[DOMAIN].pipeline_store