Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions infini_train/include/autograd/topk_mask.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include <memory>
#include <vector>

#include "infini_train/include/autograd/function.h"

namespace infini_train {
class Tensor;
}

namespace infini_train::autograd {

class TopKMask : public Function {
public:
static constexpr char kType[] = "TopKMaskFunction";

explicit TopKMask(int64_t topk) : Function(kType), topk_(topk) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
int64_t topk_ = 1;
};

} // namespace infini_train::autograd
25 changes: 25 additions & 0 deletions infini_train/include/nn/modules/transformer/moe/experts.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include <memory>
#include <vector>

#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/modules/transformer/transformer_config.h"

namespace infini_train::nn::moe {

class SequentialMLP : public CloneableModule<SequentialMLP> {
public:
static constexpr char kType[] = "SequentialMLP";
static constexpr char kExpertNamePrefix[] = "expert_";

explicit SequentialMLP(const TransformerConfig &config);

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;

private:
TransformerConfig config_;
int64_t num_local_experts_ = 0;
};

} // namespace infini_train::nn::moe
25 changes: 25 additions & 0 deletions infini_train/include/nn/modules/transformer/moe/moe_layer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include <memory>
#include <vector>

#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/modules/transformer/transformer_config.h"

namespace infini_train::nn::moe {

class MoELayer : public CloneableModule<MoELayer> {
public:
static constexpr char kType[] = "MoELayer";
static constexpr char kRouterLayerName[] = "router";
static constexpr char kExpertsLayerName[] = "experts";

explicit MoELayer(const TransformerConfig &config);

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;

private:
TransformerConfig config_;
};

} // namespace infini_train::nn::moe
9 changes: 9 additions & 0 deletions infini_train/include/nn/modules/transformer/moe/moe_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

#include "infini_train/include/nn/modules/transformer/transformer_config.h"

namespace infini_train::nn::moe {

const MoEConfig &RequireMoEConfig(const TransformerConfig &config);

} // namespace infini_train::nn::moe
25 changes: 25 additions & 0 deletions infini_train/include/nn/modules/transformer/moe/router.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include <memory>
#include <vector>

#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/modules/transformer/transformer_config.h"

namespace infini_train::nn::moe {

class TopKRouter : public CloneableModule<TopKRouter> {
public:
static constexpr char kType[] = "TopKRouter";
static constexpr char kParamWeightName[] = "weight";
static constexpr char kParamBiasName[] = "bias";

explicit TopKRouter(const TransformerConfig &config);

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;

private:
TransformerConfig config_;
};

} // namespace infini_train::nn::moe
33 changes: 33 additions & 0 deletions infini_train/include/nn/modules/transformer/transformer_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,42 @@ enum class MLPType {
kSwiGLU // SwiGLU activation
};

enum class FFNType {
kDense, // Standard dense MLP
kMoE // Mixture-of-Experts MLP
};

enum class NormType {
kLayerNorm, // LayerNorm
kRMSNorm // RMSNorm
};

enum class MoERouterType {
kTopK // Top-k router.
};

enum class MoEDispatcherType {
kLocal, // No cross-rank token exchange
kAllGather // Reserved for expert parallel MoE
};

enum class MoEExpertImpl {
kSequential // Run local experts sequentially
};

struct MoEConfig {
int64_t num_experts = 0;
int64_t expert_parallel_size = 1;
int64_t router_topk = 1;
float aux_loss_coeff = 0.0f;
std::optional<float> expert_capacity_factor = std::nullopt;
bool pad_expert_input_to_capacity = false;
int64_t moe_ffn_hidden_size = 0;
MoERouterType router_type = MoERouterType::kTopK;
MoEDispatcherType dispatcher_type = MoEDispatcherType::kLocal;
MoEExpertImpl expert_impl = MoEExpertImpl::kSequential;
};

struct TransformerConfig {
int64_t block_size = 1024; // Max seq_len
int64_t vocab_size = 50304; // Vocab size
Expand All @@ -36,6 +67,7 @@ struct TransformerConfig {

AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type
MLPType activation_type = MLPType::kGELU; // MLP activation type
FFNType ffn_type = FFNType::kDense; // Feed-forward module type
NormType norm_type = NormType::kLayerNorm; // Normalization type

bool add_bias_linear = true; // Whether to add learnable bias to all Linear layers in the Transformer block,
Expand All @@ -48,6 +80,7 @@ struct TransformerConfig {
float ffn_expansion_ratio = 4.0f; // MLP output: n_embd * ffn_expansion_ratio
std::optional<float> ffn_dim_multiplier = 1.5f; // FFN dim multiplier
int64_t multiple_of = 256; // FFN dims must be multiple of this number
std::optional<MoEConfig> moe_config = std::nullopt;

// RoPE config
float rope_theta = 500000.0f; // theta in RoPE
Expand Down
32 changes: 32 additions & 0 deletions infini_train/src/autograd/topk_mask.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "infini_train/include/autograd/topk_mask.h"

#include "glog/logging.h"

#include "infini_train/include/dispatcher.h"
#include "infini_train/include/tensor.h"

namespace infini_train::autograd {

std::vector<std::shared_ptr<Tensor>> TopKMask::Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
CHECK_EQ(input_tensors.size(), 1);
CHECK_GT(topk_, 0);
const auto &input = input_tensors[0];
auto device = input->GetDevice().type();
return {Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "TopKMaskForward"}, input, topk_)};
}

void TopKMask::SetupContext(const std::vector<std::shared_ptr<Tensor>> &,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) {
saved_tensors_ = {output_tensors[0]};
}

std::vector<std::shared_ptr<Tensor>> TopKMask::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];
const auto &mask_values = saved_tensors_[0];
auto device = grad_output->GetDevice().type();
return {
Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "TopKMaskBackward"}, grad_output, mask_values)};
}

} // namespace infini_train::autograd
88 changes: 88 additions & 0 deletions infini_train/src/kernels/cpu/topk_mask.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include <limits>
#include <memory>
#include <vector>

#include "glog/logging.h"

#include "infini_train/include/dispatcher.h"
#include "infini_train/include/tensor.h"

namespace infini_train::kernels::cpu {

std::shared_ptr<Tensor> TopKMaskForward(const std::shared_ptr<Tensor> &input, int64_t topk) {
CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskForward currently supports float32 only";
CHECK_GE(input->Dims().size(), 1);

const auto &dims = input->Dims();
const int64_t num_experts = dims.back();
CHECK_GT(num_experts, 0);
CHECK_GT(topk, 0);
CHECK_LE(topk, num_experts);
const int64_t rows = input->NumElements() / num_experts;

auto output = std::make_shared<Tensor>(dims, input->Dtype(), input->GetDevice());
output->Fill(0.0f);

const float *in = static_cast<const float *>(input->DataPtr());
float *out = static_cast<float *>(output->DataPtr());
for (int64_t row = 0; row < rows; ++row) {
const int64_t row_offset = row * num_experts;
std::vector<bool> selected_experts(num_experts, false);
float selected_sum = 0.0f;
for (int64_t selected = 0; selected < topk; ++selected) {
int64_t best_idx = -1;
float best_value = -std::numeric_limits<float>::infinity();
for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
if (selected_experts[expert_idx]) {
continue;
}
const float value = in[row_offset + expert_idx];
if (value > best_value) {
best_value = value;
best_idx = expert_idx;
}
}
CHECK_GE(best_idx, 0);
selected_experts[best_idx] = true;
out[row_offset + best_idx] = best_value;
selected_sum += best_value;
}
if (topk > 1 && selected_sum != 0.0f) {
for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
out[row_offset + expert_idx]
= out[row_offset + expert_idx] == 0.0f ? 0.0f : out[row_offset + expert_idx] / selected_sum;
}
}
}

return output;
}

std::shared_ptr<Tensor> TopKMaskBackward(const std::shared_ptr<Tensor> &grad_output,
const std::shared_ptr<Tensor> &mask_values) {
CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskBackward currently supports float32 only";
CHECK(mask_values->Dtype() == DataType::kFLOAT32);
CHECK(grad_output->Dims() == mask_values->Dims());

auto grad_input = std::make_shared<Tensor>(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice());
grad_input->Fill(0.0f);

const float *grad = static_cast<const float *>(grad_output->DataPtr());
const float *mask = static_cast<const float *>(mask_values->DataPtr());
float *out = static_cast<float *>(grad_input->DataPtr());
for (int64_t i = 0; i < static_cast<int64_t>(grad_output->NumElements()); ++i) {
out[i] = mask[i] != 0.0f ? grad[i] : 0.0f;
}

return grad_input;
}

} // namespace infini_train::kernels::cpu

#define REGISTER_CPU_TOPK_MASK_KERNEL(kernel_name) \
REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name)

REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskForward)
REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskBackward)

#undef REGISTER_CPU_TOPK_MASK_KERNEL
Loading
Loading