Add Ollama Conversation Agent Entity (#116363)

* Add ConversationEntity to OLlama integration

* Add assist_pipeline dependencies
This commit is contained in:
Allen Porter 2024-04-29 07:15:46 -07:00 committed by GitHub
parent eced3b0f57
commit f1dda8ef63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 617 additions and 557 deletions

View file

@ -4,40 +4,17 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import time
from typing import Literal
import httpx import httpx
import ollama import ollama
from homeassistant.components import conversation
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_URL, MATCH_ALL from homeassistant.const import CONF_URL, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import ( from homeassistant.helpers import config_validation as cv
area_registry as ar,
config_validation as cv,
device_registry as dr,
entity_registry as er,
intent,
template,
)
from homeassistant.util import ulid
from .const import ( from .const import CONF_MAX_HISTORY, CONF_MODEL, CONF_PROMPT, DEFAULT_TIMEOUT, DOMAIN
CONF_MAX_HISTORY,
CONF_MODEL,
CONF_PROMPT,
DEFAULT_MAX_HISTORY,
DEFAULT_PROMPT,
DEFAULT_TIMEOUT,
DOMAIN,
KEEP_ALIVE_FOREVER,
MAX_HISTORY_SECONDS,
)
from .models import ExposedEntity, MessageHistory, MessageRole
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -46,11 +23,11 @@ __all__ = [
"CONF_PROMPT", "CONF_PROMPT",
"CONF_MODEL", "CONF_MODEL",
"CONF_MAX_HISTORY", "CONF_MAX_HISTORY",
"MAX_HISTORY_NO_LIMIT",
"DOMAIN", "DOMAIN",
] ]
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION,)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
@ -65,202 +42,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
conversation.async_set_agent(hass, entry, OllamaAgent(hass, entry)) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Ollama.""" """Unload Ollama."""
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
return False
hass.data[DOMAIN].pop(entry.entry_id) hass.data[DOMAIN].pop(entry.entry_id)
conversation.async_unset_agent(hass, entry)
return True return True
class OllamaAgent(conversation.AbstractConversationAgent):
"""Ollama conversation agent."""
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.hass = hass
self.entry = entry
# conversation id -> message history
self._history: dict[str, MessageHistory] = {}
@property
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
return MATCH_ALL
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
settings = {**self.entry.data, **self.entry.options}
client = self.hass.data[DOMAIN][self.entry.entry_id]
conversation_id = user_input.conversation_id or ulid.ulid_now()
model = settings[CONF_MODEL]
# Look up message history
message_history: MessageHistory | None = None
message_history = self._history.get(conversation_id)
if message_history is None:
# New history
#
# Render prompt and error out early if there's a problem
raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT)
try:
prompt = self._generate_prompt(raw_prompt)
_LOGGER.debug("Prompt: %s", prompt)
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem generating my prompt: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
message_history = MessageHistory(
timestamp=time.monotonic(),
messages=[
ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
],
)
self._history[conversation_id] = message_history
else:
# Bump timestamp so this conversation won't get cleaned up
message_history.timestamp = time.monotonic()
# Clean up old histories
self._prune_old_histories()
# Trim this message history to keep a maximum number of *user* messages
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
self._trim_history(message_history, max_messages)
# Add new user message
message_history.messages.append(
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
)
# Get response
try:
response = await client.chat(
model=model,
# Make a copy of the messages because we mutate the list later
messages=list(message_history.messages),
stream=False,
keep_alive=KEEP_ALIVE_FOREVER,
)
except (ollama.RequestError, ollama.ResponseError) as err:
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to the Ollama server: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
response_message = response["message"]
message_history.messages.append(
ollama.Message(
role=response_message["role"], content=response_message["content"]
)
)
# Create intent response
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response_message["content"])
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
def _prune_old_histories(self) -> None:
"""Remove old message histories."""
now = time.monotonic()
self._history = {
conversation_id: message_history
for conversation_id, message_history in self._history.items()
if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
}
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
"""Trims excess messages from a single history."""
if max_messages < 1:
# Keep all messages
return
if message_history.num_user_messages >= max_messages:
# Trim history but keep system prompt (first message).
# Every other message should be an assistant message, so keep 2x
# message objects.
num_keep = 2 * max_messages
drop_index = len(message_history.messages) - num_keep
message_history.messages = [
message_history.messages[0]
] + message_history.messages[drop_index:]
def _generate_prompt(self, raw_prompt: str) -> str:
"""Generate a prompt for the user."""
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
"ha_language": self.hass.config.language,
"exposed_entities": self._get_exposed_entities(),
},
parse_result=False,
)
def _get_exposed_entities(self) -> list[ExposedEntity]:
"""Get state list of exposed entities."""
area_registry = ar.async_get(self.hass)
entity_registry = er.async_get(self.hass)
device_registry = dr.async_get(self.hass)
exposed_entities = []
exposed_states = [
state
for state in self.hass.states.async_all()
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
]
for state in exposed_states:
entity = entity_registry.async_get(state.entity_id)
names = [state.name]
area_names = []
if entity is not None:
# Add aliases
names.extend(entity.aliases)
if entity.area_id and (
area := area_registry.async_get_area(entity.area_id)
):
# Entity is in area
area_names.append(area.name)
area_names.extend(area.aliases)
elif entity.device_id and (
device := device_registry.async_get(entity.device_id)
):
# Check device area
if device.area_id and (
area := area_registry.async_get_area(device.area_id)
):
area_names.append(area.name)
area_names.extend(area.aliases)
exposed_entities.append(
ExposedEntity(
entity_id=state.entity_id,
state=state,
names=names,
area_names=area_names,
)
)
return exposed_entities

View file

@ -0,0 +1,258 @@
"""The conversation platform for the Ollama integration."""
from __future__ import annotations
import logging
import time
from typing import Literal
import ollama
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
intent,
template,
)
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
from .const import (
CONF_MAX_HISTORY,
CONF_MODEL,
CONF_PROMPT,
DEFAULT_MAX_HISTORY,
DEFAULT_PROMPT,
DOMAIN,
KEEP_ALIVE_FOREVER,
MAX_HISTORY_SECONDS,
)
from .models import ExposedEntity, MessageHistory, MessageRole
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up conversation entities."""
agent = OllamaConversationEntity(hass, config_entry)
async_add_entities([agent])
class OllamaConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
"""Ollama conversation agent."""
_attr_has_entity_name = True
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.hass = hass
self.entry = entry
# conversation id -> message history
self._history: dict[str, MessageHistory] = {}
self._attr_name = entry.title
self._attr_unique_id = entry.entry_id
async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
assist_pipeline.async_migrate_engine(
self.hass, "conversation", self.entry.entry_id, self.entity_id
)
conversation.async_set_agent(self.hass, self.entry, self)
async def async_will_remove_from_hass(self) -> None:
"""When entity will be removed from Home Assistant."""
conversation.async_unset_agent(self.hass, self.entry)
await super().async_will_remove_from_hass()
@property
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
return MATCH_ALL
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
settings = {**self.entry.data, **self.entry.options}
client = self.hass.data[DOMAIN][self.entry.entry_id]
conversation_id = user_input.conversation_id or ulid.ulid_now()
model = settings[CONF_MODEL]
# Look up message history
message_history: MessageHistory | None = None
message_history = self._history.get(conversation_id)
if message_history is None:
# New history
#
# Render prompt and error out early if there's a problem
raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT)
try:
prompt = self._generate_prompt(raw_prompt)
_LOGGER.debug("Prompt: %s", prompt)
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem generating my prompt: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
message_history = MessageHistory(
timestamp=time.monotonic(),
messages=[
ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
],
)
self._history[conversation_id] = message_history
else:
# Bump timestamp so this conversation won't get cleaned up
message_history.timestamp = time.monotonic()
# Clean up old histories
self._prune_old_histories()
# Trim this message history to keep a maximum number of *user* messages
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
self._trim_history(message_history, max_messages)
# Add new user message
message_history.messages.append(
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
)
# Get response
try:
response = await client.chat(
model=model,
# Make a copy of the messages because we mutate the list later
messages=list(message_history.messages),
stream=False,
keep_alive=KEEP_ALIVE_FOREVER,
)
except (ollama.RequestError, ollama.ResponseError) as err:
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to the Ollama server: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
response_message = response["message"]
message_history.messages.append(
ollama.Message(
role=response_message["role"], content=response_message["content"]
)
)
# Create intent response
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response_message["content"])
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
def _prune_old_histories(self) -> None:
"""Remove old message histories."""
now = time.monotonic()
self._history = {
conversation_id: message_history
for conversation_id, message_history in self._history.items()
if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
}
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
"""Trims excess messages from a single history."""
if max_messages < 1:
# Keep all messages
return
if message_history.num_user_messages >= max_messages:
# Trim history but keep system prompt (first message).
# Every other message should be an assistant message, so keep 2x
# message objects.
num_keep = 2 * max_messages
drop_index = len(message_history.messages) - num_keep
message_history.messages = [
message_history.messages[0]
] + message_history.messages[drop_index:]
def _generate_prompt(self, raw_prompt: str) -> str:
"""Generate a prompt for the user."""
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
"ha_language": self.hass.config.language,
"exposed_entities": self._get_exposed_entities(),
},
parse_result=False,
)
def _get_exposed_entities(self) -> list[ExposedEntity]:
"""Get state list of exposed entities."""
area_registry = ar.async_get(self.hass)
entity_registry = er.async_get(self.hass)
device_registry = dr.async_get(self.hass)
exposed_entities = []
exposed_states = [
state
for state in self.hass.states.async_all()
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
]
for state in exposed_states:
entity = entity_registry.async_get(state.entity_id)
names = [state.name]
area_names = []
if entity is not None:
# Add aliases
names.extend(entity.aliases)
if entity.area_id and (
area := area_registry.async_get_area(entity.area_id)
):
# Entity is in area
area_names.append(area.name)
area_names.extend(area.aliases)
elif entity.device_id and (
device := device_registry.async_get(entity.device_id)
):
# Check device area
if device.area_id and (
area := area_registry.async_get_area(device.area_id)
):
area_names.append(area.name)
area_names.extend(area.aliases)
exposed_entities.append(
ExposedEntity(
entity_id=state.entity_id,
state=state,
names=names,
area_names=area_names,
)
)
return exposed_entities

