diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 11fe169cf..6529a75c8 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -887,12 +887,18 @@ def generate( # Check for kv cache prefix match if reset and self.n_tokens > 0: - longest_prefix = 0 - for a, b in zip(self._input_ids, tokens[:-1]): - if a == b: - longest_prefix += 1 - else: - break + cached = self._input_ids + n = min(len(cached), len(tokens) - 1) + if n > 0: + eq = np.asarray(cached[:n]) == np.asarray( + tokens[:n] + ) + mismatch = np.argmin(eq) + longest_prefix = ( + int(n) if eq[mismatch] else int(mismatch) + ) + else: + longest_prefix = 0 if longest_prefix > 0: if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1): reset = False @@ -2252,13 +2258,12 @@ def logits_to_logprobs( @staticmethod def longest_token_prefix(a: Sequence[int], b: Sequence[int]): - longest_prefix = 0 - for _a, _b in zip(a, b): - if _a == _b: - longest_prefix += 1 - else: - break - return longest_prefix + n = min(len(a), len(b)) + if n == 0: + return 0 + eq = np.asarray(a[:n]) == np.asarray(b[:n]) + mismatch = np.argmin(eq) + return int(n) if eq[mismatch] else int(mismatch) @classmethod def from_pretrained(