Add Google Gen AI Conversation Agent Entity (#116362)

* Add Google Gen AI Conversation Agent Entity

* Rename agent to entity

* Revert ollama changes

* Don't copy service tests to conversation_test.py

* Move logger and cleanup snapshots

* Move property after init

* Set logger to use package

* Cleanup hass from constructor

* Fix merges

* Revert ollama change
This commit is contained in:
Allen Porter 2024-05-17 08:00:11 -07:00 committed by GitHub
parent fce4263493
commit caa35174cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 548 additions and 369 deletions

View file

@ -3,55 +3,33 @@
from __future__ import annotations from __future__ import annotations
from functools import partial from functools import partial
import logging
import mimetypes import mimetypes
from pathlib import Path from pathlib import Path
from typing import Literal
from google.api_core.exceptions import ClientError from google.api_core.exceptions import ClientError
import google.generativeai as genai import google.generativeai as genai
import google.generativeai.types as genai_types import google.generativeai.types as genai_types
import voluptuous as vol import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, MATCH_ALL from homeassistant.const import CONF_API_KEY, Platform
from homeassistant.core import ( from homeassistant.core import (
HomeAssistant, HomeAssistant,
ServiceCall, ServiceCall,
ServiceResponse, ServiceResponse,
SupportsResponse, SupportsResponse,
) )
from homeassistant.exceptions import ( from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
ConfigEntryNotReady, from homeassistant.helpers import config_validation as cv
HomeAssistantError,
TemplateError,
)
from homeassistant.helpers import config_validation as cv, intent, template
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.util import ulid
from .const import ( from .const import CONF_CHAT_MODEL, CONF_PROMPT, DEFAULT_CHAT_MODEL, DOMAIN, LOGGER
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DOMAIN,
)
_LOGGER = logging.getLogger(__name__)
SERVICE_GENERATE_CONTENT = "generate_content" SERVICE_GENERATE_CONTENT = "generate_content"
CONF_IMAGE_FILENAME = "image_filename" CONF_IMAGE_FILENAME = "image_filename"
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION,)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
@ -126,118 +104,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
) )
except ClientError as err: except ClientError as err:
if err.reason == "API_KEY_INVALID": if err.reason == "API_KEY_INVALID":
_LOGGER.error("Invalid API key: %s", err) LOGGER.error("Invalid API key: %s", err)
return False return False
raise ConfigEntryNotReady(err) from err raise ConfigEntryNotReady(err) from err
conversation.async_set_agent(hass, entry, GoogleGenerativeAIAgent(hass, entry)) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload GoogleGenerativeAI.""" """Unload GoogleGenerativeAI."""
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
return False
genai.configure(api_key=None) genai.configure(api_key=None)
conversation.async_unset_agent(hass, entry)
return True return True
class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent):
"""Google Generative AI conversation agent."""
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.hass = hass
self.entry = entry
self.history: dict[str, list[genai_types.ContentType]] = {}
@property
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
return MATCH_ALL
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = genai.GenerativeModel(
model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL),
generation_config={
"temperature": self.entry.options.get(
CONF_TEMPERATURE, DEFAULT_TEMPERATURE
),
"top_p": self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P),
"top_k": self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K),
"max_output_tokens": self.entry.options.get(
CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS
),
},
)
_LOGGER.debug("Model: %s", model)
if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
messages = self.history[conversation_id]
else:
conversation_id = ulid.ulid_now()
messages = [{}, {}]
intent_response = intent.IntentResponse(language=user_input.language)
try:
prompt = self._async_generate_prompt(raw_prompt)
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
messages[0] = {"role": "user", "parts": prompt}
messages[1] = {"role": "model", "parts": "Ok"}
_LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
chat = model.start_chat(history=messages)
try:
chat_response = await chat.send_message_async(user_input.text)
except (
ClientError,
ValueError,
genai_types.BlockedPromptException,
genai_types.StopCandidateException,
) as err:
_LOGGER.error("Error sending message: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to Google Generative AI: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
_LOGGER.debug("Response: %s", chat_response.parts)
if not chat_response.parts:
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem talking to Google Generative AI. Likely blocked",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
self.history[conversation_id] = chat.history
intent_response.async_set_speech(chat_response.text)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
def _async_generate_prompt(self, raw_prompt: str) -> str:
"""Generate a prompt for the user."""
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)

View file

@ -1,6 +1,9 @@
"""Constants for the Google Generative AI Conversation integration.""" """Constants for the Google Generative AI Conversation integration."""
import logging
DOMAIN = "google_generative_ai_conversation" DOMAIN = "google_generative_ai_conversation"
LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt" CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant. DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.

View file

@ -0,0 +1,164 @@
"""Conversation support for the Google Generative AI Conversation integration."""
from __future__ import annotations
from typing import Literal
from google.api_core.exceptions import ClientError
import google.generativeai as genai
import google.generativeai.types as genai_types
from homeassistant.components import assist_pipeline, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import intent, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
LOGGER,
)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up conversation entities."""
agent = GoogleGenerativeAIConversationEntity(config_entry)
async_add_entities([agent])
class GoogleGenerativeAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
"""Google Generative AI conversation agent."""
_attr_has_entity_name = True
def __init__(self, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.entry = entry
self.history: dict[str, list[genai_types.ContentType]] = {}
self._attr_name = entry.title
self._attr_unique_id = entry.entry_id
@property
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
return MATCH_ALL
async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
assist_pipeline.async_migrate_engine(
self.hass, "conversation", self.entry.entry_id, self.entity_id
)
conversation.async_set_agent(self.hass, self.entry, self)
async def async_will_remove_from_hass(self) -> None:
"""When entity will be removed from Home Assistant."""
conversation.async_unset_agent(self.hass, self.entry)
await super().async_will_remove_from_hass()
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = genai.GenerativeModel(
model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL),
generation_config={
"temperature": self.entry.options.get(
CONF_TEMPERATURE, DEFAULT_TEMPERATURE
),
"top_p": self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P),
"top_k": self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K),
"max_output_tokens": self.entry.options.get(
CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS
),
},
)
LOGGER.debug("Model: %s", model)
if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
messages = self.history[conversation_id]
else:
conversation_id = ulid.ulid_now()
messages = [{}, {}]
intent_response = intent.IntentResponse(language=user_input.language)
try:
prompt = self._async_generate_prompt(raw_prompt)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
messages[0] = {"role": "user", "parts": prompt}
messages[1] = {"role": "model", "parts": "Ok"}
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
chat = model.start_chat(history=messages)
try:
chat_response = await chat.send_message_async(user_input.text)
except (
ClientError,
ValueError,
genai_types.BlockedPromptException,
genai_types.StopCandidateException,
) as err:
LOGGER.error("Error sending message: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to Google Generative AI: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
LOGGER.debug("Response: %s", chat_response.parts)
if not chat_response.parts:
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem talking to Google Generative AI. Likely blocked",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
self.history[conversation_id] = chat.history
intent_response.async_set_speech(chat_response.text)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
def _async_generate_prompt(self, raw_prompt: str) -> str:
"""Generate a prompt for the user."""
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)

View file

@ -1,6 +1,7 @@
{ {
"domain": "google_generative_ai_conversation", "domain": "google_generative_ai_conversation",
"name": "Google Generative AI Conversation", "name": "Google Generative AI Conversation",
"after_dependencies": ["assist_pipeline"],
"codeowners": ["@tronikos"], "codeowners": ["@tronikos"],
"config_flow": true, "config_flow": true,
"dependencies": ["conversation"], "dependencies": ["conversation"],

View file

@ -16,6 +16,7 @@ def mock_config_entry(hass):
"""Mock a config entry.""" """Mock a config entry."""
entry = MockConfigEntry( entry = MockConfigEntry(
domain="google_generative_ai_conversation", domain="google_generative_ai_conversation",
title="Google Generative AI Conversation",
data={ data={
"api_key": "bla", "api_key": "bla",
}, },

View file

@ -0,0 +1,169 @@
# serializer version: 1
# name: test_default_prompt[None]
list([
tuple(
'',
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 0.9,
'top_k': 1,
'top_p': 1.0,
}),
'model_name': 'models/gemini-pro',
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([
dict({
'parts': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'hello',
),
dict({
}),
),
])
# ---
# name: test_default_prompt[conversation.google_generative_ai_conversation]
list([
tuple(
'',
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 0.9,
'top_k': 1,
'top_p': 1.0,
}),
'model_name': 'models/gemini-pro',
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([
dict({
'parts': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'hello',
),
dict({
}),
),
])
# ---
# name: test_generate_content_service_with_image
list([
tuple(
'',
tuple(
),
dict({
'model_name': 'gemini-pro-vision',
}),
),
tuple(
'().generate_content_async',
tuple(
list([
'Describe this image from my doorbell camera',
dict({
'data': b'image bytes',
'mime_type': 'image/jpeg',
}),
]),
),
dict({
}),
),
])
# ---
# name: test_generate_content_service_without_images
list([
tuple(
'',
tuple(
),
dict({
'model_name': 'gemini-pro',
}),
),
tuple(
'().generate_content_async',
tuple(
list([
'Write an opening speech for a Home Assistant release party',
]),
),
dict({
}),
),
])
# ---

View file

@ -1,64 +1,4 @@
# serializer version: 1 # serializer version: 1
# name: test_default_prompt
list([
tuple(
'',
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 0.9,
'top_k': 1,
'top_p': 1.0,
}),
'model_name': 'models/gemini-pro',
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([
dict({
'parts': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'hello',
),
dict({
}),
),
])
# ---
# name: test_generate_content_service_with_image # name: test_generate_content_service_with_image
list([ list([
tuple( tuple(

View file

@ -0,0 +1,198 @@
"""Tests for the Google Generative AI Conversation integration conversation platform."""
from unittest.mock import AsyncMock, MagicMock, patch
from google.api_core.exceptions import ClientError
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
from tests.common import MockConfigEntry
@pytest.mark.parametrize(
"agent_id", [None, "conversation.google_generative_ai_conversation"]
)
async def test_default_prompt(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
agent_id: str | None,
) -> None:
"""Test that the default prompt works."""
entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
for i in range(3):
area_registry.async_create(f"{i}Empty Area")
if agent_id is None:
agent_id = mock_config_entry.entry_id
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
name="Test Device",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
)
for i in range(3):
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", f"{i}abcd")},
name="Test Service",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
entry_type=dr.DeviceEntryType.SERVICE,
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "5678")},
name="Test Device 2",
manufacturer="Test Manufacturer 2",
model="Device 2",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "qwer")},
name="Test Device 4",
suggested_area="Test Area 2",
)
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-disabled")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_update_device(
device.id, disabled_by=dr.DeviceEntryDisabler.USER
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-no-name")},
manufacturer="Test Manufacturer NoName",
model="Test Model NoName",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-integer-values")},
name=1,
manufacturer=2,
model=3,
suggested_area="Test Area 2",
)
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
chat_response.parts = ["Hi there!"]
chat_response.text = "Hi there!"
result = await conversation.async_converse(
hass,
"hello",
None,
Context(),
agent_id=agent_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that client errors are caught."""
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
mock_chat.send_message_async.side_effect = ClientError("some error")
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"Sorry, I had a problem talking to Google Generative AI: None some error"
)
async def test_blocked_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test response was blocked."""
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
chat_response.parts = []
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"Sorry, I had a problem talking to Google Generative AI. Likely blocked"
)
async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that template error handling works."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
},
)
with (
patch(
"google.generativeai.get_model",
),
patch("google.generativeai.GenerativeModel"),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
async def test_conversation_agent(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test GoogleGenerativeAIAgent."""
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == "*"

View file

@ -6,188 +6,12 @@ from google.api_core.exceptions import ClientError
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation from homeassistant.core import HomeAssistant
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
async def test_default_prompt(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the default prompt works."""
entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
for i in range(3):
area_registry.async_create(f"{i}Empty Area")
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
name="Test Device",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
)
for i in range(3):
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", f"{i}abcd")},
name="Test Service",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
entry_type=dr.DeviceEntryType.SERVICE,
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "5678")},
name="Test Device 2",
manufacturer="Test Manufacturer 2",
model="Device 2",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "qwer")},
name="Test Device 4",
suggested_area="Test Area 2",
)
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-disabled")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_update_device(
device.id, disabled_by=dr.DeviceEntryDisabler.USER
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-no-name")},
manufacturer="Test Manufacturer NoName",
model="Test Model NoName",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-integer-values")},
name=1,
manufacturer=2,
model=3,
suggested_area="Test Area 2",
)
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
chat_response.parts = ["Hi there!"]
chat_response.text = "Hi there!"
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!"
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that client errors are caught."""
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
mock_chat.send_message_async.side_effect = ClientError("some error")
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"Sorry, I had a problem talking to Google Generative AI: None some error"
)
async def test_blocked_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test response was blocked."""
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
chat_response.parts = []
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"Sorry, I had a problem talking to Google Generative AI. Likely blocked"
)
async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that template error handling works."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
},
)
with (
patch(
"google.generativeai.get_model",
),
patch("google.generativeai.GenerativeModel"),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
async def test_conversation_agent(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test GoogleGenerativeAIAgent."""
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == "*"
async def test_generate_content_service_without_images( async def test_generate_content_service_without_images(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,