From ad1f0db5a46f985e995d7287bf05471d7bfdf8a7 Mon Sep 17 00:00:00 2001 From: tronikos Date: Sat, 22 Jun 2024 03:35:48 -0700 Subject: [PATCH] Pass prompt as system_instruction for Gemini 1.5 models (#120147) --- .../conversation.py | 172 ++++++++-------- homeassistant/helpers/llm.py | 1 + .../snapshots/test_conversation.ambr | 192 +++++++++++++----- .../test_conversation.py | 29 ++- 4 files changed, 253 insertions(+), 141 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 65c0dc7fd93..b9f0006dbff 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -161,10 +161,14 @@ class GoogleGenerativeAIConversationEntity( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" - intent_response = intent.IntentResponse(language=user_input.language) - llm_api: llm.APIInstance | None = None - tools: list[dict[str, Any]] | None = None - user_name: str | None = None + result = conversation.ConversationResult( + response=intent.IntentResponse(language=user_input.language), + conversation_id=user_input.conversation_id + if user_input.conversation_id in self.history + else ulid.ulid_now(), + ) + assert result.conversation_id + llm_context = llm.LLMContext( platform=DOMAIN, context=user_input.context, @@ -173,7 +177,8 @@ class GoogleGenerativeAIConversationEntity( assistant=conversation.DOMAIN, device_id=user_input.device_id, ) - + llm_api: llm.APIInstance | None = None + tools: list[dict[str, Any]] | None = None if self.entry.options.get(CONF_LLM_HASS_API): try: llm_api = await llm.async_get_api( @@ -183,17 +188,33 @@ class GoogleGenerativeAIConversationEntity( ) except HomeAssistantError as err: LOGGER.error("Error getting LLM API: %s", err) - intent_response.async_set_error( + result.response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, f"Error preparing LLM API: {err}", ) - return conversation.ConversationResult( - response=intent_response, conversation_id=user_input.conversation_id - ) + return result tools = [_format_tool(tool) for tool in llm_api.tools] + try: + prompt = await self._async_render_prompt(user_input, llm_api, llm_context) + except TemplateError as err: + LOGGER.error("Error rendering prompt: %s", err) + result.response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem with my template: {err}", + ) + return result + + model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) + # Gemini 1.0 doesn't support system_instruction while 1.5 does. + # Assume future versions will support it (if not, the request fails with a + # clear message at which point we can fix). + supports_system_instruction = ( + "gemini-1.0" not in model_name and "gemini-pro" not in model_name + ) + model = genai.GenerativeModel( - model_name=self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), + model_name=model_name, generation_config={ "temperature": self.entry.options.get( CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE @@ -219,69 +240,25 @@ class GoogleGenerativeAIConversationEntity( ), }, tools=tools or None, + system_instruction=prompt if supports_system_instruction else None, ) - if user_input.conversation_id in self.history: - conversation_id = user_input.conversation_id - messages = self.history[conversation_id] - else: - conversation_id = ulid.ulid_now() - messages = [{}, {"role": "model", "parts": "Ok"}] - - if ( - user_input.context - and user_input.context.user_id - and ( - user := await self.hass.auth.async_get_user(user_input.context.user_id) - ) - ): - user_name = user.name - - try: - if llm_api: - api_prompt = llm_api.api_prompt - else: - api_prompt = llm.async_render_no_api_prompt(self.hass) - - prompt = "\n".join( - ( - template.Template( - llm.BASE_PROMPT - + self.entry.options.get( - CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT - ), - self.hass, - ).async_render( - { - "ha_name": self.hass.config.location_name, - "user_name": user_name, - "llm_context": llm_context, - }, - parse_result=False, - ), - api_prompt, - ) - ) - - except TemplateError as err: - LOGGER.error("Error rendering prompt: %s", err) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem with my template: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - # Make a copy, because we attach it to the trace event. - messages = [ - {"role": "user", "parts": prompt}, - *messages[1:], - ] + messages = self.history.get(result.conversation_id, []) + if not supports_system_instruction: + if not messages: + messages = [{}, {"role": "model", "parts": "Ok"}] + messages[0] = {"role": "user", "parts": prompt} LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) trace.async_conversation_trace_append( - trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages} + trace.ConversationTraceEventType.AGENT_DETAIL, + { + # Make a copy to attach it to the trace event. + "messages": messages[:] + if supports_system_instruction + else messages[2:], + "prompt": prompt, + }, ) chat = model.start_chat(history=messages) @@ -307,24 +284,20 @@ class GoogleGenerativeAIConversationEntity( f"Sorry, I had a problem talking to Google Generative AI: {err}" ) - intent_response.async_set_error( + result.response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, error, ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) + return result LOGGER.debug("Response: %s", chat_response.parts) if not chat_response.parts: - intent_response.async_set_error( + result.response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, "Sorry, I had a problem getting a response from Google Generative AI.", ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - self.history[conversation_id] = chat.history + return result + self.history[result.conversation_id] = chat.history function_calls = [ part.function_call for part in chat_response.parts if part.function_call ] @@ -355,9 +328,48 @@ class GoogleGenerativeAIConversationEntity( ) chat_request = protos.Content(parts=tool_responses) - intent_response.async_set_speech( + result.response.async_set_speech( " ".join([part.text.strip() for part in chat_response.parts if part.text]) ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id + return result + + async def _async_render_prompt( + self, + user_input: conversation.ConversationInput, + llm_api: llm.APIInstance | None, + llm_context: llm.LLMContext, + ) -> str: + user_name: str | None = None + if ( + user_input.context + and user_input.context.user_id + and ( + user := await self.hass.auth.async_get_user(user_input.context.user_id) + ) + ): + user_name = user.name + + if llm_api: + api_prompt = llm_api.api_prompt + else: + api_prompt = llm.async_render_no_api_prompt(self.hass) + + return "\n".join( + ( + template.Template( + llm.BASE_PROMPT + + self.entry.options.get( + CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT + ), + self.hass, + ).async_render( + { + "ha_name": self.hass.config.location_name, + "user_name": user_name, + "llm_context": llm_context, + }, + parse_result=False, + ), + api_prompt, + ) ) diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index 903e52af1a2..53ec092fda2 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -43,6 +43,7 @@ BASE_PROMPT = ( ) DEFAULT_INSTRUCTIONS_PROMPT = """You are a voice assistant for Home Assistant. +Answer questions about the world truthfully. Answer in plain text. Keep it simple and to the point. """ diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr index 70db5d11868..aec8d088b20 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr @@ -1,5 +1,5 @@ # serializer version: 1 -# name: test_chat_history +# name: test_chat_history[models/gemini-1.0-pro-False] list([ tuple( '', @@ -12,13 +12,14 @@ 'top_k': 64, 'top_p': 0.95, }), - 'model_name': 'models/gemini-1.5-flash-latest', + 'model_name': 'models/gemini-1.0-pro', 'safety_settings': dict({ 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', }), + 'system_instruction': None, 'tools': None, }), ), @@ -32,6 +33,7 @@ 'parts': ''' Current time is 05:00:00. Today's date is 2024-05-24. You are a voice assistant for Home Assistant. + Answer questions about the world truthfully. Answer in plain text. Keep it simple and to the point. Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. ''', @@ -63,13 +65,14 @@ 'top_k': 64, 'top_p': 0.95, }), - 'model_name': 'models/gemini-1.5-flash-latest', + 'model_name': 'models/gemini-1.0-pro', 'safety_settings': dict({ 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', }), + 'system_instruction': None, 'tools': None, }), ), @@ -83,6 +86,7 @@ 'parts': ''' Current time is 05:00:00. Today's date is 2024-05-24. You are a voice assistant for Home Assistant. + Answer questions about the world truthfully. Answer in plain text. Keep it simple and to the point. Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. ''', @@ -113,6 +117,108 @@ ), ]) # --- +# name: test_chat_history[models/gemini-1.5-pro-True] + list([ + tuple( + '', + tuple( + ), + dict({ + 'generation_config': dict({ + 'max_output_tokens': 150, + 'temperature': 1.0, + 'top_k': 64, + 'top_p': 0.95, + }), + 'model_name': 'models/gemini-1.5-pro', + 'safety_settings': dict({ + 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', + 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', + 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', + 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', + }), + 'system_instruction': ''' + Current time is 05:00:00. Today's date is 2024-05-24. + You are a voice assistant for Home Assistant. + Answer questions about the world truthfully. + Answer in plain text. Keep it simple and to the point. + Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. + ''', + 'tools': None, + }), + ), + tuple( + '().start_chat', + tuple( + ), + dict({ + 'history': list([ + ]), + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + '1st user request', + ), + dict({ + }), + ), + tuple( + '', + tuple( + ), + dict({ + 'generation_config': dict({ + 'max_output_tokens': 150, + 'temperature': 1.0, + 'top_k': 64, + 'top_p': 0.95, + }), + 'model_name': 'models/gemini-1.5-pro', + 'safety_settings': dict({ + 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', + 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', + 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', + 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', + }), + 'system_instruction': ''' + Current time is 05:00:00. Today's date is 2024-05-24. + You are a voice assistant for Home Assistant. + Answer questions about the world truthfully. + Answer in plain text. Keep it simple and to the point. + Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. + ''', + 'tools': None, + }), + ), + tuple( + '().start_chat', + tuple( + ), + dict({ + 'history': list([ + dict({ + 'parts': '1st user request', + 'role': 'user', + }), + dict({ + 'parts': '1st model response', + 'role': 'model', + }), + ]), + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + '2nd user request', + ), + dict({ + }), + ), + ]) +# --- # name: test_default_prompt[config_entry_options0-None] list([ tuple( @@ -133,6 +239,13 @@ 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', }), + 'system_instruction': ''' + Current time is 05:00:00. Today's date is 2024-05-24. + You are a voice assistant for Home Assistant. + Answer questions about the world truthfully. + Answer in plain text. Keep it simple and to the point. + + ''', 'tools': None, }), ), @@ -142,19 +255,6 @@ ), dict({ 'history': list([ - dict({ - 'parts': ''' - Current time is 05:00:00. Today's date is 2024-05-24. - You are a voice assistant for Home Assistant. - Answer in plain text. Keep it simple and to the point. - - ''', - 'role': 'user', - }), - dict({ - 'parts': 'Ok', - 'role': 'model', - }), ]), }), ), @@ -188,6 +288,13 @@ 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', }), + 'system_instruction': ''' + Current time is 05:00:00. Today's date is 2024-05-24. + You are a voice assistant for Home Assistant. + Answer questions about the world truthfully. + Answer in plain text. Keep it simple and to the point. + + ''', 'tools': None, }), ), @@ -197,19 +304,6 @@ ), dict({ 'history': list([ - dict({ - 'parts': ''' - Current time is 05:00:00. Today's date is 2024-05-24. - You are a voice assistant for Home Assistant. - Answer in plain text. Keep it simple and to the point. - - ''', - 'role': 'user', - }), - dict({ - 'parts': 'Ok', - 'role': 'model', - }), ]), }), ), @@ -243,6 +337,13 @@ 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', }), + 'system_instruction': ''' + Current time is 05:00:00. Today's date is 2024-05-24. + You are a voice assistant for Home Assistant. + Answer questions about the world truthfully. + Answer in plain text. Keep it simple and to the point. + + ''', 'tools': None, }), ), @@ -252,19 +353,6 @@ ), dict({ 'history': list([ - dict({ - 'parts': ''' - Current time is 05:00:00. Today's date is 2024-05-24. - You are a voice assistant for Home Assistant. - Answer in plain text. Keep it simple and to the point. - - ''', - 'role': 'user', - }), - dict({ - 'parts': 'Ok', - 'role': 'model', - }), ]), }), ), @@ -298,6 +386,13 @@ 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', }), + 'system_instruction': ''' + Current time is 05:00:00. Today's date is 2024-05-24. + You are a voice assistant for Home Assistant. + Answer questions about the world truthfully. + Answer in plain text. Keep it simple and to the point. + + ''', 'tools': None, }), ), @@ -307,19 +402,6 @@ ), dict({ 'history': list([ - dict({ - 'parts': ''' - Current time is 05:00:00. Today's date is 2024-05-24. - You are a voice assistant for Home Assistant. - Answer in plain text. Keep it simple and to the point. - - ''', - 'role': 'user', - }), - dict({ - 'parts': 'Ok', - 'role': 'model', - }), ]), }), ), diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index e84efffe7df..7f4fe886e90 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -12,6 +12,9 @@ import voluptuous as vol from homeassistant.components import conversation from homeassistant.components.conversation import trace +from homeassistant.components.google_generative_ai_conversation.const import ( + CONF_CHAT_MODEL, +) from homeassistant.components.google_generative_ai_conversation.conversation import ( _escape_decode, ) @@ -99,13 +102,22 @@ async def test_default_prompt( assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options) +@pytest.mark.parametrize( + ("model_name", "supports_system_instruction"), + [("models/gemini-1.5-pro", True), ("models/gemini-1.0-pro", False)], +) async def test_chat_history( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component, + model_name: str, + supports_system_instruction: bool, snapshot: SnapshotAssertion, ) -> None: """Test that the agent keeps track of the chat history.""" + hass.config_entries.async_update_entry( + mock_config_entry, options={CONF_CHAT_MODEL: model_name} + ) with patch("google.generativeai.GenerativeModel") as mock_model: mock_chat = AsyncMock() mock_model.return_value.start_chat.return_value = mock_chat @@ -115,9 +127,14 @@ async def test_chat_history( mock_part.function_call = None mock_part.text = "1st model response" chat_response.parts = [mock_part] - mock_chat.history = [ - {"role": "user", "parts": "prompt"}, - {"role": "model", "parts": "Ok"}, + if supports_system_instruction: + mock_chat.history = [] + else: + mock_chat.history = [ + {"role": "user", "parts": "prompt"}, + {"role": "model", "parts": "Ok"}, + ] + mock_chat.history += [ {"role": "user", "parts": "1st user request"}, {"role": "model", "parts": "1st model response"}, ] @@ -256,7 +273,7 @@ async def test_function_call( ] # AGENT_DETAIL event contains the raw prompt passed to the model detail_event = trace_events[1] - assert "Answer in plain text" in detail_event["data"]["messages"][0]["parts"] + assert "Answer in plain text" in detail_event["data"]["prompt"] @patch( @@ -492,9 +509,9 @@ async def test_template_variables( ), result assert ( "The user name is Test User." - in mock_model.mock_calls[1][2]["history"][0]["parts"] + in mock_model.mock_calls[0][2]["system_instruction"] ) - assert "The user id is 12345." in mock_model.mock_calls[1][2]["history"][0]["parts"] + assert "The user id is 12345." in mock_model.mock_calls[0][2]["system_instruction"] async def test_conversation_agent(