LLM Assist API to ignore intents if not needed for exposed entities or calling device (#118283)

* LLM Assist API to ignore timer intents if device doesn't support it

* Refactor to use API instances

* Extract ToolContext class

* Limit exposed intents based on exposed entities
This commit is contained in:
Paulus Schoutsen 2024-05-28 21:29:18 -04:00 committed by GitHub
parent e0264c8604
commit 615a1eda51
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 302 additions and 181 deletions

View file

@ -149,13 +149,22 @@ class GoogleGenerativeAIConversationEntity(
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.API | None = None llm_api: llm.APIInstance | None = None
tools: list[dict[str, Any]] | None = None tools: list[dict[str, Any]] | None = None
if self.entry.options.get(CONF_LLM_HASS_API): if self.entry.options.get(CONF_LLM_HASS_API):
try: try:
llm_api = llm.async_get_api( llm_api = await llm.async_get_api(
self.hass, self.entry.options[CONF_LLM_HASS_API] self.hass,
self.entry.options[CONF_LLM_HASS_API],
llm.ToolContext(
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
),
) )
except HomeAssistantError as err: except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err) LOGGER.error("Error getting LLM API: %s", err)
@ -166,7 +175,7 @@ class GoogleGenerativeAIConversationEntity(
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id response=intent_response, conversation_id=user_input.conversation_id
) )
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()] tools = [_format_tool(tool) for tool in llm_api.tools]
model = genai.GenerativeModel( model = genai.GenerativeModel(
model_name=self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), model_name=self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
@ -206,19 +215,7 @@ class GoogleGenerativeAIConversationEntity(
try: try:
if llm_api: if llm_api:
empty_tool_input = llm.ToolInput( api_prompt = llm_api.api_prompt
tool_name="",
tool_args={},
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
else: else:
api_prompt = llm.async_render_no_api_prompt(self.hass) api_prompt = llm.async_render_no_api_prompt(self.hass)
@ -309,12 +306,6 @@ class GoogleGenerativeAIConversationEntity(
tool_input = llm.ToolInput( tool_input = llm.ToolInput(
tool_name=tool_call.name, tool_name=tool_call.name,
tool_args=dict(tool_call.args), tool_args=dict(tool_call.args),
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
) )
LOGGER.debug( LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args

View file

@ -50,6 +50,7 @@ from .timers import (
TimerManager, TimerManager,
TimerStatusIntentHandler, TimerStatusIntentHandler,
UnpauseTimerIntentHandler, UnpauseTimerIntentHandler,
async_device_supports_timers,
async_register_timer_handler, async_register_timer_handler,
) )
@ -59,6 +60,7 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
__all__ = [ __all__ = [
"async_register_timer_handler", "async_register_timer_handler",
"async_device_supports_timers",
"TimerInfo", "TimerInfo",
"TimerEventType", "TimerEventType",
"DOMAIN", "DOMAIN",

View file

@ -415,6 +415,15 @@ class TimerManager:
return device_id in self.handlers return device_id in self.handlers
@callback
def async_device_supports_timers(hass: HomeAssistant, device_id: str) -> bool:
"""Return True if device has been registered to handle timer events."""
timer_manager: TimerManager | None = hass.data.get(TIMER_DATA)
if timer_manager is None:
return False
return timer_manager.is_timer_device(device_id)
@callback @callback
def async_register_timer_handler( def async_register_timer_handler(
hass: HomeAssistant, device_id: str, handler: TimerHandler hass: HomeAssistant, device_id: str, handler: TimerHandler

View file

@ -99,12 +99,23 @@ class OpenAIConversationEntity(
"""Process a sentence.""" """Process a sentence."""
options = self.entry.options options = self.entry.options
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.API | None = None llm_api: llm.APIInstance | None = None
tools: list[dict[str, Any]] | None = None tools: list[dict[str, Any]] | None = None
if options.get(CONF_LLM_HASS_API): if options.get(CONF_LLM_HASS_API):
try: try:
llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API]) llm_api = await llm.async_get_api(
self.hass,
options[CONF_LLM_HASS_API],
llm.ToolContext(
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
),
)
except HomeAssistantError as err: except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err) LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error( intent_response.async_set_error(
@ -114,7 +125,7 @@ class OpenAIConversationEntity(
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id response=intent_response, conversation_id=user_input.conversation_id
) )
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()] tools = [_format_tool(tool) for tool in llm_api.tools]
if user_input.conversation_id in self.history: if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id conversation_id = user_input.conversation_id
@ -123,19 +134,7 @@ class OpenAIConversationEntity(
conversation_id = ulid.ulid_now() conversation_id = ulid.ulid_now()
try: try:
if llm_api: if llm_api:
empty_tool_input = llm.ToolInput( api_prompt = llm_api.api_prompt
tool_name="",
tool_args={},
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
else: else:
api_prompt = llm.async_render_no_api_prompt(self.hass) api_prompt = llm.async_render_no_api_prompt(self.hass)
@ -182,7 +181,7 @@ class OpenAIConversationEntity(
result = await client.chat.completions.create( result = await client.chat.completions.create(
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=messages, messages=messages,
tools=tools, tools=tools or None,
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
@ -210,12 +209,6 @@ class OpenAIConversationEntity(
tool_input = llm.ToolInput( tool_input = llm.ToolInput(
tool_name=tool_call.function.name, tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments), tool_args=json.loads(tool_call.function.arguments),
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
) )
LOGGER.debug( LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args

View file

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, replace from dataclasses import asdict, dataclass
from enum import Enum from enum import Enum
from typing import Any from typing import Any
@ -15,6 +15,7 @@ from homeassistant.components.conversation.trace import (
async_conversation_trace_append, async_conversation_trace_append,
) )
from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.components.intent import async_device_supports_timers
from homeassistant.components.weather.intent import INTENT_GET_WEATHER from homeassistant.components.weather.intent import INTENT_GET_WEATHER
from homeassistant.core import Context, HomeAssistant, callback from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -68,15 +69,16 @@ def async_register_api(hass: HomeAssistant, api: API) -> None:
apis[api.id] = api apis[api.id] = api
@callback async def async_get_api(
def async_get_api(hass: HomeAssistant, api_id: str) -> API: hass: HomeAssistant, api_id: str, tool_context: ToolContext
) -> APIInstance:
"""Get an API.""" """Get an API."""
apis = _async_get_apis(hass) apis = _async_get_apis(hass)
if api_id not in apis: if api_id not in apis:
raise HomeAssistantError(f"API {api_id} not found") raise HomeAssistantError(f"API {api_id} not found")
return apis[api_id] return await apis[api_id].async_get_api_instance(tool_context)
@callback @callback
@ -86,11 +88,9 @@ def async_get_apis(hass: HomeAssistant) -> list[API]:
@dataclass(slots=True) @dataclass(slots=True)
class ToolInput(ABC): class ToolContext:
"""Tool input to be processed.""" """Tool input to be processed."""
tool_name: str
tool_args: dict[str, Any]
platform: str platform: str
context: Context | None context: Context | None
user_prompt: str | None user_prompt: str | None
@ -99,6 +99,14 @@ class ToolInput(ABC):
device_id: str | None device_id: str | None
@dataclass(slots=True)
class ToolInput:
"""Tool input to be processed."""
tool_name: str
tool_args: dict[str, Any]
class Tool: class Tool:
"""LLM Tool base class.""" """LLM Tool base class."""
@ -108,7 +116,7 @@ class Tool:
@abstractmethod @abstractmethod
async def async_call( async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput self, hass: HomeAssistant, tool_input: ToolInput, tool_context: ToolContext
) -> JsonObjectType: ) -> JsonObjectType:
"""Call the tool.""" """Call the tool."""
raise NotImplementedError raise NotImplementedError
@ -118,6 +126,30 @@ class Tool:
return f"<{self.__class__.__name__} - {self.name}>" return f"<{self.__class__.__name__} - {self.name}>"
@dataclass
class APIInstance:
"""Instance of an API to be used by an LLM."""
api: API
api_prompt: str
tool_context: ToolContext
tools: list[Tool]
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
"""Call a LLM tool, validate args and return the response."""
async_conversation_trace_append(
ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input)
)
for tool in self.tools:
if tool.name == tool_input.tool_name:
break
else:
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
return await tool.async_call(self.api.hass, tool_input, self.tool_context)
@dataclass(slots=True, kw_only=True) @dataclass(slots=True, kw_only=True)
class API(ABC): class API(ABC):
"""An API to expose to LLMs.""" """An API to expose to LLMs."""
@ -127,38 +159,10 @@ class API(ABC):
name: str name: str
@abstractmethod @abstractmethod
async def async_get_api_prompt(self, tool_input: ToolInput) -> str: async def async_get_api_instance(self, tool_context: ToolContext) -> APIInstance:
"""Return the prompt for the API.""" """Return the instance of the API."""
raise NotImplementedError raise NotImplementedError
@abstractmethod
@callback
def async_get_tools(self) -> list[Tool]:
"""Return a list of tools."""
raise NotImplementedError
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
"""Call a LLM tool, validate args and return the response."""
async_conversation_trace_append(
ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input)
)
for tool in self.async_get_tools():
if tool.name == tool_input.tool_name:
break
else:
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
return await tool.async_call(
self.hass,
replace(
tool_input,
tool_name=tool.name,
tool_args=tool.parameters(tool_input.tool_args),
context=tool_input.context or Context(),
),
)
class IntentTool(Tool): class IntentTool(Tool):
"""LLM Tool representing an Intent.""" """LLM Tool representing an Intent."""
@ -176,21 +180,20 @@ class IntentTool(Tool):
self.parameters = vol.Schema(slot_schema) self.parameters = vol.Schema(slot_schema)
async def async_call( async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput self, hass: HomeAssistant, tool_input: ToolInput, tool_context: ToolContext
) -> JsonObjectType: ) -> JsonObjectType:
"""Handle the intent.""" """Handle the intent."""
slots = {key: {"value": val} for key, val in tool_input.tool_args.items()} slots = {key: {"value": val} for key, val in tool_input.tool_args.items()}
intent_response = await intent.async_handle( intent_response = await intent.async_handle(
hass, hass=hass,
tool_input.platform, platform=tool_context.platform,
self.name, intent_type=self.name,
slots, slots=slots,
tool_input.user_prompt, text_input=tool_context.user_prompt,
tool_input.context, context=tool_context.context,
tool_input.language, language=tool_context.language,
tool_input.assistant, assistant=tool_context.assistant,
tool_input.device_id, device_id=tool_context.device_id,
) )
return intent_response.as_dict() return intent_response.as_dict()
@ -213,15 +216,26 @@ class AssistAPI(API):
name="Assist", name="Assist",
) )
async def async_get_api_prompt(self, tool_input: ToolInput) -> str: async def async_get_api_instance(self, tool_context: ToolContext) -> APIInstance:
"""Return the prompt for the API.""" """Return the instance of the API."""
if tool_input.assistant: if tool_context.assistant:
exposed_entities: dict | None = _get_exposed_entities( exposed_entities: dict | None = _get_exposed_entities(
self.hass, tool_input.assistant self.hass, tool_context.assistant
) )
else: else:
exposed_entities = None exposed_entities = None
return APIInstance(
api=self,
api_prompt=await self._async_get_api_prompt(tool_context, exposed_entities),
tool_context=tool_context,
tools=self._async_get_tools(tool_context, exposed_entities),
)
async def _async_get_api_prompt(
self, tool_context: ToolContext, exposed_entities: dict | None
) -> str:
"""Return the prompt for the API."""
if not exposed_entities: if not exposed_entities:
return ( return (
"Only if the user wants to control a device, tell them to expose entities " "Only if the user wants to control a device, tell them to expose entities "
@ -236,9 +250,9 @@ class AssistAPI(API):
] ]
area: ar.AreaEntry | None = None area: ar.AreaEntry | None = None
floor: fr.FloorEntry | None = None floor: fr.FloorEntry | None = None
if tool_input.device_id: if tool_context.device_id:
device_reg = dr.async_get(self.hass) device_reg = dr.async_get(self.hass)
device = device_reg.async_get(tool_input.device_id) device = device_reg.async_get(tool_context.device_id)
if device: if device:
area_reg = ar.async_get(self.hass) area_reg = ar.async_get(self.hass)
@ -259,11 +273,16 @@ class AssistAPI(API):
"don't know in what area this conversation is happening." "don't know in what area this conversation is happening."
) )
if tool_input.context and tool_input.context.user_id: if tool_context.context and tool_context.context.user_id:
user = await self.hass.auth.async_get_user(tool_input.context.user_id) user = await self.hass.auth.async_get_user(tool_context.context.user_id)
if user: if user:
prompt.append(f"The user name is {user.name}.") prompt.append(f"The user name is {user.name}.")
if not tool_context.device_id or not async_device_supports_timers(
self.hass, tool_context.device_id
):
prompt.append("This device does not support timers.")
if exposed_entities: if exposed_entities:
prompt.append( prompt.append(
"An overview of the areas and the devices in this smart home:" "An overview of the areas and the devices in this smart home:"
@ -273,14 +292,44 @@ class AssistAPI(API):
return "\n".join(prompt) return "\n".join(prompt)
@callback @callback
def async_get_tools(self) -> list[Tool]: def _async_get_tools(
self, tool_context: ToolContext, exposed_entities: dict | None
) -> list[Tool]:
"""Return a list of LLM tools.""" """Return a list of LLM tools."""
return [ ignore_intents = self.IGNORE_INTENTS
IntentTool(intent_handler) if not tool_context.device_id or not async_device_supports_timers(
self.hass, tool_context.device_id
):
ignore_intents = ignore_intents | {
intent.INTENT_START_TIMER,
intent.INTENT_CANCEL_TIMER,
intent.INTENT_INCREASE_TIMER,
intent.INTENT_DECREASE_TIMER,
intent.INTENT_PAUSE_TIMER,
intent.INTENT_UNPAUSE_TIMER,
intent.INTENT_TIMER_STATUS,
}
intent_handlers = [
intent_handler
for intent_handler in intent.async_get(self.hass) for intent_handler in intent.async_get(self.hass)
if intent_handler.intent_type not in self.IGNORE_INTENTS if intent_handler.intent_type not in ignore_intents
] ]
exposed_domains: set[str] | None = None
if exposed_entities is not None:
exposed_domains = {
entity_id.split(".")[0] for entity_id in exposed_entities
}
intent_handlers = [
intent_handler
for intent_handler in intent_handlers
if intent_handler.platforms is None
or intent_handler.platforms & exposed_domains
]
return [IntentTool(intent_handler) for intent_handler in intent_handlers]
def _get_exposed_entities( def _get_exposed_entities(
hass: HomeAssistant, assistant: str hass: HomeAssistant, assistant: str

View file

@ -61,11 +61,11 @@ async def test_default_prompt(
with ( with (
patch("google.generativeai.GenerativeModel") as mock_model, patch("google.generativeai.GenerativeModel") as mock_model,
patch( patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools", "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools",
return_value=[], return_value=[],
) as mock_get_tools, ) as mock_get_tools,
patch( patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_api_prompt", "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_api_prompt",
return_value="<api_prompt>", return_value="<api_prompt>",
), ),
patch( patch(
@ -148,7 +148,7 @@ async def test_chat_history(
@patch( @patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools" "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
) )
async def test_function_call( async def test_function_call(
mock_get_tools, mock_get_tools,
@ -182,7 +182,7 @@ async def test_function_call(
mock_part.function_call.name = "test_tool" mock_part.function_call.name = "test_tool"
mock_part.function_call.args = {"param1": ["test_value"]} mock_part.function_call.args = {"param1": ["test_value"]}
def tool_call(hass, tool_input): def tool_call(hass, tool_input, tool_context):
mock_part.function_call = None mock_part.function_call = None
mock_part.text = "Hi there!" mock_part.text = "Hi there!"
return {"result": "Test response"} return {"result": "Test response"}
@ -221,6 +221,8 @@ async def test_function_call(
llm.ToolInput( llm.ToolInput(
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": ["test_value"]}, tool_args={"param1": ["test_value"]},
),
llm.ToolContext(
platform="google_generative_ai_conversation", platform="google_generative_ai_conversation",
context=context, context=context,
user_prompt="Please call the test function", user_prompt="Please call the test function",
@ -246,7 +248,7 @@ async def test_function_call(
@patch( @patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools" "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
) )
async def test_function_exception( async def test_function_exception(
mock_get_tools, mock_get_tools,
@ -280,7 +282,7 @@ async def test_function_exception(
mock_part.function_call.name = "test_tool" mock_part.function_call.name = "test_tool"
mock_part.function_call.args = {"param1": 1} mock_part.function_call.args = {"param1": 1}
def tool_call(hass, tool_input): def tool_call(hass, tool_input, tool_context):
mock_part.function_call = None mock_part.function_call = None
mock_part.text = "Hi there!" mock_part.text = "Hi there!"
raise HomeAssistantError("Test tool exception") raise HomeAssistantError("Test tool exception")
@ -319,6 +321,8 @@ async def test_function_exception(
llm.ToolInput( llm.ToolInput(
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": 1}, tool_args={"param1": 1},
),
llm.ToolContext(
platform="google_generative_ai_conversation", platform="google_generative_ai_conversation",
context=context, context=context,
user_prompt="Please call the test function", user_prompt="Please call the test function",

View file

@ -86,7 +86,7 @@ async def test_conversation_agent(
@patch( @patch(
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools" "homeassistant.components.openai_conversation.conversation.llm.AssistAPI._async_get_tools"
) )
async def test_function_call( async def test_function_call(
mock_get_tools, mock_get_tools,
@ -192,6 +192,8 @@ async def test_function_call(
llm.ToolInput( llm.ToolInput(
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "test_value"}, tool_args={"param1": "test_value"},
),
llm.ToolContext(
platform="openai_conversation", platform="openai_conversation",
context=context, context=context,
user_prompt="Please call the test function", user_prompt="Please call the test function",
@ -217,7 +219,7 @@ async def test_function_call(
@patch( @patch(
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools" "homeassistant.components.openai_conversation.conversation.llm.AssistAPI._async_get_tools"
) )
async def test_function_exception( async def test_function_exception(
mock_get_tools, mock_get_tools,
@ -323,6 +325,8 @@ async def test_function_exception(
llm.ToolInput( llm.ToolInput(
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "test_value"}, tool_args={"param1": "test_value"},
),
llm.ToolContext(
platform="openai_conversation", platform="openai_conversation",
context=context, context=context,
user_prompt="Please call the test function", user_prompt="Please call the test function",

View file

@ -5,6 +5,7 @@ from unittest.mock import Mock, patch
import pytest import pytest
import voluptuous as vol import voluptuous as vol
from homeassistant.components.intent import async_register_timer_handler
from homeassistant.core import Context, HomeAssistant, State from homeassistant.core import Context, HomeAssistant, State
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import ( from homeassistant.helpers import (
@ -22,53 +23,84 @@ from homeassistant.util import yaml
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
async def test_get_api_no_existing(hass: HomeAssistant) -> None: @pytest.fixture
def tool_input_context() -> llm.ToolContext:
"""Return tool input context."""
return llm.ToolContext(
platform="",
context=None,
user_prompt=None,
language=None,
assistant=None,
device_id=None,
)
async def test_get_api_no_existing(
hass: HomeAssistant, tool_input_context: llm.ToolContext
) -> None:
"""Test getting an llm api where no config exists.""" """Test getting an llm api where no config exists."""
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
llm.async_get_api(hass, "non-existing") await llm.async_get_api(hass, "non-existing", tool_input_context)
async def test_register_api(hass: HomeAssistant) -> None: async def test_register_api(
hass: HomeAssistant, tool_input_context: llm.ToolContext
) -> None:
"""Test registering an llm api.""" """Test registering an llm api."""
class MyAPI(llm.API): class MyAPI(llm.API):
async def async_get_api_prompt(self, tool_input: llm.ToolInput) -> str: async def async_get_api_instance(
"""Return a prompt for the tool.""" self, tool_input: llm.ToolInput
return "" ) -> llm.APIInstance:
def async_get_tools(self) -> list[llm.Tool]:
"""Return a list of tools.""" """Return a list of tools."""
return [] return llm.APIInstance(self, "", [], tool_input_context)
api = MyAPI(hass=hass, id="test", name="Test") api = MyAPI(hass=hass, id="test", name="Test")
llm.async_register_api(hass, api) llm.async_register_api(hass, api)
assert llm.async_get_api(hass, "test") is api instance = await llm.async_get_api(hass, "test", tool_input_context)
assert instance.api is api
assert api in llm.async_get_apis(hass) assert api in llm.async_get_apis(hass)
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
llm.async_register_api(hass, api) llm.async_register_api(hass, api)
async def test_call_tool_no_existing(hass: HomeAssistant) -> None: async def test_call_tool_no_existing(
hass: HomeAssistant, tool_input_context: llm.ToolContext
) -> None:
"""Test calling an llm tool where no config exists.""" """Test calling an llm tool where no config exists."""
instance = await llm.async_get_api(hass, "assist", tool_input_context)
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
await llm.async_get_api(hass, "intent").async_call_tool( await instance.async_call_tool(
llm.ToolInput( llm.ToolInput("test_tool", {}),
"test_tool",
{},
"test_platform",
None,
None,
None,
None,
None,
),
) )
async def test_assist_api(hass: HomeAssistant) -> None: async def test_assist_api(
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test Assist API.""" """Test Assist API."""
assert await async_setup_component(hass, "homeassistant", {})
entity_registry.async_get_or_create(
"light",
"kitchen",
"mock-id-kitchen",
original_name="Kitchen",
suggested_object_id="kitchen",
).write_unavailable_state(hass)
test_context = Context()
tool_context = llm.ToolContext(
platform="test_platform",
context=test_context,
user_prompt="test_text",
language="*",
assistant="conversation",
device_id="test_device",
)
schema = { schema = {
vol.Optional("area"): cv.string, vol.Optional("area"): cv.string,
vol.Optional("floor"): cv.string, vol.Optional("floor"): cv.string,
@ -77,22 +109,33 @@ async def test_assist_api(hass: HomeAssistant) -> None:
class MyIntentHandler(intent.IntentHandler): class MyIntentHandler(intent.IntentHandler):
intent_type = "test_intent" intent_type = "test_intent"
slot_schema = schema slot_schema = schema
platforms = set() # Match none
intent_handler = MyIntentHandler() intent_handler = MyIntentHandler()
intent.async_register(hass, intent_handler) intent.async_register(hass, intent_handler)
assert len(llm.async_get_apis(hass)) == 1 assert len(llm.async_get_apis(hass)) == 1
api = llm.async_get_api(hass, "assist") api = await llm.async_get_api(hass, "assist", tool_context)
tools = api.async_get_tools() assert len(api.tools) == 0
assert len(tools) == 1
tool = tools[0] # Match all
intent_handler.platforms = None
api = await llm.async_get_api(hass, "assist", tool_context)
assert len(api.tools) == 1
# Match specific domain
intent_handler.platforms = {"light"}
api = await llm.async_get_api(hass, "assist", tool_context)
assert len(api.tools) == 1
tool = api.tools[0]
assert tool.name == "test_intent" assert tool.name == "test_intent"
assert tool.description == "Execute Home Assistant test_intent intent" assert tool.description == "Execute Home Assistant test_intent intent"
assert tool.parameters == vol.Schema(intent_handler.slot_schema) assert tool.parameters == vol.Schema(intent_handler.slot_schema)
assert str(tool) == "<IntentTool - test_intent>" assert str(tool) == "<IntentTool - test_intent>"
test_context = Context()
assert test_context.json_fragment # To reproduce an error case in tracing assert test_context.json_fragment # To reproduce an error case in tracing
intent_response = intent.IntentResponse("*") intent_response = intent.IntentResponse("*")
intent_response.matched_states = [State("light.matched", "on")] intent_response.matched_states = [State("light.matched", "on")]
@ -100,12 +143,6 @@ async def test_assist_api(hass: HomeAssistant) -> None:
tool_input = llm.ToolInput( tool_input = llm.ToolInput(
tool_name="test_intent", tool_name="test_intent",
tool_args={"area": "kitchen", "floor": "ground_floor"}, tool_args={"area": "kitchen", "floor": "ground_floor"},
platform="test_platform",
context=test_context,
user_prompt="test_text",
language="*",
assistant="test_assistant",
device_id="test_device",
) )
with patch( with patch(
@ -114,18 +151,18 @@ async def test_assist_api(hass: HomeAssistant) -> None:
response = await api.async_call_tool(tool_input) response = await api.async_call_tool(tool_input)
mock_intent_handle.assert_awaited_once_with( mock_intent_handle.assert_awaited_once_with(
hass, hass=hass,
"test_platform", platform="test_platform",
"test_intent", intent_type="test_intent",
{ slots={
"area": {"value": "kitchen"}, "area": {"value": "kitchen"},
"floor": {"value": "ground_floor"}, "floor": {"value": "ground_floor"},
}, },
"test_text", text_input="test_text",
test_context, context=test_context,
"*", language="*",
"test_assistant", assistant="conversation",
"test_device", device_id="test_device",
) )
assert response == { assert response == {
"card": {}, "card": {},
@ -140,7 +177,27 @@ async def test_assist_api(hass: HomeAssistant) -> None:
} }
async def test_assist_api_description(hass: HomeAssistant) -> None: async def test_assist_api_get_timer_tools(
hass: HomeAssistant, tool_input_context: llm.ToolContext
) -> None:
"""Test getting timer tools with Assist API."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {})
api = await llm.async_get_api(hass, "assist", tool_input_context)
assert "HassStartTimer" not in [tool.name for tool in api.tools]
tool_input_context.device_id = "test_device"
async_register_timer_handler(hass, "test_device", lambda *args: None)
api = await llm.async_get_api(hass, "assist", tool_input_context)
assert "HassStartTimer" in [tool.name for tool in api.tools]
async def test_assist_api_description(
hass: HomeAssistant, tool_input_context: llm.ToolContext
) -> None:
"""Test intent description with Assist API.""" """Test intent description with Assist API."""
class MyIntentHandler(intent.IntentHandler): class MyIntentHandler(intent.IntentHandler):
@ -150,10 +207,9 @@ async def test_assist_api_description(hass: HomeAssistant) -> None:
intent.async_register(hass, MyIntentHandler()) intent.async_register(hass, MyIntentHandler())
assert len(llm.async_get_apis(hass)) == 1 assert len(llm.async_get_apis(hass)) == 1
api = llm.async_get_api(hass, "assist") api = await llm.async_get_api(hass, "assist", tool_input_context)
tools = api.async_get_tools() assert len(api.tools) == 1
assert len(tools) == 1 tool = api.tools[0]
tool = tools[0]
assert tool.name == "test_intent" assert tool.name == "test_intent"
assert tool.description == "my intent handler" assert tool.description == "my intent handler"
@ -167,20 +223,18 @@ async def test_assist_api_prompt(
) -> None: ) -> None:
"""Test prompt for the assist API.""" """Test prompt for the assist API."""
assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {})
context = Context() context = Context()
tool_input = llm.ToolInput( tool_context = llm.ToolContext(
tool_name=None,
tool_args=None,
platform="test_platform", platform="test_platform",
context=context, context=context,
user_prompt="test_text", user_prompt="test_text",
language="*", language="*",
assistant="conversation", assistant="conversation",
device_id="test_device", device_id=None,
) )
api = llm.async_get_api(hass, "assist") api = await llm.async_get_api(hass, "assist", tool_context)
prompt = await api.async_get_api_prompt(tool_input) assert api.api_prompt == (
assert prompt == (
"Only if the user wants to control a device, tell them to expose entities to their " "Only if the user wants to control a device, tell them to expose entities to their "
"voice assistant in Home Assistant." "voice assistant in Home Assistant."
) )
@ -308,7 +362,7 @@ async def test_assist_api_prompt(
) )
) )
exposed_entities = llm._get_exposed_entities(hass, tool_input.assistant) exposed_entities = llm._get_exposed_entities(hass, tool_context.assistant)
assert exposed_entities == { assert exposed_entities == {
"light.1": { "light.1": {
"areas": "Test Area 2", "areas": "Test Area 2",
@ -373,40 +427,55 @@ async def test_assist_api_prompt(
"Call the intent tools to control Home Assistant. " "Call the intent tools to control Home Assistant. "
"When controlling an area, prefer passing area name." "When controlling an area, prefer passing area name."
) )
no_timer_prompt = "This device does not support timers."
prompt = await api.async_get_api_prompt(tool_input)
area_prompt = ( area_prompt = (
"Reject all generic commands like 'turn on the lights' because we don't know in what area " "Reject all generic commands like 'turn on the lights' because we don't know in what area "
"this conversation is happening." "this conversation is happening."
) )
assert prompt == ( api = await llm.async_get_api(hass, "assist", tool_context)
assert api.api_prompt == (
f"""{first_part_prompt} f"""{first_part_prompt}
{area_prompt} {area_prompt}
{no_timer_prompt}
{exposed_entities_prompt}""" {exposed_entities_prompt}"""
) )
# Fake that request is made from a specific device ID # Fake that request is made from a specific device ID with an area
tool_input.device_id = device.id tool_context.device_id = device.id
prompt = await api.async_get_api_prompt(tool_input)
area_prompt = ( area_prompt = (
"You are in area Test Area and all generic commands like 'turn on the lights' " "You are in area Test Area and all generic commands like 'turn on the lights' "
"should target this area." "should target this area."
) )
assert prompt == ( api = await llm.async_get_api(hass, "assist", tool_context)
assert api.api_prompt == (
f"""{first_part_prompt} f"""{first_part_prompt}
{area_prompt} {area_prompt}
{no_timer_prompt}
{exposed_entities_prompt}""" {exposed_entities_prompt}"""
) )
# Add floor # Add floor
floor = floor_registry.async_create("2") floor = floor_registry.async_create("2")
area_registry.async_update(area.id, floor_id=floor.floor_id) area_registry.async_update(area.id, floor_id=floor.floor_id)
prompt = await api.async_get_api_prompt(tool_input)
area_prompt = ( area_prompt = (
"You are in area Test Area (floor 2) and all generic commands like 'turn on the lights' " "You are in area Test Area (floor 2) and all generic commands like 'turn on the lights' "
"should target this area." "should target this area."
) )
assert prompt == ( api = await llm.async_get_api(hass, "assist", tool_context)
assert api.api_prompt == (
f"""{first_part_prompt}
{area_prompt}
{no_timer_prompt}
{exposed_entities_prompt}"""
)
# Register device for timers
async_register_timer_handler(hass, device.id, lambda *args: None)
api = await llm.async_get_api(hass, "assist", tool_context)
# The no_timer_prompt is gone
assert api.api_prompt == (
f"""{first_part_prompt} f"""{first_part_prompt}
{area_prompt} {area_prompt}
{exposed_entities_prompt}""" {exposed_entities_prompt}"""
@ -418,8 +487,8 @@ async def test_assist_api_prompt(
mock_user.id = "12345" mock_user.id = "12345"
mock_user.name = "Test User" mock_user.name = "Test User"
with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user): with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user):
prompt = await api.async_get_api_prompt(tool_input) api = await llm.async_get_api(hass, "assist", tool_context)
assert prompt == ( assert api.api_prompt == (
f"""{first_part_prompt} f"""{first_part_prompt}
{area_prompt} {area_prompt}
The user name is Test User. The user name is Test User.