Skip to content

Commit 693486b

Browse files
birukwcopybara-github
authored andcommitted
[Gemma.cpp] Allows non-owned arguments for attention methods.
* Adds and uses a new `AttentionActivationPtrs` that holds non-owning `MatPtrs`. Acts as a view into `AttentionActivations`. * Updates `QBatch` to hold non-owning `MatPtr`s to the kv caches. * Enables the `MatPtrT` default constructor for simpler initializations. * Pulls out and passes `LayerWeightsPtrs::query_norm_scale` directly. While `LayerWeightsPtrs` already held non-owning `MatPtr`s, this change avoids the need to find and construct several empty weight tensors just to construct one `query_norm_scale` tensor. PiperOrigin-RevId: 823702392
1 parent 86200ce commit 693486b

File tree

12 files changed

+185
-120
lines changed

12 files changed

+185
-120
lines changed

BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ cc_test(
139139
":kv_cache",
140140
":mat",
141141
":matmul",
142-
":ops",
143142
":threading_context",
144143
":weights",
145144
"@googletest//:gtest_main", # buildcleaner: keep

gemma/activations.h

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,24 @@
3131

3232
namespace gcpp {
3333

34-
struct AttentionActivations {
35-
// Returns the scale value to use for the query in the attention computation.
36-
// Also called by ops_test.
37-
static inline float ChooseQueryScale(const ModelConfig& config) {
38-
const LayerConfig& layer_config = config.layer_configs[0];
39-
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
40-
return 1.0f /
41-
sqrtf(static_cast<float>(config.model_dim / layer_config.heads));
42-
// QueryScaleType::SqrtKeySize
43-
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
44-
}
34+
// Returns the scale value to use for the query in the attention computation.
35+
// Also called by ops_test.
36+
static inline float ChooseQueryScale(const ModelConfig& config) {
37+
const LayerConfig& layer_config = config.layer_configs[0];
38+
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
39+
return 1.0f /
40+
sqrtf(static_cast<float>(config.model_dim / layer_config.heads));
41+
// QueryScaleType::SqrtKeySize
42+
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
43+
}
4544

45+
struct AttentionActivations {
4646
AttentionActivations(
4747
const ModelConfig& config, const LayerConfig& layer_config,
4848
size_t batch_size, size_t seq_len, const Allocator& allocator,
4949
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
50-
: config(config),
51-
52-
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
53-
// and does not use an external KV cache.
50+
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
51+
// MHA and does not use an external KV cache.
5452
q(MatFactory("q", batch_size,
5553
config.vocab_size == 0
5654
? layer_config.heads * 3 * layer_config.qkv_dim
@@ -76,11 +74,7 @@ struct AttentionActivations {
7674
layer_config.post_qk == PostQKType::HalfRope)),
7775
inv_timescale_global(CreateInvTimescale(
7876
allocator, layer_config.qkv_dim,
79-
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
80-
81-
div_seq_len(static_cast<uint32_t>(seq_len)),
82-
div_heads(static_cast<uint32_t>(layer_config.heads)),
83-
query_scale(ChooseQueryScale(config)) {
77+
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) {
8478
// Batch size can be 0 in experimental code so do not assert.
8579
if (batch_size == 0) {
8680
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
@@ -108,9 +102,7 @@ struct AttentionActivations {
108102
att_sums.OverrideRows(batch_size);
109103
}
110104

111-
const ModelConfig& config;
112-
113-
MatStorageT<float> q; // query
105+
MatStorageT<float> q; // query
114106
MatStorageT<float> q_T; // Transposed to maximize attention speed.
115107

116108
MatStorageT<float> pre_att_rms_out;
@@ -122,9 +114,39 @@ struct AttentionActivations {
122114
// Rope
123115
MatStorageT<float> inv_timescale;
124116
MatStorageT<float> inv_timescale_global;
117+
};
125118

119+
// A non-owning view of AttentionActivations.
120+
struct AttentionActivationsPtrs {
121+
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len)
122+
: config(config),
123+
div_seq_len(static_cast<uint32_t>(seq_len)),
124+
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
125+
query_scale(ChooseQueryScale(config)) {}
126+
127+
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
128+
const AttentionActivations& activations)
129+
: AttentionActivationsPtrs(config, seq_len) {
130+
q = activations.q;
131+
q_T = activations.q_T;
132+
pre_att_rms_out = activations.pre_att_rms_out;
133+
att = activations.att;
134+
att_out = activations.att_out;
135+
att_sums = activations.att_sums;
136+
inv_timescale = activations.inv_timescale;
137+
inv_timescale_global = activations.inv_timescale_global;
138+
}
139+
140+
const ModelConfig& config;
141+
MatPtrT<float> q;
142+
MatPtrT<float> q_T;
143+
MatPtrT<float> pre_att_rms_out;
144+
MatPtrT<float> att;
145+
MatPtrT<float> att_out;
146+
MatPtrT<BF16> att_sums;
147+
MatPtrT<float> inv_timescale;
148+
MatPtrT<float> inv_timescale_global;
126149
hwy::Divisor div_seq_len;
127-
// Unfortunately, some models have had non-power-of-two heads.
128150
hwy::Divisor div_heads;
129151
float query_scale;
130152
};
@@ -150,8 +172,9 @@ struct Activations {
150172
ffw_out(
151173
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
152174

153-
attention(config, layer_config, batch_size, seq_len, ctx.allocator,
154-
row_ptrs) {
175+
attention_storage(config, layer_config, batch_size, seq_len,
176+
ctx.allocator, row_ptrs),
177+
attention(config, seq_len, attention_storage) {
155178
HWY_ASSERT(batch_size != 0);
156179

157180
// For MatMul outputs, precompute their row pointers.
@@ -179,12 +202,12 @@ struct Activations {
179202
C2.OverrideRows(batch_size);
180203
ffw_out.OverrideRows(batch_size);
181204

182-
attention.SetBatchSize(batch_size);
205+
attention_storage.SetBatchSize(batch_size);
183206
}
184207

185208
const LayerConfig& layer_config;
186209

187-
MatStorageT<float> x; // input
210+
MatStorageT<float> x; // input
188211
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
189212
MatStorageT<float> logits; // TODO: BF16 after Softmax supports that.
190213
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
@@ -195,7 +218,8 @@ struct Activations {
195218
MatStorageT<BF16> C2;
196219
MatStorageT<float> ffw_out;
197220

198-
AttentionActivations attention;
221+
AttentionActivations attention_storage;
222+
AttentionActivationsPtrs attention;
199223
};
200224

201225
} // namespace gcpp

gemma/attention.cc

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
7373
}
7474

7575
void PositionalEncodingQK(float* qk, const size_t layer_idx,
76-
const LayerWeightsPtrs& layer,
77-
const AttentionActivations& activations,
76+
const AttentionActivationsPtrs& activations,
7877
ThreadingContext& ctx, const size_t worker,
7978
const size_t pos, const float mul) {
80-
const size_t qkv_dim = layer.layer_config.qkv_dim;
81-
const PostQKType& post_qk = layer.layer_config.post_qk;
79+
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
80+
const size_t qkv_dim = layer_config.qkv_dim;
81+
const PostQKType& post_qk = layer_config.post_qk;
8282
// qk is either q or k, so qkv_dim is the length we operate on.
8383
const float* inv_timescale = activations.inv_timescale.PackedScale1();
8484
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
@@ -130,23 +130,23 @@ static HWY_INLINE void WeightedSumV(
130130
void SingleDotSoftmaxWeightedSum(
131131
const size_t pos, const size_t start_pos, const size_t last_pos,
132132
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
133-
const size_t layer_idx, const LayerWeightsPtrs& layer,
134-
const AttentionActivations& activations, float* HWY_RESTRICT att,
133+
const MatPtrT<float>& query_norm_scale, const size_t layer_idx,
134+
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
135135
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
136136
const float att_cap = activations.config.att_cap;
137137
const float query_scale = activations.query_scale;
138138
const size_t seq_len =
139139
static_cast<size_t>(activations.div_seq_len.GetDivisor());
140-
140+
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
141141
// Apply rope and scaling to Q.
142-
if (layer.query_norm_scale.HasPtr()) {
143-
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
142+
if (query_norm_scale.HasPtr()) {
143+
CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
144144
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
145-
layer.layer_config.qkv_dim, ctx, worker);
145+
layer_config.qkv_dim, ctx, worker);
146146
});
147147
}
148148

149-
PositionalEncodingQK(q, layer_idx, layer, activations, ctx, worker, pos,
149+
PositionalEncodingQK(q, layer_idx, activations, ctx, worker, pos,
150150
query_scale);
151151

152152
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
@@ -169,13 +169,13 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
169169
}
170170

171171
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
172-
const LayerWeightsPtrs& layer,
173-
AttentionActivations& activations, QBatch& qbatch,
174-
ThreadingContext& ctx) {
172+
const MatPtrT<float>& query_norm_scale,
173+
AttentionActivationsPtrs& activations,
174+
QBatch& qbatch, ThreadingContext& ctx) {
175175
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);
176176

177177
const hwy::Divisor div_qbatch(qbatch.Size());
178-
const LayerConfig& layer_config = layer.layer_config;
178+
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
179179
const size_t qkv_dim = layer_config.qkv_dim;
180180

181181
// A "head group" in the context of GQA refers to a collection of query
@@ -223,8 +223,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
223223
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
224224
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
225225

226-
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
227-
layer, activations, att, att_out, ctx, worker);
226+
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
227+
query_norm_scale, layer_idx, activations, att,
228+
att_out, ctx, worker);
228229
};
229230

230231
{
@@ -245,7 +246,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
245246
// Fills activations.q and writes to KV cache.
246247
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
247248
const LayerWeightsPtrs& layer,
248-
AttentionActivations& activations,
249+
AttentionActivationsPtrs& activations,
249250
const QBatch& qbatch, const int flags,
250251
MatMulEnv& env) {
251252
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(),
@@ -312,8 +313,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
312313
});
313314
}
314315

315-
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, env.ctx,
316-
worker, pos, /*mul=*/1.0f);
316+
PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
317+
pos, /*mul=*/1.0f);
317318
CompressPerThread tls;
318319
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
319320
});
@@ -322,7 +323,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
322323
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
323324
// head_dim (`qkv_dim`) into output (`layer_out`).
324325
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
325-
AttentionActivations& activations,
326+
AttentionActivationsPtrs& activations,
326327
MatMulEnv& env) {
327328
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
328329
const LayerConfig& layer_config = layer.layer_config;
@@ -340,7 +341,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
340341

341342
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
342343
const LayerWeightsPtrs& layer,
343-
AttentionActivations& activations, QBatch& qbatch,
344+
AttentionActivationsPtrs& activations, QBatch& qbatch,
344345
MatMulEnv& env, int flags) {
345346
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention);
346347

