diff options
| author | ergawy <kareem.ergawy@amd.com> | 2026-02-05 06:39:34 -0600 |
|---|---|---|
| committer | ergawy <kareem.ergawy@amd.com> | 2026-02-05 06:42:29 -0600 |
| commit | 195a07d01b63c228dd07e0534322da4983094521 (patch) | |
| tree | 32dfa3fca60dd2a4f4de5f7298d12538336d14e8 | |
| parent | d1598c96e06ad8c21ff2d949d9c46e564f1b7191 (diff) | |
| download | llvm-users/ergawy/implicit_maps_dc.tar.gz llvm-users/ergawy/implicit_maps_dc.tar.bz2 llvm-users/ergawy/implicit_maps_dc.zip | |
[flang][OpenMP][DoConcurrent] Emit declare mapper for recordsusers/ergawy/implicit_maps_dc
Extends `do concurrent` device support by emitting compiler-generated
declare mapper ops for live-ins whose types are record types and have
allocatable members.
| -rw-r--r-- | flang/include/flang/Utils/OpenMP.h | 9 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 12 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/OpenMP.cpp | 15 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/Utils.cpp | 107 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/Utils.h | 4 | ||||
| -rw-r--r-- | flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp | 37 | ||||
| -rw-r--r-- | flang/lib/Utils/OpenMP.cpp | 97 | ||||
| -rw-r--r-- | flang/test/Transforms/DoConcurrent/implicit_mapper.f90 | 28 |
8 files changed, 194 insertions, 115 deletions
diff --git a/flang/include/flang/Utils/OpenMP.h b/flang/include/flang/Utils/OpenMP.h index bad0abb6f578..e8627347fd57 100644 --- a/flang/include/flang/Utils/OpenMP.h +++ b/flang/include/flang/Utils/OpenMP.h @@ -13,6 +13,7 @@ namespace fir { class FirOpBuilder; +class RecordType; } // namespace fir namespace Fortran::utils::openmp { @@ -59,6 +60,14 @@ mlir::Value mapTemporaryValue(fir::FirOpBuilder &firOpBuilder, /// maps. void cloneOrMapRegionOutsiders( fir::FirOpBuilder &firOpBuilder, mlir::omp::TargetOp targetOp); + +using RecordMemberMapperMangler = + std::function<void(std::string &mapperId, llvm::StringRef memberName)>; + +mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper( + fir::FirOpBuilder &firOpBuilder, mlir::Location loc, + fir::RecordType recordType, llvm::StringRef mapperNameStr, + RecordMemberMapperMangler mangler = {}); } // namespace Fortran::utils::openmp #endif // FORTRAN_UTILS_OPENMP_H_ diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index b1973a3b8bf0..8edb7cd5d193 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -1352,8 +1352,16 @@ void ClauseProcessor::processMapObjects( if (!recordType) return mlir::FlatSymbolRefAttr(); - return getOrGenImplicitDefaultDeclareMapper(converter, clauseLocation, - recordType, mapperIdName); + return utils::openmp::getOrGenImplicitDefaultDeclareMapper( + converter.getFirOpBuilder(), clauseLocation, recordType, mapperIdName, + [&](std::string &mapperIdName, llvm::StringRef memberName) { + if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) + mapperIdName = converter.mangleName(mapperIdName, sym->owner()); + else if (auto *memberSym = + converter.getCurrentScope().FindSymbol(memberName.str())) + mapperIdName = + converter.mangleName(mapperIdName, memberSym->owner()); + }); }; auto getDefaultMapperID = diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index df89cbe46a5c..17d466c6340d 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -2792,7 +2792,20 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, if (auto recordType = mlir::dyn_cast_or_null<fir::RecordType>( converter.genType(*typeSpec))) mapperId = getOrGenImplicitDefaultDeclareMapper( - converter, loc, recordType, mapperIdName); + converter.getFirOpBuilder(), loc, recordType, + mapperIdName, + [&](std::string &mapperIdName, + llvm::StringRef memberName) { + if (auto *sym = converter.getCurrentScope().FindSymbol( + mapperIdName)) + mapperIdName = + converter.mangleName(mapperIdName, sym->owner()); + else if (auto *memberSym = + converter.getCurrentScope().FindSymbol( + memberName.str())) + mapperIdName = converter.mangleName( + mapperIdName, memberSym->owner()); + }); } else { mapperId = mlir::FlatSymbolRefAttr::get( &converter.getMLIRContext(), mapperIdName); diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index dce858085666..1bf5117168a7 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -67,113 +67,6 @@ llvm::cl::opt<bool> treatIndexAsSection( namespace Fortran { namespace lower { namespace omp { - -mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper( - lower::AbstractConverter &converter, mlir::Location loc, - fir::RecordType recordType, llvm::StringRef mapperNameStr) { - if (mapperNameStr.empty()) - return {}; - - if (converter.getModuleOp().lookupSymbol(mapperNameStr)) - return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), - mapperNameStr); - - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::OpBuilder::InsertionGuard guard(firOpBuilder); - - firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody()); - auto declMapperOp = mlir::omp::DeclareMapperOp::create( - firOpBuilder, loc, mapperNameStr, recordType); - auto ®ion = declMapperOp.getRegion(); - firOpBuilder.createBlock(®ion); - auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc); - - auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg, - /*uniq_name=*/""); - - const auto genBoundsOps = [&](mlir::Value mapVal, - llvm::SmallVectorImpl<mlir::Value> &bounds) { - fir::ExtendedValue extVal = - hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder, - hlfir::Entity{mapVal}, - /*contiguousHint=*/true) - .first; - fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr( - firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc()); - bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp, - mlir::omp::MapBoundsType>( - firOpBuilder, info, extVal, - /*dataExvIsAssumedSize=*/false, mapVal.getLoc()); - }; - - const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName, - mlir::Type fieldTy, mlir::Type recType) { - mlir::Value field = fir::FieldIndexOp::create( - firOpBuilder, loc, fir::FieldType::get(recType.getContext()), fieldName, - recType, fir::getTypeParams(rec)); - return fir::CoordinateOp::create( - firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field); - }; - - llvm::SmallVector<mlir::Value> clauseMapVars; - llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices; - llvm::SmallVector<mlir::Value> memberMapOps; - - mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to | - mlir::omp::ClauseMapFlags::from | - mlir::omp::ClauseMapFlags::implicit; - mlir::omp::VariableCaptureKind captureKind = - mlir::omp::VariableCaptureKind::ByRef; - - for (const auto &entry : llvm::enumerate(recordType.getTypeList())) { - const auto &memberName = entry.value().first; - const auto &memberType = entry.value().second; - mlir::FlatSymbolRefAttr mapperId; - if (auto recType = mlir::dyn_cast<fir::RecordType>( - fir::getFortranElementType(memberType))) { - std::string mapperIdName = - recType.getName().str() + llvm::omp::OmpDefaultMapperName; - if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) - mapperIdName = converter.mangleName(mapperIdName, sym->owner()); - else if (auto *memberSym = - converter.getCurrentScope().FindSymbol(memberName)) - mapperIdName = converter.mangleName(mapperIdName, memberSym->owner()); - - mapperId = getOrGenImplicitDefaultDeclareMapper(converter, loc, recType, - mapperIdName); - } - - auto ref = - getFieldRef(declareOp.getBase(), memberName, memberType, recordType); - llvm::SmallVector<mlir::Value> bounds; - genBoundsOps(ref, bounds); - mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp( - firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"", - bounds, - /*members=*/{}, - /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(), - /*partialMap=*/false, mapperId); - memberMapOps.emplace_back(mapOp); - memberPlacementIndices.emplace_back( - llvm::SmallVector<int64_t>{(int64_t)entry.index()}); - } - - llvm::SmallVector<mlir::Value> bounds; - genBoundsOps(declareOp.getOriginalBase(), bounds); - mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit; - mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp( - firOpBuilder, loc, declareOp.getOriginalBase(), - /*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps, - firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag, - captureKind, declareOp.getType(0), - /*partialMap=*/true); - - clauseMapVars.emplace_back(mapOp); - mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars); - return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), - mapperNameStr); -} - bool requiresImplicitDefaultDeclareMapper( const semantics::DerivedTypeSpec &typeSpec) { // ISO C interoperable types (e.g., c_ptr, c_funptr) must always have implicit diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index 8a68ff8bd3bd..0b651572e087 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -137,10 +137,6 @@ mlir::Value createParentSymAndGenIntermediateMaps( OmpMapParentAndMemberData &parentMemberIndices, llvm::StringRef asFortran, mlir::omp::ClauseMapFlags mapTypeBits); -mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper( - Fortran::lower::AbstractConverter &converter, mlir::Location loc, - fir::RecordType recordType, llvm::StringRef mapperNameStr); - bool requiresImplicitDefaultDeclareMapper( const semantics::DerivedTypeSpec &typeSpec); diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp index ff346e79276c..ce3659660290 100644 --- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp @@ -22,6 +22,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Frontend/OpenMP/OMPConstants.h" namespace flangomp { #define GEN_PASS_DEF_DOCONCURRENTCONVERSIONPASS @@ -583,12 +584,46 @@ private: llvm::SmallVector<mlir::Value> boundsOps; genBoundsOps(builder, liveIn, rawAddr, boundsOps); + auto asRecordType = [&](mlir::Type eleType) { + fir::RecordType recordType = mlir::dyn_cast<fir::RecordType>(eleType); + + if (auto seqType = mlir::dyn_cast<fir::SequenceType>(eleType)) + recordType = mlir::dyn_cast<fir::RecordType>(seqType.getElementType()); + + return recordType; + }; + + fir::RecordType recordType = asRecordType(eleType); + + bool requiresImplcitMapper = [&]() { + if (!recordType) + return false; + + for (auto [fieldName, fieldType] : recordType.getTypeList()) { + if (fir::isAllocatableType(fieldType)) + return true; + + if (asRecordType(fieldType)) + TODO(liveIn.getLoc(), "Nested record types are not supported yet."); + } + + return false; + }(); + + mlir::FlatSymbolRefAttr mapperId; + if (requiresImplcitMapper) { + std::string mapperIdName = + recordType.getName().str() + llvm::omp::OmpDefaultMapperName; + mapperId = Fortran::utils::openmp::getOrGenImplicitDefaultDeclareMapper( + builder, liveIn.getLoc(), recordType, mapperIdName); + } + return Fortran::utils::openmp::createMapInfoOp( builder, liveIn.getLoc(), rawAddr, /*varPtrPtr=*/{}, name.str(), boundsOps, /*members=*/{}, /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, - rawAddr.getType()); + rawAddr.getType(), /*partialMap=*/false, mapperId); } mlir::omp::TargetOp diff --git a/flang/lib/Utils/OpenMP.cpp b/flang/lib/Utils/OpenMP.cpp index c2036c4a383f..a0a67249f7f0 100644 --- a/flang/lib/Utils/OpenMP.cpp +++ b/flang/lib/Utils/OpenMP.cpp @@ -155,4 +155,101 @@ void cloneOrMapRegionOutsiders( mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove); } } + +mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper( + fir::FirOpBuilder &firOpBuilder, mlir::Location loc, + fir::RecordType recordType, llvm::StringRef mapperNameStr, + RecordMemberMapperMangler mangler) { + if (mapperNameStr.empty()) + return {}; + + mlir::ModuleOp moduleOp = firOpBuilder.getModule(); + if (moduleOp.lookupSymbol(mapperNameStr)) + return mlir::FlatSymbolRefAttr::get( + firOpBuilder.getContext(), mapperNameStr); + + mlir::OpBuilder::InsertionGuard guard(firOpBuilder); + + firOpBuilder.setInsertionPointToStart(moduleOp.getBody()); + auto declMapperOp = mlir::omp::DeclareMapperOp::create( + firOpBuilder, loc, mapperNameStr, recordType); + auto ®ion = declMapperOp.getRegion(); + firOpBuilder.createBlock(®ion); + auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc); + + auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg, + /*uniq_name=*/""); + + const auto genBoundsOps = [&](mlir::Value mapVal, + llvm::SmallVectorImpl<mlir::Value> &bounds) { + fir::ExtendedValue extVal = hlfir::translateToExtendedValue(mapVal.getLoc(), + firOpBuilder, hlfir::Entity{mapVal}, + /*contiguousHint=*/true) + .first; + fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr( + firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc()); + bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp, + mlir::omp::MapBoundsType>(firOpBuilder, info, extVal, + /*dataExvIsAssumedSize=*/false, mapVal.getLoc()); + }; + + const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName, + mlir::Type fieldTy, mlir::Type recType) { + mlir::Value field = fir::FieldIndexOp::create(firOpBuilder, loc, + fir::FieldType::get(recType.getContext()), fieldName, recType, + fir::getTypeParams(rec)); + return fir::CoordinateOp::create( + firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field); + }; + + llvm::SmallVector<mlir::Value> clauseMapVars; + llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices; + llvm::SmallVector<mlir::Value> memberMapOps; + + mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to | + mlir::omp::ClauseMapFlags::from | mlir::omp::ClauseMapFlags::implicit; + mlir::omp::VariableCaptureKind captureKind = + mlir::omp::VariableCaptureKind::ByRef; + + for (const auto &entry : llvm::enumerate(recordType.getTypeList())) { + const auto &memberName = entry.value().first; + const auto &memberType = entry.value().second; + mlir::FlatSymbolRefAttr mapperId; + if (auto recType = mlir::dyn_cast<fir::RecordType>( + fir::getFortranElementType(memberType))) { + std::string mapperIdName = + recType.getName().str() + llvm::omp::OmpDefaultMapperName; + mangler(mapperIdName, memberName); + mapperId = getOrGenImplicitDefaultDeclareMapper( + firOpBuilder, loc, recType, mapperIdName, mangler); + } + + auto ref = + getFieldRef(declareOp.getBase(), memberName, memberType, recordType); + llvm::SmallVector<mlir::Value> bounds; + genBoundsOps(ref, bounds); + mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp(firOpBuilder, + loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"", bounds, + /*members=*/{}, + /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(), + /*partialMap=*/false, mapperId); + memberMapOps.emplace_back(mapOp); + memberPlacementIndices.emplace_back( + llvm::SmallVector<int64_t>{(int64_t)entry.index()}); + } + + llvm::SmallVector<mlir::Value> bounds; + genBoundsOps(declareOp.getOriginalBase(), bounds); + mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit; + mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp( + firOpBuilder, loc, declareOp.getOriginalBase(), + /*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps, + firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag, + captureKind, declareOp.getType(0), + /*partialMap=*/true); + + clauseMapVars.emplace_back(mapOp); + mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars); + return mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), mapperNameStr); +} } // namespace Fortran::utils::openmp diff --git a/flang/test/Transforms/DoConcurrent/implicit_mapper.f90 b/flang/test/Transforms/DoConcurrent/implicit_mapper.f90 new file mode 100644 index 000000000000..f77c0cd0320c --- /dev/null +++ b/flang/test/Transforms/DoConcurrent/implicit_mapper.f90 @@ -0,0 +1,28 @@ +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-to-openmp=device %s -o - \ +! RUN: | FileCheck %s + +module record_with_alloc_mod + implicit none + public :: record_with_alloc + + type record_with_alloc + real, allocatable :: values_(:) + end type +end module record_with_alloc_mod + +subroutine random_inputs() + use record_with_alloc_mod, only : record_with_alloc + implicit none + type(record_with_alloc) :: inputs(2) + integer :: i + + do concurrent(i=1:10) + inputs(1)%values_ = [1,2,3,4] + end do +end subroutine + +! CHECK: omp.declare_mapper @[[MAPPER_NAME:.*record_with_alloc_omp_default_mapper]] : !fir.type<{{.*}}record_with_alloc{{.*}}> + +! CHECK: func.func @{{.*}}random_inputs() +! CHECK: %[[ARR_DECL:.*]]:2 = hlfir.declare {{.*}} {uniq_name = "{{.*}}inputs"} +! CHECK: omp.map.info var_ptr(%[[ARR_DECL]]#1 : {{.*}}) {{.*}} mapper(@[[MAPPER_NAME]]) |
