parent
403dffc12d
commit
885be98f8f
9 changed files with 107 additions and 78 deletions
|
@ -16,13 +16,13 @@ from homeassistant.helpers import area_registry as ar, intent, template
|
||||||
from homeassistant.util import ulid
|
from homeassistant.util import ulid
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
|
CONF_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
CONF_MODEL,
|
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
|
DEFAULT_CHAT_MODEL,
|
||||||
DEFAULT_MAX_TOKENS,
|
DEFAULT_MAX_TOKENS,
|
||||||
DEFAULT_MODEL,
|
|
||||||
DEFAULT_PROMPT,
|
DEFAULT_PROMPT,
|
||||||
DEFAULT_TEMPERATURE,
|
DEFAULT_TEMPERATURE,
|
||||||
DEFAULT_TOP_P,
|
DEFAULT_TOP_P,
|
||||||
|
@ -63,7 +63,7 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.entry = entry
|
self.entry = entry
|
||||||
self.history: dict[str, str] = {}
|
self.history: dict[str, list[dict]] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def attribution(self):
|
def attribution(self):
|
||||||
|
@ -75,14 +75,14 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||||
model = self.entry.options.get(CONF_MODEL, DEFAULT_MODEL)
|
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
|
||||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||||
|
|
||||||
if user_input.conversation_id in self.history:
|
if user_input.conversation_id in self.history:
|
||||||
conversation_id = user_input.conversation_id
|
conversation_id = user_input.conversation_id
|
||||||
prompt = self.history[conversation_id]
|
messages = self.history[conversation_id]
|
||||||
else:
|
else:
|
||||||
conversation_id = ulid.ulid()
|
conversation_id = ulid.ulid()
|
||||||
try:
|
try:
|
||||||
|
@ -97,25 +97,16 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=conversation_id
|
||||||
)
|
)
|
||||||
|
messages = [{"role": "system", "content": prompt}]
|
||||||
|
|
||||||
user_name = "User"
|
messages.append({"role": "user", "content": user_input.text})
|
||||||
if (
|
|
||||||
user_input.context.user_id
|
|
||||||
and (
|
|
||||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
|
||||||
)
|
|
||||||
and user.name
|
|
||||||
):
|
|
||||||
user_name = user.name
|
|
||||||
|
|
||||||
prompt += f"\n{user_name}: {user_input.text}\nSmart home: "
|
_LOGGER.debug("Prompt for %s: %s", model, messages)
|
||||||
|
|
||||||
_LOGGER.debug("Prompt for %s: %s", model, prompt)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await openai.Completion.acreate(
|
result = await openai.ChatCompletion.acreate(
|
||||||
engine=model,
|
model=model,
|
||||||
prompt=prompt,
|
messages=messages,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
@ -132,15 +123,12 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER.debug("Response %s", result)
|
_LOGGER.debug("Response %s", result)
|
||||||
response = result["choices"][0]["text"].strip()
|
response = result["choices"][0]["message"]
|
||||||
self.history[conversation_id] = prompt + response
|
messages.append(response)
|
||||||
|
self.history[conversation_id] = messages
|
||||||
stripped_response = response
|
|
||||||
if response.startswith("Smart home:"):
|
|
||||||
stripped_response = response[11:].strip()
|
|
||||||
|
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
intent_response.async_set_speech(stripped_response)
|
intent_response.async_set_speech(response["content"])
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=conversation_id
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,13 +22,13 @@ from homeassistant.helpers.selector import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
|
CONF_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
CONF_MODEL,
|
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
|
DEFAULT_CHAT_MODEL,
|
||||||
DEFAULT_MAX_TOKENS,
|
DEFAULT_MAX_TOKENS,
|
||||||
DEFAULT_MODEL,
|
|
||||||
DEFAULT_PROMPT,
|
DEFAULT_PROMPT,
|
||||||
DEFAULT_TEMPERATURE,
|
DEFAULT_TEMPERATURE,
|
||||||
DEFAULT_TOP_P,
|
DEFAULT_TOP_P,
|
||||||
|
@ -46,7 +46,7 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||||
DEFAULT_OPTIONS = types.MappingProxyType(
|
DEFAULT_OPTIONS = types.MappingProxyType(
|
||||||
{
|
{
|
||||||
CONF_PROMPT: DEFAULT_PROMPT,
|
CONF_PROMPT: DEFAULT_PROMPT,
|
||||||
CONF_MODEL: DEFAULT_MODEL,
|
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
|
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
|
||||||
CONF_TOP_P: DEFAULT_TOP_P,
|
CONF_TOP_P: DEFAULT_TOP_P,
|
||||||
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
|
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
|
||||||
|
@ -131,13 +131,32 @@ def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict:
|
||||||
if not options:
|
if not options:
|
||||||
options = DEFAULT_OPTIONS
|
options = DEFAULT_OPTIONS
|
||||||
return {
|
return {
|
||||||
vol.Required(CONF_PROMPT, default=options.get(CONF_PROMPT)): TemplateSelector(),
|
vol.Optional(
|
||||||
vol.Required(CONF_MODEL, default=options.get(CONF_MODEL)): str,
|
CONF_PROMPT,
|
||||||
vol.Required(CONF_MAX_TOKENS, default=options.get(CONF_MAX_TOKENS)): int,
|
description={"suggested_value": options[CONF_PROMPT]},
|
||||||
vol.Required(CONF_TOP_P, default=options.get(CONF_TOP_P)): NumberSelector(
|
default=DEFAULT_PROMPT,
|
||||||
NumberSelectorConfig(min=0, max=1, step=0.05)
|
): TemplateSelector(),
|
||||||
),
|
vol.Optional(
|
||||||
vol.Required(
|
CONF_CHAT_MODEL,
|
||||||
CONF_TEMPERATURE, default=options.get(CONF_TEMPERATURE)
|
description={
|
||||||
|
# New key in HA 2023.4
|
||||||
|
"suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
|
||||||
|
},
|
||||||
|
default=DEFAULT_CHAT_MODEL,
|
||||||
|
): str,
|
||||||
|
vol.Optional(
|
||||||
|
CONF_MAX_TOKENS,
|
||||||
|
description={"suggested_value": options[CONF_MAX_TOKENS]},
|
||||||
|
default=DEFAULT_MAX_TOKENS,
|
||||||
|
): int,
|
||||||
|
vol.Optional(
|
||||||
|
CONF_TOP_P,
|
||||||
|
description={"suggested_value": options[CONF_TOP_P]},
|
||||||
|
default=DEFAULT_TOP_P,
|
||||||
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||||
|
vol.Optional(
|
||||||
|
CONF_TEMPERATURE,
|
||||||
|
description={"suggested_value": options[CONF_TEMPERATURE]},
|
||||||
|
default=DEFAULT_TEMPERATURE,
|
||||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,13 +22,9 @@ An overview of the areas and the devices in this smart home:
|
||||||
Answer the user's questions about the world truthfully.
|
Answer the user's questions about the world truthfully.
|
||||||
|
|
||||||
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
|
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
|
||||||
|
|
||||||
Now finish this conversation:
|
|
||||||
|
|
||||||
Smart home: How can I assist?
|
|
||||||
"""
|
"""
|
||||||
CONF_MODEL = "model"
|
CONF_CHAT_MODEL = "chat_model"
|
||||||
DEFAULT_MODEL = "text-davinci-003"
|
DEFAULT_CHAT_MODEL = "gpt-3.5-turbo"
|
||||||
CONF_MAX_TOKENS = "max_tokens"
|
CONF_MAX_TOKENS = "max_tokens"
|
||||||
DEFAULT_MAX_TOKENS = 150
|
DEFAULT_MAX_TOKENS = 150
|
||||||
CONF_TOP_P = "top_p"
|
CONF_TOP_P = "top_p"
|
||||||
|
|
|
@ -7,5 +7,5 @@
|
||||||
"documentation": "https://www.home-assistant.io/integrations/openai_conversation",
|
"documentation": "https://www.home-assistant.io/integrations/openai_conversation",
|
||||||
"integration_type": "service",
|
"integration_type": "service",
|
||||||
"iot_class": "cloud_polling",
|
"iot_class": "cloud_polling",
|
||||||
"requirements": ["openai==0.26.2"]
|
"requirements": ["openai==0.27.2"]
|
||||||
}
|
}
|
||||||
|
|
|
@ -1269,7 +1269,7 @@ open-garage==0.2.0
|
||||||
open-meteo==0.2.1
|
open-meteo==0.2.1
|
||||||
|
|
||||||
# homeassistant.components.openai_conversation
|
# homeassistant.components.openai_conversation
|
||||||
openai==0.26.2
|
openai==0.27.2
|
||||||
|
|
||||||
# homeassistant.components.opencv
|
# homeassistant.components.opencv
|
||||||
# opencv-python-headless==4.6.0.66
|
# opencv-python-headless==4.6.0.66
|
||||||
|
|
|
@ -947,7 +947,7 @@ open-garage==0.2.0
|
||||||
open-meteo==0.2.1
|
open-meteo==0.2.1
|
||||||
|
|
||||||
# homeassistant.components.openai_conversation
|
# homeassistant.components.openai_conversation
|
||||||
openai==0.26.2
|
openai==0.27.2
|
||||||
|
|
||||||
# homeassistant.components.openerz
|
# homeassistant.components.openerz
|
||||||
openerz-api==0.2.0
|
openerz-api==0.2.0
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
# serializer version: 1
|
||||||
|
# name: test_default_prompt
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'content': '''
|
||||||
|
This smart home is controlled by Home Assistant.
|
||||||
|
|
||||||
|
An overview of the areas and the devices in this smart home:
|
||||||
|
|
||||||
|
Test Area:
|
||||||
|
- Test Device (Test Model)
|
||||||
|
|
||||||
|
Test Area 2:
|
||||||
|
- Test Device 2
|
||||||
|
- Test Device 3 (Test Model 3A)
|
||||||
|
- Test Device 4
|
||||||
|
- 1 (3)
|
||||||
|
|
||||||
|
Answer the user's questions about the world truthfully.
|
||||||
|
|
||||||
|
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
|
||||||
|
''',
|
||||||
|
'role': 'system',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': 'hello',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': 'Hello, how can I help you?',
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
|
@ -6,8 +6,8 @@ import pytest
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.components.openai_conversation.const import (
|
from homeassistant.components.openai_conversation.const import (
|
||||||
CONF_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
DEFAULT_MODEL,
|
DEFAULT_CHAT_MODEL,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
@ -72,7 +72,7 @@ async def test_options(
|
||||||
assert options["type"] == FlowResultType.CREATE_ENTRY
|
assert options["type"] == FlowResultType.CREATE_ENTRY
|
||||||
assert options["data"]["prompt"] == "Speak like a pirate"
|
assert options["data"]["prompt"] == "Speak like a pirate"
|
||||||
assert options["data"]["max_tokens"] == 200
|
assert options["data"]["max_tokens"] == 200
|
||||||
assert options["data"][CONF_MODEL] == DEFAULT_MODEL
|
assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from openai import error
|
from openai import error
|
||||||
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
@ -15,6 +16,7 @@ async def test_default_prompt(
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
area_registry: ar.AreaRegistry,
|
area_registry: ar.AreaRegistry,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the default prompt works."""
|
"""Test that the default prompt works."""
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
|
@ -86,40 +88,30 @@ async def test_default_prompt(
|
||||||
model=3,
|
model=3,
|
||||||
suggested_area="Test Area 2",
|
suggested_area="Test Area 2",
|
||||||
)
|
)
|
||||||
with patch("openai.Completion.acreate") as mock_create:
|
with patch(
|
||||||
|
"openai.ChatCompletion.acreate",
|
||||||
|
return_value={
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello, how can I help you?",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
) as mock_create:
|
||||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
assert (
|
assert mock_create.mock_calls[0][2]["messages"] == snapshot
|
||||||
mock_create.mock_calls[0][2]["prompt"]
|
|
||||||
== """This smart home is controlled by Home Assistant.
|
|
||||||
|
|
||||||
An overview of the areas and the devices in this smart home:
|
|
||||||
|
|
||||||
Test Area:
|
|
||||||
- Test Device (Test Model)
|
|
||||||
|
|
||||||
Test Area 2:
|
|
||||||
- Test Device 2
|
|
||||||
- Test Device 3 (Test Model 3A)
|
|
||||||
- Test Device 4
|
|
||||||
- 1 (3)
|
|
||||||
|
|
||||||
Answer the user's questions about the world truthfully.
|
|
||||||
|
|
||||||
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
|
|
||||||
|
|
||||||
Now finish this conversation:
|
|
||||||
|
|
||||||
Smart home: How can I assist?
|
|
||||||
User: hello
|
|
||||||
Smart home: """
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None:
|
async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None:
|
||||||
"""Test that the default prompt works."""
|
"""Test that the default prompt works."""
|
||||||
with patch("openai.Completion.acreate", side_effect=error.ServiceUnavailableError):
|
with patch(
|
||||||
|
"openai.ChatCompletion.acreate", side_effect=error.ServiceUnavailableError
|
||||||
|
):
|
||||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
|
@ -138,7 +130,7 @@ async def test_template_error(
|
||||||
)
|
)
|
||||||
with patch(
|
with patch(
|
||||||
"openai.Engine.list",
|
"openai.Engine.list",
|
||||||
), patch("openai.Completion.acreate"):
|
), patch("openai.ChatCompletion.acreate"):
|
||||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||||
|
|
Loading…
Add table
Reference in a new issue