From f9d7d65f3a44e1651a0b66ff440152c57ed0ba00 Mon Sep 17 00:00:00 2001 From: Franck Nijhof Date: Thu, 2 Feb 2023 21:20:10 +0100 Subject: [PATCH] Use template selector for prompt template in OpenAI (#87205) * Use template selector for prompt template in OpenAI * Fix tests * Do not parse template result --- .../openai_conversation/__init__.py | 3 ++- .../openai_conversation/config_flow.py | 7 ++---- .../openai_conversation/test_init.py | 24 +++++++++++-------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/homeassistant/components/openai_conversation/__init__.py b/homeassistant/components/openai_conversation/__init__.py index 950e15f8e11..41ff6bcf9cd 100644 --- a/homeassistant/components/openai_conversation/__init__.py +++ b/homeassistant/components/openai_conversation/__init__.py @@ -151,5 +151,6 @@ class OpenAIAgent(conversation.AbstractConversationAgent): { "ha_name": self.hass.config.location_name, "areas": list(area_registry.async_get(self.hass).areas.values()), - } + }, + parse_result=False, ) diff --git a/homeassistant/components/openai_conversation/config_flow.py b/homeassistant/components/openai_conversation/config_flow.py index 9aef77e37f7..2db5e98a1f4 100644 --- a/homeassistant/components/openai_conversation/config_flow.py +++ b/homeassistant/components/openai_conversation/config_flow.py @@ -18,8 +18,7 @@ from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers.selector import ( NumberSelector, NumberSelectorConfig, - TextSelector, - TextSelectorConfig, + TemplateSelector, ) from .const import ( @@ -132,9 +131,7 @@ def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict: if not options: options = DEFAULT_OPTIONS return { - vol.Required(CONF_PROMPT, default=options.get(CONF_PROMPT)): TextSelector( - TextSelectorConfig(multiline=True) - ), + vol.Required(CONF_PROMPT, default=options.get(CONF_PROMPT)): TemplateSelector(), vol.Required(CONF_MODEL, default=options.get(CONF_MODEL)): str, vol.Required(CONF_MAX_TOKENS, default=options.get(CONF_MAX_TOKENS)): int, vol.Required(CONF_TOP_P, default=options.get(CONF_TOP_P)): NumberSelector( diff --git a/tests/components/openai_conversation/test_init.py b/tests/components/openai_conversation/test_init.py index 759c5e2e200..c5bdb7aff0b 100644 --- a/tests/components/openai_conversation/test_init.py +++ b/tests/components/openai_conversation/test_init.py @@ -4,9 +4,11 @@ from unittest.mock import patch from openai import error from homeassistant.components import conversation -from homeassistant.core import Context +from homeassistant.core import Context, HomeAssistant from homeassistant.helpers import area_registry, device_registry, intent +from tests.common import MockConfigEntry + async def test_default_prompt(hass, mock_init_component): """Test that the default prompt works.""" @@ -107,19 +109,21 @@ async def test_error_handling(hass, mock_init_component): assert result.response.error_code == "unknown", result -async def test_template_error(hass, mock_config_entry, mock_init_component): +async def test_template_error( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> None: """Test that template error handling works.""" - options_flow = await hass.config_entries.options.async_init( - mock_config_entry.entry_id - ) - await hass.config_entries.options.async_configure( - options_flow["flow_id"], - { + hass.config_entries.async_update_entry( + mock_config_entry, + options={ "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", }, ) - await hass.async_block_till_done() - with patch("openai.Completion.acreate"): + with patch( + "openai.Engine.list", + ), patch("openai.Completion.acreate"): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() result = await conversation.async_converse(hass, "hello", None, Context()) assert result.response.response_type == intent.IntentResponseType.ERROR, result