diff options
Diffstat (limited to 'flang/lib/Utils/OpenMP.cpp')
| -rw-r--r-- | flang/lib/Utils/OpenMP.cpp | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/flang/lib/Utils/OpenMP.cpp b/flang/lib/Utils/OpenMP.cpp index c2036c4a383f..a0a67249f7f0 100644 --- a/flang/lib/Utils/OpenMP.cpp +++ b/flang/lib/Utils/OpenMP.cpp @@ -155,4 +155,101 @@ void cloneOrMapRegionOutsiders( mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove); } } + +mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper( + fir::FirOpBuilder &firOpBuilder, mlir::Location loc, + fir::RecordType recordType, llvm::StringRef mapperNameStr, + RecordMemberMapperMangler mangler) { + if (mapperNameStr.empty()) + return {}; + + mlir::ModuleOp moduleOp = firOpBuilder.getModule(); + if (moduleOp.lookupSymbol(mapperNameStr)) + return mlir::FlatSymbolRefAttr::get( + firOpBuilder.getContext(), mapperNameStr); + + mlir::OpBuilder::InsertionGuard guard(firOpBuilder); + + firOpBuilder.setInsertionPointToStart(moduleOp.getBody()); + auto declMapperOp = mlir::omp::DeclareMapperOp::create( + firOpBuilder, loc, mapperNameStr, recordType); + auto ®ion = declMapperOp.getRegion(); + firOpBuilder.createBlock(®ion); + auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc); + + auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg, + /*uniq_name=*/""); + + const auto genBoundsOps = [&](mlir::Value mapVal, + llvm::SmallVectorImpl<mlir::Value> &bounds) { + fir::ExtendedValue extVal = hlfir::translateToExtendedValue(mapVal.getLoc(), + firOpBuilder, hlfir::Entity{mapVal}, + /*contiguousHint=*/true) + .first; + fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr( + firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc()); + bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp, + mlir::omp::MapBoundsType>(firOpBuilder, info, extVal, + /*dataExvIsAssumedSize=*/false, mapVal.getLoc()); + }; + + const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName, + mlir::Type fieldTy, mlir::Type recType) { + mlir::Value field = fir::FieldIndexOp::create(firOpBuilder, loc, + fir::FieldType::get(recType.getContext()), fieldName, recType, + fir::getTypeParams(rec)); + return fir::CoordinateOp::create( + firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field); + }; + + llvm::SmallVector<mlir::Value> clauseMapVars; + llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices; + llvm::SmallVector<mlir::Value> memberMapOps; + + mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to | + mlir::omp::ClauseMapFlags::from | mlir::omp::ClauseMapFlags::implicit; + mlir::omp::VariableCaptureKind captureKind = + mlir::omp::VariableCaptureKind::ByRef; + + for (const auto &entry : llvm::enumerate(recordType.getTypeList())) { + const auto &memberName = entry.value().first; + const auto &memberType = entry.value().second; + mlir::FlatSymbolRefAttr mapperId; + if (auto recType = mlir::dyn_cast<fir::RecordType>( + fir::getFortranElementType(memberType))) { + std::string mapperIdName = + recType.getName().str() + llvm::omp::OmpDefaultMapperName; + mangler(mapperIdName, memberName); + mapperId = getOrGenImplicitDefaultDeclareMapper( + firOpBuilder, loc, recType, mapperIdName, mangler); + } + + auto ref = + getFieldRef(declareOp.getBase(), memberName, memberType, recordType); + llvm::SmallVector<mlir::Value> bounds; + genBoundsOps(ref, bounds); + mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp(firOpBuilder, + loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"", bounds, + /*members=*/{}, + /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(), + /*partialMap=*/false, mapperId); + memberMapOps.emplace_back(mapOp); + memberPlacementIndices.emplace_back( + llvm::SmallVector<int64_t>{(int64_t)entry.index()}); + } + + llvm::SmallVector<mlir::Value> bounds; + genBoundsOps(declareOp.getOriginalBase(), bounds); + mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit; + mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp( + firOpBuilder, loc, declareOp.getOriginalBase(), + /*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps, + firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag, + captureKind, declareOp.getType(0), + /*partialMap=*/true); + + clauseMapVars.emplace_back(mapOp); + mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars); + return mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), mapperNameStr); +} } // namespace Fortran::utils::openmp |
