Handle multiple function_call and text parts in Google Generative AI (#118270)

This commit is contained in:
tronikos 2024-05-27 16:57:03 -07:00 committed by GitHub
parent bfc3194661
commit 722feb285b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 40 additions and 35 deletions

View file

@ -298,43 +298,47 @@ class GoogleGenerativeAIConversationEntity(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )
self.history[conversation_id] = chat.history self.history[conversation_id] = chat.history
tool_call = chat_response.parts[0].function_call tool_calls = [
part.function_call for part in chat_response.parts if part.function_call
if not tool_call or not llm_api: ]
if not tool_calls or not llm_api:
break break
tool_input = llm.ToolInput( tool_responses = []
tool_name=tool_call.name, for tool_call in tool_calls:
tool_args=dict(tool_call.args), tool_input = llm.ToolInput(
platform=DOMAIN, tool_name=tool_call.name,
context=user_input.context, tool_args=dict(tool_call.args),
user_prompt=user_input.text, platform=DOMAIN,
language=user_input.language, context=user_input.context,
assistant=conversation.DOMAIN, user_prompt=user_input.text,
device_id=user_input.device_id, language=user_input.language,
) assistant=conversation.DOMAIN,
LOGGER.debug( device_id=user_input.device_id,
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args )
) LOGGER.debug(
try: "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
function_response = await llm_api.async_call_tool(tool_input) )
except (HomeAssistantError, vol.Invalid) as e: try:
function_response = {"error": type(e).__name__} function_response = await llm_api.async_call_tool(tool_input)
if str(e): except (HomeAssistantError, vol.Invalid) as e:
function_response["error_text"] = str(e) function_response = {"error": type(e).__name__}
if str(e):
function_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", function_response) LOGGER.debug("Tool response: %s", function_response)
chat_request = glm.Content( tool_responses.append(
parts=[
glm.Part( glm.Part(
function_response=glm.FunctionResponse( function_response=glm.FunctionResponse(
name=tool_call.name, response=function_response name=tool_call.name, response=function_response
) )
) )
] )
) chat_request = glm.Content(parts=tool_responses)
intent_response.async_set_speech(chat_response.text) intent_response.async_set_speech(
" ".join([part.text for part in chat_response.parts if part.text])
)
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )

View file

@ -191,8 +191,8 @@ async def test_default_prompt(
mock_chat.send_message_async.return_value = chat_response mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock() mock_part = MagicMock()
mock_part.function_call = None mock_part.function_call = None
mock_part.text = "Hi there!"
chat_response.parts = [mock_part] chat_response.parts = [mock_part]
chat_response.text = "Hi there!"
result = await conversation.async_converse( result = await conversation.async_converse(
hass, hass,
"hello", "hello",
@ -221,8 +221,8 @@ async def test_chat_history(
mock_chat.send_message_async.return_value = chat_response mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock() mock_part = MagicMock()
mock_part.function_call = None mock_part.function_call = None
mock_part.text = "1st model response"
chat_response.parts = [mock_part] chat_response.parts = [mock_part]
chat_response.text = "1st model response"
mock_chat.history = [ mock_chat.history = [
{"role": "user", "parts": "prompt"}, {"role": "user", "parts": "prompt"},
{"role": "model", "parts": "Ok"}, {"role": "model", "parts": "Ok"},
@ -241,7 +241,8 @@ async def test_chat_history(
result.response.as_dict()["speech"]["plain"]["speech"] result.response.as_dict()["speech"]["plain"]["speech"]
== "1st model response" == "1st model response"
) )
chat_response.text = "2nd model response" mock_part.text = "2nd model response"
chat_response.parts = [mock_part]
result = await conversation.async_converse( result = await conversation.async_converse(
hass, hass,
"2nd user request", "2nd user request",
@ -294,8 +295,8 @@ async def test_function_call(
mock_part.function_call.args = {"param1": ["test_value"]} mock_part.function_call.args = {"param1": ["test_value"]}
def tool_call(hass, tool_input): def tool_call(hass, tool_input):
mock_part.function_call = False mock_part.function_call = None
chat_response.text = "Hi there!" mock_part.text = "Hi there!"
return {"result": "Test response"} return {"result": "Test response"}
mock_tool.async_call.side_effect = tool_call mock_tool.async_call.side_effect = tool_call
@ -392,8 +393,8 @@ async def test_function_exception(
mock_part.function_call.args = {"param1": 1} mock_part.function_call.args = {"param1": 1}
def tool_call(hass, tool_input): def tool_call(hass, tool_input):
mock_part.function_call = False mock_part.function_call = None
chat_response.text = "Hi there!" mock_part.text = "Hi there!"
raise HomeAssistantError("Test tool exception") raise HomeAssistantError("Test tool exception")
mock_tool.async_call.side_effect = tool_call mock_tool.async_call.side_effect = tool_call