aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp')
-rw-r--r--mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp707
1 files changed, 360 insertions, 347 deletions
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 0c832c4..5ad514d 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -12,7 +12,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include <type_traits>
@@ -31,44 +30,45 @@ enum class AbsFn { abs, sqrt, rsqrt };
// Returns the absolute value, its square root or its reciprocal square root.
Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
- Value one = b.create<arith::ConstantOp>(real.getType(),
- b.getFloatAttr(real.getType(), 1.0));
+ Value one = arith::ConstantOp::create(b, real.getType(),
+ b.getFloatAttr(real.getType(), 1.0));
- Value absReal = b.create<math::AbsFOp>(real, fmf);
- Value absImag = b.create<math::AbsFOp>(imag, fmf);
+ Value absReal = math::AbsFOp::create(b, real, fmf);
+ Value absImag = math::AbsFOp::create(b, imag, fmf);
- Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
- Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
+ Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf);
+ Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf);
// The lowering below requires NaNs and infinities to work correctly.
arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
- Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf);
- Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
- Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
+ Value ratio = arith::DivFOp::create(b, min, max, fmfWithNaNInf);
+ Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf);
+ Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf);
Value result;
if (fn == AbsFn::rsqrt) {
- ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
- min = b.create<math::RsqrtOp>(min, fmfWithNaNInf);
- max = b.create<math::RsqrtOp>(max, fmfWithNaNInf);
+ ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
+ min = math::RsqrtOp::create(b, min, fmfWithNaNInf);
+ max = math::RsqrtOp::create(b, max, fmfWithNaNInf);
}
if (fn == AbsFn::sqrt) {
- Value quarter = b.create<arith::ConstantOp>(
- real.getType(), b.getFloatAttr(real.getType(), 0.25));
+ Value quarter = arith::ConstantOp::create(
+ b, real.getType(), b.getFloatAttr(real.getType(), 0.25));
// sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
- Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf);
- Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
- result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
+ Value sqrt = math::SqrtOp::create(b, max, fmfWithNaNInf);
+ Value p025 =
+ math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf);
+ result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf);
} else {
- Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
- result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf);
+ Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
+ result = arith::MulFOp::create(b, max, sqrt, fmfWithNaNInf);
}
- Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
- result, fmfWithNaNInf);
- return b.create<arith::SelectOp>(isNaN, min, result);
+ Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result,
+ result, fmfWithNaNInf);
+ return arith::SelectOp::create(b, isNaN, min, result);
}
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
@@ -81,8 +81,8 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
- Value real = b.create<complex::ReOp>(adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(adaptor.getComplex());
+ Value real = complex::ReOp::create(b, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, adaptor.getComplex());
rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
return success();
@@ -105,28 +105,28 @@ struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
Value lhs = adaptor.getLhs();
Value rhs = adaptor.getRhs();
- Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf);
- Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf);
+ Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf);
+ Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf);
Value rhsSquaredPlusLhsSquared =
- b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
+ complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf);
Value sqrtOfRhsSquaredPlusLhsSquared =
- b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
+ complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf);
Value zero =
- b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
- Value one = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 1));
- Value i = b.create<complex::CreateOp>(type, zero, one);
- Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf);
- Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf);
+ arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
+ Value one = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 1));
+ Value i = complex::CreateOp::create(b, type, zero, one);
+ Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf);
+ Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf);
- Value divResult = b.create<complex::DivOp>(
- rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
- Value logResult = b.create<complex::LogOp>(divResult, fmf);
+ Value divResult = complex::DivOp::create(
+ b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
+ Value logResult = complex::LogOp::create(b, divResult, fmf);
- Value negativeOne = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, -1));
- Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
+ Value negativeOne = arith::ConstantOp::create(
+ b, elementType, b.getFloatAttr(elementType, -1));
+ Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne);
rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
return success();
@@ -146,14 +146,18 @@ struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
auto loc = op.getLoc();
auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
- Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
- Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
- Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
- Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
+ Value realLhs =
+ complex::ReOp::create(rewriter, loc, type, adaptor.getLhs());
+ Value imagLhs =
+ complex::ImOp::create(rewriter, loc, type, adaptor.getLhs());
+ Value realRhs =
+ complex::ReOp::create(rewriter, loc, type, adaptor.getRhs());
+ Value imagRhs =
+ complex::ImOp::create(rewriter, loc, type, adaptor.getRhs());
Value realComparison =
- rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
+ arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs);
Value imagComparison =
- rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
+ arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs);
rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
imagComparison);
@@ -176,14 +180,14 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
- Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
- Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
- Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
- fmf.getValue());
- Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
- Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
- Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
- fmf.getValue());
+ Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
+ Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
+ Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
+ realRhs, fmf.getValue());
+ Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
+ Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
+ Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
+ imagRhs, fmf.getValue());
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
@@ -205,20 +209,20 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value real =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
// Trigonometric ops use a set of common building blocks to convert to real
// ops. Here we create these building blocks and call into an op-specific
// implementation in the subclass to combine them.
- Value half = rewriter.create<arith::ConstantOp>(
- loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
- Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
- Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
- Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
- Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
- Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
+ Value half = arith::ConstantOp::create(
+ rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
+ Value exp = math::ExpOp::create(rewriter, loc, imag, fmf);
+ Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf);
+ Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf);
+ Value sin = math::SinOp::create(rewriter, loc, real, fmf);
+ Value cos = math::CosOp::create(rewriter, loc, real, fmf);
auto resultPair =
combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
@@ -251,11 +255,11 @@ struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
// Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
// Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
Value sum =
- rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
- Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
+ arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
+ Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf);
Value diff =
- rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
- Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
+ arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
+ Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf);
return {resultReal, resultImag};
}
};
@@ -275,13 +279,13 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value lhsReal =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs());
Value lhsImag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs());
Value rhsReal =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs());
Value rhsImag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs());
Value resultReal, resultImag;
@@ -318,16 +322,16 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value real =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
- Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
+ Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue());
+ Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue());
Value resultReal =
- rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
- Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
+ arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue());
+ Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue());
Value resultImag =
- rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
+ arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue());
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
@@ -340,11 +344,11 @@ Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
arith::FastMathFlagsAttr fmf) {
auto argType = mlir::cast<FloatType>(arg.getType());
Value poly =
- b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
+ arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[0]));
for (unsigned i = 1; i < coefficients.size(); ++i) {
- poly = b.create<math::FmaOp>(
- poly, arg,
- b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
+ poly = math::FmaOp::create(
+ b, poly, arg,
+ arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[i])),
fmf);
}
return poly;
@@ -365,26 +369,26 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value real = b.create<complex::ReOp>(adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(adaptor.getComplex());
+ Value real = complex::ReOp::create(b, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, adaptor.getComplex());
- Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
- Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
+ Value zero = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 0.0));
+ Value one = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 1.0));
- Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
- Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
+ Value expm1Real = math::ExpM1Op::create(b, real, fmf);
+ Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf);
- Value sinImag = b.create<math::SinOp>(imag, fmf);
+ Value sinImag = math::SinOp::create(b, imag, fmf);
Value cosm1Imag = emitCosm1(imag, fmf, b);
- Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
+ Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf);
- Value realResult = b.create<arith::AddFOp>(
- b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
+ Value realResult = arith::AddFOp::create(
+ b, arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf);
- Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
- zero, fmf.getValue());
- Value imagResult = b.create<arith::SelectOp>(
- imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf));
+ Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag,
+ zero, fmf.getValue());
+ Value imagResult = arith::SelectOp::create(
+ b, imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf));
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
imagResult);
@@ -395,8 +399,8 @@ private:
Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
ImplicitLocOpBuilder &b) const {
auto argType = mlir::cast<FloatType>(arg.getType());
- auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5));
- auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0));
+ auto negHalf = arith::ConstantOp::create(b, b.getFloatAttr(argType, -0.5));
+ auto negOne = arith::ConstantOp::create(b, b.getFloatAttr(argType, -1.0));
// Algorithm copied from cephes cosm1.
SmallVector<double, 7> kCoeffs{
@@ -405,23 +409,23 @@ private:
2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
4.1666666666666666609054E-2,
};
- Value cos = b.create<math::CosOp>(arg, fmf);
- Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf);
+ Value cos = math::CosOp::create(b, arg, fmf);
+ Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf);
- Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf);
- Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf);
+ Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf);
+ Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf);
Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
auto forSmallArg =
- b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf),
- b.create<arith::MulFOp>(negHalf, argPow2, fmf));
+ arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf),
+ arith::MulFOp::create(b, negHalf, argPow2, fmf));
// (pi/4)^2 is approximately 0.61685
Value piOver4Pow2 =
- b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685));
- Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
- piOver4Pow2, fmf.getValue());
- return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
+ arith::ConstantOp::create(b, b.getFloatAttr(argType, 0.61685));
+ Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2,
+ piOver4Pow2, fmf.getValue());
+ return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg);
}
};
@@ -436,13 +440,13 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(),
- fmf.getValue());
- Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue());
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(),
+ fmf.getValue());
+ Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue());
+ Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
Value resultImag =
- b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
+ math::Atan2Op::create(b, elementType, imag, real, fmf.getValue());
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
@@ -460,40 +464,42 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value real = b.create<complex::ReOp>(adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(adaptor.getComplex());
+ Value real = complex::ReOp::create(b, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, adaptor.getComplex());
- Value half = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 0.5));
- Value one = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 1));
- Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
- Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
- Value absImag = b.create<math::AbsFOp>(imag, fmf);
+ Value half = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 0.5));
+ Value one = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 1));
+ Value realPlusOne = arith::AddFOp::create(b, real, one, fmf);
+ Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf);
+ Value absImag = math::AbsFOp::create(b, imag, fmf);
- Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
- Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
+ Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf);
+ Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf);
- Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
- realPlusOne, absImag, fmf);
- Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf);
+ Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT,
+ realPlusOne, absImag, fmf);
+ Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf);
Value maxAbsOfRealPlusOneAndImagMinusOne =
- b.create<arith::SelectOp>(useReal, real, maxMinusOne);
+ arith::SelectOp::create(b, useReal, real, maxMinusOne);
arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
- Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
+ Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf);
Value logOfMaxAbsOfRealPlusOneAndImag =
- b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
- Value logOfSqrtPart = b.create<math::Log1pOp>(
- b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
+ math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf);
+ Value logOfSqrtPart = math::Log1pOp::create(
+ b, arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf),
fmfWithNaNInf);
- Value r = b.create<arith::AddFOp>(
- b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
+ Value r = arith::AddFOp::create(
+ b, arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf),
logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
- Value resultReal = b.create<arith::SelectOp>(
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
+ Value resultReal = arith::SelectOp::create(
+ b,
+ arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r,
+ fmfWithNaNInf),
minAbs, r);
- Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
+ Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
@@ -511,22 +517,22 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
auto elementType = cast<FloatType>(type.getElementType());
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
auto fmfValue = fmf.getValue();
- Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
- Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
- Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
- Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
+ Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs());
+ Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs());
+ Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs());
+ Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs());
Value lhsRealTimesRhsReal =
- b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
+ arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
Value lhsImagTimesRhsImag =
- b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
- Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
- lhsImagTimesRhsImag, fmfValue);
+ arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
+ Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
+ lhsImagTimesRhsImag, fmfValue);
Value lhsImagTimesRhsReal =
- b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
+ arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
Value lhsRealTimesRhsImag =
- b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
- Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
- lhsRealTimesRhsImag, fmfValue);
+ arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
+ Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
+ lhsRealTimesRhsImag, fmfValue);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
return success();
}
@@ -543,11 +549,11 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
auto elementType = cast<FloatType>(type.getElementType());
Value real =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value negReal = rewriter.create<arith::NegFOp>(loc, real);
- Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
+ Value negReal = arith::NegFOp::create(rewriter, loc, real);
+ Value negImag = arith::NegFOp::create(rewriter, loc, imag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
return success();
}
@@ -570,11 +576,11 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
// Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
// Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
Value sum =
- rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
- Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
+ arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
+ Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf);
Value diff =
- rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
- Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
+ arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
+ Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf);
return {resultReal, resultImag};
}
};
@@ -593,64 +599,65 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
auto cst = [&](APFloat v) {
- return b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, v));
+ return arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, v));
};
const auto &floatSemantics = elementType.getFloatSemantics();
Value zero = cst(APFloat::getZero(floatSemantics));
- Value half = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 0.5));
+ Value half = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 0.5));
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
- Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
- Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
- Value cos = b.create<math::CosOp>(sqrtArg, fmf);
- Value sin = b.create<math::SinOp>(sqrtArg, fmf);
+ Value argArg = math::Atan2Op::create(b, imag, real, fmf);
+ Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf);
+ Value cos = math::CosOp::create(b, sqrtArg, fmf);
+ Value sin = math::SinOp::create(b, sqrtArg, fmf);
// sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
// 0 * inf.
Value sinIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf);
- Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
- Value resultImag = b.create<arith::SelectOp>(
- sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
+ Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf);
+ Value resultImag = arith::SelectOp::create(
+ b, sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf));
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
arith::FastMathFlags::ninf)) {
Value inf = cst(APFloat::getInf(floatSemantics));
Value negInf = cst(APFloat::getInf(floatSemantics, true));
Value nan = cst(APFloat::getNaN(floatSemantics));
- Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
+ Value absImag = math::AbsFOp::create(b, elementType, imag, fmf);
- Value absImagIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
- Value absImagIsNotInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
+ Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ absImag, inf, fmf);
+ Value absImagIsNotInf = arith::CmpFOp::create(
+ b, arith::CmpFPredicate::ONE, absImag, inf, fmf);
Value realIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
- Value realIsNegInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf);
+ Value realIsNegInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ real, negInf, fmf);
- resultReal = b.create<arith::SelectOp>(
- b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
+ resultReal = arith::SelectOp::create(
+ b, arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero,
resultReal);
- resultReal = b.create<arith::SelectOp>(
- b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
+ resultReal = arith::SelectOp::create(
+ b, arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal);
- Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
- resultImag = b.create<arith::SelectOp>(
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
+ Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf);
+ resultImag = arith::SelectOp::create(
+ b,
+ arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt),
nan, resultImag);
- resultImag = b.create<arith::SelectOp>(
- b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
+ resultImag = arith::SelectOp::create(
+ b, arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf,
resultImag);
}
Value resultIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
- resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
- resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
+ resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal);
+ resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
@@ -669,19 +676,20 @@ struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
Value zero =
- b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+ arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
Value realIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero);
Value imagIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
- Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
- auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
- Value realSign = b.create<arith::DivFOp>(real, abs, fmf);
- Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf);
- Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero);
+ Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero);
+ auto abs =
+ complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf);
+ Value realSign = arith::DivFOp::create(b, real, abs, fmf);
+ Value imagSign = arith::DivFOp::create(b, imag, abs, fmf);
+ Value sign = complex::CreateOp::create(b, type, realSign, imagSign);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
adaptor.getComplex(), sign);
return success();
@@ -703,84 +711,84 @@ struct TanTanhOpConversion : public OpConversionPattern<Op> {
const auto &floatSemantics = elementType.getFloatSemantics();
Value real =
- b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(b, loc, elementType, adaptor.getComplex());
Value imag =
- b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value negOne = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, -1.0));
+ complex::ImOp::create(b, loc, elementType, adaptor.getComplex());
+ Value negOne = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, -1.0));
if constexpr (std::is_same_v<Op, complex::TanOp>) {
// tan(x+yi) = -i*tanh(-y + xi)
std::swap(real, imag);
- real = b.create<arith::MulFOp>(real, negOne, fmf);
+ real = arith::MulFOp::create(b, real, negOne, fmf);
}
auto cst = [&](APFloat v) {
- return b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, v));
+ return arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, v));
};
Value inf = cst(APFloat::getInf(floatSemantics));
- Value four = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 4.0));
- Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
- Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf);
-
- Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf);
- Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf);
- Value realNum =
- b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
-
- Value cosImag = b.create<math::CosOp>(imag, fmf);
- Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf);
- Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf);
- Value sinImag = b.create<math::SinOp>(imag, fmf);
-
- Value imagNum = b.create<arith::MulFOp>(
- four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
-
- Value expSumMinusTwo =
- b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
+ Value four = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 4.0));
+ Value twoReal = arith::AddFOp::create(b, real, real, fmf);
+ Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf);
+
+ Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf);
+ Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf);
+ Value realNum = arith::SubFOp::create(b, expTwoRealMinusOne,
+ expNegTwoRealMinusOne, fmf);
+
+ Value cosImag = math::CosOp::create(b, imag, fmf);
+ Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf);
+ Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf);
+ Value sinImag = math::SinOp::create(b, imag, fmf);
+
+ Value imagNum = arith::MulFOp::create(
+ b, four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf);
+
+ Value expSumMinusTwo = arith::AddFOp::create(b, expTwoRealMinusOne,
+ expNegTwoRealMinusOne, fmf);
Value denom =
- b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
+ arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
- Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
- expSumMinusTwo, inf, fmf);
- Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf);
+ Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ expSumMinusTwo, inf, fmf);
+ Value realLimit = math::CopySignOp::create(b, negOne, real, fmf);
- Value resultReal = b.create<arith::SelectOp>(
- isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf));
- Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf);
+ Value resultReal = arith::SelectOp::create(
+ b, isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf));
+ Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf);
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
arith::FastMathFlags::ninf)) {
- Value absReal = b.create<math::AbsFOp>(real, fmf);
- Value zero = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, 0.0));
+ Value absReal = math::AbsFOp::create(b, real, fmf);
+ Value zero = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, 0.0));
Value nan = cst(APFloat::getNaN(floatSemantics));
- Value absRealIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
+ Value absRealIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ absReal, inf, fmf);
Value imagIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
- Value absRealIsNotInf = b.create<arith::XOrIOp>(
- absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1));
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
+ Value absRealIsNotInf = arith::XOrIOp::create(
+ b, absRealIsInf, arith::ConstantIntOp::create(b, true, /*width=*/1));
- Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
- imagNum, imagNum, fmf);
+ Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO,
+ imagNum, imagNum, fmf);
Value resultRealIsNaN =
- b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
- Value resultImagIsZero = b.create<arith::OrIOp>(
- imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
+ arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf);
+ Value resultImagIsZero = arith::OrIOp::create(
+ b, imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN));
- resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
+ resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal);
resultImag =
- b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
+ arith::SelectOp::create(b, resultImagIsZero, zero, resultImag);
}
if constexpr (std::is_same_v<Op, complex::TanOp>) {
// tan(x+yi) = -i*tanh(-y + xi)
std::swap(resultReal, resultImag);
- resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf);
+ resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf);
}
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
@@ -799,10 +807,10 @@ struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
auto type = cast<ComplexType>(adaptor.getComplex().getType());
auto elementType = cast<FloatType>(type.getElementType());
Value real =
- rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
- Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
+ complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
+ Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
@@ -818,97 +826,102 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
arith::FastMathFlags fmf) {
auto elementType = cast<FloatType>(type.getElementType());
- Value a = builder.create<complex::ReOp>(lhs);
- Value b = builder.create<complex::ImOp>(lhs);
+ Value a = complex::ReOp::create(builder, lhs);
+ Value b = complex::ImOp::create(builder, lhs);
- Value abs = builder.create<complex::AbsOp>(lhs, fmf);
- Value absToC = builder.create<math::PowFOp>(abs, c, fmf);
+ Value abs = complex::AbsOp::create(builder, lhs, fmf);
+ Value absToC = math::PowFOp::create(builder, abs, c, fmf);
- Value negD = builder.create<arith::NegFOp>(d, fmf);
- Value argLhs = builder.create<math::Atan2Op>(b, a, fmf);
- Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf);
- Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf);
+ Value negD = arith::NegFOp::create(builder, d, fmf);
+ Value argLhs = math::Atan2Op::create(builder, b, a, fmf);
+ Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf);
+ Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf);
- Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
- Value lnAbs = builder.create<math::LogOp>(abs, fmf);
- Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf);
- Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf);
- Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
- Value cosQ = builder.create<math::CosOp>(q, fmf);
- Value sinQ = builder.create<math::SinOp>(q, fmf);
+ Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf);
+ Value lnAbs = math::LogOp::create(builder, abs, fmf);
+ Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf);
+ Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf);
+ Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf);
+ Value cosQ = math::CosOp::create(builder, q, fmf);
+ Value sinQ = math::SinOp::create(builder, q, fmf);
- Value inf = builder.create<arith::ConstantOp>(
- elementType,
+ Value inf = arith::ConstantOp::create(
+ builder, elementType,
builder.getFloatAttr(elementType,
APFloat::getInf(elementType.getFloatSemantics())));
- Value zero = builder.create<arith::ConstantOp>(
- elementType, builder.getFloatAttr(elementType, 0.0));
- Value one = builder.create<arith::ConstantOp>(
- elementType, builder.getFloatAttr(elementType, 1.0));
- Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
- Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
- Value complexInf = builder.create<complex::CreateOp>(type, inf, zero);
+ Value zero = arith::ConstantOp::create(
+ builder, elementType, builder.getFloatAttr(elementType, 0.0));
+ Value one = arith::ConstantOp::create(builder, elementType,
+ builder.getFloatAttr(elementType, 1.0));
+ Value complexOne = complex::CreateOp::create(builder, type, one, zero);
+ Value complexZero = complex::CreateOp::create(builder, type, zero, zero);
+ Value complexInf = complex::CreateOp::create(builder, type, inf, zero);
// Case 0:
// d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
// Branch Cuts for Complex Elementary Functions or Much Ado About
// Nothing's Sign Bit, W. Kahan, Section 10.
Value absEqZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf);
Value dEqZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf);
Value cEqZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf);
Value bEqZero =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf);
Value zeroLeC =
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
- Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf);
- Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf);
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf);
+ Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf);
+ Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf);
Value complexOneOrZero =
- builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero);
+ arith::SelectOp::create(builder, cEqZero, complexOne, complexZero);
Value coeffCosSin =
- builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
- Value cutoff0 = builder.create<arith::SelectOp>(
- builder.create<arith::AndIOp>(
- builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
+ complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ);
+ Value cutoff0 = arith::SelectOp::create(
+ builder,
+ arith::AndIOp::create(
+ builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC),
complexOneOrZero, coeffCosSin);
// Case 1:
// x^0 is defined to be 1 for any x, see
// Branch Cuts for Complex Elementary Functions or Much Ado About
// Nothing's Sign Bit, W. Kahan, Section 10.
- Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero);
+ Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero);
Value cutoff1 =
- builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
+ arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0);
// Case 2:
// 1^(c + d*i) = 1 + 0*i
- Value lhsEqOne = builder.create<arith::AndIOp>(
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
+ Value lhsEqOne = arith::AndIOp::create(
+ builder,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf),
bEqZero);
Value cutoff2 =
- builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
+ arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1);
// Case 3:
// inf^(c + 0*i) = inf + 0*i, c > 0
- Value lhsEqInf = builder.create<arith::AndIOp>(
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
+ Value lhsEqInf = arith::AndIOp::create(
+ builder,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf),
bEqZero);
- Value rhsGt0 = builder.create<arith::AndIOp>(
- dEqZero,
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
- Value cutoff3 = builder.create<arith::SelectOp>(
- builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
+ Value rhsGt0 = arith::AndIOp::create(
+ builder, dEqZero,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf));
+ Value cutoff3 = arith::SelectOp::create(
+ builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf,
+ cutoff2);
// Case 4:
// inf^(c + 0*i) = 0 + 0*i, c < 0
- Value rhsLt0 = builder.create<arith::AndIOp>(
- dEqZero,
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
- Value cutoff4 = builder.create<arith::SelectOp>(
- builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
+ Value rhsLt0 = arith::AndIOp::create(
+ builder, dEqZero,
+ arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf));
+ Value cutoff4 = arith::SelectOp::create(
+ builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero,
+ cutoff3);
return cutoff4;
}
@@ -923,8 +936,8 @@ struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
auto type = cast<ComplexType>(adaptor.getLhs().getType());
auto elementType = cast<FloatType>(type.getElementType());
- Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
- Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
+ Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs());
+ Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs());
rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
c, d, op.getFastmath())});
@@ -945,64 +958,64 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
auto cst = [&](APFloat v) {
- return b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, v));
+ return arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, v));
};
const auto &floatSemantics = elementType.getFloatSemantics();
Value zero = cst(APFloat::getZero(floatSemantics));
Value inf = cst(APFloat::getInf(floatSemantics));
- Value negHalf = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, -0.5));
+ Value negHalf = arith::ConstantOp::create(
+ b, elementType, b.getFloatAttr(elementType, -0.5));
Value nan = cst(APFloat::getNaN(floatSemantics));
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
+ Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
- Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
- Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
- Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
- Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
+ Value argArg = math::Atan2Op::create(b, imag, real, fmf);
+ Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf);
+ Value cos = math::CosOp::create(b, rsqrtArg, fmf);
+ Value sin = math::SinOp::create(b, rsqrtArg, fmf);
- Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
- Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
+ Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf);
+ Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf);
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
arith::FastMathFlags::ninf)) {
- Value negOne = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, -1));
+ Value negOne = arith::ConstantOp::create(b, elementType,
+ b.getFloatAttr(elementType, -1));
- Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
- Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
+ Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf);
+ Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf);
Value negImagSignedZero =
- b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
+ arith::MulFOp::create(b, negOne, imagSignedZero, fmf);
- Value absReal = b.create<math::AbsFOp>(real, fmf);
- Value absImag = b.create<math::AbsFOp>(imag, fmf);
+ Value absReal = math::AbsFOp::create(b, real, fmf);
+ Value absImag = math::AbsFOp::create(b, imag, fmf);
- Value absImagIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
+ Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ absImag, inf, fmf);
Value realIsNan =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
- Value realIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
- Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf);
+ Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
+ absReal, inf, fmf);
+ Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan);
- Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
+ Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf);
resultReal =
- b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
- resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
- resultImag);
+ arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal);
+ resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero,
+ resultImag);
}
Value isRealZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf);
Value isImagZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
- Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
+ Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero);
- resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
- resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
+ resultReal = arith::SelectOp::create(b, isZero, inf, resultReal);
+ resultImag = arith::SelectOp::create(b, isZero, nan, resultImag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
@@ -1021,9 +1034,9 @@ struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value real =
- rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
+ complex::ReOp::create(rewriter, loc, type, adaptor.getComplex());
Value imag =
- rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
+ complex::ImOp::create(rewriter, loc, type, adaptor.getComplex());
rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);