View file

@ -1,6 +1,7 @@
{ {
"domain": "ollama", "domain": "ollama",
"name": "Ollama", "name": "Ollama",
"after_dependencies": ["assist_pipeline"],
"codeowners": ["@synesthesiam"], "codeowners": ["@synesthesiam"],
"config_flow": true, "config_flow": true,
"dependencies": ["conversation"], "dependencies": ["conversation"],

View file

@ -0,0 +1,347 @@
"""Tests for the Ollama integration."""
from unittest.mock import AsyncMock, patch
from ollama import Message, ResponseError
import pytest
from homeassistant.components import conversation, ollama
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
intent,
)
from tests.common import MockConfigEntry
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
async def test_chat(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
agent_id: str,
) -> None:
"""Test that the chat function is called with the appropriate arguments."""
if agent_id is None:
agent_id = mock_config_entry.entry_id
# Create some areas, devices, and entities
area_kitchen = area_registry.async_get_or_create("kitchen_id")
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
area_bedroom = area_registry.async_get_or_create("bedroom_id")
area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom")
area_office = area_registry.async_get_or_create("office_id")
area_office = area_registry.async_update(area_office.id, name="office")
entry = MockConfigEntry()
entry.add_to_hass(hass)
kitchen_device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections=set(),
identifiers={("demo", "id-1234")},
)
device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id)
kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234")
kitchen_light = entity_registry.async_update_entity(
kitchen_light.entity_id, device_id=kitchen_device.id
)
hass.states.async_set(
kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
)
bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678")
bedroom_light = entity_registry.async_update_entity(
bedroom_light.entity_id, area_id=area_bedroom.id
)
hass.states.async_set(
bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"}
)
# Hide the office light
office_light = entity_registry.async_get_or_create("light", "demo", "ABCD")
office_light = entity_registry.async_update_entity(
office_light.entity_id, area_id=area_office.id
)
hass.states.async_set(
office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"}
)
async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False)
with patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
) as mock_chat:
result = await conversation.async_converse(
hass,
"test message",
None,
Context(),
agent_id=agent_id,
)
assert mock_chat.call_count == 1
args = mock_chat.call_args.kwargs
prompt = args["messages"][0]["content"]
assert args["model"] == "test model"
assert args["messages"] == [
Message({"role": "system", "content": prompt}),
Message({"role": "user", "content": "test message"}),
]
# Verify only exposed devices/areas are in prompt
assert "kitchen light" in prompt
assert "bedroom light" in prompt
assert "office light" not in prompt
assert "office" not in prompt
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert result.response.speech["plain"]["speech"] == "test response"
async def test_message_history_trimming(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that a single message history is trimmed according to the config."""
response_idx = 0
def response(*args, **kwargs) -> dict:
nonlocal response_idx
response_idx += 1
return {"message": {"role": "assistant", "content": f"response {response_idx}"}}
with patch(
"ollama.AsyncClient.chat",
side_effect=response,
) as mock_chat:
# mock_init_component sets "max_history" to 2
for i in range(5):
result = await conversation.async_converse(
hass,
f"message {i+1}",
conversation_id="1234",
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert mock_chat.call_count == 5
args = mock_chat.call_args_list
prompt = args[0].kwargs["messages"][0]["content"]
# system + user-1
assert len(args[0].kwargs["messages"]) == 2
assert args[0].kwargs["messages"][1]["content"] == "message 1"
# Full history
# system + user-1 + assistant-1 + user-2
assert len(args[1].kwargs["messages"]) == 4
assert args[1].kwargs["messages"][0]["role"] == "system"
assert args[1].kwargs["messages"][0]["content"] == prompt
assert args[1].kwargs["messages"][1]["role"] == "user"
assert args[1].kwargs["messages"][1]["content"] == "message 1"
assert args[1].kwargs["messages"][2]["role"] == "assistant"
assert args[1].kwargs["messages"][2]["content"] == "response 1"
assert args[1].kwargs["messages"][3]["role"] == "user"
assert args[1].kwargs["messages"][3]["content"] == "message 2"
# Full history
# system + user-1 + assistant-1 + user-2 + assistant-2 + user-3
assert len(args[2].kwargs["messages"]) == 6
assert args[2].kwargs["messages"][0]["role"] == "system"
assert args[2].kwargs["messages"][0]["content"] == prompt
assert args[2].kwargs["messages"][1]["role"] == "user"
assert args[2].kwargs["messages"][1]["content"] == "message 1"
assert args[2].kwargs["messages"][2]["role"] == "assistant"
assert args[2].kwargs["messages"][2]["content"] == "response 1"
assert args[2].kwargs["messages"][3]["role"] == "user"
assert args[2].kwargs["messages"][3]["content"] == "message 2"
assert args[2].kwargs["messages"][4]["role"] == "assistant"
assert args[2].kwargs["messages"][4]["content"] == "response 2"
assert args[2].kwargs["messages"][5]["role"] == "user"
assert args[2].kwargs["messages"][5]["content"] == "message 3"
# Trimmed down to two user messages.
# system + user-2 + assistant-2 + user-3 + assistant-3 + user-4
assert len(args[3].kwargs["messages"]) == 6
assert args[3].kwargs["messages"][0]["role"] == "system"
assert args[3].kwargs["messages"][0]["content"] == prompt
assert args[3].kwargs["messages"][1]["role"] == "user"
assert args[3].kwargs["messages"][1]["content"] == "message 2"
assert args[3].kwargs["messages"][2]["role"] == "assistant"
assert args[3].kwargs["messages"][2]["content"] == "response 2"
assert args[3].kwargs["messages"][3]["role"] == "user"
assert args[3].kwargs["messages"][3]["content"] == "message 3"
assert args[3].kwargs["messages"][4]["role"] == "assistant"
assert args[3].kwargs["messages"][4]["content"] == "response 3"
assert args[3].kwargs["messages"][5]["role"] == "user"
assert args[3].kwargs["messages"][5]["content"] == "message 4"
# Trimmed down to two user messages.
# system + user-3 + assistant-3 + user-4 + assistant-4 + user-5
assert len(args[3].kwargs["messages"]) == 6
assert args[4].kwargs["messages"][0]["role"] == "system"
assert args[4].kwargs["messages"][0]["content"] == prompt
assert args[4].kwargs["messages"][1]["role"] == "user"
assert args[4].kwargs["messages"][1]["content"] == "message 3"
assert args[4].kwargs["messages"][2]["role"] == "assistant"
assert args[4].kwargs["messages"][2]["content"] == "response 3"
assert args[4].kwargs["messages"][3]["role"] == "user"
assert args[4].kwargs["messages"][3]["content"] == "message 4"
assert args[4].kwargs["messages"][4]["role"] == "assistant"
assert args[4].kwargs["messages"][4]["content"] == "response 4"
assert args[4].kwargs["messages"][5]["role"] == "user"
assert args[4].kwargs["messages"][5]["content"] == "message 5"
async def test_message_history_pruning(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that old message histories are pruned."""
with patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
):
# Create 3 different message histories
conversation_ids: list[str] = []
for i in range(3):
result = await conversation.async_converse(
hass,
f"message {i+1}",
conversation_id=None,
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert isinstance(result.conversation_id, str)
conversation_ids.append(result.conversation_id)
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert len(agent._history) == 3
assert agent._history.keys() == set(conversation_ids)
# Modify the timestamps of the first 2 histories so they will be pruned
# on the next cycle.
for conversation_id in conversation_ids[:2]:
# Move back 2 hours
agent._history[conversation_id].timestamp -= 2 * 60 * 60
# Next cycle
result = await conversation.async_converse(
hass,
"test message",
conversation_id=None,
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
# Only the most recent histories should remain
assert len(agent._history) == 2
assert conversation_ids[-1] in agent._history
assert result.conversation_id in agent._history
async def test_message_history_unlimited(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that message history is not trimmed when max_history = 0."""
conversation_id = "1234"
with (
patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
),
patch.object(mock_config_entry, "options", {ollama.CONF_MAX_HISTORY: 0}),
):
for i in range(100):
result = await conversation.async_converse(
hass,
f"message {i+1}",
conversation_id=conversation_id,
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert len(agent._history) == 1
assert conversation_id in agent._history
assert agent._history[conversation_id].num_user_messages == 100
async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test error handling during converse."""
with patch(
"ollama.AsyncClient.chat",
new_callable=AsyncMock,
side_effect=ResponseError("test error"),
):
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that template error handling works."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
},
)
with patch(
"ollama.AsyncClient.list",
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
async def test_conversation_agent(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test OllamaConversationEntity."""
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == MATCH_ALL

View file

@ -1,351 +1,17 @@
"""Tests for the Ollama integration.""" """Tests for the Ollama integration."""
from unittest.mock import AsyncMock, patch from unittest.mock import patch
from httpx import ConnectError from httpx import ConnectError
from ollama import Message, ResponseError
import pytest import pytest
from homeassistant.components import conversation, ollama from homeassistant.components import ollama
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.core import HomeAssistant
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
intent,
)
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
async def test_chat(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test that the chat function is called with the appropriate arguments."""
# Create some areas, devices, and entities
area_kitchen = area_registry.async_get_or_create("kitchen_id")
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
area_bedroom = area_registry.async_get_or_create("bedroom_id")
area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom")
area_office = area_registry.async_get_or_create("office_id")
area_office = area_registry.async_update(area_office.id, name="office")
entry = MockConfigEntry()
entry.add_to_hass(hass)
kitchen_device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections=set(),
identifiers={("demo", "id-1234")},
)
device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id)
kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234")
kitchen_light = entity_registry.async_update_entity(
kitchen_light.entity_id, device_id=kitchen_device.id
)
hass.states.async_set(
kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
)
bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678")
bedroom_light = entity_registry.async_update_entity(
bedroom_light.entity_id, area_id=area_bedroom.id
)
hass.states.async_set(
bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"}
)
# Hide the office light
office_light = entity_registry.async_get_or_create("light", "demo", "ABCD")
office_light = entity_registry.async_update_entity(
office_light.entity_id, area_id=area_office.id
)
hass.states.async_set(
office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"}
)
async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False)
with patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
) as mock_chat:
result = await conversation.async_converse(
hass,
"test message",
None,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert mock_chat.call_count == 1
args = mock_chat.call_args.kwargs
prompt = args["messages"][0]["content"]
assert args["model"] == "test model"
assert args["messages"] == [
Message({"role": "system", "content": prompt}),
Message({"role": "user", "content": "test message"}),
]
# Verify only exposed devices/areas are in prompt
assert "kitchen light" in prompt
assert "bedroom light" in prompt
assert "office light" not in prompt
assert "office" not in prompt
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert result.response.speech["plain"]["speech"] == "test response"
async def test_message_history_trimming(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that a single message history is trimmed according to the config."""
response_idx = 0
def response(*args, **kwargs) -> dict:
nonlocal response_idx
response_idx += 1
return {"message": {"role": "assistant", "content": f"response {response_idx}"}}
with patch(
"ollama.AsyncClient.chat",
side_effect=response,
) as mock_chat:
# mock_init_component sets "max_history" to 2
for i in range(5):
result = await conversation.async_converse(
hass,
f"message {i+1}",
conversation_id="1234",
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert mock_chat.call_count == 5
args = mock_chat.call_args_list
prompt = args[0].kwargs["messages"][0]["content"]
# system + user-1
assert len(args[0].kwargs["messages"]) == 2
assert args[0].kwargs["messages"][1]["content"] == "message 1"
# Full history
# system + user-1 + assistant-1 + user-2
assert len(args[1].kwargs["messages"]) == 4
assert args[1].kwargs["messages"][0]["role"] == "system"
assert args[1].kwargs["messages"][0]["content"] == prompt
assert args[1].kwargs["messages"][1]["role"] == "user"
assert args[1].kwargs["messages"][1]["content"] == "message 1"
assert args[1].kwargs["messages"][2]["role"] == "assistant"
assert args[1].kwargs["messages"][2]["content"] == "response 1"
assert args[1].kwargs["messages"][3]["role"] == "user"
assert args[1].kwargs["messages"][3]["content"] == "message 2"
# Full history
# system + user-1 + assistant-1 + user-2 + assistant-2 + user-3
assert len(args[2].kwargs["messages"]) == 6
assert args[2].kwargs["messages"][0]["role"] == "system"
assert args[2].kwargs["messages"][0]["content"] == prompt
assert args[2].kwargs["messages"][1]["role"] == "user"
assert args[2].kwargs["messages"][1]["content"] == "message 1"
assert args[2].kwargs["messages"][2]["role"] == "assistant"
assert args[2].kwargs["messages"][2]["content"] == "response 1"
assert args[2].kwargs["messages"][3]["role"] == "user"
assert args[2].kwargs["messages"][3]["content"] == "message 2"
assert args[2].kwargs["messages"][4]["role"] == "assistant"
assert args[2].kwargs["messages"][4]["content"] == "response 2"
assert args[2].kwargs["messages"][5]["role"] == "user"
assert args[2].kwargs["messages"][5]["content"] == "message 3"
# Trimmed down to two user messages.
# system + user-2 + assistant-2 + user-3 + assistant-3 + user-4
assert len(args[3].kwargs["messages"]) == 6
assert args[3].kwargs["messages"][0]["role"] == "system"
assert args[3].kwargs["messages"][0]["content"] == prompt
assert args[3].kwargs["messages"][1]["role"] == "user"
assert args[3].kwargs["messages"][1]["content"] == "message 2"
assert args[3].kwargs["messages"][2]["role"] == "assistant"
assert args[3].kwargs["messages"][2]["content"] == "response 2"
assert args[3].kwargs["messages"][3]["role"] == "user"
assert args[3].kwargs["messages"][3]["content"] == "message 3"
assert args[3].kwargs["messages"][4]["role"] == "assistant"
assert args[3].kwargs["messages"][4]["content"] == "response 3"
assert args[3].kwargs["messages"][5]["role"] == "user"
assert args[3].kwargs["messages"][5]["content"] == "message 4"
# Trimmed down to two user messages.
# system + user-3 + assistant-3 + user-4 + assistant-4 + user-5
assert len(args[3].kwargs["messages"]) == 6
assert args[4].kwargs["messages"][0]["role"] == "system"
assert args[4].kwargs["messages"][0]["content"] == prompt
assert args[4].kwargs["messages"][1]["role"] == "user"
assert args[4].kwargs["messages"][1]["content"] == "message 3"
assert args[4].kwargs["messages"][2]["role"] == "assistant"
assert args[4].kwargs["messages"][2]["content"] == "response 3"
assert args[4].kwargs["messages"][3]["role"] == "user"
assert args[4].kwargs["messages"][3]["content"] == "message 4"
assert args[4].kwargs["messages"][4]["role"] == "assistant"
assert args[4].kwargs["messages"][4]["content"] == "response 4"
assert args[4].kwargs["messages"][5]["role"] == "user"
assert args[4].kwargs["messages"][5]["content"] == "message 5"
async def test_message_history_pruning(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that old message histories are pruned."""
with patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
):
# Create 3 different message histories
conversation_ids: list[str] = []
for i in range(3):
result = await conversation.async_converse(
hass,
f"message {i+1}",
conversation_id=None,
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert isinstance(result.conversation_id, str)
conversation_ids.append(result.conversation_id)
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert isinstance(agent, ollama.OllamaAgent)
assert len(agent._history) == 3
assert agent._history.keys() == set(conversation_ids)
# Modify the timestamps of the first 2 histories so they will be pruned
# on the next cycle.
for conversation_id in conversation_ids[:2]:
# Move back 2 hours
agent._history[conversation_id].timestamp -= 2 * 60 * 60
# Next cycle
result = await conversation.async_converse(
hass,
"test message",
conversation_id=None,
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
# Only the most recent histories should remain
assert len(agent._history) == 2
assert conversation_ids[-1] in agent._history
assert result.conversation_id in agent._history
async def test_message_history_unlimited(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that message history is not trimmed when max_history = 0."""
conversation_id = "1234"
with (
patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
),
patch.object(mock_config_entry, "options", {ollama.CONF_MAX_HISTORY: 0}),
):
for i in range(100):
result = await conversation.async_converse(
hass,
f"message {i+1}",
conversation_id=conversation_id,
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert isinstance(agent, ollama.OllamaAgent)
assert len(agent._history) == 1
assert conversation_id in agent._history
assert agent._history[conversation_id].num_user_messages == 100
async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test error handling during converse."""
with patch(
"ollama.AsyncClient.chat",
new_callable=AsyncMock,
side_effect=ResponseError("test error"),
):
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that template error handling works."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
},
)
with patch(
"ollama.AsyncClient.list",
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
async def test_conversation_agent(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test OllamaAgent."""
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == MATCH_ALL
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "error"), ("side_effect", "error"),
[ [