"""The conversation platform for the Ollama integration.""" from __future__ import annotations from collections.abc import Callable import json import logging import time from typing import Any, Literal import ollama import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation from homeassistant.components.conversation import trace from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.helpers import intent, llm, template from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util import ulid from .const import ( CONF_KEEP_ALIVE, CONF_MAX_HISTORY, CONF_MODEL, CONF_NUM_CTX, CONF_PROMPT, DEFAULT_KEEP_ALIVE, DEFAULT_MAX_HISTORY, DEFAULT_NUM_CTX, DOMAIN, MAX_HISTORY_SECONDS, ) from .models import MessageHistory, MessageRole # Max number of back and forth with the LLM to generate a response MAX_TOOL_ITERATIONS = 10 _LOGGER = logging.getLogger(__name__) async def async_setup_entry( hass: HomeAssistant, config_entry: ConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up conversation entities.""" agent = OllamaConversationEntity(config_entry) async_add_entities([agent]) def _format_tool( tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None ) -> dict[str, Any]: """Format tool specification.""" tool_spec = { "name": tool.name, "parameters": convert(tool.parameters, custom_serializer=custom_serializer), } if tool.description: tool_spec["description"] = tool.description return {"type": "function", "function": tool_spec} def _fix_invalid_arguments(value: Any) -> Any: """Attempt to repair incorrectly formatted json function arguments. Small models (for example llama3.1 8B) may produce invalid argument values which we attempt to repair here. """ if not isinstance(value, str): return value if (value.startswith("[") and value.endswith("]")) or ( value.startswith("{") and value.endswith("}") ): try: return json.loads(value) except json.decoder.JSONDecodeError: pass return value def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]: """Rewrite ollama tool arguments. This function improves tool use quality by fixing common mistakes made by small local tool use models. This will repair invalid json arguments and omit unnecessary arguments with empty values that will fail intent parsing. """ return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v} class OllamaConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): """Ollama conversation agent.""" _attr_has_entity_name = True def __init__(self, entry: ConfigEntry) -> None: """Initialize the agent.""" self.entry = entry # conversation id -> message history self._history: dict[str, MessageHistory] = {} self._attr_name = entry.title self._attr_unique_id = entry.entry_id if self.entry.options.get(CONF_LLM_HASS_API): self._attr_supported_features = ( conversation.ConversationEntityFeature.CONTROL ) async def async_added_to_hass(self) -> None: """When entity is added to Home Assistant.""" await super().async_added_to_hass() assist_pipeline.async_migrate_engine( self.hass, "conversation", self.entry.entry_id, self.entity_id ) conversation.async_set_agent(self.hass, self.entry, self) self.entry.async_on_unload( self.entry.add_update_listener(self._async_entry_update_listener) ) async def async_will_remove_from_hass(self) -> None: """When entity will be removed from Home Assistant.""" conversation.async_unset_agent(self.hass, self.entry) await super().async_will_remove_from_hass() @property def supported_languages(self) -> list[str] | Literal["*"]: """Return a list of supported languages.""" return MATCH_ALL async def async_process( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" settings = {**self.entry.data, **self.entry.options} client = self.hass.data[DOMAIN][self.entry.entry_id] conversation_id = user_input.conversation_id or ulid.ulid_now() model = settings[CONF_MODEL] intent_response = intent.IntentResponse(language=user_input.language) llm_api: llm.APIInstance | None = None tools: list[dict[str, Any]] | None = None user_name: str | None = None llm_context = llm.LLMContext( platform=DOMAIN, context=user_input.context, user_prompt=user_input.text, language=user_input.language, assistant=conversation.DOMAIN, device_id=user_input.device_id, ) if settings.get(CONF_LLM_HASS_API): try: llm_api = await llm.async_get_api( self.hass, settings[CONF_LLM_HASS_API], llm_context, ) except HomeAssistantError as err: _LOGGER.error("Error getting LLM API: %s", err) intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, f"Error preparing LLM API: {err}", ) return conversation.ConversationResult( response=intent_response, conversation_id=user_input.conversation_id ) tools = [ _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools ] if ( user_input.context and user_input.context.user_id and ( user := await self.hass.auth.async_get_user(user_input.context.user_id) ) ): user_name = user.name # Look up message history message_history: MessageHistory | None = None message_history = self._history.get(conversation_id) if message_history is None: # New history # # Render prompt and error out early if there's a problem try: prompt_parts = [ template.Template( llm.BASE_PROMPT + settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT), self.hass, ).async_render( { "ha_name": self.hass.config.location_name, "user_name": user_name, "llm_context": llm_context, }, parse_result=False, ) ] except TemplateError as err: _LOGGER.error("Error rendering prompt: %s", err) intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, f"Sorry, I had a problem generating my prompt: {err}", ) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) if llm_api: prompt_parts.append(llm_api.api_prompt) prompt = "\n".join(prompt_parts) _LOGGER.debug("Prompt: %s", prompt) _LOGGER.debug("Tools: %s", tools) message_history = MessageHistory( timestamp=time.monotonic(), messages=[ ollama.Message(role=MessageRole.SYSTEM.value, content=prompt) ], ) self._history[conversation_id] = message_history else: # Bump timestamp so this conversation won't get cleaned up message_history.timestamp = time.monotonic() # Clean up old histories self._prune_old_histories() # Trim this message history to keep a maximum number of *user* messages max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)) self._trim_history(message_history, max_messages) # Add new user message message_history.messages.append( ollama.Message(role=MessageRole.USER.value, content=user_input.text) ) trace.async_conversation_trace_append( trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": message_history.messages}, ) # Get response # To prevent infinite loops, we limit the number of iterations for _iteration in range(MAX_TOOL_ITERATIONS): try: response = await client.chat( model=model, # Make a copy of the messages because we mutate the list later messages=list(message_history.messages), tools=tools, stream=False, # keep_alive requires specifying unit. In this case, seconds keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s", options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)}, ) except (ollama.RequestError, ollama.ResponseError) as err: _LOGGER.error("Unexpected error talking to Ollama server: %s", err) intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, f"Sorry, I had a problem talking to the Ollama server: {err}", ) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) response_message = response["message"] message_history.messages.append( ollama.Message( role=response_message["role"], content=response_message.get("content"), tool_calls=response_message.get("tool_calls"), ) ) tool_calls = response_message.get("tool_calls") if not tool_calls or not llm_api: break for tool_call in tool_calls: tool_input = llm.ToolInput( tool_name=tool_call["function"]["name"], tool_args=_parse_tool_args(tool_call["function"]["arguments"]), ) _LOGGER.debug( "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args ) try: tool_response = await llm_api.async_call_tool(tool_input) except (HomeAssistantError, vol.Invalid) as e: tool_response = {"error": type(e).__name__} if str(e): tool_response["error_text"] = str(e) _LOGGER.debug("Tool response: %s", tool_response) message_history.messages.append( ollama.Message( role=MessageRole.TOOL.value, content=json.dumps(tool_response), ) ) # Create intent response intent_response.async_set_speech(response_message["content"]) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) def _prune_old_histories(self) -> None: """Remove old message histories.""" now = time.monotonic() self._history = { conversation_id: message_history for conversation_id, message_history in self._history.items() if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS } def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None: """Trims excess messages from a single history.""" if max_messages < 1: # Keep all messages return if message_history.num_user_messages >= max_messages: # Trim history but keep system prompt (first message). # Every other message should be an assistant message, so keep 2x # message objects. num_keep = 2 * max_messages drop_index = len(message_history.messages) - num_keep message_history.messages = [ message_history.messages[0] ] + message_history.messages[drop_index:] async def _async_entry_update_listener( self, hass: HomeAssistant, entry: ConfigEntry ) -> None: """Handle options update.""" # Reload as we update device info + entity name + supported features await hass.config_entries.async_reload(entry.entry_id)