aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorKai Sasaki <lewuathe@gmail.com>2024-03-25 10:59:42 +0900
committerGitHub <noreply@github.com>2024-03-25 10:59:42 +0900
commit7d2d8e2a7245e4e64da22cb3c422ea3be5a0bf0a (patch)
treef970950812bbc3afcab1d580333bde34b363a7ad /mlir
parent230b1895c493c511c11541af3b5bc819887c82a8 (diff)
downloadllvm-7d2d8e2a7245e4e64da22cb3c422ea3be5a0bf0a.zip
llvm-7d2d8e2a7245e4e64da22cb3c422ea3be5a0bf0a.tar.gz
llvm-7d2d8e2a7245e4e64da22cb3c422ea3be5a0bf0a.tar.bz2
[mlir][complex] Fastmath flag for the trigonometric ops in complex (#85563)
Support Fastmath flag to convert trigonometric ops in the complex dialect. See: https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp50
-rw-r--r--mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir46
2 files changed, 75 insertions, 21 deletions
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 7672927..17f64f1 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -196,6 +196,7 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
auto loc = op.getLoc();
auto type = cast<ComplexType>(adaptor.getComplex().getType());
auto elementType = cast<FloatType>(type.getElementType());
+ arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
Value real =
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
@@ -207,14 +208,14 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
// implementation in the subclass to combine them.
Value half = rewriter.create<arith::ConstantOp>(
loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
- Value exp = rewriter.create<math::ExpOp>(loc, imag);
- Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
- Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
- Value sin = rewriter.create<math::SinOp>(loc, real);
- Value cos = rewriter.create<math::CosOp>(loc, real);
+ Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
+ Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
+ Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
+ Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
+ Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
auto resultPair =
- combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
+ combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
resultPair.second);
@@ -223,15 +224,17 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
virtual std::pair<Value, Value>
combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
- Value cos, ConversionPatternRewriter &rewriter) const = 0;
+ Value cos, ConversionPatternRewriter &rewriter,
+ arith::FastMathFlagsAttr fmf) const = 0;
};
struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
- std::pair<Value, Value>
- combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
- Value cos, ConversionPatternRewriter &rewriter) const override {
+ std::pair<Value, Value> combine(Location loc, Value scaledExp,
+ Value reciprocalExp, Value sin, Value cos,
+ ConversionPatternRewriter &rewriter,
+ arith::FastMathFlagsAttr fmf) const override {
// Complex cosine is defined as;
// cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
// Plugging in:
@@ -241,10 +244,12 @@ struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
// We get:
// Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
// Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
- Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
- Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
- Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
- Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin);
+ Value sum =
+ rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
+ Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
+ Value diff =
+ rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
+ Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
return {resultReal, resultImag};
}
};
@@ -813,9 +818,10 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
- std::pair<Value, Value>
- combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
- Value cos, ConversionPatternRewriter &rewriter) const override {
+ std::pair<Value, Value> combine(Location loc, Value scaledExp,
+ Value reciprocalExp, Value sin, Value cos,
+ ConversionPatternRewriter &rewriter,
+ arith::FastMathFlagsAttr fmf) const override {
// Complex sine is defined as;
// sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
// Plugging in:
@@ -825,10 +831,12 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
// We get:
// Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
// Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
- Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
- Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
- Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
- Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos);
+ Value sum =
+ rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
+ Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
+ Value diff =
+ rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
+ Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
return {resultReal, resultImag};
}
};
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 5918ff2..bac94aa 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1834,3 +1834,49 @@ func.func @complex_sqrt_with_fmf(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
// CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
// CHECK: return %[[VAR41]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_cos_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_cos_with_fmf(%arg: complex<f32>) -> complex<f32> {
+ %cos = complex.cos %arg fastmath<nnan,contract> : complex<f32>
+ return %cos : complex<f32>
+}
+// CHECK-DAG: %[[REAL:.*]] = complex.re %[[ARG]]
+// CHECK-DAG: %[[IMAG:.*]] = complex.im %[[ARG]]
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[SIN:.*]] = math.sin %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[COS:.*]] = math.cos %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_REXP]], %[[HALF_EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[COS]] fastmath<nnan,contract>
+// CHECK-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_REXP]], %[[HALF_EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[SIN]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]]
+
+// -----
+
+// CHECK-LABEL: func @complex_sin_with_fmf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_sin_with_fmf(%arg: complex<f32>) -> complex<f32> {
+ %cos = complex.sin %arg fastmath<nnan,contract> : complex<f32>
+ return %cos : complex<f32>
+}
+// CHECK-DAG: %[[REAL:.*]] = complex.re %[[ARG]]
+// CHECK-DAG: %[[IMAG:.*]] = complex.im %[[ARG]]
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[SIN:.*]] = math.sin %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[COS:.*]] = math.cos %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_EXP]], %[[HALF_REXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[SIN]] fastmath<nnan,contract>
+// CHECK-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_EXP]], %[[HALF_REXP]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[COS]] fastmath<nnan,contract>
+// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]]