diff options
author | Ties Stuij <ties.stuij@arm.com> | 2020-03-31 23:49:38 +0100 |
---|---|---|
committer | Ties Stuij <ties.stuij@arm.com> | 2020-05-15 14:43:43 +0100 |
commit | 8c24f33158d81d5f4b0c5d27c2f07396f0f1484b (patch) | |
tree | b78f8dec4d437ddaad0c62f98ef087f19da271bf /llvm/lib/IR/Constants.cpp | |
parent | 7063a83a7cca45a9b12a7e447c90abe681f6ebaf (diff) | |
download | llvm-8c24f33158d81d5f4b0c5d27c2f07396f0f1484b.zip llvm-8c24f33158d81d5f4b0c5d27c2f07396f0f1484b.tar.gz llvm-8c24f33158d81d5f4b0c5d27c2f07396f0f1484b.tar.bz2 |
[IR][BFloat] Add BFloat IR type
Summary:
The BFloat IR type is introduced to provide support for, initially, the BFloat16
datatype introduced with the Armv8.6 architecture (optional from Armv8.2
onwards). It has an 8-bit exponent and a 7-bit mantissa and behaves like an IEEE
754 floating point IR type.
This is part of a patch series upstreaming Armv8.6 features. Subsequent patches
will upstream intrinsics support and C-lang support for BFloat.
Reviewers: SjoerdMeijer, rjmccall, rsmith, liutianle, RKSimon, craig.topper, jfb, LukeGeeson, sdesmalen, deadalnix, ctetreau
Subscribers: hiraditya, llvm-commits, danielkiss, arphaman, kristof.beyls, dexonsmith
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D78190
Diffstat (limited to 'llvm/lib/IR/Constants.cpp')
-rw-r--r-- | llvm/lib/IR/Constants.cpp | 106 |
1 files changed, 71 insertions, 35 deletions
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index 5a3c6a4..88971d8 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -332,6 +332,9 @@ Constant *Constant::getNullValue(Type *Ty) { case Type::HalfTyID: return ConstantFP::get(Ty->getContext(), APFloat::getZero(APFloat::IEEEhalf())); + case Type::BFloatTyID: + return ConstantFP::get(Ty->getContext(), + APFloat::getZero(APFloat::BFloat())); case Type::FloatTyID: return ConstantFP::get(Ty->getContext(), APFloat::getZero(APFloat::IEEEsingle())); @@ -386,8 +389,8 @@ Constant *Constant::getAllOnesValue(Type *Ty) { APInt::getAllOnesValue(ITy->getBitWidth())); if (Ty->isFloatingPointTy()) { - APFloat FL = APFloat::getAllOnesValue(Ty->getPrimitiveSizeInBits(), - !Ty->isPPC_FP128Ty()); + APFloat FL = APFloat::getAllOnesValue(Ty->getFltSemantics(), + Ty->getPrimitiveSizeInBits()); return ConstantFP::get(Ty->getContext(), FL); } @@ -763,6 +766,8 @@ void ConstantInt::destroyConstantImpl() { static const fltSemantics *TypeToFloatSemantics(Type *Ty) { if (Ty->isHalfTy()) return &APFloat::IEEEhalf(); + if (Ty->isBFloatTy()) + return &APFloat::BFloat(); if (Ty->isFloatTy()) return &APFloat::IEEEsingle(); if (Ty->isDoubleTy()) @@ -880,6 +885,8 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) { Type *Ty; if (&V.getSemantics() == &APFloat::IEEEhalf()) Ty = Type::getHalfTy(Context); + else if (&V.getSemantics() == &APFloat::BFloat()) + Ty = Type::getBFloatTy(Context); else if (&V.getSemantics() == &APFloat::IEEEsingle()) Ty = Type::getFloatTy(Context); else if (&V.getSemantics() == &APFloat::IEEEdouble()) @@ -1029,7 +1036,7 @@ static Constant *getFPSequenceIfElementsMatch(ArrayRef<Constant *> V) { Elts.push_back(CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); else return nullptr; - return SequentialTy::getFP(V[0]->getContext(), Elts); + return SequentialTy::getFP(V[0]->getType(), Elts); } template <typename SequenceTy> @@ -1048,7 +1055,7 @@ static Constant *getSequenceIfElementsMatch(Constant *C, else if (CI->getType()->isIntegerTy(64)) return getIntSequenceIfElementsMatch<SequenceTy, uint64_t>(V); } else if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) { - if (CFP->getType()->isHalfTy()) + if (CFP->getType()->isHalfTy() || CFP->getType()->isBFloatTy()) return getFPSequenceIfElementsMatch<SequenceTy, uint16_t>(V); else if (CFP->getType()->isFloatTy()) return getFPSequenceIfElementsMatch<SequenceTy, uint32_t>(V); @@ -1421,6 +1428,12 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) { Val2.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &losesInfo); return !losesInfo; } + case Type::BFloatTyID: { + if (&Val2.getSemantics() == &APFloat::BFloat()) + return true; + Val2.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &losesInfo); + return !losesInfo; + } case Type::FloatTyID: { if (&Val2.getSemantics() == &APFloat::IEEEsingle()) return true; @@ -1429,6 +1442,7 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) { } case Type::DoubleTyID: { if (&Val2.getSemantics() == &APFloat::IEEEhalf() || + &Val2.getSemantics() == &APFloat::BFloat() || &Val2.getSemantics() == &APFloat::IEEEsingle() || &Val2.getSemantics() == &APFloat::IEEEdouble()) return true; @@ -1437,16 +1451,19 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat& Val) { } case Type::X86_FP80TyID: return &Val2.getSemantics() == &APFloat::IEEEhalf() || + &Val2.getSemantics() == &APFloat::BFloat() || &Val2.getSemantics() == &APFloat::IEEEsingle() || &Val2.getSemantics() == &APFloat::IEEEdouble() || &Val2.getSemantics() == &APFloat::x87DoubleExtended(); case Type::FP128TyID: return &Val2.getSemantics() == &APFloat::IEEEhalf() || + &Val2.getSemantics() == &APFloat::BFloat() || &Val2.getSemantics() == &APFloat::IEEEsingle() || &Val2.getSemantics() == &APFloat::IEEEdouble() || &Val2.getSemantics() == &APFloat::IEEEquad(); case Type::PPC_FP128TyID: return &Val2.getSemantics() == &APFloat::IEEEhalf() || + &Val2.getSemantics() == &APFloat::BFloat() || &Val2.getSemantics() == &APFloat::IEEEsingle() || &Val2.getSemantics() == &APFloat::IEEEdouble() || &Val2.getSemantics() == &APFloat::PPCDoubleDouble(); @@ -2562,7 +2579,8 @@ StringRef ConstantDataSequential::getRawDataValues() const { } bool ConstantDataSequential::isElementTypeCompatible(Type *Ty) { - if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) return true; + if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() || Ty->isDoubleTy()) + return true; if (auto *IT = dyn_cast<IntegerType>(Ty)) { switch (IT->getBitWidth()) { case 8: @@ -2680,26 +2698,29 @@ void ConstantDataSequential::destroyConstantImpl() { Next = nullptr; } -/// getFP() constructors - Return a constant with array type with an element -/// count and element type of float with precision matching the number of -/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits, -/// double for 64bits) Note that this can return a ConstantAggregateZero -/// object. -Constant *ConstantDataArray::getFP(LLVMContext &Context, - ArrayRef<uint16_t> Elts) { - Type *Ty = ArrayType::get(Type::getHalfTy(Context), Elts.size()); +/// getFP() constructors - Return a constant of array type with a float +/// element type taken from argument `ElementType', and count taken from +/// argument `Elts'. The amount of bits of the contained type must match the +/// number of bits of the type contained in the passed in ArrayRef. +/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note +/// that this can return a ConstantAggregateZero object. +Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint16_t> Elts) { + assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) && + "Element type is not a 16-bit float type"); + Type *Ty = ArrayType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast<const char *>(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 2), Ty); } -Constant *ConstantDataArray::getFP(LLVMContext &Context, - ArrayRef<uint32_t> Elts) { - Type *Ty = ArrayType::get(Type::getFloatTy(Context), Elts.size()); +Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint32_t> Elts) { + assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type"); + Type *Ty = ArrayType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast<const char *>(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 4), Ty); } -Constant *ConstantDataArray::getFP(LLVMContext &Context, - ArrayRef<uint64_t> Elts) { - Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size()); +Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint64_t> Elts) { + assert(ElementType->isDoubleTy() && + "Element type is not a 64-bit float type"); + Type *Ty = ArrayType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast<const char *>(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 8), Ty); } @@ -2751,26 +2772,32 @@ Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<double> Elts) { return getImpl(StringRef(Data, Elts.size() * 8), Ty); } -/// getFP() constructors - Return a constant with vector type with an element -/// count and element type of float with the precision matching the number of -/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits, -/// double for 64bits) Note that this can return a ConstantAggregateZero -/// object. -Constant *ConstantDataVector::getFP(LLVMContext &Context, +/// getFP() constructors - Return a constant of vector type with a float +/// element type taken from argument `ElementType', and count taken from +/// argument `Elts'. The amount of bits of the contained type must match the +/// number of bits of the type contained in the passed in ArrayRef. +/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note +/// that this can return a ConstantAggregateZero object. +Constant *ConstantDataVector::getFP(Type *ElementType, ArrayRef<uint16_t> Elts) { - Type *Ty = VectorType::get(Type::getHalfTy(Context), Elts.size()); + assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) && + "Element type is not a 16-bit float type"); + Type *Ty = VectorType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast<const char *>(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 2), Ty); } -Constant *ConstantDataVector::getFP(LLVMContext &Context, +Constant *ConstantDataVector::getFP(Type *ElementType, ArrayRef<uint32_t> Elts) { - Type *Ty = VectorType::get(Type::getFloatTy(Context), Elts.size()); + assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type"); + Type *Ty = VectorType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast<const char *>(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 4), Ty); } -Constant *ConstantDataVector::getFP(LLVMContext &Context, +Constant *ConstantDataVector::getFP(Type *ElementType, ArrayRef<uint64_t> Elts) { - Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size()); + assert(ElementType->isDoubleTy() && + "Element type is not a 64-bit float type"); + Type *Ty = VectorType::get(ElementType, Elts.size()); const char *Data = reinterpret_cast<const char *>(Elts.data()); return getImpl(StringRef(Data, Elts.size() * 8), Ty); } @@ -2800,17 +2827,22 @@ Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) { if (CFP->getType()->isHalfTy()) { SmallVector<uint16_t, 16> Elts( NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); - return getFP(V->getContext(), Elts); + return getFP(V->getType(), Elts); + } + if (CFP->getType()->isBFloatTy()) { + SmallVector<uint16_t, 16> Elts( + NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); + return getFP(V->getType(), Elts); } if (CFP->getType()->isFloatTy()) { SmallVector<uint32_t, 16> Elts( NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); - return getFP(V->getContext(), Elts); + return getFP(V->getType(), Elts); } if (CFP->getType()->isDoubleTy()) { SmallVector<uint64_t, 16> Elts( NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue()); - return getFP(V->getContext(), Elts); + return getFP(V->getType(), Elts); } } return ConstantVector::getSplat({NumElts, false}, V); @@ -2875,6 +2907,10 @@ APFloat ConstantDataSequential::getElementAsAPFloat(unsigned Elt) const { auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr); return APFloat(APFloat::IEEEhalf(), APInt(16, EltVal)); } + case Type::BFloatTyID: { + auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr); + return APFloat(APFloat::BFloat(), APInt(16, EltVal)); + } case Type::FloatTyID: { auto EltVal = *reinterpret_cast<const uint32_t *>(EltPtr); return APFloat(APFloat::IEEEsingle(), APInt(32, EltVal)); @@ -2899,8 +2935,8 @@ double ConstantDataSequential::getElementAsDouble(unsigned Elt) const { } Constant *ConstantDataSequential::getElementAsConstant(unsigned Elt) const { - if (getElementType()->isHalfTy() || getElementType()->isFloatTy() || - getElementType()->isDoubleTy()) + if (getElementType()->isHalfTy() || getElementType()->isBFloatTy() || + getElementType()->isFloatTy() || getElementType()->isDoubleTy()) return ConstantFP::get(getContext(), getElementAsAPFloat(Elt)); return ConstantInt::get(getElementType(), getElementAsInteger(Elt)); |