diff --git a/src/google/adk/tools/_function_parameter_parse_util.py b/src/google/adk/tools/_function_parameter_parse_util.py index e61b61db56..848e181c1e 100644 --- a/src/google/adk/tools/_function_parameter_parse_util.py +++ b/src/google/adk/tools/_function_parameter_parse_util.py @@ -203,6 +203,22 @@ def _is_default_value_compatible( for item in default_value ) + if origin is tuple: + if not isinstance(default_value, tuple): + return False + args = get_args(annotation) + if len(args) == 2 and args[1] is Ellipsis: + return all( + _is_default_value_compatible(item, args[0]) + for item in default_value + ) + if len(args) != len(default_value): + return False + return all( + _is_default_value_compatible(item, arg) + for item, arg in zip(default_value, args) + ) + if origin is Literal: return default_value in get_args(annotation) @@ -334,6 +350,31 @@ def _parse_schema_from_parameter( schema.default = param.default _raise_if_schema_unsupported(variant, schema) return schema + if origin is tuple: + if len(args) == 2 and args[1] is Ellipsis: + item_annotation = args[0] + elif args and all(arg == args[0] for arg in args): + item_annotation = args[0] + else: + raise ValueError( + f'Tuple type {param.annotation} must use one repeated item type.' + ) + schema.type = types.Type.ARRAY + schema.items = _parse_schema_from_parameter( + variant, + inspect.Parameter( + 'item', + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=item_annotation, + ), + func_name, + ) + if param.default is not inspect.Parameter.empty: + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema if origin in (Union, typing_types.UnionType): schema.any_of = [] schema.type = types.Type.OBJECT diff --git a/tests/unittests/tools/test_from_function_with_options.py b/tests/unittests/tools/test_from_function_with_options.py index 4f77bc7b1f..839c431beb 100644 --- a/tests/unittests/tools/test_from_function_with_options.py +++ b/tests/unittests/tools/test_from_function_with_options.py @@ -228,6 +228,31 @@ def test_function( assert declaration.response.type == types.Type.STRING +def test_from_function_with_tuple_type_parameter(): + """Test from_function_with_options with tuple type parameter.""" + + def test_function( + coordinate: tuple[float, float], + ) -> str: + """Formats a coordinate pair.""" + return f'{coordinate[0]}, {coordinate[1]}' + + declaration = _automatic_function_calling_util.from_function_with_options( + test_function, GoogleLLMVariant.VERTEX_AI + ) + + assert declaration.name == 'test_function' + assert declaration.parameters.type == types.Type.OBJECT + assert declaration.parameters.properties['coordinate'].type == ( + types.Type.ARRAY + ) + assert ( + declaration.parameters.properties['coordinate'].items.type + == types.Type.NUMBER + ) + assert declaration.response.type == types.Type.STRING + + def test_from_function_with_collections_return_type(): """Test from_function_with_options with collections return type.""" @@ -321,14 +346,8 @@ async def test_function(param: str) -> AsyncGenerator[Dict[str, str], None]: assert declaration.response.type == types.Type.OBJECT -def test_required_fields_set_in_json_schema_fallback(): - """Test that required fields are populated when the json_schema fallback path is used. - - When a parameter has a complex type (e.g. tuple[str, ...] | None) that - _parse_schema_from_parameter can't handle, from_function_with_options falls - back to the parameters_json_schema branch. This test verifies that the - required fields are correctly populated in that fallback branch. - """ +def test_required_fields_set_with_optional_tuple_parameter(): + """Test that required fields are populated with optional tuple parameters.""" def complex_tool( query: str, @@ -350,14 +369,9 @@ def complex_tool( 'query': types.Schema(type=types.Type.STRING), 'mode': types.Schema(type=types.Type.STRING, default='default'), 'tags': types.Schema( - any_of=[ - types.Schema( - items=types.Schema(type=types.Type.STRING), - type=types.Type.ARRAY, - ), - types.Schema(type=types.Type.NULL), - ], + items=types.Schema(type=types.Type.STRING), nullable=True, + type=types.Type.ARRAY, ), }, )