From 5a05857debebfaf175d485316106317a62c76319 Mon Sep 17 00:00:00 2001 From: Biruk Mammo Date: Mon, 27 Oct 2025 10:42:46 -0700 Subject: [PATCH] [Gemma.cpp] Allows non-owned arguments for attention methods. * Adds and uses a new `AttentionActivationPtrs` that holds non-owning `MatPtrs`. Acts as a view into `AttentionActivations`. * Updates `QBatch` to hold non-owning `MatPtr`s to the kv caches. * Enables the `MatPtrT` default constructor for simpler initializations. * Pulls out and passes `LayerWeightsPtrs::query_norm_scale` directly. While `LayerWeightsPtrs` already held non-owning `MatPtr`s, this change avoids the need to find and construct several empty weight tensors just to construct one `query_norm_scale` tensor. PiperOrigin-RevId: 824584177 --- BUILD.bazel | 1 - gemma/activations.h | 82 ++++++++++++++++++++++------------- gemma/attention.cc | 52 +++++++++++----------- gemma/attention.h | 13 +++--- gemma/flash_attention.cc | 82 ++++++++++++++++++----------------- gemma/flash_attention.h | 18 ++++---- gemma/flash_attention_test.cc | 12 ++--- gemma/gemma.h | 20 +++++++-- gemma/kv_cache.cc | 10 +++++ gemma/kv_cache.h | 10 +++++ ops/ops_test.cc | 2 +- util/mat.h | 3 ++ 12 files changed, 185 insertions(+), 120 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 38d79cf5..d14002fc 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -139,7 +139,6 @@ cc_test( ":kv_cache", ":mat", ":matmul", - ":ops", ":threading_context", ":weights", "@googletest//:gtest_main", # buildcleaner: keep diff --git a/gemma/activations.h b/gemma/activations.h index f474c84f..40320d8e 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -31,26 +31,24 @@ namespace gcpp { -struct AttentionActivations { - // Returns the scale value to use for the query in the attention computation. - // Also called by ops_test. - static inline float ChooseQueryScale(const ModelConfig& config) { - const LayerConfig& layer_config = config.layer_configs[0]; - if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) - return 1.0f / - sqrtf(static_cast(config.model_dim / layer_config.heads)); - // QueryScaleType::SqrtKeySize - return 1.0f / sqrtf(static_cast(layer_config.qkv_dim)); - } +// Returns the scale value to use for the query in the attention computation. +// Also called by ops_test. +static inline float ChooseQueryScale(const ModelConfig& config) { + const LayerConfig& layer_config = config.layer_configs[0]; + if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) + return 1.0f / + sqrtf(static_cast(config.model_dim / layer_config.heads)); + // QueryScaleType::SqrtKeySize + return 1.0f / sqrtf(static_cast(layer_config.qkv_dim)); +} +struct AttentionActivations { AttentionActivations( const ModelConfig& config, const LayerConfig& layer_config, size_t batch_size, size_t seq_len, const Allocator& allocator, std::vector>& row_ptrs) - : config(config), - - // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA - // and does not use an external KV cache. + : // `vocab_size == 0` means it is for Vit part, VitAttention is still + // MHA and does not use an external KV cache. q(MatFactory("q", batch_size, config.vocab_size == 0 ? layer_config.heads * 3 * layer_config.qkv_dim @@ -76,11 +74,7 @@ struct AttentionActivations { layer_config.post_qk == PostQKType::HalfRope)), inv_timescale_global(CreateInvTimescale( allocator, layer_config.qkv_dim, - layer_config.post_qk == PostQKType::HalfRope, 1000000.0)), - - div_seq_len(static_cast(seq_len)), - div_heads(static_cast(layer_config.heads)), - query_scale(ChooseQueryScale(config)) { + layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) { // Batch size can be 0 in experimental code so do not assert. if (batch_size == 0) { static std::atomic_flag warned = ATOMIC_FLAG_INIT; @@ -108,9 +102,7 @@ struct AttentionActivations { att_sums.OverrideRows(batch_size); } - const ModelConfig& config; - - MatStorageT q; // query + MatStorageT q; // query MatStorageT q_T; // Transposed to maximize attention speed. MatStorageT pre_att_rms_out; @@ -122,9 +114,39 @@ struct AttentionActivations { // Rope MatStorageT inv_timescale; MatStorageT inv_timescale_global; +}; +// A non-owning view of AttentionActivations. +struct AttentionActivationsPtrs { + AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len) + : config(config), + div_seq_len(static_cast(seq_len)), + div_heads(static_cast(config.layer_configs[0].heads)), + query_scale(ChooseQueryScale(config)) {} + + AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len, + const AttentionActivations& activations) + : AttentionActivationsPtrs(config, seq_len) { + q = activations.q; + q_T = activations.q_T; + pre_att_rms_out = activations.pre_att_rms_out; + att = activations.att; + att_out = activations.att_out; + att_sums = activations.att_sums; + inv_timescale = activations.inv_timescale; + inv_timescale_global = activations.inv_timescale_global; + } + + const ModelConfig& config; + MatPtrT q; + MatPtrT q_T; + MatPtrT pre_att_rms_out; + MatPtrT att; + MatPtrT att_out; + MatPtrT att_sums; + MatPtrT inv_timescale; + MatPtrT inv_timescale_global; hwy::Divisor div_seq_len; - // Unfortunately, some models have had non-power-of-two heads. hwy::Divisor div_heads; float query_scale; }; @@ -150,8 +172,9 @@ struct Activations { ffw_out( MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)), - attention(config, layer_config, batch_size, seq_len, ctx.allocator, - row_ptrs) { + attention_storage(config, layer_config, batch_size, seq_len, + ctx.allocator, row_ptrs), + attention(config, seq_len, attention_storage) { HWY_ASSERT(batch_size != 0); // For MatMul outputs, precompute their row pointers. @@ -179,12 +202,12 @@ struct Activations { C2.OverrideRows(batch_size); ffw_out.OverrideRows(batch_size); - attention.SetBatchSize(batch_size); + attention_storage.SetBatchSize(batch_size); } const LayerConfig& layer_config; - MatStorageT x; // input + MatStorageT x; // input MatStorageT x_bf; // output of final RMSNorm, input to EmbeddingMatmul MatStorageT logits; // TODO: BF16 after Softmax supports that. MatStorageT sampled; // batch_size x 3 (padded) @@ -195,7 +218,8 @@ struct Activations { MatStorageT C2; MatStorageT ffw_out; - AttentionActivations attention; + AttentionActivations attention_storage; + AttentionActivationsPtrs attention; }; } // namespace gcpp diff --git a/gemma/attention.cc b/gemma/attention.cc index 668105aa..b7099d17 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -73,12 +73,12 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, } void PositionalEncodingQK(float* qk, const size_t layer_idx, - const LayerWeightsPtrs& layer, - const AttentionActivations& activations, + const AttentionActivationsPtrs& activations, ThreadingContext& ctx, const size_t worker, const size_t pos, const float mul) { - const size_t qkv_dim = layer.layer_config.qkv_dim; - const PostQKType& post_qk = layer.layer_config.post_qk; + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; + const size_t qkv_dim = layer_config.qkv_dim; + const PostQKType& post_qk = layer_config.post_qk; // qk is either q or k, so qkv_dim is the length we operate on. const float* inv_timescale = activations.inv_timescale.PackedScale1(); const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx); @@ -130,23 +130,23 @@ static HWY_INLINE void WeightedSumV( void SingleDotSoftmaxWeightedSum( const size_t pos, const size_t start_pos, const size_t last_pos, float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, - const size_t layer_idx, const LayerWeightsPtrs& layer, - const AttentionActivations& activations, float* HWY_RESTRICT att, + const MatPtrT& query_norm_scale, const size_t layer_idx, + const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { const float att_cap = activations.config.att_cap; const float query_scale = activations.query_scale; const size_t seq_len = static_cast(activations.div_seq_len.GetDivisor()); - + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; // Apply rope and scaling to Q. - if (layer.query_norm_scale.HasPtr()) { - CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { + if (query_norm_scale.HasPtr()) { + CallUpcasted(&query_norm_scale, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q, - layer.layer_config.qkv_dim, ctx, worker); + layer_config.qkv_dim, ctx, worker); }); } - PositionalEncodingQK(q, layer_idx, layer, activations, ctx, worker, pos, + PositionalEncodingQK(q, layer_idx, activations, ctx, worker, pos, query_scale); QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker); @@ -169,13 +169,13 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) { } void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, - const LayerWeightsPtrs& layer, - AttentionActivations& activations, QBatch& qbatch, - ThreadingContext& ctx) { + const MatPtrT& query_norm_scale, + AttentionActivationsPtrs& activations, + QBatch& qbatch, ThreadingContext& ctx) { GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive); const hwy::Divisor div_qbatch(qbatch.Size()); - const LayerConfig& layer_config = layer.layer_config; + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; const size_t qkv_dim = layer_config.qkv_dim; // A "head group" in the context of GQA refers to a collection of query @@ -223,8 +223,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); - SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx, - layer, activations, att, att_out, ctx, worker); + SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, + query_norm_scale, layer_idx, activations, att, + att_out, ctx, worker); }; { @@ -245,7 +246,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, // Fills activations.q and writes to KV cache. static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, - AttentionActivations& activations, + AttentionActivationsPtrs& activations, const QBatch& qbatch, const int flags, MatMulEnv& env) { GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), @@ -312,8 +313,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, }); } - PositionalEncodingQK(kv_f32, layer_idx, layer, activations, env.ctx, - worker, pos, /*mul=*/1.0f); + PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker, + pos, /*mul=*/1.0f); CompressPerThread tls; Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); }); @@ -322,7 +323,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and // head_dim (`qkv_dim`) into output (`layer_out`). static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, - AttentionActivations& activations, + AttentionActivationsPtrs& activations, MatMulEnv& env) { GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads); const LayerConfig& layer_config = layer.layer_config; @@ -340,7 +341,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, void GemmaAttention(size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, - AttentionActivations& activations, QBatch& qbatch, + AttentionActivationsPtrs& activations, QBatch& qbatch, MatMulEnv& env, int flags) { GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention); @@ -352,13 +353,14 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); if (flags & kAttentionUseOld) { - DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, - env.ctx); + DotSoftmaxWeightedSum(num_tokens, layer_idx, layer.query_norm_scale, + activations, qbatch, env.ctx); } else { // * 2 does not help on Turin. FlashAttention(num_tokens, /*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1, - layer_idx, layer, activations, qbatch, env.ctx); + layer_idx, layer.query_norm_scale, activations, qbatch, + env.ctx); } SumHeads(layer, activations, env); } diff --git a/gemma/attention.h b/gemma/attention.h index 6c4a48e7..491a0b0b 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -29,8 +29,7 @@ namespace gcpp { #define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ namespace NAMESPACE { \ void PositionalEncodingQK(float* qk, size_t layer_idx, \ - const LayerWeightsPtrs& layer, \ - const AttentionActivations& activations, \ + const AttentionActivationsPtrs& activations, \ ThreadingContext& ctx, size_t worker, size_t pos, \ float mul); \ \ @@ -39,18 +38,18 @@ namespace gcpp { void SingleDotSoftmaxWeightedSum( \ const size_t pos, const size_t start_pos, const size_t last_pos, \ float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ - size_t layer_idx, const LayerWeightsPtrs& layer, \ - const AttentionActivations& activations, float* HWY_RESTRICT att, \ + const MatPtrT& query_norm_scale, size_t layer_idx, \ + const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \ float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \ \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ - const LayerWeightsPtrs& layer, \ - AttentionActivations& activations, \ + const MatPtrT& query_norm_scale, \ + AttentionActivationsPtrs& activations, \ QBatch& qbatch, ThreadingContext& ctx); \ \ void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ const LayerWeightsPtrs& layer, \ - AttentionActivations& activations, QBatch& qbatch, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ MatMulEnv& env, int flags); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index bf3aede6..b5dd2418 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -30,7 +30,6 @@ #include "gemma/activations.h" #include "gemma/configs.h" // kMaxQKVDim #include "gemma/gemma.h" -#include "gemma/weights.h" #include "util/threading.h" #include "hwy/profiler.h" @@ -91,32 +90,33 @@ static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, // Updates q in place for RMSNorm and positional encoding. void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, - MatPtrT& q, const size_t layer_idx, - const LayerWeightsPtrs& layer, - const AttentionActivations& activations, + MatPtrT& q, + const MatPtrT& query_norm_scale, + const size_t layer_idx, + const AttentionActivationsPtrs& activations, ThreadingContext& ctx) { + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; const float query_scale = activations.query_scale; const hwy::Divisor div_qbatch(qbatch.Size()); const auto func = [&](const size_t task, size_t worker) HWY_ATTR { GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding); size_t qi = div_qbatch.Remainder(task); size_t batch_idx = div_qbatch.Divide(task); - for (size_t h = 0; h < layer.layer_config.heads; ++h) { + for (size_t h = 0; h < layer_config.heads; ++h) { const size_t tq_idx = qbatch.Size() * batch_idx + qi; // Find the token position in the query and calculate // the range of cache positions to attend to. const size_t pos = qbatch.Pos(qi) + batch_idx; - float* HWY_RESTRICT q_row = - q.Row(tq_idx) + h * layer.layer_config.qkv_dim; + float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim; // Apply rope and scaling to Q. - if (layer.query_norm_scale.HasPtr()) { - CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { + if (query_norm_scale.HasPtr()) { + CallUpcasted(&query_norm_scale, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row, - layer.layer_config.qkv_dim, ctx, worker); + layer_config.qkv_dim, ctx, worker); }); } - PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx, worker, - pos, query_scale); + PositionalEncodingQK(q_row, layer_idx, activations, ctx, worker, pos, + query_scale); } }; { @@ -154,8 +154,7 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max, void SingleFlashAttention(const size_t start_pos, const size_t last_pos, const float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, const size_t layer_idx, - const LayerWeightsPtrs& layer, - const AttentionActivations& activations, + const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention); @@ -265,15 +264,17 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2, // Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, // max_last_pos]. -void TileFlashAttention( - const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, - const StridedView& qT, const MatPtrT& k, - const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, - const size_t min_last_pos, const size_t max_last_pos, - const MatPtrT& v, const size_t layer_idx, - const LayerWeightsPtrs& layer, const AttentionActivations& activations, - MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, - ThreadingContext& ctx, const size_t worker) { +void TileFlashAttention(const MatPtrT& q, + const uint32_t* HWY_RESTRICT q_offsets, + const StridedView& qT, const MatPtrT& k, + const size_t start_pos, + const uint32_t* HWY_RESTRICT last_pos, + const size_t min_last_pos, const size_t max_last_pos, + const MatPtrT& v, const size_t layer_idx, + const AttentionActivationsPtrs& activations, + MatPtrT& att_out, + const uint32_t* HWY_RESTRICT out_offsets, + ThreadingContext& ctx, const size_t worker) { GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention); constexpr int kHTileSize = kNFx8HTileSize; using DF = hn::ScalableTag; @@ -419,14 +420,16 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, // Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, // max_last_pos]. -void TileFlashAttention4( - const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, - const MatPtrT& k, const size_t start_pos, - const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, - const size_t max_last_pos, const MatPtrT& v, const size_t layer_idx, - const LayerWeightsPtrs& layer, const AttentionActivations& activations, - MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, - ThreadingContext& ctx, const size_t worker) { +void TileFlashAttention4(const MatPtrT& q, + const uint32_t* HWY_RESTRICT q_offsets, + const MatPtrT& k, const size_t start_pos, + const uint32_t* HWY_RESTRICT last_pos, + const size_t min_last_pos, const size_t max_last_pos, + const MatPtrT& v, const size_t layer_idx, + const AttentionActivationsPtrs& activations, + MatPtrT& att_out, + const uint32_t* HWY_RESTRICT out_offsets, + ThreadingContext& ctx, const size_t worker) { GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4); using DF = hn::ScalableTag; const DF df; @@ -589,14 +592,15 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, // grouped together so that mode 1 or 2 can be used, and choosing which of the // 3 modes to use for best efficiency. void FlashAttention(const size_t num_tokens, const size_t target_parallelism, - const size_t layer_idx, const LayerWeightsPtrs& layer, - AttentionActivations& activations, QBatch& qbatch, + const size_t layer_idx, + const MatPtrT& query_norm_scale, + AttentionActivationsPtrs& activations, QBatch& qbatch, ThreadingContext& ctx) { GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive); - RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx, - layer, activations, ctx); + RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, + query_norm_scale, layer_idx, activations, ctx); const hwy::Divisor div_qbatch(qbatch.Size()); - const LayerConfig& layer_config = layer.layer_config; + const LayerConfig& layer_config = activations.config.layer_configs[layer_idx]; const size_t qkv_dim = layer_config.qkv_dim; // A "head group" in the context of GQA refers to a collection of query @@ -732,12 +736,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, // is used above to catch all cases where qT will be used. TileFlashAttention(activations.q, q_offsets, qT, k, start_positions[offset], last_pos, min_last_pos, - max_last_pos, v, layer_idx, layer, activations, + max_last_pos, v, layer_idx, activations, activations.att_out, out_offsets, ctx, worker); } else if (kVTileSize == 4) { TileFlashAttention4(activations.q, q_offsets, k, start_positions[offset], last_pos, min_last_pos, - max_last_pos, v, layer_idx, layer, activations, + max_last_pos, v, layer_idx, activations, activations.att_out, out_offsets, ctx, worker); } else { HWY_UNREACHABLE; @@ -746,7 +750,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, } else { SingleFlashAttention(start_positions[offset], last_pos[offset], activations.q.Row(0) + q_offsets[offset], k, v, - layer_idx, layer, activations, + layer_idx, activations, activations.att_out.Row(0) + out_offsets[offset], ctx, worker); } diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 959b2276..89af4984 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -28,17 +28,16 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. #define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ namespace NAMESPACE { \ - void RMSNormAndPositionalEncoding(size_t num_tokens, const QBatch& qbatch, \ - MatPtrT& q, size_t layer_idx, \ - const LayerWeightsPtrs& layer, \ - const AttentionActivations& activations, \ - ThreadingContext& ctx); \ + void RMSNormAndPositionalEncoding( \ + size_t num_tokens, const QBatch& qbatch, MatPtrT& q, \ + const MatPtrT& query_norm_scale, size_t layer_idx, \ + const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ \ void SingleFlashAttention(size_t start_pos, size_t last_pos, \ const float* HWY_RESTRICT q, \ const MatPtrT& k, const MatPtrT& v, \ - size_t layer_idx, const LayerWeightsPtrs& layer, \ - const AttentionActivations& activations, \ + size_t layer_idx, \ + const AttentionActivationsPtrs& activations, \ float* HWY_RESTRICT att_out, \ ThreadingContext& ctx, size_t worker); \ \ @@ -46,8 +45,9 @@ namespace gcpp { size_t total_tasks, size_t target_parallelism); \ \ void FlashAttention(size_t num_tokens, size_t target_parallelism, \ - size_t layer_idx, const LayerWeightsPtrs& layer, \ - AttentionActivations& activations, QBatch& qbatch, \ + size_t layer_idx, \ + const MatPtrT& query_norm_scale, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ ThreadingContext& ctx); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 4147e389..b33f52f7 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -122,8 +122,9 @@ void TestFlashAttention(size_t target_parallelism) { QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries); const size_t batch_size = kOuter; std::vector> row_ptrs; - AttentionActivations attention(config, layer_config, batch_size, kOuter, - ctx.allocator, row_ptrs); + AttentionActivations attention_storage(config, layer_config, batch_size, + kOuter, ctx.allocator, row_ptrs); + AttentionActivationsPtrs attention(config, kOuter, attention_storage); const size_t qkv_dim = layer_config.qkv_dim; ASSERT_EQ(qkv_dim, kInner); const hwy::Divisor div_qbatch(qbatch.Size()); @@ -145,7 +146,8 @@ void TestFlashAttention(size_t target_parallelism) { SetMat(h + layer_config.heads * 2, v); } SetMat(1, attention.q); - DotSoftmaxWeightedSum(tokens.size(), 0, layers, attention, qbatch, ctx); + DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention, + qbatch, ctx); // Copy the output to saved_att to allow for comparison. auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator); SetMat(1, attention.q); @@ -158,8 +160,8 @@ void TestFlashAttention(size_t target_parallelism) { total_tasks, target_parallelism); printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n", target_parallelism, kNF, kVTileSize); - FlashAttention(tokens.size(), target_parallelism, 0, layers, attention, - qbatch, ctx); + FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale, + attention, qbatch, ctx); AssertClose(attention.att_out, *saved_att); ctx.profiler.PrintResults(); } diff --git a/gemma/gemma.h b/gemma/gemma.h index 771cd1cb..59875fbc 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -51,7 +51,7 @@ struct PerQuery { // attention in Paligemma. size_t prefix_end; - KVCache& kv_cache; + KVCachePtr kv_cache; // Previous token generated for this query, or the last prompt token. Will be // fed into the next Transformer() call. @@ -64,7 +64,7 @@ struct AllQueries { // For `GenerateSingleT`: same prompt/pos, replicated for each KV cache. AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, - const hwy::Span& kv_caches) { + const hwy::Span& kv_caches) { per_query_.reserve(kv_caches.size()); for (size_t i = 0; i < kv_caches.size(); ++i) { HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); @@ -78,11 +78,16 @@ struct AllQueries { } } + AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, + const hwy::Span& kv_caches) + : AllQueries(prompt, pos, prefix_end, + hwy::Span(ToKVCachePtrs(kv_caches))) {} + // Batch of queries with initial position set to zero. Causal attention // is requested via empty or all-zero `prefix_end`. AllQueries( const hwy::Span& prompts, - const hwy::Span& kv_caches, + const hwy::Span& kv_caches, const hwy::Span& prefix_end = hwy::Span()) { HWY_ASSERT(prompts.size() == kv_caches.size()); HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0); @@ -99,6 +104,13 @@ struct AllQueries { } } + AllQueries( + const hwy::Span& prompts, + const hwy::Span& kv_caches, + const hwy::Span& prefix_end = hwy::Span()) + : AllQueries(prompts, hwy::Span(ToKVCachePtrs(kv_caches)), + prefix_end) {} + void Reserve(size_t size) { per_query_.reserve(size); } void Append(const PerQuery& query) { per_query_.push_back(query); } @@ -156,7 +168,7 @@ class QBatch { size_t PrefixEnd(size_t qi) const { return queries_[QueryIdx(qi)].prefix_end; } - KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; } + KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; } int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; } private: diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index ca814f47..ded8df5e 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -16,6 +16,7 @@ #include "gemma/kv_cache.h" #include +#include #include "gemma/configs.h" #include "gemma/gemma_args.h" @@ -54,4 +55,13 @@ KVCache KVCache::Copy() { return copy; } +std::vector ToKVCachePtrs(const hwy::Span& kv_caches) { + std::vector ptrs; + ptrs.reserve(kv_caches.size()); + for (size_t i = 0; i < kv_caches.size(); ++i) { + ptrs.push_back(KVCachePtr{.kv_cache = kv_caches[i].kv_cache}); + } + return ptrs; +} + } // namespace gcpp diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 31e964bc..37a4d0eb 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -17,6 +17,7 @@ #define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ #include +#include #include "gemma/configs.h" // ModelConfig #include "gemma/gemma_args.h" // InferenceArgs @@ -46,6 +47,15 @@ struct KVCache { KVCache(const Extents2D& kv_extents, const Allocator& allocator); }; +// A non-owning view of a KVCache. +struct KVCachePtr { + size_t SeqLen() const { return kv_cache.Rows(); } + MatPtrT kv_cache; +}; + +// Convenience function to create views into KVCaches. +std::vector ToKVCachePtrs(const hwy::Span& kv_caches); + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ diff --git a/ops/ops_test.cc b/ops/ops_test.cc index e89dca04..8fa3625e 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -454,7 +454,7 @@ void TestRopeAndMulBy() { x.Row(0)[i] = random_float(); } - const float qmul = AttentionActivations::ChooseQueryScale(config); + const float qmul = ChooseQueryScale(config); constexpr float kmul = 1.0f; MatStorageT qexpected("qexpected", dim_qkv, ctx.allocator); diff --git a/util/mat.h b/util/mat.h index 59eceaa6..6b8dc064 100644 --- a/util/mat.h +++ b/util/mat.h @@ -284,6 +284,9 @@ class MatPtrT : public MatPtr { public: using T = MatT; + // Default constructor for use with uninitialized views. + MatPtrT() = default; + // Called by `MatStorageT`. MatPtrT(const char* name, Extents2D extents) : MatPtr(name, TypeEnum(), extents) {}