Use template selector for prompt template in OpenAI (#87205)
* Use template selector for prompt template in OpenAI * Fix tests * Do not parse template result
This commit is contained in:
parent
22698b1cc5
commit
f9d7d65f3a
3 changed files with 18 additions and 16 deletions
|
@ -151,5 +151,6 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
||||||
{
|
{
|
||||||
"ha_name": self.hass.config.location_name,
|
"ha_name": self.hass.config.location_name,
|
||||||
"areas": list(area_registry.async_get(self.hass).areas.values()),
|
"areas": list(area_registry.async_get(self.hass).areas.values()),
|
||||||
}
|
},
|
||||||
|
parse_result=False,
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,8 +18,7 @@ from homeassistant.data_entry_flow import FlowResult
|
||||||
from homeassistant.helpers.selector import (
|
from homeassistant.helpers.selector import (
|
||||||
NumberSelector,
|
NumberSelector,
|
||||||
NumberSelectorConfig,
|
NumberSelectorConfig,
|
||||||
TextSelector,
|
TemplateSelector,
|
||||||
TextSelectorConfig,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
|
@ -132,9 +131,7 @@ def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict:
|
||||||
if not options:
|
if not options:
|
||||||
options = DEFAULT_OPTIONS
|
options = DEFAULT_OPTIONS
|
||||||
return {
|
return {
|
||||||
vol.Required(CONF_PROMPT, default=options.get(CONF_PROMPT)): TextSelector(
|
vol.Required(CONF_PROMPT, default=options.get(CONF_PROMPT)): TemplateSelector(),
|
||||||
TextSelectorConfig(multiline=True)
|
|
||||||
),
|
|
||||||
vol.Required(CONF_MODEL, default=options.get(CONF_MODEL)): str,
|
vol.Required(CONF_MODEL, default=options.get(CONF_MODEL)): str,
|
||||||
vol.Required(CONF_MAX_TOKENS, default=options.get(CONF_MAX_TOKENS)): int,
|
vol.Required(CONF_MAX_TOKENS, default=options.get(CONF_MAX_TOKENS)): int,
|
||||||
vol.Required(CONF_TOP_P, default=options.get(CONF_TOP_P)): NumberSelector(
|
vol.Required(CONF_TOP_P, default=options.get(CONF_TOP_P)): NumberSelector(
|
||||||
|
|
|
@ -4,9 +4,11 @@ from unittest.mock import patch
|
||||||
from openai import error
|
from openai import error
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.core import Context
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.helpers import area_registry, device_registry, intent
|
from homeassistant.helpers import area_registry, device_registry, intent
|
||||||
|
|
||||||
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
|
||||||
async def test_default_prompt(hass, mock_init_component):
|
async def test_default_prompt(hass, mock_init_component):
|
||||||
"""Test that the default prompt works."""
|
"""Test that the default prompt works."""
|
||||||
|
@ -107,19 +109,21 @@ async def test_error_handling(hass, mock_init_component):
|
||||||
assert result.response.error_code == "unknown", result
|
assert result.response.error_code == "unknown", result
|
||||||
|
|
||||||
|
|
||||||
async def test_template_error(hass, mock_config_entry, mock_init_component):
|
async def test_template_error(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
|
) -> None:
|
||||||
"""Test that template error handling works."""
|
"""Test that template error handling works."""
|
||||||
options_flow = await hass.config_entries.options.async_init(
|
hass.config_entries.async_update_entry(
|
||||||
mock_config_entry.entry_id
|
mock_config_entry,
|
||||||
)
|
options={
|
||||||
await hass.config_entries.options.async_configure(
|
|
||||||
options_flow["flow_id"],
|
|
||||||
{
|
|
||||||
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
with patch(
|
||||||
with patch("openai.Completion.acreate"):
|
"openai.Engine.list",
|
||||||
|
), patch("openai.Completion.acreate"):
|
||||||
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue