Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
45204f1
[prf/dec]Implement prefill-decode for Llama Q8_0
orionpapadakis May 26, 2026
8ebf91f
Reorganize TornadoVM execution planning and improve naming conventions
orionpapadakis May 28, 2026
4e4478a
Update naming from `ActivationGraph` to `ActivationTaskGraph` across …
orionpapadakis May 28, 2026
ea478f8
Rename `AbstractFFNLayers` to `AbstractTransformerLayerTaskGraphs` an…
orionpapadakis May 28, 2026
e20ebc5
Refactor FFN layer comments to `transformer-layer task graphs`, align…
orionpapadakis May 28, 2026
c7522d1
[ci] Add workflows for Llama-3.2-1B-Instruct Q8_0 inference with pref…
orionpapadakis May 28, 2026
d830429
[prf/dec] Move embedding copy to InferenceCore, in alignment to singl…
orionpapadakis Jun 5, 2026
26afbd0
[prf/dec] Update TornadoVM method naming to `tornadoVMForwardDecode` …
orionpapadakis Jun 5, 2026
4ae2e8c
[prf/dec] Drop redundant `EmbeddingPreparer`
orionpapadakis Jun 5, 2026
34cee3a
[prf/dec] Make batch-state reset model-agnostic
orionpapadakis Jun 5, 2026
61c08ae
[prf/dec] Make batch-state reset model-agnostic
orionpapadakis Jun 5, 2026
ce57287
[prf/dec] Consolidate batch-prefill state management into the base `S…
orionpapadakis Jun 5, 2026
bd98e6e
[prf/dec] Replace model-specific reset methods with direct field mani…
orionpapadakis Jun 5, 2026
e215ceb
[prf/dec] Remove redundant buffer reset methods from `State` class
orionpapadakis Jun 5, 2026
b90fc9c
[prf/dec] Update TornadoVM method and interface naming to reflect sin…
orionpapadakis Jun 5, 2026
7f1864b
[prf/dec] Update comments
orionpapadakis Jun 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions .github/workflows/build-and-run.yml
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,107 @@ jobs:
flags="" \
prompt="Say hello"

- name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Prefill-Decode
env:
JAVA_TOOL_OPTIONS: >-
-Dllama.metrics.format=json
-Dllama.metrics.output=file
-Dllama.metrics.file=${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-prefill-decode.json
run: |
cd ${{ github.workspace }}
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
./llama-tornado --gpu --${{ matrix.backend.name }} \
--model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \
--prompt "Say hello" \
--with-prefill-decode
python3 scripts/write_metrics_sidecar.py \
--out "${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-prefill-decode.meta.json" \
backend="${{ matrix.backend.name }}" \
task=llama-inference \
model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \
model=Llama-3.2-1B-Instruct \
quantization=Q8_0 \
configuration=prefill-decode \
"flags=--with-prefill-decode" \
prompt="Say hello"

- name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Batch-Prefill-Decode
env:
JAVA_TOOL_OPTIONS: >-
-Dllama.metrics.format=json
-Dllama.metrics.output=file
-Dllama.metrics.file=${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-batch-prefill-decode.json
run: |
cd ${{ github.workspace }}
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
./llama-tornado --gpu --${{ matrix.backend.name }} \
--model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \
--prompt "Say hello" \
--with-prefill-decode --batch-prefill-size 32
python3 scripts/write_metrics_sidecar.py \
--out "${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-batch-prefill-decode.meta.json" \
backend="${{ matrix.backend.name }}" \
task=llama-inference \
model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \
model=Llama-3.2-1B-Instruct \
quantization=Q8_0 \
configuration=batch-prefill-decode \
"flags=--with-prefill-decode --batch-prefill-size 32" \
prompt="Say hello"

# ── PTX-only: CUDA-graph variants ────────────────────────────────────────
- name: PTX - Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Prefill-Decode-CUDA-Graphs
if: matrix.backend.name == 'ptx'
env:
JAVA_TOOL_OPTIONS: >-
-Dllama.metrics.format=json
-Dllama.metrics.output=file
-Dllama.metrics.file=${{ runner.temp }}/metrics-ptx-llama-1b-q8-prefill-decode-cuda-graphs.json
run: |
cd ${{ github.workspace }}
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
./llama-tornado --gpu --ptx \
--model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \
--prompt "Say hello" \
--with-prefill-decode \
--cuda-graphs
python3 scripts/write_metrics_sidecar.py \
--out "${{ runner.temp }}/metrics-ptx-llama-1b-q8-prefill-decode-cuda-graphs.meta.json" \
backend=ptx \
task=llama-inference \
model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \
model=Llama-3.2-1B-Instruct \
quantization=Q8_0 \
configuration=prefill-decode-cuda-graphs \
"flags=--with-prefill-decode --cuda-graphs" \
prompt="Say hello"

- name: PTX - Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Batch-Prefill-Decode-CUDA-Graphs
if: matrix.backend.name == 'ptx'
env:
JAVA_TOOL_OPTIONS: >-
-Dllama.metrics.format=json
-Dllama.metrics.output=file
-Dllama.metrics.file=${{ runner.temp }}/metrics-ptx-llama-1b-q8-batch-prefill-decode-cuda-graphs.json
run: |
cd ${{ github.workspace }}
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
./llama-tornado --gpu --ptx \
--model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \
--prompt "Say hello" \
--with-prefill-decode --batch-prefill-size 32 \
--cuda-graphs
python3 scripts/write_metrics_sidecar.py \
--out "${{ runner.temp }}/metrics-ptx-llama-1b-q8-batch-prefill-decode-cuda-graphs.meta.json" \
backend=ptx \
task=llama-inference \
model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \
model=Llama-3.2-1B-Instruct \
quantization=Q8_0 \
configuration=batch-prefill-decode-cuda-graphs \
"flags=--with-prefill-decode --batch-prefill-size 32 --cuda-graphs" \
prompt="Say hello"

- name: Q8 - Run Qwen3-0.6B-Q8_0.gguf
env:
JAVA_TOOL_OPTIONS: >-
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i
default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType());
}

return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position);
return tornadoVMMasterPlan.tornadoVMForwardDecode(position);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import org.beehive.gpullama3.auxiliary.Parallel;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.standard.StandardWeights;
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

import java.lang.foreign.MemorySegment;

/**
* Low-level forward passes for the batched prefill/decode inference path (Phase 3/4).
*
Expand All @@ -20,11 +23,10 @@
* <li>{@link #batchForwardJavaPrefill} — CPU batch prefill: processes a chunk of
* prompt tokens in one pass using batch matmul, avoiding redundant weight
* traversals. Only the KV cache is populated; logits are intentionally omitted.</li>
* <li>{@link #batchForwardTornadoVMPrefill} — GPU batch prefill: delegates the chunk
* to {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}.</li>
* <li>{@link #forwardTornadoVMDecode} — GPU decode: delegates a single decode step to
* {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, which
* handles the embedding copy and runs the full decode + logits graphs.</li>
* <li>{@link #batchForwardTornadoVMPrefill} — GPU batch prefill: copies batch embeddings
* into device-visible state buffers then runs the batch activation + layer graphs.</li>
* <li>{@link #forwardTornadoVMDecode} — GPU decode: copies the decode token embedding
* then runs the decode activation + layer + logits graphs.</li>
* </ul>
*/
public final class InferenceCoreBatchPrefillDecode {
Expand Down Expand Up @@ -161,39 +163,92 @@ public static void batchForwardJavaPrefill(Model model, State state, int[] token
// logits are not needed for any token in a prefill batch.
}

private static final int Q8_0_BLOCK_SIZE = 32;
private static final int Q8_0_BLOCK_BYTES = 34;

/**
* GPU batched prefill forward pass (Phase 4).
*
* <p>Delegates the full chunk to
* {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill},
* which handles embedding lookup and GPU execution internally.</p>
* <p>Copies {@code chunkSize} token embeddings into device-visible state buffers,
* then delegates graph execution to the plan.</p>
*
* @param model the LLaMA model
* @param state mutable inference state
* @param tokens token ids for this chunk
* @param startPos sequence position of {@code tokens[0]}
* @param chunkSize number of tokens in this chunk
* @param plan the batched prefill/decode GPU plan
*/
public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int startPos, int chunkSize,
TornadoVMMasterPlanWithBatchPrefillDecode plan) {
plan.tornadoVMForwardBatchPrefill(tokens, startPos, model, chunkSize);
public static void batchForwardTornadoVMPrefill(Model model, State state, int[] tokens, int startPos,
int chunkSize, TornadoVMMasterPlanBatchPrefillDecode plan) {
final Configuration config = model.configuration();
final TornadoWeights weights = (TornadoWeights) model.weights();

state.batchStartPosHolder.set(0, startPos);

switch (weights.getWeightType()) {
case F16 -> {
MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment();
long dimBytes = (long) config.dim() * Short.BYTES;
for (int b = 0; b < chunkSize; b++) {
MemorySegment.copy(embTable, (long) tokens[b] * dimBytes,
state.embeddingXBatch.getSegment(), (long) b * dimBytes, dimBytes);
}
}
case Q8_0 -> {
var embTable = weights.getTokenEmbeddingTable().asByteArray();
int dim = config.dim();
int blocksPerRow = (dim + Q8_0_BLOCK_SIZE - 1) / Q8_0_BLOCK_SIZE;
for (int b = 0; b < chunkSize; b++) {
int tokenId = tokens[b];
for (int j = 0; j < dim; j++) {
int blockByteOffset = (tokenId * blocksPerRow + j / Q8_0_BLOCK_SIZE) * Q8_0_BLOCK_BYTES;
float scale = embTable.getHalfFloat(blockByteOffset).getFloat32();
float quant = embTable.get(blockByteOffset + 2 + j % Q8_0_BLOCK_SIZE);
state.wrapXBatch.set(b * dim + j, quant * scale);
}
}
}
default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType());
}

plan.tornadoVMForwardBatchPrefill();
}

/**
* GPU decode forward pass (Phase 4).
*
* <p>Delegates a single-token decode step to
* {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode},
* which copies the token embedding and runs the decode + logits graphs.</p>
* <p>Copies the token embedding into device-visible state, then delegates
* graph execution to the plan.</p>
*
* @param model the LLaMA model
* @param state mutable inference state
* @param token current token id
* @param position sequence position
* @param plan the batched prefill/decode GPU plan
* @return logits array for token sampling
*/
public static FloatArray forwardTornadoVMDecode(Model model, int token, int position,
TornadoVMMasterPlanWithBatchPrefillDecode plan) {
return plan.tornadoVMForwardDecode(token, position, model);
public static FloatArray forwardTornadoVMDecode(Model model, State state, int token, int position,
TornadoVMMasterPlanBatchPrefillDecode plan) {
final Configuration config = model.configuration();
final TornadoWeights weights = (TornadoWeights) model.weights();

switch (weights.getWeightType()) {
case F16 -> {
MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment();
MemorySegment.copy(embTable, (long) token * config.dim() * Short.BYTES,
state.embeddingX.getSegment(), 0L, (long) config.dim() * Short.BYTES);
}
case Q8_0 -> {
MemorySegment embTable = weights.getTokenEmbeddingTable().asByteArray().getSegment();
int blocksPerToken = (config.dim() + Q8_0_BLOCK_SIZE - 1) / Q8_0_BLOCK_SIZE;
long bytesPerToken = (long) blocksPerToken * Q8_0_BLOCK_BYTES;
MemorySegment.copy(embTable, (long) token * bytesPerToken,
state.embeddingX.getSegment(), 0L, bytesPerToken);
}
default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType());
}

return plan.tornadoVMForwardDecode(position);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode;

import java.lang.foreign.MemorySegment;

Expand Down Expand Up @@ -131,7 +131,7 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p
*
* <p>Copies the token embedding into {@code state.embeddingX} (same as
* {@link InferenceCore#forwardTornadoVM}) then delegates to
* {@link TornadoVMMasterPlanWithPrefillDecode#tornadoVMForwardPrefill},
* {@link TornadoVMMasterPlanPrefillDecode#tornadoVMForwardPrefill},
* which executes preprocessing + layer graphs but skips the logits graph.</p>
*
* @param model the LLaMA model (must carry {@link TornadoWeights}, FP16 only)
Expand All @@ -142,7 +142,7 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p
* @throws UnsupportedOperationException if the model uses Q8_0 weights
*/
public static void forwardTornadoVMPrefill(Model model, State state, int token, int position,
TornadoVMMasterPlanWithPrefillDecode prefillPlan) {
TornadoVMMasterPlanPrefillDecode prefillPlan) {
final Configuration configuration = model.configuration();
final TornadoWeights weights = (TornadoWeights) model.weights();

Expand All @@ -153,9 +153,13 @@ public static void forwardTornadoVMPrefill(Model model, State state, int token,
MemorySegment.copy(tokenEmbeddings, (long) token * configuration.dim() * bytes,
state.embeddingX.getSegment(), 0, (long) configuration.dim() * bytes);
}
case Q8_0 -> throw new UnsupportedOperationException(
// TODO Phase 4: implement Q8_0 GPU batched prefill kernels
"GPU prefill/decode path not yet implemented for Q8_0 weights");
case Q8_0 -> {
MemorySegment tokenEmbeddings = weights.getTokenEmbeddingTable().asByteArray().getSegment();
int blocksPerToken = (configuration.dim() + 31) / 32;
long bytesPerToken = (long) blocksPerToken * 34;
MemorySegment.copy(tokenEmbeddings, (long) token * bytesPerToken,
state.embeddingX.getSegment(), 0, bytesPerToken);
}
Comment on lines +157 to +162
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this should be a method on each own. Same for the above

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i agree, but imho this should be part of another pr where embeddings copy will be refactored as a distinct component that will cleanly facilitate dispatch across quantizations and plan types (single-token, prefill-decode, batch-prefill-decode) in a well-structured manner

default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tokenizer.Tokenizer;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -163,8 +163,8 @@ public static List<Integer> generateTokensGPULlama(
? config.contextLength() : maxTokens;
final int batchSize = TornadoVMMasterPlan.PREFILL_BATCH_SIZE;

TornadoVMMasterPlanWithBatchPrefillDecode plan =
(TornadoVMMasterPlanWithBatchPrefillDecode) tornadoVMPlan;
TornadoVMMasterPlanBatchPrefillDecode plan =
(TornadoVMMasterPlanBatchPrefillDecode) tornadoVMPlan;

List<Integer> generatedTokens = new ArrayList<>();

Expand All @@ -185,7 +185,7 @@ public static List<Integer> generateTokensGPULlama(
int chunkSize = chunkEnd - chunkStart;
int[] chunk = Arrays.copyOfRange(prefillSeq, chunkStart, chunkEnd);

InferenceCoreBatchPrefillDecode.batchForwardTornadoVMPrefill(model, chunk, pos + chunkStart, chunkSize, plan);
InferenceCoreBatchPrefillDecode.batchForwardTornadoVMPrefill(model, state, chunk, pos + chunkStart, chunkSize, plan);

if (echo) {
for (int b = 0; b < chunkSize; b++) {
Expand All @@ -203,7 +203,7 @@ public static List<Integer> generateTokensGPULlama(

// ── Decode ────────────────────────────────────────────────────────────
while (pos < actualMaxTokens) {
var logits = InferenceCoreBatchPrefillDecode.forwardTornadoVMDecode(model, currentToken, pos, plan);
var logits = InferenceCoreBatchPrefillDecode.forwardTornadoVMDecode(model, state, currentToken, pos, plan);
int nextToken = sampler.sampleToken(logits);

if (echo) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tokenizer.Tokenizer;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -135,8 +135,8 @@ public static List<Integer> generateTokensGPULlama(
int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens)
? config.contextLength() : maxTokens;

TornadoVMMasterPlanWithPrefillDecode prefillPlan =
(TornadoVMMasterPlanWithPrefillDecode) tornadoVMPlan;
TornadoVMMasterPlanPrefillDecode prefillPlan =
(TornadoVMMasterPlanPrefillDecode) tornadoVMPlan;

List<Integer> generatedTokens = new ArrayList<>();

Expand Down
Loading
Loading