diff --git a/scripts/parameter_norms/compute_layer_norms.py b/scripts/parameter_norms/compute_layer_norms.py new file mode 100644 index 000000000..b41beaf35 --- /dev/null +++ b/scripts/parameter_norms/compute_layer_norms.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 + +import argparse +import json +import os +import re +from pathlib import Path +from typing import cast + +import torch +import torch.distributed as dist +from pydantic import BaseModel +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor + +from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import DCPCheckpointLoading +from modalities.checkpointing.stateful.app_state import AppState +from modalities.config.config import ProcessGroupBackendType +from modalities.config.pydantic_if_types import PydanticAppStateType, PydanticDeviceMeshIFType +from modalities.main import Main +from modalities.running_env.cuda_env import CudaEnv +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_mesh_for_parallelism_method + + +class ComponentsInstantiationModel(BaseModel): + app_state: PydanticAppStateType + device_mesh: PydanticDeviceMeshIFType | None = None + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Load one or more Modalities DCP checkpoints into an app state.") + parser.add_argument("--config-file-path", type=Path, required=True, help="Path to the YAML config file.") + parser.add_argument( + "--experiments-root-path", + type=Path, + required=True, + help="Path passed to Main for resolver/context setup.", + ) + parser.add_argument( + "--checkpoint-dir-paths", + type=Path, + nargs="+", + required=True, + help="Paths to multiple checkpoint directories containing *.distcp files.", + ) + parser.add_argument( + "--json-output-path", + type=Path, + default=Path("layer_norms_across_checkpoints.json"), + help="Output path for raw per-checkpoint norms as JSON.", + ) + return parser.parse_args() + + +def _resolve_checkpoint_dir_paths(args: argparse.Namespace) -> list[Path]: + return list(args.checkpoint_dir_paths) + + +def _normalize_parameter_name(parameter_name: str) -> str: + name = parameter_name + for prefix in ("module.", "_orig_mod.", "_fsdp_wrapped_module."): + if name.startswith(prefix): + name = name[len(prefix) :] + return name + + +def _get_dp_shard_group(device_mesh: DeviceMesh | None): + if device_mesh is None: + return None + try: + return get_mesh_for_parallelism_method(device_mesh, ParallelismDegrees.DP_SHARD).get_group() + except Exception: + # Fallback to the default process group if a dedicated DP-shard group is unavailable. + return None + + +def _compute_and_print_parameter_norms(app_state: AppState, dp_shard_group) -> dict[str, float]: + parameter_sq_sums: dict[str, torch.Tensor] = {} + + for model_part_idx, model_part in enumerate(app_state.model_parts): + for name, parameter in model_part.named_parameters(): + if not parameter.requires_grad: + continue + raw_name = f"model_part_{model_part_idx}.{name}" if len(app_state.model_parts) > 1 else name + parameter_name = _normalize_parameter_name(raw_name) + + # FSDP2 parameters can be DTensors. Convert to local shard first so c10d all_reduce + # operates on plain tensors instead of DTensors. + local_param = parameter.to_local() if isinstance(parameter, DTensor) else parameter + local_sq_sum = local_param.detach().float().pow(2).sum() + parameter_sq_sums[parameter_name] = local_sq_sum + + # Aggregate over the DP-shard group to reconstruct global norms for sharded parameters. + for parameter_name, sq_sum in parameter_sq_sums.items(): + dist.all_reduce(sq_sum, op=dist.ReduceOp.SUM, group=dp_shard_group) + parameter_sq_sums[parameter_name] = sq_sum + + parameter_norms = {name: torch.sqrt(sq_sum).item() for name, sq_sum in parameter_sq_sums.items()} + + if dist.get_rank() == 0: + print("Per-parameter L2 norms (global across DP-shards):") + for parameter_name in sorted(parameter_norms): + print(f"{parameter_name}: {parameter_norms[parameter_name]:.6f}") + + return parameter_norms + + +def _extract_checkpoint_label(checkpoint_dir_path: Path) -> str: + match = re.search(r"seen_steps_(\d+)", checkpoint_dir_path.name) + if match: + return f"steps_{match.group(1)}" + return checkpoint_dir_path.name + + +def _save_json_results(results: list[dict], output_path: Path) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2) + + +def main() -> None: + args = _parse_args() + checkpoint_dir_paths = _resolve_checkpoint_dir_paths(args) + + with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): + rank = dist.get_rank() + collected_results: list[dict] = [] + + for checkpoint_dir_path in checkpoint_dir_paths: + # Rebuild components per checkpoint because AppState only supports one load call. + main_obj = Main( + config_path=args.config_file_path, + experiments_root_path=args.experiments_root_path, + ) + components = cast( + ComponentsInstantiationModel, + main_obj.build_components(components_model_type=ComponentsInstantiationModel), + ) + + app_state = cast(AppState, getattr(components, "app_state")) + device_mesh = cast(DeviceMesh | None, getattr(components, "device_mesh", None)) + + loader = DCPCheckpointLoading(global_rank=rank) + loader.load_checkpoint_(app_state=app_state, checkpoint_dir_path=checkpoint_dir_path) + + dp_shard_group = _get_dp_shard_group(device_mesh) + if rank == 0: + print(f"\n=== {checkpoint_dir_path} ===") + parameter_norms = _compute_and_print_parameter_norms(app_state, dp_shard_group) + + if rank == 0: + collected_results.append( + { + "checkpoint_path": str(checkpoint_dir_path), + "checkpoint_label": _extract_checkpoint_label(checkpoint_dir_path), + "parameter_norms": parameter_norms, + } + ) + print( + f"Loaded checkpoint from {checkpoint_dir_path} on world size {dist.get_world_size()} " + f"(pid={os.getpid()})." + ) + + if rank == 0: + _save_json_results(collected_results, args.json_output_path) + print(f"Saved raw parameter norms JSON to {args.json_output_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/parameter_norms/plot_layer_norms.py b/scripts/parameter_norms/plot_layer_norms.py new file mode 100644 index 000000000..4a80cba01 --- /dev/null +++ b/scripts/parameter_norms/plot_layer_norms.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 + +import argparse +import json +import re +from pathlib import Path + +import matplotlib.pyplot as plt +from matplotlib.backends.backend_pdf import PdfPages + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Plot parameter norms across checkpoints from a JSON log file.") + parser.add_argument( + "--layer-norms-json-path", + type=Path, + required=True, + help="Path to JSON produced by scripts/compute_layer_norms.py.", + ) + parser.add_argument( + "--plot-output-path", + type=Path, + default=Path("parameter_norms_grouped_by_layer.pdf"), + help="Output PDF path containing one plot page per layer.", + ) + parser.add_argument( + "--layer-filter-regex", + type=str, + default=r".*", + help="Regex to select layer keys in the visualization.", + ) + return parser.parse_args() + + +def _load_results(path: Path) -> list[dict]: + with open(path, "r", encoding="utf-8") as f: + results = json.load(f) + if not isinstance(results, list) or not results: + raise ValueError("Expected a non-empty JSON list of checkpoint results.") + return results + + +def _extract_layer_key(parameter_name: str) -> str: + tokens = parameter_name.split(".") + for i in range(len(tokens) - 1): + if tokens[i] in {"h", "layers", "blocks"} and tokens[i + 1].isdigit(): + if i > 0: + return ".".join(tokens[i - 1 : i + 2]) + return ".".join(tokens[i : i + 2]) + return ".".join(tokens[:-1]) if len(tokens) > 1 else parameter_name + + +def _layer_sort_key(layer_key: str) -> tuple: + # Prefer numeric ordering for transformer block keys like h.0, layers.12, blocks.3. + match = re.search(r"(?:^|\.)(h|layers|blocks)\.(\d+)(?:\.|$)", layer_key) + if match: + return (0, match.group(1), int(match.group(2)), layer_key) + return (1, layer_key) + + +def _plot_checkpoint_comparison( + results: list[dict], + plot_output_path: Path, + layer_filter_regex: str, +) -> None: + metric_key = "parameter_norms" if "parameter_norms" in results[0] else "layer_norms" + layer_pattern = re.compile(layer_filter_regex) + filtered_parameters = sorted( + { + parameter_name + for checkpoint_result in results + for parameter_name in checkpoint_result[metric_key].keys() + if layer_pattern.search(parameter_name) + } + ) + if not filtered_parameters: + raise ValueError(f"No layer names matched --layer-filter-regex={layer_filter_regex!r}.") + + checkpoint_labels = [checkpoint_result["checkpoint_label"] for checkpoint_result in results] + + grouped_parameters: dict[str, list[str]] = {} + for parameter_name in filtered_parameters: + layer_key = _extract_layer_key(parameter_name) + grouped_parameters.setdefault(layer_key, []).append(parameter_name) + ordered_layer_keys = sorted(grouped_parameters, key=_layer_sort_key) + + plot_output_path.parent.mkdir(parents=True, exist_ok=True) + with PdfPages(plot_output_path) as pdf: + # First page: quick summary of layers and parameter counts. + summary_lines = [ + f"checkpoints: {len(checkpoint_labels)}", + f"layers: {len(grouped_parameters)}", + f"parameters plotted: {len(filtered_parameters)}", + "", + "Layer -> #parameters", + ] + for layer_key in ordered_layer_keys: + summary_lines.append(f"{layer_key}: {len(grouped_parameters[layer_key])}") + + fig, ax = plt.subplots(figsize=(10, 12)) + ax.axis("off") + ax.text(0.01, 0.99, "\n".join(summary_lines), va="top", ha="left", fontsize=10) + fig.tight_layout() + pdf.savefig(fig) + plt.close(fig) + + # One page per layer with all parameter curves for that layer. + x = list(range(len(checkpoint_labels))) + for layer_key in ordered_layer_keys: + parameter_names = sorted(grouped_parameters[layer_key]) + fig, ax = plt.subplots(figsize=(12, 6)) + for parameter_name in parameter_names: + y = [checkpoint_result[metric_key].get(parameter_name, float("nan")) for checkpoint_result in results] + short_name = ( + parameter_name[len(layer_key) + 1 :] + if parameter_name.startswith(layer_key + ".") + else parameter_name + ) + ax.plot(x, y, marker="o", linewidth=1.5, label=short_name) + + ax.set_title(f"{layer_key} parameter norms across checkpoints") + ax.set_xlabel("Checkpoint") + ax.set_ylabel("L2 norm") + ax.set_xticks(x) + ax.set_xticklabels(checkpoint_labels, rotation=45, ha="right") + ax.grid(True, alpha=0.25) + ax.legend(loc="best", fontsize=8) + fig.tight_layout() + pdf.savefig(fig) + plt.close(fig) + + +def main() -> None: + args = _parse_args() + results = _load_results(args.layer_norms_json_path) + _plot_checkpoint_comparison( + results=results, + plot_output_path=args.plot_output_path, + layer_filter_regex=args.layer_filter_regex, + ) + print(f"Saved grouped parameter-norm plots to {args.plot_output_path}") + + +if __name__ == "__main__": + main()