From 312cb34da6a5529fbfaa1be62f1aa9bbb26ce506 Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Sat, 9 Dec 2023 16:45:48 +0800 Subject: [Reassociate] Preserve NUW flags after expr tree rewriting (#72360) Alive2: https://alive2.llvm.org/ce/z/38KiC_ --- llvm/include/llvm/Transforms/Scalar/Reassociate.h | 3 +- llvm/lib/Transforms/Scalar/Reassociate.cpp | 31 ++++++++++++----- llvm/test/Transforms/Reassociate/local-cse.ll | 40 +++++++++++----------- .../test/Transforms/Reassociate/reassoc-mul-nuw.ll | 34 ++++++++++++++++++ 4 files changed, 78 insertions(+), 30 deletions(-) create mode 100644 llvm/test/Transforms/Reassociate/reassoc-mul-nuw.ll diff --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h index 28794d2..7e47f8a 100644 --- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h +++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h @@ -102,7 +102,8 @@ private: void canonicalizeOperands(Instruction *I); void ReassociateExpression(BinaryOperator *I); void RewriteExprTree(BinaryOperator *I, - SmallVectorImpl &Ops); + SmallVectorImpl &Ops, + bool HasNUW); Value *OptimizeExpression(BinaryOperator *I, SmallVectorImpl &Ops); Value *OptimizeAdd(Instruction *I, diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp index 42e979d..818c7b4 100644 --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -466,7 +466,8 @@ using RepeatedValue = std::pair; /// type and thus make the expression bigger. static bool LinearizeExprTree(Instruction *I, SmallVectorImpl &Ops, - ReassociatePass::OrderedSet &ToRedo) { + ReassociatePass::OrderedSet &ToRedo, + bool &HasNUW) { assert((isa(I) || isa(I)) && "Expected a UnaryOperator or BinaryOperator!"); LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); @@ -515,6 +516,9 @@ static bool LinearizeExprTree(Instruction *I, std::pair P = Worklist.pop_back_val(); I = P.first; // We examine the operands of this binary operator. + if (isa(I)) + HasNUW &= I->hasNoUnsignedWrap(); + for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands. Value *Op = I->getOperand(OpIdx); APInt Weight = P.second; // Number of paths to this operand. @@ -657,7 +661,8 @@ static bool LinearizeExprTree(Instruction *I, /// Now that the operands for this expression tree are /// linearized and optimized, emit them in-order. void ReassociatePass::RewriteExprTree(BinaryOperator *I, - SmallVectorImpl &Ops) { + SmallVectorImpl &Ops, + bool HasNUW) { assert(Ops.size() > 1 && "Single values should be used directly!"); // Since our optimizations should never increase the number of operations, the @@ -814,14 +819,20 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, if (ExpressionChangedStart) { bool ClearFlags = true; do { - // Preserve FastMathFlags. + // Preserve flags. if (ClearFlags) { if (isa(I)) { FastMathFlags Flags = I->getFastMathFlags(); ExpressionChangedStart->clearSubclassOptionalData(); ExpressionChangedStart->setFastMathFlags(Flags); - } else + } else { ExpressionChangedStart->clearSubclassOptionalData(); + // Note that it doesn't hold for mul if one of the operands is zero. + // TODO: We can preserve NUW flag if we prove that all mul operands + // are non-zero. + if (HasNUW && ExpressionChangedStart->getOpcode() == Instruction::Add) + ExpressionChangedStart->setHasNoUnsignedWrap(); + } } if (ExpressionChangedStart == ExpressionChangedEnd) @@ -1175,7 +1186,8 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { return nullptr; SmallVector Tree; - MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts); + bool HasNUW = true; + MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, HasNUW); SmallVector Factors; Factors.reserve(Tree.size()); for (unsigned i = 0, e = Tree.size(); i != e; ++i) { @@ -1217,7 +1229,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { if (!FoundFactor) { // Make sure to restore the operands to the expression tree. - RewriteExprTree(BO, Factors); + RewriteExprTree(BO, Factors, HasNUW); return nullptr; } @@ -1229,7 +1241,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { RedoInsts.insert(BO); V = Factors[0].Op; } else { - RewriteExprTree(BO, Factors); + RewriteExprTree(BO, Factors, HasNUW); V = BO; } @@ -2354,7 +2366,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { // First, walk the expression tree, linearizing the tree, collecting the // operand information. SmallVector Tree; - MadeChange |= LinearizeExprTree(I, Tree, RedoInsts); + bool HasNUW = true; + MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, HasNUW); SmallVector Ops; Ops.reserve(Tree.size()); for (const RepeatedValue &E : Tree) @@ -2547,7 +2560,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { dbgs() << '\n'); // Now that we ordered and optimized the expressions, splat them back into // the expression tree, removing any unneeded nodes. - RewriteExprTree(I, Ops); + RewriteExprTree(I, Ops, HasNUW); } void diff --git a/llvm/test/Transforms/Reassociate/local-cse.ll b/llvm/test/Transforms/Reassociate/local-cse.ll index 1609cb1..4d0467e 100644 --- a/llvm/test/Transforms/Reassociate/local-cse.ll +++ b/llvm/test/Transforms/Reassociate/local-cse.ll @@ -26,16 +26,16 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64 ; LOCAL_CSE-LABEL: define void @chain_spanning_several_blocks ; LOCAL_CSE-SAME: (i64 [[INV1:%.*]], i64 [[INV2:%.*]], i64 [[INV3:%.*]], i64 [[INV4:%.*]], i64 [[INV5:%.*]]) { ; LOCAL_CSE-NEXT: bb1: -; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add i64 [[INV2]], [[INV1]] +; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[INV2]], [[INV1]] ; LOCAL_CSE-NEXT: br label [[BB2:%.*]] ; LOCAL_CSE: bb2: ; LOCAL_CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val() -; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV4]] -; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add i64 [[CHAIN_A1]], [[VAL_BB2]] -; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add i64 [[CHAIN_A0]], [[INV5]] -; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add i64 [[CHAIN_B1]], [[VAL_BB2]] -; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add i64 [[INV3]], [[INV1]] -; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add i64 [[CHAIN_C0]], [[VAL_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4]] +; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5]] +; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw i64 [[INV3]], [[INV1]] +; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[VAL_BB2]] ; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]]) ; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]]) ; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]]) @@ -47,11 +47,11 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64 ; CSE-NEXT: br label [[BB2:%.*]] ; CSE: bb2: ; CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val() -; CSE-NEXT: [[CHAIN_A0:%.*]] = add i64 [[VAL_BB2]], [[INV1]] -; CSE-NEXT: [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV2]] +; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1]] +; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2]] ; CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4]] ; CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5]] -; CSE-NEXT: [[CHAIN_C1:%.*]] = add i64 [[CHAIN_A0]], [[INV3]] +; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3]] ; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]]) ; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]]) ; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]]) @@ -90,19 +90,19 @@ define void @chain_spanning_several_blocks_no_entry_anchor() { ; LOCAL_CSE-NEXT: br label [[BB1:%.*]] ; LOCAL_CSE: bb1: ; LOCAL_CSE-NEXT: [[INV1_BB1:%.*]] = call i64 @get_val() -; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add i64 [[INV1_BB1]], [[INV2_BB0]] +; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[INV1_BB1]], [[INV2_BB0]] ; LOCAL_CSE-NEXT: br label [[BB2:%.*]] ; LOCAL_CSE: bb2: ; LOCAL_CSE-NEXT: [[INV3_BB2:%.*]] = call i64 @get_val() ; LOCAL_CSE-NEXT: [[INV4_BB2:%.*]] = call i64 @get_val() ; LOCAL_CSE-NEXT: [[INV5_BB2:%.*]] = call i64 @get_val() ; LOCAL_CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val() -; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV4_BB2]] -; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add i64 [[CHAIN_A1]], [[VAL_BB2]] -; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add i64 [[CHAIN_A0]], [[INV5_BB2]] -; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add i64 [[CHAIN_B1]], [[VAL_BB2]] -; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add i64 [[VAL_BB2]], [[INV1_BB1]] -; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add i64 [[CHAIN_C0]], [[INV3_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]] +; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[INV3_BB2]] ; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]]) ; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]]) ; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]]) @@ -120,11 +120,11 @@ define void @chain_spanning_several_blocks_no_entry_anchor() { ; CSE-NEXT: [[INV4_BB2:%.*]] = call i64 @get_val() ; CSE-NEXT: [[INV5_BB2:%.*]] = call i64 @get_val() ; CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val() -; CSE-NEXT: [[CHAIN_A0:%.*]] = add i64 [[VAL_BB2]], [[INV1_BB1]] -; CSE-NEXT: [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV2_BB0]] +; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]] +; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2_BB0]] ; CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4_BB2]] ; CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5_BB2]] -; CSE-NEXT: [[CHAIN_C1:%.*]] = add i64 [[CHAIN_A0]], [[INV3_BB2]] +; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3_BB2]] ; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]]) ; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]]) ; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]]) diff --git a/llvm/test/Transforms/Reassociate/reassoc-mul-nuw.ll b/llvm/test/Transforms/Reassociate/reassoc-mul-nuw.ll new file mode 100644 index 0000000..682fad8 --- /dev/null +++ b/llvm/test/Transforms/Reassociate/reassoc-mul-nuw.ll @@ -0,0 +1,34 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 +; RUN: opt < %s -passes=reassociate -S | FileCheck %s + +; We cannot preserve nuw flags for mul +define i4 @nuw_preserve_negative(i4 %a, i4 %b, i4 %c) { +; CHECK-LABEL: define i4 @nuw_preserve_negative( +; CHECK-SAME: i4 [[A:%.*]], i4 [[B:%.*]], i4 [[C:%.*]]) { +; CHECK-NEXT: [[V0:%.*]] = mul i4 [[B]], [[A]] +; CHECK-NEXT: [[V1:%.*]] = mul i4 [[V0]], [[C]] +; CHECK-NEXT: ret i4 [[V1]] +; + %v0 = mul nuw i4 %a, %c + %v1 = mul nuw i4 %v0, %b + ret i4 %v1 +} + +; TODO: we can add nuw flags if we know all operands are non-zero. +define i4 @nuw_preserve_non_zero(i4 %a, i4 %b, i4 %c) { +; CHECK-LABEL: define i4 @nuw_preserve_non_zero( +; CHECK-SAME: i4 [[A:%.*]], i4 [[B:%.*]], i4 [[C:%.*]]) { +; CHECK-NEXT: [[A0:%.*]] = add nuw i4 [[A]], 1 +; CHECK-NEXT: [[B0:%.*]] = add nuw i4 [[B]], 1 +; CHECK-NEXT: [[C0:%.*]] = add nuw i4 [[C]], 1 +; CHECK-NEXT: [[V0:%.*]] = mul i4 [[B0]], [[A0]] +; CHECK-NEXT: [[V1:%.*]] = mul i4 [[V0]], [[C0]] +; CHECK-NEXT: ret i4 [[V1]] +; + %a0 = add nuw i4 %a, 1 + %b0 = add nuw i4 %b, 1 + %c0 = add nuw i4 %c, 1 + %v0 = mul nuw i4 %a0, %c0 + %v1 = mul nuw i4 %v0, %b0 + ret i4 %v1 +} -- cgit v1.1