diff --git a/tools/llm/utils.py b/tools/llm/utils.py index b9e3506f4b..f30ee3cf48 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -16,32 +16,37 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): In the case of guard failures due to some PyTorch kernel implements, we also try to re-export the graph by expressing them as runtime assert nodes """ + print( + f"Exporting model with min_seq_len={min_seq_len} and max_seq_len={max_seq_len}" + ) with torch.no_grad(): # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device) try: - print("Trying to export the model using torch.export.export()..") + print("Trying to export the model using torch.export._trace._export()..") # strict=False only enables aotautograd tracing and excludes dynamo. - ep = torch.export.export( + ep = torch.export._trace._export( model, args=(inputs,), kwargs={"position_ids": position_ids}, dynamic_shapes=({1: seq_len}, {1: seq_len}), strict=False, + allow_complex_guards_as_runtime_asserts=True, ) + except: print( - "Trying torch.export._trace._export to trace the graph since torch.export.export() failed" + "Trying torch.export.export to trace the graph since torch.export._trace._export() failed" ) # This API is used to express the constraint violation guards as asserts in the graph. - ep = torch.export._trace._export( + + ep = torch.export.export( model, args=(inputs,), kwargs={"position_ids": position_ids}, dynamic_shapes=({1: seq_len}, {1: seq_len}), strict=False, - prefer_deferred_runtime_asserts_over_guards=True, ) return ep @@ -223,6 +228,7 @@ def time_generate( """ timings = [] for _ in range(iterations): + print(f"Iteration {_} of {iterations}") start_time = timeit.default_timer() _ = generate_fn(model, inputs, output_seq_length, eos_token_id) torch.cuda.synchronize()