diff options
author | lewuathe <lewuathe@me.com> | 2022-06-21 08:29:02 +0900 |
---|---|---|
committer | lewuathe <lewuathe@me.com> | 2022-06-21 08:38:07 +0900 |
commit | 0bae40eff6a7b48e00ab5c8f0fc510823a1ef6a0 (patch) | |
tree | 14f25aef1a5e363a9572348dd0e63035f4e2d4c4 | |
parent | 8c6e138aa893bb88fc3d5d449e42082741f0e2a2 (diff) | |
download | llvm-0bae40eff6a7b48e00ab5c8f0fc510823a1ef6a0.zip llvm-0bae40eff6a7b48e00ab5c8f0fc510823a1ef6a0.tar.gz llvm-0bae40eff6a7b48e00ab5c8f0fc510823a1ef6a0.tar.bz2 |
[mlir][math] Lower cos,sin to libm
Lower math.cos and math.sin to libm
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D128028
-rw-r--r-- | mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 10 | ||||
-rw-r--r-- | mlir/test/Conversion/MathToLibm/convert-to-libm.mlir | 28 |
2 files changed, 36 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp index 78835e1..d209e8d 100644 --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -141,9 +141,11 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op, void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>, - VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit); + VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>, + VecOpToScalarOp<math::SinOp>>(patterns.getContext(), benefit); patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>, - PromoteOpToF32<math::TanhOp>>(patterns.getContext(), benefit); + PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>, + PromoteOpToF32<math::SinOp>>(patterns.getContext(), benefit); patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(), "atan2f", "atan2", benefit); patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff", @@ -154,6 +156,10 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, "tanh", benefit); patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(), "roundf", "round", benefit); + patterns.add<ScalarOpToLibmCall<math::CosOp>>(patterns.getContext(), "cosf", + "cos", benefit); + patterns.add<ScalarOpToLibmCall<math::SinOp>>(patterns.getContext(), "sinf", + "sin", benefit); } namespace { diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir index cb09988..4028532 100644 --- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir @@ -10,6 +10,10 @@ // CHECK-DAG: @tanhf(f32) -> f32 // CHECK-DAG: @round(f64) -> f64 // CHECK-DAG: @roundf(f32) -> f32 +// CHECK-DAG: @cos(f64) -> f64 +// CHECK-DAG: @cosf(f32) -> f32 +// CHECK-DAG: @sin(f64) -> f64 +// CHECK-DAG: @sinf(f32) -> f32 // CHECK-LABEL: func @tanh_caller // CHECK-SAME: %[[FLOAT:.*]]: f32 @@ -129,3 +133,27 @@ func.func @round_caller(%float: f32, %double: f64) -> (f32, f64) { // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] return %float_result, %double_result : f32, f64 } + +// CHECK-LABEL: func @cos_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func.func @cos_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cosf(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.cos %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cos(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.cos %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} + +// CHECK-LABEL: func @sin_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func.func @sin_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @sinf(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.sin %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @sin(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.sin %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} |