diff options
author | erichkeane <ekeane@nvidia.com> | 2024-12-17 07:39:20 -0800 |
---|---|---|
committer | erichkeane <ekeane@nvidia.com> | 2024-12-18 15:06:01 -0800 |
commit | e34cc7c99375c43e1698c78ec9150fa40c88d486 (patch) | |
tree | ac89a9bd53befe50c71804c750dc340c65a06719 /clang/lib | |
parent | e0526b0780f56eede09b05a859a93626ecdc6e4d (diff) | |
download | llvm-e34cc7c99375c43e1698c78ec9150fa40c88d486.zip llvm-e34cc7c99375c43e1698c78ec9150fa40c88d486.tar.gz llvm-e34cc7c99375c43e1698c78ec9150fa40c88d486.tar.bz2 |
[OpenACC] Implement 'wait' construct
The arguments to this are the same as for the 'wait' clause, so this
reuses all of that infrastructure. So all this has to do is support a
pair of clauses that are already implemented (if and async), plus create
an AST node. This patch does so, and adds proper testing.
Diffstat (limited to 'clang/lib')
-rw-r--r-- | clang/lib/AST/StmtOpenACC.cpp | 29 | ||||
-rw-r--r-- | clang/lib/AST/StmtPrinter.cpp | 28 | ||||
-rw-r--r-- | clang/lib/AST/StmtProfile.cpp | 8 | ||||
-rw-r--r-- | clang/lib/AST/TextNodeDumper.cpp | 4 | ||||
-rw-r--r-- | clang/lib/CodeGen/CGStmt.cpp | 3 | ||||
-rw-r--r-- | clang/lib/CodeGen/CodeGenFunction.h | 5 | ||||
-rw-r--r-- | clang/lib/Parse/ParseOpenACC.cpp | 22 | ||||
-rw-r--r-- | clang/lib/Sema/SemaExceptionSpec.cpp | 1 | ||||
-rw-r--r-- | clang/lib/Sema/SemaOpenACC.cpp | 86 | ||||
-rw-r--r-- | clang/lib/Sema/TreeTransform.h | 86 | ||||
-rw-r--r-- | clang/lib/Serialization/ASTReaderStmt.cpp | 22 | ||||
-rw-r--r-- | clang/lib/Serialization/ASTWriterStmt.cpp | 14 | ||||
-rw-r--r-- | clang/lib/StaticAnalyzer/Core/ExprEngine.cpp | 1 |
13 files changed, 248 insertions, 61 deletions
diff --git a/clang/lib/AST/StmtOpenACC.cpp b/clang/lib/AST/StmtOpenACC.cpp index fb73dfb..6d9f267 100644 --- a/clang/lib/AST/StmtOpenACC.cpp +++ b/clang/lib/AST/StmtOpenACC.cpp @@ -196,3 +196,32 @@ OpenACCHostDataConstruct *OpenACCHostDataConstruct::Create( Clauses, StructuredBlock); return Inst; } + +OpenACCWaitConstruct *OpenACCWaitConstruct::CreateEmpty(const ASTContext &C, + unsigned NumExprs, + unsigned NumClauses) { + void *Mem = C.Allocate( + OpenACCWaitConstruct::totalSizeToAlloc<Expr *, OpenACCClause *>( + NumExprs, NumClauses)); + + auto *Inst = new (Mem) OpenACCWaitConstruct(NumExprs, NumClauses); + return Inst; +} + +OpenACCWaitConstruct *OpenACCWaitConstruct::Create( + const ASTContext &C, SourceLocation Start, SourceLocation DirectiveLoc, + SourceLocation LParenLoc, Expr *DevNumExpr, SourceLocation QueuesLoc, + ArrayRef<Expr *> QueueIdExprs, SourceLocation RParenLoc, SourceLocation End, + ArrayRef<const OpenACCClause *> Clauses) { + + assert(llvm::all_of(QueueIdExprs, [](Expr *E) { return E != nullptr; })); + + void *Mem = C.Allocate( + OpenACCWaitConstruct::totalSizeToAlloc<Expr *, OpenACCClause *>( + QueueIdExprs.size() + 1, Clauses.size())); + + auto *Inst = new (Mem) + OpenACCWaitConstruct(Start, DirectiveLoc, LParenLoc, DevNumExpr, + QueuesLoc, QueueIdExprs, RParenLoc, End, Clauses); + return Inst; +} diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp index 488419a..ecc9b6e3 100644 --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -1238,6 +1238,34 @@ void StmtPrinter::VisitOpenACCHostDataConstruct(OpenACCHostDataConstruct *S) { PrintStmt(S->getStructuredBlock()); } +void StmtPrinter::VisitOpenACCWaitConstruct(OpenACCWaitConstruct *S) { + Indent() << "#pragma acc wait"; + if (!S->getLParenLoc().isInvalid()) { + OS << "("; + if (S->hasDevNumExpr()) { + OS << "devnum: "; + S->getDevNumExpr()->printPretty(OS, nullptr, Policy); + OS << " : "; + } + + if (S->hasQueuesTag()) + OS << "queues: "; + + llvm::interleaveComma(S->getQueueIdExprs(), OS, [&](const Expr *E) { + E->printPretty(OS, nullptr, Policy); + }); + + OS << ")"; + } + + if (!S->clauses().empty()) { + OS << ' '; + OpenACCClausePrinter Printer(OS, Policy); + Printer.VisitClauseList(S->clauses()); + } + OS << '\n'; +} + //===----------------------------------------------------------------------===// // Expr printing methods. //===----------------------------------------------------------------------===// diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index 1fb2387..fccd97d 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -2743,6 +2743,14 @@ void StmtProfiler::VisitOpenACCHostDataConstruct( P.VisitOpenACCClauseList(S->clauses()); } +void StmtProfiler::VisitOpenACCWaitConstruct(const OpenACCWaitConstruct *S) { + // VisitStmt covers 'children', so the exprs inside of it are covered. + VisitStmt(S); + + OpenACCClauseProfiler P{*this}; + P.VisitOpenACCClauseList(S->clauses()); +} + void StmtProfiler::VisitHLSLOutArgExpr(const HLSLOutArgExpr *S) { VisitStmt(S); } diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp index b5af10d..7cdffbe 100644 --- a/clang/lib/AST/TextNodeDumper.cpp +++ b/clang/lib/AST/TextNodeDumper.cpp @@ -2960,6 +2960,10 @@ void TextNodeDumper::VisitOpenACCHostDataConstruct( OS << " " << S->getDirectiveKind(); } +void TextNodeDumper::VisitOpenACCWaitConstruct(const OpenACCWaitConstruct *S) { + OS << " " << S->getDirectiveKind(); +} + void TextNodeDumper::VisitEmbedExpr(const EmbedExpr *S) { AddChild("begin", [=] { OS << S->getStartingElementPos(); }); AddChild("number of elements", [=] { OS << S->getDataElementCount(); }); diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp index 6c7a594..6c604f4 100644 --- a/clang/lib/CodeGen/CGStmt.cpp +++ b/clang/lib/CodeGen/CGStmt.cpp @@ -470,6 +470,9 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) { case Stmt::OpenACCHostDataConstructClass: EmitOpenACCHostDataConstruct(cast<OpenACCHostDataConstruct>(*S)); break; + case Stmt::OpenACCWaitConstructClass: + EmitOpenACCWaitConstruct(cast<OpenACCWaitConstruct>(*S)); + break; } } diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 092d553..847999c 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -4118,6 +4118,11 @@ public: EmitStmt(S.getStructuredBlock()); } + void EmitOpenACCWaitConstruct(const OpenACCWaitConstruct &S) { + // TODO OpenACC: Implement this. It is currently implemented as a 'no-op', + // but in the future we will implement some sort of IR. + } + //===--------------------------------------------------------------------===// // LValue Expression Emission //===--------------------------------------------------------------------===// diff --git a/clang/lib/Parse/ParseOpenACC.cpp b/clang/lib/Parse/ParseOpenACC.cpp index af175a4..5da34a2 100644 --- a/clang/lib/Parse/ParseOpenACC.cpp +++ b/clang/lib/Parse/ParseOpenACC.cpp @@ -573,6 +573,7 @@ bool doesDirectiveHaveAssociatedStmt(OpenACCDirectiveKind DirKind) { default: case OpenACCDirectiveKind::EnterData: case OpenACCDirectiveKind::ExitData: + case OpenACCDirectiveKind::Wait: return false; case OpenACCDirectiveKind::Parallel: case OpenACCDirectiveKind::Serial: @@ -604,6 +605,7 @@ unsigned getOpenACCScopeFlags(OpenACCDirectiveKind DirKind) { case OpenACCDirectiveKind::EnterData: case OpenACCDirectiveKind::ExitData: case OpenACCDirectiveKind::HostData: + case OpenACCDirectiveKind::Wait: return 0; case OpenACCDirectiveKind::Invalid: llvm_unreachable("Shouldn't be creating a scope for an invalid construct"); @@ -1288,7 +1290,8 @@ Parser::ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective) { return Result; } - Result.QueueIdExprs.push_back(Res.first.get()); + if (Res.first.isUsable()) + Result.QueueIdExprs.push_back(Res.first.get()); } return Result; @@ -1422,6 +1425,7 @@ Parser::ParseOpenACCDirective() { SourceLocation StartLoc = ConsumeAnnotationToken(); SourceLocation DirLoc = getCurToken().getLocation(); OpenACCDirectiveKind DirKind = ParseOpenACCDirectiveKind(*this); + Parser::OpenACCWaitParseInfo WaitInfo; getActions().OpenACC().ActOnConstruct(DirKind, DirLoc); @@ -1462,7 +1466,8 @@ Parser::ParseOpenACCDirective() { break; case OpenACCDirectiveKind::Wait: // OpenACC has an optional paren-wrapped 'wait-argument'. - if (ParseOpenACCWaitArgument(DirLoc, /*IsDirective=*/true).Failed) + WaitInfo = ParseOpenACCWaitArgument(DirLoc, /*IsDirective=*/true); + if (WaitInfo.Failed) T.skipToEnd(); else T.consumeClose(); @@ -1476,8 +1481,14 @@ Parser::ParseOpenACCDirective() { } // Parses the list of clauses, if present, plus set up return value. - OpenACCDirectiveParseInfo ParseInfo{DirKind, StartLoc, DirLoc, - SourceLocation{}, + OpenACCDirectiveParseInfo ParseInfo{DirKind, + StartLoc, + DirLoc, + T.getOpenLocation(), + T.getCloseLocation(), + /*EndLoc=*/SourceLocation{}, + WaitInfo.QueuesLoc, + WaitInfo.getAllExprs(), ParseOpenACCClauseList(DirKind)}; assert(Tok.is(tok::annot_pragma_openacc_end) && @@ -1529,6 +1540,7 @@ StmtResult Parser::ParseOpenACCDirectiveStmt() { } return getActions().OpenACC().ActOnEndStmtDirective( - DirInfo.DirKind, DirInfo.StartLoc, DirInfo.DirLoc, DirInfo.EndLoc, + DirInfo.DirKind, DirInfo.StartLoc, DirInfo.DirLoc, DirInfo.LParenLoc, + DirInfo.MiscLoc, DirInfo.Exprs, DirInfo.RParenLoc, DirInfo.EndLoc, DirInfo.Clauses, AssocStmt); } diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp index 2be6af2..505cc5e 100644 --- a/clang/lib/Sema/SemaExceptionSpec.cpp +++ b/clang/lib/Sema/SemaExceptionSpec.cpp @@ -1398,6 +1398,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) { case Expr::HLSLOutArgExprClass: case Stmt::OpenACCEnterDataConstructClass: case Stmt::OpenACCExitDataConstructClass: + case Stmt::OpenACCWaitConstructClass: // These expressions can never throw. return CT_Cannot; diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp index 11c1835..aa9097b 100644 --- a/clang/lib/Sema/SemaOpenACC.cpp +++ b/clang/lib/Sema/SemaOpenACC.cpp @@ -41,6 +41,7 @@ bool diagnoseConstructAppertainment(SemaOpenACC &S, OpenACCDirectiveKind K, case OpenACCDirectiveKind::EnterData: case OpenACCDirectiveKind::ExitData: case OpenACCDirectiveKind::HostData: + case OpenACCDirectiveKind::Wait: if (!IsStmt) return S.Diag(StartLoc, diag::err_acc_construct_appertainment) << K; break; @@ -566,6 +567,16 @@ bool checkValidAfterDeviceType( return true; } +// A temporary function that helps implement the 'not implemented' check at the +// top of each clause checking function. This should only be used in conjunction +// with the one being currently implemented/only updated after the entire +// construct has been implemented. +bool isDirectiveKindImplemented(OpenACCDirectiveKind DK) { + return isOpenACCComputeDirectiveKind(DK) || + isOpenACCCombinedDirectiveKind(DK) || isOpenACCDataDirectiveKind(DK) || + DK == OpenACCDirectiveKind::Loop || DK == OpenACCDirectiveKind::Wait; +} + class SemaOpenACCClauseVisitor { SemaOpenACC &SemaRef; ASTContext &Ctx; @@ -680,9 +691,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitIfClause( // constructs, and 'compute'/'combined'/'data' constructs are the only // constructs that can do anything with this yet, so skip/treat as // unimplemented in this case. - if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCDataDirectiveKind(Clause.getDirectiveKind())) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); // There is no prose in the standard that says duplicates aren't allowed, @@ -717,8 +726,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitSelfClause( // Restrictions only properly implemented on 'compute' constructs, and // 'compute' constructs are the only construct that can do anything with // this yet, so skip/treat as unimplemented in this case. - if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind())) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); // TODO OpenACC: When we implement this for 'update', this takes a @@ -915,9 +923,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitAsyncClause( // constructs, and 'compute'/'combined'/'data' constructs are the only // construct that can do anything with this yet, so skip/treat as // unimplemented in this case. - if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCDataDirectiveKind(Clause.getDirectiveKind())) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); // There is no prose in the standard that says duplicates aren't allowed, @@ -973,9 +979,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitPresentClause( // constructs, and 'compute'/'combined'/'data' constructs are the only // construct that can do anything with this yet, so skip/treat as // unimplemented in this case. - if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCDataDirectiveKind(Clause.getDirectiveKind())) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); // ActOnVar ensured that everything is a valid variable reference, so there // really isn't anything to do here. GCC does some duplicate-finding, though @@ -992,9 +996,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitCopyClause( // constructs, and 'compute'/'combined'/'data' constructs are the only // construct that can do anything with this yet, so skip/treat as // unimplemented in this case. - if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCDataDirectiveKind(Clause.getDirectiveKind())) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); // ActOnVar ensured that everything is a valid variable reference, so there // really isn't anything to do here. GCC does some duplicate-finding, though @@ -1011,9 +1013,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitCopyInClause( // constructs, and 'compute'/'combined'/'data' constructs are the only // construct that can do anything with this yet, so skip/treat as // unimplemented in this case. - if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCDataDirectiveKind(Clause.getDirectiveKind())) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); // ActOnVar ensured that everything is a valid variable reference, so there // really isn't anything to do here. GCC does some duplicate-finding, though @@ -1030,9 +1030,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitCopyOutClause( // constructs, and 'compute'/'combined'/'data' constructs are the only // construct that can do anything with this yet, so skip/treat as // unimplemented in this case. - if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCDataDirectiveKind(Clause.getDirectiveKind())) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); // ActOnVar ensured that everything is a valid variable reference, so there // really isn't anything to do here. GCC does some duplicate-finding, though @@ -1109,9 +1107,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitDevicePtrClause( // constructs, and 'compute'/'combined'/'data' constructs are the only // construct that can do anything with this yet, so skip/treat as // unimplemented in this case. - if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCDataDirectiveKind(Clause.getDirectiveKind())) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); // ActOnVar ensured that everything is a valid variable reference, but we @@ -1134,9 +1130,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitWaitClause( // constructs, and 'compute'/'combined'/'data' constructs are the only // construct that can do anything with this yet, so skip/treat as // unimplemented in this case. - if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind()) && - !isOpenACCDataDirectiveKind(Clause.getDirectiveKind())) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); return OpenACCWaitClause::Create( @@ -1150,10 +1144,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitDeviceTypeClause( // 'loop' constructs, and 'compute'/'combined'/'data'/'loop' constructs are // the only construct that can do anything with this yet, so skip/treat as // unimplemented in this case. - if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()) && - Clause.getDirectiveKind() != OpenACCDirectiveKind::Loop && - Clause.getDirectiveKind() != OpenACCDirectiveKind::Data && - !isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind())) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); // TODO OpenACC: Once we get enough of the CodeGen implemented that we have @@ -1347,8 +1338,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitVectorClause( // Restrictions only properly implemented on 'loop'/'combined' constructs, and // it is the only construct that can do anything with this, so skip/treat as // unimplemented for the routine constructs. - if (Clause.getDirectiveKind() != OpenACCDirectiveKind::Loop && - !isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind())) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); Expr *IntExpr = @@ -1446,8 +1436,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitWorkerClause( // Restrictions only properly implemented on 'loop'/'combined' constructs, and // it is the only construct that can do anything with this, so skip/treat as // unimplemented for the routine constructs. - if (Clause.getDirectiveKind() != OpenACCDirectiveKind::Loop && - !isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind())) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); Expr *IntExpr = @@ -1559,8 +1548,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitGangClause( // Restrictions only properly implemented on 'loop' constructs, and it is // the only construct that can do anything with this, so skip/treat as // unimplemented for the combined constructs. - if (Clause.getDirectiveKind() != OpenACCDirectiveKind::Loop && - !isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind())) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); // OpenACC 3.3 Section 2.9.11: A reduction clause may not appear on a loop @@ -1691,7 +1679,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitFinalizeClause( OpenACCClause *SemaOpenACCClauseVisitor::VisitIfPresentClause( SemaOpenACC::OpenACCParsedClause &Clause) { - if (Clause.getDirectiveKind() != OpenACCDirectiveKind::HostData) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); // There isn't anything to do here, this is only valid on one construct, and // has no associated rules. @@ -1704,7 +1692,7 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitSeqClause( // Restrictions only properly implemented on 'loop' constructs and combined , // and it is the only construct that can do anything with this, so skip/treat // as unimplemented for the routine constructs. - if (Clause.getDirectiveKind() == OpenACCDirectiveKind::Routine) + if (!isDirectiveKindImplemented(Clause.getDirectiveKind())) return isNotImplemented(); // OpenACC 3.3 2.9: @@ -1879,6 +1867,7 @@ bool PreserveLoopRAIIDepthInAssociatedStmtRAII(OpenACCDirectiveKind DK) { return true; case OpenACCDirectiveKind::EnterData: case OpenACCDirectiveKind::ExitData: + case OpenACCDirectiveKind::Wait: llvm_unreachable("Doesn't have an associated stmt"); default: case OpenACCDirectiveKind::Invalid: @@ -2308,6 +2297,10 @@ void SemaOpenACC::ActOnConstruct(OpenACCDirectiveKind K, // Nothing to do here, there is no real legalization that needs to happen // here as these constructs do not take any arguments. break; + case OpenACCDirectiveKind::Wait: + // Nothing really to do here, the arguments to the 'wait' should have + // already been handled by the time we get here. + break; default: Diag(DirLoc, diag::warn_acc_construct_unimplemented) << K; break; @@ -3637,12 +3630,11 @@ bool SemaOpenACC::ActOnStartStmtDirective( return diagnoseConstructAppertainment(*this, K, StartLoc, /*IsStmt=*/true); } -StmtResult SemaOpenACC::ActOnEndStmtDirective(OpenACCDirectiveKind K, - SourceLocation StartLoc, - SourceLocation DirLoc, - SourceLocation EndLoc, - ArrayRef<OpenACCClause *> Clauses, - StmtResult AssocStmt) { +StmtResult SemaOpenACC::ActOnEndStmtDirective( + OpenACCDirectiveKind K, SourceLocation StartLoc, SourceLocation DirLoc, + SourceLocation LParenLoc, SourceLocation MiscLoc, ArrayRef<Expr *> Exprs, + SourceLocation RParenLoc, SourceLocation EndLoc, + ArrayRef<OpenACCClause *> Clauses, StmtResult AssocStmt) { switch (K) { default: return StmtEmpty(); @@ -3685,6 +3677,11 @@ StmtResult SemaOpenACC::ActOnEndStmtDirective(OpenACCDirectiveKind K, getASTContext(), StartLoc, DirLoc, EndLoc, Clauses, AssocStmt.isUsable() ? AssocStmt.get() : nullptr); } + case OpenACCDirectiveKind::Wait: { + return OpenACCWaitConstruct::Create( + getASTContext(), StartLoc, DirLoc, LParenLoc, Exprs.front(), MiscLoc, + Exprs.drop_front(), RParenLoc, EndLoc, Clauses); + } } llvm_unreachable("Unhandled case in directive handling?"); } @@ -3697,6 +3694,7 @@ StmtResult SemaOpenACC::ActOnAssociatedStmt( llvm_unreachable("Unimplemented associated statement application"); case OpenACCDirectiveKind::EnterData: case OpenACCDirectiveKind::ExitData: + case OpenACCDirectiveKind::Wait: llvm_unreachable( "these don't have associated statements, so shouldn't get here"); case OpenACCDirectiveKind::Parallel: diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 04167e7..c097465 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -4087,8 +4087,9 @@ public: SourceLocation EndLoc, ArrayRef<OpenACCClause *> Clauses, StmtResult StrBlock) { - return getSema().OpenACC().ActOnEndStmtDirective(K, BeginLoc, DirLoc, - EndLoc, Clauses, StrBlock); + return getSema().OpenACC().ActOnEndStmtDirective( + K, BeginLoc, DirLoc, SourceLocation{}, SourceLocation{}, {}, + SourceLocation{}, EndLoc, Clauses, StrBlock); } StmtResult RebuildOpenACCLoopConstruct(SourceLocation BeginLoc, @@ -4097,7 +4098,8 @@ public: ArrayRef<OpenACCClause *> Clauses, StmtResult Loop) { return getSema().OpenACC().ActOnEndStmtDirective( - OpenACCDirectiveKind::Loop, BeginLoc, DirLoc, EndLoc, Clauses, Loop); + OpenACCDirectiveKind::Loop, BeginLoc, DirLoc, SourceLocation{}, + SourceLocation{}, {}, SourceLocation{}, EndLoc, Clauses, Loop); } StmtResult RebuildOpenACCCombinedConstruct(OpenACCDirectiveKind K, @@ -4106,8 +4108,9 @@ public: SourceLocation EndLoc, ArrayRef<OpenACCClause *> Clauses, StmtResult Loop) { - return getSema().OpenACC().ActOnEndStmtDirective(K, BeginLoc, DirLoc, - EndLoc, Clauses, Loop); + return getSema().OpenACC().ActOnEndStmtDirective( + K, BeginLoc, DirLoc, SourceLocation{}, SourceLocation{}, {}, + SourceLocation{}, EndLoc, Clauses, Loop); } StmtResult RebuildOpenACCDataConstruct(SourceLocation BeginLoc, @@ -4115,9 +4118,9 @@ public: SourceLocation EndLoc, ArrayRef<OpenACCClause *> Clauses, StmtResult StrBlock) { - return getSema().OpenACC().ActOnEndStmtDirective(OpenACCDirectiveKind::Data, - BeginLoc, DirLoc, EndLoc, - Clauses, StrBlock); + return getSema().OpenACC().ActOnEndStmtDirective( + OpenACCDirectiveKind::Data, BeginLoc, DirLoc, SourceLocation{}, + SourceLocation{}, {}, SourceLocation{}, EndLoc, Clauses, StrBlock); } StmtResult @@ -4125,7 +4128,8 @@ public: SourceLocation DirLoc, SourceLocation EndLoc, ArrayRef<OpenACCClause *> Clauses) { return getSema().OpenACC().ActOnEndStmtDirective( - OpenACCDirectiveKind::EnterData, BeginLoc, DirLoc, EndLoc, Clauses, {}); + OpenACCDirectiveKind::EnterData, BeginLoc, DirLoc, SourceLocation{}, + SourceLocation{}, {}, SourceLocation{}, EndLoc, Clauses, {}); } StmtResult @@ -4133,7 +4137,8 @@ public: SourceLocation DirLoc, SourceLocation EndLoc, ArrayRef<OpenACCClause *> Clauses) { return getSema().OpenACC().ActOnEndStmtDirective( - OpenACCDirectiveKind::ExitData, BeginLoc, DirLoc, EndLoc, Clauses, {}); + OpenACCDirectiveKind::ExitData, BeginLoc, DirLoc, SourceLocation{}, + SourceLocation{}, {}, SourceLocation{}, EndLoc, Clauses, {}); } StmtResult RebuildOpenACCHostDataConstruct(SourceLocation BeginLoc, @@ -4142,8 +4147,21 @@ public: ArrayRef<OpenACCClause *> Clauses, StmtResult StrBlock) { return getSema().OpenACC().ActOnEndStmtDirective( - OpenACCDirectiveKind::HostData, BeginLoc, DirLoc, EndLoc, Clauses, - StrBlock); + OpenACCDirectiveKind::HostData, BeginLoc, DirLoc, SourceLocation{}, + SourceLocation{}, {}, SourceLocation{}, EndLoc, Clauses, StrBlock); + } + + StmtResult RebuildOpenACCWaitConstruct( + SourceLocation BeginLoc, SourceLocation DirLoc, SourceLocation LParenLoc, + Expr *DevNumExpr, SourceLocation QueuesLoc, ArrayRef<Expr *> QueueIdExprs, + SourceLocation RParenLoc, SourceLocation EndLoc, + ArrayRef<OpenACCClause *> Clauses) { + llvm::SmallVector<Expr *> Exprs; + Exprs.push_back(DevNumExpr); + Exprs.insert(Exprs.end(), QueueIdExprs.begin(), QueueIdExprs.end()); + return getSema().OpenACC().ActOnEndStmtDirective( + OpenACCDirectiveKind::Wait, BeginLoc, DirLoc, LParenLoc, QueuesLoc, + Exprs, RParenLoc, EndLoc, Clauses, {}); } ExprResult RebuildOpenACCAsteriskSizeExpr(SourceLocation AsteriskLoc) { @@ -12331,6 +12349,50 @@ StmtResult TreeTransform<Derived>::TransformOpenACCHostDataConstruct( } template <typename Derived> +StmtResult +TreeTransform<Derived>::TransformOpenACCWaitConstruct(OpenACCWaitConstruct *C) { + getSema().OpenACC().ActOnConstruct(C->getDirectiveKind(), C->getBeginLoc()); + + ExprResult DevNumExpr; + if (C->hasDevNumExpr()) { + DevNumExpr = getDerived().TransformExpr(C->getDevNumExpr()); + + if (DevNumExpr.isUsable()) + DevNumExpr = getSema().OpenACC().ActOnIntExpr( + OpenACCDirectiveKind::Wait, OpenACCClauseKind::Invalid, + C->getBeginLoc(), DevNumExpr.get()); + } + + llvm::SmallVector<Expr *> QueueIdExprs; + + for (Expr *QE : C->getQueueIdExprs()) { + assert(QE && "Null queue id expr?"); + ExprResult NewEQ = getDerived().TransformExpr(QE); + + if (!NewEQ.isUsable()) + break; + NewEQ = getSema().OpenACC().ActOnIntExpr(OpenACCDirectiveKind::Wait, + OpenACCClauseKind::Invalid, + C->getBeginLoc(), NewEQ.get()); + if (NewEQ.isUsable()) + QueueIdExprs.push_back(NewEQ.get()); + } + + llvm::SmallVector<OpenACCClause *> TransformedClauses = + getDerived().TransformOpenACCClauseList(C->getDirectiveKind(), + C->clauses()); + + if (getSema().OpenACC().ActOnStartStmtDirective( + C->getDirectiveKind(), C->getBeginLoc(), TransformedClauses)) + return StmtError(); + + return getDerived().RebuildOpenACCWaitConstruct( + C->getBeginLoc(), C->getDirectiveLoc(), C->getLParenLoc(), + DevNumExpr.isUsable() ? DevNumExpr.get() : nullptr, C->getQueuesLoc(), + QueueIdExprs, C->getRParenLoc(), C->getEndLoc(), TransformedClauses); +} + +template <typename Derived> ExprResult TreeTransform<Derived>::TransformOpenACCAsteriskSizeExpr( OpenACCAsteriskSizeExpr *E) { if (getDerived().AlwaysRebuild()) diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp index 21ad6c5..8fe0412 100644 --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2870,6 +2870,22 @@ void ASTStmtReader::VisitOpenACCHostDataConstruct(OpenACCHostDataConstruct *S) { VisitOpenACCAssociatedStmtConstruct(S); } +void ASTStmtReader::VisitOpenACCWaitConstruct(OpenACCWaitConstruct *S) { + VisitStmt(S); + // Consume the count of Expressions. + (void)Record.readInt(); + VisitOpenACCConstructStmt(S); + S->LParenLoc = Record.readSourceLocation(); + S->RParenLoc = Record.readSourceLocation(); + S->QueuesLoc = Record.readSourceLocation(); + + for (unsigned I = 0; I < S->NumExprs; ++I) { + S->getExprPtr()[I] = cast_if_present<Expr>(Record.readSubStmt()); + assert((I == 0 || S->getExprPtr()[I] != nullptr) && + "Only first expression should be null"); + } +} + //===----------------------------------------------------------------------===// // HLSL Constructs/Directives. //===----------------------------------------------------------------------===// @@ -4365,6 +4381,12 @@ Stmt *ASTReader::ReadStmtFromStream(ModuleFile &F) { S = OpenACCHostDataConstruct::CreateEmpty(Context, NumClauses); break; } + case STMT_OPENACC_WAIT_CONSTRUCT: { + unsigned NumExprs = Record[ASTStmtReader::NumStmtFields]; + unsigned NumClauses = Record[ASTStmtReader::NumStmtFields + 1]; + S = OpenACCWaitConstruct::CreateEmpty(Context, NumExprs, NumClauses); + break; + } case EXPR_REQUIRES: { unsigned numLocalParameters = Record[ASTStmtReader::NumExprFields]; unsigned numRequirement = Record[ASTStmtReader::NumExprFields + 1]; diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp index e55cbe1..f13443d 100644 --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -2951,6 +2951,20 @@ void ASTStmtWriter::VisitOpenACCHostDataConstruct(OpenACCHostDataConstruct *S) { Code = serialization::STMT_OPENACC_HOST_DATA_CONSTRUCT; } +void ASTStmtWriter::VisitOpenACCWaitConstruct(OpenACCWaitConstruct *S) { + VisitStmt(S); + Record.push_back(S->getExprs().size()); + VisitOpenACCConstructStmt(S); + Record.AddSourceLocation(S->LParenLoc); + Record.AddSourceLocation(S->RParenLoc); + Record.AddSourceLocation(S->QueuesLoc); + + for(Expr *E : S->getExprs()) + Record.AddStmt(E); + + Code = serialization::STMT_OPENACC_WAIT_CONSTRUCT; +} + //===----------------------------------------------------------------------===// // HLSL Constructs/Directives. //===----------------------------------------------------------------------===// diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp index ae43c595..0a74a80 100644 --- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp +++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp @@ -1829,6 +1829,7 @@ void ExprEngine::Visit(const Stmt *S, ExplodedNode *Pred, case Stmt::OpenACCEnterDataConstructClass: case Stmt::OpenACCExitDataConstructClass: case Stmt::OpenACCHostDataConstructClass: + case Stmt::OpenACCWaitConstructClass: case Stmt::OMPUnrollDirectiveClass: case Stmt::OMPMetaDirectiveClass: case Stmt::HLSLOutArgExprClass: { |