aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Optimizer')
-rw-r--r--flang/lib/Optimizer/CodeGen/CodeGen.cpp16
-rw-r--r--flang/lib/Optimizer/OpenACC/CMakeLists.txt1
-rw-r--r--flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp79
-rw-r--r--flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp191
-rw-r--r--flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt12
-rw-r--r--flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp100
-rw-r--r--flang/lib/Optimizer/Support/Utils.cpp10
7 files changed, 399 insertions, 10 deletions
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 0afb295..70bb43a2 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -176,6 +176,19 @@ struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> {
llvm::LogicalResult
matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
+
+ if (auto gpuMod = addr->getParentOfType<mlir::gpu::GPUModuleOp>()) {
+ auto global = gpuMod.lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol());
+ replaceWithAddrOfOrASCast(
+ rewriter, addr->getLoc(),
+ global ? global.getAddrSpace() : getGlobalAddressSpace(rewriter),
+ getProgramAddressSpace(rewriter),
+ global ? global.getSymName()
+ : addr.getSymbol().getRootReference().getValue(),
+ convertType(addr.getType()), addr);
+ return mlir::success();
+ }
+
auto global = addr->getParentOfType<mlir::ModuleOp>()
.lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol());
replaceWithAddrOfOrASCast(
@@ -3231,7 +3244,8 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
if (global.getDataAttr() &&
*global.getDataAttr() == cuf::DataAttribute::Constant)
- TODO(global.getLoc(), "CUDA Fortran CONSTANT variable code generation");
+ g.setAddrSpace(
+ static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Constant));
rewriter.eraseOp(global);
return mlir::success();
diff --git a/flang/lib/Optimizer/OpenACC/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/CMakeLists.txt
index fc23e64..790b9fd 100644
--- a/flang/lib/Optimizer/OpenACC/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenACC/CMakeLists.txt
@@ -1 +1,2 @@
add_subdirectory(Support)
+add_subdirectory(Transforms)
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
index 89aa010..9bf10b5 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
@@ -21,6 +21,7 @@
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/Support/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/BuiltinOps.h"
@@ -352,6 +353,14 @@ getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
// calculation op.
mlir::Value baseRef =
llvm::TypeSwitch<mlir::Operation *, mlir::Value>(op)
+ .Case<fir::DeclareOp>([&](auto op) {
+ // If this declare binds a view with an underlying storage operand,
+ // treat that storage as the base reference. Otherwise, fall back
+ // to the declared memref.
+ if (auto storage = op.getStorage())
+ return storage;
+ return mlir::Value(varPtr);
+ })
.Case<hlfir::DesignateOp>([&](auto op) {
// Get the base object.
return op.getMemref();
@@ -548,14 +557,27 @@ template <typename Ty>
mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit(
mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
mlir::TypedValue<mlir::acc::MappableType> var, llvm::StringRef varName,
- mlir::ValueRange extents, mlir::Value initVal) const {
+ mlir::ValueRange extents, mlir::Value initVal, bool &needsDestroy) const {
+ needsDestroy = false;
mlir::Value retVal;
mlir::Type unwrappedTy = fir::unwrapRefType(type);
mlir::ModuleOp mod = builder.getInsertionBlock()
->getParent()
->getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder firBuilder(builder, mod);
+ if (auto recType = llvm::dyn_cast<fir::RecordType>(
+ fir::getFortranElementType(unwrappedTy))) {
+ // Need to make deep copies of allocatable components.
+ if (fir::isRecordWithAllocatableMember(recType))
+ TODO(loc,
+ "OpenACC: privatizing derived type with allocatable components");
+ // Need to decide if user assignment/final routine should be called.
+ if (fir::isRecordWithFinalRoutine(recType, mod).value_or(false))
+ TODO(loc, "OpenACC: privatizing derived type with user assignment or "
+ "final routine ");
+ }
+
+ fir::FirOpBuilder firBuilder(builder, mod);
auto getDeclareOpForType = [&](mlir::Type ty) -> hlfir::DeclareOp {
auto alloca = fir::AllocaOp::create(firBuilder, loc, ty);
return hlfir::DeclareOp::create(firBuilder, loc, alloca, varName);
@@ -615,9 +637,11 @@ mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit(
mlir::Value firClass =
fir::EmboxOp::create(builder, loc, boxTy, allocatedScalar);
fir::StoreOp::create(builder, loc, firClass, retVal);
+ needsDestroy = true;
} else if (mlir::isa<fir::SequenceType>(innerTy)) {
hlfir::Entity source = hlfir::Entity{var};
- auto [temp, cleanup] = hlfir::createTempFromMold(loc, firBuilder, source);
+ auto [temp, cleanupFlag] =
+ hlfir::createTempFromMold(loc, firBuilder, source);
if (fir::isa_ref_type(type)) {
// When the temp is created - it is not a reference - thus we can
// end up with a type inconsistency. Therefore ensure storage is created
@@ -636,6 +660,9 @@ mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit(
} else {
retVal = temp;
}
+ // If heap was allocated, a destroy is required later.
+ if (cleanupFlag)
+ needsDestroy = true;
} else {
TODO(loc, "Unsupported boxed type for OpenACC private-like recipe");
}
@@ -667,23 +694,61 @@ template mlir::Value
OpenACCMappableModel<fir::BaseBoxType>::generatePrivateInit(
mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
mlir::TypedValue<mlir::acc::MappableType> var, llvm::StringRef varName,
- mlir::ValueRange extents, mlir::Value initVal) const;
+ mlir::ValueRange extents, mlir::Value initVal, bool &needsDestroy) const;
template mlir::Value
OpenACCMappableModel<fir::ReferenceType>::generatePrivateInit(
mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
mlir::TypedValue<mlir::acc::MappableType> var, llvm::StringRef varName,
- mlir::ValueRange extents, mlir::Value initVal) const;
+ mlir::ValueRange extents, mlir::Value initVal, bool &needsDestroy) const;
template mlir::Value OpenACCMappableModel<fir::HeapType>::generatePrivateInit(
mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
mlir::TypedValue<mlir::acc::MappableType> var, llvm::StringRef varName,
- mlir::ValueRange extents, mlir::Value initVal) const;
+ mlir::ValueRange extents, mlir::Value initVal, bool &needsDestroy) const;
template mlir::Value
OpenACCMappableModel<fir::PointerType>::generatePrivateInit(
mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
mlir::TypedValue<mlir::acc::MappableType> var, llvm::StringRef varName,
- mlir::ValueRange extents, mlir::Value initVal) const;
+ mlir::ValueRange extents, mlir::Value initVal, bool &needsDestroy) const;
+
+template <typename Ty>
+bool OpenACCMappableModel<Ty>::generatePrivateDestroy(
+ mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value privatized) const {
+ mlir::Type unwrappedTy = fir::unwrapRefType(type);
+ // For boxed scalars allocated with AllocMem during init, free the heap.
+ if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(unwrappedTy)) {
+ mlir::Value boxVal = privatized;
+ if (fir::isa_ref_type(boxVal.getType()))
+ boxVal = fir::LoadOp::create(builder, loc, boxVal);
+ mlir::Value addr = fir::BoxAddrOp::create(builder, loc, boxVal);
+ // FreeMem only accepts fir.heap and this may not be represented in the box
+ // type if the privatized entity is not an allocatable.
+ mlir::Type heapType =
+ fir::HeapType::get(fir::unwrapRefType(addr.getType()));
+ if (heapType != addr.getType())
+ addr = fir::ConvertOp::create(builder, loc, heapType, addr);
+ fir::FreeMemOp::create(builder, loc, addr);
+ return true;
+ }
+
+ // Nothing to do for other categories by default, they are stack allocated.
+ return true;
+}
+
+template bool OpenACCMappableModel<fir::BaseBoxType>::generatePrivateDestroy(
+ mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value privatized) const;
+template bool OpenACCMappableModel<fir::ReferenceType>::generatePrivateDestroy(
+ mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value privatized) const;
+template bool OpenACCMappableModel<fir::HeapType>::generatePrivateDestroy(
+ mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value privatized) const;
+template bool OpenACCMappableModel<fir::PointerType>::generatePrivateDestroy(
+ mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value privatized) const;
} // namespace fir::acc
diff --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp
new file mode 100644
index 0000000..4840a99
--- /dev/null
+++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp
@@ -0,0 +1,191 @@
+//===- ACCRecipeBufferization.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Bufferize OpenACC recipes that yield fir.box<T> to operate on
+// fir.ref<fir.box<T>> and update uses accordingly.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/OpenACC/Passes.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/Visitors.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace fir::acc {
+#define GEN_PASS_DEF_ACCRECIPEBUFFERIZATION
+#include "flang/Optimizer/OpenACC/Passes.h.inc"
+} // namespace fir::acc
+
+namespace {
+
+class BufferizeInterface {
+public:
+ static std::optional<mlir::Type> mustBufferize(mlir::Type recipeType) {
+ if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(recipeType))
+ return fir::ReferenceType::get(boxTy);
+ return std::nullopt;
+ }
+
+ static mlir::Operation *load(mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value value) {
+ return builder.create<fir::LoadOp>(loc, value);
+ }
+
+ static mlir::Value placeInMemory(mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value value) {
+ auto alloca = builder.create<fir::AllocaOp>(loc, value.getType());
+ builder.create<fir::StoreOp>(loc, value, alloca);
+ return alloca;
+ }
+};
+
+static void bufferizeRegionArgsAndYields(mlir::Region &region,
+ mlir::Location loc, mlir::Type oldType,
+ mlir::Type newType) {
+ if (region.empty())
+ return;
+
+ mlir::OpBuilder builder(&region);
+ for (mlir::BlockArgument arg : region.getArguments()) {
+ if (arg.getType() == oldType) {
+ arg.setType(newType);
+ if (!arg.use_empty()) {
+ mlir::Operation *loadOp = BufferizeInterface::load(builder, loc, arg);
+ arg.replaceAllUsesExcept(loadOp->getResult(0), loadOp);
+ }
+ }
+ }
+ if (auto yield =
+ llvm::dyn_cast<mlir::acc::YieldOp>(region.back().getTerminator())) {
+ llvm::SmallVector<mlir::Value> newOperands;
+ newOperands.reserve(yield.getNumOperands());
+ bool changed = false;
+ for (mlir::Value oldYieldArg : yield.getOperands()) {
+ if (oldYieldArg.getType() == oldType) {
+ builder.setInsertionPoint(yield);
+ mlir::Value alloca =
+ BufferizeInterface::placeInMemory(builder, loc, oldYieldArg);
+ newOperands.push_back(alloca);
+ changed = true;
+ } else {
+ newOperands.push_back(oldYieldArg);
+ }
+ }
+ if (changed)
+ yield->setOperands(newOperands);
+ }
+}
+
+static void updateRecipeUse(mlir::ArrayAttr recipes, mlir::ValueRange operands,
+ llvm::StringRef recipeSymName,
+ mlir::Operation *computeOp) {
+ if (!recipes)
+ return;
+ for (auto [recipeSym, oldRes] : llvm::zip(recipes, operands)) {
+ if (llvm::cast<mlir::SymbolRefAttr>(recipeSym).getLeafReference() !=
+ recipeSymName)
+ continue;
+
+ mlir::Operation *dataOp = oldRes.getDefiningOp();
+ assert(dataOp && "dataOp must be paired with computeOp");
+ mlir::Location loc = dataOp->getLoc();
+ mlir::OpBuilder builder(dataOp);
+ llvm::TypeSwitch<mlir::Operation *, void>(dataOp)
+ .Case<mlir::acc::PrivateOp, mlir::acc::FirstprivateOp,
+ mlir::acc::ReductionOp>([&](auto privateOp) {
+ builder.setInsertionPointAfterValue(privateOp.getVar());
+ mlir::Value alloca = BufferizeInterface::placeInMemory(
+ builder, loc, privateOp.getVar());
+ privateOp.getVarMutable().assign(alloca);
+ privateOp.getAccVar().setType(alloca.getType());
+ });
+
+ llvm::SmallVector<mlir::Operation *> users(oldRes.getUsers().begin(),
+ oldRes.getUsers().end());
+ for (mlir::Operation *useOp : users) {
+ if (useOp == computeOp)
+ continue;
+ builder.setInsertionPoint(useOp);
+ mlir::Operation *load = BufferizeInterface::load(builder, loc, oldRes);
+ useOp->replaceUsesOfWith(oldRes, load->getResult(0));
+ }
+ }
+}
+
+class ACCRecipeBufferization
+ : public fir::acc::impl::ACCRecipeBufferizationBase<
+ ACCRecipeBufferization> {
+public:
+ void runOnOperation() override {
+ mlir::ModuleOp module = getOperation();
+
+ llvm::SmallVector<llvm::StringRef> recipeNames;
+ module.walk([&](mlir::Operation *recipe) {
+ llvm::TypeSwitch<mlir::Operation *, void>(recipe)
+ .Case<mlir::acc::PrivateRecipeOp, mlir::acc::FirstprivateRecipeOp,
+ mlir::acc::ReductionRecipeOp>([&](auto recipe) {
+ mlir::Type oldType = recipe.getType();
+ auto bufferizedType =
+ BufferizeInterface::mustBufferize(recipe.getType());
+ if (!bufferizedType)
+ return;
+ recipe.setTypeAttr(mlir::TypeAttr::get(*bufferizedType));
+ mlir::Location loc = recipe.getLoc();
+ using RecipeOp = decltype(recipe);
+ bufferizeRegionArgsAndYields(recipe.getInitRegion(), loc, oldType,
+ *bufferizedType);
+ if constexpr (std::is_same_v<RecipeOp,
+ mlir::acc::FirstprivateRecipeOp>)
+ bufferizeRegionArgsAndYields(recipe.getCopyRegion(), loc, oldType,
+ *bufferizedType);
+ if constexpr (std::is_same_v<RecipeOp,
+ mlir::acc::ReductionRecipeOp>)
+ bufferizeRegionArgsAndYields(recipe.getCombinerRegion(), loc,
+ oldType, *bufferizedType);
+ bufferizeRegionArgsAndYields(recipe.getDestroyRegion(), loc,
+ oldType, *bufferizedType);
+ recipeNames.push_back(recipe.getSymName());
+ });
+ });
+ if (recipeNames.empty())
+ return;
+
+ module.walk([&](mlir::Operation *op) {
+ llvm::TypeSwitch<mlir::Operation *, void>(op)
+ .Case<mlir::acc::LoopOp, mlir::acc::ParallelOp, mlir::acc::SerialOp>(
+ [&](auto computeOp) {
+ for (llvm::StringRef recipeName : recipeNames) {
+ if (computeOp.getPrivatizationRecipes())
+ updateRecipeUse(computeOp.getPrivatizationRecipesAttr(),
+ computeOp.getPrivateOperands(), recipeName,
+ op);
+ if (computeOp.getFirstprivatizationRecipes())
+ updateRecipeUse(
+ computeOp.getFirstprivatizationRecipesAttr(),
+ computeOp.getFirstprivateOperands(), recipeName, op);
+ if (computeOp.getReductionRecipes())
+ updateRecipeUse(computeOp.getReductionRecipesAttr(),
+ computeOp.getReductionOperands(),
+ recipeName, op);
+ }
+ });
+ });
+ }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::acc::createACCRecipeBufferizationPass() {
+ return std::make_unique<ACCRecipeBufferization>();
+}
diff --git a/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..2427da0
--- /dev/null
+++ b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_flang_library(FIROpenACCTransforms
+ ACCRecipeBufferization.cpp
+
+ DEPENDS
+ FIROpenACCPassesIncGen
+
+ LINK_LIBS
+ MLIRIR
+ MLIRPass
+ FIRDialect
+ MLIROpenACCDialect
+)
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 260e525..2bbd803 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -40,6 +40,7 @@
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -128,6 +129,17 @@ class MapInfoFinalizationPass
}
}
+ /// Return true if the module has an OpenMP requires clause that includes
+ /// unified_shared_memory.
+ static bool moduleRequiresUSM(mlir::ModuleOp module) {
+ assert(module && "invalid module");
+ if (auto req = module->getAttrOfType<mlir::omp::ClauseRequiresAttr>(
+ "omp.requires"))
+ return mlir::omp::bitEnumContainsAll(
+ req.getValue(), mlir::omp::ClauseRequires::unified_shared_memory);
+ return false;
+ }
+
/// Create the member map for coordRef and append it (and its index
/// path) to the provided new* vectors, if it is not already present.
void appendMemberMapIfNew(
@@ -425,8 +437,12 @@ class MapInfoFinalizationPass
mapFlags flags = mapFlags::OMP_MAP_TO |
(mapFlags(mapTypeFlag) &
- (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_CLOSE |
- mapFlags::OMP_MAP_ALWAYS));
+ (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_ALWAYS));
+ // For unified_shared_memory, we additionally add `CLOSE` on the descriptor
+ // to ensure device-local placement where required by tests relying on USM +
+ // close semantics.
+ if (moduleRequiresUSM(target->getParentOfType<mlir::ModuleOp>()))
+ flags |= mapFlags::OMP_MAP_CLOSE;
return llvm::to_underlying(flags);
}
@@ -518,6 +534,75 @@ class MapInfoFinalizationPass
return newMapInfoOp;
}
+ // Expand mappings of type(C_PTR) to map their `__address` field explicitly
+ // as a single pointer-sized member (USM-gated at callsite). This helps in
+ // USM scenarios to ensure the pointer-sized mapping is used.
+ mlir::omp::MapInfoOp genCptrMemberMap(mlir::omp::MapInfoOp op,
+ fir::FirOpBuilder &builder) {
+ if (!op.getMembers().empty())
+ return op;
+
+ mlir::Type varTy = fir::unwrapRefType(op.getVarPtr().getType());
+ if (!mlir::isa<fir::RecordType>(varTy))
+ return op;
+ auto recTy = mlir::cast<fir::RecordType>(varTy);
+ // If not a builtin C_PTR record, skip.
+ if (!recTy.getName().ends_with("__builtin_c_ptr"))
+ return op;
+
+ // Find the index of the c_ptr address component named "__address".
+ int32_t fieldIdx = recTy.getFieldIndex("__address");
+ if (fieldIdx < 0)
+ return op;
+
+ mlir::Location loc = op.getVarPtr().getLoc();
+ mlir::Type memTy = recTy.getType(fieldIdx);
+ fir::IntOrValue idxConst =
+ mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
+ mlir::Value coord = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(memTy), op.getVarPtr(),
+ llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
+
+ // Child for the `__address` member.
+ llvm::SmallVector<llvm::SmallVector<int64_t>> memberIdx = {{0}};
+ mlir::ArrayAttr newMembersAttr = builder.create2DI64ArrayAttr(memberIdx);
+ // Force CLOSE in USM paths so the pointer gets device-local placement
+ // when required by tests relying on USM + close semantics.
+ uint64_t mapTypeVal =
+ op.getMapType() |
+ llvm::to_underlying(
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
+ mlir::IntegerAttr mapTypeAttr = builder.getIntegerAttr(
+ builder.getIntegerType(64, /*isSigned=*/false), mapTypeVal);
+
+ mlir::omp::MapInfoOp memberMap = mlir::omp::MapInfoOp::create(
+ builder, loc, coord.getType(), coord,
+ mlir::TypeAttr::get(fir::unwrapRefType(coord.getType())), mapTypeAttr,
+ builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+ mlir::omp::VariableCaptureKind::ByRef),
+ /*varPtrPtr=*/mlir::Value{},
+ /*members=*/llvm::SmallVector<mlir::Value>{},
+ /*member_index=*/mlir::ArrayAttr{},
+ /*bounds=*/op.getBounds(),
+ /*mapperId=*/mlir::FlatSymbolRefAttr(),
+ /*name=*/op.getNameAttr(),
+ /*partial_map=*/builder.getBoolAttr(false));
+
+ // Rebuild the parent as a container with the `__address` member.
+ mlir::omp::MapInfoOp newParent = mlir::omp::MapInfoOp::create(
+ builder, op.getLoc(), op.getResult().getType(), op.getVarPtr(),
+ op.getVarTypeAttr(), mapTypeAttr, op.getMapCaptureTypeAttr(),
+ /*varPtrPtr=*/mlir::Value{},
+ /*members=*/llvm::SmallVector<mlir::Value>{memberMap},
+ /*member_index=*/newMembersAttr,
+ /*bounds=*/llvm::SmallVector<mlir::Value>{},
+ /*mapperId=*/mlir::FlatSymbolRefAttr(), op.getNameAttr(),
+ /*partial_map=*/builder.getBoolAttr(false));
+ op.replaceAllUsesWith(newParent.getResult());
+ op->erase();
+ return newParent;
+ }
+
mlir::omp::MapInfoOp genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
fir::FirOpBuilder &builder,
mlir::Operation *target) {
@@ -1169,6 +1254,17 @@ class MapInfoFinalizationPass
genBoxcharMemberMap(op, builder);
});
+ // Expand type(C_PTR) only when unified_shared_memory is required,
+ // to ensure device-visible pointer size/behavior in USM scenarios
+ // without changing default expectations elsewhere.
+ func->walk([&](mlir::omp::MapInfoOp op) {
+ // Only expand C_PTR members when unified_shared_memory is required.
+ if (!moduleRequiresUSM(func->getParentOfType<mlir::ModuleOp>()))
+ return;
+ builder.setInsertionPoint(op);
+ genCptrMemberMap(op, builder);
+ });
+
func->walk([&](mlir::omp::MapInfoOp op) {
// TODO: Currently only supports a single user for the MapInfoOp. This
// is fine for the moment, as the Fortran frontend will generate a
diff --git a/flang/lib/Optimizer/Support/Utils.cpp b/flang/lib/Optimizer/Support/Utils.cpp
index c71642c..92390e4a 100644
--- a/flang/lib/Optimizer/Support/Utils.cpp
+++ b/flang/lib/Optimizer/Support/Utils.cpp
@@ -51,6 +51,16 @@ std::optional<llvm::ArrayRef<int64_t>> fir::getComponentLowerBoundsIfNonDefault(
return std::nullopt;
}
+std::optional<bool>
+fir::isRecordWithFinalRoutine(fir::RecordType recordType, mlir::ModuleOp module,
+ const mlir::SymbolTable *symbolTable) {
+ fir::TypeInfoOp typeInfo =
+ fir::lookupTypeInfoOp(recordType, module, symbolTable);
+ if (!typeInfo)
+ return std::nullopt;
+ return !typeInfo.getNoFinal();
+}
+
mlir::LLVM::ConstantOp
fir::genConstantIndex(mlir::Location loc, mlir::Type ity,
mlir::ConversionPatternRewriter &rewriter,