Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions example/gpt2/checkpoint_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,22 @@ std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std::
return {}; // Unreachable, but keeps compiler happy
}
}

void ApplyRuntimeRecomputeConfig(nn::TransformerConfig *config, const nn::TransformerConfig &runtime_config) {
config->recompute_granularity = runtime_config.recompute_granularity;
config->recompute_method = runtime_config.recompute_method;
config->recompute_num_layers = runtime_config.recompute_num_layers;
}
} // namespace

namespace gpt2 {

std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath) {
return LoadFromLLMC(filepath, gpt2::GPT2Config());
}

std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath,
const nn::TransformerConfig &runtime_config) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -87,6 +98,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
gpt2_config.n_layer = n_layer;
gpt2_config.n_head = n_head;
gpt2_config.n_embd = n_embd;
ApplyRuntimeRecomputeConfig(&gpt2_config, runtime_config);
auto local_gpt2 = std::make_shared<nn::TransformerModel>(gpt2_config);

LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size
Expand Down
4 changes: 4 additions & 0 deletions example/gpt2/checkpoint_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
#include <memory>
#include <string>

#include "infini_train/include/nn/modules/transformer/transformer_config.h"

namespace infini_train::nn {
class TransformerModel;
} // namespace infini_train::nn

namespace gpt2 {
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
std::shared_ptr<infini_train::nn::TransformerModel>
LoadFromLLMC(const std::string &filepath, const infini_train::nn::TransformerConfig &runtime_config);
} // namespace gpt2
23 changes: 16 additions & 7 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");

// activation recompute
DEFINE_bool(activation_recompute, false, "Enable activation recompute to trade compute for memory.");
DEFINE_string(recompute_granularity, "full", "Activation recompute granularity: none|full|selective");
DEFINE_string(recompute_method, "none", "Activation recompute method: none|uniform|block");
DEFINE_uint32(recompute_num_layers, 0, "Number of transformer layers per recompute region for uniform/block methods.");

// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
// precision check
Expand Down Expand Up @@ -186,21 +192,24 @@ void Train(const nn::parallel::Rank &rank) {
nn::TransformerConfig model_config = gpt2::GPT2Config();
std::shared_ptr<nn::Module> model = nullptr;

if (!FLAGS_llmc_filepath.empty()) {
model = gpt2::LoadFromLLMC(FLAGS_llmc_filepath);
} else if (kModelToConfigs.count(FLAGS_model)) {
if (FLAGS_llmc_filepath.empty() && kModelToConfigs.count(FLAGS_model)) {
model_config = kModelToConfigs.at(FLAGS_model);
}
nn::SetActivationRecomputeConfig(&model_config, FLAGS_activation_recompute, FLAGS_recompute_granularity,
FLAGS_recompute_method, static_cast<int64_t>(FLAGS_recompute_num_layers));

if (!FLAGS_llmc_filepath.empty()) {
model = gpt2::LoadFromLLMC(FLAGS_llmc_filepath, model_config);
} else {
model = std::make_shared<nn::TransformerModel>(model_config);
}

CHECK(model) << "GPT2 example expects GPT2 model.";

model->To(device);

utils::PrecisionChecker::BuildNameMap(model.get());

// Get chunk size before wrapping with LoRA (needed for PipelineParallel)
auto gpt2_model = std::dynamic_pointer_cast<nn::TransformerModel>(model);
CHECK(gpt2_model) << "GPT2 example expects GPT2 model.";

// Apply LoRA using GetLoRAModel (in-place injection)
bool lora_enabled = FLAGS_lora_rank > 0;
if (lora_enabled) {
Expand Down
12 changes: 12 additions & 0 deletions example/llama3/checkpoint_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,22 @@ static std::mt19937 gen{kRandomSeed};
namespace {
constexpr int32_t kLLaMA3Magic = 20240803;
constexpr int32_t kLLaMA3FP32Version = 3;

void ApplyRuntimeRecomputeConfig(nn::TransformerConfig *config, const nn::TransformerConfig &runtime_config) {
config->recompute_granularity = runtime_config.recompute_granularity;
config->recompute_method = runtime_config.recompute_method;
config->recompute_num_layers = runtime_config.recompute_num_layers;
}
} // namespace

namespace llama3 {

std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath) {
return LoadFromLLMC(filepath, llama3::LLaMA3Config());
}

std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath,
const nn::TransformerConfig &runtime_config) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -80,6 +91,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
llama3_config.use_scaled_rope = static_cast<bool>(use_scaled_rope);
llama3_config.norm_eps = norm_eps;
llama3_config.max_gen_batch_size = max_gen_bs;
ApplyRuntimeRecomputeConfig(&llama3_config, runtime_config);
auto llama3 = std::make_shared<nn::TransformerModel>(llama3_config);

// ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ==========
Expand Down
4 changes: 4 additions & 0 deletions example/llama3/checkpoint_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
#include <memory>
#include <string>

#include "infini_train/include/nn/modules/transformer/transformer_config.h"

namespace infini_train::nn {
class TransformerModel;
} // namespace infini_train::nn

namespace llama3 {
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
std::shared_ptr<infini_train::nn::TransformerModel>
LoadFromLLMC(const std::string &filepath, const infini_train::nn::TransformerConfig &runtime_config);
} // namespace llama3
20 changes: 19 additions & 1 deletion example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");

// activation recompute
DEFINE_bool(activation_recompute, false, "Enable activation recompute to trade compute for memory.");
DEFINE_string(recompute_granularity, "full", "Activation recompute granularity: none|full|selective");
DEFINE_string(recompute_method, "none", "Activation recompute method: none|uniform|block");
DEFINE_uint32(recompute_num_layers, 0, "Number of transformer layers per recompute region for uniform/block methods.");
// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
// precision check
Expand Down Expand Up @@ -171,12 +177,16 @@ void Train(const nn::parallel::Rank &rank) {

nn::TransformerConfig model_config = llama3::LLaMA3Config();
std::shared_ptr<nn::Module> model = nullptr;
nn::SetActivationRecomputeConfig(&model_config, FLAGS_activation_recompute, FLAGS_recompute_granularity,
FLAGS_recompute_method, static_cast<int64_t>(FLAGS_recompute_num_layers));
if (!FLAGS_llmc_filepath.empty()) {
model = llama3::LoadFromLLMC(FLAGS_llmc_filepath);
model = llama3::LoadFromLLMC(FLAGS_llmc_filepath, model_config);
} else {
model = std::make_shared<nn::TransformerModel>(model_config);
}

CHECK(model) << "LLaMA3 example expects LLaMA3 model.";

model->To(device);

utils::PrecisionChecker::BuildNameMap(model.get());
Expand Down Expand Up @@ -357,12 +367,20 @@ void Train(const nn::parallel::Rank &rank) {
autocast_guard.Disable();

LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish loss forward";
auto [forward_used_mb, forward_reserved_mb] = impl->GetMemPoolPeakMB(device);
LOG(INFO) << std::format(
"Rank {}: after forward (micro_step {}/{}), peak used: {:5d} MB | peak reserved: {:5d} MB",
rank.GlobalRank(), micro_step + 1, grad_accum_steps, forward_used_mb, forward_reserved_mb);

auto loss_cpu = loss->To(Device());
lossf += static_cast<const float *>(loss_cpu.DataPtr())[0];
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start backward";
loss->Backward();
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish backward";
auto [backward_used_mb, backward_reserved_mb] = impl->GetMemPoolPeakMB(device);
LOG(INFO) << std::format(
"Rank {}: after backward (micro_step {}/{}), peak used: {:5d} MB | peak reserved: {:5d} MB",
rank.GlobalRank(), micro_step + 1, grad_accum_steps, backward_used_mb, backward_reserved_mb);
}

optimizer->Step();
Expand Down
4 changes: 2 additions & 2 deletions example/mnist/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ int main(int argc, char *argv[]) {
auto new_image = std::make_shared<Tensor>(image->To(device));
auto new_label = std::make_shared<Tensor>(label->To(device));

auto outputs = network.Forward({new_image});
auto outputs = network({new_image});
optimizer.ZeroGrad();

auto loss = loss_fn.Forward({outputs[0], new_label});
Expand Down Expand Up @@ -101,7 +101,7 @@ int main(int argc, char *argv[]) {
auto new_label = std::make_shared<Tensor>(label->To(device));

auto label_cpu = label->To(cpu_device);
auto outputs = network.Forward({new_image});
auto outputs = network({new_image});
auto output_cpu = outputs[0]->To(cpu_device);
auto loss = loss_fn.Forward({outputs[0], new_label});
auto loss_cpu = loss[0]->To(cpu_device);
Expand Down
18 changes: 18 additions & 0 deletions infini_train/include/autocast.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,24 @@ struct AutocastContext {
// Global thread-local storage for autocast context
inline thread_local AutocastContext tls_autocast_context;

// Lightweight snapshot for autocast state
struct AutocastState {
bool enabled = false;
Device::DeviceType device_type = Device::DeviceType::kCPU;
DataType autocast_dtype = DataType::kBFLOAT16;
};

inline AutocastState GetAutocastState() {
return AutocastState{tls_autocast_context.enabled, tls_autocast_context.device_type,
tls_autocast_context.autocast_dtype};
}

inline void SetAutocastState(const AutocastState &state) {
tls_autocast_context.enabled = state.enabled;
tls_autocast_context.device_type = state.device_type;
tls_autocast_context.autocast_dtype = state.autocast_dtype;
}

// RAII guard to enable/disable autocast in a scope
class AutocastGuard {
public:
Expand Down
37 changes: 36 additions & 1 deletion infini_train/include/autograd/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <functional>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

Expand All @@ -21,6 +22,15 @@ class Function : public std::enable_shared_from_this<Function> {
using FunctionPostHook = std::function<void(Function *, const std::vector<std::shared_ptr<Tensor>> &,
const std::vector<std::shared_ptr<Tensor>> &)>;

// Definition of hooks for saved_tensors, in alignment with torch.autograd.graph.saved_tensors_hooks
using SavedTensorPackHook = std::function<std::shared_ptr<void>(const std::shared_ptr<Tensor> &)>;
using SavedTensorUnpackHook = std::function<std::shared_ptr<Tensor>(const std::shared_ptr<void> &)>;

struct SavedTensorHooks {
SavedTensorPackHook pack;
SavedTensorUnpackHook unpack;
};

static constexpr char kUndefinedType[] = "Undefined";

Function() : type_(kUndefinedType) {}
Expand All @@ -45,8 +55,33 @@ class Function : public std::enable_shared_from_this<Function> {

const std::string &type() const { return type_; }

void SaveForBackward(const std::vector<std::shared_ptr<Tensor>> &tensors);
size_t SavedTensorsSize() const { return saved_tensors_.size(); }
std::shared_ptr<Tensor> GetSavedTensor(size_t index) const;
std::vector<std::shared_ptr<Tensor>> GetSavedTensors() const;

// RAII: Register pack/unpack hooks for saved_tensors, align with torch.autograd.graph.saved_tensors_hooks
class SavedTensorHooksGuard {
public:
explicit SavedTensorHooksGuard(SavedTensorHooks hooks);
~SavedTensorHooksGuard();

SavedTensorHooksGuard(const SavedTensorHooksGuard &) = delete;
SavedTensorHooksGuard &operator=(const SavedTensorHooksGuard &) = delete;

private:
size_t depth_ = 0;
};

protected:
std::vector<std::shared_ptr<Tensor>> saved_tensors_;
struct SavedTensorEntry {
// Tensor itself, used under default or reentrant version of recomputation
std::shared_ptr<Tensor> tensor;
// Function to recompute the target tensor, used under non-reentrant version of recomputation
std::shared_ptr<void> hook_state;
SavedTensorUnpackHook unpack;
};
std::vector<SavedTensorEntry> saved_tensors_;
std::vector<bool> needs_input_grad_;

private:
Expand Down
18 changes: 18 additions & 0 deletions infini_train/include/autograd/grad_mode.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ class GradMode {
// Whether to enable Autograd (enabled by default)
static bool IsEnabled() { return grad_enabled_; }
static void SetEnabled(bool enabled) { grad_enabled_ = enabled; }
static bool PropagateRequiresGrad() { return propagate_requires_grad_; }
static void SetPropagateRequiresGrad(bool enabled) { propagate_requires_grad_ = enabled; }

private:
// grad mode should be thread_local
static thread_local bool grad_enabled_;
static thread_local bool propagate_requires_grad_;
};

// RAII: Disable grad (align with torch.no_grad)
Expand All @@ -34,4 +37,19 @@ class EnableGradGuard {
bool prev_;
};

// RAII: Propagate requires_grad metadata while graph construction is disabled.
// Used by non-reentrant checkpoint recomputation so downstream SetupContext
// calls see the same needs_input_grad_ pattern as the original forward,
// without wiring the recompute graph into the engine.
class PropagateRequiresGradGuard {
public:
PropagateRequiresGradGuard() : prev_(GradMode::PropagateRequiresGrad()) {
GradMode::SetPropagateRequiresGrad(true);
}
~PropagateRequiresGradGuard() { GradMode::SetPropagateRequiresGrad(prev_); }

private:
bool prev_;
};

} // namespace infini_train::autograd
2 changes: 2 additions & 0 deletions infini_train/include/autograd/normalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,7 @@ class LayerNorm : public Function {

private:
const float eps_ = 1e-5f;
std::shared_ptr<Tensor> mean_ = nullptr;
std::shared_ptr<Tensor> rstd_ = nullptr;
};
} // namespace infini_train::autograd
2 changes: 2 additions & 0 deletions infini_train/include/nn/modules/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class Module : public std::enable_shared_from_this<Module> {
std::vector<ModulePreHook> backward_pre_hooks_;
std::vector<ModulePostHook> backward_post_hooks_;

std::vector<std::shared_ptr<Tensor>> ForwardWithHooks(const std::vector<std::shared_ptr<Tensor>> &input_tensors);

private:
friend std::vector<std::shared_ptr<Module>> parallel::function::Replicate(const std::shared_ptr<Module> &network,
const std::vector<Device> &devices);
Expand Down
24 changes: 24 additions & 0 deletions infini_train/include/nn/modules/transformer/transformer_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cstdint>
#include <optional>
#include <string_view>

namespace infini_train::nn {

Expand All @@ -25,6 +26,18 @@ enum class NormType {
kRMSNorm // RMSNorm
};

enum class ActivationRecomputeGranularity {
kNone, // Disable activation recompute.
kFull, // Recompute full transformer layers.
kSelective, // Recompute selected transformer submodules.
};

enum class ActivationRecomputeMethod {
kNone, // Recompute every layer when granularity is full.
kUniform, // Uniformly divide layers into recompute chunks.
kBlock, // Recompute only the first N layers in the local chunk/stage.
};

struct TransformerConfig {
int64_t block_size = 1024; // Max seq_len
int64_t vocab_size = 50304; // Vocab size
Expand Down Expand Up @@ -56,12 +69,23 @@ struct TransformerConfig {
// Normalization
float norm_eps = 1e-5f; // epsilon in RMSNorm

// Activation recomputation, aligned with Megatron-Core TransformerConfig.
ActivationRecomputeGranularity recompute_granularity = ActivationRecomputeGranularity::kNone;
ActivationRecomputeMethod recompute_method = ActivationRecomputeMethod::kNone;
int64_t recompute_num_layers = 0;

// Inference
bool use_kv = false; // kv cache
bool flash = false; // flash attention
int64_t max_gen_batch_size = 4; // max batch size during inference

bool UseGQA() const;
int GetChunkSize() const;
bool RecomputeEnabled() const;
};

ActivationRecomputeGranularity ParseActivationRecomputeGranularity(std::string_view value);
ActivationRecomputeMethod ParseActivationRecomputeMethod(std::string_view value);
void SetActivationRecomputeConfig(TransformerConfig *config, bool enabled, std::string_view granularity,
std::string_view method, int64_t num_layers);
} // namespace infini_train::nn
Loading
Loading