diff options
3 files changed, 161 insertions, 42 deletions
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h index 2ca5562..e1e6a033 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h @@ -14,6 +14,37 @@ namespace mlir { namespace tensor { +/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use +/// when combining a producer slice **into** a consumer slice. +/// +/// This function performs the following computation: +/// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets +/// - Combined sizes = consumer_sizes +/// - Combined strides = producer_strides * consumer_strides +LogicalResult +mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, + ArrayRef<OpFoldResult> producerOffsets, + ArrayRef<OpFoldResult> producerSizes, + ArrayRef<OpFoldResult> producerStrides, + const llvm::SmallBitVector &droppedProducerDims, + ArrayRef<OpFoldResult> consumerOffsets, + ArrayRef<OpFoldResult> consumerSizes, + ArrayRef<OpFoldResult> consumerStrides, + SmallVector<OpFoldResult> &combinedOffsets, + SmallVector<OpFoldResult> &combinedSizes, + SmallVector<OpFoldResult> &combinedStrides); + +/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use +/// when combining a `producer` slice op **into** a `consumer` slice op. +LogicalResult +mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, + OffsetSizeAndStrideOpInterface producer, + OffsetSizeAndStrideOpInterface consumer, + const llvm::SmallBitVector &droppedProducerDims, + SmallVector<OpFoldResult> &combinedOffsets, + SmallVector<OpFoldResult> &combinedSizes, + SmallVector<OpFoldResult> &combinedStrides); + //===----------------------------------------------------------------------===// // Extract slice from `tensor.collapse_shape` //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp index 48977a9..e448944 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" @@ -17,29 +17,101 @@ using namespace mlir; using namespace mlir::tensor; -/// Adds each corresponding pair of offsets in `offsets1` and `offsets2` and -/// returns the results. -static SmallVector<OpFoldResult> mergeOffsets(Location loc, - ArrayRef<OpFoldResult> offsets1, - ArrayRef<OpFoldResult> offsets2, - OpBuilder &builder) { - SmallVector<OpFoldResult> foldedOffsets; - assert(offsets1.size() == offsets2.size()); - foldedOffsets.reserve(offsets1.size()); - - AffineExpr dim1, dim2; - bindDims(builder.getContext(), dim1, dim2); - - for (const auto &pair : llvm::zip(offsets1, offsets2)) { - auto offset0 = - getValueOrCreateConstantIndexOp(builder, loc, std::get<0>(pair)); - auto offset1 = - getValueOrCreateConstantIndexOp(builder, loc, std::get<1>(pair)); - auto foldedOffset = - makeComposedAffineApply(builder, loc, dim1 + dim2, {offset0, offset1}); - foldedOffsets.push_back(foldedOffset.getResult()); +/// Creates AffineExpr from `ofr`: if the OpFoldResult is a Value, creates a +/// AffineSymbolExpr and appends it to `symbols`; otherwise creates a +/// AffineConstantExpr. +static AffineExpr getAffineExpr(OpFoldResult ofr, + SmallVector<OpFoldResult> &symbols) { + if (auto attr = ofr.dyn_cast<Attribute>()) { + return getAffineConstantExpr(attr.cast<IntegerAttr>().getInt(), + attr.getContext()); } - return foldedOffsets; + Value v = ofr.get<Value>(); + AffineExpr expr = getAffineSymbolExpr(symbols.size(), v.getContext()); + symbols.push_back(v); + return expr; +} + +/// Builds the AffineExpr incrementally for arithmetic operations. +static AffineExpr add(AffineExpr expr, OpFoldResult ofr, + SmallVector<OpFoldResult> &symbols) { + return expr + getAffineExpr(ofr, symbols); +} +static AffineExpr mul(OpFoldResult lhs, OpFoldResult rhs, + SmallVector<OpFoldResult> &symbols) { + return getAffineExpr(lhs, symbols) * getAffineExpr(rhs, symbols); +} + +/// Converts an AffineExpr to OpFoldResult by generating an `affine.apply` +/// op and fold it. +static OpFoldResult getOpFoldResult(OpBuilder &builder, Location loc, + AffineExpr expr, + SmallVector<OpFoldResult> &symbols) { + AffineMap m = AffineMap::get(0, symbols.size(), expr); + return makeComposedFoldedAffineApply(builder, loc, m, symbols); +} + +LogicalResult tensor::mergeOffsetsSizesAndStrides( + OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> producerOffsets, + ArrayRef<OpFoldResult> producerSizes, + ArrayRef<OpFoldResult> producerStrides, + const llvm::SmallBitVector &droppedProducerDims, + ArrayRef<OpFoldResult> consumerOffsets, + ArrayRef<OpFoldResult> consumerSizes, + ArrayRef<OpFoldResult> consumerStrides, + SmallVector<OpFoldResult> &combinedOffsets, + SmallVector<OpFoldResult> &combinedSizes, + SmallVector<OpFoldResult> &combinedStrides) { + combinedOffsets.resize(producerOffsets.size()); + combinedSizes.resize(producerOffsets.size()); + combinedStrides.resize(producerOffsets.size()); + unsigned consumerPos = 0; + for (auto i : llvm::seq<unsigned>(0, producerOffsets.size())) { + if (droppedProducerDims.test(i)) { + // For dropped dims, get the values from the producer. + combinedOffsets[i] = producerOffsets[i]; + combinedSizes[i] = producerSizes[i]; + combinedStrides[i] = producerStrides[i]; + continue; + } + SmallVector<OpFoldResult> offsetSymbols, strideSymbols; + // The combined offset is computed as + // producer_offset + consumer_offset * producer_strides. + combinedOffsets[i] = + getOpFoldResult(builder, loc, + add(mul(consumerOffsets[consumerPos], + producerStrides[i], offsetSymbols), + producerOffsets[i], offsetSymbols), + offsetSymbols); + combinedSizes[i] = consumerSizes[consumerPos]; + // The combined stride is computed as + // consumer_stride * producer_stride. + combinedStrides[i] = getOpFoldResult( + builder, loc, + mul(consumerStrides[consumerPos], producerStrides[i], strideSymbols), + strideSymbols); + consumerPos++; + } + return success(); +} + +LogicalResult tensor::mergeOffsetsSizesAndStrides( + OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer, + OffsetSizeAndStrideOpInterface consumer, + const llvm::SmallBitVector &droppedProducerDims, + SmallVector<OpFoldResult> &combinedOffsets, + SmallVector<OpFoldResult> &combinedSizes, + SmallVector<OpFoldResult> &combinedStrides) { + SmallVector<OpFoldResult> consumerOffsets = consumer.getMixedOffsets(); + SmallVector<OpFoldResult> consumerSizes = consumer.getMixedSizes(); + SmallVector<OpFoldResult> consumerStrides = consumer.getMixedStrides(); + SmallVector<OpFoldResult> producerOffsets = producer.getMixedOffsets(); + SmallVector<OpFoldResult> producerSizes = producer.getMixedSizes(); + SmallVector<OpFoldResult> producerStrides = producer.getMixedStrides(); + return tensor::mergeOffsetsSizesAndStrides( + builder, loc, producerOffsets, producerSizes, producerStrides, + droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides, + combinedOffsets, combinedSizes, combinedStrides); } namespace { @@ -53,24 +125,15 @@ struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> { if (!prevOp) return failure(); - if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride()) + SmallVector<OpFoldResult> newOffsets, newSizes, newStrides; + if (failed(mergeOffsetsSizesAndStrides(rewriter, nextOp.getLoc(), prevOp, + nextOp, prevOp.getDroppedDims(), + newOffsets, newSizes, newStrides))) return failure(); - auto prevResultType = prevOp.getType().cast<ShapedType>(); - if (prevOp.getSourceType().getRank() != prevResultType.getRank()) - return rewriter.notifyMatchFailure( - prevOp, "rank-reducing producder case unimplemented"); - - Location loc = nextOp.getLoc(); - - SmallVector<OpFoldResult> prevOffsets = prevOp.getMixedOffsets(); - SmallVector<OpFoldResult> nextOffsets = nextOp.getMixedOffsets(); - SmallVector<OpFoldResult> foldedOffsets = - mergeOffsets(loc, prevOffsets, nextOffsets, rewriter); - - rewriter.replaceOpWithNewOp<ExtractSliceOp>( - nextOp, nextOp.getType(), prevOp.getSource(), foldedOffsets, - nextOp.getMixedSizes(), nextOp.getMixedStrides()); + rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(), + prevOp.getSource(), newOffsets, + newSizes, newStrides); return success(); } }; diff --git a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir index 45a3f37..f5d77f6 100644 --- a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir +++ b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir @@ -9,10 +9,12 @@ func.func @extract_slice_same_rank( // CHECK-LABEL: func.func @extract_slice_same_rank // CHECK-SAME: (%[[SOURCE:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index) -// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET0]], %[[OFFSET1]]] +// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]] // CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][7, 9, 11, %[[OFFSET]]] [8, 16, 32, %[[SIZE1]]] [1, 1, 1, 1] // CHECK: return %[[EXTRACT]] : tensor<8x16x32x?xf32> +// ----- + func.func @extract_slice_rank_reducing_consumer( %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<16x?xf32> { %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32> @@ -23,6 +25,8 @@ func.func @extract_slice_rank_reducing_consumer( // CHECK-LABEL: func.func @extract_slice_rank_reducing_consumer // CHECK: tensor.extract_slice %{{.+}}[7, 9, 11, %{{.+}}] [1, 16, 1, %{{.+}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<16x?xf32> +// ----- + func.func @extract_slice_rank_reducing_producer( %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x?xf32> { %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [1, 128, 1, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x?xf32> @@ -30,8 +34,27 @@ func.func @extract_slice_rank_reducing_producer( return %1: tensor<8x?xf32> } -// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer -// CHECK-COUNT-2: tensor.extract_slice +// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer +// CHECK-SAME: (%[[SRC:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index) +// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][0, 8, 2, %[[OFFSET]]] [1, 8, 1, %[[SIZE1]]] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<8x?xf32> +// CHECK: return %[[EXTRACT]] : tensor<8x?xf32> + +// ----- + +func.func @extract_slice_non_one_stride( + %src: tensor<?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index, %stride0: index, %stride1: index) -> tensor<?xf32> { + %0 = tensor.extract_slice %src[%offset0] [%size0] [%stride0] : tensor<?xf32> to tensor<?xf32> + %1 = tensor.extract_slice %0[%offset1] [%size1] [%stride1] : tensor<?xf32> to tensor<?xf32> + return %1: tensor<?xf32> +} + +// CHECK-LABEL: func.func @extract_slice_non_one_stride +// CHECK-SAME: (%[[SRC:.+]]: tensor<?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index, %[[STRIDE0:.+]]: index, %[[STRIDE1:.+]]: index) +// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>()[%[[OFFSET1]], %[[STRIDE0]], %[[OFFSET0]]] +// CHECK: %[[STRIDE:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%[[STRIDE1]], %[[STRIDE0]]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][%[[OFFSET]]] [%[[SIZE1]]] [%[[STRIDE]]] : tensor<?xf32> to tensor<?xf32> +// CHECK: return %[[EXTRACT]] : tensor<?xf32> // ----- @@ -47,6 +70,8 @@ func.func @insert_slice_rank_reducing( // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DST]][6, 7, 8, %[[IDX]]] [1, 1, 16, 1] [1, 1, 1, 1] // CHECK: return %[[INSERT]] +// ----- + func.func @insert_slice_rank_reducing_dynamic_shape( %dst: tensor<128x128x128x128xf32>, %mid: tensor<1x?x1xf32>, %src: tensor<?xf32>, %offset: index, %size: index) -> tensor<128x128x128x128xf32> { %0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, %size, 1] [1, 1, 1] : tensor<?xf32> into tensor<1x?x1xf32> |