From 0ea9581cfc3a7c151540dd7e29cc0a421828d9e5 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 11 Jun 2024 01:49:14 -0400 Subject: [PATCH] OpenAI to respect custom conversation IDs (#119307) --- .../openai_conversation/conversation.py | 18 ++++++++- .../openai_conversation/test_conversation.py | 39 +++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index d5e566678f1..d0b3ef8f895 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -141,11 +141,25 @@ class OpenAIConversationEntity( ) tools = [_format_tool(tool) for tool in llm_api.tools] - if user_input.conversation_id in self.history: + if user_input.conversation_id is None: + conversation_id = ulid.ulid_now() + messages = [] + + elif user_input.conversation_id in self.history: conversation_id = user_input.conversation_id messages = self.history[conversation_id] + else: - conversation_id = ulid.ulid_now() + # Conversation IDs are ULIDs. We generate a new one if not provided. + # If an old OLID is passed in, we will generate a new one to indicate + # a new conversation was started. If the user picks their own, they + # want to track a conversation and we respect it. + try: + ulid.ulid_to_bytes(user_input.conversation_id) + conversation_id = ulid.ulid_now() + except ValueError: + conversation_id = user_input.conversation_id + messages = [] if ( diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 002b2df186b..5ca54611c91 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -22,6 +22,7 @@ from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import intent, llm from homeassistant.setup import async_setup_component +from homeassistant.util import ulid from tests.common import MockConfigEntry @@ -497,3 +498,41 @@ async def test_unknown_hass_api( ) assert result == snapshot + + +@patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, +) +async def test_conversation_id( + mock_create, + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, +) -> None: + """Test conversation ID is honored.""" + result = await conversation.async_converse( + hass, "hello", None, None, agent_id=mock_config_entry.entry_id + ) + + conversation_id = result.conversation_id + + result = await conversation.async_converse( + hass, "hello", conversation_id, None, agent_id=mock_config_entry.entry_id + ) + + assert result.conversation_id == conversation_id + + unknown_id = ulid.ulid() + + result = await conversation.async_converse( + hass, "hello", unknown_id, None, agent_id=mock_config_entry.entry_id + ) + + assert result.conversation_id != unknown_id + + result = await conversation.async_converse( + hass, "hello", "koala", None, agent_id=mock_config_entry.entry_id + ) + + assert result.conversation_id == "koala"