Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import traceback
from inspect import isclass
from typing import Optional, Sequence
from typing import List, Optional, Sequence, Tuple

import torch
import torch.fx
Expand All @@ -17,6 +17,10 @@
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.graph_module import (
_get_control_flow_submodules,
get_control_flow_submodules,
)

from torch._export.utils import (
get_buffer,
Expand All @@ -29,6 +33,7 @@
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.export.graph_signature import InputKind
from torch.fx import GraphModule, Node


def is_submodule_node(node: torch.fx.Node):
Expand Down Expand Up @@ -284,3 +289,45 @@ def set_node_arg(node: torch.fx.Node, i: int | str, value):
def get_output_dim_orders(graph_module):
output_node = graph_module.graph.output_node()
return [get_first_fake_tensor(node).dim_order() for node in output_node.args[0]]


def is_nested_control_flow_graph(graph_module: GraphModule) -> bool:
"""Returns True if graph_module is a nested control-flow graph."""

# Find all top-level control-flow submodules
top_cf = get_control_flow_submodules(graph_module)
# For each submodule, see if it itself has control-flow inside
for _, submod, _ in top_cf:
if get_control_flow_submodules(submod):
return True
return False


def get_cond_while_submodules_nested(
graph_module: GraphModule,
apply_quantization: bool = False,
) -> List[Tuple[str, GraphModule, Node]]:
"""Recursively find cond/while_loop submodules in an GraphModule.

In nested control flow graphs, FX records the submodule functions
(true/false or cond/body) in reverse order compared to top-level graphs. We
must swap the indices when nested so that cond (first) and body/true_fn
(second) are consistently identified across all nesting levels.

"""

# Determine arg indices based on nesting and whether only cond branch is needed
nested = is_nested_control_flow_graph(graph_module)
# cond: [true_fn, false_fn] or swapped if nested
cond_indices = [2, 1] if nested else [1, 2]
# while_loop: [cond_fn, body_fn] or swapped if nested
while_indices = [1, 0] if nested else [0, 1]
if apply_quantization:
# only keep the cond_fn for while_loop (first index) when quantizing.
while_indices = [while_indices[0]]
mapping = {
torch.ops.higher_order.cond: cond_indices,
torch.ops.higher_order.while_loop: while_indices,
}
# collect cond/while submodules (using mapping indices)
return _get_control_flow_submodules(graph_module, mapping)
35 changes: 27 additions & 8 deletions backends/arm/_passes/control_flow_const_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

from typing import Set, Type

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import (
get_cond_while_submodules_nested,
is_submodule_node,
)
from executorch.backends.transforms.utils import is_get_attr_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_cond_while_submodules

from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule

Expand All @@ -27,15 +30,23 @@ class ControlFlowConstInlinePass(ArmPass):

_passes_required_after: Set[Type[ExportPass]] = set()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
_targeted_ops = {
torch.ops.higher_order.cond,
torch.ops.higher_order.while_loop,
}

def call(self, graph_module: GraphModule) -> PassResult:
def _convert_getattr(self, graph_module):
modified = False

for _, submodule, _ in get_cond_while_submodules(graph_module):
for _, submodule, _ in get_cond_while_submodules_nested(graph_module):
for submodule_node in submodule.graph.nodes:
if is_get_attr_node(submodule_node):
if submodule_node.target in self._targeted_ops:
self._convert_getattr(submodule)

# For nested control flow, a "node" may be may actually be GraphModule.
# Enure we are only checking for nodes here.
if is_get_attr_node(submodule_node) and not is_submodule_node(
submodule_node
):
val = getattr(
submodule_node.graph.owning_module, submodule_node.target
)
Expand All @@ -53,6 +64,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
submodule_node.replace_all_uses_with(const_node)
submodule.graph.erase_node(submodule_node)
modified = True
return modified

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def call(self, graph_module: GraphModule) -> PassResult:

modified = self._convert_getattr(graph_module)

if modified:
graph_module.recompile()
Expand Down
9 changes: 5 additions & 4 deletions backends/arm/_passes/scalars_to_attribute_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm._passes.arm_pass_utils import (
get_cond_while_submodules_nested,
get_first_fake_tensor,
)
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.exir.graph_module import get_cond_while_submodules

from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule, Node
from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix
Expand Down Expand Up @@ -99,7 +100,7 @@ def handle_control_nodes(self, node: Node, graph_module: GraphModule) -> None:
"""Apply scalar argument conversion on subgraphs of control-flow
nodes.
"""
for _, submodule, _ in get_cond_while_submodules(graph_module):
for _, submodule, _ in get_cond_while_submodules_nested(graph_module):
for submodule_node in submodule.graph.nodes:
# use aten.full.default for scalar constants in control subgraphs
self._convert_scalar_args(submodule, submodule_node)
Expand Down
73 changes: 44 additions & 29 deletions backends/arm/operator_support/control_flow_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,43 @@


def _fully_partitioned(submodule: fx.GraphModule) -> bool:
"""Check that all nested control-flow ops within this submodule are also
fully partitioned.
"""
partition_tag = None

for submodule_node in submodule.graph.nodes:
if submodule_node.op == "call_function":
# Input Q ops and output DQ ops will be de-tagged even if the submodule is fully supported.
if (
submodule_node.target in Q_OPS
and list(submodule_node.all_input_nodes)[0].op == "placeholder"
):
continue
if (
submodule_node.target in DQ_OPS
and list(submodule_node.users)[0].op == "output"
):
continue
if "delegation_tag" not in submodule_node.meta:
return False
if partition_tag is None:
partition_tag = submodule_node.meta["delegation_tag"]
elif submodule_node.meta["delegation_tag"] != partition_tag:
return False
if submodule_node.target in ControlFlowOpSupported._targeted_ops:
if _submodules_fully_partitioned(submodule_node, submodule):
return True

if submodule_node.op != "call_function":
continue
# skip no-op quantize/dequantize boundary
if (
submodule_node.target in Q_OPS
and list(submodule_node.all_input_nodes)[0].op == "placeholder"
):
continue
if (
submodule_node.target in DQ_OPS
and list(submodule_node.users)[0].op == "output"
):
continue

if "delegation_tag" not in submodule_node.meta:
return False

if partition_tag is None:
partition_tag = submodule_node.meta["delegation_tag"]

elif submodule_node.meta["delegation_tag"] != partition_tag:
return False

return True


def _submodules_fully_partitioned(
node: fx.Node, exported_program: ExportedProgram
) -> bool:
def _submodules_fully_partitioned(node: fx.Node, graph_module: fx.GraphModule) -> bool:
"""Returns whether the submodule arguments to a cond node were fully
partitioned.

Expand All @@ -61,20 +72,19 @@ def _submodules_fully_partitioned(
raise ValueError(f"Unexpected target: {node.target}")
cond_submodules = (
(
exported_program.graph_module.get_submodule(
str(cast(torch.fx.Node, submodule_node).target)
),
graph_module.get_submodule(str(cast(torch.fx.Node, submodule_node).target)),
cast(torch.fx.Node, submodule_node),
)
for submodule_node in submodule_args
)
for submodule, submodule_node in cond_submodules:
submodule = cast(torch.fx.GraphModule, submodule)

if _fully_partitioned(submodule):
submodule_node.meta["val"] = submodule.graph.output_node().meta["val"]
else:
if not _fully_partitioned(submodule):
return False
else:
submodule_node.meta["val"] = submodule.graph.output_node().meta["val"]

return True


Expand Down Expand Up @@ -105,6 +115,7 @@ def __init__(
def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:

if is_submodule_node(node):
if not _tosa_spec_supports_cf(self.tosa_spec):
self.reporter.report_reject(
Expand All @@ -118,7 +129,9 @@ def is_node_supported(
node, f"Submodule had unsupported user {user}"
)
return False
if not _submodules_fully_partitioned(user, self.exported_program):
if not _submodules_fully_partitioned(
user, self.exported_program.graph_module
):
self.reporter.report_reject(
node, "One submodule was not fully partitioned"
)
Expand Down Expand Up @@ -161,7 +174,9 @@ def is_node_supported(
)
return False

if not _submodules_fully_partitioned(node, self.exported_program):
if not _submodules_fully_partitioned(
node, self.exported_program.graph_module
):
self.reporter.report_reject(
node, "Submodule was not fully partitioned."
)
Expand Down
62 changes: 33 additions & 29 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
from executorch.backends.arm.common.arm_compile_spec import (
ArmCompileSpec,
) # isort: skip
from executorch.backends.arm._passes.arm_pass_utils import (
get_cond_while_submodules_nested,
is_submodule_node,
)
from executorch.backends.arm.vgf import VgfCompileSpec
from executorch.exir.graph_module import _get_control_flow_submodules

from torch.fx import GraphModule, Node
from torchao.quantization.pt2e import (
Expand Down Expand Up @@ -639,27 +642,6 @@ def validate(self, model: GraphModule) -> None:
f"Quantizer detected operator {node.name} with different device inputs: {devices}."
)

@staticmethod
def _get_submodules_not_handled_by_torchao(
graph_module: GraphModule,
):
"""Returns control flow submodules that torchao's
prepare_pt2e/convert_pt2e do not handle natively. torchao now
recursively handles while_loop body_fn.

(arg 1), so we only need to manually handle:
- cond true/false branches (args 1, 2)
- while_loop cond_fn (arg 0)

"""
return _get_control_flow_submodules(
graph_module,
{
torch.ops.higher_order.cond: [1, 2],
torch.ops.higher_order.while_loop: [0],
},
)

def quantize_with_submodules(
self,
model: GraphModule,
Expand Down Expand Up @@ -691,18 +673,40 @@ def quantize_with_submodules(
prepare_fn = prepare_qat_pt2e if is_qat else prepare_pt2e

prepared = prepare_fn(model, self)
for name, submodule, _ in self._get_submodules_not_handled_by_torchao(prepared):
# Prepare conditional submodules (e.g., if/while bodies)
# prepare only cond branches and while_loop cond_fn
for name, submodule, _ in get_cond_while_submodules_nested(
prepared, apply_quantization=True
):
prepared.set_submodule(name, prepare_fn(submodule, self), strict=True)
for submodule_node in submodule.graph.nodes:
if is_submodule_node(submodule_node):
for nested_name, nested_sub, _ in get_cond_while_submodules_nested(
submodule, apply_quantization=True
):
prepared.set_submodule(
nested_name, prepare_fn(nested_sub, self), strict=True
)

for inp in calibration_samples:
prepared(*inp)

for name, submodule, _ in self._get_submodules_not_handled_by_torchao(prepared):
prepared.set_submodule(
name, convert_pt2e(submodule, fold_quantize=fold_quantize), strict=True
)
converted = convert_pt2e(prepared, fold_quantize=fold_quantize)
# Prepare conditional submodules (e.g., if/while bodies)
# convert only cond branches and while_loop cond_fn
for _, submodule, _ in get_cond_while_submodules_nested(
prepared, apply_quantization=True
):
converted = convert_pt2e(submodule)
for submodule_node in submodule.graph.nodes:
if is_submodule_node(submodule_node):
for nested_name, nested_sub, _ in get_cond_while_submodules_nested(
submodule, apply_quantization=True
):
converted.set_submodule(
nested_name, convert_pt2e(nested_sub), strict=True
)

return converted
return convert_pt2e(prepared)


class EthosUQuantizer(TOSAQuantizer):
Expand Down
Loading
Loading