diff options
Diffstat (limited to 'mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index 853f454..229e40e 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -121,6 +121,38 @@ using CountTrailingZerosOpLowering = LLVM::CountTrailingZerosOp>; using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>; +// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops. +struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> { + using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const LLVMTypeConverter &typeConverter = *this->getTypeConverter(); + mlir::Location loc = op.getLoc(); + mlir::Type operandType = adaptor.getOperand().getType(); + mlir::Type llvmOperandType = typeConverter.convertType(operandType); + mlir::Type sinType = typeConverter.convertType(op.getSin().getType()); + mlir::Type cosType = typeConverter.convertType(op.getCos().getType()); + if (!llvmOperandType || !sinType || !cosType) + return failure(); + + ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op); + + auto structType = LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), {llvmOperandType, llvmOperandType}); + + auto sincosOp = rewriter.create<LLVM::SincosOp>( + loc, structType, adaptor.getOperand(), attrs.getAttrs()); + + auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0); + auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1); + + rewriter.replaceOp(op, {sinValue, cosValue}); + return success(); + } +}; + // A `expm1` is converted into `exp - 1`. struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> { using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern; @@ -393,6 +425,7 @@ void mlir::populateMathToLLVMConversionPatterns( RoundEvenOpLowering, RoundOpLowering, RsqrtOpLowering, + SincosOpLowering, SinOpLowering, SinhOpLowering, ASinOpLowering, |