-
Notifications
You must be signed in to change notification settings - Fork 606
Open
Description
- I try to generate a torch.aten.avg_pool1d test case using the fx.export_and_import, so I create a torch file avg_pool1d.py , which use F.avg_pool1d(x, kernel_size=3, stride=2)
def forward(self, x):
return F.avg_pool1d(x, kernel_size=3, stride=2)
- real output of the touch.mlir: using a torch.aten.avg_pool2d operator
(py311-source) root@998ee80b761b:/home/zhongyunde/torch-mlir/test/python/fx_importer# python avg_pool1d.py
test_import_frozen_exported_program
-----------------------------------
module {
func.func @main(%arg0: !torch.vtensor<[2,3,10],f32>) -> !torch.vtensor<[2,3,4],f32> {
%int-2 = torch.constant.int -2
%0 = torch.aten.unsqueeze %arg0, %int-2 : !torch.vtensor<[2,3,10],f32>, !torch.int -> !torch.vtensor<[2,3,1,10],f32>
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%1 = torch.prim.ListConstruct %int1, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%int1_0 = torch.constant.int 1
%int2 = torch.constant.int 2
%2 = torch.prim.ListConstruct %int1_0, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%int0 = torch.constant.int 0
%int0_1 = torch.constant.int 0
%3 = torch.prim.ListConstruct %int0, %int0_1 : (!torch.int, !torch.int) -> !torch.list<int>
%false = torch.constant.bool false
%true = torch.constant.bool true
%none = torch.constant.none
%4 = torch.aten.avg_pool2d %0, %1, %2, %3, %false, %true, %none : !torch.vtensor<[2,3,1,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,3,1,4],f32>
%int-2_2 = torch.constant.int -2
%5 = torch.aten.squeeze.dim %4, %int-2_2 : !torch.vtensor<[2,3,1,4],f32>, !torch.int -> !torch.vtensor<[2,3,4],f32>
return %5 : !torch.vtensor<[2,3,4],f32>
}
}
- expected output : using torch.aten.avg_pool1d
module {
func.func @main(%arg0: !torch.vtensor<[2,3,10],f32>) -> !torch.vtensor<[2,3,4],f32> {
%true = torch.constant.bool true
%false = torch.constant.bool false
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %true : !torch.vtensor<[2,3,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[2,3,4],f32>
return %3 : !torch.vtensor<[2,3,4],f32>
}
}
- torch-mlir version
(py311-source) root@998ee80b761b:/home/zhongyunde/torch-mlir/test/python/fx_importer# pip list | grep mlir
torch-mlir 20241002.240
Metadata
Metadata
Assignees
Labels
No labels