From 187cbda931acb7a67eabe8416e8251e0f8ffcee0 Mon Sep 17 00:00:00 2001 From: sufubao Date: Sat, 13 Jun 2026 22:38:03 +0800 Subject: [PATCH 1/4] fix(linear-att): fix latent prefix-cache ref/buffer leaks Four latent defects in LinearAttPagedRadixCache / _linear_att_free_req, found via adversarial audit + property-based fuzzing: - Root in eviction set: _add_node added root to _evict_tree_set when the tree emptied (root is a leaf then), contradicting _evict's own 'assert node is not self.root_node'. Now excluded. - Root ref leak on miss / trim-to-empty: match_prefix takes a root ref on descent; the two 'no match' returns handed None to the caller without releasing it, so root.ref_counter drifted up on every miss. Both returns now release it. - Root ref leak in deref_to_first_big_page_node: the big-page downgrade path leaked the same root ref when it bottomed out at root. Fixed. - Big-page state-buffer leak / assert-crash: big-page state ids accumulated in req.linear_att_len_to_big_page_id during chunked prefill were neither inserted nor freed when a request was paused/aborted mid-prefill (fallback branch of _linear_att_free_req), tripping free_a_req_mem's assert (worker crash) or leaking slots with asserts off. Now released on the non-insert exit paths. New CPU-only tests (no GPU): property-based invariant fuzzers for the small-page and big-page regimes, plus regression tests for the pause/abort big-page release. The three root-ref issues are latent (root carries zero tokens and is never evictable). The big-page leak is a reachable worker crash for long-context serving with --linear_att_page_block_num set. --- .../dynamic_prompt/linear_att_radix_cache.py | 14 +- .../server/router/model_infer/infer_batch.py | 16 + .../test_linear_att_radix_cache_bigpage.py | 300 +++++++++++++ .../test_linear_att_radix_cache_invariants.py | 404 ++++++++++++++++++ .../test_linear_att_free_req_big_page.py | 88 ++++ 5 files changed, 821 insertions(+), 1 deletion(-) create mode 100644 unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_bigpage.py create mode 100644 unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_invariants.py create mode 100644 unit_tests/server/router/model_infer/test_linear_att_free_req_big_page.py diff --git a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py index c7408add39..22b79cbe2c 100644 --- a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py @@ -163,7 +163,10 @@ def _discard_node(self, node: LinearAttPagedTreeNode): return def _add_node(self, node: LinearAttPagedTreeNode): - if node.is_leaf(): + # root 永远不参与回收:当树为空时 root 自身也满足 is_leaf(),若加入 _evict_tree_set, + # 会与 _evict 中 "node is not self.root_node" 的断言相矛盾(当前仅靠 root 的 ref_counter>=1 + # 和回收水位 guard 掩盖)。这里显式排除,使数据结构与回收逻辑的意图一致。 + if node.is_leaf() and node is not self.root_node: self._evict_tree_set.add(node) if node.small_page_buffer_idx is not None: self._evict_tree_set_for_linear_att.add(node) @@ -362,12 +365,18 @@ def match_prefix( ans_node_list=ans_node_list, update_refs=update_refs, ) + # _match_prefix_helper 进入时一定对 root 自增了一次 ref_counter。命中链非空时,调用方最终会 + # 通过 dec_node_ref_counter(ans_node) 沿父链回收(含 root),增减平衡;但下面两个 "命中为空" + # 的提前返回会把 None 交给调用方,调用方不会再回收,root 自增就无人抵消,导致 root.ref_counter + # 在每次 miss / trim 到空时持续漂移。这里显式补偿这一次 root 自增。 if len(ans_node_list) == 0: + self.dec_node_ref_counter(self.root_node) return None, 0, None # 判定真正可以用的匹配节点。 ans_node_list = self._trim_unusable_match_tail(ans_node_list) if len(ans_node_list) == 0: + self.dec_node_ref_counter(self.root_node) return None, 0, None ans_node = ans_node_list[-1] @@ -482,6 +491,9 @@ def deref_to_first_big_page_node(self, node: LinearAttPagedTreeNode) -> Optional iter_node = iter_node.parent if iter_node is self.root_node: + # 没有可承接的 big-page 节点交给调用方释放:root 在 match 阶段同样被 +1, + # 这里必须补偿,否则与 match_prefix miss 路径同类的 root ref 漂移。 + self.dec_node_ref_counter(self.root_node) return None else: return iter_node diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index bae5ea1e3c..67fb03ca6e 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -149,6 +149,18 @@ def _full_att_free_req(self, free_token_index: List, req: "InferReq"): req.shared_kv_node = None return + def _release_pending_linear_att_big_page_ids(self, req: "InferReq"): + # 释放本请求 prefill 阶段在 big page 边界上申请、但尚未插入 radix cache 的 big page + # state buffer。仅当请求未走 insert 分支(小页/大页插入)就被释放时才会有残留,典型场景: + # big page 模式下请求在 prefill 跨过 big page 边界后、到达末尾前被 pause / abort。 + # 若不释放,会泄漏 big page state slot,并触发 free_a_req_mem 中 dict 为空的断言。 + if req.linear_att_len_to_big_page_id: + self.radix_cache.linear_att_big_page_buffers.free_state_cache( + list(req.linear_att_len_to_big_page_id.values()) + ) + req.linear_att_len_to_big_page_id.clear() + return + def _linear_att_free_req(self, free_token_index: List, req: "InferReq"): assert g_infer_context.is_linear_att_mixed_model is True args = get_env_start_args() @@ -164,6 +176,7 @@ def _linear_att_free_req(self, free_token_index: List, req: "InferReq"): assert req.linear_att_cache_len <= req.cur_kv_len if req.cur_kv_len == 0: + self._release_pending_linear_att_big_page_ids(req) return if req.linear_att_cache_len <= req.cur_kv_len and req.tail_linear_att_small_page_buffer_id is not None: @@ -232,6 +245,9 @@ def _linear_att_free_req(self, free_token_index: List, req: "InferReq"): assert req.shared_kv_node.node_prefix_total_len == req.cur_kv_len self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None + # 该分支不会把 prefill 阶段累积的 big page id 插入 radix cache(典型为 pause/abort + # 在 prefill 跨过 big page 边界后、到达末尾前触发),需在此显式释放,避免泄漏。 + self._release_pending_linear_att_big_page_ids(req) return assert False, f"error state: cur_kv_len: {req.cur_kv_len}" diff --git a/unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_bigpage.py b/unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_bigpage.py new file mode 100644 index 0000000000..fc2f5d9981 --- /dev/null +++ b/unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_bigpage.py @@ -0,0 +1,300 @@ +"""Big-page-regime coverage + invariant fuzz for LinearAttPagedRadixCache. + +Active in production only when --linear_att_page_block_num is set (e.g. the GSM8K +launch scripts use 8). Here big_page_num is small so inserts create big-page nodes +plus an optional small tail, mirroring _linear_att_free_req's two insert calls and +copy_linear_att_state_to_cache_buffer's len_to_big_page_id construction. +""" +import uuid + +import numpy as np +import pytest +import torch +from sortedcontainers import SortedDict + +from lightllm.server.router.dynamic_prompt.linear_att_radix_cache import LinearAttPagedRadixCache +from lightllm.utils.kv_cache_utils import compute_token_list_hash + +PAGE = 4 +BIGN = 2 +BIG_TOKENS = PAGE * BIGN + + +class FakePool: + def __init__(self, size): + self.size = size + self.free_set = set(range(size)) + self.order = list(range(size)) + + def alloc_one_state_cache(self): + if not self.order: + return None + i = self.order.pop(0) + self.free_set.discard(i) + return i + + def free_state_cache(self, free_indexes): + for i in free_indexes: + assert i is not None and i not in self.free_set, f"double free {i}" + self.free_set.add(i) + self.order.append(i) + + def get_free_cache_num(self): + return len(self.order) + + +class FakeAllocator: + def __init__(self, size): + self.size = size + self.can_use_mem_size = size + + +class FakeMem: + def __init__(self, size, big_pool): + self.allocator = FakeAllocator(size) + self.linear_att_big_page_buffers = big_pool + + def free(self, mem_index): + self.allocator.can_use_mem_size += len(mem_index) + + +def build(small_size=32, big_size=64, mem=400_000): + small = FakePool(small_size) + big = FakePool(big_size) + mm = FakeMem(mem, big) + cache = LinearAttPagedRadixCache( + unique_name=f"bp_{uuid.uuid4().hex[:8]}", + total_token_num=mem, + rank_in_node=0, + hash_page_size=PAGE, + big_page_num=BIGN, + kv_cache_mem_manager=mm, + linear_att_small_page_buffers=small, + ) + return cache, small, big, mm + + +def walk(cache): + out = [] + st = list(cache.root_node.children.values()) + while st: + n = st.pop() + out.append(n) + st.extend(n.children.values()) + return out + + +def page_tokens(pid): + return list(range(pid * PAGE, pid * PAGE + PAGE)) + + +def hashes_for(pids): + toks = [] + for p in pids: + toks += page_tokens(p) + toks.append(-1) + return compute_token_list_hash(toks, PAGE) + + +def check(cache, small, big): + nodes = walk(cache) + # structural + accounting + total = 0 + refed = 0 + for n in nodes: + assert n.parent is not None + assert n.node_prefix_total_len == n.parent.node_prefix_total_len + n.node_value_len + assert n.ref_counter >= 0 + assert n.node_value_len == len(n.token_mem_index_value) + if n.is_big_page_node(): + assert n.page_num == BIGN and n.node_value_len == BIG_TOKENS + assert n.big_page_buffer_idx is not None + assert n.small_page_buffer_idx is None + else: + assert n.page_num == 1 and n.node_value_len == PAGE + assert n.big_page_buffer_idx is None + total += n.node_value_len + if n.ref_counter > 0: + refed += n.node_value_len + for k, c in n.children.items(): + assert c.page_hash == k and c.parent is n + assert cache.get_tree_total_tokens_num() == total + assert cache.get_refed_tokens_num() == refed + # evict set == non-root leaves + leaves = {id(n) for n in nodes if n.is_leaf()} + assert {id(n) for n in cache._evict_tree_set} == leaves + assert id(cache.root_node) not in {id(n) for n in cache._evict_tree_set} + # buffer-evict set == small-buffer holders + assert {id(n) for n in cache._evict_tree_set_for_linear_att} == { + id(n) for n in nodes if n.small_page_buffer_idx is not None + } + # big-page id conservation + big_in_tree = [n.big_page_buffer_idx for n in nodes if n.is_big_page_node()] + assert len(big_in_tree) == len(set(big_in_tree)), "big-page id reused by two nodes" + assert set(big_in_tree).isdisjoint(big.free_set) + assert set(big_in_tree) | big.free_set == set(range(big.size)), "big-page id leaked" + # small-page id conservation + small_in_tree = [n.small_page_buffer_idx for n in nodes if n.small_page_buffer_idx is not None] + assert len(small_in_tree) == len(set(small_in_tree)) + assert set(small_in_tree).isdisjoint(small.free_set) + assert set(small_in_tree) | small.free_set == set(range(small.size)), "small-page id leaked" + + +def make_insert(cache, small, big): + """Mirror _linear_att_free_req: big-page-aligned prefix (+ optional small tail).""" + + def insert(pids, mem_base, with_small_tail): + L = len(pids) + num_big = L // BIGN + # len_to_big_page_id: one fresh big id per big-page boundary along the path + l2b = SortedDict() + big_ids_alloced = [] + for j in range(1, num_big + 1): + bid = big.alloc_one_state_cache() + if bid is None: + # big pool exhausted: the real caller would not start this insert; roll back. + for got in big_ids_alloced: + big.free_state_cache([got]) + return + big_ids_alloced.append(bid) + l2b[j * BIG_TOKENS] = bid + hashs = hashes_for(pids) + key = torch.tensor([t for p in pids for t in page_tokens(p)], dtype=torch.int64) + value = torch.arange(mem_base, mem_base + L * PAGE, dtype=torch.int64) + linear_idxs = [None] * L + tail_buf = None + if with_small_tail and (L % BIGN != 0): + tail_buf = small.alloc_one_state_cache() + if tail_buf is None: + # contract: cannot insert a None-tailed non-aligned path; drop the tail page + pids = pids[:-1] + L = len(pids) + if L == 0: + # nothing to insert; release any big ids we grabbed (none, since num_big recomputed) + for bid in big_ids_alloced: + big.free_state_cache([bid]) + return + hashs = hashes_for(pids) + key = torch.tensor([t for p in pids for t in page_tokens(p)], dtype=torch.int64) + value = torch.arange(mem_base, mem_base + L * PAGE, dtype=torch.int64) + linear_idxs = [None] * L + else: + linear_idxs[-1] = tail_buf + elif L % BIGN != 0: + # no small tail wanted but path is not big-aligned -> trim to aligned length + pids = pids[: num_big * BIGN] + L = len(pids) + if L == 0: + for bid in big_ids_alloced: + big.free_state_cache([bid]) + return + hashs = hashes_for(pids) + key = torch.tensor([t for p in pids for t in page_tokens(p)], dtype=torch.int64) + value = torch.arange(mem_base, mem_base + L * PAGE, dtype=torch.int64) + linear_idxs = [None] * L + + before_small = set(small.free_set) + cache.insert(key, value, block_hashs=hashs, block_linear_idxs=linear_idxs, len_to_big_page_id=l2b) + # any tail buffer that was a duplicate got freed by the cache; nothing to track + _ = before_small + + return insert + + +def test_pure_bigpage_insert_and_match(): + cache, small, big = build()[:3] + ins = make_insert(cache, small, big) + # 4 pages -> 2 big pages, no small tail + ins([1, 2, 3, 4], 1000, with_small_tail=False) + check(cache, small, big) + assert cache.get_tree_total_tokens_num() == 4 * PAGE + + hashs = hashes_for([1, 2, 3, 4]) + key = torch.tensor([t for p in [1, 2, 3, 4] for t in page_tokens(p)], dtype=torch.int64) + node, kv, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) + assert node is not None and node.is_big_page_node() + assert kv == 16 and len(mem) == 16 + assert torch.equal(mem, torch.arange(1000, 1016, dtype=torch.int64)) + cache.dec_node_ref_counter(node) + check(cache, small, big) + + +def test_mixed_insert_match_trims_to_bigpage_when_tail_unusable(): + cache, small, big = build(small_size=1)[:3] + ins = make_insert(cache, small, big) + # 5 pages -> 2 big pages (8 tokens *2 =16) + 1 small tail page (4) = 20 tokens + ins([1, 2, 3, 4, 5], 2000, with_small_tail=True) + check(cache, small, big) + assert cache.get_tree_total_tokens_num() == 20 + + # exhaust small pool and steal the tail buffer -> tail page unusable + while small.alloc_one_state_cache() is not None: + pass + cache.free_one_small_page_linear_att_buffer() + check(cache, small, big) + + hashs = hashes_for([1, 2, 3, 4, 5]) + key = torch.tensor([t for p in [1, 2, 3, 4, 5] for t in page_tokens(p)], dtype=torch.int64) + node, kv, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) + # tail small page has no buffer -> trim back to the last big-page boundary (16) + assert node is not None and node.is_big_page_node() + assert kv == 16 + cache.dec_node_ref_counter(node) + check(cache, small, big) + + +@pytest.mark.parametrize("seed", list(range(10))) +def test_bigpage_fuzz(seed): + rng = np.random.default_rng(seed) + cache, small, big, mm = build(small_size=10, big_size=48, mem=400_000) + ins = make_insert(cache, small, big) + live = [] + mem_base = [10_000] + + def do_ins(): + L = int(rng.integers(1, 7)) + pids = [int(rng.integers(0, 25)) for _ in range(L)] + ins(pids, mem_base[0], with_small_tail=bool(rng.integers(0, 2))) + mem_base[0] += 100 + + def do_match(): + L = int(rng.integers(1, 7)) + pids = [int(rng.integers(0, 25)) for _ in range(L)] + hashs = hashes_for(pids) + key = torch.tensor([t for p in pids for t in page_tokens(p)], dtype=torch.int64) + node, kv, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) + if node is None: + assert kv == 0 and mem is None + return + assert kv == node.node_prefix_total_len == len(mem) + assert node.is_big_page_node() or node.small_page_buffer_idx is not None + live.append(node) + + def do_dec(): + if live: + cache.dec_node_ref_counter(live.pop(int(rng.integers(0, len(live))))) + + def do_steal(): + cache.free_one_small_page_linear_att_buffer() + + def do_evict(): + unref = cache.get_tree_total_tokens_num() - cache.get_refed_tokens_num() + if unref < PAGE: + return + need = int(rng.integers(1, unref // PAGE + 1)) * PAGE + cache._evict(need, lambda m, b: small.free_state_cache([b]) if b is not None else None) + + ops = [do_ins, do_ins, do_match, do_match, do_dec, do_steal, do_evict] + for _ in range(400): + ops[int(rng.integers(0, len(ops)))]() + check(cache, small, big) + assert cache.root_node.ref_counter == 1 + len(live), "root ref drifted (big-page regime)" + + while live: + cache.dec_node_ref_counter(live.pop()) + assert cache.get_refed_tokens_num() == 0 + t = cache.get_tree_total_tokens_num() + if t: + cache._evict(t, lambda m, b: small.free_state_cache([b]) if b is not None else None) + assert cache.get_tree_total_tokens_num() == 0 + check(cache, small, big) diff --git a/unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_invariants.py b/unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_invariants.py new file mode 100644 index 0000000000..dac8255ffc --- /dev/null +++ b/unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_invariants.py @@ -0,0 +1,404 @@ +"""Property-based invariant fuzzer for LinearAttPagedRadixCache (small-page regime). + +This drives the cache the way infer_batch.py does in the default serving regime +(big-page matching disabled, i.e. big_page_num huge so every request inserts only +small pages with a single tail state buffer). After every random operation it +asserts the full set of internal invariants and verifies value integrity against +an independent first-write-wins oracle. + +Run: pytest unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_invariants.py -q +""" +import uuid + +import numpy as np +import pytest +import torch + +from lightllm.server.router.dynamic_prompt.linear_att_radix_cache import LinearAttPagedRadixCache + +PAGE = 4 +BIG = 10_000_000 # big pages effectively disabled (serving default regime) + + +class FakeSmallPageBuffers: + """Models LinearAttCacheManager's id pool with strict double-free detection.""" + + def __init__(self, size): + self.size = size + self.free_set = set(range(size)) + self.free_order = list(range(size)) + + def alloc_one_state_cache(self): + if not self.free_order: + return None + idx = self.free_order.pop(0) + self.free_set.discard(idx) + return idx + + def free_state_cache(self, free_indexes): + for idx in free_indexes: + assert idx is not None + assert idx not in self.free_set, f"double free of small-page buffer {idx}" + self.free_set.add(idx) + self.free_order.append(idx) + + def get_free_cache_num(self): + return len(self.free_order) + + +class FakeBigPageBuffers: + def __init__(self): + self.freed = [] + + def free_state_cache(self, free_indexes): + self.freed.extend(free_indexes) + + +class FakeAllocator: + def __init__(self, size): + self.size = size + self.can_use_mem_size = size + + +class FakeMemManager: + def __init__(self, size): + self.allocator = FakeAllocator(size) + self.linear_att_big_page_buffers = FakeBigPageBuffers() + self.freed_mem = [] + + def free(self, mem_index): + self.freed_mem.append(mem_index) + self.allocator.can_use_mem_size += len(mem_index) + + +def build(small_pool_size=32, mem_size=100_000): + small = FakeSmallPageBuffers(small_pool_size) + mm = FakeMemManager(mem_size) + cache = LinearAttPagedRadixCache( + unique_name=f"fuzz_{uuid.uuid4().hex[:8]}", + total_token_num=mem_size, + rank_in_node=0, + hash_page_size=PAGE, + big_page_num=BIG, + kv_cache_mem_manager=mm, + linear_att_small_page_buffers=small, + ) + return cache, small, mm + + +# ------------------------- tree walking helpers ------------------------- + + +def walk(cache): + """Return all non-root nodes via BFS.""" + out = [] + stack = list(cache.root_node.children.values()) + while stack: + n = stack.pop() + out.append(n) + stack.extend(n.children.values()) + return out + + +def check_invariants(cache, small: FakeSmallPageBuffers, allocated_ids: set): + nodes = walk(cache) + + # 1. structural: prefix len, child-key consistency, page bookkeeping + for n in nodes: + assert n.parent is not None + assert n.node_value_len == len(n.token_mem_index_value) == len(n.token_id_key) + assert n.node_prefix_total_len == n.parent.node_prefix_total_len + n.node_value_len + assert n.ref_counter >= 0, f"negative ref_counter {n.ref_counter}" + for k, c in n.children.items(): + assert c.page_hash == k + assert c.parent is n + # small-page regime: every node is exactly one page + assert n.node_value_len == PAGE + assert n.page_num == 1 + assert not n.is_big_page_node() or n.node_prefix_total_len == 0 + + # 2. accounting: tree_total and refed token counts match the live tree + total = sum(n.node_value_len for n in nodes) + refed = sum(n.node_value_len for n in nodes if n.ref_counter > 0) + assert ( + cache.get_tree_total_tokens_num() == total + ), f"tree_total {cache.get_tree_total_tokens_num()} != actual {total}" + assert cache.get_refed_tokens_num() == refed, f"refed {cache.get_refed_tokens_num()} != actual {refed}" + + # 3. evict-set membership: exactly the leaves, root excluded + leaves = {id(n) for n in nodes if n.is_leaf()} + evict_ids = {id(n) for n in cache._evict_tree_set} + assert evict_ids == leaves, "evict set must equal the set of non-root leaves" + assert id(cache.root_node) not in evict_ids + + # 4. buffer-eviction set membership: exactly nodes holding a small-page buffer + with_buf = {id(n) for n in nodes if n.small_page_buffer_idx is not None} + buf_evict_ids = {id(n) for n in cache._evict_tree_set_for_linear_att} + assert buf_evict_ids == with_buf, "linear-att evict set must equal nodes with a buffer" + + # 5. buffer-id conservation: ids in tree, ids free in pool, partition the universe + in_tree = [n.small_page_buffer_idx for n in nodes if n.small_page_buffer_idx is not None] + assert len(in_tree) == len(set(in_tree)), "a small-page buffer id is used by two nodes" + in_tree_set = set(in_tree) + # every allocated id is either in the tree or free in the pool, never both, never lost + assert in_tree_set.isdisjoint(small.free_set), "buffer id is both in tree and free" + assert in_tree_set | small.free_set == allocated_ids | set( + range(small.size) + ), "buffer id leaked (neither in tree nor free)" + + +# ------------------------- oracle for value integrity ------------------------- + + +def page_tokens(page_id): + # distinct, deterministic token block per logical page id + return list(range(page_id * PAGE, page_id * PAGE + PAGE)) + + +def hashes_for(page_ids): + # chained hash so prefixes are prefix-closed (same as real block hashing) + from lightllm.utils.kv_cache_utils import compute_token_list_hash + + toks = [] + for p in page_ids: + toks.extend(page_tokens(p)) + toks.append(-1) # one extra so (len-1)//PAGE == len(page_ids) + return compute_token_list_hash(toks, PAGE) + + +class Oracle: + """First-write-wins record of the value stored for each hashed page path.""" + + def __init__(self): + self.value_by_hash = {} # block_hash -> mem value tensor (PAGE long) + + def record(self, hashs, values): + for i, h in enumerate(hashs): + if h not in self.value_by_hash: + self.value_by_hash[h] = values[i * PAGE : (i + 1) * PAGE].clone() + + def forget(self, freed_hashs): + for h in freed_hashs: + self.value_by_hash.pop(h, None) + + def expected_mem(self, hashs): + return torch.cat([self.value_by_hash[h] for h in hashs]) + + +# ------------------------- the fuzzer ------------------------- + + +@pytest.mark.parametrize("seed", list(range(16))) +@pytest.mark.parametrize("pool", [6, 24]) # 6 => constant state-buffer pressure & real steals +def test_invariant_fuzz(seed, pool): + rng = np.random.default_rng(seed * 100 + pool) + cache, small, mm = build(small_pool_size=pool, mem_size=200_000) + oracle = Oracle() + + next_mem = [1000] + + def alloc_mem(n): + base = next_mem[0] + next_mem[0] += n + return torch.arange(base, base + n, dtype=torch.int64) + + allocated_ids = set() + live = [] # list of (ans_node, matched_hashs) holding a ref to dec later + + def do_insert(): + npages = int(rng.integers(1, 7)) + page_ids = [int(rng.integers(0, 40)) for _ in range(npages)] + hashs = hashes_for(page_ids) + key = torch.tensor([t for p in page_ids for t in page_tokens(p)], dtype=torch.int64) + value = alloc_mem(npages * PAGE) + buf = small.alloc_one_state_cache() + if buf is None: + # Contract (see _linear_att_free_req): in the small-page regime the radix + # insert is only performed when a tail state buffer exists. With the pool + # exhausted the request simply caches nothing. Skip — do not insert a + # None-tailed path (the cache rightly asserts against that). + return + allocated_ids.add(buf) + linear_idxs = [None] * npages + linear_idxs[-1] = buf + oracle.record(hashs, value) + before_free = set(small.free_set) + cache.insert(key, value, block_hashs=hashs, block_linear_idxs=linear_idxs) + # if our buffer was immediately freed (duplicate tail), drop it from allocated set + newly_free = small.free_set - before_free + for idx in newly_free: + allocated_ids.discard(idx) + + def do_match(): + npages = int(rng.integers(1, 7)) + page_ids = [int(rng.integers(0, 40)) for _ in range(npages)] + hashs = hashes_for(page_ids) + key = torch.tensor([t for p in page_ids for t in page_tokens(p)], dtype=torch.int64) + node, kv_len, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) + if node is None: + assert kv_len == 0 and mem is None + return + assert kv_len == node.node_prefix_total_len == len(mem) + assert kv_len % PAGE == 0 + matched_pages = kv_len // PAGE + matched_hashs = hashs[:matched_pages] + # value integrity: returned mem must equal first-written values for this path + expected = oracle.expected_mem(matched_hashs) + assert torch.equal(mem, expected), "match returned wrong mem values" + # the matched tail must be reusable: big-page or has a state buffer + assert node.is_big_page_node() or node.small_page_buffer_idx is not None + live.append((node, matched_hashs)) + + def do_dec(): + if not live: + return + i = int(rng.integers(0, len(live))) + node, _ = live.pop(i) + cache.dec_node_ref_counter(node) + + def do_steal(): + before = small.get_free_cache_num() + cache.free_one_small_page_linear_att_buffer() + after = small.get_free_cache_num() + if after > before: + # a stolen buffer returned to the pool; drop any of our tracked ids that + # are no longer in the tree (conservation invariant rechecks the rest) + in_tree = {n.small_page_buffer_idx for n in walk(cache) if n.small_page_buffer_idx is not None} + for idx in list(allocated_ids): + if idx not in in_tree and idx in small.free_set: + allocated_ids.discard(idx) + + def do_evict(): + unref = cache.get_tree_total_tokens_num() - cache.get_refed_tokens_num() + if unref <= 0: + return + want_pages = int(rng.integers(1, unref // PAGE + 1)) if unref >= PAGE else 0 + if want_pages == 0: + return + need = want_pages * PAGE + + def cb(mem_index, small_buf_id): + # mirror free_radix_cache_to_get_enough_token: evicted node's state buffer + # is returned to the pool by the caller's callback. + if small_buf_id is not None: + small.free_state_cache([small_buf_id]) + + # capture which page hashes leave the tree + before_nodes = {n.page_hash for n in walk(cache)} + cache._evict(need, cb) + after_nodes = {n.page_hash for n in walk(cache)} + freed_hashs = before_nodes - after_nodes + oracle.forget(freed_hashs) + # drop freed buffer ids + in_tree = {n.small_page_buffer_idx for n in walk(cache) if n.small_page_buffer_idx is not None} + for idx in list(allocated_ids): + if idx not in in_tree and idx in small.free_set: + allocated_ids.discard(idx) + + ops = [do_insert, do_insert, do_match, do_match, do_dec, do_steal, do_steal, do_evict] + for step in range(600): + op = ops[int(rng.integers(0, len(ops)))] + op() + check_invariants(cache, small, allocated_ids) + # root must hold exactly baseline(1) + one ref per still-held match; it must NOT + # drift on misses / trim-to-empty (regression guard for the root-ref leak). + assert cache.root_node.ref_counter == 1 + len( + live + ), f"root ref drifted: {cache.root_node.ref_counter} != 1 + {len(live)}" + + # drain all references, then evict everything; tree must end empty and balanced + while live: + node, _ = live.pop() + cache.dec_node_ref_counter(node) + check_invariants(cache, small, allocated_ids) + assert cache.get_refed_tokens_num() == 0 + total = cache.get_tree_total_tokens_num() + if total > 0: + cache._evict(total, lambda m, b: small.free_state_cache([b]) if b is not None else None) + assert cache.get_tree_total_tokens_num() == 0 + assert len(cache._evict_tree_set) == 0 + assert len(cache._evict_tree_set_for_linear_att) == 0 + + +def test_root_ref_balanced_across_many_matches(): + """Root ref must not drift: many match+dec cycles leave it at its initial value.""" + cache, small, mm = build() + root_ref0 = cache.root_node.ref_counter + page_ids = [1, 2, 3] + hashs = hashes_for(page_ids) + key = torch.tensor([t for p in page_ids for t in page_tokens(p)], dtype=torch.int64) + value = torch.arange(1000, 1000 + len(page_ids) * PAGE, dtype=torch.int64) + buf = small.alloc_one_state_cache() + linear_idxs = [None, None, buf] + cache.insert(key, value, block_hashs=hashs, block_linear_idxs=linear_idxs) + + for _ in range(50): + node, kv_len, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) + assert node is not None + cache.dec_node_ref_counter(node) + assert cache.root_node.ref_counter == root_ref0, "root ref_counter drifted" + assert cache.get_refed_tokens_num() == 0 + + +def test_match_then_trim_to_empty_balances_refs(): + """A match that trims all the way back to nothing must restore refs and refed tokens.""" + cache, small, mm = build(small_pool_size=2) + # insert a 2-page path whose tail has NO buffer and is not a big page: + # build it via a longer insert then steal the tail buffer so the tail is unusable. + page_ids = [5, 6] + hashs = hashes_for(page_ids) + key = torch.tensor([t for p in page_ids for t in page_tokens(p)], dtype=torch.int64) + value = torch.arange(2000, 2000 + 2 * PAGE, dtype=torch.int64) + buf = small.alloc_one_state_cache() + cache.insert(key, value, block_hashs=hashs, block_linear_idxs=[None, buf]) + + # exhaust the pool so the steal actually fires (it is a no-op while slots are free), + # then steal the only in-tree buffer -> both pages unusable (no buffer, not big page) + while small.alloc_one_state_cache() is not None: + pass + cache.free_one_small_page_linear_att_buffer() + refed0 = cache.get_refed_tokens_num() + root_ref0 = cache.root_node.ref_counter + node, kv_len, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) + assert node is None and kv_len == 0 and mem is None + assert cache.get_refed_tokens_num() == refed0, "trim-to-empty leaked refed tokens" + assert cache.root_node.ref_counter == root_ref0, "trim-to-empty leaked a root ref" + # all nodes back to ref 0 + for n in walk(cache): + assert n.ref_counter == 0 + + +def test_deref_to_root_balances_root_ref(): + """deref_to_first_big_page_node returning None (reached root) must release the + match-time root ref — otherwise the big-page-enabled match path leaks it.""" + cache, small, mm = build() + page_ids = [1, 2] + hashs = hashes_for(page_ids) + key = torch.tensor([t for p in page_ids for t in page_tokens(p)], dtype=torch.int64) + value = torch.arange(3000, 3000 + 2 * PAGE, dtype=torch.int64) + buf = small.alloc_one_state_cache() + cache.insert(key, value, block_hashs=hashs, block_linear_idxs=[None, buf]) + + r0 = cache.root_node.ref_counter + share_node, kv, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) + assert share_node is not None and not share_node.is_big_page_node() + assert cache.root_node.ref_counter == r0 + 1 # match took one root ref + # small-page regime: root is the only big-page node, so deref walks to root -> None + node = cache.deref_to_first_big_page_node(share_node) + assert node is None + assert cache.root_node.ref_counter == r0, "deref-to-root leaked a root ref" + assert cache.get_refed_tokens_num() == 0 + for n in walk(cache): + assert n.ref_counter == 0 + + +def test_root_ref_not_leaked_on_miss(): + """Repeated complete misses must not drift root.ref_counter (regression).""" + cache, small, mm = build() + r0 = cache.root_node.ref_counter + hashs = hashes_for([7, 8]) + key = torch.tensor([t for p in [7, 8] for t in page_tokens(p)], dtype=torch.int64) + for _ in range(25): + node, kv, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) + assert node is None and kv == 0 and mem is None + assert cache.root_node.ref_counter == r0, "root ref leaked on cache miss" diff --git a/unit_tests/server/router/model_infer/test_linear_att_free_req_big_page.py b/unit_tests/server/router/model_infer/test_linear_att_free_req_big_page.py new file mode 100644 index 0000000000..bbebde5431 --- /dev/null +++ b/unit_tests/server/router/model_infer/test_linear_att_free_req_big_page.py @@ -0,0 +1,88 @@ +"""Regression tests for the big-page state-buffer release on pause/abort mid-prefill. + +Bug: in big-page mode (--linear_att_page_block_num set), a request whose chunked +prefill crossed a big-page boundary (filling req.linear_att_len_to_big_page_id) and +was then paused/aborted before completing took _linear_att_free_req's fallback branch, +which freed tokens but never released the accumulated big-page state-buffer ids -> +free_a_req_mem's `assert len(req.linear_att_len_to_big_page_id) == 0` crashed the +worker (or leaked big-page slots with asserts disabled). +""" +import types + +import torch +from sortedcontainers import SortedDict + +import lightllm.common.basemodel # noqa: F401 (import first to break a circular-import cycle) +from lightllm.server.router.model_infer import infer_batch as IB + + +class _BigPool: + def __init__(self): + self.freed = [] + + def free_state_cache(self, ids): + self.freed.extend(ids) + + +class _RadixCache: + def __init__(self): + self.linear_att_big_page_buffers = _BigPool() + self.deced = [] + + def dec_node_ref_counter(self, node): + self.deced.append(node) + + +def test_release_helper_frees_and_clears(): + ctx = IB.InferenceContext.__new__(IB.InferenceContext) + ctx.radix_cache = _RadixCache() + req = types.SimpleNamespace(linear_att_len_to_big_page_id=SortedDict({8: 101, 16: 102})) + + ctx._release_pending_linear_att_big_page_ids(req) + assert sorted(ctx.radix_cache.linear_att_big_page_buffers.freed) == [101, 102] + assert len(req.linear_att_len_to_big_page_id) == 0 + + # idempotent: a second call on an empty dict frees nothing more + ctx._release_pending_linear_att_big_page_ids(req) + assert sorted(ctx.radix_cache.linear_att_big_page_buffers.freed) == [101, 102] + + +def _make_ctx_and_req(monkeypatch, cur_kv_len, cache_len, pending): + ctx = IB.InferenceContext.__new__(IB.InferenceContext) + ctx.is_linear_att_mixed_model = True + ctx.radix_cache = _RadixCache() + ctx.req_manager = types.SimpleNamespace(req_to_token_indexs=torch.arange(0, 200, dtype=torch.int64).reshape(1, 200)) + # _linear_att_free_req asserts on the *global* g_infer_context, and reads start args + monkeypatch.setattr(IB, "g_infer_context", ctx) + monkeypatch.setattr( + IB, + "get_env_start_args", + lambda: types.SimpleNamespace(linear_att_hash_page_size=4, linear_att_page_block_num=2), + ) + req = IB.InferReq.__new__(IB.InferReq) + req.req_idx = 0 + req.shared_kv_node = None + req.cur_kv_len = cur_kv_len + req.linear_att_cache_len = cache_len + req.tail_linear_att_small_page_buffer_id = None + req.linear_att_len_to_big_page_id = SortedDict(pending) + return ctx, req + + +def test_branch_c_releases_pending_big_pages(monkeypatch): + # big_page_token_num = page(4)*block(2) = 8; cache_len=16 -> tail_big=16. + # cur_kv_len=8 < tail_big=16 -> branch A and B skipped, fallback branch C taken. + # pending dict holds the big-page id saved at boundary 8 during prefill. + ctx, req = _make_ctx_and_req(monkeypatch, cur_kv_len=8, cache_len=16, pending={8: 777}) + free_idx = [] + ctx._linear_att_free_req(free_idx, req) + assert ctx.radix_cache.linear_att_big_page_buffers.freed == [777], "branch C must release pending big-page ids" + assert len(req.linear_att_len_to_big_page_id) == 0, "pending dict must be empty after free (invariant)" + + +def test_cur_kv_len_zero_release(monkeypatch): + # cur_kv_len == 0 early return must also leave the dict empty (defensive). + ctx, req = _make_ctx_and_req(monkeypatch, cur_kv_len=0, cache_len=16, pending={8: 555}) + ctx._linear_att_free_req([], req) + assert ctx.radix_cache.linear_att_big_page_buffers.freed == [555] + assert len(req.linear_att_len_to_big_page_id) == 0 From 8ec5a2a7272dfbcaa642fa20a9d4ed78268ef897 Mon Sep 17 00:00:00 2001 From: wzj Date: Sun, 14 Jun 2026 05:20:57 +0000 Subject: [PATCH 2/4] fix --- .../router/dynamic_prompt/linear_att_radix_cache.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py index 22b79cbe2c..bf07e121e6 100644 --- a/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/linear_att_radix_cache.py @@ -365,18 +365,12 @@ def match_prefix( ans_node_list=ans_node_list, update_refs=update_refs, ) - # _match_prefix_helper 进入时一定对 root 自增了一次 ref_counter。命中链非空时,调用方最终会 - # 通过 dec_node_ref_counter(ans_node) 沿父链回收(含 root),增减平衡;但下面两个 "命中为空" - # 的提前返回会把 None 交给调用方,调用方不会再回收,root 自增就无人抵消,导致 root.ref_counter - # 在每次 miss / trim 到空时持续漂移。这里显式补偿这一次 root 自增。 if len(ans_node_list) == 0: - self.dec_node_ref_counter(self.root_node) return None, 0, None # 判定真正可以用的匹配节点。 ans_node_list = self._trim_unusable_match_tail(ans_node_list) if len(ans_node_list) == 0: - self.dec_node_ref_counter(self.root_node) return None, 0, None ans_node = ans_node_list[-1] @@ -491,9 +485,6 @@ def deref_to_first_big_page_node(self, node: LinearAttPagedTreeNode) -> Optional iter_node = iter_node.parent if iter_node is self.root_node: - # 没有可承接的 big-page 节点交给调用方释放:root 在 match 阶段同样被 +1, - # 这里必须补偿,否则与 match_prefix miss 路径同类的 root ref 漂移。 - self.dec_node_ref_counter(self.root_node) return None else: return iter_node From 87fecd16cdb9f28ffacfe69f9f4702fbdd4147e0 Mon Sep 17 00:00:00 2001 From: wzj Date: Sun, 14 Jun 2026 05:25:43 +0000 Subject: [PATCH 3/4] fix --- .../server/router/model_infer/infer_batch.py | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 67fb03ca6e..5c2d0d45fb 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -149,18 +149,6 @@ def _full_att_free_req(self, free_token_index: List, req: "InferReq"): req.shared_kv_node = None return - def _release_pending_linear_att_big_page_ids(self, req: "InferReq"): - # 释放本请求 prefill 阶段在 big page 边界上申请、但尚未插入 radix cache 的 big page - # state buffer。仅当请求未走 insert 分支(小页/大页插入)就被释放时才会有残留,典型场景: - # big page 模式下请求在 prefill 跨过 big page 边界后、到达末尾前被 pause / abort。 - # 若不释放,会泄漏 big page state slot,并触发 free_a_req_mem 中 dict 为空的断言。 - if req.linear_att_len_to_big_page_id: - self.radix_cache.linear_att_big_page_buffers.free_state_cache( - list(req.linear_att_len_to_big_page_id.values()) - ) - req.linear_att_len_to_big_page_id.clear() - return - def _linear_att_free_req(self, free_token_index: List, req: "InferReq"): assert g_infer_context.is_linear_att_mixed_model is True args = get_env_start_args() @@ -176,7 +164,6 @@ def _linear_att_free_req(self, free_token_index: List, req: "InferReq"): assert req.linear_att_cache_len <= req.cur_kv_len if req.cur_kv_len == 0: - self._release_pending_linear_att_big_page_ids(req) return if req.linear_att_cache_len <= req.cur_kv_len and req.tail_linear_att_small_page_buffer_id is not None: @@ -239,15 +226,25 @@ def _linear_att_free_req(self, free_token_index: List, req: "InferReq"): if shared_kv_len <= req.cur_kv_len: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][shared_kv_len : req.cur_kv_len]) + # 该分支不会把 prefill 阶段累积的 big page id 插入 radix cache(典型为 pause/abort + # 在 prefill 跨过 big page 边界后、到达末尾前触发),需在此显式释放,避免泄漏。 + + # 释放本请求 prefill 阶段在 big page 边界上申请、但尚未插入 radix cache 的 big page + # state buffer。仅当请求未走 insert 分支(小页/大页插入)就被释放时才会有残留,典型场景: + # big page 模式下请求在 prefill 跨过 big page 边界后、到达末尾前被 pause / abort。 + # 若不释放,会泄漏 big page state slot,并触发 free_a_req_mem 中 dict 为空的断言。 + if req.linear_att_len_to_big_page_id: + self.radix_cache.linear_att_big_page_buffers.free_state_cache( + list(req.linear_att_len_to_big_page_id.values()) + ) + req.linear_att_len_to_big_page_id.clear() + req.cur_kv_len = shared_kv_len assert req.tail_linear_att_small_page_buffer_id is None if req.shared_kv_node is not None: assert req.shared_kv_node.node_prefix_total_len == req.cur_kv_len self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None - # 该分支不会把 prefill 阶段累积的 big page id 插入 radix cache(典型为 pause/abort - # 在 prefill 跨过 big page 边界后、到达末尾前触发),需在此显式释放,避免泄漏。 - self._release_pending_linear_att_big_page_ids(req) return assert False, f"error state: cur_kv_len: {req.cur_kv_len}" From c50cf6a891d32591202e2afb692a17ba37c7952d Mon Sep 17 00:00:00 2001 From: wzj Date: Sun, 14 Jun 2026 05:27:26 +0000 Subject: [PATCH 4/4] remove unit tests added by this PR --- .../test_linear_att_radix_cache_bigpage.py | 300 ------------- .../test_linear_att_radix_cache_invariants.py | 404 ------------------ .../test_linear_att_free_req_big_page.py | 88 ---- 3 files changed, 792 deletions(-) delete mode 100644 unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_bigpage.py delete mode 100644 unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_invariants.py delete mode 100644 unit_tests/server/router/model_infer/test_linear_att_free_req_big_page.py diff --git a/unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_bigpage.py b/unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_bigpage.py deleted file mode 100644 index fc2f5d9981..0000000000 --- a/unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_bigpage.py +++ /dev/null @@ -1,300 +0,0 @@ -"""Big-page-regime coverage + invariant fuzz for LinearAttPagedRadixCache. - -Active in production only when --linear_att_page_block_num is set (e.g. the GSM8K -launch scripts use 8). Here big_page_num is small so inserts create big-page nodes -plus an optional small tail, mirroring _linear_att_free_req's two insert calls and -copy_linear_att_state_to_cache_buffer's len_to_big_page_id construction. -""" -import uuid - -import numpy as np -import pytest -import torch -from sortedcontainers import SortedDict - -from lightllm.server.router.dynamic_prompt.linear_att_radix_cache import LinearAttPagedRadixCache -from lightllm.utils.kv_cache_utils import compute_token_list_hash - -PAGE = 4 -BIGN = 2 -BIG_TOKENS = PAGE * BIGN - - -class FakePool: - def __init__(self, size): - self.size = size - self.free_set = set(range(size)) - self.order = list(range(size)) - - def alloc_one_state_cache(self): - if not self.order: - return None - i = self.order.pop(0) - self.free_set.discard(i) - return i - - def free_state_cache(self, free_indexes): - for i in free_indexes: - assert i is not None and i not in self.free_set, f"double free {i}" - self.free_set.add(i) - self.order.append(i) - - def get_free_cache_num(self): - return len(self.order) - - -class FakeAllocator: - def __init__(self, size): - self.size = size - self.can_use_mem_size = size - - -class FakeMem: - def __init__(self, size, big_pool): - self.allocator = FakeAllocator(size) - self.linear_att_big_page_buffers = big_pool - - def free(self, mem_index): - self.allocator.can_use_mem_size += len(mem_index) - - -def build(small_size=32, big_size=64, mem=400_000): - small = FakePool(small_size) - big = FakePool(big_size) - mm = FakeMem(mem, big) - cache = LinearAttPagedRadixCache( - unique_name=f"bp_{uuid.uuid4().hex[:8]}", - total_token_num=mem, - rank_in_node=0, - hash_page_size=PAGE, - big_page_num=BIGN, - kv_cache_mem_manager=mm, - linear_att_small_page_buffers=small, - ) - return cache, small, big, mm - - -def walk(cache): - out = [] - st = list(cache.root_node.children.values()) - while st: - n = st.pop() - out.append(n) - st.extend(n.children.values()) - return out - - -def page_tokens(pid): - return list(range(pid * PAGE, pid * PAGE + PAGE)) - - -def hashes_for(pids): - toks = [] - for p in pids: - toks += page_tokens(p) - toks.append(-1) - return compute_token_list_hash(toks, PAGE) - - -def check(cache, small, big): - nodes = walk(cache) - # structural + accounting - total = 0 - refed = 0 - for n in nodes: - assert n.parent is not None - assert n.node_prefix_total_len == n.parent.node_prefix_total_len + n.node_value_len - assert n.ref_counter >= 0 - assert n.node_value_len == len(n.token_mem_index_value) - if n.is_big_page_node(): - assert n.page_num == BIGN and n.node_value_len == BIG_TOKENS - assert n.big_page_buffer_idx is not None - assert n.small_page_buffer_idx is None - else: - assert n.page_num == 1 and n.node_value_len == PAGE - assert n.big_page_buffer_idx is None - total += n.node_value_len - if n.ref_counter > 0: - refed += n.node_value_len - for k, c in n.children.items(): - assert c.page_hash == k and c.parent is n - assert cache.get_tree_total_tokens_num() == total - assert cache.get_refed_tokens_num() == refed - # evict set == non-root leaves - leaves = {id(n) for n in nodes if n.is_leaf()} - assert {id(n) for n in cache._evict_tree_set} == leaves - assert id(cache.root_node) not in {id(n) for n in cache._evict_tree_set} - # buffer-evict set == small-buffer holders - assert {id(n) for n in cache._evict_tree_set_for_linear_att} == { - id(n) for n in nodes if n.small_page_buffer_idx is not None - } - # big-page id conservation - big_in_tree = [n.big_page_buffer_idx for n in nodes if n.is_big_page_node()] - assert len(big_in_tree) == len(set(big_in_tree)), "big-page id reused by two nodes" - assert set(big_in_tree).isdisjoint(big.free_set) - assert set(big_in_tree) | big.free_set == set(range(big.size)), "big-page id leaked" - # small-page id conservation - small_in_tree = [n.small_page_buffer_idx for n in nodes if n.small_page_buffer_idx is not None] - assert len(small_in_tree) == len(set(small_in_tree)) - assert set(small_in_tree).isdisjoint(small.free_set) - assert set(small_in_tree) | small.free_set == set(range(small.size)), "small-page id leaked" - - -def make_insert(cache, small, big): - """Mirror _linear_att_free_req: big-page-aligned prefix (+ optional small tail).""" - - def insert(pids, mem_base, with_small_tail): - L = len(pids) - num_big = L // BIGN - # len_to_big_page_id: one fresh big id per big-page boundary along the path - l2b = SortedDict() - big_ids_alloced = [] - for j in range(1, num_big + 1): - bid = big.alloc_one_state_cache() - if bid is None: - # big pool exhausted: the real caller would not start this insert; roll back. - for got in big_ids_alloced: - big.free_state_cache([got]) - return - big_ids_alloced.append(bid) - l2b[j * BIG_TOKENS] = bid - hashs = hashes_for(pids) - key = torch.tensor([t for p in pids for t in page_tokens(p)], dtype=torch.int64) - value = torch.arange(mem_base, mem_base + L * PAGE, dtype=torch.int64) - linear_idxs = [None] * L - tail_buf = None - if with_small_tail and (L % BIGN != 0): - tail_buf = small.alloc_one_state_cache() - if tail_buf is None: - # contract: cannot insert a None-tailed non-aligned path; drop the tail page - pids = pids[:-1] - L = len(pids) - if L == 0: - # nothing to insert; release any big ids we grabbed (none, since num_big recomputed) - for bid in big_ids_alloced: - big.free_state_cache([bid]) - return - hashs = hashes_for(pids) - key = torch.tensor([t for p in pids for t in page_tokens(p)], dtype=torch.int64) - value = torch.arange(mem_base, mem_base + L * PAGE, dtype=torch.int64) - linear_idxs = [None] * L - else: - linear_idxs[-1] = tail_buf - elif L % BIGN != 0: - # no small tail wanted but path is not big-aligned -> trim to aligned length - pids = pids[: num_big * BIGN] - L = len(pids) - if L == 0: - for bid in big_ids_alloced: - big.free_state_cache([bid]) - return - hashs = hashes_for(pids) - key = torch.tensor([t for p in pids for t in page_tokens(p)], dtype=torch.int64) - value = torch.arange(mem_base, mem_base + L * PAGE, dtype=torch.int64) - linear_idxs = [None] * L - - before_small = set(small.free_set) - cache.insert(key, value, block_hashs=hashs, block_linear_idxs=linear_idxs, len_to_big_page_id=l2b) - # any tail buffer that was a duplicate got freed by the cache; nothing to track - _ = before_small - - return insert - - -def test_pure_bigpage_insert_and_match(): - cache, small, big = build()[:3] - ins = make_insert(cache, small, big) - # 4 pages -> 2 big pages, no small tail - ins([1, 2, 3, 4], 1000, with_small_tail=False) - check(cache, small, big) - assert cache.get_tree_total_tokens_num() == 4 * PAGE - - hashs = hashes_for([1, 2, 3, 4]) - key = torch.tensor([t for p in [1, 2, 3, 4] for t in page_tokens(p)], dtype=torch.int64) - node, kv, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) - assert node is not None and node.is_big_page_node() - assert kv == 16 and len(mem) == 16 - assert torch.equal(mem, torch.arange(1000, 1016, dtype=torch.int64)) - cache.dec_node_ref_counter(node) - check(cache, small, big) - - -def test_mixed_insert_match_trims_to_bigpage_when_tail_unusable(): - cache, small, big = build(small_size=1)[:3] - ins = make_insert(cache, small, big) - # 5 pages -> 2 big pages (8 tokens *2 =16) + 1 small tail page (4) = 20 tokens - ins([1, 2, 3, 4, 5], 2000, with_small_tail=True) - check(cache, small, big) - assert cache.get_tree_total_tokens_num() == 20 - - # exhaust small pool and steal the tail buffer -> tail page unusable - while small.alloc_one_state_cache() is not None: - pass - cache.free_one_small_page_linear_att_buffer() - check(cache, small, big) - - hashs = hashes_for([1, 2, 3, 4, 5]) - key = torch.tensor([t for p in [1, 2, 3, 4, 5] for t in page_tokens(p)], dtype=torch.int64) - node, kv, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) - # tail small page has no buffer -> trim back to the last big-page boundary (16) - assert node is not None and node.is_big_page_node() - assert kv == 16 - cache.dec_node_ref_counter(node) - check(cache, small, big) - - -@pytest.mark.parametrize("seed", list(range(10))) -def test_bigpage_fuzz(seed): - rng = np.random.default_rng(seed) - cache, small, big, mm = build(small_size=10, big_size=48, mem=400_000) - ins = make_insert(cache, small, big) - live = [] - mem_base = [10_000] - - def do_ins(): - L = int(rng.integers(1, 7)) - pids = [int(rng.integers(0, 25)) for _ in range(L)] - ins(pids, mem_base[0], with_small_tail=bool(rng.integers(0, 2))) - mem_base[0] += 100 - - def do_match(): - L = int(rng.integers(1, 7)) - pids = [int(rng.integers(0, 25)) for _ in range(L)] - hashs = hashes_for(pids) - key = torch.tensor([t for p in pids for t in page_tokens(p)], dtype=torch.int64) - node, kv, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) - if node is None: - assert kv == 0 and mem is None - return - assert kv == node.node_prefix_total_len == len(mem) - assert node.is_big_page_node() or node.small_page_buffer_idx is not None - live.append(node) - - def do_dec(): - if live: - cache.dec_node_ref_counter(live.pop(int(rng.integers(0, len(live))))) - - def do_steal(): - cache.free_one_small_page_linear_att_buffer() - - def do_evict(): - unref = cache.get_tree_total_tokens_num() - cache.get_refed_tokens_num() - if unref < PAGE: - return - need = int(rng.integers(1, unref // PAGE + 1)) * PAGE - cache._evict(need, lambda m, b: small.free_state_cache([b]) if b is not None else None) - - ops = [do_ins, do_ins, do_match, do_match, do_dec, do_steal, do_evict] - for _ in range(400): - ops[int(rng.integers(0, len(ops)))]() - check(cache, small, big) - assert cache.root_node.ref_counter == 1 + len(live), "root ref drifted (big-page regime)" - - while live: - cache.dec_node_ref_counter(live.pop()) - assert cache.get_refed_tokens_num() == 0 - t = cache.get_tree_total_tokens_num() - if t: - cache._evict(t, lambda m, b: small.free_state_cache([b]) if b is not None else None) - assert cache.get_tree_total_tokens_num() == 0 - check(cache, small, big) diff --git a/unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_invariants.py b/unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_invariants.py deleted file mode 100644 index dac8255ffc..0000000000 --- a/unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_invariants.py +++ /dev/null @@ -1,404 +0,0 @@ -"""Property-based invariant fuzzer for LinearAttPagedRadixCache (small-page regime). - -This drives the cache the way infer_batch.py does in the default serving regime -(big-page matching disabled, i.e. big_page_num huge so every request inserts only -small pages with a single tail state buffer). After every random operation it -asserts the full set of internal invariants and verifies value integrity against -an independent first-write-wins oracle. - -Run: pytest unit_tests/server/router/dynamic_prompt/test_linear_att_radix_cache_invariants.py -q -""" -import uuid - -import numpy as np -import pytest -import torch - -from lightllm.server.router.dynamic_prompt.linear_att_radix_cache import LinearAttPagedRadixCache - -PAGE = 4 -BIG = 10_000_000 # big pages effectively disabled (serving default regime) - - -class FakeSmallPageBuffers: - """Models LinearAttCacheManager's id pool with strict double-free detection.""" - - def __init__(self, size): - self.size = size - self.free_set = set(range(size)) - self.free_order = list(range(size)) - - def alloc_one_state_cache(self): - if not self.free_order: - return None - idx = self.free_order.pop(0) - self.free_set.discard(idx) - return idx - - def free_state_cache(self, free_indexes): - for idx in free_indexes: - assert idx is not None - assert idx not in self.free_set, f"double free of small-page buffer {idx}" - self.free_set.add(idx) - self.free_order.append(idx) - - def get_free_cache_num(self): - return len(self.free_order) - - -class FakeBigPageBuffers: - def __init__(self): - self.freed = [] - - def free_state_cache(self, free_indexes): - self.freed.extend(free_indexes) - - -class FakeAllocator: - def __init__(self, size): - self.size = size - self.can_use_mem_size = size - - -class FakeMemManager: - def __init__(self, size): - self.allocator = FakeAllocator(size) - self.linear_att_big_page_buffers = FakeBigPageBuffers() - self.freed_mem = [] - - def free(self, mem_index): - self.freed_mem.append(mem_index) - self.allocator.can_use_mem_size += len(mem_index) - - -def build(small_pool_size=32, mem_size=100_000): - small = FakeSmallPageBuffers(small_pool_size) - mm = FakeMemManager(mem_size) - cache = LinearAttPagedRadixCache( - unique_name=f"fuzz_{uuid.uuid4().hex[:8]}", - total_token_num=mem_size, - rank_in_node=0, - hash_page_size=PAGE, - big_page_num=BIG, - kv_cache_mem_manager=mm, - linear_att_small_page_buffers=small, - ) - return cache, small, mm - - -# ------------------------- tree walking helpers ------------------------- - - -def walk(cache): - """Return all non-root nodes via BFS.""" - out = [] - stack = list(cache.root_node.children.values()) - while stack: - n = stack.pop() - out.append(n) - stack.extend(n.children.values()) - return out - - -def check_invariants(cache, small: FakeSmallPageBuffers, allocated_ids: set): - nodes = walk(cache) - - # 1. structural: prefix len, child-key consistency, page bookkeeping - for n in nodes: - assert n.parent is not None - assert n.node_value_len == len(n.token_mem_index_value) == len(n.token_id_key) - assert n.node_prefix_total_len == n.parent.node_prefix_total_len + n.node_value_len - assert n.ref_counter >= 0, f"negative ref_counter {n.ref_counter}" - for k, c in n.children.items(): - assert c.page_hash == k - assert c.parent is n - # small-page regime: every node is exactly one page - assert n.node_value_len == PAGE - assert n.page_num == 1 - assert not n.is_big_page_node() or n.node_prefix_total_len == 0 - - # 2. accounting: tree_total and refed token counts match the live tree - total = sum(n.node_value_len for n in nodes) - refed = sum(n.node_value_len for n in nodes if n.ref_counter > 0) - assert ( - cache.get_tree_total_tokens_num() == total - ), f"tree_total {cache.get_tree_total_tokens_num()} != actual {total}" - assert cache.get_refed_tokens_num() == refed, f"refed {cache.get_refed_tokens_num()} != actual {refed}" - - # 3. evict-set membership: exactly the leaves, root excluded - leaves = {id(n) for n in nodes if n.is_leaf()} - evict_ids = {id(n) for n in cache._evict_tree_set} - assert evict_ids == leaves, "evict set must equal the set of non-root leaves" - assert id(cache.root_node) not in evict_ids - - # 4. buffer-eviction set membership: exactly nodes holding a small-page buffer - with_buf = {id(n) for n in nodes if n.small_page_buffer_idx is not None} - buf_evict_ids = {id(n) for n in cache._evict_tree_set_for_linear_att} - assert buf_evict_ids == with_buf, "linear-att evict set must equal nodes with a buffer" - - # 5. buffer-id conservation: ids in tree, ids free in pool, partition the universe - in_tree = [n.small_page_buffer_idx for n in nodes if n.small_page_buffer_idx is not None] - assert len(in_tree) == len(set(in_tree)), "a small-page buffer id is used by two nodes" - in_tree_set = set(in_tree) - # every allocated id is either in the tree or free in the pool, never both, never lost - assert in_tree_set.isdisjoint(small.free_set), "buffer id is both in tree and free" - assert in_tree_set | small.free_set == allocated_ids | set( - range(small.size) - ), "buffer id leaked (neither in tree nor free)" - - -# ------------------------- oracle for value integrity ------------------------- - - -def page_tokens(page_id): - # distinct, deterministic token block per logical page id - return list(range(page_id * PAGE, page_id * PAGE + PAGE)) - - -def hashes_for(page_ids): - # chained hash so prefixes are prefix-closed (same as real block hashing) - from lightllm.utils.kv_cache_utils import compute_token_list_hash - - toks = [] - for p in page_ids: - toks.extend(page_tokens(p)) - toks.append(-1) # one extra so (len-1)//PAGE == len(page_ids) - return compute_token_list_hash(toks, PAGE) - - -class Oracle: - """First-write-wins record of the value stored for each hashed page path.""" - - def __init__(self): - self.value_by_hash = {} # block_hash -> mem value tensor (PAGE long) - - def record(self, hashs, values): - for i, h in enumerate(hashs): - if h not in self.value_by_hash: - self.value_by_hash[h] = values[i * PAGE : (i + 1) * PAGE].clone() - - def forget(self, freed_hashs): - for h in freed_hashs: - self.value_by_hash.pop(h, None) - - def expected_mem(self, hashs): - return torch.cat([self.value_by_hash[h] for h in hashs]) - - -# ------------------------- the fuzzer ------------------------- - - -@pytest.mark.parametrize("seed", list(range(16))) -@pytest.mark.parametrize("pool", [6, 24]) # 6 => constant state-buffer pressure & real steals -def test_invariant_fuzz(seed, pool): - rng = np.random.default_rng(seed * 100 + pool) - cache, small, mm = build(small_pool_size=pool, mem_size=200_000) - oracle = Oracle() - - next_mem = [1000] - - def alloc_mem(n): - base = next_mem[0] - next_mem[0] += n - return torch.arange(base, base + n, dtype=torch.int64) - - allocated_ids = set() - live = [] # list of (ans_node, matched_hashs) holding a ref to dec later - - def do_insert(): - npages = int(rng.integers(1, 7)) - page_ids = [int(rng.integers(0, 40)) for _ in range(npages)] - hashs = hashes_for(page_ids) - key = torch.tensor([t for p in page_ids for t in page_tokens(p)], dtype=torch.int64) - value = alloc_mem(npages * PAGE) - buf = small.alloc_one_state_cache() - if buf is None: - # Contract (see _linear_att_free_req): in the small-page regime the radix - # insert is only performed when a tail state buffer exists. With the pool - # exhausted the request simply caches nothing. Skip — do not insert a - # None-tailed path (the cache rightly asserts against that). - return - allocated_ids.add(buf) - linear_idxs = [None] * npages - linear_idxs[-1] = buf - oracle.record(hashs, value) - before_free = set(small.free_set) - cache.insert(key, value, block_hashs=hashs, block_linear_idxs=linear_idxs) - # if our buffer was immediately freed (duplicate tail), drop it from allocated set - newly_free = small.free_set - before_free - for idx in newly_free: - allocated_ids.discard(idx) - - def do_match(): - npages = int(rng.integers(1, 7)) - page_ids = [int(rng.integers(0, 40)) for _ in range(npages)] - hashs = hashes_for(page_ids) - key = torch.tensor([t for p in page_ids for t in page_tokens(p)], dtype=torch.int64) - node, kv_len, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) - if node is None: - assert kv_len == 0 and mem is None - return - assert kv_len == node.node_prefix_total_len == len(mem) - assert kv_len % PAGE == 0 - matched_pages = kv_len // PAGE - matched_hashs = hashs[:matched_pages] - # value integrity: returned mem must equal first-written values for this path - expected = oracle.expected_mem(matched_hashs) - assert torch.equal(mem, expected), "match returned wrong mem values" - # the matched tail must be reusable: big-page or has a state buffer - assert node.is_big_page_node() or node.small_page_buffer_idx is not None - live.append((node, matched_hashs)) - - def do_dec(): - if not live: - return - i = int(rng.integers(0, len(live))) - node, _ = live.pop(i) - cache.dec_node_ref_counter(node) - - def do_steal(): - before = small.get_free_cache_num() - cache.free_one_small_page_linear_att_buffer() - after = small.get_free_cache_num() - if after > before: - # a stolen buffer returned to the pool; drop any of our tracked ids that - # are no longer in the tree (conservation invariant rechecks the rest) - in_tree = {n.small_page_buffer_idx for n in walk(cache) if n.small_page_buffer_idx is not None} - for idx in list(allocated_ids): - if idx not in in_tree and idx in small.free_set: - allocated_ids.discard(idx) - - def do_evict(): - unref = cache.get_tree_total_tokens_num() - cache.get_refed_tokens_num() - if unref <= 0: - return - want_pages = int(rng.integers(1, unref // PAGE + 1)) if unref >= PAGE else 0 - if want_pages == 0: - return - need = want_pages * PAGE - - def cb(mem_index, small_buf_id): - # mirror free_radix_cache_to_get_enough_token: evicted node's state buffer - # is returned to the pool by the caller's callback. - if small_buf_id is not None: - small.free_state_cache([small_buf_id]) - - # capture which page hashes leave the tree - before_nodes = {n.page_hash for n in walk(cache)} - cache._evict(need, cb) - after_nodes = {n.page_hash for n in walk(cache)} - freed_hashs = before_nodes - after_nodes - oracle.forget(freed_hashs) - # drop freed buffer ids - in_tree = {n.small_page_buffer_idx for n in walk(cache) if n.small_page_buffer_idx is not None} - for idx in list(allocated_ids): - if idx not in in_tree and idx in small.free_set: - allocated_ids.discard(idx) - - ops = [do_insert, do_insert, do_match, do_match, do_dec, do_steal, do_steal, do_evict] - for step in range(600): - op = ops[int(rng.integers(0, len(ops)))] - op() - check_invariants(cache, small, allocated_ids) - # root must hold exactly baseline(1) + one ref per still-held match; it must NOT - # drift on misses / trim-to-empty (regression guard for the root-ref leak). - assert cache.root_node.ref_counter == 1 + len( - live - ), f"root ref drifted: {cache.root_node.ref_counter} != 1 + {len(live)}" - - # drain all references, then evict everything; tree must end empty and balanced - while live: - node, _ = live.pop() - cache.dec_node_ref_counter(node) - check_invariants(cache, small, allocated_ids) - assert cache.get_refed_tokens_num() == 0 - total = cache.get_tree_total_tokens_num() - if total > 0: - cache._evict(total, lambda m, b: small.free_state_cache([b]) if b is not None else None) - assert cache.get_tree_total_tokens_num() == 0 - assert len(cache._evict_tree_set) == 0 - assert len(cache._evict_tree_set_for_linear_att) == 0 - - -def test_root_ref_balanced_across_many_matches(): - """Root ref must not drift: many match+dec cycles leave it at its initial value.""" - cache, small, mm = build() - root_ref0 = cache.root_node.ref_counter - page_ids = [1, 2, 3] - hashs = hashes_for(page_ids) - key = torch.tensor([t for p in page_ids for t in page_tokens(p)], dtype=torch.int64) - value = torch.arange(1000, 1000 + len(page_ids) * PAGE, dtype=torch.int64) - buf = small.alloc_one_state_cache() - linear_idxs = [None, None, buf] - cache.insert(key, value, block_hashs=hashs, block_linear_idxs=linear_idxs) - - for _ in range(50): - node, kv_len, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) - assert node is not None - cache.dec_node_ref_counter(node) - assert cache.root_node.ref_counter == root_ref0, "root ref_counter drifted" - assert cache.get_refed_tokens_num() == 0 - - -def test_match_then_trim_to_empty_balances_refs(): - """A match that trims all the way back to nothing must restore refs and refed tokens.""" - cache, small, mm = build(small_pool_size=2) - # insert a 2-page path whose tail has NO buffer and is not a big page: - # build it via a longer insert then steal the tail buffer so the tail is unusable. - page_ids = [5, 6] - hashs = hashes_for(page_ids) - key = torch.tensor([t for p in page_ids for t in page_tokens(p)], dtype=torch.int64) - value = torch.arange(2000, 2000 + 2 * PAGE, dtype=torch.int64) - buf = small.alloc_one_state_cache() - cache.insert(key, value, block_hashs=hashs, block_linear_idxs=[None, buf]) - - # exhaust the pool so the steal actually fires (it is a no-op while slots are free), - # then steal the only in-tree buffer -> both pages unusable (no buffer, not big page) - while small.alloc_one_state_cache() is not None: - pass - cache.free_one_small_page_linear_att_buffer() - refed0 = cache.get_refed_tokens_num() - root_ref0 = cache.root_node.ref_counter - node, kv_len, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) - assert node is None and kv_len == 0 and mem is None - assert cache.get_refed_tokens_num() == refed0, "trim-to-empty leaked refed tokens" - assert cache.root_node.ref_counter == root_ref0, "trim-to-empty leaked a root ref" - # all nodes back to ref 0 - for n in walk(cache): - assert n.ref_counter == 0 - - -def test_deref_to_root_balances_root_ref(): - """deref_to_first_big_page_node returning None (reached root) must release the - match-time root ref — otherwise the big-page-enabled match path leaks it.""" - cache, small, mm = build() - page_ids = [1, 2] - hashs = hashes_for(page_ids) - key = torch.tensor([t for p in page_ids for t in page_tokens(p)], dtype=torch.int64) - value = torch.arange(3000, 3000 + 2 * PAGE, dtype=torch.int64) - buf = small.alloc_one_state_cache() - cache.insert(key, value, block_hashs=hashs, block_linear_idxs=[None, buf]) - - r0 = cache.root_node.ref_counter - share_node, kv, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) - assert share_node is not None and not share_node.is_big_page_node() - assert cache.root_node.ref_counter == r0 + 1 # match took one root ref - # small-page regime: root is the only big-page node, so deref walks to root -> None - node = cache.deref_to_first_big_page_node(share_node) - assert node is None - assert cache.root_node.ref_counter == r0, "deref-to-root leaked a root ref" - assert cache.get_refed_tokens_num() == 0 - for n in walk(cache): - assert n.ref_counter == 0 - - -def test_root_ref_not_leaked_on_miss(): - """Repeated complete misses must not drift root.ref_counter (regression).""" - cache, small, mm = build() - r0 = cache.root_node.ref_counter - hashs = hashes_for([7, 8]) - key = torch.tensor([t for p in [7, 8] for t in page_tokens(p)], dtype=torch.int64) - for _ in range(25): - node, kv, mem = cache.match_prefix(key, block_hashs=hashs, update_refs=True) - assert node is None and kv == 0 and mem is None - assert cache.root_node.ref_counter == r0, "root ref leaked on cache miss" diff --git a/unit_tests/server/router/model_infer/test_linear_att_free_req_big_page.py b/unit_tests/server/router/model_infer/test_linear_att_free_req_big_page.py deleted file mode 100644 index bbebde5431..0000000000 --- a/unit_tests/server/router/model_infer/test_linear_att_free_req_big_page.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Regression tests for the big-page state-buffer release on pause/abort mid-prefill. - -Bug: in big-page mode (--linear_att_page_block_num set), a request whose chunked -prefill crossed a big-page boundary (filling req.linear_att_len_to_big_page_id) and -was then paused/aborted before completing took _linear_att_free_req's fallback branch, -which freed tokens but never released the accumulated big-page state-buffer ids -> -free_a_req_mem's `assert len(req.linear_att_len_to_big_page_id) == 0` crashed the -worker (or leaked big-page slots with asserts disabled). -""" -import types - -import torch -from sortedcontainers import SortedDict - -import lightllm.common.basemodel # noqa: F401 (import first to break a circular-import cycle) -from lightllm.server.router.model_infer import infer_batch as IB - - -class _BigPool: - def __init__(self): - self.freed = [] - - def free_state_cache(self, ids): - self.freed.extend(ids) - - -class _RadixCache: - def __init__(self): - self.linear_att_big_page_buffers = _BigPool() - self.deced = [] - - def dec_node_ref_counter(self, node): - self.deced.append(node) - - -def test_release_helper_frees_and_clears(): - ctx = IB.InferenceContext.__new__(IB.InferenceContext) - ctx.radix_cache = _RadixCache() - req = types.SimpleNamespace(linear_att_len_to_big_page_id=SortedDict({8: 101, 16: 102})) - - ctx._release_pending_linear_att_big_page_ids(req) - assert sorted(ctx.radix_cache.linear_att_big_page_buffers.freed) == [101, 102] - assert len(req.linear_att_len_to_big_page_id) == 0 - - # idempotent: a second call on an empty dict frees nothing more - ctx._release_pending_linear_att_big_page_ids(req) - assert sorted(ctx.radix_cache.linear_att_big_page_buffers.freed) == [101, 102] - - -def _make_ctx_and_req(monkeypatch, cur_kv_len, cache_len, pending): - ctx = IB.InferenceContext.__new__(IB.InferenceContext) - ctx.is_linear_att_mixed_model = True - ctx.radix_cache = _RadixCache() - ctx.req_manager = types.SimpleNamespace(req_to_token_indexs=torch.arange(0, 200, dtype=torch.int64).reshape(1, 200)) - # _linear_att_free_req asserts on the *global* g_infer_context, and reads start args - monkeypatch.setattr(IB, "g_infer_context", ctx) - monkeypatch.setattr( - IB, - "get_env_start_args", - lambda: types.SimpleNamespace(linear_att_hash_page_size=4, linear_att_page_block_num=2), - ) - req = IB.InferReq.__new__(IB.InferReq) - req.req_idx = 0 - req.shared_kv_node = None - req.cur_kv_len = cur_kv_len - req.linear_att_cache_len = cache_len - req.tail_linear_att_small_page_buffer_id = None - req.linear_att_len_to_big_page_id = SortedDict(pending) - return ctx, req - - -def test_branch_c_releases_pending_big_pages(monkeypatch): - # big_page_token_num = page(4)*block(2) = 8; cache_len=16 -> tail_big=16. - # cur_kv_len=8 < tail_big=16 -> branch A and B skipped, fallback branch C taken. - # pending dict holds the big-page id saved at boundary 8 during prefill. - ctx, req = _make_ctx_and_req(monkeypatch, cur_kv_len=8, cache_len=16, pending={8: 777}) - free_idx = [] - ctx._linear_att_free_req(free_idx, req) - assert ctx.radix_cache.linear_att_big_page_buffers.freed == [777], "branch C must release pending big-page ids" - assert len(req.linear_att_len_to_big_page_id) == 0, "pending dict must be empty after free (invariant)" - - -def test_cur_kv_len_zero_release(monkeypatch): - # cur_kv_len == 0 early return must also leave the dict empty (defensive). - ctx, req = _make_ctx_and_req(monkeypatch, cur_kv_len=0, cache_len=16, pending={8: 555}) - ctx._linear_att_free_req([], req) - assert ctx.radix_cache.linear_att_big_page_buffers.freed == [555] - assert len(req.linear_att_len_to_big_page_id) == 0