Don't resolve default stt engine in assist pipelines (#91936)
* Don't resolve default stt engine in assist pipelines * Apply suggestion from code review * Add tests * Tweak * Add test * Improve test coverage
This commit is contained in:
parent
0d815a1688
commit
c5d0c392a9
5 changed files with 192 additions and 24 deletions
|
@ -83,14 +83,16 @@ async def async_get_pipeline(
|
||||||
if pipeline_id is None:
|
if pipeline_id is None:
|
||||||
# There's no preferred pipeline, construct a pipeline for the
|
# There's no preferred pipeline, construct a pipeline for the
|
||||||
# configured language
|
# configured language
|
||||||
|
stt_engine = stt.async_default_provider(hass)
|
||||||
|
stt_language = hass.config.language if stt_engine else None
|
||||||
return await pipeline_data.pipeline_store.async_create_item(
|
return await pipeline_data.pipeline_store.async_create_item(
|
||||||
{
|
{
|
||||||
"conversation_engine": None,
|
"conversation_engine": None,
|
||||||
"conversation_language": None,
|
"conversation_language": None,
|
||||||
"language": hass.config.language,
|
"language": hass.config.language,
|
||||||
"name": hass.config.language,
|
"name": hass.config.language,
|
||||||
"stt_engine": None,
|
"stt_engine": stt_engine,
|
||||||
"stt_language": None,
|
"stt_language": stt_language,
|
||||||
"tts_engine": None,
|
"tts_engine": None,
|
||||||
"tts_language": None,
|
"tts_language": None,
|
||||||
"tts_voice": None,
|
"tts_voice": None,
|
||||||
|
@ -261,22 +263,14 @@ class PipelineRun:
|
||||||
"""Prepare speech to text."""
|
"""Prepare speech to text."""
|
||||||
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
|
stt_provider: stt.SpeechToTextEntity | stt.Provider | None = None
|
||||||
|
|
||||||
if self.pipeline.stt_engine is not None:
|
# pipeline.stt_engine can't be None or this function is not called
|
||||||
# Try entity first
|
stt_provider = stt.async_get_speech_to_text_engine(
|
||||||
stt_provider = stt.async_get_speech_to_text_entity(
|
self.hass,
|
||||||
self.hass,
|
self.pipeline.stt_engine, # type: ignore[arg-type]
|
||||||
self.pipeline.stt_engine,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if stt_provider is None:
|
if stt_provider is None:
|
||||||
# Try legacy provider second
|
engine = self.pipeline.stt_engine
|
||||||
stt_provider = stt.async_get_provider(
|
|
||||||
self.hass,
|
|
||||||
self.pipeline.stt_engine,
|
|
||||||
)
|
|
||||||
|
|
||||||
if stt_provider is None:
|
|
||||||
engine = self.pipeline.stt_engine or "default"
|
|
||||||
raise SpeechToTextError(
|
raise SpeechToTextError(
|
||||||
code="stt-provider-missing",
|
code="stt-provider-missing",
|
||||||
message=f"No speech to text provider for: {engine}",
|
message=f"No speech to text provider for: {engine}",
|
||||||
|
@ -580,11 +574,14 @@ class PipelineInput:
|
||||||
async def validate(self) -> None:
|
async def validate(self) -> None:
|
||||||
"""Validate pipeline input against start stage."""
|
"""Validate pipeline input against start stage."""
|
||||||
if self.run.start_stage == PipelineStage.STT:
|
if self.run.start_stage == PipelineStage.STT:
|
||||||
|
if self.run.pipeline.stt_engine is None:
|
||||||
|
raise PipelineRunValidationError(
|
||||||
|
"the pipeline does not support speech to text"
|
||||||
|
)
|
||||||
if self.stt_metadata is None:
|
if self.stt_metadata is None:
|
||||||
raise PipelineRunValidationError(
|
raise PipelineRunValidationError(
|
||||||
"stt_metadata is required for speech to text"
|
"stt_metadata is required for speech to text"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.stt_stream is None:
|
if self.stt_stream is None:
|
||||||
raise PipelineRunValidationError(
|
raise PipelineRunValidationError(
|
||||||
"stt_stream is required for speech to text"
|
"stt_stream is required for speech to text"
|
||||||
|
|
|
@ -41,12 +41,14 @@ from .legacy import (
|
||||||
Provider,
|
Provider,
|
||||||
SpeechMetadata,
|
SpeechMetadata,
|
||||||
SpeechResult,
|
SpeechResult,
|
||||||
|
async_default_provider,
|
||||||
async_get_provider,
|
async_get_provider,
|
||||||
async_setup_legacy,
|
async_setup_legacy,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"async_get_provider",
|
"async_get_provider",
|
||||||
|
"async_get_speech_to_text_engine",
|
||||||
"async_get_speech_to_text_entity",
|
"async_get_speech_to_text_entity",
|
||||||
"AudioBitRates",
|
"AudioBitRates",
|
||||||
"AudioChannels",
|
"AudioChannels",
|
||||||
|
@ -64,6 +66,14 @@ __all__ = [
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_default_engine(hass: HomeAssistant) -> str | None:
|
||||||
|
"""Return the domain or entity id of the default engine."""
|
||||||
|
return async_default_provider(hass) or next(
|
||||||
|
iter(hass.states.async_entity_ids(DOMAIN)), None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_speech_to_text_entity(
|
def async_get_speech_to_text_entity(
|
||||||
hass: HomeAssistant, entity_id: str
|
hass: HomeAssistant, entity_id: str
|
||||||
|
@ -74,6 +84,16 @@ def async_get_speech_to_text_entity(
|
||||||
return component.get_entity(entity_id)
|
return component.get_entity(entity_id)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_speech_to_text_engine(
|
||||||
|
hass: HomeAssistant, engine_id: str
|
||||||
|
) -> SpeechToTextEntity | Provider | None:
|
||||||
|
"""Return stt entity or legacy provider."""
|
||||||
|
if entity := async_get_speech_to_text_entity(hass, engine_id):
|
||||||
|
return entity
|
||||||
|
return async_get_provider(hass, engine_id)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_speech_to_text_languages(hass: HomeAssistant) -> set[str]:
|
def async_get_speech_to_text_languages(hass: HomeAssistant) -> set[str]:
|
||||||
"""Return a set with the union of languages supported by stt engines."""
|
"""Return a set with the union of languages supported by stt engines."""
|
||||||
|
|
|
@ -27,6 +27,15 @@ from .const import (
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_default_provider(hass: HomeAssistant) -> str | None:
|
||||||
|
"""Return the domain of the default provider."""
|
||||||
|
if "cloud" in hass.data[DATA_PROVIDERS]:
|
||||||
|
return "cloud"
|
||||||
|
|
||||||
|
return next(iter(hass.data[DATA_PROVIDERS]), None)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_provider(
|
def async_get_provider(
|
||||||
hass: HomeAssistant, domain: str | None = None
|
hass: HomeAssistant, domain: str | None = None
|
||||||
|
@ -35,13 +44,8 @@ def async_get_provider(
|
||||||
if domain:
|
if domain:
|
||||||
return hass.data[DATA_PROVIDERS].get(domain)
|
return hass.data[DATA_PROVIDERS].get(domain)
|
||||||
|
|
||||||
if not hass.data[DATA_PROVIDERS]:
|
provider = async_default_provider(hass)
|
||||||
return None
|
return hass.data[DATA_PROVIDERS][provider] if provider is not None else None
|
||||||
|
|
||||||
if "cloud" in hass.data[DATA_PROVIDERS]:
|
|
||||||
return hass.data[DATA_PROVIDERS]["cloud"]
|
|
||||||
|
|
||||||
return next(iter(hass.data[DATA_PROVIDERS].values()))
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Test Voice Assistant init."""
|
"""Test Voice Assistant init."""
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
|
||||||
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, stt
|
from homeassistant.components import assist_pipeline, stt
|
||||||
|
@ -184,3 +185,63 @@ async def test_pipeline_from_audio_stream_entity(
|
||||||
|
|
||||||
assert processed == snapshot
|
assert processed == snapshot
|
||||||
assert mock_stt_provider_entity.received == [b"part1", b"part2"]
|
assert mock_stt_provider_entity.received == [b"part1", b"part2"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_from_audio_stream_no_stt(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
mock_stt_provider: MockSttProvider,
|
||||||
|
init_components,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating a pipeline from an audio stream.
|
||||||
|
|
||||||
|
In this test, the pipeline does not support stt
|
||||||
|
"""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def audio_data():
|
||||||
|
yield b"part1"
|
||||||
|
yield b"part2"
|
||||||
|
yield b""
|
||||||
|
|
||||||
|
# Create a pipeline without stt support
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/create",
|
||||||
|
"conversation_engine": "homeassistant",
|
||||||
|
"conversation_language": "en-US",
|
||||||
|
"language": "en",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": None,
|
||||||
|
"stt_language": None,
|
||||||
|
"tts_engine": "test",
|
||||||
|
"tts_language": "en-AU",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
pipeline_id = msg["result"]["id"]
|
||||||
|
|
||||||
|
# Try to use the created pipeline
|
||||||
|
with pytest.raises(assist_pipeline.pipeline.PipelineRunValidationError):
|
||||||
|
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||||
|
hass,
|
||||||
|
Context(),
|
||||||
|
events.append,
|
||||||
|
stt.SpeechMetadata(
|
||||||
|
language="en-UK",
|
||||||
|
format=stt.AudioFormats.WAV,
|
||||||
|
codec=stt.AudioCodecs.PCM,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
),
|
||||||
|
audio_data(),
|
||||||
|
pipeline_id=pipeline_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not events
|
||||||
|
|
|
@ -18,7 +18,9 @@ from homeassistant.components.stt import (
|
||||||
SpeechResult,
|
SpeechResult,
|
||||||
SpeechResultState,
|
SpeechResultState,
|
||||||
SpeechToTextEntity,
|
SpeechToTextEntity,
|
||||||
|
async_default_engine,
|
||||||
async_get_provider,
|
async_get_provider,
|
||||||
|
async_get_speech_to_text_engine,
|
||||||
)
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow
|
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow
|
||||||
from homeassistant.core import HomeAssistant, State
|
from homeassistant.core import HomeAssistant, State
|
||||||
|
@ -349,6 +351,9 @@ async def test_get_provider(
|
||||||
await mock_setup(hass, tmp_path, mock_provider)
|
await mock_setup(hass, tmp_path, mock_provider)
|
||||||
assert mock_provider == async_get_provider(hass, TEST_DOMAIN)
|
assert mock_provider == async_get_provider(hass, TEST_DOMAIN)
|
||||||
|
|
||||||
|
# Test getting the default provider
|
||||||
|
assert mock_provider == async_get_provider(hass)
|
||||||
|
|
||||||
|
|
||||||
async def test_config_entry_unload(
|
async def test_config_entry_unload(
|
||||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||||
|
@ -444,3 +449,84 @@ async def test_ws_list_engines(
|
||||||
assert msg["result"] == {
|
assert msg["result"] == {
|
||||||
"providers": [{"engine_id": engine_id, "supported_languages": ["de-CH", "de"]}]
|
"providers": [{"engine_id": engine_id, "supported_languages": ["de-CH", "de"]}]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
|
||||||
|
"""Test async_default_engine."""
|
||||||
|
assert await async_setup_component(hass, "stt", {"stt": {}})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert async_default_engine(hass) is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_default_engine(hass: HomeAssistant, tmp_path: Path) -> None:
|
||||||
|
"""Test async_default_engine."""
|
||||||
|
mock_stt_platform(
|
||||||
|
hass,
|
||||||
|
tmp_path,
|
||||||
|
TEST_DOMAIN,
|
||||||
|
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||||
|
)
|
||||||
|
assert await async_setup_component(hass, "stt", {"stt": {"platform": TEST_DOMAIN}})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert async_default_engine(hass) == TEST_DOMAIN
|
||||||
|
|
||||||
|
|
||||||
|
async def test_default_engine_entity(
|
||||||
|
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||||
|
) -> None:
|
||||||
|
"""Test async_default_engine."""
|
||||||
|
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||||
|
|
||||||
|
assert async_default_engine(hass) == f"{DOMAIN}.{TEST_DOMAIN}"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_default_engine_prefer_cloud(hass: HomeAssistant, tmp_path: Path) -> None:
|
||||||
|
"""Test async_default_engine."""
|
||||||
|
mock_stt_platform(
|
||||||
|
hass,
|
||||||
|
tmp_path,
|
||||||
|
TEST_DOMAIN,
|
||||||
|
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||||
|
)
|
||||||
|
mock_stt_platform(
|
||||||
|
hass,
|
||||||
|
tmp_path,
|
||||||
|
"cloud",
|
||||||
|
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||||
|
)
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]}
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert async_default_engine(hass) == "cloud"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_engine_legacy(
|
||||||
|
hass: HomeAssistant, tmp_path: Path, mock_provider: MockProvider
|
||||||
|
) -> None:
|
||||||
|
"""Test async_get_speech_to_text_engine."""
|
||||||
|
mock_stt_platform(
|
||||||
|
hass,
|
||||||
|
tmp_path,
|
||||||
|
TEST_DOMAIN,
|
||||||
|
async_get_engine=AsyncMock(return_value=mock_provider),
|
||||||
|
)
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]}
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert async_get_speech_to_text_engine(hass, "no_such_provider") is None
|
||||||
|
assert async_get_speech_to_text_engine(hass, "test") is mock_provider
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_engine_entity(
|
||||||
|
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
|
||||||
|
) -> None:
|
||||||
|
"""Test async_get_speech_to_text_engine."""
|
||||||
|
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||||
|
|
||||||
|
assert async_get_speech_to_text_engine(hass, "stt.test") is mock_provider_entity
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue