Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
7a16589
refactor(optimizer): hoist learning_rate_ to Optimizer base and add l…
Mar 3, 2026
0514862
refactor(distributed_optimizer): passthrough SetLearningRate/GetLearn…
Mar 3, 2026
81295e8
feat(lr_scheduler): add LRScheduler abstract base class with StateDict
Mar 3, 2026
8e7cda0
refactor(examples): add scheduler placeholder and use runtime lr in logs
Mar 3, 2026
1e65881
feat: add ConstantLR, StepLR and LinearWarmupLR
Mar 4, 2026
d924d3d
refactor(lr_scheduler): replace ComputeLR with virtual Step and Apply…
Mar 5, 2026
baca2ef
feat(lr_schedulers): add LambdaLR strategy
Mar 5, 2026
7df75d7
refactor(optimizer): add initial_learning_rate and it's accessors
Mar 5, 2026
d0ac538
feat(lr_schedulers): add SequentialLR composite strategy
Mar 5, 2026
df4c68d
refactor(lr_scheduler): apply template method pattern to LRScheduler …
Mar 5, 2026
5b4ef6d
feat(lr_scheduler): add factory method Create<T>() with two-phase ini…
Mar 5, 2026
8c11dd9
feat(lr_scheduler): add closed and chained form, adjust LinearLR、Sequ…
Mar 6, 2026
6823244
feat(lr_schedulers): add ChainedScheduler composite strategy
Mar 6, 2026
fb9d997
feat(lr_scheduler): add scheduler factory for CLI integration
Mar 8, 2026
7a29a61
feat(lr_scheduler): add scheduler factory for CLI integration (Sequen…
Mar 8, 2026
b64566e
feat(lr_scheduler): add warmup start_factor and end_factor , remove c…
Mar 8, 2026
3a7abb4
refactor(gpt2,llama3): integrate scheduler into training loop
Mar 8, 2026
f7b3fcb
Merge branch 'InfiniTensor:master' into lr_scheduler
littleotherut Mar 11, 2026
1f95e29
style: apply clang-format to all legacy files
littleotherut Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,24 @@ target_link_libraries(test_hook infini_train)

add_executable(test_precision_check test/hook/test_precision_check.cc)
target_link_libraries(test_precision_check infini_train)

add_executable(test_lr_scheduler test/lr_scheduler/test_lr_scheduler.cc)
target_link_libraries(test_lr_scheduler infini_train)

add_executable(test_constant_lr test/lr_scheduler/test_constant_lr.cc)
target_link_libraries(test_constant_lr infini_train)

add_executable(test_step_lr test/lr_scheduler/test_step_lr.cc)
target_link_libraries(test_step_lr infini_train)

add_executable(test_linear_lr test/lr_scheduler/test_linear_lr.cc)
target_link_libraries(test_linear_lr infini_train)

add_executable(test_lambda_lr test/lr_scheduler/test_lambda_lr.cc)
target_link_libraries(test_lambda_lr infini_train)

add_executable(test_sequential_lr test/lr_scheduler/test_sequential_lr.cc)
target_link_libraries(test_sequential_lr infini_train)

add_executable(test_chained_lr test/lr_scheduler/test_chained_lr.cc)
target_link_libraries(test_chained_lr infini_train)
37 changes: 34 additions & 3 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
#include "infini_train/include/lr_scheduler.h"
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
Expand Down Expand Up @@ -55,6 +56,16 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
// lr scheduler
DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear");
DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)");
DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)");
DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)");
DEFINE_int64(step_size, 30, "StepLR: period of learning rate decay");
DEFINE_double(gamma, 0.1, "StepLR: multiplicative factor of lr decay");
DEFINE_double(start_factor, 0.333333, "LinearLR: starting multiplicative factor");
DEFINE_double(end_factor, 1.0, "LinearLR: ending multiplicative factor");
DEFINE_int64(lr_total_iters, 5, "ConstantLR/LinearLR: total iterations for the scheduler");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -268,6 +279,20 @@ void Train(const nn::parallel::Rank &rank) {
optimizer = optimizer_creator(model->Parameters());
}

