Test specifying pipeline in calls to async_pipeline_from_audio_stream (#91739)
This commit is contained in:
parent
0429b321b8
commit
a98be9dc84
3 changed files with 398 additions and 10 deletions
|
@ -1,18 +1,27 @@
|
|||
"""Test fixtures for voice assistant."""
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import AsyncIterable, Generator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import stt, tts
|
||||
from homeassistant.components.assist_pipeline import DOMAIN
|
||||
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockModule, mock_integration, mock_platform
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
MockModule,
|
||||
MockPlatform,
|
||||
mock_config_flow,
|
||||
mock_integration,
|
||||
mock_platform,
|
||||
)
|
||||
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
|
||||
mock_get_cache_files,
|
||||
mock_init_cache_dir,
|
||||
|
@ -21,7 +30,7 @@ from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unuse
|
|||
_TRANSCRIPT = "test transcript"
|
||||
|
||||
|
||||
class MockSttProvider(stt.Provider):
|
||||
class BaseProvider:
|
||||
"""Mock STT provider."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, text: str) -> None:
|
||||
|
@ -71,6 +80,14 @@ class MockSttProvider(stt.Provider):
|
|||
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
|
||||
|
||||
|
||||
class MockSttProvider(BaseProvider, stt.Provider):
|
||||
"""Mock provider."""
|
||||
|
||||
|
||||
class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity):
|
||||
"""Mock provider entity."""
|
||||
|
||||
|
||||
class MockTTSProvider(tts.Provider):
|
||||
"""Mock TTS provider."""
|
||||
|
||||
|
@ -119,27 +136,96 @@ async def mock_stt_provider(hass) -> MockSttProvider:
|
|||
return MockSttProvider(hass, _TRANSCRIPT)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stt_provider_entity(hass) -> MockSttProviderEntity:
|
||||
"""Test provider entity fixture."""
|
||||
return MockSttProviderEntity(hass, _TRANSCRIPT)
|
||||
|
||||
|
||||
class MockSttPlatform(MockPlatform):
|
||||
"""Provide a fake STT platform."""
|
||||
|
||||
def __init__(self, *, async_get_engine, **kwargs):
|
||||
"""Initialize the stt platform."""
|
||||
super().__init__(**kwargs)
|
||||
self.async_get_engine = async_get_engine
|
||||
|
||||
|
||||
class MockFlow(ConfigFlow):
|
||||
"""Test flow."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
|
||||
"""Mock config flow."""
|
||||
mock_platform(hass, "test.config_flow")
|
||||
|
||||
with mock_config_flow("test", MockFlow):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_components(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_stt_provider_entity: MockSttProviderEntity,
|
||||
config_flow_fixture,
|
||||
mock_get_cache_files, # noqa: F811
|
||||
mock_init_cache_dir, # noqa: F811,
|
||||
):
|
||||
"""Initialize relevant components with empty configs."""
|
||||
mock_integration(hass, MockModule(domain="test"))
|
||||
|
||||
async def async_setup_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setup(config_entry, stt.DOMAIN)
|
||||
return True
|
||||
|
||||
async def async_unload_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Unload up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_unload(config_entry, stt.DOMAIN)
|
||||
return True
|
||||
|
||||
async def async_setup_entry_stt_platform(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up test stt platform via config entry."""
|
||||
async_add_entities([mock_stt_provider_entity])
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
"test",
|
||||
async_setup_entry=async_setup_entry_init,
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
)
|
||||
mock_platform(hass, "test.tts", MockTTS())
|
||||
mock_platform(
|
||||
hass,
|
||||
"test.stt",
|
||||
Mock(async_get_engine=AsyncMock(return_value=mock_stt_provider)),
|
||||
MockSttPlatform(
|
||||
async_get_engine=AsyncMock(return_value=mock_stt_provider),
|
||||
async_setup_entry=async_setup_entry_stt_platform,
|
||||
),
|
||||
)
|
||||
mock_platform(hass, "test.config_flow")
|
||||
|
||||
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
||||
assert await async_setup_component(hass, "media_source", {})
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
config_entry = MockConfigEntry(domain="test")
|
||||
config_entry.add_to_hass(hass)
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue