Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,18 @@ def main(
"--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation."
)
parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.")
parser.add_argument(
"--comp-ctx-lengths-prefill",
type=lambda comp_ctx_lengths_prefill: [int(x) for x in comp_ctx_lengths_prefill.split(",")],
default=[512],
help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).",
)
parser.add_argument(
"--comp-ctx-lengths-decode",
type=lambda comp_ctx_lengths_decode: [int(x) for x in comp_ctx_lengths_decode.split(",")],
default=[2048],
help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).",
)
parser.add_argument(
"--mxfp6",
"--mxfp6_matmul",
Expand Down
16 changes: 11 additions & 5 deletions QEfficient/customop/ctx_scatter_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
def CtxGather(
data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
) -> onnxscript.FLOAT:
# Create a shape tensor based on comp_ctx_len
shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0)

# Directly use the shape tensor without validation
ctx_indices = ops.Expand(ctx_indices, shape_tensor)
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
return ops.GatherND(data, ctx_indices, batch_dims=2)

Expand All @@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function):
"""

@staticmethod
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]
Expand All @@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data)
18 changes: 12 additions & 6 deletions QEfficient/customop/ctx_scatter_gather_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,20 @@ def symbolic(

@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGatherCB(
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
) -> onnxscript.FLOAT:
batch_size = ops.Gather(ops.Shape(batch_index), [0])
num_heads = ops.Gather(ops.Shape(data), [1])
ctx_len = ops.Gather(ops.Shape(data), [2])
# using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well.
ctx_len = ops.Reshape(comp_ctx_len, [1])

# Expanded shape to create indices
zero = ops.Constant(value_ints=[0])
one = ops.Constant(value_ints=[1])
exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
# exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
exp_shape = ops.Concat(
ops.Reshape(batch_size, [1]), ops.Reshape(num_heads, [1]), ops.Reshape(ctx_len, [1]), one, axis=0
)

# Create indices
batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape)
Expand All @@ -119,7 +123,7 @@ def CtxGatherCB(

class CtxGatherFuncCB(torch.autograd.Function):
@staticmethod
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor):
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = batch_index.view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]
Expand All @@ -129,8 +133,10 @@ def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data)
def symbolic(
g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int
) -> torch.Value:
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
Expand Down
76 changes: 76 additions & 0 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv(
prompts_txt_file_path: Optional[str] = None,
device_id: Optional[List[int]] = None,
generation_len: Optional[int] = None,
comp_ctx_lengths_prefill: Optional[List[int]] = None,
comp_ctx_lengths_decode: Optional[List[int]] = None,
enable_debug_logs: bool = False,
stream: bool = True,
write_io_dir: Optional[str] = None,
Expand Down Expand Up @@ -384,6 +386,8 @@ def cloud_ai_100_exec_kv(
qpc_path=qpc_path,
device_id=device_id,
ctx_len=ctx_len,
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
enable_debug_logs=enable_debug_logs,
write_io_dir=write_io_dir,
full_batch_size=full_batch_size,
Expand Down Expand Up @@ -430,6 +434,8 @@ def __init__(
qpc_path: str,
full_batch_size: Optional[int] = None,
ctx_len: Optional[int] = None,
comp_ctx_lengths_prefill: Optional[List[int]] = None,
comp_ctx_lengths_decode: Optional[List[int]] = None,
device_id: Optional[List[int]] = None,
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
Expand All @@ -440,6 +446,8 @@ def __init__(
activate: bool = True,
) -> None:
self._ctx_len = ctx_len
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
self._write_io_dir = write_io_dir
self.is_tlm = is_tlm
self.return_pdfs = return_pdfs
Expand Down Expand Up @@ -802,7 +810,17 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)

if self.comp_ctx_lengths_prefill is not None:
self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
prefill_ccl_id = 0
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]

for i in range(num_chunks):
if self.comp_ctx_lengths_prefill is not None:
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]:
prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1)
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]

chunk_inputs = inputs.copy()
chunk_inputs["input_ids"] = inputs["input_ids"][
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
Expand All @@ -822,6 +840,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
generation_len,
)

def initialize_ccl(self, decode_inputs):
self.list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode]
max_ccl_id = len(self.comp_ctx_lengths_decode) - 1
max_position_id = np.max(decode_inputs["position_ids"])
ccl_id_initial = 0
ccl_id = ccl_id_initial
for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
if max_position_id < self.comp_ctx_lengths_decode[i]:
ccl_id = i
break

return ccl_id, max_ccl_id

def run_continuous_batching_decode(self, prompt_queue, generation_len):
"""
Runs continuous batching decode for the given prompt queue and generation length.
Expand Down Expand Up @@ -853,6 +884,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
# Prepare decode inputs inputs.
decode_inputs = self.prepare_decode_inputs()

if self.comp_ctx_lengths_decode is not None:
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]

while prompt_queue or current_decode_ongoing.any():
outputs = self._session.run(decode_inputs)

Expand Down Expand Up @@ -890,6 +925,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
batch_id_map[decode_batch_id]
]

if self.comp_ctx_lengths_decode is not None:
###Recalculate ccl_id based on position ids###
# Determine the maximum value of position_ids across all batch elements
max_position_id = np.max(decode_inputs["position_ids"])

# Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
ccl_id_initial = 0
ccl_id = ccl_id_initial
for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
if max_position_id < self.comp_ctx_lengths_decode[i]:
ccl_id = i
break
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]

else:
current_decode_ongoing[decode_batch_id] = False
else:
Expand All @@ -902,6 +951,15 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]

if self.comp_ctx_lengths_decode is not None:
# Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
if (
decode_inputs["position_ids"][decode_batch_id, -1]
>= self.comp_ctx_lengths_decode[ccl_id] - 1
):
ccl_id = min(ccl_id + 1, max_ccl_id)
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]

generated_id_current_index[decode_batch_id] += 1

return decode_pause_time
Expand All @@ -928,7 +986,18 @@ def run_decode(
self._session.set_buffers({"logits": logits_out_placeholder})
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
num_token = 0

if self.comp_ctx_lengths_decode is not None:
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]

cache_index = np.max(decode_inputs["position_ids"])
for num_token in range(1, generation_len):
if self.comp_ctx_lengths_decode is not None:
if cache_index >= self.comp_ctx_lengths_decode[ccl_id] - 1:
ccl_id = min(ccl_id + 1, max_ccl_id)
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]

if streamer:
streamer.put(decode_inputs["input_ids"][0])
outputs = self._session.run(decode_inputs)
Expand All @@ -940,6 +1009,7 @@ def run_decode(
# Prepare inputs for next iteration
decode_inputs["input_ids"] = self._fetch_next_token_id(outputs)
decode_inputs["position_ids"][:, -1] += 1
cache_index += 1
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
if self.include_sampler:
Expand Down Expand Up @@ -989,6 +1059,8 @@ def __init__(
qpc_path: str,
full_batch_size: Optional[int] = None,
ctx_len: Optional[int] = None,
comp_ctx_lengths_prefill: Optional[List[int]] = None,
comp_ctx_lengths_decode: Optional[List[int]] = None,
device_id: Optional[List[int]] = None,
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
Expand All @@ -1002,6 +1074,8 @@ def __init__(
qpc_path=qpc_path,
full_batch_size=full_batch_size,
ctx_len=ctx_len,
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
device_id=device_id,
enable_debug_logs=enable_debug_logs,
write_io_dir=write_io_dir,
Expand All @@ -1013,6 +1087,8 @@ def __init__(
self._full_batch_size = self._qaic_model.full_batch_size
self._tokenizer = self._qaic_model.tokenizer
self._ctx_len = ctx_len
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
self._perf_metrics = None
self._prompt_queue = None
self._text_streamer = None
Expand Down
16 changes: 16 additions & 0 deletions QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __init__(
vision_qpc_path: str,
device_id: Optional[List[int]] = None,
ctx_len: Optional[int] = None,
comp_ctx_lengths_prefill: Optional[List[int]] = None,
comp_ctx_lengths_decode: Optional[List[int]] = None,
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
full_batch_size: Optional[int] = None,
Expand Down Expand Up @@ -123,6 +125,8 @@ def __init__(
qpc_path=lang_qpc_path,
full_batch_size=full_batch_size,
ctx_len=ctx_len,
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
device_id=device_id,
enable_debug_logs=enable_debug_logs,
write_io_dir=write_io_dir,
Expand Down Expand Up @@ -294,6 +298,11 @@ def _execute_chunked_prefill(
outputs = None
chunk_image_idx = None

if self.comp_ctx_lengths_prefill is not None:
self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
prefill_ccl_id = 0
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]

for i in range(num_chunks):
input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len]
position_ids_slice = lang_inputs["position_ids"][
Expand All @@ -312,6 +321,13 @@ def _execute_chunked_prefill(
if "cross_attention_mask" in lang_inputs:
chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"]

if self.comp_ctx_lengths_prefill is not None:
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]:
prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1)
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]

chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"]

outputs = self._session.run(chunk_inputs)

if "image_idx_output" in outputs:
Expand Down
6 changes: 3 additions & 3 deletions QEfficient/peft/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ def forward(self, x: torch.Tensor, lora_ids: torch.Tensor):
# multilora implementation: lora_ids <batch_size, 1>
other_indices_a = torch.arange(self.lora_a_weights.shape[2]).view(1, 1, -1)
selected_lora_a_weights = CtxGatherFuncCB.apply(
self.lora_a_weights, lora_ids, other_indices_a
self.lora_a_weights, lora_ids, other_indices_a, self.lora_a_weights.shape[2]
) # <num_loras, 1, feature, r>
other_indices_b = torch.arange(self.lora_b_weights.shape[2]).view(1, 1, -1)
selected_lora_b_weights = CtxGatherFuncCB.apply(
self.lora_b_weights, lora_ids, other_indices_b
self.lora_b_weights, lora_ids, other_indices_b, self.lora_b_weights.shape[2]
) # <num_loras, 1, r, feature>
other_indices_s = torch.arange(self.lora_scalings.shape[2]).view(1, 1, -1)
selected_lora_scalings = CtxGatherFuncCB.apply(
self.lora_scalings, lora_ids, other_indices_s
self.lora_scalings, lora_ids, other_indices_s, self.lora_scalings.shape[2]
) # <num_loras, 1, 1, 1>

selected_lora_a_weights = selected_lora_a_weights.squeeze(1)
Expand Down
Loading
Loading