aboutsummaryrefslogtreecommitdiff
path: root/clang/lib/CodeGen
diff options
context:
space:
mode:
authorSameer Sahasrabuddhe <sameer.sahasrabuddhe@amd.com>2025-01-06 21:34:11 +0530
committerGitHub <noreply@github.com>2025-01-06 21:34:11 +0530
commitdf67e37e37a7862e1e67f52e01f0c9a019477930 (patch)
treed05a362f6a80a64d565e79833639dd070f772b30 /clang/lib/CodeGen
parent4ebfd43cf008b941d88a61a2c549e9a5291ee017 (diff)
downloadllvm-df67e37e37a7862e1e67f52e01f0c9a019477930.zip
llvm-df67e37e37a7862e1e67f52e01f0c9a019477930.tar.gz
llvm-df67e37e37a7862e1e67f52e01f0c9a019477930.tar.bz2
[clang][NFC] clean up the handling of convergence control tokens (#121738)
Diffstat (limited to 'clang/lib/CodeGen')
-rw-r--r--clang/lib/CodeGen/CGCall.cpp4
-rw-r--r--clang/lib/CodeGen/CGStmt.cpp46
-rw-r--r--clang/lib/CodeGen/CodeGenFunction.h23
3 files changed, 29 insertions, 44 deletions
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index f139c30..89e2eac 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4871,7 +4871,7 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
call->setCallingConv(getRuntimeCC());
if (CGM.shouldEmitConvergenceTokens() && call->isConvergent())
- return addControlledConvergenceToken(call);
+ return cast<llvm::CallInst>(addConvergenceControlToken(call));
return call;
}
@@ -5787,7 +5787,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
CI->setName("call");
if (CGM.shouldEmitConvergenceTokens() && CI->isConvergent())
- CI = addControlledConvergenceToken(CI);
+ CI = addConvergenceControlToken(CI);
// Update largest vector width from the return type.
LargestVectorWidth =
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 3974739..7904e17 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -1024,8 +1024,8 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
EmitBlock(LoopHeader.getBlock());
if (CGM.shouldEmitConvergenceTokens())
- ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
- LoopHeader.getBlock(), ConvergenceTokenStack.back()));
+ ConvergenceTokenStack.push_back(
+ emitConvergenceLoopToken(LoopHeader.getBlock()));
// Create an exit block for when the condition fails, which will
// also become the break target.
@@ -1152,8 +1152,7 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
EmitBlockWithFallThrough(LoopBody, &S);
if (CGM.shouldEmitConvergenceTokens())
- ConvergenceTokenStack.push_back(
- emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));
+ ConvergenceTokenStack.push_back(emitConvergenceLoopToken(LoopBody));
{
RunCleanupsScope BodyScope(*this);
@@ -1231,8 +1230,7 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
EmitBlock(CondBlock);
if (CGM.shouldEmitConvergenceTokens())
- ConvergenceTokenStack.push_back(
- emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+ ConvergenceTokenStack.push_back(emitConvergenceLoopToken(CondBlock));
const SourceRange &R = S.getSourceRange();
LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
@@ -1369,8 +1367,7 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
EmitBlock(CondBlock);
if (CGM.shouldEmitConvergenceTokens())
- ConvergenceTokenStack.push_back(
- emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+ ConvergenceTokenStack.push_back(emitConvergenceLoopToken(CondBlock));
const SourceRange &R = S.getSourceRange();
LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
@@ -3245,35 +3242,32 @@ CodeGenFunction::GenerateCapturedStmtFunction(const CapturedStmt &S) {
return F;
}
-namespace {
// Returns the first convergence entry/loop/anchor instruction found in |BB|.
// std::nullptr otherwise.
-llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
+static llvm::ConvergenceControlInst *getConvergenceToken(llvm::BasicBlock *BB) {
for (auto &I : *BB) {
- auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
- if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID()))
- return II;
+ if (auto *CI = dyn_cast<llvm::ConvergenceControlInst>(&I))
+ return CI;
}
return nullptr;
}
-} // namespace
-
llvm::CallBase *
-CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
- llvm::Value *ParentToken) {
+CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input) {
+ llvm::ConvergenceControlInst *ParentToken = ConvergenceTokenStack.back();
+ assert(ParentToken);
+
llvm::Value *bundleArgs[] = {ParentToken};
llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
- auto Output = llvm::CallBase::addOperandBundle(
+ auto *Output = llvm::CallBase::addOperandBundle(
Input, llvm::LLVMContext::OB_convergencectrl, OB, Input->getIterator());
Input->replaceAllUsesWith(Output);
Input->eraseFromParent();
return Output;
}
-llvm::IntrinsicInst *
-CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
- llvm::Value *ParentToken) {
+llvm::ConvergenceControlInst *
+CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB) {
CGBuilderTy::InsertPoint IP = Builder.saveIP();
if (BB->empty())
Builder.SetInsertPoint(BB);
@@ -3284,14 +3278,14 @@ CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
llvm::Intrinsic::experimental_convergence_loop, {}, {});
Builder.restoreIP(IP);
- llvm::CallBase *I = addConvergenceControlToken(CB, ParentToken);
- return cast<llvm::IntrinsicInst>(I);
+ CB = addConvergenceControlToken(CB);
+ return cast<llvm::ConvergenceControlInst>(CB);
}
-llvm::IntrinsicInst *
+llvm::ConvergenceControlInst *
CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
llvm::BasicBlock *BB = &F->getEntryBlock();
- llvm::IntrinsicInst *Token = getConvergenceToken(BB);
+ llvm::ConvergenceControlInst *Token = getConvergenceToken(BB);
if (Token)
return Token;
@@ -3306,5 +3300,5 @@ CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
assert(isa<llvm::IntrinsicInst>(I));
Builder.restoreIP(IP);
- return cast<llvm::IntrinsicInst>(I);
+ return cast<llvm::ConvergenceControlInst>(I);
}
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 1a5c42f..46f2679 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -315,7 +315,7 @@ public:
SmallVector<const BinaryOperator *, 16> MCDCLogOpStack;
/// Stack to track the controlled convergence tokens.
- SmallVector<llvm::IntrinsicInst *, 4> ConvergenceTokenStack;
+ SmallVector<llvm::ConvergenceControlInst *, 4> ConvergenceTokenStack;
/// Number of nested loop to be consumed by the last surrounding
/// loop-associated directive.
@@ -5234,29 +5234,20 @@ public:
llvm::Value *emitBoolVecConversion(llvm::Value *SrcVec,
unsigned NumElementsDst,
const llvm::Twine &Name = "");
- // Adds a convergence_ctrl token to |Input| and emits the required parent
- // convergence instructions.
- template <typename CallType>
- CallType *addControlledConvergenceToken(CallType *Input) {
- return cast<CallType>(
- addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
- }
private:
// Emits a convergence_loop instruction for the given |BB|, with |ParentToken|
// as it's parent convergence instr.
- llvm::IntrinsicInst *emitConvergenceLoopToken(llvm::BasicBlock *BB,
- llvm::Value *ParentToken);
+ llvm::ConvergenceControlInst *emitConvergenceLoopToken(llvm::BasicBlock *BB);
+
// Adds a convergence_ctrl token with |ParentToken| as parent convergence
// instr to the call |Input|.
- llvm::CallBase *addConvergenceControlToken(llvm::CallBase *Input,
- llvm::Value *ParentToken);
+ llvm::CallBase *addConvergenceControlToken(llvm::CallBase *Input);
+
// Find the convergence_entry instruction |F|, or emits ones if none exists.
// Returns the convergence instruction.
- llvm::IntrinsicInst *getOrEmitConvergenceEntryToken(llvm::Function *F);
- // Find the convergence_loop instruction for the loop defined by |LI|, or
- // emits one if none exists. Returns the convergence instruction.
- llvm::IntrinsicInst *getOrEmitConvergenceLoopToken(const LoopInfo *LI);
+ llvm::ConvergenceControlInst *
+ getOrEmitConvergenceEntryToken(llvm::Function *F);
private:
llvm::MDNode *getRangeForLoadFromType(QualType Ty);