Add LLM tools support for Ollama (#120454)
* Add LLM tools support for Ollama * fix tests * coverage * Separate call for tool parameters * Fix example * hint on parameters schema if LLM forgot to request it * Switch to native tool call functionality * Fix tests * Fix tools list * update strings and default model * Ignore mypy error until fixed upstream * Ignore mypy error until fixed upstream * Add missing prompt part * Update default model
This commit is contained in:
parent
f98487ef18
commit
4b2073ca59
10 changed files with 465 additions and 255 deletions
|
@ -18,7 +18,9 @@ from homeassistant.config_entries import (
|
|||
ConfigFlowResult,
|
||||
OptionsFlow,
|
||||
)
|
||||
from homeassistant.const import CONF_URL
|
||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_URL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.selector import (
|
||||
NumberSelector,
|
||||
NumberSelectorConfig,
|
||||
|
@ -40,7 +42,6 @@ from .const import (
|
|||
DEFAULT_KEEP_ALIVE,
|
||||
DEFAULT_MAX_HISTORY,
|
||||
DEFAULT_MODEL,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_TIMEOUT,
|
||||
DOMAIN,
|
||||
MODEL_NAMES,
|
||||
|
@ -208,25 +209,52 @@ class OllamaOptionsFlow(OptionsFlow):
|
|||
) -> ConfigFlowResult:
|
||||
"""Manage the options."""
|
||||
if user_input is not None:
|
||||
if user_input[CONF_LLM_HASS_API] == "none":
|
||||
user_input.pop(CONF_LLM_HASS_API)
|
||||
return self.async_create_entry(
|
||||
title=_get_title(self.model), data=user_input
|
||||
)
|
||||
|
||||
options = self.config_entry.options or MappingProxyType({})
|
||||
schema = ollama_config_option_schema(options)
|
||||
schema = ollama_config_option_schema(self.hass, options)
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=vol.Schema(schema),
|
||||
)
|
||||
|
||||
|
||||
def ollama_config_option_schema(options: MappingProxyType[str, Any]) -> dict:
|
||||
def ollama_config_option_schema(
|
||||
hass: HomeAssistant, options: MappingProxyType[str, Any]
|
||||
) -> dict:
|
||||
"""Ollama options schema."""
|
||||
hass_apis: list[SelectOptionDict] = [
|
||||
SelectOptionDict(
|
||||
label="No control",
|
||||
value="none",
|
||||
)
|
||||
]
|
||||
hass_apis.extend(
|
||||
SelectOptionDict(
|
||||
label=api.name,
|
||||
value=api.id,
|
||||
)
|
||||
for api in llm.async_get_apis(hass)
|
||||
)
|
||||
|
||||
return {
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
|
||||
description={
|
||||
"suggested_value": options.get(
|
||||
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
||||
)
|
||||
},
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||
default="none",
|
||||
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
|
||||
vol.Optional(
|
||||
CONF_MAX_HISTORY,
|
||||
description={
|
||||
|
|
|
@ -4,73 +4,6 @@ DOMAIN = "ollama"
|
|||
|
||||
CONF_MODEL = "model"
|
||||
CONF_PROMPT = "prompt"
|
||||
DEFAULT_PROMPT = """{%- set used_domains = set([
|
||||
"binary_sensor",
|
||||
"climate",
|
||||
"cover",
|
||||
"fan",
|
||||
"light",
|
||||
"lock",
|
||||
"sensor",
|
||||
"switch",
|
||||
"weather",
|
||||
]) %}
|
||||
{%- set used_attributes = set([
|
||||
"temperature",
|
||||
"current_temperature",
|
||||
"temperature_unit",
|
||||
"brightness",
|
||||
"humidity",
|
||||
"unit_of_measurement",
|
||||
"device_class",
|
||||
"current_position",
|
||||
"percentage",
|
||||
]) %}
|
||||
|
||||
This smart home is controlled by Home Assistant.
|
||||
The current time is {{ now().strftime("%X") }}.
|
||||
Today's date is {{ now().strftime("%x") }}.
|
||||
|
||||
An overview of the areas and the devices in this smart home:
|
||||
```yaml
|
||||
{%- for entity in exposed_entities: %}
|
||||
{%- if entity.domain not in used_domains: %}
|
||||
{%- continue %}
|
||||
{%- endif %}
|
||||
|
||||
- domain: {{ entity.domain }}
|
||||
{%- if entity.names | length == 1: %}
|
||||
name: {{ entity.names[0] }}
|
||||
{%- else: %}
|
||||
names:
|
||||
{%- for name in entity.names: %}
|
||||
- {{ name }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- if entity.area_names | length == 1: %}
|
||||
area: {{ entity.area_names[0] }}
|
||||
{%- elif entity.area_names: %}
|
||||
areas:
|
||||
{%- for area_name in entity.area_names: %}
|
||||
- {{ area_name }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
state: {{ entity.state.state }}
|
||||
{%- set attributes_key_printed = False %}
|
||||
{%- for attr_name, attr_value in entity.state.attributes.items(): %}
|
||||
{%- if attr_name in used_attributes: %}
|
||||
{%- if not attributes_key_printed: %}
|
||||
attributes:
|
||||
{%- set attributes_key_printed = True %}
|
||||
{%- endif %}
|
||||
{{ attr_name }}: {{ attr_value }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
```
|
||||
|
||||
Answer the user's questions using the information about this smart home.
|
||||
Keep your answers brief and do not apologize."""
|
||||
|
||||
CONF_KEEP_ALIVE = "keep_alive"
|
||||
DEFAULT_KEEP_ALIVE = -1 # seconds. -1 = indefinite, 0 = never
|
||||
|
@ -187,4 +120,4 @@ MODEL_NAMES = [ # https://ollama.com/library
|
|||
"yi",
|
||||
"zephyr",
|
||||
]
|
||||
DEFAULT_MODEL = "llama2:latest"
|
||||
DEFAULT_MODEL = "llama3.1:latest"
|
||||
|
|
|
@ -2,26 +2,23 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import ollama
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
intent,
|
||||
template,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.helpers import intent, llm, template
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util import ulid
|
||||
|
||||
|
@ -32,11 +29,13 @@ from .const import (
|
|||
CONF_PROMPT,
|
||||
DEFAULT_KEEP_ALIVE,
|
||||
DEFAULT_MAX_HISTORY,
|
||||
DEFAULT_PROMPT,
|
||||
DOMAIN,
|
||||
MAX_HISTORY_SECONDS,
|
||||
)
|
||||
from .models import ExposedEntity, MessageHistory, MessageRole
|
||||
from .models import MessageHistory, MessageRole
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -51,6 +50,19 @@ async def async_setup_entry(
|
|||
async_add_entities([agent])
|
||||
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||
) -> dict[str, Any]:
|
||||
"""Format tool specification."""
|
||||
tool_spec = {
|
||||
"name": tool.name,
|
||||
"parameters": convert(tool.parameters, custom_serializer=custom_serializer),
|
||||
}
|
||||
if tool.description:
|
||||
tool_spec["description"] = tool.description
|
||||
return {"type": "function", "function": tool_spec}
|
||||
|
||||
|
||||
class OllamaConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
|
@ -94,6 +106,47 @@ class OllamaConversationEntity(
|
|||
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
||||
conversation_id = user_input.conversation_id or ulid.ulid_now()
|
||||
model = settings[CONF_MODEL]
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
llm_api: llm.APIInstance | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
user_name: str | None = None
|
||||
llm_context = llm.LLMContext(
|
||||
platform=DOMAIN,
|
||||
context=user_input.context,
|
||||
user_prompt=user_input.text,
|
||||
language=user_input.language,
|
||||
assistant=conversation.DOMAIN,
|
||||
device_id=user_input.device_id,
|
||||
)
|
||||
|
||||
if settings.get(CONF_LLM_HASS_API):
|
||||
try:
|
||||
llm_api = await llm.async_get_api(
|
||||
self.hass,
|
||||
settings[CONF_LLM_HASS_API],
|
||||
llm_context,
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error("Error getting LLM API: %s", err)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Error preparing LLM API: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
tools = [
|
||||
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
||||
]
|
||||
|
||||
if (
|
||||
user_input.context
|
||||
and user_input.context.user_id
|
||||
and (
|
||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
||||
)
|
||||
):
|
||||
user_name = user.name
|
||||
|
||||
# Look up message history
|
||||
message_history: MessageHistory | None = None
|
||||
|
@ -102,13 +155,24 @@ class OllamaConversationEntity(
|
|||
# New history
|
||||
#
|
||||
# Render prompt and error out early if there's a problem
|
||||
raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
try:
|
||||
prompt = self._generate_prompt(raw_prompt)
|
||||
_LOGGER.debug("Prompt: %s", prompt)
|
||||
prompt_parts = [
|
||||
template.Template(
|
||||
llm.BASE_PROMPT
|
||||
+ settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
||||
self.hass,
|
||||
).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"user_name": user_name,
|
||||
"llm_context": llm_context,
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
]
|
||||
|
||||
except TemplateError as err:
|
||||
_LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem generating my prompt: {err}",
|
||||
|
@ -117,6 +181,13 @@ class OllamaConversationEntity(
|
|||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
if llm_api:
|
||||
prompt_parts.append(llm_api.api_prompt)
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
_LOGGER.debug("Prompt: %s", prompt)
|
||||
_LOGGER.debug("Tools: %s", tools)
|
||||
|
||||
message_history = MessageHistory(
|
||||
timestamp=time.monotonic(),
|
||||
messages=[
|
||||
|
@ -146,35 +217,66 @@ class OllamaConversationEntity(
|
|||
)
|
||||
|
||||
# Get response
|
||||
try:
|
||||
response = await client.chat(
|
||||
model=model,
|
||||
# Make a copy of the messages because we mutate the list later
|
||||
messages=list(message_history.messages),
|
||||
stream=False,
|
||||
# keep_alive requires specifying unit. In this case, seconds
|
||||
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
|
||||
)
|
||||
except (ollama.RequestError, ollama.ResponseError) as err:
|
||||
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem talking to the Ollama server: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
response = await client.chat(
|
||||
model=model,
|
||||
# Make a copy of the messages because we mutate the list later
|
||||
messages=list(message_history.messages),
|
||||
tools=tools,
|
||||
stream=False,
|
||||
# keep_alive requires specifying unit. In this case, seconds
|
||||
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
|
||||
)
|
||||
except (ollama.RequestError, ollama.ResponseError) as err:
|
||||
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem talking to the Ollama server: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
response_message = response["message"]
|
||||
message_history.messages.append(
|
||||
ollama.Message(
|
||||
role=response_message["role"],
|
||||
content=response_message.get("content"),
|
||||
tool_calls=response_message.get("tool_calls"),
|
||||
)
|
||||
)
|
||||
|
||||
response_message = response["message"]
|
||||
message_history.messages.append(
|
||||
ollama.Message(
|
||||
role=response_message["role"], content=response_message["content"]
|
||||
)
|
||||
)
|
||||
tool_calls = response_message.get("tool_calls")
|
||||
if not tool_calls or not llm_api:
|
||||
break
|
||||
|
||||
for tool_call in tool_calls:
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_call["function"]["name"],
|
||||
tool_args=tool_call["function"]["arguments"],
|
||||
)
|
||||
_LOGGER.debug(
|
||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||
)
|
||||
|
||||
try:
|
||||
tool_response = await llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
tool_response = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
tool_response["error_text"] = str(e)
|
||||
|
||||
_LOGGER.debug("Tool response: %s", tool_response)
|
||||
message_history.messages.append(
|
||||
ollama.Message(
|
||||
role=MessageRole.TOOL.value, # type: ignore[typeddict-item]
|
||||
content=json.dumps(tool_response),
|
||||
)
|
||||
)
|
||||
|
||||
# Create intent response
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(response_message["content"])
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
|
@ -204,62 +306,3 @@ class OllamaConversationEntity(
|
|||
message_history.messages = [
|
||||
message_history.messages[0]
|
||||
] + message_history.messages[drop_index:]
|
||||
|
||||
def _generate_prompt(self, raw_prompt: str) -> str:
|
||||
"""Generate a prompt for the user."""
|
||||
return template.Template(raw_prompt, self.hass).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"ha_language": self.hass.config.language,
|
||||
"exposed_entities": self._get_exposed_entities(),
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
|
||||
def _get_exposed_entities(self) -> list[ExposedEntity]:
|
||||
"""Get state list of exposed entities."""
|
||||
area_registry = ar.async_get(self.hass)
|
||||
entity_registry = er.async_get(self.hass)
|
||||
device_registry = dr.async_get(self.hass)
|
||||
|
||||
exposed_entities = []
|
||||
exposed_states = [
|
||||
state
|
||||
for state in self.hass.states.async_all()
|
||||
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
|
||||
]
|
||||
|
||||
for state in exposed_states:
|
||||
entity_entry = entity_registry.async_get(state.entity_id)
|
||||
names = [state.name]
|
||||
area_names = []
|
||||
|
||||
if entity_entry is not None:
|
||||
# Add aliases
|
||||
names.extend(entity_entry.aliases)
|
||||
if entity_entry.area_id and (
|
||||
area := area_registry.async_get_area(entity_entry.area_id)
|
||||
):
|
||||
# Entity is in area
|
||||
area_names.append(area.name)
|
||||
area_names.extend(area.aliases)
|
||||
elif entity_entry.device_id and (
|
||||
device := device_registry.async_get(entity_entry.device_id)
|
||||
):
|
||||
# Check device area
|
||||
if device.area_id and (
|
||||
area := area_registry.async_get_area(device.area_id)
|
||||
):
|
||||
area_names.append(area.name)
|
||||
area_names.extend(area.aliases)
|
||||
|
||||
exposed_entities.append(
|
||||
ExposedEntity(
|
||||
entity_id=state.entity_id,
|
||||
state=state,
|
||||
names=names,
|
||||
area_names=area_names,
|
||||
)
|
||||
)
|
||||
|
||||
return exposed_entities
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"domain": "ollama",
|
||||
"name": "Ollama",
|
||||
"after_dependencies": ["assist_pipeline"],
|
||||
"after_dependencies": ["assist_pipeline", "intent"],
|
||||
"codeowners": ["@synesthesiam"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["conversation"],
|
||||
|
|
|
@ -2,18 +2,17 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from functools import cached_property
|
||||
|
||||
import ollama
|
||||
|
||||
from homeassistant.core import State
|
||||
|
||||
|
||||
class MessageRole(StrEnum):
|
||||
"""Role of a chat message."""
|
||||
|
||||
SYSTEM = "system" # prompt
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -30,18 +29,3 @@ class MessageHistory:
|
|||
def num_user_messages(self) -> int:
|
||||
"""Return a count of user messages."""
|
||||
return sum(m["role"] == MessageRole.USER.value for m in self.messages)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExposedEntity:
|
||||
"""Relevant information about an exposed entity."""
|
||||
|
||||
entity_id: str
|
||||
state: State
|
||||
names: list[str]
|
||||
area_names: list[str]
|
||||
|
||||
@cached_property
|
||||
def domain(self) -> str:
|
||||
"""Get domain from entity id."""
|
||||
return self.entity_id.split(".", maxsplit=1)[0]
|
||||
|
|
|
@ -24,11 +24,13 @@
|
|||
"step": {
|
||||
"init": {
|
||||
"data": {
|
||||
"prompt": "Prompt template",
|
||||
"prompt": "Instructions",
|
||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
||||
"max_history": "Max history messages",
|
||||
"keep_alive": "Keep alive"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "Instruct how the LLM should respond. This can be a template.",
|
||||
"keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never."
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Tests for the Ollama integration."""
|
||||
|
||||
from homeassistant.components import ollama
|
||||
from homeassistant.components.ollama.const import DEFAULT_PROMPT
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
TEST_USER_DATA = {
|
||||
ollama.CONF_URL: "http://localhost:11434",
|
||||
|
@ -9,6 +9,6 @@ TEST_USER_DATA = {
|
|||
}
|
||||
|
||||
TEST_OPTIONS = {
|
||||
ollama.CONF_PROMPT: DEFAULT_PROMPT,
|
||||
ollama.CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||
ollama.CONF_MAX_HISTORY: 2,
|
||||
}
|
||||
|
|
|
@ -5,7 +5,9 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
|
||||
from homeassistant.components import ollama
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import TEST_OPTIONS, TEST_USER_DATA
|
||||
|
@ -25,6 +27,17 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
|||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry_with_assist(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> MockConfigEntry:
|
||||
"""Mock a config entry with assist."""
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
|
||||
)
|
||||
return mock_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfigEntry):
|
||||
"""Initialize integration."""
|
||||
|
|
34
tests/components/ollama/snapshots/test_conversation.ambr
Normal file
34
tests/components/ollama/snapshots/test_conversation.ambr
Normal file
|
@ -0,0 +1,34 @@
|
|||
# serializer version: 1
|
||||
# name: test_unknown_hass_api
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'response': IntentResponse(
|
||||
card=dict({
|
||||
}),
|
||||
error_code=<IntentResponseErrorCode.UNKNOWN: 'unknown'>,
|
||||
failed_results=list([
|
||||
]),
|
||||
intent=None,
|
||||
intent_targets=list([
|
||||
]),
|
||||
language='en',
|
||||
matched_states=list([
|
||||
]),
|
||||
reprompt=dict({
|
||||
}),
|
||||
response_type=<IntentResponseType.ERROR: 'error'>,
|
||||
speech=dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': 'Error preparing LLM API: API non-existing not found',
|
||||
}),
|
||||
}),
|
||||
speech_slots=dict({
|
||||
}),
|
||||
success_results=list([
|
||||
]),
|
||||
unmatched_states=list([
|
||||
]),
|
||||
),
|
||||
})
|
||||
# ---
|
|
@ -1,21 +1,18 @@
|
|||
"""Tests for the Ollama integration."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from ollama import Message, ResponseError
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation, ollama
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
intent,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import intent, llm
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
@ -25,9 +22,6 @@ async def test_chat(
|
|||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
area_registry: ar.AreaRegistry,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Test that the chat function is called with the appropriate arguments."""
|
||||
|
@ -35,48 +29,8 @@ async def test_chat(
|
|||
if agent_id is None:
|
||||
agent_id = mock_config_entry.entry_id
|
||||
|
||||
# Create some areas, devices, and entities
|
||||
area_kitchen = area_registry.async_get_or_create("kitchen_id")
|
||||
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
|
||||
area_bedroom = area_registry.async_get_or_create("bedroom_id")
|
||||
area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom")
|
||||
area_office = area_registry.async_get_or_create("office_id")
|
||||
area_office = area_registry.async_update(area_office.id, name="office")
|
||||
|
||||
entry = MockConfigEntry()
|
||||
entry.add_to_hass(hass)
|
||||
kitchen_device = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections=set(),
|
||||
identifiers={("demo", "id-1234")},
|
||||
)
|
||||
device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id)
|
||||
|
||||
kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234")
|
||||
kitchen_light = entity_registry.async_update_entity(
|
||||
kitchen_light.entity_id, device_id=kitchen_device.id
|
||||
)
|
||||
hass.states.async_set(
|
||||
kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
|
||||
)
|
||||
|
||||
bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678")
|
||||
bedroom_light = entity_registry.async_update_entity(
|
||||
bedroom_light.entity_id, area_id=area_bedroom.id
|
||||
)
|
||||
hass.states.async_set(
|
||||
bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"}
|
||||
)
|
||||
|
||||
# Hide the office light
|
||||
office_light = entity_registry.async_get_or_create("light", "demo", "ABCD")
|
||||
office_light = entity_registry.async_update_entity(
|
||||
office_light.entity_id, area_id=area_office.id
|
||||
)
|
||||
hass.states.async_set(
|
||||
office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"}
|
||||
)
|
||||
async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False)
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
|
@ -100,12 +54,6 @@ async def test_chat(
|
|||
Message({"role": "user", "content": "test message"}),
|
||||
]
|
||||
|
||||
# Verify only exposed devices/areas are in prompt
|
||||
assert "kitchen light" in prompt
|
||||
assert "bedroom light" in prompt
|
||||
assert "office light" not in prompt
|
||||
assert "office" not in prompt
|
||||
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
|
@ -122,7 +70,232 @@ async def test_chat(
|
|||
]
|
||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||
detail_event = trace_events[1]
|
||||
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
|
||||
assert "Current time is" in detail_event["data"]["messages"][0]["content"]
|
||||
|
||||
|
||||
async def test_template_variables(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test that template variables work."""
|
||||
context = Context(user_id="12345")
|
||||
mock_user = Mock()
|
||||
mock_user.id = "12345"
|
||||
mock_user.name = "Test User"
|
||||
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={
|
||||
"prompt": (
|
||||
"The user name is {{ user_name }}. "
|
||||
"The user id is {{ llm_context.context.user_id }}."
|
||||
),
|
||||
},
|
||||
)
|
||||
with (
|
||||
patch("ollama.AsyncClient.list"),
|
||||
patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
) as mock_chat,
|
||||
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
|
||||
):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, context, agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
|
||||
args = mock_chat.call_args.kwargs
|
||||
prompt = args["messages"][0]["content"]
|
||||
|
||||
assert "The user name is Test User." in prompt
|
||||
assert "The user id is 12345." in prompt
|
||||
|
||||
|
||||
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
||||
async def test_function_call(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test function call from the assistant."""
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{vol.Optional("param1", description="Test parameters"): str}
|
||||
)
|
||||
mock_tool.async_call.return_value = "Test response"
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
def completion_result(*args, messages, **kwargs):
|
||||
for message in messages:
|
||||
if message["role"] == "tool":
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I have successfully called the function",
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": {"param1": "test_value"},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
side_effect=completion_result,
|
||||
) as mock_chat:
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert mock_chat.call_count == 2
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert (
|
||||
result.response.speech["plain"]["speech"]
|
||||
== "I have successfully called the function"
|
||||
)
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
llm.LLMContext(
|
||||
platform="ollama",
|
||||
context=context,
|
||||
user_prompt="Please call the test function",
|
||||
language="en",
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
||||
async def test_function_exception(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test function call with exception."""
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{vol.Optional("param1", description="Test parameters"): str}
|
||||
)
|
||||
mock_tool.async_call.side_effect = HomeAssistantError("Test tool exception")
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
def completion_result(*args, messages, **kwargs):
|
||||
for message in messages:
|
||||
if message["role"] == "tool":
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "There was an error calling the function",
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": {"param1": "test_value"},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
side_effect=completion_result,
|
||||
) as mock_chat:
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert mock_chat.call_count == 2
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert (
|
||||
result.response.speech["plain"]["speech"]
|
||||
== "There was an error calling the function"
|
||||
)
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
llm.LLMContext(
|
||||
platform="ollama",
|
||||
context=context,
|
||||
user_prompt="Please call the test function",
|
||||
language="en",
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def test_unknown_hass_api(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
snapshot: SnapshotAssertion,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test when we reference an API that no longer exists."""
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={
|
||||
**mock_config_entry.options,
|
||||
CONF_LLM_HASS_API: "non-existing",
|
||||
},
|
||||
)
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result == snapshot
|
||||
|
||||
|
||||
async def test_message_history_trimming(
|
||||
|
|
Loading…
Add table
Reference in a new issue