Add ollama context window size configuration (#124555)

* Add ollama context window size configuration

* Set higher max context size
This commit is contained in:
Allen Porter 2024-08-25 12:22:57 -07:00 committed by GitHub
parent be206156b0
commit 3304e27fa3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 73 additions and 4 deletions

View file

@ -18,6 +18,7 @@ from .const import (
CONF_KEEP_ALIVE, CONF_KEEP_ALIVE,
CONF_MAX_HISTORY, CONF_MAX_HISTORY,
CONF_MODEL, CONF_MODEL,
CONF_NUM_CTX,
CONF_PROMPT, CONF_PROMPT,
DEFAULT_TIMEOUT, DEFAULT_TIMEOUT,
DOMAIN, DOMAIN,
@ -30,6 +31,7 @@ __all__ = [
"CONF_PROMPT", "CONF_PROMPT",
"CONF_MODEL", "CONF_MODEL",
"CONF_MAX_HISTORY", "CONF_MAX_HISTORY",
"CONF_NUM_CTX",
"CONF_KEEP_ALIVE", "CONF_KEEP_ALIVE",
"DOMAIN", "DOMAIN",
] ]

View file

@ -38,12 +38,16 @@ from .const import (
CONF_KEEP_ALIVE, CONF_KEEP_ALIVE,
CONF_MAX_HISTORY, CONF_MAX_HISTORY,
CONF_MODEL, CONF_MODEL,
CONF_NUM_CTX,
CONF_PROMPT, CONF_PROMPT,
DEFAULT_KEEP_ALIVE, DEFAULT_KEEP_ALIVE,
DEFAULT_MAX_HISTORY, DEFAULT_MAX_HISTORY,
DEFAULT_MODEL, DEFAULT_MODEL,
DEFAULT_NUM_CTX,
DEFAULT_TIMEOUT, DEFAULT_TIMEOUT,
DOMAIN, DOMAIN,
MAX_NUM_CTX,
MIN_NUM_CTX,
MODEL_NAMES, MODEL_NAMES,
) )
@ -255,6 +259,14 @@ def ollama_config_option_schema(
description={"suggested_value": options.get(CONF_LLM_HASS_API)}, description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none", default="none",
): SelectSelector(SelectSelectorConfig(options=hass_apis)), ): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Optional(
CONF_NUM_CTX,
description={"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
): NumberSelector(
NumberSelectorConfig(
min=MIN_NUM_CTX, max=MAX_NUM_CTX, step=1, mode=NumberSelectorMode.BOX
)
),
vol.Optional( vol.Optional(
CONF_MAX_HISTORY, CONF_MAX_HISTORY,
description={ description={

View file

@ -11,6 +11,11 @@ DEFAULT_KEEP_ALIVE = -1 # seconds. -1 = indefinite, 0 = never
KEEP_ALIVE_FOREVER = -1 KEEP_ALIVE_FOREVER = -1
DEFAULT_TIMEOUT = 5.0 # seconds DEFAULT_TIMEOUT = 5.0 # seconds
CONF_NUM_CTX = "num_ctx"
DEFAULT_NUM_CTX = 8192
MIN_NUM_CTX = 2048
MAX_NUM_CTX = 131072
CONF_MAX_HISTORY = "max_history" CONF_MAX_HISTORY = "max_history"
DEFAULT_MAX_HISTORY = 20 DEFAULT_MAX_HISTORY = 20

View file

@ -26,9 +26,11 @@ from .const import (
CONF_KEEP_ALIVE, CONF_KEEP_ALIVE,
CONF_MAX_HISTORY, CONF_MAX_HISTORY,
CONF_MODEL, CONF_MODEL,
CONF_NUM_CTX,
CONF_PROMPT, CONF_PROMPT,
DEFAULT_KEEP_ALIVE, DEFAULT_KEEP_ALIVE,
DEFAULT_MAX_HISTORY, DEFAULT_MAX_HISTORY,
DEFAULT_NUM_CTX,
DOMAIN, DOMAIN,
MAX_HISTORY_SECONDS, MAX_HISTORY_SECONDS,
) )
@ -263,6 +265,7 @@ class OllamaConversationEntity(
stream=False, stream=False,
# keep_alive requires specifying unit. In this case, seconds # keep_alive requires specifying unit. In this case, seconds
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s", keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
) )
except (ollama.RequestError, ollama.ResponseError) as err: except (ollama.RequestError, ollama.ResponseError) as err:
_LOGGER.error("Unexpected error talking to Ollama server: %s", err) _LOGGER.error("Unexpected error talking to Ollama server: %s", err)

View file

@ -27,11 +27,13 @@
"prompt": "Instructions", "prompt": "Instructions",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"max_history": "Max history messages", "max_history": "Max history messages",
"num_ctx": "Context window size",
"keep_alive": "Keep alive" "keep_alive": "Keep alive"
}, },
"data_description": { "data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template.", "prompt": "Instruct how the LLM should respond. This can be a template.",
"keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never." "keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never.",
"num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities."
} }
} }
} }

View file

@ -1,5 +1,6 @@
"""Tests Ollama integration.""" """Tests Ollama integration."""
from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -16,12 +17,20 @@ from tests.common import MockConfigEntry
@pytest.fixture @pytest.fixture
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: def mock_config_entry_options() -> dict[str, Any]:
"""Fixture for configuration entry options."""
return TEST_OPTIONS
@pytest.fixture
def mock_config_entry(
hass: HomeAssistant, mock_config_entry_options: dict[str, Any]
) -> MockConfigEntry:
"""Mock a config entry.""" """Mock a config entry."""
entry = MockConfigEntry( entry = MockConfigEntry(
domain=ollama.DOMAIN, domain=ollama.DOMAIN,
data=TEST_USER_DATA, data=TEST_USER_DATA,
options=TEST_OPTIONS, options=mock_config_entry_options,
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
return entry return entry

View file

@ -164,13 +164,18 @@ async def test_options(
) )
options = await hass.config_entries.options.async_configure( options = await hass.config_entries.options.async_configure(
options_flow["flow_id"], options_flow["flow_id"],
{ollama.CONF_PROMPT: "test prompt", ollama.CONF_MAX_HISTORY: 100}, {
ollama.CONF_PROMPT: "test prompt",
ollama.CONF_MAX_HISTORY: 100,
ollama.CONF_NUM_CTX: 32768,
},
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert options["type"] is FlowResultType.CREATE_ENTRY assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"] == { assert options["data"] == {
ollama.CONF_PROMPT: "test prompt", ollama.CONF_PROMPT: "test prompt",
ollama.CONF_MAX_HISTORY: 100, ollama.CONF_MAX_HISTORY: 100,
ollama.CONF_NUM_CTX: 32768,
} }

View file

@ -578,3 +578,34 @@ async def test_conversation_agent_with_assist(
state.attributes[ATTR_SUPPORTED_FEATURES] state.attributes[ATTR_SUPPORTED_FEATURES]
== conversation.ConversationEntityFeature.CONTROL == conversation.ConversationEntityFeature.CONTROL
) )
@pytest.mark.parametrize(
("mock_config_entry_options", "expected_options"),
[
({}, {"num_ctx": 8192}),
({"num_ctx": 16384}, {"num_ctx": 16384}),
],
)
async def test_options(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
expected_options: dict[str, Any],
) -> None:
"""Test that options are passed correctly to ollama client."""
with patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
) as mock_chat:
await conversation.async_converse(
hass,
"test message",
None,
Context(),
agent_id="conversation.mock_title",
)
assert mock_chat.call_count == 1
args = mock_chat.call_args.kwargs
assert args.get("options") == expected_options