Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ cc_test(
":mat",
":matmul",
":query",
":test_util",
":threading_context",
":weights",
"@googletest//:gtest_main", # buildcleaner: keep
Expand Down
148 changes: 148 additions & 0 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
#include <array>
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <limits>
#include <vector>

#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "gemma/flash_structs.h"
Expand Down Expand Up @@ -438,6 +440,152 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
return scale;
}

// Reduces each of x and stores in following lanes of max (tested with float32)
template <class DF, typename T = hn::TFromD<DF>,
class DF4 = hn::CappedTag<T, 4>, class VF4 = hn::Vec<DF4>,
class VF = hn::Vec<DF>, typename F>
static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
F reducer) {
const DF4 df4;
constexpr size_t kMaxLanes = hn::MaxLanes(df);
HWY_LANES_CONSTEXPR size_t kLanes = hn::Lanes(df);
HWY_ALIGN T x_transposed[4 * kMaxLanes];
hn::StoreInterleaved4<DF>(x_0, x_1, x_2, x_3, df, x_transposed);
VF4 result = hn::Load(df4, x_transposed);
for (int i = 1; i < kLanes; ++i) {
result = reducer(result, hn::Load(df4, x_transposed + i * 4));
}
return result;
}

// Handles Up to 4 Q rows by NF*2 timesteps of flash attention.
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d,
float* HWY_RESTRICT scales) {
using DF4 = hn::CappedTag<float, 4>;
const DF4 df4;
using VF4 = hn::Vec<DF4>;
static_assert(kNumQueries >= 1 && kNumQueries <= 4);
VF4 new_max = hn::Set(df4, -std::numeric_limits<float>::max() / 2.0f);
VF max_0, max_1, max_2, max_3 = hn::Zero(df);
max_0 = hn::Max(x_0_p0, x_0_p1);
if constexpr (kNumQueries >= 2) {
max_1 = hn::Max(x_1_p0, x_1_p1);
}
if constexpr (kNumQueries >= 3) {
max_2 = hn::Max(x_2_p0, x_2_p1);
}
if constexpr (kNumQueries >= 4) {
max_3 = hn::Max(x_3_p0, x_3_p1);
}
if constexpr (kNumQueries == 1) {
new_max = hn::InsertLane(new_max, 0, hn::ReduceMax(df, max_0));
} else {
new_max = Reduce4(df, max_0, max_1, max_2, max_3,
[](auto a, auto b) { return hn::Max(a, b); });
}
if (att_cap > 0.0f) {
VF4 cap = hn::Set(df4, att_cap);
VF4 one_over_cap = hn::Set(df4, one_over_att_cap);
new_max = hn::Mul(cap, hn::Tanh(df4, hn::Mul(new_max, one_over_cap)));
}
VF4 old_max_vf = hn::Set(df4, -std::numeric_limits<float>::max() / 2.0f);
old_max_vf = hn::LoadU(df4, old_max);
new_max = hn::Max(new_max, old_max_vf);
// TODO figure out what was wrong with broadcasts and change to that.
HWY_ALIGN float tmp_max[4];
hn::Store(new_max, df4, tmp_max);
if constexpr (kNumQueries >= 1) {
const VF new_max_0 = hn::Set(df, tmp_max[0]);
x_0_p0 = hn::Exp(df, hn::Sub(x_0_p0 , new_max_0));
x_0_p1 = hn::Exp(df, hn::Sub(x_0_p1, new_max_0));
}
if constexpr (kNumQueries >= 2) {
const VF new_max_0 = hn::Set(df, tmp_max[1]);
x_1_p0 = hn::Exp(df, hn::Sub(x_1_p0, new_max_0));
x_1_p1 = hn::Exp(df, hn::Sub(x_1_p1, new_max_0));
}
if constexpr (kNumQueries >= 3) {
const VF new_max_0 = hn::Set(df, tmp_max[2]);
x_2_p0 = hn::Exp(df, hn::Sub(x_2_p0, new_max_0));
x_2_p1 = hn::Exp(df, hn::Sub(x_2_p1, new_max_0));
}
if constexpr (kNumQueries >= 4) {
const VF new_max_0 = hn::Set(df, tmp_max[3]);
x_3_p0 = hn::Exp(df, hn::Sub(x_3_p0, new_max_0));
x_3_p1 = hn::Exp(df, hn::Sub(x_3_p1, new_max_0));
}
VF4 old_d_vf = hn::Set(df4, 0.0f);
old_d_vf = hn::LoadU(df4, old_d);
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));

