From e1576d59981ceb7e92db315e678f762fe24a95f8 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Tue, 30 Jan 2024 21:58:16 +0100 Subject: [PATCH] Handle deprecated cloud tts voice (#109124) * Handle deprecated cloud tts voice * Add test * Fix test logic * Add breaks in ha version * Adjust translation string --- homeassistant/components/cloud/strings.json | 11 ++ homeassistant/components/cloud/tts.py | 41 +++++++- tests/components/cloud/test_tts.py | 111 ++++++++++++++++++++ 3 files changed, 161 insertions(+), 2 deletions(-) diff --git a/homeassistant/components/cloud/strings.json b/homeassistant/components/cloud/strings.json index 56fb3c0f5c9..6f1e3c80bf7 100644 --- a/homeassistant/components/cloud/strings.json +++ b/homeassistant/components/cloud/strings.json @@ -24,6 +24,17 @@ } }, "issues": { + "deprecated_voice": { + "title": "A deprecated voice was used", + "fix_flow": { + "step": { + "confirm": { + "title": "[%key:component::cloud::issues::deprecated_voice::title%]", + "description": "The '{deprecated_voice}' voice is deprecated and will be removed.\nPlease update your automations and scripts to replace the '{deprecated_voice}' with another voice like eg. '{replacement_voice}'." + } + } + } + }, "legacy_subscription": { "title": "Legacy subscription detected", "fix_flow": { diff --git a/homeassistant/components/cloud/tts.py b/homeassistant/components/cloud/tts.py index 2626c01e66f..ba34ac7a9b0 100644 --- a/homeassistant/components/cloud/tts.py +++ b/homeassistant/components/cloud/tts.py @@ -23,6 +23,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from .assist_pipeline import async_migrate_cloud_pipeline_engine @@ -32,6 +33,7 @@ from .prefs import CloudPreferences ATTR_GENDER = "gender" +DEPRECATED_VOICES = {"XiaoxuanNeural": "XiaozhenNeural"} SUPPORT_LANGUAGES = list(TTS_VOICES) _LOGGER = logging.getLogger(__name__) @@ -158,13 +160,15 @@ class CloudTTSEntity(TextToSpeechEntity): self, message: str, language: str, options: dict[str, Any] ) -> TtsAudioType: """Load TTS from Home Assistant Cloud.""" + original_voice: str | None = options.get(ATTR_VOICE) + voice = handle_deprecated_voice(self.hass, original_voice) # Process TTS try: data = await self.cloud.voice.process_tts( text=message, language=language, gender=options.get(ATTR_GENDER), - voice=options.get(ATTR_VOICE), + voice=voice, output=options[ATTR_AUDIO_OUTPUT], ) except VoiceError as err: @@ -230,13 +234,16 @@ class CloudProvider(Provider): self, message: str, language: str, options: dict[str, Any] ) -> TtsAudioType: """Load TTS from Home Assistant Cloud.""" + original_voice: str | None = options.get(ATTR_VOICE) + assert self.hass is not None + voice = handle_deprecated_voice(self.hass, original_voice) # Process TTS try: data = await self.cloud.voice.process_tts( text=message, language=language, gender=options.get(ATTR_GENDER), - voice=options.get(ATTR_VOICE), + voice=voice, output=options[ATTR_AUDIO_OUTPUT], ) except VoiceError as err: @@ -244,3 +251,33 @@ class CloudProvider(Provider): return (None, None) return (str(options[ATTR_AUDIO_OUTPUT].value), data) + + +@callback +def handle_deprecated_voice( + hass: HomeAssistant, + original_voice: str | None, +) -> str | None: + """Handle deprecated voice.""" + voice = original_voice + if ( + original_voice + and voice + and (voice := DEPRECATED_VOICES.get(original_voice, original_voice)) + != original_voice + ): + async_create_issue( + hass, + DOMAIN, + f"deprecated_voice_{original_voice}", + is_fixable=True, + is_persistent=True, + severity=IssueSeverity.WARNING, + breaks_in_ha_version="2024.8.0", + translation_key="deprecated_voice", + translation_placeholders={ + "deprecated_voice": original_voice, + "replacement_voice": voice, + }, + ) + return voice diff --git a/tests/components/cloud/test_tts.py b/tests/components/cloud/test_tts.py index b75d2361070..92a9cb10992 100644 --- a/tests/components/cloud/test_tts.py +++ b/tests/components/cloud/test_tts.py @@ -17,6 +17,7 @@ from homeassistant.config import async_process_ha_core_config from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_registry import EntityRegistry +from homeassistant.helpers.issue_registry import IssueRegistry, IssueSeverity from homeassistant.setup import async_setup_component from . import PIPELINE_DATA @@ -408,3 +409,113 @@ async def test_migrating_pipelines( assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_id"] is None assert hass_storage[STORAGE_KEY]["data"]["items"][1] == PIPELINE_DATA["items"][1] assert hass_storage[STORAGE_KEY]["data"]["items"][2] == PIPELINE_DATA["items"][2] + + +@pytest.mark.parametrize( + ("data", "expected_url_suffix"), + [ + ({"platform": DOMAIN}, DOMAIN), + ({"engine_id": DOMAIN}, DOMAIN), + ({"engine_id": "tts.home_assistant_cloud"}, "tts.home_assistant_cloud"), + ], +) +async def test_deprecated_voice( + hass: HomeAssistant, + issue_registry: IssueRegistry, + cloud: MagicMock, + hass_client: ClientSessionGenerator, + data: dict[str, Any], + expected_url_suffix: str, +) -> None: + """Test we create an issue when a deprecated voice is used for text-to-speech.""" + language = "zh-CN" + deprecated_voice = "XiaoxuanNeural" + replacement_voice = "XiaozhenNeural" + mock_process_tts = AsyncMock( + return_value=b"", + ) + cloud.voice.process_tts = mock_process_tts + + assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}}) + await hass.async_block_till_done() + await cloud.login("test-user", "test-pass") + client = await hass_client() + + # Test with non deprecated voice. + url = "/api/tts_get_url" + data |= { + "message": "There is someone at the door.", + "language": language, + "options": {"voice": replacement_voice}, + } + + req = await client.post(url, json=data) + assert req.status == HTTPStatus.OK + response = await req.json() + + assert response == { + "url": ( + "http://example.local:8123/api/tts_proxy/" + "42f18378fd4393d18c8dd11d03fa9563c1e54491" + f"_{language.lower()}_1c4ec2f170_{expected_url_suffix}.mp3" + ), + "path": ( + "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491" + f"_{language.lower()}_1c4ec2f170_{expected_url_suffix}.mp3" + ), + } + await hass.async_block_till_done() + + assert mock_process_tts.call_count == 1 + assert mock_process_tts.call_args is not None + assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door." + assert mock_process_tts.call_args.kwargs["language"] == language + assert mock_process_tts.call_args.kwargs["gender"] == "female" + assert mock_process_tts.call_args.kwargs["voice"] == replacement_voice + assert mock_process_tts.call_args.kwargs["output"] == "mp3" + issue = issue_registry.async_get_issue( + "cloud", f"deprecated_voice_{replacement_voice}" + ) + assert issue is None + mock_process_tts.reset_mock() + + # Test with deprecated voice. + data["options"] = {"voice": deprecated_voice} + + req = await client.post(url, json=data) + assert req.status == HTTPStatus.OK + response = await req.json() + + assert response == { + "url": ( + "http://example.local:8123/api/tts_proxy/" + "42f18378fd4393d18c8dd11d03fa9563c1e54491" + f"_{language.lower()}_a1c3b0ac0e_{expected_url_suffix}.mp3" + ), + "path": ( + "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491" + f"_{language.lower()}_a1c3b0ac0e_{expected_url_suffix}.mp3" + ), + } + await hass.async_block_till_done() + + assert mock_process_tts.call_count == 1 + assert mock_process_tts.call_args is not None + assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door." + assert mock_process_tts.call_args.kwargs["language"] == language + assert mock_process_tts.call_args.kwargs["gender"] == "female" + assert mock_process_tts.call_args.kwargs["voice"] == replacement_voice + assert mock_process_tts.call_args.kwargs["output"] == "mp3" + issue = issue_registry.async_get_issue( + "cloud", f"deprecated_voice_{deprecated_voice}" + ) + assert issue is not None + assert issue.breaks_in_ha_version == "2024.8.0" + assert issue.is_fixable is True + assert issue.is_persistent is True + assert issue.severity == IssueSeverity.WARNING + assert issue.translation_key == "deprecated_voice" + assert issue.translation_placeholders == { + "deprecated_voice": deprecated_voice, + "replacement_voice": replacement_voice, + }