diff --git a/homeassistant/components/google_generative_ai_conversation/__init__.py b/homeassistant/components/google_generative_ai_conversation/__init__.py index 96be366a658..d4a6c5bfa69 100644 --- a/homeassistant/components/google_generative_ai_conversation/__init__.py +++ b/homeassistant/components/google_generative_ai_conversation/__init__.py @@ -3,55 +3,33 @@ from __future__ import annotations from functools import partial -import logging import mimetypes from pathlib import Path -from typing import Literal from google.api_core.exceptions import ClientError import google.generativeai as genai import google.generativeai.types as genai_types import voluptuous as vol -from homeassistant.components import conversation from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_API_KEY, MATCH_ALL +from homeassistant.const import CONF_API_KEY, Platform from homeassistant.core import ( HomeAssistant, ServiceCall, ServiceResponse, SupportsResponse, ) -from homeassistant.exceptions import ( - ConfigEntryNotReady, - HomeAssistantError, - TemplateError, -) -from homeassistant.helpers import config_validation as cv, intent, template +from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError +from homeassistant.helpers import config_validation as cv from homeassistant.helpers.typing import ConfigType -from homeassistant.util import ulid -from .const import ( - CONF_CHAT_MODEL, - CONF_MAX_TOKENS, - CONF_PROMPT, - CONF_TEMPERATURE, - CONF_TOP_K, - CONF_TOP_P, - DEFAULT_CHAT_MODEL, - DEFAULT_MAX_TOKENS, - DEFAULT_PROMPT, - DEFAULT_TEMPERATURE, - DEFAULT_TOP_K, - DEFAULT_TOP_P, - DOMAIN, -) +from .const import CONF_CHAT_MODEL, CONF_PROMPT, DEFAULT_CHAT_MODEL, DOMAIN, LOGGER -_LOGGER = logging.getLogger(__name__) SERVICE_GENERATE_CONTENT = "generate_content" CONF_IMAGE_FILENAME = "image_filename" CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) +PLATFORMS = (Platform.CONVERSATION,) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: @@ -126,118 +104,19 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ) except ClientError as err: if err.reason == "API_KEY_INVALID": - _LOGGER.error("Invalid API key: %s", err) + LOGGER.error("Invalid API key: %s", err) return False raise ConfigEntryNotReady(err) from err - conversation.async_set_agent(hass, entry, GoogleGenerativeAIAgent(hass, entry)) + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + return True async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload GoogleGenerativeAI.""" + if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS): + return False + genai.configure(api_key=None) - conversation.async_unset_agent(hass, entry) return True - - -class GoogleGenerativeAIAgent(conversation.AbstractConversationAgent): - """Google Generative AI conversation agent.""" - - def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None: - """Initialize the agent.""" - self.hass = hass - self.entry = entry - self.history: dict[str, list[genai_types.ContentType]] = {} - - @property - def supported_languages(self) -> list[str] | Literal["*"]: - """Return a list of supported languages.""" - return MATCH_ALL - - async def async_process( - self, user_input: conversation.ConversationInput - ) -> conversation.ConversationResult: - """Process a sentence.""" - raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) - model = genai.GenerativeModel( - model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL), - generation_config={ - "temperature": self.entry.options.get( - CONF_TEMPERATURE, DEFAULT_TEMPERATURE - ), - "top_p": self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P), - "top_k": self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K), - "max_output_tokens": self.entry.options.get( - CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS - ), - }, - ) - _LOGGER.debug("Model: %s", model) - - if user_input.conversation_id in self.history: - conversation_id = user_input.conversation_id - messages = self.history[conversation_id] - else: - conversation_id = ulid.ulid_now() - messages = [{}, {}] - - intent_response = intent.IntentResponse(language=user_input.language) - try: - prompt = self._async_generate_prompt(raw_prompt) - except TemplateError as err: - _LOGGER.error("Error rendering prompt: %s", err) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem with my template: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - messages[0] = {"role": "user", "parts": prompt} - messages[1] = {"role": "model", "parts": "Ok"} - - _LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) - - chat = model.start_chat(history=messages) - try: - chat_response = await chat.send_message_async(user_input.text) - except ( - ClientError, - ValueError, - genai_types.BlockedPromptException, - genai_types.StopCandidateException, - ) as err: - _LOGGER.error("Error sending message: %s", err) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem talking to Google Generative AI: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - _LOGGER.debug("Response: %s", chat_response.parts) - if not chat_response.parts: - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - "Sorry, I had a problem talking to Google Generative AI. Likely blocked", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - self.history[conversation_id] = chat.history - intent_response.async_set_speech(chat_response.text) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - def _async_generate_prompt(self, raw_prompt: str) -> str: - """Generate a prompt for the user.""" - return template.Template(raw_prompt, self.hass).async_render( - { - "ha_name": self.hass.config.location_name, - }, - parse_result=False, - ) diff --git a/homeassistant/components/google_generative_ai_conversation/const.py b/homeassistant/components/google_generative_ai_conversation/const.py index 2798b85f308..f7e71989efd 100644 --- a/homeassistant/components/google_generative_ai_conversation/const.py +++ b/homeassistant/components/google_generative_ai_conversation/const.py @@ -1,6 +1,9 @@ """Constants for the Google Generative AI Conversation integration.""" +import logging + DOMAIN = "google_generative_ai_conversation" +LOGGER = logging.getLogger(__package__) CONF_PROMPT = "prompt" DEFAULT_PROMPT = """This smart home is controlled by Home Assistant. diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py new file mode 100644 index 00000000000..90a3104f662 --- /dev/null +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -0,0 +1,164 @@ +"""Conversation support for the Google Generative AI Conversation integration.""" + +from __future__ import annotations + +from typing import Literal + +from google.api_core.exceptions import ClientError +import google.generativeai as genai +import google.generativeai.types as genai_types + +from homeassistant.components import assist_pipeline, conversation +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import MATCH_ALL +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import TemplateError +from homeassistant.helpers import intent, template +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.util import ulid + +from .const import ( + CONF_CHAT_MODEL, + CONF_MAX_TOKENS, + CONF_PROMPT, + CONF_TEMPERATURE, + CONF_TOP_K, + CONF_TOP_P, + DEFAULT_CHAT_MODEL, + DEFAULT_MAX_TOKENS, + DEFAULT_PROMPT, + DEFAULT_TEMPERATURE, + DEFAULT_TOP_K, + DEFAULT_TOP_P, + LOGGER, +) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up conversation entities.""" + agent = GoogleGenerativeAIConversationEntity(config_entry) + async_add_entities([agent]) + + +class GoogleGenerativeAIConversationEntity( + conversation.ConversationEntity, conversation.AbstractConversationAgent +): + """Google Generative AI conversation agent.""" + + _attr_has_entity_name = True + + def __init__(self, entry: ConfigEntry) -> None: + """Initialize the agent.""" + self.entry = entry + self.history: dict[str, list[genai_types.ContentType]] = {} + self._attr_name = entry.title + self._attr_unique_id = entry.entry_id + + @property + def supported_languages(self) -> list[str] | Literal["*"]: + """Return a list of supported languages.""" + return MATCH_ALL + + async def async_added_to_hass(self) -> None: + """When entity is added to Home Assistant.""" + await super().async_added_to_hass() + assist_pipeline.async_migrate_engine( + self.hass, "conversation", self.entry.entry_id, self.entity_id + ) + conversation.async_set_agent(self.hass, self.entry, self) + + async def async_will_remove_from_hass(self) -> None: + """When entity will be removed from Home Assistant.""" + conversation.async_unset_agent(self.hass, self.entry) + await super().async_will_remove_from_hass() + + async def async_process( + self, user_input: conversation.ConversationInput + ) -> conversation.ConversationResult: + """Process a sentence.""" + raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) + model = genai.GenerativeModel( + model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL), + generation_config={ + "temperature": self.entry.options.get( + CONF_TEMPERATURE, DEFAULT_TEMPERATURE + ), + "top_p": self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P), + "top_k": self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K), + "max_output_tokens": self.entry.options.get( + CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS + ), + }, + ) + LOGGER.debug("Model: %s", model) + + if user_input.conversation_id in self.history: + conversation_id = user_input.conversation_id + messages = self.history[conversation_id] + else: + conversation_id = ulid.ulid_now() + messages = [{}, {}] + + intent_response = intent.IntentResponse(language=user_input.language) + try: + prompt = self._async_generate_prompt(raw_prompt) + except TemplateError as err: + LOGGER.error("Error rendering prompt: %s", err) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem with my template: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + messages[0] = {"role": "user", "parts": prompt} + messages[1] = {"role": "model", "parts": "Ok"} + + LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) + + chat = model.start_chat(history=messages) + try: + chat_response = await chat.send_message_async(user_input.text) + except ( + ClientError, + ValueError, + genai_types.BlockedPromptException, + genai_types.StopCandidateException, + ) as err: + LOGGER.error("Error sending message: %s", err) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem talking to Google Generative AI: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + LOGGER.debug("Response: %s", chat_response.parts) + if not chat_response.parts: + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + "Sorry, I had a problem talking to Google Generative AI. Likely blocked", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + self.history[conversation_id] = chat.history + intent_response.async_set_speech(chat_response.text) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + def _async_generate_prompt(self, raw_prompt: str) -> str: + """Generate a prompt for the user.""" + return template.Template(raw_prompt, self.hass).async_render( + { + "ha_name": self.hass.config.location_name, + }, + parse_result=False, + ) diff --git a/homeassistant/components/google_generative_ai_conversation/manifest.json b/homeassistant/components/google_generative_ai_conversation/manifest.json index fd2b7c26323..b4f577db0d0 100644 --- a/homeassistant/components/google_generative_ai_conversation/manifest.json +++ b/homeassistant/components/google_generative_ai_conversation/manifest.json @@ -1,6 +1,7 @@ { "domain": "google_generative_ai_conversation", "name": "Google Generative AI Conversation", + "after_dependencies": ["assist_pipeline"], "codeowners": ["@tronikos"], "config_flow": true, "dependencies": ["conversation"], diff --git a/tests/components/google_generative_ai_conversation/conftest.py b/tests/components/google_generative_ai_conversation/conftest.py index c377a469df0..d5b4e8672e3 100644 --- a/tests/components/google_generative_ai_conversation/conftest.py +++ b/tests/components/google_generative_ai_conversation/conftest.py @@ -16,6 +16,7 @@ def mock_config_entry(hass): """Mock a config entry.""" entry = MockConfigEntry( domain="google_generative_ai_conversation", + title="Google Generative AI Conversation", data={ "api_key": "bla", }, diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr new file mode 100644 index 00000000000..bf37fe0f2d9 --- /dev/null +++ b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr @@ -0,0 +1,169 @@ +# serializer version: 1 +# name: test_default_prompt[None] + list([ + tuple( + '', + tuple( + ), + dict({ + 'generation_config': dict({ + 'max_output_tokens': 150, + 'temperature': 0.9, + 'top_k': 1, + 'top_p': 1.0, + }), + 'model_name': 'models/gemini-pro', + }), + ), + tuple( + '().start_chat', + tuple( + ), + dict({ + 'history': list([ + dict({ + 'parts': ''' + This smart home is controlled by Home Assistant. + + An overview of the areas and the devices in this smart home: + + Test Area: + - Test Device (Test Model) + + Test Area 2: + - Test Device 2 + - Test Device 3 (Test Model 3A) + - Test Device 4 + - 1 (3) + + Answer the user's questions about the world truthfully. + + If the user wants to control a device, reject the request and suggest using the Home Assistant app. + ''', + 'role': 'user', + }), + dict({ + 'parts': 'Ok', + 'role': 'model', + }), + ]), + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + 'hello', + ), + dict({ + }), + ), + ]) +# --- +# name: test_default_prompt[conversation.google_generative_ai_conversation] + list([ + tuple( + '', + tuple( + ), + dict({ + 'generation_config': dict({ + 'max_output_tokens': 150, + 'temperature': 0.9, + 'top_k': 1, + 'top_p': 1.0, + }), + 'model_name': 'models/gemini-pro', + }), + ), + tuple( + '().start_chat', + tuple( + ), + dict({ + 'history': list([ + dict({ + 'parts': ''' + This smart home is controlled by Home Assistant. + + An overview of the areas and the devices in this smart home: + + Test Area: + - Test Device (Test Model) + + Test Area 2: + - Test Device 2 + - Test Device 3 (Test Model 3A) + - Test Device 4 + - 1 (3) + + Answer the user's questions about the world truthfully. + + If the user wants to control a device, reject the request and suggest using the Home Assistant app. + ''', + 'role': 'user', + }), + dict({ + 'parts': 'Ok', + 'role': 'model', + }), + ]), + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + 'hello', + ), + dict({ + }), + ), + ]) +# --- +# name: test_generate_content_service_with_image + list([ + tuple( + '', + tuple( + ), + dict({ + 'model_name': 'gemini-pro-vision', + }), + ), + tuple( + '().generate_content_async', + tuple( + list([ + 'Describe this image from my doorbell camera', + dict({ + 'data': b'image bytes', + 'mime_type': 'image/jpeg', + }), + ]), + ), + dict({ + }), + ), + ]) +# --- +# name: test_generate_content_service_without_images + list([ + tuple( + '', + tuple( + ), + dict({ + 'model_name': 'gemini-pro', + }), + ), + tuple( + '().generate_content_async', + tuple( + list([ + 'Write an opening speech for a Home Assistant release party', + ]), + ), + dict({ + }), + ), + ]) +# --- diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr index 5347c010f28..aba3f35eb19 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr @@ -1,64 +1,4 @@ # serializer version: 1 -# name: test_default_prompt - list([ - tuple( - '', - tuple( - ), - dict({ - 'generation_config': dict({ - 'max_output_tokens': 150, - 'temperature': 0.9, - 'top_k': 1, - 'top_p': 1.0, - }), - 'model_name': 'models/gemini-pro', - }), - ), - tuple( - '().start_chat', - tuple( - ), - dict({ - 'history': list([ - dict({ - 'parts': ''' - This smart home is controlled by Home Assistant. - - An overview of the areas and the devices in this smart home: - - Test Area: - - Test Device (Test Model) - - Test Area 2: - - Test Device 2 - - Test Device 3 (Test Model 3A) - - Test Device 4 - - 1 (3) - - Answer the user's questions about the world truthfully. - - If the user wants to control a device, reject the request and suggest using the Home Assistant app. - ''', - 'role': 'user', - }), - dict({ - 'parts': 'Ok', - 'role': 'model', - }), - ]), - }), - ), - tuple( - '().start_chat().send_message_async', - tuple( - 'hello', - ), - dict({ - }), - ), - ]) -# --- # name: test_generate_content_service_with_image list([ tuple( diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py new file mode 100644 index 00000000000..e56838c4b31 --- /dev/null +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -0,0 +1,198 @@ +"""Tests for the Google Generative AI Conversation integration conversation platform.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +from google.api_core.exceptions import ClientError +import pytest +from syrupy.assertion import SnapshotAssertion + +from homeassistant.components import conversation +from homeassistant.core import Context, HomeAssistant +from homeassistant.helpers import area_registry as ar, device_registry as dr, intent + +from tests.common import MockConfigEntry + + +@pytest.mark.parametrize( + "agent_id", [None, "conversation.google_generative_ai_conversation"] +) +async def test_default_prompt( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + area_registry: ar.AreaRegistry, + device_registry: dr.DeviceRegistry, + snapshot: SnapshotAssertion, + agent_id: str | None, +) -> None: + """Test that the default prompt works.""" + entry = MockConfigEntry(title=None) + entry.add_to_hass(hass) + for i in range(3): + area_registry.async_create(f"{i}Empty Area") + + if agent_id is None: + agent_id = mock_config_entry.entry_id + + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "1234")}, + name="Test Device", + manufacturer="Test Manufacturer", + model="Test Model", + suggested_area="Test Area", + ) + for i in range(3): + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", f"{i}abcd")}, + name="Test Service", + manufacturer="Test Manufacturer", + model="Test Model", + suggested_area="Test Area", + entry_type=dr.DeviceEntryType.SERVICE, + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "5678")}, + name="Test Device 2", + manufacturer="Test Manufacturer 2", + model="Device 2", + suggested_area="Test Area 2", + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876")}, + name="Test Device 3", + manufacturer="Test Manufacturer 3", + model="Test Model 3A", + suggested_area="Test Area 2", + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "qwer")}, + name="Test Device 4", + suggested_area="Test Area 2", + ) + device = device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876-disabled")}, + name="Test Device 3", + manufacturer="Test Manufacturer 3", + model="Test Model 3A", + suggested_area="Test Area 2", + ) + device_registry.async_update_device( + device.id, disabled_by=dr.DeviceEntryDisabler.USER + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876-no-name")}, + manufacturer="Test Manufacturer NoName", + model="Test Model NoName", + suggested_area="Test Area 2", + ) + device_registry.async_get_or_create( + config_entry_id=entry.entry_id, + connections={("test", "9876-integer-values")}, + name=1, + manufacturer=2, + model=3, + suggested_area="Test Area 2", + ) + with patch("google.generativeai.GenerativeModel") as mock_model: + mock_chat = AsyncMock() + mock_model.return_value.start_chat.return_value = mock_chat + chat_response = MagicMock() + mock_chat.send_message_async.return_value = chat_response + chat_response.parts = ["Hi there!"] + chat_response.text = "Hi there!" + result = await conversation.async_converse( + hass, + "hello", + None, + Context(), + agent_id=agent_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" + assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot + + +async def test_error_handling( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: + """Test that client errors are caught.""" + with patch("google.generativeai.GenerativeModel") as mock_model: + mock_chat = AsyncMock() + mock_model.return_value.start_chat.return_value = mock_chat + mock_chat.send_message_async.side_effect = ClientError("some error") + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + assert result.response.as_dict()["speech"]["plain"]["speech"] == ( + "Sorry, I had a problem talking to Google Generative AI: None some error" + ) + + +async def test_blocked_response( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: + """Test response was blocked.""" + with patch("google.generativeai.GenerativeModel") as mock_model: + mock_chat = AsyncMock() + mock_model.return_value.start_chat.return_value = mock_chat + chat_response = MagicMock() + mock_chat.send_message_async.return_value = chat_response + chat_response.parts = [] + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + assert result.response.as_dict()["speech"]["plain"]["speech"] == ( + "Sorry, I had a problem talking to Google Generative AI. Likely blocked" + ) + + +async def test_template_error( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> None: + """Test that template error handling works.""" + hass.config_entries.async_update_entry( + mock_config_entry, + options={ + "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", + }, + ) + with ( + patch( + "google.generativeai.get_model", + ), + patch("google.generativeai.GenerativeModel"), + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + + +async def test_conversation_agent( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, +) -> None: + """Test GoogleGenerativeAIAgent.""" + agent = conversation.get_agent_manager(hass).async_get_agent( + mock_config_entry.entry_id + ) + assert agent.supported_languages == "*" diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py index bdf796b8c44..daae8582594 100644 --- a/tests/components/google_generative_ai_conversation/test_init.py +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -6,188 +6,12 @@ from google.api_core.exceptions import ClientError import pytest from syrupy.assertion import SnapshotAssertion -from homeassistant.components import conversation -from homeassistant.core import Context, HomeAssistant +from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import area_registry as ar, device_registry as dr, intent from tests.common import MockConfigEntry -async def test_default_prompt( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, - area_registry: ar.AreaRegistry, - device_registry: dr.DeviceRegistry, - snapshot: SnapshotAssertion, -) -> None: - """Test that the default prompt works.""" - entry = MockConfigEntry(title=None) - entry.add_to_hass(hass) - for i in range(3): - area_registry.async_create(f"{i}Empty Area") - - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "1234")}, - name="Test Device", - manufacturer="Test Manufacturer", - model="Test Model", - suggested_area="Test Area", - ) - for i in range(3): - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", f"{i}abcd")}, - name="Test Service", - manufacturer="Test Manufacturer", - model="Test Model", - suggested_area="Test Area", - entry_type=dr.DeviceEntryType.SERVICE, - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "5678")}, - name="Test Device 2", - manufacturer="Test Manufacturer 2", - model="Device 2", - suggested_area="Test Area 2", - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876")}, - name="Test Device 3", - manufacturer="Test Manufacturer 3", - model="Test Model 3A", - suggested_area="Test Area 2", - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "qwer")}, - name="Test Device 4", - suggested_area="Test Area 2", - ) - device = device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876-disabled")}, - name="Test Device 3", - manufacturer="Test Manufacturer 3", - model="Test Model 3A", - suggested_area="Test Area 2", - ) - device_registry.async_update_device( - device.id, disabled_by=dr.DeviceEntryDisabler.USER - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876-no-name")}, - manufacturer="Test Manufacturer NoName", - model="Test Model NoName", - suggested_area="Test Area 2", - ) - device_registry.async_get_or_create( - config_entry_id=entry.entry_id, - connections={("test", "9876-integer-values")}, - name=1, - manufacturer=2, - model=3, - suggested_area="Test Area 2", - ) - with patch("google.generativeai.GenerativeModel") as mock_model: - mock_chat = AsyncMock() - mock_model.return_value.start_chat.return_value = mock_chat - chat_response = MagicMock() - mock_chat.send_message_async.return_value = chat_response - chat_response.parts = ["Hi there!"] - chat_response.text = "Hi there!" - result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id - ) - - assert result.response.response_type == intent.IntentResponseType.ACTION_DONE - assert result.response.as_dict()["speech"]["plain"]["speech"] == "Hi there!" - assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot - - -async def test_error_handling( - hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component -) -> None: - """Test that client errors are caught.""" - with patch("google.generativeai.GenerativeModel") as mock_model: - mock_chat = AsyncMock() - mock_model.return_value.start_chat.return_value = mock_chat - mock_chat.send_message_async.side_effect = ClientError("some error") - result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id - ) - - assert result.response.response_type == intent.IntentResponseType.ERROR, result - assert result.response.error_code == "unknown", result - assert result.response.as_dict()["speech"]["plain"]["speech"] == ( - "Sorry, I had a problem talking to Google Generative AI: None some error" - ) - - -async def test_blocked_response( - hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component -) -> None: - """Test response was blocked.""" - with patch("google.generativeai.GenerativeModel") as mock_model: - mock_chat = AsyncMock() - mock_model.return_value.start_chat.return_value = mock_chat - chat_response = MagicMock() - mock_chat.send_message_async.return_value = chat_response - chat_response.parts = [] - result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id - ) - - assert result.response.response_type == intent.IntentResponseType.ERROR, result - assert result.response.error_code == "unknown", result - assert result.response.as_dict()["speech"]["plain"]["speech"] == ( - "Sorry, I had a problem talking to Google Generative AI. Likely blocked" - ) - - -async def test_template_error( - hass: HomeAssistant, mock_config_entry: MockConfigEntry -) -> None: - """Test that template error handling works.""" - hass.config_entries.async_update_entry( - mock_config_entry, - options={ - "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", - }, - ) - with ( - patch( - "google.generativeai.get_model", - ), - patch("google.generativeai.GenerativeModel"), - ): - await hass.config_entries.async_setup(mock_config_entry.entry_id) - await hass.async_block_till_done() - result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id - ) - - assert result.response.response_type == intent.IntentResponseType.ERROR, result - assert result.response.error_code == "unknown", result - - -async def test_conversation_agent( - hass: HomeAssistant, - mock_config_entry: MockConfigEntry, - mock_init_component, -) -> None: - """Test GoogleGenerativeAIAgent.""" - agent = conversation.get_agent_manager(hass).async_get_agent( - mock_config_entry.entry_id - ) - assert agent.supported_languages == "*" - - async def test_generate_content_service_without_images( hass: HomeAssistant, mock_config_entry: MockConfigEntry,