aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
blob: b1c0c3b161b2040c2d29052620ef3b0c768a0cc7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
//===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements folding ops such as transpose and broadcast into the
// affine maps of the elementwise op.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir {
#define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

using namespace mlir;
using namespace mlir::linalg;

#define DEBUG_TYPE "linalg-fold-into-elementwise"

namespace {
struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
  using OpRewritePattern<ElementwiseOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(ElementwiseOp op,
                                PatternRewriter &rewriter) const override {
    bool changed = false;
    SmallVector<Value> newIns;
    SmallVector<AffineMap> newMaps;
    for (OpOperand *operand : op.getDpsInputOperands()) {
      AffineMap map = op.getMatchingIndexingMap(operand);
      auto transposeOp = operand->get().getDefiningOp<TransposeOp>();

      if (!map.isIdentity() || !transposeOp) {
        // push in original operand and its map.
        newIns.push_back(operand->get());
        newMaps.push_back(map);
        continue;
      }
      newIns.push_back(transposeOp.getInput());
      // push in transposeOp's inverse permutation map.
      newMaps.push_back(transposeOp.getMatchingIndexingMap(
          transposeOp.getDpsInputOperand(0)));
      changed = true;
    }
    if (!changed)
      return failure();
    newMaps.push_back(op.getIndexingMapsArray().back());

    rewriter.replaceOpWithNewOp<ElementwiseOp>(
        op, newIns, op.getDpsInits()[0], op.getKindAttr(),
        rewriter.getAffineMapArrayAttr(newMaps));
    return success();
  }
};

struct LinalgFoldIntoElementwisePass
    : public impl::LinalgFoldIntoElementwisePassBase<
          LinalgFoldIntoElementwisePass> {
  using impl::LinalgFoldIntoElementwisePassBase<
      LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase;

  void runOnOperation() override {
    Operation *op = getOperation();
    RewritePatternSet patterns(op->getContext());
    populateLinalgFoldIntoElementwisePatterns(patterns);

    if (failed(applyPatternsGreedily(op, std::move(patterns))))
      return signalPassFailure();
  }
};
} // namespace

void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
    RewritePatternSet &patterns) {
  patterns.add<FoldTransposePattern>(patterns.getContext());
}