From d815bba0e3f2b5505d38344844cd1b8c32f66a3f Mon Sep 17 00:00:00 2001 From: Nausicaa Li <2239638+nausicaalii@users.noreply.github.com> Date: Sat, 11 Apr 2026 15:51:51 -0700 Subject: [PATCH] perf: vectorize prefix matching with numpy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace O(n) Python for-loop in KV cache prefix matching and longest_token_prefix() with numpy vectorized comparison. The element-wise numpy comparison runs in optimized C/SIMD instead of Python's interpreter loop, which matters as conversation history grows (10K+ tokens). No change in behavior — both paths find the first position where cached and new token sequences diverge. --- llama_cpp/llama.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 11fe169cf..6529a75c8 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -887,12 +887,18 @@ def generate( # Check for kv cache prefix match if reset and self.n_tokens > 0: - longest_prefix = 0 - for a, b in zip(self._input_ids, tokens[:-1]): - if a == b: - longest_prefix += 1 - else: - break + cached = self._input_ids + n = min(len(cached), len(tokens) - 1) + if n > 0: + eq = np.asarray(cached[:n]) == np.asarray( + tokens[:n] + ) + mismatch = np.argmin(eq) + longest_prefix = ( + int(n) if eq[mismatch] else int(mismatch) + ) + else: + longest_prefix = 0 if longest_prefix > 0: if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1): reset = False @@ -2252,13 +2258,12 @@ def logits_to_logprobs( @staticmethod def longest_token_prefix(a: Sequence[int], b: Sequence[int]): - longest_prefix = 0 - for _a, _b in zip(a, b): - if _a == _b: - longest_prefix += 1 - else: - break - return longest_prefix + n = min(len(a), len(b)) + if n == 0: + return 0 + eq = np.asarray(a[:n]) == np.asarray(b[:n]) + mismatch = np.argmin(eq) + return int(n) if eq[mismatch] else int(mismatch) @classmethod def from_pretrained(