Use template selector for prompt template in OpenAI (#87205)
* Use template selector for prompt template in OpenAI * Fix tests * Do not parse template result
This commit is contained in:
parent
22698b1cc5
commit
f9d7d65f3a
3 changed files with 18 additions and 16 deletions
|
@ -151,5 +151,6 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
|||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"areas": list(area_registry.async_get(self.hass).areas.values()),
|
||||
}
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
|
|
|
@ -18,8 +18,7 @@ from homeassistant.data_entry_flow import FlowResult
|
|||
from homeassistant.helpers.selector import (
|
||||
NumberSelector,
|
||||
NumberSelectorConfig,
|
||||
TextSelector,
|
||||
TextSelectorConfig,
|
||||
TemplateSelector,
|
||||
)
|
||||
|
||||
from .const import (
|
||||
|
@ -132,9 +131,7 @@ def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict:
|
|||
if not options:
|
||||
options = DEFAULT_OPTIONS
|
||||
return {
|
||||
vol.Required(CONF_PROMPT, default=options.get(CONF_PROMPT)): TextSelector(
|
||||
TextSelectorConfig(multiline=True)
|
||||
),
|
||||
vol.Required(CONF_PROMPT, default=options.get(CONF_PROMPT)): TemplateSelector(),
|
||||
vol.Required(CONF_MODEL, default=options.get(CONF_MODEL)): str,
|
||||
vol.Required(CONF_MAX_TOKENS, default=options.get(CONF_MAX_TOKENS)): int,
|
||||
vol.Required(CONF_TOP_P, default=options.get(CONF_TOP_P)): NumberSelector(
|
||||
|
|
|
@ -4,9 +4,11 @@ from unittest.mock import patch
|
|||
from openai import error
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.core import Context
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import area_registry, device_registry, intent
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_default_prompt(hass, mock_init_component):
|
||||
"""Test that the default prompt works."""
|
||||
|
@ -107,19 +109,21 @@ async def test_error_handling(hass, mock_init_component):
|
|||
assert result.response.error_code == "unknown", result
|
||||
|
||||
|
||||
async def test_template_error(hass, mock_config_entry, mock_init_component):
|
||||
async def test_template_error(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test that template error handling works."""
|
||||
options_flow = await hass.config_entries.options.async_init(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
await hass.config_entries.options.async_configure(
|
||||
options_flow["flow_id"],
|
||||
{
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={
|
||||
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
with patch("openai.Completion.acreate"):
|
||||
with patch(
|
||||
"openai.Engine.list",
|
||||
), patch("openai.Completion.acreate"):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
result = await conversation.async_converse(hass, "hello", None, Context())
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
|
|
Loading…
Add table
Reference in a new issue