From 4d6b9921b3801709dca9245b5b4d7c17944a036f Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Wed, 19 Jun 2024 12:52:53 +0100 Subject: [mlir][ArmSME] Fold MoveTileSliceToVector + TransferWrite to StoreTileSlice (#95907) --- .../Conversion/VectorToArmSME/VectorToArmSME.cpp | 67 ++++++++++++++++++++-- .../Conversion/VectorToArmSME/unsupported.mlir | 12 ++++ .../VectorToArmSME/vector-to-arm-sme.mlir | 44 ++++++++++++++ 3 files changed, 117 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index c2f1584..56ae46a 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -666,14 +666,69 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern { } }; +/// Folds a MoveTileSliceToVectorOp + TransferWriteOp to a StoreTileSliceOp. +/// +/// BEFORE: +/// ```mlir +/// %slice = arm_sme.move_tile_slice_to_vector %tile[%index] +/// : vector<[4]xf32> from vector<[4]x[4]xf32> +/// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]} +/// : vector<[4]xf32>, memref +/// ``` +/// AFTER: +/// ```mlir +/// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j] +/// : memref, vector<[4]xi1>, vector<[4]x[4]xf32> +/// ``` +struct FoldTransferWriteOfExtractTileSlice + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, + PatternRewriter &rewriter) const final { + if (!isa(writeOp.getSource().getType())) + return rewriter.notifyMatchFailure(writeOp, "destination not a memref"); + + if (writeOp.hasOutOfBoundsDim()) + return rewriter.notifyMatchFailure(writeOp, + "not inbounds transfer write"); + + auto moveTileSlice = + writeOp.getVector().getDefiningOp(); + if (!moveTileSlice) + return rewriter.notifyMatchFailure( + writeOp, "vector to store not from MoveTileSliceToVectorOp"); + + AffineMap map = writeOp.getPermutationMap(); + if (!map.isMinorIdentity()) + return rewriter.notifyMatchFailure(writeOp, + "unsupported permutation map"); + + Value mask = writeOp.getMask(); + if (!mask) { + auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type()); + mask = rewriter.create( + writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true)); + } + + rewriter.replaceOpWithNewOp( + writeOp, moveTileSlice.getTile(), moveTileSlice.getTileSliceIndex(), + mask, writeOp.getSource(), writeOp.getIndices(), + moveTileSlice.getLayout()); + return success(); + } +}; + } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { - patterns.add(&ctx); + patterns + .add( + &ctx); } diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir index 35089eb..8ed52cd 100644 --- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir +++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir @@ -145,6 +145,18 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest return } +// ----- + +// CHECK-LABEL: func.func @transfer_write_slice_unsupported_permutation +// CHECK-NOT: arm_sme.store_tile_slice +func.func @transfer_write_slice_unsupported_permutation(%vector: vector<[4]x[4]xf32>, %dest : memref, %slice_index: index) { + %c0 = arith.constant 0 : index + %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32> + vector.transfer_write %slice, %dest[%slice_index, %c0] { permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true] }: vector<[4]xf32>, memref + return +} + + //===----------------------------------------------------------------------===// // vector.outerproduct //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir index f22b6de..8aeffb0 100644 --- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir +++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir @@ -334,6 +334,50 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb return } +// ----- + +// CHECK-LABEL: func.func @transfer_write_slice( +// CHECK-SAME: %[[VECTOR:.*]]: vector<[4]x[4]xf32>, +// CHECK-SAME: %[[DEST:.*]]: memref, +// CHECK-SAME: %[[INDEX:.*]]: index) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[MASK:.*]] = arith.constant dense : vector<[4]xi1> +// CHECK: arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref, vector<[4]xi1>, vector<[4]x[4]xf32> +func.func @transfer_write_slice(%vector: vector<[4]x[4]xf32>, %dest : memref, %slice_index: index) { + %c0 = arith.constant 0 : index + %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32> + vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref + return +} + +// ----- + +// CHECK-LABEL: func.func @transfer_write_slice_with_mask( +// CHECK-SAME: %[[VECTOR:.*]]: vector<[4]x[4]xf32>, +// CHECK-SAME: %[[DEST:.*]]: memref, +// CHECK-SAME: %[[MASK:.*]]: vector<[4]xi1>, +// CHECK-SAME: %[[INDEX:.*]]: index) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref, vector<[4]xi1>, vector<[4]x[4]xf32> +func.func @transfer_write_slice_with_mask(%vector: vector<[4]x[4]xf32>, %dest : memref, %mask: vector<[4]xi1>, %slice_index: index) { + %c0 = arith.constant 0 : index + %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32> + vector.transfer_write %slice, %dest[%slice_index, %c0], %mask { in_bounds = [true] }: vector<[4]xf32>, memref + return +} + +// ----- + +// CHECK-LABEL: func.func @transfer_write_vertical_slice +// CHECK: arm_sme.store_tile_slice {{.*}} layout +func.func @transfer_write_vertical_slice(%vector: vector<[4]x[4]xf32>, %dest : memref, %slice_index: index) { + %c0 = arith.constant 0 : index + %slice = arm_sme.move_tile_slice_to_vector %vector[%slice_index] layout + : vector<[4]xf32> from vector<[4]x[4]xf32> + vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref + return +} + //===----------------------------------------------------------------------===// // vector.broadcast //===----------------------------------------------------------------------===// -- cgit v1.1