aboutsummaryrefslogtreecommitdiff
path: root/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp')
-rw-r--r--llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp146
1 files changed, 140 insertions, 6 deletions
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index d770fac..97cfc33 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4033,7 +4033,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
};
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
- Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB));
+ Builder.restoreIP(OMPBuilder.createTeams(
+ Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr,
+ /*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr));
OMPBuilder.finalize();
Builder.CreateRetVoid();
@@ -4095,7 +4097,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
Builder.restoreIP(OMPBuilder.createTeams(/*=*/Builder, BodyGenCB,
/*NumTeamsLower=*/nullptr,
/*NumTeamsUpper=*/nullptr,
- /*ThreadLimit=*/F->arg_begin()));
+ /*ThreadLimit=*/F->arg_begin(),
+ /*IfExpr=*/nullptr));
Builder.CreateRetVoid();
OMPBuilder.finalize();
@@ -4144,7 +4147,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsUpper) {
// `num_teams`
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB,
/*NumTeamsLower=*/nullptr,
- /*NumTeamsUpper=*/F->arg_begin()));
+ /*NumTeamsUpper=*/F->arg_begin(),
+ /*ThreadLimit=*/nullptr,
+ /*IfExpr=*/nullptr));
Builder.CreateRetVoid();
OMPBuilder.finalize();
@@ -4197,7 +4202,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsBoth) {
// `F` already has an integer argument, so we use that as upper bound to
// `num_teams`
Builder.restoreIP(
- OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper));
+ OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper,
+ /*ThreadLimit=*/nullptr, /*IfExpr=*/nullptr));
Builder.CreateRetVoid();
OMPBuilder.finalize();
@@ -4255,8 +4261,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
};
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
- Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower,
- NumTeamsUpper, ThreadLimit));
+ Builder.restoreIP(OMPBuilder.createTeams(
+ Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper, ThreadLimit, nullptr));
Builder.CreateRetVoid();
OMPBuilder.finalize();
@@ -4284,6 +4290,134 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));
}
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfCondition) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> &Builder = OMPBuilder.Builder;
+ Builder.SetInsertPoint(BB);
+
+ Value *IfExpr = Builder.CreateLoad(Builder.getInt1Ty(),
+ Builder.CreateAlloca(Builder.getInt1Ty()));
+
+ Function *FakeFunction =
+ Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ Builder.CreateCall(FakeFunction, {});
+ };
+
+ // `F` already has an integer argument, so we use that as upper bound to
+ // `num_teams`
+ Builder.restoreIP(OMPBuilder.createTeams(
+ Builder, BodyGenCB, /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr,
+ /*ThreadLimit=*/nullptr, IfExpr));
+
+ Builder.CreateRetVoid();
+ OMPBuilder.finalize();
+
+ ASSERT_FALSE(verifyModule(*M));
+
+ CallInst *PushNumTeamsCallInst =
+ findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
+ ASSERT_NE(PushNumTeamsCallInst, nullptr);
+ Value *NumTeamsLower = PushNumTeamsCallInst->getArgOperand(2);
+ Value *NumTeamsUpper = PushNumTeamsCallInst->getArgOperand(3);
+ Value *ThreadLimit = PushNumTeamsCallInst->getArgOperand(4);
+
+ // Check the lower_bound
+ ASSERT_NE(NumTeamsLower, nullptr);
+ SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLower);
+ ASSERT_NE(NumTeamsLowerSelectInst, nullptr);
+ EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExpr);
+ EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), Builder.getInt32(0));
+ EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1));
+
+ // Check the upper_bound
+ ASSERT_NE(NumTeamsUpper, nullptr);
+ SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpper);
+ ASSERT_NE(NumTeamsUpperSelectInst, nullptr);
+ EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExpr);
+ EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), Builder.getInt32(0));
+ EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1));
+
+ // Check thread_limit
+ EXPECT_EQ(ThreadLimit, Builder.getInt32(0));
+}
+
+TEST_F(OpenMPIRBuilderTest, CreateTeamsWithIfConditionAndNumTeams) {
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+ F->setName("func");
+ IRBuilder<> &Builder = OMPBuilder.Builder;
+ Builder.SetInsertPoint(BB);
+
+ Value *IfExpr = Builder.CreateLoad(
+ Builder.getInt32Ty(), Builder.CreateAlloca(Builder.getInt32Ty()));
+ Value *NumTeamsLower = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5));
+ Value *NumTeamsUpper =
+ Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10));
+ Value *ThreadLimit = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20));
+
+ Function *FakeFunction =
+ Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::ExternalLinkage, "fakeFunction", M.get());
+
+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+ Builder.restoreIP(CodeGenIP);
+ Builder.CreateCall(FakeFunction, {});
+ };
+
+ // `F` already has an integer argument, so we use that as upper bound to
+ // `num_teams`
+ Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower,
+ NumTeamsUpper, ThreadLimit, IfExpr));
+
+ Builder.CreateRetVoid();
+ OMPBuilder.finalize();
+
+ ASSERT_FALSE(verifyModule(*M));
+
+ CallInst *PushNumTeamsCallInst =
+ findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
+ ASSERT_NE(PushNumTeamsCallInst, nullptr);
+ Value *NumTeamsLowerArg = PushNumTeamsCallInst->getArgOperand(2);
+ Value *NumTeamsUpperArg = PushNumTeamsCallInst->getArgOperand(3);
+ Value *ThreadLimitArg = PushNumTeamsCallInst->getArgOperand(4);
+
+ // Get the boolean conversion of if expression
+ ASSERT_EQ(IfExpr->getNumUses(), 1U);
+ User *IfExprInst = IfExpr->user_back();
+ ICmpInst *IfExprCmpInst = dyn_cast<ICmpInst>(IfExprInst);
+ ASSERT_NE(IfExprCmpInst, nullptr);
+ EXPECT_EQ(IfExprCmpInst->getPredicate(), ICmpInst::Predicate::ICMP_NE);
+ EXPECT_EQ(IfExprCmpInst->getOperand(0), IfExpr);
+ EXPECT_EQ(IfExprCmpInst->getOperand(1), Builder.getInt32(0));
+
+ // Check the lower_bound
+ ASSERT_NE(NumTeamsLowerArg, nullptr);
+ SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLowerArg);
+ ASSERT_NE(NumTeamsLowerSelectInst, nullptr);
+ EXPECT_EQ(NumTeamsLowerSelectInst->getCondition(), IfExprCmpInst);
+ EXPECT_EQ(NumTeamsLowerSelectInst->getTrueValue(), NumTeamsLower);
+ EXPECT_EQ(NumTeamsLowerSelectInst->getFalseValue(), Builder.getInt32(1));
+
+ // Check the upper_bound
+ ASSERT_NE(NumTeamsUpperArg, nullptr);
+ SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpperArg);
+ ASSERT_NE(NumTeamsUpperSelectInst, nullptr);
+ EXPECT_EQ(NumTeamsUpperSelectInst->getCondition(), IfExprCmpInst);
+ EXPECT_EQ(NumTeamsUpperSelectInst->getTrueValue(), NumTeamsUpper);
+ EXPECT_EQ(NumTeamsUpperSelectInst->getFalseValue(), Builder.getInt32(1));
+
+ // Check thread_limit
+ EXPECT_EQ(ThreadLimitArg, ThreadLimit);
+}
+
/// Returns the single instruction of InstTy type in BB that uses the value V.
/// If there is more than one such instruction, returns null.
template <typename InstTy>