diff options
Diffstat (limited to 'flang/lib/Lower/OpenMP/OpenMP.cpp')
-rw-r--r-- | flang/lib/Lower/OpenMP/OpenMP.cpp | 256 |
1 files changed, 215 insertions, 41 deletions
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index fcb20fd..4c2d7bad 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -31,6 +31,7 @@ #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Parser/characters.h" +#include "flang/Parser/openmp-utils.h" #include "flang/Parser/parse-tree.h" #include "flang/Semantics/openmp-directive-sets.h" #include "flang/Semantics/tools.h" @@ -446,7 +447,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, llvm::omp::Directive dir; auto &nested = parent.getFirstNestedEvaluation(); if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>()) - dir = extractOmpDirective(*ompEval); + dir = parser::omp::GetOmpDirectiveName(*ompEval).v; else return std::nullopt; @@ -486,7 +487,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter); assert(hostInfo && "expected HOST_EVAL info structure"); - switch (extractOmpDirective(*ompEval)) { + switch (parser::omp::GetOmpDirectiveName(*ompEval).v) { case OMPD_teams_distribute_parallel_do: case OMPD_teams_distribute_parallel_do_simd: cp.processThreadLimit(stmtCtx, hostInfo->ops); @@ -547,7 +548,8 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, const auto *ompEval = eval.getIf<parser::OpenMPConstruct>(); assert(ompEval && - llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) && + llvm::omp::allTargetSet.test( + parser::omp::GetOmpDirectiveName(*ompEval).v) && "expected TARGET construct evaluation"); (void)ompEval; @@ -642,8 +644,8 @@ static void threadPrivatizeVars(lower::AbstractConverter &converter, op = declOp.getMemref().getDefiningOp(); if (mlir::isa<mlir::omp::ThreadprivateOp>(op)) symValue = mlir::dyn_cast<mlir::omp::ThreadprivateOp>(op).getSymAddr(); - return firOpBuilder.create<mlir::omp::ThreadprivateOp>( - currentLocation, symValue.getType(), symValue); + return mlir::omp::ThreadprivateOp::create(firOpBuilder, currentLocation, + symValue.getType(), symValue); }; llvm::SetVector<const semantics::Symbol *> threadprivateSyms; @@ -710,7 +712,7 @@ createAndSetPrivatizedLoopVar(lower::AbstractConverter &converter, lhs = hlfir::derefPointersAndAllocatables(loc, firOpBuilder, lhs); mlir::Operation *storeOp = - firOpBuilder.create<hlfir::AssignOp>(loc, cvtVal, lhs); + hlfir::AssignOp::create(firOpBuilder, loc, cvtVal, lhs); return storeOp; } @@ -1156,8 +1158,8 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info, fir::FirOpBuilder &firOpBuilder = info.converter.getFirOpBuilder(); auto insertMarker = [](fir::FirOpBuilder &builder) { - mlir::Value undef = builder.create<fir::UndefOp>(builder.getUnknownLoc(), - builder.getIndexType()); + mlir::Value undef = fir::UndefOp::create(builder, builder.getUnknownLoc(), + builder.getIndexType()); return undef.getDefiningOp(); }; @@ -1271,7 +1273,7 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info, mlir::Block *exit = firOpBuilder.createBlock(®ion); for (mlir::Block *b : exits) { firOpBuilder.setInsertionPointToEnd(b); - firOpBuilder.create<mlir::cf::BranchOp>(info.loc, exit); + mlir::cf::BranchOp::create(firOpBuilder, info.loc, exit); } return exit; }; @@ -1332,8 +1334,8 @@ static void genBodyOfTargetDataOp( // Remembering the position for further insertion is important since // there are hlfir.declares inserted above while setting block arguments // and new code from the body should be inserted after that. - mlir::Value undefMarker = firOpBuilder.create<fir::UndefOp>( - dataOp.getLoc(), firOpBuilder.getIndexType()); + mlir::Value undefMarker = fir::UndefOp::create(firOpBuilder, dataOp.getLoc(), + firOpBuilder.getIndexType()); // Create blocks for unstructured regions. This has to be done since // blocks are initially allocated with the function as the parent region. @@ -1342,7 +1344,7 @@ static void genBodyOfTargetDataOp( firOpBuilder, eval.getNestedEvaluations()); } - firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation); + mlir::omp::TerminatorOp::create(firOpBuilder, currentLocation); // Set the insertion point after the marker. firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp()); @@ -1496,8 +1498,8 @@ static void genBodyOfTargetOp( insertIndex, copyVal.getType(), copyVal.getLoc()); firOpBuilder.setInsertionPointToStart(entryBlock); - auto loadOp = firOpBuilder.create<fir::LoadOp>(clonedValArg.getLoc(), - clonedValArg); + auto loadOp = fir::LoadOp::create(firOpBuilder, clonedValArg.getLoc(), + clonedValArg); val.replaceUsesWithIf(loadOp->getResult(0), [entryBlock](mlir::OpOperand &use) { return use.getOwner()->getBlock() == entryBlock; @@ -1513,8 +1515,8 @@ static void genBodyOfTargetOp( // marker will be deleted since there are not uses. // In the HLFIR flow there are hlfir.declares inserted above while // setting block arguments. - mlir::Value undefMarker = firOpBuilder.create<fir::UndefOp>( - targetOp.getLoc(), firOpBuilder.getIndexType()); + mlir::Value undefMarker = fir::UndefOp::create( + firOpBuilder, targetOp.getLoc(), firOpBuilder.getIndexType()); // Create blocks for unstructured regions. This has to be done since // blocks are initially allocated with the function as the parent region. @@ -1524,7 +1526,7 @@ static void genBodyOfTargetOp( firOpBuilder, eval.getNestedEvaluations()); } - firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation); + mlir::omp::TerminatorOp::create(firOpBuilder, currentLocation); // Create the insertion point after the marker. firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp()); @@ -1570,7 +1572,7 @@ static OpTy genWrapperOp(lower::AbstractConverter &converter, fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); // Create wrapper. - auto op = firOpBuilder.create<OpTy>(loc, clauseOps); + auto op = OpTy::create(firOpBuilder, loc, clauseOps); // Create entry block with arguments. genEntryBlock(firOpBuilder, args, op.getRegion()); @@ -1983,7 +1985,7 @@ genCriticalOp(lower::AbstractConverter &converter, lower::SymMap &symTable, clauseOps, nameStr); mlir::OpBuilder modBuilder(mod.getBodyRegion()); - global = modBuilder.create<mlir::omp::CriticalDeclareOp>(loc, clauseOps); + global = mlir::omp::CriticalDeclareOp::create(modBuilder, loc, clauseOps); } nameAttr = mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), global.getSymName()); @@ -2069,6 +2071,163 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable, return loopOp; } +static mlir::omp::CanonicalLoopOp +genCanonicalLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, + lower::pft::Evaluation &eval, mlir::Location loc, + const ConstructQueue &queue, + ConstructQueue::const_iterator item, + llvm::ArrayRef<const semantics::Symbol *> ivs, + llvm::omp::Directive directive, DataSharingProcessor &dsp) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + assert(ivs.size() == 1 && "Nested loops not yet implemented"); + const semantics::Symbol *iv = ivs[0]; + + auto &nestedEval = eval.getFirstNestedEvaluation(); + if (nestedEval.getIf<parser::DoConstruct>()->IsDoConcurrent()) { + // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct. Will + // need to add special cases for this combination. + TODO(loc, "DO CONCURRENT as canonical loop not supported"); + } + + // Get the loop bounds (and increment) + auto &doLoopEval = nestedEval.getFirstNestedEvaluation(); + auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>(); + assert(doStmt && "Expected do loop to be in the nested evaluation"); + auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t); + assert(loopControl.has_value()); + auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u); + assert(bounds && "Expected bounds for canonical loop"); + lower::StatementContext stmtCtx; + mlir::Value loopLBVar = fir::getBase( + converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx)); + mlir::Value loopUBVar = fir::getBase( + converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx)); + mlir::Value loopStepVar = [&]() { + if (bounds->step) { + return fir::getBase( + converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx)); + } + + // If `step` is not present, assume it is `1`. + return firOpBuilder.createIntegerConstant(loc, firOpBuilder.getI32Type(), + 1); + }(); + + // Get the integer kind for the loop variable and cast the loop bounds + size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size(); + mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); + loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar); + loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar); + loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar); + + // Start lowering + mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0); + mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1); + mlir::Value isDownwards = firOpBuilder.create<mlir::arith::CmpIOp>( + loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero); + + // Ensure we are counting upwards. If not, negate step and swap lb and ub. + mlir::Value negStep = + firOpBuilder.create<mlir::arith::SubIOp>(loc, zero, loopStepVar); + mlir::Value incr = firOpBuilder.create<mlir::arith::SelectOp>( + loc, isDownwards, negStep, loopStepVar); + mlir::Value lb = firOpBuilder.create<mlir::arith::SelectOp>( + loc, isDownwards, loopUBVar, loopLBVar); + mlir::Value ub = firOpBuilder.create<mlir::arith::SelectOp>( + loc, isDownwards, loopLBVar, loopUBVar); + + // Compute the trip count assuming lb <= ub. This guarantees that the result + // is non-negative and we can use unsigned arithmetic. + mlir::Value span = firOpBuilder.create<mlir::arith::SubIOp>( + loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw); + mlir::Value tcMinusOne = + firOpBuilder.create<mlir::arith::DivUIOp>(loc, span, incr); + mlir::Value tcIfLooping = firOpBuilder.create<mlir::arith::AddIOp>( + loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw); + + // Fall back to 0 if lb > ub + mlir::Value isZeroTC = firOpBuilder.create<mlir::arith::CmpIOp>( + loc, mlir::arith::CmpIPredicate::slt, ub, lb); + mlir::Value tripcount = firOpBuilder.create<mlir::arith::SelectOp>( + loc, isZeroTC, zero, tcIfLooping); + + // Create the CLI handle. + auto newcli = firOpBuilder.create<mlir::omp::NewCliOp>(loc); + mlir::Value cli = newcli.getResult(); + + auto ivCallback = [&](mlir::Operation *op) + -> llvm::SmallVector<const Fortran::semantics::Symbol *> { + mlir::Region ®ion = op->getRegion(0); + + // Create the op's region skeleton (BB taking the iv as argument) + firOpBuilder.createBlock(®ion, {}, {loopVarType}, {loc}); + + // Compute the value of the loop variable from the logical iteration number. + mlir::Value natIterNum = fir::getBase(region.front().getArgument(0)); + mlir::Value scaled = + firOpBuilder.create<mlir::arith::MulIOp>(loc, natIterNum, loopStepVar); + mlir::Value userVal = + firOpBuilder.create<mlir::arith::AddIOp>(loc, loopLBVar, scaled); + + // The argument is not currently in memory, so make a temporary for the + // argument, and store it there, then bind that location to the argument. + mlir::Operation *storeOp = + createAndSetPrivatizedLoopVar(converter, loc, userVal, iv); + + firOpBuilder.setInsertionPointAfter(storeOp); + return {iv}; + }; + + // Create the omp.canonical_loop operation + auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>( + OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval, + directive) + .setClauses(&item->clauses) + .setDataSharingProcessor(&dsp) + .setGenRegionEntryCb(ivCallback), + queue, item, tripcount, cli); + + firOpBuilder.setInsertionPointAfter(canonLoop); + return canonLoop; +} + +static void genUnrollOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + lower::StatementContext &stmtCtx, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + mlir::Location loc, const ConstructQueue &queue, + ConstructQueue::const_iterator item) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + mlir::omp::LoopRelatedClauseOps loopInfo; + llvm::SmallVector<const semantics::Symbol *> iv; + collectLoopRelatedInfo(converter, loc, eval, item->clauses, loopInfo, iv); + + // Clauses for unrolling not yet implemnted + ClauseProcessor cp(converter, semaCtx, item->clauses); + cp.processTODO<clause::Partial, clause::Full>( + loc, llvm::omp::Directive::OMPD_unroll); + + // Even though unroll does not support data-sharing clauses, but this is + // required to fill the symbol table. + DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, + /*shouldCollectPreDeterminedSymbols=*/true, + /*useDelayedPrivatization=*/false, symTable); + dsp.processStep1(); + + // Emit the associated loop + auto canonLoop = + genCanonicalLoopOp(converter, symTable, semaCtx, eval, loc, queue, item, + iv, llvm::omp::Directive::OMPD_unroll, dsp); + + // Apply unrolling to it + auto cli = canonLoop.getCli(); + firOpBuilder.create<mlir::omp::UnrollHeuristicOp>(loc, cli); +} + static mlir::omp::MaskedOp genMaskedOp(lower::AbstractConverter &converter, lower::SymMap &symTable, lower::StatementContext &stmtCtx, @@ -2201,7 +2360,7 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, } // SECTIONS construct. - auto sectionsOp = builder.create<mlir::omp::SectionsOp>(loc, clauseOps); + auto sectionsOp = mlir::omp::SectionsOp::create(builder, loc, clauseOps); // Create entry block with reduction variables as arguments. EntryBlockArgs args; @@ -2277,7 +2436,7 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, // races on post-update of lastprivate variables when `nowait` // clause is present. if (clauseOps.nowait && !lastprivates.empty()) - builder.create<mlir::omp::BarrierOp>(loc); + mlir::omp::BarrierOp::create(builder, loc); return sectionsOp; } @@ -2429,7 +2588,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, }; lower::pft::visitAllSymbols(eval, captureImplicitMap); - auto targetOp = firOpBuilder.create<mlir::omp::TargetOp>(loc, clauseOps); + auto targetOp = mlir::omp::TargetOp::create(firOpBuilder, loc, clauseOps); llvm::SmallVector<mlir::Value> hasDeviceAddrBaseValues, mapBaseValues; extractMappedBaseValues(clauseOps.hasDeviceAddrVars, hasDeviceAddrBaseValues); @@ -2509,7 +2668,7 @@ static OpTy genTargetEnterExitUpdateDataOp( genTargetEnterExitUpdateDataClauses(converter, semaCtx, symTable, stmtCtx, item->clauses, loc, directive, clauseOps); - return firOpBuilder.create<OpTy>(loc, clauseOps); + return OpTy::create(firOpBuilder, loc, clauseOps); } static mlir::omp::TaskOp @@ -3249,12 +3408,14 @@ static void genOMPDispatch(lower::AbstractConverter &converter, newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; - case llvm::omp::Directive::OMPD_tile: - case llvm::omp::Directive::OMPD_unroll: { + case llvm::omp::Directive::OMPD_tile: { unsigned version = semaCtx.langOptions().OpenMPVersion; TODO(loc, "Unhandled loop directive (" + llvm::omp::getOpenMPDirectiveName(dir, version) + ")"); } + case llvm::omp::Directive::OMPD_unroll: + genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); + break; // case llvm::omp::Directive::OMPD_workdistribute: case llvm::omp::Directive::OMPD_workshare: newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc, @@ -3342,8 +3503,8 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody()); auto mlirType = converter.genType(varType.declTypeSpec->derivedTypeSpec()); - auto declMapperOp = firOpBuilder.create<mlir::omp::DeclareMapperOp>( - loc, mapperNameStr, mlirType); + auto declMapperOp = mlir::omp::DeclareMapperOp::create( + firOpBuilder, loc, mapperNameStr, mlirType); auto ®ion = declMapperOp.getRegion(); firOpBuilder.createBlock(®ion); auto varVal = region.addArgument(firOpBuilder.getRefType(mlirType), loc); @@ -3356,7 +3517,7 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, List<Clause> clauses = makeClauses(*clauseList, semaCtx); ClauseProcessor cp(converter, semaCtx, clauses); cp.processMap(loc, stmtCtx, clauseOps); - firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars); + mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseOps.mapVars); } static void @@ -3690,12 +3851,26 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, if (auto *ompNestedLoopCons{ std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>( &*optLoopCons)}) { - genOMP(converter, symTable, semaCtx, eval, ompNestedLoopCons->value()); + llvm::omp::Directive nestedDirective = + parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v; + switch (nestedDirective) { + case llvm::omp::Directive::OMPD_tile: + // Emit the omp.loop_nest with annotation for tiling + genOMP(converter, symTable, semaCtx, eval, ompNestedLoopCons->value()); + break; + default: { + unsigned version = semaCtx.langOptions().OpenMPVersion; + TODO(currentLocation, + "Applying a loop-associated on the loop generated by the " + + llvm::omp::getOpenMPDirectiveName(nestedDirective, version) + + " construct"); + } + } } } llvm::omp::Directive directive = - std::get<parser::OmpLoopDirective>(beginLoopDirective.t).v; + parser::omp::GetOmpDirectiveName(beginLoopDirective).v; const parser::CharBlock &source = std::get<parser::OmpLoopDirective>(beginLoopDirective.t).source; ConstructQueue queue{ @@ -3758,8 +3933,8 @@ mlir::Operation *Fortran::lower::genOpenMPTerminator(fir::FirOpBuilder &builder, mlir::Location loc) { if (mlir::isa<mlir::omp::AtomicUpdateOp, mlir::omp::DeclareReductionOp, mlir::omp::LoopNestOp>(op)) - return builder.create<mlir::omp::YieldOp>(loc); - return builder.create<mlir::omp::TerminatorOp>(loc); + return mlir::omp::YieldOp::create(builder, loc); + return mlir::omp::TerminatorOp::create(builder, loc); } void Fortran::lower::genOpenMPConstruct(lower::AbstractConverter &converter, @@ -3819,9 +3994,8 @@ void Fortran::lower::genThreadprivateOp(lower::AbstractConverter &converter, return; } // Generate ThreadprivateOp and rebind the common block. - mlir::Value commonThreadprivateValue = - firOpBuilder.create<mlir::omp::ThreadprivateOp>( - currentLocation, commonValue.getType(), commonValue); + mlir::Value commonThreadprivateValue = mlir::omp::ThreadprivateOp::create( + firOpBuilder, currentLocation, commonValue.getType(), commonValue); converter.bindSymbol(*common, commonThreadprivateValue); // Generate the threadprivate value for the common block member. symThreadprivateValue = genCommonBlockMember(converter, currentLocation, @@ -3841,10 +4015,10 @@ void Fortran::lower::genThreadprivateOp(lower::AbstractConverter &converter, global = globalInitialization(converter, firOpBuilder, sym, var, currentLocation); - mlir::Value symValue = firOpBuilder.create<fir::AddrOfOp>( - currentLocation, global.resultType(), global.getSymbol()); - symThreadprivateValue = firOpBuilder.create<mlir::omp::ThreadprivateOp>( - currentLocation, symValue.getType(), symValue); + mlir::Value symValue = fir::AddrOfOp::create( + firOpBuilder, currentLocation, global.resultType(), global.getSymbol()); + symThreadprivateValue = mlir::omp::ThreadprivateOp::create( + firOpBuilder, currentLocation, symValue.getType(), symValue); } else { mlir::Value symValue = converter.getSymbolAddress(sym); @@ -3859,8 +4033,8 @@ void Fortran::lower::genThreadprivateOp(lower::AbstractConverter &converter, if (mlir::isa<mlir::omp::ThreadprivateOp>(op)) return; - symThreadprivateValue = firOpBuilder.create<mlir::omp::ThreadprivateOp>( - currentLocation, symValue.getType(), symValue); + symThreadprivateValue = mlir::omp::ThreadprivateOp::create( + firOpBuilder, currentLocation, symValue.getType(), symValue); } fir::ExtendedValue sexv = converter.getSymbolExtendedValue(sym); |