KV Cache Management for ExecuTorch
A runtime-swappable KV cache system for ExecuTorch backends, designed for single-user
desktop inference, including agent workloads. It supports plain decode and
tree/speculative (Eagle-style) decoding behind a single custom op; keeps bookkeeping shared
across backends; confines tensor work to each backend; and bounds memory without
over-allocating for short sequences.
1. Motivation & Principles
Target workload
The design target is a desktop application running one user's model locally. The workload is chat, and
increasingly agents: loops that call the model many times, fan out into parallel
sub-agents or tool calls, and carry a long shared context (system prompt, tool definitions,
a reasoning trunk) across branches. This target shapes every decision:
- Latency per step > throughput. An agent waits on every step of a multi-call chain, so
speculative/tree decoding (a per-step latency win) is first-class.
- Hardware is heterogeneous. Apple silicon (MLX) today, but also XNNPACK, Vulkan. "Support a new
backend" must be a clean operation, not a rewrite.
- Concurrency is light, not multi-tenant. A handful of live sequences sharing weights, not
thousands batched, so the cache stays simple (no scheduler) while serving agent fan-out.
- Memory is shared with the desktop. Short conversations must not reserve full context; the cache
grows lazily and bounds hard.
The result is a design general across the axes this workload exercises, decoding strategy
and backend portability, and scoped below multi-tenant serving (see §6).
What a KV cache is
Not a key-value store: the access pattern is append-only writes at the current position + a
bulk read of history, so the cache's job is addressing and byte movement, not lookup.
Principles
Each is load-bearing, and each ties back to the target:
- Bookkeeping is logical and tensor-free: slot allocation, sequence membership, refcounts,
length: pure integer logic, written once and shared. (A new backend reuses all of it.)
- Byte movement is backend-typed: writing/gathering KV and attention use the backend's tensor
type and kernels (mlx::core::array). (A new backend implements only this.)
- The cache owns its attention semantics: causal / membership / tree-masked is a property of
what the cache does, not the op (a dispatcher). (Tree is a new cache, not an op change.)
- The op is functional to the tracer; the cache is runtime state: the graph never represents
the cache, keeping the .pte decoding-strategy-agnostic. (Swap to speculative without re-export.)
- One model, swappable caches: weights load once, stateless across sessions; all per-session
state lives in the cache. (Sub-agents share weights; branches fork a prefix, §5.2.)
2. Architecture
Three tiers, separated by what they depend on:
| Tier |
Location |
Depends on |
Contents |
| Runner-facing interface |
core (neutral) |
nothing (ints only) |
Cache, SeqCache, TreeCache |
| Shared bookkeeping |
core (concrete) |
nothing (ints only) |
StepPlan, *Bookkeeping |
| Backend byte layer |
per-backend |
backend tensors |
MLXCacheImpl + concrete caches |
The cut follows the coupling: runner-facing control (can_extend, clear, seq_cp/seq_rm) is |
|
|
|
neutral; update_and_fetch is backend-typed. One concrete cache object implements both faces. |
|
|
|
host runner ──calls──► Cache / SeqCache / TreeCache (neutral face)
│
one concrete cache object
│
op handler ──calls──► MLXCacheImpl::update_and_fetch (backend face)
│
shared *Bookkeeping (plan) + MLX byte layer (pool, kernels)
3. Ahead-of-Time (AOT)
The AOT half: the neutral custom op the model calls, and a shared utility that writes
model-dependent constants into the .pte, both backend-independent, shaping the artifact every
backend then lowers and runs.
3.1 The Torch Custom Op (cross-backend)
Two update_and_attends, never sharing code: (1) the PyTorch custom op (traced by
torch.export, backend-neutral, the cross-backend contract) and (2) the serialized/runtime op
(a backend node + handler over its tensor type, the lowered form).
Definition
@torch.library.custom_op("kvcache::update_and_attend", mutates_args=())
def update_and_attend(
q: Tensor, k: Tensor, v: Tensor,
position: Tensor, # int, one entry per query token this forward (length = q_len):
# decode [pos]; prefill [0..T); tree per-node. ([B] only in
# the special case of one token per batch row.)
layer_id: int, # constant per call site (one node per layer at export)
scale: float, # per-model attention scale (NOT derivable safely)
out_dtype: torch.dtype, # the dtype the op must output (graph-compatibility contract)
) -> Tensor:
... # eager reference implementation (also the test oracle)
@update_and_attend.register_fake
def _(q, k, v, position, layer_id, scale, out_dtype):
return q.new_empty(q.shape, dtype=out_dtype) # out_dtype drives downstream dtype-prop
Signature decisions
mutates_args=(), functional to the tracer. The cache is runtime state (registry, §4.3),
not a graph tensor, no functionalization, no mutable-buffer machinery; statefulness lives behind
the delegate at runtime.
layer_id: int, a per-node constant. torch.export captures it per call site (L nodes,
node N tagged layer_id=N), the only channel carrying layer identity in one delegated graph.
position: Tensor, one entry per query token (q_len = k.shape(2)): [pos] decode, [0..T)
prefill, per-node for tree. A tensor (not a scalar) because tree decoding needs non-contiguous
per-token positions (siblings share, branches differ) and RoPE consumes it in-graph; multi-seq /
heterogeneous batching reuse the same shape with no contract change.
scale: float, a node constant. Usually 1/sqrt(head_dim) but models override it, so not
safely derivable in the handler.
out_dtype: torch.dtype, the op's output contract, as a node constant (same channel as
scale/layer_id). The downstream graph was traced expecting it, so the op must emit it
(and the meta kernel propagates it). It is independent of KV storage precision: a quantized
cache stores int4 yet still emits e.g. fp16; storage dtype is runtime policy, output dtype is this
op contract. Common case: pass q.dtype.
- No
causal flag. Masking is the cache's job, returned via AttendSpec (None/Causal/
Explicit), required for tree's per-node mask, and it subsumes prefill-causal vs decode-none.
Eager implementation
The op's eager (non-meta) impl lets the model run in plain PyTorch and is the reference oracle
every backend's runtime op is diffed against (sharing the same bookkeeping logic as the C++).
The lowering contract
A backend's partitioner extracts q,k,v + position as tensor inputs (position was scalar, now
a Tid) and layer_id/scale/out_dtype as node constants, then emits its backend node, the
one new thing on each backend's AOT side.
What "a backend opts in" means
Shared: the neutral op + meta + eager reference. Per-backend: (AOT) partition+lower to a node;
(runtime) a handler over its tensor type + cache impls + builder registration. The model is written
once against the neutral op; retargeting changes only the partitioner and the linked backend.
3.2 Metadata Writer
Sizing splits by who owns the value; only architecture facts belong in the .pte:
| Input |
Bucket |
Source |
n_layers |
architecture fact |
.pte metadata |
n_kv_heads |
architecture fact |
.pte metadata |
head_dim |
architecture fact |
.pte metadata |
out_dtype (attention output) |
per-op output contract |
op arg (§3.1), not the metadata block |
capacity (context length) |
runtime policy |
caller config |
kv_dtype (KV storage precision) |
runtime policy |
caller config |
min_chunk |
tuning |
caller config |
The test for the .pte: is the value fixed by the architecture? Layer/head/dim counts pass. |
|
|
| Context length fails (deployment choice; RoPE extrapolates, no architectural max), so it is set at runtime |
|
|
| (§4.5). KV storage dtype fails (a quantized cache stores int4 regardless of model), so it is cache |
|
|
| policy. Output dtype is fixed (the next node was traced expecting it) but is per-op, so it rides |
|
|
| as an op arg (§3.1), not metadata. So the block is architecture facts only, written under a |
|
|
versioned kv_cache.* namespace so the reader (§4.5) is model-agnostic: |
|
|
# extension/llm/export, shared writer; one canonical, versioned schema.
def write_kv_cache_metadata(metadata: dict, *, n_layers, n_kv_heads, head_dim) -> dict:
metadata.update({
"kv_cache.schema_version": 1,
"kv_cache.n_layers": n_layers,
"kv_cache.n_kv_heads": n_kv_heads, # GQA-aware: KV heads, not query heads
"kv_cache.head_dim": head_dim,
})
return metadata
# fed into export_llm's `metadata` param -> serialized into the .pte's method metadata.
The schema_version makes this a versioned contract (the reader rejects/adapts an incompatible
.pte rather than misreading a renamed field). Absent: max_context_len and
kv_dtype, both runtime policy. (Runtime-chosen context length needs export with dynamic sequence
shapes (enable_dynamic_shape) so the graph bakes in no fixed length.)
4. Runtime
The runtime half (on device): the neutral interfaces, shared bookkeeping, install/rendezvous
machinery, the runner loop, and the sizing reader that turns .pte metadata + runtime policy into a
CacheConfig.
4.1 Neutral Interfaces (core)
These are what the runner depends on. Tensor-free, only ints. Backend-independent.
namespace executorch::runtime {
// Base: what every cache exposes to the host generation loop.
class Cache {
public:
virtual ~Cache() = default;
virtual bool can_extend(int n = 1) const = 0; // admission / hard-stop
virtual int capacity() const = 0; // logical cap
virtual void clear() = 0; // reset for reuse
};
// Tree/speculative adds propose (declare the candidate tree) and commit (accept a path).
class TreeCache : public Cache {
public:
virtual void propose(const int32_t* tree_parent, int num_nodes) = 0;
virtual void commit(const int32_t* accepted_nodes, int len) = 0;
};
} // namespace executorch::runtime
propose / commit live only on their respective sub-interfaces. A decode
runner holds Cache* and never sees them. Capability is which interface the runner
holds, checked at construction, not a runtime flag.
Attention semantics are produced by the cache
The op does not carry a causal flag. Instead the read result declares how to attend it,
so tree (tree mask) and paged (ragged) caches all express their own
patterns:
struct AttendSpec {
mlx::core::array K, V; // backend tensors (op-facing struct)
enum Mask { None, Causal, Explicit } mask_kind;
std::optional<mlx::core::array> mask; // for Explicit: a BOOLEAN array (true = attend)
};
The three kinds map onto MLX's fast::scaled_dot_product_attention mask argument directly,
which is why this is the right seam:
None: no mask. The read is the selection (a single decode token attends the whole
prefix; a gathered multi-sequence read selects its cells).
Causal: the "causal" string (fused path, no mask tensor). MLX's "causal" is
lower-right aligned (last query ↔ last key), correct only if new queries occupy the tail of the
read window, the cache's contract (writes append). Serves fresh and chunked prefill maskless.
Explicit: a boolean mask (true = attend); boolean avoids MLX's additive-mask
dtype-promotion (a float mask vs fp16 q/k/v is a correctness+perf hazard). ≤4-D, broadcast-compatible
with [B, N_heads, T_q, T_kv] (a 2-D tree mask broadcasts as [1,1,T_q,T_kv], §5.3).
AttendSpec carries backend tensors, so it lives on the backend (op-facing) side.
A cross-backend version would template on the tensor type or use the backend's own
struct; the neutral runner face never sees it.
4.2 Shared Bookkeeping (core, concrete)
Tensor-free, reused by every backend. The cache impls inherit these for all runner-facing
methods and for the per-step plan.
// Integer-only handoff from bookkeeping to the byte layer.
struct StepPlan {
bool contiguous;
int valid_len; // contiguous read bound
int write_start, write_len; // contiguous write run
};
Contiguous, a length counter
class ContiguousBookkeeping : public Cache {
int capacity_, length_ = 0;
public:
explicit ContiguousBookkeeping(int capacity) : capacity_(capacity) {}
bool can_extend(int n = 1) const override { return length_ + n <= capacity_; }
int capacity() const override { return capacity_; }
void clear() override { length_ = 0; } // rows overwritten on reuse; no byte work
// Truncate to new_len (no byte work; rows overwritten on next write), agent backtracking.
void rewind(int new_len) {
if (new_len > length_) throw std::runtime_error("rewind: cannot grow");
length_ = new_len;
}
StepPlan plan(int position, int T) {
int end = position + T;
if (end > capacity_) throw std::runtime_error("cache: exceeds capacity");
if (end > length_) length_ = end;
return StepPlan{ /*contiguous=*/true, /*valid_len=*/end,
/*write_start=*/position, /*write_len=*/T };
}
};
Tree, committed prefix + appended candidate frontiers
The tree cache appends speculative nodes after the committed prefix; each newly-appended
frontier attends the prefix + its ancestor chain. The target appends the whole tree as one
frontier (verify in one forward); a BFS draft appends one frontier per level. Both then commit
the same accepted chain. All integer/bool logic (parent tree, ancestor mask, commit row-mapping);
only the pool bytes + mask wrap are backend-specific.
class TreeBookkeeping : public TreeCache {
protected:
int capacity_, committed_len_=0, frontier_start_=0;
std::vector<int32_t> parent_; // spec-node parents across frontiers; -1 = root
std::vector<uint8_t> mask_bits_; bool mask_valid_=false;
public:
void propose(const int32_t* parents, int n) override; // append a frontier (next forward's queries)
bool can_extend(int n=1) const override { return committed_len_+n <= capacity_; }
int capacity() const override; void clear() override;
int base() const { return committed_len_; } int spec_count() const { return (int)parent_.size(); }
int frontier_start() const { return frontier_start_; } int frontier_size() const { return spec_count()-frontier_start_; }
bool active() const { return !parent_.empty(); } void advance_prefix(int K){ committed_len_+=K; } // no-tree prefill
// [Q, L] bool ancestor mask (memoized): each frontier query attends prefix + its ancestor chain + itself.
const std::vector<uint8_t>& mask_bits() {
if (mask_valid_) return mask_bits_;
int Q=frontier_size(), base=committed_len_, L=base+spec_count(); mask_bits_.assign(Q*L,0);
for (int qi=0; qi<Q; ++qi){ int node=frontier_start_+qi;
for (int j=0;j<base;++j) mask_bits_[qi*L+j]=1; // prefix
for (int nd=node; nd!=-1; nd=parent_[nd]) mask_bits_[qi*L+(base+nd)]=1; } // ancestors + self
mask_valid_=true; return mask_bits_;
}
struct CommitPlan { int base; std::vector<int> rows; }; // rows[i] = src row for dst base+i
CommitPlan commit_plan(const int32_t* accepted, int len); // map accepted chain, advance prefix, reset
};
The backend tree cache (§5.3) inherits this and adds only the pool, the frontier write, the
[0, base+spec_count) read, the compaction gather over commit_plan().rows, and wrapping mask_bits().
4.3 Setting / Installing the Cache
The constraint
The DelegateHandle is opaque to the host: there is no public API to reach it. So
the cache cannot be installed into the handle directly, nor read back from it. The cache
is created by the runner (which knows the kind) and reaches the delegate through a
process-global registry, with the two sides rendezvousing on a cache_key passed as
a runtime backend-load option.
Registry
// Shared TU (runner + backends). A mutex-guarded global map<string, shared_ptr<Cache>> with
// install(key,c) / get(key) / erase(key); neutral (Cache*).
class CacheRegistry { /* global() singleton; install / get / erase */ };
Ownership is shared_ptr, held by three parties: the registry entry, the runner's
session guard, and the delegate handle. The cache lives until all three release it, so
erasing the registry entry mid-method is safe.
Builder registry + factory
Cache kind is expressed by which factory function you call: there is no kind enum
in the config and no runtime kind flag. Backends register builders per kind under their
backend_id; kind survives only as an internal lookup tag.
struct CacheConfig { int capacity, n_layers, n_kv_heads, head_dim, min_chunk; ScalarType kv_dtype; int group_size=64, bits=4; };
using CacheBuilder = std::function<std::shared_ptr<Cache>(const CacheConfig&)>;
// CacheBuilderRegistry: a mutex-guarded global map (backend_id + ":" + kind_tag) -> CacheBuilder,
// with register_builder(id, tag, b) and build(id, tag, cfg) (throws if no builder for that pair).
class CacheBuilderRegistry { /* global() singleton; register_builder; build */ };
RAII session guard
// RAII: installs the cache into the global CacheRegistry under a unique key on construction,
// erases on destruction (no leak on any exit path); holds the runner's shared_ptr + typed face.
template <class Face> // Face = Cache | SeqCache | TreeCache
class CacheSession { std::string key_; std::shared_ptr<Face> cache_;
public: CacheSession(std::string k, std::shared_ptr<Face> c); ~CacheSession(); // install / erase
Face* operator->() const; };
std::string make_unique_key(); // process-global atomic counter -> "cache-N"
Per-use-case factories
// Each factory: build(backend_id, kind_tag, cfg) -> dynamic_pointer_cast to the typed Face
// (mismatch throws at the named call) -> wrap in a CacheSession (RAII install) -> emit cache_key.
inline std::pair<CacheSession<Cache>, BackendOptions>
make_decode_session(const std::string& backend_id, const CacheConfig& cfg); // kind tag "contiguous"
// make_seq_session ("seq") and make_tree_session ("tree") are identical modulo the Face and tag.
Only cache_key goes in the option; the spec lives on the cache (no drift). Key
generation is centralized (no collisions). Install happens-before load by construction
(the session is built before the Module).
Delegate-side binding (in init)
// MLXBackend::init, after state.bind(...). Read cache_key from runtime backend options.
const char* cache_key = /* read "cache_key" from BackendOptions on the init context */;
if (cache_key && *cache_key) {
auto cache = executorch::cache::CacheRegistry::global().get(cache_key);
if (!cache) throw std::runtime_error("cache_key not in registry");
handle->cache_shared = cache; // keep-alive shared_ptr
handle->state.cache = dynamic_cast<MLXCacheImpl*>(cache.get()); // op-face
if (!handle->state.cache) throw std::runtime_error("registry cache not an MLX impl");
}
The exact backend-options accessor depends on the ExecuTorch version (whether options
reach init via BackendInitContext or via a set_option backend virtual). This is a
minor integration detail to match against the target's headers.
The op handler (MLX)
// ExecutionState carries MLXCacheImpl* cache (non-owning; reset() must NOT clear it). The node
// carries constants layer_id / scale / out_dtype.
inline void exec_update_and_attend(const UpdateAndAttendNode& n, ExecutionState& st, StreamOrDevice s) {
if (!st.cache) throw std::runtime_error("update_and_attend: no cache installed");
AttendSpec spec = st.cache->update_and_fetch(n.layer_id, st.const_tensor_ref(n.position),
st.const_tensor_ref(n.k), st.const_tensor_ref(n.v), s); // KV in storage dtype
std::string mode = (spec.mask_kind==AttendSpec::Causal) ? "causal" : ""; // fused, tail-aligned
std::optional<array> mask = (spec.mask_kind==AttendSpec::Explicit) ? spec.mask : std::nullopt; // bool
array out = fast::scaled_dot_product_attention(st.const_tensor_ref(n.q), spec.K, spec.V,
n.scale, mode, mask, std::nullopt, s);
if (out.dtype() != n.out_dtype) out = astype(out, n.out_dtype, s); // honor output contract
st.set_tensor(n.out, std::move(out));
}
The op is attention-pattern-agnostic: it applies whatever the cache produced. Install
a contiguous / seq / tree cache and the same op, same .pte, does causal decode / batched
multi-seq / tree-masked speculative decode.
4.4 The Runner (neutral)
The runner names the backend by string (ideally read from method metadata) and picks a
factory by use-case. Everything else is neutral types.
The only runtime entry point is Module::execute(method_name, inputs) -> outputs
(equivalently the forward convenience). Everything the model needs, tokens, the
position tensor (one int per query token), is passed as EValue tensor inputs to that one call;
logits come back as a tensor output. There is no execute_forward-style convenience; the runner
builds input tensors and reads output tensors itself.
auto [session, opts] = make_decode_session("MLXBackend",
{.capacity=4096, .n_layers=32, .min_chunk=512, .kv_dtype=Half});
Module module(model_path, opts);
module.load_method("forward"); // init reads cache_key, binds the SAME cache
// Small host helpers that wrap the real Module::execute surface:
// to_tensor(...) -> an EValue tensor; position = one int per query token (q_len).
auto step = [&](const std::vector<int>& toks, const std::vector<int32_t>& pos) {
std::vector<EValue> inputs = { to_tensor(toks), to_tensor_i32(pos) };
auto outputs = module.execute("forward", inputs); // <-- the ONLY runtime call
return outputs[0].toTensor(); // logits
};
int pos = (int)prompt.size();
auto logits = step(prompt, range_i32(0, prompt.size())); // prefill: position [0..T)
int tok = sampler(logits);
for (int i = 0; i < max_new && !is_eos(tok); ++i) {
if (!session->can_extend()) break; // hard-stop, neutral face
logits = step({tok}, /*position=*/{ pos }); // decode: position [pos]
tok = sampler(logits);
++pos;
}
// session dtor erases the registry entry on scope exit.
Per-step flow is forward, then sample, then (seq update | commit), then forward, the structural change
happening between execute calls (integer-only bookkeeping for multi-seq; tree commit also
compacts the accepted chain). The op signature is invariant across decode / multi-seq / tree.
4.5 Sizing: Reading Metadata at Runtime
The mirror of §3.2: a shared, model-agnostic reader pulls the architecture facts from
MethodMeta and make_config composes them with caller policy (context length, KV storage
dtype) into a CacheConfig. This removes the hardcoded dimensions and magic
capacity from the examples, sizing is derived from the .pte plus an explicit runtime
budget.
// shared reader (schema matches write_kv_cache_metadata, §3.2); checks kv_cache.schema_version==1.
struct ModelCacheParams { int n_layers, n_kv_heads, head_dim; };
ModelCacheParams read_kv_cache_metadata(const MethodMeta& meta); // reads kv_cache.{n_layers,n_kv_heads,head_dim}
// model facts (.pte) + caller policy (runtime) -> full CacheConfig:
CacheConfig make_config(const ModelCacheParams& m, int capacity, // capacity = RUNTIME context budget
ScalarType kv_dtype = ScalarType::Half, int min_chunk = 512) {
return { capacity, m.n_layers, m.n_kv_heads, m.head_dim, min_chunk, kv_dtype }; // kv_dtype = storage, NOT out_dtype
}
CacheConfig carries n_kv_heads/head_dim (pool sized at construction) and kv_dtype (storage),
not the attention output dtype (that's the op's, §3.1). capacity is the only growth bound (no
serialized max_context_len; the lazy-doubling pool keeps short sessions cheap but capacity is the
hard ceiling). Size it for the peak live-cell count across active sequences (seq_cp-shared prefixes
count once; tree needs committed_len + max_tree_size).
The runner then sizes from the .pte:
auto mp = read_kv_cache_metadata(module.method_meta("forward"));
auto cfg = make_config(mp, /*capacity=*/8192); // this session's context budget
auto [session, opts] = make_decode_session("MLXBackend", cfg);
5. MLX Examples
There are three concrete MLX caches, all sharing the same Pool byte-layer policy (§5), so
each gets fp16/bf16/fp32 and quantized KV for free, and all inheriting their integer
bookkeeping from §4.2:
ContiguousCache (§5.1), single-stream decode/prefill; the guaranteed-fast hot path.
SeqCache (§5.2), the multi-sequence cell pool (batching, fork, eviction).
TreeSeqCache (§5.3), speculative / tree decoding.
Each adds only its MLXCacheImpl::update_and_fetch:
// op-facing interface (standalone, does not derive Cache, to avoid a diamond)
class MLXCacheImpl {
public:
virtual AttendSpec update_and_fetch(int layer, const mlx::core::array& position,
const mlx::core::array& k, const mlx::core::array& v,
mlx::core::StreamOrDevice s) = 0;
virtual ~MLXCacheImpl() = default;
};
The contiguous cache uses one helper, int run_start(position, T, s), it returns position[0]
after asserting the T entries form a contiguous run [start, start+T).
The Pool byte-layer policy (shared by all three caches)
All three MLX caches store KV through a per-layer Pool, the only place fp16 and quantized
differ. The pool is SDPA-major ([1, H, cells, D], cells on axis 2), so a read is a plain
contiguous slice (no transpose) and writes are slice_update/scatter on axis 2.
// per-layer cell store: the ONLY thing fp16 and quantized differ on. SDPA-major [1, H, cells, D].
struct Pool {
virtual void ensure(int needed, int H, int D, StreamOrDevice s) = 0; // doubling growth (axis 2)
virtual void write_run (int start, const array& kv, StreamOrDevice s) = 0; // contiguous T cells (axis 2)
virtual void scatter_rows(const array& cells, const array& kv, StreamOrDevice s) = 0; // scatter on axis 2
virtual void gather_into (int base, const array& src_cells, StreamOrDevice s) = 0; // compact -> [base,..) (tree commit, §5.3)
virtual array read(int start, int len, StreamOrDevice s) = 0; // slice axis 2 -> [1,H,len,D]
virtual ~Pool() = default;
};
// unquantized float (fp16/bf16/fp32): ONE tensor buf_ [1,H,cap,D], storage dtype dt_. Passthrough:
// ensure = grow buf_ by doubling on axis 2 (the slice_update copy donates the old buffer);
// write_run / scatter_rows = slice_update / scatter kv on axis 2;
// read = slice [0,len) on axis 2 -> [1,H,len,D] (a contiguous view; no transpose, no dequant).
struct FPPool : Pool { std::optional<array> buf_; int cap_=0, min_chunk_; Dtype dt_; /* ... */ };
// quantized (affine, per-group): THREE tensors (packed q_, scales sc_, biases bs_). Identical axis-2
// addressing, but write_run/scatter quantize(kv) first and read slices all three then dequantize
// -> [1,H,len,D] (or passes packed to a fused quantized-SDPA). Only the representation differs.
struct QuantPool : Pool { std::optional<array> q_,sc_,bs_; int cap_=0, min_chunk_, group_, bits_; /* ... */ };
// the factory picks the policy from kv_dtype; there is NO separate quantized cache class:
using PoolFactory = std::function<std::unique_ptr<Pool>()>;
inline PoolFactory pool_for(const CacheConfig& c) {
if (is_quantized(c.kv_dtype))
return [c]{ return std::make_unique<QuantPool>(c.min_chunk, c.group_size, c.bits); };
return [c]{ return std::make_unique<FPPool>(c.min_chunk, resolve_dtype(c.kv_dtype)); };
}
5.1 Contiguous Cache (ContiguousCache)
The simplest example and single-stream hot path: one sequence, a length counter, no cell pool, no
classification. Uses the shared Pool (so quant/growth come free); writes a contiguous run, reads
the prefix, causal for prefill / none for decode, guaranteed maskless.
class MLXContiguousCache final
: public executorch::runtime::ContiguousBookkeeping, // length counter: plan(start,T), can_extend, clear
public MLXCacheImpl {
std::vector<std::unique_ptr<Pool>> kpool_, vpool_; // shared Pool policy (FPPool | QuantPool)
public:
MLXContiguousCache(int capacity, int n_layers, const PoolFactory& mk)
: ContiguousBookkeeping(capacity) {
for (int l=0;l<n_layers;++l){ kpool_.push_back(mk()); vpool_.push_back(mk()); }
}
AttendSpec update_and_fetch(int layer, const array& position,
const array& k, const array& v, StreamOrDevice s) override {
int T=k.shape(2), H=k.shape(1), D=k.shape(3);
int start = run_start(position, T, s); // contiguous run start (position[0]; debug-checks the run)
StepPlan p = plan(start, T); // SHARED: write_start, valid_len (length counter)
Pool& K=*kpool_[layer]; Pool& V=*vpool_[layer];
K.ensure(p.valid_len,H,D,s); V.ensure(p.valid_len,H,D,s); // lazy doubling growth
K.write_run(p.write_start,k,s); V.write_run(p.write_start,v,s); // contiguous run, in-place (donation)
return { K.read(0,p.valid_len,s), V.read(0,p.valid_len,s), // contiguous prefix
(T>1)?AttendSpec::Causal:AttendSpec::None, std::nullopt };
}
};
Chunked prefill (T < valid_len) stays correct because the new tokens are written at the tail and
"causal" is lower-right aligned (§7). This is SeqCache's single-sequence fast path
(§5.2) carved out as a standalone, no-general-path-reachable class for the common case.
5.2 Sequence Cache (SeqCache)
Plain decode, quantized KV, and agent prefix-sharing are all specializations of one structure: a
pool of cells, each tagged with a position and a set of sequences. That buys
multi-sequence batching, zero-copy prefix sharing, positional rollback, and continuous admission,
with single-stream decode as a fused fast path at no cost and no op-signature change. TreeSeqCache
(§5.3) is a thin front-end over the same substrate. (vLLM's logical/physical addressing + llama.cpp's
per-cell bitset, at token granularity, §6.) Throughout, the batch is flat on the token axis (q =
[1, H, n_tok, D], B=1): sequences are distinguished by seq_id + the mask, not a batch dimension.
seq_id stays out of the op
position must be a graph input because RoPE consumes it in-graph (§3.1). seq_id is
different: nothing in the transformer compute touches it. It is cache control only (which
cells a token writes, and which cells it may read via the mask). Since the cache is runtime
state reached through the registry (§4.3), the per-token sequence assignment rides the same
between-execute channel as propose/commit, a begin_step(seq_id, n_tok) call, not a
graph tensor. Consequence: going from single- to multi-sequence requires no re-export. The
op signature (and update_and_fetch) is unchanged; only position is ever export-locked.
Shared bookkeeping: the cell pool (core, tensor-free)
struct SeqInfo { int min_cell=INT_MAX, max_cell=-1, count=0; }; // per-seq, tracked incrementally
class SeqBookkeeping : public Cache {
using SeqBits = uint64_t; // up to 64 live sequences
std::vector<int32_t> pos_; std::vector<SeqBits> seq_; // per cell: position; owning-seq bitset
std::vector<int32_t> free_; std::vector<SeqBits> step_seqs_; // free cells; next forward's per-token seqs
int capacity_, high_water_=0; std::array<SeqInfo,64> info_{}; StepPlan plan_; bool plan_valid_=false;
int alloc(int32_t pos, SeqBits bits); // pop a free cell; set pos_/seq_; update info_ + high_water_
public:
bool can_extend(int n=1) const override { return (int)free_.size() >= n; }
int capacity() const override; void clear() override;
// between-execute (pure bookkeeping, no bytes); each invalidates the memoized plan (§7):
void begin_step(const int32_t* seq_id, int n_tok); // declare next forward's per-token seqs
void seq_cp (int32_t src,int32_t dst,int p0=0,int p1=INT_MAX); // zero-copy share: set dst's bit on src's cells
void seq_rm (int32_t seq,int p0=0,int p1=INT_MAX); // clear bit; free a cell ONLY when no owner remains
void seq_keep(int32_t keep); // commit one seq, free the rest
int seq_len(int32_t seq) const { return info_[seq].count; }
const StepPlan& plan(const int32_t* posn, int n_tok); // alloc cells, classify, build Explicit mask
};
The free-list hands out cells ascending; seq_cp/seq_rm set/clear a bit (freeing a cell only when
its bitset empties) and reindex the affected seq. begin_step/seq_cp/seq_rm are the corruption
sites (each invalidates the memoized plan, §7).
Read-path classification (StepPlan::Kind)
plan allocates a cell per query token, then classifies the batch so the byte layer can
pick the cheapest correct path. StepPlan (extending §4.2) carries:
struct StepPlan { // (sequence-cache fields; earlier fields elided)
enum Kind {
SingleContiguous, // 1 seq, contiguous tail, no holes -> fused causal/none, in-place write
Explicit, // everything else (multi-seq, fork, holes, tree) -> scatter + dense mask (ALWAYS correct)
} kind;
int write_start, valid_len; // SingleContiguous
std::vector<int32_t> write_cells; int read_len; // Explicit: scatter targets + read window
std::vector<uint8_t> mask_bits; // Explicit: [n_tok, read_len] bool
};
There are two Kinds, classified O(active_seqs) from the incremental info_:
SingleContiguous: one active seq, contiguous (count == max_cell - min_cell + 1), queries at its tail. Fused causal/none, in-place write, no mask.
Explicit: everything else, any multi-sequence batch (independent decode/prefill and the shared-prefix agent fork), holes from seq_rm, tree masks, interleaving. Scatter write + dense boolean mask.
Explicit is always correct, so SingleContiguous is a pure optimization for the single-stream
hot path; a backend could implement only Explicit. Fragmentation (after seq_rm) keeps you
on Explicit.
All multi-sequence work (independent sequences and the shared-prefix fork) goes through
Explicit. Two faster paths, disjoint/varlen and prefix-shared/online-softmax, were dropped:
neither reliably beats Explicit on MLX today and both need a kernel MLX doesn't expose, and
Explicit is already copy-free for the fork (shared trunk read by one window slice; mask isolates
branches). Future optimizations, gated on a varlen / LSE kernel.
The Explicit mask is the correctness core, O(n_tok·L), memoized across the L layer calls, with
the visibility rule: query i attends cell j iff pos_[j] >= 0 && (seq_[j] & seq_i) && pos_[j] <= pos_i
(occupied and same-sequence and causal). The read window is L = high_water, the longest
occupied extent (never capacity/phys_cap, so a short batch reads a small window at any
capacity). classify is the cheap O(active_seqs) front that picks SingleContiguous when one seq
is contiguous with its queries at the tail (setting write_start/valid_len), else Explicit.
MLX byte layer (MLXSeqCache)
update_and_fetch is written once over the shared Pool (§5): plan classifies, then
SingleContiguous does a contiguous write+slice (fused, no mask) and Explicit scatters + reads the
[0, high_water) window with a dense mask. fp16 vs quantized is decided by pool_for(cfg).
class MLXSeqCache final
: public executorch::runtime::SeqBookkeeping, // cell pool, seq verbs, plan + classification
public MLXCacheImpl { // MLX bytes only
std::vector<std::unique_ptr<Pool>> kpool_, vpool_;// one per layer; pool_for(cfg) -> FPPool | QuantPool
public:
MLXSeqCache(int capacity, int n_layers, const PoolFactory& mk)
: SeqBookkeeping(capacity) {
for (int l=0;l<n_layers;++l){ kpool_.push_back(mk()); vpool_.push_back(mk()); }
}
AttendSpec update_and_fetch(int layer, const array& position,
const array& k, const array& v, StreamOrDevice s) override {
int T=k.shape(2), H=k.shape(1), D=k.shape(3);
eval(position);
const StepPlan& p = plan(position.data<int32_t>(), T); // SHARED: place + classify
Pool& K=*kpool_[layer]; Pool& V=*vpool_[layer];
K.ensure(high_water_,H,D,s); V.ensure(high_water_,H,D,s); // lazy doubling growth to high_water
switch (p.kind) {
case StepPlan::SingleContiguous: // == ContiguousCache fast path (§5.1)
K.write_run(p.write_start,k,s); V.write_run(p.write_start,v,s);
return { K.read(0,p.valid_len,s), V.read(0,p.valid_len,s),
(T>1)?AttendSpec::Causal:AttendSpec::None, std::nullopt }; // fused, no mask
case StepPlan::Explicit: default: { // multi-seq / fork / holes / tree / interleaved
auto cells = array(p.write_cells.data(), {T});
K.scatter_rows(cells,k,s); V.scatter_rows(cells,v,s);
array m = array(p.mask_bits.data(), {T,p.read_len}, bool_, s); // [T,L] -> [1,1,T,L]
return { K.read(0,p.read_len,s), V.read(0,p.read_len,s), AttendSpec::Explicit, m };
}
}
}
};
Single-sequence ⇒ contiguous cells, so plain decode/prefill always hits SingleContiguous (the
same fused, maskless path as ContiguousCache, §5.1); the scatter + dense-mask cost is paid
only for genuine multi-sequence/tree use.
Tree / speculative as a front-end over the Pool policy
A tree's per-node ancestor mask (siblings sharing a parent must not see each other) can't be
expressed by the seq-bitset + position-causal rule, so MLXTreeSeqCache keeps its own
bookkeeping (§4.2 TreeBookkeeping: parent array, frontier, ancestor mask_bits(), commit
row-mapping) and shares only the Pool policy + the Explicit byte path, gaining quantized KV
and lazy growth for free:
class MLXTreeSeqCache final : public TreeBookkeeping, public TreeSeqCache, public MLXCacheImpl {
std::vector<std::unique_ptr<Pool>> kpool_, vpool_; // SAME Pool policy (FPPool | QuantPool)
public:
void commit(const int32_t* accepted,int len) override; // §4.2 row-map, then per-layer Pool::gather_into
AttendSpec update_and_fetch(int layer,const array& pos,const array& k,const array& v,StreamOrDevice s) override {
// no active tree -> prefill/decode: write_run at base(); read [0, base()+Q); Causal/None.
// active tree -> write THIS frontier at base()+frontier_start(); read [0, base()+spec_count());
// return AttendSpec::Explicit with the §4.2 ancestor mask wrapped as a bool array.
}
};
The only new Pool method is gather_into(base, idx) (compact scattered accepted rows into the
contiguous run at base; fp16 gathers one tensor, QuantPool three), the tree's analog of
scatter_rows. Everything tree-specific is unchanged §4.2 bookkeeping; only storage moves to Pool.
Runner-facing API
class SeqCache : public executorch::runtime::Cache { // neutral; what the runner holds
public: // base: can_extend / capacity / clear
virtual void begin_step(const int32_t* seq_id, int n_tok) = 0; // next forward's seqs
virtual void seq_cp (int32_t src, int32_t dst, int p0=0, int p1=INT_MAX) = 0;
virtual void seq_rm (int32_t seq, int p0=0, int p1=INT_MAX) = 0;
virtual void seq_keep(int32_t seq) = 0;
virtual int seq_len (int32_t seq) const = 0;
};
class TreeSeqCache : public SeqCache { // speculative front-end (§5.3)
public:
virtual void propose(const int32_t* parents, int n) = 0;
virtual void commit (const int32_t* accepted, int len) = 0;
};
// factories parallel to §4.3 (kind named by the function; no enum):
std::pair<CacheSession<SeqCache>, BackendOptions> make_seq_session (const std::string& backend_id, const CacheConfig&);
std::pair<CacheSession<TreeSeqCache>, BackendOptions> make_tree_session(const std::string& backend_id, const CacheConfig&);
Use cases
A tiny host helper bundles (token, position, seq) and steps once (begin_step delivers seq
out-of-band; execute carries tokens + positions):
struct Batch { std::vector<int> tok; std::vector<int32_t> pos, seq;
void add(int t,int32_t p,int32_t s){ tok.push_back(t); pos.push_back(p); seq.push_back(s);} };
auto step = [&](Batch& b){ session->begin_step(b.seq.data(), (int)b.seq.size());
return module.execute("forward", { to_tensor(b.tok), to_tensor_i32(b.pos) })[0].toTensor(); };
// (1) PLAIN DECODE, single sequence. Hits SingleContiguous (fused, no mask).
Batch pf; for (int i=0;i<P;i++) pf.add(prompt[i], i, /*seq=*/0);
auto lg = step(pf); int tok = sample(lg, P-1), pos = P;
while (!is_eos(tok) && session->can_extend()) { Batch d; d.add(tok,pos,0);
lg = step(d); tok = sample(lg,0); ++pos; }
// (2) BATCHED MULTI-SEQ DECODE, N independent streams in one forward. Explicit (dense mask).
Batch b; for (int s=0;s<N;s++) if (active[s]) b.add(tok[s], pos[s], s);
lg = step(b); /* sample per row, ++pos[s] */
// (3) AGENT FORK, shared trunk + branches. seq_cp (zero copy); branches run via Explicit.
Batch t; for (int i=0;i<C;i++) t.add(ctx[i], i, 0); step(t); // trunk into seq 0
int trunk = session->seq_len(0);
for (int br=1;br<=K;br++) session->seq_cp(/*src=*/0, /*dst=*/br); // fork, no bytes (shared cells)
Batch s; for (int br=1;br<=K;br++) s.add(seed[br], trunk, br); step(s); // Explicit: trunk read once, mask isolates branches
int best = score(); session->seq_keep(best); // commit winner, free rest
// (4) CONTINUOUS ADMISSION, new request joins mid-flight (mixed prefill+decode). Explicit.
session->seq_cp(0, 3); // optional shared system prompt
if (session->can_extend(prompt3.size())) { Batch m;
m.add(tok1,pos1,1); m.add(tok2,pos2,2); // ongoing DECODE
int base = session->seq_len(3);
for (int i=0;i<(int)prompt3.size();i++) m.add(prompt3[i], base+i, 3); // new PREFILL
step(m); } // scheduler decides WHEN; cache exposes the seam
// (5) EVICTION / ROLLBACK / SLIDING WINDOW, all positional seq_rm (byte-free).
session->seq_rm(4); // evict whole sequence
session->seq_rm(1, 0, drop_before); // sliding window: drop oldest of seq 1
session->seq_rm(2, keep_len); // backtrack seq 2's last response, then regenerate
Speculative/tree decoding uses make_tree_session + propose/commit (§5.3); the byte layer it
needs is this section's Explicit fallback.
What this subsumes
SeqCache (+ the Pool policy) covers, on one substrate, what would otherwise be separate caches:
- Plain decode/prefill: the
SingleContiguous fast path (also offered standalone as ContiguousCache, §5.1).
- Quantized KV: a
QuantPool (pool_for on kv_dtype), not a separate class.
- Agent fork / prefix-sharing:
seq_cp/seq_rm/seq_keep + Explicit (trunk read copy-free as one window slice; the mask isolates branches).
- Tree / speculative:
TreeSeqCache (§5.3), a front-end sharing the Pool + Explicit path, keeping its own propose/commit + ancestor mask.
The one cache not folded in is beam search: expressible (beams = sequences) but reparent
costs O(depth·beams) as a bitset rewrite vs a slot pool's O(beams), so it would need its own
front-end, omitted as deprecated/niche.
Invariants & costs
- No re-export for single- to multi-sequence (
seq_id is between-execute control, not a graph
input); position is the only export-locked input.
- Fast paths are optional;
Explicit is always correct. Classification is conservative
(O(active_seqs)); fragmentation degrades to Explicit, compaction restores the fast path.
- Multi-sequence ⇒
Explicit (a dense mask each step; positions differ, so a single Causal
can't serve them, RoPE already applied them). Only SingleContiguous keeps the fused, maskless
kernel. Explicit is copy-free for the fork (shared trunk = one window slice).
- Reads are bounded by occupancy (
valid_len / read_len = high_water), never phys_cap or
capacity, a short batch reads a small window at any capacity.
- Scheduler-ready, scheduler-free. Admission/preemption sit above the seam verbs (§6, §8).
- Elided helpers (backend-defined):
active_seqs and each Pool's quantize/dequantize/
gather_into.
5.3 Tree / Speculative runtime (Eagle3)
The tree/speculative cache is MLXTreeSeqCache (§5.2, inheriting §4.2 TreeBookkeeping over the
shared Pool). This section is the runtime: how a target + draft drive two tree-cache sessions
for Eagle-style speculative decoding behind the one unchanged op. Recap: propose appends a frontier
(next forward's queries) attending the prefix + its ancestor chain via an Explicit mask; commit
compacts the accepted chain into the prefix. The target verifies the whole tree in one forward; a
BFS draft appends one frontier per level; both use the same cache.
Mask example. Prefix [0..base); a frontier of 4 nodes 0,1,2,3 with parent=[-1,0,0,1]
(node 0 root; 1,2 children of 0; 3 child of 1), e.g. the target verifying the whole tree at
once:
query \ key prefix[0..base) base base+1 base+2 base+3
node 0 all attend T . . .
node 1 (←0) all attend T T . .
node 2 (←0) all attend T . T .
node 3 (←1) all attend T T . T
Node 2 does not see node 1 (different branch); node 3 sees 0 and 1 (its ancestors).
A causal flag cannot express row 2's "see 0, skip 1": proof that masking must be
cache-produced.
Eagle3 decode loop. Target and draft are separate Modules, each with its own
TreeSeqCache (own .pte/cache_key/metadata); everything is Module::execute("forward", …).
The target's forward emits {logits, features} (Eagle taps the residual stream). The draft
expands breadth-first: one forward per tree level (propose+forward), so a depth-d tree
costs d draft forwards regardless of width. The target appends the whole tree as one frontier
and verifies it in ONE forward. Greedy verify descends from the root accepting each child whose
token matches the target's argmax; the argmax past the chain is a free, always-correct bonus
(re-fed as the next root). Both caches then commit the same accepted chain, symmetric, since
each already holds the KV, so commit compacts it into the prefix. Host helpers
(TreeProposal{tokens,parents,depths}, VerifyResult{accepted,bonus}, verify_tree,
gather_features) are not ExecuTorch API. Skeleton:
// target + draft: make_tree_session("MLXBackend", make_config(meta, capacity)); load "forward".
// Prefill BOTH with all prompt tokens EXCEPT the last (the last is round 1's root, at pos_base).
while (!done && produced < max_len) {
TreeProposal tree = propose_bfs(feat, root, pos_base, depth, topk); // draft: per-level propose+forward; fills draft cache
tgt_session->propose(tree.parents.data(), tree.size()); // target: whole tree as one frontier
auto out = target.execute("forward", { to_tensor(tree.tokens), to_tensor_i32(tree_positions) });
VerifyResult vr = verify_tree(tree, out[0].toTensor()); // longest matching path + bonus
tgt_session->commit(vr.accepted.data(), vr.len); // both caches commit the SAME chain
drf_session->commit(vr.accepted.data(), vr.len);
pos_base += vr.len; root = vr.bonus; feat = gather_features(out[1].toTensor(), vr);
// emit (vr.len-1) accepted candidates + the bonus; done = is_eos(vr.bonus).
}
The speedup: multiple tokens (accepted chain + bonus) from one target forward, with the draft
expanding the tree in depth forwards. The bonus is the target's own greedy choice, so it is always
correct, rejection only shortens the chain. With no active tree (propose not yet called) the same
TreeSeqCache returns causal/none, so it serves prompt prefill, the per-level draft forwards, and
the target verify; propose/commit are the only tree-specific verbs. The loop drives two
independent tree sessions and exercises the per-node position tensor, all behind the one unchanged op.
Backend registration (one place MLX cache types are named)
// MLX backend TU, near register_backend(...). All three share pool_for(c) -> FPPool|QuantPool,
// so quantized KV is cfg.kv_dtype; there is no separate "quantized" builder.
namespace {
const int _mlx_cache_reg = [] {
using namespace executorch::cache;
auto& R = CacheBuilderRegistry::global();
R.register_builder("MLXBackend", "contiguous", [](const CacheConfig& c){
return std::make_shared<MLXContiguousCache>(c.capacity, c.n_layers, pool_for(c)); });
R.register_builder("MLXBackend", "seq", [](const CacheConfig& c){
return std::make_shared<MLXSeqCache>(c.capacity, c.n_layers, pool_for(c)); });
R.register_builder("MLXBackend", "tree", [](const CacheConfig& c){
return std::make_shared<MLXTreeSeqCache>(c.capacity, c.n_layers, pool_for(c)); });
return 0;
}();
} // namespace
Quantized KV is not a separate builder, pool_for(cfg) picks QuantPool when
cfg.kv_dtype is a quantized type, so any of the three caches is quantized by setting
kv_dtype (§4.5). make_decode_session / make_seq_session / make_tree_session (§4.3,
§5.2) select the "contiguous" / "seq" / "tree" builders respectively.
Note on the target's forward. Eagle3 requires the exported target graph to emit
hidden-state features as a second output, so target.execute("forward", …) returns
{logits, features}. This is an AOT/graph change on the target model (expose the
residual stream), orthogonal to the cache: the cache produces K,V + mask regardless;
Eagle also taps features. The draft model is independent and shares nothing
with the cache or the target beyond the host loop.
6. Comparison to vLLM and llama.cpp
At its core this is PagedAttention's logical/physical addressing at token granularity, stripped
of the scheduler/preemption/swapping that are most of vLLM's mass, adapted to a frozen
exported-graph on-device runtime, the single-user version of the same ideas.
Shared ideas
Logical addressing decoupled from storage (vLLM's block table ≈ our cell map: per-cell seq bitset +
pos); refcounted shared history with fork (vLLM block COW ≈ our seq_cp bit-set, no COW); and
append-write + bulk history read from logical state (vLLM slot_mapping/block-walk ≈ our StepPlan,
write_cells + read_len).
Where it differs (and why)
The Level column marks each row design (neutral) vs MLX impl (a backend choice another
backend could make differently):
| Dimension |
vLLM |
This design |
Level |
| Granularity |
Block (e.g. 16 tokens) |
Token |
design |
| Bookkeeping cost |
Cheaper (per-block) |
Per-step per-cell bookkeeping (per-token) |
design |
| Fragmentation / COW |
Internal (partial blocks) + COW |
Zero internal fragmentation, no COW (token granularity) |
design |
| Scheduler |
Central: admission, preemption, swapping, continuous batching |
None, single user / a few sequences, host hard-stop only |
design |
| Tree / speculative |
External / kernel-specific |
First-class behind the same op (Explicit mask via AttendSpec) |
design |
| Export boundary |
None (Python runtime) |
Op exported & lowered per backend; cache is runtime state outside the graph |
design |
| Attention semantics |
Fixed by the paged kernel |
Produced by the cache (AttendSpec), op-agnostic |
design |
| Attention kernel |
Custom paged-attention kernel |
Backend's choice, MLX uses stock SDPA after gather/slice; a paged backend could use a fused paged kernel |
MLX impl |
| Read mechanics |
In-kernel block-table walk |
Backend's choice, MLX scatters/gathers cells (multi-seq) / slices (decode) before attention |
MLX impl |
| Storage layout |
Paged blocks |
Backend's choice, MLX uses a contiguous doubling pool; a paged Cache impl is allowed by the same interface |
MLX impl |
| Target |
Multi-tenant GPU serving |
Single-user on-device (current scope) |
scope |
| The design-level contrasts (not MLX choices): token granularity, zero fragmentation, no COW |
|
|
|
| (divergence appends to fresh cells), paid for with cheap per-cell bookkeeping, the absence of a |
|
|
|
| scheduler, first-class tree/speculative, the export boundary, and cache-produced attention |
|
|
|
| semantics. The kernel / read-mechanics / storage rows are MLX choices a paged backend could make |
|
|
|
vLLM-like behind the same AttendSpec. |
|
|
|
What vLLM has that this omits
The entire scheduler layer: admission, continuous batching, preemption, KV swapping. That's
most of what makes vLLM a serving system and what a single-user target doesn't need;
omitting it is what lets the cache be a simple pool with a host-side hard-stop.
Where it goes beyond vLLM
Tree/speculative decoding is first-class here. Token-granularity cells express the per-node
ancestor mask (Explicit via AttendSpec) behind the same op that serves plain decode. The
runtime-swappable cache, decode vs. tree chosen by which cache is installed, with one unchanged
.pte, has no vLLM analog, because vLLM has no export/delegate boundary to be invariant across.
(Token granularity would also make beam search exact and cheap, but beam is not implemented, §5.)
The path to the vLLM regime
If multi-tenant serving ever becomes a target, it's reachable without a rewrite (§8): a paged batched
Cache impl, the per-token position (already a tensor), a ragged/paged read, and a scheduler above
the runner, reusing the op-as-contract, registry, and neutral/byte split, but reintroducing COW +
the scheduler. PagedAttention's addressing without its scheduler, sized for the device, open toward
the server.
Fit for the desktop-agent target
The two interfaces are general along different axes. Read against the target from §1
(single-user desktop, including agents) rather than multi-tenant serving, the comparison
favors this design on the axes that matter here:
| Extension axis |
This design |
vLLM |
Relevant to desktop-agents? |
| New decoding strategy (tree/speculative) |
Clean, new Cache impl behind the same op |
Kernel/engine change |
Yes, speculative decoding is a latency win, and agent loops are latency-sensitive every step |
| New hardware backend |
Clean, implement update_and_fetch over the backend's tensor |
Effectively a rewrite (CUDA-shaped) |
Yes, desktop is heterogeneous (MLX, CoreML, Vulkan, …) |
| Concurrent sequences |
SeqCache: N sequences in one forward (begin_step + Explicit), or separate caches per cache_key |
Native (multi-row, one batch) |
Partly, sub-agents / parallel tool calls want a handful of live sequences, not a ragged batch |
| Cross-branch prefix reuse |
Via SeqCache fork (seq_cp, §5.2); cross-instance forking (§8) |
Native (refcounted blocks) |
Yes, agents share a long system/tool prompt across branches |
| Eviction / preemption / swap |
Not modeled (host hard-stop only) |
Native |
No, single-user has nothing to preempt against |
| Takeaway: the axes this design is general across, decoding strategy, backend portability, are the |
|
|
|
| ones desktop-agents exercise most; vLLM's extra generality (scheduler, eviction, swap) is mostly out |
|
|
|
| of scope. The two agent pressures that do bite, light multi-session and cross-branch prefix reuse |
|
|
|
, are both SeqCache features (begin_step, seq_cp, §5.2), not the paged-store + scheduler regime. |
|
|
|
Comparison to llama.cpp
After the cell-pool design, the cache layer here is llama.cpp's: a pool of cells,
each carrying a position and a per-cell sequence bitset (std::bitset<LLAMA_MAX_SEQ> there,
SeqBits here), with zero-copy prefix sharing, positional rollback, and a position-based causal +
sequence-membership mask. The one fundamental difference is the boundary: llama.cpp is its own
runtime (gguf weights + a ggml graph rebuilt every decode), so it can be fully dynamic; this design
recovers that flexibility inside ExecuTorch's ahead-of-time-exported, delegated, multi-backend
runtime, behind a single op.
Near-identical shared ideas: per-cell pos + sequence bitset (llama_kv_cells ≈ our pos_/seq_),
zero-copy seq_cp, positional seq_rm (free a cell when its bitset empties), the same
position-based causal + membership mask (seq_has(cell,seq) && pos[cell] ≤ pos[q], both built dense
on CPU per decode), and a flat batch with per-token seq_id (llama_batch ≈ our begin_step).
| Dimension |
llama.cpp |
This design |
| Boundary |
own runtime; ggml graph rebuilt per decode |
exported op lowered per backend; cache is runtime state outside a frozen graph |
| Dynamic growth |
fixed kv_size (decode fails when full) |
lazy doubling up to capacity |
| KV dtype / quant |
set at context creation (type_k/type_v) |
per-session Pool (pool_for on kv_dtype), no re-load |
| True tree mask (Eagle) |
none, lookahead via multi-seq; eagle3 is a // TODO stub |
first-class Explicit ancestor mask (TreeSeqCache, §5.3) |
| Scheduler / continuous batching |
yes (server: admission, preemption, continuous batch) |
deferred, seams only (§8) |
| SWA / recurrent / hybrid caches |
yes (iSWA, Mamba/RWKV, hybrid) |
not addressed |
| Multi-backend from one exported model |
n/a (its own backends) |
the point, runs behind MLX / CoreML / Vulkan delegates from one .pte |
| Maturity |
production, huge model coverage |
proposal |
Performance. On Apple silicon both ride tuned Metal, so common-path parity is plausible; but
llama.cpp wins today on maturity, kernel breadth, architecture coverage, and concurrent
throughput, while this design's edges (fused single-stream, tree/Eagle, dynamic growth, runtime
quant) are designed-in but unbuilt. The goal is not to beat llama.cpp but to bring its cache
flexibility into ExecuTorch's export/delegate/portability model (which llama.cpp doesn't target),
adding tree/Eagle + dynamic growth, scoped below the scheduler.
7. Key Invariants & Notes
- Bookkeeping is shared and tensor-free; each backend adds only
update_and_fetch.
A new backend = register builders + one op handler + one *CacheImpl. A new cache kind
= a bookkeeping type + a byte layer + a factory function (kind named by the function,
not an enum).
StepPlan/mask invalidation after begin_step/seq_cp/seq_rm/propose/commit are
the corruption sites. Test bookkeeping against the eager reference oracle before wiring kernels.
- No COW at token granularity (shared history read-only; divergence appends to fresh
cells). Returns only if you coarsen to blocks.
- Donation discipline on every pool reassignment, or in-place writes silently become
full copies. Verify with a memory check at long context.
shared_ptr ownership across registry / session / handle; install
happens-before load (factory return order); erase on session scope exit (RAII).
- Hard-stop bounding via
can_extend (host-side); lazy doubling growth so short
sequences do not reserve full context.
- Masks/causality are cache-produced (
AttendSpec), never an op flag. This is what
lets one op serve contiguous, multi-sequence, and tree decoding by runtime cache swap.
Causal requires queries at the tail. MLX's "causal" is lower-right aligned, so a cache
returning Causal must place new tokens at the end of the read window (holds by construction;
makes chunked prefill maskless). Explicit masks are boolean (avoids the additive-mask
dtype-promotion rule and float-mask perf cliff).
- Two dtype channels. KV storage dtype is runtime policy (
CacheConfig.kv_dtype); attention
output dtype is the op contract (out_dtype node constant). Independent: int4 KV still emits
out_dtype. Sizing splits the same way (.pte arch facts §3.2; runtime policy §4.5).
rewind/seq_rm truncate without moving bytes (keep [0, new_len), overwrite later) for agent
backtracking/regeneration. Independent caches (target+draft, sub-agents) coexist via distinct
cache_keys.
- Same-process assumption: the registry is a process global, so runner and backend
must share an address space. Cross-process would route control (seq_cp/commit)
through execute instead.
8. Future Extensions (open, not built)
SeqCache already covers what earlier drafts listed here: concurrent sequences (begin_step),
intra-session prefix sharing / fork (seq_cp), eviction and rollback (seq_rm), and swappable
per-session caches (per-cache_key). Two things remain open, both additions behind the
existing interfaces:
Shared prefix across separate cache instances
SeqCache shares cells within one cache (seq_cp). Sharing a frozen prefix across independent
cache instances, e.g. parallel sub-agents each with their own runner and cache_key, needs a
refcounted read-only SharedPrefix that each branch concatenates with its private tail:
struct SharedPrefix { std::vector<array> pk, pv; int prefix_len; std::atomic<int> refcount; }; // per-layer [1,H,prefix_len,D], READ-ONLY once frozen
class SharedPrefixCache final : public Cache, public MLXCacheImpl {
std::shared_ptr<SharedPrefix> prefix_; // shared across forks (incref on fork; no bytes copied)
std::vector<std::optional<array>> tail_k_, tail_v_; // this branch's PRIVATE divergence
public:
AttendSpec update_and_fetch(...) override { // WRITE only into the private tail; then
// READ = concatenate(prefix_->p[layer], tail) on axis 2 before SDPA; Causal/None.
}
static std::shared_ptr<SharedPrefixCache> fork(std::shared_ptr<SharedPrefix> p); // incref, empty tail
};
Cost: a two-region read (concatenate of shared prefix + private tail before SDPA, the one thing
token-granularity contiguity otherwise avoids) and a mild "shared up to the fork point, private
after" form of COW. It does not adopt vLLM's block model or scheduler. It is SeqCache's
per-cell membership/refcount idea lifted from intra-cache branches to inter-cache forks. Use it only
when the prefix is large and branches are many enough that N independent copies hurt; otherwise
seq_cp within a single SeqCache is simpler.
Heterogeneous / continuous batching (the serving regime)
SeqCache + begin_step is already the multi-sequence substrate. What remains for true serving is
a scheduler above the runner (admission, preemption, batch formation) and a paged / varlen
attention read so ragged batches don't pay the dense-mask cost (§5.2). This is vLLM's shape (§6);
the design is scoped below it but open toward it, and making position a tensor now is the one
hedge that avoids a re-export when it arrives.
KV Cache Management for ExecuTorch
A runtime-swappable KV cache system for ExecuTorch backends, designed for single-user
desktop inference, including agent workloads. It supports plain decode and
tree/speculative (Eagle-style) decoding behind a single custom op; keeps bookkeeping shared
across backends; confines tensor work to each backend; and bounds memory without
over-allocating for short sequences.
1. Motivation & Principles
Target workload
The design target is a desktop application running one user's model locally. The workload is chat, and
increasingly agents: loops that call the model many times, fan out into parallel
sub-agents or tool calls, and carry a long shared context (system prompt, tool definitions,
a reasoning trunk) across branches. This target shapes every decision:
speculative/tree decoding (a per-step latency win) is first-class.
backend" must be a clean operation, not a rewrite.
thousands batched, so the cache stays simple (no scheduler) while serving agent fan-out.
grows lazily and bounds hard.
The result is a design general across the axes this workload exercises, decoding strategy
and backend portability, and scoped below multi-tenant serving (see §6).
What a KV cache is
Not a key-value store: the access pattern is append-only writes at the current position + a
bulk read of history, so the cache's job is addressing and byte movement, not lookup.
Principles
Each is load-bearing, and each ties back to the target:
length: pure integer logic, written once and shared. (A new backend reuses all of it.)
type and kernels (
mlx::core::array). (A new backend implements only this.)what the cache does, not the op (a dispatcher). (Tree is a new cache, not an op change.)
the cache, keeping the
.ptedecoding-strategy-agnostic. (Swap to speculative without re-export.)state lives in the cache. (Sub-agents share weights; branches fork a prefix, §5.2.)
2. Architecture
Three tiers, separated by what they depend on:
Cache,SeqCache,TreeCacheStepPlan,*BookkeepingMLXCacheImpl+ concrete cachescan_extend,clear,seq_cp/seq_rm) isupdate_and_fetchis backend-typed. One concrete cache object implements both faces.3. Ahead-of-Time (AOT)
The AOT half: the neutral custom op the model calls, and a shared utility that writes
model-dependent constants into the
.pte, both backend-independent, shaping the artifact everybackend then lowers and runs.
3.1 The Torch Custom Op (cross-backend)
Two
update_and_attends, never sharing code: (1) the PyTorch custom op (traced bytorch.export, backend-neutral, the cross-backend contract) and (2) the serialized/runtime op(a backend node + handler over its tensor type, the lowered form).
Definition
Signature decisions
mutates_args=(), functional to the tracer. The cache is runtime state (registry, §4.3),not a graph tensor, no functionalization, no mutable-buffer machinery; statefulness lives behind
the delegate at runtime.
layer_id: int, a per-node constant.torch.exportcaptures it per call site (L nodes,node N tagged
layer_id=N), the only channel carrying layer identity in one delegated graph.position: Tensor, one entry per query token (q_len = k.shape(2)):[pos]decode,[0..T)prefill, per-node for tree. A tensor (not a scalar) because tree decoding needs non-contiguous
per-token positions (siblings share, branches differ) and RoPE consumes it in-graph; multi-seq /
heterogeneous batching reuse the same shape with no contract change.
scale: float, a node constant. Usually1/sqrt(head_dim)but models override it, so notsafely derivable in the handler.
out_dtype: torch.dtype, the op's output contract, as a node constant (same channel asscale/layer_id). The downstream graph was traced expecting it, so the op must emit it(and the meta kernel propagates it). It is independent of KV storage precision: a quantized
cache stores int4 yet still emits e.g. fp16; storage dtype is runtime policy, output dtype is this
op contract. Common case: pass
q.dtype.causalflag. Masking is the cache's job, returned viaAttendSpec(None/Causal/Explicit), required for tree's per-node mask, and it subsumes prefill-causal vs decode-none.Eager implementation
The op's eager (non-meta) impl lets the model run in plain PyTorch and is the reference oracle
every backend's runtime op is diffed against (sharing the same bookkeeping logic as the C++).
The lowering contract
A backend's partitioner extracts
q,k,v+positionas tensor inputs (position was scalar, nowa
Tid) andlayer_id/scale/out_dtypeas node constants, then emits its backend node, theone new thing on each backend's AOT side.
What "a backend opts in" means
Shared: the neutral op + meta + eager reference. Per-backend: (AOT) partition+lower to a node;
(runtime) a handler over its tensor type + cache impls + builder registration. The model is written
once against the neutral op; retargeting changes only the partitioner and the linked backend.
3.2 Metadata Writer
Sizing splits by who owns the value; only architecture facts belong in the
.pte:n_layers.ptemetadatan_kv_heads.ptemetadatahead_dim.ptemetadataout_dtype(attention output)capacity(context length)kv_dtype(KV storage precision)min_chunk.pte: is the value fixed by the architecture? Layer/head/dim counts pass.kv_cache.*namespace so the reader (§4.5) is model-agnostic:The
schema_versionmakes this a versioned contract (the reader rejects/adapts an incompatible.pterather than misreading a renamed field). Absent:max_context_lenandkv_dtype, both runtime policy. (Runtime-chosen context length needs export with dynamic sequenceshapes (
enable_dynamic_shape) so the graph bakes in no fixed length.)4. Runtime
The runtime half (on device): the neutral interfaces, shared bookkeeping, install/rendezvous
machinery, the runner loop, and the sizing reader that turns
.ptemetadata + runtime policy into aCacheConfig.4.1 Neutral Interfaces (core)
These are what the runner depends on. Tensor-free, only ints. Backend-independent.
propose/commitlive only on their respective sub-interfaces. A decoderunner holds
Cache*and never sees them. Capability is which interface the runnerholds, checked at construction, not a runtime flag.
Attention semantics are produced by the cache
The op does not carry a
causalflag. Instead the read result declares how to attend it,so tree (tree mask) and paged (ragged) caches all express their own
patterns:
The three kinds map onto MLX's
fast::scaled_dot_product_attentionmaskargument directly,which is why this is the right seam:
None: no mask. The read is the selection (a single decode token attends the wholeprefix; a gathered multi-sequence read selects its cells).
Causal: the"causal"string (fused path, no mask tensor). MLX's"causal"islower-right aligned (last query ↔ last key), correct only if new queries occupy the tail of the
read window, the cache's contract (writes append). Serves fresh and chunked prefill maskless.
Explicit: a boolean mask (true= attend); boolean avoids MLX's additive-maskdtype-promotion (a float mask vs fp16 q/k/v is a correctness+perf hazard). ≤4-D, broadcast-compatible
with
[B, N_heads, T_q, T_kv](a 2-D tree mask broadcasts as[1,1,T_q,T_kv], §5.3).4.2 Shared Bookkeeping (core, concrete)
Tensor-free, reused by every backend. The cache impls inherit these for all runner-facing
methods and for the per-step
plan.Contiguous, a length counter
Tree, committed prefix + appended candidate frontiers
The tree cache appends speculative nodes after the committed prefix; each newly-appended
frontier attends the prefix + its ancestor chain. The target appends the whole tree as one
frontier (verify in one forward); a BFS draft appends one frontier per level. Both then
committhe same accepted chain. All integer/bool logic (parent tree, ancestor mask, commit row-mapping);
only the pool bytes + mask wrap are backend-specific.
The backend tree cache (§5.3) inherits this and adds only the pool, the frontier write, the
[0, base+spec_count)read, the compaction gather overcommit_plan().rows, and wrappingmask_bits().4.3 Setting / Installing the Cache
The constraint
The
DelegateHandleis opaque to the host: there is no public API to reach it. Sothe cache cannot be installed into the handle directly, nor read back from it. The cache
is created by the runner (which knows the kind) and reaches the delegate through a
process-global registry, with the two sides rendezvousing on a
cache_keypassed asa runtime backend-load option.
Registry
Ownership is
shared_ptr, held by three parties: the registry entry, the runner'ssession guard, and the delegate handle. The cache lives until all three release it, so
erasing the registry entry mid-method is safe.
Builder registry + factory
Cache kind is expressed by which factory function you call: there is no kind enum
in the config and no runtime kind flag. Backends register builders per kind under their
backend_id; kind survives only as an internal lookup tag.RAII session guard
Per-use-case factories
Only
cache_keygoes in the option; the spec lives on the cache (no drift). Keygeneration is centralized (no collisions). Install happens-before load by construction
(the session is built before the
Module).Delegate-side binding (in
init)The op handler (MLX)
The op is attention-pattern-agnostic: it applies whatever the cache produced. Install
a contiguous / seq / tree cache and the same op, same
.pte, does causal decode / batchedmulti-seq / tree-masked speculative decode.
4.4 The Runner (neutral)
The runner names the backend by string (ideally read from method metadata) and picks a
factory by use-case. Everything else is neutral types.
The only runtime entry point is
Module::execute(method_name, inputs) -> outputs(equivalently the
forwardconvenience). Everything the model needs, tokens, theposition tensor (one int per query token), is passed as
EValuetensor inputs to that one call;logits come back as a tensor output. There is no
execute_forward-style convenience; the runnerbuilds input tensors and reads output tensors itself.
Per-step flow is forward, then sample, then (seq update | commit), then forward, the structural change
happening between
executecalls (integer-only bookkeeping for multi-seq; treecommitalsocompacts the accepted chain). The op signature is invariant across decode / multi-seq / tree.
4.5 Sizing: Reading Metadata at Runtime
The mirror of §3.2: a shared, model-agnostic reader pulls the architecture facts from
MethodMetaandmake_configcomposes them with caller policy (context length, KV storagedtype) into a
CacheConfig. This removes the hardcoded dimensions and magiccapacityfrom the examples, sizing is derived from the.pteplus an explicit runtimebudget.
CacheConfigcarriesn_kv_heads/head_dim(pool sized at construction) andkv_dtype(storage),not the attention output dtype (that's the op's, §3.1).
capacityis the only growth bound (noserialized
max_context_len; the lazy-doubling pool keeps short sessions cheap butcapacityis thehard ceiling). Size it for the peak live-cell count across active sequences (
seq_cp-shared prefixescount once; tree needs
committed_len + max_tree_size).The runner then sizes from the
.pte:5. MLX Examples
There are three concrete MLX caches, all sharing the same
Poolbyte-layer policy (§5), soeach gets fp16/bf16/fp32 and quantized KV for free, and all inheriting their integer
bookkeeping from §4.2:
ContiguousCache(§5.1), single-stream decode/prefill; the guaranteed-fast hot path.SeqCache(§5.2), the multi-sequence cell pool (batching, fork, eviction).TreeSeqCache(§5.3), speculative / tree decoding.Each adds only its
MLXCacheImpl::update_and_fetch:The contiguous cache uses one helper,
int run_start(position, T, s), it returnsposition[0]after asserting the T entries form a contiguous run
[start, start+T).The
Poolbyte-layer policy (shared by all three caches)All three MLX caches store KV through a per-layer
Pool, the only place fp16 and quantizeddiffer. The pool is SDPA-major (
[1, H, cells, D], cells on axis 2), so a read is a plaincontiguous slice (no transpose) and writes are
slice_update/scatteron axis 2.5.1 Contiguous Cache (
ContiguousCache)The simplest example and single-stream hot path: one sequence, a length counter, no cell pool, no
classification. Uses the shared
Pool(so quant/growth come free); writes a contiguous run, readsthe prefix, causal for prefill / none for decode, guaranteed maskless.
Chunked prefill (
T < valid_len) stays correct because the new tokens are written at the tail and"causal"is lower-right aligned (§7). This isSeqCache's single-sequence fast path(§5.2) carved out as a standalone, no-general-path-reachable class for the common case.
5.2 Sequence Cache (
SeqCache)Plain decode, quantized KV, and agent prefix-sharing are all specializations of one structure: a
pool of cells, each tagged with a position and a set of sequences. That buys
multi-sequence batching, zero-copy prefix sharing, positional rollback, and continuous admission,
with single-stream decode as a fused fast path at no cost and no op-signature change.
TreeSeqCache(§5.3) is a thin front-end over the same substrate. (vLLM's logical/physical addressing + llama.cpp's
per-cell bitset, at token granularity, §6.) Throughout, the batch is flat on the token axis (
q=[1, H, n_tok, D],B=1): sequences are distinguished byseq_id+ the mask, not a batch dimension.seq_idstays out of the oppositionmust be a graph input because RoPE consumes it in-graph (§3.1).seq_idisdifferent: nothing in the transformer compute touches it. It is cache control only (which
cells a token writes, and which cells it may read via the mask). Since the cache is runtime
state reached through the registry (§4.3), the per-token sequence assignment rides the same
between-
executechannel aspropose/commit, abegin_step(seq_id, n_tok)call, not agraph tensor. Consequence: going from single- to multi-sequence requires no re-export. The
op signature (and
update_and_fetch) is unchanged; onlypositionis ever export-locked.Shared bookkeeping: the cell pool (core, tensor-free)
The free-list hands out cells ascending;
seq_cp/seq_rmset/clear a bit (freeing a cell only whenits bitset empties) and
reindexthe affected seq.begin_step/seq_cp/seq_rmare the corruptionsites (each invalidates the memoized plan, §7).
Read-path classification (
StepPlan::Kind)planallocates a cell per query token, then classifies the batch so the byte layer canpick the cheapest correct path.
StepPlan(extending §4.2) carries:There are two
Kinds, classifiedO(active_seqs)from the incrementalinfo_:SingleContiguous: one active seq, contiguous (count == max_cell - min_cell + 1), queries at its tail. Fused causal/none, in-place write, no mask.Explicit: everything else, any multi-sequence batch (independent decode/prefill and the shared-prefix agent fork), holes fromseq_rm, tree masks, interleaving. Scatter write + dense boolean mask.Explicitis always correct, soSingleContiguousis a pure optimization for the single-streamhot path; a backend could implement only
Explicit. Fragmentation (afterseq_rm) keeps youon
Explicit.The
Explicitmask is the correctness core,O(n_tok·L), memoized across the L layer calls, withthe visibility rule: query
iattends celljiffpos_[j] >= 0 && (seq_[j] & seq_i) && pos_[j] <= pos_i(occupied and same-sequence and causal). The read window is
L = high_water, the longestoccupied extent (never
capacity/phys_cap, so a short batch reads a small window at anycapacity).
classifyis the cheapO(active_seqs)front that picksSingleContiguouswhen one seqis contiguous with its queries at the tail (setting
write_start/valid_len), elseExplicit.MLX byte layer (
MLXSeqCache)update_and_fetchis written once over the sharedPool(§5):planclassifies, thenSingleContiguousdoes a contiguous write+slice (fused, no mask) andExplicitscatters + reads the[0, high_water)window with a dense mask. fp16 vs quantized is decided bypool_for(cfg).Single-sequence ⇒ contiguous cells, so plain decode/prefill always hits
SingleContiguous(thesame fused, maskless path as
ContiguousCache, §5.1); the scatter + dense-mask cost is paidonly for genuine multi-sequence/tree use.
Tree / speculative as a front-end over the
PoolpolicyA tree's per-node ancestor mask (siblings sharing a parent must not see each other) can't be
expressed by the seq-bitset + position-causal rule, so
MLXTreeSeqCachekeeps its ownbookkeeping (§4.2
TreeBookkeeping: parent array, frontier, ancestormask_bits(), commitrow-mapping) and shares only the
Poolpolicy + theExplicitbyte path, gaining quantized KVand lazy growth for free:
The only new
Poolmethod isgather_into(base, idx)(compact scattered accepted rows into thecontiguous run at
base; fp16 gathers one tensor,QuantPoolthree), the tree's analog ofscatter_rows. Everything tree-specific is unchanged §4.2 bookkeeping; only storage moves toPool.Runner-facing API
Use cases
A tiny host helper bundles
(token, position, seq)and steps once (begin_stepdelivers seqout-of-band;
executecarries tokens + positions):Speculative/tree decoding uses
make_tree_session+propose/commit(§5.3); the byte layer itneeds is this section's
Explicitfallback.What this subsumes
SeqCache(+ thePoolpolicy) covers, on one substrate, what would otherwise be separate caches:SingleContiguousfast path (also offered standalone asContiguousCache, §5.1).QuantPool(pool_foronkv_dtype), not a separate class.seq_cp/seq_rm/seq_keep+Explicit(trunk read copy-free as one window slice; the mask isolates branches).TreeSeqCache(§5.3), a front-end sharing thePool+Explicitpath, keeping its ownpropose/commit+ ancestor mask.The one cache not folded in is beam search: expressible (beams = sequences) but
reparentcosts O(depth·beams) as a bitset rewrite vs a slot pool's O(beams), so it would need its own
front-end, omitted as deprecated/niche.
Invariants & costs
seq_idis between-executecontrol, not a graphinput);
positionis the only export-locked input.Explicitis always correct. Classification is conservative(
O(active_seqs)); fragmentation degrades toExplicit, compaction restores the fast path.Explicit(a dense mask each step; positions differ, so a singleCausalcan't serve them, RoPE already applied them). Only
SingleContiguouskeeps the fused, masklesskernel.
Explicitis copy-free for the fork (shared trunk = one window slice).valid_len/read_len = high_water), neverphys_caporcapacity, a short batch reads a small window at any capacity.active_seqsand eachPool'squantize/dequantize/gather_into.5.3 Tree / Speculative runtime (Eagle3)
The tree/speculative cache is
MLXTreeSeqCache(§5.2, inheriting §4.2TreeBookkeepingover theshared
Pool). This section is the runtime: how a target + draft drive two tree-cache sessionsfor Eagle-style speculative decoding behind the one unchanged op. Recap:
proposeappends a frontier(next forward's queries) attending the prefix + its ancestor chain via an
Explicitmask;commitcompacts the accepted chain into the prefix. The target verifies the whole tree in one forward; a
BFS draft appends one frontier per level; both use the same cache.
Mask example. Prefix
[0..base); a frontier of 4 nodes 0,1,2,3 withparent=[-1,0,0,1](node 0 root; 1,2 children of 0; 3 child of 1), e.g. the target verifying the whole tree at
once:
Node 2 does not see node 1 (different branch); node 3 sees 0 and 1 (its ancestors).
A
causalflag cannot express row 2's "see 0, skip 1": proof that masking must becache-produced.
Eagle3 decode loop. Target and draft are separate
Modules, each with its ownTreeSeqCache(own.pte/cache_key/metadata); everything isModule::execute("forward", …).The target's
forwardemits{logits, features}(Eagle taps the residual stream). The draftexpands breadth-first: one forward per tree level (
propose+forward), so a depth-dtreecosts
ddraft forwards regardless of width. The target appends the whole tree as one frontierand verifies it in ONE forward. Greedy verify descends from the root accepting each child whose
token matches the target's argmax; the argmax past the chain is a free, always-correct bonus
(re-fed as the next root). Both caches then
committhe same accepted chain, symmetric, sinceeach already holds the KV, so commit compacts it into the prefix. Host helpers
(
TreeProposal{tokens,parents,depths},VerifyResult{accepted,bonus},verify_tree,gather_features) are not ExecuTorch API. Skeleton:The speedup: multiple tokens (accepted chain + bonus) from one target forward, with the draft
expanding the tree in
depthforwards. The bonus is the target's own greedy choice, so it is alwayscorrect, rejection only shortens the chain. With no active tree (
proposenot yet called) the sameTreeSeqCachereturns causal/none, so it serves prompt prefill, the per-level draft forwards, andthe target verify;
propose/commitare the only tree-specific verbs. The loop drives twoindependent tree sessions and exercises the per-node position tensor, all behind the one unchanged op.
Backend registration (one place MLX cache types are named)
6. Comparison to vLLM and llama.cpp
At its core this is PagedAttention's logical/physical addressing at token granularity, stripped
of the scheduler/preemption/swapping that are most of vLLM's mass, adapted to a frozen
exported-graph on-device runtime, the single-user version of the same ideas.
Shared ideas
Logical addressing decoupled from storage (vLLM's block table ≈ our cell map: per-cell
seqbitset +pos); refcounted shared history with fork (vLLM block COW ≈ ourseq_cpbit-set, no COW); andappend-write + bulk history read from logical state (vLLM
slot_mapping/block-walk ≈ ourStepPlan,write_cells+read_len).Where it differs (and why)
The
Levelcolumn marks each row design (neutral) vs MLX impl (a backend choice anotherbackend could make differently):
Explicitmask viaAttendSpec)AttendSpec), op-agnosticCacheimpl is allowed by the same interfaceAttendSpec.What vLLM has that this omits
The entire scheduler layer: admission, continuous batching, preemption, KV swapping. That's
most of what makes vLLM a serving system and what a single-user target doesn't need;
omitting it is what lets the cache be a simple pool with a host-side hard-stop.
Where it goes beyond vLLM
Tree/speculative decoding is first-class here. Token-granularity cells express the per-node
ancestor mask (
ExplicitviaAttendSpec) behind the same op that serves plain decode. Theruntime-swappable cache, decode vs. tree chosen by which cache is installed, with one unchanged
.pte, has no vLLM analog, because vLLM has no export/delegate boundary to be invariant across.(Token granularity would also make beam search exact and cheap, but beam is not implemented, §5.)
The path to the vLLM regime
If multi-tenant serving ever becomes a target, it's reachable without a rewrite (§8): a paged batched
Cacheimpl, the per-tokenposition(already a tensor), a ragged/paged read, and a scheduler abovethe runner, reusing the op-as-contract, registry, and neutral/byte split, but reintroducing COW +
the scheduler. PagedAttention's addressing without its scheduler, sized for the device, open toward
the server.
Fit for the desktop-agent target
The two interfaces are general along different axes. Read against the target from §1
(single-user desktop, including agents) rather than multi-tenant serving, the comparison
favors this design on the axes that matter here:
Cacheimpl behind the same opupdate_and_fetchover the backend's tensorSeqCache: N sequences in one forward (begin_step+Explicit), or separate caches percache_keySeqCachefork (seq_cp, §5.2); cross-instance forking (§8)SeqCachefeatures (begin_step,seq_cp, §5.2), not the paged-store + scheduler regime.Comparison to llama.cpp
After the cell-pool design, the cache layer here is llama.cpp's: a pool of cells,
each carrying a position and a per-cell sequence bitset (
std::bitset<LLAMA_MAX_SEQ>there,SeqBitshere), with zero-copy prefix sharing, positional rollback, and a position-based causal +sequence-membership mask. The one fundamental difference is the boundary: llama.cpp is its own
runtime (gguf weights + a ggml graph rebuilt every
decode), so it can be fully dynamic; this designrecovers that flexibility inside ExecuTorch's ahead-of-time-exported, delegated, multi-backend
runtime, behind a single op.
Near-identical shared ideas: per-cell
pos+ sequence bitset (llama_kv_cells≈ ourpos_/seq_),zero-copy
seq_cp, positionalseq_rm(free a cell when its bitset empties), the sameposition-based causal + membership mask (
seq_has(cell,seq) && pos[cell] ≤ pos[q], both built denseon CPU per decode), and a flat batch with per-token
seq_id(llama_batch≈ ourbegin_step).decodekv_size(decode fails when full)capacitytype_k/type_v)Pool(pool_foronkv_dtype), no re-loadeagle3is a// TODOstubExplicitancestor mask (TreeSeqCache, §5.3).ptePerformance. On Apple silicon both ride tuned Metal, so common-path parity is plausible; but
llama.cpp wins today on maturity, kernel breadth, architecture coverage, and concurrent
throughput, while this design's edges (fused single-stream, tree/Eagle, dynamic growth, runtime
quant) are designed-in but unbuilt. The goal is not to beat llama.cpp but to bring its cache
flexibility into ExecuTorch's export/delegate/portability model (which llama.cpp doesn't target),
adding tree/Eagle + dynamic growth, scoped below the scheduler.
7. Key Invariants & Notes
update_and_fetch.A new backend = register builders + one op handler + one
*CacheImpl. A new cache kind= a bookkeeping type + a byte layer + a factory function (kind named by the function,
not an enum).
StepPlan/mask invalidation afterbegin_step/seq_cp/seq_rm/propose/commitarethe corruption sites. Test bookkeeping against the eager reference oracle before wiring kernels.
cells). Returns only if you coarsen to blocks.
full copies. Verify with a memory check at long context.
shared_ptrownership across registry / session / handle; installhappens-before load (factory return order); erase on session scope exit (RAII).
can_extend(host-side); lazy doubling growth so shortsequences do not reserve full context.
AttendSpec), never an op flag. This is whatlets one op serve contiguous, multi-sequence, and tree decoding by runtime cache swap.
Causalrequires queries at the tail. MLX's"causal"is lower-right aligned, so a cachereturning
Causalmust place new tokens at the end of the read window (holds by construction;makes chunked prefill maskless).
Explicitmasks are boolean (avoids the additive-maskdtype-promotion rule and float-mask perf cliff).
CacheConfig.kv_dtype); attentionoutput dtype is the op contract (
out_dtypenode constant). Independent: int4 KV still emitsout_dtype. Sizing splits the same way (.ptearch facts §3.2; runtime policy §4.5).rewind/seq_rmtruncate without moving bytes (keep[0, new_len), overwrite later) for agentbacktracking/regeneration. Independent caches (target+draft, sub-agents) coexist via distinct
cache_keys.must share an address space. Cross-process would route control (
seq_cp/commit)through
executeinstead.8. Future Extensions (open, not built)
SeqCachealready covers what earlier drafts listed here: concurrent sequences (begin_step),intra-session prefix sharing / fork (
seq_cp), eviction and rollback (seq_rm), and swappableper-session caches (per-
cache_key). Two things remain open, both additions behind theexisting interfaces:
Shared prefix across separate cache instances
SeqCacheshares cells within one cache (seq_cp). Sharing a frozen prefix across independentcache instances, e.g. parallel sub-agents each with their own runner and
cache_key, needs arefcounted read-only
SharedPrefixthat each branch concatenates with its private tail:Cost: a two-region read (
concatenateof shared prefix + private tail before SDPA, the one thingtoken-granularity contiguity otherwise avoids) and a mild "shared up to the fork point, private
after" form of COW. It does not adopt vLLM's block model or scheduler. It is
SeqCache'sper-cell membership/refcount idea lifted from intra-cache branches to inter-cache forks. Use it only
when the prefix is large and branches are many enough that N independent copies hurt; otherwise
seq_cpwithin a singleSeqCacheis simpler.Heterogeneous / continuous batching (the serving regime)
SeqCache+begin_stepis already the multi-sequence substrate. What remains for true serving isa scheduler above the runner (admission, preemption, batch formation) and a paged / varlen
attention read so ragged batches don't pay the dense-mask cost (§5.2). This is vLLM's shape (§6);
the design is scoped below it but open toward it, and making
positiona tensor now is the onehedge that avoids a re-export when it arrives.