llama: use f16 mask for FA to save VRAM#23764
Conversation
|
This stacked along with reserving only |
JohannesGaessler
left a comment
There was a problem hiding this comment.
This seems like the right approach to me. For non-FA configurations the mask is still FP32 but since the mask size scales linearly with context depth while the size of the KQ matrix scales quadratically it should not make a difference. For most models we would not even need FP16 since the mask values will just be either 0 or -inf. Or we could even calculate the mask values from indices (but which gets complicated a lot for multiple concurrent contexts). But these further optimizations would be a lot more invasive, require a lot of effort to implement properly, and only make sense if with this PR the mask would still be the largest tensor in the compute graph.
|
@am17an since you wrote:
Does that mean there will be more changes or is this PR ready for review? |
|
You can review, I just didn't like the extra llama_mask function |
| // store an f32 mask value into a buffer that is either f32 or f16 | ||
| template <typename T> | ||
| static inline T llama_mask_value(float v) { | ||
| if constexpr (std::is_same_v<T, ggml_fp16_t>) { | ||
| return ggml_fp32_to_fp16(v); | ||
| } else { | ||
| return v; | ||
| } | ||
| } | ||
|
|
There was a problem hiding this comment.
I think it would be preferable to have a function like llama_cast analogous to ggml_cuda_cast as defined in convert.cuh (if it doesn't already exist somewhere). And to then use that function consistently for type conversions in both directions.
| { | ||
| GGML_ASSERT(self_kq_mask); | ||
| GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); | ||
| const auto fill = [&](ggml_tensor * mask, int n_swa, llama_swa_type swa_type) { |
There was a problem hiding this comment.
I think the naming here is confusing since you have fill_mask, fill, and std::fill. Maybe rename fill -> fill_mask and fill_mask -> fill_mask_inner?
JohannesGaessler
left a comment
There was a problem hiding this comment.
Looking at the code for creating the mask again, I think it can be simplified a bit more with this patch:
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 92e8d0d29..1b976eb8e 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -396,8 +396,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens;
- const auto fill_mask_inner = [&](auto * data, int n_swa, llama_swa_type swa_type) {
+ const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) {
using T = std::remove_reference_t<decltype(*data)>;
+ std::fill(data, data + ne, llama_cast<T>(-INFINITY));
for (int i1 = 0; i1 < n_tokens; ++i1) {
const llama_seq_id s1 = ubatch->seq_id[i1][0];
const llama_pos p1 = ubatch->pos[i1];
@@ -426,39 +427,27 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
data[idst + i0] = llama_cast<T>(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f);
}
}
- };
-
- const auto fill_mask = [&](ggml_tensor * mask, int n_swa, llama_swa_type swa_type) {
- GGML_ASSERT(mask);
- GGML_ASSERT(ggml_backend_buffer_is_host(mask->buffer));
-
- if (mask->type == GGML_TYPE_F16) {
- ggml_fp16_t * data = (ggml_fp16_t *) mask->data;
-
- std::fill(data, data + ggml_nelements(mask), llama_cast<ggml_fp16_t>(-INFINITY));
-
- fill_mask_inner(data, n_swa, swa_type);
-
- if (debug) {
- print_mask(data, n_tokens, n_kv, n_swa, swa_type);
- }
- } else {
- float * data = (float *) mask->data;
-
- std::fill(data, data + ggml_nelements(mask), -INFINITY);
-
- fill_mask_inner(data, n_swa, swa_type);
-
- if (debug) {
- print_mask(data, n_tokens, n_kv, n_swa, swa_type);
- }
+ if (debug) {
+ print_mask(data, n_tokens, n_kv, n_swa, swa_type);
}
};
- fill_mask(self_kq_mask, 0, LLAMA_SWA_TYPE_NONE);
+ GGML_ASSERT(self_kq_mask);
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
+ if (self_kq_mask->type == GGML_TYPE_F16) {
+ fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
+ } else {
+ fill_mask((float *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE);
+ }
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
- fill_mask(self_kq_mask_swa, hparams.n_swa, hparams.swa_type);
+ GGML_ASSERT(self_kq_mask_swa);
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
+ if (self_kq_mask_swa->type == GGML_TYPE_F16) {
+ fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), 0, LLAMA_SWA_TYPE_NONE);
+ } else {
+ fill_mask((float *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), 0, LLAMA_SWA_TYPE_NONE);
+ }
}
}| } else if (std::is_same_v<src_t, ggml_fp16_t> && std::is_same_v<dst_t, float>) { | ||
| return ggml_fp16_to_fp32(v); | ||
| } else if (std::is_same_v<src_t, float> && std::is_same_v<dst_t, ggml_fp16_t>) { | ||
| return ggml_fp32_to_fp16(v); | ||
| } |
There was a problem hiding this comment.
| } else if (std::is_same_v<src_t, ggml_fp16_t> && std::is_same_v<dst_t, float>) { | |
| return ggml_fp16_to_fp32(v); | |
| } else if (std::is_same_v<src_t, float> && std::is_same_v<dst_t, ggml_fp16_t>) { | |
| return ggml_fp32_to_fp16(v); | |
| } | |
| } else if constexpr (std::is_same_v<src_t, ggml_fp16_t> && std::is_same_v<dst_t, float>) { | |
| return ggml_fp16_to_fp32(v); | |
| } else if constexpr (std::is_same_v<src_t, float> && std::is_same_v<dst_t, ggml_fp16_t>) { | |
| return ggml_fp32_to_fp16(v); | |
| } else { | |
| static_assert(std::is_same_v(dst_t, void), "unsupported type combination"); | |
| } |
I think constexpr should be added to avoid potential unexpected compiler errors; in this case casts between 16 bit unsigned integers and 32 bit floats are defined but will result in unexpected behavior. Also add a static_assert to detect unintended misuse.
| float val; | ||
| if constexpr (std::is_same_v<T, ggml_fp16_t>) { | ||
| val = ggml_fp16_to_fp32(data[i * n_kv + j]); | ||
| } else { | ||
| val = data[i * n_kv + j]; | ||
| } |
There was a problem hiding this comment.
| float val; | |
| if constexpr (std::is_same_v<T, ggml_fp16_t>) { | |
| val = ggml_fp16_to_fp32(data[i * n_kv + j]); | |
| } else { | |
| val = data[i * n_kv + j]; | |
| } | |
| float val = llama_cast<float>(data[i * n_kv + j]); |
Overview
Currently we reserve the KQ mask in f32 even if FA is used, which is then is converted to f16 while passing to backends. The f32 mask still uses the compute buffer even though is not used, taking up extra VRAM. This PR reserves the kq-mask in f16. This provides 1.2GB of VRAM saving at
-ub 2048and ~300Mb at-ub 512when using MTPAdditional information
Requirements