aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ArithToLLVM
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/ArithToLLVM')
-rw-r--r--mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp77
1 files changed, 52 insertions, 25 deletions
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index ba57155..220826d 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
@@ -36,20 +37,23 @@ namespace {
/// attribute.
template <typename SourceOp, typename TargetOp, bool Constrained,
template <typename, typename> typename AttrConvert =
- AttrConvertPassThrough>
+ AttrConvertPassThrough,
+ bool FailOnUnsupportedFP = false>
struct ConstrainedVectorConvertToLLVMPattern
- : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
- using VectorConvertToLLVMPattern<SourceOp, TargetOp,
- AttrConvert>::VectorConvertToLLVMPattern;
+ : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP> {
+ using VectorConvertToLLVMPattern<
+ SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
return failure();
- return VectorConvertToLLVMPattern<SourceOp, TargetOp,
- AttrConvert>::matchAndRewrite(op, adaptor,
- rewriter);
+ return VectorConvertToLLVMPattern<
+ SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
}
};
@@ -78,7 +82,8 @@ struct IdentityBitcastLowering final
using AddFOpLowering =
VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using AddIOpLowering =
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
arith::AttrConvertOverflowToLLVM>;
@@ -87,53 +92,67 @@ using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
using DivFOpLowering =
VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using DivSIOpLowering =
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
using DivUIOpLowering =
VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
-using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
+using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using ExtSIOpLowering =
VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
using ExtUIOpLowering =
VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
using FPToSIOpLowering =
- VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
+ VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using FPToUIOpLowering =
- VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
+ VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using MaximumFOpLowering =
VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MaxNumFOpLowering =
VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MaxSIOpLowering =
VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
using MaxUIOpLowering =
VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
using MinimumFOpLowering =
VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MinNumFOpLowering =
VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MinSIOpLowering =
VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
using MinUIOpLowering =
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
using MulFOpLowering =
VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using MulIOpLowering =
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
arith::AttrConvertOverflowToLLVM>;
using NegFOpLowering =
VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
using RemFOpLowering =
VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using RemSIOpLowering =
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
using RemUIOpLowering =
@@ -151,21 +170,25 @@ using SIToFPOpLowering =
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
using SubFOpLowering =
VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
- arith::AttrConvertFastMathToLLVM>;
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
using TruncFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
- false>;
+ false, AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
- arith::AttrConverterConstrainedFPToLLVM>;
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using TruncIOpLowering =
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
arith::AttrConvertOverflowToLLVM>;
using UIToFPOpLowering =
- VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
+ VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/true>;
using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
//===----------------------------------------------------------------------===//
@@ -240,8 +263,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
- using Adaptor =
- typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
+ using Adaptor = ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
LogicalResult
matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
@@ -259,6 +281,7 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
adaptor.getOperands(), op->getAttrs(),
+ /*propAttr=*/Attribute{},
*getTypeConverter(), rewriter);
}
@@ -460,6 +483,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
LogicalResult
CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
+ if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
+ op.getLhs().getType()))
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+
Type operandType = adaptor.getLhs().getType();
Type resultType = op.getResult().getType();
LLVM::FastmathFlags fmf =