diff options
author | Shraiysh Vaishay <Shraiysh.Vaishay@amd.com> | 2022-05-24 09:53:33 +0530 |
---|---|---|
committer | Shraiysh Vaishay <Shraiysh.Vaishay@amd.com> | 2022-05-24 10:22:11 +0530 |
commit | 7604c59bd2336ebb34f28de3e6c883abbdd3f7c7 (patch) | |
tree | e8d92aa67de581e27590c35691e1064bbc97b2cd /llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | |
parent | 7f680b260ffe34c648cbd3fd16615d8f5cdab39f (diff) | |
download | llvm-7604c59bd2336ebb34f28de3e6c883abbdd3f7c7.zip llvm-7604c59bd2336ebb34f28de3e6c883abbdd3f7c7.tar.gz llvm-7604c59bd2336ebb34f28de3e6c883abbdd3f7c7.tar.bz2 |
[OpenMP][IRBuilder] `omp task` support
This patch adds basic support for `omp task` to the OpenMPIRBuilder.
The outlined function after code extraction is called from a wrapper function with appropriate arguments. This wrapper function is passed to the runtime calls for task allocation.
This approach is different from the Clang approach - clang directly emits the runtime call to the outlined function. The outlining utility (OutlineInfo) simply outlines the code and generates a function call to the outlined function. After the function has been generated by the outlining utility, there is no easy way to alter the function arguments without meddling with the outlining itself. Hence the wrapper function approach is taken.
Reviewed By: Meinersbur
Differential Revision: https://reviews.llvm.org/D71989
Diffstat (limited to 'llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp')
-rw-r--r-- | llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | 165 |
1 files changed, 165 insertions, 0 deletions
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 5f7e322..aef5992 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4412,4 +4412,169 @@ TEST_F(OpenMPIRBuilderTest, EmitMapperCall) { EXPECT_TRUE(MapperCall->getOperand(8)->getType()->isPointerTy()); } +TEST_F(OpenMPIRBuilderTest, CreateTask) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty()); + AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty()); + Value *Val128 = + Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "bodygen.load"); + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) { + Builder.restoreIP(AllocaIP); + AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr, + "bodygen.alloca128"); + + Builder.restoreIP(CodeGenIP); + // Loading and storing captured pointer and values + Builder.CreateStore(Val128, Local128); + Value *Val32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32, + "bodygen.load32"); + + LoadInst *PrivLoad128 = Builder.CreateLoad( + Local128->getAllocatedType(), Local128, "bodygen.local.load128"); + Value *Cmp = Builder.CreateICmpNE( + Val32, Builder.CreateTrunc(PrivLoad128, Val32->getType())); + Instruction *ThenTerm, *ElseTerm; + SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(), + &ThenTerm, &ElseTerm); + }; + + BasicBlock *AllocaBB = Builder.GetInsertBlock(); + BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL); + Builder.restoreIP(OMPBuilder.createTask( + Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), + BodyGenCB)); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + EXPECT_FALSE(verifyModule(*M, &errs())); + + CallInst *TaskAllocCall = dyn_cast<CallInst>( + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc) + ->user_back()); + + // Verify the Ident argument + GlobalVariable *Ident = cast<GlobalVariable>(TaskAllocCall->getArgOperand(0)); + ASSERT_NE(Ident, nullptr); + EXPECT_TRUE(Ident->hasInitializer()); + Constant *Initializer = Ident->getInitializer(); + GlobalVariable *SrcStrGlob = + cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts()); + ASSERT_NE(SrcStrGlob, nullptr); + ConstantDataArray *SrcSrc = + dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer()); + ASSERT_NE(SrcSrc, nullptr); + + // Verify the num_threads argument. + CallInst *GTID = dyn_cast<CallInst>(TaskAllocCall->getArgOperand(1)); + ASSERT_NE(GTID, nullptr); + EXPECT_EQ(GTID->arg_size(), 1U); + EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num"); + + // Verify the flags + // TODO: Check for others flags. Currently testing only for tiedness. + ConstantInt *Flags = dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(2)); + ASSERT_NE(Flags, nullptr); + EXPECT_EQ(Flags->getSExtValue(), 1); + + // Verify the data size + ConstantInt *DataSize = + dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(3)); + ASSERT_NE(DataSize, nullptr); + EXPECT_EQ(DataSize->getSExtValue(), 24); // 64-bit pointer + 128-bit integer + + // TODO: Verify size of shared clause variables + + // Verify Wrapper function + Function *WrapperFunc = + dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts()); + ASSERT_NE(WrapperFunc, nullptr); + EXPECT_FALSE(WrapperFunc->isDeclaration()); + CallInst *OutlinedFnCall = dyn_cast<CallInst>(WrapperFunc->begin()->begin()); + ASSERT_NE(OutlinedFnCall, nullptr); + EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty()); + EXPECT_EQ(OutlinedFnCall->getArgOperand(0), WrapperFunc->getArg(1)); + + // Verify the presence of `trunc` and `icmp` instructions in Outlined function + Function *OutlinedFn = OutlinedFnCall->getCalledFunction(); + ASSERT_NE(OutlinedFn, nullptr); + EXPECT_TRUE(any_of(instructions(OutlinedFn), + [](Instruction &inst) { return isa<TruncInst>(&inst); })); + EXPECT_TRUE(any_of(instructions(OutlinedFn), + [](Instruction &inst) { return isa<ICmpInst>(&inst); })); + + // Verify the execution of the task + CallInst *TaskCall = dyn_cast<CallInst>( + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task) + ->user_back()); + ASSERT_NE(TaskCall, nullptr); + EXPECT_EQ(TaskCall->getArgOperand(0), Ident); + EXPECT_EQ(TaskCall->getArgOperand(1), GTID); + EXPECT_EQ(TaskCall->getArgOperand(2), TaskAllocCall); + + // Verify that the argument data has been copied + for (User *in : TaskAllocCall->users()) { + if (MemCpyInst *memCpyInst = dyn_cast<MemCpyInst>(in)) + EXPECT_EQ(memCpyInst->getDest(), TaskAllocCall); + } +} + +TEST_F(OpenMPIRBuilderTest, CreateTaskNoArgs) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {}; + + BasicBlock *AllocaBB = Builder.GetInsertBlock(); + BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL); + Builder.restoreIP(OMPBuilder.createTask( + Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), + BodyGenCB)); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + EXPECT_FALSE(verifyModule(*M, &errs())); +} + +TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {}; + BasicBlock *AllocaBB = Builder.GetInsertBlock(); + BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL); + Builder.restoreIP(OMPBuilder.createTask( + Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB, + /*Tied=*/false)); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + // Check for the `Tied` argument + CallInst *TaskAllocCall = dyn_cast<CallInst>( + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc) + ->user_back()); + ASSERT_NE(TaskAllocCall, nullptr); + ConstantInt *Flags = dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(2)); + ASSERT_NE(Flags, nullptr); + EXPECT_EQ(Flags->getZExtValue() & 1U, 0U); + + EXPECT_FALSE(verifyModule(*M, &errs())); +} + } // namespace |