diff options
Diffstat (limited to 'mlir/lib/Conversion/ArithToLLVM')
| -rw-r--r-- | mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 77 |
1 files changed, 52 insertions, 25 deletions
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index ba57155..220826d 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/TypeUtilities.h" @@ -36,20 +37,23 @@ namespace { /// attribute. template <typename SourceOp, typename TargetOp, bool Constrained, template <typename, typename> typename AttrConvert = - AttrConvertPassThrough> + AttrConvertPassThrough, + bool FailOnUnsupportedFP = false> struct ConstrainedVectorConvertToLLVMPattern - : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> { - using VectorConvertToLLVMPattern<SourceOp, TargetOp, - AttrConvert>::VectorConvertToLLVMPattern; + : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert, + FailOnUnsupportedFP> { + using VectorConvertToLLVMPattern< + SourceOp, TargetOp, AttrConvert, + FailOnUnsupportedFP>::VectorConvertToLLVMPattern; LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (Constrained != static_cast<bool>(op.getRoundingModeAttr())) return failure(); - return VectorConvertToLLVMPattern<SourceOp, TargetOp, - AttrConvert>::matchAndRewrite(op, adaptor, - rewriter); + return VectorConvertToLLVMPattern< + SourceOp, TargetOp, AttrConvert, + FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter); } }; @@ -78,7 +82,8 @@ struct IdentityBitcastLowering final using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp, arith::AttrConvertOverflowToLLVM>; @@ -87,53 +92,67 @@ using BitcastOpLowering = VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>; using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using DivSIOpLowering = VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>; using DivUIOpLowering = VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>; -using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>; +using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp, + AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using ExtSIOpLowering = VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>; using ExtUIOpLowering = VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>; using FPToSIOpLowering = - VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>; + VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp, + AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using FPToUIOpLowering = - VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>; + VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp, + AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using MaximumFOpLowering = VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MaxNumFOpLowering = VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MaxSIOpLowering = VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>; using MaxUIOpLowering = VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>; using MinimumFOpLowering = VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MinNumFOpLowering = VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MinSIOpLowering = VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>; using MinUIOpLowering = VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>; using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp, arith::AttrConvertOverflowToLLVM>; using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>; using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using RemSIOpLowering = VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>; using RemUIOpLowering = @@ -151,21 +170,25 @@ using SIToFPOpLowering = VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>; using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp, - arith::AttrConvertFastMathToLLVM>; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp, arith::AttrConvertOverflowToLLVM>; using TruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp, - false>; + false, AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern< arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true, - arith::AttrConverterConstrainedFPToLLVM>; + arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>; using TruncIOpLowering = VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp, arith::AttrConvertOverflowToLLVM>; using UIToFPOpLowering = - VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>; + VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp, + AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>; //===----------------------------------------------------------------------===// @@ -240,8 +263,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> { struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using Adaptor = - typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor; + using Adaptor = ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor; LogicalResult matchAndRewrite(arith::SelectOp op, Adaptor adaptor, @@ -259,6 +281,7 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), op->getAttrs(), + /*propAttr=*/Attribute{}, *getTypeConverter(), rewriter); } @@ -460,6 +483,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, LogicalResult CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(), + op.getLhs().getType())) + return rewriter.notifyMatchFailure(op, "unsupported floating point type"); + Type operandType = adaptor.getLhs().getType(); Type resultType = op.getResult().getType(); LLVM::FastmathFlags fmf = |
