Handle multiple function_call and text parts in Google Generative AI (#118270)
This commit is contained in:
parent
bfc3194661
commit
722feb285b
2 changed files with 40 additions and 35 deletions
|
@ -298,43 +298,47 @@ class GoogleGenerativeAIConversationEntity(
|
|||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
self.history[conversation_id] = chat.history
|
||||
tool_call = chat_response.parts[0].function_call
|
||||
|
||||
if not tool_call or not llm_api:
|
||||
tool_calls = [
|
||||
part.function_call for part in chat_response.parts if part.function_call
|
||||
]
|
||||
if not tool_calls or not llm_api:
|
||||
break
|
||||
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_call.name,
|
||||
tool_args=dict(tool_call.args),
|
||||
platform=DOMAIN,
|
||||
context=user_input.context,
|
||||
user_prompt=user_input.text,
|
||||
language=user_input.language,
|
||||
assistant=conversation.DOMAIN,
|
||||
device_id=user_input.device_id,
|
||||
)
|
||||
LOGGER.debug(
|
||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||
)
|
||||
try:
|
||||
function_response = await llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
function_response = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
function_response["error_text"] = str(e)
|
||||
tool_responses = []
|
||||
for tool_call in tool_calls:
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_call.name,
|
||||
tool_args=dict(tool_call.args),
|
||||
platform=DOMAIN,
|
||||
context=user_input.context,
|
||||
user_prompt=user_input.text,
|
||||
language=user_input.language,
|
||||
assistant=conversation.DOMAIN,
|
||||
device_id=user_input.device_id,
|
||||
)
|
||||
LOGGER.debug(
|
||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||
)
|
||||
try:
|
||||
function_response = await llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
function_response = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
function_response["error_text"] = str(e)
|
||||
|
||||
LOGGER.debug("Tool response: %s", function_response)
|
||||
chat_request = glm.Content(
|
||||
parts=[
|
||||
LOGGER.debug("Tool response: %s", function_response)
|
||||
tool_responses.append(
|
||||
glm.Part(
|
||||
function_response=glm.FunctionResponse(
|
||||
name=tool_call.name, response=function_response
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
chat_request = glm.Content(parts=tool_responses)
|
||||
|
||||
intent_response.async_set_speech(chat_response.text)
|
||||
intent_response.async_set_speech(
|
||||
" ".join([part.text for part in chat_response.parts if part.text])
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
|
|
@ -191,8 +191,8 @@ async def test_default_prompt(
|
|||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_part.function_call = None
|
||||
mock_part.text = "Hi there!"
|
||||
chat_response.parts = [mock_part]
|
||||
chat_response.text = "Hi there!"
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"hello",
|
||||
|
@ -221,8 +221,8 @@ async def test_chat_history(
|
|||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_part.function_call = None
|
||||
mock_part.text = "1st model response"
|
||||
chat_response.parts = [mock_part]
|
||||
chat_response.text = "1st model response"
|
||||
mock_chat.history = [
|
||||
{"role": "user", "parts": "prompt"},
|
||||
{"role": "model", "parts": "Ok"},
|
||||
|
@ -241,7 +241,8 @@ async def test_chat_history(
|
|||
result.response.as_dict()["speech"]["plain"]["speech"]
|
||||
== "1st model response"
|
||||
)
|
||||
chat_response.text = "2nd model response"
|
||||
mock_part.text = "2nd model response"
|
||||
chat_response.parts = [mock_part]
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"2nd user request",
|
||||
|
@ -294,8 +295,8 @@ async def test_function_call(
|
|||
mock_part.function_call.args = {"param1": ["test_value"]}
|
||||
|
||||
def tool_call(hass, tool_input):
|
||||
mock_part.function_call = False
|
||||
chat_response.text = "Hi there!"
|
||||
mock_part.function_call = None
|
||||
mock_part.text = "Hi there!"
|
||||
return {"result": "Test response"}
|
||||
|
||||
mock_tool.async_call.side_effect = tool_call
|
||||
|
@ -392,8 +393,8 @@ async def test_function_exception(
|
|||
mock_part.function_call.args = {"param1": 1}
|
||||
|
||||
def tool_call(hass, tool_input):
|
||||
mock_part.function_call = False
|
||||
chat_response.text = "Hi there!"
|
||||
mock_part.function_call = None
|
||||
mock_part.text = "Hi there!"
|
||||
raise HomeAssistantError("Test tool exception")
|
||||
|
||||
mock_tool.async_call.side_effect = tool_call
|
||||
|
|
Loading…
Add table
Reference in a new issue