Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
3cf4f55
Allow functions to not be inlined
justinchuby Mar 13, 2026
ebbf30f
Add functions
justinchuby Mar 13, 2026
b195105
Lint
justinchuby Mar 13, 2026
72639b1
Potential fix for pull request finding
justinchuby Mar 13, 2026
50322f9
Merge branch 'main' into justinchu/function-inline
Mar 23, 2026
5a9e00b
Update impl
Mar 24, 2026
0a3efad
Fix OpBuilder to properly extract _domain, _version, _outputs from kw…
Mar 24, 2026
98aecbf
Address comments
justinchuby Apr 17, 2026
320abb3
Update onnxscript/_internal/builder.py
justinchuby Apr 17, 2026
c94561e
optimizer: Prevent constant folding of DynamicQuantizeLinear (#2865)
Copilot Mar 26, 2026
3688398
Support None as op input in GraphBuilder (#2868)
justinchuby Mar 27, 2026
a9136fb
Unify failure-handling in rewrite-rule (#2866)
gramalingam Mar 27, 2026
3f6dc1d
Add parent/root tracking to GraphBuilder for subgraph Parameter reali…
gramalingam Apr 2, 2026
a8495d6
Fix non-deterministic rewriter behavior in multi-output pattern match…
Copilot Apr 2, 2026
acd79c8
chore(deps): bump actions/deploy-pages from 4 to 5 (#2869)
dependabot[bot] Apr 2, 2026
1b2e36b
chore(deps): bump codecov/codecov-action from 5 to 6 (#2871)
dependabot[bot] Apr 2, 2026
7b6cb95
chore(deps): bump onnx-weekly from 1.21.0.dev20260302 to 1.22.0.dev20…
dependabot[bot] Apr 3, 2026
958d50e
chore(deps): bump actions/configure-pages from 4 to 6 (#2870)
dependabot[bot] Apr 3, 2026
c400ca7
[torchlib] Add missing dtype parameter to aten_mean_dim (#2885)
linusjuni Apr 10, 2026
cfed3f5
Fix BatchNorm fusion producing invalid ONNX when Conv nodes share wei…
Copilot Apr 10, 2026
a46d489
Add input() and add_output() methods to GraphBuilder (#2828)
justinchuby Apr 10, 2026
29a5e09
Add fusion rule to remove Expand before broadcast-capable binary oper…
Copilot Apr 10, 2026
b5fe709
fix(fuse_batchnorm): support convtranpose + bn fusion with group != 1…
AyoubMDL Apr 10, 2026
49c05a2
fix: normalize cache key dtype to prevent initializer name collisions…
gramalingam Apr 14, 2026
98eaa1f
Handling initializers in GraphBuilder (#2889)
gramalingam Apr 16, 2026
da82e42
Update prefix naming
justinchuby Apr 17, 2026
79c2f39
Fix renaming
justinchuby Apr 17, 2026
f40cd7c
Merge branch 'main' into justinchu/function-inline
justinchuby Apr 17, 2026
0c240a6
Address PR review comments: delegate functions to root, fix annotatio…
justinchuby Apr 17, 2026
3f99c05
Remove onnx function
justinchuby Apr 17, 2026
5e2a946
overload
justinchuby Apr 17, 2026
91affd4
Merge branch 'main' into justinchu/function-inline
titaiwangms Apr 17, 2026
2c28041
Enhance GraphBuilder to support outer-scope values in inlined functio…
justinchuby Apr 17, 2026
461a10c
Merge branch 'main' into justinchu/function-inline
justinchuby Apr 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 125 additions & 27 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ def __init__(self, graph: ir.Graph, *, parent: GraphBuilder | None = None) -> No
# visible to subgraphs per the ONNX spec).
if parent is None:
self._constant_cache: dict[tuple[Any, ir.DataType | None], ir.Value] = {}
self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}

def opset(self, domain: str, version: int = 1) -> OpBuilder:
"""Create an OpBuilder bound to the given domain and version."""
Expand All @@ -469,6 +470,10 @@ def root(self) -> GraphBuilder:
def graph(self) -> ir.Graph:
return self._graph

@property
def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]:
return self._root._functions

def initializer(
self, tensor: ir.TensorProtocol, name: str | None = None, *, qualify: bool = True
) -> ir.Value:
Expand Down Expand Up @@ -796,12 +801,12 @@ def call_op(
op_type: str,
inputs: Sequence[ir.Value | ir.TensorProtocol | None],
kwargs: dict[str, Any],
/,
domain: str = "",
version: int | None = None,
outputs: int | Sequence[str | ir.Value] = 1,
):
"""Create an ONNX node and add it to the graph, returning its output value(s)."""
domain = kwargs.pop("_domain", "")
version = kwargs.pop("_version", None)
outputs = kwargs.pop("_outputs", 1)

count = self.graph.num_nodes()
node_name = self._qualify_node_name(f"{op_type}_node_{count}")

Expand Down Expand Up @@ -833,7 +838,54 @@ def call_op(

def call(
self,
function,
function: ir.Function | onnxscript.OnnxFunction,
*args,
_outputs: int | Sequence[str | ir.Value] | None = None,
**kwargs,
):
"""Call a function as a single function node."""
if isinstance(function, ir.Function):
Comment thread
gramalingam marked this conversation as resolved.
graph = function.graph
Comment thread
justinchuby marked this conversation as resolved.
elif isinstance(function, onnxscript.OnnxFunction):
graph = function.graph()
function = function.function_ir
else:
raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction")

if _outputs is None:
_outputs = len(graph.outputs)
Comment thread
justinchuby marked this conversation as resolved.
output_values = self._adapt_outputs(_outputs, function.name)

# Adapt inputs similarly to call_op: promote constants/tensors to ir.Value.
adapted_args = [self._input_to_ir_value(arg) for arg in args]

count = self.graph.num_nodes()
node_name = self._qualify_node_name(f"{function.name}_node_{count}")

node = ir.node(
op_type=function.name,
inputs=adapted_args,
attributes=kwargs or None,
outputs=output_values,
domain=function.domain,
name=node_name,
overload=function.overload,
)
# Attach scope metadata to the node
node.metadata_props["namespace"] = self._build_namespace()
node.metadata_props["pkg.onnxscript.class_hierarchy"] = repr(self._scope_classes())
node.metadata_props["pkg.onnxscript.name_scopes"] = repr(self._scope_names())

self.add_node(node)
self._root._functions[function.identifier()] = function

if len(node.outputs) == 0:
return ()
return node.outputs if len(node.outputs) > 1 else node.outputs[0]
Comment thread
justinchuby marked this conversation as resolved.

def call_inline(
self,
function: ir.Function | onnxscript.OnnxFunction,
*args,
_outputs: Sequence[str] | None = None,
_prefix: str = "",
Expand All @@ -842,35 +894,56 @@ def call(
if isinstance(function, ir.Function):
graph = function.graph
elif isinstance(function, onnxscript.OnnxFunction):
graph = function.graph()
# TODO(justinchuby): Reason about support for outer-scope values in inlined function bodies.
graph = function.graph().clone(allow_outer_scope_values=True)
else:
raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction")
output_renaming: dict[str, str] = {}
if _outputs is not None:
if len(_outputs) != len(graph.outputs):
raise ValueError(
f"Number of provided output names {_outputs} does not match "
f"Number of rovided output names {_outputs} does not match "
f"number of function outputs {len(graph.outputs)}."
)
for output, name in zip(graph.outputs, _outputs):
output_renaming[output.name] = self._qualify_value_name(name)
# Compute desired output names before pushing prefix scope so they
# are not affected by the prefix.
desired_output_names: list[str] = [
self._qualify_value_name(name) for name in _outputs
]
else:
for output in graph.outputs:
output_renaming[output.name] = self._qualify_value_name(output.name)
nodes, outputs = _inliner.instantiate(graph, args, kwargs)
desired_output_names = []

if _prefix:
self.push_module(_prefix)

count = self.graph.num_nodes()
node_name_prefix = self._qualify_node_name(f"{function.name}_node_{count}/")
nodes, outputs = _inliner.instantiate(graph, args, kwargs, prefix=node_name_prefix)
Comment thread
justinchuby marked this conversation as resolved.

# Track final output values so we can rename them separately.
# The inliner prefixes all names, which would prevent name-based lookup
# from matching the original graph output names.
output_value_ids = {id(v) for v in outputs if v is not None}

for node in nodes:
node.name = self._qualify_node_name(node.name)
for output in node.outputs:
if output.name:
if output.name in output_renaming:
output.name = output_renaming[output.name]
else:
output.name = self._qualify_value_name(output.name)
if output.name and id(output) not in output_value_ids:
output.name = self._qualify_value_name(output.name)
self.add_node(node)

# Apply names to final output values
if desired_output_names:
for output_val, name in zip(outputs, desired_output_names):
if output_val is not None:
output_val.name = name
else:
for output_val in outputs:
if output_val is not None and output_val.name:
output_val.name = self._qualify_value_name(output_val.name)

if _prefix:
self.pop_module()
Comment thread
justinchuby marked this conversation as resolved.
if len(outputs) == 0:
return ()
return outputs if len(outputs) > 1 else outputs[0]

def push_module(self, module: str, class_name: str = "") -> None:
Expand Down Expand Up @@ -962,27 +1035,52 @@ def version(self) -> int | None:
return self._version

def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]):
if "_domain" not in kwargs:
kwargs["_domain"] = self._domain
if self._version is not None and "_version" not in kwargs:
kwargs["_version"] = self._version
return self._builder.call_op(op_type, inputs, kwargs)
domain = kwargs.pop("_domain", self._domain)
version = kwargs.pop("_version", self._version)
outputs = kwargs.pop("_outputs", 1)
return self._builder.call_op(
op_type, inputs, kwargs, domain=domain, version=version, outputs=outputs
)

def __getattr__(self, op_type: str) -> Callable:
return lambda *args, **kwargs: self._call_op(op_type, args, kwargs)

def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value:
return self._builder.initializer(tensor, name)

Comment thread
justinchuby marked this conversation as resolved.
@property
def functions(self) -> dict[ir.OperatorIdentifier, ir.Function]:
Comment thread
justinchuby marked this conversation as resolved.
return self._builder.functions

def call(
self,
function,
*args,
_outputs: Sequence[str] | int | None = None,
**kwargs,
):
"""Call a function as a single function node.

Args:
function: The function to call (ir.Function or onnxscript.OnnxFunction).
*args: Positional arguments to pass to the function.
_outputs: Optional sequence of output names, or an integer specifying the number of outputs.
**kwargs: Keyword arguments to pass to the function.

Returns:
The output value(s) from the function call.
"""
return self._builder.call(function, *args, _outputs=_outputs, **kwargs)

def call_inline(
self,
function,
*args,
_outputs: Sequence[str] | None = None,
_prefix: str = "",
**kwargs,
):
"""Call a function and inline it into the graph.
"""Inline a function body into the current graph.

Args:
function: The function to call (ir.Function or onnxscript.OnnxFunction).
Expand All @@ -993,8 +1091,8 @@ def call(
**kwargs: Keyword arguments to pass to the function.

Returns:
The output value(s) from the function call.
The output value(s) from the inlined function body.
"""
return self._builder.call(
return self._builder.call_inline(
function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs
)
Loading
Loading