diff options
Diffstat (limited to 'mlir/lib/Conversion/ComplexToStandard')
| -rw-r--r-- | mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp | 54 |
1 files changed, 41 insertions, 13 deletions
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 0fe7239..9e46b7d 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -313,25 +313,53 @@ private: struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { using OpConversionPattern<complex::ExpOp>::OpConversionPattern; + // exp(x+I*y) = exp(x)*(cos(y)+I*sin(y)) + // Handle special cases as StableHLO implementation does: + // 1. When b == 0, set imag(exp(z)) = 0 + // 2. When exp(x) == inf, use exp(x/2)*(cos(y)+I*sin(y))*exp(x/2) LogicalResult matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast<ComplexType>(adaptor.getComplex().getType()); - auto elementType = cast<FloatType>(type.getElementType()); - arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - - Value real = - complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); - Value imag = - complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); - Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue()); - Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue()); + auto ET = cast<FloatType>(type.getElementType()); + arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); + const auto &floatSemantics = ET.getFloatSemantics(); + ImplicitLocOpBuilder b(loc, rewriter); + + Value x = complex::ReOp::create(b, ET, adaptor.getComplex()); + Value y = complex::ImOp::create(b, ET, adaptor.getComplex()); + Value zero = arith::ConstantOp::create(b, ET, b.getZeroAttr(ET)); + Value half = arith::ConstantOp::create(b, ET, b.getFloatAttr(ET, 0.5)); + Value inf = arith::ConstantOp::create( + b, ET, b.getFloatAttr(ET, APFloat::getInf(floatSemantics))); + + Value exp = math::ExpOp::create(b, x, fmf); + Value xHalf = arith::MulFOp::create(b, x, half, fmf); + Value expHalf = math::ExpOp::create(b, xHalf, fmf); + Value cos = math::CosOp::create(b, y, fmf); + Value sin = math::SinOp::create(b, y, fmf); + + Value expIsInf = + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, exp, inf, fmf); + Value yIsZero = + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, y, zero); + + // Real path: select between exp(x)*cos(y) and exp(x/2)*cos(y)*exp(x/2) + Value realNormal = arith::MulFOp::create(b, exp, cos, fmf); + Value expHalfCos = arith::MulFOp::create(b, expHalf, cos, fmf); + Value realOverflow = arith::MulFOp::create(b, expHalfCos, expHalf, fmf); Value resultReal = - arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue()); - Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue()); - Value resultImag = - arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue()); + arith::SelectOp::create(b, expIsInf, realOverflow, realNormal); + + // Imaginary part: if y == 0 return 0 else select between exp(x)*sin(y) and + // exp(x/2)*sin(y)*exp(x/2) + Value imagNormal = arith::MulFOp::create(b, exp, sin, fmf); + Value expHalfSin = arith::MulFOp::create(b, expHalf, sin, fmf); + Value imagOverflow = arith::MulFOp::create(b, expHalfSin, expHalf, fmf); + Value imagNonZero = + arith::SelectOp::create(b, expIsInf, imagOverflow, imagNormal); + Value resultImag = arith::SelectOp::create(b, yIsZero, zero, imagNonZero); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); |
