Fix unnecessary single quotes escaping in Google AI (#118522)
This commit is contained in:
parent
0d6c7d0973
commit
272c51fb38
3 changed files with 38 additions and 17 deletions
|
@ -8,6 +8,7 @@ import google.ai.generativelanguage as glm
|
||||||
from google.api_core.exceptions import GoogleAPICallError
|
from google.api_core.exceptions import GoogleAPICallError
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import google.generativeai.types as genai_types
|
import google.generativeai.types as genai_types
|
||||||
|
from google.protobuf.json_format import MessageToDict
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
|
@ -105,6 +106,17 @@ def _format_tool(tool: llm.Tool) -> dict[str, Any]:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _adjust_value(value: Any) -> Any:
|
||||||
|
"""Reverse unnecessary single quotes escaping."""
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value.replace("\\'", "'")
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [_adjust_value(item) for item in value]
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {k: _adjust_value(v) for k, v in value.items()}
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class GoogleGenerativeAIConversationEntity(
|
class GoogleGenerativeAIConversationEntity(
|
||||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||||
):
|
):
|
||||||
|
@ -295,21 +307,22 @@ class GoogleGenerativeAIConversationEntity(
|
||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=conversation_id
|
||||||
)
|
)
|
||||||
self.history[conversation_id] = chat.history
|
self.history[conversation_id] = chat.history
|
||||||
tool_calls = [
|
function_calls = [
|
||||||
part.function_call for part in chat_response.parts if part.function_call
|
part.function_call for part in chat_response.parts if part.function_call
|
||||||
]
|
]
|
||||||
if not tool_calls or not llm_api:
|
if not function_calls or not llm_api:
|
||||||
break
|
break
|
||||||
|
|
||||||
tool_responses = []
|
tool_responses = []
|
||||||
for tool_call in tool_calls:
|
for function_call in function_calls:
|
||||||
tool_input = llm.ToolInput(
|
tool_call = MessageToDict(function_call._pb) # noqa: SLF001
|
||||||
tool_name=tool_call.name,
|
tool_name = tool_call["name"]
|
||||||
tool_args=dict(tool_call.args),
|
tool_args = {
|
||||||
)
|
key: _adjust_value(value)
|
||||||
LOGGER.debug(
|
for key, value in tool_call["args"].items()
|
||||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
}
|
||||||
)
|
LOGGER.debug("Tool call: %s(%s)", tool_name, tool_args)
|
||||||
|
tool_input = llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
|
||||||
try:
|
try:
|
||||||
function_response = await llm_api.async_call_tool(tool_input)
|
function_response = await llm_api.async_call_tool(tool_input)
|
||||||
except (HomeAssistantError, vol.Invalid) as e:
|
except (HomeAssistantError, vol.Invalid) as e:
|
||||||
|
@ -321,7 +334,7 @@ class GoogleGenerativeAIConversationEntity(
|
||||||
tool_responses.append(
|
tool_responses.append(
|
||||||
glm.Part(
|
glm.Part(
|
||||||
function_response=glm.FunctionResponse(
|
function_response=glm.FunctionResponse(
|
||||||
name=tool_call.name, response=function_response
|
name=tool_name, response=function_response
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -140,7 +140,7 @@ class APIInstance:
|
||||||
"""Call a LLM tool, validate args and return the response."""
|
"""Call a LLM tool, validate args and return the response."""
|
||||||
async_conversation_trace_append(
|
async_conversation_trace_append(
|
||||||
ConversationTraceEventType.LLM_TOOL_CALL,
|
ConversationTraceEventType.LLM_TOOL_CALL,
|
||||||
{"tool_name": tool_input.tool_name, "tool_args": str(tool_input.tool_args)},
|
{"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args},
|
||||||
)
|
)
|
||||||
|
|
||||||
for tool in self.tools:
|
for tool in self.tools:
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
from google.ai.generativelanguage_v1beta.types.content import FunctionCall
|
||||||
from google.api_core.exceptions import GoogleAPICallError
|
from google.api_core.exceptions import GoogleAPICallError
|
||||||
import google.generativeai.types as genai_types
|
import google.generativeai.types as genai_types
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -179,8 +180,13 @@ async def test_function_call(
|
||||||
chat_response = MagicMock()
|
chat_response = MagicMock()
|
||||||
mock_chat.send_message_async.return_value = chat_response
|
mock_chat.send_message_async.return_value = chat_response
|
||||||
mock_part = MagicMock()
|
mock_part = MagicMock()
|
||||||
mock_part.function_call.name = "test_tool"
|
mock_part.function_call = FunctionCall(
|
||||||
mock_part.function_call.args = {"param1": ["test_value"]}
|
name="test_tool",
|
||||||
|
args={
|
||||||
|
"param1": ["test_value", "param1\\'s value"],
|
||||||
|
"param2": "param2\\'s value",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def tool_call(hass, tool_input, tool_context):
|
def tool_call(hass, tool_input, tool_context):
|
||||||
mock_part.function_call = None
|
mock_part.function_call = None
|
||||||
|
@ -220,7 +226,10 @@ async def test_function_call(
|
||||||
hass,
|
hass,
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": ["test_value"]},
|
tool_args={
|
||||||
|
"param1": ["test_value", "param1's value"],
|
||||||
|
"param2": "param2's value",
|
||||||
|
},
|
||||||
),
|
),
|
||||||
llm.ToolContext(
|
llm.ToolContext(
|
||||||
platform="google_generative_ai_conversation",
|
platform="google_generative_ai_conversation",
|
||||||
|
@ -279,8 +288,7 @@ async def test_function_exception(
|
||||||
chat_response = MagicMock()
|
chat_response = MagicMock()
|
||||||
mock_chat.send_message_async.return_value = chat_response
|
mock_chat.send_message_async.return_value = chat_response
|
||||||
mock_part = MagicMock()
|
mock_part = MagicMock()
|
||||||
mock_part.function_call.name = "test_tool"
|
mock_part.function_call = FunctionCall(name="test_tool", args={"param1": 1})
|
||||||
mock_part.function_call.args = {"param1": 1}
|
|
||||||
|
|
||||||
def tool_call(hass, tool_input, tool_context):
|
def tool_call(hass, tool_input, tool_context):
|
||||||
mock_part.function_call = None
|
mock_part.function_call = None
|
||||||
|
|
Loading…
Add table
Reference in a new issue