aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--llvm/include/llvm/IR/Constant.h14
-rw-r--r--llvm/lib/Analysis/ValueTracking.cpp7
-rw-r--r--llvm/lib/IR/ConstantFold.cpp2
-rw-r--r--llvm/lib/IR/Constants.cpp25
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp5
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp4
-rw-r--r--llvm/unittests/Analysis/ValueTrackingTest.cpp24
-rw-r--r--llvm/unittests/IR/ConstantsTest.cpp37
8 files changed, 98 insertions, 20 deletions
diff --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h
index 97650c2..0190aca 100644
--- a/llvm/include/llvm/IR/Constant.h
+++ b/llvm/include/llvm/IR/Constant.h
@@ -101,11 +101,15 @@ public:
/// lane, the constants still match.
bool isElementWiseEqual(Value *Y) const;
- /// Return true if this is a vector constant that includes any undefined
- /// elements. Since it is impossible to inspect a scalable vector element-
- /// wise at compile time, this function returns true only if the entire
- /// vector is undef
- bool containsUndefElement() const;
+ /// Return true if this is a vector constant that includes any undef or
+ /// poison elements. Since it is impossible to inspect a scalable vector
+ /// element- wise at compile time, this function returns true only if the
+ /// entire vector is undef or poison.
+ bool containsUndefOrPoisonElement() const;
+
+ /// Return true if this is a vector constant that includes any poison
+ /// elements.
+ bool containsPoisonElement() const;
/// Return true if this is a fixed width vector constant that includes
/// any constant expressions.
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e15d4f0..1c75c5f 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -4895,7 +4895,8 @@ static bool isGuaranteedNotToBeUndefOrPoison(const Value *V,
return true;
if (C->getType()->isVectorTy() && !isa<ConstantExpr>(C))
- return (PoisonOnly || !C->containsUndefElement()) &&
+ return (PoisonOnly ? !C->containsPoisonElement()
+ : !C->containsUndefOrPoisonElement()) &&
!C->containsConstantExpression();
}
@@ -5636,10 +5637,10 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
// elements because those can not be back-propagated for analysis.
Value *OutputZeroVal = nullptr;
if (match(TrueVal, m_AnyZeroFP()) && !match(FalseVal, m_AnyZeroFP()) &&
- !cast<Constant>(TrueVal)->containsUndefElement())
+ !cast<Constant>(TrueVal)->containsUndefOrPoisonElement())
OutputZeroVal = TrueVal;
else if (match(FalseVal, m_AnyZeroFP()) && !match(TrueVal, m_AnyZeroFP()) &&
- !cast<Constant>(FalseVal)->containsUndefElement())
+ !cast<Constant>(FalseVal)->containsUndefOrPoisonElement())
OutputZeroVal = FalseVal;
if (OutputZeroVal) {
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 4774568..03cb108 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -811,7 +811,7 @@ Constant *llvm::ConstantFoldSelectInstruction(Constant *Cond,
return true;
if (C->getType()->isVectorTy())
- return !C->containsUndefElement() && !C->containsConstantExpression();
+ return !C->containsPoisonElement() && !C->containsConstantExpression();
// TODO: Recursively analyze aggregates or other constants.
return false;
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index a38302d..5aa819d 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -304,31 +304,42 @@ bool Constant::isElementWiseEqual(Value *Y) const {
return isa<UndefValue>(CmpEq) || match(CmpEq, m_One());
}
-bool Constant::containsUndefElement() const {
- if (auto *VTy = dyn_cast<VectorType>(getType())) {
- if (isa<UndefValue>(this))
+static bool
+containsUndefinedElement(const Constant *C,
+ function_ref<bool(const Constant *)> HasFn) {
+ if (auto *VTy = dyn_cast<VectorType>(C->getType())) {
+ if (HasFn(C))
return true;
- if (isa<ConstantAggregateZero>(this))
+ if (isa<ConstantAggregateZero>(C))
return false;
- if (isa<ScalableVectorType>(getType()))
+ if (isa<ScalableVectorType>(C->getType()))
return false;
for (unsigned i = 0, e = cast<FixedVectorType>(VTy)->getNumElements();
i != e; ++i)
- if (isa<UndefValue>(getAggregateElement(i)))
+ if (HasFn(C->getAggregateElement(i)))
return true;
}
return false;
}
+bool Constant::containsUndefOrPoisonElement() const {
+ return containsUndefinedElement(
+ this, [&](const auto *C) { return isa<UndefValue>(C); });
+}
+
+bool Constant::containsPoisonElement() const {
+ return containsUndefinedElement(
+ this, [&](const auto *C) { return isa<PoisonValue>(C); });
+}
+
bool Constant::containsConstantExpression() const {
if (auto *VTy = dyn_cast<FixedVectorType>(getType())) {
for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i)
if (isa<ConstantExpr>(getAggregateElement(i)))
return true;
}
-
return false;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 87d4b40..0887779 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3370,7 +3370,7 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
Type *OpTy = M->getType();
auto *VecC = dyn_cast<Constant>(M);
auto *OpVTy = dyn_cast<FixedVectorType>(OpTy);
- if (OpVTy && VecC && VecC->containsUndefElement()) {
+ if (OpVTy && VecC && VecC->containsUndefOrPoisonElement()) {
Constant *SafeReplacementConstant = nullptr;
for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) {
if (!isa<UndefValue>(VecC->getAggregateElement(i))) {
@@ -5259,7 +5259,8 @@ InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
// It may not be safe to change a compare predicate in the presence of
// undefined elements, so replace those elements with the first safe constant
// that we found.
- if (C->containsUndefElement()) {
+ // TODO: in case of poison, it is safe; let's replace undefs only.
+ if (C->containsUndefOrPoisonElement()) {
assert(SafeReplacementConstant && "Replacement constant not set");
C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
index 494c58e..7718c8b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
@@ -239,8 +239,8 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
// While this is normally not behind a use-check,
// let's consider division to be special since it's costly.
if (auto *Op1C = dyn_cast<Constant>(I->getOperand(1))) {
- if (!Op1C->containsUndefElement() && Op1C->isNotMinSignedValue() &&
- Op1C->isNotOneValue()) {
+ if (!Op1C->containsUndefOrPoisonElement() &&
+ Op1C->isNotMinSignedValue() && Op1C->isNotOneValue()) {
Value *BO =
Builder.CreateSDiv(I->getOperand(0), ConstantExpr::getNeg(Op1C),
I->getName() + ".neg");
diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp
index 0d65774..d70fd6e 100644
--- a/llvm/unittests/Analysis/ValueTrackingTest.cpp
+++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp
@@ -888,6 +888,30 @@ TEST_F(ValueTrackingTest, isGuaranteedNotToBeUndefOrPoison) {
EXPECT_EQ(isGuaranteedNotToBeUndefOrPoison(PoisonValue::get(IntegerType::get(Context, 8))), false);
EXPECT_EQ(isGuaranteedNotToBePoison(UndefValue::get(IntegerType::get(Context, 8))), true);
EXPECT_EQ(isGuaranteedNotToBePoison(PoisonValue::get(IntegerType::get(Context, 8))), false);
+
+ Type *Int32Ty = Type::getInt32Ty(Context);
+ Constant *CU = UndefValue::get(Int32Ty);
+ Constant *CP = PoisonValue::get(Int32Ty);
+ Constant *C1 = ConstantInt::get(Int32Ty, 1);
+ Constant *C2 = ConstantInt::get(Int32Ty, 2);
+
+ {
+ Constant *V1 = ConstantVector::get({C1, C2});
+ EXPECT_TRUE(isGuaranteedNotToBeUndefOrPoison(V1));
+ EXPECT_TRUE(isGuaranteedNotToBePoison(V1));
+ }
+
+ {
+ Constant *V2 = ConstantVector::get({C1, CU});
+ EXPECT_FALSE(isGuaranteedNotToBeUndefOrPoison(V2));
+ EXPECT_TRUE(isGuaranteedNotToBePoison(V2));
+ }
+
+ {
+ Constant *V3 = ConstantVector::get({C1, CP});
+ EXPECT_FALSE(isGuaranteedNotToBeUndefOrPoison(V3));
+ EXPECT_FALSE(isGuaranteedNotToBePoison(V3));
+ }
}
TEST_F(ValueTrackingTest, isGuaranteedNotToBeUndefOrPoison_assume) {
diff --git a/llvm/unittests/IR/ConstantsTest.cpp b/llvm/unittests/IR/ConstantsTest.cpp
index 9dd1ba8..afae154 100644
--- a/llvm/unittests/IR/ConstantsTest.cpp
+++ b/llvm/unittests/IR/ConstantsTest.cpp
@@ -585,6 +585,43 @@ TEST(ConstantsTest, FoldGlobalVariablePtr) {
Instruction::And, TheConstantExpr, TheConstant)->isNullValue());
}
+// Check that containsUndefOrPoisonElement and containsPoisonElement is working
+// great
+
+TEST(ConstantsTest, containsUndefElemTest) {
+ LLVMContext Context;
+
+ Type *Int32Ty = Type::getInt32Ty(Context);
+ Constant *CU = UndefValue::get(Int32Ty);
+ Constant *CP = PoisonValue::get(Int32Ty);
+ Constant *C1 = ConstantInt::get(Int32Ty, 1);
+ Constant *C2 = ConstantInt::get(Int32Ty, 2);
+
+ {
+ Constant *V1 = ConstantVector::get({C1, C2});
+ EXPECT_FALSE(V1->containsUndefOrPoisonElement());
+ EXPECT_FALSE(V1->containsPoisonElement());
+ }
+
+ {
+ Constant *V2 = ConstantVector::get({C1, CU});
+ EXPECT_TRUE(V2->containsUndefOrPoisonElement());
+ EXPECT_FALSE(V2->containsPoisonElement());
+ }
+
+ {
+ Constant *V3 = ConstantVector::get({C1, CP});
+ EXPECT_TRUE(V3->containsUndefOrPoisonElement());
+ EXPECT_TRUE(V3->containsPoisonElement());
+ }
+
+ {
+ Constant *V4 = ConstantVector::get({CU, CP});
+ EXPECT_TRUE(V4->containsUndefOrPoisonElement());
+ EXPECT_TRUE(V4->containsPoisonElement());
+ }
+}
+
// Check that undefined elements in vector constants are matched
// correctly for both integer and floating-point types. Just don't
// crash on vectors of pointers (could be handled?).