aboutsummaryrefslogtreecommitdiff
path: root/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
diff options
context:
space:
mode:
authorShraiysh Vaishay <Shraiysh.Vaishay@amd.com>2022-05-24 09:53:33 +0530
committerShraiysh Vaishay <Shraiysh.Vaishay@amd.com>2022-05-24 10:22:11 +0530
commit7604c59bd2336ebb34f28de3e6c883abbdd3f7c7 (patch)
treee8d92aa67de581e27590c35691e1064bbc97b2cd /llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
parent7f680b260ffe34c648cbd3fd16615d8f5cdab39f (diff)
downloadllvm-7604c59bd2336ebb34f28de3e6c883abbdd3f7c7.zip
llvm-7604c59bd2336ebb34f28de3e6c883abbdd3f7c7.tar.gz
llvm-7604c59bd2336ebb34f28de3e6c883abbdd3f7c7.tar.bz2
[OpenMP][IRBuilder] `omp task` support
This patch adds basic support for `omp task` to the OpenMPIRBuilder. The outlined function after code extraction is called from a wrapper function with appropriate arguments. This wrapper function is passed to the runtime calls for task allocation. This approach is different from the Clang approach - clang directly emits the runtime call to the outlined function. The outlining utility (OutlineInfo) simply outlines the code and generates a function call to the outlined function. After the function has been generated by the outlining utility, there is no easy way to alter the function arguments without meddling with the outlining itself. Hence the wrapper function approach is taken. Reviewed By: Meinersbur Differential Revision: https://reviews.llvm.org/D71989
Diffstat (limited to 'llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp')
-rw-r--r--llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp165
1 files changed, 165 insertions, 0 deletions
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 5f7e322..aef5992 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4412,4 +4412,169 @@ TEST_F(OpenMPIRBuilderTest, EmitMapperCall) {
EXPECT_TRUE(MapperCall->getOperand(8)->getType()->isPointerTy());
}
+TEST_F(OpenMPIRBuilderTest, CreateTask) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+
+ AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty());
+ AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty());
+ Value *Val128 =
+ Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "bodygen.load");
+
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+ Builder.restoreIP(AllocaIP);
+ AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr,
+ "bodygen.alloca128");
+
+ Builder.restoreIP(CodeGenIP);
+ // Loading and storing captured pointer and values
+ Builder.CreateStore(Val128, Local128);
+ Value *Val32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32,
+ "bodygen.load32");
+
+ LoadInst *PrivLoad128 = Builder.CreateLoad(
+ Local128->getAllocatedType(), Local128, "bodygen.local.load128");
+ Value *Cmp = Builder.CreateICmpNE(
+ Val32, Builder.CreateTrunc(PrivLoad128, Val32->getType()));
+ Instruction *ThenTerm, *ElseTerm;
+ SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(),
+ &ThenTerm, &ElseTerm);
+ };
+
+ BasicBlock *AllocaBB = Builder.GetInsertBlock();
+ BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+ OpenMPIRBuilder::LocationDescription Loc(
+ InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
+ Builder.restoreIP(OMPBuilder.createTask(
+ Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()),
+ BodyGenCB));
+ OMPBuilder.finalize();
+ Builder.CreateRetVoid();
+
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+
+ CallInst *TaskAllocCall = dyn_cast<CallInst>(
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
+ ->user_back());
+
+ // Verify the Ident argument
+ GlobalVariable *Ident = cast<GlobalVariable>(TaskAllocCall->getArgOperand(0));
+ ASSERT_NE(Ident, nullptr);
+ EXPECT_TRUE(Ident->hasInitializer());
+ Constant *Initializer = Ident->getInitializer();
+ GlobalVariable *SrcStrGlob =
+ cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts());
+ ASSERT_NE(SrcStrGlob, nullptr);
+ ConstantDataArray *SrcSrc =
+ dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer());
+ ASSERT_NE(SrcSrc, nullptr);
+
+ // Verify the num_threads argument.
+ CallInst *GTID = dyn_cast<CallInst>(TaskAllocCall->getArgOperand(1));
+ ASSERT_NE(GTID, nullptr);
+ EXPECT_EQ(GTID->arg_size(), 1U);
+ EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num");
+
+ // Verify the flags
+ // TODO: Check for others flags. Currently testing only for tiedness.
+ ConstantInt *Flags = dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(2));
+ ASSERT_NE(Flags, nullptr);
+ EXPECT_EQ(Flags->getSExtValue(), 1);
+
+ // Verify the data size
+ ConstantInt *DataSize =
+ dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(3));
+ ASSERT_NE(DataSize, nullptr);
+ EXPECT_EQ(DataSize->getSExtValue(), 24); // 64-bit pointer + 128-bit integer
+
+ // TODO: Verify size of shared clause variables
+
+ // Verify Wrapper function
+ Function *WrapperFunc =
+ dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
+ ASSERT_NE(WrapperFunc, nullptr);
+ EXPECT_FALSE(WrapperFunc->isDeclaration());
+ CallInst *OutlinedFnCall = dyn_cast<CallInst>(WrapperFunc->begin()->begin());
+ ASSERT_NE(OutlinedFnCall, nullptr);
+ EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty());
+ EXPECT_EQ(OutlinedFnCall->getArgOperand(0), WrapperFunc->getArg(1));
+
+ // Verify the presence of `trunc` and `icmp` instructions in Outlined function
+ Function *OutlinedFn = OutlinedFnCall->getCalledFunction();
+ ASSERT_NE(OutlinedFn, nullptr);
+ EXPECT_TRUE(any_of(instructions(OutlinedFn),
+ [](Instruction &inst) { return isa<TruncInst>(&inst); }));
+ EXPECT_TRUE(any_of(instructions(OutlinedFn),
+ [](Instruction &inst) { return isa<ICmpInst>(&inst); }));
+
+ // Verify the execution of the task
+ CallInst *TaskCall = dyn_cast<CallInst>(
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task)
+ ->user_back());
+ ASSERT_NE(TaskCall, nullptr);
+ EXPECT_EQ(TaskCall->getArgOperand(0), Ident);
+ EXPECT_EQ(TaskCall->getArgOperand(1), GTID);
+ EXPECT_EQ(TaskCall->getArgOperand(2), TaskAllocCall);
+
+ // Verify that the argument data has been copied
+ for (User *in : TaskAllocCall->users()) {
+ if (MemCpyInst *memCpyInst = dyn_cast<MemCpyInst>(in))
+ EXPECT_EQ(memCpyInst->getDest(), TaskAllocCall);
+ }
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTaskNoArgs) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
+
+ BasicBlock *AllocaBB = Builder.GetInsertBlock();
+ BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+ OpenMPIRBuilder::LocationDescription Loc(
+ InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
+ Builder.restoreIP(OMPBuilder.createTask(
+ Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()),
+ BodyGenCB));
+ OMPBuilder.finalize();
+ Builder.CreateRetVoid();
+
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
+ BasicBlock *AllocaBB = Builder.GetInsertBlock();
+ BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+ OpenMPIRBuilder::LocationDescription Loc(
+ InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
+ Builder.restoreIP(OMPBuilder.createTask(
+ Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB,
+ /*Tied=*/false));
+ OMPBuilder.finalize();
+ Builder.CreateRetVoid();
+
+ // Check for the `Tied` argument
+ CallInst *TaskAllocCall = dyn_cast<CallInst>(
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
+ ->user_back());
+ ASSERT_NE(TaskAllocCall, nullptr);
+ ConstantInt *Flags = dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(2));
+ ASSERT_NE(Flags, nullptr);
+ EXPECT_EQ(Flags->getZExtValue() & 1U, 0U);
+
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+}
+
} // namespace