diff options
author | Roman Lebedev <lebedev.ri@gmail.com> | 2022-01-10 20:49:41 +0300 |
---|---|---|
committer | Roman Lebedev <lebedev.ri@gmail.com> | 2022-01-10 20:51:26 +0300 |
commit | 82fb4f4b223d78e86647f3576e41e3086ab42cd5 (patch) | |
tree | 3bd0b9179c7d7718e67a38352f7f152b6e710fb8 /llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp | |
parent | 07a0b0ee94880cc193f3c63ebe3c4662c3123606 (diff) | |
download | llvm-82fb4f4b223d78e86647f3576e41e3086ab42cd5.zip llvm-82fb4f4b223d78e86647f3576e41e3086ab42cd5.tar.gz llvm-82fb4f4b223d78e86647f3576e41e3086ab42cd5.tar.bz2 |
[SCEV] Sequential/in-order `UMin` expression
As discussed in https://github.com/llvm/llvm-project/issues/53020 / https://reviews.llvm.org/D116692,
SCEV is forbidden from reasoning about 'backedge taken count'
if the branch condition is a poison-safe logical operation,
which is conservatively correct, but is severely limiting.
Instead, we should have a way to express those
poison blocking properties in SCEV expressions.
The proposed semantics is:
```
Sequential/in-order min/max SCEV expressions are non-commutative variants
of commutative min/max SCEV expressions. If none of their operands
are poison, then they are functionally equivalent, otherwise,
if the operand that represents the saturation point* of given expression,
comes before the first poison operand, then the whole expression is not poison,
but is said saturation point.
```
* saturation point - the maximal/minimal possible integer value for the given type
The lowering is straight-forward:
```
compare each operand to the saturation point,
perform sequential in-order logical-or (poison-safe!) ordered reduction
over those checks, and if reduction returned true then return
saturation point else return the naive min/max reduction over the operands
```
https://alive2.llvm.org/ce/z/Q7jxvH (2 ops)
https://alive2.llvm.org/ce/z/QCRrhk (3 ops)
Note that we don't need to check the last operand: https://alive2.llvm.org/ce/z/abvHQS
Note that this is not commutative: https://alive2.llvm.org/ce/z/FK9e97
That allows us to handle the patterns in question.
Reviewed By: nikic, reames
Differential Revision: https://reviews.llvm.org/D116766
Diffstat (limited to 'llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp | 64 |
1 files changed, 58 insertions, 6 deletions
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 68edfe0..b41b634 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -1671,7 +1671,7 @@ Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) { return Builder.CreateSExt(V, Ty); } -Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) { +Value *SCEVExpander::expandSMaxExpr(const SCEVNAryExpr *S) { Value *LHS = expand(S->getOperand(S->getNumOperands()-1)); Type *Ty = LHS->getType(); for (int i = S->getNumOperands()-2; i >= 0; --i) { @@ -1700,7 +1700,7 @@ Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) { return LHS; } -Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) { +Value *SCEVExpander::expandUMaxExpr(const SCEVNAryExpr *S) { Value *LHS = expand(S->getOperand(S->getNumOperands()-1)); Type *Ty = LHS->getType(); for (int i = S->getNumOperands()-2; i >= 0; --i) { @@ -1729,7 +1729,7 @@ Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) { return LHS; } -Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) { +Value *SCEVExpander::expandSMinExpr(const SCEVNAryExpr *S) { Value *LHS = expand(S->getOperand(S->getNumOperands() - 1)); Type *Ty = LHS->getType(); for (int i = S->getNumOperands() - 2; i >= 0; --i) { @@ -1758,7 +1758,7 @@ Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) { return LHS; } -Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) { +Value *SCEVExpander::expandUMinExpr(const SCEVNAryExpr *S) { Value *LHS = expand(S->getOperand(S->getNumOperands() - 1)); Type *Ty = LHS->getType(); for (int i = S->getNumOperands() - 2; i >= 0; --i) { @@ -1787,6 +1787,40 @@ Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) { return LHS; } +Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) { + return expandSMaxExpr(S); +} + +Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) { + return expandUMaxExpr(S); +} + +Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) { + return expandSMinExpr(S); +} + +Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) { + return expandUMinExpr(S); +} + +Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) { + SmallVector<Value *> Ops; + for (const SCEV *Op : S->operands()) + Ops.emplace_back(expand(Op)); + + Value *SaturationPoint = + MinMaxIntrinsic::getSaturationPoint(Intrinsic::umin, S->getType()); + + SmallVector<Value *> OpIsZero; + for (Value *Op : ArrayRef<Value *>(Ops).drop_back()) + OpIsZero.emplace_back(Builder.CreateICmpEQ(Op, SaturationPoint)); + + Value *AnyOpIsZero = Builder.CreateLogicalOr(OpIsZero); + + Value *NaiveUMin = expandUMinExpr(S); + return Builder.CreateSelect(AnyOpIsZero, SaturationPoint, NaiveUMin); +} + Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *IP, bool Root) { setInsertPoint(IP); @@ -2271,10 +2305,27 @@ template<typename T> static InstructionCost costAndCollectOperands( case scSMaxExpr: case scUMaxExpr: case scSMinExpr: - case scUMinExpr: { + case scUMinExpr: + case scSequentialUMinExpr: { // FIXME: should this ask the cost for Intrinsic's? + // The reduction tree. Cost += CmpSelCost(Instruction::ICmp, S->getNumOperands() - 1, 0, 1); Cost += CmpSelCost(Instruction::Select, S->getNumOperands() - 1, 0, 2); + switch (S->getSCEVType()) { + case scSequentialUMinExpr: { + // The safety net against poison. + // FIXME: this is broken. + Cost += CmpSelCost(Instruction::ICmp, S->getNumOperands() - 1, 0, 0); + Cost += ArithCost(Instruction::Or, + S->getNumOperands() > 2 ? S->getNumOperands() - 2 : 0); + Cost += CmpSelCost(Instruction::Select, 1, 0, 1); + break; + } + default: + assert(!isa<SCEVSequentialMinMaxExpr>(S) && + "Unhandled SCEV expression type?"); + break; + } break; } case scAddRecExpr: { @@ -2399,7 +2450,8 @@ bool SCEVExpander::isHighCostExpansionHelper( case scUMaxExpr: case scSMaxExpr: case scUMinExpr: - case scSMinExpr: { + case scSMinExpr: + case scSequentialUMinExpr: { assert(cast<SCEVNAryExpr>(S)->getNumOperands() > 1 && "Nary expr should have more than 1 operand."); // The simple nary expr will require one less op (or pair of ops) |