Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions scripts/generate_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,12 +536,42 @@ def _generate_tensor_caster(name, is_data=False):


def _generate_generated_dispatch_entries(operator):
optional_tensor_params = _find_optional_tensor_params(operator.name)
vector_tensor_params = _find_vector_tensor_params(operator.name)
vector_int64_params = _find_vector_int64_params(operator.name)

def _is_optional_tensor(arg):
if arg.spelling in optional_tensor_params:
return True

return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling

def _is_vector_tensor(arg):
if arg.spelling in vector_tensor_params:
return True

return "std::vector" in arg.type.spelling and "Tensor" in arg.type.spelling

def _is_vector_int64(arg):
return arg.spelling in vector_int64_params

def _generate_params(node):
return ", ".join(
f"{arg.type.spelling} {arg.spelling}"
for arg in node.get_arguments()
if arg.spelling != "stream"
)
parts = []

for arg in node.get_arguments():
if arg.spelling == "stream":
continue

if _is_optional_tensor(arg):
parts.append(f"std::optional<Tensor> {arg.spelling}")
elif _is_vector_tensor(arg):
parts.append(f"std::vector<Tensor> {arg.spelling}")
elif _is_vector_int64(arg):
parts.append(f"std::vector<int64_t> {arg.spelling}")
else:
parts.append(f"{arg.type.spelling} {arg.spelling}")

return ", ".join(parts)

def _generate_arguments(node):
return ", ".join(
Expand Down