aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDave Pagan <dave.pagan@amd.com>2023-12-07 16:54:27 -0600
committerDave Pagan <dave.pagan@amd.com>2024-03-19 10:48:13 -0500
commit85c9855fbab9a99056fc47ee05657aa3242e1829 (patch)
treec7ab3521640b4c8275e636f3fb079032e2195d18
parentbb118c4435b1b26d5d14188d6342e3b7356ba282 (diff)
downloadllvm-85c9855fbab9a99056fc47ee05657aa3242e1829.zip
llvm-85c9855fbab9a99056fc47ee05657aa3242e1829.tar.gz
llvm-85c9855fbab9a99056fc47ee05657aa3242e1829.tar.bz2
Moved check for whether a 'target teams loop' construct can potentially
be considered equivalent to 'target teams distribute parallel for' from CodeGen to Sema.
-rw-r--r--clang/include/clang/AST/StmtOpenMP.h11
-rw-r--r--clang/include/clang/Sema/Sema.h4
-rw-r--r--clang/lib/AST/StmtOpenMP.cpp3
-rw-r--r--clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp4
-rw-r--r--clang/lib/CodeGen/CGStmtOpenMP.cpp9
-rw-r--r--clang/lib/CodeGen/CodeGenModule.cpp76
-rw-r--r--clang/lib/CodeGen/CodeGenModule.h2
-rw-r--r--clang/lib/Sema/SemaOpenMP.cpp75
-rw-r--r--clang/lib/Serialization/ASTReaderStmt.cpp1
-rw-r--r--clang/lib/Serialization/ASTWriterStmt.cpp1
10 files changed, 101 insertions, 85 deletions
diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index 3cb3c10..f735fa5 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -6109,6 +6109,8 @@ public:
class OMPTargetTeamsGenericLoopDirective final : public OMPLoopDirective {
friend class ASTStmtReader;
friend class OMPExecutableDirective;
+ /// true if loop directive's associated loop can be a parallel for.
+ bool CanBeParallelFor = false;
/// Build directive with the given start and end location.
///
/// \param StartLoc Starting location of the directive kind.
@@ -6131,6 +6133,9 @@ class OMPTargetTeamsGenericLoopDirective final : public OMPLoopDirective {
llvm::omp::OMPD_target_teams_loop, SourceLocation(),
SourceLocation(), CollapsedNum) {}
+ /// Set whether associated loop can be a parallel for.
+ void setCanBeParallelFor(bool ParFor) { CanBeParallelFor = ParFor; }
+
public:
/// Creates directive with a list of \p Clauses.
///
@@ -6145,7 +6150,7 @@ public:
static OMPTargetTeamsGenericLoopDirective *
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
unsigned CollapsedNum, ArrayRef<OMPClause *> Clauses,
- Stmt *AssociatedStmt, const HelperExprs &Exprs);
+ Stmt *AssociatedStmt, const HelperExprs &Exprs, bool CanBeParallelFor);
/// Creates an empty directive with the place
/// for \a NumClauses clauses.
@@ -6159,6 +6164,10 @@ public:
unsigned CollapsedNum,
EmptyShell);
+ /// Return true if current loop directive's associated loop can be a
+ /// parallel for.
+ bool canBeParallelFor() const { return CanBeParallelFor; }
+
static bool classof(const Stmt *T) {
return T->getStmtClass() == OMPTargetTeamsGenericLoopDirectiveClass;
}
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 95ea5eb..e2d36d8 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -10219,6 +10219,10 @@ public:
bool isInstantiationRecord() const;
};
+ /// [target] teams loop is equivalent to parallel for if associated loop
+ /// nest meets certain critera.
+ bool teamsLoopCanBeParallelFor(Stmt *Astmt);
+
/// A stack object to be created when performing template
/// instantiation.
///
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index 426b358..d8519b2 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -2431,7 +2431,7 @@ OMPTeamsGenericLoopDirective::CreateEmpty(const ASTContext &C,
OMPTargetTeamsGenericLoopDirective *OMPTargetTeamsGenericLoopDirective::Create(
const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
unsigned CollapsedNum, ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
- const HelperExprs &Exprs) {
+ const HelperExprs &Exprs, bool CanBeParallelFor) {
auto *Dir = createDirective<OMPTargetTeamsGenericLoopDirective>(
C, Clauses, AssociatedStmt,
numLoopChildren(CollapsedNum, OMPD_target_teams_loop), StartLoc, EndLoc,
@@ -2473,6 +2473,7 @@ OMPTargetTeamsGenericLoopDirective *OMPTargetTeamsGenericLoopDirective::Create(
Dir->setCombinedNextUpperBound(Exprs.DistCombinedFields.NUB);
Dir->setCombinedDistCond(Exprs.DistCombinedFields.DistCond);
Dir->setCombinedParForInDistCond(Exprs.DistCombinedFields.ParForInDistCond);
+ Dir->setCanBeParallelFor(CanBeParallelFor);
return Dir;
}
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index 32119ad..4ae8706 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -661,7 +661,9 @@ static bool supportsSPMDExecutionMode(CodeGenModule &CGM,
case OMPD_target_teams_loop:
// Whether this is true or not depends on how the directive will
// eventually be emitted.
- return CGM.teamsLoopCanBeParallelFor(D);
+ if (auto *TTLD = dyn_cast<OMPTargetTeamsGenericLoopDirective>(&D))
+ return TTLD->canBeParallelFor();
+ return false;
case OMPD_parallel:
case OMPD_for:
case OMPD_parallel_for:
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index 0a4c1f9..f4bf2e1 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -1432,9 +1432,12 @@ void CodeGenFunction::EmitOMPReductionClauseFinal(
*this, D.getBeginLoc(),
isOpenMPWorksharingDirective(D.getDirectiveKind()));
}
+ bool TeamsLoopCanBeParallel = false;
+ if (auto *TTLD = dyn_cast<OMPTargetTeamsGenericLoopDirective>(&D))
+ TeamsLoopCanBeParallel = TTLD->canBeParallelFor();
bool WithNowait = D.getSingleClause<OMPNowaitClause>() ||
isOpenMPParallelDirective(D.getDirectiveKind()) ||
- CGM.teamsLoopCanBeParallelFor(D) ||
+ TeamsLoopCanBeParallel ||
ReductionKind == OMPD_simd;
bool SimpleReduction = ReductionKind == OMPD_simd;
// Emit nowait reduction if nowait clause is present or directive is a
@@ -8014,7 +8017,7 @@ static void emitTargetTeamsGenericLoopRegionAsDistribute(
void CodeGenFunction::EmitOMPTargetTeamsGenericLoopDirective(
const OMPTargetTeamsGenericLoopDirective &S) {
auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
- if (CGF.CGM.teamsLoopCanBeParallelFor(S))
+ if (S.canBeParallelFor())
emitTargetTeamsGenericLoopRegionAsParallel(CGF, Action, S);
else
emitTargetTeamsGenericLoopRegionAsDistribute(CGF, Action, S);
@@ -8027,7 +8030,7 @@ void CodeGenFunction::EmitOMPTargetTeamsGenericLoopDeviceFunction(
const OMPTargetTeamsGenericLoopDirective &S) {
// Emit SPMD target parallel loop region as a standalone region.
auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
- if (CGF.CGM.teamsLoopCanBeParallelFor(S))
+ if (S.canBeParallelFor())
emitTargetTeamsGenericLoopRegionAsParallel(CGF, Action, S);
else
emitTargetTeamsGenericLoopRegionAsDistribute(CGF, Action, S);
diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp
index eb72bcb..31ebd16b 100644
--- a/clang/lib/CodeGen/CodeGenModule.cpp
+++ b/clang/lib/CodeGen/CodeGenModule.cpp
@@ -7579,82 +7579,6 @@ void CodeGenModule::printPostfixForExternalizedDecl(llvm::raw_ostream &OS,
}
}
-namespace {
-/// A 'teams loop' with a nested 'loop bind(parallel)' or generic function
-/// call in the associated loop-nest cannot be a 'parllel for'.
-class TeamsLoopChecker final : public ConstStmtVisitor<TeamsLoopChecker> {
-public:
- TeamsLoopChecker(CodeGenModule &CGM)
- : CGM(CGM), TeamsLoopCanBeParallelFor{true} {}
- bool teamsLoopCanBeParallelFor() const { return TeamsLoopCanBeParallelFor; }
- // Is there a nested OpenMP loop bind(parallel)
- void VisitOMPExecutableDirective(const OMPExecutableDirective *D) {
- if (D->getDirectiveKind() == llvm::omp::Directive::OMPD_loop) {
- if (const auto *C = D->getSingleClause<OMPBindClause>())
- if (C->getBindKind() == OMPC_BIND_parallel) {
- TeamsLoopCanBeParallelFor = false;
- // No need to continue visiting any more
- return;
- }
- }
- for (const Stmt *Child : D->children())
- if (Child)
- Visit(Child);
- }
-
- void VisitCallExpr(const CallExpr *C) {
- // Function calls inhibit parallel loop translation of 'target teams loop'
- // unless the assume-no-nested-parallelism flag has been specified.
- // OpenMP API runtime library calls do not inhibit parallel loop
- // translation, regardless of the assume-no-nested-parallelism.
- if (C) {
- bool IsOpenMPAPI = false;
- auto *FD = dyn_cast_or_null<FunctionDecl>(C->getCalleeDecl());
- if (FD) {
- std::string Name = FD->getNameInfo().getAsString();
- IsOpenMPAPI = Name.find("omp_") == 0;
- }
- TeamsLoopCanBeParallelFor =
- IsOpenMPAPI || CGM.getLangOpts().OpenMPNoNestedParallelism;
- if (!TeamsLoopCanBeParallelFor)
- return;
- }
- for (const Stmt *Child : C->children())
- if (Child)
- Visit(Child);
- }
-
- void VisitCapturedStmt(const CapturedStmt *S) {
- if (!S)
- return;
- Visit(S->getCapturedDecl()->getBody());
- }
-
- void VisitStmt(const Stmt *S) {
- if (!S)
- return;
- for (const Stmt *Child : S->children())
- if (Child)
- Visit(Child);
- }
-
-private:
- CodeGenModule &CGM;
- bool TeamsLoopCanBeParallelFor;
-};
-} // namespace
-
-/// Determine if 'teams loop' can be emitted using 'parallel for'.
-bool CodeGenModule::teamsLoopCanBeParallelFor(const OMPExecutableDirective &D) {
- if (D.getDirectiveKind() != llvm::omp::Directive::OMPD_target_teams_loop)
- return false;
- assert(D.hasAssociatedStmt() &&
- "Loop directive must have associated statement.");
- TeamsLoopChecker Checker(*this);
- Checker.Visit(D.getAssociatedStmt());
- return Checker.teamsLoopCanBeParallelFor();
-}
-
void CodeGenModule::emitTargetTeamsLoopCodegenStatus(
std::string StatusMsg, const OMPExecutableDirective &D, bool IsDevice) {
#ifndef NDEBUG
diff --git a/clang/lib/CodeGen/CodeGenModule.h b/clang/lib/CodeGen/CodeGenModule.h
index df927fe..7e54a19 100644
--- a/clang/lib/CodeGen/CodeGenModule.h
+++ b/clang/lib/CodeGen/CodeGenModule.h
@@ -1532,8 +1532,6 @@ public:
LValueBaseInfo *BaseInfo = nullptr,
TBAAAccessInfo *TBAAInfo = nullptr);
bool stopAutoInit();
- /// Determine if 'teams loop' can be emitted using 'parallel for'.
- bool teamsLoopCanBeParallelFor(const OMPExecutableDirective &D);
/// Print the postfix for externalized static variable or kernels for single
/// source offloading languages CUDA and HIP. The unique postfix is created
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 31fe8ba..e6ec3e8 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -6137,6 +6137,78 @@ processImplicitMapsWithDefaultMappers(Sema &S, DSAStackTy *Stack,
}
}
+namespace {
+/// A 'teams loop' with a nested 'loop bind(parallel)' or generic function
+/// call in the associated loop-nest cannot be a 'parallel for'.
+class TeamsLoopChecker final
+ : public ConstStmtVisitor<TeamsLoopChecker> {
+ Sema &SemaRef;
+public:
+ bool teamsLoopCanBeParallelFor() const { return TeamsLoopCanBeParallelFor; }
+
+ // Is there a nested OpenMP loop bind(parallel)
+ void VisitOMPExecutableDirective(const OMPExecutableDirective *D) {
+ if (D->getDirectiveKind() == llvm::omp::Directive::OMPD_loop) {
+ if (const auto *C = D->getSingleClause<OMPBindClause>())
+ if (C->getBindKind() == OMPC_BIND_parallel) {
+ TeamsLoopCanBeParallelFor = false;
+ // No need to continue visiting any more
+ return;
+ }
+ }
+ for (const Stmt *Child : D->children())
+ if (Child)
+ Visit(Child);
+ }
+
+ void VisitCallExpr(const CallExpr *C) {
+ // Function calls inhibit parallel loop translation of 'target teams loop'
+ // unless the assume-no-nested-parallelism flag has been specified.
+ // OpenMP API runtime library calls do not inhibit parallel loop
+ // translation, regardless of the assume-no-nested-parallelism.
+ if (C) {
+ bool IsOpenMPAPI = false;
+ auto *FD = dyn_cast_or_null<FunctionDecl>(C->getCalleeDecl());
+ if (FD) {
+ std::string Name = FD->getNameInfo().getAsString();
+ IsOpenMPAPI = Name.find("omp_") == 0;
+ }
+ TeamsLoopCanBeParallelFor =
+ IsOpenMPAPI || SemaRef.getLangOpts().OpenMPNoNestedParallelism;
+ if (!TeamsLoopCanBeParallelFor)
+ return;
+ }
+ for (const Stmt *Child : C->children())
+ if (Child)
+ Visit(Child);
+ }
+
+ void VisitCapturedStmt(const CapturedStmt *S) {
+ if (!S)
+ return;
+ Visit(S->getCapturedDecl()->getBody());
+ }
+
+ void VisitStmt(const Stmt *S) {
+ if (!S)
+ return;
+ for (const Stmt *Child : S->children())
+ if (Child)
+ Visit(Child);
+ }
+ explicit TeamsLoopChecker(Sema &SemaRef)
+ : SemaRef(SemaRef), TeamsLoopCanBeParallelFor(true) {}
+private:
+ bool TeamsLoopCanBeParallelFor;
+};
+} // namespace
+
+bool Sema::teamsLoopCanBeParallelFor(Stmt *AStmt) {
+ TeamsLoopChecker Checker(*this);
+ Checker.Visit(AStmt);
+ return Checker.teamsLoopCanBeParallelFor();
+}
+
bool Sema::mapLoopConstruct(llvm::SmallVector<OMPClause *> &ClausesWithoutBind,
ArrayRef<OMPClause *> Clauses,
OpenMPBindClauseKind &BindKind,
@@ -10897,7 +10969,8 @@ StmtResult Sema::ActOnOpenMPTargetTeamsGenericLoopDirective(
setFunctionHasBranchProtectedScope();
return OMPTargetTeamsGenericLoopDirective::Create(
- Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B);
+ Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B,
+ teamsLoopCanBeParallelFor(AStmt));
}
StmtResult Sema::ActOnOpenMPParallelGenericLoopDirective(
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index 674ed47..8db1ec7 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2776,6 +2776,7 @@ void ASTStmtReader::VisitOMPTeamsGenericLoopDirective(
void ASTStmtReader::VisitOMPTargetTeamsGenericLoopDirective(
OMPTargetTeamsGenericLoopDirective *D) {
VisitOMPLoopDirective(D);
+ D->setCanBeParallelFor(Record.readBool());
}
void ASTStmtReader::VisitOMPParallelGenericLoopDirective(
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index 7ce48fe..e850329 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2823,6 +2823,7 @@ void ASTStmtWriter::VisitOMPTeamsGenericLoopDirective(
void ASTStmtWriter::VisitOMPTargetTeamsGenericLoopDirective(
OMPTargetTeamsGenericLoopDirective *D) {
VisitOMPLoopDirective(D);
+ Record.writeBool(D->canBeParallelFor());
Code = serialization::STMT_OMP_TARGET_TEAMS_GENERIC_LOOP_DIRECTIVE;
}