From 116cd6eff65d494485c20eb0bdeb5c691d381f14 Mon Sep 17 00:00:00 2001 From: Phil Culliton Date: Wed, 29 Oct 2025 01:47:59 -0700 Subject: [PATCH] BF16 mixed-mode flash attention PiperOrigin-RevId: 825433929 --- gemma/activations.h | 6 +- gemma/attention.cc | 15 ++++- gemma/flash_attention.cc | 120 ++++++++++++++++++++++++++------------- ops/dot-inl.h | 3 +- 4 files changed, 99 insertions(+), 45 deletions(-) diff --git a/gemma/activations.h b/gemma/activations.h index acaecb79..cd1621f9 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -104,8 +104,8 @@ struct AttentionActivations { // `inv_timescale*` are not batched. } - MatStorageT q; // query - MatStorageT q_T; // Transposed to maximize attention speed. + MatStorageT q; // query + MatStorageT q_T; // Transposed to maximize attention speed. MatStorageT pre_att_rms_out; MatStorageT att; // attention vector @@ -151,7 +151,7 @@ struct AttentionActivationsPtrs { const ModelConfig& config; MatPtrT q; - MatPtrT q_T; + MatPtrT q_T; MatPtrT pre_att_rms_out; MatPtrT att; MatPtrT att_out; diff --git a/gemma/attention.cc b/gemma/attention.cc index 6d81c741..95d62cdc 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -57,16 +57,27 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, const MatPtrT& k, float* HWY_RESTRICT att, ThreadingContext& ctx, const size_t worker) { GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK); + const hn::ScalableTag dbf; + const size_t qkv_dim = k.Cols(); + HWY_ALIGN BF16 q_bf[kMaxQKVDim]; + + CompressPerThread tls; + const hn::ScalableTag df; + CompressTraits::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim), + 0); + if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { // Slightly faster: no wraparound. for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const float score = Dot(q, k.Row(pos), k.Cols()); + const float score = + Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim); att[pos] = score; } } else { for (size_t pos = start_pos; pos <= last_pos; ++pos) { const size_t pos_modulo = div_seq_len.Remainder(pos); - const float score = Dot(q, k.Row(pos_modulo), k.Cols()); + const float score = + Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_modulo), qkv_dim); att[pos_modulo] = score; } } diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 49cdfdc3..5392ec0f 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -58,7 +58,7 @@ static constexpr size_t kNFx8HTileSize = 8; // q has shape [batch, qbatch][head, qkv_dim]. // q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum // possible consecutive elements have the same KV. -static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, +static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, const size_t qbatch_size, ThreadingContext& ctx) { // Group floats by the number of floats in a cache line. const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float); @@ -69,12 +69,13 @@ static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, for (size_t lane = 0; lane < kNF; ++lane) { size_t q_row = task * kNF + lane; if (q_row >= q_t.Rows()) break; - float* HWY_RESTRICT qt_row = q_t.Row(q_row); + BF16* HWY_RESTRICT qt_row = q_t.Row(q_row); for (size_t qi = 0; qi < qbatch_size; ++qi) { for (size_t h = 0; h < num_heads; ++h) { for (size_t b = 0; b < batch_size; ++b) { qt_row[(qi * num_heads + h) * batch_size + b] = - q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]; + hwy::ConvertScalarTo( + q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]); } } } @@ -158,8 +159,19 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention); + const hn::ScalableTag dbf; + const size_t qkv_dim = k.Cols(); + HWY_ALIGN BF16 q_bf[kMaxQKVDim]; + + CompressPerThread tls; + const hn::ScalableTag df; + CompressTraits::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim), + 0); const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); - float m = Dot(q, k.Row(pos_mod), k.Cols()); + // TODO: Mixed-mode can be further improved for Turin: we can demote right + // before we do the dot product instruction, rather than promote both to f32. + // But some potential accuracy loss there, needs evaluation first. + float m = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim); if (float cap = activations.config.att_cap; cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. m = cap * std::tanh(m / cap); @@ -169,7 +181,8 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker); for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { const size_t pos_mod = activations.div_seq_len.Remainder(pos); - float x = Dot(q, k.Row(pos_mod), k.Cols()); + float x = + Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim); SingleFlashAttentionStep(x, activations.config.att_cap, m, d, v.Row(pos_mod), v.Cols(), att_out); } @@ -179,25 +192,31 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, // the dot products of NF rows of Q for a single K timestep. template > VF QDotKVector(DF df, const uint32_t* HWY_RESTRICT q_offsets, - const size_t k_pos, const MatPtrT& q, + const size_t k_pos, const MatPtrT& q, const MatPtrT& k) { + const hn::ScalableTag dbf; + const size_t qkv_dim = k.Cols(); + HWY_ALIGN BF16 q_bf[kMaxQKVDim]; + CompressPerThread tls; + hn::TFromD results[hn::MaxLanes(df)]; for (size_t i = 0; i < hn::Lanes(df); ++i) { - results[i] = Dot(q.Row(0) + q_offsets[i], k.Row(k_pos), k.Cols()); + CompressTraits::Compress(df, q.Row(0) + q_offsets[i], qkv_dim, tls, + MakeSpan(q_bf, qkv_dim), 0); + results[i] = + Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); } return hn::LoadU(df, results); } -// Returns an NF Q rows by 8 K rows tile of Q.K dot products, in single -// precision. +// Returns an NF Q rows by 8 K rows tile of Q.K dot products. // This is the result of NF rows of Q against 8 K timesteps, with positions // given by k_pos[0..7]. Q has been transposed so that the NF rows are read in // consecutive elements, and other columns by adding q_stride. template > -void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride, - const MatPtrT& k, const size_t* k_pos, VF& sum0, - VF& sum1, VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, - VF& sum7) { +void QDotKTile(DF df, const BF16* HWY_RESTRICT q, const size_t q_stride, + const MatPtrT& k, const size_t* k_pos, VF& sum0, VF& sum1, + VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) { constexpr size_t kHTileSize = kNFx8HTileSize; sum0 = hn::Zero(df); sum1 = hn::Zero(df); @@ -211,8 +230,13 @@ void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride, for (int i = 0; i < kHTileSize; ++i) { k_row[i] = k.Row(k_pos[i]); } + + const hn::Rebind dbfh; + using VBF = hn::Vec; + for (size_t i = 0; i < k.Cols(); ++i) { - VF q_vec = hn::Load(df, q); + const VBF q_vec_bf = hn::Load(dbfh, q); + const VF q_vec = hn::PromoteTo(df, q_vec_bf); VF k_0 = hn::Set(df, k_row[0][i]); sum0 = hn::MulAdd(q_vec, k_0, sum0); VF k_1 = hn::Set(df, k_row[1][i]); @@ -264,17 +288,14 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2, // Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, // max_last_pos]. -void TileFlashAttention(const MatPtrT& q, - const uint32_t* HWY_RESTRICT q_offsets, - const StridedView& qT, const MatPtrT& k, - const size_t start_pos, - const uint32_t* HWY_RESTRICT last_pos, - const size_t min_last_pos, const size_t max_last_pos, - const MatPtrT& v, const size_t layer_idx, - const AttentionActivationsPtrs& activations, - MatPtrT& att_out, - const uint32_t* HWY_RESTRICT out_offsets, - ThreadingContext& ctx, const size_t worker) { +void TileFlashAttention( + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, + const StridedView& qT, const MatPtrT& k, const size_t start_pos, + const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, + const size_t max_last_pos, const MatPtrT& v, const size_t layer_idx, + const AttentionActivationsPtrs& activations, MatPtrT& att_out, + const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, + const size_t worker) { GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention); constexpr int kHTileSize = kNFx8HTileSize; using DF = hn::ScalableTag; @@ -291,7 +312,7 @@ void TileFlashAttention(const MatPtrT& q, VI lasts = hn::LoadU(di, last_pos); VF old_m = hn::Set(df, -std::numeric_limits::max() / 2.0f); VF old_d = hn::Zero(df); - const float* HWY_RESTRICT qT_row = qT.Row(0); + const BF16* HWY_RESTRICT qT_row = qT.Row(0); const size_t qT_stride = qT.Stride(); size_t position = start_pos; while (position + kHTileSize - 1 <= min_last_pos) { @@ -300,8 +321,7 @@ void TileFlashAttention(const MatPtrT& q, k_pos[i] = activations.div_seq_len.Remainder(position + i); } VF x0, x1, x2, x3, x4, x5, x6, x7; - QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, - x7); + QDotKTile(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, x7); if (activations.config.att_cap > 0.0f) { // Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile. VF cap = hn::Set(df, activations.config.att_cap); @@ -390,13 +410,17 @@ void QDotKTilex4(DF df, const float* HWY_RESTRICT q, VI k_offsets_vec = hn::LoadU(di, k_offsets); for (size_t i = 0; i < k.Cols(); ++i) { VF k_vec = hn::GatherIndex(df, k_base + i, k_offsets_vec); - VF q_0 = hn::Set(df, q[q_offsets[0] + i]); + VF q_0 = hn::Set(df, hwy::ConvertScalarTo( + hwy::ConvertScalarTo(q[q_offsets[0] + i]))); sum0 = hn::MulAdd(q_0, k_vec, sum0); - VF q_1 = hn::Set(df, q[q_offsets[1] + i]); + VF q_1 = hn::Set(df, hwy::ConvertScalarTo( + hwy::ConvertScalarTo(q[q_offsets[1] + i]))); sum1 = hn::MulAdd(q_1, k_vec, sum1); - VF q_2 = hn::Set(df, q[q_offsets[2] + i]); + VF q_2 = hn::Set(df, hwy::ConvertScalarTo( + hwy::ConvertScalarTo(q[q_offsets[2] + i]))); sum2 = hn::MulAdd(q_2, k_vec, sum2); - VF q_3 = hn::Set(df, q[q_offsets[3] + i]); + VF q_3 = hn::Set(df, hwy::ConvertScalarTo( + hwy::ConvertScalarTo(q[q_offsets[3] + i]))); sum3 = hn::MulAdd(q_3, k_vec, sum3); } } @@ -478,32 +502,50 @@ void TileFlashAttention4(const MatPtrT& q, out_offsets, v.Cols()); position += kHTileSize; } + const hn::ScalableTag dbf; + const size_t qkv_dim = k.Cols(); + HWY_ALIGN BF16 q_bf[kMaxQKVDim]; + CompressPerThread tls; + const hn::ScalableTag df_compress; + while (position <= max_last_pos) { size_t k_pos = activations.div_seq_len.Remainder(position); if (position <= last_pos[0]) { // Past the last position, x0 doesn't count. - float x0 = Dot(q.Row(0) + q_offsets[0], k.Row(k_pos), k.Cols()); + CompressTraits::Compress(df_compress, q.Row(0) + q_offsets[0], + qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); + float x0 = + Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); SingleFlashAttentionStep(x0, activations.config.att_cap, old_m0, old_d0, v.Row(k_pos), v.Cols(), att_out.Row(0) + out_offsets[0]); } if (position <= last_pos[1]) { // Past the last position, x1 doesn't count. - float x1 = Dot(q.Row(0) + q_offsets[1], k.Row(k_pos), k.Cols()); + CompressTraits::Compress(df_compress, q.Row(0) + q_offsets[1], + qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); + float x1 = + Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); SingleFlashAttentionStep(x1, activations.config.att_cap, old_m1, old_d1, v.Row(k_pos), v.Cols(), att_out.Row(0) + out_offsets[1]); } if (position <= last_pos[2]) { // Past the last position, x2 doesn't count. - float x2 = Dot(q.Row(0) + q_offsets[2], k.Row(k_pos), k.Cols()); + CompressTraits::Compress(df_compress, q.Row(0) + q_offsets[2], + qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); + float x2 = + Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); SingleFlashAttentionStep(x2, activations.config.att_cap, old_m2, old_d2, v.Row(k_pos), v.Cols(), att_out.Row(0) + out_offsets[2]); } if (position <= last_pos[3]) { // Past the last position, x3 doesn't count. - float x3 = Dot(q.Row(0) + q_offsets[3], k.Row(k_pos), k.Cols()); + CompressTraits::Compress(df_compress, q.Row(0) + q_offsets[3], + qkv_dim, tls, MakeSpan(q_bf, qkv_dim), 0); + float x3 = + Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(k_pos), qkv_dim); SingleFlashAttentionStep(x3, activations.config.att_cap, old_m3, old_d3, v.Row(k_pos), v.Cols(), att_out.Row(0) + out_offsets[3]); @@ -722,9 +764,9 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, // To avoid duplicating the code to setup K and V, the call to // TileFlashAttention is inside the loop over tasks, even though it // handles all rows in the task at once. - StridedView qT = - StridedView(activations.q_T.Row(0) + first_task, kVTileSize, - activations.q_T.Stride()); + StridedView qT = + StridedView(activations.q_T.Row(0) + first_task, kVTileSize, + activations.q_T.Stride()); if (kVTileSize == kNF) { // We can still use TileFlashAttention even if we didn't transpose Q // above. The condition used for transposing Q above is more general diff --git a/ops/dot-inl.h b/ops/dot-inl.h index dae2106a..ecf1ecf6 100644 --- a/ops/dot-inl.h +++ b/ops/dot-inl.h @@ -413,7 +413,8 @@ using DotKernelDefault = template HWY_INLINE float Dot(D d, const PackedSpan& w, size_t w_ofs, const VT* HWY_RESTRICT vec, size_t num) { - return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDefault()); + return DecompressAndCall(d, w, w_ofs, MakeConstSpan(vec, num), + DotKernelDefault()); } // Adapter for two pointers, no bounds checking.