aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/IR/Constants.cpp
diff options
context:
space:
mode:
authorTies Stuij <ties.stuij@arm.com>2020-03-31 23:49:38 +0100
committerTies Stuij <ties.stuij@arm.com>2020-05-15 14:43:43 +0100
commit8c24f33158d81d5f4b0c5d27c2f07396f0f1484b (patch)
treeb78f8dec4d437ddaad0c62f98ef087f19da271bf /llvm/lib/IR/Constants.cpp
parent7063a83a7cca45a9b12a7e447c90abe681f6ebaf (diff)
downloadllvm-8c24f33158d81d5f4b0c5d27c2f07396f0f1484b.zip
llvm-8c24f33158d81d5f4b0c5d27c2f07396f0f1484b.tar.gz
llvm-8c24f33158d81d5f4b0c5d27c2f07396f0f1484b.tar.bz2
[IR][BFloat] Add BFloat IR type
Summary: The BFloat IR type is introduced to provide support for, initially, the BFloat16 datatype introduced with the Armv8.6 architecture (optional from Armv8.2 onwards). It has an 8-bit exponent and a 7-bit mantissa and behaves like an IEEE 754 floating point IR type. This is part of a patch series upstreaming Armv8.6 features. Subsequent patches will upstream intrinsics support and C-lang support for BFloat. Reviewers: SjoerdMeijer, rjmccall, rsmith, liutianle, RKSimon, craig.topper, jfb, LukeGeeson, sdesmalen, deadalnix, ctetreau Subscribers: hiraditya, llvm-commits, danielkiss, arphaman, kristof.beyls, dexonsmith Tags: #llvm Differential Revision: https://reviews.llvm.org/D78190
Diffstat (limited to 'llvm/lib/IR/Constants.cpp')
-rw-r--r--llvm/lib/IR/Constants.cpp106
1 files changed, 71 insertions, 35 deletions
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 5a3c6a4..88971d8 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -332,6 +332,9 @@ Constant *Constant::getNullValue(Type *Ty) {
case Type::HalfTyID:
return ConstantFP::get(Ty->getContext(),
APFloat::getZero(APFloat::IEEEhalf()));
+ case Type::BFloatTyID:
+ return ConstantFP::get(Ty->getContext(),
+ APFloat::getZero(APFloat::BFloat()));
case Type::FloatTyID:
return ConstantFP::get(Ty->getContext(),
APFloat::getZero(APFloat::IEEEsingle()));
@@ -386,8 +389,8 @@ Constant *Constant::getAllOnesValue(Type *Ty) {
APInt::getAllOnesValue(ITy->getBitWidth()));
if (Ty->isFloatingPointTy()) {
- APFloat FL = APFloat::getAllOnesValue(Ty->getPrimitiveSizeInBits(),
- !Ty->isPPC_FP128Ty());
+ APFloat FL = APFloat::getAllOnesValue(Ty->getFltSemantics(),
+ Ty->getPrimitiveSizeInBits());
return ConstantFP::get(Ty->getContext(), FL);
}
@@ -763,6 +766,8 @@ void ConstantInt::destroyConstantImpl() {
static const fltSemantics *TypeToFloatSemantics(Type *Ty) {
if (Ty->isHalfTy())
return &APFloat::IEEEhalf();
+ if (Ty->isBFloatTy())
+ return &APFloat::BFloat();
if (Ty->isFloatTy())
return &APFloat::IEEEsingle();
if (Ty->isDoubleTy())
@@ -880,6 +885,8 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) {
Type *Ty;
if (&V.getSemantics() == &APFloat::IEEEhalf())
Ty = Type::getHalfTy(Context);
+ else if (&V.getSemantics() == &APFloat::BFloat())
+ Ty = Type::getBFloatTy(Context);
else if (&V.getSemantics() == &APFloat::IEEEsingle())
Ty = Type::getFloatTy(Context);
else if (&V.getSemantics() == &APFloat::IEEEdouble())
@@ -1029,7 +1036,7 @@ static Constant *getFPSequenceIfElementsMatch(ArrayRef<Constant *> V) {
Elts.push_back(CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
else
return nullptr;
- return SequentialTy::getFP(V[0]->getContext(), Elts);
+ return SequentialTy::getFP(V[0]->getType(), Elts);
}
template <typename SequenceTy>
@@ -1048,7 +1055,7 @@ static Constant *getSequenceIfElementsMatch(Constant *C,
else if (CI->getType()->isIntegerTy(64))
return getIntSequenceIfElementsMatch<SequenceTy, uint64_t>(V);
} else if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
- if (CFP->getType()->isHalfTy())
+ if (CFP->getType()->isHalfTy() || CFP->getType()->isBFloatTy())
return getFPSequenceIfElementsMatch<SequenceTy, uint16_t>(V);
else if (CFP->getType()->isFloatTy())
return getFPSequenceIfElementsMatch<SequenceTy, uint32_t>(V);
@@ -1421,6 +1428,12 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
Val2.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &losesInfo);
return !losesInfo;
}
+ case Type::BFloatTyID: {
+ if (&Val2.getSemantics() == &APFloat::BFloat())
+ return true;
+ Val2.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &losesInfo);
+ return !losesInfo;
+ }
case Type::FloatTyID: {
if (&Val2.getSemantics() == &APFloat::IEEEsingle())
return true;
@@ -1429,6 +1442,7 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
}
case Type::DoubleTyID: {
if (&Val2.getSemantics() == &APFloat::IEEEhalf() ||
+ &Val2.getSemantics() == &APFloat::BFloat() ||
&Val2.getSemantics() == &APFloat::IEEEsingle() ||
&Val2.getSemantics() == &APFloat::IEEEdouble())
return true;
@@ -1437,16 +1451,19 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) {
}
case Type::X86_FP80TyID:
return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+ &Val2.getSemantics() == &APFloat::BFloat() ||
&Val2.getSemantics() == &APFloat::IEEEsingle() ||
&Val2.getSemantics() == &APFloat::IEEEdouble() ||
&Val2.getSemantics() == &APFloat::x87DoubleExtended();
case Type::FP128TyID:
return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+ &Val2.getSemantics() == &APFloat::BFloat() ||
&Val2.getSemantics() == &APFloat::IEEEsingle() ||
&Val2.getSemantics() == &APFloat::IEEEdouble() ||
&Val2.getSemantics() == &APFloat::IEEEquad();
case Type::PPC_FP128TyID:
return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+ &Val2.getSemantics() == &APFloat::BFloat() ||
&Val2.getSemantics() == &APFloat::IEEEsingle() ||
&Val2.getSemantics() == &APFloat::IEEEdouble() ||
&Val2.getSemantics() == &APFloat::PPCDoubleDouble();
@@ -2562,7 +2579,8 @@ StringRef ConstantDataSequential::getRawDataValues() const {
}
bool ConstantDataSequential::isElementTypeCompatible(Type *Ty) {
- if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) return true;
+ if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() || Ty->isDoubleTy())
+ return true;
if (auto *IT = dyn_cast<IntegerType>(Ty)) {
switch (IT->getBitWidth()) {
case 8:
@@ -2680,26 +2698,29 @@ void ConstantDataSequential::destroyConstantImpl() {
Next = nullptr;
}
-/// getFP() constructors - Return a constant with array type with an element
-/// count and element type of float with precision matching the number of
-/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
-/// double for 64bits) Note that this can return a ConstantAggregateZero
-/// object.
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
- ArrayRef<uint16_t> Elts) {
- Type *Ty = ArrayType::get(Type::getHalfTy(Context), Elts.size());
+/// getFP() constructors - Return a constant of array type with a float
+/// element type taken from argument `ElementType', and count taken from
+/// argument `Elts'. The amount of bits of the contained type must match the
+/// number of bits of the type contained in the passed in ArrayRef.
+/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+/// that this can return a ConstantAggregateZero object.
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint16_t> Elts) {
+ assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) &&
+ "Element type is not a 16-bit float type");
+ Type *Ty = ArrayType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 2), Ty);
}
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
- ArrayRef<uint32_t> Elts) {
- Type *Ty = ArrayType::get(Type::getFloatTy(Context), Elts.size());
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint32_t> Elts) {
+ assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type");
+ Type *Ty = ArrayType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 4), Ty);
}
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
- ArrayRef<uint64_t> Elts) {
- Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size());
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint64_t> Elts) {
+ assert(ElementType->isDoubleTy() &&
+ "Element type is not a 64-bit float type");
+ Type *Ty = ArrayType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 8), Ty);
}
@@ -2751,26 +2772,32 @@ Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<double> Elts) {
return getImpl(StringRef(Data, Elts.size() * 8), Ty);
}
-/// getFP() constructors - Return a constant with vector type with an element
-/// count and element type of float with the precision matching the number of
-/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
-/// double for 64bits) Note that this can return a ConstantAggregateZero
-/// object.
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+/// getFP() constructors - Return a constant of vector type with a float
+/// element type taken from argument `ElementType', and count taken from
+/// argument `Elts'. The amount of bits of the contained type must match the
+/// number of bits of the type contained in the passed in ArrayRef.
+/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+/// that this can return a ConstantAggregateZero object.
+Constant *ConstantDataVector::getFP(Type *ElementType,
ArrayRef<uint16_t> Elts) {
- Type *Ty = VectorType::get(Type::getHalfTy(Context), Elts.size());
+ assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) &&
+ "Element type is not a 16-bit float type");
+ Type *Ty = VectorType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 2), Ty);
}
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+Constant *ConstantDataVector::getFP(Type *ElementType,
ArrayRef<uint32_t> Elts) {
- Type *Ty = VectorType::get(Type::getFloatTy(Context), Elts.size());
+ assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type");
+ Type *Ty = VectorType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 4), Ty);
}
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+Constant *ConstantDataVector::getFP(Type *ElementType,
ArrayRef<uint64_t> Elts) {
- Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size());
+ assert(ElementType->isDoubleTy() &&
+ "Element type is not a 64-bit float type");
+ Type *Ty = VectorType::get(ElementType, Elts.size());
const char *Data = reinterpret_cast<const char *>(Elts.data());
return getImpl(StringRef(Data, Elts.size() * 8), Ty);
}
@@ -2800,17 +2827,22 @@ Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
if (CFP->getType()->isHalfTy()) {
SmallVector<uint16_t, 16> Elts(
NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
- return getFP(V->getContext(), Elts);
+ return getFP(V->getType(), Elts);
+ }
+ if (CFP->getType()->isBFloatTy()) {
+ SmallVector<uint16_t, 16> Elts(
+ NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+ return getFP(V->getType(), Elts);
}
if (CFP->getType()->isFloatTy()) {
SmallVector<uint32_t, 16> Elts(
NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
- return getFP(V->getContext(), Elts);
+ return getFP(V->getType(), Elts);
}
if (CFP->getType()->isDoubleTy()) {
SmallVector<uint64_t, 16> Elts(
NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
- return getFP(V->getContext(), Elts);
+ return getFP(V->getType(), Elts);
}
}
return ConstantVector::getSplat({NumElts, false}, V);
@@ -2875,6 +2907,10 @@ APFloat ConstantDataSequential::getElementAsAPFloat(unsigned Elt) const {
auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
return APFloat(APFloat::IEEEhalf(), APInt(16, EltVal));
}
+ case Type::BFloatTyID: {
+ auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
+ return APFloat(APFloat::BFloat(), APInt(16, EltVal));
+ }
case Type::FloatTyID: {
auto EltVal = *reinterpret_cast<const uint32_t *>(EltPtr);
return APFloat(APFloat::IEEEsingle(), APInt(32, EltVal));
@@ -2899,8 +2935,8 @@ double ConstantDataSequential::getElementAsDouble(unsigned Elt) const {
}
Constant *ConstantDataSequential::getElementAsConstant(unsigned Elt) const {
- if (getElementType()->isHalfTy() || getElementType()->isFloatTy() ||
- getElementType()->isDoubleTy())
+ if (getElementType()->isHalfTy() || getElementType()->isBFloatTy() ||
+ getElementType()->isFloatTy() || getElementType()->isDoubleTy())
return ConstantFP::get(getContext(), getElementAsAPFloat(Elt));
return ConstantInt::get(getElementType(), getElementAsInteger(Elt));