From 766e6797d9f79b108fbe73b45409f3de4b72dacb Mon Sep 17 00:00:00 2001 From: kilinchange Date: Mon, 11 May 2026 11:32:03 +0000 Subject: [PATCH 1/3] feat: implement MoE infrastructure --- infini_train/include/autograd/moe.h | 26 +++++ .../nn/modules/transformer/moe/experts.h | 25 ++++ .../nn/modules/transformer/moe/moe_layer.h | 25 ++++ .../nn/modules/transformer/moe/moe_utils.h | 9 ++ .../nn/modules/transformer/moe/router.h | 25 ++++ .../modules/transformer/transformer_config.h | 33 ++++++ infini_train/src/autograd/moe.cc | 31 +++++ infini_train/src/kernels/cpu/top1_mask.cc | 67 +++++++++++ infini_train/src/kernels/cuda/top1_mask.cu | 107 ++++++++++++++++++ .../src/nn/modules/transformer/moe/experts.cc | 50 ++++++++ .../nn/modules/transformer/moe/moe_layer.cc | 32 ++++++ .../nn/modules/transformer/moe/moe_utils.cc | 12 ++ .../src/nn/modules/transformer/moe/router.cc | 50 ++++++++ .../src/nn/modules/transformer/transformer.cc | 7 +- .../test_transformer_architecture.cc | 44 +++++++ 15 files changed, 542 insertions(+), 1 deletion(-) create mode 100644 infini_train/include/autograd/moe.h create mode 100644 infini_train/include/nn/modules/transformer/moe/experts.h create mode 100644 infini_train/include/nn/modules/transformer/moe/moe_layer.h create mode 100644 infini_train/include/nn/modules/transformer/moe/moe_utils.h create mode 100644 infini_train/include/nn/modules/transformer/moe/router.h create mode 100644 infini_train/src/autograd/moe.cc create mode 100644 infini_train/src/kernels/cpu/top1_mask.cc create mode 100644 infini_train/src/kernels/cuda/top1_mask.cu create mode 100644 infini_train/src/nn/modules/transformer/moe/experts.cc create mode 100644 infini_train/src/nn/modules/transformer/moe/moe_layer.cc create mode 100644 infini_train/src/nn/modules/transformer/moe/moe_utils.cc create mode 100644 infini_train/src/nn/modules/transformer/moe/router.cc diff --git a/infini_train/include/autograd/moe.h b/infini_train/include/autograd/moe.h new file mode 100644 index 00000000..5317de8e --- /dev/null +++ b/infini_train/include/autograd/moe.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class Top1Mask : public Function { +public: + static constexpr char kType[] = "Top1MaskFunction"; + + Top1Mask() : Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/nn/modules/transformer/moe/experts.h b/infini_train/include/nn/modules/transformer/moe/experts.h new file mode 100644 index 00000000..a3dda7f0 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/experts.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class SequentialMLP : public CloneableModule { +public: + static constexpr char kType[] = "SequentialMLP"; + static constexpr char kExpertNamePrefix[] = "expert_"; + + explicit SequentialMLP(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; + int64_t num_local_experts_ = 0; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/moe_layer.h b/infini_train/include/nn/modules/transformer/moe/moe_layer.h new file mode 100644 index 00000000..e5fdb3ab --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/moe_layer.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class MoELayer : public CloneableModule { +public: + static constexpr char kType[] = "MoELayer"; + static constexpr char kRouterLayerName[] = "router"; + static constexpr char kExpertsLayerName[] = "experts"; + + explicit MoELayer(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/moe_utils.h b/infini_train/include/nn/modules/transformer/moe/moe_utils.h new file mode 100644 index 00000000..e0dd3744 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/moe_utils.h @@ -0,0 +1,9 @@ +#pragma once + +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +const MoEConfig &RequireMoEConfig(const TransformerConfig &config); + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/router.h b/infini_train/include/nn/modules/transformer/moe/router.h new file mode 100644 index 00000000..1279c217 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/router.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class TopKRouter : public CloneableModule { +public: + static constexpr char kType[] = "TopKRouter"; + static constexpr char kParamWeightName[] = "weight"; + static constexpr char kParamBiasName[] = "bias"; + + explicit TopKRouter(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index 62379666..b55ce4fc 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -20,11 +20,42 @@ enum class MLPType { kSwiGLU // SwiGLU activation }; +enum class FFNType { + kDense, // Standard dense MLP + kMoE // Mixture-of-Experts MLP +}; + enum class NormType { kLayerNorm, // LayerNorm kRMSNorm // RMSNorm }; +enum class MoERouterType { + kTopK // Top-k router. The initial implementation supports top-1. +}; + +enum class MoEDispatcherType { + kLocal, // No cross-rank token exchange + kAllGather // Reserved for expert parallel MoE +}; + +enum class MoEExpertImpl { + kSequential // Run local experts sequentially +}; + +struct MoEConfig { + int64_t num_experts = 0; + int64_t expert_parallel_size = 1; + int64_t router_topk = 1; + float aux_loss_coeff = 0.0f; + std::optional expert_capacity_factor = std::nullopt; + bool pad_expert_input_to_capacity = false; + int64_t moe_ffn_hidden_size = 0; + MoERouterType router_type = MoERouterType::kTopK; + MoEDispatcherType dispatcher_type = MoEDispatcherType::kLocal; + MoEExpertImpl expert_impl = MoEExpertImpl::kSequential; +}; + struct TransformerConfig { int64_t block_size = 1024; // Max seq_len int64_t vocab_size = 50304; // Vocab size @@ -36,6 +67,7 @@ struct TransformerConfig { AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type MLPType activation_type = MLPType::kGELU; // MLP activation type + FFNType ffn_type = FFNType::kDense; // Feed-forward module type NormType norm_type = NormType::kLayerNorm; // Normalization type bool add_bias_linear = true; // Whether to add learnable bias to all Linear layers in the Transformer block, @@ -48,6 +80,7 @@ struct TransformerConfig { float ffn_expansion_ratio = 4.0f; // MLP output: n_embd * ffn_expansion_ratio std::optional ffn_dim_multiplier = 1.5f; // FFN dim multiplier int64_t multiple_of = 256; // FFN dims must be multiple of this number + std::optional moe_config = std::nullopt; // RoPE config float rope_theta = 500000.0f; // theta in RoPE diff --git a/infini_train/src/autograd/moe.cc b/infini_train/src/autograd/moe.cc new file mode 100644 index 00000000..05134e82 --- /dev/null +++ b/infini_train/src/autograd/moe.cc @@ -0,0 +1,31 @@ +#include "infini_train/include/autograd/moe.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> Top1Mask::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; + auto device = input->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "Top1MaskForward"}, input)}; +} + +void Top1Mask::SetupContext(const std::vector> &, + const std::vector> &output_tensors) { + saved_tensors_ = {output_tensors[0]}; +} + +std::vector> Top1Mask::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + const auto &mask_values = saved_tensors_[0]; + auto device = grad_output->GetDevice().type(); + return { + Dispatcher::Instance().Call>({device, "Top1MaskBackward"}, grad_output, mask_values)}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/top1_mask.cc b/infini_train/src/kernels/cpu/top1_mask.cc new file mode 100644 index 00000000..d6ae91d6 --- /dev/null +++ b/infini_train/src/kernels/cpu/top1_mask.cc @@ -0,0 +1,67 @@ +#include + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::kernels::cpu { + +std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { + CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskForward currently supports float32 only"; + CHECK_GE(input->Dims().size(), 1); + + const auto &dims = input->Dims(); + const int64_t num_experts = dims.back(); + CHECK_GT(num_experts, 0); + const int64_t rows = input->NumElements() / num_experts; + + auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); + output->Fill(0.0f); + + const float *in = static_cast(input->DataPtr()); + float *out = static_cast(output->DataPtr()); + for (int64_t row = 0; row < rows; ++row) { + int64_t best_idx = 0; + float best_value = in[row * num_experts]; + for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) { + const float value = in[row * num_experts + expert_idx]; + if (value > best_value) { + best_value = value; + best_idx = expert_idx; + } + } + out[row * num_experts + best_idx] = best_value; + } + + return output; +} + +std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_output, + const std::shared_ptr &mask_values) { + CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskBackward currently supports float32 only"; + CHECK(mask_values->Dtype() == DataType::kFLOAT32); + CHECK(grad_output->Dims() == mask_values->Dims()); + + auto grad_input = std::make_shared(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice()); + grad_input->Fill(0.0f); + + const float *grad = static_cast(grad_output->DataPtr()); + const float *mask = static_cast(mask_values->DataPtr()); + float *out = static_cast(grad_input->DataPtr()); + for (int64_t i = 0; i < static_cast(grad_output->NumElements()); ++i) { + out[i] = mask[i] != 0.0f ? grad[i] : 0.0f; + } + + return grad_input; +} + +} // namespace infini_train::kernels::cpu + +#define REGISTER_CPU_TOP1_MASK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + +REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskForward) +REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskBackward) + +#undef REGISTER_CPU_TOP1_MASK_KERNEL diff --git a/infini_train/src/kernels/cuda/top1_mask.cu b/infini_train/src/kernels/cuda/top1_mask.cu new file mode 100644 index 00000000..8fd00c91 --- /dev/null +++ b/infini_train/src/kernels/cuda/top1_mask.cu @@ -0,0 +1,107 @@ +#include "glog/logging.h" + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h" +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernels::cuda { + +template +__global__ void Top1MaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows, + int64_t num_experts) { + int64_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= rows) { + return; + } + + const int64_t offset = row * num_experts; + int64_t best_idx = 0; + float best_value = static_cast(input[offset]); + for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) { + const float value = static_cast(input[offset + expert_idx]); + if (value > best_value) { + best_value = value; + best_idx = expert_idx; + } + } + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + output[offset + expert_idx] = expert_idx == best_idx ? input[offset + expert_idx] : T(0.0f); + } +} + +std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { + CHECK_GE(input->Dims().size(), 1); + const auto &dims = input->Dims(); + const int64_t num_experts = dims.back(); + CHECK_GT(num_experts, 0); + const int64_t rows = input->NumElements() / num_experts; + + auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); + + auto device = input->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int threads = 256; + const int blocks = static_cast((rows + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + input->Dtype(), + [=]() { + Top1MaskForwardKernel<<>>( + static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, num_experts); + }, + "CUDA Top1MaskForward"); + + return output; +} + +template +__global__ void Top1MaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values, + T *__restrict__ grad_input, int64_t total_elements) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_elements) { + return; + } + grad_input[idx] = static_cast(mask_values[idx]) != 0.0f ? grad_output[idx] : T(0.0f); +} + +std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_output, + const std::shared_ptr &mask_values) { + CHECK(grad_output->Dims() == mask_values->Dims()); + CHECK(grad_output->Dtype() == mask_values->Dtype()); + auto grad_input = std::make_shared(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice()); + + auto device = grad_output->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int64_t total_elements = grad_output->NumElements(); + const int threads = 256; + const int blocks = static_cast((total_elements + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + grad_output->Dtype(), + [=]() { + Top1MaskBackwardKernel<<>>( + static_cast(grad_output->DataPtr()), static_cast(mask_values->DataPtr()), + static_cast(grad_input->DataPtr()), total_elements); + }, + "CUDA Top1MaskBackward"); + + return grad_input; +} + +} // namespace infini_train::kernels::cuda + +#define REGISTER_CUDA_TOP1_MASK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskForward) +REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskBackward) + +#undef REGISTER_CUDA_TOP1_MASK_KERNEL diff --git a/infini_train/src/nn/modules/transformer/moe/experts.cc b/infini_train/src/nn/modules/transformer/moe/experts.cc new file mode 100644 index 00000000..8f3b1be8 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/experts.cc @@ -0,0 +1,50 @@ +#include "infini_train/include/nn/modules/transformer/moe/experts.h" + +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(moe_config.expert_impl == MoEExpertImpl::kSequential); + CHECK_EQ(moe_config.expert_parallel_size, 1) + << "Current InfiniTrain MoE implementation supports expert_parallel_size=1 only"; + CHECK(moe_config.dispatcher_type == MoEDispatcherType::kLocal) + << "Current InfiniTrain MoE implementation supports local dispatch only"; + + num_local_experts_ = moe_config.num_experts; + CHECK_GT(num_local_experts_, 0); + + for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) { + modules_[std::string(kExpertNamePrefix) + std::to_string(expert_idx)] = std::make_shared(config_); + } +} + +std::vector> SequentialMLP::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 2); + auto hidden_states = input_tensors[0]; + auto routing_probs = input_tensors[1]; + CHECK_EQ(routing_probs->Dims().back(), num_local_experts_); + + std::shared_ptr output = nullptr; + const int64_t expert_dim = static_cast(routing_probs->Dims().size()) - 1; + for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) { + auto expert_name = std::string(kExpertNamePrefix) + std::to_string(expert_idx); + auto expert_output = (*modules_.at(expert_name))({hidden_states})[0]; + auto expert_prob = routing_probs->Slice(expert_dim, expert_idx, expert_idx + 1); + auto weighted_output = expert_output * expert_prob; + output = output == nullptr ? weighted_output : output + weighted_output; + } + + return {output}; +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_layer.cc b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc new file mode 100644 index 00000000..8efd51c0 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc @@ -0,0 +1,32 @@ +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/moe/experts.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/nn/modules/transformer/moe/router.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +MoELayer::MoELayer(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(config_.ffn_type == FFNType::kMoE); + CHECK(moe_config.dispatcher_type == MoEDispatcherType::kLocal) + << "Current InfiniTrain MoE implementation supports local dispatch only"; + + modules_[kRouterLayerName] = std::make_shared(config_); + modules_[kExpertsLayerName] = std::make_shared(config_); +} + +std::vector> MoELayer::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + auto hidden_states = input_tensors[0]; + auto routing_probs = (*modules_.at(kRouterLayerName))({hidden_states})[0]; + return (*modules_.at(kExpertsLayerName))({hidden_states, routing_probs}); +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_utils.cc b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc new file mode 100644 index 00000000..80ef01c1 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc @@ -0,0 +1,12 @@ +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" + +#include "glog/logging.h" + +namespace infini_train::nn::moe { + +const MoEConfig &RequireMoEConfig(const TransformerConfig &config) { + CHECK(config.moe_config.has_value()) << "MoE layer requires TransformerConfig::moe_config"; + return config.moe_config.value(); +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/router.cc b/infini_train/src/nn/modules/transformer/moe/router.cc new file mode 100644 index 00000000..59dec209 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/router.cc @@ -0,0 +1,50 @@ +#include "infini_train/include/nn/modules/transformer/moe/router.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/autograd/linear.h" +#include "infini_train/include/autograd/moe.h" +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/init.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +TopKRouter::TopKRouter(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(moe_config.router_type == MoERouterType::kTopK); + CHECK_EQ(moe_config.router_topk, 1) << "Current InfiniTrain MoE implementation supports top-1 routing only"; + CHECK_GT(moe_config.num_experts, 0); + + parameters_[kParamWeightName] + = std::make_shared(std::vector{moe_config.num_experts, config_.n_embd}, DataType::kFLOAT32, + device_) + ->RequiresGrad(); + init::KaimingUniform(parameters_[kParamWeightName]); + + if (config_.add_bias_linear) { + parameters_[kParamBiasName] + = std::make_shared(std::vector{moe_config.num_experts}, DataType::kFLOAT32, device_) + ->RequiresGrad(); + parameters_[kParamBiasName]->Fill(0.0f); + } +} + +std::vector> TopKRouter::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + std::vector> linear_inputs{input_tensors[0], parameters_.at(kParamWeightName)}; + if (parameters_.contains(kParamBiasName)) { + linear_inputs.push_back(parameters_.at(kParamBiasName)); + } + + auto logits = std::make_shared()->Apply(linear_inputs)[0]; + auto scores = function::Softmax(logits, -1); + auto routing_probs = std::make_shared()->Apply({scores})[0]; + return {routing_probs}; +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/transformer.cc b/infini_train/src/nn/modules/transformer/transformer.cc index c7e0f28c..bdcde449 100644 --- a/infini_train/src/nn/modules/transformer/transformer.cc +++ b/infini_train/src/nn/modules/transformer/transformer.cc @@ -15,6 +15,7 @@ #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" #include "infini_train/include/nn/modules/transformer/utils.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" @@ -86,7 +87,11 @@ TransformerLayer::TransformerLayer(const nn::TransformerConfig &config) : Clonea } modules_[kAttnLayerName] = std::make_shared(config); - modules_[kMlpLayerName] = std::make_shared(config); + if (config.ffn_type == FFNType::kMoE) { + modules_[kMlpLayerName] = std::make_shared(config); + } else { + modules_[kMlpLayerName] = std::make_shared(config); + } } std::vector> TransformerLayer::Forward(const std::vector> &x) { diff --git a/test/transformer/test_transformer_architecture.cc b/test/transformer/test_transformer_architecture.cc index 1c6ed28f..da3dd70e 100644 --- a/test/transformer/test_transformer_architecture.cc +++ b/test/transformer/test_transformer_architecture.cc @@ -11,6 +11,7 @@ #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" #include "infini_train/include/nn/modules/transformer/transformer.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" #include "infini_train/include/nn/modules/transformer/utils.h" @@ -525,6 +526,48 @@ void TestStateDict() { } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } } +// ============================================================================ +// Test 11: MoE Layer MVP +// ============================================================================ +void TestMoELayer() { + std::cout << "\n=== Test 11: MoE Layer MVP ===" << std::endl; + + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kGELU; + config.add_bias_linear = true; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 2; + config.moe_config->router_topk = 1; + + try { + auto moe = std::make_shared(config); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); + input->Uniform(); + + auto output = (*moe)({input}); + if (output.size() != 1) { + std::cout << "FAIL: MoELayer forward should return 1 tensor" << std::endl; + return; + } + if (output[0]->Dims() != input->Dims()) { + std::cout << "FAIL: MoELayer output shape mismatch" << std::endl; + return; + } + + auto params = moe->Parameters(); + if (params.empty()) { + std::cout << "FAIL: MoELayer should own router and expert parameters" << std::endl; + return; + } + + std::cout << "SUCCESS: MoE layer MVP forward works correctly!" << std::endl; + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + // ============================================================================ // Main // ============================================================================ @@ -547,6 +590,7 @@ int main(int argc, char *argv[]) { TestLlama3Model(); TestRopeUtils(); TestStateDict(); + TestMoELayer(); std::cout << "\n========================================" << std::endl; std::cout << " All Tests Completed" << std::endl; From 8fc75c3cc4ead9be6c9cf467eb4a6eef5e71c625 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 13 May 2026 03:01:31 +0000 Subject: [PATCH 2/3] feat: support topk_router --- .../include/autograd/{moe.h => topk_mask.h} | 9 ++- .../modules/transformer/transformer_config.h | 2 +- .../src/autograd/{moe.cc => topk_mask.cc} | 13 ++-- .../cpu/{top1_mask.cc => topk_mask.cc} | 53 ++++++++++++----- .../cuda/{top1_mask.cu => topk_mask.cu} | 55 ++++++++++------- .../src/nn/modules/transformer/moe/router.cc | 8 ++- .../test_transformer_architecture.cc | 59 ++++++++++++------- 7 files changed, 126 insertions(+), 73 deletions(-) rename infini_train/include/autograd/{moe.h => topk_mask.h} (76%) rename infini_train/src/autograd/{moe.cc => topk_mask.cc} (70%) rename infini_train/src/kernels/cpu/{top1_mask.cc => topk_mask.cc} (50%) rename infini_train/src/kernels/cuda/{top1_mask.cu => topk_mask.cu} (66%) diff --git a/infini_train/include/autograd/moe.h b/infini_train/include/autograd/topk_mask.h similarity index 76% rename from infini_train/include/autograd/moe.h rename to infini_train/include/autograd/topk_mask.h index 5317de8e..355ef400 100644 --- a/infini_train/include/autograd/moe.h +++ b/infini_train/include/autograd/topk_mask.h @@ -11,16 +11,19 @@ class Tensor; namespace infini_train::autograd { -class Top1Mask : public Function { +class TopKMask : public Function { public: - static constexpr char kType[] = "Top1MaskFunction"; + static constexpr char kType[] = "TopKMaskFunction"; - Top1Mask() : Function(kType) {} + explicit TopKMask(int64_t topk) : Function(kType), topk_(topk) {} std::vector> Forward(const std::vector> &input_tensors) override; void SetupContext(const std::vector> &input_tensors, const std::vector> &output_tensors) override; std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + int64_t topk_ = 1; }; } // namespace infini_train::autograd diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index b55ce4fc..3a96625d 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -31,7 +31,7 @@ enum class NormType { }; enum class MoERouterType { - kTopK // Top-k router. The initial implementation supports top-1. + kTopK // Top-k router. }; enum class MoEDispatcherType { diff --git a/infini_train/src/autograd/moe.cc b/infini_train/src/autograd/topk_mask.cc similarity index 70% rename from infini_train/src/autograd/moe.cc rename to infini_train/src/autograd/topk_mask.cc index 05134e82..16dc6629 100644 --- a/infini_train/src/autograd/moe.cc +++ b/infini_train/src/autograd/topk_mask.cc @@ -1,4 +1,4 @@ -#include "infini_train/include/autograd/moe.h" +#include "infini_train/include/autograd/topk_mask.h" #include "glog/logging.h" @@ -7,25 +7,26 @@ namespace infini_train::autograd { -std::vector> Top1Mask::Forward(const std::vector> &input_tensors) { +std::vector> TopKMask::Forward(const std::vector> &input_tensors) { CHECK_EQ(input_tensors.size(), 1); + CHECK_GT(topk_, 0); const auto &input = input_tensors[0]; auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "Top1MaskForward"}, input)}; + return {Dispatcher::Instance().Call>({device, "TopKMaskForward"}, input, topk_)}; } -void Top1Mask::SetupContext(const std::vector> &, +void TopKMask::SetupContext(const std::vector> &, const std::vector> &output_tensors) { saved_tensors_ = {output_tensors[0]}; } -std::vector> Top1Mask::Backward(const std::vector> &grad_outputs) { +std::vector> TopKMask::Backward(const std::vector> &grad_outputs) { CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; const auto &mask_values = saved_tensors_[0]; auto device = grad_output->GetDevice().type(); return { - Dispatcher::Instance().Call>({device, "Top1MaskBackward"}, grad_output, mask_values)}; + Dispatcher::Instance().Call>({device, "TopKMaskBackward"}, grad_output, mask_values)}; } } // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/top1_mask.cc b/infini_train/src/kernels/cpu/topk_mask.cc similarity index 50% rename from infini_train/src/kernels/cpu/top1_mask.cc rename to infini_train/src/kernels/cpu/topk_mask.cc index d6ae91d6..6a7191b9 100644 --- a/infini_train/src/kernels/cpu/top1_mask.cc +++ b/infini_train/src/kernels/cpu/topk_mask.cc @@ -1,4 +1,6 @@ +#include #include +#include #include "glog/logging.h" @@ -7,13 +9,15 @@ namespace infini_train::kernels::cpu { -std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { - CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskForward currently supports float32 only"; +std::shared_ptr TopKMaskForward(const std::shared_ptr &input, int64_t topk) { + CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskForward currently supports float32 only"; CHECK_GE(input->Dims().size(), 1); const auto &dims = input->Dims(); const int64_t num_experts = dims.back(); CHECK_GT(num_experts, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, num_experts); const int64_t rows = input->NumElements() / num_experts; auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); @@ -22,24 +26,41 @@ std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { const float *in = static_cast(input->DataPtr()); float *out = static_cast(output->DataPtr()); for (int64_t row = 0; row < rows; ++row) { - int64_t best_idx = 0; - float best_value = in[row * num_experts]; - for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) { - const float value = in[row * num_experts + expert_idx]; - if (value > best_value) { - best_value = value; - best_idx = expert_idx; + const int64_t row_offset = row * num_experts; + std::vector selected_experts(num_experts, false); + float selected_sum = 0.0f; + for (int64_t selected = 0; selected < topk; ++selected) { + int64_t best_idx = -1; + float best_value = -std::numeric_limits::infinity(); + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + if (selected_experts[expert_idx]) { + continue; + } + const float value = in[row_offset + expert_idx]; + if (value > best_value) { + best_value = value; + best_idx = expert_idx; + } + } + CHECK_GE(best_idx, 0); + selected_experts[best_idx] = true; + out[row_offset + best_idx] = best_value; + selected_sum += best_value; + } + if (topk > 1 && selected_sum != 0.0f) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + out[row_offset + expert_idx] + = out[row_offset + expert_idx] == 0.0f ? 0.0f : out[row_offset + expert_idx] / selected_sum; } } - out[row * num_experts + best_idx] = best_value; } return output; } -std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_output, +std::shared_ptr TopKMaskBackward(const std::shared_ptr &grad_output, const std::shared_ptr &mask_values) { - CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskBackward currently supports float32 only"; + CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskBackward currently supports float32 only"; CHECK(mask_values->Dtype() == DataType::kFLOAT32); CHECK(grad_output->Dims() == mask_values->Dims()); @@ -58,10 +79,10 @@ std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_out } // namespace infini_train::kernels::cpu -#define REGISTER_CPU_TOP1_MASK_KERNEL(kernel_name) \ +#define REGISTER_CPU_TOPK_MASK_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) -REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskForward) -REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskBackward) +REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskForward) +REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskBackward) -#undef REGISTER_CPU_TOP1_MASK_KERNEL +#undef REGISTER_CPU_TOPK_MASK_KERNEL diff --git a/infini_train/src/kernels/cuda/top1_mask.cu b/infini_train/src/kernels/cuda/topk_mask.cu similarity index 66% rename from infini_train/src/kernels/cuda/top1_mask.cu rename to infini_train/src/kernels/cuda/topk_mask.cu index 8fd00c91..e38c793e 100644 --- a/infini_train/src/kernels/cuda/top1_mask.cu +++ b/infini_train/src/kernels/cuda/topk_mask.cu @@ -11,33 +11,44 @@ namespace infini_train::kernels::cuda { template -__global__ void Top1MaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows, - int64_t num_experts) { +__global__ void TopKMaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows, + int64_t num_experts, int64_t topk) { int64_t row = blockIdx.x * blockDim.x + threadIdx.x; if (row >= rows) { return; } const int64_t offset = row * num_experts; - int64_t best_idx = 0; - float best_value = static_cast(input[offset]); - for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) { + float selected_sum = 0.0f; + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { const float value = static_cast(input[offset + expert_idx]); - if (value > best_value) { - best_value = value; - best_idx = expert_idx; + int64_t rank = 0; + for (int64_t other_idx = 0; other_idx < num_experts; ++other_idx) { + const float other_value = static_cast(input[offset + other_idx]); + if (other_value > value || (other_value == value && other_idx < expert_idx)) { + ++rank; + } } + const bool selected = rank < topk; + output[offset + expert_idx] = selected ? input[offset + expert_idx] : T(0.0f); + selected_sum += selected ? value : 0.0f; } - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - output[offset + expert_idx] = expert_idx == best_idx ? input[offset + expert_idx] : T(0.0f); + if (topk > 1 && selected_sum != 0.0f) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + if (static_cast(output[offset + expert_idx]) != 0.0f) { + output[offset + expert_idx] = T(static_cast(output[offset + expert_idx]) / selected_sum); + } + } } } -std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { +std::shared_ptr TopKMaskForward(const std::shared_ptr &input, int64_t topk) { CHECK_GE(input->Dims().size(), 1); const auto &dims = input->Dims(); const int64_t num_experts = dims.back(); CHECK_GT(num_experts, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, num_experts); const int64_t rows = input->NumElements() / num_experts; auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); @@ -52,16 +63,16 @@ std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { core::cuda::DispatchCudaFunc( input->Dtype(), [=]() { - Top1MaskForwardKernel<<>>( - static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, num_experts); + TopKMaskForwardKernel<<>>( + static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, num_experts, topk); }, - "CUDA Top1MaskForward"); + "CUDA TopKMaskForward"); return output; } template -__global__ void Top1MaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values, +__global__ void TopKMaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values, T *__restrict__ grad_input, int64_t total_elements) { int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= total_elements) { @@ -70,7 +81,7 @@ __global__ void Top1MaskBackwardKernel(const T *__restrict__ grad_output, const grad_input[idx] = static_cast(mask_values[idx]) != 0.0f ? grad_output[idx] : T(0.0f); } -std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_output, +std::shared_ptr TopKMaskBackward(const std::shared_ptr &grad_output, const std::shared_ptr &mask_values) { CHECK(grad_output->Dims() == mask_values->Dims()); CHECK(grad_output->Dtype() == mask_values->Dtype()); @@ -87,21 +98,21 @@ std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_out core::cuda::DispatchCudaFunc( grad_output->Dtype(), [=]() { - Top1MaskBackwardKernel<<>>( + TopKMaskBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(mask_values->DataPtr()), static_cast(grad_input->DataPtr()), total_elements); }, - "CUDA Top1MaskBackward"); + "CUDA TopKMaskBackward"); return grad_input; } } // namespace infini_train::kernels::cuda -#define REGISTER_CUDA_TOP1_MASK_KERNEL(kernel_name) \ +#define REGISTER_CUDA_TOPK_MASK_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) -REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskForward) -REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskBackward) +REGISTER_CUDA_TOPK_MASK_KERNEL(TopKMaskForward) +REGISTER_CUDA_TOPK_MASK_KERNEL(TopKMaskBackward) -#undef REGISTER_CUDA_TOP1_MASK_KERNEL +#undef REGISTER_CUDA_TOPK_MASK_KERNEL diff --git a/infini_train/src/nn/modules/transformer/moe/router.cc b/infini_train/src/nn/modules/transformer/moe/router.cc index 59dec209..851c57be 100644 --- a/infini_train/src/nn/modules/transformer/moe/router.cc +++ b/infini_train/src/nn/modules/transformer/moe/router.cc @@ -6,7 +6,7 @@ #include "glog/logging.h" #include "infini_train/include/autograd/linear.h" -#include "infini_train/include/autograd/moe.h" +#include "infini_train/include/autograd/topk_mask.h" #include "infini_train/include/nn/functional.h" #include "infini_train/include/nn/init.h" #include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" @@ -17,8 +17,9 @@ namespace infini_train::nn::moe { TopKRouter::TopKRouter(const TransformerConfig &config) : CloneableModule(kType), config_(config) { const auto &moe_config = RequireMoEConfig(config_); CHECK(moe_config.router_type == MoERouterType::kTopK); - CHECK_EQ(moe_config.router_topk, 1) << "Current InfiniTrain MoE implementation supports top-1 routing only"; CHECK_GT(moe_config.num_experts, 0); + CHECK_GT(moe_config.router_topk, 0); + CHECK_LE(moe_config.router_topk, moe_config.num_experts); parameters_[kParamWeightName] = std::make_shared(std::vector{moe_config.num_experts, config_.n_embd}, DataType::kFLOAT32, @@ -43,7 +44,8 @@ std::vector> TopKRouter::Forward(const std::vector()->Apply(linear_inputs)[0]; auto scores = function::Softmax(logits, -1); - auto routing_probs = std::make_shared()->Apply({scores})[0]; + const auto &moe_config = RequireMoEConfig(config_); + auto routing_probs = std::make_shared(moe_config.router_topk)->Apply({scores})[0]; return {routing_probs}; } diff --git a/test/transformer/test_transformer_architecture.cc b/test/transformer/test_transformer_architecture.cc index da3dd70e..469ff386 100644 --- a/test/transformer/test_transformer_architecture.cc +++ b/test/transformer/test_transformer_architecture.cc @@ -527,10 +527,10 @@ void TestStateDict() { } // ============================================================================ -// Test 11: MoE Layer MVP +// Test 11: MoE Layer // ============================================================================ void TestMoELayer() { - std::cout << "\n=== Test 11: MoE Layer MVP ===" << std::endl; + std::cout << "\n=== Test 11: MoE Layer ===" << std::endl; nn::TransformerConfig config; config.n_embd = 32; @@ -543,29 +543,43 @@ void TestMoELayer() { config.moe_config->num_experts = 2; config.moe_config->router_topk = 1; - try { - auto moe = std::make_shared(config); - auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); - input->Uniform(); + auto moe = std::make_shared(config); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); + input->Uniform(); - auto output = (*moe)({input}); - if (output.size() != 1) { - std::cout << "FAIL: MoELayer forward should return 1 tensor" << std::endl; - return; - } - if (output[0]->Dims() != input->Dims()) { - std::cout << "FAIL: MoELayer output shape mismatch" << std::endl; - return; - } + auto output = (*moe)({input}); + CHECK_EQ(output.size(), 1); + CHECK(output[0]->Dims() == input->Dims()); - auto params = moe->Parameters(); - if (params.empty()) { - std::cout << "FAIL: MoELayer should own router and expert parameters" << std::endl; - return; - } + auto params = moe->Parameters(); + CHECK(!params.empty()); - std::cout << "SUCCESS: MoE layer MVP forward works correctly!" << std::endl; - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } + std::cout << "SUCCESS: MoE layer forward works correctly!" << std::endl; +} + +void TestMoELayerTop2() { + std::cout << "\n=== Test 12: MoE Layer Top-2 ===" << std::endl; + + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kGELU; + config.add_bias_linear = true; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 4; + config.moe_config->router_topk = 2; + + auto moe = std::make_shared(config); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); + input->Uniform(); + + auto output = (*moe)({input}); + CHECK_EQ(output.size(), 1); + CHECK(output[0]->Dims() == input->Dims()); + + std::cout << "SUCCESS: MoE layer top-2 forward works correctly!" << std::endl; } // ============================================================================ @@ -591,6 +605,7 @@ int main(int argc, char *argv[]) { TestRopeUtils(); TestStateDict(); TestMoELayer(); + TestMoELayerTop2(); std::cout << "\n========================================" << std::endl; std::cout << " All Tests Completed" << std::endl; From 556a0d497f7df05c1f80206db47745ef26acfe84 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 13 May 2026 07:36:52 +0000 Subject: [PATCH 3/3] feat: support moe_ffn_hidden_size config --- infini_train/src/nn/modules/transformer/mlp.cc | 7 ++++++- test/transformer/test_transformer_architecture.cc | 13 +++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/infini_train/src/nn/modules/transformer/mlp.cc b/infini_train/src/nn/modules/transformer/mlp.cc index 3af341b2..ac35d144 100644 --- a/infini_train/src/nn/modules/transformer/mlp.cc +++ b/infini_train/src/nn/modules/transformer/mlp.cc @@ -35,9 +35,14 @@ MLP::MLP(const TransformerConfig &config) : CloneableModule(kType) { } // Round up to multiple_of - int64_t before_round = ffn_hidden; ffn_hidden = (ffn_hidden + config.multiple_of - 1) / config.multiple_of * config.multiple_of; + if (config.ffn_type == FFNType::kMoE && config.moe_config.has_value() + && config.moe_config->moe_ffn_hidden_size > 0) { + ffn_hidden = config.moe_config->moe_ffn_hidden_size; + } + CHECK_GT(ffn_hidden, 0); + // c_fc: ColumnParallel (input full, output parallel) modules_[kCFcLayerName] = std::make_shared( /*in_features=*/config.n_embd, /*out_features=*/ffn_hidden, diff --git a/test/transformer/test_transformer_architecture.cc b/test/transformer/test_transformer_architecture.cc index 469ff386..42efda2d 100644 --- a/test/transformer/test_transformer_architecture.cc +++ b/test/transformer/test_transformer_architecture.cc @@ -564,12 +564,13 @@ void TestMoELayerTop2() { config.n_embd = 32; config.n_head = 2; config.n_kv_head = 2; - config.activation_type = nn::MLPType::kGELU; - config.add_bias_linear = true; + config.activation_type = nn::MLPType::kSwiGLU; + config.add_bias_linear = false; config.ffn_type = nn::FFNType::kMoE; config.moe_config = nn::MoEConfig{}; config.moe_config->num_experts = 4; config.moe_config->router_topk = 2; + config.moe_config->moe_ffn_hidden_size = 48; auto moe = std::make_shared(config); auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); @@ -579,6 +580,14 @@ void TestMoELayerTop2() { CHECK_EQ(output.size(), 1); CHECK(output[0]->Dims() == input->Dims()); + auto state = moe->StateDict(); + CHECK(state.contains("experts.expert_0.c_fc.weight")); + CHECK(state.contains("experts.expert_0.c_fc2.weight")); + CHECK(state.contains("experts.expert_0.c_proj.weight")); + CHECK(state.at("experts.expert_0.c_fc.weight")->Dims() == std::vector({48, config.n_embd})); + CHECK(state.at("experts.expert_0.c_fc2.weight")->Dims() == std::vector({48, config.n_embd})); + CHECK(state.at("experts.expert_0.c_proj.weight")->Dims() == std::vector({config.n_embd, 48})); + std::cout << "SUCCESS: MoE layer top-2 forward works correctly!" << std::endl; }