From 45204f1bb3e138b145014ac003ab35106cb6bdd2 Mon Sep 17 00:00:00 2001
From: Orion Papadakis Receives the KV-cache device pointer from batch layer N via
* {@code consumeFromDevice}, then re-emits it via {@code persistOnDevice} so
@@ -121,73 +137,86 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) {
* Both halves of the chain are required; without the re-persist the pointer is
* not forwarded in interpreter (non-CUDA-graph) mode.
Layer 0 consumes the KV cache from device (passed through by the decode activation + * graph, which relays it from the last batch prefill layer). No FIRST_EXECUTION allocation + * for the KV cache — it was already allocated in the batch prefill phase.
+ */ +public class LlamaQ8_0FFNLayersDecode extends LlamaQ8_0FFNLayers { + + public LlamaQ8_0FFNLayersDecode(String taskGraph, LlamaState state, + LlamaTornadoWeights weights, LlamaConfiguration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, state.temp, state.tempFFN); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapAtt, state.wrapHb); + layer.consumeFromDevice("decodeActivation", state.wrapKeyCache, state.wrapValueCache); + } else { + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder); + } + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java new file mode 100644 index 00000000..a1f02ada --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java @@ -0,0 +1,47 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers; +import uk.ac.manchester.tornado.api.TaskGraph; + +/** + * Decode FFN layers for the single-token prefill/decode plan + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode}). + * + *Layer 0 delegates to {@link LlamaQ8_0FFNLayers#configureLayerDataTransfers} which + * includes {@code FIRST_EXECUTION} for {@code wrapKeyCache} and {@code wrapValueCache}, + * allocating the KV cache on the very first forward pass. Layers 1+ use explicit + * predecessor names for all consumed objects, required by TornadoVM's interpreter mode.
+ */ +public class LlamaQ8_0FFNLayersPrefillDecode extends LlamaQ8_0FFNLayers { + + public LlamaQ8_0FFNLayersPrefillDecode(String taskGraph, LlamaState state, + LlamaTornadoWeights weights, LlamaConfiguration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + return super.configureLayerDataTransfers(layer, 0); + } + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder); + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java new file mode 100644 index 00000000..054b7f7f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java @@ -0,0 +1,35 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import uk.ac.manchester.tornado.api.TaskGraph; + +/** + * Logits layer for the unified batched prefill-decode plan (Q8_0). + * + *Extends {@link LogitsQ8_0Layer} with KV-cache pass-through so the device pointers for + * {@code wrapKeyCache} and {@code wrapValueCache} survive the logits → decode-activation + * boundary between decode tokens. Without the pass-through, the KV-cache pointer is absent + * from the logits persisted set, cleared to null, and the first decode layer crashes with + * an NPE in {@code executeAlloc}.
+ */ +public class LogitsQ8_0LayerDecode extends LogitsQ8_0Layer { + + public LogitsQ8_0LayerDecode(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); + } + + @Override + protected void configureAdditionalConsumes(TaskGraph logits) { + logits.consumeFromDevice(lastTaskGraphID, state.wrapKeyCache, state.wrapValueCache); + } + + @Override + protected void configureAdditionalPersists(TaskGraph logits) { + logits.persistOnDevice(state.wrapKeyCache, state.wrapValueCache); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java new file mode 100644 index 00000000..79ea9bc2 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java @@ -0,0 +1,230 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +/** + * Prefill FFN layers with batching for the unified batched prefill-decode plan (Q8_0). + * + *Mirrors {@link org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill} + * but uses Q8_0 kernels with inline dequantization. Key differences from the FP16 path:
+ *Delegates the full chunk to - * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}, + * {@link TornadoVMMasterPlanBatchPrefillDecode#tornadoVMForwardBatchPrefill}, * which handles embedding lookup and GPU execution internally.
* * @param model the LLaMA model @@ -175,7 +175,7 @@ public static void batchForwardJavaPrefill(Model model, State state, int[] token * @param plan the batched prefill/decode GPU plan */ public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int startPos, int chunkSize, - TornadoVMMasterPlanWithBatchPrefillDecode plan) { + TornadoVMMasterPlanBatchPrefillDecode plan) { plan.tornadoVMForwardBatchPrefill(tokens, startPos, model, chunkSize); } @@ -183,7 +183,7 @@ public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int s * GPU decode forward pass (Phase 4). * *Delegates a single-token decode step to - * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, + * {@link TornadoVMMasterPlanBatchPrefillDecode#tornadoVMForwardDecode}, * which copies the token embedding and runs the decode + logits graphs.
* * @param model the LLaMA model @@ -193,7 +193,7 @@ public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int s * @return logits array for token sampling */ public static FloatArray forwardTornadoVMDecode(Model model, int token, int position, - TornadoVMMasterPlanWithBatchPrefillDecode plan) { + TornadoVMMasterPlanBatchPrefillDecode plan) { return plan.tornadoVMForwardDecode(token, position, model); } } diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java index dbd81297..9c812b10 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java @@ -7,7 +7,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tensor.standard.FloatTensor; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode; import java.lang.foreign.MemorySegment; @@ -131,7 +131,7 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p * *Copies the token embedding into {@code state.embeddingX} (same as * {@link InferenceCore#forwardTornadoVM}) then delegates to - * {@link TornadoVMMasterPlanWithPrefillDecode#tornadoVMForwardPrefill}, + * {@link TornadoVMMasterPlanPrefillDecode#tornadoVMForwardPrefill}, * which executes preprocessing + layer graphs but skips the logits graph.
* * @param model the LLaMA model (must carry {@link TornadoWeights}, FP16 only) @@ -142,7 +142,7 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p * @throws UnsupportedOperationException if the model uses Q8_0 weights */ public static void forwardTornadoVMPrefill(Model model, State state, int token, int position, - TornadoVMMasterPlanWithPrefillDecode prefillPlan) { + TornadoVMMasterPlanPrefillDecode prefillPlan) { final Configuration configuration = model.configuration(); final TornadoWeights weights = (TornadoWeights) model.weights(); diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java index 06dbe256..5f617bda 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java @@ -7,7 +7,7 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode; import java.util.ArrayList; import java.util.Arrays; @@ -163,8 +163,8 @@ public static ListThree concrete implementations exist:
*The {@link #initializeTornadoVMPlan} factory selects the implementation based on * {@code llama.withPrefillDecode} and {@code llama.prefillBatchSize}:
*When {@code llama.withPrefillDecode=true} and {@code llama.prefillBatchSize > 1}, - * a {@link TornadoVMMasterPlanWithBatchPrefillDecode} is returned. - * Otherwise a {@link TornadoVMMasterPlanStandard} is returned (used for the baseline + * a {@link TornadoVMMasterPlanBatchPrefillDecode} is returned. + * Otherwise a {@link TornadoVMMasterPlanSingleToken} is returned (used for the baseline * path and the sequential prefill/decode path when batch size is 1).
* * @param state the model state @@ -56,13 +56,13 @@ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { if (WITH_PREFILL_DECODE && PREFILL_BATCH_SIZE > 1) { // GPU path with batched prefill/decode - plan = new TornadoVMMasterPlanWithBatchPrefillDecode(state, model); + plan = new TornadoVMMasterPlanBatchPrefillDecode(state, model); } else if (WITH_PREFILL_DECODE) { // GPU path with simple prefill/decode - plan = new TornadoVMMasterPlanWithPrefillDecode(state, model); + plan = new TornadoVMMasterPlanPrefillDecode(state, model); } else { // GPU path with no prefill/decode - plan = new TornadoVMMasterPlanStandard(state, model); + plan = new TornadoVMMasterPlanSingleToken(state, model); } model.setTornadoVMPlan(plan); return plan; @@ -76,7 +76,7 @@ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { void forceCopyInReadOnlyData(); - FloatArray tornadoVMForwardExecuteLayered(int position); + FloatArray tornadoVMExecuteForward(int position); /** Releases all device memory held by this plan. */ void freeTornadoExecutionPlan(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefillDecode.java new file mode 100644 index 00000000..968e7b19 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefillDecode.java @@ -0,0 +1,165 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.auxiliary.RunMetrics; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tornadovm.plan.BatchPrefillDecodeForwardPlan; +import org.beehive.gpullama3.tornadovm.plan.ForwardPlanFactory; +import org.beehive.gpullama3.tornadovm.plan.layout.BatchPrefillDecodeForwardTaskGraphLayout; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * GPU execution plan for batched prefill + single-token decode. + * + *A single {@link TornadoExecutionPlan} holds all task graphs for + * batched prefill and single-token decode phases:
+ * + *TaskGraph layout (2N+3 TaskGraphs total):
+ *+ * [0] batchPrefillActivation B×dim embeddings → FP32 wrapXBatch + * [1..N] batch-prefill layers B tokens, all transformer ops + * [N+1] decodeActivation single-token embedding → FP32 + KV-cache pass-through + * [N+2..2N+1] decode layers single-token, standard kernels + * [2N+2] logits + *+ */ +public class TornadoVMMasterPlanBatchPrefillDecode implements TornadoVMMasterPlan { + + private final State state; + private final Model model; + private final Configuration config; + + BatchPrefillDecodeForwardPlan batchPrefillDecodeForwardPlan; + BatchPrefillDecodeForwardTaskGraphLayout taskGraphLayout; + public TornadoExecutionPlan executionPlan; + + // ── Construction ───────────────────────────────────────────────────────── + TornadoVMMasterPlanBatchPrefillDecode(State initialState, Model model) { + if (ENABLE_TORNADOVM_INIT_TIME) { + System.err.println("\nStarting TornadoVM initialization..."); + } + + this.state = initialState; + this.model = model; + this.config = model.configuration(); + + long startTime = System.nanoTime(); + this.executionPlan = createExecutionPlan(); + long planCreationTime = System.nanoTime(); + + if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); + executionPlan.withPreCompilation(); + long warmupTime = System.nanoTime(); + + forceCopyInReadOnlyData(); + long copyTime = System.nanoTime(); + + RunMetrics.setTornadoMetrics(planCreationTime - startTime, warmupTime - planCreationTime, copyTime - warmupTime); + } + + // ── Plan construction ───────────────────────────────────────────────────── + + @Override + public TornadoExecutionPlan createExecutionPlan() { + GGMLType weightType = model.weights().getWeightType(); + this.batchPrefillDecodeForwardPlan = + ForwardPlanFactory.createBatchPrefillDecode(weightType, state, model); + this.taskGraphLayout = batchPrefillDecodeForwardPlan.getTaskGraphLayout(); + var taskGraphs = batchPrefillDecodeForwardPlan.getImmutableTaskGraphs(); + return new TornadoExecutionPlan(taskGraphs.toArray(new ImmutableTaskGraph[0])); + } + + // ── Initialisation ──────────────────────────────────────────────────────── + + @Override + public void forceCopyInReadOnlyData() { + batchPrefillDecodeForwardPlan.getEmbeddingPreparer().initBatchState(); + state.wrapX.clear(); + state.positionHolder.init(0); + + for (int i = 0; i <= taskGraphLayout.logitsIdx(); i++) { + var g = executionPlan.withGraph(i) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) g.withCUDAGraph(); + g.execute(); + } + } + + // ── Forward passes ──────────────────────────────────────────────────────── + + /** + * Batch prefill: runs graphs 0..N (activation + N layers), skips logits. + * + * @param tokenIds token IDs for this chunk + * @param startPos sequence position of tokenIds[0] + * @param model model (unused — kept for API compatibility) + * @param chunkSize actual number of tokens in this chunk (≤ batchSize) + */ + public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model model, int chunkSize) { + batchPrefillDecodeForwardPlan.getEmbeddingPreparer().copyBatchEmbeddings(tokenIds, startPos, chunkSize); + + var batchAct = executionPlan.withGraph(taskGraphLayout.batchActivationIdx()) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) batchAct.withCUDAGraph(); + batchAct.execute(); + + for (int l = 0; l < config.numberOfLayers(); l++) { + var batchLayer = executionPlan.withGraph(taskGraphLayout.batchLayerIdx(l)) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) batchLayer.withCUDAGraph(); + batchLayer.execute(); + } + } + + /** + * Single-token decode: runs graphs N+1..2N+2 (activation + N layers + logits). + * + * @param token token ID to process + * @param position sequence position + * @param model model (unused — kept for API compatibility) + * @return logits array for sampling + */ + public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { + batchPrefillDecodeForwardPlan.getEmbeddingPreparer().copyDecodeEmbedding(token); + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + var decodeAct = executionPlan.withGraph(taskGraphLayout.decodeActivationIdx()) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) decodeAct.withCUDAGraph(); + decodeAct.execute(); + + for (int l = 0; l < config.numberOfLayers(); l++) { + var decodeLayer = executionPlan.withGraph(taskGraphLayout.decodeLayerIdx(l)) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) decodeLayer.withCUDAGraph(); + decodeLayer.execute(); + } + + state.tempLogits.clear(); + state.wrapLogits.clear(); + + var logits = executionPlan.withGraph(taskGraphLayout.logitsIdx()) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) logits.withCUDAGraph(); + logits.execute(); + + return state.wrapLogits; + } + + @Override + public FloatArray tornadoVMExecuteForward(int position) { + throw new UnsupportedOperationException( + "Use tornadoVMForwardBatchPrefill / tornadoVMForwardDecode for batch plan"); + } + + @Override + public void freeTornadoExecutionPlan() { + executionPlan.freeDeviceMemory(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanPrefillDecode.java new file mode 100644 index 00000000..a1f56549 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanPrefillDecode.java @@ -0,0 +1,166 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.auxiliary.RunMetrics; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tornadovm.plan.ForwardPlanFactory; +import org.beehive.gpullama3.tornadovm.plan.PrefillDecodeForwardPlan; +import org.beehive.gpullama3.tornadovm.plan.layout.PrefillDecodeForwardTaskGraphLayout; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * GPU execution plan for sequential (single-token) prefill/decode separation. + * + *
A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache + * ({@code wrapKeyCache}, {@code wrapValueCache}) is allocated once and remains on + * device across both phases. Prefill and decode reuse the same N layer graphs; + * only the logits graph is skipped during prefill.
+ * + *Graph layout (N+2 graphs total):
+ *+ * [0] decodeActivation single-token FP16 → FP32; KV-cache allocated on first execution + * [1..N] layer_0..layer_N-1 transformer layers (attention + FFN) + * [N+1] logits final RMSNorm + wcls matmul + *+ * + *
Two forward passes:
+ *A single {@link TornadoExecutionPlan} holds all {@link TaskGraph} for - * batched prefill and single-token decode phases with the following structure:
. - * - *TaskGraph layout (2N+3 TaskGraphs total):
- *- * [0] prefill batch activation B×dim FP16 → FP32 - * [1..N] prefill batch layer graphs B tokens, all transformer ops - * [N+1] decode activation single-token FP16 → FP32 + KV-cache pass-through - * [N+2..2N+1] decode layer graphs single-token, standard kernels - * [2N+2] logits graph - *- * - *
- * Incorporating cross-phase {@link TaskGraph}s withing a single {@link TornadoExecutionPlan} - * is necessary to enable KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) sharing - * across prefill and decode phases. The KV cache pointers are chained across {@link TaskGraph}s - * via the {@code persistOnDevice}/{@code consumeFromDevice} API within the {@link TornadoExecutionPlan}. - *
- * - *KV cache pointer chain across phases:
- *- * batchLayer[N-1] --persistOnDevice(wrapKeyCache)-→ - * decodeActivation --consumeFromDevice(wrapKeyCache)-→ (pass-through) - * decodeLayer[0] --consumeFromDevice(wrapKeyCache)-→ (used by attention) - *- */ -public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMasterPlan { - - private final LlamaState state; - private final Model model; - private final LlamaConfiguration config; - private final int batchSize; - private final int N; // numberOfLayers - private final TornadoExecutionPlan executionPlan; - private final GridScheduler gridScheduler; - - // ── Graph-index helpers ─────────────────────────────────────────────────── - private int batchActivationIdx() { return 0; } - private int batchLayerIdx(int i) { return 1 + i; } - private int decodeActivationIdx() { return N + 1; } - private int decodeLayerIdx(int i) { return N + 2 + i; } - private int logitsIdx() { return 2 * N + 2; } - - // ── Construction ───────────────────────────────────────────────────────── - TornadoVMMasterPlanWithBatchPrefillDecode(State initialState, Model model) { - if (ENABLE_TORNADOVM_INIT_TIME) { - System.err.println("\nStarting TornadoVM initialization..."); - } - - this.state = (LlamaState) initialState; // only LlamaFP16 supports batched prefill for now - this.model = model; - this.config = (LlamaConfiguration) model.configuration(); - this.batchSize = PREFILL_BATCH_SIZE; - this.N = config.numberOfLayers(); - this.gridScheduler = new GridScheduler(); - - long startTime = System.nanoTime(); - this.executionPlan = createExecutionPlan(); - long planCreationTime = System.nanoTime(); - - if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); - executionPlan.withPreCompilation(); - long warmupTime = System.nanoTime(); - - forceCopyInReadOnlyData(); - long copyTime = System.nanoTime(); - - RunMetrics.setTornadoMetrics(planCreationTime - startTime, warmupTime - planCreationTime, copyTime - warmupTime); - } - - // ── Batch Prefill Activation graphs ───────────────────────────────────────────────────── - - /** Graph 0 (FP16): B×dim FP16 embeddings → FP32 wrapXBatch. */ - private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { - return new TaskGraph("prefillActivation") - .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch) - .task("updateX", TransformerComputeKernels::convertFP16toFP32, - ctx, state.embeddingXBatch, state.wrapXBatch) - .persistOnDevice(state.wrapXBatch); - } - - /** Graph 0 (Q8_0): wrapXBatch pre-filled with FP32 by host; upload and persist. */ - private TaskGraph buildQ8_0BatchPrefillActivationGraph(KernelContext ctx) { - return new TaskGraph("prefillActivation") - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapXBatch) - .task("batchPassthrough", TransformerBatchPrefillKernels::batchPassthrough, - ctx, state.wrapXBatch) - .persistOnDevice(state.wrapXBatch); - } - - /** - * Graph N+1: single-token embedding → FP32 wrapX, with KV-cache pass-through. - * - *
Receives the KV-cache device pointer from batch layer N via - * {@code consumeFromDevice}, then re-emits it via {@code persistOnDevice} so - * that {@code updatePersistedObjectState()} can propagate it to decode layer 0. - * Both halves of the chain are required; without the re-persist the pointer is - * not forwarded in interpreter (non-CUDA-graph) mode.
- */ - private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID, - GGMLType weightType) { - TaskGraph tg = new TaskGraph("decodeActivation") - .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX); - if (weightType == GGMLType.Q8_0) { - tg.task("updateX", TransformerComputeKernels::convertQ8_0toFP32, - ctx, (ByteArray) state.embeddingX, state.wrapX); - } else { - tg.task("updateX", TransformerComputeKernels::convertFP16toFP32, - ctx, (HalfFloatArray) state.embeddingX, state.wrapX); - } - return tg.persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); - } - - @Override - public TornadoExecutionPlan createExecutionPlan() { - GGMLType weightType = model.weights().getWeightType(); - if (weightType != GGMLType.F16 && weightType != GGMLType.Q8_0) { - throw new UnsupportedOperationException( - "Batched prefill/decode GPU path not supported for weight type: " + weightType); - } - - LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); - SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); - ListA single {@link TornadoExecutionPlan} holds all graphs so that the KV cache - * ({@code wrapKeyCache}, {@code wrapValueCache}) is allocated once and remains on - * device across both phases. Prefill and decode reuse the same N layer graphs; - * only the logits graph is skipped during prefill.
- * - *Graph layout (N+2 graphs total):
- *- * [0] decodeActivation single-token FP16 → FP32; KV-cache allocated on first execution - * [1..N] layer_0..layer_N-1 transformer layers (attention + FFN) - * [N+1] logits final RMSNorm + wcls matmul - *- * - *
Two forward passes:
- *Outputs {@code wrapX} (FP32 hidden state) and persists it on device so that - * decode layer 0 can pick it up via {@code consumeFromDevice("decodeActivation", wrapX)}. - * The KV cache is not managed here — it is allocated on the first forward pass - * by decode layer 0 via {@code FIRST_EXECUTION}.
- */ - private TaskGraph buildActivationGraph(KernelContext ctx, GGMLType weightType) { - if (weightType == GGMLType.Q8_0) { - return new TaskGraph("decodeActivation") - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", TransformerComputeKernels::convertQ8_0toFP32, - ctx, (ByteArray) state.embeddingX, state.wrapX) - .persistOnDevice(state.wrapX); - } - return new TaskGraph("decodeActivation") - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", TransformerComputeKernels::convertFP16toFP32, - ctx, (HalfFloatArray) state.embeddingX, state.wrapX) - .persistOnDevice(state.wrapX); - } - - // ── Plan construction ───────────────────────────────────────────────────── - @Override - public TornadoExecutionPlan createExecutionPlan() { - GGMLType weightType = model.weights().getWeightType(); - if (weightType != GGMLType.F16 && weightType != GGMLType.Q8_0) { - throw new UnsupportedOperationException( - "Prefill/decode GPU path not supported for weight type: " + weightType); - } - - LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); - SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); - - ListThese kernels are meant to be registered in {@link TornadoVMMasterPlanWithPrefillDecode} + *
These kernels are meant to be registered in {@link TornadoVMMasterPlanBatchPrefillDecode} * batch task graphs; they are NOT invoked directly.
*/ public final class TransformerBatchPrefillKernels { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java deleted file mode 100644 index 9e211051..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java +++ /dev/null @@ -1,14 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner; - -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; - -import java.util.List; - -public interface GenericLayerPlanner { - - List- * The factory follows a routing logic: - *
- * Examples: - *
- * Each component is represented as an {@link ImmutableTaskGraph}, along with a - * corresponding {@link GridScheduler} configuration that defines how tasks are - * mapped on the GPU. - *
- * This method assembles all components into a unified execution pipeline and - * caches the resulting task graphs and scheduler for reuse across inference runs. - */ - protected final void createTornadoInferencePlan() { - ListImplemented by {@link Activation} and custom activation wrappers used by + * {@link org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents}.
+ */ +public interface ActivationGraph { + ImmutableTaskGraph getImmutableTaskGraph(); + GridScheduler updateGridScheduler(GridScheduler scheduler); +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/BatchPrefillTransformerLayerTaskGraphs.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/BatchPrefillTransformerLayerTaskGraphs.java new file mode 100644 index 00000000..0e7a762f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/BatchPrefillTransformerLayerTaskGraphs.java @@ -0,0 +1,17 @@ +package org.beehive.gpullama3.tornadovm.layers; + +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + +import java.util.List; + +/** + * Interface for a group of N batched-prefill transformer-layer {@link uk.ac.manchester.tornado.api.TaskGraph}. + * + *Implemented by {@code LlamaFP16LayersBatchPrefill} and {@code LlamaQ8_0LayersBatchPrefill}.
+ */ +public interface BatchPrefillTransformerLayerTaskGraphs { + ListImplemented by {@link AbstractFFNLayers} and its subclasses.
+ */ +public interface TransformerLayerTaskGraphs { + ListOverrides data-transfer declarations so that all cross-graph boundaries use * the explicit-source form of {@code consumeFromDevice}. The no-arg form (used by diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java index 2f5bac64..c1167796 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java @@ -3,14 +3,14 @@ import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** * Decode FFN layers for the single-token prefill/decode plan - * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode}). + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}). * *
Combines two concerns:
*Extends {@link LogitsFP16Layer} with KV-cache pass-through so the device * pointers for {@code wrapKeyCache} and {@code wrapValueCache} survive the diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java index a44425ef..44ce4b35 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java @@ -4,7 +4,8 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; @@ -17,7 +18,7 @@ /** * Prefill FFN layers with batching for the unified batched prefill-decode plan - * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}). + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}). * *
One {@link ImmutableTaskGraph} per transformer layer, each processing * {@code batchSize} tokens simultaneously via {@link TransformerBatchPrefillKernels}.
@@ -25,7 +26,7 @@ *KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) is persisted on device * after every layer so the subsequent single-token decode layers can consume it.
*/ -public class LlamaFP16LayersBatchPrefill { +public class LlamaFP16LayersBatchPrefill implements BatchPrefillTransformerLayerTaskGraphs { // Matches the local workgroup size used by the single-token kernels. static final int LOCAL_WORK_GROUP_SIZE = 32; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java index 0c5aedd6..0d717d2e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java @@ -4,8 +4,8 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.devstral.DevstralConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java index 8a91c75a..1aca4714 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java @@ -5,8 +5,8 @@ import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index 2cfacdc0..f5683904 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -4,8 +4,8 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java index c583bb00..dca19c83 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java @@ -8,7 +8,7 @@ import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.enums.DataTransferMode; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java index 1a1313e2..bf66118a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java @@ -7,8 +7,8 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java index 64864114..4a6f3151 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java @@ -4,8 +4,8 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.mistral.MistralConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index 8f693adc..0ffe25d6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -5,8 +5,8 @@ import org.beehive.gpullama3.model.phi3.Phi3Configuration; import org.beehive.gpullama3.tornadovm.kernels.Phi3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java index 6cdf32db..936c3573 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java @@ -6,8 +6,8 @@ import org.beehive.gpullama3.tornadovm.kernels.Qwen2Kernels; import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index 43b88fb1..1696644d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -5,8 +5,8 @@ import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java index e159c99b..0b2f13cb 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java @@ -3,14 +3,14 @@ import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** * Decode FFN layers for the unified batched prefill-decode plan - * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}). + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}). * *Layer 0 consumes the KV cache from device (passed through by the decode activation * graph, which relays it from the last batch prefill layer). No FIRST_EXECUTION allocation diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java index a1f02ada..388853f7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java @@ -3,13 +3,13 @@ import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers; import uk.ac.manchester.tornado.api.TaskGraph; /** * Decode FFN layers for the single-token prefill/decode plan - * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode}). + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}). * *
Layer 0 delegates to {@link LlamaQ8_0FFNLayers#configureLayerDataTransfers} which * includes {@code FIRST_EXECUTION} for {@code wrapKeyCache} and {@code wrapValueCache}, diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java index 054b7f7f..2a6714f2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java @@ -3,7 +3,7 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java index 79ea9bc2..14c30462 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java @@ -4,7 +4,8 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; @@ -30,7 +31,7 @@ *
Graph layout:
+ *+ * [0] batchPrefillActivation + * [1..N] batch-prefill transformer layers + * [N+1] decodeActivation (consumes + re-persists KV cache) + * [N+2..2N+1] decode transformer layers + * [2N+2] logits + *+ * + *
During batch prefill, the master plan executes graphs 0..N. + * During decode, graphs N+1..2N+2 run.
+ */ +public class BatchPrefillDecodeForwardPlan extends ForwardPlan { + + private final BatchPrefillDecodeForwardTaskGraphLayout taskGraphLayout; + private final EmbeddingPreparer embeddingPreparer; + + public BatchPrefillDecodeForwardPlan(Model model, BatchPrefillDecodeForwardPlanComponents components, int batchSize) { + int N = model.configuration().numberOfLayers(); + this.taskGraphLayout = new BatchPrefillDecodeForwardTaskGraphLayout(N); + this.embeddingPreparer = components.embeddingPreparer(); + + ListSubclasses assemble the {@link ImmutableTaskGraph} list and {@link GridScheduler} + * in their constructors by calling {@link #setGraphs}, then expose the results + * via {@link #getImmutableTaskGraphs} and {@link #getGridScheduler}.
+ */ +public abstract class ForwardPlan { + + private ListDispatches across three axes in order: + *
Use the typed convenience methods when the execution mode is known at the call site:
+ *Graph layout:
+ *+ * [0] decodeActivation + * [1..N] decode transformer layers + * [N+1] logits + *+ * + *
During prefill, the master plan executes graphs 0..N (skipping logits). + * During decode, all N+2 graphs run.
+ */ +public class PrefillDecodeForwardPlan extends ForwardPlan { + + private final PrefillDecodeForwardTaskGraphLayout taskGraphLayout; + + public PrefillDecodeForwardPlan(Model model, PrefillDecodeForwardPlanComponents components) { + int N = model.configuration().numberOfLayers(); + this.taskGraphLayout = new PrefillDecodeForwardTaskGraphLayout(N); + + ListGraph layout:
+ *+ * [0] activation + * [1..N] transformer layers + * [N+1] logits + *+ */ +public class SingleTokenForwardPlan extends ForwardPlan { + + private final SingleTokenForwardTaskGraphLayout taskGraphLayout; + + public SingleTokenForwardPlan(Model model, SingleTokenForwardPlanComponents components) { + int N = model.configuration().numberOfLayers(); + this.taskGraphLayout = new SingleTokenForwardTaskGraphLayout(N); + + List
Extends {@link PrefillDecodeForwardPlanComponents} with the batch-prefill + * activation, batch layer group, KV-cache decode activation, and the + * host-side embedding preparer.
+ */ +public interface BatchPrefillDecodeForwardPlanComponents extends PrefillDecodeForwardPlanComponents { + + ActivationGraph batchPrefillActivation(int batchSize); + + ActivationGraph batchDecodeActivation(String lastBatchLayerId); + + TransformerLayerTaskGraphs batchDecodeLayers(); + + BatchPrefillTransformerLayerTaskGraphs batchPrefillLayers(int batchSize); + + EmbeddingPreparer embeddingPreparer(); +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/EmbeddingPreparer.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/EmbeddingPreparer.java new file mode 100644 index 00000000..074bc4cd --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/EmbeddingPreparer.java @@ -0,0 +1,18 @@ +package org.beehive.gpullama3.tornadovm.plan.components; + +/** + * Prepares embedding vectors in host memory before GPU activation graphs run. + * + *Concrete implementations are format-specific (FP16 byte copy vs Q8_0 CPU dequantization) + * and live in the component-provider classes.
+ */ +public interface EmbeddingPreparer { + /** Clears the batch input buffer and resets the batch-start position holder to zero. */ + void initBatchState(); + + /** Copies or dequantizes embeddings for {@code chunkSize} tokens into the batch buffer and sets the batch-start position holder. */ + void copyBatchEmbeddings(int[] tokenIds, int startPos, int chunkSize); + + /** Copies or dequantizes the embedding for a single decode token into the single-token buffer. */ + void copyDecodeEmbedding(int token); +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java new file mode 100644 index 00000000..8f0b6cdd --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java @@ -0,0 +1,20 @@ +package org.beehive.gpullama3.tornadovm.plan.components; + +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; + +/** + * Components for the N+2 prefill/decode forward plan. + * + *Extends {@link SingleTokenForwardPlanComponents} with the decode-phase + * activation, KV-cache-aware layer group, and decode logits.
+ */ +public interface PrefillDecodeForwardPlanComponents extends SingleTokenForwardPlanComponents { + + ActivationGraph decodeActivation(); + + TransformerLayerTaskGraphs prefillDecodeLayers(); + + AbstractLogitsLayer decodeLogits(String previousGraphId); +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java new file mode 100644 index 00000000..6ecc2077 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java @@ -0,0 +1,21 @@ +package org.beehive.gpullama3.tornadovm.plan.components; + +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; + +/** + * Components for the single-token forward pass. + * + *All model+quantization combinations implement this interface. + * Models that support prefill/decode modes implement the richer + * {@link PrefillDecodeForwardPlanComponents} or {@link BatchPrefillDecodeForwardPlanComponents}.
+ */ +public interface SingleTokenForwardPlanComponents { + + ActivationGraph standardActivation(); + + TransformerLayerTaskGraphs standardLayers(); + + AbstractLogitsLayer standardLogits(String previousGraphId); +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java new file mode 100644 index 00000000..d3617a6c --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java @@ -0,0 +1,62 @@ +package org.beehive.gpullama3.tornadovm.plan.components.activation; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +/** + * Decode activation graph with KV-cache pass-through ("decodeActivation"). + * + *Used in the 2N+3 batch-prefill/decode plan. Consumes + * {@code wrapKeyCache}/{@code wrapValueCache} from the last batch-prefill layer, + * converts the single-token embedding to FP32, then re-persists the KV cache so + * that decode layer 0 can consume it.
+ */ +public class BatchDecodeActivation implements ActivationGraph { + + private final ImmutableTaskGraph itg; + private final int dim; + + public BatchDecodeActivation(LlamaState state, LlamaConfiguration config, + String lastBatchLayerId, boolean isQ8) { + this.dim = config.dim(); + KernelContext ctx = new KernelContext(); + this.itg = buildGraph(ctx, state, lastBatchLayerId, isQ8).snapshot(); + } + + private TaskGraph buildGraph(KernelContext ctx, LlamaState state, + String lastBatchLayerId, boolean isQ8) { + TaskGraph tg = new TaskGraph("decodeActivation") + .consumeFromDevice(lastBatchLayerId, state.wrapKeyCache, state.wrapValueCache) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX); + if (isQ8) { + tg.task("updateX", TransformerComputeKernels::convertQ8_0toFP32, + ctx, (ByteArray) state.embeddingX, state.wrapX); + } else { + tg.task("updateX", TransformerComputeKernels::convertFP16toFP32, + ctx, (HalfFloatArray) state.embeddingX, state.wrapX); + } + return tg.persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return itg; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler scheduler) { + scheduler.addWorkerGrid("decodeActivation.updateX", + WorkerGridFactory.genericWorker(dim, 128)); + return scheduler; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java new file mode 100644 index 00000000..73eec954 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java @@ -0,0 +1,72 @@ +package org.beehive.gpullama3.tornadovm.plan.components.activation; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Batch-prefill activation graph ("prefillActivation"). + * + *Provides all layer objects and the FP16 embedding preparer (raw half-float byte copy).
+ */ +public class LlamaFP16PlanComponents implements BatchPrefillDecodeForwardPlanComponents { + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final LlamaConfiguration config; + private final SchedulerType schedulerType; + + public LlamaFP16PlanComponents(LlamaState state, Model model) { + this.state = state; + this.config = (LlamaConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + // ── Activations ─────────────────────────────────────────────────────────── + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public ActivationGraph decodeActivation() { + return new Activation("decodeActivation", state, weights, config); + } + + @Override public ActivationGraph batchPrefillActivation(int batchSize) { + return new BatchPrefillActivation(state, config, batchSize, false); + } + + @Override public ActivationGraph batchDecodeActivation(String lastBatchLayerId) { + return new BatchDecodeActivation(state, config, lastBatchLayerId, false); + } + + // ── FFN layer groups ────────────────────────────────────────────────────── + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new LlamaFP16FFNLayers("llamaFFN", state, weights, config, schedulerType); + } + + @Override public TransformerLayerTaskGraphs prefillDecodeLayers() { + return new LlamaFP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + } + + @Override public TransformerLayerTaskGraphs batchDecodeLayers() { + return new LlamaFP16FFNLayersDecode("decode", state, weights, config, schedulerType); + } + + @Override public BatchPrefillTransformerLayerTaskGraphs batchPrefillLayers(int batchSize) { + return new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize); + } + + // ── Logits layers ───────────────────────────────────────────────────────── + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } + + @Override public AbstractLogitsLayer decodeLogits(String previousGraphId) { + return new LogitsFP16LayerDecode("logits", state, weights, config, previousGraphId, schedulerType); + } + + // ── Embedding preparation ───────────────────────────────────────────────── + + @Override public EmbeddingPreparer embeddingPreparer() { + MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int dim = config.dim(); + int bytes = Short.BYTES; + return new EmbeddingPreparer() { + @Override + public void initBatchState() { + state.wrapXBatch.clear(); + state.batchStartPosHolder.init(0); + } + @Override + public void copyBatchEmbeddings(int[] tokenIds, int startPos, int chunkSize) { + state.batchStartPosHolder.set(0, startPos); + for (int b = 0; b < chunkSize; b++) { + MemorySegment.copy(embTable, (long) tokenIds[b] * dim * bytes, + state.embeddingXBatch.getSegment(), (long) b * dim * bytes, + (long) dim * bytes); + } + } + @Override + public void copyDecodeEmbedding(int token) { + MemorySegment.copy(embTable, (long) token * dim * bytes, + state.embeddingX.getSegment(), 0L, (long) dim * bytes); + } + }; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java new file mode 100644 index 00000000..5cc01685 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.mistral.MistralConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.MistralFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class MistralFP16PlanComponents implements SingleTokenForwardPlanComponents { + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final MistralConfiguration config; + private final SchedulerType schedulerType; + + public MistralFP16PlanComponents(LlamaState state, Model model) { + this.state = state; + this.config = (MistralConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new MistralFP16FFNLayers("mistralFFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java new file mode 100644 index 00000000..b98eb47d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.phi3.Phi3Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Phi3FP16FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Phi3FP16PlanComponents implements SingleTokenForwardPlanComponents { + + private final Phi3State state; + private final Phi3TornadoWeights weights; + private final Phi3Configuration config; + private final SchedulerType schedulerType; + + public Phi3FP16PlanComponents(Phi3State state, Model model) { + this.state = state; + this.config = (Phi3Configuration) model.configuration(); + this.weights = (Phi3TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new Phi3FP16FFNLayers("phi3FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java new file mode 100644 index 00000000..2eaae7a2 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen2FP16FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Qwen2FP16PlanComponents implements SingleTokenForwardPlanComponents { + + private final Qwen2State state; + private final Qwen2TornadoWeights weights; + private final Qwen2Configuration config; + private final SchedulerType schedulerType; + + public Qwen2FP16PlanComponents(Qwen2State state, Model model) { + this.state = state; + this.config = (Qwen2Configuration) model.configuration(); + this.weights = (Qwen2TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new Qwen2FP16FFNLayers("qwen2FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java new file mode 100644 index 00000000..2ed13ab1 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Qwen3FP16PlanComponents implements SingleTokenForwardPlanComponents { + + private final Qwen3State state; + private final Qwen3TornadoWeights weights; + private final Qwen3Configuration config; + private final SchedulerType schedulerType; + + public Qwen3FP16PlanComponents(Qwen3State state, Model model) { + this.state = state; + this.config = (Qwen3Configuration) model.configuration(); + this.weights = (Qwen3TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new Qwen3FP16FFNLayers("qwen3FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java new file mode 100644 index 00000000..8302f7cb --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.DevstralState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.devstral.DevstralConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.DevstralQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class DevstralQ8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final DevstralState state; + private final LlamaTornadoWeights weights; + private final DevstralConfiguration config; + private final SchedulerType schedulerType; + + public DevstralQ8_0PlanComponents(DevstralState state, Model model) { + this.state = state; + this.config = (DevstralConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new DevstralQ8_0FFNLayers("devstralFFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java new file mode 100644 index 00000000..93f01f10 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.GraniteState; +import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.GraniteQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsGraniteQ8_0Layer; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class GraniteQ8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final GraniteState state; + private final GraniteTornadoWeights weights; + private final GraniteConfiguration config; + private final SchedulerType schedulerType; + + public GraniteQ8_0PlanComponents(GraniteState state, Model model) { + this.state = state; + this.config = (GraniteConfiguration) model.configuration(); + this.weights = (GraniteTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new ActivationGranite("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new GraniteQ8_0FFNLayers("graniteFFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsGraniteQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java new file mode 100644 index 00000000..bd82a767 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java @@ -0,0 +1,131 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersPrefillDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LogitsQ8_0LayerDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill.LlamaQ8_0LayersBatchPrefill; +import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.EmbeddingPreparer; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchDecodeActivation; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchPrefillActivation; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; + +import java.lang.foreign.MemorySegment; + +/** + * {@link BatchPrefillDecodeForwardPlanComponents} for Llama + Q8_0. + * + *Batch embedding prep: CPU dequantizes Q8_0 embeddings into {@code wrapXBatch} (FP32). + * Decode embedding prep: raw Q8_0 block copy into {@code embeddingX} for on-device conversion.
+ */ +public class LlamaQ8_0PlanComponents implements BatchPrefillDecodeForwardPlanComponents { + + private static final int BLOCK_SIZE = 32; + private static final int Q8_0_BLOCK_BYTES = 34; + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final LlamaConfiguration config; + private final SchedulerType schedulerType; + + public LlamaQ8_0PlanComponents(LlamaState state, Model model) { + this.state = state; + this.config = (LlamaConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + // ── Activations ─────────────────────────────────────────────────────────── + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public ActivationGraph decodeActivation() { + return new Activation("decodeActivation", state, weights, config); + } + + @Override public ActivationGraph batchPrefillActivation(int batchSize) { + return new BatchPrefillActivation(state, config, batchSize, true); + } + + @Override public ActivationGraph batchDecodeActivation(String lastBatchLayerId) { + return new BatchDecodeActivation(state, config, lastBatchLayerId, true); + } + + // ── FFN layer groups ────────────────────────────────────────────────────── + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new LlamaQ8_0FFNLayers("llamaFFN", state, weights, config, schedulerType); + } + + @Override public TransformerLayerTaskGraphs prefillDecodeLayers() { + return new LlamaQ8_0FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + } + + @Override public TransformerLayerTaskGraphs batchDecodeLayers() { + return new LlamaQ8_0FFNLayersDecode("decode", state, weights, config, schedulerType); + } + + @Override public BatchPrefillTransformerLayerTaskGraphs batchPrefillLayers(int batchSize) { + return new LlamaQ8_0LayersBatchPrefill(state, weights, config, batchSize); + } + + // ── Logits layers ───────────────────────────────────────────────────────── + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } + + @Override public AbstractLogitsLayer decodeLogits(String previousGraphId) { + return new LogitsQ8_0LayerDecode("logits", state, weights, config, previousGraphId, schedulerType); + } + + // ── Embedding preparation ───────────────────────────────────────────────── + + @Override public EmbeddingPreparer embeddingPreparer() { + ByteArray embTable = weights.getTokenEmbeddingTable().asByteArray(); + int dim = config.dim(); + int blocksPerRow = (dim + BLOCK_SIZE - 1) / BLOCK_SIZE; + long bytesPerToken = (long) blocksPerRow * Q8_0_BLOCK_BYTES; + + return new EmbeddingPreparer() { + @Override + public void initBatchState() { + state.wrapXBatch.clear(); + state.batchStartPosHolder.init(0); + } + @Override + public void copyBatchEmbeddings(int[] tokenIds, int startPos, int chunkSize) { + state.batchStartPosHolder.set(0, startPos); + for (int b = 0; b < chunkSize; b++) { + int tokenId = tokenIds[b]; + for (int j = 0; j < dim; j++) { + int blockByteOffset = (tokenId * blocksPerRow + j / BLOCK_SIZE) * Q8_0_BLOCK_BYTES; + float scale = embTable.getHalfFloat(blockByteOffset).getFloat32(); + float quant = embTable.get(blockByteOffset + 2 + j % BLOCK_SIZE); + state.wrapXBatch.set(b * dim + j, quant * scale); + } + } + } + @Override + public void copyDecodeEmbedding(int token) { + MemorySegment.copy(embTable.getSegment(), token * bytesPerToken, + state.embeddingX.getSegment(), 0L, bytesPerToken); + } + }; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java new file mode 100644 index 00000000..4b8a78ec --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.mistral.MistralConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.MistralQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class MistralQ8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final MistralConfiguration config; + private final SchedulerType schedulerType; + + public MistralQ8_0PlanComponents(LlamaState state, Model model) { + this.state = state; + this.config = (MistralConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new MistralQ8_0FFNLayers("mistralFFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java new file mode 100644 index 00000000..57cd9508 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.phi3.Phi3Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Phi3Q8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Phi3Q8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final Phi3State state; + private final Phi3TornadoWeights weights; + private final Phi3Configuration config; + private final SchedulerType schedulerType; + + public Phi3Q8_0PlanComponents(Phi3State state, Model model) { + this.state = state; + this.config = (Phi3Configuration) model.configuration(); + this.weights = (Phi3TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new Phi3Q8_0FFNLayers("phi3FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java new file mode 100644 index 00000000..534eedc5 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen2Q8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Qwen2Q8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final Qwen2State state; + private final Qwen2TornadoWeights weights; + private final Qwen2Configuration config; + private final SchedulerType schedulerType; + + public Qwen2Q8_0PlanComponents(Qwen2State state, Model model) { + this.state = state; + this.config = (Qwen2Configuration) model.configuration(); + this.weights = (Qwen2TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new Qwen2Q8_0FFNLayers("qwen2FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java new file mode 100644 index 00000000..3c615b39 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Qwen3Q8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final Qwen3State state; + private final Qwen3TornadoWeights weights; + private final Qwen3Configuration config; + private final SchedulerType schedulerType; + + public Qwen3Q8_0PlanComponents(Qwen3State state, Model model) { + this.state = state; + this.config = (Qwen3Configuration) model.configuration(); + this.weights = (Qwen3TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new Qwen3Q8_0FFNLayers("qwen3FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/BatchPrefillDecodeForwardTaskGraphLayout.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/BatchPrefillDecodeForwardTaskGraphLayout.java new file mode 100644 index 00000000..aec212bd --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/BatchPrefillDecodeForwardTaskGraphLayout.java @@ -0,0 +1,21 @@ +package org.beehive.gpullama3.tornadovm.plan.layout; + +/** + * Graph-index arithmetic for the 2N+3 batch-prefill/decode forward plan. + * + *
+ * [0] batchPrefillActivation
+ * [1..N] batchPrefillLayer_0 .. batchPrefillLayer_{N-1}
+ * [N+1] decodeActivation (consumes + re-persists KV cache)
+ * [N+2..2N+1] decodeLayer_0 .. decodeLayer_{N-1}
+ * [2N+2] logits
+ *
+ */
+public record BatchPrefillDecodeForwardTaskGraphLayout(int N) {
+ public int batchActivationIdx() { return 0; }
+ public int batchLayerIdx(int i) { return 1 + i; }
+ public int decodeActivationIdx() { return N + 1; }
+ public int decodeLayerIdx(int i) { return N + 2 + i; }
+ public int logitsIdx() { return 2 * N + 2; }
+ public int totalGraphs() { return 2 * N + 3; }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/PrefillDecodeForwardTaskGraphLayout.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/PrefillDecodeForwardTaskGraphLayout.java
new file mode 100644
index 00000000..c252c92d
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/PrefillDecodeForwardTaskGraphLayout.java
@@ -0,0 +1,17 @@
+package org.beehive.gpullama3.tornadovm.plan.layout;
+
+/**
+ * Graph-index arithmetic for the N+2 prefill/decode forward plan.
+ *
+ *
+ * [0] decodeActivation
+ * [1..N] layer_0 .. layer_{N-1}
+ * [N+1] logits
+ *
+ */
+public record PrefillDecodeForwardTaskGraphLayout(int N) {
+ public int activationIdx() { return 0; }
+ public int layerIdx(int i) { return 1 + i; }
+ public int logitsIdx() { return N + 1; }
+ public int totalGraphs() { return N + 2; }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/SingleTokenForwardTaskGraphLayout.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/SingleTokenForwardTaskGraphLayout.java
new file mode 100644
index 00000000..d8fe8d13
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/SingleTokenForwardTaskGraphLayout.java
@@ -0,0 +1,17 @@
+package org.beehive.gpullama3.tornadovm.plan.layout;
+
+/**
+ * Graph-index arithmetic for the N+2 single-token forward plan.
+ *
+ *
+ * [0] activation
+ * [1..N] layer_0 .. layer_{N-1}
+ * [N+1] logits
+ *
+ */
+public record SingleTokenForwardTaskGraphLayout(int N) {
+ public int activationIdx() { return 0; }
+ public int layerIdx(int i) { return 1 + i; }
+ public int logitsIdx() { return N + 1; }
+ public int totalGraphs() { return N + 2; }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerDetectionService.java
similarity index 93%
rename from src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerDetectionService.java
index 5a81caa8..5e29c528 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerDetectionService.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.strategy;
+package org.beehive.gpullama3.tornadovm.scheduling;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.ModelType;
@@ -9,7 +9,6 @@
public class SchedulerDetectionService {
-
public static SchedulerType determineSchedulerType(Model model) {
TornadoRuntime tornadoRuntime = TornadoRuntimeProvider.getTornadoRuntime();
String platformName = tornadoRuntime.getBackend(0)
@@ -24,4 +23,4 @@ public static SchedulerType determineSchedulerType(Model model) {
return (isNvidia && isNotMistral) ? SchedulerType.NVIDIA : SchedulerType.NON_NVIDIA;
}
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerType.java b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerType.java
new file mode 100644
index 00000000..bbd27169
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerType.java
@@ -0,0 +1,5 @@
+package org.beehive.gpullama3.tornadovm.scheduling;
+
+public enum SchedulerType {
+ NVIDIA, NON_NVIDIA
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/WorkerGridFactory.java
similarity index 98%
rename from src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/scheduling/WorkerGridFactory.java
index af39c133..4d5a55a5 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/WorkerGridFactory.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.tornadovm.layerplanner;
+package org.beehive.gpullama3.tornadovm.scheduling;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
@@ -98,5 +98,4 @@ private static int findOptimalLocalSize(int size) {
}
return optimal;
}
-
}
From 4e4478af24a77b80111618f7e76050f65d44b33f Mon Sep 17 00:00:00 2001
From: Orion Papadakis Implemented by {@link Activation} and custom activation wrappers used by * {@link org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents}.
*/ -public interface ActivationGraph { +public interface ActivationTaskGraph { ImmutableTaskGraph getImmutableTaskGraph(); GridScheduler updateGridScheduler(GridScheduler scheduler); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java index 1246265c..55202738 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java @@ -2,7 +2,7 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents; @@ -42,7 +42,7 @@ public BatchPrefillDecodeForwardPlan(Model model, BatchPrefillDecodeForwardPlanC ListImplemented by {@link AbstractFFNLayers} and its subclasses.
+ *Implemented by {@link AbstractTransformerLayerTaskGraphs} and its subclasses.
*/ public interface TransformerLayerTaskGraphs { ListOverrides data-transfer declarations so that all cross-graph boundaries use diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java index c1167796..f3f628b3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java @@ -9,7 +9,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Decode FFN layers for the single-token prefill/decode plan + * Decode transformer-layer task graphs for the single-token prefill/decode plan * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}). * *
Combines two concerns:
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java index 44ce4b35..dcfe626e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java @@ -17,7 +17,7 @@ import java.util.stream.IntStream; /** - * Prefill FFN layers with batching for the unified batched prefill-decode plan + * Batched-prefill transformer-layer task graphs for the unified batched prefill-decode plan * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}). * *One {@link ImmutableTaskGraph} per transformer layer, each processing
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
index 26686738..b8a9fafa 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
@@ -13,7 +13,7 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * Q8_0 FFN layers for Devstral 2 models.
+ * Q8_0 transformer-layer task graphs for Devstral 2 models.
* Uses precomputed RoPE frequencies (YaRN scaling) instead of on-the-fly computation.
*/
public class DevstralQ8_0FFNLayers extends AbstractTransformerLayerTaskGraphs Layer 0 consumes the KV cache from device (passed through by the decode activation
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
index 388853f7..6c4c22b6 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
@@ -8,7 +8,7 @@
import uk.ac.manchester.tornado.api.TaskGraph;
/**
- * Decode FFN layers for the single-token prefill/decode plan
+ * Decode transformer-layer task graphs for the single-token prefill/decode plan
* ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}).
*
* Layer 0 delegates to {@link LlamaQ8_0FFNLayers#configureLayerDataTransfers} which
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
index 14c30462..c0904ef3 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
@@ -17,7 +17,7 @@
import java.util.stream.IntStream;
/**
- * Prefill FFN layers with batching for the unified batched prefill-decode plan (Q8_0).
+ * Batched-prefill transformer-layer task graphs for the unified batched prefill-decode plan (Q8_0).
*
* Mirrors {@link org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill}
* but uses Q8_0 kernels with inline dequantization. Key differences from the FP16 path: Delegates the full chunk to
- * {@link TornadoVMMasterPlanBatchPrefillDecode#tornadoVMForwardBatchPrefill},
- * which handles embedding lookup and GPU execution internally. Copies {@code chunkSize} token embeddings into device-visible state buffers,
+ * then delegates graph execution to the plan. Delegates a single-token decode step to
- * {@link TornadoVMMasterPlanBatchPrefillDecode#tornadoVMForwardDecode},
- * which copies the token embedding and runs the decode + logits graphs. Copies the token embedding into device-visible state, then delegates
+ * graph execution to the plan. Concrete implementations are format-specific (FP16 byte copy vs Q8_0 CPU dequantization)
- * and live in the component-provider classes.
Graph layout:
*- * [0] batchPrefillActivation - * [1..N] batch-prefill transformer layers - * [N+1] decodeActivation (consumes + re-persists KV cache) - * [N+2..2N+1] decode transformer layers - * [2N+2] logits + * [0] batch activation ← batchPrefillActivation(int) + * [1..N] batch layers ← batchPrefillTransformerLayers(int) + * [N+1] decode activation ← batchDecodeActivation(String) + * [N+2..2N+1] decode layers ← batchDecodeTransformerLayers() + * [2N+2] logits ← decodeLogits(String) ** *
During batch prefill, the master plan executes graphs 0..N. @@ -43,7 +43,7 @@ public BatchPrefillDecodeForwardPlan(Model model, BatchPrefillDecodeForwardPlanC all.add(batchAct.getImmutableTaskGraph()); batchAct.updateGridScheduler(scheduler); - BatchPrefillTransformerLayerTaskGraphs batchLayers = components.batchPrefillLayers(batchSize); + BatchPrefillTransformerLayerTaskGraphs batchLayers = components.batchPrefillTransformerLayers(batchSize); all.addAll(batchLayers.getLayerImmutableTaskGraphs()); batchLayers.updateGridScheduler(scheduler); @@ -51,7 +51,7 @@ public BatchPrefillDecodeForwardPlan(Model model, BatchPrefillDecodeForwardPlanC all.add(decodeAct.getImmutableTaskGraph()); decodeAct.updateGridScheduler(scheduler); - TransformerLayerTaskGraphs decodeLayers = components.batchDecodeLayers(); + TransformerLayerTaskGraphs decodeLayers = components.batchDecodeTransformerLayers(); all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); decodeLayers.updateGridScheduler(scheduler); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java index d526080d..fa4b4cd3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java @@ -17,9 +17,9 @@ * *
Graph layout:
*- * [0] decodeActivation - * [1..N] decode transformer layers - * [N+1] logits + * [0] activation ← prefillDecodeActivation() + * [1..N] layers ← prefillDecodeTransformerLayers() + * [N+1] logits ← decodeLogits(String) ** *
During prefill, the master plan executes graphs 0..N (skipping logits).
@@ -36,11 +36,11 @@ public PrefillDecodeForwardPlan(Model model, PrefillDecodeForwardPlanComponents
List Graph layout:
+ * Batch-prefill/decode inference with TornadoVM is implemented by {@link TornadoVMMasterPlanBatchPrefillDecode}.
+ * It employs a {@link BatchPrefillDecodeForwardPlan} instance to represent the complete
+ * batch-prefill/decode forward operation as a chain of distinct TornadoVM TaskGraphs.
+ * The components of this chain are represented by the following components:
+ *
- * [0] activation
- * [1..N] transformer layers
- * [N+1] logits
+ * [0] activation ← singleTokenActivation()
+ * [1..N] layers ← singleTokenTransformerLayers()
+ * [N+1] logits ← singleTokenLogits(String)
*
*/
public class SingleTokenForwardPlan extends ForwardPlan {
@@ -33,15 +33,15 @@ public SingleTokenForwardPlan(Model model, SingleTokenForwardPlanComponents comp
List
+ *
+ *
Extends {@link PrefillDecodeForwardPlanComponents} with the batch-prefill - * activation, batch layer group, KV-cache decode activation, and the - * host-side embedding preparer.
+ * Note: Consult also the {@link org.beehive.gpullama3.tornadovm.plan.layout.BatchPrefillDecodeForwardTaskGraphLayout} */ public interface BatchPrefillDecodeForwardPlanComponents extends PrefillDecodeForwardPlanComponents { @@ -17,8 +31,8 @@ public interface BatchPrefillDecodeForwardPlanComponents extends PrefillDecodeFo ActivationTaskGraph batchDecodeActivation(String lastBatchLayerId); - TransformerLayerTaskGraphs batchDecodeLayers(); + TransformerLayerTaskGraphs batchDecodeTransformerLayers(); - BatchPrefillTransformerLayerTaskGraphs batchPrefillLayers(int batchSize); + BatchPrefillTransformerLayerTaskGraphs batchPrefillTransformerLayers(int batchSize); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java index dad27f94..2d6bcefc 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java @@ -1,20 +1,34 @@ package org.beehive.gpullama3.tornadovm.plan.components; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.plan.PrefillDecodeForwardPlan; /** - * Components for the N+2 prefill/decode forward plan. + * The necessary components that any model+quantization combination + * should implement to support *prefill-decode inference*. + *+ * Prefill-decode inference with TornadoVM is implemented by {@link TornadoVMMasterPlanPrefillDecode}. + * It employs a {@link PrefillDecodeForwardPlan} instance to represent the complete + * prefill-decode forward operation as a chain of distinct TornadoVM TaskGraphs. + * The components of this chain are represented by the following components: + *
Extends {@link SingleTokenForwardPlanComponents} with the decode-phase - * activation, KV-cache-aware layer group, and decode logits.
*/ public interface PrefillDecodeForwardPlanComponents extends SingleTokenForwardPlanComponents { - ActivationTaskGraph decodeActivation(); + ActivationTaskGraph prefillDecodeActivation(); - TransformerLayerTaskGraphs prefillDecodeLayers(); + TransformerLayerTaskGraphs prefillDecodeTransformerLayers(); AbstractLogitsTaskGraph decodeLogits(String previousGraphId); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java index 596156f4..0f93710c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java @@ -1,21 +1,33 @@ package org.beehive.gpullama3.tornadovm.plan.components; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanSingleToken; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.plan.SingleTokenForwardPlan; /** - * Components for the single-token forward pass. + * The necessary components that any model+quantization combination + * should implement to support *single-token inference*. + *+ * Single-token inference with TornadoVM is implemented by {@link TornadoVMMasterPlanSingleToken}. + * It employees a {@link SingleTokenForwardPlan} instance to represent the complete + * single-token forward operation as a chain of distinct TornadoVM TaskGraphs. + * The components of this chain are represented by the following components: + *
All model+quantization combinations implement this interface. - * Models that support prefill/decode modes implement the richer - * {@link PrefillDecodeForwardPlanComponents} or {@link BatchPrefillDecodeForwardPlanComponents}.
+ * Note: Consult also the {@link org.beehive.gpullama3.tornadovm.plan.layout.SingleTokenForwardTaskGraphLayout} */ public interface SingleTokenForwardPlanComponents { - ActivationTaskGraph standardActivation(); + ActivationTaskGraph singleTokenActivation(); - TransformerLayerTaskGraphs standardLayers(); + TransformerLayerTaskGraphs singleTokenTransformerLayers(); - AbstractLogitsTaskGraph standardLogits(String previousGraphId); + AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java index 00f49f28..d0817f7e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java @@ -28,15 +28,15 @@ public DevstralFP16PlanComponents(DevstralState state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationTaskGraph standardActivation() { + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } - @Override public TransformerLayerTaskGraphs standardLayers() { + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new DevstralFP16FFNLayers("devstralFFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java index 5f47ec94..88e741fc 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java @@ -28,15 +28,15 @@ public GraniteFP16PlanComponents(GraniteState state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationTaskGraph standardActivation() { + @Override public ActivationTaskGraph singleTokenActivation() { return new ActivationGranite("activationUpdate", state, weights, config); } - @Override public TransformerLayerTaskGraphs standardLayers() { + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new GraniteFP16FFNLayers("graniteFFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsGraniteFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java index 22e6ef9d..8055c149 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java @@ -43,11 +43,11 @@ public LlamaFP16PlanComponents(LlamaState state, Model model) { // ── Activations ─────────────────────────────────────────────────────────── - @Override public ActivationTaskGraph standardActivation() { + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } - @Override public ActivationTaskGraph decodeActivation() { + @Override public ActivationTaskGraph prefillDecodeActivation() { return new Activation("decodeActivation", state, weights, config); } @@ -61,25 +61,25 @@ public LlamaFP16PlanComponents(LlamaState state, Model model) { // ── Transformer layer task graphs ────────────────────────────────────────────────────── - @Override public TransformerLayerTaskGraphs standardLayers() { + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new LlamaFP16FFNLayers("llamaFFN", state, weights, config, schedulerType); } - @Override public TransformerLayerTaskGraphs prefillDecodeLayers() { + @Override public TransformerLayerTaskGraphs prefillDecodeTransformerLayers() { return new LlamaFP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); } - @Override public TransformerLayerTaskGraphs batchDecodeLayers() { + @Override public TransformerLayerTaskGraphs batchDecodeTransformerLayers() { return new LlamaFP16FFNLayersDecode("decode", state, weights, config, schedulerType); } - @Override public BatchPrefillTransformerLayerTaskGraphs batchPrefillLayers(int batchSize) { + @Override public BatchPrefillTransformerLayerTaskGraphs batchPrefillTransformerLayers(int batchSize) { return new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize); } // ── Logits layers ───────────────────────────────────────────────────────── - @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java index fdd4d43f..7e648070 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java @@ -28,15 +28,15 @@ public MistralFP16PlanComponents(LlamaState state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationTaskGraph standardActivation() { + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } - @Override public TransformerLayerTaskGraphs standardLayers() { + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new MistralFP16FFNLayers("mistralFFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java index fb21c4c2..f8c56759 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java @@ -28,15 +28,15 @@ public Phi3FP16PlanComponents(Phi3State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationTaskGraph standardActivation() { + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } - @Override public TransformerLayerTaskGraphs standardLayers() { + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new Phi3FP16FFNLayers("phi3FFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java index 03fc3c88..2bba98f7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java @@ -28,15 +28,15 @@ public Qwen2FP16PlanComponents(Qwen2State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationTaskGraph standardActivation() { + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } - @Override public TransformerLayerTaskGraphs standardLayers() { + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new Qwen2FP16FFNLayers("qwen2FFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java index a11270bd..59cd920d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java @@ -28,15 +28,15 @@ public Qwen3FP16PlanComponents(Qwen3State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationTaskGraph standardActivation() { + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } - @Override public TransformerLayerTaskGraphs standardLayers() { + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new Qwen3FP16FFNLayers("qwen3FFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java index 8589ee05..bae1b073 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java @@ -28,15 +28,15 @@ public DevstralQ8_0PlanComponents(DevstralState state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationTaskGraph standardActivation() { + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } - @Override public TransformerLayerTaskGraphs standardLayers() { + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new DevstralQ8_0FFNLayers("devstralFFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java index 03e0ccb9..67f4aca3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java @@ -28,15 +28,15 @@ public GraniteQ8_0PlanComponents(GraniteState state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationTaskGraph standardActivation() { + @Override public ActivationTaskGraph singleTokenActivation() { return new ActivationGranite("activationUpdate", state, weights, config); } - @Override public TransformerLayerTaskGraphs standardLayers() { + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new GraniteQ8_0FFNLayers("graniteFFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsGraniteQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java index 6c561a4c..eb55e39b 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java @@ -47,11 +47,11 @@ public LlamaQ8_0PlanComponents(LlamaState state, Model model) { // ── Activations ─────────────────────────────────────────────────────────── - @Override public ActivationTaskGraph standardActivation() { + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } - @Override public ActivationTaskGraph decodeActivation() { + @Override public ActivationTaskGraph prefillDecodeActivation() { return new Activation("decodeActivation", state, weights, config); } @@ -65,25 +65,25 @@ public LlamaQ8_0PlanComponents(LlamaState state, Model model) { // ── Transformer layer task graphs ────────────────────────────────────────────────────── - @Override public TransformerLayerTaskGraphs standardLayers() { + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new LlamaQ8_0FFNLayers("llamaFFN", state, weights, config, schedulerType); } - @Override public TransformerLayerTaskGraphs prefillDecodeLayers() { + @Override public TransformerLayerTaskGraphs prefillDecodeTransformerLayers() { return new LlamaQ8_0FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); } - @Override public TransformerLayerTaskGraphs batchDecodeLayers() { + @Override public TransformerLayerTaskGraphs batchDecodeTransformerLayers() { return new LlamaQ8_0FFNLayersDecode("decode", state, weights, config, schedulerType); } - @Override public BatchPrefillTransformerLayerTaskGraphs batchPrefillLayers(int batchSize) { + @Override public BatchPrefillTransformerLayerTaskGraphs batchPrefillTransformerLayers(int batchSize) { return new LlamaQ8_0LayersBatchPrefill(state, weights, config, batchSize); } // ── Logits layers ───────────────────────────────────────────────────────── - @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java index d4d91993..025dd3a7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java @@ -28,15 +28,15 @@ public MistralQ8_0PlanComponents(LlamaState state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationTaskGraph standardActivation() { + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } - @Override public TransformerLayerTaskGraphs standardLayers() { + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new MistralQ8_0FFNLayers("mistralFFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java index 7974a3fe..2d91aa9b 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java @@ -28,15 +28,15 @@ public Phi3Q8_0PlanComponents(Phi3State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationTaskGraph standardActivation() { + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } - @Override public TransformerLayerTaskGraphs standardLayers() { + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new Phi3Q8_0FFNLayers("phi3FFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java index fb739835..9dbed69f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java @@ -28,15 +28,15 @@ public Qwen2Q8_0PlanComponents(Qwen2State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationTaskGraph standardActivation() { + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } - @Override public TransformerLayerTaskGraphs standardLayers() { + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new Qwen2Q8_0FFNLayers("qwen2FFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java index 8ef1881d..792e889b 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java @@ -28,15 +28,15 @@ public Qwen3Q8_0PlanComponents(Qwen3State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationTaskGraph standardActivation() { + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } - @Override public TransformerLayerTaskGraphs standardLayers() { + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new Qwen3Q8_0FFNLayers("qwen3FFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } } From 7f1864bac956d3ea5d8e2c785b8afcf409add9dd Mon Sep 17 00:00:00 2001 From: Orion PapadakisA single {@link TornadoExecutionPlan} holds all task graphs for + *
A single {@link TornadoExecutionPlan} holds all TaskGraphs for * batched prefill and single-token decode phases:
* *TaskGraph layout (2N+3 TaskGraphs total):
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java index f73a55a3..7c746814 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java @@ -8,7 +8,7 @@ import uk.ac.manchester.tornado.api.TaskGraph; /** - * Abstract base class for activation, transformer-layer, and logits task graphs. + * Abstract base class for activation, transformer-layer, and logits TaskGraphs. */ public abstract class AbstractLayer { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java index 61bfa573..31adc154 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java @@ -11,7 +11,7 @@ /** * Granite-specific activation: applies an embedding scale factor during the FP32 conversion. - * Overrides only the task graph builder; all other behaviour is inherited from Activation. + * Overrides only the TaskGraph builder; all other behaviour is inherited from Activation. */ public class ActivationGranite extends Activation { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java index 7b709797..851d2940 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java @@ -14,7 +14,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * FP16 transformer-layer task graphs for Devstral 2 models. + * FP16 transformer-layer TaskGraphs for Devstral 2 models. * Uses precomputed RoPE frequencies (YaRN scaling) instead of on-the-fly computation. */ public class DevstralFP16FFNLayers extends AbstractTransformerLayerTaskGraphsThe no-arg form is safe in CUDA-graph mode (device pointers are frozen at capture time) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index b06fb942..2839bc5e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -14,7 +14,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Phi3FP16FFNLayers: FP16 transformer-layer task graphs for Phi3 with Group Query Attention (GQA) support. + * Phi3FP16FFNLayers: FP16 transformer-layer TaskGraphs for Phi3 with Group Query Attention (GQA) support. * * Key Differences from Qwen2/Qwen3: - Uses combined QKV matrix (wqkv) instead of separate Q, K, V matrices - Includes splitQKV task to separate combined buffer - Uses ropeRotationPhi3 kernel for * position embeddings - FFN uses single wUp matrix that outputs both Gate and Up (2 * hiddenDim) - Includes splitGateUpAndSiLU task for FFN activation - Uses wDown for final FFN projection - No Q, K, diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index 4b3bbca3..b75a7f70 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -17,7 +17,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Qwen2FP16FFNLayers: FP16 transformer-layer task graphs for Qwen2 with Group Query Attention (GQA) support. + * Qwen2FP16FFNLayers: FP16 transformer-layer TaskGraphs for Qwen2 with Group Query Attention (GQA) support. * * Key Differences from Qwen3: - No tempQcur/tempKcur fields in Qwen2State - Includes bias terms for Q, K, V projections - Standard GQA (no parallel offset RMSNorm) - Uses * Qwen2Kernels::processHeadsFlashAttention for attention computation - Uses Qwen3Kernels::ropeRotation for position embeddings - Simpler matrix dimensions (uses config.dim() and config.kvDim() diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index d0bee6a9..0e8b21b0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -14,7 +14,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Qwen3FP16FFNLayers: FP16 transformer-layer task graphs for Qwen3 with Group Query Attention (GQA) support. + * Qwen3FP16FFNLayers: FP16 transformer-layer TaskGraphs for Qwen3 with Group Query Attention (GQA) support. * * Key Differences from Llama: - Supports GQA with separate KV heads (nHeadKv) - Uses Qwen3Kernels for RMSNorm with parallel offset - Custom RoPE rotation for Qwen3 - Different attention computation * due to GQA structure diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java index 28d8a1f6..a8010433 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java @@ -9,7 +9,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Decode transformer-layer task graphs of the unified batched prefill-decode plan + * Decode transformer-layer TaskGraphs of the unified batched prefill-decode plan * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}). * *
Overrides data-transfer declarations so that all cross-graph boundaries use diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java index f3f628b3..4a74e69b 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java @@ -9,7 +9,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Decode transformer-layer task graphs for the single-token prefill/decode plan + * Decode transformer-layer TaskGraphs for the single-token prefill/decode plan * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}). * *
Combines two concerns:
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java index dcfe626e..6f0b3e4b 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java @@ -17,7 +17,7 @@ import java.util.stream.IntStream; /** - * Batched-prefill transformer-layer task graphs for the unified batched prefill-decode plan + * Batched-prefill transformer-layer TaskGraphs for the unified batched prefill-decode plan * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}). * *One {@link ImmutableTaskGraph} per transformer layer, each processing
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
index b8a9fafa..509f37db 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
@@ -13,7 +13,7 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * Q8_0 transformer-layer task graphs for Devstral 2 models.
+ * Q8_0 transformer-layer TaskGraphs for Devstral 2 models.
* Uses precomputed RoPE frequencies (YaRN scaling) instead of on-the-fly computation.
*/
public class DevstralQ8_0FFNLayers extends AbstractTransformerLayerTaskGraphs Layer 0 consumes the KV cache from device (passed through by the decode activation
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
index 6c4c22b6..d189b6a4 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
@@ -8,7 +8,7 @@
import uk.ac.manchester.tornado.api.TaskGraph;
/**
- * Decode transformer-layer task graphs for the single-token prefill/decode plan
+ * Decode transformer-layer TaskGraphs for the single-token prefill/decode plan
* ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}).
*
* Layer 0 delegates to {@link LlamaQ8_0FFNLayers#configureLayerDataTransfers} which
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
index c0904ef3..164aead2 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
@@ -17,7 +17,7 @@
import java.util.stream.IntStream;
/**
- * Batched-prefill transformer-layer task graphs for the unified batched prefill-decode plan (Q8_0).
+ * Batched-prefill transformer-layer TaskGraphs for the unified batched prefill-decode plan (Q8_0).
*
* Mirrors {@link org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill}
* but uses Q8_0 kernels with inline dequantization. Key differences from the FP16 path:
*
*