Cloud: Add web socket API to pick default TTS language (#45064)
* Allow picking default TTS language * Fix test * Fix coroutine function * Improve test coverage * Remove stale import * Clean up hass Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
4e71be852a
commit
82746616fa
6 changed files with 144 additions and 11 deletions
|
@ -20,6 +20,8 @@ PREF_GOOGLE_LOCAL_WEBHOOK_ID = "google_local_webhook_id"
|
|||
PREF_USERNAME = "username"
|
||||
PREF_ALEXA_DEFAULT_EXPOSE = "alexa_default_expose"
|
||||
PREF_GOOGLE_DEFAULT_EXPOSE = "google_default_expose"
|
||||
PREF_TTS_DEFAULT_VOICE = "tts_default_voice"
|
||||
DEFAULT_TTS_DEFAULT_VOICE = ("en-US", "female")
|
||||
DEFAULT_DISABLE_2FA = False
|
||||
DEFAULT_ALEXA_REPORT_STATE = False
|
||||
DEFAULT_GOOGLE_REPORT_STATE = False
|
||||
|
|
|
@ -8,6 +8,7 @@ import async_timeout
|
|||
import attr
|
||||
from hass_nabucasa import Cloud, auth, thingtalk
|
||||
from hass_nabucasa.const import STATE_DISCONNECTED
|
||||
from hass_nabucasa.voice import MAP_VOICE
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
|
@ -37,6 +38,7 @@ from .const import (
|
|||
PREF_GOOGLE_DEFAULT_EXPOSE,
|
||||
PREF_GOOGLE_REPORT_STATE,
|
||||
PREF_GOOGLE_SECURE_DEVICES_PIN,
|
||||
PREF_TTS_DEFAULT_VOICE,
|
||||
REQUEST_TIMEOUT,
|
||||
InvalidTrustedNetworks,
|
||||
InvalidTrustedProxies,
|
||||
|
@ -115,6 +117,7 @@ async def async_setup(hass):
|
|||
async_register_command(alexa_sync)
|
||||
|
||||
async_register_command(thingtalk_convert)
|
||||
async_register_command(tts_info)
|
||||
|
||||
hass.http.register_view(GoogleActionsSyncView)
|
||||
hass.http.register_view(CloudLoginView)
|
||||
|
@ -385,6 +388,9 @@ async def websocket_subscription(hass, connection, msg):
|
|||
vol.Optional(PREF_ALEXA_DEFAULT_EXPOSE): [str],
|
||||
vol.Optional(PREF_GOOGLE_DEFAULT_EXPOSE): [str],
|
||||
vol.Optional(PREF_GOOGLE_SECURE_DEVICES_PIN): vol.Any(None, str),
|
||||
vol.Optional(PREF_TTS_DEFAULT_VOICE): vol.All(
|
||||
vol.Coerce(tuple), vol.In(MAP_VOICE)
|
||||
),
|
||||
}
|
||||
)
|
||||
async def websocket_update_prefs(hass, connection, msg):
|
||||
|
@ -637,3 +643,11 @@ async def thingtalk_convert(hass, connection, msg):
|
|||
)
|
||||
except thingtalk.ThingTalkConversionError as err:
|
||||
connection.send_error(msg["id"], ws_const.ERR_UNKNOWN_ERROR, str(err))
|
||||
|
||||
|
||||
@websocket_api.websocket_command({"type": "cloud/tts/info"})
|
||||
def tts_info(hass, connection, msg):
|
||||
"""Fetch available tts info."""
|
||||
connection.send_result(
|
||||
msg["id"], {"languages": [(lang, gender.value) for lang, gender in MAP_VOICE]}
|
||||
)
|
||||
|
|
|
@ -12,6 +12,7 @@ from .const import (
|
|||
DEFAULT_ALEXA_REPORT_STATE,
|
||||
DEFAULT_EXPOSED_DOMAINS,
|
||||
DEFAULT_GOOGLE_REPORT_STATE,
|
||||
DEFAULT_TTS_DEFAULT_VOICE,
|
||||
DOMAIN,
|
||||
PREF_ALEXA_DEFAULT_EXPOSE,
|
||||
PREF_ALEXA_ENTITY_CONFIGS,
|
||||
|
@ -30,6 +31,7 @@ from .const import (
|
|||
PREF_GOOGLE_SECURE_DEVICES_PIN,
|
||||
PREF_OVERRIDE_NAME,
|
||||
PREF_SHOULD_EXPOSE,
|
||||
PREF_TTS_DEFAULT_VOICE,
|
||||
PREF_USERNAME,
|
||||
InvalidTrustedNetworks,
|
||||
InvalidTrustedProxies,
|
||||
|
@ -86,6 +88,7 @@ class CloudPreferences:
|
|||
google_report_state=UNDEFINED,
|
||||
alexa_default_expose=UNDEFINED,
|
||||
google_default_expose=UNDEFINED,
|
||||
tts_default_voice=UNDEFINED,
|
||||
):
|
||||
"""Update user preferences."""
|
||||
prefs = {**self._prefs}
|
||||
|
@ -103,6 +106,7 @@ class CloudPreferences:
|
|||
(PREF_GOOGLE_REPORT_STATE, google_report_state),
|
||||
(PREF_ALEXA_DEFAULT_EXPOSE, alexa_default_expose),
|
||||
(PREF_GOOGLE_DEFAULT_EXPOSE, google_default_expose),
|
||||
(PREF_TTS_DEFAULT_VOICE, tts_default_voice),
|
||||
):
|
||||
if value is not UNDEFINED:
|
||||
prefs[key] = value
|
||||
|
@ -203,6 +207,7 @@ class CloudPreferences:
|
|||
PREF_GOOGLE_ENTITY_CONFIGS: self.google_entity_configs,
|
||||
PREF_GOOGLE_REPORT_STATE: self.google_report_state,
|
||||
PREF_GOOGLE_SECURE_DEVICES_PIN: self.google_secure_devices_pin,
|
||||
PREF_TTS_DEFAULT_VOICE: self.tts_default_voice,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -279,6 +284,11 @@ class CloudPreferences:
|
|||
"""Return the published cloud webhooks."""
|
||||
return self._prefs.get(PREF_CLOUDHOOKS, {})
|
||||
|
||||
@property
|
||||
def tts_default_voice(self):
|
||||
"""Return the default TTS voice."""
|
||||
return self._prefs.get(PREF_TTS_DEFAULT_VOICE, DEFAULT_TTS_DEFAULT_VOICE)
|
||||
|
||||
async def get_cloud_user(self) -> str:
|
||||
"""Return ID from Home Assistant Cloud system user."""
|
||||
user = await self._load_cloud_user()
|
||||
|
|
|
@ -12,13 +12,14 @@ CONF_GENDER = "gender"
|
|||
|
||||
SUPPORT_LANGUAGES = list({key[0] for key in MAP_VOICE})
|
||||
|
||||
DEFAULT_LANG = "en-US"
|
||||
DEFAULT_GENDER = "female"
|
||||
|
||||
|
||||
def validate_lang(value):
|
||||
"""Validate chosen gender or language."""
|
||||
lang = value[CONF_LANG]
|
||||
lang = value.get(CONF_LANG)
|
||||
|
||||
if lang is None:
|
||||
return value
|
||||
|
||||
gender = value.get(CONF_GENDER)
|
||||
|
||||
if gender is None:
|
||||
|
@ -35,7 +36,7 @@ def validate_lang(value):
|
|||
PLATFORM_SCHEMA = vol.All(
|
||||
PLATFORM_SCHEMA.extend(
|
||||
{
|
||||
vol.Optional(CONF_LANG, default=DEFAULT_LANG): str,
|
||||
vol.Optional(CONF_LANG): str,
|
||||
vol.Optional(CONF_GENDER): str,
|
||||
}
|
||||
),
|
||||
|
@ -48,8 +49,8 @@ async def async_get_engine(hass, config, discovery_info=None):
|
|||
cloud: Cloud = hass.data[DOMAIN]
|
||||
|
||||
if discovery_info is not None:
|
||||
language = DEFAULT_LANG
|
||||
gender = DEFAULT_GENDER
|
||||
language = None
|
||||
gender = None
|
||||
else:
|
||||
language = config[CONF_LANG]
|
||||
gender = config[CONF_GENDER]
|
||||
|
@ -67,6 +68,16 @@ class CloudProvider(Provider):
|
|||
self._language = language
|
||||
self._gender = gender
|
||||
|
||||
if self._language is not None:
|
||||
return
|
||||
|
||||
self._language, self._gender = cloud.client.prefs.tts_default_voice
|
||||
cloud.client.prefs.async_listen_updates(self._sync_prefs)
|
||||
|
||||
async def _sync_prefs(self, prefs):
|
||||
"""Sync preferences."""
|
||||
self._language, self._gender = prefs.tts_default_voice
|
||||
|
||||
@property
|
||||
def default_language(self):
|
||||
"""Return the default language."""
|
||||
|
|
|
@ -4,7 +4,7 @@ from ipaddress import ip_network
|
|||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import aiohttp
|
||||
from hass_nabucasa import thingtalk
|
||||
from hass_nabucasa import thingtalk, voice
|
||||
from hass_nabucasa.auth import Unauthenticated, UnknownError
|
||||
from hass_nabucasa.const import STATE_CONNECTED
|
||||
from jose import jwt
|
||||
|
@ -361,6 +361,7 @@ async def test_websocket_status(
|
|||
"alexa_report_state": False,
|
||||
"google_report_state": False,
|
||||
"remote_enabled": False,
|
||||
"tts_default_voice": ["en-US", "female"],
|
||||
},
|
||||
"alexa_entities": {
|
||||
"include_domains": [],
|
||||
|
@ -491,6 +492,7 @@ async def test_websocket_update_preferences(
|
|||
"google_secure_devices_pin": "1234",
|
||||
"google_default_expose": ["light", "switch"],
|
||||
"alexa_default_expose": ["sensor", "media_player"],
|
||||
"tts_default_voice": ["en-GB", "male"],
|
||||
}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
|
@ -501,6 +503,7 @@ async def test_websocket_update_preferences(
|
|||
assert setup_api.google_secure_devices_pin == "1234"
|
||||
assert setup_api.google_default_expose == ["light", "switch"]
|
||||
assert setup_api.alexa_default_expose == ["sensor", "media_player"]
|
||||
assert setup_api.tts_default_voice == ("en-GB", "male")
|
||||
|
||||
|
||||
async def test_websocket_update_preferences_require_relink(
|
||||
|
@ -975,3 +978,25 @@ async def test_thingtalk_convert_internal(hass, hass_ws_client, setup_api):
|
|||
assert not response["success"]
|
||||
assert response["error"]["code"] == "unknown_error"
|
||||
assert response["error"]["message"] == "Did not understand"
|
||||
|
||||
|
||||
async def test_tts_info(hass, hass_ws_client, setup_api):
|
||||
"""Test that we can get TTS info."""
|
||||
# Verify the format is as expected
|
||||
assert voice.MAP_VOICE[("en-US", voice.Gender.FEMALE)] == "JennyNeural"
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch.dict(
|
||||
"homeassistant.components.cloud.http_api.MAP_VOICE",
|
||||
{
|
||||
("en-US", voice.Gender.MALE): "GuyNeural",
|
||||
("en-US", voice.Gender.FEMALE): "JennyNeural",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
await client.send_json({"id": 5, "type": "cloud/tts/info"})
|
||||
response = await client.receive_json()
|
||||
|
||||
assert response["success"]
|
||||
assert response["result"] == {"languages": [["en-US", "male"], ["en-US", "female"]]}
|
||||
|
|
|
@ -1,5 +1,22 @@
|
|||
"""Tests for cloud tts."""
|
||||
from homeassistant.components.cloud import tts
|
||||
from unittest.mock import Mock
|
||||
|
||||
from hass_nabucasa import voice
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.cloud import const, tts
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def cloud_with_prefs(cloud_prefs):
|
||||
"""Return a cloud mock with prefs."""
|
||||
return Mock(client=Mock(prefs=cloud_prefs))
|
||||
|
||||
|
||||
def test_default_exists():
|
||||
"""Test our default language exists."""
|
||||
assert const.DEFAULT_TTS_DEFAULT_VOICE in voice.MAP_VOICE
|
||||
|
||||
|
||||
def test_schema():
|
||||
|
@ -9,7 +26,61 @@ def test_schema():
|
|||
processed = tts.PLATFORM_SCHEMA({"platform": "cloud", "language": "nl-NL"})
|
||||
assert processed["gender"] == "female"
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
tts.PLATFORM_SCHEMA(
|
||||
{"platform": "cloud", "language": "non-existing", "gender": "female"}
|
||||
)
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
tts.PLATFORM_SCHEMA(
|
||||
{"platform": "cloud", "language": "nl-NL", "gender": "not-supported"}
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
processed = tts.PLATFORM_SCHEMA(
|
||||
{"platform": "cloud", "language": "nl-NL", "gender": "female"}
|
||||
tts.PLATFORM_SCHEMA({"platform": "cloud", "language": "nl-NL", "gender": "female"})
|
||||
tts.PLATFORM_SCHEMA({"platform": "cloud"})
|
||||
|
||||
|
||||
async def test_prefs_default_voice(hass, cloud_with_prefs, cloud_prefs):
|
||||
"""Test cloud provider uses the preferences."""
|
||||
assert cloud_prefs.tts_default_voice == ("en-US", "female")
|
||||
|
||||
provider_pref = await tts.async_get_engine(
|
||||
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
|
||||
)
|
||||
provider_conf = await tts.async_get_engine(
|
||||
Mock(data={const.DOMAIN: cloud_with_prefs}),
|
||||
{"language": "fr-FR", "gender": "female"},
|
||||
None,
|
||||
)
|
||||
|
||||
assert provider_pref.default_language == "en-US"
|
||||
assert provider_pref.default_options == {"gender": "female"}
|
||||
assert provider_conf.default_language == "fr-FR"
|
||||
assert provider_conf.default_options == {"gender": "female"}
|
||||
|
||||
await cloud_prefs.async_update(tts_default_voice=("nl-NL", "male"))
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert provider_pref.default_language == "nl-NL"
|
||||
assert provider_pref.default_options == {"gender": "male"}
|
||||
assert provider_conf.default_language == "fr-FR"
|
||||
assert provider_conf.default_options == {"gender": "female"}
|
||||
|
||||
|
||||
async def test_provider_properties(cloud_with_prefs):
|
||||
"""Test cloud provider."""
|
||||
provider = await tts.async_get_engine(
|
||||
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
|
||||
)
|
||||
assert provider.supported_options == ["gender"]
|
||||
assert "nl-NL" in provider.supported_languages
|
||||
|
||||
|
||||
async def test_get_tts_audio(cloud_with_prefs):
|
||||
"""Test cloud provider."""
|
||||
provider = await tts.async_get_engine(
|
||||
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
|
||||
)
|
||||
assert provider.supported_options == ["gender"]
|
||||
assert "nl-NL" in provider.supported_languages
|
||||
|
|
Loading…
Add table
Reference in a new issue