aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp')
-rw-r--r--mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp141
1 files changed, 86 insertions, 55 deletions
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 434d7df..265293b 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
+// Get in IntegerAttr from FloatAttr while preserving the bits.
+// Useful for converting float constants to integer constants while preserving
+// the bits.
+static IntegerAttr
+getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APFloat floatVal = floatAttr.getValue();
+ APInt intVal = floatVal.bitcastToAPInt();
+ return rewriter.getIntegerAttr(dstType, intVal);
+}
+
/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {
assert(type && "Not a valid type");
@@ -117,12 +128,12 @@ static Value getScalarOrVectorConstInt(Type type, uint64_t value,
if (auto vectorType = dyn_cast<VectorType>(type)) {
Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
auto attr = SplatElementsAttr::get(vectorType, element);
- return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
+ return spirv::ConstantOp::create(builder, loc, vectorType, attr);
}
if (auto intType = dyn_cast<IntegerType>(type))
- return builder.create<spirv::ConstantOp>(
- loc, type, builder.getIntegerAttr(type, value));
+ return spirv::ConstantOp::create(builder, loc, type,
+ builder.getIntegerAttr(type, value));
return nullptr;
}
@@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final
SmallVector<Attribute, 8> elements;
if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
- FloatAttr dstAttr =
- convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+ Attribute dstAttr = nullptr;
+ // Handle 8-bit float conversion to 8-bit integer.
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcElemType.getIntOrFloatBitWidth() == 8 &&
+ isa<IntegerType>(dstElemType)) {
+ dstAttr =
+ getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
+ } else {
+ dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
+ rewriter);
+ }
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
@@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final
// Floating-point types.
if (isa<FloatType>(srcType)) {
auto srcAttr = cast<FloatAttr>(cstAttr);
- auto dstAttr = srcAttr;
+ Attribute dstAttr = srcAttr;
// Floating-point types not supported in the target environment are all
// converted to float type.
- if (srcType != dstType) {
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
+ dstType.getIntOrFloatBitWidth() == 8) {
+ // If the source is an 8-bit float, convert it to a 8-bit integer.
+ dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
+ if (!dstAttr)
+ return failure();
+ } else if (srcType != dstType) {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr)
return failure();
@@ -418,18 +447,19 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
Type type = lhs.getType();
// Calculate the remainder with spirv.UMod.
- Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
- Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
- Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
+ Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs);
+ Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs);
+ Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
// Fix the sign.
Value isPositive;
if (lhs == signOperand)
- isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
+ isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs);
else
- isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
- Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
- return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
+ isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs);
+ Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
+ return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
+ absNegate);
}
/// Converts arith.remsi to GLSL SPIR-V ops.
@@ -601,13 +631,13 @@ struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
Value allOnes;
if (auto intTy = dyn_cast<IntegerType>(dstType)) {
unsigned componentBitwidth = intTy.getWidth();
- allOnes = rewriter.create<spirv::ConstantOp>(
- loc, intTy,
+ allOnes = spirv::ConstantOp::create(
+ rewriter, loc, intTy,
rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
} else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
- allOnes = rewriter.create<spirv::ConstantOp>(
- loc, vectorTy,
+ allOnes = spirv::ConstantOp::create(
+ rewriter, loc, vectorTy,
SplatElementsAttr::get(vectorTy,
APInt::getAllOnes(componentBitwidth)));
} else {
@@ -653,8 +683,8 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
// First shift left to sequeeze out all leading bits beyond the original
// bitwidth. Here we need to use the original source and result type's
// bitwidth.
- auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
- op.getLoc(), dstType, adaptor.getIn(), shiftSize);
+ auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
+ rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
// Then we perform arithmetic right shift to make sure we have the right
// sign bits for negative values.
@@ -757,9 +787,9 @@ struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
auto srcType = adaptor.getOperands().front().getType();
// Check if (x & 1) == 1.
Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
- Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
- loc, srcType, adaptor.getOperands()[0], mask);
- Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
+ Value maskedSrc = spirv::BitwiseAndOp::create(
+ rewriter, loc, srcType, adaptor.getOperands()[0], mask);
+ Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
@@ -914,9 +944,9 @@ public:
if (auto vectorType = dyn_cast<VectorType>(dstType))
type = VectorType::get(vectorType.getShape(), type);
Value extLhs =
- rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
+ arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
Value extRhs =
- rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
+ arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
extRhs);
@@ -1067,12 +1097,12 @@ public:
replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
}
} else {
- Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
- Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+ Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
+ Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
- replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
+ replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
if (op.getPredicate() == arith::CmpFPredicate::ORD)
- replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
+ replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
}
rewriter.replaceOp(op, replace);
@@ -1094,17 +1124,17 @@ public:
ConversionPatternRewriter &rewriter) const override {
Type dstElemTy = adaptor.getLhs().getType();
Location loc = op->getLoc();
- Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
- adaptor.getRhs());
+ Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
+ adaptor.getRhs());
- Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
- loc, result, llvm::ArrayRef(0));
- Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
- loc, result, llvm::ArrayRef(1));
+ Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(0));
+ Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(1));
// Convert the carry value to boolean.
Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
- Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
+ Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
rewriter.replaceOp(op, {sumResult, carryResult});
return success();
@@ -1125,12 +1155,12 @@ public:
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value result =
- rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
+ SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
- Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
- llvm::ArrayRef(0));
- Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
- llvm::ArrayRef(1));
+ Value low = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(0));
+ Value high = spirv::CompositeExtractOp::create(rewriter, loc, result,
+ llvm::ArrayRef(1));
rewriter.replaceOp(op, {low, high});
return success();
@@ -1183,20 +1213,20 @@ public:
Location loc = op.getLoc();
Value spirvOp =
- rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+ SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
rewriter.replaceOp(op, spirvOp);
return success();
}
- Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
- Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+ Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
+ Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
- Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
- adaptor.getLhs(), spirvOp);
- Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
- adaptor.getRhs(), select1);
+ Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
+ adaptor.getLhs(), spirvOp);
+ Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
+ adaptor.getRhs(), select1);
rewriter.replaceOp(op, select2);
return success();
@@ -1237,7 +1267,7 @@ public:
Location loc = op.getLoc();
Value spirvOp =
- rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+ SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
if (!shouldInsertNanGuards<SPIRVOp>() ||
bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
@@ -1245,13 +1275,13 @@ public:
return success();
}
- Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
- Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+ Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
+ Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
- Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
- adaptor.getRhs(), spirvOp);
- Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
- adaptor.getLhs(), select1);
+ Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
+ adaptor.getRhs(), spirvOp);
+ Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
+ adaptor.getLhs(), select1);
rewriter.replaceOp(op, select2);
return success();
@@ -1351,6 +1381,7 @@ struct ConvertArithToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull