diff options
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp | 90 | ||||
-rw-r--r-- | mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir | 172 |
2 files changed, 239 insertions, 23 deletions
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index db085b3..96daf4c 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -518,10 +518,25 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices))) return failure(); llvm::TypeSwitch<Operation *, void>(loadOp) - .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) { - rewriter.replaceOpWithNewOp<decltype(op)>( + .Case([&](affine::AffineLoadOp op) { + rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( loadOp, expandShapeOp.getViewSource(), sourceIndices); }) + .Case([&](memref::LoadOp op) { + rewriter.replaceOpWithNewOp<memref::LoadOp>( + loadOp, expandShapeOp.getViewSource(), sourceIndices, + op.getNontemporal()); + }) + .Case([&](vector::LoadOp op) { + rewriter.replaceOpWithNewOp<vector::LoadOp>( + op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, + op.getNontemporal()); + }) + .Case([&](vector::MaskedLoadOp op) { + rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( + op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, + op.getMask(), op.getPassThru()); + }) .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); return success(); } @@ -551,10 +566,25 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite( loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) return failure(); llvm::TypeSwitch<Operation *, void>(loadOp) - .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) { - rewriter.replaceOpWithNewOp<decltype(op)>( + .Case([&](affine::AffineLoadOp op) { + rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( loadOp, collapseShapeOp.getViewSource(), sourceIndices); }) + .Case([&](memref::LoadOp op) { + rewriter.replaceOpWithNewOp<memref::LoadOp>( + loadOp, collapseShapeOp.getViewSource(), sourceIndices, + op.getNontemporal()); + }) + .Case([&](vector::LoadOp op) { + rewriter.replaceOpWithNewOp<vector::LoadOp>( + op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices, + op.getNontemporal()); + }) + .Case([&](vector::MaskedLoadOp op) { + rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( + op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices, + op.getMask(), op.getPassThru()); + }) .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); return success(); } @@ -651,10 +681,25 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices))) return failure(); llvm::TypeSwitch<Operation *, void>(storeOp) - .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) { - rewriter.replaceOpWithNewOp<decltype(op)>(storeOp, storeOp.getValue(), - expandShapeOp.getViewSource(), - sourceIndices); + .Case([&](affine::AffineStoreOp op) { + rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( + storeOp, op.getValueToStore(), expandShapeOp.getViewSource(), + sourceIndices); + }) + .Case([&](memref::StoreOp op) { + rewriter.replaceOpWithNewOp<memref::StoreOp>( + storeOp, op.getValueToStore(), expandShapeOp.getViewSource(), + sourceIndices, op.getNontemporal()); + }) + .Case([&](vector::StoreOp op) { + rewriter.replaceOpWithNewOp<vector::StoreOp>( + op, op.getValueToStore(), expandShapeOp.getViewSource(), + sourceIndices, op.getNontemporal()); + }) + .Case([&](vector::MaskedStoreOp op) { + rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( + op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(), + op.getValueToStore()); }) .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); return success(); @@ -685,11 +730,26 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite( storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) return failure(); llvm::TypeSwitch<Operation *, void>(storeOp) - .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) { - rewriter.replaceOpWithNewOp<decltype(op)>( - storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(), + .Case([&](affine::AffineStoreOp op) { + rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( + storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(), sourceIndices); }) + .Case([&](memref::StoreOp op) { + rewriter.replaceOpWithNewOp<memref::StoreOp>( + storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(), + sourceIndices, op.getNontemporal()); + }) + .Case([&](vector::StoreOp op) { + rewriter.replaceOpWithNewOp<vector::StoreOp>( + op, op.getValueToStore(), collapseShapeOp.getViewSource(), + sourceIndices, op.getNontemporal()); + }) + .Case([&](vector::MaskedStoreOp op) { + rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( + op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(), + op.getValueToStore()); + }) .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); return success(); } @@ -763,12 +823,20 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>, LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>, LoadOpOfExpandShapeOpFolder<memref::LoadOp>, + LoadOpOfExpandShapeOpFolder<vector::LoadOp>, + LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>, StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>, StoreOpOfExpandShapeOpFolder<memref::StoreOp>, + StoreOpOfExpandShapeOpFolder<vector::StoreOp>, + StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>, LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>, LoadOpOfCollapseShapeOpFolder<memref::LoadOp>, + LoadOpOfCollapseShapeOpFolder<vector::LoadOp>, + LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>, StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>, StoreOpOfCollapseShapeOpFolder<memref::StoreOp>, + StoreOpOfCollapseShapeOpFolder<vector::StoreOp>, + StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>, SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>( patterns.getContext()); } diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index e49dff4..327cacf 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -473,10 +473,10 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 { %c0 = arith.constant 0 : index %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> - %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> + %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> return %0 : f32 } -// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>> +// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>> // CHECK-NEXT: return %[[VAL1]] : f32 // ----- @@ -487,11 +487,11 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16 %c0 = arith.constant 0 : index %c1f32 = arith.constant 1.0 : f32 %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> - memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> + memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> return } // CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>> +// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>> // CHECK-NEXT: return // ----- @@ -819,14 +819,14 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind // ----- -func.func @fold_vector_load( +func.func @fold_vector_load_subview( %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> { %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>> %1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32> return %1 : vector<12x32xf32> } -// CHECK: func @fold_vector_load +// CHECK: func @fold_vector_load_subview // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index @@ -834,14 +834,14 @@ func.func @fold_vector_load( // ----- -func.func @fold_vector_maskedload( +func.func @fold_vector_maskedload_subview( %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> { %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>> %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32> return %1 : vector<32xf32> } -// CHECK: func @fold_vector_maskedload +// CHECK: func @fold_vector_maskedload_subview // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index @@ -851,14 +851,14 @@ func.func @fold_vector_maskedload( // ----- -func.func @fold_vector_store( +func.func @fold_vector_store_subview( %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () { %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>> vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32> return } -// CHECK: func @fold_vector_store +// CHECK: func @fold_vector_store_subview // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index @@ -868,14 +868,14 @@ func.func @fold_vector_store( // ----- -func.func @fold_vector_maskedstore( +func.func @fold_vector_maskedstore_subview( %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () { %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>> vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> return } -// CHECK: func @fold_vector_maskedstore +// CHECK: func @fold_vector_maskedstore_subview // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index @@ -883,3 +883,151 @@ func.func @fold_vector_maskedstore( // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32> // CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> // CHECK: return + +// ----- + +func.func @fold_vector_load_expand_shape( + %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> { + %c0 = arith.constant 0 : index + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + %1 = vector.load %0[%arg1, %c0] {nontemporal = true} : memref<4x8xf32>, vector<8xf32> + return %1 : vector<8xf32> +} + +// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)> +// CHECK-LABEL: func @fold_vector_load_expand_shape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]] +// CHECK: vector.load %[[ARG0]][%[[IDX]]] {nontemporal = true} + +// ----- + +func.func @fold_vector_maskedload_expand_shape( + %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> { + %c0 = arith.constant 0 : index + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32> + return %1 : vector<8xf32> +} + +// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)> +// CHECK-LABEL: func @fold_vector_maskedload_expand_shape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1> +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32> +// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]] +// CHECK: vector.maskedload %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]] + +// ----- + +func.func @fold_vector_store_expand_shape( + %arg0 : memref<32xf32>, %arg1 : index, %val : vector<8xf32>) { + %c0 = arith.constant 0 : index + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + vector.store %val, %0[%arg1, %c0] {nontemporal = true} : memref<4x8xf32>, vector<8xf32> + return +} + +// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)> +// CHECK-LABEL: func @fold_vector_store_expand_shape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]] +// CHECK: vector.store %{{.*}}, %[[ARG0]][%[[IDX]]] {nontemporal = true} + +// ----- + +func.func @fold_vector_maskedstore_expand_shape( + %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) { + %c0 = arith.constant 0 : index + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> + return +} + +// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 8)> +// CHECK-LABEL: func @fold_vector_maskedstore_expand_shape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1> +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32> +// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]] +// CHECK: vector.maskedstore %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]] + +// ----- + +func.func @fold_vector_load_collapse_shape( + %arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> { + %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32> + %1 = vector.load %0[%arg1] {nontemporal = true} : memref<32xf32>, vector<8xf32> + return %1 : vector<8xf32> +} + +// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)> +// CHECK-LABEL: func @fold_vector_load_collapse_shape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]] +// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]] +// CHECK: vector.load %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true} + +// ----- + +func.func @fold_vector_maskedload_collapse_shape( + %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> { + %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32> + %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32> + return %1 : vector<8xf32> +} + +// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)> +// CHECK-LABEL: func @fold_vector_maskedload_collapse_shape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1> +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32> +// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]] +// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]] +// CHECK: vector.maskedload %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]] + +// ----- + +func.func @fold_vector_store_collapse_shape( + %arg0 : memref<4x8xf32>, %arg1 : index, %val : vector<8xf32>) { + %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32> + vector.store %val, %0[%arg1] {nontemporal = true} : memref<32xf32>, vector<8xf32> + return +} + +// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)> +// CHECK-LABEL: func @fold_vector_store_collapse_shape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]] +// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]] +// CHECK: vector.store %{{.*}}, %[[ARG0]][%[[IDX]], %[[IDX1]]] {nontemporal = true} + +// ----- + +func.func @fold_vector_maskedstore_collapse_shape( + %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) { + %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32> + vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> + return +} + +// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 8)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 8)> +// CHECK-LABEL: func @fold_vector_maskedstore_collapse_shape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1> +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32> +// CHECK: %[[IDX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]]] +// CHECK: %[[IDX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]] +// CHECK: vector.maskedstore %[[ARG0]][%[[IDX]], %[[IDX1]]], %[[ARG3]], %[[ARG4]] |