Skip to content

Conversation

jtuyls
Copy link
Contributor

@jtuyls jtuyls commented Oct 21, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Oct 21, 2025

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Jorn Tuyls (jtuyls)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/164438.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp (+13)
  • (modified) mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir (+14)
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index 11400de35e430..a15bf891dd596 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -59,6 +59,17 @@ struct DimOpInterface
   }
 };
 
+struct ExpandShapeOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
+                                                   memref::ExpandShapeOp> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto expandOp = cast<memref::ExpandShapeOp>(op);
+    assert(value == expandOp.getResult() && "invalid value");
+    cstr.bound(value)[dim] == expandOp.getOutputShape()[dim];
+  }
+};
+
 struct GetGlobalOpInterface
     : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
                                                    GetGlobalOp> {
@@ -123,6 +134,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
         memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
     memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
     memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
+    memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
+        *ctx);
     memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
     memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
     memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
index 8bd7ae8df9049..ac1f22b68b1e1 100644
--- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
@@ -63,6 +63,20 @@ func.func @memref_dim_all_positive(%m: memref<?xf32>, %x: index) {
 
 // -----
 
+// CHECK-LABEL: func @memref_expand(
+//  CHECK-SAME:     %[[m:[a-zA-Z0-9]+]]: memref<?xf32>
+//  CHECK-SAME:     %[[sz:[a-zA-Z0-9]+]]: index
+//       CHECK:   %[[c4:.*]] = arith.constant 4 : index
+//       CHECK:   return %[[sz]], %[[c4]]
+func.func @memref_expand(%m: memref<?xf32>, %sz: index) -> (index, index) {
+  %0 = memref.expand_shape %m [[0, 1]] output_shape [%sz, 4]: memref<?xf32> into memref<?x4xf32>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (memref<?x4xf32>) -> (index)
+  %2 = "test.reify_bound"(%0) {dim = 1} : (memref<?x4xf32>) -> (index)
+  return %1, %2 : index, index
+}
+
+// -----
+
 // CHECK-LABEL: func @memref_get_global(
 //       CHECK:   %[[c4:.*]] = arith.constant 4 : index
 //       CHECK:   return %[[c4]]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants