diff options
Diffstat (limited to 'mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp')
-rw-r--r-- | mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp | 707 |
1 files changed, 360 insertions, 347 deletions
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 0c832c4..5ad514d 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include <type_traits> @@ -31,44 +30,45 @@ enum class AbsFn { abs, sqrt, rsqrt }; // Returns the absolute value, its square root or its reciprocal square root. Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) { - Value one = b.create<arith::ConstantOp>(real.getType(), - b.getFloatAttr(real.getType(), 1.0)); + Value one = arith::ConstantOp::create(b, real.getType(), + b.getFloatAttr(real.getType(), 1.0)); - Value absReal = b.create<math::AbsFOp>(real, fmf); - Value absImag = b.create<math::AbsFOp>(imag, fmf); + Value absReal = math::AbsFOp::create(b, real, fmf); + Value absImag = math::AbsFOp::create(b, imag, fmf); - Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf); - Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf); + Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf); + Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf); // The lowering below requires NaNs and infinities to work correctly. arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); - Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf); - Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf); - Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf); + Value ratio = arith::DivFOp::create(b, min, max, fmfWithNaNInf); + Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf); + Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf); Value result; if (fn == AbsFn::rsqrt) { - ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf); - min = b.create<math::RsqrtOp>(min, fmfWithNaNInf); - max = b.create<math::RsqrtOp>(max, fmfWithNaNInf); + ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf); + min = math::RsqrtOp::create(b, min, fmfWithNaNInf); + max = math::RsqrtOp::create(b, max, fmfWithNaNInf); } if (fn == AbsFn::sqrt) { - Value quarter = b.create<arith::ConstantOp>( - real.getType(), b.getFloatAttr(real.getType(), 0.25)); + Value quarter = arith::ConstantOp::create( + b, real.getType(), b.getFloatAttr(real.getType(), 0.25)); // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily. - Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf); - Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf); - result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf); + Value sqrt = math::SqrtOp::create(b, max, fmfWithNaNInf); + Value p025 = + math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf); + result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf); } else { - Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf); - result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf); + Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf); + result = arith::MulFOp::create(b, max, sqrt, fmfWithNaNInf); } - Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, - result, fmfWithNaNInf); - return b.create<arith::SelectOp>(isNaN, min, result); + Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result, + result, fmfWithNaNInf); + return arith::SelectOp::create(b, isNaN, min, result); } struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { @@ -81,8 +81,8 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); - Value real = b.create<complex::ReOp>(adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(adaptor.getComplex()); + Value real = complex::ReOp::create(b, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, adaptor.getComplex()); rewriter.replaceOp(op, computeAbs(real, imag, fmf, b)); return success(); @@ -105,28 +105,28 @@ struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> { Value lhs = adaptor.getLhs(); Value rhs = adaptor.getRhs(); - Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf); - Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf); + Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf); + Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf); Value rhsSquaredPlusLhsSquared = - b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf); + complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf); Value sqrtOfRhsSquaredPlusLhsSquared = - b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf); + complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf); Value zero = - b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); - Value one = b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, 1)); - Value i = b.create<complex::CreateOp>(type, zero, one); - Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf); - Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf); + arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType)); + Value one = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 1)); + Value i = complex::CreateOp::create(b, type, zero, one); + Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf); + Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf); - Value divResult = b.create<complex::DivOp>( - rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf); - Value logResult = b.create<complex::LogOp>(divResult, fmf); + Value divResult = complex::DivOp::create( + b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf); + Value logResult = complex::LogOp::create(b, divResult, fmf); - Value negativeOne = b.create<arith::ConstantOp>( - elementType, b.getFloatAttr(elementType, -1)); - Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne); + Value negativeOne = arith::ConstantOp::create( + b, elementType, b.getFloatAttr(elementType, -1)); + Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne); rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf); return success(); @@ -146,14 +146,18 @@ struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> { auto loc = op.getLoc(); auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType(); - Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs()); - Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs()); - Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs()); - Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs()); + Value realLhs = + complex::ReOp::create(rewriter, loc, type, adaptor.getLhs()); + Value imagLhs = + complex::ImOp::create(rewriter, loc, type, adaptor.getLhs()); + Value realRhs = + complex::ReOp::create(rewriter, loc, type, adaptor.getRhs()); + Value imagRhs = + complex::ImOp::create(rewriter, loc, type, adaptor.getRhs()); Value realComparison = - rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs); + arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs); Value imagComparison = - rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs); + arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs); rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison, imagComparison); @@ -176,14 +180,14 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs()); - Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs()); - Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs, - fmf.getValue()); - Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs()); - Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs()); - Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs, - fmf.getValue()); + Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs()); + Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs()); + Value resultReal = BinaryStandardOp::create(b, elementType, realLhs, + realRhs, fmf.getValue()); + Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs()); + Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs()); + Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs, + imagRhs, fmf.getValue()); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); return success(); @@ -205,20 +209,20 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = - rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); // Trigonometric ops use a set of common building blocks to convert to real // ops. Here we create these building blocks and call into an op-specific // implementation in the subclass to combine them. - Value half = rewriter.create<arith::ConstantOp>( - loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); - Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf); - Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf); - Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf); - Value sin = rewriter.create<math::SinOp>(loc, real, fmf); - Value cos = rewriter.create<math::CosOp>(loc, real, fmf); + Value half = arith::ConstantOp::create( + rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); + Value exp = math::ExpOp::create(rewriter, loc, imag, fmf); + Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf); + Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf); + Value sin = math::SinOp::create(rewriter, loc, real, fmf); + Value cos = math::CosOp::create(rewriter, loc, real, fmf); auto resultPair = combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf); @@ -251,11 +255,11 @@ struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> { // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x Value sum = - rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf); - Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf); + arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf); + Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf); Value diff = - rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf); - Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf); + arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf); + Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf); return {resultReal, resultImag}; } }; @@ -275,13 +279,13 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value lhsReal = - rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs()); Value lhsImag = - rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs()); Value rhsReal = - rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs()); Value rhsImag = - rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs()); Value resultReal, resultImag; @@ -318,16 +322,16 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = - rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); - Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue()); - Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue()); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); + Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue()); + Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue()); Value resultReal = - rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue()); - Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue()); + arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue()); + Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue()); Value resultImag = - rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue()); + arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue()); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); @@ -340,11 +344,11 @@ Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg, arith::FastMathFlagsAttr fmf) { auto argType = mlir::cast<FloatType>(arg.getType()); Value poly = - b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0])); + arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[0])); for (unsigned i = 1; i < coefficients.size(); ++i) { - poly = b.create<math::FmaOp>( - poly, arg, - b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])), + poly = math::FmaOp::create( + b, poly, arg, + arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[i])), fmf); } return poly; @@ -365,26 +369,26 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value real = b.create<complex::ReOp>(adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(adaptor.getComplex()); + Value real = complex::ReOp::create(b, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, adaptor.getComplex()); - Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0)); - Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0)); + Value zero = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 0.0)); + Value one = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 1.0)); - Value expm1Real = b.create<math::ExpM1Op>(real, fmf); - Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf); + Value expm1Real = math::ExpM1Op::create(b, real, fmf); + Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf); - Value sinImag = b.create<math::SinOp>(imag, fmf); + Value sinImag = math::SinOp::create(b, imag, fmf); Value cosm1Imag = emitCosm1(imag, fmf, b); - Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf); + Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf); - Value realResult = b.create<arith::AddFOp>( - b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf); + Value realResult = arith::AddFOp::create( + b, arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf); - Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, - zero, fmf.getValue()); - Value imagResult = b.create<arith::SelectOp>( - imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf)); + Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, + zero, fmf.getValue()); + Value imagResult = arith::SelectOp::create( + b, imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf)); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult, imagResult); @@ -395,8 +399,8 @@ private: Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf, ImplicitLocOpBuilder &b) const { auto argType = mlir::cast<FloatType>(arg.getType()); - auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5)); - auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0)); + auto negHalf = arith::ConstantOp::create(b, b.getFloatAttr(argType, -0.5)); + auto negOne = arith::ConstantOp::create(b, b.getFloatAttr(argType, -1.0)); // Algorithm copied from cephes cosm1. SmallVector<double, 7> kCoeffs{ @@ -405,23 +409,23 @@ private: 2.4801587301570552304991E-5, -1.3888888888888872993737E-3, 4.1666666666666666609054E-2, }; - Value cos = b.create<math::CosOp>(arg, fmf); - Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf); + Value cos = math::CosOp::create(b, arg, fmf); + Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf); - Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf); - Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf); + Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf); + Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf); Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf); auto forSmallArg = - b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf), - b.create<arith::MulFOp>(negHalf, argPow2, fmf)); + arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf), + arith::MulFOp::create(b, negHalf, argPow2, fmf)); // (pi/4)^2 is approximately 0.61685 Value piOver4Pow2 = - b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685)); - Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2, - piOver4Pow2, fmf.getValue()); - return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg); + arith::ConstantOp::create(b, b.getFloatAttr(argType, 0.61685)); + Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2, + piOver4Pow2, fmf.getValue()); + return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg); } }; @@ -436,13 +440,13 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), - fmf.getValue()); - Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue()); - Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); + Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(), + fmf.getValue()); + Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value resultImag = - b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue()); + math::Atan2Op::create(b, elementType, imag, real, fmf.getValue()); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); return success(); @@ -460,40 +464,42 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value real = b.create<complex::ReOp>(adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(adaptor.getComplex()); + Value real = complex::ReOp::create(b, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, adaptor.getComplex()); - Value half = b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, 0.5)); - Value one = b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, 1)); - Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf); - Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf); - Value absImag = b.create<math::AbsFOp>(imag, fmf); + Value half = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 0.5)); + Value one = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 1)); + Value realPlusOne = arith::AddFOp::create(b, real, one, fmf); + Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf); + Value absImag = math::AbsFOp::create(b, imag, fmf); - Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf); - Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf); + Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf); + Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf); - Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, - realPlusOne, absImag, fmf); - Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf); + Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, + realPlusOne, absImag, fmf); + Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf); Value maxAbsOfRealPlusOneAndImagMinusOne = - b.create<arith::SelectOp>(useReal, real, maxMinusOne); + arith::SelectOp::create(b, useReal, real, maxMinusOne); arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear( fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf); - Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf); + Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf); Value logOfMaxAbsOfRealPlusOneAndImag = - b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf); - Value logOfSqrtPart = b.create<math::Log1pOp>( - b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf), + math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf); + Value logOfSqrtPart = math::Log1pOp::create( + b, arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf), fmfWithNaNInf); - Value r = b.create<arith::AddFOp>( - b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf), + Value r = arith::AddFOp::create( + b, arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf), logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf); - Value resultReal = b.create<arith::SelectOp>( - b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf), + Value resultReal = arith::SelectOp::create( + b, + arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r, + fmfWithNaNInf), minAbs, r); - Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf); + Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); return success(); @@ -511,22 +517,22 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> { auto elementType = cast<FloatType>(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); auto fmfValue = fmf.getValue(); - Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs()); - Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs()); - Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs()); - Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs()); + Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs()); + Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs()); + Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs()); + Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs()); Value lhsRealTimesRhsReal = - b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue); + arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue); Value lhsImagTimesRhsImag = - b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue); - Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal, - lhsImagTimesRhsImag, fmfValue); + arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue); + Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal, + lhsImagTimesRhsImag, fmfValue); Value lhsImagTimesRhsReal = - b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue); + arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue); Value lhsRealTimesRhsImag = - b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue); - Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal, - lhsRealTimesRhsImag, fmfValue); + arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue); + Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal, + lhsRealTimesRhsImag, fmfValue); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag); return success(); } @@ -543,11 +549,11 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> { auto elementType = cast<FloatType>(type.getElementType()); Value real = - rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); - Value negReal = rewriter.create<arith::NegFOp>(loc, real); - Value negImag = rewriter.create<arith::NegFOp>(loc, imag); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); + Value negReal = arith::NegFOp::create(rewriter, loc, real); + Value negImag = arith::NegFOp::create(rewriter, loc, imag); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); return success(); } @@ -570,11 +576,11 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> { // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x Value sum = - rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf); - Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf); + arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf); + Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf); Value diff = - rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf); - Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf); + arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf); + Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf); return {resultReal, resultImag}; } }; @@ -593,64 +599,65 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); auto cst = [&](APFloat v) { - return b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, v)); + return arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, v)); }; const auto &floatSemantics = elementType.getFloatSemantics(); Value zero = cst(APFloat::getZero(floatSemantics)); - Value half = b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, 0.5)); + Value half = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 0.5)); - Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt); - Value argArg = b.create<math::Atan2Op>(imag, real, fmf); - Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf); - Value cos = b.create<math::CosOp>(sqrtArg, fmf); - Value sin = b.create<math::SinOp>(sqrtArg, fmf); + Value argArg = math::Atan2Op::create(b, imag, real, fmf); + Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf); + Value cos = math::CosOp::create(b, sqrtArg, fmf); + Value sin = math::SinOp::create(b, sqrtArg, fmf); // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply // 0 * inf. Value sinIsZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf); - Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf); - Value resultImag = b.create<arith::SelectOp>( - sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf)); + Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf); + Value resultImag = arith::SelectOp::create( + b, sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf)); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { Value inf = cst(APFloat::getInf(floatSemantics)); Value negInf = cst(APFloat::getInf(floatSemantics, true)); Value nan = cst(APFloat::getNaN(floatSemantics)); - Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf); + Value absImag = math::AbsFOp::create(b, elementType, imag, fmf); - Value absImagIsInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf); - Value absImagIsNotInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf); + Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + absImag, inf, fmf); + Value absImagIsNotInf = arith::CmpFOp::create( + b, arith::CmpFPredicate::ONE, absImag, inf, fmf); Value realIsInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf); - Value realIsNegInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf); + Value realIsNegInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + real, negInf, fmf); - resultReal = b.create<arith::SelectOp>( - b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero, + resultReal = arith::SelectOp::create( + b, arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero, resultReal); - resultReal = b.create<arith::SelectOp>( - b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal); + resultReal = arith::SelectOp::create( + b, arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal); - Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf); - resultImag = b.create<arith::SelectOp>( - b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt), + Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf); + resultImag = arith::SelectOp::create( + b, + arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt), nan, resultImag); - resultImag = b.create<arith::SelectOp>( - b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf, + resultImag = arith::SelectOp::create( + b, arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf, resultImag); } Value resultIsZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf); - resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal); - resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf); + resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal); + resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); @@ -669,19 +676,20 @@ struct SignOpConversion : public OpConversionPattern<complex::SignOp> { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value zero = - b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType)); + arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType)); Value realIsZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero); Value imagIsZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero); - Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero); - auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf); - Value realSign = b.create<arith::DivFOp>(real, abs, fmf); - Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf); - Value sign = b.create<complex::CreateOp>(type, realSign, imagSign); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero); + Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero); + auto abs = + complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf); + Value realSign = arith::DivFOp::create(b, real, abs, fmf); + Value imagSign = arith::DivFOp::create(b, imag, abs, fmf); + Value sign = complex::CreateOp::create(b, type, realSign, imagSign); rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero, adaptor.getComplex(), sign); return success(); @@ -703,84 +711,84 @@ struct TanTanhOpConversion : public OpConversionPattern<Op> { const auto &floatSemantics = elementType.getFloatSemantics(); Value real = - b.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(b, loc, elementType, adaptor.getComplex()); Value imag = - b.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); - Value negOne = b.create<arith::ConstantOp>( - elementType, b.getFloatAttr(elementType, -1.0)); + complex::ImOp::create(b, loc, elementType, adaptor.getComplex()); + Value negOne = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, -1.0)); if constexpr (std::is_same_v<Op, complex::TanOp>) { // tan(x+yi) = -i*tanh(-y + xi) std::swap(real, imag); - real = b.create<arith::MulFOp>(real, negOne, fmf); + real = arith::MulFOp::create(b, real, negOne, fmf); } auto cst = [&](APFloat v) { - return b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, v)); + return arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, v)); }; Value inf = cst(APFloat::getInf(floatSemantics)); - Value four = b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, 4.0)); - Value twoReal = b.create<arith::AddFOp>(real, real, fmf); - Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf); - - Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf); - Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf); - Value realNum = - b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); - - Value cosImag = b.create<math::CosOp>(imag, fmf); - Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf); - Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf); - Value sinImag = b.create<math::SinOp>(imag, fmf); - - Value imagNum = b.create<arith::MulFOp>( - four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf); - - Value expSumMinusTwo = - b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf); + Value four = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 4.0)); + Value twoReal = arith::AddFOp::create(b, real, real, fmf); + Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf); + + Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf); + Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf); + Value realNum = arith::SubFOp::create(b, expTwoRealMinusOne, + expNegTwoRealMinusOne, fmf); + + Value cosImag = math::CosOp::create(b, imag, fmf); + Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf); + Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf); + Value sinImag = math::SinOp::create(b, imag, fmf); + + Value imagNum = arith::MulFOp::create( + b, four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf); + + Value expSumMinusTwo = arith::AddFOp::create(b, expTwoRealMinusOne, + expNegTwoRealMinusOne, fmf); Value denom = - b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf); + arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf); - Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, - expSumMinusTwo, inf, fmf); - Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf); + Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + expSumMinusTwo, inf, fmf); + Value realLimit = math::CopySignOp::create(b, negOne, real, fmf); - Value resultReal = b.create<arith::SelectOp>( - isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf)); - Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf); + Value resultReal = arith::SelectOp::create( + b, isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf)); + Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { - Value absReal = b.create<math::AbsFOp>(real, fmf); - Value zero = b.create<arith::ConstantOp>( - elementType, b.getFloatAttr(elementType, 0.0)); + Value absReal = math::AbsFOp::create(b, real, fmf); + Value zero = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, 0.0)); Value nan = cst(APFloat::getNaN(floatSemantics)); - Value absRealIsInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf); + Value absRealIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + absReal, inf, fmf); Value imagIsZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf); - Value absRealIsNotInf = b.create<arith::XOrIOp>( - absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1)); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf); + Value absRealIsNotInf = arith::XOrIOp::create( + b, absRealIsInf, arith::ConstantIntOp::create(b, true, /*width=*/1)); - Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, - imagNum, imagNum, fmf); + Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, + imagNum, imagNum, fmf); Value resultRealIsNaN = - b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf); - Value resultImagIsZero = b.create<arith::OrIOp>( - imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN)); + arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf); + Value resultImagIsZero = arith::OrIOp::create( + b, imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN)); - resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal); + resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal); resultImag = - b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag); + arith::SelectOp::create(b, resultImagIsZero, zero, resultImag); } if constexpr (std::is_same_v<Op, complex::TanOp>) { // tan(x+yi) = -i*tanh(-y + xi) std::swap(resultReal, resultImag); - resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf); + resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf); } rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, @@ -799,10 +807,10 @@ struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> { auto type = cast<ComplexType>(adaptor.getComplex().getType()); auto elementType = cast<FloatType>(type.getElementType()); Value real = - rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); Value imag = - rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex()); - Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag); + complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); + Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag); @@ -818,97 +826,102 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, arith::FastMathFlags fmf) { auto elementType = cast<FloatType>(type.getElementType()); - Value a = builder.create<complex::ReOp>(lhs); - Value b = builder.create<complex::ImOp>(lhs); + Value a = complex::ReOp::create(builder, lhs); + Value b = complex::ImOp::create(builder, lhs); - Value abs = builder.create<complex::AbsOp>(lhs, fmf); - Value absToC = builder.create<math::PowFOp>(abs, c, fmf); + Value abs = complex::AbsOp::create(builder, lhs, fmf); + Value absToC = math::PowFOp::create(builder, abs, c, fmf); - Value negD = builder.create<arith::NegFOp>(d, fmf); - Value argLhs = builder.create<math::Atan2Op>(b, a, fmf); - Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf); - Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf); + Value negD = arith::NegFOp::create(builder, d, fmf); + Value argLhs = math::Atan2Op::create(builder, b, a, fmf); + Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf); + Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf); - Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf); - Value lnAbs = builder.create<math::LogOp>(abs, fmf); - Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf); - Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf); - Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf); - Value cosQ = builder.create<math::CosOp>(q, fmf); - Value sinQ = builder.create<math::SinOp>(q, fmf); + Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf); + Value lnAbs = math::LogOp::create(builder, abs, fmf); + Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf); + Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf); + Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf); + Value cosQ = math::CosOp::create(builder, q, fmf); + Value sinQ = math::SinOp::create(builder, q, fmf); - Value inf = builder.create<arith::ConstantOp>( - elementType, + Value inf = arith::ConstantOp::create( + builder, elementType, builder.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); - Value zero = builder.create<arith::ConstantOp>( - elementType, builder.getFloatAttr(elementType, 0.0)); - Value one = builder.create<arith::ConstantOp>( - elementType, builder.getFloatAttr(elementType, 1.0)); - Value complexOne = builder.create<complex::CreateOp>(type, one, zero); - Value complexZero = builder.create<complex::CreateOp>(type, zero, zero); - Value complexInf = builder.create<complex::CreateOp>(type, inf, zero); + Value zero = arith::ConstantOp::create( + builder, elementType, builder.getFloatAttr(elementType, 0.0)); + Value one = arith::ConstantOp::create(builder, elementType, + builder.getFloatAttr(elementType, 1.0)); + Value complexOne = complex::CreateOp::create(builder, type, one, zero); + Value complexZero = complex::CreateOp::create(builder, type, zero, zero); + Value complexInf = complex::CreateOp::create(builder, type, inf, zero); // Case 0: // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. Value absEqZero = - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf); Value dEqZero = - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf); Value cEqZero = - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf); Value bEqZero = - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf); Value zeroLeC = - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf); - Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf); - Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf); + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf); + Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf); + Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf); Value complexOneOrZero = - builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero); + arith::SelectOp::create(builder, cEqZero, complexOne, complexZero); Value coeffCosSin = - builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ); - Value cutoff0 = builder.create<arith::SelectOp>( - builder.create<arith::AndIOp>( - builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC), + complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ); + Value cutoff0 = arith::SelectOp::create( + builder, + arith::AndIOp::create( + builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC), complexOneOrZero, coeffCosSin); // Case 1: // x^0 is defined to be 1 for any x, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. - Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero); + Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero); Value cutoff1 = - builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0); + arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0); // Case 2: // 1^(c + d*i) = 1 + 0*i - Value lhsEqOne = builder.create<arith::AndIOp>( - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf), + Value lhsEqOne = arith::AndIOp::create( + builder, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf), bEqZero); Value cutoff2 = - builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1); + arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1); // Case 3: // inf^(c + 0*i) = inf + 0*i, c > 0 - Value lhsEqInf = builder.create<arith::AndIOp>( - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf), + Value lhsEqInf = arith::AndIOp::create( + builder, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf), bEqZero); - Value rhsGt0 = builder.create<arith::AndIOp>( - dEqZero, - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf)); - Value cutoff3 = builder.create<arith::SelectOp>( - builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2); + Value rhsGt0 = arith::AndIOp::create( + builder, dEqZero, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf)); + Value cutoff3 = arith::SelectOp::create( + builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf, + cutoff2); // Case 4: // inf^(c + 0*i) = 0 + 0*i, c < 0 - Value rhsLt0 = builder.create<arith::AndIOp>( - dEqZero, - builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf)); - Value cutoff4 = builder.create<arith::SelectOp>( - builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3); + Value rhsLt0 = arith::AndIOp::create( + builder, dEqZero, + arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf)); + Value cutoff4 = arith::SelectOp::create( + builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero, + cutoff3); return cutoff4; } @@ -923,8 +936,8 @@ struct PowOpConversion : public OpConversionPattern<complex::PowOp> { auto type = cast<ComplexType>(adaptor.getLhs().getType()); auto elementType = cast<FloatType>(type.getElementType()); - Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs()); - Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs()); + Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs()); + Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs()); rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(), c, d, op.getFastmath())}); @@ -945,64 +958,64 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> { arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); auto cst = [&](APFloat v) { - return b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, v)); + return arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, v)); }; const auto &floatSemantics = elementType.getFloatSemantics(); Value zero = cst(APFloat::getZero(floatSemantics)); Value inf = cst(APFloat::getInf(floatSemantics)); - Value negHalf = b.create<arith::ConstantOp>( - elementType, b.getFloatAttr(elementType, -0.5)); + Value negHalf = arith::ConstantOp::create( + b, elementType, b.getFloatAttr(elementType, -0.5)); Value nan = cst(APFloat::getNaN(floatSemantics)); - Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); + Value real = complex::ReOp::create(b, elementType, adaptor.getComplex()); + Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex()); Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt); - Value argArg = b.create<math::Atan2Op>(imag, real, fmf); - Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf); - Value cos = b.create<math::CosOp>(rsqrtArg, fmf); - Value sin = b.create<math::SinOp>(rsqrtArg, fmf); + Value argArg = math::Atan2Op::create(b, imag, real, fmf); + Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf); + Value cos = math::CosOp::create(b, rsqrtArg, fmf); + Value sin = math::SinOp::create(b, rsqrtArg, fmf); - Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf); - Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf); + Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf); + Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf); if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf)) { - Value negOne = b.create<arith::ConstantOp>( - elementType, b.getFloatAttr(elementType, -1)); + Value negOne = arith::ConstantOp::create(b, elementType, + b.getFloatAttr(elementType, -1)); - Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf); - Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf); + Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf); + Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf); Value negImagSignedZero = - b.create<arith::MulFOp>(negOne, imagSignedZero, fmf); + arith::MulFOp::create(b, negOne, imagSignedZero, fmf); - Value absReal = b.create<math::AbsFOp>(real, fmf); - Value absImag = b.create<math::AbsFOp>(imag, fmf); + Value absReal = math::AbsFOp::create(b, real, fmf); + Value absImag = math::AbsFOp::create(b, imag, fmf); - Value absImagIsInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf); + Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + absImag, inf, fmf); Value realIsNan = - b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf); - Value realIsInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf); - Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan); + arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf); + Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, + absReal, inf, fmf); + Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan); - Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf); + Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf); resultReal = - b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal); - resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero, - resultImag); + arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal); + resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero, + resultImag); } Value isRealZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf); Value isImagZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf); - Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero); + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf); + Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero); - resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal); - resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag); + resultReal = arith::SelectOp::create(b, isZero, inf, resultReal); + resultImag = arith::SelectOp::create(b, isZero, nan, resultImag); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); @@ -1021,9 +1034,9 @@ struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> { arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = - rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex()); + complex::ReOp::create(rewriter, loc, type, adaptor.getComplex()); Value imag = - rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex()); + complex::ImOp::create(rewriter, loc, type, adaptor.getComplex()); rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf); |