diff options
author | Shilei Tian <i@tianshilei.me> | 2024-08-06 10:55:15 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-06 10:55:15 -0400 |
commit | cee594cf36dc6c737df61e5417a98e09d807bd06 (patch) | |
tree | a26fe182f696ec325844812beebb69b7afe45d93 /clang/lib/Sema/SemaOpenMP.cpp | |
parent | f0178d881ce61e82b49fa63dcd023eed57c0804b (diff) | |
download | llvm-cee594cf36dc6c737df61e5417a98e09d807bd06.zip llvm-cee594cf36dc6c737df61e5417a98e09d807bd06.tar.gz llvm-cee594cf36dc6c737df61e5417a98e09d807bd06.tar.bz2 |
[Clang][Sema][OpenMP] Allow `num_teams` to accept multiple expressions (#99732)
By the OpenMP standard, `num_teams` clause can only accept one
expression (for now). In this patch, we extend it to allow to accept
multiple expressions when it is used with `target teams ompx_bare`
construct. This will allow to launch a multi-dim grid, same as CUDA/HIP.
Diffstat (limited to 'clang/lib/Sema/SemaOpenMP.cpp')
-rw-r--r-- | clang/lib/Sema/SemaOpenMP.cpp | 91 |
1 files changed, 73 insertions, 18 deletions
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index 9b60afd..7d814e6b 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -13034,6 +13034,25 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetUpdateDirective( Clauses, AStmt); } +/// This checks whether a \p ClauseType clause \p C has at most \p Max +/// expression. If not, a diag of number \p Diag will be emitted. +template <typename ClauseType> +static bool checkNumExprsInClause(SemaBase &SemaRef, + ArrayRef<OMPClause *> Clauses, + unsigned MaxNum, unsigned Diag) { + auto ClauseItr = llvm::find_if(Clauses, llvm::IsaPred<ClauseType>); + if (ClauseItr == Clauses.end()) + return true; + const auto *C = cast<ClauseType>(*ClauseItr); + auto VarList = C->getVarRefs(); + if (VarList.size() > MaxNum) { + SemaRef.Diag(VarList[MaxNum]->getBeginLoc(), Diag) + << getOpenMPClauseName(C->getClauseKind()); + return false; + } + return true; +} + StmtResult SemaOpenMP::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses, Stmt *AStmt, SourceLocation StartLoc, @@ -13041,6 +13060,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses, if (!AStmt) return StmtError(); + if (!checkNumExprsInClause<OMPNumTeamsClause>( + *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) + return StmtError(); + // Report affected OpenMP target offloading behavior when in HIP lang-mode. if (getLangOpts().HIP && (DSAStack->getParentDirective() == OMPD_target)) Diag(StartLoc, diag::warn_hip_omp_target_directives); @@ -13815,6 +13838,14 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDirective( return StmtError(); } + unsigned ClauseMaxNumExprs = HasBareClause ? 3 : 1; + unsigned DiagNo = HasBareClause + ? diag::err_ompx_more_than_three_expr_not_allowed + : diag::err_omp_multi_expr_not_allowed; + if (!checkNumExprsInClause<OMPNumTeamsClause>(*this, Clauses, + ClauseMaxNumExprs, DiagNo)) + return StmtError(); + return OMPTargetTeamsDirective::Create(getASTContext(), StartLoc, EndLoc, Clauses, AStmt); } @@ -13825,6 +13856,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeDirective( if (!AStmt) return StmtError(); + if (!checkNumExprsInClause<OMPNumTeamsClause>( + *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) + return StmtError(); + CapturedStmt *CS = setBranchProtectedScope(SemaRef, OMPD_target_teams_distribute, AStmt); @@ -13851,6 +13886,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForDirective( if (!AStmt) return StmtError(); + if (!checkNumExprsInClause<OMPNumTeamsClause>( + *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) + return StmtError(); + CapturedStmt *CS = setBranchProtectedScope( SemaRef, OMPD_target_teams_distribute_parallel_for, AStmt); @@ -13878,6 +13917,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForSimdDirective( if (!AStmt) return StmtError(); + if (!checkNumExprsInClause<OMPNumTeamsClause>( + *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) + return StmtError(); + CapturedStmt *CS = setBranchProtectedScope( SemaRef, OMPD_target_teams_distribute_parallel_for_simd, AStmt); @@ -13908,6 +13951,10 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective( if (!AStmt) return StmtError(); + if (!checkNumExprsInClause<OMPNumTeamsClause>( + *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) + return StmtError(); + CapturedStmt *CS = setBranchProtectedScope( SemaRef, OMPD_target_teams_distribute_simd, AStmt); @@ -14955,9 +15002,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, case OMPC_ordered: Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr); break; - case OMPC_num_teams: - Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc); - break; case OMPC_thread_limit: Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc); break; @@ -15064,6 +15108,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, case OMPC_affinity: case OMPC_when: case OMPC_bind: + case OMPC_num_teams: default: llvm_unreachable("Clause is not allowed."); } @@ -16927,6 +16972,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind, static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier), ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc); break; + case OMPC_num_teams: + Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc); + break; case OMPC_if: case OMPC_depobj: case OMPC_final: @@ -16957,7 +17005,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind, case OMPC_device: case OMPC_threads: case OMPC_simd: - case OMPC_num_teams: case OMPC_thread_limit: case OMPC_priority: case OMPC_grainsize: @@ -21834,32 +21881,40 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const { return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl(); } -OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams, +OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList, SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc) { - Expr *ValExpr = NumTeams; - Stmt *HelperValStmt = nullptr; - - // OpenMP [teams Constrcut, Restrictions] - // The num_teams expression must evaluate to a positive integer value. - if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams, - /*StrictlyPositive=*/true)) + if (VarList.empty()) return nullptr; + for (Expr *ValExpr : VarList) { + // OpenMP [teams Constrcut, Restrictions] + // The num_teams expression must evaluate to a positive integer value. + if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams, + /*StrictlyPositive=*/true)) + return nullptr; + } + OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective(); OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause( DKind, OMPC_num_teams, getLangOpts().OpenMP); - if (CaptureRegion != OMPD_unknown && - !SemaRef.CurContext->isDependentContext()) { + if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext()) + return OMPNumTeamsClause::Create(getASTContext(), CaptureRegion, StartLoc, + LParenLoc, EndLoc, VarList, + /*PreInit=*/nullptr); + + llvm::MapVector<const Expr *, DeclRefExpr *> Captures; + SmallVector<Expr *, 3> Vars; + for (Expr *ValExpr : VarList) { ValExpr = SemaRef.MakeFullExpr(ValExpr).get(); - llvm::MapVector<const Expr *, DeclRefExpr *> Captures; ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get(); - HelperValStmt = buildPreInits(getASTContext(), Captures); + Vars.push_back(ValExpr); } - return new (getASTContext()) OMPNumTeamsClause( - ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc); + Stmt *PreInit = buildPreInits(getASTContext(), Captures); + return OMPNumTeamsClause::Create(getASTContext(), CaptureRegion, StartLoc, + LParenLoc, EndLoc, Vars, PreInit); } OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit, |