parent
403dffc12d
commit
885be98f8f
9 changed files with 107 additions and 78 deletions
|
@ -16,13 +16,13 @@ from homeassistant.helpers import area_registry as ar, intent, template
|
|||
from homeassistant.util import ulid
|
||||
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_MODEL,
|
||||
CONF_PROMPT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_MODEL,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_P,
|
||||
|
@ -63,7 +63,7 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
|||
"""Initialize the agent."""
|
||||
self.hass = hass
|
||||
self.entry = entry
|
||||
self.history: dict[str, str] = {}
|
||||
self.history: dict[str, list[dict]] = {}
|
||||
|
||||
@property
|
||||
def attribution(self):
|
||||
|
@ -75,14 +75,14 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
|||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
model = self.entry.options.get(CONF_MODEL, DEFAULT_MODEL)
|
||||
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
|
||||
if user_input.conversation_id in self.history:
|
||||
conversation_id = user_input.conversation_id
|
||||
prompt = self.history[conversation_id]
|
||||
messages = self.history[conversation_id]
|
||||
else:
|
||||
conversation_id = ulid.ulid()
|
||||
try:
|
||||
|
@ -97,25 +97,16 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
|||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
messages = [{"role": "system", "content": prompt}]
|
||||
|
||||
user_name = "User"
|
||||
if (
|
||||
user_input.context.user_id
|
||||
and (
|
||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
||||
)
|
||||
and user.name
|
||||
):
|
||||
user_name = user.name
|
||||
messages.append({"role": "user", "content": user_input.text})
|
||||
|
||||
prompt += f"\n{user_name}: {user_input.text}\nSmart home: "
|
||||
|
||||
_LOGGER.debug("Prompt for %s: %s", model, prompt)
|
||||
_LOGGER.debug("Prompt for %s: %s", model, messages)
|
||||
|
||||
try:
|
||||
result = await openai.Completion.acreate(
|
||||
engine=model,
|
||||
prompt=prompt,
|
||||
result = await openai.ChatCompletion.acreate(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
|
@ -132,15 +123,12 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
|||
)
|
||||
|
||||
_LOGGER.debug("Response %s", result)
|
||||
response = result["choices"][0]["text"].strip()
|
||||
self.history[conversation_id] = prompt + response
|
||||
|
||||
stripped_response = response
|
||||
if response.startswith("Smart home:"):
|
||||
stripped_response = response[11:].strip()
|
||||
response = result["choices"][0]["message"]
|
||||
messages.append(response)
|
||||
self.history[conversation_id] = messages
|
||||
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(stripped_response)
|
||||
intent_response.async_set_speech(response["content"])
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
|
|
@ -22,13 +22,13 @@ from homeassistant.helpers.selector import (
|
|||
)
|
||||
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_MODEL,
|
||||
CONF_PROMPT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_MODEL,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_P,
|
||||
|
@ -46,7 +46,7 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
|||
DEFAULT_OPTIONS = types.MappingProxyType(
|
||||
{
|
||||
CONF_PROMPT: DEFAULT_PROMPT,
|
||||
CONF_MODEL: DEFAULT_MODEL,
|
||||
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
|
||||
CONF_TOP_P: DEFAULT_TOP_P,
|
||||
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
|
||||
|
@ -131,13 +131,32 @@ def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict:
|
|||
if not options:
|
||||
options = DEFAULT_OPTIONS
|
||||
return {
|
||||
vol.Required(CONF_PROMPT, default=options.get(CONF_PROMPT)): TemplateSelector(),
|
||||
vol.Required(CONF_MODEL, default=options.get(CONF_MODEL)): str,
|
||||
vol.Required(CONF_MAX_TOKENS, default=options.get(CONF_MAX_TOKENS)): int,
|
||||
vol.Required(CONF_TOP_P, default=options.get(CONF_TOP_P)): NumberSelector(
|
||||
NumberSelectorConfig(min=0, max=1, step=0.05)
|
||||
),
|
||||
vol.Required(
|
||||
CONF_TEMPERATURE, default=options.get(CONF_TEMPERATURE)
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={"suggested_value": options[CONF_PROMPT]},
|
||||
default=DEFAULT_PROMPT,
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_CHAT_MODEL,
|
||||
description={
|
||||
# New key in HA 2023.4
|
||||
"suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
|
||||
},
|
||||
default=DEFAULT_CHAT_MODEL,
|
||||
): str,
|
||||
vol.Optional(
|
||||
CONF_MAX_TOKENS,
|
||||
description={"suggested_value": options[CONF_MAX_TOKENS]},
|
||||
default=DEFAULT_MAX_TOKENS,
|
||||
): int,
|
||||
vol.Optional(
|
||||
CONF_TOP_P,
|
||||
description={"suggested_value": options[CONF_TOP_P]},
|
||||
default=DEFAULT_TOP_P,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Optional(
|
||||
CONF_TEMPERATURE,
|
||||
description={"suggested_value": options[CONF_TEMPERATURE]},
|
||||
default=DEFAULT_TEMPERATURE,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
}
|
||||
|
|
|
@ -22,13 +22,9 @@ An overview of the areas and the devices in this smart home:
|
|||
Answer the user's questions about the world truthfully.
|
||||
|
||||
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
|
||||
|
||||
Now finish this conversation:
|
||||
|
||||
Smart home: How can I assist?
|
||||
"""
|
||||
CONF_MODEL = "model"
|
||||
DEFAULT_MODEL = "text-davinci-003"
|
||||
CONF_CHAT_MODEL = "chat_model"
|
||||
DEFAULT_CHAT_MODEL = "gpt-3.5-turbo"
|
||||
CONF_MAX_TOKENS = "max_tokens"
|
||||
DEFAULT_MAX_TOKENS = 150
|
||||
CONF_TOP_P = "top_p"
|
||||
|
|
|
@ -7,5 +7,5 @@
|
|||
"documentation": "https://www.home-assistant.io/integrations/openai_conversation",
|
||||
"integration_type": "service",
|
||||
"iot_class": "cloud_polling",
|
||||
"requirements": ["openai==0.26.2"]
|
||||
"requirements": ["openai==0.27.2"]
|
||||
}
|
||||
|
|
|
@ -1269,7 +1269,7 @@ open-garage==0.2.0
|
|||
open-meteo==0.2.1
|
||||
|
||||
# homeassistant.components.openai_conversation
|
||||
openai==0.26.2
|
||||
openai==0.27.2
|
||||
|
||||
# homeassistant.components.opencv
|
||||
# opencv-python-headless==4.6.0.66
|
||||
|
|
|
@ -947,7 +947,7 @@ open-garage==0.2.0
|
|||
open-meteo==0.2.1
|
||||
|
||||
# homeassistant.components.openai_conversation
|
||||
openai==0.26.2
|
||||
openai==0.27.2
|
||||
|
||||
# homeassistant.components.openerz
|
||||
openerz-api==0.2.0
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# serializer version: 1
|
||||
# name: test_default_prompt
|
||||
list([
|
||||
dict({
|
||||
'content': '''
|
||||
This smart home is controlled by Home Assistant.
|
||||
|
||||
An overview of the areas and the devices in this smart home:
|
||||
|
||||
Test Area:
|
||||
- Test Device (Test Model)
|
||||
|
||||
Test Area 2:
|
||||
- Test Device 2
|
||||
- Test Device 3 (Test Model 3A)
|
||||
- Test Device 4
|
||||
- 1 (3)
|
||||
|
||||
Answer the user's questions about the world truthfully.
|
||||
|
||||
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
|
||||
''',
|
||||
'role': 'system',
|
||||
}),
|
||||
dict({
|
||||
'content': 'hello',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'content': 'Hello, how can I help you?',
|
||||
'role': 'assistant',
|
||||
}),
|
||||
])
|
||||
# ---
|
|
@ -6,8 +6,8 @@ import pytest
|
|||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.openai_conversation.const import (
|
||||
CONF_MODEL,
|
||||
DEFAULT_MODEL,
|
||||
CONF_CHAT_MODEL,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
@ -72,7 +72,7 @@ async def test_options(
|
|||
assert options["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert options["data"]["prompt"] == "Speak like a pirate"
|
||||
assert options["data"]["max_tokens"] == 200
|
||||
assert options["data"][CONF_MODEL] == DEFAULT_MODEL
|
||||
assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
from openai import error
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
|
@ -15,6 +16,7 @@ async def test_default_prompt(
|
|||
mock_init_component,
|
||||
area_registry: ar.AreaRegistry,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that the default prompt works."""
|
||||
for i in range(3):
|
||||
|
@ -86,40 +88,30 @@ async def test_default_prompt(
|
|||
model=3,
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
with patch("openai.Completion.acreate") as mock_create:
|
||||
with patch(
|
||||
"openai.ChatCompletion.acreate",
|
||||
return_value={
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello, how can I help you?",
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
) as mock_create:
|
||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert (
|
||||
mock_create.mock_calls[0][2]["prompt"]
|
||||
== """This smart home is controlled by Home Assistant.
|
||||
|
||||
An overview of the areas and the devices in this smart home:
|
||||
|
||||
Test Area:
|
||||
- Test Device (Test Model)
|
||||
|
||||
Test Area 2:
|
||||
- Test Device 2
|
||||
- Test Device 3 (Test Model 3A)
|
||||
- Test Device 4
|
||||
- 1 (3)
|
||||
|
||||
Answer the user's questions about the world truthfully.
|
||||
|
||||
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
|
||||
|
||||
Now finish this conversation:
|
||||
|
||||
Smart home: How can I assist?
|
||||
User: hello
|
||||
Smart home: """
|
||||
)
|
||||
assert mock_create.mock_calls[0][2]["messages"] == snapshot
|
||||
|
||||
|
||||
async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None:
|
||||
"""Test that the default prompt works."""
|
||||
with patch("openai.Completion.acreate", side_effect=error.ServiceUnavailableError):
|
||||
with patch(
|
||||
"openai.ChatCompletion.acreate", side_effect=error.ServiceUnavailableError
|
||||
):
|
||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
|
@ -138,7 +130,7 @@ async def test_template_error(
|
|||
)
|
||||
with patch(
|
||||
"openai.Engine.list",
|
||||
), patch("openai.Completion.acreate"):
|
||||
), patch("openai.ChatCompletion.acreate"):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||
|
|
Loading…
Add table
Reference in a new issue