From 9fb1c2a90bb375f7c2282f52b65a1a89739cdd94 Mon Sep 17 00:00:00 2001 From: Zhang Shuo <52872288+fuyou4546@users.noreply.github.com> Date: Thu, 21 May 2026 08:06:06 +0000 Subject: [PATCH] fix(scripts): apply regex fallback in dispatch entries generation --- scripts/generate_wrappers.py | 40 +++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index f2ff3706..4eaa3474 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -536,12 +536,42 @@ def _generate_tensor_caster(name, is_data=False): def _generate_generated_dispatch_entries(operator): + optional_tensor_params = _find_optional_tensor_params(operator.name) + vector_tensor_params = _find_vector_tensor_params(operator.name) + vector_int64_params = _find_vector_int64_params(operator.name) + + def _is_optional_tensor(arg): + if arg.spelling in optional_tensor_params: + return True + + return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling + + def _is_vector_tensor(arg): + if arg.spelling in vector_tensor_params: + return True + + return "std::vector" in arg.type.spelling and "Tensor" in arg.type.spelling + + def _is_vector_int64(arg): + return arg.spelling in vector_int64_params + def _generate_params(node): - return ", ".join( - f"{arg.type.spelling} {arg.spelling}" - for arg in node.get_arguments() - if arg.spelling != "stream" - ) + parts = [] + + for arg in node.get_arguments(): + if arg.spelling == "stream": + continue + + if _is_optional_tensor(arg): + parts.append(f"std::optional {arg.spelling}") + elif _is_vector_tensor(arg): + parts.append(f"std::vector {arg.spelling}") + elif _is_vector_int64(arg): + parts.append(f"std::vector {arg.spelling}") + else: + parts.append(f"{arg.type.spelling} {arg.spelling}") + + return ", ".join(parts) def _generate_arguments(node): return ", ".join(