diff options
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r-- | llvm/lib/Transforms/IPO/FunctionAttrs.cpp | 62 | ||||
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 11 |
2 files changed, 35 insertions, 38 deletions
diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp index 2eec438..c321afb 100644 --- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp @@ -903,49 +903,37 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes, return MadeChange; } -/// Removes convergent attributes where we can prove that none of the SCC's -/// callees are themselves convergent. Returns true if successful at removing -/// the attribute. +/// Remove the convergent attribute from all functions in the SCC if every +/// callsite within the SCC is not convergent (except for calls to functions +/// within the SCC). Returns true if changes were made. static bool removeConvergentAttrs(const SCCNodeSet &SCCNodes) { - // Determines whether a function can be made non-convergent, ignoring all - // other functions in SCC. (A function can *actually* be made non-convergent - // only if all functions in its SCC can be made convergent.) - auto CanRemoveConvergent = [&](Function *F) { - if (!F->isConvergent()) - return true; - - // Can't remove convergent from declarations. - if (F->isDeclaration()) - return false; - - for (Instruction &I : instructions(*F)) - if (auto CS = CallSite(&I)) { - // Can't remove convergent if any of F's callees -- ignoring functions - // in the SCC itself -- are convergent. This needs to consider both - // function calls and intrinsic calls. We also assume indirect calls - // might call a convergent function. - // FIXME: We should revisit this when we put convergent onto calls - // instead of functions so that indirect calls which should be - // convergent are required to be marked as such. - Function *Callee = CS.getCalledFunction(); - if (!Callee || (SCCNodes.count(Callee) == 0 && Callee->isConvergent())) - return false; - } - - return true; - }; + // No point checking if none of SCCNodes is convergent. + if (!llvm::any_of(SCCNodes, [](Function *F) { return F->isConvergent(); })) + return false; - // We can remove the convergent attr from functions in the SCC if they all - // can be made non-convergent (because they call only non-convergent - // functions, other than each other). - if (!llvm::all_of(SCCNodes, CanRemoveConvergent)) + // Can't remove convergent from function declarations. + if (llvm::any_of(SCCNodes, [](Function *F) { return F->isDeclaration(); })) return false; - // If we got here, all of the SCC's callees are non-convergent. Therefore all - // of the SCC's functions can be marked as non-convergent. + // Can't remove convergent if any of our functions has a convergent call to a + // function not in the SCC. + for (Function *F : SCCNodes) + for (Instruction &I : instructions(*F)) { + CallSite CS(&I); + // Bail if is CS a convergent call to a function not in the SCC. + if (CS && CS.isConvergent() && + SCCNodes.count(CS.getCalledFunction()) == 0) + return false; + } + + // If we got here, all of the calls the SCC makes to functions not in the SCC + // are non-convergent. Therefore all of the SCC's functions can also be made + // non-convergent. We'll remove the attr from the callsites in + // InstCombineCalls. for (Function *F : SCCNodes) { if (F->isConvergent()) - DEBUG(dbgs() << "Removing convergent attr from " << F->getName() << "\n"); + DEBUG(dbgs() << "Removing convergent attr from fn " << F->getName() + << "\n"); F->setNotConvergent(); } return true; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 249553a..71199a4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -2070,7 +2070,15 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { if (!isa<Function>(Callee) && transformConstExprCastCall(CS)) return nullptr; - if (Function *CalleeF = dyn_cast<Function>(Callee)) + if (Function *CalleeF = dyn_cast<Function>(Callee)) { + // Remove the convergent attr on calls when the callee is not convergent. + if (CS.isConvergent() && !CalleeF->isConvergent()) { + DEBUG(dbgs() << "Removing convergent attr from instr " + << CS.getInstruction() << "\n"); + CS.setNotConvergent(); + return CS.getInstruction(); + } + // If the call and callee calling conventions don't match, this call must // be unreachable, as the call is undefined. if (CalleeF->getCallingConv() != CS.getCallingConv() && @@ -2095,6 +2103,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { Constant::getNullValue(CalleeF->getType())); return nullptr; } + } if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) { // If CS does not return void then replaceAllUsesWith undef. |