hass-core/tests/components/assist_pipeline/conftest.py

233 lines
6.9 KiB
Python

"""Test fixtures for voice assistant."""
from collections.abc import AsyncIterable, Generator
from typing import Any
from unittest.mock import AsyncMock
import pytest
from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
mock_config_flow,
mock_integration,
mock_platform,
)
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
mock_get_cache_files,
mock_init_cache_dir,
)
_TRANSCRIPT = "test transcript"
class BaseProvider:
"""Mock STT provider."""
def __init__(self, hass: HomeAssistant, text: str) -> None:
"""Init test provider."""
self.hass = hass
self.text = text
self.received = []
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return ["en-US"]
@property
def supported_formats(self) -> list[stt.AudioFormats]:
"""Return a list of supported formats."""
return [stt.AudioFormats.WAV]
@property
def supported_codecs(self) -> list[stt.AudioCodecs]:
"""Return a list of supported codecs."""
return [stt.AudioCodecs.PCM]
@property
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
"""Return a list of supported bitrates."""
return [stt.AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
"""Return a list of supported samplerates."""
return [stt.AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[stt.AudioChannels]:
"""Return a list of supported channels."""
return [stt.AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
) -> stt.SpeechResult:
"""Process an audio stream."""
async for data in stream:
if not data:
break
self.received.append(data)
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
class MockSttProvider(BaseProvider, stt.Provider):
"""Mock provider."""
class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity):
"""Mock provider entity."""
class MockTTSProvider(tts.Provider):
"""Mock TTS provider."""
name = "Test"
@property
def default_language(self) -> str:
"""Return the default language."""
return "en"
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return ["en-US"]
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> tts.TtsAudioType:
"""Load TTS data."""
return ("mp3", b"")
class MockTTS:
"""A mock TTS platform."""
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
async def async_get_engine(
self,
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> tts.Provider:
"""Set up a mock speech component."""
return MockTTSProvider()
@pytest.fixture
async def mock_stt_provider(hass) -> MockSttProvider:
"""Mock STT provider."""
return MockSttProvider(hass, _TRANSCRIPT)
@pytest.fixture
def mock_stt_provider_entity(hass) -> MockSttProviderEntity:
"""Test provider entity fixture."""
return MockSttProviderEntity(hass, _TRANSCRIPT)
class MockSttPlatform(MockPlatform):
"""Provide a fake STT platform."""
def __init__(self, *, async_get_engine, **kwargs):
"""Initialize the stt platform."""
super().__init__(**kwargs)
self.async_get_engine = async_get_engine
class MockFlow(ConfigFlow):
"""Test flow."""
@pytest.fixture
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
"""Mock config flow."""
mock_platform(hass, "test.config_flow")
with mock_config_flow("test", MockFlow):
yield
@pytest.fixture
async def init_components(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_stt_provider_entity: MockSttProviderEntity,
config_flow_fixture,
mock_get_cache_files, # noqa: F811
mock_init_cache_dir, # noqa: F811,
):
"""Initialize relevant components with empty configs."""
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setup(config_entry, stt.DOMAIN)
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload up test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, stt.DOMAIN)
return True
async def async_setup_entry_stt_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test stt platform via config entry."""
async_add_entities([mock_stt_provider_entity])
mock_integration(
hass,
MockModule(
"test",
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
mock_platform(hass, "test.tts", MockTTS())
mock_platform(
hass,
"test.stt",
MockSttPlatform(
async_get_engine=AsyncMock(return_value=mock_stt_provider),
async_setup_entry=async_setup_entry_stt_platform,
),
)
mock_platform(hass, "test.config_flow")
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component(hass, "assist_pipeline", {})
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
@pytest.fixture
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
"""Return pipeline storage collection."""
return hass.data[DOMAIN].pipeline_store