aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
blob: 1e3b377ab85c72ef2cc20e2b0841326c4357662a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
//===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Swap a `tensor.extract_slice` with the producer of the source if the producer
// implements the `TilingInterface`. When used in conjunction with tiling this
// effectively tiles + fuses the producer with its consumer.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "tensor-swap-slices"

using namespace mlir;

FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
    OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
  auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
  if (!producerOp)
    return failure();

  // `TilingInterface` currently only supports strides being 1.
  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
    return failure();

  FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
      builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
      sliceOp.getMixedSizes());
  if (failed(tiledResult))
    return failure();

  // For cases where the slice was rank-reducing, create a rank-reducing slice
  // to get the same type back.
  llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
  if (droppedDims.any()) {
    assert(tiledResult->tiledValues.size() == 1 &&
           "expected only a single tiled result value to replace the extract "
           "slice");
    SmallVector<OpFoldResult> offsets(sliceOp.getSourceType().getRank(),
                                      builder.getIndexAttr(0));
    SmallVector<OpFoldResult> strides(sliceOp.getSourceType().getRank(),
                                      builder.getIndexAttr(1));
    auto newSliceOp = tensor::ExtractSliceOp::create(
        builder, sliceOp.getLoc(), sliceOp.getType(),
        tiledResult->tiledValues[0], offsets, sliceOp.getMixedSizes(), strides);
    tiledResult->tiledValues[0] = newSliceOp;
  }

  return *tiledResult;
}

FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
    OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
    ArrayRef<OpOperand *> consumerOperands) {
  if (sliceOps.empty()) {
    LLVM_DEBUG(
        { llvm::dbgs() << "expected candidate slices list to be non-empty"; });
    return failure();
  }
  if (sliceOps.size() != consumerOperands.size()) {
    LLVM_DEBUG({
      llvm::dbgs()
          << "expected as many operands as the number of slices passed";
    });
    return failure();
  }
  auto consumerOp =
      dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
  if (!consumerOp)
    return failure();
  for (auto opOperand : consumerOperands.drop_front()) {
    if (opOperand->getOwner() != consumerOp) {
      LLVM_DEBUG({
        llvm::dbgs()
            << "expected all consumer operands to be from the same operation";
      });
      return failure();
    }
  }

  auto consumerOperandNums = llvm::map_to_vector(
      consumerOperands, [](OpOperand *opOperand) -> unsigned {
        return opOperand->getOperandNumber();
      });
  SmallVector<SmallVector<OpFoldResult>> allOffsets;
  SmallVector<SmallVector<OpFoldResult>> allSizes;
  for (auto sliceOp : sliceOps) {

    // `TilingInterface` currently only supports strides being 1.
    if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
      return failure();

    SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
    SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
    allOffsets.emplace_back(std::move(offsets));
    allSizes.emplace_back(std::move(sizes));
  }
  FailureOr<TilingResult> tiledResult =
      consumerOp.getTiledImplementationFromOperandTiles(
          builder, consumerOperandNums, allOffsets, allSizes);
  if (failed(tiledResult))
    return failure();

  return *tiledResult;
}