Add LLM tools support for Ollama (#120454)

* Add LLM tools support for Ollama

* fix tests

* coverage

* Separate call for tool parameters

* Fix example

* hint on parameters schema if LLM forgot to request it

* Switch to native tool call functionality

* Fix tests

* Fix tools list

* update strings and default model

* Ignore mypy error until fixed upstream

* Ignore mypy error until fixed upstream

* Add missing prompt part

* Update default model
This commit is contained in:
Denis Shulyaka 2024-07-29 04:19:53 +03:00 committed by GitHub
parent f98487ef18
commit 4b2073ca59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 465 additions and 255 deletions

View file

@ -18,7 +18,9 @@ from homeassistant.config_entries import (
ConfigFlowResult,
OptionsFlow,
)
from homeassistant.const import CONF_URL
from homeassistant.const import CONF_LLM_HASS_API, CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.helpers.selector import (
NumberSelector,
NumberSelectorConfig,
@ -40,7 +42,6 @@ from .const import (
DEFAULT_KEEP_ALIVE,
DEFAULT_MAX_HISTORY,
DEFAULT_MODEL,
DEFAULT_PROMPT,
DEFAULT_TIMEOUT,
DOMAIN,
MODEL_NAMES,
@ -208,25 +209,52 @@ class OllamaOptionsFlow(OptionsFlow):
) -> ConfigFlowResult:
"""Manage the options."""
if user_input is not None:
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(
title=_get_title(self.model), data=user_input
)
options = self.config_entry.options or MappingProxyType({})
schema = ollama_config_option_schema(options)
schema = ollama_config_option_schema(self.hass, options)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
)
def ollama_config_option_schema(options: MappingProxyType[str, Any]) -> dict:
def ollama_config_option_schema(
hass: HomeAssistant, options: MappingProxyType[str, Any]
) -> dict:
"""Ollama options schema."""
hass_apis: list[SelectOptionDict] = [
SelectOptionDict(
label="No control",
value="none",
)
]
hass_apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
)
return {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
description={
"suggested_value": options.get(
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
)
},
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Optional(
CONF_MAX_HISTORY,
description={

View file

@ -4,73 +4,6 @@ DOMAIN = "ollama"
CONF_MODEL = "model"
CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """{%- set used_domains = set([
"binary_sensor",
"climate",
"cover",
"fan",
"light",
"lock",
"sensor",
"switch",
"weather",
]) %}
{%- set used_attributes = set([
"temperature",
"current_temperature",
"temperature_unit",
"brightness",
"humidity",
"unit_of_measurement",
"device_class",
"current_position",
"percentage",
]) %}
This smart home is controlled by Home Assistant.
The current time is {{ now().strftime("%X") }}.
Today's date is {{ now().strftime("%x") }}.
An overview of the areas and the devices in this smart home:
```yaml
{%- for entity in exposed_entities: %}
{%- if entity.domain not in used_domains: %}
{%- continue %}
{%- endif %}
- domain: {{ entity.domain }}
{%- if entity.names | length == 1: %}
name: {{ entity.names[0] }}
{%- else: %}
names:
{%- for name in entity.names: %}
- {{ name }}
{%- endfor %}
{%- endif %}
{%- if entity.area_names | length == 1: %}
area: {{ entity.area_names[0] }}
{%- elif entity.area_names: %}
areas:
{%- for area_name in entity.area_names: %}
- {{ area_name }}
{%- endfor %}
{%- endif %}
state: {{ entity.state.state }}
{%- set attributes_key_printed = False %}
{%- for attr_name, attr_value in entity.state.attributes.items(): %}
{%- if attr_name in used_attributes: %}
{%- if not attributes_key_printed: %}
attributes:
{%- set attributes_key_printed = True %}
{%- endif %}
{{ attr_name }}: {{ attr_value }}
{%- endif %}
{%- endfor %}
{%- endfor %}
```
Answer the user's questions using the information about this smart home.
Keep your answers brief and do not apologize."""
CONF_KEEP_ALIVE = "keep_alive"
DEFAULT_KEEP_ALIVE = -1 # seconds. -1 = indefinite, 0 = never
@ -187,4 +120,4 @@ MODEL_NAMES = [ # https://ollama.com/library
"yi",
"zephyr",
]
DEFAULT_MODEL = "llama2:latest"
DEFAULT_MODEL = "llama3.1:latest"

View file

@ -2,26 +2,23 @@
from __future__ import annotations
from collections.abc import Callable
import json
import logging
import time
from typing import Literal
from typing import Any, Literal
import ollama
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
intent,
template,
)
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import intent, llm, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
@ -32,11 +29,13 @@ from .const import (
CONF_PROMPT,
DEFAULT_KEEP_ALIVE,
DEFAULT_MAX_HISTORY,
DEFAULT_PROMPT,
DOMAIN,
MAX_HISTORY_SECONDS,
)
from .models import ExposedEntity, MessageHistory, MessageRole
from .models import MessageHistory, MessageRole
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
_LOGGER = logging.getLogger(__name__)
@ -51,6 +50,19 @@ async def async_setup_entry(
async_add_entities([agent])
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> dict[str, Any]:
"""Format tool specification."""
tool_spec = {
"name": tool.name,
"parameters": convert(tool.parameters, custom_serializer=custom_serializer),
}
if tool.description:
tool_spec["description"] = tool.description
return {"type": "function", "function": tool_spec}
class OllamaConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
@ -94,6 +106,47 @@ class OllamaConversationEntity(
client = self.hass.data[DOMAIN][self.entry.entry_id]
conversation_id = user_input.conversation_id or ulid.ulid_now()
model = settings[CONF_MODEL]
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.APIInstance | None = None
tools: list[dict[str, Any]] | None = None
user_name: str | None = None
llm_context = llm.LLMContext(
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
if settings.get(CONF_LLM_HASS_API):
try:
llm_api = await llm.async_get_api(
self.hass,
settings[CONF_LLM_HASS_API],
llm_context,
)
except HomeAssistantError as err:
_LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = [
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
]
if (
user_input.context
and user_input.context.user_id
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
)
):
user_name = user.name
# Look up message history
message_history: MessageHistory | None = None
@ -102,13 +155,24 @@ class OllamaConversationEntity(
# New history
#
# Render prompt and error out early if there's a problem
raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT)
try:
prompt = self._generate_prompt(raw_prompt)
_LOGGER.debug("Prompt: %s", prompt)
prompt_parts = [
template.Template(
llm.BASE_PROMPT
+ settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
self.hass,
).async_render(
{
"ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
},
parse_result=False,
)
]
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem generating my prompt: {err}",
@ -117,6 +181,13 @@ class OllamaConversationEntity(
response=intent_response, conversation_id=conversation_id
)
if llm_api:
prompt_parts.append(llm_api.api_prompt)
prompt = "\n".join(prompt_parts)
_LOGGER.debug("Prompt: %s", prompt)
_LOGGER.debug("Tools: %s", tools)
message_history = MessageHistory(
timestamp=time.monotonic(),
messages=[
@ -146,35 +217,66 @@ class OllamaConversationEntity(
)
# Get response
try:
response = await client.chat(
model=model,
# Make a copy of the messages because we mutate the list later
messages=list(message_history.messages),
stream=False,
# keep_alive requires specifying unit. In this case, seconds
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
)
except (ollama.RequestError, ollama.ResponseError) as err:
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to the Ollama server: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
response = await client.chat(
model=model,
# Make a copy of the messages because we mutate the list later
messages=list(message_history.messages),
tools=tools,
stream=False,
# keep_alive requires specifying unit. In this case, seconds
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
)
except (ollama.RequestError, ollama.ResponseError) as err:
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to the Ollama server: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
response_message = response["message"]
message_history.messages.append(
ollama.Message(
role=response_message["role"],
content=response_message.get("content"),
tool_calls=response_message.get("tool_calls"),
)
)
response_message = response["message"]
message_history.messages.append(
ollama.Message(
role=response_message["role"], content=response_message["content"]
)
)
tool_calls = response_message.get("tool_calls")
if not tool_calls or not llm_api:
break
for tool_call in tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call["function"]["name"],
tool_args=tool_call["function"]["arguments"],
)
_LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
try:
tool_response = await llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
_LOGGER.debug("Tool response: %s", tool_response)
message_history.messages.append(
ollama.Message(
role=MessageRole.TOOL.value, # type: ignore[typeddict-item]
content=json.dumps(tool_response),
)
)
# Create intent response
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response_message["content"])
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
@ -204,62 +306,3 @@ class OllamaConversationEntity(
message_history.messages = [
message_history.messages[0]
] + message_history.messages[drop_index:]
def _generate_prompt(self, raw_prompt: str) -> str:
"""Generate a prompt for the user."""
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
"ha_language": self.hass.config.language,
"exposed_entities": self._get_exposed_entities(),
},
parse_result=False,
)
def _get_exposed_entities(self) -> list[ExposedEntity]:
"""Get state list of exposed entities."""
area_registry = ar.async_get(self.hass)
entity_registry = er.async_get(self.hass)
device_registry = dr.async_get(self.hass)
exposed_entities = []
exposed_states = [
state
for state in self.hass.states.async_all()
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
]
for state in exposed_states:
entity_entry = entity_registry.async_get(state.entity_id)
names = [state.name]
area_names = []
if entity_entry is not None:
# Add aliases
names.extend(entity_entry.aliases)
if entity_entry.area_id and (
area := area_registry.async_get_area(entity_entry.area_id)
):
# Entity is in area
area_names.append(area.name)
area_names.extend(area.aliases)
elif entity_entry.device_id and (
device := device_registry.async_get(entity_entry.device_id)
):
# Check device area
if device.area_id and (
area := area_registry.async_get_area(device.area_id)
):
area_names.append(area.name)
area_names.extend(area.aliases)
exposed_entities.append(
ExposedEntity(
entity_id=state.entity_id,
state=state,
names=names,
area_names=area_names,
)
)
return exposed_entities

View file

@ -1,7 +1,7 @@
{
"domain": "ollama",
"name": "Ollama",
"after_dependencies": ["assist_pipeline"],
"after_dependencies": ["assist_pipeline", "intent"],
"codeowners": ["@synesthesiam"],
"config_flow": true,
"dependencies": ["conversation"],

View file

@ -2,18 +2,17 @@
from dataclasses import dataclass
from enum import StrEnum
from functools import cached_property
import ollama
from homeassistant.core import State
class MessageRole(StrEnum):
"""Role of a chat message."""
SYSTEM = "system" # prompt
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
@dataclass
@ -30,18 +29,3 @@ class MessageHistory:
def num_user_messages(self) -> int:
"""Return a count of user messages."""
return sum(m["role"] == MessageRole.USER.value for m in self.messages)
@dataclass(frozen=True)
class ExposedEntity:
"""Relevant information about an exposed entity."""
entity_id: str
state: State
names: list[str]
area_names: list[str]
@cached_property
def domain(self) -> str:
"""Get domain from entity id."""
return self.entity_id.split(".", maxsplit=1)[0]

View file

@ -24,11 +24,13 @@
"step": {
"init": {
"data": {
"prompt": "Prompt template",
"prompt": "Instructions",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"max_history": "Max history messages",
"keep_alive": "Keep alive"
},
"data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template.",
"keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never."
}
}

View file

@ -1,7 +1,7 @@
"""Tests for the Ollama integration."""
from homeassistant.components import ollama
from homeassistant.components.ollama.const import DEFAULT_PROMPT
from homeassistant.helpers import llm
TEST_USER_DATA = {
ollama.CONF_URL: "http://localhost:11434",
@ -9,6 +9,6 @@ TEST_USER_DATA = {
}
TEST_OPTIONS = {
ollama.CONF_PROMPT: DEFAULT_PROMPT,
ollama.CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
ollama.CONF_MAX_HISTORY: 2,
}

View file

@ -5,7 +5,9 @@ from unittest.mock import patch
import pytest
from homeassistant.components import ollama
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.setup import async_setup_component
from . import TEST_OPTIONS, TEST_USER_DATA
@ -25,6 +27,17 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
return entry
@pytest.fixture
def mock_config_entry_with_assist(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> MockConfigEntry:
"""Mock a config entry with assist."""
hass.config_entries.async_update_entry(
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
)
return mock_config_entry
@pytest.fixture
async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfigEntry):
"""Initialize integration."""

View file

@ -0,0 +1,34 @@
# serializer version: 1
# name: test_unknown_hass_api
dict({
'conversation_id': None,
'response': IntentResponse(
card=dict({
}),
error_code=<IntentResponseErrorCode.UNKNOWN: 'unknown'>,
failed_results=list([
]),
intent=None,
intent_targets=list([
]),
language='en',
matched_states=list([
]),
reprompt=dict({
}),
response_type=<IntentResponseType.ERROR: 'error'>,
speech=dict({
'plain': dict({
'extra_data': None,
'speech': 'Error preparing LLM API: API non-existing not found',
}),
}),
speech_slots=dict({
}),
success_results=list([
]),
unmatched_states=list([
]),
),
})
# ---

View file

@ -1,21 +1,18 @@
"""Tests for the Ollama integration."""
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
from ollama import Message, ResponseError
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation, ollama
from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
intent,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm
from tests.common import MockConfigEntry
@ -25,9 +22,6 @@ async def test_chat(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
agent_id: str,
) -> None:
"""Test that the chat function is called with the appropriate arguments."""
@ -35,48 +29,8 @@ async def test_chat(
if agent_id is None:
agent_id = mock_config_entry.entry_id
# Create some areas, devices, and entities
area_kitchen = area_registry.async_get_or_create("kitchen_id")
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
area_bedroom = area_registry.async_get_or_create("bedroom_id")
area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom")
area_office = area_registry.async_get_or_create("office_id")
area_office = area_registry.async_update(area_office.id, name="office")
entry = MockConfigEntry()
entry.add_to_hass(hass)
kitchen_device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections=set(),
identifiers={("demo", "id-1234")},
)
device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id)
kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234")
kitchen_light = entity_registry.async_update_entity(
kitchen_light.entity_id, device_id=kitchen_device.id
)
hass.states.async_set(
kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
)
bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678")
bedroom_light = entity_registry.async_update_entity(
bedroom_light.entity_id, area_id=area_bedroom.id
)
hass.states.async_set(
bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"}
)
# Hide the office light
office_light = entity_registry.async_get_or_create("light", "demo", "ABCD")
office_light = entity_registry.async_update_entity(
office_light.entity_id, area_id=area_office.id
)
hass.states.async_set(
office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"}
)
async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False)
with patch(
"ollama.AsyncClient.chat",
@ -100,12 +54,6 @@ async def test_chat(
Message({"role": "user", "content": "test message"}),
]
# Verify only exposed devices/areas are in prompt
assert "kitchen light" in prompt
assert "bedroom light" in prompt
assert "office light" not in prompt
assert "office" not in prompt
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
@ -122,7 +70,232 @@ async def test_chat(
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
assert "Current time is" in detail_event["data"]["messages"][0]["content"]
async def test_template_variables(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that template variables work."""
context = Context(user_id="12345")
mock_user = Mock()
mock_user.id = "12345"
mock_user.name = "Test User"
hass.config_entries.async_update_entry(
mock_config_entry,
options={
"prompt": (
"The user name is {{ user_name }}. "
"The user id is {{ llm_context.context.user_id }}."
),
},
)
with (
patch("ollama.AsyncClient.list"),
patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
) as mock_chat,
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(
hass, "hello", None, context, agent_id=mock_config_entry.entry_id
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
args = mock_chat.call_args.kwargs
prompt = args["messages"][0]["content"]
assert "The user name is Test User." in prompt
assert "The user id is 12345." in prompt
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
async def test_function_call(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test function call from the assistant."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
mock_tool.async_call.return_value = "Test response"
mock_get_tools.return_value = [mock_tool]
def completion_result(*args, messages, **kwargs):
for message in messages:
if message["role"] == "tool":
return {
"message": {
"role": "assistant",
"content": "I have successfully called the function",
}
}
return {
"message": {
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "test_tool",
"arguments": {"param1": "test_value"},
}
}
],
}
}
with patch(
"ollama.AsyncClient.chat",
side_effect=completion_result,
) as mock_chat:
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert mock_chat.call_count == 2
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.speech["plain"]["speech"]
== "I have successfully called the function"
)
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "test_value"},
),
llm.LLMContext(
platform="ollama",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
device_id=None,
),
)
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
async def test_function_exception(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test function call with exception."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
mock_tool.async_call.side_effect = HomeAssistantError("Test tool exception")
mock_get_tools.return_value = [mock_tool]
def completion_result(*args, messages, **kwargs):
for message in messages:
if message["role"] == "tool":
return {
"message": {
"role": "assistant",
"content": "There was an error calling the function",
}
}
return {
"message": {
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "test_tool",
"arguments": {"param1": "test_value"},
}
}
],
}
}
with patch(
"ollama.AsyncClient.chat",
side_effect=completion_result,
) as mock_chat:
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert mock_chat.call_count == 2
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.speech["plain"]["speech"]
== "There was an error calling the function"
)
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "test_value"},
),
llm.LLMContext(
platform="ollama",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
device_id=None,
),
)
async def test_unknown_hass_api(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
snapshot: SnapshotAssertion,
mock_init_component,
) -> None:
"""Test when we reference an API that no longer exists."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={
**mock_config_entry.options,
CONF_LLM_HASS_API: "non-existing",
},
)
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result == snapshot
async def test_message_history_trimming(