Add ollama context window size configuration (#124555)

* Add ollama context window size configuration

* Set higher max context size
This commit is contained in:
Allen Porter 2024-08-25 12:22:57 -07:00 committed by GitHub
parent be206156b0
commit 3304e27fa3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 73 additions and 4 deletions

View file

@ -18,6 +18,7 @@ from .const import (
CONF_KEEP_ALIVE,
CONF_MAX_HISTORY,
CONF_MODEL,
CONF_NUM_CTX,
CONF_PROMPT,
DEFAULT_TIMEOUT,
DOMAIN,
@ -30,6 +31,7 @@ __all__ = [
"CONF_PROMPT",
"CONF_MODEL",
"CONF_MAX_HISTORY",
"CONF_NUM_CTX",
"CONF_KEEP_ALIVE",
"DOMAIN",
]

View file

@ -38,12 +38,16 @@ from .const import (
CONF_KEEP_ALIVE,
CONF_MAX_HISTORY,
CONF_MODEL,
CONF_NUM_CTX,
CONF_PROMPT,
DEFAULT_KEEP_ALIVE,
DEFAULT_MAX_HISTORY,
DEFAULT_MODEL,
DEFAULT_NUM_CTX,
DEFAULT_TIMEOUT,
DOMAIN,
MAX_NUM_CTX,
MIN_NUM_CTX,
MODEL_NAMES,
)
@ -255,6 +259,14 @@ def ollama_config_option_schema(
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Optional(
CONF_NUM_CTX,
description={"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
): NumberSelector(
NumberSelectorConfig(
min=MIN_NUM_CTX, max=MAX_NUM_CTX, step=1, mode=NumberSelectorMode.BOX
)
),
vol.Optional(
CONF_MAX_HISTORY,
description={

View file

@ -11,6 +11,11 @@ DEFAULT_KEEP_ALIVE = -1 # seconds. -1 = indefinite, 0 = never
KEEP_ALIVE_FOREVER = -1
DEFAULT_TIMEOUT = 5.0 # seconds
CONF_NUM_CTX = "num_ctx"
DEFAULT_NUM_CTX = 8192
MIN_NUM_CTX = 2048
MAX_NUM_CTX = 131072
CONF_MAX_HISTORY = "max_history"
DEFAULT_MAX_HISTORY = 20

View file

@ -26,9 +26,11 @@ from .const import (
CONF_KEEP_ALIVE,
CONF_MAX_HISTORY,
CONF_MODEL,
CONF_NUM_CTX,
CONF_PROMPT,
DEFAULT_KEEP_ALIVE,
DEFAULT_MAX_HISTORY,
DEFAULT_NUM_CTX,
DOMAIN,
MAX_HISTORY_SECONDS,
)
@ -263,6 +265,7 @@ class OllamaConversationEntity(
stream=False,
# keep_alive requires specifying unit. In this case, seconds
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
)
except (ollama.RequestError, ollama.ResponseError) as err:
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)

View file

@ -27,11 +27,13 @@
"prompt": "Instructions",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"max_history": "Max history messages",
"num_ctx": "Context window size",
"keep_alive": "Keep alive"
},
"data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template.",
"keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never."
"keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never.",
"num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities."
}
}
}

View file

@ -1,5 +1,6 @@
"""Tests Ollama integration."""
from typing import Any
from unittest.mock import patch
import pytest
@ -16,12 +17,20 @@ from tests.common import MockConfigEntry
@pytest.fixture
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
def mock_config_entry_options() -> dict[str, Any]:
"""Fixture for configuration entry options."""
return TEST_OPTIONS
@pytest.fixture
def mock_config_entry(
hass: HomeAssistant, mock_config_entry_options: dict[str, Any]
) -> MockConfigEntry:
"""Mock a config entry."""
entry = MockConfigEntry(
domain=ollama.DOMAIN,
data=TEST_USER_DATA,
options=TEST_OPTIONS,
options=mock_config_entry_options,
)
entry.add_to_hass(hass)
return entry

View file

@ -164,13 +164,18 @@ async def test_options(
)
options = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{ollama.CONF_PROMPT: "test prompt", ollama.CONF_MAX_HISTORY: 100},
{
ollama.CONF_PROMPT: "test prompt",
ollama.CONF_MAX_HISTORY: 100,
ollama.CONF_NUM_CTX: 32768,
},
)
await hass.async_block_till_done()
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"] == {
ollama.CONF_PROMPT: "test prompt",
ollama.CONF_MAX_HISTORY: 100,
ollama.CONF_NUM_CTX: 32768,
}

View file

@ -578,3 +578,34 @@ async def test_conversation_agent_with_assist(
state.attributes[ATTR_SUPPORTED_FEATURES]
== conversation.ConversationEntityFeature.CONTROL
)
@pytest.mark.parametrize(
("mock_config_entry_options", "expected_options"),
[
({}, {"num_ctx": 8192}),
({"num_ctx": 16384}, {"num_ctx": 16384}),
],
)
async def test_options(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
expected_options: dict[str, Any],
) -> None:
"""Test that options are passed correctly to ollama client."""
with patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
) as mock_chat:
await conversation.async_converse(
hass,
"test message",
None,
Context(),
agent_id="conversation.mock_title",
)
assert mock_chat.call_count == 1
args = mock_chat.call_args.kwargs
assert args.get("options") == expected_options