aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/IR/Constants.cpp
diff options
context:
space:
mode:
authorHuihui Zhang <huihuiz@quicinc.com>2020-03-12 13:15:34 -0700
committerHuihui Zhang <huihuiz@quicinc.com>2020-03-12 13:22:41 -0700
commit118abf20173899e9e1667db1a9c850dc5570b6ae (patch)
treedba40cf35b91a20a47755929161a1021642889e4 /llvm/lib/IR/Constants.cpp
parente91feeed21ee16abdb73f6e8cd471a253136e2cf (diff)
downloadllvm-118abf20173899e9e1667db1a9c850dc5570b6ae.zip
llvm-118abf20173899e9e1667db1a9c850dc5570b6ae.tar.gz
llvm-118abf20173899e9e1667db1a9c850dc5570b6ae.tar.bz2
[SVE] Update API ConstantVector::getSplat() to use ElementCount.
Summary: Support ConstantInt::get() and Constant::getAllOnesValue() for scalable vector type, this requires ConstantVector::getSplat() to take in 'ElementCount', instead of 'unsigned' number of element count. This change is needed for D73753. Reviewers: sdesmalen, efriedma, apazos, spatel, huntergr, willlovett Reviewed By: efriedma Subscribers: tschuett, hiraditya, rkruppe, psnobl, cfe-commits, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D74386
Diffstat (limited to 'llvm/lib/IR/Constants.cpp')
-rw-r--r--llvm/lib/IR/Constants.cpp86
1 files changed, 53 insertions, 33 deletions
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 399bd41..eb0e589 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -370,7 +370,7 @@ Constant *Constant::getIntegerValue(Type *Ty, const APInt &V) {
// Broadcast a scalar to a vector, if necessary.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- C = ConstantVector::getSplat(VTy->getNumElements(), C);
+ C = ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -387,7 +387,7 @@ Constant *Constant::getAllOnesValue(Type *Ty) {
}
VectorType *VTy = cast<VectorType>(Ty);
- return ConstantVector::getSplat(VTy->getNumElements(),
+ return ConstantVector::getSplat(VTy->getElementCount(),
getAllOnesValue(VTy->getElementType()));
}
@@ -681,7 +681,7 @@ Constant *ConstantInt::getTrue(Type *Ty) {
assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1.");
ConstantInt *TrueC = ConstantInt::getTrue(Ty->getContext());
if (auto *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), TrueC);
+ return ConstantVector::getSplat(VTy->getElementCount(), TrueC);
return TrueC;
}
@@ -689,7 +689,7 @@ Constant *ConstantInt::getFalse(Type *Ty) {
assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1.");
ConstantInt *FalseC = ConstantInt::getFalse(Ty->getContext());
if (auto *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), FalseC);
+ return ConstantVector::getSplat(VTy->getElementCount(), FalseC);
return FalseC;
}
@@ -712,7 +712,7 @@ Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) {
// For vectors, broadcast the value.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -736,7 +736,7 @@ Constant *ConstantInt::get(Type *Ty, const APInt& V) {
// For vectors, broadcast the value.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -781,7 +781,7 @@ Constant *ConstantFP::get(Type *Ty, double V) {
// For vectors, broadcast the value.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -793,7 +793,7 @@ Constant *ConstantFP::get(Type *Ty, const APFloat &V) {
// For vectors, broadcast the value.
if (auto *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -806,7 +806,7 @@ Constant *ConstantFP::get(Type *Ty, StringRef Str) {
// For vectors, broadcast the value.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -817,7 +817,7 @@ Constant *ConstantFP::getNaN(Type *Ty, bool Negative, uint64_t Payload) {
Constant *C = get(Ty->getContext(), NaN);
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -828,7 +828,7 @@ Constant *ConstantFP::getQNaN(Type *Ty, bool Negative, APInt *Payload) {
Constant *C = get(Ty->getContext(), NaN);
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -839,7 +839,7 @@ Constant *ConstantFP::getSNaN(Type *Ty, bool Negative, APInt *Payload) {
Constant *C = get(Ty->getContext(), NaN);
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -850,7 +850,7 @@ Constant *ConstantFP::getNegativeZero(Type *Ty) {
Constant *C = get(Ty->getContext(), NegZero);
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -898,7 +898,7 @@ Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative));
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -1204,15 +1204,35 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
return nullptr;
}
-Constant *ConstantVector::getSplat(unsigned NumElts, Constant *V) {
- // If this splat is compatible with ConstantDataVector, use it instead of
- // ConstantVector.
- if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) &&
- ConstantDataSequential::isElementTypeCompatible(V->getType()))
- return ConstantDataVector::getSplat(NumElts, V);
+Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
+ if (!EC.Scalable) {
+ // If this splat is compatible with ConstantDataVector, use it instead of
+ // ConstantVector.
+ if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) &&
+ ConstantDataSequential::isElementTypeCompatible(V->getType()))
+ return ConstantDataVector::getSplat(EC.Min, V);
- SmallVector<Constant*, 32> Elts(NumElts, V);
- return get(Elts);
+ SmallVector<Constant *, 32> Elts(EC.Min, V);
+ return get(Elts);
+ }
+
+ Type *VTy = VectorType::get(V->getType(), EC);
+
+ if (V->isNullValue())
+ return ConstantAggregateZero::get(VTy);
+ else if (isa<UndefValue>(V))
+ return UndefValue::get(VTy);
+
+ Type *I32Ty = Type::getInt32Ty(VTy->getContext());
+
+ // Move scalar into vector.
+ Constant *UndefV = UndefValue::get(VTy);
+ V = ConstantExpr::getInsertElement(UndefV, V, ConstantInt::get(I32Ty, 0));
+ // Build shuffle mask to perform the splat.
+ Type *MaskTy = VectorType::get(I32Ty, EC);
+ Constant *Zeros = ConstantAggregateZero::get(MaskTy);
+ // Splat.
+ return ConstantExpr::getShuffleVector(V, UndefV, Zeros);
}
ConstantTokenNone *ConstantTokenNone::get(LLVMContext &Context) {
@@ -2098,15 +2118,15 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C,
unsigned AS = C->getType()->getPointerAddressSpace();
Type *ReqTy = DestTy->getPointerTo(AS);
- unsigned NumVecElts = 0;
- if (C->getType()->isVectorTy())
- NumVecElts = C->getType()->getVectorNumElements();
+ ElementCount EltCount = {0, false};
+ if (VectorType *VecTy = dyn_cast<VectorType>(C->getType()))
+ EltCount = VecTy->getElementCount();
else for (auto Idx : Idxs)
- if (Idx->getType()->isVectorTy())
- NumVecElts = Idx->getType()->getVectorNumElements();
+ if (VectorType *VecTy = dyn_cast<VectorType>(Idx->getType()))
+ EltCount = VecTy->getElementCount();
- if (NumVecElts)
- ReqTy = VectorType::get(ReqTy, NumVecElts);
+ if (EltCount.Min != 0)
+ ReqTy = VectorType::get(ReqTy, EltCount);
if (OnlyIfReducedTy == ReqTy)
return nullptr;
@@ -2117,12 +2137,12 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C,
ArgVec.push_back(C);
for (unsigned i = 0, e = Idxs.size(); i != e; ++i) {
assert((!Idxs[i]->getType()->isVectorTy() ||
- Idxs[i]->getType()->getVectorNumElements() == NumVecElts) &&
+ Idxs[i]->getType()->getVectorElementCount() == EltCount) &&
"getelementptr index type missmatch");
Constant *Idx = cast<Constant>(Idxs[i]);
- if (NumVecElts && !Idxs[i]->getType()->isVectorTy())
- Idx = ConstantVector::getSplat(NumVecElts, Idx);
+ if (EltCount.Min != 0 && !Idxs[i]->getType()->isVectorTy())
+ Idx = ConstantVector::getSplat(EltCount, Idx);
ArgVec.push_back(Idx);
}
@@ -2759,7 +2779,7 @@ Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
return getFP(V->getContext(), Elts);
}
}
- return ConstantVector::getSplat(NumElts, V);
+ return ConstantVector::getSplat({NumElts, false}, V);
}