LRSchedulerConfig sched_config;
sched_config.type = FLAGS_lr_scheduler;
sched_config.warmup_steps = FLAGS_warmup_steps;
sched_config.warmup_start_factor = static_cast<float>(FLAGS_warmup_start_factor);
sched_config.warmup_end_factor = static_cast<float>(FLAGS_warmup_end_factor);
sched_config.step_size = FLAGS_step_size;
sched_config.step_gamma = static_cast<float>(FLAGS_gamma);
sched_config.linear_start_factor = static_cast<float>(FLAGS_start_factor);
sched_config.linear_end_factor = static_cast<float>(FLAGS_end_factor);
sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用
sched_config.constant_total_iters = FLAGS_lr_total_iters;
sched_config.linear_total_iters = FLAGS_lr_total_iters;
auto scheduler = CreateLRScheduler(optimizer, sched_config);

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(
Expand Down Expand Up @@ -354,6 +379,9 @@ void Train(const nn::parallel::Rank &rank) {
}

optimizer->Step();
if (scheduler) {
scheduler->Step();
}
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
Expand All @@ -363,6 +391,9 @@ void Train(const nn::parallel::Rank &rank) {
y = std::make_shared<Tensor>(y->To(device));

lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype);
if (scheduler) {
scheduler->Step();
}
}

if (ddp_world_size > 1) {
Expand All @@ -378,11 +409,11 @@ void Train(const nn::parallel::Rank &rank) {
if (rank.IsLastRank()) {
size_t used_mb = 0, reserved_mb = 0;
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);

const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,
used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
Expand Down
37 changes: 34 additions & 3 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
#include "infini_train/include/lr_scheduler.h"
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
Expand Down Expand Up @@ -54,6 +55,16 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
// lr scheduler
DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear");
DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)");
DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)");
DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)");
DEFINE_int64(step_size, 30, "StepLR: period of learning rate decay");
DEFINE_double(gamma, 0.1, "StepLR: multiplicative factor of lr decay");
DEFINE_double(start_factor, 0.333333, "LinearLR: starting multiplicative factor");
DEFINE_double(end_factor, 1.0, "LinearLR: ending multiplicative factor");
DEFINE_int64(lr_total_iters, 5, "ConstantLR/LinearLR: total iterations for the scheduler");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -247,6 +258,20 @@ void Train(const nn::parallel::Rank &rank) {
optimizer = optimizer_creator(model->Parameters());
}

LRSchedulerConfig sched_config;
sched_config.type = FLAGS_lr_scheduler;
sched_config.warmup_steps = FLAGS_warmup_steps;
sched_config.warmup_start_factor = static_cast<float>(FLAGS_warmup_start_factor);
sched_config.warmup_end_factor = static_cast<float>(FLAGS_warmup_end_factor);
sched_config.step_size = FLAGS_step_size;
sched_config.step_gamma = static_cast<float>(FLAGS_gamma);
sched_config.linear_start_factor = static_cast<float>(FLAGS_start_factor);
sched_config.linear_end_factor = static_cast<float>(FLAGS_end_factor);
sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用
sched_config.constant_total_iters = FLAGS_lr_total_iters;
sched_config.linear_total_iters = FLAGS_lr_total_iters;
auto scheduler = CreateLRScheduler(optimizer, sched_config);

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(std::make_shared<VocabParallelCrossEntropyLoss>())
Expand Down Expand Up @@ -330,6 +355,9 @@ void Train(const nn::parallel::Rank &rank) {
}

optimizer->Step();
if (scheduler) {
scheduler->Step();
}
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
Expand All @@ -339,6 +367,9 @@ void Train(const nn::parallel::Rank &rank) {
y = std::make_shared<Tensor>(y->To(device));

lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype);
if (scheduler) {
scheduler->Step();
}
}

if (ddp_world_size > 1) {
Expand All @@ -354,11 +385,11 @@ void Train(const nn::parallel::Rank &rank) {
if (rank.IsLastRank()) {
size_t used_mb = 0, reserved_mb = 0;
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);

const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,
used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
Expand Down
186 changes: 186 additions & 0 deletions infini_train/include/lr_scheduler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#pragma once

#include <cmath>
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <variant>
#include <vector>

namespace infini_train {

class Optimizer;

using StateValue = std::variant<int64_t, float, double, std::string, std::vector<float>>;
using StateDict = std::unordered_map<std::string, StateValue>;

struct LRSchedulerConfig {
std::string type = "none";
// ConstantLR
float constant_factor = 1.0f / 3.0f;
int constant_total_iters = 5;
// StepLR
int64_t step_size = 10;
float step_gamma = 0.1f;
// LinearLR
float linear_start_factor = 1.0f / 3.0f;
float linear_end_factor = 1.0f;
int linear_total_iters = 5;
// LambdaLR
std::function<float(int64_t)> lambda_fn = nullptr;
// SequentialLR
std::vector<LRSchedulerConfig> sequential_configs;
std::vector<int64_t> sequential_milestones;
// ChainedScheduler
std::vector<LRSchedulerConfig> chained_configs;
// warmup
int64_t warmup_steps = 0;
float warmup_start_factor = 1.0f / 3.0f;
float warmup_end_factor = 1.0f;
};

class LRScheduler {
public:
template <typename T, typename... Args> static std::shared_ptr<T> Create(Args &&...args) {
auto scheduler = std::make_shared<T>(std::forward<Args>(args)...);
scheduler->InitialStep();
return scheduler;
}

explicit LRScheduler(std::shared_ptr<Optimizer> optimizer, int64_t last_step = -1);
virtual ~LRScheduler() = default;

LRScheduler(const LRScheduler &) = delete;
LRScheduler &operator=(const LRScheduler &) = delete;

virtual void Step();
virtual void Step(int64_t epoch);
virtual void InitialStep();

float GetLR() const;
float BaseLR() const;
int64_t LastStep() const;

void ResetStep(int64_t step = -1);
virtual StateDict State() const;
virtual void LoadState(const StateDict &state);

protected:
virtual float GetClosedFormLR() const = 0;
virtual float GetChainedFormLR() const;
void ApplyLR(float lr);

std::shared_ptr<Optimizer> optimizer_;
int64_t last_step_;
float current_lr_;
float base_lr_;
bool is_initial_ = false;
};

std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimizer, const LRSchedulerConfig &config);

namespace lr_schedulers {

class ConstantLR : public LRScheduler {
public:
ConstantLR(std::shared_ptr<Optimizer> optimizer, float factor = 1.0f / 3.0f, int total_iters = 5,
int64_t last_step = -1);
~ConstantLR() override = default;

protected:
float GetChainedFormLR() const override;
float GetClosedFormLR() const override;

private:
const float factor_;
const int64_t total_iters_;
};

class StepLR : public LRScheduler {
public:
StepLR(std::shared_ptr<Optimizer> optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1);
~StepLR() override = default;

protected:
float GetChainedFormLR() const override;
float GetClosedFormLR() const override;

private:
const int64_t step_size_;
const float gamma_;
};

class LinearLR : public LRScheduler {
public:
LinearLR(std::shared_ptr<Optimizer> optimizer, float start_factor = 1.0f / 3.0f, float end_factor = 1.0f,
int64_t total_iters = 5, int64_t last_step = -1);
~LinearLR() override = default;

protected:
float GetChainedFormLR() const override;
float GetClosedFormLR() const override;

private:
const float start_factor_;
const float end_factor_;
const int64_t total_iters_;
};

class LambdaLR : public LRScheduler {
public:
using LambdaFunc = std::function<float(int64_t)>;

LambdaLR(std::shared_ptr<Optimizer> optimizer, LambdaFunc lr_lambda, int64_t last_step = -1);
~LambdaLR() override = default;

protected:
float GetClosedFormLR() const override;

private:
const LambdaFunc lr_lambda_;
};

class SequentialLR : public LRScheduler {
public:
SequentialLR(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
std::vector<int64_t> milestones, int64_t last_step = -1);
~SequentialLR() override = default;

void Step() override;
void InitialStep() override;

StateDict State() const override;
void LoadState(const StateDict &state) override;

protected:
float GetClosedFormLR() const override { return current_lr_; }
void UndoChildInitialSteps();

private:
std::vector<std::shared_ptr<LRScheduler>> schedulers_;
std::vector<int64_t> milestones_;
};

class ChainedScheduler : public LRScheduler {
public:
ChainedScheduler(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
int64_t last_step = -1);
~ChainedScheduler() override = default;

void Step() override;
void InitialStep() override;

StateDict State() const override;
void LoadState(const StateDict &state) override;

protected:
float GetClosedFormLR() const override { return current_lr_; }

private:
std::vector<std::shared_ptr<LRScheduler>> schedulers_;
};

} // namespace lr_schedulers
} // namespace infini_train
3 changes: 3 additions & 0 deletions infini_train/include/nn/parallel/ddp/distributed_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class DistributedOptimizer final : public infini_train::Optimizer {
void StartParamSync(bool force_sync = false);
void FinishParamSync(bool skip_next_bucket_dispatch = false);

virtual void SetLearningRate(float lr) override;
virtual float GetLearningRate() const override;

private:
void BuildShardParamsAndBindGrads();

Expand Down
Loading
Loading