aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h7
-rw-r--r--llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp123
-rw-r--r--llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp81
3 files changed, 211 insertions, 0 deletions
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 523a071..1699ed3 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1889,6 +1889,13 @@ public:
BodyGenCallbackTy BodyGenCB,
FinalizeCallbackTy FiniCB);
+ /// Generator for `#omp teams`
+ ///
+ /// \param Loc The location where the teams construct was encountered.
+ /// \param BodyGenCB Callback that will generate the region code.
+ InsertPointTy createTeams(const LocationDescription &Loc,
+ BodyGenCallbackTy BodyGenCB);
+
/// Generate conditional branch and relevant BasicBlocks through which private
/// threads copy the 'copyin' variables from Master copy to threadprivate
/// copies.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 1ace7d5..8a5f4cb 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5735,6 +5735,129 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
return Builder.saveIP();
}
+OpenMPIRBuilder::InsertPointTy
+OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
+ BodyGenCallbackTy BodyGenCB) {
+ if (!updateToLocation(Loc))
+ return InsertPointTy();
+
+ uint32_t SrcLocStrSize;
+ Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+ Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+ Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
+
+ // Outer allocation basicblock is the entry block of the current function.
+ BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
+ if (&OuterAllocaBB == Builder.GetInsertBlock()) {
+ BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
+ Builder.SetInsertPoint(BodyBB, BodyBB->begin());
+ }
+
+ // The current basic block is split into four basic blocks. After outlining,
+ // they will be mapped as follows:
+ // ```
+ // def current_fn() {
+ // current_basic_block:
+ // br label %teams.exit
+ // teams.exit:
+ // ; instructions after teams
+ // }
+ //
+ // def outlined_fn() {
+ // teams.alloca:
+ // br label %teams.body
+ // teams.body:
+ // ; instructions within teams body
+ // }
+ // ```
+ BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, "teams.exit");
+ BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.body");
+ BasicBlock *AllocaBB =
+ splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
+
+ OutlineInfo OI;
+ OI.EntryBB = AllocaBB;
+ OI.ExitBB = ExitBB;
+ OI.OuterAllocaBB = &OuterAllocaBB;
+ OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) {
+ // The input IR here looks like the following-
+ // ```
+ // func @current_fn() {
+ // outlined_fn(%args)
+ // }
+ // func @outlined_fn(%args) { ... }
+ // ```
+ //
+ // This is changed to the following-
+ //
+ // ```
+ // func @current_fn() {
+ // runtime_call(..., wrapper_fn, ...)
+ // }
+ // func @wrapper_fn(..., %args) {
+ // outlined_fn(%args)
+ // }
+ // func @outlined_fn(%args) { ... }
+ // ```
+
+ // The stale call instruction will be replaced with a new call instruction
+ // for runtime call with a wrapper function.
+
+ assert(OutlinedFn.getNumUses() == 1 &&
+ "there must be a single user for the outlined function");
+ CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+
+ // Create the wrapper function.
+ SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()};
+ for (auto &Arg : OutlinedFn.args())
+ WrapperArgTys.push_back(Arg.getType());
+ FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
+ (Twine(OutlinedFn.getName()) + ".teams").str(),
+ FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false));
+ Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
+ WrapperFunc->getArg(0)->setName("global_tid");
+ WrapperFunc->getArg(1)->setName("bound_tid");
+ if (WrapperFunc->arg_size() > 2)
+ WrapperFunc->getArg(2)->setName("data");
+
+ // Emit the body of the wrapper function - just a call to outlined function
+ // and return statement.
+ BasicBlock *WrapperEntryBB =
+ BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc);
+ Builder.SetInsertPoint(WrapperEntryBB);
+ SmallVector<Value *> Args;
+ for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++)
+ Args.push_back(WrapperFunc->getArg(ArgIndex));
+ Builder.CreateCall(&OutlinedFn, Args);
+ Builder.CreateRetVoid();
+
+ OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline);
+
+ // Call to the runtime function for teams in the current function.
+ assert(StaleCI && "Error while outlining - no CallInst user found for the "
+ "outlined function.");
+ Builder.SetInsertPoint(StaleCI);
+ Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc};
+ for (Use &Arg : StaleCI->args())
+ Args.push_back(Arg);
+ Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
+ omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
+ Args);
+ StaleCI->eraseFromParent();
+ };
+
+ addOutlineInfo(std::move(OI));
+
+ // Generate the body of teams.
+ InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
+ InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
+ BodyGenCB(AllocaIP, CodeGenIP);
+
+ Builder.SetInsertPoint(ExitBB, ExitBB->begin());
+
+ return Builder.saveIP();
+}
+
GlobalVariable *
OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
std::string VarName) {
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 2026824..fd524f6 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4001,6 +4001,87 @@ TEST_F(OpenMPIRBuilderTest, OMPAtomicCompareCapture) {
EXPECT_FALSE(verifyModule(*M, &errs()));
}
+TEST_F(OpenMPIRBuilderTest, CreateTeams) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+
+ AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty());
+ AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty());
+ Value *Val128 = Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "load");
+
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+ Builder.restoreIP(AllocaIP);
+ AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr,
+ "bodygen.alloca128");
+
+ Builder.restoreIP(CodeGenIP);
+ // Loading and storing captured pointer and values
+ Builder.CreateStore(Val128, Local128);
+ Value *Val32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32,
+ "bodygen.load32");
+
+ LoadInst *PrivLoad128 = Builder.CreateLoad(
+ Local128->getAllocatedType(), Local128, "bodygen.local.load128");
+ Value *Cmp = Builder.CreateICmpNE(
+ Val32, Builder.CreateTrunc(PrivLoad128, Val32->getType()));
+ Instruction *ThenTerm, *ElseTerm;
+ SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(),
+ &ThenTerm, &ElseTerm);
+ };
+
+ OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+ Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB));
+
+ OMPBuilder.finalize();
+ Builder.CreateRetVoid();
+
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+
+ CallInst *TeamsForkCall = dyn_cast<CallInst>(
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams)
+ ->user_back());
+
+ // Verify the Ident argument
+ GlobalVariable *Ident = cast<GlobalVariable>(TeamsForkCall->getArgOperand(0));
+ ASSERT_NE(Ident, nullptr);
+ EXPECT_TRUE(Ident->hasInitializer());
+ Constant *Initializer = Ident->getInitializer();
+ GlobalVariable *SrcStrGlob =
+ cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts());
+ ASSERT_NE(SrcStrGlob, nullptr);
+ ConstantDataArray *SrcSrc =
+ dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer());
+ ASSERT_NE(SrcSrc, nullptr);
+
+ // Verify the outlined function signature.
+ Function *WrapperFn =
+ dyn_cast<Function>(TeamsForkCall->getArgOperand(2)->stripPointerCasts());
+ ASSERT_NE(WrapperFn, nullptr);
+ EXPECT_FALSE(WrapperFn->isDeclaration());
+ EXPECT_TRUE(WrapperFn->arg_size() >= 3);
+ EXPECT_EQ(WrapperFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid
+ EXPECT_EQ(WrapperFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid
+ EXPECT_EQ(WrapperFn->getArg(2)->getType(),
+ Builder.getPtrTy()); // captured args
+
+ // Check for TruncInst and ICmpInst in the outlined function.
+ inst_range Instructions = instructions(WrapperFn);
+ auto OutlinedFnInst = find_if(
+ Instructions, [](Instruction &Inst) { return isa<CallInst>(&Inst); });
+ ASSERT_NE(OutlinedFnInst, Instructions.end());
+ CallInst *OutlinedFnCI = dyn_cast<CallInst>(&*OutlinedFnInst);
+ ASSERT_NE(OutlinedFnCI, nullptr);
+ Function *OutlinedFn = OutlinedFnCI->getCalledFunction();
+
+ EXPECT_TRUE(any_of(instructions(OutlinedFn),
+ [](Instruction &inst) { return isa<TruncInst>(&inst); }));
+ EXPECT_TRUE(any_of(instructions(OutlinedFn),
+ [](Instruction &inst) { return isa<ICmpInst>(&inst); }));
+}
+
/// Returns the single instruction of InstTy type in BB that uses the value V.
/// If there is more than one such instruction, returns null.
template <typename InstTy>