aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenjamin Maxwell <benjamin.maxwell@arm.com>2024-06-19 12:52:53 +0100
committerGitHub <noreply@github.com>2024-06-19 12:52:53 +0100
commit4d6b9921b3801709dca9245b5b4d7c17944a036f (patch)
tree5ee361903b9bde6c21f238009b9743ac19d9f876
parent5dde4951ae16283fffad40f84bc8ae4149766782 (diff)
downloadllvm-4d6b9921b3801709dca9245b5b4d7c17944a036f.zip
llvm-4d6b9921b3801709dca9245b5b4d7c17944a036f.tar.gz
llvm-4d6b9921b3801709dca9245b5b4d7c17944a036f.tar.bz2
[mlir][ArmSME] Fold MoveTileSliceToVector + TransferWrite to StoreTileSlice (#95907)
-rw-r--r--mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp67
-rw-r--r--mlir/test/Conversion/VectorToArmSME/unsupported.mlir12
-rw-r--r--mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir44
3 files changed, 117 insertions, 6 deletions
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index c2f1584..56ae46a 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -666,14 +666,69 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
}
};
+/// Folds a MoveTileSliceToVectorOp + TransferWriteOp to a StoreTileSliceOp.
+///
+/// BEFORE:
+/// ```mlir
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%index]
+/// : vector<[4]xf32> from vector<[4]x[4]xf32>
+/// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
+/// : vector<[4]xf32>, memref<?x?xf32>
+/// ```
+/// AFTER:
+/// ```mlir
+/// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
+/// : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+/// ```
+struct FoldTransferWriteOfExtractTileSlice
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const final {
+ if (!isa<MemRefType>(writeOp.getSource().getType()))
+ return rewriter.notifyMatchFailure(writeOp, "destination not a memref");
+
+ if (writeOp.hasOutOfBoundsDim())
+ return rewriter.notifyMatchFailure(writeOp,
+ "not inbounds transfer write");
+
+ auto moveTileSlice =
+ writeOp.getVector().getDefiningOp<arm_sme::MoveTileSliceToVectorOp>();
+ if (!moveTileSlice)
+ return rewriter.notifyMatchFailure(
+ writeOp, "vector to store not from MoveTileSliceToVectorOp");
+
+ AffineMap map = writeOp.getPermutationMap();
+ if (!map.isMinorIdentity())
+ return rewriter.notifyMatchFailure(writeOp,
+ "unsupported permutation map");
+
+ Value mask = writeOp.getMask();
+ if (!mask) {
+ auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type());
+ mask = rewriter.create<arith::ConstantOp>(
+ writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
+ }
+
+ rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
+ writeOp, moveTileSlice.getTile(), moveTileSlice.getTileSliceIndex(),
+ mask, writeOp.getSource(), writeOp.getIndices(),
+ moveTileSlice.getLayout());
+ return success();
+ }
+};
+
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
- TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
- TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
- VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
- VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
- VectorPrintToArmSMELowering>(&ctx);
+ patterns
+ .add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+ TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
+ TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
+ VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
+ VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice>(
+ &ctx);
}
diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
index 35089eb..8ed52cd 100644
--- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
@@ -145,6 +145,18 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
return
}
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_slice_unsupported_permutation
+// CHECK-NOT: arm_sme.store_tile_slice
+func.func @transfer_write_slice_unsupported_permutation(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) {
+ %c0 = arith.constant 0 : index
+ %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
+ vector.transfer_write %slice, %dest[%slice_index, %c0] { permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
+ return
+}
+
+
//===----------------------------------------------------------------------===//
// vector.outerproduct
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index f22b6de..8aeffb0 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -334,6 +334,50 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb
return
}
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_slice(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[INDEX:.*]]: index) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK: arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+func.func @transfer_write_slice(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) {
+ %c0 = arith.constant 0 : index
+ %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
+ vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_slice_with_mask(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<[4]xi1>,
+// CHECK-SAME: %[[INDEX:.*]]: index) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+func.func @transfer_write_slice_with_mask(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %mask: vector<[4]xi1>, %slice_index: index) {
+ %c0 = arith.constant 0 : index
+ %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
+ vector.transfer_write %slice, %dest[%slice_index, %c0], %mask { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @transfer_write_vertical_slice
+// CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical>
+func.func @transfer_write_vertical_slice(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) {
+ %c0 = arith.constant 0 : index
+ %slice = arm_sme.move_tile_slice_to_vector %vector[%slice_index] layout<vertical>
+ : vector<[4]xf32> from vector<[4]x[4]xf32>
+ vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
+ return
+}
+
//===----------------------------------------------------------------------===//
// vector.broadcast
//===----------------------------------------------------------------------===//