diff --git a/homeassistant/components/google_assistant_sdk/__init__.py b/homeassistant/components/google_assistant_sdk/__init__.py index c784ebb500e..a414239b69f 100644 --- a/homeassistant/components/google_assistant_sdk/__init__.py +++ b/homeassistant/components/google_assistant_sdk/__init__.py @@ -11,23 +11,36 @@ from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform from homeassistant.core import Context, HomeAssistant, ServiceCall from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady -from homeassistant.helpers import discovery, intent +from homeassistant.helpers import config_validation as cv, discovery, intent from homeassistant.helpers.config_entry_oauth2_flow import ( OAuth2Session, async_get_config_entry_implementation, ) from homeassistant.helpers.typing import ConfigType -from .const import CONF_ENABLE_CONVERSATION_AGENT, CONF_LANGUAGE_CODE, DOMAIN -from .helpers import async_send_text_commands, default_language_code +from .const import ( + CONF_ENABLE_CONVERSATION_AGENT, + CONF_LANGUAGE_CODE, + DATA_MEM_STORAGE, + DATA_SESSION, + DOMAIN, +) +from .helpers import ( + GoogleAssistantSDKAudioView, + InMemoryStorage, + async_send_text_commands, + default_language_code, +) SERVICE_SEND_TEXT_COMMAND = "send_text_command" SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND = "command" +SERVICE_SEND_TEXT_COMMAND_FIELD_MEDIA_PLAYER = "media_player" SERVICE_SEND_TEXT_COMMAND_SCHEMA = vol.All( { vol.Required(SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND): vol.All( str, vol.Length(min=1) ), + vol.Optional(SERVICE_SEND_TEXT_COMMAND_FIELD_MEDIA_PLAYER): cv.comp_entity_ids, }, ) @@ -45,6 +58,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up Google Assistant SDK from a config entry.""" + hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {} + implementation = await async_get_config_entry_implementation(hass, entry) session = OAuth2Session(hass, entry, implementation) try: @@ -57,7 +72,11 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: raise ConfigEntryNotReady from err except aiohttp.ClientError as err: raise ConfigEntryNotReady from err - hass.data.setdefault(DOMAIN, {})[entry.entry_id] = session + hass.data[DOMAIN][entry.entry_id][DATA_SESSION] = session + + mem_storage = InMemoryStorage(hass) + hass.data[DOMAIN][entry.entry_id][DATA_MEM_STORAGE] = mem_storage + hass.http.register_view(GoogleAssistantSDKAudioView(mem_storage)) await async_setup_service(hass) @@ -88,7 +107,10 @@ async def async_setup_service(hass: HomeAssistant) -> None: async def send_text_command(call: ServiceCall) -> None: """Send a text command to Google Assistant SDK.""" command: str = call.data[SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND] - await async_send_text_commands([command], hass) + media_players: list[str] | None = call.data.get( + SERVICE_SEND_TEXT_COMMAND_FIELD_MEDIA_PLAYER + ) + await async_send_text_commands(hass, [command], media_players) hass.services.async_register( DOMAIN, @@ -136,7 +158,7 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent): if self.session: session = self.session else: - session = self.hass.data[DOMAIN].get(self.entry.entry_id) + session = self.hass.data[DOMAIN][self.entry.entry_id][DATA_SESSION] self.session = session if not session.valid_token: await session.async_ensure_token_valid() diff --git a/homeassistant/components/google_assistant_sdk/const.py b/homeassistant/components/google_assistant_sdk/const.py index 1b77b58d0fb..c9f86160bb4 100644 --- a/homeassistant/components/google_assistant_sdk/const.py +++ b/homeassistant/components/google_assistant_sdk/const.py @@ -5,8 +5,12 @@ DOMAIN: Final = "google_assistant_sdk" DEFAULT_NAME: Final = "Google Assistant SDK" +CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent" CONF_LANGUAGE_CODE: Final = "language_code" +DATA_MEM_STORAGE: Final = "mem_storage" +DATA_SESSION: Final = "session" + # https://developers.google.com/assistant/sdk/reference/rpc/languages SUPPORTED_LANGUAGE_CODES: Final = [ "de-DE", @@ -24,5 +28,3 @@ SUPPORTED_LANGUAGE_CODES: Final = [ "ko-KR", "pt-BR", ] - -CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent" diff --git a/homeassistant/components/google_assistant_sdk/helpers.py b/homeassistant/components/google_assistant_sdk/helpers.py index e2d704a917a..1c85e5b6a4b 100644 --- a/homeassistant/components/google_assistant_sdk/helpers.py +++ b/homeassistant/components/google_assistant_sdk/helpers.py @@ -1,18 +1,38 @@ """Helper classes for Google Assistant SDK integration.""" from __future__ import annotations +from http import HTTPStatus import logging +from typing import Any +import uuid import aiohttp +from aiohttp import web from gassist_text import TextAssistant from google.oauth2.credentials import Credentials +from homeassistant.components.http import HomeAssistantView +from homeassistant.components.media_player import ( + ATTR_MEDIA_ANNOUNCE, + ATTR_MEDIA_CONTENT_ID, + ATTR_MEDIA_CONTENT_TYPE, + DOMAIN as DOMAIN_MP, + SERVICE_PLAY_MEDIA, + MediaType, +) from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_ACCESS_TOKEN +from homeassistant.const import ATTR_ENTITY_ID, CONF_ACCESS_TOKEN from homeassistant.core import HomeAssistant from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session +from homeassistant.helpers.event import async_call_later -from .const import CONF_LANGUAGE_CODE, DOMAIN, SUPPORTED_LANGUAGE_CODES +from .const import ( + CONF_LANGUAGE_CODE, + DATA_MEM_STORAGE, + DATA_SESSION, + DOMAIN, + SUPPORTED_LANGUAGE_CODES, +) _LOGGER = logging.getLogger(__name__) @@ -28,12 +48,14 @@ DEFAULT_LANGUAGE_CODES = { } -async def async_send_text_commands(commands: list[str], hass: HomeAssistant) -> None: +async def async_send_text_commands( + hass: HomeAssistant, commands: list[str], media_players: list[str] | None = None +) -> None: """Send text commands to Google Assistant Service.""" # There can only be 1 entry (config_flow has single_instance_allowed) entry: ConfigEntry = hass.config_entries.async_entries(DOMAIN)[0] - session: OAuth2Session = hass.data[DOMAIN].get(entry.entry_id) + session: OAuth2Session = hass.data[DOMAIN][entry.entry_id][DATA_SESSION] try: await session.async_ensure_token_valid() except aiohttp.ClientResponseError as err: @@ -43,10 +65,32 @@ async def async_send_text_commands(commands: list[str], hass: HomeAssistant) -> credentials = Credentials(session.token[CONF_ACCESS_TOKEN]) language_code = entry.options.get(CONF_LANGUAGE_CODE, default_language_code(hass)) - with TextAssistant(credentials, language_code) as assistant: + with TextAssistant( + credentials, language_code, audio_out=bool(media_players) + ) as assistant: for command in commands: - text_response = assistant.assist(command)[0] + resp = assistant.assist(command) + text_response = resp[0] _LOGGER.debug("command: %s\nresponse: %s", command, text_response) + audio_response = resp[2] + if media_players and audio_response: + mem_storage: InMemoryStorage = hass.data[DOMAIN][entry.entry_id][ + DATA_MEM_STORAGE + ] + audio_url = GoogleAssistantSDKAudioView.url.format( + filename=mem_storage.store_and_get_identifier(audio_response) + ) + await hass.services.async_call( + DOMAIN_MP, + SERVICE_PLAY_MEDIA, + { + ATTR_ENTITY_ID: media_players, + ATTR_MEDIA_CONTENT_ID: audio_url, + ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC, + ATTR_MEDIA_ANNOUNCE: True, + }, + blocking=True, + ) def default_language_code(hass: HomeAssistant): @@ -55,3 +99,53 @@ def default_language_code(hass: HomeAssistant): if language_code in SUPPORTED_LANGUAGE_CODES: return language_code return DEFAULT_LANGUAGE_CODES.get(hass.config.language, "en-US") + + +class InMemoryStorage: + """Temporarily store and retrieve data from in memory storage.""" + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize InMemoryStorage.""" + self.hass: HomeAssistant = hass + self.mem: dict[str, bytes] = {} + + def store_and_get_identifier(self, data: bytes) -> str: + """ + Temporarily store data and return identifier to be able to retrieve it. + + Data expires after 5 minutes. + """ + identifier: str = uuid.uuid1().hex + self.mem[identifier] = data + + def async_remove_from_mem(*_: Any) -> None: + """Cleanup memory.""" + self.mem.pop(identifier, None) + + # Remove the entry from memory 5 minutes later + async_call_later(self.hass, 5 * 60, async_remove_from_mem) + + return identifier + + def retrieve(self, identifier: str) -> bytes | None: + """Retrieve previously stored data.""" + return self.mem.get(identifier) + + +class GoogleAssistantSDKAudioView(HomeAssistantView): + """Google Assistant SDK view to serve audio responses.""" + + requires_auth = True + url = "/api/google_assistant_sdk/audio/{filename}" + name = "api:google_assistant_sdk:audio" + + def __init__(self, mem_storage: InMemoryStorage) -> None: + """Initialize GoogleAssistantSDKView.""" + self.mem_storage: InMemoryStorage = mem_storage + + async def get(self, request: web.Request, filename: str) -> web.Response: + """Start a get request.""" + audio = self.mem_storage.retrieve(filename) + if not audio: + return web.Response(status=HTTPStatus.NOT_FOUND) + return web.Response(body=audio, content_type="audio/mpeg") diff --git a/homeassistant/components/google_assistant_sdk/manifest.json b/homeassistant/components/google_assistant_sdk/manifest.json index e1b390f9496..86684242b73 100644 --- a/homeassistant/components/google_assistant_sdk/manifest.json +++ b/homeassistant/components/google_assistant_sdk/manifest.json @@ -2,9 +2,9 @@ "domain": "google_assistant_sdk", "name": "Google Assistant SDK", "config_flow": true, - "dependencies": ["application_credentials"], + "dependencies": ["application_credentials", "http"], "documentation": "https://www.home-assistant.io/integrations/google_assistant_sdk/", - "requirements": ["gassist-text==0.0.7"], + "requirements": ["gassist-text==0.0.8"], "codeowners": ["@tronikos"], "iot_class": "cloud_polling", "integration_type": "service" diff --git a/homeassistant/components/google_assistant_sdk/notify.py b/homeassistant/components/google_assistant_sdk/notify.py index f9a212b54c3..80d0e70f44c 100644 --- a/homeassistant/components/google_assistant_sdk/notify.py +++ b/homeassistant/components/google_assistant_sdk/notify.py @@ -70,4 +70,4 @@ class BroadcastNotificationService(BaseNotificationService): commands.append( broadcast_commands(language_code)[1].format(message, target) ) - await async_send_text_commands(commands, self.hass) + await async_send_text_commands(self.hass, commands) diff --git a/homeassistant/components/google_assistant_sdk/services.yaml b/homeassistant/components/google_assistant_sdk/services.yaml index b9d4e8635de..c010843ed92 100644 --- a/homeassistant/components/google_assistant_sdk/services.yaml +++ b/homeassistant/components/google_assistant_sdk/services.yaml @@ -8,3 +8,10 @@ send_text_command: example: turn off kitchen TV selector: text: + media_player: + name: Media Player Entity + description: Name(s) of media player entities to play response on + example: media_player.living_room_speaker + selector: + entity: + domain: media_player diff --git a/requirements_all.txt b/requirements_all.txt index cbc045ac7d5..280d7f043b0 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -754,7 +754,7 @@ fritzconnection==1.10.3 gTTS==2.2.4 # homeassistant.components.google_assistant_sdk -gassist-text==0.0.7 +gassist-text==0.0.8 # homeassistant.components.google gcal-sync==4.1.2 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index ba3a076f0d4..8fc204d3651 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -573,7 +573,7 @@ fritzconnection==1.10.3 gTTS==2.2.4 # homeassistant.components.google_assistant_sdk -gassist-text==0.0.7 +gassist-text==0.0.8 # homeassistant.components.google gcal-sync==4.1.2 diff --git a/tests/components/google_assistant_sdk/test_init.py b/tests/components/google_assistant_sdk/test_init.py index b93f83feda7..01993389c80 100644 --- a/tests/components/google_assistant_sdk/test_init.py +++ b/tests/components/google_assistant_sdk/test_init.py @@ -1,4 +1,5 @@ """Tests for Google Assistant SDK.""" +from datetime import timedelta import http import time from unittest.mock import call, patch @@ -10,12 +11,22 @@ from homeassistant.components.google_assistant_sdk import DOMAIN from homeassistant.config_entries import ConfigEntryState from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component +from homeassistant.util.dt import utcnow from .conftest import ComponentSetup, ExpectedCredentials +from tests.common import async_fire_time_changed, async_mock_service from tests.test_util.aiohttp import AiohttpClientMocker +async def fetch_api_url(hass_client, url): + """Fetch an API URL and return HTTP status and contents.""" + client = await hass_client() + response = await client.get(url) + contents = await response.read() + return response.status, contents + + async def test_setup_success( hass: HomeAssistant, setup_integration: ComponentSetup ) -> None: @@ -129,7 +140,7 @@ async def test_send_text_command( blocking=True, ) mock_text_assistant.assert_called_once_with( - ExpectedCredentials(), expected_language_code + ExpectedCredentials(), expected_language_code, audio_out=False ) mock_text_assistant.assert_has_calls([call().__enter__().assist(command)]) @@ -180,6 +191,88 @@ async def test_send_text_command_expired_token_refresh_failure( assert any(entry.async_get_active_flows(hass, {"reauth"})) == requires_reauth +async def test_send_text_command_media_player( + hass: HomeAssistant, setup_integration: ComponentSetup, hass_client +) -> None: + """Test send_text_command with media_player.""" + await setup_integration() + + play_media_calls = async_mock_service(hass, "media_player", "play_media") + + command = "tell me a joke" + media_player = "media_player.office_speaker" + audio_response1 = b"joke1 audio response bytes" + audio_response2 = b"joke2 audio response bytes" + with patch( + "homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist", + side_effect=[ + ("joke1 text", None, audio_response1), + ("joke2 text", None, audio_response2), + ], + ) as mock_assist_call: + # Run the same command twice, getting different audio response each time. + await hass.services.async_call( + DOMAIN, + "send_text_command", + { + "command": command, + "media_player": media_player, + }, + blocking=True, + ) + await hass.services.async_call( + DOMAIN, + "send_text_command", + { + "command": command, + "media_player": media_player, + }, + blocking=True, + ) + + mock_assist_call.assert_has_calls([call(command), call(command)]) + assert len(play_media_calls) == 2 + for play_media_call in play_media_calls: + assert play_media_call.data["entity_id"] == [media_player] + assert play_media_call.data["media_content_id"].startswith( + "/api/google_assistant_sdk/audio/" + ) + + audio_url1 = play_media_calls[0].data["media_content_id"] + audio_url2 = play_media_calls[1].data["media_content_id"] + assert audio_url1 != audio_url2 + + # Assert that both audio responses can be served + status, response = await fetch_api_url(hass_client, audio_url1) + assert status == http.HTTPStatus.OK + assert response == audio_response1 + status, response = await fetch_api_url(hass_client, audio_url2) + assert status == http.HTTPStatus.OK + assert response == audio_response2 + + # Assert a nonexistent URL returns 404 + status, _ = await fetch_api_url( + hass_client, "/api/google_assistant_sdk/audio/nonexistent" + ) + assert status == http.HTTPStatus.NOT_FOUND + + # Assert that both audio responses can still be served before the 5 minutes expiration + async_fire_time_changed(hass, utcnow() + timedelta(minutes=4)) + status, response = await fetch_api_url(hass_client, audio_url1) + assert status == http.HTTPStatus.OK + assert response == audio_response1 + status, response = await fetch_api_url(hass_client, audio_url2) + assert status == http.HTTPStatus.OK + assert response == audio_response2 + + # Assert that they cannot be served after the 5 minutes expiration + async_fire_time_changed(hass, utcnow() + timedelta(minutes=6)) + status, response = await fetch_api_url(hass_client, audio_url1) + assert status == http.HTTPStatus.NOT_FOUND + status, response = await fetch_api_url(hass_client, audio_url2) + assert status == http.HTTPStatus.NOT_FOUND + + async def test_conversation_agent( hass: HomeAssistant, setup_integration: ComponentSetup, diff --git a/tests/components/google_assistant_sdk/test_notify.py b/tests/components/google_assistant_sdk/test_notify.py index 85d421b1675..95d5720cc7e 100644 --- a/tests/components/google_assistant_sdk/test_notify.py +++ b/tests/components/google_assistant_sdk/test_notify.py @@ -44,7 +44,9 @@ async def test_broadcast_no_targets( {notify.ATTR_MESSAGE: message}, ) await hass.async_block_till_done() - mock_text_assistant.assert_called_once_with(ExpectedCredentials(), language_code) + mock_text_assistant.assert_called_once_with( + ExpectedCredentials(), language_code, audio_out=False + ) mock_text_assistant.assert_has_calls([call().__enter__().assist(expected_command)]) @@ -84,7 +86,7 @@ async def test_broadcast_one_target( with patch( "homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist", - return_value=["text_response", None], + return_value=("text_response", None, b""), ) as mock_assist_call: await hass.services.async_call( notify.DOMAIN, @@ -108,7 +110,7 @@ async def test_broadcast_two_targets( expected_command2 = "broadcast to master bedroom time for dinner" with patch( "homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist", - return_value=["text_response", None], + return_value=("text_response", None, b""), ) as mock_assist_call: await hass.services.async_call( notify.DOMAIN, @@ -129,7 +131,7 @@ async def test_broadcast_empty_message( with patch( "homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist", - return_value=["text_response", None], + return_value=("text_response", None, b""), ) as mock_assist_call: await hass.services.async_call( notify.DOMAIN,