hn::StoreU(new_max, df4, old_max);

VF4 x_sum = hn::Zero(df4);
if constexpr (kNumQueries == 1) {
x_sum = hn::Set(df4, hn::ReduceSum(df, x_0_p0) + hn::ReduceSum(df, x_0_p1));
} else {
VF x_0_sum = hn::Add(x_0_p0, x_0_p1);
VF x_1_sum = hn::Add(x_1_p0, x_1_p1);
VF x_2_sum = hn::Add(x_2_p0, x_2_p1);
VF x_3_sum = hn::Add(x_3_p0, x_3_p1);
x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
[](auto a, auto b) { return hn::Add(a, b); });
}
old_d_vf = hn::Add(scale, x_sum);
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f));
const VF zero = hn::Zero(df);
const VF4 zero4 = hn::Zero(df4);
const VF4 one_over_d =
hn::MaskedDivOr(zero4, non_zero_mask, hn::Set(df4, 1.0f), old_d_vf);
float tmp_one_over_d[4];
hn::Store(one_over_d, df4, tmp_one_over_d);
hn::Store(old_d_vf, df4, old_d);
scale = hn::Mul(scale, one_over_d);
hn::Store(scale, df4, scales);
if (hn::ExtractLane(old_d_vf, 0) > 0.0f) {
const VF one_over_d_0 = hn::Set(df, tmp_one_over_d[0]);
x_0_p0 = hn::Mul(x_0_p0, one_over_d_0);
x_0_p1 = hn::Mul(x_0_p1, one_over_d_0);
} else {
x_0_p0 = zero;
x_0_p1 = zero;
}
if constexpr (kNumQueries >= 2) {
if (hn::ExtractLane(old_d_vf, 1) > 0.0f) {
const VF one_over_d_1 = hn::Set(df, tmp_one_over_d[1]);
x_1_p0 = hn::Mul(x_1_p0, one_over_d_1);
x_1_p1 = hn::Mul(x_1_p1, one_over_d_1);
} else {
x_1_p0 = zero;
x_1_p1 = zero;
}
}
if constexpr (kNumQueries >= 3) {
if (hn::ExtractLane(old_d_vf, 2) > 0.0f) {
const VF one_over_d_2 = hn::Set(df, tmp_one_over_d[2]);
x_2_p0 = hn::Mul(x_2_p0, one_over_d_2);
x_2_p1 = hn::Mul(x_2_p1, one_over_d_2);
} else {
x_2_p0 = zero;
x_2_p1 = zero;
}
}
if constexpr (kNumQueries >= 4) {
if (hn::ExtractLane(old_d_vf, 3) > 0.0f) {
const VF one_over_d_3 = hn::Set(df, tmp_one_over_d[3]);
x_3_p0 = hn::Mul(x_3_p0, one_over_d_3);
x_3_p1 = hn::Mul(x_3_p1, one_over_d_3);
} else {
x_3_p0 = zero;
x_3_p1 = zero;
}
}
}

// Implements flash attention for a strip of 4 query vectors.
// It iterates through timesteps in K from `start_pos` up to `max_last_pos`.
// Timesteps up to `min_last_pos` (*) are processed in tiles of shape 4 Q rows
Expand Down
3 changes: 3 additions & 0 deletions gemma/flash_attention_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
// limitations under the License.

#include <cstring>
#include <iostream>
#include <limits>
#include <numeric>
#include <vector>

Expand All @@ -24,6 +26,7 @@
#include "gemma/kv_cache.h"
#include "gemma/weights.h"
#include "ops/matmul.h"
#include "util/test_util.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
Expand Down
15 changes: 15 additions & 0 deletions util/mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,21 @@ decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func,
}
}

// Like CallUpcasted, but only for kv_cache types: kBF16 and kF32.
template <class Func, typename... Args>
decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func,
Args&&... args) {
if (base->GetType() == Type::kF32) {
const MatPtrT<float> mat(*base);
return func(&mat, std::forward<Args>(args)...);
} else if (base->GetType() == Type::kBF16) {
const MatPtrT<BF16> mat(*base);
return func(&mat, std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
}
}

void CopyMat(const MatPtr& from, MatPtr& to);
void ZeroInit(MatPtr& mat);

Expand Down
Loading