Google Assistant SDK conversation agent (#85499)

* Google Assistant SDK conversation agent

* refresh token

* fix session

* Add tests

* Add option to enable conversation agent
This commit is contained in:
tronikos 2023-01-09 17:53:41 -08:00 committed by GitHub
parent f2df72e014
commit e24989b446
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 238 additions and 18 deletions

View file

@ -2,21 +2,24 @@
from __future__ import annotations
import aiohttp
from gassist_text import TextAssistant
from google.oauth2.credentials import Credentials
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import CONF_NAME, Platform
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
from homeassistant.core import Context, HomeAssistant, ServiceCall
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import discovery
from homeassistant.helpers import discovery, intent
from homeassistant.helpers.config_entry_oauth2_flow import (
OAuth2Session,
async_get_config_entry_implementation,
)
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN
from .helpers import async_send_text_commands
from .const import CONF_ENABLE_CONVERSATION_AGENT, CONF_LANGUAGE_CODE, DOMAIN
from .helpers import async_send_text_commands, default_language_code
SERVICE_SEND_TEXT_COMMAND = "send_text_command"
SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND = "command"
@ -58,6 +61,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
await async_setup_service(hass)
entry.async_on_unload(entry.add_update_listener(update_listener))
await update_listener(hass, entry)
return True
@ -90,3 +96,64 @@ async def async_setup_service(hass: HomeAssistant) -> None:
send_text_command,
schema=SERVICE_SEND_TEXT_COMMAND_SCHEMA,
)
async def update_listener(hass, entry):
"""Handle options update."""
if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False):
agent = GoogleAssistantConversationAgent(hass, entry)
conversation.async_set_agent(hass, agent)
else:
conversation.async_set_agent(hass, None)
class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
"""Google Assistant SDK conversation agent."""
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.hass = hass
self.entry = entry
self.assistant: TextAssistant | None = None
self.session: OAuth2Session | None = None
@property
def attribution(self):
"""Return the attribution."""
return {
"name": "Powered by Google Assistant SDK",
"url": "https://www.home-assistant.io/integrations/google_assistant_sdk/",
}
async def async_process(
self,
text: str,
context: Context,
conversation_id: str | None = None,
language: str | None = None,
) -> conversation.ConversationResult | None:
"""Process a sentence."""
if self.session:
session = self.session
else:
session = self.hass.data[DOMAIN].get(self.entry.entry_id)
self.session = session
if not session.valid_token:
await session.async_ensure_token_valid()
self.assistant = None
if not self.assistant:
credentials = Credentials(session.token[CONF_ACCESS_TOKEN])
language_code = self.entry.options.get(
CONF_LANGUAGE_CODE, default_language_code(self.hass)
)
self.assistant = TextAssistant(credentials, language_code)
resp = self.assistant.assist(text)
text_response = resp[0]
language = language or self.hass.config.language
intent_response = intent.IntentResponse(language=language)
intent_response.async_set_speech(text_response)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

View file

@ -13,7 +13,13 @@ from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import config_entry_oauth2_flow
from .const import CONF_LANGUAGE_CODE, DEFAULT_NAME, DOMAIN, SUPPORTED_LANGUAGE_CODES
from .const import (
CONF_ENABLE_CONVERSATION_AGENT,
CONF_LANGUAGE_CODE,
DEFAULT_NAME,
DOMAIN,
SUPPORTED_LANGUAGE_CODES,
)
from .helpers import default_language_code
_LOGGER = logging.getLogger(__name__)
@ -108,6 +114,12 @@ class OptionsFlowHandler(config_entries.OptionsFlow):
CONF_LANGUAGE_CODE,
default=self.config_entry.options.get(CONF_LANGUAGE_CODE),
): vol.In(SUPPORTED_LANGUAGE_CODES),
vol.Required(
CONF_ENABLE_CONVERSATION_AGENT,
default=self.config_entry.options.get(
CONF_ENABLE_CONVERSATION_AGENT
),
): bool,
}
),
)

View file

@ -24,3 +24,5 @@ SUPPORTED_LANGUAGE_CODES: Final = [
"ko-KR",
"pt-BR",
]
CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent"

View file

@ -31,8 +31,10 @@
"step": {
"init": {
"data": {
"enable_conversation_agent": "Enable the conversation agent",
"language_code": "Language code"
}
},
"description": "Set language for interactions with Google Assistant and whether you want to enable the conversation agent."
}
}
},

View file

