Skip to content

Commit 05e8128

Browse files
committed
Adding Compute-Context-Length(CCL)
Signed-off-by: Vahid Janfaza <vjanfaza@qti.qualcomm.com>
1 parent d541bd5 commit 05e8128

23 files changed

+892
-210
lines changed

QEfficient/generation/text_generation_inference.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -814,14 +814,12 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
814814
self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
815815
prefill_ccl_id = 0
816816
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
817-
print(f"CCL Prefill: {self.comp_ctx_lengths_prefill[prefill_ccl_id]}")
818817

819818
for i in range(num_chunks):
820819
if self.comp_ctx_lengths_prefill is not None:
821820
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]:
822821
prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1)
823822
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
824-
print(f"CCL Prefill: {self.comp_ctx_lengths_prefill[prefill_ccl_id]}")
825823

826824
chunk_inputs = inputs.copy()
827825
chunk_inputs["input_ids"] = inputs["input_ids"][

QEfficient/generation/vlm_generation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def __init__(
8383
vision_qpc_path: str,
8484
device_id: Optional[List[int]] = None,
8585
ctx_len: Optional[int] = None,
86+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
87+
comp_ctx_lengths_decode: Optional[List[int]] = None,
8688
enable_debug_logs: bool = False,
8789
write_io_dir: Optional[str] = None,
8890
full_batch_size: Optional[int] = None,
@@ -123,6 +125,8 @@ def __init__(
123125
qpc_path=lang_qpc_path,
124126
full_batch_size=full_batch_size,
125127
ctx_len=ctx_len,
128+
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
129+
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
126130
device_id=device_id,
127131
enable_debug_logs=enable_debug_logs,
128132
write_io_dir=write_io_dir,
@@ -294,6 +298,11 @@ def _execute_chunked_prefill(
294298
outputs = None
295299
chunk_image_idx = None
296300

301+
if self.comp_ctx_lengths_prefill is not None:
302+
self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
303+
prefill_ccl_id = 0
304+
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
305+
297306
for i in range(num_chunks):
298307
input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len]
299308
position_ids_slice = lang_inputs["position_ids"][
@@ -312,6 +321,13 @@ def _execute_chunked_prefill(
312321
if "cross_attention_mask" in lang_inputs:
313322
chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"]
314323

324+
if self.comp_ctx_lengths_prefill is not None:
325+
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]:
326+
prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1)
327+
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
328+
329+
chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"]
330+
315331
outputs = self._session.run(chunk_inputs)
316332

317333
if "image_idx_output" in outputs:

QEfficient/transformers/cache_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,7 @@ def update(
622622
is_sliding_layer = cache_kwargs.get("is_sliding")
623623
sliding_window = cache_kwargs.get("sliding_window")
624624
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value from the kwargs
625+
comp_ctx_len = cache_kwargs.get("CCL")
625626

626627
if is_sliding_layer:
627628
kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window)
@@ -649,7 +650,10 @@ def update(
649650
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
650651

651652
# Original Gather
652-
ctx_len = self.key_cache[layer_idx].shape[2]
653+
if is_sliding_layer:
654+
ctx_len = k_out.shape[2]
655+
else:
656+
ctx_len = comp_ctx_len
653657
ctx_indices = torch.arange(ctx_len)[None, None, ...]
654658
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
655659
invalid_mask = ctx_indices > gather_limit
@@ -660,11 +664,11 @@ def update(
660664
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
661665

662666
if batch_index is not None:
663-
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
664-
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
667+
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len)
668+
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len)
665669
else:
666-
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
667-
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
670+
k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len)
671+
v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len)
668672

669673
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
670674
return k_out, v_out

QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ def forward(
417417
attention_mask: Optional[torch.Tensor],
418418
position_ids: Optional[torch.LongTensor] = None,
419419
past_key_value: Optional[Cache] = None,
420+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
420421
batch_index: Optional[torch.LongTensor] = None,
421422
cache_position: Optional[torch.LongTensor] = None,
422423
sliding_mask=None,
@@ -433,6 +434,8 @@ def forward(
433434
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
434435

435436
if past_key_value is not None:
437+
if comp_ctx_lengths is not None:
438+
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
436439
# sin and cos are specific to RoPE models; cache_position needed for the static cache
437440
cache_kwargs = {
438441
"sin": sin,
@@ -442,6 +445,7 @@ def forward(
442445
"config": self.config,
443446
"is_sliding": self.sliding_window is not None,
444447
"sliding_window": past_key_value.sliding_window_len,
448+
"CCL": attention_mask.shape[-1],
445449
}
446450
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
447451

@@ -476,6 +480,7 @@ def forward(
476480
attention_mask: Optional[torch.Tensor] = None,
477481
position_ids: Optional[torch.LongTensor] = None,
478482
past_key_value: Optional[Cache] = None,
483+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
479484
batch_index: Optional[torch.LongTensor] = None,
480485
output_attentions: Optional[bool] = False,
481486
use_cache: Optional[bool] = False,
@@ -492,6 +497,7 @@ def forward(
492497
attention_mask=attention_mask,
493498
position_ids=position_ids,
494499
past_key_value=past_key_value,
500+
comp_ctx_lengths=comp_ctx_lengths,
495501
batch_index=batch_index,
496502
use_cache=use_cache,
497503
cache_position=cache_position,
@@ -526,6 +532,7 @@ def forward(
526532
attention_mask: Optional[torch.Tensor] = None,
527533
position_ids: Optional[torch.LongTensor] = None,
528534
past_key_values: Optional[Cache] = None,
535+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
529536
batch_index: Optional[torch.LongTensor] = None,
530537
inputs_embeds: Optional[torch.FloatTensor] = None,
531538
use_cache: Optional[bool] = None,
@@ -586,6 +593,7 @@ def forward(
586593
attention_mask=causal_mask,
587594
position_ids=position_ids,
588595
past_key_value=past_key_values,
596+
comp_ctx_lengths=comp_ctx_lengths,
589597
batch_index=batch_index,
590598
use_cache=use_cache,
591599
output_attentions=output_attentions,
@@ -619,6 +627,7 @@ def forward(
619627
attention_mask: Optional[torch.Tensor] = None,
620628
position_ids: Optional[torch.LongTensor] = None,
621629
past_key_values: Optional[Cache] = None,
630+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
622631
batch_index: Optional[torch.LongTensor] = None,
623632
inputs_embeds: Optional[torch.FloatTensor] = None,
624633
labels: Optional[torch.LongTensor] = None,
@@ -670,6 +679,7 @@ def forward(
670679
attention_mask=attention_mask,
671680
position_ids=position_ids,
672681
past_key_values=past_key_values,
682+
comp_ctx_lengths=comp_ctx_lengths,
673683
batch_index=batch_index,
674684
inputs_embeds=inputs_embeds,
675685
use_cache=use_cache,

0 commit comments

Comments
 (0)