Add service to OpenAI to Generate an image (#97018)

Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Paulus Schoutsen 2023-07-26 04:14:18 -04:00 committed by GitHub
parent 5caa1969c5
commit aad281db18
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 176 additions and 8 deletions

View file

@ -7,13 +7,24 @@ from typing import Literal
import openai
from openai import error
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
from homeassistant.helpers import intent, template
from homeassistant.core import (
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
)
from homeassistant.exceptions import (
ConfigEntryNotReady,
HomeAssistantError,
TemplateError,
)
from homeassistant.helpers import config_validation as cv, intent, selector, template
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import ulid
from .const import (
@ -27,18 +38,61 @@ from .const import (
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN,
)
_LOGGER = logging.getLogger(__name__)
SERVICE_GENERATE_IMAGE = "generate_image"
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up OpenAI Conversation."""
async def render_image(call: ServiceCall) -> ServiceResponse:
"""Render an image with dall-e."""
try:
response = await openai.Image.acreate(
api_key=hass.data[DOMAIN][call.data["config_entry"]],
prompt=call.data["prompt"],
n=1,
size=f'{call.data["size"]}x{call.data["size"]}',
)
except error.OpenAIError as err:
raise HomeAssistantError(f"Error generating image: {err}") from err
return response["data"][0]
hass.services.async_register(
DOMAIN,
SERVICE_GENERATE_IMAGE,
render_image,
schema=vol.Schema(
{
vol.Required("config_entry"): selector.ConfigEntrySelector(
{
"integration": DOMAIN,
}
),
vol.Required("prompt"): cv.string,
vol.Optional("size", default="512"): vol.In(("256", "512", "1024")),
}
),
supports_response=SupportsResponse.ONLY,
)
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up OpenAI Conversation from a config entry."""
openai.api_key = entry.data[CONF_API_KEY]
try:
await hass.async_add_executor_job(
partial(openai.Engine.list, request_timeout=10)
partial(
openai.Engine.list,
api_key=entry.data[CONF_API_KEY],
request_timeout=10,
)
)
except error.AuthenticationError as err:
_LOGGER.error("Invalid API key: %s", err)
@ -46,13 +100,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
except error.OpenAIError as err:
raise ConfigEntryNotReady(err) from err
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry.data[CONF_API_KEY]
conversation.async_set_agent(hass, entry, OpenAIAgent(hass, entry))
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload OpenAI."""
openai.api_key = None
hass.data[DOMAIN].pop(entry.entry_id)
conversation.async_unset_agent(hass, entry)
return True
@ -106,6 +162,7 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
try:
result = await openai.ChatCompletion.acreate(
api_key=self.entry.data[CONF_API_KEY],
model=model,
messages=messages,
max_tokens=max_tokens,

View file

@ -0,0 +1,22 @@
generate_image:
fields:
config_entry:
required: true
selector:
config_entry:
integration: openai_conversation
prompt:
required: true
selector:
text:
multiline: true
size:
required: true
example: "512"
default: "512"
selector:
select:
options:
- "256"
- "512"
- "1024"

View file

@ -25,5 +25,26 @@
}
}
}
},
"services": {
"generate_image": {
"name": "Generate image",
"description": "Turn a prompt into an image",
"fields": {
"config_entry": {
"name": "Config Entry",
"description": "The config entry to use for this service"
},
"prompt": {
"name": "Prompt",
"description": "The text to turn into an image",
"example": "A photo of a dog"
},
"size": {
"name": "Size",
"description": "The size of the image to generate"
}
}
}
}
}

View file

@ -539,7 +539,7 @@ class ConversationAgentSelectorConfig(TypedDict, total=False):
@SELECTORS.register("conversation_agent")
class COnversationAgentSelector(Selector[ConversationAgentSelectorConfig]):
class ConversationAgentSelector(Selector[ConversationAgentSelectorConfig]):
"""Selector for a conversation agent."""
selector_type = "conversation_agent"

View file

@ -2,10 +2,12 @@
from unittest.mock import patch
from openai import error
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
from tests.common import MockConfigEntry
@ -158,3 +160,69 @@ async def test_conversation_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == "*"
@pytest.mark.parametrize(
("service_data", "expected_args"),
[
(
{"prompt": "Picture of a dog"},
{"prompt": "Picture of a dog", "size": "512x512"},
),
(
{"prompt": "Picture of a dog", "size": "256"},
{"prompt": "Picture of a dog", "size": "256x256"},
),
(
{"prompt": "Picture of a dog", "size": "1024"},
{"prompt": "Picture of a dog", "size": "1024x1024"},
),
],
)
async def test_generate_image_service(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
service_data,
expected_args,
) -> None:
"""Test generate image service."""
service_data["config_entry"] = mock_config_entry.entry_id
expected_args["api_key"] = mock_config_entry.data["api_key"]
expected_args["n"] = 1
with patch(
"openai.Image.acreate", return_value={"data": [{"url": "A"}]}
) as mock_create:
response = await hass.services.async_call(
"openai_conversation",
"generate_image",
service_data,
blocking=True,
return_response=True,
)
assert response == {"url": "A"}
assert len(mock_create.mock_calls) == 1
assert mock_create.mock_calls[0][2] == expected_args
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_image_service_error(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test generate image service handles errors."""
with patch(
"openai.Image.acreate", side_effect=error.ServiceUnavailableError("Reason")
), pytest.raises(HomeAssistantError, match="Error generating image: Reason"):
await hass.services.async_call(
"openai_conversation",
"generate_image",
{
"config_entry": mock_config_entry.entry_id,
"prompt": "Image of an epic fail",
},
blocking=True,
return_response=True,
)