Cloud: Add web socket API to pick default TTS language (#45064)

* Allow picking default TTS language

* Fix test

* Fix coroutine function

* Improve test coverage

* Remove stale import

* Clean up hass

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Paulus Schoutsen 2021-01-13 00:05:30 +01:00 committed by GitHub
parent 4e71be852a
commit 82746616fa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 144 additions and 11 deletions

View file

@ -20,6 +20,8 @@ PREF_GOOGLE_LOCAL_WEBHOOK_ID = "google_local_webhook_id"
PREF_USERNAME = "username"
PREF_ALEXA_DEFAULT_EXPOSE = "alexa_default_expose"
PREF_GOOGLE_DEFAULT_EXPOSE = "google_default_expose"
PREF_TTS_DEFAULT_VOICE = "tts_default_voice"
DEFAULT_TTS_DEFAULT_VOICE = ("en-US", "female")
DEFAULT_DISABLE_2FA = False
DEFAULT_ALEXA_REPORT_STATE = False
DEFAULT_GOOGLE_REPORT_STATE = False

View file

@ -8,6 +8,7 @@ import async_timeout
import attr
from hass_nabucasa import Cloud, auth, thingtalk
from hass_nabucasa.const import STATE_DISCONNECTED
from hass_nabucasa.voice import MAP_VOICE
import voluptuous as vol
from homeassistant.components import websocket_api
@ -37,6 +38,7 @@ from .const import (
PREF_GOOGLE_DEFAULT_EXPOSE,
PREF_GOOGLE_REPORT_STATE,
PREF_GOOGLE_SECURE_DEVICES_PIN,
PREF_TTS_DEFAULT_VOICE,
REQUEST_TIMEOUT,
InvalidTrustedNetworks,
InvalidTrustedProxies,
@ -115,6 +117,7 @@ async def async_setup(hass):
async_register_command(alexa_sync)
async_register_command(thingtalk_convert)
async_register_command(tts_info)
hass.http.register_view(GoogleActionsSyncView)
hass.http.register_view(CloudLoginView)
@ -385,6 +388,9 @@ async def websocket_subscription(hass, connection, msg):
vol.Optional(PREF_ALEXA_DEFAULT_EXPOSE): [str],
vol.Optional(PREF_GOOGLE_DEFAULT_EXPOSE): [str],
vol.Optional(PREF_GOOGLE_SECURE_DEVICES_PIN): vol.Any(None, str),
vol.Optional(PREF_TTS_DEFAULT_VOICE): vol.All(
vol.Coerce(tuple), vol.In(MAP_VOICE)
),
}
)
async def websocket_update_prefs(hass, connection, msg):
@ -637,3 +643,11 @@ async def thingtalk_convert(hass, connection, msg):
)
except thingtalk.ThingTalkConversionError as err:
connection.send_error(msg["id"], ws_const.ERR_UNKNOWN_ERROR, str(err))
@websocket_api.websocket_command({"type": "cloud/tts/info"})
def tts_info(hass, connection, msg):
"""Fetch available tts info."""
connection.send_result(
msg["id"], {"languages": [(lang, gender.value) for lang, gender in MAP_VOICE]}
)

View file