@@ -352,13 +353,14 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
352353

353354
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
354355
if (flags & kAttentionUseOld) {
355-
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
356-
env.ctx);
356+
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer.query_norm_scale,
357+
activations, qbatch, env.ctx);
357358
} else {
358359
// * 2 does not help on Turin.
359360
FlashAttention(num_tokens,
360361
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
361-
layer_idx, layer, activations, qbatch, env.ctx);
362+
layer_idx, layer.query_norm_scale, activations, qbatch,
363+
env.ctx);
362364
}
363365
SumHeads(layer, activations, env);
364366
}

gemma/attention.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ namespace gcpp {
2929
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
3030
namespace NAMESPACE { \
3131
void PositionalEncodingQK(float* qk, size_t layer_idx, \
32-
const LayerWeightsPtrs& layer, \
33-
const AttentionActivations& activations, \
32+
const AttentionActivationsPtrs& activations, \
3433
ThreadingContext& ctx, size_t worker, size_t pos, \
3534
float mul); \
3635
\
@@ -39,18 +38,18 @@ namespace gcpp {
3938
void SingleDotSoftmaxWeightedSum( \
4039
const size_t pos, const size_t start_pos, const size_t last_pos, \
4140
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
42-
size_t layer_idx, const LayerWeightsPtrs& layer, \
43-
const AttentionActivations& activations, float* HWY_RESTRICT att, \
41+
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
42+
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
4443
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
4544
\
4645
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
47-
const LayerWeightsPtrs& layer, \
48-
AttentionActivations& activations, \
46+
const MatPtrT<float>& query_norm_scale, \
47+
AttentionActivationsPtrs& activations, \
4948
QBatch& qbatch, ThreadingContext& ctx); \
5049
\
5150
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
5251
const LayerWeightsPtrs& layer, \
53-
AttentionActivations& activations, QBatch& qbatch, \
52+
AttentionActivationsPtrs& activations, QBatch& qbatch, \
5453
MatMulEnv& env, int flags); \
5554
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
5655
} // namespace NAMESPACE

0 commit comments

Comments
 (0)