aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
diff options
context:
space:
mode:
authorRoman Lebedev <lebedev.ri@gmail.com>2022-01-10 20:49:41 +0300
committerRoman Lebedev <lebedev.ri@gmail.com>2022-01-10 20:51:26 +0300
commit82fb4f4b223d78e86647f3576e41e3086ab42cd5 (patch)
tree3bd0b9179c7d7718e67a38352f7f152b6e710fb8 /llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
parent07a0b0ee94880cc193f3c63ebe3c4662c3123606 (diff)
downloadllvm-82fb4f4b223d78e86647f3576e41e3086ab42cd5.zip
llvm-82fb4f4b223d78e86647f3576e41e3086ab42cd5.tar.gz
llvm-82fb4f4b223d78e86647f3576e41e3086ab42cd5.tar.bz2
[SCEV] Sequential/in-order `UMin` expression
As discussed in https://github.com/llvm/llvm-project/issues/53020 / https://reviews.llvm.org/D116692, SCEV is forbidden from reasoning about 'backedge taken count' if the branch condition is a poison-safe logical operation, which is conservatively correct, but is severely limiting. Instead, we should have a way to express those poison blocking properties in SCEV expressions. The proposed semantics is: ``` Sequential/in-order min/max SCEV expressions are non-commutative variants of commutative min/max SCEV expressions. If none of their operands are poison, then they are functionally equivalent, otherwise, if the operand that represents the saturation point* of given expression, comes before the first poison operand, then the whole expression is not poison, but is said saturation point. ``` * saturation point - the maximal/minimal possible integer value for the given type The lowering is straight-forward: ``` compare each operand to the saturation point, perform sequential in-order logical-or (poison-safe!) ordered reduction over those checks, and if reduction returned true then return saturation point else return the naive min/max reduction over the operands ``` https://alive2.llvm.org/ce/z/Q7jxvH (2 ops) https://alive2.llvm.org/ce/z/QCRrhk (3 ops) Note that we don't need to check the last operand: https://alive2.llvm.org/ce/z/abvHQS Note that this is not commutative: https://alive2.llvm.org/ce/z/FK9e97 That allows us to handle the patterns in question. Reviewed By: nikic, reames Differential Revision: https://reviews.llvm.org/D116766
Diffstat (limited to 'llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp64
1 files changed, 58 insertions, 6 deletions
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 68edfe0..b41b634 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -1671,7 +1671,7 @@ Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) {
return Builder.CreateSExt(V, Ty);
}
-Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) {
+Value *SCEVExpander::expandSMaxExpr(const SCEVNAryExpr *S) {
Value *LHS = expand(S->getOperand(S->getNumOperands()-1));
Type *Ty = LHS->getType();
for (int i = S->getNumOperands()-2; i >= 0; --i) {
@@ -1700,7 +1700,7 @@ Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) {
return LHS;
}
-Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) {
+Value *SCEVExpander::expandUMaxExpr(const SCEVNAryExpr *S) {
Value *LHS = expand(S->getOperand(S->getNumOperands()-1));
Type *Ty = LHS->getType();
for (int i = S->getNumOperands()-2; i >= 0; --i) {
@@ -1729,7 +1729,7 @@ Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) {
return LHS;
}
-Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) {
+Value *SCEVExpander::expandSMinExpr(const SCEVNAryExpr *S) {
Value *LHS = expand(S->getOperand(S->getNumOperands() - 1));
Type *Ty = LHS->getType();
for (int i = S->getNumOperands() - 2; i >= 0; --i) {
@@ -1758,7 +1758,7 @@ Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) {
return LHS;
}
-Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) {
+Value *SCEVExpander::expandUMinExpr(const SCEVNAryExpr *S) {
Value *LHS = expand(S->getOperand(S->getNumOperands() - 1));
Type *Ty = LHS->getType();
for (int i = S->getNumOperands() - 2; i >= 0; --i) {
@@ -1787,6 +1787,40 @@ Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) {
return LHS;
}
+Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) {
+ return expandSMaxExpr(S);
+}
+
+Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) {
+ return expandUMaxExpr(S);
+}
+
+Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) {
+ return expandSMinExpr(S);
+}
+
+Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) {
+ return expandUMinExpr(S);
+}
+
+Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) {
+ SmallVector<Value *> Ops;
+ for (const SCEV *Op : S->operands())
+ Ops.emplace_back(expand(Op));
+
+ Value *SaturationPoint =
+ MinMaxIntrinsic::getSaturationPoint(Intrinsic::umin, S->getType());
+
+ SmallVector<Value *> OpIsZero;
+ for (Value *Op : ArrayRef<Value *>(Ops).drop_back())
+ OpIsZero.emplace_back(Builder.CreateICmpEQ(Op, SaturationPoint));
+
+ Value *AnyOpIsZero = Builder.CreateLogicalOr(OpIsZero);
+
+ Value *NaiveUMin = expandUMinExpr(S);
+ return Builder.CreateSelect(AnyOpIsZero, SaturationPoint, NaiveUMin);
+}
+
Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty,
Instruction *IP, bool Root) {
setInsertPoint(IP);
@@ -2271,10 +2305,27 @@ template<typename T> static InstructionCost costAndCollectOperands(
case scSMaxExpr:
case scUMaxExpr:
case scSMinExpr:
- case scUMinExpr: {
+ case scUMinExpr:
+ case scSequentialUMinExpr: {
// FIXME: should this ask the cost for Intrinsic's?
+ // The reduction tree.
Cost += CmpSelCost(Instruction::ICmp, S->getNumOperands() - 1, 0, 1);
Cost += CmpSelCost(Instruction::Select, S->getNumOperands() - 1, 0, 2);
+ switch (S->getSCEVType()) {
+ case scSequentialUMinExpr: {
+ // The safety net against poison.
+ // FIXME: this is broken.
+ Cost += CmpSelCost(Instruction::ICmp, S->getNumOperands() - 1, 0, 0);
+ Cost += ArithCost(Instruction::Or,
+ S->getNumOperands() > 2 ? S->getNumOperands() - 2 : 0);
+ Cost += CmpSelCost(Instruction::Select, 1, 0, 1);
+ break;
+ }
+ default:
+ assert(!isa<SCEVSequentialMinMaxExpr>(S) &&
+ "Unhandled SCEV expression type?");
+ break;
+ }
break;
}
case scAddRecExpr: {
@@ -2399,7 +2450,8 @@ bool SCEVExpander::isHighCostExpansionHelper(
case scUMaxExpr:
case scSMaxExpr:
case scUMinExpr:
- case scSMinExpr: {
+ case scSMinExpr:
+ case scSequentialUMinExpr: {
assert(cast<SCEVNAryExpr>(S)->getNumOperands() > 1 &&
"Nary expr should have more than 1 operand.");
// The simple nary expr will require one less op (or pair of ops)