OpenAI to respect custom conversation IDs (#119307)

This commit is contained in:
Paulus Schoutsen 2024-06-11 01:49:14 -04:00 committed by GitHub
parent cdd9f19cf9
commit 0ea9581cfc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 55 additions and 2 deletions

View file

@ -141,11 +141,25 @@ class OpenAIConversationEntity(
)
tools = [_format_tool(tool) for tool in llm_api.tools]
if user_input.conversation_id in self.history:
if user_input.conversation_id is None:
conversation_id = ulid.ulid_now()
messages = []
elif user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
messages = self.history[conversation_id]
else:
conversation_id = ulid.ulid_now()
# Conversation IDs are ULIDs. We generate a new one if not provided.
# If an old OLID is passed in, we will generate a new one to indicate
# a new conversation was started. If the user picks their own, they
# want to track a conversation and we respect it.
try:
ulid.ulid_to_bytes(user_input.conversation_id)
conversation_id = ulid.ulid_now()
except ValueError:
conversation_id = user_input.conversation_id
messages = []
if (

View file

@ -22,6 +22,7 @@ from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm
from homeassistant.setup import async_setup_component
from homeassistant.util import ulid
from tests.common import MockConfigEntry
@ -497,3 +498,41 @@ async def test_unknown_hass_api(
)
assert result == snapshot
@patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
)
async def test_conversation_id(
mock_create,
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test conversation ID is honored."""
result = await conversation.async_converse(
hass, "hello", None, None, agent_id=mock_config_entry.entry_id
)
conversation_id = result.conversation_id
result = await conversation.async_converse(
hass, "hello", conversation_id, None, agent_id=mock_config_entry.entry_id
)
assert result.conversation_id == conversation_id
unknown_id = ulid.ulid()
result = await conversation.async_converse(
hass, "hello", unknown_id, None, agent_id=mock_config_entry.entry_id
)
assert result.conversation_id != unknown_id
result = await conversation.async_converse(
hass, "hello", "koala", None, agent_id=mock_config_entry.entry_id
)
assert result.conversation_id == "koala"