diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 33dade8bf29..f85cf2530dc 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -149,13 +149,22 @@ class GoogleGenerativeAIConversationEntity( ) -> conversation.ConversationResult: """Process a sentence.""" intent_response = intent.IntentResponse(language=user_input.language) - llm_api: llm.API | None = None + llm_api: llm.APIInstance | None = None tools: list[dict[str, Any]] | None = None if self.entry.options.get(CONF_LLM_HASS_API): try: - llm_api = llm.async_get_api( - self.hass, self.entry.options[CONF_LLM_HASS_API] + llm_api = await llm.async_get_api( + self.hass, + self.entry.options[CONF_LLM_HASS_API], + llm.ToolContext( + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=conversation.DOMAIN, + device_id=user_input.device_id, + ), ) except HomeAssistantError as err: LOGGER.error("Error getting LLM API: %s", err) @@ -166,7 +175,7 @@ class GoogleGenerativeAIConversationEntity( return conversation.ConversationResult( response=intent_response, conversation_id=user_input.conversation_id ) - tools = [_format_tool(tool) for tool in llm_api.async_get_tools()] + tools = [_format_tool(tool) for tool in llm_api.tools] model = genai.GenerativeModel( model_name=self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), @@ -206,19 +215,7 @@ class GoogleGenerativeAIConversationEntity( try: if llm_api: - empty_tool_input = llm.ToolInput( - tool_name="", - tool_args={}, - platform=DOMAIN, - context=user_input.context, - user_prompt=user_input.text, - language=user_input.language, - assistant=conversation.DOMAIN, - device_id=user_input.device_id, - ) - - api_prompt = await llm_api.async_get_api_prompt(empty_tool_input) - + api_prompt = llm_api.api_prompt else: api_prompt = llm.async_render_no_api_prompt(self.hass) @@ -309,12 +306,6 @@ class GoogleGenerativeAIConversationEntity( tool_input = llm.ToolInput( tool_name=tool_call.name, tool_args=dict(tool_call.args), - platform=DOMAIN, - context=user_input.context, - user_prompt=user_input.text, - language=user_input.language, - assistant=conversation.DOMAIN, - device_id=user_input.device_id, ) LOGGER.debug( "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args diff --git a/homeassistant/components/intent/__init__.py b/homeassistant/components/intent/__init__.py index 7fba729e96b..9b09fa9167b 100644 --- a/homeassistant/components/intent/__init__.py +++ b/homeassistant/components/intent/__init__.py @@ -50,6 +50,7 @@ from .timers import ( TimerManager, TimerStatusIntentHandler, UnpauseTimerIntentHandler, + async_device_supports_timers, async_register_timer_handler, ) @@ -59,6 +60,7 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) __all__ = [ "async_register_timer_handler", + "async_device_supports_timers", "TimerInfo", "TimerEventType", "DOMAIN", diff --git a/homeassistant/components/intent/timers.py b/homeassistant/components/intent/timers.py index f5a06e6e028..167f37ed6fc 100644 --- a/homeassistant/components/intent/timers.py +++ b/homeassistant/components/intent/timers.py @@ -415,6 +415,15 @@ class TimerManager: return device_id in self.handlers +@callback +def async_device_supports_timers(hass: HomeAssistant, device_id: str) -> bool: + """Return True if device has been registered to handle timer events.""" + timer_manager: TimerManager | None = hass.data.get(TIMER_DATA) + if timer_manager is None: + return False + return timer_manager.is_timer_device(device_id) + + @callback def async_register_timer_handler( hass: HomeAssistant, device_id: str, handler: TimerHandler diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index ab76d9cfb56..f4652a1f820 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -99,12 +99,23 @@ class OpenAIConversationEntity( """Process a sentence.""" options = self.entry.options intent_response = intent.IntentResponse(language=user_input.language) - llm_api: llm.API | None = None + llm_api: llm.APIInstance | None = None tools: list[dict[str, Any]] | None = None if options.get(CONF_LLM_HASS_API): try: - llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API]) + llm_api = await llm.async_get_api( + self.hass, + options[CONF_LLM_HASS_API], + llm.ToolContext( + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=conversation.DOMAIN, + device_id=user_input.device_id, + ), + ) except HomeAssistantError as err: LOGGER.error("Error getting LLM API: %s", err) intent_response.async_set_error( @@ -114,7 +125,7 @@ class OpenAIConversationEntity( return conversation.ConversationResult( response=intent_response, conversation_id=user_input.conversation_id ) - tools = [_format_tool(tool) for tool in llm_api.async_get_tools()] + tools = [_format_tool(tool) for tool in llm_api.tools] if user_input.conversation_id in self.history: conversation_id = user_input.conversation_id @@ -123,19 +134,7 @@ class OpenAIConversationEntity( conversation_id = ulid.ulid_now() try: if llm_api: - empty_tool_input = llm.ToolInput( - tool_name="", - tool_args={}, - platform=DOMAIN, - context=user_input.context, - user_prompt=user_input.text, - language=user_input.language, - assistant=conversation.DOMAIN, - device_id=user_input.device_id, - ) - - api_prompt = await llm_api.async_get_api_prompt(empty_tool_input) - + api_prompt = llm_api.api_prompt else: api_prompt = llm.async_render_no_api_prompt(self.hass) @@ -182,7 +181,7 @@ class OpenAIConversationEntity( result = await client.chat.completions.create( model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), messages=messages, - tools=tools, + tools=tools or None, max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), @@ -210,12 +209,6 @@ class OpenAIConversationEntity( tool_input = llm.ToolInput( tool_name=tool_call.function.name, tool_args=json.loads(tool_call.function.arguments), - platform=DOMAIN, - context=user_input.context, - user_prompt=user_input.text, - language=user_input.language, - assistant=conversation.DOMAIN, - device_id=user_input.device_id, ) LOGGER.debug( "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index 8271c247e23..2f808321c13 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import asdict, dataclass, replace +from dataclasses import asdict, dataclass from enum import Enum from typing import Any @@ -15,6 +15,7 @@ from homeassistant.components.conversation.trace import ( async_conversation_trace_append, ) from homeassistant.components.homeassistant.exposed_entities import async_should_expose +from homeassistant.components.intent import async_device_supports_timers from homeassistant.components.weather.intent import INTENT_GET_WEATHER from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError @@ -68,15 +69,16 @@ def async_register_api(hass: HomeAssistant, api: API) -> None: apis[api.id] = api -@callback -def async_get_api(hass: HomeAssistant, api_id: str) -> API: +async def async_get_api( + hass: HomeAssistant, api_id: str, tool_context: ToolContext +) -> APIInstance: """Get an API.""" apis = _async_get_apis(hass) if api_id not in apis: raise HomeAssistantError(f"API {api_id} not found") - return apis[api_id] + return await apis[api_id].async_get_api_instance(tool_context) @callback @@ -86,11 +88,9 @@ def async_get_apis(hass: HomeAssistant) -> list[API]: @dataclass(slots=True) -class ToolInput(ABC): +class ToolContext: """Tool input to be processed.""" - tool_name: str - tool_args: dict[str, Any] platform: str context: Context | None user_prompt: str | None @@ -99,6 +99,14 @@ class ToolInput(ABC): device_id: str | None +@dataclass(slots=True) +class ToolInput: + """Tool input to be processed.""" + + tool_name: str + tool_args: dict[str, Any] + + class Tool: """LLM Tool base class.""" @@ -108,7 +116,7 @@ class Tool: @abstractmethod async def async_call( - self, hass: HomeAssistant, tool_input: ToolInput + self, hass: HomeAssistant, tool_input: ToolInput, tool_context: ToolContext ) -> JsonObjectType: """Call the tool.""" raise NotImplementedError @@ -118,6 +126,30 @@ class Tool: return f"<{self.__class__.__name__} - {self.name}>" +@dataclass +class APIInstance: + """Instance of an API to be used by an LLM.""" + + api: API + api_prompt: str + tool_context: ToolContext + tools: list[Tool] + + async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: + """Call a LLM tool, validate args and return the response.""" + async_conversation_trace_append( + ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input) + ) + + for tool in self.tools: + if tool.name == tool_input.tool_name: + break + else: + raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found') + + return await tool.async_call(self.api.hass, tool_input, self.tool_context) + + @dataclass(slots=True, kw_only=True) class API(ABC): """An API to expose to LLMs.""" @@ -127,38 +159,10 @@ class API(ABC): name: str @abstractmethod - async def async_get_api_prompt(self, tool_input: ToolInput) -> str: - """Return the prompt for the API.""" + async def async_get_api_instance(self, tool_context: ToolContext) -> APIInstance: + """Return the instance of the API.""" raise NotImplementedError - @abstractmethod - @callback - def async_get_tools(self) -> list[Tool]: - """Return a list of tools.""" - raise NotImplementedError - - async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: - """Call a LLM tool, validate args and return the response.""" - async_conversation_trace_append( - ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input) - ) - - for tool in self.async_get_tools(): - if tool.name == tool_input.tool_name: - break - else: - raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found') - - return await tool.async_call( - self.hass, - replace( - tool_input, - tool_name=tool.name, - tool_args=tool.parameters(tool_input.tool_args), - context=tool_input.context or Context(), - ), - ) - class IntentTool(Tool): """LLM Tool representing an Intent.""" @@ -176,21 +180,20 @@ class IntentTool(Tool): self.parameters = vol.Schema(slot_schema) async def async_call( - self, hass: HomeAssistant, tool_input: ToolInput + self, hass: HomeAssistant, tool_input: ToolInput, tool_context: ToolContext ) -> JsonObjectType: """Handle the intent.""" slots = {key: {"value": val} for key, val in tool_input.tool_args.items()} - intent_response = await intent.async_handle( - hass, - tool_input.platform, - self.name, - slots, - tool_input.user_prompt, - tool_input.context, - tool_input.language, - tool_input.assistant, - tool_input.device_id, + hass=hass, + platform=tool_context.platform, + intent_type=self.name, + slots=slots, + text_input=tool_context.user_prompt, + context=tool_context.context, + language=tool_context.language, + assistant=tool_context.assistant, + device_id=tool_context.device_id, ) return intent_response.as_dict() @@ -213,15 +216,26 @@ class AssistAPI(API): name="Assist", ) - async def async_get_api_prompt(self, tool_input: ToolInput) -> str: - """Return the prompt for the API.""" - if tool_input.assistant: + async def async_get_api_instance(self, tool_context: ToolContext) -> APIInstance: + """Return the instance of the API.""" + if tool_context.assistant: exposed_entities: dict | None = _get_exposed_entities( - self.hass, tool_input.assistant + self.hass, tool_context.assistant ) else: exposed_entities = None + return APIInstance( + api=self, + api_prompt=await self._async_get_api_prompt(tool_context, exposed_entities), + tool_context=tool_context, + tools=self._async_get_tools(tool_context, exposed_entities), + ) + + async def _async_get_api_prompt( + self, tool_context: ToolContext, exposed_entities: dict | None + ) -> str: + """Return the prompt for the API.""" if not exposed_entities: return ( "Only if the user wants to control a device, tell them to expose entities " @@ -236,9 +250,9 @@ class AssistAPI(API): ] area: ar.AreaEntry | None = None floor: fr.FloorEntry | None = None - if tool_input.device_id: + if tool_context.device_id: device_reg = dr.async_get(self.hass) - device = device_reg.async_get(tool_input.device_id) + device = device_reg.async_get(tool_context.device_id) if device: area_reg = ar.async_get(self.hass) @@ -259,11 +273,16 @@ class AssistAPI(API): "don't know in what area this conversation is happening." ) - if tool_input.context and tool_input.context.user_id: - user = await self.hass.auth.async_get_user(tool_input.context.user_id) + if tool_context.context and tool_context.context.user_id: + user = await self.hass.auth.async_get_user(tool_context.context.user_id) if user: prompt.append(f"The user name is {user.name}.") + if not tool_context.device_id or not async_device_supports_timers( + self.hass, tool_context.device_id + ): + prompt.append("This device does not support timers.") + if exposed_entities: prompt.append( "An overview of the areas and the devices in this smart home:" @@ -273,14 +292,44 @@ class AssistAPI(API): return "\n".join(prompt) @callback - def async_get_tools(self) -> list[Tool]: + def _async_get_tools( + self, tool_context: ToolContext, exposed_entities: dict | None + ) -> list[Tool]: """Return a list of LLM tools.""" - return [ - IntentTool(intent_handler) + ignore_intents = self.IGNORE_INTENTS + if not tool_context.device_id or not async_device_supports_timers( + self.hass, tool_context.device_id + ): + ignore_intents = ignore_intents | { + intent.INTENT_START_TIMER, + intent.INTENT_CANCEL_TIMER, + intent.INTENT_INCREASE_TIMER, + intent.INTENT_DECREASE_TIMER, + intent.INTENT_PAUSE_TIMER, + intent.INTENT_UNPAUSE_TIMER, + intent.INTENT_TIMER_STATUS, + } + + intent_handlers = [ + intent_handler for intent_handler in intent.async_get(self.hass) - if intent_handler.intent_type not in self.IGNORE_INTENTS + if intent_handler.intent_type not in ignore_intents ] + exposed_domains: set[str] | None = None + if exposed_entities is not None: + exposed_domains = { + entity_id.split(".")[0] for entity_id in exposed_entities + } + intent_handlers = [ + intent_handler + for intent_handler in intent_handlers + if intent_handler.platforms is None + or intent_handler.platforms & exposed_domains + ] + + return [IntentTool(intent_handler) for intent_handler in intent_handlers] + def _get_exposed_entities( hass: HomeAssistant, assistant: str diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index e3a938a04d6..4c7f2de5e2e 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -61,11 +61,11 @@ async def test_default_prompt( with ( patch("google.generativeai.GenerativeModel") as mock_model, patch( - "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools", + "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools", return_value=[], ) as mock_get_tools, patch( - "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_api_prompt", + "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_api_prompt", return_value="", ), patch( @@ -148,7 +148,7 @@ async def test_chat_history( @patch( - "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools" + "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools" ) async def test_function_call( mock_get_tools, @@ -182,7 +182,7 @@ async def test_function_call( mock_part.function_call.name = "test_tool" mock_part.function_call.args = {"param1": ["test_value"]} - def tool_call(hass, tool_input): + def tool_call(hass, tool_input, tool_context): mock_part.function_call = None mock_part.text = "Hi there!" return {"result": "Test response"} @@ -221,6 +221,8 @@ async def test_function_call( llm.ToolInput( tool_name="test_tool", tool_args={"param1": ["test_value"]}, + ), + llm.ToolContext( platform="google_generative_ai_conversation", context=context, user_prompt="Please call the test function", @@ -246,7 +248,7 @@ async def test_function_call( @patch( - "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools" + "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools" ) async def test_function_exception( mock_get_tools, @@ -280,7 +282,7 @@ async def test_function_exception( mock_part.function_call.name = "test_tool" mock_part.function_call.args = {"param1": 1} - def tool_call(hass, tool_input): + def tool_call(hass, tool_input, tool_context): mock_part.function_call = None mock_part.text = "Hi there!" raise HomeAssistantError("Test tool exception") @@ -319,6 +321,8 @@ async def test_function_exception( llm.ToolInput( tool_name="test_tool", tool_args={"param1": 1}, + ), + llm.ToolContext( platform="google_generative_ai_conversation", context=context, user_prompt="Please call the test function", diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 3fa5c307b6d..0eec14395e5 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -86,7 +86,7 @@ async def test_conversation_agent( @patch( - "homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools" + "homeassistant.components.openai_conversation.conversation.llm.AssistAPI._async_get_tools" ) async def test_function_call( mock_get_tools, @@ -192,6 +192,8 @@ async def test_function_call( llm.ToolInput( tool_name="test_tool", tool_args={"param1": "test_value"}, + ), + llm.ToolContext( platform="openai_conversation", context=context, user_prompt="Please call the test function", @@ -217,7 +219,7 @@ async def test_function_call( @patch( - "homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools" + "homeassistant.components.openai_conversation.conversation.llm.AssistAPI._async_get_tools" ) async def test_function_exception( mock_get_tools, @@ -323,6 +325,8 @@ async def test_function_exception( llm.ToolInput( tool_name="test_tool", tool_args={"param1": "test_value"}, + ), + llm.ToolContext( platform="openai_conversation", context=context, user_prompt="Please call the test function", diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 873e2796d1e..c71d11da8a2 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -5,6 +5,7 @@ from unittest.mock import Mock, patch import pytest import voluptuous as vol +from homeassistant.components.intent import async_register_timer_handler from homeassistant.core import Context, HomeAssistant, State from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import ( @@ -22,53 +23,84 @@ from homeassistant.util import yaml from tests.common import MockConfigEntry -async def test_get_api_no_existing(hass: HomeAssistant) -> None: +@pytest.fixture +def tool_input_context() -> llm.ToolContext: + """Return tool input context.""" + return llm.ToolContext( + platform="", + context=None, + user_prompt=None, + language=None, + assistant=None, + device_id=None, + ) + + +async def test_get_api_no_existing( + hass: HomeAssistant, tool_input_context: llm.ToolContext +) -> None: """Test getting an llm api where no config exists.""" with pytest.raises(HomeAssistantError): - llm.async_get_api(hass, "non-existing") + await llm.async_get_api(hass, "non-existing", tool_input_context) -async def test_register_api(hass: HomeAssistant) -> None: +async def test_register_api( + hass: HomeAssistant, tool_input_context: llm.ToolContext +) -> None: """Test registering an llm api.""" class MyAPI(llm.API): - async def async_get_api_prompt(self, tool_input: llm.ToolInput) -> str: - """Return a prompt for the tool.""" - return "" - - def async_get_tools(self) -> list[llm.Tool]: + async def async_get_api_instance( + self, tool_input: llm.ToolInput + ) -> llm.APIInstance: """Return a list of tools.""" - return [] + return llm.APIInstance(self, "", [], tool_input_context) api = MyAPI(hass=hass, id="test", name="Test") llm.async_register_api(hass, api) - assert llm.async_get_api(hass, "test") is api + instance = await llm.async_get_api(hass, "test", tool_input_context) + assert instance.api is api assert api in llm.async_get_apis(hass) with pytest.raises(HomeAssistantError): llm.async_register_api(hass, api) -async def test_call_tool_no_existing(hass: HomeAssistant) -> None: +async def test_call_tool_no_existing( + hass: HomeAssistant, tool_input_context: llm.ToolContext +) -> None: """Test calling an llm tool where no config exists.""" + instance = await llm.async_get_api(hass, "assist", tool_input_context) with pytest.raises(HomeAssistantError): - await llm.async_get_api(hass, "intent").async_call_tool( - llm.ToolInput( - "test_tool", - {}, - "test_platform", - None, - None, - None, - None, - None, - ), + await instance.async_call_tool( + llm.ToolInput("test_tool", {}), ) -async def test_assist_api(hass: HomeAssistant) -> None: +async def test_assist_api( + hass: HomeAssistant, entity_registry: er.EntityRegistry +) -> None: """Test Assist API.""" + assert await async_setup_component(hass, "homeassistant", {}) + + entity_registry.async_get_or_create( + "light", + "kitchen", + "mock-id-kitchen", + original_name="Kitchen", + suggested_object_id="kitchen", + ).write_unavailable_state(hass) + + test_context = Context() + tool_context = llm.ToolContext( + platform="test_platform", + context=test_context, + user_prompt="test_text", + language="*", + assistant="conversation", + device_id="test_device", + ) schema = { vol.Optional("area"): cv.string, vol.Optional("floor"): cv.string, @@ -77,22 +109,33 @@ async def test_assist_api(hass: HomeAssistant) -> None: class MyIntentHandler(intent.IntentHandler): intent_type = "test_intent" slot_schema = schema + platforms = set() # Match none intent_handler = MyIntentHandler() intent.async_register(hass, intent_handler) assert len(llm.async_get_apis(hass)) == 1 - api = llm.async_get_api(hass, "assist") - tools = api.async_get_tools() - assert len(tools) == 1 - tool = tools[0] + api = await llm.async_get_api(hass, "assist", tool_context) + assert len(api.tools) == 0 + + # Match all + intent_handler.platforms = None + + api = await llm.async_get_api(hass, "assist", tool_context) + assert len(api.tools) == 1 + + # Match specific domain + intent_handler.platforms = {"light"} + + api = await llm.async_get_api(hass, "assist", tool_context) + assert len(api.tools) == 1 + tool = api.tools[0] assert tool.name == "test_intent" assert tool.description == "Execute Home Assistant test_intent intent" assert tool.parameters == vol.Schema(intent_handler.slot_schema) assert str(tool) == "" - test_context = Context() assert test_context.json_fragment # To reproduce an error case in tracing intent_response = intent.IntentResponse("*") intent_response.matched_states = [State("light.matched", "on")] @@ -100,12 +143,6 @@ async def test_assist_api(hass: HomeAssistant) -> None: tool_input = llm.ToolInput( tool_name="test_intent", tool_args={"area": "kitchen", "floor": "ground_floor"}, - platform="test_platform", - context=test_context, - user_prompt="test_text", - language="*", - assistant="test_assistant", - device_id="test_device", ) with patch( @@ -114,18 +151,18 @@ async def test_assist_api(hass: HomeAssistant) -> None: response = await api.async_call_tool(tool_input) mock_intent_handle.assert_awaited_once_with( - hass, - "test_platform", - "test_intent", - { + hass=hass, + platform="test_platform", + intent_type="test_intent", + slots={ "area": {"value": "kitchen"}, "floor": {"value": "ground_floor"}, }, - "test_text", - test_context, - "*", - "test_assistant", - "test_device", + text_input="test_text", + context=test_context, + language="*", + assistant="conversation", + device_id="test_device", ) assert response == { "card": {}, @@ -140,7 +177,27 @@ async def test_assist_api(hass: HomeAssistant) -> None: } -async def test_assist_api_description(hass: HomeAssistant) -> None: +async def test_assist_api_get_timer_tools( + hass: HomeAssistant, tool_input_context: llm.ToolContext +) -> None: + """Test getting timer tools with Assist API.""" + assert await async_setup_component(hass, "homeassistant", {}) + assert await async_setup_component(hass, "intent", {}) + api = await llm.async_get_api(hass, "assist", tool_input_context) + + assert "HassStartTimer" not in [tool.name for tool in api.tools] + + tool_input_context.device_id = "test_device" + + async_register_timer_handler(hass, "test_device", lambda *args: None) + + api = await llm.async_get_api(hass, "assist", tool_input_context) + assert "HassStartTimer" in [tool.name for tool in api.tools] + + +async def test_assist_api_description( + hass: HomeAssistant, tool_input_context: llm.ToolContext +) -> None: """Test intent description with Assist API.""" class MyIntentHandler(intent.IntentHandler): @@ -150,10 +207,9 @@ async def test_assist_api_description(hass: HomeAssistant) -> None: intent.async_register(hass, MyIntentHandler()) assert len(llm.async_get_apis(hass)) == 1 - api = llm.async_get_api(hass, "assist") - tools = api.async_get_tools() - assert len(tools) == 1 - tool = tools[0] + api = await llm.async_get_api(hass, "assist", tool_input_context) + assert len(api.tools) == 1 + tool = api.tools[0] assert tool.name == "test_intent" assert tool.description == "my intent handler" @@ -167,20 +223,18 @@ async def test_assist_api_prompt( ) -> None: """Test prompt for the assist API.""" assert await async_setup_component(hass, "homeassistant", {}) + assert await async_setup_component(hass, "intent", {}) context = Context() - tool_input = llm.ToolInput( - tool_name=None, - tool_args=None, + tool_context = llm.ToolContext( platform="test_platform", context=context, user_prompt="test_text", language="*", assistant="conversation", - device_id="test_device", + device_id=None, ) - api = llm.async_get_api(hass, "assist") - prompt = await api.async_get_api_prompt(tool_input) - assert prompt == ( + api = await llm.async_get_api(hass, "assist", tool_context) + assert api.api_prompt == ( "Only if the user wants to control a device, tell them to expose entities to their " "voice assistant in Home Assistant." ) @@ -308,7 +362,7 @@ async def test_assist_api_prompt( ) ) - exposed_entities = llm._get_exposed_entities(hass, tool_input.assistant) + exposed_entities = llm._get_exposed_entities(hass, tool_context.assistant) assert exposed_entities == { "light.1": { "areas": "Test Area 2", @@ -373,40 +427,55 @@ async def test_assist_api_prompt( "Call the intent tools to control Home Assistant. " "When controlling an area, prefer passing area name." ) + no_timer_prompt = "This device does not support timers." - prompt = await api.async_get_api_prompt(tool_input) area_prompt = ( "Reject all generic commands like 'turn on the lights' because we don't know in what area " "this conversation is happening." ) - assert prompt == ( + api = await llm.async_get_api(hass, "assist", tool_context) + assert api.api_prompt == ( f"""{first_part_prompt} {area_prompt} +{no_timer_prompt} {exposed_entities_prompt}""" ) - # Fake that request is made from a specific device ID - tool_input.device_id = device.id - prompt = await api.async_get_api_prompt(tool_input) + # Fake that request is made from a specific device ID with an area + tool_context.device_id = device.id area_prompt = ( "You are in area Test Area and all generic commands like 'turn on the lights' " "should target this area." ) - assert prompt == ( + api = await llm.async_get_api(hass, "assist", tool_context) + assert api.api_prompt == ( f"""{first_part_prompt} {area_prompt} +{no_timer_prompt} {exposed_entities_prompt}""" ) # Add floor floor = floor_registry.async_create("2") area_registry.async_update(area.id, floor_id=floor.floor_id) - prompt = await api.async_get_api_prompt(tool_input) area_prompt = ( "You are in area Test Area (floor 2) and all generic commands like 'turn on the lights' " "should target this area." ) - assert prompt == ( + api = await llm.async_get_api(hass, "assist", tool_context) + assert api.api_prompt == ( + f"""{first_part_prompt} +{area_prompt} +{no_timer_prompt} +{exposed_entities_prompt}""" + ) + + # Register device for timers + async_register_timer_handler(hass, device.id, lambda *args: None) + + api = await llm.async_get_api(hass, "assist", tool_context) + # The no_timer_prompt is gone + assert api.api_prompt == ( f"""{first_part_prompt} {area_prompt} {exposed_entities_prompt}""" @@ -418,8 +487,8 @@ async def test_assist_api_prompt( mock_user.id = "12345" mock_user.name = "Test User" with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user): - prompt = await api.async_get_api_prompt(tool_input) - assert prompt == ( + api = await llm.async_get_api(hass, "assist", tool_context) + assert api.api_prompt == ( f"""{first_part_prompt} {area_prompt} The user name is Test User.