diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index f2ff3706..4eaa3474 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -536,12 +536,42 @@ def _generate_tensor_caster(name, is_data=False): def _generate_generated_dispatch_entries(operator): + optional_tensor_params = _find_optional_tensor_params(operator.name) + vector_tensor_params = _find_vector_tensor_params(operator.name) + vector_int64_params = _find_vector_int64_params(operator.name) + + def _is_optional_tensor(arg): + if arg.spelling in optional_tensor_params: + return True + + return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling + + def _is_vector_tensor(arg): + if arg.spelling in vector_tensor_params: + return True + + return "std::vector" in arg.type.spelling and "Tensor" in arg.type.spelling + + def _is_vector_int64(arg): + return arg.spelling in vector_int64_params + def _generate_params(node): - return ", ".join( - f"{arg.type.spelling} {arg.spelling}" - for arg in node.get_arguments() - if arg.spelling != "stream" - ) + parts = [] + + for arg in node.get_arguments(): + if arg.spelling == "stream": + continue + + if _is_optional_tensor(arg): + parts.append(f"std::optional {arg.spelling}") + elif _is_vector_tensor(arg): + parts.append(f"std::vector {arg.spelling}") + elif _is_vector_int64(arg): + parts.append(f"std::vector {arg.spelling}") + else: + parts.append(f"{arg.type.spelling} {arg.spelling}") + + return ", ".join(parts) def _generate_arguments(node): return ", ".join(