Use template selector for prompt template in OpenAI (#87205)

* Use template selector for prompt template in OpenAI

* Fix tests

* Do not parse template result
This commit is contained in:
Franck Nijhof 2023-02-02 21:20:10 +01:00 committed by GitHub
parent 22698b1cc5
commit f9d7d65f3a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 16 deletions

View file

@ -151,5 +151,6 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
{ {
"ha_name": self.hass.config.location_name, "ha_name": self.hass.config.location_name,
"areas": list(area_registry.async_get(self.hass).areas.values()), "areas": list(area_registry.async_get(self.hass).areas.values()),
} },
parse_result=False,
) )

View file

@ -18,8 +18,7 @@ from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.selector import ( from homeassistant.helpers.selector import (
NumberSelector, NumberSelector,
NumberSelectorConfig, NumberSelectorConfig,
TextSelector, TemplateSelector,
TextSelectorConfig,
) )
from .const import ( from .const import (
@ -132,9 +131,7 @@ def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict:
if not options: if not options:
options = DEFAULT_OPTIONS options = DEFAULT_OPTIONS
return { return {
vol.Required(CONF_PROMPT, default=options.get(CONF_PROMPT)): TextSelector( vol.Required(CONF_PROMPT, default=options.get(CONF_PROMPT)): TemplateSelector(),
TextSelectorConfig(multiline=True)
),
vol.Required(CONF_MODEL, default=options.get(CONF_MODEL)): str, vol.Required(CONF_MODEL, default=options.get(CONF_MODEL)): str,
vol.Required(CONF_MAX_TOKENS, default=options.get(CONF_MAX_TOKENS)): int, vol.Required(CONF_MAX_TOKENS, default=options.get(CONF_MAX_TOKENS)): int,
vol.Required(CONF_TOP_P, default=options.get(CONF_TOP_P)): NumberSelector( vol.Required(CONF_TOP_P, default=options.get(CONF_TOP_P)): NumberSelector(

View file

@ -4,9 +4,11 @@ from unittest.mock import patch
from openai import error from openai import error
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.core import Context from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import area_registry, device_registry, intent from homeassistant.helpers import area_registry, device_registry, intent
from tests.common import MockConfigEntry
async def test_default_prompt(hass, mock_init_component): async def test_default_prompt(hass, mock_init_component):
"""Test that the default prompt works.""" """Test that the default prompt works."""
@ -107,19 +109,21 @@ async def test_error_handling(hass, mock_init_component):
assert result.response.error_code == "unknown", result assert result.response.error_code == "unknown", result
async def test_template_error(hass, mock_config_entry, mock_init_component): async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that template error handling works.""" """Test that template error handling works."""
options_flow = await hass.config_entries.options.async_init( hass.config_entries.async_update_entry(
mock_config_entry.entry_id mock_config_entry,
) options={
await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
}, },
) )
with patch(
"openai.Engine.list",
), patch("openai.Completion.acreate"):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
with patch("openai.Completion.acreate"):
result = await conversation.async_converse(hass, "hello", None, Context()) result = await conversation.async_converse(hass, "hello", None, Context())
assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.response_type == intent.IntentResponseType.ERROR, result