Skip to content

Commit 6bb0b74

Browse files
google-genai-botcopybara-github
authored andcommitted
chore: Add abstract type annotation support to AFC
PiperOrigin-RevId: 831545133
1 parent 69627b6 commit 6bb0b74

File tree

3 files changed

+21
-215
lines changed

3 files changed

+21
-215
lines changed

src/google/adk/tools/_automatic_function_calling_util.py

Lines changed: 21 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -296,55 +296,20 @@ def from_function_with_options(
296296
) -> 'types.FunctionDeclaration':
297297

298298
parameters_properties = {}
299-
parameters_json_schema = {}
300-
annotation_under_future = typing.get_type_hints(func)
301-
try:
302-
for name, param in inspect.signature(func).parameters.items():
303-
if param.kind in (
304-
inspect.Parameter.POSITIONAL_OR_KEYWORD,
305-
inspect.Parameter.KEYWORD_ONLY,
306-
inspect.Parameter.POSITIONAL_ONLY,
307-
):
308-
param = _function_parameter_parse_util._handle_params_as_deferred_annotations(
309-
param, annotation_under_future, name
310-
)
311-
312-
schema = _function_parameter_parse_util._parse_schema_from_parameter(
313-
variant, param, func.__name__
314-
)
315-
parameters_properties[name] = schema
316-
except ValueError:
317-
# If the function has complex parameter types that fail in _parse_schema_from_parameter,
318-
# we try to generate a json schema for the parameter using pydantic.TypeAdapter.
319-
parameters_properties = {}
320-
for name, param in inspect.signature(func).parameters.items():
321-
if param.kind in (
322-
inspect.Parameter.POSITIONAL_OR_KEYWORD,
323-
inspect.Parameter.KEYWORD_ONLY,
324-
inspect.Parameter.POSITIONAL_ONLY,
325-
):
326-
try:
327-
if param.annotation == inspect.Parameter.empty:
328-
param = param.replace(annotation=Any)
329-
330-
param = _function_parameter_parse_util._handle_params_as_deferred_annotations(
331-
param, annotation_under_future, name
332-
)
333-
334-
_function_parameter_parse_util._raise_for_invalid_enum_value(param)
335-
336-
json_schema_dict = _function_parameter_parse_util._generate_json_schema_for_parameter(
337-
param
338-
)
339-
340-
parameters_json_schema[name] = types.Schema.model_validate(
341-
json_schema_dict
342-
)
343-
except Exception as e:
344-
_function_parameter_parse_util._raise_for_unsupported_param(
345-
param, func.__name__, e
346-
)
347-
299+
for name, param in inspect.signature(func).parameters.items():
300+
if param.kind in (
301+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
302+
inspect.Parameter.KEYWORD_ONLY,
303+
inspect.Parameter.POSITIONAL_ONLY,
304+
):
305+
# This snippet catches the case when type hints are stored as strings
306+
if isinstance(param.annotation, str):
307+
param = param.replace(annotation=typing.get_type_hints(func)[name])
308+
309+
schema = _function_parameter_parse_util._parse_schema_from_parameter(
310+
variant, param, func.__name__
311+
)
312+
parameters_properties[name] = schema
348313
declaration = types.FunctionDeclaration(
349314
name=func.__name__,
350315
description=func.__doc__,
@@ -359,12 +324,6 @@ def from_function_with_options(
359324
declaration.parameters
360325
)
361326
)
362-
elif parameters_json_schema:
363-
declaration.parameters = types.Schema(
364-
type='OBJECT',
365-
properties=parameters_json_schema,
366-
)
367-
368327
if variant == GoogleLLMVariant.GEMINI_API:
369328
return declaration
370329

@@ -413,35 +372,17 @@ def from_function_with_options(
413372
inspect.Parameter.POSITIONAL_OR_KEYWORD,
414373
annotation=return_annotation,
415374
)
375+
# This snippet catches the case when type hints are stored as strings
416376
if isinstance(return_value.annotation, str):
417377
return_value = return_value.replace(
418378
annotation=typing.get_type_hints(func)['return']
419379
)
420380

