diff options
author | Tina Jung <tinamaria.jung@amd.com> | 2024-06-18 16:56:56 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-18 16:56:56 +0200 |
commit | ffc31d3221e2ebe1f5b1e5c846dcde27cb326616 (patch) | |
tree | 880b771ff139b39ad5db7bc967234d5a42f6142f /mlir | |
parent | 6be6c3a37be46ebefa967b66e398d8ea9ed4ffe8 (diff) | |
download | llvm-ffc31d3221e2ebe1f5b1e5c846dcde27cb326616.zip llvm-ffc31d3221e2ebe1f5b1e5c846dcde27cb326616.tar.gz llvm-ffc31d3221e2ebe1f5b1e5c846dcde27cb326616.tar.bz2 |
[mlir][emitc] arith.negf to EmitC conversion (#95372)
Lower arith.negf to the unary minus in EmitC.
Diffstat (limited to 'mlir')
3 files changed, 64 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 74f0f61..27913df 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -288,6 +288,34 @@ public: } }; +class NegFOpConversion : public OpConversionPattern<arith::NegFOp> { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto adaptedOp = adaptor.getOperand(); + auto adaptedOpType = adaptedOp.getType(); + + if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) { + return rewriter.notifyMatchFailure( + op.getLoc(), + "negf currently only supports scalar types, not vectors or tensors"); + } + + if (!emitc::isSupportedFloatType(adaptedOpType)) { + return rewriter.notifyMatchFailure( + op.getLoc(), "floating-point type is not supported by EmitC"); + } + + rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType, + adaptedOp); + return success(); + } +}; + template <typename ArithOp, bool castToUnsigned> class CastConversion : public OpConversionPattern<ArithOp> { public: @@ -621,6 +649,7 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>, CmpFOpConversion, CmpIOpConversion, + NegFOpConversion, SelectOpConversion, // Truncation is guaranteed for unsigned types. UnsignedCastConversion<arith::TruncIOp>, diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir index c072891..caef040 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir @@ -81,6 +81,30 @@ func.func @arith_cmpf_tensor(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tens // ----- +func.func @arith_negf_f80(%arg0: f80) -> f80 { + // expected-error @+1 {{failed to legalize operation 'arith.negf'}} + %n = arith.negf %arg0 : f80 + return %n: f80 +} + +// ----- + +func.func @arith_negf_tensor(%arg0: tensor<5xf32>) -> tensor<5xf32> { + // expected-error @+1 {{failed to legalize operation 'arith.negf'}} + %n = arith.negf %arg0 : tensor<5xf32> + return %n: tensor<5xf32> +} + +// ----- + +func.func @arith_negf_vector(%arg0: vector<5xf32>) -> vector<5xf32> { + // expected-error @+1 {{failed to legalize operation 'arith.negf'}} + %n = arith.negf %arg0 : vector<5xf32> + return %n: vector<5xf32> +} + +// ----- + func.func @arith_extsi_i1_to_i32(%arg0: i1) { // expected-error @+1 {{failed to legalize operation 'arith.extsi'}} %idx = arith.extsi %arg0 : i1 to i32 diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 71f1a6a..667ff79 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -422,6 +422,17 @@ func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) { // ----- +func.func @arith_negf(%arg0: f32) -> f32 { + // CHECK-LABEL: arith_negf + // CHECK-SAME: %[[Arg0:[^ ]*]]: f32 + // CHECK: %[[N:[^ ]*]] = emitc.unary_minus %[[Arg0]] : (f32) -> f32 + %n = arith.negf %arg0 : f32 + // CHECK: return %[[N]] + return %n: f32 +} + +// ----- + func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) { // CHECK: emitc.cast %arg0 : f32 to i32 %0 = arith.fptosi %arg0 : f32 to i32 |