Don't resolve default stt engine in assist pipelines (#91936)

* Don't resolve default stt engine in assist pipelines

* Apply suggestion from code review

* Add tests

* Tweak

* Add test

* Improve test coverage
This commit is contained in:
Erik Montnemery 2023-04-24 13:37:13 +02:00 committed by GitHub
parent 0d815a1688
commit c5d0c392a9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 192 additions and 24 deletions

View file

@ -83,14 +83,16 @@ async def async_get_pipeline(
if pipeline_id is None:
# There's no preferred pipeline, construct a pipeline for the
# configured language
stt_engine = stt.async_default_provider(hass)
stt_language = hass.config.language if stt_engine else None
return await pipeline_data.pipeline_store.async_create_item(
{
"conversation_engine": None,
"conversation_language": None,
"language": hass.config.language,
"name": hass.config.language,
"stt_engine": None,
"stt_language": None,
"stt_engine": stt_engine,
"stt_language": stt_language,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
@ -261,22 +263,14 @@ class PipelineRun:
"""Prepare speech to text."""
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
if self.pipeline.stt_engine is not None:
# Try entity first
stt_provider = stt.async_get_speech_to_text_entity(
# pipeline.stt_engine can't be None or this function is not called
stt_provider = stt.async_get_speech_to_text_engine(
self.hass,
self.pipeline.stt_engine,
self.pipeline.stt_engine, # type: ignore[arg-type]
)
if stt_provider is None:
# Try legacy provider second
stt_provider = stt.async_get_provider(
self.hass,
self.pipeline.stt_engine,
)
if stt_provider is None:
engine = self.pipeline.stt_engine or "default"
engine = self.pipeline.stt_engine
raise SpeechToTextError(
code="stt-provider-missing",
message=f"No speech to text provider for: {engine}",
@ -580,11 +574,14 @@ class PipelineInput:
async def validate(self) -> None:
"""Validate pipeline input against start stage."""
if self.run.start_stage == PipelineStage.STT:
if self.run.pipeline.stt_engine is None:
raise PipelineRunValidationError(
"the pipeline does not support speech to text"
)
if self.stt_metadata is None:
raise PipelineRunValidationError(
"stt_metadata is required for speech to text"
)
if self.stt_stream is None:
raise PipelineRunValidationError(
"stt_stream is required for speech to text"

View file

@ -41,12 +41,14 @@ from .legacy import (
Provider,
SpeechMetadata,
SpeechResult,
async_default_provider,
async_get_provider,
async_setup_legacy,
)
__all__ = [
"async_get_provider",
"async_get_speech_to_text_engine",
"async_get_speech_to_text_entity",
"AudioBitRates",
"AudioChannels",
@ -64,6 +66,14 @@ __all__ = [
_LOGGER = logging.getLogger(__name__)
@callback
def async_default_engine(hass: HomeAssistant) -> str | None:
"""Return the domain or entity id of the default engine."""
return async_default_provider(hass) or next(
iter(hass.states.async_entity_ids(DOMAIN)), None
)
@callback
def async_get_speech_to_text_entity(
hass: HomeAssistant, entity_id: str
@ -74,6 +84,16 @@ def async_get_speech_to_text_entity(
return component.get_entity(entity_id)
@callback
def async_get_speech_to_text_engine(
hass: HomeAssistant, engine_id: str
) -> SpeechToTextEntity | Provider | None:
"""Return stt entity or legacy provider."""
if entity := async_get_speech_to_text_entity(hass, engine_id):
return entity
return async_get_provider(hass, engine_id)
@callback
def async_get_speech_to_text_languages(hass: HomeAssistant) -> set[str]:
"""Return a set with the union of languages supported by stt engines."""

View file

@ -27,6 +27,15 @@ from .const import (
_LOGGER = logging.getLogger(__name__)
@callback
def async_default_provider(hass: HomeAssistant) -> str | None:
"""Return the domain of the default provider."""
if "cloud" in hass.data[DATA_PROVIDERS]:
return "cloud"
return next(iter(hass.data[DATA_PROVIDERS]), None)
@callback
def async_get_provider(
hass: HomeAssistant, domain: str | None = None
@ -35,13 +44,8 @@ def async_get_provider(
if domain:
return hass.data[DATA_PROVIDERS].get(domain)
if not hass.data[DATA_PROVIDERS]:
return None
if "cloud" in hass.data[DATA_PROVIDERS]:
return hass.data[DATA_PROVIDERS]["cloud"]
return next(iter(hass.data[DATA_PROVIDERS].values()))
provider = async_default_provider(hass)
return hass.data[DATA_PROVIDERS][provider] if provider is not None else None
@callback

View file

@ -1,6 +1,7 @@
"""Test Voice Assistant init."""
from dataclasses import asdict
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import assist_pipeline, stt
@ -184,3 +185,63 @@ async def test_pipeline_from_audio_stream_entity(
assert processed == snapshot
assert mock_stt_provider_entity.received == [b"part1", b"part2"]
async def test_pipeline_from_audio_stream_no_stt(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_stt_provider: MockSttProvider,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream.
In this test, the pipeline does not support stt
"""
client = await hass_ws_client(hass)
events = []
async def audio_data():
yield b"part1"
yield b"part2"
yield b""
# Create a pipeline without stt support
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": "en-US",
"language": "en",
"name": "test_name",
"stt_engine": None,
"stt_language": None,
"tts_engine": "test",
"tts_language": "en-AU",
"tts_voice": "Arnold Schwarzenegger",
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
# Try to use the created pipeline
with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
Context(),
events.append,
stt.SpeechMetadata(
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
pipeline_id=pipeline_id,
)
assert not events

View file

@ -18,7 +18,9 @@ from homeassistant.components.stt import (
SpeechResult,
SpeechResultState,
SpeechToTextEntity,
async_default_engine,
async_get_provider,
async_get_speech_to_text_engine,
)
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow
from homeassistant.core import HomeAssistant, State
@ -349,6 +351,9 @@ async def test_get_provider(
await mock_setup(hass, tmp_path, mock_provider)
assert mock_provider == async_get_provider(hass, TEST_DOMAIN)
# Test getting the default provider
assert mock_provider == async_get_provider(hass)
async def test_config_entry_unload(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
@ -444,3 +449,84 @@ async def test_ws_list_engines(
assert msg["result"] == {
"providers": [{"engine_id": engine_id, "supported_languages": ["de-CH", "de"]}]
}
async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
"""Test async_default_engine."""
assert await async_setup_component(hass, "stt", {"stt": {}})
await hass.async_block_till_done()
assert async_default_engine(hass) is None
async def test_default_engine(hass: HomeAssistant, tmp_path: Path) -> None:
"""Test async_default_engine."""
mock_stt_platform(
hass,
tmp_path,
TEST_DOMAIN,
async_get_engine=AsyncMock(return_value=mock_provider),
)
assert await async_setup_component(hass, "stt", {"stt": {"platform": TEST_DOMAIN}})
await hass.async_block_till_done()
assert async_default_engine(hass) == TEST_DOMAIN
async def test_default_engine_entity(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
) -> None:
"""Test async_default_engine."""
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}"
async def test_default_engine_prefer_cloud(hass: HomeAssistant, tmp_path: Path) -> None:
"""Test async_default_engine."""
mock_stt_platform(
hass,
tmp_path,
TEST_DOMAIN,
async_get_engine=AsyncMock(return_value=mock_provider),
)
mock_stt_platform(
hass,
tmp_path,
"cloud",
async_get_engine=AsyncMock(return_value=mock_provider),
)
assert await async_setup_component(
hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]}
)
await hass.async_block_till_done()
assert async_default_engine(hass) == "cloud"
async def test_get_engine_legacy(
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
) -> None:
"""Test async_get_speech_to_text_engine."""
mock_stt_platform(
hass,
tmp_path,
TEST_DOMAIN,
async_get_engine=AsyncMock(return_value=mock_provider),
)
assert await async_setup_component(
hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]}
)
await hass.async_block_till_done()
assert async_get_speech_to_text_engine(hass, "no_such_provider") is None
assert async_get_speech_to_text_engine(hass, "test") is mock_provider
async def test_get_engine_entity(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
) -> None:
"""Test async_get_speech_to_text_engine."""
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
assert async_get_speech_to_text_engine(hass, "stt.test") is mock_provider_entity