diff options
3 files changed, 162 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 1447b18..0be3d76 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Tools/PDLL/AST/Types.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -112,6 +113,93 @@ public: } }; +template <typename ArithOp, bool castToUnsigned> +class CastConversion : public OpConversionPattern<ArithOp> { +public: + using OpConversionPattern<ArithOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Type opReturnType = this->getTypeConverter()->convertType(op.getType()); + if (!isa_and_nonnull<IntegerType>(opReturnType)) + return rewriter.notifyMatchFailure(op, "expected integer result type"); + + if (adaptor.getOperands().size() != 1) { + return rewriter.notifyMatchFailure( + op, "CastConversion only supports unary ops"); + } + + Type operandType = adaptor.getIn().getType(); + if (!isa_and_nonnull<IntegerType>(operandType)) + return rewriter.notifyMatchFailure(op, "expected integer operand type"); + + // Signed (sign-extending) casts from i1 are not supported. + if (operandType.isInteger(1) && !castToUnsigned) + return rewriter.notifyMatchFailure(op, + "operation not supported on i1 type"); + + // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is + // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives + // truncation. + if (opReturnType.isInteger(1)) { + auto constOne = rewriter.create<emitc::ConstantOp>( + op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1)); + auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>( + op.getLoc(), operandType, adaptor.getIn(), constOne); + rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType, + oneAndOperand); + return success(); + } + + bool isTruncation = operandType.getIntOrFloatBitWidth() > + opReturnType.getIntOrFloatBitWidth(); + bool doUnsigned = castToUnsigned || isTruncation; + + Type castType = opReturnType; + // If the op is a ui variant and the type wanted as + // return type isn't unsigned, we need to issue an unsigned type to do + // the conversion. + if (castType.isUnsignedInteger() != doUnsigned) { + castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(), + /*isSigned=*/!doUnsigned); + } + + Value actualOp = adaptor.getIn(); + // Adapt the signedness of the operand if necessary + if (operandType.isUnsignedInteger() != doUnsigned) { + Type correctSignednessType = + rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(), + /*isSigned=*/!doUnsigned); + actualOp = rewriter.template create<emitc::CastOp>( + op.getLoc(), correctSignednessType, actualOp); + } + + auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType, + actualOp); + + // Cast to the expected output type + if (castType != opReturnType) { + result = rewriter.template create<emitc::CastOp>(op.getLoc(), + opReturnType, result); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +template <typename ArithOp> +class UnsignedCastConversion : public CastConversion<ArithOp, true> { + using CastConversion<ArithOp, true>::CastConversion; +}; + +template <typename ArithOp> +class SignedCastConversion : public CastConversion<ArithOp, false> { + using CastConversion<ArithOp, false>::CastConversion; +}; + template <typename ArithOp, typename EmitCOp> class ArithOpConversion final : public OpConversionPattern<ArithOp> { public: @@ -313,6 +401,10 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, IntegerOpConversion<arith::SubIOp, emitc::SubOp>, CmpIOpConversion, SelectOpConversion, + // Truncation is guaranteed for unsigned types. + UnsignedCastConversion<arith::TruncIOp>, + SignedCastConversion<arith::ExtSIOp>, + UnsignedCastConversion<arith::ExtUIOp>, ItoFCastOpConversion<arith::SIToFPOp>, ItoFCastOpConversion<arith::UIToFPOp>, FtoICastOpConversion<arith::FPToSIOp>, diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir index 66dfa8f..97e4593 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir @@ -63,3 +63,10 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 { return %t: i1 } +// ----- + +func.func @arith_extsi_i1_to_i32(%arg0: i1) { + // expected-error @+1 {{failed to legalize operation 'arith.extsi'}} + %idx = arith.extsi %arg0 : i1 to i32 + return +} diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 79fecd6..b453b69 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -177,3 +177,66 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) { return } + +// ----- + +func.func @arith_trunci(%arg0: i32) -> i8 { + // CHECK-LABEL: arith_trunci + // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32) + // CHECK: %[[CastUI:.*]] = emitc.cast %[[Arg0]] : i32 to ui32 + // CHECK: %[[Trunc:.*]] = emitc.cast %[[CastUI]] : ui32 to ui8 + // CHECK: emitc.cast %[[Trunc]] : ui8 to i8 + %truncd = arith.trunci %arg0 : i32 to i8 + + return %truncd : i8 +} + +// ----- + +func.func @arith_trunci_to_i1(%arg0: i32) -> i1 { + // CHECK-LABEL: arith_trunci_to_i1 + // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32) + // CHECK: %[[Const:.*]] = "emitc.constant" + // CHECK-SAME: value = 1 + // CHECK: %[[And:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32 + // CHECK: emitc.cast %[[And]] : i32 to i1 + %truncd = arith.trunci %arg0 : i32 to i1 + + return %truncd : i1 +} + +// ----- + +func.func @arith_extsi(%arg0: i32) { + // CHECK-LABEL: arith_extsi + // CHECK-SAME: ([[Arg0:[^ ]*]]: i32) + // CHECK: emitc.cast [[Arg0]] : i32 to i64 + %extd = arith.extsi %arg0 : i32 to i64 + + return +} + +// ----- + +func.func @arith_extui(%arg0: i32) { + // CHECK-LABEL: arith_extui + // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32) + // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32 + // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to ui64 + // CHECK: emitc.cast %[[Conv1]] : ui64 to i64 + %extd = arith.extui %arg0 : i32 to i64 + + return +} + +// ----- + +func.func @arith_extui_i1_to_i32(%arg0: i1) { + // CHECK-LABEL: arith_extui_i1_to_i32 + // CHECK-SAME: (%[[Arg0:[^ ]*]]: i1) + // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i1 to ui1 + // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui1 to ui32 + // CHECK: emitc.cast %[[Conv1]] : ui32 to i32 + %idx = arith.extui %arg0 : i1 to i32 + return +} |
