diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml index a17c9228..46dc3bf4 100644 --- a/.github/workflows/build-and-run.yml +++ b/.github/workflows/build-and-run.yml @@ -385,6 +385,107 @@ jobs: flags="" \ prompt="Say hello" + - name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Prefill-Decode + env: + JAVA_TOOL_OPTIONS: >- + -Dllama.metrics.format=json + -Dllama.metrics.output=file + -Dllama.metrics.file=${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-prefill-decode.json + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \ + --prompt "Say hello" \ + --with-prefill-decode + python3 scripts/write_metrics_sidecar.py \ + --out "${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-prefill-decode.meta.json" \ + backend="${{ matrix.backend.name }}" \ + task=llama-inference \ + model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \ + model=Llama-3.2-1B-Instruct \ + quantization=Q8_0 \ + configuration=prefill-decode \ + "flags=--with-prefill-decode" \ + prompt="Say hello" + + - name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Batch-Prefill-Decode + env: + JAVA_TOOL_OPTIONS: >- + -Dllama.metrics.format=json + -Dllama.metrics.output=file + -Dllama.metrics.file=${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-batch-prefill-decode.json + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \ + --prompt "Say hello" \ + --with-prefill-decode --batch-prefill-size 32 + python3 scripts/write_metrics_sidecar.py \ + --out "${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-batch-prefill-decode.meta.json" \ + backend="${{ matrix.backend.name }}" \ + task=llama-inference \ + model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \ + model=Llama-3.2-1B-Instruct \ + quantization=Q8_0 \ + configuration=batch-prefill-decode \ + "flags=--with-prefill-decode --batch-prefill-size 32" \ + prompt="Say hello" + + # ── PTX-only: CUDA-graph variants ──────────────────────────────────────── + - name: PTX - Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Prefill-Decode-CUDA-Graphs + if: matrix.backend.name == 'ptx' + env: + JAVA_TOOL_OPTIONS: >- + -Dllama.metrics.format=json + -Dllama.metrics.output=file + -Dllama.metrics.file=${{ runner.temp }}/metrics-ptx-llama-1b-q8-prefill-decode-cuda-graphs.json + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --ptx \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \ + --prompt "Say hello" \ + --with-prefill-decode \ + --cuda-graphs + python3 scripts/write_metrics_sidecar.py \ + --out "${{ runner.temp }}/metrics-ptx-llama-1b-q8-prefill-decode-cuda-graphs.meta.json" \ + backend=ptx \ + task=llama-inference \ + model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \ + model=Llama-3.2-1B-Instruct \ + quantization=Q8_0 \ + configuration=prefill-decode-cuda-graphs \ + "flags=--with-prefill-decode --cuda-graphs" \ + prompt="Say hello" + + - name: PTX - Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Batch-Prefill-Decode-CUDA-Graphs + if: matrix.backend.name == 'ptx' + env: + JAVA_TOOL_OPTIONS: >- + -Dllama.metrics.format=json + -Dllama.metrics.output=file + -Dllama.metrics.file=${{ runner.temp }}/metrics-ptx-llama-1b-q8-batch-prefill-decode-cuda-graphs.json + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --ptx \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \ + --prompt "Say hello" \ + --with-prefill-decode --batch-prefill-size 32 \ + --cuda-graphs + python3 scripts/write_metrics_sidecar.py \ + --out "${{ runner.temp }}/metrics-ptx-llama-1b-q8-batch-prefill-decode-cuda-graphs.meta.json" \ + backend=ptx \ + task=llama-inference \ + model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \ + model=Llama-3.2-1B-Instruct \ + quantization=Q8_0 \ + configuration=batch-prefill-decode-cuda-graphs \ + "flags=--with-prefill-decode --batch-prefill-size 32 --cuda-graphs" \ + prompt="Say hello" + - name: Q8 - Run Qwen3-0.6B-Q8_0.gguf env: JAVA_TOOL_OPTIONS: >- diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java index 9beade35..f63f2078 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java @@ -814,7 +814,7 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType()); } - return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position); + return tornadoVMMasterPlan.tornadoVMForwardDecode(position); } } diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java index d3c74599..ec259cb4 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java @@ -3,13 +3,16 @@ import org.beehive.gpullama3.auxiliary.Parallel; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.standard.StandardWeights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; import org.beehive.gpullama3.tensor.standard.FloatTensor; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import java.lang.foreign.MemorySegment; + /** * Low-level forward passes for the batched prefill/decode inference path (Phase 3/4). * @@ -20,11 +23,10 @@ *
Delegates the full chunk to - * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}, - * which handles embedding lookup and GPU execution internally.
+ *Copies {@code chunkSize} token embeddings into device-visible state buffers, + * then delegates graph execution to the plan.
* * @param model the LLaMA model + * @param state mutable inference state * @param tokens token ids for this chunk * @param startPos sequence position of {@code tokens[0]} * @param chunkSize number of tokens in this chunk * @param plan the batched prefill/decode GPU plan */ - public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int startPos, int chunkSize, - TornadoVMMasterPlanWithBatchPrefillDecode plan) { - plan.tornadoVMForwardBatchPrefill(tokens, startPos, model, chunkSize); + public static void batchForwardTornadoVMPrefill(Model model, State state, int[] tokens, int startPos, + int chunkSize, TornadoVMMasterPlanBatchPrefillDecode plan) { + final Configuration config = model.configuration(); + final TornadoWeights weights = (TornadoWeights) model.weights(); + + state.batchStartPosHolder.set(0, startPos); + + switch (weights.getWeightType()) { + case F16 -> { + MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + long dimBytes = (long) config.dim() * Short.BYTES; + for (int b = 0; b < chunkSize; b++) { + MemorySegment.copy(embTable, (long) tokens[b] * dimBytes, + state.embeddingXBatch.getSegment(), (long) b * dimBytes, dimBytes); + } + } + case Q8_0 -> { + var embTable = weights.getTokenEmbeddingTable().asByteArray(); + int dim = config.dim(); + int blocksPerRow = (dim + Q8_0_BLOCK_SIZE - 1) / Q8_0_BLOCK_SIZE; + for (int b = 0; b < chunkSize; b++) { + int tokenId = tokens[b]; + for (int j = 0; j < dim; j++) { + int blockByteOffset = (tokenId * blocksPerRow + j / Q8_0_BLOCK_SIZE) * Q8_0_BLOCK_BYTES; + float scale = embTable.getHalfFloat(blockByteOffset).getFloat32(); + float quant = embTable.get(blockByteOffset + 2 + j % Q8_0_BLOCK_SIZE); + state.wrapXBatch.set(b * dim + j, quant * scale); + } + } + } + default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType()); + } + + plan.tornadoVMForwardBatchPrefill(); } /** * GPU decode forward pass (Phase 4). * - *Delegates a single-token decode step to - * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, - * which copies the token embedding and runs the decode + logits graphs.
+ *Copies the token embedding into device-visible state, then delegates + * graph execution to the plan.
* * @param model the LLaMA model + * @param state mutable inference state * @param token current token id * @param position sequence position * @param plan the batched prefill/decode GPU plan * @return logits array for token sampling */ - public static FloatArray forwardTornadoVMDecode(Model model, int token, int position, - TornadoVMMasterPlanWithBatchPrefillDecode plan) { - return plan.tornadoVMForwardDecode(token, position, model); + public static FloatArray forwardTornadoVMDecode(Model model, State state, int token, int position, + TornadoVMMasterPlanBatchPrefillDecode plan) { + final Configuration config = model.configuration(); + final TornadoWeights weights = (TornadoWeights) model.weights(); + + switch (weights.getWeightType()) { + case F16 -> { + MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + MemorySegment.copy(embTable, (long) token * config.dim() * Short.BYTES, + state.embeddingX.getSegment(), 0L, (long) config.dim() * Short.BYTES); + } + case Q8_0 -> { + MemorySegment embTable = weights.getTokenEmbeddingTable().asByteArray().getSegment(); + int blocksPerToken = (config.dim() + Q8_0_BLOCK_SIZE - 1) / Q8_0_BLOCK_SIZE; + long bytesPerToken = (long) blocksPerToken * Q8_0_BLOCK_BYTES; + MemorySegment.copy(embTable, (long) token * bytesPerToken, + state.embeddingX.getSegment(), 0L, bytesPerToken); + } + default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType()); + } + + return plan.tornadoVMForwardDecode(position); } } diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java index 91bb6f79..9c812b10 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java @@ -7,7 +7,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tensor.standard.FloatTensor; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode; import java.lang.foreign.MemorySegment; @@ -131,7 +131,7 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p * *Copies the token embedding into {@code state.embeddingX} (same as * {@link InferenceCore#forwardTornadoVM}) then delegates to - * {@link TornadoVMMasterPlanWithPrefillDecode#tornadoVMForwardPrefill}, + * {@link TornadoVMMasterPlanPrefillDecode#tornadoVMForwardPrefill}, * which executes preprocessing + layer graphs but skips the logits graph.
* * @param model the LLaMA model (must carry {@link TornadoWeights}, FP16 only) @@ -142,7 +142,7 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p * @throws UnsupportedOperationException if the model uses Q8_0 weights */ public static void forwardTornadoVMPrefill(Model model, State state, int token, int position, - TornadoVMMasterPlanWithPrefillDecode prefillPlan) { + TornadoVMMasterPlanPrefillDecode prefillPlan) { final Configuration configuration = model.configuration(); final TornadoWeights weights = (TornadoWeights) model.weights(); @@ -153,9 +153,13 @@ public static void forwardTornadoVMPrefill(Model model, State state, int token, MemorySegment.copy(tokenEmbeddings, (long) token * configuration.dim() * bytes, state.embeddingX.getSegment(), 0, (long) configuration.dim() * bytes); } - case Q8_0 -> throw new UnsupportedOperationException( - // TODO Phase 4: implement Q8_0 GPU batched prefill kernels - "GPU prefill/decode path not yet implemented for Q8_0 weights"); + case Q8_0 -> { + MemorySegment tokenEmbeddings = weights.getTokenEmbeddingTable().asByteArray().getSegment(); + int blocksPerToken = (configuration.dim() + 31) / 32; + long bytesPerToken = (long) blocksPerToken * 34; + MemorySegment.copy(tokenEmbeddings, (long) token * bytesPerToken, + state.embeddingX.getSegment(), 0, bytesPerToken); + } default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType()); } diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java index 06dbe256..774ae6e1 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java @@ -7,7 +7,7 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode; import java.util.ArrayList; import java.util.Arrays; @@ -163,8 +163,8 @@ public static ListThree concrete implementations exist:
*The {@link #initializeTornadoVMPlan} factory selects the implementation based on * {@code llama.withPrefillDecode} and {@code llama.prefillBatchSize}:
*When {@code llama.withPrefillDecode=true} and {@code llama.prefillBatchSize > 1}, - * a {@link TornadoVMMasterPlanWithBatchPrefillDecode} is returned. - * Otherwise a {@link TornadoVMMasterPlanStandard} is returned (used for the baseline + * a {@link TornadoVMMasterPlanBatchPrefillDecode} is returned. + * Otherwise a {@link TornadoVMMasterPlanSingleToken} is returned (used for the baseline * path and the sequential prefill/decode path when batch size is 1).
* * @param state the model state @@ -56,13 +56,13 @@ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { if (WITH_PREFILL_DECODE && PREFILL_BATCH_SIZE > 1) { // GPU path with batched prefill/decode - plan = new TornadoVMMasterPlanWithBatchPrefillDecode(state, model); + plan = new TornadoVMMasterPlanBatchPrefillDecode(state, model); } else if (WITH_PREFILL_DECODE) { // GPU path with simple prefill/decode - plan = new TornadoVMMasterPlanWithPrefillDecode(state, model); + plan = new TornadoVMMasterPlanPrefillDecode(state, model); } else { // GPU path with no prefill/decode - plan = new TornadoVMMasterPlanStandard(state, model); + plan = new TornadoVMMasterPlanSingleToken(state, model); } model.setTornadoVMPlan(plan); return plan; @@ -76,7 +76,7 @@ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { void forceCopyInReadOnlyData(); - FloatArray tornadoVMForwardExecuteLayered(int position); + FloatArray tornadoVMForwardDecode(int position); /** Releases all device memory held by this plan. */ void freeTornadoExecutionPlan(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefillDecode.java new file mode 100644 index 00000000..593621a0 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefillDecode.java @@ -0,0 +1,153 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.auxiliary.RunMetrics; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tornadovm.plan.BatchPrefillDecodeForwardPlan; +import org.beehive.gpullama3.tornadovm.plan.ForwardPlanFactory; +import org.beehive.gpullama3.tornadovm.plan.layout.BatchPrefillDecodeForwardTaskGraphLayout; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * GPU execution plan for batched prefill + single-token decode. + * + *A single {@link TornadoExecutionPlan} holds all TaskGraphs for + * batched prefill and single-token decode phases:
+ * + *TaskGraph layout (2N+3 TaskGraphs total):
+ *+ * [0] batchPrefillActivation B×dim embeddings → FP32 wrapXBatch + * [1..N] batch-prefill layers B tokens, all transformer ops + * [N+1] decodeActivation single-token embedding → FP32 + KV-cache pass-through + * [N+2..2N+1] decode layers single-token, standard kernels + * [2N+2] logits + *+ */ +public class TornadoVMMasterPlanBatchPrefillDecode implements TornadoVMMasterPlan { + + private final State state; + private final Model model; + private final Configuration config; + + BatchPrefillDecodeForwardPlan batchPrefillDecodeForwardPlan; + BatchPrefillDecodeForwardTaskGraphLayout taskGraphLayout; + public TornadoExecutionPlan executionPlan; + + // ── Construction ───────────────────────────────────────────────────────── + TornadoVMMasterPlanBatchPrefillDecode(State initialState, Model model) { + if (ENABLE_TORNADOVM_INIT_TIME) { + System.err.println("\nStarting TornadoVM initialization..."); + } + + this.state = initialState; + this.model = model; + this.config = model.configuration(); + + long startTime = System.nanoTime(); + this.executionPlan = createExecutionPlan(); + long planCreationTime = System.nanoTime(); + + if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); + executionPlan.withPreCompilation(); + long warmupTime = System.nanoTime(); + + forceCopyInReadOnlyData(); + long copyTime = System.nanoTime(); + + RunMetrics.setTornadoMetrics(planCreationTime - startTime, warmupTime - planCreationTime, copyTime - warmupTime); + } + + // ── Plan construction ───────────────────────────────────────────────────── + + @Override + public TornadoExecutionPlan createExecutionPlan() { + GGMLType weightType = model.weights().getWeightType(); + this.batchPrefillDecodeForwardPlan = + ForwardPlanFactory.createBatchPrefillDecode(weightType, state, model); + this.taskGraphLayout = batchPrefillDecodeForwardPlan.getTaskGraphLayout(); + var taskGraphs = batchPrefillDecodeForwardPlan.getImmutableTaskGraphs(); + return new TornadoExecutionPlan(taskGraphs.toArray(new ImmutableTaskGraph[0])); + } + + // ── Initialisation ──────────────────────────────────────────────────────── + + @Override + public void forceCopyInReadOnlyData() { + state.wrapX.clear(); + state.positionHolder.init(0); + state.wrapXBatch.clear(); + state.batchStartPosHolder.init(0); + + for (int i = 0; i <= taskGraphLayout.logitsIdx(); i++) { + var g = executionPlan.withGraph(i) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) g.withCUDAGraph(); + g.execute(); + } + } + + // ── Forward passes ──────────────────────────────────────────────────────── + + /** + * Batch prefill: runs graphs 0..N (activation + N layers), skips logits. + * Caller is responsible for copying batch embeddings into state before calling this. + */ + public void tornadoVMForwardBatchPrefill() { + var batchAct = executionPlan.withGraph(taskGraphLayout.batchActivationIdx()) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) batchAct.withCUDAGraph(); + batchAct.execute(); + + for (int l = 0; l < config.numberOfLayers(); l++) { + var batchLayer = executionPlan.withGraph(taskGraphLayout.batchLayerIdx(l)) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) batchLayer.withCUDAGraph(); + batchLayer.execute(); + } + } + + /** + * Single-token decode: runs graphs N+1..2N+2 (activation + N layers + logits). + * Caller is responsible for copying the decode embedding into state before calling this. + * + * @param position sequence position + * @return logits array for sampling + */ + @Override + public FloatArray tornadoVMForwardDecode(int position) { + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + var decodeAct = executionPlan.withGraph(taskGraphLayout.decodeActivationIdx()) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) decodeAct.withCUDAGraph(); + decodeAct.execute(); + + for (int l = 0; l < config.numberOfLayers(); l++) { + var decodeLayer = executionPlan.withGraph(taskGraphLayout.decodeLayerIdx(l)) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) decodeLayer.withCUDAGraph(); + decodeLayer.execute(); + } + + state.tempLogits.clear(); + state.wrapLogits.clear(); + + var logits = executionPlan.withGraph(taskGraphLayout.logitsIdx()) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) logits.withCUDAGraph(); + logits.execute(); + + return state.wrapLogits; + } + + @Override + public void freeTornadoExecutionPlan() { + executionPlan.freeDeviceMemory(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanPrefillDecode.java new file mode 100644 index 00000000..b8b93187 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanPrefillDecode.java @@ -0,0 +1,162 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.auxiliary.RunMetrics; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tornadovm.plan.ForwardPlanFactory; +import org.beehive.gpullama3.tornadovm.plan.PrefillDecodeForwardPlan; +import org.beehive.gpullama3.tornadovm.plan.layout.PrefillDecodeForwardTaskGraphLayout; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * GPU execution plan for sequential (single-token) prefill/decode separation. + * + *
A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache + * ({@code wrapKeyCache}, {@code wrapValueCache}) is allocated once and remains on + * device across both phases. Prefill and decode reuse the same N layer graphs; + * only the logits graph is skipped during prefill.
+ * + *Graph layout (N+2 graphs total):
+ *+ * [0] decodeActivation single-token FP16 → FP32; KV-cache allocated on first execution + * [1..N] layer_0..layer_N-1 transformer layers (attention + FFN) + * [N+1] logits final RMSNorm + wcls matmul + *+ * + *
Two forward passes:
+ *A single {@link TornadoExecutionPlan} holds all {@link TaskGraph} for - * batched prefill and single-token decode phases with the following structure:
. - * - *TaskGraph layout (2N+3 TaskGraphs total):
- *- * [0] prefill batch activation B×dim FP16 → FP32 - * [1..N] prefill batch layer graphs B tokens, all transformer ops - * [N+1] decode activation single-token FP16 → FP32 + KV-cache pass-through - * [N+2..2N+1] decode layer graphs single-token, standard kernels - * [2N+2] logits graph - *- * - *
- * Incorporating cross-phase {@link TaskGraph}s withing a single {@link TornadoExecutionPlan} - * is necessary to enable KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) sharing - * across prefill and decode phases. The KV cache pointers are chained across {@link TaskGraph}s - * via the {@code persistOnDevice}/{@code consumeFromDevice} API within the {@link TornadoExecutionPlan}. - *
- * - *KV cache pointer chain across phases:
- *- * batchLayer[N-1] --persistOnDevice(wrapKeyCache)-→ - * decodeActivation --consumeFromDevice(wrapKeyCache)-→ (pass-through) - * decodeLayer[0] --consumeFromDevice(wrapKeyCache)-→ (used by attention) - *- */ -public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMasterPlan { - - private final LlamaState state; - private final Model model; - private final LlamaConfiguration config; - private final int batchSize; - private final int N; // numberOfLayers - private final TornadoExecutionPlan executionPlan; - private final GridScheduler gridScheduler; - - // ── Graph-index helpers ─────────────────────────────────────────────────── - private int batchActivationIdx() { return 0; } - private int batchLayerIdx(int i) { return 1 + i; } - private int decodeActivationIdx() { return N + 1; } - private int decodeLayerIdx(int i) { return N + 2 + i; } - private int logitsIdx() { return 2 * N + 2; } - - // ── Construction ───────────────────────────────────────────────────────── - TornadoVMMasterPlanWithBatchPrefillDecode(State initialState, Model model) { - if (ENABLE_TORNADOVM_INIT_TIME) { - System.err.println("\nStarting TornadoVM initialization..."); - } - - this.state = (LlamaState) initialState; // only LlamaFP16 supports batched prefill for now - this.model = model; - this.config = (LlamaConfiguration) model.configuration(); - this.batchSize = PREFILL_BATCH_SIZE; - this.N = config.numberOfLayers(); - this.gridScheduler = new GridScheduler(); - - long startTime = System.nanoTime(); - this.executionPlan = createExecutionPlan(); - long planCreationTime = System.nanoTime(); - - if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); - executionPlan.withPreCompilation(); - long warmupTime = System.nanoTime(); - - forceCopyInReadOnlyData(); - long copyTime = System.nanoTime(); - - RunMetrics.setTornadoMetrics(planCreationTime - startTime, warmupTime - planCreationTime, copyTime - warmupTime); - } - - // ── Batch Prefill Activation graphs ───────────────────────────────────────────────────── - - /** Graph 0: B×dim FP16 embeddings → FP32 wrapXBatch. */ - private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { - return new TaskGraph("prefillActivation") - .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch) - .task("updateX", TransformerComputeKernels::convertFP16toFP32, - ctx, state.embeddingXBatch, state.wrapXBatch) - .persistOnDevice(state.wrapXBatch); - } - - /** - * Graph N+1: single-token FP16 → FP32. - * - *
Receives the KV-cache device pointer from batch layer N via - * {@code consumeFromDevice}, then re-emits it via {@code persistOnDevice} so - * that {@code updatePersistedObjectState()} can propagate it to decode layer 0. - * Both halves of the chain are required; without the re-persist the pointer is - * not forwarded in interpreter (non-CUDA-graph) mode.
- */ - private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID) { - return new TaskGraph("decodeActivation") - .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) // KV pass-through - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", - TransformerComputeKernels::convertFP16toFP32, - ctx, (HalfFloatArray) state.embeddingX, state.wrapX) - // wrapX persisted for decode layer 0; wrapKeyCache/wrapValueCache - // re-persisted so updatePersistedObjectState() propagates the device - // pointer to decode layer 0's consumeFromDevice without CUDA graphs. - .persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); - } - - /** - * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill in batches and separated decode*. - * - * TODO: support Q8_0 weights - * To implement this, consult how {@link TornadoVMMasterPlanStandard} uses the {@link QuantizationPlannerFactory} - */ - @Override - public TornadoExecutionPlan createExecutionPlan() { - GGMLType weightType = model.weights().getWeightType(); - switch (weightType) { - case F16 -> { /* supported — continue below */ } - case Q8_0 -> throw new UnsupportedOperationException( - "Batched prefill/decode GPU path not yet implemented for Q8_0 weights"); - default -> throw new UnsupportedOperationException( - "Batched prefill/decode GPU path not supported for weight type: " + weightType); - } - - LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); - SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); - - ListA single {@link TornadoExecutionPlan} holds all graphs so that the KV cache - * ({@code wrapKeyCache}, {@code wrapValueCache}) is allocated once and remains on - * device across both phases. Prefill and decode reuse the same N layer graphs; - * only the logits graph is skipped during prefill.
- * - *Graph layout (N+2 graphs total):
- *- * [0] decodeActivation single-token FP16 → FP32; KV-cache allocated on first execution - * [1..N] layer_0..layer_N-1 transformer layers (attention + FFN) - * [N+1] logits final RMSNorm + wcls matmul - *- * - *
Two forward passes:
- *Outputs {@code wrapX} (FP32 hidden state) and persists it on device so that - * decode layer 0 can pick it up via {@code consumeFromDevice("decodeActivation", wrapX)}. - * The KV cache is not managed here — it is allocated on the first forward pass - * by decode layer 0 via {@code FIRST_EXECUTION}.
- */ - private TaskGraph buildActivationGraph(KernelContext ctx) { - return new TaskGraph("decodeActivation") - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", TransformerComputeKernels::convertFP16toFP32, - ctx, (HalfFloatArray) state.embeddingX, state.wrapX) - .persistOnDevice(state.wrapX); - } - - // ── Plan construction ───────────────────────────────────────────────────── - /** - * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill/decode separation*. - * Prefill is token-by-token but does not compute logits. - * - * TODO: support Q8_0 weights - * To implement this, consult how {@link TornadoVMMasterPlanStandard} uses the {@link QuantizationPlannerFactory} - */ - @Override - public TornadoExecutionPlan createExecutionPlan() { - GGMLType weightType = model.weights().getWeightType(); - switch (weightType) { - case F16 -> { /* supported — continue below */ } - case Q8_0 -> throw new UnsupportedOperationException( - "Prefill/decode GPU path not yet implemented for Q8_0 weights"); - default -> throw new UnsupportedOperationException( - "Prefill/decode GPU path not supported for weight type: " + weightType); - } - - LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); - SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); - - ListThese kernels are meant to be registered in {@link TornadoVMMasterPlanWithPrefillDecode} + *
These kernels are meant to be registered in {@link TornadoVMMasterPlanBatchPrefillDecode} * batch task graphs; they are NOT invoked directly.
*/ public final class TransformerBatchPrefillKernels { @@ -458,4 +459,229 @@ public static void batchedFusedRmsNormFFNGateUp(KernelContext context, wrapHbBatch.set(batchIdx * hiddenDim + rowIdx, silu * result3); } } + + // ── Q8_0 Batch Kernels ─────────────────────────────────────────────────── + + /** + * No-op kernel for Q8_0 batch activation graph. + * The host fills wrapXBatch with dequantized FP32 embeddings before execution. + * Worker: 1 global thread. + */ + public static void batchPassthrough(KernelContext context, FloatArray wrapXBatch) { + if (context.globalIdx == 0) { + wrapXBatch.set(0, wrapXBatch.get(0)); + } + } + + /** + * Applies RMS normalization to FP32 — Q8_0 variant. + * Writes normalized FP32 to wrapXbBatch (reused as xb intermediate before QKV). + * Worker: B*dim global threads, localSize=256. + */ + public static void batchedRmsApplyFP32(KernelContext context, + FloatArray wrapXbBatch, + FloatArray wrapXBatch, + FloatArray rmsWeights, + FloatArray attnScaleBatch, + int dim) { + int gid = context.globalIdx; + int b = gid / dim; + int i = gid % dim; + wrapXbBatch.set(gid, rmsWeights.get(i) * attnScaleBatch.get(b) * wrapXBatch.get(gid)); + } + + /** + * Fused batched QKV projection with Q8_0 weight dequantization. + * Input wrapXbBatch is FP32 (written by batchedRmsApplyFP32). + * groupIdx = batchIdx * (dim + 2*kvDim) + rowIdx. + * Worker: B*(dim+2*kvDim) workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedQKVMatmulQ8(KernelContext context, + FloatArray wrapXbBatch, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + ByteArray wq, + ByteArray wk, + ByteArray wv, + int dim, int kvDim, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int totalRows = dim + 2 * kvDim; + int batchIdx = groupId / totalRows; + int rowIdx = groupId % totalRows; + int inputOff = batchIdx * dim; + + int blockSize = 32; + int Q8_0_BLOCK_BYTES = 34; + int blocksPerRow = (dim + blockSize - 1) / blockSize; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowIdx < dim) { + int rowBlockOffset = rowIdx * blocksPerRow; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES; + float scale = wq.getHalfFloat(blockByteOffset).getFloat32(); + float quant = wq.get(blockByteOffset + 2 + j % blockSize); + partial += quant * scale * wrapXbBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapQBatch.set(batchIdx * dim + rowIdx, localSum[0]); + + } else if (rowIdx < dim + kvDim) { + int kRow = rowIdx - dim; + int rowBlockOffset = kRow * blocksPerRow; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES; + float scale = wk.getHalfFloat(blockByteOffset).getFloat32(); + float quant = wk.get(blockByteOffset + 2 + j % blockSize); + partial += quant * scale * wrapXbBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapKBatch.set(batchIdx * kvDim + kRow, localSum[0]); + + } else { + int vRow = rowIdx - dim - kvDim; + int rowBlockOffset = vRow * blocksPerRow; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES; + float scale = wv.getHalfFloat(blockByteOffset).getFloat32(); + float quant = wv.get(blockByteOffset + 2 + j % blockSize); + partial += quant * scale * wrapXbBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapVBatch.set(batchIdx * kvDim + vRow, localSum[0]); + } + } + + /** + * Batched matrix-vector multiply with residual add (Q8_0 weights). + * Used for attention output (Wo) and FFN down (W2) projections. + * groupIdx = batchIdx * d + rowIdx. + * Worker: B*d workgroups × localWorkGroupSize threads. + */ + public static void batchedMatVecWithResidualQ8(KernelContext context, + FloatArray inputBatch, + FloatArray outputBatch, + ByteArray w, + int n, int d, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int batchIdx = groupId / d; + int rowIdx = groupId % d; + + int blockSize = 32; + int Q8_0_BLOCK_BYTES = 34; + int blocksPerRow = (n + blockSize - 1) / blockSize; + int rowBlockOffset = rowIdx * blocksPerRow; + int inputOff = batchIdx * n; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + float partial = 0.0f; + for (int j = localId; j < n; j += localWorkGroupSize) { + int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES; + float scale = w.getHalfFloat(blockByteOffset).getFloat32(); + float quant = w.get(blockByteOffset + 2 + j % blockSize); + partial += quant * scale * inputBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) { + int outIdx = batchIdx * d + rowIdx; + outputBatch.set(outIdx, outputBatch.get(outIdx) + localSum[0]); + } + } + + /** + * Batched fused RMS-apply + W1/W3 gate-up projections + SiLU + GLU (Q8_0 weights). + * groupIdx = batchIdx * hiddenDim + rowIdx. + * Worker: B*hiddenDim workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedRmsNormFFNGateUpQ8(KernelContext context, + FloatArray wrapXBatch, + FloatArray wrapHbBatch, + FloatArray rmsFFNWeights, + FloatArray ffnScaleBatch, + ByteArray w1, + ByteArray w3, + int dim, int hiddenDim, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int batchIdx = groupId / hiddenDim; + int rowIdx = groupId % hiddenDim; + + float scale = ffnScaleBatch.get(batchIdx); + int inputOff = batchIdx * dim; + + int blockSize = 32; + int Q8_0_BLOCK_BYTES = 34; + int blocksPerRow = (dim + blockSize - 1) / blockSize; + int rowBlockOffset = rowIdx * blocksPerRow; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + float sum1 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES; + float w1Scale = w1.getHalfFloat(blockByteOffset).getFloat32(); + float w1Quant = w1.get(blockByteOffset + 2 + j % blockSize); + float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j); + sum1 += w1Quant * w1Scale * normed; + } + localSum[localId] = sum1; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + float result1 = localSum[0]; + + float sum3 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES; + float w3Scale = w3.getHalfFloat(blockByteOffset).getFloat32(); + float w3Quant = w3.get(blockByteOffset + 2 + j % blockSize); + float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j); + sum3 += w3Quant * w3Scale * normed; + } + localSum[localId] = sum3; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + float result3 = localSum[0]; + + if (localId == 0) { + float silu = result1 / (1.0f + TornadoMath.exp(-result1)); + wrapHbBatch.set(batchIdx * hiddenDim + rowIdx, silu * result3); + } + } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java deleted file mode 100644 index 9e211051..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java +++ /dev/null @@ -1,14 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner; - -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; - -import java.util.List; - -public interface GenericLayerPlanner { - - List- * The factory follows a routing logic: - *
- * Examples: - *
- * Each component is represented as an {@link ImmutableTaskGraph}, along with a - * corresponding {@link GridScheduler} configuration that defines how tasks are - * mapped on the GPU. - *
- * This method assembles all components into a unified execution pipeline and - * caches the resulting task graphs and scheduler for reuse across inference runs. - */ - protected final void createTornadoInferencePlan() { - ListImplemented by {@link Activation} and custom activation wrappers used by + * {@link org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents}.
+ */ +public interface ActivationTaskGraph { + ImmutableTaskGraph getImmutableTaskGraph(); + GridScheduler updateGridScheduler(GridScheduler scheduler); +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/BatchPrefillTransformerLayerTaskGraphs.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/BatchPrefillTransformerLayerTaskGraphs.java new file mode 100644 index 00000000..0e7a762f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/BatchPrefillTransformerLayerTaskGraphs.java @@ -0,0 +1,17 @@ +package org.beehive.gpullama3.tornadovm.layers; + +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + +import java.util.List; + +/** + * Interface for a group of N batched-prefill transformer-layer {@link uk.ac.manchester.tornado.api.TaskGraph}. + * + *Implemented by {@code LlamaFP16LayersBatchPrefill} and {@code LlamaQ8_0LayersBatchPrefill}.
+ */ +public interface BatchPrefillTransformerLayerTaskGraphs { + ListImplemented by {@link AbstractTransformerLayerTaskGraphs} and its subclasses.
+ */ +public interface TransformerLayerTaskGraphs { + ListThe no-arg form is safe in CUDA-graph mode (device pointers are frozen at capture time)
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
index 1858408e..1295db09 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
@@ -7,15 +7,15 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-public class LogitsFP16Layer extends AbstractLogitsLayer {
+public class LogitsFP16Layer extends AbstractLogitsTaskGraph {
public LogitsFP16Layer(String name, State state, Weights weights, Configuration config,
String lastTaskGraphID, SchedulerType schedulerType) {
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java
index 54ec9641..2e3a6fe6 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java
@@ -8,7 +8,7 @@
import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java
index 499fc176..cd03b2ba 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java
@@ -5,15 +5,15 @@
import org.beehive.gpullama3.model.mistral.MistralConfiguration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-public class MistralFP16FFNLayers extends AbstractFFNLayers Overrides data-transfer declarations so that all cross-graph boundaries use
* the explicit-source form of {@code consumeFromDevice}. The no-arg form (used by
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java
index 2f5bac64..4a74e69b 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java
@@ -3,14 +3,14 @@
import org.beehive.gpullama3.inference.state.LlamaState;
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * Decode FFN layers for the single-token prefill/decode plan
- * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode}).
+ * Decode transformer-layer TaskGraphs for the single-token prefill/decode plan
+ * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}).
*
* Combines two concerns: Extends {@link LogitsFP16Layer} with KV-cache pass-through so the device
* pointers for {@code wrapKeyCache} and {@code wrapValueCache} survive the
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
index a44425ef..6f0b3e4b 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
@@ -4,7 +4,8 @@
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
@@ -16,8 +17,8 @@
import java.util.stream.IntStream;
/**
- * Prefill FFN layers with batching for the unified batched prefill-decode plan
- * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}).
+ * Batched-prefill transformer-layer TaskGraphs for the unified batched prefill-decode plan
+ * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}).
*
* One {@link ImmutableTaskGraph} per transformer layer, each processing
* {@code batchSize} tokens simultaneously via {@link TransformerBatchPrefillKernels}. KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) is persisted on device
* after every layer so the subsequent single-token decode layers can consume it. Layer 0 consumes the KV cache from device (passed through by the decode activation
+ * graph, which relays it from the last batch prefill layer). No FIRST_EXECUTION allocation
+ * for the KV cache — it was already allocated in the batch prefill phase. Layer 0 delegates to {@link LlamaQ8_0FFNLayers#configureLayerDataTransfers} which
+ * includes {@code FIRST_EXECUTION} for {@code wrapKeyCache} and {@code wrapValueCache},
+ * allocating the KV cache on the very first forward pass. Layers 1+ use explicit
+ * predecessor names for all consumed objects, required by TornadoVM's interpreter mode. Extends {@link LogitsQ8_0Layer} with KV-cache pass-through so the device pointers for
+ * {@code wrapKeyCache} and {@code wrapValueCache} survive the logits → decode-activation
+ * boundary between decode tokens. Without the pass-through, the KV-cache pointer is absent
+ * from the logits persisted set, cleared to null, and the first decode layer crashes with
+ * an NPE in {@code executeAlloc}. Mirrors {@link org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill}
+ * but uses Q8_0 kernels with inline dequantization. Key differences from the FP16 path: Graph layout: During batch prefill, the master plan executes graphs 0..N.
+ * During decode, graphs N+1..2N+2 run. Subclasses assemble the {@link ImmutableTaskGraph} list and {@link GridScheduler}
+ * in their constructors by calling {@link #setGraphs}, then expose the results
+ * via {@link #getImmutableTaskGraphs} and {@link #getGridScheduler}. Dispatches across three axes in order:
+ * Use the typed convenience methods when the execution mode is known at the call site: Graph layout: During prefill, the master plan executes graphs 0..N (skipping logits).
+ * During decode, all N+2 graphs run. Graph layout:
+ * Batch-prefill/decode inference with TornadoVM is implemented by {@link TornadoVMMasterPlanBatchPrefillDecode}.
+ * It employs a {@link BatchPrefillDecodeForwardPlan} instance to represent the complete
+ * batch-prefill/decode forward operation as a chain of distinct TornadoVM TaskGraphs.
+ * The components of this chain are represented by the following components:
+ *
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java
index 350e6760..7ee529d0 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java
@@ -3,13 +3,13 @@
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.model.Configuration;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
import uk.ac.manchester.tornado.api.TaskGraph;
/**
* Logits layer of the unified batched prefill-decode plan
- * * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}).
+ * * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}).
*
*
+ *
+ */
+public class LlamaQ8_0LayersBatchPrefill implements BatchPrefillTransformerLayerTaskGraphs {
+
+ static final int LOCAL_WORK_GROUP_SIZE = 32;
+
+ private final LlamaState state;
+ private final LlamaTornadoWeights weights;
+ private final LlamaConfiguration config;
+ private final KernelContext context = new KernelContext();
+ private final int batchSize;
+ private final List
+ * [0] batch activation ← batchPrefillActivation(int)
+ * [1..N] batch layers ← batchPrefillTransformerLayers(int)
+ * [N+1] decode activation ← batchDecodeActivation(String)
+ * [N+2..2N+1] decode layers ← batchDecodeTransformerLayers()
+ * [2N+2] logits ← decodeLogits(String)
+ *
+ *
+ *
+ *
+ */
+public enum ExecutionMode {
+ STANDARD,
+ PREFILL_DECODE,
+ BATCH_PREFILL_DECODE
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlan.java
new file mode 100644
index 00000000..5ccdb4c0
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlan.java
@@ -0,0 +1,32 @@
+package org.beehive.gpullama3.tornadovm.plan;
+
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+
+import java.util.List;
+
+/**
+ * Abstract base for GPU forward-pass execution plans.
+ *
+ *
+ *
+ *
+ *
+ *
+ */
+public class ForwardPlanFactory {
+
+ private ForwardPlanFactory() {}
+
+ // ── Typed public API ──────────────────────────────────────────────────────
+
+ public static SingleTokenForwardPlan createSingleToken(GGMLType quantization, State state, Model model) {
+ ForwardPlan plan = create(quantization, ExecutionMode.STANDARD, state, model);
+ if (plan instanceof SingleTokenForwardPlan singleToken) return singleToken;
+ throw new IllegalStateException("Expected SingleTokenForwardPlan for STANDARD mode but got " + plan.getClass().getSimpleName());
+ }
+
+ public static PrefillDecodeForwardPlan createPrefillDecode(GGMLType quantization, State state, Model model) {
+ ForwardPlan plan = create(quantization, ExecutionMode.PREFILL_DECODE, state, model);
+ if (plan instanceof PrefillDecodeForwardPlan prefillDecode) return prefillDecode;
+ throw new IllegalStateException("Expected PrefillDecodeForwardPlan for PREFILL_DECODE mode but got " + plan.getClass().getSimpleName());
+ }
+
+ public static BatchPrefillDecodeForwardPlan createBatchPrefillDecode(GGMLType quantization, State state, Model model) {
+ ForwardPlan plan = create(quantization, ExecutionMode.BATCH_PREFILL_DECODE, state, model);
+ if (plan instanceof BatchPrefillDecodeForwardPlan batchPrefillDecode) return batchPrefillDecode;
+ throw new IllegalStateException("Expected BatchPrefillDecodeForwardPlan for BATCH_PREFILL_DECODE mode but got " + plan.getClass().getSimpleName());
+ }
+
+ // ── Generic dispatch ──────────────────────────────────────────────────────
+
+ static ForwardPlan create(GGMLType quantization, ExecutionMode mode, State state, Model model) {
+ return switch (quantization) {
+ case F16 -> createFP16Plan(mode, state, model);
+ case Q8_0 -> createQ8_0Plan(mode, state, model);
+ case F32 -> throw new UnsupportedOperationException("F32 plans not yet implemented");
+ case Q4_0 -> throw new UnsupportedOperationException("Q4_0 plans not yet implemented");
+ default -> throw new UnsupportedOperationException("Quantization not supported: " + quantization);
+ };
+ }
+
+ // ── FP16 branch ───────────────────────────────────────────────────────────
+
+ private static ForwardPlan createFP16Plan(ExecutionMode mode, State state, Model model) {
+ return switch (model.getModelType()) {
+ case LLAMA_3 -> createLlamaFP16Plan(mode, (LlamaState) state, model);
+ case MISTRAL -> createMistralFP16Plan(mode, (LlamaState) state, model);
+ case DEVSTRAL_2 -> createDevstralFP16Plan(mode, (DevstralState) state, model);
+ case QWEN_2 -> createQwen2FP16Plan(mode, (Qwen2State) state, model);
+ case QWEN_3 -> createQwen3FP16Plan(mode, (Qwen3State) state, model);
+ case PHI_3 -> createPhi3FP16Plan(mode, (Phi3State) state, model);
+ case GRANITE -> createGraniteFP16Plan(mode, (GraniteState) state, model);
+ case DEEPSEEK_R1_DISTILL_QWEN -> createQwen2FP16Plan(mode, (Qwen2State) state, model);
+ default -> throw new UnsupportedOperationException("F16 not supported for model: " + model.getModelType());
+ };
+ }
+
+ // ── Q8_0 branch ───────────────────────────────────────────────────────────
+
+ private static ForwardPlan createQ8_0Plan(ExecutionMode mode, State state, Model model) {
+ return switch (model.getModelType()) {
+ case LLAMA_3 -> createLlamaQ8_0Plan(mode, (LlamaState) state, model);
+ case MISTRAL -> createMistralQ8_0Plan(mode, (LlamaState) state, model);
+ case DEVSTRAL_2 -> createDevstralQ8_0Plan(mode, (DevstralState) state, model);
+ case QWEN_2 -> createQwen2Q8_0Plan(mode, (Qwen2State) state, model);
+ case QWEN_3 -> createQwen3Q8_0Plan(mode, (Qwen3State) state, model);
+ case PHI_3 -> createPhi3Q8_0Plan(mode, (Phi3State) state, model);
+ case GRANITE -> createGraniteQ8_0Plan(mode, (GraniteState) state, model);
+ case DEEPSEEK_R1_DISTILL_QWEN -> createQwen2Q8_0Plan(mode, (Qwen2State) state, model);
+ default -> throw new UnsupportedOperationException("Q8_0 not supported for model: " + model.getModelType());
+ };
+ }
+
+ // ── Model+quant helpers — Llama (all 3 modes supported) ──────────────────
+
+ private static ForwardPlan createLlamaFP16Plan(ExecutionMode mode, LlamaState state, Model model) {
+ BatchPrefillDecodeForwardPlanComponents components = new LlamaFP16PlanComponents(state, model);
+ return switch (mode) {
+ case STANDARD -> new SingleTokenForwardPlan(model, components);
+ case PREFILL_DECODE -> new PrefillDecodeForwardPlan(model, components);
+ case BATCH_PREFILL_DECODE -> new BatchPrefillDecodeForwardPlan(model, components, TornadoVMMasterPlan.PREFILL_BATCH_SIZE);
+ };
+ }
+
+ private static ForwardPlan createLlamaQ8_0Plan(ExecutionMode mode, LlamaState state, Model model) {
+ BatchPrefillDecodeForwardPlanComponents components = new LlamaQ8_0PlanComponents(state, model);
+ return switch (mode) {
+ case STANDARD -> new SingleTokenForwardPlan(model, components);
+ case PREFILL_DECODE -> new PrefillDecodeForwardPlan(model, components);
+ case BATCH_PREFILL_DECODE -> new BatchPrefillDecodeForwardPlan(model, components, TornadoVMMasterPlan.PREFILL_BATCH_SIZE);
+ };
+ }
+
+ // ── Model+quant helpers — STANDARD only ──────────────────────────────────
+
+ private static ForwardPlan createMistralFP16Plan(ExecutionMode mode, LlamaState state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for MISTRAL + F16");
+ return new SingleTokenForwardPlan(model, new MistralFP16PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createMistralQ8_0Plan(ExecutionMode mode, LlamaState state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for MISTRAL + Q8_0");
+ return new SingleTokenForwardPlan(model, new MistralQ8_0PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createDevstralFP16Plan(ExecutionMode mode, DevstralState state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for DEVSTRAL_2 + F16");
+ return new SingleTokenForwardPlan(model, new DevstralFP16PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createDevstralQ8_0Plan(ExecutionMode mode, DevstralState state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for DEVSTRAL_2 + Q8_0");
+ return new SingleTokenForwardPlan(model, new DevstralQ8_0PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createQwen2FP16Plan(ExecutionMode mode, Qwen2State state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for QWEN_2 + F16");
+ return new SingleTokenForwardPlan(model, new Qwen2FP16PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createQwen2Q8_0Plan(ExecutionMode mode, Qwen2State state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for QWEN_2 + Q8_0");
+ return new SingleTokenForwardPlan(model, new Qwen2Q8_0PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createQwen3FP16Plan(ExecutionMode mode, Qwen3State state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for QWEN_3 + F16");
+ return new SingleTokenForwardPlan(model, new Qwen3FP16PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createQwen3Q8_0Plan(ExecutionMode mode, Qwen3State state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for QWEN_3 + Q8_0");
+ return new SingleTokenForwardPlan(model, new Qwen3Q8_0PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createPhi3FP16Plan(ExecutionMode mode, Phi3State state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for PHI_3 + F16");
+ return new SingleTokenForwardPlan(model, new Phi3FP16PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createPhi3Q8_0Plan(ExecutionMode mode, Phi3State state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for PHI_3 + Q8_0");
+ return new SingleTokenForwardPlan(model, new Phi3Q8_0PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createGraniteFP16Plan(ExecutionMode mode, GraniteState state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for GRANITE + F16");
+ return new SingleTokenForwardPlan(model, new GraniteFP16PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createGraniteQ8_0Plan(ExecutionMode mode, GraniteState state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for GRANITE + Q8_0");
+ return new SingleTokenForwardPlan(model, new GraniteQ8_0PlanComponents(state, model));
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java
new file mode 100644
index 00000000..fa4b4cd3
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java
@@ -0,0 +1,57 @@
+package org.beehive.gpullama3.tornadovm.plan;
+
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.plan.components.PrefillDecodeForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.layout.PrefillDecodeForwardTaskGraphLayout;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Topology plan for the N+2 prefill/decode forward pass.
+ *
+ *
+ * [0] activation ← prefillDecodeActivation()
+ * [1..N] layers ← prefillDecodeTransformerLayers()
+ * [N+1] logits ← decodeLogits(String)
+ *
+ *
+ *
+ * [0] activation ← singleTokenActivation()
+ * [1..N] layers ← singleTokenTransformerLayers()
+ * [N+1] logits ← singleTokenLogits(String)
+ *
+ */
+public class SingleTokenForwardPlan extends ForwardPlan {
+
+ private final SingleTokenForwardTaskGraphLayout taskGraphLayout;
+
+ public SingleTokenForwardPlan(Model model, SingleTokenForwardPlanComponents components) {
+ int N = model.configuration().numberOfLayers();
+ this.taskGraphLayout = new SingleTokenForwardTaskGraphLayout(N);
+
+ List
+ *
+ *
+ * Prefill-decode inference with TornadoVM is implemented by {@link TornadoVMMasterPlanPrefillDecode}. + * It employs a {@link PrefillDecodeForwardPlan} instance to represent the complete + * prefill-decode forward operation as a chain of distinct TornadoVM TaskGraphs. + * The components of this chain are represented by the following components: + *
+ * Single-token inference with TornadoVM is implemented by {@link TornadoVMMasterPlanSingleToken}. + * It employees a {@link SingleTokenForwardPlan} instance to represent the complete + * single-token forward operation as a chain of distinct TornadoVM TaskGraphs. + * The components of this chain are represented by the following components: + *
Used in the 2N+3 batch-prefill/decode plan. Consumes + * {@code wrapKeyCache}/{@code wrapValueCache} from the last batch-prefill layer, + * converts the single-token embedding to FP32, then re-persists the KV cache so + * that decode layer 0 can consume it.
+ */ +public class BatchDecodeActivation implements ActivationTaskGraph { + + private final ImmutableTaskGraph itg; + private final int dim; + + public BatchDecodeActivation(LlamaState state, LlamaConfiguration config, + String lastBatchLayerId, boolean isQ8) { + this.dim = config.dim(); + KernelContext ctx = new KernelContext(); + this.itg = buildGraph(ctx, state, lastBatchLayerId, isQ8).snapshot(); + } + + private TaskGraph buildGraph(KernelContext ctx, LlamaState state, + String lastBatchLayerId, boolean isQ8) { + TaskGraph tg = new TaskGraph("decodeActivation") + .consumeFromDevice(lastBatchLayerId, state.wrapKeyCache, state.wrapValueCache) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX); + if (isQ8) { + tg.task("updateX", TransformerComputeKernels::convertQ8_0toFP32, + ctx, (ByteArray) state.embeddingX, state.wrapX); + } else { + tg.task("updateX", TransformerComputeKernels::convertFP16toFP32, + ctx, (HalfFloatArray) state.embeddingX, state.wrapX); + } + return tg.persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return itg; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler scheduler) { + scheduler.addWorkerGrid("decodeActivation.updateX", + WorkerGridFactory.genericWorker(dim, 128)); + return scheduler; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java new file mode 100644 index 00000000..17b80cfa --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java @@ -0,0 +1,72 @@ +package org.beehive.gpullama3.tornadovm.plan.components.activation; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Batch-prefill activation graph ("prefillActivation"). + * + *Provides all layer objects and the FP16 embedding preparer (raw half-float byte copy).
+ */ +public class LlamaFP16PlanComponents implements BatchPrefillDecodeForwardPlanComponents { + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final LlamaConfiguration config; + private final SchedulerType schedulerType; + + public LlamaFP16PlanComponents(LlamaState state, Model model) { + this.state = state; + this.config = (LlamaConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + // ── Activations ─────────────────────────────────────────────────────────── + + @Override public ActivationTaskGraph singleTokenActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public ActivationTaskGraph prefillDecodeActivation() { + return new Activation("decodeActivation", state, weights, config); + } + + @Override public ActivationTaskGraph batchPrefillActivation(int batchSize) { + return new BatchPrefillActivation(state, config, batchSize, false); + } + + @Override public ActivationTaskGraph batchDecodeActivation(String lastBatchLayerId) { + return new BatchDecodeActivation(state, config, lastBatchLayerId, false); + } + + // ── Transformer layer TaskGraphs ────────────────────────────────────────────────────── + + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { + return new LlamaFP16FFNLayers("llamaFFN", state, weights, config, schedulerType); + } + + @Override public TransformerLayerTaskGraphs prefillDecodeTransformerLayers() { + return new LlamaFP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + } + + @Override public TransformerLayerTaskGraphs batchDecodeTransformerLayers() { + return new LlamaFP16FFNLayersDecode("decode", state, weights, config, schedulerType); + } + + @Override public BatchPrefillTransformerLayerTaskGraphs batchPrefillTransformerLayers(int batchSize) { + return new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize); + } + + // ── Logits layers ───────────────────────────────────────────────────────── + + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } + + @Override public AbstractLogitsTaskGraph decodeLogits(String previousGraphId) { + return new LogitsFP16LayerDecode("logits", state, weights, config, previousGraphId, schedulerType); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java new file mode 100644 index 00000000..7e648070 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.mistral.MistralConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.MistralFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class MistralFP16PlanComponents implements SingleTokenForwardPlanComponents { + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final MistralConfiguration config; + private final SchedulerType schedulerType; + + public MistralFP16PlanComponents(LlamaState state, Model model) { + this.state = state; + this.config = (MistralConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationTaskGraph singleTokenActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { + return new MistralFP16FFNLayers("mistralFFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java new file mode 100644 index 00000000..f8c56759 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.phi3.Phi3Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Phi3FP16FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Phi3FP16PlanComponents implements SingleTokenForwardPlanComponents { + + private final Phi3State state; + private final Phi3TornadoWeights weights; + private final Phi3Configuration config; + private final SchedulerType schedulerType; + + public Phi3FP16PlanComponents(Phi3State state, Model model) { + this.state = state; + this.config = (Phi3Configuration) model.configuration(); + this.weights = (Phi3TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationTaskGraph singleTokenActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { + return new Phi3FP16FFNLayers("phi3FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java new file mode 100644 index 00000000..2bba98f7 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen2FP16FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Qwen2FP16PlanComponents implements SingleTokenForwardPlanComponents { + + private final Qwen2State state; + private final Qwen2TornadoWeights weights; + private final Qwen2Configuration config; + private final SchedulerType schedulerType; + + public Qwen2FP16PlanComponents(Qwen2State state, Model model) { + this.state = state; + this.config = (Qwen2Configuration) model.configuration(); + this.weights = (Qwen2TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationTaskGraph singleTokenActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { + return new Qwen2FP16FFNLayers("qwen2FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java new file mode 100644 index 00000000..59cd920d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Qwen3FP16PlanComponents implements SingleTokenForwardPlanComponents { + + private final Qwen3State state; + private final Qwen3TornadoWeights weights; + private final Qwen3Configuration config; + private final SchedulerType schedulerType; + + public Qwen3FP16PlanComponents(Qwen3State state, Model model) { + this.state = state; + this.config = (Qwen3Configuration) model.configuration(); + this.weights = (Qwen3TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationTaskGraph singleTokenActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { + return new Qwen3FP16FFNLayers("qwen3FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java new file mode 100644 index 00000000..bae1b073 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.DevstralState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.devstral.DevstralConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.DevstralQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class DevstralQ8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final DevstralState state; + private final LlamaTornadoWeights weights; + private final DevstralConfiguration config; + private final SchedulerType schedulerType; + + public DevstralQ8_0PlanComponents(DevstralState state, Model model) { + this.state = state; + this.config = (DevstralConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationTaskGraph singleTokenActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { + return new DevstralQ8_0FFNLayers("devstralFFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java new file mode 100644 index 00000000..67f4aca3 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.GraniteState; +import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.GraniteQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsGraniteQ8_0Layer; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class GraniteQ8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final GraniteState state; + private final GraniteTornadoWeights weights; + private final GraniteConfiguration config; + private final SchedulerType schedulerType; + + public GraniteQ8_0PlanComponents(GraniteState state, Model model) { + this.state = state; + this.config = (GraniteConfiguration) model.configuration(); + this.weights = (GraniteTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationTaskGraph singleTokenActivation() { + return new ActivationGranite("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { + return new GraniteQ8_0FFNLayers("graniteFFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { + return new LogitsGraniteQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java new file mode 100644 index 00000000..d8cc916e --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java @@ -0,0 +1,94 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersPrefillDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LogitsQ8_0LayerDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill.LlamaQ8_0LayersBatchPrefill; +import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchDecodeActivation; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchPrefillActivation; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + + +/** + * {@link BatchPrefillDecodeForwardPlanComponents} for Llama + Q8_0. + * + *Batch embedding prep: CPU dequantizes Q8_0 embeddings into {@code wrapXBatch} (FP32). + * Decode embedding prep: raw Q8_0 block copy into {@code embeddingX} for on-device conversion.
+ */ +public class LlamaQ8_0PlanComponents implements BatchPrefillDecodeForwardPlanComponents { + + private static final int BLOCK_SIZE = 32; + private static final int Q8_0_BLOCK_BYTES = 34; + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final LlamaConfiguration config; + private final SchedulerType schedulerType; + + public LlamaQ8_0PlanComponents(LlamaState state, Model model) { + this.state = state; + this.config = (LlamaConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + // ── Activations ─────────────────────────────────────────────────────────── + + @Override public ActivationTaskGraph singleTokenActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public ActivationTaskGraph prefillDecodeActivation() { + return new Activation("decodeActivation", state, weights, config); + } + + @Override public ActivationTaskGraph batchPrefillActivation(int batchSize) { + return new BatchPrefillActivation(state, config, batchSize, true); + } + + @Override public ActivationTaskGraph batchDecodeActivation(String lastBatchLayerId) { + return new BatchDecodeActivation(state, config, lastBatchLayerId, true); + } + + // ── Transformer layer TaskGraphs ────────────────────────────────────────────────────── + + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { + return new LlamaQ8_0FFNLayers("llamaFFN", state, weights, config, schedulerType); + } + + @Override public TransformerLayerTaskGraphs prefillDecodeTransformerLayers() { + return new LlamaQ8_0FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + } + + @Override public TransformerLayerTaskGraphs batchDecodeTransformerLayers() { + return new LlamaQ8_0FFNLayersDecode("decode", state, weights, config, schedulerType); + } + + @Override public BatchPrefillTransformerLayerTaskGraphs batchPrefillTransformerLayers(int batchSize) { + return new LlamaQ8_0LayersBatchPrefill(state, weights, config, batchSize); + } + + // ── Logits layers ───────────────────────────────────────────────────────── + + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } + + @Override public AbstractLogitsTaskGraph decodeLogits(String previousGraphId) { + return new LogitsQ8_0LayerDecode("logits", state, weights, config, previousGraphId, schedulerType); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java new file mode 100644 index 00000000..025dd3a7 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.mistral.MistralConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.MistralQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class MistralQ8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final MistralConfiguration config; + private final SchedulerType schedulerType; + + public MistralQ8_0PlanComponents(LlamaState state, Model model) { + this.state = state; + this.config = (MistralConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationTaskGraph singleTokenActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { + return new MistralQ8_0FFNLayers("mistralFFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java new file mode 100644 index 00000000..2d91aa9b --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.phi3.Phi3Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Phi3Q8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Phi3Q8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final Phi3State state; + private final Phi3TornadoWeights weights; + private final Phi3Configuration config; + private final SchedulerType schedulerType; + + public Phi3Q8_0PlanComponents(Phi3State state, Model model) { + this.state = state; + this.config = (Phi3Configuration) model.configuration(); + this.weights = (Phi3TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationTaskGraph singleTokenActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { + return new Phi3Q8_0FFNLayers("phi3FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java new file mode 100644 index 00000000..9dbed69f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen2Q8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Qwen2Q8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final Qwen2State state; + private final Qwen2TornadoWeights weights; + private final Qwen2Configuration config; + private final SchedulerType schedulerType; + + public Qwen2Q8_0PlanComponents(Qwen2State state, Model model) { + this.state = state; + this.config = (Qwen2Configuration) model.configuration(); + this.weights = (Qwen2TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationTaskGraph singleTokenActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { + return new Qwen2Q8_0FFNLayers("qwen2FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java new file mode 100644 index 00000000..792e889b --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Qwen3Q8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final Qwen3State state; + private final Qwen3TornadoWeights weights; + private final Qwen3Configuration config; + private final SchedulerType schedulerType; + + public Qwen3Q8_0PlanComponents(Qwen3State state, Model model) { + this.state = state; + this.config = (Qwen3Configuration) model.configuration(); + this.weights = (Qwen3TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationTaskGraph singleTokenActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { + return new Qwen3Q8_0FFNLayers("qwen3FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/BatchPrefillDecodeForwardTaskGraphLayout.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/BatchPrefillDecodeForwardTaskGraphLayout.java new file mode 100644 index 00000000..aec212bd --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/BatchPrefillDecodeForwardTaskGraphLayout.java @@ -0,0 +1,21 @@ +package org.beehive.gpullama3.tornadovm.plan.layout; + +/** + * Graph-index arithmetic for the 2N+3 batch-prefill/decode forward plan. + * + *
+ * [0] batchPrefillActivation
+ * [1..N] batchPrefillLayer_0 .. batchPrefillLayer_{N-1}
+ * [N+1] decodeActivation (consumes + re-persists KV cache)
+ * [N+2..2N+1] decodeLayer_0 .. decodeLayer_{N-1}
+ * [2N+2] logits
+ *
+ */
+public record BatchPrefillDecodeForwardTaskGraphLayout(int N) {
+ public int batchActivationIdx() { return 0; }
+ public int batchLayerIdx(int i) { return 1 + i; }
+ public int decodeActivationIdx() { return N + 1; }
+ public int decodeLayerIdx(int i) { return N + 2 + i; }
+ public int logitsIdx() { return 2 * N + 2; }
+ public int totalGraphs() { return 2 * N + 3; }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/PrefillDecodeForwardTaskGraphLayout.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/PrefillDecodeForwardTaskGraphLayout.java
new file mode 100644
index 00000000..c252c92d
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/PrefillDecodeForwardTaskGraphLayout.java
@@ -0,0 +1,17 @@
+package org.beehive.gpullama3.tornadovm.plan.layout;
+
+/**
+ * Graph-index arithmetic for the N+2 prefill/decode forward plan.
+ *
+ *
+ * [0] decodeActivation
+ * [1..N] layer_0 .. layer_{N-1}
+ * [N+1] logits
+ *
+ */
+public record PrefillDecodeForwardTaskGraphLayout(int N) {
+ public int activationIdx() { return 0; }
+ public int layerIdx(int i) { return 1 + i; }
+ public int logitsIdx() { return N + 1; }
+ public int totalGraphs() { return N + 2; }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/SingleTokenForwardTaskGraphLayout.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/SingleTokenForwardTaskGraphLayout.java
new file mode 100644
index 00000000..d8fe8d13
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/SingleTokenForwardTaskGraphLayout.java
@@ -0,0 +1,17 @@
+package org.beehive.gpullama3.tornadovm.plan.layout;
+
+/**
+ * Graph-index arithmetic for the N+2 single-token forward plan.
+ *
+ *
+ * [0] activation
+ * [1..N] layer_0 .. layer_{N-1}
+ * [N+1] logits
+ *
+ */
+public record SingleTokenForwardTaskGraphLayout(int N) {
+ public int activationIdx() { return 0; }
+ public int layerIdx(int i) { return 1 + i; }
+ public int logitsIdx() { return N + 1; }
+ public int totalGraphs() { return N + 2; }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerDetectionService.java
similarity index 93%
rename from src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerDetectionService.java
index 5a81caa8..5e29c528 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerDetectionService.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.strategy;
+package org.beehive.gpullama3.tornadovm.scheduling;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.ModelType;
@@ -9,7 +9,6 @@
public class SchedulerDetectionService {
-
public static SchedulerType determineSchedulerType(Model model) {
TornadoRuntime tornadoRuntime = TornadoRuntimeProvider.getTornadoRuntime();
String platformName = tornadoRuntime.getBackend(0)
@@ -24,4 +23,4 @@ public static SchedulerType determineSchedulerType(Model model) {
return (isNvidia && isNotMistral) ? SchedulerType.NVIDIA : SchedulerType.NON_NVIDIA;
}
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerType.java b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerType.java
new file mode 100644
index 00000000..bbd27169
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerType.java
@@ -0,0 +1,5 @@
+package org.beehive.gpullama3.tornadovm.scheduling;
+
+public enum SchedulerType {
+ NVIDIA, NON_NVIDIA
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/WorkerGridFactory.java
similarity index 98%
rename from src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/scheduling/WorkerGridFactory.java
index af39c133..4d5a55a5 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/WorkerGridFactory.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.tornadovm.layerplanner;
+package org.beehive.gpullama3.tornadovm.scheduling;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
@@ -98,5 +98,4 @@ private static int findOptimalLocalSize(int size) {
}
return optimal;
}
-
}