aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/IR/ConstantFold.cpp
diff options
context:
space:
mode:
authorThomas Raoux <thomasraoux@google.com>2020-03-30 10:36:21 -0700
committerThomas Raoux <thomasraoux@google.com>2020-03-30 11:27:09 -0700
commit3ea0774b13a538759aa1a68f30130d18ddb0d3f2 (patch)
treef19bd3ecf25f6fd958a60bc909e185b98b4eea2f /llvm/lib/IR/ConstantFold.cpp
parent8242509a49e019c0279305d152c7ab2b9cdc2d0d (diff)
downloadllvm-3ea0774b13a538759aa1a68f30130d18ddb0d3f2.zip
llvm-3ea0774b13a538759aa1a68f30130d18ddb0d3f2.tar.gz
llvm-3ea0774b13a538759aa1a68f30130d18ddb0d3f2.tar.bz2
[ConstantFold][NFC] Compile time optimization for large vectors
Optimize the common case of splat vector constant. For large vector going through all elements is expensive. For splatr/broadcast cases we can skip going through all elements. Differential Revision: https://reviews.llvm.org/D76664
Diffstat (limited to 'llvm/lib/IR/ConstantFold.cpp')
-rw-r--r--llvm/lib/IR/ConstantFold.cpp44
1 files changed, 42 insertions, 2 deletions
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 3e2e74c..07292e5 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -60,6 +60,11 @@ static Constant *BitCastConstantVector(Constant *CV, VectorType *DstTy) {
return nullptr;
Type *DstEltTy = DstTy->getElementType();
+ // Fast path for splatted constants.
+ if (Constant *Splat = CV->getSplatValue()) {
+ return ConstantVector::getSplat(DstTy->getVectorElementCount(),
+ ConstantExpr::getBitCast(Splat, DstEltTy));
+ }
SmallVector<Constant*, 16> Result;
Type *Ty = IntegerType::get(CV->getContext(), 32);
@@ -577,9 +582,15 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
DestTy->isVectorTy() &&
DestTy->getVectorNumElements() == V->getType()->getVectorNumElements()) {
- SmallVector<Constant*, 16> res;
VectorType *DestVecTy = cast<VectorType>(DestTy);
Type *DstEltTy = DestVecTy->getElementType();
+ // Fast path for splatted constants.
+ if (Constant *Splat = V->getSplatValue()) {
+ return ConstantVector::getSplat(
+ DestTy->getVectorElementCount(),
+ ConstantExpr::getCast(opc, Splat, DstEltTy));
+ }
+ SmallVector<Constant *, 16> res;
Type *Ty = IntegerType::get(V->getContext(), 32);
for (unsigned i = 0, e = V->getType()->getVectorNumElements(); i != e; ++i) {
Constant *C =
@@ -878,6 +889,14 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1,
// Don't break the bitcode reader hack.
if (isa<ConstantExpr>(Mask)) return nullptr;
+ // If the mask is all zeros this is a splat, no need to go through all
+ // elements.
+ if (isa<ConstantAggregateZero>(Mask) && !MaskEltCount.Scalable) {
+ Type *Ty = IntegerType::get(V1->getContext(), 32);
+ Constant *Elt =
+ ConstantExpr::getExtractElement(V1, ConstantInt::get(Ty, 0));
+ return ConstantVector::getSplat(MaskEltCount, Elt);
+ }
// Do not iterate on scalable vector. The num of elements is unknown at
// compile-time.
VectorType *ValTy = cast<VectorType>(V1->getType());
@@ -993,10 +1012,15 @@ Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) {
// compile-time.
if (IsScalableVector)
return nullptr;
+ Type *Ty = IntegerType::get(VTy->getContext(), 32);
+ // Fast path for splatted constants.
+ if (Constant *Splat = C->getSplatValue()) {
+ Constant *Elt = ConstantExpr::get(Opcode, Splat);
+ return ConstantVector::getSplat(VTy->getElementCount(), Elt);
+ }
// Fold each element and create a vector constant from those constants.
SmallVector<Constant*, 16> Result;
- Type *Ty = IntegerType::get(VTy->getContext(), 32);
for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
Constant *ExtractIdx = ConstantInt::get(Ty, i);
Constant *Elt = ConstantExpr::getExtractElement(C, ExtractIdx);
@@ -1357,6 +1381,16 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
// compile-time.
if (IsScalableVector)
return nullptr;
+ // Fast path for splatted constants.
+ if (Constant *C2Splat = C2->getSplatValue()) {
+ if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue())
+ return UndefValue::get(VTy);
+ if (Constant *C1Splat = C1->getSplatValue()) {
+ return ConstantVector::getSplat(
+ VTy->getVectorElementCount(),
+ ConstantExpr::get(Opcode, C1Splat, C2Splat));
+ }
+ }
// Fold each element and create a vector constant from those constants.
SmallVector<Constant*, 16> Result;
@@ -1975,6 +2009,12 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
// compile-time.
if (C1->getType()->getVectorIsScalable())
return nullptr;
+ // Fast path for splatted constants.
+ if (Constant *C1Splat = C1->getSplatValue())
+ if (Constant *C2Splat = C2->getSplatValue())
+ return ConstantVector::getSplat(
+ C1->getType()->getVectorElementCount(),
+ ConstantExpr::getCompare(pred, C1Splat, C2Splat));
// If we can constant fold the comparison of each element, constant fold
// the whole vector comparison.