Adjust typing of AbstractConversationAgent.supported_languages (#91648)

* Adjust typing of AbstractConversationAgent.supported_languages

* Update test
This commit is contained in:
Erik Montnemery 2023-04-19 16:53:49 +02:00 committed by GitHub
parent 5e9bbeb4ad
commit eabbe8969d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 5 deletions

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, TypedDict from typing import Any, Literal, TypedDict
from homeassistant.core import Context from homeassistant.core import Context
from homeassistant.helpers import intent from homeassistant.helpers import intent
@ -51,7 +51,7 @@ class AbstractConversationAgent(ABC):
@property @property
@abstractmethod @abstractmethod
def supported_languages(self) -> list[str]: def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages.""" """Return a list of supported languages."""
@abstractmethod @abstractmethod

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from functools import partial from functools import partial
import logging import logging
from typing import Literal
import openai import openai
from openai import error from openai import error
@ -71,9 +72,9 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
return {"name": "Powered by OpenAI", "url": "https://www.openai.com"} return {"name": "Powered by OpenAI", "url": "https://www.openai.com"}
@property @property
def supported_languages(self) -> list[str]: def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages.""" """Return a list of supported languages."""
return [MATCH_ALL] return MATCH_ALL
async def async_process( async def async_process(
self, user_input: conversation.ConversationInput self, user_input: conversation.ConversationInput

View file

@ -148,4 +148,4 @@ async def test_conversation_agent(
agent = await conversation._get_agent_manager(hass).async_get_agent( agent = await conversation._get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id mock_config_entry.entry_id
) )
assert agent.supported_languages == ["*"] assert agent.supported_languages == "*"