Add service to OpenAI to Generate an image (#97018)
Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
parent
5caa1969c5
commit
aad281db18
5 changed files with 176 additions and 8 deletions
|
@ -7,13 +7,24 @@ from typing import Literal
|
|||
|
||||
import openai
|
||||
from openai import error
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_API_KEY, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
|
||||
from homeassistant.helpers import intent, template
|
||||
from homeassistant.core import (
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
SupportsResponse,
|
||||
)
|
||||
from homeassistant.exceptions import (
|
||||
ConfigEntryNotReady,
|
||||
HomeAssistantError,
|
||||
TemplateError,
|
||||
)
|
||||
from homeassistant.helpers import config_validation as cv, intent, selector, template
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from .const import (
|
||||
|
@ -27,18 +38,61 @@ from .const import (
|
|||
DEFAULT_PROMPT,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_P,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
SERVICE_GENERATE_IMAGE = "generate_image"
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up OpenAI Conversation."""
|
||||
|
||||
async def render_image(call: ServiceCall) -> ServiceResponse:
|
||||
"""Render an image with dall-e."""
|
||||
try:
|
||||
response = await openai.Image.acreate(
|
||||
api_key=hass.data[DOMAIN][call.data["config_entry"]],
|
||||
prompt=call.data["prompt"],
|
||||
n=1,
|
||||
size=f'{call.data["size"]}x{call.data["size"]}',
|
||||
)
|
||||
except error.OpenAIError as err:
|
||||
raise HomeAssistantError(f"Error generating image: {err}") from err
|
||||
|
||||
return response["data"][0]
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
SERVICE_GENERATE_IMAGE,
|
||||
render_image,
|
||||
schema=vol.Schema(
|
||||
{
|
||||
vol.Required("config_entry"): selector.ConfigEntrySelector(
|
||||
{
|
||||
"integration": DOMAIN,
|
||||
}
|
||||
),
|
||||
vol.Required("prompt"): cv.string,
|
||||
vol.Optional("size", default="512"): vol.In(("256", "512", "1024")),
|
||||
}
|
||||
),
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up OpenAI Conversation from a config entry."""
|
||||
openai.api_key = entry.data[CONF_API_KEY]
|
||||
|
||||
try:
|
||||
await hass.async_add_executor_job(
|
||||
partial(openai.Engine.list, request_timeout=10)
|
||||
partial(
|
||||
openai.Engine.list,
|
||||
api_key=entry.data[CONF_API_KEY],
|
||||
request_timeout=10,
|
||||
)
|
||||
)
|
||||
except error.AuthenticationError as err:
|
||||
_LOGGER.error("Invalid API key: %s", err)
|
||||
|
@ -46,13 +100,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
except error.OpenAIError as err:
|
||||
raise ConfigEntryNotReady(err) from err
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry.data[CONF_API_KEY]
|
||||
|
||||
conversation.async_set_agent(hass, entry, OpenAIAgent(hass, entry))
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload OpenAI."""
|
||||
openai.api_key = None
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
conversation.async_unset_agent(hass, entry)
|
||||
return True
|
||||
|
||||
|
@ -106,6 +162,7 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
|||
|
||||
try:
|
||||
result = await openai.ChatCompletion.acreate(
|
||||
api_key=self.entry.data[CONF_API_KEY],
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
|
|
22
homeassistant/components/openai_conversation/services.yaml
Normal file
22
homeassistant/components/openai_conversation/services.yaml
Normal file
|
@ -0,0 +1,22 @@
|
|||
generate_image:
|
||||
fields:
|
||||
config_entry:
|
||||
required: true
|
||||
selector:
|
||||
config_entry:
|
||||
integration: openai_conversation
|
||||
prompt:
|
||||
required: true
|
||||
selector:
|
||||
text:
|
||||
multiline: true
|
||||
size:
|
||||
required: true
|
||||
example: "512"
|
||||
default: "512"
|
||||
selector:
|
||||
select:
|
||||
options:
|
||||
- "256"
|
||||
- "512"
|
||||
- "1024"
|
|
@ -25,5 +25,26 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"generate_image": {
|
||||
"name": "Generate image",
|
||||
"description": "Turn a prompt into an image",
|
||||
"fields": {
|
||||
"config_entry": {
|
||||
"name": "Config Entry",
|
||||
"description": "The config entry to use for this service"
|
||||
},
|
||||
"prompt": {
|
||||
"name": "Prompt",
|
||||
"description": "The text to turn into an image",
|
||||
"example": "A photo of a dog"
|
||||
},
|
||||
"size": {
|
||||
"name": "Size",
|
||||
"description": "The size of the image to generate"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -539,7 +539,7 @@ class ConversationAgentSelectorConfig(TypedDict, total=False):
|
|||
|
||||
|
||||
@SELECTORS.register("conversation_agent")
|
||||
class COnversationAgentSelector(Selector[ConversationAgentSelectorConfig]):
|
||||
class ConversationAgentSelector(Selector[ConversationAgentSelectorConfig]):
|
||||
"""Selector for a conversation agent."""
|
||||
|
||||
selector_type = "conversation_agent"
|
||||
|
|
|
@ -2,10 +2,12 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
from openai import error
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
@ -158,3 +160,69 @@ async def test_conversation_agent(
|
|||
mock_config_entry.entry_id
|
||||
)
|
||||
assert agent.supported_languages == "*"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("service_data", "expected_args"),
|
||||
[
|
||||
(
|
||||
{"prompt": "Picture of a dog"},
|
||||
{"prompt": "Picture of a dog", "size": "512x512"},
|
||||
),
|
||||
(
|
||||
{"prompt": "Picture of a dog", "size": "256"},
|
||||
{"prompt": "Picture of a dog", "size": "256x256"},
|
||||
),
|
||||
(
|
||||
{"prompt": "Picture of a dog", "size": "1024"},
|
||||
{"prompt": "Picture of a dog", "size": "1024x1024"},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_generate_image_service(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
service_data,
|
||||
expected_args,
|
||||
) -> None:
|
||||
"""Test generate image service."""
|
||||
service_data["config_entry"] = mock_config_entry.entry_id
|
||||
expected_args["api_key"] = mock_config_entry.data["api_key"]
|
||||
expected_args["n"] = 1
|
||||
|
||||
with patch(
|
||||
"openai.Image.acreate", return_value={"data": [{"url": "A"}]}
|
||||
) as mock_create:
|
||||
response = await hass.services.async_call(
|
||||
"openai_conversation",
|
||||
"generate_image",
|
||||
service_data,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
assert response == {"url": "A"}
|
||||
assert len(mock_create.mock_calls) == 1
|
||||
assert mock_create.mock_calls[0][2] == expected_args
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_generate_image_service_error(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test generate image service handles errors."""
|
||||
with patch(
|
||||
"openai.Image.acreate", side_effect=error.ServiceUnavailableError("Reason")
|
||||
), pytest.raises(HomeAssistantError, match="Error generating image: Reason"):
|
||||
await hass.services.async_call(
|
||||
"openai_conversation",
|
||||
"generate_image",
|
||||
{
|
||||
"config_entry": mock_config_entry.entry_id,
|
||||
"prompt": "Image of an epic fail",
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue