diff options
Diffstat (limited to 'llvm/lib/IR/ConstantFold.cpp')
-rw-r--r-- | llvm/lib/IR/ConstantFold.cpp | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp index 07292e5..c4465d9 100644 --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -876,22 +876,22 @@ Constant *llvm::ConstantFoldInsertElementInstruction(Constant *Val, return ConstantVector::get(Result); } -Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1, - Constant *V2, - Constant *Mask) { - ElementCount MaskEltCount = Mask->getType()->getVectorElementCount(); +Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1, Constant *V2, + ArrayRef<int> Mask) { + unsigned MaskNumElts = Mask.size(); + ElementCount MaskEltCount = {MaskNumElts, + V1->getType()->getVectorIsScalable()}; Type *EltTy = V1->getType()->getVectorElementType(); // Undefined shuffle mask -> undefined value. - if (isa<UndefValue>(Mask)) - return UndefValue::get(VectorType::get(EltTy, MaskEltCount)); - - // Don't break the bitcode reader hack. - if (isa<ConstantExpr>(Mask)) return nullptr; + if (all_of(Mask, [](int Elt) { return Elt == UndefMaskElem; })) { + return UndefValue::get(VectorType::get(EltTy, MaskNumElts)); + } // If the mask is all zeros this is a splat, no need to go through all // elements. - if (isa<ConstantAggregateZero>(Mask) && !MaskEltCount.Scalable) { + if (all_of(Mask, [](int Elt) { return Elt == 0; }) && + !MaskEltCount.Scalable) { Type *Ty = IntegerType::get(V1->getContext(), 32); Constant *Elt = ConstantExpr::getExtractElement(V1, ConstantInt::get(Ty, 0)); @@ -903,13 +903,12 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1, if (ValTy->isScalable()) return nullptr; - unsigned MaskNumElts = MaskEltCount.Min; unsigned SrcNumElts = V1->getType()->getVectorNumElements(); // Loop over the shuffle mask, evaluating each element. SmallVector<Constant*, 32> Result; for (unsigned i = 0; i != MaskNumElts; ++i) { - int Elt = ShuffleVectorInst::getMaskValue(Mask, i); + int Elt = Mask[i]; if (Elt == -1) { Result.push_back(UndefValue::get(EltTy)); continue; |