diff options
author | David Majnemer <david.majnemer@gmail.com> | 2023-03-09 23:10:57 +0000 |
---|---|---|
committer | David Majnemer <david.majnemer@gmail.com> | 2023-03-24 20:06:40 +0000 |
commit | 2f086f265bf97fe6543fb199f4ef874ca3522479 (patch) | |
tree | 88227d7be116b19dd22f0a6939872d9be15853f2 /llvm/lib/Support/APFloat.cpp | |
parent | 5a9bad171be5dfdf9430a0f6cbff14d29ca54181 (diff) | |
download | llvm-2f086f265bf97fe6543fb199f4ef874ca3522479.zip llvm-2f086f265bf97fe6543fb199f4ef874ca3522479.tar.gz llvm-2f086f265bf97fe6543fb199f4ef874ca3522479.tar.bz2 |
[APFloat] Add E4M3B11FNUZ
X. Sun et al. (https://dl.acm.org/doi/10.5555/3454287.3454728) published
a paper showing that an FP format with 4 bits of exponent, 3 bits of
significand and an exponent bias of 11 would work quite well for ML
applications.
Google hardware supports a variant of this format where 0x80 is used to
represent NaN, as in the Float8E4M3FNUZ format. Just like the
Float8E4M3FNUZ format, this format does not support -0 and values which
would map to it will become +0.
This format is proposed for inclusion in OpenXLA's StableHLO dialect: https://github.com/openxla/stablehlo/pull/1308
As part of inclusion in that dialect, APFloat needs to know how to
handle this format.
Differential Revision: https://reviews.llvm.org/D146441
Diffstat (limited to 'llvm/lib/Support/APFloat.cpp')
-rw-r--r-- | llvm/lib/Support/APFloat.cpp | 77 |
1 files changed, 75 insertions, 2 deletions
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp index 0505382..97c811a 100644 --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -60,8 +60,9 @@ enum class fltNonfiniteBehavior { IEEE754, // This behavior is present in the Float8ExMyFN* types (Float8E4M3FN, - // Float8E5M2FNUZ, and Float8E4M3FNUZ). There is no representation for Inf, - // and operations that would ordinarily produce Inf produce NaN instead. + // Float8E5M2FNUZ, Float8E4M3FNUZ, and Float8E4M3B11FNUZ). There is no + // representation for Inf, and operations that would ordinarily produce Inf + // produce NaN instead. // The details of the NaN representation(s) in this form are determined by the // `fltNanEncoding` enum. We treat all NaNs as quiet, as the available // encodings do not distinguish between signalling and quiet NaN. @@ -138,6 +139,13 @@ struct fltSemantics { 8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes}; static const fltSemantics semFloat8E4M3FNUZ = { 7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero}; + static const fltSemantics semFloat8E4M3B11FNUZ = { + 4, + -10, + 4, + 8, + fltNonfiniteBehavior::NanOnly, + fltNanEncoding::NegativeZero}; static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80}; static const fltSemantics semBogus = {0, 0, 0, 0}; @@ -201,6 +209,8 @@ struct fltSemantics { return Float8E4M3FN(); case S_Float8E4M3FNUZ: return Float8E4M3FNUZ(); + case S_Float8E4M3B11FNUZ: + return Float8E4M3B11FNUZ(); case S_x87DoubleExtended: return x87DoubleExtended(); } @@ -229,6 +239,8 @@ struct fltSemantics { return S_Float8E4M3FN; else if (&Sem == &llvm::APFloat::Float8E4M3FNUZ()) return S_Float8E4M3FNUZ; + else if (&Sem == &llvm::APFloat::Float8E4M3B11FNUZ()) + return S_Float8E4M3B11FNUZ; else if (&Sem == &llvm::APFloat::x87DoubleExtended()) return S_x87DoubleExtended; else @@ -259,6 +271,9 @@ struct fltSemantics { const fltSemantics &APFloatBase::Float8E4M3FNUZ() { return semFloat8E4M3FNUZ; } + const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() { + return semFloat8E4M3B11FNUZ; + } const fltSemantics &APFloatBase::x87DoubleExtended() { return semX87DoubleExtended; } @@ -3709,6 +3724,33 @@ APInt IEEEFloat::convertFloat8E4M3FNUZAPFloatToAPInt() const { (mysignificand & 0x7))); } +APInt IEEEFloat::convertFloat8E4M3B11FNUZAPFloatToAPInt() const { + assert(semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ); + assert(partCount() == 1); + + uint32_t myexponent, mysignificand; + + if (isFiniteNonZero()) { + myexponent = exponent + 11; // bias + mysignificand = (uint32_t)*significandParts(); + if (myexponent == 1 && !(mysignificand & 0x8)) + myexponent = 0; // denormal + } else if (category == fcZero) { + myexponent = 0; + mysignificand = 0; + } else if (category == fcInfinity) { + myexponent = 0; + mysignificand = 0; + } else { + assert(category == fcNaN && "Unknown category!"); + myexponent = 0; + mysignificand = (uint32_t)*significandParts(); + } + + return APInt(8, (((sign & 1) << 7) | ((myexponent & 0xf) << 3) | + (mysignificand & 0x7))); +} + // This function creates an APInt that is just a bit map of the floating // point constant as it would appear in memory. It is not a conversion, // and treating the result as a normal integer is unlikely to be useful. @@ -3744,6 +3786,9 @@ APInt IEEEFloat::bitcastToAPInt() const { if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FNUZ) return convertFloat8E4M3FNUZAPFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ) + return convertFloat8E4M3B11FNUZAPFloatToAPInt(); + assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended && "unknown format!"); return convertF80LongDoubleAPFloatToAPInt(); @@ -4077,6 +4122,32 @@ void IEEEFloat::initFromFloat8E4M3FNUZAPInt(const APInt &api) { } } +void IEEEFloat::initFromFloat8E4M3B11FNUZAPInt(const APInt &api) { + uint32_t i = (uint32_t)*api.getRawData(); + uint32_t myexponent = (i >> 3) & 0xf; + uint32_t mysignificand = i & 0x7; + + initialize(&semFloat8E4M3B11FNUZ); + assert(partCount() == 1); + + sign = i >> 7; + if (myexponent == 0 && mysignificand == 0 && sign == 0) { + makeZero(sign); + } else if (myexponent == 0 && mysignificand == 0 && sign == 1) { + category = fcNaN; + exponent = exponentNaN(); + *significandParts() = mysignificand; + } else { + category = fcNormal; + exponent = myexponent - 11; // bias + *significandParts() = mysignificand; + if (myexponent == 0) // denormal + exponent = -10; + else + *significandParts() |= 0x8; // integer bit + } +} + /// Treat api as containing the bits of a floating point number. void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { assert(api.getBitWidth() == Sem->sizeInBits); @@ -4102,6 +4173,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { return initFromFloat8E4M3FNAPInt(api); if (Sem == &semFloat8E4M3FNUZ) return initFromFloat8E4M3FNUZAPInt(api); + if (Sem == &semFloat8E4M3B11FNUZ) + return initFromFloat8E4M3B11FNUZAPInt(api); llvm_unreachable(nullptr); } |