diff options
Diffstat (limited to 'flang/lib')
24 files changed, 996 insertions, 273 deletions
diff --git a/flang/lib/Frontend/CompilerInvocation.cpp b/flang/lib/Frontend/CompilerInvocation.cpp index 81610ed..548ca67 100644 --- a/flang/lib/Frontend/CompilerInvocation.cpp +++ b/flang/lib/Frontend/CompilerInvocation.cpp @@ -1425,6 +1425,9 @@ static bool parseFloatingPointArgs(CompilerInvocation &invoc, opts.setFPContractMode(Fortran::common::LangOptions::FPM_Fast); } + if (args.hasArg(clang::driver::options::OPT_fno_fast_real_mod)) + opts.NoFastRealMod = true; + return true; } diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp index d5e0325..0c630d2 100644 --- a/flang/lib/Frontend/FrontendActions.cpp +++ b/flang/lib/Frontend/FrontendActions.cpp @@ -277,6 +277,14 @@ bool CodeGenAction::beginSourceFileAction() { ci.getInvocation().getLangOpts().OpenMPVersion); } + if (ci.getInvocation().getLangOpts().NoFastRealMod) { + mlir::ModuleOp mod = lb.getModule(); + mod.getOperation()->setAttr( + mlir::StringAttr::get(mod.getContext(), + llvm::Twine{"fir.no_fast_real_mod"}), + mlir::BoolAttr::get(mod.getContext(), true)); + } + // Create a parse tree and lower it to FIR parseAndLowerTree(ci, lb); diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 149e51b..780d56f 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3182,7 +3182,7 @@ private: mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); localSymbols.pushScope(); mlir::Value exitCond = genOpenACCConstruct( - *this, bridge.getSemanticsContext(), getEval(), acc); + *this, bridge.getSemanticsContext(), getEval(), acc, localSymbols); const Fortran::parser::OpenACCLoopConstruct *accLoop = std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u); diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 95d0ada..f9b9b850 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::AccClauseList &accClauseList) { + const Fortran::parser::AccClauseList &accClauseList, + Fortran::lower::SymMap &localSymbols) { mlir::Value ifCond; llvm::SmallVector<mlir::Value> dataOperands; bool addIfPresentAttr = false; @@ -3199,6 +3200,19 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter, } else if (const auto *useDevice = std::get_if<Fortran::parser::AccClause::UseDevice>( &clause.u)) { + // When CUDA Fotran is enabled, extra symbols are used in the host_data + // region. Look for them and bind their values with the symbols in the + // outer scope. + if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) { + const Fortran::parser::AccObjectList &objectList{useDevice->v}; + for (const auto &accObject : objectList.v) { + Fortran::semantics::Symbol &symbol = + getSymbolFromAccObject(accObject); + const Fortran::semantics::Symbol *baseSym = + localSymbols.lookupSymbolByName(symbol.name().ToString()); + localSymbols.copySymbolBinding(*baseSym, symbol); + } + } genDataOperandOperations<mlir::acc::UseDeviceOp>( useDevice->v, converter, semanticsContext, stmtCtx, dataOperands, mlir::acc::DataClause::acc_use_device, @@ -3239,11 +3253,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter, hostDataOp.setIfPresentAttr(builder.getUnitAttr()); } -static void -genACC(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semanticsContext, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { +static void genACC(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenACCBlockConstruct &blockConstruct, + Fortran::lower::SymMap &localSymbols) { const auto &beginBlockDirective = std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t); const auto &blockDirective = @@ -3273,7 +3287,7 @@ genACC(Fortran::lower::AbstractConverter &converter, accClauseList); } else if (blockDirective.v == llvm::acc::ACCD_host_data) { genACCHostDataOp(converter, currentLocation, eval, semanticsContext, - stmtCtx, accClauseList); + stmtCtx, accClauseList, localSymbols); } } @@ -4647,13 +4661,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenACCConstruct &accConstruct) { + const Fortran::parser::OpenACCConstruct &accConstruct, + Fortran::lower::SymMap &localSymbols) { mlir::Value exitCond; Fortran::common::visit( common::visitors{ [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { - genACC(converter, semanticsContext, eval, blockConstruct); + genACC(converter, semanticsContext, eval, blockConstruct, + localSymbols); }, [&](const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) { diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index a96884f..55eda7e 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -431,6 +431,19 @@ bool ClauseProcessor::processNumTasks( return false; } +bool ClauseProcessor::processSizes(StatementContext &stmtCtx, + mlir::omp::SizesClauseOps &result) const { + if (auto *clause = findUniqueClause<omp::clause::Sizes>()) { + result.sizes.reserve(clause->v.size()); + for (const ExprTy &vv : clause->v) + result.sizes.push_back(fir::getBase(converter.genExprValue(vv, stmtCtx))); + + return true; + } + + return false; +} + bool ClauseProcessor::processNumTeams( lower::StatementContext &stmtCtx, mlir::omp::NumTeamsClauseOps &result) const { diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 324ea3c..9e352fa 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -66,6 +66,8 @@ public: mlir::omp::LoopRelatedClauseOps &loopResult, mlir::omp::CollapseClauseOps &collapseResult, llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const; + bool processSizes(StatementContext &stmtCtx, + mlir::omp::SizesClauseOps &result) const; bool processDevice(lower::StatementContext &stmtCtx, mlir::omp::DeviceClauseOps &result) const; bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const; diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 1cb3335..9e56c2b 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1982,125 +1982,241 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable, return loopOp; } -static mlir::omp::CanonicalLoopOp -genCanonicalLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, - lower::pft::Evaluation &eval, mlir::Location loc, - const ConstructQueue &queue, - ConstructQueue::const_iterator item, - llvm::ArrayRef<const semantics::Symbol *> ivs, - llvm::omp::Directive directive) { +static void genCanonicalLoopNest( + lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, + mlir::Location loc, const ConstructQueue &queue, + ConstructQueue::const_iterator item, size_t numLoops, + llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) { + assert(loops.empty() && "Expecting empty list to fill"); + assert(numLoops >= 1 && "Expecting at least one loop"); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - assert(ivs.size() == 1 && "Nested loops not yet implemented"); - const semantics::Symbol *iv = ivs[0]; + mlir::omp::LoopRelatedClauseOps loopInfo; + llvm::SmallVector<const semantics::Symbol *, 3> ivs; + collectLoopRelatedInfo(converter, loc, eval, numLoops, loopInfo, ivs); + assert(ivs.size() == numLoops && + "Expected to parse as many loop variables as there are loops"); + + // Steps that follow: + // 1. Emit all of the loop's prologues (compute the tripcount) + // 2. Emit omp.canonical_loop nested inside each other (iteratively) + // 2.1. In the innermost omp.canonical_loop, emit the loop body prologue (in + // the body callback) + // + // Since emitting prologues and body code is split, remember prologue values + // for use when emitting the same loop's epilogues. + llvm::SmallVector<mlir::Value> tripcounts; + llvm::SmallVector<mlir::Value> clis; + llvm::SmallVector<lower::pft::Evaluation *> evals; + llvm::SmallVector<mlir::Type> loopVarTypes; + llvm::SmallVector<mlir::Value> loopStepVars; + llvm::SmallVector<mlir::Value> loopLBVars; + llvm::SmallVector<mlir::Value> blockArgs; + + // Step 1: Loop prologues + // Computing the trip count must happen before entering the outermost loop + lower::pft::Evaluation *innermostEval = &eval.getFirstNestedEvaluation(); + for ([[maybe_unused]] auto iv : ivs) { + if (innermostEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) { + // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct. + // Will need to add special cases for this combination. + TODO(loc, "DO CONCURRENT as canonical loop not supported"); + } + + auto &doLoopEval = innermostEval->getFirstNestedEvaluation(); + evals.push_back(innermostEval); + + // Get the loop bounds (and increment) + // auto &doLoopEval = nestedEval.getFirstNestedEvaluation(); + auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>(); + assert(doStmt && "Expected do loop to be in the nested evaluation"); + auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t); + assert(loopControl.has_value()); + auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u); + assert(bounds && "Expected bounds for canonical loop"); + lower::StatementContext stmtCtx; + mlir::Value loopLBVar = fir::getBase( + converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx)); + mlir::Value loopUBVar = fir::getBase( + converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx)); + mlir::Value loopStepVar = [&]() { + if (bounds->step) { + return fir::getBase( + converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx)); + } - auto &nestedEval = eval.getFirstNestedEvaluation(); - if (nestedEval.getIf<parser::DoConstruct>()->IsDoConcurrent()) { - // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct. Will - // need to add special cases for this combination. - TODO(loc, "DO CONCURRENT as canonical loop not supported"); + // If `step` is not present, assume it is `1`. + auto intTy = firOpBuilder.getI32Type(); + return firOpBuilder.createIntegerConstant(loc, intTy, 1); + }(); + + // Get the integer kind for the loop variable and cast the loop bounds + size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size(); + mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); + loopVarTypes.push_back(loopVarType); + loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar); + loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar); + loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar); + loopLBVars.push_back(loopLBVar); + loopStepVars.push_back(loopStepVar); + + // Start lowering + mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0); + mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1); + mlir::Value isDownwards = firOpBuilder.create<mlir::arith::CmpIOp>( + loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero); + + // Ensure we are counting upwards. If not, negate step and swap lb and ub. + mlir::Value negStep = + firOpBuilder.create<mlir::arith::SubIOp>(loc, zero, loopStepVar); + mlir::Value incr = firOpBuilder.create<mlir::arith::SelectOp>( + loc, isDownwards, negStep, loopStepVar); + mlir::Value lb = firOpBuilder.create<mlir::arith::SelectOp>( + loc, isDownwards, loopUBVar, loopLBVar); + mlir::Value ub = firOpBuilder.create<mlir::arith::SelectOp>( + loc, isDownwards, loopLBVar, loopUBVar); + + // Compute the trip count assuming lb <= ub. This guarantees that the result + // is non-negative and we can use unsigned arithmetic. + mlir::Value span = firOpBuilder.create<mlir::arith::SubIOp>( + loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw); + mlir::Value tcMinusOne = + firOpBuilder.create<mlir::arith::DivUIOp>(loc, span, incr); + mlir::Value tcIfLooping = firOpBuilder.create<mlir::arith::AddIOp>( + loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw); + + // Fall back to 0 if lb > ub + mlir::Value isZeroTC = firOpBuilder.create<mlir::arith::CmpIOp>( + loc, mlir::arith::CmpIPredicate::slt, ub, lb); + mlir::Value tripcount = firOpBuilder.create<mlir::arith::SelectOp>( + loc, isZeroTC, zero, tcIfLooping); + tripcounts.push_back(tripcount); + + // Create the CLI handle. + auto newcli = firOpBuilder.create<mlir::omp::NewCliOp>(loc); + mlir::Value cli = newcli.getResult(); + clis.push_back(cli); + + innermostEval = &*std::next(innermostEval->getNestedEvaluations().begin()); } - // Get the loop bounds (and increment) - auto &doLoopEval = nestedEval.getFirstNestedEvaluation(); - auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>(); - assert(doStmt && "Expected do loop to be in the nested evaluation"); - auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t); - assert(loopControl.has_value()); - auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u); - assert(bounds && "Expected bounds for canonical loop"); - lower::StatementContext stmtCtx; - mlir::Value loopLBVar = fir::getBase( - converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx)); - mlir::Value loopUBVar = fir::getBase( - converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx)); - mlir::Value loopStepVar = [&]() { - if (bounds->step) { - return fir::getBase( - converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx)); - } + // Step 2: Create nested canoncial loops + for (auto i : llvm::seq<size_t>(numLoops)) { + bool isInnermost = (i == numLoops - 1); + mlir::Type loopVarType = loopVarTypes[i]; + mlir::Value tripcount = tripcounts[i]; + mlir::Value cli = clis[i]; + auto &&eval = evals[i]; + + auto ivCallback = [&, i, isInnermost](mlir::Operation *op) + -> llvm::SmallVector<const Fortran::semantics::Symbol *> { + mlir::Region ®ion = op->getRegion(0); + + // Create the op's region skeleton (BB taking the iv as argument) + firOpBuilder.createBlock(®ion, {}, {loopVarType}, {loc}); + blockArgs.push_back(region.front().getArgument(0)); + + // Step 2.1: Emit body prologue code + // Compute the translation from logical iteration number to the value of + // the loop's iteration variable only in the innermost body. Currently, + // loop transformations do not allow any instruction between loops, but + // this will change with + if (isInnermost) { + assert(blockArgs.size() == numLoops && + "Expecting all block args to have been collected by now"); + for (auto j : llvm::seq<size_t>(numLoops)) { + mlir::Value natIterNum = fir::getBase(blockArgs[j]); + mlir::Value scaled = firOpBuilder.create<mlir::arith::MulIOp>( + loc, natIterNum, loopStepVars[j]); + mlir::Value userVal = firOpBuilder.create<mlir::arith::AddIOp>( + loc, loopLBVars[j], scaled); + + mlir::OpBuilder::InsertPoint insPt = + firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); + mlir::Type tempTy = converter.genType(*ivs[j]); + firOpBuilder.restoreInsertionPoint(insPt); + + // Write the loop value into loop variable + mlir::Value cvtVal = firOpBuilder.createConvert(loc, tempTy, userVal); + hlfir::Entity lhs{converter.getSymbolAddress(*ivs[j])}; + lhs = hlfir::derefPointersAndAllocatables(loc, firOpBuilder, lhs); + mlir::Operation *storeOp = + hlfir::AssignOp::create(firOpBuilder, loc, cvtVal, lhs); + firOpBuilder.setInsertionPointAfter(storeOp); + } + } - // If `step` is not present, assume it is `1`. - return firOpBuilder.createIntegerConstant(loc, firOpBuilder.getI32Type(), - 1); - }(); + return {ivs[i]}; + }; - // Get the integer kind for the loop variable and cast the loop bounds - size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size(); - mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); - loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar); - loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar); - loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar); - - // Start lowering - mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0); - mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1); - mlir::Value isDownwards = mlir::arith::CmpIOp::create( - firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero); - - // Ensure we are counting upwards. If not, negate step and swap lb and ub. - mlir::Value negStep = - mlir::arith::SubIOp::create(firOpBuilder, loc, zero, loopStepVar); - mlir::Value incr = mlir::arith::SelectOp::create( - firOpBuilder, loc, isDownwards, negStep, loopStepVar); - mlir::Value lb = mlir::arith::SelectOp::create(firOpBuilder, loc, isDownwards, - loopUBVar, loopLBVar); - mlir::Value ub = mlir::arith::SelectOp::create(firOpBuilder, loc, isDownwards, - loopLBVar, loopUBVar); - - // Compute the trip count assuming lb <= ub. This guarantees that the result - // is non-negative and we can use unsigned arithmetic. - mlir::Value span = mlir::arith::SubIOp::create( - firOpBuilder, loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw); - mlir::Value tcMinusOne = - mlir::arith::DivUIOp::create(firOpBuilder, loc, span, incr); - mlir::Value tcIfLooping = - mlir::arith::AddIOp::create(firOpBuilder, loc, tcMinusOne, one, - ::mlir::arith::IntegerOverflowFlags::nuw); - - // Fall back to 0 if lb > ub - mlir::Value isZeroTC = mlir::arith::CmpIOp::create( - firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, ub, lb); - mlir::Value tripcount = mlir::arith::SelectOp::create( - firOpBuilder, loc, isZeroTC, zero, tcIfLooping); - - // Create the CLI handle. - auto newcli = mlir::omp::NewCliOp::create(firOpBuilder, loc); - mlir::Value cli = newcli.getResult(); - - auto ivCallback = [&](mlir::Operation *op) - -> llvm::SmallVector<const Fortran::semantics::Symbol *> { - mlir::Region ®ion = op->getRegion(0); - - // Create the op's region skeleton (BB taking the iv as argument) - firOpBuilder.createBlock(®ion, {}, {loopVarType}, {loc}); - - // Compute the value of the loop variable from the logical iteration number. - mlir::Value natIterNum = fir::getBase(region.front().getArgument(0)); - mlir::Value scaled = - mlir::arith::MulIOp::create(firOpBuilder, loc, natIterNum, loopStepVar); - mlir::Value userVal = - mlir::arith::AddIOp::create(firOpBuilder, loc, loopLBVar, scaled); - - // Write loop value to loop variable - mlir::Operation *storeOp = setLoopVar(converter, loc, userVal, iv); - - firOpBuilder.setInsertionPointAfter(storeOp); - return {iv}; - }; + // Create the omp.canonical_loop operation + auto opGenInfo = OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *eval, + llvm::omp::Directive::OMPD_unknown) + .setGenSkeletonOnly(!isInnermost) + .setClauses(&item->clauses) + .setPrivatize(false) + .setGenRegionEntryCb(ivCallback); + auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>( + std::move(opGenInfo), queue, item, tripcount, cli); + loops.push_back(canonLoop); + + // Insert next loop nested inside last loop + firOpBuilder.setInsertionPoint( + canonLoop.getRegion().back().getTerminator()); + } - // Create the omp.canonical_loop operation - auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>( - OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval, - directive) - .setClauses(&item->clauses) - .setPrivatize(false) - .setGenRegionEntryCb(ivCallback), - queue, item, tripcount, cli); + firOpBuilder.setInsertionPointAfter(loops.front()); +} + +static void genTileOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + lower::StatementContext &stmtCtx, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const ConstructQueue &queue, + ConstructQueue::const_iterator item) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - firOpBuilder.setInsertionPointAfter(canonLoop); - return canonLoop; + mlir::omp::SizesClauseOps sizesClause; + ClauseProcessor cp(converter, semaCtx, item->clauses); + cp.processSizes(stmtCtx, sizesClause); + + size_t numLoops = sizesClause.sizes.size(); + llvm::SmallVector<mlir::omp::CanonicalLoopOp, 3> canonLoops; + canonLoops.reserve(numLoops); + + genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item, + numLoops, canonLoops); + assert((canonLoops.size() == numLoops) && + "Expecting the predetermined number of loops"); + + llvm::SmallVector<mlir::Value, 3> applyees; + applyees.reserve(numLoops); + for (mlir::omp::CanonicalLoopOp l : canonLoops) + applyees.push_back(l.getCli()); + + // Emit the associated loops and create a CLI for each affected loop + llvm::SmallVector<mlir::Value, 3> gridGeneratees; + llvm::SmallVector<mlir::Value, 3> intratileGeneratees; + gridGeneratees.reserve(numLoops); + intratileGeneratees.reserve(numLoops); + for ([[maybe_unused]] auto i : llvm::seq<int>(0, sizesClause.sizes.size())) { + auto gridCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc); + gridGeneratees.push_back(gridCLI.getResult()); + auto intratileCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc); + intratileGeneratees.push_back(intratileCLI.getResult()); + } + + llvm::SmallVector<mlir::Value, 6> generatees; + generatees.reserve(2 * numLoops); + generatees.append(gridGeneratees); + generatees.append(intratileGeneratees); + + firOpBuilder.create<mlir::omp::TileOp>(loc, generatees, applyees, + sizesClause.sizes); } static void genUnrollOp(Fortran::lower::AbstractConverter &converter, @@ -2112,22 +2228,22 @@ static void genUnrollOp(Fortran::lower::AbstractConverter &converter, ConstructQueue::const_iterator item) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::omp::LoopRelatedClauseOps loopInfo; - llvm::SmallVector<const semantics::Symbol *> iv; - collectLoopRelatedInfo(converter, loc, eval, item->clauses, loopInfo, iv); - // Clauses for unrolling not yet implemnted ClauseProcessor cp(converter, semaCtx, item->clauses); cp.processTODO<clause::Partial, clause::Full>( loc, llvm::omp::Directive::OMPD_unroll); // Emit the associated loop - auto canonLoop = - genCanonicalLoopOp(converter, symTable, semaCtx, eval, loc, queue, item, - iv, llvm::omp::Directive::OMPD_unroll); + llvm::SmallVector<mlir::omp::CanonicalLoopOp, 1> canonLoops; + genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item, 1, + canonLoops); + + llvm::SmallVector<mlir::Value, 1> applyees; + for (auto &&canonLoop : canonLoops) + applyees.push_back(canonLoop.getCli()); // Apply unrolling to it - auto cli = canonLoop.getCli(); + auto cli = llvm::getSingleElement(canonLoops).getCli(); mlir::omp::UnrollHeuristicOp::create(firOpBuilder, loc, cli); } @@ -3360,13 +3476,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter, newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; - case llvm::omp::Directive::OMPD_tile: { - unsigned version = semaCtx.langOptions().OpenMPVersion; - if (!semaCtx.langOptions().OpenMPSimd) - TODO(loc, "Unhandled loop directive (" + - llvm::omp::getOpenMPDirectiveName(dir, version) + ")"); + case llvm::omp::Directive::OMPD_tile: + genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; - } case llvm::omp::Directive::OMPD_unroll: genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index 83b7ccb..29cccbd 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -652,7 +652,6 @@ int64_t collectLoopRelatedInfo( mlir::omp::LoopRelatedClauseOps &result, llvm::SmallVectorImpl<const semantics::Symbol *> &iv) { int64_t numCollapse = 1; - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); // Collect the loops to collapse. lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation(); @@ -667,6 +666,25 @@ int64_t collectLoopRelatedInfo( numCollapse = collapseValue; } + collectLoopRelatedInfo(converter, currentLocation, eval, numCollapse, result, + iv); + return numCollapse; +} + +void collectLoopRelatedInfo( + lower::AbstractConverter &converter, mlir::Location currentLocation, + lower::pft::Evaluation &eval, int64_t numCollapse, + mlir::omp::LoopRelatedClauseOps &result, + llvm::SmallVectorImpl<const semantics::Symbol *> &iv) { + + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + // Collect the loops to collapse. + lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation(); + if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) { + TODO(currentLocation, "Do Concurrent in Worksharing loop construct"); + } + // Collect sizes from tile directive if present. std::int64_t sizesLengthValue = 0l; if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) { @@ -676,7 +694,7 @@ int64_t collectLoopRelatedInfo( }); } - collapseValue = std::max(collapseValue, sizesLengthValue); + std::int64_t collapseValue = std::max(numCollapse, sizesLengthValue); std::size_t loopVarTypeSize = 0; do { lower::pft::Evaluation *doLoop = @@ -709,8 +727,6 @@ int64_t collectLoopRelatedInfo( } while (collapseValue > 0); convertLoopBounds(converter, currentLocation, result, loopVarTypeSize); - - return numCollapse; } } // namespace omp diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index 5f191d8..69499f9 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -165,6 +165,13 @@ int64_t collectLoopRelatedInfo( mlir::omp::LoopRelatedClauseOps &result, llvm::SmallVectorImpl<const semantics::Symbol *> &iv); +void collectLoopRelatedInfo( + lower::AbstractConverter &converter, mlir::Location currentLocation, + lower::pft::Evaluation &eval, std::int64_t collapseValue, + // const omp::List<omp::Clause> &clauses, + mlir::omp::LoopRelatedClauseOps &result, + llvm::SmallVectorImpl<const semantics::Symbol *> &iv); + void collectTileSizesFromOpenMPConstruct( const parser::OpenMPConstruct *ompCons, llvm::SmallVectorImpl<int64_t> &tileSizes, diff --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp index 080f21e..78529e0 100644 --- a/flang/lib/Lower/SymbolMap.cpp +++ b/flang/lib/Lower/SymbolMap.cpp @@ -45,6 +45,16 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) { return SymbolBox::None{}; } +const Fortran::semantics::Symbol * +Fortran::lower::SymMap::lookupSymbolByName(llvm::StringRef symName) { + for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend(); + jmap != jend; ++jmap) + for (auto const &[sym, symBox] : *jmap) + if (sym->name().ToString() == symName) + return sym; + return nullptr; +} + Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol( Fortran::semantics::SymbolRef symRef) { auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate(); diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp index 5e6e208..5da27d1 100644 --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -1943,7 +1943,7 @@ void fir::factory::genDimInfoFromBox( return; unsigned rank = fir::getBoxRank(boxType); - assert(rank != 0 && "must be an array of known rank"); + assert(!boxType.isAssumedRank() && "must be an array of known rank"); mlir::Type idxTy = builder.getIndexType(); for (unsigned i = 0; i < rank; ++i) { mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i); diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp index f93eaf7..dbfcae1 100644 --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -676,6 +676,34 @@ mlir::Value hlfir::genLBound(mlir::Location loc, fir::FirOpBuilder &builder, return dimInfo.getLowerBound(); } +static bool +getExprLengthParameters(mlir::Value expr, + llvm::SmallVectorImpl<mlir::Value> &result) { + if (auto concat = expr.getDefiningOp<hlfir::ConcatOp>()) { + result.push_back(concat.getLength()); + return true; + } + if (auto setLen = expr.getDefiningOp<hlfir::SetLengthOp>()) { + result.push_back(setLen.getLength()); + return true; + } + if (auto elemental = expr.getDefiningOp<hlfir::ElementalOp>()) { + result.append(elemental.getTypeparams().begin(), + elemental.getTypeparams().end()); + return true; + } + if (auto evalInMem = expr.getDefiningOp<hlfir::EvaluateInMemoryOp>()) { + result.append(evalInMem.getTypeparams().begin(), + evalInMem.getTypeparams().end()); + return true; + } + if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) { + result.append(apply.getTypeparams().begin(), apply.getTypeparams().end()); + return true; + } + return false; +} + void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity, llvm::SmallVectorImpl<mlir::Value> &result) { @@ -688,29 +716,14 @@ void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder, // Going through fir::ExtendedValue would create a temp, // which is not desired for an inquiry. // TODO: make this an interface when adding further character producing ops. - if (auto concat = expr.getDefiningOp<hlfir::ConcatOp>()) { - result.push_back(concat.getLength()); - return; - } else if (auto concat = expr.getDefiningOp<hlfir::SetLengthOp>()) { - result.push_back(concat.getLength()); - return; - } else if (auto asExpr = expr.getDefiningOp<hlfir::AsExprOp>()) { + + if (auto asExpr = expr.getDefiningOp<hlfir::AsExprOp>()) { hlfir::genLengthParameters(loc, builder, hlfir::Entity{asExpr.getVar()}, result); return; - } else if (auto elemental = expr.getDefiningOp<hlfir::ElementalOp>()) { - result.append(elemental.getTypeparams().begin(), - elemental.getTypeparams().end()); - return; - } else if (auto evalInMem = - expr.getDefiningOp<hlfir::EvaluateInMemoryOp>()) { - result.append(evalInMem.getTypeparams().begin(), - evalInMem.getTypeparams().end()); - return; - } else if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) { - result.append(apply.getTypeparams().begin(), apply.getTypeparams().end()); - return; } + if (getExprLengthParameters(expr, result)) + return; if (entity.isCharacter()) { result.push_back(hlfir::GetLengthOp::create(builder, loc, expr)); return; @@ -733,6 +746,36 @@ mlir::Value hlfir::genCharLength(mlir::Location loc, fir::FirOpBuilder &builder, return lenParams[0]; } +std::optional<std::int64_t> hlfir::getCharLengthIfConst(hlfir::Entity entity) { + if (!entity.isCharacter()) { + return std::nullopt; + } + if (mlir::isa<hlfir::ExprType>(entity.getType())) { + mlir::Value expr = entity; + if (auto reassoc = expr.getDefiningOp<hlfir::NoReassocOp>()) + expr = reassoc.getVal(); + + if (auto asExpr = expr.getDefiningOp<hlfir::AsExprOp>()) + return getCharLengthIfConst(hlfir::Entity{asExpr.getVar()}); + + llvm::SmallVector<mlir::Value> param; + if (getExprLengthParameters(expr, param)) { + assert(param.size() == 1 && "characters must have one length parameters"); + return fir::getIntIfConstant(param.pop_back_val()); + } + return std::nullopt; + } + + // entity is a var + if (mlir::Value len = tryGettingNonDeferredCharLen(entity)) + return fir::getIntIfConstant(len); + auto charType = + mlir::cast<fir::CharacterType>(entity.getFortranElementType()); + if (charType.hasConstantLen()) + return charType.getLen(); + return std::nullopt; +} + mlir::Value hlfir::genRank(mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity entity, mlir::Type resultType) { if (!entity.isAssumedRank()) diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 71d35e3..de7694f 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -6989,8 +6989,33 @@ mlir::Value IntrinsicLibrary::genMergeBits(mlir::Type resultType, } // MOD +static mlir::Value genFastMod(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value a, mlir::Value p) { + auto fastmathFlags = mlir::arith::FastMathFlags::contract; + auto fastmathAttr = + mlir::arith::FastMathFlagsAttr::get(builder.getContext(), fastmathFlags); + mlir::Value divResult = + mlir::arith::DivFOp::create(builder, loc, a, p, fastmathAttr); + mlir::Type intType = builder.getIntegerType( + a.getType().getIntOrFloatBitWidth(), /*signed=*/true); + mlir::Value intResult = builder.createConvert(loc, intType, divResult); + mlir::Value cnvResult = builder.createConvert(loc, a.getType(), intResult); + mlir::Value mulResult = + mlir::arith::MulFOp::create(builder, loc, cnvResult, p, fastmathAttr); + mlir::Value subResult = + mlir::arith::SubFOp::create(builder, loc, a, mulResult, fastmathAttr); + return subResult; +} + mlir::Value IntrinsicLibrary::genMod(mlir::Type resultType, llvm::ArrayRef<mlir::Value> args) { + auto mod = builder.getModule(); + bool dontUseFastRealMod = false; + bool canUseApprox = mlir::arith::bitEnumContainsAny( + builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn); + if (auto attr = mod->getAttrOfType<mlir::BoolAttr>("fir.no_fast_real_mod")) + dontUseFastRealMod = attr.getValue(); + assert(args.size() == 2); if (resultType.isUnsignedInteger()) { mlir::Type signlessType = mlir::IntegerType::get( @@ -7002,9 +7027,16 @@ mlir::Value IntrinsicLibrary::genMod(mlir::Type resultType, if (mlir::isa<mlir::IntegerType>(resultType)) return mlir::arith::RemSIOp::create(builder, loc, args[0], args[1]); - // Use runtime. - return builder.createConvert( - loc, resultType, fir::runtime::genMod(builder, loc, args[0], args[1])); + if (resultType.isFloat() && canUseApprox && !dontUseFastRealMod) { + // Treat MOD as an approximate function and code-gen inline code + // instead of calling into the Fortran runtime library. + return builder.createConvert(loc, resultType, + genFastMod(builder, loc, args[0], args[1])); + } else { + // Use runtime. + return builder.createConvert( + loc, resultType, fir::runtime::genMod(builder, loc, args[0], args[1])); + } } // MODULO diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index d8e36ea..ce8ebaa 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -2284,6 +2284,212 @@ public: } }; +static std::pair<mlir::Value, hlfir::AssociateOp> +getVariable(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value val) { + // If it is an expression - create a variable from it, or forward + // the value otherwise. + hlfir::AssociateOp associate; + if (!mlir::isa<hlfir::ExprType>(val.getType())) + return {val, associate}; + hlfir::Entity entity{val}; + mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr(builder); + associate = hlfir::genAssociateExpr(loc, builder, entity, entity.getType(), + "", byRefAttr); + return {associate.getBase(), associate}; +} + +class IndexOpConversion : public mlir::OpRewritePattern<hlfir::IndexOp> { +public: + using mlir::OpRewritePattern<hlfir::IndexOp>::OpRewritePattern; + + llvm::LogicalResult + matchAndRewrite(hlfir::IndexOp op, + mlir::PatternRewriter &rewriter) const override { + // We simplify only limited cases: + // 1) a substring length shall be known at compile time + // 2) if a substring length is 0 then replace with 1 for forward search, + // or otherwise with the string length + 1 (builder shall const-fold if + // lookup direction is known at compile time). + // 3) for known string length at compile time, if it is + // shorter than substring => replace with zero. + // 4) if a substring length is one => inline as simple search loop + // 5) for forward search with input strings of kind=1 runtime is faster. + // Do not simplify in all the other cases relying on a runtime call. + + fir::FirOpBuilder builder{rewriter, op.getOperation()}; + const mlir::Location &loc = op->getLoc(); + + auto resultTy = op.getType(); + mlir::Value back = op.getBack(); + auto substrLenCst = + hlfir::getCharLengthIfConst(hlfir::Entity{op.getSubstr()}); + if (!substrLenCst) { + return rewriter.notifyMatchFailure( + op, "substring length unknown at compile time"); + } + hlfir::Entity strEntity{op.getStr()}; + auto i1Ty = builder.getI1Type(); + auto idxTy = builder.getIndexType(); + if (*substrLenCst == 0) { + mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1); + // zero length substring. For back search replace with + // strLen+1, or otherwise with 1. + mlir::Value strLen = hlfir::genCharLength(loc, builder, strEntity); + mlir::Value strEnd = mlir::arith::AddIOp::create( + builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx); + if (back) + back = builder.createConvert(loc, i1Ty, back); + else + back = builder.createIntegerConstant(loc, i1Ty, 0); + mlir::Value result = + mlir::arith::SelectOp::create(builder, loc, back, strEnd, oneIdx); + + rewriter.replaceOp(op, builder.createConvert(loc, resultTy, result)); + return mlir::success(); + } + + if (auto strLenCst = hlfir::getCharLengthIfConst(strEntity)) { + if (*strLenCst < *substrLenCst) { + rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 0)); + return mlir::success(); + } + if (*strLenCst == 0) { + // both strings have zero length + rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 1)); + return mlir::success(); + } + } + if (*substrLenCst != 1) { + return rewriter.notifyMatchFailure( + op, "rely on runtime implementation if substring length > 1"); + } + // For forward search and character kind=1 the runtime uses memchr + // which well optimized. But it looks like memchr idiom is not recognized + // in LLVM yet. On a micro-kernel test with strings of length 40 runtime + // had ~2x less execution time vs inlined code. For unknown search direction + // at compile time pessimistically assume "forward". + std::optional<bool> isBack; + if (back) { + if (auto backCst = fir::getIntIfConstant(back)) + isBack = *backCst != 0; + } else { + isBack = false; + } + auto charTy = mlir::cast<fir::CharacterType>( + hlfir::getFortranElementType(op.getSubstr().getType())); + unsigned kind = charTy.getFKind(); + if (kind == 1 && (!isBack || !*isBack)) { + return rewriter.notifyMatchFailure( + op, "rely on runtime implementation for character kind 1"); + } + + // All checks are passed here. Generate single character search loop. + auto [strV, strAssociate] = getVariable(builder, loc, op.getStr()); + auto [substrV, substrAssociate] = getVariable(builder, loc, op.getSubstr()); + hlfir::Entity str{strV}; + hlfir::Entity substr{substrV}; + mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1); + + auto genExtractAndConvertToInt = [&charTy, &idxTy, &oneIdx, + kind](mlir::Location loc, + fir::FirOpBuilder &builder, + hlfir::Entity &charStr, + mlir::Value index) { + auto bits = builder.getKindMap().getCharacterBitsize(kind); + auto intTy = builder.getIntegerType(bits); + auto charLen1Ty = + fir::CharacterType::getSingleton(builder.getContext(), kind); + mlir::Type designatorTy = + fir::ReferenceType::get(charLen1Ty, fir::isa_volatile_type(charTy)); + auto idxAttr = builder.getIntegerAttr(idxTy, 0); + + auto singleChr = hlfir::DesignateOp::create( + builder, loc, designatorTy, charStr, /*component=*/{}, + /*compShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{}, + /*substring=*/mlir::ValueRange{index, index}, + /*complexPart=*/std::nullopt, + /*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{oneIdx}, + fir::FortranVariableFlagsAttr{}); + auto chrVal = fir::LoadOp::create(builder, loc, singleChr); + mlir::Value intVal = fir::ExtractValueOp::create( + builder, loc, intTy, chrVal, builder.getArrayAttr(idxAttr)); + return intVal; + }; + + auto wantChar = genExtractAndConvertToInt(loc, builder, substr, oneIdx); + + // Generate search loop body with the following C equivalent: + // idx_t result = 0; + // idx_t end = strlen + 1; + // char want = substr[0]; + // for (idx_t idx = 1; idx < end; ++idx) { + // if (result == 0) { + // idx_t at = back ? end - idx: idx; + // result = str[at-1] == want ? at : result; + // } + // } + mlir::Value strLen = hlfir::genCharLength(loc, builder, strEntity); + if (!back) + back = builder.createIntegerConstant(loc, i1Ty, 0); + else + back = builder.createConvert(loc, i1Ty, back); + mlir::Value strEnd = mlir::arith::AddIOp::create( + builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx); + mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); + auto genSearchBody = [&](mlir::Location loc, fir::FirOpBuilder &builder, + mlir::ValueRange index, + mlir::ValueRange reductionArgs) + -> llvm::SmallVector<mlir::Value, 1> { + assert(index.size() == 1 && "expected single loop"); + assert(reductionArgs.size() == 1 && "expected single reduction value"); + mlir::Value inRes = reductionArgs[0]; + auto resEQzero = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroIdx); + + mlir::Value res = + builder + .genIfOp(loc, {idxTy}, resEQzero, + /*withElseRegion=*/true) + .genThen([&]() { + mlir::Value idx = builder.createConvert(loc, idxTy, index[0]); + // offset = back ? end - idx : idx; + mlir::Value offset = mlir::arith::SelectOp::create( + builder, loc, back, + mlir::arith::SubIOp::create(builder, loc, strEnd, idx), + idx); + + auto haveChar = + genExtractAndConvertToInt(loc, builder, str, offset); + auto charsEQ = mlir::arith::CmpIOp::create( + builder, loc, mlir::arith::CmpIPredicate::eq, haveChar, + wantChar); + mlir::Value newVal = mlir::arith::SelectOp::create( + builder, loc, charsEQ, offset, inRes); + + fir::ResultOp::create(builder, loc, newVal); + }) + .genElse([&]() { fir::ResultOp::create(builder, loc, inRes); }) + .getResults()[0]; + return {res}; + }; + + llvm::SmallVector<mlir::Value, 1> loopOut = + hlfir::genLoopNestWithReductions(loc, builder, {strLen}, + /*reductionInits=*/{zeroIdx}, + genSearchBody, + /*isUnordered=*/false); + mlir::Value result = builder.createConvert(loc, resultTy, loopOut[0]); + + if (strAssociate) + hlfir::EndAssociateOp::create(builder, loc, strAssociate); + if (substrAssociate) + hlfir::EndAssociateOp::create(builder, loc, substrAssociate); + + rewriter.replaceOp(op, result); + return mlir::success(); + } +}; + template <typename Op> class MatmulConversion : public mlir::OpRewritePattern<Op> { public: @@ -2955,6 +3161,7 @@ public: patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context); patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context); patterns.insert<CmpCharOpConversion>(context); + patterns.insert<IndexOpConversion>(context); patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context); patterns.insert<ReductionConversion<hlfir::CountOp>>(context); patterns.insert<ReductionConversion<hlfir::AnyOp>>(context); diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp index 57be863..e595e61 100644 --- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp +++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp @@ -41,7 +41,9 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Support/raw_ostream.h" #include <algorithm> #include <cstddef> #include <iterator> @@ -75,6 +77,112 @@ class MapInfoFinalizationPass /// | | std::map<mlir::Operation *, mlir::Value> localBoxAllocas; + /// Return true if the given path exists in a list of paths. + static bool + containsPath(const llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &paths, + llvm::ArrayRef<int64_t> path) { + return llvm::any_of(paths, [&](const llvm::SmallVector<int64_t> &p) { + return p.size() == path.size() && + std::equal(p.begin(), p.end(), path.begin()); + }); + } + + /// Return true if the given path is already present in + /// op.getMembersIndexAttr(). + static bool mappedIndexPathExists(mlir::omp::MapInfoOp op, + llvm::ArrayRef<int64_t> indexPath) { + if (mlir::ArrayAttr attr = op.getMembersIndexAttr()) { + for (mlir::Attribute list : attr) { + auto listAttr = mlir::cast<mlir::ArrayAttr>(list); + if (listAttr.size() != indexPath.size()) + continue; + bool allEq = true; + for (auto [i, val] : llvm::enumerate(listAttr)) { + if (mlir::cast<mlir::IntegerAttr>(val).getInt() != indexPath[i]) { + allEq = false; + break; + } + } + if (allEq) + return true; + } + } + return false; + } + + /// Build a compact string key for an index path for set-based + /// deduplication. Format: "N:v0,v1,..." where N is the length. + static void buildPathKey(llvm::ArrayRef<int64_t> path, + llvm::SmallString<64> &outKey) { + outKey.clear(); + llvm::raw_svector_ostream os(outKey); + os << path.size() << ':'; + for (size_t i = 0; i < path.size(); ++i) { + if (i) + os << ','; + os << path[i]; + } + } + + /// Create the member map for coordRef and append it (and its index + /// path) to the provided new* vectors, if it is not already present. + void appendMemberMapIfNew( + mlir::omp::MapInfoOp op, fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value coordRef, llvm::ArrayRef<int64_t> indexPath, + llvm::StringRef memberName, + llvm::SmallVectorImpl<mlir::Value> &newMapOpsForFields, + llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &newMemberIndexPaths) { + // Local de-dup within this op invocation. + if (containsPath(newMemberIndexPaths, indexPath)) + return; + // Global de-dup against already present member indices. + if (mappedIndexPathExists(op, indexPath)) + return; + + if (op.getMapperId()) { + mlir::omp::DeclareMapperOp symbol = + mlir::SymbolTable::lookupNearestSymbolFrom< + mlir::omp::DeclareMapperOp>(op, op.getMapperIdAttr()); + assert(symbol && "missing symbol for declare mapper identifier"); + mlir::omp::DeclareMapperInfoOp mapperInfo = symbol.getDeclareMapperInfo(); + // TODO: Probably a way to cache these keys in someway so we don't + // constantly go through the process of rebuilding them on every check, to + // save some cycles, but it can wait for a subsequent patch. + for (auto v : mapperInfo.getMapVars()) { + mlir::omp::MapInfoOp map = + mlir::cast<mlir::omp::MapInfoOp>(v.getDefiningOp()); + if (!map.getMembers().empty() && mappedIndexPathExists(map, indexPath)) + return; + } + } + + builder.setInsertionPoint(op); + fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr( + builder, coordRef, /*isOptional=*/false, loc); + llvm::SmallVector<mlir::Value> bounds = fir::factory::genImplicitBoundsOps< + mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>( + builder, info, + hlfir::translateToExtendedValue(loc, builder, hlfir::Entity{coordRef}) + .first, + /*dataExvIsAssumedSize=*/false, loc); + + mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create( + builder, loc, coordRef.getType(), coordRef, + mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())), + op.getMapTypeAttr(), + builder.getAttr<mlir::omp::VariableCaptureKindAttr>( + mlir::omp::VariableCaptureKind::ByRef), + /*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{}, + /*members_index=*/mlir::ArrayAttr{}, bounds, + /*mapperId=*/mlir::FlatSymbolRefAttr(), + builder.getStringAttr(op.getNameAttr().strref() + "." + memberName + + ".implicit_map"), + /*partial_map=*/builder.getBoolAttr(false)); + + newMapOpsForFields.emplace_back(fieldMapOp); + newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end()); + } + /// getMemberUserList gathers all users of a particular MapInfoOp that are /// other MapInfoOp's and places them into the mapMemberUsers list, which /// records the map that the current argument MapInfoOp "op" is part of @@ -363,7 +471,7 @@ class MapInfoFinalizationPass mlir::ArrayAttr newMembersAttr; mlir::SmallVector<mlir::Value> newMembers; llvm::SmallVector<llvm::SmallVector<int64_t>> memberIndices; - bool IsHasDeviceAddr = isHasDeviceAddr(op, target); + bool isHasDeviceAddrFlag = isHasDeviceAddr(op, target); if (!mapMemberUsers.empty() || !op.getMembers().empty()) getMemberIndicesAsVectors( @@ -406,7 +514,7 @@ class MapInfoFinalizationPass mapUser.parent.getMembersMutable().assign(newMemberOps); mapUser.parent.setMembersIndexAttr( builder.create2DI64ArrayAttr(memberIndices)); - } else if (!IsHasDeviceAddr) { + } else if (!isHasDeviceAddrFlag) { auto baseAddr = genBaseAddrMap(descriptor, op.getBounds(), op.getMapType(), builder); newMembers.push_back(baseAddr); @@ -429,7 +537,7 @@ class MapInfoFinalizationPass // The contents of the descriptor (the base address in particular) will // remain unchanged though. uint64_t mapType = op.getMapType(); - if (IsHasDeviceAddr) { + if (isHasDeviceAddrFlag) { mapType |= llvm::to_underlying( llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); } @@ -701,94 +809,134 @@ class MapInfoFinalizationPass auto recordType = mlir::cast<fir::RecordType>(underlyingType); llvm::SmallVector<mlir::Value> newMapOpsForFields; - llvm::SmallVector<int64_t> fieldIndicies; + llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndexPaths; + // 1) Handle direct top-level allocatable fields. for (auto fieldMemTyPair : recordType.getTypeList()) { auto &field = fieldMemTyPair.first; auto memTy = fieldMemTyPair.second; - bool shouldMapField = - llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) { - if (!fir::isAllocatableType(memTy)) - return false; - - auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp); - if (!designateOp) - return false; - - return designateOp.getComponent() && - designateOp.getComponent()->strref() == field; - }) != mapVarForwardSlice.end(); - - // TODO Handle recursive record types. Adapting - // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR - // entities might be helpful here. - - if (!shouldMapField) + if (!fir::isAllocatableType(memTy)) continue; - int32_t fieldIdx = recordType.getFieldIndex(field); - bool alreadyMapped = [&]() { - if (op.getMembersIndexAttr()) - for (auto indexList : op.getMembersIndexAttr()) { - auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList); - if (indexListAttr.size() == 1 && - mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() == - fieldIdx) - return true; - } - - return false; - }(); - - if (alreadyMapped) + bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) { + auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv); + return designateOp && designateOp.getComponent() && + designateOp.getComponent()->strref() == field; + }); + if (!referenced) continue; + int32_t fieldIdx = recordType.getFieldIndex(field); builder.setInsertionPoint(op); fir::IntOrValue idxConst = mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx); auto fieldCoord = fir::CoordinateOp::create( builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(), llvm::SmallVector<fir::IntOrValue, 1>{idxConst}); - fir::factory::AddrAndBoundsInfo info = - fir::factory::getDataOperandBaseAddr( - builder, fieldCoord, /*isOptional=*/false, op.getLoc()); - llvm::SmallVector<mlir::Value> bounds = - fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp, - mlir::omp::MapBoundsType>( - builder, info, - hlfir::translateToExtendedValue(op.getLoc(), builder, - hlfir::Entity{fieldCoord}) - .first, - /*dataExvIsAssumedSize=*/false, op.getLoc()); - - mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create( - builder, op.getLoc(), fieldCoord.getResult().getType(), - fieldCoord.getResult(), - mlir::TypeAttr::get( - fir::unwrapRefType(fieldCoord.getResult().getType())), - op.getMapTypeAttr(), - builder.getAttr<mlir::omp::VariableCaptureKindAttr>( - mlir::omp::VariableCaptureKind::ByRef), - /*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{}, - /*members_index=*/mlir::ArrayAttr{}, bounds, - /*mapperId=*/mlir::FlatSymbolRefAttr(), - builder.getStringAttr(op.getNameAttr().strref() + "." + field + - ".implicit_map"), - /*partial_map=*/builder.getBoolAttr(false)); - newMapOpsForFields.emplace_back(fieldMapOp); - fieldIndicies.emplace_back(fieldIdx); + int64_t fieldIdx64 = static_cast<int64_t>(fieldIdx); + llvm::SmallVector<int64_t, 1> idxPath{fieldIdx64}; + appendMemberMapIfNew(op, builder, op.getLoc(), fieldCoord, idxPath, + field, newMapOpsForFields, newMemberIndexPaths); + } + + // Handle nested allocatable fields along any component chain + // referenced in the region via HLFIR designates. + llvm::SmallVector<llvm::SmallVector<int64_t>> seenIndexPaths; + for (mlir::Operation *sliceOp : mapVarForwardSlice) { + auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp); + if (!designateOp || !designateOp.getComponent()) + continue; + llvm::SmallVector<llvm::StringRef> compPathReversed; + compPathReversed.push_back(designateOp.getComponent()->strref()); + mlir::Value curBase = designateOp.getMemref(); + bool rootedAtMapArg = false; + while (true) { + if (auto parentDes = curBase.getDefiningOp<hlfir::DesignateOp>()) { + if (!parentDes.getComponent()) + break; + compPathReversed.push_back(parentDes.getComponent()->strref()); + curBase = parentDes.getMemref(); + continue; + } + if (auto decl = curBase.getDefiningOp<hlfir::DeclareOp>()) { + if (auto barg = + mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref())) + rootedAtMapArg = (barg == opBlockArg); + } else if (auto blockArg = + mlir::dyn_cast_or_null<mlir::BlockArgument>( + curBase)) { + rootedAtMapArg = (blockArg == opBlockArg); + } + break; + } + // Only process nested paths (2+ components). Single-component paths + // for direct fields are handled above. + if (!rootedAtMapArg || compPathReversed.size() < 2) + continue; + builder.setInsertionPoint(op); + llvm::SmallVector<int64_t> indexPath; + mlir::Type curTy = underlyingType; + mlir::Value coordRef = op.getVarPtr(); + bool validPath = true; + for (llvm::StringRef compName : llvm::reverse(compPathReversed)) { + auto recTy = mlir::dyn_cast<fir::RecordType>(curTy); + if (!recTy) { + validPath = false; + break; + } + int32_t idx = recTy.getFieldIndex(compName); + if (idx < 0) { + validPath = false; + break; + } + indexPath.push_back(idx); + mlir::Type memTy = recTy.getType(idx); + fir::IntOrValue idxConst = + mlir::IntegerAttr::get(builder.getI32Type(), idx); + coordRef = fir::CoordinateOp::create( + builder, op.getLoc(), builder.getRefType(memTy), coordRef, + llvm::SmallVector<fir::IntOrValue, 1>{idxConst}); + curTy = memTy; + } + if (!validPath) + continue; + if (auto finalRefTy = + mlir::dyn_cast<fir::ReferenceType>(coordRef.getType())) { + mlir::Type eleTy = finalRefTy.getElementType(); + if (fir::isAllocatableType(eleTy)) { + if (!containsPath(seenIndexPaths, indexPath)) { + seenIndexPaths.emplace_back(indexPath.begin(), indexPath.end()); + appendMemberMapIfNew(op, builder, op.getLoc(), coordRef, + indexPath, compPathReversed.front(), + newMapOpsForFields, newMemberIndexPaths); + } + } + } } if (newMapOpsForFields.empty()) return mlir::WalkResult::advance(); - op.getMembersMutable().append(newMapOpsForFields); + // Deduplicate by index path to avoid emitting duplicate members for + // the same component. Use a set-based key to keep this near O(n). + llvm::SmallVector<mlir::Value> dedupMapOps; + llvm::SmallVector<llvm::SmallVector<int64_t>> dedupIndexPaths; + llvm::StringSet<> seenKeys; + for (auto [i, mapOp] : llvm::enumerate(newMapOpsForFields)) { + const auto &path = newMemberIndexPaths[i]; + llvm::SmallString<64> key; + buildPathKey(path, key); + if (seenKeys.contains(key)) + continue; + seenKeys.insert(key); + dedupMapOps.push_back(mapOp); + dedupIndexPaths.emplace_back(path.begin(), path.end()); + } + op.getMembersMutable().append(dedupMapOps); llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices; - mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr(); - - if (oldMembersIdxAttr) - for (mlir::Attribute indexList : oldMembersIdxAttr) { + if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr()) + for (mlir::Attribute indexList : oldAttr) { llvm::SmallVector<int64_t> listVec; for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList)) @@ -796,10 +944,8 @@ class MapInfoFinalizationPass newMemberIndices.emplace_back(std::move(listVec)); } - - for (int64_t newFieldIdx : fieldIndicies) - newMemberIndices.emplace_back( - llvm::SmallVector<int64_t>(1, newFieldIdx)); + for (auto &path : dedupIndexPaths) + newMemberIndices.emplace_back(path); op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices)); op.setPartialMap(true); diff --git a/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp b/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp index bdf7e4a..e006d2e 100644 --- a/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp +++ b/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp @@ -285,11 +285,16 @@ mlir::LLVM::DIModuleAttr AddDebugInfoPass::getOrCreateModuleAttr( if (auto iter{moduleMap.find(name)}; iter != moduleMap.end()) { modAttr = iter->getValue(); } else { + // When decl is true, it means that module is only being used in this + // compilation unit and it is defined elsewhere. But if the file/line/scope + // fields are valid, the module is not merged with its definition and is + // considered different. So we only set those fields when decl is false. modAttr = mlir::LLVM::DIModuleAttr::get( - context, fileAttr, scope, mlir::StringAttr::get(context, name), + context, decl ? nullptr : fileAttr, decl ? nullptr : scope, + mlir::StringAttr::get(context, name), /* configMacros */ mlir::StringAttr(), /* includePath */ mlir::StringAttr(), - /* apinotes */ mlir::StringAttr(), line, decl); + /* apinotes */ mlir::StringAttr(), decl ? 0 : line, decl); moduleMap[name] = modAttr; } return modAttr; diff --git a/flang/lib/Parser/parsing.cpp b/flang/lib/Parser/parsing.cpp index 8a8c6ef..2df6881 100644 --- a/flang/lib/Parser/parsing.cpp +++ b/flang/lib/Parser/parsing.cpp @@ -85,6 +85,7 @@ const SourceFile *Parsing::Prescan(const std::string &path, Options options) { if (options.features.IsEnabled(LanguageFeature::OpenACC) || (options.prescanAndReformat && noneOfTheAbove)) { prescanner.AddCompilerDirectiveSentinel("$acc"); + prescanner.AddCompilerDirectiveSentinel("@acc"); } if (options.features.IsEnabled(LanguageFeature::OpenMP) || (options.prescanAndReformat && noneOfTheAbove)) { diff --git a/flang/lib/Parser/prescan.cpp b/flang/lib/Parser/prescan.cpp index 865c149..66e5b2c 100644 --- a/flang/lib/Parser/prescan.cpp +++ b/flang/lib/Parser/prescan.cpp @@ -147,6 +147,11 @@ void Prescanner::Statement() { directiveSentinel_[4] == '\0') { // CUDA conditional compilation line. condOffset = 5; + } else if (directiveSentinel_[0] == '@' && directiveSentinel_[1] == 'a' && + directiveSentinel_[2] == 'c' && directiveSentinel_[3] == 'c' && + directiveSentinel_[4] == '\0') { + // OpenACC conditional compilation line. + condOffset = 5; } if (condOffset && !preprocessingOnly_) { at_ += *condOffset, column_ += *condOffset; diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp index 1049a6d2..7b88100 100644 --- a/flang/lib/Semantics/check-declarations.cpp +++ b/flang/lib/Semantics/check-declarations.cpp @@ -1189,7 +1189,8 @@ void CheckHelper::CheckObjectEntity( } } else if (!subpDetails && symbol.owner().kind() != Scope::Kind::Module && symbol.owner().kind() != Scope::Kind::MainProgram && - symbol.owner().kind() != Scope::Kind::BlockConstruct) { + symbol.owner().kind() != Scope::Kind::BlockConstruct && + symbol.owner().kind() != Scope::Kind::OpenACCConstruct) { messages_.Say( "ATTRIBUTES(%s) may apply only to module, host subprogram, block, or device subprogram data"_err_en_US, parser::ToUpperCaseLetters(common::EnumToString(attr))); diff --git a/flang/lib/Semantics/check-directive-structure.h b/flang/lib/Semantics/check-directive-structure.h index b1bf3e5..bd78d3c 100644 --- a/flang/lib/Semantics/check-directive-structure.h +++ b/flang/lib/Semantics/check-directive-structure.h @@ -383,7 +383,8 @@ protected: const C &clause, const parser::ScalarIntConstantExpr &i); void RequiresPositiveParameter(const C &clause, - const parser::ScalarIntExpr &i, llvm::StringRef paramName = "parameter"); + const parser::ScalarIntExpr &i, llvm::StringRef paramName = "parameter", + bool allowZero = true); void OptionalConstantPositiveParameter( const C &clause, const std::optional<parser::ScalarIntConstantExpr> &o); @@ -657,9 +658,9 @@ void DirectiveStructureChecker<D, C, PC, ClauseEnumSize>::SayNotMatching( template <typename D, typename C, typename PC, std::size_t ClauseEnumSize> void DirectiveStructureChecker<D, C, PC, ClauseEnumSize>::RequiresPositiveParameter(const C &clause, - const parser::ScalarIntExpr &i, llvm::StringRef paramName) { + const parser::ScalarIntExpr &i, llvm::StringRef paramName, bool allowZero) { if (const auto v{GetIntValue(i)}) { - if (*v < 0) { + if (*v < (allowZero ? 0 : 1)) { context_.Say(GetContext().clauseSource, "The %s of the %s clause must be " "a positive integer expression"_err_en_US, diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index e224e06..c0c41c1 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -1361,9 +1361,19 @@ void OmpStructureChecker::Enter(const parser::OpenMPDeclareSimdConstruct &x) { return; } + auto isValidSymbol{[](const Symbol *sym) { + if (IsProcedure(*sym) || IsFunction(*sym)) { + return true; + } + if (const Symbol *owner{GetScopingUnit(sym->owner()).symbol()}) { + return IsProcedure(*owner) || IsFunction(*owner); + } + return false; + }}; + const parser::OmpArgument &arg{args.v.front()}; if (auto *sym{GetArgumentSymbol(arg)}) { - if (!IsProcedure(*sym) && !IsFunction(*sym)) { + if (!isValidSymbol(sym)) { auto &msg{context_.Say(arg.source, "The name '%s' should refer to a procedure"_err_en_US, sym->name())}; if (sym->test(Symbol::Flag::Implicit)) { @@ -3135,6 +3145,13 @@ void OmpStructureChecker::Enter(const parser::OmpClause &x) { } } +void OmpStructureChecker::Enter(const parser::OmpClause::Sizes &c) { + CheckAllowedClause(llvm::omp::Clause::OMPC_sizes); + for (const parser::Cosubscript &v : c.v) + RequiresPositiveParameter(llvm::omp::Clause::OMPC_sizes, v, + /*paramName=*/"parameter", /*allowZero=*/false); +} + // Following clauses do not have a separate node in parse-tree.h. CHECK_SIMPLE_CLAUSE(Absent, OMPC_absent) CHECK_SIMPLE_CLAUSE(Affinity, OMPC_affinity) @@ -3176,7 +3193,6 @@ CHECK_SIMPLE_CLAUSE(Notinbranch, OMPC_notinbranch) CHECK_SIMPLE_CLAUSE(Partial, OMPC_partial) CHECK_SIMPLE_CLAUSE(ProcBind, OMPC_proc_bind) CHECK_SIMPLE_CLAUSE(Simd, OMPC_simd) -CHECK_SIMPLE_CLAUSE(Sizes, OMPC_sizes) CHECK_SIMPLE_CLAUSE(Permutation, OMPC_permutation) CHECK_SIMPLE_CLAUSE(Uniform, OMPC_uniform) CHECK_SIMPLE_CLAUSE(Unknown, OMPC_unknown) diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp index 35b7718..a8ec4d6 100644 --- a/flang/lib/Semantics/openmp-utils.cpp +++ b/flang/lib/Semantics/openmp-utils.cpp @@ -41,6 +41,24 @@ namespace Fortran::semantics::omp { using namespace Fortran::parser::omp; +const Scope &GetScopingUnit(const Scope &scope) { + const Scope *iter{&scope}; + for (; !iter->IsTopLevel(); iter = &iter->parent()) { + switch (iter->kind()) { + case Scope::Kind::BlockConstruct: + case Scope::Kind::BlockData: + case Scope::Kind::DerivedType: + case Scope::Kind::MainProgram: + case Scope::Kind::Module: + case Scope::Kind::Subprogram: + return *iter; + default: + break; + } + } + return *iter; +} + SourcedActionStmt GetActionStmt(const parser::ExecutionPartConstruct *x) { if (x == nullptr) { return SourcedActionStmt{}; diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index bd7b8ac..02fcf02 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -328,6 +328,11 @@ public: return false; } + bool Pre(const parser::AccClause::UseDevice &x) { + ResolveAccObjectList(x.v, Symbol::Flag::AccUseDevice); + return false; + } + void Post(const parser::Name &); private: @@ -379,24 +384,6 @@ public: explicit OmpAttributeVisitor(SemanticsContext &context) : DirectiveAttributeVisitor(context) {} - static const Scope &scopingUnit(const Scope &scope) { - const Scope *iter{&scope}; - for (; !iter->IsTopLevel(); iter = &iter->parent()) { - switch (iter->kind()) { - case Scope::Kind::BlockConstruct: - case Scope::Kind::BlockData: - case Scope::Kind::DerivedType: - case Scope::Kind::MainProgram: - case Scope::Kind::Module: - case Scope::Kind::Subprogram: - return *iter; - default: - break; - } - } - return *iter; - } - template <typename A> void Walk(const A &x) { parser::Walk(x, *this); } template <typename A> bool Pre(const A &) { return true; } template <typename A> void Post(const A &) {} @@ -2303,14 +2290,17 @@ void OmpAttributeVisitor::CheckPerfectNestAndRectangularLoop( } auto checkPerfectNest = [&, this]() { - auto blockSize = block.size(); - if (blockSize <= 1) + if (block.empty()) return; + auto last = block.end(); + --last; - if (parser::Unwrap<parser::ContinueStmt>(x)) - blockSize -= 1; + // A trailing CONTINUE is not considered part of the loop body + if (parser::Unwrap<parser::ContinueStmt>(*last)) + --last; - if (blockSize <= 1) + // In a perfectly nested loop, the nested loop must be the only statement + if (last == block.begin()) return; // Non-perfectly nested loop @@ -2431,10 +2421,18 @@ void OmpAttributeVisitor::PrivatizeAssociatedLoopIndexAndCheckLoopLevel( void OmpAttributeVisitor::CheckAssocLoopLevel( std::int64_t level, const parser::OmpClause *clause) { if (clause && level != 0) { - context_.Say(clause->source, - "The value of the parameter in the COLLAPSE or ORDERED clause must" - " not be larger than the number of nested loops" - " following the construct."_err_en_US); + switch (clause->Id()) { + case llvm::omp::OMPC_sizes: + context_.Say(clause->source, + "The SIZES clause has more entries than there are nested canonical loops."_err_en_US); + break; + default: + context_.Say(clause->source, + "The value of the parameter in the COLLAPSE or ORDERED clause must" + " not be larger than the number of nested loops" + " following the construct."_err_en_US); + break; + } } } @@ -3086,8 +3084,8 @@ void OmpAttributeVisitor::ResolveOmpDesignator( checkScope = ompFlag == Symbol::Flag::OmpExecutableAllocateDirective; } if (checkScope) { - if (scopingUnit(GetContext().scope) != - scopingUnit(symbol->GetUltimate().owner())) { + if (omp::GetScopingUnit(GetContext().scope) != + omp::GetScopingUnit(symbol->GetUltimate().owner())) { context_.Say(designator.source, // 2.15.3 "List items must be declared in the same scoping unit in which the %s directive appears"_err_en_US, parser::ToUpperCaseLetters( diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index d1150a9..5041a6a 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -1387,6 +1387,8 @@ private: // Create scopes for OpenACC constructs class AccVisitor : public virtual DeclarationVisitor { public: + explicit AccVisitor(SemanticsContext &context) : context_{context} {} + void AddAccSourceRange(const parser::CharBlock &); static bool NeedsScope(const parser::OpenACCBlockConstruct &); @@ -1395,6 +1397,7 @@ public: void Post(const parser::OpenACCBlockConstruct &); bool Pre(const parser::OpenACCCombinedConstruct &); void Post(const parser::OpenACCCombinedConstruct &); + bool Pre(const parser::AccClause::UseDevice &x); bool Pre(const parser::AccBeginBlockDirective &x) { AddAccSourceRange(x.source); return true; @@ -1430,6 +1433,11 @@ public: void Post(const parser::AccBeginLoopDirective &x) { messageHandler().set_currStmtSource(std::nullopt); } + + void CopySymbolWithDevice(const parser::Name *name); + +private: + SemanticsContext &context_; }; bool AccVisitor::NeedsScope(const parser::OpenACCBlockConstruct &x) { @@ -1459,6 +1467,60 @@ bool AccVisitor::Pre(const parser::OpenACCBlockConstruct &x) { return true; } +void AccVisitor::CopySymbolWithDevice(const parser::Name *name) { + // When CUDA Fortran is enabled together with OpenACC, new + // symbols are created for the one appearing in the use_device + // clause. These new symbols have the CUDA Fortran device + // attribute. + if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) { + name->symbol = currScope().CopySymbol(*name->symbol); + if (auto *object{name->symbol->detailsIf<ObjectEntityDetails>()}) { + object->set_cudaDataAttr(common::CUDADataAttr::Device); + } + } +} + +bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) { + for (const auto &accObject : x.v.v) { + common::visit( + common::visitors{ + [&](const parser::Designator &designator) { + if (const auto *name{ + semantics::getDesignatorNameIfDataRef(designator)}) { + Symbol *prev{currScope().FindSymbol(name->source)}; + if (prev != name->symbol) { + name->symbol = prev; + } + CopySymbolWithDevice(name); + } else { + if (const auto *dataRef{ + std::get_if<parser::DataRef>(&designator.u)}) { + using ElementIndirection = + common::Indirection<parser::ArrayElement>; + if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) { + const parser::ArrayElement &arrayElement{ind->value()}; + Walk(arrayElement.subscripts); + const parser::DataRef &base{arrayElement.base}; + if (auto *name{std::get_if<parser::Name>(&base.u)}) { + Symbol *prev{currScope().FindSymbol(name->source)}; + if (prev != name->symbol) { + name->symbol = prev; + } + CopySymbolWithDevice(name); + } + } + } + } + }, + [&](const parser::Name &name) { + // TODO: common block in use_device? + }, + }, + accObject.u); + } + return false; +} + void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) { if (NeedsScope(x)) { PopScope(); @@ -2038,7 +2100,8 @@ public: ResolveNamesVisitor( SemanticsContext &context, ImplicitRulesMap &rules, Scope &top) - : BaseVisitor{context, *this, rules}, topScope_{top} { + : BaseVisitor{context, *this, rules}, AccVisitor(context), + topScope_{top} { PushScope(top); } |