diff options
author | Han-Chung Wang <hanhan0912@gmail.com> | 2024-01-19 03:15:13 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-19 03:15:13 -0800 |
commit | 12b676de728ee9046ac5fea49e27b9bf1cde4a70 (patch) | |
tree | f0aa8f6e50b3c5052a11b05a5d29e5020f96f7fb /mlir | |
parent | 2c78f3b86007fbf56a6f40b647b5cb757c082215 (diff) | |
download | llvm-12b676de728ee9046ac5fea49e27b9bf1cde4a70.zip llvm-12b676de728ee9046ac5fea49e27b9bf1cde4a70.tar.gz llvm-12b676de728ee9046ac5fea49e27b9bf1cde4a70.tar.bz2 |
[mlir][vector] Drop innermost unit dims on transfer_write. (#78554)
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 219 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir | 53 |
2 files changed, 218 insertions, 54 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index bd02c07..9c734e8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1152,8 +1152,78 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> { } }; -// Drop inner most contiguous unit dimensions from transfer_read operand. -class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { +/// Returns the number of dims can be folded away from transfer ops. It returns +/// a failure if it can not determine the number of dims to be folded. +/// Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and +/// `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims +/// can be dropped by memref.subview ops. +/// Example 2: it returns "1" if `srcType` is the same memref type with +/// [8192, 16, 8, 1] strides. +static FailureOr<size_t> +getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { + SmallVector<int64_t> srcStrides; + int64_t srcOffset; + if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) + return failure(); + + // According to vector.transfer_read/write semantics, the vector can be a + // slice. Thus, we have to offset the check index with `rankDiff` in + // `srcStrides` and source dim sizes. + size_t result = 0; + int rankDiff = srcType.getRank() - vectorType.getRank(); + for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) { + // Check that the inner dim size is 1 for both memref type and vector slice. + // It can be folded only if they are 1 and the stride is 1. + int dim = vectorType.getRank() - i - 1; + if (srcStrides[dim + rankDiff] != 1 || + srcType.getDimSize(dim + rankDiff) != 1 || + vectorType.getDimSize(dim) != 1) + break; + result++; + } + return result; +} + +/// Returns a MemRef type that drops inner `dimsToDrop` dimensions from +/// `srcType`. E.g., if `srcType` is memref<512x16x1x1xf32> and `dimsToDrop` is +/// two, it returns memref<512x16x16> type. +static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder, + MemRefType srcType, + size_t dimsToDrop) { + MemRefType resultMemrefType; + MemRefLayoutAttrInterface layout = srcType.getLayout(); + if (isa<AffineMapAttr>(layout) && layout.isIdentity()) { + return MemRefType::get(srcType.getShape().drop_back(dimsToDrop), + srcType.getElementType(), nullptr, + srcType.getMemorySpace()); + } + MemRefLayoutAttrInterface updatedLayout; + if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) { + auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop)); + updatedLayout = StridedLayoutAttr::get(strided.getContext(), + strided.getOffset(), strides); + return MemRefType::get(srcType.getShape().drop_back(dimsToDrop), + srcType.getElementType(), updatedLayout, + srcType.getMemorySpace()); + } + + // Non-strided layout case. + AffineMap map = srcType.getLayout().getAffineMap(); + int numSymbols = map.getNumSymbols(); + for (size_t i = 0; i < dimsToDrop; ++i) { + int dim = srcType.getRank() - i - 1; + map = map.replace(builder.getAffineDimExpr(dim), + builder.getAffineConstantExpr(0), map.getNumDims() - 1, + numSymbols); + } + return MemRefType::get(srcType.getShape().drop_back(dimsToDrop), + srcType.getElementType(), updatedLayout, + srcType.getMemorySpace()); +} + +/// Drop inner most contiguous unit dimensions from transfer_read operand. +class DropInnerMostUnitDimsTransferRead + : public OpRewritePattern<vector::TransferReadOp> { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp readOp, @@ -1177,29 +1247,12 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { if (targetType.getRank() <= 1) return failure(); - SmallVector<int64_t> srcStrides; - int64_t srcOffset; - if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) - return failure(); - - // According to vector.transfer_read semantics, the result can be a slice. - // It pads the indices with `1` starting from beginning. Thus, we have to - // offset the check index with `rankDiff` in `srcStrides` and source dim - // sizes. - size_t dimsToDrop = 0; - int rankDiff = srcType.getRank() - targetType.getRank(); - for (int64_t i = 0, e = targetType.getRank(); i < e; ++i) { - // Check that the inner dim size is 1 for both memref/tensor type and - // vector slice. It can be folded only if they are 1 and the stride is 1. - int dim = targetType.getRank() - i - 1; - if (srcStrides[dim + rankDiff] == 1 && - srcType.getDimSize(dim + rankDiff) == 1 && - targetType.getDimSize(dim) == 1) { - dimsToDrop++; - } else { - break; - } - } + FailureOr<size_t> maybeDimsToDrop = + getTransferFoldableInnerUnitDims(srcType, targetType); + if (failed(maybeDimsToDrop)) + return failure(); + + size_t dimsToDrop = maybeDimsToDrop.value(); if (dimsToDrop == 0) return failure(); @@ -1207,35 +1260,9 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { VectorType::get(targetType.getShape().drop_back(dimsToDrop), targetType.getElementType()); - MemRefType resultMemrefType; - MemRefLayoutAttrInterface layout = srcType.getLayout(); - if (isa<AffineMapAttr>(layout) && layout.isIdentity()) { - resultMemrefType = MemRefType::get( - srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), - nullptr, srcType.getMemorySpace()); - } else { - MemRefLayoutAttrInterface updatedLayout; - if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) { - auto strides = - llvm::to_vector(strided.getStrides().drop_back(dimsToDrop)); - updatedLayout = StridedLayoutAttr::get(strided.getContext(), - strided.getOffset(), strides); - } else { - AffineMap map = srcType.getLayout().getAffineMap(); - int numSymbols = map.getNumSymbols(); - for (size_t i = 0; i < dimsToDrop; ++i) { - int dim = srcType.getRank() - i - 1; - map = map.replace(rewriter.getAffineDimExpr(dim), - rewriter.getAffineConstantExpr(0), - map.getNumDims() - 1, numSymbols); - } - } - resultMemrefType = MemRefType::get( - srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), - updatedLayout, srcType.getMemorySpace()); - } - auto loc = readOp.getLoc(); + MemRefType resultMemrefType = + getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop); SmallVector<int64_t> offsets(srcType.getRank(), 0); SmallVector<int64_t> strides(srcType.getRank(), 1); @@ -1261,6 +1288,88 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { } }; +/// Drop inner most contiguous unit dimensions from transfer_write operand. +/// E.g., +/// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] +/// {in_bounds = [true, true, true, true, true]} +/// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32> +/// +/// will be replaced with +/// +/// %subview = memref.subview %arg0 +/// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1] +/// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32> +/// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32> +/// to vector<1x16x16xf32> +/// vector.transfer_write %0, %subview[%c0, %arg2, %c0] +/// {in_bounds = [true, true, true]} +/// : vector<1x16x16xf32>, memref<1x512x16xf32> +class DropInnerMostUnitDimsTransferWrite + : public OpRewritePattern<vector::TransferWriteOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, + PatternRewriter &rewriter) const override { + // TODO: support 0-d corner case. + if (writeOp.getTransferRank() == 0) + return failure(); + + // TODO: support mask. + if (writeOp.getMask()) + return failure(); + + auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType()); + if (!srcType || !srcType.hasStaticShape()) + return failure(); + + if (!writeOp.getPermutationMap().isMinorIdentity()) + return failure(); + + auto targetType = writeOp.getVectorType(); + if (targetType.getRank() <= 1) + return failure(); + + FailureOr<size_t> maybeDimsToDrop = + getTransferFoldableInnerUnitDims(srcType, targetType); + if (failed(maybeDimsToDrop)) + return failure(); + + size_t dimsToDrop = maybeDimsToDrop.value(); + if (dimsToDrop == 0) + return failure(); + + auto resultTargetVecType = + VectorType::get(targetType.getShape().drop_back(dimsToDrop), + targetType.getElementType()); + + MemRefType resultMemrefType = + getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop); + SmallVector<int64_t> offsets(srcType.getRank(), 0); + SmallVector<int64_t> strides(srcType.getRank(), 1); + ArrayAttr inBoundsAttr = + writeOp.getInBounds() + ? rewriter.getArrayAttr( + writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) + : ArrayAttr(); + + Location loc = writeOp.getLoc(); + Value rankedReducedView = rewriter.create<memref::SubViewOp>( + loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(), + strides); + auto permMap = getTransferMinorIdentityMap( + cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType); + + auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( + loc, resultTargetVecType, writeOp.getVector()); + rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( + writeOp, shapeCast, rankedReducedView, + writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), + // TODO: support mask. + /*mask=*/Value(), inBoundsAttr); + return success(); + } +}; + /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul /// semantics to a contraction suitable for MMT (matrix matrix multiplication /// with the RHS transposed) lowering. @@ -1696,7 +1805,9 @@ void mlir::vector::populateVectorReductionToContractPatterns( void mlir::vector:: populateVectorTransferCollapseInnerMostContiguousDimsPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit); + patterns.add<DropInnerMostUnitDimsTransferRead, + DropInnerMostUnitDimsTransferWrite>(patterns.getContext(), + benefit); } void mlir::vector::populateSinkVectorBroadcastPatterns( diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir index 0d2743b..d6d69c8 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir @@ -76,3 +76,56 @@ func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) -> // CHECK-NOT: memref.subview // CHECK: %[[READ:.+]] = vector.transfer_read %[[SRC]] // CHECK: return %[[READ]] : vector<4x8xf32> + +// ----- + +func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) { + %c0 = arith.constant 0 : index + vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] + {in_bounds = [true, true, true, true, true]} + : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32> + return +} +// CHECK: func.func @drop_two_inner_most_dim_for_transfer_write +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]] +// CHECK-SAME: memref<1x512x16x1x1xf32> to memref<1x512x16xf32> +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1x1xf32> to vector<1x16x16xf32> +// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]] +// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]] + +// ----- + +func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) { + %c0 = arith.constant 0 : index + vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0] + {in_bounds = [true, true, true, true]} + : vector<1x16x16x1xf32>, memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> + return +} +// CHECK: func.func @drop_inner_most_dim_for_transfer_write +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]] +// CHECK-SAME: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<1x512x16xf32, strided<[8192, 16, 1], offset: ?>> +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32> +// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]] +// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]] + +// ----- + +func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) { + %c0 = arith.constant 0 : index + vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0] + {in_bounds = [true, true, true]} + : vector<16x16x1xf32>, memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>> + return +} +// The inner most unit dims can not be dropped if the strides are not ones. +// CHECK: func.func @non_unit_strides +// CHECK-NOT: memref.subview |