diff --git a/infini_train/include/autograd/topk_mask.h b/infini_train/include/autograd/topk_mask.h new file mode 100644 index 00000000..355ef400 --- /dev/null +++ b/infini_train/include/autograd/topk_mask.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class TopKMask : public Function { +public: + static constexpr char kType[] = "TopKMaskFunction"; + + explicit TopKMask(int64_t topk) : Function(kType), topk_(topk) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + int64_t topk_ = 1; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/nn/modules/transformer/moe/experts.h b/infini_train/include/nn/modules/transformer/moe/experts.h new file mode 100644 index 00000000..a3dda7f0 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/experts.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class SequentialMLP : public CloneableModule { +public: + static constexpr char kType[] = "SequentialMLP"; + static constexpr char kExpertNamePrefix[] = "expert_"; + + explicit SequentialMLP(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; + int64_t num_local_experts_ = 0; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/moe_layer.h b/infini_train/include/nn/modules/transformer/moe/moe_layer.h new file mode 100644 index 00000000..e5fdb3ab --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/moe_layer.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class MoELayer : public CloneableModule { +public: + static constexpr char kType[] = "MoELayer"; + static constexpr char kRouterLayerName[] = "router"; + static constexpr char kExpertsLayerName[] = "experts"; + + explicit MoELayer(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/moe_utils.h b/infini_train/include/nn/modules/transformer/moe/moe_utils.h new file mode 100644 index 00000000..e0dd3744 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/moe_utils.h @@ -0,0 +1,9 @@ +#pragma once + +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +const MoEConfig &RequireMoEConfig(const TransformerConfig &config); + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/router.h b/infini_train/include/nn/modules/transformer/moe/router.h new file mode 100644 index 00000000..1279c217 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/router.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class TopKRouter : public CloneableModule { +public: + static constexpr char kType[] = "TopKRouter"; + static constexpr char kParamWeightName[] = "weight"; + static constexpr char kParamBiasName[] = "bias"; + + explicit TopKRouter(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index 62379666..3a96625d 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -20,11 +20,42 @@ enum class MLPType { kSwiGLU // SwiGLU activation }; +enum class FFNType { + kDense, // Standard dense MLP + kMoE // Mixture-of-Experts MLP +}; + enum class NormType { kLayerNorm, // LayerNorm kRMSNorm // RMSNorm }; +enum class MoERouterType { + kTopK // Top-k router. +}; + +enum class MoEDispatcherType { + kLocal, // No cross-rank token exchange + kAllGather // Reserved for expert parallel MoE +}; + +enum class MoEExpertImpl { + kSequential // Run local experts sequentially +}; + +struct MoEConfig { + int64_t num_experts = 0; + int64_t expert_parallel_size = 1; + int64_t router_topk = 1; + float aux_loss_coeff = 0.0f; + std::optional expert_capacity_factor = std::nullopt; + bool pad_expert_input_to_capacity = false; + int64_t moe_ffn_hidden_size = 0; + MoERouterType router_type = MoERouterType::kTopK; + MoEDispatcherType dispatcher_type = MoEDispatcherType::kLocal; + MoEExpertImpl expert_impl = MoEExpertImpl::kSequential; +}; + struct TransformerConfig { int64_t block_size = 1024; // Max seq_len int64_t vocab_size = 50304; // Vocab size @@ -36,6 +67,7 @@ struct TransformerConfig { AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type MLPType activation_type = MLPType::kGELU; // MLP activation type + FFNType ffn_type = FFNType::kDense; // Feed-forward module type NormType norm_type = NormType::kLayerNorm; // Normalization type bool add_bias_linear = true; // Whether to add learnable bias to all Linear layers in the Transformer block, @@ -48,6 +80,7 @@ struct TransformerConfig { float ffn_expansion_ratio = 4.0f; // MLP output: n_embd * ffn_expansion_ratio std::optional ffn_dim_multiplier = 1.5f; // FFN dim multiplier int64_t multiple_of = 256; // FFN dims must be multiple of this number + std::optional moe_config = std::nullopt; // RoPE config float rope_theta = 500000.0f; // theta in RoPE diff --git a/infini_train/src/autograd/topk_mask.cc b/infini_train/src/autograd/topk_mask.cc new file mode 100644 index 00000000..16dc6629 --- /dev/null +++ b/infini_train/src/autograd/topk_mask.cc @@ -0,0 +1,32 @@ +#include "infini_train/include/autograd/topk_mask.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> TopKMask::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + CHECK_GT(topk_, 0); + const auto &input = input_tensors[0]; + auto device = input->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "TopKMaskForward"}, input, topk_)}; +} + +void TopKMask::SetupContext(const std::vector> &, + const std::vector> &output_tensors) { + saved_tensors_ = {output_tensors[0]}; +} + +std::vector> TopKMask::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + const auto &mask_values = saved_tensors_[0]; + auto device = grad_output->GetDevice().type(); + return { + Dispatcher::Instance().Call>({device, "TopKMaskBackward"}, grad_output, mask_values)}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/topk_mask.cc b/infini_train/src/kernels/cpu/topk_mask.cc new file mode 100644 index 00000000..6a7191b9 --- /dev/null +++ b/infini_train/src/kernels/cpu/topk_mask.cc @@ -0,0 +1,88 @@ +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::kernels::cpu { + +std::shared_ptr TopKMaskForward(const std::shared_ptr &input, int64_t topk) { + CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskForward currently supports float32 only"; + CHECK_GE(input->Dims().size(), 1); + + const auto &dims = input->Dims(); + const int64_t num_experts = dims.back(); + CHECK_GT(num_experts, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, num_experts); + const int64_t rows = input->NumElements() / num_experts; + + auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); + output->Fill(0.0f); + + const float *in = static_cast(input->DataPtr()); + float *out = static_cast(output->DataPtr()); + for (int64_t row = 0; row < rows; ++row) { + const int64_t row_offset = row * num_experts; + std::vector selected_experts(num_experts, false); + float selected_sum = 0.0f; + for (int64_t selected = 0; selected < topk; ++selected) { + int64_t best_idx = -1; + float best_value = -std::numeric_limits::infinity(); + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + if (selected_experts[expert_idx]) { + continue; + } + const float value = in[row_offset + expert_idx]; + if (value > best_value) { + best_value = value; + best_idx = expert_idx; + } + } + CHECK_GE(best_idx, 0); + selected_experts[best_idx] = true; + out[row_offset + best_idx] = best_value; + selected_sum += best_value; + } + if (topk > 1 && selected_sum != 0.0f) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + out[row_offset + expert_idx] + = out[row_offset + expert_idx] == 0.0f ? 0.0f : out[row_offset + expert_idx] / selected_sum; + } + } + } + + return output; +} + +std::shared_ptr TopKMaskBackward(const std::shared_ptr &grad_output, + const std::shared_ptr &mask_values) { + CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskBackward currently supports float32 only"; + CHECK(mask_values->Dtype() == DataType::kFLOAT32); + CHECK(grad_output->Dims() == mask_values->Dims()); + + auto grad_input = std::make_shared(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice()); + grad_input->Fill(0.0f); + + const float *grad = static_cast(grad_output->DataPtr()); + const float *mask = static_cast(mask_values->DataPtr()); + float *out = static_cast(grad_input->DataPtr()); + for (int64_t i = 0; i < static_cast(grad_output->NumElements()); ++i) { + out[i] = mask[i] != 0.0f ? grad[i] : 0.0f; + } + + return grad_input; +} + +} // namespace infini_train::kernels::cpu + +#define REGISTER_CPU_TOPK_MASK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + +REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskForward) +REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskBackward) + +#undef REGISTER_CPU_TOPK_MASK_KERNEL diff --git a/infini_train/src/kernels/cuda/topk_mask.cu b/infini_train/src/kernels/cuda/topk_mask.cu new file mode 100644 index 00000000..e38c793e --- /dev/null +++ b/infini_train/src/kernels/cuda/topk_mask.cu @@ -0,0 +1,118 @@ +#include "glog/logging.h" + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h" +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernels::cuda { + +template +__global__ void TopKMaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows, + int64_t num_experts, int64_t topk) { + int64_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= rows) { + return; + } + + const int64_t offset = row * num_experts; + float selected_sum = 0.0f; + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + const float value = static_cast(input[offset + expert_idx]); + int64_t rank = 0; + for (int64_t other_idx = 0; other_idx < num_experts; ++other_idx) { + const float other_value = static_cast(input[offset + other_idx]); + if (other_value > value || (other_value == value && other_idx < expert_idx)) { + ++rank; + } + } + const bool selected = rank < topk; + output[offset + expert_idx] = selected ? input[offset + expert_idx] : T(0.0f); + selected_sum += selected ? value : 0.0f; + } + if (topk > 1 && selected_sum != 0.0f) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + if (static_cast(output[offset + expert_idx]) != 0.0f) { + output[offset + expert_idx] = T(static_cast(output[offset + expert_idx]) / selected_sum); + } + } + } +} + +std::shared_ptr TopKMaskForward(const std::shared_ptr &input, int64_t topk) { + CHECK_GE(input->Dims().size(), 1); + const auto &dims = input->Dims(); + const int64_t num_experts = dims.back(); + CHECK_GT(num_experts, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, num_experts); + const int64_t rows = input->NumElements() / num_experts; + + auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); + + auto device = input->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int threads = 256; + const int blocks = static_cast((rows + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + input->Dtype(), + [=]() { + TopKMaskForwardKernel<<>>( + static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, num_experts, topk); + }, + "CUDA TopKMaskForward"); + + return output; +} + +template +__global__ void TopKMaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values, + T *__restrict__ grad_input, int64_t total_elements) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_elements) { + return; + } + grad_input[idx] = static_cast(mask_values[idx]) != 0.0f ? grad_output[idx] : T(0.0f); +} + +std::shared_ptr TopKMaskBackward(const std::shared_ptr &grad_output, + const std::shared_ptr &mask_values) { + CHECK(grad_output->Dims() == mask_values->Dims()); + CHECK(grad_output->Dtype() == mask_values->Dtype()); + auto grad_input = std::make_shared(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice()); + + auto device = grad_output->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int64_t total_elements = grad_output->NumElements(); + const int threads = 256; + const int blocks = static_cast((total_elements + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + grad_output->Dtype(), + [=]() { + TopKMaskBackwardKernel<<>>( + static_cast(grad_output->DataPtr()), static_cast(mask_values->DataPtr()), + static_cast(grad_input->DataPtr()), total_elements); + }, + "CUDA TopKMaskBackward"); + + return grad_input; +} + +} // namespace infini_train::kernels::cuda + +#define REGISTER_CUDA_TOPK_MASK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_TOPK_MASK_KERNEL(TopKMaskForward) +REGISTER_CUDA_TOPK_MASK_KERNEL(TopKMaskBackward) + +#undef REGISTER_CUDA_TOPK_MASK_KERNEL diff --git a/infini_train/src/nn/modules/transformer/mlp.cc b/infini_train/src/nn/modules/transformer/mlp.cc index 3af341b2..ac35d144 100644 --- a/infini_train/src/nn/modules/transformer/mlp.cc +++ b/infini_train/src/nn/modules/transformer/mlp.cc @@ -35,9 +35,14 @@ MLP::MLP(const TransformerConfig &config) : CloneableModule(kType) { } // Round up to multiple_of - int64_t before_round = ffn_hidden; ffn_hidden = (ffn_hidden + config.multiple_of - 1) / config.multiple_of * config.multiple_of; + if (config.ffn_type == FFNType::kMoE && config.moe_config.has_value() + && config.moe_config->moe_ffn_hidden_size > 0) { + ffn_hidden = config.moe_config->moe_ffn_hidden_size; + } + CHECK_GT(ffn_hidden, 0); + // c_fc: ColumnParallel (input full, output parallel) modules_[kCFcLayerName] = std::make_shared( /*in_features=*/config.n_embd, /*out_features=*/ffn_hidden, diff --git a/infini_train/src/nn/modules/transformer/moe/experts.cc b/infini_train/src/nn/modules/transformer/moe/experts.cc new file mode 100644 index 00000000..8f3b1be8 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/experts.cc @@ -0,0 +1,50 @@ +#include "infini_train/include/nn/modules/transformer/moe/experts.h" + +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(moe_config.expert_impl == MoEExpertImpl::kSequential); + CHECK_EQ(moe_config.expert_parallel_size, 1) + << "Current InfiniTrain MoE implementation supports expert_parallel_size=1 only"; + CHECK(moe_config.dispatcher_type == MoEDispatcherType::kLocal) + << "Current InfiniTrain MoE implementation supports local dispatch only"; + + num_local_experts_ = moe_config.num_experts; + CHECK_GT(num_local_experts_, 0); + + for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) { + modules_[std::string(kExpertNamePrefix) + std::to_string(expert_idx)] = std::make_shared(config_); + } +} + +std::vector> SequentialMLP::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 2); + auto hidden_states = input_tensors[0]; + auto routing_probs = input_tensors[1]; + CHECK_EQ(routing_probs->Dims().back(), num_local_experts_); + + std::shared_ptr output = nullptr; + const int64_t expert_dim = static_cast(routing_probs->Dims().size()) - 1; + for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) { + auto expert_name = std::string(kExpertNamePrefix) + std::to_string(expert_idx); + auto expert_output = (*modules_.at(expert_name))({hidden_states})[0]; + auto expert_prob = routing_probs->Slice(expert_dim, expert_idx, expert_idx + 1); + auto weighted_output = expert_output * expert_prob; + output = output == nullptr ? weighted_output : output + weighted_output; + } + + return {output}; +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_layer.cc b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc new file mode 100644 index 00000000..8efd51c0 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc @@ -0,0 +1,32 @@ +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/moe/experts.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/nn/modules/transformer/moe/router.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +MoELayer::MoELayer(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(config_.ffn_type == FFNType::kMoE); + CHECK(moe_config.dispatcher_type == MoEDispatcherType::kLocal) + << "Current InfiniTrain MoE implementation supports local dispatch only"; + + modules_[kRouterLayerName] = std::make_shared(config_); + modules_[kExpertsLayerName] = std::make_shared(config_); +} + +std::vector> MoELayer::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + auto hidden_states = input_tensors[0]; + auto routing_probs = (*modules_.at(kRouterLayerName))({hidden_states})[0]; + return (*modules_.at(kExpertsLayerName))({hidden_states, routing_probs}); +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_utils.cc b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc new file mode 100644 index 00000000..80ef01c1 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc @@ -0,0 +1,12 @@ +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" + +#include "glog/logging.h" + +namespace infini_train::nn::moe { + +const MoEConfig &RequireMoEConfig(const TransformerConfig &config) { + CHECK(config.moe_config.has_value()) << "MoE layer requires TransformerConfig::moe_config"; + return config.moe_config.value(); +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/router.cc b/infini_train/src/nn/modules/transformer/moe/router.cc new file mode 100644 index 00000000..851c57be --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/router.cc @@ -0,0 +1,52 @@ +#include "infini_train/include/nn/modules/transformer/moe/router.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/autograd/linear.h" +#include "infini_train/include/autograd/topk_mask.h" +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/init.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +TopKRouter::TopKRouter(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(moe_config.router_type == MoERouterType::kTopK); + CHECK_GT(moe_config.num_experts, 0); + CHECK_GT(moe_config.router_topk, 0); + CHECK_LE(moe_config.router_topk, moe_config.num_experts); + + parameters_[kParamWeightName] + = std::make_shared(std::vector{moe_config.num_experts, config_.n_embd}, DataType::kFLOAT32, + device_) + ->RequiresGrad(); + init::KaimingUniform(parameters_[kParamWeightName]); + + if (config_.add_bias_linear) { + parameters_[kParamBiasName] + = std::make_shared(std::vector{moe_config.num_experts}, DataType::kFLOAT32, device_) + ->RequiresGrad(); + parameters_[kParamBiasName]->Fill(0.0f); + } +} + +std::vector> TopKRouter::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + std::vector> linear_inputs{input_tensors[0], parameters_.at(kParamWeightName)}; + if (parameters_.contains(kParamBiasName)) { + linear_inputs.push_back(parameters_.at(kParamBiasName)); + } + + auto logits = std::make_shared()->Apply(linear_inputs)[0]; + auto scores = function::Softmax(logits, -1); + const auto &moe_config = RequireMoEConfig(config_); + auto routing_probs = std::make_shared(moe_config.router_topk)->Apply({scores})[0]; + return {routing_probs}; +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/transformer.cc b/infini_train/src/nn/modules/transformer/transformer.cc index c7e0f28c..bdcde449 100644 --- a/infini_train/src/nn/modules/transformer/transformer.cc +++ b/infini_train/src/nn/modules/transformer/transformer.cc @@ -15,6 +15,7 @@ #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" #include "infini_train/include/nn/modules/transformer/utils.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" @@ -86,7 +87,11 @@ TransformerLayer::TransformerLayer(const nn::TransformerConfig &config) : Clonea } modules_[kAttnLayerName] = std::make_shared(config); - modules_[kMlpLayerName] = std::make_shared(config); + if (config.ffn_type == FFNType::kMoE) { + modules_[kMlpLayerName] = std::make_shared(config); + } else { + modules_[kMlpLayerName] = std::make_shared(config); + } } std::vector> TransformerLayer::Forward(const std::vector> &x) { diff --git a/test/transformer/test_transformer_architecture.cc b/test/transformer/test_transformer_architecture.cc index 1c6ed28f..42efda2d 100644 --- a/test/transformer/test_transformer_architecture.cc +++ b/test/transformer/test_transformer_architecture.cc @@ -11,6 +11,7 @@ #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" #include "infini_train/include/nn/modules/transformer/transformer.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" #include "infini_train/include/nn/modules/transformer/utils.h" @@ -525,6 +526,71 @@ void TestStateDict() { } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } } +// ============================================================================ +// Test 11: MoE Layer +// ============================================================================ +void TestMoELayer() { + std::cout << "\n=== Test 11: MoE Layer ===" << std::endl; + + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kGELU; + config.add_bias_linear = true; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 2; + config.moe_config->router_topk = 1; + + auto moe = std::make_shared(config); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); + input->Uniform(); + + auto output = (*moe)({input}); + CHECK_EQ(output.size(), 1); + CHECK(output[0]->Dims() == input->Dims()); + + auto params = moe->Parameters(); + CHECK(!params.empty()); + + std::cout << "SUCCESS: MoE layer forward works correctly!" << std::endl; +} + +void TestMoELayerTop2() { + std::cout << "\n=== Test 12: MoE Layer Top-2 ===" << std::endl; + + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kSwiGLU; + config.add_bias_linear = false; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 4; + config.moe_config->router_topk = 2; + config.moe_config->moe_ffn_hidden_size = 48; + + auto moe = std::make_shared(config); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); + input->Uniform(); + + auto output = (*moe)({input}); + CHECK_EQ(output.size(), 1); + CHECK(output[0]->Dims() == input->Dims()); + + auto state = moe->StateDict(); + CHECK(state.contains("experts.expert_0.c_fc.weight")); + CHECK(state.contains("experts.expert_0.c_fc2.weight")); + CHECK(state.contains("experts.expert_0.c_proj.weight")); + CHECK(state.at("experts.expert_0.c_fc.weight")->Dims() == std::vector({48, config.n_embd})); + CHECK(state.at("experts.expert_0.c_fc2.weight")->Dims() == std::vector({48, config.n_embd})); + CHECK(state.at("experts.expert_0.c_proj.weight")->Dims() == std::vector({config.n_embd, 48})); + + std::cout << "SUCCESS: MoE layer top-2 forward works correctly!" << std::endl; +} + // ============================================================================ // Main // ============================================================================ @@ -547,6 +613,8 @@ int main(int argc, char *argv[]) { TestLlama3Model(); TestRopeUtils(); TestStateDict(); + TestMoELayer(); + TestMoELayerTop2(); std::cout << "\n========================================" << std::endl; std::cout << " All Tests Completed" << std::endl;