aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorKunwar Grover <groverkss@gmail.com>2024-06-12 15:36:16 +0100
committerGitHub <noreply@github.com>2024-06-12 15:36:16 +0100
commit57e4360836f421a2c6131de51e3845620c6aea76 (patch)
treedf39c8527a0f5903891dfac5c089957da5692fab /mlir
parent3e3b7c70f52fa020557a42a4276b9105d75044a0 (diff)
downloadllvm-57e4360836f421a2c6131de51e3845620c6aea76.zip
llvm-57e4360836f421a2c6131de51e3845620c6aea76.tar.gz
llvm-57e4360836f421a2c6131de51e3845620c6aea76.tar.bz2
[mlir][memref] Add memref alias folders for expand/collapse_shape for vector load/store (#95223)
This patch adds adds patterns to fold memref alias for expand_shape/collapse_shape feeding into vector.load/vector.store and vector.maskedload/vector.maskedstore
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp90
-rw-r--r--mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir172
2 files changed, 239 insertions, 23 deletions
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index db085b3..96daf4c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -518,10 +518,25 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
return failure();
llvm::TypeSwitch<Operation *, void>(loadOp)
- .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
- rewriter.replaceOpWithNewOp<decltype(op)>(
+ .Case([&](affine::AffineLoadOp op) {
+ rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
loadOp, expandShapeOp.getViewSource(), sourceIndices);
})
+ .Case([&](memref::LoadOp op) {
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(
+ loadOp, expandShapeOp.getViewSource(), sourceIndices,
+ op.getNontemporal());
+ })
+ .Case([&](vector::LoadOp op) {
+ rewriter.replaceOpWithNewOp<vector::LoadOp>(
+ op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
+ op.getNontemporal());
+ })
+ .Case([&](vector::MaskedLoadOp op) {
+ rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+ op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
+ op.getMask(), op.getPassThru());
+ })
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
return success();
}
@@ -551,10 +566,25 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
return failure();
llvm::TypeSwitch<Operation *, void>(loadOp)
- .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
- rewriter.replaceOpWithNewOp<decltype(op)>(
+ .Case([&](affine::AffineLoadOp op) {
+ rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
loadOp, collapseShapeOp.getViewSource(), sourceIndices);
})
+ .Case([&](memref::LoadOp op) {
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(
+ loadOp, collapseShapeOp.getViewSource(), sourceIndices,
+ op.getNontemporal());
+ })
+ .Case([&](vector::LoadOp op) {
+ rewriter.replaceOpWithNewOp<vector::LoadOp>(
+ op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
+ op.getNontemporal());
+ })
+ .Case([&](vector::MaskedLoadOp op) {
+ rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+ op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
+ op.getMask(), op.getPassThru());
+ })
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
return success();
}
@@ -651,10 +681,25 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
return failure();
llvm::TypeSwitch<Operation *, void>(storeOp)
- .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
- rewriter.replaceOpWithNewOp<decltype(op)>(storeOp, storeOp.getValue(),
- expandShapeOp.getViewSource(),
- sourceIndices);
+ .Case([&](affine::AffineStoreOp op) {
+ rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
+ storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
+ sourceIndices);
+ })
+ .Case([&](memref::StoreOp op) {
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
+ sourceIndices, op.getNontemporal());
+ })
+ .Case([&](vector::StoreOp op) {
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ op, op.getValueToStore(), expandShapeOp.getViewSource(),
+ sourceIndices, op.getNontemporal());
+ })
+ .Case([&](vector::MaskedStoreOp op) {
+ rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+ op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
+ op.getValueToStore());
})
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
return success();
@@ -685,11 +730,26 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
return failure();
llvm::TypeSwitch<Operation *, void>(storeOp)
- .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
- rewriter.replaceOpWithNewOp<decltype(op)>(
- storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(),
+ .Case([&](affine::AffineStoreOp op) {
+ rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
+ storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
sourceIndices);
})
+ .Case([&](memref::StoreOp op) {
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
+ sourceIndices, op.getNontemporal());
+ })
+ .Case([&](vector::StoreOp op) {
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ op, op.getValueToStore(), collapseShapeOp.getViewSource(),
+ sourceIndices, op.getNontemporal());
+ })
+ .Case([&](vector::MaskedStoreOp op) {
+ rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+ op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
+ op.getValueToStore());
+ })
.Default([](Operation *) { llvm_unreachable("unexpected operation."); });
return success();
}
@@ -763,12 +823,20 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
+ LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
+ LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
+ StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
+ StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
+ LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
+ LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
+ StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
+ StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index e49dff4..327cacf 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -473,10 +473,10 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
%c0 = arith.constant 0 : index
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
- %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+ %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return %0 : f32
}
-// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: return %[[VAL1]] : f32
// -----
@@ -487,11 +487,11 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
%c0 = arith.constant 0 : index
%c1f32 = arith.constant 1.0 : f32
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
- memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+ memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return
}
// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: return
// -----
@@ -819,14 +819,14 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
// -----
-func.func @fold_vector_load(
+func.func @fold_vector_load_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
%1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
return %1 : vector<12x32xf32>
}
-// CHECK: func @fold_vector_load
+// CHECK: func @fold_vector_load_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -834,14 +834,14 @@ func.func @fold_vector_load(
// -----
-func.func @fold_vector_maskedload(
+func.func @fold_vector_maskedload_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
%1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
return %1 : vector<32xf32>
}
-// CHECK: func @fold_vector_maskedload
+// CHECK: func @fold_vector_maskedload_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -851,14 +851,14 @@ func.func @fold_vector_maskedload(
// -----
-func.func @fold_vector_store(
+func.func @fold_vector_store_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
return
}
-// CHECK: func @fold_vector_store
+// CHECK: func @fold_vector_store_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -868,14 +868,14 @@ func.func @fold_vector_store(
// -----
-func.func @fold_vector_maskedstore(
+func.func @fold_vector_maskedstore_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32>
return
}
-// CHECK: func @fold_vector_maskedstore
+// CHECK: func @fold_vector_maskedstore_subview
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -883,3 +883,151 @@ func.func @fold_vector_maskedstore(
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
// CHECK: return
+
+// -----
+
+func.func @fold_vector_load_expand_shape(
+ %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ %1 = vector.load %0[%arg1, %c0] {nontemporal = true} : memref<4x8xf32>, vector<8xf32>
+ return %1 : vector<8xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_load_expand_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: vector.load %[[ARG0]][%[[IDX]]] {nontemporal = true}
+
+// -----
+
+func.func @fold_vector_maskedload_expand_shape(
+ %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+ return %1 : vector<8xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_maskedload_expand_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: vector.maskedload %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_store_expand_shape(
+ %arg0 : memref<32xf32>, %arg1 : index, %val : vector<8xf32>) {
+ %c0 = arith.constant 0 : index
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ vector.store %val, %0[%arg1, %c0] {nontemporal = true} : memref<4x8xf32>, vector<8xf32>
+ return
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_store_expand_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: vector.store %{{.*}}, %[[ARG0]][%[[IDX]]] {nontemporal = true}
+
+// -----
+
+func.func @fold_vector_maskedstore_expand_shape(
+ %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
+ %c0 = arith.constant 0 : index
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
+ return
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)>
+// CHECK-LABEL: func @fold_vector_maskedstore_expand_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: vector.maskedstore %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_load_collapse_shape(
+ %arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> {
+ %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+ %1 = vector.load %0[%arg1] {nontemporal = true} : memref<32xf32>, vector<8xf32>
+ return %1 : vector<8xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_load_collapse_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+// CHECK: vector.load %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true}
+
+// -----
+
+func.func @fold_vector_maskedload_collapse_shape(
+ %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
+ %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+ %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+ return %1 : vector<8xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_maskedload_collapse_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+// CHECK: vector.maskedload %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_store_collapse_shape(
+ %arg0 : memref<4x8xf32>, %arg1 : index, %val : vector<8xf32>) {
+ %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+ vector.store %val, %0[%arg1] {nontemporal = true} : memref<32xf32>, vector<8xf32>
+ return
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_store_collapse_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+// CHECK: vector.store %{{.*}}, %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true}
+
+// -----
+
+func.func @fold_vector_maskedstore_collapse_shape(
+ %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
+ %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+ vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32>
+ return
+}
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)>
+// CHECK-LABEL: func @fold_vector_maskedstore_collapse_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]]
+// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+// CHECK: vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]]