aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp
blob: 9a0651a5445e60d753d0571dd3bd2c1c19bcbe50 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
//===- MathToEmitC.cpp - Math to EmitC Patterns -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;

namespace {
template <typename OpType>
class LowerToEmitCCallOpaque : public OpRewritePattern<OpType> {
  std::string calleeStr;
  emitc::LanguageTarget languageTarget;

public:
  LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr,
                         emitc::LanguageTarget languageTarget)
      : OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)),
        languageTarget(languageTarget) {}

  LogicalResult matchAndRewrite(OpType op,
                                PatternRewriter &rewriter) const override;
};

template <typename OpType>
LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
    OpType op, PatternRewriter &rewriter) const {
  if (!llvm::all_of(op->getOperandTypes(),
                    llvm::IsaPred<Float32Type, Float64Type>) ||
      !llvm::all_of(op->getResultTypes(),
                    llvm::IsaPred<Float32Type, Float64Type>))
    return rewriter.notifyMatchFailure(
        op.getLoc(),
        "expected all operands and results to be of type f32 or f64");
  std::string modifiedCalleeStr = calleeStr;
  if (languageTarget == emitc::LanguageTarget::cpp11) {
    modifiedCalleeStr = "std::" + calleeStr;
  } else if (languageTarget == emitc::LanguageTarget::c99) {
    auto operandType = op->getOperandTypes()[0];
    if (operandType.isF32())
      modifiedCalleeStr = calleeStr + "f";
  }
  rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
      op, op.getType(), modifiedCalleeStr, op->getOperands());
  return success();
}

} // namespace

// Populates patterns to replace `math` operations with `emitc.call_opaque`,
// using function names consistent with those in <math.h>.
void mlir::populateConvertMathToEmitCPatterns(
    RewritePatternSet &patterns, emitc::LanguageTarget languageTarget) {
  auto *context = patterns.getContext();
  patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor",
                                                         languageTarget);
  patterns.insert<LowerToEmitCCallOpaque<math::RoundOp>>(context, "round",
                                                         languageTarget);
  patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp",
                                                       languageTarget);
  patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos",
                                                       languageTarget);
  patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin",
                                                       languageTarget);
  patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos",
                                                        languageTarget);
  patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin",
                                                        languageTarget);
  patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2",
                                                         languageTarget);
  patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil",
                                                        languageTarget);
  patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs",
                                                        languageTarget);
  patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow",
                                                        languageTarget);
}