diff --git a/CMakeLists.txt b/CMakeLists.txt index 74536707..ca479abe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -200,3 +200,24 @@ target_link_libraries(test_hook infini_train) add_executable(test_precision_check test/hook/test_precision_check.cc) target_link_libraries(test_precision_check infini_train) + +add_executable(test_lr_scheduler test/lr_scheduler/test_lr_scheduler.cc) +target_link_libraries(test_lr_scheduler infini_train) + +add_executable(test_constant_lr test/lr_scheduler/test_constant_lr.cc) +target_link_libraries(test_constant_lr infini_train) + +add_executable(test_step_lr test/lr_scheduler/test_step_lr.cc) +target_link_libraries(test_step_lr infini_train) + +add_executable(test_linear_lr test/lr_scheduler/test_linear_lr.cc) +target_link_libraries(test_linear_lr infini_train) + +add_executable(test_lambda_lr test/lr_scheduler/test_lambda_lr.cc) +target_link_libraries(test_lambda_lr infini_train) + +add_executable(test_sequential_lr test/lr_scheduler/test_sequential_lr.cc) +target_link_libraries(test_sequential_lr infini_train) + +add_executable(test_chained_lr test/lr_scheduler/test_chained_lr.cc) +target_link_libraries(test_chained_lr infini_train) \ No newline at end of file diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index a007dff1..90f60262 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -13,6 +13,7 @@ #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" +#include "infini_train/include/lr_scheduler.h" #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" @@ -55,6 +56,16 @@ DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations"); DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); +// lr scheduler +DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear"); +DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)"); +DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)"); +DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)"); +DEFINE_int64(step_size, 30, "StepLR: period of learning rate decay"); +DEFINE_double(gamma, 0.1, "StepLR: multiplicative factor of lr decay"); +DEFINE_double(start_factor, 0.333333, "LinearLR: starting multiplicative factor"); +DEFINE_double(end_factor, 1.0, "LinearLR: ending multiplicative factor"); +DEFINE_int64(lr_total_iters, 5, "ConstantLR/LinearLR: total iterations for the scheduler"); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -268,6 +279,20 @@ void Train(const nn::parallel::Rank &rank) { optimizer = optimizer_creator(model->Parameters()); } + LRSchedulerConfig sched_config; + sched_config.type = FLAGS_lr_scheduler; + sched_config.warmup_steps = FLAGS_warmup_steps; + sched_config.warmup_start_factor = static_cast(FLAGS_warmup_start_factor); + sched_config.warmup_end_factor = static_cast(FLAGS_warmup_end_factor); + sched_config.step_size = FLAGS_step_size; + sched_config.step_gamma = static_cast(FLAGS_gamma); + sched_config.linear_start_factor = static_cast(FLAGS_start_factor); + sched_config.linear_end_factor = static_cast(FLAGS_end_factor); + sched_config.constant_factor = static_cast(FLAGS_start_factor); // 复用 + sched_config.constant_total_iters = FLAGS_lr_total_iters; + sched_config.linear_total_iters = FLAGS_lr_total_iters; + auto scheduler = CreateLRScheduler(optimizer, sched_config); + auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast( @@ -354,6 +379,9 @@ void Train(const nn::parallel::Rank &rank) { } optimizer->Step(); + if (scheduler) { + scheduler->Step(); + } } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below @@ -363,6 +391,9 @@ void Train(const nn::parallel::Rank &rank) { y = std::make_shared(y->To(device)); lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); + if (scheduler) { + scheduler->Step(); + } } if (ddp_world_size > 1) { @@ -378,11 +409,11 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); - + const float current_lr = scheduler ? scheduler->GetLR() : static_cast(FLAGS_learning_rate); LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", - step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, - tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, + step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps, + used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 2b1e2121..5b1bffbb 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -11,6 +11,7 @@ #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" +#include "infini_train/include/lr_scheduler.h" #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h" @@ -54,6 +55,16 @@ DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations"); DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); +// lr scheduler +DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear"); +DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)"); +DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)"); +DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)"); +DEFINE_int64(step_size, 30, "StepLR: period of learning rate decay"); +DEFINE_double(gamma, 0.1, "StepLR: multiplicative factor of lr decay"); +DEFINE_double(start_factor, 0.333333, "LinearLR: starting multiplicative factor"); +DEFINE_double(end_factor, 1.0, "LinearLR: ending multiplicative factor"); +DEFINE_int64(lr_total_iters, 5, "ConstantLR/LinearLR: total iterations for the scheduler"); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -247,6 +258,20 @@ void Train(const nn::parallel::Rank &rank) { optimizer = optimizer_creator(model->Parameters()); } + LRSchedulerConfig sched_config; + sched_config.type = FLAGS_lr_scheduler; + sched_config.warmup_steps = FLAGS_warmup_steps; + sched_config.warmup_start_factor = static_cast(FLAGS_warmup_start_factor); + sched_config.warmup_end_factor = static_cast(FLAGS_warmup_end_factor); + sched_config.step_size = FLAGS_step_size; + sched_config.step_gamma = static_cast(FLAGS_gamma); + sched_config.linear_start_factor = static_cast(FLAGS_start_factor); + sched_config.linear_end_factor = static_cast(FLAGS_end_factor); + sched_config.constant_factor = static_cast(FLAGS_start_factor); // 复用 + sched_config.constant_total_iters = FLAGS_lr_total_iters; + sched_config.linear_total_iters = FLAGS_lr_total_iters; + auto scheduler = CreateLRScheduler(optimizer, sched_config); + auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast(std::make_shared()) @@ -330,6 +355,9 @@ void Train(const nn::parallel::Rank &rank) { } optimizer->Step(); + if (scheduler) { + scheduler->Step(); + } } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below @@ -339,6 +367,9 @@ void Train(const nn::parallel::Rank &rank) { y = std::make_shared(y->To(device)); lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); + if (scheduler) { + scheduler->Step(); + } } if (ddp_world_size > 1) { @@ -354,11 +385,11 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); - + const float current_lr = scheduler ? scheduler->GetLR() : static_cast(FLAGS_learning_rate); LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", - step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, - tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, + step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps, + used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h new file mode 100644 index 00000000..4e4695ce --- /dev/null +++ b/infini_train/include/lr_scheduler.h @@ -0,0 +1,186 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace infini_train { + +class Optimizer; + +using StateValue = std::variant>; +using StateDict = std::unordered_map; + +struct LRSchedulerConfig { + std::string type = "none"; + // ConstantLR + float constant_factor = 1.0f / 3.0f; + int constant_total_iters = 5; + // StepLR + int64_t step_size = 10; + float step_gamma = 0.1f; + // LinearLR + float linear_start_factor = 1.0f / 3.0f; + float linear_end_factor = 1.0f; + int linear_total_iters = 5; + // LambdaLR + std::function lambda_fn = nullptr; + // SequentialLR + std::vector sequential_configs; + std::vector sequential_milestones; + // ChainedScheduler + std::vector chained_configs; + // warmup + int64_t warmup_steps = 0; + float warmup_start_factor = 1.0f / 3.0f; + float warmup_end_factor = 1.0f; +}; + +class LRScheduler { +public: + template static std::shared_ptr Create(Args &&...args) { + auto scheduler = std::make_shared(std::forward(args)...); + scheduler->InitialStep(); + return scheduler; + } + + explicit LRScheduler(std::shared_ptr optimizer, int64_t last_step = -1); + virtual ~LRScheduler() = default; + + LRScheduler(const LRScheduler &) = delete; + LRScheduler &operator=(const LRScheduler &) = delete; + + virtual void Step(); + virtual void Step(int64_t epoch); + virtual void InitialStep(); + + float GetLR() const; + float BaseLR() const; + int64_t LastStep() const; + + void ResetStep(int64_t step = -1); + virtual StateDict State() const; + virtual void LoadState(const StateDict &state); + +protected: + virtual float GetClosedFormLR() const = 0; + virtual float GetChainedFormLR() const; + void ApplyLR(float lr); + + std::shared_ptr optimizer_; + int64_t last_step_; + float current_lr_; + float base_lr_; + bool is_initial_ = false; +}; + +std::shared_ptr CreateLRScheduler(std::shared_ptr optimizer, const LRSchedulerConfig &config); + +namespace lr_schedulers { + +class ConstantLR : public LRScheduler { +public: + ConstantLR(std::shared_ptr optimizer, float factor = 1.0f / 3.0f, int total_iters = 5, + int64_t last_step = -1); + ~ConstantLR() override = default; + +protected: + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; + +private: + const float factor_; + const int64_t total_iters_; +}; + +class StepLR : public LRScheduler { +public: + StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1); + ~StepLR() override = default; + +protected: + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; + +private: + const int64_t step_size_; + const float gamma_; +}; + +class LinearLR : public LRScheduler { +public: + LinearLR(std::shared_ptr optimizer, float start_factor = 1.0f / 3.0f, float end_factor = 1.0f, + int64_t total_iters = 5, int64_t last_step = -1); + ~LinearLR() override = default; + +protected: + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; + +private: + const float start_factor_; + const float end_factor_; + const int64_t total_iters_; +}; + +class LambdaLR : public LRScheduler { +public: + using LambdaFunc = std::function; + + LambdaLR(std::shared_ptr optimizer, LambdaFunc lr_lambda, int64_t last_step = -1); + ~LambdaLR() override = default; + +protected: + float GetClosedFormLR() const override; + +private: + const LambdaFunc lr_lambda_; +}; + +class SequentialLR : public LRScheduler { +public: + SequentialLR(std::shared_ptr optimizer, std::vector> schedulers, + std::vector milestones, int64_t last_step = -1); + ~SequentialLR() override = default; + + void Step() override; + void InitialStep() override; + + StateDict State() const override; + void LoadState(const StateDict &state) override; + +protected: + float GetClosedFormLR() const override { return current_lr_; } + void UndoChildInitialSteps(); + +private: + std::vector> schedulers_; + std::vector milestones_; +}; + +class ChainedScheduler : public LRScheduler { +public: + ChainedScheduler(std::shared_ptr optimizer, std::vector> schedulers, + int64_t last_step = -1); + ~ChainedScheduler() override = default; + + void Step() override; + void InitialStep() override; + + StateDict State() const override; + void LoadState(const StateDict &state) override; + +protected: + float GetClosedFormLR() const override { return current_lr_; } + +private: + std::vector> schedulers_; +}; + +} // namespace lr_schedulers +} // namespace infini_train \ No newline at end of file diff --git a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h index bc31442e..f95801a2 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h +++ b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h @@ -34,6 +34,9 @@ class DistributedOptimizer final : public infini_train::Optimizer { void StartParamSync(bool force_sync = false); void FinishParamSync(bool skip_next_bucket_dispatch = false); + virtual void SetLearningRate(float lr) override; + virtual float GetLearningRate() const override; + private: void BuildShardParamsAndBindGrads(); diff --git a/infini_train/include/optimizer.h b/infini_train/include/optimizer.h index fb0ae2d5..c72ee6c9 100644 --- a/infini_train/include/optimizer.h +++ b/infini_train/include/optimizer.h @@ -15,14 +15,25 @@ using OptimizerCreator = std::function(const std::vec class Optimizer { public: - explicit Optimizer(const std::vector> ¶ms); + explicit Optimizer(const std::vector> ¶ms, float learning_rate = 0.0f); virtual void ZeroGrad(bool set_to_none = true); virtual void Step() = 0; + virtual void SetLearningRate(float lr); + + virtual float GetLearningRate() const; + + float GetInitialLearningRate() const; + + void SetInitialLearningRate(float lr); + protected: std::vector> params_; + float learning_rate_ = 0.0f; + float initial_learning_rate_ = 0.0f; + bool initial_lr_set_ = false; }; namespace optimizers { @@ -37,9 +48,6 @@ class SGD : public Optimizer { return std::make_shared(params, learning_rate); }; } - -private: - const float learning_rate_ = 0.0; }; class Adam : public Optimizer { @@ -58,7 +66,6 @@ class Adam : public Optimizer { private: int64_t t_; - const float learning_rate_; const float beta1_; const float beta2_; const float eps_; diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc new file mode 100644 index 00000000..11b6034b --- /dev/null +++ b/infini_train/src/lr_scheduler.cc @@ -0,0 +1,326 @@ +#include "infini_train/include/lr_scheduler.h" + +#include "glog/logging.h" + +#include "infini_train/include/optimizer.h" + +namespace infini_train { + +std::shared_ptr CreateLRScheduler(std::shared_ptr optimizer, const LRSchedulerConfig &config) { + if (config.type == "none") { + return nullptr; + } + + auto create_main = [&](std::shared_ptr opt) -> std::shared_ptr { + if (config.type == "constant") { + return LRScheduler::Create(opt, config.constant_factor, + config.constant_total_iters); + } + if (config.type == "step") { + return LRScheduler::Create(opt, config.step_size, config.step_gamma); + } + if (config.type == "linear") { + return LRScheduler::Create(opt, config.linear_start_factor, + config.linear_end_factor, config.linear_total_iters); + } + if (config.type == "lambda") { + return LRScheduler::Create(opt, config.lambda_fn); + } + if (config.type == "sequential") { + std::vector> schedulers; + std::vector milestones = config.sequential_milestones; + for (const auto &sub_config : config.sequential_configs) { + auto sub_sched = CreateLRScheduler(opt, sub_config); + if (sub_sched) { + schedulers.push_back(sub_sched); + } + } + return LRScheduler::Create(opt, schedulers, milestones); + } + if (config.type == "chained") { + std::vector> schedulers; + for (const auto &sub_config : config.chained_configs) { + auto sub_sched = CreateLRScheduler(opt, sub_config); + if (sub_sched) { + schedulers.push_back(sub_sched); + } + } + return LRScheduler::Create(opt, schedulers); + } + LOG(FATAL) << "Unsupported LR scheduler type: " << config.type; + return nullptr; + }; + + if (config.warmup_steps <= 0) { + return create_main(optimizer); + } + + auto warmup_scheduler = LRScheduler::Create(optimizer, + /*start_factor=*/config.warmup_start_factor, + /*end_factor=*/config.warmup_end_factor, + /*total_iters=*/config.warmup_steps); + + auto main_scheduler = create_main(optimizer); + + return LRScheduler::Create( + optimizer, std::vector>{warmup_scheduler, main_scheduler}, + std::vector{config.warmup_steps}); +}; + +LRScheduler::LRScheduler(std::shared_ptr optimizer, int64_t last_step) + : optimizer_(std::move(optimizer)), last_step_(last_step), current_lr_(0.0f), base_lr_(0.0f) { + CHECK(optimizer_) << "LRScheduler: optimizer must not be null."; + optimizer_->SetInitialLearningRate(optimizer_->GetLearningRate()); + base_lr_ = optimizer_->GetInitialLearningRate(); + current_lr_ = base_lr_; +} + +void LRScheduler::Step() { + ++last_step_; + ApplyLR(GetChainedFormLR()); +} + +void LRScheduler::Step(int64_t epoch) { + last_step_ = epoch; + ApplyLR(GetClosedFormLR()); +} + +void LRScheduler::InitialStep() { + is_initial_ = true; + Step(); + is_initial_ = false; +} + +void LRScheduler::ApplyLR(float lr) { + current_lr_ = lr; + optimizer_->SetLearningRate(current_lr_); +} + +float LRScheduler::GetChainedFormLR() const { return GetClosedFormLR(); } + +float LRScheduler::GetLR() const { return current_lr_; } + +float LRScheduler::BaseLR() const { return base_lr_; } + +int64_t LRScheduler::LastStep() const { return last_step_; } + +void LRScheduler::ResetStep(int64_t step) { last_step_ = step; } + +StateDict LRScheduler::State() const { + return { + {"last_step", last_step_}, + {"current_lr", current_lr_}, + {"base_lr", base_lr_}, + }; +} + +void LRScheduler::LoadState(const StateDict &state) { + last_step_ = std::get(state.at("last_step")); + current_lr_ = std::get(state.at("current_lr")); + base_lr_ = std::get(state.at("base_lr")); + optimizer_->SetLearningRate(current_lr_); +} + +// Concrete LR Schedulers + +namespace lr_schedulers { + +// --- ConstantLR --- + +ConstantLR::ConstantLR(std::shared_ptr optimizer, float factor, int total_iters, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) {} + +float ConstantLR::GetClosedFormLR() const { return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; } + +float ConstantLR::GetChainedFormLR() const { + const float lr = optimizer_->GetLearningRate(); + if (last_step_ == 0) { + return lr * factor_; + } else if (last_step_ < total_iters_) { + return lr; + } else if (last_step_ == total_iters_) { + return lr / factor_; + } + return lr; +} + +// --- StepLR --- + +StepLR::StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) {} + +float StepLR::GetClosedFormLR() const { + return base_lr_ + * static_cast(std::pow(static_cast(gamma_), static_cast(last_step_ / step_size_))); +} + +float StepLR::GetChainedFormLR() const { + const float lr = optimizer_->GetLearningRate(); + if (last_step_ == 0 || (last_step_ % step_size_) != 0) { + return lr; + } + return lr * gamma_; +} + +LinearLR::LinearLR(std::shared_ptr optimizer, float start_factor, float end_factor, int64_t total_iters, + int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), start_factor_(start_factor), end_factor_(end_factor), + total_iters_(total_iters) {} + +float LinearLR::GetClosedFormLR() const { + if (last_step_ >= total_iters_) { + return base_lr_ * end_factor_; + } + return base_lr_ + * (start_factor_ + + (end_factor_ - start_factor_) * static_cast(last_step_) / static_cast(total_iters_)); +} + +float LinearLR::GetChainedFormLR() const { + const float lr = optimizer_->GetLearningRate(); + if (last_step_ == 0) { + return lr * start_factor_; + } + if (last_step_ > total_iters_ || is_initial_) { + return lr; + } + if (last_step_ == total_iters_) { + const float prev_factor + = start_factor_ + + (end_factor_ - start_factor_) * static_cast(total_iters_ - 1) / static_cast(total_iters_); + return lr * (end_factor_ / prev_factor); + } + + const float numerator = end_factor_ - start_factor_; + const float denominator + = start_factor_ * static_cast(total_iters_) + static_cast(last_step_ - 1) * numerator; + return lr * (1.0f + numerator / denominator); +} + +LambdaLR::LambdaLR(std::shared_ptr optimizer, std::function lr_lambda, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) {} + +float LambdaLR::GetClosedFormLR() const { return base_lr_ * lr_lambda_(last_step_); } + +SequentialLR::SequentialLR(std::shared_ptr optimizer, std::vector> schedulers, + std::vector milestones, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)), + milestones_(std::move(milestones)) {} + +void SequentialLR::InitialStep() { + CHECK(!schedulers_.empty()) << "SequentialLR requires at least one scheduler."; + CHECK_EQ(milestones_.size(), schedulers_.size() - 1) + << "SequentialLR: milestones count must be schedulers count - 1."; + + for (size_t i = 1; i < milestones_.size(); ++i) { + CHECK_GT(milestones_[i], milestones_[i - 1]) << "Milestones must be strictly increasing."; + } + + optimizer_->SetLearningRate(schedulers_[0]->BaseLR()); + + UndoChildInitialSteps(); + + ++last_step_; + schedulers_[0]->InitialStep(); + current_lr_ = schedulers_[0]->GetLR(); +} + +void SequentialLR::UndoChildInitialSteps() { + for (auto &sched : schedulers_) { + if (auto nested = std::dynamic_pointer_cast(sched)) { + nested->UndoChildInitialSteps(); + } + sched->ResetStep(sched->LastStep() - 1); + } +} + +void SequentialLR::Step() { + ++last_step_; + size_t idx = std::upper_bound(milestones_.begin(), milestones_.end(), last_step_) - milestones_.begin(); + + auto &scheduler = schedulers_[idx]; + + if (idx > 0 && milestones_[idx - 1] == last_step_) { + scheduler->Step(0); + } else { + scheduler->Step(); + } + + current_lr_ = optimizer_->GetLearningRate(); +} + +StateDict SequentialLR::State() const { + StateDict state; + state["last_step"] = last_step_; + state["current_lr"] = current_lr_; + state["base_lr"] = base_lr_; + for (size_t i = 0; i < schedulers_.size(); ++i) { + auto sub_state = schedulers_[i]->State(); + for (const auto &[key, value] : sub_state) { state["scheduler_" + std::to_string(i) + "." + key] = value; } + } + return state; +} + +void SequentialLR::LoadState(const StateDict &state) { + last_step_ = std::get(state.at("last_step")); + current_lr_ = std::get(state.at("current_lr")); + base_lr_ = std::get(state.at("base_lr")); + + for (size_t i = 0; i < schedulers_.size(); ++i) { + StateDict sub_state; + std::string prefix = "scheduler_" + std::to_string(i) + "."; + for (const auto &[key, value] : state) { + if (key.substr(0, prefix.size()) == prefix) { + sub_state[key.substr(prefix.size())] = value; + } + } + if (!sub_state.empty()) { + schedulers_[i]->LoadState(sub_state); + } + } + optimizer_->SetLearningRate(current_lr_); +} + +ChainedScheduler::ChainedScheduler(std::shared_ptr optimizer, + std::vector> schedulers, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)) {} + +void ChainedScheduler::InitialStep() { + CHECK(!schedulers_.empty()) << "ChainedScheduler requires at least one scheduler."; + + current_lr_ = optimizer_->GetLearningRate(); +} + +void ChainedScheduler::Step() { + ++last_step_; + for (auto &sched : schedulers_) { sched->Step(); } + current_lr_ = optimizer_->GetLearningRate(); +} + +StateDict ChainedScheduler::State() const { + StateDict state = LRScheduler::State(); + for (size_t i = 0; i < schedulers_.size(); ++i) { + auto sub_state = schedulers_[i]->State(); + for (const auto &[key, value] : sub_state) { state["scheduler_" + std::to_string(i) + "." + key] = value; } + } + return state; +} + +void ChainedScheduler::LoadState(const StateDict &state) { + LRScheduler::LoadState(state); + for (size_t i = 0; i < schedulers_.size(); ++i) { + StateDict sub_state; + std::string prefix = "scheduler_" + std::to_string(i) + "."; + for (const auto &[key, value] : state) { + if (key.substr(0, prefix.size()) == prefix) { + sub_state[key.substr(prefix.size())] = value; + } + } + if (!sub_state.empty()) { + schedulers_[i]->LoadState(sub_state); + } + } +} + +} // namespace lr_schedulers +} // namespace infini_train \ No newline at end of file diff --git a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc index 55e5800b..c7f3de2a 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc @@ -114,6 +114,20 @@ void DistributedOptimizer::ZeroGrad(bool set_to_none) { } } +void DistributedOptimizer::SetLearningRate(float lr) { + Optimizer::SetLearningRate(lr); + if (base_optimizer_) { + base_optimizer_->SetLearningRate(lr); + } +} + +float DistributedOptimizer::GetLearningRate() const { + if (base_optimizer_) { + return base_optimizer_->GetLearningRate(); + } + return Optimizer::GetLearningRate(); +} + void DistributedOptimizer::Step() { // 1. Ensure grads are synced FinishGradSync(); diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index 2c9b218a..c86c40f1 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -8,16 +8,32 @@ #include "infini_train/include/tensor.h" namespace infini_train { -Optimizer::Optimizer(const std::vector> ¶ms) : params_(params) {} +Optimizer::Optimizer(const std::vector> ¶ms, float learning_rate) + : params_(params), learning_rate_(learning_rate) {} void Optimizer::ZeroGrad(bool set_to_none) { for (auto param : params_) { param->ZeroGrad(set_to_none); } } +void Optimizer::SetLearningRate(float lr) { learning_rate_ = lr; } + +float Optimizer::GetLearningRate() const { return learning_rate_; } + +float Optimizer::GetInitialLearningRate() const { + CHECK(initial_lr_set_) << "Optimizer: initial_learning_rate not set. " + "Use with an LRScheduler first."; + return initial_learning_rate_; +} + +void Optimizer::SetInitialLearningRate(float lr) { + if (!initial_lr_set_) { + initial_learning_rate_ = lr; + initial_lr_set_ = true; + } +} namespace optimizers { -SGD::SGD(const std::vector> ¶ms, float learning_rate) - : Optimizer(params), learning_rate_(learning_rate) {} +SGD::SGD(const std::vector> ¶ms, float learning_rate) : Optimizer(params, learning_rate) {} void SGD::Step() { for (auto param : params_) { @@ -33,7 +49,7 @@ void SGD::Step() { } Adam::Adam(const std::vector> ¶ms, float learning_rate, float beta1, float beta2, float eps) - : Optimizer(params), t_(0), learning_rate_(learning_rate), beta1_(beta1), beta2_(beta2), eps_(eps) { + : Optimizer(params, learning_rate), t_(0), beta1_(beta1), beta2_(beta2), eps_(eps) { for (const auto ¶m : params_) { m_.emplace_back(std::make_shared(param->Dims(), param->Dtype(), param->GetDevice())); diff --git a/test/lr_scheduler/test_chained_lr.cc b/test/lr_scheduler/test_chained_lr.cc new file mode 100644 index 00000000..3ea6bd55 --- /dev/null +++ b/test/lr_scheduler/test_chained_lr.cc @@ -0,0 +1,202 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { +constexpr float kBaseLR = 0.1f; +} +// TC1: 单子调度器退化 +void TestSingleScheduler() { + std::cout << "[TC1] TestSingleScheduler" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto step_lr = CreateLRScheduler(opt, { + .type = "step", + .step_size = 3, + .step_gamma = 0.5f, + }); + auto sched = LRScheduler::Create(opt, std::vector>{step_lr}); + + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + sched->Step(); // step=1 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); +} + +// TC2: StepLR + LambdaLR 乘法叠加 +void TestMultiplicativeChain() { + std::cout << "[TC2] TestMultiplicativeChain" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = CreateLRScheduler( + opt, { + .type = "chained", + .chained_configs = {{ + .type = "step", + .step_size = 2, + .step_gamma = 0.5f, + }, + { + .type = "lambda", + .lambda_fn = [](int64_t step) { return 1.0f - 0.1f * step; }, + }}, + }); + + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.09f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.08f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.07f, kEps); +} + +// TC3: ConstantLR + StepLR 叠加 (无穿插声明) +void TestConstantPlusStep() { + std::cout << "[TC3] TestConstantPlusStep" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = CreateLRScheduler(opt, { + .type = "chained", + .chained_configs = {{ + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 2, + }, + { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }}, + }); + + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.05f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.05f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.01f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.01f, kEps); +} + +// TC4: ConstantLR + StepLR 叠加(有穿插声明) +void TestConstantPlusStepDLC() { + std::cout << "[TC4] TestConstantPlusStepDLC" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto constant = CreateLRScheduler(opt, { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 2, + }); + auto linear = CreateLRScheduler(opt, { + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }); + auto step_lr = CreateLRScheduler(opt, { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }); + auto Lambda = CreateLRScheduler(opt, { + .type = "lambda", + .lambda_fn = [](int64_t step) { return 1.0f - 0.1f * step; }, + }); + + auto sched + = LRScheduler::Create(opt, std::vector>{constant, step_lr}); + + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.2f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.02f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.02f, kEps); +} + +// TC5: State/LoadState 往返 +void TestStateRoundTrip() { + std::cout << "[TC5] TestStateRoundTrip" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto step_lr = std::make_shared(opt, 2, 0.5f); + auto lambda_lr = std::make_shared(opt, [](int64_t step) { return 1.0f - 0.05f * step; }); + auto sched + = LRScheduler::Create(opt, std::vector>{step_lr, lambda_lr}); + + for (int i = 0; i < 5; ++i) { sched->Step(); } + StateDict saved = sched->State(); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + auto step_lr2 = std::make_shared(opt2, 2, 0.5f); + auto lambda_lr2 = std::make_shared(opt2, [](int64_t step) { return 1.0f - 0.05f * step; }); + auto sched2 + = LRScheduler::Create(opt2, std::vector>{step_lr2, lambda_lr2}); + sched2->LoadState(saved); + + ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); + ASSERT_FLOAT_NEAR(sched2->GetLR(), sched->GetLR(), kEps); +} + +// TC6: resume 一致性 +void TestResumeConsistency() { + std::cout << "[TC6] TestResumeConsistency" << std::endl; + constexpr int kN = 10, kK = 4; + auto lambda_fn = [](int64_t step) { return 1.0f - 0.05f * step; }; + + auto make_sched = [&](std::shared_ptr opt) { + auto step_lr = std::make_shared(opt, 2, 0.5f); + auto lambda_lr = std::make_shared(opt, lambda_fn); + return LRScheduler::Create(opt, + std::vector>{step_lr, lambda_lr}); + }; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = make_sched(opt_ref); + for (int i = 0; i < kN; ++i) { sched_ref->Step(); } + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = make_sched(opt_a); + for (int i = 0; i < kK; ++i) { sched_a->Step(); } + StateDict ckpt = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = make_sched(opt_b); + sched_b->LoadState(ckpt); + for (int i = 0; i < kN - kK; ++i) { sched_b->Step(); } + + ASSERT_FLOAT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== ChainedScheduler Tests ===" << std::endl; + TestSingleScheduler(); + TestMultiplicativeChain(); + TestConstantPlusStep(); + TestConstantPlusStepDLC(); + TestStateRoundTrip(); + TestResumeConsistency(); + std::cout << "========================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} \ No newline at end of file diff --git a/test/lr_scheduler/test_constant_lr.cc b/test/lr_scheduler/test_constant_lr.cc new file mode 100644 index 00000000..df5e9be2 --- /dev/null +++ b/test/lr_scheduler/test_constant_lr.cc @@ -0,0 +1,184 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { +constexpr float kBaseLR = 0.1f; +} // namespace + +void TestInitialState() { + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 3, + }; + auto sched = CreateLRScheduler(opt, config); + ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); + ASSERT_TRUE(sched->LastStep() == 0); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f); +} + +void TestFirstStepAppliesFactor() { + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 3, + }; + + auto sched = CreateLRScheduler(opt, config); + sched->Step(); // last_step_ = 0 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f); + ASSERT_TRUE(sched->LastStep() == 1); +} + +void TestWithinTotalIters() { + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 3, + }; + auto sched = CreateLRScheduler(opt, config); + for (int i = 0; i < 2; ++i) { sched->Step(); } + // last_step_ = 2, still < 3 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); +} + +void TestBeyondTotalIters() { + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 3, + }; + auto sched = CreateLRScheduler(opt, config); + for (int i = 0; i < 10; ++i) { sched->Step(); } + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); +} + +void TestPyTorchAlignment() { + const std::vector expected = {0.05f, 0.05f, 0.1f, 0.1f, 0.1f}; + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 3, + }; + auto sched = CreateLRScheduler(opt, config); + for (size_t i = 0; i < expected.size(); ++i) { + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), expected[i]); + } +} + +void TestStateRoundTrip() { + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 5, + }; + auto sched = CreateLRScheduler(opt, config); + for (int i = 0; i < 3; ++i) { sched->Step(); } + StateDict saved = sched->State(); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config2 = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 5, + }; + auto sched2 = CreateLRScheduler(opt2, config2); + sched2->LoadState(saved); + + ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); + ASSERT_FLOAT_EQ(sched2->GetLR(), sched->GetLR()); + ASSERT_FLOAT_EQ(opt2->GetLearningRate(), sched->GetLR()); +} + +void TestResumeConsistency() { + constexpr int kN = 8; + constexpr int kK = 3; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config_ref = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 5, + }; + auto sched_ref = CreateLRScheduler(opt_ref, config_ref); + for (int i = 0; i < kN; ++i) { sched_ref->Step(); } + + auto opt_a = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config_a = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 5, + }; + auto sched_a = CreateLRScheduler(opt_a, config_a); + for (int i = 0; i < kK; ++i) { sched_a->Step(); } + StateDict ckpt = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config_b = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 5, + }; + auto sched_b = CreateLRScheduler(opt_b, config_b); + sched_b->LoadState(ckpt); + for (int i = 0; i < kN - kK; ++i) { sched_b->Step(); } + + ASSERT_FLOAT_EQ(sched_b->GetLR(), sched_ref->GetLR()); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); +} + +void TestChainableAndClosedFormConsistency() { + auto opt_a = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config_a = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 5, + }; + auto chainable = CreateLRScheduler(opt_a, config_a); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config_b = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 5, + }; + auto closed_form = CreateLRScheduler(opt_b, config_b); + + for (int epoch = 1; epoch <= 12; ++epoch) { + chainable->Step(); + closed_form->Step(epoch); + ASSERT_FLOAT_NEAR(chainable->GetLR(), closed_form->GetLR(), 1e-7f); + } +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== ConstantLR Tests ===" << std::endl; + TestInitialState(); + TestFirstStepAppliesFactor(); + TestWithinTotalIters(); + TestBeyondTotalIters(); + TestPyTorchAlignment(); + TestStateRoundTrip(); + TestResumeConsistency(); + TestChainableAndClosedFormConsistency(); + std::cout << "========================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} \ No newline at end of file diff --git a/test/lr_scheduler/test_helpers.h b/test/lr_scheduler/test_helpers.h new file mode 100644 index 00000000..f4b22430 --- /dev/null +++ b/test/lr_scheduler/test_helpers.h @@ -0,0 +1,35 @@ +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +namespace { + +constexpr float kEps = 1e-6f; + +std::shared_ptr MakeDummyOptimizer(float lr) { + std::vector> empty_params; + return std::make_shared(empty_params, lr); +} + +bool FloatNear(float a, float b, float eps = kEps) { return std::fabs(a - b) < eps; } + +int g_fail_count = 0; + +void Check(bool cond, const char *expr, int line) { + if (!cond) { + std::cerr << "FAIL [line " << line << "]: " << expr << std::endl; + ++g_fail_count; + } +} + +#define ASSERT_TRUE(cond) Check((cond), #cond, __LINE__) +#define ASSERT_FLOAT_EQ(a, b) Check(FloatNear((a), (b)), #a " == " #b, __LINE__) +#define ASSERT_FLOAT_NEAR(a, b, eps) Check(FloatNear((a), (b), (eps)), #a " ≈ " #b, __LINE__) + +} // namespace \ No newline at end of file diff --git a/test/lr_scheduler/test_lambda_lr.cc b/test/lr_scheduler/test_lambda_lr.cc new file mode 100644 index 00000000..a89d6356 --- /dev/null +++ b/test/lr_scheduler/test_lambda_lr.cc @@ -0,0 +1,127 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { +constexpr float kBaseLR = 0.1f; +} // namespace + +void TestIdentityLambda() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = CreateLRScheduler(opt, { + .type = "lambda", + .lambda_fn = [](int64_t) { return 1.0f; }, + }); + // 构造器内 Step() → last_step_=0, lr = 0.1 * 1.0 = 0.1 + ASSERT_TRUE(sched->LastStep() == 0); + ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps); + ASSERT_FLOAT_NEAR(opt->GetLearningRate(), kBaseLR, kEps); +} + +void TestLinearDecayLambda() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = CreateLRScheduler(opt, { + .type = "lambda", + .lambda_fn = [](int64_t step) { return 1.0f - step * 0.1f; }, + }); + // step=0, lambda(0)=1.0, lr=0.1 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); // step=1, lambda(1)=0.9, lr=0.09 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.09f, kEps); + + sched->Step(); // step=2, lambda(2)=0.8, lr=0.08 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.08f, kEps); + + sched->Step(); // step=3, lambda(3)=0.7, lr=0.07 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.07f, kEps); +} + +void TestPyTorchAlignment() { + // PyTorch: LambdaLR(opt, lr_lambda=lambda epoch: 0.95**epoch) + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched + = CreateLRScheduler(opt, { + .type = "lambda", + .lambda_fn = [](int64_t step) { return static_cast(std::pow(0.95, step)); }, + }); + // step=0, lr = 0.1 * 0.95^0 = 0.1 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + std::vector expected = {0.095f, 0.09025f, 0.0857375f}; + for (size_t i = 0; i < expected.size(); ++i) { + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-5f); + } +} + +void TestStateRoundTrip() { + auto lambda_fn = [](int64_t step) { return 1.0f - step * 0.05f; }; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = CreateLRScheduler(opt, { + .type = "lambda", + .lambda_fn = lambda_fn, + }); + for (int i = 0; i < 5; ++i) { sched->Step(); } + StateDict saved = sched->State(); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + auto sched2 = CreateLRScheduler(opt2, { + .type = "lambda", + .lambda_fn = lambda_fn, + }); // same lambda + sched2->LoadState(saved); + + ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); + ASSERT_FLOAT_NEAR(sched2->GetLR(), sched->GetLR(), kEps); +} + +void TestResumeConsistency() { + auto lambda_fn = [](int64_t step) { return 1.0f - step * 0.05f; }; + constexpr int kN = 10, kK = 4; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = CreateLRScheduler(opt_ref, { + .type = "lambda", + .lambda_fn = lambda_fn, + }); + for (int i = 0; i < kN; ++i) { sched_ref->Step(); } + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = CreateLRScheduler(opt_a, { + .type = "lambda", + .lambda_fn = lambda_fn, + }); + for (int i = 0; i < kK; ++i) { sched_a->Step(); } + StateDict ckpt = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = CreateLRScheduler(opt_b, { + .type = "lambda", + .lambda_fn = lambda_fn, + }); + sched_b->LoadState(ckpt); + for (int i = 0; i < kN - kK; ++i) { sched_b->Step(); } + + ASSERT_FLOAT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== LambdaLR Tests ===" << std::endl; + TestIdentityLambda(); + TestLinearDecayLambda(); + TestPyTorchAlignment(); + TestStateRoundTrip(); + TestResumeConsistency(); + std::cout << "======================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} \ No newline at end of file diff --git a/test/lr_scheduler/test_linear_lr.cc b/test/lr_scheduler/test_linear_lr.cc new file mode 100644 index 00000000..e1659a7f --- /dev/null +++ b/test/lr_scheduler/test_linear_lr.cc @@ -0,0 +1,140 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { +constexpr float kBaseLR = 0.1f; +} + +void TestFirstStepFromZero() { + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "linear", + .linear_start_factor = 0.2f, + .linear_end_factor = 1.0f, + .linear_total_iters = 5, + }; + + auto sched = CreateLRScheduler(opt, config); + ASSERT_FLOAT_EQ(sched->GetLR(), 0.02f); +} + +void TestMidpointLR() { + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "linear", + .linear_start_factor = 0.2f, + .linear_end_factor = 1.0f, + .linear_total_iters = 5, + }; + auto sched = CreateLRScheduler(opt, config); + for (int i = 0; i < 3; ++i) { sched->Step(); } + // last_step_=3 -> 0.1*(0.2 + 0.8*3/5) = 0.068 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.068f); +} + +void TestWarmupEnd() { + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "linear", + .linear_start_factor = 0.2f, + .linear_end_factor = 1.0f, + .linear_total_iters = 5, + }; + auto sched = CreateLRScheduler(opt, config); + for (int i = 0; i < 5; ++i) { sched->Step(); } + // last_step_ >= total_iters -> base_lr * end_factor + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); +} + +void TestBeyondWarmup() { + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "linear", + .linear_start_factor = 0.2f, + .linear_end_factor = 1.0f, + .linear_total_iters = 5, + }; + auto sched = CreateLRScheduler(opt, config); + for (int i = 0; i < 20; ++i) { sched->Step(); } + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); +} + +void TestCustomStartFactor() { + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "linear", + .linear_start_factor = 0.25f, + .linear_end_factor = 1.0f, + .linear_total_iters = 4, + }; + auto sched = CreateLRScheduler(opt, config); + sched->Step(); // last_step_=1, lr=0.1*(0.25+0.75*1/4)=0.04375 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.04375f, 1e-6f); + sched->Step(); // last_step_=2, lr=0.1*(0.25+0.75*2/4)=0.0625 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.0625f, 1e-6f); +} + +void TestPyTorchAlignment() { + const std::vector expected = {0.036f, 0.052f, 0.068f, 0.084f, 0.1f, 0.1f, 0.1f}; + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "linear", + .linear_start_factor = 0.2f, + .linear_end_factor = 1.0f, + .linear_total_iters = 5, + }; + auto sched = CreateLRScheduler(opt, config); + for (size_t i = 0; i < expected.size(); ++i) { + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-7f); + } +} + +void TestChainableAndClosedFormConsistency() { + auto opt_a = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config_a = { + .type = "linear", + .linear_start_factor = 0.2f, + .linear_end_factor = 1.0f, + .linear_total_iters = 5, + }; + auto chainable = CreateLRScheduler(opt_a, config_a); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config_b = { + .type = "linear", + .linear_start_factor = 0.2f, + .linear_end_factor = 1.0f, + .linear_total_iters = 5, + }; + auto closed_form = CreateLRScheduler(opt_b, config_b); + + for (int epoch = 1; epoch <= 10; ++epoch) { + chainable->Step(); + closed_form->Step(epoch); + ASSERT_FLOAT_NEAR(chainable->GetLR(), closed_form->GetLR(), 1e-6f); + } +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== Linear Tests ===" << std::endl; + TestFirstStepFromZero(); + TestMidpointLR(); + TestWarmupEnd(); + TestBeyondWarmup(); + TestCustomStartFactor(); + TestPyTorchAlignment(); + TestChainableAndClosedFormConsistency(); + + std::cout << "========================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} \ No newline at end of file diff --git a/test/lr_scheduler/test_lr_scheduler.cc b/test/lr_scheduler/test_lr_scheduler.cc new file mode 100644 index 00000000..58f6bdd6 --- /dev/null +++ b/test/lr_scheduler/test_lr_scheduler.cc @@ -0,0 +1,178 @@ +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/lr_scheduler.h" +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +using namespace infini_train; + +namespace { + +constexpr float kBaseLR = 0.1f; +constexpr float kEps = 1e-7f; + +class IdentityScheduler : public LRScheduler { +public: + IdentityScheduler(std::shared_ptr optimizer, int64_t last_step = -1) + : LRScheduler(std::move(optimizer), last_step) {} + ~IdentityScheduler() override = default; + +protected: + float GetClosedFormLR() const override { return base_lr_; } +}; + +class LinearDecayScheduler : public LRScheduler { +public: + LinearDecayScheduler(std::shared_ptr optimizer, int64_t total_steps, int64_t last_step = -1) + : LRScheduler(std::move(optimizer), last_step), total_steps_(total_steps) {} + +protected: + float GetClosedFormLR() const override { + if (last_step_ >= total_steps_) { + return 0.0f; + } + return base_lr_ * (1.0f - static_cast(last_step_) / static_cast(total_steps_)); + } + +private: + int64_t total_steps_; +}; + +std::shared_ptr MakeDummyOptimizer(float lr) { + std::vector> empty_params; + return std::make_shared(empty_params, lr); +} + +bool FloatEq(float a, float b) { return std::fabs(a - b) < kEps; } + +int g_fail_count = 0; + +void Check(bool cond, const char *expr, int line) { + if (!cond) { + std::cerr << "FAIL [line " << line << "]: " << expr << std::endl; + ++g_fail_count; + } +} + +#define ASSERT_TRUE(cond) Check((cond), #cond, __LINE__) +#define ASSERT_FLOAT_EQ(a, b) Check(FloatEq((a), (b)), #a " == " #b, __LINE__) + +// T1: Init +void TestInitialState() { + std::cout << "[T1] TestInitialState" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt); + + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); + ASSERT_TRUE(sched->LastStep() == 0); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); +} + +// T2: SingleStep +void TestSingleStep() { + std::cout << "[T2] TestSingleStep" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt); + + sched->Step(); + + ASSERT_TRUE(sched->LastStep() == 1); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); +} + +// T3: ComputeLR +void TestLinearDecay() { + std::cout << "[T3] TestLinearDecay" << std::endl; + constexpr int64_t kTotalSteps = 10; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, kTotalSteps); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), kBaseLR); + + sched->Step(); // last_step = 1 -> 0.09 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.09f); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.09f); + + for (int i = 0; i < 4; ++i) { sched->Step(); } // last_step = 5 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); + ASSERT_FLOAT_EQ(opt->GetLearningRate(), 0.05f); +} + +// T4: State → LoadState 往返一致性。 +void TestStateRoundTrip() { + std::cout << "[T4] TestStateRoundTrip" << std::endl; + constexpr int64_t kTotalSteps = 20; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, kTotalSteps); + + for (int i = 0; i < 7; ++i) { sched->Step(); } + + StateDict saved = sched->State(); + + ASSERT_TRUE(saved.count("last_step") == 1); + ASSERT_TRUE(saved.count("current_lr") == 1); + ASSERT_TRUE(saved.count("base_lr") == 1); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + auto sched2 = LRScheduler::Create(opt2, kTotalSteps); + sched2->LoadState(saved); + + ASSERT_TRUE(sched2->LastStep() == 7); + ASSERT_FLOAT_EQ(sched2->GetLR(), sched->GetLR()); + ASSERT_FLOAT_EQ(opt2->GetLearningRate(), sched->GetLR()); +} + +// T5: resume Step +void TestResumeAndContinue() { + std::cout << "[T5] TestResumeAndContinue" << std::endl; + constexpr int64_t kTotalSteps = 20; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = LRScheduler::Create(opt_ref, kTotalSteps); + for (int i = 0; i < 10; ++i) { sched_ref->Step(); } + float lr_at_10 = sched_ref->GetLR(); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = LRScheduler::Create(opt_a, kTotalSteps); + for (int i = 0; i < 5; ++i) { sched_a->Step(); } + StateDict checkpoint = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = LRScheduler::Create(opt_b, kTotalSteps); + sched_b->LoadState(checkpoint); + for (int i = 0; i < 5; ++i) { sched_b->Step(); } + + ASSERT_FLOAT_EQ(sched_b->GetLR(), lr_at_10); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); +} + +} // namespace + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + + std::cout << "========================================" << std::endl; + std::cout << " LRScheduler Base Class Tests" << std::endl; + std::cout << "========================================" << std::endl; + + TestInitialState(); + TestSingleStep(); + TestLinearDecay(); + TestStateRoundTrip(); + TestResumeAndContinue(); + + std::cout << "========================================" << std::endl; + if (g_fail_count == 0) { + std::cout << " All Tests PASSED" << std::endl; + } else { + std::cout << " " << g_fail_count << " test(s) FAILED" << std::endl; + } + std::cout << "========================================" << std::endl; + + return g_fail_count > 0 ? 1 : 0; +} \ No newline at end of file diff --git a/test/lr_scheduler/test_sequential_lr.cc b/test/lr_scheduler/test_sequential_lr.cc new file mode 100644 index 00000000..df8d6cbd --- /dev/null +++ b/test/lr_scheduler/test_sequential_lr.cc @@ -0,0 +1,222 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; +namespace { +constexpr float kBaseLR = 0.1f; +} // namespace + +void TestLinearThenConstant() { + std::cout << "[TC1] TestLinearThenConstant" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + + LRSchedulerConfig linear_config = { + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }; + auto linear = CreateLRScheduler(opt, linear_config); + LRSchedulerConfig constant_config = { + .type = "constant", + .constant_factor = 1.0f, + .constant_total_iters = 100, + }; + auto constant = CreateLRScheduler(opt, constant_config); + auto sched = CreateLRScheduler(opt, { + .type = "sequential", + .sequential_configs = {linear_config, constant_config}, + .sequential_milestones = {3}, + }); + + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.0f, kEps); + + sched->Step(); // global=1, warmup step=1, lr=0.1*(1/3) + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f / 3.0f, 1e-5f); + + sched->Step(); // global=2, warmup step=2, lr=0.1*(2/3) + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.2f / 3.0f, 1e-5f); + + sched->Step(); // global=3, constant step=0, lr=0.1*1.0=0.1 + ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps); + + sched->Step(); // global=4, constant step=1, lr=0.1 + ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps); +} + +void TestLinearThenStepLR() { + std::cout << "[TC2] TestLinearThenStepLR" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + + LRSchedulerConfig linear_config = { + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }; + auto linear = CreateLRScheduler(opt, linear_config); + LRSchedulerConfig step_config = { + .type = "step", + .step_size = 3, + .step_gamma = 0.5f, + }; + auto step_lr = CreateLRScheduler(opt, step_config); + + auto sched = CreateLRScheduler(opt, { + .type = "sequential", + .sequential_configs = {linear_config, step_config}, + .sequential_milestones = {3}, + }); + + sched->Step(); // global=1 + sched->Step(); // global=2 + + sched->Step(); // global=3, StepLR step=0, lr=0.1 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); // global=4, StepLR step=1 + sched->Step(); // global=5, StepLR step=2 + sched->Step(); // global=6, StepLR step=3, 3//3=1, lr=0.1*0.5=0.05 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.05f, kEps); +} + +void TestLinearThenStepThenConstant() { + std::cout << "[TC3] TestLinearThenStepThenConstant" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + + LRSchedulerConfig linear_config = { + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }; + auto linear = CreateLRScheduler(opt, linear_config); + LRSchedulerConfig step_config = { + .type = "step", + .step_size = 3, + .step_gamma = 0.5f, + }; + auto step_lr = CreateLRScheduler(opt, step_config); + LRSchedulerConfig constant_config = { + .type = "constant", + .constant_factor = 0.5f, + .constant_total_iters = 2, + }; + auto constant = CreateLRScheduler(opt, constant_config); + + auto sched = CreateLRScheduler(opt, { + .type = "sequential", + .sequential_configs = {linear_config, step_config, constant_config}, + .sequential_milestones = {3, 6}, + }); + const std::vector expected = {0.033333f, 0.066667f, 0.1f, 0.1f, 0.1f, 0.05f, 0.05f, 0.1f, 0.1f, 0.1f}; + for (size_t i = 0; i < expected.size(); ++i) { + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-5f); + } +} + +void TestStateRoundTrip() { + std::cout << "[TC4] TestStateRoundTrip" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig linear_config = { + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }; + auto linear = CreateLRScheduler(opt, linear_config); + LRSchedulerConfig step_config = { + .type = "step", + .step_size = 3, + .step_gamma = 0.5f, + }; + auto step_lr = CreateLRScheduler(opt, step_config); + auto sched = CreateLRScheduler(opt, { + .type = "sequential", + .sequential_configs = {linear_config, step_config}, + .sequential_milestones = {3}, + }); + for (int i = 0; i < 5; ++i) { sched->Step(); } + StateDict saved = sched->State(); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig linear_config2 = { + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }; + auto linear2 = CreateLRScheduler(opt2, linear_config2); + LRSchedulerConfig step_config2 = { + .type = "step", + .step_size = 3, + .step_gamma = 0.5f, + }; + auto step_lr2 = CreateLRScheduler(opt2, step_config2); + auto sched2 = CreateLRScheduler(opt2, { + .type = "sequential", + .sequential_configs = {linear_config2, step_config2}, + .sequential_milestones = {3}, + }); + sched2->LoadState(saved); + + ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); + ASSERT_FLOAT_NEAR(sched2->GetLR(), sched->GetLR(), kEps); +} + +void TestResumeConsistency() { + std::cout << "[TC5] TestResumeConsistency" << std::endl; + constexpr int kN = 10, kK = 4; + + auto make_sched = [](std::shared_ptr opt) { + return CreateLRScheduler(opt, { + .type = "sequential", + .sequential_configs = {{ + .type = "linear", + .linear_start_factor = 1e-8f, + .linear_end_factor = 1.0f, + .linear_total_iters = 3, + }, + { + .type = "step", + .step_size = 3, + .step_gamma = 0.5f, + }}, + .sequential_milestones = {3}, + }); + }; + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = make_sched(opt_ref); + for (int i = 0; i < kN; ++i) { sched_ref->Step(); } + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = make_sched(opt_a); + for (int i = 0; i < kK; ++i) { sched_a->Step(); } + StateDict ckpt = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = make_sched(opt_b); + sched_b->LoadState(ckpt); + for (int i = 0; i < kN - kK; ++i) { sched_b->Step(); } + + ASSERT_FLOAT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== SequentialLR Tests ===" << std::endl; + TestLinearThenConstant(); + TestLinearThenStepLR(); + TestLinearThenStepThenConstant(); + TestStateRoundTrip(); + TestResumeConsistency(); + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} \ No newline at end of file diff --git a/test/lr_scheduler/test_step_lr.cc b/test/lr_scheduler/test_step_lr.cc new file mode 100644 index 00000000..698bcc49 --- /dev/null +++ b/test/lr_scheduler/test_step_lr.cc @@ -0,0 +1,119 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { +constexpr float kBaseLR = 0.1f; +} + +void TestWithinFirstPeriod() { + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }; + auto sched = CreateLRScheduler(opt, config); + for (int i = 0; i < 2; ++i) { + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); // last_step 1,2 → 指数 0 + } +} + +void TestFirstDecay() { + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }; + auto sched = CreateLRScheduler(opt, config); + for (int i = 0; i < 3; ++i) { sched->Step(); } + // last_step=3, 3//3=1 → 0.1^1 = 0.1 → lr=0.01 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.01f); +} + +void TestMultipleDecays() { + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }; + auto sched = CreateLRScheduler(opt, config); + for (int i = 0; i < 6; ++i) { sched->Step(); } + // last_step=6, 6//3=2 → 0.1^2 = 0.01 → lr=0.001 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.001f, 1e-7f); +} + +void TestPyTorchAlignment() { + const std::vector expected = {0.1f, 0.1f, 0.01f, 0.01f, 0.01f, 0.001f, 0.001f}; + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }; + auto sched = CreateLRScheduler(opt, config); + for (size_t i = 0; i < expected.size(); ++i) { + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-7f); + } +} + +void TestGammaOne() { + auto opt = MakeDummyOptimizer(kBaseLR); + LRSchedulerConfig config = { + .type = "step", + .step_size = 3, + .step_gamma = 1.0f, + }; + auto sched = CreateLRScheduler(opt, config); + for (int i = 0; i < 20; ++i) { + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); + } +} + +void TestChainableAndClosedFormConsistency() { + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto chainable = CreateLRScheduler(opt_a, { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto closed_form = CreateLRScheduler(opt_b, { + .type = "step", + .step_size = 3, + .step_gamma = 0.1f, + }); + + for (int epoch = 1; epoch <= 12; ++epoch) { + chainable->Step(); + closed_form->Step(epoch); + ASSERT_FLOAT_NEAR(chainable->GetLR(), closed_form->GetLR(), 1e-7f); + } +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== Step Tests ===" << std::endl; + TestWithinFirstPeriod(); + TestFirstDecay(); + TestMultipleDecays(); + TestPyTorchAlignment(); + TestGammaOne(); + TestChainableAndClosedFormConsistency(); + + std::cout << "========================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} \ No newline at end of file