Add service to OpenAI to Generate an image (#97018)
Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
parent
5caa1969c5
commit
aad281db18
5 changed files with 176 additions and 8 deletions
|
@ -7,13 +7,24 @@ from typing import Literal
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from openai import error
|
from openai import error
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import CONF_API_KEY, MATCH_ALL
|
from homeassistant.const import CONF_API_KEY, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import (
|
||||||
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
|
HomeAssistant,
|
||||||
from homeassistant.helpers import intent, template
|
ServiceCall,
|
||||||
|
ServiceResponse,
|
||||||
|
SupportsResponse,
|
||||||
|
)
|
||||||
|
from homeassistant.exceptions import (
|
||||||
|
ConfigEntryNotReady,
|
||||||
|
HomeAssistantError,
|
||||||
|
TemplateError,
|
||||||
|
)
|
||||||
|
from homeassistant.helpers import config_validation as cv, intent, selector, template
|
||||||
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.util import ulid
|
from homeassistant.util import ulid
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
|
@ -27,18 +38,61 @@ from .const import (
|
||||||
DEFAULT_PROMPT,
|
DEFAULT_PROMPT,
|
||||||
DEFAULT_TEMPERATURE,
|
DEFAULT_TEMPERATURE,
|
||||||
DEFAULT_TOP_P,
|
DEFAULT_TOP_P,
|
||||||
|
DOMAIN,
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
SERVICE_GENERATE_IMAGE = "generate_image"
|
||||||
|
|
||||||
|
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
|
"""Set up OpenAI Conversation."""
|
||||||
|
|
||||||
|
async def render_image(call: ServiceCall) -> ServiceResponse:
|
||||||
|
"""Render an image with dall-e."""
|
||||||
|
try:
|
||||||
|
response = await openai.Image.acreate(
|
||||||
|
api_key=hass.data[DOMAIN][call.data["config_entry"]],
|
||||||
|
prompt=call.data["prompt"],
|
||||||
|
n=1,
|
||||||
|
size=f'{call.data["size"]}x{call.data["size"]}',
|
||||||
|
)
|
||||||
|
except error.OpenAIError as err:
|
||||||
|
raise HomeAssistantError(f"Error generating image: {err}") from err
|
||||||
|
|
||||||
|
return response["data"][0]
|
||||||
|
|
||||||
|
hass.services.async_register(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_GENERATE_IMAGE,
|
||||||
|
render_image,
|
||||||
|
schema=vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required("config_entry"): selector.ConfigEntrySelector(
|
||||||
|
{
|
||||||
|
"integration": DOMAIN,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
vol.Required("prompt"): cv.string,
|
||||||
|
vol.Optional("size", default="512"): vol.In(("256", "512", "1024")),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
supports_response=SupportsResponse.ONLY,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Set up OpenAI Conversation from a config entry."""
|
"""Set up OpenAI Conversation from a config entry."""
|
||||||
openai.api_key = entry.data[CONF_API_KEY]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await hass.async_add_executor_job(
|
await hass.async_add_executor_job(
|
||||||
partial(openai.Engine.list, request_timeout=10)
|
partial(
|
||||||
|
openai.Engine.list,
|
||||||
|
api_key=entry.data[CONF_API_KEY],
|
||||||
|
request_timeout=10,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except error.AuthenticationError as err:
|
except error.AuthenticationError as err:
|
||||||
_LOGGER.error("Invalid API key: %s", err)
|
_LOGGER.error("Invalid API key: %s", err)
|
||||||
|
@ -46,13 +100,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
except error.OpenAIError as err:
|
except error.OpenAIError as err:
|
||||||
raise ConfigEntryNotReady(err) from err
|
raise ConfigEntryNotReady(err) from err
|
||||||
|
|
||||||
|
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry.data[CONF_API_KEY]
|
||||||
|
|
||||||
conversation.async_set_agent(hass, entry, OpenAIAgent(hass, entry))
|
conversation.async_set_agent(hass, entry, OpenAIAgent(hass, entry))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Unload OpenAI."""
|
"""Unload OpenAI."""
|
||||||
openai.api_key = None
|
hass.data[DOMAIN].pop(entry.entry_id)
|
||||||
conversation.async_unset_agent(hass, entry)
|
conversation.async_unset_agent(hass, entry)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -106,6 +162,7 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await openai.ChatCompletion.acreate(
|
result = await openai.ChatCompletion.acreate(
|
||||||
|
api_key=self.entry.data[CONF_API_KEY],
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
|
22
homeassistant/components/openai_conversation/services.yaml
Normal file
22
homeassistant/components/openai_conversation/services.yaml
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
generate_image:
|
||||||
|
fields:
|
||||||
|
config_entry:
|
||||||
|
required: true
|
||||||
|
selector:
|
||||||
|
config_entry:
|
||||||
|
integration: openai_conversation
|
||||||
|
prompt:
|
||||||
|
required: true
|
||||||
|
selector:
|
||||||
|
text:
|
||||||
|
multiline: true
|
||||||
|
size:
|
||||||
|
required: true
|
||||||
|
example: "512"
|
||||||
|
default: "512"
|
||||||
|
selector:
|
||||||
|
select:
|
||||||
|
options:
|
||||||
|
- "256"
|
||||||
|
- "512"
|
||||||
|
- "1024"
|
|
@ -25,5 +25,26 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"services": {
|
||||||
|
"generate_image": {
|
||||||
|
"name": "Generate image",
|
||||||
|
"description": "Turn a prompt into an image",
|
||||||
|
"fields": {
|
||||||
|
"config_entry": {
|
||||||
|
"name": "Config Entry",
|
||||||
|
"description": "The config entry to use for this service"
|
||||||
|
},
|
||||||
|
"prompt": {
|
||||||
|
"name": "Prompt",
|
||||||
|
"description": "The text to turn into an image",
|
||||||
|
"example": "A photo of a dog"
|
||||||
|
},
|
||||||
|
"size": {
|
||||||
|
"name": "Size",
|
||||||
|
"description": "The size of the image to generate"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -539,7 +539,7 @@ class ConversationAgentSelectorConfig(TypedDict, total=False):
|
||||||
|
|
||||||
|
|
||||||
@SELECTORS.register("conversation_agent")
|
@SELECTORS.register("conversation_agent")
|
||||||
class COnversationAgentSelector(Selector[ConversationAgentSelectorConfig]):
|
class ConversationAgentSelector(Selector[ConversationAgentSelectorConfig]):
|
||||||
"""Selector for a conversation agent."""
|
"""Selector for a conversation agent."""
|
||||||
|
|
||||||
selector_type = "conversation_agent"
|
selector_type = "conversation_agent"
|
||||||
|
|
|
@ -2,10 +2,12 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from openai import error
|
from openai import error
|
||||||
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
|
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
@ -158,3 +160,69 @@ async def test_conversation_agent(
|
||||||
mock_config_entry.entry_id
|
mock_config_entry.entry_id
|
||||||
)
|
)
|
||||||
assert agent.supported_languages == "*"
|
assert agent.supported_languages == "*"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("service_data", "expected_args"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
{"prompt": "Picture of a dog"},
|
||||||
|
{"prompt": "Picture of a dog", "size": "512x512"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"prompt": "Picture of a dog", "size": "256"},
|
||||||
|
{"prompt": "Picture of a dog", "size": "256x256"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"prompt": "Picture of a dog", "size": "1024"},
|
||||||
|
{"prompt": "Picture of a dog", "size": "1024x1024"},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_generate_image_service(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
service_data,
|
||||||
|
expected_args,
|
||||||
|
) -> None:
|
||||||
|
"""Test generate image service."""
|
||||||
|
service_data["config_entry"] = mock_config_entry.entry_id
|
||||||
|
expected_args["api_key"] = mock_config_entry.data["api_key"]
|
||||||
|
expected_args["n"] = 1
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"openai.Image.acreate", return_value={"data": [{"url": "A"}]}
|
||||||
|
) as mock_create:
|
||||||
|
response = await hass.services.async_call(
|
||||||
|
"openai_conversation",
|
||||||
|
"generate_image",
|
||||||
|
service_data,
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == {"url": "A"}
|
||||||
|
assert len(mock_create.mock_calls) == 1
|
||||||
|
assert mock_create.mock_calls[0][2] == expected_args
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
async def test_generate_image_service_error(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
) -> None:
|
||||||
|
"""Test generate image service handles errors."""
|
||||||
|
with patch(
|
||||||
|
"openai.Image.acreate", side_effect=error.ServiceUnavailableError("Reason")
|
||||||
|
), pytest.raises(HomeAssistantError, match="Error generating image: Reason"):
|
||||||
|
await hass.services.async_call(
|
||||||
|
"openai_conversation",
|
||||||
|
"generate_image",
|
||||||
|
{
|
||||||
|
"config_entry": mock_config_entry.entry_id,
|
||||||
|
"prompt": "Image of an epic fail",
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue