diff options
author | Durgadoss R <durgadossr@nvidia.com> | 2024-10-02 23:04:21 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-02 23:04:21 +0530 |
commit | 99f527d2807b5a14dc7ee64d15405f09e95ee9f2 (patch) | |
tree | 1e614683c82d2eedfbea8cf3ec142bc05bf045c2 /llvm/lib/Support/APFloat.cpp | |
parent | 5e92bfe97fe0f72f3052df53f813d8dcbb7038d3 (diff) | |
download | llvm-99f527d2807b5a14dc7ee64d15405f09e95ee9f2.zip llvm-99f527d2807b5a14dc7ee64d15405f09e95ee9f2.tar.gz llvm-99f527d2807b5a14dc7ee64d15405f09e95ee9f2.tar.bz2 |
[APFloat] Add APFloat support for E8M0 type (#107127)
This patch adds an APFloat type for unsigned E8M0 format. This format is
used for representing the "scale-format" in the MX specification:
(section 5.4)
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
This format does not support {Inf, denorms, zeroes}. Like FP32, this
format's exponents are 8-bits (all bits here) and the bias value is 127.
However, it differs from IEEE-FP32 in that the minExponent is -127
(instead of -126). There are updates done in the APFloat utility
functions to handle these constraints for this format.
* The bias calculation is different and convertIEEE* APIs are updated to
handle this.
* Since there are no significand bits, the isSignificandAll{Zeroes/Ones}
methods are updated accordingly.
* Although the format does not have any precision, the precision bit in
the fltSemantics is set to 1 for consistency with APFloat's internal
representation.
* Many utility functions are updated to handle the fact that this format
does not support Zero.
* Provide a separate initFromAPInt() implementation to handle the quirks
of the format.
* Add specific tests to verify the range of values for this format.
Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
Diffstat (limited to 'llvm/lib/Support/APFloat.cpp')
-rw-r--r-- | llvm/lib/Support/APFloat.cpp | 198 |
1 files changed, 169 insertions, 29 deletions
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp index dee917f..03413f6 100644 --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -119,6 +119,13 @@ struct fltSemantics { fltNonfiniteBehavior nonFiniteBehavior = fltNonfiniteBehavior::IEEE754; fltNanEncoding nanEncoding = fltNanEncoding::IEEE; + + /* Whether this semantics has an encoding for Zero */ + bool hasZero = true; + + /* Whether this semantics can represent signed values */ + bool hasSignedRepr = true; + // Returns true if any number described by this semantics can be precisely // represented by the specified semantics. Does not take into account // the value of fltNonfiniteBehavior. @@ -145,6 +152,10 @@ static constexpr fltSemantics semFloat8E4M3B11FNUZ = { 4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero}; static constexpr fltSemantics semFloat8E3M4 = {3, -2, 5, 8}; static constexpr fltSemantics semFloatTF32 = {127, -126, 11, 19}; +static constexpr fltSemantics semFloat8E8M0FNU = { + 127, -127, 1, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes, + false, false}; + static constexpr fltSemantics semFloat6E3M2FN = { 4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly}; static constexpr fltSemantics semFloat6E2M3FN = { @@ -222,6 +233,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) { return Float8E3M4(); case S_FloatTF32: return FloatTF32(); + case S_Float8E8M0FNU: + return Float8E8M0FNU(); case S_Float6E3M2FN: return Float6E3M2FN(); case S_Float6E2M3FN: @@ -264,6 +277,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) { return S_Float8E3M4; else if (&Sem == &llvm::APFloat::FloatTF32()) return S_FloatTF32; + else if (&Sem == &llvm::APFloat::Float8E8M0FNU()) + return S_Float8E8M0FNU; else if (&Sem == &llvm::APFloat::Float6E3M2FN()) return S_Float6E3M2FN; else if (&Sem == &llvm::APFloat::Float6E2M3FN()) @@ -294,6 +309,7 @@ const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() { } const fltSemantics &APFloatBase::Float8E3M4() { return semFloat8E3M4; } const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; } +const fltSemantics &APFloatBase::Float8E8M0FNU() { return semFloat8E8M0FNU; } const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; } const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; } const fltSemantics &APFloatBase::Float4E2M1FN() { return semFloat4E2M1FN; } @@ -396,7 +412,8 @@ static inline Error createError(const Twine &Err) { } static constexpr inline unsigned int partCountForBits(unsigned int bits) { - return ((bits) + APFloatBase::integerPartWidth - 1) / APFloatBase::integerPartWidth; + return std::max(1u, (bits + APFloatBase::integerPartWidth - 1) / + APFloatBase::integerPartWidth); } /* Returns 0U-9U. Return values >= 10U are not digits. */ @@ -918,6 +935,10 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) { if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly) llvm_unreachable("This floating point format does not support NaN"); + if (Negative && !semantics->hasSignedRepr) + llvm_unreachable( + "This floating point format does not support signed values"); + category = fcNaN; sign = Negative; exponent = exponentNaN(); @@ -955,7 +976,8 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) { significand[part] = 0; } - unsigned QNaNBit = semantics->precision - 2; + unsigned QNaNBit = + (semantics->precision >= 2) ? (semantics->precision - 2) : 0; if (SNaN) { // We always have to clear the QNaN bit to make it an SNaN. @@ -1025,6 +1047,19 @@ bool IEEEFloat::isSmallestNormalized() const { isSignificandAllZerosExceptMSB(); } +unsigned int IEEEFloat::getNumHighBits() const { + const unsigned int PartCount = partCountForBits(semantics->precision); + const unsigned int Bits = PartCount * integerPartWidth; + + // Compute how many bits are used in the final word. + // When precision is just 1, it represents the 'Pth' + // Precision bit and not the actual significand bit. + const unsigned int NumHighBits = (semantics->precision > 1) + ? (Bits - semantics->precision + 1) + : (Bits - semantics->precision); + return NumHighBits; +} + bool IEEEFloat::isSignificandAllOnes() const { // Test if the significand excluding the integral bit is all ones. This allows // us to test for binade boundaries. @@ -1035,13 +1070,12 @@ bool IEEEFloat::isSignificandAllOnes() const { return false; // Set the unused high bits to all ones when we compare. - const unsigned NumHighBits = - PartCount*integerPartWidth - semantics->precision + 1; + const unsigned NumHighBits = getNumHighBits(); assert(NumHighBits <= integerPartWidth && NumHighBits > 0 && "Can not have more high bits to fill than integerPartWidth"); const integerPart HighBitFill = ~integerPart(0) << (integerPartWidth - NumHighBits); - if (~(Parts[PartCount - 1] | HighBitFill)) + if ((semantics->precision <= 1) || (~(Parts[PartCount - 1] | HighBitFill))) return false; return true; @@ -1062,8 +1096,7 @@ bool IEEEFloat::isSignificandAllOnesExceptLSB() const { } // Set the unused high bits to all ones when we compare. - const unsigned NumHighBits = - PartCount * integerPartWidth - semantics->precision + 1; + const unsigned NumHighBits = getNumHighBits(); assert(NumHighBits <= integerPartWidth && NumHighBits > 0 && "Can not have more high bits to fill than integerPartWidth"); const integerPart HighBitFill = ~integerPart(0) @@ -1085,13 +1118,12 @@ bool IEEEFloat::isSignificandAllZeros() const { return false; // Compute how many bits are used in the final word. - const unsigned NumHighBits = - PartCount*integerPartWidth - semantics->precision + 1; + const unsigned NumHighBits = getNumHighBits(); assert(NumHighBits < integerPartWidth && "Can not have more high bits to " "clear than integerPartWidth"); const integerPart HighBitMask = ~integerPart(0) >> NumHighBits; - if (Parts[PartCount - 1] & HighBitMask) + if ((semantics->precision > 1) && (Parts[PartCount - 1] & HighBitMask)) return false; return true; @@ -1106,25 +1138,26 @@ bool IEEEFloat::isSignificandAllZerosExceptMSB() const { return false; } - const unsigned NumHighBits = - PartCount * integerPartWidth - semantics->precision + 1; - return Parts[PartCount - 1] == integerPart(1) - << (integerPartWidth - NumHighBits); + const unsigned NumHighBits = getNumHighBits(); + const integerPart MSBMask = integerPart(1) + << (integerPartWidth - NumHighBits); + return ((semantics->precision <= 1) || (Parts[PartCount - 1] == MSBMask)); } bool IEEEFloat::isLargest() const { + bool IsMaxExp = isFiniteNonZero() && exponent == semantics->maxExponent; if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly && semantics->nanEncoding == fltNanEncoding::AllOnes) { // The largest number by magnitude in our format will be the floating point // number with maximum exponent and with significand that is all ones except // the LSB. - return isFiniteNonZero() && exponent == semantics->maxExponent && - isSignificandAllOnesExceptLSB(); + return (IsMaxExp && APFloat::hasSignificand(*semantics)) + ? isSignificandAllOnesExceptLSB() + : IsMaxExp; } else { // The largest number by magnitude in our format will be the floating point // number with maximum exponent and with significand that is all ones. - return isFiniteNonZero() && exponent == semantics->maxExponent && - isSignificandAllOnes(); + return IsMaxExp && isSignificandAllOnes(); } } @@ -1165,7 +1198,13 @@ IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics, integerPart value) { IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics) { initialize(&ourSemantics); - makeZero(false); + // The Float8E8MOFNU format does not have a representation + // for zero. So, use the closest representation instead. + // Moreover, the all-zero encoding represents a valid + // normal value (which is the smallestNormalized here). + // Hence, we call makeSmallestNormalized (where category is + // 'fcNormal') instead of makeZero (where category is 'fcZero'). + ourSemantics.hasZero ? makeZero(false) : makeSmallestNormalized(false); } // Delegate to the previous constructor, because later copy constructor may @@ -1245,7 +1284,8 @@ IEEEFloat::integerPart IEEEFloat::subtractSignificand(const IEEEFloat &rhs, on to the full-precision result of the multiplication. Returns the lost fraction. */ lostFraction IEEEFloat::multiplySignificand(const IEEEFloat &rhs, - IEEEFloat addend) { + IEEEFloat addend, + bool ignoreAddend) { unsigned int omsb; // One, not zero, based MSB. unsigned int partsCount, newPartsCount, precision; integerPart *lhsSignificand; @@ -1289,7 +1329,7 @@ lostFraction IEEEFloat::multiplySignificand(const IEEEFloat &rhs, // toward left by two bits, and adjust exponent accordingly. exponent += 2; - if (addend.isNonZero()) { + if (!ignoreAddend && addend.isNonZero()) { // The intermediate result of the multiplication has "2 * precision" // signicant bit; adjust the addend to be consistent with mul result. // @@ -1377,7 +1417,12 @@ lostFraction IEEEFloat::multiplySignificand(const IEEEFloat &rhs, } lostFraction IEEEFloat::multiplySignificand(const IEEEFloat &rhs) { - return multiplySignificand(rhs, IEEEFloat(*semantics)); + // When the given semantics has zero, the addend here is a zero. + // i.e . it belongs to the 'fcZero' category. + // But when the semantics does not support zero, we need to + // explicitly convey that this addend should be ignored + // for multiplication. + return multiplySignificand(rhs, IEEEFloat(*semantics), !semantics->hasZero); } /* Multiply the significands of LHS and RHS to DST. */ @@ -1483,7 +1528,8 @@ lostFraction IEEEFloat::shiftSignificandRight(unsigned int bits) { /* Shift the significand left BITS bits, subtract BITS from its exponent. */ void IEEEFloat::shiftSignificandLeft(unsigned int bits) { - assert(bits < semantics->precision); + assert(bits < semantics->precision || + (semantics->precision == 1 && bits <= 1)); if (bits) { unsigned int partsCount = partCount(); @@ -1678,6 +1724,8 @@ IEEEFloat::opStatus IEEEFloat::normalize(roundingMode rounding_mode, category = fcZero; if (semantics->nanEncoding == fltNanEncoding::NegativeZero) sign = false; + if (!semantics->hasZero) + makeSmallestNormalized(false); } return opOK; @@ -1729,6 +1777,11 @@ IEEEFloat::opStatus IEEEFloat::normalize(roundingMode rounding_mode, category = fcZero; if (semantics->nanEncoding == fltNanEncoding::NegativeZero) sign = false; + // This condition handles the case where the semantics + // does not have zero but uses the all-zero encoding + // to represent the smallest normal value. + if (!semantics->hasZero) + makeSmallestNormalized(false); } /* The fcZero case is a denormal that underflowed to zero. */ @@ -1807,6 +1860,10 @@ lostFraction IEEEFloat::addOrSubtractSignificand(const IEEEFloat &rhs, /* Subtraction is more subtle than one might naively expect. */ if (subtract) { + if ((bits < 0) && !semantics->hasSignedRepr) + llvm_unreachable( + "This floating point format does not support signed values"); + IEEEFloat temp_rhs(rhs); if (bits == 0) @@ -2252,6 +2309,17 @@ IEEEFloat::opStatus IEEEFloat::mod(const IEEEFloat &rhs) { V.sign = sign; fs = subtract(V, rmNearestTiesToEven); + + // When the semantics supports zero, this loop's + // exit-condition is handled by the 'isFiniteNonZero' + // category check above. However, when the semantics + // does not have 'fcZero' and we have reached the + // minimum possible value, (and any further subtract + // will underflow to the same value) explicitly + // provide an exit-path here. + if (!semantics->hasZero && this->isSmallest()) + break; + assert(fs==opOK); } if (isZero()) { @@ -2606,6 +2674,8 @@ IEEEFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics, fs = opOK; } + if (category == fcZero && !semantics->hasZero) + makeSmallestNormalized(false); return fs; } @@ -3070,6 +3140,8 @@ IEEEFloat::convertFromDecimalString(StringRef str, roundingMode rounding_mode) { fs = opOK; if (semantics->nanEncoding == fltNanEncoding::NegativeZero) sign = false; + if (!semantics->hasZero) + makeSmallestNormalized(false); /* Check whether the normalized exponent is high enough to overflow max during the log-rebasing in the max-exponent check below. */ @@ -3237,6 +3309,10 @@ IEEEFloat::convertFromString(StringRef str, roundingMode rounding_mode) { StringRef::iterator p = str.begin(); size_t slen = str.size(); sign = *p == '-' ? 1 : 0; + if (sign && !semantics->hasSignedRepr) + llvm_unreachable( + "This floating point format does not support signed values"); + if (*p == '-' || *p == '+') { p++; slen--; @@ -3533,15 +3609,16 @@ APInt IEEEFloat::convertPPCDoubleDoubleAPFloatToAPInt() const { template <const fltSemantics &S> APInt IEEEFloat::convertIEEEFloatToAPInt() const { assert(semantics == &S); - - constexpr int bias = -(S.minExponent - 1); + const int bias = + (semantics == &semFloat8E8M0FNU) ? -S.minExponent : -(S.minExponent - 1); constexpr unsigned int trailing_significand_bits = S.precision - 1; constexpr int integer_bit_part = trailing_significand_bits / integerPartWidth; constexpr integerPart integer_bit = integerPart{1} << (trailing_significand_bits % integerPartWidth); constexpr uint64_t significand_mask = integer_bit - 1; constexpr unsigned int exponent_bits = - S.sizeInBits - 1 - trailing_significand_bits; + trailing_significand_bits ? (S.sizeInBits - 1 - trailing_significand_bits) + : S.sizeInBits; static_assert(exponent_bits < 64); constexpr uint64_t exponent_mask = (uint64_t{1} << exponent_bits) - 1; @@ -3557,6 +3634,8 @@ APInt IEEEFloat::convertIEEEFloatToAPInt() const { !(significandParts()[integer_bit_part] & integer_bit)) myexponent = 0; // denormal } else if (category == fcZero) { + if (!S.hasZero) + llvm_unreachable("semantics does not support zero!"); myexponent = ::exponentZero(S) + bias; mysignificand.fill(0); } else if (category == fcInfinity) { @@ -3659,6 +3738,11 @@ APInt IEEEFloat::convertFloatTF32APFloatToAPInt() const { return convertIEEEFloatToAPInt<semFloatTF32>(); } +APInt IEEEFloat::convertFloat8E8M0FNUAPFloatToAPInt() const { + assert(partCount() == 1); + return convertIEEEFloatToAPInt<semFloat8E8M0FNU>(); +} + APInt IEEEFloat::convertFloat6E3M2FNAPFloatToAPInt() const { assert(partCount() == 1); return convertIEEEFloatToAPInt<semFloat6E3M2FN>(); @@ -3721,6 +3805,9 @@ APInt IEEEFloat::bitcastToAPInt() const { if (semantics == (const llvm::fltSemantics *)&semFloatTF32) return convertFloatTF32APFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semFloat8E8M0FNU) + return convertFloat8E8M0FNUAPFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semFloat6E3M2FN) return convertFloat6E3M2FNAPFloatToAPInt(); @@ -3819,6 +3906,40 @@ void IEEEFloat::initFromPPCDoubleDoubleAPInt(const APInt &api) { } } +// The E8M0 format has the following characteristics: +// It is an 8-bit unsigned format with only exponents (no actual significand). +// No encodings for {zero, infinities or denorms}. +// NaN is represented by all 1's. +// Bias is 127. +void IEEEFloat::initFromFloat8E8M0FNUAPInt(const APInt &api) { + const uint64_t exponent_mask = 0xff; + uint64_t val = api.getRawData()[0]; + uint64_t myexponent = (val & exponent_mask); + + initialize(&semFloat8E8M0FNU); + assert(partCount() == 1); + + // This format has unsigned representation only + sign = 0; + + // Set the significand + // This format does not have any significand but the 'Pth' precision bit is + // always set to 1 for consistency in APFloat's internal representation. + uint64_t mysignificand = 1; + significandParts()[0] = mysignificand; + + // This format can either have a NaN or fcNormal + // All 1's i.e. 255 is a NaN + if (val == exponent_mask) { + category = fcNaN; + exponent = exponentNaN(); + return; + } + // Handle fcNormal... + category = fcNormal; + exponent = myexponent - 127; // 127 is bias + return; +} template <const fltSemantics &S> void IEEEFloat::initFromIEEEAPInt(const APInt &api) { assert(api.getBitWidth() == S.sizeInBits); @@ -3999,6 +4120,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { return initFromFloat8E3M4APInt(api); if (Sem == &semFloatTF32) return initFromFloatTF32APInt(api); + if (Sem == &semFloat8E8M0FNU) + return initFromFloat8E8M0FNUAPInt(api); if (Sem == &semFloat6E3M2FN) return initFromFloat6E3M2FNAPInt(api); if (Sem == &semFloat6E2M3FN) @@ -4012,6 +4135,9 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { /// Make this number the largest magnitude normal number in the given /// semantics. void IEEEFloat::makeLargest(bool Negative) { + if (Negative && !semantics->hasSignedRepr) + llvm_unreachable( + "This floating point format does not support signed values"); // We want (in interchange format): // sign = {Negative} // exponent = 1..10 @@ -4032,15 +4158,18 @@ void IEEEFloat::makeLargest(bool Negative) { significand[PartCount - 1] = (NumUnusedHighBits < integerPartWidth) ? (~integerPart(0) >> NumUnusedHighBits) : 0; - if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly && - semantics->nanEncoding == fltNanEncoding::AllOnes) + semantics->nanEncoding == fltNanEncoding::AllOnes && + (semantics->precision > 1)) significand[0] &= ~integerPart(1); } /// Make this number the smallest magnitude denormal number in the given /// semantics. void IEEEFloat::makeSmallest(bool Negative) { + if (Negative && !semantics->hasSignedRepr) + llvm_unreachable( + "This floating point format does not support signed values"); // We want (in interchange format): // sign = {Negative} // exponent = 0..0 @@ -4052,6 +4181,9 @@ void IEEEFloat::makeSmallest(bool Negative) { } void IEEEFloat::makeSmallestNormalized(bool Negative) { + if (Negative && !semantics->hasSignedRepr) + llvm_unreachable( + "This floating point format does not support signed values"); // We want (in interchange format): // sign = {Negative} // exponent = 0..0 @@ -4509,6 +4641,8 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) { exponent = 0; if (semantics->nanEncoding == fltNanEncoding::NegativeZero) sign = false; + if (!semantics->hasZero) + makeSmallestNormalized(false); break; } @@ -4574,7 +4708,10 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) { // the integral bit to 1, and increment the exponent. If we have a // denormal always increment since moving denormals and the numbers in the // smallest normal binade have the same exponent in our representation. - bool WillCrossBinadeBoundary = !isDenormal() && isSignificandAllOnes(); + // If there are only exponents, any increment always crosses the + // BinadeBoundary. + bool WillCrossBinadeBoundary = !APFloat::hasSignificand(*semantics) || + (!isDenormal() && isSignificandAllOnes()); if (WillCrossBinadeBoundary) { integerPart *Parts = significandParts(); @@ -4626,6 +4763,9 @@ void IEEEFloat::makeInf(bool Negative) { } void IEEEFloat::makeZero(bool Negative) { + if (!semantics->hasZero) + llvm_unreachable("This floating point format does not support Zero"); + category = fcZero; sign = Negative; if (semantics->nanEncoding == fltNanEncoding::NegativeZero) { |