aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
blob: cf60a048f782c6aa0caae292cc6b8229591f6098 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
//===- AttrToLLVMConverter.cpp - Arith attributes conversion to LLVM ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"

using namespace mlir;

LLVM::FastmathFlags
mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
  LLVM::FastmathFlags llvmFMF{};
  const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
      {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
      {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
      {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
      {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
      {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
      {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
      {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
  for (auto [arithFlag, llvmFlag] : flags) {
    if (bitEnumContainsAny(arithFMF, arithFlag))
      llvmFMF = llvmFMF | llvmFlag;
  }
  return llvmFMF;
}

LLVM::FastmathFlagsAttr
mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
  arith::FastMathFlags arithFMF = fmfAttr.getValue();
  return LLVM::FastmathFlagsAttr::get(
      fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
}

LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM(
    arith::IntegerOverflowFlags arithFlags) {
  LLVM::IntegerOverflowFlags llvmFlags{};
  const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
      flags[] = {
          {arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
          {arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
  for (auto [arithFlag, llvmFlag] : flags) {
    if (bitEnumContainsAny(arithFlags, arithFlag))
      llvmFlags = llvmFlags | llvmFlag;
  }
  return llvmFlags;
}

LLVM::RoundingMode
mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
  switch (roundingMode) {
  case arith::RoundingMode::downward:
    return LLVM::RoundingMode::TowardNegative;
  case arith::RoundingMode::to_nearest_away:
    return LLVM::RoundingMode::NearestTiesToAway;
  case arith::RoundingMode::to_nearest_even:
    return LLVM::RoundingMode::NearestTiesToEven;
  case arith::RoundingMode::toward_zero:
    return LLVM::RoundingMode::TowardZero;
  case arith::RoundingMode::upward:
    return LLVM::RoundingMode::TowardPositive;
  }
  llvm_unreachable("Unhandled rounding mode");
}

LLVM::RoundingModeAttr mlir::arith::convertArithRoundingModeAttrToLLVM(
    arith::RoundingModeAttr roundingModeAttr) {
  assert(roundingModeAttr && "Expecting valid attribute");
  return LLVM::RoundingModeAttr::get(
      roundingModeAttr.getContext(),
      convertArithRoundingModeToLLVM(roundingModeAttr.getValue()));
}

LLVM::FPExceptionBehaviorAttr
mlir::arith::getLLVMDefaultFPExceptionBehavior(MLIRContext &context) {
  return LLVM::FPExceptionBehaviorAttr::get(&context,
                                            LLVM::FPExceptionBehavior::Ignore);
}