aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Lower/OpenMP/OpenMP.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Lower/OpenMP/OpenMP.cpp')
-rw-r--r--flang/lib/Lower/OpenMP/OpenMP.cpp360
1 files changed, 236 insertions, 124 deletions
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 1cb3335..9e56c2b 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1982,125 +1982,241 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return loopOp;
}
-static mlir::omp::CanonicalLoopOp
-genCanonicalLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx,
- lower::pft::Evaluation &eval, mlir::Location loc,
- const ConstructQueue &queue,
- ConstructQueue::const_iterator item,
- llvm::ArrayRef<const semantics::Symbol *> ivs,
- llvm::omp::Directive directive) {
+static void genCanonicalLoopNest(
+ lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+ mlir::Location loc, const ConstructQueue &queue,
+ ConstructQueue::const_iterator item, size_t numLoops,
+ llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
+ assert(loops.empty() && "Expecting empty list to fill");
+ assert(numLoops >= 1 && "Expecting at least one loop");
+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- assert(ivs.size() == 1 && "Nested loops not yet implemented");
- const semantics::Symbol *iv = ivs[0];
+ mlir::omp::LoopRelatedClauseOps loopInfo;
+ llvm::SmallVector<const semantics::Symbol *, 3> ivs;
+ collectLoopRelatedInfo(converter, loc, eval, numLoops, loopInfo, ivs);
+ assert(ivs.size() == numLoops &&
+ "Expected to parse as many loop variables as there are loops");
+
+ // Steps that follow:
+ // 1. Emit all of the loop's prologues (compute the tripcount)
+ // 2. Emit omp.canonical_loop nested inside each other (iteratively)
+ // 2.1. In the innermost omp.canonical_loop, emit the loop body prologue (in
+ // the body callback)
+ //
+ // Since emitting prologues and body code is split, remember prologue values
+ // for use when emitting the same loop's epilogues.
+ llvm::SmallVector<mlir::Value> tripcounts;
+ llvm::SmallVector<mlir::Value> clis;
+ llvm::SmallVector<lower::pft::Evaluation *> evals;
+ llvm::SmallVector<mlir::Type> loopVarTypes;
+ llvm::SmallVector<mlir::Value> loopStepVars;
+ llvm::SmallVector<mlir::Value> loopLBVars;
+ llvm::SmallVector<mlir::Value> blockArgs;
+
+ // Step 1: Loop prologues
+ // Computing the trip count must happen before entering the outermost loop
+ lower::pft::Evaluation *innermostEval = &eval.getFirstNestedEvaluation();
+ for ([[maybe_unused]] auto iv : ivs) {
+ if (innermostEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
+ // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct.
+ // Will need to add special cases for this combination.
+ TODO(loc, "DO CONCURRENT as canonical loop not supported");
+ }
+
+ auto &doLoopEval = innermostEval->getFirstNestedEvaluation();
+ evals.push_back(innermostEval);
+
+ // Get the loop bounds (and increment)
+ // auto &doLoopEval = nestedEval.getFirstNestedEvaluation();
+ auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>();
+ assert(doStmt && "Expected do loop to be in the nested evaluation");
+ auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t);
+ assert(loopControl.has_value());
+ auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
+ assert(bounds && "Expected bounds for canonical loop");
+ lower::StatementContext stmtCtx;
+ mlir::Value loopLBVar = fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx));
+ mlir::Value loopUBVar = fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx));
+ mlir::Value loopStepVar = [&]() {
+ if (bounds->step) {
+ return fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx));
+ }
- auto &nestedEval = eval.getFirstNestedEvaluation();
- if (nestedEval.getIf<parser::DoConstruct>()->IsDoConcurrent()) {
- // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct. Will
- // need to add special cases for this combination.
- TODO(loc, "DO CONCURRENT as canonical loop not supported");
+ // If `step` is not present, assume it is `1`.
+ auto intTy = firOpBuilder.getI32Type();
+ return firOpBuilder.createIntegerConstant(loc, intTy, 1);
+ }();
+
+ // Get the integer kind for the loop variable and cast the loop bounds
+ size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size();
+ mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
+ loopVarTypes.push_back(loopVarType);
+ loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar);
+ loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar);
+ loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar);
+ loopLBVars.push_back(loopLBVar);
+ loopStepVars.push_back(loopStepVar);
+
+ // Start lowering
+ mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0);
+ mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1);
+ mlir::Value isDownwards = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);
+
+ // Ensure we are counting upwards. If not, negate step and swap lb and ub.
+ mlir::Value negStep =
+ firOpBuilder.create<mlir::arith::SubIOp>(loc, zero, loopStepVar);
+ mlir::Value incr = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, negStep, loopStepVar);
+ mlir::Value lb = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, loopUBVar, loopLBVar);
+ mlir::Value ub = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, loopLBVar, loopUBVar);
+
+ // Compute the trip count assuming lb <= ub. This guarantees that the result
+ // is non-negative and we can use unsigned arithmetic.
+ mlir::Value span = firOpBuilder.create<mlir::arith::SubIOp>(
+ loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
+ mlir::Value tcMinusOne =
+ firOpBuilder.create<mlir::arith::DivUIOp>(loc, span, incr);
+ mlir::Value tcIfLooping = firOpBuilder.create<mlir::arith::AddIOp>(
+ loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw);
+
+ // Fall back to 0 if lb > ub
+ mlir::Value isZeroTC = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::slt, ub, lb);
+ mlir::Value tripcount = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isZeroTC, zero, tcIfLooping);
+ tripcounts.push_back(tripcount);
+
+ // Create the CLI handle.
+ auto newcli = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
+ mlir::Value cli = newcli.getResult();
+ clis.push_back(cli);
+
+ innermostEval = &*std::next(innermostEval->getNestedEvaluations().begin());
}
- // Get the loop bounds (and increment)
- auto &doLoopEval = nestedEval.getFirstNestedEvaluation();
- auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>();
- assert(doStmt && "Expected do loop to be in the nested evaluation");
- auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t);
- assert(loopControl.has_value());
- auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
- assert(bounds && "Expected bounds for canonical loop");
- lower::StatementContext stmtCtx;
- mlir::Value loopLBVar = fir::getBase(
- converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx));
- mlir::Value loopUBVar = fir::getBase(
- converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx));
- mlir::Value loopStepVar = [&]() {
- if (bounds->step) {
- return fir::getBase(
- converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx));
- }
+ // Step 2: Create nested canoncial loops
+ for (auto i : llvm::seq<size_t>(numLoops)) {
+ bool isInnermost = (i == numLoops - 1);
+ mlir::Type loopVarType = loopVarTypes[i];
+ mlir::Value tripcount = tripcounts[i];
+ mlir::Value cli = clis[i];
+ auto &&eval = evals[i];
+
+ auto ivCallback = [&, i, isInnermost](mlir::Operation *op)
+ -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
+ mlir::Region &region = op->getRegion(0);
+
+ // Create the op's region skeleton (BB taking the iv as argument)
+ firOpBuilder.createBlock(&region, {}, {loopVarType}, {loc});
+ blockArgs.push_back(region.front().getArgument(0));
+
+ // Step 2.1: Emit body prologue code
+ // Compute the translation from logical iteration number to the value of
+ // the loop's iteration variable only in the innermost body. Currently,
+ // loop transformations do not allow any instruction between loops, but
+ // this will change with
+ if (isInnermost) {
+ assert(blockArgs.size() == numLoops &&
+ "Expecting all block args to have been collected by now");
+ for (auto j : llvm::seq<size_t>(numLoops)) {
+ mlir::Value natIterNum = fir::getBase(blockArgs[j]);
+ mlir::Value scaled = firOpBuilder.create<mlir::arith::MulIOp>(
+ loc, natIterNum, loopStepVars[j]);
+ mlir::Value userVal = firOpBuilder.create<mlir::arith::AddIOp>(
+ loc, loopLBVars[j], scaled);
+
+ mlir::OpBuilder::InsertPoint insPt =
+ firOpBuilder.saveInsertionPoint();
+ firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
+ mlir::Type tempTy = converter.genType(*ivs[j]);
+ firOpBuilder.restoreInsertionPoint(insPt);
+
+ // Write the loop value into loop variable
+ mlir::Value cvtVal = firOpBuilder.createConvert(loc, tempTy, userVal);
+ hlfir::Entity lhs{converter.getSymbolAddress(*ivs[j])};
+ lhs = hlfir::derefPointersAndAllocatables(loc, firOpBuilder, lhs);
+ mlir::Operation *storeOp =
+ hlfir::AssignOp::create(firOpBuilder, loc, cvtVal, lhs);
+ firOpBuilder.setInsertionPointAfter(storeOp);
+ }
+ }
- // If `step` is not present, assume it is `1`.
- return firOpBuilder.createIntegerConstant(loc, firOpBuilder.getI32Type(),
- 1);
- }();
+ return {ivs[i]};
+ };
- // Get the integer kind for the loop variable and cast the loop bounds
- size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size();
- mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
- loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar);
- loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar);
- loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar);
-
- // Start lowering
- mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0);
- mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1);
- mlir::Value isDownwards = mlir::arith::CmpIOp::create(
- firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);
-
- // Ensure we are counting upwards. If not, negate step and swap lb and ub.
- mlir::Value negStep =
- mlir::arith::SubIOp::create(firOpBuilder, loc, zero, loopStepVar);
- mlir::Value incr = mlir::arith::SelectOp::create(
- firOpBuilder, loc, isDownwards, negStep, loopStepVar);
- mlir::Value lb = mlir::arith::SelectOp::create(firOpBuilder, loc, isDownwards,
- loopUBVar, loopLBVar);
- mlir::Value ub = mlir::arith::SelectOp::create(firOpBuilder, loc, isDownwards,
- loopLBVar, loopUBVar);
-
- // Compute the trip count assuming lb <= ub. This guarantees that the result
- // is non-negative and we can use unsigned arithmetic.
- mlir::Value span = mlir::arith::SubIOp::create(
- firOpBuilder, loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
- mlir::Value tcMinusOne =
- mlir::arith::DivUIOp::create(firOpBuilder, loc, span, incr);
- mlir::Value tcIfLooping =
- mlir::arith::AddIOp::create(firOpBuilder, loc, tcMinusOne, one,
- ::mlir::arith::IntegerOverflowFlags::nuw);
-
- // Fall back to 0 if lb > ub
- mlir::Value isZeroTC = mlir::arith::CmpIOp::create(
- firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, ub, lb);
- mlir::Value tripcount = mlir::arith::SelectOp::create(
- firOpBuilder, loc, isZeroTC, zero, tcIfLooping);
-
- // Create the CLI handle.
- auto newcli = mlir::omp::NewCliOp::create(firOpBuilder, loc);
- mlir::Value cli = newcli.getResult();
-
- auto ivCallback = [&](mlir::Operation *op)
- -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
- mlir::Region &region = op->getRegion(0);
-
- // Create the op's region skeleton (BB taking the iv as argument)
- firOpBuilder.createBlock(&region, {}, {loopVarType}, {loc});
-
- // Compute the value of the loop variable from the logical iteration number.
- mlir::Value natIterNum = fir::getBase(region.front().getArgument(0));
- mlir::Value scaled =
- mlir::arith::MulIOp::create(firOpBuilder, loc, natIterNum, loopStepVar);
- mlir::Value userVal =
- mlir::arith::AddIOp::create(firOpBuilder, loc, loopLBVar, scaled);
-
- // Write loop value to loop variable
- mlir::Operation *storeOp = setLoopVar(converter, loc, userVal, iv);
-
- firOpBuilder.setInsertionPointAfter(storeOp);
- return {iv};
- };
+ // Create the omp.canonical_loop operation
+ auto opGenInfo = OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *eval,
+ llvm::omp::Directive::OMPD_unknown)
+ .setGenSkeletonOnly(!isInnermost)
+ .setClauses(&item->clauses)
+ .setPrivatize(false)
+ .setGenRegionEntryCb(ivCallback);
+ auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
+ std::move(opGenInfo), queue, item, tripcount, cli);
+ loops.push_back(canonLoop);
+
+ // Insert next loop nested inside last loop
+ firOpBuilder.setInsertionPoint(
+ canonLoop.getRegion().back().getTerminator());
+ }
- // Create the omp.canonical_loop operation
- auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
- OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval,
- directive)
- .setClauses(&item->clauses)
- .setPrivatize(false)
- .setGenRegionEntryCb(ivCallback),
- queue, item, tripcount, cli);
+ firOpBuilder.setInsertionPointAfter(loops.front());
+}
+
+static void genTileOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ lower::StatementContext &stmtCtx,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const ConstructQueue &queue,
+ ConstructQueue::const_iterator item) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- firOpBuilder.setInsertionPointAfter(canonLoop);
- return canonLoop;
+ mlir::omp::SizesClauseOps sizesClause;
+ ClauseProcessor cp(converter, semaCtx, item->clauses);
+ cp.processSizes(stmtCtx, sizesClause);
+
+ size_t numLoops = sizesClause.sizes.size();
+ llvm::SmallVector<mlir::omp::CanonicalLoopOp, 3> canonLoops;
+ canonLoops.reserve(numLoops);
+
+ genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item,
+ numLoops, canonLoops);
+ assert((canonLoops.size() == numLoops) &&
+ "Expecting the predetermined number of loops");
+
+ llvm::SmallVector<mlir::Value, 3> applyees;
+ applyees.reserve(numLoops);
+ for (mlir::omp::CanonicalLoopOp l : canonLoops)
+ applyees.push_back(l.getCli());
+
+ // Emit the associated loops and create a CLI for each affected loop
+ llvm::SmallVector<mlir::Value, 3> gridGeneratees;
+ llvm::SmallVector<mlir::Value, 3> intratileGeneratees;
+ gridGeneratees.reserve(numLoops);
+ intratileGeneratees.reserve(numLoops);
+ for ([[maybe_unused]] auto i : llvm::seq<int>(0, sizesClause.sizes.size())) {
+ auto gridCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
+ gridGeneratees.push_back(gridCLI.getResult());
+ auto intratileCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
+ intratileGeneratees.push_back(intratileCLI.getResult());
+ }
+
+ llvm::SmallVector<mlir::Value, 6> generatees;
+ generatees.reserve(2 * numLoops);
+ generatees.append(gridGeneratees);
+ generatees.append(intratileGeneratees);
+
+ firOpBuilder.create<mlir::omp::TileOp>(loc, generatees, applyees,
+ sizesClause.sizes);
}
static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
@@ -2112,22 +2228,22 @@ static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
ConstructQueue::const_iterator item) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::omp::LoopRelatedClauseOps loopInfo;
- llvm::SmallVector<const semantics::Symbol *> iv;
- collectLoopRelatedInfo(converter, loc, eval, item->clauses, loopInfo, iv);
-
// Clauses for unrolling not yet implemnted
ClauseProcessor cp(converter, semaCtx, item->clauses);
cp.processTODO<clause::Partial, clause::Full>(
loc, llvm::omp::Directive::OMPD_unroll);
// Emit the associated loop
- auto canonLoop =
- genCanonicalLoopOp(converter, symTable, semaCtx, eval, loc, queue, item,
- iv, llvm::omp::Directive::OMPD_unroll);
+ llvm::SmallVector<mlir::omp::CanonicalLoopOp, 1> canonLoops;
+ genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item, 1,
+ canonLoops);
+
+ llvm::SmallVector<mlir::Value, 1> applyees;
+ for (auto &&canonLoop : canonLoops)
+ applyees.push_back(canonLoop.getCli());
// Apply unrolling to it
- auto cli = canonLoop.getCli();
+ auto cli = llvm::getSingleElement(canonLoops).getCli();
mlir::omp::UnrollHeuristicOp::create(firOpBuilder, loc, cli);
}
@@ -3360,13 +3476,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
item);
break;
- case llvm::omp::Directive::OMPD_tile: {
- unsigned version = semaCtx.langOptions().OpenMPVersion;
- if (!semaCtx.langOptions().OpenMPSimd)
- TODO(loc, "Unhandled loop directive (" +
- llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
+ case llvm::omp::Directive::OMPD_tile:
+ genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
- }
case llvm::omp::Directive::OMPD_unroll:
genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;