aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Lower/OpenACC.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Lower/OpenACC.cpp')
-rw-r--r--flang/lib/Lower/OpenACC.cpp380
1 files changed, 285 insertions, 95 deletions
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 4a9e494..62e5c0c 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -20,6 +20,7 @@
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/Support/Utils.h"
+#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
@@ -33,6 +34,7 @@
#include "flang/Semantics/scope.h"
#include "flang/Semantics/tools.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
@@ -60,6 +62,16 @@ static llvm::cl::opt<bool> lowerDoLoopToAccLoop(
llvm::cl::desc("Whether to lower do loops as `acc.loop` operations."),
llvm::cl::init(true));
+static llvm::cl::opt<bool> enableSymbolRemapping(
+ "openacc-remap-symbols",
+ llvm::cl::desc("Whether to remap symbols that appears in data clauses."),
+ llvm::cl::init(true));
+
+static llvm::cl::opt<bool> enableDevicePtrRemap(
+ "openacc-remap-device-ptr-symbols",
+ llvm::cl::desc("sub-option of openacc-remap-symbols for deviceptr clause"),
+ llvm::cl::init(false));
+
// Special value for * passed in device_type or gang clauses.
static constexpr std::int64_t starCst = -1;
@@ -624,17 +636,19 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
}
template <typename Op>
-static void
-genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semanticsContext,
- Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &dataOperands,
- mlir::acc::DataClause dataClause, bool structured,
- bool implicit, llvm::ArrayRef<mlir::Value> async,
- llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
- llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
- bool setDeclareAttr = false) {
+static void genDataOperandOperations(
+ const Fortran::parser::AccObjectList &objectList,
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semanticsContext,
+ Fortran::lower::StatementContext &stmtCtx,
+ llvm::SmallVectorImpl<mlir::Value> &dataOperands,
+ mlir::acc::DataClause dataClause, bool structured, bool implicit,
+ llvm::ArrayRef<mlir::Value> async,
+ llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
+ llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
+ bool setDeclareAttr = false,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, Fortran::semantics::SymbolRef>>
+ *symbolPairs = nullptr) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
const bool unwrapBoxAddr = true;
@@ -655,6 +669,9 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
/*strideIncludeLowerExtent=*/strideIncludeLowerExtent);
LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
+ bool isWholeSymbol =
+ !designator || Fortran::evaluate::UnwrapWholeSymbolDataRef(*designator);
+
// If the input value is optional and is not a descriptor, we use the
// rawInput directly.
mlir::Value baseAddr = ((fir::unwrapRefType(info.addr.getType()) !=
@@ -668,6 +685,11 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
asyncOnlyDeviceTypes, unwrapBoxAddr, info.isPresent);
dataOperands.push_back(op.getAccVar());
+ // Track the symbol and its corresponding mlir::Value if requested
+ if (symbolPairs && isWholeSymbol)
+ symbolPairs->emplace_back(op.getAccVar(),
+ Fortran::semantics::SymbolRef(symbol));
+
// For UseDeviceOp, if operand is one of a pair resulting from a
// declare operation, create a UseDeviceOp for the other operand as well.
if constexpr (std::is_same_v<Op, mlir::acc::UseDeviceOp>) {
@@ -681,6 +703,8 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
asyncDeviceTypes, asyncOnlyDeviceTypes,
unwrapBoxAddr, info.isPresent);
dataOperands.push_back(op.getAccVar());
+ // Not adding this to symbolPairs because it only make sense to
+ // map the symbol to a single value.
}
}
}
@@ -1264,7 +1288,9 @@ static void genPrivatizationRecipes(
llvm::SmallVector<mlir::Attribute> &privatizationRecipes,
llvm::ArrayRef<mlir::Value> async,
llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
- llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes) {
+ llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, Fortran::semantics::SymbolRef>>
+ *symbolPairs = nullptr) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
@@ -1284,6 +1310,9 @@ static void genPrivatizationRecipes(
/*strideIncludeLowerExtent=*/strideIncludeLowerExtent);
LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
+ bool isWholeSymbol =
+ !designator || Fortran::evaluate::UnwrapWholeSymbolDataRef(*designator);
+
RecipeOp recipe;
mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType());
if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
@@ -1297,6 +1326,11 @@ static void genPrivatizationRecipes(
/*implicit=*/false, mlir::acc::DataClause::acc_private, retTy, async,
asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true);
dataOperands.push_back(op.getAccVar());
+
+ // Track the symbol and its corresponding mlir::Value if requested
+ if (symbolPairs && isWholeSymbol)
+ symbolPairs->emplace_back(op.getAccVar(),
+ Fortran::semantics::SymbolRef(symbol));
} else {
std::string suffix =
areAllBoundConstant(bounds) ? getBoundsString(bounds) : "";
@@ -1310,6 +1344,11 @@ static void genPrivatizationRecipes(
async, asyncDeviceTypes, asyncOnlyDeviceTypes,
/*unwrapBoxAddr=*/true);
dataOperands.push_back(op.getAccVar());
+
+ // Track the symbol and its corresponding mlir::Value if requested
+ if (symbolPairs && isWholeSymbol)
+ symbolPairs->emplace_back(op.getAccVar(),
+ Fortran::semantics::SymbolRef(symbol));
}
privatizationRecipes.push_back(mlir::SymbolRefAttr::get(
builder.getContext(), recipe.getSymName().str()));
@@ -1949,15 +1988,16 @@ mlir::Type getTypeFromIvTypeSize(fir::FirOpBuilder &builder,
return builder.getIntegerType(ivTypeSize * 8);
}
-static void
-privatizeIv(Fortran::lower::AbstractConverter &converter,
- const Fortran::semantics::Symbol &sym, mlir::Location loc,
- llvm::SmallVector<mlir::Type> &ivTypes,
- llvm::SmallVector<mlir::Location> &ivLocs,
- llvm::SmallVector<mlir::Value> &privateOperands,
- llvm::SmallVector<mlir::Value> &ivPrivate,
- llvm::SmallVector<mlir::Attribute> &privatizationRecipes,
- bool isDoConcurrent = false) {
+static void privatizeIv(
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::semantics::Symbol &sym, mlir::Location loc,
+ llvm::SmallVector<mlir::Type> &ivTypes,
+ llvm::SmallVector<mlir::Location> &ivLocs,
+ llvm::SmallVector<mlir::Value> &privateOperands,
+ llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>>
+ &ivPrivate,
+ llvm::SmallVector<mlir::Attribute> &privatizationRecipes,
+ bool isDoConcurrent = false) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Type ivTy = getTypeFromIvTypeSize(builder, sym);
@@ -2001,15 +2041,8 @@ privatizeIv(Fortran::lower::AbstractConverter &converter,
builder.getContext(), recipe.getSymName().str()));
}
- // Map the new private iv to its symbol for the scope of the loop. bindSymbol
- // might create a hlfir.declare op, if so, we map its result in order to
- // use the sym value in the scope.
- converter.bindSymbol(sym, mlir::acc::getAccVar(privateOp));
- auto privateValue = converter.getSymbolAddress(sym);
- if (auto declareOp =
- mlir::dyn_cast<hlfir::DeclareOp>(privateValue.getDefiningOp()))
- privateValue = declareOp.getResults()[0];
- ivPrivate.push_back(privateValue);
+ ivPrivate.emplace_back(mlir::acc::getAccVar(privateOp),
+ Fortran::semantics::SymbolRef(sym));
}
static void determineDefaultLoopParMode(
@@ -2088,7 +2121,8 @@ static void processDoLoopBounds(
llvm::SmallVector<mlir::Value> &upperbounds,
llvm::SmallVector<mlir::Value> &steps,
llvm::SmallVector<mlir::Value> &privateOperands,
- llvm::SmallVector<mlir::Value> &ivPrivate,
+ llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>>
+ &ivPrivate,
llvm::SmallVector<mlir::Attribute> &privatizationRecipes,
llvm::SmallVector<mlir::Type> &ivTypes,
llvm::SmallVector<mlir::Location> &ivLocs,
@@ -2144,11 +2178,25 @@ static void processDoLoopBounds(
locs.push_back(converter.genLocation(
Fortran::parser::FindSourceLocation(outerDoConstruct)));
} else {
- auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
- assert(doCons && "expect do construct");
- loopControl = &*doCons->GetLoopControl();
+ // Safely locate the next inner DoConstruct within this eval.
+ const Fortran::parser::DoConstruct *innerDo = nullptr;
+ if (crtEval && crtEval->hasNestedEvaluations()) {
+ for (Fortran::lower::pft::Evaluation &child :
+ crtEval->getNestedEvaluations()) {
+ if (auto *stmt = child.getIf<Fortran::parser::DoConstruct>()) {
+ innerDo = stmt;
+ // Prepare to descend for the next iteration
+ crtEval = &child;
+ break;
+ }
+ }
+ }
+ if (!innerDo)
+ break; // No deeper loop; stop collecting collapsed bounds.
+
+ loopControl = &*innerDo->GetLoopControl();
locs.push_back(converter.genLocation(
- Fortran::parser::FindSourceLocation(*doCons)));
+ Fortran::parser::FindSourceLocation(*innerDo)));
}
const Fortran::parser::LoopControl::Bounds *bounds =
@@ -2172,32 +2220,127 @@ static void processDoLoopBounds(
inclusiveBounds.push_back(true);
- if (i < loopsToProcess - 1)
- crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
+ // crtEval already updated when descending; no blind increment here.
}
}
}
-static mlir::acc::LoopOp
-buildACCLoopOp(Fortran::lower::AbstractConverter &converter,
- mlir::Location currentLocation,
- Fortran::semantics::SemanticsContext &semanticsContext,
- Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::DoConstruct &outerDoConstruct,
- Fortran::lower::pft::Evaluation &eval,
- llvm::SmallVector<mlir::Value> &privateOperands,
- llvm::SmallVector<mlir::Attribute> &privatizationRecipes,
- llvm::SmallVector<mlir::Value> &gangOperands,
- llvm::SmallVector<mlir::Value> &workerNumOperands,
- llvm::SmallVector<mlir::Value> &vectorOperands,
- llvm::SmallVector<mlir::Value> &tileOperands,
- llvm::SmallVector<mlir::Value> &cacheOperands,
- llvm::SmallVector<mlir::Value> &reductionOperands,
- llvm::SmallVector<mlir::Type> &retTy, mlir::Value yieldValue,
- uint64_t loopsToProcess) {
+/// Remap symbols that appeared in OpenACC data clauses to use the results of
+/// the corresponding data operations. This allows isolating symbol accesses
+/// inside the OpenACC region from accesses in the host and other regions while
+/// preserving Fortran information about the symbols for optimizations.
+template <typename RegionOp>
+static void remapDataOperandSymbols(
+ Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder,
+ RegionOp &regionOp,
+ const llvm::SmallVector<
+ std::pair<mlir::Value, Fortran::semantics::SymbolRef>>
+ &dataOperandSymbolPairs) {
+ if (!enableSymbolRemapping || dataOperandSymbolPairs.empty())
+ return;
+
+ // Map Symbols that appeared inside data clauses to a new hlfir.declare whose
+ // input is the acc data operation result.
+ // This allows isolating all the symbol accesses inside the compute region
+ // from accesses in the host and other regions while preserving the Fortran
+ // information about the symbols for Fortran specific optimizations inside the
+ // region.
+ Fortran::lower::SymMap &symbolMap = converter.getSymbolMap();
+ mlir::OpBuilder::InsertionGuard insertGuard(builder);
+ builder.setInsertionPointToStart(&regionOp.getRegion().front());
+ llvm::SmallPtrSet<const Fortran::semantics::Symbol *, 8> seenSymbols;
+ mlir::IRMapping mapper;
+ for (auto [value, symbol] : dataOperandSymbolPairs) {
+
+ // If A symbol appears on several data clause, just map it to the first
+ // result (all data operations results for a symbol are pointing same
+ // memory, so it does not matter which one is used).
+ if (seenSymbols.contains(&symbol.get()))
+ continue;
+ seenSymbols.insert(&symbol.get());
+ std::optional<fir::FortranVariableOpInterface> hostDef =
+ symbolMap.lookupVariableDefinition(symbol);
+ assert(hostDef.has_value() && llvm::isa<hlfir::DeclareOp>(*hostDef) &&
+ "expected symbol to be mapped to hlfir.declare");
+ auto hostDeclare = llvm::cast<hlfir::DeclareOp>(*hostDef);
+ // Replace base input and DummyScope inputs.
+ mlir::Value hostInput = hostDeclare.getMemref();
+ mlir::Type hostType = hostInput.getType();
+ mlir::Type computeType = value.getType();
+ if (hostType == computeType) {
+ mapper.map(hostInput, value);
+ } else if (llvm::isa<fir::BaseBoxType>(computeType)) {
+ assert(!llvm::isa<fir::BaseBoxType>(hostType) &&
+ "box type mismatch between compute region variable and "
+ "hlfir.declare input unexpected");
+ if (Fortran::semantics::IsOptional(symbol))
+ TODO(regionOp.getLoc(),
+ "remapping OPTIONAL symbol in OpenACC compute region");
+ auto rawValue =
+ fir::BoxAddrOp::create(builder, regionOp.getLoc(), hostType, value);
+ mapper.map(hostInput, rawValue);
+ } else {
+ assert(!llvm::isa<fir::BaseBoxType>(hostType) &&
+ "compute region variable should not be raw address when host "
+ "hlfir.declare input was a box");
+ assert(fir::isBoxAddress(hostType) == fir::isBoxAddress(computeType) &&
+ "compute region variable should be a pointer/allocatable if and "
+ "only if host is");
+ assert(fir::isa_ref_type(hostType) && fir::isa_ref_type(computeType) &&
+ "compute region variable and host variable should both be raw "
+ "addresses");
+ mlir::Value cast =
+ builder.createConvert(regionOp.getLoc(), hostType, value);
+ mapper.map(hostInput, cast);
+ }
+ if (mlir::Value dummyScope = hostDeclare.getDummyScope()) {
+ // Copy the dummy scope into the region so that aliasing rules about
+ // Fortran dummies are understood inside the region and the abstract dummy
+ // scope type does not have to cross the OpenACC compute region boundary.
+ if (!mapper.contains(dummyScope)) {
+ mlir::Operation *hostDummyScopeOp = dummyScope.getDefiningOp();
+ assert(hostDummyScopeOp &&
+ "dummyScope defining operation must be visible in lowering");
+ (void)builder.clone(*hostDummyScopeOp, mapper);
+ }
+ }
+
+ mlir::Operation *computeDef =
+ builder.clone(*hostDeclare.getOperation(), mapper);
+
+ // The input box already went through an hlfir.declare. It has the correct
+ // local lower bounds and attribute. Do not generate a new fir.rebox.
+ if (llvm::isa<fir::BaseBoxType>(hostDeclare.getMemref().getType()))
+ llvm::cast<hlfir::DeclareOp>(*computeDef).setSkipRebox(true);
+
+ symbolMap.addVariableDefinition(
+ symbol, llvm::cast<fir::FortranVariableOpInterface>(computeDef));
+ }
+}
+
+static mlir::acc::LoopOp buildACCLoopOp(
+ Fortran::lower::AbstractConverter &converter,
+ mlir::Location currentLocation,
+ Fortran::semantics::SemanticsContext &semanticsContext,
+ Fortran::lower::StatementContext &stmtCtx,
+ const Fortran::parser::DoConstruct &outerDoConstruct,
+ Fortran::lower::pft::Evaluation &eval,
+ llvm::SmallVector<mlir::Value> &privateOperands,
+ llvm::SmallVector<mlir::Attribute> &privatizationRecipes,
+ llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>>
+ &dataOperandSymbolPairs,
+ llvm::SmallVector<mlir::Value> &gangOperands,
+ llvm::SmallVector<mlir::Value> &workerNumOperands,
+ llvm::SmallVector<mlir::Value> &vectorOperands,
+ llvm::SmallVector<mlir::Value> &tileOperands,
+ llvm::SmallVector<mlir::Value> &cacheOperands,
+ llvm::SmallVector<mlir::Value> &reductionOperands,
+ llvm::SmallVector<mlir::Type> &retTy, mlir::Value yieldValue,
+ uint64_t loopsToProcess) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
- llvm::SmallVector<mlir::Value> ivPrivate;
+ llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>>
+ ivPrivate;
llvm::SmallVector<mlir::Type> ivTypes;
llvm::SmallVector<mlir::Location> ivLocs;
llvm::SmallVector<bool> inclusiveBounds;
@@ -2231,10 +2374,22 @@ buildACCLoopOp(Fortran::lower::AbstractConverter &converter,
builder, builder.getFusedLoc(locs), currentLocation, eval, operands,
operandSegments, /*outerCombined=*/false, retTy, yieldValue, ivTypes,
ivLocs);
-
- for (auto [arg, value] : llvm::zip(
- loopOp.getLoopRegions().front()->front().getArguments(), ivPrivate))
- fir::StoreOp::create(builder, currentLocation, arg, value);
+ // Ensure the iv symbol is mapped to private iv SSA value for the scope of
+ // the loop even if it did not appear explicitly in a PRIVATE clause (if it
+ // appeared explicitly in such clause, that is also fine because duplicates
+ // in the list are ignored).
+ dataOperandSymbolPairs.append(ivPrivate.begin(), ivPrivate.end());
+ // Remap symbols from data clauses to use data operation results
+ remapDataOperandSymbols(converter, builder, loopOp, dataOperandSymbolPairs);
+
+ for (auto [arg, iv] :
+ llvm::zip(loopOp.getLoopRegions().front()->front().getArguments(),
+ ivPrivate)) {
+ // Store block argument to the related iv private variable.
+ mlir::Value privateValue =
+ converter.getSymbolAddress(std::get<Fortran::semantics::SymbolRef>(iv));
+ fir::StoreOp::create(builder, currentLocation, arg, privateValue);
+ }
loopOp.setInclusiveUpperbound(inclusiveBounds);
@@ -2260,6 +2415,10 @@ static mlir::acc::LoopOp createLoopOp(
llvm::SmallVector<int32_t> tileOperandsSegments, gangOperandsSegments;
llvm::SmallVector<int64_t> collapseValues;
+ // Vector to track mlir::Value results and their corresponding Fortran symbols
+ llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>>
+ dataOperandSymbolPairs;
+
llvm::SmallVector<mlir::Attribute> gangArgTypes;
llvm::SmallVector<mlir::Attribute> seqDeviceTypes, independentDeviceTypes,
autoDeviceTypes, vectorOperandsDeviceTypes, workerNumOperandsDeviceTypes,
@@ -2380,7 +2539,8 @@ static mlir::acc::LoopOp createLoopOp(
genPrivatizationRecipes<mlir::acc::PrivateRecipeOp>(
privateClause->v, converter, semanticsContext, stmtCtx,
privateOperands, privatizationRecipes, /*async=*/{},
- /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
+ /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{},
+ &dataOperandSymbolPairs);
} else if (const auto *reductionClause =
std::get_if<Fortran::parser::AccClause::Reduction>(
&clause.u)) {
@@ -2406,10 +2566,6 @@ static mlir::acc::LoopOp createLoopOp(
std::get_if<Fortran::parser::AccClause::Collapse>(
&clause.u)) {
const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
- const auto &force = std::get<bool>(arg.t);
- if (force)
- TODO(clauseLocation, "OpenACC collapse force modifier");
-
const auto &intExpr =
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
const auto *expr = Fortran::semantics::GetExpr(intExpr);
@@ -2436,9 +2592,9 @@ static mlir::acc::LoopOp createLoopOp(
Fortran::lower::getLoopCountForCollapseAndTile(accClauseList);
auto loopOp = buildACCLoopOp(
converter, currentLocation, semanticsContext, stmtCtx, outerDoConstruct,
- eval, privateOperands, privatizationRecipes, gangOperands,
- workerNumOperands, vectorOperands, tileOperands, cacheOperands,
- reductionOperands, retTy, yieldValue, loopsToProcess);
+ eval, privateOperands, privatizationRecipes, dataOperandSymbolPairs,
+ gangOperands, workerNumOperands, vectorOperands, tileOperands,
+ cacheOperands, reductionOperands, retTy, yieldValue, loopsToProcess);
if (!gangDeviceTypes.empty())
loopOp.setGangAttr(builder.getArrayAttr(gangDeviceTypes));
@@ -2568,7 +2724,9 @@ static void genDataOperandOperationsWithModifier(
llvm::ArrayRef<mlir::Value> async,
llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes,
llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes,
- bool setDeclareAttr = false) {
+ bool setDeclareAttr = false,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, Fortran::semantics::SymbolRef>>
+ *symbolPairs = nullptr) {
const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v;
const auto &accObjectList =
std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
@@ -2581,7 +2739,7 @@ static void genDataOperandOperationsWithModifier(
stmtCtx, dataClauseOperands, dataClause,
/*structured=*/true, /*implicit=*/false, async,
asyncDeviceTypes, asyncOnlyDeviceTypes,
- setDeclareAttr);
+ setDeclareAttr, symbolPairs);
}
template <typename Op>
@@ -2612,6 +2770,10 @@ static Op createComputeOp(
llvm::SmallVector<mlir::Attribute> privatizationRecipes,
firstPrivatizationRecipes, reductionRecipes;
+ // Vector to track mlir::Value results and their corresponding Fortran symbols
+ llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>>
+ dataOperandSymbolPairs;
+
// Self clause has optional values but can be present with
// no value as well. When there is no value, the op has an attribute to
// represent the clause.
@@ -2732,7 +2894,8 @@ static Op createComputeOp(
copyClause->v, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_copy,
/*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
- asyncOnlyDeviceTypes);
+ asyncOnlyDeviceTypes, /*setDeclareAttr=*/false,
+ &dataOperandSymbolPairs);
copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
dataClauseOperands.end());
} else if (const auto *copyinClause =
@@ -2744,7 +2907,8 @@ static Op createComputeOp(
Fortran::parser::AccDataModifier::Modifier::ReadOnly,
dataClauseOperands, mlir::acc::DataClause::acc_copyin,
mlir::acc::DataClause::acc_copyin_readonly, async, asyncDeviceTypes,
- asyncOnlyDeviceTypes);
+ asyncOnlyDeviceTypes, /*setDeclareAttr=*/false,
+ &dataOperandSymbolPairs);
copyinEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
dataClauseOperands.end());
} else if (const auto *copyoutClause =
@@ -2757,7 +2921,8 @@ static Op createComputeOp(
Fortran::parser::AccDataModifier::Modifier::ReadOnly,
dataClauseOperands, mlir::acc::DataClause::acc_copyout,
mlir::acc::DataClause::acc_copyout_zero, async, asyncDeviceTypes,
- asyncOnlyDeviceTypes);
+ asyncOnlyDeviceTypes, /*setDeclareAttr=*/false,
+ &dataOperandSymbolPairs);
copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
dataClauseOperands.end());
} else if (const auto *createClause =
@@ -2769,7 +2934,8 @@ static Op createComputeOp(
Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands,
mlir::acc::DataClause::acc_create,
mlir::acc::DataClause::acc_create_zero, async, asyncDeviceTypes,
- asyncOnlyDeviceTypes);
+ asyncOnlyDeviceTypes, /*setDeclareAttr=*/false,
+ &dataOperandSymbolPairs);
createEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
dataClauseOperands.end());
} else if (const auto *noCreateClause =
@@ -2780,7 +2946,8 @@ static Op createComputeOp(
noCreateClause->v, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_no_create,
/*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
- asyncOnlyDeviceTypes);
+ asyncOnlyDeviceTypes, /*setDeclareAttr=*/false,
+ &dataOperandSymbolPairs);
nocreateEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
dataClauseOperands.end());
} else if (const auto *presentClause =
@@ -2791,17 +2958,21 @@ static Op createComputeOp(
presentClause->v, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_present,
/*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
- asyncOnlyDeviceTypes);
+ asyncOnlyDeviceTypes, /*setDeclareAttr=*/false,
+ &dataOperandSymbolPairs);
presentEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
dataClauseOperands.end());
} else if (const auto *devicePtrClause =
std::get_if<Fortran::parser::AccClause::Deviceptr>(
&clause.u)) {
+ llvm::SmallVectorImpl<
+ std::pair<mlir::Value, Fortran::semantics::SymbolRef>> *symPairs =
+ enableDevicePtrRemap ? &dataOperandSymbolPairs : nullptr;
genDataOperandOperations<mlir::acc::DevicePtrOp>(
devicePtrClause->v, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_deviceptr,
/*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
- asyncOnlyDeviceTypes);
+ asyncOnlyDeviceTypes, /*setDeclareAttr=*/false, symPairs);
} else if (const auto *attachClause =
std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
auto crtDataStart = dataClauseOperands.size();
@@ -2809,7 +2980,8 @@ static Op createComputeOp(
attachClause->v, converter, semanticsContext, stmtCtx,
dataClauseOperands, mlir::acc::DataClause::acc_attach,
/*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes,
- asyncOnlyDeviceTypes);
+ asyncOnlyDeviceTypes, /*setDeclareAttr=*/false,
+ &dataOperandSymbolPairs);
attachEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
dataClauseOperands.end());
} else if (const auto *privateClause =
@@ -2819,14 +2991,14 @@ static Op createComputeOp(
genPrivatizationRecipes<mlir::acc::PrivateRecipeOp>(
privateClause->v, converter, semanticsContext, stmtCtx,
privateOperands, privatizationRecipes, async, asyncDeviceTypes,
- asyncOnlyDeviceTypes);
+ asyncOnlyDeviceTypes, &dataOperandSymbolPairs);
} else if (const auto *firstprivateClause =
std::get_if<Fortran::parser::AccClause::Firstprivate>(
&clause.u)) {
genPrivatizationRecipes<mlir::acc::FirstprivateRecipeOp>(
firstprivateClause->v, converter, semanticsContext, stmtCtx,
firstprivateOperands, firstPrivatizationRecipes, async,
- asyncDeviceTypes, asyncOnlyDeviceTypes);
+ asyncDeviceTypes, asyncOnlyDeviceTypes, &dataOperandSymbolPairs);
} else if (const auto *reductionClause =
std::get_if<Fortran::parser::AccClause::Reduction>(
&clause.u)) {
@@ -2846,7 +3018,8 @@ static Op createComputeOp(
converter, semanticsContext, stmtCtx, dataClauseOperands,
mlir::acc::DataClause::acc_reduction,
/*structured=*/true, /*implicit=*/true, async, asyncDeviceTypes,
- asyncOnlyDeviceTypes);
+ asyncOnlyDeviceTypes, /*setDeclareAttr=*/false,
+ &dataOperandSymbolPairs);
copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart,
dataClauseOperands.end());
}
@@ -2945,6 +3118,11 @@ static Op createComputeOp(
computeOp.setCombinedAttr(builder.getUnitAttr());
auto insPt = builder.saveInsertionPoint();
+
+ // Remap symbols from data clauses to use data operation results
+ remapDataOperandSymbols(converter, builder, computeOp,
+ dataOperandSymbolPairs);
+
builder.setInsertionPointAfter(computeOp);
// Create the exit operations after the region.
@@ -4860,25 +5038,34 @@ void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
uint64_t Fortran::lower::getLoopCountForCollapseAndTile(
const Fortran::parser::AccClauseList &clauseList) {
- uint64_t collapseLoopCount = 1;
+ uint64_t collapseLoopCount = getCollapseSizeAndForce(clauseList).first;
uint64_t tileLoopCount = 1;
for (const Fortran::parser::AccClause &clause : clauseList.v) {
- if (const auto *collapseClause =
- std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
- const parser::AccCollapseArg &arg = collapseClause->v;
- const auto &collapseValue{std::get<parser::ScalarIntConstantExpr>(arg.t)};
- collapseLoopCount = *Fortran::semantics::GetIntValue(collapseValue);
- }
if (const auto *tileClause =
std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
const parser::AccTileExprList &tileExprList = tileClause->v;
- const std::list<parser::AccTileExpr> &listTileExpr = tileExprList.v;
- tileLoopCount = listTileExpr.size();
+ tileLoopCount = tileExprList.v.size();
}
}
- if (tileLoopCount > collapseLoopCount)
- return tileLoopCount;
- return collapseLoopCount;
+ return tileLoopCount > collapseLoopCount ? tileLoopCount : collapseLoopCount;
+}
+
+std::pair<uint64_t, bool> Fortran::lower::getCollapseSizeAndForce(
+ const Fortran::parser::AccClauseList &clauseList) {
+ uint64_t size = 1;
+ bool force = false;
+ for (const Fortran::parser::AccClause &clause : clauseList.v) {
+ if (const auto *collapseClause =
+ std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
+ const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
+ force = std::get<bool>(arg.t);
+ const auto &collapseValue =
+ std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
+ size = *Fortran::semantics::GetIntValue(collapseValue);
+ break;
+ }
+ }
+ return {size, force};
}
/// Create an ACC loop operation for a DO construct when inside ACC compute
@@ -4921,6 +5108,8 @@ mlir::Operation *Fortran::lower::genOpenACCLoopFromDoConstruct(
reductionOperands;
llvm::SmallVector<mlir::Attribute> privatizationRecipes;
llvm::SmallVector<mlir::Type> retTy;
+ llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>>
+ dataOperandSymbolPairs;
mlir::Value yieldValue;
uint64_t loopsToProcess = 1; // Single loop construct
@@ -4929,9 +5118,10 @@ mlir::Operation *Fortran::lower::genOpenACCLoopFromDoConstruct(
Fortran::lower::StatementContext stmtCtx;
auto loopOp = buildACCLoopOp(
converter, converter.getCurrentLocation(), semanticsContext, stmtCtx,
- doConstruct, eval, privateOperands, privatizationRecipes, gangOperands,
- workerNumOperands, vectorOperands, tileOperands, cacheOperands,
- reductionOperands, retTy, yieldValue, loopsToProcess);
+ doConstruct, eval, privateOperands, privatizationRecipes,
+ dataOperandSymbolPairs, gangOperands, workerNumOperands, vectorOperands,
+ tileOperands, cacheOperands, reductionOperands, retTy, yieldValue,
+ loopsToProcess);
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
if (!privatizationRecipes.empty())