diff --git a/homeassistant/components/almond/__init__.py b/homeassistant/components/almond/__init__.py index 3da49e51f21..39bd2a10335 100644 --- a/homeassistant/components/almond/__init__.py +++ b/homeassistant/components/almond/__init__.py @@ -22,7 +22,7 @@ from homeassistant.const import ( CONF_TYPE, EVENT_HOMEASSISTANT_START, ) -from homeassistant.core import Context, CoreState, HomeAssistant +from homeassistant.core import CoreState, HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers import ( aiohttp_client, @@ -147,7 +147,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, almond_hass_start) - conversation.async_set_agent(hass, agent) + conversation.async_set_agent(hass, entry, agent) return True @@ -223,7 +223,7 @@ async def _configure_almond_for_ha( async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload Almond.""" - conversation.async_set_agent(hass, None) + conversation.async_unset_agent(hass, entry) return True @@ -264,40 +264,13 @@ class AlmondAgent(conversation.AbstractConversationAgent): """Return the attribution.""" return {"name": "Powered by Almond", "url": "https://almond.stanford.edu/"} - async def async_get_onboarding(self): - """Get onboard url if not onboarded.""" - if self.entry.data.get("onboarded"): - return None - - host = self.entry.data["host"] - if self.entry.data.get("is_hassio"): - host = "/core_almond" - return { - "text": ( - "Would you like to opt-in to share your anonymized commands with" - " Stanford to improve Almond's responses?" - ), - "url": f"{host}/conversation", - } - - async def async_set_onboarding(self, shown): - """Set onboarding status.""" - self.hass.config_entries.async_update_entry( - self.entry, data={**self.entry.data, "onboarded": shown} - ) - - return True - async def async_process( - self, - text: str, - context: Context, - conversation_id: str | None = None, - language: str | None = None, + self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" - response = await self.api.async_converse_text(text, conversation_id) - language = language or self.hass.config.language + response = await self.api.async_converse_text( + user_input.text, user_input.conversation_id + ) first_choice = True buffer = "" @@ -318,8 +291,8 @@ class AlmondAgent(conversation.AbstractConversationAgent): buffer += "," buffer += f" {message['title']}" - intent_response = intent.IntentResponse(language=language) + intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_speech(buffer.strip()) return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id + response=intent_response, conversation_id=user_input.conversation_id ) diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index ed06234707f..a9356ab8b7e 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -10,12 +10,13 @@ import voluptuous as vol from homeassistant import core from homeassistant.components import http, websocket_api from homeassistant.components.http.data_validator import RequestDataValidator +from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv, intent from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass -from .agent import AbstractConversationAgent, ConversationResult +from .agent import AbstractConversationAgent, ConversationInput, ConversationResult from .default_agent import DefaultAgent _LOGGER = logging.getLogger(__name__) @@ -62,11 +63,25 @@ CONFIG_SCHEMA = vol.Schema( @core.callback @bind_hass -def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent | None): +def async_set_agent( + hass: core.HomeAssistant, + config_entry: ConfigEntry, + agent: AbstractConversationAgent, +): """Set the agent to handle the conversations.""" hass.data[DATA_AGENT] = agent +@core.callback +@bind_hass +def async_unset_agent( + hass: core.HomeAssistant, + config_entry: ConfigEntry, +): + """Set the agent to handle the conversations.""" + hass.data[DATA_AGENT] = None + + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Register the process service.""" if config_intents := config.get(DOMAIN, {}).get("intents"): @@ -79,7 +94,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: agent = await _get_agent(hass) try: await agent.async_process( - text, service.context, language=service.data.get(ATTR_LANGUAGE) + ConversationInput( + text=text, + context=service.context, + conversation_id=None, + language=service.data.get(ATTR_LANGUAGE, hass.config.language), + ) ) except intent.IntentHandleError as err: _LOGGER.error("Error processing %s: %s", text, err) @@ -99,7 +119,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: websocket_api.async_register_command(hass, websocket_process) websocket_api.async_register_command(hass, websocket_prepare) websocket_api.async_register_command(hass, websocket_get_agent_info) - websocket_api.async_register_command(hass, websocket_set_onboarding) return True @@ -158,41 +177,17 @@ async def websocket_get_agent_info( connection: websocket_api.ActiveConnection, msg: dict[str, Any], ) -> None: - """Do we need onboarding.""" + """Info about the agent in use.""" agent = await _get_agent(hass) connection.send_result( msg["id"], { - "onboarding": await agent.async_get_onboarding(), "attribution": agent.attribution, }, ) -@websocket_api.websocket_command( - { - vol.Required("type"): "conversation/onboarding/set", - vol.Required("shown"): bool, - } -) -@websocket_api.async_response -async def websocket_set_onboarding( - hass: HomeAssistant, - connection: websocket_api.ActiveConnection, - msg: dict[str, Any], -) -> None: - """Set onboarding status.""" - agent = await _get_agent(hass) - - success = await agent.async_set_onboarding(msg["shown"]) - - if success: - connection.send_result(msg["id"]) - else: - connection.send_error(msg["id"], "error", "Failed to set onboarding") - - class ConversationProcessView(http.HomeAssistantView): """View to process text.""" @@ -242,5 +237,12 @@ async def async_converse( if language is None: language = hass.config.language - result = await agent.async_process(text, context, conversation_id, language) + result = await agent.async_process( + ConversationInput( + text=text, + context=context, + conversation_id=conversation_id, + language=language, + ) + ) return result diff --git a/homeassistant/components/conversation/agent.py b/homeassistant/components/conversation/agent.py index 2d01a7a1e3e..2b2c307f824 100644 --- a/homeassistant/components/conversation/agent.py +++ b/homeassistant/components/conversation/agent.py @@ -3,12 +3,22 @@ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any +from typing import Any, TypedDict from homeassistant.core import Context from homeassistant.helpers import intent +@dataclass +class ConversationInput: + """User input to be processed.""" + + text: str + context: Context + conversation_id: str | None + language: str + + @dataclass class ConversationResult: """Result of async_process.""" @@ -24,34 +34,27 @@ class ConversationResult: } +class Attribution(TypedDict): + """Attribution for a conversation agent.""" + + name: str + url: str + + class AbstractConversationAgent(ABC): """Abstract conversation agent.""" @property - def attribution(self): + def attribution(self) -> Attribution | None: """Return the attribution.""" return None - async def async_get_onboarding(self): - """Get onboard data.""" - return None - - async def async_set_onboarding(self, shown): - """Set onboard data.""" - return True - @abstractmethod - async def async_process( - self, - text: str, - context: Context, - conversation_id: str | None = None, - language: str | None = None, - ) -> ConversationResult: + async def async_process(self, user_input: ConversationInput) -> ConversationResult: """Process a sentence.""" - async def async_reload(self, language: str | None = None): + async def async_reload(self, language: str | None = None) -> None: """Clear cached intents for a language.""" - async def async_prepare(self, language: str | None = None): + async def async_prepare(self, language: str | None = None) -> None: """Load intents for a language.""" diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index b33991c0540..7be37062d13 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -20,7 +20,7 @@ from homeassistant import core, setup from homeassistant.helpers import area_registry, entity_registry, intent, template from homeassistant.helpers.json import json_loads -from .agent import AbstractConversationAgent, ConversationResult +from .agent import AbstractConversationAgent, ConversationInput, ConversationResult from .const import DOMAIN _LOGGER = logging.getLogger(__name__) @@ -81,16 +81,11 @@ class DefaultAgent(AbstractConversationAgent): if config_intents: self._config_intents = config_intents - async def async_process( - self, - text: str, - context: core.Context, - conversation_id: str | None = None, - language: str | None = None, - ) -> ConversationResult: + async def async_process(self, user_input: ConversationInput) -> ConversationResult: """Process a sentence.""" - language = language or self.hass.config.language + language = user_input.language or self.hass.config.language lang_intents = self._lang_intents.get(language) + conversation_id = None # Not supported # Reload intents if missing or new components if lang_intents is None or ( @@ -114,9 +109,9 @@ class DefaultAgent(AbstractConversationAgent): "name": self._make_names_list(), } - result = recognize(text, lang_intents.intents, slot_lists=slot_lists) + result = recognize(user_input.text, lang_intents.intents, slot_lists=slot_lists) if result is None: - _LOGGER.debug("No intent was matched for '%s'", text) + _LOGGER.debug("No intent was matched for '%s'", user_input.text) return _make_error_result( language, intent.IntentResponseErrorCode.NO_INTENT_MATCH, @@ -133,8 +128,8 @@ class DefaultAgent(AbstractConversationAgent): entity.name: {"value": entity.value} for entity in result.entities_list }, - text, - context, + user_input.text, + user_input.context, language, ) except intent.IntentHandleError: diff --git a/homeassistant/components/google_assistant_sdk/__init__.py b/homeassistant/components/google_assistant_sdk/__init__.py index c25db0a856e..185e49435ba 100644 --- a/homeassistant/components/google_assistant_sdk/__init__.py +++ b/homeassistant/components/google_assistant_sdk/__init__.py @@ -9,7 +9,7 @@ import voluptuous as vol from homeassistant.components import conversation from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform -from homeassistant.core import Context, HomeAssistant, ServiceCall +from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.helpers import config_validation as cv, discovery, intent from homeassistant.helpers.config_entry_oauth2_flow import ( @@ -124,9 +124,9 @@ async def update_listener(hass, entry): """Handle options update.""" if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False): agent = GoogleAssistantConversationAgent(hass, entry) - conversation.async_set_agent(hass, agent) + conversation.async_set_agent(hass, entry, agent) else: - conversation.async_set_agent(hass, None) + conversation.async_unset_agent(hass, entry) class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent): @@ -148,11 +148,7 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent): } async def async_process( - self, - text: str, - context: Context, - conversation_id: str | None = None, - language: str | None = None, + self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" if self.session: @@ -170,12 +166,11 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent): ) self.assistant = TextAssistant(credentials, language_code) - resp = self.assistant.assist(text) + resp = self.assistant.assist(user_input.text) text_response = resp[0] - language = language or self.hass.config.language - intent_response = intent.IntentResponse(language=language) + intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_speech(text_response) return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id + response=intent_response, conversation_id=user_input.conversation_id ) diff --git a/tests/components/conversation/__init__.py b/tests/components/conversation/__init__.py index 998885ea218..8c5371f8cbe 100644 --- a/tests/components/conversation/__init__.py +++ b/tests/components/conversation/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations from homeassistant.components import conversation -from homeassistant.core import Context from homeassistant.helpers import intent @@ -15,16 +14,12 @@ class MockAgent(conversation.AbstractConversationAgent): self.response = "Test response" async def async_process( - self, - text: str, - context: Context, - conversation_id: str | None = None, - language: str | None = None, + self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process some text.""" - self.calls.append((text, context, conversation_id, language)) - response = intent.IntentResponse(language=language) + self.calls.append(user_input) + response = intent.IntentResponse(language=user_input.language) response.async_set_speech(self.response) return conversation.ConversationResult( - response=response, conversation_id=conversation_id + response=response, conversation_id=user_input.conversation_id ) diff --git a/tests/components/conversation/conftest.py b/tests/components/conversation/conftest.py index 5dbd52dc841..35f9937e5a0 100644 --- a/tests/components/conversation/conftest.py +++ b/tests/components/conversation/conftest.py @@ -11,5 +11,5 @@ from . import MockAgent def mock_agent(hass): """Mock agent.""" agent = MockAgent() - conversation.async_set_agent(hass, agent) + conversation.async_set_agent(hass, None, agent) return agent diff --git a/tests/components/conversation/test_init.py b/tests/components/conversation/test_init.py index f48ad4c0dfd..e79cd69475c 100644 --- a/tests/components/conversation/test_init.py +++ b/tests/components/conversation/test_init.py @@ -1,6 +1,6 @@ """The tests for the Conversation component.""" from http import HTTPStatus -from unittest.mock import ANY, patch +from unittest.mock import patch import pytest @@ -295,10 +295,10 @@ async def test_custom_agent(hass, hass_client, hass_admin_user, mock_agent): } assert len(mock_agent.calls) == 1 - assert mock_agent.calls[0][0] == "Test Text" - assert mock_agent.calls[0][1].user_id == hass_admin_user.id - assert mock_agent.calls[0][2] == "test-conv-id" - assert mock_agent.calls[0][3] == "test-language" + assert mock_agent.calls[0].text == "Test Text" + assert mock_agent.calls[0].context.user_id == hass_admin_user.id + assert mock_agent.calls[0].conversation_id == "test-conv-id" + assert mock_agent.calls[0].language == "test-language" @pytest.mark.parametrize( @@ -349,7 +349,7 @@ async def test_ws_api(hass, hass_ws_client, payload): "language": payload.get("language", hass.config.language), "data": {"code": "no_intent_match"}, }, - "conversation_id": payload.get("conversation_id") or ANY, + "conversation_id": None, } @@ -560,5 +560,12 @@ async def test_non_default_response(hass, init_components): agent = await conversation._get_agent(hass) assert isinstance(agent, conversation.DefaultAgent) - result = await agent.async_process("open the front door", Context()) + result = await agent.async_process( + conversation.ConversationInput( + text="open the front door", + context=Context(), + conversation_id=None, + language=hass.config.language, + ) + ) assert result.response.speech["plain"]["speech"] == "Opened front door"