Don't resolve default stt engine in assist pipelines (#91936)
* Don't resolve default stt engine in assist pipelines * Apply suggestion from code review * Add tests * Tweak * Add test * Improve test coverage
This commit is contained in:
parent
0d815a1688
commit
c5d0c392a9
5 changed files with 192 additions and 24 deletions
|
@ -83,14 +83,16 @@ async def async_get_pipeline(
|
|||
if pipeline_id is None:
|
||||
# There's no preferred pipeline, construct a pipeline for the
|
||||
# configured language
|
||||
stt_engine = stt.async_default_provider(hass)
|
||||
stt_language = hass.config.language if stt_engine else None
|
||||
return await pipeline_data.pipeline_store.async_create_item(
|
||||
{
|
||||
"conversation_engine": None,
|
||||
"conversation_language": None,
|
||||
"language": hass.config.language,
|
||||
"name": hass.config.language,
|
||||
"stt_engine": None,
|
||||
"stt_language": None,
|
||||
"stt_engine": stt_engine,
|
||||
"stt_language": stt_language,
|
||||
"tts_engine": None,
|
||||
"tts_language": None,
|
||||
"tts_voice": None,
|
||||
|
@ -261,22 +263,14 @@ class PipelineRun:
|
|||
"""Prepare speech to text."""
|
||||
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
|
||||
|
||||
if self.pipeline.stt_engine is not None:
|
||||
# Try entity first
|
||||
stt_provider = stt.async_get_speech_to_text_entity(
|
||||
self.hass,
|
||||
self.pipeline.stt_engine,
|
||||
)
|
||||
# pipeline.stt_engine can't be None or this function is not called
|
||||
stt_provider = stt.async_get_speech_to_text_engine(
|
||||
self.hass,
|
||||
self.pipeline.stt_engine, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if stt_provider is None:
|
||||
# Try legacy provider second
|
||||
stt_provider = stt.async_get_provider(
|
||||
self.hass,
|
||||
self.pipeline.stt_engine,
|
||||
)
|
||||
|
||||
if stt_provider is None:
|
||||
engine = self.pipeline.stt_engine or "default"
|
||||
engine = self.pipeline.stt_engine
|
||||
raise SpeechToTextError(
|
||||
code="stt-provider-missing",
|
||||
message=f"No speech to text provider for: {engine}",
|
||||
|
@ -580,11 +574,14 @@ class PipelineInput:
|
|||
async def validate(self) -> None:
|
||||
"""Validate pipeline input against start stage."""
|
||||
if self.run.start_stage == PipelineStage.STT:
|
||||
if self.run.pipeline.stt_engine is None:
|
||||
raise PipelineRunValidationError(
|
||||
"the pipeline does not support speech to text"
|
||||
)
|
||||
if self.stt_metadata is None:
|
||||
raise PipelineRunValidationError(
|
||||
"stt_metadata is required for speech to text"
|
||||
)
|
||||
|
||||
if self.stt_stream is None:
|
||||
raise PipelineRunValidationError(
|
||||
"stt_stream is required for speech to text"
|
||||
|
|
|
@ -41,12 +41,14 @@ from .legacy import (
|
|||
Provider,
|
||||
SpeechMetadata,
|
||||
SpeechResult,
|
||||
async_default_provider,
|
||||
async_get_provider,
|
||||
async_setup_legacy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"async_get_provider",
|
||||
"async_get_speech_to_text_engine",
|
||||
"async_get_speech_to_text_entity",
|
||||
"AudioBitRates",
|
||||
"AudioChannels",
|
||||
|
@ -64,6 +66,14 @@ __all__ = [
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@callback
|
||||
def async_default_engine(hass: HomeAssistant) -> str | None:
|
||||
"""Return the domain or entity id of the default engine."""
|
||||
return async_default_provider(hass) or next(
|
||||
iter(hass.states.async_entity_ids(DOMAIN)), None
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_speech_to_text_entity(
|
||||
hass: HomeAssistant, entity_id: str
|
||||
|
@ -74,6 +84,16 @@ def async_get_speech_to_text_entity(
|
|||
return component.get_entity(entity_id)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_speech_to_text_engine(
|
||||
hass: HomeAssistant, engine_id: str
|
||||
) -> SpeechToTextEntity | Provider | None:
|
||||
"""Return stt entity or legacy provider."""
|
||||
if entity := async_get_speech_to_text_entity(hass, engine_id):
|
||||
return entity
|
||||
return async_get_provider(hass, engine_id)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_speech_to_text_languages(hass: HomeAssistant) -> set[str]:
|
||||
"""Return a set with the union of languages supported by stt engines."""
|
||||
|
|
|
@ -27,6 +27,15 @@ from .const import (
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@callback
|
||||
def async_default_provider(hass: HomeAssistant) -> str | None:
|
||||
"""Return the domain of the default provider."""
|
||||
if "cloud" in hass.data[DATA_PROVIDERS]:
|
||||
return "cloud"
|
||||
|
||||
return next(iter(hass.data[DATA_PROVIDERS]), None)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_provider(
|
||||
hass: HomeAssistant, domain: str | None = None
|
||||
|
@ -35,13 +44,8 @@ def async_get_provider(
|
|||
if domain:
|
||||
return hass.data[DATA_PROVIDERS].get(domain)
|
||||
|
||||
if not hass.data[DATA_PROVIDERS]:
|
||||
return None
|
||||
|
||||
if "cloud" in hass.data[DATA_PROVIDERS]:
|
||||
return hass.data[DATA_PROVIDERS]["cloud"]
|
||||
|
||||
return next(iter(hass.data[DATA_PROVIDERS].values()))
|
||||
provider = async_default_provider(hass)
|
||||
return hass.data[DATA_PROVIDERS][provider] if provider is not None else None
|
||||
|
||||
|
||||
@callback
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Test Voice Assistant init."""
|
||||
from dataclasses import asdict
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import assist_pipeline, stt
|
||||
|
@ -184,3 +185,63 @@ async def test_pipeline_from_audio_stream_entity(
|
|||
|
||||
assert processed == snapshot
|
||||
assert mock_stt_provider_entity.received == [b"part1", b"part2"]
|
||||
|
||||
|
||||
async def test_pipeline_from_audio_stream_no_stt(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test creating a pipeline from an audio stream.
|
||||
|
||||
In this test, the pipeline does not support stt
|
||||
"""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
events = []
|
||||
|
||||
async def audio_data():
|
||||
yield b"part1"
|
||||
yield b"part2"
|
||||
yield b""
|
||||
|
||||
# Create a pipeline without stt support
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "homeassistant",
|
||||
"conversation_language": "en-US",
|
||||
"language": "en",
|
||||
"name": "test_name",
|
||||
"stt_engine": None,
|
||||
"stt_language": None,
|
||||
"tts_engine": "test",
|
||||
"tts_language": "en-AU",
|
||||
"tts_voice": "Arnold Schwarzenegger",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id = msg["result"]["id"]
|
||||
|
||||
# Try to use the created pipeline
|
||||
with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError):
|
||||
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||
hass,
|
||||
Context(),
|
||||
events.append,
|
||||
stt.SpeechMetadata(
|
||||
language="en-UK",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
audio_data(),
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
|
||||
assert not events
|
||||
|
|
|
@ -18,7 +18,9 @@ from homeassistant.components.stt import (
|
|||
SpeechResult,
|
||||
SpeechResultState,
|
||||
SpeechToTextEntity,
|
||||
async_default_engine,
|
||||
async_get_provider,
|
||||
async_get_speech_to_text_engine,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
|
@ -349,6 +351,9 @@ async def test_get_provider(
|
|||
await mock_setup(hass, tmp_path, mock_provider)
|
||||
assert mock_provider == async_get_provider(hass, TEST_DOMAIN)
|
||||
|
||||
# Test getting the default provider
|
||||
assert mock_provider == async_get_provider(hass)
|
||||
|
||||
|
||||
async def test_config_entry_unload(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||
|
@ -444,3 +449,84 @@ async def test_ws_list_engines(
|
|||
assert msg["result"] == {
|
||||
"providers": [{"engine_id": engine_id, "supported_languages": ["de-CH", "de"]}]
|
||||
}
|
||||
|
||||
|
||||
async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
|
||||
"""Test async_default_engine."""
|
||||
assert await async_setup_component(hass, "stt", {"stt": {}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert async_default_engine(hass) is None
|
||||
|
||||
|
||||
async def test_default_engine(hass: HomeAssistant, tmp_path: Path) -> None:
|
||||
"""Test async_default_engine."""
|
||||
mock_stt_platform(
|
||||
hass,
|
||||
tmp_path,
|
||||
TEST_DOMAIN,
|
||||
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||
)
|
||||
assert await async_setup_component(hass, "stt", {"stt": {"platform": TEST_DOMAIN}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert async_default_engine(hass) == TEST_DOMAIN
|
||||
|
||||
|
||||
async def test_default_engine_entity(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||
) -> None:
|
||||
"""Test async_default_engine."""
|
||||
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||
|
||||
assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}"
|
||||
|
||||
|
||||
async def test_default_engine_prefer_cloud(hass: HomeAssistant, tmp_path: Path) -> None:
|
||||
"""Test async_default_engine."""
|
||||
mock_stt_platform(
|
||||
hass,
|
||||
tmp_path,
|
||||
TEST_DOMAIN,
|
||||
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||
)
|
||||
mock_stt_platform(
|
||||
hass,
|
||||
tmp_path,
|
||||
"cloud",
|
||||
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||
)
|
||||
assert await async_setup_component(
|
||||
hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert async_default_engine(hass) == "cloud"
|
||||
|
||||
|
||||
async def test_get_engine_legacy(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
|
||||
) -> None:
|
||||
"""Test async_get_speech_to_text_engine."""
|
||||
mock_stt_platform(
|
||||
hass,
|
||||
tmp_path,
|
||||
TEST_DOMAIN,
|
||||
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||
)
|
||||
assert await async_setup_component(
|
||||
hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert async_get_speech_to_text_engine(hass, "no_such_provider") is None
|
||||
assert async_get_speech_to_text_engine(hass, "test") is mock_provider
|
||||
|
||||
|
||||
async def test_get_engine_entity(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||
) -> None:
|
||||
"""Test async_get_speech_to_text_engine."""
|
||||
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||
|
||||
assert async_get_speech_to_text_engine(hass, "stt.test") is mock_provider_entity
|
||||
|
|
Loading…
Add table
Reference in a new issue