From 99192d27d07c9bb21fd9e99ba7b301b2a5cf195f Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 23 Mar 2026 12:12:05 +0000 Subject: [PATCH 1/6] feat: Compute layer norms for a model checkpoint --- scripts/compute_layer_norms.py | 158 +++++++++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 scripts/compute_layer_norms.py diff --git a/scripts/compute_layer_norms.py b/scripts/compute_layer_norms.py new file mode 100644 index 000000000..b193d713f --- /dev/null +++ b/scripts/compute_layer_norms.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 + +import argparse +import json +import os +from pathlib import Path +from typing import cast + +import torch +import torch.distributed as dist +from pydantic import BaseModel +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor + +from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import DCPCheckpointLoading +from modalities.checkpointing.stateful.app_state import AppState +from modalities.config.config import ProcessGroupBackendType +from modalities.config.pydantic_if_types import PydanticAppStateType, PydanticDeviceMeshIFType +from modalities.main import Main +from modalities.running_env.cuda_env import CudaEnv +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_mesh_for_parallelism_method + + +class ComponentsInstantiationModel(BaseModel): + app_state: PydanticAppStateType + device_mesh: PydanticDeviceMeshIFType | None = None + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Load a Modalities DCP checkpoint into an app state.") + parser.add_argument("--config-file-path", type=Path, required=True, help="Path to the YAML config file.") + parser.add_argument( + "--experiments-root-path", + type=Path, + required=True, + help="Path passed to Main for resolver/context setup.", + ) + parser.add_argument( + "--checkpoint-dir-path", + type=Path, + default=None, + help="Path to a checkpoint directory containing *.distcp files.", + ) + parser.add_argument( + "--last-checkpoint-info-path", + type=Path, + default=None, + help="Path to last_checkpoint_info.json. Used when checkpoint-dir-path is omitted.", + ) + return parser.parse_args() + + +def _resolve_checkpoint_dir_path(args: argparse.Namespace) -> Path: + if args.checkpoint_dir_path is not None and args.last_checkpoint_info_path is not None: + raise ValueError("Pass either --checkpoint-dir-path or --last-checkpoint-info-path, not both.") + + if args.checkpoint_dir_path is not None: + return args.checkpoint_dir_path + + if args.last_checkpoint_info_path is None: + raise ValueError("Pass one of --checkpoint-dir-path or --last-checkpoint-info-path.") + + with open(args.last_checkpoint_info_path, "r", encoding="utf-8") as f: + checkpoint_info = json.load(f) + + return Path(checkpoint_info["checkpoint_folder_path"]) + + +def _get_layer_key(parameter_name: str) -> str: + # Strip common wrapping prefixes that appear for wrapped modules. + name = parameter_name + for prefix in ("module.", "_orig_mod.", "_fsdp_wrapped_module."): + if name.startswith(prefix): + name = name[len(prefix) :] + + tokens = name.split(".") + for i in range(len(tokens) - 1): + if tokens[i] in {"h", "layers", "blocks"} and tokens[i + 1].isdigit(): + if i > 0: + return ".".join(tokens[i - 1 : i + 2]) + return ".".join(tokens[i : i + 2]) + + # Fallback: group by parent module path if no canonical layer index token exists. + return ".".join(tokens[:-1]) if len(tokens) > 1 else name + + +def _get_dp_shard_group(device_mesh: DeviceMesh | None): + if device_mesh is None: + return None + try: + return get_mesh_for_parallelism_method(device_mesh, ParallelismDegrees.DP_SHARD).get_group() + except Exception: + # Fallback to the default process group if a dedicated DP-shard group is unavailable. + return None + + +def _compute_and_print_layer_norms(app_state: AppState, dp_shard_group) -> None: + layer_sq_sums: dict[str, torch.Tensor] = {} + + for model_part_idx, model_part in enumerate(app_state.model_parts): + for name, parameter in model_part.named_parameters(): + if not parameter.requires_grad: + continue + full_name = f"model_part_{model_part_idx}.{name}" if len(app_state.model_parts) > 1 else name + layer_key = _get_layer_key(full_name) + + # FSDP2 parameters can be DTensors. Convert to local shard first so c10d all_reduce + # operates on plain tensors instead of DTensors. + local_param = parameter.to_local() if isinstance(parameter, DTensor) else parameter + local_sq_sum = local_param.detach().float().pow(2).sum() + layer_sq_sums[layer_key] = layer_sq_sums.get(layer_key, torch.zeros_like(local_sq_sum)) + local_sq_sum + + # Aggregate over the DP-shard group to reconstruct global norms for sharded parameters. + for layer_key, sq_sum in layer_sq_sums.items(): + dist.all_reduce(sq_sum, op=dist.ReduceOp.SUM, group=dp_shard_group) + layer_sq_sums[layer_key] = sq_sum + + if dist.get_rank() == 0: + print("Per-layer parameter L2 norms (global across DP-shards):") + for layer_key in sorted(layer_sq_sums): + norm = torch.sqrt(layer_sq_sums[layer_key]).item() + print(f"{layer_key}: {norm:.6f}") + + +def main() -> None: + args = _parse_args() + checkpoint_dir_path = _resolve_checkpoint_dir_path(args) + + with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): + rank = dist.get_rank() + + main_obj = Main( + config_path=args.config_file_path, + experiments_root_path=args.experiments_root_path, + ) + components = cast( + ComponentsInstantiationModel, + main_obj.build_components(components_model_type=ComponentsInstantiationModel), + ) + + app_state = cast(AppState, getattr(components, "app_state")) + device_mesh = cast(DeviceMesh | None, getattr(components, "device_mesh", None)) + + loader = DCPCheckpointLoading(global_rank=rank) + loader.load_checkpoint_(app_state=app_state, checkpoint_dir_path=checkpoint_dir_path) + + dp_shard_group = _get_dp_shard_group(device_mesh) + _compute_and_print_layer_norms(app_state, dp_shard_group) + + if rank == 0: + print( + f"Loaded checkpoint from {checkpoint_dir_path} on world size {dist.get_world_size()} " + f"(pid={os.getpid()})." + ) + + +if __name__ == "__main__": + main() From 89a9400eb3255dad3c33c276ea350f8b7b55cae6 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 23 Mar 2026 12:34:40 +0000 Subject: [PATCH 2/6] feat: Support multiple checkpoints --- scripts/compute_layer_norms.py | 76 ++++++++++++++-------------------- 1 file changed, 31 insertions(+), 45 deletions(-) diff --git a/scripts/compute_layer_norms.py b/scripts/compute_layer_norms.py index b193d713f..862dc5a0c 100644 --- a/scripts/compute_layer_norms.py +++ b/scripts/compute_layer_norms.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import argparse -import json import os from pathlib import Path from typing import cast @@ -27,7 +26,7 @@ class ComponentsInstantiationModel(BaseModel): def _parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Load a Modalities DCP checkpoint into an app state.") + parser = argparse.ArgumentParser(description="Load one or more Modalities DCP checkpoints into an app state.") parser.add_argument("--config-file-path", type=Path, required=True, help="Path to the YAML config file.") parser.add_argument( "--experiments-root-path", @@ -36,34 +35,17 @@ def _parse_args() -> argparse.Namespace: help="Path passed to Main for resolver/context setup.", ) parser.add_argument( - "--checkpoint-dir-path", + "--checkpoint-dir-paths", type=Path, - default=None, - help="Path to a checkpoint directory containing *.distcp files.", - ) - parser.add_argument( - "--last-checkpoint-info-path", - type=Path, - default=None, - help="Path to last_checkpoint_info.json. Used when checkpoint-dir-path is omitted.", + nargs="+", + required=True, + help="Paths to multiple checkpoint directories containing *.distcp files.", ) return parser.parse_args() -def _resolve_checkpoint_dir_path(args: argparse.Namespace) -> Path: - if args.checkpoint_dir_path is not None and args.last_checkpoint_info_path is not None: - raise ValueError("Pass either --checkpoint-dir-path or --last-checkpoint-info-path, not both.") - - if args.checkpoint_dir_path is not None: - return args.checkpoint_dir_path - - if args.last_checkpoint_info_path is None: - raise ValueError("Pass one of --checkpoint-dir-path or --last-checkpoint-info-path.") - - with open(args.last_checkpoint_info_path, "r", encoding="utf-8") as f: - checkpoint_info = json.load(f) - - return Path(checkpoint_info["checkpoint_folder_path"]) +def _resolve_checkpoint_dir_paths(args: argparse.Namespace) -> list[Path]: + return list(args.checkpoint_dir_paths) def _get_layer_key(parameter_name: str) -> str: @@ -124,34 +106,38 @@ def _compute_and_print_layer_norms(app_state: AppState, dp_shard_group) -> None: def main() -> None: args = _parse_args() - checkpoint_dir_path = _resolve_checkpoint_dir_path(args) + checkpoint_dir_paths = _resolve_checkpoint_dir_paths(args) with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): rank = dist.get_rank() - main_obj = Main( - config_path=args.config_file_path, - experiments_root_path=args.experiments_root_path, - ) - components = cast( - ComponentsInstantiationModel, - main_obj.build_components(components_model_type=ComponentsInstantiationModel), - ) + for checkpoint_dir_path in checkpoint_dir_paths: + # Rebuild components per checkpoint because AppState only supports one load call. + main_obj = Main( + config_path=args.config_file_path, + experiments_root_path=args.experiments_root_path, + ) + components = cast( + ComponentsInstantiationModel, + main_obj.build_components(components_model_type=ComponentsInstantiationModel), + ) - app_state = cast(AppState, getattr(components, "app_state")) - device_mesh = cast(DeviceMesh | None, getattr(components, "device_mesh", None)) + app_state = cast(AppState, getattr(components, "app_state")) + device_mesh = cast(DeviceMesh | None, getattr(components, "device_mesh", None)) - loader = DCPCheckpointLoading(global_rank=rank) - loader.load_checkpoint_(app_state=app_state, checkpoint_dir_path=checkpoint_dir_path) + loader = DCPCheckpointLoading(global_rank=rank) + loader.load_checkpoint_(app_state=app_state, checkpoint_dir_path=checkpoint_dir_path) - dp_shard_group = _get_dp_shard_group(device_mesh) - _compute_and_print_layer_norms(app_state, dp_shard_group) + dp_shard_group = _get_dp_shard_group(device_mesh) + if rank == 0: + print(f"\n=== {checkpoint_dir_path} ===") + _compute_and_print_layer_norms(app_state, dp_shard_group) - if rank == 0: - print( - f"Loaded checkpoint from {checkpoint_dir_path} on world size {dist.get_world_size()} " - f"(pid={os.getpid()})." - ) + if rank == 0: + print( + f"Loaded checkpoint from {checkpoint_dir_path} on world size {dist.get_world_size()} " + f"(pid={os.getpid()})." + ) if __name__ == "__main__": From 76cc747403fc60bcb3a4ba740fe49b7649c0e144 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 23 Mar 2026 13:35:07 +0000 Subject: [PATCH 3/6] refactor: Separate logic for layer norm computation and plotting --- scripts/compute_layer_norms.py | 46 +++++++++++++-- scripts/plot_layer_norms.py | 103 +++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 5 deletions(-) create mode 100644 scripts/plot_layer_norms.py diff --git a/scripts/compute_layer_norms.py b/scripts/compute_layer_norms.py index 862dc5a0c..decdf9bc2 100644 --- a/scripts/compute_layer_norms.py +++ b/scripts/compute_layer_norms.py @@ -1,7 +1,9 @@ #!/usr/bin/env python3 import argparse +import json import os +import re from pathlib import Path from typing import cast @@ -41,6 +43,12 @@ def _parse_args() -> argparse.Namespace: required=True, help="Paths to multiple checkpoint directories containing *.distcp files.", ) + parser.add_argument( + "--json-output-path", + type=Path, + default=Path("layer_norms_across_checkpoints.json"), + help="Output path for raw per-checkpoint norms as JSON.", + ) return parser.parse_args() @@ -76,7 +84,7 @@ def _get_dp_shard_group(device_mesh: DeviceMesh | None): return None -def _compute_and_print_layer_norms(app_state: AppState, dp_shard_group) -> None: +def _compute_and_print_layer_norms(app_state: AppState, dp_shard_group) -> dict[str, float]: layer_sq_sums: dict[str, torch.Tensor] = {} for model_part_idx, model_part in enumerate(app_state.model_parts): @@ -97,11 +105,27 @@ def _compute_and_print_layer_norms(app_state: AppState, dp_shard_group) -> None: dist.all_reduce(sq_sum, op=dist.ReduceOp.SUM, group=dp_shard_group) layer_sq_sums[layer_key] = sq_sum + layer_norms = {layer_key: torch.sqrt(sq_sum).item() for layer_key, sq_sum in layer_sq_sums.items()} + if dist.get_rank() == 0: print("Per-layer parameter L2 norms (global across DP-shards):") - for layer_key in sorted(layer_sq_sums): - norm = torch.sqrt(layer_sq_sums[layer_key]).item() - print(f"{layer_key}: {norm:.6f}") + for layer_key in sorted(layer_norms): + print(f"{layer_key}: {layer_norms[layer_key]:.6f}") + + return layer_norms + + +def _extract_checkpoint_label(checkpoint_dir_path: Path) -> str: + match = re.search(r"seen_steps_(\d+)", checkpoint_dir_path.name) + if match: + return f"steps_{match.group(1)}" + return checkpoint_dir_path.name + + +def _save_json_results(results: list[dict], output_path: Path) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2) def main() -> None: @@ -110,6 +134,7 @@ def main() -> None: with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): rank = dist.get_rank() + collected_results: list[dict] = [] for checkpoint_dir_path in checkpoint_dir_paths: # Rebuild components per checkpoint because AppState only supports one load call. @@ -131,14 +156,25 @@ def main() -> None: dp_shard_group = _get_dp_shard_group(device_mesh) if rank == 0: print(f"\n=== {checkpoint_dir_path} ===") - _compute_and_print_layer_norms(app_state, dp_shard_group) + layer_norms = _compute_and_print_layer_norms(app_state, dp_shard_group) if rank == 0: + collected_results.append( + { + "checkpoint_path": str(checkpoint_dir_path), + "checkpoint_label": _extract_checkpoint_label(checkpoint_dir_path), + "layer_norms": layer_norms, + } + ) print( f"Loaded checkpoint from {checkpoint_dir_path} on world size {dist.get_world_size()} " f"(pid={os.getpid()})." ) + if rank == 0: + _save_json_results(collected_results, args.json_output_path) + print(f"Saved raw layer norms JSON to {args.json_output_path}") + if __name__ == "__main__": main() diff --git a/scripts/plot_layer_norms.py b/scripts/plot_layer_norms.py new file mode 100644 index 000000000..6398660a3 --- /dev/null +++ b/scripts/plot_layer_norms.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +import argparse +import json +import re +from pathlib import Path + +import matplotlib.pyplot as plt +import torch + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Plot layer norms across checkpoints from a JSON log file.") + parser.add_argument( + "--layer-norms-json-path", + type=Path, + required=True, + help="Path to JSON produced by scripts/compute_layer_norms.py.", + ) + parser.add_argument( + "--plot-output-path", + type=Path, + default=Path("layer_norms_across_checkpoints.png"), + help="Output image path for cross-checkpoint layer-norm visualization.", + ) + parser.add_argument( + "--layer-filter-regex", + type=str, + default=r".*", + help="Regex to select layer keys in the visualization.", + ) + return parser.parse_args() + + +def _load_results(path: Path) -> list[dict]: + with open(path, "r", encoding="utf-8") as f: + results = json.load(f) + if not isinstance(results, list) or not results: + raise ValueError("Expected a non-empty JSON list of checkpoint results.") + return results + + +def _plot_checkpoint_comparison( + results: list[dict], + plot_output_path: Path, + layer_filter_regex: str, +) -> None: + layer_pattern = re.compile(layer_filter_regex) + filtered_layers = sorted( + { + layer_name + for checkpoint_result in results + for layer_name in checkpoint_result["layer_norms"].keys() + if layer_pattern.search(layer_name) + } + ) + if not filtered_layers: + raise ValueError(f"No layer names matched --layer-filter-regex={layer_filter_regex!r}.") + + checkpoint_labels = [checkpoint_result["checkpoint_label"] for checkpoint_result in results] + matrix = torch.tensor( + [ + [checkpoint_result["layer_norms"].get(layer_name, float("nan")) for layer_name in filtered_layers] + for checkpoint_result in results + ], + dtype=torch.float32, + ) + + fig_width = max(12, 0.55 * len(checkpoint_labels)) + fig_height = max(8, 0.25 * len(filtered_layers)) + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + image = ax.imshow(matrix.T.numpy(), aspect="auto", interpolation="nearest") + + ax.set_title("Layer Norms Across Checkpoints") + ax.set_xlabel("Checkpoint") + ax.set_ylabel("Layer") + ax.set_xticks(range(len(checkpoint_labels))) + ax.set_xticklabels(checkpoint_labels, rotation=45, ha="right") + ax.set_yticks(range(len(filtered_layers))) + ax.set_yticklabels(filtered_layers) + + cbar = fig.colorbar(image, ax=ax) + cbar.set_label("L2 norm") + + fig.tight_layout() + plot_output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(plot_output_path, dpi=180) + plt.close(fig) + + +def main() -> None: + args = _parse_args() + results = _load_results(args.layer_norms_json_path) + _plot_checkpoint_comparison( + results=results, + plot_output_path=args.plot_output_path, + layer_filter_regex=args.layer_filter_regex, + ) + print(f"Saved cross-checkpoint layer-norm plot to {args.plot_output_path}") + + +if __name__ == "__main__": + main() From e53cb467d73af28e75fec8065901e6ea7464fafe Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 23 Mar 2026 15:06:44 +0000 Subject: [PATCH 4/6] feat: Compute norms for parameters instead of layers --- scripts/compute_layer_norms.py | 44 +++++++++++++--------------------- scripts/plot_layer_norms.py | 13 +++++----- 2 files changed, 24 insertions(+), 33 deletions(-) diff --git a/scripts/compute_layer_norms.py b/scripts/compute_layer_norms.py index decdf9bc2..b41beaf35 100644 --- a/scripts/compute_layer_norms.py +++ b/scripts/compute_layer_norms.py @@ -56,22 +56,12 @@ def _resolve_checkpoint_dir_paths(args: argparse.Namespace) -> list[Path]: return list(args.checkpoint_dir_paths) -def _get_layer_key(parameter_name: str) -> str: - # Strip common wrapping prefixes that appear for wrapped modules. +def _normalize_parameter_name(parameter_name: str) -> str: name = parameter_name for prefix in ("module.", "_orig_mod.", "_fsdp_wrapped_module."): if name.startswith(prefix): name = name[len(prefix) :] - - tokens = name.split(".") - for i in range(len(tokens) - 1): - if tokens[i] in {"h", "layers", "blocks"} and tokens[i + 1].isdigit(): - if i > 0: - return ".".join(tokens[i - 1 : i + 2]) - return ".".join(tokens[i : i + 2]) - - # Fallback: group by parent module path if no canonical layer index token exists. - return ".".join(tokens[:-1]) if len(tokens) > 1 else name + return name def _get_dp_shard_group(device_mesh: DeviceMesh | None): @@ -84,35 +74,35 @@ def _get_dp_shard_group(device_mesh: DeviceMesh | None): return None -def _compute_and_print_layer_norms(app_state: AppState, dp_shard_group) -> dict[str, float]: - layer_sq_sums: dict[str, torch.Tensor] = {} +def _compute_and_print_parameter_norms(app_state: AppState, dp_shard_group) -> dict[str, float]: + parameter_sq_sums: dict[str, torch.Tensor] = {} for model_part_idx, model_part in enumerate(app_state.model_parts): for name, parameter in model_part.named_parameters(): if not parameter.requires_grad: continue - full_name = f"model_part_{model_part_idx}.{name}" if len(app_state.model_parts) > 1 else name - layer_key = _get_layer_key(full_name) + raw_name = f"model_part_{model_part_idx}.{name}" if len(app_state.model_parts) > 1 else name + parameter_name = _normalize_parameter_name(raw_name) # FSDP2 parameters can be DTensors. Convert to local shard first so c10d all_reduce # operates on plain tensors instead of DTensors. local_param = parameter.to_local() if isinstance(parameter, DTensor) else parameter local_sq_sum = local_param.detach().float().pow(2).sum() - layer_sq_sums[layer_key] = layer_sq_sums.get(layer_key, torch.zeros_like(local_sq_sum)) + local_sq_sum + parameter_sq_sums[parameter_name] = local_sq_sum # Aggregate over the DP-shard group to reconstruct global norms for sharded parameters. - for layer_key, sq_sum in layer_sq_sums.items(): + for parameter_name, sq_sum in parameter_sq_sums.items(): dist.all_reduce(sq_sum, op=dist.ReduceOp.SUM, group=dp_shard_group) - layer_sq_sums[layer_key] = sq_sum + parameter_sq_sums[parameter_name] = sq_sum - layer_norms = {layer_key: torch.sqrt(sq_sum).item() for layer_key, sq_sum in layer_sq_sums.items()} + parameter_norms = {name: torch.sqrt(sq_sum).item() for name, sq_sum in parameter_sq_sums.items()} if dist.get_rank() == 0: - print("Per-layer parameter L2 norms (global across DP-shards):") - for layer_key in sorted(layer_norms): - print(f"{layer_key}: {layer_norms[layer_key]:.6f}") + print("Per-parameter L2 norms (global across DP-shards):") + for parameter_name in sorted(parameter_norms): + print(f"{parameter_name}: {parameter_norms[parameter_name]:.6f}") - return layer_norms + return parameter_norms def _extract_checkpoint_label(checkpoint_dir_path: Path) -> str: @@ -156,14 +146,14 @@ def main() -> None: dp_shard_group = _get_dp_shard_group(device_mesh) if rank == 0: print(f"\n=== {checkpoint_dir_path} ===") - layer_norms = _compute_and_print_layer_norms(app_state, dp_shard_group) + parameter_norms = _compute_and_print_parameter_norms(app_state, dp_shard_group) if rank == 0: collected_results.append( { "checkpoint_path": str(checkpoint_dir_path), "checkpoint_label": _extract_checkpoint_label(checkpoint_dir_path), - "layer_norms": layer_norms, + "parameter_norms": parameter_norms, } ) print( @@ -173,7 +163,7 @@ def main() -> None: if rank == 0: _save_json_results(collected_results, args.json_output_path) - print(f"Saved raw layer norms JSON to {args.json_output_path}") + print(f"Saved raw parameter norms JSON to {args.json_output_path}") if __name__ == "__main__": diff --git a/scripts/plot_layer_norms.py b/scripts/plot_layer_norms.py index 6398660a3..754121a30 100644 --- a/scripts/plot_layer_norms.py +++ b/scripts/plot_layer_norms.py @@ -10,7 +10,7 @@ def _parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Plot layer norms across checkpoints from a JSON log file.") + parser = argparse.ArgumentParser(description="Plot parameter norms across checkpoints from a JSON log file.") parser.add_argument( "--layer-norms-json-path", type=Path, @@ -45,12 +45,13 @@ def _plot_checkpoint_comparison( plot_output_path: Path, layer_filter_regex: str, ) -> None: + metric_key = "parameter_norms" if "parameter_norms" in results[0] else "layer_norms" layer_pattern = re.compile(layer_filter_regex) filtered_layers = sorted( { layer_name for checkpoint_result in results - for layer_name in checkpoint_result["layer_norms"].keys() + for layer_name in checkpoint_result[metric_key].keys() if layer_pattern.search(layer_name) } ) @@ -60,7 +61,7 @@ def _plot_checkpoint_comparison( checkpoint_labels = [checkpoint_result["checkpoint_label"] for checkpoint_result in results] matrix = torch.tensor( [ - [checkpoint_result["layer_norms"].get(layer_name, float("nan")) for layer_name in filtered_layers] + [checkpoint_result[metric_key].get(layer_name, float("nan")) for layer_name in filtered_layers] for checkpoint_result in results ], dtype=torch.float32, @@ -71,9 +72,9 @@ def _plot_checkpoint_comparison( fig, ax = plt.subplots(figsize=(fig_width, fig_height)) image = ax.imshow(matrix.T.numpy(), aspect="auto", interpolation="nearest") - ax.set_title("Layer Norms Across Checkpoints") + ax.set_title("Parameter Norms Across Checkpoints") ax.set_xlabel("Checkpoint") - ax.set_ylabel("Layer") + ax.set_ylabel("Parameter") ax.set_xticks(range(len(checkpoint_labels))) ax.set_xticklabels(checkpoint_labels, rotation=45, ha="right") ax.set_yticks(range(len(filtered_layers))) @@ -96,7 +97,7 @@ def main() -> None: plot_output_path=args.plot_output_path, layer_filter_regex=args.layer_filter_regex, ) - print(f"Saved cross-checkpoint layer-norm plot to {args.plot_output_path}") + print(f"Saved cross-checkpoint parameter-norm plot to {args.plot_output_path}") if __name__ == "__main__": From cb5a8f276570c5350974eb1821b535b6a2e662cc Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 23 Mar 2026 15:57:04 +0000 Subject: [PATCH 5/6] feat: Group plots by layer --- scripts/plot_layer_norms.py | 109 +++++++++++++++++++++++++----------- 1 file changed, 75 insertions(+), 34 deletions(-) diff --git a/scripts/plot_layer_norms.py b/scripts/plot_layer_norms.py index 754121a30..4a80cba01 100644 --- a/scripts/plot_layer_norms.py +++ b/scripts/plot_layer_norms.py @@ -6,7 +6,7 @@ from pathlib import Path import matplotlib.pyplot as plt -import torch +from matplotlib.backends.backend_pdf import PdfPages def _parse_args() -> argparse.Namespace: @@ -20,8 +20,8 @@ def _parse_args() -> argparse.Namespace: parser.add_argument( "--plot-output-path", type=Path, - default=Path("layer_norms_across_checkpoints.png"), - help="Output image path for cross-checkpoint layer-norm visualization.", + default=Path("parameter_norms_grouped_by_layer.pdf"), + help="Output PDF path containing one plot page per layer.", ) parser.add_argument( "--layer-filter-regex", @@ -40,6 +40,24 @@ def _load_results(path: Path) -> list[dict]: return results +def _extract_layer_key(parameter_name: str) -> str: + tokens = parameter_name.split(".") + for i in range(len(tokens) - 1): + if tokens[i] in {"h", "layers", "blocks"} and tokens[i + 1].isdigit(): + if i > 0: + return ".".join(tokens[i - 1 : i + 2]) + return ".".join(tokens[i : i + 2]) + return ".".join(tokens[:-1]) if len(tokens) > 1 else parameter_name + + +def _layer_sort_key(layer_key: str) -> tuple: + # Prefer numeric ordering for transformer block keys like h.0, layers.12, blocks.3. + match = re.search(r"(?:^|\.)(h|layers|blocks)\.(\d+)(?:\.|$)", layer_key) + if match: + return (0, match.group(1), int(match.group(2)), layer_key) + return (1, layer_key) + + def _plot_checkpoint_comparison( results: list[dict], plot_output_path: Path, @@ -47,46 +65,69 @@ def _plot_checkpoint_comparison( ) -> None: metric_key = "parameter_norms" if "parameter_norms" in results[0] else "layer_norms" layer_pattern = re.compile(layer_filter_regex) - filtered_layers = sorted( + filtered_parameters = sorted( { - layer_name + parameter_name for checkpoint_result in results - for layer_name in checkpoint_result[metric_key].keys() - if layer_pattern.search(layer_name) + for parameter_name in checkpoint_result[metric_key].keys() + if layer_pattern.search(parameter_name) } ) - if not filtered_layers: + if not filtered_parameters: raise ValueError(f"No layer names matched --layer-filter-regex={layer_filter_regex!r}.") checkpoint_labels = [checkpoint_result["checkpoint_label"] for checkpoint_result in results] - matrix = torch.tensor( - [ - [checkpoint_result[metric_key].get(layer_name, float("nan")) for layer_name in filtered_layers] - for checkpoint_result in results - ], - dtype=torch.float32, - ) - - fig_width = max(12, 0.55 * len(checkpoint_labels)) - fig_height = max(8, 0.25 * len(filtered_layers)) - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - image = ax.imshow(matrix.T.numpy(), aspect="auto", interpolation="nearest") - - ax.set_title("Parameter Norms Across Checkpoints") - ax.set_xlabel("Checkpoint") - ax.set_ylabel("Parameter") - ax.set_xticks(range(len(checkpoint_labels))) - ax.set_xticklabels(checkpoint_labels, rotation=45, ha="right") - ax.set_yticks(range(len(filtered_layers))) - ax.set_yticklabels(filtered_layers) - cbar = fig.colorbar(image, ax=ax) - cbar.set_label("L2 norm") + grouped_parameters: dict[str, list[str]] = {} + for parameter_name in filtered_parameters: + layer_key = _extract_layer_key(parameter_name) + grouped_parameters.setdefault(layer_key, []).append(parameter_name) + ordered_layer_keys = sorted(grouped_parameters, key=_layer_sort_key) - fig.tight_layout() plot_output_path.parent.mkdir(parents=True, exist_ok=True) - fig.savefig(plot_output_path, dpi=180) - plt.close(fig) + with PdfPages(plot_output_path) as pdf: + # First page: quick summary of layers and parameter counts. + summary_lines = [ + f"checkpoints: {len(checkpoint_labels)}", + f"layers: {len(grouped_parameters)}", + f"parameters plotted: {len(filtered_parameters)}", + "", + "Layer -> #parameters", + ] + for layer_key in ordered_layer_keys: + summary_lines.append(f"{layer_key}: {len(grouped_parameters[layer_key])}") + + fig, ax = plt.subplots(figsize=(10, 12)) + ax.axis("off") + ax.text(0.01, 0.99, "\n".join(summary_lines), va="top", ha="left", fontsize=10) + fig.tight_layout() + pdf.savefig(fig) + plt.close(fig) + + # One page per layer with all parameter curves for that layer. + x = list(range(len(checkpoint_labels))) + for layer_key in ordered_layer_keys: + parameter_names = sorted(grouped_parameters[layer_key]) + fig, ax = plt.subplots(figsize=(12, 6)) + for parameter_name in parameter_names: + y = [checkpoint_result[metric_key].get(parameter_name, float("nan")) for checkpoint_result in results] + short_name = ( + parameter_name[len(layer_key) + 1 :] + if parameter_name.startswith(layer_key + ".") + else parameter_name + ) + ax.plot(x, y, marker="o", linewidth=1.5, label=short_name) + + ax.set_title(f"{layer_key} parameter norms across checkpoints") + ax.set_xlabel("Checkpoint") + ax.set_ylabel("L2 norm") + ax.set_xticks(x) + ax.set_xticklabels(checkpoint_labels, rotation=45, ha="right") + ax.grid(True, alpha=0.25) + ax.legend(loc="best", fontsize=8) + fig.tight_layout() + pdf.savefig(fig) + plt.close(fig) def main() -> None: @@ -97,7 +138,7 @@ def main() -> None: plot_output_path=args.plot_output_path, layer_filter_regex=args.layer_filter_regex, ) - print(f"Saved cross-checkpoint parameter-norm plot to {args.plot_output_path}") + print(f"Saved grouped parameter-norm plots to {args.plot_output_path}") if __name__ == "__main__": From e9368ff01c535113a3f53687ebf69cbae17306ae Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 23 Mar 2026 15:59:29 +0000 Subject: [PATCH 6/6] chore: Move parameter norm scripts to separate folder --- scripts/{ => parameter_norms}/compute_layer_norms.py | 0 scripts/{ => parameter_norms}/plot_layer_norms.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename scripts/{ => parameter_norms}/compute_layer_norms.py (100%) rename scripts/{ => parameter_norms}/plot_layer_norms.py (100%) diff --git a/scripts/compute_layer_norms.py b/scripts/parameter_norms/compute_layer_norms.py similarity index 100% rename from scripts/compute_layer_norms.py rename to scripts/parameter_norms/compute_layer_norms.py diff --git a/scripts/plot_layer_norms.py b/scripts/parameter_norms/plot_layer_norms.py similarity index 100% rename from scripts/plot_layer_norms.py rename to scripts/parameter_norms/plot_layer_norms.py