Google Assistant SDK: support audio response playback (#85989)

* Google Assistant SDK: support response playback

* Update PATHS_WITHOUT_AUTH

* gassist-text==0.0.8

* address review comments
This commit is contained in:
tronikos 2023-01-24 08:19:23 -08:00 committed by GitHub
parent 80a8da26bc
commit 0daaa37e09
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 244 additions and 24 deletions

View file

@ -11,23 +11,36 @@ from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
from homeassistant.core import Context, HomeAssistant, ServiceCall from homeassistant.core import Context, HomeAssistant, ServiceCall
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import discovery, intent from homeassistant.helpers import config_validation as cv, discovery, intent
from homeassistant.helpers.config_entry_oauth2_flow import ( from homeassistant.helpers.config_entry_oauth2_flow import (
OAuth2Session, OAuth2Session,
async_get_config_entry_implementation, async_get_config_entry_implementation,
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import CONF_ENABLE_CONVERSATION_AGENT, CONF_LANGUAGE_CODE, DOMAIN from .const import (
from .helpers import async_send_text_commands, default_language_code CONF_ENABLE_CONVERSATION_AGENT,
CONF_LANGUAGE_CODE,
DATA_MEM_STORAGE,
DATA_SESSION,
DOMAIN,
)
from .helpers import (
GoogleAssistantSDKAudioView,
InMemoryStorage,
async_send_text_commands,
default_language_code,
)
SERVICE_SEND_TEXT_COMMAND = "send_text_command" SERVICE_SEND_TEXT_COMMAND = "send_text_command"
SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND = "command" SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND = "command"
SERVICE_SEND_TEXT_COMMAND_FIELD_MEDIA_PLAYER = "media_player"
SERVICE_SEND_TEXT_COMMAND_SCHEMA = vol.All( SERVICE_SEND_TEXT_COMMAND_SCHEMA = vol.All(
{ {
vol.Required(SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND): vol.All( vol.Required(SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND): vol.All(
str, vol.Length(min=1) str, vol.Length(min=1)
), ),
vol.Optional(SERVICE_SEND_TEXT_COMMAND_FIELD_MEDIA_PLAYER): cv.comp_entity_ids,
}, },
) )
@ -45,6 +58,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up Google Assistant SDK from a config entry.""" """Set up Google Assistant SDK from a config entry."""
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {}
implementation = await async_get_config_entry_implementation(hass, entry) implementation = await async_get_config_entry_implementation(hass, entry)
session = OAuth2Session(hass, entry, implementation) session = OAuth2Session(hass, entry, implementation)
try: try:
@ -57,7 +72,11 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
raise ConfigEntryNotReady from err raise ConfigEntryNotReady from err
except aiohttp.ClientError as err: except aiohttp.ClientError as err:
raise ConfigEntryNotReady from err raise ConfigEntryNotReady from err
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = session hass.data[DOMAIN][entry.entry_id][DATA_SESSION] = session
mem_storage = InMemoryStorage(hass)
hass.data[DOMAIN][entry.entry_id][DATA_MEM_STORAGE] = mem_storage
hass.http.register_view(GoogleAssistantSDKAudioView(mem_storage))
await async_setup_service(hass) await async_setup_service(hass)
@ -88,7 +107,10 @@ async def async_setup_service(hass: HomeAssistant) -> None:
async def send_text_command(call: ServiceCall) -> None: async def send_text_command(call: ServiceCall) -> None:
"""Send a text command to Google Assistant SDK.""" """Send a text command to Google Assistant SDK."""
command: str = call.data[SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND] command: str = call.data[SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND]
await async_send_text_commands([command], hass) media_players: list[str] | None = call.data.get(
SERVICE_SEND_TEXT_COMMAND_FIELD_MEDIA_PLAYER
)
await async_send_text_commands(hass, [command], media_players)
hass.services.async_register( hass.services.async_register(
DOMAIN, DOMAIN,
@ -136,7 +158,7 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
if self.session: if self.session:
session = self.session session = self.session
else: else:
session = self.hass.data[DOMAIN].get(self.entry.entry_id) session = self.hass.data[DOMAIN][self.entry.entry_id][DATA_SESSION]
self.session = session self.session = session
if not session.valid_token: if not session.valid_token:
await session.async_ensure_token_valid() await session.async_ensure_token_valid()

View file

@ -5,8 +5,12 @@ DOMAIN: Final = "google_assistant_sdk"
DEFAULT_NAME: Final = "Google Assistant SDK" DEFAULT_NAME: Final = "Google Assistant SDK"
CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent"
CONF_LANGUAGE_CODE: Final = "language_code" CONF_LANGUAGE_CODE: Final = "language_code"
DATA_MEM_STORAGE: Final = "mem_storage"
DATA_SESSION: Final = "session"
# https://developers.google.com/assistant/sdk/reference/rpc/languages # https://developers.google.com/assistant/sdk/reference/rpc/languages
SUPPORTED_LANGUAGE_CODES: Final = [ SUPPORTED_LANGUAGE_CODES: Final = [
"de-DE", "de-DE",
@ -24,5 +28,3 @@ SUPPORTED_LANGUAGE_CODES: Final = [
"ko-KR", "ko-KR",
"pt-BR", "pt-BR",
] ]
CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent"

View file

@ -1,18 +1,38 @@
"""Helper classes for Google Assistant SDK integration.""" """Helper classes for Google Assistant SDK integration."""
from __future__ import annotations from __future__ import annotations
from http import HTTPStatus
import logging import logging
from typing import Any
import uuid
import aiohttp import aiohttp
from aiohttp import web
from gassist_text import TextAssistant from gassist_text import TextAssistant
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE,
ATTR_MEDIA_CONTENT_ID,
ATTR_MEDIA_CONTENT_TYPE,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
MediaType,
)
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_ACCESS_TOKEN from homeassistant.const import ATTR_ENTITY_ID, CONF_ACCESS_TOKEN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session
from homeassistant.helpers.event import async_call_later
from .const import CONF_LANGUAGE_CODE, DOMAIN, SUPPORTED_LANGUAGE_CODES from .const import (
CONF_LANGUAGE_CODE,
DATA_MEM_STORAGE,
DATA_SESSION,
DOMAIN,
SUPPORTED_LANGUAGE_CODES,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -28,12 +48,14 @@ DEFAULT_LANGUAGE_CODES = {
} }
async def async_send_text_commands(commands: list[str], hass: HomeAssistant) -> None: async def async_send_text_commands(
hass: HomeAssistant, commands: list[str], media_players: list[str] | None = None
) -> None:
"""Send text commands to Google Assistant Service.""" """Send text commands to Google Assistant Service."""
# There can only be 1 entry (config_flow has single_instance_allowed) # There can only be 1 entry (config_flow has single_instance_allowed)
entry: ConfigEntry = hass.config_entries.async_entries(DOMAIN)[0] entry: ConfigEntry = hass.config_entries.async_entries(DOMAIN)[0]
session: OAuth2Session = hass.data[DOMAIN].get(entry.entry_id) session: OAuth2Session = hass.data[DOMAIN][entry.entry_id][DATA_SESSION]
try: try:
await session.async_ensure_token_valid() await session.async_ensure_token_valid()
except aiohttp.ClientResponseError as err: except aiohttp.ClientResponseError as err:
@ -43,10 +65,32 @@ async def async_send_text_commands(commands: list[str], hass: HomeAssistant) ->
credentials = Credentials(session.token[CONF_ACCESS_TOKEN]) credentials = Credentials(session.token[CONF_ACCESS_TOKEN])
language_code = entry.options.get(CONF_LANGUAGE_CODE, default_language_code(hass)) language_code = entry.options.get(CONF_LANGUAGE_CODE, default_language_code(hass))
with TextAssistant(credentials, language_code) as assistant: with TextAssistant(
credentials, language_code, audio_out=bool(media_players)
) as assistant:
for command in commands: for command in commands:
text_response = assistant.assist(command)[0] resp = assistant.assist(command)
text_response = resp[0]
_LOGGER.debug("command: %s\nresponse: %s", command, text_response) _LOGGER.debug("command: %s\nresponse: %s", command, text_response)
audio_response = resp[2]
if media_players and audio_response:
mem_storage: InMemoryStorage = hass.data[DOMAIN][entry.entry_id][
DATA_MEM_STORAGE
]
audio_url = GoogleAssistantSDKAudioView.url.format(
filename=mem_storage.store_and_get_identifier(audio_response)
)
await hass.services.async_call(
DOMAIN_MP,
SERVICE_PLAY_MEDIA,
{
ATTR_ENTITY_ID: media_players,
ATTR_MEDIA_CONTENT_ID: audio_url,
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
ATTR_MEDIA_ANNOUNCE: True,
},
blocking=True,
)
def default_language_code(hass: HomeAssistant): def default_language_code(hass: HomeAssistant):
@ -55,3 +99,53 @@ def default_language_code(hass: HomeAssistant):
if language_code in SUPPORTED_LANGUAGE_CODES: if language_code in SUPPORTED_LANGUAGE_CODES:
return language_code return language_code
return DEFAULT_LANGUAGE_CODES.get(hass.config.language, "en-US") return DEFAULT_LANGUAGE_CODES.get(hass.config.language, "en-US")
class InMemoryStorage:
"""Temporarily store and retrieve data from in memory storage."""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize InMemoryStorage."""
self.hass: HomeAssistant = hass
self.mem: dict[str, bytes] = {}
def store_and_get_identifier(self, data: bytes) -> str:
"""
Temporarily store data and return identifier to be able to retrieve it.
Data expires after 5 minutes.
"""
identifier: str = uuid.uuid1().hex
self.mem[identifier] = data
def async_remove_from_mem(*_: Any) -> None:
"""Cleanup memory."""
self.mem.pop(identifier, None)
# Remove the entry from memory 5 minutes later
async_call_later(self.hass, 5 * 60, async_remove_from_mem)
return identifier
def retrieve(self, identifier: str) -> bytes | None:
"""Retrieve previously stored data."""
return self.mem.get(identifier)
class GoogleAssistantSDKAudioView(HomeAssistantView):
"""Google Assistant SDK view to serve audio responses."""
requires_auth = True
url = "/api/google_assistant_sdk/audio/{filename}"
name = "api:google_assistant_sdk:audio"
def __init__(self, mem_storage: InMemoryStorage) -> None:
"""Initialize GoogleAssistantSDKView."""
self.mem_storage: InMemoryStorage = mem_storage
async def get(self, request: web.Request, filename: str) -> web.Response:
"""Start a get request."""
audio = self.mem_storage.retrieve(filename)
if not audio:
return web.Response(status=HTTPStatus.NOT_FOUND)
return web.Response(body=audio, content_type="audio/mpeg")

View file

@ -2,9 +2,9 @@
"domain": "google_assistant_sdk", "domain": "google_assistant_sdk",
"name": "Google Assistant SDK", "name": "Google Assistant SDK",
"config_flow": true, "config_flow": true,
"dependencies": ["application_credentials"], "dependencies": ["application_credentials", "http"],
"documentation": "https://www.home-assistant.io/integrations/google_assistant_sdk/", "documentation": "https://www.home-assistant.io/integrations/google_assistant_sdk/",
"requirements": ["gassist-text==0.0.7"], "requirements": ["gassist-text==0.0.8"],
"codeowners": ["@tronikos"], "codeowners": ["@tronikos"],
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"integration_type": "service" "integration_type": "service"

View file

@ -70,4 +70,4 @@ class BroadcastNotificationService(BaseNotificationService):
commands.append( commands.append(
broadcast_commands(language_code)[1].format(message, target) broadcast_commands(language_code)[1].format(message, target)
) )
await async_send_text_commands(commands, self.hass) await async_send_text_commands(self.hass, commands)

View file

@ -8,3 +8,10 @@ send_text_command:
example: turn off kitchen TV example: turn off kitchen TV
selector: selector:
text: text:
media_player:
name: Media Player Entity
description: Name(s) of media player entities to play response on
example: media_player.living_room_speaker
selector:
entity:
domain: media_player

View file

@ -754,7 +754,7 @@ fritzconnection==1.10.3
gTTS==2.2.4 gTTS==2.2.4
# homeassistant.components.google_assistant_sdk # homeassistant.components.google_assistant_sdk
gassist-text==0.0.7 gassist-text==0.0.8
# homeassistant.components.google # homeassistant.components.google
gcal-sync==4.1.2 gcal-sync==4.1.2

View file

@ -573,7 +573,7 @@ fritzconnection==1.10.3
gTTS==2.2.4 gTTS==2.2.4
# homeassistant.components.google_assistant_sdk # homeassistant.components.google_assistant_sdk
gassist-text==0.0.7 gassist-text==0.0.8
# homeassistant.components.google # homeassistant.components.google
gcal-sync==4.1.2 gcal-sync==4.1.2

View file

@ -1,4 +1,5 @@
"""Tests for Google Assistant SDK.""" """Tests for Google Assistant SDK."""
from datetime import timedelta
import http import http
import time import time
from unittest.mock import call, patch from unittest.mock import call, patch
@ -10,12 +11,22 @@ from homeassistant.components.google_assistant_sdk import DOMAIN
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow
from .conftest import ComponentSetup, ExpectedCredentials from .conftest import ComponentSetup, ExpectedCredentials
from tests.common import async_fire_time_changed, async_mock_service
from tests.test_util.aiohttp import AiohttpClientMocker from tests.test_util.aiohttp import AiohttpClientMocker
async def fetch_api_url(hass_client, url):
"""Fetch an API URL and return HTTP status and contents."""
client = await hass_client()
response = await client.get(url)
contents = await response.read()
return response.status, contents
async def test_setup_success( async def test_setup_success(
hass: HomeAssistant, setup_integration: ComponentSetup hass: HomeAssistant, setup_integration: ComponentSetup
) -> None: ) -> None:
@ -129,7 +140,7 @@ async def test_send_text_command(
blocking=True, blocking=True,
) )
mock_text_assistant.assert_called_once_with( mock_text_assistant.assert_called_once_with(
ExpectedCredentials(), expected_language_code ExpectedCredentials(), expected_language_code, audio_out=False
) )
mock_text_assistant.assert_has_calls([call().__enter__().assist(command)]) mock_text_assistant.assert_has_calls([call().__enter__().assist(command)])
@ -180,6 +191,88 @@ async def test_send_text_command_expired_token_refresh_failure(
assert any(entry.async_get_active_flows(hass, {"reauth"})) == requires_reauth assert any(entry.async_get_active_flows(hass, {"reauth"})) == requires_reauth
async def test_send_text_command_media_player(
hass: HomeAssistant, setup_integration: ComponentSetup, hass_client
) -> None:
"""Test send_text_command with media_player."""
await setup_integration()
play_media_calls = async_mock_service(hass, "media_player", "play_media")
command = "tell me a joke"
media_player = "media_player.office_speaker"
audio_response1 = b"joke1 audio response bytes"
audio_response2 = b"joke2 audio response bytes"
with patch(
"homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist",
side_effect=[
("joke1 text", None, audio_response1),
("joke2 text", None, audio_response2),
],
) as mock_assist_call:
# Run the same command twice, getting different audio response each time.
await hass.services.async_call(
DOMAIN,
"send_text_command",
{
"command": command,
"media_player": media_player,
},
blocking=True,
)
await hass.services.async_call(
DOMAIN,
"send_text_command",
{
"command": command,
"media_player": media_player,
},
blocking=True,
)
mock_assist_call.assert_has_calls([call(command), call(command)])
assert len(play_media_calls) == 2
for play_media_call in play_media_calls:
assert play_media_call.data["entity_id"] == [media_player]
assert play_media_call.data["media_content_id"].startswith(
"/api/google_assistant_sdk/audio/"
)
audio_url1 = play_media_calls[0].data["media_content_id"]
audio_url2 = play_media_calls[1].data["media_content_id"]
assert audio_url1 != audio_url2
# Assert that both audio responses can be served
status, response = await fetch_api_url(hass_client, audio_url1)
assert status == http.HTTPStatus.OK
assert response == audio_response1
status, response = await fetch_api_url(hass_client, audio_url2)
assert status == http.HTTPStatus.OK
assert response == audio_response2
# Assert a nonexistent URL returns 404
status, _ = await fetch_api_url(
hass_client, "/api/google_assistant_sdk/audio/nonexistent"
)
assert status == http.HTTPStatus.NOT_FOUND
# Assert that both audio responses can still be served before the 5 minutes expiration
async_fire_time_changed(hass, utcnow() + timedelta(minutes=4))
status, response = await fetch_api_url(hass_client, audio_url1)
assert status == http.HTTPStatus.OK
assert response == audio_response1
status, response = await fetch_api_url(hass_client, audio_url2)
assert status == http.HTTPStatus.OK
assert response == audio_response2
# Assert that they cannot be served after the 5 minutes expiration
async_fire_time_changed(hass, utcnow() + timedelta(minutes=6))
status, response = await fetch_api_url(hass_client, audio_url1)
assert status == http.HTTPStatus.NOT_FOUND
status, response = await fetch_api_url(hass_client, audio_url2)
assert status == http.HTTPStatus.NOT_FOUND
async def test_conversation_agent( async def test_conversation_agent(
hass: HomeAssistant, hass: HomeAssistant,
setup_integration: ComponentSetup, setup_integration: ComponentSetup,

View file

@ -44,7 +44,9 @@ async def test_broadcast_no_targets(
{notify.ATTR_MESSAGE: message}, {notify.ATTR_MESSAGE: message},
) )
await hass.async_block_till_done() await hass.async_block_till_done()
mock_text_assistant.assert_called_once_with(ExpectedCredentials(), language_code) mock_text_assistant.assert_called_once_with(
ExpectedCredentials(), language_code, audio_out=False
)
mock_text_assistant.assert_has_calls([call().__enter__().assist(expected_command)]) mock_text_assistant.assert_has_calls([call().__enter__().assist(expected_command)])
@ -84,7 +86,7 @@ async def test_broadcast_one_target(
with patch( with patch(
"homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist", "homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist",
return_value=["text_response", None], return_value=("text_response", None, b""),
) as mock_assist_call: ) as mock_assist_call:
await hass.services.async_call( await hass.services.async_call(
notify.DOMAIN, notify.DOMAIN,
@ -108,7 +110,7 @@ async def test_broadcast_two_targets(
expected_command2 = "broadcast to master bedroom time for dinner" expected_command2 = "broadcast to master bedroom time for dinner"
with patch( with patch(
"homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist", "homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist",
return_value=["text_response", None], return_value=("text_response", None, b""),
) as mock_assist_call: ) as mock_assist_call:
await hass.services.async_call( await hass.services.async_call(
notify.DOMAIN, notify.DOMAIN,
@ -129,7 +131,7 @@ async def test_broadcast_empty_message(
with patch( with patch(
"homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist", "homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist",
return_value=["text_response", None], return_value=("text_response", None, b""),
) as mock_assist_call: ) as mock_assist_call:
await hass.services.async_call( await hass.services.async_call(
notify.DOMAIN, notify.DOMAIN,