Update OpenAI prompt on each interaction (#118747)

This commit is contained in:
Paulus Schoutsen 2024-06-03 16:27:05 -04:00 committed by GitHub
parent 8ea3a6843a
commit 299c0de968
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 93 additions and 53 deletions

View file

@ -146,58 +146,58 @@ class OpenAIConversationEntity(
messages = self.history[conversation_id] messages = self.history[conversation_id]
else: else:
conversation_id = ulid.ulid_now() conversation_id = ulid.ulid_now()
messages = []
if ( if (
user_input.context user_input.context
and user_input.context.user_id and user_input.context.user_id
and ( and (
user := await self.hass.auth.async_get_user( user := await self.hass.auth.async_get_user(user_input.context.user_id)
user_input.context.user_id )
) ):
user_name = user.name
try:
if llm_api:
api_prompt = llm_api.api_prompt
else:
api_prompt = llm.async_render_no_api_prompt(self.hass)
prompt = "\n".join(
(
template.Template(
llm.BASE_PROMPT
+ options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
self.hass,
).async_render(
{
"ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
},
parse_result=False,
),
api_prompt,
) )
): )
user_name = user.name
try: except TemplateError as err:
if llm_api: LOGGER.error("Error rendering prompt: %s", err)
api_prompt = llm_api.api_prompt intent_response = intent.IntentResponse(language=user_input.language)
else: intent_response.async_set_error(
api_prompt = llm.async_render_no_api_prompt(self.hass) intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
prompt = "\n".join( # Create a copy of the variable because we attach it to the trace
( messages = [
template.Template( ChatCompletionSystemMessageParam(role="system", content=prompt),
llm.BASE_PROMPT *messages[1:],
+ options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT), ChatCompletionUserMessageParam(role="user", content=user_input.text),
self.hass, ]
).async_render(
{
"ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
},
parse_result=False,
),
api_prompt,
)
)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
messages = [ChatCompletionSystemMessageParam(role="system", content=prompt)]
messages.append(
ChatCompletionUserMessageParam(role="user", content=user_input.text)
)
LOGGER.debug("Prompt: %s", messages) LOGGER.debug("Prompt: %s", messages)
trace.async_conversation_trace_append( trace.async_conversation_trace_append(

View file

@ -2,6 +2,7 @@
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
from freezegun import freeze_time
from httpx import Response from httpx import Response
from openai import RateLimitError from openai import RateLimitError
from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion import ChatCompletion, Choice
@ -214,11 +215,14 @@ async def test_function_call(
), ),
) )
with patch( with (
"openai.resources.chat.completions.AsyncCompletions.create", patch(
new_callable=AsyncMock, "openai.resources.chat.completions.AsyncCompletions.create",
side_effect=completion_result, new_callable=AsyncMock,
) as mock_create: side_effect=completion_result,
) as mock_create,
freeze_time("2024-06-03 23:00:00"),
):
result = await conversation.async_converse( result = await conversation.async_converse(
hass, hass,
"Please call the test function", "Please call the test function",
@ -227,6 +231,11 @@ async def test_function_call(
agent_id=agent_id, agent_id=agent_id,
) )
assert (
"Today's date is 2024-06-03."
in mock_create.mock_calls[1][2]["messages"][0]["content"]
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_create.mock_calls[1][2]["messages"][3] == { assert mock_create.mock_calls[1][2]["messages"][3] == {
"role": "tool", "role": "tool",
@ -262,6 +271,37 @@ async def test_function_call(
# AGENT_DETAIL event contains the raw prompt passed to the model # AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1] detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"] assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
assert (
"Today's date is 2024-06-03."
in trace_events[1]["data"]["messages"][0]["content"]
)
# Call it again, make sure we have updated prompt
with (
patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
side_effect=completion_result,
) as mock_create,
freeze_time("2024-06-04 23:00:00"),
):
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert (
"Today's date is 2024-06-04."
in mock_create.mock_calls[1][2]["messages"][0]["content"]
)
# Test old assert message not updated
assert (
"Today's date is 2024-06-03."
in trace_events[1]["data"]["messages"][0]["content"]
)
@patch( @patch(