Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/google/adk/tools/_function_parameter_parse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,22 @@ def _is_default_value_compatible(
for item in default_value
)

if origin is tuple:
if not isinstance(default_value, tuple):
return False
args = get_args(annotation)
if len(args) == 2 and args[1] is Ellipsis:
return all(
_is_default_value_compatible(item, args[0])
for item in default_value
)
if len(args) != len(default_value):
return False
return all(
_is_default_value_compatible(item, arg)
for item, arg in zip(default_value, args)
)

if origin is Literal:
return default_value in get_args(annotation)

Expand Down Expand Up @@ -334,6 +350,31 @@ def _parse_schema_from_parameter(
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
if origin is tuple:
if len(args) == 2 and args[1] is Ellipsis:
item_annotation = args[0]
elif args and all(arg == args[0] for arg in args):
item_annotation = args[0]
else:
raise ValueError(
f'Tuple type {param.annotation} must use one repeated item type.'
)
schema.type = types.Type.ARRAY
schema.items = _parse_schema_from_parameter(
variant,
inspect.Parameter(
'item',
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=item_annotation,
),
func_name,
)
if param.default is not inspect.Parameter.empty:
if not _is_default_value_compatible(param.default, param.annotation):
raise ValueError(default_value_error_msg)
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
if origin in (Union, typing_types.UnionType):
schema.any_of = []
schema.type = types.Type.OBJECT
Expand Down
44 changes: 29 additions & 15 deletions tests/unittests/tools/test_from_function_with_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,31 @@ def test_function(
assert declaration.response.type == types.Type.STRING


def test_from_function_with_tuple_type_parameter():
"""Test from_function_with_options with tuple type parameter."""

def test_function(
coordinate: tuple[float, float],
) -> str:
"""Formats a coordinate pair."""
return f'{coordinate[0]}, {coordinate[1]}'

declaration = _automatic_function_calling_util.from_function_with_options(
test_function, GoogleLLMVariant.VERTEX_AI
)

assert declaration.name == 'test_function'
assert declaration.parameters.type == types.Type.OBJECT
assert declaration.parameters.properties['coordinate'].type == (
types.Type.ARRAY
)
assert (
declaration.parameters.properties['coordinate'].items.type
== types.Type.NUMBER
)
assert declaration.response.type == types.Type.STRING


def test_from_function_with_collections_return_type():
"""Test from_function_with_options with collections return type."""

Expand Down Expand Up @@ -321,14 +346,8 @@ async def test_function(param: str) -> AsyncGenerator[Dict[str, str], None]:
assert declaration.response.type == types.Type.OBJECT


def test_required_fields_set_in_json_schema_fallback():
"""Test that required fields are populated when the json_schema fallback path is used.

When a parameter has a complex type (e.g. tuple[str, ...] | None) that
_parse_schema_from_parameter can't handle, from_function_with_options falls
back to the parameters_json_schema branch. This test verifies that the
required fields are correctly populated in that fallback branch.
"""
def test_required_fields_set_with_optional_tuple_parameter():
"""Test that required fields are populated with optional tuple parameters."""

def complex_tool(
query: str,
Expand All @@ -350,14 +369,9 @@ def complex_tool(
'query': types.Schema(type=types.Type.STRING),
'mode': types.Schema(type=types.Type.STRING, default='default'),
'tags': types.Schema(
any_of=[
types.Schema(
items=types.Schema(type=types.Type.STRING),
type=types.Type.ARRAY,
),
types.Schema(type=types.Type.NULL),
],
items=types.Schema(type=types.Type.STRING),
nullable=True,
type=types.Type.ARRAY,
),
},
)