aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp')
-rw-r--r--mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp68
1 files changed, 45 insertions, 23 deletions
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d5..b609990 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -36,20 +36,23 @@ namespace {
/// attribute.
template <typename SourceOp, typename TargetOp, bool Constrained,
template <typename, typename> typename AttrConvert =
- AttrConvertPassThrough>
+ AttrConvertPassThrough,
+ bool FailOnUnsupportedFP = false>
struct ConstrainedVectorConvertToLLVMPattern
- : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
- using VectorConvertToLLVMPattern<SourceOp, TargetOp,
- AttrConvert>::VectorConvertToLLVMPattern;
+ : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP> {
+ using VectorConvertToLLVMPattern<
+ SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
return failure();
- return VectorConvertToLLVMPattern<SourceOp, TargetOp,
- AttrConvert>::matchAndRewrite(op, adaptor,
- rewriter);
+ return VectorConvertToLLVMPattern<
+ SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
}
};
@@ -78,7 +81,8 @@ struct IdentityBitcastLowering final
using AddFOpLowering =
VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using AddIOpLowering =
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
arith::AttrConvertOverflowToLLVM>;
@@ -87,53 +91,67 @@ using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
using DivFOpLowering =
VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using DivSIOpLowering =
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
using DivUIOpLowering =
VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
-using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
+using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using ExtSIOpLowering =
VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
using ExtUIOpLowering =
VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
using FPToSIOpLowering =
- VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
+ VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using FPToUIOpLowering =
- VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
+ VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using MaximumFOpLowering =
VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MaxNumFOpLowering =
VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MaxSIOpLowering =
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
using MaxUIOpLowering =
VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
using MinimumFOpLowering =
VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MinNumFOpLowering =
VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MinSIOpLowering =
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
using MinUIOpLowering =
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
using MulFOpLowering =
VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MulIOpLowering =
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
arith::AttrConvertOverflowToLLVM>;
using NegFOpLowering =
VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
using RemFOpLowering =
VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using RemSIOpLowering =
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
using RemUIOpLowering =
@@ -151,21 +169,25 @@ using SIToFPOpLowering =
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
using SubFOpLowering =
VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
using TruncFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
- false>;
+ false, AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
- arith::AttrConverterConstrainedFPToLLVM>;
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using TruncIOpLowering =
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
arith::AttrConvertOverflowToLLVM>;
using UIToFPOpLowering =
- VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
+ VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
//===----------------------------------------------------------------------===//