Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ struct AttentionActivations {
att_sums.OverrideRows(batch_size);
}

MatStorageT<float> q; // query
MatStorageT<float> q_T; // Transposed to maximize attention speed.
MatStorageT<float> q; // query
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.

MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
Expand Down Expand Up @@ -139,7 +139,7 @@ struct AttentionActivationsPtrs {

const ModelConfig& config;
MatPtrT<float> q;
MatPtrT<float> q_T;
MatPtrT<BF16> q_T;
MatPtrT<float> pre_att_rms_out;
MatPtrT<float> att;
MatPtrT<float> att_out;
Expand Down Expand Up @@ -203,6 +203,12 @@ struct Activations {
ffw_out.OverrideRows(batch_size);

attention_storage.SetBatchSize(batch_size);
attention.q = attention_storage.q;
attention.q_T = attention_storage.q_T;
attention.pre_att_rms_out = attention_storage.pre_att_rms_out;
attention.att = attention_storage.att;
attention.att_out = attention_storage.att_out;
attention.att_sums = attention_storage.att_sums;
}

const LayerConfig& layer_config;
Expand Down
21 changes: 17 additions & 4 deletions gemma/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,29 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const MatPtrT<KV_t>& k, float* HWY_RESTRICT att,
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kGenAttentionQDotK);
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];

CompressPerThread tls;
const hn::ScalableTag<float> df;
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
0);

if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const float score = Dot(q, k.Row(pos), k.Cols());
// const float score = Dot(q, k.Row(pos), k.Cols());
const float score =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos), qkv_dim);
att[pos] = score;
}
} else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t pos_modulo = div_seq_len.Remainder(pos);
const float score = Dot(q, k.Row(pos_modulo), k.Cols());
// const float score = Dot(q, k.Row(pos_modulo), k.Cols());
const float score =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_modulo), qkv_dim);
att[pos_modulo] = score;
}
}
Expand Down Expand Up @@ -130,7 +143,7 @@ static HWY_INLINE void WeightedSumV(
void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const MatPtrT<float>& query_norm_scale, const size_t layer_idx,
const MatPtr& query_norm_scale, const size_t layer_idx,
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
const float att_cap = activations.config.att_cap;
Expand Down Expand Up @@ -169,7 +182,7 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
}

void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const MatPtrT<float>& query_norm_scale,
const MatPtr& query_norm_scale,
AttentionActivationsPtrs& activations,
QBatch& qbatch, ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);
Expand Down
4 changes: 2 additions & 2 deletions gemma/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ namespace gcpp {
void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
\
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const MatPtrT<float>& query_norm_scale, \
const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, \
QBatch& qbatch, ThreadingContext& ctx); \
\
Expand Down
102 changes: 77 additions & 25 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ static constexpr size_t kNFx8HTileSize = 8;
// q has shape [batch, qbatch][head, qkv_dim].
// q_t has shape [qkv_dim][qbatch, head, batch] in order to make the maximum
// possible consecutive elements have the same KV.
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
static void TransposeQ(const MatPtrT<float>& q, MatPtrT<BF16>& q_t,
const size_t qbatch_size, ThreadingContext& ctx) {
// Group floats by the number of floats in a cache line.
const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float);
Expand All @@ -69,12 +69,13 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
for (size_t lane = 0; lane < kNF; ++lane) {
size_t q_row = task * kNF + lane;
if (q_row >= q_t.Rows()) break;
float* HWY_RESTRICT qt_row = q_t.Row(q_row);
BF16* HWY_RESTRICT qt_row = q_t.Row(q_row);
for (size_t qi = 0; qi < qbatch_size; ++qi) {
for (size_t h = 0; h < num_heads; ++h) {
for (size_t b = 0; b < batch_size; ++b) {
qt_row[(qi * num_heads + h) * batch_size + b] =
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row];
hwy::ConvertScalarTo<BF16>(
q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]);
}
}
}
Expand All @@ -91,7 +92,7 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
// Updates q in place for RMSNorm and positional encoding.
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
MatPtrT<float>& q,
const MatPtrT<float>& query_norm_scale,
const MatPtr& query_norm_scale,
const size_t layer_idx,
const AttentionActivationsPtrs& activations,
ThreadingContext& ctx) {
Expand Down Expand Up @@ -158,8 +159,19 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
const hn::ScalableTag<BF16> dbf;
const size_t qkv_dim = k.Cols();
HWY_ALIGN BF16 q_bf[kMaxQKVDim];

CompressPerThread tls;
const hn::ScalableTag<float> df;
CompressTraits<BF16>::Compress(df, q, qkv_dim, tls, MakeSpan(q_bf, qkv_dim),
0);
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
float m = Dot(q, k.Row(pos_mod), k.Cols());
// TODO: Mixed-mode can be further improved for Turin: we can demote right
// before we do the dot product instruction, rather than promote both to f32.
// But some potential accuracy loss there, needs evaluation first.
float m = Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
if (float cap = activations.config.att_cap; cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
m = cap * std::tanh(m / cap);
Expand All @@ -169,7 +181,8 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), ctx, worker);
for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) {
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
float x = Dot(q, k.Row(pos_mod), k.Cols());
float x =
Dot(dbf, MakeConstSpan(q_bf, qkv_dim), 0, k.Row(pos_mod), qkv_dim);
SingleFlashAttentionStep(x, activations.config.att_cap, m, d,
v.Row(pos_mod), v.Cols(), att_out);
}
Expand Down Expand Up @@ -233,6 +246,50 @@ void QDotKTileFloat(DF df, const float* HWY_RESTRICT q, const size_t q_stride,
}
}

template <class DF, class VF = hn::Vec<DF>>
void QDotKTile(DF df, const BF16* HWY_RESTRICT q, const size_t q_stride,
const MatPtrT<KV_t>& k, const size_t* k_pos, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& sum4, VF& sum5, VF& sum6, VF& sum7) {
constexpr size_t kHTileSize = kNFx8HTileSize;
sum0 = hn::Zero(df);
sum1 = hn::Zero(df);
sum2 = hn::Zero(df);
sum3 = hn::Zero(df);
sum4 = hn::Zero(df);
sum5 = hn::Zero(df);
sum6 = hn::Zero(df);
sum7 = hn::Zero(df);
const float* HWY_RESTRICT k_row[kHTileSize];
for (int i = 0; i < kHTileSize; ++i) {
k_row[i] = k.Row(k_pos[i]);
}

const hn::Rebind<BF16, DF> dbfh;
using VBF = hn::Vec<decltype(dbfh)>;

for (size_t i = 0; i < k.Cols(); ++i) {
const VBF q_vec_bf = hn::Load(dbfh, q);
const VF q_vec = hn::PromoteTo(df, q_vec_bf);
VF k_0 = hn::Set(df, k_row[0][i]);
sum0 = hn::MulAdd(q_vec, k_0, sum0);
VF k_1 = hn::Set(df, k_row[1][i]);
sum1 = hn::MulAdd(q_vec, k_1, sum1);
VF k_2 = hn::Set(df, k_row[2][i]);
sum2 = hn::MulAdd(q_vec, k_2, sum2);
VF k_3 = hn::Set(df, k_row[3][i]);
sum3 = hn::MulAdd(q_vec, k_3, sum3);
VF k_4 = hn::Set(df, k_row[4][i]);
sum4 = hn::MulAdd(q_vec, k_4, sum4);
VF k_5 = hn::Set(df, k_row[5][i]);
sum5 = hn::MulAdd(q_vec, k_5, sum5);
VF k_6 = hn::Set(df, k_row[6][i]);
sum6 = hn::MulAdd(q_vec, k_6, sum6);
VF k_7 = hn::Set(df, k_row[7][i]);
sum7 = hn::MulAdd(q_vec, k_7, sum7);
q += q_stride;
}
}

// Returns the element-wise maximum of 8 vectors, in a single vector.
template <class DF, class VF = hn::Vec<DF>>
VF HWY_INLINE ElementwiseMaxOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
Expand Down Expand Up @@ -264,17 +321,14 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos].
void TileFlashAttention(const MatPtrT<float>& q,
const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos,
const size_t min_last_pos, const size_t max_last_pos,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations,
MatPtrT<float>& att_out,
const uint32_t* HWY_RESTRICT out_offsets,
ThreadingContext& ctx, const size_t worker) {
void TileFlashAttention(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx,
const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
constexpr int kHTileSize = kNFx8HTileSize;
using DF = hn::ScalableTag<float>;
Expand All @@ -291,7 +345,7 @@ void TileFlashAttention(const MatPtrT<float>& q,
VI lasts = hn::LoadU(di, last_pos);
VF old_m = hn::Set(df, -std::numeric_limits<float>::max() / 2.0f);
VF old_d = hn::Zero(df);
const float* HWY_RESTRICT qT_row = qT.Row(0);
const BF16* HWY_RESTRICT qT_row = qT.Row(0);
const size_t qT_stride = qT.Stride();
size_t position = start_pos;
while (position + kHTileSize - 1 <= min_last_pos) {
Expand All @@ -300,8 +354,7 @@ void TileFlashAttention(const MatPtrT<float>& q,
k_pos[i] = activations.div_seq_len.Remainder(position + i);
}
VF x0, x1, x2, x3, x4, x5, x6, x7;
QDotKTileFloat(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6,
x7);
QDotKTile(df, qT_row, qT_stride, k, k_pos, x0, x1, x2, x3, x4, x5, x6, x7);
if (activations.config.att_cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the tile.
VF cap = hn::Set(df, activations.config.att_cap);
Expand Down Expand Up @@ -592,8 +645,7 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
// grouped together so that mode 1 or 2 can be used, and choosing which of the
// 3 modes to use for best efficiency.
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const size_t layer_idx,
const MatPtrT<float>& query_norm_scale,
const size_t layer_idx, const MatPtr& query_norm_scale,
AttentionActivationsPtrs& activations, QBatch& qbatch,
ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
Expand Down Expand Up @@ -723,9 +775,9 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
// To avoid duplicating the code to setup K and V, the call to
// TileFlashAttention is inside the loop over tasks, even though it
// handles all rows in the task at once.
StridedView<float> qT =
StridedView<float>(activations.q_T.Row(0) + first_task, kVTileSize,
activations.q_T.Stride());
StridedView<BF16> qT =
StridedView<BF16>(activations.q_T.Row(0) + first_task, kVTileSize,
activations.q_T.Stride());
if (kVTileSize == kNF) {
// We can still use TileFlashAttention even if we didn't transpose Q
// above. The condition used for transposing Q above is more general
Expand Down
5 changes: 2 additions & 3 deletions gemma/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace gcpp {
namespace NAMESPACE { \
void RMSNormAndPositionalEncoding( \
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
Expand All @@ -45,8 +45,7 @@ namespace gcpp {
size_t total_tasks, size_t target_parallelism); \
\
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, \
const MatPtrT<float>& query_norm_scale, \
size_t layer_idx, const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
Expand Down
3 changes: 2 additions & 1 deletion ops/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ using DotKernelDefault =
template <class D, typename WT, typename VT>
HWY_INLINE float Dot(D d, const PackedSpan<const WT>& w, size_t w_ofs,
const VT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelDefault());
return DecompressAndCall(d, w, w_ofs, MakeConstSpan(vec, num),
DotKernelDefault());
}

// Adapter for two pointers, no bounds checking.
Expand Down
Loading