OpenAI to respect custom conversation IDs (#119307)
This commit is contained in:
parent
cdd9f19cf9
commit
0ea9581cfc
2 changed files with 55 additions and 2 deletions
|
@ -141,11 +141,25 @@ class OpenAIConversationEntity(
|
|||
)
|
||||
tools = [_format_tool(tool) for tool in llm_api.tools]
|
||||
|
||||
if user_input.conversation_id in self.history:
|
||||
if user_input.conversation_id is None:
|
||||
conversation_id = ulid.ulid_now()
|
||||
messages = []
|
||||
|
||||
elif user_input.conversation_id in self.history:
|
||||
conversation_id = user_input.conversation_id
|
||||
messages = self.history[conversation_id]
|
||||
|
||||
else:
|
||||
conversation_id = ulid.ulid_now()
|
||||
# Conversation IDs are ULIDs. We generate a new one if not provided.
|
||||
# If an old OLID is passed in, we will generate a new one to indicate
|
||||
# a new conversation was started. If the user picks their own, they
|
||||
# want to track a conversation and we respect it.
|
||||
try:
|
||||
ulid.ulid_to_bytes(user_input.conversation_id)
|
||||
conversation_id = ulid.ulid_now()
|
||||
except ValueError:
|
||||
conversation_id = user_input.conversation_id
|
||||
|
||||
messages = []
|
||||
|
||||
if (
|
||||
|
|
|
@ -22,6 +22,7 @@ from homeassistant.core import Context, HomeAssistant
|
|||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import intent, llm
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
@ -497,3 +498,41 @@ async def test_unknown_hass_api(
|
|||
)
|
||||
|
||||
assert result == snapshot
|
||||
|
||||
|
||||
@patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_conversation_id(
|
||||
mock_create,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test conversation ID is honored."""
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, None, agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
conversation_id = result.conversation_id
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", conversation_id, None, agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.conversation_id == conversation_id
|
||||
|
||||
unknown_id = ulid.ulid()
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", unknown_id, None, agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.conversation_id != unknown_id
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", "koala", None, agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.conversation_id == "koala"
|
||||
|
|
Loading…
Add table
Reference in a new issue