Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ cc_test(
":kv_cache",
":mat",
":matmul",
":ops",
":threading_context",
":weights",
"@googletest//:gtest_main", # buildcleaner: keep
Expand Down
82 changes: 53 additions & 29 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,24 @@

namespace gcpp {

struct AttentionActivations {
// Returns the scale value to use for the query in the attention computation.
// Also called by ops_test.
static inline float ChooseQueryScale(const ModelConfig& config) {
const LayerConfig& layer_config = config.layer_configs[0];
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f /
sqrtf(static_cast<float>(config.model_dim / layer_config.heads));
// QueryScaleType::SqrtKeySize
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
}
// Returns the scale value to use for the query in the attention computation.
// Also called by ops_test.
static inline float ChooseQueryScale(const ModelConfig& config) {
const LayerConfig& layer_config = config.layer_configs[0];
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f /
sqrtf(static_cast<float>(config.model_dim / layer_config.heads));
// QueryScaleType::SqrtKeySize
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
}

struct AttentionActivations {
AttentionActivations(
const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: config(config),

// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
// and does not use an external KV cache.
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
// MHA and does not use an external KV cache.
q(MatFactory("q", batch_size,
config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim
Expand All @@ -76,11 +74,7 @@ struct AttentionActivations {
layer_config.post_qk == PostQKType::HalfRope)),
inv_timescale_global(CreateInvTimescale(
allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),

div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(layer_config.heads)),
query_scale(ChooseQueryScale(config)) {
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) {
// Batch size can be 0 in experimental code so do not assert.
if (batch_size == 0) {
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
Expand Down Expand Up @@ -108,9 +102,7 @@ struct AttentionActivations {
att_sums.OverrideRows(batch_size);
}

const ModelConfig& config;

MatStorageT<float> q; // query
MatStorageT<float> q; // query
MatStorageT<float> q_T; // Transposed to maximize attention speed.

MatStorageT<float> pre_att_rms_out;
Expand All @@ -122,9 +114,39 @@ struct AttentionActivations {
// Rope
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
};

// A non-owning view of AttentionActivations.
struct AttentionActivationsPtrs {
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len)
: config(config),
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
query_scale(ChooseQueryScale(config)) {}

AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
const AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len) {
q = activations.q;
q_T = activations.q_T;
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
att_out = activations.att_out;
att_sums = activations.att_sums;
inv_timescale = activations.inv_timescale;
inv_timescale_global = activations.inv_timescale_global;
}

const ModelConfig& config;
MatPtrT<float> q;
MatPtrT<float> q_T;
MatPtrT<float> pre_att_rms_out;
MatPtrT<float> att;
MatPtrT<float> att_out;
MatPtrT<BF16> att_sums;
MatPtrT<float> inv_timescale;
MatPtrT<float> inv_timescale_global;
hwy::Divisor div_seq_len;
// Unfortunately, some models have had non-power-of-two heads.
hwy::Divisor div_heads;
float query_scale;
};
Expand All @@ -150,8 +172,9 @@ struct Activations {
ffw_out(
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),

attention(config, layer_config, batch_size, seq_len, ctx.allocator,
row_ptrs) {
attention_storage(config, layer_config, batch_size, seq_len,
ctx.allocator, row_ptrs),
attention(config, seq_len, attention_storage) {
HWY_ASSERT(batch_size != 0);

// For MatMul outputs, precompute their row pointers.
Expand Down Expand Up @@ -179,12 +202,12 @@ struct Activations {
C2.OverrideRows(batch_size);
ffw_out.OverrideRows(batch_size);

attention.SetBatchSize(batch_size);
attention_storage.SetBatchSize(batch_size);
}

const LayerConfig& layer_config;

MatStorageT<float> x; // input
MatStorageT<float> x; // input
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
MatStorageT<float> logits; // TODO: BF16 after Softmax supports that.
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
Expand All @@ -195,7 +218,8 @@ struct Activations {
MatStorageT<BF16> C2;
MatStorageT<float> ffw_out;

AttentionActivations attention;
AttentionActivations attention_storage;
AttentionActivationsPtrs attention;
};

} // namespace gcpp
Expand Down
52 changes: 27 additions & 25 deletions gemma/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
}

void PositionalEncodingQK(float* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
const AttentionActivationsPtrs& activations,
ThreadingContext& ctx, const size_t worker,
const size_t pos, const float mul) {
const size_t qkv_dim = layer.layer_config.qkv_dim;
const PostQKType& post_qk = layer.layer_config.post_qk;
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const size_t qkv_dim = layer_config.qkv_dim;
const PostQKType& post_qk = layer_config.post_qk;
// qk is either q or k, so qkv_dim is the length we operate on.
const float* inv_timescale = activations.inv_timescale.PackedScale1();
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
Expand Down Expand Up @@ -130,23 +130,23 @@ static HWY_INLINE void WeightedSumV(
void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, float* HWY_RESTRICT att,
const MatPtrT<float>& query_norm_scale, const size_t layer_idx,
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale;
const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor());

const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
// Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
if (query_norm_scale.HasPtr()) {
CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
layer.layer_config.qkv_dim, ctx, worker);
layer_config.qkv_dim, ctx, worker);
});
}

PositionalEncodingQK(q, layer_idx, layer, activations, ctx, worker, pos,
PositionalEncodingQK(q, layer_idx, activations, ctx, worker, pos,
query_scale);

QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
Expand All @@ -169,13 +169,13 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
}

void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch,
ThreadingContext& ctx) {
const MatPtrT<float>& query_norm_scale,
AttentionActivationsPtrs& activations,
QBatch& qbatch, ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);

const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config;
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const size_t qkv_dim = layer_config.qkv_dim;

// A "head group" in the context of GQA refers to a collection of query
Expand Down Expand Up @@ -223,8 +223,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());

SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
layer, activations, att, att_out, ctx, worker);
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
query_norm_scale, layer_idx, activations, att,
att_out, ctx, worker);
};

