diff options
author | Reed <reedwm@google.com> | 2022-11-15 20:11:50 +0100 |
---|---|---|
committer | Benjamin Kramer <benny.kra@googlemail.com> | 2022-11-15 20:26:42 +0100 |
commit | 88eb3c62f25d820a80bc73af786002e1fc4bbbba (patch) | |
tree | db4b36d7e6fe79c8ddbe1de3c1cddce4823a48aa /llvm/lib/Support/APFloat.cpp | |
parent | b2d9e08c4428c1d1c65487d6fa4d6bdf80c32450 (diff) | |
download | llvm-88eb3c62f25d820a80bc73af786002e1fc4bbbba.zip llvm-88eb3c62f25d820a80bc73af786002e1fc4bbbba.tar.gz llvm-88eb3c62f25d820a80bc73af786002e1fc4bbbba.tar.bz2 |
Add FP8 E4M3 support to APFloat.
NVIDIA, ARM, and Intel recently introduced two new FP8 formats, as described in the paper: https://arxiv.org/abs/2209.05433. The first of the two FP8 dtypes, E5M2, was added in https://reviews.llvm.org/D133823. This change adds the second of the two: E4M3.
There is an RFC for adding the FP8 dtypes here: https://discourse.llvm.org/t/rfc-add-apfloat-and-mlir-type-support-for-fp8-e5m2/65279. I spoke with the RFC's author, Stella, and she gave me the go ahead to implement the E4M3 type. The name of the E4M3 type in APFloat is Float8E4M3FN, as discussed in the RFC. The "FN" means only Finite and NaN values are supported.
Unlike E5M2, E4M3 has different behavior from IEEE types in regards to Inf and NaN values. There are no Inf values, and NaN is represented when the exponent and mantissa bits are all 1s. To represent these differences in APFloat, I added an enum field, fltNonfiniteBehavior, to the fltSemantics struct. The possible enum values are IEEE754 and NanOnly. Only Float8E4M3FN has the NanOnly behavior.
After this change is submitted, I plan on adding the Float8E4M3FN type to MLIR, in the same way as E5M2 was added in https://reviews.llvm.org/D133823.
Reviewed By: bkramer
Differential Revision: https://reviews.llvm.org/D137760
Diffstat (limited to 'llvm/lib/Support/APFloat.cpp')
-rw-r--r-- | llvm/lib/Support/APFloat.cpp | 223 |
1 files changed, 201 insertions, 22 deletions
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp index 68063bb..22dd40c 100644 --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -50,6 +50,23 @@ using namespace llvm; static_assert(APFloatBase::integerPartWidth % 4 == 0, "Part width must be divisible by 4!"); namespace llvm { + + // How the nonfinite values Inf and NaN are represented. + enum class fltNonfiniteBehavior { + // Represents standard IEEE 754 behavior. A value is nonfinite if the + // exponent field is all 1s. In such cases, a value is Inf if the + // significand bits are all zero, and NaN otherwise + IEEE754, + + // Only the Float8E5M2 has this behavior. There is no Inf representation. A + // value is NaN if the exponent field and the mantissa field are all 1s. + // This behavior matches the FP8 E4M3 type described in + // https://arxiv.org/abs/2209.05433. We treat both signed and unsigned NaNs + // as non-signalling, although the paper does not state whether the NaN + // values are signalling or not. + NanOnly, + }; + /* Represents floating point arithmetic semantics. */ struct fltSemantics { /* The largest E such that 2^E is representable; this matches the @@ -67,8 +84,11 @@ namespace llvm { /* Number of bits actually used in the semantics. */ unsigned int sizeInBits; + fltNonfiniteBehavior nonFiniteBehavior = fltNonfiniteBehavior::IEEE754; + // Returns true if any number described by this semantics can be precisely - // represented by the specified semantics. + // represented by the specified semantics. Does not take into account + // the value of fltNonfiniteBehavior. bool isRepresentableBy(const fltSemantics &S) const { return maxExponent <= S.maxExponent && minExponent >= S.minExponent && precision <= S.precision; @@ -81,6 +101,8 @@ namespace llvm { static const fltSemantics semIEEEdouble = {1023, -1022, 53, 64}; static const fltSemantics semIEEEquad = {16383, -16382, 113, 128}; static const fltSemantics semFloat8E5M2 = {15, -14, 3, 8}; + static const fltSemantics semFloat8E4M3FN = {8, -6, 4, 8, + fltNonfiniteBehavior::NanOnly}; static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80}; static const fltSemantics semBogus = {0, 0, 0, 0}; @@ -138,6 +160,8 @@ namespace llvm { return PPCDoubleDouble(); case S_Float8E5M2: return Float8E5M2(); + case S_Float8E4M3FN: + return Float8E4M3FN(); case S_x87DoubleExtended: return x87DoubleExtended(); } @@ -160,6 +184,8 @@ namespace llvm { return S_PPCDoubleDouble; else if (&Sem == &llvm::APFloat::Float8E5M2()) return S_Float8E5M2; + else if (&Sem == &llvm::APFloat::Float8E4M3FN()) + return S_Float8E4M3FN; else if (&Sem == &llvm::APFloat::x87DoubleExtended()) return S_x87DoubleExtended; else @@ -183,6 +209,7 @@ namespace llvm { return semPPCDoubleDouble; } const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; } + const fltSemantics &APFloatBase::Float8E4M3FN() { return semFloat8E4M3FN; } const fltSemantics &APFloatBase::x87DoubleExtended() { return semX87DoubleExtended; } @@ -769,6 +796,15 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) { integerPart *significand = significandParts(); unsigned numParts = partCount(); + APInt fill_storage; + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + // The only NaN representation is where the mantissa is all 1s, which is + // non-signalling. + SNaN = false; + fill_storage = APInt::getAllOnes(semantics->precision - 1); + fill = &fill_storage; + } + // Set the significand bits to the fill. if (!fill || fill->getNumWords() < numParts) APInt::tcSet(significand, 0, numParts); @@ -869,6 +905,33 @@ bool IEEEFloat::isSignificandAllOnes() const { return true; } +bool IEEEFloat::isSignificandAllOnesExceptLSB() const { + // Test if the significand excluding the integral bit is all ones except for + // the least significant bit. + const integerPart *Parts = significandParts(); + + if (Parts[0] & 1) + return false; + + const unsigned PartCount = partCountForBits(semantics->precision); + for (unsigned i = 0; i < PartCount - 1; i++) { + if (~Parts[i] & ~unsigned{!i}) + return false; + } + + // Set the unused high bits to all ones when we compare. + const unsigned NumHighBits = + PartCount * integerPartWidth - semantics->precision + 1; + assert(NumHighBits <= integerPartWidth && NumHighBits > 0 && + "Can not have more high bits to fill than integerPartWidth"); + const integerPart HighBitFill = ~integerPart(0) + << (integerPartWidth - NumHighBits); + if (~(Parts[PartCount - 1] | HighBitFill | 0x1)) + return false; + + return true; +} + bool IEEEFloat::isSignificandAllZeros() const { // Test if the significand excluding the integral bit is all zeros. This // allows us to test for binade boundaries. @@ -893,10 +956,18 @@ bool IEEEFloat::isSignificandAllZeros() const { } bool IEEEFloat::isLargest() const { - // The largest number by magnitude in our format will be the floating point - // number with maximum exponent and with significand that is all ones. - return isFiniteNonZero() && exponent == semantics->maxExponent - && isSignificandAllOnes(); + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + // The largest number by magnitude in our format will be the floating point + // number with maximum exponent and with significand that is all ones except + // the LSB. + return isFiniteNonZero() && exponent == semantics->maxExponent && + isSignificandAllOnesExceptLSB(); + } else { + // The largest number by magnitude in our format will be the floating point + // number with maximum exponent and with significand that is all ones. + return isFiniteNonZero() && exponent == semantics->maxExponent && + isSignificandAllOnes(); + } } bool IEEEFloat::isInteger() const { @@ -1315,7 +1386,10 @@ IEEEFloat::opStatus IEEEFloat::handleOverflow(roundingMode rounding_mode) { rounding_mode == rmNearestTiesToAway || (rounding_mode == rmTowardPositive && !sign) || (rounding_mode == rmTowardNegative && sign)) { - category = fcInfinity; + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) + makeNaN(false, sign); + else + category = fcInfinity; return (opStatus) (opOverflow | opInexact); } @@ -1324,6 +1398,8 @@ IEEEFloat::opStatus IEEEFloat::handleOverflow(roundingMode rounding_mode) { exponent = semantics->maxExponent; tcSetLeastSignificantBits(significandParts(), partCount(), semantics->precision); + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) + APInt::tcClearBit(significandParts(), 0); return opInexact; } @@ -1423,6 +1499,10 @@ IEEEFloat::opStatus IEEEFloat::normalize(roundingMode rounding_mode, } } + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly && + exponent == semantics->maxExponent && isSignificandAllOnes()) + return handleOverflow(rounding_mode); + /* Now round the number according to rounding_mode given the lost fraction. */ @@ -1459,6 +1539,10 @@ IEEEFloat::opStatus IEEEFloat::normalize(roundingMode rounding_mode, return opInexact; } + + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly && + exponent == semantics->maxExponent && isSignificandAllOnes()) + return handleOverflow(rounding_mode); } /* The normal case - we were and are not denormal, and any @@ -1679,7 +1763,10 @@ IEEEFloat::opStatus IEEEFloat::divideSpecials(const IEEEFloat &rhs) { return opOK; case PackCategoriesIntoKey(fcNormal, fcZero): - category = fcInfinity; + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) + makeNaN(false, sign); + else + category = fcInfinity; return opDivByZero; case PackCategoriesIntoKey(fcInfinity, fcInfinity): @@ -1965,9 +2052,12 @@ IEEEFloat::opStatus IEEEFloat::mod(const IEEEFloat &rhs) { while (isFiniteNonZero() && rhs.isFiniteNonZero() && compareAbsoluteValue(rhs) != cmpLessThan) { - IEEEFloat V = scalbn(rhs, ilogb(*this) - ilogb(rhs), rmNearestTiesToEven); - if (compareAbsoluteValue(V) == cmpLessThan) - V = scalbn(V, -1, rmNearestTiesToEven); + int Exp = ilogb(*this) - ilogb(rhs); + IEEEFloat V = scalbn(rhs, Exp, rmNearestTiesToEven); + // V can overflow to NaN with fltNonfiniteBehavior::NanOnly, so explicitly + // check for it. + if (V.isNaN() || compareAbsoluteValue(V) == cmpLessThan) + V = scalbn(rhs, Exp - 1, rmNearestTiesToEven); V.sign = sign; fs = subtract(V, rmNearestTiesToEven); @@ -2194,6 +2284,7 @@ IEEEFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics, opStatus fs; int shift; const fltSemantics &fromSemantics = *semantics; + bool is_signaling = isSignaling(); lostFraction = lfExactlyZero; newPartCount = partCountForBits(toSemantics.precision + 1); @@ -2235,7 +2326,9 @@ IEEEFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics, } // If this is a truncation, perform the shift before we narrow the storage. - if (shift < 0 && (isFiniteNonZero() || category==fcNaN)) + if (shift < 0 && (isFiniteNonZero() || + (category == fcNaN && semantics->nonFiniteBehavior != + fltNonfiniteBehavior::NanOnly))) lostFraction = shiftRight(significandParts(), oldPartCount, -shift); // Fix the storage so it can hold to new value. @@ -2269,6 +2362,13 @@ IEEEFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics, fs = normalize(rounding_mode, lostFraction); *losesInfo = (fs != opOK); } else if (category == fcNaN) { + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + *losesInfo = + fromSemantics.nonFiniteBehavior != fltNonfiniteBehavior::NanOnly; + makeNaN(false, sign); + return is_signaling ? opInvalidOp : opOK; + } + *losesInfo = lostFraction != lfExactlyZero || X86SpecialNan; // For x87 extended precision, we want to make a NaN, not a special NaN if @@ -2279,12 +2379,17 @@ IEEEFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics, // Convert of sNaN creates qNaN and raises an exception (invalid op). // This also guarantees that a sNaN does not become Inf on a truncation // that loses all payload bits. - if (isSignaling()) { + if (is_signaling) { makeQuiet(); fs = opInvalidOp; } else { fs = opOK; } + } else if (category == fcInfinity && + semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + makeNaN(false, sign); + *losesInfo = true; + fs = opInexact; } else { *losesInfo = false; fs = opOK; @@ -3382,6 +3487,33 @@ APInt IEEEFloat::convertFloat8E5M2APFloatToAPInt() const { (mysignificand & 0x3))); } +APInt IEEEFloat::convertFloat8E4M3FNAPFloatToAPInt() const { + assert(semantics == (const llvm::fltSemantics *)&semFloat8E4M3FN); + assert(partCount() == 1); + + uint32_t myexponent, mysignificand; + + if (isFiniteNonZero()) { + myexponent = exponent + 7; // bias + mysignificand = (uint32_t)*significandParts(); + if (myexponent == 1 && !(mysignificand & 0x8)) + myexponent = 0; // denormal + } else if (category == fcZero) { + myexponent = 0; + mysignificand = 0; + } else if (category == fcInfinity) { + myexponent = 0xf; + mysignificand = 0; + } else { + assert(category == fcNaN && "Unknown category!"); + myexponent = 0xf; + mysignificand = (uint32_t)*significandParts(); + } + + return APInt(8, (((sign & 1) << 7) | ((myexponent & 0xf) << 3) | + (mysignificand & 0x7))); +} + // This function creates an APInt that is just a bit map of the floating // point constant as it would appear in memory. It is not a conversion, // and treating the result as a normal integer is unlikely to be useful. @@ -3408,6 +3540,9 @@ APInt IEEEFloat::bitcastToAPInt() const { if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2) return convertFloat8E5M2APFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FN) + return convertFloat8E4M3FNAPFloatToAPInt(); + assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended && "unknown format!"); return convertF80LongDoubleAPFloatToAPInt(); @@ -3663,10 +3798,33 @@ void IEEEFloat::initFromFloat8E5M2APInt(const APInt &api) { } } -/// Treat api as containing the bits of a floating point number. Currently -/// we infer the floating point type from the size of the APInt. The -/// isIEEE argument distinguishes between PPC128 and IEEE128 (not meaningful -/// when the size is anything else). +void IEEEFloat::initFromFloat8E4M3FNAPInt(const APInt &api) { + uint32_t i = (uint32_t)*api.getRawData(); + uint32_t myexponent = (i >> 3) & 0xf; + uint32_t mysignificand = i & 0x7; + + initialize(&semFloat8E4M3FN); + assert(partCount() == 1); + + sign = i >> 7; + if (myexponent == 0 && mysignificand == 0) { + makeZero(sign); + } else if (myexponent == 0xf && mysignificand == 7) { + category = fcNaN; + exponent = exponentNaN(); + *significandParts() = mysignificand; + } else { + category = fcNormal; + exponent = myexponent - 7; // bias + *significandParts() = mysignificand; + if (myexponent == 0) // denormal + exponent = -6; + else + *significandParts() |= 0x8; // integer bit + } +} + +/// Treat api as containing the bits of a floating point number. void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { assert(api.getBitWidth() == Sem->sizeInBits); if (Sem == &semIEEEhalf) @@ -3685,6 +3843,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { return initFromPPCDoubleDoubleAPInt(api); if (Sem == &semFloat8E5M2) return initFromFloat8E5M2APInt(api); + if (Sem == &semFloat8E4M3FN) + return initFromFloat8E4M3FNAPInt(api); llvm_unreachable(nullptr); } @@ -3712,6 +3872,9 @@ void IEEEFloat::makeLargest(bool Negative) { significand[PartCount - 1] = (NumUnusedHighBits < integerPartWidth) ? (~integerPart(0) >> NumUnusedHighBits) : 0; + + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) + significand[0] &= ~integerPart(1); } /// Make this number the smallest magnitude denormal number in the given @@ -4085,6 +4248,8 @@ bool IEEEFloat::getExactInverse(APFloat *inv) const { bool IEEEFloat::isSignaling() const { if (!isNaN()) return false; + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) + return false; // IEEE-754R 2008 6.2.1: A signaling NaN bit string should be encoded with the // first bit of the trailing significand being 0. @@ -4135,12 +4300,18 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) { break; } - // nextUp(getLargest()) == INFINITY if (isLargest() && !isNegative()) { - APInt::tcSet(significandParts(), 0, partCount()); - category = fcInfinity; - exponent = semantics->maxExponent + 1; - break; + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + // nextUp(getLargest()) == NAN + makeNaN(); + break; + } else { + // nextUp(getLargest()) == INFINITY + APInt::tcSet(significandParts(), 0, partCount()); + category = fcInfinity; + exponent = semantics->maxExponent + 1; + break; + } } // nextUp(normal) == normal + inc. @@ -4212,6 +4383,8 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) { } APFloatBase::ExponentType IEEEFloat::exponentNaN() const { + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) + return semantics->maxExponent; return semantics->maxExponent + 1; } @@ -4224,6 +4397,11 @@ APFloatBase::ExponentType IEEEFloat::exponentZero() const { } void IEEEFloat::makeInf(bool Negative) { + if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly) { + // There is no Inf, so make NaN instead. + makeNaN(false, Negative); + return; + } category = fcInfinity; sign = Negative; exponent = exponentInf(); @@ -4239,7 +4417,8 @@ void IEEEFloat::makeZero(bool Negative) { void IEEEFloat::makeQuiet() { assert(isNaN()); - APInt::tcSetBit(significandParts(), semantics->precision - 2); + if (semantics->nonFiniteBehavior != fltNonfiniteBehavior::NanOnly) + APInt::tcSetBit(significandParts(), semantics->precision - 2); } int ilogb(const IEEEFloat &Arg) { |