Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 69 additions & 83 deletions gemma/gemma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1057,96 +1057,82 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
// always equal.
size_t pos_offset = 0; // offset relative to pos

auto prefill_phase = [&]() HWY_ATTR {
bool keep_on = true;
const double prefill_start = hwy::platform::Now();

// Prefill stops before prompt_size - 1 since the last prompt token is the
// first input token for generation.
while (pos_offset < prompt_size - 1 && keep_on) {
const size_t batch_size =
std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset);
HWY_DASSERT(batch_size <= kPrefillBatchSize);
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
const int* batch_tokens = prompt.data() + pos_offset;
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
prefill_activations, kv_cache, pool, inner_pool);
for (size_t idx = 0; idx < batch_size; ++idx) {
keep_on = stream_token(batch_tokens[idx], 0.0f);
if(!keep_on) {
break;
}
}
pos += batch_size;
pos_offset += batch_size;
const double prefill_start = hwy::platform::Now();

// Prefill stops before prompt_size - 1 since the last prompt token is the
// first input token for generation.
while (pos_offset < prompt_size - 1) {
const size_t batch_size =
std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset);
HWY_DASSERT(batch_size <= kPrefillBatchSize);
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
const int* batch_tokens = prompt.data() + pos_offset;
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
prefill_activations, kv_cache, pool, inner_pool);
for (size_t idx = 0; idx < batch_size; ++idx) {
if (!stream_token(batch_tokens[idx], 0.0f)) return;
}
pos += batch_size;
pos_offset += batch_size;
}

if (verbosity >= 2) {
// in the future this output should not occur in GenerateImpl but instead
// should be available as observable state for frontend code to handle I/O.
const double prefill_end = hwy::platform::Now();
const double prefill_tok_sec =
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
}
if (verbosity >= 2) {
// in the future this output should not occur in GenerateImpl but instead
// should be available as observable state for frontend code to handle I/O.
const double prefill_end = hwy::platform::Now();
const double prefill_tok_sec =
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
}

return keep_on;
};
const double gen_start = hwy::platform::Now();

auto transform_phase = [&]() HWY_ATTR {

const double gen_start = hwy::platform::Now();

HWY_DASSERT(pos_offset == prompt_size - 1);

size_t pos_gen_start = pos_offset;
int token = prompt.at(pos_offset);
stream_token(token, 0);
for (size_t generate_pos = 0;
pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) {
const bool is_generating_phase = pos_offset >= prompt_size - 1;
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool,
layers_output);
float* final_activation = activations.x.data();
// The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill
// is disabled.
if (is_generating_phase) {
PROFILER_ZONE("Gen.Embedding");
// Generation phase
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
0, final_activation,
activations.logits.data(), pool);
// Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), kVocabSize);
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
gen, temperature, accept_token);
if (!stream_token(token, activations.logits[token])) {
token = EOS_ID;
}
} else {
// We would take this branch if we were not doing Prefill but would
// process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1);
stream_token(token, 0);
HWY_DASSERT(pos_offset == prompt_size - 1);

size_t pos_gen_start = pos_offset;
int token = prompt.at(pos_offset);
stream_token(token, 0);
for (size_t generate_pos = 0;
pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) {
const bool is_generating_phase = pos_offset >= prompt_size - 1;
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool,
layers_output);
float* final_activation = activations.x.data();
// The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill
// is disabled.
if (is_generating_phase) {
PROFILER_ZONE("Gen.Embedding");
// Generation phase
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
0, final_activation,
activations.logits.data(), pool);
// Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), kVocabSize);
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
gen, temperature, accept_token);
if (!stream_token(token, activations.logits[token])) {
token = EOS_ID;
}
if (token == EOS_ID) {
if (verbosity >= 2) {
const double gen_end = hwy::platform::Now();
const double gen_tok_sec =
static_cast<double>(pos_offset - pos_gen_start) /
(gen_end - gen_start);
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
}
break;
} else {
// We would take this branch if we were not doing Prefill but would
// process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1);
if (!stream_token(token, 0)) {
token = EOS_ID;
}
}
};

if(prefill_phase()) {
transform_phase();
if (token == EOS_ID) {
if (verbosity >= 2) {
const double gen_end = hwy::platform::Now();
const double gen_tok_sec =
static_cast<double>(pos_offset - pos_gen_start) /
(gen_end - gen_start);
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
}
break;
}
}
}

Expand Down