{
Expand All @@ -245,7 +246,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
// Fills activations.q and writes to KV cache.
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
AttentionActivations& activations,
AttentionActivationsPtrs& activations,
const QBatch& qbatch, const int flags,
MatMulEnv& env) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(),
Expand Down Expand Up @@ -312,8 +313,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
});
}

PositionalEncodingQK(kv_f32, layer_idx, layer, activations, env.ctx,
worker, pos, /*mul=*/1.0f);
PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
pos, /*mul=*/1.0f);
CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
});
Expand All @@ -322,7 +323,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`).
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
AttentionActivations& activations,
AttentionActivationsPtrs& activations,
MatMulEnv& env) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
const LayerConfig& layer_config = layer.layer_config;
Expand All @@ -340,7 +341,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,

void GemmaAttention(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch,
AttentionActivationsPtrs& activations, QBatch& qbatch,
MatMulEnv& env, int flags) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention);

Expand All @@ -352,13 +353,14 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,

ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
if (flags & kAttentionUseOld) {
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
env.ctx);
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer.query_norm_scale,
activations, qbatch, env.ctx);
} else {
// * 2 does not help on Turin.
FlashAttention(num_tokens,
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
layer_idx, layer, activations, qbatch, env.ctx);
layer_idx, layer.query_norm_scale, activations, qbatch,
env.ctx);
}
SumHeads(layer, activations, env);
}
Expand Down
13 changes: 6 additions & 7 deletions gemma/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ namespace gcpp {
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void PositionalEncodingQK(float* qk, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \
const AttentionActivationsPtrs& activations, \
ThreadingContext& ctx, size_t worker, size_t pos, \
float mul); \
\
Expand All @@ -39,18 +38,18 @@ namespace gcpp {
void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
\
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, \
const MatPtrT<float>& query_norm_scale, \
AttentionActivationsPtrs& activations, \
QBatch& qbatch, ThreadingContext& ctx); \
\
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
Expand Down
Loading
Loading