From aad281db187aab69a5f57b83434fd569f8406ff3 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 26 Jul 2023 04:14:18 -0400 Subject: [PATCH] Add service to OpenAI to Generate an image (#97018) Co-authored-by: Franck Nijhof --- .../openai_conversation/__init__.py | 71 +++++++++++++++++-- .../openai_conversation/services.yaml | 22 ++++++ .../openai_conversation/strings.json | 21 ++++++ homeassistant/helpers/selector.py | 2 +- .../openai_conversation/test_init.py | 68 ++++++++++++++++++ 5 files changed, 176 insertions(+), 8 deletions(-) create mode 100644 homeassistant/components/openai_conversation/services.yaml diff --git a/homeassistant/components/openai_conversation/__init__.py b/homeassistant/components/openai_conversation/__init__.py index efa81c7b73c..9f4c30d91ba 100644 --- a/homeassistant/components/openai_conversation/__init__.py +++ b/homeassistant/components/openai_conversation/__init__.py @@ -7,13 +7,24 @@ from typing import Literal import openai from openai import error +import voluptuous as vol from homeassistant.components import conversation from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_API_KEY, MATCH_ALL -from homeassistant.core import HomeAssistant -from homeassistant.exceptions import ConfigEntryNotReady, TemplateError -from homeassistant.helpers import intent, template +from homeassistant.core import ( + HomeAssistant, + ServiceCall, + ServiceResponse, + SupportsResponse, +) +from homeassistant.exceptions import ( + ConfigEntryNotReady, + HomeAssistantError, + TemplateError, +) +from homeassistant.helpers import config_validation as cv, intent, selector, template +from homeassistant.helpers.typing import ConfigType from homeassistant.util import ulid from .const import ( @@ -27,18 +38,61 @@ from .const import ( DEFAULT_PROMPT, DEFAULT_TEMPERATURE, DEFAULT_TOP_P, + DOMAIN, ) _LOGGER = logging.getLogger(__name__) +SERVICE_GENERATE_IMAGE = "generate_image" + +CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) + + +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: + """Set up OpenAI Conversation.""" + + async def render_image(call: ServiceCall) -> ServiceResponse: + """Render an image with dall-e.""" + try: + response = await openai.Image.acreate( + api_key=hass.data[DOMAIN][call.data["config_entry"]], + prompt=call.data["prompt"], + n=1, + size=f'{call.data["size"]}x{call.data["size"]}', + ) + except error.OpenAIError as err: + raise HomeAssistantError(f"Error generating image: {err}") from err + + return response["data"][0] + + hass.services.async_register( + DOMAIN, + SERVICE_GENERATE_IMAGE, + render_image, + schema=vol.Schema( + { + vol.Required("config_entry"): selector.ConfigEntrySelector( + { + "integration": DOMAIN, + } + ), + vol.Required("prompt"): cv.string, + vol.Optional("size", default="512"): vol.In(("256", "512", "1024")), + } + ), + supports_response=SupportsResponse.ONLY, + ) + return True async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up OpenAI Conversation from a config entry.""" - openai.api_key = entry.data[CONF_API_KEY] - try: await hass.async_add_executor_job( - partial(openai.Engine.list, request_timeout=10) + partial( + openai.Engine.list, + api_key=entry.data[CONF_API_KEY], + request_timeout=10, + ) ) except error.AuthenticationError as err: _LOGGER.error("Invalid API key: %s", err) @@ -46,13 +100,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: except error.OpenAIError as err: raise ConfigEntryNotReady(err) from err + hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry.data[CONF_API_KEY] + conversation.async_set_agent(hass, entry, OpenAIAgent(hass, entry)) return True async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload OpenAI.""" - openai.api_key = None + hass.data[DOMAIN].pop(entry.entry_id) conversation.async_unset_agent(hass, entry) return True @@ -106,6 +162,7 @@ class OpenAIAgent(conversation.AbstractConversationAgent): try: result = await openai.ChatCompletion.acreate( + api_key=self.entry.data[CONF_API_KEY], model=model, messages=messages, max_tokens=max_tokens, diff --git a/homeassistant/components/openai_conversation/services.yaml b/homeassistant/components/openai_conversation/services.yaml new file mode 100644 index 00000000000..81818fb3e71 --- /dev/null +++ b/homeassistant/components/openai_conversation/services.yaml @@ -0,0 +1,22 @@ +generate_image: + fields: + config_entry: + required: true + selector: + config_entry: + integration: openai_conversation + prompt: + required: true + selector: + text: + multiline: true + size: + required: true + example: "512" + default: "512" + selector: + select: + options: + - "256" + - "512" + - "1024" diff --git a/homeassistant/components/openai_conversation/strings.json b/homeassistant/components/openai_conversation/strings.json index 9583e759bd2..542fe06dd56 100644 --- a/homeassistant/components/openai_conversation/strings.json +++ b/homeassistant/components/openai_conversation/strings.json @@ -25,5 +25,26 @@ } } } + }, + "services": { + "generate_image": { + "name": "Generate image", + "description": "Turn a prompt into an image", + "fields": { + "config_entry": { + "name": "Config Entry", + "description": "The config entry to use for this service" + }, + "prompt": { + "name": "Prompt", + "description": "The text to turn into an image", + "example": "A photo of a dog" + }, + "size": { + "name": "Size", + "description": "The size of the image to generate" + } + } + } } } diff --git a/homeassistant/helpers/selector.py b/homeassistant/helpers/selector.py index 8ec8d5eac3e..08975c5c881 100644 --- a/homeassistant/helpers/selector.py +++ b/homeassistant/helpers/selector.py @@ -539,7 +539,7 @@ class ConversationAgentSelectorConfig(TypedDict, total=False): @SELECTORS.register("conversation_agent") -class COnversationAgentSelector(Selector[ConversationAgentSelectorConfig]): +class ConversationAgentSelector(Selector[ConversationAgentSelectorConfig]): """Selector for a conversation agent.""" selector_type = "conversation_agent" diff --git a/tests/components/openai_conversation/test_init.py b/tests/components/openai_conversation/test_init.py index fe23bbac56c..1b9f81f60c0 100644 --- a/tests/components/openai_conversation/test_init.py +++ b/tests/components/openai_conversation/test_init.py @@ -2,10 +2,12 @@ from unittest.mock import patch from openai import error +import pytest from syrupy.assertion import SnapshotAssertion from homeassistant.components import conversation from homeassistant.core import Context, HomeAssistant +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import area_registry as ar, device_registry as dr, intent from tests.common import MockConfigEntry @@ -158,3 +160,69 @@ async def test_conversation_agent( mock_config_entry.entry_id ) assert agent.supported_languages == "*" + + +@pytest.mark.parametrize( + ("service_data", "expected_args"), + [ + ( + {"prompt": "Picture of a dog"}, + {"prompt": "Picture of a dog", "size": "512x512"}, + ), + ( + {"prompt": "Picture of a dog", "size": "256"}, + {"prompt": "Picture of a dog", "size": "256x256"}, + ), + ( + {"prompt": "Picture of a dog", "size": "1024"}, + {"prompt": "Picture of a dog", "size": "1024x1024"}, + ), + ], +) +async def test_generate_image_service( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + service_data, + expected_args, +) -> None: + """Test generate image service.""" + service_data["config_entry"] = mock_config_entry.entry_id + expected_args["api_key"] = mock_config_entry.data["api_key"] + expected_args["n"] = 1 + + with patch( + "openai.Image.acreate", return_value={"data": [{"url": "A"}]} + ) as mock_create: + response = await hass.services.async_call( + "openai_conversation", + "generate_image", + service_data, + blocking=True, + return_response=True, + ) + + assert response == {"url": "A"} + assert len(mock_create.mock_calls) == 1 + assert mock_create.mock_calls[0][2] == expected_args + + +@pytest.mark.usefixtures("mock_init_component") +async def test_generate_image_service_error( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, +) -> None: + """Test generate image service handles errors.""" + with patch( + "openai.Image.acreate", side_effect=error.ServiceUnavailableError("Reason") + ), pytest.raises(HomeAssistantError, match="Error generating image: Reason"): + await hass.services.async_call( + "openai_conversation", + "generate_image", + { + "config_entry": mock_config_entry.entry_id, + "prompt": "Image of an epic fail", + }, + blocking=True, + return_response=True, + )