diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index f97873ce646..9eef4413a63 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -632,6 +632,8 @@ def _run_once( return y, attn_updates def _update_states(self, attn_updates, update_pos, update_len): + if attn_updates["out_cache_state"] is None: + return for mask in self._masks.values(): mask.unmask(update_len) k_cache_updates, v_cache_updates = attn_updates["out_cache_state"]