aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp')
-rw-r--r--mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp33
1 files changed, 33 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 853f454..229e40e 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -121,6 +121,38 @@ using CountTrailingZerosOpLowering =
LLVM::CountTrailingZerosOp>;
using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
+// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
+struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
+ using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
+ mlir::Location loc = op.getLoc();
+ mlir::Type operandType = adaptor.getOperand().getType();
+ mlir::Type llvmOperandType = typeConverter.convertType(operandType);
+ mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
+ mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
+ if (!llvmOperandType || !sinType || !cosType)
+ return failure();
+
+ ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
+
+ auto structType = LLVM::LLVMStructType::getLiteral(
+ rewriter.getContext(), {llvmOperandType, llvmOperandType});
+
+ auto sincosOp = rewriter.create<LLVM::SincosOp>(
+ loc, structType, adaptor.getOperand(), attrs.getAttrs());
+
+ auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
+ auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
+
+ rewriter.replaceOp(op, {sinValue, cosValue});
+ return success();
+ }
+};
+
// A `expm1` is converted into `exp - 1`.
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
@@ -393,6 +425,7 @@ void mlir::populateMathToLLVMConversionPatterns(
RoundEvenOpLowering,
RoundOpLowering,
RsqrtOpLowering,
+ SincosOpLowering,
SinOpLowering,
SinhOpLowering,
ASinOpLowering,