@ -34,8 +34,10 @@
"step": {
"init": {
"data": {
"enable_conversation_agent": "Enable conversation agent",
"language_code": "Language code"
}
},
"description": "Set language for interactions with Google Assistant and whether you want to enable the conversation agent."
}
}
}

View file

@ -87,6 +87,10 @@ async def mock_setup_integration(
class ExpectedCredentials:
"""Assert credentials have the expected access token."""
def __init__(self, expected_access_token: str = ACCESS_TOKEN) -> None:
"""Initialize ExpectedCredentials."""
self.expected_access_token = expected_access_token
def __eq__(self, other: Credentials):
"""Return true if credentials have the expected access token."""
return other.token == ACCESS_TOKEN
return other.token == self.expected_access_token

View file

@ -221,39 +221,65 @@ async def test_options_flow(
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"language_code"}
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"language_code": "es-ES"},
user_input={"enable_conversation_agent": False, "language_code": "es-ES"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {"language_code": "es-ES"}
assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "es-ES",
}
# Retrigger options flow, not change language
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"language_code"}
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"language_code": "es-ES"},
user_input={"enable_conversation_agent": False, "language_code": "es-ES"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {"language_code": "es-ES"}
assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "es-ES",
}
# Retrigger options flow, change language
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"language_code"}
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"language_code": "en-US"},
user_input={"enable_conversation_agent": False, "language_code": "en-US"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {"language_code": "en-US"}
assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "en-US",
}
# Retrigger options flow, enable conversation agent
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"enable_conversation_agent": True, "language_code": "en-US"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {
"enable_conversation_agent": True,
"language_code": "en-US",
}

View file

@ -9,6 +9,7 @@ import pytest
from homeassistant.components.google_assistant_sdk import DOMAIN
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from .conftest import ComponentSetup, ExpectedCredentials
@ -177,3 +178,107 @@ async def test_send_text_command_expired_token_refresh_failure(
)
assert any(entry.async_get_active_flows(hass, {"reauth"})) == requires_reauth
async def test_conversation_agent(
hass: HomeAssistant,
setup_integration: ComponentSetup,
) -> None:
"""Test GoogleAssistantConversationAgent."""
await setup_integration()
assert await async_setup_component(hass, "conversation", {})
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED
hass.config_entries.async_update_entry(
entry, options={"enable_conversation_agent": True}
)
await hass.async_block_till_done()
text1 = "tell me a joke"
text2 = "tell me another one"
with patch(
"homeassistant.components.google_assistant_sdk.TextAssistant"
) as mock_text_assistant:
await hass.services.async_call(
"conversation",
"process",
{"text": text1},
blocking=True,
)
await hass.services.async_call(
"conversation",
"process",
{"text": text2},
blocking=True,
)
# Assert constructor is called only once since it's reused across requests
assert mock_text_assistant.call_count == 1
mock_text_assistant.assert_called_once_with(ExpectedCredentials(), "en-US")
mock_text_assistant.assert_has_calls([call().assist(text1)])
mock_text_assistant.assert_has_calls([call().assist(text2)])
async def test_conversation_agent_refresh_token(
hass: HomeAssistant,
setup_integration: ComponentSetup,
aioclient_mock: AiohttpClientMocker,
) -> None:
"""Test GoogleAssistantConversationAgent when token is expired."""
await setup_integration()
assert await async_setup_component(hass, "conversation", {})
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED
hass.config_entries.async_update_entry(
entry, options={"enable_conversation_agent": True}
)
await hass.async_block_till_done()
text1 = "tell me a joke"
text2 = "tell me another one"
with patch(
"homeassistant.components.google_assistant_sdk.TextAssistant"
) as mock_text_assistant:
await hass.services.async_call(
"conversation",
"process",
{"text": text1},
blocking=True,
)
# Expire the token between requests
entry.data["token"]["expires_at"] = time.time() - 3600
updated_access_token = "updated-access-token"
aioclient_mock.post(
"https://oauth2.googleapis.com/token",
json={
"access_token": updated_access_token,
"refresh_token": "updated-refresh-token",
"expires_at": time.time() + 3600,
"expires_in": 3600,
},
)
await hass.services.async_call(
"conversation",
"process",
{"text": text2},
blocking=True,
)
# Assert constructor is called twice since the token was expired
assert mock_text_assistant.call_count == 2
mock_text_assistant.assert_has_calls([call(ExpectedCredentials(), "en-US")])
mock_text_assistant.assert_has_calls(
[call(ExpectedCredentials(updated_access_token), "en-US")]
)
mock_text_assistant.assert_has_calls([call().assist(text1)])
mock_text_assistant.assert_has_calls([call().assist(text2)])