Fix unnecessary single quotes escaping in Google AI (#118522)
This commit is contained in:
parent
0d6c7d0973
commit
272c51fb38
3 changed files with 38 additions and 17 deletions
|
@ -8,6 +8,7 @@ import google.ai.generativelanguage as glm
|
|||
from google.api_core.exceptions import GoogleAPICallError
|
||||
import google.generativeai as genai
|
||||
import google.generativeai.types as genai_types
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
|
@ -105,6 +106,17 @@ def _format_tool(tool: llm.Tool) -> dict[str, Any]:
|
|||
)
|
||||
|
||||
|
||||
def _adjust_value(value: Any) -> Any:
|
||||
"""Reverse unnecessary single quotes escaping."""
|
||||
if isinstance(value, str):
|
||||
return value.replace("\\'", "'")
|
||||
if isinstance(value, list):
|
||||
return [_adjust_value(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {k: _adjust_value(v) for k, v in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
class GoogleGenerativeAIConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
|
@ -295,21 +307,22 @@ class GoogleGenerativeAIConversationEntity(
|
|||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
self.history[conversation_id] = chat.history
|
||||
tool_calls = [
|
||||
function_calls = [
|
||||
part.function_call for part in chat_response.parts if part.function_call
|
||||
]
|
||||
if not tool_calls or not llm_api:
|
||||
if not function_calls or not llm_api:
|
||||
break
|
||||
|
||||
tool_responses = []
|
||||
for tool_call in tool_calls:
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_call.name,
|
||||
tool_args=dict(tool_call.args),
|
||||
)
|
||||
LOGGER.debug(
|
||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||
)
|
||||
for function_call in function_calls:
|
||||
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = {
|
||||
key: _adjust_value(value)
|
||||
for key, value in tool_call["args"].items()
|
||||
}
|
||||
LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
|
||||
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||
try:
|
||||
function_response = await llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
|
@ -321,7 +334,7 @@ class GoogleGenerativeAIConversationEntity(
|
|||
tool_responses.append(
|
||||
glm.Part(
|
||||
function_response=glm.FunctionResponse(
|
||||
name=tool_call.name, response=function_response
|
||||
name=tool_name, response=function_response
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue