From 0bba44a7000f8d386f7af96068ed0e6dd61e5bfe Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Tue, 3 Mar 2026 17:53:16 -0800 Subject: [PATCH 1/3] [Cortex-M]: Add int8 I/O quantization to Cortex-M export path Apply QuantizeInputs and QuantizeOutputs passes in the Cortex-M compilation path to strip the float-in/float-out wrapper from quantized models. This produces a fully int8 model that accepts and returns int8 tensors directly. The passes are applied after to_edge_transform_and_lower but before CortexMPassManager, since the latter renames quantized_decomposed ops to cortex_m variants which the I/O passes cannot recognize. --- examples/arm/aot_arm_compiler.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 61fae8ff8fc..551337ec32f 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -47,6 +47,8 @@ from executorch.devtools.backend_debug import get_delegation_info from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite +from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs + from executorch.exir import ( EdgeCompileConfig, ExecutorchBackendConfig, @@ -860,6 +862,17 @@ def _to_channels_last(x): ), ) + # Strip the float I/O wrapper from the quantized model to produce + # fully int8 inputs and outputs. This must run before CortexMPassManager + # which renames quantized_decomposed ops to cortex_m variants. + if args.quantize: + print("Applying passes to create a fully int8 quantized model...") + + edge = edge.transform([ + QuantizeInputs(edge, [0]), + QuantizeOutputs(edge, [0]), + ]) + pass_manager = CortexMPassManager(edge.exported_program()) edge._edge_programs["forward"] = pass_manager.transform() From dd8f4e96ec4d2a977b3f4f97b5cc87b2d2fa0dd3 Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Mon, 9 Mar 2026 23:54:11 -0700 Subject: [PATCH 2/3] Move graph HTML visualization to devtools/visualization/ Add a self-contained HTML graph visualizer as devtools/visualization/html_visualization.py, complementing the existing Model Explorer-based visualization. Generates interactive Cytoscape.js HTML files from .pt2, .pte, .etrecord, and multi-pass trace .json files with no server or external dependencies required. Key changes from the original repo-root visualize_graph.py: - Fix broken ETRecord import (executorch.sdk -> executorch.devtools.etrecord) and rewrite extract_from_etrecord to use correct ETRecord attributes (edge_dialect_program, graph_map) instead of non-existent graph_module - Replace Cortex-M-specific "cortex_m" category with generic "backend" category and configurable _BACKEND_OP_PREFIXES for all backends - Merge duplicate extract_from_pte / extract_from_pte_enhanced into one function with bounds checking and generic delegate blob analysis - Add escapeHtml to single-pass HTML template (XSS fix) - Fix O(n*m) edge filter to O(n) set lookup - Remove dead code (extract_delegated_graph, unreachable PTE branches) - Replace Arm-specific extract_arm_delegate_info with backend-agnostic _extract_delegate_blob_info - Make __init__.py imports from visualization_utils conditional so html_visualization works without model_explorer installed The old visualize_graph.py becomes a thin deprecation shim. Authored with Claude. --- devtools/visualization/TARGETS | 2 + devtools/visualization/__init__.py | 26 +- devtools/visualization/html_visualization.py | 1030 ++++++++++++++++++ trace_cortex_m_passes.py | 445 ++++++++ visualize_graph.py | 42 + 5 files changed, 1539 insertions(+), 6 deletions(-) create mode 100644 devtools/visualization/html_visualization.py create mode 100644 trace_cortex_m_passes.py create mode 100644 visualize_graph.py diff --git a/devtools/visualization/TARGETS b/devtools/visualization/TARGETS index 88a5ba77107..87c86f80551 100644 --- a/devtools/visualization/TARGETS +++ b/devtools/visualization/TARGETS @@ -9,12 +9,14 @@ runtime.python_library( srcs = [ "__init__.py", "visualization_utils.py", + "html_visualization.py", ], visibility = ["PUBLIC"], deps = [ "//caffe2:torch", "//executorch/exir:lib", "//executorch/exir/_serialize:lib", + "//executorch/devtools/etrecord:lib", ], ) diff --git a/devtools/visualization/__init__.py b/devtools/visualization/__init__.py index 8e91d7ffdb2..be0f2f8959d 100644 --- a/devtools/visualization/__init__.py +++ b/devtools/visualization/__init__.py @@ -3,11 +3,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# model_explorer-based visualization (requires model_explorer pip package) +try: + from executorch.devtools.visualization.visualization_utils import ( # noqa: F401 + ModelExplorerServer, + SingletonModelExplorerServer, + visualize, + visualize_graph, + visualize_model_explorer, + ) +except ImportError: + pass -from executorch.devtools.visualization.visualization_utils import ( # noqa: F401 - ModelExplorerServer, - SingletonModelExplorerServer, - visualize, - visualize_graph, - visualize_model_explorer, +# Self-contained HTML visualization (no external dependencies) +from executorch.devtools.visualization.html_visualization import ( # noqa: F401 + extract_from_exported_program, + extract_from_etrecord, + extract_from_pt2, + extract_from_pte, + generate_html, + generate_multi_pass_html, + visualize_edge_manager, ) diff --git a/devtools/visualization/html_visualization.py b/devtools/visualization/html_visualization.py new file mode 100644 index 00000000000..adea83e637a --- /dev/null +++ b/devtools/visualization/html_visualization.py @@ -0,0 +1,1030 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +"""Generate an interactive HTML visualization of an ExecuTorch graph. + +Supports pre-serialization (.pt2), post-serialization (.pte), ETRecord +(.etrecord), and multi-pass trace (.json) files. Produces a self-contained +HTML file using Cytoscape.js with dagre layout. + +Usage: + python3 -m executorch.devtools.visualization.html_visualization model.pt2 + python3 -m executorch.devtools.visualization.html_visualization model.pte -o graph.html + python3 -m executorch.devtools.visualization.html_visualization trace.json -o passes.html + +Authored with Claude. +""" + +import argparse +import json +import os +import re +import sys +from typing import Any, Dict, List, Optional + + +CATEGORY_COLORS = { + "backend": "#4caf50", + "aten_compute": "#2196f3", + "quantize": "#ff9800", + "memory": "#9e9e9e", + "placeholder": "#03a9f4", + "param": "#78909c", + "delegate": "#ab47bc", +} + +# Backend custom-op prefixes. Ops containing these (case-insensitive) are +# categorized as "backend" rather than generic compute. Extend this tuple +# when new backends register custom op namespaces. +_BACKEND_OP_PREFIXES = ( + "cortex_m", + "cadence", + "qaisw", +) + + +def categorize_node(op_name: str) -> str: + name = op_name.lower() + if any(prefix in name for prefix in _BACKEND_OP_PREFIXES): + return "backend" + if any( + k in name + for k in ( + "quantize_per_tensor", + "dequantize_per_", + "quantize_per_channel", + "dequantize_per_channel", + ) + ): + return "quantize" + if any( + k in name + for k in ( + "view", + "clone", + "permute", + "slice", + "copy", + "expand", + "reshape", + "t_copy", + "unsqueeze", + "squeeze", + ) + ): + return "memory" + if any(k in name for k in ("placeholder", "output", "getitem", "get_attr")): + return "placeholder" + if "delegate" in name: + return "delegate" + return "aten_compute" + + +def _make_label(op_name: str) -> str: + name = op_name.split("::")[-1] if "::" in op_name else op_name + if "." in name: + name = name.rsplit(".", 1)[0] + if len(name) > 30: + name = name[:27] + "..." + return name + + +def extract_from_exported_program(ep, model_name: str) -> dict: + """Walk an in-memory ExportedProgram's graph and extract visualization data.""" + graph = ep.graph + + nodes = [] + edges = [] + node_map = {} + + for node in graph.nodes: + node_id = node.name + op_name = node.op + if node.op == "call_function": + op_name = str(node.target) + elif node.op == "call_method": + op_name = node.target + + details = {"op": node.op, "target": str(getattr(node, "target", ""))} + meta_val = node.meta.get("val") + if meta_val is not None: + if hasattr(meta_val, "shape"): + details["shape"] = str(list(meta_val.shape)) + details["dtype"] = str(meta_val.dtype) + elif isinstance(meta_val, (list, tuple)): + shapes = [] + for v in meta_val: + if hasattr(v, "shape"): + shapes.append(f"{list(v.shape)} {v.dtype}") + if shapes: + details["shapes"] = shapes + + category = categorize_node(op_name) + label = _make_label(op_name) + + if node.op == "placeholder": + target = str(getattr(node, "target", "")) + if target.startswith("p_") and target.endswith("_weight"): + label = "weight" + category = "param" + elif target.startswith("p_") and target.endswith("_bias"): + label = "bias" + category = "param" + elif target.startswith("b_") and "running_mean" in target: + label = "bn_mean" + category = "param" + elif target.startswith("b_") and "running_var" in target: + label = "bn_var" + category = "param" + elif target == "x" or not target.startswith(("p_", "b_")): + label = "input" + elif node.op == "output": + label = "output" + + node_map[node_id] = len(nodes) + nodes.append( + { + "id": node_id, + "label": label, + "w": max(len(label) * 8 + 16, 60), + "category": category, + "op_name": op_name, + "details": details, + } + ) + + for arg in node.args: + if hasattr(arg, "name") and arg.name in node_map: + edges.append({"source": arg.name, "target": node_id}) + elif isinstance(arg, (list, tuple)): + for a in arg: + if hasattr(a, "name") and a.name in node_map: + edges.append({"source": a.name, "target": node_id}) + + category_counts = {} + for n in nodes: + cat = n["category"] + category_counts[cat] = category_counts.get(cat, 0) + 1 + + return { + "metadata": { + "model_name": model_name, + "source_type": "pt2", + "total_nodes": len(nodes), + "category_counts": category_counts, + }, + "nodes": nodes, + "edges": edges, + } + + +def extract_from_pt2(path: str) -> dict: + import torch + + ep = torch.export.load(path) + return extract_from_exported_program(ep, os.path.basename(path)) + + +def _extract_delegate_blob_info( + delegate_data: bytes, +) -> Optional[Dict[str, Any]]: + """Extract metadata from a delegate blob. Backend-agnostic.""" + if len(delegate_data) < 8: + return None + info: Dict[str, Any] = {"blob_size_bytes": len(delegate_data)} + op_patterns = re.findall(rb"(?:tosa|aten|xnnpack|qnn)\.\w+", delegate_data) + if op_patterns: + info["detected_ops"] = sorted(set(op.decode() for op in op_patterns[:20])) + return info + + +def extract_from_pte(path: str) -> dict: + from executorch.exir._serialize._program import deserialize_pte_binary + from executorch.exir.schema import DelegateCall, KernelCall, Tensor + + with open(path, "rb") as f: + data = f.read() + + pte_file = deserialize_pte_binary(data) + plan = pte_file.program.execution_plan[0] + + nodes: List[dict] = [] + edges: List[dict] = [] + value_producers: Dict[int, str] = {} + + # Extract delegate blobs for analysis + delegate_info_map: Dict[int, dict] = {} + for idx, delegate in enumerate(plan.delegates): + if hasattr(delegate, "processed") and delegate.processed: + blob = getattr(delegate.processed, "data", None) + if blob: + info = _extract_delegate_blob_info(bytes(blob)) + if info: + delegate_info_map[idx] = info + + for chain_idx, chain in enumerate(plan.chains): + for instr_idx, instr in enumerate(chain.instructions): + args = instr.instr_args + + if isinstance(args, KernelCall): + op = plan.operators[args.op_index] + op_name = f"{op.name}.{op.overload}" if op.overload else op.name + node_id = f"k_{chain_idx}_{instr_idx}" + + details: Dict[str, Any] = {"op_name": op_name} + input_tensors = [] + output_tensors = [] + for val_idx in args.args: + if val_idx < len(plan.values): + val = plan.values[val_idx].val + if isinstance(val, Tensor): + shape_str = ( + f"[{','.join(str(s) for s in val.sizes)}]" + ) + dtype_str = ( + val.scalar_type.name + if hasattr(val.scalar_type, "name") + else str(val.scalar_type) + ) + info_str = f"{shape_str} {dtype_str}" + if val_idx in value_producers: + input_tensors.append(info_str) + else: + output_tensors.append(info_str) + + if input_tensors: + details["inputs"] = input_tensors + if output_tensors: + details["outputs"] = output_tensors + + category = categorize_node(op_name) + label = _make_label(op_name) + + nodes.append( + { + "id": node_id, + "label": label, + "w": max(len(label) * 8 + 16, 60), + "category": category, + "op_name": op_name, + "details": details, + } + ) + + for val_idx in args.args: + if val_idx in value_producers: + edges.append( + {"source": value_producers[val_idx], "target": node_id} + ) + + for val_idx in args.args: + value_producers[val_idx] = node_id + + elif isinstance(args, DelegateCall): + node_id = f"d_{chain_idx}_{instr_idx}" + delegate = plan.delegates[args.delegate_index] + + details = { + "delegate_id": delegate.id, + "delegate_index": args.delegate_index, + } + if args.delegate_index in delegate_info_map: + details.update(delegate_info_map[args.delegate_index]) + + label = delegate.id + if len(label) > 25: + label = label[:22] + "..." + + nodes.append( + { + "id": node_id, + "label": label, + "w": max(len(label) * 8 + 20, 100), + "category": "delegate", + "op_name": f"delegate:{delegate.id}", + "details": details, + } + ) + + for val_idx in args.args: + if val_idx in value_producers: + edges.append( + {"source": value_producers[val_idx], "target": node_id} + ) + + for val_idx in args.args: + value_producers[val_idx] = node_id + + for i, idx in enumerate(plan.inputs): + node_id = f"input_{i}" + val = plan.values[idx].val + details = {"value_index": idx} + if isinstance(val, Tensor): + details["shape"] = list(val.sizes) + details["dtype"] = ( + val.scalar_type.name + if hasattr(val.scalar_type, "name") + else str(val.scalar_type) + ) + + nodes.insert( + 0, + { + "id": node_id, + "label": f"input_{i}", + "w": 70, + "category": "placeholder", + "op_name": "input", + "details": details, + }, + ) + value_producers[idx] = node_id + + for i, idx in enumerate(plan.outputs): + node_id = f"output_{i}" + val = plan.values[idx].val + details = {"value_index": idx} + if isinstance(val, Tensor): + details["shape"] = list(val.sizes) + details["dtype"] = ( + val.scalar_type.name + if hasattr(val.scalar_type, "name") + else str(val.scalar_type) + ) + nodes.append( + { + "id": node_id, + "label": f"output_{i}", + "w": 80, + "category": "placeholder", + "op_name": "output", + "details": details, + } + ) + if idx in value_producers: + edges.append({"source": value_producers[idx], "target": node_id}) + + # Filter edges to only reference existing nodes + node_ids = {n["id"] for n in nodes} + edges = [e for e in edges if e["source"] in node_ids and e["target"] in node_ids] + + category_counts = {} + for n in nodes: + cat = n["category"] + category_counts[cat] = category_counts.get(cat, 0) + 1 + + return { + "metadata": { + "model_name": os.path.basename(path), + "source_type": "pte", + "total_nodes": len(nodes), + "category_counts": category_counts, + }, + "nodes": nodes, + "edges": edges, + } + + +def extract_from_trace_json(path: str) -> dict: + """Load a multi-pass trace JSON.""" + with open(path) as f: + data = json.load(f) + if "passes" not in data: + raise ValueError(f"{path} does not contain a 'passes' key") + return data + + +def extract_from_etrecord(path: str) -> dict: + """Extract visualization data from an ETRecord file.""" + from executorch.devtools.etrecord import parse_etrecord + + etrecord = parse_etrecord(path) + passes = [] + + # edge_dialect_program can be a single ExportedProgram or a dict of them + edp = etrecord.edge_dialect_program + if edp is not None: + if isinstance(edp, dict): + for method_name, ep in edp.items(): + passes.append( + extract_from_exported_program(ep, f"Edge Dialect: {method_name}") + ) + else: + passes.append( + extract_from_exported_program(edp, "Edge Dialect (pre-delegation)") + ) + + # graph_map contains additional stages + if etrecord.graph_map: + for name, ep in etrecord.graph_map.items(): + passes.append(extract_from_exported_program(ep, name)) + + # Fallback to exported_program if nothing else is available + if etrecord.exported_program is not None and not passes: + passes.append( + extract_from_exported_program(etrecord.exported_program, "Exported Program") + ) + + if len(passes) == 0: + raise ValueError(f"No graph data found in {path}") + + if len(passes) == 1: + return passes[0] + + return { + "model_name": os.path.basename(path), + "passes": passes, + } + + +# --------------------------------------------------------------------------- +# Single-pass HTML template +# --------------------------------------------------------------------------- + +HTML_TEMPLATE = """ + + + +ExecuTorch Graph: $$MODEL_NAME$$ + + + + +
+
+
+ + + +
+
+
+ × +

+
+
+
+ + + + + +""" + + +# --------------------------------------------------------------------------- +# Multi-pass HTML template +# --------------------------------------------------------------------------- + +MULTI_PASS_HTML_TEMPLATE = """ + + + +ExecuTorch Pass Trace: $$MODEL_NAME$$ + + + + +
+ + + + Nodes: 0 + +
+
+
+
+
Traceback
+
+
+
+ + + +
+
+
+ × +

+
+
+
+ + + + + +""" + + +def generate_html(graph_data: dict, output_path: str) -> None: + html = HTML_TEMPLATE + html = html.replace("$$MODEL_NAME$$", graph_data["metadata"]["model_name"]) + html = html.replace("$$GRAPH_JSON$$", json.dumps(graph_data)) + html = html.replace("$$COLORS_JSON$$", json.dumps(CATEGORY_COLORS)) + with open(output_path, "w") as f: + f.write(html) + print( + f"Wrote {output_path} " + f"({graph_data['metadata']['total_nodes']} nodes, " + f"{len(graph_data['edges'])} edges)" + ) + + +def generate_multi_pass_html(trace_data: dict, output_path: str) -> None: + model_name = trace_data.get("model_name", "unknown") + passes = trace_data["passes"] + + html = MULTI_PASS_HTML_TEMPLATE + html = html.replace("$$MODEL_NAME$$", model_name) + html = html.replace("$$PASSES_JSON$$", json.dumps(passes)) + html = html.replace("$$COLORS_JSON$$", json.dumps(CATEGORY_COLORS)) + with open(output_path, "w") as f: + f.write(html) + + total_nodes = sum(p["metadata"]["total_nodes"] for p in passes) + print( + f"Wrote {output_path} ({len(passes)} passes, " + f"{total_nodes} total nodes across all snapshots)" + ) + + +def visualize_edge_manager(edge_manager, output_path: str = "graph.html") -> str: + """Visualize an EdgeProgramManager as HTML before to_executorch(). + + Usage in your export script: + from executorch.devtools.visualization.html_visualization import ( + visualize_edge_manager, + ) + + edge_manager = to_edge_transform_and_lower(...) + visualize_edge_manager(edge_manager, "my_model_graph.html") + et_program = edge_manager.to_executorch() + """ + ep = edge_manager.exported_program() + graph_data = extract_from_exported_program(ep, "Edge Manager Graph") + generate_html(graph_data, output_path) + return output_path + + +def main(): + parser = argparse.ArgumentParser( + description="Visualize ExecuTorch graph as interactive HTML" + ) + parser.add_argument("input", help="Path to .pt2, .pte, .etrecord, or .json file") + parser.add_argument("-o", "--output", default=None) + args = parser.parse_args() + + output = args.output or os.path.splitext(args.input)[0] + ".html" + ext = os.path.splitext(args.input)[1].lower() + + if ext == ".json": + trace_data = extract_from_trace_json(args.input) + generate_multi_pass_html(trace_data, output) + elif ext == ".pt2": + graph_data = extract_from_pt2(args.input) + generate_html(graph_data, output) + elif ext == ".pte": + graph_data = extract_from_pte(args.input) + generate_html(graph_data, output) + elif ext in (".etrecord", ".bin"): + data = extract_from_etrecord(args.input) + if "passes" in data: + generate_multi_pass_html(data, output) + else: + generate_html(data, output) + else: + print(f"Error: unsupported '{ext}'", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/trace_cortex_m_passes.py b/trace_cortex_m_passes.py new file mode 100644 index 00000000000..4f4a6a30d92 --- /dev/null +++ b/trace_cortex_m_passes.py @@ -0,0 +1,445 @@ +"""Trace the Cortex-M compilation pipeline pass-by-pass, capturing graph snapshots. + +Runs quantization, export, to_edge, then each CortexMPassManager pass individually, +saving a JSON file with per-pass graph snapshots for use with visualize_graph.py. + +Usage: + python3 trace_cortex_m_passes.py --model mobilenet_v2 -o mv2_trace.json + +Authored with Claude. +""" + +import argparse +import inspect +import json +import os +import sys +import traceback + +import torch +from executorch.backends.arm.constants import DQ_OPS, Q_OPS +from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager +from executorch.backends.cortex_m.quantizer.quantizer import CortexMQuantizer +from executorch.exir import EdgeCompileConfig, to_edge +from executorch.exir.pass_base import ExportPass +from executorch.exir.program._program import _transform +from torch.export import export +from torchao.quantization.pt2e.export_utils import model_is_exported +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + +CATEGORY_COLORS = { + "backend": "#4caf50", + "aten_compute": "#2196f3", + "quantize": "#ff9800", + "memory": "#9e9e9e", + "placeholder": "#03a9f4", + "param": "#78909c", + "delegate": "#ab47bc", +} + + +def categorize_node(op_name: str) -> str: + name = op_name.lower() + if "cortex_m" in name: + return "backend" + if any( + k in name + for k in ( + "quantize_per_tensor", + "dequantize_per_", + "quantize_per_channel", + "dequantize_per_channel", + ) + ): + return "quantize" + if any( + k in name + for k in ( + "view", + "clone", + "permute", + "slice", + "copy", + "expand", + "reshape", + "t_copy", + "unsqueeze", + "squeeze", + ) + ): + return "memory" + if any(k in name for k in ("placeholder", "output", "getitem", "get_attr")): + return "placeholder" + if "delegate" in name: + return "delegate" + return "aten_compute" + + +def _make_label(op_name: str) -> str: + name = op_name.split("::")[-1] if "::" in op_name else op_name + if "." in name: + name = name.rsplit(".", 1)[0] + if len(name) > 30: + name = name[:27] + "..." + return name + + +def _serialize_qparams(qparams_dict): + """Serialize a dict[int, QuantArgs] to JSON-safe form.""" + if not qparams_dict: + return None + result = {} + for idx, qa in qparams_dict.items(): + result[str(idx)] = { + "scale": qa.scale, + "zp": qa.zp, + "qmin": qa.qmin, + "qmax": qa.qmax, + "dtype": str(qa.dtype), + "axis": qa.axis, + "per_channel": qa.per_channel, + } + return result + + +def detect_qdq_groups(graph) -> dict: + """Detect DQ -> compute_op -> Q chains and assign group IDs. + + Returns {node_name: group_id} for nodes in DQ->op->Q chains. + """ + groups = {} + group_id = 0 + + q_op_set = set() + dq_op_set = set() + for node in graph.nodes: + if node.op == "call_function": + target = node.target + if target in Q_OPS: + q_op_set.add(target) + if target in DQ_OPS: + dq_op_set.add(target) + + for node in graph.nodes: + if node.op != "call_function" or node.target not in Q_OPS: + continue + # This is a Q node. Find its compute input. + if not node.args: + continue + compute_node = node.args[0] + if not hasattr(compute_node, "name"): + continue + if not hasattr(compute_node, "target"): + continue + # Skip if compute_node is itself a Q/DQ op + if compute_node.target in Q_OPS or compute_node.target in DQ_OPS: + continue + + # Find DQ nodes feeding the compute node + dq_nodes = [] + for arg in compute_node.args: + if hasattr(arg, "target") and arg.target in DQ_OPS: + dq_nodes.append(arg) + elif isinstance(arg, (list, tuple)): + for a in arg: + if hasattr(a, "target") and a.target in DQ_OPS: + dq_nodes.append(a) + + if not dq_nodes: + continue + + # Assign all members the same group_id + members = [n.name for n in dq_nodes] + [compute_node.name, node.name] + for m in members: + if m not in groups: + groups[m] = group_id + group_id += 1 + + return groups + + +def extract_from_exported_program(ep, stage_name, qdq_groups=None): + """Walk an ExportedProgram's graph and extract visualization data.""" + graph = ep.graph + + nodes = [] + edges = [] + node_map = {} + + for node in graph.nodes: + node_id = node.name + op_name = node.op + if node.op == "call_function": + op_name = str(node.target) + elif node.op == "call_method": + op_name = node.target + + details = {"op": node.op, "target": str(getattr(node, "target", ""))} + + # Shape/dtype from meta + meta_val = node.meta.get("val") + if meta_val is not None: + if hasattr(meta_val, "shape"): + details["shape"] = str(list(meta_val.shape)) + details["dtype"] = str(meta_val.dtype) + elif isinstance(meta_val, (list, tuple)): + shapes = [] + for v in meta_val: + if hasattr(v, "shape"): + shapes.append(f"{list(v.shape)} {v.dtype}") + if shapes: + details["shapes"] = shapes + + # Stack trace + stack_trace = node.meta.get("stack_trace") + if stack_trace: + # Truncate long traces to last 500 chars + if len(stack_trace) > 500: + stack_trace = "..." + stack_trace[-500:] + details["stack_trace"] = stack_trace + + # Quantization params (post-fold stages) + input_qparams = node.meta.get("input_qparams") + if input_qparams: + serialized = _serialize_qparams(input_qparams) + if serialized: + details["input_qparams"] = serialized + + output_qparams = node.meta.get("output_qparams") + if output_qparams: + serialized = _serialize_qparams(output_qparams) + if serialized: + details["output_qparams"] = serialized + + category = categorize_node(op_name) + label = _make_label(op_name) + + # Meaningful labels for placeholders + if node.op == "placeholder": + target = str(getattr(node, "target", "")) + if target.startswith("p_") and target.endswith("_weight"): + label = "weight" + category = "param" + elif target.startswith("p_") and target.endswith("_bias"): + label = "bias" + category = "param" + elif target.startswith("b_") and "running_mean" in target: + label = "bn_mean" + category = "param" + elif target.startswith("b_") and "running_var" in target: + label = "bn_var" + category = "param" + elif target == "x" or not target.startswith(("p_", "b_")): + label = "input" + elif node.op == "output": + label = "output" + + node_data = { + "id": node_id, + "label": label, + "w": max(len(label) * 8 + 16, 60), + "category": category, + "op_name": op_name, + "details": details, + } + + if qdq_groups and node_id in qdq_groups: + node_data["qdq_group_id"] = qdq_groups[node_id] + + node_map[node_id] = len(nodes) + nodes.append(node_data) + + for arg in node.args: + if hasattr(arg, "name") and arg.name in node_map: + edges.append({"source": arg.name, "target": node_id}) + elif isinstance(arg, (list, tuple)): + for a in arg: + if hasattr(a, "name") and a.name in node_map: + edges.append({"source": a.name, "target": node_id}) + + category_counts = {} + for n in nodes: + cat = n["category"] + category_counts[cat] = category_counts.get(cat, 0) + 1 + + return { + "metadata": { + "model_name": stage_name, + "source_type": "trace", + "total_nodes": len(nodes), + "category_counts": category_counts, + "error": None, + }, + "nodes": nodes, + "edges": edges, + } + + +def _to_channels_last(x): + if isinstance(x, torch.Tensor): + return x.to(memory_format=torch.channels_last) if x.dim() == 4 else x + elif isinstance(x, tuple): + return tuple(_to_channels_last(t) for t in x) + return x + + +def get_model(model_name: str): + """Load a model by name, returning (module, example_inputs).""" + if model_name == "mobilenet_v2": + from torchvision.models import mobilenet_v2 + + model = mobilenet_v2(weights=None) + model.eval() + return model, (torch.randn(1, 3, 224, 224),) + elif model_name == "lstm": + from torch.nn.quantizable.modules import rnn + + model = rnn.LSTM(10, 20, 2) + model.eval() + example_inputs = ( + torch.randn(5, 3, 10), + (torch.randn(2, 3, 20), torch.randn(2, 3, 20)), + ) + return model, example_inputs + else: + raise ValueError(f"Unknown model: {model_name}") + + +def run_pipeline(model_name: str) -> dict: + """Run the full Cortex-M pipeline, capturing snapshots after each stage.""" + model, example_inputs = get_model(model_name) + snapshots = [] + + def _randn_like(x): + if isinstance(x, torch.Tensor): + return torch.randn_like(x) + elif isinstance(x, tuple): + return tuple(_randn_like(t) for t in x) + return x + + # --- Stage 1: Quantize --- + print("Quantizing...") + quantizer = CortexMQuantizer() + model = torch.export.export_for_training(model, example_inputs).module() + prepared = prepare_pt2e(model, quantizer) + # Calibrate with random data + with torch.no_grad(): + for _ in range(5): + prepared(*[_randn_like(t) for t in example_inputs]) + quantized = convert_pt2e(prepared) + + # --- Stage 2: Export --- + print("Exporting...") + ep = export(quantized, example_inputs, strict=True) + qdq_groups = detect_qdq_groups(ep.graph) + snapshots.append( + extract_from_exported_program(ep, "1_post_export", qdq_groups) + ) + print(f" 1_post_export: {snapshots[-1]['metadata']['total_nodes']} nodes") + + # --- Stage 3: to_edge --- + print("Converting to edge...") + edge_config = EdgeCompileConfig( + preserve_ops=[ + torch.ops.aten.linear.default, + torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardsigmoid_.default, + torch.ops.aten.hardswish.default, + torch.ops.aten.hardswish_.default, + ], + _check_ir_validity=False, + _core_aten_ops_exception_list=[torch.ops.aten.max_pool2d.default], + ) + edge_program = to_edge(ep, compile_config=edge_config) + edge_ep = edge_program.exported_program() + qdq_groups = detect_qdq_groups(edge_ep.graph) + snapshots.append( + extract_from_exported_program(edge_ep, "2_post_to_edge", qdq_groups) + ) + print(f" 2_post_to_edge: {snapshots[-1]['metadata']['total_nodes']} nodes") + + # --- Stages 4-10+: Individual passes --- + pass_list = CortexMPassManager.pass_list + ep = edge_ep + + for i, pass_cls in enumerate(pass_list): + pass_name = f"{i + 3}_{pass_cls.__name__}" + print(f"Running {pass_name}...") + + try: + signature = inspect.signature(pass_cls.__init__) + if "exported_program" in signature.parameters: + transform_pass = pass_cls(ep) + elif issubclass(pass_cls, ExportPass): + transform_pass = pass_cls() + else: + raise RuntimeError( + f"Unexpected pass type: {pass_cls} ({type(pass_cls)})" + ) + ep = _transform(ep, transform_pass) + + # Detect QDQ groups for pre-fold stages + qdq_groups_for_pass = None + if pass_cls != FoldAndAnnotateQParamsPass: + try: + qdq_groups_for_pass = detect_qdq_groups(ep.graph) + except Exception: + pass + + snapshot = extract_from_exported_program( + ep, pass_name, qdq_groups_for_pass + ) + snapshots.append(snapshot) + print(f" {pass_name}: {snapshot['metadata']['total_nodes']} nodes") + + except Exception as exc: + print(f" ERROR in {pass_name}: {exc}", file=sys.stderr) + error_snapshot = extract_from_exported_program( + ep, f"{pass_name}_ERROR" + ) + error_snapshot["metadata"]["error"] = { + "pass_name": pass_name, + "message": str(exc), + "traceback": traceback.format_exc(), + } + snapshots.append(error_snapshot) + break + + return {"model_name": model_name, "passes": snapshots} + + +# Import here so it's available for the isinstance check in the pass loop +from executorch.backends.arm._passes import FoldAndAnnotateQParamsPass # noqa: E402 + + +def main(): + parser = argparse.ArgumentParser( + description="Trace Cortex-M compilation passes and output JSON snapshots" + ) + parser.add_argument( + "--model", + default="mobilenet_v2", + help="Model name (default: mobilenet_v2)", + ) + parser.add_argument( + "-o", + "--output", + default=None, + help="Output JSON path (default: _trace.json)", + ) + args = parser.parse_args() + + output = args.output or f"{args.model}_trace.json" + result = run_pipeline(args.model) + + with open(output, "w") as f: + json.dump(result, f) + print( + f"Wrote {output} ({len(result['passes'])} passes, " + f"{os.path.getsize(output) / 1024:.0f} KB)" + ) + + +if __name__ == "__main__": + main() diff --git a/visualize_graph.py b/visualize_graph.py new file mode 100644 index 00000000000..b7b018c91ea --- /dev/null +++ b/visualize_graph.py @@ -0,0 +1,42 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +"""DEPRECATED: Moved to executorch.devtools.visualization.html_visualization. + +Update your imports: + from executorch.devtools.visualization.html_visualization import ( + visualize_edge_manager, + generate_html, + ) +""" + +import warnings + +warnings.warn( + "visualize_graph has moved to " + "executorch.devtools.visualization.html_visualization. " + "Update your imports.", + DeprecationWarning, + stacklevel=2, +) + +from executorch.devtools.visualization.html_visualization import ( # noqa: F401 + CATEGORY_COLORS, + categorize_node, + extract_from_exported_program, + extract_from_etrecord, + extract_from_pt2, + extract_from_pte, + extract_from_trace_json, + generate_html, + generate_multi_pass_html, + main, + visualize_edge_manager, +) + +if __name__ == "__main__": + main() From 191ef83bc509ceca978a348838b6389ae4731b5e Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Tue, 10 Mar 2026 07:40:05 -0700 Subject: [PATCH 3/3] Add generic multi-backend pass tracer for graph visualization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add devtools/visualization/trace_passes.py — a generic version of trace_cortex_m_passes.py that works with any ExecuTorch backend. Traces quantization, export, to_edge, then each backend pass individually, producing a JSON file with per-pass graph snapshots for visualization with html_visualization.py. Supports 5 backends out of the box: - cortex_m: full pass-by-pass tracing (8 passes) - xnnpack: full pass-by-pass tracing (16 passes) - cadence: full pass-by-pass tracing - vulkan: export/edge stages (no static pass list) - qnn: export/edge stages (no static pass list) New backends can be added by calling register_backend() with a BackendConfig specifying the quantizer class, pass list source, and edge compile config. Usage: python -m executorch.devtools.visualization.trace_passes \ --backend xnnpack --model mobilenet_v2 -o trace.json python -m executorch.devtools.visualization.html_visualization \ trace.json -o trace.html Authored with Claude. --- devtools/visualization/TARGETS | 1 + devtools/visualization/trace_passes.py | 838 +++++++++++++++++++++++++ 2 files changed, 839 insertions(+) create mode 100644 devtools/visualization/trace_passes.py diff --git a/devtools/visualization/TARGETS b/devtools/visualization/TARGETS index 87c86f80551..c6b485f1461 100644 --- a/devtools/visualization/TARGETS +++ b/devtools/visualization/TARGETS @@ -10,6 +10,7 @@ runtime.python_library( "__init__.py", "visualization_utils.py", "html_visualization.py", + "trace_passes.py", ], visibility = ["PUBLIC"], deps = [ diff --git a/devtools/visualization/trace_passes.py b/devtools/visualization/trace_passes.py new file mode 100644 index 00000000000..1ee188cd4d2 --- /dev/null +++ b/devtools/visualization/trace_passes.py @@ -0,0 +1,838 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +"""Trace any ExecuTorch backend's compilation pipeline pass-by-pass. + +Runs quantization, export, to_edge, then each backend pass individually, +saving a JSON file with per-pass graph snapshots for visualization with +html_visualization.py. + +Usage: + # Cortex-M backend: + python -m executorch.devtools.visualization.trace_passes \\ + --backend cortex_m --model mobilenet_v2 -o trace.json + + # XNNPACK backend: + python -m executorch.devtools.visualization.trace_passes \\ + --backend xnnpack --model mobilenet_v2 -o trace.json + + # Cadence backend: + python -m executorch.devtools.visualization.trace_passes \\ + --backend cadence --model mobilenet_v2 -o trace.json + + # Skip quantization (trace passes on a float model): + python -m executorch.devtools.visualization.trace_passes \\ + --backend xnnpack --model mobilenet_v2 --no-quantize -o trace.json + + # Then visualize: + python -m executorch.devtools.visualization.html_visualization trace.json -o trace.html + +Authored with Claude. +""" + +import argparse +import inspect +import json +import os +import sys +import traceback +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from executorch.exir import EdgeCompileConfig, to_edge +from executorch.exir.pass_base import ExportPass +from executorch.exir.program._program import _transform +from torch.export import export + + +# --------------------------------------------------------------------------- +# Node categorization (shared with html_visualization.py) +# --------------------------------------------------------------------------- + +CATEGORY_COLORS = { + "backend": "#4caf50", + "aten_compute": "#2196f3", + "quantize": "#ff9800", + "memory": "#9e9e9e", + "placeholder": "#03a9f4", + "param": "#78909c", + "delegate": "#ab47bc", +} + +_BACKEND_OP_PREFIXES = ( + "cortex_m", + "cadence", + "qaisw", +) + + +def categorize_node(op_name: str) -> str: + name = op_name.lower() + if any(prefix in name for prefix in _BACKEND_OP_PREFIXES): + return "backend" + if any( + k in name + for k in ( + "quantize_per_tensor", + "dequantize_per_", + "quantize_per_channel", + "dequantize_per_channel", + ) + ): + return "quantize" + if any( + k in name + for k in ( + "view", + "clone", + "permute", + "slice", + "copy", + "expand", + "reshape", + "t_copy", + "unsqueeze", + "squeeze", + ) + ): + return "memory" + if any(k in name for k in ("placeholder", "output", "getitem", "get_attr")): + return "placeholder" + if "delegate" in name: + return "delegate" + return "aten_compute" + + +def _make_label(op_name: str) -> str: + name = op_name.split("::")[-1] if "::" in op_name else op_name + if "." in name: + name = name.rsplit(".", 1)[0] + if len(name) > 30: + name = name[:27] + "..." + return name + + +# --------------------------------------------------------------------------- +# QDQ group detection +# --------------------------------------------------------------------------- + + +def _get_qdq_ops(): + """Get Q/DQ op sets. Works with or without Arm constants.""" + q_ops = set() + dq_ops = set() + try: + from executorch.backends.arm.constants import DQ_OPS, Q_OPS + + q_ops.update(Q_OPS) + dq_ops.update(DQ_OPS) + except ImportError: + pass + # Fallback: common quantize/dequantize ops + for op_name in [ + "quantize_per_tensor", + "quantize_per_channel", + "quantize_per_tensor_tensor", + ]: + op = getattr(torch.ops.quantized_decomposed, op_name, None) + if op is not None: + q_ops.add(op.default) + for op_name in [ + "dequantize_per_tensor", + "dequantize_per_channel", + "dequantize_per_tensor_tensor", + ]: + op = getattr(torch.ops.quantized_decomposed, op_name, None) + if op is not None: + dq_ops.add(op.default) + return q_ops, dq_ops + + +def detect_qdq_groups(graph) -> dict: + """Detect DQ -> compute_op -> Q chains and assign group IDs.""" + q_ops, dq_ops = _get_qdq_ops() + if not q_ops or not dq_ops: + return {} + + groups = {} + group_id = 0 + + for node in graph.nodes: + if node.op != "call_function" or node.target not in q_ops: + continue + if not node.args: + continue + compute_node = node.args[0] + if not hasattr(compute_node, "name") or not hasattr(compute_node, "target"): + continue + if compute_node.target in q_ops or compute_node.target in dq_ops: + continue + + dq_nodes = [] + for arg in compute_node.args: + if hasattr(arg, "target") and arg.target in dq_ops: + dq_nodes.append(arg) + elif isinstance(arg, (list, tuple)): + for a in arg: + if hasattr(a, "target") and a.target in dq_ops: + dq_nodes.append(a) + + if not dq_nodes: + continue + + members = [n.name for n in dq_nodes] + [compute_node.name, node.name] + for m in members: + if m not in groups: + groups[m] = group_id + group_id += 1 + + return groups + + +# --------------------------------------------------------------------------- +# Graph snapshot extraction +# --------------------------------------------------------------------------- + + +def _serialize_qparams(qparams_dict): + if not qparams_dict: + return None + result = {} + for idx, qa in qparams_dict.items(): + entry = {} + for attr in ("scale", "zp", "qmin", "qmax", "axis", "per_channel"): + val = getattr(qa, attr, None) + if val is not None: + entry[attr] = val + dtype = getattr(qa, "dtype", None) + if dtype is not None: + entry["dtype"] = str(dtype) + result[str(idx)] = entry + return result + + +def extract_from_exported_program(ep, stage_name, qdq_groups=None): + """Walk an ExportedProgram's graph and extract visualization data.""" + graph = ep.graph + nodes = [] + edges = [] + node_map = {} + + for node in graph.nodes: + node_id = node.name + op_name = node.op + if node.op == "call_function": + op_name = str(node.target) + elif node.op == "call_method": + op_name = node.target + + details = {"op": node.op, "target": str(getattr(node, "target", ""))} + + meta_val = node.meta.get("val") + if meta_val is not None: + if hasattr(meta_val, "shape"): + details["shape"] = str(list(meta_val.shape)) + details["dtype"] = str(meta_val.dtype) + elif isinstance(meta_val, (list, tuple)): + shapes = [] + for v in meta_val: + if hasattr(v, "shape"): + shapes.append(f"{list(v.shape)} {v.dtype}") + if shapes: + details["shapes"] = shapes + + stack_trace = node.meta.get("stack_trace") + if stack_trace: + if len(stack_trace) > 500: + stack_trace = "..." + stack_trace[-500:] + details["stack_trace"] = stack_trace + + input_qparams = node.meta.get("input_qparams") + if input_qparams: + serialized = _serialize_qparams(input_qparams) + if serialized: + details["input_qparams"] = serialized + + output_qparams = node.meta.get("output_qparams") + if output_qparams: + serialized = _serialize_qparams(output_qparams) + if serialized: + details["output_qparams"] = serialized + + category = categorize_node(op_name) + label = _make_label(op_name) + + if node.op == "placeholder": + target = str(getattr(node, "target", "")) + if target.startswith("p_") and target.endswith("_weight"): + label = "weight" + category = "param" + elif target.startswith("p_") and target.endswith("_bias"): + label = "bias" + category = "param" + elif target.startswith("b_") and "running_mean" in target: + label = "bn_mean" + category = "param" + elif target.startswith("b_") and "running_var" in target: + label = "bn_var" + category = "param" + elif target == "x" or not target.startswith(("p_", "b_")): + label = "input" + elif node.op == "output": + label = "output" + + node_data = { + "id": node_id, + "label": label, + "w": max(len(label) * 8 + 16, 60), + "category": category, + "op_name": op_name, + "details": details, + } + + if qdq_groups and node_id in qdq_groups: + node_data["qdq_group_id"] = qdq_groups[node_id] + + node_map[node_id] = len(nodes) + nodes.append(node_data) + + for arg in node.args: + if hasattr(arg, "name") and arg.name in node_map: + edges.append({"source": arg.name, "target": node_id}) + elif isinstance(arg, (list, tuple)): + for a in arg: + if hasattr(a, "name") and a.name in node_map: + edges.append({"source": a.name, "target": node_id}) + + category_counts = {} + for n in nodes: + cat = n["category"] + category_counts[cat] = category_counts.get(cat, 0) + 1 + + return { + "metadata": { + "model_name": stage_name, + "source_type": "trace", + "total_nodes": len(nodes), + "category_counts": category_counts, + "error": None, + }, + "nodes": nodes, + "edges": edges, + } + + +# --------------------------------------------------------------------------- +# Backend registry +# --------------------------------------------------------------------------- + + +class BackendConfig: + """Configuration for tracing a backend's compilation pipeline.""" + + def __init__( + self, + name: str, + quantizer_cls: Optional[str] = None, + pass_list_source: Optional[str] = None, + edge_compile_config: Optional[EdgeCompileConfig] = None, + # Some passes should skip QDQ detection after them + skip_qdq_after: Optional[List[str]] = None, + ): + self.name = name + self._quantizer_cls_path = quantizer_cls + self._pass_list_source = pass_list_source + self.edge_compile_config = edge_compile_config or EdgeCompileConfig( + _check_ir_validity=False, + ) + self.skip_qdq_after = set(skip_qdq_after or []) + + def get_quantizer(self): + if not self._quantizer_cls_path: + return None + module_path, cls_name = self._quantizer_cls_path.rsplit(".", 1) + import importlib + + mod = importlib.import_module(module_path) + cls = getattr(mod, cls_name) + return cls() + + def get_pass_list(self) -> List[Type]: + """Return a list of pass classes to iterate through.""" + if not self._pass_list_source: + return [] + + import importlib + + if ":" in self._pass_list_source: + # "module.path:function_or_attribute" format + module_path, attr = self._pass_list_source.split(":", 1) + mod = importlib.import_module(module_path) + obj = getattr(mod, attr) + if callable(obj) and not isinstance(obj, type): + return obj() + return list(obj) + else: + # "module.path.Class.pass_list" — dotted attribute access + parts = self._pass_list_source.rsplit(".", 1) + module_path, attr = parts + # Try importing as module.Class.attr + try: + mod = importlib.import_module(module_path) + return list(getattr(mod, attr)) + except (ImportError, AttributeError): + # Try one level up: module.Class has attr + mod_path, cls_name = module_path.rsplit(".", 1) + mod = importlib.import_module(mod_path) + cls = getattr(mod, cls_name) + return list(getattr(cls, attr)) + + +# Backend configurations — add new backends here +_BACKEND_REGISTRY: Dict[str, BackendConfig] = {} + + +def register_backend(config: BackendConfig): + _BACKEND_REGISTRY[config.name] = config + + +def _register_builtin_backends(): + register_backend( + BackendConfig( + name="cortex_m", + quantizer_cls=( + "executorch.backends.cortex_m.quantizer.quantizer" + ".CortexMQuantizer" + ), + pass_list_source=( + "executorch.backends.cortex_m.passes" + ".cortex_m_pass_manager.CortexMPassManager.pass_list" + ), + edge_compile_config=EdgeCompileConfig( + preserve_ops=[ + torch.ops.aten.linear.default, + torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardsigmoid_.default, + torch.ops.aten.hardswish.default, + torch.ops.aten.hardswish_.default, + ], + _check_ir_validity=False, + _core_aten_ops_exception_list=[ + torch.ops.aten.max_pool2d.default, + ], + ), + skip_qdq_after=["FoldAndAnnotateQParamsPass"], + ) + ) + + register_backend( + BackendConfig( + name="xnnpack", + quantizer_cls=( + "executorch.backends.xnnpack.quantizer" + ".xnnpack_quantizer.XNNPACKQuantizer" + ), + pass_list_source=( + "executorch.devtools.visualization.trace_passes" + ":_default_pass_list" + ), + ) + ) + + register_backend( + BackendConfig( + name="cadence", + quantizer_cls=( + "executorch.backends.cadence.aot.quantizer" + ".quantizer.CadenceDefaultQuantizer" + ), + pass_list_source=( + "executorch.backends.cadence.aot.passes" + ":get_passes_in_default_order" + ), + ) + ) + + register_backend( + BackendConfig( + name="vulkan", + quantizer_cls=( + "executorch.backends.vulkan.quantizer" + ".vulkan_quantizer.VulkanQuantizer" + ), + # Vulkan has no static pass list — passes are inline in preprocess() + pass_list_source=None, + ) + ) + + register_backend( + BackendConfig( + name="qnn", + quantizer_cls=( + "executorch.backends.qualcomm.quantizer" + ".quantizer.QnnQuantizer" + ), + # QNN uses dynamic pipeline methods, no static pass list + pass_list_source=None, + ) + ) + + +# --------------------------------------------------------------------------- +# Pass instantiation +# --------------------------------------------------------------------------- + + +def _instantiate_pass(pass_cls, exported_program): + """Instantiate a pass class, passing exported_program if needed.""" + sig = inspect.signature(pass_cls.__init__) + params = list(sig.parameters.keys()) + # Skip 'self' + if "exported_program" in params: + return pass_cls(exported_program) + elif len(params) > 1 and params[1] == "exported_program": + return pass_cls(exported_program) + else: + try: + return pass_cls() + except TypeError: + # Some passes require arguments we can't infer + return None + + +# --------------------------------------------------------------------------- +# Models +# --------------------------------------------------------------------------- + + +def get_model(model_name: str) -> Tuple[torch.nn.Module, tuple]: + """Load a model by name, returning (module, example_inputs).""" + if model_name == "mobilenet_v2": + from torchvision.models import mobilenet_v2 + + model = mobilenet_v2(weights=None) + model.eval() + return model, (torch.randn(1, 3, 224, 224),) + elif model_name == "mobilenet_v3_small": + from torchvision.models import mobilenet_v3_small + + model = mobilenet_v3_small(weights=None) + model.eval() + return model, (torch.randn(1, 3, 224, 224),) + elif model_name == "resnet18": + from torchvision.models import resnet18 + + model = resnet18(weights=None) + model.eval() + return model, (torch.randn(1, 3, 224, 224),) + elif model_name == "resnet50": + from torchvision.models import resnet50 + + model = resnet50(weights=None) + model.eval() + return model, (torch.randn(1, 3, 224, 224),) + elif model_name == "lstm": + from torch.nn.quantizable.modules import rnn + + model = rnn.LSTM(10, 20, 2) + model.eval() + return model, ( + torch.randn(5, 3, 10), + (torch.randn(2, 3, 20), torch.randn(2, 3, 20)), + ) + else: + raise ValueError( + f"Unknown model: {model_name}. " + f"Available: mobilenet_v2, mobilenet_v3_small, resnet18, resnet50, lstm" + ) + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + + +def _randn_like(x): + if isinstance(x, torch.Tensor): + return torch.randn_like(x) + elif isinstance(x, tuple): + return tuple(_randn_like(t) for t in x) + return x + + +def run_pipeline( + backend_config: BackendConfig, + model_name: str, + quantize: bool = True, +) -> dict: + """Run a backend's compilation pipeline, capturing snapshots.""" + model, example_inputs = get_model(model_name) + snapshots = [] + stage_num = 0 + + # --- Stage: Quantize --- + if quantize: + quantizer = backend_config.get_quantizer() + if quantizer is None: + print( + f"Warning: no quantizer configured for {backend_config.name}, " + f"skipping quantization" + ) + else: + stage_num += 1 + print("Quantizing...") + from torchao.quantization.pt2e.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + ) + + model = torch.export.export_for_training( + model, example_inputs + ).module() + prepared = prepare_pt2e(model, quantizer) + with torch.no_grad(): + for _ in range(5): + prepared(*[_randn_like(t) for t in example_inputs]) + model = convert_pt2e(prepared) + + # --- Stage: Export --- + stage_num += 1 + print("Exporting...") + ep = export(model, example_inputs, strict=True) + qdq_groups = detect_qdq_groups(ep.graph) + snapshots.append( + extract_from_exported_program( + ep, f"{stage_num}_post_export", qdq_groups + ) + ) + print(f" {stage_num}_post_export: {snapshots[-1]['metadata']['total_nodes']} nodes") + + # --- Stage: to_edge --- + stage_num += 1 + print("Converting to edge...") + edge_program = to_edge(ep, compile_config=backend_config.edge_compile_config) + edge_ep = edge_program.exported_program() + qdq_groups = detect_qdq_groups(edge_ep.graph) + snapshots.append( + extract_from_exported_program( + edge_ep, f"{stage_num}_post_to_edge", qdq_groups + ) + ) + print( + f" {stage_num}_post_to_edge: " + f"{snapshots[-1]['metadata']['total_nodes']} nodes" + ) + + # --- Stages: Individual passes --- + pass_list = backend_config.get_pass_list() + if not pass_list: + print( + f" No static pass list for {backend_config.name} — " + f"showing export/edge stages only" + ) + else: + ep = edge_ep + for i, pass_cls in enumerate(pass_list): + stage_num += 1 + pass_name = f"{stage_num}_{pass_cls.__name__}" + print(f"Running {pass_name}...") + + try: + transform_pass = _instantiate_pass(pass_cls, ep) + if transform_pass is None: + print(f" Skipping {pass_name} (cannot instantiate)") + stage_num -= 1 + continue + + ep = _transform(ep, transform_pass) + + qdq_groups_for_pass = None + if pass_cls.__name__ not in backend_config.skip_qdq_after: + try: + qdq_groups_for_pass = detect_qdq_groups(ep.graph) + except Exception: + pass + + snapshot = extract_from_exported_program( + ep, pass_name, qdq_groups_for_pass + ) + snapshots.append(snapshot) + print( + f" {pass_name}: " + f"{snapshot['metadata']['total_nodes']} nodes" + ) + + except Exception as exc: + print(f" ERROR in {pass_name}: {exc}", file=sys.stderr) + error_snapshot = extract_from_exported_program( + ep, f"{pass_name}_ERROR" + ) + error_snapshot["metadata"]["error"] = { + "pass_name": pass_name, + "message": str(exc), + "traceback": traceback.format_exc(), + } + snapshots.append(error_snapshot) + break + + return { + "model_name": f"{model_name} ({backend_config.name})", + "passes": snapshots, + } + + +# --------------------------------------------------------------------------- +# XNNPACK helper — extract default pass list from XNNPACKPassManager +# --------------------------------------------------------------------------- + + +def _default_pass_list(): + """Extract the default pass list from XNNPACKPassManager. + + XNNPACKPassManager stores passes as an instance attribute, so we need + to peek at the default set in __init__. + """ + try: + from executorch.backends.xnnpack._passes import XNNPACKPassManager + + # XNNPACKPassManager.__init__ takes exported_program, but we just + # need the default pass class list. Use the source to get it. + src = inspect.getsource(XNNPACKPassManager.__init__) + # The pass list is assigned as self.passes = [...] in __init__ + # Rather than parsing source, just grab from a dummy instance + # Actually, we can't instantiate without an exported_program. + # Fall back to the known default list. + from executorch.backends.xnnpack._passes import ( + ChannelsLastTaggedReshapePass, + ConstPropPass, + Conv1dUnsqueezePass, + ConvertToLinearPass, + ConvertToSDPAPass, + ConvertToUpsampleBilinear2d, + DecomposeBatchNorm, + DecomposeConcatenate, + DimOrderOpsRevertPass, + FuseActivationPass, + FuseBatchNormPass, + PReLUReshapePass, + PropagateCustomMetaPass, + RemoveGetItemPass, + RemoveRedundantCopyPass, + XNNPACKRemoveCloneOpsTransform, + ) + + return [ + XNNPACKRemoveCloneOpsTransform, + DimOrderOpsRevertPass, + ConvertToUpsampleBilinear2d, + ConvertToLinearPass, + PropagateCustomMetaPass, + ConvertToSDPAPass, + ConstPropPass, + FuseBatchNormPass, + DecomposeBatchNorm, + FuseActivationPass, + DecomposeConcatenate, + RemoveGetItemPass, + Conv1dUnsqueezePass, + PReLUReshapePass, + ChannelsLastTaggedReshapePass, + RemoveRedundantCopyPass, + ] + except ImportError as e: + print(f"Warning: Could not import XNNPACK passes: {e}", file=sys.stderr) + return [] + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main(): + _register_builtin_backends() + + parser = argparse.ArgumentParser( + description="Trace ExecuTorch backend compilation passes", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=( + "Available backends: " + + ", ".join(sorted(_BACKEND_REGISTRY.keys())) + + "\n\nAvailable models: mobilenet_v2, mobilenet_v3_small, " + "resnet18, resnet50, lstm" + ), + ) + parser.add_argument( + "--backend", + default=None, + help="Backend name (e.g., cortex_m, xnnpack, cadence)", + ) + parser.add_argument( + "--model", + default="mobilenet_v2", + help="Model name (default: mobilenet_v2)", + ) + parser.add_argument( + "-o", + "--output", + default=None, + help="Output JSON path (default: __trace.json)", + ) + parser.add_argument( + "--no-quantize", + action="store_true", + help="Skip quantization (trace on float model)", + ) + parser.add_argument( + "--list-backends", + action="store_true", + help="List available backends and exit", + ) + args = parser.parse_args() + + if args.list_backends: + print("Available backends:") + for name, cfg in sorted(_BACKEND_REGISTRY.items()): + has_passes = "with passes" if cfg._pass_list_source else "export/edge only" + has_quant = "quantized" if cfg._quantizer_cls_path else "float" + print(f" {name:15s} ({has_quant}, {has_passes})") + return + + if not args.backend: + parser.error("--backend is required (use --list-backends to see options)") + + if args.backend not in _BACKEND_REGISTRY: + print( + f"Error: unknown backend '{args.backend}'. " + f"Available: {', '.join(sorted(_BACKEND_REGISTRY.keys()))}", + file=sys.stderr, + ) + sys.exit(1) + + backend_config = _BACKEND_REGISTRY[args.backend] + output = args.output or f"{args.model}_{args.backend}_trace.json" + + result = run_pipeline( + backend_config, + args.model, + quantize=not args.no_quantize, + ) + + with open(output, "w") as f: + json.dump(result, f) + print( + f"\nWrote {output} ({len(result['passes'])} passes, " + f"{os.path.getsize(output) / 1024:.0f} KB)" + ) + print( + f"Visualize: python -m executorch.devtools.visualization" + f".html_visualization {output} -o {output.replace('.json', '.html')}" + ) + + +if __name__ == "__main__": + main()