OpenAI to use GPT3.5 (#90423)

* OpenAI to use GPT3.5

* Add snapshot
This commit is contained in:
Paulus Schoutsen 2023-03-28 23:37:43 -04:00 committed by GitHub
parent 403dffc12d
commit 885be98f8f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 107 additions and 78 deletions

View file

@ -16,13 +16,13 @@ from homeassistant.helpers import area_registry as ar, intent, template
from homeassistant.util import ulid from homeassistant.util import ulid
from .const import ( from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
CONF_MODEL,
CONF_PROMPT, CONF_PROMPT,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_P, CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS, DEFAULT_MAX_TOKENS,
DEFAULT_MODEL,
DEFAULT_PROMPT, DEFAULT_PROMPT,
DEFAULT_TEMPERATURE, DEFAULT_TEMPERATURE,
DEFAULT_TOP_P, DEFAULT_TOP_P,
@ -63,7 +63,7 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
"""Initialize the agent.""" """Initialize the agent."""
self.hass = hass self.hass = hass
self.entry = entry self.entry = entry
self.history: dict[str, str] = {} self.history: dict[str, list[dict]] = {}
@property @property
def attribution(self): def attribution(self):
@ -75,14 +75,14 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = self.entry.options.get(CONF_MODEL, DEFAULT_MODEL) model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
if user_input.conversation_id in self.history: if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id conversation_id = user_input.conversation_id
prompt = self.history[conversation_id] messages = self.history[conversation_id]
else: else:
conversation_id = ulid.ulid() conversation_id = ulid.ulid()
try: try:
@ -97,25 +97,16 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )
messages = [{"role": "system", "content": prompt}]
user_name = "User" messages.append({"role": "user", "content": user_input.text})
if (
user_input.context.user_id
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
)
and user.name
):
user_name = user.name
prompt += f"\n{user_name}: {user_input.text}\nSmart home: " _LOGGER.debug("Prompt for %s: %s", model, messages)
_LOGGER.debug("Prompt for %s: %s", model, prompt)
try: try:
result = await openai.Completion.acreate( result = await openai.ChatCompletion.acreate(
engine=model, model=model,
prompt=prompt, messages=messages,
max_tokens=max_tokens, max_tokens=max_tokens,
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
@ -132,15 +123,12 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
) )
_LOGGER.debug("Response %s", result) _LOGGER.debug("Response %s", result)
response = result["choices"][0]["text"].strip() response = result["choices"][0]["message"]
self.history[conversation_id] = prompt + response messages.append(response)
self.history[conversation_id] = messages
stripped_response = response
if response.startswith("Smart home:"):
stripped_response = response[11:].strip()
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(stripped_response) intent_response.async_set_speech(response["content"])
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )

View file

@ -22,13 +22,13 @@ from homeassistant.helpers.selector import (
) )
from .const import ( from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
CONF_MODEL,
CONF_PROMPT, CONF_PROMPT,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_P, CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS, DEFAULT_MAX_TOKENS,
DEFAULT_MODEL,
DEFAULT_PROMPT, DEFAULT_PROMPT,
DEFAULT_TEMPERATURE, DEFAULT_TEMPERATURE,
DEFAULT_TOP_P, DEFAULT_TOP_P,
@ -46,7 +46,7 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
DEFAULT_OPTIONS = types.MappingProxyType( DEFAULT_OPTIONS = types.MappingProxyType(
{ {
CONF_PROMPT: DEFAULT_PROMPT, CONF_PROMPT: DEFAULT_PROMPT,
CONF_MODEL: DEFAULT_MODEL, CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS, CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
CONF_TOP_P: DEFAULT_TOP_P, CONF_TOP_P: DEFAULT_TOP_P,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE, CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
@ -131,13 +131,32 @@ def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict:
if not options: if not options:
options = DEFAULT_OPTIONS options = DEFAULT_OPTIONS
return { return {
vol.Required(CONF_PROMPT, default=options.get(CONF_PROMPT)): TemplateSelector(), vol.Optional(
vol.Required(CONF_MODEL, default=options.get(CONF_MODEL)): str, CONF_PROMPT,
vol.Required(CONF_MAX_TOKENS, default=options.get(CONF_MAX_TOKENS)): int, description={"suggested_value": options[CONF_PROMPT]},
vol.Required(CONF_TOP_P, default=options.get(CONF_TOP_P)): NumberSelector( default=DEFAULT_PROMPT,
NumberSelectorConfig(min=0, max=1, step=0.05) ): TemplateSelector(),
), vol.Optional(
vol.Required( CONF_CHAT_MODEL,
CONF_TEMPERATURE, default=options.get(CONF_TEMPERATURE) description={
# New key in HA 2023.4
"suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
},
default=DEFAULT_CHAT_MODEL,
): str,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options[CONF_MAX_TOKENS]},
default=DEFAULT_MAX_TOKENS,
): int,
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options[CONF_TOP_P]},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options[CONF_TEMPERATURE]},
default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
} }

View file

@ -22,13 +22,9 @@ An overview of the areas and the devices in this smart home:
Answer the user's questions about the world truthfully. Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app. If the user wants to control a device, reject the request and suggest using the Home Assistant app.
Now finish this conversation:
Smart home: How can I assist?
""" """
CONF_MODEL = "model" CONF_CHAT_MODEL = "chat_model"
DEFAULT_MODEL = "text-davinci-003" DEFAULT_CHAT_MODEL = "gpt-3.5-turbo"
CONF_MAX_TOKENS = "max_tokens" CONF_MAX_TOKENS = "max_tokens"
DEFAULT_MAX_TOKENS = 150 DEFAULT_MAX_TOKENS = 150
CONF_TOP_P = "top_p" CONF_TOP_P = "top_p"

View file

@ -7,5 +7,5 @@
"documentation": "https://www.home-assistant.io/integrations/openai_conversation", "documentation": "https://www.home-assistant.io/integrations/openai_conversation",
"integration_type": "service", "integration_type": "service",
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"requirements": ["openai==0.26.2"] "requirements": ["openai==0.27.2"]
} }

