aboutsummaryrefslogtreecommitdiff
path: root/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp')
-rw-r--r--llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp65
1 files changed, 65 insertions, 0 deletions
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index aa120c1..92a118b 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5040,6 +5040,71 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
EXPECT_FALSE(verifyModule(*M, &errs()));
}
+TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> Builder(BB);
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
+ IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
+ BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+ Builder.SetInsertPoint(BodyBB);
+ Value *IfCondition = Builder.CreateICmp(
+ CmpInst::Predicate::ICMP_EQ, F->getArg(0),
+ ConstantInt::get(Type::getInt32Ty(M->getContext()), 0U));
+ OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL);
+ Builder.restoreIP(OMPBuilder.createTask(Loc, AllocaIP, BodyGenCB,
+ /*Tied=*/false, /*Final=*/nullptr,
+ IfCondition));
+ OMPBuilder.finalize();
+ Builder.CreateRetVoid();
+
+ EXPECT_FALSE(verifyModule(*M, &errs()));
+
+ CallInst *TaskAllocCall = dyn_cast<CallInst>(
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
+ ->user_back());
+ ASSERT_NE(TaskAllocCall, nullptr);
+
+ // Check the branching is based on the if condition argument.
+ BranchInst *IfConditionBranchInst =
+ dyn_cast<BranchInst>(TaskAllocCall->getParent()->getTerminator());
+ ASSERT_NE(IfConditionBranchInst, nullptr);
+ ASSERT_TRUE(IfConditionBranchInst->isConditional());
+ EXPECT_EQ(IfConditionBranchInst->getCondition(), IfCondition);
+
+ // Check that the `__kmpc_omp_task` executes only in the then branch.
+ CallInst *TaskCall = dyn_cast<CallInst>(
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task)
+ ->user_back());
+ ASSERT_NE(TaskCall, nullptr);
+ EXPECT_EQ(TaskCall->getParent(), IfConditionBranchInst->getSuccessor(0));
+
+ // Check that the OpenMP Runtime Functions specific to `if` clause execute
+ // only in the else branch. Also check that the function call is between the
+ // `__kmpc_omp_task_begin_if0` and `__kmpc_omp_task_complete_if0` calls.
+ CallInst *TaskBeginIfCall = dyn_cast<CallInst>(
+ OMPBuilder
+ .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0)
+ ->user_back());
+ CallInst *TaskCompleteCall = dyn_cast<CallInst>(
+ OMPBuilder
+ .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0)
+ ->user_back());
+ ASSERT_NE(TaskBeginIfCall, nullptr);
+ ASSERT_NE(TaskCompleteCall, nullptr);
+ Function *WrapperFunc =
+ dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
+ ASSERT_NE(WrapperFunc, nullptr);
+ CallInst *WrapperFuncCall = dyn_cast<CallInst>(WrapperFunc->user_back());
+ ASSERT_NE(WrapperFuncCall, nullptr);
+ EXPECT_EQ(TaskBeginIfCall->getParent(),
+ IfConditionBranchInst->getSuccessor(1));
+ EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), WrapperFuncCall);
+ EXPECT_EQ(WrapperFuncCall->getNextNonDebugInstruction(), TaskCompleteCall);
+}
+
TEST_F(OpenMPIRBuilderTest, CreateTaskgroup) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
OpenMPIRBuilder OMPBuilder(*M);