Test specifying pipeline in calls to async_pipeline_from_audio_stream (#91739)

This commit is contained in:
Erik Montnemery 2023-04-20 15:01:31 +02:00 committed by GitHub
parent 0429b321b8
commit a98be9dc84
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 398 additions and 10 deletions

View file

@ -1,18 +1,27 @@
"""Test fixtures for voice assistant."""
from collections.abc import AsyncIterable
from collections.abc import AsyncIterable, Generator
from typing import Any
from unittest.mock import AsyncMock, Mock
from unittest.mock import AsyncMock
import pytest
from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from tests.common import MockModule, mock_integration, mock_platform
from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
mock_config_flow,
mock_integration,
mock_platform,
)
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
mock_get_cache_files,
mock_init_cache_dir,
@ -21,7 +30,7 @@ from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unuse
_TRANSCRIPT = "test transcript"
class MockSttProvider(stt.Provider):
class BaseProvider:
"""Mock STT provider."""
def __init__(self, hass: HomeAssistant, text: str) -> None:
@ -71,6 +80,14 @@ class MockSttProvider(stt.Provider):
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
class MockSttProvider(BaseProvider, stt.Provider):
"""Mock provider."""
class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity):
"""Mock provider entity."""
class MockTTSProvider(tts.Provider):
"""Mock TTS provider."""
@ -119,27 +136,96 @@ async def mock_stt_provider(hass) -> MockSttProvider:
return MockSttProvider(hass, _TRANSCRIPT)
@pytest.fixture
def mock_stt_provider_entity(hass) -> MockSttProviderEntity:
"""Test provider entity fixture."""
return MockSttProviderEntity(hass, _TRANSCRIPT)
class MockSttPlatform(MockPlatform):
"""Provide a fake STT platform."""
def __init__(self, *, async_get_engine, **kwargs):
"""Initialize the stt platform."""
super().__init__(**kwargs)
self.async_get_engine = async_get_engine
class MockFlow(ConfigFlow):
"""Test flow."""
@pytest.fixture
def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
"""Mock config flow."""
mock_platform(hass, "test.config_flow")
with mock_config_flow("test", MockFlow):
yield
@pytest.fixture
async def init_components(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_stt_provider_entity: MockSttProviderEntity,
config_flow_fixture,
mock_get_cache_files, # noqa: F811
mock_init_cache_dir, # noqa: F811,
):
"""Initialize relevant components with empty configs."""
mock_integration(hass, MockModule(domain="test"))
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setup(config_entry, stt.DOMAIN)
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload up test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, stt.DOMAIN)
return True
async def async_setup_entry_stt_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test stt platform via config entry."""
async_add_entities([mock_stt_provider_entity])
mock_integration(
hass,
MockModule(
"test",
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
mock_platform(hass, "test.tts", MockTTS())
mock_platform(
hass,
"test.stt",
Mock(async_get_engine=AsyncMock(return_value=mock_stt_provider)),
MockSttPlatform(
async_get_engine=AsyncMock(return_value=mock_stt_provider),
async_setup_entry=async_setup_entry_stt_platform,
),
)
mock_platform(hass, "test.config_flow")
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component(hass, "assist_pipeline", {})
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
@pytest.fixture
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection: