From 7a165897720fd804822a1dca602065a45e6fffac Mon Sep 17 00:00:00 2001 From: kinorw Date: Tue, 3 Mar 2026 14:42:50 +0800 Subject: [PATCH 01/18] refactor(optimizer): hoist learning_rate_ to Optimizer base and add lr accessors --- infini_train/include/optimizer.h | 10 ++++++---- infini_train/src/optimizer.cc | 13 +++++++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/infini_train/include/optimizer.h b/infini_train/include/optimizer.h index fb0ae2d5..2ca7a054 100644 --- a/infini_train/include/optimizer.h +++ b/infini_train/include/optimizer.h @@ -15,14 +15,19 @@ using OptimizerCreator = std::function(const std::vec class Optimizer { public: - explicit Optimizer(const std::vector> ¶ms); + explicit Optimizer(const std::vector> ¶ms, float learning_rate = 0.0f); virtual void ZeroGrad(bool set_to_none = true); virtual void Step() = 0; + virtual void SetLearningRate(float lr); + + virtual float GetLearningRate() const; + protected: std::vector> params_; + float learning_rate_ = 0.0f; }; namespace optimizers { @@ -38,8 +43,6 @@ class SGD : public Optimizer { }; } -private: - const float learning_rate_ = 0.0; }; class Adam : public Optimizer { @@ -58,7 +61,6 @@ class Adam : public Optimizer { private: int64_t t_; - const float learning_rate_; const float beta1_; const float beta2_; const float eps_; diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index 8eacafa3..a3830ce4 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -8,16 +8,21 @@ #include "infini_train/include/tensor.h" namespace infini_train { -Optimizer::Optimizer(const std::vector> ¶ms) : params_(params) {} - +Optimizer::Optimizer(const std::vector> ¶ms, float learning_rate) + : params_(params), learning_rate_(learning_rate) {} + void Optimizer::ZeroGrad(bool set_to_none) { for (auto param : params_) { param->ZeroGrad(set_to_none); } } +void Optimizer::SetLearningRate(float lr) { learning_rate_ = lr; } + +float Optimizer::GetLearningRate() const { return learning_rate_; } + namespace optimizers { SGD::SGD(const std::vector> ¶ms, float learning_rate) - : Optimizer(params), learning_rate_(learning_rate) {} + : Optimizer(params, learning_rate) {} void SGD::Step() { for (auto param : params_) { @@ -33,7 +38,7 @@ void SGD::Step() { } Adam::Adam(const std::vector> ¶ms, float learning_rate, float beta1, float beta2, float eps) - : Optimizer(params), t_(0), learning_rate_(learning_rate), beta1_(beta1), beta2_(beta2), eps_(eps) { + : Optimizer(params, learning_rate), t_(0), beta1_(beta1), beta2_(beta2), eps_(eps) { for (const auto ¶m : params_) { m_.emplace_back(std::make_shared(param->Dims(), param->Dtype(), param->GetDevice())); From 051486200ed43c29cdb67ba11ead8b4fa4beb416 Mon Sep 17 00:00:00 2001 From: kinorw Date: Tue, 3 Mar 2026 14:43:14 +0800 Subject: [PATCH 02/18] refactor(distributed_optimizer): passthrough SetLearningRate/GetLearningRate --- .../nn/parallel/ddp/distributed_optimizer.h | 3 +++ .../src/nn/parallel/ddp/distributed_optimizer.cc | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h index bc31442e..f95801a2 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h +++ b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h @@ -34,6 +34,9 @@ class DistributedOptimizer final : public infini_train::Optimizer { void StartParamSync(bool force_sync = false); void FinishParamSync(bool skip_next_bucket_dispatch = false); + virtual void SetLearningRate(float lr) override; + virtual float GetLearningRate() const override; + private: void BuildShardParamsAndBindGrads(); diff --git a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc index 55e5800b..c7f3de2a 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc @@ -114,6 +114,20 @@ void DistributedOptimizer::ZeroGrad(bool set_to_none) { } } +void DistributedOptimizer::SetLearningRate(float lr) { + Optimizer::SetLearningRate(lr); + if (base_optimizer_) { + base_optimizer_->SetLearningRate(lr); + } +} + +float DistributedOptimizer::GetLearningRate() const { + if (base_optimizer_) { + return base_optimizer_->GetLearningRate(); + } + return Optimizer::GetLearningRate(); +} + void DistributedOptimizer::Step() { // 1. Ensure grads are synced FinishGradSync(); From 81295e8fabd474d26d8e66830253c84399cf95d8 Mon Sep 17 00:00:00 2001 From: kinorw Date: Tue, 3 Mar 2026 14:43:31 +0800 Subject: [PATCH 03/18] feat(lr_scheduler): add LRScheduler abstract base class with StateDict --- CMakeLists.txt | 3 + infini_train/include/lr_scheduler.h | 47 +++++++ infini_train/src/lr_scheduler.cc | 45 +++++++ test/lr_scheduler/test_lr_scheduler.cc | 178 +++++++++++++++++++++++++ 4 files changed, 273 insertions(+) create mode 100644 infini_train/include/lr_scheduler.h create mode 100644 infini_train/src/lr_scheduler.cc create mode 100644 test/lr_scheduler/test_lr_scheduler.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 6c160686..35594ea0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -197,3 +197,6 @@ target_link_libraries(test_hook infini_train) add_executable(test_precision_check test/hook/test_precision_check.cc) target_link_libraries(test_precision_check infini_train) + +add_executable(test_lr_scheduler test/lr_scheduler/test_lr_scheduler.cc) +target_link_libraries(test_lr_scheduler infini_train) \ No newline at end of file diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h new file mode 100644 index 00000000..e697b241 --- /dev/null +++ b/infini_train/include/lr_scheduler.h @@ -0,0 +1,47 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace infini_train { + +class Optimizer; + +using StateValue = std::variant>; +using StateDict = std::unordered_map; + +class LRScheduler { +public: + explicit LRScheduler(std::shared_ptr optimizer, + int64_t last_step = -1); + + virtual ~LRScheduler() = default; + + LRScheduler(const LRScheduler &) = delete; + LRScheduler &operator=(const LRScheduler &) = delete; + + void Step(); + + float GetLR() const; + + int64_t LastStep() const; + + virtual StateDict State() const; + + virtual void LoadState(const StateDict &state); + +protected: + virtual float ComputeLR() = 0; + + std::shared_ptr optimizer_; + int64_t last_step_; + float current_lr_; + float base_lr_; +}; + +} // namespace infini_train \ No newline at end of file diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc new file mode 100644 index 00000000..b9725d2e --- /dev/null +++ b/infini_train/src/lr_scheduler.cc @@ -0,0 +1,45 @@ +#include "infini_train/include/lr_scheduler.h" + +#include "glog/logging.h" + +#include "infini_train/include/optimizer.h" + +namespace infini_train { + +LRScheduler::LRScheduler(std::shared_ptr optimizer, + int64_t last_step) + : optimizer_(std::move(optimizer)), + last_step_(last_step), + current_lr_(0.0f), + base_lr_(0.0f) { + CHECK(optimizer_) << "LRScheduler: optimizer must not be null."; + base_lr_ = optimizer_->GetLearningRate(); + current_lr_ = base_lr_; +} + +void LRScheduler::Step() { + ++last_step_; + current_lr_ = ComputeLR(); + optimizer_->SetLearningRate(current_lr_); +} + +float LRScheduler::GetLR() const { return current_lr_; } + +int64_t LRScheduler::LastStep() const { return last_step_; } + +StateDict LRScheduler::State() const { + return { + {"last_step", last_step_}, + {"current_lr", current_lr_}, + {"base_lr", base_lr_}, + }; +} + +void LRScheduler::LoadState(const StateDict &state) { + last_step_ = std::get(state.at("last_step")); + current_lr_ = std::get(state.at("current_lr")); + base_lr_ = std::get(state.at("base_lr")); + optimizer_->SetLearningRate(current_lr_); +} + +} // namespace infini_train \ No newline at end of file diff --git a/test/lr_scheduler/test_lr_scheduler.cc b/test/lr_scheduler/test_lr_scheduler.cc new file mode 100644 index 00000000..1d912187 --- /dev/null +++ b/test/lr_scheduler/test_lr_scheduler.cc @@ -0,0 +1,178 @@ +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/lr_scheduler.h" +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +using namespace infini_train; + +namespace { + +constexpr float kBaseLR = 0.1f; +constexpr float kEps = 1e-7f; + +class IdentityScheduler : public LRScheduler { +public: + using LRScheduler::LRScheduler; + +protected: + float ComputeLR() override { return base_lr_; } +}; + +class LinearDecayScheduler : public LRScheduler { +public: + LinearDecayScheduler(std::shared_ptr optimizer, + int64_t total_steps, int64_t last_step = -1) + : LRScheduler(std::move(optimizer), last_step), + total_steps_(total_steps) {} + +protected: + float ComputeLR() override { + if (last_step_ >= total_steps_) return 0.0f; + return base_lr_ * (1.0f - static_cast(last_step_) + / static_cast(total_steps_)); + } + +private: + int64_t total_steps_; +}; + +std::shared_ptr MakeDummyOptimizer(float lr) { + std::vector> empty_params; + return std::make_shared(empty_params, lr); +} + +bool FloatEq(float a, float b) { return std::fabs(a - b) < kEps; } + +int g_fail_count = 0; + +void Check(bool cond, const char *expr, int line) { + if (!cond) { + std::cerr << "FAIL [line " << line << "]: " << expr << std::endl; + ++g_fail_count; + } +} + +#define ASSERT_TRUE(cond) Check((cond), #cond, __LINE__) +#define ASSERT_FLOAT_EQ(a, b) \ + Check(FloatEq((a), (b)), #a " == " #b, __LINE__) + + +// T1: Init +void TestInitialState() { + std::cout << "[T1] TestInitialState" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + IdentityScheduler sched(opt); + + ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); + ASSERT_TRUE(sched.LastStep() == -1); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); +} + +// T2: SingleStep +void TestSingleStep() { + std::cout << "[T2] TestSingleStep" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + IdentityScheduler sched(opt); + + sched.Step(); + + ASSERT_TRUE(sched.LastStep() == 0); + ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); +} + +// T3: ComputeLR +void TestLinearDecay() { + std::cout << "[T3] TestLinearDecay" << std::endl; + constexpr int64_t kTotalSteps = 10; + auto opt = MakeDummyOptimizer(kBaseLR); + LinearDecayScheduler sched(opt, kTotalSteps); + + sched.Step(); + ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); + + for (int i = 0; i < 4; ++i) { sched.Step(); } + sched.Step(); + ASSERT_FLOAT_EQ(sched.GetLR(), 0.05f); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f); +} + +// T4: State → LoadState 往返一致性。 +void TestStateRoundTrip() { + std::cout << "[T4] TestStateRoundTrip" << std::endl; + constexpr int64_t kTotalSteps = 20; + auto opt = MakeDummyOptimizer(kBaseLR); + LinearDecayScheduler sched(opt, kTotalSteps); + + for (int i = 0; i < 7; ++i) { sched.Step(); } + + StateDict saved = sched.State(); + + ASSERT_TRUE(saved.count("last_step") == 1); + ASSERT_TRUE(saved.count("current_lr") == 1); + ASSERT_TRUE(saved.count("base_lr") == 1); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + LinearDecayScheduler sched2(opt2, kTotalSteps); + sched2.LoadState(saved); + + ASSERT_TRUE(sched2.LastStep() == 6); + ASSERT_FLOAT_EQ(sched2.GetLR(), sched.GetLR()); + ASSERT_FLOAT_EQ(opt2->GetLearningRate(), sched.GetLR()); +} + +// T5: resume Step +void TestResumeAndContinue() { + std::cout << "[T5] TestResumeAndContinue" << std::endl; + constexpr int64_t kTotalSteps = 20; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + LinearDecayScheduler sched_ref(opt_ref, kTotalSteps); + for (int i = 0; i < 10; ++i) { sched_ref.Step(); } + float lr_at_10 = sched_ref.GetLR(); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + LinearDecayScheduler sched_a(opt_a, kTotalSteps); + for (int i = 0; i < 5; ++i) { sched_a.Step(); } + StateDict checkpoint = sched_a.State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + LinearDecayScheduler sched_b(opt_b, kTotalSteps); + sched_b.LoadState(checkpoint); + for (int i = 0; i < 5; ++i) { sched_b.Step(); } + + ASSERT_FLOAT_EQ(sched_b.GetLR(), lr_at_10); + ASSERT_TRUE(sched_b.LastStep() == sched_ref.LastStep()); +} + +} // namespace + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + + std::cout << "========================================" << std::endl; + std::cout << " LRScheduler Base Class Tests" << std::endl; + std::cout << "========================================" << std::endl; + + TestInitialState(); + TestSingleStep(); + TestLinearDecay(); + TestStateRoundTrip(); + TestResumeAndContinue(); + + std::cout << "========================================" << std::endl; + if (g_fail_count == 0) { + std::cout << " All Tests PASSED" << std::endl; + } else { + std::cout << " " << g_fail_count << " test(s) FAILED" << std::endl; + } + std::cout << "========================================" << std::endl; + + return g_fail_count > 0 ? 1 : 0; +} \ No newline at end of file From 8e7cda012fa9693d98ea17453fe76754b60546fe Mon Sep 17 00:00:00 2001 From: kinorw Date: Tue, 3 Mar 2026 14:43:41 +0800 Subject: [PATCH 04/18] refactor(examples): add scheduler placeholder and use runtime lr in logs --- example/gpt2/main.cc | 6 ++++++ example/llama3/main.cc | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 3dfeadd3..e13956b0 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -13,6 +13,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" +#include "infini_train/include/lr_scheduler.h" #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" @@ -256,6 +257,7 @@ void Train(const nn::parallel::Rank &rank) { // auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate); auto optimizer_creator = optimizers::SGD::Create(FLAGS_learning_rate); std::shared_ptr optimizer = nullptr; + std::shared_ptr scheduler = nullptr; if (FLAGS_use_distributed_optimizer) { auto model_chunks = (pp_world_size > 1) @@ -267,6 +269,10 @@ void Train(const nn::parallel::Rank &rank) { optimizer = optimizer_creator(model->Parameters()); } + if (scheduler) { + scheduler->Step(); + } + auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast( diff --git a/example/llama3/main.cc b/example/llama3/main.cc index a7de81ff..6c9bffcd 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -11,6 +11,7 @@ #include "infini_train/include/core/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" +#include "infini_train/include/lr_scheduler.h" #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" @@ -235,6 +236,7 @@ void Train(const nn::parallel::Rank &rank) { // auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate); auto optimizer_creator = optimizers::Adam::Create(FLAGS_learning_rate); std::shared_ptr optimizer = nullptr; + std::shared_ptr scheduler = nullptr; if (FLAGS_use_distributed_optimizer) { auto model_chunks = (pp_world_size > 1) @@ -246,6 +248,10 @@ void Train(const nn::parallel::Rank &rank) { optimizer = optimizer_creator(model->Parameters()); } + if (scheduler){ + scheduler->Step(); + } + auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast(std::make_shared()) From 1e658816198b6ad28018fddd386328931f8cc05d Mon Sep 17 00:00:00 2001 From: kinorw Date: Wed, 4 Mar 2026 15:42:03 +0800 Subject: [PATCH 05/18] feat: add ConstantLR, StepLR and LinearWarmupLR --- CMakeLists.txt | 11 ++- infini_train/include/lr_scheduler.h | 44 +++++++++++ infini_train/src/lr_scheduler.cc | 38 +++++++++ test/lr_scheduler/test_constant_lr.cc | 109 ++++++++++++++++++++++++++ test/lr_scheduler/test_helpers.h | 39 +++++++++ test/lr_scheduler/test_linear_lr.cc | 78 ++++++++++++++++++ test/lr_scheduler/test_step_lr.cc | 72 +++++++++++++++++ 7 files changed, 390 insertions(+), 1 deletion(-) create mode 100644 test/lr_scheduler/test_constant_lr.cc create mode 100644 test/lr_scheduler/test_helpers.h create mode 100644 test/lr_scheduler/test_linear_lr.cc create mode 100644 test/lr_scheduler/test_step_lr.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 35594ea0..00337242 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,4 +199,13 @@ add_executable(test_precision_check test/hook/test_precision_check.cc) target_link_libraries(test_precision_check infini_train) add_executable(test_lr_scheduler test/lr_scheduler/test_lr_scheduler.cc) -target_link_libraries(test_lr_scheduler infini_train) \ No newline at end of file +target_link_libraries(test_lr_scheduler infini_train) + +add_executable(test_constant_lr test/lr_scheduler/test_constant_lr.cc) +target_link_libraries(test_constant_lr infini_train) + +add_executable(test_step_lr test/lr_scheduler/test_step_lr.cc) +target_link_libraries(test_step_lr infini_train) + +add_executable(test_linear_lr test/lr_scheduler/test_linear_lr.cc) +target_link_libraries(test_linear_lr infini_train) \ No newline at end of file diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h index e697b241..6a76a7d5 100644 --- a/infini_train/include/lr_scheduler.h +++ b/infini_train/include/lr_scheduler.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -44,4 +45,47 @@ class LRScheduler { float base_lr_; }; +namespace lr_schedulers { +class ConstantLR : public LRScheduler { +public: + ConstantLR(std::shared_ptr optimizer, float factor = 1.0f / 3.0f, int total_iters = 5, + int64_t last_step = -1); + + ~ConstantLR() override = default; + +protected: + float ComputeLR() override ; + +private: + const float factor_; + const int64_t total_iters_; +}; + +class StepLR : public LRScheduler { +public: + StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1); + ~StepLR() override = default; + +protected: + float ComputeLR() override; +private: + const int64_t step_size_; + const float gamma_; +}; + +class LinearWarmupLR : public LRScheduler { +public: + LinearWarmupLR(std::shared_ptr optimizer, int64_t warmup_steps, float start_factor = 0.0f, int64_t last_step = -1); + ~LinearWarmupLR() override = default; + +protected: + float ComputeLR() override ; + +private: + const int64_t warmup_steps_; + const float start_factor_; + +}; + +} // namespace lr_schedulers } // namespace infini_train \ No newline at end of file diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc index b9725d2e..5dcd2253 100644 --- a/infini_train/src/lr_scheduler.cc +++ b/infini_train/src/lr_scheduler.cc @@ -42,4 +42,42 @@ void LRScheduler::LoadState(const StateDict &state) { optimizer_->SetLearningRate(current_lr_); } +namespace lr_schedulers { +ConstantLR::ConstantLR(std::shared_ptr optimizer, float factor, int total_iters, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) { + Step(); +} + +float ConstantLR::ComputeLR() { + if(last_step_ < total_iters_) { + return base_lr_ * factor_; + } + return base_lr_; +} + +StepLR::StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma , int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) { + Step(); +} + +float StepLR::ComputeLR() { + return base_lr_ * static_cast(std::pow(static_cast(gamma_), + static_cast(last_step_ / step_size_))); +} + +LinearWarmupLR::LinearWarmupLR(std::shared_ptr optimizer, int64_t warmup_steps, float start_factor, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), warmup_steps_(warmup_steps), start_factor_(start_factor) { + Step(); +} + +float LinearWarmupLR::ComputeLR() { + if (last_step_ >= warmup_steps_) { + return base_lr_; + } + float alpha = static_cast(last_step_) / static_cast(warmup_steps_); + return base_lr_ * ( start_factor_ + (1.0f - start_factor_) * alpha); +} + + +} // namespace lr_schedulers } // namespace infini_train \ No newline at end of file diff --git a/test/lr_scheduler/test_constant_lr.cc b/test/lr_scheduler/test_constant_lr.cc new file mode 100644 index 00000000..f997742b --- /dev/null +++ b/test/lr_scheduler/test_constant_lr.cc @@ -0,0 +1,109 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { +constexpr float kBaseLR = 0.1f; +} // namespace + + +void TestInitialState() { + auto opt = MakeDummyOptimizer(kBaseLR); + ConstantLR sched(opt, /*factor=*/0.5f, /*total_iters=*/3); + ASSERT_FLOAT_EQ(sched.GetLR(), 0.05f); + ASSERT_TRUE(sched.LastStep() == 0); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f); +} + +void TestFirstStepAppliesFactor() { + auto opt = MakeDummyOptimizer(kBaseLR); + ConstantLR sched(opt, 0.5f, 3); + sched.Step(); // last_step_ = 0 + ASSERT_FLOAT_EQ(sched.GetLR(), 0.05f); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f); + ASSERT_TRUE(sched.LastStep() == 1); +} + +void TestWithinTotalIters() { + auto opt = MakeDummyOptimizer(kBaseLR); + ConstantLR sched(opt, 0.5f, 3); + for (int i = 0; i < 2; ++i) sched.Step(); + // last_step_ = 2, still < 3 + ASSERT_FLOAT_EQ(sched.GetLR(), 0.05f); +} + +void TestBeyondTotalIters() { + auto opt = MakeDummyOptimizer(kBaseLR); + ConstantLR sched(opt, 0.5f, 3); + for (int i = 0; i < 10; ++i) sched.Step(); + ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); +} + +void TestPyTorchAlignment() { + const std::vector expected = {0.05f, 0.05f, 0.1f, 0.1f, 0.1f}; + auto opt = MakeDummyOptimizer(kBaseLR); + ConstantLR sched(opt, 0.5f, 3); + for (size_t i = 0; i < expected.size(); ++i) { + sched.Step(); + ASSERT_FLOAT_EQ(sched.GetLR(), expected[i]); + } +} + +void TestStateRoundTrip() { + auto opt = MakeDummyOptimizer(kBaseLR); + ConstantLR sched(opt, 0.5f, 5); + for (int i = 0; i < 3; ++i) sched.Step(); + StateDict saved = sched.State(); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + ConstantLR sched2(opt2, 0.5f, 5); + sched2.LoadState(saved); + + ASSERT_TRUE(sched2.LastStep() == sched.LastStep()); + ASSERT_FLOAT_EQ(sched2.GetLR(), sched.GetLR()); + ASSERT_FLOAT_EQ(opt2->GetLearningRate(), sched.GetLR()); +} + +void TestResumeConsistency() { + constexpr int kN = 8; + constexpr int kK = 3; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + ConstantLR sched_ref(opt_ref, 0.5f, 5); + for (int i = 0; i < kN; ++i) sched_ref.Step(); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + ConstantLR sched_a(opt_a, 0.5f, 5); + for (int i = 0; i < kK; ++i) sched_a.Step(); + StateDict ckpt = sched_a.State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + ConstantLR sched_b(opt_b, 0.5f, 5); + sched_b.LoadState(ckpt); + for (int i = 0; i < kN - kK; ++i) sched_b.Step(); + + ASSERT_FLOAT_EQ(sched_b.GetLR(), sched_ref.GetLR()); + ASSERT_TRUE(sched_b.LastStep() == sched_ref.LastStep()); +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== ConstantLR Tests ===" << std::endl; + TestInitialState(); + TestFirstStepAppliesFactor(); + TestWithinTotalIters(); + TestBeyondTotalIters(); + TestPyTorchAlignment(); + TestStateRoundTrip(); + TestResumeConsistency(); + std::cout << "========================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} \ No newline at end of file diff --git a/test/lr_scheduler/test_helpers.h b/test/lr_scheduler/test_helpers.h new file mode 100644 index 00000000..43f0a2f5 --- /dev/null +++ b/test/lr_scheduler/test_helpers.h @@ -0,0 +1,39 @@ +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +namespace { + +constexpr float kEps = 1e-6f; + +std::shared_ptr MakeDummyOptimizer(float lr) { + std::vector> empty_params; + return std::make_shared(empty_params, lr); +} + +bool FloatNear(float a, float b, float eps = kEps) { + return std::fabs(a - b) < eps; +} + +int g_fail_count = 0; + +void Check(bool cond, const char *expr, int line) { + if (!cond) { + std::cerr << "FAIL [line " << line << "]: " << expr << std::endl; + ++g_fail_count; + } +} + +#define ASSERT_TRUE(cond) Check((cond), #cond, __LINE__) +#define ASSERT_FLOAT_EQ(a, b) \ + Check(FloatNear((a), (b)), #a " == " #b, __LINE__) +#define ASSERT_FLOAT_NEAR(a, b, eps) \ + Check(FloatNear((a), (b), (eps)), #a " ≈ " #b, __LINE__) + +} // namespace \ No newline at end of file diff --git a/test/lr_scheduler/test_linear_lr.cc b/test/lr_scheduler/test_linear_lr.cc new file mode 100644 index 00000000..8d6981ed --- /dev/null +++ b/test/lr_scheduler/test_linear_lr.cc @@ -0,0 +1,78 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { +constexpr float kBaseLR = 0.1f; +} + +void TestFirstStepFromZero() { + auto opt = MakeDummyOptimizer(kBaseLR); + LinearWarmupLR sched(opt, /*warmup_steps=*/5, /*start_factor=*/0.0f); + sched.Step(); // last_step_=1, alpha=1/5=0.2 → lr=0.02 + ASSERT_FLOAT_EQ(sched.GetLR(), 0.02f); +} + +void TestMidpointLR() { + auto opt = MakeDummyOptimizer(kBaseLR); + LinearWarmupLR sched(opt, 5, 0.0f); + for (int i = 0; i < 3; ++i) sched.Step(); + // last_step_=2, alpha=3/5=0.6 → lr=0.1*0.6=0.06 + ASSERT_FLOAT_EQ(sched.GetLR(), 0.06f); +} + +void TestWarmupEnd() { + auto opt = MakeDummyOptimizer(kBaseLR); + LinearWarmupLR sched(opt, 5, 0.0f); + for (int i = 0; i < 5; ++i) sched.Step(); + // last_step_=5, 5 >= 5 → base_lr + ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); +} + +void TestBeyondWarmup() { + auto opt = MakeDummyOptimizer(kBaseLR); + LinearWarmupLR sched(opt, 5, 0.0f); + for (int i = 0; i < 20; ++i) sched.Step(); + ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); +} + +void TestCustomStartFactor() { + auto opt = MakeDummyOptimizer(kBaseLR); + LinearWarmupLR sched(opt, 4, /*start_factor=*/0.25f); + sched.Step(); // last_step_=1, alpha=1/4=0.25 → lr=0.1*(0.25+0.75*0.25)=0.04375 + ASSERT_FLOAT_NEAR(sched.GetLR(), 0.04375f, 1e-6f); + sched.Step(); // last_step_=2, alpha=2/4=0.5 → lr=0.1*(0.25+0.75*0.5)=0.0625 + ASSERT_FLOAT_NEAR(sched.GetLR(), 0.0625f, 1e-6f); +} + +void TestPyTorchAlignment() { + const std::vector expected = { + 0.02f, 0.04f, 0.06f, 0.08f, 0.1f, 0.1f, 0.1f}; + auto opt = MakeDummyOptimizer(kBaseLR); + LinearWarmupLR sched(opt, 5, 0.0f); + for (size_t i = 0; i < expected.size(); ++i) { + sched.Step(); + ASSERT_FLOAT_NEAR(sched.GetLR(), expected[i], 1e-7f); + } +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== Linear Tests ===" << std::endl; + TestFirstStepFromZero(); + TestMidpointLR(); + TestWarmupEnd(); + TestBeyondWarmup(); + TestCustomStartFactor(); + TestPyTorchAlignment(); + + std::cout << "========================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} \ No newline at end of file diff --git a/test/lr_scheduler/test_step_lr.cc b/test/lr_scheduler/test_step_lr.cc new file mode 100644 index 00000000..b7c1466e --- /dev/null +++ b/test/lr_scheduler/test_step_lr.cc @@ -0,0 +1,72 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { +constexpr float kBaseLR = 0.1f; +} + +void TestWithinFirstPeriod() { + auto opt = MakeDummyOptimizer(kBaseLR); + StepLR sched(opt, /*step_size=*/3, /*gamma=*/0.1f); + for (int i = 0; i < 2; ++i) { + sched.Step(); + ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); // last_step 1,2 → 指数 0 + } +} + +void TestFirstDecay() { + auto opt = MakeDummyOptimizer(kBaseLR); + StepLR sched(opt, 3, 0.1f); + for (int i = 0; i < 3; ++i) sched.Step(); + // last_step=3, 3//3=1 → 0.1^1 = 0.1 → lr=0.01 + ASSERT_FLOAT_EQ(sched.GetLR(), 0.01f); +} + +void TestMultipleDecays() { + auto opt = MakeDummyOptimizer(kBaseLR); + StepLR sched(opt, 3, 0.1f); + for (int i = 0; i < 6; ++i) sched.Step(); + // last_step=6, 6//3=2 → 0.1^2 = 0.01 → lr=0.001 + ASSERT_FLOAT_NEAR(sched.GetLR(), 0.001f, 1e-7f); +} + +void TestPyTorchAlignment() { + const std::vector expected = { + 0.1f, 0.1f, 0.01f, 0.01f, 0.01f, 0.001f, 0.001f}; + auto opt = MakeDummyOptimizer(kBaseLR); + StepLR sched(opt, 3, 0.1f); + for (size_t i = 0; i < expected.size(); ++i) { + sched.Step(); + ASSERT_FLOAT_NEAR(sched.GetLR(), expected[i], 1e-7f); + } +} + +void TestGammaOne() { + auto opt = MakeDummyOptimizer(kBaseLR); + StepLR sched(opt, 5, 1.0f); + for (int i = 0; i < 20; ++i) { + sched.Step(); + ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); + } +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== Step Tests ===" << std::endl; + TestWithinFirstPeriod(); + TestFirstDecay(); + TestMultipleDecays(); + TestPyTorchAlignment(); + TestGammaOne(); + + std::cout << "========================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} \ No newline at end of file From d924d3dfdb8c4f50eef106918c3889f938350c61 Mon Sep 17 00:00:00 2001 From: kinorw Date: Thu, 5 Mar 2026 14:15:16 +0800 Subject: [PATCH 06/18] refactor(lr_scheduler): replace ComputeLR with virtual Step and ApplyLR helper --- infini_train/include/lr_scheduler.h | 20 ++++++++-------- infini_train/src/lr_scheduler.cc | 32 ++++++++++++++------------ test/lr_scheduler/test_lr_scheduler.cc | 16 +++++++------ 3 files changed, 36 insertions(+), 32 deletions(-) diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h index 6a76a7d5..acf33d34 100644 --- a/infini_train/include/lr_scheduler.h +++ b/infini_train/include/lr_scheduler.h @@ -26,18 +26,20 @@ class LRScheduler { LRScheduler(const LRScheduler &) = delete; LRScheduler &operator=(const LRScheduler &) = delete; - void Step(); + virtual void Step() = 0; float GetLR() const; - + float BaseLR() const; int64_t LastStep() const; - virtual StateDict State() const; + void ResetStep(int64_t step = -1); + virtual StateDict State() const; virtual void LoadState(const StateDict &state); protected: - virtual float ComputeLR() = 0; + + void ApplyLR(float lr); std::shared_ptr optimizer_; int64_t last_step_; @@ -53,8 +55,7 @@ class ConstantLR : public LRScheduler { ~ConstantLR() override = default; -protected: - float ComputeLR() override ; + void Step() override; private: const float factor_; @@ -66,8 +67,8 @@ class StepLR : public LRScheduler { StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1); ~StepLR() override = default; -protected: - float ComputeLR() override; + void Step() override; + private: const int64_t step_size_; const float gamma_; @@ -78,8 +79,7 @@ class LinearWarmupLR : public LRScheduler { LinearWarmupLR(std::shared_ptr optimizer, int64_t warmup_steps, float start_factor = 0.0f, int64_t last_step = -1); ~LinearWarmupLR() override = default; -protected: - float ComputeLR() override ; + void Step() override; private: const int64_t warmup_steps_; diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc index 5dcd2253..1db88ade 100644 --- a/infini_train/src/lr_scheduler.cc +++ b/infini_train/src/lr_scheduler.cc @@ -17,9 +17,8 @@ LRScheduler::LRScheduler(std::shared_ptr optimizer, current_lr_ = base_lr_; } -void LRScheduler::Step() { - ++last_step_; - current_lr_ = ComputeLR(); +void LRScheduler::ApplyLR(float lr) { + current_lr_ = lr; optimizer_->SetLearningRate(current_lr_); } @@ -48,11 +47,10 @@ ConstantLR::ConstantLR(std::shared_ptr optimizer, float factor, int t Step(); } -float ConstantLR::ComputeLR() { - if(last_step_ < total_iters_) { - return base_lr_ * factor_; - } - return base_lr_; +void ConstantLR::Step() { + ++last_step_; + ApplyLR( + last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_); } StepLR::StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma , int64_t last_step) @@ -60,9 +58,11 @@ StepLR::StepLR(std::shared_ptr optimizer, int64_t step_size, float ga Step(); } -float StepLR::ComputeLR() { - return base_lr_ * static_cast(std::pow(static_cast(gamma_), - static_cast(last_step_ / step_size_))); +void StepLR::Step() { + ++last_step_; + ApplyLR(base_lr_ * static_cast( + std::pow(static_cast(gamma_), + static_cast(last_step_ / step_size_)))); } LinearWarmupLR::LinearWarmupLR(std::shared_ptr optimizer, int64_t warmup_steps, float start_factor, int64_t last_step) @@ -70,12 +70,14 @@ LinearWarmupLR::LinearWarmupLR(std::shared_ptr optimizer, int64_t war Step(); } -float LinearWarmupLR::ComputeLR() { +void LinearWarmupLR::Step() { + ++last_step_; if (last_step_ >= warmup_steps_) { - return base_lr_; + ApplyLR(base_lr_); + } else{ + float alpha = static_cast(last_step_) / static_cast(warmup_steps_); + ApplyLR(base_lr_ * ( start_factor_ + (1.0f - start_factor_) * alpha)); } - float alpha = static_cast(last_step_) / static_cast(warmup_steps_); - return base_lr_ * ( start_factor_ + (1.0f - start_factor_) * alpha); } diff --git a/test/lr_scheduler/test_lr_scheduler.cc b/test/lr_scheduler/test_lr_scheduler.cc index 1d912187..aae8bff7 100644 --- a/test/lr_scheduler/test_lr_scheduler.cc +++ b/test/lr_scheduler/test_lr_scheduler.cc @@ -19,8 +19,10 @@ class IdentityScheduler : public LRScheduler { public: using LRScheduler::LRScheduler; -protected: - float ComputeLR() override { return base_lr_; } + void Step() override { + ++last_step_; + ApplyLR(base_lr_); + } }; class LinearDecayScheduler : public LRScheduler { @@ -30,11 +32,11 @@ class LinearDecayScheduler : public LRScheduler { : LRScheduler(std::move(optimizer), last_step), total_steps_(total_steps) {} -protected: - float ComputeLR() override { - if (last_step_ >= total_steps_) return 0.0f; - return base_lr_ * (1.0f - static_cast(last_step_) - / static_cast(total_steps_)); + void Step() override { + ++last_step_; + ApplyLR( + last_step_ >= total_steps_ ? 0.0f : base_lr_ * (1.0f - static_cast(last_step_) + / static_cast(total_steps_))); } private: From baca2ef58bab79bf11908200925d011d7f6e067e Mon Sep 17 00:00:00 2001 From: kinorw Date: Thu, 5 Mar 2026 14:39:07 +0800 Subject: [PATCH 07/18] feat(lr_schedulers): add LambdaLR strategy --- CMakeLists.txt | 5 +- infini_train/include/lr_scheduler.h | 15 ++++ infini_train/src/lr_scheduler.cc | 9 +++ test/lr_scheduler/test_lambda_lr.cc | 105 ++++++++++++++++++++++++++++ 4 files changed, 133 insertions(+), 1 deletion(-) create mode 100644 test/lr_scheduler/test_lambda_lr.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 00337242..eb20396a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -208,4 +208,7 @@ add_executable(test_step_lr test/lr_scheduler/test_step_lr.cc) target_link_libraries(test_step_lr infini_train) add_executable(test_linear_lr test/lr_scheduler/test_linear_lr.cc) -target_link_libraries(test_linear_lr infini_train) \ No newline at end of file +target_link_libraries(test_linear_lr infini_train) + +add_executable(test_lambda_lr test/lr_scheduler/test_lambda_lr.cc) +target_link_libraries(test_lambda_lr infini_train) \ No newline at end of file diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h index acf33d34..f88584ff 100644 --- a/infini_train/include/lr_scheduler.h +++ b/infini_train/include/lr_scheduler.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -87,5 +88,19 @@ class LinearWarmupLR : public LRScheduler { }; +class LambdaLR : public LRScheduler { +public: + using LambdaFunc = std::function; + + LambdaLR(std::shared_ptr optimizer, LambdaFunc lr_lambda, int64_t last_step = -1); + ~LambdaLR() override = default; + + void Step() override; + +private: + const LambdaFunc lr_lambda_; + +}; + } // namespace lr_schedulers } // namespace infini_train \ No newline at end of file diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc index 1db88ade..f20b084d 100644 --- a/infini_train/src/lr_scheduler.cc +++ b/infini_train/src/lr_scheduler.cc @@ -80,6 +80,15 @@ void LinearWarmupLR::Step() { } } +LambdaLR::LambdaLR(std::shared_ptr optimizer, std::function lr_lambda, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) { + Step(); +} + +void LambdaLR::Step() { + ++last_step_; + ApplyLR(base_lr_ * lr_lambda_(last_step_)); +} } // namespace lr_schedulers } // namespace infini_train \ No newline at end of file diff --git a/test/lr_scheduler/test_lambda_lr.cc b/test/lr_scheduler/test_lambda_lr.cc new file mode 100644 index 00000000..37b0af23 --- /dev/null +++ b/test/lr_scheduler/test_lambda_lr.cc @@ -0,0 +1,105 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { +constexpr float kBaseLR = 0.1f; +} // namespace + +void TestIdentityLambda() { + auto opt = MakeDummyOptimizer(kBaseLR); + LambdaLR sched(opt, [](int64_t) { return 1.0f; }); + // 构造器内 Step() → last_step_=0, lr = 0.1 * 1.0 = 0.1 + ASSERT_TRUE(sched.LastStep() == 0); + ASSERT_FLOAT_NEAR(sched.GetLR(), kBaseLR, kEps); + ASSERT_FLOAT_NEAR(opt->GetLearningRate(), kBaseLR, kEps); +} + +void TestLinearDecayLambda() { + auto opt = MakeDummyOptimizer(kBaseLR); + LambdaLR sched(opt, [](int64_t step) { return 1.0f - step * 0.1f; }); + // step=0, lambda(0)=1.0, lr=0.1 + ASSERT_FLOAT_NEAR(sched.GetLR(), 0.1f, kEps); + + sched.Step(); // step=1, lambda(1)=0.9, lr=0.09 + ASSERT_FLOAT_NEAR(sched.GetLR(), 0.09f, kEps); + + sched.Step(); // step=2, lambda(2)=0.8, lr=0.08 + ASSERT_FLOAT_NEAR(sched.GetLR(), 0.08f, kEps); + + sched.Step(); // step=3, lambda(3)=0.7, lr=0.07 + ASSERT_FLOAT_NEAR(sched.GetLR(), 0.07f, kEps); +} + +void TestPyTorchAlignment() { + // PyTorch: LambdaLR(opt, lr_lambda=lambda epoch: 0.95**epoch) + auto opt = MakeDummyOptimizer(kBaseLR); + LambdaLR sched(opt, [](int64_t step) { + return static_cast(std::pow(0.95, step)); + }); + // step=0, lr = 0.1 * 0.95^0 = 0.1 + ASSERT_FLOAT_NEAR(sched.GetLR(), 0.1f, kEps); + + std::vector expected = {0.095f, 0.09025f, 0.0857375f}; + for (size_t i = 0; i < expected.size(); ++i) { + sched.Step(); + ASSERT_FLOAT_NEAR(sched.GetLR(), expected[i], 1e-5f); + } +} + +void TestStateRoundTrip() { + auto lambda_fn = [](int64_t step) { return 1.0f - step * 0.05f; }; + auto opt = MakeDummyOptimizer(kBaseLR); + LambdaLR sched(opt, lambda_fn); + for (int i = 0; i < 5; ++i) sched.Step(); + StateDict saved = sched.State(); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + LambdaLR sched2(opt2, lambda_fn); // same lambda + sched2.LoadState(saved); + + ASSERT_TRUE(sched2.LastStep() == sched.LastStep()); + ASSERT_FLOAT_NEAR(sched2.GetLR(), sched.GetLR(), kEps); +} + +void TestResumeConsistency() { + auto lambda_fn = [](int64_t step) { return 1.0f - step * 0.05f; }; + constexpr int kN = 10, kK = 4; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + LambdaLR sched_ref(opt_ref, lambda_fn); + for (int i = 0; i < kN; ++i) sched_ref.Step(); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + LambdaLR sched_a(opt_a, lambda_fn); + for (int i = 0; i < kK; ++i) sched_a.Step(); + StateDict ckpt = sched_a.State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + LambdaLR sched_b(opt_b, lambda_fn); + sched_b.LoadState(ckpt); + for (int i = 0; i < kN - kK; ++i) sched_b.Step(); + + ASSERT_FLOAT_NEAR(sched_b.GetLR(), sched_ref.GetLR(), kEps); + ASSERT_TRUE(sched_b.LastStep() == sched_ref.LastStep()); +} + + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== LambdaLR Tests ===" << std::endl; + TestIdentityLambda(); + TestLinearDecayLambda(); + TestPyTorchAlignment(); + TestStateRoundTrip(); + TestResumeConsistency(); + std::cout << "======================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} \ No newline at end of file From 7df75d7be78241abf7050a972a0b8f78b3c2aab2 Mon Sep 17 00:00:00 2001 From: kinorw Date: Thu, 5 Mar 2026 17:18:07 +0800 Subject: [PATCH 08/18] refactor(optimizer): add initial_learning_rate and it's accessors --- infini_train/include/optimizer.h | 6 ++++++ infini_train/src/optimizer.cc | 13 +++++++++++++ 2 files changed, 19 insertions(+) diff --git a/infini_train/include/optimizer.h b/infini_train/include/optimizer.h index 2ca7a054..ee879215 100644 --- a/infini_train/include/optimizer.h +++ b/infini_train/include/optimizer.h @@ -25,9 +25,15 @@ class Optimizer { virtual float GetLearningRate() const; + float GetInitialLearningRate() const; + + void SetInitialLearningRate(float lr); + protected: std::vector> params_; float learning_rate_ = 0.0f; + float initial_learning_rate_ = 0.0f; + bool initial_lr_set_ = false; }; namespace optimizers { diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index a3830ce4..2afaa694 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -19,6 +19,19 @@ void Optimizer::SetLearningRate(float lr) { learning_rate_ = lr; } float Optimizer::GetLearningRate() const { return learning_rate_; } +float Optimizer::GetInitialLearningRate() const { + CHECK(initial_lr_set_) + << "Optimizer: initial_learning_rate not set. " + "Use with an LRScheduler first."; + return initial_learning_rate_; +} + +void Optimizer::SetInitialLearningRate(float lr) { + if (!initial_lr_set_) { + initial_learning_rate_ = lr; + initial_lr_set_ = true; + } +} namespace optimizers { SGD::SGD(const std::vector> ¶ms, float learning_rate) From d0ac538eea37daa38a1d2e3f806e9d785fcb2bc8 Mon Sep 17 00:00:00 2001 From: kinorw Date: Thu, 5 Mar 2026 17:25:07 +0800 Subject: [PATCH 09/18] feat(lr_schedulers): add SequentialLR composite strategy --- CMakeLists.txt | 5 +- infini_train/include/lr_scheduler.h | 20 ++++ infini_train/src/lr_scheduler.cc | 86 ++++++++++++++- test/lr_scheduler/test_sequential_lr.cc | 136 ++++++++++++++++++++++++ 4 files changed, 245 insertions(+), 2 deletions(-) create mode 100644 test/lr_scheduler/test_sequential_lr.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index eb20396a..02c964f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -211,4 +211,7 @@ add_executable(test_linear_lr test/lr_scheduler/test_linear_lr.cc) target_link_libraries(test_linear_lr infini_train) add_executable(test_lambda_lr test/lr_scheduler/test_lambda_lr.cc) -target_link_libraries(test_lambda_lr infini_train) \ No newline at end of file +target_link_libraries(test_lambda_lr infini_train) + +add_executable(test_sequential_lr test/lr_scheduler/test_sequential_lr.cc) +target_link_libraries(test_sequential_lr infini_train) \ No newline at end of file diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h index f88584ff..687fe66e 100644 --- a/infini_train/include/lr_scheduler.h +++ b/infini_train/include/lr_scheduler.h @@ -102,5 +102,25 @@ class LambdaLR : public LRScheduler { }; + +class SequentialLR : public LRScheduler { +public: + SequentialLR(std::shared_ptr optimizer, + std::vector> schedulers, + std::vector milestones, + int64_t last_step = -1); + + ~SequentialLR() override = default; + + void Step() override; + StateDict State() const override; + void LoadState(const StateDict &state) override; + +private: + std::vector> schedulers_; + std::vector milestones_; +}; + + } // namespace lr_schedulers } // namespace infini_train \ No newline at end of file diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc index f20b084d..c18ad576 100644 --- a/infini_train/src/lr_scheduler.cc +++ b/infini_train/src/lr_scheduler.cc @@ -13,7 +13,8 @@ LRScheduler::LRScheduler(std::shared_ptr optimizer, current_lr_(0.0f), base_lr_(0.0f) { CHECK(optimizer_) << "LRScheduler: optimizer must not be null."; - base_lr_ = optimizer_->GetLearningRate(); + optimizer_->SetInitialLearningRate(optimizer_->GetLearningRate()); + base_lr_ = optimizer_->GetInitialLearningRate(); current_lr_ = base_lr_; } @@ -24,8 +25,12 @@ void LRScheduler::ApplyLR(float lr) { float LRScheduler::GetLR() const { return current_lr_; } +float LRScheduler::BaseLR() const { return base_lr_; } + int64_t LRScheduler::LastStep() const { return last_step_; } +void LRScheduler::ResetStep(int64_t step) { last_step_ = step; } + StateDict LRScheduler::State() const { return { {"last_step", last_step_}, @@ -90,5 +95,84 @@ void LambdaLR::Step() { ApplyLR(base_lr_ * lr_lambda_(last_step_)); } +SequentialLR::SequentialLR(std::shared_ptr optimizer, + std::vector> schedulers, + std::vectormilestones, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)), milestones_(std::move(milestones)) { + CHECK(!schedulers_.empty()) << "SequentialLR requires at least one scheduler."; + CHECK_EQ(milestones_.size(), schedulers_.size() - 1) + << "SequentialLR: milestones count must be schedulers count - 1."; + + for(size_t i = 1; i < milestones_.size(); ++i) { + CHECK_GT(milestones_[i], milestones_[i-1]) << "Milestones must be strictly increasing."; + } + + optimizer_->SetLearningRate(schedulers_[0]->BaseLR()); + + // Reset all schedulers to the same last_step so they are in sync when Step() is called. + for (auto &sched : schedulers_) { + sched->ResetStep(sched->LastStep()-1); + } + + Step(); +} + +void SequentialLR::Step() { + ++last_step_; + size_t idx = 0; + for (size_t i = 0; i < milestones_.size(); ++i) { + if (last_step_ >= milestones_[i]) { + idx = i + 1; + } else { + break; + } + } + + auto &scheduler = schedulers_[idx]; + + if (idx > 0 && milestones_[idx - 1] == last_step_) { + scheduler->ResetStep(-1); + scheduler->Step(); + } else { + scheduler->Step(); + } + + current_lr_ = scheduler->GetLR(); +} + +StateDict SequentialLR::State() const { + StateDict state; + state["last_step"] = last_step_; + state["current_lr"] = current_lr_; + state["base_lr"] = base_lr_; + for (size_t i = 0; i < schedulers_.size(); ++i) { + auto sub_state = schedulers_[i]->State(); + for (const auto &[key, value] : sub_state) { + state["scheduler_" + std::to_string(i) + "." + key] = value; + } + } + return state; +} + +void SequentialLR::LoadState(const StateDict &state) { + last_step_ = std::get(state.at("last_step")); + current_lr_ = std::get(state.at("current_lr")); + base_lr_ = std::get(state.at("base_lr")); + + for (size_t i = 0; i < schedulers_.size(); ++i) { + StateDict sub_state; + std::string prefix = "scheduler_" + std::to_string(i) + "."; + for (const auto &[key, value] : state) { + if (key.substr(0, prefix.size()) == prefix) { + sub_state[key.substr(prefix.size())] = value; + } + } + if(!sub_state.empty()) + schedulers_[i]->LoadState(sub_state); + } + optimizer_->SetLearningRate(current_lr_); +} + + } // namespace lr_schedulers } // namespace infini_train \ No newline at end of file diff --git a/test/lr_scheduler/test_sequential_lr.cc b/test/lr_scheduler/test_sequential_lr.cc new file mode 100644 index 00000000..2a29a208 --- /dev/null +++ b/test/lr_scheduler/test_sequential_lr.cc @@ -0,0 +1,136 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; +namespace { +constexpr float kBaseLR = 0.1f; +} // namespace + +void TestWarmupThenConstant() { + std::cout << "[TC1] TestWarmupThenConstant" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + + auto warmup = std::make_shared(opt, /*warmup_steps=*/3, /*start_factor=*/1e-8); + auto constant = std::make_shared(opt, /*factor=*/1.0f, /*total_iters=*/100); + + SequentialLR sched(opt, {warmup, constant}, {3}); + + ASSERT_FLOAT_NEAR(sched.GetLR(), 0.0f, kEps); + + sched.Step(); // global=1, warmup step=1, lr=0.1*(1/3) + ASSERT_FLOAT_NEAR(sched.GetLR(), 0.1f / 3.0f, 1e-5f); + + sched.Step(); // global=2, warmup step=2, lr=0.1*(2/3) + ASSERT_FLOAT_NEAR(sched.GetLR(), 0.2f / 3.0f, 1e-5f); + + sched.Step(); // global=3, constant step=0, lr=0.1*1.0=0.1 + ASSERT_FLOAT_NEAR(sched.GetLR(), kBaseLR, kEps); + + sched.Step(); // global=4, constant step=1, lr=0.1 + ASSERT_FLOAT_NEAR(sched.GetLR(), kBaseLR, kEps); +} + +void TestWarmupThenStepLR() { + std::cout << "[TC2] TestWarmupThenStepLR" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + + auto warmup = std::make_shared(opt, 3, 0.0f); + auto step_lr = std::make_shared(opt, /*step_size=*/3, /*gamma=*/0.5f); + + SequentialLR sched(opt, {warmup, step_lr}, {3}); + + sched.Step(); // global=1 + sched.Step(); // global=2 + + sched.Step(); // global=3, StepLR step=0, lr=0.1 + ASSERT_FLOAT_NEAR(sched.GetLR(), 0.1f, kEps); + + sched.Step(); // global=4, StepLR step=1 + sched.Step(); // global=5, StepLR step=2 + sched.Step(); // global=6, StepLR step=3, 3//3=1, lr=0.1*0.5=0.05 + ASSERT_FLOAT_NEAR(sched.GetLR(), 0.05f, kEps); +} + +void TestWarmupThenStepThenConstant(){ + std::cout << "[TC3] TestWarmupThenStepThenConstant" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + + auto warmup = std::make_shared(opt, 3, 0.0f); + auto step_lr = std::make_shared(opt, 3, 0.5f); + auto constant = std::make_shared(opt, 0.5f, 2); + + SequentialLR sched(opt, {warmup, step_lr, constant}, {3, 6}); + const std::vector expected = { + 0.033333f, 0.066667f, 0.1f, 0.1f, 0.1f, 0.05f, 0.05f, 0.1f, 0.1f, 0.1f}; + for (size_t i = 0; i < expected.size(); ++i) { + sched.Step(); + ASSERT_FLOAT_NEAR(sched.GetLR(), expected[i], 1e-5f); + } +} + +void TestStateRoundTrip() { + std::cout << "[TC4] TestStateRoundTrip" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto warmup = std::make_shared(opt, 3, 0.0f); + auto step_lr = std::make_shared(opt, 3, 0.5f); + SequentialLR sched(opt, {warmup, step_lr}, {3}); + + for (int i = 0; i < 5; ++i) sched.Step(); + StateDict saved = sched.State(); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + auto warmup2 = std::make_shared(opt2, 3, 0.0f); + auto step_lr2 = std::make_shared(opt2, 3, 0.5f); + SequentialLR sched2(opt2, {warmup2, step_lr2}, {3}); + sched2.LoadState(saved); + + ASSERT_TRUE(sched2.LastStep() == sched.LastStep()); + ASSERT_FLOAT_NEAR(sched2.GetLR(), sched.GetLR(), kEps); +} + +void TestResumeConsistency() { + std::cout << "[TC5] TestResumeConsistency" << std::endl; + constexpr int kN = 10, kK = 4; + + auto make_sched = [](std::shared_ptr opt) { + auto warmup = std::make_shared(opt, 3, 0.0f); + auto step_lr = std::make_shared(opt, 3, 0.5f); + return std::make_unique(opt, + std::vector>{warmup, step_lr}, + std::vector{3}); + }; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = make_sched(opt_ref); + for (int i = 0; i < kN; ++i) sched_ref->Step(); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = make_sched(opt_a); + for (int i = 0; i < kK; ++i) sched_a->Step(); + StateDict ckpt = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = make_sched(opt_b); + sched_b->LoadState(ckpt); + for (int i = 0; i < kN - kK; ++i) sched_b->Step(); + + ASSERT_FLOAT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== SequentialLR Tests ===" << std::endl; + TestWarmupThenConstant(); + TestWarmupThenStepLR(); + TestWarmupThenStepThenConstant(); + TestStateRoundTrip(); + TestResumeConsistency(); + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} \ No newline at end of file From df4c68d7a3e5b687a35c66095d77351349643a3c Mon Sep 17 00:00:00 2001 From: kinorw Date: Fri, 6 Mar 2026 00:23:46 +0800 Subject: [PATCH 10/18] refactor(lr_scheduler): apply template method pattern to LRScheduler base class - Change Step() to virtual with default implementation - Add pure virtual ComputeLR() for subclasses to implement. - Adapt test helpers (IdentityScheduler, LinearDecayScheduler) to implement ComputeLR() instead of Step(). - All existing tests pass without behavioral changes. BREAKING CHANGE: Subclasses must implement ComputeLR() instead of Step(). --- infini_train/include/lr_scheduler.h | 46 ++++++++----- infini_train/src/lr_scheduler.cc | 93 +++++++++++++++++--------- test/lr_scheduler/test_lr_scheduler.cc | 18 ++--- 3 files changed, 102 insertions(+), 55 deletions(-) diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h index 687fe66e..43ee67b5 100644 --- a/infini_train/include/lr_scheduler.h +++ b/infini_train/include/lr_scheduler.h @@ -27,7 +27,7 @@ class LRScheduler { LRScheduler(const LRScheduler &) = delete; LRScheduler &operator=(const LRScheduler &) = delete; - virtual void Step() = 0; + virtual void Step(); float GetLR() const; float BaseLR() const; @@ -40,6 +40,8 @@ class LRScheduler { protected: + virtual float ComputeLR() const = 0; + void ApplyLR(float lr); std::shared_ptr optimizer_; @@ -49,14 +51,17 @@ class LRScheduler { }; namespace lr_schedulers { + class ConstantLR : public LRScheduler { public: - ConstantLR(std::shared_ptr optimizer, float factor = 1.0f / 3.0f, int total_iters = 5, - int64_t last_step = -1); - + ConstantLR(std::shared_ptr optimizer, + float factor = 1.0f / 3.0f, + int total_iters = 5, + int64_t last_step = -1); ~ConstantLR() override = default; - void Step() override; +protected: + float ComputeLR() const override; private: const float factor_; @@ -65,10 +70,14 @@ class ConstantLR : public LRScheduler { class StepLR : public LRScheduler { public: - StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1); + StepLR(std::shared_ptr optimizer, + int64_t step_size, + float gamma = 0.1f, + int64_t last_step = -1); ~StepLR() override = default; - void Step() override; +protected: + float ComputeLR() const override; private: const int64_t step_size_; @@ -77,29 +86,34 @@ class StepLR : public LRScheduler { class LinearWarmupLR : public LRScheduler { public: - LinearWarmupLR(std::shared_ptr optimizer, int64_t warmup_steps, float start_factor = 0.0f, int64_t last_step = -1); + LinearWarmupLR(std::shared_ptr optimizer, + int64_t warmup_steps, + float start_factor = 0.0f, + int64_t last_step = -1); ~LinearWarmupLR() override = default; - void Step() override; +protected: + float ComputeLR() const override; private: const int64_t warmup_steps_; const float start_factor_; - }; class LambdaLR : public LRScheduler { public: using LambdaFunc = std::function; - LambdaLR(std::shared_ptr optimizer, LambdaFunc lr_lambda, int64_t last_step = -1); + LambdaLR(std::shared_ptr optimizer, + LambdaFunc lr_lambda, + int64_t last_step = -1); ~LambdaLR() override = default; - void Step() override; +protected: + float ComputeLR() const override; private: const LambdaFunc lr_lambda_; - }; @@ -109,18 +123,20 @@ class SequentialLR : public LRScheduler { std::vector> schedulers, std::vector milestones, int64_t last_step = -1); - ~SequentialLR() override = default; void Step() override; + StateDict State() const override; void LoadState(const StateDict &state) override; + +protected: + float ComputeLR() const override { return 0.0f; } private: std::vector> schedulers_; std::vector milestones_; }; - } // namespace lr_schedulers } // namespace infini_train \ No newline at end of file diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc index c18ad576..4978101b 100644 --- a/infini_train/src/lr_scheduler.cc +++ b/infini_train/src/lr_scheduler.cc @@ -18,6 +18,11 @@ LRScheduler::LRScheduler(std::shared_ptr optimizer, current_lr_ = base_lr_; } +void LRScheduler::Step() { + ++last_step_; + ApplyLR(ComputeLR()); +} + void LRScheduler::ApplyLR(float lr) { current_lr_ = lr; optimizer_->SetLearningRate(current_lr_); @@ -46,65 +51,91 @@ void LRScheduler::LoadState(const StateDict &state) { optimizer_->SetLearningRate(current_lr_); } + + +// Concrete LR Schedulers + namespace lr_schedulers { -ConstantLR::ConstantLR(std::shared_ptr optimizer, float factor, int total_iters, int64_t last_step) - : LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) { + +// --- ConstantLR --- + +ConstantLR::ConstantLR(std::shared_ptr optimizer, + float factor, + int total_iters, + int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), + factor_(factor), + total_iters_(total_iters) { Step(); } -void ConstantLR::Step() { - ++last_step_; - ApplyLR( - last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_); +float ConstantLR::ComputeLR() const { + return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; } -StepLR::StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma , int64_t last_step) - : LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) { +// --- StepLR --- + +StepLR::StepLR(std::shared_ptr optimizer, + int64_t step_size, + float gamma, + int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), + step_size_(step_size), + gamma_(gamma) { Step(); } -void StepLR::Step() { - ++last_step_; - ApplyLR(base_lr_ * static_cast( - std::pow(static_cast(gamma_), - static_cast(last_step_ / step_size_)))); +float StepLR::ComputeLR() const { + return base_lr_ * static_cast(std::pow( + static_cast(gamma_), + static_cast(last_step_ / step_size_))); } -LinearWarmupLR::LinearWarmupLR(std::shared_ptr optimizer, int64_t warmup_steps, float start_factor, int64_t last_step) - : LRScheduler(std::move(optimizer), last_step), warmup_steps_(warmup_steps), start_factor_(start_factor) { +LinearWarmupLR::LinearWarmupLR(std::shared_ptr optimizer, + int64_t warmup_steps, + float start_factor, + int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), + warmup_steps_(warmup_steps), + start_factor_(start_factor) { Step(); } -void LinearWarmupLR::Step() { - ++last_step_; - if (last_step_ >= warmup_steps_) { - ApplyLR(base_lr_); - } else{ - float alpha = static_cast(last_step_) / static_cast(warmup_steps_); - ApplyLR(base_lr_ * ( start_factor_ + (1.0f - start_factor_) * alpha)); - } +float LinearWarmupLR::ComputeLR() const { + if (last_step_ >= warmup_steps_) { + return base_lr_; + } + float alpha = + static_cast(last_step_) / static_cast(warmup_steps_); + return base_lr_ * (start_factor_ + (1.0f - start_factor_) * alpha); } -LambdaLR::LambdaLR(std::shared_ptr optimizer, std::function lr_lambda, int64_t last_step) - : LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) { +LambdaLR::LambdaLR(std::shared_ptr optimizer, + std::function lr_lambda, + int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), + lr_lambda_(std::move(lr_lambda)) { Step(); } -void LambdaLR::Step() { - ++last_step_; - ApplyLR(base_lr_ * lr_lambda_(last_step_)); +float LambdaLR::ComputeLR() const { + return base_lr_ * lr_lambda_(last_step_); } SequentialLR::SequentialLR(std::shared_ptr optimizer, std::vector> schedulers, std::vectormilestones, int64_t last_step) - : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)), milestones_(std::move(milestones)) { - CHECK(!schedulers_.empty()) << "SequentialLR requires at least one scheduler."; + : LRScheduler(std::move(optimizer), last_step), + schedulers_(std::move(schedulers)), + milestones_(std::move(milestones)) { + CHECK(!schedulers_.empty()) + << "SequentialLR requires at least one scheduler."; CHECK_EQ(milestones_.size(), schedulers_.size() - 1) << "SequentialLR: milestones count must be schedulers count - 1."; for(size_t i = 1; i < milestones_.size(); ++i) { - CHECK_GT(milestones_[i], milestones_[i-1]) << "Milestones must be strictly increasing."; + CHECK_GT(milestones_[i], milestones_[i-1]) + << "Milestones must be strictly increasing."; } optimizer_->SetLearningRate(schedulers_[0]->BaseLR()); diff --git a/test/lr_scheduler/test_lr_scheduler.cc b/test/lr_scheduler/test_lr_scheduler.cc index aae8bff7..c2b529bf 100644 --- a/test/lr_scheduler/test_lr_scheduler.cc +++ b/test/lr_scheduler/test_lr_scheduler.cc @@ -19,10 +19,8 @@ class IdentityScheduler : public LRScheduler { public: using LRScheduler::LRScheduler; - void Step() override { - ++last_step_; - ApplyLR(base_lr_); - } +protected: + float ComputeLR() const override { return base_lr_; } }; class LinearDecayScheduler : public LRScheduler { @@ -32,11 +30,13 @@ class LinearDecayScheduler : public LRScheduler { : LRScheduler(std::move(optimizer), last_step), total_steps_(total_steps) {} - void Step() override { - ++last_step_; - ApplyLR( - last_step_ >= total_steps_ ? 0.0f : base_lr_ * (1.0f - static_cast(last_step_) - / static_cast(total_steps_))); +protected: + float ComputeLR() const override { + if (last_step_ >= total_steps_) { + return 0.0f; + } + return base_lr_ * (1.0f - static_cast(last_step_) / + static_cast(total_steps_)); } private: From 5b4ef6d3e388d6ea4382433e9be9667be97e59ff Mon Sep 17 00:00:00 2001 From: kinorw Date: Fri, 6 Mar 2026 01:18:56 +0800 Subject: [PATCH 11/18] feat(lr_scheduler): add factory method Create() with two-phase init and update all tests to use Create() factory method. --- infini_train/include/lr_scheduler.h | 11 +++- infini_train/src/lr_scheduler.cc | 20 +++---- test/lr_scheduler/test_constant_lr.cc | 68 +++++++++++------------ test/lr_scheduler/test_lambda_lr.cc | 64 +++++++++++----------- test/lr_scheduler/test_linear_lr.cc | 40 +++++++------- test/lr_scheduler/test_lr_scheduler.cc | 71 +++++++++++++------------ test/lr_scheduler/test_sequential_lr.cc | 26 ++++----- test/lr_scheduler/test_step_lr.cc | 30 +++++------ 8 files changed, 169 insertions(+), 161 deletions(-) diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h index 43ee67b5..640b6140 100644 --- a/infini_train/include/lr_scheduler.h +++ b/infini_train/include/lr_scheduler.h @@ -19,6 +19,13 @@ using StateDict = std::unordered_map; class LRScheduler { public: + template + static std::shared_ptr Create(Args&&... args) { + auto scheduler = std::make_shared(std::forward(args)...); + scheduler->InitialStep(); + return scheduler; + } + explicit LRScheduler(std::shared_ptr optimizer, int64_t last_step = -1); @@ -42,6 +49,8 @@ class LRScheduler { virtual float ComputeLR() const = 0; + void InitialStep(); + void ApplyLR(float lr); std::shared_ptr optimizer_; @@ -129,7 +138,7 @@ class SequentialLR : public LRScheduler { StateDict State() const override; void LoadState(const StateDict &state) override; - + protected: float ComputeLR() const override { return 0.0f; } diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc index 4978101b..0339ab06 100644 --- a/infini_train/src/lr_scheduler.cc +++ b/infini_train/src/lr_scheduler.cc @@ -23,6 +23,10 @@ void LRScheduler::Step() { ApplyLR(ComputeLR()); } +void LRScheduler::InitialStep() { + Step(); +} + void LRScheduler::ApplyLR(float lr) { current_lr_ = lr; optimizer_->SetLearningRate(current_lr_); @@ -65,9 +69,7 @@ ConstantLR::ConstantLR(std::shared_ptr optimizer, int64_t last_step) : LRScheduler(std::move(optimizer), last_step), factor_(factor), - total_iters_(total_iters) { - Step(); -} + total_iters_(total_iters) {} float ConstantLR::ComputeLR() const { return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; @@ -81,9 +83,7 @@ StepLR::StepLR(std::shared_ptr optimizer, int64_t last_step) : LRScheduler(std::move(optimizer), last_step), step_size_(step_size), - gamma_(gamma) { - Step(); -} + gamma_(gamma) {} float StepLR::ComputeLR() const { return base_lr_ * static_cast(std::pow( @@ -97,9 +97,7 @@ LinearWarmupLR::LinearWarmupLR(std::shared_ptr optimizer, int64_t last_step) : LRScheduler(std::move(optimizer), last_step), warmup_steps_(warmup_steps), - start_factor_(start_factor) { - Step(); -} + start_factor_(start_factor) {} float LinearWarmupLR::ComputeLR() const { if (last_step_ >= warmup_steps_) { @@ -114,9 +112,7 @@ LambdaLR::LambdaLR(std::shared_ptr optimizer, std::function lr_lambda, int64_t last_step) : LRScheduler(std::move(optimizer), last_step), - lr_lambda_(std::move(lr_lambda)) { - Step(); -} + lr_lambda_(std::move(lr_lambda)) {} float LambdaLR::ComputeLR() const { return base_lr_ * lr_lambda_(last_step_); diff --git a/test/lr_scheduler/test_constant_lr.cc b/test/lr_scheduler/test_constant_lr.cc index f997742b..176b0d8f 100644 --- a/test/lr_scheduler/test_constant_lr.cc +++ b/test/lr_scheduler/test_constant_lr.cc @@ -11,60 +11,60 @@ constexpr float kBaseLR = 0.1f; void TestInitialState() { auto opt = MakeDummyOptimizer(kBaseLR); - ConstantLR sched(opt, /*factor=*/0.5f, /*total_iters=*/3); - ASSERT_FLOAT_EQ(sched.GetLR(), 0.05f); - ASSERT_TRUE(sched.LastStep() == 0); + auto sched = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/3); + ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); + ASSERT_TRUE(sched->LastStep() == 0); ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f); } void TestFirstStepAppliesFactor() { auto opt = MakeDummyOptimizer(kBaseLR); - ConstantLR sched(opt, 0.5f, 3); - sched.Step(); // last_step_ = 0 - ASSERT_FLOAT_EQ(sched.GetLR(), 0.05f); + auto sched = LRScheduler::Create(opt, 0.5f, 3); + sched->Step(); // last_step_ = 0 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f); - ASSERT_TRUE(sched.LastStep() == 1); + ASSERT_TRUE(sched->LastStep() == 1); } void TestWithinTotalIters() { auto opt = MakeDummyOptimizer(kBaseLR); - ConstantLR sched(opt, 0.5f, 3); - for (int i = 0; i < 2; ++i) sched.Step(); + auto sched = LRScheduler::Create(opt, 0.5f, 3); + for (int i = 0; i < 2; ++i) sched->Step(); // last_step_ = 2, still < 3 - ASSERT_FLOAT_EQ(sched.GetLR(), 0.05f); + ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); } void TestBeyondTotalIters() { auto opt = MakeDummyOptimizer(kBaseLR); - ConstantLR sched(opt, 0.5f, 3); - for (int i = 0; i < 10; ++i) sched.Step(); - ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); + auto sched = LRScheduler::Create(opt, 0.5f, 3); + for (int i = 0; i < 10; ++i) sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); } void TestPyTorchAlignment() { const std::vector expected = {0.05f, 0.05f, 0.1f, 0.1f, 0.1f}; auto opt = MakeDummyOptimizer(kBaseLR); - ConstantLR sched(opt, 0.5f, 3); + auto sched = LRScheduler::Create(opt, 0.5f, 3); for (size_t i = 0; i < expected.size(); ++i) { - sched.Step(); - ASSERT_FLOAT_EQ(sched.GetLR(), expected[i]); + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), expected[i]); } } void TestStateRoundTrip() { auto opt = MakeDummyOptimizer(kBaseLR); - ConstantLR sched(opt, 0.5f, 5); - for (int i = 0; i < 3; ++i) sched.Step(); - StateDict saved = sched.State(); + auto sched = LRScheduler::Create(opt, 0.5f, 5); + for (int i = 0; i < 3; ++i) sched->Step(); + StateDict saved = sched->State(); auto opt2 = MakeDummyOptimizer(kBaseLR); - ConstantLR sched2(opt2, 0.5f, 5); - sched2.LoadState(saved); + auto sched2 = LRScheduler::Create(opt2, 0.5f, 5); + sched2->LoadState(saved); - ASSERT_TRUE(sched2.LastStep() == sched.LastStep()); - ASSERT_FLOAT_EQ(sched2.GetLR(), sched.GetLR()); - ASSERT_FLOAT_EQ(opt2->GetLearningRate(), sched.GetLR()); + ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); + ASSERT_FLOAT_EQ(sched2->GetLR(), sched->GetLR()); + ASSERT_FLOAT_EQ(opt2->GetLearningRate(), sched->GetLR()); } void TestResumeConsistency() { @@ -72,21 +72,21 @@ void TestResumeConsistency() { constexpr int kK = 3; auto opt_ref = MakeDummyOptimizer(kBaseLR); - ConstantLR sched_ref(opt_ref, 0.5f, 5); - for (int i = 0; i < kN; ++i) sched_ref.Step(); + auto sched_ref = LRScheduler::Create(opt_ref, 0.5f, 5); + for (int i = 0; i < kN; ++i) sched_ref->Step(); auto opt_a = MakeDummyOptimizer(kBaseLR); - ConstantLR sched_a(opt_a, 0.5f, 5); - for (int i = 0; i < kK; ++i) sched_a.Step(); - StateDict ckpt = sched_a.State(); + auto sched_a = LRScheduler::Create(opt_a, 0.5f, 5); + for (int i = 0; i < kK; ++i) sched_a->Step(); + StateDict ckpt = sched_a->State(); auto opt_b = MakeDummyOptimizer(kBaseLR); - ConstantLR sched_b(opt_b, 0.5f, 5); - sched_b.LoadState(ckpt); - for (int i = 0; i < kN - kK; ++i) sched_b.Step(); + auto sched_b = LRScheduler::Create(opt_b, 0.5f, 5); + sched_b->LoadState(ckpt); + for (int i = 0; i < kN - kK; ++i) sched_b->Step(); - ASSERT_FLOAT_EQ(sched_b.GetLR(), sched_ref.GetLR()); - ASSERT_TRUE(sched_b.LastStep() == sched_ref.LastStep()); + ASSERT_FLOAT_EQ(sched_b->GetLR(), sched_ref->GetLR()); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); } int main(int argc, char *argv[]) { diff --git a/test/lr_scheduler/test_lambda_lr.cc b/test/lr_scheduler/test_lambda_lr.cc index 37b0af23..0a99f132 100644 --- a/test/lr_scheduler/test_lambda_lr.cc +++ b/test/lr_scheduler/test_lambda_lr.cc @@ -10,58 +10,58 @@ constexpr float kBaseLR = 0.1f; void TestIdentityLambda() { auto opt = MakeDummyOptimizer(kBaseLR); - LambdaLR sched(opt, [](int64_t) { return 1.0f; }); + auto sched = LRScheduler::Create(opt, [](int64_t) { return 1.0f; }); // 构造器内 Step() → last_step_=0, lr = 0.1 * 1.0 = 0.1 - ASSERT_TRUE(sched.LastStep() == 0); - ASSERT_FLOAT_NEAR(sched.GetLR(), kBaseLR, kEps); + ASSERT_TRUE(sched->LastStep() == 0); + ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps); ASSERT_FLOAT_NEAR(opt->GetLearningRate(), kBaseLR, kEps); } void TestLinearDecayLambda() { auto opt = MakeDummyOptimizer(kBaseLR); - LambdaLR sched(opt, [](int64_t step) { return 1.0f - step * 0.1f; }); + auto sched = LRScheduler::Create(opt, [](int64_t step) { return 1.0f - step * 0.1f; }); // step=0, lambda(0)=1.0, lr=0.1 - ASSERT_FLOAT_NEAR(sched.GetLR(), 0.1f, kEps); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); - sched.Step(); // step=1, lambda(1)=0.9, lr=0.09 - ASSERT_FLOAT_NEAR(sched.GetLR(), 0.09f, kEps); + sched->Step(); // step=1, lambda(1)=0.9, lr=0.09 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.09f, kEps); - sched.Step(); // step=2, lambda(2)=0.8, lr=0.08 - ASSERT_FLOAT_NEAR(sched.GetLR(), 0.08f, kEps); + sched->Step(); // step=2, lambda(2)=0.8, lr=0.08 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.08f, kEps); - sched.Step(); // step=3, lambda(3)=0.7, lr=0.07 - ASSERT_FLOAT_NEAR(sched.GetLR(), 0.07f, kEps); + sched->Step(); // step=3, lambda(3)=0.7, lr=0.07 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.07f, kEps); } void TestPyTorchAlignment() { // PyTorch: LambdaLR(opt, lr_lambda=lambda epoch: 0.95**epoch) auto opt = MakeDummyOptimizer(kBaseLR); - LambdaLR sched(opt, [](int64_t step) { + auto sched = LRScheduler::Create(opt, [](int64_t step) { return static_cast(std::pow(0.95, step)); }); // step=0, lr = 0.1 * 0.95^0 = 0.1 - ASSERT_FLOAT_NEAR(sched.GetLR(), 0.1f, kEps); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); std::vector expected = {0.095f, 0.09025f, 0.0857375f}; for (size_t i = 0; i < expected.size(); ++i) { - sched.Step(); - ASSERT_FLOAT_NEAR(sched.GetLR(), expected[i], 1e-5f); + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-5f); } } void TestStateRoundTrip() { auto lambda_fn = [](int64_t step) { return 1.0f - step * 0.05f; }; auto opt = MakeDummyOptimizer(kBaseLR); - LambdaLR sched(opt, lambda_fn); - for (int i = 0; i < 5; ++i) sched.Step(); - StateDict saved = sched.State(); + auto sched = LRScheduler::Create(opt, lambda_fn); + for (int i = 0; i < 5; ++i) sched->Step(); + StateDict saved = sched->State(); auto opt2 = MakeDummyOptimizer(kBaseLR); - LambdaLR sched2(opt2, lambda_fn); // same lambda - sched2.LoadState(saved); + auto sched2 = LRScheduler::Create(opt2, lambda_fn); // same lambda + sched2->LoadState(saved); - ASSERT_TRUE(sched2.LastStep() == sched.LastStep()); - ASSERT_FLOAT_NEAR(sched2.GetLR(), sched.GetLR(), kEps); + ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); + ASSERT_FLOAT_NEAR(sched2->GetLR(), sched->GetLR(), kEps); } void TestResumeConsistency() { @@ -69,21 +69,21 @@ void TestResumeConsistency() { constexpr int kN = 10, kK = 4; auto opt_ref = MakeDummyOptimizer(kBaseLR); - LambdaLR sched_ref(opt_ref, lambda_fn); - for (int i = 0; i < kN; ++i) sched_ref.Step(); + auto sched_ref = LRScheduler::Create(opt_ref, lambda_fn); + for (int i = 0; i < kN; ++i) sched_ref->Step(); auto opt_a = MakeDummyOptimizer(kBaseLR); - LambdaLR sched_a(opt_a, lambda_fn); - for (int i = 0; i < kK; ++i) sched_a.Step(); - StateDict ckpt = sched_a.State(); + auto sched_a = LRScheduler::Create(opt_a, lambda_fn); + for (int i = 0; i < kK; ++i) sched_a->Step(); + StateDict ckpt = sched_a->State(); auto opt_b = MakeDummyOptimizer(kBaseLR); - LambdaLR sched_b(opt_b, lambda_fn); - sched_b.LoadState(ckpt); - for (int i = 0; i < kN - kK; ++i) sched_b.Step(); + auto sched_b = LRScheduler::Create(opt_b, lambda_fn); + sched_b->LoadState(ckpt); + for (int i = 0; i < kN - kK; ++i) sched_b->Step(); - ASSERT_FLOAT_NEAR(sched_b.GetLR(), sched_ref.GetLR(), kEps); - ASSERT_TRUE(sched_b.LastStep() == sched_ref.LastStep()); + ASSERT_FLOAT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); } diff --git a/test/lr_scheduler/test_linear_lr.cc b/test/lr_scheduler/test_linear_lr.cc index 8d6981ed..33b13847 100644 --- a/test/lr_scheduler/test_linear_lr.cc +++ b/test/lr_scheduler/test_linear_lr.cc @@ -10,51 +10,51 @@ constexpr float kBaseLR = 0.1f; void TestFirstStepFromZero() { auto opt = MakeDummyOptimizer(kBaseLR); - LinearWarmupLR sched(opt, /*warmup_steps=*/5, /*start_factor=*/0.0f); - sched.Step(); // last_step_=1, alpha=1/5=0.2 → lr=0.02 - ASSERT_FLOAT_EQ(sched.GetLR(), 0.02f); + auto sched = LRScheduler::Create(opt, /*warmup_steps=*/5, /*start_factor=*/0.0f); + sched->Step(); // last_step_=1, alpha=1/5=0.2 → lr=0.02 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.02f); } void TestMidpointLR() { auto opt = MakeDummyOptimizer(kBaseLR); - LinearWarmupLR sched(opt, 5, 0.0f); - for (int i = 0; i < 3; ++i) sched.Step(); + auto sched = LRScheduler::Create(opt, 5, 0.0f); + for (int i = 0; i < 3; ++i) sched->Step(); // last_step_=2, alpha=3/5=0.6 → lr=0.1*0.6=0.06 - ASSERT_FLOAT_EQ(sched.GetLR(), 0.06f); + ASSERT_FLOAT_EQ(sched->GetLR(), 0.06f); } void TestWarmupEnd() { auto opt = MakeDummyOptimizer(kBaseLR); - LinearWarmupLR sched(opt, 5, 0.0f); - for (int i = 0; i < 5; ++i) sched.Step(); + auto sched = LRScheduler::Create(opt, 5, 0.0f); + for (int i = 0; i < 5; ++i) sched->Step(); // last_step_=5, 5 >= 5 → base_lr - ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); } void TestBeyondWarmup() { auto opt = MakeDummyOptimizer(kBaseLR); - LinearWarmupLR sched(opt, 5, 0.0f); - for (int i = 0; i < 20; ++i) sched.Step(); - ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); + auto sched = LRScheduler::Create(opt, 5, 0.0f); + for (int i = 0; i < 20; ++i) sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); } void TestCustomStartFactor() { auto opt = MakeDummyOptimizer(kBaseLR); - LinearWarmupLR sched(opt, 4, /*start_factor=*/0.25f); - sched.Step(); // last_step_=1, alpha=1/4=0.25 → lr=0.1*(0.25+0.75*0.25)=0.04375 - ASSERT_FLOAT_NEAR(sched.GetLR(), 0.04375f, 1e-6f); - sched.Step(); // last_step_=2, alpha=2/4=0.5 → lr=0.1*(0.25+0.75*0.5)=0.0625 - ASSERT_FLOAT_NEAR(sched.GetLR(), 0.0625f, 1e-6f); + auto sched = LRScheduler::Create(opt, 4, /*start_factor=*/0.25f); + sched->Step(); // last_step_=1, alpha=1/4=0.25 → lr=0.1*(0.25+0.75*0.25)=0.04375 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.04375f, 1e-6f); + sched->Step(); // last_step_=2, alpha=2/4=0.5 → lr=0.1*(0.25+0.75*0.5)=0.0625 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.0625f, 1e-6f); } void TestPyTorchAlignment() { const std::vector expected = { 0.02f, 0.04f, 0.06f, 0.08f, 0.1f, 0.1f, 0.1f}; auto opt = MakeDummyOptimizer(kBaseLR); - LinearWarmupLR sched(opt, 5, 0.0f); + auto sched = LRScheduler::Create(opt, 5, 0.0f); for (size_t i = 0; i < expected.size(); ++i) { - sched.Step(); - ASSERT_FLOAT_NEAR(sched.GetLR(), expected[i], 1e-7f); + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-7f); } } diff --git a/test/lr_scheduler/test_lr_scheduler.cc b/test/lr_scheduler/test_lr_scheduler.cc index c2b529bf..89856435 100644 --- a/test/lr_scheduler/test_lr_scheduler.cc +++ b/test/lr_scheduler/test_lr_scheduler.cc @@ -17,7 +17,9 @@ constexpr float kEps = 1e-7f; class IdentityScheduler : public LRScheduler { public: - using LRScheduler::LRScheduler; + IdentityScheduler(std::shared_ptr optimizer, int64_t last_step = -1) + : LRScheduler(std::move(optimizer), last_step) {} + ~IdentityScheduler() override = default; protected: float ComputeLR() const override { return base_lr_; } @@ -68,10 +70,10 @@ void Check(bool cond, const char *expr, int line) { void TestInitialState() { std::cout << "[T1] TestInitialState" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - IdentityScheduler sched(opt); + auto sched = LRScheduler::Create(opt); - ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); - ASSERT_TRUE(sched.LastStep() == -1); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); + ASSERT_TRUE(sched->LastStep() == 0); ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); } @@ -79,12 +81,12 @@ void TestInitialState() { void TestSingleStep() { std::cout << "[T2] TestSingleStep" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - IdentityScheduler sched(opt); + auto sched = LRScheduler::Create(opt); - sched.Step(); + sched->Step(); - ASSERT_TRUE(sched.LastStep() == 0); - ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); + ASSERT_TRUE(sched->LastStep() == 1); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); } @@ -93,15 +95,16 @@ void TestLinearDecay() { std::cout << "[T3] TestLinearDecay" << std::endl; constexpr int64_t kTotalSteps = 10; auto opt = MakeDummyOptimizer(kBaseLR); - LinearDecayScheduler sched(opt, kTotalSteps); - - sched.Step(); - ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); + auto sched = LRScheduler::Create(opt, kTotalSteps); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); - for (int i = 0; i < 4; ++i) { sched.Step(); } - sched.Step(); - ASSERT_FLOAT_EQ(sched.GetLR(), 0.05f); + sched->Step(); // last_step = 1 -> 0.09 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.09f); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.09f); + + for (int i = 0; i < 4; ++i) { sched->Step(); } // last_step = 5 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f); } @@ -110,23 +113,23 @@ void TestStateRoundTrip() { std::cout << "[T4] TestStateRoundTrip" << std::endl; constexpr int64_t kTotalSteps = 20; auto opt = MakeDummyOptimizer(kBaseLR); - LinearDecayScheduler sched(opt, kTotalSteps); + auto sched = LRScheduler::Create(opt, kTotalSteps); - for (int i = 0; i < 7; ++i) { sched.Step(); } + for (int i = 0; i < 7; ++i) { sched->Step(); } - StateDict saved = sched.State(); + StateDict saved = sched->State(); ASSERT_TRUE(saved.count("last_step") == 1); ASSERT_TRUE(saved.count("current_lr") == 1); ASSERT_TRUE(saved.count("base_lr") == 1); auto opt2 = MakeDummyOptimizer(kBaseLR); - LinearDecayScheduler sched2(opt2, kTotalSteps); - sched2.LoadState(saved); + auto sched2 = LRScheduler::Create(opt2, kTotalSteps); + sched2->LoadState(saved); - ASSERT_TRUE(sched2.LastStep() == 6); - ASSERT_FLOAT_EQ(sched2.GetLR(), sched.GetLR()); - ASSERT_FLOAT_EQ(opt2->GetLearningRate(), sched.GetLR()); + ASSERT_TRUE(sched2->LastStep() == 7); + ASSERT_FLOAT_EQ(sched2->GetLR(), sched->GetLR()); + ASSERT_FLOAT_EQ(opt2->GetLearningRate(), sched->GetLR()); } // T5: resume Step @@ -135,22 +138,22 @@ void TestResumeAndContinue() { constexpr int64_t kTotalSteps = 20; auto opt_ref = MakeDummyOptimizer(kBaseLR); - LinearDecayScheduler sched_ref(opt_ref, kTotalSteps); - for (int i = 0; i < 10; ++i) { sched_ref.Step(); } - float lr_at_10 = sched_ref.GetLR(); + auto sched_ref = LRScheduler::Create(opt_ref, kTotalSteps); + for (int i = 0; i < 10; ++i) { sched_ref->Step(); } + float lr_at_10 = sched_ref->GetLR(); auto opt_a = MakeDummyOptimizer(kBaseLR); - LinearDecayScheduler sched_a(opt_a, kTotalSteps); - for (int i = 0; i < 5; ++i) { sched_a.Step(); } - StateDict checkpoint = sched_a.State(); + auto sched_a = LRScheduler::Create(opt_a, kTotalSteps); + for (int i = 0; i < 5; ++i) { sched_a->Step(); } + StateDict checkpoint = sched_a->State(); auto opt_b = MakeDummyOptimizer(kBaseLR); - LinearDecayScheduler sched_b(opt_b, kTotalSteps); - sched_b.LoadState(checkpoint); - for (int i = 0; i < 5; ++i) { sched_b.Step(); } + auto sched_b = LRScheduler::Create(opt_b, kTotalSteps); + sched_b->LoadState(checkpoint); + for (int i = 0; i < 5; ++i) { sched_b->Step(); } - ASSERT_FLOAT_EQ(sched_b.GetLR(), lr_at_10); - ASSERT_TRUE(sched_b.LastStep() == sched_ref.LastStep()); + ASSERT_FLOAT_EQ(sched_b->GetLR(), lr_at_10); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); } } // namespace diff --git a/test/lr_scheduler/test_sequential_lr.cc b/test/lr_scheduler/test_sequential_lr.cc index 2a29a208..ef558258 100644 --- a/test/lr_scheduler/test_sequential_lr.cc +++ b/test/lr_scheduler/test_sequential_lr.cc @@ -11,8 +11,8 @@ void TestWarmupThenConstant() { std::cout << "[TC1] TestWarmupThenConstant" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto warmup = std::make_shared(opt, /*warmup_steps=*/3, /*start_factor=*/1e-8); - auto constant = std::make_shared(opt, /*factor=*/1.0f, /*total_iters=*/100); + auto warmup = LRScheduler::Create(opt, /*warmup_steps=*/3, /*start_factor=*/1e-8); + auto constant = LRScheduler::Create(opt, /*factor=*/1.0f, /*total_iters=*/100); SequentialLR sched(opt, {warmup, constant}, {3}); @@ -35,8 +35,8 @@ void TestWarmupThenStepLR() { std::cout << "[TC2] TestWarmupThenStepLR" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto warmup = std::make_shared(opt, 3, 0.0f); - auto step_lr = std::make_shared(opt, /*step_size=*/3, /*gamma=*/0.5f); + auto warmup = LRScheduler::Create(opt, /*warmup_steps=*/3, /*start_factor=*/0.0f); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); SequentialLR sched(opt, {warmup, step_lr}, {3}); @@ -56,9 +56,9 @@ void TestWarmupThenStepThenConstant(){ std::cout << "[TC3] TestWarmupThenStepThenConstant" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto warmup = std::make_shared(opt, 3, 0.0f); - auto step_lr = std::make_shared(opt, 3, 0.5f); - auto constant = std::make_shared(opt, 0.5f, 2); + auto warmup = LRScheduler::Create(opt, /*warmup_steps=*/3, /*start_factor=*/0.0f); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); + auto constant = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/2); SequentialLR sched(opt, {warmup, step_lr, constant}, {3, 6}); const std::vector expected = { @@ -72,16 +72,16 @@ void TestWarmupThenStepThenConstant(){ void TestStateRoundTrip() { std::cout << "[TC4] TestStateRoundTrip" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto warmup = std::make_shared(opt, 3, 0.0f); - auto step_lr = std::make_shared(opt, 3, 0.5f); + auto warmup = LRScheduler::Create(opt, /*warmup_steps=*/3, /*start_factor=*/0.0f); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); SequentialLR sched(opt, {warmup, step_lr}, {3}); for (int i = 0; i < 5; ++i) sched.Step(); StateDict saved = sched.State(); auto opt2 = MakeDummyOptimizer(kBaseLR); - auto warmup2 = std::make_shared(opt2, 3, 0.0f); - auto step_lr2 = std::make_shared(opt2, 3, 0.5f); + auto warmup2 = LRScheduler::Create(opt2, /*warmup_steps=*/3, /*start_factor=*/0.0f); + auto step_lr2 = LRScheduler::Create(opt2, /*step_size=*/3, /*gamma=*/0.5f); SequentialLR sched2(opt2, {warmup2, step_lr2}, {3}); sched2.LoadState(saved); @@ -94,8 +94,8 @@ void TestResumeConsistency() { constexpr int kN = 10, kK = 4; auto make_sched = [](std::shared_ptr opt) { - auto warmup = std::make_shared(opt, 3, 0.0f); - auto step_lr = std::make_shared(opt, 3, 0.5f); + auto warmup = LRScheduler::Create(opt, /*warmup_steps=*/3, /*start_factor=*/0.0f); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); return std::make_unique(opt, std::vector>{warmup, step_lr}, std::vector{3}); diff --git a/test/lr_scheduler/test_step_lr.cc b/test/lr_scheduler/test_step_lr.cc index b7c1466e..74cce52e 100644 --- a/test/lr_scheduler/test_step_lr.cc +++ b/test/lr_scheduler/test_step_lr.cc @@ -10,46 +10,46 @@ constexpr float kBaseLR = 0.1f; void TestWithinFirstPeriod() { auto opt = MakeDummyOptimizer(kBaseLR); - StepLR sched(opt, /*step_size=*/3, /*gamma=*/0.1f); + auto sched = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); for (int i = 0; i < 2; ++i) { - sched.Step(); - ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); // last_step 1,2 → 指数 0 + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); // last_step 1,2 → 指数 0 } } void TestFirstDecay() { auto opt = MakeDummyOptimizer(kBaseLR); - StepLR sched(opt, 3, 0.1f); - for (int i = 0; i < 3; ++i) sched.Step(); + auto sched = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); + for (int i = 0; i < 3; ++i) sched->Step(); // last_step=3, 3//3=1 → 0.1^1 = 0.1 → lr=0.01 - ASSERT_FLOAT_EQ(sched.GetLR(), 0.01f); + ASSERT_FLOAT_EQ(sched->GetLR(), 0.01f); } void TestMultipleDecays() { auto opt = MakeDummyOptimizer(kBaseLR); - StepLR sched(opt, 3, 0.1f); - for (int i = 0; i < 6; ++i) sched.Step(); + auto sched = LRScheduler::Create(opt, 3, 0.1f); + for (int i = 0; i < 6; ++i) sched->Step(); // last_step=6, 6//3=2 → 0.1^2 = 0.01 → lr=0.001 - ASSERT_FLOAT_NEAR(sched.GetLR(), 0.001f, 1e-7f); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.001f, 1e-7f); } void TestPyTorchAlignment() { const std::vector expected = { 0.1f, 0.1f, 0.01f, 0.01f, 0.01f, 0.001f, 0.001f}; auto opt = MakeDummyOptimizer(kBaseLR); - StepLR sched(opt, 3, 0.1f); + auto sched = LRScheduler::Create(opt, 3, 0.1f); for (size_t i = 0; i < expected.size(); ++i) { - sched.Step(); - ASSERT_FLOAT_NEAR(sched.GetLR(), expected[i], 1e-7f); + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-7f); } } void TestGammaOne() { auto opt = MakeDummyOptimizer(kBaseLR); - StepLR sched(opt, 5, 1.0f); + auto sched = LRScheduler::Create(opt, 3, 1.0f); for (int i = 0; i < 20; ++i) { - sched.Step(); - ASSERT_FLOAT_EQ(sched.GetLR(), kBaseLR); + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); } } From 8c11dd9db04d3927761436dfc29f66c81b06e89c Mon Sep 17 00:00:00 2001 From: kinorw Date: Fri, 6 Mar 2026 23:12:52 +0800 Subject: [PATCH 12/18] =?UTF-8?q?feat(lr=5Fscheduler):=20add=20closed=20an?= =?UTF-8?q?d=20chained=20form,=20adjust=20LinearLR=E3=80=81SequentialLR=20?= =?UTF-8?q?-=20enhance=20LRScheduler=20with=20chained=20and=20closed=20for?= =?UTF-8?q?m=20learning=20rate=20methods=20-=20adapt=20methods(Step,=20Ini?= =?UTF-8?q?tialStep,=20GetClosedFormLR,=20GetChainedFormLR)=20to=20match?= =?UTF-8?q?=20PyTorch=E2=80=98s=20design=20-=20add=20tests=20for=20consist?= =?UTF-8?q?ency=20-=20refactor=20LinearLR:=20add=20end=5Ffactor,=20and=20r?= =?UTF-8?q?ename=20this=20class=20-=20add=20SequentialLR=20InitialStep=20a?= =?UTF-8?q?nd=20UndoChildInitialSteps?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BREAKING CHANGE: Subclasses must implement GetClosedFormLR instead of ComputeLR(). Should use LinearLR instead of LinearwarmupLR. --- infini_train/include/lr_scheduler.h | 45 ++++---- infini_train/src/lr_scheduler.cc | 130 +++++++++++++++++------- test/lr_scheduler/test_constant_lr.cc | 15 +++ test/lr_scheduler/test_linear_lr.cc | 42 +++++--- test/lr_scheduler/test_lr_scheduler.cc | 4 +- test/lr_scheduler/test_sequential_lr.cc | 93 +++++++++-------- test/lr_scheduler/test_step_lr.cc | 15 +++ 7 files changed, 228 insertions(+), 116 deletions(-) diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h index 640b6140..464f3a82 100644 --- a/infini_train/include/lr_scheduler.h +++ b/infini_train/include/lr_scheduler.h @@ -28,35 +28,33 @@ class LRScheduler { explicit LRScheduler(std::shared_ptr optimizer, int64_t last_step = -1); - virtual ~LRScheduler() = default; LRScheduler(const LRScheduler &) = delete; LRScheduler &operator=(const LRScheduler &) = delete; virtual void Step(); + virtual void Step(int64_t epoch); + virtual void InitialStep(); float GetLR() const; float BaseLR() const; int64_t LastStep() const; void ResetStep(int64_t step = -1); - virtual StateDict State() const; virtual void LoadState(const StateDict &state); protected: - - virtual float ComputeLR() const = 0; - - void InitialStep(); - + virtual float GetClosedFormLR() const = 0; + virtual float GetChainedFormLR() const; void ApplyLR(float lr); std::shared_ptr optimizer_; int64_t last_step_; float current_lr_; float base_lr_; + bool is_initial_ = false; }; namespace lr_schedulers { @@ -70,7 +68,8 @@ class ConstantLR : public LRScheduler { ~ConstantLR() override = default; protected: - float ComputeLR() const override; + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; private: const float factor_; @@ -86,27 +85,31 @@ class StepLR : public LRScheduler { ~StepLR() override = default; protected: - float ComputeLR() const override; + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; private: const int64_t step_size_; const float gamma_; }; -class LinearWarmupLR : public LRScheduler { -public: - LinearWarmupLR(std::shared_ptr optimizer, - int64_t warmup_steps, - float start_factor = 0.0f, - int64_t last_step = -1); - ~LinearWarmupLR() override = default; +class LinearLR : public LRScheduler { +public: + LinearLR(std::shared_ptr optimizer, + float start_factor = 1.0f / 3.0f, + float end_factor = 1.0f, + int64_t total_iters = 5, + int64_t last_step = -1); + ~LinearLR() override = default; protected: - float ComputeLR() const override; + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; private: - const int64_t warmup_steps_; const float start_factor_; + const float end_factor_; + const int64_t total_iters_; }; class LambdaLR : public LRScheduler { @@ -119,7 +122,7 @@ class LambdaLR : public LRScheduler { ~LambdaLR() override = default; protected: - float ComputeLR() const override; + float GetClosedFormLR() const override; private: const LambdaFunc lr_lambda_; @@ -135,12 +138,14 @@ class SequentialLR : public LRScheduler { ~SequentialLR() override = default; void Step() override; + void InitialStep() override; StateDict State() const override; void LoadState(const StateDict &state) override; protected: - float ComputeLR() const override { return 0.0f; } + float GetClosedFormLR() const override { return current_lr_; } + void UndoChildInitialSteps(); private: std::vector> schedulers_; diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc index 0339ab06..050ac2bf 100644 --- a/infini_train/src/lr_scheduler.cc +++ b/infini_train/src/lr_scheduler.cc @@ -20,11 +20,18 @@ LRScheduler::LRScheduler(std::shared_ptr optimizer, void LRScheduler::Step() { ++last_step_; - ApplyLR(ComputeLR()); + ApplyLR(GetChainedFormLR()); +} + +void LRScheduler::Step(int64_t epoch) { + last_step_ = epoch; + ApplyLR(GetClosedFormLR()); } void LRScheduler::InitialStep() { + is_initial_ = true; Step(); + is_initial_ = false; } void LRScheduler::ApplyLR(float lr) { @@ -32,6 +39,10 @@ void LRScheduler::ApplyLR(float lr) { optimizer_->SetLearningRate(current_lr_); } +float LRScheduler::GetChainedFormLR() const { + return GetClosedFormLR(); +} + float LRScheduler::GetLR() const { return current_lr_; } float LRScheduler::BaseLR() const { return base_lr_; } @@ -71,10 +82,22 @@ ConstantLR::ConstantLR(std::shared_ptr optimizer, factor_(factor), total_iters_(total_iters) {} -float ConstantLR::ComputeLR() const { +float ConstantLR::GetClosedFormLR() const { return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; } +float ConstantLR::GetChainedFormLR() const { + const float lr = optimizer_->GetLearningRate(); + if (last_step_ == 0) { + return lr * factor_; + } else if (last_step_ < total_iters_) { + return lr; + } else if (last_step_ == total_iters_){ + return lr / factor_; + } + return lr; +} + // --- StepLR --- StepLR::StepLR(std::shared_ptr optimizer, @@ -85,27 +108,63 @@ StepLR::StepLR(std::shared_ptr optimizer, step_size_(step_size), gamma_(gamma) {} -float StepLR::ComputeLR() const { +float StepLR::GetClosedFormLR() const { return base_lr_ * static_cast(std::pow( static_cast(gamma_), static_cast(last_step_ / step_size_))); } -LinearWarmupLR::LinearWarmupLR(std::shared_ptr optimizer, - int64_t warmup_steps, - float start_factor, - int64_t last_step) - : LRScheduler(std::move(optimizer), last_step), - warmup_steps_(warmup_steps), - start_factor_(start_factor) {} +float StepLR::GetChainedFormLR() const { + const float lr = optimizer_->GetLearningRate(); + if (last_step_ == 0 || (last_step_ % step_size_) != 0) { + return lr; + } + return lr * gamma_; +} + + +LinearLR::LinearLR(std::shared_ptr optimizer, + float start_factor, + float end_factor, + int64_t total_iters, + int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), + start_factor_(start_factor), + end_factor_(end_factor), + total_iters_(total_iters) {} + +float LinearLR::GetClosedFormLR() const { + if (last_step_ >= total_iters_) { + return base_lr_ * end_factor_; + } + return base_lr_ * + (start_factor_ + (end_factor_ - start_factor_) * + static_cast(last_step_) / + static_cast(total_iters_)); +} + +float LinearLR::GetChainedFormLR() const { + const float lr = optimizer_->GetLearningRate(); + if (last_step_ == 0) { + return lr * start_factor_; + } + if (last_step_ > total_iters_ || is_initial_) { + return lr; + } + if (last_step_ == total_iters_) { + const float prev_factor = + start_factor_ + + (end_factor_ - start_factor_) * + static_cast(total_iters_ - 1) / + static_cast(total_iters_); + return lr * (end_factor_ / prev_factor); + } -float LinearWarmupLR::ComputeLR() const { - if (last_step_ >= warmup_steps_) { - return base_lr_; - } - float alpha = - static_cast(last_step_) / static_cast(warmup_steps_); - return base_lr_ * (start_factor_ + (1.0f - start_factor_) * alpha); + const float numerator = end_factor_ - start_factor_; + const float denominator = + start_factor_ * static_cast(total_iters_) + + static_cast(last_step_ - 1) * numerator; + return lr * (1.0f + numerator / denominator); } LambdaLR::LambdaLR(std::shared_ptr optimizer, @@ -114,7 +173,7 @@ LambdaLR::LambdaLR(std::shared_ptr optimizer, : LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) {} -float LambdaLR::ComputeLR() const { +float LambdaLR::GetClosedFormLR() const { return base_lr_ * lr_lambda_(last_step_); } @@ -123,7 +182,9 @@ SequentialLR::SequentialLR(std::shared_ptr optimizer, std::vectormilestones, int64_t last_step) : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)), - milestones_(std::move(milestones)) { + milestones_(std::move(milestones)) {} + +void SequentialLR::InitialStep() { CHECK(!schedulers_.empty()) << "SequentialLR requires at least one scheduler."; CHECK_EQ(milestones_.size(), schedulers_.size() - 1) @@ -136,35 +197,36 @@ SequentialLR::SequentialLR(std::shared_ptr optimizer, optimizer_->SetLearningRate(schedulers_[0]->BaseLR()); - // Reset all schedulers to the same last_step so they are in sync when Step() is called. + UndoChildInitialSteps(); + + ++last_step_; + schedulers_[0]->InitialStep(); + current_lr_ = schedulers_[0]->GetLR(); + +} + +void SequentialLR::UndoChildInitialSteps() { for (auto &sched : schedulers_) { - sched->ResetStep(sched->LastStep()-1); + if (auto nested = std::dynamic_pointer_cast(sched)) { + nested->UndoChildInitialSteps(); + } + sched->ResetStep(sched->LastStep() - 1); } - - Step(); } void SequentialLR::Step() { ++last_step_; - size_t idx = 0; - for (size_t i = 0; i < milestones_.size(); ++i) { - if (last_step_ >= milestones_[i]) { - idx = i + 1; - } else { - break; - } - } + size_t idx = std::upper_bound(milestones_.begin(), milestones_.end(), last_step_) - milestones_.begin(); auto &scheduler = schedulers_[idx]; if (idx > 0 && milestones_[idx - 1] == last_step_) { - scheduler->ResetStep(-1); - scheduler->Step(); + scheduler->Step(0); } else { scheduler->Step(); } - current_lr_ = scheduler->GetLR(); + ApplyLR(scheduler->GetLR()); } StateDict SequentialLR::State() const { diff --git a/test/lr_scheduler/test_constant_lr.cc b/test/lr_scheduler/test_constant_lr.cc index 176b0d8f..6d08af50 100644 --- a/test/lr_scheduler/test_constant_lr.cc +++ b/test/lr_scheduler/test_constant_lr.cc @@ -89,6 +89,20 @@ void TestResumeConsistency() { ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); } +void TestChainableAndClosedFormConsistency() { + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto chainable = LRScheduler::Create(opt_a, 0.5f, 5); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto closed_form = LRScheduler::Create(opt_b, 0.5f, 5); + + for (int epoch = 1; epoch <= 12; ++epoch) { + chainable->Step(); + closed_form->Step(epoch); + ASSERT_FLOAT_NEAR(chainable->GetLR(), closed_form->GetLR(), 1e-7f); + } +} + int main(int argc, char *argv[]) { google::InitGoogleLogging(argv[0]); std::cout << "=== ConstantLR Tests ===" << std::endl; @@ -99,6 +113,7 @@ int main(int argc, char *argv[]) { TestPyTorchAlignment(); TestStateRoundTrip(); TestResumeConsistency(); + TestChainableAndClosedFormConsistency(); std::cout << "========================" << std::endl; if (g_fail_count == 0) { std::cout << "All Tests PASSED" << std::endl; diff --git a/test/lr_scheduler/test_linear_lr.cc b/test/lr_scheduler/test_linear_lr.cc index 33b13847..6379cc46 100644 --- a/test/lr_scheduler/test_linear_lr.cc +++ b/test/lr_scheduler/test_linear_lr.cc @@ -10,54 +10,69 @@ constexpr float kBaseLR = 0.1f; void TestFirstStepFromZero() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, /*warmup_steps=*/5, /*start_factor=*/0.0f); - sched->Step(); // last_step_=1, alpha=1/5=0.2 → lr=0.02 + auto sched = LRScheduler::Create( + opt, /*start_factor=*/0.2f, /*end_factor=*/1.0f, /*total_iters=*/5); ASSERT_FLOAT_EQ(sched->GetLR(), 0.02f); } void TestMidpointLR() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 5, 0.0f); + auto sched = LRScheduler::Create(opt, 0.2f, 1.0f, 5); for (int i = 0; i < 3; ++i) sched->Step(); - // last_step_=2, alpha=3/5=0.6 → lr=0.1*0.6=0.06 - ASSERT_FLOAT_EQ(sched->GetLR(), 0.06f); + // last_step_=3 -> 0.1*(0.2 + 0.8*3/5) = 0.068 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.068f); } void TestWarmupEnd() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 5, 0.0f); + auto sched = LRScheduler::Create(opt, 0.2f, 1.0f, 5); for (int i = 0; i < 5; ++i) sched->Step(); - // last_step_=5, 5 >= 5 → base_lr + // last_step_ >= total_iters -> base_lr * end_factor ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); } void TestBeyondWarmup() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 5, 0.0f); + auto sched = LRScheduler::Create(opt, 0.2f, 1.0f, 5); for (int i = 0; i < 20; ++i) sched->Step(); ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); } void TestCustomStartFactor() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 4, /*start_factor=*/0.25f); - sched->Step(); // last_step_=1, alpha=1/4=0.25 → lr=0.1*(0.25+0.75*0.25)=0.04375 + auto sched = LRScheduler::Create( + opt, /*start_factor=*/0.25f, /*end_factor=*/1.0f, /*total_iters=*/4); + sched->Step(); // last_step_=1, lr=0.1*(0.25+0.75*1/4)=0.04375 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.04375f, 1e-6f); - sched->Step(); // last_step_=2, alpha=2/4=0.5 → lr=0.1*(0.25+0.75*0.5)=0.0625 + sched->Step(); // last_step_=2, lr=0.1*(0.25+0.75*2/4)=0.0625 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.0625f, 1e-6f); } void TestPyTorchAlignment() { const std::vector expected = { - 0.02f, 0.04f, 0.06f, 0.08f, 0.1f, 0.1f, 0.1f}; + 0.036f, 0.052f, 0.068f, 0.084f, 0.1f, 0.1f, 0.1f}; auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 5, 0.0f); + auto sched = LRScheduler::Create(opt, 0.2f, 1.0f, 5); for (size_t i = 0; i < expected.size(); ++i) { sched->Step(); ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-7f); } } +void TestChainableAndClosedFormConsistency() { + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto chainable = LRScheduler::Create(opt_a, 0.2f, 1.0f, 5); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto closed_form = LRScheduler::Create(opt_b, 0.2f, 1.0f, 5); + + for (int epoch = 1; epoch <= 10; ++epoch) { + chainable->Step(); + closed_form->Step(epoch); + ASSERT_FLOAT_NEAR(chainable->GetLR(), closed_form->GetLR(), 1e-6f); + } +} + int main(int argc, char *argv[]) { google::InitGoogleLogging(argv[0]); std::cout << "=== Linear Tests ===" << std::endl; @@ -67,6 +82,7 @@ int main(int argc, char *argv[]) { TestBeyondWarmup(); TestCustomStartFactor(); TestPyTorchAlignment(); + TestChainableAndClosedFormConsistency(); std::cout << "========================" << std::endl; if (g_fail_count == 0) { diff --git a/test/lr_scheduler/test_lr_scheduler.cc b/test/lr_scheduler/test_lr_scheduler.cc index 89856435..0c3dce1d 100644 --- a/test/lr_scheduler/test_lr_scheduler.cc +++ b/test/lr_scheduler/test_lr_scheduler.cc @@ -22,7 +22,7 @@ class IdentityScheduler : public LRScheduler { ~IdentityScheduler() override = default; protected: - float ComputeLR() const override { return base_lr_; } + float GetClosedFormLR() const override { return base_lr_; } }; class LinearDecayScheduler : public LRScheduler { @@ -33,7 +33,7 @@ class LinearDecayScheduler : public LRScheduler { total_steps_(total_steps) {} protected: - float ComputeLR() const override { + float GetClosedFormLR() const override { if (last_step_ >= total_steps_) { return 0.0f; } diff --git a/test/lr_scheduler/test_sequential_lr.cc b/test/lr_scheduler/test_sequential_lr.cc index ef558258..d05fc7d1 100644 --- a/test/lr_scheduler/test_sequential_lr.cc +++ b/test/lr_scheduler/test_sequential_lr.cc @@ -7,86 +7,85 @@ namespace { constexpr float kBaseLR = 0.1f; } // namespace -void TestWarmupThenConstant() { - std::cout << "[TC1] TestWarmupThenConstant" << std::endl; +void TestLinearThenConstant() { + std::cout << "[TC1] TestLinearThenConstant" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto warmup = LRScheduler::Create(opt, /*warmup_steps=*/3, /*start_factor=*/1e-8); + auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8, /*end_factor=*/1.0f, /*total_iters=*/3); auto constant = LRScheduler::Create(opt, /*factor=*/1.0f, /*total_iters=*/100); + auto sched = LRScheduler::Create(opt, std::vector>{linear, constant}, std::vector{3}); - SequentialLR sched(opt, {warmup, constant}, {3}); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.0f, kEps); - ASSERT_FLOAT_NEAR(sched.GetLR(), 0.0f, kEps); + sched->Step(); // global=1, warmup step=1, lr=0.1*(1/3) + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f / 3.0f, 1e-5f); - sched.Step(); // global=1, warmup step=1, lr=0.1*(1/3) - ASSERT_FLOAT_NEAR(sched.GetLR(), 0.1f / 3.0f, 1e-5f); + sched->Step(); // global=2, warmup step=2, lr=0.1*(2/3) + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.2f / 3.0f, 1e-5f); - sched.Step(); // global=2, warmup step=2, lr=0.1*(2/3) - ASSERT_FLOAT_NEAR(sched.GetLR(), 0.2f / 3.0f, 1e-5f); + sched->Step(); // global=3, constant step=0, lr=0.1*1.0=0.1 + ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps); - sched.Step(); // global=3, constant step=0, lr=0.1*1.0=0.1 - ASSERT_FLOAT_NEAR(sched.GetLR(), kBaseLR, kEps); - - sched.Step(); // global=4, constant step=1, lr=0.1 - ASSERT_FLOAT_NEAR(sched.GetLR(), kBaseLR, kEps); + sched->Step(); // global=4, constant step=1, lr=0.1 + ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps); } -void TestWarmupThenStepLR() { - std::cout << "[TC2] TestWarmupThenStepLR" << std::endl; +void TestLinearThenStepLR() { + std::cout << "[TC2] TestLinearThenStepLR" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto warmup = LRScheduler::Create(opt, /*warmup_steps=*/3, /*start_factor=*/0.0f); + auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, /*total_iters=*/3); auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); - SequentialLR sched(opt, {warmup, step_lr}, {3}); + auto sched = LRScheduler::Create(opt, std::vector>{linear, step_lr}, std::vector{3}); - sched.Step(); // global=1 - sched.Step(); // global=2 + sched->Step(); // global=1 + sched->Step(); // global=2 - sched.Step(); // global=3, StepLR step=0, lr=0.1 - ASSERT_FLOAT_NEAR(sched.GetLR(), 0.1f, kEps); + sched->Step(); // global=3, StepLR step=0, lr=0.1 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); - sched.Step(); // global=4, StepLR step=1 - sched.Step(); // global=5, StepLR step=2 - sched.Step(); // global=6, StepLR step=3, 3//3=1, lr=0.1*0.5=0.05 - ASSERT_FLOAT_NEAR(sched.GetLR(), 0.05f, kEps); + sched->Step(); // global=4, StepLR step=1 + sched->Step(); // global=5, StepLR step=2 + sched->Step(); // global=6, StepLR step=3, 3//3=1, lr=0.1*0.5=0.05 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.05f, kEps); } -void TestWarmupThenStepThenConstant(){ - std::cout << "[TC3] TestWarmupThenStepThenConstant" << std::endl; +void TestLinearThenStepThenConstant(){ + std::cout << "[TC3] TestLinearThenStepThenConstant" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto warmup = LRScheduler::Create(opt, /*warmup_steps=*/3, /*start_factor=*/0.0f); + auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, /*total_iters=*/3); auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); auto constant = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/2); - SequentialLR sched(opt, {warmup, step_lr, constant}, {3, 6}); + auto sched = LRScheduler::Create(opt, std::vector>{linear, step_lr, constant}, std::vector{3, 6}); const std::vector expected = { 0.033333f, 0.066667f, 0.1f, 0.1f, 0.1f, 0.05f, 0.05f, 0.1f, 0.1f, 0.1f}; for (size_t i = 0; i < expected.size(); ++i) { - sched.Step(); - ASSERT_FLOAT_NEAR(sched.GetLR(), expected[i], 1e-5f); + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-5f); } } void TestStateRoundTrip() { std::cout << "[TC4] TestStateRoundTrip" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto warmup = LRScheduler::Create(opt, /*warmup_steps=*/3, /*start_factor=*/0.0f); + auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, /*total_iters=*/3); auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); - SequentialLR sched(opt, {warmup, step_lr}, {3}); + auto sched = LRScheduler::Create(opt, std::vector>{linear, step_lr}, std::vector{3}); - for (int i = 0; i < 5; ++i) sched.Step(); - StateDict saved = sched.State(); + for (int i = 0; i < 5; ++i) sched->Step(); + StateDict saved = sched->State(); auto opt2 = MakeDummyOptimizer(kBaseLR); - auto warmup2 = LRScheduler::Create(opt2, /*warmup_steps=*/3, /*start_factor=*/0.0f); + auto linear2 = LRScheduler::Create(opt2, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, /*total_iters=*/3); auto step_lr2 = LRScheduler::Create(opt2, /*step_size=*/3, /*gamma=*/0.5f); - SequentialLR sched2(opt2, {warmup2, step_lr2}, {3}); - sched2.LoadState(saved); + auto sched2 = LRScheduler::Create(opt2, std::vector>{linear2, step_lr2}, std::vector{3}); + sched2->LoadState(saved); - ASSERT_TRUE(sched2.LastStep() == sched.LastStep()); - ASSERT_FLOAT_NEAR(sched2.GetLR(), sched.GetLR(), kEps); + ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); + ASSERT_FLOAT_NEAR(sched2->GetLR(), sched->GetLR(), kEps); } void TestResumeConsistency() { @@ -94,10 +93,10 @@ void TestResumeConsistency() { constexpr int kN = 10, kK = 4; auto make_sched = [](std::shared_ptr opt) { - auto warmup = LRScheduler::Create(opt, /*warmup_steps=*/3, /*start_factor=*/0.0f); + auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, /*total_iters=*/3); auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); - return std::make_unique(opt, - std::vector>{warmup, step_lr}, + return LRScheduler::Create(opt, + std::vector>{linear, step_lr}, std::vector{3}); }; @@ -122,9 +121,9 @@ void TestResumeConsistency() { int main(int argc, char *argv[]) { google::InitGoogleLogging(argv[0]); std::cout << "=== SequentialLR Tests ===" << std::endl; - TestWarmupThenConstant(); - TestWarmupThenStepLR(); - TestWarmupThenStepThenConstant(); + TestLinearThenConstant(); + TestLinearThenStepLR(); + TestLinearThenStepThenConstant(); TestStateRoundTrip(); TestResumeConsistency(); if (g_fail_count == 0) { diff --git a/test/lr_scheduler/test_step_lr.cc b/test/lr_scheduler/test_step_lr.cc index 74cce52e..0b328c40 100644 --- a/test/lr_scheduler/test_step_lr.cc +++ b/test/lr_scheduler/test_step_lr.cc @@ -53,6 +53,20 @@ void TestGammaOne() { } } +void TestChainableAndClosedFormConsistency() { + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto chainable = LRScheduler::Create(opt_a, 3, 0.1f); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto closed_form = LRScheduler::Create(opt_b, 3, 0.1f); + + for (int epoch = 1; epoch <= 12; ++epoch) { + chainable->Step(); + closed_form->Step(epoch); + ASSERT_FLOAT_NEAR(chainable->GetLR(), closed_form->GetLR(), 1e-7f); + } +} + int main(int argc, char *argv[]) { google::InitGoogleLogging(argv[0]); std::cout << "=== Step Tests ===" << std::endl; @@ -61,6 +75,7 @@ int main(int argc, char *argv[]) { TestMultipleDecays(); TestPyTorchAlignment(); TestGammaOne(); + TestChainableAndClosedFormConsistency(); std::cout << "========================" << std::endl; if (g_fail_count == 0) { From 6823244bff3ec5dd853a1bc17e1dcf24d5887b5d Mon Sep 17 00:00:00 2001 From: kinorw Date: Sat, 7 Mar 2026 00:14:22 +0800 Subject: [PATCH 13/18] feat(lr_schedulers): add ChainedScheduler composite strategy --- CMakeLists.txt | 5 +- infini_train/include/lr_scheduler.h | 21 ++++ infini_train/src/lr_scheduler.cc | 52 ++++++++- test/lr_scheduler/test_chained_lr.cc | 164 +++++++++++++++++++++++++++ 4 files changed, 240 insertions(+), 2 deletions(-) create mode 100644 test/lr_scheduler/test_chained_lr.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 02c964f6..8b7fa054 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -214,4 +214,7 @@ add_executable(test_lambda_lr test/lr_scheduler/test_lambda_lr.cc) target_link_libraries(test_lambda_lr infini_train) add_executable(test_sequential_lr test/lr_scheduler/test_sequential_lr.cc) -target_link_libraries(test_sequential_lr infini_train) \ No newline at end of file +target_link_libraries(test_sequential_lr infini_train) + +add_executable(test_chained_lr test/lr_scheduler/test_chained_lr.cc) +target_link_libraries(test_chained_lr infini_train) \ No newline at end of file diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h index 464f3a82..05701df9 100644 --- a/infini_train/include/lr_scheduler.h +++ b/infini_train/include/lr_scheduler.h @@ -152,5 +152,26 @@ class SequentialLR : public LRScheduler { std::vector milestones_; }; +class ChainedScheduler : public LRScheduler { +public: + ChainedScheduler(std::shared_ptr optimizer, + std::vector> schedulers, + int64_t last_step = -1); + ~ChainedScheduler() override = default; + + void Step() override; + void InitialStep() override; + + StateDict State() const override; + void LoadState(const StateDict& state) override; + +protected: + float GetClosedFormLR() const override { return current_lr_; } + +private: + std::vector> schedulers_; +}; + + } // namespace lr_schedulers } // namespace infini_train \ No newline at end of file diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc index 050ac2bf..474e9f39 100644 --- a/infini_train/src/lr_scheduler.cc +++ b/infini_train/src/lr_scheduler.cc @@ -4,6 +4,7 @@ #include "infini_train/include/optimizer.h" + namespace infini_train { LRScheduler::LRScheduler(std::shared_ptr optimizer, @@ -226,7 +227,7 @@ void SequentialLR::Step() { scheduler->Step(); } - ApplyLR(scheduler->GetLR()); + current_lr_ = optimizer_->GetLearningRate(); } StateDict SequentialLR::State() const { @@ -262,6 +263,55 @@ void SequentialLR::LoadState(const StateDict &state) { optimizer_->SetLearningRate(current_lr_); } +ChainedScheduler::ChainedScheduler(std::shared_ptr optimizer, + std::vector> schedulers, + int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), + schedulers_(std::move(schedulers)) {} + + +void ChainedScheduler::InitialStep() { + CHECK(!schedulers_.empty()) + << "ChainedScheduler requires at least one scheduler."; + + current_lr_ = optimizer_->GetLearningRate(); +} + + +void ChainedScheduler::Step() { + ++last_step_; + for (auto &sched : schedulers_) { + sched->Step(); + } + current_lr_ = optimizer_->GetLearningRate(); +} + +StateDict ChainedScheduler::State() const { + StateDict state = LRScheduler::State(); + for (size_t i = 0; i < schedulers_.size(); ++i) { + auto sub_state = schedulers_[i]->State(); + for (const auto& [key, value] : sub_state) { + state["scheduler_" + std::to_string(i) + "." + key] = value; + } + } + return state; +} + +void ChainedScheduler::LoadState(const StateDict& state) { + LRScheduler::LoadState(state); + for (size_t i = 0; i < schedulers_.size(); ++i) { + StateDict sub_state; + std::string prefix = "scheduler_" + std::to_string(i) + "."; + for (const auto& [key, value] : state) { + if (key.substr(0, prefix.size()) == prefix) { + sub_state[key.substr(prefix.size())] = value; + } + } + if (!sub_state.empty()) { + schedulers_[i]->LoadState(sub_state); + } + } +} } // namespace lr_schedulers } // namespace infini_train \ No newline at end of file diff --git a/test/lr_scheduler/test_chained_lr.cc b/test/lr_scheduler/test_chained_lr.cc new file mode 100644 index 00000000..ccdbdff1 --- /dev/null +++ b/test/lr_scheduler/test_chained_lr.cc @@ -0,0 +1,164 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace{ +constexpr float kBaseLR = 0.1f; +} +// TC1: 单子调度器退化 +void TestSingleScheduler() { + std::cout << "[TC1] TestSingleScheduler" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto step_lr = std::make_shared(opt, 3, 0.5f); + auto sched = LRScheduler::Create(opt, std::vector>{step_lr}); + + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + sched->Step(); // step=1 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); +} + +// TC2: StepLR + LambdaLR 乘法叠加 +void TestMultiplicativeChain() { + std::cout << "[TC2] TestMultiplicativeChain" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/2, /*gamma=*/0.5f); + auto lambda_lr = LRScheduler::Create(opt, + [](int64_t step) { return 1.0f - 0.1f * step; }); + + auto sched = LRScheduler::Create(opt, std::vector>{step_lr, lambda_lr}); + + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.09f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.08f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.07f, kEps); +} + +// TC3: ConstantLR + StepLR 叠加 (无穿插声明) +void TestConstantPlusStep() { + std::cout << "[TC3] TestConstantPlusStep" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto constant = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/2); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); + + auto sched = LRScheduler::Create(opt, std::vector>{constant, step_lr}); + + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.05f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.05f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.01f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.01f, kEps); +} + +// TC4: ConstantLR + StepLR 叠加(有穿插声明) +void TestConstantPlusStepDLC() { + std::cout << "[TC4] TestConstantPlusStepDLC" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto constant = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/2); + auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, /*total_iters=*/3); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); + auto Lambda = LRScheduler::Create(opt, + [](int64_t step) { return 1.0f - 0.1f * step; }); + + auto sched = LRScheduler::Create(opt, std::vector>{constant, step_lr}); + + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.2f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.02f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.02f, kEps); +} + +// TC5: State/LoadState 往返 +void TestStateRoundTrip() { + std::cout << "[TC5] TestStateRoundTrip" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto step_lr = std::make_shared(opt, 2, 0.5f); + auto lambda_lr = std::make_shared(opt, + [](int64_t step) { return 1.0f - 0.05f * step; }); + auto sched = LRScheduler::Create(opt, std::vector>{step_lr, lambda_lr}); + + for (int i = 0; i < 5; ++i) sched->Step(); + StateDict saved = sched->State(); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + auto step_lr2 = std::make_shared(opt2, 2, 0.5f); + auto lambda_lr2 = std::make_shared(opt2, + [](int64_t step) { return 1.0f - 0.05f * step; }); + auto sched2 = LRScheduler::Create(opt2, std::vector>{step_lr2, lambda_lr2}); + sched2->LoadState(saved); + + ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); + ASSERT_FLOAT_NEAR(sched2->GetLR(), sched->GetLR(), kEps); +} + +// TC6: resume 一致性 +void TestResumeConsistency() { + std::cout << "[TC6] TestResumeConsistency" << std::endl; + constexpr int kN = 10, kK = 4; + auto lambda_fn = [](int64_t step) { return 1.0f - 0.05f * step; }; + + auto make_sched = [&](std::shared_ptr opt) { + auto step_lr = std::make_shared(opt, 2, 0.5f); + auto lambda_lr = std::make_shared(opt, lambda_fn); + return LRScheduler::Create(opt, std::vector>{step_lr, lambda_lr}); + }; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = make_sched(opt_ref); + for (int i = 0; i < kN; ++i) sched_ref->Step(); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = make_sched(opt_a); + for (int i = 0; i < kK; ++i) sched_a->Step(); + StateDict ckpt = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = make_sched(opt_b); + sched_b->LoadState(ckpt); + for (int i = 0; i < kN - kK; ++i) sched_b->Step(); + + ASSERT_FLOAT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== ChainedScheduler Tests ===" << std::endl; + TestSingleScheduler(); + TestMultiplicativeChain(); + TestConstantPlusStep(); + TestConstantPlusStepDLC(); + TestStateRoundTrip(); + TestResumeConsistency(); + std::cout << "========================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} \ No newline at end of file From fb9d997c834da78edc75b76198a0962c732e84c8 Mon Sep 17 00:00:00 2001 From: kinorw Date: Sun, 8 Mar 2026 15:43:14 +0800 Subject: [PATCH 14/18] feat(lr_scheduler): add scheduler factory for CLI integration - Add LRSchedulerConfig struct with parameters for all basic schedulers(constant, linear, step) - Add CreateLRScheduler() factory function - Support automatic warmup wrapping via SequentialLR when warmup_steps > 0 - Adapt test files --- infini_train/include/lr_scheduler.h | 21 +++++++ infini_train/src/lr_scheduler.cc | 45 ++++++++++++++ test/lr_scheduler/test_constant_lr.cc | 85 +++++++++++++++++++++++---- test/lr_scheduler/test_linear_lr.cc | 67 +++++++++++++++++---- test/lr_scheduler/test_step_lr.cc | 47 ++++++++++++--- 5 files changed, 236 insertions(+), 29 deletions(-) diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h index 05701df9..8aa2a387 100644 --- a/infini_train/include/lr_scheduler.h +++ b/infini_train/include/lr_scheduler.h @@ -17,6 +17,23 @@ using StateValue = std::variant>; using StateDict = std::unordered_map; +struct LRSchedulerConfig { + std::string type = "none"; + // ConstantLR + float constant_factor = 1.0f / 3.0f; + int constant_total_iters = 5; + // StepLR + int64_t step_size = 10; + float step_gamma = 0.1f; + // LinearLR + float linear_start_factor = 1.0f / 3.0f; + float linear_end_factor = 1.0f; + int linear_total_iters = 5; + // common + int64_t warmup_steps = 0; + int64_t total_iters = 0; +}; + class LRScheduler { public: template @@ -57,6 +74,10 @@ class LRScheduler { bool is_initial_ = false; }; +std::shared_ptr CreateLRScheduler( + std::shared_ptr optimizer, + const LRSchedulerConfig& config); + namespace lr_schedulers { class ConstantLR : public LRScheduler { diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc index 474e9f39..4a99bd53 100644 --- a/infini_train/src/lr_scheduler.cc +++ b/infini_train/src/lr_scheduler.cc @@ -7,6 +7,51 @@ namespace infini_train { +std::shared_ptr CreateLRScheduler( + std::shared_ptr optimizer, + const LRSchedulerConfig& config) { + if (config.type == "none") { + return nullptr; + } + + auto create_main = [&](std::shared_ptr opt) + -> std::shared_ptr { + if (config.type == "constant") { + return LRScheduler::Create( + opt, config.constant_factor, config.constant_total_iters); + } + if (config.type == "step") { + return LRScheduler::Create( + opt, config.step_size, config.step_gamma); + } + if (config.type == "linear") { + return LRScheduler::Create( + opt, config.linear_start_factor, config.linear_end_factor, + config.linear_total_iters); + } + LOG(FATAL) << "Unsupported LR scheduler type: " << config.type; + return nullptr; + }; + + if (config.warmup_steps <= 0) { + return create_main(optimizer); + } + + auto warmup_scheduler = LRScheduler::Create( + optimizer, + /*start_factor=*/1e-8f, + /*end_factor=*/1.0f, + /*total_iters=*/config.warmup_steps); + + auto main_scheduler = create_main(optimizer); + + return LRScheduler::Create( + optimizer, + std::vector>{warmup_scheduler, main_scheduler}, + std::vector{config.warmup_steps}); + +}; + LRScheduler::LRScheduler(std::shared_ptr optimizer, int64_t last_step) : optimizer_(std::move(optimizer)), diff --git a/test/lr_scheduler/test_constant_lr.cc b/test/lr_scheduler/test_constant_lr.cc index 6d08af50..aea6f6b3 100644 --- a/test/lr_scheduler/test_constant_lr.cc +++ b/test/lr_scheduler/test_constant_lr.cc @@ -11,7 +11,12 @@ constexpr float kBaseLR = 0.1f; void TestInitialState() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/3); + LRSchedulerConfig config = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 3, + }; + auto sched = CreateLRScheduler(opt, config); ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); ASSERT_TRUE(sched->LastStep() == 0); ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f); @@ -19,7 +24,13 @@ void TestInitialState() { void TestFirstStepAppliesFactor() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 0.5f, 3); + LRSchedulerConfig config = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 3, + }; + + auto sched = CreateLRScheduler(opt, config); sched->Step(); // last_step_ = 0 ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f); @@ -28,7 +39,12 @@ void TestFirstStepAppliesFactor() { void TestWithinTotalIters() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 0.5f, 3); + LRSchedulerConfig config = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 3, + }; + auto sched = CreateLRScheduler(opt, config); for (int i = 0; i < 2; ++i) sched->Step(); // last_step_ = 2, still < 3 ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); @@ -36,7 +52,12 @@ void TestWithinTotalIters() { void TestBeyondTotalIters() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 0.5f, 3); + LRSchedulerConfig config = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 3, + }; + auto sched = CreateLRScheduler(opt, config); for (int i = 0; i < 10; ++i) sched->Step(); ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); @@ -45,7 +66,12 @@ void TestBeyondTotalIters() { void TestPyTorchAlignment() { const std::vector expected = {0.05f, 0.05f, 0.1f, 0.1f, 0.1f}; auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 0.5f, 3); + LRSchedulerConfig config = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 3, + }; + auto sched = CreateLRScheduler(opt, config); for (size_t i = 0; i < expected.size(); ++i) { sched->Step(); ASSERT_FLOAT_EQ(sched->GetLR(), expected[i]); @@ -54,12 +80,22 @@ void TestPyTorchAlignment() { void TestStateRoundTrip() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 0.5f, 5); + LRSchedulerConfig config = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 5, + }; + auto sched = CreateLRScheduler(opt, config); for (int i = 0; i < 3; ++i) sched->Step(); StateDict saved = sched->State(); auto opt2 = MakeDummyOptimizer(kBaseLR); - auto sched2 = LRScheduler::Create(opt2, 0.5f, 5); + LRSchedulerConfig config2 = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 5, + }; + auto sched2 = CreateLRScheduler(opt2, config2); sched2->LoadState(saved); ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); @@ -72,16 +108,31 @@ void TestResumeConsistency() { constexpr int kK = 3; auto opt_ref = MakeDummyOptimizer(kBaseLR); - auto sched_ref = LRScheduler::Create(opt_ref, 0.5f, 5); + LRSchedulerConfig config_ref = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 5, + }; + auto sched_ref = CreateLRScheduler(opt_ref, config_ref); for (int i = 0; i < kN; ++i) sched_ref->Step(); auto opt_a = MakeDummyOptimizer(kBaseLR); - auto sched_a = LRScheduler::Create(opt_a, 0.5f, 5); + LRSchedulerConfig config_a = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 5, + }; + auto sched_a = CreateLRScheduler(opt_a, config_a); for (int i = 0; i < kK; ++i) sched_a->Step(); StateDict ckpt = sched_a->State(); auto opt_b = MakeDummyOptimizer(kBaseLR); - auto sched_b = LRScheduler::Create(opt_b, 0.5f, 5); + LRSchedulerConfig config_b = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 5, + }; + auto sched_b = CreateLRScheduler(opt_b, config_b); sched_b->LoadState(ckpt); for (int i = 0; i < kN - kK; ++i) sched_b->Step(); @@ -91,10 +142,20 @@ void TestResumeConsistency() { void TestChainableAndClosedFormConsistency() { auto opt_a = MakeDummyOptimizer(kBaseLR); - auto chainable = LRScheduler::Create(opt_a, 0.5f, 5); + LRSchedulerConfig config_a = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 5, + }; + auto chainable = CreateLRScheduler(opt_a, config_a); auto opt_b = MakeDummyOptimizer(kBaseLR); - auto closed_form = LRScheduler::Create(opt_b, 0.5f, 5); + LRSchedulerConfig config_b = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 5, + }; + auto closed_form = CreateLRScheduler(opt_b, config_b); for (int epoch = 1; epoch <= 12; ++epoch) { chainable->Step(); diff --git a/test/lr_scheduler/test_linear_lr.cc b/test/lr_scheduler/test_linear_lr.cc index 6379cc46..a7391a50 100644 --- a/test/lr_scheduler/test_linear_lr.cc +++ b/test/lr_scheduler/test_linear_lr.cc @@ -10,14 +10,26 @@ constexpr float kBaseLR = 0.1f; void TestFirstStepFromZero() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create( - opt, /*start_factor=*/0.2f, /*end_factor=*/1.0f, /*total_iters=*/5); + LRSchedulerConfig config = { + .type = "linear", + .linear_start_factor = 0.2f, + .linear_end_factor = 1.0f, + .linear_total_iters = 5, + }; + + auto sched = CreateLRScheduler(opt, config); ASSERT_FLOAT_EQ(sched->GetLR(), 0.02f); } void TestMidpointLR() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 0.2f, 1.0f, 5); + LRSchedulerConfig config = { + .type = "linear", + .linear_start_factor = 0.2f, + .linear_end_factor = 1.0f, + .linear_total_iters = 5, + }; + auto sched = CreateLRScheduler(opt, config); for (int i = 0; i < 3; ++i) sched->Step(); // last_step_=3 -> 0.1*(0.2 + 0.8*3/5) = 0.068 ASSERT_FLOAT_EQ(sched->GetLR(), 0.068f); @@ -25,7 +37,13 @@ void TestMidpointLR() { void TestWarmupEnd() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 0.2f, 1.0f, 5); + LRSchedulerConfig config = { + .type = "linear", + .linear_start_factor = 0.2f, + .linear_end_factor = 1.0f, + .linear_total_iters = 5, + }; + auto sched = CreateLRScheduler(opt, config); for (int i = 0; i < 5; ++i) sched->Step(); // last_step_ >= total_iters -> base_lr * end_factor ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); @@ -33,15 +51,26 @@ void TestWarmupEnd() { void TestBeyondWarmup() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 0.2f, 1.0f, 5); + LRSchedulerConfig config = { + .type = "linear", + .linear_start_factor = 0.2f, + .linear_end_factor = 1.0f, + .linear_total_iters = 5, + }; + auto sched = CreateLRScheduler(opt, config); for (int i = 0; i < 20; ++i) sched->Step(); ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); } void TestCustomStartFactor() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create( - opt, /*start_factor=*/0.25f, /*end_factor=*/1.0f, /*total_iters=*/4); + LRSchedulerConfig config = { + .type = "linear", + .linear_start_factor = 0.25f, + .linear_end_factor = 1.0f, + .linear_total_iters = 4, + }; + auto sched = CreateLRScheduler(opt, config); sched->Step(); // last_step_=1, lr=0.1*(0.25+0.75*1/4)=0.04375 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.04375f, 1e-6f); sched->Step(); // last_step_=2, lr=0.1*(0.25+0.75*2/4)=0.0625 @@ -52,7 +81,13 @@ void TestPyTorchAlignment() { const std::vector expected = { 0.036f, 0.052f, 0.068f, 0.084f, 0.1f, 0.1f, 0.1f}; auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 0.2f, 1.0f, 5); + LRSchedulerConfig config = { + .type = "linear", + .linear_start_factor = 0.2f, + .linear_end_factor = 1.0f, + .linear_total_iters = 5, + }; + auto sched = CreateLRScheduler(opt, config); for (size_t i = 0; i < expected.size(); ++i) { sched->Step(); ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-7f); @@ -61,10 +96,22 @@ void TestPyTorchAlignment() { void TestChainableAndClosedFormConsistency() { auto opt_a = MakeDummyOptimizer(kBaseLR); - auto chainable = LRScheduler::Create(opt_a, 0.2f, 1.0f, 5); + LRSchedulerConfig config_a = { + .type = "linear", + .linear_start_factor = 0.2f, + .linear_end_factor = 1.0f, + .linear_total_iters = 5, + }; + auto chainable = CreateLRScheduler(opt_a, config_a); auto opt_b = MakeDummyOptimizer(kBaseLR); - auto closed_form = LRScheduler::Create(opt_b, 0.2f, 1.0f, 5); + LRSchedulerConfig config_b = { + .type = "linear", + .linear_start_factor = 0.2f, + .linear_end_factor = 1.0f, + .linear_total_iters = 5, + }; + auto closed_form = CreateLRScheduler(opt_b, config_b); for (int epoch = 1; epoch <= 10; ++epoch) { chainable->Step(); diff --git a/test/lr_scheduler/test_step_lr.cc b/test/lr_scheduler/test_step_lr.cc index 0b328c40..1c1bb515 100644 --- a/test/lr_scheduler/test_step_lr.cc +++ b/test/lr_scheduler/test_step_lr.cc @@ -10,7 +10,12 @@ constexpr float kBaseLR = 0.1f; void TestWithinFirstPeriod() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); + LRSchedulerConfig config = { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }; + auto sched = CreateLRScheduler(opt, config); for (int i = 0; i < 2; ++i) { sched->Step(); ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); // last_step 1,2 → 指数 0 @@ -19,7 +24,12 @@ void TestWithinFirstPeriod() { void TestFirstDecay() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); + LRSchedulerConfig config = { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }; + auto sched = CreateLRScheduler(opt, config); for (int i = 0; i < 3; ++i) sched->Step(); // last_step=3, 3//3=1 → 0.1^1 = 0.1 → lr=0.01 ASSERT_FLOAT_EQ(sched->GetLR(), 0.01f); @@ -27,7 +37,12 @@ void TestFirstDecay() { void TestMultipleDecays() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 3, 0.1f); + LRSchedulerConfig config = { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }; + auto sched = CreateLRScheduler(opt, config); for (int i = 0; i < 6; ++i) sched->Step(); // last_step=6, 6//3=2 → 0.1^2 = 0.01 → lr=0.001 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.001f, 1e-7f); @@ -37,7 +52,12 @@ void TestPyTorchAlignment() { const std::vector expected = { 0.1f, 0.1f, 0.01f, 0.01f, 0.01f, 0.001f, 0.001f}; auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 3, 0.1f); + LRSchedulerConfig config = { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }; + auto sched = CreateLRScheduler(opt, config); for (size_t i = 0; i < expected.size(); ++i) { sched->Step(); ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-7f); @@ -46,7 +66,12 @@ void TestPyTorchAlignment() { void TestGammaOne() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, 3, 1.0f); + LRSchedulerConfig config = { + .type = "step", + .step_size = 3, + .step_gamma = 1.0f, + }; + auto sched = CreateLRScheduler(opt, config); for (int i = 0; i < 20; ++i) { sched->Step(); ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); @@ -55,10 +80,18 @@ void TestGammaOne() { void TestChainableAndClosedFormConsistency() { auto opt_a = MakeDummyOptimizer(kBaseLR); - auto chainable = LRScheduler::Create(opt_a, 3, 0.1f); + auto chainable = CreateLRScheduler(opt_a, { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }); auto opt_b = MakeDummyOptimizer(kBaseLR); - auto closed_form = LRScheduler::Create(opt_b, 3, 0.1f); + auto closed_form = CreateLRScheduler(opt_b, { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }); for (int epoch = 1; epoch <= 12; ++epoch) { chainable->Step(); From 7a29a61e3c89d2a6ced78e353b510b2963aa141d Mon Sep 17 00:00:00 2001 From: kinorw Date: Sun, 8 Mar 2026 16:18:12 +0800 Subject: [PATCH 15/18] feat(lr_scheduler): add scheduler factory for CLI integration (Sequential, Chained, and Lambda) --- infini_train/include/lr_scheduler.h | 7 ++ infini_train/src/lr_scheduler.cc | 27 +++++ test/lr_scheduler/test_chained_lr.cc | 63 ++++++++--- test/lr_scheduler/test_lambda_lr.cc | 40 +++++-- test/lr_scheduler/test_sequential_lr.cc | 135 +++++++++++++++++++----- 5 files changed, 224 insertions(+), 48 deletions(-) diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h index 8aa2a387..d80e706e 100644 --- a/infini_train/include/lr_scheduler.h +++ b/infini_train/include/lr_scheduler.h @@ -29,6 +29,13 @@ struct LRSchedulerConfig { float linear_start_factor = 1.0f / 3.0f; float linear_end_factor = 1.0f; int linear_total_iters = 5; + // LambdaLR + std::function lambda_fn = nullptr; + // SequentialLR + std::vector sequential_configs; + std::vector sequential_milestones; + // ChainedScheduler + std::vector chained_configs; // common int64_t warmup_steps = 0; int64_t total_iters = 0; diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc index 4a99bd53..9fc26f17 100644 --- a/infini_train/src/lr_scheduler.cc +++ b/infini_train/src/lr_scheduler.cc @@ -29,6 +29,33 @@ std::shared_ptr CreateLRScheduler( opt, config.linear_start_factor, config.linear_end_factor, config.linear_total_iters); } + if (config.type == "lambda") { + return LRScheduler::Create( + opt, config.lambda_fn); + } + if (config.type == "sequential") { + std::vector> schedulers; + std::vector milestones = config.sequential_milestones; + for (const auto& sub_config : config.sequential_configs) { + auto sub_sched = CreateLRScheduler(opt, sub_config); + if (sub_sched) { + schedulers.push_back(sub_sched); + } + } + return LRScheduler::Create( + opt, schedulers, milestones); + } + if (config.type == "chained") { + std::vector> schedulers; + for (const auto& sub_config : config.chained_configs) { + auto sub_sched = CreateLRScheduler(opt, sub_config); + if (sub_sched) { + schedulers.push_back(sub_sched); + } + } + return LRScheduler::Create( + opt, schedulers); + } LOG(FATAL) << "Unsupported LR scheduler type: " << config.type; return nullptr; }; diff --git a/test/lr_scheduler/test_chained_lr.cc b/test/lr_scheduler/test_chained_lr.cc index ccdbdff1..4c451883 100644 --- a/test/lr_scheduler/test_chained_lr.cc +++ b/test/lr_scheduler/test_chained_lr.cc @@ -11,7 +11,11 @@ constexpr float kBaseLR = 0.1f; void TestSingleScheduler() { std::cout << "[TC1] TestSingleScheduler" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto step_lr = std::make_shared(opt, 3, 0.5f); + auto step_lr = CreateLRScheduler(opt, { + .type = "step", + .step_size = 3, + .step_gamma = 0.5f, + }); auto sched = LRScheduler::Create(opt, std::vector>{step_lr}); ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); @@ -23,11 +27,17 @@ void TestSingleScheduler() { void TestMultiplicativeChain() { std::cout << "[TC2] TestMultiplicativeChain" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto step_lr = LRScheduler::Create(opt, /*step_size=*/2, /*gamma=*/0.5f); - auto lambda_lr = LRScheduler::Create(opt, - [](int64_t step) { return 1.0f - 0.1f * step; }); - - auto sched = LRScheduler::Create(opt, std::vector>{step_lr, lambda_lr}); + auto sched = CreateLRScheduler(opt, { + .type = "chained", + .chained_configs = {{ + .type = "step", + .step_size = 2, + .step_gamma = 0.5f, + },{ + .type = "lambda", + .lambda_fn = [](int64_t step) { return 1.0f - 0.1f * step; }, + }}, + }); ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); @@ -45,10 +55,18 @@ void TestMultiplicativeChain() { void TestConstantPlusStep() { std::cout << "[TC3] TestConstantPlusStep" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto constant = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/2); - auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); - - auto sched = LRScheduler::Create(opt, std::vector>{constant, step_lr}); + auto sched = CreateLRScheduler(opt, { + .type = "chained", + .chained_configs = {{ + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 2, + }, { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }}, + }); ASSERT_FLOAT_NEAR(sched->GetLR(), 0.05f, kEps); @@ -69,11 +87,26 @@ void TestConstantPlusStep() { void TestConstantPlusStepDLC() { std::cout << "[TC4] TestConstantPlusStepDLC" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto constant = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/2); - auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, /*total_iters=*/3); - auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); - auto Lambda = LRScheduler::Create(opt, - [](int64_t step) { return 1.0f - 0.1f * step; }); + auto constant = CreateLRScheduler(opt, { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 2, + }); + auto linear = CreateLRScheduler(opt, { + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }); + auto step_lr = CreateLRScheduler(opt, { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }); + auto Lambda = CreateLRScheduler(opt, { + .type = "lambda", + .lambda_fn = [](int64_t step) { return 1.0f - 0.1f * step; }, + }); auto sched = LRScheduler::Create(opt, std::vector>{constant, step_lr}); diff --git a/test/lr_scheduler/test_lambda_lr.cc b/test/lr_scheduler/test_lambda_lr.cc index 0a99f132..0e977ba0 100644 --- a/test/lr_scheduler/test_lambda_lr.cc +++ b/test/lr_scheduler/test_lambda_lr.cc @@ -10,7 +10,10 @@ constexpr float kBaseLR = 0.1f; void TestIdentityLambda() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, [](int64_t) { return 1.0f; }); + auto sched = CreateLRScheduler(opt, { + .type = "lambda", + .lambda_fn = [](int64_t) { return 1.0f; }, + }); // 构造器内 Step() → last_step_=0, lr = 0.1 * 1.0 = 0.1 ASSERT_TRUE(sched->LastStep() == 0); ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps); @@ -19,7 +22,10 @@ void TestIdentityLambda() { void TestLinearDecayLambda() { auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, [](int64_t step) { return 1.0f - step * 0.1f; }); + auto sched = CreateLRScheduler(opt, { + .type = "lambda", + .lambda_fn = [](int64_t step) { return 1.0f - step * 0.1f; }, + }); // step=0, lambda(0)=1.0, lr=0.1 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); @@ -36,8 +42,9 @@ void TestLinearDecayLambda() { void TestPyTorchAlignment() { // PyTorch: LambdaLR(opt, lr_lambda=lambda epoch: 0.95**epoch) auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, [](int64_t step) { - return static_cast(std::pow(0.95, step)); + auto sched = CreateLRScheduler(opt, { + .type = "lambda", + .lambda_fn = [](int64_t step) { return static_cast(std::pow(0.95, step)); }, }); // step=0, lr = 0.1 * 0.95^0 = 0.1 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); @@ -52,12 +59,18 @@ void TestPyTorchAlignment() { void TestStateRoundTrip() { auto lambda_fn = [](int64_t step) { return 1.0f - step * 0.05f; }; auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = LRScheduler::Create(opt, lambda_fn); + auto sched = CreateLRScheduler(opt, { + .type = "lambda", + .lambda_fn = lambda_fn, + }); for (int i = 0; i < 5; ++i) sched->Step(); StateDict saved = sched->State(); auto opt2 = MakeDummyOptimizer(kBaseLR); - auto sched2 = LRScheduler::Create(opt2, lambda_fn); // same lambda + auto sched2 = CreateLRScheduler(opt2, { + .type = "lambda", + .lambda_fn = lambda_fn, + }); // same lambda sched2->LoadState(saved); ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); @@ -69,16 +82,25 @@ void TestResumeConsistency() { constexpr int kN = 10, kK = 4; auto opt_ref = MakeDummyOptimizer(kBaseLR); - auto sched_ref = LRScheduler::Create(opt_ref, lambda_fn); + auto sched_ref = CreateLRScheduler(opt_ref, { + .type = "lambda", + .lambda_fn = lambda_fn, + }); for (int i = 0; i < kN; ++i) sched_ref->Step(); auto opt_a = MakeDummyOptimizer(kBaseLR); - auto sched_a = LRScheduler::Create(opt_a, lambda_fn); + auto sched_a = CreateLRScheduler(opt_a, { + .type = "lambda", + .lambda_fn = lambda_fn, + }); for (int i = 0; i < kK; ++i) sched_a->Step(); StateDict ckpt = sched_a->State(); auto opt_b = MakeDummyOptimizer(kBaseLR); - auto sched_b = LRScheduler::Create(opt_b, lambda_fn); + auto sched_b = CreateLRScheduler(opt_b, { + .type = "lambda", + .lambda_fn = lambda_fn, + }); sched_b->LoadState(ckpt); for (int i = 0; i < kN - kK; ++i) sched_b->Step(); diff --git a/test/lr_scheduler/test_sequential_lr.cc b/test/lr_scheduler/test_sequential_lr.cc index d05fc7d1..2d866540 100644 --- a/test/lr_scheduler/test_sequential_lr.cc +++ b/test/lr_scheduler/test_sequential_lr.cc @@ -11,10 +11,25 @@ void TestLinearThenConstant() { std::cout << "[TC1] TestLinearThenConstant" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8, /*end_factor=*/1.0f, /*total_iters=*/3); - auto constant = LRScheduler::Create(opt, /*factor=*/1.0f, /*total_iters=*/100); - auto sched = LRScheduler::Create(opt, std::vector>{linear, constant}, std::vector{3}); - + LRSchedulerConfig linear_config = { + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }; + auto linear = CreateLRScheduler(opt, linear_config); + LRSchedulerConfig constant_config = { + .type = "constant", + .constant_factor = 1.0f, + .constant_total_iters = 100, + }; + auto constant = CreateLRScheduler(opt, constant_config); + auto sched = CreateLRScheduler(opt, { + .type = "sequential", + .sequential_configs = {linear_config, constant_config}, + .sequential_milestones = {3}, + }); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.0f, kEps); sched->Step(); // global=1, warmup step=1, lr=0.1*(1/3) @@ -34,10 +49,25 @@ void TestLinearThenStepLR() { std::cout << "[TC2] TestLinearThenStepLR" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, /*total_iters=*/3); - auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); + LRSchedulerConfig linear_config = { + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }; + auto linear = CreateLRScheduler(opt, linear_config); + LRSchedulerConfig step_config = { + .type = "step", + .step_size = 3, + .step_gamma = 0.5f, + }; + auto step_lr = CreateLRScheduler(opt, step_config); - auto sched = LRScheduler::Create(opt, std::vector>{linear, step_lr}, std::vector{3}); + auto sched = CreateLRScheduler(opt, { + .type = "sequential", + .sequential_configs = {linear_config, step_config}, + .sequential_milestones = {3}, + }); sched->Step(); // global=1 sched->Step(); // global=2 @@ -55,11 +85,31 @@ void TestLinearThenStepThenConstant(){ std::cout << "[TC3] TestLinearThenStepThenConstant" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, /*total_iters=*/3); - auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); - auto constant = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/2); + LRSchedulerConfig linear_config = { + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }; + auto linear = CreateLRScheduler(opt, linear_config); + LRSchedulerConfig step_config = { + .type = "step", + .step_size = 3, + .step_gamma = 0.5f, + }; + auto step_lr = CreateLRScheduler(opt, step_config); + LRSchedulerConfig constant_config = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 2, + }; + auto constant = CreateLRScheduler(opt, constant_config); - auto sched = LRScheduler::Create(opt, std::vector>{linear, step_lr, constant}, std::vector{3, 6}); + auto sched = CreateLRScheduler(opt, { + .type = "sequential", + .sequential_configs = {linear_config, step_config, constant_config}, + .sequential_milestones = {3, 6}, + }); const std::vector expected = { 0.033333f, 0.066667f, 0.1f, 0.1f, 0.1f, 0.05f, 0.05f, 0.1f, 0.1f, 0.1f}; for (size_t i = 0; i < expected.size(); ++i) { @@ -71,17 +121,46 @@ void TestLinearThenStepThenConstant(){ void TestStateRoundTrip() { std::cout << "[TC4] TestStateRoundTrip" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, /*total_iters=*/3); - auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); - auto sched = LRScheduler::Create(opt, std::vector>{linear, step_lr}, std::vector{3}); - + LRSchedulerConfig linear_config = { + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }; + auto linear = CreateLRScheduler(opt, linear_config); + LRSchedulerConfig step_config = { + .type = "step", + .step_size = 3, + .step_gamma = 0.5f, + }; + auto step_lr = CreateLRScheduler(opt, step_config); + auto sched = CreateLRScheduler(opt, { + .type = "sequential", + .sequential_configs = {linear_config, step_config}, + .sequential_milestones = {3}, + }); for (int i = 0; i < 5; ++i) sched->Step(); StateDict saved = sched->State(); auto opt2 = MakeDummyOptimizer(kBaseLR); - auto linear2 = LRScheduler::Create(opt2, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, /*total_iters=*/3); - auto step_lr2 = LRScheduler::Create(opt2, /*step_size=*/3, /*gamma=*/0.5f); - auto sched2 = LRScheduler::Create(opt2, std::vector>{linear2, step_lr2}, std::vector{3}); + LRSchedulerConfig linear_config2 = { + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }; + auto linear2 = CreateLRScheduler(opt2, linear_config2); + LRSchedulerConfig step_config2 = { + .type = "step", + .step_size = 3, + .step_gamma = 0.5f, + }; + auto step_lr2 = CreateLRScheduler(opt2, step_config2); + auto sched2 = CreateLRScheduler(opt2, { + .type = "sequential", + .sequential_configs = {linear_config2, step_config2}, + .sequential_milestones = {3}, + }); sched2->LoadState(saved); ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); @@ -93,13 +172,21 @@ void TestResumeConsistency() { constexpr int kN = 10, kK = 4; auto make_sched = [](std::shared_ptr opt) { - auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, /*total_iters=*/3); - auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); - return LRScheduler::Create(opt, - std::vector>{linear, step_lr}, - std::vector{3}); + return CreateLRScheduler(opt, { + .type = "sequential", + .sequential_configs = {{ + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }, { + .type = "step", + .step_size = 3, + .step_gamma = 0.5f, + }}, + .sequential_milestones = {3}, + }); }; - auto opt_ref = MakeDummyOptimizer(kBaseLR); auto sched_ref = make_sched(opt_ref); for (int i = 0; i < kN; ++i) sched_ref->Step(); From b64566e88bb45a6741f5c4455b57211c6ef6c829 Mon Sep 17 00:00:00 2001 From: kinorw Date: Sun, 8 Mar 2026 16:54:39 +0800 Subject: [PATCH 16/18] feat(lr_scheduler): add warmup start_factor and end_factor , remove common total_iters --- infini_train/include/lr_scheduler.h | 5 +++-- infini_train/src/lr_scheduler.cc | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h index d80e706e..c9772ac1 100644 --- a/infini_train/include/lr_scheduler.h +++ b/infini_train/include/lr_scheduler.h @@ -36,9 +36,10 @@ struct LRSchedulerConfig { std::vector sequential_milestones; // ChainedScheduler std::vector chained_configs; - // common + // warmup int64_t warmup_steps = 0; - int64_t total_iters = 0; + float warmup_start_factor = 1.0f / 3.0f; + float warmup_end_factor = 1.0f; }; class LRScheduler { diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc index 9fc26f17..bf834886 100644 --- a/infini_train/src/lr_scheduler.cc +++ b/infini_train/src/lr_scheduler.cc @@ -66,8 +66,8 @@ std::shared_ptr CreateLRScheduler( auto warmup_scheduler = LRScheduler::Create( optimizer, - /*start_factor=*/1e-8f, - /*end_factor=*/1.0f, + /*start_factor=*/config.warmup_start_factor, + /*end_factor=*/config.warmup_end_factor, /*total_iters=*/config.warmup_steps); auto main_scheduler = create_main(optimizer); From 3a7abb488090b781bd176c6f966e3d36488919bb Mon Sep 17 00:00:00 2001 From: kinorw Date: Sun, 8 Mar 2026 16:54:44 +0800 Subject: [PATCH 17/18] refactor(gpt2,llama3): integrate scheduler into training loop - Add gflags: --lr_scheduler, --warmup_steps, --step_size, --gamma, --start_factor, --end_factor, --lr_total_iters, --total_steps - Replace nullptr scheduler with factory-created scheduler - Move scheduler.Step() after optimizer.Step() in both DP and PP paths - Replace hardcoded FLAGS_learning_rate in log with scheduler->GetLR() --- example/gpt2/main.cc | 39 +++++++++++++++++++++++++++++++++------ example/llama3/main.cc | 39 +++++++++++++++++++++++++++++++++------ 2 files changed, 66 insertions(+), 12 deletions(-) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index e13956b0..2303b80d 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -56,6 +56,17 @@ DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations"); DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); +// lr scheduler +DEFINE_string(lr_scheduler, "none", + "Learning rate scheduler type: none|constant|step|linear"); +DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)"); +DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)"); +DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)"); +DEFINE_int64(step_size, 30, "StepLR: period of learning rate decay"); +DEFINE_double(gamma, 0.1, "StepLR: multiplicative factor of lr decay"); +DEFINE_double(start_factor, 0.333333, "LinearLR: starting multiplicative factor"); +DEFINE_double(end_factor, 1.0, "LinearLR: ending multiplicative factor"); +DEFINE_int64(lr_total_iters, 5, "ConstantLR/LinearLR: total iterations for the scheduler"); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -257,7 +268,6 @@ void Train(const nn::parallel::Rank &rank) { // auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate); auto optimizer_creator = optimizers::SGD::Create(FLAGS_learning_rate); std::shared_ptr optimizer = nullptr; - std::shared_ptr scheduler = nullptr; if (FLAGS_use_distributed_optimizer) { auto model_chunks = (pp_world_size > 1) @@ -269,9 +279,19 @@ void Train(const nn::parallel::Rank &rank) { optimizer = optimizer_creator(model->Parameters()); } - if (scheduler) { - scheduler->Step(); - } + LRSchedulerConfig sched_config; + sched_config.type = FLAGS_lr_scheduler; + sched_config.warmup_steps = FLAGS_warmup_steps; + sched_config.warmup_start_factor = static_cast(FLAGS_warmup_start_factor); + sched_config.warmup_end_factor = static_cast(FLAGS_warmup_end_factor); + sched_config.step_size = FLAGS_step_size; + sched_config.step_gamma = static_cast(FLAGS_gamma); + sched_config.linear_start_factor = static_cast(FLAGS_start_factor); + sched_config.linear_end_factor = static_cast(FLAGS_end_factor); + sched_config.constant_factor = static_cast(FLAGS_start_factor); // 复用 + sched_config.constant_total_iters = FLAGS_lr_total_iters; + sched_config.linear_total_iters = FLAGS_lr_total_iters; + auto scheduler = CreateLRScheduler(optimizer,sched_config); auto train_iter = train_loader.begin(); std::shared_ptr loss_fn @@ -359,6 +379,9 @@ void Train(const nn::parallel::Rank &rank) { } optimizer->Step(); + if (scheduler) { + scheduler->Step(); + } } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below @@ -368,6 +391,9 @@ void Train(const nn::parallel::Rank &rank) { y = std::make_shared(y->To(device)); lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); + if (scheduler) { + scheduler->Step(); + } } if (ddp_world_size > 1) { @@ -383,10 +409,11 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); - + const float current_lr = scheduler ? scheduler->GetLR() + : static_cast(FLAGS_learning_rate); LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", - step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, + step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 6c9bffcd..f81fca2e 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -55,6 +55,17 @@ DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations"); DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); +// lr scheduler +DEFINE_string(lr_scheduler, "none", + "Learning rate scheduler type: none|constant|step|linear"); +DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)"); +DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)"); +DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)"); +DEFINE_int64(step_size, 30, "StepLR: period of learning rate decay"); +DEFINE_double(gamma, 0.1, "StepLR: multiplicative factor of lr decay"); +DEFINE_double(start_factor, 0.333333, "LinearLR: starting multiplicative factor"); +DEFINE_double(end_factor, 1.0, "LinearLR: ending multiplicative factor"); +DEFINE_int64(lr_total_iters, 5, "ConstantLR/LinearLR: total iterations for the scheduler"); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -236,7 +247,6 @@ void Train(const nn::parallel::Rank &rank) { // auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate); auto optimizer_creator = optimizers::Adam::Create(FLAGS_learning_rate); std::shared_ptr optimizer = nullptr; - std::shared_ptr scheduler = nullptr; if (FLAGS_use_distributed_optimizer) { auto model_chunks = (pp_world_size > 1) @@ -248,9 +258,19 @@ void Train(const nn::parallel::Rank &rank) { optimizer = optimizer_creator(model->Parameters()); } - if (scheduler){ - scheduler->Step(); - } + LRSchedulerConfig sched_config; + sched_config.type = FLAGS_lr_scheduler; + sched_config.warmup_steps = FLAGS_warmup_steps; + sched_config.warmup_start_factor = static_cast(FLAGS_warmup_start_factor); + sched_config.warmup_end_factor = static_cast(FLAGS_warmup_end_factor); + sched_config.step_size = FLAGS_step_size; + sched_config.step_gamma = static_cast(FLAGS_gamma); + sched_config.linear_start_factor = static_cast(FLAGS_start_factor); + sched_config.linear_end_factor = static_cast(FLAGS_end_factor); + sched_config.constant_factor = static_cast(FLAGS_start_factor); // 复用 + sched_config.constant_total_iters = FLAGS_lr_total_iters; + sched_config.linear_total_iters = FLAGS_lr_total_iters; + auto scheduler = CreateLRScheduler(optimizer,sched_config); auto train_iter = train_loader.begin(); std::shared_ptr loss_fn @@ -335,6 +355,9 @@ void Train(const nn::parallel::Rank &rank) { } optimizer->Step(); + if (scheduler) { + scheduler->Step(); + } } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below @@ -344,6 +367,9 @@ void Train(const nn::parallel::Rank &rank) { y = std::make_shared(y->To(device)); lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); + if (scheduler) { + scheduler->Step(); + } } if (ddp_world_size > 1) { @@ -359,10 +385,11 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); - + const float current_lr = scheduler ? scheduler->GetLR() + : static_cast(FLAGS_learning_rate); LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", - step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, + step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); From 1f95e29ea5cb14c28a887314762ac3f8bc33764a Mon Sep 17 00:00:00 2001 From: Kinorw Date: Wed, 11 Mar 2026 18:09:56 +0800 Subject: [PATCH 18/18] style: apply clang-format to all legacy files --- example/gpt2/main.cc | 14 +- example/llama3/main.cc | 14 +- infini_train/include/lr_scheduler.h | 54 +++---- infini_train/include/optimizer.h | 3 +- infini_train/src/lr_scheduler.cc | 189 ++++++++---------------- infini_train/src/optimizer.cc | 10 +- test/lr_scheduler/test_chained_lr.cc | 117 ++++++++------- test/lr_scheduler/test_constant_lr.cc | 19 ++- test/lr_scheduler/test_helpers.h | 12 +- test/lr_scheduler/test_lambda_lr.cc | 68 ++++----- test/lr_scheduler/test_linear_lr.cc | 15 +- test/lr_scheduler/test_lr_scheduler.cc | 19 +-- test/lr_scheduler/test_sequential_lr.cc | 102 ++++++------- test/lr_scheduler/test_step_lr.cc | 25 ++-- 14 files changed, 282 insertions(+), 379 deletions(-) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 1e8c7794..90f60262 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -57,8 +57,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text"); DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations"); DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); // lr scheduler -DEFINE_string(lr_scheduler, "none", - "Learning rate scheduler type: none|constant|step|linear"); +DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear"); DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)"); DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)"); DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)"); @@ -289,10 +288,10 @@ void Train(const nn::parallel::Rank &rank) { sched_config.step_gamma = static_cast(FLAGS_gamma); sched_config.linear_start_factor = static_cast(FLAGS_start_factor); sched_config.linear_end_factor = static_cast(FLAGS_end_factor); - sched_config.constant_factor = static_cast(FLAGS_start_factor); // 复用 + sched_config.constant_factor = static_cast(FLAGS_start_factor); // 复用 sched_config.constant_total_iters = FLAGS_lr_total_iters; sched_config.linear_total_iters = FLAGS_lr_total_iters; - auto scheduler = CreateLRScheduler(optimizer,sched_config); + auto scheduler = CreateLRScheduler(optimizer, sched_config); auto train_iter = train_loader.begin(); std::shared_ptr loss_fn @@ -410,12 +409,11 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); - const float current_lr = scheduler ? scheduler->GetLR() - : static_cast(FLAGS_learning_rate); + const float current_lr = scheduler ? scheduler->GetLR() : static_cast(FLAGS_learning_rate); LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", - step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, - tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, + step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps, + used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 58ceb95f..5b1bffbb 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -56,8 +56,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text"); DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations"); DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); // lr scheduler -DEFINE_string(lr_scheduler, "none", - "Learning rate scheduler type: none|constant|step|linear"); +DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear"); DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)"); DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)"); DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)"); @@ -268,10 +267,10 @@ void Train(const nn::parallel::Rank &rank) { sched_config.step_gamma = static_cast(FLAGS_gamma); sched_config.linear_start_factor = static_cast(FLAGS_start_factor); sched_config.linear_end_factor = static_cast(FLAGS_end_factor); - sched_config.constant_factor = static_cast(FLAGS_start_factor); // 复用 + sched_config.constant_factor = static_cast(FLAGS_start_factor); // 复用 sched_config.constant_total_iters = FLAGS_lr_total_iters; sched_config.linear_total_iters = FLAGS_lr_total_iters; - auto scheduler = CreateLRScheduler(optimizer,sched_config); + auto scheduler = CreateLRScheduler(optimizer, sched_config); auto train_iter = train_loader.begin(); std::shared_ptr loss_fn @@ -386,12 +385,11 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); - const float current_lr = scheduler ? scheduler->GetLR() - : static_cast(FLAGS_learning_rate); + const float current_lr = scheduler ? scheduler->GetLR() : static_cast(FLAGS_learning_rate); LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", - step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, - tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, + step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps, + used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h index c9772ac1..4e4695ce 100644 --- a/infini_train/include/lr_scheduler.h +++ b/infini_train/include/lr_scheduler.h @@ -1,7 +1,7 @@ #pragma once -#include #include +#include #include #include #include @@ -13,12 +13,11 @@ namespace infini_train { class Optimizer; -using StateValue = std::variant>; +using StateValue = std::variant>; using StateDict = std::unordered_map; struct LRSchedulerConfig { - std::string type = "none"; + std::string type = "none"; // ConstantLR float constant_factor = 1.0f / 3.0f; int constant_total_iters = 5; @@ -44,15 +43,13 @@ struct LRSchedulerConfig { class LRScheduler { public: - template - static std::shared_ptr Create(Args&&... args) { + template static std::shared_ptr Create(Args &&...args) { auto scheduler = std::make_shared(std::forward(args)...); scheduler->InitialStep(); return scheduler; } - explicit LRScheduler(std::shared_ptr optimizer, - int64_t last_step = -1); + explicit LRScheduler(std::shared_ptr optimizer, int64_t last_step = -1); virtual ~LRScheduler() = default; LRScheduler(const LRScheduler &) = delete; @@ -82,17 +79,13 @@ class LRScheduler { bool is_initial_ = false; }; -std::shared_ptr CreateLRScheduler( - std::shared_ptr optimizer, - const LRSchedulerConfig& config); +std::shared_ptr CreateLRScheduler(std::shared_ptr optimizer, const LRSchedulerConfig &config); namespace lr_schedulers { class ConstantLR : public LRScheduler { public: - ConstantLR(std::shared_ptr optimizer, - float factor = 1.0f / 3.0f, - int total_iters = 5, + ConstantLR(std::shared_ptr optimizer, float factor = 1.0f / 3.0f, int total_iters = 5, int64_t last_step = -1); ~ConstantLR() override = default; @@ -107,10 +100,7 @@ class ConstantLR : public LRScheduler { class StepLR : public LRScheduler { public: - StepLR(std::shared_ptr optimizer, - int64_t step_size, - float gamma = 0.1f, - int64_t last_step = -1); + StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1); ~StepLR() override = default; protected: @@ -124,11 +114,8 @@ class StepLR : public LRScheduler { class LinearLR : public LRScheduler { public: - LinearLR(std::shared_ptr optimizer, - float start_factor = 1.0f / 3.0f, - float end_factor = 1.0f, - int64_t total_iters = 5, - int64_t last_step = -1); + LinearLR(std::shared_ptr optimizer, float start_factor = 1.0f / 3.0f, float end_factor = 1.0f, + int64_t total_iters = 5, int64_t last_step = -1); ~LinearLR() override = default; protected: @@ -145,9 +132,7 @@ class LambdaLR : public LRScheduler { public: using LambdaFunc = std::function; - LambdaLR(std::shared_ptr optimizer, - LambdaFunc lr_lambda, - int64_t last_step = -1); + LambdaLR(std::shared_ptr optimizer, LambdaFunc lr_lambda, int64_t last_step = -1); ~LambdaLR() override = default; protected: @@ -157,13 +142,10 @@ class LambdaLR : public LRScheduler { const LambdaFunc lr_lambda_; }; - class SequentialLR : public LRScheduler { public: - SequentialLR(std::shared_ptr optimizer, - std::vector> schedulers, - std::vector milestones, - int64_t last_step = -1); + SequentialLR(std::shared_ptr optimizer, std::vector> schedulers, + std::vector milestones, int64_t last_step = -1); ~SequentialLR() override = default; void Step() override; @@ -183,8 +165,7 @@ class SequentialLR : public LRScheduler { class ChainedScheduler : public LRScheduler { public: - ChainedScheduler(std::shared_ptr optimizer, - std::vector> schedulers, + ChainedScheduler(std::shared_ptr optimizer, std::vector> schedulers, int64_t last_step = -1); ~ChainedScheduler() override = default; @@ -192,7 +173,7 @@ class ChainedScheduler : public LRScheduler { void InitialStep() override; StateDict State() const override; - void LoadState(const StateDict& state) override; + void LoadState(const StateDict &state) override; protected: float GetClosedFormLR() const override { return current_lr_; } @@ -201,6 +182,5 @@ class ChainedScheduler : public LRScheduler { std::vector> schedulers_; }; - -} // namespace lr_schedulers -} // namespace infini_train \ No newline at end of file +} // namespace lr_schedulers +} // namespace infini_train \ No newline at end of file diff --git a/infini_train/include/optimizer.h b/infini_train/include/optimizer.h index ee879215..c72ee6c9 100644 --- a/infini_train/include/optimizer.h +++ b/infini_train/include/optimizer.h @@ -26,7 +26,7 @@ class Optimizer { virtual float GetLearningRate() const; float GetInitialLearningRate() const; - + void SetInitialLearningRate(float lr); protected: @@ -48,7 +48,6 @@ class SGD : public Optimizer { return std::make_shared(params, learning_rate); }; } - }; class Adam : public Optimizer { diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc index bf834886..11b6034b 100644 --- a/infini_train/src/lr_scheduler.cc +++ b/infini_train/src/lr_scheduler.cc @@ -4,87 +4,71 @@ #include "infini_train/include/optimizer.h" - namespace infini_train { -std::shared_ptr CreateLRScheduler( - std::shared_ptr optimizer, - const LRSchedulerConfig& config) { +std::shared_ptr CreateLRScheduler(std::shared_ptr optimizer, const LRSchedulerConfig &config) { if (config.type == "none") { return nullptr; } - auto create_main = [&](std::shared_ptr opt) - -> std::shared_ptr { + auto create_main = [&](std::shared_ptr opt) -> std::shared_ptr { if (config.type == "constant") { - return LRScheduler::Create( - opt, config.constant_factor, config.constant_total_iters); + return LRScheduler::Create(opt, config.constant_factor, + config.constant_total_iters); } if (config.type == "step") { - return LRScheduler::Create( - opt, config.step_size, config.step_gamma); + return LRScheduler::Create(opt, config.step_size, config.step_gamma); } if (config.type == "linear") { - return LRScheduler::Create( - opt, config.linear_start_factor, config.linear_end_factor, - config.linear_total_iters); + return LRScheduler::Create(opt, config.linear_start_factor, + config.linear_end_factor, config.linear_total_iters); } if (config.type == "lambda") { - return LRScheduler::Create( - opt, config.lambda_fn); + return LRScheduler::Create(opt, config.lambda_fn); } if (config.type == "sequential") { std::vector> schedulers; std::vector milestones = config.sequential_milestones; - for (const auto& sub_config : config.sequential_configs) { + for (const auto &sub_config : config.sequential_configs) { auto sub_sched = CreateLRScheduler(opt, sub_config); if (sub_sched) { schedulers.push_back(sub_sched); } } - return LRScheduler::Create( - opt, schedulers, milestones); + return LRScheduler::Create(opt, schedulers, milestones); } if (config.type == "chained") { std::vector> schedulers; - for (const auto& sub_config : config.chained_configs) { + for (const auto &sub_config : config.chained_configs) { auto sub_sched = CreateLRScheduler(opt, sub_config); if (sub_sched) { schedulers.push_back(sub_sched); } } - return LRScheduler::Create( - opt, schedulers); + return LRScheduler::Create(opt, schedulers); } LOG(FATAL) << "Unsupported LR scheduler type: " << config.type; - return nullptr; + return nullptr; }; if (config.warmup_steps <= 0) { return create_main(optimizer); } - auto warmup_scheduler = LRScheduler::Create( - optimizer, - /*start_factor=*/config.warmup_start_factor, - /*end_factor=*/config.warmup_end_factor, - /*total_iters=*/config.warmup_steps); - + auto warmup_scheduler = LRScheduler::Create(optimizer, + /*start_factor=*/config.warmup_start_factor, + /*end_factor=*/config.warmup_end_factor, + /*total_iters=*/config.warmup_steps); + auto main_scheduler = create_main(optimizer); return LRScheduler::Create( - optimizer, - std::vector>{warmup_scheduler, main_scheduler}, + optimizer, std::vector>{warmup_scheduler, main_scheduler}, std::vector{config.warmup_steps}); - }; -LRScheduler::LRScheduler(std::shared_ptr optimizer, - int64_t last_step) - : optimizer_(std::move(optimizer)), - last_step_(last_step), - current_lr_(0.0f), - base_lr_(0.0f) { +LRScheduler::LRScheduler(std::shared_ptr optimizer, int64_t last_step) + : optimizer_(std::move(optimizer)), last_step_(last_step), current_lr_(0.0f), base_lr_(0.0f) { CHECK(optimizer_) << "LRScheduler: optimizer must not be null."; optimizer_->SetInitialLearningRate(optimizer_->GetLearningRate()); base_lr_ = optimizer_->GetInitialLearningRate(); @@ -112,9 +96,7 @@ void LRScheduler::ApplyLR(float lr) { optimizer_->SetLearningRate(current_lr_); } -float LRScheduler::GetChainedFormLR() const { - return GetClosedFormLR(); -} +float LRScheduler::GetChainedFormLR() const { return GetClosedFormLR(); } float LRScheduler::GetLR() const { return current_lr_; } @@ -139,25 +121,16 @@ void LRScheduler::LoadState(const StateDict &state) { optimizer_->SetLearningRate(current_lr_); } - - // Concrete LR Schedulers namespace lr_schedulers { -// --- ConstantLR --- +// --- ConstantLR --- -ConstantLR::ConstantLR(std::shared_ptr optimizer, - float factor, - int total_iters, - int64_t last_step) - : LRScheduler(std::move(optimizer), last_step), - factor_(factor), - total_iters_(total_iters) {} +ConstantLR::ConstantLR(std::shared_ptr optimizer, float factor, int total_iters, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) {} -float ConstantLR::GetClosedFormLR() const { - return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; -} +float ConstantLR::GetClosedFormLR() const { return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; } float ConstantLR::GetChainedFormLR() const { const float lr = optimizer_->GetLearningRate(); @@ -165,7 +138,7 @@ float ConstantLR::GetChainedFormLR() const { return lr * factor_; } else if (last_step_ < total_iters_) { return lr; - } else if (last_step_ == total_iters_){ + } else if (last_step_ == total_iters_) { return lr / factor_; } return lr; @@ -173,18 +146,12 @@ float ConstantLR::GetChainedFormLR() const { // --- StepLR --- -StepLR::StepLR(std::shared_ptr optimizer, - int64_t step_size, - float gamma, - int64_t last_step) - : LRScheduler(std::move(optimizer), last_step), - step_size_(step_size), - gamma_(gamma) {} +StepLR::StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) {} float StepLR::GetClosedFormLR() const { - return base_lr_ * static_cast(std::pow( - static_cast(gamma_), - static_cast(last_step_ / step_size_))); + return base_lr_ + * static_cast(std::pow(static_cast(gamma_), static_cast(last_step_ / step_size_))); } float StepLR::GetChainedFormLR() const { @@ -195,25 +162,18 @@ float StepLR::GetChainedFormLR() const { return lr * gamma_; } - -LinearLR::LinearLR(std::shared_ptr optimizer, - float start_factor, - float end_factor, - int64_t total_iters, +LinearLR::LinearLR(std::shared_ptr optimizer, float start_factor, float end_factor, int64_t total_iters, int64_t last_step) - : LRScheduler(std::move(optimizer), last_step), - start_factor_(start_factor), - end_factor_(end_factor), + : LRScheduler(std::move(optimizer), last_step), start_factor_(start_factor), end_factor_(end_factor), total_iters_(total_iters) {} float LinearLR::GetClosedFormLR() const { if (last_step_ >= total_iters_) { return base_lr_ * end_factor_; } - return base_lr_ * - (start_factor_ + (end_factor_ - start_factor_) * - static_cast(last_step_) / - static_cast(total_iters_)); + return base_lr_ + * (start_factor_ + + (end_factor_ - start_factor_) * static_cast(last_step_) / static_cast(total_iters_)); } float LinearLR::GetChainedFormLR() const { @@ -225,47 +185,35 @@ float LinearLR::GetChainedFormLR() const { return lr; } if (last_step_ == total_iters_) { - const float prev_factor = - start_factor_ + - (end_factor_ - start_factor_) * - static_cast(total_iters_ - 1) / - static_cast(total_iters_); + const float prev_factor + = start_factor_ + + (end_factor_ - start_factor_) * static_cast(total_iters_ - 1) / static_cast(total_iters_); return lr * (end_factor_ / prev_factor); } const float numerator = end_factor_ - start_factor_; - const float denominator = - start_factor_ * static_cast(total_iters_) + - static_cast(last_step_ - 1) * numerator; + const float denominator + = start_factor_ * static_cast(total_iters_) + static_cast(last_step_ - 1) * numerator; return lr * (1.0f + numerator / denominator); } -LambdaLR::LambdaLR(std::shared_ptr optimizer, - std::function lr_lambda, - int64_t last_step) - : LRScheduler(std::move(optimizer), last_step), - lr_lambda_(std::move(lr_lambda)) {} +LambdaLR::LambdaLR(std::shared_ptr optimizer, std::function lr_lambda, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) {} -float LambdaLR::GetClosedFormLR() const { - return base_lr_ * lr_lambda_(last_step_); -} +float LambdaLR::GetClosedFormLR() const { return base_lr_ * lr_lambda_(last_step_); } -SequentialLR::SequentialLR(std::shared_ptr optimizer, - std::vector> schedulers, - std::vectormilestones, int64_t last_step) - : LRScheduler(std::move(optimizer), last_step), - schedulers_(std::move(schedulers)), - milestones_(std::move(milestones)) {} +SequentialLR::SequentialLR(std::shared_ptr optimizer, std::vector> schedulers, + std::vector milestones, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)), + milestones_(std::move(milestones)) {} void SequentialLR::InitialStep() { - CHECK(!schedulers_.empty()) - << "SequentialLR requires at least one scheduler."; + CHECK(!schedulers_.empty()) << "SequentialLR requires at least one scheduler."; CHECK_EQ(milestones_.size(), schedulers_.size() - 1) << "SequentialLR: milestones count must be schedulers count - 1."; - for(size_t i = 1; i < milestones_.size(); ++i) { - CHECK_GT(milestones_[i], milestones_[i-1]) - << "Milestones must be strictly increasing."; + for (size_t i = 1; i < milestones_.size(); ++i) { + CHECK_GT(milestones_[i], milestones_[i - 1]) << "Milestones must be strictly increasing."; } optimizer_->SetLearningRate(schedulers_[0]->BaseLR()); @@ -275,7 +223,6 @@ void SequentialLR::InitialStep() { ++last_step_; schedulers_[0]->InitialStep(); current_lr_ = schedulers_[0]->GetLR(); - } void SequentialLR::UndoChildInitialSteps() { @@ -309,9 +256,7 @@ StateDict SequentialLR::State() const { state["base_lr"] = base_lr_; for (size_t i = 0; i < schedulers_.size(); ++i) { auto sub_state = schedulers_[i]->State(); - for (const auto &[key, value] : sub_state) { - state["scheduler_" + std::to_string(i) + "." + key] = value; - } + for (const auto &[key, value] : sub_state) { state["scheduler_" + std::to_string(i) + "." + key] = value; } } return state; } @@ -329,32 +274,26 @@ void SequentialLR::LoadState(const StateDict &state) { sub_state[key.substr(prefix.size())] = value; } } - if(!sub_state.empty()) + if (!sub_state.empty()) { schedulers_[i]->LoadState(sub_state); + } } optimizer_->SetLearningRate(current_lr_); } ChainedScheduler::ChainedScheduler(std::shared_ptr optimizer, - std::vector> schedulers, - int64_t last_step) - : LRScheduler(std::move(optimizer), last_step), - schedulers_(std::move(schedulers)) {} - + std::vector> schedulers, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)) {} void ChainedScheduler::InitialStep() { - CHECK(!schedulers_.empty()) - << "ChainedScheduler requires at least one scheduler."; + CHECK(!schedulers_.empty()) << "ChainedScheduler requires at least one scheduler."; current_lr_ = optimizer_->GetLearningRate(); } - void ChainedScheduler::Step() { ++last_step_; - for (auto &sched : schedulers_) { - sched->Step(); - } + for (auto &sched : schedulers_) { sched->Step(); } current_lr_ = optimizer_->GetLearningRate(); } @@ -362,19 +301,17 @@ StateDict ChainedScheduler::State() const { StateDict state = LRScheduler::State(); for (size_t i = 0; i < schedulers_.size(); ++i) { auto sub_state = schedulers_[i]->State(); - for (const auto& [key, value] : sub_state) { - state["scheduler_" + std::to_string(i) + "." + key] = value; - } + for (const auto &[key, value] : sub_state) { state["scheduler_" + std::to_string(i) + "." + key] = value; } } return state; } -void ChainedScheduler::LoadState(const StateDict& state) { +void ChainedScheduler::LoadState(const StateDict &state) { LRScheduler::LoadState(state); for (size_t i = 0; i < schedulers_.size(); ++i) { StateDict sub_state; std::string prefix = "scheduler_" + std::to_string(i) + "."; - for (const auto& [key, value] : state) { + for (const auto &[key, value] : state) { if (key.substr(0, prefix.size()) == prefix) { sub_state[key.substr(prefix.size())] = value; } @@ -385,5 +322,5 @@ void ChainedScheduler::LoadState(const StateDict& state) { } } -} // namespace lr_schedulers -} // namespace infini_train \ No newline at end of file +} // namespace lr_schedulers +} // namespace infini_train \ No newline at end of file diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index 29930814..c86c40f1 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -10,7 +10,7 @@ namespace infini_train { Optimizer::Optimizer(const std::vector> ¶ms, float learning_rate) : params_(params), learning_rate_(learning_rate) {} - + void Optimizer::ZeroGrad(bool set_to_none) { for (auto param : params_) { param->ZeroGrad(set_to_none); } } @@ -20,9 +20,8 @@ void Optimizer::SetLearningRate(float lr) { learning_rate_ = lr; } float Optimizer::GetLearningRate() const { return learning_rate_; } float Optimizer::GetInitialLearningRate() const { - CHECK(initial_lr_set_) - << "Optimizer: initial_learning_rate not set. " - "Use with an LRScheduler first."; + CHECK(initial_lr_set_) << "Optimizer: initial_learning_rate not set. " + "Use with an LRScheduler first."; return initial_learning_rate_; } @@ -34,8 +33,7 @@ void Optimizer::SetInitialLearningRate(float lr) { } namespace optimizers { -SGD::SGD(const std::vector> ¶ms, float learning_rate) - : Optimizer(params, learning_rate) {} +SGD::SGD(const std::vector> ¶ms, float learning_rate) : Optimizer(params, learning_rate) {} void SGD::Step() { for (auto param : params_) { diff --git a/test/lr_scheduler/test_chained_lr.cc b/test/lr_scheduler/test_chained_lr.cc index 4c451883..3ea6bd55 100644 --- a/test/lr_scheduler/test_chained_lr.cc +++ b/test/lr_scheduler/test_chained_lr.cc @@ -4,7 +4,7 @@ using namespace infini_train; using namespace infini_train::lr_schedulers; -namespace{ +namespace { constexpr float kBaseLR = 0.1f; } // TC1: 单子调度器退化 @@ -12,14 +12,14 @@ void TestSingleScheduler() { std::cout << "[TC1] TestSingleScheduler" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); auto step_lr = CreateLRScheduler(opt, { - .type = "step", - .step_size = 3, - .step_gamma = 0.5f, - }); + .type = "step", + .step_size = 3, + .step_gamma = 0.5f, + }); auto sched = LRScheduler::Create(opt, std::vector>{step_lr}); ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); - sched->Step(); // step=1 + sched->Step(); // step=1 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); } @@ -27,17 +27,19 @@ void TestSingleScheduler() { void TestMultiplicativeChain() { std::cout << "[TC2] TestMultiplicativeChain" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = CreateLRScheduler(opt, { - .type = "chained", - .chained_configs = {{ - .type = "step", - .step_size = 2, - .step_gamma = 0.5f, - },{ - .type = "lambda", - .lambda_fn = [](int64_t step) { return 1.0f - 0.1f * step; }, - }}, - }); + auto sched = CreateLRScheduler( + opt, { + .type = "chained", + .chained_configs = {{ + .type = "step", + .step_size = 2, + .step_gamma = 0.5f, + }, + { + .type = "lambda", + .lambda_fn = [](int64_t step) { return 1.0f - 0.1f * step; }, + }}, + }); ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); @@ -56,17 +58,18 @@ void TestConstantPlusStep() { std::cout << "[TC3] TestConstantPlusStep" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); auto sched = CreateLRScheduler(opt, { - .type = "chained", - .chained_configs = {{ - .type = "constant", - .constant_factor = 0.5f, - .constant_total_iters = 2, - }, { - .type = "step", - .step_size = 3, - .step_gamma = 0.1f, - }}, - }); + .type = "chained", + .chained_configs = {{ + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 2, + }, + { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }}, + }); ASSERT_FLOAT_NEAR(sched->GetLR(), 0.05f, kEps); @@ -88,27 +91,28 @@ void TestConstantPlusStepDLC() { std::cout << "[TC4] TestConstantPlusStepDLC" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); auto constant = CreateLRScheduler(opt, { - .type = "constant", - .constant_factor = 0.5f, - .constant_total_iters = 2, - }); + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 2, + }); auto linear = CreateLRScheduler(opt, { - .type = "linear", - .linear_start_factor = 1e-8f, - .linear_end_factor = 1.0f, - .linear_total_iters = 3, - }); + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }); auto step_lr = CreateLRScheduler(opt, { - .type = "step", - .step_size = 3, - .step_gamma = 0.1f, - }); + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }); auto Lambda = CreateLRScheduler(opt, { - .type = "lambda", - .lambda_fn = [](int64_t step) { return 1.0f - 0.1f * step; }, - }); + .type = "lambda", + .lambda_fn = [](int64_t step) { return 1.0f - 0.1f * step; }, + }); - auto sched = LRScheduler::Create(opt, std::vector>{constant, step_lr}); + auto sched + = LRScheduler::Create(opt, std::vector>{constant, step_lr}); ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); @@ -130,18 +134,18 @@ void TestStateRoundTrip() { std::cout << "[TC5] TestStateRoundTrip" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); auto step_lr = std::make_shared(opt, 2, 0.5f); - auto lambda_lr = std::make_shared(opt, - [](int64_t step) { return 1.0f - 0.05f * step; }); - auto sched = LRScheduler::Create(opt, std::vector>{step_lr, lambda_lr}); + auto lambda_lr = std::make_shared(opt, [](int64_t step) { return 1.0f - 0.05f * step; }); + auto sched + = LRScheduler::Create(opt, std::vector>{step_lr, lambda_lr}); - for (int i = 0; i < 5; ++i) sched->Step(); + for (int i = 0; i < 5; ++i) { sched->Step(); } StateDict saved = sched->State(); auto opt2 = MakeDummyOptimizer(kBaseLR); auto step_lr2 = std::make_shared(opt2, 2, 0.5f); - auto lambda_lr2 = std::make_shared(opt2, - [](int64_t step) { return 1.0f - 0.05f * step; }); - auto sched2 = LRScheduler::Create(opt2, std::vector>{step_lr2, lambda_lr2}); + auto lambda_lr2 = std::make_shared(opt2, [](int64_t step) { return 1.0f - 0.05f * step; }); + auto sched2 + = LRScheduler::Create(opt2, std::vector>{step_lr2, lambda_lr2}); sched2->LoadState(saved); ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); @@ -157,22 +161,23 @@ void TestResumeConsistency() { auto make_sched = [&](std::shared_ptr opt) { auto step_lr = std::make_shared(opt, 2, 0.5f); auto lambda_lr = std::make_shared(opt, lambda_fn); - return LRScheduler::Create(opt, std::vector>{step_lr, lambda_lr}); + return LRScheduler::Create(opt, + std::vector>{step_lr, lambda_lr}); }; auto opt_ref = MakeDummyOptimizer(kBaseLR); auto sched_ref = make_sched(opt_ref); - for (int i = 0; i < kN; ++i) sched_ref->Step(); + for (int i = 0; i < kN; ++i) { sched_ref->Step(); } auto opt_a = MakeDummyOptimizer(kBaseLR); auto sched_a = make_sched(opt_a); - for (int i = 0; i < kK; ++i) sched_a->Step(); + for (int i = 0; i < kK; ++i) { sched_a->Step(); } StateDict ckpt = sched_a->State(); auto opt_b = MakeDummyOptimizer(kBaseLR); auto sched_b = make_sched(opt_b); sched_b->LoadState(ckpt); - for (int i = 0; i < kN - kK; ++i) sched_b->Step(); + for (int i = 0; i < kN - kK; ++i) { sched_b->Step(); } ASSERT_FLOAT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); diff --git a/test/lr_scheduler/test_constant_lr.cc b/test/lr_scheduler/test_constant_lr.cc index aea6f6b3..df5e9be2 100644 --- a/test/lr_scheduler/test_constant_lr.cc +++ b/test/lr_scheduler/test_constant_lr.cc @@ -6,8 +6,7 @@ using namespace infini_train::lr_schedulers; namespace { constexpr float kBaseLR = 0.1f; -} // namespace - +} // namespace void TestInitialState() { auto opt = MakeDummyOptimizer(kBaseLR); @@ -29,9 +28,9 @@ void TestFirstStepAppliesFactor() { .constant_factor = 0.5f, .constant_total_iters = 3, }; - + auto sched = CreateLRScheduler(opt, config); - sched->Step(); // last_step_ = 0 + sched->Step(); // last_step_ = 0 ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f); ASSERT_TRUE(sched->LastStep() == 1); @@ -45,7 +44,7 @@ void TestWithinTotalIters() { .constant_total_iters = 3, }; auto sched = CreateLRScheduler(opt, config); - for (int i = 0; i < 2; ++i) sched->Step(); + for (int i = 0; i < 2; ++i) { sched->Step(); } // last_step_ = 2, still < 3 ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); } @@ -58,7 +57,7 @@ void TestBeyondTotalIters() { .constant_total_iters = 3, }; auto sched = CreateLRScheduler(opt, config); - for (int i = 0; i < 10; ++i) sched->Step(); + for (int i = 0; i < 10; ++i) { sched->Step(); } ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); } @@ -86,7 +85,7 @@ void TestStateRoundTrip() { .constant_total_iters = 5, }; auto sched = CreateLRScheduler(opt, config); - for (int i = 0; i < 3; ++i) sched->Step(); + for (int i = 0; i < 3; ++i) { sched->Step(); } StateDict saved = sched->State(); auto opt2 = MakeDummyOptimizer(kBaseLR); @@ -114,7 +113,7 @@ void TestResumeConsistency() { .constant_total_iters = 5, }; auto sched_ref = CreateLRScheduler(opt_ref, config_ref); - for (int i = 0; i < kN; ++i) sched_ref->Step(); + for (int i = 0; i < kN; ++i) { sched_ref->Step(); } auto opt_a = MakeDummyOptimizer(kBaseLR); LRSchedulerConfig config_a = { @@ -123,7 +122,7 @@ void TestResumeConsistency() { .constant_total_iters = 5, }; auto sched_a = CreateLRScheduler(opt_a, config_a); - for (int i = 0; i < kK; ++i) sched_a->Step(); + for (int i = 0; i < kK; ++i) { sched_a->Step(); } StateDict ckpt = sched_a->State(); auto opt_b = MakeDummyOptimizer(kBaseLR); @@ -134,7 +133,7 @@ void TestResumeConsistency() { }; auto sched_b = CreateLRScheduler(opt_b, config_b); sched_b->LoadState(ckpt); - for (int i = 0; i < kN - kK; ++i) sched_b->Step(); + for (int i = 0; i < kN - kK; ++i) { sched_b->Step(); } ASSERT_FLOAT_EQ(sched_b->GetLR(), sched_ref->GetLR()); ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); diff --git a/test/lr_scheduler/test_helpers.h b/test/lr_scheduler/test_helpers.h index 43f0a2f5..f4b22430 100644 --- a/test/lr_scheduler/test_helpers.h +++ b/test/lr_scheduler/test_helpers.h @@ -17,9 +17,7 @@ std::shared_ptr MakeDummyOptimizer(float lr) { return std::make_shared(empty_params, lr); } -bool FloatNear(float a, float b, float eps = kEps) { - return std::fabs(a - b) < eps; -} +bool FloatNear(float a, float b, float eps = kEps) { return std::fabs(a - b) < eps; } int g_fail_count = 0; @@ -31,9 +29,7 @@ void Check(bool cond, const char *expr, int line) { } #define ASSERT_TRUE(cond) Check((cond), #cond, __LINE__) -#define ASSERT_FLOAT_EQ(a, b) \ - Check(FloatNear((a), (b)), #a " == " #b, __LINE__) -#define ASSERT_FLOAT_NEAR(a, b, eps) \ - Check(FloatNear((a), (b), (eps)), #a " ≈ " #b, __LINE__) +#define ASSERT_FLOAT_EQ(a, b) Check(FloatNear((a), (b)), #a " == " #b, __LINE__) +#define ASSERT_FLOAT_NEAR(a, b, eps) Check(FloatNear((a), (b), (eps)), #a " ≈ " #b, __LINE__) -} // namespace \ No newline at end of file +} // namespace \ No newline at end of file diff --git a/test/lr_scheduler/test_lambda_lr.cc b/test/lr_scheduler/test_lambda_lr.cc index 0e977ba0..a89d6356 100644 --- a/test/lr_scheduler/test_lambda_lr.cc +++ b/test/lr_scheduler/test_lambda_lr.cc @@ -6,14 +6,14 @@ using namespace infini_train::lr_schedulers; namespace { constexpr float kBaseLR = 0.1f; -} // namespace +} // namespace void TestIdentityLambda() { auto opt = MakeDummyOptimizer(kBaseLR); auto sched = CreateLRScheduler(opt, { - .type = "lambda", - .lambda_fn = [](int64_t) { return 1.0f; }, - }); + .type = "lambda", + .lambda_fn = [](int64_t) { return 1.0f; }, + }); // 构造器内 Step() → last_step_=0, lr = 0.1 * 1.0 = 0.1 ASSERT_TRUE(sched->LastStep() == 0); ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps); @@ -23,29 +23,30 @@ void TestIdentityLambda() { void TestLinearDecayLambda() { auto opt = MakeDummyOptimizer(kBaseLR); auto sched = CreateLRScheduler(opt, { - .type = "lambda", - .lambda_fn = [](int64_t step) { return 1.0f - step * 0.1f; }, - }); + .type = "lambda", + .lambda_fn = [](int64_t step) { return 1.0f - step * 0.1f; }, + }); // step=0, lambda(0)=1.0, lr=0.1 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); - sched->Step(); // step=1, lambda(1)=0.9, lr=0.09 + sched->Step(); // step=1, lambda(1)=0.9, lr=0.09 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.09f, kEps); - sched->Step(); // step=2, lambda(2)=0.8, lr=0.08 + sched->Step(); // step=2, lambda(2)=0.8, lr=0.08 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.08f, kEps); - sched->Step(); // step=3, lambda(3)=0.7, lr=0.07 + sched->Step(); // step=3, lambda(3)=0.7, lr=0.07 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.07f, kEps); } void TestPyTorchAlignment() { // PyTorch: LambdaLR(opt, lr_lambda=lambda epoch: 0.95**epoch) auto opt = MakeDummyOptimizer(kBaseLR); - auto sched = CreateLRScheduler(opt, { - .type = "lambda", - .lambda_fn = [](int64_t step) { return static_cast(std::pow(0.95, step)); }, - }); + auto sched + = CreateLRScheduler(opt, { + .type = "lambda", + .lambda_fn = [](int64_t step) { return static_cast(std::pow(0.95, step)); }, + }); // step=0, lr = 0.1 * 0.95^0 = 0.1 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); @@ -60,17 +61,17 @@ void TestStateRoundTrip() { auto lambda_fn = [](int64_t step) { return 1.0f - step * 0.05f; }; auto opt = MakeDummyOptimizer(kBaseLR); auto sched = CreateLRScheduler(opt, { - .type = "lambda", - .lambda_fn = lambda_fn, - }); - for (int i = 0; i < 5; ++i) sched->Step(); + .type = "lambda", + .lambda_fn = lambda_fn, + }); + for (int i = 0; i < 5; ++i) { sched->Step(); } StateDict saved = sched->State(); auto opt2 = MakeDummyOptimizer(kBaseLR); auto sched2 = CreateLRScheduler(opt2, { - .type = "lambda", - .lambda_fn = lambda_fn, - }); // same lambda + .type = "lambda", + .lambda_fn = lambda_fn, + }); // same lambda sched2->LoadState(saved); ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); @@ -83,32 +84,31 @@ void TestResumeConsistency() { auto opt_ref = MakeDummyOptimizer(kBaseLR); auto sched_ref = CreateLRScheduler(opt_ref, { - .type = "lambda", - .lambda_fn = lambda_fn, - }); - for (int i = 0; i < kN; ++i) sched_ref->Step(); + .type = "lambda", + .lambda_fn = lambda_fn, + }); + for (int i = 0; i < kN; ++i) { sched_ref->Step(); } auto opt_a = MakeDummyOptimizer(kBaseLR); auto sched_a = CreateLRScheduler(opt_a, { - .type = "lambda", - .lambda_fn = lambda_fn, - }); - for (int i = 0; i < kK; ++i) sched_a->Step(); + .type = "lambda", + .lambda_fn = lambda_fn, + }); + for (int i = 0; i < kK; ++i) { sched_a->Step(); } StateDict ckpt = sched_a->State(); auto opt_b = MakeDummyOptimizer(kBaseLR); auto sched_b = CreateLRScheduler(opt_b, { - .type = "lambda", - .lambda_fn = lambda_fn, - }); + .type = "lambda", + .lambda_fn = lambda_fn, + }); sched_b->LoadState(ckpt); - for (int i = 0; i < kN - kK; ++i) sched_b->Step(); + for (int i = 0; i < kN - kK; ++i) { sched_b->Step(); } ASSERT_FLOAT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); } - int main(int argc, char *argv[]) { google::InitGoogleLogging(argv[0]); std::cout << "=== LambdaLR Tests ===" << std::endl; diff --git a/test/lr_scheduler/test_linear_lr.cc b/test/lr_scheduler/test_linear_lr.cc index a7391a50..e1659a7f 100644 --- a/test/lr_scheduler/test_linear_lr.cc +++ b/test/lr_scheduler/test_linear_lr.cc @@ -16,7 +16,7 @@ void TestFirstStepFromZero() { .linear_end_factor = 1.0f, .linear_total_iters = 5, }; - + auto sched = CreateLRScheduler(opt, config); ASSERT_FLOAT_EQ(sched->GetLR(), 0.02f); } @@ -30,7 +30,7 @@ void TestMidpointLR() { .linear_total_iters = 5, }; auto sched = CreateLRScheduler(opt, config); - for (int i = 0; i < 3; ++i) sched->Step(); + for (int i = 0; i < 3; ++i) { sched->Step(); } // last_step_=3 -> 0.1*(0.2 + 0.8*3/5) = 0.068 ASSERT_FLOAT_EQ(sched->GetLR(), 0.068f); } @@ -44,7 +44,7 @@ void TestWarmupEnd() { .linear_total_iters = 5, }; auto sched = CreateLRScheduler(opt, config); - for (int i = 0; i < 5; ++i) sched->Step(); + for (int i = 0; i < 5; ++i) { sched->Step(); } // last_step_ >= total_iters -> base_lr * end_factor ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); } @@ -58,7 +58,7 @@ void TestBeyondWarmup() { .linear_total_iters = 5, }; auto sched = CreateLRScheduler(opt, config); - for (int i = 0; i < 20; ++i) sched->Step(); + for (int i = 0; i < 20; ++i) { sched->Step(); } ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); } @@ -71,15 +71,14 @@ void TestCustomStartFactor() { .linear_total_iters = 4, }; auto sched = CreateLRScheduler(opt, config); - sched->Step(); // last_step_=1, lr=0.1*(0.25+0.75*1/4)=0.04375 + sched->Step(); // last_step_=1, lr=0.1*(0.25+0.75*1/4)=0.04375 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.04375f, 1e-6f); - sched->Step(); // last_step_=2, lr=0.1*(0.25+0.75*2/4)=0.0625 + sched->Step(); // last_step_=2, lr=0.1*(0.25+0.75*2/4)=0.0625 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.0625f, 1e-6f); } void TestPyTorchAlignment() { - const std::vector expected = { - 0.036f, 0.052f, 0.068f, 0.084f, 0.1f, 0.1f, 0.1f}; + const std::vector expected = {0.036f, 0.052f, 0.068f, 0.084f, 0.1f, 0.1f, 0.1f}; auto opt = MakeDummyOptimizer(kBaseLR); LRSchedulerConfig config = { .type = "linear", diff --git a/test/lr_scheduler/test_lr_scheduler.cc b/test/lr_scheduler/test_lr_scheduler.cc index 0c3dce1d..58f6bdd6 100644 --- a/test/lr_scheduler/test_lr_scheduler.cc +++ b/test/lr_scheduler/test_lr_scheduler.cc @@ -27,18 +27,15 @@ class IdentityScheduler : public LRScheduler { class LinearDecayScheduler : public LRScheduler { public: - LinearDecayScheduler(std::shared_ptr optimizer, - int64_t total_steps, int64_t last_step = -1) - : LRScheduler(std::move(optimizer), last_step), - total_steps_(total_steps) {} + LinearDecayScheduler(std::shared_ptr optimizer, int64_t total_steps, int64_t last_step = -1) + : LRScheduler(std::move(optimizer), last_step), total_steps_(total_steps) {} protected: float GetClosedFormLR() const override { if (last_step_ >= total_steps_) { return 0.0f; } - return base_lr_ * (1.0f - static_cast(last_step_) / - static_cast(total_steps_)); + return base_lr_ * (1.0f - static_cast(last_step_) / static_cast(total_steps_)); } private: @@ -62,9 +59,7 @@ void Check(bool cond, const char *expr, int line) { } #define ASSERT_TRUE(cond) Check((cond), #cond, __LINE__) -#define ASSERT_FLOAT_EQ(a, b) \ - Check(FloatEq((a), (b)), #a " == " #b, __LINE__) - +#define ASSERT_FLOAT_EQ(a, b) Check(FloatEq((a), (b)), #a " == " #b, __LINE__) // T1: Init void TestInitialState() { @@ -99,11 +94,11 @@ void TestLinearDecay() { ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); - sched->Step(); // last_step = 1 -> 0.09 + sched->Step(); // last_step = 1 -> 0.09 ASSERT_FLOAT_EQ(sched->GetLR(), 0.09f); ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.09f); - for (int i = 0; i < 4; ++i) { sched->Step(); } // last_step = 5 + for (int i = 0; i < 4; ++i) { sched->Step(); } // last_step = 5 ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f); } @@ -156,7 +151,7 @@ void TestResumeAndContinue() { ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); } -} // namespace +} // namespace int main(int argc, char *argv[]) { google::InitGoogleLogging(argv[0]); diff --git a/test/lr_scheduler/test_sequential_lr.cc b/test/lr_scheduler/test_sequential_lr.cc index 2d866540..df8d6cbd 100644 --- a/test/lr_scheduler/test_sequential_lr.cc +++ b/test/lr_scheduler/test_sequential_lr.cc @@ -25,23 +25,23 @@ void TestLinearThenConstant() { }; auto constant = CreateLRScheduler(opt, constant_config); auto sched = CreateLRScheduler(opt, { - .type = "sequential", - .sequential_configs = {linear_config, constant_config}, - .sequential_milestones = {3}, - }); - + .type = "sequential", + .sequential_configs = {linear_config, constant_config}, + .sequential_milestones = {3}, + }); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.0f, kEps); - sched->Step(); // global=1, warmup step=1, lr=0.1*(1/3) + sched->Step(); // global=1, warmup step=1, lr=0.1*(1/3) ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f / 3.0f, 1e-5f); - sched->Step(); // global=2, warmup step=2, lr=0.1*(2/3) + sched->Step(); // global=2, warmup step=2, lr=0.1*(2/3) ASSERT_FLOAT_NEAR(sched->GetLR(), 0.2f / 3.0f, 1e-5f); - sched->Step(); // global=3, constant step=0, lr=0.1*1.0=0.1 + sched->Step(); // global=3, constant step=0, lr=0.1*1.0=0.1 ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps); - sched->Step(); // global=4, constant step=1, lr=0.1 + sched->Step(); // global=4, constant step=1, lr=0.1 ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps); } @@ -64,24 +64,24 @@ void TestLinearThenStepLR() { auto step_lr = CreateLRScheduler(opt, step_config); auto sched = CreateLRScheduler(opt, { - .type = "sequential", - .sequential_configs = {linear_config, step_config}, - .sequential_milestones = {3}, - }); + .type = "sequential", + .sequential_configs = {linear_config, step_config}, + .sequential_milestones = {3}, + }); - sched->Step(); // global=1 - sched->Step(); // global=2 + sched->Step(); // global=1 + sched->Step(); // global=2 - sched->Step(); // global=3, StepLR step=0, lr=0.1 + sched->Step(); // global=3, StepLR step=0, lr=0.1 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); - sched->Step(); // global=4, StepLR step=1 - sched->Step(); // global=5, StepLR step=2 - sched->Step(); // global=6, StepLR step=3, 3//3=1, lr=0.1*0.5=0.05 + sched->Step(); // global=4, StepLR step=1 + sched->Step(); // global=5, StepLR step=2 + sched->Step(); // global=6, StepLR step=3, 3//3=1, lr=0.1*0.5=0.05 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.05f, kEps); } -void TestLinearThenStepThenConstant(){ +void TestLinearThenStepThenConstant() { std::cout << "[TC3] TestLinearThenStepThenConstant" << std::endl; auto opt = MakeDummyOptimizer(kBaseLR); @@ -106,12 +106,11 @@ void TestLinearThenStepThenConstant(){ auto constant = CreateLRScheduler(opt, constant_config); auto sched = CreateLRScheduler(opt, { - .type = "sequential", - .sequential_configs = {linear_config, step_config, constant_config}, - .sequential_milestones = {3, 6}, - }); - const std::vector expected = { - 0.033333f, 0.066667f, 0.1f, 0.1f, 0.1f, 0.05f, 0.05f, 0.1f, 0.1f, 0.1f}; + .type = "sequential", + .sequential_configs = {linear_config, step_config, constant_config}, + .sequential_milestones = {3, 6}, + }); + const std::vector expected = {0.033333f, 0.066667f, 0.1f, 0.1f, 0.1f, 0.05f, 0.05f, 0.1f, 0.1f, 0.1f}; for (size_t i = 0; i < expected.size(); ++i) { sched->Step(); ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-5f); @@ -135,11 +134,11 @@ void TestStateRoundTrip() { }; auto step_lr = CreateLRScheduler(opt, step_config); auto sched = CreateLRScheduler(opt, { - .type = "sequential", - .sequential_configs = {linear_config, step_config}, - .sequential_milestones = {3}, - }); - for (int i = 0; i < 5; ++i) sched->Step(); + .type = "sequential", + .sequential_configs = {linear_config, step_config}, + .sequential_milestones = {3}, + }); + for (int i = 0; i < 5; ++i) { sched->Step(); } StateDict saved = sched->State(); auto opt2 = MakeDummyOptimizer(kBaseLR); @@ -157,10 +156,10 @@ void TestStateRoundTrip() { }; auto step_lr2 = CreateLRScheduler(opt2, step_config2); auto sched2 = CreateLRScheduler(opt2, { - .type = "sequential", - .sequential_configs = {linear_config2, step_config2}, - .sequential_milestones = {3}, - }); + .type = "sequential", + .sequential_configs = {linear_config2, step_config2}, + .sequential_milestones = {3}, + }); sched2->LoadState(saved); ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); @@ -173,33 +172,34 @@ void TestResumeConsistency() { auto make_sched = [](std::shared_ptr opt) { return CreateLRScheduler(opt, { - .type = "sequential", - .sequential_configs = {{ - .type = "linear", - .linear_start_factor = 1e-8f, - .linear_end_factor = 1.0f, - .linear_total_iters = 3, - }, { - .type = "step", - .step_size = 3, - .step_gamma = 0.5f, - }}, - .sequential_milestones = {3}, - }); + .type = "sequential", + .sequential_configs = {{ + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }, + { + .type = "step", + .step_size = 3, + .step_gamma = 0.5f, + }}, + .sequential_milestones = {3}, + }); }; auto opt_ref = MakeDummyOptimizer(kBaseLR); auto sched_ref = make_sched(opt_ref); - for (int i = 0; i < kN; ++i) sched_ref->Step(); + for (int i = 0; i < kN; ++i) { sched_ref->Step(); } auto opt_a = MakeDummyOptimizer(kBaseLR); auto sched_a = make_sched(opt_a); - for (int i = 0; i < kK; ++i) sched_a->Step(); + for (int i = 0; i < kK; ++i) { sched_a->Step(); } StateDict ckpt = sched_a->State(); auto opt_b = MakeDummyOptimizer(kBaseLR); auto sched_b = make_sched(opt_b); sched_b->LoadState(ckpt); - for (int i = 0; i < kN - kK; ++i) sched_b->Step(); + for (int i = 0; i < kN - kK; ++i) { sched_b->Step(); } ASSERT_FLOAT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); diff --git a/test/lr_scheduler/test_step_lr.cc b/test/lr_scheduler/test_step_lr.cc index 1c1bb515..698bcc49 100644 --- a/test/lr_scheduler/test_step_lr.cc +++ b/test/lr_scheduler/test_step_lr.cc @@ -18,7 +18,7 @@ void TestWithinFirstPeriod() { auto sched = CreateLRScheduler(opt, config); for (int i = 0; i < 2; ++i) { sched->Step(); - ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); // last_step 1,2 → 指数 0 + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); // last_step 1,2 → 指数 0 } } @@ -30,7 +30,7 @@ void TestFirstDecay() { .step_gamma = 0.1f, }; auto sched = CreateLRScheduler(opt, config); - for (int i = 0; i < 3; ++i) sched->Step(); + for (int i = 0; i < 3; ++i) { sched->Step(); } // last_step=3, 3//3=1 → 0.1^1 = 0.1 → lr=0.01 ASSERT_FLOAT_EQ(sched->GetLR(), 0.01f); } @@ -43,14 +43,13 @@ void TestMultipleDecays() { .step_gamma = 0.1f, }; auto sched = CreateLRScheduler(opt, config); - for (int i = 0; i < 6; ++i) sched->Step(); + for (int i = 0; i < 6; ++i) { sched->Step(); } // last_step=6, 6//3=2 → 0.1^2 = 0.01 → lr=0.001 ASSERT_FLOAT_NEAR(sched->GetLR(), 0.001f, 1e-7f); } void TestPyTorchAlignment() { - const std::vector expected = { - 0.1f, 0.1f, 0.01f, 0.01f, 0.01f, 0.001f, 0.001f}; + const std::vector expected = {0.1f, 0.1f, 0.01f, 0.01f, 0.01f, 0.001f, 0.001f}; auto opt = MakeDummyOptimizer(kBaseLR); LRSchedulerConfig config = { .type = "step", @@ -81,17 +80,17 @@ void TestGammaOne() { void TestChainableAndClosedFormConsistency() { auto opt_a = MakeDummyOptimizer(kBaseLR); auto chainable = CreateLRScheduler(opt_a, { - .type = "step", - .step_size = 3, - .step_gamma = 0.1f, - }); + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }); auto opt_b = MakeDummyOptimizer(kBaseLR); auto closed_form = CreateLRScheduler(opt_b, { - .type = "step", - .step_size = 3, - .step_gamma = 0.1f, - }); + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }); for (int epoch = 1; epoch <= 12; ++epoch) { chainable->Step();