aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorHan-Chung Wang <hanhan0912@gmail.com>2024-01-19 03:15:13 -0800
committerGitHub <noreply@github.com>2024-01-19 03:15:13 -0800
commit12b676de728ee9046ac5fea49e27b9bf1cde4a70 (patch)
treef0aa8f6e50b3c5052a11b05a5d29e5020f96f7fb /mlir
parent2c78f3b86007fbf56a6f40b647b5cb757c082215 (diff)
downloadllvm-12b676de728ee9046ac5fea49e27b9bf1cde4a70.zip
llvm-12b676de728ee9046ac5fea49e27b9bf1cde4a70.tar.gz
llvm-12b676de728ee9046ac5fea49e27b9bf1cde4a70.tar.bz2
[mlir][vector] Drop innermost unit dims on transfer_write. (#78554)
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp219
-rw-r--r--mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir53
2 files changed, 218 insertions, 54 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bd02c07..9c734e8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1152,8 +1152,78 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
}
};
-// Drop inner most contiguous unit dimensions from transfer_read operand.
-class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
+/// Returns the number of dims can be folded away from transfer ops. It returns
+/// a failure if it can not determine the number of dims to be folded.
+/// Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and
+/// `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims
+/// can be dropped by memref.subview ops.
+/// Example 2: it returns "1" if `srcType` is the same memref type with
+/// [8192, 16, 8, 1] strides.
+static FailureOr<size_t>
+getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
+ SmallVector<int64_t> srcStrides;
+ int64_t srcOffset;
+ if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+ return failure();
+
+ // According to vector.transfer_read/write semantics, the vector can be a
+ // slice. Thus, we have to offset the check index with `rankDiff` in
+ // `srcStrides` and source dim sizes.
+ size_t result = 0;
+ int rankDiff = srcType.getRank() - vectorType.getRank();
+ for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
+ // Check that the inner dim size is 1 for both memref type and vector slice.
+ // It can be folded only if they are 1 and the stride is 1.
+ int dim = vectorType.getRank() - i - 1;
+ if (srcStrides[dim + rankDiff] != 1 ||
+ srcType.getDimSize(dim + rankDiff) != 1 ||
+ vectorType.getDimSize(dim) != 1)
+ break;
+ result++;
+ }
+ return result;
+}
+
+/// Returns a MemRef type that drops inner `dimsToDrop` dimensions from
+/// `srcType`. E.g., if `srcType` is memref<512x16x1x1xf32> and `dimsToDrop` is
+/// two, it returns memref<512x16x16> type.
+static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder,
+ MemRefType srcType,
+ size_t dimsToDrop) {
+ MemRefType resultMemrefType;
+ MemRefLayoutAttrInterface layout = srcType.getLayout();
+ if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
+ return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+ srcType.getElementType(), nullptr,
+ srcType.getMemorySpace());
+ }
+ MemRefLayoutAttrInterface updatedLayout;
+ if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
+ auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
+ updatedLayout = StridedLayoutAttr::get(strided.getContext(),
+ strided.getOffset(), strides);
+ return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+ srcType.getElementType(), updatedLayout,
+ srcType.getMemorySpace());
+ }
+
+ // Non-strided layout case.
+ AffineMap map = srcType.getLayout().getAffineMap();
+ int numSymbols = map.getNumSymbols();
+ for (size_t i = 0; i < dimsToDrop; ++i) {
+ int dim = srcType.getRank() - i - 1;
+ map = map.replace(builder.getAffineDimExpr(dim),
+ builder.getAffineConstantExpr(0), map.getNumDims() - 1,
+ numSymbols);
+ }
+ return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+ srcType.getElementType(), updatedLayout,
+ srcType.getMemorySpace());
+}
+
+/// Drop inner most contiguous unit dimensions from transfer_read operand.
+class DropInnerMostUnitDimsTransferRead
+ : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
@@ -1177,29 +1247,12 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
if (targetType.getRank() <= 1)
return failure();
- SmallVector<int64_t> srcStrides;
- int64_t srcOffset;
- if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
- return failure();
-
- // According to vector.transfer_read semantics, the result can be a slice.
- // It pads the indices with `1` starting from beginning. Thus, we have to
- // offset the check index with `rankDiff` in `srcStrides` and source dim
- // sizes.
- size_t dimsToDrop = 0;
- int rankDiff = srcType.getRank() - targetType.getRank();
- for (int64_t i = 0, e = targetType.getRank(); i < e; ++i) {
- // Check that the inner dim size is 1 for both memref/tensor type and
- // vector slice. It can be folded only if they are 1 and the stride is 1.
- int dim = targetType.getRank() - i - 1;
- if (srcStrides[dim + rankDiff] == 1 &&
- srcType.getDimSize(dim + rankDiff) == 1 &&
- targetType.getDimSize(dim) == 1) {
- dimsToDrop++;
- } else {
- break;
- }
- }
+ FailureOr<size_t> maybeDimsToDrop =
+ getTransferFoldableInnerUnitDims(srcType, targetType);
+ if (failed(maybeDimsToDrop))
+ return failure();
+
+ size_t dimsToDrop = maybeDimsToDrop.value();
if (dimsToDrop == 0)
return failure();
@@ -1207,35 +1260,9 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
targetType.getElementType());
- MemRefType resultMemrefType;
- MemRefLayoutAttrInterface layout = srcType.getLayout();
- if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
- resultMemrefType = MemRefType::get(
- srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
- nullptr, srcType.getMemorySpace());
- } else {
- MemRefLayoutAttrInterface updatedLayout;
- if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
- auto strides =
- llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
- updatedLayout = StridedLayoutAttr::get(strided.getContext(),
- strided.getOffset(), strides);
- } else {
- AffineMap map = srcType.getLayout().getAffineMap();
- int numSymbols = map.getNumSymbols();
- for (size_t i = 0; i < dimsToDrop; ++i) {
- int dim = srcType.getRank() - i - 1;
- map = map.replace(rewriter.getAffineDimExpr(dim),
- rewriter.getAffineConstantExpr(0),
- map.getNumDims() - 1, numSymbols);
- }
- }
- resultMemrefType = MemRefType::get(
- srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
- updatedLayout, srcType.getMemorySpace());
- }
-
auto loc = readOp.getLoc();
+ MemRefType resultMemrefType =
+ getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
SmallVector<int64_t> offsets(srcType.getRank(), 0);
SmallVector<int64_t> strides(srcType.getRank(), 1);
@@ -1261,6 +1288,88 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
}
};
+/// Drop inner most contiguous unit dimensions from transfer_write operand.
+/// E.g.,
+/// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
+/// {in_bounds = [true, true, true, true, true]}
+/// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
+///
+/// will be replaced with
+///
+/// %subview = memref.subview %arg0
+/// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
+/// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
+/// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
+/// to vector<1x16x16xf32>
+/// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
+/// {in_bounds = [true, true, true]}
+/// : vector<1x16x16xf32>, memref<1x512x16xf32>
+class DropInnerMostUnitDimsTransferWrite
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (writeOp.getTransferRank() == 0)
+ return failure();
+
+ // TODO: support mask.
+ if (writeOp.getMask())
+ return failure();
+
+ auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
+ if (!srcType || !srcType.hasStaticShape())
+ return failure();
+
+ if (!writeOp.getPermutationMap().isMinorIdentity())
+ return failure();
+
+ auto targetType = writeOp.getVectorType();
+ if (targetType.getRank() <= 1)
+ return failure();
+
+ FailureOr<size_t> maybeDimsToDrop =
+ getTransferFoldableInnerUnitDims(srcType, targetType);
+ if (failed(maybeDimsToDrop))
+ return failure();
+
+ size_t dimsToDrop = maybeDimsToDrop.value();
+ if (dimsToDrop == 0)
+ return failure();
+
+ auto resultTargetVecType =
+ VectorType::get(targetType.getShape().drop_back(dimsToDrop),
+ targetType.getElementType());
+
+ MemRefType resultMemrefType =
+ getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
+ SmallVector<int64_t> offsets(srcType.getRank(), 0);
+ SmallVector<int64_t> strides(srcType.getRank(), 1);
+ ArrayAttr inBoundsAttr =
+ writeOp.getInBounds()
+ ? rewriter.getArrayAttr(
+ writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
+ : ArrayAttr();
+
+ Location loc = writeOp.getLoc();
+ Value rankedReducedView = rewriter.create<memref::SubViewOp>(
+ loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
+ strides);
+ auto permMap = getTransferMinorIdentityMap(
+ cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
+
+ auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, resultTargetVecType, writeOp.getVector());
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ writeOp, shapeCast, rankedReducedView,
+ writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
+ // TODO: support mask.
+ /*mask=*/Value(), inBoundsAttr);
+ return success();
+ }
+};
+
/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
/// semantics to a contraction suitable for MMT (matrix matrix multiplication
/// with the RHS transposed) lowering.
@@ -1696,7 +1805,9 @@ void mlir::vector::populateVectorReductionToContractPatterns(
void mlir::vector::
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
+ patterns.add<DropInnerMostUnitDimsTransferRead,
+ DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
+ benefit);
}
void mlir::vector::populateSinkVectorBroadcastPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 0d2743b..d6d69c8 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -76,3 +76,56 @@ func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) ->
// CHECK-NOT: memref.subview
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SRC]]
// CHECK: return %[[READ]] : vector<4x8xf32>
+
+// -----
+
+func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
+ {in_bounds = [true, true, true, true, true]}
+ : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
+ return
+}
+// CHECK: func.func @drop_two_inner_most_dim_for_transfer_write
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
+// CHECK-SAME: memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1x1xf32> to vector<1x16x16xf32>
+// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
+// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
+
+// -----
+
+func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
+ {in_bounds = [true, true, true, true]}
+ : vector<1x16x16x1xf32>, memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
+ return
+}
+// CHECK: func.func @drop_inner_most_dim_for_transfer_write
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
+// CHECK-SAME: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<1x512x16xf32, strided<[8192, 16, 1], offset: ?>>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
+// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
+// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
+
+// -----
+
+func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]
+ {in_bounds = [true, true, true]}
+ : vector<16x16x1xf32>, memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>
+ return
+}
+// The inner most unit dims can not be dropped if the strides are not ones.
+// CHECK: func.func @non_unit_strides
+// CHECK-NOT: memref.subview