aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp')
-rw-r--r--mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp38
1 files changed, 34 insertions, 4 deletions
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index d43e681..265293b 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
+// Get in IntegerAttr from FloatAttr while preserving the bits.
+// Useful for converting float constants to integer constants while preserving
+// the bits.
+static IntegerAttr
+getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APFloat floatVal = floatAttr.getValue();
+ APInt intVal = floatVal.bitcastToAPInt();
+ return rewriter.getIntegerAttr(dstType, intVal);
+}
+
/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {
assert(type && "Not a valid type");
@@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final
SmallVector<Attribute, 8> elements;
if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
- FloatAttr dstAttr =
- convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+ Attribute dstAttr = nullptr;
+ // Handle 8-bit float conversion to 8-bit integer.
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcElemType.getIntOrFloatBitWidth() == 8 &&
+ isa<IntegerType>(dstElemType)) {
+ dstAttr =
+ getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
+ } else {
+ dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
+ rewriter);
+ }
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
@@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final
// Floating-point types.
if (isa<FloatType>(srcType)) {
auto srcAttr = cast<FloatAttr>(cstAttr);
- auto dstAttr = srcAttr;
+ Attribute dstAttr = srcAttr;
// Floating-point types not supported in the target environment are all
// converted to float type.
- if (srcType != dstType) {
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
+ dstType.getIntOrFloatBitWidth() == 8) {
+ // If the source is an 8-bit float, convert it to a 8-bit integer.
+ dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
+ if (!dstAttr)
+ return failure();
+ } else if (srcType != dstType) {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr)
return failure();
@@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull