diff options
| author | Jianjian Guan <jacquesguan@me.com> | 2025-01-16 14:57:43 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-16 14:57:43 +0800 |
| commit | f9a80062470daf94e07f65f9dd23df6a4f2946a2 (patch) | |
| tree | 0e0785ee11da001c487b46cae4c30c3aef23bd25 | |
| parent | 0195ec452e16a0ff4b4f4ff2e2ea5a1dd5a20563 (diff) | |
| download | llvm-f9a80062470daf94e07f65f9dd23df6a4f2946a2.zip llvm-f9a80062470daf94e07f65f9dd23df6a4f2946a2.tar.gz llvm-f9a80062470daf94e07f65f9dd23df6a4f2946a2.tar.bz2 | |
[mlir][emitc] Support convert arith.extf and arith.truncf to emitc (#121184)
3 files changed, 106 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index ccbc166..359d7b2 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -733,6 +733,43 @@ public: } }; +// Floating-point to floating-point conversions. +template <typename CastOp> +class FpCastOpConversion : public OpConversionPattern<CastOp> { +public: + FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern<CastOp>(typeConverter, context) {} + + LogicalResult + matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Vectors in particular are not supported. + Type operandType = adaptor.getIn().getType(); + if (!emitc::isSupportedFloatType(operandType)) + return rewriter.notifyMatchFailure(castOp, + "unsupported cast source type"); + if (auto roundingModeOp = + dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) { + // Only supporting default rounding mode as of now. + if (roundingModeOp.getRoundingModeAttr()) + return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode"); + } + + Type dstType = this->getTypeConverter()->convertType(castOp.getType()); + if (!dstType) + return rewriter.notifyMatchFailure(castOp, "type conversion failed"); + + if (!emitc::isSupportedFloatType(dstType)) + return rewriter.notifyMatchFailure(castOp, + "unsupported cast destination type"); + + Value fpCastOperand = adaptor.getIn(); + rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand); + + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -778,7 +815,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, ItoFCastOpConversion<arith::SIToFPOp>, ItoFCastOpConversion<arith::UIToFPOp>, FtoICastOpConversion<arith::FPToSIOp>, - FtoICastOpConversion<arith::FPToUIOp> + FtoICastOpConversion<arith::FPToUIOp>, + FpCastOpConversion<arith::ExtFOp>, + FpCastOpConversion<arith::TruncFOp> >(typeConverter, ctx); // clang-format on } diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir index b866904..9850f33 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir @@ -149,3 +149,43 @@ func.func @arith_remui_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) -> vec %divui = arith.remui %arg0, %arg1 : vector<5xi32> return %divui: vector<5xi32> } + +// ----- + +func.func @arith_truncf(%arg0: f64) -> f32 { + // expected-error @+1 {{failed to legalize operation 'arith.truncf'}} + %truncd = arith.truncf %arg0 to_nearest_away : f64 to f32 + return %truncd : f32 +} + +// ----- + +func.func @arith_extf_f128(%arg0: f32) -> f128 { + // expected-error @+1 {{failed to legalize operation 'arith.extf'}} + %extd = arith.extf %arg0 : f32 to f128 + return %extd : f128 +} + +// ----- + +func.func @arith_truncf_f128(%arg0: f128) -> f32 { + // expected-error @+1 {{failed to legalize operation 'arith.truncf'}} + %truncd = arith.truncf %arg0 : f128 to f32 + return %truncd : f32 +} + +// ----- + +func.func @arith_extf_vector(%arg0: vector<4xf32>) -> vector<4xf64> { + // expected-error @+1 {{failed to legalize operation 'arith.extf'}} + %extd = arith.extf %arg0 : vector<4xf32> to vector<4xf64> + return %extd : vector<4xf64> +} + +// ----- + +func.func @arith_truncf_vector(%arg0: vector<4xf64>) -> vector<4xf32> { + // expected-error @+1 {{failed to legalize operation 'arith.truncf'}} + %truncd = arith.truncf %arg0 : vector<4xf64> to vector<4xf32> + return %truncd : vector<4xf32> +} diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 1728c3a..4e3d108 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -739,3 +739,29 @@ func.func @arith_divui_remui(%arg0: i32, %arg1: i32) -> i32 { return %div : i32 } + +// ----- + +func.func @arith_extf(%arg0: f16) -> f64 { + // CHECK-LABEL: arith_extf + // CHECK-SAME: (%[[Arg0:[^ ]*]]: f16) + // CHECK: %[[Extd0:.*]] = emitc.cast %[[Arg0]] : f16 to f32 + %extd0 = arith.extf %arg0 : f16 to f32 + // CHECK: %[[Extd1:.*]] = emitc.cast %[[Extd0]] : f32 to f64 + %extd1 = arith.extf %extd0 : f32 to f64 + + return %extd1 : f64 +} + +// ----- + +func.func @arith_truncf(%arg0: f64) -> f16 { + // CHECK-LABEL: arith_truncf + // CHECK-SAME: (%[[Arg0:[^ ]*]]: f64) + // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32 + %truncd0 = arith.truncf %arg0 : f64 to f32 + // CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16 + %truncd1 = arith.truncf %truncd0 : f32 to f16 + + return %truncd1 : f16 +} |
