diff options
author | Huihui Zhang <huihuiz@quicinc.com> | 2020-03-12 13:15:34 -0700 |
---|---|---|
committer | Huihui Zhang <huihuiz@quicinc.com> | 2020-03-12 13:22:41 -0700 |
commit | 118abf20173899e9e1667db1a9c850dc5570b6ae (patch) | |
tree | dba40cf35b91a20a47755929161a1021642889e4 /llvm/lib/IR/Constants.cpp | |
parent | e91feeed21ee16abdb73f6e8cd471a253136e2cf (diff) | |
download | llvm-118abf20173899e9e1667db1a9c850dc5570b6ae.zip llvm-118abf20173899e9e1667db1a9c850dc5570b6ae.tar.gz llvm-118abf20173899e9e1667db1a9c850dc5570b6ae.tar.bz2 |
[SVE] Update API ConstantVector::getSplat() to use ElementCount.
Summary:
Support ConstantInt::get() and Constant::getAllOnesValue() for scalable
vector type, this requires ConstantVector::getSplat() to take in 'ElementCount',
instead of 'unsigned' number of element count.
This change is needed for D73753.
Reviewers: sdesmalen, efriedma, apazos, spatel, huntergr, willlovett
Reviewed By: efriedma
Subscribers: tschuett, hiraditya, rkruppe, psnobl, cfe-commits, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D74386
Diffstat (limited to 'llvm/lib/IR/Constants.cpp')
-rw-r--r-- | llvm/lib/IR/Constants.cpp | 86 |
1 files changed, 53 insertions, 33 deletions
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index 399bd41..eb0e589 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -370,7 +370,7 @@ Constant *Constant::getIntegerValue(Type *Ty, const APInt &V) { // Broadcast a scalar to a vector, if necessary. if (VectorType *VTy = dyn_cast<VectorType>(Ty)) - C = ConstantVector::getSplat(VTy->getNumElements(), C); + C = ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -387,7 +387,7 @@ Constant *Constant::getAllOnesValue(Type *Ty) { } VectorType *VTy = cast<VectorType>(Ty); - return ConstantVector::getSplat(VTy->getNumElements(), + return ConstantVector::getSplat(VTy->getElementCount(), getAllOnesValue(VTy->getElementType())); } @@ -681,7 +681,7 @@ Constant *ConstantInt::getTrue(Type *Ty) { assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1."); ConstantInt *TrueC = ConstantInt::getTrue(Ty->getContext()); if (auto *VTy = dyn_cast<VectorType>(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), TrueC); + return ConstantVector::getSplat(VTy->getElementCount(), TrueC); return TrueC; } @@ -689,7 +689,7 @@ Constant *ConstantInt::getFalse(Type *Ty) { assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1."); ConstantInt *FalseC = ConstantInt::getFalse(Ty->getContext()); if (auto *VTy = dyn_cast<VectorType>(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), FalseC); + return ConstantVector::getSplat(VTy->getElementCount(), FalseC); return FalseC; } @@ -712,7 +712,7 @@ Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) { // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast<VectorType>(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -736,7 +736,7 @@ Constant *ConstantInt::get(Type *Ty, const APInt& V) { // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast<VectorType>(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -781,7 +781,7 @@ Constant *ConstantFP::get(Type *Ty, double V) { // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast<VectorType>(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -793,7 +793,7 @@ Constant *ConstantFP::get(Type *Ty, const APFloat &V) { // For vectors, broadcast the value. if (auto *VTy = dyn_cast<VectorType>(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -806,7 +806,7 @@ Constant *ConstantFP::get(Type *Ty, StringRef Str) { // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast<VectorType>(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -817,7 +817,7 @@ Constant *ConstantFP::getNaN(Type *Ty, bool Negative, uint64_t Payload) { Constant *C = get(Ty->getContext(), NaN); if (VectorType *VTy = dyn_cast<VectorType>(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -828,7 +828,7 @@ Constant *ConstantFP::getQNaN(Type *Ty, bool Negative, APInt *Payload) { Constant *C = get(Ty->getContext(), NaN); if (VectorType *VTy = dyn_cast<VectorType>(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -839,7 +839,7 @@ Constant *ConstantFP::getSNaN(Type *Ty, bool Negative, APInt *Payload) { Constant *C = get(Ty->getContext(), NaN); if (VectorType *VTy = dyn_cast<VectorType>(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -850,7 +850,7 @@ Constant *ConstantFP::getNegativeZero(Type *Ty) { Constant *C = get(Ty->getContext(), NegZero); if (VectorType *VTy = dyn_cast<VectorType>(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -898,7 +898,7 @@ Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) { Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative)); if (VectorType *VTy = dyn_cast<VectorType>(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -1204,15 +1204,35 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) { return nullptr; } -Constant *ConstantVector::getSplat(unsigned NumElts, Constant *V) { - // If this splat is compatible with ConstantDataVector, use it instead of - // ConstantVector. - if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) && - ConstantDataSequential::isElementTypeCompatible(V->getType())) - return ConstantDataVector::getSplat(NumElts, V); +Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) { + if (!EC.Scalable) { + // If this splat is compatible with ConstantDataVector, use it instead of + // ConstantVector. + if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) && + ConstantDataSequential::isElementTypeCompatible(V->getType())) + return ConstantDataVector::getSplat(EC.Min, V); - SmallVector<Constant*, 32> Elts(NumElts, V); - return get(Elts); + SmallVector<Constant *, 32> Elts(EC.Min, V); + return get(Elts); + } + + Type *VTy = VectorType::get(V->getType(), EC); + + if (V->isNullValue()) + return ConstantAggregateZero::get(VTy); + else if (isa<UndefValue>(V)) + return UndefValue::get(VTy); + + Type *I32Ty = Type::getInt32Ty(VTy->getContext()); + + // Move scalar into vector. + Constant *UndefV = UndefValue::get(VTy); + V = ConstantExpr::getInsertElement(UndefV, V, ConstantInt::get(I32Ty, 0)); + // Build shuffle mask to perform the splat. + Type *MaskTy = VectorType::get(I32Ty, EC); + Constant *Zeros = ConstantAggregateZero::get(MaskTy); + // Splat. + return ConstantExpr::getShuffleVector(V, UndefV, Zeros); } ConstantTokenNone *ConstantTokenNone::get(LLVMContext &Context) { @@ -2098,15 +2118,15 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C, unsigned AS = C->getType()->getPointerAddressSpace(); Type *ReqTy = DestTy->getPointerTo(AS); - unsigned NumVecElts = 0; - if (C->getType()->isVectorTy()) - NumVecElts = C->getType()->getVectorNumElements(); + ElementCount EltCount = {0, false}; + if (VectorType *VecTy = dyn_cast<VectorType>(C->getType())) + EltCount = VecTy->getElementCount(); else for (auto Idx : Idxs) - if (Idx->getType()->isVectorTy()) - NumVecElts = Idx->getType()->getVectorNumElements(); + if (VectorType *VecTy = dyn_cast<VectorType>(Idx->getType())) + EltCount = VecTy->getElementCount(); - if (NumVecElts) - ReqTy = VectorType::get(ReqTy, NumVecElts); + if (EltCount.Min != 0) + ReqTy = VectorType::get(ReqTy, EltCount); if (OnlyIfReducedTy == ReqTy) return nullptr; @@ -2117,12 +2137,12 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C, ArgVec.push_back(C); for (unsigned i = 0, e = Idxs.size(); i != e; ++i) { assert((!Idxs[i]->getType()->isVectorTy() || - Idxs[i]->getType()->getVectorNumElements() == NumVecElts) && + Idxs[i]->getType()->getVectorElementCount() == EltCount) && "getelementptr index type missmatch"); Constant *Idx = cast<Constant>(Idxs[i]); - if (NumVecElts && !Idxs[i]->getType()->isVectorTy()) - Idx = ConstantVector::getSplat(NumVecElts, Idx); + if (EltCount.Min != 0 && !Idxs[i]->getType()->isVectorTy()) + Idx = ConstantVector::getSplat(EltCount, Idx); ArgVec.push_back(Idx); } @@ -2759,7 +2779,7 @@ Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) { return getFP(V->getContext(), Elts); } } - return ConstantVector::getSplat(NumElts, V); + return ConstantVector::getSplat({NumElts, false}, V); } |