diff options
-rw-r--r-- | llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h | 7 | ||||
-rw-r--r-- | llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 123 | ||||
-rw-r--r-- | llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | 81 |
3 files changed, 211 insertions, 0 deletions
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 523a071..1699ed3 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1889,6 +1889,13 @@ public: BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB); + /// Generator for `#omp teams` + /// + /// \param Loc The location where the teams construct was encountered. + /// \param BodyGenCB Callback that will generate the region code. + InsertPointTy createTeams(const LocationDescription &Loc, + BodyGenCallbackTy BodyGenCB); + /// Generate conditional branch and relevant BasicBlocks through which private /// threads copy the 'copyin' variables from Master copy to threadprivate /// copies. diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 1ace7d5..8a5f4cb 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -5735,6 +5735,129 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare( return Builder.saveIP(); } +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::createTeams(const LocationDescription &Loc, + BodyGenCallbackTy BodyGenCB) { + if (!updateToLocation(Loc)) + return InsertPointTy(); + + uint32_t SrcLocStrSize; + Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize); + Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize); + Function *CurrentFunction = Builder.GetInsertBlock()->getParent(); + + // Outer allocation basicblock is the entry block of the current function. + BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock(); + if (&OuterAllocaBB == Builder.GetInsertBlock()) { + BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry"); + Builder.SetInsertPoint(BodyBB, BodyBB->begin()); + } + + // The current basic block is split into four basic blocks. After outlining, + // they will be mapped as follows: + // ``` + // def current_fn() { + // current_basic_block: + // br label %teams.exit + // teams.exit: + // ; instructions after teams + // } + // + // def outlined_fn() { + // teams.alloca: + // br label %teams.body + // teams.body: + // ; instructions within teams body + // } + // ``` + BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, "teams.exit"); + BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.body"); + BasicBlock *AllocaBB = + splitBB(Builder, /*CreateBranch=*/true, "teams.alloca"); + + OutlineInfo OI; + OI.EntryBB = AllocaBB; + OI.ExitBB = ExitBB; + OI.OuterAllocaBB = &OuterAllocaBB; + OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) { + // The input IR here looks like the following- + // ``` + // func @current_fn() { + // outlined_fn(%args) + // } + // func @outlined_fn(%args) { ... } + // ``` + // + // This is changed to the following- + // + // ``` + // func @current_fn() { + // runtime_call(..., wrapper_fn, ...) + // } + // func @wrapper_fn(..., %args) { + // outlined_fn(%args) + // } + // func @outlined_fn(%args) { ... } + // ``` + + // The stale call instruction will be replaced with a new call instruction + // for runtime call with a wrapper function. + + assert(OutlinedFn.getNumUses() == 1 && + "there must be a single user for the outlined function"); + CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back()); + + // Create the wrapper function. + SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()}; + for (auto &Arg : OutlinedFn.args()) + WrapperArgTys.push_back(Arg.getType()); + FunctionCallee WrapperFuncVal = M.getOrInsertFunction( + (Twine(OutlinedFn.getName()) + ".teams").str(), + FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false)); + Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee()); + WrapperFunc->getArg(0)->setName("global_tid"); + WrapperFunc->getArg(1)->setName("bound_tid"); + if (WrapperFunc->arg_size() > 2) + WrapperFunc->getArg(2)->setName("data"); + + // Emit the body of the wrapper function - just a call to outlined function + // and return statement. + BasicBlock *WrapperEntryBB = + BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc); + Builder.SetInsertPoint(WrapperEntryBB); + SmallVector<Value *> Args; + for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++) + Args.push_back(WrapperFunc->getArg(ArgIndex)); + Builder.CreateCall(&OutlinedFn, Args); + Builder.CreateRetVoid(); + + OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline); + + // Call to the runtime function for teams in the current function. + assert(StaleCI && "Error while outlining - no CallInst user found for the " + "outlined function."); + Builder.SetInsertPoint(StaleCI); + Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc}; + for (Use &Arg : StaleCI->args()) + Args.push_back(Arg); + Builder.CreateCall(getOrCreateRuntimeFunctionPtr( + omp::RuntimeFunction::OMPRTL___kmpc_fork_teams), + Args); + StaleCI->eraseFromParent(); + }; + + addOutlineInfo(std::move(OI)); + + // Generate the body of teams. + InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin()); + InsertPointTy CodeGenIP(BodyBB, BodyBB->begin()); + BodyGenCB(AllocaIP, CodeGenIP); + + Builder.SetInsertPoint(ExitBB, ExitBB->begin()); + + return Builder.saveIP(); +} + GlobalVariable * OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names, std::string VarName) { diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 2026824..fd524f6 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4001,6 +4001,87 @@ TEST_F(OpenMPIRBuilderTest, OMPAtomicCompareCapture) { EXPECT_FALSE(verifyModule(*M, &errs())); } +TEST_F(OpenMPIRBuilderTest, CreateTeams) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty()); + AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty()); + Value *Val128 = Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "load"); + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) { + Builder.restoreIP(AllocaIP); + AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr, + "bodygen.alloca128"); + + Builder.restoreIP(CodeGenIP); + // Loading and storing captured pointer and values + Builder.CreateStore(Val128, Local128); + Value *Val32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32, + "bodygen.load32"); + + LoadInst *PrivLoad128 = Builder.CreateLoad( + Local128->getAllocatedType(), Local128, "bodygen.local.load128"); + Value *Cmp = Builder.CreateICmpNE( + Val32, Builder.CreateTrunc(PrivLoad128, Val32->getType())); + Instruction *ThenTerm, *ElseTerm; + SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(), + &ThenTerm, &ElseTerm); + }; + + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB)); + + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + EXPECT_FALSE(verifyModule(*M, &errs())); + + CallInst *TeamsForkCall = dyn_cast<CallInst>( + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams) + ->user_back()); + + // Verify the Ident argument + GlobalVariable *Ident = cast<GlobalVariable>(TeamsForkCall->getArgOperand(0)); + ASSERT_NE(Ident, nullptr); + EXPECT_TRUE(Ident->hasInitializer()); + Constant *Initializer = Ident->getInitializer(); + GlobalVariable *SrcStrGlob = + cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts()); + ASSERT_NE(SrcStrGlob, nullptr); + ConstantDataArray *SrcSrc = + dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer()); + ASSERT_NE(SrcSrc, nullptr); + + // Verify the outlined function signature. + Function *WrapperFn = + dyn_cast<Function>(TeamsForkCall->getArgOperand(2)->stripPointerCasts()); + ASSERT_NE(WrapperFn, nullptr); + EXPECT_FALSE(WrapperFn->isDeclaration()); + EXPECT_TRUE(WrapperFn->arg_size() >= 3); + EXPECT_EQ(WrapperFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid + EXPECT_EQ(WrapperFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid + EXPECT_EQ(WrapperFn->getArg(2)->getType(), + Builder.getPtrTy()); // captured args + + // Check for TruncInst and ICmpInst in the outlined function. + inst_range Instructions = instructions(WrapperFn); + auto OutlinedFnInst = find_if( + Instructions, [](Instruction &Inst) { return isa<CallInst>(&Inst); }); + ASSERT_NE(OutlinedFnInst, Instructions.end()); + CallInst *OutlinedFnCI = dyn_cast<CallInst>(&*OutlinedFnInst); + ASSERT_NE(OutlinedFnCI, nullptr); + Function *OutlinedFn = OutlinedFnCI->getCalledFunction(); + + EXPECT_TRUE(any_of(instructions(OutlinedFn), + [](Instruction &inst) { return isa<TruncInst>(&inst); })); + EXPECT_TRUE(any_of(instructions(OutlinedFn), + [](Instruction &inst) { return isa<ICmpInst>(&inst); })); +} + /// Returns the single instruction of InstTy type in BB that uses the value V. /// If there is more than one such instruction, returns null. template <typename InstTy> |