diff options
author | Andrzej WarzyĆski <andrzej.warzynski@arm.com> | 2023-12-05 08:35:58 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-05 08:35:58 +0000 |
commit | 2eb9e33cc57d5acc2232d468a99f0e35c8f583dc (patch) | |
tree | d0d3f040e246ba0c22c79d16ca43b2bb521d2123 /mlir | |
parent | e8dbe945f39f2249fe24e0d62ec8ac998e853c2b (diff) | |
download | llvm-2eb9e33cc57d5acc2232d468a99f0e35c8f583dc.zip llvm-2eb9e33cc57d5acc2232d468a99f0e35c8f583dc.tar.gz llvm-2eb9e33cc57d5acc2232d468a99f0e35c8f583dc.tar.bz2 |
[mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N) (#73523)
Updates patterns for flattening `vector.transfer_read` by relaxing the
requirement that the "collapsed" indices are all zero. This enables
collapsing cases like this one:
```mlir
%2 = vector.transfer_read %arg4[%c0, %arg0, %arg1, %c0] ... :
memref<1x43x4x6xi32>, vector<1x2x6xi32>
```
Previously only the following case would be consider for collapsing
(all indices are 0):
```mlir
%2 = vector.transfer_read %arg4[%c0, %c0, %c0, %c0] ... :
memref<1x43x4x6xi32>, vector<1x2x6xi32>
```
Also adds some new comments and renames the `firstContiguousInnerDim`
parameter as `firstDimToCollapse` (the latter better matches the actual
meaning).
Similar updates for `vector.transfer_write` will be implemented in a
follow-up patch.
Diffstat (limited to 'mlir')
4 files changed, 129 insertions, 11 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index aab7075..ed42e65 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -511,6 +511,8 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, /// Checks that the indices corresponding to dimensions starting at /// `firstDimToCollapse` are constant 0, and writes to `outIndices` /// the truncated indices where `firstDimToCollapse` is now the innermost dim. +/// TODO: Extract the logic that writes to outIndices so that this method +/// simply checks one pre-condition. static LogicalResult checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, SmallVector<Value> &outIndices) { @@ -542,18 +544,18 @@ class FlattenContiguousRowMajorTransferReadPattern auto loc = transferReadOp.getLoc(); Value vector = transferReadOp.getVector(); VectorType vectorType = cast<VectorType>(vector.getType()); - Value source = transferReadOp.getSource(); + auto source = transferReadOp.getSource(); MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); + + // 0. Check pre-conditions // Contiguity check is valid on tensors only. if (!sourceType) return failure(); + // If this is already 0D/1D, there's nothing to do. if (vectorType.getRank() <= 1) - // Already 0D/1D, nothing to do. return failure(); if (!vector::isContiguousSlice(sourceType, vectorType)) return failure(); - int64_t firstContiguousInnerDim = - sourceType.getRank() - vectorType.getRank(); // TODO: generalize this pattern, relax the requirements here. if (transferReadOp.hasOutOfBoundsDim()) return failure(); @@ -561,26 +563,81 @@ class FlattenContiguousRowMajorTransferReadPattern return failure(); if (transferReadOp.getMask()) return failure(); + SmallVector<Value> collapsedIndices; - if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(), - firstContiguousInnerDim, - collapsedIndices))) - return failure(); + int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); + + // 1. Collapse the source memref Value collapsedSource = - collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); + collapseInnerDims(rewriter, loc, source, firstDimToCollapse); MemRefType collapsedSourceType = dyn_cast<MemRefType>(collapsedSource.getType()); int64_t collapsedRank = collapsedSourceType.getRank(); - assert(collapsedRank == firstContiguousInnerDim + 1); + assert(collapsedRank == firstDimToCollapse + 1); + + // 2. Generate input args for a new vector.transfer_read that will read + // from the collapsed memref. + // 2.1. New dim exprs + affine map SmallVector<AffineExpr, 1> dimExprs{ - getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; + getAffineDimExpr(firstDimToCollapse, rewriter.getContext())}; auto collapsedMap = AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); + + // 2.2 New indices + // If all the collapsed indices are zero then no extra logic is needed. + // Otherwise, a new offset/index has to be computed. + if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(), + firstDimToCollapse, + collapsedIndices))) { + // Copy all the leading indices + collapsedIndices = transferReadOp.getIndices(); + collapsedIndices.resize(firstDimToCollapse); + + // Compute the remaining trailing index/offset required for reading from + // the collapsed memref: + // + // offset = 0 + // for (i = firstDimToCollapse; i < outputRank; ++i) + // offset += sourceType.getDimSize(i) * transferReadOp.indices[i] + // + // For this example: + // %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) : + // memref<1x43x2xi32>, vector<1x2xi32> + // which would be collapsed to: + // %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) : + // memref<1x86xi32>, vector<2xi32> + // one would get the following offset: + // %offset = %arg0 * 43 + AffineExpr offsetExpr, idxExpr; + bindSymbols(rewriter.getContext(), offsetExpr, idxExpr); + + int64_t outputRank = transferReadOp.getIndices().size(); + OpFoldResult offset = + rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult(); + + for (int64_t i = firstDimToCollapse; i < outputRank; ++i) { + int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i); + offset = affine::makeComposedFoldedAffineApply( + rewriter, loc, offsetExpr + dim * idxExpr, + {offset, transferReadOp.getIndices()[i]}); + } + if (offset.is<Value>()) { + collapsedIndices.push_back(offset.get<Value>()); + } else { + collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>( + loc, *getConstantIntValue(offset))); + } + } + + // 3. Create new vector.transfer_read that reads from the collapsed memref VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, vectorType.getElementType()); vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); + + // 4. Replace the old transfer_read with the new one reading from the + // collapsed shape rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( transferReadOp, cast<VectorType>(vector.getType()), flatRead); return success(); diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index ac0fe64..2ad992a 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -265,6 +265,11 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) { return false; auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank); + // TODO: Add support for memref with trailing dynamic shapes. Memrefs + // with leading dynamic dimensions are already supported. + if (ShapedType::isDynamicShape(memrefShape)) + return false; + // Cond 1: A contiguous memref will always have a unit trailing stride. if (strides.back() != 1) return false; diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 2ffe85b..603792e 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -41,6 +41,61 @@ func.func @transfer_read_dims_mismatch_contiguous( // ----- +func.func @transfer_read_dims_mismatch_non_zero_indices( + %idx_1: index, + %idx_2: index, + %m_in: memref<1x43x4x6xi32>, + %m_out: memref<1x2x6xi32>) { + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : + memref<1x43x4x6xi32>, vector<1x2x6xi32> + vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} : + vector<1x2x6xi32>, memref<1x2x6xi32> + return +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)> + +// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices( +// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, +// CHECK-SAME: %[[M_IN:.*]]: memref<1x43x4x6xi32>, +// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) { +// CHECK: %[[C_0:.*]] = arith.constant 0 : i32 +// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index +// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32> +// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]], %[[IDX_1]]] +// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32> +// CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32> +// CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32> + +// ----- + +func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( + %idx_1: index, + %idx_2: index, + %m_in: memref<1x?x4x6xi32>, + %m_out: memref<1x2x6xi32>) { + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : + memref<1x?x4x6xi32>, vector<1x2x6xi32> + vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} : + vector<1x2x6xi32>, memref<1x2x6xi32> + return +} + +// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( +// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, +// CHECK-SAME: %[[M_IN:.*]]: memref<1x?x4x6xi32>, +// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) { +// CHECK: %[[READ:.*]] = vector.transfer_read %[[M_IN]]{{.*}} : memref<1x?x4x6xi32>, vector<1x2x6xi32> +// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_OUT]]{{.*}} : memref<1x2x6xi32> into memref<12xi32> +// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32> +// CHECK: vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32> + +// ----- + func.func @transfer_read_dims_mismatch_non_contiguous( %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> { %c0 = arith.constant 0 : index diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index feb716c..86b8d5f 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -454,6 +454,7 @@ struct TestFlattenVectorTransferPatterns } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert<memref::MemRefDialect>(); + registry.insert<affine::AffineDialect>(); } void runOnOperation() override { RewritePatternSet patterns(&getContext()); |