diff options
Diffstat (limited to 'flang/lib/Lower/OpenACC.cpp')
-rw-r--r-- | flang/lib/Lower/OpenACC.cpp | 62 |
1 files changed, 40 insertions, 22 deletions
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 742f58f..62e5c0c 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -2178,11 +2178,25 @@ static void processDoLoopBounds( locs.push_back(converter.genLocation( Fortran::parser::FindSourceLocation(outerDoConstruct))); } else { - auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>(); - assert(doCons && "expect do construct"); - loopControl = &*doCons->GetLoopControl(); + // Safely locate the next inner DoConstruct within this eval. + const Fortran::parser::DoConstruct *innerDo = nullptr; + if (crtEval && crtEval->hasNestedEvaluations()) { + for (Fortran::lower::pft::Evaluation &child : + crtEval->getNestedEvaluations()) { + if (auto *stmt = child.getIf<Fortran::parser::DoConstruct>()) { + innerDo = stmt; + // Prepare to descend for the next iteration + crtEval = &child; + break; + } + } + } + if (!innerDo) + break; // No deeper loop; stop collecting collapsed bounds. + + loopControl = &*innerDo->GetLoopControl(); locs.push_back(converter.genLocation( - Fortran::parser::FindSourceLocation(*doCons))); + Fortran::parser::FindSourceLocation(*innerDo))); } const Fortran::parser::LoopControl::Bounds *bounds = @@ -2206,8 +2220,7 @@ static void processDoLoopBounds( inclusiveBounds.push_back(true); - if (i < loopsToProcess - 1) - crtEval = &*std::next(crtEval->getNestedEvaluations().begin()); + // crtEval already updated when descending; no blind increment here. } } } @@ -2553,10 +2566,6 @@ static mlir::acc::LoopOp createLoopOp( std::get_if<Fortran::parser::AccClause::Collapse>( &clause.u)) { const Fortran::parser::AccCollapseArg &arg = collapseClause->v; - const auto &force = std::get<bool>(arg.t); - if (force) - TODO(clauseLocation, "OpenACC collapse force modifier"); - const auto &intExpr = std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t); const auto *expr = Fortran::semantics::GetExpr(intExpr); @@ -5029,25 +5038,34 @@ void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder, uint64_t Fortran::lower::getLoopCountForCollapseAndTile( const Fortran::parser::AccClauseList &clauseList) { - uint64_t collapseLoopCount = 1; + uint64_t collapseLoopCount = getCollapseSizeAndForce(clauseList).first; uint64_t tileLoopCount = 1; for (const Fortran::parser::AccClause &clause : clauseList.v) { - if (const auto *collapseClause = - std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) { - const parser::AccCollapseArg &arg = collapseClause->v; - const auto &collapseValue{std::get<parser::ScalarIntConstantExpr>(arg.t)}; - collapseLoopCount = *Fortran::semantics::GetIntValue(collapseValue); - } if (const auto *tileClause = std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) { const parser::AccTileExprList &tileExprList = tileClause->v; - const std::list<parser::AccTileExpr> &listTileExpr = tileExprList.v; - tileLoopCount = listTileExpr.size(); + tileLoopCount = tileExprList.v.size(); + } + } + return tileLoopCount > collapseLoopCount ? tileLoopCount : collapseLoopCount; +} + +std::pair<uint64_t, bool> Fortran::lower::getCollapseSizeAndForce( + const Fortran::parser::AccClauseList &clauseList) { + uint64_t size = 1; + bool force = false; + for (const Fortran::parser::AccClause &clause : clauseList.v) { + if (const auto *collapseClause = + std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) { + const Fortran::parser::AccCollapseArg &arg = collapseClause->v; + force = std::get<bool>(arg.t); + const auto &collapseValue = + std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t); + size = *Fortran::semantics::GetIntValue(collapseValue); + break; } } - if (tileLoopCount > collapseLoopCount) - return tileLoopCount; - return collapseLoopCount; + return {size, force}; } /// Create an ACC loop operation for a DO construct when inside ACC compute |