aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Analysis/InstructionSimplify.cpp48
-rw-r--r--llvm/lib/Analysis/VectorUtils.cpp52
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp67
3 files changed, 109 insertions, 58 deletions
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index ad85a3e..fa42b48 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -24,6 +24,7 @@
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
@@ -3555,6 +3556,47 @@ Value *llvm::SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs,
RecursionLimit);
}
+/// SimplifyExtractElementInst - Given operands for an ExtractElementInst, see if we
+/// can fold the result. If not, this returns null.
+static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const Query &,
+ unsigned) {
+ if (auto *CVec = dyn_cast<Constant>(Vec)) {
+ if (auto *CIdx = dyn_cast<Constant>(Idx))
+ return ConstantFoldExtractElementInstruction(CVec, CIdx);
+
+ // The index is not relevant if our vector is a splat.
+ if (auto *Splat = CVec->getSplatValue())
+ return Splat;
+
+ if (isa<UndefValue>(Vec))
+ return UndefValue::get(Vec->getType()->getVectorElementType());
+ }
+
+ // If extracting a specified index from the vector, see if we can recursively
+ // find a previously computed scalar that was inserted into the vector.
+ if (auto *IdxC = dyn_cast<ConstantInt>(Idx)) {
+ unsigned IndexVal = IdxC->getZExtValue();
+ unsigned VectorWidth = Vec->getType()->getVectorNumElements();
+
+ // If this is extracting an invalid index, turn this into undef, to avoid
+ // crashing the code below.
+ if (IndexVal >= VectorWidth)
+ return UndefValue::get(Vec->getType()->getVectorElementType());
+
+ if (Value *Elt = findScalarElement(Vec, IndexVal))
+ return Elt;
+ }
+
+ return nullptr;
+}
+
+Value *llvm::SimplifyExtractElementInst(
+ Value *Vec, Value *Idx, const DataLayout &DL, const TargetLibraryInfo *TLI,
+ const DominatorTree *DT, AssumptionCache *AC, const Instruction *CxtI) {
+ return ::SimplifyExtractElementInst(Vec, Idx, Query(DL, TLI, DT, AC, CxtI),
+ RecursionLimit);
+}
+
/// SimplifyPHINode - See if we can fold the given phi. If not, returns null.
static Value *SimplifyPHINode(PHINode *PN, const Query &Q) {
// If all of the PHI's incoming values are the same then replace the PHI node
@@ -3970,6 +4012,12 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL,
EVI->getIndices(), DL, TLI, DT, AC, I);
break;
}
+ case Instruction::ExtractElement: {
+ auto *EEI = cast<ExtractElementInst>(I);
+ Result = SimplifyExtractElementInst(
+ EEI->getVectorOperand(), EEI->getIndexOperand(), DL, TLI, DT, AC, I);
+ break;
+ }
case Instruction::PHI:
Result = SimplifyPHINode(cast<PHINode>(I), Query(DL, TLI, DT, AC, I));
break;
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index eab5887..67f68dc 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -357,3 +357,55 @@ llvm::Value *llvm::getStrideFromPointer(llvm::Value *Ptr, ScalarEvolution *SE,
return Stride;
}
+
+/// \brief Given a vector and an element number, see if the scalar value is
+/// already around as a register, for example if it were inserted then extracted
+/// from the vector.
+llvm::Value *llvm::findScalarElement(llvm::Value *V, unsigned EltNo) {
+ assert(V->getType()->isVectorTy() && "Not looking at a vector?");
+ VectorType *VTy = cast<VectorType>(V->getType());
+ unsigned Width = VTy->getNumElements();
+ if (EltNo >= Width) // Out of range access.
+ return UndefValue::get(VTy->getElementType());
+
+ if (Constant *C = dyn_cast<Constant>(V))
+ return C->getAggregateElement(EltNo);
+
+ if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) {
+ // If this is an insert to a variable element, we don't know what it is.
+ if (!isa<ConstantInt>(III->getOperand(2)))
+ return nullptr;
+ unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue();
+
+ // If this is an insert to the element we are looking for, return the
+ // inserted value.
+ if (EltNo == IIElt)
+ return III->getOperand(1);
+
+ // Otherwise, the insertelement doesn't modify the value, recurse on its
+ // vector input.
+ return findScalarElement(III->getOperand(0), EltNo);
+ }
+
+ if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) {
+ unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements();
+ int InEl = SVI->getMaskValue(EltNo);
+ if (InEl < 0)
+ return UndefValue::get(VTy->getElementType());
+ if (InEl < (int)LHSWidth)
+ return findScalarElement(SVI->getOperand(0), InEl);
+ return findScalarElement(SVI->getOperand(1), InEl - LHSWidth);
+ }
+
+ // Extract a value from a vector add operation with a constant zero.
+ Value *Val = nullptr; Constant *Con = nullptr;
+ if (match(V,
+ llvm::PatternMatch::m_Add(llvm::PatternMatch::m_Value(Val),
+ llvm::PatternMatch::m_Constant(Con)))) {
+ if (Con->getAggregateElement(EltNo)->isNullValue())
+ return findScalarElement(Val, EltNo);
+ }
+
+ // Otherwise, we don't know.
+ return nullptr;
+}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 24446c8..2730472 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -14,6 +14,8 @@
#include "InstCombineInternal.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/PatternMatch.h"
using namespace llvm;
using namespace PatternMatch;
@@ -60,56 +62,6 @@ static bool CheapToScalarize(Value *V, bool isConstant) {
return false;
}
-/// FindScalarElement - Given a vector and an element number, see if the scalar
-/// value is already around as a register, for example if it were inserted then
-/// extracted from the vector.
-static Value *FindScalarElement(Value *V, unsigned EltNo) {
- assert(V->getType()->isVectorTy() && "Not looking at a vector?");
- VectorType *VTy = cast<VectorType>(V->getType());
- unsigned Width = VTy->getNumElements();
- if (EltNo >= Width) // Out of range access.
- return UndefValue::get(VTy->getElementType());
-
- if (Constant *C = dyn_cast<Constant>(V))
- return C->getAggregateElement(EltNo);
-
- if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) {
- // If this is an insert to a variable element, we don't know what it is.
- if (!isa<ConstantInt>(III->getOperand(2)))
- return nullptr;
- unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue();
-
- // If this is an insert to the element we are looking for, return the
- // inserted value.
- if (EltNo == IIElt)
- return III->getOperand(1);
-
- // Otherwise, the insertelement doesn't modify the value, recurse on its
- // vector input.
- return FindScalarElement(III->getOperand(0), EltNo);
- }
-
- if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) {
- unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements();
- int InEl = SVI->getMaskValue(EltNo);
- if (InEl < 0)
- return UndefValue::get(VTy->getElementType());
- if (InEl < (int)LHSWidth)
- return FindScalarElement(SVI->getOperand(0), InEl);
- return FindScalarElement(SVI->getOperand(1), InEl - LHSWidth);
- }
-
- // Extract a value from a vector add operation with a constant zero.
- Value *Val = nullptr; Constant *Con = nullptr;
- if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) {
- if (Con->getAggregateElement(EltNo)->isNullValue())
- return FindScalarElement(Val, EltNo);
- }
-
- // Otherwise, we don't know.
- return nullptr;
-}
-
// If we have a PHI node with a vector type that has only 2 uses: feed
// itself and be an operand of extractelement at a constant location,
// try to replace the PHI of the vector type with a PHI of a scalar type.
@@ -178,6 +130,10 @@ Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) {
}
Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
+ if (Value *V = SimplifyExtractElementInst(
+ EI.getVectorOperand(), EI.getIndexOperand(), DL, TLI, DT, AC))
+ return ReplaceInstUsesWith(EI, V);
+
// If vector val is constant with all elements the same, replace EI with
// that element. We handle a known element # below.
if (Constant *C = dyn_cast<Constant>(EI.getOperand(0)))
@@ -190,10 +146,8 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
unsigned IndexVal = IdxC->getZExtValue();
unsigned VectorWidth = EI.getVectorOperandType()->getNumElements();
- // If this is extracting an invalid index, turn this into undef, to avoid
- // crashing the code below.
- if (IndexVal >= VectorWidth)
- return ReplaceInstUsesWith(EI, UndefValue::get(EI.getType()));
+ // InstSimplify handles cases where the index is invalid.
+ assert(IndexVal < VectorWidth);
// This instruction only demands the single element from the input vector.
// If the input vector has a single use, simplify it based on this use
@@ -209,16 +163,13 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
}
}
- if (Value *Elt = FindScalarElement(EI.getOperand(0), IndexVal))
- return ReplaceInstUsesWith(EI, Elt);
-
// If the this extractelement is directly using a bitcast from a vector of
// the same number of elements, see if we can find the source element from
// it. In this case, we will end up needing to bitcast the scalars.
if (BitCastInst *BCI = dyn_cast<BitCastInst>(EI.getOperand(0))) {
if (VectorType *VT = dyn_cast<VectorType>(BCI->getOperand(0)->getType()))
if (VT->getNumElements() == VectorWidth)
- if (Value *Elt = FindScalarElement(BCI->getOperand(0), IndexVal))
+ if (Value *Elt = findScalarElement(BCI->getOperand(0), IndexVal))
return new BitCastInst(Elt, EI.getType());
}