From c0e0a6d23bc8d886c65a87d2e3b3d5af25e68914 Mon Sep 17 00:00:00 2001 From: bolunz Date: Wed, 13 May 2026 09:07:57 +0800 Subject: [PATCH 1/2] feat: init ac --- example/gpt2/checkpoint_loader.cc | 12 + example/gpt2/checkpoint_loader.h | 4 + example/gpt2/main.cc | 23 +- example/llama3/checkpoint_loader.cc | 12 + example/llama3/checkpoint_loader.h | 4 + example/llama3/main.cc | 20 +- example/mnist/main.cc | 4 +- infini_train/include/autocast.h | 18 + infini_train/include/autograd/function.h | 37 +- infini_train/include/autograd/grad_mode.h | 18 + infini_train/include/autograd/normalization.h | 2 + infini_train/include/nn/modules/module.h | 2 + .../modules/transformer/transformer_config.h | 24 ++ .../include/nn/parallel/tensor_parallel.h | 7 + infini_train/include/tensor.h | 1 + infini_train/include/utils/checkpoint.h | 41 +++ infini_train/src/autograd/activations.cc | 6 +- infini_train/src/autograd/elementwise.cc | 58 +-- infini_train/src/autograd/function.cc | 101 ++++- infini_train/src/autograd/grad_mode.cc | 1 + infini_train/src/autograd/linear.cc | 8 +- infini_train/src/autograd/loss.cc | 8 +- infini_train/src/autograd/matmul.cc | 8 +- infini_train/src/autograd/misc.cc | 10 +- infini_train/src/autograd/normalization.cc | 19 +- infini_train/src/autograd/outer.cc | 8 +- infini_train/src/autograd/reduction.cc | 16 +- infini_train/src/autograd/softmax.cc | 6 +- infini_train/src/autograd/sparse.cc | 4 +- infini_train/src/nn/modules/module.cc | 8 +- .../src/nn/modules/transformer/transformer.cc | 97 ++++- .../modules/transformer/transformer_config.cc | 70 ++++ .../src/nn/parallel/tensor_parallel.cc | 26 +- infini_train/src/tensor.cc | 11 + infini_train/src/utils/checkpoint.cc | 348 ++++++++++++++++++ scripts/test_config.json | 230 +++++++++++- 36 files changed, 1158 insertions(+), 114 deletions(-) create mode 100644 infini_train/include/utils/checkpoint.h create mode 100644 infini_train/src/utils/checkpoint.cc diff --git a/example/gpt2/checkpoint_loader.cc b/example/gpt2/checkpoint_loader.cc index 57064423..0af968c8 100644 --- a/example/gpt2/checkpoint_loader.cc +++ b/example/gpt2/checkpoint_loader.cc @@ -52,11 +52,22 @@ std::tuple DetermineAndCheckVersion(const std:: return {}; // Unreachable, but keeps compiler happy } } + +void ApplyRuntimeRecomputeConfig(nn::TransformerConfig *config, const nn::TransformerConfig &runtime_config) { + config->recompute_granularity = runtime_config.recompute_granularity; + config->recompute_method = runtime_config.recompute_method; + config->recompute_num_layers = runtime_config.recompute_num_layers; +} } // namespace namespace gpt2 { std::shared_ptr LoadFromLLMC(const std::string &filepath) { + return LoadFromLLMC(filepath, gpt2::GPT2Config()); +} + +std::shared_ptr LoadFromLLMC(const std::string &filepath, + const nn::TransformerConfig &runtime_config) { if (!std::filesystem::exists(filepath)) { LOG(FATAL) << "File not found: " << filepath; } @@ -87,6 +98,7 @@ std::shared_ptr LoadFromLLMC(const std::string &filepath) gpt2_config.n_layer = n_layer; gpt2_config.n_head = n_head; gpt2_config.n_embd = n_embd; + ApplyRuntimeRecomputeConfig(&gpt2_config, runtime_config); auto local_gpt2 = std::make_shared(gpt2_config); LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size diff --git a/example/gpt2/checkpoint_loader.h b/example/gpt2/checkpoint_loader.h index e80c356e..e486a6d8 100644 --- a/example/gpt2/checkpoint_loader.h +++ b/example/gpt2/checkpoint_loader.h @@ -3,10 +3,14 @@ #include #include +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + namespace infini_train::nn { class TransformerModel; } // namespace infini_train::nn namespace gpt2 { std::shared_ptr LoadFromLLMC(const std::string &filepath); +std::shared_ptr +LoadFromLLMC(const std::string &filepath, const infini_train::nn::TransformerConfig &runtime_config); } // namespace gpt2 diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index f69736f5..e9b48336 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -75,6 +75,12 @@ DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages."); DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); +// activation recompute +DEFINE_bool(activation_recompute, false, "Enable activation recompute to trade compute for memory."); +DEFINE_string(recompute_granularity, "full", "Activation recompute granularity: none|full|selective"); +DEFINE_string(recompute_method, "none", "Activation recompute method: none|uniform|block"); +DEFINE_uint32(recompute_num_layers, 0, "Number of transformer layers per recompute region for uniform/block methods."); + // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); // precision check @@ -186,21 +192,24 @@ void Train(const nn::parallel::Rank &rank) { nn::TransformerConfig model_config = gpt2::GPT2Config(); std::shared_ptr model = nullptr; - if (!FLAGS_llmc_filepath.empty()) { - model = gpt2::LoadFromLLMC(FLAGS_llmc_filepath); - } else if (kModelToConfigs.count(FLAGS_model)) { + if (FLAGS_llmc_filepath.empty() && kModelToConfigs.count(FLAGS_model)) { model_config = kModelToConfigs.at(FLAGS_model); + } + nn::SetActivationRecomputeConfig(&model_config, FLAGS_activation_recompute, FLAGS_recompute_granularity, + FLAGS_recompute_method, static_cast(FLAGS_recompute_num_layers)); + + if (!FLAGS_llmc_filepath.empty()) { + model = gpt2::LoadFromLLMC(FLAGS_llmc_filepath, model_config); + } else { model = std::make_shared(model_config); } + CHECK(model) << "GPT2 example expects GPT2 model."; + model->To(device); utils::PrecisionChecker::BuildNameMap(model.get()); - // Get chunk size before wrapping with LoRA (needed for PipelineParallel) - auto gpt2_model = std::dynamic_pointer_cast(model); - CHECK(gpt2_model) << "GPT2 example expects GPT2 model."; - // Apply LoRA using GetLoRAModel (in-place injection) bool lora_enabled = FLAGS_lora_rank > 0; if (lora_enabled) { diff --git a/example/llama3/checkpoint_loader.cc b/example/llama3/checkpoint_loader.cc index a31d1748..2064f8a0 100644 --- a/example/llama3/checkpoint_loader.cc +++ b/example/llama3/checkpoint_loader.cc @@ -35,11 +35,22 @@ static std::mt19937 gen{kRandomSeed}; namespace { constexpr int32_t kLLaMA3Magic = 20240803; constexpr int32_t kLLaMA3FP32Version = 3; + +void ApplyRuntimeRecomputeConfig(nn::TransformerConfig *config, const nn::TransformerConfig &runtime_config) { + config->recompute_granularity = runtime_config.recompute_granularity; + config->recompute_method = runtime_config.recompute_method; + config->recompute_num_layers = runtime_config.recompute_num_layers; +} } // namespace namespace llama3 { std::shared_ptr LoadFromLLMC(const std::string &filepath) { + return LoadFromLLMC(filepath, llama3::LLaMA3Config()); +} + +std::shared_ptr LoadFromLLMC(const std::string &filepath, + const nn::TransformerConfig &runtime_config) { if (!std::filesystem::exists(filepath)) { LOG(FATAL) << "File not found: " << filepath; } @@ -80,6 +91,7 @@ std::shared_ptr LoadFromLLMC(const std::string &filepath) llama3_config.use_scaled_rope = static_cast(use_scaled_rope); llama3_config.norm_eps = norm_eps; llama3_config.max_gen_batch_size = max_gen_bs; + ApplyRuntimeRecomputeConfig(&llama3_config, runtime_config); auto llama3 = std::make_shared(llama3_config); // ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ========== diff --git a/example/llama3/checkpoint_loader.h b/example/llama3/checkpoint_loader.h index d4aea3d0..61a44873 100644 --- a/example/llama3/checkpoint_loader.h +++ b/example/llama3/checkpoint_loader.h @@ -3,10 +3,14 @@ #include #include +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + namespace infini_train::nn { class TransformerModel; } // namespace infini_train::nn namespace llama3 { std::shared_ptr LoadFromLLMC(const std::string &filepath); +std::shared_ptr +LoadFromLLMC(const std::string &filepath, const infini_train::nn::TransformerConfig &runtime_config); } // namespace llama3 diff --git a/example/llama3/main.cc b/example/llama3/main.cc index da9a1027..3fd9f85e 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -73,6 +73,12 @@ DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages."); DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); + +// activation recompute +DEFINE_bool(activation_recompute, false, "Enable activation recompute to trade compute for memory."); +DEFINE_string(recompute_granularity, "full", "Activation recompute granularity: none|full|selective"); +DEFINE_string(recompute_method, "none", "Activation recompute method: none|uniform|block"); +DEFINE_uint32(recompute_num_layers, 0, "Number of transformer layers per recompute region for uniform/block methods."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); // precision check @@ -171,12 +177,16 @@ void Train(const nn::parallel::Rank &rank) { nn::TransformerConfig model_config = llama3::LLaMA3Config(); std::shared_ptr model = nullptr; + nn::SetActivationRecomputeConfig(&model_config, FLAGS_activation_recompute, FLAGS_recompute_granularity, + FLAGS_recompute_method, static_cast(FLAGS_recompute_num_layers)); if (!FLAGS_llmc_filepath.empty()) { - model = llama3::LoadFromLLMC(FLAGS_llmc_filepath); + model = llama3::LoadFromLLMC(FLAGS_llmc_filepath, model_config); } else { model = std::make_shared(model_config); } + CHECK(model) << "LLaMA3 example expects LLaMA3 model."; + model->To(device); utils::PrecisionChecker::BuildNameMap(model.get()); @@ -357,12 +367,20 @@ void Train(const nn::parallel::Rank &rank) { autocast_guard.Disable(); LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish loss forward"; + auto [forward_used_mb, forward_reserved_mb] = impl->GetMemPoolPeakMB(device); + LOG(INFO) << std::format( + "Rank {}: after forward (micro_step {}/{}), peak used: {:5d} MB | peak reserved: {:5d} MB", + rank.GlobalRank(), micro_step + 1, grad_accum_steps, forward_used_mb, forward_reserved_mb); auto loss_cpu = loss->To(Device()); lossf += static_cast(loss_cpu.DataPtr())[0]; LOG(INFO) << "Rank " << rank.GlobalRank() << ": start backward"; loss->Backward(); LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish backward"; + auto [backward_used_mb, backward_reserved_mb] = impl->GetMemPoolPeakMB(device); + LOG(INFO) << std::format( + "Rank {}: after backward (micro_step {}/{}), peak used: {:5d} MB | peak reserved: {:5d} MB", + rank.GlobalRank(), micro_step + 1, grad_accum_steps, backward_used_mb, backward_reserved_mb); } optimizer->Step(); diff --git a/example/mnist/main.cc b/example/mnist/main.cc index e62257d7..d3db61e9 100644 --- a/example/mnist/main.cc +++ b/example/mnist/main.cc @@ -66,7 +66,7 @@ int main(int argc, char *argv[]) { auto new_image = std::make_shared(image->To(device)); auto new_label = std::make_shared(label->To(device)); - auto outputs = network.Forward({new_image}); + auto outputs = network({new_image}); optimizer.ZeroGrad(); auto loss = loss_fn.Forward({outputs[0], new_label}); @@ -101,7 +101,7 @@ int main(int argc, char *argv[]) { auto new_label = std::make_shared(label->To(device)); auto label_cpu = label->To(cpu_device); - auto outputs = network.Forward({new_image}); + auto outputs = network({new_image}); auto output_cpu = outputs[0]->To(cpu_device); auto loss = loss_fn.Forward({outputs[0], new_label}); auto loss_cpu = loss[0]->To(cpu_device); diff --git a/infini_train/include/autocast.h b/infini_train/include/autocast.h index 499c586f..43513e2a 100644 --- a/infini_train/include/autocast.h +++ b/infini_train/include/autocast.h @@ -164,6 +164,24 @@ struct AutocastContext { // Global thread-local storage for autocast context inline thread_local AutocastContext tls_autocast_context; +// Lightweight snapshot for autocast state +struct AutocastState { + bool enabled = false; + Device::DeviceType device_type = Device::DeviceType::kCPU; + DataType autocast_dtype = DataType::kBFLOAT16; +}; + +inline AutocastState GetAutocastState() { + return AutocastState{tls_autocast_context.enabled, tls_autocast_context.device_type, + tls_autocast_context.autocast_dtype}; +} + +inline void SetAutocastState(const AutocastState &state) { + tls_autocast_context.enabled = state.enabled; + tls_autocast_context.device_type = state.device_type; + tls_autocast_context.autocast_dtype = state.autocast_dtype; +} + // RAII guard to enable/disable autocast in a scope class AutocastGuard { public: diff --git a/infini_train/include/autograd/function.h b/infini_train/include/autograd/function.h index e1e1390e..adaf1686 100644 --- a/infini_train/include/autograd/function.h +++ b/infini_train/include/autograd/function.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -21,6 +22,15 @@ class Function : public std::enable_shared_from_this { using FunctionPostHook = std::function> &, const std::vector> &)>; + // Definition of hooks for saved_tensors, in alignment with torch.autograd.graph.saved_tensors_hooks + using SavedTensorPackHook = std::function(const std::shared_ptr &)>; + using SavedTensorUnpackHook = std::function(const std::shared_ptr &)>; + + struct SavedTensorHooks { + SavedTensorPackHook pack; + SavedTensorUnpackHook unpack; + }; + static constexpr char kUndefinedType[] = "Undefined"; Function() : type_(kUndefinedType) {} @@ -45,8 +55,33 @@ class Function : public std::enable_shared_from_this { const std::string &type() const { return type_; } + void SaveForBackward(const std::vector> &tensors); + size_t SavedTensorsSize() const { return saved_tensors_.size(); } + std::shared_ptr GetSavedTensor(size_t index) const; + std::vector> GetSavedTensors() const; + + // RAII: Register pack/unpack hooks for saved_tensors, align with torch.autograd.graph.saved_tensors_hooks + class SavedTensorHooksGuard { + public: + explicit SavedTensorHooksGuard(SavedTensorHooks hooks); + ~SavedTensorHooksGuard(); + + SavedTensorHooksGuard(const SavedTensorHooksGuard &) = delete; + SavedTensorHooksGuard &operator=(const SavedTensorHooksGuard &) = delete; + + private: + size_t depth_ = 0; + }; + protected: - std::vector> saved_tensors_; + struct SavedTensorEntry { + // Tensor itself, used under default or reentrant version of recomputation + std::shared_ptr tensor; + // Function to recompute the target tensor, used under non-reentrant version of recomputation + std::shared_ptr hook_state; + SavedTensorUnpackHook unpack; + }; + std::vector saved_tensors_; std::vector needs_input_grad_; private: diff --git a/infini_train/include/autograd/grad_mode.h b/infini_train/include/autograd/grad_mode.h index 65157387..e5fe47b7 100644 --- a/infini_train/include/autograd/grad_mode.h +++ b/infini_train/include/autograd/grad_mode.h @@ -8,10 +8,13 @@ class GradMode { // Whether to enable Autograd (enabled by default) static bool IsEnabled() { return grad_enabled_; } static void SetEnabled(bool enabled) { grad_enabled_ = enabled; } + static bool PropagateRequiresGrad() { return propagate_requires_grad_; } + static void SetPropagateRequiresGrad(bool enabled) { propagate_requires_grad_ = enabled; } private: // grad mode should be thread_local static thread_local bool grad_enabled_; + static thread_local bool propagate_requires_grad_; }; // RAII: Disable grad (align with torch.no_grad) @@ -34,4 +37,19 @@ class EnableGradGuard { bool prev_; }; +// RAII: Propagate requires_grad metadata while graph construction is disabled. +// Used by non-reentrant checkpoint recomputation so downstream SetupContext +// calls see the same needs_input_grad_ pattern as the original forward, +// without wiring the recompute graph into the engine. +class PropagateRequiresGradGuard { +public: + PropagateRequiresGradGuard() : prev_(GradMode::PropagateRequiresGrad()) { + GradMode::SetPropagateRequiresGrad(true); + } + ~PropagateRequiresGradGuard() { GradMode::SetPropagateRequiresGrad(prev_); } + +private: + bool prev_; +}; + } // namespace infini_train::autograd diff --git a/infini_train/include/autograd/normalization.h b/infini_train/include/autograd/normalization.h index c4148cba..3288cb3f 100644 --- a/infini_train/include/autograd/normalization.h +++ b/infini_train/include/autograd/normalization.h @@ -23,5 +23,7 @@ class LayerNorm : public Function { private: const float eps_ = 1e-5f; + std::shared_ptr mean_ = nullptr; + std::shared_ptr rstd_ = nullptr; }; } // namespace infini_train::autograd diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index f366661b..6a66230e 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -104,6 +104,8 @@ class Module : public std::enable_shared_from_this { std::vector backward_pre_hooks_; std::vector backward_post_hooks_; + std::vector> ForwardWithHooks(const std::vector> &input_tensors); + private: friend std::vector> parallel::function::Replicate(const std::shared_ptr &network, const std::vector &devices); diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index 62379666..8155a8d7 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -2,6 +2,7 @@ #include #include +#include namespace infini_train::nn { @@ -25,6 +26,18 @@ enum class NormType { kRMSNorm // RMSNorm }; +enum class ActivationRecomputeGranularity { + kNone, // Disable activation recompute. + kFull, // Recompute full transformer layers. + kSelective, // Recompute selected transformer submodules. +}; + +enum class ActivationRecomputeMethod { + kNone, // Recompute every layer when granularity is full. + kUniform, // Uniformly divide layers into recompute chunks. + kBlock, // Recompute only the first N layers in the local chunk/stage. +}; + struct TransformerConfig { int64_t block_size = 1024; // Max seq_len int64_t vocab_size = 50304; // Vocab size @@ -56,6 +69,11 @@ struct TransformerConfig { // Normalization float norm_eps = 1e-5f; // epsilon in RMSNorm + // Activation recomputation, aligned with Megatron-Core TransformerConfig. + ActivationRecomputeGranularity recompute_granularity = ActivationRecomputeGranularity::kNone; + ActivationRecomputeMethod recompute_method = ActivationRecomputeMethod::kNone; + int64_t recompute_num_layers = 0; + // Inference bool use_kv = false; // kv cache bool flash = false; // flash attention @@ -63,5 +81,11 @@ struct TransformerConfig { bool UseGQA() const; int GetChunkSize() const; + bool RecomputeEnabled() const; }; + +ActivationRecomputeGranularity ParseActivationRecomputeGranularity(std::string_view value); +ActivationRecomputeMethod ParseActivationRecomputeMethod(std::string_view value); +void SetActivationRecomputeConfig(TransformerConfig *config, bool enabled, std::string_view granularity, + std::string_view method, int64_t num_layers); } // namespace infini_train::nn diff --git a/infini_train/include/nn/parallel/tensor_parallel.h b/infini_train/include/nn/parallel/tensor_parallel.h index 0611dfbb..1b744581 100644 --- a/infini_train/include/nn/parallel/tensor_parallel.h +++ b/infini_train/include/nn/parallel/tensor_parallel.h @@ -104,6 +104,8 @@ class VocabParallelCrossEntropy : public autograd::Function { : autograd::Function(kType), vocab_size_original_(vocab_size_original), label_smoothing_(label_smoothing) {} std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; std::vector> Backward(const std::vector> &grad_outputs) override; private: @@ -113,6 +115,11 @@ class VocabParallelCrossEntropy : public autograd::Function { int64_t vocab_size_local_ = 0; int64_t vocab_size_global_ = 0; int64_t vocab_size_original_ = 0; // For padded situations + + std::shared_ptr softmax_local_; + std::shared_ptr target_mask_; + std::shared_ptr masked_target_; + std::shared_ptr valid_mask_local_; }; class VocabParallelCrossEntropyLoss : public nn::CloneableModule { diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index 4f4ed94b..e2d71e2b 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -142,6 +142,7 @@ class Tensor : public std::enable_shared_from_this { // class before this can be implemented correctly. The guard in elementwise.cu ensures // non-contiguous tensors fall back to the broadcast path until this is resolved. bool IsContiguous() const; + std::shared_ptr Detach() const; std::shared_ptr Flatten(int64_t start = 0, int64_t end = -1); std::shared_ptr Squeeze(int64_t dim); std::shared_ptr Unsqueeze(int64_t dim); diff --git a/infini_train/include/utils/checkpoint.h b/infini_train/include/utils/checkpoint.h new file mode 100644 index 00000000..e9306651 --- /dev/null +++ b/infini_train/include/utils/checkpoint.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include + +#include "infini_train/include/autocast.h" +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} // namespace infini_train + +namespace infini_train::utils::checkpoint { + +class CheckpointFunction : public autograd::Function { +public: + using ForwardFn = std::function>(const std::vector> &)>; + + explicit CheckpointFunction(ForwardFn forward_fn); + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + ForwardFn forward_fn_; + + std::vector> saved_inputs_; + std::vector saved_inputs_requires_grad_; + AutocastState saved_autocast_; +}; + +// Reentrant activation checkpointing (torch.utils.checkpoint.checkpoint style). +std::vector> Checkpoint(const CheckpointFunction::ForwardFn &forward_fn, + const std::vector> &inputs, + bool use_reentrant = false, bool preserve_rng_state = true, + bool determinism_check = true, bool early_stop = true); + +} // namespace infini_train::utils::checkpoint diff --git a/infini_train/src/autograd/activations.cc b/infini_train/src/autograd/activations.cc index 3641865a..184bcd68 100644 --- a/infini_train/src/autograd/activations.cc +++ b/infini_train/src/autograd/activations.cc @@ -17,12 +17,12 @@ std::vector> Sigmoid::Forward(const std::vector> &, const std::vector> &output_tensors) { const auto &output = output_tensors[0]; - saved_tensors_ = {output}; + SaveForBackward({output}); } std::vector> Sigmoid::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &output = saved_tensors_[0]; + CHECK_EQ(SavedTensorsSize(), 1); + const auto &output = GetSavedTensor(0); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; diff --git a/infini_train/src/autograd/elementwise.cc b/infini_train/src/autograd/elementwise.cc index 655cd309..567599a0 100644 --- a/infini_train/src/autograd/elementwise.cc +++ b/infini_train/src/autograd/elementwise.cc @@ -33,12 +33,12 @@ std::vector> Reciprocal::Forward(const std::vector> &input_tensors, const std::vector> &) { const auto &input = input_tensors[0]; - saved_tensors_ = {input}; + SaveForBackward({input}); } std::vector> Reciprocal::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; + CHECK_EQ(SavedTensorsSize(), 1); + const auto &input = GetSavedTensor(0); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -57,12 +57,12 @@ std::vector> Sin::Forward(const std::vector> &input_tensors, const std::vector> &) { const auto &input = input_tensors[0]; - saved_tensors_ = {input}; + SaveForBackward({input}); } std::vector> Sin::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; + CHECK_EQ(SavedTensorsSize(), 1); + const auto &input = GetSavedTensor(0); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -81,12 +81,12 @@ std::vector> Cos::Forward(const std::vector> &input_tensors, const std::vector> &) { const auto &input = input_tensors[0]; - saved_tensors_ = {input}; + SaveForBackward({input}); } std::vector> Cos::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; + CHECK_EQ(SavedTensorsSize(), 1); + const auto &input = GetSavedTensor(0); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -105,12 +105,12 @@ std::vector> Tanh::Forward(const std::vector> &, const std::vector> &output_tensors) { const auto &output = output_tensors[0]; - saved_tensors_ = {output}; + SaveForBackward({output}); } std::vector> Tanh::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &output = saved_tensors_[0]; + CHECK_EQ(SavedTensorsSize(), 1); + const auto &output = GetSavedTensor(0); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -130,12 +130,12 @@ std::vector> Pow::Forward(const std::vector> &input_tensors, const std::vector> &) { const auto &input = input_tensors[0]; - saved_tensors_ = {input}; + SaveForBackward({input}); } std::vector> Pow::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; + CHECK_EQ(SavedTensorsSize(), 1); + const auto &input = GetSavedTensor(0); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -155,12 +155,12 @@ std::vector> Rsqrt::Forward(const std::vector> &input_tensors, const std::vector> &) { const auto &input = input_tensors[0]; - saved_tensors_ = {input}; + SaveForBackward({input}); } std::vector> Rsqrt::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; + CHECK_EQ(SavedTensorsSize(), 1); + const auto &input = GetSavedTensor(0); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -195,12 +195,12 @@ std::vector> Log::Forward(const std::vector> &input_tensors, const std::vector> &) { const auto &input = input_tensors[0]; - saved_tensors_ = {input}; + SaveForBackward({input}); } std::vector> Log::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; + CHECK_EQ(SavedTensorsSize(), 1); + const auto &input = GetSavedTensor(0); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -455,13 +455,13 @@ void Mul::SetupContext(const std::vector> &input_tensors const std::vector> &) { const auto &a = input_tensors[0]; const auto &b = input_tensors[1]; - saved_tensors_ = {a, b}; + SaveForBackward({a, b}); } std::vector> Mul::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 2); - const auto &a = saved_tensors_[0]; - const auto &b = saved_tensors_[1]; + CHECK_EQ(SavedTensorsSize(), 2); + const auto &a = GetSavedTensor(0); + const auto &b = GetSavedTensor(1); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; @@ -500,13 +500,13 @@ void Div::SetupContext(const std::vector> &input_tensors const std::vector> &) { const auto &a = input_tensors[0]; const auto &b = input_tensors[1]; - saved_tensors_ = {a, b}; + SaveForBackward({a, b}); } std::vector> Div::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 2); - const auto &a = saved_tensors_[0]; - const auto &b = saved_tensors_[1]; + CHECK_EQ(SavedTensorsSize(), 2); + const auto &a = GetSavedTensor(0); + const auto &b = GetSavedTensor(1); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; diff --git a/infini_train/src/autograd/function.cc b/infini_train/src/autograd/function.cc index f7eb35c7..08675c79 100644 --- a/infini_train/src/autograd/function.cc +++ b/infini_train/src/autograd/function.cc @@ -14,6 +14,9 @@ #include "infini_train/include/utils/precision_checker.h" namespace infini_train::autograd { +namespace { +thread_local std::vector tls_saved_tensor_hooks; +} // namespace std::vector> Function::Apply(const std::vector> &input_tensors) { CHECK_GE(input_tensors.size(), 1); @@ -37,13 +40,13 @@ std::vector> Function::Apply(const std::vectorrequires_grad(); - } + // SetupContext can use it for saved-tensor pruning. This must not depend + // on GradMode: non-reentrant checkpoint recomputes under no-grad to avoid + // wiring an unused recompute graph into the engine, but SetupContext still + // needs the original per-input grad requirements. + needs_input_grad_.resize(input_tensors.size()); + for (size_t idx = 0; idx < input_tensors.size(); ++idx) { + needs_input_grad_[idx] = input_tensors[idx] && input_tensors[idx]->requires_grad(); } std::vector> output_tensors; @@ -62,13 +65,36 @@ std::vector> Function::Apply(const std::vectorrequires_grad(); + } + for (int output_idx = 0; output_idx < output_tensors.size(); ++output_idx) { + auto &output_tensor = output_tensors[output_idx]; + if (!output_tensor) { + continue; + } + output_tensor->set_requires_grad(output_requires_grad); + output_tensor->set_grad_fn(nullptr); + output_tensor->set_is_leaf(true); + output_tensor->set_output_idx(output_idx); + } + } return output_tensors; } bool output_requires_grad = false; for (int idx = 0; idx < input_tensors.size(); ++idx) { const auto &input_tensor = input_tensors[idx]; + if (!input_tensor) { + next_functions_.emplace_back(nullptr, 0); + continue; + } if (input_tensor->requires_grad() && input_tensor->is_leaf()) { next_functions_.emplace_back(input_tensor->grad_accumulator(), input_tensor->output_idx()); input_tensor->grad_accumulator()->IncreaseDependenciesNumber(); @@ -80,7 +106,6 @@ std::vector> Function::Apply(const std::vectorrequires_grad(); } - grad_outputs_reached_ = 0; grad_outputs_.resize(output_tensors.size(), nullptr); for (int output_idx = 0; output_idx < output_tensors.size(); ++output_idx) { @@ -176,6 +201,64 @@ void Function::BackwardPartial(std::shared_ptr grad_output, int grad_out } } +void Function::SaveForBackward(const std::vector> &tensors) { + saved_tensors_.clear(); + saved_tensors_.reserve(tensors.size()); + for (const auto &tensor : tensors) { + SavedTensorEntry entry; + if (!tensor || tls_saved_tensor_hooks.empty()) { + // If no hooks are registered, save the tensor itself + entry.tensor = tensor; + } else { + // Otherwise, use the pack_hook to obtain the related states and save unpack hook + const auto &hooks = tls_saved_tensor_hooks.back(); + if (!hooks.pack && !hooks.unpack) { + entry.tensor = tensor; + } else { + entry.hook_state = hooks.pack ? hooks.pack(tensor) : nullptr; + entry.unpack = hooks.unpack; + } + } + saved_tensors_.push_back(std::move(entry)); + } +} + +std::shared_ptr Function::GetSavedTensor(size_t index) const { + CHECK_LT(index, SavedTensorsSize()); + const auto &entry = saved_tensors_[index]; + if (entry.tensor) { + // If the tensor itself is saved, then no recomputation is needed + return entry.tensor; + } + if (entry.hook_state && entry.unpack) { + // If unpack hook is saved, then do the recomputation + return entry.unpack(entry.hook_state); + } + return nullptr; +} + +std::vector> Function::GetSavedTensors() const { + std::vector> out; + out.reserve(SavedTensorsSize()); + for (size_t i = 0; i < SavedTensorsSize(); ++i) { out.push_back(GetSavedTensor(i)); } + return out; +} + +Function::SavedTensorHooksGuard::SavedTensorHooksGuard(SavedTensorHooks hooks) { + tls_saved_tensor_hooks.push_back(std::move(hooks)); + depth_ = tls_saved_tensor_hooks.size(); +} + +Function::SavedTensorHooksGuard::~SavedTensorHooksGuard() { + if (tls_saved_tensor_hooks.size() == depth_) { + // Generally depth_ should be equal to the number of hooks + tls_saved_tensor_hooks.pop_back(); + } else if (!tls_saved_tensor_hooks.empty()) { + LOG(WARNING) << "SavedTensorHooksGuard: redundant hooks are detected."; + tls_saved_tensor_hooks.pop_back(); + } +} + void Function::IncreaseDependenciesNumber() { ++dependencies_number_; } std::shared_ptr Function::RegisterForwardPreHook(FunctionPreHook hook) { diff --git a/infini_train/src/autograd/grad_mode.cc b/infini_train/src/autograd/grad_mode.cc index 28a6e693..f638f2bc 100644 --- a/infini_train/src/autograd/grad_mode.cc +++ b/infini_train/src/autograd/grad_mode.cc @@ -2,4 +2,5 @@ namespace infini_train::autograd { thread_local bool GradMode::grad_enabled_ = true; +thread_local bool GradMode::propagate_requires_grad_ = false; } // namespace infini_train::autograd diff --git a/infini_train/src/autograd/linear.cc b/infini_train/src/autograd/linear.cc index ff0283ce..9d1b73be 100644 --- a/infini_train/src/autograd/linear.cc +++ b/infini_train/src/autograd/linear.cc @@ -39,7 +39,7 @@ void Linear::SetupContext(const std::vector> &input_tens }; // grad_input needs weight, grad_weight needs input - saved_tensors_ = {need_weight ? cast(input) : nullptr, need_input ? cast(weight) : nullptr}; + SaveForBackward({need_weight ? cast(input) : nullptr, need_input ? cast(weight) : nullptr}); transpose_ = true; bias_ = input_tensors.size() == 3; @@ -49,9 +49,9 @@ void Linear::SetupContext(const std::vector> &input_tens } std::vector> Linear::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 2); - const auto &input = saved_tensors_[0]; - const auto &weight = saved_tensors_[1]; + CHECK_EQ(SavedTensorsSize(), 2); + const auto &input = GetSavedTensor(0); + const auto &weight = GetSavedTensor(1); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; diff --git a/infini_train/src/autograd/loss.cc b/infini_train/src/autograd/loss.cc index 657ea649..d76ef977 100644 --- a/infini_train/src/autograd/loss.cc +++ b/infini_train/src/autograd/loss.cc @@ -19,13 +19,13 @@ void CrossEntropy::SetupContext(const std::vector> &inpu const std::vector> &) { const auto &input = input_tensors[0]; const auto &target = input_tensors[1]; - saved_tensors_ = {input, target}; + SaveForBackward({input, target}); } std::vector> CrossEntropy::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 2); - const auto &input = saved_tensors_[0]; - const auto &target = saved_tensors_[1]; + CHECK_EQ(SavedTensorsSize(), 2); + const auto &input = GetSavedTensor(0); + const auto &target = GetSavedTensor(1); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; diff --git a/infini_train/src/autograd/matmul.cc b/infini_train/src/autograd/matmul.cc index 8ddfc578..8a43e048 100644 --- a/infini_train/src/autograd/matmul.cc +++ b/infini_train/src/autograd/matmul.cc @@ -41,16 +41,16 @@ void Matmul::SetupContext(const std::vector> &input_tens return t->Dtype() == compute_dtype ? t : std::make_shared(t->To(compute_dtype)); }; - saved_tensors_ = {need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr}; + SaveForBackward({need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr}); input1_dims_ = input1->Dims(); input2_dims_ = input2->Dims(); out_features_ = output->Dims()[0]; } std::vector> Matmul::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 2); - const auto &input1 = saved_tensors_[0]; - const auto &input2 = saved_tensors_[1]; + CHECK_EQ(SavedTensorsSize(), 2); + const auto &input1 = GetSavedTensor(0); + const auto &input2 = GetSavedTensor(1); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; diff --git a/infini_train/src/autograd/misc.cc b/infini_train/src/autograd/misc.cc index 601258eb..8e2810cd 100644 --- a/infini_train/src/autograd/misc.cc +++ b/infini_train/src/autograd/misc.cc @@ -42,13 +42,13 @@ void IndexGather::SetupContext(const std::vector> &input const auto &input = input_tensors[0]; const auto &index = input_tensors[1]; input_dims_ = input->Dims(); - saved_tensors_ = {index}; + SaveForBackward({index}); } std::vector> IndexGather::Backward(const std::vector> &grad_outputs) { CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; - const auto &index = saved_tensors_[0]; + const auto &index = GetSavedTensor(0); auto device = grad_outputs[0]->GetDevice(); auto kernel = Dispatcher::Instance().GetKernel({device.type(), "IndexGatherBackward"}); @@ -90,12 +90,12 @@ void Slice::SetupContext(const std::vector> &input_tenso const std::vector> &) { // FIXME(dcj): only input's dim need to be saved const auto &input = input_tensors[0]; - saved_tensors_ = {input}; + SaveForBackward({input}); } std::vector> Slice::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; + CHECK_EQ(SavedTensorsSize(), 1); + const auto &input = GetSavedTensor(0); const auto &grad_output = grad_outputs[0]; auto device = input->GetDevice().type(); diff --git a/infini_train/src/autograd/normalization.cc b/infini_train/src/autograd/normalization.cc index 79a14abb..f6dd761c 100644 --- a/infini_train/src/autograd/normalization.cc +++ b/infini_train/src/autograd/normalization.cc @@ -18,7 +18,8 @@ std::vector> LayerNorm::Forward(const std::vector, std::shared_ptr, std::shared_ptr>>( {device, "LayerNormForward"}, input, weight, bias, eps_); - saved_tensors_ = {mean, rstd}; + mean_ = mean; + rstd_ = rstd; return {output}; } @@ -27,16 +28,18 @@ void LayerNorm::SetupContext(const std::vector> &input_t const auto &input = input_tensors[0]; const auto &weight = input_tensors[1]; const auto &bias = input_tensors[2]; - saved_tensors_.insert(saved_tensors_.begin(), {input, weight, bias}); + SaveForBackward({input, weight, bias, mean_, rstd_}); + mean_.reset(); + rstd_.reset(); } std::vector> LayerNorm::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 5); - const auto &input = saved_tensors_[0]; - const auto &weight = saved_tensors_[1]; - const auto &bias = saved_tensors_[2]; - const auto &mean = saved_tensors_[3]; - const auto &rstd = saved_tensors_[4]; + CHECK_EQ(SavedTensorsSize(), 5); + const auto &input = GetSavedTensor(0); + const auto &weight = GetSavedTensor(1); + const auto &bias = GetSavedTensor(2); + const auto &mean = GetSavedTensor(3); + const auto &rstd = GetSavedTensor(4); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; diff --git a/infini_train/src/autograd/outer.cc b/infini_train/src/autograd/outer.cc index 85a8c9ca..9f419c9f 100644 --- a/infini_train/src/autograd/outer.cc +++ b/infini_train/src/autograd/outer.cc @@ -22,13 +22,13 @@ void Outer::SetupContext(const std::vector> &input_tenso const std::vector> &output_tensors) { const auto &input1 = input_tensors[0]; const auto &input2 = input_tensors[1]; - saved_tensors_ = {input1, input2}; + SaveForBackward({input1, input2}); } std::vector> Outer::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 2); - const auto &input1 = saved_tensors_[0]; - const auto &input2 = saved_tensors_[1]; + CHECK_EQ(SavedTensorsSize(), 2); + const auto &input1 = GetSavedTensor(0); + const auto &input2 = GetSavedTensor(1); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; diff --git a/infini_train/src/autograd/reduction.cc b/infini_train/src/autograd/reduction.cc index 5a6e086f..54613cbe 100644 --- a/infini_train/src/autograd/reduction.cc +++ b/infini_train/src/autograd/reduction.cc @@ -67,15 +67,15 @@ void Max::SetupContext(const std::vector> &input_tensors const std::vector> &output_tensors) { const auto &input = input_tensors[0]; const auto &output = output_tensors[0]; - saved_tensors_ = {input, output}; + SaveForBackward({input, output}); } std::vector> Max::Backward(const std::vector> &grad_outputs) { CHECK_EQ(grad_outputs.size(), 1); - CHECK_EQ(saved_tensors_.size(), 2); + CHECK_EQ(SavedTensorsSize(), 2); const auto &grad_output = grad_outputs[0]; - const auto &input = saved_tensors_[0]; - const auto &reduced = saved_tensors_[1]; + const auto &input = GetSavedTensor(0); + const auto &reduced = GetSavedTensor(1); auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MaxBackward"}, grad_output, input, reduced, @@ -94,15 +94,15 @@ void Min::SetupContext(const std::vector> &input_tensors const std::vector> &output_tensors) { const auto &input = input_tensors[0]; const auto &output = output_tensors[0]; - saved_tensors_ = {input, output}; + SaveForBackward({input, output}); } std::vector> Min::Backward(const std::vector> &grad_outputs) { CHECK_EQ(grad_outputs.size(), 1); - CHECK_EQ(saved_tensors_.size(), 2); + CHECK_EQ(SavedTensorsSize(), 2); const auto &grad_output = grad_outputs[0]; - const auto &input = saved_tensors_[0]; - const auto &reduced = saved_tensors_[1]; + const auto &input = GetSavedTensor(0); + const auto &reduced = GetSavedTensor(1); auto device = input->GetDevice().type(); return {Dispatcher::Instance().Call>({device, "MinBackward"}, grad_output, input, reduced, diff --git a/infini_train/src/autograd/softmax.cc b/infini_train/src/autograd/softmax.cc index 39569a8c..f1b32139 100644 --- a/infini_train/src/autograd/softmax.cc +++ b/infini_train/src/autograd/softmax.cc @@ -17,12 +17,12 @@ std::vector> Softmax::Forward(const std::vector> &, const std::vector> &output_tensors) { const auto &output = output_tensors[0]; - saved_tensors_ = {output}; + SaveForBackward({output}); } std::vector> Softmax::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &output = saved_tensors_[0]; + CHECK_EQ(SavedTensorsSize(), 1); + const auto &output = GetSavedTensor(0); CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; diff --git a/infini_train/src/autograd/sparse.cc b/infini_train/src/autograd/sparse.cc index 93315b4f..83677fa0 100644 --- a/infini_train/src/autograd/sparse.cc +++ b/infini_train/src/autograd/sparse.cc @@ -20,12 +20,12 @@ void Embedding::SetupContext(const std::vector> &input_t const auto &input = input_tensors[0]; const auto &weight = input_tensors[1]; weight_dims_ = weight->Dims(); - saved_tensors_ = {input}; + SaveForBackward({input}); } std::vector> Embedding::Backward(const std::vector> &grad_outputs) { CHECK_EQ(grad_outputs.size(), 1); - const auto &input = saved_tensors_[0]; + const auto &input = GetSavedTensor(0); const auto &grad_output = grad_outputs[0]; auto device = input->GetDevice().type(); diff --git a/infini_train/src/nn/modules/module.cc b/infini_train/src/nn/modules/module.cc index 6d48dcab..ebb83da5 100644 --- a/infini_train/src/nn/modules/module.cc +++ b/infini_train/src/nn/modules/module.cc @@ -20,7 +20,6 @@ #endif namespace infini_train::nn { - Module::Module() : Module(kUndefinedType) {} Module::Module(const std::string &type) : type_(type), device_(Device()) {} @@ -152,7 +151,8 @@ std::vector> Module::Forward(const std::vector> Module::operator()(const std::vector> &input_tensors) { +std::vector> +Module::ForwardWithHooks(const std::vector> &input_tensors) { // 1. Call global module forward pre-hooks utils::GlobalModuleHookRegistry::Instance().CallModuleForwardPreHooks(this, input_tensors); @@ -223,6 +223,10 @@ std::vector> Module::operator()(const std::vector> Module::operator()(const std::vector> &input_tensors) { + return ForwardWithHooks(input_tensors); +} + void Module::To(Device device) { if (device == device_) { return; diff --git a/infini_train/src/nn/modules/transformer/transformer.cc b/infini_train/src/nn/modules/transformer/transformer.cc index c7e0f28c..db7f2f5b 100644 --- a/infini_train/src/nn/modules/transformer/transformer.cc +++ b/infini_train/src/nn/modules/transformer/transformer.cc @@ -1,5 +1,6 @@ #include "infini_train/include/nn/modules/transformer/transformer.h" +#include #include #include #include @@ -7,6 +8,8 @@ #include "glog/logging.h" +#include "infini_train/include/autograd/function.h" +#include "infini_train/include/autograd/grad_mode.h" #include "infini_train/include/nn/functional.h" #include "infini_train/include/nn/init.h" #include "infini_train/include/nn/modules/container.h" @@ -20,8 +23,52 @@ #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/tensor.h" +#include "infini_train/include/utils/checkpoint.h" namespace infini_train::nn { +namespace { +std::vector> RunTransformerLayers(const std::shared_ptr &layers, size_t start, + size_t end, + const std::vector> &inputs) { + auto hidden = inputs[0]; + auto layer_inputs = inputs; + for (size_t layer_idx = start; layer_idx < end; ++layer_idx) { + layer_inputs[0] = hidden; + hidden = (*(*layers)[layer_idx])(layer_inputs)[0]; + } + return {hidden}; +} + +std::vector> CheckpointTransformerLayers(const std::shared_ptr &layers, + size_t start, size_t end, + const std::vector> &inputs) { + auto forward_fn = [layers, start, end](const std::vector> &checkpoint_inputs) { + return RunTransformerLayers(layers, start, end, checkpoint_inputs); + }; + constexpr bool kUseReentrant = false; + constexpr bool kPreserveRngState = true; + return utils::checkpoint::Checkpoint(forward_fn, inputs, kUseReentrant, kPreserveRngState); +} + +bool ShouldUseLayerRecompute(const TransformerConfig &config) { + if (!config.RecomputeEnabled() || !autograd::GradMode::IsEnabled()) { + return false; + } + if (config.recompute_granularity == ActivationRecomputeGranularity::kSelective) { + LOG(FATAL) << "Selective activation recompute is not implemented yet. Use full layer recompute."; + } + return config.recompute_granularity == ActivationRecomputeGranularity::kFull; +} + +size_t GetRecomputeNumLayers(const TransformerConfig &config) { + if (config.recompute_method == ActivationRecomputeMethod::kNone) { + return 1; + } + CHECK_GE(config.recompute_num_layers, 1) + << "recompute_num_layers must be >= 1 when recompute_method is uniform or block."; + return static_cast(config.recompute_num_layers); +} +} // namespace TransformerFirstStage::TransformerFirstStage(const TransformerConfig &config) : CloneableModule(kType), config_(config) { @@ -127,8 +174,11 @@ TransformerChunk::TransformerChunk(const TransformerConfig &config, int start_la std::vector> TransformerChunk::Forward(const std::vector> &x) { auto x1 = x[0]; + auto layers = std::dynamic_pointer_cast(modules_[kHLayerName]); + CHECK(layers); // Check if we need to pass RoPE parameters (for LLaMA3 style models) + std::vector> layer_inputs = {x1}; if (config_.attention_type == AttentionType::kRoPE) { // For RoPE models, we need to prepare freqs_cis and potentially other parameters const auto device = x1->GetDevice(); @@ -136,6 +186,7 @@ std::vector> TransformerChunk::Forward(const std::vector // Init freqs_cis on device only once if (buffers_[kFreqsCisName] == nullptr) { int64_t head_dim = config_.n_embd / config_.n_head; + autograd::Function::SavedTensorHooksGuard disable_saved_tensor_hooks({nullptr, nullptr}); buffers_[kFreqsCisName] = PrecomputeFreqsCis(head_dim, config_.block_size * 2, config_.rope_theta, config_.use_scaled_rope, device); } @@ -152,13 +203,47 @@ std::vector> TransformerChunk::Forward(const std::vector std::shared_ptr start_pos_ptr = nullptr; - // Pass RoPE parameters to each transformer block - for (auto &h : *std::dynamic_pointer_cast(modules_[kHLayerName])) { - x1 = (*h)({x1, freqs_view, start_pos_ptr, mask})[0]; + layer_inputs = {x1, freqs_view, start_pos_ptr, mask}; + } + + if (!ShouldUseLayerRecompute(config_)) { + return RunTransformerLayers(layers, 0, static_cast(layers->end() - layers->begin()), layer_inputs); + } + + const size_t num_layers = static_cast(layers->end() - layers->begin()); + const size_t recompute_num_layers = GetRecomputeNumLayers(config_); + switch (config_.recompute_method) { + case ActivationRecomputeMethod::kNone: { + for (size_t layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + layer_inputs[0] = x1; + x1 = CheckpointTransformerLayers(layers, layer_idx, layer_idx + 1, layer_inputs)[0]; } - } else { - // Standard attention (GPT2 style) - for (auto &h : *std::dynamic_pointer_cast(modules_[kHLayerName])) { x1 = (*h)({x1})[0]; } + break; + } + case ActivationRecomputeMethod::kUniform: { + size_t layer_idx = 0; + while (layer_idx < num_layers) { + auto chunk_end = std::min(layer_idx + recompute_num_layers, num_layers); + layer_inputs[0] = x1; + x1 = CheckpointTransformerLayers(layers, layer_idx, chunk_end, layer_inputs)[0]; + layer_idx = chunk_end; + } + break; + } + case ActivationRecomputeMethod::kBlock: { + const size_t checkpoint_layers = std::min(recompute_num_layers, num_layers); + for (size_t layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + layer_inputs[0] = x1; + if (layer_idx < checkpoint_layers) { + x1 = CheckpointTransformerLayers(layers, layer_idx, layer_idx + 1, layer_inputs)[0]; + } else { + x1 = RunTransformerLayers(layers, layer_idx, layer_idx + 1, layer_inputs)[0]; + } + } + break; + } + default: + LOG(FATAL) << "Unsupported activation recompute method."; } return {x1}; diff --git a/infini_train/src/nn/modules/transformer/transformer_config.cc b/infini_train/src/nn/modules/transformer/transformer_config.cc index b8947d4b..2fab5a30 100644 --- a/infini_train/src/nn/modules/transformer/transformer_config.cc +++ b/infini_train/src/nn/modules/transformer/transformer_config.cc @@ -1,9 +1,28 @@ #include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include +#include + +#include "glog/logging.h" + #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" namespace infini_train::nn { +namespace { +template +Enum ParseEnum(std::string_view value, const std::array, N> &entries, + std::string_view name, std::string_view expected) { + for (const auto &[key, parsed] : entries) { + if (value == key) { + return parsed; + } + } + LOG(FATAL) << "Unknown " << name << ": " << value << ". Expected " << expected << "."; + return entries[0].second; +} +} // namespace + bool TransformerConfig::UseGQA() const { return n_kv_head < n_head; } int TransformerConfig::GetChunkSize() const { @@ -12,4 +31,55 @@ int TransformerConfig::GetChunkSize() const { parallel::global::GetVirtualPipelineParallelSize()); return stage_info.layer_ranges_per_chunk.size(); } + +bool TransformerConfig::RecomputeEnabled() const { + return recompute_granularity != ActivationRecomputeGranularity::kNone; +} + +ActivationRecomputeGranularity ParseActivationRecomputeGranularity(std::string_view value) { + static constexpr std::array kEntries = { + std::pair{"none", ActivationRecomputeGranularity::kNone}, + std::pair{"full", ActivationRecomputeGranularity::kFull}, + std::pair{"selective", + ActivationRecomputeGranularity::kSelective}, + }; + return ParseEnum(value, kEntries, "recompute_granularity", "none|full|selective"); +} + +ActivationRecomputeMethod ParseActivationRecomputeMethod(std::string_view value) { + static constexpr std::array kEntries = { + std::pair{"none", ActivationRecomputeMethod::kNone}, + std::pair{"uniform", ActivationRecomputeMethod::kUniform}, + std::pair{"block", ActivationRecomputeMethod::kBlock}, + }; + return ParseEnum(value, kEntries, "recompute_method", "none|uniform|block"); +} + +void SetActivationRecomputeConfig(TransformerConfig *config, bool enabled, std::string_view granularity, + std::string_view method, int64_t num_layers) { + CHECK(config); + if (!enabled) { + config->recompute_granularity = ActivationRecomputeGranularity::kNone; + config->recompute_method = ActivationRecomputeMethod::kNone; + config->recompute_num_layers = 0; + return; + } + + config->recompute_granularity = ParseActivationRecomputeGranularity(granularity); + if (config->recompute_granularity == ActivationRecomputeGranularity::kNone + || config->recompute_granularity == ActivationRecomputeGranularity::kSelective) { + config->recompute_method = ActivationRecomputeMethod::kNone; + config->recompute_num_layers = 0; + return; + } + + config->recompute_method = ParseActivationRecomputeMethod(method); + if (config->recompute_method == ActivationRecomputeMethod::kNone) { + config->recompute_num_layers = 0; + return; + } + + CHECK_GE(num_layers, 1) << "recompute_num_layers must be >= 1 when recompute_method is uniform or block."; + config->recompute_num_layers = num_layers; +} } // namespace infini_train::nn diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index 44ab8189..1f9514d0 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -535,27 +535,39 @@ VocabParallelCrossEntropy::Forward(const std::vector> &i loss = loss->Mul(1.0f - smoothing)->Sub(mean_logp->Mul(smoothing)); } - // 8. Save for backward - saved_tensors_ = {softmax_local, target_mask, masked_target, valid_mask_local}; + softmax_local_ = softmax_local; + target_mask_ = target_mask; + masked_target_ = masked_target; + valid_mask_local_ = valid_mask_local; return {loss}; } +void VocabParallelCrossEntropy::SetupContext(const std::vector> &, + const std::vector> &output_tensors) { + (void)output_tensors; + SaveForBackward({softmax_local_, target_mask_, masked_target_, valid_mask_local_}); + softmax_local_.reset(); + target_mask_.reset(); + masked_target_.reset(); + valid_mask_local_.reset(); +} + std::vector> VocabParallelCrossEntropy::Backward(const std::vector> &grad_outputs) { CHECK_EQ(grad_outputs.size(), 1); auto grad_output = grad_outputs[0]; - auto softmax_local = saved_tensors_[0]; - auto target_mask = std::make_shared(saved_tensors_[1]->To(softmax_local->Dtype())); - auto masked_target = saved_tensors_[2]; - auto valid_mask_local = saved_tensors_[3]; + auto softmax_local = GetSavedTensor(0); + auto target_mask = std::make_shared(GetSavedTensor(1)->To(softmax_local->Dtype())); + auto masked_target = GetSavedTensor(2); + auto valid_mask_local = GetSavedTensor(3); auto device = grad_output->GetDevice().type(); auto grad_input = Dispatcher::Instance().Call>( {device, "VocabParallelCrossEntropyBackward"}, grad_output, softmax_local, target_mask, masked_target, valid_mask_local, vocab_size_local_, vocab_size_original_, label_smoothing_); - return {grad_input, nullptr}; + return std::vector>{grad_input, nullptr}; } std::vector> diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index c41aa974..446c5c98 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -378,6 +379,16 @@ std::shared_ptr Tensor::Contiguous() { // elementwise.cu ensures non-contiguous tensors fall back to the broadcast path. bool Tensor::IsContiguous() const { return true; } +std::shared_ptr Tensor::Detach() const { + // Return a detached view of original tensor + // Shares the same storage, but never uses gradient + auto view = std::make_shared(*this, offset_, dims_); + view->set_requires_grad(false); + view->set_is_leaf(true); + view->set_grad_fn(nullptr); + return view; +} + std::shared_ptr Tensor::Flatten(int64_t start, int64_t end) { auto ndim = dims_.size(); auto start_dim = start >= 0 ? start : start + ndim; diff --git a/infini_train/src/utils/checkpoint.cc b/infini_train/src/utils/checkpoint.cc new file mode 100644 index 00000000..08dd7874 --- /dev/null +++ b/infini_train/src/utils/checkpoint.cc @@ -0,0 +1,348 @@ +#include "infini_train/include/utils/checkpoint.h" + +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/autograd/grad_mode.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::utils::checkpoint { +namespace { +constexpr char kCheckpointType[] = "Checkpoint"; +std::atomic g_checkpoint_gid{0}; + +int64_t NextCheckpointGid() { return g_checkpoint_gid.fetch_add(1) + 1; } + +struct SavedTensorMeta { + std::vector dims; + DataType dtype = DataType::kFLOAT32; + Device::DeviceType device_type = Device::DeviceType::kCPU; +}; + +bool MetaEquals(const SavedTensorMeta &a, const SavedTensorMeta &b) { + return a.dtype == b.dtype && a.device_type == b.device_type && a.dims == b.dims; +} + +struct CheckpointFrame; + +struct SavedTensorHolder { + std::shared_ptr frame; + size_t index = 0; + std::unordered_map handles; + std::shared_ptr tensor; +}; + +struct CheckpointFrame { + CheckpointFunction::ForwardFn forward_fn; + std::vector> inputs; + std::vector inputs_requires_grad; + AutocastState autocast_state; + std::vector forward_metas; + bool early_stop = true; + bool determinism_check = true; + std::unordered_map recomputed; + std::vector> weak_holders; + int64_t current_gid = -1; + + struct StopRecomputeError : public std::exception {}; + + bool HasActiveHandles(int64_t gid) const { + for (const auto &weak_holder : weak_holders) { + auto holder = weak_holder.lock(); + if (!holder) { + continue; + } + auto it = holder->handles.find(gid); + if (it != holder->handles.end() && it->second) { + return true; + } + } + return false; + } + + size_t CountAliveHolders() const { + size_t alive = 0; + for (const auto &weak_holder : weak_holders) { + if (weak_holder.lock()) { + ++alive; + } + } + return alive; + } + + int64_t GetOrCreateGid() { + if (current_gid < 0) { + current_gid = NextCheckpointGid(); + return current_gid; + } + if (recomputed[current_gid] && !HasActiveHandles(current_gid)) { + current_gid = NextCheckpointGid(); + } + return current_gid; + } + + void Recompute(int64_t gid) { + if (recomputed[gid]) { + return; + } + const size_t alive_needed = CountAliveHolders(); + size_t filled = 0; + size_t recompute_index = 0; + + std::vector> detached_inputs; + detached_inputs.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + if (!inputs[i]) { + detached_inputs.push_back(nullptr); + continue; + } + auto detached = inputs[i]->Detach(); + detached->set_requires_grad(inputs_requires_grad[i]); + detached->set_is_leaf(true); + detached->set_grad_fn(nullptr); + detached_inputs.push_back(detached); + } + + auto prev_autocast = GetAutocastState(); + SetAutocastState(autocast_state); + // Unlike PyTorch's engine, this autograd implementation mutates + // dependency counters eagerly while building a graph. The recompute + // graph is not traversed by non-reentrant checkpoint here, so building + // it would pollute parameter/input dependency counts and break + // accumulation. Recompute under no-grad, while Function::Apply still + // propagates requires_grad metadata and populates needs_input_grad_ so + // SetupContext saves the same tensors as the original forward. + autograd::NoGradGuard no_grad; + autograd::PropagateRequiresGradGuard propagate_requires_grad; + + autograd::Function::SavedTensorHooks hooks; + hooks.pack = [this, gid, alive_needed, &filled, + &recompute_index](const std::shared_ptr &tensor) -> std::shared_ptr { + size_t idx = recompute_index++; + if (idx >= weak_holders.size()) { + LOG(FATAL) << "Checkpoint: recomputed more tensors than saved during forward."; + } + auto holder = weak_holders[idx].lock(); + if (tensor) { + if (determinism_check) { + SavedTensorMeta meta; + meta.dims = tensor->Dims(); + meta.dtype = tensor->Dtype(); + meta.device_type = tensor->GetDevice().type(); + if (!MetaEquals(meta, forward_metas[idx])) { + LOG(FATAL) << "Checkpoint: recomputed tensor metadata mismatch at index " << idx << "."; + } + } + if (holder) { + holder->handles[gid] = true; + holder->tensor = tensor; + ++filled; + } + } else { + if (holder) { + holder->handles[gid] = true; + holder->tensor.reset(); + ++filled; + } + } + if (early_stop && filled >= alive_needed) { + throw StopRecomputeError(); + } + return tensor; + }; + hooks.unpack = [](const std::shared_ptr &state) -> std::shared_ptr { + return std::static_pointer_cast(state); + }; + autograd::Function::SavedTensorHooksGuard guard(std::move(hooks)); + + try { + forward_fn(detached_inputs); + } catch (const StopRecomputeError &) { + // Early-stop: expected when all needed tensors are recomputed. + } + if (filled < alive_needed) { + LOG(FATAL) << "Checkpoint: recomputed fewer tensors (" << filled << ") than required (" << alive_needed + << ")."; + } + + SetAutocastState(prev_autocast); + recomputed[gid] = true; + + // Break potential reference cycles once recomputation is done. + inputs.clear(); + inputs_requires_grad.clear(); + forward_fn = nullptr; + } +}; +} // namespace + +CheckpointFunction::CheckpointFunction(ForwardFn forward_fn) + : autograd::Function(kCheckpointType), forward_fn_(std::move(forward_fn)) {} + +std::vector> +CheckpointFunction::Forward(const std::vector> &input_tensors) { + saved_autocast_ = GetAutocastState(); + saved_inputs_.clear(); + saved_inputs_requires_grad_.clear(); + saved_inputs_.reserve(input_tensors.size()); + saved_inputs_requires_grad_.reserve(input_tensors.size()); + for (const auto &input : input_tensors) { + if (!input) { + saved_inputs_requires_grad_.push_back(false); + saved_inputs_.push_back(nullptr); + continue; + } + saved_inputs_requires_grad_.push_back(input->requires_grad()); + saved_inputs_.push_back(input->Detach()); + } + + // TODO(zbl): RNG state is not captured yet. Dropout or random ops are not supported. + return forward_fn_(input_tensors); +} + +void CheckpointFunction::SetupContext(const std::vector> &, + const std::vector> &) { + // Intentionally empty: checkpoint avoids saving intermediate tensors. +} + +std::vector> +CheckpointFunction::Backward(const std::vector> &grad_outputs) { + // TODO(zbl): RNG state is not captured yet. Dropout or random ops are not supported. + + CHECK(!saved_inputs_.empty()); + CHECK_EQ(grad_outputs.size(), 1) << "Checkpoint currently supports single-output forward only."; + + auto prev_autocast = GetAutocastState(); + SetAutocastState(saved_autocast_); + autograd::EnableGradGuard enable_grad; + + std::vector> detached_inputs; + detached_inputs.reserve(saved_inputs_.size()); + for (size_t i = 0; i < saved_inputs_.size(); ++i) { + if (!saved_inputs_[i]) { + detached_inputs.push_back(nullptr); + continue; + } + auto detached = saved_inputs_[i]->Detach(); + detached->set_requires_grad(saved_inputs_requires_grad_[i]); + detached->set_is_leaf(true); + detached->set_grad_fn(nullptr); + detached_inputs.push_back(detached); + } + + auto outputs = forward_fn_(detached_inputs); + // TODO(zbl): Support multiple-output forward. + CHECK_EQ(outputs.size(), 1) << "Checkpoint currently supports single-output forward only."; + + if (grad_outputs[0]) { + outputs[0]->Backward(grad_outputs[0]); + } + + SetAutocastState(prev_autocast); + + std::vector> grad_inputs; + grad_inputs.reserve(detached_inputs.size()); + for (const auto &detached : detached_inputs) { + if (detached && detached->requires_grad()) { + grad_inputs.push_back(detached->grad()); + } else { + grad_inputs.push_back(nullptr); + } + } + + saved_inputs_.clear(); + saved_inputs_requires_grad_.clear(); + return grad_inputs; +} + +std::vector> Checkpoint(const CheckpointFunction::ForwardFn &forward_fn, + const std::vector> &inputs, bool use_reentrant, + bool preserve_rng_state, bool determinism_check, bool early_stop) { + if (preserve_rng_state) { + // TODO(zbl): Preserve and restore RNG state for CPU/CUDA. + } + if (!autograd::GradMode::IsEnabled()) { + return forward_fn(inputs); + } + + if (use_reentrant) { + const bool any_requires_grad = std::any_of( + inputs.begin(), inputs.end(), [](const std::shared_ptr &t) { return t && t->requires_grad(); }); + if (!any_requires_grad) { + return forward_fn(inputs); + } + auto func = std::make_shared(forward_fn); + return func->Apply(inputs); + } + + auto frame = std::make_shared(); + frame->forward_fn = forward_fn; + frame->early_stop = early_stop; + frame->determinism_check = determinism_check; + frame->inputs.reserve(inputs.size()); + frame->inputs_requires_grad.reserve(inputs.size()); + for (const auto &input : inputs) { + if (input) { + frame->inputs.push_back(input->Detach()); + frame->inputs_requires_grad.push_back(input->requires_grad()); + } else { + frame->inputs.push_back(nullptr); + frame->inputs_requires_grad.push_back(false); + } + } + frame->autocast_state = GetAutocastState(); + + autograd::Function::SavedTensorHooks hooks; + hooks.pack = [frame](const std::shared_ptr &tensor) -> std::shared_ptr { + auto holder = std::make_shared(); + holder->frame = frame; + holder->index = frame->forward_metas.size(); + frame->weak_holders.push_back(holder); + if (tensor) { + SavedTensorMeta meta; + meta.dims = tensor->Dims(); + meta.dtype = tensor->Dtype(); + meta.device_type = tensor->GetDevice().type(); + frame->forward_metas.push_back(std::move(meta)); + } else { + frame->forward_metas.push_back({}); + } + return holder; + }; + hooks.unpack = [](const std::shared_ptr &state) -> std::shared_ptr { + auto holder = std::static_pointer_cast(state); + auto frame = holder->frame; + const int64_t gid = frame->GetOrCreateGid(); + if (!frame->recomputed[gid]) { + frame->Recompute(gid); + } + auto it = holder->handles.find(gid); + if (it == holder->handles.end() || !it->second) { + LOG(FATAL) << "Checkpoint: unpack called more than once for index " << holder->index << "."; + } + auto recomputed = holder->tensor; + const auto &meta = frame->forward_metas[holder->index]; + if (recomputed && meta.dims.empty() && meta.dtype == DataType::kFLOAT32 + && meta.device_type == Device::DeviceType::kCPU) { + LOG(FATAL) << "Checkpoint: recomputed non-null tensor for saved null entry."; + } + if (!recomputed + && !(meta.dims.empty() && meta.dtype == DataType::kFLOAT32 + && meta.device_type == Device::DeviceType::kCPU)) { + LOG(FATAL) << "Checkpoint: recomputed null tensor for saved non-null entry."; + } + // TODO(zbl): Determinism check (shape/dtype/device) vs forward_metas. + // Release recomputed tensor as soon as it's unpacked to reduce peak memory. + holder->tensor.reset(); + it->second = false; + return recomputed; + }; + + autograd::Function::SavedTensorHooksGuard guard(std::move(hooks)); + return forward_fn(inputs); +} + +} // namespace infini_train::utils::checkpoint diff --git a/scripts/test_config.json b/scripts/test_config.json index 54332f70..f2c92d45 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -14,11 +14,6 @@ "id": "build_1", "profile": false, "cmd": "cmake -DUSE_CUDA=ON -DUSE_NCCL=ON .. && make -j" - }, - { - "id": "build_2", - "profile": true, - "cmd": "cmake -DUSE_CUDA=ON -DUSE_NCCL=ON -DPROFILE_MODE=ON .. && make -j" } ], "test_groups": [ @@ -304,6 +299,231 @@ } ] }, + { + "tag": "recompute", + "tests": [ + { + "id": "2_recompute_uniform2", + "args": { + "dtype": "float32", + "num_iteration": 10, + "batch_size": 80, + "total_batch_size": 5120, + "activation_recompute": true, + "recompute_granularity": "full", + "recompute_method": "uniform", + "recompute_num_layers": 2 + } + }, + { + "id": "2_recompute_block2_bfloat16", + "args": { + "dtype": "bfloat16", + "num_iteration": 10, + "batch_size": 80, + "total_batch_size": 5120, + "activation_recompute": true, + "recompute_granularity": "full", + "recompute_method": "block", + "recompute_num_layers": 2 + } + }, + { + "id": "3_recompute_uniform1", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "activation_recompute": true, + "recompute_granularity": "full", + "recompute_method": "uniform", + "recompute_num_layers": 1 + } + }, + { + "id": "3_recompute_block2_bfloat16", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "activation_recompute": true, + "recompute_granularity": "full", + "recompute_method": "block", + "recompute_num_layers": 2 + } + }, + { + "id": "6_recompute_uniform1", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "pipeline_parallel": 8, + "activation_recompute": true, + "recompute_granularity": "full", + "recompute_method": "uniform", + "recompute_num_layers": 1 + } + }, + { + "id": "6_recompute_block1_bfloat16", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "pipeline_parallel": 8, + "activation_recompute": true, + "recompute_granularity": "full", + "recompute_method": "block", + "recompute_num_layers": 1 + } + }, + { + "id": "7_recompute_uniform2", + "args": { + "dtype": "float32", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "pipeline_parallel": 4, + "virtual_pipeline_parallel": 2, + "activation_recompute": true, + "recompute_granularity": "full", + "recompute_method": "uniform", + "recompute_num_layers": 2 + } + }, + { + "id": "7_recompute_block2_bfloat16", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "pipeline_parallel": 4, + "virtual_pipeline_parallel": 2, + "activation_recompute": true, + "recompute_granularity": "full", + "recompute_method": "block", + "recompute_num_layers": 2 + } + }, + { + "id": "8_recompute_uniform3", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "activation_recompute": true, + "recompute_granularity": "full", + "recompute_method": "uniform", + "recompute_num_layers": 3 + } + }, + { + "id": "8_recompute_block2_bfloat16", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "activation_recompute": true, + "recompute_granularity": "full", + "recompute_method": "block", + "recompute_num_layers": 2 + } + }, + { + "id": "3_recompute_uniform1_distopt", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "use_distributed_optimizer": true, + "activation_recompute": true, + "recompute_granularity": "full", + "recompute_method": "uniform", + "recompute_num_layers": 1 + } + }, + { + "id": "3_recompute_block2_distopt_bfloat16", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "use_distributed_optimizer": true, + "activation_recompute": true, + "recompute_granularity": "full", + "recompute_method": "block", + "recompute_num_layers": 2 + } + }, + { + "id": "8_recompute_uniform3_distopt", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "use_distributed_optimizer": true, + "activation_recompute": true, + "recompute_granularity": "full", + "recompute_method": "uniform", + "recompute_num_layers": 3 + } + }, + { + "id": "8_recompute_block2_distopt_bfloat16", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "use_distributed_optimizer": true, + "activation_recompute": true, + "recompute_granularity": "full", + "recompute_method": "block", + "recompute_num_layers": 2 + } + } + ] + }, { "tag": "lora", "tests": [ From d87926e2a0ae37fb6ec937b4e7a28b340d025514 Mon Sep 17 00:00:00 2001 From: bolunz Date: Thu, 21 May 2026 10:48:42 +0800 Subject: [PATCH 2/2] fix: restore test_config.json --- scripts/test_config.json | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/scripts/test_config.json b/scripts/test_config.json index f2c92d45..e630544e 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -14,6 +14,11 @@ "id": "build_1", "profile": false, "cmd": "cmake -DUSE_CUDA=ON -DUSE_NCCL=ON .. && make -j" + }, + { + "id": "build_2", + "profile": true, + "cmd": "cmake -DUSE_CUDA=ON -DUSE_NCCL=ON -DPROFILE_MODE=ON .. && make -j" } ], "test_groups": [