LLM Assist API to ignore intents if not needed for exposed entities or calling device (#118283)

* LLM Assist API to ignore timer intents if device doesn't support it

* Refactor to use API instances

* Extract ToolContext class

* Limit exposed intents based on exposed entities
This commit is contained in:
Paulus Schoutsen 2024-05-28 21:29:18 -04:00 committed by GitHub
parent e0264c8604
commit 615a1eda51
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 302 additions and 181 deletions

View file

@ -61,11 +61,11 @@ async def test_default_prompt(
with (
patch("google.generativeai.GenerativeModel") as mock_model,
patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools",
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools",
return_value=[],
) as mock_get_tools,
patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_api_prompt",
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_api_prompt",
return_value="<api_prompt>",
),
patch(
@ -148,7 +148,7 @@ async def test_chat_history(
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
)
async def test_function_call(
mock_get_tools,
@ -182,7 +182,7 @@ async def test_function_call(
mock_part.function_call.name = "test_tool"
mock_part.function_call.args = {"param1": ["test_value"]}
def tool_call(hass, tool_input):
def tool_call(hass, tool_input, tool_context):
mock_part.function_call = None
mock_part.text = "Hi there!"
return {"result": "Test response"}
@ -221,6 +221,8 @@ async def test_function_call(
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": ["test_value"]},
),
llm.ToolContext(
platform="google_generative_ai_conversation",
context=context,
user_prompt="Please call the test function",
@ -246,7 +248,7 @@ async def test_function_call(
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
)
async def test_function_exception(
mock_get_tools,
@ -280,7 +282,7 @@ async def test_function_exception(
mock_part.function_call.name = "test_tool"
mock_part.function_call.args = {"param1": 1}
def tool_call(hass, tool_input):
def tool_call(hass, tool_input, tool_context):
mock_part.function_call = None
mock_part.text = "Hi there!"
raise HomeAssistantError("Test tool exception")
@ -319,6 +321,8 @@ async def test_function_exception(
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": 1},
),
llm.ToolContext(
platform="google_generative_ai_conversation",
context=context,
user_prompt="Please call the test function",