aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJianjian Guan <jacquesguan@me.com>2025-01-16 14:57:43 +0800
committerGitHub <noreply@github.com>2025-01-16 14:57:43 +0800
commitf9a80062470daf94e07f65f9dd23df6a4f2946a2 (patch)
tree0e0785ee11da001c487b46cae4c30c3aef23bd25
parent0195ec452e16a0ff4b4f4ff2e2ea5a1dd5a20563 (diff)
downloadllvm-f9a80062470daf94e07f65f9dd23df6a4f2946a2.zip
llvm-f9a80062470daf94e07f65f9dd23df6a4f2946a2.tar.gz
llvm-f9a80062470daf94e07f65f9dd23df6a4f2946a2.tar.bz2
[mlir][emitc] Support convert arith.extf and arith.truncf to emitc (#121184)
-rw-r--r--mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp41
-rw-r--r--mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir40
-rw-r--r--mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir26
3 files changed, 106 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index ccbc166..359d7b2 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -733,6 +733,43 @@ public:
}
};
+// Floating-point to floating-point conversions.
+template <typename CastOp>
+class FpCastOpConversion : public OpConversionPattern<CastOp> {
+public:
+ FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
+ : OpConversionPattern<CastOp>(typeConverter, context) {}
+
+ LogicalResult
+ matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Vectors in particular are not supported.
+ Type operandType = adaptor.getIn().getType();
+ if (!emitc::isSupportedFloatType(operandType))
+ return rewriter.notifyMatchFailure(castOp,
+ "unsupported cast source type");
+ if (auto roundingModeOp =
+ dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
+ // Only supporting default rounding mode as of now.
+ if (roundingModeOp.getRoundingModeAttr())
+ return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode");
+ }
+
+ Type dstType = this->getTypeConverter()->convertType(castOp.getType());
+ if (!dstType)
+ return rewriter.notifyMatchFailure(castOp, "type conversion failed");
+
+ if (!emitc::isSupportedFloatType(dstType))
+ return rewriter.notifyMatchFailure(castOp,
+ "unsupported cast destination type");
+
+ Value fpCastOperand = adaptor.getIn();
+ rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
+
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -778,7 +815,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,
- FtoICastOpConversion<arith::FPToUIOp>
+ FtoICastOpConversion<arith::FPToUIOp>,
+ FpCastOpConversion<arith::ExtFOp>,
+ FpCastOpConversion<arith::TruncFOp>
>(typeConverter, ctx);
// clang-format on
}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index b866904..9850f33 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -149,3 +149,43 @@ func.func @arith_remui_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) -> vec
%divui = arith.remui %arg0, %arg1 : vector<5xi32>
return %divui: vector<5xi32>
}
+
+// -----
+
+func.func @arith_truncf(%arg0: f64) -> f32 {
+ // expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
+ %truncd = arith.truncf %arg0 to_nearest_away : f64 to f32
+ return %truncd : f32
+}
+
+// -----
+
+func.func @arith_extf_f128(%arg0: f32) -> f128 {
+ // expected-error @+1 {{failed to legalize operation 'arith.extf'}}
+ %extd = arith.extf %arg0 : f32 to f128
+ return %extd : f128
+}
+
+// -----
+
+func.func @arith_truncf_f128(%arg0: f128) -> f32 {
+ // expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
+ %truncd = arith.truncf %arg0 : f128 to f32
+ return %truncd : f32
+}
+
+// -----
+
+func.func @arith_extf_vector(%arg0: vector<4xf32>) -> vector<4xf64> {
+ // expected-error @+1 {{failed to legalize operation 'arith.extf'}}
+ %extd = arith.extf %arg0 : vector<4xf32> to vector<4xf64>
+ return %extd : vector<4xf64>
+}
+
+// -----
+
+func.func @arith_truncf_vector(%arg0: vector<4xf64>) -> vector<4xf32> {
+ // expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
+ %truncd = arith.truncf %arg0 : vector<4xf64> to vector<4xf32>
+ return %truncd : vector<4xf32>
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 1728c3a..4e3d108 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -739,3 +739,29 @@ func.func @arith_divui_remui(%arg0: i32, %arg1: i32) -> i32 {
return %div : i32
}
+
+// -----
+
+func.func @arith_extf(%arg0: f16) -> f64 {
+ // CHECK-LABEL: arith_extf
+ // CHECK-SAME: (%[[Arg0:[^ ]*]]: f16)
+ // CHECK: %[[Extd0:.*]] = emitc.cast %[[Arg0]] : f16 to f32
+ %extd0 = arith.extf %arg0 : f16 to f32
+ // CHECK: %[[Extd1:.*]] = emitc.cast %[[Extd0]] : f32 to f64
+ %extd1 = arith.extf %extd0 : f32 to f64
+
+ return %extd1 : f64
+}
+
+// -----
+
+func.func @arith_truncf(%arg0: f64) -> f16 {
+ // CHECK-LABEL: arith_truncf
+ // CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
+ // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
+ %truncd0 = arith.truncf %arg0 : f64 to f32
+ // CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
+ %truncd1 = arith.truncf %truncd0 : f32 to f16
+
+ return %truncd1 : f16
+}