Conversation cleanup (#86592)

* Require config entry when setting Conversation agent, add new unset agent method

* Remove onboarding from conversation agent

* Type attribution

* Wrap async_process params in ConversationInput object
This commit is contained in:
Paulus Schoutsen 2023-01-24 22:47:49 -05:00 committed by GitHub
parent 5c6656dcac
commit 6c8efe3a3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 97 additions and 127 deletions

View file

@ -22,7 +22,7 @@ from homeassistant.const import (
CONF_TYPE,
EVENT_HOMEASSISTANT_START,
)
from homeassistant.core import Context, CoreState, HomeAssistant
from homeassistant.core import CoreState, HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import (
aiohttp_client,
@ -147,7 +147,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, almond_hass_start)
conversation.async_set_agent(hass, agent)
conversation.async_set_agent(hass, entry, agent)
return True
@ -223,7 +223,7 @@ async def _configure_almond_for_ha(
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Almond."""
conversation.async_set_agent(hass, None)
conversation.async_unset_agent(hass, entry)
return True
@ -264,40 +264,13 @@ class AlmondAgent(conversation.AbstractConversationAgent):
"""Return the attribution."""
return {"name": "Powered by Almond", "url": "https://almond.stanford.edu/"}
async def async_get_onboarding(self):
"""Get onboard url if not onboarded."""
if self.entry.data.get("onboarded"):
return None
host = self.entry.data["host"]
if self.entry.data.get("is_hassio"):
host = "/core_almond"
return {
"text": (
"Would you like to opt-in to share your anonymized commands with"
" Stanford to improve Almond's responses?"
),
"url": f"{host}/conversation",
}
async def async_set_onboarding(self, shown):
"""Set onboarding status."""
self.hass.config_entries.async_update_entry(
self.entry, data={**self.entry.data, "onboarded": shown}
)
return True
async def async_process(
self,
text: str,
context: Context,
conversation_id: str | None = None,
language: str | None = None,
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
response = await self.api.async_converse_text(text, conversation_id)
language = language or self.hass.config.language
response = await self.api.async_converse_text(
user_input.text, user_input.conversation_id
)
first_choice = True
buffer = ""
@ -318,8 +291,8 @@ class AlmondAgent(conversation.AbstractConversationAgent):
buffer += ","
buffer += f" {message['title']}"
intent_response = intent.IntentResponse(language=language)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(buffer.strip())
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)

View file

@ -10,12 +10,13 @@ import voluptuous as vol
from homeassistant import core
from homeassistant.components import http, websocket_api
from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv, intent
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
from .agent import AbstractConversationAgent, ConversationResult
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .default_agent import DefaultAgent
_LOGGER = logging.getLogger(__name__)
@ -62,11 +63,25 @@ CONFIG_SCHEMA = vol.Schema(
@core.callback
@bind_hass
def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent | None):
def async_set_agent(
hass: core.HomeAssistant,
config_entry: ConfigEntry,
agent: AbstractConversationAgent,
):
"""Set the agent to handle the conversations."""
hass.data[DATA_AGENT] = agent
@core.callback
@bind_hass
def async_unset_agent(
hass: core.HomeAssistant,
config_entry: ConfigEntry,
):
"""Set the agent to handle the conversations."""
hass.data[DATA_AGENT] = None
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service."""
if config_intents := config.get(DOMAIN, {}).get("intents"):
@ -79,7 +94,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
agent = await _get_agent(hass)
try:
await agent.async_process(
text, service.context, language=service.data.get(ATTR_LANGUAGE)
ConversationInput(
text=text,
context=service.context,
conversation_id=None,
language=service.data.get(ATTR_LANGUAGE, hass.config.language),
)
)
except intent.IntentHandleError as err:
_LOGGER.error("Error processing %s: %s", text, err)
@ -99,7 +119,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
websocket_api.async_register_command(hass, websocket_process)
websocket_api.async_register_command(hass, websocket_prepare)
websocket_api.async_register_command(hass, websocket_get_agent_info)
websocket_api.async_register_command(hass, websocket_set_onboarding)
return True
@ -158,41 +177,17 @@ async def websocket_get_agent_info(
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Do we need onboarding."""
"""Info about the agent in use."""
agent = await _get_agent(hass)
connection.send_result(
msg["id"],
{
"onboarding": await agent.async_get_onboarding(),
"attribution": agent.attribution,
},
)
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/onboarding/set",
vol.Required("shown"): bool,
}
)
@websocket_api.async_response
async def websocket_set_onboarding(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Set onboarding status."""
agent = await _get_agent(hass)
success = await agent.async_set_onboarding(msg["shown"])
if success:
connection.send_result(msg["id"])
else:
connection.send_error(msg["id"], "error", "Failed to set onboarding")
class ConversationProcessView(http.HomeAssistantView):
"""View to process text."""
@ -242,5 +237,12 @@ async def async_converse(
if language is None:
language = hass.config.language
result = await agent.async_process(text, context, conversation_id, language)
result = await agent.async_process(
ConversationInput(
text=text,
context=context,
conversation_id=conversation_id,
language=language,
)
)
return result

View file

@ -3,12 +3,22 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
from typing import Any, TypedDict
from homeassistant.core import Context
from homeassistant.helpers import intent
@dataclass
class ConversationInput:
"""User input to be processed."""
text: str
context: Context
conversation_id: str | None
language: str
@dataclass
class ConversationResult:
"""Result of async_process."""
@ -24,34 +34,27 @@ class ConversationResult:
}
class Attribution(TypedDict):
"""Attribution for a conversation agent."""
name: str
url: str
class AbstractConversationAgent(ABC):
"""Abstract conversation agent."""
@property
def attribution(self):
def attribution(self) -> Attribution | None:
"""Return the attribution."""
return None
async def async_get_onboarding(self):
"""Get onboard data."""
return None
async def async_set_onboarding(self, shown):
"""Set onboard data."""
return True
@abstractmethod
async def async_process(
self,
text: str,
context: Context,
conversation_id: str | None = None,
language: str | None = None,
) -> ConversationResult:
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence."""
async def async_reload(self, language: str | None = None):
async def async_reload(self, language: str | None = None) -> None:
"""Clear cached intents for a language."""
async def async_prepare(self, language: str | None = None):
async def async_prepare(self, language: str | None = None) -> None:
"""Load intents for a language."""

View file

@ -20,7 +20,7 @@ from homeassistant import core, setup
from homeassistant.helpers import area_registry, entity_registry, intent, template
from homeassistant.helpers.json import json_loads
from .agent import AbstractConversationAgent, ConversationResult
from .agent import AbstractConversationAgent, ConversationInput, ConversationResult
from .const import DOMAIN
_LOGGER = logging.getLogger(__name__)
@ -81,16 +81,11 @@ class DefaultAgent(AbstractConversationAgent):
if config_intents:
self._config_intents = config_intents
async def async_process(
self,
text: str,
context: core.Context,
conversation_id: str | None = None,
language: str | None = None,
) -> ConversationResult:
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
"""Process a sentence."""
language = language or self.hass.config.language
language = user_input.language or self.hass.config.language
lang_intents = self._lang_intents.get(language)
conversation_id = None # Not supported
# Reload intents if missing or new components
if lang_intents is None or (
@ -114,9 +109,9 @@ class DefaultAgent(AbstractConversationAgent):
"name": self._make_names_list(),
}
result = recognize(text, lang_intents.intents, slot_lists=slot_lists)
result = recognize(user_input.text, lang_intents.intents, slot_lists=slot_lists)
if result is None:
_LOGGER.debug("No intent was matched for '%s'", text)
_LOGGER.debug("No intent was matched for '%s'", user_input.text)
return _make_error_result(
language,
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
@ -133,8 +128,8 @@ class DefaultAgent(AbstractConversationAgent):
entity.name: {"value": entity.value}
for entity in result.entities_list
},
text,
context,
user_input.text,
user_input.context,
language,
)
except intent.IntentHandleError:

View file

@ -9,7 +9,7 @@ import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
from homeassistant.core import Context, HomeAssistant, ServiceCall
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import config_validation as cv, discovery, intent
from homeassistant.helpers.config_entry_oauth2_flow import (
@ -124,9 +124,9 @@ async def update_listener(hass, entry):
"""Handle options update."""
if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False):
agent = GoogleAssistantConversationAgent(hass, entry)
conversation.async_set_agent(hass, agent)
conversation.async_set_agent(hass, entry, agent)
else:
conversation.async_set_agent(hass, None)
conversation.async_unset_agent(hass, entry)
class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
@ -148,11 +148,7 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
}
async def async_process(
self,
text: str,
context: Context,
conversation_id: str | None = None,
language: str | None = None,
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
if self.session:
@ -170,12 +166,11 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
)
self.assistant = TextAssistant(credentials, language_code)
resp = self.assistant.assist(text)
resp = self.assistant.assist(user_input.text)
text_response = resp[0]
language = language or self.hass.config.language
intent_response = intent.IntentResponse(language=language)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(text_response)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=user_input.conversation_id
)

View file

@ -2,7 +2,6 @@
from __future__ import annotations
from homeassistant.components import conversation
from homeassistant.core import Context
from homeassistant.helpers import intent
@ -15,16 +14,12 @@ class MockAgent(conversation.AbstractConversationAgent):
self.response = "Test response"
async def async_process(
self,
text: str,
context: Context,
conversation_id: str | None = None,
language: str | None = None,
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process some text."""
self.calls.append((text, context, conversation_id, language))
response = intent.IntentResponse(language=language)
self.calls.append(user_input)
response = intent.IntentResponse(language=user_input.language)
response.async_set_speech(self.response)
return conversation.ConversationResult(
response=response, conversation_id=conversation_id
response=response, conversation_id=user_input.conversation_id
)

View file

@ -11,5 +11,5 @@ from . import MockAgent
def mock_agent(hass):
"""Mock agent."""
agent = MockAgent()
conversation.async_set_agent(hass, agent)
conversation.async_set_agent(hass, None, agent)
return agent

View file

@ -1,6 +1,6 @@
"""The tests for the Conversation component."""
from http import HTTPStatus
from unittest.mock import ANY, patch
from unittest.mock import patch
import pytest
@ -295,10 +295,10 @@ async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent):
}
assert len(mock_agent.calls) == 1
assert mock_agent.calls[0][0] == "Test Text"
assert mock_agent.calls[0][1].user_id == hass_admin_user.id
assert mock_agent.calls[0][2] == "test-conv-id"
assert mock_agent.calls[0][3] == "test-language"
assert mock_agent.calls[0].text == "Test Text"
assert mock_agent.calls[0].context.user_id == hass_admin_user.id
assert mock_agent.calls[0].conversation_id == "test-conv-id"
assert mock_agent.calls[0].language == "test-language"
@pytest.mark.parametrize(
@ -349,7 +349,7 @@ async def test_ws_api(hass, hass_ws_client, payload):
"language": payload.get("language", hass.config.language),
"data": {"code": "no_intent_match"},
},
"conversation_id": payload.get("conversation_id") or ANY,
"conversation_id": None,
}
@ -560,5 +560,12 @@ async def test_non_default_response(hass, init_components):
agent = await conversation._get_agent(hass)
assert isinstance(agent, conversation.DefaultAgent)
result = await agent.async_process("open the front door", Context())
result = await agent.async_process(
conversation.ConversationInput(
text="open the front door",
context=Context(),
conversation_id=None,
language=hass.config.language,
)
)
assert result.response.speech["plain"]["speech"] == "Opened front door"