diff options
Diffstat (limited to 'mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp')
| -rw-r--r-- | mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 68 |
1 files changed, 45 insertions, 23 deletions
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 03ed4d5..b609990 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -36,20 +36,23 @@ namespace { /// attribute. template <typename SourceOp, typename TargetOp, bool Constrained, template <typename, typename> typename AttrConvert = - AttrConvertPassThrough> + AttrConvertPassThrough, + bool FailOnUnsupportedFP = false> struct ConstrainedVectorConvertToLLVMPattern - : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> { - using VectorConvertToLLVMPattern<SourceOp, TargetOp, - AttrConvert>::VectorConvertToLLVMPattern; + : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert, + FailOnUnsupportedFP> { + using VectorConvertToLLVMPattern< + SourceOp, TargetOp, AttrConvert, + FailOnUnsupportedFP>::VectorConvertToLLVMPattern; LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (Constrained != static_cast<bool>(op.getRoundingModeAttr())) return failure(); - return VectorConvertToLLVMPattern<SourceOp, TargetOp, - AttrConvert>::matchAndRewrite(op, adaptor, - rewriter); + return VectorConvertToLLVMPattern< + SourceOp, TargetOp, AttrConvert, + FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter); } }; @@ -78,7 +81,8 @@ struct IdentityBitcastLowering final using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp, arith::AttrConvertOverflowToLLVM>; @@ -87,53 +91,67 @@ using BitcastOpLowering = VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>; using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using DivSIOpLowering = VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>; using DivUIOpLowering = VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>; -using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>; +using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp, + AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using ExtSIOpLowering = VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>; using ExtUIOpLowering = VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>; using FPToSIOpLowering = - VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>; + VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp, + AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using FPToUIOpLowering = - VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>; + VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp, + AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using MaximumFOpLowering = VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MaxNumFOpLowering = VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MaxSIOpLowering = VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>; using MaxUIOpLowering = VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>; using MinimumFOpLowering = VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MinNumFOpLowering = VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MinSIOpLowering = VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>; using MinUIOpLowering = VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>; using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp, arith::AttrConvertOverflowToLLVM>; using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>; using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using RemSIOpLowering = VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>; using RemUIOpLowering = @@ -151,21 +169,25 @@ using SIToFPOpLowering = VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>; using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp, arith::AttrConvertOverflowToLLVM>; using TruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp, - false>; + false, AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern< arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true, - arith::AttrConverterConstrainedFPToLLVM>; + arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>; using TruncIOpLowering = VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp, arith::AttrConvertOverflowToLLVM>; using UIToFPOpLowering = - VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>; + VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp, + AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>; //===----------------------------------------------------------------------===// |
