aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Support/APFloat.cpp
diff options
context:
space:
mode:
authorDurgadoss R <durgadossr@nvidia.com>2024-10-02 23:04:21 +0530
committerGitHub <noreply@github.com>2024-10-02 23:04:21 +0530
commit99f527d2807b5a14dc7ee64d15405f09e95ee9f2 (patch)
tree1e614683c82d2eedfbea8cf3ec142bc05bf045c2 /llvm/lib/Support/APFloat.cpp
parent5e92bfe97fe0f72f3052df53f813d8dcbb7038d3 (diff)
downloadllvm-99f527d2807b5a14dc7ee64d15405f09e95ee9f2.zip
llvm-99f527d2807b5a14dc7ee64d15405f09e95ee9f2.tar.gz
llvm-99f527d2807b5a14dc7ee64d15405f09e95ee9f2.tar.bz2
[APFloat] Add APFloat support for E8M0 type (#107127)
This patch adds an APFloat type for unsigned E8M0 format. This format is used for representing the "scale-format" in the MX specification: (section 5.4) https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf This format does not support {Inf, denorms, zeroes}. Like FP32, this format's exponents are 8-bits (all bits here) and the bias value is 127. However, it differs from IEEE-FP32 in that the minExponent is -127 (instead of -126). There are updates done in the APFloat utility functions to handle these constraints for this format. * The bias calculation is different and convertIEEE* APIs are updated to handle this. * Since there are no significand bits, the isSignificandAll{Zeroes/Ones} methods are updated accordingly. * Although the format does not have any precision, the precision bit in the fltSemantics is set to 1 for consistency with APFloat's internal representation. * Many utility functions are updated to handle the fact that this format does not support Zero. * Provide a separate initFromAPInt() implementation to handle the quirks of the format. * Add specific tests to verify the range of values for this format. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
Diffstat (limited to 'llvm/lib/Support/APFloat.cpp')
-rw-r--r--llvm/lib/Support/APFloat.cpp198
1 files changed, 169 insertions, 29 deletions
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index dee917f..03413f6 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -119,6 +119,13 @@ struct fltSemantics {
fltNonfiniteBehavior nonFiniteBehavior = fltNonfiniteBehavior::IEEE754;
fltNanEncoding nanEncoding = fltNanEncoding::IEEE;
+
+ /* Whether this semantics has an encoding for Zero */
+ bool hasZero = true;
+
+ /* Whether this semantics can represent signed values */
+ bool hasSignedRepr = true;
+
// Returns true if any number described by this semantics can be precisely
// represented by the specified semantics. Does not take into account
// the value of fltNonfiniteBehavior.
@@ -145,6 +152,10 @@ static constexpr fltSemantics semFloat8E4M3B11FNUZ = {
4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E3M4 = {3, -2, 5, 8};
static constexpr fltSemantics semFloatTF32 = {127, -126, 11, 19};
+static constexpr fltSemantics semFloat8E8M0FNU = {
+ 127, -127, 1, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes,
+ false, false};
+
static constexpr fltSemantics semFloat6E3M2FN = {
4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly};
static constexpr fltSemantics semFloat6E2M3FN = {
@@ -222,6 +233,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
return Float8E3M4();
case S_FloatTF32:
return FloatTF32();
+ case S_Float8E8M0FNU:
+ return Float8E8M0FNU();
case S_Float6E3M2FN:
return Float6E3M2FN();
case S_Float6E2M3FN:
@@ -264,6 +277,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
return S_Float8E3M4;
else if (&Sem == &llvm::APFloat::FloatTF32())
return S_FloatTF32;
+ else if (&Sem == &llvm::APFloat::Float8E8M0FNU())
+ return S_Float8E8M0FNU;
else if (&Sem == &llvm::APFloat::Float6E3M2FN())
return S_Float6E3M2FN;
else if (&Sem == &llvm::APFloat::Float6E2M3FN())
@@ -294,6 +309,7 @@ const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
}
const fltSemantics &APFloatBase::Float8E3M4() { return semFloat8E3M4; }
const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; }
+const fltSemantics &APFloatBase::Float8E8M0FNU() { return semFloat8E8M0FNU; }
const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; }
const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; }
const fltSemantics &APFloatBase::Float4E2M1FN() { return semFloat4E2M1FN; }
@@ -396,7 +412,8 @@ static inline Error createError(const Twine &Err) {
}
static constexpr inline unsigned int partCountForBits(unsigned int bits) {
- return ((bits) + APFloatBase::integerPartWidth - 1) / APFloatBase::integerPartWidth;
+ return std::max(1u, (bits + APFloatBase::integerPartWidth - 1) /
+ APFloatBase::integerPartWidth);
}
/* Returns 0U-9U. Return values >= 10U are not digits. */
@@ -918,6 +935,10 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) {
if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly)
llvm_unreachable("This floating point format does not support NaN");
+ if (Negative && !semantics->hasSignedRepr)
+ llvm_unreachable(
+ "This floating point format does not support signed values");
+
category = fcNaN;
sign = Negative;
exponent = exponentNaN();
@@ -955,7 +976,8 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) {
significand[part] = 0;
}
- unsigned QNaNBit = semantics->precision - 2;
+ unsigned QNaNBit =
+ (semantics->precision >= 2) ? (semantics->precision - 2) : 0;
if (SNaN) {
// We always have to clear the QNaN bit to make it an SNaN.
@@ -1025,6 +1047,19 @@ bool IEEEFloat::isSmallestNormalized() const {
isSignificandAllZerosExceptMSB();
}
+unsigned int IEEEFloat::getNumHighBits() const {
+ const unsigned int PartCount = partCountForBits(semantics->precision);
+ const unsigned int Bits = PartCount * integerPartWidth;
+
+ // Compute how many bits are used in the final word.
+ // When precision is just 1, it represents the 'Pth'
+ // Precision bit and not the actual significand bit.
+ const unsigned int NumHighBits = (semantics->precision > 1)
+ ? (Bits - semantics->precision + 1)
+ : (Bits - semantics->precision);
+ return NumHighBits;
+}
+
bool IEEEFloat::isSignificandAllOnes() const {
// Test if the significand excluding the integral bit is all ones. This allows
// us to test for binade boundaries.
@@ -1035,13 +1070,12 @@ bool IEEEFloat::isSignificandAllOnes() const {
return false;
// Set the unused high bits to all ones when we compare.
- const unsigned NumHighBits =
- PartCount*integerPartWidth - semantics->precision + 1;
+ const unsigned NumHighBits = getNumHighBits();
assert(NumHighBits <= integerPartWidth && NumHighBits > 0 &&
"Can not have more high bits to fill than integerPartWidth");
const integerPart HighBitFill =
~integerPart(0) << (integerPartWidth - NumHighBits);
- if (~(Parts[PartCount - 1] | HighBitFill))
+ if ((semantics->precision <= 1) || (~(Parts[PartCount - 1] | HighBitFill)))
return false;
return true;
@@ -1062,8 +1096,7 @@ bool IEEEFloat::isSignificandAllOnesExceptLSB() const {
}
// Set the unused high bits to all ones when we compare.
- const unsigned NumHighBits =
- PartCount * integerPartWidth - semantics->precision + 1;
+ const unsigned NumHighBits = getNumHighBits();
assert(NumHighBits <= integerPartWidth && NumHighBits > 0 &&
"Can not have more high bits to fill than integerPartWidth");
const integerPart HighBitFill = ~integerPart(0)
@@ -1085,13 +1118,12 @@ bool IEEEFloat::isSignificandAllZeros() const {
return false;
// Compute how many bits are used in the final word.
- const unsigned NumHighBits =
- PartCount*integerPartWidth - semantics->precision + 1;
+ const unsigned NumHighBits = getNumHighBits();
assert(NumHighBits < integerPartWidth && "Can not have more high bits to "
"clear than integerPartWidth");
const integerPart HighBitMask = ~integerPart(0) >> NumHighBits;
- if (Parts[PartCount - 1] & HighBitMask)
+ if ((semantics->precision > 1) && (Parts[PartCount - 1] & HighBitMask))
return false;
return true;
@@ -1106,25 +1138,26 @@ bool IEEEFloat::isSignificandAllZerosExceptMSB() const {
return false;
}
- const unsigned NumHighBits =
- PartCount * integerPartWidth - semantics->precision + 1;
- return Parts[PartCount - 1] == integerPart(1)
- << (integerPartWidth - NumHighBits);
+ const unsigned NumHighBits = getNumHighBits();
+ const integerPart MSBMask = integerPart(1)
+ << (integerPartWidth - NumHighBits);
+ return ((semantics->precision <= 1) || (Parts[PartCount - 1] == MSBMask));
}
bool IEEEFloat::isLargest() const {
+ bool IsMaxExp = isFiniteNonZero() && exponent == semantics->maxExponent;
if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly &&
semantics->nanEncoding == fltNanEncoding::AllOnes) {
// The largest number by magnitude in our format will be the floating point
// number with maximum exponent and with significand that is all ones except
// the LSB.
- return isFiniteNonZero() && exponent == semantics->maxExponent &&
- isSignificandAllOnesExceptLSB();
+ return (IsMaxExp && APFloat::hasSignificand(*semantics))
+ ? isSignificandAllOnesExceptLSB()
+ : IsMaxExp;
} else {
// The largest number by magnitude in our format will be the floating point
// number with maximum exponent and with significand that is all ones.
- return isFiniteNonZero() && exponent == semantics->maxExponent &&
- isSignificandAllOnes();
+ return IsMaxExp && isSignificandAllOnes();
}
}
@@ -1165,7 +1198,13 @@ IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics, integerPart value) {
IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics) {
initialize(&ourSemantics);
- makeZero(false);
+ // The Float8E8MOFNU format does not have a representation
+ // for zero. So, use the closest representation instead.
+ // Moreover, the all-zero encoding represents a valid
+ // normal value (which is the smallestNormalized here).
+ // Hence, we call makeSmallestNormalized (where category is
+ // 'fcNormal') instead of makeZero (where category is 'fcZero').
+ ourSemantics.hasZero ? makeZero(false) : makeSmallestNormalized(false);
}
// Delegate to the previous constructor, because later copy constructor may
@@ -1245,7 +1284,8 @@ IEEEFloat::integerPart IEEEFloat::subtractSignificand(const IEEEFloat &rhs,
on to the full-precision result of the multiplication. Returns the
lost fraction. */
lostFraction IEEEFloat::multiplySignificand(const IEEEFloat &rhs,
- IEEEFloat addend) {
+ IEEEFloat addend,
+ bool ignoreAddend) {
unsigned int omsb; // One, not zero, based MSB.
unsigned int partsCount, newPartsCount, precision;
integerPart *lhsSignificand;
@@ -1289,7 +1329,7 @@ lostFraction IEEEFloat::multiplySignificand(const IEEEFloat &rhs,
// toward left by two bits, and adjust exponent accordingly.
exponent += 2;
- if (addend.isNonZero()) {
+ if (!ignoreAddend && addend.isNonZero()) {
// The intermediate result of the multiplication has "2 * precision"
// signicant bit; adjust the addend to be consistent with mul result.
//
@@ -1377,7 +1417,12 @@ lostFraction IEEEFloat::multiplySignificand(const IEEEFloat &rhs,
}
lostFraction IEEEFloat::multiplySignificand(const IEEEFloat &rhs) {
- return multiplySignificand(rhs, IEEEFloat(*semantics));
+ // When the given semantics has zero, the addend here is a zero.
+ // i.e . it belongs to the 'fcZero' category.
+ // But when the semantics does not support zero, we need to
+ // explicitly convey that this addend should be ignored
+ // for multiplication.
+ return multiplySignificand(rhs, IEEEFloat(*semantics), !semantics->hasZero);
}
/* Multiply the significands of LHS and RHS to DST. */
@@ -1483,7 +1528,8 @@ lostFraction IEEEFloat::shiftSignificandRight(unsigned int bits) {
/* Shift the significand left BITS bits, subtract BITS from its exponent. */
void IEEEFloat::shiftSignificandLeft(unsigned int bits) {
- assert(bits < semantics->precision);
+ assert(bits < semantics->precision ||
+ (semantics->precision == 1 && bits <= 1));
if (bits) {
unsigned int partsCount = partCount();
@@ -1678,6 +1724,8 @@ IEEEFloat::opStatus IEEEFloat::normalize(roundingMode rounding_mode,
category = fcZero;
if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
sign = false;
+ if (!semantics->hasZero)
+ makeSmallestNormalized(false);
}
return opOK;
@@ -1729,6 +1777,11 @@ IEEEFloat::opStatus IEEEFloat::normalize(roundingMode rounding_mode,
category = fcZero;
if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
sign = false;
+ // This condition handles the case where the semantics
+ // does not have zero but uses the all-zero encoding
+ // to represent the smallest normal value.
+ if (!semantics->hasZero)
+ makeSmallestNormalized(false);
}
/* The fcZero case is a denormal that underflowed to zero. */
@@ -1807,6 +1860,10 @@ lostFraction IEEEFloat::addOrSubtractSignificand(const IEEEFloat &rhs,
/* Subtraction is more subtle than one might naively expect. */
if (subtract) {
+ if ((bits < 0) && !semantics->hasSignedRepr)
+ llvm_unreachable(
+ "This floating point format does not support signed values");
+
IEEEFloat temp_rhs(rhs);
if (bits == 0)
@@ -2252,6 +2309,17 @@ IEEEFloat::opStatus IEEEFloat::mod(const IEEEFloat &rhs) {
V.sign = sign;
fs = subtract(V, rmNearestTiesToEven);
+
+ // When the semantics supports zero, this loop's
+ // exit-condition is handled by the 'isFiniteNonZero'
+ // category check above. However, when the semantics
+ // does not have 'fcZero' and we have reached the
+ // minimum possible value, (and any further subtract
+ // will underflow to the same value) explicitly
+ // provide an exit-path here.
+ if (!semantics->hasZero && this->isSmallest())
+ break;
+
assert(fs==opOK);
}
if (isZero()) {
@@ -2606,6 +2674,8 @@ IEEEFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics,
fs = opOK;
}
+ if (category == fcZero && !semantics->hasZero)
+ makeSmallestNormalized(false);
return fs;
}
@@ -3070,6 +3140,8 @@ IEEEFloat::convertFromDecimalString(StringRef str, roundingMode rounding_mode) {
fs = opOK;
if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
sign = false;
+ if (!semantics->hasZero)
+ makeSmallestNormalized(false);
/* Check whether the normalized exponent is high enough to overflow
max during the log-rebasing in the max-exponent check below. */
@@ -3237,6 +3309,10 @@ IEEEFloat::convertFromString(StringRef str, roundingMode rounding_mode) {
StringRef::iterator p = str.begin();
size_t slen = str.size();
sign = *p == '-' ? 1 : 0;
+ if (sign && !semantics->hasSignedRepr)
+ llvm_unreachable(
+ "This floating point format does not support signed values");
+
if (*p == '-' || *p == '+') {
p++;
slen--;
@@ -3533,15 +3609,16 @@ APInt IEEEFloat::convertPPCDoubleDoubleAPFloatToAPInt() const {
template <const fltSemantics &S>
APInt IEEEFloat::convertIEEEFloatToAPInt() const {
assert(semantics == &S);
-
- constexpr int bias = -(S.minExponent - 1);
+ const int bias =
+ (semantics == &semFloat8E8M0FNU) ? -S.minExponent : -(S.minExponent - 1);
constexpr unsigned int trailing_significand_bits = S.precision - 1;
constexpr int integer_bit_part = trailing_significand_bits / integerPartWidth;
constexpr integerPart integer_bit =
integerPart{1} << (trailing_significand_bits % integerPartWidth);
constexpr uint64_t significand_mask = integer_bit - 1;
constexpr unsigned int exponent_bits =
- S.sizeInBits - 1 - trailing_significand_bits;
+ trailing_significand_bits ? (S.sizeInBits - 1 - trailing_significand_bits)
+ : S.sizeInBits;
static_assert(exponent_bits < 64);
constexpr uint64_t exponent_mask = (uint64_t{1} << exponent_bits) - 1;
@@ -3557,6 +3634,8 @@ APInt IEEEFloat::convertIEEEFloatToAPInt() const {
!(significandParts()[integer_bit_part] & integer_bit))
myexponent = 0; // denormal
} else if (category == fcZero) {
+ if (!S.hasZero)
+ llvm_unreachable("semantics does not support zero!");
myexponent = ::exponentZero(S) + bias;
mysignificand.fill(0);
} else if (category == fcInfinity) {
@@ -3659,6 +3738,11 @@ APInt IEEEFloat::convertFloatTF32APFloatToAPInt() const {
return convertIEEEFloatToAPInt<semFloatTF32>();
}
+APInt IEEEFloat::convertFloat8E8M0FNUAPFloatToAPInt() const {
+ assert(partCount() == 1);
+ return convertIEEEFloatToAPInt<semFloat8E8M0FNU>();
+}
+
APInt IEEEFloat::convertFloat6E3M2FNAPFloatToAPInt() const {
assert(partCount() == 1);
return convertIEEEFloatToAPInt<semFloat6E3M2FN>();
@@ -3721,6 +3805,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
if (semantics == (const llvm::fltSemantics *)&semFloatTF32)
return convertFloatTF32APFloatToAPInt();
+ if (semantics == (const llvm::fltSemantics *)&semFloat8E8M0FNU)
+ return convertFloat8E8M0FNUAPFloatToAPInt();
+
if (semantics == (const llvm::fltSemantics *)&semFloat6E3M2FN)
return convertFloat6E3M2FNAPFloatToAPInt();
@@ -3819,6 +3906,40 @@ void IEEEFloat::initFromPPCDoubleDoubleAPInt(const APInt &api) {
}
}
+// The E8M0 format has the following characteristics:
+// It is an 8-bit unsigned format with only exponents (no actual significand).
+// No encodings for {zero, infinities or denorms}.
+// NaN is represented by all 1's.
+// Bias is 127.
+void IEEEFloat::initFromFloat8E8M0FNUAPInt(const APInt &api) {
+ const uint64_t exponent_mask = 0xff;
+ uint64_t val = api.getRawData()[0];
+ uint64_t myexponent = (val & exponent_mask);
+
+ initialize(&semFloat8E8M0FNU);
+ assert(partCount() == 1);
+
+ // This format has unsigned representation only
+ sign = 0;
+
+ // Set the significand
+ // This format does not have any significand but the 'Pth' precision bit is
+ // always set to 1 for consistency in APFloat's internal representation.
+ uint64_t mysignificand = 1;
+ significandParts()[0] = mysignificand;
+
+ // This format can either have a NaN or fcNormal
+ // All 1's i.e. 255 is a NaN
+ if (val == exponent_mask) {
+ category = fcNaN;
+ exponent = exponentNaN();
+ return;
+ }
+ // Handle fcNormal...
+ category = fcNormal;
+ exponent = myexponent - 127; // 127 is bias
+ return;
+}
template <const fltSemantics &S>
void IEEEFloat::initFromIEEEAPInt(const APInt &api) {
assert(api.getBitWidth() == S.sizeInBits);
@@ -3999,6 +4120,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
return initFromFloat8E3M4APInt(api);
if (Sem == &semFloatTF32)
return initFromFloatTF32APInt(api);
+ if (Sem == &semFloat8E8M0FNU)
+ return initFromFloat8E8M0FNUAPInt(api);
if (Sem == &semFloat6E3M2FN)
return initFromFloat6E3M2FNAPInt(api);
if (Sem == &semFloat6E2M3FN)
@@ -4012,6 +4135,9 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
/// Make this number the largest magnitude normal number in the given
/// semantics.
void IEEEFloat::makeLargest(bool Negative) {
+ if (Negative && !semantics->hasSignedRepr)
+ llvm_unreachable(
+ "This floating point format does not support signed values");
// We want (in interchange format):
// sign = {Negative}
// exponent = 1..10
@@ -4032,15 +4158,18 @@ void IEEEFloat::makeLargest(bool Negative) {
significand[PartCount - 1] = (NumUnusedHighBits < integerPartWidth)
? (~integerPart(0) >> NumUnusedHighBits)
: 0;
-
if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly &&
- semantics->nanEncoding == fltNanEncoding::AllOnes)
+ semantics->nanEncoding == fltNanEncoding::AllOnes &&
+ (semantics->precision > 1))
significand[0] &= ~integerPart(1);
}
/// Make this number the smallest magnitude denormal number in the given
/// semantics.
void IEEEFloat::makeSmallest(bool Negative) {
+ if (Negative && !semantics->hasSignedRepr)
+ llvm_unreachable(
+ "This floating point format does not support signed values");
// We want (in interchange format):
// sign = {Negative}
// exponent = 0..0
@@ -4052,6 +4181,9 @@ void IEEEFloat::makeSmallest(bool Negative) {
}
void IEEEFloat::makeSmallestNormalized(bool Negative) {
+ if (Negative && !semantics->hasSignedRepr)
+ llvm_unreachable(
+ "This floating point format does not support signed values");
// We want (in interchange format):
// sign = {Negative}
// exponent = 0..0
@@ -4509,6 +4641,8 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) {
exponent = 0;
if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
sign = false;
+ if (!semantics->hasZero)
+ makeSmallestNormalized(false);
break;
}
@@ -4574,7 +4708,10 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) {
// the integral bit to 1, and increment the exponent. If we have a
// denormal always increment since moving denormals and the numbers in the
// smallest normal binade have the same exponent in our representation.
- bool WillCrossBinadeBoundary = !isDenormal() && isSignificandAllOnes();
+ // If there are only exponents, any increment always crosses the
+ // BinadeBoundary.
+ bool WillCrossBinadeBoundary = !APFloat::hasSignificand(*semantics) ||
+ (!isDenormal() && isSignificandAllOnes());
if (WillCrossBinadeBoundary) {
integerPart *Parts = significandParts();
@@ -4626,6 +4763,9 @@ void IEEEFloat::makeInf(bool Negative) {
}
void IEEEFloat::makeZero(bool Negative) {
+ if (!semantics->hasZero)
+ llvm_unreachable("This floating point format does not support Zero");
+
category = fcZero;
sign = Negative;
if (semantics->nanEncoding == fltNanEncoding::NegativeZero) {