aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCorentin Ferry <corentin.ferry@amd.com>2024-06-19 09:19:33 +0200
committerGitHub <noreply@github.com>2024-06-19 09:19:33 +0200
commit519175c3f5d844bac0cf3173396dc41db2873e1d (patch)
tree9c8757f707407e128fac1b3fcc21c97af872de81
parent8af86025af2456c70c84aec309cca9a069124671 (diff)
downloadllvm-519175c3f5d844bac0cf3173396dc41db2873e1d.zip
llvm-519175c3f5d844bac0cf3173396dc41db2873e1d.tar.gz
llvm-519175c3f5d844bac0cf3173396dc41db2873e1d.tar.bz2
[mlir][emitc] Refactor ArithToEmitC: perform sign adaptation, type conversions / cast insertion in a single place (#95789)
Factor EmitC type signedness adaptation and cast operations in ArithToEmitC using adaptValueType and adaptIntegralTypeSignedness.
-rw-r--r--mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp78
-rw-r--r--mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir7
2 files changed, 33 insertions, 52 deletions
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 27913df..93717e3 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -270,19 +270,11 @@ public:
bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
- Type arithmeticType = type;
- if (type.isUnsignedInteger() != needsUnsigned) {
- arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
- /*isSigned=*/!needsUnsigned);
- }
- Value lhs = adaptor.getLhs();
- Value rhs = adaptor.getRhs();
- if (arithmeticType != type) {
- lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
- lhs);
- rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
- rhs);
- }
+
+ Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
+ Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+ Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
return success();
}
@@ -356,37 +348,26 @@ public:
return success();
}
- bool isTruncation = operandType.getIntOrFloatBitWidth() >
- opReturnType.getIntOrFloatBitWidth();
+ bool isTruncation =
+ (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
+ operandType.getIntOrFloatBitWidth() >
+ opReturnType.getIntOrFloatBitWidth());
bool doUnsigned = castToUnsigned || isTruncation;
- Type castType = opReturnType;
- // If the op is a ui variant and the type wanted as
- // return type isn't unsigned, we need to issue an unsigned type to do
- // the conversion.
- if (castType.isUnsignedInteger() != doUnsigned) {
- castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
- /*isSigned=*/!doUnsigned);
- }
+ // Adapt the signedness of the result (bitwidth-preserving cast)
+ // This is needed e.g., if the return type is signless.
+ Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
- Value actualOp = adaptor.getIn();
- // Adapt the signedness of the operand if necessary
- if (operandType.isUnsignedInteger() != doUnsigned) {
- Type correctSignednessType =
- rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
- /*isSigned=*/!doUnsigned);
- actualOp = rewriter.template create<emitc::CastOp>(
- op.getLoc(), correctSignednessType, actualOp);
- }
+ // Adapt the signedness of the operand (bitwidth-preserving cast)
+ Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
+ Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
- auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
- actualOp);
+ // Actual cast (may change bitwidth)
+ auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
+ castDestType, actualOp);
// Cast to the expected output type
- if (castType != opReturnType) {
- result = rewriter.template create<emitc::CastOp>(op.getLoc(),
- opReturnType, result);
- }
+ auto result = adaptValueType(cast, rewriter, opReturnType);
rewriter.replaceOp(op, result);
return success();
@@ -438,8 +419,6 @@ public:
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
}
- Value lhs = adaptor.getLhs();
- Value rhs = adaptor.getRhs();
Type arithmeticType = type;
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
!bitEnumContainsAll(op.getOverflowFlags(),
@@ -449,20 +428,15 @@ public:
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}
- if (arithmeticType != type) {
- lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
- lhs);
- rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
- rhs);
- }
- Value result = rewriter.template create<EmitCOp>(op.getLoc(),
- arithmeticType, lhs, rhs);
+ Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
+ Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
+
+ Value arithmeticResult = rewriter.template create<EmitCOp>(
+ op.getLoc(), arithmeticType, lhs, rhs);
+
+ Value result = adaptValueType(arithmeticResult, rewriter, type);
- if (arithmeticType != type) {
- result =
- rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
- }
rewriter.replaceOp(op, result);
return success();
}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 667ff79..0289b7d 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -477,6 +477,13 @@ func.func @arith_trunci(%arg0: i32) -> i8 {
// CHECK: emitc.cast %[[Trunc]] : ui8 to i8
%truncd = arith.trunci %arg0 : i32 to i8
+ // CHECK: %[[Const:.*]] = "emitc.constant"
+ // CHECK-SAME: value = 1
+ // CHECK-SAME: () -> i32
+ // CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
+ // CHECK: %[[Conv:.*]] = emitc.cast %[[AndOne]] : i32 to i1
+ %bool = arith.trunci %arg0 : i32 to i1
+
return %truncd : i8
}