aboutsummaryrefslogtreecommitdiff
path: root/llvm
diff options
context:
space:
mode:
authorDurgadoss R <durgadossr@nvidia.com>2024-10-02 23:04:21 +0530
committerGitHub <noreply@github.com>2024-10-02 23:04:21 +0530
commit99f527d2807b5a14dc7ee64d15405f09e95ee9f2 (patch)
tree1e614683c82d2eedfbea8cf3ec142bc05bf045c2 /llvm
parent5e92bfe97fe0f72f3052df53f813d8dcbb7038d3 (diff)
downloadllvm-99f527d2807b5a14dc7ee64d15405f09e95ee9f2.zip
llvm-99f527d2807b5a14dc7ee64d15405f09e95ee9f2.tar.gz
llvm-99f527d2807b5a14dc7ee64d15405f09e95ee9f2.tar.bz2
[APFloat] Add APFloat support for E8M0 type (#107127)
This patch adds an APFloat type for unsigned E8M0 format. This format is used for representing the "scale-format" in the MX specification: (section 5.4) https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf This format does not support {Inf, denorms, zeroes}. Like FP32, this format's exponents are 8-bits (all bits here) and the bias value is 127. However, it differs from IEEE-FP32 in that the minExponent is -127 (instead of -126). There are updates done in the APFloat utility functions to handle these constraints for this format. * The bias calculation is different and convertIEEE* APIs are updated to handle this. * Since there are no significand bits, the isSignificandAll{Zeroes/Ones} methods are updated accordingly. * Although the format does not have any precision, the precision bit in the fltSemantics is set to 1 for consistency with APFloat's internal representation. * Many utility functions are updated to handle the fact that this format does not support Zero. * Provide a separate initFromAPInt() implementation to handle the quirks of the format. * Add specific tests to verify the range of values for this format. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
Diffstat (limited to 'llvm')
-rw-r--r--llvm/include/llvm/ADT/APFloat.h24
-rw-r--r--llvm/lib/Support/APFloat.cpp198
-rw-r--r--llvm/unittests/ADT/APFloatTest.cpp461
3 files changed, 653 insertions, 30 deletions
diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index acb3b2e..40131c8 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -195,6 +195,13 @@ struct APFloatBase {
// improved range compared to half (16-bit) formats, at (potentially)
// greater throughput than single precision (32-bit) formats.
S_FloatTF32,
+ // 8-bit floating point number with (all the) 8 bits for the exponent
+ // like in FP32. There are no zeroes, no infinities, and no denormal values.
+ // This format has unsigned representation only. (U -> Unsigned only).
+ // NaN is represented with all bits set to 1. Bias is 127.
+ // This format represents the scale data type in the MX specification from:
+ // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+ S_Float8E8M0FNU,
// 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754
// types, there are no infinity or NaN values. The format is detailed in
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
@@ -229,6 +236,7 @@ struct APFloatBase {
static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE;
static const fltSemantics &Float8E3M4() LLVM_READNONE;
static const fltSemantics &FloatTF32() LLVM_READNONE;
+ static const fltSemantics &Float8E8M0FNU() LLVM_READNONE;
static const fltSemantics &Float6E3M2FN() LLVM_READNONE;
static const fltSemantics &Float6E2M3FN() LLVM_READNONE;
static const fltSemantics &Float4E2M1FN() LLVM_READNONE;
@@ -581,7 +589,8 @@ private:
integerPart addSignificand(const IEEEFloat &);
integerPart subtractSignificand(const IEEEFloat &, integerPart);
lostFraction addOrSubtractSignificand(const IEEEFloat &, bool subtract);
- lostFraction multiplySignificand(const IEEEFloat &, IEEEFloat);
+ lostFraction multiplySignificand(const IEEEFloat &, IEEEFloat,
+ bool ignoreAddend = false);
lostFraction multiplySignificand(const IEEEFloat&);
lostFraction divideSignificand(const IEEEFloat &);
void incrementSignificand();
@@ -591,6 +600,7 @@ private:
unsigned int significandLSB() const;
unsigned int significandMSB() const;
void zeroSignificand();
+ unsigned int getNumHighBits() const;
/// Return true if the significand excluding the integral bit is all ones.
bool isSignificandAllOnes() const;
bool isSignificandAllOnesExceptLSB() const;
@@ -652,6 +662,7 @@ private:
APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const;
APInt convertFloat8E3M4APFloatToAPInt() const;
APInt convertFloatTF32APFloatToAPInt() const;
+ APInt convertFloat8E8M0FNUAPFloatToAPInt() const;
APInt convertFloat6E3M2FNAPFloatToAPInt() const;
APInt convertFloat6E2M3FNAPFloatToAPInt() const;
APInt convertFloat4E2M1FNAPFloatToAPInt() const;
@@ -672,6 +683,7 @@ private:
void initFromFloat8E4M3B11FNUZAPInt(const APInt &api);
void initFromFloat8E3M4APInt(const APInt &api);
void initFromFloatTF32APInt(const APInt &api);
+ void initFromFloat8E8M0FNUAPInt(const APInt &api);
void initFromFloat6E3M2FNAPInt(const APInt &api);
void initFromFloat6E2M3FNAPInt(const APInt &api);
void initFromFloat4E2M1FNAPInt(const APInt &api);
@@ -1079,6 +1091,9 @@ public:
/// \param Semantics - type float semantics
static APFloat getAllOnesValue(const fltSemantics &Semantics);
+ /// Returns true if the given semantics supports either NaN or Infinity.
+ ///
+ /// \param Sem - type float semantics
static bool hasNanOrInf(const fltSemantics &Sem) {
switch (SemanticsToEnum(Sem)) {
default:
@@ -1091,6 +1106,13 @@ public:
}
}
+ /// Returns true if the given semantics has actual significand.
+ ///
+ /// \param Sem - type float semantics
+ static bool hasSignificand(const fltSemantics &Sem) {
+ return &Sem != &Float8E8M0FNU();
+ }
+
/// Used to insert APFloat objects, or objects that contain APFloat objects,
/// into FoldingSets.
void Profile(FoldingSetNodeID &NID) const;
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index dee917f..03413f6 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -119,6 +119,13 @@ struct fltSemantics {
fltNonfiniteBehavior nonFiniteBehavior = fltNonfiniteBehavior::IEEE754;
fltNanEncoding nanEncoding = fltNanEncoding::IEEE;
+
+ /* Whether this semantics has an encoding for Zero */
+ bool hasZero = true;
+
+ /* Whether this semantics can represent signed values */
+ bool hasSignedRepr = true;
+
// Returns true if any number described by this semantics can be precisely
// represented by the specified semantics. Does not take into account
// the value of fltNonfiniteBehavior.
@@ -145,6 +152,10 @@ static constexpr fltSemantics semFloat8E4M3B11FNUZ = {
4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E3M4 = {3, -2, 5, 8};
static constexpr fltSemantics semFloatTF32 = {127, -126, 11, 19};
+static constexpr fltSemantics semFloat8E8M0FNU = {
+ 127, -127, 1, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes,
+ false, false};
+
static constexpr fltSemantics semFloat6E3M2FN = {
4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly};
static constexpr fltSemantics semFloat6E2M3FN = {
@@ -222,6 +233,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
return Float8E3M4();
case S_FloatTF32:
return FloatTF32();
+ case S_Float8E8M0FNU:
+ return Float8E8M0FNU();
case S_Float6E3M2FN:
return Float6E3M2FN();
case S_Float6E2M3FN:
@@ -264,6 +277,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
return S_Float8E3M4;
else if (&Sem == &llvm::APFloat::FloatTF32())
return S_FloatTF32;
+ else if (&Sem == &llvm::APFloat::Float8E8M0FNU())
+ return S_Float8E8M0FNU;
else if (&Sem == &llvm::APFloat::Float6E3M2FN())
return S_Float6E3M2FN;
else if (&Sem == &llvm::APFloat::Float6E2M3FN())
@@ -294,6 +309,7 @@ const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
}
const fltSemantics &APFloatBase::Float8E3M4() { return semFloat8E3M4; }
const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; }
+const fltSemantics &APFloatBase::Float8E8M0FNU() { return semFloat8E8M0FNU; }
const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; }
const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; }
const fltSemantics &APFloatBase::Float4E2M1FN() { return semFloat4E2M1FN; }
@@ -396,7 +412,8 @@ static inline Error createError(const Twine &Err) {
}
static constexpr inline unsigned int partCountForBits(unsigned int bits) {
- return ((bits) + APFloatBase::integerPartWidth - 1) / APFloatBase::integerPartWidth;
+ return std::max(1u, (bits + APFloatBase::integerPartWidth - 1) /
+ APFloatBase::integerPartWidth);
}
/* Returns 0U-9U. Return values >= 10U are not digits. */
@@ -918,6 +935,10 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) {
if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly)
llvm_unreachable("This floating point format does not support NaN");
+ if (Negative && !semantics->hasSignedRepr)
+ llvm_unreachable(
+ "This floating point format does not support signed values");
+
category = fcNaN;
sign = Negative;
exponent = exponentNaN();
@@ -955,7 +976,8 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) {
significand[part] = 0;
}
- unsigned QNaNBit = semantics->precision - 2;
+ unsigned QNaNBit =
+ (semantics->precision >= 2) ? (semantics->precision - 2) : 0;
if (SNaN) {
// We always have to clear the QNaN bit to make it an SNaN.
@@ -1025,6 +1047,19 @@ bool IEEEFloat::isSmallestNormalized() const {
isSignificandAllZerosExceptMSB();
}
+unsigned int IEEEFloat::getNumHighBits() const {
+ const unsigned int PartCount = partCountForBits(semantics->precision);
+ const unsigned int Bits = PartCount * integerPartWidth;
+
+ // Compute how many bits are used in the final word.
+ // When precision is just 1, it represents the 'Pth'
+ // Precision bit and not the actual significand bit.
+ const unsigned int NumHighBits = (semantics->precision > 1)
+ ? (Bits - semantics->precision + 1)
+ : (Bits - semantics->precision);
+ return NumHighBits;
+}
+
bool IEEEFloat::isSignificandAllOnes() const {
// Test if the significand excluding the integral bit is all ones. This allows
// us to test for binade boundaries.
@@ -1035,13 +1070,12 @@ bool IEEEFloat::isSignificandAllOnes() const {
return false;
// Set the unused high bits to all ones when we compare.
- const unsigned NumHighBits =
- PartCount*integerPartWidth - semantics->precision + 1;
+ const unsigned NumHighBits = getNumHighBits();
assert(NumHighBits <= integerPartWidth && NumHighBits > 0 &&
"Can not have more high bits to fill than integerPartWidth");
const integerPart HighBitFill =
~integerPart(0) << (integerPartWidth - NumHighBits);
- if (~(Parts[PartCount - 1] | HighBitFill))
+ if ((semantics->precision <= 1) || (~(Parts[PartCount - 1] | HighBitFill)))
return false;
return true;
@@ -1062,8 +1096,7 @@ bool IEEEFloat::isSignificandAllOnesExceptLSB() const {
}
// Set the unused high bits to all ones when we compare.
- const unsigned NumHighBits =
- PartCount * integerPartWidth - semantics->precision + 1;
+ const unsigned NumHighBits = getNumHighBits();
assert(NumHighBits <= integerPartWidth && NumHighBits > 0 &&
"Can not have more high bits to fill than integerPartWidth");
const integerPart HighBitFill = ~integerPart(0)
@@ -1085,13 +1118,12 @@ bool IEEEFloat::isSignificandAllZeros() const {
return false;
// Compute how many bits are used in the final word.
- const unsigned NumHighBits =
- PartCount*integerPartWidth - semantics->precision + 1;
+ const unsigned NumHighBits = getNumHighBits();
assert(NumHighBits < integerPartWidth && "Can not have more high bits to "
"clear than integerPartWidth");
const integerPart HighBitMask = ~integerPart(0) >> NumHighBits;
- if (Parts[PartCount - 1] & HighBitMask)
+ if ((semantics->precision > 1) && (Parts[PartCount - 1] & HighBitMask))
return false;
return true;
@@ -1106,25 +1138,26 @@ bool IEEEFloat::isSignificandAllZerosExceptMSB() const {
return false;
}
- const unsigned NumHighBits =
- PartCount * integerPartWidth - semantics->precision + 1;
- return Parts[PartCount - 1] == integerPart(1)
- << (integerPartWidth - NumHighBits);
+ const unsigned NumHighBits = getNumHighBits();
+ const integerPart MSBMask = integerPart(1)
+ << (integerPartWidth - NumHighBits);
+ return ((semantics->precision <= 1) || (Parts[PartCount - 1] == MSBMask));
}
bool IEEEFloat::isLargest() const {
+ bool IsMaxExp = isFiniteNonZero() && exponent == semantics->maxExponent;
if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly &&
semantics->nanEncoding == fltNanEncoding::AllOnes) {
// The largest number by magnitude in our format will be the floating point
// number with maximum exponent and with significand that is all ones except
// the LSB.
- return isFiniteNonZero() && exponent == semantics->maxExponent &&
- isSignificandAllOnesExceptLSB();
+ return (IsMaxExp && APFloat::hasSignificand(*semantics))
+ ? isSignificandAllOnesExceptLSB()
+ : IsMaxExp;
} else {
// The largest number by magnitude in our format will be the floating point
// number with maximum exponent and with significand that is all ones.
- return isFiniteNonZero() && exponent == semantics->maxExponent &&
- isSignificandAllOnes();
+ return IsMaxExp && isSignificandAllOnes();
}
}
@@ -1165,7 +1198,13 @@ IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics, integerPart value) {
IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics) {
initialize(&ourSemantics);
- makeZero(false);
+ // The Float8E8MOFNU format does not have a representation
+ // for zero. So, use the closest representation instead.
+ // Moreover, the all-zero encoding represents a valid
+ // normal value (which is the smallestNormalized here).
+ // Hence, we call makeSmallestNormalized (where category is
+ // 'fcNormal') instead of makeZero (where category is 'fcZero').
+ ourSemantics.hasZero ? makeZero(false) : makeSmallestNormalized(false);
}
// Delegate to the previous constructor, because later copy constructor may
@@ -1245,7 +1284,8 @@ IEEEFloat::integerPart IEEEFloat::subtractSignificand(const IEEEFloat &rhs,
on to the full-precision result of the multiplication. Returns the
lost fraction. */
lostFraction IEEEFloat::multiplySignificand(const IEEEFloat &rhs,
- IEEEFloat addend) {
+ IEEEFloat addend,
+ bool ignoreAddend) {
unsigned int omsb; // One, not zero, based MSB.
unsigned int partsCount, newPartsCount, precision;
integerPart *lhsSignificand;
@@ -1289,7 +1329,7 @@ lostFraction IEEEFloat::multiplySignificand(const IEEEFloat &rhs,
// toward left by two bits, and adjust exponent accordingly.
exponent += 2;
- if (addend.isNonZero()) {
+ if (!ignoreAddend && addend.isNonZero()) {
// The intermediate result of the multiplication has "2 * precision"
// signicant bit; adjust the addend to be consistent with mul result.
//
@@ -1377,7 +1417,12 @@ lostFraction IEEEFloat::multiplySignificand(const IEEEFloat &rhs,
}
lostFraction IEEEFloat::multiplySignificand(const IEEEFloat &rhs) {
- return multiplySignificand(rhs, IEEEFloat(*semantics));
+ // When the given semantics has zero, the addend here is a zero.
+ // i.e . it belongs to the 'fcZero' category.
+ // But when the semantics does not support zero, we need to
+ // explicitly convey that this addend should be ignored
+ // for multiplication.
+ return multiplySignificand(rhs, IEEEFloat(*semantics), !semantics->hasZero);
}
/* Multiply the significands of LHS and RHS to DST. */
@@ -1483,7 +1528,8 @@ lostFraction IEEEFloat::shiftSignificandRight(unsigned int bits) {
/* Shift the significand left BITS bits, subtract BITS from its exponent. */
void IEEEFloat::shiftSignificandLeft(unsigned int bits) {
- assert(bits < semantics->precision);
+ assert(bits < semantics->precision ||
+ (semantics->precision == 1 && bits <= 1));
if (bits) {
unsigned int partsCount = partCount();
@@ -1678,6 +1724,8 @@ IEEEFloat::opStatus IEEEFloat::normalize(roundingMode rounding_mode,
category = fcZero;
if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
sign = false;
+ if (!semantics->hasZero)
+ makeSmallestNormalized(false);
}
return opOK;
@@ -1729,6 +1777,11 @@ IEEEFloat::opStatus IEEEFloat::normalize(roundingMode rounding_mode,
category = fcZero;
if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
sign = false;
+ // This condition handles the case where the semantics
+ // does not have zero but uses the all-zero encoding
+ // to represent the smallest normal value.
+ if (!semantics->hasZero)
+ makeSmallestNormalized(false);
}
/* The fcZero case is a denormal that underflowed to zero. */
@@ -1807,6 +1860,10 @@ lostFraction IEEEFloat::addOrSubtractSignificand(const IEEEFloat &rhs,
/* Subtraction is more subtle than one might naively expect. */
if (subtract) {
+ if ((bits < 0) && !semantics->hasSignedRepr)
+ llvm_unreachable(
+ "This floating point format does not support signed values");
+
IEEEFloat temp_rhs(rhs);
if (bits == 0)
@@ -2252,6 +2309,17 @@ IEEEFloat::opStatus IEEEFloat::mod(const IEEEFloat &rhs) {
V.sign = sign;
fs = subtract(V, rmNearestTiesToEven);
+
+ // When the semantics supports zero, this loop's
+ // exit-condition is handled by the 'isFiniteNonZero'
+ // category check above. However, when the semantics
+ // does not have 'fcZero' and we have reached the
+ // minimum possible value, (and any further subtract
+ // will underflow to the same value) explicitly
+ // provide an exit-path here.
+ if (!semantics->hasZero && this->isSmallest())
+ break;
+
assert(fs==opOK);
}
if (isZero()) {
@@ -2606,6 +2674,8 @@ IEEEFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics,
fs = opOK;
}
+ if (category == fcZero && !semantics->hasZero)
+ makeSmallestNormalized(false);
return fs;
}
@@ -3070,6 +3140,8 @@ IEEEFloat::convertFromDecimalString(StringRef str, roundingMode rounding_mode) {
fs = opOK;
if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
sign = false;
+ if (!semantics->hasZero)
+ makeSmallestNormalized(false);
/* Check whether the normalized exponent is high enough to overflow
max during the log-rebasing in the max-exponent check below. */
@@ -3237,6 +3309,10 @@ IEEEFloat::convertFromString(StringRef str, roundingMode rounding_mode) {
StringRef::iterator p = str.begin();
size_t slen = str.size();
sign = *p == '-' ? 1 : 0;
+ if (sign && !semantics->hasSignedRepr)
+ llvm_unreachable(
+ "This floating point format does not support signed values");
+
if (*p == '-' || *p == '+') {
p++;
slen--;
@@ -3533,15 +3609,16 @@ APInt IEEEFloat::convertPPCDoubleDoubleAPFloatToAPInt() const {
template <const fltSemantics &S>
APInt IEEEFloat::convertIEEEFloatToAPInt() const {
assert(semantics == &S);
-
- constexpr int bias = -(S.minExponent - 1);
+ const int bias =
+ (semantics == &semFloat8E8M0FNU) ? -S.minExponent : -(S.minExponent - 1);
constexpr unsigned int trailing_significand_bits = S.precision - 1;
constexpr int integer_bit_part = trailing_significand_bits / integerPartWidth;
constexpr integerPart integer_bit =
integerPart{1} << (trailing_significand_bits % integerPartWidth);
constexpr uint64_t significand_mask = integer_bit - 1;
constexpr unsigned int exponent_bits =
- S.sizeInBits - 1 - trailing_significand_bits;
+ trailing_significand_bits ? (S.sizeInBits - 1 - trailing_significand_bits)
+ : S.sizeInBits;
static_assert(exponent_bits < 64);
constexpr uint64_t exponent_mask = (uint64_t{1} << exponent_bits) - 1;
@@ -3557,6 +3634,8 @@ APInt IEEEFloat::convertIEEEFloatToAPInt() const {
!(significandParts()[integer_bit_part] & integer_bit))
myexponent = 0; // denormal
} else if (category == fcZero) {
+ if (!S.hasZero)
+ llvm_unreachable("semantics does not support zero!");
myexponent = ::exponentZero(S) + bias;
mysignificand.fill(0);
} else if (category == fcInfinity) {
@@ -3659,6 +3738,11 @@ APInt IEEEFloat::convertFloatTF32APFloatToAPInt() const {
return convertIEEEFloatToAPInt<semFloatTF32>();
}
+APInt IEEEFloat::convertFloat8E8M0FNUAPFloatToAPInt() const {
+ assert(partCount() == 1);
+ return convertIEEEFloatToAPInt<semFloat8E8M0FNU>();
+}
+
APInt IEEEFloat::convertFloat6E3M2FNAPFloatToAPInt() const {
assert(partCount() == 1);
return convertIEEEFloatToAPInt<semFloat6E3M2FN>();
@@ -3721,6 +3805,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
if (semantics == (const llvm::fltSemantics *)&semFloatTF32)
return convertFloatTF32APFloatToAPInt();
+ if (semantics == (const llvm::fltSemantics *)&semFloat8E8M0FNU)
+ return convertFloat8E8M0FNUAPFloatToAPInt();
+
if (semantics == (const llvm::fltSemantics *)&semFloat6E3M2FN)
return convertFloat6E3M2FNAPFloatToAPInt();
@@ -3819,6 +3906,40 @@ void IEEEFloat::initFromPPCDoubleDoubleAPInt(const APInt &api) {
}
}
+// The E8M0 format has the following characteristics:
+// It is an 8-bit unsigned format with only exponents (no actual significand).
+// No encodings for {zero, infinities or denorms}.
+// NaN is represented by all 1's.
+// Bias is 127.
+void IEEEFloat::initFromFloat8E8M0FNUAPInt(const APInt &api) {
+ const uint64_t exponent_mask = 0xff;
+ uint64_t val = api.getRawData()[0];
+ uint64_t myexponent = (val & exponent_mask);
+
+ initialize(&semFloat8E8M0FNU);
+ assert(partCount() == 1);
+
+ // This format has unsigned representation only
+ sign = 0;
+
+ // Set the significand
+ // This format does not have any significand but the 'Pth' precision bit is
+ // always set to 1 for consistency in APFloat's internal representation.
+ uint64_t mysignificand = 1;
+ significandParts()[0] = mysignificand;
+
+ // This format can either have a NaN or fcNormal
+ // All 1's i.e. 255 is a NaN
+ if (val == exponent_mask) {
+ category = fcNaN;
+ exponent = exponentNaN();
+ return;
+ }
+ // Handle fcNormal...
+ category = fcNormal;
+ exponent = myexponent - 127; // 127 is bias
+ return;
+}
template <const fltSemantics &S>
void IEEEFloat::initFromIEEEAPInt(const APInt &api) {
assert(api.getBitWidth() == S.sizeInBits);
@@ -3999,6 +4120,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
return initFromFloat8E3M4APInt(api);
if (Sem == &semFloatTF32)
return initFromFloatTF32APInt(api);
+ if (Sem == &semFloat8E8M0FNU)
+ return initFromFloat8E8M0FNUAPInt(api);
if (Sem == &semFloat6E3M2FN)
return initFromFloat6E3M2FNAPInt(api);
if (Sem == &semFloat6E2M3FN)
@@ -4012,6 +4135,9 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
/// Make this number the largest magnitude normal number in the given
/// semantics.
void IEEEFloat::makeLargest(bool Negative) {
+ if (Negative && !semantics->hasSignedRepr)
+ llvm_unreachable(
+ "This floating point format does not support signed values");
// We want (in interchange format):
// sign = {Negative}
// exponent = 1..10
@@ -4032,15 +4158,18 @@ void IEEEFloat::makeLargest(bool Negative) {
significand[PartCount - 1] = (NumUnusedHighBits < integerPartWidth)
? (~integerPart(0) >> NumUnusedHighBits)
: 0;
-
if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly &&
- semantics->nanEncoding == fltNanEncoding::AllOnes)
+ semantics->nanEncoding == fltNanEncoding::AllOnes &&
+ (semantics->precision > 1))
significand[0] &= ~integerPart(1);
}
/// Make this number the smallest magnitude denormal number in the given
/// semantics.
void IEEEFloat::makeSmallest(bool Negative) {
+ if (Negative && !semantics->hasSignedRepr)
+ llvm_unreachable(
+ "This floating point format does not support signed values");
// We want (in interchange format):
// sign = {Negative}
// exponent = 0..0
@@ -4052,6 +4181,9 @@ void IEEEFloat::makeSmallest(bool Negative) {
}
void IEEEFloat::makeSmallestNormalized(bool Negative) {
+ if (Negative && !semantics->hasSignedRepr)
+ llvm_unreachable(
+ "This floating point format does not support signed values");
// We want (in interchange format):
// sign = {Negative}
// exponent = 0..0
@@ -4509,6 +4641,8 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) {
exponent = 0;
if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
sign = false;
+ if (!semantics->hasZero)
+ makeSmallestNormalized(false);
break;
}
@@ -4574,7 +4708,10 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) {
// the integral bit to 1, and increment the exponent. If we have a
// denormal always increment since moving denormals and the numbers in the
// smallest normal binade have the same exponent in our representation.
- bool WillCrossBinadeBoundary = !isDenormal() && isSignificandAllOnes();
+ // If there are only exponents, any increment always crosses the
+ // BinadeBoundary.
+ bool WillCrossBinadeBoundary = !APFloat::hasSignificand(*semantics) ||
+ (!isDenormal() && isSignificandAllOnes());
if (WillCrossBinadeBoundary) {
integerPart *Parts = significandParts();
@@ -4626,6 +4763,9 @@ void IEEEFloat::makeInf(bool Negative) {
}
void IEEEFloat::makeZero(bool Negative) {
+ if (!semantics->hasZero)
+ llvm_unreachable("This floating point format does not support Zero");
+
category = fcZero;
sign = Negative;
if (semantics->nanEncoding == fltNanEncoding::NegativeZero) {
diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp
index cd8a00f..6008f00 100644
--- a/llvm/unittests/ADT/APFloatTest.cpp
+++ b/llvm/unittests/ADT/APFloatTest.cpp
@@ -824,6 +824,11 @@ TEST(APFloatTest, IsSmallestNormalized) {
const fltSemantics &Semantics =
APFloat::EnumToSemantics(static_cast<APFloat::Semantics>(I));
+ // For Float8E8M0FNU format, the below cases are tested
+ // through Float8E8M0FNUSmallest and Float8E8M0FNUNext tests.
+ if (I == APFloat::S_Float8E8M0FNU)
+ continue;
+
EXPECT_FALSE(APFloat::getZero(Semantics, false).isSmallestNormalized());
EXPECT_FALSE(APFloat::getZero(Semantics, true).isSmallestNormalized());
@@ -1917,6 +1922,57 @@ TEST(DoubleAPFloatTest, isInteger) {
EXPECT_FALSE(T3.isInteger());
}
+// Test to check if the full range of Float8E8M0FNU
+// values are being represented correctly.
+TEST(APFloatTest, Float8E8M0FNUValues) {
+ // High end of the range
+ auto test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p127");
+ EXPECT_EQ(0x1.0p127, test.convertToDouble());
+
+ test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p126");
+ EXPECT_EQ(0x1.0p126, test.convertToDouble());
+
+ test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p125");
+ EXPECT_EQ(0x1.0p125, test.convertToDouble());
+
+ // tests the fix in makeLargest()
+ test = APFloat::getLargest(APFloat::Float8E8M0FNU());
+ EXPECT_EQ(0x1.0p127, test.convertToDouble());
+
+ // tests overflow to nan
+ APFloat nan = APFloat(APFloat::Float8E8M0FNU(), "nan");
+ test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p128");
+ EXPECT_TRUE(test.bitwiseIsEqual(nan));
+
+ // Mid of the range
+ test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p0");
+ EXPECT_EQ(1.0, test.convertToDouble());
+
+ test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p1");
+ EXPECT_EQ(2.0, test.convertToDouble());
+
+ test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p2");
+ EXPECT_EQ(4.0, test.convertToDouble());
+
+ // Low end of the range
+ test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p-125");
+ EXPECT_EQ(0x1.0p-125, test.convertToDouble());
+
+ test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p-126");
+ EXPECT_EQ(0x1.0p-126, test.convertToDouble());
+
+ test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p-127");
+ EXPECT_EQ(0x1.0p-127, test.convertToDouble());
+
+ // Smallest value
+ test = APFloat::getSmallest(APFloat::Float8E8M0FNU());
+ EXPECT_EQ(0x1.0p-127, test.convertToDouble());
+
+ // Value below the smallest, but clamped to the smallest
+ test = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p-128");
+ EXPECT_EQ(0x1.0p-127, test.convertToDouble());
+}
+
TEST(APFloatTest, getLargest) {
EXPECT_EQ(3.402823466e+38f, APFloat::getLargest(APFloat::IEEEsingle()).convertToFloat());
EXPECT_EQ(1.7976931348623158e+308, APFloat::getLargest(APFloat::IEEEdouble()).convertToDouble());
@@ -1929,6 +1985,8 @@ TEST(APFloatTest, getLargest) {
30, APFloat::getLargest(APFloat::Float8E4M3B11FNUZ()).convertToDouble());
EXPECT_EQ(3.40116213421e+38f,
APFloat::getLargest(APFloat::FloatTF32()).convertToFloat());
+ EXPECT_EQ(1.701411834e+38f,
+ APFloat::getLargest(APFloat::Float8E8M0FNU()).convertToDouble());
EXPECT_EQ(28, APFloat::getLargest(APFloat::Float6E3M2FN()).convertToDouble());
EXPECT_EQ(7.5,
APFloat::getLargest(APFloat::Float6E2M3FN()).convertToDouble());
@@ -2012,6 +2070,13 @@ TEST(APFloatTest, getSmallest) {
EXPECT_TRUE(test.isFiniteNonZero());
EXPECT_TRUE(test.isDenormal());
EXPECT_TRUE(test.bitwiseIsEqual(expected));
+
+ test = APFloat::getSmallest(APFloat::Float8E8M0FNU());
+ expected = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p-127");
+ EXPECT_FALSE(test.isNegative());
+ EXPECT_TRUE(test.isFiniteNonZero());
+ EXPECT_FALSE(test.isDenormal());
+ EXPECT_TRUE(test.bitwiseIsEqual(expected));
}
TEST(APFloatTest, getSmallestNormalized) {
@@ -2118,6 +2183,14 @@ TEST(APFloatTest, getSmallestNormalized) {
EXPECT_FALSE(test.isDenormal());
EXPECT_TRUE(test.bitwiseIsEqual(expected));
EXPECT_TRUE(test.isSmallestNormalized());
+
+ test = APFloat::getSmallestNormalized(APFloat::Float8E8M0FNU(), false);
+ expected = APFloat(APFloat::Float8E8M0FNU(), "0x1.0p-127");
+ EXPECT_FALSE(test.isNegative());
+ EXPECT_TRUE(test.isFiniteNonZero());
+ EXPECT_FALSE(test.isDenormal());
+ EXPECT_TRUE(test.bitwiseIsEqual(expected));
+ EXPECT_TRUE(test.isSmallestNormalized());
}
TEST(APFloatTest, getZero) {
@@ -5326,6 +5399,104 @@ TEST(APFloatTest, Float8ExhaustivePair) {
}
}
+TEST(APFloatTest, Float8E8M0FNUExhaustivePair) {
+ // Test each pair of 8-bit values for Float8E8M0FNU format
+ APFloat::Semantics Sem = APFloat::S_Float8E8M0FNU;
+ const llvm::fltSemantics &S = APFloat::EnumToSemantics(Sem);
+ for (int i = 0; i < 256; i++) {
+ for (int j = 0; j < 256; j++) {
+ SCOPED_TRACE("sem=" + std::to_string(Sem) + ",i=" + std::to_string(i) +
+ ",j=" + std::to_string(j));
+ APFloat x(S, APInt(8, i));
+ APFloat y(S, APInt(8, j));
+
+ bool losesInfo;
+ APFloat xd = x;
+ xd.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
+ &losesInfo);
+ EXPECT_FALSE(losesInfo);
+ APFloat yd = y;
+ yd.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
+ &losesInfo);
+ EXPECT_FALSE(losesInfo);
+
+ // Add
+ APFloat z = x;
+ z.add(y, APFloat::rmNearestTiesToEven);
+ APFloat zd = xd;
+ zd.add(yd, APFloat::rmNearestTiesToEven);
+ zd.convert(S, APFloat::rmNearestTiesToEven, &losesInfo);
+ EXPECT_TRUE(z.bitwiseIsEqual(zd))
+ << "sem=" << Sem << ", i=" << i << ", j=" << j;
+
+ // Subtract
+ if (i >= j) {
+ z = x;
+ z.subtract(y, APFloat::rmNearestTiesToEven);
+ zd = xd;
+ zd.subtract(yd, APFloat::rmNearestTiesToEven);
+ zd.convert(S, APFloat::rmNearestTiesToEven, &losesInfo);
+ EXPECT_TRUE(z.bitwiseIsEqual(zd))
+ << "sem=" << Sem << ", i=" << i << ", j=" << j;
+ }
+
+ // Multiply
+ z = x;
+ z.multiply(y, APFloat::rmNearestTiesToEven);
+ zd = xd;
+ zd.multiply(yd, APFloat::rmNearestTiesToEven);
+ zd.convert(S, APFloat::rmNearestTiesToEven, &losesInfo);
+ EXPECT_TRUE(z.bitwiseIsEqual(zd))
+ << "sem=" << Sem << ", i=" << i << ", j=" << j;
+
+ // Divide
+ z = x;
+ z.divide(y, APFloat::rmNearestTiesToEven);
+ zd = xd;
+ zd.divide(yd, APFloat::rmNearestTiesToEven);
+ zd.convert(S, APFloat::rmNearestTiesToEven, &losesInfo);
+ EXPECT_TRUE(z.bitwiseIsEqual(zd))
+ << "sem=" << Sem << ", i=" << i << ", j=" << j;
+
+ // Mod
+ z = x;
+ z.mod(y);
+ zd = xd;
+ zd.mod(yd);
+ zd.convert(S, APFloat::rmNearestTiesToEven, &losesInfo);
+ EXPECT_TRUE(z.bitwiseIsEqual(zd))
+ << "sem=" << Sem << ", i=" << i << ", j=" << j;
+ APFloat mod_cached = z;
+ // When one of them is a NaN, the result is a NaN.
+ // When i < j, the mod is 'i' since it is the smaller
+ // number. Otherwise the mod is always zero since
+ // both x and y are powers-of-two in this format.
+ // Since this format does not support zero and it is
+ // represented as the smallest normalized value, we
+ // test for isSmallestNormalized().
+ if (i == 255 || j == 255)
+ EXPECT_TRUE(z.isNaN());
+ else if (i >= j)
+ EXPECT_TRUE(z.isSmallestNormalized());
+ else
+ EXPECT_TRUE(z.bitwiseIsEqual(x));
+
+ // Remainder
+ z = x;
+ z.remainder(y);
+ zd = xd;
+ zd.remainder(yd);
+ zd.convert(S, APFloat::rmNearestTiesToEven, &losesInfo);
+ EXPECT_TRUE(z.bitwiseIsEqual(zd))
+ << "sem=" << Sem << ", i=" << i << ", j=" << j;
+ // Since this format has only exponents (i.e. no precision)
+ // we expect the remainder and mod to provide the same results.
+ EXPECT_TRUE(z.bitwiseIsEqual(mod_cached))
+ << "sem=" << Sem << ", i=" << i << ", j=" << j;
+ }
+ }
+}
+
TEST(APFloatTest, Float6ExhaustivePair) {
// Test each pair of 6-bit floats with non-standard semantics
for (APFloat::Semantics Sem :
@@ -5801,6 +5972,46 @@ TEST(APFloatTest, Float8E4M3FNExhaustive) {
}
}
+TEST(APFloatTest, Float8E8M0FNUExhaustive) {
+ // Test each of the 256 Float8E8M0FNU values.
+ for (int i = 0; i < 256; i++) {
+ APFloat test(APFloat::Float8E8M0FNU(), APInt(8, i));
+ SCOPED_TRACE("i=" + std::to_string(i));
+
+ // isLargest
+ if (i == 254) {
+ EXPECT_TRUE(test.isLargest());
+ EXPECT_EQ(abs(test).convertToDouble(), 0x1.0p127);
+ } else {
+ EXPECT_FALSE(test.isLargest());
+ }
+
+ // isSmallest
+ if (i == 0) {
+ EXPECT_TRUE(test.isSmallest());
+ EXPECT_EQ(abs(test).convertToDouble(), 0x1.0p-127);
+ } else {
+ EXPECT_FALSE(test.isSmallest());
+ }
+
+ // convert to Double
+ bool losesInfo;
+ std::string val = std::to_string(i - 127); // 127 is the bias
+ llvm::SmallString<16> str("0x1.0p");
+ str += val;
+ APFloat test2(APFloat::IEEEdouble(), str);
+
+ APFloat::opStatus status = test.convert(
+ APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &losesInfo);
+ EXPECT_EQ(status, APFloat::opOK);
+ EXPECT_FALSE(losesInfo);
+ if (i == 255)
+ EXPECT_TRUE(test.isNaN());
+ else
+ EXPECT_EQ(test.convertToDouble(), test2.convertToDouble());
+ }
+}
+
TEST(APFloatTest, Float8E5M2FNUZNext) {
APFloat test(APFloat::Float8E5M2FNUZ(), APFloat::uninitialized);
APFloat expected(APFloat::Float8E5M2FNUZ(), APFloat::uninitialized);
@@ -7077,6 +7288,12 @@ TEST(APFloatTest, getExactLog2) {
auto SemEnum = static_cast<APFloat::Semantics>(I);
const fltSemantics &Semantics = APFloat::EnumToSemantics(SemEnum);
+ // For the Float8E8M0FNU format, the below cases along
+ // with some more corner cases are tested through
+ // Float8E8M0FNUGetExactLog2.
+ if (I == APFloat::S_Float8E8M0FNU)
+ continue;
+
APFloat One(Semantics, "1.0");
if (I == APFloat::S_PPCDoubleDouble) {
@@ -7146,6 +7363,250 @@ TEST(APFloatTest, getExactLog2) {
}
}
+TEST(APFloatTest, Float8E8M0FNUGetZero) {
+#ifdef GTEST_HAS_DEATH_TEST
+#ifndef NDEBUG
+ EXPECT_DEATH(APFloat::getZero(APFloat::Float8E8M0FNU(), false),
+ "This floating point format does not support Zero");
+ EXPECT_DEATH(APFloat::getZero(APFloat::Float8E8M0FNU(), true),
+ "This floating point format does not support Zero");
+#endif
+#endif
+}
+
+TEST(APFloatTest, Float8E8M0FNUGetSignedValues) {
+#ifdef GTEST_HAS_DEATH_TEST
+#ifndef NDEBUG
+ EXPECT_DEATH(APFloat(APFloat::Float8E8M0FNU(), "-64"),
+ "This floating point format does not support signed values");
+ EXPECT_DEATH(APFloat(APFloat::Float8E8M0FNU(), "-0x1.0p128"),
+ "This floating point format does not support signed values");
+ EXPECT_DEATH(APFloat(APFloat::Float8E8M0FNU(), "-inf"),
+ "This floating point format does not support signed values");
+ EXPECT_DEATH(APFloat::getNaN(APFloat::Float8E8M0FNU(), true),
+ "This floating point format does not support signed values");
+ EXPECT_DEATH(APFloat::getInf(APFloat::Float8E8M0FNU(), true),
+ "This floating point format does not support signed values");
+ EXPECT_DEATH(APFloat::getSmallest(APFloat::Float8E8M0FNU(), true),
+ "This floating point format does not support signed values");
+ EXPECT_DEATH(APFloat::getSmallestNormalized(APFloat::Float8E8M0FNU(), true),
+ "This floating point format does not support signed values");
+ EXPECT_DEATH(APFloat::getLargest(APFloat::Float8E8M0FNU(), true),
+ "This floating point format does not support signed values");
+ APFloat x = APFloat(APFloat::Float8E8M0FNU(), "4");
+ APFloat y = APFloat(APFloat::Float8E8M0FNU(), "8");
+ EXPECT_DEATH(x.subtract(y, APFloat::rmNearestTiesToEven),
+ "This floating point format does not support signed values");
+#endif
+#endif
+}
+
+TEST(APFloatTest, Float8E8M0FNUGetInf) {
+ // The E8M0 format does not support infinity and the
+ // all ones representation is treated as NaN.
+ APFloat t = APFloat::getInf(APFloat::Float8E8M0FNU());
+ EXPECT_TRUE(t.isNaN());
+ EXPECT_FALSE(t.isInfinity());
+}
+
+TEST(APFloatTest, Float8E8M0FNUFromString) {
+ // Exactly representable
+ EXPECT_EQ(64, APFloat(APFloat::Float8E8M0FNU(), "64").convertToDouble());
+ // Overflow to NaN
+ EXPECT_TRUE(APFloat(APFloat::Float8E8M0FNU(), "0x1.0p128").isNaN());
+ // Inf converted to NaN
+ EXPECT_TRUE(APFloat(APFloat::Float8E8M0FNU(), "inf").isNaN());
+ // NaN converted to NaN
+ EXPECT_TRUE(APFloat(APFloat::Float8E8M0FNU(), "nan").isNaN());
+}
+
+TEST(APFloatTest, Float8E8M0FNUDivideByZero) {
+ APFloat x(APFloat::Float8E8M0FNU(), "1");
+ APFloat zero(APFloat::Float8E8M0FNU(), "0");
+ x.divide(zero, APFloat::rmNearestTiesToEven);
+
+ // Zero is represented as the smallest normalized value
+ // in this format i.e 2^-127.
+ // This tests the fix in convertFromDecimalString() function.
+ EXPECT_EQ(0x1.0p-127, zero.convertToDouble());
+
+ // [1 / (2^-127)] = 2^127
+ EXPECT_EQ(0x1.0p127, x.convertToDouble());
+}
+
+TEST(APFloatTest, Float8E8M0FNUGetExactLog2) {
+ const fltSemantics &Semantics = APFloat::Float8E8M0FNU();
+ APFloat One(Semantics, "1.0");
+ EXPECT_EQ(0, One.getExactLog2());
+
+ // In the Float8E8M0FNU format, 3 is rounded-up to 4.
+ // So, we expect 2 as the result.
+ EXPECT_EQ(2, APFloat(Semantics, "3.0").getExactLog2());
+ EXPECT_EQ(2, APFloat(Semantics, "3.0").getExactLog2Abs());
+
+ // In the Float8E8M0FNU format, 5 is rounded-down to 4.
+ // So, we expect 2 as the result.
+ EXPECT_EQ(2, APFloat(Semantics, "5.0").getExactLog2());
+ EXPECT_EQ(2, APFloat(Semantics, "5.0").getExactLog2Abs());
+
+ // Exact power-of-two value.
+ EXPECT_EQ(3, APFloat(Semantics, "8.0").getExactLog2());
+ EXPECT_EQ(3, APFloat(Semantics, "8.0").getExactLog2Abs());
+
+ // Negative exponent value.
+ EXPECT_EQ(-2, APFloat(Semantics, "0.25").getExactLog2());
+ EXPECT_EQ(-2, APFloat(Semantics, "0.25").getExactLog2Abs());
+
+ int MinExp = APFloat::semanticsMinExponent(Semantics);
+ int MaxExp = APFloat::semanticsMaxExponent(Semantics);
+ int Precision = APFloat::semanticsPrecision(Semantics);
+
+ // Values below the minExp getting capped to minExp.
+ EXPECT_EQ(-127,
+ scalbn(One, MinExp - Precision - 1, APFloat::rmNearestTiesToEven)
+ .getExactLog2());
+ EXPECT_EQ(-127, scalbn(One, MinExp - Precision, APFloat::rmNearestTiesToEven)
+ .getExactLog2());
+
+ // Values above the maxExp overflow to NaN, and getExactLog2() returns
+ // INT_MIN for these cases.
+ EXPECT_EQ(
+ INT_MIN,
+ scalbn(One, MaxExp + 1, APFloat::rmNearestTiesToEven).getExactLog2());
+
+ // This format can represent [minExp, maxExp].
+ // So, the result is the same as the 'Exp' of the scalbn.
+ for (int i = MinExp - Precision + 1; i <= MaxExp; ++i) {
+ EXPECT_EQ(i, scalbn(One, i, APFloat::rmNearestTiesToEven).getExactLog2());
+ }
+}
+
+TEST(APFloatTest, Float8E8M0FNUSmallest) {
+ APFloat test(APFloat::getSmallest(APFloat::Float8E8M0FNU()));
+ EXPECT_EQ(0x1.0p-127, test.convertToDouble());
+
+ // For E8M0 format, there are no denorms.
+ // So, getSmallest is equal to isSmallestNormalized().
+ EXPECT_TRUE(test.isSmallestNormalized());
+ EXPECT_EQ(fcPosNormal, test.classify());
+
+ test = APFloat::getAllOnesValue(APFloat::Float8E8M0FNU());
+ EXPECT_FALSE(test.isSmallestNormalized());
+ EXPECT_TRUE(test.isNaN());
+}
+
+TEST(APFloatTest, Float8E8M0FNUNext) {
+ APFloat test(APFloat::getSmallest(APFloat::Float8E8M0FNU()));
+ // Increment of 1 should reach 2^-126
+ EXPECT_EQ(APFloat::opOK, test.next(false));
+ EXPECT_FALSE(test.isSmallestNormalized());
+ EXPECT_EQ(0x1.0p-126, test.convertToDouble());
+
+ // Decrement of 1, again, should reach 2^-127
+ // i.e. smallest normalized
+ EXPECT_EQ(APFloat::opOK, test.next(true));
+ EXPECT_TRUE(test.isSmallestNormalized());
+
+ // Decrement again, but gets capped at the smallest normalized
+ EXPECT_EQ(APFloat::opOK, test.next(true));
+ EXPECT_TRUE(test.isSmallestNormalized());
+}
+
+TEST(APFloatTest, Float8E8M0FNUFMA) {
+ APFloat f1(APFloat::Float8E8M0FNU(), "4.0");
+ APFloat f2(APFloat::Float8E8M0FNU(), "2.0");
+ APFloat f3(APFloat::Float8E8M0FNU(), "8.0");
+
+ // Exact value: 4*2 + 8 = 16.
+ f1.fusedMultiplyAdd(f2, f3, APFloat::rmNearestTiesToEven);
+ EXPECT_EQ(16.0, f1.convertToDouble());
+
+ // 4*2 + 4 = 12 but it gets rounded-up to 16.
+ f1 = APFloat(APFloat::Float8E8M0FNU(), "4.0");
+ f1.fusedMultiplyAdd(f2, f1, APFloat::rmNearestTiesToEven);
+ EXPECT_EQ(16.0, f1.convertToDouble());
+
+ // 4*2 + 2 = 10 but it gets rounded-down to 8.
+ f1 = APFloat(APFloat::Float8E8M0FNU(), "4.0");
+ f1.fusedMultiplyAdd(f2, f2, APFloat::rmNearestTiesToEven);
+ EXPECT_EQ(8.0, f1.convertToDouble());
+
+ // All of them using the same value.
+ f1 = APFloat(APFloat::Float8E8M0FNU(), "1.0");
+ f1.fusedMultiplyAdd(f1, f1, APFloat::rmNearestTiesToEven);
+ EXPECT_EQ(2.0, f1.convertToDouble());
+}
+
+TEST(APFloatTest, ConvertDoubleToE8M0FNU) {
+ bool losesInfo;
+ APFloat test(APFloat::IEEEdouble(), "1.0");
+ APFloat::opStatus status = test.convert(
+ APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven, &losesInfo);
+ EXPECT_EQ(1.0, test.convertToDouble());
+ EXPECT_FALSE(losesInfo);
+ EXPECT_EQ(status, APFloat::opOK);
+
+ // For E8M0, zero encoding is represented as the smallest normalized value.
+ test = APFloat(APFloat::IEEEdouble(), "0.0");
+ status = test.convert(APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven,
+ &losesInfo);
+ EXPECT_TRUE(test.isSmallestNormalized());
+ EXPECT_EQ(0x1.0p-127, test.convertToDouble());
+ EXPECT_FALSE(losesInfo);
+ EXPECT_EQ(status, APFloat::opOK);
+
+ // Test that the conversion of a power-of-two value is precise.
+ test = APFloat(APFloat::IEEEdouble(), "8.0");
+ status = test.convert(APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven,
+ &losesInfo);
+ EXPECT_EQ(8.0f, test.convertToDouble());
+ EXPECT_FALSE(losesInfo);
+ EXPECT_EQ(status, APFloat::opOK);
+
+ // Test to check round-down conversion to power-of-two.
+ // The fractional part of 9 is "001" (i.e. 1.125x2^3=9).
+ test = APFloat(APFloat::IEEEdouble(), "9.0");
+ status = test.convert(APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven,
+ &losesInfo);
+ EXPECT_EQ(8.0f, test.convertToDouble());
+ EXPECT_TRUE(losesInfo);
+ EXPECT_EQ(status, APFloat::opInexact);
+
+ // Test to check round-up conversion to power-of-two.
+ // The fractional part of 13 is "101" (i.e. 1.625x2^3=13).
+ test = APFloat(APFloat::IEEEdouble(), "13.0");
+ status = test.convert(APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven,
+ &losesInfo);
+ EXPECT_EQ(16.0f, test.convertToDouble());
+ EXPECT_TRUE(losesInfo);
+ EXPECT_EQ(status, APFloat::opInexact);
+
+ // Test to check round-up conversion to power-of-two.
+ // The fractional part of 12 is "100" (i.e. 1.5x2^3=12).
+ test = APFloat(APFloat::IEEEdouble(), "12.0");
+ status = test.convert(APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven,
+ &losesInfo);
+ EXPECT_EQ(16.0f, test.convertToDouble());
+ EXPECT_TRUE(losesInfo);
+ EXPECT_EQ(status, APFloat::opInexact);
+
+ // Overflow to NaN.
+ test = APFloat(APFloat::IEEEdouble(), "0x1.0p128");
+ status = test.convert(APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven,
+ &losesInfo);
+ EXPECT_TRUE(test.isNaN());
+ EXPECT_TRUE(losesInfo);
+ EXPECT_EQ(status, APFloat::opOverflow | APFloat::opInexact);
+
+ // Underflow to smallest normalized value.
+ test = APFloat(APFloat::IEEEdouble(), "0x1.0p-128");
+ status = test.convert(APFloat::Float8E8M0FNU(), APFloat::rmNearestTiesToEven,
+ &losesInfo);
+ EXPECT_TRUE(test.isSmallestNormalized());
+ EXPECT_TRUE(losesInfo);
+ EXPECT_EQ(status, APFloat::opUnderflow | APFloat::opInexact);
+}
+
TEST(APFloatTest, Float6E3M2FNFromString) {
// Exactly representable
EXPECT_EQ(28, APFloat(APFloat::Float6E3M2FN(), "28").convertToDouble());