LLM Assist API to ignore intents if not needed for exposed entities or calling device (#118283)
* LLM Assist API to ignore timer intents if device doesn't support it * Refactor to use API instances * Extract ToolContext class * Limit exposed intents based on exposed entities
This commit is contained in:
parent
e0264c8604
commit
615a1eda51
8 changed files with 302 additions and 181 deletions
|
@ -149,13 +149,22 @@ class GoogleGenerativeAIConversationEntity(
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
llm_api: llm.API | None = None
|
llm_api: llm.APIInstance | None = None
|
||||||
tools: list[dict[str, Any]] | None = None
|
tools: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||||
try:
|
try:
|
||||||
llm_api = llm.async_get_api(
|
llm_api = await llm.async_get_api(
|
||||||
self.hass, self.entry.options[CONF_LLM_HASS_API]
|
self.hass,
|
||||||
|
self.entry.options[CONF_LLM_HASS_API],
|
||||||
|
llm.ToolContext(
|
||||||
|
platform=DOMAIN,
|
||||||
|
context=user_input.context,
|
||||||
|
user_prompt=user_input.text,
|
||||||
|
language=user_input.language,
|
||||||
|
assistant=conversation.DOMAIN,
|
||||||
|
device_id=user_input.device_id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
except HomeAssistantError as err:
|
except HomeAssistantError as err:
|
||||||
LOGGER.error("Error getting LLM API: %s", err)
|
LOGGER.error("Error getting LLM API: %s", err)
|
||||||
|
@ -166,7 +175,7 @@ class GoogleGenerativeAIConversationEntity(
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=user_input.conversation_id
|
response=intent_response, conversation_id=user_input.conversation_id
|
||||||
)
|
)
|
||||||
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
|
tools = [_format_tool(tool) for tool in llm_api.tools]
|
||||||
|
|
||||||
model = genai.GenerativeModel(
|
model = genai.GenerativeModel(
|
||||||
model_name=self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
model_name=self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||||
|
@ -206,19 +215,7 @@ class GoogleGenerativeAIConversationEntity(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if llm_api:
|
if llm_api:
|
||||||
empty_tool_input = llm.ToolInput(
|
api_prompt = llm_api.api_prompt
|
||||||
tool_name="",
|
|
||||||
tool_args={},
|
|
||||||
platform=DOMAIN,
|
|
||||||
context=user_input.context,
|
|
||||||
user_prompt=user_input.text,
|
|
||||||
language=user_input.language,
|
|
||||||
assistant=conversation.DOMAIN,
|
|
||||||
device_id=user_input.device_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
api_prompt = llm.async_render_no_api_prompt(self.hass)
|
api_prompt = llm.async_render_no_api_prompt(self.hass)
|
||||||
|
|
||||||
|
@ -309,12 +306,6 @@ class GoogleGenerativeAIConversationEntity(
|
||||||
tool_input = llm.ToolInput(
|
tool_input = llm.ToolInput(
|
||||||
tool_name=tool_call.name,
|
tool_name=tool_call.name,
|
||||||
tool_args=dict(tool_call.args),
|
tool_args=dict(tool_call.args),
|
||||||
platform=DOMAIN,
|
|
||||||
context=user_input.context,
|
|
||||||
user_prompt=user_input.text,
|
|
||||||
language=user_input.language,
|
|
||||||
assistant=conversation.DOMAIN,
|
|
||||||
device_id=user_input.device_id,
|
|
||||||
)
|
)
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||||
|
|
|
@ -50,6 +50,7 @@ from .timers import (
|
||||||
TimerManager,
|
TimerManager,
|
||||||
TimerStatusIntentHandler,
|
TimerStatusIntentHandler,
|
||||||
UnpauseTimerIntentHandler,
|
UnpauseTimerIntentHandler,
|
||||||
|
async_device_supports_timers,
|
||||||
async_register_timer_handler,
|
async_register_timer_handler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -59,6 +60,7 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"async_register_timer_handler",
|
"async_register_timer_handler",
|
||||||
|
"async_device_supports_timers",
|
||||||
"TimerInfo",
|
"TimerInfo",
|
||||||
"TimerEventType",
|
"TimerEventType",
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
|
|
|
@ -415,6 +415,15 @@ class TimerManager:
|
||||||
return device_id in self.handlers
|
return device_id in self.handlers
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_device_supports_timers(hass: HomeAssistant, device_id: str) -> bool:
|
||||||
|
"""Return True if device has been registered to handle timer events."""
|
||||||
|
timer_manager: TimerManager | None = hass.data.get(TIMER_DATA)
|
||||||
|
if timer_manager is None:
|
||||||
|
return False
|
||||||
|
return timer_manager.is_timer_device(device_id)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_register_timer_handler(
|
def async_register_timer_handler(
|
||||||
hass: HomeAssistant, device_id: str, handler: TimerHandler
|
hass: HomeAssistant, device_id: str, handler: TimerHandler
|
||||||
|
|
|
@ -99,12 +99,23 @@ class OpenAIConversationEntity(
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
options = self.entry.options
|
options = self.entry.options
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
llm_api: llm.API | None = None
|
llm_api: llm.APIInstance | None = None
|
||||||
tools: list[dict[str, Any]] | None = None
|
tools: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
if options.get(CONF_LLM_HASS_API):
|
if options.get(CONF_LLM_HASS_API):
|
||||||
try:
|
try:
|
||||||
llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API])
|
llm_api = await llm.async_get_api(
|
||||||
|
self.hass,
|
||||||
|
options[CONF_LLM_HASS_API],
|
||||||
|
llm.ToolContext(
|
||||||
|
platform=DOMAIN,
|
||||||
|
context=user_input.context,
|
||||||
|
user_prompt=user_input.text,
|
||||||
|
language=user_input.language,
|
||||||
|
assistant=conversation.DOMAIN,
|
||||||
|
device_id=user_input.device_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
except HomeAssistantError as err:
|
except HomeAssistantError as err:
|
||||||
LOGGER.error("Error getting LLM API: %s", err)
|
LOGGER.error("Error getting LLM API: %s", err)
|
||||||
intent_response.async_set_error(
|
intent_response.async_set_error(
|
||||||
|
@ -114,7 +125,7 @@ class OpenAIConversationEntity(
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=user_input.conversation_id
|
response=intent_response, conversation_id=user_input.conversation_id
|
||||||
)
|
)
|
||||||
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
|
tools = [_format_tool(tool) for tool in llm_api.tools]
|
||||||
|
|
||||||
if user_input.conversation_id in self.history:
|
if user_input.conversation_id in self.history:
|
||||||
conversation_id = user_input.conversation_id
|
conversation_id = user_input.conversation_id
|
||||||
|
@ -123,19 +134,7 @@ class OpenAIConversationEntity(
|
||||||
conversation_id = ulid.ulid_now()
|
conversation_id = ulid.ulid_now()
|
||||||
try:
|
try:
|
||||||
if llm_api:
|
if llm_api:
|
||||||
empty_tool_input = llm.ToolInput(
|
api_prompt = llm_api.api_prompt
|
||||||
tool_name="",
|
|
||||||
tool_args={},
|
|
||||||
platform=DOMAIN,
|
|
||||||
context=user_input.context,
|
|
||||||
user_prompt=user_input.text,
|
|
||||||
language=user_input.language,
|
|
||||||
assistant=conversation.DOMAIN,
|
|
||||||
device_id=user_input.device_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
api_prompt = llm.async_render_no_api_prompt(self.hass)
|
api_prompt = llm.async_render_no_api_prompt(self.hass)
|
||||||
|
|
||||||
|
@ -182,7 +181,7 @@ class OpenAIConversationEntity(
|
||||||
result = await client.chat.completions.create(
|
result = await client.chat.completions.create(
|
||||||
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools or None,
|
||||||
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||||
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||||
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||||
|
@ -210,12 +209,6 @@ class OpenAIConversationEntity(
|
||||||
tool_input = llm.ToolInput(
|
tool_input = llm.ToolInput(
|
||||||
tool_name=tool_call.function.name,
|
tool_name=tool_call.function.name,
|
||||||
tool_args=json.loads(tool_call.function.arguments),
|
tool_args=json.loads(tool_call.function.arguments),
|
||||||
platform=DOMAIN,
|
|
||||||
context=user_input.context,
|
|
||||||
user_prompt=user_input.text,
|
|
||||||
language=user_input.language,
|
|
||||||
assistant=conversation.DOMAIN,
|
|
||||||
device_id=user_input.device_id,
|
|
||||||
)
|
)
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import asdict, dataclass, replace
|
from dataclasses import asdict, dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ from homeassistant.components.conversation.trace import (
|
||||||
async_conversation_trace_append,
|
async_conversation_trace_append,
|
||||||
)
|
)
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||||
|
from homeassistant.components.intent import async_device_supports_timers
|
||||||
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
|
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
|
||||||
from homeassistant.core import Context, HomeAssistant, callback
|
from homeassistant.core import Context, HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
@ -68,15 +69,16 @@ def async_register_api(hass: HomeAssistant, api: API) -> None:
|
||||||
apis[api.id] = api
|
apis[api.id] = api
|
||||||
|
|
||||||
|
|
||||||
@callback
|
async def async_get_api(
|
||||||
def async_get_api(hass: HomeAssistant, api_id: str) -> API:
|
hass: HomeAssistant, api_id: str, tool_context: ToolContext
|
||||||
|
) -> APIInstance:
|
||||||
"""Get an API."""
|
"""Get an API."""
|
||||||
apis = _async_get_apis(hass)
|
apis = _async_get_apis(hass)
|
||||||
|
|
||||||
if api_id not in apis:
|
if api_id not in apis:
|
||||||
raise HomeAssistantError(f"API {api_id} not found")
|
raise HomeAssistantError(f"API {api_id} not found")
|
||||||
|
|
||||||
return apis[api_id]
|
return await apis[api_id].async_get_api_instance(tool_context)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
|
@ -86,11 +88,9 @@ def async_get_apis(hass: HomeAssistant) -> list[API]:
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class ToolInput(ABC):
|
class ToolContext:
|
||||||
"""Tool input to be processed."""
|
"""Tool input to be processed."""
|
||||||
|
|
||||||
tool_name: str
|
|
||||||
tool_args: dict[str, Any]
|
|
||||||
platform: str
|
platform: str
|
||||||
context: Context | None
|
context: Context | None
|
||||||
user_prompt: str | None
|
user_prompt: str | None
|
||||||
|
@ -99,6 +99,14 @@ class ToolInput(ABC):
|
||||||
device_id: str | None
|
device_id: str | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ToolInput:
|
||||||
|
"""Tool input to be processed."""
|
||||||
|
|
||||||
|
tool_name: str
|
||||||
|
tool_args: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class Tool:
|
class Tool:
|
||||||
"""LLM Tool base class."""
|
"""LLM Tool base class."""
|
||||||
|
|
||||||
|
@ -108,7 +116,7 @@ class Tool:
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def async_call(
|
async def async_call(
|
||||||
self, hass: HomeAssistant, tool_input: ToolInput
|
self, hass: HomeAssistant, tool_input: ToolInput, tool_context: ToolContext
|
||||||
) -> JsonObjectType:
|
) -> JsonObjectType:
|
||||||
"""Call the tool."""
|
"""Call the tool."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -118,6 +126,30 @@ class Tool:
|
||||||
return f"<{self.__class__.__name__} - {self.name}>"
|
return f"<{self.__class__.__name__} - {self.name}>"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class APIInstance:
|
||||||
|
"""Instance of an API to be used by an LLM."""
|
||||||
|
|
||||||
|
api: API
|
||||||
|
api_prompt: str
|
||||||
|
tool_context: ToolContext
|
||||||
|
tools: list[Tool]
|
||||||
|
|
||||||
|
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
||||||
|
"""Call a LLM tool, validate args and return the response."""
|
||||||
|
async_conversation_trace_append(
|
||||||
|
ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input)
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool in self.tools:
|
||||||
|
if tool.name == tool_input.tool_name:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
|
||||||
|
|
||||||
|
return await tool.async_call(self.api.hass, tool_input, self.tool_context)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True, kw_only=True)
|
@dataclass(slots=True, kw_only=True)
|
||||||
class API(ABC):
|
class API(ABC):
|
||||||
"""An API to expose to LLMs."""
|
"""An API to expose to LLMs."""
|
||||||
|
@ -127,38 +159,10 @@ class API(ABC):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
|
async def async_get_api_instance(self, tool_context: ToolContext) -> APIInstance:
|
||||||
"""Return the prompt for the API."""
|
"""Return the instance of the API."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
@callback
|
|
||||||
def async_get_tools(self) -> list[Tool]:
|
|
||||||
"""Return a list of tools."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
|
||||||
"""Call a LLM tool, validate args and return the response."""
|
|
||||||
async_conversation_trace_append(
|
|
||||||
ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input)
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool in self.async_get_tools():
|
|
||||||
if tool.name == tool_input.tool_name:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
|
|
||||||
|
|
||||||
return await tool.async_call(
|
|
||||||
self.hass,
|
|
||||||
replace(
|
|
||||||
tool_input,
|
|
||||||
tool_name=tool.name,
|
|
||||||
tool_args=tool.parameters(tool_input.tool_args),
|
|
||||||
context=tool_input.context or Context(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IntentTool(Tool):
|
class IntentTool(Tool):
|
||||||
"""LLM Tool representing an Intent."""
|
"""LLM Tool representing an Intent."""
|
||||||
|
@ -176,21 +180,20 @@ class IntentTool(Tool):
|
||||||
self.parameters = vol.Schema(slot_schema)
|
self.parameters = vol.Schema(slot_schema)
|
||||||
|
|
||||||
async def async_call(
|
async def async_call(
|
||||||
self, hass: HomeAssistant, tool_input: ToolInput
|
self, hass: HomeAssistant, tool_input: ToolInput, tool_context: ToolContext
|
||||||
) -> JsonObjectType:
|
) -> JsonObjectType:
|
||||||
"""Handle the intent."""
|
"""Handle the intent."""
|
||||||
slots = {key: {"value": val} for key, val in tool_input.tool_args.items()}
|
slots = {key: {"value": val} for key, val in tool_input.tool_args.items()}
|
||||||
|
|
||||||
intent_response = await intent.async_handle(
|
intent_response = await intent.async_handle(
|
||||||
hass,
|
hass=hass,
|
||||||
tool_input.platform,
|
platform=tool_context.platform,
|
||||||
self.name,
|
intent_type=self.name,
|
||||||
slots,
|
slots=slots,
|
||||||
tool_input.user_prompt,
|
text_input=tool_context.user_prompt,
|
||||||
tool_input.context,
|
context=tool_context.context,
|
||||||
tool_input.language,
|
language=tool_context.language,
|
||||||
tool_input.assistant,
|
assistant=tool_context.assistant,
|
||||||
tool_input.device_id,
|
device_id=tool_context.device_id,
|
||||||
)
|
)
|
||||||
return intent_response.as_dict()
|
return intent_response.as_dict()
|
||||||
|
|
||||||
|
@ -213,15 +216,26 @@ class AssistAPI(API):
|
||||||
name="Assist",
|
name="Assist",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
|
async def async_get_api_instance(self, tool_context: ToolContext) -> APIInstance:
|
||||||
"""Return the prompt for the API."""
|
"""Return the instance of the API."""
|
||||||
if tool_input.assistant:
|
if tool_context.assistant:
|
||||||
exposed_entities: dict | None = _get_exposed_entities(
|
exposed_entities: dict | None = _get_exposed_entities(
|
||||||
self.hass, tool_input.assistant
|
self.hass, tool_context.assistant
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
exposed_entities = None
|
exposed_entities = None
|
||||||
|
|
||||||
|
return APIInstance(
|
||||||
|
api=self,
|
||||||
|
api_prompt=await self._async_get_api_prompt(tool_context, exposed_entities),
|
||||||
|
tool_context=tool_context,
|
||||||
|
tools=self._async_get_tools(tool_context, exposed_entities),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _async_get_api_prompt(
|
||||||
|
self, tool_context: ToolContext, exposed_entities: dict | None
|
||||||
|
) -> str:
|
||||||
|
"""Return the prompt for the API."""
|
||||||
if not exposed_entities:
|
if not exposed_entities:
|
||||||
return (
|
return (
|
||||||
"Only if the user wants to control a device, tell them to expose entities "
|
"Only if the user wants to control a device, tell them to expose entities "
|
||||||
|
@ -236,9 +250,9 @@ class AssistAPI(API):
|
||||||
]
|
]
|
||||||
area: ar.AreaEntry | None = None
|
area: ar.AreaEntry | None = None
|
||||||
floor: fr.FloorEntry | None = None
|
floor: fr.FloorEntry | None = None
|
||||||
if tool_input.device_id:
|
if tool_context.device_id:
|
||||||
device_reg = dr.async_get(self.hass)
|
device_reg = dr.async_get(self.hass)
|
||||||
device = device_reg.async_get(tool_input.device_id)
|
device = device_reg.async_get(tool_context.device_id)
|
||||||
|
|
||||||
if device:
|
if device:
|
||||||
area_reg = ar.async_get(self.hass)
|
area_reg = ar.async_get(self.hass)
|
||||||
|
@ -259,11 +273,16 @@ class AssistAPI(API):
|
||||||
"don't know in what area this conversation is happening."
|
"don't know in what area this conversation is happening."
|
||||||
)
|
)
|
||||||
|
|
||||||
if tool_input.context and tool_input.context.user_id:
|
if tool_context.context and tool_context.context.user_id:
|
||||||
user = await self.hass.auth.async_get_user(tool_input.context.user_id)
|
user = await self.hass.auth.async_get_user(tool_context.context.user_id)
|
||||||
if user:
|
if user:
|
||||||
prompt.append(f"The user name is {user.name}.")
|
prompt.append(f"The user name is {user.name}.")
|
||||||
|
|
||||||
|
if not tool_context.device_id or not async_device_supports_timers(
|
||||||
|
self.hass, tool_context.device_id
|
||||||
|
):
|
||||||
|
prompt.append("This device does not support timers.")
|
||||||
|
|
||||||
if exposed_entities:
|
if exposed_entities:
|
||||||
prompt.append(
|
prompt.append(
|
||||||
"An overview of the areas and the devices in this smart home:"
|
"An overview of the areas and the devices in this smart home:"
|
||||||
|
@ -273,14 +292,44 @@ class AssistAPI(API):
|
||||||
return "\n".join(prompt)
|
return "\n".join(prompt)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_tools(self) -> list[Tool]:
|
def _async_get_tools(
|
||||||
|
self, tool_context: ToolContext, exposed_entities: dict | None
|
||||||
|
) -> list[Tool]:
|
||||||
"""Return a list of LLM tools."""
|
"""Return a list of LLM tools."""
|
||||||
return [
|
ignore_intents = self.IGNORE_INTENTS
|
||||||
IntentTool(intent_handler)
|
if not tool_context.device_id or not async_device_supports_timers(
|
||||||
|
self.hass, tool_context.device_id
|
||||||
|
):
|
||||||
|
ignore_intents = ignore_intents | {
|
||||||
|
intent.INTENT_START_TIMER,
|
||||||
|
intent.INTENT_CANCEL_TIMER,
|
||||||
|
intent.INTENT_INCREASE_TIMER,
|
||||||
|
intent.INTENT_DECREASE_TIMER,
|
||||||
|
intent.INTENT_PAUSE_TIMER,
|
||||||
|
intent.INTENT_UNPAUSE_TIMER,
|
||||||
|
intent.INTENT_TIMER_STATUS,
|
||||||
|
}
|
||||||
|
|
||||||
|
intent_handlers = [
|
||||||
|
intent_handler
|
||||||
for intent_handler in intent.async_get(self.hass)
|
for intent_handler in intent.async_get(self.hass)
|
||||||
if intent_handler.intent_type not in self.IGNORE_INTENTS
|
if intent_handler.intent_type not in ignore_intents
|
||||||
]
|
]
|
||||||
|
|
||||||
|
exposed_domains: set[str] | None = None
|
||||||
|
if exposed_entities is not None:
|
||||||
|
exposed_domains = {
|
||||||
|
entity_id.split(".")[0] for entity_id in exposed_entities
|
||||||
|
}
|
||||||
|
intent_handlers = [
|
||||||
|
intent_handler
|
||||||
|
for intent_handler in intent_handlers
|
||||||
|
if intent_handler.platforms is None
|
||||||
|
or intent_handler.platforms & exposed_domains
|
||||||
|
]
|
||||||
|
|
||||||
|
return [IntentTool(intent_handler) for intent_handler in intent_handlers]
|
||||||
|
|
||||||
|
|
||||||
def _get_exposed_entities(
|
def _get_exposed_entities(
|
||||||
hass: HomeAssistant, assistant: str
|
hass: HomeAssistant, assistant: str
|
||||||
|
|
|
@ -61,11 +61,11 @@ async def test_default_prompt(
|
||||||
with (
|
with (
|
||||||
patch("google.generativeai.GenerativeModel") as mock_model,
|
patch("google.generativeai.GenerativeModel") as mock_model,
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools",
|
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools",
|
||||||
return_value=[],
|
return_value=[],
|
||||||
) as mock_get_tools,
|
) as mock_get_tools,
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_api_prompt",
|
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_api_prompt",
|
||||||
return_value="<api_prompt>",
|
return_value="<api_prompt>",
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
|
@ -148,7 +148,7 @@ async def test_chat_history(
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||||
)
|
)
|
||||||
async def test_function_call(
|
async def test_function_call(
|
||||||
mock_get_tools,
|
mock_get_tools,
|
||||||
|
@ -182,7 +182,7 @@ async def test_function_call(
|
||||||
mock_part.function_call.name = "test_tool"
|
mock_part.function_call.name = "test_tool"
|
||||||
mock_part.function_call.args = {"param1": ["test_value"]}
|
mock_part.function_call.args = {"param1": ["test_value"]}
|
||||||
|
|
||||||
def tool_call(hass, tool_input):
|
def tool_call(hass, tool_input, tool_context):
|
||||||
mock_part.function_call = None
|
mock_part.function_call = None
|
||||||
mock_part.text = "Hi there!"
|
mock_part.text = "Hi there!"
|
||||||
return {"result": "Test response"}
|
return {"result": "Test response"}
|
||||||
|
@ -221,6 +221,8 @@ async def test_function_call(
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": ["test_value"]},
|
tool_args={"param1": ["test_value"]},
|
||||||
|
),
|
||||||
|
llm.ToolContext(
|
||||||
platform="google_generative_ai_conversation",
|
platform="google_generative_ai_conversation",
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt="Please call the test function",
|
user_prompt="Please call the test function",
|
||||||
|
@ -246,7 +248,7 @@ async def test_function_call(
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||||
)
|
)
|
||||||
async def test_function_exception(
|
async def test_function_exception(
|
||||||
mock_get_tools,
|
mock_get_tools,
|
||||||
|
@ -280,7 +282,7 @@ async def test_function_exception(
|
||||||
mock_part.function_call.name = "test_tool"
|
mock_part.function_call.name = "test_tool"
|
||||||
mock_part.function_call.args = {"param1": 1}
|
mock_part.function_call.args = {"param1": 1}
|
||||||
|
|
||||||
def tool_call(hass, tool_input):
|
def tool_call(hass, tool_input, tool_context):
|
||||||
mock_part.function_call = None
|
mock_part.function_call = None
|
||||||
mock_part.text = "Hi there!"
|
mock_part.text = "Hi there!"
|
||||||
raise HomeAssistantError("Test tool exception")
|
raise HomeAssistantError("Test tool exception")
|
||||||
|
@ -319,6 +321,8 @@ async def test_function_exception(
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": 1},
|
tool_args={"param1": 1},
|
||||||
|
),
|
||||||
|
llm.ToolContext(
|
||||||
platform="google_generative_ai_conversation",
|
platform="google_generative_ai_conversation",
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt="Please call the test function",
|
user_prompt="Please call the test function",
|
||||||
|
|
|
@ -86,7 +86,7 @@ async def test_conversation_agent(
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||||
)
|
)
|
||||||
async def test_function_call(
|
async def test_function_call(
|
||||||
mock_get_tools,
|
mock_get_tools,
|
||||||
|
@ -192,6 +192,8 @@ async def test_function_call(
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "test_value"},
|
tool_args={"param1": "test_value"},
|
||||||
|
),
|
||||||
|
llm.ToolContext(
|
||||||
platform="openai_conversation",
|
platform="openai_conversation",
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt="Please call the test function",
|
user_prompt="Please call the test function",
|
||||||
|
@ -217,7 +219,7 @@ async def test_function_call(
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||||
)
|
)
|
||||||
async def test_function_exception(
|
async def test_function_exception(
|
||||||
mock_get_tools,
|
mock_get_tools,
|
||||||
|
@ -323,6 +325,8 @@ async def test_function_exception(
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "test_value"},
|
tool_args={"param1": "test_value"},
|
||||||
|
),
|
||||||
|
llm.ToolContext(
|
||||||
platform="openai_conversation",
|
platform="openai_conversation",
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt="Please call the test function",
|
user_prompt="Please call the test function",
|
||||||
|
|
|
@ -5,6 +5,7 @@ from unittest.mock import Mock, patch
|
||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components.intent import async_register_timer_handler
|
||||||
from homeassistant.core import Context, HomeAssistant, State
|
from homeassistant.core import Context, HomeAssistant, State
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
|
@ -22,53 +23,84 @@ from homeassistant.util import yaml
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
|
||||||
async def test_get_api_no_existing(hass: HomeAssistant) -> None:
|
@pytest.fixture
|
||||||
|
def tool_input_context() -> llm.ToolContext:
|
||||||
|
"""Return tool input context."""
|
||||||
|
return llm.ToolContext(
|
||||||
|
platform="",
|
||||||
|
context=None,
|
||||||
|
user_prompt=None,
|
||||||
|
language=None,
|
||||||
|
assistant=None,
|
||||||
|
device_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_api_no_existing(
|
||||||
|
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
||||||
|
) -> None:
|
||||||
"""Test getting an llm api where no config exists."""
|
"""Test getting an llm api where no config exists."""
|
||||||
with pytest.raises(HomeAssistantError):
|
with pytest.raises(HomeAssistantError):
|
||||||
llm.async_get_api(hass, "non-existing")
|
await llm.async_get_api(hass, "non-existing", tool_input_context)
|
||||||
|
|
||||||
|
|
||||||
async def test_register_api(hass: HomeAssistant) -> None:
|
async def test_register_api(
|
||||||
|
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
||||||
|
) -> None:
|
||||||
"""Test registering an llm api."""
|
"""Test registering an llm api."""
|
||||||
|
|
||||||
class MyAPI(llm.API):
|
class MyAPI(llm.API):
|
||||||
async def async_get_api_prompt(self, tool_input: llm.ToolInput) -> str:
|
async def async_get_api_instance(
|
||||||
"""Return a prompt for the tool."""
|
self, tool_input: llm.ToolInput
|
||||||
return ""
|
) -> llm.APIInstance:
|
||||||
|
|
||||||
def async_get_tools(self) -> list[llm.Tool]:
|
|
||||||
"""Return a list of tools."""
|
"""Return a list of tools."""
|
||||||
return []
|
return llm.APIInstance(self, "", [], tool_input_context)
|
||||||
|
|
||||||
api = MyAPI(hass=hass, id="test", name="Test")
|
api = MyAPI(hass=hass, id="test", name="Test")
|
||||||
llm.async_register_api(hass, api)
|
llm.async_register_api(hass, api)
|
||||||
|
|
||||||
assert llm.async_get_api(hass, "test") is api
|
instance = await llm.async_get_api(hass, "test", tool_input_context)
|
||||||
|
assert instance.api is api
|
||||||
assert api in llm.async_get_apis(hass)
|
assert api in llm.async_get_apis(hass)
|
||||||
|
|
||||||
with pytest.raises(HomeAssistantError):
|
with pytest.raises(HomeAssistantError):
|
||||||
llm.async_register_api(hass, api)
|
llm.async_register_api(hass, api)
|
||||||
|
|
||||||
|
|
||||||
async def test_call_tool_no_existing(hass: HomeAssistant) -> None:
|
async def test_call_tool_no_existing(
|
||||||
|
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
||||||
|
) -> None:
|
||||||
"""Test calling an llm tool where no config exists."""
|
"""Test calling an llm tool where no config exists."""
|
||||||
|
instance = await llm.async_get_api(hass, "assist", tool_input_context)
|
||||||
with pytest.raises(HomeAssistantError):
|
with pytest.raises(HomeAssistantError):
|
||||||
await llm.async_get_api(hass, "intent").async_call_tool(
|
await instance.async_call_tool(
|
||||||
llm.ToolInput(
|
llm.ToolInput("test_tool", {}),
|
||||||
"test_tool",
|
|
||||||
{},
|
|
||||||
"test_platform",
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_assist_api(hass: HomeAssistant) -> None:
|
async def test_assist_api(
|
||||||
|
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
||||||
|
) -> None:
|
||||||
"""Test Assist API."""
|
"""Test Assist API."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
|
||||||
|
entity_registry.async_get_or_create(
|
||||||
|
"light",
|
||||||
|
"kitchen",
|
||||||
|
"mock-id-kitchen",
|
||||||
|
original_name="Kitchen",
|
||||||
|
suggested_object_id="kitchen",
|
||||||
|
).write_unavailable_state(hass)
|
||||||
|
|
||||||
|
test_context = Context()
|
||||||
|
tool_context = llm.ToolContext(
|
||||||
|
platform="test_platform",
|
||||||
|
context=test_context,
|
||||||
|
user_prompt="test_text",
|
||||||
|
language="*",
|
||||||
|
assistant="conversation",
|
||||||
|
device_id="test_device",
|
||||||
|
)
|
||||||
schema = {
|
schema = {
|
||||||
vol.Optional("area"): cv.string,
|
vol.Optional("area"): cv.string,
|
||||||
vol.Optional("floor"): cv.string,
|
vol.Optional("floor"): cv.string,
|
||||||
|
@ -77,22 +109,33 @@ async def test_assist_api(hass: HomeAssistant) -> None:
|
||||||
class MyIntentHandler(intent.IntentHandler):
|
class MyIntentHandler(intent.IntentHandler):
|
||||||
intent_type = "test_intent"
|
intent_type = "test_intent"
|
||||||
slot_schema = schema
|
slot_schema = schema
|
||||||
|
platforms = set() # Match none
|
||||||
|
|
||||||
intent_handler = MyIntentHandler()
|
intent_handler = MyIntentHandler()
|
||||||
|
|
||||||
intent.async_register(hass, intent_handler)
|
intent.async_register(hass, intent_handler)
|
||||||
|
|
||||||
assert len(llm.async_get_apis(hass)) == 1
|
assert len(llm.async_get_apis(hass)) == 1
|
||||||
api = llm.async_get_api(hass, "assist")
|
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||||
tools = api.async_get_tools()
|
assert len(api.tools) == 0
|
||||||
assert len(tools) == 1
|
|
||||||
tool = tools[0]
|
# Match all
|
||||||
|
intent_handler.platforms = None
|
||||||
|
|
||||||
|
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||||
|
assert len(api.tools) == 1
|
||||||
|
|
||||||
|
# Match specific domain
|
||||||
|
intent_handler.platforms = {"light"}
|
||||||
|
|
||||||
|
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||||
|
assert len(api.tools) == 1
|
||||||
|
tool = api.tools[0]
|
||||||
assert tool.name == "test_intent"
|
assert tool.name == "test_intent"
|
||||||
assert tool.description == "Execute Home Assistant test_intent intent"
|
assert tool.description == "Execute Home Assistant test_intent intent"
|
||||||
assert tool.parameters == vol.Schema(intent_handler.slot_schema)
|
assert tool.parameters == vol.Schema(intent_handler.slot_schema)
|
||||||
assert str(tool) == "<IntentTool - test_intent>"
|
assert str(tool) == "<IntentTool - test_intent>"
|
||||||
|
|
||||||
test_context = Context()
|
|
||||||
assert test_context.json_fragment # To reproduce an error case in tracing
|
assert test_context.json_fragment # To reproduce an error case in tracing
|
||||||
intent_response = intent.IntentResponse("*")
|
intent_response = intent.IntentResponse("*")
|
||||||
intent_response.matched_states = [State("light.matched", "on")]
|
intent_response.matched_states = [State("light.matched", "on")]
|
||||||
|
@ -100,12 +143,6 @@ async def test_assist_api(hass: HomeAssistant) -> None:
|
||||||
tool_input = llm.ToolInput(
|
tool_input = llm.ToolInput(
|
||||||
tool_name="test_intent",
|
tool_name="test_intent",
|
||||||
tool_args={"area": "kitchen", "floor": "ground_floor"},
|
tool_args={"area": "kitchen", "floor": "ground_floor"},
|
||||||
platform="test_platform",
|
|
||||||
context=test_context,
|
|
||||||
user_prompt="test_text",
|
|
||||||
language="*",
|
|
||||||
assistant="test_assistant",
|
|
||||||
device_id="test_device",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
|
@ -114,18 +151,18 @@ async def test_assist_api(hass: HomeAssistant) -> None:
|
||||||
response = await api.async_call_tool(tool_input)
|
response = await api.async_call_tool(tool_input)
|
||||||
|
|
||||||
mock_intent_handle.assert_awaited_once_with(
|
mock_intent_handle.assert_awaited_once_with(
|
||||||
hass,
|
hass=hass,
|
||||||
"test_platform",
|
platform="test_platform",
|
||||||
"test_intent",
|
intent_type="test_intent",
|
||||||
{
|
slots={
|
||||||
"area": {"value": "kitchen"},
|
"area": {"value": "kitchen"},
|
||||||
"floor": {"value": "ground_floor"},
|
"floor": {"value": "ground_floor"},
|
||||||
},
|
},
|
||||||
"test_text",
|
text_input="test_text",
|
||||||
test_context,
|
context=test_context,
|
||||||
"*",
|
language="*",
|
||||||
"test_assistant",
|
assistant="conversation",
|
||||||
"test_device",
|
device_id="test_device",
|
||||||
)
|
)
|
||||||
assert response == {
|
assert response == {
|
||||||
"card": {},
|
"card": {},
|
||||||
|
@ -140,7 +177,27 @@ async def test_assist_api(hass: HomeAssistant) -> None:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_assist_api_description(hass: HomeAssistant) -> None:
|
async def test_assist_api_get_timer_tools(
|
||||||
|
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
||||||
|
) -> None:
|
||||||
|
"""Test getting timer tools with Assist API."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
assert await async_setup_component(hass, "intent", {})
|
||||||
|
api = await llm.async_get_api(hass, "assist", tool_input_context)
|
||||||
|
|
||||||
|
assert "HassStartTimer" not in [tool.name for tool in api.tools]
|
||||||
|
|
||||||
|
tool_input_context.device_id = "test_device"
|
||||||
|
|
||||||
|
async_register_timer_handler(hass, "test_device", lambda *args: None)
|
||||||
|
|
||||||
|
api = await llm.async_get_api(hass, "assist", tool_input_context)
|
||||||
|
assert "HassStartTimer" in [tool.name for tool in api.tools]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_assist_api_description(
|
||||||
|
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
||||||
|
) -> None:
|
||||||
"""Test intent description with Assist API."""
|
"""Test intent description with Assist API."""
|
||||||
|
|
||||||
class MyIntentHandler(intent.IntentHandler):
|
class MyIntentHandler(intent.IntentHandler):
|
||||||
|
@ -150,10 +207,9 @@ async def test_assist_api_description(hass: HomeAssistant) -> None:
|
||||||
intent.async_register(hass, MyIntentHandler())
|
intent.async_register(hass, MyIntentHandler())
|
||||||
|
|
||||||
assert len(llm.async_get_apis(hass)) == 1
|
assert len(llm.async_get_apis(hass)) == 1
|
||||||
api = llm.async_get_api(hass, "assist")
|
api = await llm.async_get_api(hass, "assist", tool_input_context)
|
||||||
tools = api.async_get_tools()
|
assert len(api.tools) == 1
|
||||||
assert len(tools) == 1
|
tool = api.tools[0]
|
||||||
tool = tools[0]
|
|
||||||
assert tool.name == "test_intent"
|
assert tool.name == "test_intent"
|
||||||
assert tool.description == "my intent handler"
|
assert tool.description == "my intent handler"
|
||||||
|
|
||||||
|
@ -167,20 +223,18 @@ async def test_assist_api_prompt(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test prompt for the assist API."""
|
"""Test prompt for the assist API."""
|
||||||
assert await async_setup_component(hass, "homeassistant", {})
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
assert await async_setup_component(hass, "intent", {})
|
||||||
context = Context()
|
context = Context()
|
||||||
tool_input = llm.ToolInput(
|
tool_context = llm.ToolContext(
|
||||||
tool_name=None,
|
|
||||||
tool_args=None,
|
|
||||||
platform="test_platform",
|
platform="test_platform",
|
||||||
context=context,
|
context=context,
|
||||||
user_prompt="test_text",
|
user_prompt="test_text",
|
||||||
language="*",
|
language="*",
|
||||||
assistant="conversation",
|
assistant="conversation",
|
||||||
device_id="test_device",
|
device_id=None,
|
||||||
)
|
)
|
||||||
api = llm.async_get_api(hass, "assist")
|
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||||
prompt = await api.async_get_api_prompt(tool_input)
|
assert api.api_prompt == (
|
||||||
assert prompt == (
|
|
||||||
"Only if the user wants to control a device, tell them to expose entities to their "
|
"Only if the user wants to control a device, tell them to expose entities to their "
|
||||||
"voice assistant in Home Assistant."
|
"voice assistant in Home Assistant."
|
||||||
)
|
)
|
||||||
|
@ -308,7 +362,7 @@ async def test_assist_api_prompt(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
exposed_entities = llm._get_exposed_entities(hass, tool_input.assistant)
|
exposed_entities = llm._get_exposed_entities(hass, tool_context.assistant)
|
||||||
assert exposed_entities == {
|
assert exposed_entities == {
|
||||||
"light.1": {
|
"light.1": {
|
||||||
"areas": "Test Area 2",
|
"areas": "Test Area 2",
|
||||||
|
@ -373,40 +427,55 @@ async def test_assist_api_prompt(
|
||||||
"Call the intent tools to control Home Assistant. "
|
"Call the intent tools to control Home Assistant. "
|
||||||
"When controlling an area, prefer passing area name."
|
"When controlling an area, prefer passing area name."
|
||||||
)
|
)
|
||||||
|
no_timer_prompt = "This device does not support timers."
|
||||||
|
|
||||||
prompt = await api.async_get_api_prompt(tool_input)
|
|
||||||
area_prompt = (
|
area_prompt = (
|
||||||
"Reject all generic commands like 'turn on the lights' because we don't know in what area "
|
"Reject all generic commands like 'turn on the lights' because we don't know in what area "
|
||||||
"this conversation is happening."
|
"this conversation is happening."
|
||||||
)
|
)
|
||||||
assert prompt == (
|
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||||
|
assert api.api_prompt == (
|
||||||
f"""{first_part_prompt}
|
f"""{first_part_prompt}
|
||||||
{area_prompt}
|
{area_prompt}
|
||||||
|
{no_timer_prompt}
|
||||||
{exposed_entities_prompt}"""
|
{exposed_entities_prompt}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fake that request is made from a specific device ID
|
# Fake that request is made from a specific device ID with an area
|
||||||
tool_input.device_id = device.id
|
tool_context.device_id = device.id
|
||||||
prompt = await api.async_get_api_prompt(tool_input)
|
|
||||||
area_prompt = (
|
area_prompt = (
|
||||||
"You are in area Test Area and all generic commands like 'turn on the lights' "
|
"You are in area Test Area and all generic commands like 'turn on the lights' "
|
||||||
"should target this area."
|
"should target this area."
|
||||||
)
|
)
|
||||||
assert prompt == (
|
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||||
|
assert api.api_prompt == (
|
||||||
f"""{first_part_prompt}
|
f"""{first_part_prompt}
|
||||||
{area_prompt}
|
{area_prompt}
|
||||||
|
{no_timer_prompt}
|
||||||
{exposed_entities_prompt}"""
|
{exposed_entities_prompt}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add floor
|
# Add floor
|
||||||
floor = floor_registry.async_create("2")
|
floor = floor_registry.async_create("2")
|
||||||
area_registry.async_update(area.id, floor_id=floor.floor_id)
|
area_registry.async_update(area.id, floor_id=floor.floor_id)
|
||||||
prompt = await api.async_get_api_prompt(tool_input)
|
|
||||||
area_prompt = (
|
area_prompt = (
|
||||||
"You are in area Test Area (floor 2) and all generic commands like 'turn on the lights' "
|
"You are in area Test Area (floor 2) and all generic commands like 'turn on the lights' "
|
||||||
"should target this area."
|
"should target this area."
|
||||||
)
|
)
|
||||||
assert prompt == (
|
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||||
|
assert api.api_prompt == (
|
||||||
|
f"""{first_part_prompt}
|
||||||
|
{area_prompt}
|
||||||
|
{no_timer_prompt}
|
||||||
|
{exposed_entities_prompt}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register device for timers
|
||||||
|
async_register_timer_handler(hass, device.id, lambda *args: None)
|
||||||
|
|
||||||
|
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||||
|
# The no_timer_prompt is gone
|
||||||
|
assert api.api_prompt == (
|
||||||
f"""{first_part_prompt}
|
f"""{first_part_prompt}
|
||||||
{area_prompt}
|
{area_prompt}
|
||||||
{exposed_entities_prompt}"""
|
{exposed_entities_prompt}"""
|
||||||
|
@ -418,8 +487,8 @@ async def test_assist_api_prompt(
|
||||||
mock_user.id = "12345"
|
mock_user.id = "12345"
|
||||||
mock_user.name = "Test User"
|
mock_user.name = "Test User"
|
||||||
with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user):
|
with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user):
|
||||||
prompt = await api.async_get_api_prompt(tool_input)
|
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||||
assert prompt == (
|
assert api.api_prompt == (
|
||||||
f"""{first_part_prompt}
|
f"""{first_part_prompt}
|
||||||
{area_prompt}
|
{area_prompt}
|
||||||
The user name is Test User.
|
The user name is Test User.
|
||||||
|
|
Loading…
Add table
Reference in a new issue