Add Ollama Conversation Agent Entity (#116363)
* Add ConversationEntity to OLlama integration * Add assist_pipeline dependencies
This commit is contained in:
parent
eced3b0f57
commit
f1dda8ef63
5 changed files with 617 additions and 557 deletions
|
@ -4,40 +4,17 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import ollama
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_URL, MATCH_ALL
|
||||
from homeassistant.const import CONF_URL, Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
config_validation as cv,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
intent,
|
||||
template,
|
||||
)
|
||||
from homeassistant.util import ulid
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
|
||||
from .const import (
|
||||
CONF_MAX_HISTORY,
|
||||
CONF_MODEL,
|
||||
CONF_PROMPT,
|
||||
DEFAULT_MAX_HISTORY,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_TIMEOUT,
|
||||
DOMAIN,
|
||||
KEEP_ALIVE_FOREVER,
|
||||
MAX_HISTORY_SECONDS,
|
||||
)
|
||||
from .models import ExposedEntity, MessageHistory, MessageRole
|
||||
from .const import CONF_MAX_HISTORY, CONF_MODEL, CONF_PROMPT, DEFAULT_TIMEOUT, DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -46,11 +23,11 @@ __all__ = [
|
|||
"CONF_PROMPT",
|
||||
"CONF_MODEL",
|
||||
"CONF_MAX_HISTORY",
|
||||
"MAX_HISTORY_NO_LIMIT",
|
||||
"DOMAIN",
|
||||
]
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
PLATFORMS = (Platform.CONVERSATION,)
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
|
@ -65,202 +42,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
|
||||
|
||||
conversation.async_set_agent(hass, entry, OllamaAgent(hass, entry))
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload Ollama."""
|
||||
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
||||
return False
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
conversation.async_unset_agent(hass, entry)
|
||||
return True
|
||||
|
||||
|
||||
class OllamaAgent(conversation.AbstractConversationAgent):
|
||||
"""Ollama conversation agent."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
|
||||
"""Initialize the agent."""
|
||||
self.hass = hass
|
||||
self.entry = entry
|
||||
|
||||
# conversation id -> message history
|
||||
self._history: dict[str, MessageHistory] = {}
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||
"""Return a list of supported languages."""
|
||||
return MATCH_ALL
|
||||
|
||||
async def async_process(
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
settings = {**self.entry.data, **self.entry.options}
|
||||
|
||||
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
||||
conversation_id = user_input.conversation_id or ulid.ulid_now()
|
||||
model = settings[CONF_MODEL]
|
||||
|
||||
# Look up message history
|
||||
message_history: MessageHistory | None = None
|
||||
message_history = self._history.get(conversation_id)
|
||||
if message_history is None:
|
||||
# New history
|
||||
#
|
||||
# Render prompt and error out early if there's a problem
|
||||
raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
try:
|
||||
prompt = self._generate_prompt(raw_prompt)
|
||||
_LOGGER.debug("Prompt: %s", prompt)
|
||||
except TemplateError as err:
|
||||
_LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem generating my prompt: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
message_history = MessageHistory(
|
||||
timestamp=time.monotonic(),
|
||||
messages=[
|
||||
ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
|
||||
],
|
||||
)
|
||||
self._history[conversation_id] = message_history
|
||||
else:
|
||||
# Bump timestamp so this conversation won't get cleaned up
|
||||
message_history.timestamp = time.monotonic()
|
||||
|
||||
# Clean up old histories
|
||||
self._prune_old_histories()
|
||||
|
||||
# Trim this message history to keep a maximum number of *user* messages
|
||||
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
|
||||
self._trim_history(message_history, max_messages)
|
||||
|
||||
# Add new user message
|
||||
message_history.messages.append(
|
||||
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
|
||||
)
|
||||
|
||||
# Get response
|
||||
try:
|
||||
response = await client.chat(
|
||||
model=model,
|
||||
# Make a copy of the messages because we mutate the list later
|
||||
messages=list(message_history.messages),
|
||||
stream=False,
|
||||
keep_alive=KEEP_ALIVE_FOREVER,
|
||||
)
|
||||
except (ollama.RequestError, ollama.ResponseError) as err:
|
||||
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem talking to the Ollama server: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
response_message = response["message"]
|
||||
message_history.messages.append(
|
||||
ollama.Message(
|
||||
role=response_message["role"], content=response_message["content"]
|
||||
)
|
||||
)
|
||||
|
||||
# Create intent response
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(response_message["content"])
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
def _prune_old_histories(self) -> None:
|
||||
"""Remove old message histories."""
|
||||
now = time.monotonic()
|
||||
self._history = {
|
||||
conversation_id: message_history
|
||||
for conversation_id, message_history in self._history.items()
|
||||
if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
|
||||
}
|
||||
|
||||
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
|
||||
"""Trims excess messages from a single history."""
|
||||
if max_messages < 1:
|
||||
# Keep all messages
|
||||
return
|
||||
|
||||
if message_history.num_user_messages >= max_messages:
|
||||
# Trim history but keep system prompt (first message).
|
||||
# Every other message should be an assistant message, so keep 2x
|
||||
# message objects.
|
||||
num_keep = 2 * max_messages
|
||||
drop_index = len(message_history.messages) - num_keep
|
||||
message_history.messages = [
|
||||
message_history.messages[0]
|
||||
] + message_history.messages[drop_index:]
|
||||
|
||||
def _generate_prompt(self, raw_prompt: str) -> str:
|
||||
"""Generate a prompt for the user."""
|
||||
return template.Template(raw_prompt, self.hass).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"ha_language": self.hass.config.language,
|
||||
"exposed_entities": self._get_exposed_entities(),
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
|
||||
def _get_exposed_entities(self) -> list[ExposedEntity]:
|
||||
"""Get state list of exposed entities."""
|
||||
area_registry = ar.async_get(self.hass)
|
||||
entity_registry = er.async_get(self.hass)
|
||||
device_registry = dr.async_get(self.hass)
|
||||
|
||||
exposed_entities = []
|
||||
exposed_states = [
|
||||
state
|
||||
for state in self.hass.states.async_all()
|
||||
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
|
||||
]
|
||||
|
||||
for state in exposed_states:
|
||||
entity = entity_registry.async_get(state.entity_id)
|
||||
names = [state.name]
|
||||
area_names = []
|
||||
|
||||
if entity is not None:
|
||||
# Add aliases
|
||||
names.extend(entity.aliases)
|
||||
if entity.area_id and (
|
||||
area := area_registry.async_get_area(entity.area_id)
|
||||
):
|
||||
# Entity is in area
|
||||
area_names.append(area.name)
|
||||
area_names.extend(area.aliases)
|
||||
elif entity.device_id and (
|
||||
device := device_registry.async_get(entity.device_id)
|
||||
):
|
||||
# Check device area
|
||||
if device.area_id and (
|
||||
area := area_registry.async_get_area(device.area_id)
|
||||
):
|
||||
area_names.append(area.name)
|
||||
area_names.extend(area.aliases)
|
||||
|
||||
exposed_entities.append(
|
||||
ExposedEntity(
|
||||
entity_id=state.entity_id,
|
||||
state=state,
|
||||
names=names,
|
||||
area_names=area_names,
|
||||
)
|
||||
)
|
||||
|
||||
return exposed_entities
|
||||
|
|
258
homeassistant/components/ollama/conversation.py
Normal file
258
homeassistant/components/ollama/conversation.py
Normal file
|
@ -0,0 +1,258 @@
|
|||
"""The conversation platform for the Ollama integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
import ollama
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
intent,
|
||||
template,
|
||||
)
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from .const import (
|
||||
CONF_MAX_HISTORY,
|
||||
CONF_MODEL,
|
||||
CONF_PROMPT,
|
||||
DEFAULT_MAX_HISTORY,
|
||||
DEFAULT_PROMPT,
|
||||
DOMAIN,
|
||||
KEEP_ALIVE_FOREVER,
|
||||
MAX_HISTORY_SECONDS,
|
||||
)
|
||||
from .models import ExposedEntity, MessageHistory, MessageRole
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up conversation entities."""
|
||||
agent = OllamaConversationEntity(hass, config_entry)
|
||||
async_add_entities([agent])
|
||||
|
||||
|
||||
class OllamaConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
"""Ollama conversation agent."""
|
||||
|
||||
_attr_has_entity_name = True
|
||||
|
||||
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
|
||||
"""Initialize the agent."""
|
||||
self.hass = hass
|
||||
self.entry = entry
|
||||
|
||||
# conversation id -> message history
|
||||
self._history: dict[str, MessageHistory] = {}
|
||||
self._attr_name = entry.title
|
||||
self._attr_unique_id = entry.entry_id
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""When entity is added to Home Assistant."""
|
||||
await super().async_added_to_hass()
|
||||
assist_pipeline.async_migrate_engine(
|
||||
self.hass, "conversation", self.entry.entry_id, self.entity_id
|
||||
)
|
||||
conversation.async_set_agent(self.hass, self.entry, self)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""When entity will be removed from Home Assistant."""
|
||||
conversation.async_unset_agent(self.hass, self.entry)
|
||||
await super().async_will_remove_from_hass()
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||
"""Return a list of supported languages."""
|
||||
return MATCH_ALL
|
||||
|
||||
async def async_process(
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
settings = {**self.entry.data, **self.entry.options}
|
||||
|
||||
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
||||
conversation_id = user_input.conversation_id or ulid.ulid_now()
|
||||
model = settings[CONF_MODEL]
|
||||
|
||||
# Look up message history
|
||||
message_history: MessageHistory | None = None
|
||||
message_history = self._history.get(conversation_id)
|
||||
if message_history is None:
|
||||
# New history
|
||||
#
|
||||
# Render prompt and error out early if there's a problem
|
||||
raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
try:
|
||||
prompt = self._generate_prompt(raw_prompt)
|
||||
_LOGGER.debug("Prompt: %s", prompt)
|
||||
except TemplateError as err:
|
||||
_LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem generating my prompt: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
message_history = MessageHistory(
|
||||
timestamp=time.monotonic(),
|
||||
messages=[
|
||||
ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
|
||||
],
|
||||
)
|
||||
self._history[conversation_id] = message_history
|
||||
else:
|
||||
# Bump timestamp so this conversation won't get cleaned up
|
||||
message_history.timestamp = time.monotonic()
|
||||
|
||||
# Clean up old histories
|
||||
self._prune_old_histories()
|
||||
|
||||
# Trim this message history to keep a maximum number of *user* messages
|
||||
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
|
||||
self._trim_history(message_history, max_messages)
|
||||
|
||||
# Add new user message
|
||||
message_history.messages.append(
|
||||
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
|
||||
)
|
||||
|
||||
# Get response
|
||||
try:
|
||||
response = await client.chat(
|
||||
model=model,
|
||||
# Make a copy of the messages because we mutate the list later
|
||||
messages=list(message_history.messages),
|
||||
stream=False,
|
||||
keep_alive=KEEP_ALIVE_FOREVER,
|
||||
)
|
||||
except (ollama.RequestError, ollama.ResponseError) as err:
|
||||
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem talking to the Ollama server: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
response_message = response["message"]
|
||||
message_history.messages.append(
|
||||
ollama.Message(
|
||||
role=response_message["role"], content=response_message["content"]
|
||||
)
|
||||
)
|
||||
|
||||
# Create intent response
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(response_message["content"])
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
def _prune_old_histories(self) -> None:
|
||||
"""Remove old message histories."""
|
||||
now = time.monotonic()
|
||||
self._history = {
|
||||
conversation_id: message_history
|
||||
for conversation_id, message_history in self._history.items()
|
||||
if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
|
||||
}
|
||||
|
||||
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
|
||||
"""Trims excess messages from a single history."""
|
||||
if max_messages < 1:
|
||||
# Keep all messages
|
||||
return
|
||||
|
||||
if message_history.num_user_messages >= max_messages:
|
||||
# Trim history but keep system prompt (first message).
|
||||
# Every other message should be an assistant message, so keep 2x
|
||||
# message objects.
|
||||
num_keep = 2 * max_messages
|
||||
drop_index = len(message_history.messages) - num_keep
|
||||
message_history.messages = [
|
||||
message_history.messages[0]
|
||||
] + message_history.messages[drop_index:]
|
||||
|
||||
def _generate_prompt(self, raw_prompt: str) -> str:
|
||||
"""Generate a prompt for the user."""
|
||||
return template.Template(raw_prompt, self.hass).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"ha_language": self.hass.config.language,
|
||||
"exposed_entities": self._get_exposed_entities(),
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
|
||||
def _get_exposed_entities(self) -> list[ExposedEntity]:
|
||||
"""Get state list of exposed entities."""
|
||||
area_registry = ar.async_get(self.hass)
|
||||
entity_registry = er.async_get(self.hass)
|
||||
device_registry = dr.async_get(self.hass)
|
||||
|
||||
exposed_entities = []
|
||||
exposed_states = [
|
||||
state
|
||||
for state in self.hass.states.async_all()
|
||||
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
|
||||
]
|
||||
|
||||
for state in exposed_states:
|
||||
entity = entity_registry.async_get(state.entity_id)
|
||||
names = [state.name]
|
||||
area_names = []
|
||||
|
||||
if entity is not None:
|
||||
# Add aliases
|
||||
names.extend(entity.aliases)
|
||||
if entity.area_id and (
|
||||
area := area_registry.async_get_area(entity.area_id)
|
||||
):
|
||||
# Entity is in area
|
||||
area_names.append(area.name)
|
||||
area_names.extend(area.aliases)
|
||||
elif entity.device_id and (
|
||||
device := device_registry.async_get(entity.device_id)
|
||||
):
|
||||
# Check device area
|
||||
if device.area_id and (
|
||||
area := area_registry.async_get_area(device.area_id)
|
||||
):
|
||||
area_names.append(area.name)
|
||||
area_names.extend(area.aliases)
|
||||
|
||||
exposed_entities.append(
|
||||
ExposedEntity(
|
||||
entity_id=state.entity_id,
|
||||
state=state,
|
||||
names=names,
|
||||
area_names=area_names,
|
||||
)
|
||||
)
|
||||
|
||||
return exposed_entities
|
|
@ -1,6 +1,7 @@
|
|||
{
|
||||
"domain": "ollama",
|
||||
"name": "Ollama",
|
||||
"after_dependencies": ["assist_pipeline"],
|
||||
"codeowners": ["@synesthesiam"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["conversation"],
|
||||
|
|
347
tests/components/ollama/test_conversation.py
Normal file
347
tests/components/ollama/test_conversation.py
Normal file
|
@ -0,0 +1,347 @@
|
|||
"""Tests for the Ollama integration."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from ollama import Message, ResponseError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import conversation, ollama
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
intent,
|
||||
)
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
||||
async def test_chat(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
area_registry: ar.AreaRegistry,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Test that the chat function is called with the appropriate arguments."""
|
||||
|
||||
if agent_id is None:
|
||||
agent_id = mock_config_entry.entry_id
|
||||
|
||||
# Create some areas, devices, and entities
|
||||
area_kitchen = area_registry.async_get_or_create("kitchen_id")
|
||||
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
|
||||
area_bedroom = area_registry.async_get_or_create("bedroom_id")
|
||||
area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom")
|
||||
area_office = area_registry.async_get_or_create("office_id")
|
||||
area_office = area_registry.async_update(area_office.id, name="office")
|
||||
|
||||
entry = MockConfigEntry()
|
||||
entry.add_to_hass(hass)
|
||||
kitchen_device = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections=set(),
|
||||
identifiers={("demo", "id-1234")},
|
||||
)
|
||||
device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id)
|
||||
|
||||
kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234")
|
||||
kitchen_light = entity_registry.async_update_entity(
|
||||
kitchen_light.entity_id, device_id=kitchen_device.id
|
||||
)
|
||||
hass.states.async_set(
|
||||
kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
|
||||
)
|
||||
|
||||
bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678")
|
||||
bedroom_light = entity_registry.async_update_entity(
|
||||
bedroom_light.entity_id, area_id=area_bedroom.id
|
||||
)
|
||||
hass.states.async_set(
|
||||
bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"}
|
||||
)
|
||||
|
||||
# Hide the office light
|
||||
office_light = entity_registry.async_get_or_create("light", "demo", "ABCD")
|
||||
office_light = entity_registry.async_update_entity(
|
||||
office_light.entity_id, area_id=area_office.id
|
||||
)
|
||||
hass.states.async_set(
|
||||
office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"}
|
||||
)
|
||||
async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False)
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
) as mock_chat:
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"test message",
|
||||
None,
|
||||
Context(),
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert mock_chat.call_count == 1
|
||||
args = mock_chat.call_args.kwargs
|
||||
prompt = args["messages"][0]["content"]
|
||||
|
||||
assert args["model"] == "test model"
|
||||
assert args["messages"] == [
|
||||
Message({"role": "system", "content": prompt}),
|
||||
Message({"role": "user", "content": "test message"}),
|
||||
]
|
||||
|
||||
# Verify only exposed devices/areas are in prompt
|
||||
assert "kitchen light" in prompt
|
||||
assert "bedroom light" in prompt
|
||||
assert "office light" not in prompt
|
||||
assert "office" not in prompt
|
||||
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
assert result.response.speech["plain"]["speech"] == "test response"
|
||||
|
||||
|
||||
async def test_message_history_trimming(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test that a single message history is trimmed according to the config."""
|
||||
response_idx = 0
|
||||
|
||||
def response(*args, **kwargs) -> dict:
|
||||
nonlocal response_idx
|
||||
response_idx += 1
|
||||
return {"message": {"role": "assistant", "content": f"response {response_idx}"}}
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
side_effect=response,
|
||||
) as mock_chat:
|
||||
# mock_init_component sets "max_history" to 2
|
||||
for i in range(5):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
f"message {i+1}",
|
||||
conversation_id="1234",
|
||||
context=Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
|
||||
assert mock_chat.call_count == 5
|
||||
args = mock_chat.call_args_list
|
||||
prompt = args[0].kwargs["messages"][0]["content"]
|
||||
|
||||
# system + user-1
|
||||
assert len(args[0].kwargs["messages"]) == 2
|
||||
assert args[0].kwargs["messages"][1]["content"] == "message 1"
|
||||
|
||||
# Full history
|
||||
# system + user-1 + assistant-1 + user-2
|
||||
assert len(args[1].kwargs["messages"]) == 4
|
||||
assert args[1].kwargs["messages"][0]["role"] == "system"
|
||||
assert args[1].kwargs["messages"][0]["content"] == prompt
|
||||
assert args[1].kwargs["messages"][1]["role"] == "user"
|
||||
assert args[1].kwargs["messages"][1]["content"] == "message 1"
|
||||
assert args[1].kwargs["messages"][2]["role"] == "assistant"
|
||||
assert args[1].kwargs["messages"][2]["content"] == "response 1"
|
||||
assert args[1].kwargs["messages"][3]["role"] == "user"
|
||||
assert args[1].kwargs["messages"][3]["content"] == "message 2"
|
||||
|
||||
# Full history
|
||||
# system + user-1 + assistant-1 + user-2 + assistant-2 + user-3
|
||||
assert len(args[2].kwargs["messages"]) == 6
|
||||
assert args[2].kwargs["messages"][0]["role"] == "system"
|
||||
assert args[2].kwargs["messages"][0]["content"] == prompt
|
||||
assert args[2].kwargs["messages"][1]["role"] == "user"
|
||||
assert args[2].kwargs["messages"][1]["content"] == "message 1"
|
||||
assert args[2].kwargs["messages"][2]["role"] == "assistant"
|
||||
assert args[2].kwargs["messages"][2]["content"] == "response 1"
|
||||
assert args[2].kwargs["messages"][3]["role"] == "user"
|
||||
assert args[2].kwargs["messages"][3]["content"] == "message 2"
|
||||
assert args[2].kwargs["messages"][4]["role"] == "assistant"
|
||||
assert args[2].kwargs["messages"][4]["content"] == "response 2"
|
||||
assert args[2].kwargs["messages"][5]["role"] == "user"
|
||||
assert args[2].kwargs["messages"][5]["content"] == "message 3"
|
||||
|
||||
# Trimmed down to two user messages.
|
||||
# system + user-2 + assistant-2 + user-3 + assistant-3 + user-4
|
||||
assert len(args[3].kwargs["messages"]) == 6
|
||||
assert args[3].kwargs["messages"][0]["role"] == "system"
|
||||
assert args[3].kwargs["messages"][0]["content"] == prompt
|
||||
assert args[3].kwargs["messages"][1]["role"] == "user"
|
||||
assert args[3].kwargs["messages"][1]["content"] == "message 2"
|
||||
assert args[3].kwargs["messages"][2]["role"] == "assistant"
|
||||
assert args[3].kwargs["messages"][2]["content"] == "response 2"
|
||||
assert args[3].kwargs["messages"][3]["role"] == "user"
|
||||
assert args[3].kwargs["messages"][3]["content"] == "message 3"
|
||||
assert args[3].kwargs["messages"][4]["role"] == "assistant"
|
||||
assert args[3].kwargs["messages"][4]["content"] == "response 3"
|
||||
assert args[3].kwargs["messages"][5]["role"] == "user"
|
||||
assert args[3].kwargs["messages"][5]["content"] == "message 4"
|
||||
|
||||
# Trimmed down to two user messages.
|
||||
# system + user-3 + assistant-3 + user-4 + assistant-4 + user-5
|
||||
assert len(args[3].kwargs["messages"]) == 6
|
||||
assert args[4].kwargs["messages"][0]["role"] == "system"
|
||||
assert args[4].kwargs["messages"][0]["content"] == prompt
|
||||
assert args[4].kwargs["messages"][1]["role"] == "user"
|
||||
assert args[4].kwargs["messages"][1]["content"] == "message 3"
|
||||
assert args[4].kwargs["messages"][2]["role"] == "assistant"
|
||||
assert args[4].kwargs["messages"][2]["content"] == "response 3"
|
||||
assert args[4].kwargs["messages"][3]["role"] == "user"
|
||||
assert args[4].kwargs["messages"][3]["content"] == "message 4"
|
||||
assert args[4].kwargs["messages"][4]["role"] == "assistant"
|
||||
assert args[4].kwargs["messages"][4]["content"] == "response 4"
|
||||
assert args[4].kwargs["messages"][5]["role"] == "user"
|
||||
assert args[4].kwargs["messages"][5]["content"] == "message 5"
|
||||
|
||||
|
||||
async def test_message_history_pruning(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test that old message histories are pruned."""
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
):
|
||||
# Create 3 different message histories
|
||||
conversation_ids: list[str] = []
|
||||
for i in range(3):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
f"message {i+1}",
|
||||
conversation_id=None,
|
||||
context=Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
assert isinstance(result.conversation_id, str)
|
||||
conversation_ids.append(result.conversation_id)
|
||||
|
||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
assert len(agent._history) == 3
|
||||
assert agent._history.keys() == set(conversation_ids)
|
||||
|
||||
# Modify the timestamps of the first 2 histories so they will be pruned
|
||||
# on the next cycle.
|
||||
for conversation_id in conversation_ids[:2]:
|
||||
# Move back 2 hours
|
||||
agent._history[conversation_id].timestamp -= 2 * 60 * 60
|
||||
|
||||
# Next cycle
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"test message",
|
||||
conversation_id=None,
|
||||
context=Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
|
||||
# Only the most recent histories should remain
|
||||
assert len(agent._history) == 2
|
||||
assert conversation_ids[-1] in agent._history
|
||||
assert result.conversation_id in agent._history
|
||||
|
||||
|
||||
async def test_message_history_unlimited(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test that message history is not trimmed when max_history = 0."""
|
||||
conversation_id = "1234"
|
||||
with (
|
||||
patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
),
|
||||
patch.object(mock_config_entry, "options", {ollama.CONF_MAX_HISTORY: 0}),
|
||||
):
|
||||
for i in range(100):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
f"message {i+1}",
|
||||
conversation_id=conversation_id,
|
||||
context=Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
|
||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert len(agent._history) == 1
|
||||
assert conversation_id in agent._history
|
||||
assert agent._history[conversation_id].num_user_messages == 100
|
||||
|
||||
|
||||
async def test_error_handling(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test error handling during converse."""
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ResponseError("test error"),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
|
||||
|
||||
async def test_template_error(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test that template error handling works."""
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={
|
||||
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
||||
},
|
||||
)
|
||||
with patch(
|
||||
"ollama.AsyncClient.list",
|
||||
):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
|
||||
|
||||
async def test_conversation_agent(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test OllamaConversationEntity."""
|
||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
assert agent.supported_languages == MATCH_ALL
|
|
@ -1,351 +1,17 @@
|
|||
"""Tests for the Ollama integration."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
from httpx import ConnectError
|
||||
from ollama import Message, ResponseError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import conversation, ollama
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
intent,
|
||||
)
|
||||
from homeassistant.components import ollama
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_chat(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
area_registry: ar.AreaRegistry,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test that the chat function is called with the appropriate arguments."""
|
||||
|
||||
# Create some areas, devices, and entities
|
||||
area_kitchen = area_registry.async_get_or_create("kitchen_id")
|
||||
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
|
||||
area_bedroom = area_registry.async_get_or_create("bedroom_id")
|
||||
area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom")
|
||||
area_office = area_registry.async_get_or_create("office_id")
|
||||
area_office = area_registry.async_update(area_office.id, name="office")
|
||||
|
||||
entry = MockConfigEntry()
|
||||
entry.add_to_hass(hass)
|
||||
kitchen_device = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections=set(),
|
||||
identifiers={("demo", "id-1234")},
|
||||
)
|
||||
device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id)
|
||||
|
||||
kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234")
|
||||
kitchen_light = entity_registry.async_update_entity(
|
||||
kitchen_light.entity_id, device_id=kitchen_device.id
|
||||
)
|
||||
hass.states.async_set(
|
||||
kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
|
||||
)
|
||||
|
||||
bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678")
|
||||
bedroom_light = entity_registry.async_update_entity(
|
||||
bedroom_light.entity_id, area_id=area_bedroom.id
|
||||
)
|
||||
hass.states.async_set(
|
||||
bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"}
|
||||
)
|
||||
|
||||
# Hide the office light
|
||||
office_light = entity_registry.async_get_or_create("light", "demo", "ABCD")
|
||||
office_light = entity_registry.async_update_entity(
|
||||
office_light.entity_id, area_id=area_office.id
|
||||
)
|
||||
hass.states.async_set(
|
||||
office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"}
|
||||
)
|
||||
async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False)
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
) as mock_chat:
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"test message",
|
||||
None,
|
||||
Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
|
||||
assert mock_chat.call_count == 1
|
||||
args = mock_chat.call_args.kwargs
|
||||
prompt = args["messages"][0]["content"]
|
||||
|
||||
assert args["model"] == "test model"
|
||||
assert args["messages"] == [
|
||||
Message({"role": "system", "content": prompt}),
|
||||
Message({"role": "user", "content": "test message"}),
|
||||
]
|
||||
|
||||
# Verify only exposed devices/areas are in prompt
|
||||
assert "kitchen light" in prompt
|
||||
assert "bedroom light" in prompt
|
||||
assert "office light" not in prompt
|
||||
assert "office" not in prompt
|
||||
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
assert result.response.speech["plain"]["speech"] == "test response"
|
||||
|
||||
|
||||
async def test_message_history_trimming(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test that a single message history is trimmed according to the config."""
|
||||
response_idx = 0
|
||||
|
||||
def response(*args, **kwargs) -> dict:
|
||||
nonlocal response_idx
|
||||
response_idx += 1
|
||||
return {"message": {"role": "assistant", "content": f"response {response_idx}"}}
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
side_effect=response,
|
||||
) as mock_chat:
|
||||
# mock_init_component sets "max_history" to 2
|
||||
for i in range(5):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
f"message {i+1}",
|
||||
conversation_id="1234",
|
||||
context=Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
|
||||
assert mock_chat.call_count == 5
|
||||
args = mock_chat.call_args_list
|
||||
prompt = args[0].kwargs["messages"][0]["content"]
|
||||
|
||||
# system + user-1
|
||||
assert len(args[0].kwargs["messages"]) == 2
|
||||
assert args[0].kwargs["messages"][1]["content"] == "message 1"
|
||||
|
||||
# Full history
|
||||
# system + user-1 + assistant-1 + user-2
|
||||
assert len(args[1].kwargs["messages"]) == 4
|
||||
assert args[1].kwargs["messages"][0]["role"] == "system"
|
||||
assert args[1].kwargs["messages"][0]["content"] == prompt
|
||||
assert args[1].kwargs["messages"][1]["role"] == "user"
|
||||
assert args[1].kwargs["messages"][1]["content"] == "message 1"
|
||||
assert args[1].kwargs["messages"][2]["role"] == "assistant"
|
||||
assert args[1].kwargs["messages"][2]["content"] == "response 1"
|
||||
assert args[1].kwargs["messages"][3]["role"] == "user"
|
||||
assert args[1].kwargs["messages"][3]["content"] == "message 2"
|
||||
|
||||
# Full history
|
||||
# system + user-1 + assistant-1 + user-2 + assistant-2 + user-3
|
||||
assert len(args[2].kwargs["messages"]) == 6
|
||||
assert args[2].kwargs["messages"][0]["role"] == "system"
|
||||
assert args[2].kwargs["messages"][0]["content"] == prompt
|
||||
assert args[2].kwargs["messages"][1]["role"] == "user"
|
||||
assert args[2].kwargs["messages"][1]["content"] == "message 1"
|
||||
assert args[2].kwargs["messages"][2]["role"] == "assistant"
|
||||
assert args[2].kwargs["messages"][2]["content"] == "response 1"
|
||||
assert args[2].kwargs["messages"][3]["role"] == "user"
|
||||
assert args[2].kwargs["messages"][3]["content"] == "message 2"
|
||||
assert args[2].kwargs["messages"][4]["role"] == "assistant"
|
||||
assert args[2].kwargs["messages"][4]["content"] == "response 2"
|
||||
assert args[2].kwargs["messages"][5]["role"] == "user"
|
||||
assert args[2].kwargs["messages"][5]["content"] == "message 3"
|
||||
|
||||
# Trimmed down to two user messages.
|
||||
# system + user-2 + assistant-2 + user-3 + assistant-3 + user-4
|
||||
assert len(args[3].kwargs["messages"]) == 6
|
||||
assert args[3].kwargs["messages"][0]["role"] == "system"
|
||||
assert args[3].kwargs["messages"][0]["content"] == prompt
|
||||
assert args[3].kwargs["messages"][1]["role"] == "user"
|
||||
assert args[3].kwargs["messages"][1]["content"] == "message 2"
|
||||
assert args[3].kwargs["messages"][2]["role"] == "assistant"
|
||||
assert args[3].kwargs["messages"][2]["content"] == "response 2"
|
||||
assert args[3].kwargs["messages"][3]["role"] == "user"
|
||||
assert args[3].kwargs["messages"][3]["content"] == "message 3"
|
||||
assert args[3].kwargs["messages"][4]["role"] == "assistant"
|
||||
assert args[3].kwargs["messages"][4]["content"] == "response 3"
|
||||
assert args[3].kwargs["messages"][5]["role"] == "user"
|
||||
assert args[3].kwargs["messages"][5]["content"] == "message 4"
|
||||
|
||||
# Trimmed down to two user messages.
|
||||
# system + user-3 + assistant-3 + user-4 + assistant-4 + user-5
|
||||
assert len(args[3].kwargs["messages"]) == 6
|
||||
assert args[4].kwargs["messages"][0]["role"] == "system"
|
||||
assert args[4].kwargs["messages"][0]["content"] == prompt
|
||||
assert args[4].kwargs["messages"][1]["role"] == "user"
|
||||
assert args[4].kwargs["messages"][1]["content"] == "message 3"
|
||||
assert args[4].kwargs["messages"][2]["role"] == "assistant"
|
||||
assert args[4].kwargs["messages"][2]["content"] == "response 3"
|
||||
assert args[4].kwargs["messages"][3]["role"] == "user"
|
||||
assert args[4].kwargs["messages"][3]["content"] == "message 4"
|
||||
assert args[4].kwargs["messages"][4]["role"] == "assistant"
|
||||
assert args[4].kwargs["messages"][4]["content"] == "response 4"
|
||||
assert args[4].kwargs["messages"][5]["role"] == "user"
|
||||
assert args[4].kwargs["messages"][5]["content"] == "message 5"
|
||||
|
||||
|
||||
async def test_message_history_pruning(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test that old message histories are pruned."""
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
):
|
||||
# Create 3 different message histories
|
||||
conversation_ids: list[str] = []
|
||||
for i in range(3):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
f"message {i+1}",
|
||||
conversation_id=None,
|
||||
context=Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
assert isinstance(result.conversation_id, str)
|
||||
conversation_ids.append(result.conversation_id)
|
||||
|
||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
assert isinstance(agent, ollama.OllamaAgent)
|
||||
assert len(agent._history) == 3
|
||||
assert agent._history.keys() == set(conversation_ids)
|
||||
|
||||
# Modify the timestamps of the first 2 histories so they will be pruned
|
||||
# on the next cycle.
|
||||
for conversation_id in conversation_ids[:2]:
|
||||
# Move back 2 hours
|
||||
agent._history[conversation_id].timestamp -= 2 * 60 * 60
|
||||
|
||||
# Next cycle
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"test message",
|
||||
conversation_id=None,
|
||||
context=Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
|
||||
# Only the most recent histories should remain
|
||||
assert len(agent._history) == 2
|
||||
assert conversation_ids[-1] in agent._history
|
||||
assert result.conversation_id in agent._history
|
||||
|
||||
|
||||
async def test_message_history_unlimited(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test that message history is not trimmed when max_history = 0."""
|
||||
conversation_id = "1234"
|
||||
with (
|
||||
patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
),
|
||||
patch.object(mock_config_entry, "options", {ollama.CONF_MAX_HISTORY: 0}),
|
||||
):
|
||||
for i in range(100):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
f"message {i+1}",
|
||||
conversation_id=conversation_id,
|
||||
context=Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
|
||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
assert isinstance(agent, ollama.OllamaAgent)
|
||||
|
||||
assert len(agent._history) == 1
|
||||
assert conversation_id in agent._history
|
||||
assert agent._history[conversation_id].num_user_messages == 100
|
||||
|
||||
|
||||
async def test_error_handling(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test error handling during converse."""
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ResponseError("test error"),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
|
||||
|
||||
async def test_template_error(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test that template error handling works."""
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={
|
||||
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
||||
},
|
||||
)
|
||||
with patch(
|
||||
"ollama.AsyncClient.list",
|
||||
):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
|
||||
|
||||
async def test_conversation_agent(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test OllamaAgent."""
|
||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
assert agent.supported_languages == MATCH_ALL
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "error"),
|
||||
[
|
||||
|
|
Loading…
Add table
Reference in a new issue