diff options
Diffstat (limited to 'flang/lib/Lower/OpenMP/OpenMP.cpp')
-rw-r--r-- | flang/lib/Lower/OpenMP/OpenMP.cpp | 360 |
1 files changed, 236 insertions, 124 deletions
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 1cb3335..9e56c2b 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1982,125 +1982,241 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable, return loopOp; } -static mlir::omp::CanonicalLoopOp -genCanonicalLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, - lower::pft::Evaluation &eval, mlir::Location loc, - const ConstructQueue &queue, - ConstructQueue::const_iterator item, - llvm::ArrayRef<const semantics::Symbol *> ivs, - llvm::omp::Directive directive) { +static void genCanonicalLoopNest( + lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, + mlir::Location loc, const ConstructQueue &queue, + ConstructQueue::const_iterator item, size_t numLoops, + llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) { + assert(loops.empty() && "Expecting empty list to fill"); + assert(numLoops >= 1 && "Expecting at least one loop"); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - assert(ivs.size() == 1 && "Nested loops not yet implemented"); - const semantics::Symbol *iv = ivs[0]; + mlir::omp::LoopRelatedClauseOps loopInfo; + llvm::SmallVector<const semantics::Symbol *, 3> ivs; + collectLoopRelatedInfo(converter, loc, eval, numLoops, loopInfo, ivs); + assert(ivs.size() == numLoops && + "Expected to parse as many loop variables as there are loops"); + + // Steps that follow: + // 1. Emit all of the loop's prologues (compute the tripcount) + // 2. Emit omp.canonical_loop nested inside each other (iteratively) + // 2.1. In the innermost omp.canonical_loop, emit the loop body prologue (in + // the body callback) + // + // Since emitting prologues and body code is split, remember prologue values + // for use when emitting the same loop's epilogues. + llvm::SmallVector<mlir::Value> tripcounts; + llvm::SmallVector<mlir::Value> clis; + llvm::SmallVector<lower::pft::Evaluation *> evals; + llvm::SmallVector<mlir::Type> loopVarTypes; + llvm::SmallVector<mlir::Value> loopStepVars; + llvm::SmallVector<mlir::Value> loopLBVars; + llvm::SmallVector<mlir::Value> blockArgs; + + // Step 1: Loop prologues + // Computing the trip count must happen before entering the outermost loop + lower::pft::Evaluation *innermostEval = &eval.getFirstNestedEvaluation(); + for ([[maybe_unused]] auto iv : ivs) { + if (innermostEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) { + // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct. + // Will need to add special cases for this combination. + TODO(loc, "DO CONCURRENT as canonical loop not supported"); + } + + auto &doLoopEval = innermostEval->getFirstNestedEvaluation(); + evals.push_back(innermostEval); + + // Get the loop bounds (and increment) + // auto &doLoopEval = nestedEval.getFirstNestedEvaluation(); + auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>(); + assert(doStmt && "Expected do loop to be in the nested evaluation"); + auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t); + assert(loopControl.has_value()); + auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u); + assert(bounds && "Expected bounds for canonical loop"); + lower::StatementContext stmtCtx; + mlir::Value loopLBVar = fir::getBase( + converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx)); + mlir::Value loopUBVar = fir::getBase( + converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx)); + mlir::Value loopStepVar = [&]() { + if (bounds->step) { + return fir::getBase( + converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx)); + } - auto &nestedEval = eval.getFirstNestedEvaluation(); - if (nestedEval.getIf<parser::DoConstruct>()->IsDoConcurrent()) { - // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct. Will - // need to add special cases for this combination. - TODO(loc, "DO CONCURRENT as canonical loop not supported"); + // If `step` is not present, assume it is `1`. + auto intTy = firOpBuilder.getI32Type(); + return firOpBuilder.createIntegerConstant(loc, intTy, 1); + }(); + + // Get the integer kind for the loop variable and cast the loop bounds + size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size(); + mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); + loopVarTypes.push_back(loopVarType); + loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar); + loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar); + loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar); + loopLBVars.push_back(loopLBVar); + loopStepVars.push_back(loopStepVar); + + // Start lowering + mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0); + mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1); + mlir::Value isDownwards = firOpBuilder.create<mlir::arith::CmpIOp>( + loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero); + + // Ensure we are counting upwards. If not, negate step and swap lb and ub. + mlir::Value negStep = + firOpBuilder.create<mlir::arith::SubIOp>(loc, zero, loopStepVar); + mlir::Value incr = firOpBuilder.create<mlir::arith::SelectOp>( + loc, isDownwards, negStep, loopStepVar); + mlir::Value lb = firOpBuilder.create<mlir::arith::SelectOp>( + loc, isDownwards, loopUBVar, loopLBVar); + mlir::Value ub = firOpBuilder.create<mlir::arith::SelectOp>( + loc, isDownwards, loopLBVar, loopUBVar); + + // Compute the trip count assuming lb <= ub. This guarantees that the result + // is non-negative and we can use unsigned arithmetic. + mlir::Value span = firOpBuilder.create<mlir::arith::SubIOp>( + loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw); + mlir::Value tcMinusOne = + firOpBuilder.create<mlir::arith::DivUIOp>(loc, span, incr); + mlir::Value tcIfLooping = firOpBuilder.create<mlir::arith::AddIOp>( + loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw); + + // Fall back to 0 if lb > ub + mlir::Value isZeroTC = firOpBuilder.create<mlir::arith::CmpIOp>( + loc, mlir::arith::CmpIPredicate::slt, ub, lb); + mlir::Value tripcount = firOpBuilder.create<mlir::arith::SelectOp>( + loc, isZeroTC, zero, tcIfLooping); + tripcounts.push_back(tripcount); + + // Create the CLI handle. + auto newcli = firOpBuilder.create<mlir::omp::NewCliOp>(loc); + mlir::Value cli = newcli.getResult(); + clis.push_back(cli); + + innermostEval = &*std::next(innermostEval->getNestedEvaluations().begin()); } - // Get the loop bounds (and increment) - auto &doLoopEval = nestedEval.getFirstNestedEvaluation(); - auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>(); - assert(doStmt && "Expected do loop to be in the nested evaluation"); - auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t); - assert(loopControl.has_value()); - auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u); - assert(bounds && "Expected bounds for canonical loop"); - lower::StatementContext stmtCtx; - mlir::Value loopLBVar = fir::getBase( - converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx)); - mlir::Value loopUBVar = fir::getBase( - converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx)); - mlir::Value loopStepVar = [&]() { - if (bounds->step) { - return fir::getBase( - converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx)); - } + // Step 2: Create nested canoncial loops + for (auto i : llvm::seq<size_t>(numLoops)) { + bool isInnermost = (i == numLoops - 1); + mlir::Type loopVarType = loopVarTypes[i]; + mlir::Value tripcount = tripcounts[i]; + mlir::Value cli = clis[i]; + auto &&eval = evals[i]; + + auto ivCallback = [&, i, isInnermost](mlir::Operation *op) + -> llvm::SmallVector<const Fortran::semantics::Symbol *> { + mlir::Region ®ion = op->getRegion(0); + + // Create the op's region skeleton (BB taking the iv as argument) + firOpBuilder.createBlock(®ion, {}, {loopVarType}, {loc}); + blockArgs.push_back(region.front().getArgument(0)); + + // Step 2.1: Emit body prologue code + // Compute the translation from logical iteration number to the value of + // the loop's iteration variable only in the innermost body. Currently, + // loop transformations do not allow any instruction between loops, but + // this will change with + if (isInnermost) { + assert(blockArgs.size() == numLoops && + "Expecting all block args to have been collected by now"); + for (auto j : llvm::seq<size_t>(numLoops)) { + mlir::Value natIterNum = fir::getBase(blockArgs[j]); + mlir::Value scaled = firOpBuilder.create<mlir::arith::MulIOp>( + loc, natIterNum, loopStepVars[j]); + mlir::Value userVal = firOpBuilder.create<mlir::arith::AddIOp>( + loc, loopLBVars[j], scaled); + + mlir::OpBuilder::InsertPoint insPt = + firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); + mlir::Type tempTy = converter.genType(*ivs[j]); + firOpBuilder.restoreInsertionPoint(insPt); + + // Write the loop value into loop variable + mlir::Value cvtVal = firOpBuilder.createConvert(loc, tempTy, userVal); + hlfir::Entity lhs{converter.getSymbolAddress(*ivs[j])}; + lhs = hlfir::derefPointersAndAllocatables(loc, firOpBuilder, lhs); + mlir::Operation *storeOp = + hlfir::AssignOp::create(firOpBuilder, loc, cvtVal, lhs); + firOpBuilder.setInsertionPointAfter(storeOp); + } + } - // If `step` is not present, assume it is `1`. - return firOpBuilder.createIntegerConstant(loc, firOpBuilder.getI32Type(), - 1); - }(); + return {ivs[i]}; + }; - // Get the integer kind for the loop variable and cast the loop bounds - size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size(); - mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); - loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar); - loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar); - loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar); - - // Start lowering - mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0); - mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1); - mlir::Value isDownwards = mlir::arith::CmpIOp::create( - firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero); - - // Ensure we are counting upwards. If not, negate step and swap lb and ub. - mlir::Value negStep = - mlir::arith::SubIOp::create(firOpBuilder, loc, zero, loopStepVar); - mlir::Value incr = mlir::arith::SelectOp::create( - firOpBuilder, loc, isDownwards, negStep, loopStepVar); - mlir::Value lb = mlir::arith::SelectOp::create(firOpBuilder, loc, isDownwards, - loopUBVar, loopLBVar); - mlir::Value ub = mlir::arith::SelectOp::create(firOpBuilder, loc, isDownwards, - loopLBVar, loopUBVar); - - // Compute the trip count assuming lb <= ub. This guarantees that the result - // is non-negative and we can use unsigned arithmetic. - mlir::Value span = mlir::arith::SubIOp::create( - firOpBuilder, loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw); - mlir::Value tcMinusOne = - mlir::arith::DivUIOp::create(firOpBuilder, loc, span, incr); - mlir::Value tcIfLooping = - mlir::arith::AddIOp::create(firOpBuilder, loc, tcMinusOne, one, - ::mlir::arith::IntegerOverflowFlags::nuw); - - // Fall back to 0 if lb > ub - mlir::Value isZeroTC = mlir::arith::CmpIOp::create( - firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, ub, lb); - mlir::Value tripcount = mlir::arith::SelectOp::create( - firOpBuilder, loc, isZeroTC, zero, tcIfLooping); - - // Create the CLI handle. - auto newcli = mlir::omp::NewCliOp::create(firOpBuilder, loc); - mlir::Value cli = newcli.getResult(); - - auto ivCallback = [&](mlir::Operation *op) - -> llvm::SmallVector<const Fortran::semantics::Symbol *> { - mlir::Region ®ion = op->getRegion(0); - - // Create the op's region skeleton (BB taking the iv as argument) - firOpBuilder.createBlock(®ion, {}, {loopVarType}, {loc}); - - // Compute the value of the loop variable from the logical iteration number. - mlir::Value natIterNum = fir::getBase(region.front().getArgument(0)); - mlir::Value scaled = - mlir::arith::MulIOp::create(firOpBuilder, loc, natIterNum, loopStepVar); - mlir::Value userVal = - mlir::arith::AddIOp::create(firOpBuilder, loc, loopLBVar, scaled); - - // Write loop value to loop variable - mlir::Operation *storeOp = setLoopVar(converter, loc, userVal, iv); - - firOpBuilder.setInsertionPointAfter(storeOp); - return {iv}; - }; + // Create the omp.canonical_loop operation + auto opGenInfo = OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *eval, + llvm::omp::Directive::OMPD_unknown) + .setGenSkeletonOnly(!isInnermost) + .setClauses(&item->clauses) + .setPrivatize(false) + .setGenRegionEntryCb(ivCallback); + auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>( + std::move(opGenInfo), queue, item, tripcount, cli); + loops.push_back(canonLoop); + + // Insert next loop nested inside last loop + firOpBuilder.setInsertionPoint( + canonLoop.getRegion().back().getTerminator()); + } - // Create the omp.canonical_loop operation - auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>( - OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval, - directive) - .setClauses(&item->clauses) - .setPrivatize(false) - .setGenRegionEntryCb(ivCallback), - queue, item, tripcount, cli); + firOpBuilder.setInsertionPointAfter(loops.front()); +} + +static void genTileOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + lower::StatementContext &stmtCtx, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const ConstructQueue &queue, + ConstructQueue::const_iterator item) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - firOpBuilder.setInsertionPointAfter(canonLoop); - return canonLoop; + mlir::omp::SizesClauseOps sizesClause; + ClauseProcessor cp(converter, semaCtx, item->clauses); + cp.processSizes(stmtCtx, sizesClause); + + size_t numLoops = sizesClause.sizes.size(); + llvm::SmallVector<mlir::omp::CanonicalLoopOp, 3> canonLoops; + canonLoops.reserve(numLoops); + + genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item, + numLoops, canonLoops); + assert((canonLoops.size() == numLoops) && + "Expecting the predetermined number of loops"); + + llvm::SmallVector<mlir::Value, 3> applyees; + applyees.reserve(numLoops); + for (mlir::omp::CanonicalLoopOp l : canonLoops) + applyees.push_back(l.getCli()); + + // Emit the associated loops and create a CLI for each affected loop + llvm::SmallVector<mlir::Value, 3> gridGeneratees; + llvm::SmallVector<mlir::Value, 3> intratileGeneratees; + gridGeneratees.reserve(numLoops); + intratileGeneratees.reserve(numLoops); + for ([[maybe_unused]] auto i : llvm::seq<int>(0, sizesClause.sizes.size())) { + auto gridCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc); + gridGeneratees.push_back(gridCLI.getResult()); + auto intratileCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc); + intratileGeneratees.push_back(intratileCLI.getResult()); + } + + llvm::SmallVector<mlir::Value, 6> generatees; + generatees.reserve(2 * numLoops); + generatees.append(gridGeneratees); + generatees.append(intratileGeneratees); + + firOpBuilder.create<mlir::omp::TileOp>(loc, generatees, applyees, + sizesClause.sizes); } static void genUnrollOp(Fortran::lower::AbstractConverter &converter, @@ -2112,22 +2228,22 @@ static void genUnrollOp(Fortran::lower::AbstractConverter &converter, ConstructQueue::const_iterator item) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::omp::LoopRelatedClauseOps loopInfo; - llvm::SmallVector<const semantics::Symbol *> iv; - collectLoopRelatedInfo(converter, loc, eval, item->clauses, loopInfo, iv); - // Clauses for unrolling not yet implemnted ClauseProcessor cp(converter, semaCtx, item->clauses); cp.processTODO<clause::Partial, clause::Full>( loc, llvm::omp::Directive::OMPD_unroll); // Emit the associated loop - auto canonLoop = - genCanonicalLoopOp(converter, symTable, semaCtx, eval, loc, queue, item, - iv, llvm::omp::Directive::OMPD_unroll); + llvm::SmallVector<mlir::omp::CanonicalLoopOp, 1> canonLoops; + genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item, 1, + canonLoops); + + llvm::SmallVector<mlir::Value, 1> applyees; + for (auto &&canonLoop : canonLoops) + applyees.push_back(canonLoop.getCli()); // Apply unrolling to it - auto cli = canonLoop.getCli(); + auto cli = llvm::getSingleElement(canonLoops).getCli(); mlir::omp::UnrollHeuristicOp::create(firOpBuilder, loc, cli); } @@ -3360,13 +3476,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter, newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; - case llvm::omp::Directive::OMPD_tile: { - unsigned version = semaCtx.langOptions().OpenMPVersion; - if (!semaCtx.langOptions().OpenMPSimd) - TODO(loc, "Unhandled loop directive (" + - llvm::omp::getOpenMPDirectiveName(dir, version) + ")"); + case llvm::omp::Directive::OMPD_tile: + genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; - } case llvm::omp::Directive::OMPD_unroll: genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; |