Fix gemini api format conversion (#122403)

* Fix gemini api format conversion

* add tests

* fix tests

* fix tests

* fix coverage
This commit is contained in:
Denis Shulyaka 2024-07-23 03:56:13 +03:00 committed by GitHub
parent 5d3c57ecfe
commit 975cfa6457
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 38 additions and 2 deletions

View file

@ -73,6 +73,14 @@ SUPPORTED_SCHEMA_KEYS = {
def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
"""Format the schema to protobuf."""
if (subschemas := schema.get("anyOf")) or (subschemas := schema.get("allOf")):
for subschema in subschemas: # Gemini API does not support anyOf and allOf keys
if "type" in subschema: # Fallback to first subschema with 'type' field
return _format_schema(subschema)
return _format_schema(
subschemas[0]
) # Or, if not found, to any of the subschemas
result = {}
for key, val in schema.items():
if key not in SUPPORTED_SCHEMA_KEYS:
@ -81,7 +89,9 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
key = "type_"
val = val.upper()
elif key == "format":
if schema.get("type") == "string" and val != "enum":
if (schema.get("type") == "string" and val != "enum") or (
schema.get("type") not in ("number", "integer", "string")
):
continue
key = "format_"
elif key == "items":
@ -89,6 +99,12 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
elif key == "properties":
val = {k: _format_schema(v) for k, v in val.items()}
result[key] = val
if result.get("type_") == "OBJECT" and not result.get("properties"):
# An object with undefined properties is not supported by Gemini API.
# Fallback to JSON string. This will probably fail for most tools that want it,
# but we don't have a better fallback strategy so far.
result["properties"] = {"json": {"type_": "STRING"}}
return result

View file

@ -442,6 +442,24 @@
description: "Test function"
parameters {
type_: OBJECT
properties {
key: "param3"
value {
type_: OBJECT
properties {
key: "json"
value {
type_: STRING
}
}
}
}
properties {
key: "param2"
value {
type_: NUMBER
}
}
properties {
key: "param1"
value {

View file

@ -185,7 +185,9 @@ async def test_function_call(
{
vol.Optional("param1", description="Test parameters"): [
vol.All(str, vol.Lower)
]
],
vol.Optional("param2"): vol.Any(float, int),
vol.Optional("param3"): dict,
}
)