diff options
Diffstat (limited to 'flang/lib/Lower/OpenACC.cpp')
-rw-r--r-- | flang/lib/Lower/OpenACC.cpp | 318 |
1 files changed, 245 insertions, 73 deletions
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 4a9e494..742f58f 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -20,6 +20,7 @@ #include "flang/Lower/PFTBuilder.h" #include "flang/Lower/StatementContext.h" #include "flang/Lower/Support/Utils.h" +#include "flang/Lower/SymbolMap.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/Complex.h" #include "flang/Optimizer/Builder/FIRBuilder.h" @@ -33,6 +34,7 @@ #include "flang/Semantics/scope.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" @@ -60,6 +62,16 @@ static llvm::cl::opt<bool> lowerDoLoopToAccLoop( llvm::cl::desc("Whether to lower do loops as `acc.loop` operations."), llvm::cl::init(true)); +static llvm::cl::opt<bool> enableSymbolRemapping( + "openacc-remap-symbols", + llvm::cl::desc("Whether to remap symbols that appears in data clauses."), + llvm::cl::init(true)); + +static llvm::cl::opt<bool> enableDevicePtrRemap( + "openacc-remap-device-ptr-symbols", + llvm::cl::desc("sub-option of openacc-remap-symbols for deviceptr clause"), + llvm::cl::init(false)); + // Special value for * passed in device_type or gang clauses. static constexpr std::int64_t starCst = -1; @@ -624,17 +636,19 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter, } template <typename Op> -static void -genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semanticsContext, - Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl<mlir::Value> &dataOperands, - mlir::acc::DataClause dataClause, bool structured, - bool implicit, llvm::ArrayRef<mlir::Value> async, - llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes, - llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes, - bool setDeclareAttr = false) { +static void genDataOperandOperations( + const Fortran::parser::AccObjectList &objectList, + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::lower::StatementContext &stmtCtx, + llvm::SmallVectorImpl<mlir::Value> &dataOperands, + mlir::acc::DataClause dataClause, bool structured, bool implicit, + llvm::ArrayRef<mlir::Value> async, + llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes, + llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes, + bool setDeclareAttr = false, + llvm::SmallVectorImpl<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> + *symbolPairs = nullptr) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; const bool unwrapBoxAddr = true; @@ -655,6 +669,9 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, /*strideIncludeLowerExtent=*/strideIncludeLowerExtent); LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs())); + bool isWholeSymbol = + !designator || Fortran::evaluate::UnwrapWholeSymbolDataRef(*designator); + // If the input value is optional and is not a descriptor, we use the // rawInput directly. mlir::Value baseAddr = ((fir::unwrapRefType(info.addr.getType()) != @@ -668,6 +685,11 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, asyncOnlyDeviceTypes, unwrapBoxAddr, info.isPresent); dataOperands.push_back(op.getAccVar()); + // Track the symbol and its corresponding mlir::Value if requested + if (symbolPairs && isWholeSymbol) + symbolPairs->emplace_back(op.getAccVar(), + Fortran::semantics::SymbolRef(symbol)); + // For UseDeviceOp, if operand is one of a pair resulting from a // declare operation, create a UseDeviceOp for the other operand as well. if constexpr (std::is_same_v<Op, mlir::acc::UseDeviceOp>) { @@ -681,6 +703,8 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList, asyncDeviceTypes, asyncOnlyDeviceTypes, unwrapBoxAddr, info.isPresent); dataOperands.push_back(op.getAccVar()); + // Not adding this to symbolPairs because it only make sense to + // map the symbol to a single value. } } } @@ -1264,7 +1288,9 @@ static void genPrivatizationRecipes( llvm::SmallVector<mlir::Attribute> &privatizationRecipes, llvm::ArrayRef<mlir::Value> async, llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes, - llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes) { + llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes, + llvm::SmallVectorImpl<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> + *symbolPairs = nullptr) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext}; for (const auto &accObject : objectList.v) { @@ -1284,6 +1310,9 @@ static void genPrivatizationRecipes( /*strideIncludeLowerExtent=*/strideIncludeLowerExtent); LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs())); + bool isWholeSymbol = + !designator || Fortran::evaluate::UnwrapWholeSymbolDataRef(*designator); + RecipeOp recipe; mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType()); if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) { @@ -1297,6 +1326,11 @@ static void genPrivatizationRecipes( /*implicit=*/false, mlir::acc::DataClause::acc_private, retTy, async, asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true); dataOperands.push_back(op.getAccVar()); + + // Track the symbol and its corresponding mlir::Value if requested + if (symbolPairs && isWholeSymbol) + symbolPairs->emplace_back(op.getAccVar(), + Fortran::semantics::SymbolRef(symbol)); } else { std::string suffix = areAllBoundConstant(bounds) ? getBoundsString(bounds) : ""; @@ -1310,6 +1344,11 @@ static void genPrivatizationRecipes( async, asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true); dataOperands.push_back(op.getAccVar()); + + // Track the symbol and its corresponding mlir::Value if requested + if (symbolPairs && isWholeSymbol) + symbolPairs->emplace_back(op.getAccVar(), + Fortran::semantics::SymbolRef(symbol)); } privatizationRecipes.push_back(mlir::SymbolRefAttr::get( builder.getContext(), recipe.getSymName().str())); @@ -1949,15 +1988,16 @@ mlir::Type getTypeFromIvTypeSize(fir::FirOpBuilder &builder, return builder.getIntegerType(ivTypeSize * 8); } -static void -privatizeIv(Fortran::lower::AbstractConverter &converter, - const Fortran::semantics::Symbol &sym, mlir::Location loc, - llvm::SmallVector<mlir::Type> &ivTypes, - llvm::SmallVector<mlir::Location> &ivLocs, - llvm::SmallVector<mlir::Value> &privateOperands, - llvm::SmallVector<mlir::Value> &ivPrivate, - llvm::SmallVector<mlir::Attribute> &privatizationRecipes, - bool isDoConcurrent = false) { +static void privatizeIv( + Fortran::lower::AbstractConverter &converter, + const Fortran::semantics::Symbol &sym, mlir::Location loc, + llvm::SmallVector<mlir::Type> &ivTypes, + llvm::SmallVector<mlir::Location> &ivLocs, + llvm::SmallVector<mlir::Value> &privateOperands, + llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> + &ivPrivate, + llvm::SmallVector<mlir::Attribute> &privatizationRecipes, + bool isDoConcurrent = false) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); mlir::Type ivTy = getTypeFromIvTypeSize(builder, sym); @@ -2001,15 +2041,8 @@ privatizeIv(Fortran::lower::AbstractConverter &converter, builder.getContext(), recipe.getSymName().str())); } - // Map the new private iv to its symbol for the scope of the loop. bindSymbol - // might create a hlfir.declare op, if so, we map its result in order to - // use the sym value in the scope. - converter.bindSymbol(sym, mlir::acc::getAccVar(privateOp)); - auto privateValue = converter.getSymbolAddress(sym); - if (auto declareOp = - mlir::dyn_cast<hlfir::DeclareOp>(privateValue.getDefiningOp())) - privateValue = declareOp.getResults()[0]; - ivPrivate.push_back(privateValue); + ivPrivate.emplace_back(mlir::acc::getAccVar(privateOp), + Fortran::semantics::SymbolRef(sym)); } static void determineDefaultLoopParMode( @@ -2088,7 +2121,8 @@ static void processDoLoopBounds( llvm::SmallVector<mlir::Value> &upperbounds, llvm::SmallVector<mlir::Value> &steps, llvm::SmallVector<mlir::Value> &privateOperands, - llvm::SmallVector<mlir::Value> &ivPrivate, + llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> + &ivPrivate, llvm::SmallVector<mlir::Attribute> &privatizationRecipes, llvm::SmallVector<mlir::Type> &ivTypes, llvm::SmallVector<mlir::Location> &ivLocs, @@ -2178,26 +2212,122 @@ static void processDoLoopBounds( } } -static mlir::acc::LoopOp -buildACCLoopOp(Fortran::lower::AbstractConverter &converter, - mlir::Location currentLocation, - Fortran::semantics::SemanticsContext &semanticsContext, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::DoConstruct &outerDoConstruct, - Fortran::lower::pft::Evaluation &eval, - llvm::SmallVector<mlir::Value> &privateOperands, - llvm::SmallVector<mlir::Attribute> &privatizationRecipes, - llvm::SmallVector<mlir::Value> &gangOperands, - llvm::SmallVector<mlir::Value> &workerNumOperands, - llvm::SmallVector<mlir::Value> &vectorOperands, - llvm::SmallVector<mlir::Value> &tileOperands, - llvm::SmallVector<mlir::Value> &cacheOperands, - llvm::SmallVector<mlir::Value> &reductionOperands, - llvm::SmallVector<mlir::Type> &retTy, mlir::Value yieldValue, - uint64_t loopsToProcess) { +/// Remap symbols that appeared in OpenACC data clauses to use the results of +/// the corresponding data operations. This allows isolating symbol accesses +/// inside the OpenACC region from accesses in the host and other regions while +/// preserving Fortran information about the symbols for optimizations. +template <typename RegionOp> +static void remapDataOperandSymbols( + Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder, + RegionOp ®ionOp, + const llvm::SmallVector< + std::pair<mlir::Value, Fortran::semantics::SymbolRef>> + &dataOperandSymbolPairs) { + if (!enableSymbolRemapping || dataOperandSymbolPairs.empty()) + return; + + // Map Symbols that appeared inside data clauses to a new hlfir.declare whose + // input is the acc data operation result. + // This allows isolating all the symbol accesses inside the compute region + // from accesses in the host and other regions while preserving the Fortran + // information about the symbols for Fortran specific optimizations inside the + // region. + Fortran::lower::SymMap &symbolMap = converter.getSymbolMap(); + mlir::OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(®ionOp.getRegion().front()); + llvm::SmallPtrSet<const Fortran::semantics::Symbol *, 8> seenSymbols; + mlir::IRMapping mapper; + for (auto [value, symbol] : dataOperandSymbolPairs) { + + // If A symbol appears on several data clause, just map it to the first + // result (all data operations results for a symbol are pointing same + // memory, so it does not matter which one is used). + if (seenSymbols.contains(&symbol.get())) + continue; + seenSymbols.insert(&symbol.get()); + std::optional<fir::FortranVariableOpInterface> hostDef = + symbolMap.lookupVariableDefinition(symbol); + assert(hostDef.has_value() && llvm::isa<hlfir::DeclareOp>(*hostDef) && + "expected symbol to be mapped to hlfir.declare"); + auto hostDeclare = llvm::cast<hlfir::DeclareOp>(*hostDef); + // Replace base input and DummyScope inputs. + mlir::Value hostInput = hostDeclare.getMemref(); + mlir::Type hostType = hostInput.getType(); + mlir::Type computeType = value.getType(); + if (hostType == computeType) { + mapper.map(hostInput, value); + } else if (llvm::isa<fir::BaseBoxType>(computeType)) { + assert(!llvm::isa<fir::BaseBoxType>(hostType) && + "box type mismatch between compute region variable and " + "hlfir.declare input unexpected"); + if (Fortran::semantics::IsOptional(symbol)) + TODO(regionOp.getLoc(), + "remapping OPTIONAL symbol in OpenACC compute region"); + auto rawValue = + fir::BoxAddrOp::create(builder, regionOp.getLoc(), hostType, value); + mapper.map(hostInput, rawValue); + } else { + assert(!llvm::isa<fir::BaseBoxType>(hostType) && + "compute region variable should not be raw address when host " + "hlfir.declare input was a box"); + assert(fir::isBoxAddress(hostType) == fir::isBoxAddress(computeType) && + "compute region variable should be a pointer/allocatable if and " + "only if host is"); + assert(fir::isa_ref_type(hostType) && fir::isa_ref_type(computeType) && + "compute region variable and host variable should both be raw " + "addresses"); + mlir::Value cast = + builder.createConvert(regionOp.getLoc(), hostType, value); + mapper.map(hostInput, cast); + } + if (mlir::Value dummyScope = hostDeclare.getDummyScope()) { + // Copy the dummy scope into the region so that aliasing rules about + // Fortran dummies are understood inside the region and the abstract dummy + // scope type does not have to cross the OpenACC compute region boundary. + if (!mapper.contains(dummyScope)) { + mlir::Operation *hostDummyScopeOp = dummyScope.getDefiningOp(); + assert(hostDummyScopeOp && + "dummyScope defining operation must be visible in lowering"); + (void)builder.clone(*hostDummyScopeOp, mapper); + } + } + + mlir::Operation *computeDef = + builder.clone(*hostDeclare.getOperation(), mapper); + + // The input box already went through an hlfir.declare. It has the correct + // local lower bounds and attribute. Do not generate a new fir.rebox. + if (llvm::isa<fir::BaseBoxType>(hostDeclare.getMemref().getType())) + llvm::cast<hlfir::DeclareOp>(*computeDef).setSkipRebox(true); + + symbolMap.addVariableDefinition( + symbol, llvm::cast<fir::FortranVariableOpInterface>(computeDef)); + } +} + +static mlir::acc::LoopOp buildACCLoopOp( + Fortran::lower::AbstractConverter &converter, + mlir::Location currentLocation, + Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::DoConstruct &outerDoConstruct, + Fortran::lower::pft::Evaluation &eval, + llvm::SmallVector<mlir::Value> &privateOperands, + llvm::SmallVector<mlir::Attribute> &privatizationRecipes, + llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> + &dataOperandSymbolPairs, + llvm::SmallVector<mlir::Value> &gangOperands, + llvm::SmallVector<mlir::Value> &workerNumOperands, + llvm::SmallVector<mlir::Value> &vectorOperands, + llvm::SmallVector<mlir::Value> &tileOperands, + llvm::SmallVector<mlir::Value> &cacheOperands, + llvm::SmallVector<mlir::Value> &reductionOperands, + llvm::SmallVector<mlir::Type> &retTy, mlir::Value yieldValue, + uint64_t loopsToProcess) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); - llvm::SmallVector<mlir::Value> ivPrivate; + llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> + ivPrivate; llvm::SmallVector<mlir::Type> ivTypes; llvm::SmallVector<mlir::Location> ivLocs; llvm::SmallVector<bool> inclusiveBounds; @@ -2231,10 +2361,22 @@ buildACCLoopOp(Fortran::lower::AbstractConverter &converter, builder, builder.getFusedLoc(locs), currentLocation, eval, operands, operandSegments, /*outerCombined=*/false, retTy, yieldValue, ivTypes, ivLocs); - - for (auto [arg, value] : llvm::zip( - loopOp.getLoopRegions().front()->front().getArguments(), ivPrivate)) - fir::StoreOp::create(builder, currentLocation, arg, value); + // Ensure the iv symbol is mapped to private iv SSA value for the scope of + // the loop even if it did not appear explicitly in a PRIVATE clause (if it + // appeared explicitly in such clause, that is also fine because duplicates + // in the list are ignored). + dataOperandSymbolPairs.append(ivPrivate.begin(), ivPrivate.end()); + // Remap symbols from data clauses to use data operation results + remapDataOperandSymbols(converter, builder, loopOp, dataOperandSymbolPairs); + + for (auto [arg, iv] : + llvm::zip(loopOp.getLoopRegions().front()->front().getArguments(), + ivPrivate)) { + // Store block argument to the related iv private variable. + mlir::Value privateValue = + converter.getSymbolAddress(std::get<Fortran::semantics::SymbolRef>(iv)); + fir::StoreOp::create(builder, currentLocation, arg, privateValue); + } loopOp.setInclusiveUpperbound(inclusiveBounds); @@ -2260,6 +2402,10 @@ static mlir::acc::LoopOp createLoopOp( llvm::SmallVector<int32_t> tileOperandsSegments, gangOperandsSegments; llvm::SmallVector<int64_t> collapseValues; + // Vector to track mlir::Value results and their corresponding Fortran symbols + llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> + dataOperandSymbolPairs; + llvm::SmallVector<mlir::Attribute> gangArgTypes; llvm::SmallVector<mlir::Attribute> seqDeviceTypes, independentDeviceTypes, autoDeviceTypes, vectorOperandsDeviceTypes, workerNumOperandsDeviceTypes, @@ -2380,7 +2526,8 @@ static mlir::acc::LoopOp createLoopOp( genPrivatizationRecipes<mlir::acc::PrivateRecipeOp>( privateClause->v, converter, semanticsContext, stmtCtx, privateOperands, privatizationRecipes, /*async=*/{}, - /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); + /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}, + &dataOperandSymbolPairs); } else if (const auto *reductionClause = std::get_if<Fortran::parser::AccClause::Reduction>( &clause.u)) { @@ -2436,9 +2583,9 @@ static mlir::acc::LoopOp createLoopOp( Fortran::lower::getLoopCountForCollapseAndTile(accClauseList); auto loopOp = buildACCLoopOp( converter, currentLocation, semanticsContext, stmtCtx, outerDoConstruct, - eval, privateOperands, privatizationRecipes, gangOperands, - workerNumOperands, vectorOperands, tileOperands, cacheOperands, - reductionOperands, retTy, yieldValue, loopsToProcess); + eval, privateOperands, privatizationRecipes, dataOperandSymbolPairs, + gangOperands, workerNumOperands, vectorOperands, tileOperands, + cacheOperands, reductionOperands, retTy, yieldValue, loopsToProcess); if (!gangDeviceTypes.empty()) loopOp.setGangAttr(builder.getArrayAttr(gangDeviceTypes)); @@ -2568,7 +2715,9 @@ static void genDataOperandOperationsWithModifier( llvm::ArrayRef<mlir::Value> async, llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes, llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes, - bool setDeclareAttr = false) { + bool setDeclareAttr = false, + llvm::SmallVectorImpl<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> + *symbolPairs = nullptr) { const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v; const auto &accObjectList = std::get<Fortran::parser::AccObjectList>(listWithModifier.t); @@ -2581,7 +2730,7 @@ static void genDataOperandOperationsWithModifier( stmtCtx, dataClauseOperands, dataClause, /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, asyncOnlyDeviceTypes, - setDeclareAttr); + setDeclareAttr, symbolPairs); } template <typename Op> @@ -2612,6 +2761,10 @@ static Op createComputeOp( llvm::SmallVector<mlir::Attribute> privatizationRecipes, firstPrivatizationRecipes, reductionRecipes; + // Vector to track mlir::Value results and their corresponding Fortran symbols + llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> + dataOperandSymbolPairs; + // Self clause has optional values but can be present with // no value as well. When there is no value, the op has an attribute to // represent the clause. @@ -2732,7 +2885,8 @@ static Op createComputeOp( copyClause->v, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_copy, /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, - asyncOnlyDeviceTypes); + asyncOnlyDeviceTypes, /*setDeclareAttr=*/false, + &dataOperandSymbolPairs); copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end()); } else if (const auto *copyinClause = @@ -2744,7 +2898,8 @@ static Op createComputeOp( Fortran::parser::AccDataModifier::Modifier::ReadOnly, dataClauseOperands, mlir::acc::DataClause::acc_copyin, mlir::acc::DataClause::acc_copyin_readonly, async, asyncDeviceTypes, - asyncOnlyDeviceTypes); + asyncOnlyDeviceTypes, /*setDeclareAttr=*/false, + &dataOperandSymbolPairs); copyinEntryOperands.append(dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end()); } else if (const auto *copyoutClause = @@ -2757,7 +2912,8 @@ static Op createComputeOp( Fortran::parser::AccDataModifier::Modifier::ReadOnly, dataClauseOperands, mlir::acc::DataClause::acc_copyout, mlir::acc::DataClause::acc_copyout_zero, async, asyncDeviceTypes, - asyncOnlyDeviceTypes); + asyncOnlyDeviceTypes, /*setDeclareAttr=*/false, + &dataOperandSymbolPairs); copyoutEntryOperands.append(dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end()); } else if (const auto *createClause = @@ -2769,7 +2925,8 @@ static Op createComputeOp( Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands, mlir::acc::DataClause::acc_create, mlir::acc::DataClause::acc_create_zero, async, asyncDeviceTypes, - asyncOnlyDeviceTypes); + asyncOnlyDeviceTypes, /*setDeclareAttr=*/false, + &dataOperandSymbolPairs); createEntryOperands.append(dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end()); } else if (const auto *noCreateClause = @@ -2780,7 +2937,8 @@ static Op createComputeOp( noCreateClause->v, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_no_create, /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, - asyncOnlyDeviceTypes); + asyncOnlyDeviceTypes, /*setDeclareAttr=*/false, + &dataOperandSymbolPairs); nocreateEntryOperands.append(dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end()); } else if (const auto *presentClause = @@ -2791,17 +2949,21 @@ static Op createComputeOp( presentClause->v, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_present, /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, - asyncOnlyDeviceTypes); + asyncOnlyDeviceTypes, /*setDeclareAttr=*/false, + &dataOperandSymbolPairs); presentEntryOperands.append(dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end()); } else if (const auto *devicePtrClause = std::get_if<Fortran::parser::AccClause::Deviceptr>( &clause.u)) { + llvm::SmallVectorImpl< + std::pair<mlir::Value, Fortran::semantics::SymbolRef>> *symPairs = + enableDevicePtrRemap ? &dataOperandSymbolPairs : nullptr; genDataOperandOperations<mlir::acc::DevicePtrOp>( devicePtrClause->v, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_deviceptr, /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, - asyncOnlyDeviceTypes); + asyncOnlyDeviceTypes, /*setDeclareAttr=*/false, symPairs); } else if (const auto *attachClause = std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) { auto crtDataStart = dataClauseOperands.size(); @@ -2809,7 +2971,8 @@ static Op createComputeOp( attachClause->v, converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_attach, /*structured=*/true, /*implicit=*/false, async, asyncDeviceTypes, - asyncOnlyDeviceTypes); + asyncOnlyDeviceTypes, /*setDeclareAttr=*/false, + &dataOperandSymbolPairs); attachEntryOperands.append(dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end()); } else if (const auto *privateClause = @@ -2819,14 +2982,14 @@ static Op createComputeOp( genPrivatizationRecipes<mlir::acc::PrivateRecipeOp>( privateClause->v, converter, semanticsContext, stmtCtx, privateOperands, privatizationRecipes, async, asyncDeviceTypes, - asyncOnlyDeviceTypes); + asyncOnlyDeviceTypes, &dataOperandSymbolPairs); } else if (const auto *firstprivateClause = std::get_if<Fortran::parser::AccClause::Firstprivate>( &clause.u)) { genPrivatizationRecipes<mlir::acc::FirstprivateRecipeOp>( firstprivateClause->v, converter, semanticsContext, stmtCtx, firstprivateOperands, firstPrivatizationRecipes, async, - asyncDeviceTypes, asyncOnlyDeviceTypes); + asyncDeviceTypes, asyncOnlyDeviceTypes, &dataOperandSymbolPairs); } else if (const auto *reductionClause = std::get_if<Fortran::parser::AccClause::Reduction>( &clause.u)) { @@ -2846,7 +3009,8 @@ static Op createComputeOp( converter, semanticsContext, stmtCtx, dataClauseOperands, mlir::acc::DataClause::acc_reduction, /*structured=*/true, /*implicit=*/true, async, asyncDeviceTypes, - asyncOnlyDeviceTypes); + asyncOnlyDeviceTypes, /*setDeclareAttr=*/false, + &dataOperandSymbolPairs); copyEntryOperands.append(dataClauseOperands.begin() + crtDataStart, dataClauseOperands.end()); } @@ -2945,6 +3109,11 @@ static Op createComputeOp( computeOp.setCombinedAttr(builder.getUnitAttr()); auto insPt = builder.saveInsertionPoint(); + + // Remap symbols from data clauses to use data operation results + remapDataOperandSymbols(converter, builder, computeOp, + dataOperandSymbolPairs); + builder.setInsertionPointAfter(computeOp); // Create the exit operations after the region. @@ -4921,6 +5090,8 @@ mlir::Operation *Fortran::lower::genOpenACCLoopFromDoConstruct( reductionOperands; llvm::SmallVector<mlir::Attribute> privatizationRecipes; llvm::SmallVector<mlir::Type> retTy; + llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> + dataOperandSymbolPairs; mlir::Value yieldValue; uint64_t loopsToProcess = 1; // Single loop construct @@ -4929,9 +5100,10 @@ mlir::Operation *Fortran::lower::genOpenACCLoopFromDoConstruct( Fortran::lower::StatementContext stmtCtx; auto loopOp = buildACCLoopOp( converter, converter.getCurrentLocation(), semanticsContext, stmtCtx, - doConstruct, eval, privateOperands, privatizationRecipes, gangOperands, - workerNumOperands, vectorOperands, tileOperands, cacheOperands, - reductionOperands, retTy, yieldValue, loopsToProcess); + doConstruct, eval, privateOperands, privatizationRecipes, + dataOperandSymbolPairs, gangOperands, workerNumOperands, vectorOperands, + tileOperands, cacheOperands, reductionOperands, retTy, yieldValue, + loopsToProcess); fir::FirOpBuilder &builder = converter.getFirOpBuilder(); if (!privatizationRecipes.empty()) |