Add property supported_languages
to AbstractConversationAgent
(#91588)
* Add property supported_languages to AbstractConversationAgent * Fix test * Use MATCH_ALL for openai supported languages
This commit is contained in:
parent
d7eb4c4740
commit
dc3c47986b
8 changed files with 62 additions and 2 deletions
|
@ -49,6 +49,11 @@ class AbstractConversationAgent(ABC):
|
||||||
"""Return the attribution."""
|
"""Return the attribution."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
async def async_process(self, user_input: ConversationInput) -> ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
|
|
|
@ -13,7 +13,7 @@ from typing import IO, Any
|
||||||
from hassil.intents import Intents, ResponseType, SlotList, TextSlotList
|
from hassil.intents import Intents, ResponseType, SlotList, TextSlotList
|
||||||
from hassil.recognize import RecognizeResult, recognize_all
|
from hassil.recognize import RecognizeResult, recognize_all
|
||||||
from hassil.util import merge_dict
|
from hassil.util import merge_dict
|
||||||
from home_assistant_intents import get_intents
|
from home_assistant_intents import get_domains_and_languages, get_intents
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from homeassistant import core, setup
|
from homeassistant import core, setup
|
||||||
|
@ -86,6 +86,11 @@ class DefaultAgent(AbstractConversationAgent):
|
||||||
self._config_intents: dict[str, Any] = {}
|
self._config_intents: dict[str, Any] = {}
|
||||||
self._slot_lists: dict[str, SlotList] | None = None
|
self._slot_lists: dict[str, SlotList] | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
return get_domains_and_languages()["homeassistant"]
|
||||||
|
|
||||||
async def async_initialize(self, config_intents):
|
async def async_initialize(self, config_intents):
|
||||||
"""Initialize the default agent."""
|
"""Initialize the default agent."""
|
||||||
if "intent" not in self.hass.config.components:
|
if "intent" not in self.hass.config.components:
|
||||||
|
|
|
@ -150,6 +150,14 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
||||||
"url": "https://www.home-assistant.io/integrations/google_assistant_sdk/",
|
"url": "https://www.home-assistant.io/integrations/google_assistant_sdk/",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
language_code = self.entry.options.get(
|
||||||
|
CONF_LANGUAGE_CODE, default_language_code(self.hass)
|
||||||
|
)
|
||||||
|
return [language_code]
|
||||||
|
|
||||||
async def async_process(
|
async def async_process(
|
||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
|
|
|
@ -9,7 +9,7 @@ from openai import error
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import CONF_API_KEY
|
from homeassistant.const import CONF_API_KEY, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
|
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
|
||||||
from homeassistant.helpers import intent, template
|
from homeassistant.helpers import intent, template
|
||||||
|
@ -70,6 +70,11 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
||||||
"""Return the attribution."""
|
"""Return the attribution."""
|
||||||
return {"name": "Powered by OpenAI", "url": "https://www.openai.com"}
|
return {"name": "Powered by OpenAI", "url": "https://www.openai.com"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
return [MATCH_ALL]
|
||||||
|
|
||||||
async def async_process(
|
async def async_process(
|
||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
|
|
|
@ -18,6 +18,11 @@ class MockAgent(conversation.AbstractConversationAgent):
|
||||||
self.calls = []
|
self.calls = []
|
||||||
self.response = "Test response"
|
self.response = "Test response"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self) -> list[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
return ["smurfish"]
|
||||||
|
|
||||||
async def async_process(
|
async def async_process(
|
||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Test for the default agent."""
|
"""Test for the default agent."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -121,3 +122,18 @@ async def test_exposed_areas(
|
||||||
# This should be an intent match failure because the area isn't in the slot list
|
# This should be an intent match failure because the area isn't in the slot list
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR
|
assert result.response.response_type == intent.IntentResponseType.ERROR
|
||||||
assert result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH
|
assert result.response.error_code == intent.IntentResponseErrorCode.NO_INTENT_MATCH
|
||||||
|
|
||||||
|
|
||||||
|
async def test_conversation_agent(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
) -> None:
|
||||||
|
"""Test DefaultAgent."""
|
||||||
|
agent = await conversation._get_agent_manager(hass).async_get_agent(
|
||||||
|
conversation.HOME_ASSISTANT_AGENT
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.conversation.default_agent.get_domains_and_languages",
|
||||||
|
return_value={"homeassistant": ["dwarvish", "elvish", "entish"]},
|
||||||
|
):
|
||||||
|
assert agent.supported_languages == ["dwarvish", "elvish", "entish"]
|
||||||
|
|
|
@ -7,6 +7,7 @@ from unittest.mock import call, patch
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.google_assistant_sdk import DOMAIN
|
from homeassistant.components.google_assistant_sdk import DOMAIN
|
||||||
from homeassistant.config_entries import ConfigEntryState
|
from homeassistant.config_entries import ConfigEntryState
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
@ -334,6 +335,9 @@ async def test_conversation_agent(
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
agent = await conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
|
||||||
|
assert agent.supported_languages == ["en-US"]
|
||||||
|
|
||||||
text1 = "tell me a joke"
|
text1 = "tell me a joke"
|
||||||
text2 = "tell me another one"
|
text2 = "tell me another one"
|
||||||
with patch(
|
with patch(
|
||||||
|
|
|
@ -137,3 +137,15 @@ async def test_template_error(
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
assert result.response.error_code == "unknown", result
|
assert result.response.error_code == "unknown", result
|
||||||
|
|
||||||
|
|
||||||
|
async def test_conversation_agent(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test OpenAIAgent."""
|
||||||
|
agent = await conversation._get_agent_manager(hass).async_get_agent(
|
||||||
|
mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
assert agent.supported_languages == ["*"]
|
||||||
|
|
Loading…
Add table
Reference in a new issue