aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
authorDurgadoss R <durgadossr@nvidia.com>2024-06-14 14:17:37 +0530
committerGitHub <noreply@github.com>2024-06-14 14:17:37 +0530
commit880d37038c7bbff53ef02c9d6b01cbbc87875243 (patch)
treeed2d892ed0615eb537efd84a5356e83ec59948a3 /llvm/lib
parente83adfe59632d2e2f8ff26db33087ba7fb754485 (diff)
downloadllvm-880d37038c7bbff53ef02c9d6b01cbbc87875243.zip
llvm-880d37038c7bbff53ef02c9d6b01cbbc87875243.tar.gz
llvm-880d37038c7bbff53ef02c9d6b01cbbc87875243.tar.bz2
[APFloat] Add APFloat support for FP4 data type (#95392)
This patch adds APFloat type support for the E2M1 FP4 datatype. The definitions for this format are detailed in section 5.3.3 of the OCP specification, which can be accessed here: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Support/APFloat.cpp25
1 files changed, 23 insertions, 2 deletions
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index 1209bf7..47618bc 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -69,8 +69,8 @@ enum class fltNonfiniteBehavior {
// encodings do not distinguish between signalling and quiet NaN.
NanOnly,
- // This behavior is present in Float6E3M2FN and Float6E2M3FN types,
- // which do not support Inf or NaN values.
+ // This behavior is present in Float6E3M2FN, Float6E2M3FN, and
+ // Float4E2M1FN types, which do not support Inf or NaN values.
FiniteOnly,
};
@@ -147,6 +147,8 @@ static constexpr fltSemantics semFloat6E3M2FN = {
4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly};
static constexpr fltSemantics semFloat6E2M3FN = {
2, 0, 4, 6, fltNonfiniteBehavior::FiniteOnly};
+static constexpr fltSemantics semFloat4E2M1FN = {
+ 2, 0, 2, 4, fltNonfiniteBehavior::FiniteOnly};
static constexpr fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
static constexpr fltSemantics semBogus = {0, 0, 0, 0};
@@ -218,6 +220,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
return Float6E3M2FN();
case S_Float6E2M3FN:
return Float6E2M3FN();
+ case S_Float4E2M1FN:
+ return Float4E2M1FN();
case S_x87DoubleExtended:
return x87DoubleExtended();
}
@@ -254,6 +258,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
return S_Float6E3M2FN;
else if (&Sem == &llvm::APFloat::Float6E2M3FN())
return S_Float6E2M3FN;
+ else if (&Sem == &llvm::APFloat::Float4E2M1FN())
+ return S_Float4E2M1FN;
else if (&Sem == &llvm::APFloat::x87DoubleExtended())
return S_x87DoubleExtended;
else
@@ -278,6 +284,7 @@ const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; }
const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; }
const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; }
+const fltSemantics &APFloatBase::Float4E2M1FN() { return semFloat4E2M1FN; }
const fltSemantics &APFloatBase::x87DoubleExtended() {
return semX87DoubleExtended;
}
@@ -3640,6 +3647,11 @@ APInt IEEEFloat::convertFloat6E2M3FNAPFloatToAPInt() const {
return convertIEEEFloatToAPInt<semFloat6E2M3FN>();
}
+APInt IEEEFloat::convertFloat4E2M1FNAPFloatToAPInt() const {
+ assert(partCount() == 1);
+ return convertIEEEFloatToAPInt<semFloat4E2M1FN>();
+}
+
// This function creates an APInt that is just a bit map of the floating
// point constant as it would appear in memory. It is not a conversion,
// and treating the result as a normal integer is unlikely to be useful.
@@ -3687,6 +3699,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
if (semantics == (const llvm::fltSemantics *)&semFloat6E2M3FN)
return convertFloat6E2M3FNAPFloatToAPInt();
+ if (semantics == (const llvm::fltSemantics *)&semFloat4E2M1FN)
+ return convertFloat4E2M1FNAPFloatToAPInt();
+
assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
"unknown format!");
return convertF80LongDoubleAPFloatToAPInt();
@@ -3911,6 +3926,10 @@ void IEEEFloat::initFromFloat6E2M3FNAPInt(const APInt &api) {
initFromIEEEAPInt<semFloat6E2M3FN>(api);
}
+void IEEEFloat::initFromFloat4E2M1FNAPInt(const APInt &api) {
+ initFromIEEEAPInt<semFloat4E2M1FN>(api);
+}
+
/// Treat api as containing the bits of a floating point number.
void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
assert(api.getBitWidth() == Sem->sizeInBits);
@@ -3944,6 +3963,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
return initFromFloat6E3M2FNAPInt(api);
if (Sem == &semFloat6E2M3FN)
return initFromFloat6E2M3FNAPInt(api);
+ if (Sem == &semFloat4E2M1FN)
+ return initFromFloat4E2M1FNAPInt(api);
llvm_unreachable(nullptr);
}