@ -12,6 +12,7 @@ from .const import (
DEFAULT_ALEXA_REPORT_STATE,
DEFAULT_EXPOSED_DOMAINS,
DEFAULT_GOOGLE_REPORT_STATE,
DEFAULT_TTS_DEFAULT_VOICE,
DOMAIN,
PREF_ALEXA_DEFAULT_EXPOSE,
PREF_ALEXA_ENTITY_CONFIGS,
@ -30,6 +31,7 @@ from .const import (
PREF_GOOGLE_SECURE_DEVICES_PIN,
PREF_OVERRIDE_NAME,
PREF_SHOULD_EXPOSE,
PREF_TTS_DEFAULT_VOICE,
PREF_USERNAME,
InvalidTrustedNetworks,
InvalidTrustedProxies,
@ -86,6 +88,7 @@ class CloudPreferences:
google_report_state=UNDEFINED,
alexa_default_expose=UNDEFINED,
google_default_expose=UNDEFINED,
tts_default_voice=UNDEFINED,
):
"""Update user preferences."""
prefs = {**self._prefs}
@ -103,6 +106,7 @@ class CloudPreferences:
(PREF_GOOGLE_REPORT_STATE, google_report_state),
(PREF_ALEXA_DEFAULT_EXPOSE, alexa_default_expose),
(PREF_GOOGLE_DEFAULT_EXPOSE, google_default_expose),
(PREF_TTS_DEFAULT_VOICE, tts_default_voice),
):
if value is not UNDEFINED:
prefs[key] = value
@ -203,6 +207,7 @@ class CloudPreferences:
PREF_GOOGLE_ENTITY_CONFIGS: self.google_entity_configs,
PREF_GOOGLE_REPORT_STATE: self.google_report_state,
PREF_GOOGLE_SECURE_DEVICES_PIN: self.google_secure_devices_pin,
PREF_TTS_DEFAULT_VOICE: self.tts_default_voice,
}
@property
@ -279,6 +284,11 @@ class CloudPreferences:
"""Return the published cloud webhooks."""
return self._prefs.get(PREF_CLOUDHOOKS, {})
@property
def tts_default_voice(self):
"""Return the default TTS voice."""
return self._prefs.get(PREF_TTS_DEFAULT_VOICE, DEFAULT_TTS_DEFAULT_VOICE)
async def get_cloud_user(self) -> str:
"""Return ID from Home Assistant Cloud system user."""
user = await self._load_cloud_user()

View file

@ -12,13 +12,14 @@ CONF_GENDER = "gender"
SUPPORT_LANGUAGES = list({key[0] for key in MAP_VOICE})
DEFAULT_LANG = "en-US"
DEFAULT_GENDER = "female"
def validate_lang(value):
"""Validate chosen gender or language."""
lang = value[CONF_LANG]
lang = value.get(CONF_LANG)
if lang is None:
return value
gender = value.get(CONF_GENDER)
if gender is None:
@ -35,7 +36,7 @@ def validate_lang(value):
PLATFORM_SCHEMA = vol.All(
PLATFORM_SCHEMA.extend(
{
vol.Optional(CONF_LANG, default=DEFAULT_LANG): str,
vol.Optional(CONF_LANG): str,
vol.Optional(CONF_GENDER): str,
}
),
@ -48,8 +49,8 @@ async def async_get_engine(hass, config, discovery_info=None):
cloud: Cloud = hass.data[DOMAIN]
if discovery_info is not None:
language = DEFAULT_LANG
gender = DEFAULT_GENDER
language = None
gender = None
else:
language = config[CONF_LANG]
gender = config[CONF_GENDER]
@ -67,6 +68,16 @@ class CloudProvider(Provider):
self._language = language
self._gender = gender
if self._language is not None:
return
self._language, self._gender = cloud.client.prefs.tts_default_voice
cloud.client.prefs.async_listen_updates(self._sync_prefs)
async def _sync_prefs(self, prefs):
"""Sync preferences."""
self._language, self._gender = prefs.tts_default_voice
@property
def default_language(self):
"""Return the default language."""

View file

@ -4,7 +4,7 @@ from ipaddress import ip_network
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import aiohttp
from hass_nabucasa import thingtalk
from hass_nabucasa import thingtalk, voice
from hass_nabucasa.auth import Unauthenticated, UnknownError
from hass_nabucasa.const import STATE_CONNECTED
from jose import jwt
@ -361,6 +361,7 @@ async def test_websocket_status(
"alexa_report_state": False,
"google_report_state": False,
"remote_enabled": False,
"tts_default_voice": ["en-US", "female"],
},
"alexa_entities": {
"include_domains": [],
@ -491,6 +492,7 @@ async def test_websocket_update_preferences(
"google_secure_devices_pin": "1234",
"google_default_expose": ["light", "switch"],
"alexa_default_expose": ["sensor", "media_player"],
"tts_default_voice": ["en-GB", "male"],
}
)
response = await client.receive_json()
@ -501,6 +503,7 @@ async def test_websocket_update_preferences(
assert setup_api.google_secure_devices_pin == "1234"
assert setup_api.google_default_expose == ["light", "switch"]
assert setup_api.alexa_default_expose == ["sensor", "media_player"]
assert setup_api.tts_default_voice == ("en-GB", "male")
async def test_websocket_update_preferences_require_relink(
@ -975,3 +978,25 @@ async def test_thingtalk_convert_internal(hass, hass_ws_client, setup_api):
assert not response["success"]
assert response["error"]["code"] == "unknown_error"
assert response["error"]["message"] == "Did not understand"
async def test_tts_info(hass, hass_ws_client, setup_api):
"""Test that we can get TTS info."""
# Verify the format is as expected
assert voice.MAP_VOICE[("en-US", voice.Gender.FEMALE)] == "JennyNeural"
client = await hass_ws_client(hass)
with patch.dict(
"homeassistant.components.cloud.http_api.MAP_VOICE",
{
("en-US", voice.Gender.MALE): "GuyNeural",
("en-US", voice.Gender.FEMALE): "JennyNeural",
},
clear=True,
):
await client.send_json({"id": 5, "type": "cloud/tts/info"})
response = await client.receive_json()
assert response["success"]
assert response["result"] == {"languages": [["en-US", "male"], ["en-US", "female"]]}

View file

@ -1,5 +1,22 @@
"""Tests for cloud tts."""
from homeassistant.components.cloud import tts
from unittest.mock import Mock
from hass_nabucasa import voice
import pytest
import voluptuous as vol
from homeassistant.components.cloud import const, tts
@pytest.fixture()
def cloud_with_prefs(cloud_prefs):
"""Return a cloud mock with prefs."""
return Mock(client=Mock(prefs=cloud_prefs))
def test_default_exists():
"""Test our default language exists."""
assert const.DEFAULT_TTS_DEFAULT_VOICE in voice.MAP_VOICE
def test_schema():
@ -9,7 +26,61 @@ def test_schema():
processed = tts.PLATFORM_SCHEMA({"platform": "cloud", "language": "nl-NL"})
assert processed["gender"] == "female"
with pytest.raises(vol.Invalid):
tts.PLATFORM_SCHEMA(
{"platform": "cloud", "language": "non-existing", "gender": "female"}
)
with pytest.raises(vol.Invalid):
tts.PLATFORM_SCHEMA(
{"platform": "cloud", "language": "nl-NL", "gender": "not-supported"}
)
# Should not raise
processed = tts.PLATFORM_SCHEMA(
{"platform": "cloud", "language": "nl-NL", "gender": "female"}
tts.PLATFORM_SCHEMA({"platform": "cloud", "language": "nl-NL", "gender": "female"})
tts.PLATFORM_SCHEMA({"platform": "cloud"})
async def test_prefs_default_voice(hass, cloud_with_prefs, cloud_prefs):
"""Test cloud provider uses the preferences."""
assert cloud_prefs.tts_default_voice == ("en-US", "female")
provider_pref = await tts.async_get_engine(
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
)
provider_conf = await tts.async_get_engine(
Mock(data={const.DOMAIN: cloud_with_prefs}),
{"language": "fr-FR", "gender": "female"},
None,
)
assert provider_pref.default_language == "en-US"
assert provider_pref.default_options == {"gender": "female"}
assert provider_conf.default_language == "fr-FR"
assert provider_conf.default_options == {"gender": "female"}
await cloud_prefs.async_update(tts_default_voice=("nl-NL", "male"))
await hass.async_block_till_done()
assert provider_pref.default_language == "nl-NL"
assert provider_pref.default_options == {"gender": "male"}
assert provider_conf.default_language == "fr-FR"
assert provider_conf.default_options == {"gender": "female"}
async def test_provider_properties(cloud_with_prefs):
"""Test cloud provider."""
provider = await tts.async_get_engine(
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
)
assert provider.supported_options == ["gender"]
assert "nl-NL" in provider.supported_languages
async def test_get_tts_audio(cloud_with_prefs):
"""Test cloud provider."""
provider = await tts.async_get_engine(
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {}
)
assert provider.supported_options == ["gender"]
assert "nl-NL" in provider.supported_languages