View file

@ -1269,7 +1269,7 @@ open-garage==0.2.0
open-meteo==0.2.1 open-meteo==0.2.1
# homeassistant.components.openai_conversation # homeassistant.components.openai_conversation
openai==0.26.2 openai==0.27.2
# homeassistant.components.opencv # homeassistant.components.opencv
# opencv-python-headless==4.6.0.66 # opencv-python-headless==4.6.0.66

View file

@ -947,7 +947,7 @@ open-garage==0.2.0
open-meteo==0.2.1 open-meteo==0.2.1
# homeassistant.components.openai_conversation # homeassistant.components.openai_conversation
openai==0.26.2 openai==0.27.2
# homeassistant.components.openerz # homeassistant.components.openerz
openerz-api==0.2.0 openerz-api==0.2.0

View file

@ -0,0 +1,34 @@
# serializer version: 1
# name: test_default_prompt
list([
dict({
'content': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
''',
'role': 'system',
}),
dict({
'content': 'hello',
'role': 'user',
}),
dict({
'content': 'Hello, how can I help you?',
'role': 'assistant',
}),
])
# ---

View file

@ -6,8 +6,8 @@ import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.openai_conversation.const import ( from homeassistant.components.openai_conversation.const import (
CONF_MODEL, CONF_CHAT_MODEL,
DEFAULT_MODEL, DEFAULT_CHAT_MODEL,
DOMAIN, DOMAIN,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -72,7 +72,7 @@ async def test_options(
assert options["type"] == FlowResultType.CREATE_ENTRY assert options["type"] == FlowResultType.CREATE_ENTRY
assert options["data"]["prompt"] == "Speak like a pirate" assert options["data"]["prompt"] == "Speak like a pirate"
assert options["data"]["max_tokens"] == 200 assert options["data"]["max_tokens"] == 200
assert options["data"][CONF_MODEL] == DEFAULT_MODEL assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL
@pytest.mark.parametrize( @pytest.mark.parametrize(

View file

@ -2,6 +2,7 @@
from unittest.mock import patch from unittest.mock import patch
from openai import error from openai import error
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
@ -15,6 +16,7 @@ async def test_default_prompt(
mock_init_component, mock_init_component,
area_registry: ar.AreaRegistry, area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test that the default prompt works.""" """Test that the default prompt works."""
for i in range(3): for i in range(3):
@ -86,40 +88,30 @@ async def test_default_prompt(
model=3, model=3,
suggested_area="Test Area 2", suggested_area="Test Area 2",
) )
with patch("openai.Completion.acreate") as mock_create: with patch(
"openai.ChatCompletion.acreate",
return_value={
"choices": [
{
"message": {
"role": "assistant",
"content": "Hello, how can I help you?",
}
}
]
},
) as mock_create:
result = await conversation.async_converse(hass, "hello", None, Context()) result = await conversation.async_converse(hass, "hello", None, Context())
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert ( assert mock_create.mock_calls[0][2]["messages"] == snapshot
mock_create.mock_calls[0][2]["prompt"]
== """This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
Now finish this conversation:
Smart home: How can I assist?
User: hello
Smart home: """
)
async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None: async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None:
"""Test that the default prompt works.""" """Test that the default prompt works."""
with patch("openai.Completion.acreate", side_effect=error.ServiceUnavailableError): with patch(
"openai.ChatCompletion.acreate", side_effect=error.ServiceUnavailableError
):
result = await conversation.async_converse(hass, "hello", None, Context()) result = await conversation.async_converse(hass, "hello", None, Context())
assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.response_type == intent.IntentResponseType.ERROR, result
@ -138,7 +130,7 @@ async def test_template_error(
) )
with patch( with patch(
"openai.Engine.list", "openai.Engine.list",
), patch("openai.Completion.acreate"): ), patch("openai.ChatCompletion.acreate"):
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
result = await conversation.async_converse(hass, "hello", None, Context()) result = await conversation.async_converse(hass, "hello", None, Context())