diff options
author | Thomas Raoux <thomasraoux@google.com> | 2020-03-30 10:36:21 -0700 |
---|---|---|
committer | Thomas Raoux <thomasraoux@google.com> | 2020-03-30 11:27:09 -0700 |
commit | 3ea0774b13a538759aa1a68f30130d18ddb0d3f2 (patch) | |
tree | f19bd3ecf25f6fd958a60bc909e185b98b4eea2f /llvm/lib/IR/ConstantFold.cpp | |
parent | 8242509a49e019c0279305d152c7ab2b9cdc2d0d (diff) | |
download | llvm-3ea0774b13a538759aa1a68f30130d18ddb0d3f2.zip llvm-3ea0774b13a538759aa1a68f30130d18ddb0d3f2.tar.gz llvm-3ea0774b13a538759aa1a68f30130d18ddb0d3f2.tar.bz2 |
[ConstantFold][NFC] Compile time optimization for large vectors
Optimize the common case of splat vector constant. For large vector
going through all elements is expensive. For splatr/broadcast cases we
can skip going through all elements.
Differential Revision: https://reviews.llvm.org/D76664
Diffstat (limited to 'llvm/lib/IR/ConstantFold.cpp')
-rw-r--r-- | llvm/lib/IR/ConstantFold.cpp | 44 |
1 files changed, 42 insertions, 2 deletions
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp index 3e2e74c..07292e5 100644 --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -60,6 +60,11 @@ static Constant *BitCastConstantVector(Constant *CV, VectorType *DstTy) { return nullptr; Type *DstEltTy = DstTy->getElementType(); + // Fast path for splatted constants. + if (Constant *Splat = CV->getSplatValue()) { + return ConstantVector::getSplat(DstTy->getVectorElementCount(), + ConstantExpr::getBitCast(Splat, DstEltTy)); + } SmallVector<Constant*, 16> Result; Type *Ty = IntegerType::get(CV->getContext(), 32); @@ -577,9 +582,15 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) && DestTy->isVectorTy() && DestTy->getVectorNumElements() == V->getType()->getVectorNumElements()) { - SmallVector<Constant*, 16> res; VectorType *DestVecTy = cast<VectorType>(DestTy); Type *DstEltTy = DestVecTy->getElementType(); + // Fast path for splatted constants. + if (Constant *Splat = V->getSplatValue()) { + return ConstantVector::getSplat( + DestTy->getVectorElementCount(), + ConstantExpr::getCast(opc, Splat, DstEltTy)); + } + SmallVector<Constant *, 16> res; Type *Ty = IntegerType::get(V->getContext(), 32); for (unsigned i = 0, e = V->getType()->getVectorNumElements(); i != e; ++i) { Constant *C = @@ -878,6 +889,14 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1, // Don't break the bitcode reader hack. if (isa<ConstantExpr>(Mask)) return nullptr; + // If the mask is all zeros this is a splat, no need to go through all + // elements. + if (isa<ConstantAggregateZero>(Mask) && !MaskEltCount.Scalable) { + Type *Ty = IntegerType::get(V1->getContext(), 32); + Constant *Elt = + ConstantExpr::getExtractElement(V1, ConstantInt::get(Ty, 0)); + return ConstantVector::getSplat(MaskEltCount, Elt); + } // Do not iterate on scalable vector. The num of elements is unknown at // compile-time. VectorType *ValTy = cast<VectorType>(V1->getType()); @@ -993,10 +1012,15 @@ Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) { // compile-time. if (IsScalableVector) return nullptr; + Type *Ty = IntegerType::get(VTy->getContext(), 32); + // Fast path for splatted constants. + if (Constant *Splat = C->getSplatValue()) { + Constant *Elt = ConstantExpr::get(Opcode, Splat); + return ConstantVector::getSplat(VTy->getElementCount(), Elt); + } // Fold each element and create a vector constant from those constants. SmallVector<Constant*, 16> Result; - Type *Ty = IntegerType::get(VTy->getContext(), 32); for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { Constant *ExtractIdx = ConstantInt::get(Ty, i); Constant *Elt = ConstantExpr::getExtractElement(C, ExtractIdx); @@ -1357,6 +1381,16 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1, // compile-time. if (IsScalableVector) return nullptr; + // Fast path for splatted constants. + if (Constant *C2Splat = C2->getSplatValue()) { + if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue()) + return UndefValue::get(VTy); + if (Constant *C1Splat = C1->getSplatValue()) { + return ConstantVector::getSplat( + VTy->getVectorElementCount(), + ConstantExpr::get(Opcode, C1Splat, C2Splat)); + } + } // Fold each element and create a vector constant from those constants. SmallVector<Constant*, 16> Result; @@ -1975,6 +2009,12 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, // compile-time. if (C1->getType()->getVectorIsScalable()) return nullptr; + // Fast path for splatted constants. + if (Constant *C1Splat = C1->getSplatValue()) + if (Constant *C2Splat = C2->getSplatValue()) + return ConstantVector::getSplat( + C1->getType()->getVectorElementCount(), + ConstantExpr::getCompare(pred, C1Splat, C2Splat)); // If we can constant fold the comparison of each element, constant fold // the whole vector comparison. |