aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Utils/OpenMP.cpp
diff options
context:
space:
mode:
authorergawy <kareem.ergawy@amd.com>2026-02-05 06:39:34 -0600
committerergawy <kareem.ergawy@amd.com>2026-02-05 06:42:29 -0600
commit195a07d01b63c228dd07e0534322da4983094521 (patch)
tree32dfa3fca60dd2a4f4de5f7298d12538336d14e8 /flang/lib/Utils/OpenMP.cpp
parentd1598c96e06ad8c21ff2d949d9c46e564f1b7191 (diff)
downloadllvm-users/ergawy/implicit_maps_dc.tar.gz
llvm-users/ergawy/implicit_maps_dc.tar.bz2
llvm-users/ergawy/implicit_maps_dc.zip
[flang][OpenMP][DoConcurrent] Emit declare mapper for recordsusers/ergawy/implicit_maps_dc
Extends `do concurrent` device support by emitting compiler-generated declare mapper ops for live-ins whose types are record types and have allocatable members.
Diffstat (limited to 'flang/lib/Utils/OpenMP.cpp')
-rw-r--r--flang/lib/Utils/OpenMP.cpp97
1 files changed, 97 insertions, 0 deletions
diff --git a/flang/lib/Utils/OpenMP.cpp b/flang/lib/Utils/OpenMP.cpp
index c2036c4a383f..a0a67249f7f0 100644
--- a/flang/lib/Utils/OpenMP.cpp
+++ b/flang/lib/Utils/OpenMP.cpp
@@ -155,4 +155,101 @@ void cloneOrMapRegionOutsiders(
mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
}
}
+
+mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
+ fir::FirOpBuilder &firOpBuilder, mlir::Location loc,
+ fir::RecordType recordType, llvm::StringRef mapperNameStr,
+ RecordMemberMapperMangler mangler) {
+ if (mapperNameStr.empty())
+ return {};
+
+ mlir::ModuleOp moduleOp = firOpBuilder.getModule();
+ if (moduleOp.lookupSymbol(mapperNameStr))
+ return mlir::FlatSymbolRefAttr::get(
+ firOpBuilder.getContext(), mapperNameStr);
+
+ mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+
+ firOpBuilder.setInsertionPointToStart(moduleOp.getBody());
+ auto declMapperOp = mlir::omp::DeclareMapperOp::create(
+ firOpBuilder, loc, mapperNameStr, recordType);
+ auto &region = declMapperOp.getRegion();
+ firOpBuilder.createBlock(&region);
+ auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
+
+ auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg,
+ /*uniq_name=*/"");
+
+ const auto genBoundsOps = [&](mlir::Value mapVal,
+ llvm::SmallVectorImpl<mlir::Value> &bounds) {
+ fir::ExtendedValue extVal = hlfir::translateToExtendedValue(mapVal.getLoc(),
+ firOpBuilder, hlfir::Entity{mapVal},
+ /*contiguousHint=*/true)
+ .first;
+ fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
+ firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
+ bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(firOpBuilder, info, extVal,
+ /*dataExvIsAssumedSize=*/false, mapVal.getLoc());
+ };
+
+ const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName,
+ mlir::Type fieldTy, mlir::Type recType) {
+ mlir::Value field = fir::FieldIndexOp::create(firOpBuilder, loc,
+ fir::FieldType::get(recType.getContext()), fieldName, recType,
+ fir::getTypeParams(rec));
+ return fir::CoordinateOp::create(
+ firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field);
+ };
+
+ llvm::SmallVector<mlir::Value> clauseMapVars;
+ llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
+ llvm::SmallVector<mlir::Value> memberMapOps;
+
+ mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to |
+ mlir::omp::ClauseMapFlags::from | mlir::omp::ClauseMapFlags::implicit;
+ mlir::omp::VariableCaptureKind captureKind =
+ mlir::omp::VariableCaptureKind::ByRef;
+
+ for (const auto &entry : llvm::enumerate(recordType.getTypeList())) {
+ const auto &memberName = entry.value().first;
+ const auto &memberType = entry.value().second;
+ mlir::FlatSymbolRefAttr mapperId;
+ if (auto recType = mlir::dyn_cast<fir::RecordType>(
+ fir::getFortranElementType(memberType))) {
+ std::string mapperIdName =
+ recType.getName().str() + llvm::omp::OmpDefaultMapperName;
+ mangler(mapperIdName, memberName);
+ mapperId = getOrGenImplicitDefaultDeclareMapper(
+ firOpBuilder, loc, recType, mapperIdName, mangler);
+ }
+
+ auto ref =
+ getFieldRef(declareOp.getBase(), memberName, memberType, recordType);
+ llvm::SmallVector<mlir::Value> bounds;
+ genBoundsOps(ref, bounds);
+ mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp(firOpBuilder,
+ loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"", bounds,
+ /*members=*/{},
+ /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(),
+ /*partialMap=*/false, mapperId);
+ memberMapOps.emplace_back(mapOp);
+ memberPlacementIndices.emplace_back(
+ llvm::SmallVector<int64_t>{(int64_t)entry.index()});
+ }
+
+ llvm::SmallVector<mlir::Value> bounds;
+ genBoundsOps(declareOp.getOriginalBase(), bounds);
+ mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit;
+ mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp(
+ firOpBuilder, loc, declareOp.getOriginalBase(),
+ /*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
+ firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag,
+ captureKind, declareOp.getType(0),
+ /*partialMap=*/true);
+
+ clauseMapVars.emplace_back(mapOp);
+ mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars);
+ return mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), mapperNameStr);
+}
} // namespace Fortran::utils::openmp