diff options
Diffstat (limited to 'flang/lib/Lower/Bridge.cpp')
-rw-r--r-- | flang/lib/Lower/Bridge.cpp | 70 |
1 files changed, 67 insertions, 3 deletions
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 780d56f..50a687c 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -643,6 +643,8 @@ public: return localSymbols.lookupStorage(sym); } + Fortran::lower::SymMap &getSymbolMap() override final { return localSymbols; } + void overrideExprValues(const Fortran::lower::ExprToValueMap *map) override final { exprValueOverrides = map; @@ -3190,15 +3192,20 @@ private: std::get_if<Fortran::parser::OpenACCCombinedConstruct>(&acc.u); Fortran::lower::pft::Evaluation *curEval = &getEval(); + // Determine collapse depth/force and loopCount + bool collapseForce = false; + uint64_t collapseDepth = 1; + uint64_t loopCount = 1; if (accLoop || accCombined) { - uint64_t loopCount; if (accLoop) { const Fortran::parser::AccBeginLoopDirective &beginLoopDir = std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t); const Fortran::parser::AccClauseList &clauseList = std::get<Fortran::parser::AccClauseList>(beginLoopDir.t); loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList); + std::tie(collapseDepth, collapseForce) = + Fortran::lower::getCollapseSizeAndForce(clauseList); } else if (accCombined) { const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir = std::get<Fortran::parser::AccBeginCombinedDirective>( @@ -3206,6 +3213,8 @@ private: const Fortran::parser::AccClauseList &clauseList = std::get<Fortran::parser::AccClauseList>(beginCombinedDir.t); loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList); + std::tie(collapseDepth, collapseForce) = + Fortran::lower::getCollapseSizeAndForce(clauseList); } if (curEval->lowerAsStructured()) { @@ -3215,8 +3224,63 @@ private: } } - for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) - genFIR(e); + const bool isStructured = curEval && curEval->lowerAsStructured(); + if (isStructured && collapseForce && collapseDepth > 1) { + // force: collect prologue/epilogue for the first collapseDepth nested + // loops and sink them into the innermost loop body at that depth + llvm::SmallVector<Fortran::lower::pft::Evaluation *> prologue, epilogue; + Fortran::lower::pft::Evaluation *parent = &getEval(); + Fortran::lower::pft::Evaluation *innermostLoopEval = nullptr; + for (uint64_t lvl = 0; lvl + 1 < collapseDepth; ++lvl) { + epilogue.clear(); + auto &kids = parent->getNestedEvaluations(); + // Collect all non-loop statements before the next inner loop as + // prologue, then mark remaining siblings as epilogue and descend into + // the inner loop. + Fortran::lower::pft::Evaluation *childLoop = nullptr; + for (auto it = kids.begin(); it != kids.end(); ++it) { + if (it->getIf<Fortran::parser::DoConstruct>()) { + childLoop = &*it; + for (auto it2 = std::next(it); it2 != kids.end(); ++it2) + epilogue.push_back(&*it2); + break; + } + prologue.push_back(&*it); + } + // Semantics guarantees collapseDepth does not exceed nest depth + // so childLoop must be found here. + assert(childLoop && "Expected inner DoConstruct for collapse"); + parent = childLoop; + innermostLoopEval = childLoop; + } + + // Track sunk evaluations (avoid double-lowering) + llvm::SmallPtrSet<const Fortran::lower::pft::Evaluation *, 16> sunk; + for (auto *e : prologue) + sunk.insert(e); + for (auto *e : epilogue) + sunk.insert(e); + + auto sink = + [&](llvm::SmallVector<Fortran::lower::pft::Evaluation *> &lst) { + for (auto *e : lst) + genFIR(*e); + }; + + sink(prologue); + + // Lower innermost loop body, skipping sunk + for (Fortran::lower::pft::Evaluation &e : + innermostLoopEval->getNestedEvaluations()) + if (!sunk.contains(&e)) + genFIR(e); + + sink(epilogue); + } else { + // Normal lowering + for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) + genFIR(e); + } localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); |