aboutsummaryrefslogtreecommitdiff
path: root/flang/lib
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib')
-rw-r--r--flang/lib/Frontend/CompilerInvocation.cpp3
-rw-r--r--flang/lib/Frontend/FrontendActions.cpp8
-rw-r--r--flang/lib/Lower/Bridge.cpp2
-rw-r--r--flang/lib/Lower/OpenACC.cpp34
-rw-r--r--flang/lib/Lower/OpenMP/ClauseProcessor.cpp13
-rw-r--r--flang/lib/Lower/OpenMP/ClauseProcessor.h2
-rw-r--r--flang/lib/Lower/OpenMP/OpenMP.cpp360
-rw-r--r--flang/lib/Lower/OpenMP/Utils.cpp24
-rw-r--r--flang/lib/Lower/OpenMP/Utils.h7
-rw-r--r--flang/lib/Lower/SymbolMap.cpp10
-rw-r--r--flang/lib/Optimizer/Builder/FIRBuilder.cpp2
-rw-r--r--flang/lib/Optimizer/Builder/HLFIRTools.cpp81
-rw-r--r--flang/lib/Optimizer/Builder/IntrinsicCall.cpp38
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp207
-rw-r--r--flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp294
-rw-r--r--flang/lib/Optimizer/Transforms/AddDebugInfo.cpp9
-rw-r--r--flang/lib/Parser/parsing.cpp1
-rw-r--r--flang/lib/Parser/prescan.cpp5
-rw-r--r--flang/lib/Semantics/check-declarations.cpp3
-rw-r--r--flang/lib/Semantics/check-directive-structure.h7
-rw-r--r--flang/lib/Semantics/check-omp-structure.cpp20
-rw-r--r--flang/lib/Semantics/openmp-utils.cpp18
-rw-r--r--flang/lib/Semantics/resolve-directives.cpp56
-rw-r--r--flang/lib/Semantics/resolve-names.cpp65
24 files changed, 996 insertions, 273 deletions
diff --git a/flang/lib/Frontend/CompilerInvocation.cpp b/flang/lib/Frontend/CompilerInvocation.cpp
index 81610ed..548ca67 100644
--- a/flang/lib/Frontend/CompilerInvocation.cpp
+++ b/flang/lib/Frontend/CompilerInvocation.cpp
@@ -1425,6 +1425,9 @@ static bool parseFloatingPointArgs(CompilerInvocation &invoc,
opts.setFPContractMode(Fortran::common::LangOptions::FPM_Fast);
}
+ if (args.hasArg(clang::driver::options::OPT_fno_fast_real_mod))
+ opts.NoFastRealMod = true;
+
return true;
}
diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp
index d5e0325..0c630d2 100644
--- a/flang/lib/Frontend/FrontendActions.cpp
+++ b/flang/lib/Frontend/FrontendActions.cpp
@@ -277,6 +277,14 @@ bool CodeGenAction::beginSourceFileAction() {
ci.getInvocation().getLangOpts().OpenMPVersion);
}
+ if (ci.getInvocation().getLangOpts().NoFastRealMod) {
+ mlir::ModuleOp mod = lb.getModule();
+ mod.getOperation()->setAttr(
+ mlir::StringAttr::get(mod.getContext(),
+ llvm::Twine{"fir.no_fast_real_mod"}),
+ mlir::BoolAttr::get(mod.getContext(), true));
+ }
+
// Create a parse tree and lower it to FIR
parseAndLowerTree(ci, lb);
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 149e51b..780d56f 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3182,7 +3182,7 @@ private:
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
localSymbols.pushScope();
mlir::Value exitCond = genOpenACCConstruct(
- *this, bridge.getSemanticsContext(), getEval(), acc);
+ *this, bridge.getSemanticsContext(), getEval(), acc, localSymbols);
const Fortran::parser::OpenACCLoopConstruct *accLoop =
std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 95d0ada..f9b9b850 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::AccClauseList &accClauseList) {
+ const Fortran::parser::AccClauseList &accClauseList,
+ Fortran::lower::SymMap &localSymbols) {
mlir::Value ifCond;
llvm::SmallVector<mlir::Value> dataOperands;
bool addIfPresentAttr = false;
@@ -3199,6 +3200,19 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *useDevice =
std::get_if<Fortran::parser::AccClause::UseDevice>(
&clause.u)) {
+ // When CUDA Fotran is enabled, extra symbols are used in the host_data
+ // region. Look for them and bind their values with the symbols in the
+ // outer scope.
+ if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
+ const Fortran::parser::AccObjectList &objectList{useDevice->v};
+ for (const auto &accObject : objectList.v) {
+ Fortran::semantics::Symbol &symbol =
+ getSymbolFromAccObject(accObject);
+ const Fortran::semantics::Symbol *baseSym =
+ localSymbols.lookupSymbolByName(symbol.name().ToString());
+ localSymbols.copySymbolBinding(*baseSym, symbol);
+ }
+ }
genDataOperandOperations<mlir::acc::UseDeviceOp>(
useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
mlir::acc::DataClause::acc_use_device,
@@ -3239,11 +3253,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
hostDataOp.setIfPresentAttr(builder.getUnitAttr());
}
-static void
-genACC(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semanticsContext,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
+static void genACC(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semanticsContext,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenACCBlockConstruct &blockConstruct,
+ Fortran::lower::SymMap &localSymbols) {
const auto &beginBlockDirective =
std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
const auto &blockDirective =
@@ -3273,7 +3287,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
accClauseList);
} else if (blockDirective.v == llvm::acc::ACCD_host_data) {
genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
- stmtCtx, accClauseList);
+ stmtCtx, accClauseList, localSymbols);
}
}
@@ -4647,13 +4661,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCConstruct &accConstruct) {
+ const Fortran::parser::OpenACCConstruct &accConstruct,
+ Fortran::lower::SymMap &localSymbols) {
mlir::Value exitCond;
Fortran::common::visit(
common::visitors{
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
- genACC(converter, semanticsContext, eval, blockConstruct);
+ genACC(converter, semanticsContext, eval, blockConstruct,
+ localSymbols);
},
[&](const Fortran::parser::OpenACCCombinedConstruct
&combinedConstruct) {
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index a96884f..55eda7e 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -431,6 +431,19 @@ bool ClauseProcessor::processNumTasks(
return false;
}
+bool ClauseProcessor::processSizes(StatementContext &stmtCtx,
+ mlir::omp::SizesClauseOps &result) const {
+ if (auto *clause = findUniqueClause<omp::clause::Sizes>()) {
+ result.sizes.reserve(clause->v.size());
+ for (const ExprTy &vv : clause->v)
+ result.sizes.push_back(fir::getBase(converter.genExprValue(vv, stmtCtx)));
+
+ return true;
+ }
+
+ return false;
+}
+
bool ClauseProcessor::processNumTeams(
lower::StatementContext &stmtCtx,
mlir::omp::NumTeamsClauseOps &result) const {
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 324ea3c..9e352fa 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -66,6 +66,8 @@ public:
mlir::omp::LoopRelatedClauseOps &loopResult,
mlir::omp::CollapseClauseOps &collapseResult,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
+ bool processSizes(StatementContext &stmtCtx,
+ mlir::omp::SizesClauseOps &result) const;
bool processDevice(lower::StatementContext &stmtCtx,
mlir::omp::DeviceClauseOps &result) const;
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 1cb3335..9e56c2b 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1982,125 +1982,241 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return loopOp;
}
-static mlir::omp::CanonicalLoopOp
-genCanonicalLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx,
- lower::pft::Evaluation &eval, mlir::Location loc,
- const ConstructQueue &queue,
- ConstructQueue::const_iterator item,
- llvm::ArrayRef<const semantics::Symbol *> ivs,
- llvm::omp::Directive directive) {
+static void genCanonicalLoopNest(
+ lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+ mlir::Location loc, const ConstructQueue &queue,
+ ConstructQueue::const_iterator item, size_t numLoops,
+ llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
+ assert(loops.empty() && "Expecting empty list to fill");
+ assert(numLoops >= 1 && "Expecting at least one loop");
+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- assert(ivs.size() == 1 && "Nested loops not yet implemented");
- const semantics::Symbol *iv = ivs[0];
+ mlir::omp::LoopRelatedClauseOps loopInfo;
+ llvm::SmallVector<const semantics::Symbol *, 3> ivs;
+ collectLoopRelatedInfo(converter, loc, eval, numLoops, loopInfo, ivs);
+ assert(ivs.size() == numLoops &&
+ "Expected to parse as many loop variables as there are loops");
+
+ // Steps that follow:
+ // 1. Emit all of the loop's prologues (compute the tripcount)
+ // 2. Emit omp.canonical_loop nested inside each other (iteratively)
+ // 2.1. In the innermost omp.canonical_loop, emit the loop body prologue (in
+ // the body callback)
+ //
+ // Since emitting prologues and body code is split, remember prologue values
+ // for use when emitting the same loop's epilogues.
+ llvm::SmallVector<mlir::Value> tripcounts;
+ llvm::SmallVector<mlir::Value> clis;
+ llvm::SmallVector<lower::pft::Evaluation *> evals;
+ llvm::SmallVector<mlir::Type> loopVarTypes;
+ llvm::SmallVector<mlir::Value> loopStepVars;
+ llvm::SmallVector<mlir::Value> loopLBVars;
+ llvm::SmallVector<mlir::Value> blockArgs;
+
+ // Step 1: Loop prologues
+ // Computing the trip count must happen before entering the outermost loop
+ lower::pft::Evaluation *innermostEval = &eval.getFirstNestedEvaluation();
+ for ([[maybe_unused]] auto iv : ivs) {
+ if (innermostEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
+ // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct.
+ // Will need to add special cases for this combination.
+ TODO(loc, "DO CONCURRENT as canonical loop not supported");
+ }
+
+ auto &doLoopEval = innermostEval->getFirstNestedEvaluation();
+ evals.push_back(innermostEval);
+
+ // Get the loop bounds (and increment)
+ // auto &doLoopEval = nestedEval.getFirstNestedEvaluation();
+ auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>();
+ assert(doStmt && "Expected do loop to be in the nested evaluation");
+ auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t);
+ assert(loopControl.has_value());
+ auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
+ assert(bounds && "Expected bounds for canonical loop");
+ lower::StatementContext stmtCtx;
+ mlir::Value loopLBVar = fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx));
+ mlir::Value loopUBVar = fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx));
+ mlir::Value loopStepVar = [&]() {
+ if (bounds->step) {
+ return fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx));
+ }
- auto &nestedEval = eval.getFirstNestedEvaluation();
- if (nestedEval.getIf<parser::DoConstruct>()->IsDoConcurrent()) {
- // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct. Will
- // need to add special cases for this combination.
- TODO(loc, "DO CONCURRENT as canonical loop not supported");
+ // If `step` is not present, assume it is `1`.
+ auto intTy = firOpBuilder.getI32Type();
+ return firOpBuilder.createIntegerConstant(loc, intTy, 1);
+ }();
+
+ // Get the integer kind for the loop variable and cast the loop bounds
+ size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size();
+ mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
+ loopVarTypes.push_back(loopVarType);
+ loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar);
+ loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar);
+ loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar);
+ loopLBVars.push_back(loopLBVar);
+ loopStepVars.push_back(loopStepVar);
+
+ // Start lowering
+ mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0);
+ mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1);
+ mlir::Value isDownwards = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);
+
+ // Ensure we are counting upwards. If not, negate step and swap lb and ub.
+ mlir::Value negStep =
+ firOpBuilder.create<mlir::arith::SubIOp>(loc, zero, loopStepVar);
+ mlir::Value incr = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, negStep, loopStepVar);
+ mlir::Value lb = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, loopUBVar, loopLBVar);
+ mlir::Value ub = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, loopLBVar, loopUBVar);
+
+ // Compute the trip count assuming lb <= ub. This guarantees that the result
+ // is non-negative and we can use unsigned arithmetic.
+ mlir::Value span = firOpBuilder.create<mlir::arith::SubIOp>(
+ loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
+ mlir::Value tcMinusOne =
+ firOpBuilder.create<mlir::arith::DivUIOp>(loc, span, incr);
+ mlir::Value tcIfLooping = firOpBuilder.create<mlir::arith::AddIOp>(
+ loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw);
+
+ // Fall back to 0 if lb > ub
+ mlir::Value isZeroTC = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::slt, ub, lb);
+ mlir::Value tripcount = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isZeroTC, zero, tcIfLooping);
+ tripcounts.push_back(tripcount);
+
+ // Create the CLI handle.
+ auto newcli = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
+ mlir::Value cli = newcli.getResult();
+ clis.push_back(cli);
+
+ innermostEval = &*std::next(innermostEval->getNestedEvaluations().begin());
}
- // Get the loop bounds (and increment)
- auto &doLoopEval = nestedEval.getFirstNestedEvaluation();
- auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>();
- assert(doStmt && "Expected do loop to be in the nested evaluation");
- auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t);
- assert(loopControl.has_value());
- auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
- assert(bounds && "Expected bounds for canonical loop");
- lower::StatementContext stmtCtx;
- mlir::Value loopLBVar = fir::getBase(
- converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx));
- mlir::Value loopUBVar = fir::getBase(
- converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx));
- mlir::Value loopStepVar = [&]() {
- if (bounds->step) {
- return fir::getBase(
- converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx));
- }
+ // Step 2: Create nested canoncial loops
+ for (auto i : llvm::seq<size_t>(numLoops)) {
+ bool isInnermost = (i == numLoops - 1);
+ mlir::Type loopVarType = loopVarTypes[i];
+ mlir::Value tripcount = tripcounts[i];
+ mlir::Value cli = clis[i];
+ auto &&eval = evals[i];
+
+ auto ivCallback = [&, i, isInnermost](mlir::Operation *op)
+ -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
+ mlir::Region &region = op->getRegion(0);
+
+ // Create the op's region skeleton (BB taking the iv as argument)
+ firOpBuilder.createBlock(&region, {}, {loopVarType}, {loc});
+ blockArgs.push_back(region.front().getArgument(0));
+
+ // Step 2.1: Emit body prologue code
+ // Compute the translation from logical iteration number to the value of
+ // the loop's iteration variable only in the innermost body. Currently,
+ // loop transformations do not allow any instruction between loops, but
+ // this will change with
+ if (isInnermost) {
+ assert(blockArgs.size() == numLoops &&
+ "Expecting all block args to have been collected by now");
+ for (auto j : llvm::seq<size_t>(numLoops)) {
+ mlir::Value natIterNum = fir::getBase(blockArgs[j]);
+ mlir::Value scaled = firOpBuilder.create<mlir::arith::MulIOp>(
+ loc, natIterNum, loopStepVars[j]);
+ mlir::Value userVal = firOpBuilder.create<mlir::arith::AddIOp>(
+ loc, loopLBVars[j], scaled);
+
+ mlir::OpBuilder::InsertPoint insPt =
+ firOpBuilder.saveInsertionPoint();
+ firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
+ mlir::Type tempTy = converter.genType(*ivs[j]);
+ firOpBuilder.restoreInsertionPoint(insPt);
+
+ // Write the loop value into loop variable
+ mlir::Value cvtVal = firOpBuilder.createConvert(loc, tempTy, userVal);
+ hlfir::Entity lhs{converter.getSymbolAddress(*ivs[j])};
+ lhs = hlfir::derefPointersAndAllocatables(loc, firOpBuilder, lhs);
+ mlir::Operation *storeOp =
+ hlfir::AssignOp::create(firOpBuilder, loc, cvtVal, lhs);
+ firOpBuilder.setInsertionPointAfter(storeOp);
+ }
+ }
- // If `step` is not present, assume it is `1`.
- return firOpBuilder.createIntegerConstant(loc, firOpBuilder.getI32Type(),
- 1);
- }();
+ return {ivs[i]};
+ };
- // Get the integer kind for the loop variable and cast the loop bounds
- size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size();
- mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
- loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar);
- loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar);
- loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar);
-
- // Start lowering
- mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0);
- mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1);
- mlir::Value isDownwards = mlir::arith::CmpIOp::create(
- firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);
-
- // Ensure we are counting upwards. If not, negate step and swap lb and ub.
- mlir::Value negStep =
- mlir::arith::SubIOp::create(firOpBuilder, loc, zero, loopStepVar);
- mlir::Value incr = mlir::arith::SelectOp::create(
- firOpBuilder, loc, isDownwards, negStep, loopStepVar);
- mlir::Value lb = mlir::arith::SelectOp::create(firOpBuilder, loc, isDownwards,
- loopUBVar, loopLBVar);
- mlir::Value ub = mlir::arith::SelectOp::create(firOpBuilder, loc, isDownwards,
- loopLBVar, loopUBVar);
-
- // Compute the trip count assuming lb <= ub. This guarantees that the result
- // is non-negative and we can use unsigned arithmetic.
- mlir::Value span = mlir::arith::SubIOp::create(
- firOpBuilder, loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
- mlir::Value tcMinusOne =
- mlir::arith::DivUIOp::create(firOpBuilder, loc, span, incr);
- mlir::Value tcIfLooping =
- mlir::arith::AddIOp::create(firOpBuilder, loc, tcMinusOne, one,
- ::mlir::arith::IntegerOverflowFlags::nuw);
-
- // Fall back to 0 if lb > ub
- mlir::Value isZeroTC = mlir::arith::CmpIOp::create(
- firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, ub, lb);
- mlir::Value tripcount = mlir::arith::SelectOp::create(
- firOpBuilder, loc, isZeroTC, zero, tcIfLooping);
-
- // Create the CLI handle.
- auto newcli = mlir::omp::NewCliOp::create(firOpBuilder, loc);
- mlir::Value cli = newcli.getResult();
-
- auto ivCallback = [&](mlir::Operation *op)
- -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
- mlir::Region &region = op->getRegion(0);
-
- // Create the op's region skeleton (BB taking the iv as argument)
- firOpBuilder.createBlock(&region, {}, {loopVarType}, {loc});
-
- // Compute the value of the loop variable from the logical iteration number.
- mlir::Value natIterNum = fir::getBase(region.front().getArgument(0));
- mlir::Value scaled =
- mlir::arith::MulIOp::create(firOpBuilder, loc, natIterNum, loopStepVar);
- mlir::Value userVal =
- mlir::arith::AddIOp::create(firOpBuilder, loc, loopLBVar, scaled);
-
- // Write loop value to loop variable
- mlir::Operation *storeOp = setLoopVar(converter, loc, userVal, iv);
-
- firOpBuilder.setInsertionPointAfter(storeOp);
- return {iv};
- };
+ // Create the omp.canonical_loop operation
+ auto opGenInfo = OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *eval,
+ llvm::omp::Directive::OMPD_unknown)
+ .setGenSkeletonOnly(!isInnermost)
+ .setClauses(&item->clauses)
+ .setPrivatize(false)
+ .setGenRegionEntryCb(ivCallback);
+ auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
+ std::move(opGenInfo), queue, item, tripcount, cli);
+ loops.push_back(canonLoop);
+
+ // Insert next loop nested inside last loop
+ firOpBuilder.setInsertionPoint(
+ canonLoop.getRegion().back().getTerminator());
+ }
- // Create the omp.canonical_loop operation
- auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
- OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval,
- directive)
- .setClauses(&item->clauses)
- .setPrivatize(false)
- .setGenRegionEntryCb(ivCallback),
- queue, item, tripcount, cli);
+ firOpBuilder.setInsertionPointAfter(loops.front());
+}
+
+static void genTileOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ lower::StatementContext &stmtCtx,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const ConstructQueue &queue,
+ ConstructQueue::const_iterator item) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- firOpBuilder.setInsertionPointAfter(canonLoop);
- return canonLoop;
+ mlir::omp::SizesClauseOps sizesClause;
+ ClauseProcessor cp(converter, semaCtx, item->clauses);
+ cp.processSizes(stmtCtx, sizesClause);
+
+ size_t numLoops = sizesClause.sizes.size();
+ llvm::SmallVector<mlir::omp::CanonicalLoopOp, 3> canonLoops;
+ canonLoops.reserve(numLoops);
+
+ genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item,
+ numLoops, canonLoops);
+ assert((canonLoops.size() == numLoops) &&
+ "Expecting the predetermined number of loops");
+
+ llvm::SmallVector<mlir::Value, 3> applyees;
+ applyees.reserve(numLoops);
+ for (mlir::omp::CanonicalLoopOp l : canonLoops)
+ applyees.push_back(l.getCli());
+
+ // Emit the associated loops and create a CLI for each affected loop
+ llvm::SmallVector<mlir::Value, 3> gridGeneratees;
+ llvm::SmallVector<mlir::Value, 3> intratileGeneratees;
+ gridGeneratees.reserve(numLoops);
+ intratileGeneratees.reserve(numLoops);
+ for ([[maybe_unused]] auto i : llvm::seq<int>(0, sizesClause.sizes.size())) {
+ auto gridCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
+ gridGeneratees.push_back(gridCLI.getResult());
+ auto intratileCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
+ intratileGeneratees.push_back(intratileCLI.getResult());
+ }
+
+ llvm::SmallVector<mlir::Value, 6> generatees;
+ generatees.reserve(2 * numLoops);
+ generatees.append(gridGeneratees);
+ generatees.append(intratileGeneratees);
+
+ firOpBuilder.create<mlir::omp::TileOp>(loc, generatees, applyees,
+ sizesClause.sizes);
}
static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
@@ -2112,22 +2228,22 @@ static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
ConstructQueue::const_iterator item) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::omp::LoopRelatedClauseOps loopInfo;
- llvm::SmallVector<const semantics::Symbol *> iv;
- collectLoopRelatedInfo(converter, loc, eval, item->clauses, loopInfo, iv);
-
// Clauses for unrolling not yet implemnted
ClauseProcessor cp(converter, semaCtx, item->clauses);
cp.processTODO<clause::Partial, clause::Full>(
loc, llvm::omp::Directive::OMPD_unroll);
// Emit the associated loop
- auto canonLoop =
- genCanonicalLoopOp(converter, symTable, semaCtx, eval, loc, queue, item,
- iv, llvm::omp::Directive::OMPD_unroll);
+ llvm::SmallVector<mlir::omp::CanonicalLoopOp, 1> canonLoops;
+ genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item, 1,
+ canonLoops);
+
+ llvm::SmallVector<mlir::Value, 1> applyees;
+ for (auto &&canonLoop : canonLoops)
+ applyees.push_back(canonLoop.getCli());
// Apply unrolling to it
- auto cli = canonLoop.getCli();
+ auto cli = llvm::getSingleElement(canonLoops).getCli();
mlir::omp::UnrollHeuristicOp::create(firOpBuilder, loc, cli);
}
@@ -3360,13 +3476,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
item);
break;
- case llvm::omp::Directive::OMPD_tile: {
- unsigned version = semaCtx.langOptions().OpenMPVersion;
- if (!semaCtx.langOptions().OpenMPSimd)
- TODO(loc, "Unhandled loop directive (" +
- llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
+ case llvm::omp::Directive::OMPD_tile:
+ genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
- }
case llvm::omp::Directive::OMPD_unroll:
genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 83b7ccb..29cccbd 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -652,7 +652,6 @@ int64_t collectLoopRelatedInfo(
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
int64_t numCollapse = 1;
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
// Collect the loops to collapse.
lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
@@ -667,6 +666,25 @@ int64_t collectLoopRelatedInfo(
numCollapse = collapseValue;
}
+ collectLoopRelatedInfo(converter, currentLocation, eval, numCollapse, result,
+ iv);
+ return numCollapse;
+}
+
+void collectLoopRelatedInfo(
+ lower::AbstractConverter &converter, mlir::Location currentLocation,
+ lower::pft::Evaluation &eval, int64_t numCollapse,
+ mlir::omp::LoopRelatedClauseOps &result,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
+
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+ // Collect the loops to collapse.
+ lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
+ if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
+ TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
+ }
+
// Collect sizes from tile directive if present.
std::int64_t sizesLengthValue = 0l;
if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
@@ -676,7 +694,7 @@ int64_t collectLoopRelatedInfo(
});
}
- collapseValue = std::max(collapseValue, sizesLengthValue);
+ std::int64_t collapseValue = std::max(numCollapse, sizesLengthValue);
std::size_t loopVarTypeSize = 0;
do {
lower::pft::Evaluation *doLoop =
@@ -709,8 +727,6 @@ int64_t collectLoopRelatedInfo(
} while (collapseValue > 0);
convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
-
- return numCollapse;
}
} // namespace omp
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 5f191d8..69499f9 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -165,6 +165,13 @@ int64_t collectLoopRelatedInfo(
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
+void collectLoopRelatedInfo(
+ lower::AbstractConverter &converter, mlir::Location currentLocation,
+ lower::pft::Evaluation &eval, std::int64_t collapseValue,
+ // const omp::List<omp::Clause> &clauses,
+ mlir::omp::LoopRelatedClauseOps &result,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
+
void collectTileSizesFromOpenMPConstruct(
const parser::OpenMPConstruct *ompCons,
llvm::SmallVectorImpl<int64_t> &tileSizes,
diff --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp
index 080f21e..78529e0 100644
--- a/flang/lib/Lower/SymbolMap.cpp
+++ b/flang/lib/Lower/SymbolMap.cpp
@@ -45,6 +45,16 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) {
return SymbolBox::None{};
}
+const Fortran::semantics::Symbol *
+Fortran::lower::SymMap::lookupSymbolByName(llvm::StringRef symName) {
+ for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend();
+ jmap != jend; ++jmap)
+ for (auto const &[sym, symBox] : *jmap)
+ if (sym->name().ToString() == symName)
+ return sym;
+ return nullptr;
+}
+
Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol(
Fortran::semantics::SymbolRef symRef) {
auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 5e6e208..5da27d1 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -1943,7 +1943,7 @@ void fir::factory::genDimInfoFromBox(
return;
unsigned rank = fir::getBoxRank(boxType);
- assert(rank != 0 && "must be an array of known rank");
+ assert(!boxType.isAssumedRank() && "must be an array of known rank");
mlir::Type idxTy = builder.getIndexType();
for (unsigned i = 0; i < rank; ++i) {
mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index f93eaf7..dbfcae1 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -676,6 +676,34 @@ mlir::Value hlfir::genLBound(mlir::Location loc, fir::FirOpBuilder &builder,
return dimInfo.getLowerBound();
}
+static bool
+getExprLengthParameters(mlir::Value expr,
+ llvm::SmallVectorImpl<mlir::Value> &result) {
+ if (auto concat = expr.getDefiningOp<hlfir::ConcatOp>()) {
+ result.push_back(concat.getLength());
+ return true;
+ }
+ if (auto setLen = expr.getDefiningOp<hlfir::SetLengthOp>()) {
+ result.push_back(setLen.getLength());
+ return true;
+ }
+ if (auto elemental = expr.getDefiningOp<hlfir::ElementalOp>()) {
+ result.append(elemental.getTypeparams().begin(),
+ elemental.getTypeparams().end());
+ return true;
+ }
+ if (auto evalInMem = expr.getDefiningOp<hlfir::EvaluateInMemoryOp>()) {
+ result.append(evalInMem.getTypeparams().begin(),
+ evalInMem.getTypeparams().end());
+ return true;
+ }
+ if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) {
+ result.append(apply.getTypeparams().begin(), apply.getTypeparams().end());
+ return true;
+ }
+ return false;
+}
+
void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
Entity entity,
llvm::SmallVectorImpl<mlir::Value> &result) {
@@ -688,29 +716,14 @@ void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
// Going through fir::ExtendedValue would create a temp,
// which is not desired for an inquiry.
// TODO: make this an interface when adding further character producing ops.
- if (auto concat = expr.getDefiningOp<hlfir::ConcatOp>()) {
- result.push_back(concat.getLength());
- return;
- } else if (auto concat = expr.getDefiningOp<hlfir::SetLengthOp>()) {
- result.push_back(concat.getLength());
- return;
- } else if (auto asExpr = expr.getDefiningOp<hlfir::AsExprOp>()) {
+
+ if (auto asExpr = expr.getDefiningOp<hlfir::AsExprOp>()) {
hlfir::genLengthParameters(loc, builder, hlfir::Entity{asExpr.getVar()},
result);
return;
- } else if (auto elemental = expr.getDefiningOp<hlfir::ElementalOp>()) {
- result.append(elemental.getTypeparams().begin(),
- elemental.getTypeparams().end());
- return;
- } else if (auto evalInMem =
- expr.getDefiningOp<hlfir::EvaluateInMemoryOp>()) {
- result.append(evalInMem.getTypeparams().begin(),
- evalInMem.getTypeparams().end());
- return;
- } else if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) {
- result.append(apply.getTypeparams().begin(), apply.getTypeparams().end());
- return;
}
+ if (getExprLengthParameters(expr, result))
+ return;
if (entity.isCharacter()) {
result.push_back(hlfir::GetLengthOp::create(builder, loc, expr));
return;
@@ -733,6 +746,36 @@ mlir::Value hlfir::genCharLength(mlir::Location loc, fir::FirOpBuilder &builder,
return lenParams[0];
}
+std::optional<std::int64_t> hlfir::getCharLengthIfConst(hlfir::Entity entity) {
+ if (!entity.isCharacter()) {
+ return std::nullopt;
+ }
+ if (mlir::isa<hlfir::ExprType>(entity.getType())) {
+ mlir::Value expr = entity;
+ if (auto reassoc = expr.getDefiningOp<hlfir::NoReassocOp>())
+ expr = reassoc.getVal();
+
+ if (auto asExpr = expr.getDefiningOp<hlfir::AsExprOp>())
+ return getCharLengthIfConst(hlfir::Entity{asExpr.getVar()});
+
+ llvm::SmallVector<mlir::Value> param;
+ if (getExprLengthParameters(expr, param)) {
+ assert(param.size() == 1 && "characters must have one length parameters");
+ return fir::getIntIfConstant(param.pop_back_val());
+ }
+ return std::nullopt;
+ }
+
+ // entity is a var
+ if (mlir::Value len = tryGettingNonDeferredCharLen(entity))
+ return fir::getIntIfConstant(len);
+ auto charType =
+ mlir::cast<fir::CharacterType>(entity.getFortranElementType());
+ if (charType.hasConstantLen())
+ return charType.getLen();
+ return std::nullopt;
+}
+
mlir::Value hlfir::genRank(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity entity, mlir::Type resultType) {
if (!entity.isAssumedRank())
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 71d35e3..de7694f 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -6989,8 +6989,33 @@ mlir::Value IntrinsicLibrary::genMergeBits(mlir::Type resultType,
}
// MOD
+static mlir::Value genFastMod(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value a, mlir::Value p) {
+ auto fastmathFlags = mlir::arith::FastMathFlags::contract;
+ auto fastmathAttr =
+ mlir::arith::FastMathFlagsAttr::get(builder.getContext(), fastmathFlags);
+ mlir::Value divResult =
+ mlir::arith::DivFOp::create(builder, loc, a, p, fastmathAttr);
+ mlir::Type intType = builder.getIntegerType(
+ a.getType().getIntOrFloatBitWidth(), /*signed=*/true);
+ mlir::Value intResult = builder.createConvert(loc, intType, divResult);
+ mlir::Value cnvResult = builder.createConvert(loc, a.getType(), intResult);
+ mlir::Value mulResult =
+ mlir::arith::MulFOp::create(builder, loc, cnvResult, p, fastmathAttr);
+ mlir::Value subResult =
+ mlir::arith::SubFOp::create(builder, loc, a, mulResult, fastmathAttr);
+ return subResult;
+}
+
mlir::Value IntrinsicLibrary::genMod(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
+ auto mod = builder.getModule();
+ bool dontUseFastRealMod = false;
+ bool canUseApprox = mlir::arith::bitEnumContainsAny(
+ builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn);
+ if (auto attr = mod->getAttrOfType<mlir::BoolAttr>("fir.no_fast_real_mod"))
+ dontUseFastRealMod = attr.getValue();
+
assert(args.size() == 2);
if (resultType.isUnsignedInteger()) {
mlir::Type signlessType = mlir::IntegerType::get(
@@ -7002,9 +7027,16 @@ mlir::Value IntrinsicLibrary::genMod(mlir::Type resultType,
if (mlir::isa<mlir::IntegerType>(resultType))
return mlir::arith::RemSIOp::create(builder, loc, args[0], args[1]);
- // Use runtime.
- return builder.createConvert(
- loc, resultType, fir::runtime::genMod(builder, loc, args[0], args[1]));
+ if (resultType.isFloat() && canUseApprox && !dontUseFastRealMod) {
+ // Treat MOD as an approximate function and code-gen inline code
+ // instead of calling into the Fortran runtime library.
+ return builder.createConvert(loc, resultType,
+ genFastMod(builder, loc, args[0], args[1]));
+ } else {
+ // Use runtime.
+ return builder.createConvert(
+ loc, resultType, fir::runtime::genMod(builder, loc, args[0], args[1]));
+ }
}
// MODULO
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index d8e36ea..ce8ebaa 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -2284,6 +2284,212 @@ public:
}
};
+static std::pair<mlir::Value, hlfir::AssociateOp>
+getVariable(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value val) {
+ // If it is an expression - create a variable from it, or forward
+ // the value otherwise.
+ hlfir::AssociateOp associate;
+ if (!mlir::isa<hlfir::ExprType>(val.getType()))
+ return {val, associate};
+ hlfir::Entity entity{val};
+ mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr(builder);
+ associate = hlfir::genAssociateExpr(loc, builder, entity, entity.getType(),
+ "", byRefAttr);
+ return {associate.getBase(), associate};
+}
+
+class IndexOpConversion : public mlir::OpRewritePattern<hlfir::IndexOp> {
+public:
+ using mlir::OpRewritePattern<hlfir::IndexOp>::OpRewritePattern;
+
+ llvm::LogicalResult
+ matchAndRewrite(hlfir::IndexOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ // We simplify only limited cases:
+ // 1) a substring length shall be known at compile time
+ // 2) if a substring length is 0 then replace with 1 for forward search,
+ // or otherwise with the string length + 1 (builder shall const-fold if
+ // lookup direction is known at compile time).
+ // 3) for known string length at compile time, if it is
+ // shorter than substring => replace with zero.
+ // 4) if a substring length is one => inline as simple search loop
+ // 5) for forward search with input strings of kind=1 runtime is faster.
+ // Do not simplify in all the other cases relying on a runtime call.
+
+ fir::FirOpBuilder builder{rewriter, op.getOperation()};
+ const mlir::Location &loc = op->getLoc();
+
+ auto resultTy = op.getType();
+ mlir::Value back = op.getBack();
+ auto substrLenCst =
+ hlfir::getCharLengthIfConst(hlfir::Entity{op.getSubstr()});
+ if (!substrLenCst) {
+ return rewriter.notifyMatchFailure(
+ op, "substring length unknown at compile time");
+ }
+ hlfir::Entity strEntity{op.getStr()};
+ auto i1Ty = builder.getI1Type();
+ auto idxTy = builder.getIndexType();
+ if (*substrLenCst == 0) {
+ mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
+ // zero length substring. For back search replace with
+ // strLen+1, or otherwise with 1.
+ mlir::Value strLen = hlfir::genCharLength(loc, builder, strEntity);
+ mlir::Value strEnd = mlir::arith::AddIOp::create(
+ builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx);
+ if (back)
+ back = builder.createConvert(loc, i1Ty, back);
+ else
+ back = builder.createIntegerConstant(loc, i1Ty, 0);
+ mlir::Value result =
+ mlir::arith::SelectOp::create(builder, loc, back, strEnd, oneIdx);
+
+ rewriter.replaceOp(op, builder.createConvert(loc, resultTy, result));
+ return mlir::success();
+ }
+
+ if (auto strLenCst = hlfir::getCharLengthIfConst(strEntity)) {
+ if (*strLenCst < *substrLenCst) {
+ rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 0));
+ return mlir::success();
+ }
+ if (*strLenCst == 0) {
+ // both strings have zero length
+ rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 1));
+ return mlir::success();
+ }
+ }
+ if (*substrLenCst != 1) {
+ return rewriter.notifyMatchFailure(
+ op, "rely on runtime implementation if substring length > 1");
+ }
+ // For forward search and character kind=1 the runtime uses memchr
+ // which well optimized. But it looks like memchr idiom is not recognized
+ // in LLVM yet. On a micro-kernel test with strings of length 40 runtime
+ // had ~2x less execution time vs inlined code. For unknown search direction
+ // at compile time pessimistically assume "forward".
+ std::optional<bool> isBack;
+ if (back) {
+ if (auto backCst = fir::getIntIfConstant(back))
+ isBack = *backCst != 0;
+ } else {
+ isBack = false;
+ }
+ auto charTy = mlir::cast<fir::CharacterType>(
+ hlfir::getFortranElementType(op.getSubstr().getType()));
+ unsigned kind = charTy.getFKind();
+ if (kind == 1 && (!isBack || !*isBack)) {
+ return rewriter.notifyMatchFailure(
+ op, "rely on runtime implementation for character kind 1");
+ }
+
+ // All checks are passed here. Generate single character search loop.
+ auto [strV, strAssociate] = getVariable(builder, loc, op.getStr());
+ auto [substrV, substrAssociate] = getVariable(builder, loc, op.getSubstr());
+ hlfir::Entity str{strV};
+ hlfir::Entity substr{substrV};
+ mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
+
+ auto genExtractAndConvertToInt = [&charTy, &idxTy, &oneIdx,
+ kind](mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ hlfir::Entity &charStr,
+ mlir::Value index) {
+ auto bits = builder.getKindMap().getCharacterBitsize(kind);
+ auto intTy = builder.getIntegerType(bits);
+ auto charLen1Ty =
+ fir::CharacterType::getSingleton(builder.getContext(), kind);
+ mlir::Type designatorTy =
+ fir::ReferenceType::get(charLen1Ty, fir::isa_volatile_type(charTy));
+ auto idxAttr = builder.getIntegerAttr(idxTy, 0);
+
+ auto singleChr = hlfir::DesignateOp::create(
+ builder, loc, designatorTy, charStr, /*component=*/{},
+ /*compShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{},
+ /*substring=*/mlir::ValueRange{index, index},
+ /*complexPart=*/std::nullopt,
+ /*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{oneIdx},
+ fir::FortranVariableFlagsAttr{});
+ auto chrVal = fir::LoadOp::create(builder, loc, singleChr);
+ mlir::Value intVal = fir::ExtractValueOp::create(
+ builder, loc, intTy, chrVal, builder.getArrayAttr(idxAttr));
+ return intVal;
+ };
+
+ auto wantChar = genExtractAndConvertToInt(loc, builder, substr, oneIdx);
+
+ // Generate search loop body with the following C equivalent:
+ // idx_t result = 0;
+ // idx_t end = strlen + 1;
+ // char want = substr[0];
+ // for (idx_t idx = 1; idx < end; ++idx) {
+ // if (result == 0) {
+ // idx_t at = back ? end - idx: idx;
+ // result = str[at-1] == want ? at : result;
+ // }
+ // }
+ mlir::Value strLen = hlfir::genCharLength(loc, builder, strEntity);
+ if (!back)
+ back = builder.createIntegerConstant(loc, i1Ty, 0);
+ else
+ back = builder.createConvert(loc, i1Ty, back);
+ mlir::Value strEnd = mlir::arith::AddIOp::create(
+ builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx);
+ mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
+ auto genSearchBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
+ mlir::ValueRange index,
+ mlir::ValueRange reductionArgs)
+ -> llvm::SmallVector<mlir::Value, 1> {
+ assert(index.size() == 1 && "expected single loop");
+ assert(reductionArgs.size() == 1 && "expected single reduction value");
+ mlir::Value inRes = reductionArgs[0];
+ auto resEQzero = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroIdx);
+
+ mlir::Value res =
+ builder
+ .genIfOp(loc, {idxTy}, resEQzero,
+ /*withElseRegion=*/true)
+ .genThen([&]() {
+ mlir::Value idx = builder.createConvert(loc, idxTy, index[0]);
+ // offset = back ? end - idx : idx;
+ mlir::Value offset = mlir::arith::SelectOp::create(
+ builder, loc, back,
+ mlir::arith::SubIOp::create(builder, loc, strEnd, idx),
+ idx);
+
+ auto haveChar =
+ genExtractAndConvertToInt(loc, builder, str, offset);
+ auto charsEQ = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::eq, haveChar,
+ wantChar);
+ mlir::Value newVal = mlir::arith::SelectOp::create(
+ builder, loc, charsEQ, offset, inRes);
+
+ fir::ResultOp::create(builder, loc, newVal);
+ })
+ .genElse([&]() { fir::ResultOp::create(builder, loc, inRes); })
+ .getResults()[0];
+ return {res};
+ };
+
+ llvm::SmallVector<mlir::Value, 1> loopOut =
+ hlfir::genLoopNestWithReductions(loc, builder, {strLen},
+ /*reductionInits=*/{zeroIdx},
+ genSearchBody,
+ /*isUnordered=*/false);
+ mlir::Value result = builder.createConvert(loc, resultTy, loopOut[0]);
+
+ if (strAssociate)
+ hlfir::EndAssociateOp::create(builder, loc, strAssociate);
+ if (substrAssociate)
+ hlfir::EndAssociateOp::create(builder, loc, substrAssociate);
+
+ rewriter.replaceOp(op, result);
+ return mlir::success();
+ }
+};
+
template <typename Op>
class MatmulConversion : public mlir::OpRewritePattern<Op> {
public:
@@ -2955,6 +3161,7 @@ public:
patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context);
patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context);
patterns.insert<CmpCharOpConversion>(context);
+ patterns.insert<IndexOpConversion>(context);
patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context);
patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 57be863..e595e61 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -41,7 +41,9 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/StringSet.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cstddef>
#include <iterator>
@@ -75,6 +77,112 @@ class MapInfoFinalizationPass
/// | |
std::map<mlir::Operation *, mlir::Value> localBoxAllocas;
+ /// Return true if the given path exists in a list of paths.
+ static bool
+ containsPath(const llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &paths,
+ llvm::ArrayRef<int64_t> path) {
+ return llvm::any_of(paths, [&](const llvm::SmallVector<int64_t> &p) {
+ return p.size() == path.size() &&
+ std::equal(p.begin(), p.end(), path.begin());
+ });
+ }
+
+ /// Return true if the given path is already present in
+ /// op.getMembersIndexAttr().
+ static bool mappedIndexPathExists(mlir::omp::MapInfoOp op,
+ llvm::ArrayRef<int64_t> indexPath) {
+ if (mlir::ArrayAttr attr = op.getMembersIndexAttr()) {
+ for (mlir::Attribute list : attr) {
+ auto listAttr = mlir::cast<mlir::ArrayAttr>(list);
+ if (listAttr.size() != indexPath.size())
+ continue;
+ bool allEq = true;
+ for (auto [i, val] : llvm::enumerate(listAttr)) {
+ if (mlir::cast<mlir::IntegerAttr>(val).getInt() != indexPath[i]) {
+ allEq = false;
+ break;
+ }
+ }
+ if (allEq)
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /// Build a compact string key for an index path for set-based
+ /// deduplication. Format: "N:v0,v1,..." where N is the length.
+ static void buildPathKey(llvm::ArrayRef<int64_t> path,
+ llvm::SmallString<64> &outKey) {
+ outKey.clear();
+ llvm::raw_svector_ostream os(outKey);
+ os << path.size() << ':';
+ for (size_t i = 0; i < path.size(); ++i) {
+ if (i)
+ os << ',';
+ os << path[i];
+ }
+ }
+
+ /// Create the member map for coordRef and append it (and its index
+ /// path) to the provided new* vectors, if it is not already present.
+ void appendMemberMapIfNew(
+ mlir::omp::MapInfoOp op, fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value coordRef, llvm::ArrayRef<int64_t> indexPath,
+ llvm::StringRef memberName,
+ llvm::SmallVectorImpl<mlir::Value> &newMapOpsForFields,
+ llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &newMemberIndexPaths) {
+ // Local de-dup within this op invocation.
+ if (containsPath(newMemberIndexPaths, indexPath))
+ return;
+ // Global de-dup against already present member indices.
+ if (mappedIndexPathExists(op, indexPath))
+ return;
+
+ if (op.getMapperId()) {
+ mlir::omp::DeclareMapperOp symbol =
+ mlir::SymbolTable::lookupNearestSymbolFrom<
+ mlir::omp::DeclareMapperOp>(op, op.getMapperIdAttr());
+ assert(symbol && "missing symbol for declare mapper identifier");
+ mlir::omp::DeclareMapperInfoOp mapperInfo = symbol.getDeclareMapperInfo();
+ // TODO: Probably a way to cache these keys in someway so we don't
+ // constantly go through the process of rebuilding them on every check, to
+ // save some cycles, but it can wait for a subsequent patch.
+ for (auto v : mapperInfo.getMapVars()) {
+ mlir::omp::MapInfoOp map =
+ mlir::cast<mlir::omp::MapInfoOp>(v.getDefiningOp());
+ if (!map.getMembers().empty() && mappedIndexPathExists(map, indexPath))
+ return;
+ }
+ }
+
+ builder.setInsertionPoint(op);
+ fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
+ builder, coordRef, /*isOptional=*/false, loc);
+ llvm::SmallVector<mlir::Value> bounds = fir::factory::genImplicitBoundsOps<
+ mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
+ builder, info,
+ hlfir::translateToExtendedValue(loc, builder, hlfir::Entity{coordRef})
+ .first,
+ /*dataExvIsAssumedSize=*/false, loc);
+
+ mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
+ builder, loc, coordRef.getType(), coordRef,
+ mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())),
+ op.getMapTypeAttr(),
+ builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+ mlir::omp::VariableCaptureKind::ByRef),
+ /*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
+ /*members_index=*/mlir::ArrayAttr{}, bounds,
+ /*mapperId=*/mlir::FlatSymbolRefAttr(),
+ builder.getStringAttr(op.getNameAttr().strref() + "." + memberName +
+ ".implicit_map"),
+ /*partial_map=*/builder.getBoolAttr(false));
+
+ newMapOpsForFields.emplace_back(fieldMapOp);
+ newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
+ }
+
/// getMemberUserList gathers all users of a particular MapInfoOp that are
/// other MapInfoOp's and places them into the mapMemberUsers list, which
/// records the map that the current argument MapInfoOp "op" is part of
@@ -363,7 +471,7 @@ class MapInfoFinalizationPass
mlir::ArrayAttr newMembersAttr;
mlir::SmallVector<mlir::Value> newMembers;
llvm::SmallVector<llvm::SmallVector<int64_t>> memberIndices;
- bool IsHasDeviceAddr = isHasDeviceAddr(op, target);
+ bool isHasDeviceAddrFlag = isHasDeviceAddr(op, target);
if (!mapMemberUsers.empty() || !op.getMembers().empty())
getMemberIndicesAsVectors(
@@ -406,7 +514,7 @@ class MapInfoFinalizationPass
mapUser.parent.getMembersMutable().assign(newMemberOps);
mapUser.parent.setMembersIndexAttr(
builder.create2DI64ArrayAttr(memberIndices));
- } else if (!IsHasDeviceAddr) {
+ } else if (!isHasDeviceAddrFlag) {
auto baseAddr =
genBaseAddrMap(descriptor, op.getBounds(), op.getMapType(), builder);
newMembers.push_back(baseAddr);
@@ -429,7 +537,7 @@ class MapInfoFinalizationPass
// The contents of the descriptor (the base address in particular) will
// remain unchanged though.
uint64_t mapType = op.getMapType();
- if (IsHasDeviceAddr) {
+ if (isHasDeviceAddrFlag) {
mapType |= llvm::to_underlying(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
}
@@ -701,94 +809,134 @@ class MapInfoFinalizationPass
auto recordType = mlir::cast<fir::RecordType>(underlyingType);
llvm::SmallVector<mlir::Value> newMapOpsForFields;
- llvm::SmallVector<int64_t> fieldIndicies;
+ llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndexPaths;
+ // 1) Handle direct top-level allocatable fields.
for (auto fieldMemTyPair : recordType.getTypeList()) {
auto &field = fieldMemTyPair.first;
auto memTy = fieldMemTyPair.second;
- bool shouldMapField =
- llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
- if (!fir::isAllocatableType(memTy))
- return false;
-
- auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
- if (!designateOp)
- return false;
-
- return designateOp.getComponent() &&
- designateOp.getComponent()->strref() == field;
- }) != mapVarForwardSlice.end();
-
- // TODO Handle recursive record types. Adapting
- // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
- // entities might be helpful here.
-
- if (!shouldMapField)
+ if (!fir::isAllocatableType(memTy))
continue;
- int32_t fieldIdx = recordType.getFieldIndex(field);
- bool alreadyMapped = [&]() {
- if (op.getMembersIndexAttr())
- for (auto indexList : op.getMembersIndexAttr()) {
- auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
- if (indexListAttr.size() == 1 &&
- mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
- fieldIdx)
- return true;
- }
-
- return false;
- }();
-
- if (alreadyMapped)
+ bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) {
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
+ return designateOp && designateOp.getComponent() &&
+ designateOp.getComponent()->strref() == field;
+ });
+ if (!referenced)
continue;
+ int32_t fieldIdx = recordType.getFieldIndex(field);
builder.setInsertionPoint(op);
fir::IntOrValue idxConst =
mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
auto fieldCoord = fir::CoordinateOp::create(
builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
- fir::factory::AddrAndBoundsInfo info =
- fir::factory::getDataOperandBaseAddr(
- builder, fieldCoord, /*isOptional=*/false, op.getLoc());
- llvm::SmallVector<mlir::Value> bounds =
- fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
- mlir::omp::MapBoundsType>(
- builder, info,
- hlfir::translateToExtendedValue(op.getLoc(), builder,
- hlfir::Entity{fieldCoord})
- .first,
- /*dataExvIsAssumedSize=*/false, op.getLoc());
-
- mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
- builder, op.getLoc(), fieldCoord.getResult().getType(),
- fieldCoord.getResult(),
- mlir::TypeAttr::get(
- fir::unwrapRefType(fieldCoord.getResult().getType())),
- op.getMapTypeAttr(),
- builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
- mlir::omp::VariableCaptureKind::ByRef),
- /*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
- /*members_index=*/mlir::ArrayAttr{}, bounds,
- /*mapperId=*/mlir::FlatSymbolRefAttr(),
- builder.getStringAttr(op.getNameAttr().strref() + "." + field +
- ".implicit_map"),
- /*partial_map=*/builder.getBoolAttr(false));
- newMapOpsForFields.emplace_back(fieldMapOp);
- fieldIndicies.emplace_back(fieldIdx);
+ int64_t fieldIdx64 = static_cast<int64_t>(fieldIdx);
+ llvm::SmallVector<int64_t, 1> idxPath{fieldIdx64};
+ appendMemberMapIfNew(op, builder, op.getLoc(), fieldCoord, idxPath,
+ field, newMapOpsForFields, newMemberIndexPaths);
+ }
+
+ // Handle nested allocatable fields along any component chain
+ // referenced in the region via HLFIR designates.
+ llvm::SmallVector<llvm::SmallVector<int64_t>> seenIndexPaths;
+ for (mlir::Operation *sliceOp : mapVarForwardSlice) {
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+ if (!designateOp || !designateOp.getComponent())
+ continue;
+ llvm::SmallVector<llvm::StringRef> compPathReversed;
+ compPathReversed.push_back(designateOp.getComponent()->strref());
+ mlir::Value curBase = designateOp.getMemref();
+ bool rootedAtMapArg = false;
+ while (true) {
+ if (auto parentDes = curBase.getDefiningOp<hlfir::DesignateOp>()) {
+ if (!parentDes.getComponent())
+ break;
+ compPathReversed.push_back(parentDes.getComponent()->strref());
+ curBase = parentDes.getMemref();
+ continue;
+ }
+ if (auto decl = curBase.getDefiningOp<hlfir::DeclareOp>()) {
+ if (auto barg =
+ mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref()))
+ rootedAtMapArg = (barg == opBlockArg);
+ } else if (auto blockArg =
+ mlir::dyn_cast_or_null<mlir::BlockArgument>(
+ curBase)) {
+ rootedAtMapArg = (blockArg == opBlockArg);
+ }
+ break;
+ }
+ // Only process nested paths (2+ components). Single-component paths
+ // for direct fields are handled above.
+ if (!rootedAtMapArg || compPathReversed.size() < 2)
+ continue;
+ builder.setInsertionPoint(op);
+ llvm::SmallVector<int64_t> indexPath;
+ mlir::Type curTy = underlyingType;
+ mlir::Value coordRef = op.getVarPtr();
+ bool validPath = true;
+ for (llvm::StringRef compName : llvm::reverse(compPathReversed)) {
+ auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
+ if (!recTy) {
+ validPath = false;
+ break;
+ }
+ int32_t idx = recTy.getFieldIndex(compName);
+ if (idx < 0) {
+ validPath = false;
+ break;
+ }
+ indexPath.push_back(idx);
+ mlir::Type memTy = recTy.getType(idx);
+ fir::IntOrValue idxConst =
+ mlir::IntegerAttr::get(builder.getI32Type(), idx);
+ coordRef = fir::CoordinateOp::create(
+ builder, op.getLoc(), builder.getRefType(memTy), coordRef,
+ llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
+ curTy = memTy;
+ }
+ if (!validPath)
+ continue;
+ if (auto finalRefTy =
+ mlir::dyn_cast<fir::ReferenceType>(coordRef.getType())) {
+ mlir::Type eleTy = finalRefTy.getElementType();
+ if (fir::isAllocatableType(eleTy)) {
+ if (!containsPath(seenIndexPaths, indexPath)) {
+ seenIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
+ appendMemberMapIfNew(op, builder, op.getLoc(), coordRef,
+ indexPath, compPathReversed.front(),
+ newMapOpsForFields, newMemberIndexPaths);
+ }
+ }
+ }
}
if (newMapOpsForFields.empty())
return mlir::WalkResult::advance();
- op.getMembersMutable().append(newMapOpsForFields);
+ // Deduplicate by index path to avoid emitting duplicate members for
+ // the same component. Use a set-based key to keep this near O(n).
+ llvm::SmallVector<mlir::Value> dedupMapOps;
+ llvm::SmallVector<llvm::SmallVector<int64_t>> dedupIndexPaths;
+ llvm::StringSet<> seenKeys;
+ for (auto [i, mapOp] : llvm::enumerate(newMapOpsForFields)) {
+ const auto &path = newMemberIndexPaths[i];
+ llvm::SmallString<64> key;
+ buildPathKey(path, key);
+ if (seenKeys.contains(key))
+ continue;
+ seenKeys.insert(key);
+ dedupMapOps.push_back(mapOp);
+ dedupIndexPaths.emplace_back(path.begin(), path.end());
+ }
+ op.getMembersMutable().append(dedupMapOps);
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
- mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
-
- if (oldMembersIdxAttr)
- for (mlir::Attribute indexList : oldMembersIdxAttr) {
+ if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr())
+ for (mlir::Attribute indexList : oldAttr) {
llvm::SmallVector<int64_t> listVec;
for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
@@ -796,10 +944,8 @@ class MapInfoFinalizationPass
newMemberIndices.emplace_back(std::move(listVec));
}
-
- for (int64_t newFieldIdx : fieldIndicies)
- newMemberIndices.emplace_back(
- llvm::SmallVector<int64_t>(1, newFieldIdx));
+ for (auto &path : dedupIndexPaths)
+ newMemberIndices.emplace_back(path);
op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
op.setPartialMap(true);
diff --git a/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp b/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp
index bdf7e4a..e006d2e 100644
--- a/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp
+++ b/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp
@@ -285,11 +285,16 @@ mlir::LLVM::DIModuleAttr AddDebugInfoPass::getOrCreateModuleAttr(
if (auto iter{moduleMap.find(name)}; iter != moduleMap.end()) {
modAttr = iter->getValue();
} else {
+ // When decl is true, it means that module is only being used in this
+ // compilation unit and it is defined elsewhere. But if the file/line/scope
+ // fields are valid, the module is not merged with its definition and is
+ // considered different. So we only set those fields when decl is false.
modAttr = mlir::LLVM::DIModuleAttr::get(
- context, fileAttr, scope, mlir::StringAttr::get(context, name),
+ context, decl ? nullptr : fileAttr, decl ? nullptr : scope,
+ mlir::StringAttr::get(context, name),
/* configMacros */ mlir::StringAttr(),
/* includePath */ mlir::StringAttr(),
- /* apinotes */ mlir::StringAttr(), line, decl);
+ /* apinotes */ mlir::StringAttr(), decl ? 0 : line, decl);
moduleMap[name] = modAttr;
}
return modAttr;
diff --git a/flang/lib/Parser/parsing.cpp b/flang/lib/Parser/parsing.cpp
index 8a8c6ef..2df6881 100644
--- a/flang/lib/Parser/parsing.cpp
+++ b/flang/lib/Parser/parsing.cpp
@@ -85,6 +85,7 @@ const SourceFile *Parsing::Prescan(const std::string &path, Options options) {
if (options.features.IsEnabled(LanguageFeature::OpenACC) ||
(options.prescanAndReformat && noneOfTheAbove)) {
prescanner.AddCompilerDirectiveSentinel("$acc");
+ prescanner.AddCompilerDirectiveSentinel("@acc");
}
if (options.features.IsEnabled(LanguageFeature::OpenMP) ||
(options.prescanAndReformat && noneOfTheAbove)) {
diff --git a/flang/lib/Parser/prescan.cpp b/flang/lib/Parser/prescan.cpp
index 865c149..66e5b2c 100644
--- a/flang/lib/Parser/prescan.cpp
+++ b/flang/lib/Parser/prescan.cpp
@@ -147,6 +147,11 @@ void Prescanner::Statement() {
directiveSentinel_[4] == '\0') {
// CUDA conditional compilation line.
condOffset = 5;
+ } else if (directiveSentinel_[0] == '@' && directiveSentinel_[1] == 'a' &&
+ directiveSentinel_[2] == 'c' && directiveSentinel_[3] == 'c' &&
+ directiveSentinel_[4] == '\0') {
+ // OpenACC conditional compilation line.
+ condOffset = 5;
}
if (condOffset && !preprocessingOnly_) {
at_ += *condOffset, column_ += *condOffset;
diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp
index 1049a6d2..7b88100 100644
--- a/flang/lib/Semantics/check-declarations.cpp
+++ b/flang/lib/Semantics/check-declarations.cpp
@@ -1189,7 +1189,8 @@ void CheckHelper::CheckObjectEntity(
}
} else if (!subpDetails && symbol.owner().kind() != Scope::Kind::Module &&
symbol.owner().kind() != Scope::Kind::MainProgram &&
- symbol.owner().kind() != Scope::Kind::BlockConstruct) {
+ symbol.owner().kind() != Scope::Kind::BlockConstruct &&
+ symbol.owner().kind() != Scope::Kind::OpenACCConstruct) {
messages_.Say(
"ATTRIBUTES(%s) may apply only to module, host subprogram, block, or device subprogram data"_err_en_US,
parser::ToUpperCaseLetters(common::EnumToString(attr)));
diff --git a/flang/lib/Semantics/check-directive-structure.h b/flang/lib/Semantics/check-directive-structure.h
index b1bf3e5..bd78d3c 100644
--- a/flang/lib/Semantics/check-directive-structure.h
+++ b/flang/lib/Semantics/check-directive-structure.h
@@ -383,7 +383,8 @@ protected:
const C &clause, const parser::ScalarIntConstantExpr &i);
void RequiresPositiveParameter(const C &clause,
- const parser::ScalarIntExpr &i, llvm::StringRef paramName = "parameter");
+ const parser::ScalarIntExpr &i, llvm::StringRef paramName = "parameter",
+ bool allowZero = true);
void OptionalConstantPositiveParameter(
const C &clause, const std::optional<parser::ScalarIntConstantExpr> &o);
@@ -657,9 +658,9 @@ void DirectiveStructureChecker<D, C, PC, ClauseEnumSize>::SayNotMatching(
template <typename D, typename C, typename PC, std::size_t ClauseEnumSize>
void DirectiveStructureChecker<D, C, PC,
ClauseEnumSize>::RequiresPositiveParameter(const C &clause,
- const parser::ScalarIntExpr &i, llvm::StringRef paramName) {
+ const parser::ScalarIntExpr &i, llvm::StringRef paramName, bool allowZero) {
if (const auto v{GetIntValue(i)}) {
- if (*v < 0) {
+ if (*v < (allowZero ? 0 : 1)) {
context_.Say(GetContext().clauseSource,
"The %s of the %s clause must be "
"a positive integer expression"_err_en_US,
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index e224e06..c0c41c1 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -1361,9 +1361,19 @@ void OmpStructureChecker::Enter(const parser::OpenMPDeclareSimdConstruct &x) {
return;
}
+ auto isValidSymbol{[](const Symbol *sym) {
+ if (IsProcedure(*sym) || IsFunction(*sym)) {
+ return true;
+ }
+ if (const Symbol *owner{GetScopingUnit(sym->owner()).symbol()}) {
+ return IsProcedure(*owner) || IsFunction(*owner);
+ }
+ return false;
+ }};
+
const parser::OmpArgument &arg{args.v.front()};
if (auto *sym{GetArgumentSymbol(arg)}) {
- if (!IsProcedure(*sym) && !IsFunction(*sym)) {
+ if (!isValidSymbol(sym)) {
auto &msg{context_.Say(arg.source,
"The name '%s' should refer to a procedure"_err_en_US, sym->name())};
if (sym->test(Symbol::Flag::Implicit)) {
@@ -3135,6 +3145,13 @@ void OmpStructureChecker::Enter(const parser::OmpClause &x) {
}
}
+void OmpStructureChecker::Enter(const parser::OmpClause::Sizes &c) {
+ CheckAllowedClause(llvm::omp::Clause::OMPC_sizes);
+ for (const parser::Cosubscript &v : c.v)
+ RequiresPositiveParameter(llvm::omp::Clause::OMPC_sizes, v,
+ /*paramName=*/"parameter", /*allowZero=*/false);
+}
+
// Following clauses do not have a separate node in parse-tree.h.
CHECK_SIMPLE_CLAUSE(Absent, OMPC_absent)
CHECK_SIMPLE_CLAUSE(Affinity, OMPC_affinity)
@@ -3176,7 +3193,6 @@ CHECK_SIMPLE_CLAUSE(Notinbranch, OMPC_notinbranch)
CHECK_SIMPLE_CLAUSE(Partial, OMPC_partial)
CHECK_SIMPLE_CLAUSE(ProcBind, OMPC_proc_bind)
CHECK_SIMPLE_CLAUSE(Simd, OMPC_simd)
-CHECK_SIMPLE_CLAUSE(Sizes, OMPC_sizes)
CHECK_SIMPLE_CLAUSE(Permutation, OMPC_permutation)
CHECK_SIMPLE_CLAUSE(Uniform, OMPC_uniform)
CHECK_SIMPLE_CLAUSE(Unknown, OMPC_unknown)
diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp
index 35b7718..a8ec4d6 100644
--- a/flang/lib/Semantics/openmp-utils.cpp
+++ b/flang/lib/Semantics/openmp-utils.cpp
@@ -41,6 +41,24 @@
namespace Fortran::semantics::omp {
using namespace Fortran::parser::omp;
+const Scope &GetScopingUnit(const Scope &scope) {
+ const Scope *iter{&scope};
+ for (; !iter->IsTopLevel(); iter = &iter->parent()) {
+ switch (iter->kind()) {
+ case Scope::Kind::BlockConstruct:
+ case Scope::Kind::BlockData:
+ case Scope::Kind::DerivedType:
+ case Scope::Kind::MainProgram:
+ case Scope::Kind::Module:
+ case Scope::Kind::Subprogram:
+ return *iter;
+ default:
+ break;
+ }
+ }
+ return *iter;
+}
+
SourcedActionStmt GetActionStmt(const parser::ExecutionPartConstruct *x) {
if (x == nullptr) {
return SourcedActionStmt{};
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index bd7b8ac..02fcf02 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -328,6 +328,11 @@ public:
return false;
}
+ bool Pre(const parser::AccClause::UseDevice &x) {
+ ResolveAccObjectList(x.v, Symbol::Flag::AccUseDevice);
+ return false;
+ }
+
void Post(const parser::Name &);
private:
@@ -379,24 +384,6 @@ public:
explicit OmpAttributeVisitor(SemanticsContext &context)
: DirectiveAttributeVisitor(context) {}
- static const Scope &scopingUnit(const Scope &scope) {
- const Scope *iter{&scope};
- for (; !iter->IsTopLevel(); iter = &iter->parent()) {
- switch (iter->kind()) {
- case Scope::Kind::BlockConstruct:
- case Scope::Kind::BlockData:
- case Scope::Kind::DerivedType:
- case Scope::Kind::MainProgram:
- case Scope::Kind::Module:
- case Scope::Kind::Subprogram:
- return *iter;
- default:
- break;
- }
- }
- return *iter;
- }
-
template <typename A> void Walk(const A &x) { parser::Walk(x, *this); }
template <typename A> bool Pre(const A &) { return true; }
template <typename A> void Post(const A &) {}
@@ -2303,14 +2290,17 @@ void OmpAttributeVisitor::CheckPerfectNestAndRectangularLoop(
}
auto checkPerfectNest = [&, this]() {
- auto blockSize = block.size();
- if (blockSize <= 1)
+ if (block.empty())
return;
+ auto last = block.end();
+ --last;
- if (parser::Unwrap<parser::ContinueStmt>(x))
- blockSize -= 1;
+ // A trailing CONTINUE is not considered part of the loop body
+ if (parser::Unwrap<parser::ContinueStmt>(*last))
+ --last;
- if (blockSize <= 1)
+ // In a perfectly nested loop, the nested loop must be the only statement
+ if (last == block.begin())
return;
// Non-perfectly nested loop
@@ -2431,10 +2421,18 @@ void OmpAttributeVisitor::PrivatizeAssociatedLoopIndexAndCheckLoopLevel(
void OmpAttributeVisitor::CheckAssocLoopLevel(
std::int64_t level, const parser::OmpClause *clause) {
if (clause && level != 0) {
- context_.Say(clause->source,
- "The value of the parameter in the COLLAPSE or ORDERED clause must"
- " not be larger than the number of nested loops"
- " following the construct."_err_en_US);
+ switch (clause->Id()) {
+ case llvm::omp::OMPC_sizes:
+ context_.Say(clause->source,
+ "The SIZES clause has more entries than there are nested canonical loops."_err_en_US);
+ break;
+ default:
+ context_.Say(clause->source,
+ "The value of the parameter in the COLLAPSE or ORDERED clause must"
+ " not be larger than the number of nested loops"
+ " following the construct."_err_en_US);
+ break;
+ }
}
}
@@ -3086,8 +3084,8 @@ void OmpAttributeVisitor::ResolveOmpDesignator(
checkScope = ompFlag == Symbol::Flag::OmpExecutableAllocateDirective;
}
if (checkScope) {
- if (scopingUnit(GetContext().scope) !=
- scopingUnit(symbol->GetUltimate().owner())) {
+ if (omp::GetScopingUnit(GetContext().scope) !=
+ omp::GetScopingUnit(symbol->GetUltimate().owner())) {
context_.Say(designator.source, // 2.15.3
"List items must be declared in the same scoping unit in which the %s directive appears"_err_en_US,
parser::ToUpperCaseLetters(
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index d1150a9..5041a6a 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1387,6 +1387,8 @@ private:
// Create scopes for OpenACC constructs
class AccVisitor : public virtual DeclarationVisitor {
public:
+ explicit AccVisitor(SemanticsContext &context) : context_{context} {}
+
void AddAccSourceRange(const parser::CharBlock &);
static bool NeedsScope(const parser::OpenACCBlockConstruct &);
@@ -1395,6 +1397,7 @@ public:
void Post(const parser::OpenACCBlockConstruct &);
bool Pre(const parser::OpenACCCombinedConstruct &);
void Post(const parser::OpenACCCombinedConstruct &);
+ bool Pre(const parser::AccClause::UseDevice &x);
bool Pre(const parser::AccBeginBlockDirective &x) {
AddAccSourceRange(x.source);
return true;
@@ -1430,6 +1433,11 @@ public:
void Post(const parser::AccBeginLoopDirective &x) {
messageHandler().set_currStmtSource(std::nullopt);
}
+
+ void CopySymbolWithDevice(const parser::Name *name);
+
+private:
+ SemanticsContext &context_;
};
bool AccVisitor::NeedsScope(const parser::OpenACCBlockConstruct &x) {
@@ -1459,6 +1467,60 @@ bool AccVisitor::Pre(const parser::OpenACCBlockConstruct &x) {
return true;
}
+void AccVisitor::CopySymbolWithDevice(const parser::Name *name) {
+ // When CUDA Fortran is enabled together with OpenACC, new
+ // symbols are created for the one appearing in the use_device
+ // clause. These new symbols have the CUDA Fortran device
+ // attribute.
+ if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) {
+ name->symbol = currScope().CopySymbol(*name->symbol);
+ if (auto *object{name->symbol->detailsIf<ObjectEntityDetails>()}) {
+ object->set_cudaDataAttr(common::CUDADataAttr::Device);
+ }
+ }
+}
+
+bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) {
+ for (const auto &accObject : x.v.v) {
+ common::visit(
+ common::visitors{
+ [&](const parser::Designator &designator) {
+ if (const auto *name{
+ semantics::getDesignatorNameIfDataRef(designator)}) {
+ Symbol *prev{currScope().FindSymbol(name->source)};
+ if (prev != name->symbol) {
+ name->symbol = prev;
+ }
+ CopySymbolWithDevice(name);
+ } else {
+ if (const auto *dataRef{
+ std::get_if<parser::DataRef>(&designator.u)}) {
+ using ElementIndirection =
+ common::Indirection<parser::ArrayElement>;
+ if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
+ const parser::ArrayElement &arrayElement{ind->value()};
+ Walk(arrayElement.subscripts);
+ const parser::DataRef &base{arrayElement.base};
+ if (auto *name{std::get_if<parser::Name>(&base.u)}) {
+ Symbol *prev{currScope().FindSymbol(name->source)};
+ if (prev != name->symbol) {
+ name->symbol = prev;
+ }
+ CopySymbolWithDevice(name);
+ }
+ }
+ }
+ }
+ },
+ [&](const parser::Name &name) {
+ // TODO: common block in use_device?
+ },
+ },
+ accObject.u);
+ }
+ return false;
+}
+
void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) {
if (NeedsScope(x)) {
PopScope();
@@ -2038,7 +2100,8 @@ public:
ResolveNamesVisitor(
SemanticsContext &context, ImplicitRulesMap &rules, Scope &top)
- : BaseVisitor{context, *this, rules}, topScope_{top} {
+ : BaseVisitor{context, *this, rules}, AccVisitor(context),
+ topScope_{top} {
PushScope(top);
}