aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp
blob: 00a076b6e974602af1f679b249da2331b52aa659 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements rewriting those linalg named ops that are essentially
// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
// optimization on `linalg.elementwise` such as folding transpose, broadcast.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::linalg;

#define DEBUG_TYPE "linalg-named-to-elementwise"

namespace {
ElementwiseKind getKind(Operation *op) {
  return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
      .Case([](SelectOp) { return ElementwiseKind::select; })
      .Case([](AddOp) { return ElementwiseKind::add; })
      .Case([](SubOp) { return ElementwiseKind::sub; })
      .Case([](MulOp) { return ElementwiseKind::mul; })
      .Case([](DivOp) { return ElementwiseKind::div; })
      .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
      .Case([](PowFOp) { return ElementwiseKind::powf; })
      .Case([](ExpOp) { return ElementwiseKind::exp; })
      .Case([](LogOp) { return ElementwiseKind::log; })
      .Case([](AbsOp) { return ElementwiseKind::abs; })
      .Case([](CeilOp) { return ElementwiseKind::ceil; })
      .Case([](FloorOp) { return ElementwiseKind::floor; })
      .Case([](NegFOp) { return ElementwiseKind::negf; })
      .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
      .Case([](RoundOp) { return ElementwiseKind::round; })
      .Case([](SqrtOp) { return ElementwiseKind::sqrt; })
      .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
      .Case([](SquareOp) { return ElementwiseKind::square; })
      .Case([](TanhOp) { return ElementwiseKind::tanh; })
      .Case([](ErfOp) { return ElementwiseKind::erf; })
      .Default([&](Operation *op) {
        llvm_unreachable("unhandled case in named to elementwise");
        return ElementwiseKind::sub;
      });
}

template <typename NamedOpTy>
struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
  using OpRewritePattern<NamedOpTy>::OpRewritePattern;

  LogicalResult matchAndRewrite(NamedOpTy op,
                                PatternRewriter &rewriter) const override {
    SmallVector<NamedAttribute> attrs;
    auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
    attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
    attrs.push_back(
        rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));

    rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
                                               op.getDpsInits(), attrs);
    return success();
  }
};
} // namespace

void mlir::linalg::populateLinalgNamedToElementwisePatterns(
    RewritePatternSet &patterns) {
  patterns.add<NamedToElementwisePattern<SelectOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
  patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
}