Add Ollama Conversation Agent Entity (#116363)
* Add ConversationEntity to OLlama integration * Add assist_pipeline dependencies
This commit is contained in:
parent
eced3b0f57
commit
f1dda8ef63
5 changed files with 617 additions and 557 deletions
|
@ -4,40 +4,17 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import ollama
|
import ollama
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import CONF_URL, MATCH_ALL
|
from homeassistant.const import CONF_URL, Platform
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
|
from homeassistant.exceptions import ConfigEntryNotReady
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import config_validation as cv
|
||||||
area_registry as ar,
|
|
||||||
config_validation as cv,
|
|
||||||
device_registry as dr,
|
|
||||||
entity_registry as er,
|
|
||||||
intent,
|
|
||||||
template,
|
|
||||||
)
|
|
||||||
from homeassistant.util import ulid
|
|
||||||
|
|
||||||
from .const import (
|
from .const import CONF_MAX_HISTORY, CONF_MODEL, CONF_PROMPT, DEFAULT_TIMEOUT, DOMAIN
|
||||||
CONF_MAX_HISTORY,
|
|
||||||
CONF_MODEL,
|
|
||||||
CONF_PROMPT,
|
|
||||||
DEFAULT_MAX_HISTORY,
|
|
||||||
DEFAULT_PROMPT,
|
|
||||||
DEFAULT_TIMEOUT,
|
|
||||||
DOMAIN,
|
|
||||||
KEEP_ALIVE_FOREVER,
|
|
||||||
MAX_HISTORY_SECONDS,
|
|
||||||
)
|
|
||||||
from .models import ExposedEntity, MessageHistory, MessageRole
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -46,11 +23,11 @@ __all__ = [
|
||||||
"CONF_PROMPT",
|
"CONF_PROMPT",
|
||||||
"CONF_MODEL",
|
"CONF_MODEL",
|
||||||
"CONF_MAX_HISTORY",
|
"CONF_MAX_HISTORY",
|
||||||
"MAX_HISTORY_NO_LIMIT",
|
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
]
|
]
|
||||||
|
|
||||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||||
|
PLATFORMS = (Platform.CONVERSATION,)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
|
@ -65,202 +42,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
|
|
||||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
|
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
|
||||||
|
|
||||||
conversation.async_set_agent(hass, entry, OllamaAgent(hass, entry))
|
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Unload Ollama."""
|
"""Unload Ollama."""
|
||||||
|
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
||||||
|
return False
|
||||||
hass.data[DOMAIN].pop(entry.entry_id)
|
hass.data[DOMAIN].pop(entry.entry_id)
|
||||||
conversation.async_unset_agent(hass, entry)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class OllamaAgent(conversation.AbstractConversationAgent):
|
|
||||||
"""Ollama conversation agent."""
|
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
|
|
||||||
"""Initialize the agent."""
|
|
||||||
self.hass = hass
|
|
||||||
self.entry = entry
|
|
||||||
|
|
||||||
# conversation id -> message history
|
|
||||||
self._history: dict[str, MessageHistory] = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
|
||||||
"""Return a list of supported languages."""
|
|
||||||
return MATCH_ALL
|
|
||||||
|
|
||||||
async def async_process(
|
|
||||||
self, user_input: conversation.ConversationInput
|
|
||||||
) -> conversation.ConversationResult:
|
|
||||||
"""Process a sentence."""
|
|
||||||
settings = {**self.entry.data, **self.entry.options}
|
|
||||||
|
|
||||||
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
|
||||||
conversation_id = user_input.conversation_id or ulid.ulid_now()
|
|
||||||
model = settings[CONF_MODEL]
|
|
||||||
|
|
||||||
# Look up message history
|
|
||||||
message_history: MessageHistory | None = None
|
|
||||||
message_history = self._history.get(conversation_id)
|
|
||||||
if message_history is None:
|
|
||||||
# New history
|
|
||||||
#
|
|
||||||
# Render prompt and error out early if there's a problem
|
|
||||||
raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT)
|
|
||||||
try:
|
|
||||||
prompt = self._generate_prompt(raw_prompt)
|
|
||||||
_LOGGER.debug("Prompt: %s", prompt)
|
|
||||||
except TemplateError as err:
|
|
||||||
_LOGGER.error("Error rendering prompt: %s", err)
|
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
|
||||||
intent_response.async_set_error(
|
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
|
||||||
f"Sorry, I had a problem generating my prompt: {err}",
|
|
||||||
)
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response, conversation_id=conversation_id
|
|
||||||
)
|
|
||||||
|
|
||||||
message_history = MessageHistory(
|
|
||||||
timestamp=time.monotonic(),
|
|
||||||
messages=[
|
|
||||||
ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
self._history[conversation_id] = message_history
|
|
||||||
else:
|
|
||||||
# Bump timestamp so this conversation won't get cleaned up
|
|
||||||
message_history.timestamp = time.monotonic()
|
|
||||||
|
|
||||||
# Clean up old histories
|
|
||||||
self._prune_old_histories()
|
|
||||||
|
|
||||||
# Trim this message history to keep a maximum number of *user* messages
|
|
||||||
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
|
|
||||||
self._trim_history(message_history, max_messages)
|
|
||||||
|
|
||||||
# Add new user message
|
|
||||||
message_history.messages.append(
|
|
||||||
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get response
|
|
||||||
try:
|
|
||||||
response = await client.chat(
|
|
||||||
model=model,
|
|
||||||
# Make a copy of the messages because we mutate the list later
|
|
||||||
messages=list(message_history.messages),
|
|
||||||
stream=False,
|
|
||||||
keep_alive=KEEP_ALIVE_FOREVER,
|
|
||||||
)
|
|
||||||
except (ollama.RequestError, ollama.ResponseError) as err:
|
|
||||||
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
|
||||||
intent_response.async_set_error(
|
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
|
||||||
f"Sorry, I had a problem talking to the Ollama server: {err}",
|
|
||||||
)
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response, conversation_id=conversation_id
|
|
||||||
)
|
|
||||||
|
|
||||||
response_message = response["message"]
|
|
||||||
message_history.messages.append(
|
|
||||||
ollama.Message(
|
|
||||||
role=response_message["role"], content=response_message["content"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create intent response
|
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
|
||||||
intent_response.async_set_speech(response_message["content"])
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response, conversation_id=conversation_id
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prune_old_histories(self) -> None:
|
|
||||||
"""Remove old message histories."""
|
|
||||||
now = time.monotonic()
|
|
||||||
self._history = {
|
|
||||||
conversation_id: message_history
|
|
||||||
for conversation_id, message_history in self._history.items()
|
|
||||||
if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
|
|
||||||
}
|
|
||||||
|
|
||||||
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
|
|
||||||
"""Trims excess messages from a single history."""
|
|
||||||
if max_messages < 1:
|
|
||||||
# Keep all messages
|
|
||||||
return
|
|
||||||
|
|
||||||
if message_history.num_user_messages >= max_messages:
|
|
||||||
# Trim history but keep system prompt (first message).
|
|
||||||
# Every other message should be an assistant message, so keep 2x
|
|
||||||
# message objects.
|
|
||||||
num_keep = 2 * max_messages
|
|
||||||
drop_index = len(message_history.messages) - num_keep
|
|
||||||
message_history.messages = [
|
|
||||||
message_history.messages[0]
|
|
||||||
] + message_history.messages[drop_index:]
|
|
||||||
|
|
||||||
def _generate_prompt(self, raw_prompt: str) -> str:
|
|
||||||
"""Generate a prompt for the user."""
|
|
||||||
return template.Template(raw_prompt, self.hass).async_render(
|
|
||||||
{
|
|
||||||
"ha_name": self.hass.config.location_name,
|
|
||||||
"ha_language": self.hass.config.language,
|
|
||||||
"exposed_entities": self._get_exposed_entities(),
|
|
||||||
},
|
|
||||||
parse_result=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_exposed_entities(self) -> list[ExposedEntity]:
|
|
||||||
"""Get state list of exposed entities."""
|
|
||||||
area_registry = ar.async_get(self.hass)
|
|
||||||
entity_registry = er.async_get(self.hass)
|
|
||||||
device_registry = dr.async_get(self.hass)
|
|
||||||
|
|
||||||
exposed_entities = []
|
|
||||||
exposed_states = [
|
|
||||||
state
|
|
||||||
for state in self.hass.states.async_all()
|
|
||||||
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
|
|
||||||
]
|
|
||||||
|
|
||||||
for state in exposed_states:
|
|
||||||
entity = entity_registry.async_get(state.entity_id)
|
|
||||||
names = [state.name]
|
|
||||||
area_names = []
|
|
||||||
|
|
||||||
if entity is not None:
|
|
||||||
# Add aliases
|
|
||||||
names.extend(entity.aliases)
|
|
||||||
if entity.area_id and (
|
|
||||||
area := area_registry.async_get_area(entity.area_id)
|
|
||||||
):
|
|
||||||
# Entity is in area
|
|
||||||
area_names.append(area.name)
|
|
||||||
area_names.extend(area.aliases)
|
|
||||||
elif entity.device_id and (
|
|
||||||
device := device_registry.async_get(entity.device_id)
|
|
||||||
):
|
|
||||||
# Check device area
|
|
||||||
if device.area_id and (
|
|
||||||
area := area_registry.async_get_area(device.area_id)
|
|
||||||
):
|
|
||||||
area_names.append(area.name)
|
|
||||||
area_names.extend(area.aliases)
|
|
||||||
|
|
||||||
exposed_entities.append(
|
|
||||||
ExposedEntity(
|
|
||||||
entity_id=state.entity_id,
|
|
||||||
state=state,
|
|
||||||
names=names,
|
|
||||||
area_names=area_names,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return exposed_entities
|
|
||||||
|
|
258
homeassistant/components/ollama/conversation.py
Normal file
258
homeassistant/components/ollama/conversation.py
Normal file
|
@ -0,0 +1,258 @@
|
||||||
|
"""The conversation platform for the Ollama integration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import ollama
|
||||||
|
|
||||||
|
from homeassistant.components import assist_pipeline, conversation
|
||||||
|
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import MATCH_ALL
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.exceptions import TemplateError
|
||||||
|
from homeassistant.helpers import (
|
||||||
|
area_registry as ar,
|
||||||
|
device_registry as dr,
|
||||||
|
entity_registry as er,
|
||||||
|
intent,
|
||||||
|
template,
|
||||||
|
)
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
from homeassistant.util import ulid
|
||||||
|
|
||||||
|
from .const import (
|
||||||
|
CONF_MAX_HISTORY,
|
||||||
|
CONF_MODEL,
|
||||||
|
CONF_PROMPT,
|
||||||
|
DEFAULT_MAX_HISTORY,
|
||||||
|
DEFAULT_PROMPT,
|
||||||
|
DOMAIN,
|
||||||
|
KEEP_ALIVE_FOREVER,
|
||||||
|
MAX_HISTORY_SECONDS,
|
||||||
|
)
|
||||||
|
from .models import ExposedEntity, MessageHistory, MessageRole
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up conversation entities."""
|
||||||
|
agent = OllamaConversationEntity(hass, config_entry)
|
||||||
|
async_add_entities([agent])
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaConversationEntity(
|
||||||
|
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||||
|
):
|
||||||
|
"""Ollama conversation agent."""
|
||||||
|
|
||||||
|
_attr_has_entity_name = True
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
|
||||||
|
"""Initialize the agent."""
|
||||||
|
self.hass = hass
|
||||||
|
self.entry = entry
|
||||||
|
|
||||||
|
# conversation id -> message history
|
||||||
|
self._history: dict[str, MessageHistory] = {}
|
||||||
|
self._attr_name = entry.title
|
||||||
|
self._attr_unique_id = entry.entry_id
|
||||||
|
|
||||||
|
async def async_added_to_hass(self) -> None:
|
||||||
|
"""When entity is added to Home Assistant."""
|
||||||
|
await super().async_added_to_hass()
|
||||||
|
assist_pipeline.async_migrate_engine(
|
||||||
|
self.hass, "conversation", self.entry.entry_id, self.entity_id
|
||||||
|
)
|
||||||
|
conversation.async_set_agent(self.hass, self.entry, self)
|
||||||
|
|
||||||
|
async def async_will_remove_from_hass(self) -> None:
|
||||||
|
"""When entity will be removed from Home Assistant."""
|
||||||
|
conversation.async_unset_agent(self.hass, self.entry)
|
||||||
|
await super().async_will_remove_from_hass()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
return MATCH_ALL
|
||||||
|
|
||||||
|
async def async_process(
|
||||||
|
self, user_input: conversation.ConversationInput
|
||||||
|
) -> conversation.ConversationResult:
|
||||||
|
"""Process a sentence."""
|
||||||
|
settings = {**self.entry.data, **self.entry.options}
|
||||||
|
|
||||||
|
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
||||||
|
conversation_id = user_input.conversation_id or ulid.ulid_now()
|
||||||
|
model = settings[CONF_MODEL]
|
||||||
|
|
||||||
|
# Look up message history
|
||||||
|
message_history: MessageHistory | None = None
|
||||||
|
message_history = self._history.get(conversation_id)
|
||||||
|
if message_history is None:
|
||||||
|
# New history
|
||||||
|
#
|
||||||
|
# Render prompt and error out early if there's a problem
|
||||||
|
raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||||
|
try:
|
||||||
|
prompt = self._generate_prompt(raw_prompt)
|
||||||
|
_LOGGER.debug("Prompt: %s", prompt)
|
||||||
|
except TemplateError as err:
|
||||||
|
_LOGGER.error("Error rendering prompt: %s", err)
|
||||||
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
|
intent_response.async_set_error(
|
||||||
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
|
f"Sorry, I had a problem generating my prompt: {err}",
|
||||||
|
)
|
||||||
|
return conversation.ConversationResult(
|
||||||
|
response=intent_response, conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
message_history = MessageHistory(
|
||||||
|
timestamp=time.monotonic(),
|
||||||
|
messages=[
|
||||||
|
ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self._history[conversation_id] = message_history
|
||||||
|
else:
|
||||||
|
# Bump timestamp so this conversation won't get cleaned up
|
||||||
|
message_history.timestamp = time.monotonic()
|
||||||
|
|
||||||
|
# Clean up old histories
|
||||||
|
self._prune_old_histories()
|
||||||
|
|
||||||
|
# Trim this message history to keep a maximum number of *user* messages
|
||||||
|
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
|
||||||
|
self._trim_history(message_history, max_messages)
|
||||||
|
|
||||||
|
# Add new user message
|
||||||
|
message_history.messages.append(
|
||||||
|
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get response
|
||||||
|
try:
|
||||||
|
response = await client.chat(
|
||||||
|
model=model,
|
||||||
|
# Make a copy of the messages because we mutate the list later
|
||||||
|
messages=list(message_history.messages),
|
||||||
|
stream=False,
|
||||||
|
keep_alive=KEEP_ALIVE_FOREVER,
|
||||||
|
)
|
||||||
|
except (ollama.RequestError, ollama.ResponseError) as err:
|
||||||
|
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
||||||
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
|
intent_response.async_set_error(
|
||||||
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
|
f"Sorry, I had a problem talking to the Ollama server: {err}",
|
||||||
|
)
|
||||||
|
return conversation.ConversationResult(
|
||||||
|
response=intent_response, conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
response_message = response["message"]
|
||||||
|
message_history.messages.append(
|
||||||
|
ollama.Message(
|
||||||
|
role=response_message["role"], content=response_message["content"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create intent response
|
||||||
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
|
intent_response.async_set_speech(response_message["content"])
|
||||||
|
return conversation.ConversationResult(
|
||||||
|
response=intent_response, conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prune_old_histories(self) -> None:
|
||||||
|
"""Remove old message histories."""
|
||||||
|
now = time.monotonic()
|
||||||
|
self._history = {
|
||||||
|
conversation_id: message_history
|
||||||
|
for conversation_id, message_history in self._history.items()
|
||||||
|
if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
|
||||||
|
}
|
||||||
|
|
||||||
|
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
|
||||||
|
"""Trims excess messages from a single history."""
|
||||||
|
if max_messages < 1:
|
||||||
|
# Keep all messages
|
||||||
|
return
|
||||||
|
|
||||||
|
if message_history.num_user_messages >= max_messages:
|
||||||
|
# Trim history but keep system prompt (first message).
|
||||||
|
# Every other message should be an assistant message, so keep 2x
|
||||||
|
# message objects.
|
||||||
|
num_keep = 2 * max_messages
|
||||||
|
drop_index = len(message_history.messages) - num_keep
|
||||||
|
message_history.messages = [
|
||||||
|
message_history.messages[0]
|
||||||
|
] + message_history.messages[drop_index:]
|
||||||
|
|
||||||
|
def _generate_prompt(self, raw_prompt: str) -> str:
|
||||||
|
"""Generate a prompt for the user."""
|
||||||
|
return template.Template(raw_prompt, self.hass).async_render(
|
||||||
|
{
|
||||||
|
"ha_name": self.hass.config.location_name,
|
||||||
|
"ha_language": self.hass.config.language,
|
||||||
|
"exposed_entities": self._get_exposed_entities(),
|
||||||
|
},
|
||||||
|
parse_result=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_exposed_entities(self) -> list[ExposedEntity]:
|
||||||
|
"""Get state list of exposed entities."""
|
||||||
|
area_registry = ar.async_get(self.hass)
|
||||||
|
entity_registry = er.async_get(self.hass)
|
||||||
|
device_registry = dr.async_get(self.hass)
|
||||||
|
|
||||||
|
exposed_entities = []
|
||||||
|
exposed_states = [
|
||||||
|
state
|
||||||
|
for state in self.hass.states.async_all()
|
||||||
|
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
|
||||||
|
]
|
||||||
|
|
||||||
|
for state in exposed_states:
|
||||||
|
entity = entity_registry.async_get(state.entity_id)
|
||||||
|
names = [state.name]
|
||||||
|
area_names = []
|
||||||
|
|
||||||
|
if entity is not None:
|
||||||
|
# Add aliases
|
||||||
|
names.extend(entity.aliases)
|
||||||
|
if entity.area_id and (
|
||||||
|
area := area_registry.async_get_area(entity.area_id)
|
||||||
|
):
|
||||||
|
# Entity is in area
|
||||||
|
area_names.append(area.name)
|
||||||
|
area_names.extend(area.aliases)
|
||||||
|
elif entity.device_id and (
|
||||||
|
device := device_registry.async_get(entity.device_id)
|
||||||
|
):
|
||||||
|
# Check device area
|
||||||
|
if device.area_id and (
|
||||||
|
area := area_registry.async_get_area(device.area_id)
|
||||||
|
):
|
||||||
|
area_names.append(area.name)
|
||||||
|
area_names.extend(area.aliases)
|
||||||
|
|
||||||
|
exposed_entities.append(
|
||||||
|
ExposedEntity(
|
||||||
|
entity_id=state.entity_id,
|
||||||
|
state=state,
|
||||||
|
names=names,
|
||||||
|
area_names=area_names,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return exposed_entities
|
|
@ -1,6 +1,7 @@
|
||||||
{
|
{
|
||||||
"domain": "ollama",
|
"domain": "ollama",
|
||||||
"name": "Ollama",
|
"name": "Ollama",
|
||||||
|
"after_dependencies": ["assist_pipeline"],
|
||||||
"codeowners": ["@synesthesiam"],
|
"codeowners": ["@synesthesiam"],
|
||||||
"config_flow": true,
|
"config_flow": true,
|
||||||
"dependencies": ["conversation"],
|
"dependencies": ["conversation"],
|
||||||
|
|
347
tests/components/ollama/test_conversation.py
Normal file
347
tests/components/ollama/test_conversation.py
Normal file
|
@ -0,0 +1,347 @@
|
||||||
|
"""Tests for the Ollama integration."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from ollama import Message, ResponseError
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components import conversation, ollama
|
||||||
|
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||||
|
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
||||||
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
from homeassistant.helpers import (
|
||||||
|
area_registry as ar,
|
||||||
|
device_registry as dr,
|
||||||
|
entity_registry as er,
|
||||||
|
intent,
|
||||||
|
)
|
||||||
|
|
||||||
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
||||||
|
async def test_chat(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
area_registry: ar.AreaRegistry,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
agent_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test that the chat function is called with the appropriate arguments."""
|
||||||
|
|
||||||
|
if agent_id is None:
|
||||||
|
agent_id = mock_config_entry.entry_id
|
||||||
|
|
||||||
|
# Create some areas, devices, and entities
|
||||||
|
area_kitchen = area_registry.async_get_or_create("kitchen_id")
|
||||||
|
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
|
||||||
|
area_bedroom = area_registry.async_get_or_create("bedroom_id")
|
||||||
|
area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom")
|
||||||
|
area_office = area_registry.async_get_or_create("office_id")
|
||||||
|
area_office = area_registry.async_update(area_office.id, name="office")
|
||||||
|
|
||||||
|
entry = MockConfigEntry()
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
kitchen_device = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections=set(),
|
||||||
|
identifiers={("demo", "id-1234")},
|
||||||
|
)
|
||||||
|
device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id)
|
||||||
|
|
||||||
|
kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234")
|
||||||
|
kitchen_light = entity_registry.async_update_entity(
|
||||||
|
kitchen_light.entity_id, device_id=kitchen_device.id
|
||||||
|
)
|
||||||
|
hass.states.async_set(
|
||||||
|
kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
|
||||||
|
)
|
||||||
|
|
||||||
|
bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678")
|
||||||
|
bedroom_light = entity_registry.async_update_entity(
|
||||||
|
bedroom_light.entity_id, area_id=area_bedroom.id
|
||||||
|
)
|
||||||
|
hass.states.async_set(
|
||||||
|
bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hide the office light
|
||||||
|
office_light = entity_registry.async_get_or_create("light", "demo", "ABCD")
|
||||||
|
office_light = entity_registry.async_update_entity(
|
||||||
|
office_light.entity_id, area_id=area_office.id
|
||||||
|
)
|
||||||
|
hass.states.async_set(
|
||||||
|
office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"}
|
||||||
|
)
|
||||||
|
async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"ollama.AsyncClient.chat",
|
||||||
|
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||||
|
) as mock_chat:
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"test message",
|
||||||
|
None,
|
||||||
|
Context(),
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_chat.call_count == 1
|
||||||
|
args = mock_chat.call_args.kwargs
|
||||||
|
prompt = args["messages"][0]["content"]
|
||||||
|
|
||||||
|
assert args["model"] == "test model"
|
||||||
|
assert args["messages"] == [
|
||||||
|
Message({"role": "system", "content": prompt}),
|
||||||
|
Message({"role": "user", "content": "test message"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Verify only exposed devices/areas are in prompt
|
||||||
|
assert "kitchen light" in prompt
|
||||||
|
assert "bedroom light" in prompt
|
||||||
|
assert "office light" not in prompt
|
||||||
|
assert "office" not in prompt
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
), result
|
||||||
|
assert result.response.speech["plain"]["speech"] == "test response"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_message_history_trimming(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||||
|
) -> None:
|
||||||
|
"""Test that a single message history is trimmed according to the config."""
|
||||||
|
response_idx = 0
|
||||||
|
|
||||||
|
def response(*args, **kwargs) -> dict:
|
||||||
|
nonlocal response_idx
|
||||||
|
response_idx += 1
|
||||||
|
return {"message": {"role": "assistant", "content": f"response {response_idx}"}}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"ollama.AsyncClient.chat",
|
||||||
|
side_effect=response,
|
||||||
|
) as mock_chat:
|
||||||
|
# mock_init_component sets "max_history" to 2
|
||||||
|
for i in range(5):
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
f"message {i+1}",
|
||||||
|
conversation_id="1234",
|
||||||
|
context=Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
), result
|
||||||
|
|
||||||
|
assert mock_chat.call_count == 5
|
||||||
|
args = mock_chat.call_args_list
|
||||||
|
prompt = args[0].kwargs["messages"][0]["content"]
|
||||||
|
|
||||||
|
# system + user-1
|
||||||
|
assert len(args[0].kwargs["messages"]) == 2
|
||||||
|
assert args[0].kwargs["messages"][1]["content"] == "message 1"
|
||||||
|
|
||||||
|
# Full history
|
||||||
|
# system + user-1 + assistant-1 + user-2
|
||||||
|
assert len(args[1].kwargs["messages"]) == 4
|
||||||
|
assert args[1].kwargs["messages"][0]["role"] == "system"
|
||||||
|
assert args[1].kwargs["messages"][0]["content"] == prompt
|
||||||
|
assert args[1].kwargs["messages"][1]["role"] == "user"
|
||||||
|
assert args[1].kwargs["messages"][1]["content"] == "message 1"
|
||||||
|
assert args[1].kwargs["messages"][2]["role"] == "assistant"
|
||||||
|
assert args[1].kwargs["messages"][2]["content"] == "response 1"
|
||||||
|
assert args[1].kwargs["messages"][3]["role"] == "user"
|
||||||
|
assert args[1].kwargs["messages"][3]["content"] == "message 2"
|
||||||
|
|
||||||
|
# Full history
|
||||||
|
# system + user-1 + assistant-1 + user-2 + assistant-2 + user-3
|
||||||
|
assert len(args[2].kwargs["messages"]) == 6
|
||||||
|
assert args[2].kwargs["messages"][0]["role"] == "system"
|
||||||
|
assert args[2].kwargs["messages"][0]["content"] == prompt
|
||||||
|
assert args[2].kwargs["messages"][1]["role"] == "user"
|
||||||
|
assert args[2].kwargs["messages"][1]["content"] == "message 1"
|
||||||
|
assert args[2].kwargs["messages"][2]["role"] == "assistant"
|
||||||
|
assert args[2].kwargs["messages"][2]["content"] == "response 1"
|
||||||
|
assert args[2].kwargs["messages"][3]["role"] == "user"
|
||||||
|
assert args[2].kwargs["messages"][3]["content"] == "message 2"
|
||||||
|
assert args[2].kwargs["messages"][4]["role"] == "assistant"
|
||||||
|
assert args[2].kwargs["messages"][4]["content"] == "response 2"
|
||||||
|
assert args[2].kwargs["messages"][5]["role"] == "user"
|
||||||
|
assert args[2].kwargs["messages"][5]["content"] == "message 3"
|
||||||
|
|
||||||
|
# Trimmed down to two user messages.
|
||||||
|
# system + user-2 + assistant-2 + user-3 + assistant-3 + user-4
|
||||||
|
assert len(args[3].kwargs["messages"]) == 6
|
||||||
|
assert args[3].kwargs["messages"][0]["role"] == "system"
|
||||||
|
assert args[3].kwargs["messages"][0]["content"] == prompt
|
||||||
|
assert args[3].kwargs["messages"][1]["role"] == "user"
|
||||||
|
assert args[3].kwargs["messages"][1]["content"] == "message 2"
|
||||||
|
assert args[3].kwargs["messages"][2]["role"] == "assistant"
|
||||||
|
assert args[3].kwargs["messages"][2]["content"] == "response 2"
|
||||||
|
assert args[3].kwargs["messages"][3]["role"] == "user"
|
||||||
|
assert args[3].kwargs["messages"][3]["content"] == "message 3"
|
||||||
|
assert args[3].kwargs["messages"][4]["role"] == "assistant"
|
||||||
|
assert args[3].kwargs["messages"][4]["content"] == "response 3"
|
||||||
|
assert args[3].kwargs["messages"][5]["role"] == "user"
|
||||||
|
assert args[3].kwargs["messages"][5]["content"] == "message 4"
|
||||||
|
|
||||||
|
# Trimmed down to two user messages.
|
||||||
|
# system + user-3 + assistant-3 + user-4 + assistant-4 + user-5
|
||||||
|
assert len(args[3].kwargs["messages"]) == 6
|
||||||
|
assert args[4].kwargs["messages"][0]["role"] == "system"
|
||||||
|
assert args[4].kwargs["messages"][0]["content"] == prompt
|
||||||
|
assert args[4].kwargs["messages"][1]["role"] == "user"
|
||||||
|
assert args[4].kwargs["messages"][1]["content"] == "message 3"
|
||||||
|
assert args[4].kwargs["messages"][2]["role"] == "assistant"
|
||||||
|
assert args[4].kwargs["messages"][2]["content"] == "response 3"
|
||||||
|
assert args[4].kwargs["messages"][3]["role"] == "user"
|
||||||
|
assert args[4].kwargs["messages"][3]["content"] == "message 4"
|
||||||
|
assert args[4].kwargs["messages"][4]["role"] == "assistant"
|
||||||
|
assert args[4].kwargs["messages"][4]["content"] == "response 4"
|
||||||
|
assert args[4].kwargs["messages"][5]["role"] == "user"
|
||||||
|
assert args[4].kwargs["messages"][5]["content"] == "message 5"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_message_history_pruning(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||||
|
) -> None:
|
||||||
|
"""Test that old message histories are pruned."""
|
||||||
|
with patch(
|
||||||
|
"ollama.AsyncClient.chat",
|
||||||
|
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||||
|
):
|
||||||
|
# Create 3 different message histories
|
||||||
|
conversation_ids: list[str] = []
|
||||||
|
for i in range(3):
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
f"message {i+1}",
|
||||||
|
conversation_id=None,
|
||||||
|
context=Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
), result
|
||||||
|
assert isinstance(result.conversation_id, str)
|
||||||
|
conversation_ids.append(result.conversation_id)
|
||||||
|
|
||||||
|
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||||
|
mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
assert len(agent._history) == 3
|
||||||
|
assert agent._history.keys() == set(conversation_ids)
|
||||||
|
|
||||||
|
# Modify the timestamps of the first 2 histories so they will be pruned
|
||||||
|
# on the next cycle.
|
||||||
|
for conversation_id in conversation_ids[:2]:
|
||||||
|
# Move back 2 hours
|
||||||
|
agent._history[conversation_id].timestamp -= 2 * 60 * 60
|
||||||
|
|
||||||
|
# Next cycle
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"test message",
|
||||||
|
conversation_id=None,
|
||||||
|
context=Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
), result
|
||||||
|
|
||||||
|
# Only the most recent histories should remain
|
||||||
|
assert len(agent._history) == 2
|
||||||
|
assert conversation_ids[-1] in agent._history
|
||||||
|
assert result.conversation_id in agent._history
|
||||||
|
|
||||||
|
|
||||||
|
async def test_message_history_unlimited(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||||
|
) -> None:
|
||||||
|
"""Test that message history is not trimmed when max_history = 0."""
|
||||||
|
conversation_id = "1234"
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"ollama.AsyncClient.chat",
|
||||||
|
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||||
|
),
|
||||||
|
patch.object(mock_config_entry, "options", {ollama.CONF_MAX_HISTORY: 0}),
|
||||||
|
):
|
||||||
|
for i in range(100):
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
f"message {i+1}",
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
context=Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
), result
|
||||||
|
|
||||||
|
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||||
|
mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(agent._history) == 1
|
||||||
|
assert conversation_id in agent._history
|
||||||
|
assert agent._history[conversation_id].num_user_messages == 100
|
||||||
|
|
||||||
|
|
||||||
|
async def test_error_handling(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||||
|
) -> None:
|
||||||
|
"""Test error handling during converse."""
|
||||||
|
with patch(
|
||||||
|
"ollama.AsyncClient.chat",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=ResponseError("test error"),
|
||||||
|
):
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
|
assert result.response.error_code == "unknown", result
|
||||||
|
|
||||||
|
|
||||||
|
async def test_template_error(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
|
) -> None:
|
||||||
|
"""Test that template error handling works."""
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
mock_config_entry,
|
||||||
|
options={
|
||||||
|
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"ollama.AsyncClient.list",
|
||||||
|
):
|
||||||
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
|
assert result.response.error_code == "unknown", result
|
||||||
|
|
||||||
|
|
||||||
|
async def test_conversation_agent(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test OllamaConversationEntity."""
|
||||||
|
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||||
|
mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
assert agent.supported_languages == MATCH_ALL
|
|
@ -1,351 +1,17 @@
|
||||||
"""Tests for the Ollama integration."""
|
"""Tests for the Ollama integration."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from httpx import ConnectError
|
from httpx import ConnectError
|
||||||
from ollama import Message, ResponseError
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import conversation, ollama
|
from homeassistant.components import ollama
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
|
||||||
from homeassistant.core import Context, HomeAssistant
|
|
||||||
from homeassistant.helpers import (
|
|
||||||
area_registry as ar,
|
|
||||||
device_registry as dr,
|
|
||||||
entity_registry as er,
|
|
||||||
intent,
|
|
||||||
)
|
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
|
||||||
async def test_chat(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
mock_config_entry: MockConfigEntry,
|
|
||||||
mock_init_component,
|
|
||||||
area_registry: ar.AreaRegistry,
|
|
||||||
device_registry: dr.DeviceRegistry,
|
|
||||||
entity_registry: er.EntityRegistry,
|
|
||||||
) -> None:
|
|
||||||
"""Test that the chat function is called with the appropriate arguments."""
|
|
||||||
|
|
||||||
# Create some areas, devices, and entities
|
|
||||||
area_kitchen = area_registry.async_get_or_create("kitchen_id")
|
|
||||||
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
|
|
||||||
area_bedroom = area_registry.async_get_or_create("bedroom_id")
|
|
||||||
area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom")
|
|
||||||
area_office = area_registry.async_get_or_create("office_id")
|
|
||||||
area_office = area_registry.async_update(area_office.id, name="office")
|
|
||||||
|
|
||||||
entry = MockConfigEntry()
|
|
||||||
entry.add_to_hass(hass)
|
|
||||||
kitchen_device = device_registry.async_get_or_create(
|
|
||||||
config_entry_id=entry.entry_id,
|
|
||||||
connections=set(),
|
|
||||||
identifiers={("demo", "id-1234")},
|
|
||||||
)
|
|
||||||
device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id)
|
|
||||||
|
|
||||||
kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234")
|
|
||||||
kitchen_light = entity_registry.async_update_entity(
|
|
||||||
kitchen_light.entity_id, device_id=kitchen_device.id
|
|
||||||
)
|
|
||||||
hass.states.async_set(
|
|
||||||
kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
|
|
||||||
)
|
|
||||||
|
|
||||||
bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678")
|
|
||||||
bedroom_light = entity_registry.async_update_entity(
|
|
||||||
bedroom_light.entity_id, area_id=area_bedroom.id
|
|
||||||
)
|
|
||||||
hass.states.async_set(
|
|
||||||
bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hide the office light
|
|
||||||
office_light = entity_registry.async_get_or_create("light", "demo", "ABCD")
|
|
||||||
office_light = entity_registry.async_update_entity(
|
|
||||||
office_light.entity_id, area_id=area_office.id
|
|
||||||
)
|
|
||||||
hass.states.async_set(
|
|
||||||
office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"}
|
|
||||||
)
|
|
||||||
async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"ollama.AsyncClient.chat",
|
|
||||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
|
||||||
) as mock_chat:
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
"test message",
|
|
||||||
None,
|
|
||||||
Context(),
|
|
||||||
agent_id=mock_config_entry.entry_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mock_chat.call_count == 1
|
|
||||||
args = mock_chat.call_args.kwargs
|
|
||||||
prompt = args["messages"][0]["content"]
|
|
||||||
|
|
||||||
assert args["model"] == "test model"
|
|
||||||
assert args["messages"] == [
|
|
||||||
Message({"role": "system", "content": prompt}),
|
|
||||||
Message({"role": "user", "content": "test message"}),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Verify only exposed devices/areas are in prompt
|
|
||||||
assert "kitchen light" in prompt
|
|
||||||
assert "bedroom light" in prompt
|
|
||||||
assert "office light" not in prompt
|
|
||||||
assert "office" not in prompt
|
|
||||||
|
|
||||||
assert (
|
|
||||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
||||||
), result
|
|
||||||
assert result.response.speech["plain"]["speech"] == "test response"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_message_history_trimming(
|
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
|
||||||
) -> None:
|
|
||||||
"""Test that a single message history is trimmed according to the config."""
|
|
||||||
response_idx = 0
|
|
||||||
|
|
||||||
def response(*args, **kwargs) -> dict:
|
|
||||||
nonlocal response_idx
|
|
||||||
response_idx += 1
|
|
||||||
return {"message": {"role": "assistant", "content": f"response {response_idx}"}}
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"ollama.AsyncClient.chat",
|
|
||||||
side_effect=response,
|
|
||||||
) as mock_chat:
|
|
||||||
# mock_init_component sets "max_history" to 2
|
|
||||||
for i in range(5):
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
f"message {i+1}",
|
|
||||||
conversation_id="1234",
|
|
||||||
context=Context(),
|
|
||||||
agent_id=mock_config_entry.entry_id,
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
||||||
), result
|
|
||||||
|
|
||||||
assert mock_chat.call_count == 5
|
|
||||||
args = mock_chat.call_args_list
|
|
||||||
prompt = args[0].kwargs["messages"][0]["content"]
|
|
||||||
|
|
||||||
# system + user-1
|
|
||||||
assert len(args[0].kwargs["messages"]) == 2
|
|
||||||
assert args[0].kwargs["messages"][1]["content"] == "message 1"
|
|
||||||
|
|
||||||
# Full history
|
|
||||||
# system + user-1 + assistant-1 + user-2
|
|
||||||
assert len(args[1].kwargs["messages"]) == 4
|
|
||||||
assert args[1].kwargs["messages"][0]["role"] == "system"
|
|
||||||
assert args[1].kwargs["messages"][0]["content"] == prompt
|
|
||||||
assert args[1].kwargs["messages"][1]["role"] == "user"
|
|
||||||
assert args[1].kwargs["messages"][1]["content"] == "message 1"
|
|
||||||
assert args[1].kwargs["messages"][2]["role"] == "assistant"
|
|
||||||
assert args[1].kwargs["messages"][2]["content"] == "response 1"
|
|
||||||
assert args[1].kwargs["messages"][3]["role"] == "user"
|
|
||||||
assert args[1].kwargs["messages"][3]["content"] == "message 2"
|
|
||||||
|
|
||||||
# Full history
|
|
||||||
# system + user-1 + assistant-1 + user-2 + assistant-2 + user-3
|
|
||||||
assert len(args[2].kwargs["messages"]) == 6
|
|
||||||
assert args[2].kwargs["messages"][0]["role"] == "system"
|
|
||||||
assert args[2].kwargs["messages"][0]["content"] == prompt
|
|
||||||
assert args[2].kwargs["messages"][1]["role"] == "user"
|
|
||||||
assert args[2].kwargs["messages"][1]["content"] == "message 1"
|
|
||||||
assert args[2].kwargs["messages"][2]["role"] == "assistant"
|
|
||||||
assert args[2].kwargs["messages"][2]["content"] == "response 1"
|
|
||||||
assert args[2].kwargs["messages"][3]["role"] == "user"
|
|
||||||
assert args[2].kwargs["messages"][3]["content"] == "message 2"
|
|
||||||
assert args[2].kwargs["messages"][4]["role"] == "assistant"
|
|
||||||
assert args[2].kwargs["messages"][4]["content"] == "response 2"
|
|
||||||
assert args[2].kwargs["messages"][5]["role"] == "user"
|
|
||||||
assert args[2].kwargs["messages"][5]["content"] == "message 3"
|
|
||||||
|
|
||||||
# Trimmed down to two user messages.
|
|
||||||
# system + user-2 + assistant-2 + user-3 + assistant-3 + user-4
|
|
||||||
assert len(args[3].kwargs["messages"]) == 6
|
|
||||||
assert args[3].kwargs["messages"][0]["role"] == "system"
|
|
||||||
assert args[3].kwargs["messages"][0]["content"] == prompt
|
|
||||||
assert args[3].kwargs["messages"][1]["role"] == "user"
|
|
||||||
assert args[3].kwargs["messages"][1]["content"] == "message 2"
|
|
||||||
assert args[3].kwargs["messages"][2]["role"] == "assistant"
|
|
||||||
assert args[3].kwargs["messages"][2]["content"] == "response 2"
|
|
||||||
assert args[3].kwargs["messages"][3]["role"] == "user"
|
|
||||||
assert args[3].kwargs["messages"][3]["content"] == "message 3"
|
|
||||||
assert args[3].kwargs["messages"][4]["role"] == "assistant"
|
|
||||||
assert args[3].kwargs["messages"][4]["content"] == "response 3"
|
|
||||||
assert args[3].kwargs["messages"][5]["role"] == "user"
|
|
||||||
assert args[3].kwargs["messages"][5]["content"] == "message 4"
|
|
||||||
|
|
||||||
# Trimmed down to two user messages.
|
|
||||||
# system + user-3 + assistant-3 + user-4 + assistant-4 + user-5
|
|
||||||
assert len(args[3].kwargs["messages"]) == 6
|
|
||||||
assert args[4].kwargs["messages"][0]["role"] == "system"
|
|
||||||
assert args[4].kwargs["messages"][0]["content"] == prompt
|
|
||||||
assert args[4].kwargs["messages"][1]["role"] == "user"
|
|
||||||
assert args[4].kwargs["messages"][1]["content"] == "message 3"
|
|
||||||
assert args[4].kwargs["messages"][2]["role"] == "assistant"
|
|
||||||
assert args[4].kwargs["messages"][2]["content"] == "response 3"
|
|
||||||
assert args[4].kwargs["messages"][3]["role"] == "user"
|
|
||||||
assert args[4].kwargs["messages"][3]["content"] == "message 4"
|
|
||||||
assert args[4].kwargs["messages"][4]["role"] == "assistant"
|
|
||||||
assert args[4].kwargs["messages"][4]["content"] == "response 4"
|
|
||||||
assert args[4].kwargs["messages"][5]["role"] == "user"
|
|
||||||
assert args[4].kwargs["messages"][5]["content"] == "message 5"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_message_history_pruning(
|
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
|
||||||
) -> None:
|
|
||||||
"""Test that old message histories are pruned."""
|
|
||||||
with patch(
|
|
||||||
"ollama.AsyncClient.chat",
|
|
||||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
|
||||||
):
|
|
||||||
# Create 3 different message histories
|
|
||||||
conversation_ids: list[str] = []
|
|
||||||
for i in range(3):
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
f"message {i+1}",
|
|
||||||
conversation_id=None,
|
|
||||||
context=Context(),
|
|
||||||
agent_id=mock_config_entry.entry_id,
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
||||||
), result
|
|
||||||
assert isinstance(result.conversation_id, str)
|
|
||||||
conversation_ids.append(result.conversation_id)
|
|
||||||
|
|
||||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
|
||||||
mock_config_entry.entry_id
|
|
||||||
)
|
|
||||||
assert isinstance(agent, ollama.OllamaAgent)
|
|
||||||
assert len(agent._history) == 3
|
|
||||||
assert agent._history.keys() == set(conversation_ids)
|
|
||||||
|
|
||||||
# Modify the timestamps of the first 2 histories so they will be pruned
|
|
||||||
# on the next cycle.
|
|
||||||
for conversation_id in conversation_ids[:2]:
|
|
||||||
# Move back 2 hours
|
|
||||||
agent._history[conversation_id].timestamp -= 2 * 60 * 60
|
|
||||||
|
|
||||||
# Next cycle
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
"test message",
|
|
||||||
conversation_id=None,
|
|
||||||
context=Context(),
|
|
||||||
agent_id=mock_config_entry.entry_id,
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
||||||
), result
|
|
||||||
|
|
||||||
# Only the most recent histories should remain
|
|
||||||
assert len(agent._history) == 2
|
|
||||||
assert conversation_ids[-1] in agent._history
|
|
||||||
assert result.conversation_id in agent._history
|
|
||||||
|
|
||||||
|
|
||||||
async def test_message_history_unlimited(
|
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
|
||||||
) -> None:
|
|
||||||
"""Test that message history is not trimmed when max_history = 0."""
|
|
||||||
conversation_id = "1234"
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"ollama.AsyncClient.chat",
|
|
||||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
|
||||||
),
|
|
||||||
patch.object(mock_config_entry, "options", {ollama.CONF_MAX_HISTORY: 0}),
|
|
||||||
):
|
|
||||||
for i in range(100):
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
f"message {i+1}",
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
context=Context(),
|
|
||||||
agent_id=mock_config_entry.entry_id,
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
||||||
), result
|
|
||||||
|
|
||||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
|
||||||
mock_config_entry.entry_id
|
|
||||||
)
|
|
||||||
assert isinstance(agent, ollama.OllamaAgent)
|
|
||||||
|
|
||||||
assert len(agent._history) == 1
|
|
||||||
assert conversation_id in agent._history
|
|
||||||
assert agent._history[conversation_id].num_user_messages == 100
|
|
||||||
|
|
||||||
|
|
||||||
async def test_error_handling(
|
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
|
||||||
) -> None:
|
|
||||||
"""Test error handling during converse."""
|
|
||||||
with patch(
|
|
||||||
"ollama.AsyncClient.chat",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
side_effect=ResponseError("test error"),
|
|
||||||
):
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
|
||||||
assert result.response.error_code == "unknown", result
|
|
||||||
|
|
||||||
|
|
||||||
async def test_template_error(
|
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
|
||||||
) -> None:
|
|
||||||
"""Test that template error handling works."""
|
|
||||||
hass.config_entries.async_update_entry(
|
|
||||||
mock_config_entry,
|
|
||||||
options={
|
|
||||||
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
with patch(
|
|
||||||
"ollama.AsyncClient.list",
|
|
||||||
):
|
|
||||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
|
||||||
assert result.response.error_code == "unknown", result
|
|
||||||
|
|
||||||
|
|
||||||
async def test_conversation_agent(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
mock_config_entry: MockConfigEntry,
|
|
||||||
mock_init_component,
|
|
||||||
) -> None:
|
|
||||||
"""Test OllamaAgent."""
|
|
||||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
|
||||||
mock_config_entry.entry_id
|
|
||||||
)
|
|
||||||
assert agent.supported_languages == MATCH_ALL
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("side_effect", "error"),
|
("side_effect", "error"),
|
||||||
[
|
[
|
||||||
|
|
Loading…
Add table
Reference in a new issue