aboutsummaryrefslogtreecommitdiff
path: root/clang/lib/Sema/SemaOpenMP.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'clang/lib/Sema/SemaOpenMP.cpp')
-rw-r--r--clang/lib/Sema/SemaOpenMP.cpp833
1 files changed, 804 insertions, 29 deletions
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 48e06d1..0fa21e8 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -2490,7 +2490,8 @@ VarDecl *SemaOpenMP::isOpenMPCapturedDecl(ValueDecl *D, bool CheckScopeInfo,
DSAStackTy::DSAVarData DVarTop =
DSAStack->getTopDSA(D, DSAStack->isClauseParsingMode());
if (DVarTop.CKind != OMPC_unknown && isOpenMPPrivate(DVarTop.CKind) &&
- (!VD || VD->hasLocalStorage() || !DVarTop.AppliedToPointee))
+ (!VD || VD->hasLocalStorage() ||
+ !(DVarTop.AppliedToPointee && DVarTop.CKind != OMPC_reduction)))
return VD ? VD : cast<VarDecl>(DVarTop.PrivateCopy->getDecl());
// Threadprivate variables must not be captured.
if (isOpenMPThreadPrivate(DVarTop.CKind))
@@ -4569,6 +4570,7 @@ void SemaOpenMP::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind,
case OMPD_unroll:
case OMPD_reverse:
case OMPD_interchange:
+ case OMPD_fuse:
case OMPD_assume:
break;
default:
@@ -6410,6 +6412,10 @@ StmtResult SemaOpenMP::ActOnOpenMPExecutableDirective(
Res = ActOnOpenMPInterchangeDirective(ClausesWithImplicit, AStmt, StartLoc,
EndLoc);
break;
+ case OMPD_fuse:
+ Res =
+ ActOnOpenMPFuseDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc);
+ break;
case OMPD_for:
Res = ActOnOpenMPForDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc,
VarsWithInheritedDSA);
@@ -9488,7 +9494,9 @@ static bool checkOpenMPIterationSpace(
// sharing attributes.
VarsWithImplicitDSA.erase(LCDecl);
- assert(isOpenMPLoopDirective(DKind) && "DSA for non-loop vars");
+ assert((isOpenMPLoopDirective(DKind) ||
+ isOpenMPCanonicalLoopSequenceTransformationDirective(DKind)) &&
+ "DSA for non-loop vars");
// Check test-expr.
HasErrors |= ISC.checkAndSetCond(For ? For->getCond() : CXXFor->getCond());
@@ -9916,7 +9924,8 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr,
unsigned NumLoops = std::max(OrderedLoopCount, NestedLoopCount);
SmallVector<LoopIterationSpace, 4> IterSpaces(NumLoops);
if (!OMPLoopBasedDirective::doForAllLoops(
- AStmt->IgnoreContainers(!isOpenMPLoopTransformationDirective(DKind)),
+ AStmt->IgnoreContainers(
+ !isOpenMPCanonicalLoopNestTransformationDirective(DKind)),
SupportsNonPerfectlyNested, NumLoops,
[DKind, &SemaRef, &DSA, NumLoops, NestedLoopCount,
CollapseLoopCountExpr, OrderedLoopCountExpr, &VarsWithImplicitDSA,
@@ -9938,8 +9947,7 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr,
}
return false;
},
- [&SemaRef,
- &Captures](OMPCanonicalLoopNestTransformationDirective *Transform) {
+ [&SemaRef, &Captures](OMPLoopTransformationDirective *Transform) {
Stmt *DependentPreInits = Transform->getPreInits();
if (!DependentPreInits)
return;
@@ -9954,7 +9962,8 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr,
auto *D = cast<VarDecl>(C);
DeclRefExpr *Ref = buildDeclRefExpr(
SemaRef, D, D->getType().getNonReferenceType(),
- Transform->getBeginLoc());
+ cast<OMPExecutableDirective>(Transform->getDirective())
+ ->getBeginLoc());
Captures[Ref] = Ref;
}
}
@@ -14404,10 +14413,34 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective(
getASTContext(), StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B);
}
+/// Updates OriginalInits by checking Transform against loop transformation
+/// directives and appending their pre-inits if a match is found.
+static void updatePreInits(OMPLoopTransformationDirective *Transform,
+ SmallVectorImpl<Stmt *> &PreInits) {
+ Stmt *Dir = Transform->getDirective();
+ switch (Dir->getStmtClass()) {
+#define STMT(CLASS, PARENT)
+#define ABSTRACT_STMT(CLASS)
+#define COMMON_OMP_LOOP_TRANSFORMATION(CLASS, PARENT) \
+ case Stmt::CLASS##Class: \
+ appendFlattenedStmtList(PreInits, \
+ static_cast<const CLASS *>(Dir)->getPreInits()); \
+ break;
+#define OMPCANONICALLOOPNESTTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \
+ COMMON_OMP_LOOP_TRANSFORMATION(CLASS, PARENT)
+#define OMPCANONICALLOOPSEQUENCETRANSFORMATIONDIRECTIVE(CLASS, PARENT) \
+ COMMON_OMP_LOOP_TRANSFORMATION(CLASS, PARENT)
+#include "clang/AST/StmtNodes.inc"
+#undef COMMON_OMP_LOOP_TRANSFORMATION
+ default:
+ llvm_unreachable("Not a loop transformation");
+ }
+}
+
bool SemaOpenMP::checkTransformableLoopNest(
OpenMPDirectiveKind Kind, Stmt *AStmt, int NumLoops,
SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers,
- Stmt *&Body, SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits) {
+ Stmt *&Body, SmallVectorImpl<SmallVector<Stmt *>> &OriginalInits) {
OriginalInits.emplace_back();
bool Result = OMPLoopBasedDirective::doForAllLoops(
AStmt->IgnoreContainers(), /*TryImperfectlyNestedLoops=*/false, NumLoops,
@@ -14433,29 +14466,268 @@ bool SemaOpenMP::checkTransformableLoopNest(
OriginalInits.emplace_back();
return false;
},
- [&OriginalInits](OMPLoopBasedDirective *Transform) {
- Stmt *DependentPreInits;
- if (auto *Dir = dyn_cast<OMPTileDirective>(Transform))
- DependentPreInits = Dir->getPreInits();
- else if (auto *Dir = dyn_cast<OMPStripeDirective>(Transform))
- DependentPreInits = Dir->getPreInits();
- else if (auto *Dir = dyn_cast<OMPUnrollDirective>(Transform))
- DependentPreInits = Dir->getPreInits();
- else if (auto *Dir = dyn_cast<OMPReverseDirective>(Transform))
- DependentPreInits = Dir->getPreInits();
- else if (auto *Dir = dyn_cast<OMPInterchangeDirective>(Transform))
- DependentPreInits = Dir->getPreInits();
- else
- llvm_unreachable("Unhandled loop transformation");
-
- appendFlattenedStmtList(OriginalInits.back(), DependentPreInits);
+ [&OriginalInits](OMPLoopTransformationDirective *Transform) {
+ updatePreInits(Transform, OriginalInits.back());
});
assert(OriginalInits.back().empty() && "No preinit after innermost loop");
OriginalInits.pop_back();
return Result;
}
-/// Add preinit statements that need to be propageted from the selected loop.
+/// Counts the total number of OpenMP canonical nested loops, including the
+/// outermost loop (the original loop). PRECONDITION of this visitor is that it
+/// must be invoked from the original loop to be analyzed. The traversal stops
+/// for Decl's and Expr's given that they may contain inner loops that must not
+/// be counted.
+///
+/// Example AST structure for the code:
+///
+/// int main() {
+/// #pragma omp fuse
+/// {
+/// for (int i = 0; i < 100; i++) { <-- Outer loop
+/// []() {
+/// for(int j = 0; j < 100; j++) {} <-- NOT A LOOP (1)
+/// };
+/// for(int j = 0; j < 5; ++j) {} <-- Inner loop
+/// }
+/// for (int r = 0; i < 100; i++) { <-- Outer loop
+/// struct LocalClass {
+/// void bar() {
+/// for(int j = 0; j < 100; j++) {} <-- NOT A LOOP (2)
+/// }
+/// };
+/// for(int k = 0; k < 10; ++k) {} <-- Inner loop
+/// {x = 5; for(k = 0; k < 10; ++k) x += k; x}; <-- NOT A LOOP (3)
+/// }
+/// }
+/// }
+/// (1) because in a different function (here: a lambda)
+/// (2) because in a different function (here: class method)
+/// (3) because considered to be intervening-code of non-perfectly nested loop
+/// Result: Loop 'i' contains 2 loops, Loop 'r' also contains 2 loops.
+class NestedLoopCounterVisitor final : public DynamicRecursiveASTVisitor {
+private:
+ unsigned NestedLoopCount = 0;
+
+public:
+ explicit NestedLoopCounterVisitor() = default;
+
+ unsigned getNestedLoopCount() const { return NestedLoopCount; }
+
+ bool VisitForStmt(ForStmt *FS) override {
+ ++NestedLoopCount;
+ return true;
+ }
+
+ bool VisitCXXForRangeStmt(CXXForRangeStmt *FRS) override {
+ ++NestedLoopCount;
+ return true;
+ }
+
+ bool TraverseStmt(Stmt *S) override {
+ if (!S)
+ return true;
+
+ // Skip traversal of all expressions, including special cases like
+ // LambdaExpr, StmtExpr, BlockExpr, and RequiresExpr. These expressions
+ // may contain inner statements (and even loops), but they are not part
+ // of the syntactic body of the surrounding loop structure.
+ // Therefore must not be counted.
+ if (isa<Expr>(S))
+ return true;
+
+ // Only recurse into CompoundStmt (block {}) and loop bodies.
+ if (isa<CompoundStmt, ForStmt, CXXForRangeStmt>(S)) {
+ return DynamicRecursiveASTVisitor::TraverseStmt(S);
+ }
+
+ // Stop traversal of the rest of statements, that break perfect
+ // loop nesting, such as control flow (IfStmt, SwitchStmt...).
+ return true;
+ }
+
+ bool TraverseDecl(Decl *D) override {
+ // Stop in the case of finding a declaration, it is not important
+ // in order to find nested loops (Possible CXXRecordDecl, RecordDecl,
+ // FunctionDecl...).
+ return true;
+ }
+};
+
+bool SemaOpenMP::analyzeLoopSequence(Stmt *LoopSeqStmt,
+ LoopSequenceAnalysis &SeqAnalysis,
+ ASTContext &Context,
+ OpenMPDirectiveKind Kind) {
+ VarsWithInheritedDSAType TmpDSA;
+ // Helper Lambda to handle storing initialization and body statements for
+ // both ForStmt and CXXForRangeStmt.
+ auto StoreLoopStatements = [](LoopAnalysis &Analysis, Stmt *LoopStmt) {
+ if (auto *For = dyn_cast<ForStmt>(LoopStmt)) {
+ Analysis.OriginalInits.push_back(For->getInit());
+ Analysis.TheForStmt = For;
+ } else {
+ auto *CXXFor = cast<CXXForRangeStmt>(LoopStmt);
+ Analysis.OriginalInits.push_back(CXXFor->getBeginStmt());
+ Analysis.TheForStmt = CXXFor;
+ }
+ };
+
+ // Helper lambda functions to encapsulate the processing of different
+ // derivations of the canonical loop sequence grammar
+ // Modularized code for handling loop generation and transformations.
+ auto AnalyzeLoopGeneration = [&](Stmt *Child) {
+ auto *LoopTransform = cast<OMPLoopTransformationDirective>(Child);
+ Stmt *TransformedStmt = LoopTransform->getTransformedStmt();
+ unsigned NumGeneratedTopLevelLoops =
+ LoopTransform->getNumGeneratedTopLevelLoops();
+ // Handle the case where transformed statement is not available due to
+ // dependent contexts
+ if (!TransformedStmt) {
+ if (NumGeneratedTopLevelLoops > 0) {
+ SeqAnalysis.LoopSeqSize += NumGeneratedTopLevelLoops;
+ return true;
+ }
+ // Unroll full (0 loops produced)
+ Diag(Child->getBeginLoc(), diag::err_omp_not_for)
+ << 0 << getOpenMPDirectiveName(Kind);
+ return false;
+ }
+ // Handle loop transformations with multiple loop nests
+ // Unroll full
+ if (!NumGeneratedTopLevelLoops) {
+ Diag(Child->getBeginLoc(), diag::err_omp_not_for)
+ << 0 << getOpenMPDirectiveName(Kind);
+ return false;
+ }
+ // Loop transformatons such as split or loopranged fuse
+ if (NumGeneratedTopLevelLoops > 1) {
+ // Get the preinits related to this loop sequence generating
+ // loop transformation (i.e loopranged fuse, split...)
+ // These preinits differ slightly from regular inits/pre-inits related
+ // to single loop generating loop transformations (interchange, unroll)
+ // given that they are not bounded to a particular loop nest
+ // so they need to be treated independently
+ updatePreInits(LoopTransform, SeqAnalysis.LoopSequencePreInits);
+ return analyzeLoopSequence(TransformedStmt, SeqAnalysis, Context, Kind);
+ }
+ // Vast majority: (Tile, Unroll, Stripe, Reverse, Interchange, Fuse all)
+ // Process the transformed loop statement
+ LoopAnalysis &NewTransformedSingleLoop =
+ SeqAnalysis.Loops.emplace_back(Child);
+ unsigned IsCanonical = checkOpenMPLoop(
+ Kind, nullptr, nullptr, TransformedStmt, SemaRef, *DSAStack, TmpDSA,
+ NewTransformedSingleLoop.HelperExprs);
+
+ if (!IsCanonical)
+ return false;
+
+ StoreLoopStatements(NewTransformedSingleLoop, TransformedStmt);
+ updatePreInits(LoopTransform, NewTransformedSingleLoop.TransformsPreInits);
+
+ SeqAnalysis.LoopSeqSize++;
+ return true;
+ };
+
+ // Modularized code for handling regular canonical loops.
+ auto AnalyzeRegularLoop = [&](Stmt *Child) {
+ LoopAnalysis &NewRegularLoop = SeqAnalysis.Loops.emplace_back(Child);
+ unsigned IsCanonical =
+ checkOpenMPLoop(Kind, nullptr, nullptr, Child, SemaRef, *DSAStack,
+ TmpDSA, NewRegularLoop.HelperExprs);
+
+ if (!IsCanonical)
+ return false;
+
+ StoreLoopStatements(NewRegularLoop, Child);
+ NestedLoopCounterVisitor NLCV;
+ NLCV.TraverseStmt(Child);
+ return true;
+ };
+
+ // High level grammar validation.
+ for (Stmt *Child : LoopSeqStmt->children()) {
+ if (!Child)
+ continue;
+ // Skip over non-loop-sequence statements.
+ if (!LoopSequenceAnalysis::isLoopSequenceDerivation(Child)) {
+ Child = Child->IgnoreContainers();
+ // Ignore empty compound statement.
+ if (!Child)
+ continue;
+ // In the case of a nested loop sequence ignoring containers would not
+ // be enough, a recurisve transversal of the loop sequence is required.
+ if (isa<CompoundStmt>(Child)) {
+ if (!analyzeLoopSequence(Child, SeqAnalysis, Context, Kind))
+ return false;
+ // Already been treated, skip this children
+ continue;
+ }
+ }
+ // Regular loop sequence handling.
+ if (LoopSequenceAnalysis::isLoopSequenceDerivation(Child)) {
+ if (LoopAnalysis::isLoopTransformation(Child)) {
+ if (!AnalyzeLoopGeneration(Child))
+ return false;
+ // AnalyzeLoopGeneration updates SeqAnalysis.LoopSeqSize accordingly.
+ } else {
+ if (!AnalyzeRegularLoop(Child))
+ return false;
+ SeqAnalysis.LoopSeqSize++;
+ }
+ } else {
+ // Report error for invalid statement inside canonical loop sequence.
+ Diag(Child->getBeginLoc(), diag::err_omp_not_for)
+ << 0 << getOpenMPDirectiveName(Kind);
+ return false;
+ }
+ }
+ return true;
+}
+
+bool SemaOpenMP::checkTransformableLoopSequence(
+ OpenMPDirectiveKind Kind, Stmt *AStmt, LoopSequenceAnalysis &SeqAnalysis,
+ ASTContext &Context) {
+ // Following OpenMP 6.0 API Specification, a Canonical Loop Sequence follows
+ // the grammar:
+ //
+ // canonical-loop-sequence:
+ // {
+ // loop-sequence+
+ // }
+ // where loop-sequence can be any of the following:
+ // 1. canonical-loop-sequence
+ // 2. loop-nest
+ // 3. loop-sequence-generating-construct (i.e OMPLoopTransformationDirective)
+ //
+ // To recognise and traverse this structure the helper function
+ // analyzeLoopSequence serves as the recurisve entry point
+ // and tries to match the input AST to the canonical loop sequence grammar
+ // structure. This function will perform both a semantic and syntactical
+ // analysis of the given statement according to OpenMP 6.0 definition of
+ // the aforementioned canonical loop sequence.
+
+ // We expect an outer compound statement.
+ if (!isa<CompoundStmt>(AStmt)) {
+ Diag(AStmt->getBeginLoc(), diag::err_omp_not_a_loop_sequence)
+ << getOpenMPDirectiveName(Kind);
+ return false;
+ }
+
+ // Recursive entry point to process the main loop sequence
+ if (!analyzeLoopSequence(AStmt, SeqAnalysis, Context, Kind))
+ return false;
+
+ // Diagnose an empty loop sequence.
+ if (!SeqAnalysis.LoopSeqSize) {
+ Diag(AStmt->getBeginLoc(), diag::err_omp_empty_loop_sequence)
+ << getOpenMPDirectiveName(Kind);
+ return false;
+ }
+ return true;
+}
+
+/// Add preinit statements that need to be propagated from the selected loop.
static void addLoopPreInits(ASTContext &Context,
OMPLoopBasedDirective::HelperExprs &LoopHelper,
Stmt *LoopStmt, ArrayRef<Stmt *> OriginalInit,
@@ -14540,7 +14812,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
// Verify and diagnose loop nest.
SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops);
Stmt *Body = nullptr;
- SmallVector<SmallVector<Stmt *, 0>, 4> OriginalInits;
+ SmallVector<SmallVector<Stmt *>, 4> OriginalInits;
if (!checkTransformableLoopNest(OMPD_tile, AStmt, NumLoops, LoopHelpers, Body,
OriginalInits))
return StmtError();
@@ -14817,7 +15089,7 @@ StmtResult SemaOpenMP::ActOnOpenMPStripeDirective(ArrayRef<OMPClause *> Clauses,
// Verify and diagnose loop nest.
SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops);
Stmt *Body = nullptr;
- SmallVector<SmallVector<Stmt *, 0>, 4> OriginalInits;
+ SmallVector<SmallVector<Stmt *>, 4> OriginalInits;
if (!checkTransformableLoopNest(OMPD_stripe, AStmt, NumLoops, LoopHelpers,
Body, OriginalInits))
return StmtError();
@@ -15078,7 +15350,7 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
Stmt *Body = nullptr;
SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers(
NumLoops);
- SmallVector<SmallVector<Stmt *, 0>, NumLoops + 1> OriginalInits;
+ SmallVector<SmallVector<Stmt *>, NumLoops + 1> OriginalInits;
if (!checkTransformableLoopNest(OMPD_unroll, AStmt, NumLoops, LoopHelpers,
Body, OriginalInits))
return StmtError();
@@ -15348,7 +15620,7 @@ StmtResult SemaOpenMP::ActOnOpenMPReverseDirective(Stmt *AStmt,
Stmt *Body = nullptr;
SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers(
NumLoops);
- SmallVector<SmallVector<Stmt *, 0>, NumLoops + 1> OriginalInits;
+ SmallVector<SmallVector<Stmt *>, NumLoops + 1> OriginalInits;
if (!checkTransformableLoopNest(OMPD_reverse, AStmt, NumLoops, LoopHelpers,
Body, OriginalInits))
return StmtError();
@@ -15540,7 +15812,7 @@ StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective(
// Verify and diagnose loop nest.
SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops);
Stmt *Body = nullptr;
- SmallVector<SmallVector<Stmt *, 0>, 2> OriginalInits;
+ SmallVector<SmallVector<Stmt *>, 2> OriginalInits;
if (!checkTransformableLoopNest(OMPD_interchange, AStmt, NumLoops,
LoopHelpers, Body, OriginalInits))
return StmtError();
@@ -15716,6 +15988,484 @@ StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective(
buildPreInits(Context, PreInits));
}
+StmtResult SemaOpenMP::ActOnOpenMPFuseDirective(ArrayRef<OMPClause *> Clauses,
+ Stmt *AStmt,
+ SourceLocation StartLoc,
+ SourceLocation EndLoc) {
+
+ ASTContext &Context = getASTContext();
+ DeclContext *CurrContext = SemaRef.CurContext;
+ Scope *CurScope = SemaRef.getCurScope();
+ CaptureVars CopyTransformer(SemaRef);
+
+ // Ensure the structured block is not empty
+ if (!AStmt)
+ return StmtError();
+
+ // Defer transformation in dependent contexts
+ // The NumLoopNests argument is set to a placeholder 1 (even though
+ // using looprange fuse could yield up to 3 top level loop nests)
+ // because a dependent context could prevent determining its true value
+ if (CurrContext->isDependentContext())
+ return OMPFuseDirective::Create(Context, StartLoc, EndLoc, Clauses,
+ /* NumLoops */ 1, AStmt, nullptr, nullptr);
+
+ // Validate that the potential loop sequence is transformable for fusion
+ // Also collect the HelperExprs, Loop Stmts, Inits, and Number of loops
+ LoopSequenceAnalysis SeqAnalysis;
+ if (!checkTransformableLoopSequence(OMPD_fuse, AStmt, SeqAnalysis, Context))
+ return StmtError();
+
+ // SeqAnalysis.LoopSeqSize exists mostly to handle dependent contexts,
+ // otherwise it must be the same as SeqAnalysis.Loops.size().
+ assert(SeqAnalysis.LoopSeqSize == SeqAnalysis.Loops.size() &&
+ "Inconsistent size of the loop sequence and the number of loops "
+ "found in the sequence");
+
+ // Handle clauses, which can be any of the following: [looprange, apply]
+ const auto *LRC =
+ OMPExecutableDirective::getSingleClause<OMPLoopRangeClause>(Clauses);
+
+ // The clause arguments are invalidated if any error arises
+ // such as non-constant or non-positive arguments
+ if (LRC && (!LRC->getFirst() || !LRC->getCount()))
+ return StmtError();
+
+ // Delayed semantic check of LoopRange constraint
+ // Evaluates the loop range arguments and returns the first and count values
+ auto EvaluateLoopRangeArguments = [&Context](Expr *First, Expr *Count,
+ uint64_t &FirstVal,
+ uint64_t &CountVal) {
+ llvm::APSInt FirstInt = First->EvaluateKnownConstInt(Context);
+ llvm::APSInt CountInt = Count->EvaluateKnownConstInt(Context);
+ FirstVal = FirstInt.getZExtValue();
+ CountVal = CountInt.getZExtValue();
+ };
+
+ // OpenMP [6.0, Restrictions]
+ // first + count - 1 must not evaluate to a value greater than the
+ // loop sequence length of the associated canonical loop sequence.
+ auto ValidLoopRange = [](uint64_t FirstVal, uint64_t CountVal,
+ unsigned NumLoops) -> bool {
+ return FirstVal + CountVal - 1 <= NumLoops;
+ };
+ uint64_t FirstVal = 1, CountVal = 0, LastVal = SeqAnalysis.LoopSeqSize;
+
+ // Validates the loop range after evaluating the semantic information
+ // and ensures that the range is valid for the given loop sequence size.
+ // Expressions are evaluated at compile time to obtain constant values.
+ if (LRC) {
+ EvaluateLoopRangeArguments(LRC->getFirst(), LRC->getCount(), FirstVal,
+ CountVal);
+ if (CountVal == 1)
+ SemaRef.Diag(LRC->getCountLoc(), diag::warn_omp_redundant_fusion)
+ << getOpenMPDirectiveName(OMPD_fuse);
+
+ if (!ValidLoopRange(FirstVal, CountVal, SeqAnalysis.LoopSeqSize)) {
+ SemaRef.Diag(LRC->getFirstLoc(), diag::err_omp_invalid_looprange)
+ << getOpenMPDirectiveName(OMPD_fuse) << FirstVal
+ << (FirstVal + CountVal - 1) << SeqAnalysis.LoopSeqSize;
+ return StmtError();
+ }
+
+ LastVal = FirstVal + CountVal - 1;
+ }
+
+ // Complete fusion generates a single canonical loop nest
+ // However looprange clause may generate several loop nests
+ unsigned NumGeneratedTopLevelLoops =
+ LRC ? SeqAnalysis.LoopSeqSize - CountVal + 1 : 1;
+
+ // Emit a warning for redundant loop fusion when the sequence contains only
+ // one loop.
+ if (SeqAnalysis.LoopSeqSize == 1)
+ SemaRef.Diag(AStmt->getBeginLoc(), diag::warn_omp_redundant_fusion)
+ << getOpenMPDirectiveName(OMPD_fuse);
+
+ // Select the type with the largest bit width among all induction variables
+ QualType IVType =
+ SeqAnalysis.Loops[FirstVal - 1].HelperExprs.IterationVarRef->getType();
+ for (unsigned I : llvm::seq<unsigned>(FirstVal, LastVal)) {
+ QualType CurrentIVType =
+ SeqAnalysis.Loops[I].HelperExprs.IterationVarRef->getType();
+ if (Context.getTypeSize(CurrentIVType) > Context.getTypeSize(IVType)) {
+ IVType = CurrentIVType;
+ }
+ }
+ uint64_t IVBitWidth = Context.getIntWidth(IVType);
+
+ // Create pre-init declarations for all loops lower bounds, upper bounds,
+ // strides and num-iterations for every top level loop in the fusion
+ SmallVector<VarDecl *, 4> LBVarDecls;
+ SmallVector<VarDecl *, 4> STVarDecls;
+ SmallVector<VarDecl *, 4> NIVarDecls;
+ SmallVector<VarDecl *, 4> UBVarDecls;
+ SmallVector<VarDecl *, 4> IVVarDecls;
+
+ // Helper lambda to create variables for bounds, strides, and other
+ // expressions. Generates both the variable declaration and the corresponding
+ // initialization statement.
+ auto CreateHelperVarAndStmt =
+ [&, &SemaRef = SemaRef](Expr *ExprToCopy, const std::string &BaseName,
+ unsigned I, bool NeedsNewVD = false) {
+ Expr *TransformedExpr =
+ AssertSuccess(CopyTransformer.TransformExpr(ExprToCopy));
+ if (!TransformedExpr)
+ return std::pair<VarDecl *, StmtResult>(nullptr, StmtError());
+
+ auto Name = (Twine(".omp.") + BaseName + std::to_string(I)).str();
+
+ VarDecl *VD;
+ if (NeedsNewVD) {
+ VD = buildVarDecl(SemaRef, SourceLocation(), IVType, Name);
+ SemaRef.AddInitializerToDecl(VD, TransformedExpr, false);
+ } else {
+ // Create a unique variable name
+ DeclRefExpr *DRE = cast<DeclRefExpr>(TransformedExpr);
+ VD = cast<VarDecl>(DRE->getDecl());
+ VD->setDeclName(&SemaRef.PP.getIdentifierTable().get(Name));
+ }
+ // Create the corresponding declaration statement
+ StmtResult DeclStmt = new (Context) class DeclStmt(
+ DeclGroupRef(VD), SourceLocation(), SourceLocation());
+ return std::make_pair(VD, DeclStmt);
+ };
+
+ // PreInits hold a sequence of variable declarations that must be executed
+ // before the fused loop begins. These include bounds, strides, and other
+ // helper variables required for the transformation. Other loop transforms
+ // also contain their own preinits
+ SmallVector<Stmt *> PreInits;
+
+ // Update the general preinits using the preinits generated by loop sequence
+ // generating loop transformations. These preinits differ slightly from
+ // single-loop transformation preinits, as they can be detached from a
+ // specific loop inside multiple generated loop nests. This happens
+ // because certain helper variables, like '.omp.fuse.max', are introduced to
+ // handle fused iteration spaces and may not be directly tied to a single
+ // original loop. The preinit structure must ensure that hidden variables
+ // like '.omp.fuse.max' are still properly handled.
+ // Transformations that apply this concept: Loopranged Fuse, Split
+ llvm::append_range(PreInits, SeqAnalysis.LoopSequencePreInits);
+
+ // Process each single loop to generate and collect declarations
+ // and statements for all helper expressions related to
+ // particular single loop nests
+
+ // Also In the case of the fused loops, we keep track of their original
+ // inits by appending them to their preinits statement, and in the case of
+ // transformations, also append their preinits (which contain the original
+ // loop initialization statement or other statements)
+
+ // Firstly we need to set TransformIndex to match the begining of the
+ // looprange section
+ unsigned int TransformIndex = 0;
+ for (unsigned I : llvm::seq<unsigned>(FirstVal - 1)) {
+ if (SeqAnalysis.Loops[I].isLoopTransformation())
+ ++TransformIndex;
+ }
+
+ for (unsigned int I = FirstVal - 1, J = 0; I < LastVal; ++I, ++J) {
+ if (SeqAnalysis.Loops[I].isRegularLoop()) {
+ addLoopPreInits(Context, SeqAnalysis.Loops[I].HelperExprs,
+ SeqAnalysis.Loops[I].TheForStmt,
+ SeqAnalysis.Loops[I].OriginalInits, PreInits);
+ } else if (SeqAnalysis.Loops[I].isLoopTransformation()) {
+ // For transformed loops, insert both pre-inits and original inits.
+ // Order matters: pre-inits may define variables used in the original
+ // inits such as upper bounds...
+ SmallVector<Stmt *> &TransformPreInit =
+ SeqAnalysis.Loops[TransformIndex++].TransformsPreInits;
+ llvm::append_range(PreInits, TransformPreInit);
+
+ addLoopPreInits(Context, SeqAnalysis.Loops[I].HelperExprs,
+ SeqAnalysis.Loops[I].TheForStmt,
+ SeqAnalysis.Loops[I].OriginalInits, PreInits);
+ }
+ auto [UBVD, UBDStmt] =
+ CreateHelperVarAndStmt(SeqAnalysis.Loops[I].HelperExprs.UB, "ub", J);
+ auto [LBVD, LBDStmt] =
+ CreateHelperVarAndStmt(SeqAnalysis.Loops[I].HelperExprs.LB, "lb", J);
+ auto [STVD, STDStmt] =
+ CreateHelperVarAndStmt(SeqAnalysis.Loops[I].HelperExprs.ST, "st", J);
+ auto [NIVD, NIDStmt] = CreateHelperVarAndStmt(
+ SeqAnalysis.Loops[I].HelperExprs.NumIterations, "ni", J, true);
+ auto [IVVD, IVDStmt] = CreateHelperVarAndStmt(
+ SeqAnalysis.Loops[I].HelperExprs.IterationVarRef, "iv", J);
+
+ assert(LBVD && STVD && NIVD && IVVD &&
+ "OpenMP Fuse Helper variables creation failed");
+
+ UBVarDecls.push_back(UBVD);
+ LBVarDecls.push_back(LBVD);
+ STVarDecls.push_back(STVD);
+ NIVarDecls.push_back(NIVD);
+ IVVarDecls.push_back(IVVD);
+
+ PreInits.push_back(LBDStmt.get());
+ PreInits.push_back(STDStmt.get());
+ PreInits.push_back(NIDStmt.get());
+ PreInits.push_back(IVDStmt.get());
+ }
+
+ auto MakeVarDeclRef = [&SemaRef = this->SemaRef](VarDecl *VD) {
+ return buildDeclRefExpr(SemaRef, VD, VD->getType(), VD->getLocation(),
+ false);
+ };
+
+ // Following up the creation of the final fused loop will be performed
+ // which has the following shape (considering the selected loops):
+ //
+ // for (fuse.index = 0; fuse.index < max(ni0, ni1..., nik); ++fuse.index) {
+ // if (fuse.index < ni0){
+ // iv0 = lb0 + st0 * fuse.index;
+ // original.index0 = iv0
+ // body(0);
+ // }
+ // if (fuse.index < ni1){
+ // iv1 = lb1 + st1 * fuse.index;
+ // original.index1 = iv1
+ // body(1);
+ // }
+ //
+ // ...
+ //
+ // if (fuse.index < nik){
+ // ivk = lbk + stk * fuse.index;
+ // original.indexk = ivk
+ // body(k); Expr *InitVal = IntegerLiteral::Create(Context,
+ // llvm::APInt(IVWidth, 0),
+ // }
+
+ // 1. Create the initialized fuse index
+ StringRef IndexName = ".omp.fuse.index";
+ Expr *InitVal = IntegerLiteral::Create(Context, llvm::APInt(IVBitWidth, 0),
+ IVType, SourceLocation());
+ VarDecl *IndexDecl =
+ buildVarDecl(SemaRef, {}, IVType, IndexName, nullptr, nullptr);
+ SemaRef.AddInitializerToDecl(IndexDecl, InitVal, false);
+ StmtResult InitStmt = new (Context)
+ DeclStmt(DeclGroupRef(IndexDecl), SourceLocation(), SourceLocation());
+
+ if (!InitStmt.isUsable())
+ return StmtError();
+
+ auto MakeIVRef = [&SemaRef = this->SemaRef, IndexDecl, IVType,
+ Loc = InitVal->getExprLoc()]() {
+ return buildDeclRefExpr(SemaRef, IndexDecl, IVType, Loc, false);
+ };
+
+ // 2. Iteratively compute the max number of logical iterations Max(NI_1, NI_2,
+ // ..., NI_k)
+ //
+ // This loop accumulates the maximum value across multiple expressions,
+ // ensuring each step constructs a unique AST node for correctness. By using
+ // intermediate temporary variables and conditional operators, we maintain
+ // distinct nodes and avoid duplicating subtrees, For instance, max(a,b,c):
+ // omp.temp0 = max(a, b)
+ // omp.temp1 = max(omp.temp0, c)
+ // omp.fuse.max = max(omp.temp1, omp.temp0)
+
+ ExprResult MaxExpr;
+ // I is the range of loops in the sequence that we fuse.
+ for (unsigned I = FirstVal - 1, J = 0; I < LastVal; ++I, ++J) {
+ DeclRefExpr *NIRef = MakeVarDeclRef(NIVarDecls[J]);
+ QualType NITy = NIRef->getType();
+
+ if (MaxExpr.isUnset()) {
+ // Initialize MaxExpr with the first NI expression
+ MaxExpr = NIRef;
+ } else {
+ // Create a new acummulator variable t_i = MaxExpr
+ std::string TempName = (Twine(".omp.temp.") + Twine(J)).str();
+ VarDecl *TempDecl =
+ buildVarDecl(SemaRef, {}, NITy, TempName, nullptr, nullptr);
+ TempDecl->setInit(MaxExpr.get());
+ DeclRefExpr *TempRef =
+ buildDeclRefExpr(SemaRef, TempDecl, NITy, SourceLocation(), false);
+ DeclRefExpr *TempRef2 =
+ buildDeclRefExpr(SemaRef, TempDecl, NITy, SourceLocation(), false);
+ // Add a DeclStmt to PreInits to ensure the variable is declared.
+ StmtResult TempStmt = new (Context)
+ DeclStmt(DeclGroupRef(TempDecl), SourceLocation(), SourceLocation());
+
+ if (!TempStmt.isUsable())
+ return StmtError();
+ PreInits.push_back(TempStmt.get());
+
+ // Build MaxExpr <-(MaxExpr > NIRef ? MaxExpr : NIRef)
+ ExprResult Comparison =
+ SemaRef.BuildBinOp(nullptr, SourceLocation(), BO_GT, TempRef, NIRef);
+ // Handle any errors in Comparison creation
+ if (!Comparison.isUsable())
+ return StmtError();
+
+ DeclRefExpr *NIRef2 = MakeVarDeclRef(NIVarDecls[J]);
+ // Update MaxExpr using a conditional expression to hold the max value
+ MaxExpr = new (Context) ConditionalOperator(
+ Comparison.get(), SourceLocation(), TempRef2, SourceLocation(),
+ NIRef2->getExprStmt(), NITy, VK_LValue, OK_Ordinary);
+
+ if (!MaxExpr.isUsable())
+ return StmtError();
+ }
+ }
+ if (!MaxExpr.isUsable())
+ return StmtError();
+
+ // 3. Declare the max variable
+ const std::string MaxName = Twine(".omp.fuse.max").str();
+ VarDecl *MaxDecl =
+ buildVarDecl(SemaRef, {}, IVType, MaxName, nullptr, nullptr);
+ MaxDecl->setInit(MaxExpr.get());
+ DeclRefExpr *MaxRef = buildDeclRefExpr(SemaRef, MaxDecl, IVType, {}, false);
+ StmtResult MaxStmt = new (Context)
+ DeclStmt(DeclGroupRef(MaxDecl), SourceLocation(), SourceLocation());
+
+ if (MaxStmt.isInvalid())
+ return StmtError();
+ PreInits.push_back(MaxStmt.get());
+
+ // 4. Create condition Expr: index < n_max
+ ExprResult CondExpr = SemaRef.BuildBinOp(CurScope, SourceLocation(), BO_LT,
+ MakeIVRef(), MaxRef);
+ if (!CondExpr.isUsable())
+ return StmtError();
+
+ // 5. Increment Expr: ++index
+ ExprResult IncrExpr =
+ SemaRef.BuildUnaryOp(CurScope, SourceLocation(), UO_PreInc, MakeIVRef());
+ if (!IncrExpr.isUsable())
+ return StmtError();
+
+ // 6. Build the Fused Loop Body
+ // The final fused loop iterates over the maximum logical range. Inside the
+ // loop, each original loop's index is calculated dynamically, and its body
+ // is executed conditionally.
+ //
+ // Each sub-loop's body is guarded by a conditional statement to ensure
+ // it executes only within its logical iteration range:
+ //
+ // if (fuse.index < ni_k){
+ // iv_k = lb_k + st_k * fuse.index;
+ // original.index = iv_k
+ // body(k);
+ // }
+
+ CompoundStmt *FusedBody = nullptr;
+ SmallVector<Stmt *, 4> FusedBodyStmts;
+ for (unsigned I = FirstVal - 1, J = 0; I < LastVal; ++I, ++J) {
+ // Assingment of the original sub-loop index to compute the logical index
+ // IV_k = LB_k + omp.fuse.index * ST_k
+ ExprResult IdxExpr =
+ SemaRef.BuildBinOp(CurScope, SourceLocation(), BO_Mul,
+ MakeVarDeclRef(STVarDecls[J]), MakeIVRef());
+ if (!IdxExpr.isUsable())
+ return StmtError();
+ IdxExpr = SemaRef.BuildBinOp(CurScope, SourceLocation(), BO_Add,
+ MakeVarDeclRef(LBVarDecls[J]), IdxExpr.get());
+
+ if (!IdxExpr.isUsable())
+ return StmtError();
+ IdxExpr = SemaRef.BuildBinOp(CurScope, SourceLocation(), BO_Assign,
+ MakeVarDeclRef(IVVarDecls[J]), IdxExpr.get());
+ if (!IdxExpr.isUsable())
+ return StmtError();
+
+ // Update the original i_k = IV_k
+ SmallVector<Stmt *, 4> BodyStmts;
+ BodyStmts.push_back(IdxExpr.get());
+ llvm::append_range(BodyStmts, SeqAnalysis.Loops[I].HelperExprs.Updates);
+
+ // If the loop is a CXXForRangeStmt then the iterator variable is needed
+ if (auto *SourceCXXFor =
+ dyn_cast<CXXForRangeStmt>(SeqAnalysis.Loops[I].TheForStmt))
+ BodyStmts.push_back(SourceCXXFor->getLoopVarStmt());
+
+ Stmt *Body =
+ (isa<ForStmt>(SeqAnalysis.Loops[I].TheForStmt))
+ ? cast<ForStmt>(SeqAnalysis.Loops[I].TheForStmt)->getBody()
+ : cast<CXXForRangeStmt>(SeqAnalysis.Loops[I].TheForStmt)->getBody();
+ BodyStmts.push_back(Body);
+
+ CompoundStmt *CombinedBody =
+ CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(),
+ SourceLocation(), SourceLocation());
+ ExprResult Condition =
+ SemaRef.BuildBinOp(CurScope, SourceLocation(), BO_LT, MakeIVRef(),
+ MakeVarDeclRef(NIVarDecls[J]));
+
+ if (!Condition.isUsable())
+ return StmtError();
+
+ IfStmt *IfStatement = IfStmt::Create(
+ Context, SourceLocation(), IfStatementKind::Ordinary, nullptr, nullptr,
+ Condition.get(), SourceLocation(), SourceLocation(), CombinedBody,
+ SourceLocation(), nullptr);
+
+ FusedBodyStmts.push_back(IfStatement);
+ }
+ FusedBody = CompoundStmt::Create(Context, FusedBodyStmts, FPOptionsOverride(),
+ SourceLocation(), SourceLocation());
+
+ // 7. Construct the final fused loop
+ ForStmt *FusedForStmt = new (Context)
+ ForStmt(Context, InitStmt.get(), CondExpr.get(), nullptr, IncrExpr.get(),
+ FusedBody, InitStmt.get()->getBeginLoc(), SourceLocation(),
+ IncrExpr.get()->getEndLoc());
+
+ // In the case of looprange, the result of fuse won't simply
+ // be a single loop (ForStmt), but rather a loop sequence
+ // (CompoundStmt) of 3 parts: the pre-fusion loops, the fused loop
+ // and the post-fusion loops, preserving its original order.
+ //
+ // Note: If looprange clause produces a single fused loop nest then
+ // this compound statement wrapper is unnecessary (Therefore this
+ // treatment is skipped)
+
+ Stmt *FusionStmt = FusedForStmt;
+ if (LRC && CountVal != SeqAnalysis.LoopSeqSize) {
+ SmallVector<Stmt *, 4> FinalLoops;
+
+ // Reset the transform index
+ TransformIndex = 0;
+
+ // Collect all non-fused loops before and after the fused region.
+ // Pre-fusion and post-fusion loops are inserted in order exploiting their
+ // symmetry, along with their corresponding transformation pre-inits if
+ // needed. The fused loop is added between the two regions.
+ for (unsigned I : llvm::seq<unsigned>(SeqAnalysis.LoopSeqSize)) {
+ if (I >= FirstVal - 1 && I < FirstVal + CountVal - 1) {
+ // Update the Transformation counter to skip already treated
+ // loop transformations
+ if (!SeqAnalysis.Loops[I].isLoopTransformation())
+ ++TransformIndex;
+ continue;
+ }
+
+ // No need to handle:
+ // Regular loops: they are kept intact as-is.
+ // Loop-sequence-generating transformations: already handled earlier.
+ // Only TransformSingleLoop requires inserting pre-inits here
+ if (SeqAnalysis.Loops[I].isRegularLoop()) {
+ const auto &TransformPreInit =
+ SeqAnalysis.Loops[TransformIndex++].TransformsPreInits;
+ if (!TransformPreInit.empty())
+ llvm::append_range(PreInits, TransformPreInit);
+ }
+
+ FinalLoops.push_back(SeqAnalysis.Loops[I].TheForStmt);
+ }
+
+ FinalLoops.insert(FinalLoops.begin() + (FirstVal - 1), FusedForStmt);
+ FusionStmt = CompoundStmt::Create(Context, FinalLoops, FPOptionsOverride(),
+ SourceLocation(), SourceLocation());
+ }
+ return OMPFuseDirective::Create(Context, StartLoc, EndLoc, Clauses,
+ NumGeneratedTopLevelLoops, AStmt, FusionStmt,
+ buildPreInits(Context, PreInits));
+}
+
OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
Expr *Expr,
SourceLocation StartLoc,
@@ -16887,6 +17637,31 @@ OMPClause *SemaOpenMP::ActOnOpenMPPartialClause(Expr *FactorExpr,
FactorExpr);
}
+OMPClause *SemaOpenMP::ActOnOpenMPLoopRangeClause(
+ Expr *First, Expr *Count, SourceLocation StartLoc, SourceLocation LParenLoc,
+ SourceLocation FirstLoc, SourceLocation CountLoc, SourceLocation EndLoc) {
+
+ // OpenMP [6.0, Restrictions]
+ // First and Count must be integer expressions with positive value
+ ExprResult FirstVal =
+ VerifyPositiveIntegerConstantInClause(First, OMPC_looprange);
+ if (FirstVal.isInvalid())
+ First = nullptr;
+
+ ExprResult CountVal =
+ VerifyPositiveIntegerConstantInClause(Count, OMPC_looprange);
+ if (CountVal.isInvalid())
+ Count = nullptr;
+
+ // OpenMP [6.0, Restrictions]
+ // first + count - 1 must not evaluate to a value greater than the
+ // loop sequence length of the associated canonical loop sequence.
+ // This check must be performed afterwards due to the delayed
+ // parsing and computation of the associated loop sequence
+ return OMPLoopRangeClause::Create(getASTContext(), StartLoc, LParenLoc,
+ FirstLoc, CountLoc, EndLoc, First, Count);
+}
+
OMPClause *SemaOpenMP::ActOnOpenMPAlignClause(Expr *A, SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc) {