diff --git a/homeassistant/components/ollama/config_flow.py b/homeassistant/components/ollama/config_flow.py index 475d5339dea..bcdd6e06f48 100644 --- a/homeassistant/components/ollama/config_flow.py +++ b/homeassistant/components/ollama/config_flow.py @@ -18,7 +18,9 @@ from homeassistant.config_entries import ( ConfigFlowResult, OptionsFlow, ) -from homeassistant.const import CONF_URL +from homeassistant.const import CONF_LLM_HASS_API, CONF_URL +from homeassistant.core import HomeAssistant +from homeassistant.helpers import llm from homeassistant.helpers.selector import ( NumberSelector, NumberSelectorConfig, @@ -40,7 +42,6 @@ from .const import ( DEFAULT_KEEP_ALIVE, DEFAULT_MAX_HISTORY, DEFAULT_MODEL, - DEFAULT_PROMPT, DEFAULT_TIMEOUT, DOMAIN, MODEL_NAMES, @@ -208,25 +209,52 @@ class OllamaOptionsFlow(OptionsFlow): ) -> ConfigFlowResult: """Manage the options.""" if user_input is not None: + if user_input[CONF_LLM_HASS_API] == "none": + user_input.pop(CONF_LLM_HASS_API) return self.async_create_entry( title=_get_title(self.model), data=user_input ) options = self.config_entry.options or MappingProxyType({}) - schema = ollama_config_option_schema(options) + schema = ollama_config_option_schema(self.hass, options) return self.async_show_form( step_id="init", data_schema=vol.Schema(schema), ) -def ollama_config_option_schema(options: MappingProxyType[str, Any]) -> dict: +def ollama_config_option_schema( + hass: HomeAssistant, options: MappingProxyType[str, Any] +) -> dict: """Ollama options schema.""" + hass_apis: list[SelectOptionDict] = [ + SelectOptionDict( + label="No control", + value="none", + ) + ] + hass_apis.extend( + SelectOptionDict( + label=api.name, + value=api.id, + ) + for api in llm.async_get_apis(hass) + ) + return { vol.Optional( CONF_PROMPT, - description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)}, + description={ + "suggested_value": options.get( + CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT + ) + }, ): TemplateSelector(), + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": options.get(CONF_LLM_HASS_API)}, + default="none", + ): SelectSelector(SelectSelectorConfig(options=hass_apis)), vol.Optional( CONF_MAX_HISTORY, description={ diff --git a/homeassistant/components/ollama/const.py b/homeassistant/components/ollama/const.py index b3bce3624c2..97c4f1186fc 100644 --- a/homeassistant/components/ollama/const.py +++ b/homeassistant/components/ollama/const.py @@ -4,73 +4,6 @@ DOMAIN = "ollama" CONF_MODEL = "model" CONF_PROMPT = "prompt" -DEFAULT_PROMPT = """{%- set used_domains = set([ - "binary_sensor", - "climate", - "cover", - "fan", - "light", - "lock", - "sensor", - "switch", - "weather", -]) %} -{%- set used_attributes = set([ - "temperature", - "current_temperature", - "temperature_unit", - "brightness", - "humidity", - "unit_of_measurement", - "device_class", - "current_position", - "percentage", -]) %} - -This smart home is controlled by Home Assistant. -The current time is {{ now().strftime("%X") }}. -Today's date is {{ now().strftime("%x") }}. - -An overview of the areas and the devices in this smart home: -```yaml -{%- for entity in exposed_entities: %} -{%- if entity.domain not in used_domains: %} - {%- continue %} -{%- endif %} - -- domain: {{ entity.domain }} -{%- if entity.names | length == 1: %} - name: {{ entity.names[0] }} -{%- else: %} - names: -{%- for name in entity.names: %} - - {{ name }} -{%- endfor %} -{%- endif %} -{%- if entity.area_names | length == 1: %} - area: {{ entity.area_names[0] }} -{%- elif entity.area_names: %} - areas: -{%- for area_name in entity.area_names: %} - - {{ area_name }} -{%- endfor %} -{%- endif %} - state: {{ entity.state.state }} - {%- set attributes_key_printed = False %} -{%- for attr_name, attr_value in entity.state.attributes.items(): %} - {%- if attr_name in used_attributes: %} - {%- if not attributes_key_printed: %} - attributes: - {%- set attributes_key_printed = True %} - {%- endif %} - {{ attr_name }}: {{ attr_value }} - {%- endif %} -{%- endfor %} -{%- endfor %} -``` - -Answer the user's questions using the information about this smart home. -Keep your answers brief and do not apologize.""" CONF_KEEP_ALIVE = "keep_alive" DEFAULT_KEEP_ALIVE = -1 # seconds. -1 = indefinite, 0 = never @@ -187,4 +120,4 @@ MODEL_NAMES = [ # https://ollama.com/library "yi", "zephyr", ] -DEFAULT_MODEL = "llama2:latest" +DEFAULT_MODEL = "llama3.1:latest" diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index ccc7b9bdecc..ae0acef1077 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -2,26 +2,23 @@ from __future__ import annotations +from collections.abc import Callable +import json import logging import time -from typing import Literal +from typing import Any, Literal import ollama +import voluptuous as vol +from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation from homeassistant.components.conversation import trace -from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.config_entries import ConfigEntry -from homeassistant.const import MATCH_ALL +from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant -from homeassistant.exceptions import TemplateError -from homeassistant.helpers import ( - area_registry as ar, - device_registry as dr, - entity_registry as er, - intent, - template, -) +from homeassistant.exceptions import HomeAssistantError, TemplateError +from homeassistant.helpers import intent, llm, template from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util import ulid @@ -32,11 +29,13 @@ from .const import ( CONF_PROMPT, DEFAULT_KEEP_ALIVE, DEFAULT_MAX_HISTORY, - DEFAULT_PROMPT, DOMAIN, MAX_HISTORY_SECONDS, ) -from .models import ExposedEntity, MessageHistory, MessageRole +from .models import MessageHistory, MessageRole + +# Max number of back and forth with the LLM to generate a response +MAX_TOOL_ITERATIONS = 10 _LOGGER = logging.getLogger(__name__) @@ -51,6 +50,19 @@ async def async_setup_entry( async_add_entities([agent]) +def _format_tool( + tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None +) -> dict[str, Any]: + """Format tool specification.""" + tool_spec = { + "name": tool.name, + "parameters": convert(tool.parameters, custom_serializer=custom_serializer), + } + if tool.description: + tool_spec["description"] = tool.description + return {"type": "function", "function": tool_spec} + + class OllamaConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -94,6 +106,47 @@ class OllamaConversationEntity( client = self.hass.data[DOMAIN][self.entry.entry_id] conversation_id = user_input.conversation_id or ulid.ulid_now() model = settings[CONF_MODEL] + intent_response = intent.IntentResponse(language=user_input.language) + llm_api: llm.APIInstance | None = None + tools: list[dict[str, Any]] | None = None + user_name: str | None = None + llm_context = llm.LLMContext( + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=conversation.DOMAIN, + device_id=user_input.device_id, + ) + + if settings.get(CONF_LLM_HASS_API): + try: + llm_api = await llm.async_get_api( + self.hass, + settings[CONF_LLM_HASS_API], + llm_context, + ) + except HomeAssistantError as err: + _LOGGER.error("Error getting LLM API: %s", err) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Error preparing LLM API: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=user_input.conversation_id + ) + tools = [ + _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools + ] + + if ( + user_input.context + and user_input.context.user_id + and ( + user := await self.hass.auth.async_get_user(user_input.context.user_id) + ) + ): + user_name = user.name # Look up message history message_history: MessageHistory | None = None @@ -102,13 +155,24 @@ class OllamaConversationEntity( # New history # # Render prompt and error out early if there's a problem - raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT) try: - prompt = self._generate_prompt(raw_prompt) - _LOGGER.debug("Prompt: %s", prompt) + prompt_parts = [ + template.Template( + llm.BASE_PROMPT + + settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT), + self.hass, + ).async_render( + { + "ha_name": self.hass.config.location_name, + "user_name": user_name, + "llm_context": llm_context, + }, + parse_result=False, + ) + ] + except TemplateError as err: _LOGGER.error("Error rendering prompt: %s", err) - intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, f"Sorry, I had a problem generating my prompt: {err}", @@ -117,6 +181,13 @@ class OllamaConversationEntity( response=intent_response, conversation_id=conversation_id ) + if llm_api: + prompt_parts.append(llm_api.api_prompt) + + prompt = "\n".join(prompt_parts) + _LOGGER.debug("Prompt: %s", prompt) + _LOGGER.debug("Tools: %s", tools) + message_history = MessageHistory( timestamp=time.monotonic(), messages=[ @@ -146,35 +217,66 @@ class OllamaConversationEntity( ) # Get response - try: - response = await client.chat( - model=model, - # Make a copy of the messages because we mutate the list later - messages=list(message_history.messages), - stream=False, - # keep_alive requires specifying unit. In this case, seconds - keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s", - ) - except (ollama.RequestError, ollama.ResponseError) as err: - _LOGGER.error("Unexpected error talking to Ollama server: %s", err) - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem talking to the Ollama server: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id + # To prevent infinite loops, we limit the number of iterations + for _iteration in range(MAX_TOOL_ITERATIONS): + try: + response = await client.chat( + model=model, + # Make a copy of the messages because we mutate the list later + messages=list(message_history.messages), + tools=tools, + stream=False, + # keep_alive requires specifying unit. In this case, seconds + keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s", + ) + except (ollama.RequestError, ollama.ResponseError) as err: + _LOGGER.error("Unexpected error talking to Ollama server: %s", err) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem talking to the Ollama server: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + response_message = response["message"] + message_history.messages.append( + ollama.Message( + role=response_message["role"], + content=response_message.get("content"), + tool_calls=response_message.get("tool_calls"), + ) ) - response_message = response["message"] - message_history.messages.append( - ollama.Message( - role=response_message["role"], content=response_message["content"] - ) - ) + tool_calls = response_message.get("tool_calls") + if not tool_calls or not llm_api: + break + + for tool_call in tool_calls: + tool_input = llm.ToolInput( + tool_name=tool_call["function"]["name"], + tool_args=tool_call["function"]["arguments"], + ) + _LOGGER.debug( + "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args + ) + + try: + tool_response = await llm_api.async_call_tool(tool_input) + except (HomeAssistantError, vol.Invalid) as e: + tool_response = {"error": type(e).__name__} + if str(e): + tool_response["error_text"] = str(e) + + _LOGGER.debug("Tool response: %s", tool_response) + message_history.messages.append( + ollama.Message( + role=MessageRole.TOOL.value, # type: ignore[typeddict-item] + content=json.dumps(tool_response), + ) + ) # Create intent response - intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_speech(response_message["content"]) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id @@ -204,62 +306,3 @@ class OllamaConversationEntity( message_history.messages = [ message_history.messages[0] ] + message_history.messages[drop_index:] - - def _generate_prompt(self, raw_prompt: str) -> str: - """Generate a prompt for the user.""" - return template.Template(raw_prompt, self.hass).async_render( - { - "ha_name": self.hass.config.location_name, - "ha_language": self.hass.config.language, - "exposed_entities": self._get_exposed_entities(), - }, - parse_result=False, - ) - - def _get_exposed_entities(self) -> list[ExposedEntity]: - """Get state list of exposed entities.""" - area_registry = ar.async_get(self.hass) - entity_registry = er.async_get(self.hass) - device_registry = dr.async_get(self.hass) - - exposed_entities = [] - exposed_states = [ - state - for state in self.hass.states.async_all() - if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id) - ] - - for state in exposed_states: - entity_entry = entity_registry.async_get(state.entity_id) - names = [state.name] - area_names = [] - - if entity_entry is not None: - # Add aliases - names.extend(entity_entry.aliases) - if entity_entry.area_id and ( - area := area_registry.async_get_area(entity_entry.area_id) - ): - # Entity is in area - area_names.append(area.name) - area_names.extend(area.aliases) - elif entity_entry.device_id and ( - device := device_registry.async_get(entity_entry.device_id) - ): - # Check device area - if device.area_id and ( - area := area_registry.async_get_area(device.area_id) - ): - area_names.append(area.name) - area_names.extend(area.aliases) - - exposed_entities.append( - ExposedEntity( - entity_id=state.entity_id, - state=state, - names=names, - area_names=area_names, - ) - ) - - return exposed_entities diff --git a/homeassistant/components/ollama/manifest.json b/homeassistant/components/ollama/manifest.json index f7265d87aab..4d4321b8e3d 100644 --- a/homeassistant/components/ollama/manifest.json +++ b/homeassistant/components/ollama/manifest.json @@ -1,7 +1,7 @@ { "domain": "ollama", "name": "Ollama", - "after_dependencies": ["assist_pipeline"], + "after_dependencies": ["assist_pipeline", "intent"], "codeowners": ["@synesthesiam"], "config_flow": true, "dependencies": ["conversation"], diff --git a/homeassistant/components/ollama/models.py b/homeassistant/components/ollama/models.py index 56cc552fad1..3b6fc958587 100644 --- a/homeassistant/components/ollama/models.py +++ b/homeassistant/components/ollama/models.py @@ -2,18 +2,17 @@ from dataclasses import dataclass from enum import StrEnum -from functools import cached_property import ollama -from homeassistant.core import State - class MessageRole(StrEnum): """Role of a chat message.""" SYSTEM = "system" # prompt USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" @dataclass @@ -30,18 +29,3 @@ class MessageHistory: def num_user_messages(self) -> int: """Return a count of user messages.""" return sum(m["role"] == MessageRole.USER.value for m in self.messages) - - -@dataclass(frozen=True) -class ExposedEntity: - """Relevant information about an exposed entity.""" - - entity_id: str - state: State - names: list[str] - area_names: list[str] - - @cached_property - def domain(self) -> str: - """Get domain from entity id.""" - return self.entity_id.split(".", maxsplit=1)[0] diff --git a/homeassistant/components/ollama/strings.json b/homeassistant/components/ollama/strings.json index cc0f05d3068..2366ecd0848 100644 --- a/homeassistant/components/ollama/strings.json +++ b/homeassistant/components/ollama/strings.json @@ -24,11 +24,13 @@ "step": { "init": { "data": { - "prompt": "Prompt template", + "prompt": "Instructions", + "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", "max_history": "Max history messages", "keep_alive": "Keep alive" }, "data_description": { + "prompt": "Instruct how the LLM should respond. This can be a template.", "keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never." } } diff --git a/tests/components/ollama/__init__.py b/tests/components/ollama/__init__.py index 22a576e94a4..6ad77bb2217 100644 --- a/tests/components/ollama/__init__.py +++ b/tests/components/ollama/__init__.py @@ -1,7 +1,7 @@ """Tests for the Ollama integration.""" from homeassistant.components import ollama -from homeassistant.components.ollama.const import DEFAULT_PROMPT +from homeassistant.helpers import llm TEST_USER_DATA = { ollama.CONF_URL: "http://localhost:11434", @@ -9,6 +9,6 @@ TEST_USER_DATA = { } TEST_OPTIONS = { - ollama.CONF_PROMPT: DEFAULT_PROMPT, + ollama.CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT, ollama.CONF_MAX_HISTORY: 2, } diff --git a/tests/components/ollama/conftest.py b/tests/components/ollama/conftest.py index db1689bd416..0355a13eba7 100644 --- a/tests/components/ollama/conftest.py +++ b/tests/components/ollama/conftest.py @@ -5,7 +5,9 @@ from unittest.mock import patch import pytest from homeassistant.components import ollama +from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import HomeAssistant +from homeassistant.helpers import llm from homeassistant.setup import async_setup_component from . import TEST_OPTIONS, TEST_USER_DATA @@ -25,6 +27,17 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: return entry +@pytest.fixture +def mock_config_entry_with_assist( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> MockConfigEntry: + """Mock a config entry with assist.""" + hass.config_entries.async_update_entry( + mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST} + ) + return mock_config_entry + + @pytest.fixture async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfigEntry): """Initialize integration.""" diff --git a/tests/components/ollama/snapshots/test_conversation.ambr b/tests/components/ollama/snapshots/test_conversation.ambr new file mode 100644 index 00000000000..e4dd7cd00bb --- /dev/null +++ b/tests/components/ollama/snapshots/test_conversation.ambr @@ -0,0 +1,34 @@ +# serializer version: 1 +# name: test_unknown_hass_api + dict({ + 'conversation_id': None, + 'response': IntentResponse( + card=dict({ + }), + error_code=, + failed_results=list([ + ]), + intent=None, + intent_targets=list([ + ]), + language='en', + matched_states=list([ + ]), + reprompt=dict({ + }), + response_type=, + speech=dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': 'Error preparing LLM API: API non-existing not found', + }), + }), + speech_slots=dict({ + }), + success_results=list([ + ]), + unmatched_states=list([ + ]), + ), + }) +# --- diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py index b6f0be3c414..9be6f3b33a3 100644 --- a/tests/components/ollama/test_conversation.py +++ b/tests/components/ollama/test_conversation.py @@ -1,21 +1,18 @@ """Tests for the Ollama integration.""" -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch from ollama import Message, ResponseError import pytest +from syrupy.assertion import SnapshotAssertion +import voluptuous as vol from homeassistant.components import conversation, ollama from homeassistant.components.conversation import trace -from homeassistant.components.homeassistant.exposed_entities import async_expose_entity -from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL +from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import Context, HomeAssistant -from homeassistant.helpers import ( - area_registry as ar, - device_registry as dr, - entity_registry as er, - intent, -) +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import intent, llm from tests.common import MockConfigEntry @@ -25,9 +22,6 @@ async def test_chat( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component, - area_registry: ar.AreaRegistry, - device_registry: dr.DeviceRegistry, - entity_registry: er.EntityRegistry, agent_id: str, ) -> None: """Test that the chat function is called with the appropriate arguments.""" @@ -35,48 +29,8 @@ async def test_chat( if agent_id is None: agent_id = mock_config_entry.entry_id - # Create some areas, devices, and entities - area_kitchen = area_registry.async_get_or_create("kitchen_id") - area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen") - area_bedroom = area_registry.async_get_or_create("bedroom_id") - area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom") - area_office = area_registry.async_get_or_create("office_id") - area_office = area_registry.async_update(area_office.id, name="office") - entry = MockConfigEntry() entry.add_to_hass(hass) - kitchen_device = device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections=set(), - identifiers={("demo", "id-1234")}, - ) - device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id) - - kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234") - kitchen_light = entity_registry.async_update_entity( - kitchen_light.entity_id, device_id=kitchen_device.id - ) - hass.states.async_set( - kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"} - ) - - bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678") - bedroom_light = entity_registry.async_update_entity( - bedroom_light.entity_id, area_id=area_bedroom.id - ) - hass.states.async_set( - bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"} - ) - - # Hide the office light - office_light = entity_registry.async_get_or_create("light", "demo", "ABCD") - office_light = entity_registry.async_update_entity( - office_light.entity_id, area_id=area_office.id - ) - hass.states.async_set( - office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"} - ) - async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False) with patch( "ollama.AsyncClient.chat", @@ -100,12 +54,6 @@ async def test_chat( Message({"role": "user", "content": "test message"}), ] - # Verify only exposed devices/areas are in prompt - assert "kitchen light" in prompt - assert "bedroom light" in prompt - assert "office light" not in prompt - assert "office" not in prompt - assert ( result.response.response_type == intent.IntentResponseType.ACTION_DONE ), result @@ -122,7 +70,232 @@ async def test_chat( ] # AGENT_DETAIL event contains the raw prompt passed to the model detail_event = trace_events[1] - assert "The current time is" in detail_event["data"]["messages"][0]["content"] + assert "Current time is" in detail_event["data"]["messages"][0]["content"] + + +async def test_template_variables( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> None: + """Test that template variables work.""" + context = Context(user_id="12345") + mock_user = Mock() + mock_user.id = "12345" + mock_user.name = "Test User" + + hass.config_entries.async_update_entry( + mock_config_entry, + options={ + "prompt": ( + "The user name is {{ user_name }}. " + "The user id is {{ llm_context.context.user_id }}." + ), + }, + ) + with ( + patch("ollama.AsyncClient.list"), + patch( + "ollama.AsyncClient.chat", + return_value={"message": {"role": "assistant", "content": "test response"}}, + ) as mock_chat, + patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user), + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + result = await conversation.async_converse( + hass, "hello", None, context, agent_id=mock_config_entry.entry_id + ) + + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + + args = mock_chat.call_args.kwargs + prompt = args["messages"][0]["content"] + + assert "The user name is Test User." in prompt + assert "The user id is 12345." in prompt + + +@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools") +async def test_function_call( + mock_get_tools, + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, +) -> None: + """Test function call from the assistant.""" + agent_id = mock_config_entry_with_assist.entry_id + context = Context() + + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema( + {vol.Optional("param1", description="Test parameters"): str} + ) + mock_tool.async_call.return_value = "Test response" + + mock_get_tools.return_value = [mock_tool] + + def completion_result(*args, messages, **kwargs): + for message in messages: + if message["role"] == "tool": + return { + "message": { + "role": "assistant", + "content": "I have successfully called the function", + } + } + + return { + "message": { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "test_tool", + "arguments": {"param1": "test_value"}, + } + } + ], + } + } + + with patch( + "ollama.AsyncClient.chat", + side_effect=completion_result, + ) as mock_chat: + result = await conversation.async_converse( + hass, + "Please call the test function", + None, + context, + agent_id=agent_id, + ) + + assert mock_chat.call_count == 2 + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert ( + result.response.speech["plain"]["speech"] + == "I have successfully called the function" + ) + mock_tool.async_call.assert_awaited_once_with( + hass, + llm.ToolInput( + tool_name="test_tool", + tool_args={"param1": "test_value"}, + ), + llm.LLMContext( + platform="ollama", + context=context, + user_prompt="Please call the test function", + language="en", + assistant="conversation", + device_id=None, + ), + ) + + +@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools") +async def test_function_exception( + mock_get_tools, + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, +) -> None: + """Test function call with exception.""" + agent_id = mock_config_entry_with_assist.entry_id + context = Context() + + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema( + {vol.Optional("param1", description="Test parameters"): str} + ) + mock_tool.async_call.side_effect = HomeAssistantError("Test tool exception") + + mock_get_tools.return_value = [mock_tool] + + def completion_result(*args, messages, **kwargs): + for message in messages: + if message["role"] == "tool": + return { + "message": { + "role": "assistant", + "content": "There was an error calling the function", + } + } + + return { + "message": { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "test_tool", + "arguments": {"param1": "test_value"}, + } + } + ], + } + } + + with patch( + "ollama.AsyncClient.chat", + side_effect=completion_result, + ) as mock_chat: + result = await conversation.async_converse( + hass, + "Please call the test function", + None, + context, + agent_id=agent_id, + ) + + assert mock_chat.call_count == 2 + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert ( + result.response.speech["plain"]["speech"] + == "There was an error calling the function" + ) + mock_tool.async_call.assert_awaited_once_with( + hass, + llm.ToolInput( + tool_name="test_tool", + tool_args={"param1": "test_value"}, + ), + llm.LLMContext( + platform="ollama", + context=context, + user_prompt="Please call the test function", + language="en", + assistant="conversation", + device_id=None, + ), + ) + + +async def test_unknown_hass_api( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + snapshot: SnapshotAssertion, + mock_init_component, +) -> None: + """Test when we reference an API that no longer exists.""" + hass.config_entries.async_update_entry( + mock_config_entry, + options={ + **mock_config_entry.options, + CONF_LLM_HASS_API: "non-existing", + }, + ) + + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result == snapshot async def test_message_history_trimming(