Fix unnecessary single quotes escaping in Google AI (#118522)

This commit is contained in:
tronikos 2024-05-30 16:56:06 -07:00 committed by GitHub
parent 0d6c7d0973
commit 272c51fb38
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 38 additions and 17 deletions

View file

@ -8,6 +8,7 @@ import google.ai.generativelanguage as glm
from google.api_core.exceptions import GoogleAPICallError
import google.generativeai as genai
import google.generativeai.types as genai_types
from google.protobuf.json_format import MessageToDict
import voluptuous as vol
from voluptuous_openapi import convert
@ -105,6 +106,17 @@ def _format_tool(tool: llm.Tool) -> dict[str, Any]:
)
def _adjust_value(value: Any) -> Any:
"""Reverse unnecessary single quotes escaping."""
if isinstance(value, str):
return value.replace("\\'", "'")
if isinstance(value, list):
return [_adjust_value(item) for item in value]
if isinstance(value, dict):
return {k: _adjust_value(v) for k, v in value.items()}
return value
class GoogleGenerativeAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
@ -295,21 +307,22 @@ class GoogleGenerativeAIConversationEntity(
response=intent_response, conversation_id=conversation_id
)
self.history[conversation_id] = chat.history
tool_calls = [
function_calls = [
part.function_call for part in chat_response.parts if part.function_call
]
if not tool_calls or not llm_api:
if not function_calls or not llm_api:
break
tool_responses = []
for tool_call in tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call.name,
tool_args=dict(tool_call.args),
)
LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
for function_call in function_calls:
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
tool_name = tool_call["name"]
tool_args = {
key: _adjust_value(value)
for key, value in tool_call["args"].items()
}
LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
try:
function_response = await llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
@ -321,7 +334,7 @@ class GoogleGenerativeAIConversationEntity(
tool_responses.append(
glm.Part(
function_response=glm.FunctionResponse(
name=tool_call.name, response=function_response
name=tool_name, response=function_response
)
)
)