aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/IR/Constants.cpp26
-rw-r--r--llvm/lib/IR/Instruction.cpp2
-rw-r--r--llvm/lib/Transforms/Scalar/Reassociate.cpp2
-rw-r--r--llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp14
4 files changed, 39 insertions, 5 deletions
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index bc55d5b..a38b912 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -2556,6 +2556,32 @@ Constant *ConstantExpr::getBinOpIdentity(unsigned Opcode, Type *Ty,
}
}
+Constant *ConstantExpr::getIntrinsicIdentity(Intrinsic::ID ID, Type *Ty) {
+ switch (ID) {
+ case Intrinsic::umax:
+ return Constant::getNullValue(Ty);
+ case Intrinsic::umin:
+ return Constant::getAllOnesValue(Ty);
+ case Intrinsic::smax:
+ return Constant::getIntegerValue(
+ Ty, APInt::getSignedMinValue(Ty->getIntegerBitWidth()));
+ case Intrinsic::smin:
+ return Constant::getIntegerValue(
+ Ty, APInt::getSignedMaxValue(Ty->getIntegerBitWidth()));
+ default:
+ return nullptr;
+ }
+}
+
+Constant *ConstantExpr::getIdentity(Instruction *I, Type *Ty,
+ bool AllowRHSConstant, bool NSZ) {
+ if (I->isBinaryOp())
+ return getBinOpIdentity(I->getOpcode(), Ty, AllowRHSConstant, NSZ);
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I))
+ return getIntrinsicIdentity(II->getIntrinsicID(), Ty);
+ return nullptr;
+}
+
Constant *ConstantExpr::getBinOpAbsorber(unsigned Opcode, Type *Ty) {
switch (Opcode) {
default:
diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index 4b53498..a6e43de 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -1091,6 +1091,8 @@ const DebugLoc &Instruction::getStableDebugLoc() const {
}
bool Instruction::isAssociative() const {
+ if (auto *II = dyn_cast<IntrinsicInst>(this))
+ return II->isAssociative();
unsigned Opcode = getOpcode();
if (isAssociative(Opcode))
return true;
diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index 0d55c72..d3f6d24 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -2554,7 +2554,7 @@ ReassociatePass::BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT) {
// Make a "pairmap" of how often each operand pair occurs.
for (BasicBlock *BI : RPOT) {
for (Instruction &I : *BI) {
- if (!I.isAssociative())
+ if (!I.isAssociative() || !I.isBinaryOp())
continue;
// Ignore nodes that aren't at the root of trees.
diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
index 7b850f0..c6e8505 100644
--- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -369,8 +369,14 @@ static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) {
if (!I->isAssociative() || !I->isCommutative())
return false;
- assert(I->getNumOperands() == 2 &&
- "Associative/commutative operations should have 2 args!");
+ assert(I->getNumOperands() >= 2 &&
+ "Associative/commutative operations should have at least 2 args!");
+
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+ // Accumulators must have an identity.
+ if (!ConstantExpr::getIntrinsicIdentity(II->getIntrinsicID(), I->getType()))
+ return false;
+ }
// Exactly one operand should be the result of the call instruction.
if ((I->getOperand(0) == CI && I->getOperand(1) == CI) ||
@@ -569,8 +575,8 @@ void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) {
for (pred_iterator PI = PB; PI != PE; ++PI) {
BasicBlock *P = *PI;
if (P == &F.getEntryBlock()) {
- Constant *Identity = ConstantExpr::getBinOpIdentity(
- AccRecInstr->getOpcode(), AccRecInstr->getType());
+ Constant *Identity =
+ ConstantExpr::getIdentity(AccRecInstr, AccRecInstr->getType());
AccPN->addIncoming(Identity, P);
} else {
AccPN->addIncoming(AccPN, P);