aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Lower/Bridge.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Lower/Bridge.cpp')
-rw-r--r--flang/lib/Lower/Bridge.cpp70
1 files changed, 67 insertions, 3 deletions
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 780d56f..50a687c 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -643,6 +643,8 @@ public:
return localSymbols.lookupStorage(sym);
}
+ Fortran::lower::SymMap &getSymbolMap() override final { return localSymbols; }
+
void
overrideExprValues(const Fortran::lower::ExprToValueMap *map) override final {
exprValueOverrides = map;
@@ -3190,15 +3192,20 @@ private:
std::get_if<Fortran::parser::OpenACCCombinedConstruct>(&acc.u);
Fortran::lower::pft::Evaluation *curEval = &getEval();
+ // Determine collapse depth/force and loopCount
+ bool collapseForce = false;
+ uint64_t collapseDepth = 1;
+ uint64_t loopCount = 1;
if (accLoop || accCombined) {
- uint64_t loopCount;
if (accLoop) {
const Fortran::parser::AccBeginLoopDirective &beginLoopDir =
std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t);
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
+ std::tie(collapseDepth, collapseForce) =
+ Fortran::lower::getCollapseSizeAndForce(clauseList);
} else if (accCombined) {
const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir =
std::get<Fortran::parser::AccBeginCombinedDirective>(
@@ -3206,6 +3213,8 @@ private:
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginCombinedDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
+ std::tie(collapseDepth, collapseForce) =
+ Fortran::lower::getCollapseSizeAndForce(clauseList);
}
if (curEval->lowerAsStructured()) {
@@ -3215,8 +3224,63 @@ private:
}
}
- for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
- genFIR(e);
+ const bool isStructured = curEval && curEval->lowerAsStructured();
+ if (isStructured && collapseForce && collapseDepth > 1) {
+ // force: collect prologue/epilogue for the first collapseDepth nested
+ // loops and sink them into the innermost loop body at that depth
+ llvm::SmallVector<Fortran::lower::pft::Evaluation *> prologue, epilogue;
+ Fortran::lower::pft::Evaluation *parent = &getEval();
+ Fortran::lower::pft::Evaluation *innermostLoopEval = nullptr;
+ for (uint64_t lvl = 0; lvl + 1 < collapseDepth; ++lvl) {
+ epilogue.clear();
+ auto &kids = parent->getNestedEvaluations();
+ // Collect all non-loop statements before the next inner loop as
+ // prologue, then mark remaining siblings as epilogue and descend into
+ // the inner loop.
+ Fortran::lower::pft::Evaluation *childLoop = nullptr;
+ for (auto it = kids.begin(); it != kids.end(); ++it) {
+ if (it->getIf<Fortran::parser::DoConstruct>()) {
+ childLoop = &*it;
+ for (auto it2 = std::next(it); it2 != kids.end(); ++it2)
+ epilogue.push_back(&*it2);
+ break;
+ }
+ prologue.push_back(&*it);
+ }
+ // Semantics guarantees collapseDepth does not exceed nest depth
+ // so childLoop must be found here.
+ assert(childLoop && "Expected inner DoConstruct for collapse");
+ parent = childLoop;
+ innermostLoopEval = childLoop;
+ }
+
+ // Track sunk evaluations (avoid double-lowering)
+ llvm::SmallPtrSet<const Fortran::lower::pft::Evaluation *, 16> sunk;
+ for (auto *e : prologue)
+ sunk.insert(e);
+ for (auto *e : epilogue)
+ sunk.insert(e);
+
+ auto sink =
+ [&](llvm::SmallVector<Fortran::lower::pft::Evaluation *> &lst) {
+ for (auto *e : lst)
+ genFIR(*e);
+ };
+
+ sink(prologue);
+
+ // Lower innermost loop body, skipping sunk
+ for (Fortran::lower::pft::Evaluation &e :
+ innermostLoopEval->getNestedEvaluations())
+ if (!sunk.contains(&e))
+ genFIR(e);
+
+ sink(epilogue);
+ } else {
+ // Normal lowering
+ for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
+ genFIR(e);
+ }
localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);