Handle multiple function_call and text parts in Google Generative AI (#118270)

This commit is contained in:
tronikos 2024-05-27 16:57:03 -07:00 committed by GitHub
parent bfc3194661
commit 722feb285b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 40 additions and 35 deletions

View file

@ -298,11 +298,14 @@ class GoogleGenerativeAIConversationEntity(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )
self.history[conversation_id] = chat.history self.history[conversation_id] = chat.history
tool_call = chat_response.parts[0].function_call tool_calls = [
part.function_call for part in chat_response.parts if part.function_call
if not tool_call or not llm_api: ]
if not tool_calls or not llm_api:
break break
tool_responses = []
for tool_call in tool_calls:
tool_input = llm.ToolInput( tool_input = llm.ToolInput(
tool_name=tool_call.name, tool_name=tool_call.name,
tool_args=dict(tool_call.args), tool_args=dict(tool_call.args),
@ -324,17 +327,18 @@ class GoogleGenerativeAIConversationEntity(
function_response["error_text"] = str(e) function_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", function_response) LOGGER.debug("Tool response: %s", function_response)
chat_request = glm.Content( tool_responses.append(
parts=[
glm.Part( glm.Part(
function_response=glm.FunctionResponse( function_response=glm.FunctionResponse(
name=tool_call.name, response=function_response name=tool_call.name, response=function_response
) )
) )
]
) )
chat_request = glm.Content(parts=tool_responses)
intent_response.async_set_speech(chat_response.text) intent_response.async_set_speech(
" ".join([part.text for part in chat_response.parts if part.text])
)
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )

View file

@ -191,8 +191,8 @@ async def test_default_prompt(
mock_chat.send_message_async.return_value = chat_response mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock() mock_part = MagicMock()
mock_part.function_call = None mock_part.function_call = None
mock_part.text = "Hi there!"
chat_response.parts = [mock_part] chat_response.parts = [mock_part]
chat_response.text = "Hi there!"
result = await conversation.async_converse( result = await conversation.async_converse(
hass, hass,
"hello", "hello",
@ -221,8 +221,8 @@ async def test_chat_history(
mock_chat.send_message_async.return_value = chat_response mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock() mock_part = MagicMock()
mock_part.function_call = None mock_part.function_call = None
mock_part.text = "1st model response"
chat_response.parts = [mock_part] chat_response.parts = [mock_part]
chat_response.text = "1st model response"
mock_chat.history = [ mock_chat.history = [
{"role": "user", "parts": "prompt"}, {"role": "user", "parts": "prompt"},
{"role": "model", "parts": "Ok"}, {"role": "model", "parts": "Ok"},
@ -241,7 +241,8 @@ async def test_chat_history(
result.response.as_dict()["speech"]["plain"]["speech"] result.response.as_dict()["speech"]["plain"]["speech"]
== "1st model response" == "1st model response"
) )
chat_response.text = "2nd model response" mock_part.text = "2nd model response"
chat_response.parts = [mock_part]
result = await conversation.async_converse( result = await conversation.async_converse(
hass, hass,
"2nd user request", "2nd user request",
@ -294,8 +295,8 @@ async def test_function_call(
mock_part.function_call.args = {"param1": ["test_value"]} mock_part.function_call.args = {"param1": ["test_value"]}
def tool_call(hass, tool_input): def tool_call(hass, tool_input):
mock_part.function_call = False mock_part.function_call = None
chat_response.text = "Hi there!" mock_part.text = "Hi there!"
return {"result": "Test response"} return {"result": "Test response"}
mock_tool.async_call.side_effect = tool_call mock_tool.async_call.side_effect = tool_call
@ -392,8 +393,8 @@ async def test_function_exception(
mock_part.function_call.args = {"param1": 1} mock_part.function_call.args = {"param1": 1}
def tool_call(hass, tool_input): def tool_call(hass, tool_input):
mock_part.function_call = False mock_part.function_call = None
chat_response.text = "Hi there!" mock_part.text = "Hi there!"
raise HomeAssistantError("Test tool exception") raise HomeAssistantError("Test tool exception")
mock_tool.async_call.side_effect = tool_call mock_tool.async_call.side_effect = tool_call