aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorlewuathe <lewuathe@me.com>2022-06-21 08:29:02 +0900
committerlewuathe <lewuathe@me.com>2022-06-21 08:38:07 +0900
commit0bae40eff6a7b48e00ab5c8f0fc510823a1ef6a0 (patch)
tree14f25aef1a5e363a9572348dd0e63035f4e2d4c4
parent8c6e138aa893bb88fc3d5d449e42082741f0e2a2 (diff)
downloadllvm-0bae40eff6a7b48e00ab5c8f0fc510823a1ef6a0.zip
llvm-0bae40eff6a7b48e00ab5c8f0fc510823a1ef6a0.tar.gz
llvm-0bae40eff6a7b48e00ab5c8f0fc510823a1ef6a0.tar.bz2
[mlir][math] Lower cos,sin to libm
Lower math.cos and math.sin to libm Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D128028
-rw-r--r--mlir/lib/Conversion/MathToLibm/MathToLibm.cpp10
-rw-r--r--mlir/test/Conversion/MathToLibm/convert-to-libm.mlir28
2 files changed, 36 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 78835e1..d209e8d 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -141,9 +141,11 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
- VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
+ VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
+ VecOpToScalarOp<math::SinOp>>(patterns.getContext(), benefit);
patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
- PromoteOpToF32<math::TanhOp>>(patterns.getContext(), benefit);
+ PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>,
+ PromoteOpToF32<math::SinOp>>(patterns.getContext(), benefit);
patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
"atan2f", "atan2", benefit);
patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
@@ -154,6 +156,10 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
"tanh", benefit);
patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(),
"roundf", "round", benefit);
+ patterns.add<ScalarOpToLibmCall<math::CosOp>>(patterns.getContext(), "cosf",
+ "cos", benefit);
+ patterns.add<ScalarOpToLibmCall<math::SinOp>>(patterns.getContext(), "sinf",
+ "sin", benefit);
}
namespace {
diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index cb09988..4028532 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -10,6 +10,10 @@
// CHECK-DAG: @tanhf(f32) -> f32
// CHECK-DAG: @round(f64) -> f64
// CHECK-DAG: @roundf(f32) -> f32
+// CHECK-DAG: @cos(f64) -> f64
+// CHECK-DAG: @cosf(f32) -> f32
+// CHECK-DAG: @sin(f64) -> f64
+// CHECK-DAG: @sinf(f32) -> f32
// CHECK-LABEL: func @tanh_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
@@ -129,3 +133,27 @@ func.func @round_caller(%float: f32, %double: f64) -> (f32, f64) {
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
return %float_result, %double_result : f32, f64
}
+
+// CHECK-LABEL: func @cos_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @cos_caller(%float: f32, %double: f64) -> (f32, f64) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cosf(%[[FLOAT]]) : (f32) -> f32
+ %float_result = math.cos %float : f32
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cos(%[[DOUBLE]]) : (f64) -> f64
+ %double_result = math.cos %double : f64
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : f32, f64
+}
+
+// CHECK-LABEL: func @sin_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @sin_caller(%float: f32, %double: f64) -> (f32, f64) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @sinf(%[[FLOAT]]) : (f32) -> f32
+ %float_result = math.sin %float : f32
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @sin(%[[DOUBLE]]) : (f64) -> f64
+ %double_result = math.sin %double : f64
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : f32, f64
+}