Add cloud tts entity (#108293)
* Add cloud tts entity * Test test_login_view_missing_entity * Fix pipeline iteration for migration * Update tests * Make migration more strict * Fix docstring
This commit is contained in:
parent
d0da457a04
commit
e086cd9fef
12 changed files with 428 additions and 102 deletions
|
@ -65,7 +65,7 @@ from .subscription import async_subscription_info
|
||||||
|
|
||||||
DEFAULT_MODE = MODE_PROD
|
DEFAULT_MODE = MODE_PROD
|
||||||
|
|
||||||
PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT]
|
PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT, Platform.TTS]
|
||||||
|
|
||||||
SERVICE_REMOTE_CONNECT = "remote_connect"
|
SERVICE_REMOTE_CONNECT = "remote_connect"
|
||||||
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
|
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
|
||||||
|
@ -288,9 +288,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
loaded = False
|
loaded = False
|
||||||
stt_platform_loaded = asyncio.Event()
|
stt_platform_loaded = asyncio.Event()
|
||||||
tts_platform_loaded = asyncio.Event()
|
tts_platform_loaded = asyncio.Event()
|
||||||
|
stt_tts_entities_added = asyncio.Event()
|
||||||
hass.data[DATA_PLATFORMS_SETUP] = {
|
hass.data[DATA_PLATFORMS_SETUP] = {
|
||||||
Platform.STT: stt_platform_loaded,
|
Platform.STT: stt_platform_loaded,
|
||||||
Platform.TTS: tts_platform_loaded,
|
Platform.TTS: tts_platform_loaded,
|
||||||
|
"stt_tts_entities_added": stt_tts_entities_added,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _on_start() -> None:
|
async def _on_start() -> None:
|
||||||
|
@ -330,6 +332,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
|
|
||||||
account_link.async_setup(hass)
|
account_link.async_setup(hass)
|
||||||
|
|
||||||
|
# Load legacy tts platform for backwards compatibility.
|
||||||
hass.async_create_task(
|
hass.async_create_task(
|
||||||
async_load_platform(
|
async_load_platform(
|
||||||
hass,
|
hass,
|
||||||
|
@ -377,8 +380,10 @@ def _remote_handle_prefs_updated(cloud: Cloud[CloudClient]) -> None:
|
||||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Set up a config entry."""
|
"""Set up a config entry."""
|
||||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||||
stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT]
|
stt_tts_entities_added: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][
|
||||||
stt_platform_loaded.set()
|
"stt_tts_entities_added"
|
||||||
|
]
|
||||||
|
stt_tts_entities_added.set()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -9,16 +9,23 @@ from homeassistant.components.assist_pipeline import (
|
||||||
)
|
)
|
||||||
from homeassistant.components.conversation import HOME_ASSISTANT_AGENT
|
from homeassistant.components.conversation import HOME_ASSISTANT_AGENT
|
||||||
from homeassistant.components.stt import DOMAIN as STT_DOMAIN
|
from homeassistant.components.stt import DOMAIN as STT_DOMAIN
|
||||||
|
from homeassistant.components.tts import DOMAIN as TTS_DOMAIN
|
||||||
from homeassistant.const import Platform
|
from homeassistant.const import Platform
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
import homeassistant.helpers.entity_registry as er
|
import homeassistant.helpers.entity_registry as er
|
||||||
|
|
||||||
from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID
|
from .const import (
|
||||||
|
DATA_PLATFORMS_SETUP,
|
||||||
|
DOMAIN,
|
||||||
|
STT_ENTITY_UNIQUE_ID,
|
||||||
|
TTS_ENTITY_UNIQUE_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
||||||
"""Create a cloud assist pipeline."""
|
"""Create a cloud assist pipeline."""
|
||||||
# Wait for stt and tts platforms to set up before creating the pipeline.
|
# Wait for stt and tts platforms to set up and entities to be added
|
||||||
|
# before creating the pipeline.
|
||||||
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
|
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
|
||||||
await asyncio.gather(*(event.wait() for event in platforms_setup.values()))
|
await asyncio.gather(*(event.wait() for event in platforms_setup.values()))
|
||||||
# Make sure the pipeline store is loaded, needed because assist_pipeline
|
# Make sure the pipeline store is loaded, needed because assist_pipeline
|
||||||
|
@ -29,8 +36,11 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
||||||
new_stt_engine_id = entity_registry.async_get_entity_id(
|
new_stt_engine_id = entity_registry.async_get_entity_id(
|
||||||
STT_DOMAIN, DOMAIN, STT_ENTITY_UNIQUE_ID
|
STT_DOMAIN, DOMAIN, STT_ENTITY_UNIQUE_ID
|
||||||
)
|
)
|
||||||
if new_stt_engine_id is None:
|
new_tts_engine_id = entity_registry.async_get_entity_id(
|
||||||
# If there's no cloud stt entity, we can't create a cloud pipeline.
|
TTS_DOMAIN, DOMAIN, TTS_ENTITY_UNIQUE_ID
|
||||||
|
)
|
||||||
|
if new_stt_engine_id is None or new_tts_engine_id is None:
|
||||||
|
# If there's no cloud stt or tts entity, we can't create a cloud pipeline.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
|
def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
|
||||||
|
@ -43,7 +53,7 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
||||||
if (
|
if (
|
||||||
pipeline.conversation_engine == HOME_ASSISTANT_AGENT
|
pipeline.conversation_engine == HOME_ASSISTANT_AGENT
|
||||||
and pipeline.stt_engine in (DOMAIN, new_stt_engine_id)
|
and pipeline.stt_engine in (DOMAIN, new_stt_engine_id)
|
||||||
and pipeline.tts_engine == DOMAIN
|
and pipeline.tts_engine in (DOMAIN, new_tts_engine_id)
|
||||||
):
|
):
|
||||||
return pipeline.id
|
return pipeline.id
|
||||||
return None
|
return None
|
||||||
|
@ -52,7 +62,7 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
||||||
cloud_pipeline := await async_create_default_pipeline(
|
cloud_pipeline := await async_create_default_pipeline(
|
||||||
hass,
|
hass,
|
||||||
stt_engine_id=new_stt_engine_id,
|
stt_engine_id=new_stt_engine_id,
|
||||||
tts_engine_id=DOMAIN,
|
tts_engine_id=new_tts_engine_id,
|
||||||
pipeline_name="Home Assistant Cloud",
|
pipeline_name="Home Assistant Cloud",
|
||||||
)
|
)
|
||||||
) is None:
|
) is None:
|
||||||
|
@ -61,25 +71,34 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
||||||
return cloud_pipeline.id
|
return cloud_pipeline.id
|
||||||
|
|
||||||
|
|
||||||
async def async_migrate_cloud_pipeline_stt_engine(
|
async def async_migrate_cloud_pipeline_engine(
|
||||||
hass: HomeAssistant, stt_engine_id: str
|
hass: HomeAssistant, platform: Platform, engine_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Migrate the speech-to-text engine in the cloud assist pipeline."""
|
"""Migrate the pipeline engines in the cloud assist pipeline."""
|
||||||
# Migrate existing pipelines with cloud stt to use new cloud stt engine id.
|
# Migrate existing pipelines with cloud stt or tts to use new cloud engine id.
|
||||||
# Added in 2024.01.0. Can be removed in 2025.01.0.
|
# Added in 2024.02.0. Can be removed in 2025.02.0.
|
||||||
|
|
||||||
|
# We need to make sure that both stt and tts are loaded before this migration.
|
||||||
|
# Assist pipeline will call default engine when setting up the store.
|
||||||
|
# Wait for the stt or tts platform loaded event here.
|
||||||
|
if platform == Platform.STT:
|
||||||
|
wait_for_platform = Platform.TTS
|
||||||
|
pipeline_attribute = "stt_engine"
|
||||||
|
elif platform == Platform.TTS:
|
||||||
|
wait_for_platform = Platform.STT
|
||||||
|
pipeline_attribute = "tts_engine"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid platform {platform}")
|
||||||
|
|
||||||
# We need to make sure that tts is loaded before this migration.
|
|
||||||
# Assist pipeline will call default engine of tts when setting up the store.
|
|
||||||
# Wait for the tts platform loaded event here.
|
|
||||||
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
|
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
|
||||||
await platforms_setup[Platform.TTS].wait()
|
await platforms_setup[wait_for_platform].wait()
|
||||||
|
|
||||||
# Make sure the pipeline store is loaded, needed because assist_pipeline
|
# Make sure the pipeline store is loaded, needed because assist_pipeline
|
||||||
# is an after dependency of cloud
|
# is an after dependency of cloud
|
||||||
await async_setup_pipeline_store(hass)
|
await async_setup_pipeline_store(hass)
|
||||||
|
|
||||||
|
kwargs: dict[str, str] = {pipeline_attribute: engine_id}
|
||||||
pipelines = async_get_pipelines(hass)
|
pipelines = async_get_pipelines(hass)
|
||||||
for pipeline in pipelines:
|
for pipeline in pipelines:
|
||||||
if pipeline.stt_engine != DOMAIN:
|
if getattr(pipeline, pipeline_attribute) == DOMAIN:
|
||||||
continue
|
await async_update_pipeline(hass, pipeline, **kwargs)
|
||||||
await async_update_pipeline(hass, pipeline, stt_engine=stt_engine_id)
|
|
||||||
|
|
|
@ -73,3 +73,4 @@ MODE_PROD = "production"
|
||||||
DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")
|
DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")
|
||||||
|
|
||||||
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
|
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
|
||||||
|
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
|
||||||
|
|
|
@ -104,10 +104,18 @@ class CloudPreferences:
|
||||||
@callback
|
@callback
|
||||||
def async_listen_updates(
|
def async_listen_updates(
|
||||||
self, listener: Callable[[CloudPreferences], Coroutine[Any, Any, None]]
|
self, listener: Callable[[CloudPreferences], Coroutine[Any, Any, None]]
|
||||||
) -> None:
|
) -> Callable[[], None]:
|
||||||
"""Listen for updates to the preferences."""
|
"""Listen for updates to the preferences."""
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def unsubscribe() -> None:
|
||||||
|
"""Remove the listener."""
|
||||||
|
self._listeners.remove(listener)
|
||||||
|
|
||||||
self._listeners.append(listener)
|
self._listeners.append(listener)
|
||||||
|
|
||||||
|
return unsubscribe
|
||||||
|
|
||||||
async def async_update(
|
async def async_update(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Support for the cloud for speech to text service."""
|
"""Support for the cloud for speech to text service."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import AsyncIterable
|
from collections.abc import AsyncIterable
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -19,12 +20,13 @@ from homeassistant.components.stt import (
|
||||||
SpeechToTextEntity,
|
SpeechToTextEntity,
|
||||||
)
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import Platform
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
from .assist_pipeline import async_migrate_cloud_pipeline_stt_engine
|
from .assist_pipeline import async_migrate_cloud_pipeline_engine
|
||||||
from .client import CloudClient
|
from .client import CloudClient
|
||||||
from .const import DOMAIN, STT_ENTITY_UNIQUE_ID
|
from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -35,18 +37,20 @@ async def async_setup_entry(
|
||||||
async_add_entities: AddEntitiesCallback,
|
async_add_entities: AddEntitiesCallback,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up Home Assistant Cloud speech platform via config entry."""
|
"""Set up Home Assistant Cloud speech platform via config entry."""
|
||||||
|
stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT]
|
||||||
|
stt_platform_loaded.set()
|
||||||
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
|
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
|
||||||
async_add_entities([CloudProviderEntity(cloud)])
|
async_add_entities([CloudProviderEntity(cloud)])
|
||||||
|
|
||||||
|
|
||||||
class CloudProviderEntity(SpeechToTextEntity):
|
class CloudProviderEntity(SpeechToTextEntity):
|
||||||
"""NabuCasa speech API provider."""
|
"""Home Assistant Cloud speech API provider."""
|
||||||
|
|
||||||
_attr_name = "Home Assistant Cloud"
|
_attr_name = "Home Assistant Cloud"
|
||||||
_attr_unique_id = STT_ENTITY_UNIQUE_ID
|
_attr_unique_id = STT_ENTITY_UNIQUE_ID
|
||||||
|
|
||||||
def __init__(self, cloud: Cloud[CloudClient]) -> None:
|
def __init__(self, cloud: Cloud[CloudClient]) -> None:
|
||||||
"""Home Assistant NabuCasa Speech to text."""
|
"""Initialize cloud Speech to text entity."""
|
||||||
self.cloud = cloud
|
self.cloud = cloud
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -81,7 +85,9 @@ class CloudProviderEntity(SpeechToTextEntity):
|
||||||
|
|
||||||
async def async_added_to_hass(self) -> None:
|
async def async_added_to_hass(self) -> None:
|
||||||
"""Run when entity is about to be added to hass."""
|
"""Run when entity is about to be added to hass."""
|
||||||
await async_migrate_cloud_pipeline_stt_engine(self.hass, self.entity_id)
|
await async_migrate_cloud_pipeline_engine(
|
||||||
|
self.hass, platform=Platform.STT, engine_id=self.entity_id
|
||||||
|
)
|
||||||
|
|
||||||
async def async_process_audio_stream(
|
async def async_process_audio_stream(
|
||||||
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Support for the cloud for text-to-speech service."""
|
"""Support for the cloud for text-to-speech service."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -12,16 +13,21 @@ from homeassistant.components.tts import (
|
||||||
ATTR_AUDIO_OUTPUT,
|
ATTR_AUDIO_OUTPUT,
|
||||||
ATTR_VOICE,
|
ATTR_VOICE,
|
||||||
CONF_LANG,
|
CONF_LANG,
|
||||||
PLATFORM_SCHEMA,
|
PLATFORM_SCHEMA as TTS_PLATFORM_SCHEMA,
|
||||||
Provider,
|
Provider,
|
||||||
|
TextToSpeechEntity,
|
||||||
TtsAudioType,
|
TtsAudioType,
|
||||||
Voice,
|
Voice,
|
||||||
)
|
)
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import Platform
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
|
from .assist_pipeline import async_migrate_cloud_pipeline_engine
|
||||||
from .client import CloudClient
|
from .client import CloudClient
|
||||||
from .const import DOMAIN
|
from .const import DATA_PLATFORMS_SETUP, DOMAIN, TTS_ENTITY_UNIQUE_ID
|
||||||
from .prefs import CloudPreferences
|
from .prefs import CloudPreferences
|
||||||
|
|
||||||
ATTR_GENDER = "gender"
|
ATTR_GENDER = "gender"
|
||||||
|
@ -48,7 +54,7 @@ def validate_lang(value: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
|
||||||
|
|
||||||
PLATFORM_SCHEMA = vol.All(
|
PLATFORM_SCHEMA = vol.All(
|
||||||
PLATFORM_SCHEMA.extend(
|
TTS_PLATFORM_SCHEMA.extend(
|
||||||
{
|
{
|
||||||
vol.Optional(CONF_LANG): str,
|
vol.Optional(CONF_LANG): str,
|
||||||
vol.Optional(ATTR_GENDER): str,
|
vol.Optional(ATTR_GENDER): str,
|
||||||
|
@ -81,8 +87,95 @@ async def async_get_engine(
|
||||||
return cloud_provider
|
return cloud_provider
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up Home Assistant Cloud text-to-speech platform."""
|
||||||
|
tts_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.TTS]
|
||||||
|
tts_platform_loaded.set()
|
||||||
|
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
|
||||||
|
async_add_entities([CloudTTSEntity(cloud)])
|
||||||
|
|
||||||
|
|
||||||
|
class CloudTTSEntity(TextToSpeechEntity):
|
||||||
|
"""Home Assistant Cloud text-to-speech entity."""
|
||||||
|
|
||||||
|
_attr_name = "Home Assistant Cloud"
|
||||||
|
_attr_unique_id = TTS_ENTITY_UNIQUE_ID
|
||||||
|
|
||||||
|
def __init__(self, cloud: Cloud[CloudClient]) -> None:
|
||||||
|
"""Initialize cloud text-to-speech entity."""
|
||||||
|
self.cloud = cloud
|
||||||
|
self._language, self._gender = cloud.client.prefs.tts_default_voice
|
||||||
|
|
||||||
|
async def _sync_prefs(self, prefs: CloudPreferences) -> None:
|
||||||
|
"""Sync preferences."""
|
||||||
|
self._language, self._gender = prefs.tts_default_voice
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_language(self) -> str:
|
||||||
|
"""Return the default language."""
|
||||||
|
return self._language
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_options(self) -> dict[str, Any]:
|
||||||
|
"""Return a dict include default options."""
|
||||||
|
return {
|
||||||
|
ATTR_GENDER: self._gender,
|
||||||
|
ATTR_AUDIO_OUTPUT: AudioOutput.MP3,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return list of supported languages."""
|
||||||
|
return SUPPORT_LANGUAGES
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_options(self) -> list[str]:
|
||||||
|
"""Return list of supported options like voice, emotion."""
|
||||||
|
return [ATTR_GENDER, ATTR_VOICE, ATTR_AUDIO_OUTPUT]
|
||||||
|
|
||||||
|
async def async_added_to_hass(self) -> None:
|
||||||
|
"""Handle entity which will be added."""
|
||||||
|
await super().async_added_to_hass()
|
||||||
|
await async_migrate_cloud_pipeline_engine(
|
||||||
|
self.hass, platform=Platform.TTS, engine_id=self.entity_id
|
||||||
|
)
|
||||||
|
self.async_on_remove(
|
||||||
|
self.cloud.client.prefs.async_listen_updates(self._sync_prefs)
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
|
||||||
|
"""Return a list of supported voices for a language."""
|
||||||
|
if not (voices := TTS_VOICES.get(language)):
|
||||||
|
return None
|
||||||
|
return [Voice(voice, voice) for voice in voices]
|
||||||
|
|
||||||
|
async def async_get_tts_audio(
|
||||||
|
self, message: str, language: str, options: dict[str, Any]
|
||||||
|
) -> TtsAudioType:
|
||||||
|
"""Load TTS from Home Assistant Cloud."""
|
||||||
|
# Process TTS
|
||||||
|
try:
|
||||||
|
data = await self.cloud.voice.process_tts(
|
||||||
|
text=message,
|
||||||
|
language=language,
|
||||||
|
gender=options.get(ATTR_GENDER),
|
||||||
|
voice=options.get(ATTR_VOICE),
|
||||||
|
output=options[ATTR_AUDIO_OUTPUT],
|
||||||
|
)
|
||||||
|
except VoiceError as err:
|
||||||
|
_LOGGER.error("Voice error: %s", err)
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
return (str(options[ATTR_AUDIO_OUTPUT].value), data)
|
||||||
|
|
||||||
|
|
||||||
class CloudProvider(Provider):
|
class CloudProvider(Provider):
|
||||||
"""NabuCasa Cloud speech API provider."""
|
"""Home Assistant Cloud speech API provider."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, cloud: Cloud[CloudClient], language: str | None, gender: str | None
|
self, cloud: Cloud[CloudClient], language: str | None, gender: str | None
|
||||||
|
@ -136,7 +229,7 @@ class CloudProvider(Provider):
|
||||||
async def async_get_tts_audio(
|
async def async_get_tts_audio(
|
||||||
self, message: str, language: str, options: dict[str, Any]
|
self, message: str, language: str, options: dict[str, Any]
|
||||||
) -> TtsAudioType:
|
) -> TtsAudioType:
|
||||||
"""Load TTS from NabuCasa Cloud."""
|
"""Load TTS from Home Assistant Cloud."""
|
||||||
# Process TTS
|
# Process TTS
|
||||||
try:
|
try:
|
||||||
data = await self.cloud.voice.process_tts(
|
data = await self.cloud.voice.process_tts(
|
||||||
|
|
|
@ -7,6 +7,54 @@ from homeassistant.components import cloud
|
||||||
from homeassistant.components.cloud import const, prefs as cloud_prefs
|
from homeassistant.components.cloud import const, prefs as cloud_prefs
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
PIPELINE_DATA = {
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"conversation_engine": "conversation_engine_1",
|
||||||
|
"conversation_language": "language_1",
|
||||||
|
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||||
|
"language": "language_1",
|
||||||
|
"name": "Home Assistant Cloud",
|
||||||
|
"stt_engine": "cloud",
|
||||||
|
"stt_language": "language_1",
|
||||||
|
"tts_engine": "cloud",
|
||||||
|
"tts_language": "language_1",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
|
"wake_word_entity": None,
|
||||||
|
"wake_word_id": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"conversation_engine": "conversation_engine_2",
|
||||||
|
"conversation_language": "language_2",
|
||||||
|
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
|
||||||
|
"language": "language_2",
|
||||||
|
"name": "name_2",
|
||||||
|
"stt_engine": "stt_engine_2",
|
||||||
|
"stt_language": "language_2",
|
||||||
|
"tts_engine": "tts_engine_2",
|
||||||
|
"tts_language": "language_2",
|
||||||
|
"tts_voice": "The Voice",
|
||||||
|
"wake_word_entity": None,
|
||||||
|
"wake_word_id": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"conversation_engine": "conversation_engine_3",
|
||||||
|
"conversation_language": "language_3",
|
||||||
|
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
|
||||||
|
"language": "language_3",
|
||||||
|
"name": "name_3",
|
||||||
|
"stt_engine": None,
|
||||||
|
"stt_language": None,
|
||||||
|
"tts_engine": None,
|
||||||
|
"tts_language": None,
|
||||||
|
"tts_voice": None,
|
||||||
|
"wake_word_entity": None,
|
||||||
|
"wake_word_id": None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def mock_cloud(hass, config=None):
|
async def mock_cloud(hass, config=None):
|
||||||
"""Mock cloud."""
|
"""Mock cloud."""
|
||||||
|
|
|
@ -15,11 +15,22 @@ import jwt
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.cloud import CloudClient, const, prefs
|
from homeassistant.components.cloud import CloudClient, const, prefs
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
from . import mock_cloud, mock_cloud_prefs
|
from . import mock_cloud, mock_cloud_prefs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def load_homeassistant(hass: HomeAssistant) -> None:
|
||||||
|
"""Load the homeassistant integration.
|
||||||
|
|
||||||
|
This is needed for the cloud integration to work.
|
||||||
|
"""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="cloud")
|
@pytest.fixture(name="cloud")
|
||||||
async def cloud_fixture() -> AsyncGenerator[MagicMock, None]:
|
async def cloud_fixture() -> AsyncGenerator[MagicMock, None]:
|
||||||
"""Mock the cloud object.
|
"""Mock the cloud object.
|
||||||
|
|
16
tests/components/cloud/test_assist_pipeline.py
Normal file
16
tests/components/cloud/test_assist_pipeline.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
"""Test the cloud assist pipeline."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.cloud.assist_pipeline import (
|
||||||
|
async_migrate_cloud_pipeline_engine,
|
||||||
|
)
|
||||||
|
from homeassistant.const import Platform
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
|
||||||
|
async def test_migrate_pipeline_invalid_platform(hass: HomeAssistant) -> None:
|
||||||
|
"""Test migrate pipeline with invalid platform."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await async_migrate_cloud_pipeline_engine(
|
||||||
|
hass, Platform.BINARY_SENSOR, "test-engine-id"
|
||||||
|
)
|
|
@ -147,15 +147,19 @@ async def test_google_actions_sync_fails(
|
||||||
assert mock_request_sync.call_count == 1
|
assert mock_request_sync.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_login_view_missing_stt_entity(
|
@pytest.mark.parametrize(
|
||||||
|
"entity_id", ["stt.home_assistant_cloud", "tts.home_assistant_cloud"]
|
||||||
|
)
|
||||||
|
async def test_login_view_missing_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
setup_cloud: None,
|
setup_cloud: None,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
|
entity_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test logging in when the cloud stt entity is missing."""
|
"""Test logging in when a cloud assist pipeline needed entity is missing."""
|
||||||
# Make sure that the cloud stt entity does not exist.
|
# Make sure that the cloud entity does not exist.
|
||||||
entity_registry.async_remove("stt.home_assistant_cloud")
|
entity_registry.async_remove(entity_id)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
cloud_client = await hass_client()
|
cloud_client = await hass_client()
|
||||||
|
@ -243,7 +247,7 @@ async def test_login_view_create_pipeline(
|
||||||
create_pipeline_mock.assert_awaited_once_with(
|
create_pipeline_mock.assert_awaited_once_with(
|
||||||
hass,
|
hass,
|
||||||
stt_engine_id="stt.home_assistant_cloud",
|
stt_engine_id="stt.home_assistant_cloud",
|
||||||
tts_engine_id="cloud",
|
tts_engine_id="tts.home_assistant_cloud",
|
||||||
pipeline_name="Home Assistant Cloud",
|
pipeline_name="Home Assistant Cloud",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -282,7 +286,7 @@ async def test_login_view_create_pipeline_fail(
|
||||||
create_pipeline_mock.assert_awaited_once_with(
|
create_pipeline_mock.assert_awaited_once_with(
|
||||||
hass,
|
hass,
|
||||||
stt_engine_id="stt.home_assistant_cloud",
|
stt_engine_id="stt.home_assistant_cloud",
|
||||||
tts_engine_id="cloud",
|
tts_engine_id="tts.home_assistant_cloud",
|
||||||
pipeline_name="Home Assistant Cloud",
|
pipeline_name="Home Assistant Cloud",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -14,62 +14,10 @@ from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from . import PIPELINE_DATA
|
||||||
|
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
PIPELINE_DATA = {
|
|
||||||
"items": [
|
|
||||||
{
|
|
||||||
"conversation_engine": "conversation_engine_1",
|
|
||||||
"conversation_language": "language_1",
|
|
||||||
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
|
||||||
"language": "language_1",
|
|
||||||
"name": "Home Assistant Cloud",
|
|
||||||
"stt_engine": "cloud",
|
|
||||||
"stt_language": "language_1",
|
|
||||||
"tts_engine": "cloud",
|
|
||||||
"tts_language": "language_1",
|
|
||||||
"tts_voice": "Arnold Schwarzenegger",
|
|
||||||
"wake_word_entity": None,
|
|
||||||
"wake_word_id": None,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"conversation_engine": "conversation_engine_2",
|
|
||||||
"conversation_language": "language_2",
|
|
||||||
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
|
|
||||||
"language": "language_2",
|
|
||||||
"name": "name_2",
|
|
||||||
"stt_engine": "stt_engine_2",
|
|
||||||
"stt_language": "language_2",
|
|
||||||
"tts_engine": "tts_engine_2",
|
|
||||||
"tts_language": "language_2",
|
|
||||||
"tts_voice": "The Voice",
|
|
||||||
"wake_word_entity": None,
|
|
||||||
"wake_word_id": None,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"conversation_engine": "conversation_engine_3",
|
|
||||||
"conversation_language": "language_3",
|
|
||||||
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
|
|
||||||
"language": "language_3",
|
|
||||||
"name": "name_3",
|
|
||||||
"stt_engine": None,
|
|
||||||
"stt_language": None,
|
|
||||||
"tts_engine": None,
|
|
||||||
"tts_language": None,
|
|
||||||
"tts_voice": None,
|
|
||||||
"wake_word_entity": None,
|
|
||||||
"wake_word_id": None,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
async def load_homeassistant(hass: HomeAssistant) -> None:
|
|
||||||
"""Load the homeassistant integration."""
|
|
||||||
assert await async_setup_component(hass, "homeassistant", {})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
async def delay_save_fixture() -> AsyncGenerator[None, None]:
|
async def delay_save_fixture() -> AsyncGenerator[None, None]:
|
||||||
|
@ -143,6 +91,7 @@ async def test_migrating_pipelines(
|
||||||
hass_storage: dict[str, Any],
|
hass_storage: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test migrating pipelines when cloud stt entity is added."""
|
"""Test migrating pipelines when cloud stt entity is added."""
|
||||||
|
entity_id = "stt.home_assistant_cloud"
|
||||||
cloud.voice.process_stt = AsyncMock(
|
cloud.voice.process_stt = AsyncMock(
|
||||||
return_value=STTResponse(True, "Turn the Kitchen Lights on")
|
return_value=STTResponse(True, "Turn the Kitchen Lights on")
|
||||||
)
|
)
|
||||||
|
@ -157,18 +106,18 @@ async def test_migrating_pipelines(
|
||||||
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
|
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
on_start_callback = cloud.register_on_start.call_args[0][0]
|
await cloud.login("test-user", "test-pass")
|
||||||
await on_start_callback()
|
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
state = hass.states.get("stt.home_assistant_cloud")
|
state = hass.states.get(entity_id)
|
||||||
assert state
|
assert state
|
||||||
assert state.state == STATE_UNKNOWN
|
assert state.state == STATE_UNKNOWN
|
||||||
|
|
||||||
# The stt engine should be updated to the new cloud stt engine id.
|
# The stt/tts engines should have been updated to the new cloud engine ids.
|
||||||
|
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"] == entity_id
|
||||||
assert (
|
assert (
|
||||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"]
|
hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"]
|
||||||
== "stt.home_assistant_cloud"
|
== "tts.home_assistant_cloud"
|
||||||
)
|
)
|
||||||
|
|
||||||
# The other items should stay the same.
|
# The other items should stay the same.
|
||||||
|
@ -189,7 +138,6 @@ async def test_migrating_pipelines(
|
||||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["name"] == "Home Assistant Cloud"
|
hass_storage[STORAGE_KEY]["data"]["items"][0]["name"] == "Home Assistant Cloud"
|
||||||
)
|
)
|
||||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_language"] == "language_1"
|
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_language"] == "language_1"
|
||||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"] == "cloud"
|
|
||||||
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_language"] == "language_1"
|
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_language"] == "language_1"
|
||||||
assert (
|
assert (
|
||||||
hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_voice"]
|
hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_voice"]
|
||||||
|
|
|
@ -1,23 +1,36 @@
|
||||||
"""Tests for cloud tts."""
|
"""Tests for cloud tts."""
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import AsyncGenerator, Callable, Coroutine
|
||||||
|
from copy import deepcopy
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
from hass_nabucasa.voice import MAP_VOICE, VoiceError, VoiceTokenError
|
from hass_nabucasa.voice import MAP_VOICE, VoiceError, VoiceTokenError
|
||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY
|
||||||
from homeassistant.components.cloud import DOMAIN, const, tts
|
from homeassistant.components.cloud import DOMAIN, const, tts
|
||||||
from homeassistant.components.tts import DOMAIN as TTS_DOMAIN
|
from homeassistant.components.tts import DOMAIN as TTS_DOMAIN
|
||||||
from homeassistant.components.tts.helper import get_engine_instance
|
from homeassistant.components.tts.helper import get_engine_instance
|
||||||
from homeassistant.config import async_process_ha_core_config
|
from homeassistant.config import async_process_ha_core_config
|
||||||
|
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.entity_registry import EntityRegistry
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from . import PIPELINE_DATA
|
||||||
|
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def delay_save_fixture() -> AsyncGenerator[None, None]:
|
||||||
|
"""Load the homeassistant integration."""
|
||||||
|
with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
async def internal_url_mock(hass: HomeAssistant) -> None:
|
async def internal_url_mock(hass: HomeAssistant) -> None:
|
||||||
"""Mock internal URL of the instance."""
|
"""Mock internal URL of the instance."""
|
||||||
|
@ -70,6 +83,10 @@ def test_schema() -> None:
|
||||||
"gender": "female",
|
"gender": "female",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"tts.home_assistant_cloud",
|
||||||
|
None,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_prefs_default_voice(
|
async def test_prefs_default_voice(
|
||||||
|
@ -104,9 +121,17 @@ async def test_prefs_default_voice(
|
||||||
assert engine.default_options == {"gender": "male", "audio_output": "mp3"}
|
assert engine.default_options == {"gender": "male", "audio_output": "mp3"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"engine_id",
|
||||||
|
[
|
||||||
|
DOMAIN,
|
||||||
|
"tts.home_assistant_cloud",
|
||||||
|
],
|
||||||
|
)
|
||||||
async def test_provider_properties(
|
async def test_provider_properties(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
cloud: MagicMock,
|
cloud: MagicMock,
|
||||||
|
engine_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test cloud provider."""
|
"""Test cloud provider."""
|
||||||
assert await async_setup_component(hass, "homeassistant", {})
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
@ -115,7 +140,7 @@ async def test_provider_properties(
|
||||||
on_start_callback = cloud.register_on_start.call_args[0][0]
|
on_start_callback = cloud.register_on_start.call_args[0][0]
|
||||||
await on_start_callback()
|
await on_start_callback()
|
||||||
|
|
||||||
engine = get_engine_instance(hass, DOMAIN)
|
engine = get_engine_instance(hass, engine_id)
|
||||||
|
|
||||||
assert engine is not None
|
assert engine is not None
|
||||||
assert engine.supported_options == ["gender", "voice", "audio_output"]
|
assert engine.supported_options == ["gender", "voice", "audio_output"]
|
||||||
|
@ -132,6 +157,7 @@ async def test_provider_properties(
|
||||||
[
|
[
|
||||||
({"platform": DOMAIN}, DOMAIN),
|
({"platform": DOMAIN}, DOMAIN),
|
||||||
({"engine_id": DOMAIN}, DOMAIN),
|
({"engine_id": DOMAIN}, DOMAIN),
|
||||||
|
({"engine_id": "tts.home_assistant_cloud"}, "tts.home_assistant_cloud"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -241,3 +267,144 @@ async def test_get_tts_audio_logged_out(
|
||||||
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
|
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
|
||||||
assert mock_process_tts.call_args.kwargs["gender"] == "female"
|
assert mock_process_tts.call_args.kwargs["gender"] == "female"
|
||||||
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("mock_process_tts_return_value", "mock_process_tts_side_effect"),
|
||||||
|
[
|
||||||
|
(b"", None),
|
||||||
|
(None, VoiceError("Boom!")),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_tts_entity(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
entity_registry: EntityRegistry,
|
||||||
|
cloud: MagicMock,
|
||||||
|
mock_process_tts_return_value: bytes | None,
|
||||||
|
mock_process_tts_side_effect: Exception | None,
|
||||||
|
) -> None:
|
||||||
|
"""Test text-to-speech entity."""
|
||||||
|
mock_process_tts = AsyncMock(
|
||||||
|
return_value=mock_process_tts_return_value,
|
||||||
|
side_effect=mock_process_tts_side_effect,
|
||||||
|
)
|
||||||
|
cloud.voice.process_tts = mock_process_tts
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
on_start_callback = cloud.register_on_start.call_args[0][0]
|
||||||
|
await on_start_callback()
|
||||||
|
client = await hass_client()
|
||||||
|
entity_id = "tts.home_assistant_cloud"
|
||||||
|
|
||||||
|
state = hass.states.get(entity_id)
|
||||||
|
assert state
|
||||||
|
assert state.state == STATE_UNKNOWN
|
||||||
|
|
||||||
|
url = "/api/tts_get_url"
|
||||||
|
data = {
|
||||||
|
"engine_id": entity_id,
|
||||||
|
"message": "There is someone at the door.",
|
||||||
|
}
|
||||||
|
|
||||||
|
req = await client.post(url, json=data)
|
||||||
|
assert req.status == HTTPStatus.OK
|
||||||
|
response = await req.json()
|
||||||
|
|
||||||
|
assert response == {
|
||||||
|
"url": (
|
||||||
|
"http://example.local:8123/api/tts_proxy/"
|
||||||
|
"42f18378fd4393d18c8dd11d03fa9563c1e54491"
|
||||||
|
f"_en-us_e09b5a0968_{entity_id}.mp3"
|
||||||
|
),
|
||||||
|
"path": (
|
||||||
|
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
|
||||||
|
f"_en-us_e09b5a0968_{entity_id}.mp3"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert mock_process_tts.call_count == 1
|
||||||
|
assert mock_process_tts.call_args is not None
|
||||||
|
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
|
||||||
|
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
|
||||||
|
assert mock_process_tts.call_args.kwargs["gender"] == "female"
|
||||||
|
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
|
||||||
|
|
||||||
|
state = hass.states.get(entity_id)
|
||||||
|
assert state
|
||||||
|
assert state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
||||||
|
|
||||||
|
# Test removing the entity
|
||||||
|
entity_registry.async_remove(entity_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
state = hass.states.get(entity_id)
|
||||||
|
assert state is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_migrating_pipelines(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
cloud: MagicMock,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
hass_storage: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Test migrating pipelines when cloud tts entity is added."""
|
||||||
|
entity_id = "tts.home_assistant_cloud"
|
||||||
|
mock_process_tts = AsyncMock(
|
||||||
|
return_value=b"",
|
||||||
|
)
|
||||||
|
cloud.voice.process_tts = mock_process_tts
|
||||||
|
hass_storage[STORAGE_KEY] = {
|
||||||
|
"version": 1,
|
||||||
|
"minor_version": 1,
|
||||||
|
"key": "assist_pipeline.pipelines",
|
||||||
|
"data": deepcopy(PIPELINE_DATA),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||||
|
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
await cloud.login("test-user", "test-pass")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
state = hass.states.get(entity_id)
|
||||||
|
assert state
|
||||||
|
assert state.state == STATE_UNKNOWN
|
||||||
|
|
||||||
|
# The stt/tts engines should have been updated to the new cloud engine ids.
|
||||||
|
assert (
|
||||||
|
hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"]
|
||||||
|
== "stt.home_assistant_cloud"
|
||||||
|
)
|
||||||
|
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"] == entity_id
|
||||||
|
|
||||||
|
# The other items should stay the same.
|
||||||
|
assert (
|
||||||
|
hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_engine"]
|
||||||
|
== "conversation_engine_1"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_language"]
|
||||||
|
== "language_1"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
hass_storage[STORAGE_KEY]["data"]["items"][0]["id"]
|
||||||
|
== "01GX8ZWBAQYWNB1XV3EXEZ75DY"
|
||||||
|
)
|
||||||
|
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["language"] == "language_1"
|
||||||
|
assert (
|
||||||
|
hass_storage[STORAGE_KEY]["data"]["items"][0]["name"] == "Home Assistant Cloud"
|
||||||
|
)
|
||||||
|
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_language"] == "language_1"
|
||||||
|
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_language"] == "language_1"
|
||||||
|
assert (
|
||||||
|
hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_voice"]
|
||||||
|
== "Arnold Schwarzenegger"
|
||||||
|
)
|
||||||
|
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_entity"] is None
|
||||||
|
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_id"] is None
|
||||||
|
assert hass_storage[STORAGE_KEY]["data"]["items"][1] == PIPELINE_DATA["items"][1]
|
||||||
|
assert hass_storage[STORAGE_KEY]["data"]["items"][2] == PIPELINE_DATA["items"][2]
|
||||||
|
|
Loading…
Add table
Reference in a new issue