Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ add_library(llmq-common SHARED
src/training/logging.cpp
src/training/checkpoint.cpp
src/training/model.cpp
src/training/gradients.cpp
src/training/transformer_config.cpp

src/models/llama_run_state.cpp
Expand Down
11 changes: 7 additions & 4 deletions src/binding/py_train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,18 +276,21 @@ std::vector<std::pair<std::string, long>> MultiGPUPyTrainer::get_stack_info(int
}

std::vector<std::pair<std::string, Tensor>> MultiGPUPyTrainer::get_gradients(int gpu_id) {
using namespace LLamaWeightID;

std::vector<std::pair<std::string, Tensor>> result;
// TODO make this work with generalized gradients
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO comment here indicates that get_gradients still hardcodes LLaMA-specific weight IDs (using LLamaWeightID::EMBEDDING, LM_HEAD, LNF_W, etc.) rather than leveraging the generic IGradientManager interface. This means the function is LLaMA-specific and won't work correctly for other model architectures that may use IGradientManager. Since the PR's goal is to provide a generic gradient manager implementation, this is a known gap that should be tracked.

Copilot uses AI. Check for mistakes.
run_work([&result](sThreadContext& ctx) {
const auto& config = ctx.Model->config();
auto& grads = ctx.Model->grads();
CUDA_CHECK(cudaDeviceSynchronize());
result.emplace_back("model.embed_tokens.weight", grads.get_embeddings_shard(nullptr));
result.emplace_back("model.embed_tokens.weight", grads.get_non_block_shard(LLamaWeightID::EMBEDDING, nullptr));
if (!config.TiedWordEmbeddings) {
result.emplace_back("lm_head.weight", grads.get_lmhead_shard(nullptr));
result.emplace_back("lm_head.weight", grads.get_non_block_shard(LM_HEAD, nullptr));
}
result.emplace_back("model.norm.weight", grads.get_lnf_w_shard(nullptr));
result.emplace_back("model.norm.weight", grads.get_non_block_shard(LNF_W, nullptr));
for (int l = 0; l < config.NumLayers; l++) {
using namespace LLamaWeightID;

std::string prefix = "model.layers." + std::to_string(l);
auto& block = grads.get_block_shard(l, nullptr);
result.emplace_back(prefix + ".self_attn.qkv.weight", block.get_tensor(QKV_W));
Expand Down
367 changes: 33 additions & 334 deletions src/models/llama_gradients.cpp

Large diffs are not rendered by default.

86 changes: 46 additions & 40 deletions src/models/llama_gradients.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,48 +6,54 @@
#define LLMQ_SRC_MODELS_LLAMA_GRADIENTS_H

#include "llama_weights.h"
#include "utilities/philox.h"
#include "training/gradients.h"

class LLamaGradsManager {
class LLamaModel;

class LLamaGradientsUnsharded : public UnshardedGradientManager {
public:
LLamaGradientsUnsharded(const TransformerConfig& cfg, IModel& model, std::uint64_t seed, int step, int rank,
int world, const std::shared_ptr<TensorAllocator>& alloc)
: UnshardedGradientManager(cfg, model, seed, step, rank, world, alloc) {
}

void on_first_micro_step(cudaStream_t stream) override;
};

class LLamaGradientsBlockShardedBase : public ShardedBlocksGradientManager {
public:
using ShardedBlocksGradientManager::ShardedBlocksGradientManager;

void on_first_micro_step(cudaStream_t stream) override;
void on_get_block(SimpleTensorContainer& block, cudaStream_t stream) override;
};

class LLamaGradientsBlockSharded_ScatterReduce : public LLamaGradientsBlockShardedBase {
public:
using LLamaGradientsBlockShardedBase::LLamaGradientsBlockShardedBase;
private:
void on_notify_block(int layer_idx, SimpleTensorContainer& block, cudaStream_t stream, cudaEvent_t signal, NCCLCommunicator& comm) override;
void sr_accumulate_layer(int layer_idx,
SimpleTensorContainer& dw,
SimpleTensorContainer& sw,
cudaStream_t stream,
NCCLCommunicator& comm) override;
};

class LLamaGradientsBlockSharded_AllToAll : public LLamaGradientsBlockShardedBase {
public:
virtual ~LLamaGradsManager() = default;

void start_micro_step(cudaStream_t stream, int micro_step, int total_steps);
virtual void end_micro_step(cudaStream_t stream, NCCLCommunicator& comm) = 0;

// Get references to full gradient accumulators for use in the backward pass
virtual Tensor& get_embeddings_full(cudaStream_t stream, NCCLCommunicator& comm, bool& accumulate) = 0;
virtual Tensor& get_lmhead_full(cudaStream_t stream, NCCLCommunicator& comm, bool& accumulate) = 0;
virtual Tensor& get_lnf_w_full(cudaStream_t stream, NCCLCommunicator& comm, bool& accumulate) = 0;
virtual sLLamaBlockWeights<Tensor>& get_block_full(int layer_idx, cudaStream_t stream, NCCLCommunicator& comm, bool& accumulate) = 0;

// Get references to sharded gradients for use in the optimizer
virtual TensorShard& get_embeddings_shard(cudaStream_t stream) = 0;
virtual TensorShard& get_lmhead_shard(cudaStream_t stream) = 0;
virtual TensorShard& get_lnf_w_shard(cudaStream_t stream) = 0;
virtual SimpleTensorContainer& get_block_shard(int layer_idx, cudaStream_t stream) = 0;

// notify that gradient calculations have been completed
virtual void notify_embeddings(cudaStream_t stream, NCCLCommunicator& comm) = 0;
virtual void notify_lmhead(cudaStream_t stream, NCCLCommunicator& comm) = 0;
virtual void notify_lnf_w(cudaStream_t stream, NCCLCommunicator& comm) = 0;
virtual void notify_block(int layer_idx, cudaStream_t stream, NCCLCommunicator& comm) = 0;

static std::unique_ptr<LLamaGradsManager> create(std::uint64_t seed, int step, const TransformerConfig& config,
const LLamaOptions& options, int rank, int world,
const std::shared_ptr<TensorAllocator>& alloc);

protected:
LLamaGradsManager(std::uint64_t seed, int step);
virtual void on_first_micro_step(cudaStream_t stream) = 0;

void scatter_reduce(Tensor& tensor, cudaStream_t stream, cudaEvent_t signal, NCCLCommunicator& comm);
virtual void scatter_reduce(int layer_idx, SimpleTensorContainer& block, cudaStream_t stream, cudaEvent_t signal, NCCLCommunicator& comm);

Philox4x32 mRng;
int mStepCounter = -1;
bool mIsFirstMicroStep = true;
bool mIsLastMicroStep = false;
using LLamaGradientsBlockShardedBase::LLamaGradientsBlockShardedBase;
private:
void on_notify_block(int layer_idx, SimpleTensorContainer& block, cudaStream_t stream, cudaEvent_t signal, NCCLCommunicator& comm) override;
void sr_accumulate_layer(int layer_idx,
SimpleTensorContainer& dw,
SimpleTensorContainer& sw,
cudaStream_t stream,
NCCLCommunicator& comm) override;
};

std::unique_ptr<IGradientManager> create_grads_manager(
std::uint64_t seed, int step, LLamaModel& model,
const TransformerConfig& config, const LLamaOptions& options,
int rank, int world, const std::shared_ptr<TensorAllocator>& alloc);
#endif //LLMQ_SRC_MODELS_LLAMA_GRADIENTS_H
56 changes: 34 additions & 22 deletions src/models/llama_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ void LLamaModel::backward(Tensor inputs, Tensor targets, NCCLCommunicator& comm,
}

bool accumulate;
auto& d_lnf_w = Grads->get_lnf_w_full(main_stream, comm, accumulate);
auto& d_lnf_w = Grads->get_non_block_full(LLamaWeightID::LNF_W, main_stream, comm, accumulate);
Parameters->gather_lnf(comm);
// backward the final layernorm
rmsnorm_backward(rs->DActs[L-1].DResFFN.Value, d_lnf_w, rs->RMSNormScratch, rs->DActs[L - 1].DResFFN.Value, rs->DLNF,
Expand All @@ -398,7 +398,7 @@ void LLamaModel::backward(Tensor inputs, Tensor targets, NCCLCommunicator& comm,
rs->release_res_ffn(L-1, main_stream);

Parameters->release_lnf(main_stream);
Grads->notify_lnf_w(main_stream, comm);
Grads->notify_non_block(LLamaWeightID::LNF_W, main_stream, comm);
rs->fetch_res_ffn(L-2, comm.stream());
Parameters->gather_block(L - 1, comm, *rs);
// now backward all the layers
Expand Down Expand Up @@ -429,22 +429,22 @@ void LLamaModel::backward(Tensor inputs, Tensor targets, NCCLCommunicator& comm,

if(l > 0) {
auto& prev_dacts = rs->DActs.at(l - 1);
rmsnorm_backward(prev_dacts.DResFFN.Value, dw.LN1_w, rs->RMSNormScratch, prev_dacts.DResAtt.Value, d_acts.DLN1,
rmsnorm_backward(prev_dacts.DResFFN.Value, dw.get_tensor(LLamaWeightID::LN1_W), rs->RMSNormScratch, prev_dacts.DResAtt.Value, d_acts.DLN1,
rs->get_res_ffn(l-1, main_stream), weights.LN1_w, rs->Acts[l].LN1_Rstd, prev_dacts.DResFFN.Quant.abs_max(),
B, T, C, rs->DeviceProp, main_stream);
rs->release_res_ffn(l - 1, main_stream);
} else {
rmsnorm_backward(rs->DEmb, dw.LN1_w, rs->RMSNormScratch, d_acts.DResAtt.Value, d_acts.DLN1,
rmsnorm_backward(rs->DEmb, dw.get_tensor(LLamaWeightID::LN1_W), rs->RMSNormScratch, d_acts.DResAtt.Value, d_acts.DLN1,
rs->Encoded, weights.LN1_w, rs->Acts[l].LN1_Rstd, nullptr, B, T, C, rs->DeviceProp, main_stream);
}
Parameters->release_block(l, main_stream);
Grads->notify_block(l, main_stream, comm);
}

auto& d_emb = Grads->get_embeddings_full(main_stream, comm, accumulate);
auto& d_emb = Grads->get_non_block_full(LLamaWeightID::EMBEDDING, main_stream, comm, accumulate);
encoder_backward(d_emb, rs->EncoderBwdScratch, rs->EncoderBwdIndices, rs->EncoderBwdInfo,
rs->DEmb, rs->Inputs, inputs, B, T, C, OptimizerRNG(), main_stream, rs->SideStreamEvent, rs->SideStream);
Grads->notify_embeddings(main_stream, comm);
Grads->notify_non_block(LLamaWeightID::EMBEDDING, main_stream, comm);

// make sure all gradients are communicated before we go to the update step.
Grads->end_micro_step(main_stream, comm);
Expand Down Expand Up @@ -509,12 +509,21 @@ void LLamaModel::_backward_lmhead(long B, long T, float z_loss, int micro_step,

// handle the LM-head. We run the d_lmhead matmul first, so that the gradient reduction can overlap with the DLNF matmul.
bool accumulate;
auto& d_lmhead = Grads->get_lmhead_full(main_stream, comm, accumulate);
// get the correct matrix depending on whether we have tied embeddings
auto& d_lmhead = [&]() -> Tensor& {
if (Config.TiedWordEmbeddings) {
return Grads->get_non_block_full(LLamaWeightID::EMBEDDING, main_stream, comm, accumulate);
} else {
return Grads->get_non_block_full(LLamaWeightID::LM_HEAD, main_stream, comm, accumulate);
}
}();

// even if we overwrite for first micro-batch, we need to accumulate on non-first nano batch
accumulate |= nano_step != 0;
matmul(d_lmhead, lnf_slice, rs->Output, Tensor{}, nullptr, nullptr,
rs->CublasLtHandle, rs->CuBlasWorkspace, C, V, nano_batch_size, EMMTranspose::NT, accumulate, main_stream, rs->MatmulBackend);
if (nano_step == nano_batches - 1) {
Grads->notify_lmhead(main_stream, comm);
if (nano_step == nano_batches - 1 && !Config.TiedWordEmbeddings) {
Grads->notify_non_block(LLamaWeightID::LM_HEAD, main_stream, comm);
}

matmul(dlnf_slice, Parameters->get_head(main_stream), rs->Output, Tensor{}, nullptr, nullptr,
Expand Down Expand Up @@ -610,8 +619,9 @@ void LLamaModel::_recompute_block(sLLamaBlockWeights<Tensor>& weights, sLLamaLay
}
}

void LLamaModel::_backward_block(bool accumulate, sLLamaBlockWeights<Tensor>& weights, sLLamaGradBlock& d_weights,
void LLamaModel::_backward_block(bool accumulate, sLLamaBlockWeights<Tensor>& weights, SimpleTensorContainer& d_weights,
sLLamaLayerActivations& acts, sLLamaLayerGradients& d_acts) {
using namespace LLamaWeightID;
auto& rs = RunState;
cudaStream_t main_stream = rs->MainStream;
// convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow
Expand All @@ -626,7 +636,7 @@ void LLamaModel::_backward_block(bool accumulate, sLLamaBlockWeights<Tensor>& we
// backward the 2nd matmul of MLP
// note that _recompute_block guarantees that if SwiGLu is already quantized (if necessary)
rs->temp_acquire(d_acts.DSwiGLU);
backward_qmm(d_acts.DSwiGLU, d_weights.MLP_Down_w, Tensor{}, d_acts.DResFFN, acts.SwiGLu, weights.MLP_Down_w, Tensor{},
backward_qmm(d_acts.DSwiGLU, d_weights.get_tensor(DOWN_W), Tensor{}, d_acts.DResFFN, acts.SwiGLu, weights.MLP_Down_w, Tensor{},
accumulate, *rs, B, T, D, C, true, main_stream);

swiglu_backward(d_acts.DMlpUp.Value, d_acts.DSwiGLU, acts.MlpUp, d_acts.DMlpUp.Quant.abs_max(), B, T, D, main_stream);
Expand All @@ -635,18 +645,18 @@ void LLamaModel::_backward_block(bool accumulate, sLLamaBlockWeights<Tensor>& we
if(Options.grad_dtype() != d_acts.DMlpUp.Value.DType) {
rs->temp_acquire(d_acts.DMlpUp.Quant);
}
backward_qmm(d_acts.DLN2, d_weights.MLP_Up_w, Tensor{}, d_acts.DMlpUp, acts.LN2, weights.MLP_Up_w, Tensor{},
backward_qmm(d_acts.DLN2, d_weights.get_tensor(UP_W), Tensor{}, d_acts.DMlpUp, acts.LN2, weights.MLP_Up_w, Tensor{},
accumulate, *rs, B, T, C, 2 * D, !rs->Options.RecomputeRMSNorm, main_stream);
if(Options.grad_dtype() != d_acts.DMlpUp.Value.DType) {
rs->temp_free(d_acts.DMlpUp.Quant);
}

// rmsnorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above
rmsnorm_backward(d_acts.DResAtt.Value, d_weights.LN2_w, rs->RMSNormScratch, d_acts.DResFFN.Value, d_acts.DLN2,
rmsnorm_backward(d_acts.DResAtt.Value, d_weights.get_tensor(LN2_W), rs->RMSNormScratch, d_acts.DResFFN.Value, d_acts.DLN2,
acts.ResidualAtt, weights.LN2_w, acts.LN2_Rstd, d_acts.DResAtt.Quant.abs_max(), B, T, C, rs->DeviceProp, main_stream);

bool recompute_ln1 = rs->Options.RecomputeRMSNorm || rs->Options.RecomputeAtt;
backward_qmm(d_acts.DAttY, d_weights.Attn_Out_w, Tensor{}, d_acts.DResAtt, acts.Att, weights.Attn_Out_w, Tensor{},
backward_qmm(d_acts.DAttY, d_weights.get_tensor(ATTO_W), Tensor{}, d_acts.DResAtt, acts.Att, weights.Attn_Out_w, Tensor{},
accumulate, *rs, B, T, C, C, false, main_stream);

rs->temp_acquire(d_acts.DQKV.Value);
Expand All @@ -664,7 +674,7 @@ void LLamaModel::_backward_block(bool accumulate, sLLamaBlockWeights<Tensor>& we
rs->temp_free(rs->CuDNNWorkspace);
rope_backward(d_acts.DQKV.Value, d_acts.DQKV.Value, rs->FreqCis, d_acts.DQKV.Quant.abs_max(), B, T, Hq, Hkv, Hs, main_stream);

backward_qmm(d_acts.DLN1, d_weights.Attn_QKV_w, d_weights.Attn_QKV_b, d_acts.DQKV, acts.LN1, weights.Attn_QKV_w, rs->MatmulBiasScratch,
backward_qmm(d_acts.DLN1, d_weights.get_tensor(QKV_W), d_weights.get_tensor(QKV_B), d_acts.DQKV, acts.LN1, weights.Attn_QKV_w, rs->MatmulBiasScratch,
accumulate, *rs, B, T, C, Config.qkv_channels(), !recompute_ln1, main_stream);
rs->temp_free(d_acts.DQKV.Value);
}
Expand Down Expand Up @@ -756,6 +766,8 @@ void LLamaModel::fill_non_block_shapes(GenericTensorContainer& target, const Tra
create(target.get_tensor(LLamaWeightID::LNF_W), C, 0, other_dtype);
if(!config.TiedWordEmbeddings) {
create(target.get_tensor(LLamaWeightID::LM_HEAD), V, C, matrix_dtype);
} else {
create(target.get_tensor(LLamaWeightID::LM_HEAD), 0, 0, matrix_dtype);
}
}

Expand All @@ -768,12 +780,12 @@ void LLamaModel::_calculate_gradient_norm(NCCLCommunicator& comm, float grad_cli
global_norm_squared(rs->NormBuffer, grad, grad.nelem(), rs->DeviceProp, stream);
};

norm_squared(Grads->get_embeddings_shard(stream));
norm_squared(Grads->get_non_block_shard(LLamaWeightID::EMBEDDING, stream));

if(!Config.TiedWordEmbeddings) {
norm_squared(Grads->get_lmhead_shard(stream));
norm_squared(Grads->get_non_block_shard(LLamaWeightID::LM_HEAD, stream));
}
norm_squared(Grads->get_lnf_w_shard(stream));
norm_squared(Grads->get_non_block_shard(LLamaWeightID::LNF_W, stream));

for(int i = 0; i < Config.NumLayers; i++) {
auto& block = Grads->get_block_shard(i, stream);
Expand Down Expand Up @@ -822,11 +834,11 @@ void LLamaModel::update(NCCLCommunicator& comm, float learning_rate, float beta_
auto& nb_scales = OptimizerState->non_block_m_scales();

using namespace LLamaWeightID;
run_update(Parameters->get_master_embeddings(), Grads->get_embeddings_shard(main_stream),
run_update(Parameters->get_master_embeddings(), Grads->get_non_block_shard(EMBEDDING, main_stream),
OptimizerState->non_block_m().get_tensor(EMBEDDING), OptimizerState->non_block_v().get_tensor(EMBEDDING),
nb_scales.get_tensor(EMBEDDING), weight_decay);
comm.reduce_max(Parameters->get_master_embeddings().abs_max());
run_update(Parameters->get_master_lnf_w(), Grads->get_lnf_w_shard(main_stream),
run_update(Parameters->get_master_lnf_w(), Grads->get_non_block_shard(LNF_W, main_stream),
OptimizerState->non_block_m().get_tensor(LNF_W), OptimizerState->non_block_v().get_tensor(LNF_W), nb_scales.get_tensor(LNF_W), 0.f);
comm.reduce_max(Parameters->get_master_lnf_w().abs_max());
CUDA_CHECK(cudaEventRecord(rs->OptEmbeddingsDone, main_stream));
Expand Down Expand Up @@ -868,7 +880,7 @@ void LLamaModel::update(NCCLCommunicator& comm, float learning_rate, float beta_
}

if(!Config.TiedWordEmbeddings) {
run_update(Parameters->get_master_lmhead(), Grads->get_lmhead_shard(main_stream),
run_update(Parameters->get_master_lmhead(), Grads->get_non_block_shard(LM_HEAD, main_stream),
OptimizerState->non_block_m().get_tensor(LM_HEAD), OptimizerState->non_block_v().get_tensor(LM_HEAD), nb_scales.get_tensor(LM_HEAD), weight_decay);
comm.reduce_max(Parameters->get_master_lmhead().abs_max());
}
Expand Down Expand Up @@ -914,7 +926,7 @@ void LLamaModel::allocate_run_state(const LLamaOptions& options, NCCLCommunicato

{
auto ctx = Allocator->with_context("Gradients");
Grads = LLamaGradsManager::create(42, 0, Config, options, comm.rank(), comm.world_size(), Allocator);
Grads = create_grads_manager(42, 0, *this, Config, options, comm.rank(), comm.world_size(), Allocator);
}

OptimizerRNG = std::minstd_rand{42};
Expand Down
8 changes: 5 additions & 3 deletions src/models/llama_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

// ---------------------------------------------------------------------------------------------------------------------

class IGradientManager;

struct LLamaOptions {
bool KeepAllActivations = false;
bool RecomputeSwiGLu = false;
Expand Down Expand Up @@ -113,7 +115,7 @@ class LLamaModel : public IModel {
const TensorAllocator& get_allocator() const { return *Allocator; }

const TransformerConfig& config() { return Config; }
LLamaGradsManager& grads() { return *Grads; }
IGradientManager& grads() { return *Grads; }
LLamaRunState& run_state() { return *RunState; }

IRunState& get_run_state() const override;
Expand All @@ -129,15 +131,15 @@ class LLamaModel : public IModel {
void _forward_block(sLLamaBlockWeights<Tensor>& weights, sLLamaLayerActivations& activations, Tensor& residual);
void _backward_lmhead(long B, long T, float z_loss, int micro_step, int grad_accum_steps, NCCLCommunicator& comm);
void _recompute_block(sLLamaBlockWeights<Tensor>& weights, sLLamaLayerActivations& activations, Tensor& residual);
void _backward_block(bool accumulate, sLLamaBlockWeights<Tensor>& weights, sLLamaGradBlock& grads,
void _backward_block(bool accumulate, sLLamaBlockWeights<Tensor>& weights, SimpleTensorContainer& grads,
sLLamaLayerActivations& activations, sLLamaLayerGradients& d_activations);
private:
TransformerConfig Config;
LLamaOptions Options;
std::shared_ptr<TensorAllocator> Allocator;
std::unique_ptr<LLamaWeightsManager> Parameters;
std::unique_ptr<LLamaOptimizerStateManager> OptimizerState;
std::unique_ptr<LLamaGradsManager> Grads;
std::unique_ptr<IGradientManager> Grads;
std::unique_ptr<LLamaRunState> RunState;

std::minstd_rand OptimizerRNG; //!< Seed generator for stochastic rounding in the optimizer
Expand Down
Loading