421-
response_schema: Optional[types.Schema] = None
422-
response_json_schema: Optional[Union[Dict[str, Any], types.Schema]] = None
423-
try:
424-
response_schema = (
425-
_function_parameter_parse_util._parse_schema_from_parameter(
426-
variant,
427-
return_value,
428-
func.__name__,
429-
)
430-
)
431-
except ValueError:
432-
try:
433-
response_json_schema = (
434-
_function_parameter_parse_util._generate_json_schema_for_parameter(
435-
return_value
436-
)
437-
)
438-
response_json_schema = types.Schema.model_validate(response_json_schema)
439-
except Exception as e:
440-
_function_parameter_parse_util._raise_for_unsupported_param(
441-
return_value, func.__name__, e
381+
declaration.response = (
382+
_function_parameter_parse_util._parse_schema_from_parameter(
383+
variant,
384+
return_value,
385+
func.__name__,
442386
)
443-
if response_schema:
444-
declaration.response = response_schema
445-
elif response_json_schema:
446-
declaration.response = response_json_schema
387+
)
447388
return declaration

src/google/adk/tools/_function_parameter_parse_util.py

Lines changed: 0 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -49,91 +49,6 @@
4949
logger = logging.getLogger('google_adk.' + __name__)
5050

5151

52-
def _handle_params_as_deferred_annotations(
53-
param: inspect.Parameter, annotation_under_future: dict[str, Any], name: str
54-
) -> inspect.Parameter:
55-
"""Catches the case when type hints are stored as strings."""
56-
if isinstance(param.annotation, str):
57-
param = param.replace(annotation=annotation_under_future[name])
58-
return param
59-
60-
61-
def _add_unevaluated_items_to_fixed_len_tuple_schema(
62-
json_schema: dict[str, Any],
63-
) -> dict[str, Any]:
64-
"""Adds 'unevaluatedItems': False to schemas for fixed-length tuples.
65-
66-
For example, the schema for a parameter of type `tuple[float, float]` would
67-
be:
68-
{
69-
"type": "array",
70-
"prefixItems": [
71-
{
72-
"type": "number"
73-
},
74-
{
75-
"type": "number"
76-
},
77-
],
78-
"minItems": 2,
79-
"maxItems": 2,
80-
"unevaluatedItems": False
81-
}
82-
83-
"""
84-
if (
85-
json_schema.get('maxItems')
86-
and (
87-
json_schema.get('prefixItems')
88-
and len(json_schema['prefixItems']) == json_schema['maxItems']
89-
)
90-
and json_schema.get('type') == 'array'
91-
):
92-
json_schema['unevaluatedItems'] = False
93-
return json_schema
94-
95-
96-
def _raise_for_unsupported_param(
97-
param: inspect.Parameter,
98-
func_name: str,
99-
exception: Exception,
100-
) -> None:
101-
raise ValueError(
102-
f'Failed to parse the parameter {param} of function {func_name} for'
103-
' automatic function calling.Automatic function calling works best with'
104-
' simpler function signature schema, consider manually parsing your'
105-
f' function declaration for function {func_name}.'
106-
) from exception
107-
108-
109-
def _raise_for_invalid_enum_value(param: inspect.Parameter):
110-
"""Raises an error if the default value is not a valid enum value."""
111-
if inspect.isclass(param.annotation) and issubclass(param.annotation, Enum):
112-
if param.default is not inspect.Parameter.empty and param.default not in [
113-
e.value for e in param.annotation
114-
]:
115-
raise ValueError(
116-
f'Default value {param.default} is not a valid enum value for'
117-
f' {param.annotation}.'
118-
)
119-
120-
121-
def _generate_json_schema_for_parameter(
122-
param: inspect.Parameter,
123-
) -> dict[str, Any]:
124-
"""Generates a JSON schema for a parameter using pydantic.TypeAdapter."""
125-
126-
param_schema_adapter = pydantic.TypeAdapter(
127-
param.annotation,
128-
config=pydantic.ConfigDict(arbitrary_types_allowed=True),
129-
)
130-
json_schema_dict = param_schema_adapter.json_schema()
131-
json_schema_dict = _add_unevaluated_items_to_fixed_len_tuple_schema(
132-
json_schema_dict
133-
)
134-
return json_schema_dict
135-
136-
13752
def _is_builtin_primitive_or_compound(
13853
annotation: inspect.Parameter.annotation,
13954
) -> bool:

tests/unittests/tools/test_from_function_with_options.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from collections.abc import Sequence
1615
from typing import Any
1716
from typing import Dict
1817

@@ -193,52 +192,3 @@ def test_function() -> None:
193192
# VERTEX_AI should have response schema for None return
194193
assert declaration.response is not None
195194
assert declaration.response.type == types.Type.NULL
196-
197-
198-
def test_from_function_with_collections_type_parameter():
199-
"""Test from_function_with_options with collections type parameter."""
200-
201-
def test_function(
202-
artifact_key: str,
203-
input_edit_ids: Sequence[str],
204-
) -> str:
205-
"""Saves a sequence of edit IDs."""
206-
return f'Saved {len(input_edit_ids)} edit IDs for artifact {artifact_key}'
207-
208-
declaration = _automatic_function_calling_util.from_function_with_options(
209-
test_function, GoogleLLMVariant.VERTEX_AI
210-
)
211-
212-
assert declaration.name == 'test_function'
213-
assert declaration.parameters.type == types.Type.OBJECT
214-
assert (
215-
declaration.parameters.properties['artifact_key'].type
216-
== types.Type.STRING
217-
)
218-
assert (
219-
declaration.parameters.properties['input_edit_ids'].type
220-
== types.Type.ARRAY
221-
)
222-
assert (
223-
declaration.parameters.properties['input_edit_ids'].items.type
224-
== types.Type.STRING
225-
)
226-
assert declaration.response.type == types.Type.STRING
227-
228-
229-
def test_from_function_with_collections_return_type():
230-
"""Test from_function_with_options with collections return type."""
231-
232-
def test_function(
233-
names: list[str],
234-
) -> Sequence[str]:
235-
"""Returns a sequence of names."""
236-
return names
237-
238-
declaration = _automatic_function_calling_util.from_function_with_options(
239-
test_function, GoogleLLMVariant.VERTEX_AI
240-
)
241-
242-
assert declaration.name == 'test_function'
243-
assert declaration.response.type == types.Type.ARRAY
244-
assert declaration.response.items.type == types.Type.STRING

0 commit comments

Comments
 (0)