Add cloud tts entity (#108293)
* Add cloud tts entity * Test test_login_view_missing_entity * Fix pipeline iteration for migration * Update tests * Make migration more strict * Fix docstring
This commit is contained in:
parent
d0da457a04
commit
e086cd9fef
12 changed files with 428 additions and 102 deletions
|
@ -65,7 +65,7 @@ from .subscription import async_subscription_info
|
|||
|
||||
DEFAULT_MODE = MODE_PROD
|
||||
|
||||
PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT]
|
||||
PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT, Platform.TTS]
|
||||
|
||||
SERVICE_REMOTE_CONNECT = "remote_connect"
|
||||
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
|
||||
|
@ -288,9 +288,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
loaded = False
|
||||
stt_platform_loaded = asyncio.Event()
|
||||
tts_platform_loaded = asyncio.Event()
|
||||
stt_tts_entities_added = asyncio.Event()
|
||||
hass.data[DATA_PLATFORMS_SETUP] = {
|
||||
Platform.STT: stt_platform_loaded,
|
||||
Platform.TTS: tts_platform_loaded,
|
||||
"stt_tts_entities_added": stt_tts_entities_added,
|
||||
}
|
||||
|
||||
async def _on_start() -> None:
|
||||
|
@ -330,6 +332,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
|
||||
account_link.async_setup(hass)
|
||||
|
||||
# Load legacy tts platform for backwards compatibility.
|
||||
hass.async_create_task(
|
||||
async_load_platform(
|
||||
hass,
|
||||
|
@ -377,8 +380,10 @@ def _remote_handle_prefs_updated(cloud: Cloud[CloudClient]) -> None:
|
|||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT]
|
||||
stt_platform_loaded.set()
|
||||
stt_tts_entities_added: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][
|
||||
"stt_tts_entities_added"
|
||||
]
|
||||
stt_tts_entities_added.set()
|
||||
|
||||
return True
|
||||
|
||||
|
|
|
@ -9,16 +9,23 @@ from homeassistant.components.assist_pipeline import (
|
|||
)
|
||||
from homeassistant.components.conversation import HOME_ASSISTANT_AGENT
|
||||
from homeassistant.components.stt import DOMAIN as STT_DOMAIN
|
||||
from homeassistant.components.tts import DOMAIN as TTS_DOMAIN
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
import homeassistant.helpers.entity_registry as er
|
||||
|
||||
from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID
|
||||
from .const import (
|
||||
DATA_PLATFORMS_SETUP,
|
||||
DOMAIN,
|
||||
STT_ENTITY_UNIQUE_ID,
|
||||
TTS_ENTITY_UNIQUE_ID,
|
||||
)
|
||||
|
||||
|
||||
async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
||||
"""Create a cloud assist pipeline."""
|
||||
# Wait for stt and tts platforms to set up before creating the pipeline.
|
||||
# Wait for stt and tts platforms to set up and entities to be added
|
||||
# before creating the pipeline.
|
||||
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
|
||||
await asyncio.gather(*(event.wait() for event in platforms_setup.values()))
|
||||
# Make sure the pipeline store is loaded, needed because assist_pipeline
|
||||
|
@ -29,8 +36,11 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
|||
new_stt_engine_id = entity_registry.async_get_entity_id(
|
||||
STT_DOMAIN, DOMAIN, STT_ENTITY_UNIQUE_ID
|
||||
)
|
||||
if new_stt_engine_id is None:
|
||||
# If there's no cloud stt entity, we can't create a cloud pipeline.
|
||||
new_tts_engine_id = entity_registry.async_get_entity_id(
|
||||
TTS_DOMAIN, DOMAIN, TTS_ENTITY_UNIQUE_ID
|
||||
)
|
||||
if new_stt_engine_id is None or new_tts_engine_id is None:
|
||||
# If there's no cloud stt or tts entity, we can't create a cloud pipeline.
|
||||
return None
|
||||
|
||||
def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
|
||||
|
@ -43,7 +53,7 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
|||
if (
|
||||
pipeline.conversation_engine == HOME_ASSISTANT_AGENT
|
||||
and pipeline.stt_engine in (DOMAIN, new_stt_engine_id)
|
||||
and pipeline.tts_engine == DOMAIN
|
||||
and pipeline.tts_engine in (DOMAIN, new_tts_engine_id)
|
||||
):
|
||||
return pipeline.id
|
||||
return None
|
||||
|
@ -52,7 +62,7 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
|||
cloud_pipeline := await async_create_default_pipeline(
|
||||
hass,
|
||||
stt_engine_id=new_stt_engine_id,
|
||||
tts_engine_id=DOMAIN,
|
||||
tts_engine_id=new_tts_engine_id,
|
||||
pipeline_name="Home Assistant Cloud",
|
||||
)
|
||||
) is None:
|
||||
|
@ -61,25 +71,34 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
|||
return cloud_pipeline.id
|
||||
|
||||
|
||||
async def async_migrate_cloud_pipeline_stt_engine(
|
||||
hass: HomeAssistant, stt_engine_id: str
|
||||
async def async_migrate_cloud_pipeline_engine(
|
||||
hass: HomeAssistant, platform: Platform, engine_id: str
|
||||
) -> None:
|
||||
"""Migrate the speech-to-text engine in the cloud assist pipeline."""
|
||||
# Migrate existing pipelines with cloud stt to use new cloud stt engine id.
|
||||
# Added in 2024.01.0. Can be removed in 2025.01.0.
|
||||
"""Migrate the pipeline engines in the cloud assist pipeline."""
|
||||
# Migrate existing pipelines with cloud stt or tts to use new cloud engine id.
|
||||
# Added in 2024.02.0. Can be removed in 2025.02.0.
|
||||
|
||||
# We need to make sure that both stt and tts are loaded before this migration.
|
||||
# Assist pipeline will call default engine when setting up the store.
|
||||
# Wait for the stt or tts platform loaded event here.
|
||||
if platform == Platform.STT:
|
||||
wait_for_platform = Platform.TTS
|
||||
pipeline_attribute = "stt_engine"
|
||||
elif platform == Platform.TTS:
|
||||
wait_for_platform = Platform.STT
|
||||
pipeline_attribute = "tts_engine"
|
||||
else:
|
||||
raise ValueError(f"Invalid platform {platform}")
|
||||
|
||||
# We need to make sure that tts is loaded before this migration.
|
||||
# Assist pipeline will call default engine of tts when setting up the store.
|
||||
# Wait for the tts platform loaded event here.
|
||||
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
|
||||
await platforms_setup[Platform.TTS].wait()
|
||||
await platforms_setup[wait_for_platform].wait()
|
||||
|
||||
# Make sure the pipeline store is loaded, needed because assist_pipeline
|
||||
# is an after dependency of cloud
|
||||
await async_setup_pipeline_store(hass)
|
||||
|
||||
kwargs: dict[str, str] = {pipeline_attribute: engine_id}
|
||||
pipelines = async_get_pipelines(hass)
|
||||
for pipeline in pipelines:
|
||||
if pipeline.stt_engine != DOMAIN:
|
||||
continue
|
||||
await async_update_pipeline(hass, pipeline, stt_engine=stt_engine_id)
|
||||
if getattr(pipeline, pipeline_attribute) == DOMAIN:
|
||||
await async_update_pipeline(hass, pipeline, **kwargs)
|
||||
|
|
|
@ -73,3 +73,4 @@ MODE_PROD = "production"
|
|||
DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")
|
||||
|
||||
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
|
||||
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
|
||||
|
|
|
@ -104,10 +104,18 @@ class CloudPreferences:
|
|||
@callback
|
||||
def async_listen_updates(
|
||||
self, listener: Callable[[CloudPreferences], Coroutine[Any, Any, None]]
|
||||
) -> None:
|
||||
) -> Callable[[], None]:
|
||||
"""Listen for updates to the preferences."""
|
||||
|
||||
@callback
|
||||
def unsubscribe() -> None:
|
||||
"""Remove the listener."""
|
||||
self._listeners.remove(listener)
|
||||
|
||||
self._listeners.append(listener)
|
||||
|
||||
return unsubscribe
|
||||
|
||||
async def async_update(
|
||||
self,
|
||||
*,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Support for the cloud for speech to text service."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
import logging
|
||||
|
||||
|
@ -19,12 +20,13 @@ from homeassistant.components.stt import (
|
|||
SpeechToTextEntity,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .assist_pipeline import async_migrate_cloud_pipeline_stt_engine
|
||||
from .assist_pipeline import async_migrate_cloud_pipeline_engine
|
||||
from .client import CloudClient
|
||||
from .const import DOMAIN, STT_ENTITY_UNIQUE_ID
|
||||
from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -35,18 +37,20 @@ async def async_setup_entry(
|
|||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Home Assistant Cloud speech platform via config entry."""
|
||||
stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT]
|
||||
stt_platform_loaded.set()
|
||||
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
|
||||
async_add_entities([CloudProviderEntity(cloud)])
|
||||
|
||||
|
||||
class CloudProviderEntity(SpeechToTextEntity):
|
||||
"""NabuCasa speech API provider."""
|
||||
"""Home Assistant Cloud speech API provider."""
|
||||
|
||||
_attr_name = "Home Assistant Cloud"
|
||||
_attr_unique_id = STT_ENTITY_UNIQUE_ID
|
||||
|
||||
def __init__(self, cloud: Cloud[CloudClient]) -> None:
|
||||
"""Home Assistant NabuCasa Speech to text."""
|
||||
"""Initialize cloud Speech to text entity."""
|
||||
self.cloud = cloud
|
||||
|
||||
@property
|
||||
|
@ -81,7 +85,9 @@ class CloudProviderEntity(SpeechToTextEntity):
|
|||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity is about to be added to hass."""
|
||||
await async_migrate_cloud_pipeline_stt_engine(self.hass, self.entity_id)
|
||||
await async_migrate_cloud_pipeline_engine(
|
||||
self.hass, platform=Platform.STT, engine_id=self.entity_id
|
||||
)
|
||||
|
||||
async def async_process_audio_stream(
|
||||
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Support for the cloud for text-to-speech service."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
|
@ -12,16 +13,21 @@ from homeassistant.components.tts import (
|
|||
ATTR_AUDIO_OUTPUT,
|
||||
ATTR_VOICE,
|
||||
CONF_LANG,
|
||||
PLATFORM_SCHEMA,
|
||||
PLATFORM_SCHEMA as TTS_PLATFORM_SCHEMA,
|
||||
Provider,
|
||||
TextToSpeechEntity,
|
||||
TtsAudioType,
|
||||
Voice,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from .assist_pipeline import async_migrate_cloud_pipeline_engine
|
||||
from .client import CloudClient
|
||||
from .const import DOMAIN
|
||||
from .const import DATA_PLATFORMS_SETUP, DOMAIN, TTS_ENTITY_UNIQUE_ID
|
||||
from .prefs import CloudPreferences
|
||||
|
||||
ATTR_GENDER = "gender"
|
||||
|
@ -48,7 +54,7 @@ def validate_lang(value: dict[str, Any]) -> dict[str, Any]:
|
|||
|
||||
|
||||
PLATFORM_SCHEMA = vol.All(
|
||||
PLATFORM_SCHEMA.extend(
|
||||
TTS_PLATFORM_SCHEMA.extend(
|
||||
{
|
||||
vol.Optional(CONF_LANG): str,
|
||||
vol.Optional(ATTR_GENDER): str,
|
||||
|
@ -81,8 +87,95 @@ async def async_get_engine(
|
|||
return cloud_provider
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Home Assistant Cloud text-to-speech platform."""
|
||||
tts_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.TTS]
|
||||
tts_platform_loaded.set()
|
||||
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
|
||||
async_add_entities([CloudTTSEntity(cloud)])
|
||||
|
||||
|
||||
class CloudTTSEntity(TextToSpeechEntity):
|
||||
"""Home Assistant Cloud text-to-speech entity."""
|
||||
|
||||
_attr_name = "Home Assistant Cloud"
|
||||
_attr_unique_id = TTS_ENTITY_UNIQUE_ID
|
||||
|
||||
def __init__(self, cloud: Cloud[CloudClient]) -> None:
|
||||
"""Initialize cloud text-to-speech entity."""
|
||||
self.cloud = cloud
|
||||
self._language, self._gender = cloud.client.prefs.tts_default_voice
|
||||
|
||||
async def _sync_prefs(self, prefs: CloudPreferences) -> None:
|
||||
"""Sync preferences."""
|
||||
self._language, self._gender = prefs.tts_default_voice
|
||||
|
||||
@property
|
||||
def default_language(self) -> str:
|
||||
"""Return the default language."""
|
||||
return self._language
|
||||
|
||||
@property
|
||||
def default_options(self) -> dict[str, Any]:
|
||||
"""Return a dict include default options."""
|
||||
return {
|
||||
ATTR_GENDER: self._gender,
|
||||
ATTR_AUDIO_OUTPUT: AudioOutput.MP3,
|
||||
}
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return list of supported languages."""
|
||||
return SUPPORT_LANGUAGES
|
||||
|
||||
@property
|
||||
def supported_options(self) -> list[str]:
|
||||
"""Return list of supported options like voice, emotion."""
|
||||
return [ATTR_GENDER, ATTR_VOICE, ATTR_AUDIO_OUTPUT]
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Handle entity which will be added."""
|
||||
await super().async_added_to_hass()
|
||||
await async_migrate_cloud_pipeline_engine(
|
||||
self.hass, platform=Platform.TTS, engine_id=self.entity_id
|
||||
)
|
||||
self.async_on_remove(
|
||||
self.cloud.client.prefs.async_listen_updates(self._sync_prefs)
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
|
||||
"""Return a list of supported voices for a language."""
|
||||
if not (voices := TTS_VOICES.get(language)):
|
||||
return None
|
||||
return [Voice(voice, voice) for voice in voices]
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load TTS from Home Assistant Cloud."""
|
||||
# Process TTS
|
||||
try:
|
||||
data = await self.cloud.voice.process_tts(
|
||||
text=message,
|
||||
language=language,
|
||||
gender=options.get(ATTR_GENDER),
|
||||
voice=options.get(ATTR_VOICE),
|
||||
output=options[ATTR_AUDIO_OUTPUT],
|
||||
)
|
||||
except VoiceError as err:
|
||||
_LOGGER.error("Voice error: %s", err)
|
||||
return (None, None)
|
||||
|
||||
return (str(options[ATTR_AUDIO_OUTPUT].value), data)
|
||||
|
||||
|
||||
class CloudProvider(Provider):
|
||||
"""NabuCasa Cloud speech API provider."""
|
||||
"""Home Assistant Cloud speech API provider."""
|
||||
|
||||
def __init__(
|
||||
self, cloud: Cloud[CloudClient], language: str | None, gender: str | None
|
||||
|
@ -136,7 +229,7 @@ class CloudProvider(Provider):
|
|||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load TTS from NabuCasa Cloud."""
|
||||
"""Load TTS from Home Assistant Cloud."""
|
||||
# Process TTS
|
||||
try:
|
||||
data = await self.cloud.voice.process_tts(
|
||||
|
|
|
@ -7,6 +7,54 @@ from homeassistant.components import cloud
|
|||
from homeassistant.components.cloud import const, prefs as cloud_prefs
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
PIPELINE_DATA = {
|
||||
"items": [
|
||||
{
|
||||
"conversation_engine": "conversation_engine_1",
|
||||
"conversation_language": "language_1",
|
||||
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||
"language": "language_1",
|
||||
"name": "Home Assistant Cloud",
|
||||
"stt_engine": "cloud",
|
||||
"stt_language": "language_1",
|
||||
"tts_engine": "cloud",
|
||||
"tts_language": "language_1",
|
||||
"tts_voice": "Arnold Schwarzenegger",
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
},
|
||||
{
|
||||
"conversation_engine": "conversation_engine_2",
|
||||
"conversation_language": "language_2",
|
||||
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
|
||||
"language": "language_2",
|
||||
"name": "name_2",
|
||||
"stt_engine": "stt_engine_2",
|
||||
"stt_language": "language_2",
|
||||
"tts_engine": "tts_engine_2",
|
||||
"tts_language": "language_2",
|
||||
"tts_voice": "The Voice",
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
},
|
||||
{
|
||||
"conversation_engine": "conversation_engine_3",
|
||||
"conversation_language": "language_3",
|
||||
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
|
||||
"language": "language_3",
|
||||
"name": "name_3",
|
||||
"stt_engine": None,
|
||||
"stt_language": None,
|
||||
"tts_engine": None,
|
||||
"tts_language": None,
|
||||
"tts_voice": None,
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
},
|
||||
],
|
||||
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||
}
|
||||
|
||||
|
||||
async def mock_cloud(hass, config=None):
|
||||
"""Mock cloud."""
|
||||
|
|
|
@ -15,11 +15,22 @@ import jwt
|
|||
import pytest
|
||||
|
||||
from homeassistant.components.cloud import CloudClient, const, prefs
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from . import mock_cloud, mock_cloud_prefs
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def load_homeassistant(hass: HomeAssistant) -> None:
|
||||
"""Load the homeassistant integration.
|
||||
|
||||
This is needed for the cloud integration to work.
|
||||
"""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
|
||||
@pytest.fixture(name="cloud")
|
||||
async def cloud_fixture() -> AsyncGenerator[MagicMock, None]:
|
||||
"""Mock the cloud object.
|
||||
|
|
16
tests/components/cloud/test_assist_pipeline.py
Normal file
16
tests/components/cloud/test_assist_pipeline.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
"""Test the cloud assist pipeline."""
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.cloud.assist_pipeline import (
|
||||
async_migrate_cloud_pipeline_engine,
|
||||
)
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
async def test_migrate_pipeline_invalid_platform(hass: HomeAssistant) -> None:
|
||||
"""Test migrate pipeline with invalid platform."""
|
||||
with pytest.raises(ValueError):
|
||||
await async_migrate_cloud_pipeline_engine(
|
||||
hass, Platform.BINARY_SENSOR, "test-engine-id"
|
||||
)
|
|
@ -147,15 +147,19 @@ async def test_google_actions_sync_fails(
|
|||
assert mock_request_sync.call_count == 1
|
||||
|
||||
|
||||
async def test_login_view_missing_stt_entity(
|
||||
@pytest.mark.parametrize(
|
||||
"entity_id", ["stt.home_assistant_cloud", "tts.home_assistant_cloud"]
|
||||
)
|
||||
async def test_login_view_missing_entity(
|
||||
hass: HomeAssistant,
|
||||
setup_cloud: None,
|
||||
entity_registry: er.EntityRegistry,
|
||||
hass_client: ClientSessionGenerator,
|
||||
entity_id: str,
|
||||
) -> None:
|
||||
"""Test logging in when the cloud stt entity is missing."""
|
||||
# Make sure that the cloud stt entity does not exist.
|
||||
entity_registry.async_remove("stt.home_assistant_cloud")
|
||||
"""Test logging in when a cloud assist pipeline needed entity is missing."""
|
||||
# Make sure that the cloud entity does not exist.
|
||||
entity_registry.async_remove(entity_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
cloud_client = await hass_client()
|
||||
|
@ -243,7 +247,7 @@ async def test_login_view_create_pipeline(
|
|||
create_pipeline_mock.assert_awaited_once_with(
|
||||
hass,
|
||||
stt_engine_id="stt.home_assistant_cloud",
|
||||
tts_engine_id="cloud",
|
||||
tts_engine_id="tts.home_assistant_cloud",
|
||||
pipeline_name="Home Assistant Cloud",
|
||||
)
|
||||
|
||||
|
@ -282,7 +286,7 @@ async def test_login_view_create_pipeline_fail(
|
|||
create_pipeline_mock.assert_awaited_once_with(
|
||||
hass,
|
||||
stt_engine_id="stt.home_assistant_cloud",
|
||||
tts_engine_id="cloud",
|
||||
tts_engine_id="tts.home_assistant_cloud",
|
||||
pipeline_name="Home Assistant Cloud",
|
||||
)
|
||||
|
||||
|
|
|
@ -14,62 +14,10 @@ from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
|||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import PIPELINE_DATA
|
||||
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
PIPELINE_DATA = {
|
||||
"items": [
|
||||
{
|
||||
"conversation_engine": "conversation_engine_1",
|
||||
"conversation_language": "language_1",
|
||||
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||
"language": "language_1",
|
||||
"name": "Home Assistant Cloud",
|
||||
"stt_engine": "cloud",
|
||||
"stt_language": "language_1",
|
||||
"tts_engine": "cloud",
|
||||
"tts_language": "language_1",
|
||||
"tts_voice": "Arnold Schwarzenegger",
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
},
|
||||
{
|
||||
"conversation_engine": "conversation_engine_2",
|
||||
"conversation_language": "language_2",
|
||||
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
|
||||
"language": "language_2",
|
||||
"name": "name_2",
|
||||
"stt_engine": "stt_engine_2",
|
||||
"stt_language": "language_2",
|
||||
"tts_engine": "tts_engine_2",
|
||||
"tts_language": "language_2",
|
||||
"tts_voice": "The Voice",
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
},
|
||||
{
|
||||
"conversation_engine": "conversation_engine_3",
|
||||
"conversation_language": "language_3",
|
||||
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
|
||||
"language": "language_3",
|
||||
"name": "name_3",
|
||||
"stt_engine": None,
|
||||
"stt_language": None,
|
||||
"tts_engine": None,
|
||||
"tts_language": None,
|
||||
"tts_voice": None,
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
},
|
||||
],
|
||||
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def load_homeassistant(hass: HomeAssistant) -> None:
|
||||
"""Load the homeassistant integration."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def delay_save_fixture() -> AsyncGenerator[None, None]:
|
||||
|
@ -143,6 +91,7 @@ async def test_migrating_pipelines(
|
|||
hass_storage: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test migrating pipelines when cloud stt entity is added."""
|
||||
entity_id = "stt.home_assistant_cloud"
|
||||
cloud.voice.process_stt = AsyncMock(
|
||||
return_value=STTResponse(True, "Turn the Kitchen Lights on")
|
||||
)
|
||||
|
@ -157,18 +106,18 @@ async def test_migrating_pipelines(
|
|||
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
on_start_callback = cloud.register_on_start.call_args[0][0]
|
||||
await on_start_callback()
|
||||
await cloud.login("test-user", "test-pass")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get("stt.home_assistant_cloud")
|
||||
state = hass.states.get(entity_id)
|
||||
assert state
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
# The stt engine should be updated to the new cloud stt engine id.
|
||||
# The stt/tts engines should have been updated to the new cloud engine ids.
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"] == entity_id
|
||||
assert (
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"]
|
||||
== "stt.home_assistant_cloud"
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"]
|
||||
== "tts.home_assistant_cloud"
|
||||
)
|
||||
|
||||
# The other items should stay the same.
|
||||
|
@ -189,7 +138,6 @@ async def test_migrating_pipelines(
|
|||
hass_storage[STORAGE_KEY]["data"]["items"][0]["name"] == "Home Assistant Cloud"
|
||||
)
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_language"] == "language_1"
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"] == "cloud"
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_language"] == "language_1"
|
||||
assert (
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_voice"]
|
||||
|
|
|
@ -1,23 +1,36 @@
|
|||
"""Tests for cloud tts."""
|
||||
from collections.abc import Callable, Coroutine
|
||||
from collections.abc import AsyncGenerator, Callable, Coroutine
|
||||
from copy import deepcopy
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from hass_nabucasa.voice import MAP_VOICE, VoiceError, VoiceTokenError
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY
|
||||
from homeassistant.components.cloud import DOMAIN, const, tts
|
||||
from homeassistant.components.tts import DOMAIN as TTS_DOMAIN
|
||||
from homeassistant.components.tts.helper import get_engine_instance
|
||||
from homeassistant.config import async_process_ha_core_config
|
||||
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_registry import EntityRegistry
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import PIPELINE_DATA
|
||||
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def delay_save_fixture() -> AsyncGenerator[None, None]:
|
||||
"""Load the homeassistant integration."""
|
||||
with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def internal_url_mock(hass: HomeAssistant) -> None:
|
||||
"""Mock internal URL of the instance."""
|
||||
|
@ -70,6 +83,10 @@ def test_schema() -> None:
|
|||
"gender": "female",
|
||||
},
|
||||
),
|
||||
(
|
||||
"tts.home_assistant_cloud",
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_prefs_default_voice(
|
||||
|
@ -104,9 +121,17 @@ async def test_prefs_default_voice(
|
|||
assert engine.default_options == {"gender": "male", "audio_output": "mp3"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"engine_id",
|
||||
[
|
||||
DOMAIN,
|
||||
"tts.home_assistant_cloud",
|
||||
],
|
||||
)
|
||||
async def test_provider_properties(
|
||||
hass: HomeAssistant,
|
||||
cloud: MagicMock,
|
||||
engine_id: str,
|
||||
) -> None:
|
||||
"""Test cloud provider."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
@ -115,7 +140,7 @@ async def test_provider_properties(
|
|||
on_start_callback = cloud.register_on_start.call_args[0][0]
|
||||
await on_start_callback()
|
||||
|
||||
engine = get_engine_instance(hass, DOMAIN)
|
||||
engine = get_engine_instance(hass, engine_id)
|
||||
|
||||
assert engine is not None
|
||||
assert engine.supported_options == ["gender", "voice", "audio_output"]
|
||||
|
@ -132,6 +157,7 @@ async def test_provider_properties(
|
|||
[
|
||||
({"platform": DOMAIN}, DOMAIN),
|
||||
({"engine_id": DOMAIN}, DOMAIN),
|
||||
({"engine_id": "tts.home_assistant_cloud"}, "tts.home_assistant_cloud"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -241,3 +267,144 @@ async def test_get_tts_audio_logged_out(
|
|||
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
|
||||
assert mock_process_tts.call_args.kwargs["gender"] == "female"
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mock_process_tts_return_value", "mock_process_tts_side_effect"),
|
||||
[
|
||||
(b"", None),
|
||||
(None, VoiceError("Boom!")),
|
||||
],
|
||||
)
|
||||
async def test_tts_entity(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
entity_registry: EntityRegistry,
|
||||
cloud: MagicMock,
|
||||
mock_process_tts_return_value: bytes | None,
|
||||
mock_process_tts_side_effect: Exception | None,
|
||||
) -> None:
|
||||
"""Test text-to-speech entity."""
|
||||
mock_process_tts = AsyncMock(
|
||||
return_value=mock_process_tts_return_value,
|
||||
side_effect=mock_process_tts_side_effect,
|
||||
)
|
||||
cloud.voice.process_tts = mock_process_tts
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
|
||||
await hass.async_block_till_done()
|
||||
on_start_callback = cloud.register_on_start.call_args[0][0]
|
||||
await on_start_callback()
|
||||
client = await hass_client()
|
||||
entity_id = "tts.home_assistant_cloud"
|
||||
|
||||
state = hass.states.get(entity_id)
|
||||
assert state
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
url = "/api/tts_get_url"
|
||||
data = {
|
||||
"engine_id": entity_id,
|
||||
"message": "There is someone at the door.",
|
||||
}
|
||||
|
||||
req = await client.post(url, json=data)
|
||||
assert req.status == HTTPStatus.OK
|
||||
response = await req.json()
|
||||
|
||||
assert response == {
|
||||
"url": (
|
||||
"http://example.local:8123/api/tts_proxy/"
|
||||
"42f18378fd4393d18c8dd11d03fa9563c1e54491"
|
||||
f"_en-us_e09b5a0968_{entity_id}.mp3"
|
||||
),
|
||||
"path": (
|
||||
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
|
||||
f"_en-us_e09b5a0968_{entity_id}.mp3"
|
||||
),
|
||||
}
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_process_tts.call_count == 1
|
||||
assert mock_process_tts.call_args is not None
|
||||
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
|
||||
assert mock_process_tts.call_args.kwargs["gender"] == "female"
|
||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||
|
||||
state = hass.states.get(entity_id)
|
||||
assert state
|
||||
assert state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
||||
|
||||
# Test removing the entity
|
||||
entity_registry.async_remove(entity_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get(entity_id)
|
||||
assert state is None
|
||||
|
||||
|
||||
async def test_migrating_pipelines(
|
||||
hass: HomeAssistant,
|
||||
cloud: MagicMock,
|
||||
hass_client: ClientSessionGenerator,
|
||||
hass_storage: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test migrating pipelines when cloud tts entity is added."""
|
||||
entity_id = "tts.home_assistant_cloud"
|
||||
mock_process_tts = AsyncMock(
|
||||
return_value=b"",
|
||||
)
|
||||
cloud.voice.process_tts = mock_process_tts
|
||||
hass_storage[STORAGE_KEY] = {
|
||||
"version": 1,
|
||||
"minor_version": 1,
|
||||
"key": "assist_pipeline.pipelines",
|
||||
"data": deepcopy(PIPELINE_DATA),
|
||||
}
|
||||
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await cloud.login("test-user", "test-pass")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get(entity_id)
|
||||
assert state
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
# The stt/tts engines should have been updated to the new cloud engine ids.
|
||||
assert (
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"]
|
||||
== "stt.home_assistant_cloud"
|
||||
)
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"] == entity_id
|
||||
|
||||
# The other items should stay the same.
|
||||
assert (
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_engine"]
|
||||
== "conversation_engine_1"
|
||||
)
|
||||
assert (
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_language"]
|
||||
== "language_1"
|
||||
)
|
||||
assert (
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["id"]
|
||||
== "01GX8ZWBAQYWNB1XV3EXEZ75DY"
|
||||
)
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["language"] == "language_1"
|
||||
assert (
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["name"] == "Home Assistant Cloud"
|
||||
)
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_language"] == "language_1"
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_language"] == "language_1"
|
||||
assert (
|
||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_voice"]
|
||||
== "Arnold Schwarzenegger"
|
||||
)
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_entity"] is None
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_id"] is None
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][1] == PIPELINE_DATA["items"][1]
|
||||
assert hass_storage[STORAGE_KEY]["data"]["items"][2] == PIPELINE_DATA["items"][2]
|
||||
|
|
Loading…
Add table
Reference in a new issue