aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
diff options
context:
space:
mode:
authorSlava Zakharin <szakharin@nvidia.com>2025-06-19 10:13:58 -0700
committerGitHub <noreply@github.com>2025-06-19 10:13:58 -0700
commitc0c71463f6bca05eb4540b68cdcbd17c916562c9 (patch)
tree06bff7632e602f88a2fcb3031b83381c05ef9bdb /llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
parent6ce86538c11b3ef93a2a8df3bd4f817a724f42bd (diff)
downloadllvm-c0c71463f6bca05eb4540b68cdcbd17c916562c9.zip
llvm-c0c71463f6bca05eb4540b68cdcbd17c916562c9.tar.gz
llvm-c0c71463f6bca05eb4540b68cdcbd17c916562c9.tar.bz2
[InstCombine] Optimize sub(sext(add(x,y)),sext(add(x,z))). (#144174)
This pattern can be often met in Flang generated LLVM IR, for example, for the counts of the loops generated for array expressions like: `a(x:x+y)` or `a(x+z:x+z)` or their variations. In order to compute the loop count, Flang needs to subtract the lower bound of the array slice from the upper bound of the array slice. To avoid the sign wraps, it sign extends the original values (that may be of any user data type) to `i64`. This peephole is really helpful in CPU2017/548.exchange2, where we have multiple following statements like this: ``` block(row+1:row+2, 7:9, i7) = block(row+1:row+2, 7:9, i7) - 10 ``` While this is just a 2x3 iterations loop nest, LLVM cannot figure it out, ending up vectorizing the inner loop really hard (with a vector epilog and scalar remainder). This, in turn, causes problems for LSR that ends up creating too many loop-carried values in the loop containing the above statement, which are then causing too many spills/reloads. Alive2: https://alive2.llvm.org/ce/z/gLgfYX Related to #143219.
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp51
1 files changed, 48 insertions, 3 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 0a3837f..418302d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1896,7 +1896,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
{Sub, Builder.getFalse()});
Value *Ret = Builder.CreateSub(
ConstantInt::get(A->getType(), A->getType()->getScalarSizeInBits()),
- Ctlz, "", /*HasNUW*/ true, /*HasNSW*/ true);
+ Ctlz, "", /*HasNUW=*/true, /*HasNSW=*/true);
return replaceInstUsesWith(I, Builder.CreateZExtOrTrunc(Ret, I.getType()));
}
@@ -2363,8 +2363,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
OverflowingBinaryOperator *LHSSub = cast<OverflowingBinaryOperator>(Op0);
bool HasNUW = I.hasNoUnsignedWrap() && LHSSub->hasNoUnsignedWrap();
bool HasNSW = HasNUW && I.hasNoSignedWrap() && LHSSub->hasNoSignedWrap();
- Value *Add = Builder.CreateAdd(Y, Op1, "", /* HasNUW */ HasNUW,
- /* HasNSW */ HasNSW);
+ Value *Add = Builder.CreateAdd(Y, Op1, "", /*HasNUW=*/HasNUW,
+ /*HasNSW=*/HasNSW);
BinaryOperator *Sub = BinaryOperator::CreateSub(X, Add);
Sub->setHasNoUnsignedWrap(HasNUW);
Sub->setHasNoSignedWrap(HasNSW);
@@ -2835,6 +2835,51 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
return Res;
+ // (sub (sext (add nsw (X, Y)), sext (X))) --> (sext (Y))
+ if (match(Op1, m_SExtLike(m_Value(X))) &&
+ match(Op0, m_SExtLike(m_c_NSWAdd(m_Specific(X), m_Value(Y))))) {
+ Value *SExtY = Builder.CreateSExt(Y, I.getType());
+ return replaceInstUsesWith(I, SExtY);
+ }
+
+ // (sub[ nsw] (sext (add nsw (X, Y)), sext (add nsw (X, Z)))) -->
+ // --> (sub[ nsw] (sext (Y), sext (Z)))
+ {
+ Value *Z, *Add0, *Add1;
+ if (match(Op0, m_SExtLike(m_Value(Add0))) &&
+ match(Op1, m_SExtLike(m_Value(Add1))) &&
+ ((match(Add0, m_NSWAdd(m_Value(X), m_Value(Y))) &&
+ match(Add1, m_c_NSWAdd(m_Specific(X), m_Value(Z)))) ||
+ (match(Add0, m_NSWAdd(m_Value(Y), m_Value(X))) &&
+ match(Add1, m_c_NSWAdd(m_Specific(X), m_Value(Z)))))) {
+ unsigned NumOfNewInstrs = 0;
+ // Non-constant Y, Z require new SExt.
+ NumOfNewInstrs += !isa<Constant>(Y) ? 1 : 0;
+ NumOfNewInstrs += !isa<Constant>(Z) ? 1 : 0;
+ // Check if we can trade some of the old instructions for the new ones.
+ unsigned NumOfDeadInstrs = 0;
+ if (Op0->hasOneUse()) {
+ // If Op0 (sext) has multiple uses, then we keep it
+ // and the add that it uses, otherwise, we can remove
+ // the sext and probably the add (depending on the number of its uses).
+ ++NumOfDeadInstrs;
+ NumOfDeadInstrs += Add0->hasOneUse() ? 1 : 0;
+ }
+ if (Op1->hasOneUse()) {
+ ++NumOfDeadInstrs;
+ NumOfDeadInstrs += Add1->hasOneUse() ? 1 : 0;
+ }
+ if (NumOfDeadInstrs >= NumOfNewInstrs) {
+ Value *SExtY = Builder.CreateSExt(Y, I.getType());
+ Value *SExtZ = Builder.CreateSExt(Z, I.getType());
+ Value *Sub = Builder.CreateSub(SExtY, SExtZ, "",
+ /*HasNUW=*/false,
+ /*HasNSW=*/I.hasNoSignedWrap());
+ return replaceInstUsesWith(I, Sub);
+ }
+ }
+ }
+
return TryToNarrowDeduceFlags();
}