aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Lower/OpenACC.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Lower/OpenACC.cpp')
-rw-r--r--flang/lib/Lower/OpenACC.cpp55
1 files changed, 47 insertions, 8 deletions
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index af4f420..1fc59c7 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2366,6 +2366,23 @@ static void processDoLoopBounds(
}
}
+static void remapCommonBlockMember(
+ Fortran::lower::AbstractConverter &converter, mlir::Location loc,
+ const Fortran::semantics::Symbol &member,
+ mlir::Value newCommonBlockBaseAddress,
+ const Fortran::semantics::Symbol &commonBlockSymbol,
+ llvm::SmallPtrSetImpl<const Fortran::semantics::Symbol *> &seenSymbols) {
+ if (seenSymbols.contains(&member))
+ return;
+ mlir::Value accMemberValue = Fortran::lower::genCommonBlockMember(
+ converter, loc, member, newCommonBlockBaseAddress,
+ commonBlockSymbol.size());
+ fir::ExtendedValue hostExv = converter.getSymbolExtendedValue(member);
+ fir::ExtendedValue accExv = fir::substBase(hostExv, accMemberValue);
+ converter.bindSymbol(member, accExv);
+ seenSymbols.insert(&member);
+}
+
/// Remap symbols that appeared in OpenACC data clauses to use the results of
/// the corresponding data operations. This allows isolating symbol accesses
/// inside the OpenACC region from accesses in the host and other regions while
@@ -2391,14 +2408,39 @@ static void remapDataOperandSymbols(
builder.setInsertionPointToStart(&regionOp.getRegion().front());
llvm::SmallPtrSet<const Fortran::semantics::Symbol *, 8> seenSymbols;
mlir::IRMapping mapper;
+ mlir::Location loc = regionOp.getLoc();
for (auto [value, symbol] : dataOperandSymbolPairs) {
-
- // If A symbol appears on several data clause, just map it to the first
+ // If a symbol appears on several data clause, just map it to the first
// result (all data operations results for a symbol are pointing same
// memory, so it does not matter which one is used).
if (seenSymbols.contains(&symbol.get()))
continue;
seenSymbols.insert(&symbol.get());
+ // When a common block appears in a directive, remap its members.
+ // Note: this will instantiate all common block members even if they are not
+ // used inside the region. If hlfir.declare DCE is not made possible, this
+ // could be improved to reduce IR noise.
+ if (const auto *commonBlock = symbol->template detailsIf<
+ Fortran::semantics::CommonBlockDetails>()) {
+ const Fortran::semantics::Scope &commonScope = symbol->owner();
+ if (commonScope.equivalenceSets().empty()) {
+ for (auto member : commonBlock->objects())
+ remapCommonBlockMember(converter, loc, *member, value, *symbol,
+ seenSymbols);
+ } else {
+ // Objects equivalenced with common block members still belong to the
+ // common block storage even if they are not part of the common block
+ // declaration. The easiest and most robust way to find all symbols
+ // belonging to the common block is to loop through the scope symbols
+ // and check if they belong to the common.
+ for (const auto &scopeSymbol : commonScope)
+ if (Fortran::semantics::FindCommonBlockContaining(
+ *scopeSymbol.second) == &symbol.get())
+ remapCommonBlockMember(converter, loc, *scopeSymbol.second, value,
+ *symbol, seenSymbols);
+ }
+ continue;
+ }
std::optional<fir::FortranVariableOpInterface> hostDef =
symbolMap.lookupVariableDefinition(symbol);
assert(hostDef.has_value() && llvm::isa<hlfir::DeclareOp>(*hostDef) &&
@@ -2415,10 +2457,8 @@ static void remapDataOperandSymbols(
"box type mismatch between compute region variable and "
"hlfir.declare input unexpected");
if (Fortran::semantics::IsOptional(symbol))
- TODO(regionOp.getLoc(),
- "remapping OPTIONAL symbol in OpenACC compute region");
- auto rawValue =
- fir::BoxAddrOp::create(builder, regionOp.getLoc(), hostType, value);
+ TODO(loc, "remapping OPTIONAL symbol in OpenACC compute region");
+ auto rawValue = fir::BoxAddrOp::create(builder, loc, hostType, value);
mapper.map(hostInput, rawValue);
} else {
assert(!llvm::isa<fir::BaseBoxType>(hostType) &&
@@ -2430,8 +2470,7 @@ static void remapDataOperandSymbols(
assert(fir::isa_ref_type(hostType) && fir::isa_ref_type(computeType) &&
"compute region variable and host variable should both be raw "
"addresses");
- mlir::Value cast =
- builder.createConvert(regionOp.getLoc(), hostType, value);
+ mlir::Value cast = builder.createConvert(loc, hostType, value);
mapper.map(hostInput, cast);
}
if (mlir::Value dummyScope = hostDeclare.getDummyScope()) {