//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include "mlir/Conversion/ComplexCommon/DivisionConverter.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include namespace mlir { #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARDPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { enum class AbsFn { abs, sqrt, rsqrt }; // Returns the absolute value, its square root or its reciprocal square root. Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) { Value one = arith::ConstantOp::create(b, real.getType(), b.getFloatAttr(real.getType(), 1.0)); Value absReal = math::AbsFOp::create(b, real, fmf); Value absImag = math::AbsFOp::create(b, imag, fmf); Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf); Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf); // The lowering below requires NaNs and infinities to work correctly. arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); Value ratio = arith::DivFOp::create(b, min, max, fmfWithNaNInf); Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf); Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf); Value result; if (fn == AbsFn::rsqrt) { ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf); min = math::RsqrtOp::create(b, min, fmfWithNaNInf); max = math::RsqrtOp::create(b, max, fmfWithNaNInf); } if (fn == AbsFn::sqrt) { Value quarter = arith::ConstantOp::create( b, real.getType(), b.getFloatAttr(real.getType(), 0.25)); // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily. Value sqrt = math::SqrtOp::create(b, max, fmfWithNaNInf); Value p025 = math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf); result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf); } else { Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf); result = arith::MulFOp::create(b, max, sqrt, fmfWithNaNInf); } Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result, result, fmfWithNaNInf); return arith::SelectOp::create(b, isNaN, min, result); } struct AbsOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); Value real = complex::ReOp::create(b, adaptor.getComplex()); Value imag = complex::ImOp::create(b, adaptor.getComplex()); rewriter.replaceOp(op, computeAbs(real, imag, fmf, b)); return success(); } }; // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2)) struct Atan2OpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto type = cast(op.getType()); Type elementType = type.getElementType(); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value lhs = adaptor.getLhs(); Value rhs = adaptor.getRhs(); Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf); Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf); Value rhsSquaredPlusLhsSquared = complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf); Value sqrtOfRhsSquaredPlusLhsSquared = complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf); Value zero = arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType)); Value one = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, 1)); Value i = complex::CreateOp::create(b, type, zero, one); Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf); Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf); Value divResult = complex::DivOp::create( b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf); Value logResult = complex::LogOp::create(b, divResult, fmf); Value negativeOne = arith::ConstantOp::create( b, elementType, b.getFloatAttr(elementType, -1)); Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne); rewriter.replaceOpWithNewOp(op, negativeI, logResult, fmf); return success(); } }; template struct ComparisonOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; using ResultCombiner = std::conditional_t::value, arith::AndIOp, arith::OrIOp>; LogicalResult matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getLhs().getType()).getElementType(); Value realLhs = complex::ReOp::create(rewriter, loc, type, adaptor.getLhs()); Value imagLhs = complex::ImOp::create(rewriter, loc, type, adaptor.getLhs()); Value realRhs = complex::ReOp::create(rewriter, loc, type, adaptor.getRhs()); Value imagRhs = complex::ImOp::create(rewriter, loc, type, adaptor.getRhs()); Value realComparison = arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs); Value imagComparison = arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs); rewriter.replaceOpWithNewOp(op, realComparison, imagComparison); return success(); } }; // Default conversion which applies the BinaryStandardOp separately on the real // and imaginary parts. Can for example be used for complex::AddOp and // complex::SubOp. template struct BinaryComplexOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs()); Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs()); Value resultReal = BinaryStandardOp::create(b, elementType, realLhs, realRhs, fmf.getValue()); Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs()); Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs()); Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs, imagRhs, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; template struct TrigonometricOpConversion : public OpConversionPattern { using OpAdaptor = typename OpConversionPattern::OpAdaptor; using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); // Trigonometric ops use a set of common building blocks to convert to real // ops. Here we create these building blocks and call into an op-specific // implementation in the subclass to combine them. Value half = arith::ConstantOp::create( rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); Value exp = math::ExpOp::create(rewriter, loc, imag, fmf); Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf); Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf); Value sin = math::SinOp::create(rewriter, loc, real, fmf); Value cos = math::CosOp::create(rewriter, loc, real, fmf); auto resultPair = combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf); rewriter.replaceOpWithNewOp(op, type, resultPair.first, resultPair.second); return success(); } virtual std::pair combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, Value cos, ConversionPatternRewriter &rewriter, arith::FastMathFlagsAttr fmf) const = 0; }; struct CosOpConversion : public TrigonometricOpConversion { using TrigonometricOpConversion::TrigonometricOpConversion; std::pair combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, Value cos, ConversionPatternRewriter &rewriter, arith::FastMathFlagsAttr fmf) const override { // Complex cosine is defined as; // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy))) // Plugging in: // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) // and defining t := exp(y) // We get: // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x Value sum = arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf); Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf); Value diff = arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf); Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf); return {resultReal, resultImag}; } }; struct DivOpConversion : public OpConversionPattern { DivOpConversion(MLIRContext *context, complex::ComplexRangeFlags target) : OpConversionPattern(context), complexRange(target) {} using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value lhsReal = complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs()); Value lhsImag = complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs()); Value rhsReal = complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs()); Value rhsImag = complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs()); Value resultReal, resultImag; if (complexRange == complex::ComplexRangeFlags::basic || complexRange == complex::ComplexRangeFlags::none) { mlir::complex::convertDivToStandardUsingAlgebraic( rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal, &resultImag); } else if (complexRange == complex::ComplexRangeFlags::improved) { mlir::complex::convertDivToStandardUsingRangeReduction( rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal, &resultImag); } rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } private: complex::ComplexRangeFlags complexRange; }; struct ExpOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue()); Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue()); Value resultReal = arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue()); Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue()); Value resultImag = arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg, ArrayRef coefficients, arith::FastMathFlagsAttr fmf) { auto argType = mlir::cast(arg.getType()); Value poly = arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[0])); for (unsigned i = 1; i < coefficients.size(); ++i) { poly = math::FmaOp::create( b, poly, arg, arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[i])), fmf); } return poly; } struct Expm1OpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i // [handle inaccuracies when a and/or b are small] // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i LogicalResult matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = op.getType(); auto elemType = mlir::cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value real = complex::ReOp::create(b, adaptor.getComplex()); Value imag = complex::ImOp::create(b, adaptor.getComplex()); Value zero = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 0.0)); Value one = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 1.0)); Value expm1Real = math::ExpM1Op::create(b, real, fmf); Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf); Value sinImag = math::SinOp::create(b, imag, fmf); Value cosm1Imag = emitCosm1(imag, fmf, b); Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf); Value realResult = arith::AddFOp::create( b, arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf); Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf.getValue()); Value imagResult = arith::SelectOp::create( b, imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf)); rewriter.replaceOpWithNewOp(op, type, realResult, imagResult); return success(); } private: Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf, ImplicitLocOpBuilder &b) const { auto argType = mlir::cast(arg.getType()); auto negHalf = arith::ConstantOp::create(b, b.getFloatAttr(argType, -0.5)); auto negOne = arith::ConstantOp::create(b, b.getFloatAttr(argType, -1.0)); // Algorithm copied from cephes cosm1. SmallVector kCoeffs{ 4.7377507964246204691685E-14, -1.1470284843425359765671E-11, 2.0876754287081521758361E-9, -2.7557319214999787979814E-7, 2.4801587301570552304991E-5, -1.3888888888888872993737E-3, 4.1666666666666666609054E-2, }; Value cos = math::CosOp::create(b, arg, fmf); Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf); Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf); Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf); Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf); auto forSmallArg = arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf), arith::MulFOp::create(b, negHalf, argPow2, fmf)); // (pi/4)^2 is approximately 0.61685 Value piOver4Pow2 = arith::ConstantOp::create(b, b.getFloatAttr(argType, 0.61685)); Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2, piOver4Pow2, fmf.getValue()); return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg); } }; struct LogOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf.getValue()); Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue()); Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value resultImag = math::Atan2Op::create(b, elementType, imag, real, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct Log1pOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value real = complex::ReOp::create(b, adaptor.getComplex()); Value imag = complex::ImOp::create(b, adaptor.getComplex()); Value half = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, 0.5)); Value one = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, 1)); Value realPlusOne = arith::AddFOp::create(b, real, one, fmf); Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf); Value absImag = math::AbsFOp::create(b, imag, fmf); Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf); Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf); Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, realPlusOne, absImag, fmf); Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf); Value maxAbsOfRealPlusOneAndImagMinusOne = arith::SelectOp::create(b, useReal, real, maxMinusOne); arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf); Value logOfMaxAbsOfRealPlusOneAndImag = math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf); Value logOfSqrtPart = math::Log1pOp::create( b, arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf), fmfWithNaNInf); Value r = arith::AddFOp::create( b, arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf), logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf); Value resultReal = arith::SelectOp::create( b, arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf), minAbs, r); Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct MulOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); auto fmfValue = fmf.getValue(); Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs()); Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs()); Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs()); Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs()); Value lhsRealTimesRhsReal = arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue); Value lhsImagTimesRhsImag = arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue); Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal, lhsImagTimesRhsImag, fmfValue); Value lhsImagTimesRhsReal = arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue); Value lhsRealTimesRhsImag = arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue); Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal, lhsRealTimesRhsImag, fmfValue); rewriter.replaceOpWithNewOp(op, type, real, imag); return success(); } }; struct NegOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); Value real = complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value negReal = arith::NegFOp::create(rewriter, loc, real); Value negImag = arith::NegFOp::create(rewriter, loc, imag); rewriter.replaceOpWithNewOp(op, type, negReal, negImag); return success(); } }; struct SinOpConversion : public TrigonometricOpConversion { using TrigonometricOpConversion::TrigonometricOpConversion; std::pair combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, Value cos, ConversionPatternRewriter &rewriter, arith::FastMathFlagsAttr fmf) const override { // Complex sine is defined as; // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy))) // Plugging in: // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) // and defining t := exp(y) // We get: // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x Value sum = arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf); Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf); Value diff = arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf); Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf); return {resultReal, resultImag}; } }; // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. struct SqrtOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto type = cast(op.getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); auto cst = [&](APFloat v) { return arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, v)); }; const auto &floatSemantics = elementType.getFloatSemantics(); Value zero = cst(APFloat::getZero(floatSemantics)); Value half = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, 0.5)); Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt); Value argArg = math::Atan2Op::create(b, imag, real, fmf); Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf); Value cos = math::CosOp::create(b, sqrtArg, fmf); Value sin = math::SinOp::create(b, sqrtArg, fmf); // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply // 0 * inf. Value sinIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf); Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf); Value resultImag = arith::SelectOp::create( b, sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf)); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { Value inf = cst(APFloat::getInf(floatSemantics)); Value negInf = cst(APFloat::getInf(floatSemantics, true)); Value nan = cst(APFloat::getNaN(floatSemantics)); Value absImag = math::AbsFOp::create(b, elementType, imag, fmf); Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absImag, inf, fmf); Value absImagIsNotInf = arith::CmpFOp::create( b, arith::CmpFPredicate::ONE, absImag, inf, fmf); Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf); Value realIsNegInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, negInf, fmf); resultReal = arith::SelectOp::create( b, arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero, resultReal); resultReal = arith::SelectOp::create( b, arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal); Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf); resultImag = arith::SelectOp::create( b, arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt), nan, resultImag); resultImag = arith::SelectOp::create( b, arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf, resultImag); } Value resultIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf); resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal); resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct SignOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value zero = arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType)); Value realIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero); Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero); Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero); auto abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf); Value realSign = arith::DivFOp::create(b, real, abs, fmf); Value imagSign = arith::DivFOp::create(b, imag, abs, fmf); Value sign = complex::CreateOp::create(b, type, realSign, imagSign); rewriter.replaceOpWithNewOp(op, isZero, adaptor.getComplex(), sign); return success(); } }; template struct TanTanhOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); const auto &floatSemantics = elementType.getFloatSemantics(); Value real = complex::ReOp::create(b, loc, elementType, adaptor.getComplex()); Value imag = complex::ImOp::create(b, loc, elementType, adaptor.getComplex()); Value negOne = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, -1.0)); if constexpr (std::is_same_v) { // tan(x+yi) = -i*tanh(-y + xi) std::swap(real, imag); real = arith::MulFOp::create(b, real, negOne, fmf); } auto cst = [&](APFloat v) { return arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, v)); }; Value inf = cst(APFloat::getInf(floatSemantics)); Value four = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, 4.0)); Value twoReal = arith::AddFOp::create(b, real, real, fmf); Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf); Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf); Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf); Value realNum = arith::SubFOp::create(b, expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); Value cosImag = math::CosOp::create(b, imag, fmf); Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf); Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf); Value sinImag = math::SinOp::create(b, imag, fmf); Value imagNum = arith::MulFOp::create( b, four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf); Value expSumMinusTwo = arith::AddFOp::create(b, expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); Value denom = arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf); Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, expSumMinusTwo, inf, fmf); Value realLimit = math::CopySignOp::create(b, negOne, real, fmf); Value resultReal = arith::SelectOp::create( b, isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf)); Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { Value absReal = math::AbsFOp::create(b, real, fmf); Value zero = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, 0.0)); Value nan = cst(APFloat::getNaN(floatSemantics)); Value absRealIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absReal, inf, fmf); Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf); Value absRealIsNotInf = arith::XOrIOp::create( b, absRealIsInf, arith::ConstantIntOp::create(b, true, /*width=*/1)); Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, imagNum, imagNum, fmf); Value resultRealIsNaN = arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf); Value resultImagIsZero = arith::OrIOp::create( b, imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN)); resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal); resultImag = arith::SelectOp::create(b, resultImagIsZero, zero, resultImag); } if constexpr (std::is_same_v) { // tan(x+yi) = -i*tanh(-y + xi) std::swap(resultReal, resultImag); resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf); } rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct ConjOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); Value real = complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag); rewriter.replaceOpWithNewOp(op, type, real, negImag); return success(); } }; /// Converts lhs^y = (a+bi)^(c+di) to /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, ComplexType type, Value lhs, Value c, Value d, arith::FastMathFlags fmf) { auto elementType = cast(type.getElementType()); Value a = complex::ReOp::create(builder, lhs); Value b = complex::ImOp::create(builder, lhs); Value abs = complex::AbsOp::create(builder, lhs, fmf); Value absToC = math::PowFOp::create(builder, abs, c, fmf); Value negD = arith::NegFOp::create(builder, d, fmf); Value argLhs = math::Atan2Op::create(builder, b, a, fmf); Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf); Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf); Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf); Value lnAbs = math::LogOp::create(builder, abs, fmf); Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf); Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf); Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf); Value cosQ = math::CosOp::create(builder, q, fmf); Value sinQ = math::SinOp::create(builder, q, fmf); Value inf = arith::ConstantOp::create( builder, elementType, builder.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); Value zero = arith::ConstantOp::create( builder, elementType, builder.getFloatAttr(elementType, 0.0)); Value one = arith::ConstantOp::create(builder, elementType, builder.getFloatAttr(elementType, 1.0)); Value complexOne = complex::CreateOp::create(builder, type, one, zero); Value complexZero = complex::CreateOp::create(builder, type, zero, zero); Value complexInf = complex::CreateOp::create(builder, type, inf, zero); // Case 0: // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. Value absEqZero = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf); Value dEqZero = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf); Value cEqZero = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf); Value bEqZero = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf); Value zeroLeC = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf); Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf); Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf); Value complexOneOrZero = arith::SelectOp::create(builder, cEqZero, complexOne, complexZero); Value coeffCosSin = complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ); Value cutoff0 = arith::SelectOp::create( builder, arith::AndIOp::create( builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC), complexOneOrZero, coeffCosSin); // Case 1: // x^0 is defined to be 1 for any x, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero); Value cutoff1 = arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0); // Case 2: // 1^(c + d*i) = 1 + 0*i Value lhsEqOne = arith::AndIOp::create( builder, arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf), bEqZero); Value cutoff2 = arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1); // Case 3: // inf^(c + 0*i) = inf + 0*i, c > 0 Value lhsEqInf = arith::AndIOp::create( builder, arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf), bEqZero); Value rhsGt0 = arith::AndIOp::create( builder, dEqZero, arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf)); Value cutoff3 = arith::SelectOp::create( builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf, cutoff2); // Case 4: // inf^(c + 0*i) = 0 + 0*i, c < 0 Value rhsLt0 = arith::AndIOp::create( builder, dEqZero, arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf)); Value cutoff4 = arith::SelectOp::create( builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero, cutoff3); return cutoff4; } struct PowOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::PowOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs()); Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs()); rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(), c, d, op.getFastmath())}); return success(); } }; struct RsqrtOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); auto cst = [&](APFloat v) { return arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, v)); }; const auto &floatSemantics = elementType.getFloatSemantics(); Value zero = cst(APFloat::getZero(floatSemantics)); Value inf = cst(APFloat::getInf(floatSemantics)); Value negHalf = arith::ConstantOp::create( b, elementType, b.getFloatAttr(elementType, -0.5)); Value nan = cst(APFloat::getNaN(floatSemantics)); Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt); Value argArg = math::Atan2Op::create(b, imag, real, fmf); Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf); Value cos = math::CosOp::create(b, rsqrtArg, fmf); Value sin = math::SinOp::create(b, rsqrtArg, fmf); Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf); Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { Value negOne = arith::ConstantOp::create(b, elementType, b.getFloatAttr(elementType, -1)); Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf); Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf); Value negImagSignedZero = arith::MulFOp::create(b, negOne, imagSignedZero, fmf); Value absReal = math::AbsFOp::create(b, real, fmf); Value absImag = math::AbsFOp::create(b, imag, fmf); Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absImag, inf, fmf); Value realIsNan = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf); Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absReal, inf, fmf); Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan); Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf); resultReal = arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal); resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero, resultImag); } Value isRealZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf); Value isImagZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf); Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero); resultReal = arith::SelectOp::create(b, isZero, inf, resultReal); resultImag = arith::SelectOp::create(b, isZero, nan, resultImag); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct AngleOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = op.getType(); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = complex::ReOp::create(rewriter, loc, type, adaptor.getComplex()); Value imag = complex::ImOp::create(rewriter, loc, type, adaptor.getComplex()); rewriter.replaceOpWithNewOp(op, imag, real, fmf); return success(); } }; } // namespace void mlir::populateComplexToStandardConversionPatterns( RewritePatternSet &patterns, complex::ComplexRangeFlags complexRange) { // clang-format off patterns.add< AbsOpConversion, AngleOpConversion, Atan2OpConversion, BinaryComplexOpConversion, BinaryComplexOpConversion, ComparisonOpConversion, ComparisonOpConversion, ConjOpConversion, CosOpConversion, ExpOpConversion, Expm1OpConversion, Log1pOpConversion, LogOpConversion, MulOpConversion, NegOpConversion, SignOpConversion, SinOpConversion, SqrtOpConversion, TanTanhOpConversion, TanTanhOpConversion, PowOpConversion, RsqrtOpConversion >(patterns.getContext()); patterns.add(patterns.getContext(), complexRange); // clang-format on } namespace { struct ConvertComplexToStandardPass : public impl::ConvertComplexToStandardPassBase< ConvertComplexToStandardPass> { using Base::Base; void runOnOperation() override; }; void ConvertComplexToStandardPass::runOnOperation() { // Convert to the Standard dialect using the converter defined above. RewritePatternSet patterns(&getContext()); populateComplexToStandardConversionPatterns(patterns, complexRange); ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } } // namespace