aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Optimizer')
-rw-r--r--flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp294
1 files changed, 220 insertions, 74 deletions
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 57be863..e595e61 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -41,7 +41,9 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/StringSet.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cstddef>
#include <iterator>
@@ -75,6 +77,112 @@ class MapInfoFinalizationPass
/// | |
std::map<mlir::Operation *, mlir::Value> localBoxAllocas;
+ /// Return true if the given path exists in a list of paths.
+ static bool
+ containsPath(const llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &paths,
+ llvm::ArrayRef<int64_t> path) {
+ return llvm::any_of(paths, [&](const llvm::SmallVector<int64_t> &p) {
+ return p.size() == path.size() &&
+ std::equal(p.begin(), p.end(), path.begin());
+ });
+ }
+
+ /// Return true if the given path is already present in
+ /// op.getMembersIndexAttr().
+ static bool mappedIndexPathExists(mlir::omp::MapInfoOp op,
+ llvm::ArrayRef<int64_t> indexPath) {
+ if (mlir::ArrayAttr attr = op.getMembersIndexAttr()) {
+ for (mlir::Attribute list : attr) {
+ auto listAttr = mlir::cast<mlir::ArrayAttr>(list);
+ if (listAttr.size() != indexPath.size())
+ continue;
+ bool allEq = true;
+ for (auto [i, val] : llvm::enumerate(listAttr)) {
+ if (mlir::cast<mlir::IntegerAttr>(val).getInt() != indexPath[i]) {
+ allEq = false;
+ break;
+ }
+ }
+ if (allEq)
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /// Build a compact string key for an index path for set-based
+ /// deduplication. Format: "N:v0,v1,..." where N is the length.
+ static void buildPathKey(llvm::ArrayRef<int64_t> path,
+ llvm::SmallString<64> &outKey) {
+ outKey.clear();
+ llvm::raw_svector_ostream os(outKey);
+ os << path.size() << ':';
+ for (size_t i = 0; i < path.size(); ++i) {
+ if (i)
+ os << ',';
+ os << path[i];
+ }
+ }
+
+ /// Create the member map for coordRef and append it (and its index
+ /// path) to the provided new* vectors, if it is not already present.
+ void appendMemberMapIfNew(
+ mlir::omp::MapInfoOp op, fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value coordRef, llvm::ArrayRef<int64_t> indexPath,
+ llvm::StringRef memberName,
+ llvm::SmallVectorImpl<mlir::Value> &newMapOpsForFields,
+ llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &newMemberIndexPaths) {
+ // Local de-dup within this op invocation.
+ if (containsPath(newMemberIndexPaths, indexPath))
+ return;
+ // Global de-dup against already present member indices.
+ if (mappedIndexPathExists(op, indexPath))
+ return;
+
+ if (op.getMapperId()) {
+ mlir::omp::DeclareMapperOp symbol =
+ mlir::SymbolTable::lookupNearestSymbolFrom<
+ mlir::omp::DeclareMapperOp>(op, op.getMapperIdAttr());
+ assert(symbol && "missing symbol for declare mapper identifier");
+ mlir::omp::DeclareMapperInfoOp mapperInfo = symbol.getDeclareMapperInfo();
+ // TODO: Probably a way to cache these keys in someway so we don't
+ // constantly go through the process of rebuilding them on every check, to
+ // save some cycles, but it can wait for a subsequent patch.
+ for (auto v : mapperInfo.getMapVars()) {
+ mlir::omp::MapInfoOp map =
+ mlir::cast<mlir::omp::MapInfoOp>(v.getDefiningOp());
+ if (!map.getMembers().empty() && mappedIndexPathExists(map, indexPath))
+ return;
+ }
+ }
+
+ builder.setInsertionPoint(op);
+ fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
+ builder, coordRef, /*isOptional=*/false, loc);
+ llvm::SmallVector<mlir::Value> bounds = fir::factory::genImplicitBoundsOps<
+ mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
+ builder, info,
+ hlfir::translateToExtendedValue(loc, builder, hlfir::Entity{coordRef})
+ .first,
+ /*dataExvIsAssumedSize=*/false, loc);
+
+ mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
+ builder, loc, coordRef.getType(), coordRef,
+ mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())),
+ op.getMapTypeAttr(),
+ builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+ mlir::omp::VariableCaptureKind::ByRef),
+ /*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
+ /*members_index=*/mlir::ArrayAttr{}, bounds,
+ /*mapperId=*/mlir::FlatSymbolRefAttr(),
+ builder.getStringAttr(op.getNameAttr().strref() + "." + memberName +
+ ".implicit_map"),
+ /*partial_map=*/builder.getBoolAttr(false));
+
+ newMapOpsForFields.emplace_back(fieldMapOp);
+ newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
+ }
+
/// getMemberUserList gathers all users of a particular MapInfoOp that are
/// other MapInfoOp's and places them into the mapMemberUsers list, which
/// records the map that the current argument MapInfoOp "op" is part of
@@ -363,7 +471,7 @@ class MapInfoFinalizationPass
mlir::ArrayAttr newMembersAttr;
mlir::SmallVector<mlir::Value> newMembers;
llvm::SmallVector<llvm::SmallVector<int64_t>> memberIndices;
- bool IsHasDeviceAddr = isHasDeviceAddr(op, target);
+ bool isHasDeviceAddrFlag = isHasDeviceAddr(op, target);
if (!mapMemberUsers.empty() || !op.getMembers().empty())
getMemberIndicesAsVectors(
@@ -406,7 +514,7 @@ class MapInfoFinalizationPass
mapUser.parent.getMembersMutable().assign(newMemberOps);
mapUser.parent.setMembersIndexAttr(
builder.create2DI64ArrayAttr(memberIndices));
- } else if (!IsHasDeviceAddr) {
+ } else if (!isHasDeviceAddrFlag) {
auto baseAddr =
genBaseAddrMap(descriptor, op.getBounds(), op.getMapType(), builder);
newMembers.push_back(baseAddr);
@@ -429,7 +537,7 @@ class MapInfoFinalizationPass
// The contents of the descriptor (the base address in particular) will
// remain unchanged though.
uint64_t mapType = op.getMapType();
- if (IsHasDeviceAddr) {
+ if (isHasDeviceAddrFlag) {
mapType |= llvm::to_underlying(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
}
@@ -701,94 +809,134 @@ class MapInfoFinalizationPass
auto recordType = mlir::cast<fir::RecordType>(underlyingType);
llvm::SmallVector<mlir::Value> newMapOpsForFields;
- llvm::SmallVector<int64_t> fieldIndicies;
+ llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndexPaths;
+ // 1) Handle direct top-level allocatable fields.
for (auto fieldMemTyPair : recordType.getTypeList()) {
auto &field = fieldMemTyPair.first;
auto memTy = fieldMemTyPair.second;
- bool shouldMapField =
- llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
- if (!fir::isAllocatableType(memTy))
- return false;
-
- auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
- if (!designateOp)
- return false;
-
- return designateOp.getComponent() &&
- designateOp.getComponent()->strref() == field;
- }) != mapVarForwardSlice.end();
-
- // TODO Handle recursive record types. Adapting
- // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
- // entities might be helpful here.
-
- if (!shouldMapField)
+ if (!fir::isAllocatableType(memTy))
continue;
- int32_t fieldIdx = recordType.getFieldIndex(field);
- bool alreadyMapped = [&]() {
- if (op.getMembersIndexAttr())
- for (auto indexList : op.getMembersIndexAttr()) {
- auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
- if (indexListAttr.size() == 1 &&
- mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
- fieldIdx)
- return true;
- }
-
- return false;
- }();
-
- if (alreadyMapped)
+ bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) {
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
+ return designateOp && designateOp.getComponent() &&
+ designateOp.getComponent()->strref() == field;
+ });
+ if (!referenced)
continue;
+ int32_t fieldIdx = recordType.getFieldIndex(field);
builder.setInsertionPoint(op);
fir::IntOrValue idxConst =
mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
auto fieldCoord = fir::CoordinateOp::create(
builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
- fir::factory::AddrAndBoundsInfo info =
- fir::factory::getDataOperandBaseAddr(
- builder, fieldCoord, /*isOptional=*/false, op.getLoc());
- llvm::SmallVector<mlir::Value> bounds =
- fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
- mlir::omp::MapBoundsType>(
- builder, info,
- hlfir::translateToExtendedValue(op.getLoc(), builder,
- hlfir::Entity{fieldCoord})
- .first,
- /*dataExvIsAssumedSize=*/false, op.getLoc());
-
- mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
- builder, op.getLoc(), fieldCoord.getResult().getType(),
- fieldCoord.getResult(),
- mlir::TypeAttr::get(
- fir::unwrapRefType(fieldCoord.getResult().getType())),
- op.getMapTypeAttr(),
- builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
- mlir::omp::VariableCaptureKind::ByRef),
- /*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
- /*members_index=*/mlir::ArrayAttr{}, bounds,
- /*mapperId=*/mlir::FlatSymbolRefAttr(),
- builder.getStringAttr(op.getNameAttr().strref() + "." + field +
- ".implicit_map"),
- /*partial_map=*/builder.getBoolAttr(false));
- newMapOpsForFields.emplace_back(fieldMapOp);
- fieldIndicies.emplace_back(fieldIdx);
+ int64_t fieldIdx64 = static_cast<int64_t>(fieldIdx);
+ llvm::SmallVector<int64_t, 1> idxPath{fieldIdx64};
+ appendMemberMapIfNew(op, builder, op.getLoc(), fieldCoord, idxPath,
+ field, newMapOpsForFields, newMemberIndexPaths);
+ }
+
+ // Handle nested allocatable fields along any component chain
+ // referenced in the region via HLFIR designates.
+ llvm::SmallVector<llvm::SmallVector<int64_t>> seenIndexPaths;
+ for (mlir::Operation *sliceOp : mapVarForwardSlice) {
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+ if (!designateOp || !designateOp.getComponent())
+ continue;
+ llvm::SmallVector<llvm::StringRef> compPathReversed;
+ compPathReversed.push_back(designateOp.getComponent()->strref());
+ mlir::Value curBase = designateOp.getMemref();
+ bool rootedAtMapArg = false;
+ while (true) {
+ if (auto parentDes = curBase.getDefiningOp<hlfir::DesignateOp>()) {
+ if (!parentDes.getComponent())
+ break;
+ compPathReversed.push_back(parentDes.getComponent()->strref());
+ curBase = parentDes.getMemref();
+ continue;
+ }
+ if (auto decl = curBase.getDefiningOp<hlfir::DeclareOp>()) {
+ if (auto barg =
+ mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref()))
+ rootedAtMapArg = (barg == opBlockArg);
+ } else if (auto blockArg =
+ mlir::dyn_cast_or_null<mlir::BlockArgument>(
+ curBase)) {
+ rootedAtMapArg = (blockArg == opBlockArg);
+ }
+ break;
+ }
+ // Only process nested paths (2+ components). Single-component paths
+ // for direct fields are handled above.
+ if (!rootedAtMapArg || compPathReversed.size() < 2)
+ continue;
+ builder.setInsertionPoint(op);
+ llvm::SmallVector<int64_t> indexPath;
+ mlir::Type curTy = underlyingType;
+ mlir::Value coordRef = op.getVarPtr();
+ bool validPath = true;
+ for (llvm::StringRef compName : llvm::reverse(compPathReversed)) {
+ auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
+ if (!recTy) {
+ validPath = false;
+ break;
+ }
+ int32_t idx = recTy.getFieldIndex(compName);
+ if (idx < 0) {
+ validPath = false;
+ break;
+ }
+ indexPath.push_back(idx);
+ mlir::Type memTy = recTy.getType(idx);
+ fir::IntOrValue idxConst =
+ mlir::IntegerAttr::get(builder.getI32Type(), idx);
+ coordRef = fir::CoordinateOp::create(
+ builder, op.getLoc(), builder.getRefType(memTy), coordRef,
+ llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
+ curTy = memTy;
+ }
+ if (!validPath)
+ continue;
+ if (auto finalRefTy =
+ mlir::dyn_cast<fir::ReferenceType>(coordRef.getType())) {
+ mlir::Type eleTy = finalRefTy.getElementType();
+ if (fir::isAllocatableType(eleTy)) {
+ if (!containsPath(seenIndexPaths, indexPath)) {
+ seenIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
+ appendMemberMapIfNew(op, builder, op.getLoc(), coordRef,
+ indexPath, compPathReversed.front(),
+ newMapOpsForFields, newMemberIndexPaths);
+ }
+ }
+ }
}
if (newMapOpsForFields.empty())
return mlir::WalkResult::advance();
- op.getMembersMutable().append(newMapOpsForFields);
+ // Deduplicate by index path to avoid emitting duplicate members for
+ // the same component. Use a set-based key to keep this near O(n).
+ llvm::SmallVector<mlir::Value> dedupMapOps;
+ llvm::SmallVector<llvm::SmallVector<int64_t>> dedupIndexPaths;
+ llvm::StringSet<> seenKeys;
+ for (auto [i, mapOp] : llvm::enumerate(newMapOpsForFields)) {
+ const auto &path = newMemberIndexPaths[i];
+ llvm::SmallString<64> key;
+ buildPathKey(path, key);
+ if (seenKeys.contains(key))
+ continue;
+ seenKeys.insert(key);
+ dedupMapOps.push_back(mapOp);
+ dedupIndexPaths.emplace_back(path.begin(), path.end());
+ }
+ op.getMembersMutable().append(dedupMapOps);
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
- mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
-
- if (oldMembersIdxAttr)
- for (mlir::Attribute indexList : oldMembersIdxAttr) {
+ if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr())
+ for (mlir::Attribute indexList : oldAttr) {
llvm::SmallVector<int64_t> listVec;
for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
@@ -796,10 +944,8 @@ class MapInfoFinalizationPass
newMemberIndices.emplace_back(std::move(listVec));
}
-
- for (int64_t newFieldIdx : fieldIndicies)
- newMemberIndices.emplace_back(
- llvm::SmallVector<int64_t>(1, newFieldIdx));
+ for (auto &path : dedupIndexPaths)
+ newMemberIndices.emplace_back(path);
op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
op.setPartialMap(true);