diff options
author | Hugo Trachino <hugo.trachino@huawei.com> | 2024-06-18 14:24:03 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-18 14:24:03 +0100 |
commit | 74941d053cd3cbe6a49a4b3387e21bd139377cee (patch) | |
tree | 20b3e77b3782614e714ff340c202055cfc2d5477 /mlir | |
parent | 65b0301943e64d7841e11f047a1a9fbd15f28037 (diff) | |
download | llvm-74941d053cd3cbe6a49a4b3387e21bd139377cee.zip llvm-74941d053cd3cbe6a49a4b3387e21bd139377cee.tar.gz llvm-74941d053cd3cbe6a49a4b3387e21bd139377cee.tar.bz2 |
[MLIR][Vector] Implement XferOp To {Load|Store}Lowering as MaskableOpRewritePattern (#92892)
Implements `TransferReadToVectorLoadLowering` and
`TransferWriteToVectorStoreLowering` as a `MaskableOpRewritePattern`.
Allowing to exit gracefully when run on an xferOp located inside a
`vector::MaskOp` instead of breaking because the pattern generated
multiple ops in the MaskOp with `error: 'vector.mask' op expects only
one operation to mask`.
Split of https://github.com/llvm/llvm-project/pull/90835
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 56 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir | 17 |
2 files changed, 48 insertions, 25 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 6bfb2eb..c31c514 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -429,20 +429,24 @@ namespace { /// result type. /// - The permutation map doesn't perform permutation (broadcasting is allowed). struct TransferReadToVectorLoadLowering - : public OpRewritePattern<vector::TransferReadOp> { + : public MaskableOpRewritePattern<vector::TransferReadOp> { TransferReadToVectorLoadLowering(MLIRContext *context, std::optional<unsigned> maxRank, PatternBenefit benefit = 1) - : OpRewritePattern<vector::TransferReadOp>(context, benefit), + : MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit), maxTransferRank(maxRank) {} - LogicalResult matchAndRewrite(vector::TransferReadOp read, - PatternRewriter &rewriter) const override { + FailureOr<mlir::Value> + matchAndRewriteMaskableOp(vector::TransferReadOp read, + MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override { if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) { return rewriter.notifyMatchFailure( read, "vector type is greater than max transfer rank"); } + if (maskOp) + return rewriter.notifyMatchFailure(read, "Masked case not supported"); SmallVector<unsigned> broadcastedDims; // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. @@ -485,7 +489,7 @@ struct TransferReadToVectorLoadLowering return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask"); // Create vector load op. - Operation *loadOp; + Operation *res; if (read.getMask()) { if (read.getVectorType().getRank() != 1) // vector.maskedload operates on 1-D vectors. @@ -495,24 +499,20 @@ struct TransferReadToVectorLoadLowering Value fill = rewriter.create<vector::SplatOp>( read.getLoc(), unbroadcastedVectorType, read.getPadding()); - loadOp = rewriter.create<vector::MaskedLoadOp>( + res = rewriter.create<vector::MaskedLoadOp>( read.getLoc(), unbroadcastedVectorType, read.getSource(), read.getIndices(), read.getMask(), fill); } else { - loadOp = rewriter.create<vector::LoadOp>( + res = rewriter.create<vector::LoadOp>( read.getLoc(), unbroadcastedVectorType, read.getSource(), read.getIndices()); } // Insert a broadcasting op if required. - if (!broadcastedDims.empty()) { - rewriter.replaceOpWithNewOp<vector::BroadcastOp>( - read, read.getVectorType(), loadOp->getResult(0)); - } else { - rewriter.replaceOp(read, loadOp->getResult(0)); - } - - return success(); + if (!broadcastedDims.empty()) + res = rewriter.create<vector::BroadcastOp>( + read.getLoc(), read.getVectorType(), res->getResult(0)); + return res->getResult(0); } std::optional<unsigned> maxTransferRank; @@ -581,19 +581,23 @@ struct VectorStoreToMemrefStoreLowering /// - The permutation map is the minor identity map (neither permutation nor /// broadcasting is allowed). struct TransferWriteToVectorStoreLowering - : public OpRewritePattern<vector::TransferWriteOp> { + : public MaskableOpRewritePattern<vector::TransferWriteOp> { TransferWriteToVectorStoreLowering(MLIRContext *context, std::optional<unsigned> maxRank, PatternBenefit benefit = 1) - : OpRewritePattern<vector::TransferWriteOp>(context, benefit), + : MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit), maxTransferRank(maxRank) {} - LogicalResult matchAndRewrite(vector::TransferWriteOp write, - PatternRewriter &rewriter) const override { + FailureOr<mlir::Value> + matchAndRewriteMaskableOp(vector::TransferWriteOp write, + MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override { if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) { return rewriter.notifyMatchFailure( write, "vector type is greater than max transfer rank"); } + if (maskOp) + return rewriter.notifyMatchFailure(write, "Masked case not supported"); // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. @@ -645,14 +649,16 @@ struct TransferWriteToVectorStoreLowering << write; }); - rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( - write, write.getSource(), write.getIndices(), write.getMask(), - write.getVector()); + rewriter.create<vector::MaskedStoreOp>( + write.getLoc(), write.getSource(), write.getIndices(), + write.getMask(), write.getVector()); } else { - rewriter.replaceOpWithNewOp<vector::StoreOp>( - write, write.getVector(), write.getSource(), write.getIndices()); + rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(), + write.getSource(), write.getIndices()); } - return success(); + // There's no return value for StoreOps. Use Value() to signal success to + // matchAndRewrite. + return Value(); } std::optional<unsigned> maxTransferRank; diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index 2f2bdca..d169e6d 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -51,6 +51,23 @@ func.func @transfer_to_load(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> return %res : vector<4xf32> } +// Masked transfer_read/write inside are NOT lowered to vector.load/store +// CHECK-LABEL: func @masked_transfer_to_load( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index, +// CHECK-SAME: %[[MASK:.*]]: vector<4xi1>) -> memref<8x8xf32> +// CHECK-NOT: vector.load +// CHECK-NOT: vector.store +// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %arg0[%[[IDX]], %[[IDX]]]{{.*}} : memref<8x8xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32> +// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1> + +func.func @masked_transfer_to_load(%mem : memref<8x8xf32>, %i : index, %mask : vector<4xi1>) -> memref<8x8xf32> { + %cf0 = arith.constant 0.0 : f32 + %read = vector.mask %mask {vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>} : vector<4xi1> -> vector<4xf32> + vector.mask %mask {vector.transfer_write %read, %mem[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1> + return %mem : memref<8x8xf32> +} + // n-D results are also supported. // CHECK-LABEL: func @transfer_2D( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, |