diff options
Diffstat (limited to 'clang/lib/Sema')
-rw-r--r-- | clang/lib/Sema/SemaOpenMP.cpp | 197 |
1 files changed, 140 insertions, 57 deletions
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index 6110e52..bab61e8 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -9815,6 +9815,25 @@ static Stmt *buildPreInits(ASTContext &Context, return nullptr; } +/// Append the \p Item or the content of a CompoundStmt to the list \p +/// TargetList. +/// +/// A CompoundStmt is used as container in case multiple statements need to be +/// stored in lieu of using an explicit list. Flattening is necessary because +/// contained DeclStmts need to be visible after the execution of the list. Used +/// for OpenMP pre-init declarations/statements. +static void appendFlattendedStmtList(SmallVectorImpl<Stmt *> &TargetList, + Stmt *Item) { + // nullptr represents an empty list. + if (!Item) + return; + + if (auto *CS = dyn_cast<CompoundStmt>(Item)) + llvm::append_range(TargetList, CS->body()); + else + TargetList.push_back(Item); +} + /// Build preinits statement for the given declarations. static Stmt * buildPreInits(ASTContext &Context, @@ -9828,6 +9847,17 @@ buildPreInits(ASTContext &Context, return nullptr; } +/// Build pre-init statement for the given statements. +static Stmt *buildPreInits(ASTContext &Context, ArrayRef<Stmt *> PreInits) { + if (PreInits.empty()) + return nullptr; + + SmallVector<Stmt *> Stmts; + for (Stmt *S : PreInits) + appendFlattendedStmtList(Stmts, S); + return CompoundStmt::Create(Context, PreInits, FPOptionsOverride(), {}, {}); +} + /// Build postupdate expression for the given list of postupdates expressions. static Expr *buildPostUpdate(Sema &S, ArrayRef<Expr *> PostUpdates) { Expr *PostUpdate = nullptr; @@ -9924,11 +9954,21 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr, Stmt *DependentPreInits = Transform->getPreInits(); if (!DependentPreInits) return; - for (Decl *C : cast<DeclStmt>(DependentPreInits)->getDeclGroup()) { - auto *D = cast<VarDecl>(C); - DeclRefExpr *Ref = buildDeclRefExpr(SemaRef, D, D->getType(), - Transform->getBeginLoc()); - Captures[Ref] = Ref; + + // Search for pre-init declared variables that need to be captured + // to be referenceable inside the directive. + SmallVector<Stmt *> Constituents; + appendFlattendedStmtList(Constituents, DependentPreInits); + for (Stmt *S : Constituents) { + if (auto *DC = dyn_cast<DeclStmt>(S)) { + for (Decl *C : DC->decls()) { + auto *D = cast<VarDecl>(C); + DeclRefExpr *Ref = buildDeclRefExpr( + SemaRef, D, D->getType().getNonReferenceType(), + Transform->getBeginLoc()); + Captures[Ref] = Ref; + } + } } })) return 0; @@ -15059,9 +15099,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective( bool SemaOpenMP::checkTransformableLoopNest( OpenMPDirectiveKind Kind, Stmt *AStmt, int NumLoops, SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers, - Stmt *&Body, - SmallVectorImpl<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>> - &OriginalInits) { + Stmt *&Body, SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits) { OriginalInits.emplace_back(); bool Result = OMPLoopBasedDirective::doForAllLoops( AStmt->IgnoreContainers(), /*TryImperfectlyNestedLoops=*/false, NumLoops, @@ -15095,16 +15133,70 @@ bool SemaOpenMP::checkTransformableLoopNest( DependentPreInits = Dir->getPreInits(); else llvm_unreachable("Unhandled loop transformation"); - if (!DependentPreInits) - return; - llvm::append_range(OriginalInits.back(), - cast<DeclStmt>(DependentPreInits)->getDeclGroup()); + + appendFlattendedStmtList(OriginalInits.back(), DependentPreInits); }); assert(OriginalInits.back().empty() && "No preinit after innermost loop"); OriginalInits.pop_back(); return Result; } +/// Add preinit statements that need to be propageted from the selected loop. +static void addLoopPreInits(ASTContext &Context, + OMPLoopBasedDirective::HelperExprs &LoopHelper, + Stmt *LoopStmt, ArrayRef<Stmt *> OriginalInit, + SmallVectorImpl<Stmt *> &PreInits) { + + // For range-based for-statements, ensure that their syntactic sugar is + // executed by adding them as pre-init statements. + if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt)) { + Stmt *RangeInit = CXXRangeFor->getInit(); + if (RangeInit) + PreInits.push_back(RangeInit); + + DeclStmt *RangeStmt = CXXRangeFor->getRangeStmt(); + PreInits.push_back(new (Context) DeclStmt(RangeStmt->getDeclGroup(), + RangeStmt->getBeginLoc(), + RangeStmt->getEndLoc())); + + DeclStmt *RangeEnd = CXXRangeFor->getEndStmt(); + PreInits.push_back(new (Context) DeclStmt(RangeEnd->getDeclGroup(), + RangeEnd->getBeginLoc(), + RangeEnd->getEndLoc())); + } + + llvm::append_range(PreInits, OriginalInit); + + // List of OMPCapturedExprDecl, for __begin, __end, and NumIterations + if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits)) { + PreInits.push_back(new (Context) DeclStmt( + PI->getDeclGroup(), PI->getBeginLoc(), PI->getEndLoc())); + } + + // Gather declarations for the data members used as counters. + for (Expr *CounterRef : LoopHelper.Counters) { + auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl(); + if (isa<OMPCapturedExprDecl>(CounterDecl)) + PreInits.push_back(new (Context) DeclStmt( + DeclGroupRef(CounterDecl), SourceLocation(), SourceLocation())); + } +} + +/// Collect the loop statements (ForStmt or CXXRangeForStmt) of the affected +/// loop of a construct. +static void collectLoopStmts(Stmt *AStmt, MutableArrayRef<Stmt *> LoopStmts) { + size_t NumLoops = LoopStmts.size(); + OMPLoopBasedDirective::doForAllLoops( + AStmt, /*TryImperfectlyNestedLoops=*/false, NumLoops, + [LoopStmts](unsigned Cnt, Stmt *CurStmt) { + assert(!LoopStmts[Cnt] && "Loop statement must not yet be assigned"); + LoopStmts[Cnt] = CurStmt; + return false; + }); + assert(!is_contained(LoopStmts, nullptr) && + "Expecting a loop statement for each affected loop"); +} + StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses, Stmt *AStmt, SourceLocation StartLoc, @@ -15126,8 +15218,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses, // Verify and diagnose loop nest. SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops); Stmt *Body = nullptr; - SmallVector<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>, 4> - OriginalInits; + SmallVector<SmallVector<Stmt *, 0>, 4> OriginalInits; if (!checkTransformableLoopNest(OMPD_tile, AStmt, NumLoops, LoopHelpers, Body, OriginalInits)) return StmtError(); @@ -15144,7 +15235,11 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses, "Expecting loop iteration space dimensionality to match number of " "affected loops"); - SmallVector<Decl *, 4> PreInits; + // Collect all affected loop statements. + SmallVector<Stmt *> LoopStmts(NumLoops, nullptr); + collectLoopStmts(AStmt, LoopStmts); + + SmallVector<Stmt *, 4> PreInits; CaptureVars CopyTransformer(SemaRef); // Create iteration variables for the generated loops. @@ -15184,20 +15279,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses, &SemaRef.PP.getIdentifierTable().get(TileCntName)); TileIndVars[I] = TileCntDecl; } - for (auto &P : OriginalInits[I]) { - if (auto *D = P.dyn_cast<Decl *>()) - PreInits.push_back(D); - else if (auto *PI = dyn_cast_or_null<DeclStmt>(P.dyn_cast<Stmt *>())) - PreInits.append(PI->decl_begin(), PI->decl_end()); - } - if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits)) - PreInits.append(PI->decl_begin(), PI->decl_end()); - // Gather declarations for the data members used as counters. - for (Expr *CounterRef : LoopHelper.Counters) { - auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl(); - if (isa<OMPCapturedExprDecl>(CounterDecl)) - PreInits.push_back(CounterDecl); - } + + addLoopPreInits(Context, LoopHelper, LoopStmts[I], OriginalInits[I], + PreInits); } // Once the original iteration values are set, append the innermost body. @@ -15246,19 +15330,20 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses, OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I]; Expr *NumIterations = LoopHelper.NumIterations; auto *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters[0]); - QualType CntTy = OrigCntVar->getType(); + QualType IVTy = NumIterations->getType(); + Stmt *LoopStmt = LoopStmts[I]; // Commonly used variables. One of the constraints of an AST is that every // node object must appear at most once, hence we define lamdas that create // a new AST node at every use. - auto MakeTileIVRef = [&SemaRef = this->SemaRef, &TileIndVars, I, CntTy, + auto MakeTileIVRef = [&SemaRef = this->SemaRef, &TileIndVars, I, IVTy, OrigCntVar]() { - return buildDeclRefExpr(SemaRef, TileIndVars[I], CntTy, + return buildDeclRefExpr(SemaRef, TileIndVars[I], IVTy, OrigCntVar->getExprLoc()); }; - auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, CntTy, + auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, IVTy, OrigCntVar]() { - return buildDeclRefExpr(SemaRef, FloorIndVars[I], CntTy, + return buildDeclRefExpr(SemaRef, FloorIndVars[I], IVTy, OrigCntVar->getExprLoc()); }; @@ -15320,6 +15405,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses, // further into the inner loop. SmallVector<Stmt *, 4> BodyParts; BodyParts.append(LoopHelper.Updates.begin(), LoopHelper.Updates.end()); + if (auto *SourceCXXFor = dyn_cast<CXXForRangeStmt>(LoopStmt)) + BodyParts.push_back(SourceCXXFor->getLoopVarStmt()); BodyParts.push_back(Inner); Inner = CompoundStmt::Create(Context, BodyParts, FPOptionsOverride(), Inner->getBeginLoc(), Inner->getEndLoc()); @@ -15334,12 +15421,14 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses, auto &LoopHelper = LoopHelpers[I]; Expr *NumIterations = LoopHelper.NumIterations; DeclRefExpr *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters[0]); - QualType CntTy = OrigCntVar->getType(); + QualType IVTy = NumIterations->getType(); - // Commonly used variables. - auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, CntTy, + // Commonly used variables. One of the constraints of an AST is that every + // node object must appear at most once, hence we define lamdas that create + // a new AST node at every use. + auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, IVTy, OrigCntVar]() { - return buildDeclRefExpr(SemaRef, FloorIndVars[I], CntTy, + return buildDeclRefExpr(SemaRef, FloorIndVars[I], IVTy, OrigCntVar->getExprLoc()); }; @@ -15405,8 +15494,7 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses, Stmt *Body = nullptr; SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers( NumLoops); - SmallVector<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>, NumLoops + 1> - OriginalInits; + SmallVector<SmallVector<Stmt *, 0>, NumLoops + 1> OriginalInits; if (!checkTransformableLoopNest(OMPD_unroll, AStmt, NumLoops, LoopHelpers, Body, OriginalInits)) return StmtError(); @@ -15418,6 +15506,10 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses, return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, NumGeneratedLoops, nullptr, nullptr); + assert(LoopHelpers.size() == NumLoops && + "Expecting a single-dimensional loop iteration space"); + assert(OriginalInits.size() == NumLoops && + "Expecting a single-dimensional loop iteration space"); OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front(); if (FullClause) { @@ -15481,24 +15573,13 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses, // of a canonical loop nest where these PreInits are emitted before the // outermost directive. + // Find the loop statement. + Stmt *LoopStmt = nullptr; + collectLoopStmts(AStmt, {LoopStmt}); + // Determine the PreInit declarations. - SmallVector<Decl *, 4> PreInits; - assert(OriginalInits.size() == 1 && - "Expecting a single-dimensional loop iteration space"); - for (auto &P : OriginalInits[0]) { - if (auto *D = P.dyn_cast<Decl *>()) - PreInits.push_back(D); - else if (auto *PI = dyn_cast_or_null<DeclStmt>(P.dyn_cast<Stmt *>())) - PreInits.append(PI->decl_begin(), PI->decl_end()); - } - if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits)) - PreInits.append(PI->decl_begin(), PI->decl_end()); - // Gather declarations for the data members used as counters. - for (Expr *CounterRef : LoopHelper.Counters) { - auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl(); - if (isa<OMPCapturedExprDecl>(CounterDecl)) - PreInits.push_back(CounterDecl); - } + SmallVector<Stmt *, 4> PreInits; + addLoopPreInits(Context, LoopHelper, LoopStmt, OriginalInits[0], PreInits); auto *IterationVarRef = cast<DeclRefExpr>(LoopHelper.IterationVarRef); QualType IVTy = IterationVarRef->getType(); @@ -15604,6 +15685,8 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses, // Inner For statement. SmallVector<Stmt *> InnerBodyStmts; InnerBodyStmts.append(LoopHelper.Updates.begin(), LoopHelper.Updates.end()); + if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt)) + InnerBodyStmts.push_back(CXXRangeFor->getLoopVarStmt()); InnerBodyStmts.push_back(Body); CompoundStmt *InnerBody = CompoundStmt::Create(getASTContext(), InnerBodyStmts, FPOptionsOverride(), |