Skip to content

llama: use f16 mask for FA to save VRAM#23764

Open
am17an wants to merge 2 commits into
ggml-org:masterfrom
am17an:kq_mask_f16
Open

llama: use f16 mask for FA to save VRAM#23764
am17an wants to merge 2 commits into
ggml-org:masterfrom
am17an:kq_mask_f16

Conversation

@am17an
Copy link
Copy Markdown
Contributor

@am17an am17an commented May 27, 2026

Overview

Currently we reserve the KQ mask in f32 even if FA is used, which is then is converted to f16 while passing to backends. The f32 mask still uses the compute buffer even though is not used, taking up extra VRAM. This PR reserves the kq-mask in f16. This provides 1.2GB of VRAM saving at -ub 2048 and ~300Mb at -ub 512 when using MTP

Additional information

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES, used CC and Codex to identify the problem and write the code. Will polish it up a bit

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 27, 2026

This stacked along with reserving only n_outputs == n_seqs saves some more VRAM. On -ub 512 I go from 824 Mb as compute buffer to 444 Mb. On -ub 2048 from 3.2GB to 1.5GB

Copy link
Copy Markdown
Contributor

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like the right approach to me. For non-FA configurations the mask is still FP32 but since the mask size scales linearly with context depth while the size of the KQ matrix scales quadratically it should not make a difference. For most models we would not even need FP16 since the mask values will just be either 0 or -inf. Or we could even calculate the mask values from indices (but which gets complicated a lot for multiple concurrent contexts). But these further optimizations would be a lot more invasive, require a lot of effort to implement properly, and only make sense if with this PR the mask would still be the largest tensor in the compute graph.

@JohannesGaessler
Copy link
Copy Markdown
Contributor

@am17an since you wrote:

Will polish it up a bit

Does that mean there will be more changes or is this PR ready for review?

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 28, 2026

You can review, I just didn't like the extra llama_mask function

