Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions tools/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,37 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
In the case of guard failures due to some PyTorch kernel implements, we also
try to re-export the graph by expressing them as runtime assert nodes
"""
print(
f"Exporting model with min_seq_len={min_seq_len} and max_seq_len={max_seq_len}"
)
with torch.no_grad():
# max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604
seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len)
position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device)
try:
print("Trying to export the model using torch.export.export()..")
print("Trying to export the model using torch.export._trace._export()..")
# strict=False only enables aotautograd tracing and excludes dynamo.
ep = torch.export.export(
ep = torch.export._trace._export(
model,
args=(inputs,),
kwargs={"position_ids": position_ids},
dynamic_shapes=({1: seq_len}, {1: seq_len}),
strict=False,
allow_complex_guards_as_runtime_asserts=True,
)

except:
print(
"Trying torch.export._trace._export to trace the graph since torch.export.export() failed"
"Trying torch.export.export to trace the graph since torch.export._trace._export() failed"
)
# This API is used to express the constraint violation guards as asserts in the graph.
ep = torch.export._trace._export(

ep = torch.export.export(
model,
args=(inputs,),
kwargs={"position_ids": position_ids},
dynamic_shapes=({1: seq_len}, {1: seq_len}),
strict=False,
prefer_deferred_runtime_asserts_over_guards=True,
)

return ep
Expand Down Expand Up @@ -223,6 +228,7 @@ def time_generate(
"""
timings = []
for _ in range(iterations):
print(f"Iteration {_} of {iterations}")
start_time = timeit.default_timer()
_ = generate_fn(model, inputs, output_seq_length, eos_token_id)
torch.cuda.synchronize()
Expand Down
Loading