aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Lower/OpenACC.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Lower/OpenACC.cpp')
-rw-r--r--flang/lib/Lower/OpenACC.cpp62
1 files changed, 40 insertions, 22 deletions
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 742f58f..62e5c0c 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2178,11 +2178,25 @@ static void processDoLoopBounds(
locs.push_back(converter.genLocation(
Fortran::parser::FindSourceLocation(outerDoConstruct)));
} else {
- auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
- assert(doCons && "expect do construct");
- loopControl = &*doCons->GetLoopControl();
+ // Safely locate the next inner DoConstruct within this eval.
+ const Fortran::parser::DoConstruct *innerDo = nullptr;
+ if (crtEval && crtEval->hasNestedEvaluations()) {
+ for (Fortran::lower::pft::Evaluation &child :
+ crtEval->getNestedEvaluations()) {
+ if (auto *stmt = child.getIf<Fortran::parser::DoConstruct>()) {
+ innerDo = stmt;
+ // Prepare to descend for the next iteration
+ crtEval = &child;
+ break;
+ }
+ }
+ }
+ if (!innerDo)
+ break; // No deeper loop; stop collecting collapsed bounds.
+
+ loopControl = &*innerDo->GetLoopControl();
locs.push_back(converter.genLocation(
- Fortran::parser::FindSourceLocation(*doCons)));
+ Fortran::parser::FindSourceLocation(*innerDo)));
}
const Fortran::parser::LoopControl::Bounds *bounds =
@@ -2206,8 +2220,7 @@ static void processDoLoopBounds(
inclusiveBounds.push_back(true);
- if (i < loopsToProcess - 1)
- crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
+ // crtEval already updated when descending; no blind increment here.
}
}
}
@@ -2553,10 +2566,6 @@ static mlir::acc::LoopOp createLoopOp(
std::get_if<Fortran::parser::AccClause::Collapse>(
&clause.u)) {
const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
- const auto &force = std::get<bool>(arg.t);
- if (force)
- TODO(clauseLocation, "OpenACC collapse force modifier");
-
const auto &intExpr =
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
const auto *expr = Fortran::semantics::GetExpr(intExpr);
@@ -5029,25 +5038,34 @@ void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
uint64_t Fortran::lower::getLoopCountForCollapseAndTile(
const Fortran::parser::AccClauseList &clauseList) {
- uint64_t collapseLoopCount = 1;
+ uint64_t collapseLoopCount = getCollapseSizeAndForce(clauseList).first;
uint64_t tileLoopCount = 1;
for (const Fortran::parser::AccClause &clause : clauseList.v) {
- if (const auto *collapseClause =
- std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
- const parser::AccCollapseArg &arg = collapseClause->v;
- const auto &collapseValue{std::get<parser::ScalarIntConstantExpr>(arg.t)};
- collapseLoopCount = *Fortran::semantics::GetIntValue(collapseValue);
- }
if (const auto *tileClause =
std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
const parser::AccTileExprList &tileExprList = tileClause->v;
- const std::list<parser::AccTileExpr> &listTileExpr = tileExprList.v;
- tileLoopCount = listTileExpr.size();
+ tileLoopCount = tileExprList.v.size();
+ }
+ }
+ return tileLoopCount > collapseLoopCount ? tileLoopCount : collapseLoopCount;
+}
+
+std::pair<uint64_t, bool> Fortran::lower::getCollapseSizeAndForce(
+ const Fortran::parser::AccClauseList &clauseList) {
+ uint64_t size = 1;
+ bool force = false;
+ for (const Fortran::parser::AccClause &clause : clauseList.v) {
+ if (const auto *collapseClause =
+ std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
+ const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
+ force = std::get<bool>(arg.t);
+ const auto &collapseValue =
+ std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
+ size = *Fortran::semantics::GetIntValue(collapseValue);
+ break;
}
}
- if (tileLoopCount > collapseLoopCount)
- return tileLoopCount;
- return collapseLoopCount;
+ return {size, force};
}
/// Create an ACC loop operation for a DO construct when inside ACC compute