Handle multiple function_call and text parts in Google Generative AI (#118270)
This commit is contained in:
parent
bfc3194661
commit
722feb285b
2 changed files with 40 additions and 35 deletions
|
@ -298,43 +298,47 @@ class GoogleGenerativeAIConversationEntity(
|
||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=conversation_id
|
||||||
)
|
)
|
||||||
self.history[conversation_id] = chat.history
|
self.history[conversation_id] = chat.history
|
||||||
tool_call = chat_response.parts[0].function_call
|
tool_calls = [
|
||||||
|
part.function_call for part in chat_response.parts if part.function_call
|
||||||
if not tool_call or not llm_api:
|
]
|
||||||
|
if not tool_calls or not llm_api:
|
||||||
break
|
break
|
||||||
|
|
||||||
tool_input = llm.ToolInput(
|
tool_responses = []
|
||||||
tool_name=tool_call.name,
|
for tool_call in tool_calls:
|
||||||
tool_args=dict(tool_call.args),
|
tool_input = llm.ToolInput(
|
||||||
platform=DOMAIN,
|
tool_name=tool_call.name,
|
||||||
context=user_input.context,
|
tool_args=dict(tool_call.args),
|
||||||
user_prompt=user_input.text,
|
platform=DOMAIN,
|
||||||
language=user_input.language,
|
context=user_input.context,
|
||||||
assistant=conversation.DOMAIN,
|
user_prompt=user_input.text,
|
||||||
device_id=user_input.device_id,
|
language=user_input.language,
|
||||||
)
|
assistant=conversation.DOMAIN,
|
||||||
LOGGER.debug(
|
device_id=user_input.device_id,
|
||||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
)
|
||||||
)
|
LOGGER.debug(
|
||||||
try:
|
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||||
function_response = await llm_api.async_call_tool(tool_input)
|
)
|
||||||
except (HomeAssistantError, vol.Invalid) as e:
|
try:
|
||||||
function_response = {"error": type(e).__name__}
|
function_response = await llm_api.async_call_tool(tool_input)
|
||||||
if str(e):
|
except (HomeAssistantError, vol.Invalid) as e:
|
||||||
function_response["error_text"] = str(e)
|
function_response = {"error": type(e).__name__}
|
||||||
|
if str(e):
|
||||||
|
function_response["error_text"] = str(e)
|
||||||
|
|
||||||
LOGGER.debug("Tool response: %s", function_response)
|
LOGGER.debug("Tool response: %s", function_response)
|
||||||
chat_request = glm.Content(
|
tool_responses.append(
|
||||||
parts=[
|
|
||||||
glm.Part(
|
glm.Part(
|
||||||
function_response=glm.FunctionResponse(
|
function_response=glm.FunctionResponse(
|
||||||
name=tool_call.name, response=function_response
|
name=tool_call.name, response=function_response
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
]
|
)
|
||||||
)
|
chat_request = glm.Content(parts=tool_responses)
|
||||||
|
|
||||||
intent_response.async_set_speech(chat_response.text)
|
intent_response.async_set_speech(
|
||||||
|
" ".join([part.text for part in chat_response.parts if part.text])
|
||||||
|
)
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=conversation_id
|
||||||
)
|
)
|
||||||
|
|
|
@ -191,8 +191,8 @@ async def test_default_prompt(
|
||||||
mock_chat.send_message_async.return_value = chat_response
|
mock_chat.send_message_async.return_value = chat_response
|
||||||
mock_part = MagicMock()
|
mock_part = MagicMock()
|
||||||
mock_part.function_call = None
|
mock_part.function_call = None
|
||||||
|
mock_part.text = "Hi there!"
|
||||||
chat_response.parts = [mock_part]
|
chat_response.parts = [mock_part]
|
||||||
chat_response.text = "Hi there!"
|
|
||||||
result = await conversation.async_converse(
|
result = await conversation.async_converse(
|
||||||
hass,
|
hass,
|
||||||
"hello",
|
"hello",
|
||||||
|
@ -221,8 +221,8 @@ async def test_chat_history(
|
||||||
mock_chat.send_message_async.return_value = chat_response
|
mock_chat.send_message_async.return_value = chat_response
|
||||||
mock_part = MagicMock()
|
mock_part = MagicMock()
|
||||||
mock_part.function_call = None
|
mock_part.function_call = None
|
||||||
|
mock_part.text = "1st model response"
|
||||||
chat_response.parts = [mock_part]
|
chat_response.parts = [mock_part]
|
||||||
chat_response.text = "1st model response"
|
|
||||||
mock_chat.history = [
|
mock_chat.history = [
|
||||||
{"role": "user", "parts": "prompt"},
|
{"role": "user", "parts": "prompt"},
|
||||||
{"role": "model", "parts": "Ok"},
|
{"role": "model", "parts": "Ok"},
|
||||||
|
@ -241,7 +241,8 @@ async def test_chat_history(
|
||||||
result.response.as_dict()["speech"]["plain"]["speech"]
|
result.response.as_dict()["speech"]["plain"]["speech"]
|
||||||
== "1st model response"
|
== "1st model response"
|
||||||
)
|
)
|
||||||
chat_response.text = "2nd model response"
|
mock_part.text = "2nd model response"
|
||||||
|
chat_response.parts = [mock_part]
|
||||||
result = await conversation.async_converse(
|
result = await conversation.async_converse(
|
||||||
hass,
|
hass,
|
||||||
"2nd user request",
|
"2nd user request",
|
||||||
|
@ -294,8 +295,8 @@ async def test_function_call(
|
||||||
mock_part.function_call.args = {"param1": ["test_value"]}
|
mock_part.function_call.args = {"param1": ["test_value"]}
|
||||||
|
|
||||||
def tool_call(hass, tool_input):
|
def tool_call(hass, tool_input):
|
||||||
mock_part.function_call = False
|
mock_part.function_call = None
|
||||||
chat_response.text = "Hi there!"
|
mock_part.text = "Hi there!"
|
||||||
return {"result": "Test response"}
|
return {"result": "Test response"}
|
||||||
|
|
||||||
mock_tool.async_call.side_effect = tool_call
|
mock_tool.async_call.side_effect = tool_call
|
||||||
|
@ -392,8 +393,8 @@ async def test_function_exception(
|
||||||
mock_part.function_call.args = {"param1": 1}
|
mock_part.function_call.args = {"param1": 1}
|
||||||
|
|
||||||
def tool_call(hass, tool_input):
|
def tool_call(hass, tool_input):
|
||||||
mock_part.function_call = False
|
mock_part.function_call = None
|
||||||
chat_response.text = "Hi there!"
|
mock_part.text = "Hi there!"
|
||||||
raise HomeAssistantError("Test tool exception")
|
raise HomeAssistantError("Test tool exception")
|
||||||
|
|
||||||
mock_tool.async_call.side_effect = tool_call
|
mock_tool.async_call.side_effect = tool_call
|
||||||
|
|
Loading…
Add table
Reference in a new issue