diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index 814122b9d..fbff5b18b 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -340,6 +340,18 @@ def main( "--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation." ) parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.") + parser.add_argument( + "--comp-ctx-lengths-prefill", + type=lambda comp_ctx_lengths_prefill: [int(x) for x in comp_ctx_lengths_prefill.split(",")], + default=[512], + help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).", + ) + parser.add_argument( + "--comp-ctx-lengths-decode", + type=lambda comp_ctx_lengths_decode: [int(x) for x in comp_ctx_lengths_decode.split(",")], + default=[2048], + help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).", + ) parser.add_argument( "--mxfp6", "--mxfp6_matmul", diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index c4f5a7bbd..269ccb0be 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) -def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: - ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0])) +def CtxGather( + data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 +) -> onnxscript.FLOAT: + # Create a shape tensor based on comp_ctx_len + shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0) + + # Directly use the shape tensor without validation + ctx_indices = ops.Expand(ctx_indices, shape_tensor) ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) return ops.GatherND(data, ctx_indices, batch_dims=2) @@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function): """ @staticmethod - def forward(data: torch.Tensor, ctx_indices: torch.Tensor): + def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) return data[batch_indices, head_indices, ctx_indices] @@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs): pass @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) + def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value: + return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data) diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py index 75d9a12ef..cc9693716 100644 --- a/QEfficient/customop/ctx_scatter_gather_cb.py +++ b/QEfficient/customop/ctx_scatter_gather_cb.py @@ -97,16 +97,20 @@ def symbolic( @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGatherCB( - data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32 + data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 ) -> onnxscript.FLOAT: batch_size = ops.Gather(ops.Shape(batch_index), [0]) num_heads = ops.Gather(ops.Shape(data), [1]) - ctx_len = ops.Gather(ops.Shape(data), [2]) + # using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well. + ctx_len = ops.Reshape(comp_ctx_len, [1]) # Expanded shape to create indices zero = ops.Constant(value_ints=[0]) one = ops.Constant(value_ints=[1]) - exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0) + # exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0) + exp_shape = ops.Concat( + ops.Reshape(batch_size, [1]), ops.Reshape(num_heads, [1]), ops.Reshape(ctx_len, [1]), one, axis=0 + ) # Create indices batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape) @@ -119,7 +123,7 @@ def CtxGatherCB( class CtxGatherFuncCB(torch.autograd.Function): @staticmethod - def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor): + def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = batch_index.view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) return data[batch_indices, head_indices, ctx_indices] @@ -129,8 +133,10 @@ def setup_context(ctx, inputs, outputs): pass @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data) + def symbolic( + g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int + ) -> torch.Value: + return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data) @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index e96908824..7da2300d6 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv( prompts_txt_file_path: Optional[str] = None, device_id: Optional[List[int]] = None, generation_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, enable_debug_logs: bool = False, stream: bool = True, write_io_dir: Optional[str] = None, @@ -384,6 +386,8 @@ def cloud_ai_100_exec_kv( qpc_path=qpc_path, device_id=device_id, ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, enable_debug_logs=enable_debug_logs, write_io_dir=write_io_dir, full_batch_size=full_batch_size, @@ -430,6 +434,8 @@ def __init__( qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, @@ -440,6 +446,8 @@ def __init__( activate: bool = True, ) -> None: self._ctx_len = ctx_len + self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill + self.comp_ctx_lengths_decode = comp_ctx_lengths_decode self._write_io_dir = write_io_dir self.is_tlm = is_tlm self.return_pdfs = return_pdfs @@ -802,7 +810,17 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)] inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + if self.comp_ctx_lengths_prefill is not None: + self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill] + prefill_ccl_id = 0 + inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + for i in range(num_chunks): + if self.comp_ctx_lengths_prefill is not None: + if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]: + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + chunk_inputs = inputs.copy() chunk_inputs["input_ids"] = inputs["input_ids"][ :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len @@ -822,6 +840,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i generation_len, ) + def initialize_ccl(self, decode_inputs): + self.list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode] + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + max_position_id = np.max(decode_inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + + return ccl_id, max_ccl_id + def run_continuous_batching_decode(self, prompt_queue, generation_len): """ Runs continuous batching decode for the given prompt queue and generation length. @@ -853,6 +884,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): # Prepare decode inputs inputs. decode_inputs = self.prepare_decode_inputs() + if self.comp_ctx_lengths_decode is not None: + ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + while prompt_queue or current_decode_ongoing.any(): outputs = self._session.run(decode_inputs) @@ -890,6 +925,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): batch_id_map[decode_batch_id] ] + if self.comp_ctx_lengths_decode is not None: + ###Recalculate ccl_id based on position ids### + # Determine the maximum value of position_ids across all batch elements + max_position_id = np.max(decode_inputs["position_ids"]) + + # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + else: current_decode_ongoing[decode_batch_id] = False else: @@ -902,6 +951,15 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): if self.include_sampler: decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] + if self.comp_ctx_lengths_decode is not None: + # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id + if ( + decode_inputs["position_ids"][decode_batch_id, -1] + >= self.comp_ctx_lengths_decode[ccl_id] - 1 + ): + ccl_id = min(ccl_id + 1, max_ccl_id) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + generated_id_current_index[decode_batch_id] += 1 return decode_pause_time @@ -928,7 +986,18 @@ def run_decode( self._session.set_buffers({"logits": logits_out_placeholder}) finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id num_token = 0 + + if self.comp_ctx_lengths_decode is not None: + ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + + cache_index = np.max(decode_inputs["position_ids"]) for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if cache_index >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + if streamer: streamer.put(decode_inputs["input_ids"][0]) outputs = self._session.run(decode_inputs) @@ -940,6 +1009,7 @@ def run_decode( # Prepare inputs for next iteration decode_inputs["input_ids"] = self._fetch_next_token_id(outputs) decode_inputs["position_ids"][:, -1] += 1 + cache_index += 1 self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1] finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id if self.include_sampler: @@ -989,6 +1059,8 @@ def __init__( qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, @@ -1002,6 +1074,8 @@ def __init__( qpc_path=qpc_path, full_batch_size=full_batch_size, ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, device_id=device_id, enable_debug_logs=enable_debug_logs, write_io_dir=write_io_dir, @@ -1013,6 +1087,8 @@ def __init__( self._full_batch_size = self._qaic_model.full_batch_size self._tokenizer = self._qaic_model.tokenizer self._ctx_len = ctx_len + self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill + self.comp_ctx_lengths_decode = comp_ctx_lengths_decode self._perf_metrics = None self._prompt_queue = None self._text_streamer = None diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py index 2e8f04f2b..5eb91d142 100644 --- a/QEfficient/generation/vlm_generation.py +++ b/QEfficient/generation/vlm_generation.py @@ -83,6 +83,8 @@ def __init__( vision_qpc_path: str, device_id: Optional[List[int]] = None, ctx_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, full_batch_size: Optional[int] = None, @@ -123,6 +125,8 @@ def __init__( qpc_path=lang_qpc_path, full_batch_size=full_batch_size, ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, device_id=device_id, enable_debug_logs=enable_debug_logs, write_io_dir=write_io_dir, @@ -294,6 +298,11 @@ def _execute_chunked_prefill( outputs = None chunk_image_idx = None + if self.comp_ctx_lengths_prefill is not None: + self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill] + prefill_ccl_id = 0 + lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + for i in range(num_chunks): input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len] position_ids_slice = lang_inputs["position_ids"][ @@ -312,6 +321,13 @@ def _execute_chunked_prefill( if "cross_attention_mask" in lang_inputs: chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"] + if self.comp_ctx_lengths_prefill is not None: + if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]: + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + + chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"] + outputs = self._session.run(chunk_inputs) if "image_idx_output" in outputs: diff --git a/QEfficient/peft/lora/layers.py b/QEfficient/peft/lora/layers.py index 6b75e696f..79abeba77 100644 --- a/QEfficient/peft/lora/layers.py +++ b/QEfficient/peft/lora/layers.py @@ -42,15 +42,15 @@ def forward(self, x: torch.Tensor, lora_ids: torch.Tensor): # multilora implementation: lora_ids other_indices_a = torch.arange(self.lora_a_weights.shape[2]).view(1, 1, -1) selected_lora_a_weights = CtxGatherFuncCB.apply( - self.lora_a_weights, lora_ids, other_indices_a + self.lora_a_weights, lora_ids, other_indices_a, self.lora_a_weights.shape[2] ) # other_indices_b = torch.arange(self.lora_b_weights.shape[2]).view(1, 1, -1) selected_lora_b_weights = CtxGatherFuncCB.apply( - self.lora_b_weights, lora_ids, other_indices_b + self.lora_b_weights, lora_ids, other_indices_b, self.lora_b_weights.shape[2] ) # other_indices_s = torch.arange(self.lora_scalings.shape[2]).view(1, 1, -1) selected_lora_scalings = CtxGatherFuncCB.apply( - self.lora_scalings, lora_ids, other_indices_s + self.lora_scalings, lora_ids, other_indices_s, self.lora_scalings.shape[2] ) # selected_lora_a_weights = selected_lora_a_weights.squeeze(1) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 853567be9..5452589f6 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -40,7 +40,8 @@ def read_only(self, cache_kwargs): k_out, v_out = self.keys, self.values position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) - ctx_len = k_out.shape[2] + ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) + ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit @@ -53,11 +54,11 @@ def read_only(self, cache_kwargs): ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out @@ -137,7 +138,7 @@ def update( k_out, v_out = self.keys, self.values # Gather - ctx_len = k_out.shape[2] + ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit @@ -149,11 +150,11 @@ def update( ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out @@ -414,7 +415,7 @@ def update( k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] # Original Gather - ctx_len = self.key_cache[layer_idx].shape[2] + ctx_len = cache_kwargs.get("CCL", self.key_cache[layer_idx].shape[2]) ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit @@ -426,11 +427,12 @@ def update( all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) + rolling_indices = rolling_indices[:ctx_len] final_indices = torch.where( (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices ) - k_out = CtxGatherFunc.apply(k_out, final_indices) - v_out = CtxGatherFunc.apply(v_out, final_indices) + k_out = CtxGatherFunc.apply(k_out, final_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, final_indices, ctx_len) ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out @@ -516,7 +518,8 @@ def update( k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] # Original Gather - ctx_len = min(layer_ctx_len, k_out.shape[2]) + ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) + ctx_len = min(layer_ctx_len, ctx_len) ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit @@ -529,11 +532,12 @@ def update( # Rolling indices for sliding window all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) + rolling_indices = rolling_indices[:ctx_len] final_indices = torch.where( (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices ) - k_out = CtxGatherFunc.apply(k_out, final_indices) - v_out = CtxGatherFunc.apply(v_out, final_indices) + k_out = CtxGatherFunc.apply(k_out, final_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, final_indices, ctx_len) ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out @@ -637,7 +641,11 @@ def update( k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] # Original Gather - ctx_len = self.key_cache[layer_idx].shape[2] + if is_sliding_layer: + ctx_len = self.key_cache[layer_idx].shape[2] + else: + ctx_len = cache_kwargs.get("CCL", self.key_cache[layer_idx].shape[2]) + ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit @@ -648,11 +656,11 @@ def update( ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out diff --git a/QEfficient/transformers/models/codegen/modeling_codegen.py b/QEfficient/transformers/models/codegen/modeling_codegen.py index 776bfce43..15efa2ce5 100644 --- a/QEfficient/transformers/models/codegen/modeling_codegen.py +++ b/QEfficient/transformers/models/codegen/modeling_codegen.py @@ -72,6 +72,7 @@ def forward( self, hidden_states: Optional[torch.FloatTensor], layer_past: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -123,7 +124,9 @@ def forward( query = query.permute(0, 2, 1, 3) if layer_past is not None: - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs) # compute self-attention: V x Softmax(QK^T) @@ -147,6 +150,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -245,6 +249,7 @@ def forward( outputs = block( hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, attention_mask=attention_mask, position_ids=position_ids, @@ -294,6 +299,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -312,6 +318,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, batch_index=batch_index, @@ -348,6 +355,7 @@ def forward( self, hidden_states: Optional[torch.FloatTensor], layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -361,6 +369,7 @@ def forward( attn_outputs, attn_weights = self.attn( hidden_states=hidden_states, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 8f2c3730d..218852b15 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -117,6 +117,7 @@ def forward( attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, layer_past: Optional[Cache] = None, head_mask: Optional[torch.Tensor] = None, @@ -140,7 +141,9 @@ def forward( query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) if layer_past is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) if attention_mask is not None: @@ -172,6 +175,7 @@ def forward( attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, layer_past: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None, head_mask: Optional[torch.Tensor] = None, @@ -195,6 +199,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, alibi=alibi, head_mask=head_mask, @@ -245,6 +250,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -307,6 +313,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, head_mask=head_mask[i], use_cache=use_cache, @@ -352,6 +359,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, @@ -368,6 +376,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, head_mask=head_mask, inputs_embeds=inputs_embeds, diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index eea1e3898..4c64109d8 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -137,6 +137,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -153,7 +154,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -186,6 +189,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -214,6 +218,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -243,6 +248,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -299,6 +305,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -334,6 +341,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -350,6 +358,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index be3ba942d..85bba2989 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -144,6 +144,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -160,8 +161,16 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -194,6 +203,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -226,6 +236,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -266,6 +277,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -338,6 +350,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -381,6 +394,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -404,6 +418,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 20b7036fd..95ee662b4 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -6,7 +6,7 @@ # ----------------------------------------------------------------------------- import copy -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -215,6 +215,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -245,6 +246,8 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin, @@ -253,6 +256,7 @@ def forward( "position_ids": position_ids, "is_sliding": self.is_sliding, "sliding_window_pattern": self.config.sliding_window_pattern, + "CCL": attention_mask.shape[-1], } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -297,6 +301,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -323,6 +328,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -363,6 +369,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -429,6 +436,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -466,6 +474,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -525,6 +534,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -592,7 +602,15 @@ def __init__(self, model): self.config = self.model.config self.lm_head = self.model.lm_head - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape selected = input_ids == self.model.config.image_token_index @@ -603,7 +621,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -620,7 +642,15 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffGemma3DecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): image_features = self.get_image_features(pixel_values=pixel_values) inputs_embeds = self.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape @@ -632,7 +662,11 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -647,6 +681,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -667,24 +703,55 @@ def get_specializations( "ctx_len": ctx_len, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "sliding_window": self.language_model.config.sliding_window, - "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "sliding_window": self.language_model.config.sliding_window, - "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, - }, - ] + if comp_ctx_lengths_prefill and comp_ctx_lengths_decode: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + } + ) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + } + ) + + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + }, + ] + specializations = {} if kv_offload: @@ -694,7 +761,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -719,6 +786,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): ) lang_dynamic_axes[f"past_{kv}.{i}"] = apply_dynamic_axes + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes @@ -767,7 +837,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 896) else: @@ -813,6 +883,9 @@ def get_dummy_inputs(self, kv_offload: bool = False): seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/gpt2/modeling_gpt2.py b/QEfficient/transformers/models/gpt2/modeling_gpt2.py index d68a65430..59d864907 100644 --- a/QEfficient/transformers/models/gpt2/modeling_gpt2.py +++ b/QEfficient/transformers/models/gpt2/modeling_gpt2.py @@ -65,6 +65,7 @@ def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -118,9 +119,11 @@ def forward( if (past_key_value is not None and not is_cross_attention) or ( past_key_value is not None and is_cross_attention and not is_updated ): + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # save all key/value_layer to cache to be re-used for fast auto-regressive generation # Update the cache_kwargs with position_ids for Cloud AI 100 - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) @@ -156,6 +159,7 @@ def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -174,6 +178,7 @@ def forward( hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, + comp_ctx_lengths=comp_ctx_lengths, position_ids=position_ids, batch_index=batch_index, head_mask=head_mask, @@ -232,6 +237,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -341,6 +347,7 @@ def forward( outputs = block( hidden_states, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -392,6 +399,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -418,6 +426,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index af233870b..cb6a0f0d0 100644 --- a/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -98,6 +98,7 @@ def forward( self, hidden_states: torch.Tensor, layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -151,8 +152,10 @@ def forward( ) if layer_past is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key, value = curr_past_key_value.update(key, value, self.layer_idx, cache_kwargs) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if self.is_cross_attention: @@ -180,6 +183,7 @@ def forward( self, hidden_states: Optional[Tuple[torch.Tensor]], layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -194,6 +198,7 @@ def forward( attn_outputs = self.attn( hidden_states, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -242,6 +247,7 @@ def forward( self, input_ids: Optional[torch.Tensor] = None, past_key_values: Optional[list[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, @@ -333,6 +339,7 @@ def forward( outputs = block( hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, position_ids=position_ids, batch_index=batch_index, attention_mask=attention_mask, @@ -374,6 +381,7 @@ def forward( self, input_ids: Optional[torch.Tensor] = None, past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, @@ -399,6 +407,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 62bc849b7..576326921 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -417,6 +417,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, sliding_mask=None, @@ -433,6 +434,8 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin, @@ -442,6 +445,7 @@ def forward( "config": self.config, "is_sliding": self.sliding_window is not None, "sliding_window": past_key_value.sliding_window_len, + "CCL": attention_mask.shape[-1], } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -476,6 +480,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -492,6 +497,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -526,6 +532,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -586,6 +593,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, output_attentions=output_attentions, @@ -619,6 +627,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -670,6 +679,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/gptj/modeling_gptj.py b/QEfficient/transformers/models/gptj/modeling_gptj.py index dc3e5e6d2..da5bd881c 100644 --- a/QEfficient/transformers/models/gptj/modeling_gptj.py +++ b/QEfficient/transformers/models/gptj/modeling_gptj.py @@ -83,6 +83,7 @@ def forward( self, hidden_states: torch.FloatTensor, layer_past: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -134,7 +135,9 @@ def forward( query = query.permute(0, 2, 1, 3) if layer_past is not None: - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs) # compute self-attention: V x Softmax(QK^T) @@ -151,6 +154,7 @@ def forward( self, hidden_states: Optional[torch.FloatTensor], layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -164,6 +168,7 @@ def forward( attn_outputs, attn_weights = self.attn( hidden_states=hidden_states, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -191,6 +196,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -270,6 +276,7 @@ def forward( outputs = block( hidden_states=hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, position_ids=position_ids, batch_index=batch_index, @@ -314,6 +321,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -339,6 +347,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 2a2d47d6d..dd3d6c7f3 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -129,6 +129,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -145,8 +146,16 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -171,6 +180,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -226,6 +236,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -267,6 +278,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -319,6 +331,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index c085f6a5e..07031d7fc 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -123,6 +123,7 @@ def forward( position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -142,6 +143,8 @@ def forward( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin, @@ -149,6 +152,7 @@ def forward( "cache_position": cache_position, "batch_index": batch_index, "position_ids": position_ids, + "CCL": attention_mask.shape[-1], } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -209,6 +213,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -286,6 +291,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -297,6 +303,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -492,6 +499,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -546,6 +554,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py index 567a8e070..a0f9cd915 100644 --- a/QEfficient/transformers/models/grok_1/modeling_grok1.py +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -55,6 +55,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -93,7 +94,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -205,6 +208,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, @@ -235,6 +239,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -277,6 +282,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -351,6 +357,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -395,6 +402,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -441,6 +449,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index 38d0fe167..96c59325f 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +from typing import List, Optional + import torch import torch.nn as nn import torch.nn.functional as F @@ -34,7 +36,15 @@ def __init__(self, model): self.config = self.model.language_model.config self.language_model = self.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): input_embeds = self.model.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape image_input_embeds = input_embeds.reshape(B * N, C) @@ -55,7 +65,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) inputs_embeds = inputs_embeds.reshape(B, N, C) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) return outputs.logits, vision_embeds, image_idx, outputs.past_key_values @@ -74,6 +88,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -104,24 +120,54 @@ def get_specializations( "batched_num_patches": batch_size * num_patches, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "num_patches": num_patches, - "img_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "num_patches": num_patches, - "img_size": img_size, - "vision_size": vision_size, - }, - ] + if comp_ctx_lengths_prefill and comp_ctx_lengths_decode: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + } + ) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + } + ) + + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + }, + ] specializations = {} @@ -132,7 +178,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -146,6 +192,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes @@ -173,7 +222,7 @@ def get_output_names(self, kv_offload: bool = False): return lang_output_names return output_names - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", constants.INTERN_IMG_SIZE) else: @@ -234,6 +283,9 @@ def get_dummy_inputs(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -244,7 +296,15 @@ def get_dummy_inputs(self, kv_offload: bool = False): return inputs - def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + pixel_values, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): input_embeds = self.language_model.get_input_embeddings()(input_ids) vision_embeds = self.extract_feature(pixel_values) B, N, C = input_embeds.shape @@ -266,7 +326,11 @@ def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_val inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) inputs_embeds = inputs_embeds.reshape(B, N, C) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index f2a68f80e..58d174270 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -132,6 +132,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -154,7 +155,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -187,6 +190,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -202,6 +206,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -229,6 +234,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -277,6 +283,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -310,6 +317,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -326,6 +334,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index b7b951101..8c8eb4371 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -470,6 +470,7 @@ def forward( position_embeddings: Tuple[torch.Tensor, torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -503,6 +504,8 @@ def forward( if past_key_value is not None: chunk_position_ids = position_ids + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] if self.use_rope: chunk_position_ids = torch.where( @@ -510,7 +513,11 @@ def forward( ) # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_position_ids} + cache_kwargs = { + "batch_index": batch_index, + "position_ids": chunk_position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -543,6 +550,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, @@ -562,6 +570,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -615,6 +624,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -682,6 +692,7 @@ def forward( attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -731,6 +742,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -754,6 +766,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -844,6 +857,7 @@ def forward( image_idx, past_key_values, batch_index: Optional[torch.LongTensor] = None, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) selected = input_ids == self.model.config.image_token_index @@ -857,6 +871,7 @@ def forward( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=True, ) @@ -872,7 +887,15 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffLlama4DecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): inputs_embeds = self.language_model.get_input_embeddings()(input_ids) vision_feature_layer = self.config.vision_config.vision_feature_layer vision_feature_select_strategy = self.config.vision_config.vision_feature_select_strategy @@ -892,7 +915,11 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -911,6 +938,9 @@ def get_specializations( **compiler_options, ): max_num_tiles = compiler_options.pop("max_num_tiles", None) + comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode", None) + if max_num_tiles is None: logger.warning( "User should pass `max_num_tiles` to compile API to fix the dynamic axes `pixel_values`, you can get more info by calling get_inputs_info function!, Since its not found setting its value to 17" @@ -957,41 +987,87 @@ def get_specializations( } ] - lang_prefill = { - "batch_size": 1 if continuous_batching else batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - } - if continuous_batching: - lang_prefill["full_batch_size"] = kv_cache_batch_size - else: - lang_prefill["batch_size"] = kv_cache_batch_size - if full_batch_size: - lang_prefill["full_batch_exec_size"] = full_batch_size - - lang_decode = { - "batch_size": full_batch_size if continuous_batching else batch_size, - "seq_len": 1, - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - } - if continuous_batching: - lang_decode["full_batch_size"] = kv_cache_batch_size + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang.append(lang_decode) + else: - lang_decode["batch_size"] = kv_cache_batch_size + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size - lang = [] - lang.append(lang_prefill) - lang.append(lang_decode) + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -1004,7 +1080,9 @@ def get_specializations( lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -1026,6 +1104,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: b for kv in ["key", "value"]: lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes @@ -1079,7 +1160,9 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False): + def get_dummy_inputs( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 336) else: @@ -1142,6 +1225,10 @@ def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = if continuous_batching: lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 9fd1ed782..5b36b1019 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -89,6 +89,7 @@ def forward( hidden_states: torch.Tensor, position_ids: torch.LongTensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: torch.Tensor = None, batch_index: Optional[torch.LongTensor] = None, ) -> torch.Tensor: @@ -105,8 +106,10 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] kv_seq_len = past_key_value.get_seq_length(self.layer_idx) - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) @@ -155,6 +158,7 @@ def forward( hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, + comp_ctx_lengths, causal_mask, batch_index: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -166,6 +170,7 @@ def forward( hidden_states=hidden_states, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, batch_index=batch_index, ) @@ -201,11 +206,19 @@ def __init__(self, config: QEffLlamaSwiftKVConfig): self.norm_swiftkv = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def _run_swiftkv_layers( - self, hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, causal_mask, batch_index + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + past_key_values, + comp_ctx_lengths, + causal_mask, + batch_index, ) -> torch.Tensor: for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): layer = self.layers[layer_idx] - hidden_states = layer(hidden_states, position_ids, past_key_values, causal_mask, batch_index) + hidden_states = layer( + hidden_states, position_ids, past_key_values, comp_ctx_lengths, causal_mask, batch_index + ) hidden_states = self.norm(hidden_states) return hidden_states, past_key_values @@ -289,6 +302,7 @@ def forward( input_ids: Optional[torch.Tensor], position_ids: torch.Tensor, past_key_values: List[torch.Tensor], + comp_ctx_lengths: Optional[torch.LongTensor], batch_index: Optional[torch.LongTensor] = None, ): inputs_embeds = self.embed_tokens(input_ids) @@ -328,6 +342,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=False, use_cache=True, @@ -373,7 +388,7 @@ def forward( causal_mask = causal_mask[torch.arange(bsz).reshape(-1, 1), :, last_pos_id, :] hidden_states, next_decoder_cache = self._run_swiftkv_layers( - hidden_states, position_ids, past_key_values, causal_mask, batch_index + hidden_states, position_ids, past_key_values, comp_ctx_lengths, causal_mask, batch_index ) # We can fill the orig_hidden_states with the processed hidden_states here but it's not needed as for next token prediction # we only need the last valid pos_indices hidden_states. @@ -405,9 +420,12 @@ def forward( input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, ): - hidden_states, output_past_key_values = self.model(input_ids, position_ids, past_key_values, batch_index) + hidden_states, output_past_key_values = self.model( + input_ids, position_ids, past_key_values, comp_ctx_lengths, batch_index + ) logits = self.lm_head(hidden_states) return CausalLMOutputWithPast( loss=None, diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index e260beb05..dc6653db0 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +from typing import List, Optional + import torch import torch.nn as nn import torch.utils.checkpoint @@ -51,7 +53,15 @@ def __init__(self, model): self.language_model = self.model.language_model self.lm_head = self.model.lm_head - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.model.config.image_token_index @@ -65,6 +75,7 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, return_dict=True, ) @@ -83,7 +94,15 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEFFLlavaDecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): inputs_embeds = self.get_input_embeddings()(input_ids) # Image features image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) @@ -109,6 +128,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -120,7 +140,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx) return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -150,6 +170,10 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): ) ) lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: @@ -166,6 +190,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -187,24 +213,55 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - }, - ] + + if comp_ctx_lengths_prefill and comp_ctx_lengths_decode: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + } + ) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + } + ) + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + }, + ] + specializations = {} if kv_offload: @@ -214,7 +271,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -230,6 +287,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index 2fa1d9234..2e4848b6b 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -6,6 +6,8 @@ # ----------------------------------------------------------------------------- +from typing import List, Optional + import numpy as np import torch import torch.nn as nn @@ -123,7 +125,15 @@ def __init__(self, model): self.language_model = self.model.language_model self.lm_head = self.model.lm_head - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) image_features = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.config.image_token_index @@ -138,6 +148,7 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -154,7 +165,7 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffLlavaNextDecoderWrapper(self) - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -217,6 +228,10 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): ) ) lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, constants.GRANITEVISION_CTX_LEN - 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -232,6 +247,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -285,30 +302,67 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "image_size_height": image_size_height, - "image_size_width": image_size_width, - "num_patches": num_patches, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "image_size_height": image_size_height, - "image_size_width": image_size_width, - "num_patches": num_patches, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - }, - ] + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + } + ) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + } + ) + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + }, + ] + specializations = {} if kv_offload: specializations["vision"] = vision @@ -317,7 +371,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers vision_dynamic_axes = { @@ -332,6 +386,10 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): for i in range(num_layers): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index ca23cc144..30c73ae8b 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -140,6 +140,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -163,7 +164,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -196,6 +199,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -226,6 +230,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -256,6 +261,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -316,6 +322,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -354,6 +361,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -377,6 +385,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index 735eec9e5..694ed4cde 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -106,6 +106,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, use_cache: Optional[bool] = None, @@ -126,6 +127,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, @@ -166,7 +168,15 @@ def __init__(self, model): self.config = self.model.config self.language_model = self.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.model.config.image_token_index @@ -179,6 +189,7 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds=inputs_embeds_1, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) # Cast to int32 to avoid ONNXRT issue @@ -198,7 +209,15 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEFFMistral3DecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): inputs_embeds = self.get_input_embeddings()(input_ids) image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]).repeat(pixel_values.shape[0], 1) image_features = self.get_image_features( @@ -219,6 +238,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) # Cast to int32 to avoid ONNXRT issue logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -230,7 +250,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) height = self.config.vision_config.image_size @@ -282,6 +302,9 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -298,6 +321,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -323,22 +348,50 @@ def get_specializations( "vision_size": vision_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "image_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "image_size": img_size, - "vision_size": vision_size, - }, - ] + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "image_size": img_size, + "vision_size": vision_size, + } + ) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "image_size": img_size, + "vision_size": vision_size, + } + ) + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "image_size": img_size, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "image_size": img_size, + "vision_size": vision_size, + }, + ] specializations = {} @@ -351,7 +404,7 @@ def get_specializations( lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -368,6 +421,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 9b9e3448a..6e61568ac 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -137,6 +137,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -159,7 +160,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -245,6 +248,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -282,6 +286,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -314,6 +319,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -375,6 +381,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, @@ -412,6 +419,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -435,6 +443,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index cb24f1de4..d6fb1dcd2 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -177,6 +177,7 @@ def forward( hidden_states: torch.Tensor, cross_attention_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = None, @@ -249,6 +250,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, position_embeddings: torch.Tensor = None, use_cache: bool = False, @@ -278,9 +280,12 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs = { "batch_index": batch_index, "position_ids": position_ids, + "CCL": attention_mask.shape[-1], } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -316,6 +321,7 @@ def forward( full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -350,6 +356,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -379,6 +386,7 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = None, @@ -396,13 +404,15 @@ def forward( key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # if we have a new image + new tokens, we only computed key_states on that new image # we still update the cross key states, past_image, new_image. And use it! key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, - {"batch_index": batch_index, "position_ids": position_ids}, + {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]}, ) elif past_key_value is not None: key_states, value_states = ( @@ -448,6 +458,7 @@ def forward( full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -461,6 +472,7 @@ def forward( attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, cache_position=cache_position, ) @@ -594,6 +606,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cross_attention_states: Optional[torch.FloatTensor] = None, cross_attention_mask: Optional[torch.Tensor] = None, @@ -658,6 +671,7 @@ def forward( full_text_row_masked_out_mask=full_text_row_masked_out_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, ) @@ -688,6 +702,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cross_attention_states: Optional[torch.LongTensor] = None, cross_attention_mask: Optional[torch.LongTensor] = None, @@ -707,6 +722,7 @@ def forward( cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, @@ -774,6 +790,7 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, @@ -820,6 +837,7 @@ def forward( cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, inputs_embeds=inputs_embeds, cache_position=cache_position, @@ -853,6 +871,7 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -869,6 +888,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, @@ -879,7 +899,7 @@ def forward( logits = self.lm_head(hidden_states).float() return logits, image_idx, outputs.past_key_values, pixel_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE SEQ_LEN = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN CTX_LEN = constants.ONNX_EXPORT_CTX_LEN @@ -943,6 +963,10 @@ def get_dummy_inputs(self, kv_offload: bool = False): lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: @@ -959,6 +983,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -973,22 +999,53 @@ def get_specializations( logger.warning("Setting `img_size=448` as it was neither passed nor found in vision_config") vision = [{"batch_size": batch_size, "max_num_images": max_num_images, "img_size": img_size}] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - }, - ] + + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "max_num_images": max_num_images, + "img_size": img_size, + } + ) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "max_num_images": max_num_images, + "img_size": img_size, + } + ) + + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + }, + ] + specializations = {} if kv_offload: @@ -998,7 +1055,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): txt_cfg = self.config.get_text_config() num_hidden_layers = txt_cfg.num_hidden_layers cross_attention_layers = txt_cfg.cross_attention_layers @@ -1023,6 +1080,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 60f60c768..5f1ec51e6 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -59,6 +59,7 @@ constants, get_padding_shape_from_config, ) +from QEfficient.utils.check_ccl_specializations import process_ccl_specializations from QEfficient.utils.logging_utils import logger @@ -860,6 +861,7 @@ def __init__( self, model: nn.Module, continuous_batching: bool = False, + qaic_config: Optional[dict] = None, **kwargs, ): """ @@ -881,6 +883,9 @@ def __init__( raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") self.model = model self.config = model.config + + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config) + self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) self.continuous_batching = continuous_batching @@ -902,7 +907,7 @@ def model_name(self) -> str: return mname @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs): """ Load a QEfficient multimodal model for dual QPC from a pretrained HuggingFace model or local path. @@ -928,7 +933,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs, + ) @property def onnx_path(self): @@ -985,13 +995,21 @@ def export( """ # TODO This is a temporary change as continous batching is enabled only for few models. Once support is added for all the models this exception handing can be removed. try: - inputs = self.model.get_dummy_inputs(kv_offload=True, continuous_batching=self.continuous_batching) + inputs = self.model.get_dummy_inputs( + kv_offload=True, + continuous_batching=self.continuous_batching, + comp_ctx_lengths=self.comp_ctx_lengths_decode, + ) dynamic_axes = self.model.get_onnx_dynamic_axes( - kv_offload=True, continuous_batching=self.continuous_batching + kv_offload=True, + continuous_batching=self.continuous_batching, + comp_ctx_lengths=self.comp_ctx_lengths_decode, ) except TypeError: - inputs = self.model.get_dummy_inputs(kv_offload=True) - dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) + inputs = self.model.get_dummy_inputs(kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode) + dynamic_axes = self.model.get_onnx_dynamic_axes( + kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode + ) output_names = self.model.get_output_names(kv_offload=True) self.vision_model.export( @@ -1096,10 +1114,17 @@ def compile( output_names = self.model.get_output_names(kv_offload=True) + # For supporting VLLM and Disaggregated with CCL + if "comp_ctx_lengths_prefill" in compiler_options: + self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill") + self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode") + specializations, compiler_options = self.model.get_specializations( batch_size=batch_size, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, img_size=img_size, kv_offload=True, continuous_batching=self.continuous_batching, @@ -1246,6 +1271,8 @@ def generate( device_id=device_ids, # if device_ids is not None else [0], ctx_len=ctx_len_comp, full_batch_size=fbs, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, ) # Call generate method @@ -1393,11 +1420,23 @@ def kv_offload_generate( lang_session.set_buffers(vision_outputs) + if self.comp_ctx_lengths_prefill is not None: + list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill] + prefill_ccl_id = 0 + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + lang_start = perf_counter() # Run prefill chunk_inputs = lang_inputs.copy() for i in range(num_chunks): + if ( + self.comp_ctx_lengths_prefill is not None + and (i + 1) * prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id] + ): + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] chunk_inputs["position_ids"] = lang_inputs["position_ids"][ ..., i * prefill_seq_len : (i + 1) * prefill_seq_len @@ -1429,8 +1468,25 @@ def kv_offload_generate( streamer.put(lang_inputs["input_ids"][0]) # Decode loop + if self.comp_ctx_lengths_decode is not None: + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode] + max_position_id = np.max(lang_inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + decode_start = perf_counter() for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if max_position_id >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + outputs = lang_session.run(lang_inputs) # Prepare inputs for next iteration @@ -1480,6 +1536,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal def __init__( self, model: nn.Module, + qaic_config: Optional[dict] = None, **kwargs, ): """ @@ -1501,6 +1558,8 @@ def __init__( raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") super().__init__(model, **kwargs) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config) + # to handle internvl models if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"): self.model.config.llm_config.use_cache = True @@ -1517,6 +1576,7 @@ def __init__( def from_pretrained( cls, pretrained_model_name_or_path, + qaic_config: Optional[dict] = None, *args, **kwargs, ): @@ -1554,7 +1614,12 @@ def from_pretrained( config.vision_config.use_flash_attn = "false" model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs) - return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs, + ) def export( self, @@ -1576,8 +1641,8 @@ def export( str Path to the generated ONNX graph file. """ - inputs = self.model.get_dummy_inputs() - dynamic_axes = self.model.get_onnx_dynamic_axes() + inputs = self.model.get_dummy_inputs(comp_ctx_lengths=self.comp_ctx_lengths_decode) + dynamic_axes = self.model.get_onnx_dynamic_axes(comp_ctx_lengths=self.comp_ctx_lengths_decode) output_names = self.model.get_output_names() return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) @@ -1651,14 +1716,24 @@ def compile( f"full_batch_size={full_batch_size}, kv_cache_batch_size={kv_cache_batch_size}, num_speculative_tokens={num_speculative_tokens}, " ) + # Infer kv_cache_batch_size if not provided + kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size output_names = self.model.get_output_names() + # For supporting VLLM and Disaggregated with CCL + if "comp_ctx_lengths_prefill" in compiler_options: + self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill") + self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode") + # Get specializations from modelling file # TODO: expose this via the auto class as well specializations, compiler_options = self.model.get_specializations( batch_size=batch_size, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, + kv_cache_batch_size=kv_cache_batch_size, img_size=img_size, **compiler_options, ) @@ -1677,6 +1752,11 @@ def compile( if output_name.endswith("_RetainedState"): custom_io[output_name] = "float16" if "pixel_values" in output_name else kv_cache_dtype + # TODO this hould be removed once the continous batching is supported for all the models. + compiler_options.pop("continuous_batching", None) + compiler_options.pop("kv_cache_batch_size", None) + compiler_options.pop("full_batch_size", None) + self._compile( onnx_path=onnx_path, compile_dir=compile_dir, @@ -1843,12 +1923,24 @@ def cloud_ai_100_generate( inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) inputs["image_idx"] = np.array([[0]]) + if self.comp_ctx_lengths_prefill is not None: + list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill] + prefill_ccl_id = 0 + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + qpc_session.activate() chunk_inputs = inputs.copy() prefill_start = perf_counter() # Run prefill for i in range(num_chunks): + if ( + self.comp_ctx_lengths_prefill is not None + and (i + 1) * prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id] + ): + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] outputs = qpc_session.run(chunk_inputs) @@ -1872,8 +1964,25 @@ def cloud_ai_100_generate( inputs.pop("pixel_values") # Decode loop + if self.comp_ctx_lengths_decode is not None: + list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode] + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + max_position_id = np.max(inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + decode_start = perf_counter() for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if max_position_id >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + outputs = qpc_session.run(inputs) # Prepare inputs for next iteration inputs["input_ids"] = outputs["logits"].argmax(2) @@ -1991,7 +2100,14 @@ class QEFFAutoModelForImageTextToText: _hf_auto_class = AutoModelForImageTextToText - def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, continuous_batching: bool = False, **kwargs): + def __new__( + self, + model: nn.Module, + kv_offload: Optional[bool] = True, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + **kwargs, + ): """ Instantiate the appropriate internal class for single or dual QPC mode. @@ -2012,9 +2128,11 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, continuou The wrapped model instance, configured for either dual or single QPC. """ if kv_offload: - return _QEffAutoModelForImageTextToTextDualQPC(model, continuous_batching, **kwargs) + return _QEffAutoModelForImageTextToTextDualQPC( + model, continuous_batching, qaic_config=qaic_config, **kwargs + ) else: - return _QEFFAutoModelForImageTextToTextSingleQPC(model, **kwargs) + return _QEFFAutoModelForImageTextToTextSingleQPC(model, qaic_config=qaic_config, **kwargs) @classmethod @with_replaced_quantizers @@ -2023,6 +2141,7 @@ def from_pretrained( pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, continuous_batching: bool = False, + qaic_config: Optional[dict] = None, **kwargs, ): """ @@ -2063,12 +2182,14 @@ def from_pretrained( logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( model, kv_offload=kv_offload, continuous_batching=continuous_batching, pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, **kwargs, ) @@ -2163,6 +2284,9 @@ def __init__( ) # Set use_cache=True to get KV values as output during ONNX export model.config.use_cache = True + + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config) + super().__init__(model, qaic_config=qaic_config, **kwargs) self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching @@ -2273,7 +2397,11 @@ def from_pretrained( if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( - model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs + model, + kv_offload=kv_offload, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs, ) return cls( model, @@ -2329,6 +2457,10 @@ def export(self, export_dir: Optional[str] = None) -> str: "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, } + if self.comp_ctx_lengths_prefill is not None: + example_inputs["comp_ctx_lengths"] = torch.randint(0, 512, (512,), dtype=torch.long) + dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d pkv_dynamic_axes = { 0: "full_batch_size" if self.continuous_batching else "batch_size", @@ -2485,6 +2617,7 @@ def build_prefill_specialization( self, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths: Optional[int] = None, batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, @@ -2522,6 +2655,8 @@ def build_prefill_specialization( "seq_len": prefill_seq_len, "ctx_len": ctx_len, } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths spec["num_logits_to_keep"] = 1 if self.is_tlm else None if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size @@ -2535,6 +2670,7 @@ def build_decode_specialization( self, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths: Optional[int] = None, batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, @@ -2579,6 +2715,9 @@ def build_decode_specialization( "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, "ctx_len": ctx_len, } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths + spec["num_logits_to_keep"] = (num_speculative_tokens + 1) if self.is_tlm else None if self.continuous_batching: @@ -2681,6 +2820,25 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ + + # For supporting VLLM and Disaggregated with CCL + if "comp_ctx_lengths_prefill" in compiler_options and "comp_ctx_lengths_decode" in compiler_options: + comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill") + comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode") + if isinstance(comp_ctx_lengths_prefill, str): + import ast + + try: + # Safely evaluate the string to a Python list for disaggregated input + self.comp_ctx_lengths_prefill = ast.literal_eval(comp_ctx_lengths_prefill) + self.comp_ctx_lengths_decode = ast.literal_eval(comp_ctx_lengths_decode) + + except (ValueError, SyntaxError): + raise ValueError("Invalid format for comp_ctx_lengths. Expected a list-like string.") + else: + self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill + self.comp_ctx_lengths_decode = comp_ctx_lengths_decode + # --- Validation --- if prefill_only is not None and not isinstance(prefill_only, bool): raise TypeError("`prefill_only` must be a boolean.") @@ -2711,26 +2869,58 @@ def compile( # --- Specializations --- specializations = [] if prefill_only is None or prefill_only or prefill_seq_len == 1: - specializations.append( - self.build_prefill_specialization( + if self.comp_ctx_lengths_prefill is not None: + # Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization + for i in range(0, len(self.comp_ctx_lengths_prefill)): + specializations.append( + self.build_prefill_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths=self.comp_ctx_lengths_prefill[i], + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + ) + ) + + else: + specializations.append( + self.build_prefill_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + ) + ) + + if prefill_only is None or not prefill_only: + if self.comp_ctx_lengths_decode is not None: + # Adding elements from self.comp_ctx_lengths_decode to decode_specialization + for i in range(0, len(self.comp_ctx_lengths_decode)): + decode_spec = self.build_decode_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths=self.comp_ctx_lengths_decode[i], + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + ) + if decode_spec: + specializations.append(decode_spec) + + else: + decode_spec = self.build_decode_specialization( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, batch_size=batch_size, kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, ) - ) - if prefill_only is None or not prefill_only: - decode_spec = self.build_decode_specialization( - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - batch_size=batch_size, - kv_cache_batch_size=kv_cache_batch_size, - full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, - ) - if decode_spec: - specializations.append(decode_spec) + if decode_spec: + specializations.append(decode_spec) # --- Compilation --- kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" @@ -2808,6 +2998,8 @@ def generate( tokenizer=tokenizer, qpc_path=self.qpc_path, prompt=prompts, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, device_id=device_id, generation_len=generation_len, automation=kwargs.pop("automation", False), diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py index 4f92316ca..48a8fbf7b 100644 --- a/QEfficient/transformers/models/molmo/modeling_molmo.py +++ b/QEfficient/transformers/models/molmo/modeling_molmo.py @@ -243,6 +243,7 @@ def attention( attention_bias: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, **kwargs, @@ -278,8 +279,16 @@ def attention( q, k = qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, self.config) if layer_past is not None: + if comp_ctx_lengths is not None: + attention_bias = attention_bias[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_bias.shape[-1], + } k, v = layer_past.update(k, v, self.layer_id, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -311,6 +320,7 @@ def forward( attention_bias: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, **kwargs, @@ -334,6 +344,7 @@ def forward( attention_bias, position_ids=position_ids, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, ) @@ -380,6 +391,7 @@ def forward( subsegment_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, last_logits_only: bool = False, @@ -496,6 +508,7 @@ def forward( attention_bias=causal_mask, position_ids=position_ids, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, ) @@ -518,6 +531,7 @@ def forward( attention_bias=causal_mask, position_ids=position_ids, layers_past=layers_past, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, ) @@ -574,7 +588,15 @@ def __init__(self, model): # self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): if input_ids is not None: input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) inputs_embeds = self.model.model.transformer.wte(input_ids) @@ -587,7 +609,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va # inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.model.forward( - input_embeddings=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + input_embeddings=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -608,7 +634,16 @@ def get_qeff_language_decoder(self): """ def forward( - self, pixel_values, image_masks, image_input_idx, valid_idx, input_ids, position_ids, image_idx, past_key_values + self, + pixel_values, + image_masks, + image_input_idx, + valid_idx, + input_ids, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): image_features, _ = self.model.vision_backbone(pixel_values, image_masks) num_image, num_patch = image_features.shape[1:3] @@ -637,7 +672,11 @@ def forward( inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.forward( - input_embeddings=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + input_embeddings=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -651,6 +690,8 @@ def get_specializations( ctx_len: int, num_images: int = None, img_size: int = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, valid_size: int = None, kv_offload: bool = False, **compiler_options, @@ -679,30 +720,77 @@ def get_specializations( } ] - lang_prefill = { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "valid_size": valid_size, - } - - lang_decode = {"batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, "valid_size": valid_size} + if comp_ctx_lengths_prefill is not None and comp_ctx_lengths_decode is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "valid_size": valid_size, + } + if kv_offload: + values = { + "img_size": img_size, + "img_tile": img_tile, + "num_images": num_images, + "num_patch": num_patch, + } + + for key, value in values.items(): + lang_prefill[key] = value + + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "valid_size": valid_size, + } + if kv_offload: + values = { + "img_size": img_size, + "img_tile": img_tile, + "num_images": num_images, + "num_patch": num_patch, + } + + for key, value in values.items(): + lang_decode[key] = value + + lang.append(lang_decode) - if kv_offload: - values = { - "img_size": img_size, - "img_tile": img_tile, - "num_images": num_images, - "num_patch": num_patch, + else: + lang_prefill = { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "valid_size": valid_size, } - for key, value in values.items(): - lang_prefill[key] = value - lang_decode[key] = value + lang_decode = {"batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, "valid_size": valid_size} + + if kv_offload: + values = { + "img_size": img_size, + "img_tile": img_tile, + "num_images": num_images, + "num_patch": num_patch, + } + + for key, value in values.items(): + lang_prefill[key] = value + lang_decode[key] = value + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) - lang = [] - lang.append(lang_prefill) - lang.append(lang_decode) specializations = {} if kv_offload: @@ -712,7 +800,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -731,6 +819,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes @@ -760,7 +851,7 @@ def get_output_names(self, kv_offload: bool = False): return lang_output_names return output_names - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): inputs_shapes = {} inputs_shapes_lang = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -823,6 +914,9 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/mpt/modeling_mpt.py b/QEfficient/transformers/models/mpt/modeling_mpt.py index 9bf6a4422..16ca54051 100644 --- a/QEfficient/transformers/models/mpt/modeling_mpt.py +++ b/QEfficient/transformers/models/mpt/modeling_mpt.py @@ -39,6 +39,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, ): @@ -51,7 +52,9 @@ def forward( value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale @@ -101,6 +104,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, use_cache: bool = False, output_attentions: bool = False, ): @@ -118,6 +122,7 @@ def forward( batch_index=batch_index, attention_mask=attention_mask, past_key_value=layer_past, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, ) @@ -144,6 +149,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -205,6 +211,7 @@ def forward( outputs = block( hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, position_ids=position_ids, batch_index=batch_index, @@ -250,6 +257,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -271,6 +279,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index 0d23729c1..bf946da39 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -132,6 +132,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -154,8 +155,10 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -188,6 +191,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -203,6 +207,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -233,6 +238,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -286,6 +292,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -322,6 +329,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -343,6 +351,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/phi/modeling_phi.py b/QEfficient/transformers/models/phi/modeling_phi.py index 18557f1ca..a5e53216a 100644 --- a/QEfficient/transformers/models/phi/modeling_phi.py +++ b/QEfficient/transformers/models/phi/modeling_phi.py @@ -67,6 +67,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -104,8 +105,16 @@ def forward( key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # Update the cache_kwargs with position_ids for Cloud AI 100 - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -140,6 +149,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, @@ -181,6 +191,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -213,6 +224,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -274,6 +286,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -316,6 +329,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -370,6 +384,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 4b5234a5a..851395f08 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -140,6 +140,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, position_ids=Optional[torch.Tensor], past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -162,9 +163,12 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs = { "batch_index": batch_index, "position_ids": position_ids, + "CCL": attention_mask.shape[-1], } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -198,6 +202,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -235,6 +240,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -265,6 +271,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -314,6 +321,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -350,6 +358,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -366,6 +375,7 @@ def forward( batch_index=batch_index, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, output_hidden_states=output_hidden_states, diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 24e8df46c..1aca7039d 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -150,6 +150,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -166,7 +167,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -200,6 +203,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -231,6 +235,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -261,6 +266,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -313,6 +319,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -348,6 +355,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -364,6 +372,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 445c15583..7a9ec8ea8 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -485,6 +485,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -511,8 +512,16 @@ def forward( ) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids[0]} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids[0], + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -543,6 +552,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -582,6 +592,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -614,6 +625,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -664,6 +676,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -702,6 +715,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -725,6 +739,7 @@ def forward( position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -774,6 +789,7 @@ def forward( image_idx, past_key_values, batch_index: Optional[torch.LongTensor] = None, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape @@ -788,6 +804,7 @@ def forward( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=True, ) @@ -807,7 +824,13 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffQwen_2_5_vl_DecoderWrapper(self) - def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False, **kwargs): + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + **kwargs, + ): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -862,6 +885,9 @@ def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = if continuous_batching: lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -887,6 +913,9 @@ def get_specializations( full_batch_size: Optional[int] = None, **compiler_options, ): + comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode", None) + if height is None or width is None: height = constants.QWEN2_5_VL_HEIGHT width = constants.QWEN2_5_VL_WIDTH @@ -966,37 +995,77 @@ def smart_resize( "grid_w": grid_w, } ] - lang_prefill = { - "batch_size": 1 if continuous_batching else batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "vision_size": vision_size, - "vision_batch_size": batch_size, - } - if continuous_batching: - lang_prefill["full_batch_size"] = kv_cache_batch_size + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "vision_size": vision_size, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang.append(lang_decode) else: - lang_prefill["batch_size"] = kv_cache_batch_size - if full_batch_size: - lang_prefill["full_batch_exec_size"] = full_batch_size - - lang_decode = { - "batch_size": full_batch_size if continuous_batching else batch_size, - "seq_len": 1, - "ctx_len": ctx_len, - "vision_size": vision_size, - "vision_batch_size": batch_size, - } + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } - if continuous_batching: - lang_decode["full_batch_size"] = kv_cache_batch_size - else: - lang_decode["batch_size"] = kv_cache_batch_size + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size - lang = [] - lang.append(lang_prefill) - lang.append(lang_decode) + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -1009,7 +1078,9 @@ def smart_resize( lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -1037,6 +1108,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: b if continuous_batching: lang_dynamic_axes["batch_index"] = {0: "batch_size"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index ecdb36019..ccf918c2c 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -151,6 +151,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -167,7 +168,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -201,6 +204,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -232,6 +236,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -262,6 +267,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -314,6 +320,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -349,6 +356,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -367,6 +375,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 591f7c1b0..c8a5ae2fd 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -201,6 +201,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -217,7 +218,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -243,6 +246,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -274,6 +278,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -300,6 +305,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, batch_index: Optional[torch.LongTensor] = None, @@ -342,6 +348,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -369,6 +376,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -385,6 +393,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, batch_index=batch_index, use_cache=use_cache, diff --git a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py index 9a327761d..075b8aedb 100644 --- a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py +++ b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py @@ -69,6 +69,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -84,7 +85,9 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -118,6 +121,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -153,6 +157,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -184,6 +189,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -237,6 +243,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -273,6 +280,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -289,6 +297,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/whisper/modeling_whisper.py b/QEfficient/transformers/models/whisper/modeling_whisper.py index e078493a7..79907818d 100644 --- a/QEfficient/transformers/models/whisper/modeling_whisper.py +++ b/QEfficient/transformers/models/whisper/modeling_whisper.py @@ -55,6 +55,7 @@ def forward( position_ids_layer: torch.Tensor = None, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -99,7 +100,9 @@ def forward( key_states = key_states.transpose(1, 2).contiguous() value_states = value_states.transpose(1, 2).contiguous() if past_key_value is not None: - cache_kwargs = {"position_ids": position_ids_layer} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"position_ids": position_ids_layer, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) @@ -181,6 +184,7 @@ def forward( layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.LongTensor] = None, @@ -215,6 +219,7 @@ def forward( hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -388,6 +393,7 @@ def forward( cross_attn_head_mask=None, position_ids=None, past_key_values=None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds=None, use_cache=None, output_attentions=None, @@ -532,6 +538,7 @@ def forward( layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, position_ids_layer=position_ids, @@ -643,6 +650,7 @@ def forward( cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, decoder_inputs_embeds=None, use_cache=None, output_attentions=None, @@ -674,6 +682,7 @@ def forward( head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, @@ -719,6 +728,7 @@ def forward( cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, position_ids: Optional[Tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, @@ -740,6 +750,7 @@ def forward( decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, decoder_inputs_embeds=decoder_inputs_embeds, decoder_position_ids=position_ids, use_cache=use_cache, diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py new file mode 100644 index 000000000..308c69554 --- /dev/null +++ b/QEfficient/utils/check_ccl_specializations.py @@ -0,0 +1,51 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + + +def process_ccl_specializations(qaic_config): + if qaic_config is None: + return None, None + ccl_prefill = qaic_config.pop("comp_ctx_lengths_prefill", None) + ccl_decode = qaic_config.pop("comp_ctx_lengths_decode", None) + ctx_len = qaic_config.pop("ctx_len", None) + prefill_seq_len = qaic_config.pop("prefill_seq_len", 128) + + if ccl_prefill is None or ccl_decode is None: + return None, None + + if ctx_len is None: + raise TypeError("`ctx_len` is required when loading the model with CCL.") + + if prefill_seq_len == 1: + # both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them. + ccl_union_all = sorted(set(ccl_prefill + ccl_decode)) + ccl_union_all = [min(x, ctx_len) for x in ccl_union_all] + return ccl_union_all, ccl_union_all + + # Step 1: Cap values to ctx_len + ccl_prefill = [min(x, ctx_len) for x in ccl_prefill] + ccl_decode = [min(x, ctx_len) for x in ccl_decode] + + # Step 2: Remove duplicates within each list + ccl_prefill = list(set(ccl_prefill)) + ccl_decode = list(set(ccl_decode)) + + # Step 3: Ensure no overlap between ccl_prefill and ccl_decode + updated_prefill = [] + for val in ccl_prefill: + while val in ccl_decode or val in updated_prefill: + val -= 1 + if val < 0: + break # Prevent negative values + if val >= 0: + updated_prefill.append(val) + + # Step 4: Sort both lists + updated_prefill.sort() + ccl_decode.sort() + + return updated_prefill, ccl_decode diff --git a/examples/ccl_gpt_oss.py b/examples/ccl_gpt_oss.py new file mode 100644 index 000000000..b211ba914 --- /dev/null +++ b/examples/ccl_gpt_oss.py @@ -0,0 +1,51 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from transformers import AutoTokenizer, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM + +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 + +ctx_len = 4096 +# In moe models like gpt-oss, since prefill_seq_len=1 both comp_ctx_lengths_prefill and comp_ctx_lengths_decode can share similar lists. +# Set the list of ccl during prefilling process +comp_ctx_lengths_prefill = [512, ctx_len] +# Set the list of ccl during decoding process +comp_ctx_lengths_decode = [512, ctx_len] + + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + "prefill_seq_len": 1, # Passing prefill_seq_len is mandatory for CCL goal in moe models. Currently we can get best perf using PL=1. + }, +) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +onnx_model_path = qeff_model.export() +qpc_path = qeff_model.compile( + prefill_seq_len=1, # Currently we can get best perf using PL=1 i.e. decode-only model, prefill optimizations are being worked on. + ctx_len=ctx_len, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=4, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, +) +print(f"qpc path is {qpc_path}") +streamer = TextStreamer(tokenizer) +exec_info = qeff_model.generate( + tokenizer, + prompts="Who is your creator? and What all you are allowed to do?", + generation_len=256, +) diff --git a/examples/ccl_image_text_to_text_inference.py b/examples/ccl_image_text_to_text_inference.py new file mode 100644 index 000000000..be472f433 --- /dev/null +++ b/examples/ccl_image_text_to_text_inference.py @@ -0,0 +1,137 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +from PIL import Image +from transformers import AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +# Add HuggingFace Token to access the model +HF_TOKEN = "" + + +def run_model( + model_name, + token, + query, + image_url, + kv_offload=False, + prefill_seq_len=32, + ctx_len=512, + comp_ctx_lengths_prefill=None, + comp_ctx_lengths_decode=None, + generation_len=128, + img_size=560, + num_cores=16, + num_devices=1, +): + ## STEP - 1 Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name, token=token) + + # `kv_offload` is used to compile the model in a Single QPC or 2 QPCs. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, + token=token, + attn_implementation="eager", + kv_offload=kv_offload, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, + ) + + ## STEP - 2 Export & Compile the Model + + model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + ) + + ## STEP - 3 Load and process the inputs for Inference + + image = Image.open(requests.get(image_url, stream=True).raw) + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": query}, + ], + } + ] + input_text = [processor.apply_chat_template(messages, add_generation_prompt=True)] + + inputs = processor( + text=input_text, + images=image, + return_tensors="pt", + add_special_tokens=False, + padding="max_length", + max_length=prefill_seq_len, + ) + + ## STEP - 4 Run Inference on the compiled model + + streamer = TextStreamer(processor.tokenizer) + output_statistics = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + print(output_statistics) + + +if __name__ == "__main__": + # Model name and Input parameters + # model_name = "llava-hf/llava-1.5-7b-hf" + model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" + query = "Describe this image." + image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" + + # Compilation parameters for the model + kv_offload = True + prefill_seq_len = 32 + ctx_len = 8192 + generation_len = 128 + # img_size = 336 + img_size = 560 + num_cores = 16 + num_devices = 4 + comp_ctx_lengths_prefill = [4096] + comp_ctx_lengths_decode = [6144, ctx_len] + + run_model( + model_name=model_name, + token=HF_TOKEN, + query=query, + kv_offload=kv_offload, + image_url=image_url, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + generation_len=generation_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + ) + + +""" +Expected Response: + +This image depicts a charming anthropomorphic rabbit standing on a dirt path in front of a picturesque stone cottage, surrounded by a serene landscape. + +The rabbit, with its light brown fur and distinctive long ears, is attired in a stylish blue coat, brown vest, and tan pants, exuding a sense of sophistication. The dirt path, flanked by vibrant flowers and lush greenery, leads to the cottage, which features a thatched roof and a chimney, adding to the rustic charm of the scene. In the background, rolling hills and trees create a breathtaking panorama, while the sky above is a brilliant blue with white clouds, completing the + +""" diff --git a/examples/ccl_llama4_CB_example_vision_lang.py b/examples/ccl_llama4_CB_example_vision_lang.py new file mode 100644 index 000000000..6423ee765 --- /dev/null +++ b/examples/ccl_llama4_CB_example_vision_lang.py @@ -0,0 +1,109 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import transformers +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +ctx_len = 4096 +# Set the list of ccl during prefilling process +comp_ctx_lengths_prefill = [3072] +# Set the list of ccl during decoding process +comp_ctx_lengths_decode = [ctx_len] + +continious_batching = True +if continious_batching: + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, + ) + + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + full_batch_size=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) +else: + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, + ) + + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + +image_urls = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +exec_info = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + device_ids=[32, 33, 34, 35], + generation_len=100, +) + +# print("Generated texts:", exec_info.generated_texts) +print("Generated IDs:", exec_info.generated_ids) +print(exec_info) diff --git a/examples/ccl_llama4_example.py b/examples/ccl_llama4_example.py new file mode 100644 index 000000000..5da29960f --- /dev/null +++ b/examples/ccl_llama4_example.py @@ -0,0 +1,128 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +ctx_len = 8192 +# Set the list of ccl during prefilling process +comp_ctx_lengths_prefill = [3072] +# Set the list of ccl during decoding process +comp_ctx_lengths_decode = [4096, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, + config=config, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Can you describe the image in detail.", + }, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3], generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + ## Vision + Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + ### IMAGE + TEXT ### + image_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + {"type": "text", "text": "Can you describe the image in detail."}, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3], generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + print() diff --git a/examples/ccl_llama4_multi_image_example.py b/examples/ccl_llama4_multi_image_example.py new file mode 100644 index 000000000..33bf07df0 --- /dev/null +++ b/examples/ccl_llama4_multi_image_example.py @@ -0,0 +1,89 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +ctx_len = 8192 +# Set the list of ccl during prefilling process +comp_ctx_lengths_prefill = [5376] +# Set the list of ccl during decoding process +comp_ctx_lengths_decode = [6144, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +### For multi-image, the value of max_num_tiles should be the sum of the num_tiles values across all the images ### +qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=34, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +### Multi_image Prompt ### +image_url_1 = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" +) + + +image_url_2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" + +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url_1}, + {"type": "image", "url": image_url_2}, + { + "type": "text", + "text": "Analyze the key elements, colors, and objects in the two images. Discuss their similarities, differences, and how they complement or contrast each other. Reflect on the emotions or ideas they convey, considering the context, light, shadow, and composition.", + }, + ], + }, +] + +inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", +) + +inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) +streamer = TextStreamer(tokenizer) +output = qeff_model.generate(inputs=inputs, device_ids=[32, 33, 34, 35], generation_len=100) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output) diff --git a/examples/ccl_mistral3_example.py b/examples/ccl_mistral3_example.py new file mode 100644 index 000000000..96ed519f5 --- /dev/null +++ b/examples/ccl_mistral3_example.py @@ -0,0 +1,123 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +from PIL import Image +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + + +def run_model( + model_name, + query, + image_url, + kv_offload=False, + prefill_seq_len=128, + ctx_len=4096, + comp_ctx_lengths_prefill=None, + comp_ctx_lengths_decode=None, + generation_len=128, + img_size=1540, + num_cores=16, + num_devices=4, +): + ## STEP - 1 Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name) + + # `kv_offload` is used to compile the model in a 2 QPCs.Currently we are not supporting 1 qpc so the flag false is not allowed. + # The `kv_offload` flag should always be set to True. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + config = AutoConfig.from_pretrained(model_name) + config.vision_config._attn_implementation = "eager" + + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, + kv_offload=kv_offload, + config=config, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, + ) + + ## STEP - 2 Export & Compile the Model + + model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + ) + + ## STEP - 3 Load and process the inputs for Inference + + # We are resizing the image to (w x h) (1540 x 1540) so that any image can work on the model irrespective of image dimensssions + # we have a fixed size of height 1540 and width 1540 as defined in the config + + image = Image.open(requests.get(image_url, stream=True).raw) + image = image.resize((1540, 1540)) + + messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": query}]}] + input_text = processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt") + + ## STEP - 4 Run Inference on the compiled model + + streamer = TextStreamer(processor.tokenizer) + output = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + print(output) + + +if __name__ == "__main__": + # Model name and Input parameters + model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" + + # Please add prompt here + query = "Describe the image" + + # Please pass image url or image path .The format of the image should be jpg. + image_url = "https://www.ilankelman.org/stopsigns/australia.jpg" + + # Compilation parameters for the model + kv_offload = True + prefill_seq_len = 128 + ctx_len = 8192 + generation_len = 128 + num_cores = 16 + num_devices = 4 + comp_ctx_lengths_prefill = [4096] + comp_ctx_lengths_decode = [6144, ctx_len] + + run_model( + model_name=model_name, + query=query, + kv_offload=kv_offload, + image_url=image_url, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + generation_len=generation_len, + num_cores=num_cores, + num_devices=num_devices, + ) + + +""" +Expected Response: +The image depicts a street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese archway, known as a paifang, which is intricately designed with red columns and ornate details. The archway features Chinese characters at the top, which translate to "Chinatown Gate." +In the foreground, there is a red stop sign mounted on a pole. The street is relatively quiet, with a single dark-colored SUV driving through the archway. On either side of the archway, there are stone lion statues, which are common decorative elements in Chinese architecture and symbolize protection. + + +""" diff --git a/examples/ccl_molmo_example.py b/examples/ccl_molmo_example.py new file mode 100644 index 000000000..dd09fa020 --- /dev/null +++ b/examples/ccl_molmo_example.py @@ -0,0 +1,100 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +import torch +import transformers +from PIL import Image +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM + +model_id = "allenai/Molmo-7B-D-0924" +config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + +# config.num_hidden_layers = 2 + +# load the model +ctx_len = 8192 +comp_ctx_lengths_prefill = [3072] +comp_ctx_lengths_decode = [4096, 8192] + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, + kv_offload=True, + trust_remote_code=True, + config=config, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=4, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) + + inputs = processor.process(text="Tell me about yourself") + inputs = {k: v.unsqueeze(0) for k, v in inputs.items()} + inputs["input_ids"] = inputs["input_ids"].to(torch.int64) + inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3], generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + ## Vision + Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=4, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + + image = Image.open(requests.get(image_url, stream=True).raw) + image = image.resize((536, 354)) + + inputs = processor.process(images=[image], text="Can you describe the image in detail.") + + inputs = {k: v.unsqueeze(0) for k, v in inputs.items()} + inputs["pixel_values"] = inputs.pop("images") + inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64) + + valid = inputs["image_input_idx"] > 0 + valid = valid.reshape(1, -1) + inputs["valid_idx"] = torch.nonzero(valid)[:, 1].unsqueeze(0) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3], generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + print() diff --git a/examples/ccl_qwen2_5_vl_CB.py b/examples/ccl_qwen2_5_vl_CB.py new file mode 100644 index 000000000..6954d356f --- /dev/null +++ b/examples/ccl_qwen2_5_vl_CB.py @@ -0,0 +1,81 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +# If we want to enable QBlocking Run below command:, default is without blocking +# ATTENTION_BLOCKING_MODE=q num_q_blocks=2 python -W ignore qwen2_5_vl_example.py + +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "Qwen/Qwen2.5-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) +# config.text_config.num_hidden_layers = 2 + +ctx_len = 8192 +comp_ctx_lengths_prefill = [4096] +comp_ctx_lengths_decode = [6144, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +batch_size = 1 +## Vision + Text ## +qeff_model.compile( + batch_size=batch_size, + full_batch_size=4, + prefill_seq_len=128, + ctx_len=8192, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +image_urls = [ + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +streamer = TextStreamer(tokenizer) +output = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + generation_len=100, +) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output) diff --git a/examples/ccl_qwen2_5_vl_example.py b/examples/ccl_qwen2_5_vl_example.py new file mode 100644 index 000000000..273a18361 --- /dev/null +++ b/examples/ccl_qwen2_5_vl_example.py @@ -0,0 +1,152 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +# If we want to enable QBlocking Run below command:, default is without blocking +# ATTENTION_BLOCKING_MODE=q num_q_blocks=2 python -W ignore qwen2_5_vl_example.py + +import requests +import transformers +from PIL import Image +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "Qwen/Qwen2.5-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) +# config.text_config.num_hidden_layers = 2 + +ctx_len = 8192 +comp_ctx_lengths_prefill = [4096] +comp_ctx_lengths_decode = [6144, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + + ## Set Batch_Size ## + batch_size = 1 + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=False, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] + + messages = [messages] * batch_size + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100, device_ids=[0, 1, 2, 3]) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + batch_size = 1 + ## Vision + Text ## + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + + image = Image.open(requests.get(image_url, stream=True).raw) + + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe this image."}, + ], + }, + ] + + messages_2 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe about the color of the dog."}, + ], + }, + ] + + messages = [messages_2] * batch_size + + texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] + + image_inputs, video_inputs = process_vision_info(messages) + inputs = processor( + text=texts, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100, device_ids=[0, 1, 2, 3]) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py new file mode 100644 index 000000000..163261e04 --- /dev/null +++ b/examples/compute_context_length.py @@ -0,0 +1,70 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +## In this example, you can run a model for static and continuous batching with different Compute-Context-Length (CCL) inputs. ## + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + +## Using optional variable comp_ctx_lengths variable you can pass a list of context lengths for both prefilling and decoding processes. It will run the model with default context length if comp_ctx_lengths=None. ## +## - The first comp_ctx_lengths_prefill list shows the compute-ctx-length list for prefilling process. It will start the prefilling process with the first element in the list and gradually will increase the comp_ctx_lengths based on the position_id of the current prompt chunk. ## +## - The second comp_ctx_lengths_decode list will be used for decoding. During the decoding process, based on the position_id or cache index it will work with the specific compute-context-length in the list. It will start from a proper compute-context-length in the list based on input prompt length and will gradually increase the compute-context-length if the cache index passes the current compute-context-length. ## + +ctx_len = 1024 +comp_ctx_lengths_prefill = [256, 500] # None +comp_ctx_lengths_decode = [512, ctx_len] # None + +model_name = "meta-llama/Llama-3.2-1B" +# model_name = "google/gemma-7b" +# model_name = "tiiuae/falcon-7b-instruct" +# model_name = "google/gemma-2-2b" +# model_name = "ibm-granite/granite-3.1-8b-instruct" +# model_name = "Snowflake/Llama-3.1-SwiftKV-8B-Instruct" +# model_name = "mistralai/Mistral-7B-v0.1" +# model_name = "microsoft/phi-1_5" +# model_name = "microsoft/Phi-3-mini-4k-instruct" +# model_name = "Qwen/Qwen2.5-7B-Instruct" +# model_name = "Qwen/Qwen3-1.7B" +# model_name = "allenai/OLMo-2-0425-1B" +# model_name = "ibm-granite/granite-3.3-2b-base" +# model_name = "ibm-granite/granite-3.2-8b-instruct" +# model_name = "meta-llama/Llama-3.3-70B-Instruct" +# model_name = "Salesforce/codegen-350M-mono" +# model_name = "openai-community/gpt2" +# model_name = "EleutherAI/gpt-j-6b" + +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + continuous_batching=True, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, # Is required for CCL checkings + }, +) + +# model compilation for either continuous or static batching. For continuous batching full_batch_size is needed. +model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=1, + mxint8_kv_cache=True, + mxfp6_matmul=True, + full_batch_size=1, +) + +# Create tokenizer and run model.generate and passes the input prompts to it. +tokenizer = AutoTokenizer.from_pretrained(model_name) +model.generate( + prompts=[ + "My name is ", + ], + tokenizer=tokenizer, + generation_len=128, +) diff --git a/examples/gemma3_example/ccl_gemma3_mm.py b/examples/gemma3_example/ccl_gemma3_mm.py new file mode 100644 index 000000000..9bf6e9c5a --- /dev/null +++ b/examples/gemma3_example/ccl_gemma3_mm.py @@ -0,0 +1,121 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import transformers +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText + +# Change model_id to "google/gemma-3-27b-it" for 27B model +model_id = "google/gemma-3-4b-it" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +# config.text_config.num_hidden_layers = 1 +# config.vision_config.num_hidden_layers = 2 +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) +processor = AutoProcessor.from_pretrained(model_id) + +# pass HF_TOKEN if gated model +# For running the model in single QPC approach use kv_offload=False. For Dual QPC approach use kv_offload=True ### +ctx_len = 8192 +comp_ctx_lengths_prefill = [3072] +comp_ctx_lengths_decode = [4096, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + config=config, + attn_implementation="eager", + kv_offload=True, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, +) + +### use skip_vision=Ture, if want to run only text, or false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=896, + num_cores=16, + num_devices=4, + mxfp6_matmul=False, + mxint8_kv_cache=False, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + node_precision_info="examples/gemma3_example/fp32_nodes_gemma3_27b.yaml", + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe the transformers architecture in LLMs."}, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + ## Vision + Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=896, + num_cores=16, + num_devices=4, + mxfp6_matmul=False, + mxint8_kv_cache=False, + aic_enable_depth_first=True, + mos=1, + node_precision_info="examples/gemma3_example/fp32_nodes_gemma3_27b.yaml", + ) + + ### IMAGE + TEXT ### + image_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + {"type": "text", "text": "Can you describe the image in detail."}, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) diff --git a/examples/granite_example/ccl_granite_vision_inference.py b/examples/granite_example/ccl_granite_vision_inference.py new file mode 100644 index 000000000..64ecaf948 --- /dev/null +++ b/examples/granite_example/ccl_granite_vision_inference.py @@ -0,0 +1,129 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +from PIL import Image +from transformers import AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +# Add HuggingFace Token to access the model +HF_TOKEN = "" + + +def run_model( + model_name, + token, + query, + image_url, + kv_offload=False, + prefill_seq_len=5500, + ctx_len=6000, + comp_ctx_lengths_prefill=None, + comp_ctx_lengths_decode=None, + generation_len=128, + img_size=384, + num_cores=16, + num_devices=1, +): + ## STEP - 1 Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name, token=token) + + # `kv_offload` is used to compile the model in a 2 QPCs.Currently we are not supporting 1 qpc so the flag false is not allowed. + # The `kv_offload` flag should always be set to True. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, + token=token, + kv_offload=kv_offload, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, + ) + + ## STEP - 2 Export & Compile the Model + + model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + ) + + ## STEP - 3 Load and process the inputs for Inference + + # We are resizing the image to (w x h) (1610 x 1109) so that any image can work on the model irrespective of image dimensssions + # we have a fixed size of height 1109 and width 1610 + + image = Image.open(requests.get(image_url, stream=True).raw) + image = image.resize((1610, 1109)) + + messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": query}]}] + input_text = processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt") + + ## STEP - 4 Run Inference on the compiled model + + streamer = TextStreamer(processor.tokenizer) + output = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + print(output) + + +if __name__ == "__main__": + # Model name and Input parameters + model_name = "ibm-granite/granite-vision-3.2-2b" + + # Please add prompt here + query = "Describe the image" + + # Please pass image url or image path .The format of the image should be jpg. + image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" + + # Compilation parameters for the model + kv_offload = True + prefill_seq_len = 5500 + ctx_len = 8192 + generation_len = 128 + img_size = 384 + num_cores = 16 + num_devices = 4 + ctx_len = 8192 + comp_ctx_lengths_prefill = [5500] + comp_ctx_lengths_decode = [6144, ctx_len] + + run_model( + model_name=model_name, + token=HF_TOKEN, + query=query, + kv_offload=kv_offload, + image_url=image_url, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + generation_len=generation_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + ) + + +""" +Expected Response: + +The image depicts two cats lying on a pink blanket that is spread out on a red couch. The cats are positioned in a relaxed manner, with their bodies stretched out and their heads resting on the blanket. +The cat on the left is a smaller, tabby cat with a mix of black, gray, and white fur. It has a long, slender body and a distinctive tail that is curled up near its tail end. The cat on the right is a larger, +tabby cat with a mix of gray, black, and brown fur. It has + +""" diff --git a/examples/intern_example/ccl_internvl_inference.py b/examples/intern_example/ccl_internvl_inference.py new file mode 100644 index 000000000..827d50c97 --- /dev/null +++ b/examples/intern_example/ccl_internvl_inference.py @@ -0,0 +1,288 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from io import BytesIO +from typing import List + +import requests +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoTokenizer, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.logging_utils import logger + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +# Process the input messages to generate prompt for the model. +def get_prompt(messages) -> str: + """Get the prompt for generation.""" + ## Chat template used for InternVL + system_prompt = "<|im_start|>system\n你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。" + sep = "<|im_end|>\n" + + ret = system_prompt + sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + sep + else: + ret += role + return ret + + +# Processor class for InternVL models +class InternProcessor: + """ + InternVL model only has an AutoTokenizer so this class performs the processing tasks similar to an AutoProcessor. + The methods used here are borrowed from the original InternVL modelling files. + "https://huggingface.co/OpenGVLab/InternVL2_5-1B/" + """ + + def __init__(self, model: nn.Module, tokenizer): + self.model = model + image_size = self.model.config.force_image_size or self.model.config.vision_config.image_size + patch_size = self.model.config.vision_config.patch_size + self.template = model.config.template + self.num_image_token = int((image_size // patch_size) ** 2 * (self.model.config.downsample_ratio**2)) + self.tokenizer = tokenizer + + def build_transform(self, input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD), + ] + ) + return transform + + def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + # find the closest aspect ratio to the target + target_aspect_ratio = self.find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + def load_image(self, image, input_size=448, max_num=12): + transform = self.build_transform(input_size=input_size) + images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + def __call__( + self, + pixel_values, + question, + messages, + roles, + history=None, + num_patches_list=None, + IMG_START_TOKEN="", + IMG_END_TOKEN="", + IMG_CONTEXT_TOKEN="", + verbose=False, + ) -> str: + if history is None and pixel_values is not None and "" not in question: + question = "\n" + question + if num_patches_list is None: + num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] + assert pixel_values is None or len(pixel_values) == sum(num_patches_list) + img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + self.model.img_context_token_id = img_context_token_id + + messages.append([roles[0], question]) + messages.append([roles[1], None]) + query = get_prompt(messages) + if verbose and pixel_values is not None: + image_bs = pixel_values.shape[0] + logger.info(f"dynamic ViT batch size: {image_bs}") + + for num_patches in num_patches_list: + image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN + query = query.replace("", image_tokens, 1) + return query + + +def run_intern_on_aic( + model_name, + prompt, + image_url, + messages, + roles, + kv_offload=False, + prefill_seq_len=3840, + num_devices=1, + num_cores=16, +): + ## STEP 1 -- LOAD THE MODEL + + # The original Intern-VL model, despite being multimodal, is loaded using `AutoModelForCausalLM` in Huggingface. + # To maintain compatibility, we load this model using `QEFFAutoModelForCausalLM`. + + ctx_len = 8192 + comp_ctx_lengths_prefill = [4096] + comp_ctx_lengths_decode = [6144, ctx_len] + + # model = QEFFAutoModelForCausalLM.from_pretrained(model_name, kv_offload=kv_offload, trust_remote_code=True) + + model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + kv_offload=kv_offload, + trust_remote_code=True, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, + ) + + ## STEP 2 -- EXPORT & COMPILE THE MODEL + + model.compile( + num_cores=num_cores, + num_devices=num_devices, + ctx_len=ctx_len, + prefill_seq_len=prefill_seq_len, + mxfp6_matmul=False, + ) + + ## STEP 3 -- SETUP THE PROCESSOR + + # InternVL doesn't have an AutoProcessor yet, so we will use our own processor class "InternProcessor" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) + internProcessor = InternProcessor(model.model, tokenizer) + + ## STEP 4 -- PREPROCESS THE INPUTS + + img = requests.get(image_url, stream=True) + image = Image.open(BytesIO(img.content)).convert("RGB") + + # Images are resized to (1000, 747) for inference + image = image.resize((1000, 747)) + + # preprocess the resized image + pixel_values = internProcessor.load_image(image, max_num=12) + question = "\n" + prompt + query = internProcessor(pixel_values, question, messages, roles) + inputs = tokenizer( + query, return_tensors="pt", padding="max_length", max_length=prefill_seq_len, padding_side="right" + ) + + inputs["pixel_values"] = pixel_values + + ## STEP 5 -- RUN INFERENCE VIA GENERATE FUNCTION + streamer = TextStreamer(tokenizer) + model.generate(inputs=inputs, streamer=streamer, generation_len=128) + + +if __name__ == "__main__": + model_name = "OpenGVLab/InternVL2_5-1B" + + # Chat Template information for prompt preprocessing + messages: List[List[str]] = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + + # Inputs for the model + prompt = "Please describe the image in detail." + image_url = "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg" + + ## Compilation parameters + + # `kv_offload` is used to compile the model in a Single QPC or 2 QPCs. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + kv_offload = True + + # InternVL is an Early-Fusion model that uses placeholder tokens within the input_ids to interleave text_embeddings with + # Image embeddings and generate final input_embeds for outout generation. Hence we need very large prefill_seq_len (3840 in this case) to + # incorporate the memory for the merged embeddings. + + prefill_seq_len = 3840 + num_devices = 4 + num_cores = 16 + + run_intern_on_aic( + model_name=model_name, + prompt=prompt, + image_url=image_url, + messages=messages, + roles=roles, + kv_offload=kv_offload, + prefill_seq_len=prefill_seq_len, + num_devices=num_devices, + num_cores=num_cores, + ) + + +""" +Expected Response: + +The image is a promotional graphic for Microsoft Azure. It features a blue background with a hexagonal pattern on the left side. The hexagons are white and are arranged in a way that suggests a network or connectivity theme. + +On the right side of the image, the Microsoft Azure logo is prominently displayed. The logo consists of the Azure name in white, with the Microsoft logo above it, which includes four colored squares (blue, green, yellow, and red). Below the logo, the word "Azure" is written in large white letters. + +Below the logo, there is text that reads: +- "By Dinesh Kumar Wick +""" diff --git a/examples/qwen3moe_example/ccl_qwen3moe_inference.py b/examples/qwen3moe_example/ccl_qwen3moe_inference.py new file mode 100644 index 000000000..d2fa208df --- /dev/null +++ b/examples/qwen3moe_example/ccl_qwen3moe_inference.py @@ -0,0 +1,48 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.constants import Constants + +model_name = "Qwen/Qwen3-30B-A3B-Instruct-2507" +""" +# For CB inference, set continuous_batching to True and add full_batch_size,mxfp6,mint8 argument in compile function +# We will use prompt_len=1 for compilation for both cb and non-cb inference +""" + +ctx_len = 1024 +prefill_seq_len = 1 +# In moe models when compiling with prefill_seq_len=1 and non-continuous-batching mode, prefill and decode will share the same specializations. +comp_ctx_lengths_prefill = [256, 512, ctx_len] +comp_ctx_lengths_decode = [256, 512, ctx_len] + +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + continuous_batching=False, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + "prefill_seq_len": prefill_seq_len, + }, +) + +model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=1, + num_cores=16, + num_devices=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + mos=1, +) +# mos=1, +tokenizer = AutoTokenizer.from_pretrained(model_name) +exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer)