diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index e9fbc4778f5..969906a7c10 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -1649,3 +1649,129 @@ def get_stacktraces_for_row(aot_ops: List[str]) -> Dict[str, Optional[str]]: df["stacktraces"] = df["aot_ops"].apply(get_stacktraces_for_row) return df + + def calculate_numeric_gap_from_taps( + self, + flat_runtime_outputs: Sequence, + tap_specs: Sequence, + distance: Union[str, NumericalComparatorBase] = "MSE", + reference_graph: Optional[str] = None, + disable_debug_handle_valdiation: bool = False, + ) -> pd.DataFrame: + """ + Compares AOT intermediate outputs (from the ETRecord, captured by + IntermediateOutputCapturer) with runtime tap values exposed as + USER_OUTPUTs by the `intermediate_output_tap` package. + + Unlike `calculate_numeric_gap`, this method works through delegates + with no backend-side support: the runtime values come from extra + outputs the AOT pass added to the ExportedProgram before lowering. + + IMPORTANT: ETRecord serialization regenerates `debug_handle`s during + roundtrip, so the handles in `tap_specs` (set at AOT-pass time) are + stale. Each spec's `reducer_node_name` (set by + `strip_taps_(edge, tap_specs=specs)`) is used to look up the + post-roundtrip `debug_handle` in the AOT reference graph. + + Args: + flat_runtime_outputs: The flat output tuple/list returned by + running the lowered program (e.g. `Method.execute(inputs)`). + tap_specs: The list of TapSpec returned by + `strip_taps_(edge, tap_specs=specs)` — these carry + `reducer_node_name` for alignment. + distance: "MSE", "L1", "SNR", or a `NumericalComparatorBase`. + reference_graph: AOT graph to use as the golden. See + `calculate_numeric_gap` for valid values. + disable_debug_handle_valdiation: Bypass debug handle validation. + + Returns: + DataFrame with one row per (aot_handle, runtime_handle) pair. + Same shape produced by `calculate_numeric_gap`. + """ + reference_graph_module, _resolved_graph_name = self._resolve_reference_graph( + reference_graph, + disable_debug_handle_valdiation, + ) + aot_intermediate_outputs, aot_debug_handle_to_op_names = ( + self._get_aot_intermediate_outputs_and_op_names(reference_graph_module) + ) + if len(aot_intermediate_outputs) == 0: + raise ValueError( + "No AOT intermediate outputs were captured. ETRecord must be " + "provided with representative_inputs for tap-based comparison." + ) + + spec_handles = _lookup_handles_by_name(reference_graph_module, tap_specs) + + runtime_intermediate_outputs: Dict[DebugHandle, Tuple[Any, int]] = {} + runtime_debug_handle_to_op_names: Dict[DebugHandle, List[str]] = {} + for spec, dh in zip(tap_specs, spec_handles): + if dh is None: + continue + key: DebugHandle = (int(dh),) + runtime_intermediate_outputs[key] = ( + flat_runtime_outputs[spec.output_index], + 1, + ) + runtime_debug_handle_to_op_names[key] = [spec.op_target] + + if not runtime_intermediate_outputs: + raise ValueError( + "Could not recover any post-roundtrip handles for tap_specs. " + "Verify that strip_taps_(edge, tap_specs=specs) was called and " + "the returned specs were passed to this method, and that " + "generate_etrecord ran AFTER strip_taps_." + ) + + mapping = map_runtime_aot_intermediate_outputs( + aot_intermediate_outputs, runtime_intermediate_outputs + ) + + if isinstance(distance, NumericalComparatorBase): + comparator = distance + if comparator.inspector is None: + comparator.inspector = self + else: + metric = distance.strip().upper() + if metric == "MSE": + comparator = MSEComparator(inspector=self) + elif metric == "L1": + comparator = L1Comparator(inspector=self) + elif metric == "SNR": + comparator = SNRComparator(inspector=self) + else: + raise ValueError(f"Unsupported distance metric {distance!r}") + + return comparator.compare( + mapping, + aot_debug_handle_to_op_names, + runtime_debug_handle_to_op_names, + ) + + +def _lookup_handles_by_name( + reference_graph_module, + tap_specs: Sequence, +) -> List[Optional[int]]: + """ + For each TapSpec, return the post-roundtrip `debug_handle` of the FX node + whose name equals `spec.reducer_node_name`. Returns None for any spec + without a `reducer_node_name` set or whose name is not found. + + `reducer_node_name` is set by `strip_taps_(edge, tap_specs=specs)`. FX + node names survive ETRecord serialization roundtrip, so this lookup is + stable. + """ + name_to_handle: Dict[str, Optional[int]] = {} + for n in reference_graph_module.graph.nodes: + h = n.meta.get("debug_handle") + name_to_handle[n.name] = int(h) if isinstance(h, int) else None + + out: List[Optional[int]] = [] + for spec in tap_specs: + rn = getattr(spec, "reducer_node_name", None) + if rn is None: + out.append(None) + continue + out.append(name_to_handle.get(rn)) + return out diff --git a/devtools/intermediate_output_tap/TARGETS b/devtools/intermediate_output_tap/TARGETS new file mode 100644 index 00000000000..5ae2ddc8380 --- /dev/null +++ b/devtools/intermediate_output_tap/TARGETS @@ -0,0 +1,85 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "spec", + srcs = ["_spec.py"], +) + +runtime.python_library( + name = "custom_ops_lib", + srcs = ["custom_ops_lib.py"], + deps = [ + "//caffe2:torch", + ], +) + +runtime.python_library( + name = "selectors", + srcs = ["_selectors.py"], + deps = [ + "//caffe2:torch", + ], +) + +runtime.python_library( + name = "reducers", + srcs = ["_reducers.py"], + deps = [ + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ], +) + +runtime.python_library( + name = "tap_pass", + srcs = ["_tap_pass.py"], + deps = [ + "//caffe2:torch", + ":custom_ops_lib", + ":reducers", + ":selectors", + ":spec", + ], +) + +runtime.python_library( + name = "strip_pass", + srcs = ["_strip_pass.py"], + deps = [ + "//caffe2:torch", + ":reducers", + ":tap_pass", + ], +) + +runtime.python_library( + name = "convenience", + srcs = ["_convenience.py"], + deps = [ + "fbsource//third-party/pypi/pandas:pandas", + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/runtime:runtime", + ":reducers", + ":selectors", + ":spec", + ":strip_pass", + ":tap_pass", + ], +) + +runtime.python_library( + name = "lib", + srcs = ["__init__.py"], + deps = [ + ":convenience", + ":custom_ops_lib", + ":reducers", + ":selectors", + ":spec", + ":strip_pass", + ":tap_pass", + ], +) diff --git a/devtools/intermediate_output_tap/__init__.py b/devtools/intermediate_output_tap/__init__.py new file mode 100644 index 00000000000..03221ad4c76 --- /dev/null +++ b/devtools/intermediate_output_tap/__init__.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Public API for the ExecuTorch numerical debugger. + +Backend-agnostic intermediate-value tap that complements the existing +Inspector framework: + +- AOT side : `IntermediateOutputCapturer` (existing) +- Runtime side : ETDump intermediate output events (existing, opaque inside delegates) +- Runtime side : USER_OUTPUT taps (this module — works through delegates without + any backend-side changes) + +Typical usage: + + from executorch.devtools.intermediate_output_tap import ( + tap_intermediate_outputs, strip_taps_, DEFAULT_STATS, + ) + + ep = export(model, example_inputs) + ep_tapped, specs = tap_intermediate_outputs(ep, reducer=DEFAULT_STATS) + edge = to_edge_transform_and_lower(ep_tapped, partitioner=[XnnpackPartitioner()]) + strip_taps_(edge) + et_program = edge.to_executorch() + + flat_outputs = runtime.forward(*example_inputs) + df = inspector.calculate_numeric_gap_from_taps(flat_outputs, specs) +""" + +from executorch.devtools.intermediate_output_tap import ( + custom_ops_lib, # noqa: F401 ensures torch.ops.executorch_devtools.tap is registered +) +from executorch.devtools.intermediate_output_tap._convenience import ( + compare_aot_runtime_dataframe, + format_tap_dataframe, + specs_to_dataframe, + tap_all_and_run, +) +from executorch.devtools.intermediate_output_tap._reducers import ( + ABS_MAX_ONLY, + DEFAULT_STATS, + FULL_TENSOR, + get_reducer, + MIN_MAX_MEAN, + StatReducer, +) +from executorch.devtools.intermediate_output_tap._selectors import ( + NodeSelector, + select_all, + select_all_call_function, + select_any, + select_by_meta_tag, + select_by_module_path, + select_by_op_type, + select_not, +) +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_ +from executorch.devtools.intermediate_output_tap._tap_pass import ( + find_tap_nodes, + is_tap_node, + tap_intermediate_outputs, +) + + +__all__ = [ + # Core API + "tap_intermediate_outputs", + "strip_taps_", + "TapSpec", + # Convenience + "tap_all_and_run", + "specs_to_dataframe", + "format_tap_dataframe", + "compare_aot_runtime_dataframe", + # Reducers + "StatReducer", + "FULL_TENSOR", + "ABS_MAX_ONLY", + "MIN_MAX_MEAN", + "DEFAULT_STATS", + "get_reducer", + # Selectors + "NodeSelector", + "select_all_call_function", + "select_by_op_type", + "select_by_module_path", + "select_by_meta_tag", + "select_any", + "select_all", + "select_not", + # Helpers + "find_tap_nodes", + "is_tap_node", +] diff --git a/devtools/intermediate_output_tap/_convenience.py b/devtools/intermediate_output_tap/_convenience.py new file mode 100644 index 00000000000..a2a884b9ec9 --- /dev/null +++ b/devtools/intermediate_output_tap/_convenience.py @@ -0,0 +1,299 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +One-line convenience wrapper for the most common smoke-test workflow: + + df = tap_all_and_run(model, example_inputs, partitioner=[XnnpackPartitioner()]) + +Exports `model`, taps every call_function, lowers with the user's partitioner, +runs through the ExecuTorch runtime, and returns a pandas DataFrame of one row +per tap (one column per stat field). No Inspector setup, no ETRecord. For +AOT-vs-runtime numerical comparison, use Inspector.calculate_numeric_gap_from_taps, +then `format_tap_dataframe(df, tap_specs)` to get a friendly view. +""" + +from __future__ import annotations + +import os +import tempfile +from collections.abc import Sequence +from typing import Any + +import pandas as pd +import torch +from executorch.devtools.intermediate_output_tap._reducers import ( + get_reducer, + StatReducer, +) +from executorch.devtools.intermediate_output_tap._selectors import ( + NodeSelector, + select_all_call_function, +) +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_ +from executorch.devtools.intermediate_output_tap._tap_pass import ( + tap_intermediate_outputs, +) + + +def tap_all_and_run( + model: torch.nn.Module, + example_inputs: tuple[Any, ...], + partitioner: list | None = None, + reducer: str | StatReducer = "DEFAULT_STATS", + selector: NodeSelector | None = None, + skip_if_no_debug_handle: bool = True, +) -> pd.DataFrame: + """ + Export -> tap -> lower -> strip -> to_executorch -> run -> DataFrame. + + Returns a DataFrame indexed by tap with columns: + node_name, op_target, debug_handle, output_index, reducer_name, plus + one column per reducer field (or `value` for FULL_TENSOR). + """ + from executorch.exir import to_edge_transform_and_lower + + selector = selector or select_all_call_function() + ep = torch.export.export(model, example_inputs, strict=True) + ep_tapped, specs = tap_intermediate_outputs( + ep, + selector=selector, + reducer=reducer, + skip_if_no_debug_handle=skip_if_no_debug_handle, + ) + edge = to_edge_transform_and_lower( + ep_tapped, partitioner=partitioner or [] + ) + strip_taps_(edge) + et_program = edge.to_executorch() + + flat_outputs = _run_pte(et_program, example_inputs) + return specs_to_dataframe(specs, flat_outputs) + + +def _run_pte(et_program, example_inputs: tuple[Any, ...]) -> Sequence[Any]: + from executorch.runtime import Runtime, Verification + + with tempfile.TemporaryDirectory() as temp_dir: + pte_path = os.path.join(temp_dir, "model.pte") + et_program.save(pte_path) + rt = Runtime.get() + program = rt.load_program(pte_path, verification=Verification.Minimal) + method = program.load_method("forward") + return method.execute(example_inputs) + + +def specs_to_dataframe( + specs: Sequence[TapSpec], + flat_outputs: Sequence[Any], +) -> pd.DataFrame: + """Build a per-tap DataFrame from the tap_specs + flat output tuple.""" + rows = [] + for spec in specs: + runtime_value = flat_outputs[spec.output_index] + row: dict[str, Any] = { + "node_name": spec.node_name, + "op_target": spec.op_target, + "debug_handle": spec.debug_handle, + "output_index": spec.output_index, + "reducer_name": spec.reducer_name, + } + if spec.fields: + tensor_vals = ( + runtime_value.detach().cpu().tolist() + if isinstance(runtime_value, torch.Tensor) + else list(runtime_value) + ) + for i, field in enumerate(spec.fields): + row[field] = tensor_vals[i] if i < len(tensor_vals) else None + else: + row["value"] = runtime_value + rows.append(row) + return pd.DataFrame(rows) + + +def format_tap_dataframe( + df: pd.DataFrame, + tap_specs: Sequence[TapSpec], +) -> pd.DataFrame: + """ + Reshape the raw DataFrame returned by + `Inspector.calculate_numeric_gap_from_taps` into a friendlier per-tap, + per-field view. + + The raw DataFrame uses the existing Inspector comparator format, which + packs the reducer's stat tensor into a list of 0-d tensors and labels + rows by the post-strip reducer node name (e.g. `aten_stack_default`). + This helper: + - matches each raw row to a TapSpec (by reducer_node_name) + - renames `aot_ops`/`runtime_ops` columns to a single `node_name` (the + original source node name, e.g. `linear`, `linear_1`) + - expands the reducer stat tensor into one column per field + (e.g. `aot_min`, `rt_min`, `aot_max`, `rt_max`, ...) + - flattens the gap to a single float + - drops the verbose `aot_intermediate_output` / `runtime_intermediate_output` + list columns + + Returns a DataFrame with columns: + node_name, op_target, reducer_name, + gap, + aot_, rt_, aot_, rt_, ... + """ + # Map reducer_node_name -> spec for quick lookup. + name_to_spec: dict[str, TapSpec] = { + s.reducer_node_name: s + for s in tap_specs + if s.reducer_node_name is not None + } + + rows = [] + for _, row in df.iterrows(): + aot_ops = row.get("aot_ops", []) + spec = None + for op in aot_ops or []: + if op in name_to_spec: + spec = name_to_spec[op] + break + if spec is None: + # Couldn't match — keep a thin row with whatever we have. + rows.append( + { + "node_name": ",".join(aot_ops or []), + "op_target": "?", + "reducer_name": "?", + "gap": _flatten_gap(row.get("gap")), + } + ) + continue + + new_row: dict[str, Any] = { + "node_name": spec.node_name, + "op_target": spec.op_target, + "reducer_name": spec.reducer_name, + "gap": _flatten_gap(row.get("gap")), + } + aot_vals = _to_float_list(row.get("aot_intermediate_output")) + rt_vals = _to_float_list(row.get("runtime_intermediate_output")) + for i, field in enumerate(spec.fields): + new_row[f"aot_{field}"] = aot_vals[i] if i < len(aot_vals) else None + new_row[f"rt_{field}"] = rt_vals[i] if i < len(rt_vals) else None + if not spec.fields: # FULL_TENSOR + new_row["aot_value"] = row.get("aot_intermediate_output") + new_row["rt_value"] = row.get("runtime_intermediate_output") + rows.append(new_row) + return pd.DataFrame(rows) + + +def _flatten_gap(g: Any) -> float | None: + if g is None: + return None + if isinstance(g, list): + if not g: + return None + g = g[0] + if isinstance(g, torch.Tensor): + return float(g) + try: + return float(g) + except (TypeError, ValueError): + return None + + +def _to_float_list(v: Any) -> list[float]: + if v is None: + return [] + if isinstance(v, torch.Tensor): + return v.detach().cpu().tolist() + if isinstance(v, list): + out: list[float] = [] + for x in v: + if isinstance(x, torch.Tensor): + out.append(float(x)) + else: + try: + out.append(float(x)) + except (TypeError, ValueError): + out.append(float("nan")) + return out + return [] + + +def _flat_floats(v: Any) -> list[float]: + """Flatten a tap value (tensor / list / scalar) to a flat list of floats.""" + if isinstance(v, torch.Tensor): + return [float(x) for x in v.detach().to(torch.float32).cpu().reshape(-1).tolist()] + if isinstance(v, (list, tuple)): + out: list[float] = [] + for x in v: + out.extend(_flat_floats(x)) + return out + try: + return [float(v)] + except (TypeError, ValueError): + return [] + + +def compare_aot_runtime_dataframe( + specs: Sequence[TapSpec], + aot_flat: Sequence[Any], + rt_flat: Sequence[Any], +) -> pd.DataFrame: + """ + Build a side-by-side AOT-vs-runtime DataFrame from the flat outputs of + the *tapped* ExportedProgram (eager) and the post-strip runtime program. + + AOT side: + `aot_flat[spec.output_index]` is the **raw** tapped tensor — at eager + time `tap.Tensor` is identity, so the output is the source op's + output. We apply the reducer's `eager` callable to reproduce what + `strip_taps_` materialises in the runtime graph. + + Runtime side: + `rt_flat[spec.output_index]` already contains the reduced 1-D tensor + (or original tensor for FULL_TENSOR). + + Returns one row per spec with columns: + node_name, op_target, reducer_name, output_index, + aot_, rt_, aot_, rt_, ... + """ + rows: list[dict[str, Any]] = [] + for spec in specs: + aot_raw = aot_flat[spec.output_index] + rt_raw = rt_flat[spec.output_index] + + # AOT raw might be wrapped in a 1-tuple; unwrap first tensor. + if not isinstance(aot_raw, torch.Tensor) and isinstance( + aot_raw, (list, tuple) + ) and aot_raw: + aot_raw = aot_raw[0] + + reducer = get_reducer(spec.reducer_name) + if isinstance(aot_raw, torch.Tensor): + aot_reduced = reducer.eager(aot_raw) + else: + aot_reduced = aot_raw + + aot_vals = _flat_floats(aot_reduced) + rt_vals = _flat_floats(rt_raw) + + fields = list(spec.fields) if spec.fields else [ + f"v{i}" for i in range(max(len(aot_vals), len(rt_vals))) + ] + row: dict[str, Any] = { + "node_name": spec.node_name, + "module_path": spec.module_path, + "op_target": spec.op_target, + "reducer_name": spec.reducer_name, + "output_index": spec.output_index, + } + for i, f in enumerate(fields): + row[f"aot_{f}"] = aot_vals[i] if i < len(aot_vals) else float("nan") + row[f"rt_{f}"] = rt_vals[i] if i < len(rt_vals) else float("nan") + rows.append(row) + return pd.DataFrame(rows) diff --git a/devtools/intermediate_output_tap/_reducers.py b/devtools/intermediate_output_tap/_reducers.py new file mode 100644 index 00000000000..b2d19c00c66 --- /dev/null +++ b/devtools/intermediate_output_tap/_reducers.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Stat reducers used by `tap_intermediate_outputs`. + +A `StatReducer` is a small specification consumed by `strip_taps_` (after +`to_backend`) to materialise a portable reducer subgraph in place of the +`executorch_devtools::tap.Tensor` placeholder. + +`emit(graph, src_node) -> fx.Node` builds the reducer subgraph in `graph` +just before the placeholder, using the source tensor `src_node` as input, +and returns the final node whose output replaces the placeholder's output. + +The emit functions cast to fp32 first for cross-backend numerical stability +and use full-tensor reductions (no `dim=`) so the result is a stable shape +regardless of the source tensor's rank. + +For v1 we ship: FULL_TENSOR, ABS_MAX_ONLY, MIN_MAX_MEAN, DEFAULT_STATS. +HISTOGRAM_64 is deferred (`aten.histc` has restricted edge support). +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch +from executorch.exir.dialects._ops import ops as exir_ops + + +if TYPE_CHECKING: + import torch.fx as fx + + +# --- Reducer dataclass --------------------------------------------------- + +EmitFn = Callable[["fx.Graph", "fx.Node"], "fx.Node"] +EagerFn = Callable[[torch.Tensor], torch.Tensor] + + +@dataclass(frozen=True) +class StatReducer: + """ + A reducer specification. + + `emit` is invoked by `strip_taps_` to materialise the reducer subgraph + in the post-lowering FX graph. `eager` is the equivalent pure-torch + implementation, used by callers that want to reproduce what the runtime + will compute (e.g. AOT-vs-runtime comparisons without a debugger). + + `name` is what the user types and what's stored on each TapSpec. + `fields` enumerates the columns of the 1-D output tensor (empty for + FULL_TENSOR, which preserves the original tensor shape). + """ + + name: str + fields: tuple[str, ...] + emit: EmitFn + eager: EagerFn + + +# --- Helpers ------------------------------------------------------------- + + +def _cast_fp32(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + """Insert a fp32 cast (no-op semantically if already fp32).""" + # exir_ops.edge.dim_order_ops._to_dim_order_copy.default exists for edge dialect, + # but the simpler aten._to_copy variant is broadly supported. + return graph.call_function( + exir_ops.edge.aten._to_copy.default, + args=(x,), + kwargs={"dtype": torch.float32}, + ) + + +def _scalar_node(graph: "fx.Graph", op, x: "fx.Node") -> "fx.Node": + """Call a full-reduction op (amin/amax/mean/sum) producing a 0-d tensor.""" + return graph.call_function(op, args=(x,)) + + +def _stack(graph: "fx.Graph", scalars: list["fx.Node"]) -> "fx.Node": + """Stack a list of 0-d tensors into a 1-D tensor.""" + return graph.call_function( + exir_ops.edge.aten.stack.default, + args=(scalars,), + kwargs={"dim": 0}, + ) + + +# --- Built-in reducers --------------------------------------------------- + + +def _emit_full_tensor(_graph: "fx.Graph", src: "fx.Node") -> "fx.Node": + """Identity — return the source node directly. strip_taps_ will splice.""" + return src + + +def _eager_full_tensor(t: torch.Tensor) -> torch.Tensor: + return t.detach() + + +FULL_TENSOR: StatReducer = StatReducer( + name="FULL_TENSOR", + fields=(), + emit=_emit_full_tensor, + eager=_eager_full_tensor, +) + + +def _emit_abs_max(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": + f = _cast_fp32(graph, src) + abs_x = graph.call_function(exir_ops.edge.aten.abs.default, args=(f,)) + return _scalar_node(graph, exir_ops.edge.aten.amax.default, abs_x) + + +def _eager_abs_max(t: torch.Tensor) -> torch.Tensor: + f = t.detach().to(torch.float32) + return f.abs().amax() + + +ABS_MAX_ONLY: StatReducer = StatReducer( + name="ABS_MAX_ONLY", + fields=("abs_max",), + emit=_emit_abs_max, + eager=_eager_abs_max, +) + + +def _emit_min_max_mean(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": + f = _cast_fp32(graph, src) + mn = _scalar_node(graph, exir_ops.edge.aten.amin.default, f) + mx = _scalar_node(graph, exir_ops.edge.aten.amax.default, f) + me = _scalar_node(graph, exir_ops.edge.aten.mean.default, f) + return _stack(graph, [mn, mx, me]) + + +def _eager_min_max_mean(t: torch.Tensor) -> torch.Tensor: + f = t.detach().to(torch.float32) + return torch.stack([f.amin(), f.amax(), f.mean()], dim=0) + + +MIN_MAX_MEAN: StatReducer = StatReducer( + name="MIN_MAX_MEAN", + fields=("min", "max", "mean"), + emit=_emit_min_max_mean, + eager=_eager_min_max_mean, +) + + +def _emit_default_stats(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": + """ + Default stats: (min, max, mean, abs_max) — 4 floats. + + NOTE: nan_count/inf_count/std are intentionally excluded because the + underlying portable kernels (`isnan`, `isinf`, `sum.dtype`, `std.*`) + don't all have out variants registered in ExecuTorch's default runtime + op table, which fails memory planning or runtime method-load. If you + need them, supply a custom StatReducer. + """ + f = _cast_fp32(graph, src) + mn = _scalar_node(graph, exir_ops.edge.aten.amin.default, f) + mx = _scalar_node(graph, exir_ops.edge.aten.amax.default, f) + me = _scalar_node(graph, exir_ops.edge.aten.mean.default, f) + + abs_x = graph.call_function(exir_ops.edge.aten.abs.default, args=(f,)) + abs_max = _scalar_node(graph, exir_ops.edge.aten.amax.default, abs_x) + + return _stack(graph, [mn, mx, me, abs_max]) + + +def _eager_default_stats(t: torch.Tensor) -> torch.Tensor: + f = t.detach().to(torch.float32) + return torch.stack( + [f.amin(), f.amax(), f.mean(), f.abs().amax()], + dim=0, + ) + + +DEFAULT_STATS: StatReducer = StatReducer( + name="DEFAULT_STATS", + fields=("min", "max", "mean", "abs_max"), + emit=_emit_default_stats, + eager=_eager_default_stats, +) + + +# --- Registry ------------------------------------------------------------- + +_BUILTIN_REDUCERS: dict[str, StatReducer] = { + r.name: r + for r in (FULL_TENSOR, ABS_MAX_ONLY, MIN_MAX_MEAN, DEFAULT_STATS) +} + + +def get_reducer(name_or_reducer: str | StatReducer) -> StatReducer: + """Look up a built-in by name, or return a user-supplied StatReducer as-is.""" + if isinstance(name_or_reducer, StatReducer): + return name_or_reducer + if name_or_reducer not in _BUILTIN_REDUCERS: + raise ValueError( + f"Unknown reducer {name_or_reducer!r}; " + f"available: {sorted(_BUILTIN_REDUCERS)}" + ) + return _BUILTIN_REDUCERS[name_or_reducer] diff --git a/devtools/intermediate_output_tap/_selectors.py b/devtools/intermediate_output_tap/_selectors.py new file mode 100644 index 00000000000..004c057df13 --- /dev/null +++ b/devtools/intermediate_output_tap/_selectors.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +Predicates for selecting which FX nodes to tap. + +A `NodeSelector` is just `Callable[[fx.Node], bool]`. The provided builders +let you compose them by op type, by `nn_module_stack` path, by arbitrary meta +tag, and via boolean combinators. + +Examples: + selector = select_any( + select_by_op_type("aten.linear.default", "aten.matmul.default"), + select_by_module_path("layers.*.attention"), + ) + selector = select_all(selector, select_not(select_by_op_type("aten.view.default"))) +""" + +from __future__ import annotations + +import fnmatch +from collections.abc import Callable +from typing import Any + +import torch.fx as fx + + +NodeSelector = Callable[[fx.Node], bool] + + +def select_all_call_function( + exclude: tuple[str, ...] = ("getitem",), +) -> NodeSelector: + """Match every `call_function` node whose target name is not in `exclude`.""" + excluded = set(exclude) + + def predicate(n: fx.Node) -> bool: + if n.op != "call_function": + return False + target_name = getattr(n.target, "__name__", str(n.target)) + # `getitem` shows up as the builtin name; also normalise common aten suffixes. + return target_name not in excluded and str(n.target) not in excluded + + return predicate + + +def select_by_op_type(*op_targets: str) -> NodeSelector: + """ + Match nodes whose `str(node.target)` ends with any of `op_targets`. + + The "ends with" rule lets the user write either the short name + ("aten.linear.default") or a fully-qualified name and have it match. + """ + if not op_targets: + raise ValueError("select_by_op_type requires at least one op target") + suffixes = tuple(op_targets) + + def predicate(n: fx.Node) -> bool: + if n.op != "call_function": + return False + target_str = str(n.target) + return any(target_str.endswith(s) or target_str == s for s in suffixes) + + return predicate + + +def select_by_module_path(pattern: str) -> NodeSelector: + """ + Match nodes whose `nn_module_stack` (the chain of nn.Module attribute names + walked to reach this op during tracing) contains a path matching `pattern`. + `pattern` is a shell-glob (fnmatch) — e.g. "layers.*", "layers.0.attention", + "*.attention.*". + """ + + def predicate(n: fx.Node) -> bool: + stack = n.meta.get("nn_module_stack") + if not stack: + return False + # nn_module_stack is an OrderedDict: id -> (qualified_path, module_type) + for entry in stack.values(): + path = entry[0] if isinstance(entry, tuple) else entry + if fnmatch.fnmatchcase(path, pattern): + return True + return False + + return predicate + + +# Sentinel: matches when the meta key exists at all, regardless of value. +_ANY_VALUE: object = object() + + +def select_by_meta_tag(key: str, value: Any = _ANY_VALUE) -> NodeSelector: + """ + Match nodes that carry `node.meta[key]`. If `value` is provided, also + requires `node.meta[key] == value`. + """ + + def predicate(n: fx.Node) -> bool: + if key not in n.meta: + return False + if value is _ANY_VALUE: + return True + return n.meta[key] == value + + return predicate + + +def select_any(*selectors: NodeSelector) -> NodeSelector: + """Match if ANY of `selectors` matches.""" + if not selectors: + return lambda _n: False + sels = tuple(selectors) + return lambda n: any(s(n) for s in sels) + + +def select_all(*selectors: NodeSelector) -> NodeSelector: + """Match if ALL of `selectors` match.""" + if not selectors: + return lambda _n: True + sels = tuple(selectors) + return lambda n: all(s(n) for s in sels) + + +def select_not(selector: NodeSelector) -> NodeSelector: + """Match if `selector` does NOT match.""" + return lambda n: not selector(n) diff --git a/devtools/intermediate_output_tap/_spec.py b/devtools/intermediate_output_tap/_spec.py new file mode 100644 index 00000000000..c043a142913 --- /dev/null +++ b/devtools/intermediate_output_tap/_spec.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +TapSpec records one tap inserted by `tap_intermediate_outputs(...)`. + +A list of TapSpecs is returned to the user from the AOT pass; the user passes +that same list to `Inspector.calculate_numeric_gap_from_taps(...)` at runtime +to demux the flat output tuple back into per-op intermediate values. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class TapSpec: + """ + Metadata about a single tap. + + Attributes: + node_name: The FX node name of the *source* node (the tapped op) at the + time the AOT pass ran. Useful for debugging / pretty-printing. + op_target: `str(node.target)` of the source node, e.g. + "aten.linear.default". + debug_handle: `node.meta["debug_handle"]` of the source node, or None + if the source had no debug handle. Set at AOT-pass time. NOT used + by the Inspector integration directly — the serializer regenerates + debug_handles, so Inspector aligns by `reducer_node_name` instead. + output_index: 0-based index into the runtime program's flat output + tuple where this tap's value lands. Computed at AOT time and stable + through `to_edge` / `to_backend` / `to_executorch` because we only + ever *append* to the output node and `OutputSpec`. + reducer_name: Name of the StatReducer used (e.g. "DEFAULT_STATS"). + fields: Names of the per-element fields in the reducer's output tensor + (e.g. ("min", "max", "abs_max")). Empty tuple for FULL_TENSOR. + stack_trace: `node.meta["stack_trace"]` of the source node if present, + for human-readable error messages. + reducer_node_name: The FX node name of the post-strip reducer terminal + node — i.e., the node whose value is surfaced as the runtime tap + output. Populated by `strip_taps_` when `tap_specs` is passed. + FX node names survive ETRecord serialization roundtrip, so this + is the stable bridge `Inspector.calculate_numeric_gap_from_taps` + uses to find the post-roundtrip handle for alignment. + """ + + node_name: str + op_target: str + debug_handle: int | None + output_index: int + reducer_name: str + fields: tuple[str, ...] + stack_trace: str | None = None + reducer_node_name: str | None = None + module_path: str | None = None diff --git a/devtools/intermediate_output_tap/_strip_pass.py b/devtools/intermediate_output_tap/_strip_pass.py new file mode 100644 index 00000000000..5047dba0cff --- /dev/null +++ b/devtools/intermediate_output_tap/_strip_pass.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Post-`to_backend` pass: replace each `executorch_devtools::tap.Tensor` node +with either an identity edge (FULL_TENSOR) or a portable reducer subgraph +(DEFAULT_STATS, MIN_MAX_MEAN, ABS_MAX_ONLY). + +Pattern stolen from `remove_graph_break_` in +`executorch/examples/apple/coreml/llama/export_static_llm_coreml.py`. + +This pass MUST run *after* `to_edge_transform_and_lower(...)` and *before* +`to_executorch()`. Running it before partitioning would defeat the whole +mechanism (the reducer ops would be eligible for delegation). + +When called with the `tap_specs` from `tap_intermediate_outputs`, this pass +also populates `TapSpec.reducer_node_name` for each spec — the FX node name +of the post-strip reducer terminal. This is the bridge +`Inspector.calculate_numeric_gap_from_taps` uses to recover the +post-ETRecord-roundtrip `debug_handle` for alignment. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import replace as _dataclass_replace + +import torch.fx as fx +from executorch.devtools.intermediate_output_tap._reducers import get_reducer +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from executorch.devtools.intermediate_output_tap._tap_pass import find_tap_nodes + + +def strip_taps_( + edge_manager, + tap_specs: Sequence[TapSpec] | None = None, +) -> list[TapSpec] | None: + """ + Replace every `tap.Tensor(src, reducer_name, debug_handle)` node in every + method of `edge_manager` with the materialised reducer subgraph, in place. + + For FULL_TENSOR the placeholder is collapsed (the source node's value + flows directly to whatever consumed the placeholder). + + Args: + edge_manager: An EdgeProgramManager (post-`to_edge_transform_and_lower`). + tap_specs: Optional. If provided, the pass returns a NEW list of + TapSpecs with `reducer_node_name` populated for each spec — the + FX name of the post-strip reducer terminal node. This list must + be passed to `Inspector.calculate_numeric_gap_from_taps` for + alignment to work. + + Returns: + Updated tap_specs list if `tap_specs` was provided, else None. + """ + # Walk in graph order; tap nodes appear in the same order they were + # created by `tap_intermediate_outputs`, which is the same order as + # `tap_specs`. Track each tap's replacement node so we can update the + # corresponding spec. + replacement_names: list[str | None] = [] + for method_name in edge_manager.methods: + ep = edge_manager.exported_program(method_name) + gm = ep.graph_module + for replacement_node in _strip_taps_in_graph_module(gm): + replacement_names.append( + replacement_node.name if replacement_node is not None else None + ) + + if tap_specs is None: + return None + + if len(tap_specs) != len(replacement_names): + raise RuntimeError( + f"strip_taps_: tap_specs length ({len(tap_specs)}) does not match " + f"the number of tap nodes found in the edge_manager " + f"({len(replacement_names)}). The strip pass cannot align specs " + f"to reducer nodes. Did you call strip_taps_ on a different " + f"edge_manager than the one produced from the tapped EP?" + ) + + return [ + _dataclass_replace(spec, reducer_node_name=name) + for spec, name in zip(tap_specs, replacement_names) + ] + + +def _strip_taps_in_graph_module(gm: fx.GraphModule) -> list[fx.Node | None]: + """ + Strip taps in a single GraphModule. Returns the list of replacement nodes + in tap-creation order (same as graph order). For FULL_TENSOR taps the + "replacement" is the source node itself (since the tap collapses to + identity). + """ + graph = gm.graph + tap_nodes = find_tap_nodes(gm) + if not tap_nodes: + return [] + + output_node = graph.output_node() + replacements: list[fx.Node | None] = [] + + # Compute next available debug_handle so each reducer terminal gets a + # unique one (necessary so Inspector can look it up by node name and find + # a non-None handle in the post-roundtrip graph). + existing_handles = [ + n.meta.get("debug_handle") + for n in graph.nodes + if isinstance(n.meta.get("debug_handle"), int) + ] + next_handle = (max(existing_handles) + 1) if existing_handles else 1 + + for tap in tap_nodes: + # tap.args = (src_node, reducer_name, debug_handle) + src, reducer_name, dh = tap.args[0], tap.args[1], tap.args[2] + reducer = get_reducer(str(reducer_name)) + + if reducer.name == "FULL_TENSOR": + # Identity: re-route all consumers to the source. The "reducer + # terminal" is the source itself. + tap.replace_all_uses_with(src) + replacements.append(src if isinstance(src, fx.Node) else None) + continue + + # Build the reducer subgraph (reads from src). + with graph.inserting_before(tap): + replacement = reducer.emit(graph, src) + # Always assign a debug_handle to the reducer terminal so Inspector + # can find it post-roundtrip. Prefer the source's pre-tap handle if + # available (carries semantic meaning); otherwise use next_handle. + if dh: + replacement.meta["debug_handle"] = dh + else: + replacement.meta["debug_handle"] = next_handle + next_handle += 1 + replacement.meta["is_tap"] = True + replacement.meta["source_node"] = ( + src.name if isinstance(src, fx.Node) else None + ) + + # `tap` may have ended up in the data path during to_edge's re-trace + # (because CompositeExplicitAutograd preserves the op as an identity + # node, and re-traced consumers point at it instead of `src`). So: + # - the OUTPUT-node use becomes the reducer (the value we want + # surfaced as a tap). + # - every OTHER use is rewritten back to `src` (identity passthrough), + # restoring the original data path. + for use_node in list(tap.users.keys()): + if use_node is output_node: + new_outs = tuple( + replacement if a is tap else a for a in output_node.args[0] + ) + output_node.args = (new_outs,) + else: + use_node.replace_input_with(tap, src) + replacements.append(replacement) + + graph.eliminate_dead_code() + gm.recompile() + return replacements diff --git a/devtools/intermediate_output_tap/_tap_pass.py b/devtools/intermediate_output_tap/_tap_pass.py new file mode 100644 index 00000000000..aa4a30bf6c0 --- /dev/null +++ b/devtools/intermediate_output_tap/_tap_pass.py @@ -0,0 +1,261 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +AOT pass: insert `tap.Tensor` placeholders after selected nodes and surface +them as additional USER_OUTPUTs of the ExportedProgram. + +Pattern stolen from `executorch/exir/passes/weights_to_outputs_pass.py`: +- find existing output node +- build new output args (existing + new tap nodes) +- create new output node, replace_all_uses_with, erase old +- append OutputSpec(USER_OUTPUT) entries to gs.output_specs +- eliminate_dead_code() + recompile() +""" + +from __future__ import annotations + +import copy +from collections.abc import Callable + +import torch +import torch.fx as fx +from executorch.devtools.intermediate_output_tap import custom_ops_lib # noqa: F401 registers tap.Tensor +from executorch.devtools.intermediate_output_tap._reducers import ( + DEFAULT_STATS, + get_reducer, + StatReducer, +) +from executorch.devtools.intermediate_output_tap._selectors import ( + NodeSelector, + select_all_call_function, +) +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from torch.export import ExportedProgram +from torch.export.exported_program import OutputKind, OutputSpec, TensorArgument + + +# Don't ever tap our own tap nodes if a user runs the pass twice. +# `tap.Tensor` is already an OpOverload (not a packet) since "Tensor" is the +# overload name — same convention as torch.ops.executorch_utils.graph_break.Tensor. +_TAP_TARGET = torch.ops.executorch_devtools.tap.Tensor + + +def _is_tap_node(n: fx.Node) -> bool: + return n.op == "call_function" and n.target is _TAP_TARGET + + +def tap_intermediate_outputs( + ep: ExportedProgram, + selector: NodeSelector | None = None, + reducer: str | StatReducer = DEFAULT_STATS, + *, + tap_name_prefix: str = "tap_", + skip_if_no_debug_handle: bool = False, + max_taps: int | None = None, + inplace: bool = False, +) -> tuple[ExportedProgram, list[TapSpec]]: + """ + Rewrite `ep` so each node matching `selector` has its output appended to + the program outputs (wrapped in a `tap.Tensor` placeholder that survives + partitioning). Returns the new ExportedProgram and a list of TapSpecs. + + The returned EP is safe to feed to + `to_edge_transform_and_lower(...).to_executorch()` *after* calling + `strip_taps_(edge_manager)` to replace the placeholders with their + reducer subgraphs (or identities, for FULL_TENSOR). + + Args: + ep: The ExportedProgram to tap. + selector: A predicate over fx.Node. Defaults to + `select_all_call_function()`. Tap nodes themselves are always + excluded so re-running the pass is idempotent. + reducer: Either a built-in reducer name ("DEFAULT_STATS", + "MIN_MAX_MEAN", "ABS_MAX_ONLY", "FULL_TENSOR") or a custom + StatReducer instance. + tap_name_prefix: Prefix for the tap nodes' names. Helps when + grepping the dumped graph. + skip_if_no_debug_handle: If True, only tap nodes that already + carry `node.meta["debug_handle"]`. Recommended for Inspector + integration since handle-less taps cannot be aligned with + AOT outputs. + max_taps: Optional cap on number of taps. Helps avoid OOM for + very large models. + inplace: If False (default), deep-copy `ep` before mutating. + """ + if selector is None: + selector = select_all_call_function() + reducer_obj = get_reducer(reducer) + + if not inplace: + ep = copy.deepcopy(ep) + + gs = ep.graph_signature + gm = ep.graph_module + graph = gm.graph + output_node = graph.output_node() + existing_outputs = list(output_node.args[0]) + + # Snapshot before we start mutating the graph. + candidate_nodes = [n for n in graph.nodes if not _is_tap_node(n)] + + specs: list[TapSpec] = [] + new_tap_nodes: list[fx.Node] = [] + + for node in candidate_nodes: + if node.op != "call_function" or not selector(node): + continue + debug_handle = node.meta.get("debug_handle") + if skip_if_no_debug_handle and debug_handle is None: + continue + if max_taps is not None and len(specs) >= max_taps: + break + + # tap.Tensor's int arg cannot be None; sentinel 0 means "no handle". + dh_arg = int(debug_handle) if isinstance(debug_handle, int) else 0 + + with graph.inserting_after(node): + tap_node = graph.call_function( + _TAP_TARGET, + args=(node, reducer_obj.name, dh_arg), + ) + # Don't override the auto-assigned name — FX guarantees uniqueness. + # Stash the prefixed-source-name in meta for human-readable logs. + tap_node.meta["tap_label"] = f"{tap_name_prefix}{node.name}" + # Preserve provenance for Inspector's `propagate_back_debug_handle` + # and for users that pretty-print the graph. + if debug_handle is not None: + tap_node.meta["debug_handle"] = debug_handle + if "from_node" in node.meta: + tap_node.meta["from_node"] = node.meta["from_node"] + if "stack_trace" in node.meta: + tap_node.meta["stack_trace"] = node.meta["stack_trace"] + if "nn_module_stack" in node.meta: + tap_node.meta["nn_module_stack"] = node.meta["nn_module_stack"] + tap_node.meta["is_tap"] = True + tap_node.meta["source_node"] = node.name + + new_tap_nodes.append(tap_node) + # Leaf module FQN from nn_module_stack (e.g., "layers.0.attention.wqs.0"). + module_path: str | None = None + stack = node.meta.get("nn_module_stack") + if stack: + try: + last_entry = list(stack.values())[-1] + module_path = ( + last_entry[0] if isinstance(last_entry, tuple) else str(last_entry) + ) + except Exception: + module_path = None + specs.append( + TapSpec( + node_name=node.name, + op_target=str(node.target), + debug_handle=debug_handle if isinstance(debug_handle, int) else None, + output_index=len(existing_outputs) + len(specs), + reducer_name=reducer_obj.name, + fields=reducer_obj.fields, + stack_trace=node.meta.get("stack_trace"), + module_path=module_path, + ) + ) + + if not new_tap_nodes: + return ep, [] + + # Splice new outputs into the graph (mirror weights_to_outputs_pass). + new_output_args = tuple(existing_outputs + new_tap_nodes) + with graph.inserting_before(output_node): + new_output = graph.output(new_output_args) + output_node.replace_all_uses_with(new_output) + graph.erase_node(output_node) + + # Append OutputSpec entries so the EP's signature matches the graph. + for tap_node in new_tap_nodes: + gs.output_specs.append( + OutputSpec( + kind=OutputKind.USER_OUTPUT, + arg=TensorArgument(name=tap_node.name), + target=None, + ) + ) + + # Update each ModuleCallSignature's out_spec so `to_edge`'s re-trace can + # unflatten the new flat output list. The "" (root) entry holds the + # user-facing forward output structure; we wrap it in a tuple alongside + # the new tap leaves and re-derive the spec. + _extend_module_call_graph_outputs(ep, new_tap_nodes) + + graph.eliminate_dead_code() + gm.recompile() + return ep, specs + + +def _extend_module_call_graph_outputs( + ep: ExportedProgram, + new_tap_nodes: list[fx.Node], +) -> None: + """ + Append `len(new_tap_nodes)` extra leaves to the root module-call entry's + `out_spec` so the pytree unflatten step in `run_decompositions` works. + Also extends the entry's `outputs: list[ArgumentSpec]`. + + NOTE: We append TensorArgument(name="") for each new tap output. Empty + names are *skipped* by `_verify_exported_program_module_call_graph` (its + check is `if arg.name and arg.name not in nodes`). We can't use the + pre-trace tap node names because `to_edge`'s re-trace renames nodes via + `from_node` chains, and our tap nodes' provenance wouldn't update them + correctly — leading to "Output X does not exist in the graph" errors. + The verifier's name check is metadata-only; the actual pytree unflatten + only needs `out_spec` to have the correct number of leaves. + """ + import torch.utils._pytree as pytree + from torch.export.exported_program import TensorArgument as _TensorArgument + + n_new = len(new_tap_nodes) + if n_new == 0: + return + + for entry in ep._module_call_graph: + if entry.fqn != "": + continue + sig = entry.signature + if sig is None: + continue + old_spec = sig.out_spec + # Build a dummy structure matching the old spec, then wrap with N new + # leaves and re-derive the spec. This handles arbitrary pytree shapes. + old_dummy = pytree.tree_unflatten([0] * old_spec.num_leaves, old_spec) + if isinstance(old_dummy, tuple): + new_dummy = (*old_dummy, *([0] * n_new)) + else: + new_dummy = (old_dummy, *([0] * n_new)) + sig.out_spec = pytree.tree_structure(new_dummy) + for _ in range(n_new): + sig.outputs.append(_TensorArgument(name="")) + break + + +def find_tap_nodes(gm: fx.GraphModule) -> list[fx.Node]: + """Helper: enumerate tap.Tensor nodes in a GraphModule (any dialect).""" + out: list[fx.Node] = [] + for n in gm.graph.nodes: + if n.op != "call_function": + continue + # Match across dialects: + # pre-edge: torch.ops.executorch_devtools.tap.Tensor — str ends with name + # post-edge: : schema = ... + # so substring-match the qualified name. + if "executorch_devtools.tap.Tensor" in str(n.target) or n.target is _TAP_TARGET: + out.append(n) + return out + + +# Re-export the predicate so callers can identify tap nodes without importing +# torch.ops directly. +is_tap_node: Callable[[fx.Node], bool] = _is_tap_node diff --git a/devtools/intermediate_output_tap/custom_ops_lib.py b/devtools/intermediate_output_tap/custom_ops_lib.py new file mode 100644 index 00000000000..93c4636ad2c --- /dev/null +++ b/devtools/intermediate_output_tap/custom_ops_lib.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Custom op registration for the intermediate-output tap mechanism. + +The op `executorch_devtools::tap.Tensor(Tensor x, str reducer_name, int debug_handle) -> Tensor` +is an identity op whose sole job is to be an unknown-to-every-partitioner FX node +that "uses" a tapped tensor `x`. Because `x` now has a user outside any partition, +every ExecuTorch partitioner must surface `x` as a partition output (this is the +canonical contract enforced in `executorch/exir/lowered_backend_module.py`). + +After `to_edge_transform_and_lower(...)` the tap.Tensor node still exists in the +parent graph; `strip_taps_` (see `_strip_pass.py`) replaces it with either an +identity edge (FULL_TENSOR) or a small reducer subgraph of portable aten ops. + +The dispatch key MUST be `CompositeExplicitAutograd` (not `CompositeImplicitAutograd`) +so the op survives tracing/decomposition; otherwise it would inline at export time +and disappear before partitioning. This mirrors the pattern in +`executorch/examples/apple/coreml/llama/export_static_llm_coreml.py`. + +`reducer_name` and `debug_handle` are stored as op arguments (not just node.meta) +so they survive any meta-stripping pass between `to_edge` and `strip_taps_`. +""" + +from __future__ import annotations + +from torch.library import impl, Library + +# Library namespace verified collision-free across fbsource as of Nov 2025. +lib: Library = Library("executorch_devtools", "DEF") + +lib.define("tap.Tensor(Tensor x, str reducer_name, int debug_handle) -> Tensor") + + +@impl(lib, "tap.Tensor", "CompositeExplicitAutograd") +def tap_tensor_impl(x, reducer_name, debug_handle): # noqa: ARG001 + return x diff --git a/devtools/intermediate_output_tap/tests/TARGETS b/devtools/intermediate_output_tap/tests/TARGETS new file mode 100644 index 00000000000..79344627b17 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/TARGETS @@ -0,0 +1,74 @@ +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") +load("@fbsource//tools/target_determinator/macros:ci.bzl", "ci") + +oncall("executorch") + +python_unittest( + name = "test_selectors", + srcs = ["test_selectors.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:selectors", + ], +) + +python_unittest( + name = "test_reducers", + srcs = ["test_reducers.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:reducers", + ], +) + +python_unittest( + name = "test_tap_pass", + srcs = ["test_tap_pass.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:reducers", + "//executorch/devtools/intermediate_output_tap:selectors", + "//executorch/devtools/intermediate_output_tap:tap_pass", + ], +) + +python_unittest( + name = "test_strip_pass", + srcs = ["test_strip_pass.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:reducers", + "//executorch/devtools/intermediate_output_tap:selectors", + "//executorch/devtools/intermediate_output_tap:strip_pass", + "//executorch/devtools/intermediate_output_tap:tap_pass", + "//executorch/exir:lib", + ], +) + +python_unittest( + name = "test_xnnpack_e2e", + srcs = ["test_xnnpack_e2e.py"], + deps = [ + "//caffe2:torch", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/devtools/intermediate_output_tap:lib", + "//executorch/exir:lib", + "//executorch/runtime:runtime", + ], +) + +python_unittest( + name = "test_inspector_integration", + srcs = ["test_inspector_integration.py"], + labels = ci.labels( + ci.buckconfig("executorch.event_tracer_enabled", "true"), + ), + deps = [ + "//caffe2:torch", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/devtools:lib", + "//executorch/devtools/intermediate_output_tap:lib", + "//executorch/exir:lib", + "//executorch/runtime:runtime", + ], +) diff --git a/devtools/intermediate_output_tap/tests/test_inspector_integration.py b/devtools/intermediate_output_tap/tests/test_inspector_integration.py new file mode 100644 index 00000000000..b48dcfe4e0a --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_inspector_integration.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Integration test: run the full pipeline (export -> tap -> lower with XNNPACK +-> strip -> generate_etrecord -> to_executorch -> runtime) and feed the flat +runtime outputs + (post-strip) TapSpecs to +Inspector.calculate_numeric_gap_from_taps. Verify the returned DataFrame has +rows aligned by debug_handle. + +KEY DESIGN POINTS: +1. ETRecord generation MUST happen AFTER `strip_taps_` so the snapshot of the + edge program contains no `tap.Tensor` nodes (which the EXIR serializer + can't handle). +2. `strip_taps_(edge, tap_specs=specs)` returns updated specs whose + `reducer_node_name` is set to the post-strip reducer terminal node name. + Inspector uses that name to look up the post-roundtrip `debug_handle` — + FX node names survive ETRecord serialization, debug_handle values do not. +""" + +import os +import sys +import tempfile +import unittest + +import torch +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, +) +from executorch.devtools import generate_etrecord, Inspector +from executorch.devtools.intermediate_output_tap import ( + DEFAULT_STATS, + format_tap_dataframe, + select_by_op_type, + strip_taps_, + tap_intermediate_outputs, +) +from executorch.exir import to_edge_transform_and_lower +from executorch.runtime import Runtime, Verification +from torch.export import export + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 16) + self.l2 = torch.nn.Linear(16, 4) + + def forward(self, x): + return self.l2(self.l1(x).relu()) + + +@unittest.skipIf(sys.platform.startswith("win"), "ExecuTorch runtime not available on Windows") +class InspectorIntegrationTest(unittest.TestCase): + def test_calculate_numeric_gap_from_taps(self): + model = _MLP() + example_inputs = (torch.randn(2, 8),) + + ep = export(model, example_inputs, strict=True) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=DEFAULT_STATS, + ) + # Do NOT pass generate_etrecord=True — we'd snapshot the EP while it + # still has tap.Tensor nodes (unserializable). + edge = to_edge_transform_and_lower( + ep_t, + partitioner=[XnnpackPartitioner()], + ) + # strip_taps_ with tap_specs returns updated specs whose + # reducer_node_name points at the post-strip reducer terminal node. + specs = strip_taps_(edge, tap_specs=specs) + et_program = edge.to_executorch() + + with tempfile.TemporaryDirectory() as temp_dir: + pte_path = os.path.join(temp_dir, "model.pte") + et_program.save(pte_path) + + # ETRecord generated AFTER strip — the edge program is now + # serializable. Don't pass exported_program: Inspector falls back + # to the edge dialect program for AOT capture. + etrecord_path = os.path.join(temp_dir, "etrecord.bin") + generate_etrecord( + etrecord_path, + edge_dialect_program=edge, + executorch_program=et_program, + ) + + rt = Runtime.get() + program = rt.load_program( + pte_path, + verification=Verification.Minimal, + enable_etdump=True, + debug_buffer_size=1024 * 1024, + ) + method = program.load_method("forward") + flat_outputs = method.execute(list(example_inputs)) + + etdump_path = os.path.join(temp_dir, "etdump.etdp") + debug_buffer_path = os.path.join(temp_dir, "debug_buffer.bin") + program.write_etdump_result_to_file(etdump_path, debug_buffer_path) + if not os.path.exists(etdump_path): + self.skipTest( + "Event tracer not enabled. Run with " + "--config executorch.event_tracer_enabled=true" + ) + + inspector = Inspector( + etdump_path=etdump_path, + etrecord=etrecord_path, + debug_buffer_path=debug_buffer_path, + ) + inspector._etrecord._representative_inputs = list(example_inputs) + df = inspector.calculate_numeric_gap_from_taps( + flat_runtime_outputs=flat_outputs, + tap_specs=specs, + distance="MSE", + ) + # Print friendly per-tap view to stdout (visible via --print-passing-details). + friendly = format_tap_dataframe(df, specs) + import pandas as _pd + with _pd.option_context( + "display.max_columns", None, + "display.width", 240, + "display.max_colwidth", 30, + "display.float_format", "{:.4g}".format, + ): + print("\n=== Inspector.calculate_numeric_gap_from_taps (friendly) ===") + print(friendly.to_string()) + + self.assertGreater(len(df), 0, "expected at least one tap row in DataFrame") + for col in ("aot_ops", "runtime_ops", "gap"): + self.assertIn(col, df.columns) + for _, row in df.iterrows(): + self.assertIsNotNone(row["aot_ops"]) + self.assertIsNotNone(row["runtime_ops"]) + gap = row["gap"] + if isinstance(gap, list): + gap = gap[0] if gap else 0.0 + self.assertGreaterEqual(float(gap), 0.0) diff --git a/devtools/intermediate_output_tap/tests/test_reducers.py b/devtools/intermediate_output_tap/tests/test_reducers.py new file mode 100644 index 00000000000..e36762be5ac --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_reducers.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +from executorch.devtools.intermediate_output_tap._reducers import ( + ABS_MAX_ONLY, + DEFAULT_STATS, + FULL_TENSOR, + get_reducer, + MIN_MAX_MEAN, + StatReducer, +) + + +class ReducersTest(unittest.TestCase): + def test_get_reducer_by_name(self): + self.assertIs(get_reducer("DEFAULT_STATS"), DEFAULT_STATS) + self.assertIs(get_reducer("FULL_TENSOR"), FULL_TENSOR) + self.assertIs(get_reducer("MIN_MAX_MEAN"), MIN_MAX_MEAN) + self.assertIs(get_reducer("ABS_MAX_ONLY"), ABS_MAX_ONLY) + + def test_get_reducer_passthrough(self): + custom = StatReducer(name="X", fields=("a",), emit=lambda g, n: n) + self.assertIs(get_reducer(custom), custom) + + def test_get_reducer_unknown_raises(self): + with self.assertRaises(ValueError): + get_reducer("DOES_NOT_EXIST") + + def test_reducer_field_counts(self): + self.assertEqual(FULL_TENSOR.fields, ()) + self.assertEqual(ABS_MAX_ONLY.fields, ("abs_max",)) + self.assertEqual(MIN_MAX_MEAN.fields, ("min", "max", "mean")) + self.assertEqual( + DEFAULT_STATS.fields, + ("min", "max", "mean", "abs_max"), + ) + + def test_reducer_names_unique(self): + names = {r.name for r in (FULL_TENSOR, ABS_MAX_ONLY, MIN_MAX_MEAN, DEFAULT_STATS)} + self.assertEqual(len(names), 4) + + def test_default_stats_eager_correctness(self): + """Confirm DEFAULT_STATS spec has 4 fields (std/nan_count/inf_count excluded).""" + self.assertEqual(len(DEFAULT_STATS.fields), 4) diff --git a/devtools/intermediate_output_tap/tests/test_selectors.py b/devtools/intermediate_output_tap/tests/test_selectors.py new file mode 100644 index 00000000000..495e184de52 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_selectors.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.devtools.intermediate_output_tap._selectors import ( + select_all, + select_all_call_function, + select_any, + select_by_meta_tag, + select_by_module_path, + select_by_op_type, + select_not, +) +from torch.export import export + + +class _Inner(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear(x).relu() + + +class _Outer(torch.nn.Module): + def __init__(self): + super().__init__() + self.inner = _Inner() + self.head = torch.nn.Linear(4, 2) + + def forward(self, x): + return self.head(self.inner(x)) + + +def _exported_graph(): + ep = export(_Outer(), (torch.randn(2, 4),), strict=True) + return ep.graph_module.graph + + +class SelectorsTest(unittest.TestCase): + def setUp(self): + self.graph = _exported_graph() + self.call_nodes = [n for n in self.graph.nodes if n.op == "call_function"] + + def test_select_all_call_function_excludes_getitem(self): + sel = select_all_call_function() + for n in self.call_nodes: + if "getitem" in str(n.target): + self.assertFalse(sel(n)) + else: + self.assertTrue(sel(n)) + + def test_select_by_op_type_matches_suffix(self): + sel = select_by_op_type("aten.linear.default", "aten.relu.default") + matched = [n for n in self.call_nodes if sel(n)] + # Two linears + one relu in the model. + self.assertGreaterEqual(len(matched), 2) + for n in matched: + self.assertTrue( + str(n.target).endswith("aten.linear.default") + or str(n.target).endswith("aten.relu.default") + ) + + def test_select_by_op_type_requires_target(self): + with self.assertRaises(ValueError): + select_by_op_type() + + def test_select_by_module_path(self): + sel = select_by_module_path("inner.*") + matched = [n for n in self.call_nodes if sel(n)] + # inner contains a linear and a relu. + self.assertGreater(len(matched), 0) + for n in matched: + stack = n.meta.get("nn_module_stack") or {} + paths = [ + v[0] if isinstance(v, tuple) else v for v in stack.values() + ] + self.assertTrue(any(p.startswith("inner") for p in paths)) + + def test_select_by_meta_tag_presence(self): + for n in self.call_nodes[:1]: + n.meta["debug_me"] = "yes" + sel = select_by_meta_tag("debug_me") + self.assertTrue(sel(self.call_nodes[0])) + self.assertFalse(sel(self.call_nodes[1])) + + def test_select_by_meta_tag_value(self): + self.call_nodes[0].meta["color"] = "blue" + self.call_nodes[1].meta["color"] = "red" + sel = select_by_meta_tag("color", "blue") + self.assertTrue(sel(self.call_nodes[0])) + self.assertFalse(sel(self.call_nodes[1])) + + def test_select_combinators(self): + a = select_by_op_type("aten.linear.default") + b = select_by_op_type("aten.relu.default") + any_sel = select_any(a, b) + all_sel = select_all(a, b) + not_sel = select_not(a) + + for n in self.call_nodes: + if a(n) or b(n): + self.assertTrue(any_sel(n)) + self.assertEqual(all_sel(n), a(n) and b(n)) + self.assertEqual(not_sel(n), not a(n)) + + def test_select_any_empty(self): + for n in self.call_nodes: + self.assertFalse(select_any()(n)) + + def test_select_all_empty(self): + for n in self.call_nodes: + self.assertTrue(select_all()(n)) diff --git a/devtools/intermediate_output_tap/tests/test_strip_pass.py b/devtools/intermediate_output_tap/tests/test_strip_pass.py new file mode 100644 index 00000000000..2ff0d9154b1 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_strip_pass.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.devtools.intermediate_output_tap._reducers import ( + DEFAULT_STATS, + FULL_TENSOR, + MIN_MAX_MEAN, +) +from executorch.devtools.intermediate_output_tap._selectors import ( + select_by_op_type, +) +from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_ +from executorch.devtools.intermediate_output_tap._tap_pass import ( + find_tap_nodes, + tap_intermediate_outputs, +) +from executorch.exir import to_edge +from torch.export import export + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 8) + self.l2 = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.l2(self.l1(x).relu()) + + +def _tapped_edge(reducer): + ep = export(_MLP(), (torch.randn(2, 8),), strict=True) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=reducer, + ) + return to_edge(ep_t), specs + + +class StripPassTest(unittest.TestCase): + def test_strip_removes_all_tap_nodes_full_tensor(self): + edge, _ = _tapped_edge(FULL_TENSOR) + # Pre-strip: tap nodes present. + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertGreater(len(find_tap_nodes(ep.graph_module)), 0) + + strip_taps_(edge) + + # Post-strip: no tap nodes. + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertEqual(len(find_tap_nodes(ep.graph_module)), 0) + + def test_strip_full_tensor_routes_source_to_output(self): + edge, specs = _tapped_edge(FULL_TENSOR) + strip_taps_(edge) + # Output node should still have all the user outputs + tap outputs. + for method_name in edge.methods: + ep = edge.exported_program(method_name) + outs = list(ep.graph_module.graph.output_node().args[0]) + # Original outputs + 2 linears tapped. + self.assertGreaterEqual(len(outs), len(specs)) + + def test_strip_min_max_mean_emits_subgraph(self): + edge, specs = _tapped_edge(MIN_MAX_MEAN) + strip_taps_(edge) + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertEqual(len(find_tap_nodes(ep.graph_module)), 0) + # Some reduction op (amin/amax/mean) should now be in the graph. + # Substring match because EdgeOpOverload's str() looks like + # ": schema = ..." (no clean + # endswith). + targets = {str(n.target) for n in ep.graph_module.graph.nodes} + self.assertTrue( + any( + "aten.amin" in t or "aten.amax" in t or "aten.mean" in t + for t in targets + ), + f"expected reducer ops in graph, got {targets}", + ) + + def test_strip_default_stats_preserves_debug_handle(self): + edge, specs = _tapped_edge(DEFAULT_STATS) + # Take a known debug_handle from one of the tap specs. + known_handles = {s.debug_handle for s in specs if s.debug_handle is not None} + if not known_handles: + self.skipTest("Test model produced no debug_handle on tap sources") + + strip_taps_(edge) + + post_handles: set = set() + for method_name in edge.methods: + ep = edge.exported_program(method_name) + for n in ep.graph_module.graph.nodes: + if n.meta.get("is_tap"): + post_handles.add(n.meta.get("debug_handle")) + # At least one tapped debug handle should still be present. + self.assertTrue(known_handles & post_handles) + + def test_strip_idempotent(self): + edge, _ = _tapped_edge(FULL_TENSOR) + strip_taps_(edge) + # Second call should be a no-op. + strip_taps_(edge) + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertEqual(len(find_tap_nodes(ep.graph_module)), 0) diff --git a/devtools/intermediate_output_tap/tests/test_tap_pass.py b/devtools/intermediate_output_tap/tests/test_tap_pass.py new file mode 100644 index 00000000000..9e01ffe61d7 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_tap_pass.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import copy +import unittest + +import torch +from executorch.devtools.intermediate_output_tap._reducers import ( + DEFAULT_STATS, + FULL_TENSOR, +) +from executorch.devtools.intermediate_output_tap._selectors import ( + select_by_op_type, +) +from executorch.devtools.intermediate_output_tap._tap_pass import ( + is_tap_node, + tap_intermediate_outputs, +) +from torch.export import export +from torch.export.exported_program import OutputKind + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 16) + self.l2 = torch.nn.Linear(16, 8) + self.l3 = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.l3(self.l2(self.l1(x).relu()).relu()) + + +def _export(): + return export(_MLP(), (torch.randn(2, 8),), strict=True) + + +class TapPassTest(unittest.TestCase): + def test_inserts_tap_per_selected_node(self): + ep = _export() + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + # MLP has 3 linears. + self.assertEqual(len(specs), 3) + tap_nodes = [n for n in ep_t.graph_module.graph.nodes if is_tap_node(n)] + self.assertEqual(len(tap_nodes), 3) + + def test_appends_user_outputs(self): + ep = _export() + original_user_outs = sum( + 1 for s in ep.graph_signature.output_specs if s.kind == OutputKind.USER_OUTPUT + ) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + new_user_outs = sum( + 1 + for s in ep_t.graph_signature.output_specs + if s.kind == OutputKind.USER_OUTPUT + ) + self.assertEqual(new_user_outs, original_user_outs + len(specs)) + + def test_output_indices_contiguous_after_user_outputs(self): + ep = _export() + original_user_outs = sum( + 1 for s in ep.graph_signature.output_specs if s.kind == OutputKind.USER_OUTPUT + ) + _, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + for i, spec in enumerate(specs): + self.assertEqual(spec.output_index, original_user_outs + i) + + def test_default_reducer_is_default_stats(self): + ep = _export() + _, specs = tap_intermediate_outputs( + ep, selector=select_by_op_type("aten.linear.default") + ) + for s in specs: + self.assertEqual(s.reducer_name, DEFAULT_STATS.name) + self.assertEqual(s.fields, DEFAULT_STATS.fields) + + def test_inplace_false_does_not_mutate_original(self): + ep = _export() + before_outs = len(list(ep.graph_module.graph.output_node().args[0])) + before_specs = len(ep.graph_signature.output_specs) + _ = tap_intermediate_outputs( + ep, selector=select_by_op_type("aten.linear.default"), reducer=FULL_TENSOR + ) + after_outs = len(list(ep.graph_module.graph.output_node().args[0])) + after_specs = len(ep.graph_signature.output_specs) + self.assertEqual(before_outs, after_outs) + self.assertEqual(before_specs, after_specs) + + def test_max_taps(self): + ep = _export() + _, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + max_taps=2, + ) + self.assertEqual(len(specs), 2) + + def test_idempotent_does_not_tap_taps(self): + ep = _export() + ep_once, specs1 = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + # Running again should not add NEW taps for our existing tap nodes. + ep_twice, specs2 = tap_intermediate_outputs( + ep_once, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + # Same number of linears matched; tap.Tensor itself is excluded. + self.assertEqual(len(specs2), len(specs1)) + + def test_no_match_returns_empty_specs(self): + ep = _export() + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.does.not.exist"), + reducer=FULL_TENSOR, + ) + self.assertEqual(specs, []) + # Original graph signature is unchanged. + self.assertEqual( + len(ep_t.graph_signature.output_specs), + len(ep.graph_signature.output_specs), + ) + + def test_skip_if_no_debug_handle(self): + ep = _export() + # Strip all debug handles to simulate a graph without them. + ep_clean = copy.deepcopy(ep) + for n in ep_clean.graph_module.graph.nodes: + n.meta.pop("debug_handle", None) + _, specs = tap_intermediate_outputs( + ep_clean, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + skip_if_no_debug_handle=True, + ) + self.assertEqual(specs, []) diff --git a/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py b/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py new file mode 100644 index 00000000000..4967427a60a --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +End-to-end test: prove that intermediate values surfaced as USER_OUTPUT taps +flow through XNNPACK delegation and out the runtime *with no XNNPACK-side +support*. This is the central correctness claim of the design. +""" + +import os +import sys +import tempfile +import unittest + +import torch +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, +) +from executorch.devtools.intermediate_output_tap import ( + ABS_MAX_ONLY, + DEFAULT_STATS, + FULL_TENSOR, + MIN_MAX_MEAN, + select_by_op_type, + strip_taps_, + tap_intermediate_outputs, +) +from executorch.exir import to_edge_transform_and_lower +from executorch.runtime import Runtime, Verification +from torch.export import export + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 16) + self.l2 = torch.nn.Linear(16, 4) + + def forward(self, x): + return self.l2(self.l1(x).relu()) + + +@unittest.skipIf(sys.platform.startswith("win"), "ExecuTorch runtime not available on Windows") +class XnnpackEndToEndTest(unittest.TestCase): + def _run_pipeline(self, reducer): + model = _MLP() + example_inputs = (torch.randn(2, 8),) + + ep = export(model, example_inputs, strict=True) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=reducer, + ) + edge = to_edge_transform_and_lower( + ep_t, partitioner=[XnnpackPartitioner()] + ) + strip_taps_(edge) + et_program = edge.to_executorch() + + with tempfile.TemporaryDirectory() as temp_dir: + pte_path = os.path.join(temp_dir, "model.pte") + et_program.save(pte_path) + + rt = Runtime.get() + program = rt.load_program(pte_path, verification=Verification.Minimal) + method = program.load_method("forward") + flat_outputs = method.execute(list(example_inputs)) + + return specs, flat_outputs, model, example_inputs + + def test_full_tensor_taps_match_eager(self): + specs, flat, model, example_inputs = self._run_pipeline(FULL_TENSOR) + self.assertEqual(len(specs), 2) # two linears + + # The user output is at index 0; tap outputs follow. + for spec in specs: + tap_value = flat[spec.output_index] + self.assertIsInstance(tap_value, torch.Tensor) + # FULL_TENSOR preserves the source tensor's shape — so e.g. for + # the first linear, shape is (batch, l1.out_features). + self.assertGreater(tap_value.numel(), 0) + + def test_abs_max_only_returns_scalar(self): + specs, flat, _, _ = self._run_pipeline(ABS_MAX_ONLY) + self.assertEqual(len(specs), 2) + for spec in specs: + tap_value = flat[spec.output_index] + self.assertIsInstance(tap_value, torch.Tensor) + # 0-dim scalar + self.assertEqual(tap_value.numel(), 1) + self.assertGreaterEqual(float(tap_value), 0.0) + + def test_min_max_mean_e2e(self): + specs, flat, _, _ = self._run_pipeline(MIN_MAX_MEAN) + self.assertEqual(len(specs), 2) + for spec in specs: + tap_value = flat[spec.output_index] + self.assertEqual(tap_value.numel(), 3) + + def test_default_stats_returns_seven_floats(self): + specs, flat, _, _ = self._run_pipeline(DEFAULT_STATS) + self.assertEqual(len(specs), 2) + for spec in specs: + tap_value = flat[spec.output_index] + self.assertIsInstance(tap_value, torch.Tensor) + self.assertEqual(tap_value.numel(), 4) + mn, mx, _, abs_max = tap_value.tolist() + self.assertLessEqual(mn, mx) + self.assertGreaterEqual(abs_max, max(abs(mn), abs(mx)) - 1e-5) + + def test_user_outputs_still_correct(self): + """Tap outputs must not corrupt the original user outputs.""" + specs, flat, model, example_inputs = self._run_pipeline(FULL_TENSOR) + + eager_out = model(*example_inputs) + # User output is at index 0 (one user output for our MLP). + user_out = flat[0] + torch.testing.assert_close(user_out, eager_out, atol=1e-3, rtol=1e-3) + # Verify tap indices are non-overlapping with user-output index 0. + for spec in specs: + self.assertGreaterEqual(spec.output_index, 1)