Comment thread src/llama-impl.h Outdated
Comment on lines +44 to +53
// store an f32 mask value into a buffer that is either f32 or f16
template <typename T>
static inline T llama_mask_value(float v) {
if constexpr (std::is_same_v<T, ggml_fp16_t>) {
return ggml_fp32_to_fp16(v);
} else {
return v;
}
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be preferable to have a function like llama_cast analogous to ggml_cuda_cast as defined in convert.cuh (if it doesn't already exist somewhere). And to then use that function consistently for type conversions in both directions.

Comment thread src/llama-graph.cpp Outdated
{
GGML_ASSERT(self_kq_mask);
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
const auto fill = [&](ggml_tensor * mask, int n_swa, llama_swa_type swa_type) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the naming here is confusing since you have fill_mask, fill, and std::fill. Maybe rename fill -> fill_mask and fill_mask -> fill_mask_inner?

Comment thread src/llama-graph.cpp Outdated
Comment thread src/llama-graph.cpp
Comment thread src/llama-graph.h Outdated
Comment thread src/llama-graph.h Outdated
Comment thread src/llama-graph.h Outdated
Comment thread src/llama-graph.h Outdated
Comment thread src/llama-graph.h Outdated
Comment thread src/llama-graph.h Outdated
@am17an am17an requested a review from JohannesGaessler May 28, 2026 15:38
Copy link
Copy Markdown
Contributor

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the code for creating the mask again, I think it can be simplified a bit more with this patch:

diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 92e8d0d29..1b976eb8e 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -396,8 +396,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
     const int64_t n_kv     = ubatch->n_tokens;
     const int64_t n_tokens = ubatch->n_tokens;
 
-    const auto fill_mask_inner = [&](auto * data, int n_swa, llama_swa_type swa_type) {
+    const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) {
         using T = std::remove_reference_t<decltype(*data)>;
+        std::fill(data, data + ne, llama_cast<T>(-INFINITY));
         for (int i1 = 0; i1 < n_tokens; ++i1) {
             const llama_seq_id s1 = ubatch->seq_id[i1][0];
             const llama_pos    p1 = ubatch->pos[i1];
@@ -426,39 +427,27 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
                 data[idst + i0] = llama_cast<T>(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f);
             }
         }
-    };
-
-    const auto fill_mask = [&](ggml_tensor * mask, int n_swa, llama_swa_type swa_type) {
-        GGML_ASSERT(mask);
-        GGML_ASSERT(ggml_backend_buffer_is_host(mask->buffer));
-
-        if (mask->type == GGML_TYPE_F16) {
-            ggml_fp16_t * data = (ggml_fp16_t *) mask->data;
-
-            std::fill(data, data + ggml_nelements(mask), llama_cast<ggml_fp16_t>(-INFINITY));
-
-            fill_mask_inner(data, n_swa, swa_type);
-
-            if (debug) {
-                print_mask(data, n_tokens, n_kv, n_swa, swa_type);
-            }
-        } else {
-            float * data = (float *) mask->data;
-
-            std::fill(data, data + ggml_nelements(mask), -INFINITY);
-
-            fill_mask_inner(data, n_swa, swa_type);
-
-            if (debug) {
-                print_mask(data, n_tokens, n_kv, n_swa, swa_type);
-            }
+        if (debug) {
+            print_mask(data, n_tokens, n_kv, n_swa, swa_type);
         }
     };
 
-    fill_mask(self_kq_mask, 0, LLAMA_SWA_TYPE_NONE);
+    GGML_ASSERT(self_kq_mask);
+    GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
+    if (self_kq_mask->type == GGML_TYPE_F16) {
+        fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
+    } else {
+        fill_mask((float       *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
+    }
 
     if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
-        fill_mask(self_kq_mask_swa, hparams.n_swa, hparams.swa_type);
+        GGML_ASSERT(self_kq_mask_swa);
+        GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
+        if (self_kq_mask_swa->type == GGML_TYPE_F16) {
+            fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), 0, LLAMA_SWA_TYPE_NONE);
+        } else {
+            fill_mask((float       *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), 0, LLAMA_SWA_TYPE_NONE);
+        }
     }
 }

Comment thread src/llama-impl.h
Comment on lines +48 to +52
} else if (std::is_same_v<src_t, ggml_fp16_t> && std::is_same_v<dst_t, float>) {
return ggml_fp16_to_fp32(v);
} else if (std::is_same_v<src_t, float> && std::is_same_v<dst_t, ggml_fp16_t>) {
return ggml_fp32_to_fp16(v);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
} else if (std::is_same_v<src_t, ggml_fp16_t> && std::is_same_v<dst_t, float>) {
return ggml_fp16_to_fp32(v);
} else if (std::is_same_v<src_t, float> && std::is_same_v<dst_t, ggml_fp16_t>) {
return ggml_fp32_to_fp16(v);
}
} else if constexpr (std::is_same_v<src_t, ggml_fp16_t> && std::is_same_v<dst_t, float>) {
return ggml_fp16_to_fp32(v);
} else if constexpr (std::is_same_v<src_t, float> && std::is_same_v<dst_t, ggml_fp16_t>) {
return ggml_fp32_to_fp16(v);
} else {
static_assert(std::is_same_v(dst_t, void), "unsupported type combination");
}

I think constexpr should be added to avoid potential unexpected compiler errors; in this case casts between 16 bit unsigned integers and 32 bit floats are defined but will result in unexpected behavior. Also add a static_assert to detect unintended misuse.

Comment thread src/llama-graph.cpp
Comment on lines +379 to +384
float val;
if constexpr (std::is_same_v<T, ggml_fp16_t>) {
val = ggml_fp16_to_fp32(data[i * n_kv + j]);
} else {
val = data[i * n_kv + j];
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
float val;
if constexpr (std::is_same_v<T, ggml_fp16_t>) {
val = ggml_fp16_to_fp32(data[i * n_kv + j]);
} else {
val = data[i * n_kv + j];
}
float val = llama_cast<float>(data[i * n_kv + j]);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants