-
Notifications
You must be signed in to change notification settings - Fork 606
Open
Description
PyTorch in-place operations (add_) are being translated into operations that allocate a new output tensor instead of utilizing the Destination-Passing Style (DPS). This violates the in-place contract, resulting in unnecessary memory allocations. The lowering pass should map in-place ops to ops where the original tensor is correctly passed as an operand.
import torch
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
def forward(self, x):
return x.add_(x)
model = SimpleModel()
input1_tensor = torch.full((32, 32), 2.0)
import torch_mlir.fx as fx
from torch.export import export
exported_program = export(model, (input1_tensor, ))
print(exported_program)
print("step3: fx.export_and_import")
with torch.no_grad():
module = fx.export_and_import(
exported_program,
(input1_tensor, ),
output_type="linalg-on-tensors",
func_name="forward",
enable_graph_printing=True,
)
print(module)
This code generates the following Linalg-IR.
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @forward(%arg0: tensor<32x32xf32>) -> (tensor<32x32xf32>, tensor<32x32xf32>) {
%0 = tensor.empty() : tensor<32x32xf32>
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg0 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%0 : tensor<32x32xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%2 = arith.addf %in, %in_0 : f32
linalg.yield %2 : f32
} -> tensor<32x32xf32>
return %1, %1 : tensor<32x32xf32>, tensor<32x32xf32>
}
}
But the semantically correct lowering would be.
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @forward(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> {
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg0 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%arg0: tensor<32x32xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%2 = arith.addf %in, %in_0 : f32
linalg.yield %2 : f32
} -> tensor<32x32xf32>
return %0 : tensor<32x32xf32>
}
}
Is there some reason for generating such lowering?
Metadata
Metadata
Assignees
Labels
No labels