diff options
Diffstat (limited to 'mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp')
-rw-r--r-- | mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 141 |
1 files changed, 86 insertions, 55 deletions
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 434d7df..265293b 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, return builder.getF32FloatAttr(dstVal.convertToFloat()); } +// Get in IntegerAttr from FloatAttr while preserving the bits. +// Useful for converting float constants to integer constants while preserving +// the bits. +static IntegerAttr +getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, + ConversionPatternRewriter &rewriter) { + APFloat floatVal = floatAttr.getValue(); + APInt intVal = floatVal.bitcastToAPInt(); + return rewriter.getIntegerAttr(dstType, intVal); +} + /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { assert(type && "Not a valid type"); @@ -117,12 +128,12 @@ static Value getScalarOrVectorConstInt(Type type, uint64_t value, if (auto vectorType = dyn_cast<VectorType>(type)) { Attribute element = IntegerAttr::get(vectorType.getElementType(), value); auto attr = SplatElementsAttr::get(vectorType, element); - return builder.create<spirv::ConstantOp>(loc, vectorType, attr); + return spirv::ConstantOp::create(builder, loc, vectorType, attr); } if (auto intType = dyn_cast<IntegerType>(type)) - return builder.create<spirv::ConstantOp>( - loc, type, builder.getIntegerAttr(type, value)); + return spirv::ConstantOp::create(builder, loc, type, + builder.getIntegerAttr(type, value)); return nullptr; } @@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final SmallVector<Attribute, 8> elements; if (isa<FloatType>(srcElemType)) { for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { - FloatAttr dstAttr = - convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter); + Attribute dstAttr = nullptr; + // Handle 8-bit float conversion to 8-bit integer. + auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcElemType.getIntOrFloatBitWidth() == 8 && + isa<IntegerType>(dstElemType)) { + dstAttr = + getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter); + } else { + dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), + rewriter); + } if (!dstAttr) return failure(); elements.push_back(dstAttr); @@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final // Floating-point types. if (isa<FloatType>(srcType)) { auto srcAttr = cast<FloatAttr>(cstAttr); - auto dstAttr = srcAttr; + Attribute dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. - if (srcType != dstType) { + auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) && + dstType.getIntOrFloatBitWidth() == 8) { + // If the source is an 8-bit float, convert it to a 8-bit integer. + dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter); + if (!dstAttr) + return failure(); + } else if (srcType != dstType) { dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter); if (!dstAttr) return failure(); @@ -418,18 +447,19 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Type type = lhs.getType(); // Calculate the remainder with spirv.UMod. - Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs); - Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs); - Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); + Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs); + Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs); + Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs); // Fix the sign. Value isPositive; if (lhs == signOperand) - isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); + isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs); else - isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); - Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); - return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); + isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs); + Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs); + return spirv::SelectOp::create(builder, loc, type, isPositive, abs, + absNegate); } /// Converts arith.remsi to GLSL SPIR-V ops. @@ -601,13 +631,13 @@ struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> { Value allOnes; if (auto intTy = dyn_cast<IntegerType>(dstType)) { unsigned componentBitwidth = intTy.getWidth(); - allOnes = rewriter.create<spirv::ConstantOp>( - loc, intTy, + allOnes = spirv::ConstantOp::create( + rewriter, loc, intTy, rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) { unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); - allOnes = rewriter.create<spirv::ConstantOp>( - loc, vectorTy, + allOnes = spirv::ConstantOp::create( + rewriter, loc, vectorTy, SplatElementsAttr::get(vectorTy, APInt::getAllOnes(componentBitwidth))); } else { @@ -653,8 +683,8 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> { // First shift left to sequeeze out all leading bits beyond the original // bitwidth. Here we need to use the original source and result type's // bitwidth. - auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>( - op.getLoc(), dstType, adaptor.getIn(), shiftSize); + auto shiftLOp = spirv::ShiftLeftLogicalOp::create( + rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize); // Then we perform arithmetic right shift to make sure we have the right // sign bits for negative values. @@ -757,9 +787,9 @@ struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { auto srcType = adaptor.getOperands().front().getType(); // Check if (x & 1) == 1. Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); - Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( - loc, srcType, adaptor.getOperands()[0], mask); - Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); + Value maskedSrc = spirv::BitwiseAndOp::create( + rewriter, loc, srcType, adaptor.getOperands()[0], mask); + Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); @@ -914,9 +944,9 @@ public: if (auto vectorType = dyn_cast<VectorType>(dstType)) type = VectorType::get(vectorType.getShape(), type); Value extLhs = - rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs()); + arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs()); Value extRhs = - rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs()); + arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs()); rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs, extRhs); @@ -1067,12 +1097,12 @@ public: replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter); } } else { - Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); + replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan); if (op.getPredicate() == arith::CmpFPredicate::ORD) - replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); + replace = spirv::LogicalNotOp::create(rewriter, loc, replace); } rewriter.replaceOp(op, replace); @@ -1094,17 +1124,17 @@ public: ConversionPatternRewriter &rewriter) const override { Type dstElemTy = adaptor.getLhs().getType(); Location loc = op->getLoc(); - Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(), - adaptor.getRhs()); + Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(), + adaptor.getRhs()); - Value sumResult = rewriter.create<spirv::CompositeExtractOp>( - loc, result, llvm::ArrayRef(0)); - Value carryValue = rewriter.create<spirv::CompositeExtractOp>( - loc, result, llvm::ArrayRef(1)); + Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(0)); + Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(1)); // Convert the carry value to boolean. Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); - Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one); + Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one); rewriter.replaceOp(op, {sumResult, carryResult}); return success(); @@ -1125,12 +1155,12 @@ public: ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value result = - rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs()); + SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs()); - Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result, - llvm::ArrayRef(0)); - Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result, - llvm::ArrayRef(1)); + Value low = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(0)); + Value high = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(1)); rewriter.replaceOp(op, {low, high}); return success(); @@ -1183,20 +1213,20 @@ public: Location loc = op.getLoc(); Value spirvOp = - rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands()); + SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands()); if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { rewriter.replaceOp(op, spirvOp); return success(); } - Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan, - adaptor.getLhs(), spirvOp); - Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan, - adaptor.getRhs(), select1); + Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan, + adaptor.getLhs(), spirvOp); + Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan, + adaptor.getRhs(), select1); rewriter.replaceOp(op, select2); return success(); @@ -1237,7 +1267,7 @@ public: Location loc = op.getLoc(); Value spirvOp = - rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands()); + SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands()); if (!shouldInsertNanGuards<SPIRVOp>() || bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { @@ -1245,13 +1275,13 @@ public: return success(); } - Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan, - adaptor.getRhs(), spirvOp); - Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan, - adaptor.getLhs(), select1); + Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan, + adaptor.getRhs(), spirvOp); + Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan, + adaptor.getLhs(), select1); rewriter.replaceOp(op, select2); return success(); @@ -1351,6 +1381,7 @@ struct ConvertArithToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull |