aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
blob: ff003e486d21c87c665462deafd1b7c8c7a947b9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
//===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;
using namespace mlir::tensor;

namespace {
/// Merges consecutive tensor.extract_slice ops into one.
// TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
                                PatternRewriter &rewriter) const override {
    auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
    if (!prevOp)
      return failure();

    SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
    if (failed(affine::mergeOffsetsSizesAndStrides(
            rewriter, nextOp.getLoc(), prevOp, nextOp, prevOp.getDroppedDims(),
            newOffsets, newSizes, newStrides)))
      return failure();

    rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(),
                                                prevOp.getSource(), newOffsets,
                                                newSizes, newStrides);
    return success();
  }
};

/// Merges consecutive tensor.insert_slice ops into one.
// TODO: move to FoldTensorSubsetOps and unify APIs with FoldMemRefAliasOps.
template <typename OpTy>
struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
  using OpRewritePattern<OpTy>::OpRewritePattern;

  LogicalResult matchAndRewrite(OpTy nextOp,
                                PatternRewriter &rewriter) const override {
    auto prevOp = nextOp.getSource().template getDefiningOp<InsertSliceOp>();
    if (!prevOp)
      return failure();

    if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
      return failure();

    // The first insert_slice op should be rank reducing to make sure we cover
    // the full source tensor to be inserted in the second insert_slice op.
    SliceVerificationResult result =
        isRankReducedType(prevOp.getDestType(), prevOp.getSourceType());
    if (result != SliceVerificationResult::Success)
      return failure();

    // Dynamic dimensions can pass rank reducing check in the above, e.g,
    // inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain
    // the dynamic size covers the full tensor.
    if (!prevOp.getSourceType().hasStaticShape() ||
        !prevOp.getDestType().hasStaticShape())
      return failure();

    rewriter.replaceOpWithNewOp<OpTy>(
        nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
        nextOp.getMixedSizes(), nextOp.getMixedStrides());
    return success();
  }
};

/// Drop redundant rank expansion of insert_slice that are directly followed
/// by extract_slice. E.g.:
/// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
/// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
///     : tensor<1x1x5x10xf32> to tensor<2x2xf32>
struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
    : public OpRewritePattern<ExtractSliceOp> {
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(ExtractSliceOp extractSliceOp,
                                PatternRewriter &rewriter) const override {
    // Nothing to do if no dims are dropped.
    llvm::SmallBitVector droppedDims = extractSliceOp.getDroppedDims();
    if (droppedDims.none())
      return failure();

    // Look for tensor.insert_slice op that has an inverse rank expansion.
    auto insertSliceOp =
        extractSliceOp.getSource().getDefiningOp<InsertSliceOp>();
    if (!insertSliceOp)
      return failure();
    llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();

    // TODO: This could be extended to support cases where the dropped dims are
    // a subset of the expanded dims.
    if (expandedDims != droppedDims)
      return failure();

    // The tensor.insert_slice may not be redundant if it has multiple users.
    if (!insertSliceOp->hasOneUse())
      return failure();

    // Only consider tensor.insert_slice ops that are pure rank-reductions.
    // I.e., no elements are taken from the destination.
    if (!isCastLikeInsertSliceOp(insertSliceOp))
      return failure();

    // Extract directly from the source.
    OpBuilder::InsertionGuard g(rewriter);
    rewriter.setInsertionPoint(extractSliceOp);
    SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
    for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
         ++i) {
      if (droppedDims.test(i))
        continue;
      newOffsets.push_back(extractSliceOp.getMixedOffsets()[i]);
      newSizes.push_back(extractSliceOp.getMixedSizes()[i]);
      newStrides.push_back(extractSliceOp.getMixedStrides()[i]);
    }
    rewriter.replaceOpWithNewOp<ExtractSliceOp>(
        extractSliceOp, /*source=*/insertSliceOp.getSource(), newOffsets,
        newSizes, newStrides);
    rewriter.eraseOp(insertSliceOp);
    return success();
  }
};

/// Drop redundant rank expansion of insert_slice that direclty follows
/// extract_slice.
///
/// This can be done when the insert_slice op purely expands ranks (adds unit
/// dims) and the extrace_slice drops corresponding unit dims. For example:
///
/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
///     : tensor<2x8xf32> to tensor<8xf32>
/// %inserted_slice = tensor.insert_slice %extracted_slice
///     into %dest[0, 0] [1, 8] [1, 1]
///     : tensor<8xf32> into tensor<1x8xf32>
///
/// can be folded into:
///
/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
///     : tensor<2x8xf32> to tensor<1x8xf32>
struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
    : public OpRewritePattern<tensor::InsertSliceOp> {
  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
                                PatternRewriter &rewriter) const override {
    auto extractSliceOp =
        insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
    if (!extractSliceOp) {
      return rewriter.notifyMatchFailure(insertSliceOp,
                                         "source is not extract_slice");
    }

    // Can't fold if the extract_slice op has other users.
    if (!extractSliceOp->hasOneUse()) {
      return rewriter.notifyMatchFailure(insertSliceOp,
                                         "source has multi-uses");
    }

    // Check if the insert_slice op purely expands ranks (add unit dims).
    if (!isCastLikeInsertSliceOp(insertSliceOp)) {
      return rewriter.notifyMatchFailure(insertSliceOp,
                                         "insert_slice is not cast-like");
    }

    llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
    llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
    // Can't fold if the insert_slice op expands to more dims.
    if (extractDroppedDims.size() < insertDroppedDims.size()) {
      return rewriter.notifyMatchFailure(insertSliceOp,
                                         "insert_slice expands more dims");
    }

    // Try to match the extract dropped dims to the insert dropped dims. This is
    // done by scanning the dims of extract_slice and find the left-most one can
    // match the dim of insert_slice. If a match is found, advance the dim of
    // insert_slice to match the next one.
    unsigned insertDimPos = 0;
    for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
         ++extractDimPos) {
      // Matched all dims.
      if (insertDimPos == insertDroppedDims.size())
        break;

      bool isExtractDropped = extractDroppedDims[extractDimPos];
      bool isInsertDropped = insertDroppedDims[insertDimPos];
      // Match if both sides drop/keep the dim. Advance and match the next dim
      // of insert_slice.
      if (isExtractDropped == isInsertDropped) {
        insertDimPos += 1;
      } else if (!isExtractDropped && isInsertDropped) {
        // Not enough extract dropped dims to match the insert dropped dims.
        return rewriter.notifyMatchFailure(insertSliceOp,
                                           "insert_slice drops more unit dims");
      }
      // If the dim is dropped by extract_slice and not by insert_slice, look
      // the next dim of extract_slice to see if it can match the current dim of
      // insert_slice.
    }
    // Can't match some insert dims.
    if (insertDimPos != insertDroppedDims.size()) {
      return rewriter.notifyMatchFailure(insertSliceOp,
                                         "insert_slice has unmatched dims");
    }

    rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
        insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
        extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
        extractSliceOp.getMixedStrides());
    rewriter.eraseOp(extractSliceOp);

    return success();
  }
};
} // namespace

void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
    RewritePatternSet &patterns) {
  patterns.add<MergeConsecutiveExtractSlice,
               MergeConsecutiveInsertSlice<InsertSliceOp>,
               MergeConsecutiveInsertSlice<ParallelInsertSliceOp>>(
      patterns.getContext());
}

void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns(
    RewritePatternSet &patterns) {
  patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
               DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
      patterns.getContext());
}