Add OpenAI Conversation system prompt user_name and llm_context variables (#118512)

* OpenAI Conversation: Add variables to the system prompt

* User name and llm_context

* test for user name

* test for user id

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Denis Shulyaka 2024-06-01 03:28:23 +08:00 committed by GitHub
parent 80e9ff672a
commit 46da43d09d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 75 additions and 10 deletions

View file

@ -113,20 +113,22 @@ class OpenAIConversationEntity(
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.APIInstance | None = None llm_api: llm.APIInstance | None = None
tools: list[ChatCompletionToolParam] | None = None tools: list[ChatCompletionToolParam] | None = None
user_name: str | None = None
llm_context = llm.LLMContext(
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
if options.get(CONF_LLM_HASS_API): if options.get(CONF_LLM_HASS_API):
try: try:
llm_api = await llm.async_get_api( llm_api = await llm.async_get_api(
self.hass, self.hass,
options[CONF_LLM_HASS_API], options[CONF_LLM_HASS_API],
llm.LLMContext( llm_context,
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
),
) )
except HomeAssistantError as err: except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err) LOGGER.error("Error getting LLM API: %s", err)
@ -144,6 +146,18 @@ class OpenAIConversationEntity(
messages = self.history[conversation_id] messages = self.history[conversation_id]
else: else:
conversation_id = ulid.ulid_now() conversation_id = ulid.ulid_now()
if (
user_input.context
and user_input.context.user_id
and (
user := await self.hass.auth.async_get_user(
user_input.context.user_id
)
)
):
user_name = user.name
try: try:
if llm_api: if llm_api:
api_prompt = llm_api.api_prompt api_prompt = llm_api.api_prompt
@ -158,6 +172,8 @@ class OpenAIConversationEntity(
).async_render( ).async_render(
{ {
"ha_name": self.hass.config.location_name, "ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
}, },
parse_result=False, parse_result=False,
), ),

View file

@ -1,6 +1,6 @@
"""Tests for the OpenAI integration.""" """Tests for the OpenAI integration."""
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, Mock, patch
from httpx import Response from httpx import Response
from openai import RateLimitError from openai import RateLimitError
@ -73,6 +73,53 @@ async def test_template_error(
assert result.response.error_code == "unknown", result assert result.response.error_code == "unknown", result
async def test_template_variables(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that template variables work."""
context = Context(user_id="12345")
mock_user = Mock()
mock_user.id = "12345"
mock_user.name = "Test User"
hass.config_entries.async_update_entry(
mock_config_entry,
options={
"prompt": (
"The user name is {{ user_name }}. "
"The user id is {{ llm_context.context.user_id }}."
),
},
)
with (
patch(
"openai.resources.models.AsyncModels.list",
),
patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
) as mock_create,
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(
hass, "hello", None, context, agent_id=mock_config_entry.entry_id
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert (
"The user name is Test User."
in mock_create.mock_calls[0][2]["messages"][0]["content"]
)
assert (
"The user id is 12345."
in mock_create.mock_calls[0][2]["messages"][0]["content"]
)
async def test_conversation_agent( async def test_conversation_agent(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
@ -382,7 +429,9 @@ async def test_assist_api_tools_conversion(
), ),
), ),
) as mock_create: ) as mock_create:
await conversation.async_converse(hass, "hello", None, None, agent_id=agent_id) await conversation.async_converse(
hass, "hello", None, Context(), agent_id=agent_id
)
tools = mock_create.mock_calls[0][2]["tools"] tools = mock_create.mock_calls[0][2]["tools"]
assert tools assert tools