diff options
author | Benjamin Maxwell <benjamin.maxwell@arm.com> | 2024-06-19 12:52:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-19 12:52:53 +0100 |
commit | 4d6b9921b3801709dca9245b5b4d7c17944a036f (patch) | |
tree | 5ee361903b9bde6c21f238009b9743ac19d9f876 /mlir | |
parent | 5dde4951ae16283fffad40f84bc8ae4149766782 (diff) | |
download | llvm-4d6b9921b3801709dca9245b5b4d7c17944a036f.zip llvm-4d6b9921b3801709dca9245b5b4d7c17944a036f.tar.gz llvm-4d6b9921b3801709dca9245b5b4d7c17944a036f.tar.bz2 |
[mlir][ArmSME] Fold MoveTileSliceToVector + TransferWrite to StoreTileSlice (#95907)
Diffstat (limited to 'mlir')
3 files changed, 117 insertions, 6 deletions
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index c2f1584..56ae46a 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -666,14 +666,69 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> { } }; +/// Folds a MoveTileSliceToVectorOp + TransferWriteOp to a StoreTileSliceOp. +/// +/// BEFORE: +/// ```mlir +/// %slice = arm_sme.move_tile_slice_to_vector %tile[%index] +/// : vector<[4]xf32> from vector<[4]x[4]xf32> +/// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]} +/// : vector<[4]xf32>, memref<?x?xf32> +/// ``` +/// AFTER: +/// ```mlir +/// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j] +/// : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32> +/// ``` +struct FoldTransferWriteOfExtractTileSlice + : public OpRewritePattern<vector::TransferWriteOp> { + using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, + PatternRewriter &rewriter) const final { + if (!isa<MemRefType>(writeOp.getSource().getType())) + return rewriter.notifyMatchFailure(writeOp, "destination not a memref"); + + if (writeOp.hasOutOfBoundsDim()) + return rewriter.notifyMatchFailure(writeOp, + "not inbounds transfer write"); + + auto moveTileSlice = + writeOp.getVector().getDefiningOp<arm_sme::MoveTileSliceToVectorOp>(); + if (!moveTileSlice) + return rewriter.notifyMatchFailure( + writeOp, "vector to store not from MoveTileSliceToVectorOp"); + + AffineMap map = writeOp.getPermutationMap(); + if (!map.isMinorIdentity()) + return rewriter.notifyMatchFailure(writeOp, + "unsupported permutation map"); + + Value mask = writeOp.getMask(); + if (!mask) { + auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type()); + mask = rewriter.create<arith::ConstantOp>( + writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true)); + } + + rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>( + writeOp, moveTileSlice.getTile(), moveTileSlice.getTileSliceIndex(), + mask, writeOp.getSource(), writeOp.getIndices(), + moveTileSlice.getLayout()); + return success(); + } +}; + } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { - patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering, - TransferReadToArmSMELowering, TransferWriteToArmSMELowering, - TransposeOpToArmSMELowering, VectorLoadToArmSMELowering, - VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering, - VectorExtractToArmSMELowering, VectorInsertToArmSMELowering, - VectorPrintToArmSMELowering>(&ctx); + patterns + .add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering, + TransferReadToArmSMELowering, TransferWriteToArmSMELowering, + TransposeOpToArmSMELowering, VectorLoadToArmSMELowering, + VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering, + VectorExtractToArmSMELowering, VectorInsertToArmSMELowering, + VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice>( + &ctx); } diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir index 35089eb..8ed52cd 100644 --- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir +++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir @@ -145,6 +145,18 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest return } +// ----- + +// CHECK-LABEL: func.func @transfer_write_slice_unsupported_permutation +// CHECK-NOT: arm_sme.store_tile_slice +func.func @transfer_write_slice_unsupported_permutation(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) { + %c0 = arith.constant 0 : index + %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32> + vector.transfer_write %slice, %dest[%slice_index, %c0] { permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32> + return +} + + //===----------------------------------------------------------------------===// // vector.outerproduct //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir index f22b6de..8aeffb0 100644 --- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir +++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir @@ -334,6 +334,50 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb return } +// ----- + +// CHECK-LABEL: func.func @transfer_write_slice( +// CHECK-SAME: %[[VECTOR:.*]]: vector<[4]x[4]xf32>, +// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>, +// CHECK-SAME: %[[INDEX:.*]]: index) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<[4]xi1> +// CHECK: arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32> +func.func @transfer_write_slice(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) { + %c0 = arith.constant 0 : index + %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32> + vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32> + return +} + +// ----- + +// CHECK-LABEL: func.func @transfer_write_slice_with_mask( +// CHECK-SAME: %[[VECTOR:.*]]: vector<[4]x[4]xf32>, +// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>, +// CHECK-SAME: %[[MASK:.*]]: vector<[4]xi1>, +// CHECK-SAME: %[[INDEX:.*]]: index) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32> +func.func @transfer_write_slice_with_mask(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %mask: vector<[4]xi1>, %slice_index: index) { + %c0 = arith.constant 0 : index + %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32> + vector.transfer_write %slice, %dest[%slice_index, %c0], %mask { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32> + return +} + +// ----- + +// CHECK-LABEL: func.func @transfer_write_vertical_slice +// CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> +func.func @transfer_write_vertical_slice(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) { + %c0 = arith.constant 0 : index + %slice = arm_sme.move_tile_slice_to_vector %vector[%slice_index] layout<vertical> + : vector<[4]xf32> from vector<[4]x[4]xf32> + vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32> + return +} + //===----------------------------------------------------------------------===// // vector.broadcast //===----------------------------------------------------------------------===// |