aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
blob: 9f36e5c369d066c27274637dd55d018b23d5629b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
//===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
#define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_

#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"

namespace mlir {

namespace {
/// Detection trait tor the `getFastmath` instance method.
template <typename T>
using has_get_fastmath_t = decltype(std::declval<T>().getFastmath());
} // namespace

/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or
/// `f32ApproxFunc` or `f16Func` or `i32Type` depending on the element type and
/// the fastMathFlag of that Op, if present. The function declaration is added
/// in case it was not added before.
///
/// If the input values are of bf16 type (or f16 type if f16Func is empty), the
/// value is first casted to f32, the function called and then the result casted
/// back.
///
/// Example with NVVM:
///   %exp_f32 = math.exp %arg_f32 : f32
///
/// will be transformed into
///   llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
///
/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
/// to the approximate calculation function.
///
/// Also example with NVVM:
///   %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
///
/// will be transformed into
///   llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
///
/// Final example with NVVM:
///   %pow_f32 = math.fpowi %arg_f32, %arg_i32
///
/// will be transformed into
///   llvm.call @__nv_powif(%arg_f32, %arg_i32) : (f32, i32) -> f32
template <typename SourceOp>
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
  explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
                                StringRef f32Func, StringRef f64Func,
                                StringRef f32ApproxFunc, StringRef f16Func,
                                StringRef i32Func = "",
                                PatternBenefit benefit = 1)
      : ConvertOpToLLVMPattern<SourceOp>(lowering, benefit), f32Func(f32Func),
        f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func),
        i32Func(i32Func) {}

  LogicalResult
  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    using LLVM::LLVMFuncOp;

    static_assert(
        std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
        "expected single result op");

    bool isResultBool = op->getResultTypes().front().isInteger(1);
    if constexpr (!std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
                                   SourceOp>::value) {
      assert(op->getNumOperands() > 0 &&
             "expected op to take at least one operand");
      assert((op->getResultTypes().front() == op->getOperand(0).getType() ||
              isResultBool) &&
             "expected op with same operand and result types");
    }

    if (!op->template getParentOfType<FunctionOpInterface>()) {
      return rewriter.notifyMatchFailure(
          op, "expected op to be within a function region");
    }

    SmallVector<Value, 1> castedOperands;
    for (Value operand : adaptor.getOperands())
      castedOperands.push_back(maybeCast(operand, rewriter));

    Type castedOperandType = castedOperands.front().getType();

    // At ABI level, booleans are treated as i32.
    Type resultType =
        isResultBool ? rewriter.getIntegerType(32) : castedOperandType;
    Type funcType = getFunctionType(resultType, castedOperands);
    StringRef funcName = getFunctionName(castedOperandType, op);
    if (funcName.empty())
      return failure();

    LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
    auto callOp =
        LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands);

    if (resultType == adaptor.getOperands().front().getType()) {
      rewriter.replaceOp(op, {callOp.getResult()});
      return success();
    }

    // Boolean result are mapping to i32 at the ABI level with zero values being
    // interpreted as false and non-zero values being interpreted as true. Since
    // there is no guarantee of a specific value being used to indicate true,
    // compare for inequality with zero (rather than truncate or shift).
    if (isResultBool) {
      Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(),
                                            rewriter.getIntegerType(32),
                                            rewriter.getI32IntegerAttr(0));
      Value truncated =
          LLVM::ICmpOp::create(rewriter, op->getLoc(), LLVM::ICmpPredicate::ne,
                               callOp.getResult(), zero);
      rewriter.replaceOp(op, {truncated});
      return success();
    }

    assert(callOp.getResult().getType().isF32() &&
           "only f32 types are supposed to be truncated back");
    Value truncated = LLVM::FPTruncOp::create(
        rewriter, op->getLoc(), adaptor.getOperands().front().getType(),
        callOp.getResult());
    rewriter.replaceOp(op, {truncated});
    return success();
  }

  Value maybeCast(Value operand, PatternRewriter &rewriter) const {
    Type type = operand.getType();
    if (!isa<Float16Type, BFloat16Type>(type))
      return operand;

    // If there's an f16 function, no need to cast f16 values.
    if (!f16Func.empty() && isa<Float16Type>(type))
      return operand;

    return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
                                 Float32Type::get(rewriter.getContext()),
                                 operand);
  }

  Type getFunctionType(Type resultType, ValueRange operands) const {
    SmallVector<Type> operandTypes(operands.getTypes());
    return LLVM::LLVMFunctionType::get(resultType, operandTypes);
  }

  LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
                                     Operation *op) const {
    using LLVM::LLVMFuncOp;

    auto funcAttr = StringAttr::get(op->getContext(), funcName);
    auto funcOp =
        SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
    if (funcOp)
      return funcOp;

    auto parentFunc = op->getParentOfType<FunctionOpInterface>();
    assert(parentFunc && "expected there to be a parent function");
    OpBuilder b(parentFunc);

    // Create a valid global location removing any metadata attached to the
    // location as debug info metadata inside of a function cannot be used
    // outside of that function.
    auto globalloc = op->getLoc()->findInstanceOfOrUnknown<FileLineColLoc>();
    return LLVMFuncOp::create(b, globalloc, funcName, funcType);
  }

  StringRef getFunctionName(Type type, SourceOp op) const {
    bool useApprox = false;
    if constexpr (llvm::is_detected<has_get_fastmath_t, SourceOp>::value) {
      arith::FastMathFlags flag = op.getFastmath();
      useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
                  !f32ApproxFunc.empty();
    }

    if (isa<Float16Type>(type))
      return f16Func;
    if (isa<Float32Type>(type)) {
      if (useApprox)
        return f32ApproxFunc;
      return f32Func;
    }
    if (isa<Float64Type>(type))
      return f64Func;

    if (type.isInteger(32))
      return i32Func;
    return "";
  }

  const std::string f32Func;
  const std::string f64Func;
  const std::string f32ApproxFunc;
  const std::string f16Func;
  const std::string i32Func;
};

} // namespace mlir

#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_