aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Support/APFloat.cpp
diff options
context:
space:
mode:
authorDavid Majnemer <david.majnemer@gmail.com>2023-03-09 23:10:57 +0000
committerDavid Majnemer <david.majnemer@gmail.com>2023-03-24 20:06:40 +0000
commit2f086f265bf97fe6543fb199f4ef874ca3522479 (patch)
tree88227d7be116b19dd22f0a6939872d9be15853f2 /llvm/lib/Support/APFloat.cpp
parent5a9bad171be5dfdf9430a0f6cbff14d29ca54181 (diff)
downloadllvm-2f086f265bf97fe6543fb199f4ef874ca3522479.zip
llvm-2f086f265bf97fe6543fb199f4ef874ca3522479.tar.gz
llvm-2f086f265bf97fe6543fb199f4ef874ca3522479.tar.bz2
[APFloat] Add E4M3B11FNUZ
X. Sun et al. (https://dl.acm.org/doi/10.5555/3454287.3454728) published a paper showing that an FP format with 4 bits of exponent, 3 bits of significand and an exponent bias of 11 would work quite well for ML applications. Google hardware supports a variant of this format where 0x80 is used to represent NaN, as in the Float8E4M3FNUZ format. Just like the Float8E4M3FNUZ format, this format does not support -0 and values which would map to it will become +0. This format is proposed for inclusion in OpenXLA's StableHLO dialect: https://github.com/openxla/stablehlo/pull/1308 As part of inclusion in that dialect, APFloat needs to know how to handle this format. Differential Revision: https://reviews.llvm.org/D146441
Diffstat (limited to 'llvm/lib/Support/APFloat.cpp')
-rw-r--r--llvm/lib/Support/APFloat.cpp77
1 files changed, 75 insertions, 2 deletions
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index 0505382..97c811a 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -60,8 +60,9 @@ enum class fltNonfiniteBehavior {
IEEE754,
// This behavior is present in the Float8ExMyFN* types (Float8E4M3FN,
- // Float8E5M2FNUZ, and Float8E4M3FNUZ). There is no representation for Inf,
- // and operations that would ordinarily produce Inf produce NaN instead.
+ // Float8E5M2FNUZ, Float8E4M3FNUZ, and Float8E4M3B11FNUZ). There is no
+ // representation for Inf, and operations that would ordinarily produce Inf
+ // produce NaN instead.
// The details of the NaN representation(s) in this form are determined by the
// `fltNanEncoding` enum. We treat all NaNs as quiet, as the available
// encodings do not distinguish between signalling and quiet NaN.
@@ -138,6 +139,13 @@ struct fltSemantics {
8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};
static const fltSemantics semFloat8E4M3FNUZ = {
7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
+ static const fltSemantics semFloat8E4M3B11FNUZ = {
+ 4,
+ -10,
+ 4,
+ 8,
+ fltNonfiniteBehavior::NanOnly,
+ fltNanEncoding::NegativeZero};
static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
static const fltSemantics semBogus = {0, 0, 0, 0};
@@ -201,6 +209,8 @@ struct fltSemantics {
return Float8E4M3FN();
case S_Float8E4M3FNUZ:
return Float8E4M3FNUZ();
+ case S_Float8E4M3B11FNUZ:
+ return Float8E4M3B11FNUZ();
case S_x87DoubleExtended:
return x87DoubleExtended();
}
@@ -229,6 +239,8 @@ struct fltSemantics {
return S_Float8E4M3FN;
else if (&Sem == &llvm::APFloat::Float8E4M3FNUZ())
return S_Float8E4M3FNUZ;
+ else if (&Sem == &llvm::APFloat::Float8E4M3B11FNUZ())
+ return S_Float8E4M3B11FNUZ;
else if (&Sem == &llvm::APFloat::x87DoubleExtended())
return S_x87DoubleExtended;
else
@@ -259,6 +271,9 @@ struct fltSemantics {
const fltSemantics &APFloatBase::Float8E4M3FNUZ() {
return semFloat8E4M3FNUZ;
}
+ const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
+ return semFloat8E4M3B11FNUZ;
+ }
const fltSemantics &APFloatBase::x87DoubleExtended() {
return semX87DoubleExtended;
}
@@ -3709,6 +3724,33 @@ APInt IEEEFloat::convertFloat8E4M3FNUZAPFloatToAPInt() const {
(mysignificand & 0x7)));
}
+APInt IEEEFloat::convertFloat8E4M3B11FNUZAPFloatToAPInt() const {
+ assert(semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ);
+ assert(partCount() == 1);
+
+ uint32_t myexponent, mysignificand;
+
+ if (isFiniteNonZero()) {
+ myexponent = exponent + 11; // bias
+ mysignificand = (uint32_t)*significandParts();
+ if (myexponent == 1 && !(mysignificand & 0x8))
+ myexponent = 0; // denormal
+ } else if (category == fcZero) {
+ myexponent = 0;
+ mysignificand = 0;
+ } else if (category == fcInfinity) {
+ myexponent = 0;
+ mysignificand = 0;
+ } else {
+ assert(category == fcNaN && "Unknown category!");
+ myexponent = 0;
+ mysignificand = (uint32_t)*significandParts();
+ }
+
+ return APInt(8, (((sign & 1) << 7) | ((myexponent & 0xf) << 3) |
+ (mysignificand & 0x7)));
+}
+
// This function creates an APInt that is just a bit map of the floating
// point constant as it would appear in memory. It is not a conversion,
// and treating the result as a normal integer is unlikely to be useful.
@@ -3744,6 +3786,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FNUZ)
return convertFloat8E4M3FNUZAPFloatToAPInt();
+ if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ)
+ return convertFloat8E4M3B11FNUZAPFloatToAPInt();
+
assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
"unknown format!");
return convertF80LongDoubleAPFloatToAPInt();
@@ -4077,6 +4122,32 @@ void IEEEFloat::initFromFloat8E4M3FNUZAPInt(const APInt &api) {
}
}
+void IEEEFloat::initFromFloat8E4M3B11FNUZAPInt(const APInt &api) {
+ uint32_t i = (uint32_t)*api.getRawData();
+ uint32_t myexponent = (i >> 3) & 0xf;
+ uint32_t mysignificand = i & 0x7;
+
+ initialize(&semFloat8E4M3B11FNUZ);
+ assert(partCount() == 1);
+
+ sign = i >> 7;
+ if (myexponent == 0 && mysignificand == 0 && sign == 0) {
+ makeZero(sign);
+ } else if (myexponent == 0 && mysignificand == 0 && sign == 1) {
+ category = fcNaN;
+ exponent = exponentNaN();
+ *significandParts() = mysignificand;
+ } else {
+ category = fcNormal;
+ exponent = myexponent - 11; // bias
+ *significandParts() = mysignificand;
+ if (myexponent == 0) // denormal
+ exponent = -10;
+ else
+ *significandParts() |= 0x8; // integer bit
+ }
+}
+
/// Treat api as containing the bits of a floating point number.
void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
assert(api.getBitWidth() == Sem->sizeInBits);
@@ -4102,6 +4173,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
return initFromFloat8E4M3FNAPInt(api);
if (Sem == &semFloat8E4M3FNUZ)
return initFromFloat8E4M3FNUZAPInt(api);
+ if (Sem == &semFloat8E4M3B11FNUZ)
+ return initFromFloat8E4M3B11FNUZAPInt(api);
llvm_unreachable(nullptr);
}