diff --git a/homeassistant/components/ollama/__init__.py b/homeassistant/components/ollama/__init__.py index 8c9b00f3c9c..323642a8d90 100644 --- a/homeassistant/components/ollama/__init__.py +++ b/homeassistant/components/ollama/__init__.py @@ -4,40 +4,17 @@ from __future__ import annotations import asyncio import logging -import time -from typing import Literal import httpx import ollama -from homeassistant.components import conversation -from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_URL, MATCH_ALL +from homeassistant.const import CONF_URL, Platform from homeassistant.core import HomeAssistant -from homeassistant.exceptions import ConfigEntryNotReady, TemplateError -from homeassistant.helpers import ( - area_registry as ar, - config_validation as cv, - device_registry as dr, - entity_registry as er, - intent, - template, -) -from homeassistant.util import ulid +from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.helpers import config_validation as cv -from .const import ( - CONF_MAX_HISTORY, - CONF_MODEL, - CONF_PROMPT, - DEFAULT_MAX_HISTORY, - DEFAULT_PROMPT, - DEFAULT_TIMEOUT, - DOMAIN, - KEEP_ALIVE_FOREVER, - MAX_HISTORY_SECONDS, -) -from .models import ExposedEntity, MessageHistory, MessageRole +from .const import CONF_MAX_HISTORY, CONF_MODEL, CONF_PROMPT, DEFAULT_TIMEOUT, DOMAIN _LOGGER = logging.getLogger(__name__) @@ -46,11 +23,11 @@ __all__ = [ "CONF_PROMPT", "CONF_MODEL", "CONF_MAX_HISTORY", - "MAX_HISTORY_NO_LIMIT", "DOMAIN", ] CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) +PLATFORMS = (Platform.CONVERSATION,) async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: @@ -65,202 +42,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client - conversation.async_set_agent(hass, entry, OllamaAgent(hass, entry)) + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) return True async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload Ollama.""" + if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS): + return False hass.data[DOMAIN].pop(entry.entry_id) - conversation.async_unset_agent(hass, entry) return True - - -class OllamaAgent(conversation.AbstractConversationAgent): - """Ollama conversation agent.""" - - def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: - """Initialize the agent.""" - self.hass = hass - self.entry = entry - - # conversation id -> message history - self._history: dict[str, MessageHistory] = {} - - @property - def supported_languages(self) -> list[str] | Literal["*"]: - """Return a list of supported languages.""" - return MATCH_ALL - - async def async_process( - self, user_input: conversation.ConversationInput - ) -> conversation.ConversationResult: - """Process a sentence.""" - settings = {**self.entry.data, **self.entry.options} - - client = self.hass.data[DOMAIN][self.entry.entry_id] - conversation_id = user_input.conversation_id or ulid.ulid_now() - model = settings[CONF_MODEL] - - # Look up message history - message_history: MessageHistory | None = None - message_history = self._history.get(conversation_id) - if message_history is None: - # New history - # - # Render prompt and error out early if there's a problem - raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT) - try: - prompt = self._generate_prompt(raw_prompt) - _LOGGER.debug("Prompt: %s", prompt) - except TemplateError as err: - _LOGGER.error("Error rendering prompt: %s", err) - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem generating my prompt: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - message_history = MessageHistory( - timestamp=time.monotonic(), - messages=[ - ollama.Message(role=MessageRole.SYSTEM.value, content=prompt) - ], - ) - self._history[conversation_id] = message_history - else: - # Bump timestamp so this conversation won't get cleaned up - message_history.timestamp = time.monotonic() - - # Clean up old histories - self._prune_old_histories() - - # Trim this message history to keep a maximum number of *user* messages - max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)) - self._trim_history(message_history, max_messages) - - # Add new user message - message_history.messages.append( - ollama.Message(role=MessageRole.USER.value, content=user_input.text) - ) - - # Get response - try: - response = await client.chat( - model=model, - # Make a copy of the messages because we mutate the list later - messages=list(message_history.messages), - stream=False, - keep_alive=KEEP_ALIVE_FOREVER, - ) - except (ollama.RequestError, ollama.ResponseError) as err: - _LOGGER.error("Unexpected error talking to Ollama server: %s", err) - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem talking to the Ollama server: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - response_message = response["message"] - message_history.messages.append( - ollama.Message( - role=response_message["role"], content=response_message["content"] - ) - ) - - # Create intent response - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_speech(response_message["content"]) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - def _prune_old_histories(self) -> None: - """Remove old message histories.""" - now = time.monotonic() - self._history = { - conversation_id: message_history - for conversation_id, message_history in self._history.items() - if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS - } - - def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None: - """Trims excess messages from a single history.""" - if max_messages < 1: - # Keep all messages - return - - if message_history.num_user_messages >= max_messages: - # Trim history but keep system prompt (first message). - # Every other message should be an assistant message, so keep 2x - # message objects. - num_keep = 2 * max_messages - drop_index = len(message_history.messages) - num_keep - message_history.messages = [ - message_history.messages[0] - ] + message_history.messages[drop_index:] - - def _generate_prompt(self, raw_prompt: str) -> str: - """Generate a prompt for the user.""" - return template.Template(raw_prompt, self.hass).async_render( - { - "ha_name": self.hass.config.location_name, - "ha_language": self.hass.config.language, - "exposed_entities": self._get_exposed_entities(), - }, - parse_result=False, - ) - - def _get_exposed_entities(self) -> list[ExposedEntity]: - """Get state list of exposed entities.""" - area_registry = ar.async_get(self.hass) - entity_registry = er.async_get(self.hass) - device_registry = dr.async_get(self.hass) - - exposed_entities = [] - exposed_states = [ - state - for state in self.hass.states.async_all() - if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id) - ] - - for state in exposed_states: - entity = entity_registry.async_get(state.entity_id) - names = [state.name] - area_names = [] - - if entity is not None: - # Add aliases - names.extend(entity.aliases) - if entity.area_id and ( - area := area_registry.async_get_area(entity.area_id) - ): - # Entity is in area - area_names.append(area.name) - area_names.extend(area.aliases) - elif entity.device_id and ( - device := device_registry.async_get(entity.device_id) - ): - # Check device area - if device.area_id and ( - area := area_registry.async_get_area(device.area_id) - ): - area_names.append(area.name) - area_names.extend(area.aliases) - - exposed_entities.append( - ExposedEntity( - entity_id=state.entity_id, - state=state, - names=names, - area_names=area_names, - ) - ) - - return exposed_entities diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py new file mode 100644 index 00000000000..8a5f6e7d5c5 --- /dev/null +++ b/homeassistant/components/ollama/conversation.py @@ -0,0 +1,258 @@ +"""The conversation platform for the Ollama integration.""" + +from __future__ import annotations + +import logging +import time +from typing import Literal + +import ollama + +from homeassistant.components import assist_pipeline, conversation +from homeassistant.components.homeassistant.exposed_entities import async_should_expose +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import MATCH_ALL +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import TemplateError +from homeassistant.helpers import ( + area_registry as ar, + device_registry as dr, + entity_registry as er, + intent, + template, +) +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.util import ulid + +from .const import ( + CONF_MAX_HISTORY, + CONF_MODEL, + CONF_PROMPT, + DEFAULT_MAX_HISTORY, + DEFAULT_PROMPT, + DOMAIN, + KEEP_ALIVE_FOREVER, + MAX_HISTORY_SECONDS, +) +from .models import ExposedEntity, MessageHistory, MessageRole + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up conversation entities.""" + agent = OllamaConversationEntity(hass, config_entry) + async_add_entities([agent]) + + +class OllamaConversationEntity( + conversation.ConversationEntity, conversation.AbstractConversationAgent +): + """Ollama conversation agent.""" + + _attr_has_entity_name = True + + def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: + """Initialize the agent.""" + self.hass = hass + self.entry = entry + + # conversation id -> message history + self._history: dict[str, MessageHistory] = {} + self._attr_name = entry.title + self._attr_unique_id = entry.entry_id + + async def async_added_to_hass(self) -> None: + """When entity is added to Home Assistant.""" + await super().async_added_to_hass() + assist_pipeline.async_migrate_engine( + self.hass, "conversation", self.entry.entry_id, self.entity_id + ) + conversation.async_set_agent(self.hass, self.entry, self) + + async def async_will_remove_from_hass(self) -> None: + """When entity will be removed from Home Assistant.""" + conversation.async_unset_agent(self.hass, self.entry) + await super().async_will_remove_from_hass() + + @property + def supported_languages(self) -> list[str] | Literal["*"]: + """Return a list of supported languages.""" + return MATCH_ALL + + async def async_process( + self, user_input: conversation.ConversationInput + ) -> conversation.ConversationResult: + """Process a sentence.""" + settings = {**self.entry.data, **self.entry.options} + + client = self.hass.data[DOMAIN][self.entry.entry_id] + conversation_id = user_input.conversation_id or ulid.ulid_now() + model = settings[CONF_MODEL] + + # Look up message history + message_history: MessageHistory | None = None + message_history = self._history.get(conversation_id) + if message_history is None: + # New history + # + # Render prompt and error out early if there's a problem + raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT) + try: + prompt = self._generate_prompt(raw_prompt) + _LOGGER.debug("Prompt: %s", prompt) + except TemplateError as err: + _LOGGER.error("Error rendering prompt: %s", err) + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem generating my prompt: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + message_history = MessageHistory( + timestamp=time.monotonic(), + messages=[ + ollama.Message(role=MessageRole.SYSTEM.value, content=prompt) + ], + ) + self._history[conversation_id] = message_history + else: + # Bump timestamp so this conversation won't get cleaned up + message_history.timestamp = time.monotonic() + + # Clean up old histories + self._prune_old_histories() + + # Trim this message history to keep a maximum number of *user* messages + max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)) + self._trim_history(message_history, max_messages) + + # Add new user message + message_history.messages.append( + ollama.Message(role=MessageRole.USER.value, content=user_input.text) + ) + + # Get response + try: + response = await client.chat( + model=model, + # Make a copy of the messages because we mutate the list later + messages=list(message_history.messages), + stream=False, + keep_alive=KEEP_ALIVE_FOREVER, + ) + except (ollama.RequestError, ollama.ResponseError) as err: + _LOGGER.error("Unexpected error talking to Ollama server: %s", err) + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem talking to the Ollama server: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + response_message = response["message"] + message_history.messages.append( + ollama.Message( + role=response_message["role"], content=response_message["content"] + ) + ) + + # Create intent response + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_speech(response_message["content"]) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + def _prune_old_histories(self) -> None: + """Remove old message histories.""" + now = time.monotonic() + self._history = { + conversation_id: message_history + for conversation_id, message_history in self._history.items() + if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS + } + + def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None: + """Trims excess messages from a single history.""" + if max_messages < 1: + # Keep all messages + return + + if message_history.num_user_messages >= max_messages: + # Trim history but keep system prompt (first message). + # Every other message should be an assistant message, so keep 2x + # message objects. + num_keep = 2 * max_messages + drop_index = len(message_history.messages) - num_keep + message_history.messages = [ + message_history.messages[0] + ] + message_history.messages[drop_index:] + + def _generate_prompt(self, raw_prompt: str) -> str: + """Generate a prompt for the user.""" + return template.Template(raw_prompt, self.hass).async_render( + { + "ha_name": self.hass.config.location_name, + "ha_language": self.hass.config.language, + "exposed_entities": self._get_exposed_entities(), + }, + parse_result=False, + ) + + def _get_exposed_entities(self) -> list[ExposedEntity]: + """Get state list of exposed entities.""" + area_registry = ar.async_get(self.hass) + entity_registry = er.async_get(self.hass) + device_registry = dr.async_get(self.hass) + + exposed_entities = [] + exposed_states = [ + state + for state in self.hass.states.async_all() + if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id) + ] + + for state in exposed_states: + entity = entity_registry.async_get(state.entity_id) + names = [state.name] + area_names = [] + + if entity is not None: + # Add aliases + names.extend(entity.aliases) + if entity.area_id and ( + area := area_registry.async_get_area(entity.area_id) + ): + # Entity is in area + area_names.append(area.name) + area_names.extend(area.aliases) + elif entity.device_id and ( + device := device_registry.async_get(entity.device_id) + ): + # Check device area + if device.area_id and ( + area := area_registry.async_get_area(device.area_id) + ): + area_names.append(area.name) + area_names.extend(area.aliases) + + exposed_entities.append( + ExposedEntity( + entity_id=state.entity_id, + state=state, + names=names, + area_names=area_names, + ) + ) + + return exposed_entities diff --git a/homeassistant/components/ollama/manifest.json b/homeassistant/components/ollama/manifest.json index 6b16ae667f1..7afaaa3dbd4 100644 --- a/homeassistant/components/ollama/manifest.json +++ b/homeassistant/components/ollama/manifest.json @@ -1,6 +1,7 @@ { "domain": "ollama", "name": "Ollama", + "after_dependencies": ["assist_pipeline"], "codeowners": ["@synesthesiam"], "config_flow": true, "dependencies": ["conversation"], diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py new file mode 100644 index 00000000000..080d0d34f2d --- /dev/null +++ b/tests/components/ollama/test_conversation.py @@ -0,0 +1,347 @@ +"""Tests for the Ollama integration.""" + +from unittest.mock import AsyncMock, patch + +from ollama import Message, ResponseError +import pytest + +from homeassistant.components import conversation, ollama +from homeassistant.components.homeassistant.exposed_entities import async_expose_entity +from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL +from homeassistant.core import Context, HomeAssistant +from homeassistant.helpers import ( + area_registry as ar, + device_registry as dr, + entity_registry as er, + intent, +) + +from tests.common import MockConfigEntry + + +@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"]) +async def test_chat( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + area_registry: ar.AreaRegistry, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, + agent_id: str, +) -> None: + """Test that the chat function is called with the appropriate arguments.""" + + if agent_id is None: + agent_id = mock_config_entry.entry_id + + # Create some areas, devices, and entities + area_kitchen = area_registry.async_get_or_create("kitchen_id") + area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen") + area_bedroom = area_registry.async_get_or_create("bedroom_id") + area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom") + area_office = area_registry.async_get_or_create("office_id") + area_office = area_registry.async_update(area_office.id, name="office") + + entry = MockConfigEntry() + entry.add_to_hass(hass) + kitchen_device = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections=set(), + identifiers={("demo", "id-1234")}, + ) + device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id) + + kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234") + kitchen_light = entity_registry.async_update_entity( + kitchen_light.entity_id, device_id=kitchen_device.id + ) + hass.states.async_set( + kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"} + ) + + bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678") + bedroom_light = entity_registry.async_update_entity( + bedroom_light.entity_id, area_id=area_bedroom.id + ) + hass.states.async_set( + bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"} + ) + + # Hide the office light + office_light = entity_registry.async_get_or_create("light", "demo", "ABCD") + office_light = entity_registry.async_update_entity( + office_light.entity_id, area_id=area_office.id + ) + hass.states.async_set( + office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"} + ) + async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False) + + with patch( + "ollama.AsyncClient.chat", + return_value={"message": {"role": "assistant", "content": "test response"}}, + ) as mock_chat: + result = await conversation.async_converse( + hass, + "test message", + None, + Context(), + agent_id=agent_id, + ) + + assert mock_chat.call_count == 1 + args = mock_chat.call_args.kwargs + prompt = args["messages"][0]["content"] + + assert args["model"] == "test model" + assert args["messages"] == [ + Message({"role": "system", "content": prompt}), + Message({"role": "user", "content": "test message"}), + ] + + # Verify only exposed devices/areas are in prompt + assert "kitchen light" in prompt + assert "bedroom light" in prompt + assert "office light" not in prompt + assert "office" not in prompt + + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + assert result.response.speech["plain"]["speech"] == "test response" + + +async def test_message_history_trimming( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: + """Test that a single message history is trimmed according to the config.""" + response_idx = 0 + + def response(*args, **kwargs) -> dict: + nonlocal response_idx + response_idx += 1 + return {"message": {"role": "assistant", "content": f"response {response_idx}"}} + + with patch( + "ollama.AsyncClient.chat", + side_effect=response, + ) as mock_chat: + # mock_init_component sets "max_history" to 2 + for i in range(5): + result = await conversation.async_converse( + hass, + f"message {i+1}", + conversation_id="1234", + context=Context(), + agent_id=mock_config_entry.entry_id, + ) + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + + assert mock_chat.call_count == 5 + args = mock_chat.call_args_list + prompt = args[0].kwargs["messages"][0]["content"] + + # system + user-1 + assert len(args[0].kwargs["messages"]) == 2 + assert args[0].kwargs["messages"][1]["content"] == "message 1" + + # Full history + # system + user-1 + assistant-1 + user-2 + assert len(args[1].kwargs["messages"]) == 4 + assert args[1].kwargs["messages"][0]["role"] == "system" + assert args[1].kwargs["messages"][0]["content"] == prompt + assert args[1].kwargs["messages"][1]["role"] == "user" + assert args[1].kwargs["messages"][1]["content"] == "message 1" + assert args[1].kwargs["messages"][2]["role"] == "assistant" + assert args[1].kwargs["messages"][2]["content"] == "response 1" + assert args[1].kwargs["messages"][3]["role"] == "user" + assert args[1].kwargs["messages"][3]["content"] == "message 2" + + # Full history + # system + user-1 + assistant-1 + user-2 + assistant-2 + user-3 + assert len(args[2].kwargs["messages"]) == 6 + assert args[2].kwargs["messages"][0]["role"] == "system" + assert args[2].kwargs["messages"][0]["content"] == prompt + assert args[2].kwargs["messages"][1]["role"] == "user" + assert args[2].kwargs["messages"][1]["content"] == "message 1" + assert args[2].kwargs["messages"][2]["role"] == "assistant" + assert args[2].kwargs["messages"][2]["content"] == "response 1" + assert args[2].kwargs["messages"][3]["role"] == "user" + assert args[2].kwargs["messages"][3]["content"] == "message 2" + assert args[2].kwargs["messages"][4]["role"] == "assistant" + assert args[2].kwargs["messages"][4]["content"] == "response 2" + assert args[2].kwargs["messages"][5]["role"] == "user" + assert args[2].kwargs["messages"][5]["content"] == "message 3" + + # Trimmed down to two user messages. + # system + user-2 + assistant-2 + user-3 + assistant-3 + user-4 + assert len(args[3].kwargs["messages"]) == 6 + assert args[3].kwargs["messages"][0]["role"] == "system" + assert args[3].kwargs["messages"][0]["content"] == prompt + assert args[3].kwargs["messages"][1]["role"] == "user" + assert args[3].kwargs["messages"][1]["content"] == "message 2" + assert args[3].kwargs["messages"][2]["role"] == "assistant" + assert args[3].kwargs["messages"][2]["content"] == "response 2" + assert args[3].kwargs["messages"][3]["role"] == "user" + assert args[3].kwargs["messages"][3]["content"] == "message 3" + assert args[3].kwargs["messages"][4]["role"] == "assistant" + assert args[3].kwargs["messages"][4]["content"] == "response 3" + assert args[3].kwargs["messages"][5]["role"] == "user" + assert args[3].kwargs["messages"][5]["content"] == "message 4" + + # Trimmed down to two user messages. + # system + user-3 + assistant-3 + user-4 + assistant-4 + user-5 + assert len(args[3].kwargs["messages"]) == 6 + assert args[4].kwargs["messages"][0]["role"] == "system" + assert args[4].kwargs["messages"][0]["content"] == prompt + assert args[4].kwargs["messages"][1]["role"] == "user" + assert args[4].kwargs["messages"][1]["content"] == "message 3" + assert args[4].kwargs["messages"][2]["role"] == "assistant" + assert args[4].kwargs["messages"][2]["content"] == "response 3" + assert args[4].kwargs["messages"][3]["role"] == "user" + assert args[4].kwargs["messages"][3]["content"] == "message 4" + assert args[4].kwargs["messages"][4]["role"] == "assistant" + assert args[4].kwargs["messages"][4]["content"] == "response 4" + assert args[4].kwargs["messages"][5]["role"] == "user" + assert args[4].kwargs["messages"][5]["content"] == "message 5" + + +async def test_message_history_pruning( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: + """Test that old message histories are pruned.""" + with patch( + "ollama.AsyncClient.chat", + return_value={"message": {"role": "assistant", "content": "test response"}}, + ): + # Create 3 different message histories + conversation_ids: list[str] = [] + for i in range(3): + result = await conversation.async_converse( + hass, + f"message {i+1}", + conversation_id=None, + context=Context(), + agent_id=mock_config_entry.entry_id, + ) + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + assert isinstance(result.conversation_id, str) + conversation_ids.append(result.conversation_id) + + agent = conversation.get_agent_manager(hass).async_get_agent( + mock_config_entry.entry_id + ) + assert len(agent._history) == 3 + assert agent._history.keys() == set(conversation_ids) + + # Modify the timestamps of the first 2 histories so they will be pruned + # on the next cycle. + for conversation_id in conversation_ids[:2]: + # Move back 2 hours + agent._history[conversation_id].timestamp -= 2 * 60 * 60 + + # Next cycle + result = await conversation.async_converse( + hass, + "test message", + conversation_id=None, + context=Context(), + agent_id=mock_config_entry.entry_id, + ) + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + + # Only the most recent histories should remain + assert len(agent._history) == 2 + assert conversation_ids[-1] in agent._history + assert result.conversation_id in agent._history + + +async def test_message_history_unlimited( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: + """Test that message history is not trimmed when max_history = 0.""" + conversation_id = "1234" + with ( + patch( + "ollama.AsyncClient.chat", + return_value={"message": {"role": "assistant", "content": "test response"}}, + ), + patch.object(mock_config_entry, "options", {ollama.CONF_MAX_HISTORY: 0}), + ): + for i in range(100): + result = await conversation.async_converse( + hass, + f"message {i+1}", + conversation_id=conversation_id, + context=Context(), + agent_id=mock_config_entry.entry_id, + ) + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + + agent = conversation.get_agent_manager(hass).async_get_agent( + mock_config_entry.entry_id + ) + + assert len(agent._history) == 1 + assert conversation_id in agent._history + assert agent._history[conversation_id].num_user_messages == 100 + + +async def test_error_handling( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: + """Test error handling during converse.""" + with patch( + "ollama.AsyncClient.chat", + new_callable=AsyncMock, + side_effect=ResponseError("test error"), + ): + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + + +async def test_template_error( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> None: + """Test that template error handling works.""" + hass.config_entries.async_update_entry( + mock_config_entry, + options={ + "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", + }, + ) + with patch( + "ollama.AsyncClient.list", + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + + +async def test_conversation_agent( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, +) -> None: + """Test OllamaConversationEntity.""" + agent = conversation.get_agent_manager(hass).async_get_agent( + mock_config_entry.entry_id + ) + assert agent.supported_languages == MATCH_ALL diff --git a/tests/components/ollama/test_init.py b/tests/components/ollama/test_init.py index 5326a8ed609..c296d6de700 100644 --- a/tests/components/ollama/test_init.py +++ b/tests/components/ollama/test_init.py @@ -1,351 +1,17 @@ """Tests for the Ollama integration.""" -from unittest.mock import AsyncMock, patch +from unittest.mock import patch from httpx import ConnectError -from ollama import Message, ResponseError import pytest -from homeassistant.components import conversation, ollama -from homeassistant.components.homeassistant.exposed_entities import async_expose_entity -from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL -from homeassistant.core import Context, HomeAssistant -from homeassistant.helpers import ( - area_registry as ar, - device_registry as dr, - entity_registry as er, - intent, -) +from homeassistant.components import ollama +from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry -async def test_chat( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, - area_registry: ar.AreaRegistry, - device_registry: dr.DeviceRegistry, - entity_registry: er.EntityRegistry, -) -> None: - """Test that the chat function is called with the appropriate arguments.""" - - # Create some areas, devices, and entities - area_kitchen = area_registry.async_get_or_create("kitchen_id") - area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen") - area_bedroom = area_registry.async_get_or_create("bedroom_id") - area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom") - area_office = area_registry.async_get_or_create("office_id") - area_office = area_registry.async_update(area_office.id, name="office") - - entry = MockConfigEntry() - entry.add_to_hass(hass) - kitchen_device = device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections=set(), - identifiers={("demo", "id-1234")}, - ) - device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id) - - kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234") - kitchen_light = entity_registry.async_update_entity( - kitchen_light.entity_id, device_id=kitchen_device.id - ) - hass.states.async_set( - kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"} - ) - - bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678") - bedroom_light = entity_registry.async_update_entity( - bedroom_light.entity_id, area_id=area_bedroom.id - ) - hass.states.async_set( - bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"} - ) - - # Hide the office light - office_light = entity_registry.async_get_or_create("light", "demo", "ABCD") - office_light = entity_registry.async_update_entity( - office_light.entity_id, area_id=area_office.id - ) - hass.states.async_set( - office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"} - ) - async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False) - - with patch( - "ollama.AsyncClient.chat", - return_value={"message": {"role": "assistant", "content": "test response"}}, - ) as mock_chat: - result = await conversation.async_converse( - hass, - "test message", - None, - Context(), - agent_id=mock_config_entry.entry_id, - ) - - assert mock_chat.call_count == 1 - args = mock_chat.call_args.kwargs - prompt = args["messages"][0]["content"] - - assert args["model"] == "test model" - assert args["messages"] == [ - Message({"role": "system", "content": prompt}), - Message({"role": "user", "content": "test message"}), - ] - - # Verify only exposed devices/areas are in prompt - assert "kitchen light" in prompt - assert "bedroom light" in prompt - assert "office light" not in prompt - assert "office" not in prompt - - assert ( - result.response.response_type == intent.IntentResponseType.ACTION_DONE - ), result - assert result.response.speech["plain"]["speech"] == "test response" - - -async def test_message_history_trimming( - hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component -) -> None: - """Test that a single message history is trimmed according to the config.""" - response_idx = 0 - - def response(*args, **kwargs) -> dict: - nonlocal response_idx - response_idx += 1 - return {"message": {"role": "assistant", "content": f"response {response_idx}"}} - - with patch( - "ollama.AsyncClient.chat", - side_effect=response, - ) as mock_chat: - # mock_init_component sets "max_history" to 2 - for i in range(5): - result = await conversation.async_converse( - hass, - f"message {i+1}", - conversation_id="1234", - context=Context(), - agent_id=mock_config_entry.entry_id, - ) - assert ( - result.response.response_type == intent.IntentResponseType.ACTION_DONE - ), result - - assert mock_chat.call_count == 5 - args = mock_chat.call_args_list - prompt = args[0].kwargs["messages"][0]["content"] - - # system + user-1 - assert len(args[0].kwargs["messages"]) == 2 - assert args[0].kwargs["messages"][1]["content"] == "message 1" - - # Full history - # system + user-1 + assistant-1 + user-2 - assert len(args[1].kwargs["messages"]) == 4 - assert args[1].kwargs["messages"][0]["role"] == "system" - assert args[1].kwargs["messages"][0]["content"] == prompt - assert args[1].kwargs["messages"][1]["role"] == "user" - assert args[1].kwargs["messages"][1]["content"] == "message 1" - assert args[1].kwargs["messages"][2]["role"] == "assistant" - assert args[1].kwargs["messages"][2]["content"] == "response 1" - assert args[1].kwargs["messages"][3]["role"] == "user" - assert args[1].kwargs["messages"][3]["content"] == "message 2" - - # Full history - # system + user-1 + assistant-1 + user-2 + assistant-2 + user-3 - assert len(args[2].kwargs["messages"]) == 6 - assert args[2].kwargs["messages"][0]["role"] == "system" - assert args[2].kwargs["messages"][0]["content"] == prompt - assert args[2].kwargs["messages"][1]["role"] == "user" - assert args[2].kwargs["messages"][1]["content"] == "message 1" - assert args[2].kwargs["messages"][2]["role"] == "assistant" - assert args[2].kwargs["messages"][2]["content"] == "response 1" - assert args[2].kwargs["messages"][3]["role"] == "user" - assert args[2].kwargs["messages"][3]["content"] == "message 2" - assert args[2].kwargs["messages"][4]["role"] == "assistant" - assert args[2].kwargs["messages"][4]["content"] == "response 2" - assert args[2].kwargs["messages"][5]["role"] == "user" - assert args[2].kwargs["messages"][5]["content"] == "message 3" - - # Trimmed down to two user messages. - # system + user-2 + assistant-2 + user-3 + assistant-3 + user-4 - assert len(args[3].kwargs["messages"]) == 6 - assert args[3].kwargs["messages"][0]["role"] == "system" - assert args[3].kwargs["messages"][0]["content"] == prompt - assert args[3].kwargs["messages"][1]["role"] == "user" - assert args[3].kwargs["messages"][1]["content"] == "message 2" - assert args[3].kwargs["messages"][2]["role"] == "assistant" - assert args[3].kwargs["messages"][2]["content"] == "response 2" - assert args[3].kwargs["messages"][3]["role"] == "user" - assert args[3].kwargs["messages"][3]["content"] == "message 3" - assert args[3].kwargs["messages"][4]["role"] == "assistant" - assert args[3].kwargs["messages"][4]["content"] == "response 3" - assert args[3].kwargs["messages"][5]["role"] == "user" - assert args[3].kwargs["messages"][5]["content"] == "message 4" - - # Trimmed down to two user messages. - # system + user-3 + assistant-3 + user-4 + assistant-4 + user-5 - assert len(args[3].kwargs["messages"]) == 6 - assert args[4].kwargs["messages"][0]["role"] == "system" - assert args[4].kwargs["messages"][0]["content"] == prompt - assert args[4].kwargs["messages"][1]["role"] == "user" - assert args[4].kwargs["messages"][1]["content"] == "message 3" - assert args[4].kwargs["messages"][2]["role"] == "assistant" - assert args[4].kwargs["messages"][2]["content"] == "response 3" - assert args[4].kwargs["messages"][3]["role"] == "user" - assert args[4].kwargs["messages"][3]["content"] == "message 4" - assert args[4].kwargs["messages"][4]["role"] == "assistant" - assert args[4].kwargs["messages"][4]["content"] == "response 4" - assert args[4].kwargs["messages"][5]["role"] == "user" - assert args[4].kwargs["messages"][5]["content"] == "message 5" - - -async def test_message_history_pruning( - hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component -) -> None: - """Test that old message histories are pruned.""" - with patch( - "ollama.AsyncClient.chat", - return_value={"message": {"role": "assistant", "content": "test response"}}, - ): - # Create 3 different message histories - conversation_ids: list[str] = [] - for i in range(3): - result = await conversation.async_converse( - hass, - f"message {i+1}", - conversation_id=None, - context=Context(), - agent_id=mock_config_entry.entry_id, - ) - assert ( - result.response.response_type == intent.IntentResponseType.ACTION_DONE - ), result - assert isinstance(result.conversation_id, str) - conversation_ids.append(result.conversation_id) - - agent = conversation.get_agent_manager(hass).async_get_agent( - mock_config_entry.entry_id - ) - assert isinstance(agent, ollama.OllamaAgent) - assert len(agent._history) == 3 - assert agent._history.keys() == set(conversation_ids) - - # Modify the timestamps of the first 2 histories so they will be pruned - # on the next cycle. - for conversation_id in conversation_ids[:2]: - # Move back 2 hours - agent._history[conversation_id].timestamp -= 2 * 60 * 60 - - # Next cycle - result = await conversation.async_converse( - hass, - "test message", - conversation_id=None, - context=Context(), - agent_id=mock_config_entry.entry_id, - ) - assert ( - result.response.response_type == intent.IntentResponseType.ACTION_DONE - ), result - - # Only the most recent histories should remain - assert len(agent._history) == 2 - assert conversation_ids[-1] in agent._history - assert result.conversation_id in agent._history - - -async def test_message_history_unlimited( - hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component -) -> None: - """Test that message history is not trimmed when max_history = 0.""" - conversation_id = "1234" - with ( - patch( - "ollama.AsyncClient.chat", - return_value={"message": {"role": "assistant", "content": "test response"}}, - ), - patch.object(mock_config_entry, "options", {ollama.CONF_MAX_HISTORY: 0}), - ): - for i in range(100): - result = await conversation.async_converse( - hass, - f"message {i+1}", - conversation_id=conversation_id, - context=Context(), - agent_id=mock_config_entry.entry_id, - ) - assert ( - result.response.response_type == intent.IntentResponseType.ACTION_DONE - ), result - - agent = conversation.get_agent_manager(hass).async_get_agent( - mock_config_entry.entry_id - ) - assert isinstance(agent, ollama.OllamaAgent) - - assert len(agent._history) == 1 - assert conversation_id in agent._history - assert agent._history[conversation_id].num_user_messages == 100 - - -async def test_error_handling( - hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component -) -> None: - """Test error handling during converse.""" - with patch( - "ollama.AsyncClient.chat", - new_callable=AsyncMock, - side_effect=ResponseError("test error"), - ): - result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id - ) - - assert result.response.response_type == intent.IntentResponseType.ERROR, result - assert result.response.error_code == "unknown", result - - -async def test_template_error( - hass: HomeAssistant, mock_config_entry: MockConfigEntry -) -> None: - """Test that template error handling works.""" - hass.config_entries.async_update_entry( - mock_config_entry, - options={ - "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", - }, - ) - with patch( - "ollama.AsyncClient.list", - ): - await hass.config_entries.async_setup(mock_config_entry.entry_id) - await hass.async_block_till_done() - result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id - ) - - assert result.response.response_type == intent.IntentResponseType.ERROR, result - assert result.response.error_code == "unknown", result - - -async def test_conversation_agent( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, -) -> None: - """Test OllamaAgent.""" - agent = conversation.get_agent_manager(hass).async_get_agent( - mock_config_entry.entry_id - ) - assert agent.supported_languages == MATCH_ALL - - @pytest.mark.parametrize( ("side_effect", "error"), [