aboutsummaryrefslogtreecommitdiff
path: root/clang/lib/Sema/SemaOpenMP.cpp
diff options
context:
space:
mode:
authorShilei Tian <i@tianshilei.me>2024-08-06 10:55:15 -0400
committerGitHub <noreply@github.com>2024-08-06 10:55:15 -0400
commitcee594cf36dc6c737df61e5417a98e09d807bd06 (patch)
treea26fe182f696ec325844812beebb69b7afe45d93 /clang/lib/Sema/SemaOpenMP.cpp
parentf0178d881ce61e82b49fa63dcd023eed57c0804b (diff)
downloadllvm-cee594cf36dc6c737df61e5417a98e09d807bd06.zip
llvm-cee594cf36dc6c737df61e5417a98e09d807bd06.tar.gz
llvm-cee594cf36dc6c737df61e5417a98e09d807bd06.tar.bz2
[Clang][Sema][OpenMP] Allow `num_teams` to accept multiple expressions (#99732)
By the OpenMP standard, `num_teams` clause can only accept one expression (for now). In this patch, we extend it to allow to accept multiple expressions when it is used with `target teams ompx_bare` construct. This will allow to launch a multi-dim grid, same as CUDA/HIP.
Diffstat (limited to 'clang/lib/Sema/SemaOpenMP.cpp')
-rw-r--r--clang/lib/Sema/SemaOpenMP.cpp91
1 files changed, 73 insertions, 18 deletions
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 9b60afd..7d814e6b 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -13034,6 +13034,25 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetUpdateDirective(
Clauses, AStmt);
}
+/// This checks whether a \p ClauseType clause \p C has at most \p Max
+/// expression. If not, a diag of number \p Diag will be emitted.
+template <typename ClauseType>
+static bool checkNumExprsInClause(SemaBase &SemaRef,
+ ArrayRef<OMPClause *> Clauses,
+ unsigned MaxNum, unsigned Diag) {
+ auto ClauseItr = llvm::find_if(Clauses, llvm::IsaPred<ClauseType>);
+ if (ClauseItr == Clauses.end())
+ return true;
+ const auto *C = cast<ClauseType>(*ClauseItr);
+ auto VarList = C->getVarRefs();
+ if (VarList.size() > MaxNum) {
+ SemaRef.Diag(VarList[MaxNum]->getBeginLoc(), Diag)
+ << getOpenMPClauseName(C->getClauseKind());
+ return false;
+ }
+ return true;
+}
+
StmtResult SemaOpenMP::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses,
Stmt *AStmt,
SourceLocation StartLoc,
@@ -13041,6 +13060,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses,
if (!AStmt)
return StmtError();
+ if (!checkNumExprsInClause<OMPNumTeamsClause>(
+ *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
+ return StmtError();
+
// Report affected OpenMP target offloading behavior when in HIP lang-mode.
if (getLangOpts().HIP && (DSAStack->getParentDirective() == OMPD_target))
Diag(StartLoc, diag::warn_hip_omp_target_directives);
@@ -13815,6 +13838,14 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDirective(
return StmtError();
}
+ unsigned ClauseMaxNumExprs = HasBareClause ? 3 : 1;
+ unsigned DiagNo = HasBareClause
+ ? diag::err_ompx_more_than_three_expr_not_allowed
+ : diag::err_omp_multi_expr_not_allowed;
+ if (!checkNumExprsInClause<OMPNumTeamsClause>(*this, Clauses,
+ ClauseMaxNumExprs, DiagNo))
+ return StmtError();
+
return OMPTargetTeamsDirective::Create(getASTContext(), StartLoc, EndLoc,
Clauses, AStmt);
}
@@ -13825,6 +13856,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeDirective(
if (!AStmt)
return StmtError();
+ if (!checkNumExprsInClause<OMPNumTeamsClause>(
+ *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
+ return StmtError();
+
CapturedStmt *CS =
setBranchProtectedScope(SemaRef, OMPD_target_teams_distribute, AStmt);
@@ -13851,6 +13886,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForDirective(
if (!AStmt)
return StmtError();
+ if (!checkNumExprsInClause<OMPNumTeamsClause>(
+ *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
+ return StmtError();
+
CapturedStmt *CS = setBranchProtectedScope(
SemaRef, OMPD_target_teams_distribute_parallel_for, AStmt);
@@ -13878,6 +13917,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForSimdDirective(
if (!AStmt)
return StmtError();
+ if (!checkNumExprsInClause<OMPNumTeamsClause>(
+ *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
+ return StmtError();
+
CapturedStmt *CS = setBranchProtectedScope(
SemaRef, OMPD_target_teams_distribute_parallel_for_simd, AStmt);
@@ -13908,6 +13951,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective(
if (!AStmt)
return StmtError();
+ if (!checkNumExprsInClause<OMPNumTeamsClause>(
+ *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed))
+ return StmtError();
+
CapturedStmt *CS = setBranchProtectedScope(
SemaRef, OMPD_target_teams_distribute_simd, AStmt);
@@ -14955,9 +15002,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
case OMPC_ordered:
Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr);
break;
- case OMPC_num_teams:
- Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc);
- break;
case OMPC_thread_limit:
Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc);
break;
@@ -15064,6 +15108,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
case OMPC_affinity:
case OMPC_when:
case OMPC_bind:
+ case OMPC_num_teams:
default:
llvm_unreachable("Clause is not allowed.");
}
@@ -16927,6 +16972,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier),
ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc);
break;
+ case OMPC_num_teams:
+ Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc);
+ break;
case OMPC_if:
case OMPC_depobj:
case OMPC_final:
@@ -16957,7 +17005,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
case OMPC_device:
case OMPC_threads:
case OMPC_simd:
- case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
@@ -21834,32 +21881,40 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const {
return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl();
}
-OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams,
+OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc) {
- Expr *ValExpr = NumTeams;
- Stmt *HelperValStmt = nullptr;
-
- // OpenMP [teams Constrcut, Restrictions]
- // The num_teams expression must evaluate to a positive integer value.
- if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
- /*StrictlyPositive=*/true))
+ if (VarList.empty())
return nullptr;
+ for (Expr *ValExpr : VarList) {
+ // OpenMP [teams Constrcut, Restrictions]
+ // The num_teams expression must evaluate to a positive integer value.
+ if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
+ /*StrictlyPositive=*/true))
+ return nullptr;
+ }
+
OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause(
DKind, OMPC_num_teams, getLangOpts().OpenMP);
- if (CaptureRegion != OMPD_unknown &&
- !SemaRef.CurContext->isDependentContext()) {
+ if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext())
+ return OMPNumTeamsClause::Create(getASTContext(), CaptureRegion, StartLoc,
+ LParenLoc, EndLoc, VarList,
+ /*PreInit=*/nullptr);
+
+ llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
+ SmallVector<Expr *, 3> Vars;
+ for (Expr *ValExpr : VarList) {
ValExpr = SemaRef.MakeFullExpr(ValExpr).get();
- llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get();
- HelperValStmt = buildPreInits(getASTContext(), Captures);
+ Vars.push_back(ValExpr);
}
- return new (getASTContext()) OMPNumTeamsClause(
- ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
+ Stmt *PreInit = buildPreInits(getASTContext(), Captures);
+ return OMPNumTeamsClause::Create(getASTContext(), CaptureRegion, StartLoc,
+ LParenLoc, EndLoc, Vars, PreInit);
}
OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,