diff options
Diffstat (limited to 'flang/lib/Optimizer')
-rw-r--r-- | flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp | 294 |
1 files changed, 220 insertions, 74 deletions
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp index 57be863..e595e61 100644 --- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp +++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp @@ -41,7 +41,9 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Support/raw_ostream.h" #include <algorithm> #include <cstddef> #include <iterator> @@ -75,6 +77,112 @@ class MapInfoFinalizationPass /// | | std::map<mlir::Operation *, mlir::Value> localBoxAllocas; + /// Return true if the given path exists in a list of paths. + static bool + containsPath(const llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &paths, + llvm::ArrayRef<int64_t> path) { + return llvm::any_of(paths, [&](const llvm::SmallVector<int64_t> &p) { + return p.size() == path.size() && + std::equal(p.begin(), p.end(), path.begin()); + }); + } + + /// Return true if the given path is already present in + /// op.getMembersIndexAttr(). + static bool mappedIndexPathExists(mlir::omp::MapInfoOp op, + llvm::ArrayRef<int64_t> indexPath) { + if (mlir::ArrayAttr attr = op.getMembersIndexAttr()) { + for (mlir::Attribute list : attr) { + auto listAttr = mlir::cast<mlir::ArrayAttr>(list); + if (listAttr.size() != indexPath.size()) + continue; + bool allEq = true; + for (auto [i, val] : llvm::enumerate(listAttr)) { + if (mlir::cast<mlir::IntegerAttr>(val).getInt() != indexPath[i]) { + allEq = false; + break; + } + } + if (allEq) + return true; + } + } + return false; + } + + /// Build a compact string key for an index path for set-based + /// deduplication. Format: "N:v0,v1,..." where N is the length. + static void buildPathKey(llvm::ArrayRef<int64_t> path, + llvm::SmallString<64> &outKey) { + outKey.clear(); + llvm::raw_svector_ostream os(outKey); + os << path.size() << ':'; + for (size_t i = 0; i < path.size(); ++i) { + if (i) + os << ','; + os << path[i]; + } + } + + /// Create the member map for coordRef and append it (and its index + /// path) to the provided new* vectors, if it is not already present. + void appendMemberMapIfNew( + mlir::omp::MapInfoOp op, fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value coordRef, llvm::ArrayRef<int64_t> indexPath, + llvm::StringRef memberName, + llvm::SmallVectorImpl<mlir::Value> &newMapOpsForFields, + llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &newMemberIndexPaths) { + // Local de-dup within this op invocation. + if (containsPath(newMemberIndexPaths, indexPath)) + return; + // Global de-dup against already present member indices. + if (mappedIndexPathExists(op, indexPath)) + return; + + if (op.getMapperId()) { + mlir::omp::DeclareMapperOp symbol = + mlir::SymbolTable::lookupNearestSymbolFrom< + mlir::omp::DeclareMapperOp>(op, op.getMapperIdAttr()); + assert(symbol && "missing symbol for declare mapper identifier"); + mlir::omp::DeclareMapperInfoOp mapperInfo = symbol.getDeclareMapperInfo(); + // TODO: Probably a way to cache these keys in someway so we don't + // constantly go through the process of rebuilding them on every check, to + // save some cycles, but it can wait for a subsequent patch. + for (auto v : mapperInfo.getMapVars()) { + mlir::omp::MapInfoOp map = + mlir::cast<mlir::omp::MapInfoOp>(v.getDefiningOp()); + if (!map.getMembers().empty() && mappedIndexPathExists(map, indexPath)) + return; + } + } + + builder.setInsertionPoint(op); + fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr( + builder, coordRef, /*isOptional=*/false, loc); + llvm::SmallVector<mlir::Value> bounds = fir::factory::genImplicitBoundsOps< + mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>( + builder, info, + hlfir::translateToExtendedValue(loc, builder, hlfir::Entity{coordRef}) + .first, + /*dataExvIsAssumedSize=*/false, loc); + + mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create( + builder, loc, coordRef.getType(), coordRef, + mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())), + op.getMapTypeAttr(), + builder.getAttr<mlir::omp::VariableCaptureKindAttr>( + mlir::omp::VariableCaptureKind::ByRef), + /*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{}, + /*members_index=*/mlir::ArrayAttr{}, bounds, + /*mapperId=*/mlir::FlatSymbolRefAttr(), + builder.getStringAttr(op.getNameAttr().strref() + "." + memberName + + ".implicit_map"), + /*partial_map=*/builder.getBoolAttr(false)); + + newMapOpsForFields.emplace_back(fieldMapOp); + newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end()); + } + /// getMemberUserList gathers all users of a particular MapInfoOp that are /// other MapInfoOp's and places them into the mapMemberUsers list, which /// records the map that the current argument MapInfoOp "op" is part of @@ -363,7 +471,7 @@ class MapInfoFinalizationPass mlir::ArrayAttr newMembersAttr; mlir::SmallVector<mlir::Value> newMembers; llvm::SmallVector<llvm::SmallVector<int64_t>> memberIndices; - bool IsHasDeviceAddr = isHasDeviceAddr(op, target); + bool isHasDeviceAddrFlag = isHasDeviceAddr(op, target); if (!mapMemberUsers.empty() || !op.getMembers().empty()) getMemberIndicesAsVectors( @@ -406,7 +514,7 @@ class MapInfoFinalizationPass mapUser.parent.getMembersMutable().assign(newMemberOps); mapUser.parent.setMembersIndexAttr( builder.create2DI64ArrayAttr(memberIndices)); - } else if (!IsHasDeviceAddr) { + } else if (!isHasDeviceAddrFlag) { auto baseAddr = genBaseAddrMap(descriptor, op.getBounds(), op.getMapType(), builder); newMembers.push_back(baseAddr); @@ -429,7 +537,7 @@ class MapInfoFinalizationPass // The contents of the descriptor (the base address in particular) will // remain unchanged though. uint64_t mapType = op.getMapType(); - if (IsHasDeviceAddr) { + if (isHasDeviceAddrFlag) { mapType |= llvm::to_underlying( llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); } @@ -701,94 +809,134 @@ class MapInfoFinalizationPass auto recordType = mlir::cast<fir::RecordType>(underlyingType); llvm::SmallVector<mlir::Value> newMapOpsForFields; - llvm::SmallVector<int64_t> fieldIndicies; + llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndexPaths; + // 1) Handle direct top-level allocatable fields. for (auto fieldMemTyPair : recordType.getTypeList()) { auto &field = fieldMemTyPair.first; auto memTy = fieldMemTyPair.second; - bool shouldMapField = - llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) { - if (!fir::isAllocatableType(memTy)) - return false; - - auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp); - if (!designateOp) - return false; - - return designateOp.getComponent() && - designateOp.getComponent()->strref() == field; - }) != mapVarForwardSlice.end(); - - // TODO Handle recursive record types. Adapting - // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR - // entities might be helpful here. - - if (!shouldMapField) + if (!fir::isAllocatableType(memTy)) continue; - int32_t fieldIdx = recordType.getFieldIndex(field); - bool alreadyMapped = [&]() { - if (op.getMembersIndexAttr()) - for (auto indexList : op.getMembersIndexAttr()) { - auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList); - if (indexListAttr.size() == 1 && - mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() == - fieldIdx) - return true; - } - - return false; - }(); - - if (alreadyMapped) + bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) { + auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv); + return designateOp && designateOp.getComponent() && + designateOp.getComponent()->strref() == field; + }); + if (!referenced) continue; + int32_t fieldIdx = recordType.getFieldIndex(field); builder.setInsertionPoint(op); fir::IntOrValue idxConst = mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx); auto fieldCoord = fir::CoordinateOp::create( builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(), llvm::SmallVector<fir::IntOrValue, 1>{idxConst}); - fir::factory::AddrAndBoundsInfo info = - fir::factory::getDataOperandBaseAddr( - builder, fieldCoord, /*isOptional=*/false, op.getLoc()); - llvm::SmallVector<mlir::Value> bounds = - fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp, - mlir::omp::MapBoundsType>( - builder, info, - hlfir::translateToExtendedValue(op.getLoc(), builder, - hlfir::Entity{fieldCoord}) - .first, - /*dataExvIsAssumedSize=*/false, op.getLoc()); - - mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create( - builder, op.getLoc(), fieldCoord.getResult().getType(), - fieldCoord.getResult(), - mlir::TypeAttr::get( - fir::unwrapRefType(fieldCoord.getResult().getType())), - op.getMapTypeAttr(), - builder.getAttr<mlir::omp::VariableCaptureKindAttr>( - mlir::omp::VariableCaptureKind::ByRef), - /*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{}, - /*members_index=*/mlir::ArrayAttr{}, bounds, - /*mapperId=*/mlir::FlatSymbolRefAttr(), - builder.getStringAttr(op.getNameAttr().strref() + "." + field + - ".implicit_map"), - /*partial_map=*/builder.getBoolAttr(false)); - newMapOpsForFields.emplace_back(fieldMapOp); - fieldIndicies.emplace_back(fieldIdx); + int64_t fieldIdx64 = static_cast<int64_t>(fieldIdx); + llvm::SmallVector<int64_t, 1> idxPath{fieldIdx64}; + appendMemberMapIfNew(op, builder, op.getLoc(), fieldCoord, idxPath, + field, newMapOpsForFields, newMemberIndexPaths); + } + + // Handle nested allocatable fields along any component chain + // referenced in the region via HLFIR designates. + llvm::SmallVector<llvm::SmallVector<int64_t>> seenIndexPaths; + for (mlir::Operation *sliceOp : mapVarForwardSlice) { + auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp); + if (!designateOp || !designateOp.getComponent()) + continue; + llvm::SmallVector<llvm::StringRef> compPathReversed; + compPathReversed.push_back(designateOp.getComponent()->strref()); + mlir::Value curBase = designateOp.getMemref(); + bool rootedAtMapArg = false; + while (true) { + if (auto parentDes = curBase.getDefiningOp<hlfir::DesignateOp>()) { + if (!parentDes.getComponent()) + break; + compPathReversed.push_back(parentDes.getComponent()->strref()); + curBase = parentDes.getMemref(); + continue; + } + if (auto decl = curBase.getDefiningOp<hlfir::DeclareOp>()) { + if (auto barg = + mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref())) + rootedAtMapArg = (barg == opBlockArg); + } else if (auto blockArg = + mlir::dyn_cast_or_null<mlir::BlockArgument>( + curBase)) { + rootedAtMapArg = (blockArg == opBlockArg); + } + break; + } + // Only process nested paths (2+ components). Single-component paths + // for direct fields are handled above. + if (!rootedAtMapArg || compPathReversed.size() < 2) + continue; + builder.setInsertionPoint(op); + llvm::SmallVector<int64_t> indexPath; + mlir::Type curTy = underlyingType; + mlir::Value coordRef = op.getVarPtr(); + bool validPath = true; + for (llvm::StringRef compName : llvm::reverse(compPathReversed)) { + auto recTy = mlir::dyn_cast<fir::RecordType>(curTy); + if (!recTy) { + validPath = false; + break; + } + int32_t idx = recTy.getFieldIndex(compName); + if (idx < 0) { + validPath = false; + break; + } + indexPath.push_back(idx); + mlir::Type memTy = recTy.getType(idx); + fir::IntOrValue idxConst = + mlir::IntegerAttr::get(builder.getI32Type(), idx); + coordRef = fir::CoordinateOp::create( + builder, op.getLoc(), builder.getRefType(memTy), coordRef, + llvm::SmallVector<fir::IntOrValue, 1>{idxConst}); + curTy = memTy; + } + if (!validPath) + continue; + if (auto finalRefTy = + mlir::dyn_cast<fir::ReferenceType>(coordRef.getType())) { + mlir::Type eleTy = finalRefTy.getElementType(); + if (fir::isAllocatableType(eleTy)) { + if (!containsPath(seenIndexPaths, indexPath)) { + seenIndexPaths.emplace_back(indexPath.begin(), indexPath.end()); + appendMemberMapIfNew(op, builder, op.getLoc(), coordRef, + indexPath, compPathReversed.front(), + newMapOpsForFields, newMemberIndexPaths); + } + } + } } if (newMapOpsForFields.empty()) return mlir::WalkResult::advance(); - op.getMembersMutable().append(newMapOpsForFields); + // Deduplicate by index path to avoid emitting duplicate members for + // the same component. Use a set-based key to keep this near O(n). + llvm::SmallVector<mlir::Value> dedupMapOps; + llvm::SmallVector<llvm::SmallVector<int64_t>> dedupIndexPaths; + llvm::StringSet<> seenKeys; + for (auto [i, mapOp] : llvm::enumerate(newMapOpsForFields)) { + const auto &path = newMemberIndexPaths[i]; + llvm::SmallString<64> key; + buildPathKey(path, key); + if (seenKeys.contains(key)) + continue; + seenKeys.insert(key); + dedupMapOps.push_back(mapOp); + dedupIndexPaths.emplace_back(path.begin(), path.end()); + } + op.getMembersMutable().append(dedupMapOps); llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices; - mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr(); - - if (oldMembersIdxAttr) - for (mlir::Attribute indexList : oldMembersIdxAttr) { + if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr()) + for (mlir::Attribute indexList : oldAttr) { llvm::SmallVector<int64_t> listVec; for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList)) @@ -796,10 +944,8 @@ class MapInfoFinalizationPass newMemberIndices.emplace_back(std::move(listVec)); } - - for (int64_t newFieldIdx : fieldIndicies) - newMemberIndices.emplace_back( - llvm::SmallVector<int64_t>(1, newFieldIdx)); + for (auto &path : dedupIndexPaths) + newMemberIndices.emplace_back(path); op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices)); op.setPartialMap(true); |