aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Optimizer')
-rw-r--r--flang/lib/Optimizer/Analysis/AliasAnalysis.cpp362
-rw-r--r--flang/lib/Optimizer/Analysis/ArraySectionAnalyzer.cpp300
-rw-r--r--flang/lib/Optimizer/Analysis/CMakeLists.txt4
-rw-r--r--flang/lib/Optimizer/Analysis/TBAAForest.cpp9
-rw-r--r--flang/lib/Optimizer/Builder/CMakeLists.txt2
-rw-r--r--flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp1722
-rw-r--r--flang/lib/Optimizer/Builder/CUFCommon.cpp64
-rw-r--r--flang/lib/Optimizer/Builder/FIRBuilder.cpp80
-rw-r--r--flang/lib/Optimizer/Builder/HLFIRTools.cpp121
-rw-r--r--flang/lib/Optimizer/Builder/IntrinsicCall.cpp1210
-rw-r--r--flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp1
-rw-r--r--flang/lib/Optimizer/Builder/Runtime/Allocatable.cpp7
-rw-r--r--flang/lib/Optimizer/Builder/Runtime/Character.cpp36
-rw-r--r--flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp40
-rw-r--r--flang/lib/Optimizer/Builder/Runtime/Main.cpp2
-rw-r--r--flang/lib/Optimizer/Builder/Runtime/Reduction.cpp2
-rw-r--r--flang/lib/Optimizer/Builder/TemporaryStorage.cpp8
-rw-r--r--flang/lib/Optimizer/CodeGen/CodeGen.cpp175
-rw-r--r--flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp23
-rw-r--r--flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp1
-rw-r--r--flang/lib/Optimizer/CodeGen/PassDetail.h2
-rw-r--r--flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp7
-rw-r--r--flang/lib/Optimizer/CodeGen/Target.cpp57
-rw-r--r--flang/lib/Optimizer/CodeGen/TargetRewrite.cpp73
-rw-r--r--flang/lib/Optimizer/CodeGen/TypeConverter.cpp4
-rw-r--r--flang/lib/Optimizer/Dialect/CMakeLists.txt2
-rw-r--r--flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp14
-rw-r--r--flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp23
-rw-r--r--flang/lib/Optimizer/Dialect/FIROperationMoveOpInterface.cpp49
-rw-r--r--flang/lib/Optimizer/Dialect/FIROps.cpp357
-rw-r--r--flang/lib/Optimizer/Dialect/FIRType.cpp34
-rw-r--r--flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt11
-rw-r--r--flang/lib/Optimizer/Dialect/MIF/MIFOps.cpp59
-rw-r--r--flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp4
-rw-r--r--flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp25
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp18
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp47
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp60
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp381
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp531
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.h56
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp68
-rw-r--r--flang/lib/Optimizer/OpenACC/Analysis/CMakeLists.txt24
-rw-r--r--flang/lib/Optimizer/OpenACC/Analysis/FIROpenACCSupportAnalysis.cpp56
-rw-r--r--flang/lib/Optimizer/OpenACC/CMakeLists.txt1
-rw-r--r--flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt8
-rw-r--r--flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp227
-rw-r--r--flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp1089
-rw-r--r--flang/lib/Optimizer/OpenACC/Support/FIROpenACCUtils.cpp655
-rw-r--r--flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp62
-rw-r--r--flang/lib/Optimizer/OpenACC/Transforms/ACCInitializeFIRAnalyses.cpp56
-rw-r--r--flang/lib/Optimizer/OpenACC/Transforms/ACCOptimizeFirstprivateMap.cpp193
-rw-r--r--flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp59
-rw-r--r--flang/lib/Optimizer/OpenACC/Transforms/ACCUseDeviceCanonicalizer.cpp400
-rw-r--r--flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt11
-rw-r--r--flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp15
-rw-r--r--flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp45
-rw-r--r--flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp387
-rw-r--r--flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp2
-rw-r--r--flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp159
-rw-r--r--flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp246
-rw-r--r--flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp29
-rw-r--r--flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp137
-rw-r--r--flang/lib/Optimizer/OpenMP/Support/CMakeLists.txt1
-rw-r--r--flang/lib/Optimizer/OpenMP/Support/FIROpenMPOpsInterfaces.cpp102
-rw-r--r--flang/lib/Optimizer/OpenMP/Support/RegisterOpenMPExtensions.cpp1
-rw-r--r--flang/lib/Optimizer/Passes/CommandLineOpts.cpp1
-rw-r--r--flang/lib/Optimizer/Passes/Pipelines.cpp33
-rw-r--r--flang/lib/Optimizer/Support/CMakeLists.txt2
-rw-r--r--flang/lib/Optimizer/Support/Utils.cpp6
-rw-r--r--flang/lib/Optimizer/Transforms/AddAliasTags.cpp92
-rw-r--r--flang/lib/Optimizer/Transforms/AddDebugInfo.cpp449
-rw-r--r--flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp4
-rw-r--r--flang/lib/Optimizer/Transforms/CMakeLists.txt53
-rw-r--r--flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp (renamed from flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp)0
-rw-r--r--flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp445
-rw-r--r--flang/lib/Optimizer/Transforms/CUDA/CUFComputeSharedMemoryOffsetsAndSize.cpp (renamed from flang/lib/Optimizer/Transforms/CUFComputeSharedMemoryOffsetsAndSize.cpp)92
-rw-r--r--flang/lib/Optimizer/Transforms/CUDA/CUFDeviceFuncTransform.cpp250
-rw-r--r--flang/lib/Optimizer/Transforms/CUDA/CUFDeviceGlobal.cpp (renamed from flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp)0
-rw-r--r--flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp103
-rw-r--r--flang/lib/Optimizer/Transforms/CUDA/CUFGPUToLLVMConversion.cpp (renamed from flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp)7
-rw-r--r--flang/lib/Optimizer/Transforms/CUDA/CUFLaunchAttachAttr.cpp70
-rw-r--r--flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp (renamed from flang/lib/Optimizer/Transforms/CUFOpConversion.cpp)494
-rw-r--r--flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp120
-rw-r--r--flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp153
-rw-r--r--flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp40
-rw-r--r--flang/lib/Optimizer/Transforms/FIRToMemRef.cpp1061
-rw-r--r--flang/lib/Optimizer/Transforms/FIRToSCF.cpp134
-rw-r--r--flang/lib/Optimizer/Transforms/FunctionAttr.cpp8
-rw-r--r--flang/lib/Optimizer/Transforms/LoopInvariantCodeMotion.cpp323
-rw-r--r--flang/lib/Optimizer/Transforms/MIFOpConversion.cpp398
-rw-r--r--flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp4
-rw-r--r--flang/lib/Optimizer/Transforms/SetRuntimeCallAttributes.cpp5
-rw-r--r--flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp13
-rw-r--r--flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp1
-rw-r--r--flang/lib/Optimizer/Transforms/VScaleAttr.cpp25
96 files changed, 11253 insertions, 3149 deletions
diff --git a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
index 73ddd1f..0eb00e2 100644
--- a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
+++ b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Analysis/AliasAnalysis.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Dialect/FIRType.h"
@@ -21,12 +22,38 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
#define DEBUG_TYPE "fir-alias-analysis"
+llvm::cl::opt<bool> supportCrayPointers(
+ "unsafe-cray-pointers",
+ llvm::cl::desc("Support Cray POINTERs that ALIAS with non-TARGET data"),
+ llvm::cl::init(false));
+
+// Inspect for value-scoped Allocate effects and determine whether
+// 'candidate' is a new allocation. Returns SourceKind::Allocate if a
+// MemAlloc effect is attached
+static fir::AliasAnalysis::SourceKind
+classifyAllocateFromEffects(mlir::Operation *op, mlir::Value candidate) {
+ if (!op)
+ return fir::AliasAnalysis::SourceKind::Unknown;
+ auto interface = llvm::dyn_cast<mlir::MemoryEffectOpInterface>(op);
+ if (!interface)
+ return fir::AliasAnalysis::SourceKind::Unknown;
+ llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 4> effects;
+ interface.getEffects(effects);
+ for (mlir::MemoryEffects::EffectInstance &e : effects) {
+ if (mlir::isa<mlir::MemoryEffects::Allocate>(e.getEffect()) &&
+ e.getValue() && e.getValue() == candidate)
+ return fir::AliasAnalysis::SourceKind::Allocate;
+ }
+ return fir::AliasAnalysis::SourceKind::Unknown;
+}
+
//===----------------------------------------------------------------------===//
// AliasAnalysis: alias
//===----------------------------------------------------------------------===//
@@ -40,15 +67,28 @@ getAttrsFromVariable(fir::FortranVariableOpInterface var) {
attrs.set(fir::AliasAnalysis::Attribute::Pointer);
if (var.isIntentIn())
attrs.set(fir::AliasAnalysis::Attribute::IntentIn);
+ if (var.isCrayPointer())
+ attrs.set(fir::AliasAnalysis::Attribute::CrayPointer);
+ if (var.isCrayPointee())
+ attrs.set(fir::AliasAnalysis::Attribute::CrayPointee);
return attrs;
}
-static bool hasGlobalOpTargetAttr(mlir::Value v, fir::AddrOfOp op) {
- auto globalOpName =
- mlir::OperationName(fir::GlobalOp::getOperationName(), op->getContext());
- return fir::valueHasFirAttribute(
- v, fir::GlobalOp::getTargetAttrName(globalOpName));
+bool fir::AliasAnalysis::symbolMayHaveTargetAttr(mlir::SymbolRefAttr symbol,
+ mlir::Operation *from) {
+ assert(from);
+
+ // If we cannot find the nearest SymbolTable assume the worst.
+ const mlir::SymbolTable *symTab = getNearestSymbolTable(from);
+ if (!symTab)
+ return true;
+
+ if (auto globalOp = symTab->lookup<fir::GlobalOp>(symbol.getLeafReference()))
+ return globalOp.getTarget().value_or(false);
+
+ // If the symbol is not defined by fir.global assume the worst.
+ return true;
}
static bool isEvaluateInMemoryBlockArg(mlir::Value v) {
@@ -118,6 +158,18 @@ bool AliasAnalysis::Source::isPointer() const {
return attributes.test(Attribute::Pointer);
}
+bool AliasAnalysis::Source::isCrayPointee() const {
+ return attributes.test(Attribute::CrayPointee);
+}
+
+bool AliasAnalysis::Source::isCrayPointer() const {
+ return attributes.test(Attribute::CrayPointer);
+}
+
+bool AliasAnalysis::Source::isCrayPointerOrPointee() const {
+ return isCrayPointer() || isCrayPointee();
+}
+
bool AliasAnalysis::Source::isDummyArgument() const {
if (auto v = origin.u.dyn_cast<mlir::Value>()) {
return fir::isDummyArgument(v);
@@ -175,6 +227,34 @@ bool AliasAnalysis::Source::mayBeActualArgWithPtr(
return false;
}
+// Return true if the two locations cannot alias based
+// on the access data type, e.g. an address of a descriptor
+// cannot alias with an address of data (unless the data
+// may contain a descriptor).
+static bool noAliasBasedOnType(mlir::Value lhs, mlir::Value rhs) {
+ mlir::Type lhsType = lhs.getType();
+ mlir::Type rhsType = rhs.getType();
+ if (!fir::isa_ref_type(lhsType) || !fir::isa_ref_type(rhsType))
+ return false;
+ mlir::Type lhsElemType = fir::unwrapRefType(lhsType);
+ mlir::Type rhsElemType = fir::unwrapRefType(rhsType);
+ if (mlir::isa<fir::BaseBoxType>(lhsElemType) !=
+ mlir::isa<fir::BaseBoxType>(rhsElemType)) {
+ // One of the types is fir.box and another is not.
+ mlir::Type nonBoxType;
+ if (mlir::isa<fir::BaseBoxType>(lhsElemType))
+ nonBoxType = rhsElemType;
+ else
+ nonBoxType = lhsElemType;
+
+ if (!fir::isRecordWithDescriptorMember(nonBoxType)) {
+ LLVM_DEBUG(llvm::dbgs() << " no alias based on the access types\n");
+ return true;
+ }
+ }
+ return false;
+}
+
AliasResult AliasAnalysis::alias(mlir::Value lhs, mlir::Value rhs) {
// A wrapper around alias(Source lhsSrc, Source rhsSrc, mlir::Value lhs,
// mlir::Value rhs) This allows a user to provide Source that may be obtained
@@ -196,6 +276,10 @@ AliasResult AliasAnalysis::alias(Source lhsSrc, Source rhsSrc, mlir::Value lhs,
llvm::dbgs() << " rhs: " << rhs << "\n";
llvm::dbgs() << " rhsSrc: " << rhsSrc << "\n";);
+ // Disambiguate data and descriptors addresses.
+ if (noAliasBasedOnType(lhs, rhs))
+ return AliasResult::NoAlias;
+
// Indirect case currently not handled. Conservatively assume
// it aliases with everything
if (lhsSrc.kind >= SourceKind::Indirect ||
@@ -204,6 +288,15 @@ AliasResult AliasAnalysis::alias(Source lhsSrc, Source rhsSrc, mlir::Value lhs,
return AliasResult::MayAlias;
}
+ // Cray pointers/pointees can alias with anything via LOC.
+ if (supportCrayPointers) {
+ if (lhsSrc.isCrayPointerOrPointee() || rhsSrc.isCrayPointerOrPointee()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " aliasing because of Cray pointer/pointee\n");
+ return AliasResult::MayAlias;
+ }
+ }
+
if (lhsSrc.kind == rhsSrc.kind) {
// If the kinds and origins are the same, then lhs and rhs must alias unless
// either source is approximate. Approximate sources are for parts of the
@@ -214,6 +307,17 @@ AliasResult AliasAnalysis::alias(Source lhsSrc, Source rhsSrc, mlir::Value lhs,
<< " aliasing because same source kind and origin\n");
if (approximateSource)
return AliasResult::MayAlias;
+ // One should be careful about relying on MustAlias.
+ // The LLVM definition implies that the two MustAlias
+ // memory objects start at exactly the same location.
+ // With Fortran array slices two objects may have
+ // the same starting location, but otherwise represent
+ // partially overlapping memory locations, e.g.:
+ // integer :: a(10)
+ // ... a(5:1:-1) ! starts at a(5) and addresses a(5), ..., a(1)
+ // ... a(5:10:1) ! starts at a(5) and addresses a(5), ..., a(10)
+ // The current implementation of FIR alias analysis will always
+ // return MayAlias for such cases.
return AliasResult::MustAlias;
}
// If one value is the address of a composite, and if the other value is the
@@ -287,6 +391,12 @@ AliasResult AliasAnalysis::alias(Source lhsSrc, Source rhsSrc, mlir::Value lhs,
// of non-data is included below.
if (src1->isTargetOrPointer() && src2->isTargetOrPointer() &&
src1->isData() && src2->isData()) {
+ // Two distinct TARGET globals may not alias.
+ if (!src1->isPointer() && !src2->isPointer() &&
+ src1->kind == SourceKind::Global && src2->kind == SourceKind::Global &&
+ src1->origin.u != src2->origin.u) {
+ return AliasResult::NoAlias;
+ }
LLVM_DEBUG(llvm::dbgs() << " aliasing because of target or pointer\n");
return AliasResult::MayAlias;
}
@@ -400,7 +510,8 @@ static ModRefResult getCallModRef(fir::CallOp call, mlir::Value var) {
// TODO: limit to Fortran functions??
// 1. Detect variables that can be accessed indirectly.
fir::AliasAnalysis aliasAnalysis;
- fir::AliasAnalysis::Source varSrc = aliasAnalysis.getSource(var);
+ fir::AliasAnalysis::Source varSrc =
+ aliasAnalysis.getSource(var, /*getLastInstantiationPoint=*/true);
// If the variable is not a user variable, we cannot safely assume that
// Fortran semantics apply (e.g., a bare alloca/allocmem result may very well
// be placed in an allocatable/pointer descriptor and escape).
@@ -430,6 +541,7 @@ static ModRefResult getCallModRef(fir::CallOp call, mlir::Value var) {
// At that stage, it has been ruled out that local (including the saved ones)
// and dummy cannot be indirectly accessed in the call.
if (varSrc.kind != fir::AliasAnalysis::SourceKind::Allocate &&
+ varSrc.kind != fir::AliasAnalysis::SourceKind::Argument &&
!varSrc.isDummyArgument()) {
if (varSrc.kind != fir::AliasAnalysis::SourceKind::Global ||
!isSavedLocal(varSrc))
@@ -450,25 +562,43 @@ static ModRefResult getCallModRef(fir::CallOp call, mlir::Value var) {
return ModRefResult::getNoModRef();
}
-/// This is mostly inspired by MLIR::LocalAliasAnalysis with 2 notable
-/// differences 1) Regions are not handled here but will be handled by a data
-/// flow analysis to come 2) Allocate and Free effects are considered
-/// modifying
+/// This is mostly inspired by MLIR::LocalAliasAnalysis, except that
+/// fir.call's are handled in a special way.
ModRefResult AliasAnalysis::getModRef(Operation *op, Value location) {
- MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
- if (!interface) {
- if (auto call = llvm::dyn_cast<fir::CallOp>(op))
- return getCallModRef(call, location);
- return ModRefResult::getModAndRef();
- }
+ if (auto call = llvm::dyn_cast<fir::CallOp>(op))
+ return getCallModRef(call, location);
// Build a ModRefResult by merging the behavior of the effects of this
// operation.
+ ModRefResult result = ModRefResult::getNoModRef();
+ MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+ if (op->hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) {
+ for (mlir::Region &region : op->getRegions()) {
+ result = result.merge(getModRef(region, location));
+ if (result.isModAndRef())
+ break;
+ }
+
+ // In MLIR, RecursiveMemoryEffects can be combined with
+ // MemoryEffectOpInterface to describe extra effects on top of the
+ // effects of the nested operations. However, the presence of
+ // RecursiveMemoryEffects and the absence of MemoryEffectOpInterface
+ // implies the operation has no other memory effects than the one of its
+ // nested operations.
+ if (!interface)
+ return result;
+ }
+
+ if (!interface || result.isModAndRef())
+ return ModRefResult::getModAndRef();
+
SmallVector<MemoryEffects::EffectInstance> effects;
interface.getEffects(effects);
- ModRefResult result = ModRefResult::getNoModRef();
for (const MemoryEffects::EffectInstance &effect : effects) {
+ // MemAlloc and MemFree are not mod-ref effects.
+ if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(effect.getEffect()))
+ continue;
// Check for an alias between the effect and our memory location.
AliasResult aliasResult = AliasResult::MayAlias;
@@ -495,22 +625,6 @@ ModRefResult AliasAnalysis::getModRef(mlir::Region &region,
mlir::Value location) {
ModRefResult result = ModRefResult::getNoModRef();
for (mlir::Operation &op : region.getOps()) {
- if (op.hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) {
- for (mlir::Region &subRegion : op.getRegions()) {
- result = result.merge(getModRef(subRegion, location));
- // Fast return is already mod and ref.
- if (result.isModAndRef())
- return result;
- }
- // In MLIR, RecursiveMemoryEffects can be combined with
- // MemoryEffectOpInterface to describe extra effects on top of the
- // effects of the nested operations. However, the presence of
- // RecursiveMemoryEffects and the absence of MemoryEffectOpInterface
- // implies the operation has no other memory effects than the one of its
- // nested operations.
- if (!mlir::isa<mlir::MemoryEffectOpInterface>(op))
- continue;
- }
result = result.merge(getModRef(&op, location));
if (result.isModAndRef())
return result;
@@ -534,13 +648,28 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,
Source::Attributes attributes;
mlir::Operation *instantiationPoint{nullptr};
while (defOp && !breakFromLoop) {
- ty = defOp->getResultTypes()[0];
+ // Value-scoped allocation detection via effects.
+ if (classifyAllocateFromEffects(defOp, v) == SourceKind::Allocate) {
+ type = SourceKind::Allocate;
+ break;
+ }
+ // Operations may have multiple results, so we need to analyze
+ // the result for which the source is queried.
+ auto opResult = mlir::cast<OpResult>(v);
+ assert(opResult.getOwner() == defOp && "v must be a result of defOp");
+ ty = opResult.getType();
llvm::TypeSwitch<Operation *>(defOp)
- .Case<hlfir::AsExprOp>([&](auto op) {
+ .Case([&](hlfir::AsExprOp op) {
+ // TODO: we should probably always report hlfir.as_expr
+ // as a unique source, and let the codegen decide whether
+ // to use the original buffer or create a copy.
v = op.getVar();
defOp = v.getDefiningOp();
})
- .Case<hlfir::AssociateOp>([&](auto op) {
+ .Case([&](hlfir::AssociateOp op) {
+ assert(opResult != op.getMustFreeStrorageFlag() &&
+ "MustFreeStorageFlag result is not an aliasing candidate");
+
mlir::Value source = op.getSource();
if (fir::isa_trivial(source.getType())) {
// Trivial values will always use distinct temp memory,
@@ -554,17 +683,7 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,
defOp = v.getDefiningOp();
}
})
- .Case<fir::AllocaOp, fir::AllocMemOp>([&](auto op) {
- // Unique memory allocation.
- type = SourceKind::Allocate;
- breakFromLoop = true;
- })
- .Case<fir::ConvertOp>([&](auto op) {
- // Skip ConvertOp's and track further through the operand.
- v = op->getOperand(0);
- defOp = v.getDefiningOp();
- })
- .Case<fir::PackArrayOp>([&](auto op) {
+ .Case([&](fir::PackArrayOp op) {
// The packed array is not distinguishable from the original
// array, so skip PackArrayOp and track further through
// the array operand.
@@ -572,29 +691,7 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,
defOp = v.getDefiningOp();
approximateSource = true;
})
- .Case<fir::BoxAddrOp>([&](auto op) {
- v = op->getOperand(0);
- defOp = v.getDefiningOp();
- if (mlir::isa<fir::BaseBoxType>(v.getType()))
- followBoxData = true;
- })
- .Case<fir::ArrayCoorOp, fir::CoordinateOp>([&](auto op) {
- if (isPointerReference(ty))
- attributes.set(Attribute::Pointer);
- v = op->getOperand(0);
- defOp = v.getDefiningOp();
- if (mlir::isa<fir::BaseBoxType>(v.getType()))
- followBoxData = true;
- approximateSource = true;
- })
- .Case<fir::EmboxOp, fir::ReboxOp>([&](auto op) {
- if (followBoxData) {
- v = op->getOperand(0);
- defOp = v.getDefiningOp();
- } else
- breakFromLoop = true;
- })
- .Case<fir::LoadOp>([&](auto op) {
+ .Case([&](fir::LoadOp op) {
// If load is inside target and it points to mapped item,
// continue tracking.
Operation *loadMemrefOp = op.getMemref().getDefiningOp();
@@ -623,21 +720,35 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,
isCapturedInInternalProcedure |=
boxSrc.isCapturedInInternalProcedure;
+ if (getLastInstantiationPoint) {
+ if (!instantiationPoint)
+ instantiationPoint = boxSrc.origin.instantiationPoint;
+ } else {
+ instantiationPoint = boxSrc.origin.instantiationPoint;
+ }
+
global = llvm::dyn_cast<mlir::SymbolRefAttr>(boxSrc.origin.u);
if (global) {
type = SourceKind::Global;
} else {
auto def = llvm::cast<mlir::Value>(boxSrc.origin.u);
- // TODO: Add support to fir.allocmem
- if (auto allocOp = def.template getDefiningOp<fir::AllocaOp>()) {
- v = def;
- defOp = v.getDefiningOp();
- type = SourceKind::Allocate;
- } else if (isDummyArgument(def)) {
- defOp = nullptr;
- v = def;
- } else {
- type = SourceKind::Indirect;
+ bool classified = false;
+ if (auto defDefOp = def.getDefiningOp()) {
+ if (classifyAllocateFromEffects(defDefOp, def) ==
+ SourceKind::Allocate) {
+ v = def;
+ defOp = defDefOp;
+ type = SourceKind::Allocate;
+ classified = true;
+ }
+ }
+ if (!classified) {
+ if (isDummyArgument(def)) {
+ defOp = nullptr;
+ v = def;
+ } else {
+ type = SourceKind::Indirect;
+ }
}
}
breakFromLoop = true;
@@ -647,28 +758,39 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,
type = SourceKind::Indirect;
breakFromLoop = true;
})
- .Case<fir::AddrOfOp>([&](auto op) {
+ .Case<fir::AddrOfOp, cuf::DeviceAddressOp>([&](auto op) {
// Address of a global scope object.
ty = v.getType();
type = SourceKind::Global;
-
- if (hasGlobalOpTargetAttr(v, op))
- attributes.set(Attribute::Target);
-
// TODO: Take followBoxData into account when setting the pointer
// attribute
if (isPointerReference(ty))
attributes.set(Attribute::Pointer);
- global = llvm::cast<fir::AddrOfOp>(op).getSymbol();
+
+ if constexpr (std::is_same_v<std::decay_t<decltype(op)>,
+ fir::AddrOfOp>)
+ global = op.getSymbol();
+ else if constexpr (std::is_same_v<std::decay_t<decltype(op)>,
+ cuf::DeviceAddressOp>)
+ global = op.getHostSymbol();
+ else
+ llvm_unreachable("unexpected operation");
+
+ if (symbolMayHaveTargetAttr(global, op))
+ attributes.set(Attribute::Target);
+
breakFromLoop = true;
})
.Case<hlfir::DeclareOp, fir::DeclareOp>([&](auto op) {
+ // The declare operations support FortranObjectViewOpInterface,
+ // but their handling is more complex. Maybe we can find better
+ // abstractions to handle them in a general fashion.
bool isPrivateItem = false;
if (omp::BlockArgOpenMPOpInterface argIface =
dyn_cast<omp::BlockArgOpenMPOpInterface>(op->getParentOp())) {
Value ompValArg;
llvm::TypeSwitch<Operation *>(op->getParentOp())
- .template Case<omp::TargetOp>([&](auto targetOp) {
+ .Case([&](omp::TargetOp targetOp) {
// If declare operation is inside omp target region,
// continue alias analysis outside the target region
for (auto [opArg, blockArg] : llvm::zip_equal(
@@ -713,7 +835,7 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,
// currently provide any useful information. The host associated
// access will end up dereferencing the host association tuple,
// so we may as well stop right now.
- v = defOp->getResult(0);
+ v = opResult;
// TODO: if the host associated variable is a dummy argument
// of the host, I think, we can treat it as SourceKind::Argument
// for the purpose of alias analysis inside the internal procedure.
@@ -748,21 +870,45 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,
v = op.getMemref();
defOp = v.getDefiningOp();
})
- .Case<hlfir::DesignateOp>([&](auto op) {
- auto varIf = llvm::cast<fir::FortranVariableOpInterface>(defOp);
- attributes |= getAttrsFromVariable(varIf);
- // Track further through the memory indexed into
- // => if the source arrays/structures don't alias then nor do the
- // results of hlfir.designate
- v = op.getMemref();
+ .Case([&](fir::FortranObjectViewOpInterface op) {
+ // This case must be located after the cases for concrete
+ // operations that support FortraObjectViewOpInterface,
+ // so that their special handling kicks in.
+
+ // fir.embox/rebox case: this is the only case where we check
+ // for followBoxData.
+ // TODO: it looks like we do not have LIT tests that fail
+ // upon removal of the followBoxData code. We should come up
+ // with a test or remove this code.
+ if (!followBoxData &&
+ (mlir::isa<fir::EmboxOp>(op) || mlir::isa<fir::ReboxOp>(op))) {
+ breakFromLoop = true;
+ return;
+ }
+
+ // Collect attributes from FortranVariableOpInterface operations.
+ if (auto varIf =
+ mlir::dyn_cast<fir::FortranVariableOpInterface>(defOp))
+ attributes |= getAttrsFromVariable(varIf);
+ // Set Pointer attribute based on the reference type.
+ if (isPointerReference(ty))
+ attributes.set(Attribute::Pointer);
+
+ // Update v to point to the operand that represents the object
+ // referenced by the operation's result.
+ v = op.getViewSource(opResult);
defOp = v.getDefiningOp();
- // TODO: there will be some cases which provably don't alias if one
- // takes into account the component or indices, which are currently
- // ignored here - leading to false positives
- // because of this limitation, we need to make sure we never return
- // MustAlias after going through a designate operation
- approximateSource = true;
- if (mlir::isa<fir::BaseBoxType>(v.getType()))
+ // If the input the resulting object references are offsetted,
+ // then set approximateSource.
+ auto offset = op.getViewOffset(opResult);
+ if (!offset || *offset != 0)
+ approximateSource = true;
+
+ // If the source is a box, and the result is not a box,
+ // then this is one of the box "unpacking" operations,
+ // so we should set followBoxData.
+ if (mlir::isa<fir::BaseBoxType>(v.getType()) &&
+ !mlir::isa<fir::BaseBoxType>(ty))
followBoxData = true;
})
.Default([&](auto op) {
@@ -803,4 +949,16 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,
isCapturedInInternalProcedure};
}
+const mlir::SymbolTable *
+fir::AliasAnalysis::getNearestSymbolTable(mlir::Operation *from) {
+ assert(from);
+ Operation *symTabOp = mlir::SymbolTable::getNearestSymbolTable(from);
+ if (!symTabOp)
+ return nullptr;
+ auto it = symTabMap.find(symTabOp);
+ if (it != symTabMap.end())
+ return &it->second;
+ return &symTabMap.try_emplace(symTabOp, symTabOp).first->second;
+}
+
} // namespace fir
diff --git a/flang/lib/Optimizer/Analysis/ArraySectionAnalyzer.cpp b/flang/lib/Optimizer/Analysis/ArraySectionAnalyzer.cpp
new file mode 100644
index 0000000..f5ee298
--- /dev/null
+++ b/flang/lib/Optimizer/Analysis/ArraySectionAnalyzer.cpp
@@ -0,0 +1,300 @@
+//===- ArraySectionAnalyzer.cpp - Analyze array sections ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Analysis/ArraySectionAnalyzer.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "array-section-analyzer"
+
+using namespace fir;
+
+ArraySectionAnalyzer::SectionDesc::SectionDesc(mlir::Value lb, mlir::Value ub,
+ mlir::Value stride)
+ : lb(lb), ub(ub), stride(stride) {
+ assert(lb && "lower bound or index must be specified");
+ normalize();
+}
+
+void ArraySectionAnalyzer::SectionDesc::normalize() {
+ if (!ub)
+ ub = lb;
+ if (lb == ub)
+ stride = nullptr;
+ if (stride)
+ if (auto val = fir::getIntIfConstant(stride))
+ if (*val == 1)
+ stride = nullptr;
+}
+
+bool ArraySectionAnalyzer::SectionDesc::operator==(
+ const SectionDesc &other) const {
+ return lb == other.lb && ub == other.ub && stride == other.stride;
+}
+
+ArraySectionAnalyzer::SectionDesc
+ArraySectionAnalyzer::readSectionDesc(mlir::Operation::operand_iterator &it,
+ bool isTriplet) {
+ if (isTriplet)
+ return {*it++, *it++, *it++};
+ return {*it++, nullptr, nullptr};
+}
+
+std::pair<mlir::Value, mlir::Value>
+ArraySectionAnalyzer::getOrderedBounds(const SectionDesc &desc) {
+ mlir::Value stride = desc.stride;
+ // Null stride means stride=1.
+ if (!stride)
+ return {desc.lb, desc.ub};
+ // Reverse the bounds, if stride is negative.
+ if (auto val = fir::getIntIfConstant(stride)) {
+ if (*val >= 0)
+ return {desc.lb, desc.ub};
+ else
+ return {desc.ub, desc.lb};
+ }
+
+ return {nullptr, nullptr};
+}
+
+bool ArraySectionAnalyzer::areDisjointSections(const SectionDesc &desc1,
+ const SectionDesc &desc2) {
+ auto [lb1, ub1] = getOrderedBounds(desc1);
+ auto [lb2, ub2] = getOrderedBounds(desc2);
+ if (!lb1 || !lb2)
+ return false;
+ // Note that this comparison must be made on the ordered bounds,
+ // otherwise 'a(x:y:1) = a(z:x-1:-1) + 1' may be incorrectly treated
+ // as not overlapping (x=2, y=10, z=9).
+ if (isLess(ub1, lb2) || isLess(ub2, lb1))
+ return true;
+ return false;
+}
+
+bool ArraySectionAnalyzer::areIdenticalSections(const SectionDesc &desc1,
+ const SectionDesc &desc2) {
+ if (desc1 == desc2)
+ return true;
+ return false;
+}
+
+ArraySectionAnalyzer::SlicesOverlapKind
+ArraySectionAnalyzer::analyze(mlir::Value ref1, mlir::Value ref2) {
+ if (ref1 == ref2)
+ return SlicesOverlapKind::DefinitelyIdentical;
+
+ auto des1 = ref1.getDefiningOp<hlfir::DesignateOp>();
+ auto des2 = ref2.getDefiningOp<hlfir::DesignateOp>();
+ // We only support a pair of designators right now.
+ if (!des1 || !des2)
+ return SlicesOverlapKind::Unknown;
+
+ if (des1.getMemref() != des2.getMemref()) {
+ // If the bases are different, then there is unknown overlap.
+ LLVM_DEBUG(llvm::dbgs() << "No identical base for:\n"
+ << des1 << "and:\n"
+ << des2 << "\n");
+ return SlicesOverlapKind::Unknown;
+ }
+
+ // Require all components of the designators to be the same.
+ // It might be too strict, e.g. we may probably allow for
+ // different type parameters.
+ if (des1.getComponent() != des2.getComponent() ||
+ des1.getComponentShape() != des2.getComponentShape() ||
+ des1.getSubstring() != des2.getSubstring() ||
+ des1.getComplexPart() != des2.getComplexPart() ||
+ des1.getTypeparams() != des2.getTypeparams()) {
+ LLVM_DEBUG(llvm::dbgs() << "Different designator specs for:\n"
+ << des1 << "and:\n"
+ << des2 << "\n");
+ return SlicesOverlapKind::Unknown;
+ }
+
+ // Analyze the subscripts.
+ auto des1It = des1.getIndices().begin();
+ auto des2It = des2.getIndices().begin();
+ bool identicalTriplets = true;
+ bool identicalIndices = true;
+ for (auto [isTriplet1, isTriplet2] :
+ llvm::zip(des1.getIsTriplet(), des2.getIsTriplet())) {
+ SectionDesc desc1 = readSectionDesc(des1It, isTriplet1);
+ SectionDesc desc2 = readSectionDesc(des2It, isTriplet2);
+
+ // See if we can prove that any of the sections do not overlap.
+ // This is mostly a Polyhedron/nf performance hack that looks for
+ // particular relations between the lower and upper bounds
+ // of the array sections, e.g. for any positive constant C:
+ // X:Y does not overlap with (Y+C):Z
+ // X:Y does not overlap with Z:(X-C)
+ if (areDisjointSections(desc1, desc2))
+ return SlicesOverlapKind::DefinitelyDisjoint;
+
+ if (!areIdenticalSections(desc1, desc2)) {
+ if (isTriplet1 || isTriplet2) {
+ // For example:
+ // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0)
+ // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1)
+ //
+ // If all the triplets (section speficiers) are the same, then
+ // we do not care if %0 is equal to %1 - the slices are either
+ // identical or completely disjoint.
+ //
+ // Also, treat these as identical sections:
+ // hlfir.designate %6#0 (%c2:%c2:%c1)
+ // hlfir.designate %6#0 (%c2)
+ identicalTriplets = false;
+ LLVM_DEBUG(llvm::dbgs() << "Triplet mismatch for:\n"
+ << des1 << "and:\n"
+ << des2 << "\n");
+ } else {
+ identicalIndices = false;
+ LLVM_DEBUG(llvm::dbgs() << "Indices mismatch for:\n"
+ << des1 << "and:\n"
+ << des2 << "\n");
+ }
+ }
+ }
+
+ if (identicalTriplets) {
+ if (identicalIndices)
+ return SlicesOverlapKind::DefinitelyIdentical;
+ else
+ return SlicesOverlapKind::EitherIdenticalOrDisjoint;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "Different sections for:\n"
+ << des1 << "and:\n"
+ << des2 << "\n");
+ return SlicesOverlapKind::Unknown;
+}
+
+bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) {
+ auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
+ auto *op = v.getDefiningOp();
+ while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op))
+ op = conv.getValue().getDefiningOp();
+ return op;
+ };
+
+ auto isPositiveConstant = [](mlir::Value v) -> bool {
+ if (auto val = fir::getIntIfConstant(v))
+ return *val > 0;
+ return false;
+ };
+
+ auto *op1 = removeConvert(v1);
+ auto *op2 = removeConvert(v2);
+ if (!op1 || !op2)
+ return false;
+
+ // Check if they are both constants.
+ if (auto val1 = fir::getIntIfConstant(op1->getResult(0)))
+ if (auto val2 = fir::getIntIfConstant(op2->getResult(0)))
+ return *val1 < *val2;
+
+ // Handle some variable cases (C > 0):
+ // v2 = v1 + C
+ // v2 = C + v1
+ // v1 = v2 - C
+ if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
+ if ((addi.getLhs().getDefiningOp() == op1 &&
+ isPositiveConstant(addi.getRhs())) ||
+ (addi.getRhs().getDefiningOp() == op1 &&
+ isPositiveConstant(addi.getLhs())))
+ return true;
+ if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
+ if (subi.getLhs().getDefiningOp() == op2 &&
+ isPositiveConstant(subi.getRhs()))
+ return true;
+ return false;
+}
+
+/// Returns the array indices for the given hlfir.designate.
+/// It recognizes the computations used to transform the one-based indices
+/// into the array's lb-based indices, and returns the one-based indices
+/// in these cases.
+static llvm::SmallVector<mlir::Value>
+getDesignatorIndices(hlfir::DesignateOp designate) {
+ mlir::Value memref = designate.getMemref();
+
+ // If the object is a box, then the indices may be adjusted
+ // according to the box's lower bound(s). Scan through
+ // the computations to try to find the one-based indices.
+ if (mlir::isa<fir::BaseBoxType>(memref.getType())) {
+ // Look for the following pattern:
+ // %13 = fir.load %12 : !fir.ref<!fir.box<...>
+ // %14:3 = fir.box_dims %13, %c0 : (!fir.box<...>, index) -> ...
+ // %17 = arith.subi %14#0, %c1 : index
+ // %18 = arith.addi %arg2, %17 : index
+ // %19 = hlfir.designate %13 (%18) : (!fir.box<...>, index) -> ...
+ //
+ // %arg2 is a one-based index.
+
+ auto isNormalizedLb = [memref](mlir::Value v, unsigned dim) {
+ // Return true, if v and dim are such that:
+ // %14:3 = fir.box_dims %13, %dim : (!fir.box<...>, index) -> ...
+ // %17 = arith.subi %14#0, %c1 : index
+ // %19 = hlfir.designate %13 (...) : (!fir.box<...>, index) -> ...
+ if (auto subOp =
+ mlir::dyn_cast_or_null<mlir::arith::SubIOp>(v.getDefiningOp())) {
+ auto cst = fir::getIntIfConstant(subOp.getRhs());
+ if (!cst || *cst != 1)
+ return false;
+ if (auto dimsOp = mlir::dyn_cast_or_null<fir::BoxDimsOp>(
+ subOp.getLhs().getDefiningOp())) {
+ if (memref != dimsOp.getVal() ||
+ dimsOp.getResult(0) != subOp.getLhs())
+ return false;
+ auto dimsOpDim = fir::getIntIfConstant(dimsOp.getDim());
+ return dimsOpDim && dimsOpDim == dim;
+ }
+ }
+ return false;
+ };
+
+ llvm::SmallVector<mlir::Value> newIndices;
+ for (auto index : llvm::enumerate(designate.getIndices())) {
+ if (auto addOp = mlir::dyn_cast_or_null<mlir::arith::AddIOp>(
+ index.value().getDefiningOp())) {
+ for (unsigned opNum = 0; opNum < 2; ++opNum)
+ if (isNormalizedLb(addOp->getOperand(opNum), index.index())) {
+ newIndices.push_back(addOp->getOperand((opNum + 1) % 2));
+ break;
+ }
+
+ // If new one-based index was not added, exit early.
+ if (newIndices.size() <= index.index())
+ break;
+ }
+ }
+
+ // If any of the indices is not adjusted to the array's lb,
+ // then return the original designator indices.
+ if (newIndices.size() != designate.getIndices().size())
+ return designate.getIndices();
+
+ return newIndices;
+ }
+
+ return designate.getIndices();
+}
+
+bool fir::ArraySectionAnalyzer::isDesignatingArrayInOrder(
+ hlfir::DesignateOp designate, hlfir::ElementalOpInterface elemental) {
+
+ auto indices = getDesignatorIndices(designate);
+ auto elementalIndices = elemental.getIndices();
+ if (indices.size() == elementalIndices.size())
+ return std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
+ elementalIndices.end());
+ return false;
+}
diff --git a/flang/lib/Optimizer/Analysis/CMakeLists.txt b/flang/lib/Optimizer/Analysis/CMakeLists.txt
index 4d4ad88..398a6d3 100644
--- a/flang/lib/Optimizer/Analysis/CMakeLists.txt
+++ b/flang/lib/Optimizer/Analysis/CMakeLists.txt
@@ -1,14 +1,16 @@
add_flang_library(FIRAnalysis
AliasAnalysis.cpp
+ ArraySectionAnalyzer.cpp
TBAAForest.cpp
DEPENDS
+ CUFDialect
FIRDialect
FIRSupport
HLFIRDialect
LINK_LIBS
- FIRBuilder
+ CUFDialect
FIRDialect
FIRSupport
HLFIRDialect
diff --git a/flang/lib/Optimizer/Analysis/TBAAForest.cpp b/flang/lib/Optimizer/Analysis/TBAAForest.cpp
index 44a0348..7154785 100644
--- a/flang/lib/Optimizer/Analysis/TBAAForest.cpp
+++ b/flang/lib/Optimizer/Analysis/TBAAForest.cpp
@@ -66,12 +66,9 @@ fir::TBAATree::TBAATree(mlir::LLVM::TBAATypeDescriptorAttr anyAccess,
mlir::LLVM::TBAATypeDescriptorAttr dataRoot,
mlir::LLVM::TBAATypeDescriptorAttr boxMemberTypeDesc)
: targetDataTree(dataRoot.getContext(), "target data", dataRoot),
- globalDataTree(dataRoot.getContext(), "global data",
- targetDataTree.getRoot()),
- allocatedDataTree(dataRoot.getContext(), "allocated data",
- targetDataTree.getRoot()),
+ globalDataTree(dataRoot.getContext(), "global data", dataRoot),
+ allocatedDataTree(dataRoot.getContext(), "allocated data", dataRoot),
dummyArgDataTree(dataRoot.getContext(), "dummy arg data", dataRoot),
- directDataTree(dataRoot.getContext(), "direct data",
- targetDataTree.getRoot()),
+ directDataTree(dataRoot.getContext(), "direct data", dataRoot),
anyAccessDesc(anyAccess), boxMemberTypeDesc(boxMemberTypeDesc),
anyDataTypeDesc(dataRoot) {}
diff --git a/flang/lib/Optimizer/Builder/CMakeLists.txt b/flang/lib/Optimizer/Builder/CMakeLists.txt
index 1f95259..d966c52 100644
--- a/flang/lib/Optimizer/Builder/CMakeLists.txt
+++ b/flang/lib/Optimizer/Builder/CMakeLists.txt
@@ -5,6 +5,7 @@ add_flang_library(FIRBuilder
BoxValue.cpp
Character.cpp
Complex.cpp
+ CUDAIntrinsicCall.cpp
CUFCommon.cpp
DoLoopHelper.cpp
FIRBuilder.cpp
@@ -46,6 +47,7 @@ add_flang_library(FIRBuilder
LINK_LIBS
CUFAttrs
CUFDialect
+ FIRAnalysis
FIRDialect
FIRDialectSupport
FIRSupport
diff --git a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
new file mode 100644
index 0000000..fe2db46
--- /dev/null
+++ b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
@@ -0,0 +1,1722 @@
+//===-- CUDAIntrinsicCall.cpp ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Helper routines for constructing the FIR dialect of MLIR for PowerPC
+// intrinsics. Extensive use of MLIR interfaces and MLIR's coding style
+// (https://mlir.llvm.org/getting_started/DeveloperGuide/) is used in this
+// module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/CUDAIntrinsicCall.h"
+#include "flang/Evaluate/common.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/MutableBox.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+namespace fir {
+
+using CI = CUDAIntrinsicLibrary;
+
+static const char __ldca_i4x4[] = "__ldca_i4x4_";
+static const char __ldca_i8x2[] = "__ldca_i8x2_";
+static const char __ldca_r2x2[] = "__ldca_r2x2_";
+static const char __ldca_r4x4[] = "__ldca_r4x4_";
+static const char __ldca_r8x2[] = "__ldca_r8x2_";
+static const char __ldcg_i4x4[] = "__ldcg_i4x4_";
+static const char __ldcg_i8x2[] = "__ldcg_i8x2_";
+static const char __ldcg_r2x2[] = "__ldcg_r2x2_";
+static const char __ldcg_r4x4[] = "__ldcg_r4x4_";
+static const char __ldcg_r8x2[] = "__ldcg_r8x2_";
+static const char __ldcs_i4x4[] = "__ldcs_i4x4_";
+static const char __ldcs_i8x2[] = "__ldcs_i8x2_";
+static const char __ldcs_r2x2[] = "__ldcs_r2x2_";
+static const char __ldcs_r4x4[] = "__ldcs_r4x4_";
+static const char __ldcs_r8x2[] = "__ldcs_r8x2_";
+static const char __ldcv_i4x4[] = "__ldcv_i4x4_";
+static const char __ldcv_i8x2[] = "__ldcv_i8x2_";
+static const char __ldcv_r2x2[] = "__ldcv_r2x2_";
+static const char __ldcv_r4x4[] = "__ldcv_r4x4_";
+static const char __ldcv_r8x2[] = "__ldcv_r8x2_";
+static const char __ldlu_i4x4[] = "__ldlu_i4x4_";
+static const char __ldlu_i8x2[] = "__ldlu_i8x2_";
+static const char __ldlu_r2x2[] = "__ldlu_r2x2_";
+static const char __ldlu_r4x4[] = "__ldlu_r4x4_";
+static const char __ldlu_r8x2[] = "__ldlu_r8x2_";
+
+static constexpr unsigned kTMAAlignment = 16;
+
+// CUDA specific intrinsic handlers.
+static constexpr IntrinsicHandler cudaHandlers[]{
+ {"__ldca_i4x4",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldca_i4x4, 4>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldca_i8x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldca_i8x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldca_r2x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldca_r2x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldca_r4x4",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldca_r4x4, 4>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldca_r8x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldca_r8x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_i4x4",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcg_i4x4, 4>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_i8x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcg_i8x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_r2x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcg_r2x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_r4x4",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcg_r4x4, 4>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_r8x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcg_r8x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_i4x4",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcs_i4x4, 4>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_i8x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcs_i8x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_r2x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcs_r2x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_r4x4",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcs_r4x4, 4>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_r8x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcs_r8x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_i4x4",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcv_i4x4, 4>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_i8x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcv_i8x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_r2x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcv_r2x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_r4x4",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcv_r4x4, 4>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_r8x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldcv_r8x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_i4x4",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldlu_i4x4, 4>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_i8x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldlu_i8x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_r2x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldlu_r2x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_r4x4",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldlu_r4x4, 4>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_r8x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genLDXXFunc<__ldlu_r8x2, 2>),
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"all_sync",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genVoteSync<mlir::NVVM::VoteSyncKind::all>),
+ {{{"mask", asValue}, {"pred", asValue}}},
+ /*isElemental=*/false},
+ {"any_sync",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genVoteSync<mlir::NVVM::VoteSyncKind::any>),
+ {{{"mask", asValue}, {"pred", asValue}}},
+ /*isElemental=*/false},
+ {"atomicadd_r4x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genAtomicAddVector<2>),
+ {{{"a", asAddr}, {"v", asAddr}}},
+ false},
+ {"atomicadd_r4x4",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genAtomicAddVector4x4),
+ {{{"a", asAddr}, {"v", asAddr}}},
+ false},
+ {"atomicaddd",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicAdd),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicaddf",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicAdd),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicaddi",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicAdd),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicaddl",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicAdd),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicaddr2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicAddR2),
+ {{{"a", asAddr}, {"v", asAddr}}},
+ false},
+ {"atomicaddvector_r2x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genAtomicAddVector<2>),
+ {{{"a", asAddr}, {"v", asAddr}}},
+ false},
+ {"atomicaddvector_r4x2",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genAtomicAddVector<2>),
+ {{{"a", asAddr}, {"v", asAddr}}},
+ false},
+ {"atomicandi",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicAnd),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomiccasd",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicCas),
+ {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
+ false},
+ {"atomiccasf",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicCas),
+ {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
+ false},
+ {"atomiccasi",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicCas),
+ {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
+ false},
+ {"atomiccasul",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicCas),
+ {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
+ false},
+ {"atomicdeci",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicDec),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicexchd",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicExch),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicexchf",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicExch),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicexchi",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicExch),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicexchul",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicExch),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicinci",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicInc),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicmaxd",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMax),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicmaxf",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMax),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicmaxi",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMax),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicmaxl",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMax),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicmind",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMin),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicminf",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMin),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicmini",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMin),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicminl",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicMin),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicori",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicOr),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicsubd",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicSub),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicsubf",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicSub),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicsubi",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicSub),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicsubl",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genAtomicSub),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"atomicxori",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(&CI::genAtomicXor),
+ {{{"a", asAddr}, {"v", asValue}}},
+ false},
+ {"ballot_sync",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genVoteSync<mlir::NVVM::VoteSyncKind::ballot>),
+ {{{"mask", asValue}, {"pred", asValue}}},
+ /*isElemental=*/false},
+ {"barrier_arrive",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genBarrierArrive),
+ {{{"barrier", asAddr}}},
+ /*isElemental=*/false},
+ {"barrier_arrive_cnt",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genBarrierArriveCnt),
+ {{{"barrier", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"barrier_init",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genBarrierInit),
+ {{{"barrier", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"barrier_try_wait",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genBarrierTryWait),
+ {{{"barrier", asAddr}, {"token", asValue}}},
+ /*isElemental=*/false},
+ {"barrier_try_wait_sleep",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genBarrierTryWaitSleep),
+ {{{"barrier", asAddr}, {"token", asValue}, {"ns", asValue}}},
+ /*isElemental=*/false},
+ {"clock",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genNVVMTime<mlir::NVVM::ClockOp>),
+ {},
+ /*isElemental=*/false},
+ {"clock64",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genNVVMTime<mlir::NVVM::Clock64Op>),
+ {},
+ /*isElemental=*/false},
+ {"cluster_block_index",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genClusterBlockIndex),
+ {},
+ /*isElemental=*/false},
+ {"cluster_dim_blocks",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genClusterDimBlocks),
+ {},
+ /*isElemental=*/false},
+ {"fence_proxy_async",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genFenceProxyAsync),
+ {},
+ /*isElemental=*/false},
+ {"globaltimer",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genNVVMTime<mlir::NVVM::GlobalTimerOp>),
+ {},
+ /*isElemental=*/false},
+ {"match_all_syncjd",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genMatchAllSync),
+ {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
+ /*isElemental=*/false},
+ {"match_all_syncjf",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genMatchAllSync),
+ {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
+ /*isElemental=*/false},
+ {"match_all_syncjj",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genMatchAllSync),
+ {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
+ /*isElemental=*/false},
+ {"match_all_syncjx",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genMatchAllSync),
+ {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
+ /*isElemental=*/false},
+ {"match_any_syncjd",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genMatchAnySync),
+ {{{"mask", asValue}, {"value", asValue}}},
+ /*isElemental=*/false},
+ {"match_any_syncjf",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genMatchAnySync),
+ {{{"mask", asValue}, {"value", asValue}}},
+ /*isElemental=*/false},
+ {"match_any_syncjj",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genMatchAnySync),
+ {{{"mask", asValue}, {"value", asValue}}},
+ /*isElemental=*/false},
+ {"match_any_syncjx",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genMatchAnySync),
+ {{{"mask", asValue}, {"value", asValue}}},
+ /*isElemental=*/false},
+ {"syncthreads",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genSyncThreads),
+ {},
+ /*isElemental=*/false},
+ {"syncthreads_and_i4",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genSyncThreadsAnd),
+ {},
+ /*isElemental=*/false},
+ {"syncthreads_and_l4",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genSyncThreadsAnd),
+ {},
+ /*isElemental=*/false},
+ {"syncthreads_count_i4",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genSyncThreadsCount),
+ {},
+ /*isElemental=*/false},
+ {"syncthreads_count_l4",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genSyncThreadsCount),
+ {},
+ /*isElemental=*/false},
+ {"syncthreads_or_i4",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genSyncThreadsOr),
+ {},
+ /*isElemental=*/false},
+ {"syncthreads_or_l4",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genSyncThreadsOr),
+ {},
+ /*isElemental=*/false},
+ {"syncwarp",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(&CI::genSyncWarp),
+ {},
+ /*isElemental=*/false},
+ {"this_cluster",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genThisCluster),
+ {},
+ /*isElemental=*/false},
+ {"this_grid",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genThisGrid),
+ {},
+ /*isElemental=*/false},
+ {"this_thread_block",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genThisThreadBlock),
+ {},
+ /*isElemental=*/false},
+ {"this_warp",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(&CI::genThisWarp),
+ {},
+ /*isElemental=*/false},
+ {"threadfence",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genThreadFence<mlir::NVVM::MemScopeKind::GPU>),
+ {},
+ /*isElemental=*/false},
+ {"threadfence_block",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genThreadFence<mlir::NVVM::MemScopeKind::CTA>),
+ {},
+ /*isElemental=*/false},
+ {"threadfence_system",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genThreadFence<mlir::NVVM::MemScopeKind::SYS>),
+ {},
+ /*isElemental=*/false},
+ {"tma_bulk_commit_group",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkCommitGroup),
+ {{}},
+ /*isElemental=*/false},
+ {"tma_bulk_g2s",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(&CI::genTMABulkG2S),
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nbytes", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_ldc4",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkLoadC4),
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nelems", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_ldc8",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkLoadC8),
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nelems", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_ldi4",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkLoadI4),
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nelems", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_ldi8",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkLoadI8),
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nelems", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_ldr2",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkLoadR2),
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nelems", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_ldr4",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkLoadR4),
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nelems", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_ldr8",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkLoadR8),
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nelems", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_s2g",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(&CI::genTMABulkS2G),
+ {{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_store_c4",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkStoreC4),
+ {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_store_c8",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkStoreC8),
+ {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_store_i4",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkStoreI4),
+ {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_store_i8",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkStoreI8),
+ {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_store_r2",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkStoreR2),
+ {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_store_r4",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkStoreR4),
+ {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_store_r8",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkStoreR8),
+ {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_wait_group",
+ static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
+ &CI::genTMABulkWaitGroup),
+ {{}},
+ /*isElemental=*/false},
+};
+
+template <std::size_t N>
+static constexpr bool isSorted(const IntrinsicHandler (&array)[N]) {
+ // Replace by std::sorted when C++20 is default (will be constexpr).
+ const IntrinsicHandler *lastSeen{nullptr};
+ bool isSorted{true};
+ for (const auto &x : array) {
+ if (lastSeen)
+ isSorted &= std::string_view{lastSeen->name} < std::string_view{x.name};
+ lastSeen = &x;
+ }
+ return isSorted;
+}
+static_assert(isSorted(cudaHandlers) && "map must be sorted");
+
+const IntrinsicHandler *findCUDAIntrinsicHandler(llvm::StringRef name) {
+ auto compare = [](const IntrinsicHandler &cudaHandler, llvm::StringRef name) {
+ return name.compare(cudaHandler.name) > 0;
+ };
+ auto result = llvm::lower_bound(cudaHandlers, name, compare);
+ return result != std::end(cudaHandlers) && result->name == name ? result
+ : nullptr;
+}
+
+static mlir::Value convertPtrToNVVMSpace(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ mlir::Value barrier,
+ mlir::NVVM::NVVMMemorySpace space) {
+ mlir::Value llvmPtr = fir::ConvertOp::create(
+ builder, loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()),
+ barrier);
+ mlir::Value addrCast = mlir::LLVM::AddrSpaceCastOp::create(
+ builder, loc,
+ mlir::LLVM::LLVMPointerType::get(builder.getContext(),
+ static_cast<unsigned>(space)),
+ llvmPtr);
+ return addrCast;
+}
+
+static mlir::Value genAtomBinOp(fir::FirOpBuilder &builder, mlir::Location &loc,
+ mlir::LLVM::AtomicBinOp binOp, mlir::Value arg0,
+ mlir::Value arg1) {
+ auto llvmPointerType = mlir::LLVM::LLVMPointerType::get(builder.getContext());
+ arg0 = builder.createConvert(loc, llvmPointerType, arg0);
+ return mlir::LLVM::AtomicRMWOp::create(builder, loc, binOp, arg0, arg1,
+ mlir::LLVM::AtomicOrdering::seq_cst);
+}
+
+// ATOMICADD
+mlir::Value
+CUDAIntrinsicLibrary::genAtomicAdd(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+ mlir::LLVM::AtomicBinOp binOp =
+ mlir::isa<mlir::IntegerType>(args[1].getType())
+ ? mlir::LLVM::AtomicBinOp::add
+ : mlir::LLVM::AtomicBinOp::fadd;
+ return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
+}
+
+fir::ExtendedValue
+CUDAIntrinsicLibrary::genAtomicAddR2(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 2);
+
+ mlir::Value a = fir::getBase(args[0]);
+
+ if (mlir::isa<fir::BaseBoxType>(a.getType())) {
+ a = fir::BoxAddrOp::create(builder, loc, a);
+ }
+
+ auto loc = builder.getUnknownLoc();
+ auto f16Ty = builder.getF16Type();
+ auto i32Ty = builder.getI32Type();
+ auto vecF16Ty = mlir::VectorType::get({2}, f16Ty);
+ mlir::Type idxTy = builder.getIndexType();
+ auto f16RefTy = fir::ReferenceType::get(f16Ty);
+ auto zero = builder.createIntegerConstant(loc, idxTy, 0);
+ auto one = builder.createIntegerConstant(loc, idxTy, 1);
+ auto v1Coord = fir::CoordinateOp::create(builder, loc, f16RefTy,
+ fir::getBase(args[1]), zero);
+ auto v2Coord = fir::CoordinateOp::create(builder, loc, f16RefTy,
+ fir::getBase(args[1]), one);
+ auto v1 = fir::LoadOp::create(builder, loc, v1Coord);
+ auto v2 = fir::LoadOp::create(builder, loc, v2Coord);
+ mlir::Value undef = mlir::LLVM::UndefOp::create(builder, loc, vecF16Ty);
+ mlir::Value vec1 = mlir::LLVM::InsertElementOp::create(
+ builder, loc, undef, v1, builder.createIntegerConstant(loc, i32Ty, 0));
+ mlir::Value vec2 = mlir::LLVM::InsertElementOp::create(
+ builder, loc, vec1, v2, builder.createIntegerConstant(loc, i32Ty, 1));
+ auto res = genAtomBinOp(builder, loc, mlir::LLVM::AtomicBinOp::fadd, a, vec2);
+ auto i32VecTy = mlir::VectorType::get({1}, i32Ty);
+ mlir::Value vecI32 =
+ mlir::vector::BitCastOp::create(builder, loc, i32VecTy, res);
+ return mlir::vector::ExtractOp::create(builder, loc, vecI32,
+ mlir::ArrayRef<int64_t>{0});
+}
+
+// ATOMICADDVECTOR
+template <int extent>
+fir::ExtendedValue CUDAIntrinsicLibrary::genAtomicAddVector(
+ mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 2);
+ mlir::Value res = fir::AllocaOp::create(
+ builder, loc, fir::SequenceType::get({extent}, resultType));
+ mlir::Value a = fir::getBase(args[0]);
+ if (mlir::isa<fir::BaseBoxType>(a.getType())) {
+ a = fir::BoxAddrOp::create(builder, loc, a);
+ }
+ auto vecTy = mlir::VectorType::get({extent}, resultType);
+ auto refTy = fir::ReferenceType::get(resultType);
+ mlir::Type i32Ty = builder.getI32Type();
+ mlir::Type idxTy = builder.getIndexType();
+
+ // Extract the values from the array.
+ llvm::SmallVector<mlir::Value> values;
+ for (unsigned i = 0; i < extent; ++i) {
+ mlir::Value pos = builder.createIntegerConstant(loc, idxTy, i);
+ mlir::Value coord = fir::CoordinateOp::create(builder, loc, refTy,
+ fir::getBase(args[1]), pos);
+ mlir::Value value = fir::LoadOp::create(builder, loc, coord);
+ values.push_back(value);
+ }
+ // Pack extracted values into a vector to call the atomic add.
+ mlir::Value undef = mlir::LLVM::UndefOp::create(builder, loc, vecTy);
+ for (unsigned i = 0; i < extent; ++i) {
+ mlir::Value insert = mlir::LLVM::InsertElementOp::create(
+ builder, loc, undef, values[i],
+ builder.createIntegerConstant(loc, i32Ty, i));
+ undef = insert;
+ }
+ // Atomic operation with a vector of values.
+ mlir::Value add =
+ genAtomBinOp(builder, loc, mlir::LLVM::AtomicBinOp::fadd, a, undef);
+ // Store results in the result array.
+ for (unsigned i = 0; i < extent; ++i) {
+ mlir::Value r = mlir::LLVM::ExtractElementOp::create(
+ builder, loc, add, builder.createIntegerConstant(loc, i32Ty, i));
+ mlir::Value c = fir::CoordinateOp::create(
+ builder, loc, refTy, res, builder.createIntegerConstant(loc, idxTy, i));
+ fir::StoreOp::create(builder, loc, r, c);
+ }
+ mlir::Value ext = builder.createIntegerConstant(loc, idxTy, extent);
+ return fir::ArrayBoxValue(res, {ext});
+}
+
+// ATOMICADDVECTOR4x4
+fir::ExtendedValue CUDAIntrinsicLibrary::genAtomicAddVector4x4(
+ mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 2);
+ mlir::Value a = fir::getBase(args[0]);
+ if (mlir::isa<fir::BaseBoxType>(a.getType()))
+ a = fir::BoxAddrOp::create(builder, loc, a);
+
+ const unsigned extent = 4;
+ auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
+ mlir::Value ptr = builder.createConvert(loc, llvmPtrTy, a);
+ mlir::Type f32Ty = builder.getF32Type();
+ mlir::Type idxTy = builder.getIndexType();
+ mlir::Type refTy = fir::ReferenceType::get(f32Ty);
+ llvm::SmallVector<mlir::Value> values;
+ for (unsigned i = 0; i < extent; ++i) {
+ mlir::Value pos = builder.createIntegerConstant(loc, idxTy, i);
+ mlir::Value coord = fir::CoordinateOp::create(builder, loc, refTy,
+ fir::getBase(args[1]), pos);
+ mlir::Value value = fir::LoadOp::create(builder, loc, coord);
+ values.push_back(value);
+ }
+
+ auto inlinePtx = mlir::NVVM::InlinePtxOp::create(
+ builder, loc, {f32Ty, f32Ty, f32Ty, f32Ty},
+ {ptr, values[0], values[1], values[2], values[3]}, {},
+ "atom.add.v4.f32 {%0, %1, %2, %3}, [%4], {%5, %6, %7, %8};", {});
+
+ llvm::SmallVector<mlir::Value> results;
+ results.push_back(inlinePtx.getResult(0));
+ results.push_back(inlinePtx.getResult(1));
+ results.push_back(inlinePtx.getResult(2));
+ results.push_back(inlinePtx.getResult(3));
+
+ mlir::Type vecF32Ty = mlir::VectorType::get({extent}, f32Ty);
+ mlir::Value undef = mlir::LLVM::UndefOp::create(builder, loc, vecF32Ty);
+ mlir::Type i32Ty = builder.getI32Type();
+ for (unsigned i = 0; i < extent; ++i)
+ undef = mlir::LLVM::InsertElementOp::create(
+ builder, loc, undef, results[i],
+ builder.createIntegerConstant(loc, i32Ty, i));
+
+ auto i128Ty = builder.getIntegerType(128);
+ auto i128VecTy = mlir::VectorType::get({1}, i128Ty);
+ mlir::Value vec128 =
+ mlir::vector::BitCastOp::create(builder, loc, i128VecTy, undef);
+ return mlir::vector::ExtractOp::create(builder, loc, vec128,
+ mlir::ArrayRef<int64_t>{0});
+}
+
+mlir::Value
+CUDAIntrinsicLibrary::genAtomicAnd(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+ assert(mlir::isa<mlir::IntegerType>(args[1].getType()));
+
+ mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::_and;
+ return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
+}
+
+mlir::Value
+CUDAIntrinsicLibrary::genAtomicOr(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+ assert(mlir::isa<mlir::IntegerType>(args[1].getType()));
+
+ mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::_or;
+ return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
+}
+
+// ATOMICCAS
+fir::ExtendedValue
+CUDAIntrinsicLibrary::genAtomicCas(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ auto successOrdering = mlir::LLVM::AtomicOrdering::acq_rel;
+ auto failureOrdering = mlir::LLVM::AtomicOrdering::monotonic;
+ auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(resultType.getContext());
+
+ mlir::Value arg0 = fir::getBase(args[0]);
+ mlir::Value arg1 = fir::getBase(args[1]);
+ mlir::Value arg2 = fir::getBase(args[2]);
+
+ auto bitCastFloat = [&](mlir::Value arg) -> mlir::Value {
+ if (mlir::isa<mlir::Float32Type>(arg.getType()))
+ return mlir::LLVM::BitcastOp::create(builder, loc, builder.getI32Type(),
+ arg);
+ if (mlir::isa<mlir::Float64Type>(arg.getType()))
+ return mlir::LLVM::BitcastOp::create(builder, loc, builder.getI64Type(),
+ arg);
+ return arg;
+ };
+
+ arg1 = bitCastFloat(arg1);
+ arg2 = bitCastFloat(arg2);
+
+ if (arg1.getType() != arg2.getType()) {
+ // arg1 and arg2 need to have the same type in AtomicCmpXchgOp.
+ arg2 = builder.createConvert(loc, arg1.getType(), arg2);
+ }
+
+ auto address =
+ mlir::UnrealizedConversionCastOp::create(builder, loc, llvmPtrTy, arg0)
+ .getResult(0);
+ auto cmpxchg = mlir::LLVM::AtomicCmpXchgOp::create(
+ builder, loc, address, arg1, arg2, successOrdering, failureOrdering);
+ mlir::Value boolResult =
+ mlir::LLVM::ExtractValueOp::create(builder, loc, cmpxchg, 1);
+ return builder.createConvert(loc, resultType, boolResult);
+}
+
+mlir::Value
+CUDAIntrinsicLibrary::genAtomicDec(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+ assert(mlir::isa<mlir::IntegerType>(args[1].getType()));
+
+ mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::udec_wrap;
+ return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
+}
+
+// ATOMICEXCH
+fir::ExtendedValue
+CUDAIntrinsicLibrary::genAtomicExch(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 2);
+ mlir::Value arg0 = fir::getBase(args[0]);
+ mlir::Value arg1 = fir::getBase(args[1]);
+ assert(arg1.getType().isIntOrFloat());
+
+ mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::xchg;
+ return genAtomBinOp(builder, loc, binOp, arg0, arg1);
+}
+
+mlir::Value
+CUDAIntrinsicLibrary::genAtomicInc(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+ assert(mlir::isa<mlir::IntegerType>(args[1].getType()));
+
+ mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::uinc_wrap;
+ return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
+}
+
+mlir::Value
+CUDAIntrinsicLibrary::genAtomicMax(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+
+ mlir::LLVM::AtomicBinOp binOp =
+ mlir::isa<mlir::IntegerType>(args[1].getType())
+ ? mlir::LLVM::AtomicBinOp::max
+ : mlir::LLVM::AtomicBinOp::fmax;
+ return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
+}
+
+mlir::Value
+CUDAIntrinsicLibrary::genAtomicMin(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+
+ mlir::LLVM::AtomicBinOp binOp =
+ mlir::isa<mlir::IntegerType>(args[1].getType())
+ ? mlir::LLVM::AtomicBinOp::min
+ : mlir::LLVM::AtomicBinOp::fmin;
+ return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
+}
+
+// ATOMICSUB
+mlir::Value
+CUDAIntrinsicLibrary::genAtomicSub(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+ mlir::LLVM::AtomicBinOp binOp =
+ mlir::isa<mlir::IntegerType>(args[1].getType())
+ ? mlir::LLVM::AtomicBinOp::sub
+ : mlir::LLVM::AtomicBinOp::fsub;
+ return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
+}
+
+// ATOMICXOR
+fir::ExtendedValue
+CUDAIntrinsicLibrary::genAtomicXor(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 2);
+ mlir::Value arg0 = fir::getBase(args[0]);
+ mlir::Value arg1 = fir::getBase(args[1]);
+ return genAtomBinOp(builder, loc, mlir::LLVM::AtomicBinOp::_xor, arg0, arg1);
+}
+
+// BARRIER_ARRIVE
+mlir::Value
+CUDAIntrinsicLibrary::genBarrierArrive(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 1);
+ mlir::Value barrier = convertPtrToNVVMSpace(
+ builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared);
+ return mlir::NVVM::MBarrierArriveOp::create(builder, loc, resultType, barrier)
+ .getResult(0);
+}
+
+// BARRIER_ARRIBVE_CNT
+mlir::Value
+CUDAIntrinsicLibrary::genBarrierArriveCnt(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+ mlir::Value barrier = convertPtrToNVVMSpace(
+ builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared);
+ return mlir::NVVM::InlinePtxOp::create(builder, loc, {resultType},
+ {barrier, args[1]}, {},
+ "mbarrier.arrive.expect_tx.release."
+ "cta.shared::cta.b64 %0, [%1], %2;",
+ {})
+ .getResult(0);
+}
+
+// BARRIER_INIT
+void CUDAIntrinsicLibrary::genBarrierInit(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 2);
+ mlir::Value barrier = convertPtrToNVVMSpace(
+ builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared);
+ mlir::NVVM::MBarrierInitOp::create(builder, loc, barrier,
+ fir::getBase(args[1]), {});
+ auto kind = mlir::NVVM::ProxyKindAttr::get(
+ builder.getContext(), mlir::NVVM::ProxyKind::async_shared);
+ auto space = mlir::NVVM::SharedSpaceAttr::get(
+ builder.getContext(), mlir::NVVM::SharedSpace::shared_cta);
+ mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space);
+}
+
+// BARRIER_TRY_WAIT
+mlir::Value
+CUDAIntrinsicLibrary::genBarrierTryWait(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+ mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
+ mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
+ fir::StoreOp::create(builder, loc, zero, res);
+ mlir::Value ns =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 1000000);
+ mlir::Value load = fir::LoadOp::create(builder, loc, res);
+ auto whileOp = mlir::scf::WhileOp::create(
+ builder, loc, mlir::TypeRange{resultType}, mlir::ValueRange{load});
+ mlir::Block *beforeBlock = builder.createBlock(&whileOp.getBefore());
+ mlir::Value beforeArg = beforeBlock->addArgument(resultType, loc);
+ builder.setInsertionPointToStart(beforeBlock);
+ mlir::Value condition = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::eq, beforeArg, zero);
+ mlir::scf::ConditionOp::create(builder, loc, condition, beforeArg);
+ mlir::Block *afterBlock = builder.createBlock(&whileOp.getAfter());
+ afterBlock->addArgument(resultType, loc);
+ builder.setInsertionPointToStart(afterBlock);
+ auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
+ auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]);
+ mlir::Value ret = mlir::NVVM::InlinePtxOp::create(
+ builder, loc, {resultType}, {barrier, args[1], ns}, {},
+ "{\n"
+ " .reg .pred p;\n"
+ " mbarrier.try_wait.shared.b64 p, [%1], %2, %3;\n"
+ " selp.b32 %0, 1, 0, p;\n"
+ "}",
+ {})
+ .getResult(0);
+ mlir::scf::YieldOp::create(builder, loc, ret);
+ builder.setInsertionPointAfter(whileOp);
+ return whileOp.getResult(0);
+}
+
+// BARRIER_TRY_WAIT_SLEEP
+mlir::Value
+CUDAIntrinsicLibrary::genBarrierTryWaitSleep(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 3);
+ auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
+ auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]);
+ return mlir::NVVM::InlinePtxOp::create(
+ builder, loc, {resultType}, {barrier, args[1], args[2]}, {},
+ "{\n"
+ " .reg .pred p;\n"
+ " mbarrier.try_wait.shared.b64 p, [%1], %2, %3;\n"
+ " selp.b32 %0, 1, 0, p;\n"
+ "}",
+ {})
+ .getResult(0);
+}
+
+static void insertValueAtPos(fir::FirOpBuilder &builder, mlir::Location loc,
+ fir::RecordType recTy, mlir::Value base,
+ mlir::Value dim, unsigned fieldPos) {
+ auto fieldName = recTy.getTypeList()[fieldPos].first;
+ mlir::Type fieldTy = recTy.getTypeList()[fieldPos].second;
+ mlir::Type fieldIndexType = fir::FieldType::get(base.getContext());
+ mlir::Value fieldIndex =
+ fir::FieldIndexOp::create(builder, loc, fieldIndexType, fieldName, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ mlir::Value coord = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(fieldTy), base, fieldIndex);
+ fir::StoreOp::create(builder, loc, dim, coord);
+}
+
+// CLUSTER_BLOCK_INDEX
+mlir::Value
+CUDAIntrinsicLibrary::genClusterBlockIndex(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 0);
+ auto recTy = mlir::cast<fir::RecordType>(resultType);
+ assert(recTy && "RecordType expepected");
+ mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
+ mlir::Type i32Ty = builder.getI32Type();
+ mlir::Value x = mlir::NVVM::BlockInClusterIdXOp::create(builder, loc, i32Ty);
+ mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
+ x = mlir::arith::AddIOp::create(builder, loc, x, one);
+ insertValueAtPos(builder, loc, recTy, res, x, 0);
+ mlir::Value y = mlir::NVVM::BlockInClusterIdYOp::create(builder, loc, i32Ty);
+ y = mlir::arith::AddIOp::create(builder, loc, y, one);
+ insertValueAtPos(builder, loc, recTy, res, y, 1);
+ mlir::Value z = mlir::NVVM::BlockInClusterIdZOp::create(builder, loc, i32Ty);
+ z = mlir::arith::AddIOp::create(builder, loc, z, one);
+ insertValueAtPos(builder, loc, recTy, res, z, 2);
+ return res;
+}
+
+// CLUSTER_DIM_BLOCKS
+mlir::Value
+CUDAIntrinsicLibrary::genClusterDimBlocks(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 0);
+ auto recTy = mlir::cast<fir::RecordType>(resultType);
+ assert(recTy && "RecordType expepected");
+ mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
+ mlir::Type i32Ty = builder.getI32Type();
+ mlir::Value x = mlir::NVVM::ClusterDimBlocksXOp::create(builder, loc, i32Ty);
+ insertValueAtPos(builder, loc, recTy, res, x, 0);
+ mlir::Value y = mlir::NVVM::ClusterDimBlocksYOp::create(builder, loc, i32Ty);
+ insertValueAtPos(builder, loc, recTy, res, y, 1);
+ mlir::Value z = mlir::NVVM::ClusterDimBlocksZOp::create(builder, loc, i32Ty);
+ insertValueAtPos(builder, loc, recTy, res, z, 2);
+ return res;
+}
+
+// FENCE_PROXY_ASYNC
+void CUDAIntrinsicLibrary::genFenceProxyAsync(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 0);
+ auto kind = mlir::NVVM::ProxyKindAttr::get(
+ builder.getContext(), mlir::NVVM::ProxyKind::async_shared);
+ auto space = mlir::NVVM::SharedSpaceAttr::get(
+ builder.getContext(), mlir::NVVM::SharedSpace::shared_cta);
+ mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space);
+}
+
+// __LDCA, __LDCS, __LDLU, __LDCV
+template <const char *fctName, int extent>
+fir::ExtendedValue
+CUDAIntrinsicLibrary::genLDXXFunc(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 1);
+ mlir::Type resTy = fir::SequenceType::get(extent, resultType);
+ mlir::Value arg = fir::getBase(args[0]);
+ mlir::Value res = fir::AllocaOp::create(builder, loc, resTy);
+ if (mlir::isa<fir::BaseBoxType>(arg.getType()))
+ arg = fir::BoxAddrOp::create(builder, loc, arg);
+ mlir::Type refResTy = fir::ReferenceType::get(resTy);
+ mlir::FunctionType ftype =
+ mlir::FunctionType::get(arg.getContext(), {refResTy, refResTy}, {});
+ auto funcOp = builder.createFunction(loc, fctName, ftype);
+ llvm::SmallVector<mlir::Value> funcArgs;
+ funcArgs.push_back(res);
+ funcArgs.push_back(arg);
+ fir::CallOp::create(builder, loc, funcOp, funcArgs);
+ mlir::Value ext =
+ builder.createIntegerConstant(loc, builder.getIndexType(), extent);
+ return fir::ArrayBoxValue(res, {ext});
+}
+
+// CLOCK, CLOCK64, GLOBALTIMER
+template <typename OpTy>
+mlir::Value
+CUDAIntrinsicLibrary::genNVVMTime(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 0 && "expect no arguments");
+ return OpTy::create(builder, loc, resultType).getResult();
+}
+
+// MATCH_ALL_SYNC
+mlir::Value
+CUDAIntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 3);
+ bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
+
+ mlir::Type i1Ty = builder.getI1Type();
+ mlir::MLIRContext *context = builder.getContext();
+
+ mlir::Value arg1 = args[1];
+ if (arg1.getType().isF32() || arg1.getType().isF64())
+ arg1 = fir::ConvertOp::create(
+ builder, loc, is32 ? builder.getI32Type() : builder.getI64Type(), arg1);
+
+ mlir::Type retTy =
+ mlir::LLVM::LLVMStructType::getLiteral(context, {resultType, i1Ty});
+ auto match =
+ mlir::NVVM::MatchSyncOp::create(builder, loc, retTy, args[0], arg1,
+ mlir::NVVM::MatchSyncKind::all)
+ .getResult();
+ auto value = mlir::LLVM::ExtractValueOp::create(builder, loc, match, 0);
+ auto pred = mlir::LLVM::ExtractValueOp::create(builder, loc, match, 1);
+ auto conv = mlir::LLVM::ZExtOp::create(builder, loc, resultType, pred);
+ fir::StoreOp::create(builder, loc, conv, args[2]);
+ return value;
+}
+
+// MATCH_ANY_SYNC
+mlir::Value
+CUDAIntrinsicLibrary::genMatchAnySync(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+ bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
+
+ mlir::Value arg1 = args[1];
+ if (arg1.getType().isF32() || arg1.getType().isF64())
+ arg1 = fir::ConvertOp::create(
+ builder, loc, is32 ? builder.getI32Type() : builder.getI64Type(), arg1);
+
+ return mlir::NVVM::MatchSyncOp::create(builder, loc, resultType, args[0],
+ arg1, mlir::NVVM::MatchSyncKind::any)
+ .getResult();
+}
+
+// SYNCTHREADS
+void CUDAIntrinsicLibrary::genSyncThreads(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ mlir::NVVM::Barrier0Op::create(builder, loc);
+}
+
+// SYNCTHREADS_AND
+mlir::Value
+CUDAIntrinsicLibrary::genSyncThreadsAnd(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ mlir::Value arg = builder.createConvert(loc, builder.getI32Type(), args[0]);
+ return mlir::NVVM::BarrierOp::create(
+ builder, loc, resultType, {}, {},
+ mlir::NVVM::BarrierReductionAttr::get(
+ builder.getContext(), mlir::NVVM::BarrierReduction::AND),
+ arg)
+ .getResult(0);
+}
+
+// SYNCTHREADS_COUNT
+mlir::Value
+CUDAIntrinsicLibrary::genSyncThreadsCount(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ mlir::Value arg = builder.createConvert(loc, builder.getI32Type(), args[0]);
+ return mlir::NVVM::BarrierOp::create(
+ builder, loc, resultType, {}, {},
+ mlir::NVVM::BarrierReductionAttr::get(
+ builder.getContext(), mlir::NVVM::BarrierReduction::POPC),
+ arg)
+ .getResult(0);
+}
+
+// SYNCTHREADS_OR
+mlir::Value
+CUDAIntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ mlir::Value arg = builder.createConvert(loc, builder.getI32Type(), args[0]);
+ return mlir::NVVM::BarrierOp::create(
+ builder, loc, resultType, {}, {},
+ mlir::NVVM::BarrierReductionAttr::get(
+ builder.getContext(), mlir::NVVM::BarrierReduction::OR),
+ arg)
+ .getResult(0);
+}
+
+// SYNCWARP
+void CUDAIntrinsicLibrary::genSyncWarp(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 1);
+ mlir::NVVM::SyncWarpOp::create(builder, loc, fir::getBase(args[0]));
+}
+
+// THIS_CLUSTER
+mlir::Value
+CUDAIntrinsicLibrary::genThisCluster(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 0);
+ auto recTy = mlir::cast<fir::RecordType>(resultType);
+ assert(recTy && "RecordType expepected");
+ mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
+ mlir::Type i32Ty = builder.getI32Type();
+
+ // SIZE
+ mlir::Value size = mlir::NVVM::ClusterDim::create(builder, loc, i32Ty);
+ auto sizeFieldName = recTy.getTypeList()[1].first;
+ mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
+ mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
+ mlir::Value sizeFieldIndex = fir::FieldIndexOp::create(
+ builder, loc, fieldIndexType, sizeFieldName, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ mlir::Value sizeCoord = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
+ fir::StoreOp::create(builder, loc, size, sizeCoord);
+
+ // RANK
+ mlir::Value rank = mlir::NVVM::ClusterId::create(builder, loc, i32Ty);
+ mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
+ rank = mlir::arith::AddIOp::create(builder, loc, rank, one);
+ auto rankFieldName = recTy.getTypeList()[2].first;
+ mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
+ mlir::Value rankFieldIndex = fir::FieldIndexOp::create(
+ builder, loc, fieldIndexType, rankFieldName, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ mlir::Value rankCoord = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
+ fir::StoreOp::create(builder, loc, rank, rankCoord);
+
+ return res;
+}
+
+// THIS_GRID
+mlir::Value
+CUDAIntrinsicLibrary::genThisGrid(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 0);
+ auto recTy = mlir::cast<fir::RecordType>(resultType);
+ assert(recTy && "RecordType expepected");
+ mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
+ mlir::Type i32Ty = builder.getI32Type();
+
+ mlir::Value threadIdX = mlir::NVVM::ThreadIdXOp::create(builder, loc, i32Ty);
+ mlir::Value threadIdY = mlir::NVVM::ThreadIdYOp::create(builder, loc, i32Ty);
+ mlir::Value threadIdZ = mlir::NVVM::ThreadIdZOp::create(builder, loc, i32Ty);
+
+ mlir::Value blockIdX = mlir::NVVM::BlockIdXOp::create(builder, loc, i32Ty);
+ mlir::Value blockIdY = mlir::NVVM::BlockIdYOp::create(builder, loc, i32Ty);
+ mlir::Value blockIdZ = mlir::NVVM::BlockIdZOp::create(builder, loc, i32Ty);
+
+ mlir::Value blockDimX = mlir::NVVM::BlockDimXOp::create(builder, loc, i32Ty);
+ mlir::Value blockDimY = mlir::NVVM::BlockDimYOp::create(builder, loc, i32Ty);
+ mlir::Value blockDimZ = mlir::NVVM::BlockDimZOp::create(builder, loc, i32Ty);
+ mlir::Value gridDimX = mlir::NVVM::GridDimXOp::create(builder, loc, i32Ty);
+ mlir::Value gridDimY = mlir::NVVM::GridDimYOp::create(builder, loc, i32Ty);
+ mlir::Value gridDimZ = mlir::NVVM::GridDimZOp::create(builder, loc, i32Ty);
+
+ // this_grid.size = ((blockDim.z * gridDim.z) * (blockDim.y * gridDim.y)) *
+ // (blockDim.x * gridDim.x);
+ mlir::Value resZ =
+ mlir::arith::MulIOp::create(builder, loc, blockDimZ, gridDimZ);
+ mlir::Value resY =
+ mlir::arith::MulIOp::create(builder, loc, blockDimY, gridDimY);
+ mlir::Value resX =
+ mlir::arith::MulIOp::create(builder, loc, blockDimX, gridDimX);
+ mlir::Value resZY = mlir::arith::MulIOp::create(builder, loc, resZ, resY);
+ mlir::Value size = mlir::arith::MulIOp::create(builder, loc, resZY, resX);
+
+ // tmp = ((blockIdx.z * gridDim.y * gridDim.x) + (blockIdx.y * gridDim.x)) +
+ // blockIdx.x;
+ // this_group.rank = tmp * ((blockDim.x * blockDim.y) * blockDim.z) +
+ // ((threadIdx.z * blockDim.y) * blockDim.x) +
+ // (threadIdx.y * blockDim.x) + threadIdx.x + 1;
+ mlir::Value r1 =
+ mlir::arith::MulIOp::create(builder, loc, blockIdZ, gridDimY);
+ mlir::Value r2 = mlir::arith::MulIOp::create(builder, loc, r1, gridDimX);
+ mlir::Value r3 =
+ mlir::arith::MulIOp::create(builder, loc, blockIdY, gridDimX);
+ mlir::Value r2r3 = mlir::arith::AddIOp::create(builder, loc, r2, r3);
+ mlir::Value tmp = mlir::arith::AddIOp::create(builder, loc, r2r3, blockIdX);
+
+ mlir::Value bXbY =
+ mlir::arith::MulIOp::create(builder, loc, blockDimX, blockDimY);
+ mlir::Value bXbYbZ =
+ mlir::arith::MulIOp::create(builder, loc, bXbY, blockDimZ);
+ mlir::Value tZbY =
+ mlir::arith::MulIOp::create(builder, loc, threadIdZ, blockDimY);
+ mlir::Value tZbYbX =
+ mlir::arith::MulIOp::create(builder, loc, tZbY, blockDimX);
+ mlir::Value tYbX =
+ mlir::arith::MulIOp::create(builder, loc, threadIdY, blockDimX);
+ mlir::Value rank = mlir::arith::MulIOp::create(builder, loc, tmp, bXbYbZ);
+ rank = mlir::arith::AddIOp::create(builder, loc, rank, tZbYbX);
+ rank = mlir::arith::AddIOp::create(builder, loc, rank, tYbX);
+ rank = mlir::arith::AddIOp::create(builder, loc, rank, threadIdX);
+ mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
+ rank = mlir::arith::AddIOp::create(builder, loc, rank, one);
+
+ auto sizeFieldName = recTy.getTypeList()[1].first;
+ mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
+ mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
+ mlir::Value sizeFieldIndex = fir::FieldIndexOp::create(
+ builder, loc, fieldIndexType, sizeFieldName, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ mlir::Value sizeCoord = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
+ fir::StoreOp::create(builder, loc, size, sizeCoord);
+
+ auto rankFieldName = recTy.getTypeList()[2].first;
+ mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
+ mlir::Value rankFieldIndex = fir::FieldIndexOp::create(
+ builder, loc, fieldIndexType, rankFieldName, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ mlir::Value rankCoord = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
+ fir::StoreOp::create(builder, loc, rank, rankCoord);
+ return res;
+}
+
+// THIS_THREAD_BLOCK
+mlir::Value
+CUDAIntrinsicLibrary::genThisThreadBlock(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 0);
+ auto recTy = mlir::cast<fir::RecordType>(resultType);
+ assert(recTy && "RecordType expepected");
+ mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
+ mlir::Type i32Ty = builder.getI32Type();
+
+ // this_thread_block%size = blockDim.z * blockDim.y * blockDim.x;
+ mlir::Value blockDimX = mlir::NVVM::BlockDimXOp::create(builder, loc, i32Ty);
+ mlir::Value blockDimY = mlir::NVVM::BlockDimYOp::create(builder, loc, i32Ty);
+ mlir::Value blockDimZ = mlir::NVVM::BlockDimZOp::create(builder, loc, i32Ty);
+ mlir::Value size =
+ mlir::arith::MulIOp::create(builder, loc, blockDimZ, blockDimY);
+ size = mlir::arith::MulIOp::create(builder, loc, size, blockDimX);
+
+ // this_thread_block%rank = ((threadIdx.z * blockDim.y) * blockDim.x) +
+ // (threadIdx.y * blockDim.x) + threadIdx.x + 1;
+ mlir::Value threadIdX = mlir::NVVM::ThreadIdXOp::create(builder, loc, i32Ty);
+ mlir::Value threadIdY = mlir::NVVM::ThreadIdYOp::create(builder, loc, i32Ty);
+ mlir::Value threadIdZ = mlir::NVVM::ThreadIdZOp::create(builder, loc, i32Ty);
+ mlir::Value r1 =
+ mlir::arith::MulIOp::create(builder, loc, threadIdZ, blockDimY);
+ mlir::Value r2 = mlir::arith::MulIOp::create(builder, loc, r1, blockDimX);
+ mlir::Value r3 =
+ mlir::arith::MulIOp::create(builder, loc, threadIdY, blockDimX);
+ mlir::Value r2r3 = mlir::arith::AddIOp::create(builder, loc, r2, r3);
+ mlir::Value rank = mlir::arith::AddIOp::create(builder, loc, r2r3, threadIdX);
+ mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
+ rank = mlir::arith::AddIOp::create(builder, loc, rank, one);
+
+ auto sizeFieldName = recTy.getTypeList()[1].first;
+ mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
+ mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
+ mlir::Value sizeFieldIndex = fir::FieldIndexOp::create(
+ builder, loc, fieldIndexType, sizeFieldName, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ mlir::Value sizeCoord = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
+ fir::StoreOp::create(builder, loc, size, sizeCoord);
+
+ auto rankFieldName = recTy.getTypeList()[2].first;
+ mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
+ mlir::Value rankFieldIndex = fir::FieldIndexOp::create(
+ builder, loc, fieldIndexType, rankFieldName, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ mlir::Value rankCoord = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
+ fir::StoreOp::create(builder, loc, rank, rankCoord);
+ return res;
+}
+
+// THIS_WARP
+mlir::Value
+CUDAIntrinsicLibrary::genThisWarp(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 0);
+ auto recTy = mlir::cast<fir::RecordType>(resultType);
+ assert(recTy && "RecordType expepected");
+ mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
+ mlir::Type i32Ty = builder.getI32Type();
+
+ // coalesced_group%size = 32
+ mlir::Value size = builder.createIntegerConstant(loc, i32Ty, 32);
+ auto sizeFieldName = recTy.getTypeList()[1].first;
+ mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
+ mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
+ mlir::Value sizeFieldIndex = fir::FieldIndexOp::create(
+ builder, loc, fieldIndexType, sizeFieldName, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ mlir::Value sizeCoord = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
+ fir::StoreOp::create(builder, loc, size, sizeCoord);
+
+ // coalesced_group%rank = threadIdx.x & 31 + 1
+ mlir::Value threadIdX = mlir::NVVM::ThreadIdXOp::create(builder, loc, i32Ty);
+ mlir::Value mask = builder.createIntegerConstant(loc, i32Ty, 31);
+ mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
+ mlir::Value masked =
+ mlir::arith::AndIOp::create(builder, loc, threadIdX, mask);
+ mlir::Value rank = mlir::arith::AddIOp::create(builder, loc, masked, one);
+ auto rankFieldName = recTy.getTypeList()[2].first;
+ mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
+ mlir::Value rankFieldIndex = fir::FieldIndexOp::create(
+ builder, loc, fieldIndexType, rankFieldName, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ mlir::Value rankCoord = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
+ fir::StoreOp::create(builder, loc, rank, rankCoord);
+ return res;
+}
+
+// THREADFENCE, THREADFENCE_BLOCK, THREADFENCE_SYSTEM
+template <mlir::NVVM::MemScopeKind scope>
+void CUDAIntrinsicLibrary::genThreadFence(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 0);
+ mlir::NVVM::MembarOp::create(builder, loc, scope);
+}
+
+// TMA_BULK_COMMIT_GROUP
+void CUDAIntrinsicLibrary::genTMABulkCommitGroup(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 0);
+ mlir::NVVM::CpAsyncBulkCommitGroupOp::create(builder, loc);
+}
+
+// TMA_BULK_G2S
+void CUDAIntrinsicLibrary::genTMABulkG2S(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value barrier = convertPtrToNVVMSpace(
+ builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared);
+ mlir::Value dst =
+ convertPtrToNVVMSpace(builder, loc, fir::getBase(args[2]),
+ mlir::NVVM::NVVMMemorySpace::SharedCluster);
+ mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]),
+ mlir::NVVM::NVVMMemorySpace::Global);
+ mlir::NVVM::CpAsyncBulkGlobalToSharedClusterOp::create(
+ builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {});
+}
+
+static void setAlignment(mlir::Value ptr, unsigned alignment) {
+ if (auto declareOp = mlir::dyn_cast<hlfir::DeclareOp>(ptr.getDefiningOp()))
+ if (auto sharedOp = mlir::dyn_cast<cuf::SharedMemoryOp>(
+ declareOp.getMemref().getDefiningOp()))
+ sharedOp.setAlignment(alignment);
+}
+
+static void genTMABulkLoad(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value barrier, mlir::Value src,
+ mlir::Value dst, mlir::Value nelem,
+ mlir::Value eleSize) {
+ mlir::Value size = mlir::arith::MulIOp::create(builder, loc, nelem, eleSize);
+ auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
+ barrier = builder.createConvert(loc, llvmPtrTy, barrier);
+ setAlignment(dst, kTMAAlignment);
+ dst = builder.createConvert(loc, llvmPtrTy, dst);
+ src = builder.createConvert(loc, llvmPtrTy, src);
+ mlir::NVVM::InlinePtxOp::create(
+ builder, loc, mlir::TypeRange{}, {dst, src, size, barrier}, {},
+ "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], "
+ "[%1], %2, [%3];",
+ {});
+ mlir::NVVM::InlinePtxOp::create(
+ builder, loc, mlir::TypeRange{}, {barrier, size}, {},
+ "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;", {});
+}
+
+// TMA_BULK_LOADC4
+void CUDAIntrinsicLibrary::genTMABulkLoadC4(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+ genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADC8
+void CUDAIntrinsicLibrary::genTMABulkLoadC8(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 16);
+ genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADI4
+void CUDAIntrinsicLibrary::genTMABulkLoadI4(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 4);
+ genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADI8
+void CUDAIntrinsicLibrary::genTMABulkLoadI8(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+ genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADR2
+void CUDAIntrinsicLibrary::genTMABulkLoadR2(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 2);
+ genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADR4
+void CUDAIntrinsicLibrary::genTMABulkLoadR4(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 4);
+ genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADR8
+void CUDAIntrinsicLibrary::genTMABulkLoadR8(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+ genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_S2G
+void CUDAIntrinsicLibrary::genTMABulkS2G(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[0]),
+ mlir::NVVM::NVVMMemorySpace::Shared);
+ mlir::Value dst = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]),
+ mlir::NVVM::NVVMMemorySpace::Global);
+ mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create(
+ builder, loc, dst, src, fir::getBase(args[2]), {}, {});
+
+ mlir::NVVM::InlinePtxOp::create(builder, loc, mlir::TypeRange{}, {}, {},
+ "cp.async.bulk.commit_group;", {});
+ mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc,
+ builder.getI32IntegerAttr(0), {});
+}
+
+static void genTMABulkStore(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value src, mlir::Value dst, mlir::Value count,
+ mlir::Value eleSize) {
+ mlir::Value size = mlir::arith::MulIOp::create(builder, loc, eleSize, count);
+ setAlignment(src, kTMAAlignment);
+ src = convertPtrToNVVMSpace(builder, loc, src,
+ mlir::NVVM::NVVMMemorySpace::Shared);
+ dst = convertPtrToNVVMSpace(builder, loc, dst,
+ mlir::NVVM::NVVMMemorySpace::Global);
+ mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create(builder, loc, dst, src,
+ size, {}, {});
+ mlir::NVVM::InlinePtxOp::create(builder, loc, mlir::TypeRange{}, {}, {},
+ "cp.async.bulk.commit_group;", {});
+ mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc,
+ builder.getI32IntegerAttr(0), {});
+}
+
+// TMA_BULK_STORE_C4
+void CUDAIntrinsicLibrary::genTMABulkStoreC4(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+ genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_C8
+void CUDAIntrinsicLibrary::genTMABulkStoreC8(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 16);
+ genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_I4
+void CUDAIntrinsicLibrary::genTMABulkStoreI4(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 4);
+ genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_I8
+void CUDAIntrinsicLibrary::genTMABulkStoreI8(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+ genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_R2
+void CUDAIntrinsicLibrary::genTMABulkStoreR2(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 2);
+ genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_R4
+void CUDAIntrinsicLibrary::genTMABulkStoreR4(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 4);
+ genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_R8
+void CUDAIntrinsicLibrary::genTMABulkStoreR8(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+ genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_WAIT_GROUP
+void CUDAIntrinsicLibrary::genTMABulkWaitGroup(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 0);
+ auto group = builder.getIntegerAttr(builder.getI32Type(), 0);
+ mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc, group, {});
+}
+
+// ALL_SYNC, ANY_SYNC, BALLOT_SYNC
+template <mlir::NVVM::VoteSyncKind kind>
+mlir::Value
+CUDAIntrinsicLibrary::genVoteSync(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+ mlir::Value arg1 =
+ fir::ConvertOp::create(builder, loc, builder.getI1Type(), args[1]);
+ mlir::Type resTy = kind == mlir::NVVM::VoteSyncKind::ballot
+ ? builder.getI32Type()
+ : builder.getI1Type();
+ auto voteRes =
+ mlir::NVVM::VoteSyncOp::create(builder, loc, resTy, args[0], arg1, kind)
+ .getResult();
+ return fir::ConvertOp::create(builder, loc, resultType, voteRes);
+}
+
+} // namespace fir
diff --git a/flang/lib/Optimizer/Builder/CUFCommon.cpp b/flang/lib/Optimizer/Builder/CUFCommon.cpp
index cf7588f..2266f4d 100644
--- a/flang/lib/Optimizer/Builder/CUFCommon.cpp
+++ b/flang/lib/Optimizer/Builder/CUFCommon.cpp
@@ -9,6 +9,7 @@
#include "flang/Optimizer/Builder/CUFCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
@@ -91,3 +92,66 @@ void cuf::genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder) {
}
}
}
+
+int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type,
+ fir::KindMapping &kindMap,
+ bool emitErrorOnFailure) {
+ auto eleTy = fir::unwrapSequenceType(type);
+ if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
+ return t.getWidth() / 8;
+ if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
+ return t.getWidth() / 8;
+ if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
+ return kindMap.getLogicalBitsize(t.getFKind()) / 8;
+ if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
+ int elemSize =
+ mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
+ return 2 * elemSize;
+ }
+ if (auto t{mlir::dyn_cast<fir::CharacterType>(eleTy)})
+ return kindMap.getCharacterBitsize(t.getFKind()) / 8;
+ if (emitErrorOnFailure)
+ mlir::emitError(loc, "unsupported type");
+ return 0;
+}
+
+mlir::Value cuf::computeElementCount(mlir::PatternRewriter &rewriter,
+ mlir::Location loc,
+ mlir::Value shapeOperand,
+ mlir::Type seqType,
+ mlir::Type targetType) {
+ if (shapeOperand) {
+ // Dynamic extent - extract from shape operand
+ llvm::SmallVector<mlir::Value> extents;
+ if (auto shapeOp =
+ mlir::dyn_cast<fir::ShapeOp>(shapeOperand.getDefiningOp())) {
+ extents = shapeOp.getExtents();
+ } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
+ shapeOperand.getDefiningOp())) {
+ for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
+ if (i.index() & 1)
+ extents.push_back(i.value());
+ }
+
+ if (extents.empty())
+ return mlir::Value();
+
+ // Compute total element count by multiplying all dimensions
+ mlir::Value count =
+ fir::ConvertOp::create(rewriter, loc, targetType, extents[0]);
+ for (unsigned i = 1; i < extents.size(); ++i) {
+ auto operand =
+ fir::ConvertOp::create(rewriter, loc, targetType, extents[i]);
+ count = mlir::arith::MulIOp::create(rewriter, loc, count, operand);
+ }
+ return count;
+ } else {
+ // Static extent - use constant array size
+ if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(seqType)) {
+ mlir::IntegerAttr attr =
+ rewriter.getIntegerAttr(targetType, seqTy.getConstantArraySize());
+ return mlir::arith::ConstantOp::create(rewriter, loc, targetType, attr);
+ }
+ }
+ return mlir::Value();
+}
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 5da27d1..6a9c84f 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Analysis/AliasAnalysis.h"
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/Character.h"
#include "flang/Optimizer/Builder/Complex.h"
@@ -427,7 +428,8 @@ mlir::Value fir::FirOpBuilder::genTempDeclareOp(
builder, loc, memref.getType(), memref, shape, typeParams,
/*dummy_scope=*/nullptr,
/*storage=*/nullptr,
- /*storage_offset=*/0, nameAttr, fortranAttrs, cuf::DataAttributeAttr{});
+ /*storage_offset=*/0, nameAttr, fortranAttrs, cuf::DataAttributeAttr{},
+ /*dummy_arg_no=*/mlir::IntegerAttr{});
}
mlir::Value fir::FirOpBuilder::genStackSave(mlir::Location loc) {
@@ -858,21 +860,32 @@ mlir::Value fir::FirOpBuilder::genIsNullAddr(mlir::Location loc,
mlir::arith::CmpIPredicate::eq);
}
-mlir::Value fir::FirOpBuilder::genExtentFromTriplet(mlir::Location loc,
- mlir::Value lb,
- mlir::Value ub,
- mlir::Value step,
- mlir::Type type) {
+template <typename OpTy, typename... Args>
+static mlir::Value createAndMaybeFold(bool fold, fir::FirOpBuilder &builder,
+ mlir::Location loc, Args &&...args) {
+ if (fold)
+ return builder.createOrFold<OpTy>(loc, std::forward<Args>(args)...);
+ return OpTy::create(builder, loc, std::forward<Args>(args)...);
+}
+
+mlir::Value
+fir::FirOpBuilder::genExtentFromTriplet(mlir::Location loc, mlir::Value lb,
+ mlir::Value ub, mlir::Value step,
+ mlir::Type type, bool fold) {
auto zero = createIntegerConstant(loc, type, 0);
lb = createConvert(loc, type, lb);
ub = createConvert(loc, type, ub);
step = createConvert(loc, type, step);
- auto diff = mlir::arith::SubIOp::create(*this, loc, ub, lb);
- auto add = mlir::arith::AddIOp::create(*this, loc, diff, step);
- auto div = mlir::arith::DivSIOp::create(*this, loc, add, step);
- auto cmp = mlir::arith::CmpIOp::create(
- *this, loc, mlir::arith::CmpIPredicate::sgt, div, zero);
- return mlir::arith::SelectOp::create(*this, loc, cmp, div, zero);
+
+ auto diff = createAndMaybeFold<mlir::arith::SubIOp>(fold, *this, loc, ub, lb);
+ auto add =
+ createAndMaybeFold<mlir::arith::AddIOp>(fold, *this, loc, diff, step);
+ auto div =
+ createAndMaybeFold<mlir::arith::DivSIOp>(fold, *this, loc, add, step);
+ auto cmp = createAndMaybeFold<mlir::arith::CmpIOp>(
+ fold, *this, loc, mlir::arith::CmpIPredicate::sgt, div, zero);
+ return createAndMaybeFold<mlir::arith::SelectOp>(fold, *this, loc, cmp, div,
+ zero);
}
mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,
@@ -1392,12 +1405,10 @@ fir::ExtendedValue fir::factory::arraySectionElementToExtendedValue(
return fir::factory::componentToExtendedValue(builder, loc, element);
}
-void fir::factory::genScalarAssignment(fir::FirOpBuilder &builder,
- mlir::Location loc,
- const fir::ExtendedValue &lhs,
- const fir::ExtendedValue &rhs,
- bool needFinalization,
- bool isTemporaryLHS) {
+void fir::factory::genScalarAssignment(
+ fir::FirOpBuilder &builder, mlir::Location loc,
+ const fir::ExtendedValue &lhs, const fir::ExtendedValue &rhs,
+ bool needFinalization, bool isTemporaryLHS, mlir::ArrayAttr accessGroups) {
assert(lhs.rank() == 0 && rhs.rank() == 0 && "must be scalars");
auto type = fir::unwrapSequenceType(
fir::unwrapPassByRefType(fir::getBase(lhs).getType()));
@@ -1419,7 +1430,9 @@ void fir::factory::genScalarAssignment(fir::FirOpBuilder &builder,
mlir::Value lhsAddr = fir::getBase(lhs);
rhsVal = builder.createConvert(loc, fir::unwrapRefType(lhsAddr.getType()),
rhsVal);
- fir::StoreOp::create(builder, loc, rhsVal, lhsAddr);
+ fir::StoreOp store = fir::StoreOp::create(builder, loc, rhsVal, lhsAddr);
+ if (accessGroups)
+ store.setAccessGroupsAttr(accessGroups);
}
}
@@ -1554,8 +1567,15 @@ void fir::factory::genRecordAssignment(fir::FirOpBuilder &builder,
mlir::isa<fir::BaseBoxType>(fir::getBase(rhs).getType());
auto recTy = mlir::dyn_cast<fir::RecordType>(baseTy);
assert(recTy && "must be a record type");
+
+ // Use alias analysis to guard the fast path.
+ fir::AliasAnalysis aa;
+ // Aliased SEQUENCE types must take the conservative (slow) path.
+ bool disjoint = isTemporaryLHS || !recTy.isSequence() ||
+ (aa.alias(fir::getBase(lhs), fir::getBase(rhs)) ==
+ mlir::AliasResult::NoAlias);
if ((needFinalization && mayHaveFinalizer(recTy, builder)) ||
- hasBoxOperands || !recordTypeCanBeMemCopied(recTy)) {
+ hasBoxOperands || !recordTypeCanBeMemCopied(recTy) || !disjoint) {
auto to = fir::getBase(builder.createBox(loc, lhs));
auto from = fir::getBase(builder.createBox(loc, rhs));
// The runtime entry point may modify the LHS descriptor if it is
@@ -1670,6 +1690,26 @@ mlir::Value fir::factory::createZeroValue(fir::FirOpBuilder &builder,
"numeric or logical type");
}
+mlir::Value fir::factory::createOneValue(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Type type) {
+ mlir::Type i1 = builder.getIntegerType(1);
+ if (mlir::isa<fir::LogicalType>(type) || type == i1)
+ return builder.createConvert(loc, type, builder.createBool(loc, true));
+ if (fir::isa_integer(type))
+ return builder.createIntegerConstant(loc, type, 1);
+ if (fir::isa_real(type))
+ return builder.createRealOneConstant(loc, type);
+ if (fir::isa_complex(type)) {
+ fir::factory::Complex complexHelper(builder, loc);
+ mlir::Type partType = complexHelper.getComplexPartType(type);
+ mlir::Value realPart = builder.createRealOneConstant(loc, partType);
+ mlir::Value imagPart = builder.createRealZeroConstant(loc, partType);
+ return complexHelper.createComplex(type, realPart, imagPart);
+ }
+ fir::emitFatalError(loc, "internal: trying to generate one value of non "
+ "numeric or logical type");
+}
+
std::optional<std::int64_t>
fir::factory::getExtentFromTriplet(mlir::Value lb, mlir::Value ub,
mlir::Value stride) {
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 93dfc57..3355bf1 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -250,7 +250,7 @@ hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder,
const fir::ExtendedValue &exv, llvm::StringRef name,
fir::FortranVariableFlagsAttr flags, mlir::Value dummyScope,
mlir::Value storage, std::uint64_t storageOffset,
- cuf::DataAttributeAttr dataAttr) {
+ cuf::DataAttributeAttr dataAttr, unsigned dummyArgNo) {
mlir::Value base = fir::getBase(exv);
assert(fir::conformsWithPassByRef(base.getType()) &&
@@ -281,7 +281,7 @@ hlfir::genDeclare(mlir::Location loc, fir::FirOpBuilder &builder,
[](const auto &) {});
auto declareOp = hlfir::DeclareOp::create(
builder, loc, base, name, shapeOrShift, lenParams, dummyScope, storage,
- storageOffset, flags, dataAttr);
+ storageOffset, flags, dataAttr, dummyArgNo);
return mlir::cast<fir::FortranVariableOpInterface>(declareOp.getOperation());
}
@@ -402,9 +402,9 @@ hlfir::Entity hlfir::genVariableBox(mlir::Location loc,
fir::BoxType::get(var.getElementOrSequenceType(), isVolatile);
if (forceBoxType) {
boxType = forceBoxType;
- mlir::Type baseType =
- fir::ReferenceType::get(fir::unwrapRefType(forceBoxType.getEleTy()));
- addr = builder.createConvert(loc, baseType, addr);
+ mlir::Type baseType = fir::ReferenceType::get(
+ fir::unwrapRefType(forceBoxType.getEleTy()), forceBoxType.isVolatile());
+ addr = builder.createConvertWithVolatileCast(loc, baseType, addr);
}
auto embox = fir::EmboxOp::create(builder, loc, boxType, addr, shape,
/*slice=*/mlir::Value{}, typeParams);
@@ -1392,6 +1392,79 @@ bool hlfir::elementalOpMustProduceTemp(hlfir::ElementalOp elemental) {
return false;
}
+static void combineAndStoreElement(
+ mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity lhs,
+ hlfir::Entity rhs, bool temporaryLHS,
+ std::function<hlfir::Entity(mlir::Location, fir::FirOpBuilder &,
+ hlfir::Entity, hlfir::Entity)> *combiner,
+ mlir::ArrayAttr accessGroups) {
+ hlfir::Entity valueToAssign = hlfir::loadTrivialScalar(loc, builder, rhs);
+ if (accessGroups)
+ if (auto load = valueToAssign.getDefiningOp<fir::LoadOp>())
+ load.setAccessGroupsAttr(accessGroups);
+ if (combiner) {
+ hlfir::Entity lhsValue = hlfir::loadTrivialScalar(loc, builder, lhs);
+ if (accessGroups)
+ if (auto load = lhsValue.getDefiningOp<fir::LoadOp>())
+ load.setAccessGroupsAttr(accessGroups);
+ valueToAssign = (*combiner)(loc, builder, lhsValue, valueToAssign);
+ }
+ auto assign = hlfir::AssignOp::create(builder, loc, valueToAssign, lhs,
+ /*realloc=*/false,
+ /*keep_lhs_length_if_realloc=*/false,
+ /*temporary_lhs=*/temporaryLHS);
+ if (accessGroups)
+ assign->setAttr(fir::getAccessGroupsAttrName(), accessGroups);
+}
+
+void hlfir::genNoAliasArrayAssignment(
+ mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity rhs,
+ hlfir::Entity lhs, bool emitWorkshareLoop, bool temporaryLHS,
+ std::function<hlfir::Entity(mlir::Location, fir::FirOpBuilder &,
+ hlfir::Entity, hlfir::Entity)> *combiner,
+ mlir::ArrayAttr accessGroups) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs);
+ lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
+ mlir::Value lhsShape = hlfir::genShape(loc, builder, lhs);
+ llvm::SmallVector<mlir::Value> extents =
+ hlfir::getIndexExtents(loc, builder, lhsShape);
+ if (rhs.isArray()) {
+ mlir::Value rhsShape = hlfir::genShape(loc, builder, rhs);
+ llvm::SmallVector<mlir::Value> rhsExtents =
+ hlfir::getIndexExtents(loc, builder, rhsShape);
+ extents = fir::factory::deduceOptimalExtents(extents, rhsExtents);
+ }
+ hlfir::LoopNest loopNest =
+ hlfir::genLoopNest(loc, builder, extents,
+ /*isUnordered=*/true, emitWorkshareLoop);
+ builder.setInsertionPointToStart(loopNest.body);
+ auto rhsArrayElement =
+ hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
+ rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);
+ auto lhsArrayElement =
+ hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
+ combineAndStoreElement(loc, builder, lhsArrayElement, rhsArrayElement,
+ temporaryLHS, combiner, accessGroups);
+}
+
+void hlfir::genNoAliasAssignment(
+ mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity rhs,
+ hlfir::Entity lhs, bool emitWorkshareLoop, bool temporaryLHS,
+ std::function<hlfir::Entity(mlir::Location, fir::FirOpBuilder &,
+ hlfir::Entity, hlfir::Entity)> *combiner,
+ mlir::ArrayAttr accessGroups) {
+ if (lhs.isArray()) {
+ genNoAliasArrayAssignment(loc, builder, rhs, lhs, emitWorkshareLoop,
+ temporaryLHS, combiner, accessGroups);
+ return;
+ }
+ rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs);
+ lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
+ combineAndStoreElement(loc, builder, lhs, rhs, temporaryLHS, combiner,
+ accessGroups);
+}
+
std::pair<hlfir::Entity, bool>
hlfir::createTempFromMold(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity mold) {
@@ -1624,25 +1697,38 @@ hlfir::genExtentsVector(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity hlfir::gen1DSection(mlir::Location loc,
fir::FirOpBuilder &builder,
hlfir::Entity array, int64_t dim,
- mlir::ArrayRef<mlir::Value> lbounds,
mlir::ArrayRef<mlir::Value> extents,
mlir::ValueRange oneBasedIndices,
mlir::ArrayRef<mlir::Value> typeParams) {
assert(array.isVariable() && "array must be a variable");
assert(dim > 0 && dim <= array.getRank() && "invalid dim number");
+ llvm::SmallVector<mlir::Value> lbounds =
+ getNonDefaultLowerBounds(loc, builder, array);
mlir::Value one =
builder.createIntegerConstant(loc, builder.getIndexType(), 1);
hlfir::DesignateOp::Subscripts subscripts;
unsigned indexId = 0;
for (int i = 0; i < array.getRank(); ++i) {
if (i == dim - 1) {
- mlir::Value ubound = genUBound(loc, builder, lbounds[i], extents[i], one);
- subscripts.emplace_back(
- hlfir::DesignateOp::Triplet{lbounds[i], ubound, one});
+ // (...,:, ..)
+ if (lbounds.empty()) {
+ subscripts.emplace_back(
+ hlfir::DesignateOp::Triplet{one, extents[i], one});
+ } else {
+ mlir::Value ubound =
+ genUBound(loc, builder, lbounds[i], extents[i], one);
+ subscripts.emplace_back(
+ hlfir::DesignateOp::Triplet{lbounds[i], ubound, one});
+ }
} else {
- mlir::Value index =
- genUBound(loc, builder, lbounds[i], oneBasedIndices[indexId++], one);
- subscripts.emplace_back(index);
+ // (...,lb + one_based_index - 1, ..)
+ if (lbounds.empty()) {
+ subscripts.emplace_back(oneBasedIndices[indexId++]);
+ } else {
+ mlir::Value index = genUBound(loc, builder, lbounds[i],
+ oneBasedIndices[indexId++], one);
+ subscripts.emplace_back(index);
+ }
}
}
mlir::Value sectionShape =
@@ -1710,9 +1796,10 @@ bool hlfir::isSimplyContiguous(mlir::Value base, bool checkWhole) {
return false;
return mlir::TypeSwitch<mlir::Operation *, bool>(def)
- .Case<fir::EmboxOp>(
- [&](auto op) { return fir::isContiguousEmbox(op, checkWhole); })
- .Case<fir::ReboxOp>([&](auto op) {
+ .Case([&](fir::EmboxOp op) {
+ return fir::isContiguousEmbox(op, checkWhole);
+ })
+ .Case([&](fir::ReboxOp op) {
hlfir::Entity box{op.getBox()};
return fir::reboxPreservesContinuity(
op, box.mayHaveNonDefaultLowerBounds(), checkWhole) &&
@@ -1721,7 +1808,7 @@ bool hlfir::isSimplyContiguous(mlir::Value base, bool checkWhole) {
.Case<fir::DeclareOp, hlfir::DeclareOp>([&](auto op) {
return isSimplyContiguous(op.getMemref(), checkWhole);
})
- .Case<fir::ConvertOp>(
- [&](auto op) { return isSimplyContiguous(op.getValue()); })
+ .Case(
+ [&](fir::ConvertOp op) { return isSimplyContiguous(op.getValue()); })
.Default([](auto &&) { return false; });
}
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index ec0c802..d3c6739 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -15,7 +15,9 @@
#include "flang/Optimizer/Builder/IntrinsicCall.h"
#include "flang/Common/static-multimap-view.h"
+#include "flang/Lower/AbstractConverter.h"
#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/CUDAIntrinsicCall.h"
#include "flang/Optimizer/Builder/CUFCommon.h"
#include "flang/Optimizer/Builder/Character.h"
#include "flang/Optimizer/Builder/Complex.h"
@@ -90,6 +92,11 @@ static bool isStaticallyAbsent(llvm::ArrayRef<mlir::Value> args,
size_t argIndex) {
return args.size() <= argIndex || !args[argIndex];
}
+static bool isOptional(mlir::Value value) {
+ auto varIface = mlir::dyn_cast_or_null<fir::FortranVariableOpInterface>(
+ value.getDefiningOp());
+ return varIface && varIface.isOptional();
+}
/// Test if an ExtendedValue is present. This is used to test if an intrinsic
/// argument is present at compile time. This does not imply that the related
@@ -107,34 +114,6 @@ using I = IntrinsicLibrary;
/// argument is an optional variable in the current scope).
static constexpr bool handleDynamicOptional = true;
-/// TODO: Move all CUDA Fortran intrinsic handlers into its own file similar to
-/// PPC.
-static const char __ldca_i4x4[] = "__ldca_i4x4_";
-static const char __ldca_i8x2[] = "__ldca_i8x2_";
-static const char __ldca_r2x2[] = "__ldca_r2x2_";
-static const char __ldca_r4x4[] = "__ldca_r4x4_";
-static const char __ldca_r8x2[] = "__ldca_r8x2_";
-static const char __ldcg_i4x4[] = "__ldcg_i4x4_";
-static const char __ldcg_i8x2[] = "__ldcg_i8x2_";
-static const char __ldcg_r2x2[] = "__ldcg_r2x2_";
-static const char __ldcg_r4x4[] = "__ldcg_r4x4_";
-static const char __ldcg_r8x2[] = "__ldcg_r8x2_";
-static const char __ldcs_i4x4[] = "__ldcs_i4x4_";
-static const char __ldcs_i8x2[] = "__ldcs_i8x2_";
-static const char __ldcs_r2x2[] = "__ldcs_r2x2_";
-static const char __ldcs_r4x4[] = "__ldcs_r4x4_";
-static const char __ldcs_r8x2[] = "__ldcs_r8x2_";
-static const char __ldcv_i4x4[] = "__ldcv_i4x4_";
-static const char __ldcv_i8x2[] = "__ldcv_i8x2_";
-static const char __ldcv_r2x2[] = "__ldcv_r2x2_";
-static const char __ldcv_r4x4[] = "__ldcv_r4x4_";
-static const char __ldcv_r8x2[] = "__ldcv_r8x2_";
-static const char __ldlu_i4x4[] = "__ldlu_i4x4_";
-static const char __ldlu_i8x2[] = "__ldlu_i8x2_";
-static const char __ldlu_r2x2[] = "__ldlu_r2x2_";
-static const char __ldlu_r4x4[] = "__ldlu_r4x4_";
-static const char __ldlu_r8x2[] = "__ldlu_r8x2_";
-
/// Table that drives the fir generation depending on the intrinsic or intrinsic
/// module procedure one to one mapping with Fortran arguments. If no mapping is
/// defined here for a generic intrinsic, genRuntimeCall will be called
@@ -143,106 +122,6 @@ static const char __ldlu_r8x2[] = "__ldlu_r8x2_";
/// argument must not be lowered by value. In which case, the lowering rules
/// should be provided for all the intrinsic arguments for completeness.
static constexpr IntrinsicHandler handlers[]{
- {"__ldca_i4x4",
- &I::genCUDALDXXFunc<__ldca_i4x4, 4>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldca_i8x2",
- &I::genCUDALDXXFunc<__ldca_i8x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldca_r2x2",
- &I::genCUDALDXXFunc<__ldca_r2x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldca_r4x4",
- &I::genCUDALDXXFunc<__ldca_r4x4, 4>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldca_r8x2",
- &I::genCUDALDXXFunc<__ldca_r8x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcg_i4x4",
- &I::genCUDALDXXFunc<__ldcg_i4x4, 4>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcg_i8x2",
- &I::genCUDALDXXFunc<__ldcg_i8x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcg_r2x2",
- &I::genCUDALDXXFunc<__ldcg_r2x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcg_r4x4",
- &I::genCUDALDXXFunc<__ldcg_r4x4, 4>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcg_r8x2",
- &I::genCUDALDXXFunc<__ldcg_r8x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcs_i4x4",
- &I::genCUDALDXXFunc<__ldcs_i4x4, 4>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcs_i8x2",
- &I::genCUDALDXXFunc<__ldcs_i8x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcs_r2x2",
- &I::genCUDALDXXFunc<__ldcs_r2x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcs_r4x4",
- &I::genCUDALDXXFunc<__ldcs_r4x4, 4>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcs_r8x2",
- &I::genCUDALDXXFunc<__ldcs_r8x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcv_i4x4",
- &I::genCUDALDXXFunc<__ldcv_i4x4, 4>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcv_i8x2",
- &I::genCUDALDXXFunc<__ldcv_i8x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcv_r2x2",
- &I::genCUDALDXXFunc<__ldcv_r2x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcv_r4x4",
- &I::genCUDALDXXFunc<__ldcv_r4x4, 4>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldcv_r8x2",
- &I::genCUDALDXXFunc<__ldcv_r8x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldlu_i4x4",
- &I::genCUDALDXXFunc<__ldlu_i4x4, 4>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldlu_i8x2",
- &I::genCUDALDXXFunc<__ldlu_i8x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldlu_r2x2",
- &I::genCUDALDXXFunc<__ldlu_r2x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldlu_r4x4",
- &I::genCUDALDXXFunc<__ldlu_r4x4, 4>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
- {"__ldlu_r8x2",
- &I::genCUDALDXXFunc<__ldlu_r8x2, 2>,
- {{{"a", asAddr}}},
- /*isElemental=*/false},
{"abort", &I::genAbort},
{"abs", &I::genAbs},
{"achar", &I::genChar},
@@ -262,10 +141,6 @@ static constexpr IntrinsicHandler handlers[]{
&I::genAll,
{{{"mask", asAddr}, {"dim", asValue}}},
/*isElemental=*/false},
- {"all_sync",
- &I::genVoteSync<mlir::NVVM::VoteSyncKind::all>,
- {{{"mask", asValue}, {"pred", asValue}}},
- /*isElemental=*/false},
{"allocated",
&I::genAllocated,
{{{"array", asInquired}, {"scalar", asInquired}}},
@@ -275,10 +150,6 @@ static constexpr IntrinsicHandler handlers[]{
&I::genAny,
{{{"mask", asAddr}, {"dim", asValue}}},
/*isElemental=*/false},
- {"any_sync",
- &I::genVoteSync<mlir::NVVM::VoteSyncKind::any>,
- {{{"mask", asValue}, {"pred", asValue}}},
- /*isElemental=*/false},
{"asind", &I::genAsind},
{"asinpi", &I::genAsinpi},
{"associated",
@@ -289,75 +160,6 @@ static constexpr IntrinsicHandler handlers[]{
{"atan2pi", &I::genAtanpi},
{"atand", &I::genAtand},
{"atanpi", &I::genAtanpi},
- {"atomicaddd", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicaddf", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicaddi", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicaddl", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicandi", &I::genAtomicAnd, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomiccasd",
- &I::genAtomicCas,
- {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
- false},
- {"atomiccasf",
- &I::genAtomicCas,
- {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
- false},
- {"atomiccasi",
- &I::genAtomicCas,
- {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
- false},
- {"atomiccasul",
- &I::genAtomicCas,
- {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
- false},
- {"atomicdeci", &I::genAtomicDec, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicexchd",
- &I::genAtomicExch,
- {{{"a", asAddr}, {"v", asValue}}},
- false},
- {"atomicexchf",
- &I::genAtomicExch,
- {{{"a", asAddr}, {"v", asValue}}},
- false},
- {"atomicexchi",
- &I::genAtomicExch,
- {{{"a", asAddr}, {"v", asValue}}},
- false},
- {"atomicexchul",
- &I::genAtomicExch,
- {{{"a", asAddr}, {"v", asValue}}},
- false},
- {"atomicinci", &I::genAtomicInc, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicmaxd", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicmaxf", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicmaxi", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicmaxl", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicmind", &I::genAtomicMin, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicminf", &I::genAtomicMin, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicmini", &I::genAtomicMin, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicminl", &I::genAtomicMin, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicori", &I::genAtomicOr, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicsubd", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicsubf", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicsubi", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicsubl", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"atomicxori", &I::genAtomicXor, {{{"a", asAddr}, {"v", asValue}}}, false},
- {"ballot_sync",
- &I::genVoteSync<mlir::NVVM::VoteSyncKind::ballot>,
- {{{"mask", asValue}, {"pred", asValue}}},
- /*isElemental=*/false},
- {"barrier_arrive",
- &I::genBarrierArrive,
- {{{"barrier", asAddr}}},
- /*isElemental=*/false},
- {"barrier_arrive_cnt",
- &I::genBarrierArriveCnt,
- {{{"barrier", asAddr}, {"count", asValue}}},
- /*isElemental=*/false},
- {"barrier_init",
- &I::genBarrierInit,
- {{{"barrier", asAddr}, {"count", asValue}}},
- /*isElemental=*/false},
{"bessel_jn",
&I::genBesselJn,
{{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}},
@@ -391,6 +193,12 @@ static constexpr IntrinsicHandler handlers[]{
&I::genCFProcPointer,
{{{"cptr", asValue}, {"fptr", asInquired}}},
/*isElemental=*/false},
+ {"c_f_strpointer",
+ &I::genCFStrPointer,
+ {{{"cstrptr_or_cstrarray", asValue},
+ {"fstrptr", asInquired},
+ {"nchars", asValue, handleDynamicOptional}}},
+ /*isElemental=*/false},
{"c_funloc", &I::genCFunLoc, {{{"x", asBox}}}, /*isElemental=*/false},
{"c_loc", &I::genCLoc, {{{"x", asBox}}}, /*isElemental=*/false},
{"c_ptr_eq", &I::genCPtrCompare<mlir::arith::CmpIPredicate::eq>},
@@ -401,11 +209,6 @@ static constexpr IntrinsicHandler handlers[]{
&I::genChdir,
{{{"name", asAddr}, {"status", asAddr, handleDynamicOptional}}},
/*isElemental=*/false},
- {"clock", &I::genNVVMTime<mlir::NVVM::ClockOp>, {}, /*isElemental=*/false},
- {"clock64",
- &I::genNVVMTime<mlir::NVVM::Clock64Op>,
- {},
- /*isElemental=*/false},
{"cmplx",
&I::genCmplx,
{{{"x", asValue}, {"y", asValue, handleDynamicOptional}}}},
@@ -502,9 +305,9 @@ static constexpr IntrinsicHandler handlers[]{
&I::genExtendsTypeOf,
{{{"a", asBox}, {"mold", asBox}}},
/*isElemental=*/false},
- {"fence_proxy_async",
- &I::genFenceProxyAsync,
- {},
+ {"f_c_string",
+ &I::genFCString,
+ {{{"string", asAddr}, {"asis", asValue, handleDynamicOptional}}},
/*isElemental=*/false},
{"findloc",
&I::genFindloc,
@@ -516,6 +319,10 @@ static constexpr IntrinsicHandler handlers[]{
{"back", asValue, handleDynamicOptional}}},
/*isElemental=*/false},
{"floor", &I::genFloor},
+ {"flush",
+ &I::genFlush,
+ {{{"unit", asAddr}}},
+ /*isElemental=*/false},
{"fraction", &I::genFraction},
{"free", &I::genFree},
{"fseek",
@@ -553,6 +360,10 @@ static constexpr IntrinsicHandler handlers[]{
{"trim_name", asAddr, handleDynamicOptional},
{"errmsg", asBox, handleDynamicOptional}}},
/*isElemental=*/false},
+ {"get_team",
+ &I::genGetTeam,
+ {{{"level", asValue, handleDynamicOptional}}},
+ /*isElemental=*/false},
{"getcwd",
&I::genGetCwd,
{{{"c", asBox}, {"status", asAddr, handleDynamicOptional}}},
@@ -560,10 +371,6 @@ static constexpr IntrinsicHandler handlers[]{
{"getgid", &I::genGetGID},
{"getpid", &I::genGetPID},
{"getuid", &I::genGetUID},
- {"globaltimer",
- &I::genNVVMTime<mlir::NVVM::GlobalTimerOp>,
- {},
- /*isElemental=*/false},
{"hostnm",
&I::genHostnm,
{{{"c", asBox}, {"status", asAddr, handleDynamicOptional}}},
@@ -703,6 +510,10 @@ static constexpr IntrinsicHandler handlers[]{
{"dim", asValue},
{"mask", asBox, handleDynamicOptional}}},
/*isElemental=*/false},
+ {"irand",
+ &I::genIrand,
+ {{{"i", asAddr, handleDynamicOptional}}},
+ /*isElemental=*/false},
{"is_contiguous",
&I::genIsContiguous,
{{{"array", asBox}}},
@@ -731,38 +542,6 @@ static constexpr IntrinsicHandler handlers[]{
{"malloc", &I::genMalloc},
{"maskl", &I::genMask<mlir::arith::ShLIOp>},
{"maskr", &I::genMask<mlir::arith::ShRUIOp>},
- {"match_all_syncjd",
- &I::genMatchAllSync,
- {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
- /*isElemental=*/false},
- {"match_all_syncjf",
- &I::genMatchAllSync,
- {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
- /*isElemental=*/false},
- {"match_all_syncjj",
- &I::genMatchAllSync,
- {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
- /*isElemental=*/false},
- {"match_all_syncjx",
- &I::genMatchAllSync,
- {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
- /*isElemental=*/false},
- {"match_any_syncjd",
- &I::genMatchAnySync,
- {{{"mask", asValue}, {"value", asValue}}},
- /*isElemental=*/false},
- {"match_any_syncjf",
- &I::genMatchAnySync,
- {{{"mask", asValue}, {"value", asValue}}},
- /*isElemental=*/false},
- {"match_any_syncjj",
- &I::genMatchAnySync,
- {{{"mask", asValue}, {"value", asValue}}},
- /*isElemental=*/false},
- {"match_any_syncjx",
- &I::genMatchAnySync,
- {{{"mask", asValue}, {"value", asValue}}},
- /*isElemental=*/false},
{"matmul",
&I::genMatmul,
{{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
@@ -861,6 +640,10 @@ static constexpr IntrinsicHandler handlers[]{
&I::genPutenv,
{{{"str", asAddr}, {"status", asAddr, handleDynamicOptional}}},
/*isElemental=*/false},
+ {"rand",
+ &I::genRand,
+ {{{"i", asAddr, handleDynamicOptional}}},
+ /*isElemental=*/false},
{"random_init",
&I::genRandomInit,
{{{"repeatable", asValue}, {"image_distinct", asValue}}},
@@ -955,6 +738,10 @@ static constexpr IntrinsicHandler handlers[]{
{"shifta", &I::genShiftA},
{"shiftl", &I::genShift<mlir::arith::ShLIOp>},
{"shiftr", &I::genShift<mlir::arith::ShRUIOp>},
+ {"show_descriptor",
+ &I::genShowDescriptor,
+ {{{"d", asInquired}}},
+ /*isElemental=*/false},
{"sign", &I::genSign},
{"signal",
&I::genSignalSubroutine,
@@ -988,11 +775,6 @@ static constexpr IntrinsicHandler handlers[]{
{"dim", asValue},
{"mask", asBox, handleDynamicOptional}}},
/*isElemental=*/false},
- {"syncthreads", &I::genSyncThreads, {}, /*isElemental=*/false},
- {"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
- {"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false},
- {"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
- {"syncwarp", &I::genSyncWarp, {}, /*isElemental=*/false},
{"system",
&I::genSystem,
{{{"command", asBox}, {"exitstat", asBox, handleDynamicOptional}}},
@@ -1003,38 +785,17 @@ static constexpr IntrinsicHandler handlers[]{
/*isElemental=*/false},
{"tand", &I::genTand},
{"tanpi", &I::genTanpi},
- {"this_grid", &I::genThisGrid, {}, /*isElemental=*/false},
+ {"team_number",
+ &I::genTeamNumber,
+ {{{"team", asBox, handleDynamicOptional}}},
+ /*isElemental=*/false},
{"this_image",
&I::genThisImage,
{{{"coarray", asBox},
{"dim", asAddr},
{"team", asBox, handleDynamicOptional}}},
/*isElemental=*/false},
- {"this_thread_block", &I::genThisThreadBlock, {}, /*isElemental=*/false},
- {"this_warp", &I::genThisWarp, {}, /*isElemental=*/false},
- {"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
- {"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
- {"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false},
{"time", &I::genTime, {}, /*isElemental=*/false},
- {"tma_bulk_commit_group",
- &I::genTMABulkCommitGroup,
- {{}},
- /*isElemental=*/false},
- {"tma_bulk_g2s",
- &I::genTMABulkG2S,
- {{{"barrier", asAddr},
- {"src", asAddr},
- {"dst", asAddr},
- {"nbytes", asValue}}},
- /*isElemental=*/false},
- {"tma_bulk_s2g",
- &I::genTMABulkS2G,
- {{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}},
- /*isElemental=*/false},
- {"tma_bulk_wait_group",
- &I::genTMABulkWaitGroup,
- {{}},
- /*isElemental=*/false},
{"trailz", &I::genTrailz},
{"transfer",
&I::genTransfer,
@@ -1758,8 +1519,10 @@ static constexpr MathOperation mathOperations[] = {
genComplexMathOp<mlir::complex::SinOp>},
{"sin", RTNAME_STRING(CSinF128), FuncTypeComplex16Complex16,
genLibF128Call},
- {"sinh", "sinhf", genFuncType<Ty::Real<4>, Ty::Real<4>>, genLibCall},
- {"sinh", "sinh", genFuncType<Ty::Real<8>, Ty::Real<8>>, genLibCall},
+ {"sinh", "sinhf", genFuncType<Ty::Real<4>, Ty::Real<4>>,
+ genMathOp<mlir::math::SinhOp>},
+ {"sinh", "sinh", genFuncType<Ty::Real<8>, Ty::Real<8>>,
+ genMathOp<mlir::math::SinhOp>},
{"sinh", RTNAME_STRING(SinhF128), FuncTypeReal16Real16, genLibF128Call},
{"sinh", "csinhf", genFuncType<Ty::Complex<4>, Ty::Complex<4>>, genLibCall},
{"sinh", "csinh", genFuncType<Ty::Complex<8>, Ty::Complex<8>>, genLibCall},
@@ -2124,6 +1887,9 @@ lookupIntrinsicHandler(fir::FirOpBuilder &builder,
if (isPPCTarget)
if (const IntrinsicHandler *ppcHandler = findPPCIntrinsicHandler(name))
return std::make_optional<IntrinsicHandlerEntry>(ppcHandler);
+ // TODO: Look for CUDA intrinsic handlers only if CUDA is enabled.
+ if (const IntrinsicHandler *cudaHandler = findCUDAIntrinsicHandler(name))
+ return std::make_optional<IntrinsicHandlerEntry>(cudaHandler);
// Subroutines should have a handler.
if (!resultType)
return std::nullopt;
@@ -3010,157 +2776,6 @@ mlir::Value IntrinsicLibrary::genAtanpi(mlir::Type resultType,
return mlir::arith::MulFOp::create(builder, loc, atan, factor);
}
-static mlir::Value genAtomBinOp(fir::FirOpBuilder &builder, mlir::Location &loc,
- mlir::LLVM::AtomicBinOp binOp, mlir::Value arg0,
- mlir::Value arg1) {
- auto llvmPointerType = mlir::LLVM::LLVMPointerType::get(builder.getContext());
- arg0 = builder.createConvert(loc, llvmPointerType, arg0);
- return mlir::LLVM::AtomicRMWOp::create(builder, loc, binOp, arg0, arg1,
- mlir::LLVM::AtomicOrdering::seq_cst);
-}
-
-mlir::Value IntrinsicLibrary::genAtomicAdd(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 2);
-
- mlir::LLVM::AtomicBinOp binOp =
- mlir::isa<mlir::IntegerType>(args[1].getType())
- ? mlir::LLVM::AtomicBinOp::add
- : mlir::LLVM::AtomicBinOp::fadd;
- return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
-}
-
-mlir::Value IntrinsicLibrary::genAtomicSub(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 2);
-
- mlir::LLVM::AtomicBinOp binOp =
- mlir::isa<mlir::IntegerType>(args[1].getType())
- ? mlir::LLVM::AtomicBinOp::sub
- : mlir::LLVM::AtomicBinOp::fsub;
- return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
-}
-
-mlir::Value IntrinsicLibrary::genAtomicAnd(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 2);
- assert(mlir::isa<mlir::IntegerType>(args[1].getType()));
-
- mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::_and;
- return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
-}
-
-mlir::Value IntrinsicLibrary::genAtomicOr(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 2);
- assert(mlir::isa<mlir::IntegerType>(args[1].getType()));
-
- mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::_or;
- return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
-}
-
-// ATOMICCAS
-fir::ExtendedValue
-IntrinsicLibrary::genAtomicCas(mlir::Type resultType,
- llvm::ArrayRef<fir::ExtendedValue> args) {
- assert(args.size() == 3);
- auto successOrdering = mlir::LLVM::AtomicOrdering::acq_rel;
- auto failureOrdering = mlir::LLVM::AtomicOrdering::monotonic;
- auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(resultType.getContext());
-
- mlir::Value arg0 = fir::getBase(args[0]);
- mlir::Value arg1 = fir::getBase(args[1]);
- mlir::Value arg2 = fir::getBase(args[2]);
-
- auto bitCastFloat = [&](mlir::Value arg) -> mlir::Value {
- if (mlir::isa<mlir::Float32Type>(arg.getType()))
- return mlir::LLVM::BitcastOp::create(builder, loc, builder.getI32Type(),
- arg);
- if (mlir::isa<mlir::Float64Type>(arg.getType()))
- return mlir::LLVM::BitcastOp::create(builder, loc, builder.getI64Type(),
- arg);
- return arg;
- };
-
- arg1 = bitCastFloat(arg1);
- arg2 = bitCastFloat(arg2);
-
- if (arg1.getType() != arg2.getType()) {
- // arg1 and arg2 need to have the same type in AtomicCmpXchgOp.
- arg2 = builder.createConvert(loc, arg1.getType(), arg2);
- }
-
- auto address =
- mlir::UnrealizedConversionCastOp::create(builder, loc, llvmPtrTy, arg0)
- .getResult(0);
- auto cmpxchg = mlir::LLVM::AtomicCmpXchgOp::create(
- builder, loc, address, arg1, arg2, successOrdering, failureOrdering);
- return mlir::LLVM::ExtractValueOp::create(builder, loc, cmpxchg, 1);
-}
-
-mlir::Value IntrinsicLibrary::genAtomicDec(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 2);
- assert(mlir::isa<mlir::IntegerType>(args[1].getType()));
-
- mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::udec_wrap;
- return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
-}
-
-// ATOMICEXCH
-fir::ExtendedValue
-IntrinsicLibrary::genAtomicExch(mlir::Type resultType,
- llvm::ArrayRef<fir::ExtendedValue> args) {
- assert(args.size() == 2);
- mlir::Value arg0 = fir::getBase(args[0]);
- mlir::Value arg1 = fir::getBase(args[1]);
- assert(arg1.getType().isIntOrFloat());
-
- mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::xchg;
- return genAtomBinOp(builder, loc, binOp, arg0, arg1);
-}
-
-mlir::Value IntrinsicLibrary::genAtomicInc(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 2);
- assert(mlir::isa<mlir::IntegerType>(args[1].getType()));
-
- mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::uinc_wrap;
- return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
-}
-
-mlir::Value IntrinsicLibrary::genAtomicMax(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 2);
-
- mlir::LLVM::AtomicBinOp binOp =
- mlir::isa<mlir::IntegerType>(args[1].getType())
- ? mlir::LLVM::AtomicBinOp::max
- : mlir::LLVM::AtomicBinOp::fmax;
- return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
-}
-
-mlir::Value IntrinsicLibrary::genAtomicMin(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 2);
-
- mlir::LLVM::AtomicBinOp binOp =
- mlir::isa<mlir::IntegerType>(args[1].getType())
- ? mlir::LLVM::AtomicBinOp::min
- : mlir::LLVM::AtomicBinOp::fmin;
- return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
-}
-
-// ATOMICXOR
-fir::ExtendedValue
-IntrinsicLibrary::genAtomicXor(mlir::Type resultType,
- llvm::ArrayRef<fir::ExtendedValue> args) {
- assert(args.size() == 2);
- mlir::Value arg0 = fir::getBase(args[0]);
- mlir::Value arg1 = fir::getBase(args[1]);
- return genAtomBinOp(builder, loc, mlir::LLVM::AtomicBinOp::_xor, arg0, arg1);
-}
-
// ASSOCIATED
fir::ExtendedValue
IntrinsicLibrary::genAssociated(mlir::Type resultType,
@@ -3212,63 +2827,6 @@ IntrinsicLibrary::genAssociated(mlir::Type resultType,
return fir::runtime::genAssociated(builder, loc, pointerBox, targetBox);
}
-static mlir::Value convertPtrToNVVMSpace(fir::FirOpBuilder &builder,
- mlir::Location loc,
- mlir::Value barrier,
- mlir::NVVM::NVVMMemorySpace space) {
- mlir::Value llvmPtr = fir::ConvertOp::create(
- builder, loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()),
- barrier);
- mlir::Value addrCast = mlir::LLVM::AddrSpaceCastOp::create(
- builder, loc,
- mlir::LLVM::LLVMPointerType::get(builder.getContext(),
- static_cast<unsigned>(space)),
- llvmPtr);
- return addrCast;
-}
-
-// BARRIER_ARRIVE (CUDA)
-mlir::Value
-IntrinsicLibrary::genBarrierArrive(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 1);
- mlir::Value barrier = convertPtrToNVVMSpace(
- builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared);
- return mlir::NVVM::MBarrierArriveSharedOp::create(builder, loc, resultType,
- barrier)
- .getResult();
-}
-
-// BARRIER_ARRIBVE_CNT (CUDA)
-mlir::Value
-IntrinsicLibrary::genBarrierArriveCnt(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 2);
- mlir::Value barrier = convertPtrToNVVMSpace(
- builder, loc, args[0], mlir::NVVM::NVVMMemorySpace::Shared);
- mlir::Value token = fir::AllocaOp::create(builder, loc, resultType);
- // TODO: the MBarrierArriveExpectTxOp is not taking the state argument and
- // currently just the sink symbol `_`.
- // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
- mlir::NVVM::MBarrierArriveExpectTxOp::create(builder, loc, barrier, args[1],
- {});
- return fir::LoadOp::create(builder, loc, token);
-}
-
-// BARRIER_INIT (CUDA)
-void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {
- assert(args.size() == 2);
- mlir::Value barrier = convertPtrToNVVMSpace(
- builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared);
- mlir::NVVM::MBarrierInitSharedOp::create(builder, loc, barrier,
- fir::getBase(args[1]), {});
- auto kind = mlir::NVVM::ProxyKindAttr::get(
- builder.getContext(), mlir::NVVM::ProxyKind::async_shared);
- auto space = mlir::NVVM::SharedSpaceAttr::get(
- builder.getContext(), mlir::NVVM::SharedSpace::shared_cta);
- mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space);
-}
-
// BESSEL_JN
fir::ExtendedValue
IntrinsicLibrary::genBesselJn(mlir::Type resultType,
@@ -3516,11 +3074,23 @@ static mlir::Value getAddrFromBox(fir::FirOpBuilder &builder,
return addr;
}
+static void clocDeviceArgRewrite(fir::ExtendedValue arg) {
+ // Special case for device address in c_loc.
+ if (auto emboxOp = mlir::dyn_cast_or_null<fir::EmboxOp>(
+ fir::getBase(arg).getDefiningOp()))
+ if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
+ emboxOp.getMemref().getDefiningOp()))
+ if (declareOp.getDataAttr() &&
+ declareOp.getDataAttr() == cuf::DataAttribute::Device)
+ emboxOp.getMemrefMutable().assign(declareOp.getMemref());
+}
+
static fir::ExtendedValue
genCLocOrCFunLoc(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args,
bool isFunc = false, bool isDevLoc = false) {
assert(args.size() == 1);
+ clocDeviceArgRewrite(args[0]);
mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
mlir::Value resAddr;
if (isDevLoc)
@@ -3686,6 +3256,99 @@ void IntrinsicLibrary::genCFProcPointer(
fir::StoreOp::create(builder, loc, cptrBox, fptr);
}
+// C_F_STRPOINTER
+void IntrinsicLibrary::genCFStrPointer(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+
+ mlir::Value cStrAddr;
+ mlir::Value strLen;
+
+ const mlir::Value firstArg = fir::getBase(args[0]);
+ const mlir::Type firstArgType = fir::unwrapRefType(firstArg.getType());
+ const bool isCstrptr = mlir::isa<fir::RecordType>(firstArgType);
+
+ if (isCstrptr) {
+ // CSTRPTR form: Extract address from C_PTR
+ cStrAddr = fir::factory::genCPtrOrCFunptrValue(builder, loc, firstArg);
+
+ assert(isStaticallyPresent(args[2]));
+ mlir::Value nchars = fir::getBase(args[2]);
+ if (fir::isa_ref_type(nchars.getType())) {
+ strLen = fir::LoadOp::create(builder, loc, nchars);
+ } else {
+ strLen = nchars;
+ }
+ } else {
+ // CSTRARRAY form: Get address from CHARACTER array
+ if (const auto boxCharTy =
+ mlir::dyn_cast<fir::BoxCharType>(firstArg.getType())) {
+ const auto charTy = mlir::cast<fir::CharacterType>(boxCharTy.getEleTy());
+ const auto addrTy = builder.getRefType(charTy);
+ auto unboxed = fir::UnboxCharOp::create(
+ builder, loc, mlir::TypeRange{addrTy, builder.getIndexType()},
+ firstArg);
+ cStrAddr = unboxed.getResult(0);
+ } else if (mlir::isa<fir::BoxType>(firstArg.getType())) {
+ cStrAddr = fir::BoxAddrOp::create(builder, loc, firstArg);
+ } else {
+ cStrAddr = firstArg;
+ }
+
+ // Handle optional NCHARS argument
+ if (isStaticallyPresent(args[2])) {
+ mlir::Value nchars = fir::getBase(args[2]);
+ if (fir::isa_ref_type(nchars.getType())) {
+ strLen = fir::LoadOp::create(builder, loc, nchars);
+ } else {
+ strLen = nchars;
+ }
+ } else {
+ const mlir::Type i8PtrTy = builder.getRefType(builder.getIntegerType(8));
+ const mlir::Value strPtr = builder.createConvert(loc, i8PtrTy, cStrAddr);
+
+ const mlir::Type i64Ty = builder.getIntegerType(64);
+ const mlir::FunctionType strlenType =
+ mlir::FunctionType::get(builder.getContext(), {i8PtrTy}, {i64Ty});
+
+ mlir::func::FuncOp strlenFunc = builder.getNamedFunction("strlen");
+ if (!strlenFunc) {
+ strlenFunc = builder.createFunction(loc, "strlen", strlenType);
+ strlenFunc->setAttr(
+ fir::getSymbolAttrName(),
+ mlir::StringAttr::get(builder.getContext(), "strlen"));
+ }
+ auto call = fir::CallOp::create(builder, loc, strlenFunc, {strPtr});
+ strLen = call.getResult(0);
+ }
+ }
+
+ // Handle FSTRPTR (second argument)
+ const auto *fStrPtr = args[1].getBoxOf<fir::MutableBoxValue>();
+ assert(fStrPtr && "FSTRPTR must be a pointer");
+
+ const mlir::Value lenIdx =
+ builder.createConvert(loc, builder.getIndexType(), strLen);
+
+ const mlir::Type charPtrType = fir::PointerType::get(fir::CharacterType::get(
+ builder.getContext(), 1, fir::CharacterType::unknownLen()));
+ const mlir::Value charPtr = builder.createConvert(loc, charPtrType, cStrAddr);
+
+ const fir::CharBoxValue charBox{charPtr, lenIdx};
+ fir::factory::associateMutableBox(builder, loc, *fStrPtr, charBox,
+ /*lbounds=*/mlir::ValueRange{});
+
+ // CUDA synchronization if needed
+ if (auto declare = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
+ fStrPtr->getAddr().getDefiningOp()))
+ if (declare.getMemref().getDefiningOp() &&
+ mlir::isa<fir::AddrOfOp>(declare.getMemref().getDefiningOp()))
+ if (cuf::isRegisteredDeviceAttr(declare.getDataAttr()) &&
+ !cuf::isCUDADeviceContext(builder.getRegion()))
+ fir::runtime::cuda::genSyncGlobalDescriptor(builder, loc,
+ declare.getMemref());
+}
+
// C_FUNLOC
fir::ExtendedValue
IntrinsicLibrary::genCFunLoc(mlir::Type resultType,
@@ -3990,30 +3653,6 @@ IntrinsicLibrary::genCshift(mlir::Type resultType,
return readAndAddCleanUp(resultMutableBox, resultType, "CSHIFT");
}
-// __LDCA, __LDCS, __LDLU, __LDCV
-template <const char *fctName, int extent>
-fir::ExtendedValue
-IntrinsicLibrary::genCUDALDXXFunc(mlir::Type resultType,
- llvm::ArrayRef<fir::ExtendedValue> args) {
- assert(args.size() == 1);
- mlir::Type resTy = fir::SequenceType::get(extent, resultType);
- mlir::Value arg = fir::getBase(args[0]);
- mlir::Value res = fir::AllocaOp::create(builder, loc, resTy);
- if (mlir::isa<fir::BaseBoxType>(arg.getType()))
- arg = fir::BoxAddrOp::create(builder, loc, arg);
- mlir::Type refResTy = fir::ReferenceType::get(resTy);
- mlir::FunctionType ftype =
- mlir::FunctionType::get(arg.getContext(), {refResTy, refResTy}, {});
- auto funcOp = builder.createFunction(loc, fctName, ftype);
- llvm::SmallVector<mlir::Value> funcArgs;
- funcArgs.push_back(res);
- funcArgs.push_back(arg);
- fir::CallOp::create(builder, loc, funcOp, funcArgs);
- mlir::Value ext =
- builder.createIntegerConstant(loc, builder.getIndexType(), extent);
- return fir::ArrayBoxValue(res, {ext});
-}
-
// DATE_AND_TIME
void IntrinsicLibrary::genDateAndTime(llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 4 && "date_and_time has 4 args");
@@ -4317,9 +3956,6 @@ void IntrinsicLibrary::genExit(llvm::ArrayRef<fir::ExtendedValue> args) {
EXIT_SUCCESS)
: fir::getBase(args[0]);
- assert(status.getType() == builder.getDefaultIntegerType() &&
- "STATUS parameter must be an INTEGER of default kind");
-
fir::runtime::genExit(builder, loc, status);
}
@@ -4346,15 +3982,30 @@ IntrinsicLibrary::genExtendsTypeOf(mlir::Type resultType,
fir::getBase(args[1])));
}
-// FENCE_PROXY_ASYNC (CUDA)
-void IntrinsicLibrary::genFenceProxyAsync(
- llvm::ArrayRef<fir::ExtendedValue> args) {
- assert(args.size() == 0);
- auto kind = mlir::NVVM::ProxyKindAttr::get(
- builder.getContext(), mlir::NVVM::ProxyKind::async_shared);
- auto space = mlir::NVVM::SharedSpaceAttr::get(
- builder.getContext(), mlir::NVVM::SharedSpace::shared_cta);
- mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space);
+// F_C_STRING
+fir::ExtendedValue
+IntrinsicLibrary::genFCString(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() >= 1 && args.size() <= 2);
+
+ mlir::Value string = builder.createBox(loc, args[0]);
+
+ // Handle optional ASIS argument
+ mlir::Value asis = isStaticallyAbsent(args, 1)
+ ? builder.createBool(loc, false)
+ : fir::getBase(args[1]);
+
+ // Create mutable fir.box to be passed to the runtime for the result.
+ fir::MutableBoxValue resultMutableBox =
+ fir::factory::createTempMutableBox(builder, loc, resultType);
+ mlir::Value resultIrBox =
+ fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
+
+ fir::runtime::genFCString(builder, loc, resultIrBox, string, asis);
+
+ // Read result from mutable fir.box and add it to the list of temps to be
+ // finalized by the StatementContext.
+ return readAndAddCleanUp(resultMutableBox, resultType, "F_C_STRING");
}
// FINDLOC
@@ -4439,6 +4090,40 @@ mlir::Value IntrinsicLibrary::genFloor(mlir::Type resultType,
return builder.createConvert(loc, resultType, floor);
}
+// FLUSH
+void IntrinsicLibrary::genFlush(llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 1);
+
+ mlir::Value unit;
+ if (isStaticallyAbsent(args[0]))
+ // Give a sentinal value of `-1` on the `()` case.
+ unit = builder.createIntegerConstant(loc, builder.getI32Type(), -1);
+ else {
+ unit = fir::getBase(args[0]);
+ if (isOptional(unit)) {
+ mlir::Value isPresent =
+ fir::IsPresentOp::create(builder, loc, builder.getI1Type(), unit);
+ unit = builder
+ .genIfOp(loc, builder.getI32Type(), isPresent,
+ /*withElseRegion=*/true)
+ .genThen([&]() {
+ mlir::Value loaded = fir::LoadOp::create(builder, loc, unit);
+ fir::ResultOp::create(builder, loc, loaded);
+ })
+ .genElse([&]() {
+ mlir::Value negOne = builder.createIntegerConstant(
+ loc, builder.getI32Type(), -1);
+ fir::ResultOp::create(builder, loc, negOne);
+ })
+ .getResults()[0];
+ } else {
+ unit = fir::LoadOp::create(builder, loc, unit);
+ }
+ }
+
+ fir::runtime::genFlush(builder, loc, unit);
+}
+
// FRACTION
mlir::Value IntrinsicLibrary::genFraction(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
@@ -4518,6 +4203,15 @@ IntrinsicLibrary::genFtell(std::optional<mlir::Type> resultType,
}
}
+// GET_TEAM
+mlir::Value IntrinsicLibrary::genGetTeam(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ converter->checkCoarrayEnabled();
+ assert(args.size() == 1);
+ return mif::GetTeamOp::create(builder, loc, fir::BoxType::get(resultType),
+ /*level*/ args[0]);
+}
+
// GETCWD
fir::ExtendedValue
IntrinsicLibrary::genGetCwd(std::optional<mlir::Type> resultType,
@@ -6603,6 +6297,20 @@ IntrinsicLibrary::genIparity(mlir::Type resultType,
"IPARITY", resultType, args);
}
+// IRAND
+fir::ExtendedValue
+IntrinsicLibrary::genIrand(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 1);
+ mlir::Value i =
+ isStaticallyPresent(args[0])
+ ? fir::getBase(args[0])
+ : fir::AbsentOp::create(builder, loc,
+ builder.getRefType(builder.getI32Type()))
+ .getResult();
+ return fir::runtime::genIrand(builder, loc, i);
+}
+
// IS_CONTIGUOUS
fir::ExtendedValue
IntrinsicLibrary::genIsContiguous(mlir::Type resultType,
@@ -6786,12 +6494,6 @@ IntrinsicLibrary::genCharacterCompare(mlir::Type resultType,
fir::getBase(args[1]), fir::getLen(args[1]));
}
-static bool isOptional(mlir::Value value) {
- auto varIface = mlir::dyn_cast_or_null<fir::FortranVariableOpInterface>(
- value.getDefiningOp());
- return varIface && varIface.isOptional();
-}
-
// LOC
fir::ExtendedValue
IntrinsicLibrary::genLoc(mlir::Type resultType,
@@ -6867,67 +6569,6 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType,
return result;
}
-// MATCH_ALL_SYNC
-mlir::Value
-IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 3);
- bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
-
- mlir::Type i1Ty = builder.getI1Type();
- mlir::MLIRContext *context = builder.getContext();
-
- mlir::Value arg1 = args[1];
- if (arg1.getType().isF32() || arg1.getType().isF64())
- arg1 = fir::ConvertOp::create(
- builder, loc, is32 ? builder.getI32Type() : builder.getI64Type(), arg1);
-
- mlir::Type retTy =
- mlir::LLVM::LLVMStructType::getLiteral(context, {resultType, i1Ty});
- auto match =
- mlir::NVVM::MatchSyncOp::create(builder, loc, retTy, args[0], arg1,
- mlir::NVVM::MatchSyncKind::all)
- .getResult();
- auto value = mlir::LLVM::ExtractValueOp::create(builder, loc, match, 0);
- auto pred = mlir::LLVM::ExtractValueOp::create(builder, loc, match, 1);
- auto conv = mlir::LLVM::ZExtOp::create(builder, loc, resultType, pred);
- fir::StoreOp::create(builder, loc, conv, args[2]);
- return value;
-}
-
-// ALL_SYNC, ANY_SYNC, BALLOT_SYNC
-template <mlir::NVVM::VoteSyncKind kind>
-mlir::Value IntrinsicLibrary::genVoteSync(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 2);
- mlir::Value arg1 =
- fir::ConvertOp::create(builder, loc, builder.getI1Type(), args[1]);
- mlir::Type resTy = kind == mlir::NVVM::VoteSyncKind::ballot
- ? builder.getI32Type()
- : builder.getI1Type();
- auto voteRes =
- mlir::NVVM::VoteSyncOp::create(builder, loc, resTy, args[0], arg1, kind)
- .getResult();
- return fir::ConvertOp::create(builder, loc, resultType, voteRes);
-}
-
-// MATCH_ANY_SYNC
-mlir::Value
-IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 2);
- bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
-
- mlir::Value arg1 = args[1];
- if (arg1.getType().isF32() || arg1.getType().isF64())
- arg1 = fir::ConvertOp::create(
- builder, loc, is32 ? builder.getI32Type() : builder.getI64Type(), arg1);
-
- return mlir::NVVM::MatchSyncOp::create(builder, loc, resultType, args[0],
- arg1, mlir::NVVM::MatchSyncKind::any)
- .getResult();
-}
-
// MATMUL
fir::ExtendedValue
IntrinsicLibrary::genMatmul(mlir::Type resultType,
@@ -7075,11 +6716,9 @@ static mlir::Value genFastMod(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value IntrinsicLibrary::genMod(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
auto mod = builder.getModule();
- bool dontUseFastRealMod = false;
- bool canUseApprox = mlir::arith::bitEnumContainsAny(
- builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn);
- if (auto attr = mod->getAttrOfType<mlir::BoolAttr>("fir.no_fast_real_mod"))
- dontUseFastRealMod = attr.getValue();
+ bool useFastRealMod = false;
+ if (auto attr = mod->getAttrOfType<mlir::BoolAttr>("fir.fast_real_mod"))
+ useFastRealMod = attr.getValue();
assert(args.size() == 2);
if (resultType.isUnsignedInteger()) {
@@ -7092,7 +6731,7 @@ mlir::Value IntrinsicLibrary::genMod(mlir::Type resultType,
if (mlir::isa<mlir::IntegerType>(resultType))
return mlir::arith::RemSIOp::create(builder, loc, args[0], args[1]);
- if (resultType.isFloat() && canUseApprox && !dontUseFastRealMod) {
+ if (resultType.isFloat() && useFastRealMod) {
// Treat MOD as an approximate function and code-gen inline code
// instead of calling into the Fortran runtime library.
return builder.createConvert(loc, resultType,
@@ -7545,14 +7184,6 @@ IntrinsicLibrary::genNumImages(mlir::Type resultType,
return mif::NumImagesOp::create(builder, loc).getResult();
}
-// CLOCK, CLOCK64, GLOBALTIMER
-template <typename OpTy>
-mlir::Value IntrinsicLibrary::genNVVMTime(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 0 && "expect no arguments");
- return OpTy::create(builder, loc, resultType).getResult();
-}
-
// PACK
fir::ExtendedValue
IntrinsicLibrary::genPack(mlir::Type resultType,
@@ -7706,6 +7337,19 @@ IntrinsicLibrary::genPutenv(std::optional<mlir::Type> resultType,
return {};
}
+// RAND
+fir::ExtendedValue
+IntrinsicLibrary::genRand(mlir::Type, llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 1);
+ mlir::Value i =
+ isStaticallyPresent(args[0])
+ ? fir::getBase(args[0])
+ : fir::AbsentOp::create(builder, loc,
+ builder.getRefType(builder.getI32Type()))
+ .getResult();
+ return fir::runtime::genRand(builder, loc, i);
+}
+
// RANDOM_INIT
void IntrinsicLibrary::genRandomInit(llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 2);
@@ -8371,6 +8015,47 @@ mlir::Value IntrinsicLibrary::genShiftA(mlir::Type resultType,
return result;
}
+void IntrinsicLibrary::genShowDescriptor(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 1 && "expected single argument for show_descriptor");
+ const mlir::Value arg = fir::getBase(args[0]);
+
+ // Use consistent !fir.ref<!fir.box<none>> argument type
+ auto targetType = fir::BoxType::get(builder.getNoneType());
+ auto targetRefType = fir::ReferenceType::get(targetType);
+
+ mlir::Value descrAddr = nullptr;
+ if (fir::isBoxAddress(arg.getType())) {
+ // If it's already a reference to a box, convert it to correct type and
+ // pass it directly
+ descrAddr = builder.createConvert(loc, targetRefType, arg);
+ } else {
+ // At this point, arg is either SSA descriptor or a non-descriptor entity.
+ // If necessary, wrap non-descriptor entity in a descriptor.
+ mlir::Value descriptor = nullptr;
+ if (fir::isa_box_type(arg.getType())) {
+ descriptor = arg;
+ } else if (fir::isa_ref_type(arg.getType())) {
+ // Note: here use full extended value args[0]
+ descriptor = builder.createBox(loc, args[0]);
+ } else {
+ // arg is a value (e.g. constant), spill it to a temporary
+ // because createBox expects a memory reference.
+ mlir::Value temp = builder.createTemporary(loc, arg.getType());
+ builder.createStoreWithConvert(loc, arg, temp);
+
+ // Note: here use full extended value args[0]
+ descriptor = builder.createBox(loc, fir::substBase(args[0], temp));
+ }
+
+ // Spill it to the stack
+ descrAddr = builder.createTemporary(loc, targetType);
+ builder.createStoreWithConvert(loc, descriptor, descrAddr);
+ }
+
+ fir::runtime::genShowDescriptor(builder, loc, descrAddr);
+}
+
// SIGNAL
void IntrinsicLibrary::genSignalSubroutine(
llvm::ArrayRef<fir::ExtendedValue> args) {
@@ -8527,90 +8212,16 @@ mlir::Value IntrinsicLibrary::genTanpi(mlir::Type resultType,
return getRuntimeCallGenerator("tan", ftype)(builder, loc, {arg});
}
-// THIS_GRID
-mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 0);
- auto recTy = mlir::cast<fir::RecordType>(resultType);
- assert(recTy && "RecordType expepected");
- mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
- mlir::Type i32Ty = builder.getI32Type();
+// TEAM_NUMBER
+fir::ExtendedValue
+IntrinsicLibrary::genTeamNumber(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ converter->checkCoarrayEnabled();
+ assert(args.size() == 1);
- mlir::Value threadIdX = mlir::NVVM::ThreadIdXOp::create(builder, loc, i32Ty);
- mlir::Value threadIdY = mlir::NVVM::ThreadIdYOp::create(builder, loc, i32Ty);
- mlir::Value threadIdZ = mlir::NVVM::ThreadIdZOp::create(builder, loc, i32Ty);
-
- mlir::Value blockIdX = mlir::NVVM::BlockIdXOp::create(builder, loc, i32Ty);
- mlir::Value blockIdY = mlir::NVVM::BlockIdYOp::create(builder, loc, i32Ty);
- mlir::Value blockIdZ = mlir::NVVM::BlockIdZOp::create(builder, loc, i32Ty);
-
- mlir::Value blockDimX = mlir::NVVM::BlockDimXOp::create(builder, loc, i32Ty);
- mlir::Value blockDimY = mlir::NVVM::BlockDimYOp::create(builder, loc, i32Ty);
- mlir::Value blockDimZ = mlir::NVVM::BlockDimZOp::create(builder, loc, i32Ty);
- mlir::Value gridDimX = mlir::NVVM::GridDimXOp::create(builder, loc, i32Ty);
- mlir::Value gridDimY = mlir::NVVM::GridDimYOp::create(builder, loc, i32Ty);
- mlir::Value gridDimZ = mlir::NVVM::GridDimZOp::create(builder, loc, i32Ty);
-
- // this_grid.size = ((blockDim.z * gridDim.z) * (blockDim.y * gridDim.y)) *
- // (blockDim.x * gridDim.x);
- mlir::Value resZ =
- mlir::arith::MulIOp::create(builder, loc, blockDimZ, gridDimZ);
- mlir::Value resY =
- mlir::arith::MulIOp::create(builder, loc, blockDimY, gridDimY);
- mlir::Value resX =
- mlir::arith::MulIOp::create(builder, loc, blockDimX, gridDimX);
- mlir::Value resZY = mlir::arith::MulIOp::create(builder, loc, resZ, resY);
- mlir::Value size = mlir::arith::MulIOp::create(builder, loc, resZY, resX);
-
- // tmp = ((blockIdx.z * gridDim.y * gridDim.x) + (blockIdx.y * gridDim.x)) +
- // blockIdx.x;
- // this_group.rank = tmp * ((blockDim.x * blockDim.y) * blockDim.z) +
- // ((threadIdx.z * blockDim.y) * blockDim.x) +
- // (threadIdx.y * blockDim.x) + threadIdx.x + 1;
- mlir::Value r1 =
- mlir::arith::MulIOp::create(builder, loc, blockIdZ, gridDimY);
- mlir::Value r2 = mlir::arith::MulIOp::create(builder, loc, r1, gridDimX);
- mlir::Value r3 =
- mlir::arith::MulIOp::create(builder, loc, blockIdY, gridDimX);
- mlir::Value r2r3 = mlir::arith::AddIOp::create(builder, loc, r2, r3);
- mlir::Value tmp = mlir::arith::AddIOp::create(builder, loc, r2r3, blockIdX);
-
- mlir::Value bXbY =
- mlir::arith::MulIOp::create(builder, loc, blockDimX, blockDimY);
- mlir::Value bXbYbZ =
- mlir::arith::MulIOp::create(builder, loc, bXbY, blockDimZ);
- mlir::Value tZbY =
- mlir::arith::MulIOp::create(builder, loc, threadIdZ, blockDimY);
- mlir::Value tZbYbX =
- mlir::arith::MulIOp::create(builder, loc, tZbY, blockDimX);
- mlir::Value tYbX =
- mlir::arith::MulIOp::create(builder, loc, threadIdY, blockDimX);
- mlir::Value rank = mlir::arith::MulIOp::create(builder, loc, tmp, bXbYbZ);
- rank = mlir::arith::AddIOp::create(builder, loc, rank, tZbYbX);
- rank = mlir::arith::AddIOp::create(builder, loc, rank, tYbX);
- rank = mlir::arith::AddIOp::create(builder, loc, rank, threadIdX);
- mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
- rank = mlir::arith::AddIOp::create(builder, loc, rank, one);
-
- auto sizeFieldName = recTy.getTypeList()[1].first;
- mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
- mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
- mlir::Value sizeFieldIndex = fir::FieldIndexOp::create(
- builder, loc, fieldIndexType, sizeFieldName, recTy,
- /*typeParams=*/mlir::ValueRange{});
- mlir::Value sizeCoord = fir::CoordinateOp::create(
- builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
- fir::StoreOp::create(builder, loc, size, sizeCoord);
-
- auto rankFieldName = recTy.getTypeList()[2].first;
- mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
- mlir::Value rankFieldIndex = fir::FieldIndexOp::create(
- builder, loc, fieldIndexType, rankFieldName, recTy,
- /*typeParams=*/mlir::ValueRange{});
- mlir::Value rankCoord = fir::CoordinateOp::create(
- builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
- fir::StoreOp::create(builder, loc, rank, rankCoord);
- return res;
+ mlir::Value res = mif::TeamNumberOp::create(builder, loc,
+ /*team*/ fir::getBase(args[0]));
+ return builder.createConvert(loc, resultType, res);
}
// THIS_IMAGE
@@ -8628,99 +8239,6 @@ IntrinsicLibrary::genThisImage(mlir::Type resultType,
return builder.createConvert(loc, resultType, res);
}
-// THIS_THREAD_BLOCK
-mlir::Value
-IntrinsicLibrary::genThisThreadBlock(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 0);
- auto recTy = mlir::cast<fir::RecordType>(resultType);
- assert(recTy && "RecordType expepected");
- mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
- mlir::Type i32Ty = builder.getI32Type();
-
- // this_thread_block%size = blockDim.z * blockDim.y * blockDim.x;
- mlir::Value blockDimX = mlir::NVVM::BlockDimXOp::create(builder, loc, i32Ty);
- mlir::Value blockDimY = mlir::NVVM::BlockDimYOp::create(builder, loc, i32Ty);
- mlir::Value blockDimZ = mlir::NVVM::BlockDimZOp::create(builder, loc, i32Ty);
- mlir::Value size =
- mlir::arith::MulIOp::create(builder, loc, blockDimZ, blockDimY);
- size = mlir::arith::MulIOp::create(builder, loc, size, blockDimX);
-
- // this_thread_block%rank = ((threadIdx.z * blockDim.y) * blockDim.x) +
- // (threadIdx.y * blockDim.x) + threadIdx.x + 1;
- mlir::Value threadIdX = mlir::NVVM::ThreadIdXOp::create(builder, loc, i32Ty);
- mlir::Value threadIdY = mlir::NVVM::ThreadIdYOp::create(builder, loc, i32Ty);
- mlir::Value threadIdZ = mlir::NVVM::ThreadIdZOp::create(builder, loc, i32Ty);
- mlir::Value r1 =
- mlir::arith::MulIOp::create(builder, loc, threadIdZ, blockDimY);
- mlir::Value r2 = mlir::arith::MulIOp::create(builder, loc, r1, blockDimX);
- mlir::Value r3 =
- mlir::arith::MulIOp::create(builder, loc, threadIdY, blockDimX);
- mlir::Value r2r3 = mlir::arith::AddIOp::create(builder, loc, r2, r3);
- mlir::Value rank = mlir::arith::AddIOp::create(builder, loc, r2r3, threadIdX);
- mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
- rank = mlir::arith::AddIOp::create(builder, loc, rank, one);
-
- auto sizeFieldName = recTy.getTypeList()[1].first;
- mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
- mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
- mlir::Value sizeFieldIndex = fir::FieldIndexOp::create(
- builder, loc, fieldIndexType, sizeFieldName, recTy,
- /*typeParams=*/mlir::ValueRange{});
- mlir::Value sizeCoord = fir::CoordinateOp::create(
- builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
- fir::StoreOp::create(builder, loc, size, sizeCoord);
-
- auto rankFieldName = recTy.getTypeList()[2].first;
- mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
- mlir::Value rankFieldIndex = fir::FieldIndexOp::create(
- builder, loc, fieldIndexType, rankFieldName, recTy,
- /*typeParams=*/mlir::ValueRange{});
- mlir::Value rankCoord = fir::CoordinateOp::create(
- builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
- fir::StoreOp::create(builder, loc, rank, rankCoord);
- return res;
-}
-
-// THIS_WARP
-mlir::Value IntrinsicLibrary::genThisWarp(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- assert(args.size() == 0);
- auto recTy = mlir::cast<fir::RecordType>(resultType);
- assert(recTy && "RecordType expepected");
- mlir::Value res = fir::AllocaOp::create(builder, loc, resultType);
- mlir::Type i32Ty = builder.getI32Type();
-
- // coalesced_group%size = 32
- mlir::Value size = builder.createIntegerConstant(loc, i32Ty, 32);
- auto sizeFieldName = recTy.getTypeList()[1].first;
- mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
- mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
- mlir::Value sizeFieldIndex = fir::FieldIndexOp::create(
- builder, loc, fieldIndexType, sizeFieldName, recTy,
- /*typeParams=*/mlir::ValueRange{});
- mlir::Value sizeCoord = fir::CoordinateOp::create(
- builder, loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
- fir::StoreOp::create(builder, loc, size, sizeCoord);
-
- // coalesced_group%rank = threadIdx.x & 31 + 1
- mlir::Value threadIdX = mlir::NVVM::ThreadIdXOp::create(builder, loc, i32Ty);
- mlir::Value mask = builder.createIntegerConstant(loc, i32Ty, 31);
- mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
- mlir::Value masked =
- mlir::arith::AndIOp::create(builder, loc, threadIdX, mask);
- mlir::Value rank = mlir::arith::AddIOp::create(builder, loc, masked, one);
- auto rankFieldName = recTy.getTypeList()[2].first;
- mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
- mlir::Value rankFieldIndex = fir::FieldIndexOp::create(
- builder, loc, fieldIndexType, rankFieldName, recTy,
- /*typeParams=*/mlir::ValueRange{});
- mlir::Value rankCoord = fir::CoordinateOp::create(
- builder, loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
- fir::StoreOp::create(builder, loc, rank, rankCoord);
- return res;
-}
-
// TRAILZ
mlir::Value IntrinsicLibrary::genTrailz(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
@@ -8942,59 +8460,6 @@ IntrinsicLibrary::genSum(mlir::Type resultType,
resultType, args);
}
-// SYNCTHREADS
-void IntrinsicLibrary::genSyncThreads(llvm::ArrayRef<fir::ExtendedValue> args) {
- mlir::NVVM::Barrier0Op::create(builder, loc);
-}
-
-// SYNCTHREADS_AND
-mlir::Value
-IntrinsicLibrary::genSyncThreadsAnd(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.and";
- mlir::MLIRContext *context = builder.getContext();
- mlir::FunctionType ftype =
- mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
- auto funcOp = builder.createFunction(loc, funcName, ftype);
- return fir::CallOp::create(builder, loc, funcOp, args).getResult(0);
-}
-
-// SYNCTHREADS_COUNT
-mlir::Value
-IntrinsicLibrary::genSyncThreadsCount(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.popc";
- mlir::MLIRContext *context = builder.getContext();
- mlir::FunctionType ftype =
- mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
- auto funcOp = builder.createFunction(loc, funcName, ftype);
- return fir::CallOp::create(builder, loc, funcOp, args).getResult(0);
-}
-
-// SYNCTHREADS_OR
-mlir::Value
-IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
- llvm::ArrayRef<mlir::Value> args) {
- constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.or";
- mlir::MLIRContext *context = builder.getContext();
- mlir::FunctionType ftype =
- mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
- auto funcOp = builder.createFunction(loc, funcName, ftype);
- return fir::CallOp::create(builder, loc, funcOp, args).getResult(0);
-}
-
-// SYNCWARP
-void IntrinsicLibrary::genSyncWarp(llvm::ArrayRef<fir::ExtendedValue> args) {
- assert(args.size() == 1);
- constexpr llvm::StringLiteral funcName = "llvm.nvvm.bar.warp.sync";
- mlir::Value mask = fir::getBase(args[0]);
- mlir::FunctionType funcType =
- mlir::FunctionType::get(builder.getContext(), {mask.getType()}, {});
- auto funcOp = builder.createFunction(loc, funcName, funcType);
- llvm::SmallVector<mlir::Value> argsList{mask};
- fir::CallOp::create(builder, loc, funcOp, argsList);
-}
-
// SYSTEM
fir::ExtendedValue
IntrinsicLibrary::genSystem(std::optional<mlir::Type> resultType,
@@ -9126,38 +8591,6 @@ IntrinsicLibrary::genTranspose(mlir::Type resultType,
return readAndAddCleanUp(resultMutableBox, resultType, "TRANSPOSE");
}
-// THREADFENCE
-void IntrinsicLibrary::genThreadFence(llvm::ArrayRef<fir::ExtendedValue> args) {
- constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.gl";
- mlir::FunctionType funcType =
- mlir::FunctionType::get(builder.getContext(), {}, {});
- auto funcOp = builder.createFunction(loc, funcName, funcType);
- llvm::SmallVector<mlir::Value> noArgs;
- fir::CallOp::create(builder, loc, funcOp, noArgs);
-}
-
-// THREADFENCE_BLOCK
-void IntrinsicLibrary::genThreadFenceBlock(
- llvm::ArrayRef<fir::ExtendedValue> args) {
- constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.cta";
- mlir::FunctionType funcType =
- mlir::FunctionType::get(builder.getContext(), {}, {});
- auto funcOp = builder.createFunction(loc, funcName, funcType);
- llvm::SmallVector<mlir::Value> noArgs;
- fir::CallOp::create(builder, loc, funcOp, noArgs);
-}
-
-// THREADFENCE_SYSTEM
-void IntrinsicLibrary::genThreadFenceSystem(
- llvm::ArrayRef<fir::ExtendedValue> args) {
- constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.sys";
- mlir::FunctionType funcType =
- mlir::FunctionType::get(builder.getContext(), {}, {});
- auto funcOp = builder.createFunction(loc, funcName, funcType);
- llvm::SmallVector<mlir::Value> noArgs;
- fir::CallOp::create(builder, loc, funcOp, noArgs);
-}
-
// TIME
mlir::Value IntrinsicLibrary::genTime(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
@@ -9166,46 +8599,6 @@ mlir::Value IntrinsicLibrary::genTime(mlir::Type resultType,
fir::runtime::genTime(builder, loc));
}
-// TMA_BULK_COMMIT_GROUP (CUDA)
-void IntrinsicLibrary::genTMABulkCommitGroup(
- llvm::ArrayRef<fir::ExtendedValue> args) {
- assert(args.size() == 0);
- mlir::NVVM::CpAsyncBulkCommitGroupOp::create(builder, loc);
-}
-
-// TMA_BULK_G2S (CUDA)
-void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue> args) {
- assert(args.size() == 4);
- mlir::Value barrier = convertPtrToNVVMSpace(
- builder, loc, fir::getBase(args[0]), mlir::NVVM::NVVMMemorySpace::Shared);
- mlir::Value dst =
- convertPtrToNVVMSpace(builder, loc, fir::getBase(args[2]),
- mlir::NVVM::NVVMMemorySpace::SharedCluster);
- mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]),
- mlir::NVVM::NVVMMemorySpace::Global);
- mlir::NVVM::CpAsyncBulkGlobalToSharedClusterOp::create(
- builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {});
-}
-
-// TMA_BULK_S2G (CUDA)
-void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {
- assert(args.size() == 3);
- mlir::Value src = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[0]),
- mlir::NVVM::NVVMMemorySpace::Shared);
- mlir::Value dst = convertPtrToNVVMSpace(builder, loc, fir::getBase(args[1]),
- mlir::NVVM::NVVMMemorySpace::Global);
- mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create(
- builder, loc, dst, src, fir::getBase(args[2]), {}, {});
-}
-
-// TMA_BULK_WAIT_GROUP (CUDA)
-void IntrinsicLibrary::genTMABulkWaitGroup(
- llvm::ArrayRef<fir::ExtendedValue> args) {
- assert(args.size() == 0);
- auto group = builder.getIntegerAttr(builder.getI32Type(), 0);
- mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc, group, {});
-}
-
// TRIM
fir::ExtendedValue
IntrinsicLibrary::genTrim(mlir::Type resultType,
@@ -9620,6 +9013,9 @@ getIntrinsicArgumentLowering(llvm::StringRef specificName) {
if (const IntrinsicHandler *ppcHandler = findPPCIntrinsicHandler(name))
if (!ppcHandler->argLoweringRules.hasDefaultRules())
return &ppcHandler->argLoweringRules;
+ if (const IntrinsicHandler *cudaHandler = findCUDAIntrinsicHandler(name))
+ if (!cudaHandler->argLoweringRules.hasDefaultRules())
+ return &cudaHandler->argLoweringRules;
return nullptr;
}
diff --git a/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
index 265e268..5a4e517 100644
--- a/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
@@ -15,6 +15,7 @@
#include "flang/Optimizer/Builder/PPCIntrinsicCall.h"
#include "flang/Evaluate/common.h"
+#include "flang/Lower/AbstractConverter.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/MutableBox.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
diff --git a/flang/lib/Optimizer/Builder/Runtime/Allocatable.cpp b/flang/lib/Optimizer/Builder/Runtime/Allocatable.cpp
index cc9f828..89f5f45 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Allocatable.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Allocatable.cpp
@@ -86,8 +86,9 @@ void fir::runtime::genAllocatableAllocate(fir::FirOpBuilder &builder,
mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType());
errMsg = fir::AbsentOp::create(builder, loc, boxNoneTy).getResult();
}
- llvm::SmallVector<mlir::Value> args{
- fir::runtime::createArguments(builder, loc, fTy, desc, asyncObject,
- hasStat, errMsg, sourceFile, sourceLine)};
+ mlir::Value deviceInit = builder.createBool(loc, false);
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, desc, asyncObject, hasStat, errMsg, sourceFile,
+ sourceLine, deviceInit)};
fir::CallOp::create(builder, loc, func, args);
}
diff --git a/flang/lib/Optimizer/Builder/Runtime/Character.cpp b/flang/lib/Optimizer/Builder/Runtime/Character.cpp
index 540ecba..e297125 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Character.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Character.cpp
@@ -94,27 +94,34 @@ fir::runtime::genCharCompare(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::arith::CmpIPredicate cmp,
mlir::Value lhsBuff, mlir::Value lhsLen,
mlir::Value rhsBuff, mlir::Value rhsLen) {
- mlir::func::FuncOp beginFunc;
- switch (discoverKind(lhsBuff.getType())) {
+ int lhsKind = discoverKind(lhsBuff.getType());
+ int rhsKind = discoverKind(rhsBuff.getType());
+ if (lhsKind != rhsKind) {
+ fir::emitFatalError(loc, "runtime does not support comparison of different "
+ "CHARACTER kind values");
+ }
+ mlir::func::FuncOp func;
+ switch (lhsKind) {
case 1:
- beginFunc = fir::runtime::getRuntimeFunc<mkRTKey(CharacterCompareScalar1)>(
+ func = fir::runtime::getRuntimeFunc<mkRTKey(CharacterCompareScalar1)>(
loc, builder);
break;
case 2:
- beginFunc = fir::runtime::getRuntimeFunc<mkRTKey(CharacterCompareScalar2)>(
+ func = fir::runtime::getRuntimeFunc<mkRTKey(CharacterCompareScalar2)>(
loc, builder);
break;
case 4:
- beginFunc = fir::runtime::getRuntimeFunc<mkRTKey(CharacterCompareScalar4)>(
+ func = fir::runtime::getRuntimeFunc<mkRTKey(CharacterCompareScalar4)>(
loc, builder);
break;
default:
- llvm_unreachable("runtime does not support CHARACTER KIND");
+ fir::emitFatalError(
+ loc, "unsupported CHARACTER kind value. Runtime expects 1, 2, or 4.");
}
- auto fTy = beginFunc.getFunctionType();
+ auto fTy = func.getFunctionType();
auto args = fir::runtime::createArguments(builder, loc, fTy, lhsBuff, rhsBuff,
lhsLen, rhsLen);
- auto tri = fir::CallOp::create(builder, loc, beginFunc, args).getResult(0);
+ auto tri = fir::CallOp::create(builder, loc, func, args).getResult(0);
auto zero = builder.createIntegerConstant(loc, tri.getType(), 0);
return mlir::arith::CmpIOp::create(builder, loc, cmp, tri, zero);
}
@@ -140,6 +147,19 @@ mlir::Value fir::runtime::genCharCompare(fir::FirOpBuilder &builder,
rhsBuffer, fir::getLen(rhs));
}
+void fir::runtime::genFCString(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value resultBox, mlir::Value stringBox,
+ mlir::Value asis) {
+ auto func = fir::runtime::getRuntimeFunc<mkRTKey(FCString)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ auto sourceFile = fir::factory::locationToFilename(builder, loc);
+ auto sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+ auto args = fir::runtime::createArguments(
+ builder, loc, fTy, resultBox, stringBox, asis, sourceFile, sourceLine);
+ fir::CallOp::create(builder, loc, func, args);
+}
+
mlir::Value fir::runtime::genIndex(fir::FirOpBuilder &builder,
mlir::Location loc, int kind,
mlir::Value stringBase,
diff --git a/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp b/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp
index 110b1b2..a5f16f8 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Intrinsics.cpp
@@ -137,6 +137,15 @@ void fir::runtime::genEtime(fir::FirOpBuilder &builder, mlir::Location loc,
fir::CallOp::create(builder, loc, runtimeFunc, args);
}
+void fir::runtime::genFlush(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value unit) {
+ auto runtimeFunc = fir::runtime::getRuntimeFunc<mkRTKey(Flush)>(loc, builder);
+ llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
+ builder, loc, runtimeFunc.getFunctionType(), unit);
+
+ fir::CallOp::create(builder, loc, runtimeFunc, args);
+}
+
void fir::runtime::genFree(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value ptr) {
auto runtimeFunc = fir::runtime::getRuntimeFunc<mkRTKey(Free)>(loc, builder);
@@ -461,3 +470,34 @@ mlir::Value fir::runtime::genChdir(fir::FirOpBuilder &builder,
fir::runtime::createArguments(builder, loc, func.getFunctionType(), name);
return fir::CallOp::create(builder, loc, func, args).getResult(0);
}
+
+mlir::Value fir::runtime::genIrand(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Value i) {
+ auto runtimeFunc = fir::runtime::getRuntimeFunc<mkRTKey(Irand)>(loc, builder);
+ mlir::FunctionType runtimeFuncTy = runtimeFunc.getFunctionType();
+
+ llvm::SmallVector<mlir::Value> args =
+ fir::runtime::createArguments(builder, loc, runtimeFuncTy, i);
+ return fir::CallOp::create(builder, loc, runtimeFunc, args).getResult(0);
+}
+
+mlir::Value fir::runtime::genRand(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Value i) {
+ auto runtimeFunc = fir::runtime::getRuntimeFunc<mkRTKey(Rand)>(loc, builder);
+ mlir::FunctionType runtimeFuncTy = runtimeFunc.getFunctionType();
+
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, runtimeFuncTy.getInput(2));
+
+ llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
+ builder, loc, runtimeFuncTy, i, sourceFile, sourceLine);
+ return fir::CallOp::create(builder, loc, runtimeFunc, args).getResult(0);
+}
+
+void fir::runtime::genShowDescriptor(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Value descAddr) {
+ mlir::func::FuncOp func{
+ fir::runtime::getRuntimeFunc<mkRTKey(ShowDescriptor)>(loc, builder)};
+ fir::CallOp::create(builder, loc, func, descAddr);
+}
diff --git a/flang/lib/Optimizer/Builder/Runtime/Main.cpp b/flang/lib/Optimizer/Builder/Runtime/Main.cpp
index 9ce5e17..2b748de 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Main.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Main.cpp
@@ -74,8 +74,8 @@ void fir::runtime::genMain(
mif::InitOp::create(builder, loc);
fir::CallOp::create(builder, loc, qqMainFn);
- fir::CallOp::create(builder, loc, stopFn);
mlir::Value ret = builder.createIntegerConstant(loc, argcTy, 0);
+ fir::CallOp::create(builder, loc, stopFn);
mlir::func::ReturnOp::create(builder, loc, ret);
}
diff --git a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
index 157d435..343d848 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
@@ -1841,7 +1841,7 @@ mlir::Value fir::runtime::genReduce(fir::FirOpBuilder &builder,
assert((fir::isa_real(eleTy) || fir::isa_integer(eleTy) ||
mlir::isa<fir::LogicalType>(eleTy)) &&
- "expect real, interger or logical");
+ "expect real, integer or logical");
auto [cat, kind] = fir::mlirTypeToCategoryKind(loc, eleTy);
mlir::func::FuncOp func;
diff --git a/flang/lib/Optimizer/Builder/TemporaryStorage.cpp b/flang/lib/Optimizer/Builder/TemporaryStorage.cpp
index 7e329e3..5db40af 100644
--- a/flang/lib/Optimizer/Builder/TemporaryStorage.cpp
+++ b/flang/lib/Optimizer/Builder/TemporaryStorage.cpp
@@ -258,13 +258,9 @@ void fir::factory::AnyVariableStack::pushValue(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Value variable) {
hlfir::Entity entity{variable};
- mlir::Type storageElementType =
- hlfir::getFortranElementType(retValueBox.getType());
- auto [box, maybeCleanUp] =
- hlfir::convertToBox(loc, builder, entity, storageElementType);
+ mlir::Value box =
+ hlfir::genVariableBox(loc, builder, entity, entity.getBoxType());
fir::runtime::genPushDescriptor(loc, builder, opaquePtr, fir::getBase(box));
- if (maybeCleanUp)
- (*maybeCleanUp)();
}
void fir::factory::AnyVariableStack::resetFetchPosition(
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 70bb43a2..6257017 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -39,6 +39,7 @@
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
@@ -680,6 +681,22 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
llvmCall.setResAttrsAttr(resAttrs);
+ if (auto inlineAttr = call.getInlineAttrAttr()) {
+ llvmCall->removeAttr("inline_attr");
+ if (inlineAttr.getValue() == fir::FortranInlineEnum::no_inline) {
+ llvmCall.setNoInlineAttr(rewriter.getUnitAttr());
+ } else if (inlineAttr.getValue() == fir::FortranInlineEnum::inline_hint) {
+ llvmCall.setInlineHintAttr(rewriter.getUnitAttr());
+ } else if (inlineAttr.getValue() ==
+ fir::FortranInlineEnum::always_inline) {
+ llvmCall.setAlwaysInlineAttr(rewriter.getUnitAttr());
+ }
+ }
+
+ if (std::optional<mlir::ArrayAttr> optionalAccessGroups =
+ call.getAccessGroups())
+ llvmCall.setAccessGroups(*optionalAccessGroups);
+
if (memAttr)
llvmCall.setMemoryEffectsAttr(
mlir::cast<mlir::LLVM::MemoryEffectsAttr>(memAttr));
@@ -749,6 +766,44 @@ struct VolatileCastOpConversion
}
};
+/// Lower `fir.assumed_size_extent` to constant -1 of index type.
+struct AssumedSizeExtentOpConversion
+ : public fir::FIROpConversion<fir::AssumedSizeExtentOp> {
+ using FIROpConversion::FIROpConversion;
+
+ llvm::LogicalResult
+ matchAndRewrite(fir::AssumedSizeExtentOp op, OpAdaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ mlir::Location loc = op.getLoc();
+ mlir::Type ity = lowerTy().indexType();
+ auto cst = fir::genConstantIndex(loc, ity, rewriter, -1);
+ rewriter.replaceOp(op, cst.getResult());
+ return mlir::success();
+ }
+};
+
+/// Lower `fir.is_assumed_size_extent` to integer equality with -1.
+struct IsAssumedSizeExtentOpConversion
+ : public fir::FIROpConversion<fir::IsAssumedSizeExtentOp> {
+ using FIROpConversion::FIROpConversion;
+
+ llvm::LogicalResult
+ matchAndRewrite(fir::IsAssumedSizeExtentOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ mlir::Location loc = op.getLoc();
+ mlir::Value val = adaptor.getVal();
+ mlir::Type valTy = val.getType();
+ // Create constant -1 of the operand type.
+ auto negOneAttr = rewriter.getIntegerAttr(valTy, -1);
+ auto negOne =
+ mlir::LLVM::ConstantOp::create(rewriter, loc, valTy, negOneAttr);
+ auto cmp = mlir::LLVM::ICmpOp::create(
+ rewriter, loc, mlir::LLVM::ICmpPredicate::eq, val, negOne);
+ rewriter.replaceOp(op, cmp.getResult());
+ return mlir::success();
+ }
+};
+
/// convert value of from-type to value of to-type
struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
using FIROpConversion::FIROpConversion;
@@ -762,6 +817,60 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
auto fromFirTy = convert.getValue().getType();
auto toFirTy = convert.getRes().getType();
+
+ // Handle conversions between pointer-like values and memref descriptors.
+ // These are produced by FIR-to-MemRef lowering and represent descriptor
+ // conversion rather than pure value conversions.
+ if (auto memRefTy = mlir::dyn_cast<mlir::MemRefType>(toFirTy)) {
+ mlir::Location loc = convert.getLoc();
+ mlir::Value basePtr = adaptor.getValue();
+ assert(basePtr && "null base pointer");
+
+ auto [strides, offset] = memRefTy.getStridesAndOffset();
+ bool hasStaticLayout =
+ mlir::ShapedType::isStatic(offset) &&
+ llvm::none_of(strides, mlir::ShapedType::isDynamic);
+
+ auto *firConv =
+ static_cast<const fir::LLVMTypeConverter *>(this->getTypeConverter());
+ assert(firConv && "expected non-null LLVMTypeConverter");
+
+ if (memRefTy.hasStaticShape() && hasStaticLayout) {
+ // Static shape and layout: build a fully-populated descriptor.
+ mlir::Value memrefDesc = mlir::MemRefDescriptor::fromStaticShape(
+ rewriter, loc, *firConv, memRefTy, basePtr);
+ rewriter.replaceOp(convert, memrefDesc);
+ return mlir::success();
+ }
+
+ // Dynamic shape or layout: create an LLVM memref descriptor and insert
+ // the base pointer field, letting the rest of the fields be populated
+ // by subsequent lowering.
+ mlir::Type llvmMemRefTy = firConv->convertType(memRefTy);
+ auto undef = mlir::LLVM::UndefOp::create(rewriter, loc, llvmMemRefTy);
+ auto insert =
+ mlir::LLVM::InsertValueOp::create(rewriter, loc, undef, basePtr, 1);
+ rewriter.replaceOp(convert, insert);
+ return mlir::success();
+ }
+
+ if (auto memRefTy = mlir::dyn_cast<mlir::MemRefType>(fromFirTy)) {
+ // Legalize conversions *from* memref descriptors to pointer-like values
+ // by extracting the underlying buffer pointer from the descriptor.
+ mlir::Location loc = convert.getLoc();
+ mlir::Value base = adaptor.getValue();
+ auto alignedPtr =
+ mlir::LLVM::ExtractValueOp::create(rewriter, loc, base, 1);
+ auto offset = mlir::LLVM::ExtractValueOp::create(rewriter, loc, base, 2);
+ mlir::Type elementType =
+ this->getTypeConverter()->convertType(memRefTy.getElementType());
+ auto gepOp = mlir::LLVM::GEPOp::create(rewriter, loc,
+ alignedPtr.getType(), elementType,
+ alignedPtr, offset.getResult());
+ rewriter.replaceOp(convert, gepOp);
+ return mlir::success();
+ }
+
auto fromTy = convertType(fromFirTy);
auto toTy = convertType(toFirTy);
mlir::Value op0 = adaptor.getOperands()[0];
@@ -1113,7 +1222,7 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> {
mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy);
if (auto scaleSize =
fir::genAllocationScaleSize(loc, heap.getInType(), ity, rewriter))
- size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
+ size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize);
for (mlir::Value opnd : adaptor.getOperands())
size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size,
integerCast(loc, rewriter, ity, opnd));
@@ -3296,6 +3405,26 @@ private:
}
};
+/// `fir.prefetch` --> `llvm.prefetch`
+struct PrefetchOpConversion : public fir::FIROpConversion<fir::PrefetchOp> {
+ using FIROpConversion::FIROpConversion;
+
+ llvm::LogicalResult
+ matchAndRewrite(fir::PrefetchOp prefetch, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ mlir::IntegerAttr rw = mlir::IntegerAttr::get(rewriter.getI32Type(),
+ prefetch.getRwAttr() ? 1 : 0);
+ mlir::IntegerAttr localityHint = prefetch.getLocalityHintAttr();
+ mlir::IntegerAttr cacheType = mlir::IntegerAttr::get(
+ rewriter.getI32Type(), prefetch.getCacheTypeAttr() ? 1 : 0);
+ mlir::LLVM::Prefetch::create(rewriter, prefetch.getLoc(),
+ adaptor.getOperands().front(), rw,
+ localityHint, cacheType);
+ rewriter.eraseOp(prefetch);
+ return mlir::success();
+ }
+};
+
/// `fir.load` --> `llvm.load`
struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
using FIROpConversion::FIROpConversion;
@@ -3352,6 +3481,9 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
loadOp.setTBAATags(*optionalTag);
else
attachTBAATag(loadOp, load.getType(), load.getType(), nullptr);
+ if (std::optional<mlir::ArrayAttr> optionalAccessGroups =
+ load.getAccessGroups())
+ loadOp.setAccessGroups(*optionalAccessGroups);
rewriter.replaceOp(load, loadOp.getResult());
}
return mlir::success();
@@ -3396,6 +3528,20 @@ struct NoReassocOpConversion : public fir::FIROpConversion<fir::NoReassocOp> {
}
};
+/// Erase `fir.use_stmt` operations during LLVM lowering.
+/// These operations are only used for debug info generation by the
+/// AddDebugInfo pass and have no runtime representation.
+struct UseStmtOpConversion : public fir::FIROpConversion<fir::UseStmtOp> {
+ using FIROpConversion::FIROpConversion;
+
+ llvm::LogicalResult
+ matchAndRewrite(fir::UseStmtOp useStmt, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ rewriter.eraseOp(useStmt);
+ return mlir::success();
+ }
+};
+
static void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest,
std::optional<mlir::ValueRange> destOps,
mlir::ConversionPatternRewriter &rewriter,
@@ -3466,6 +3612,11 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> {
mlir::Block *dest = caseOp.getSuccessor(t);
std::optional<mlir::ValueRange> destOps =
caseOp.getSuccessorOperands(adaptor.getOperands(), t);
+ // Convert block signature if needed
+ if (destOps && !destOps->empty())
+ if (auto conversion = getTypeConverter()->convertBlockSignature(dest))
+ dest = rewriter.applySignatureConversion(dest, *conversion,
+ getTypeConverter());
std::optional<mlir::ValueRange> cmpOps =
*caseOp.getCompareOperands(adaptor.getOperands(), t);
mlir::Attribute attr = cases[t];
@@ -3683,6 +3834,10 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
if (store.getNontemporal())
storeOp.setNontemporal(true);
+ if (std::optional<mlir::ArrayAttr> optionalAccessGroups =
+ store.getAccessGroups())
+ storeOp.setAccessGroups(*optionalAccessGroups);
+
newOp = storeOp;
}
if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())
@@ -4360,6 +4515,7 @@ void fir::populateFIRToLLVMConversionPatterns(
AllocaOpConversion, AllocMemOpConversion, BoxAddrOpConversion,
BoxCharLenOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion,
BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion,
+ AssumedSizeExtentOpConversion, IsAssumedSizeExtentOpConversion,
BoxOffsetOpConversion, BoxProcHostOpConversion, BoxRankOpConversion,
BoxTypeCodeOpConversion, BoxTypeDescOpConversion, CallOpConversion,
CmpcOpConversion, VolatileCastOpConversion, ConvertOpConversion,
@@ -4372,14 +4528,15 @@ void fir::populateFIRToLLVMConversionPatterns(
FirEndOpConversion, FreeMemOpConversion, GlobalLenOpConversion,
GlobalOpConversion, InsertOnRangeOpConversion, IsPresentOpConversion,
LenParamIndexOpConversion, LoadOpConversion, MulcOpConversion,
- NegcOpConversion, NoReassocOpConversion, SelectCaseOpConversion,
- SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
- ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion,
- SliceOpConversion, StoreOpConversion, StringLitOpConversion,
- SubcOpConversion, TypeDescOpConversion, TypeInfoOpConversion,
- UnboxCharOpConversion, UnboxProcOpConversion, UndefOpConversion,
- UnreachableOpConversion, XArrayCoorOpConversion, XEmboxOpConversion,
- XReboxOpConversion, ZeroOpConversion>(converter, options);
+ NegcOpConversion, NoReassocOpConversion, PrefetchOpConversion,
+ SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion,
+ SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion,
+ ShiftOpConversion, SliceOpConversion, StoreOpConversion,
+ StringLitOpConversion, SubcOpConversion, TypeDescOpConversion,
+ TypeInfoOpConversion, UnboxCharOpConversion, UnboxProcOpConversion,
+ UndefOpConversion, UnreachableOpConversion, UseStmtOpConversion,
+ XArrayCoorOpConversion, XEmboxOpConversion, XReboxOpConversion,
+ ZeroOpConversion>(converter, options);
// Patterns that are populated without a type converter do not trigger
// target materializations for the operands of the root op.
diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
index 381b2a2..3e1fe1d 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp
@@ -242,10 +242,11 @@ struct TargetAllocMemOpConversion
loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout());
if (auto scaleSize = fir::genAllocationScaleSize(
loc, allocmemOp.getInType(), ity, rewriter))
- size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
+ size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize);
for (mlir::Value opnd : adaptor.getOperands().drop_front())
- size = rewriter.create<mlir::LLVM::MulOp>(
- loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd));
+ size = mlir::LLVM::MulOp::create(
+ rewriter, loc, ity, size,
+ integerCast(lowerTy(), loc, rewriter, ity, opnd));
auto mallocTyWidth = lowerTy().getIndexTypeBitwidth();
auto mallocTy =
mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth);
@@ -259,6 +260,21 @@ struct TargetAllocMemOpConversion
return mlir::success();
}
};
+
+struct DeclareMapperOpConversion
+ : public OpenMPFIROpConversion<mlir::omp::DeclareMapperOp> {
+ using OpenMPFIROpConversion::OpenMPFIROpConversion;
+
+ llvm::LogicalResult
+ matchAndRewrite(mlir::omp::DeclareMapperOp curOp, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ rewriter.startOpModification(curOp);
+ curOp.setType(convertObjectType(lowerTy(), curOp.getType()));
+ rewriter.finalizeOpModification(curOp);
+ return mlir::success();
+ }
+};
+
} // namespace
void fir::populateOpenMPFIRToLLVMConversionPatterns(
@@ -266,4 +282,5 @@ void fir::populateOpenMPFIRToLLVMConversionPatterns(
patterns.add<MapInfoOpConversion>(converter);
patterns.add<PrivateClauseOpConversion>(converter);
patterns.add<TargetAllocMemOpConversion>(converter);
+ patterns.add<DeclareMapperOpConversion>(converter);
}
diff --git a/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp b/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp
index ac432c7..81488d7 100644
--- a/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp
+++ b/flang/lib/Optimizer/CodeGen/LowerRepackArrays.cpp
@@ -289,7 +289,6 @@ PackArrayConversion::genRepackedBox(fir::FirOpBuilder &builder,
fir::factory::genDimInfoFromBox(builder, loc, box, &lbounds, &extents,
/*strides=*/nullptr);
// Get the type parameters from the box, if needed.
- llvm::SmallVector<mlir::Value> assumedTypeParams;
if (numTypeParams != 0) {
if (auto charType =
mlir::dyn_cast<fir::CharacterType>(boxType.unwrapInnerType()))
diff --git a/flang/lib/Optimizer/CodeGen/PassDetail.h b/flang/lib/Optimizer/CodeGen/PassDetail.h
index f703013..252da02 100644
--- a/flang/lib/Optimizer/CodeGen/PassDetail.h
+++ b/flang/lib/Optimizer/CodeGen/PassDetail.h
@@ -18,7 +18,7 @@
namespace fir {
-#define GEN_PASS_CLASSES
+#define GEN_PASS_DECL
#include "flang/Optimizer/CodeGen/CGPasses.h.inc"
} // namespace fir
diff --git a/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp b/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp
index 1b1d43c..3b137d1 100644
--- a/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp
@@ -302,11 +302,16 @@ public:
else
return mlir::failure();
}
+ // Extract dummy_arg_no attribute if present
+ mlir::IntegerAttr dummyArgNoAttr;
+ if (auto attr = declareOp->getAttrOfType<mlir::IntegerAttr>("dummy_arg_no"))
+ dummyArgNoAttr = attr;
// FIXME: Add FortranAttrs and CudaAttrs
auto xDeclOp = fir::cg::XDeclareOp::create(
rewriter, loc, declareOp.getType(), declareOp.getMemref(), shapeOpers,
shiftOpers, declareOp.getTypeparams(), declareOp.getDummyScope(),
- declareOp.getUniqName());
+ declareOp.getStorage(), declareOp.getStorageOffset(),
+ declareOp.getUniqName(), dummyArgNoAttr);
LLVM_DEBUG(llvm::dbgs()
<< "rewriting " << declareOp << " to " << xDeclOp << '\n');
rewriter.replaceOp(declareOp, xDeclOp.getOperation()->getResults());
diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index b60a72e..9b6c9be 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -353,7 +353,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
ArgClass &current = byteOffset < 8 ? Lo : Hi;
// System V AMD64 ABI 3.2.3. version 1.0
llvm::TypeSwitch<mlir::Type>(type)
- .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+ .Case([&](mlir::IntegerType intTy) {
if (intTy.getWidth() == 128)
Hi = Lo = ArgClass::Integer;
else
@@ -371,7 +371,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
current = ArgClass::SSE;
}
})
- .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+ .Case([&](mlir::ComplexType cmplx) {
const auto *sem = &floatToSemantics(kindMap, cmplx.getElementType());
if (sem == &llvm::APFloat::x87DoubleExtended()) {
current = ArgClass::ComplexX87;
@@ -382,23 +382,23 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
byteOffset, Lo, Hi);
}
})
- .template Case<fir::LogicalType>([&](fir::LogicalType logical) {
+ .Case([&](fir::LogicalType logical) {
if (kindMap.getLogicalBitsize(logical.getFKind()) == 128)
Hi = Lo = ArgClass::Integer;
else
current = ArgClass::Integer;
})
- .template Case<fir::CharacterType>(
+ .Case(
[&](fir::CharacterType character) { current = ArgClass::Integer; })
- .template Case<fir::SequenceType>([&](fir::SequenceType seqTy) {
+ .Case([&](fir::SequenceType seqTy) {
// Array component.
classifyArray(loc, seqTy, byteOffset, Lo, Hi);
})
- .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+ .Case([&](fir::RecordType recTy) {
// Component that is a derived type.
classifyStruct(loc, recTy, byteOffset, Lo, Hi);
})
- .template Case<fir::VectorType>([&](fir::VectorType vecTy) {
+ .Case([&](fir::VectorType vecTy) {
// Previously marshalled SSE eight byte for a previous struct
// argument.
auto *sem = fir::isa_real(vecTy.getEleTy())
@@ -939,23 +939,23 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
NRegs usedRegsForType(mlir::Location loc, mlir::Type type) const {
return llvm::TypeSwitch<mlir::Type, NRegs>(type)
- .Case<mlir::IntegerType>([&](auto intTy) {
+ .Case([&](mlir::IntegerType intTy) {
return intTy.getWidth() == 128 ? NRegs{2, false} : NRegs{1, false};
})
- .Case<mlir::FloatType>([&](auto) { return NRegs{1, true}; })
- .Case<mlir::ComplexType>([&](auto) { return NRegs{2, true}; })
- .Case<fir::LogicalType>([&](auto) { return NRegs{1, false}; })
- .Case<fir::CharacterType>([&](auto) { return NRegs{1, false}; })
- .Case<fir::SequenceType>([&](auto ty) {
+ .Case([&](mlir::FloatType) { return NRegs{1, true}; })
+ .Case([&](mlir::ComplexType) { return NRegs{2, true}; })
+ .Case([&](fir::LogicalType) { return NRegs{1, false}; })
+ .Case([&](fir::CharacterType) { return NRegs{1, false}; })
+ .Case([&](fir::SequenceType ty) {
assert(ty.getShape().size() == 1 &&
"invalid array dimensions in BIND(C)");
NRegs nregs = usedRegsForType(loc, ty.getEleTy());
nregs.n *= ty.getShape()[0];
return nregs;
})
- .Case<fir::RecordType>(
- [&](auto ty) { return usedRegsForRecordType(loc, ty); })
- .Case<fir::VectorType>([&](auto) {
+ .Case(
+ [&](fir::RecordType ty) { return usedRegsForRecordType(loc, ty); })
+ .Case([&](fir::VectorType) {
TODO(loc, "passing vector argument to C by value is not supported");
return NRegs{};
})
@@ -1167,13 +1167,12 @@ struct TargetPPC64le : public GenericTarget<TargetPPC64le> {
unsigned getElemWidth(mlir::Type ty) const {
unsigned width{};
llvm::TypeSwitch<mlir::Type>(ty)
- .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+ .Case([&](mlir::ComplexType cmplx) {
auto elemType{
mlir::dyn_cast<mlir::FloatType>(cmplx.getElementType())};
width = elemType.getWidth();
})
- .template Case<mlir::FloatType>(
- [&](mlir::FloatType real) { width = real.getWidth(); });
+ .Case([&](mlir::FloatType real) { width = real.getWidth(); });
return width;
}
@@ -1594,15 +1593,15 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> {
llvm::SmallVector<mlir::Type> flatTypes;
llvm::TypeSwitch<mlir::Type>(type)
- .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+ .Case([&](mlir::IntegerType intTy) {
if (intTy.getWidth() != 0)
flatTypes.push_back(intTy);
})
- .template Case<mlir::FloatType>([&](mlir::FloatType floatTy) {
+ .Case([&](mlir::FloatType floatTy) {
if (floatTy.getWidth() != 0)
flatTypes.push_back(floatTy);
})
- .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+ .Case([&](mlir::ComplexType cmplx) {
const auto *sem = &floatToSemantics(kindMap, cmplx.getElementType());
if (sem == &llvm::APFloat::IEEEsingle() ||
sem == &llvm::APFloat::IEEEdouble() ||
@@ -1614,21 +1613,21 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> {
"IEEEquad) as a structure component for BIND(C), "
"VALUE derived type argument and type return");
})
- .template Case<fir::LogicalType>([&](fir::LogicalType logicalTy) {
+ .Case([&](fir::LogicalType logicalTy) {
const unsigned width =
kindMap.getLogicalBitsize(logicalTy.getFKind());
if (width != 0)
flatTypes.push_back(
mlir::IntegerType::get(type.getContext(), width));
})
- .template Case<fir::CharacterType>([&](fir::CharacterType charTy) {
+ .Case([&](fir::CharacterType charTy) {
assert(kindMap.getCharacterBitsize(charTy.getFKind()) <= 8 &&
"the bit size of characterType as an interoperable type must "
"not exceed 8");
for (unsigned i = 0; i < charTy.getLen(); ++i)
flatTypes.push_back(mlir::IntegerType::get(type.getContext(), 8));
})
- .template Case<fir::SequenceType>([&](fir::SequenceType seqTy) {
+ .Case([&](fir::SequenceType seqTy) {
if (!seqTy.hasDynamicExtents()) {
const std::uint64_t numOfEle = seqTy.getConstantArraySize();
mlir::Type eleTy = seqTy.getEleTy();
@@ -1646,7 +1645,7 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> {
"component for BIND(C), "
"VALUE derived type argument and type return");
})
- .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+ .Case([&](fir::RecordType recTy) {
for (auto &component : recTy.getTypeList()) {
mlir::Type eleTy = component.second;
llvm::SmallVector<mlir::Type> subTypeList =
@@ -1655,7 +1654,7 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> {
llvm::copy(subTypeList, std::back_inserter(flatTypes));
}
})
- .template Case<fir::VectorType>([&](fir::VectorType vecTy) {
+ .Case([&](fir::VectorType vecTy) {
auto sizeAndAlign = fir::getTypeSizeAndAlignmentOrCrash(
loc, vecTy, getDataLayout(), kindMap);
if (sizeAndAlign.first == 2 * GRLenInChar)
@@ -1742,7 +1741,7 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> {
return true;
llvm::TypeSwitch<mlir::Type>(type)
- .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+ .Case([&](mlir::IntegerType intTy) {
const unsigned width = intTy.getWidth();
if (width > 128)
TODO(loc,
@@ -1754,7 +1753,7 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> {
else if (width <= 2 * GRLen)
GARsLeft = GARsLeft - 2;
})
- .template Case<mlir::FloatType>([&](mlir::FloatType floatTy) {
+ .Case([&](mlir::FloatType floatTy) {
const unsigned width = floatTy.getWidth();
if (width > 128)
TODO(loc, "floatType with width exceeding 128 bits is unsupported");
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index ac285b5..3ef4703 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -143,7 +143,8 @@ public:
llvm::SmallVector<mlir::Type> operandsTypes;
for (auto arg : gpuLaunchFunc.getKernelOperands())
operandsTypes.push_back(arg.getType());
- auto fctTy = mlir::FunctionType::get(&context, operandsTypes, {});
+ auto fctTy = mlir::FunctionType::get(&context, operandsTypes,
+ gpuLaunchFunc.getResultTypes());
if (!hasPortableSignature(fctTy, op))
convertCallOp(gpuLaunchFunc, fctTy);
} else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
@@ -392,12 +393,12 @@ public:
if (fnTy.getResults().size() == 1) {
mlir::Type ty = fnTy.getResult(0);
llvm::TypeSwitch<mlir::Type>(ty)
- .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+ .Case([&](mlir::ComplexType cmplx) {
wrap = rewriteCallComplexResultType(loc, cmplx, newResTys,
newInTyAndAttrs, newOpers,
savedStackPtr);
})
- .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+ .Case([&](fir::RecordType recTy) {
wrap = rewriteCallStructResultType(loc, recTy, newResTys,
newInTyAndAttrs, newOpers,
savedStackPtr);
@@ -421,7 +422,7 @@ public:
mlir::Value oper = std::get<1>(e.value());
unsigned index = e.index();
llvm::TypeSwitch<mlir::Type>(ty)
- .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
+ .Case([&](fir::BoxCharType boxTy) {
if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
if (noCharacterConversion) {
newInTyAndAttrs.push_back(
@@ -455,15 +456,15 @@ public:
}
}
})
- .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+ .Case([&](mlir::ComplexType cmplx) {
rewriteCallComplexInputType(loc, cmplx, oper, newInTyAndAttrs,
newOpers, savedStackPtr);
})
- .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+ .Case([&](fir::RecordType recTy) {
rewriteCallStructInputType(loc, recTy, oper, newInTyAndAttrs,
newOpers, savedStackPtr);
})
- .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
+ .Case([&](mlir::TupleType tuple) {
if (fir::isCharacterProcedureTuple(tuple)) {
mlir::ModuleOp module = getModule();
if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
@@ -520,10 +521,14 @@ public:
llvm::SmallVector<mlir::Value, 1> newCallResults;
// TODO propagate/update call argument and result attributes.
if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) {
+ mlir::Value asyncToken = callOp.getAsyncToken();
auto newCall = A::create(*rewriter, loc, callOp.getKernel(),
callOp.getGridSizeOperandValues(),
callOp.getBlockSizeOperandValues(),
- callOp.getDynamicSharedMemorySize(), newOpers);
+ callOp.getDynamicSharedMemorySize(), newOpers,
+ asyncToken ? asyncToken.getType() : nullptr,
+ callOp.getAsyncDependencies(),
+ /*clusterSize=*/std::nullopt);
if (callOp.getClusterSizeX())
newCall.getClusterSizeXMutable().assign(callOp.getClusterSizeX());
if (callOp.getClusterSizeY())
@@ -702,10 +707,10 @@ public:
auto loc = addrOp.getLoc();
for (mlir::Type ty : addrTy.getResults()) {
llvm::TypeSwitch<mlir::Type>(ty)
- .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
+ .Case([&](mlir::ComplexType ty) {
lowerComplexSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
})
- .Case<fir::RecordType>([&](fir::RecordType ty) {
+ .Case([&](fir::RecordType ty) {
lowerStructSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
})
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
@@ -713,7 +718,7 @@ public:
llvm::SmallVector<mlir::Type> trailingInTys;
for (mlir::Type ty : addrTy.getInputs()) {
llvm::TypeSwitch<mlir::Type>(ty)
- .Case<fir::BoxCharType>([&](auto box) {
+ .Case([&](fir::BoxCharType box) {
if (noCharacterConversion) {
newInTyAndAttrs.push_back(
fir::CodeGenSpecifics::getTypeAndAttr(box));
@@ -728,10 +733,10 @@ public:
}
}
})
- .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
+ .Case([&](mlir::ComplexType ty) {
lowerComplexSignatureArg(loc, ty, newInTyAndAttrs);
})
- .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
+ .Case([&](mlir::TupleType tuple) {
if (fir::isCharacterProcedureTuple(tuple)) {
newInTyAndAttrs.push_back(
fir::CodeGenSpecifics::getTypeAndAttr(tuple.getType(0)));
@@ -741,7 +746,7 @@ public:
fir::CodeGenSpecifics::getTypeAndAttr(ty));
}
})
- .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+ .Case([&](fir::RecordType recTy) {
lowerStructSignatureArg(loc, recTy, newInTyAndAttrs);
})
.Default([&](mlir::Type ty) {
@@ -872,16 +877,24 @@ public:
}
}
+ // Count the number of arguments that have to stay in place at the end of
+ // the argument list.
+ unsigned trailingArgs = 0;
+ if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>) {
+ trailingArgs =
+ func.getNumWorkgroupAttributions() + func.getNumPrivateAttributions();
+ }
+
// Convert return value(s)
for (auto ty : funcTy.getResults())
llvm::TypeSwitch<mlir::Type>(ty)
- .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+ .Case([&](mlir::ComplexType cmplx) {
if (noComplexConversion)
newResTys.push_back(cmplx);
else
doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
})
- .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+ .Case([&](mlir::IntegerType intTy) {
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
assert(m.size() == 1);
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -895,7 +908,7 @@ public:
rewriter->getUnitAttr()));
newResTys.push_back(retTy);
})
- .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+ .Case([&](fir::RecordType recTy) {
doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups);
})
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
@@ -910,7 +923,7 @@ public:
auto ty = e.value();
unsigned index = e.index();
llvm::TypeSwitch<mlir::Type>(ty)
- .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
+ .Case([&](fir::BoxCharType boxTy) {
if (noCharacterConversion) {
newInTyAndAttrs.push_back(
fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
@@ -933,10 +946,10 @@ public:
}
}
})
- .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+ .Case([&](mlir::ComplexType cmplx) {
doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
})
- .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
+ .Case([&](mlir::TupleType tuple) {
if (fir::isCharacterProcedureTuple(tuple)) {
fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
newInTyAndAttrs.size(), trailingTys.size());
@@ -948,7 +961,7 @@ public:
fir::CodeGenSpecifics::getTypeAndAttr(ty));
}
})
- .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+ .Case([&](mlir::IntegerType intTy) {
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
assert(m.size() == 1);
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -965,7 +978,7 @@ public:
newInTyAndAttrs.push_back(m[0]);
})
- .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+ .Case([&](fir::RecordType recTy) {
doStructArg(func, recTy, newInTyAndAttrs, fixups);
})
.Default([&](mlir::Type ty) {
@@ -981,6 +994,16 @@ public:
}
}
+ // Add the argument at the end if the number of trailing arguments is 0,
+ // otherwise insert the argument at the appropriate index.
+ auto addOrInsertArgument = [&](mlir::Type ty, mlir::Location loc) {
+ unsigned inputIndex = func.front().getArguments().size() - trailingArgs;
+ auto newArg = trailingArgs == 0
+ ? func.front().addArgument(ty, loc)
+ : func.front().insertArgument(inputIndex, ty, loc);
+ return newArg;
+ };
+
if (!func.empty()) {
// If the function has a body, then apply the fixups to the arguments and
// return ops as required. These fixups are done in place.
@@ -1117,8 +1140,7 @@ public:
// original arguments. (Boxchar arguments.)
auto newBufArg =
func.front().insertArgument(fixup.index, fixupType, loc);
- auto newLenArg =
- func.front().addArgument(trailingTys[fixup.second], loc);
+ auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
auto boxTy = oldArgTys[fixup.index - offset];
rewriter->setInsertionPointToStart(&func.front());
auto box = fir::EmboxCharOp::create(*rewriter, loc, boxTy, newBufArg,
@@ -1133,8 +1155,7 @@ public:
// appended after all the original arguments.
auto newProcPointerArg =
func.front().insertArgument(fixup.index, fixupType, loc);
- auto newLenArg =
- func.front().addArgument(trailingTys[fixup.second], loc);
+ auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
auto tupleType = oldArgTys[fixup.index - offset];
rewriter->setInsertionPointToStart(&func.front());
fir::FirOpBuilder builder(*rewriter, getModule());
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index 2283560..3c4162c 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -163,8 +163,8 @@ LLVMTypeConverter::convertRecordType(fir::RecordType derived,
return mlir::success();
}
callStack.push_back(derived);
- auto popConversionCallStack =
- llvm::make_scope_exit([&callStack]() { callStack.pop_back(); });
+ llvm::scope_exit popConversionCallStack(
+ [&callStack]() { callStack.pop_back(); });
llvm::SmallVector<mlir::Type> members;
for (auto mem : derived.getTypeList()) {
diff --git a/flang/lib/Optimizer/Dialect/CMakeLists.txt b/flang/lib/Optimizer/Dialect/CMakeLists.txt
index 65d1f2c..f81989a 100644
--- a/flang/lib/Optimizer/Dialect/CMakeLists.txt
+++ b/flang/lib/Optimizer/Dialect/CMakeLists.txt
@@ -6,6 +6,7 @@ add_subdirectory(MIF)
add_flang_library(FIRDialect
FIRAttr.cpp
FIRDialect.cpp
+ FIROperationMoveOpInterface.cpp
FIROps.cpp
FIRType.cpp
FirAliasTagOpInterface.cpp
@@ -15,6 +16,7 @@ add_flang_library(FIRDialect
DEPENDS
CanonicalizationPatternsIncGen
+ FIROperationMoveOpInterfaceIncGen
FIROpsIncGen
FIRSafeTempArrayCopyAttrInterfaceIncGen
CUFAttrsIncGen
diff --git a/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp b/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp
index bd0499f..3f58065 100644
--- a/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/Attributes/CUFAttr.cpp
@@ -52,4 +52,18 @@ bool hasDataAttr(mlir::Operation *op, cuf::DataAttribute value) {
return false;
}
+bool isDeviceDataAttribute(cuf::DataAttribute attr) {
+ return attr == cuf::DataAttribute::Device ||
+ attr == cuf::DataAttribute::Managed ||
+ attr == cuf::DataAttribute::Constant ||
+ attr == cuf::DataAttribute::Shared ||
+ attr == cuf::DataAttribute::Unified;
+}
+
+bool hasDeviceDataAttr(mlir::Operation *op) {
+ if (auto dataAttr = getDataAttr(op))
+ return isDeviceDataAttribute(dataAttr.getValue());
+ return false;
+}
+
} // namespace cuf
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index 687007d..a157c47 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -274,6 +274,26 @@ llvm::LogicalResult cuf::KernelOp::verify() {
return checkStreamType(*this);
}
+bool cuf::KernelOp::canMoveFromDescendant(mlir::Operation *descendant,
+ mlir::Operation *candidate) {
+ // Moving operations out of loops inside cuf.kernel is always legal.
+ return true;
+}
+
+bool cuf::KernelOp::canMoveOutOf(mlir::Operation *candidate) {
+ // In general, some movement of operations out of cuf.kernel is allowed.
+ if (!candidate)
+ return true;
+
+ // Operations that have !fir.ref operands cannot be moved
+ // out of cuf.kernel, because this may break implicit data mapping
+ // passes that may run after LICM.
+ return !llvm::any_of(candidate->getOperands(),
+ [&](mlir::Value candidateOperand) {
+ return fir::isa_ref_type(candidateOperand.getType());
+ });
+}
+
//===----------------------------------------------------------------------===//
// RegisterKernelOp
//===----------------------------------------------------------------------===//
@@ -333,7 +353,8 @@ void cuf::SharedMemoryOp::build(
bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName);
build(builder, result, wrapAllocaResultType(inType),
mlir::TypeAttr::get(inType), nameAttr, bindcAttr, typeparams, shape,
- /*offset=*/mlir::Value{});
+ /*offset=*/mlir::Value{}, /*alignment=*/mlir::IntegerAttr{},
+ /*isStatic=*/nullptr);
result.addAttributes(attributes);
}
diff --git a/flang/lib/Optimizer/Dialect/FIROperationMoveOpInterface.cpp b/flang/lib/Optimizer/Dialect/FIROperationMoveOpInterface.cpp
new file mode 100644
index 0000000..dcf5323
--- /dev/null
+++ b/flang/lib/Optimizer/Dialect/FIROperationMoveOpInterface.cpp
@@ -0,0 +1,49 @@
+//===-- FIROperationMoveOpInterface.cpp -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIROperationMoveOpInterface.h"
+
+#include "flang/Optimizer/Dialect/FIROperationMoveOpInterface.cpp.inc"
+
+llvm::LogicalResult
+fir::detail::verifyOperationMoveOpInterface(mlir::Operation *op) {
+ // It does not make sense to use this interface for operations
+ // without any regions.
+ if (op->getNumRegions() == 0)
+ return op->emitOpError("must contain at least one region");
+ return llvm::success();
+}
+
+bool fir::canMoveFromDescendant(mlir::Operation *op,
+ mlir::Operation *descendant,
+ mlir::Operation *candidate) {
+ // Perform some sanity checks.
+ assert(op->isProperAncestor(descendant) &&
+ "op must be an ancestor of descendant");
+ if (candidate)
+ assert(descendant->isProperAncestor(candidate) &&
+ "descendant must be an ancestor of candidate");
+ if (auto iface = mlir::dyn_cast<OperationMoveOpInterface>(op))
+ return iface.canMoveFromDescendant(descendant, candidate);
+
+ return true;
+}
+
+bool fir::canMoveOutOf(mlir::Operation *op, mlir::Operation *candidate) {
+ if (candidate)
+ assert(op->isProperAncestor(candidate) &&
+ "op must be an ancestor of candidate");
+ if (auto iface = mlir::dyn_cast<OperationMoveOpInterface>(op))
+ return iface.canMoveOutOf(candidate);
+
+ return true;
+}
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 1712af1..9c22b61 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -174,6 +174,32 @@ static void printAllocatableOp(mlir::OpAsmPrinter &p, OP &op) {
p.printOptionalAttrDict(op->getAttrs(), {"in_type", "operandSegmentSizes"});
}
+bool fir::mayBeAbsentBox(mlir::Value val) {
+ assert(mlir::isa<fir::BaseBoxType>(val.getType()) && "expected box argument");
+ while (val) {
+ mlir::Operation *defOp = val.getDefiningOp();
+ if (!defOp)
+ return true;
+
+ if (auto varIface = mlir::dyn_cast<fir::FortranVariableOpInterface>(defOp))
+ return varIface.isOptional();
+
+ // Check for fir.embox and fir.rebox before checking for
+ // FortranObjectViewOpInterface, which they support.
+ // A box created by fir.embox/rebox cannot be absent.
+ if (mlir::isa<fir::ReboxOp, fir::EmboxOp, fir::LoadOp>(defOp))
+ return false;
+
+ if (auto viewIface =
+ mlir::dyn_cast<fir::FortranObjectViewOpInterface>(defOp)) {
+ val = viewIface.getViewSource(mlir::cast<mlir::OpResult>(val));
+ continue;
+ }
+ break;
+ }
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// AllocaOp
//===----------------------------------------------------------------------===//
@@ -186,6 +212,36 @@ static mlir::Type wrapAllocaResultType(mlir::Type intype) {
return fir::ReferenceType::get(intype);
}
+llvm::SmallVector<mlir::MemorySlot> fir::AllocaOp::getPromotableSlots() {
+ // TODO: support promotion of dynamic allocas
+ if (isDynamic())
+ return {};
+
+ return {mlir::MemorySlot{getResult(), getAllocatedType()}};
+}
+
+mlir::Value fir::AllocaOp::getDefaultValue(const mlir::MemorySlot &slot,
+ mlir::OpBuilder &builder) {
+ return fir::UndefOp::create(builder, getLoc(), slot.elemType);
+}
+
+void fir::AllocaOp::handleBlockArgument(const mlir::MemorySlot &slot,
+ mlir::BlockArgument argument,
+ mlir::OpBuilder &builder) {}
+
+std::optional<mlir::PromotableAllocationOpInterface>
+fir::AllocaOp::handlePromotionComplete(const mlir::MemorySlot &slot,
+ mlir::Value defaultValue,
+ mlir::OpBuilder &builder) {
+ if (defaultValue && defaultValue.use_empty()) {
+ assert(mlir::isa<fir::UndefOp>(defaultValue.getDefiningOp()) &&
+ "Expected undef op to be the default value");
+ defaultValue.getDefiningOp()->erase();
+ }
+ this->erase();
+ return std::nullopt;
+}
+
mlir::Type fir::AllocaOp::getAllocatedType() {
return mlir::cast<fir::ReferenceType>(getType()).getEleTy();
}
@@ -834,6 +890,11 @@ void fir::ArrayCoorOp::getCanonicalizationPatterns(
patterns.add<SimplifyArrayCoorOp>(context);
}
+std::optional<std::int64_t> fir::ArrayCoorOp::getViewOffset(mlir::OpResult) {
+ // TODO: we can try to compute the constant offset.
+ return std::nullopt;
+}
+
//===----------------------------------------------------------------------===//
// ArrayLoadOp
//===----------------------------------------------------------------------===//
@@ -1054,17 +1115,16 @@ void fir::BoxAddrOp::build(mlir::OpBuilder &builder,
mlir::OperationState &result, mlir::Value val) {
mlir::Type type =
llvm::TypeSwitch<mlir::Type, mlir::Type>(val.getType())
- .Case<fir::BaseBoxType>([&](fir::BaseBoxType ty) -> mlir::Type {
+ .Case([&](fir::BaseBoxType ty) -> mlir::Type {
mlir::Type eleTy = ty.getEleTy();
if (fir::isa_ref_type(eleTy))
return eleTy;
return fir::ReferenceType::get(eleTy);
})
- .Case<fir::BoxCharType>([&](fir::BoxCharType ty) -> mlir::Type {
+ .Case([&](fir::BoxCharType ty) -> mlir::Type {
return fir::ReferenceType::get(ty.getEleTy());
})
- .Case<fir::BoxProcType>(
- [&](fir::BoxProcType ty) { return ty.getEleTy(); })
+ .Case([&](fir::BoxProcType ty) { return ty.getEleTy(); })
.Default([&](const auto &) { return mlir::Type{}; });
assert(type && "bad val type");
build(builder, result, type, val);
@@ -1086,6 +1146,22 @@ mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) {
return {};
}
+std::optional<std::int64_t> fir::BoxAddrOp::getViewOffset(mlir::OpResult) {
+ // fir.box_addr just returns the base address stored inside a box,
+ // so the direct accesses through the base address and through the box
+ // are not offsetted.
+ return 0;
+}
+
+mlir::Speculation::Speculatability fir::BoxAddrOp::getSpeculatability() {
+ // Do not speculate fir.box_addr with BoxProcType and BoxCharType
+ // inputs.
+ if (!mlir::isa<fir::BaseBoxType>(getVal().getType()))
+ return mlir::Speculation::NotSpeculatable;
+ return mayBeAbsentBox(getVal()) ? mlir::Speculation::NotSpeculatable
+ : mlir::Speculation::Speculatable;
+}
+
//===----------------------------------------------------------------------===//
// BoxCharLenOp
//===----------------------------------------------------------------------===//
@@ -1110,6 +1186,11 @@ mlir::Type fir::BoxDimsOp::getTupleType() {
return mlir::TupleType::get(getContext(), triple);
}
+mlir::Speculation::Speculatability fir::BoxDimsOp::getSpeculatability() {
+ return mayBeAbsentBox(getVal()) ? mlir::Speculation::NotSpeculatable
+ : mlir::Speculation::Speculatable;
+}
+
//===----------------------------------------------------------------------===//
// BoxRankOp
//===----------------------------------------------------------------------===//
@@ -1588,6 +1669,22 @@ llvm::LogicalResult fir::ConvertOp::verify() {
<< getValue().getType() << " / " << getType();
}
+mlir::Speculation::Speculatability fir::ConvertOp::getSpeculatability() {
+ // fir.convert is speculatable, in general. The only concern may be
+ // converting from or/and to floating point types, which may trigger
+ // some FP exceptions. Disallow speculating such converts for the time being.
+ // Also disallow speculation for converts to/from non-FIR types, except
+ // for some builtin types.
+ auto canSpeculateType = [](mlir::Type ty) {
+ if (fir::isa_fir_type(ty) || fir::isa_integer(ty))
+ return true;
+ return false;
+ };
+ return (canSpeculateType(getValue().getType()) && canSpeculateType(getType()))
+ ? mlir::Speculation::Speculatable
+ : mlir::Speculation::NotSpeculatable;
+}
+
//===----------------------------------------------------------------------===//
// CoordinateOp
//===----------------------------------------------------------------------===//
@@ -1627,11 +1724,11 @@ void fir::CoordinateOp::build(mlir::OpBuilder &builder,
bool anyField = false;
for (fir::IntOrValue index : coor) {
llvm::TypeSwitch<fir::IntOrValue>(index)
- .Case<mlir::IntegerAttr>([&](mlir::IntegerAttr intAttr) {
+ .Case([&](mlir::IntegerAttr intAttr) {
fieldIndices.push_back(intAttr.getInt());
anyField = true;
})
- .Case<mlir::Value>([&](mlir::Value value) {
+ .Case([&](mlir::Value value) {
dynamicIndices.push_back(value);
fieldIndices.push_back(fir::CoordinateOp::kDynamicIndex);
});
@@ -1654,7 +1751,7 @@ void fir::CoordinateOp::print(mlir::OpAsmPrinter &p) {
for (auto index : getIndices()) {
p << ", ";
llvm::TypeSwitch<fir::IntOrValue>(index)
- .Case<mlir::IntegerAttr>([&](mlir::IntegerAttr intAttr) {
+ .Case([&](mlir::IntegerAttr intAttr) {
if (auto recordType = llvm::dyn_cast<fir::RecordType>(eleTy)) {
int fieldId = intAttr.getInt();
if (fieldId < static_cast<int>(recordType.getNumFields())) {
@@ -1669,7 +1766,7 @@ void fir::CoordinateOp::print(mlir::OpAsmPrinter &p) {
// investigated.
p << intAttr;
})
- .Case<mlir::Value>([&](mlir::Value value) { p << value; });
+ .Case([&](mlir::Value value) { p << value; });
}
}
p.printOptionalAttrDict(
@@ -1820,6 +1917,20 @@ fir::CoordinateIndicesAdaptor fir::CoordinateOp::getIndices() {
return CoordinateIndicesAdaptor(getFieldIndicesAttr(), getCoor());
}
+std::optional<std::int64_t> fir::CoordinateOp::getViewOffset(mlir::OpResult) {
+ // TODO: we can try to compute the constant offset.
+ return std::nullopt;
+}
+
+mlir::Speculation::Speculatability fir::CoordinateOp::getSpeculatability() {
+ const mlir::Type refTy = getRef().getType();
+ if (fir::isa_ref_type(refTy))
+ return mlir::Speculation::Speculatable;
+
+ return mayBeAbsentBox(getRef()) ? mlir::Speculation::NotSpeculatable
+ : mlir::Speculation::Speculatable;
+}
+
//===----------------------------------------------------------------------===//
// DispatchOp
//===----------------------------------------------------------------------===//
@@ -2066,6 +2177,20 @@ bool fir::isContiguousEmbox(fir::EmboxOp embox, bool checkWhole) {
return false;
}
+std::optional<std::int64_t> fir::EmboxOp::getViewOffset(mlir::OpResult) {
+ // The address offset is zero, unless there is a slice.
+ // TODO: we can handle slices that leave the base address untouched.
+ if (!getSlice())
+ return 0;
+ return std::nullopt;
+}
+
+mlir::Speculation::Speculatability fir::EmboxOp::getSpeculatability() {
+ return (getSourceBox() && mayBeAbsentBox(getSourceBox()))
+ ? mlir::Speculation::NotSpeculatable
+ : mlir::Speculation::Speculatable;
+}
+
//===----------------------------------------------------------------------===//
// EmboxCharOp
//===----------------------------------------------------------------------===//
@@ -2836,6 +2961,39 @@ llvm::SmallVector<mlir::Attribute> fir::LenParamIndexOp::getAttributes() {
// LoadOp
//===----------------------------------------------------------------------===//
+bool fir::LoadOp::loadsFrom(const mlir::MemorySlot &slot) {
+ return getMemref() == slot.ptr;
+}
+
+bool fir::LoadOp::storesTo(const mlir::MemorySlot &slot) { return false; }
+
+mlir::Value fir::LoadOp::getStored(const mlir::MemorySlot &slot,
+ mlir::OpBuilder &builder,
+ mlir::Value reachingDef,
+ const mlir::DataLayout &dataLayout) {
+ return mlir::Value();
+}
+
+bool fir::LoadOp::canUsesBeRemoved(
+ const mlir::MemorySlot &slot,
+ const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+ mlir::SmallVectorImpl<mlir::OpOperand *> &newBlockingUses,
+ const mlir::DataLayout &dataLayout) {
+ if (blockingUses.size() != 1)
+ return false;
+ mlir::Value blockingUse = (*blockingUses.begin())->get();
+ return blockingUse == slot.ptr && getMemref() == slot.ptr;
+}
+
+mlir::DeletionKind fir::LoadOp::removeBlockingUses(
+ const mlir::MemorySlot &slot,
+ const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+ mlir::OpBuilder &builder, mlir::Value reachingDefinition,
+ const mlir::DataLayout &dataLayout) {
+ getResult().replaceAllUsesWith(reachingDefinition);
+ return mlir::DeletionKind::Delete;
+}
+
void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
mlir::Value refVal) {
if (!refVal) {
@@ -3205,11 +3363,19 @@ mlir::ParseResult fir::DTEntryOp::parse(mlir::OpAsmParser &parser,
parser.parseAttribute(calleeAttr, fir::DTEntryOp::getProcAttrNameStr(),
result.attributes))
return mlir::failure();
+
+ // Optional "deferred" keyword.
+ if (succeeded(parser.parseOptionalKeyword("deferred"))) {
+ result.addAttribute(fir::DTEntryOp::getDeferredAttrNameStr(),
+ parser.getBuilder().getUnitAttr());
+ }
return mlir::success();
}
void fir::DTEntryOp::print(mlir::OpAsmPrinter &p) {
p << ' ' << getMethodAttr() << ", " << getProcAttr();
+ if ((*this)->getAttr(fir::DTEntryOp::getDeferredAttrNameStr()))
+ p << " deferred";
}
//===----------------------------------------------------------------------===//
@@ -3313,6 +3479,19 @@ llvm::LogicalResult fir::ReboxOp::verify() {
return mlir::success();
}
+std::optional<std::int64_t> fir::ReboxOp::getViewOffset(mlir::OpResult) {
+ // The address offset is zero, unless there is a slice.
+ // TODO: we can handle slices that leave the base address untouched.
+ if (!getSlice())
+ return 0;
+ return std::nullopt;
+}
+
+mlir::Speculation::Speculatability fir::ReboxOp::getSpeculatability() {
+ return mayBeAbsentBox(getBox()) ? mlir::Speculation::NotSpeculatable
+ : mlir::Speculation::Speculatable;
+}
+
//===----------------------------------------------------------------------===//
// ReboxAssumedRankOp
//===----------------------------------------------------------------------===//
@@ -4215,6 +4394,39 @@ llvm::LogicalResult fir::SliceOp::verify() {
// StoreOp
//===----------------------------------------------------------------------===//
+bool fir::StoreOp::loadsFrom(const mlir::MemorySlot &slot) { return false; }
+
+bool fir::StoreOp::storesTo(const mlir::MemorySlot &slot) {
+ return getMemref() == slot.ptr;
+}
+
+mlir::Value fir::StoreOp::getStored(const mlir::MemorySlot &slot,
+ mlir::OpBuilder &builder,
+ mlir::Value reachingDef,
+ const mlir::DataLayout &dataLayout) {
+ return getValue();
+}
+
+bool fir::StoreOp::canUsesBeRemoved(
+ const mlir::MemorySlot &slot,
+ const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+ mlir::SmallVectorImpl<mlir::OpOperand *> &newBlockingUses,
+ const mlir::DataLayout &dataLayout) {
+ if (blockingUses.size() != 1)
+ return false;
+ mlir::Value blockingUse = (*blockingUses.begin())->get();
+ return blockingUse == slot.ptr && getMemref() == slot.ptr &&
+ getValue() != slot.ptr;
+}
+
+mlir::DeletionKind fir::StoreOp::removeBlockingUses(
+ const mlir::MemorySlot &slot,
+ const SmallPtrSetImpl<mlir::OpOperand *> &blockingUses,
+ mlir::OpBuilder &builder, mlir::Value reachingDefinition,
+ const mlir::DataLayout &dataLayout) {
+ return mlir::DeletionKind::Delete;
+}
+
mlir::Type fir::StoreOp::elementType(mlir::Type refType) {
return fir::dyn_cast_ptrEleTy(refType);
}
@@ -4252,7 +4464,7 @@ llvm::LogicalResult fir::StoreOp::verify() {
void fir::StoreOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
mlir::Value value, mlir::Value memref) {
- build(builder, result, value, memref, {});
+ build(builder, result, value, memref, {}, {}, {});
}
void fir::StoreOp::getEffects(
@@ -4265,6 +4477,84 @@ void fir::StoreOp::getEffects(
}
//===----------------------------------------------------------------------===//
+// PrefetchOp
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult fir::PrefetchOp::parse(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ mlir::OpAsmParser::UnresolvedOperand memref;
+ if (parser.parseOperand(memref))
+ return mlir::failure();
+
+ if (mlir::succeeded(parser.parseLBrace())) {
+ llvm::StringRef kw;
+ if (parser.parseKeyword(&kw))
+ return mlir::failure();
+
+ if (kw == "read")
+ result.addAttribute("rw", parser.getBuilder().getBoolAttr(false));
+ else if (kw == "write")
+ result.addAttribute("rw", parser.getBuilder().getUnitAttr());
+ else
+ return parser.emitError(parser.getCurrentLocation(),
+ "Expected either read or write keyword");
+
+ if (parser.parseComma())
+ return mlir::failure();
+
+ if (parser.parseKeyword(&kw))
+ return mlir::failure();
+ if (kw == "instruction") {
+ result.addAttribute("cacheType", parser.getBuilder().getBoolAttr(false));
+ } else if (kw == "data") {
+ result.addAttribute("cacheType", parser.getBuilder().getUnitAttr());
+ } else
+ return parser.emitError(parser.getCurrentLocation(),
+ "Expected either intruction or data keyword");
+
+ if (parser.parseComma())
+ return mlir::failure();
+
+ if (mlir::succeeded(parser.parseKeyword("localityHint"))) {
+ if (parser.parseEqual())
+ return mlir::failure();
+ mlir::Attribute intAttr;
+ if (parser.parseAttribute(intAttr))
+ return mlir::failure();
+ result.addAttribute("localityHint", intAttr);
+ }
+ if (parser.parseRBrace())
+ return mlir::failure();
+ }
+ mlir::Type type;
+ if (parser.parseColonType(type))
+ return mlir::failure();
+
+ if (parser.resolveOperand(memref, type, result.operands))
+ return mlir::failure();
+ return mlir::success();
+}
+
+void fir::PrefetchOp::print(mlir::OpAsmPrinter &p) {
+ p << " ";
+ p.printOperand(getMemref());
+ p << " {";
+ if (getRw())
+ p << "write";
+ else
+ p << "read";
+ p << ", ";
+ if (getCacheType())
+ p << "data";
+ else
+ p << "instruction";
+ p << ", localityHint = ";
+ p << getLocalityHint();
+ p << " : " << getLocalityHintAttr().getType();
+ p << "} : " << getMemref().getType();
+}
+
+//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
@@ -4484,7 +4774,7 @@ void fir::IfOp::getSuccessorRegions(
llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation.
if (!point.isParent()) {
- regions.push_back(mlir::RegionSuccessor(getResults()));
+ regions.push_back(mlir::RegionSuccessor::parent());
return;
}
@@ -4494,11 +4784,18 @@ void fir::IfOp::getSuccessorRegions(
// Don't consider the else region if it is empty.
mlir::Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- regions.push_back(mlir::RegionSuccessor());
+ regions.push_back(mlir::RegionSuccessor::parent());
else
regions.push_back(mlir::RegionSuccessor(elseRegion));
}
+mlir::ValueRange
+fir::IfOp::getSuccessorInputs(mlir::RegionSuccessor successor) {
+ if (successor.isParent())
+ return getOperation()->getResults();
+ return mlir::ValueRange();
+}
+
void fir::IfOp::getEntrySuccessorRegions(
llvm::ArrayRef<mlir::Attribute> operands,
llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
@@ -4513,7 +4810,7 @@ void fir::IfOp::getEntrySuccessorRegions(
if (!getElseRegion().empty())
regions.emplace_back(&getElseRegion());
else
- regions.emplace_back(getResults());
+ regions.push_back(mlir::RegionSuccessor::parent());
}
}
@@ -4887,7 +5184,7 @@ bool fir::isDummyArgument(mlir::Value v) {
mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) {
for (auto i = path.begin(), end = path.end(); eleTy && i < end;) {
eleTy = llvm::TypeSwitch<mlir::Type, mlir::Type>(eleTy)
- .Case<fir::RecordType>([&](fir::RecordType ty) {
+ .Case([&](fir::RecordType ty) {
if (auto *op = (*i++).getDefiningOp()) {
if (auto off = mlir::dyn_cast<fir::FieldIndexOp>(op))
return ty.getType(off.getFieldName());
@@ -4896,7 +5193,7 @@ mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) {
}
return mlir::Type{};
})
- .Case<fir::SequenceType>([&](fir::SequenceType ty) {
+ .Case([&](fir::SequenceType ty) {
bool valid = true;
const auto rank = ty.getDimension();
for (std::remove_const_t<decltype(rank)> ii = 0;
@@ -4904,13 +5201,13 @@ mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) {
valid = i < end && fir::isa_integer((*i++).getType());
return valid ? ty.getEleTy() : mlir::Type{};
})
- .Case<mlir::TupleType>([&](mlir::TupleType ty) {
+ .Case([&](mlir::TupleType ty) {
if (auto *op = (*i++).getDefiningOp())
if (auto off = mlir::dyn_cast<mlir::arith::ConstantOp>(op))
return ty.getType(fir::toInt(off));
return mlir::Type{};
})
- .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
+ .Case([&](mlir::ComplexType ty) {
if (fir::isa_integer((*i++).getType()))
return ty.getElementType();
return mlir::Type{};
@@ -5143,6 +5440,34 @@ void fir::BoxTotalElementsOp::getCanonicalizationPatterns(
}
//===----------------------------------------------------------------------===//
+// IsAssumedSizeExtentOp and AssumedSizeExtentOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldIsAssumedSizeExtentOnCtor
+ : public mlir::OpRewritePattern<fir::IsAssumedSizeExtentOp> {
+ using mlir::OpRewritePattern<fir::IsAssumedSizeExtentOp>::OpRewritePattern;
+ mlir::LogicalResult
+ matchAndRewrite(fir::IsAssumedSizeExtentOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ if (llvm::isa_and_nonnull<fir::AssumedSizeExtentOp>(
+ op.getVal().getDefiningOp())) {
+ mlir::Type i1 = rewriter.getI1Type();
+ rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
+ op, i1, rewriter.getIntegerAttr(i1, 1));
+ return mlir::success();
+ }
+ return mlir::failure();
+ }
+};
+} // namespace
+
+void fir::IsAssumedSizeExtentOp::getCanonicalizationPatterns(
+ mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
+ patterns.add<FoldIsAssumedSizeExtentOnCtor>(context);
+}
+
+//===----------------------------------------------------------------------===//
// LocalitySpecifierOp
//===----------------------------------------------------------------------===//
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index fe35b08..ccdc8e4 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -183,18 +183,21 @@ struct RecordTypeStorage : public mlir::TypeStorage {
bool isPacked() const { return packed; }
void pack(bool p) { packed = p; }
+ bool isSequence() const { return sequence; }
+ void setSequence(bool s) { sequence = s; }
protected:
std::string name;
bool finalized;
bool packed;
+ bool sequence;
std::vector<RecordType::TypePair> lens;
std::vector<RecordType::TypePair> types;
private:
RecordTypeStorage() = delete;
explicit RecordTypeStorage(llvm::StringRef name)
- : name{name}, finalized{false}, packed{false} {}
+ : name{name}, finalized{false}, packed{false}, sequence{false} {}
};
} // namespace detail
@@ -226,8 +229,7 @@ mlir::Type getDerivedType(mlir::Type ty) {
return seq.getEleTy();
return p.getEleTy();
})
- .Case<fir::BaseBoxType>(
- [](auto p) { return getDerivedType(p.getEleTy()); })
+ .Case([](fir::BaseBoxType p) { return getDerivedType(p.getEleTy()); })
.Default([](mlir::Type t) { return t; });
}
@@ -423,7 +425,7 @@ mlir::Type unwrapInnerType(mlir::Type ty) {
return seqTy.getEleTy();
return eleTy;
})
- .Case<fir::RecordType>([](auto t) { return t; })
+ .Case([](fir::RecordType t) { return t; })
.Default([](mlir::Type) { return mlir::Type{}; });
}
@@ -685,7 +687,7 @@ std::string getTypeAsString(mlir::Type ty, const fir::KindMapping &kindMap,
mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType,
bool turnBoxIntoClass) {
return llvm::TypeSwitch<mlir::Type, mlir::Type>(type)
- .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
+ .Case([&](fir::SequenceType seqTy) -> mlir::Type {
return fir::SequenceType::get(seqTy.getShape(), newElementType);
})
.Case<fir::ReferenceType, fir::ClassType>([&](auto t) -> mlir::Type {
@@ -699,7 +701,7 @@ mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType,
return FIRT::get(
changeElementType(t.getEleTy(), newElementType, turnBoxIntoClass));
})
- .Case<fir::BoxType>([&](fir::BoxType t) -> mlir::Type {
+ .Case([&](fir::BoxType t) -> mlir::Type {
mlir::Type newInnerType =
changeElementType(t.getEleTy(), newElementType, false);
if (turnBoxIntoClass)
@@ -1014,6 +1016,14 @@ mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) {
if (parser.parseLess() || parser.parseKeyword(&name))
return {};
RecordType result = RecordType::get(parser.getContext(), name);
+ // Optional SEQUENCE attribute: ", sequence"
+ if (!parser.parseOptionalComma()) {
+ if (parser.parseKeyword("sequence")) {
+ parser.emitError(parser.getNameLoc(), "expected 'sequence' keyword");
+ return {};
+ }
+ result.setSequence(true);
+ }
RecordType::TypeVector lenParamList;
if (!parser.parseOptionalLParen()) {
@@ -1069,6 +1079,8 @@ mlir::Type fir::RecordType::parse(mlir::AsmParser &parser) {
void fir::RecordType::print(mlir::AsmPrinter &printer) const {
printer << "<" << getName();
+ if (isSequence())
+ printer << ",sequence";
if (!recordTypeVisited.count(uniqueKey())) {
recordTypeVisited.insert(uniqueKey());
if (getLenParamList().size()) {
@@ -1123,6 +1135,10 @@ void fir::RecordType::pack(bool p) { getImpl()->pack(p); }
bool fir::RecordType::isPacked() const { return getImpl()->isPacked(); }
+bool fir::RecordType::isSequence() const { return getImpl()->isSequence(); }
+
+void fir::RecordType::setSequence(bool s) { getImpl()->setSequence(s); }
+
detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const {
return getImpl();
}
@@ -1438,7 +1454,7 @@ static mlir::Type
changeTypeShape(mlir::Type type,
std::optional<fir::SequenceType::ShapeRef> newShape) {
return llvm::TypeSwitch<mlir::Type, mlir::Type>(type)
- .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
+ .Case([&](fir::SequenceType seqTy) -> mlir::Type {
if (newShape)
return fir::SequenceType::get(*newShape, seqTy.getEleTy());
return seqTy.getEleTy();
@@ -1498,10 +1514,10 @@ fir::BaseBoxType fir::BaseBoxType::getBoxTypeWithNewAttr(
break;
}
return llvm::TypeSwitch<fir::BaseBoxType, fir::BaseBoxType>(*this)
- .Case<fir::BoxType>([baseType](auto b) {
+ .Case([baseType](fir::BoxType b) {
return fir::BoxType::get(baseType, b.isVolatile());
})
- .Case<fir::ClassType>([baseType](auto b) {
+ .Case([baseType](fir::ClassType b) {
return fir::ClassType::get(baseType, b.isVolatile());
});
}
diff --git a/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt b/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt
index d52ab09..d53937eb 100644
--- a/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt
+++ b/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt
@@ -3,18 +3,21 @@ add_flang_library(MIFDialect
MIFOps.cpp
DEPENDS
- MLIRIR
MIFOpsIncGen
LINK_LIBS
FIRDialect
FIRDialectSupport
- FIRSupport
- MLIRIR
- MLIRTargetLLVMIRExport
LINK_COMPONENTS
AsmParser
AsmPrinter
Remarks
+
+ MLIR_DEPS
+ MLIRIR
+
+ MLIR_LIBS
+ MLIRIR
+ MLIRTargetLLVMIRExport
)
diff --git a/flang/lib/Optimizer/Dialect/MIF/MIFOps.cpp b/flang/lib/Optimizer/Dialect/MIF/MIFOps.cpp
index c6cc2e8..8b04226 100644
--- a/flang/lib/Optimizer/Dialect/MIF/MIFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/MIF/MIFOps.cpp
@@ -15,9 +15,6 @@
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallVector.h"
-#define GET_OP_CLASSES
-#include "flang/Optimizer/Dialect/MIF/MIFOps.cpp.inc"
-
//===----------------------------------------------------------------------===//
// NumImagesOp
//===----------------------------------------------------------------------===//
@@ -151,3 +148,59 @@ llvm::LogicalResult mif::CoSumOp::verify() {
return emitOpError("`A` shall be of numeric type.");
return mlir::success();
}
+
+//===----------------------------------------------------------------------===//
+// ChangeTeamOp
+//===----------------------------------------------------------------------===//
+
+void mif::ChangeTeamOp::build(mlir::OpBuilder &builder,
+ mlir::OperationState &result, mlir::Value team,
+ llvm::ArrayRef<mlir::NamedAttribute> attributes) {
+ build(builder, result, team, /*stat*/ mlir::Value{}, /*errmsg*/ mlir::Value{},
+ attributes);
+}
+
+void mif::ChangeTeamOp::build(mlir::OpBuilder &builder,
+ mlir::OperationState &result, mlir::Value team,
+ mlir::Value stat, mlir::Value errmsg,
+ llvm::ArrayRef<mlir::NamedAttribute> attributes) {
+ std::int32_t argStat = 0, argErrmsg = 0;
+ result.addOperands(team);
+ if (stat) {
+ result.addOperands(stat);
+ argStat++;
+ }
+ if (errmsg) {
+ result.addOperands(errmsg);
+ argErrmsg++;
+ }
+
+ mlir::Region *bodyRegion = result.addRegion();
+ bodyRegion->push_back(new mlir::Block{});
+
+ result.addAttribute(getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({1, argStat, argErrmsg}));
+ result.addAttributes(attributes);
+}
+
+static mlir::ParseResult parseChangeTeamOpBody(mlir::OpAsmParser &parser,
+ mlir::Region &body) {
+ if (parser.parseRegion(body))
+ return mlir::failure();
+
+ mlir::Operation *terminator = body.back().getTerminator();
+ if (!terminator || !mlir::isa<mif::EndTeamOp>(terminator))
+ return parser.emitError(parser.getNameLoc(),
+ "missing mif.end_team terminator");
+
+ return mlir::success();
+}
+
+static void printChangeTeamOpBody(mlir::OpAsmPrinter &p, mif::ChangeTeamOp op,
+ mlir::Region &body) {
+ p.printRegion(op.getRegion(), /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+}
+
+#define GET_OP_CLASSES
+#include "flang/Optimizer/Dialect/MIF/MIFOps.cpp.inc"
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index 1b1abef..e0fee2f 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -87,8 +87,8 @@ bool hlfir::isFortranVariableType(mlir::Type type) {
return mlir::isa<fir::BaseBoxType>(eleType) ||
!fir::hasDynamicSize(eleType);
})
- .Case<fir::BaseBoxType, fir::BoxCharType>([](auto) { return true; })
- .Case<fir::VectorType>([](auto) { return true; })
+ .Case<fir::BaseBoxType, fir::BoxCharType>([](mlir::Type) { return true; })
+ .Case([](fir::VectorType) { return true; })
.Default([](mlir::Type) { return false; });
}
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 1332dc5..e42c064 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -261,14 +261,12 @@ updateDeclaredInputTypeWithVolatility(mlir::Type inputType, mlir::Value memref,
return std::make_pair(inputType, memref);
}
-void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
- mlir::OperationState &result, mlir::Value memref,
- llvm::StringRef uniq_name, mlir::Value shape,
- mlir::ValueRange typeparams,
- mlir::Value dummy_scope, mlir::Value storage,
- std::uint64_t storage_offset,
- fir::FortranVariableFlagsAttr fortran_attrs,
- cuf::DataAttributeAttr data_attr) {
+void hlfir::DeclareOp::build(
+ mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value memref,
+ llvm::StringRef uniq_name, mlir::Value shape, mlir::ValueRange typeparams,
+ mlir::Value dummy_scope, mlir::Value storage, std::uint64_t storage_offset,
+ fir::FortranVariableFlagsAttr fortran_attrs,
+ cuf::DataAttributeAttr data_attr, unsigned dummy_arg_no) {
auto nameAttr = builder.getStringAttr(uniq_name);
mlir::Type inputType = memref.getType();
bool hasExplicitLbs = hasExplicitLowerBounds(shape);
@@ -279,9 +277,12 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
}
auto [hlfirVariableType, firVarType] =
getDeclareOutputTypes(inputType, hasExplicitLbs);
+ mlir::IntegerAttr argNoAttr;
+ if (dummy_arg_no > 0)
+ argNoAttr = builder.getUI32IntegerAttr(dummy_arg_no);
build(builder, result, {hlfirVariableType, firVarType}, memref, shape,
typeparams, dummy_scope, storage, storage_offset, nameAttr,
- fortran_attrs, data_attr, /*skip_rebox=*/mlir::UnitAttr{});
+ fortran_attrs, data_attr, /*skip_rebox=*/mlir::UnitAttr{}, argNoAttr);
}
llvm::LogicalResult hlfir::DeclareOp::verify() {
@@ -591,6 +592,12 @@ llvm::LogicalResult hlfir::DesignateOp::verify() {
return mlir::success();
}
+std::optional<std::int64_t> hlfir::DesignateOp::getViewOffset(mlir::OpResult) {
+ // TODO: we can compute the constant offset
+ // based on the component/indices/etc.
+ return std::nullopt;
+}
+
//===----------------------------------------------------------------------===//
// ParentComponentOp
//===----------------------------------------------------------------------===//
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
index 6a57bf2..13d9fc2 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
@@ -149,13 +149,18 @@ public:
!assignOp.isTemporaryLHS() &&
mlir::isa<fir::RecordType>(fir::getElementTypeOf(lhsExv));
+ mlir::ArrayAttr accessGroups;
+ if (auto attrs = assignOp.getOperation()->getAttrOfType<mlir::ArrayAttr>(
+ fir::getAccessGroupsAttrName()))
+ accessGroups = attrs;
+
// genScalarAssignment() must take care of potential overlap
// between LHS and RHS. Note that the overlap is possible
// also for components of LHS/RHS, and the Assign() runtime
// must take care of it.
- fir::factory::genScalarAssignment(builder, loc, lhsExv, rhsExv,
- needFinalization,
- assignOp.isTemporaryLHS());
+ fir::factory::genScalarAssignment(
+ builder, loc, lhsExv, rhsExv, needFinalization,
+ assignOp.isTemporaryLHS(), accessGroups);
}
rewriter.eraseOp(assignOp);
return mlir::success();
@@ -308,7 +313,8 @@ public:
declareOp.getTypeparams(), declareOp.getDummyScope(),
/*storage=*/declareOp.getStorage(),
/*storage_offset=*/declareOp.getStorageOffset(),
- declareOp.getUniqName(), fortranAttrs, dataAttr);
+ declareOp.getUniqName(), fortranAttrs, dataAttr,
+ declareOp.getDummyArgNoAttr());
// Propagate other attributes from hlfir.declare to fir.declare.
// OpenACC's acc.declare is one example. Right now, the propagation
@@ -467,7 +473,7 @@ public:
if (designate.getComponent()) {
mlir::Type baseRecordType = baseEntity.getFortranElementType();
if (fir::isRecordWithTypeParameters(baseRecordType))
- TODO(loc, "hlfir.designate with a parametrized derived type base");
+ TODO(loc, "hlfir.designate with a parameterized derived type base");
fieldIndex = fir::FieldIndexOp::create(
builder, loc, fir::FieldType::get(builder.getContext()),
designate.getComponent().value(), baseRecordType,
@@ -493,7 +499,7 @@ public:
return mlir::success();
}
TODO(loc,
- "addressing parametrized derived type automatic components");
+ "addressing parameterized derived type automatic components");
}
baseEleTy = hlfir::getFortranElementType(componentType);
shape = designate.getComponentShape();
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp
index 86d3974..356552f 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRAssign.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Analysis/AliasAnalysis.h"
+#include "flang/Optimizer/Analysis/ArraySectionAnalyzer.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
@@ -93,40 +94,32 @@ public:
// and proceed with the inlining.
fir::AliasAnalysis aliasAnalysis;
mlir::AliasResult aliasRes = aliasAnalysis.alias(lhs, rhs);
- // TODO: use areIdenticalOrDisjointSlices() from
- // OptimizedBufferization.cpp to check if we can still do the expansion.
if (!aliasRes.isNo()) {
- LLVM_DEBUG(llvm::dbgs() << "InlineHLFIRAssign:\n"
- << "\tLHS: " << lhs << "\n"
- << "\tRHS: " << rhs << "\n"
- << "\tALIAS: " << aliasRes << "\n");
- return rewriter.notifyMatchFailure(assign, "RHS/LHS may alias");
+ // Alias analysis reports potential aliasing, but we can use
+ // ArraySectionAnalyzer to check if the slices are disjoint
+ // or identical (which is safe for element-wise assignment).
+ fir::ArraySectionAnalyzer::SlicesOverlapKind overlap =
+ fir::ArraySectionAnalyzer::analyze(lhs, rhs);
+ if (overlap == fir::ArraySectionAnalyzer::SlicesOverlapKind::Unknown) {
+ LLVM_DEBUG(llvm::dbgs() << "InlineHLFIRAssign:\n"
+ << "\tLHS: " << lhs << "\n"
+ << "\tRHS: " << rhs << "\n"
+ << "\tALIAS: " << aliasRes << "\n");
+ return rewriter.notifyMatchFailure(assign, "RHS/LHS may alias");
+ }
}
}
mlir::Location loc = assign->getLoc();
fir::FirOpBuilder builder(rewriter, assign.getOperation());
builder.setInsertionPoint(assign);
- rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs);
- lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
- mlir::Value lhsShape = hlfir::genShape(loc, builder, lhs);
- llvm::SmallVector<mlir::Value> lhsExtents =
- hlfir::getIndexExtents(loc, builder, lhsShape);
- mlir::Value rhsShape = hlfir::genShape(loc, builder, rhs);
- llvm::SmallVector<mlir::Value> rhsExtents =
- hlfir::getIndexExtents(loc, builder, rhsShape);
- llvm::SmallVector<mlir::Value> extents =
- fir::factory::deduceOptimalExtents(lhsExtents, rhsExtents);
- hlfir::LoopNest loopNest =
- hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
- flangomp::shouldUseWorkshareLowering(assign));
- builder.setInsertionPointToStart(loopNest.body);
- auto rhsArrayElement =
- hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
- rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);
- auto lhsArrayElement =
- hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
- hlfir::AssignOp::create(builder, loc, rhsArrayElement, lhsArrayElement);
+ mlir::ArrayAttr accessGroups;
+ if (auto attrs = assign.getOperation()->getAttrOfType<mlir::ArrayAttr>(
+ fir::getAccessGroupsAttrName()))
+ accessGroups = attrs;
+ hlfir::genNoAliasArrayAssignment(
+ loc, builder, rhs, lhs, flangomp::shouldUseWorkshareLowering(assign),
+ /*temporaryLHS=*/false, /*combiner=*/nullptr, accessGroups);
rewriter.eraseOp(assign);
return mlir::success();
}
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
index 32998ab..a3fd19d 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
@@ -96,7 +96,7 @@ struct MaskedArrayExpr {
/// hlfir.elemental_addr that form the elemental tree producing
/// the expression value. hlfir.elemental that produce values
/// used inside transformational operations are not part of this set.
- llvm::SmallPtrSet<mlir::Operation *, 4> elementalParts{};
+ hlfir::ElementalTree elementalParts;
/// Was generateNoneElementalPart called?
bool noneElementalPartWasGenerated = false;
/// Is this expression the mask expression of the outer where statement?
@@ -517,7 +517,10 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
} else {
// TODO: preserve allocatable assignment aspects for forall once
// they are conveyed in hlfir.region_assign.
- hlfir::AssignOp::create(builder, loc, rhsEntity, lhsEntity);
+ auto assignOp = hlfir::AssignOp::create(builder, loc, rhsEntity, lhsEntity);
+ if (auto accessGroups = regionAssignOp->getAttrOfType<mlir::ArrayAttr>(
+ fir::getAccessGroupsAttrName()))
+ assignOp->setAttr(fir::getAccessGroupsAttrName(), accessGroups);
}
generateCleanupIfAny(loweredLhs.elementalCleanup);
if (loweredLhs.vectorSubscriptLoopNest)
@@ -897,62 +900,11 @@ bool OrderedAssignmentRewriter::isRequiredInCurrentRun(
return false;
}
-/// Is the apply using all the elemental indices in order?
-static bool isInOrderApply(hlfir::ApplyOp apply,
- hlfir::ElementalOpInterface elemental) {
- mlir::Region::BlockArgListType elementalIndices = elemental.getIndices();
- if (elementalIndices.size() != apply.getIndices().size())
- return false;
- for (auto [elementalIdx, applyIdx] :
- llvm::zip(elementalIndices, apply.getIndices()))
- if (elementalIdx != applyIdx)
- return false;
- return true;
-}
-
-/// Gather the tree of hlfir::ElementalOpInterface use-def, if any, starting
-/// from \p elemental, which may be a nullptr.
-static void
-gatherElementalTree(hlfir::ElementalOpInterface elemental,
- llvm::SmallPtrSetImpl<mlir::Operation *> &elementalOps,
- bool isOutOfOrder) {
- if (elemental) {
- // Only inline an applied elemental that must be executed in order if the
- // applying indices are in order. An hlfir::Elemental may have been created
- // for a transformational like transpose, and Fortran 2018 standard
- // section 10.2.3.2, point 10 imply that impure elemental sub-expression
- // evaluations should not be masked if they are the arguments of
- // transformational expressions.
- if (isOutOfOrder && elemental.isOrdered())
- return;
- elementalOps.insert(elemental.getOperation());
- for (mlir::Operation &op : elemental.getElementalRegion().getOps())
- if (auto apply = mlir::dyn_cast<hlfir::ApplyOp>(op)) {
- bool isUnorderedApply =
- isOutOfOrder || !isInOrderApply(apply, elemental);
- auto maybeElemental =
- mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>(
- apply.getExpr().getDefiningOp());
- gatherElementalTree(maybeElemental, elementalOps, isUnorderedApply);
- }
- }
-}
-
MaskedArrayExpr::MaskedArrayExpr(mlir::Location loc, mlir::Region &region,
bool isOuterMaskExpr)
: loc{loc}, region{region}, isOuterMaskExpr{isOuterMaskExpr} {
mlir::Operation &terminator = region.back().back();
- if (auto elementalAddr =
- mlir::dyn_cast<hlfir::ElementalOpInterface>(terminator)) {
- // Vector subscripted designator (hlfir.elemental_addr terminator).
- gatherElementalTree(elementalAddr, elementalParts, /*isOutOfOrder=*/false);
- return;
- }
- // Try if elemental expression.
- mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
- auto maybeElemental = mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>(
- entity.getDefiningOp());
- gatherElementalTree(maybeElemental, elementalParts, /*isOutOfOrder=*/false);
+ elementalParts = hlfir::ElementalTree::buildElementalTree(terminator);
}
void MaskedArrayExpr::generateNoneElementalPart(fir::FirOpBuilder &builder,
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 2712bfb..5889122 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Analysis/AliasAnalysis.h"
+#include "flang/Optimizer/Analysis/ArraySectionAnalyzer.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Dialect/FIROps.h"
@@ -88,13 +89,6 @@ private:
/// determines if the transformation can be applied to this elemental
static std::optional<MatchInfo> findMatch(hlfir::ElementalOp elemental);
- /// Returns the array indices for the given hlfir.designate.
- /// It recognizes the computations used to transform the one-based indices
- /// into the array's lb-based indices, and returns the one-based indices
- /// in these cases.
- static llvm::SmallVector<mlir::Value>
- getDesignatorIndices(hlfir::DesignateOp designate);
-
public:
using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;
@@ -167,344 +161,6 @@ containsReadOrWriteEffectOn(const mlir::MemoryEffects::EffectInstance &effect,
return mlir::AliasResult::NoAlias;
}
-// Helper class for analyzing two array slices represented
-// by two hlfir.designate operations.
-class ArraySectionAnalyzer {
-public:
- // The result of the analyzis is one of the values below.
- enum class SlicesOverlapKind {
- // Slices overlap is unknown.
- Unknown,
- // Slices are definitely identical.
- DefinitelyIdentical,
- // Slices are definitely disjoint.
- DefinitelyDisjoint,
- // Slices may be either disjoint or identical,
- // i.e. there is definitely no partial overlap.
- EitherIdenticalOrDisjoint
- };
-
- // Analyzes two hlfir.designate results and returns the overlap kind.
- // The callers may use this method when the alias analysis reports
- // an alias of some kind, so that we can run Fortran specific analysis
- // on the array slices to see if they are identical or disjoint.
- // Note that the alias analysis are not able to give such an answer
- // about the references.
- static SlicesOverlapKind analyze(mlir::Value ref1, mlir::Value ref2);
-
-private:
- struct SectionDesc {
- // An array section is described by <lb, ub, stride> tuple.
- // If the designator's subscript is not a triple, then
- // the section descriptor is constructed as <lb, nullptr, nullptr>.
- mlir::Value lb, ub, stride;
-
- SectionDesc(mlir::Value lb, mlir::Value ub, mlir::Value stride)
- : lb(lb), ub(ub), stride(stride) {
- assert(lb && "lower bound or index must be specified");
- normalize();
- }
-
- // Normalize the section descriptor:
- // 1. If UB is nullptr, then it is set to LB.
- // 2. If LB==UB, then stride does not matter,
- // so it is reset to nullptr.
- // 3. If STRIDE==1, then it is reset to nullptr.
- void normalize() {
- if (!ub)
- ub = lb;
- if (lb == ub)
- stride = nullptr;
- if (stride)
- if (auto val = fir::getIntIfConstant(stride))
- if (*val == 1)
- stride = nullptr;
- }
-
- bool operator==(const SectionDesc &other) const {
- return lb == other.lb && ub == other.ub && stride == other.stride;
- }
- };
-
- // Given an operand_iterator over the indices operands,
- // read the subscript values and return them as SectionDesc
- // updating the iterator. If isTriplet is true,
- // the subscript is a triplet, and the result is <lb, ub, stride>.
- // Otherwise, the subscript is a scalar index, and the result
- // is <index, nullptr, nullptr>.
- static SectionDesc readSectionDesc(mlir::Operation::operand_iterator &it,
- bool isTriplet) {
- if (isTriplet)
- return {*it++, *it++, *it++};
- return {*it++, nullptr, nullptr};
- }
-
- // Return the ordered lower and upper bounds of the section.
- // If stride is known to be non-negative, then the ordered
- // bounds match the <lb, ub> of the descriptor.
- // If stride is known to be negative, then the ordered
- // bounds are <ub, lb> of the descriptor.
- // If stride is unknown, we cannot deduce any order,
- // so the result is <nullptr, nullptr>
- static std::pair<mlir::Value, mlir::Value>
- getOrderedBounds(const SectionDesc &desc) {
- mlir::Value stride = desc.stride;
- // Null stride means stride=1.
- if (!stride)
- return {desc.lb, desc.ub};
- // Reverse the bounds, if stride is negative.
- if (auto val = fir::getIntIfConstant(stride)) {
- if (*val >= 0)
- return {desc.lb, desc.ub};
- else
- return {desc.ub, desc.lb};
- }
-
- return {nullptr, nullptr};
- }
-
- // Given two array sections <lb1, ub1, stride1> and
- // <lb2, ub2, stride2>, return true only if the sections
- // are known to be disjoint.
- //
- // For example, for any positive constant C:
- // X:Y does not overlap with (Y+C):Z
- // X:Y does not overlap with Z:(X-C)
- static bool areDisjointSections(const SectionDesc &desc1,
- const SectionDesc &desc2) {
- auto [lb1, ub1] = getOrderedBounds(desc1);
- auto [lb2, ub2] = getOrderedBounds(desc2);
- if (!lb1 || !lb2)
- return false;
- // Note that this comparison must be made on the ordered bounds,
- // otherwise 'a(x:y:1) = a(z:x-1:-1) + 1' may be incorrectly treated
- // as not overlapping (x=2, y=10, z=9).
- if (isLess(ub1, lb2) || isLess(ub2, lb1))
- return true;
- return false;
- }
-
- // Given two array sections <lb1, ub1, stride1> and
- // <lb2, ub2, stride2>, return true only if the sections
- // are known to be identical.
- //
- // For example:
- // <x, x, stride>
- // <x, nullptr, nullptr>
- //
- // These sections are identical, from the point of which array
- // elements are being addresses, even though the shape
- // of the array slices might be different.
- static bool areIdenticalSections(const SectionDesc &desc1,
- const SectionDesc &desc2) {
- if (desc1 == desc2)
- return true;
- return false;
- }
-
- // Return true, if v1 is known to be less than v2.
- static bool isLess(mlir::Value v1, mlir::Value v2);
-};
-
-ArraySectionAnalyzer::SlicesOverlapKind
-ArraySectionAnalyzer::analyze(mlir::Value ref1, mlir::Value ref2) {
- if (ref1 == ref2)
- return SlicesOverlapKind::DefinitelyIdentical;
-
- auto des1 = ref1.getDefiningOp<hlfir::DesignateOp>();
- auto des2 = ref2.getDefiningOp<hlfir::DesignateOp>();
- // We only support a pair of designators right now.
- if (!des1 || !des2)
- return SlicesOverlapKind::Unknown;
-
- if (des1.getMemref() != des2.getMemref()) {
- // If the bases are different, then there is unknown overlap.
- LLVM_DEBUG(llvm::dbgs() << "No identical base for:\n"
- << des1 << "and:\n"
- << des2 << "\n");
- return SlicesOverlapKind::Unknown;
- }
-
- // Require all components of the designators to be the same.
- // It might be too strict, e.g. we may probably allow for
- // different type parameters.
- if (des1.getComponent() != des2.getComponent() ||
- des1.getComponentShape() != des2.getComponentShape() ||
- des1.getSubstring() != des2.getSubstring() ||
- des1.getComplexPart() != des2.getComplexPart() ||
- des1.getTypeparams() != des2.getTypeparams()) {
- LLVM_DEBUG(llvm::dbgs() << "Different designator specs for:\n"
- << des1 << "and:\n"
- << des2 << "\n");
- return SlicesOverlapKind::Unknown;
- }
-
- // Analyze the subscripts.
- auto des1It = des1.getIndices().begin();
- auto des2It = des2.getIndices().begin();
- bool identicalTriplets = true;
- bool identicalIndices = true;
- for (auto [isTriplet1, isTriplet2] :
- llvm::zip(des1.getIsTriplet(), des2.getIsTriplet())) {
- SectionDesc desc1 = readSectionDesc(des1It, isTriplet1);
- SectionDesc desc2 = readSectionDesc(des2It, isTriplet2);
-
- // See if we can prove that any of the sections do not overlap.
- // This is mostly a Polyhedron/nf performance hack that looks for
- // particular relations between the lower and upper bounds
- // of the array sections, e.g. for any positive constant C:
- // X:Y does not overlap with (Y+C):Z
- // X:Y does not overlap with Z:(X-C)
- if (areDisjointSections(desc1, desc2))
- return SlicesOverlapKind::DefinitelyDisjoint;
-
- if (!areIdenticalSections(desc1, desc2)) {
- if (isTriplet1 || isTriplet2) {
- // For example:
- // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0)
- // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1)
- //
- // If all the triplets (section speficiers) are the same, then
- // we do not care if %0 is equal to %1 - the slices are either
- // identical or completely disjoint.
- //
- // Also, treat these as identical sections:
- // hlfir.designate %6#0 (%c2:%c2:%c1)
- // hlfir.designate %6#0 (%c2)
- identicalTriplets = false;
- LLVM_DEBUG(llvm::dbgs() << "Triplet mismatch for:\n"
- << des1 << "and:\n"
- << des2 << "\n");
- } else {
- identicalIndices = false;
- LLVM_DEBUG(llvm::dbgs() << "Indices mismatch for:\n"
- << des1 << "and:\n"
- << des2 << "\n");
- }
- }
- }
-
- if (identicalTriplets) {
- if (identicalIndices)
- return SlicesOverlapKind::DefinitelyIdentical;
- else
- return SlicesOverlapKind::EitherIdenticalOrDisjoint;
- }
-
- LLVM_DEBUG(llvm::dbgs() << "Different sections for:\n"
- << des1 << "and:\n"
- << des2 << "\n");
- return SlicesOverlapKind::Unknown;
-}
-
-bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) {
- auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
- auto *op = v.getDefiningOp();
- while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op))
- op = conv.getValue().getDefiningOp();
- return op;
- };
-
- auto isPositiveConstant = [](mlir::Value v) -> bool {
- if (auto val = fir::getIntIfConstant(v))
- return *val > 0;
- return false;
- };
-
- auto *op1 = removeConvert(v1);
- auto *op2 = removeConvert(v2);
- if (!op1 || !op2)
- return false;
-
- // Check if they are both constants.
- if (auto val1 = fir::getIntIfConstant(op1->getResult(0)))
- if (auto val2 = fir::getIntIfConstant(op2->getResult(0)))
- return *val1 < *val2;
-
- // Handle some variable cases (C > 0):
- // v2 = v1 + C
- // v2 = C + v1
- // v1 = v2 - C
- if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
- if ((addi.getLhs().getDefiningOp() == op1 &&
- isPositiveConstant(addi.getRhs())) ||
- (addi.getRhs().getDefiningOp() == op1 &&
- isPositiveConstant(addi.getLhs())))
- return true;
- if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
- if (subi.getLhs().getDefiningOp() == op2 &&
- isPositiveConstant(subi.getRhs()))
- return true;
- return false;
-}
-
-llvm::SmallVector<mlir::Value>
-ElementalAssignBufferization::getDesignatorIndices(
- hlfir::DesignateOp designate) {
- mlir::Value memref = designate.getMemref();
-
- // If the object is a box, then the indices may be adjusted
- // according to the box's lower bound(s). Scan through
- // the computations to try to find the one-based indices.
- if (mlir::isa<fir::BaseBoxType>(memref.getType())) {
- // Look for the following pattern:
- // %13 = fir.load %12 : !fir.ref<!fir.box<...>
- // %14:3 = fir.box_dims %13, %c0 : (!fir.box<...>, index) -> ...
- // %17 = arith.subi %14#0, %c1 : index
- // %18 = arith.addi %arg2, %17 : index
- // %19 = hlfir.designate %13 (%18) : (!fir.box<...>, index) -> ...
- //
- // %arg2 is a one-based index.
-
- auto isNormalizedLb = [memref](mlir::Value v, unsigned dim) {
- // Return true, if v and dim are such that:
- // %14:3 = fir.box_dims %13, %dim : (!fir.box<...>, index) -> ...
- // %17 = arith.subi %14#0, %c1 : index
- // %19 = hlfir.designate %13 (...) : (!fir.box<...>, index) -> ...
- if (auto subOp =
- mlir::dyn_cast_or_null<mlir::arith::SubIOp>(v.getDefiningOp())) {
- auto cst = fir::getIntIfConstant(subOp.getRhs());
- if (!cst || *cst != 1)
- return false;
- if (auto dimsOp = mlir::dyn_cast_or_null<fir::BoxDimsOp>(
- subOp.getLhs().getDefiningOp())) {
- if (memref != dimsOp.getVal() ||
- dimsOp.getResult(0) != subOp.getLhs())
- return false;
- auto dimsOpDim = fir::getIntIfConstant(dimsOp.getDim());
- return dimsOpDim && dimsOpDim == dim;
- }
- }
- return false;
- };
-
- llvm::SmallVector<mlir::Value> newIndices;
- for (auto index : llvm::enumerate(designate.getIndices())) {
- if (auto addOp = mlir::dyn_cast_or_null<mlir::arith::AddIOp>(
- index.value().getDefiningOp())) {
- for (unsigned opNum = 0; opNum < 2; ++opNum)
- if (isNormalizedLb(addOp->getOperand(opNum), index.index())) {
- newIndices.push_back(addOp->getOperand((opNum + 1) % 2));
- break;
- }
-
- // If new one-based index was not added, exit early.
- if (newIndices.size() <= index.index())
- break;
- }
- }
-
- // If any of the indices is not adjusted to the array's lb,
- // then return the original designator indices.
- if (newIndices.size() != designate.getIndices().size())
- return designate.getIndices();
-
- return newIndices;
- }
-
- return designate.getIndices();
-}
-
std::optional<ElementalAssignBufferization::MatchInfo>
ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
mlir::Operation::user_range users = elemental->getUsers();
@@ -627,22 +283,20 @@ ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
if (!res.isPartial()) {
if (auto designate =
effect.getValue().getDefiningOp<hlfir::DesignateOp>()) {
- ArraySectionAnalyzer::SlicesOverlapKind overlap =
- ArraySectionAnalyzer::analyze(match.array, designate.getMemref());
+ fir::ArraySectionAnalyzer::SlicesOverlapKind overlap =
+ fir::ArraySectionAnalyzer::analyze(match.array,
+ designate.getMemref());
if (overlap ==
- ArraySectionAnalyzer::SlicesOverlapKind::DefinitelyDisjoint)
+ fir::ArraySectionAnalyzer::SlicesOverlapKind::DefinitelyDisjoint)
continue;
- if (overlap == ArraySectionAnalyzer::SlicesOverlapKind::Unknown) {
+ if (overlap == fir::ArraySectionAnalyzer::SlicesOverlapKind::Unknown) {
LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate
<< " at " << elemental.getLoc() << "\n");
return std::nullopt;
}
- auto indices = getDesignatorIndices(designate);
- auto elementalIndices = elemental.getIndices();
- if (indices.size() == elementalIndices.size() &&
- std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
- elementalIndices.end()))
+ if (fir::ArraySectionAnalyzer::isDesignatingArrayInOrder(designate,
+ elemental))
continue;
LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate
@@ -727,9 +381,13 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
// Assign the element value to the array element for this iteration.
auto arrayElement =
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
- hlfir::AssignOp::create(
+ auto newAssign = hlfir::AssignOp::create(
builder, loc, elementValue, arrayElement, /*realloc=*/false,
/*keep_lhs_length_if_realloc=*/false, match->assign.getTemporaryLhs());
+ if (auto accessGroups =
+ match->assign.getOperation()->getAttrOfType<mlir::ArrayAttr>(
+ fir::getAccessGroupsAttrName()))
+ newAssign->setAttr(fir::getAccessGroupsAttrName(), accessGroups);
rewriter.eraseOp(match->assign);
rewriter.eraseOp(match->destroy);
@@ -788,6 +446,11 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
llvm::SmallVector<mlir::Value> extents =
hlfir::getIndexExtents(loc, builder, shape);
+ mlir::ArrayAttr accessGroups;
+ if (auto attrs = assign.getOperation()->getAttrOfType<mlir::ArrayAttr>(
+ fir::getAccessGroupsAttrName()))
+ accessGroups = attrs;
+
if (lhs.isSimplyContiguous() && extents.size() > 1) {
// Flatten the array to use a single assign loop, that can be better
// optimized.
@@ -824,7 +487,9 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
mlir::Value arrayElement =
hlfir::DesignateOp::create(builder, loc, fir::ReferenceType::get(eleTy),
flatArray, loopNest.oneBasedIndices);
- hlfir::AssignOp::create(builder, loc, rhs, arrayElement);
+ auto newAssign = hlfir::AssignOp::create(builder, loc, rhs, arrayElement);
+ if (accessGroups)
+ newAssign->setAttr(fir::getAccessGroupsAttrName(), accessGroups);
} else {
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
@@ -832,7 +497,9 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
builder.setInsertionPointToStart(loopNest.body);
auto arrayElement =
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
- hlfir::AssignOp::create(builder, loc, rhs, arrayElement);
+ auto newAssign = hlfir::AssignOp::create(builder, loc, rhs, arrayElement);
+ if (accessGroups)
+ newAssign->setAttr(fir::getAccessGroupsAttrName(), accessGroups);
}
rewriter.eraseOp(assign);
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp
index 63a5803..6bc5317 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp
@@ -8,6 +8,7 @@
#include "ScheduleOrderedAssignments.h"
#include "flang/Optimizer/Analysis/AliasAnalysis.h"
+#include "flang/Optimizer/Analysis/ArraySectionAnalyzer.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
@@ -23,7 +24,13 @@
/// Log RAW or WAW conflict.
[[maybe_unused]] static void logConflict(llvm::raw_ostream &os,
mlir::Value writtenOrReadVarA,
- mlir::Value writtenVarB);
+ mlir::Value writtenVarB,
+ bool isAligned = false);
+/// Log when a region must be retroactively saved.
+[[maybe_unused]] static void
+logRetroactiveSave(llvm::raw_ostream &os, mlir::Region &yieldRegion,
+ hlfir::Run &modifyingRun,
+ hlfir::RegionAssignOp currentAssign);
/// Log when an expression evaluation must be saved.
[[maybe_unused]] static void logSaveEvaluation(llvm::raw_ostream &os,
unsigned runid,
@@ -39,15 +46,129 @@ logStartScheduling(llvm::raw_ostream &os,
hlfir::OrderedAssignmentTreeOpInterface root);
/// Log op if effect value is not known.
[[maybe_unused]] static void
-logIfUnkownEffectValue(llvm::raw_ostream &os,
- mlir::MemoryEffects::EffectInstance effect,
- mlir::Operation &op);
+logIfUnknownEffectValue(llvm::raw_ostream &os,
+ mlir::MemoryEffects::EffectInstance effect,
+ mlir::Operation &op);
//===----------------------------------------------------------------------===//
// Scheduling Implementation
//===----------------------------------------------------------------------===//
+/// Is the apply using all the elemental indices in order?
+static bool isInOrderApply(hlfir::ApplyOp apply,
+ hlfir::ElementalOpInterface elemental) {
+ mlir::Region::BlockArgListType elementalIndices = elemental.getIndices();
+ if (elementalIndices.size() != apply.getIndices().size())
+ return false;
+ for (auto [elementalIdx, applyIdx] :
+ llvm::zip(elementalIndices, apply.getIndices()))
+ if (elementalIdx != applyIdx)
+ return false;
+ return true;
+}
+
+hlfir::ElementalTree
+hlfir::ElementalTree::buildElementalTree(mlir::Operation &regionTerminator) {
+ ElementalTree tree;
+ if (auto elementalAddr =
+ mlir::dyn_cast<hlfir::ElementalOpInterface>(regionTerminator)) {
+ // Vector subscripted designator (hlfir.elemental_addr terminator).
+ tree.gatherElementalTree(elementalAddr, /*isAppliedInOrder=*/true);
+ return tree;
+ }
+ // Try if elemental expression.
+ if (auto yield = mlir::dyn_cast<hlfir::YieldOp>(regionTerminator)) {
+ mlir::Value entity = yield.getEntity();
+ if (auto maybeElemental =
+ mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>(
+ entity.getDefiningOp()))
+ tree.gatherElementalTree(maybeElemental, /*isAppliedInOrder=*/true);
+ }
+ return tree;
+}
+
+// Check if op is an ElementalOpInterface that is part of this elemental tree.
+bool hlfir::ElementalTree::contains(mlir::Operation *op) const {
+ for (auto &p : tree)
+ if (p.first == op)
+ return true;
+ return false;
+}
+
+std::optional<bool> hlfir::ElementalTree::isOrdered(mlir::Operation *op) const {
+ for (auto &p : tree)
+ if (p.first == op)
+ return p.second;
+ return std::nullopt;
+}
+
+void hlfir::ElementalTree::gatherElementalTree(
+ hlfir::ElementalOpInterface elemental, bool isAppliedInOrder) {
+ if (!elemental)
+ return;
+ // Only inline an applied elemental that must be executed in order if the
+ // applying indices are in order. An hlfir::Elemental may have been created
+ // for a transformational like transpose, and Fortran 2018 standard
+ // section 10.2.3.2, point 10 imply that impure elemental sub-expression
+ // evaluations should not be masked if they are the arguments of
+ // transformational expressions.
+ if (!isAppliedInOrder && elemental.isOrdered())
+ return;
+
+ insert(elemental, isAppliedInOrder);
+ for (mlir::Operation &op : elemental.getElementalRegion().getOps())
+ if (auto apply = mlir::dyn_cast<hlfir::ApplyOp>(op)) {
+ bool isUnorderedApply =
+ !isAppliedInOrder || !isInOrderApply(apply, elemental);
+ auto maybeElemental = mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>(
+ apply.getExpr().getDefiningOp());
+ gatherElementalTree(maybeElemental, !isUnorderedApply);
+ }
+}
+
+void hlfir::ElementalTree::insert(hlfir::ElementalOpInterface elementalOp,
+ bool isAppliedInOrder) {
+ tree.push_back({elementalOp.getOperation(), isAppliedInOrder});
+}
+
+static bool isInOrderDesignate(hlfir::DesignateOp designate,
+ hlfir::ElementalTree *tree) {
+ if (!tree)
+ return false;
+ if (auto elemental =
+ designate->getParentOfType<hlfir::ElementalOpInterface>())
+ if (tree->isOrdered(elemental.getOperation()))
+ return fir::ArraySectionAnalyzer::isDesignatingArrayInOrder(designate,
+ elemental);
+ return false;
+}
+
+hlfir::DetailedEffectInstance::DetailedEffectInstance(
+ mlir::MemoryEffects::Effect *effect, mlir::OpOperand *value,
+ mlir::Value orderedElementalEffectOn)
+ : effectInstance(effect, value),
+ orderedElementalEffectOn(orderedElementalEffectOn) {}
+
+hlfir::DetailedEffectInstance::DetailedEffectInstance(
+ mlir::MemoryEffects::EffectInstance effectInst,
+ mlir::Value orderedElementalEffectOn)
+ : effectInstance(effectInst),
+ orderedElementalEffectOn(orderedElementalEffectOn) {}
+
+hlfir::DetailedEffectInstance
+hlfir::DetailedEffectInstance::getArrayReadEffect(mlir::OpOperand *array) {
+ return DetailedEffectInstance(mlir::MemoryEffects::Read::get(), array,
+ array->get());
+}
+
+hlfir::DetailedEffectInstance
+hlfir::DetailedEffectInstance::getArrayWriteEffect(mlir::OpOperand *array) {
+ return DetailedEffectInstance(mlir::MemoryEffects::Write::get(), array,
+ array->get());
+}
+
namespace {
+
/// Structure that is in charge of building the schedule. For each
/// hlfir.region_assign inside an ordered assignment tree, it is walked through
/// the parent operations and their "leaf" regions (that contain expression
@@ -99,20 +220,25 @@ public:
/// After all the dependent evaluation regions have been analyzed, create the
/// action to evaluate the assignment that was being analyzed.
- void finishSchedulingAssignment(hlfir::RegionAssignOp assign);
+ void finishSchedulingAssignment(hlfir::RegionAssignOp assign,
+ bool leafRegionsMayOnlyRead);
/// Once all the assignments have been analyzed and scheduled, return the
/// schedule. The scheduler object should not be used after this call.
hlfir::Schedule moveSchedule() { return std::move(schedule); }
private:
+ struct EvaluationState {
+ bool saved = false;
+ std::optional<hlfir::Schedule::iterator> modifiedInRun;
+ };
+
/// Save a conflicting region that is evaluating an expression that is
/// controlling or masking the current assignment, or is evaluating the
/// RHS/LHS.
- void
- saveEvaluation(mlir::Region &yieldRegion,
- llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effects,
- bool anyWrite);
+ void saveEvaluation(mlir::Region &yieldRegion,
+ llvm::ArrayRef<hlfir::DetailedEffectInstance> effects,
+ bool anyWrite);
/// Can the current assignment be schedule with the previous run. This is
/// only possible if the assignment and all of its dependencies have no side
@@ -120,19 +246,17 @@ private:
bool canFuseAssignmentWithPreviousRun();
/// Memory effects of the assignments being lowered.
- llvm::SmallVector<mlir::MemoryEffects::EffectInstance> assignEffects;
+ llvm::SmallVector<hlfir::DetailedEffectInstance> assignEffects;
/// Memory effects of the evaluations implied by the assignments
/// being lowered. They do not include the implicit writes
/// to the LHS of the assignments.
- llvm::SmallVector<mlir::MemoryEffects::EffectInstance> assignEvaluateEffects;
+ llvm::SmallVector<hlfir::DetailedEffectInstance> assignEvaluateEffects;
/// Memory effects of the unsaved evaluation region that are controlling or
/// masking the current assignments.
- llvm::SmallVector<mlir::MemoryEffects::EffectInstance>
- parentEvaluationEffects;
+ llvm::SmallVector<hlfir::DetailedEffectInstance> parentEvaluationEffects;
/// Same as parentEvaluationEffects, but for the current "leaf group" being
/// analyzed scheduled.
- llvm::SmallVector<mlir::MemoryEffects::EffectInstance>
- independentEvaluationEffects;
+ llvm::SmallVector<hlfir::DetailedEffectInstance> independentEvaluationEffects;
/// Were any region saved for the current assignment?
bool savedAnyRegionForCurrentAssignment = false;
@@ -140,7 +264,10 @@ private:
// Schedule being built.
hlfir::Schedule schedule;
/// Leaf regions that have been saved so far.
- llvm::SmallPtrSet<mlir::Region *, 16> savedRegions;
+ llvm::DenseMap<mlir::Region *, EvaluationState> regionStates;
+ /// Regions that have an aligned conflict with the current assignment.
+ llvm::SmallVector<mlir::Region *> pendingAlignedRegions;
+
/// Is schedule.back() a schedule that is only saving region with read
/// effects?
bool currentRunIsReadOnly = false;
@@ -171,9 +298,10 @@ static bool isForallIndex(mlir::Value var) {
/// side effect interface, or that are writing temporary variables that may be
/// hard to identify as such (one would have to prove the write is "local" to
/// the region even when the alloca may be outside of the region).
-static void gatherMemoryEffects(
+static void gatherMemoryEffectsImpl(
mlir::Region &region, bool mayOnlyRead,
- llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &effects) {
+ llvm::SmallVectorImpl<hlfir::DetailedEffectInstance> &effects,
+ hlfir::ElementalTree *tree = nullptr) {
/// This analysis is a simple walk of all the operations of the region that is
/// evaluating and yielding a value. This is a lot simpler and safer than
/// trying to walk back the SSA DAG from the yielded value. But if desired,
@@ -181,7 +309,7 @@ static void gatherMemoryEffects(
for (mlir::Operation &op : region.getOps()) {
if (op.hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) {
for (mlir::Region &subRegion : op.getRegions())
- gatherMemoryEffects(subRegion, mayOnlyRead, effects);
+ gatherMemoryEffectsImpl(subRegion, mayOnlyRead, effects, tree);
// In MLIR, RecursiveMemoryEffects can be combined with
// MemoryEffectOpInterface to describe extra effects on top of the
// effects of the nested operations. However, the presence of
@@ -214,17 +342,45 @@ static void gatherMemoryEffects(
interface.getEffects(opEffects);
for (auto &effect : opEffects)
if (!isForallIndex(effect.getValue())) {
+ mlir::Value array;
+ if (effect.getValue())
+ if (auto designate =
+ effect.getValue().getDefiningOp<hlfir::DesignateOp>())
+ if (isInOrderDesignate(designate, tree))
+ array = designate.getMemref();
+
if (mlir::isa<mlir::MemoryEffects::Read>(effect.getEffect())) {
- LLVM_DEBUG(logIfUnkownEffectValue(llvm::dbgs(), effect, op););
- effects.push_back(effect);
+ LLVM_DEBUG(logIfUnknownEffectValue(llvm::dbgs(), effect, op););
+ effects.emplace_back(effect, array);
} else if (!mayOnlyRead &&
mlir::isa<mlir::MemoryEffects::Write>(effect.getEffect())) {
- LLVM_DEBUG(logIfUnkownEffectValue(llvm::dbgs(), effect, op););
- effects.push_back(effect);
+ LLVM_DEBUG(logIfUnknownEffectValue(llvm::dbgs(), effect, op););
+ effects.emplace_back(effect, array);
}
}
}
}
+static void gatherMemoryEffects(
+ mlir::Region &region, bool mayOnlyRead,
+ llvm::SmallVectorImpl<hlfir::DetailedEffectInstance> &effects) {
+ if (!region.getParentOfType<hlfir::ForallOp>()) {
+ // TODO: leverage array access analysis for FORALL.
+ // While FORALL assignments can be array assignments, the iteration space
+ // is also driven by the FORALL indices, so the way ArraySectionAnalyzer
+ // results are used is not adequate for it.
+ // For instance "disjoint" array access cannot be ignored in:
+ // "forall (i=1:10) x(i+1,:) = x(i,:)".
+ // While identical access can probably also be accepted, this would deserve
+ // more thinking, it would probably make sense to also deal with "aligned
+ // scalar" access for them like in "forall (i=1:10) x(i) = x(i) + 1". For
+ // now this feature is disabled for inside FORALL.
+ hlfir::ElementalTree tree =
+ hlfir::ElementalTree::buildElementalTree(region.back().back());
+ gatherMemoryEffectsImpl(region, mayOnlyRead, effects, &tree);
+ return;
+ }
+ gatherMemoryEffectsImpl(region, mayOnlyRead, effects, /*tree=*/nullptr);
+}
/// Return the entity yielded by a region, or a null value if the region
/// is not terminated by a yield.
@@ -246,10 +402,14 @@ static mlir::OpOperand *getYieldedEntity(mlir::Region &region) {
static void gatherAssignEffects(
hlfir::RegionAssignOp regionAssign,
bool userDefAssignmentMayOnlyWriteToAssignedVariable,
- llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &assignEffects) {
+ llvm::SmallVectorImpl<hlfir::DetailedEffectInstance> &assignEffects) {
mlir::OpOperand *assignedVar = getYieldedEntity(regionAssign.getLhsRegion());
assert(assignedVar && "lhs cannot be an empty region");
- assignEffects.emplace_back(mlir::MemoryEffects::Write::get(), assignedVar);
+ if (regionAssign->getParentOfType<hlfir::ForallOp>())
+ assignEffects.emplace_back(mlir::MemoryEffects::Write::get(), assignedVar);
+ else
+ assignEffects.emplace_back(
+ hlfir::DetailedEffectInstance::getArrayWriteEffect(assignedVar));
if (!regionAssign.getUserDefinedAssignment().empty()) {
// The write effect on the INTENT(OUT) LHS argument is already taken
@@ -273,7 +433,7 @@ static void gatherAssignEffects(
static void gatherAssignEvaluationEffects(
hlfir::RegionAssignOp regionAssign,
bool userDefAssignmentMayOnlyWriteToAssignedVariable,
- llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &assignEffects) {
+ llvm::SmallVectorImpl<hlfir::DetailedEffectInstance> &assignEffects) {
gatherMemoryEffects(regionAssign.getLhsRegion(),
userDefAssignmentMayOnlyWriteToAssignedVariable,
assignEffects);
@@ -308,12 +468,57 @@ static mlir::Value getStorageSource(mlir::Value var) {
return source;
}
+namespace {
+
+/// Class to represent conflicts between several accesses (effects) to a memory
+/// location (read after write, write after write).
+struct ConflictKind {
+ enum Kind {
+ // None: The effects are not affecting the same memory location, or they are
+ // all reads.
+ None,
+ // Aligned: There are both read and write effects affecting the same memory
+ // location, but it is known that these effects are all accessing the memory
+ // location element by element in array order. This means the conflict does
+ // not introduce loop-carried dependencies.
+ Aligned,
+ // Any: There may be both read and write effects affecting the same memory
+ // in any way.
+ Any
+ };
+ Kind kind;
+
+ ConflictKind(Kind k) : kind(k) {}
+
+ static ConflictKind none() { return ConflictKind(None); }
+ static ConflictKind aligned() { return ConflictKind(Aligned); }
+ static ConflictKind any() { return ConflictKind(Any); }
+
+ bool isNone() const { return kind == None; }
+ bool isAligned() const { return kind == Aligned; }
+ bool isAny() const { return kind == Any; }
+
+ // Merge conflicts:
+ // none || none -> none
+ // aligned || <not any> -> aligned
+ // any || _ -> any
+ ConflictKind operator||(const ConflictKind &other) const {
+ if (kind == Any || other.kind == Any)
+ return any();
+ if (kind == Aligned || other.kind == Aligned)
+ return aligned();
+ return none();
+ }
+};
+} // namespace
+
/// Could there be any read or write in effectsA on a variable written to in
/// effectsB?
-static bool
-anyRAWorWAW(llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effectsA,
- llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effectsB,
+static ConflictKind
+anyRAWorWAW(llvm::ArrayRef<hlfir::DetailedEffectInstance> effectsA,
+ llvm::ArrayRef<hlfir::DetailedEffectInstance> effectsB,
fir::AliasAnalysis &aliasAnalysis) {
+ ConflictKind result = ConflictKind::none();
for (const auto &effectB : effectsB)
if (mlir::isa<mlir::MemoryEffects::Write>(effectB.getEffect())) {
mlir::Value writtenVarB = effectB.getValue();
@@ -325,38 +530,66 @@ anyRAWorWAW(llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effectsA,
mlir::Value writtenOrReadVarA = effectA.getValue();
if (!writtenVarB || !writtenOrReadVarA) {
LLVM_DEBUG(
- logConflict(llvm::dbgs(), writtenOrReadVarA, writtenVarB););
- return true; // unknown conflict.
+ logConflict(llvm::dbgs(), writtenOrReadVarA, writtenVarB));
+ return ConflictKind::any(); // unknown conflict.
}
writtenOrReadVarA = getStorageSource(writtenOrReadVarA);
if (!aliasAnalysis.alias(writtenOrReadVarA, writtenVarB).isNo()) {
+ mlir::Value arrayA = effectA.getOrderedElementalEffectOn();
+ mlir::Value arrayB = effectB.getOrderedElementalEffectOn();
+ if (arrayA && arrayB) {
+ if (arrayA == arrayB) {
+ result = result || ConflictKind::aligned();
+ LLVM_DEBUG(logConflict(llvm::dbgs(), writtenOrReadVarA,
+ writtenVarB, /*isAligned=*/true));
+ continue;
+ }
+ auto overlap = fir::ArraySectionAnalyzer::analyze(arrayA, arrayB);
+ if (overlap == fir::ArraySectionAnalyzer::SlicesOverlapKind::
+ DefinitelyDisjoint)
+ continue;
+ if (overlap == fir::ArraySectionAnalyzer::SlicesOverlapKind::
+ DefinitelyIdentical ||
+ overlap == fir::ArraySectionAnalyzer::SlicesOverlapKind::
+ EitherIdenticalOrDisjoint) {
+ result = result || ConflictKind::aligned();
+ LLVM_DEBUG(logConflict(llvm::dbgs(), writtenOrReadVarA,
+ writtenVarB, /*isAligned=*/true));
+ continue;
+ }
+ LLVM_DEBUG(llvm::dbgs() << "conflicting arrays:" << arrayA
+ << " and " << arrayB << "\n");
+ return ConflictKind::any();
+ }
LLVM_DEBUG(
- logConflict(llvm::dbgs(), writtenOrReadVarA, writtenVarB););
- return true;
+ logConflict(llvm::dbgs(), writtenOrReadVarA, writtenVarB));
+ return ConflictKind::any();
}
}
}
- return false;
+ return result;
}
/// Could there be any read or write in effectsA on a variable written to in
/// effectsB, or any read in effectsB on a variable written to in effectsA?
-static bool
-conflict(llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effectsA,
- llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effectsB) {
+static ConflictKind
+conflict(llvm::ArrayRef<hlfir::DetailedEffectInstance> effectsA,
+ llvm::ArrayRef<hlfir::DetailedEffectInstance> effectsB) {
fir::AliasAnalysis aliasAnalysis;
// (RAW || WAW) || (WAR || WAW).
- return anyRAWorWAW(effectsA, effectsB, aliasAnalysis) ||
- anyRAWorWAW(effectsB, effectsA, aliasAnalysis);
+ ConflictKind result = anyRAWorWAW(effectsA, effectsB, aliasAnalysis);
+ if (result.isAny())
+ return result;
+ return result || anyRAWorWAW(effectsB, effectsA, aliasAnalysis);
}
/// Could there be any write effects in "effects" affecting memory storages
/// that are not local to the current region.
static bool
-anyNonLocalWrite(llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effects,
+anyNonLocalWrite(llvm::ArrayRef<hlfir::DetailedEffectInstance> effects,
mlir::Region &region) {
return llvm::any_of(
- effects, [&region](const mlir::MemoryEffects::EffectInstance &effect) {
+ effects, [&region](const hlfir::DetailedEffectInstance &effect) {
if (mlir::isa<mlir::MemoryEffects::Write>(effect.getEffect())) {
if (mlir::Value v = effect.getValue()) {
v = getStorageSource(v);
@@ -393,9 +626,9 @@ void Scheduler::saveEvaluationIfConflict(mlir::Region &yieldRegion,
// If the region evaluation was previously executed and saved, the saved
// value will be used when evaluating the current assignment and this has
// no effects in the current assignment evaluation.
- if (savedRegions.contains(&yieldRegion))
+ if (regionStates[&yieldRegion].saved)
return;
- llvm::SmallVector<mlir::MemoryEffects::EffectInstance> effects;
+ llvm::SmallVector<hlfir::DetailedEffectInstance> effects;
gatherMemoryEffects(yieldRegion, leafRegionsMayOnlyRead, effects);
// Yield has no effect as such, but in the context of order assignments.
// The order assignments will usually read the yielded entity (except for
@@ -404,8 +637,13 @@ void Scheduler::saveEvaluationIfConflict(mlir::Region &yieldRegion,
// intent(inout)).
if (yieldIsImplicitRead) {
mlir::OpOperand *entity = getYieldedEntity(yieldRegion);
- if (entity && hlfir::isFortranVariableType(entity->get().getType()))
- effects.emplace_back(mlir::MemoryEffects::Read::get(), entity);
+ if (entity && hlfir::isFortranVariableType(entity->get().getType())) {
+ if (yieldRegion.getParentOfType<hlfir::ForallOp>())
+ effects.emplace_back(mlir::MemoryEffects::Read::get(), entity);
+ else
+ effects.emplace_back(
+ hlfir::DetailedEffectInstance::getArrayReadEffect(entity));
+ }
}
if (!leafRegionsMayOnlyRead && anyNonLocalWrite(effects, yieldRegion)) {
// Region with write effect must be executed only once (unless all writes
@@ -415,33 +653,58 @@ void Scheduler::saveEvaluationIfConflict(mlir::Region &yieldRegion,
<< "saving eval because write effect prevents re-evaluation"
<< "\n";);
saveEvaluation(yieldRegion, effects, /*anyWrite=*/true);
- } else if (conflict(effects, assignEffects)) {
- // Region that conflicts with the current assignments must be fully
- // evaluated and saved before doing the assignment (Note that it may
- // have already have been evaluated without saving it before, but this
- // implies that it never conflicted with a prior assignment, so its value
- // should be the same.)
- saveEvaluation(yieldRegion, effects, /*anyWrite=*/false);
- } else if (evaluationsMayConflict &&
- conflict(effects, assignEvaluateEffects)) {
- // If evaluations of the assignment may conflict with the yield
- // evaluations, we have to save yield evaluation.
- // For example, a WHERE mask might be written by the masked assignment
- // evaluations, and it has to be saved in this case:
- // where (mask) r = f() ! function f modifies mask
- saveEvaluation(yieldRegion, effects,
- anyNonLocalWrite(effects, yieldRegion));
} else {
- // Can be executed while doing the assignment.
- independentEvaluationEffects.append(effects.begin(), effects.end());
+ ConflictKind conflictKind = conflict(effects, assignEffects);
+ if (conflictKind.isAny()) {
+ // Region that conflicts with the current assignments must be fully
+ // evaluated and saved before doing the assignment (Note that it may
+ // have already been evaluated without saving it before, but this
+ // implies that it never conflicted with a prior assignment, so its value
+ // should be the same.)
+ saveEvaluation(yieldRegion, effects, /*anyWrite=*/false);
+ } else {
+ if (conflictKind.isAligned())
+ pendingAlignedRegions.push_back(&yieldRegion);
+
+ if (evaluationsMayConflict &&
+ !conflict(effects, assignEvaluateEffects).isNone()) {
+ // If evaluations of the assignment may conflict with the yield
+ // evaluations, we have to save yield evaluation.
+ // For example, a WHERE mask might be written by the masked assignment
+ // evaluations, and it has to be saved in this case:
+ // where (mask) r = f() ! function f modifies mask
+ saveEvaluation(yieldRegion, effects,
+ anyNonLocalWrite(effects, yieldRegion));
+ } else {
+ // Can be executed while doing the assignment.
+ independentEvaluationEffects.append(effects.begin(), effects.end());
+ }
+ }
}
}
void Scheduler::saveEvaluation(
mlir::Region &yieldRegion,
- llvm::ArrayRef<mlir::MemoryEffects::EffectInstance> effects,
- bool anyWrite) {
+ llvm::ArrayRef<hlfir::DetailedEffectInstance> effects, bool anyWrite) {
savedAnyRegionForCurrentAssignment = true;
+ auto &state = regionStates[&yieldRegion];
+ if (state.modifiedInRun) {
+ // The region was modified in a previous run, but we now realize we need its
+ // value. We must save it before that modification run.
+ auto &newRun = *schedule.emplace(*state.modifiedInRun, hlfir::Run{});
+ newRun.actions.emplace_back(hlfir::SaveEntity{&yieldRegion});
+ // We do not have the parent effects from that time easily available here.
+ // However, since we are saving a parent of the current assignment, its
+ // parents are also parents of the current assignment.
+ newRun.memoryEffects.append(parentEvaluationEffects.begin(),
+ parentEvaluationEffects.end());
+ newRun.memoryEffects.append(effects.begin(), effects.end());
+ state.saved = true;
+ LLVM_DEBUG(
+ logSaveEvaluation(llvm::dbgs(), /*runid=*/0, yieldRegion, anyWrite););
+ return;
+ }
+
if (anyWrite) {
// Create a new run just for regions with side effect. Further analysis
// could try to prove the effects do not conflict with the previous
@@ -465,7 +728,7 @@ void Scheduler::saveEvaluation(
schedule.back().memoryEffects.append(parentEvaluationEffects.begin(),
parentEvaluationEffects.end());
schedule.back().memoryEffects.append(effects.begin(), effects.end());
- savedRegions.insert(&yieldRegion);
+ state.saved = true;
LLVM_DEBUG(
logSaveEvaluation(llvm::dbgs(), schedule.size(), yieldRegion, anyWrite););
}
@@ -476,18 +739,78 @@ bool Scheduler::canFuseAssignmentWithPreviousRun() {
if (savedAnyRegionForCurrentAssignment || schedule.empty())
return false;
auto &previousRunEffects = schedule.back().memoryEffects;
- return !conflict(previousRunEffects, assignEffects) &&
- !conflict(previousRunEffects, parentEvaluationEffects) &&
- !conflict(previousRunEffects, independentEvaluationEffects);
+ return !conflict(previousRunEffects, assignEffects).isAny() &&
+ !conflict(previousRunEffects, parentEvaluationEffects).isAny() &&
+ !conflict(previousRunEffects, independentEvaluationEffects).isAny();
+}
+
+/// Gather the parents of (not included) \p node in reverse execution order.
+static void gatherParents(
+ hlfir::OrderedAssignmentTreeOpInterface node,
+ llvm::SmallVectorImpl<hlfir::OrderedAssignmentTreeOpInterface> &parents) {
+ while (node) {
+ auto parent =
+ mlir::dyn_cast_or_null<hlfir::OrderedAssignmentTreeOpInterface>(
+ node->getParentOp());
+ if (parent && parent.getSubTreeRegion() == node->getParentRegion()) {
+ parents.push_back(parent);
+ node = parent;
+ } else {
+ break;
+ }
+ }
+}
+
+// Build the list of the parent nodes for this assignment. The list is built
+// from the closest parent until the ordered assignment tree root (this is the
+// reverse of their execution order).
+static void gatherAssignmentParents(
+ hlfir::RegionAssignOp assign,
+ llvm::SmallVectorImpl<hlfir::OrderedAssignmentTreeOpInterface> &parents) {
+ gatherParents(mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
+ assign.getOperation()),
+ parents);
}
-void Scheduler::finishSchedulingAssignment(hlfir::RegionAssignOp assign) {
- // For now, always schedule each assignment in its own run. They could
- // be done as part of previous assignment runs if it is proven they have
- // no conflicting effects.
+void Scheduler::finishSchedulingAssignment(hlfir::RegionAssignOp assign,
+ bool leafRegionsMayOnlyRead) {
+ // Schedule the assignment in a new run, unless it can be fused with the
+ // previous run (if enabled and proven safe).
currentRunIsReadOnly = false;
- if (!tryFusingAssignments || !canFuseAssignmentWithPreviousRun())
+ bool fuse = tryFusingAssignments && canFuseAssignmentWithPreviousRun();
+ if (!fuse) {
+ // If we cannot fuse, we are about to start a new run.
+ // Check if any parent region was modified in a previous run and needs to be
+ // saved.
+ llvm::SmallVector<hlfir::OrderedAssignmentTreeOpInterface> parents;
+ gatherAssignmentParents(assign, parents);
+ for (auto parent : parents) {
+ llvm::SmallVector<mlir::Region *, 4> yieldRegions;
+ parent.getLeafRegions(yieldRegions);
+ for (mlir::Region *yieldRegion : yieldRegions) {
+ if (regionStates[yieldRegion].modifiedInRun &&
+ !regionStates[yieldRegion].saved) {
+ LLVM_DEBUG(logRetroactiveSave(
+ llvm::dbgs(), *yieldRegion,
+ **regionStates[yieldRegion].modifiedInRun, assign));
+ llvm::SmallVector<hlfir::DetailedEffectInstance> effects;
+ gatherMemoryEffects(*yieldRegion, leafRegionsMayOnlyRead, effects);
+ saveEvaluation(*yieldRegion, effects,
+ anyNonLocalWrite(effects, *yieldRegion));
+ }
+ }
+ }
schedule.emplace_back(hlfir::Run{});
+ }
+
+ // Mark pending aligned regions as modified in the current run (which is the
+ // last one).
+ auto runIt = std::prev(schedule.end());
+ for (mlir::Region *region : pendingAlignedRegions)
+ if (!regionStates[region].saved)
+ regionStates[region].modifiedInRun = runIt;
+ pendingAlignedRegions.clear();
+
schedule.back().actions.emplace_back(assign);
// TODO: when fusing, it would probably be best to filter the
// parentEvaluationEffects that already in the previous run effects (since
@@ -530,34 +853,6 @@ gatherAssignments(hlfir::OrderedAssignmentTreeOpInterface root,
}
}
-/// Gather the parents of (not included) \p node in reverse execution order.
-static void gatherParents(
- hlfir::OrderedAssignmentTreeOpInterface node,
- llvm::SmallVectorImpl<hlfir::OrderedAssignmentTreeOpInterface> &parents) {
- while (node) {
- auto parent =
- mlir::dyn_cast_or_null<hlfir::OrderedAssignmentTreeOpInterface>(
- node->getParentOp());
- if (parent && parent.getSubTreeRegion() == node->getParentRegion()) {
- parents.push_back(parent);
- node = parent;
- } else {
- break;
- }
- }
-}
-
-// Build the list of the parent nodes for this assignment. The list is built
-// from the closest parent until the ordered assignment tree root (this is the
-// revere of their execution order).
-static void gatherAssignmentParents(
- hlfir::RegionAssignOp assign,
- llvm::SmallVectorImpl<hlfir::OrderedAssignmentTreeOpInterface> &parents) {
- gatherParents(mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
- assign.getOperation()),
- parents);
-}
-
hlfir::Schedule
hlfir::buildEvaluationSchedule(hlfir::OrderedAssignmentTreeOpInterface root,
bool tryFusingAssignments) {
@@ -616,7 +911,7 @@ hlfir::buildEvaluationSchedule(hlfir::OrderedAssignmentTreeOpInterface root,
leafRegionsMayOnlyRead,
/*yieldIsImplicitRead=*/false);
scheduler.finishIndependentEvaluationGroup();
- scheduler.finishSchedulingAssignment(assign);
+ scheduler.finishSchedulingAssignment(assign, leafRegionsMayOnlyRead);
}
return scheduler.moveSchedule();
}
@@ -704,6 +999,25 @@ static llvm::raw_ostream &printRegionPath(llvm::raw_ostream &os,
return printRegionId(os, yieldRegion);
}
+[[maybe_unused]] static void
+logRetroactiveSave(llvm::raw_ostream &os, mlir::Region &yieldRegion,
+ hlfir::Run &modifyingRun,
+ hlfir::RegionAssignOp currentAssign) {
+ printRegionPath(os, yieldRegion) << " is modified in order by ";
+ bool first = true;
+ for (auto &action : modifyingRun.actions) {
+ if (auto *assign = std::get_if<hlfir::RegionAssignOp>(&action)) {
+ if (!first)
+ os << ", ";
+ printNodePath(os, assign->getOperation());
+ first = false;
+ }
+ }
+ os << " and is needed by ";
+ printNodePath(os, currentAssign.getOperation());
+ os << " that is scheduled in a later run\n";
+}
+
[[maybe_unused]] static void logSaveEvaluation(llvm::raw_ostream &os,
unsigned runid,
mlir::Region &yieldRegion,
@@ -721,13 +1035,14 @@ logAssignmentEvaluation(llvm::raw_ostream &os, unsigned runid,
[[maybe_unused]] static void logConflict(llvm::raw_ostream &os,
mlir::Value writtenOrReadVarA,
- mlir::Value writtenVarB) {
+ mlir::Value writtenVarB,
+ bool isAligned) {
auto printIfValue = [&](mlir::Value var) -> llvm::raw_ostream & {
if (!var)
return os << "<unknown>";
return os << var;
};
- os << "conflict: R/W: ";
+ os << "conflict" << (isAligned ? " (aligned)" : "") << ": R/W: ";
printIfValue(writtenOrReadVarA) << " W:";
printIfValue(writtenVarB) << "\n";
}
@@ -743,9 +1058,9 @@ logStartScheduling(llvm::raw_ostream &os,
}
[[maybe_unused]] static void
-logIfUnkownEffectValue(llvm::raw_ostream &os,
- mlir::MemoryEffects::EffectInstance effect,
- mlir::Operation &op) {
+logIfUnknownEffectValue(llvm::raw_ostream &os,
+ mlir::MemoryEffects::EffectInstance effect,
+ mlir::Operation &op) {
if (effect.getValue() != nullptr)
return;
os << "unknown effected value (";
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.h b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.h
index 2ed242e..7f479ab 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.h
+++ b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.h
@@ -15,9 +15,30 @@
#define OPTIMIZER_HLFIR_TRANSFORM_SCHEDULEORDEREDASSIGNMENTS_H
#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include <list>
namespace hlfir {
+struct ElementalTree {
+ // build an elemental tree given a masked region terminator.
+ static ElementalTree buildElementalTree(mlir::Operation &regionTerminator);
+ // Check if op is an ElementalOpInterface that is part of this elemental tree.
+ bool contains(mlir::Operation *op) const;
+
+ std::optional<bool> isOrdered(mlir::Operation *op) const;
+
+private:
+ void gatherElementalTree(hlfir::ElementalOpInterface elemental,
+ bool isAppliedInOrder);
+ void insert(hlfir::ElementalOpInterface elementalOp, bool isAppliedInOrder);
+ // List of ElementalOpInterface operation forming this tree, as well as a
+ // Boolean to indicate if they are applied in order (that is, if their
+ // indexing space is the same as the one for the array yielded by the mask
+ // region that owns this tree).
+ llvm::SmallVector<std::pair<mlir::Operation *, bool>> tree;
+};
+
/// Structure to represent that the value yielded by some region
/// must be fully evaluated and saved for all index values at
/// a given point of the ordered assignment tree evaluation.
@@ -29,6 +50,37 @@ struct SaveEntity {
mlir::Value getSavedValue();
};
+/// Wrapper class around mlir::MemoryEffects::EffectInstance that
+/// allows providing an extra array value that indicates that the
+/// effect is done element by element in array order (one element
+/// accessed at each iteration of the ordered assignment iteration
+/// space).
+class DetailedEffectInstance {
+public:
+ DetailedEffectInstance(mlir::MemoryEffects::Effect *effect,
+ mlir::OpOperand *value = nullptr,
+ mlir::Value orderedElementalEffectOn = nullptr);
+ DetailedEffectInstance(mlir::MemoryEffects::EffectInstance effectInstance,
+ mlir::Value orderedElementalEffectOn = nullptr);
+
+ static DetailedEffectInstance getArrayReadEffect(mlir::OpOperand *array);
+ static DetailedEffectInstance getArrayWriteEffect(mlir::OpOperand *array);
+
+ mlir::Value getValue() const { return effectInstance.getValue(); }
+ mlir::MemoryEffects::Effect *getEffect() const {
+ return effectInstance.getEffect();
+ }
+ mlir::Value getOrderedElementalEffectOn() const {
+ return orderedElementalEffectOn;
+ }
+
+private:
+ mlir::MemoryEffects::EffectInstance effectInstance;
+ // Array whose elements are affected in array order by the
+ // ordered assignment iterations. Null value otherwise.
+ mlir::Value orderedElementalEffectOn;
+};
+
/// A run is a list of actions required to evaluate an ordered assignment tree
/// that can be done in the same loop nest.
/// The actions can evaluate and saves element values into temporary or evaluate
@@ -42,11 +94,11 @@ struct Run {
/// the assignment part of an hlfir::RegionAssignOp.
using Action = std::variant<hlfir::RegionAssignOp, SaveEntity>;
llvm::SmallVector<Action> actions;
- llvm::SmallVector<mlir::MemoryEffects::EffectInstance> memoryEffects;
+ llvm::SmallVector<DetailedEffectInstance> memoryEffects;
};
/// List of runs to be executed in order to evaluate an order assignment tree.
-using Schedule = llvm::SmallVector<Run>;
+using Schedule = std::list<Run>;
/// Example of schedules and run, and what they mean:
/// Fortran: forall (i=i:10) x(i) = y(i)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index ce8ebaa..cc39652 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -931,6 +931,37 @@ private:
mlir::Value genScalarAdd(mlir::Value value1, mlir::Value value2);
};
+/// Reduction converter for Product.
+class ProductAsElementalConverter
+ : public NumericReductionAsElementalConverterBase<hlfir::ProductOp> {
+ using Base = NumericReductionAsElementalConverterBase;
+
+public:
+ ProductAsElementalConverter(hlfir::ProductOp op,
+ mlir::PatternRewriter &rewriter)
+ : Base{op, rewriter} {}
+
+private:
+ virtual llvm::SmallVector<mlir::Value> genReductionInitValues(
+ [[maybe_unused]] mlir::ValueRange oneBasedIndices,
+ [[maybe_unused]] const llvm::SmallVectorImpl<mlir::Value> &extents)
+ final {
+ return {fir::factory::createOneValue(builder, loc, getResultElementType())};
+ }
+ virtual llvm::SmallVector<mlir::Value>
+ reduceOneElement(const llvm::SmallVectorImpl<mlir::Value> &currentValue,
+ hlfir::Entity array,
+ mlir::ValueRange oneBasedIndices) final {
+ checkReductions(currentValue);
+ hlfir::Entity elementValue =
+ hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
+ return {genScalarMult(currentValue[0], elementValue)};
+ }
+
+ // Generate scalar multiplication of the two values (of the same data type).
+ mlir::Value genScalarMult(mlir::Value value1, mlir::Value value2);
+};
+
/// Base class for logical reductions like ALL, ANY, COUNT.
/// They do not have MASK and FastMathFlags.
template <typename OpT>
@@ -1194,6 +1225,20 @@ mlir::Value SumAsElementalConverter::genScalarAdd(mlir::Value value1,
llvm_unreachable("unsupported SUM reduction type");
}
+mlir::Value ProductAsElementalConverter::genScalarMult(mlir::Value value1,
+ mlir::Value value2) {
+ mlir::Type ty = value1.getType();
+ assert(ty == value2.getType() && "reduction values' types do not match");
+ if (mlir::isa<mlir::FloatType>(ty))
+ return mlir::arith::MulFOp::create(builder, loc, value1, value2);
+ else if (mlir::isa<mlir::ComplexType>(ty))
+ return fir::MulcOp::create(builder, loc, value1, value2);
+ else if (mlir::isa<mlir::IntegerType>(ty))
+ return mlir::arith::MulIOp::create(builder, loc, value1, value2);
+
+ llvm_unreachable("unsupported MUL reduction type");
+}
+
mlir::Value ReductionAsElementalConverter::genMaskValue(
mlir::Value mask, mlir::Value isPresentPred, mlir::ValueRange indices) {
mlir::OpBuilder::InsertionGuard guard(builder);
@@ -1265,6 +1310,9 @@ public:
} else if constexpr (std::is_same_v<Op, hlfir::SumOp>) {
SumAsElementalConverter converter{op, rewriter};
return converter.convert();
+ } else if constexpr (std::is_same_v<Op, hlfir::ProductOp>) {
+ ProductAsElementalConverter converter{op, rewriter};
+ return converter.convert();
}
return rewriter.notifyMatchFailure(op, "unexpected reduction operation");
}
@@ -1371,15 +1419,12 @@ private:
}
/// The indices computations for the array shifts are done using I64 type.
- /// For CSHIFT, all computations do not overflow signed and unsigned I64.
- /// For EOSHIFT, some computations may involve negative shift values,
- /// so using no-unsigned wrap flag would be incorrect.
+ /// For CSHIFT, and EOSHIFT all computations do not overflow signed I64.
+ /// While no-unsigned wrap could be set on some operation generated for
+ /// CSHIFT, it is in general unsafe to mix with computations involving
+ /// user defined bounds that may be negative.
static void setArithOverflowFlags(Op op, fir::FirOpBuilder &builder) {
- if constexpr (std::is_same_v<Op, hlfir::EOShiftOp>)
- builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw);
- else
- builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw |
- mlir::arith::IntegerOverflowFlags::nuw);
+ builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nsw);
}
/// Return the element type of the EOSHIFT boundary that may be omitted
@@ -1841,11 +1886,9 @@ private:
hlfir::Entity srcArray = array;
if (exposeContiguity && mlir::isa<fir::BaseBoxType>(srcArray.getType())) {
assert(dimVal == 1 && "can expose contiguity only for dim 1");
- llvm::SmallVector<mlir::Value, maxRank> arrayLbounds =
- hlfir::genLowerbounds(loc, builder, arrayShape, array.getRank());
hlfir::Entity section =
- hlfir::gen1DSection(loc, builder, srcArray, dimVal, arrayLbounds,
- arrayExtents, oneBasedIndices, typeParams);
+ hlfir::gen1DSection(loc, builder, srcArray, dimVal, arrayExtents,
+ oneBasedIndices, typeParams);
mlir::Value addr = hlfir::genVariableRawAddress(loc, builder, section);
mlir::Value shape = hlfir::genShape(loc, builder, section);
mlir::Type boxType = fir::wrapInClassOrBoxType(
@@ -3158,6 +3201,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.insert<TransposeAsElementalConversion>(context);
patterns.insert<ReductionConversion<hlfir::SumOp>>(context);
+ patterns.insert<ReductionConversion<hlfir::ProductOp>>(context);
patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context);
patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context);
patterns.insert<CmpCharOpConversion>(context);
diff --git a/flang/lib/Optimizer/OpenACC/Analysis/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/Analysis/CMakeLists.txt
new file mode 100644
index 0000000..d9dda9d
--- /dev/null
+++ b/flang/lib/Optimizer/OpenACC/Analysis/CMakeLists.txt
@@ -0,0 +1,24 @@
+add_flang_library(FIROpenACCAnalysis
+ FIROpenACCSupportAnalysis.cpp
+
+ DEPENDS
+ FIRAnalysis
+ FIRDialect
+ FIROpenACCSupport
+ HLFIRDialect
+
+ LINK_LIBS
+ FIRAnalysis
+ FIRDialect
+ FIROpenACCSupport
+ HLFIRDialect
+
+ MLIR_DEPS
+ MLIROpenACCDialect
+ MLIROpenACCUtils
+
+ MLIR_LIBS
+ MLIROpenACCDialect
+ MLIROpenACCUtils
+)
+
diff --git a/flang/lib/Optimizer/OpenACC/Analysis/FIROpenACCSupportAnalysis.cpp b/flang/lib/Optimizer/OpenACC/Analysis/FIROpenACCSupportAnalysis.cpp
new file mode 100644
index 0000000..3ad3188
--- /dev/null
+++ b/flang/lib/Optimizer/OpenACC/Analysis/FIROpenACCSupportAnalysis.cpp
@@ -0,0 +1,56 @@
+//===- FIROpenACCSupportAnalysis.cpp - FIR OpenACCSupport Analysis -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the FIR-specific OpenACCSupport analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/OpenACC/Analysis/FIROpenACCSupportAnalysis.h"
+
+#include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/OpenACC/Support/FIROpenACCUtils.h"
+#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
+
+using namespace mlir;
+
+namespace fir {
+namespace acc {
+
+std::string FIROpenACCSupportAnalysis::getVariableName(Value v) {
+ return fir::acc::getVariableName(v, /*preferDemangledName=*/true);
+}
+
+std::string FIROpenACCSupportAnalysis::getRecipeName(mlir::acc::RecipeKind kind,
+ Type type, Value var) {
+ return fir::acc::getRecipeName(kind, type, var);
+}
+
+mlir::InFlightDiagnostic
+FIROpenACCSupportAnalysis::emitNYI(Location loc, const Twine &message) {
+ TODO(loc, message);
+ // Should be unreachable, but we return an actual diagnostic
+ // to satisfy the interface.
+ return mlir::emitError(loc, "not yet implemented: " + message.str());
+}
+
+bool FIROpenACCSupportAnalysis::isValidValueUse(Value v, Region &region) {
+ // First check using the base utility.
+ if (mlir::acc::isValidValueUse(v, region))
+ return true;
+
+ // FIR-specific: fir.logical is a trivial scalar type that can be
+ // passed by value.
+ if (mlir::isa<fir::LogicalType>(v.getType()))
+ return true;
+
+ return false;
+}
+
+} // namespace acc
+} // namespace fir
diff --git a/flang/lib/Optimizer/OpenACC/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/CMakeLists.txt
index 790b9fd..16a4025 100644
--- a/flang/lib/Optimizer/OpenACC/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenACC/CMakeLists.txt
@@ -1,2 +1,3 @@
+add_subdirectory(Analysis)
add_subdirectory(Support)
add_subdirectory(Transforms)
diff --git a/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt
index ef67ab1..9ff46c7 100644
--- a/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenACC/Support/CMakeLists.txt
@@ -2,10 +2,14 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
add_flang_library(FIROpenACCSupport
FIROpenACCAttributes.cpp
+ FIROpenACCOpsInterfaces.cpp
FIROpenACCTypeInterfaces.cpp
+ FIROpenACCUtils.cpp
RegisterOpenACCExtensions.cpp
DEPENDS
+ CUFAttrs
+ CUFDialect
FIRBuilder
FIRDialect
FIRDialectSupport
@@ -13,6 +17,8 @@ add_flang_library(FIROpenACCSupport
HLFIRDialect
LINK_LIBS
+ CUFAttrs
+ CUFDialect
FIRBuilder
FIRCodeGenDialect
FIRDialect
@@ -22,7 +28,9 @@ add_flang_library(FIROpenACCSupport
MLIR_DEPS
MLIROpenACCDialect
+ MLIROpenACCUtils
MLIR_LIBS
MLIROpenACCDialect
+ MLIROpenACCUtils
)
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
new file mode 100644
index 0000000..fc654e4
--- /dev/null
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
@@ -0,0 +1,227 @@
+//===-- FIROpenACCOpsInterfaces.cpp ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implementation of external operation interfaces for FIR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h"
+
+#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Support/InternalNames.h"
+#include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/SmallSet.h"
+
+namespace fir::acc {
+
+template <>
+mlir::Value PartialEntityAccessModel<fir::ArrayCoorOp>::getBaseEntity(
+ mlir::Operation *op) const {
+ return mlir::cast<fir::ArrayCoorOp>(op).getMemref();
+}
+
+template <>
+mlir::Value PartialEntityAccessModel<fir::CoordinateOp>::getBaseEntity(
+ mlir::Operation *op) const {
+ return mlir::cast<fir::CoordinateOp>(op).getRef();
+}
+
+template <>
+mlir::Value PartialEntityAccessModel<hlfir::DesignateOp>::getBaseEntity(
+ mlir::Operation *op) const {
+ return mlir::cast<hlfir::DesignateOp>(op).getMemref();
+}
+
+mlir::Value PartialEntityAccessModel<fir::DeclareOp>::getBaseEntity(
+ mlir::Operation *op) const {
+ auto declareOp = mlir::cast<fir::DeclareOp>(op);
+ // If storage is present, return it (partial view case)
+ if (mlir::Value storage = declareOp.getStorage())
+ return storage;
+ // Otherwise return the memref (complete view case)
+ return declareOp.getMemref();
+}
+
+bool PartialEntityAccessModel<fir::DeclareOp>::isCompleteView(
+ mlir::Operation *op) const {
+ // Complete view if storage is absent
+ return !mlir::cast<fir::DeclareOp>(op).getStorage();
+}
+
+mlir::Value PartialEntityAccessModel<hlfir::DeclareOp>::getBaseEntity(
+ mlir::Operation *op) const {
+ auto declareOp = mlir::cast<hlfir::DeclareOp>(op);
+ // If storage is present, return it (partial view case)
+ if (mlir::Value storage = declareOp.getStorage())
+ return storage;
+ // Otherwise return the memref (complete view case)
+ return declareOp.getMemref();
+}
+
+bool PartialEntityAccessModel<hlfir::DeclareOp>::isCompleteView(
+ mlir::Operation *op) const {
+ // Complete view if storage is absent
+ return !mlir::cast<hlfir::DeclareOp>(op).getStorage();
+}
+
+mlir::SymbolRefAttr AddressOfGlobalModel::getSymbol(mlir::Operation *op) const {
+ return mlir::cast<fir::AddrOfOp>(op).getSymbolAttr();
+}
+
+bool GlobalVariableModel::isConstant(mlir::Operation *op) const {
+ auto globalOp = mlir::cast<fir::GlobalOp>(op);
+ return globalOp.getConstant().has_value();
+}
+
+mlir::Region *GlobalVariableModel::getInitRegion(mlir::Operation *op) const {
+ auto globalOp = mlir::cast<fir::GlobalOp>(op);
+ return globalOp.hasInitializationBody() ? &globalOp.getRegion() : nullptr;
+}
+
+bool GlobalVariableModel::isDeviceData(mlir::Operation *op) const {
+ if (auto dataAttr = cuf::getDataAttr(op))
+ return cuf::isDeviceDataAttribute(dataAttr.getValue());
+ return false;
+}
+
+// Helper to recursively process address-of operations in derived type
+// descriptors and collect all needed fir.globals.
+static void processAddrOfOpInDerivedTypeDescriptor(
+ fir::AddrOfOp addrOfOp, mlir::SymbolTable &symTab,
+ llvm::SmallSet<mlir::Operation *, 16> &globalsSet,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols) {
+ if (auto globalOp = symTab.lookup<fir::GlobalOp>(
+ addrOfOp.getSymbol().getLeafReference().getValue())) {
+ if (globalsSet.contains(globalOp))
+ return;
+ globalsSet.insert(globalOp);
+ symbols.push_back(addrOfOp.getSymbolAttr());
+ globalOp.walk([&](fir::AddrOfOp op) {
+ processAddrOfOpInDerivedTypeDescriptor(op, symTab, globalsSet, symbols);
+ });
+ }
+}
+
+// Utility to collect referenced symbols for type descriptors of derived types.
+// This is the common logic for operations that may require type descriptor
+// globals.
+static void collectReferencedSymbolsForType(
+ mlir::Type ty, mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) {
+ ty = fir::getDerivedType(fir::unwrapRefType(ty));
+
+ // Look for type descriptor globals only if it's a derived (record) type
+ if (auto recTy = mlir::dyn_cast_if_present<fir::RecordType>(ty)) {
+ // If no symbol table provided, simply add the type descriptor name
+ if (!symbolTable) {
+ symbols.push_back(mlir::SymbolRefAttr::get(
+ op->getContext(),
+ fir::NameUniquer::getTypeDescriptorName(recTy.getName())));
+ return;
+ }
+
+ // Otherwise, do full lookup and recursive processing
+ llvm::SmallSet<mlir::Operation *, 16> globalsSet;
+
+ fir::GlobalOp globalOp = symbolTable->lookup<fir::GlobalOp>(
+ fir::NameUniquer::getTypeDescriptorName(recTy.getName()));
+ if (!globalOp)
+ globalOp = symbolTable->lookup<fir::GlobalOp>(
+ fir::NameUniquer::getTypeDescriptorAssemblyName(recTy.getName()));
+
+ if (globalOp) {
+ globalsSet.insert(globalOp);
+ symbols.push_back(
+ mlir::SymbolRefAttr::get(op->getContext(), globalOp.getSymName()));
+ globalOp.walk([&](fir::AddrOfOp addrOp) {
+ processAddrOfOpInDerivedTypeDescriptor(addrOp, *symbolTable, globalsSet,
+ symbols);
+ });
+ }
+ }
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::AllocaOp>::getReferencedSymbols(
+ mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) const {
+ auto allocaOp = mlir::cast<fir::AllocaOp>(op);
+ collectReferencedSymbolsForType(allocaOp.getType(), op, symbols, symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::EmboxOp>::getReferencedSymbols(
+ mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) const {
+ auto emboxOp = mlir::cast<fir::EmboxOp>(op);
+ collectReferencedSymbolsForType(emboxOp.getMemref().getType(), op, symbols,
+ symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::ReboxOp>::getReferencedSymbols(
+ mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) const {
+ auto reboxOp = mlir::cast<fir::ReboxOp>(op);
+ collectReferencedSymbolsForType(reboxOp.getBox().getType(), op, symbols,
+ symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::TypeDescOp>::getReferencedSymbols(
+ mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) const {
+ auto typeDescOp = mlir::cast<fir::TypeDescOp>(op);
+ collectReferencedSymbolsForType(typeDescOp.getInType(), op, symbols,
+ symbolTable);
+}
+
+template <>
+bool OperationMoveModel<mlir::acc::LoopOp>::canMoveFromDescendant(
+ mlir::Operation *op, mlir::Operation *descendant,
+ mlir::Operation *candidate) const {
+ // It should be always allowed to move operations from descendants
+ // of acc.loop into the acc.loop.
+ return true;
+}
+
+template <>
+bool OperationMoveModel<mlir::acc::LoopOp>::canMoveOutOf(
+ mlir::Operation *op, mlir::Operation *candidate) const {
+ // Disallow moving operations, which have operands that are referenced
+ // in the data operands (e.g. in [first]private() etc.) of the acc.loop.
+ // For example:
+ // %17 = acc.private var(%16 : !fir.box<!fir.array<?xf32>>)
+ // acc.loop private(%17 : !fir.box<!fir.array<?xf32>>) ... {
+ // %19 = fir.box_addr %17
+ // }
+ // We cannot hoist %19 without violating assumptions that OpenACC
+ // transformations rely on.
+
+ // In general, some movement out of acc.loop is allowed,
+ // so return true if candidate is nullptr.
+ if (!candidate)
+ return true;
+
+ auto loopOp = mlir::cast<mlir::acc::LoopOp>(op);
+ unsigned numDataOperands = loopOp.getNumDataOperands();
+ for (unsigned i = 0; i < numDataOperands; ++i) {
+ mlir::Value dataOperand = loopOp.getDataOperand(i);
+ if (llvm::any_of(candidate->getOperands(),
+ [&](mlir::Value candidateOperand) {
+ return dataOperand == candidateOperand;
+ }))
+ return false;
+ }
+ return true;
+}
+
+} // namespace fir::acc
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
index ed9e41c..9ced235 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
@@ -15,12 +15,15 @@
#include "flang/Optimizer/Builder/DirectivesCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
+#include "flang/Optimizer/Builder/IntrinsicCall.h"
+#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
#include "flang/Optimizer/Dialect/FIRCG/CGOps.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/OpenACC/Support/FIROpenACCUtils.h"
#include "flang/Optimizer/Support/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -193,6 +196,28 @@ OpenACCMappableModel<fir::PointerType>::getOffsetInBytes(
mlir::Type type, mlir::Value var, mlir::ValueRange accBounds,
const mlir::DataLayout &dataLayout) const;
+template <typename Ty>
+bool OpenACCMappableModel<Ty>::hasUnknownDimensions(mlir::Type type) const {
+ assert(fir::isa_ref_type(type) && "expected FIR reference type");
+ return fir::hasDynamicSize(fir::unwrapRefType(type));
+}
+
+template bool OpenACCMappableModel<fir::ReferenceType>::hasUnknownDimensions(
+ mlir::Type type) const;
+
+template bool OpenACCMappableModel<fir::HeapType>::hasUnknownDimensions(
+ mlir::Type type) const;
+
+template bool OpenACCMappableModel<fir::PointerType>::hasUnknownDimensions(
+ mlir::Type type) const;
+
+template <>
+bool OpenACCMappableModel<fir::BaseBoxType>::hasUnknownDimensions(
+ mlir::Type type) const {
+ // Descriptor-based entities have dimensions encoded.
+ return false;
+}
+
static llvm::SmallVector<mlir::Value>
generateSeqTyAccBounds(fir::SequenceType seqType, mlir::Value var,
mlir::OpBuilder &builder) {
@@ -202,48 +227,53 @@ generateSeqTyAccBounds(fir::SequenceType seqType, mlir::Value var,
fir::FirOpBuilder firBuilder(builder, var.getDefiningOp());
mlir::Location loc = var.getLoc();
- if (seqType.hasDynamicExtents() || seqType.hasUnknownShape()) {
- if (auto boxAddr =
- mlir::dyn_cast_if_present<fir::BoxAddrOp>(var.getDefiningOp())) {
- mlir::Value box = boxAddr.getVal();
- auto res =
- hlfir::translateToExtendedValue(loc, firBuilder, hlfir::Entity(box));
- fir::ExtendedValue exv = res.first;
- mlir::Value boxRef = box;
- if (auto boxPtr = mlir::cast<mlir::acc::MappableType>(box.getType())
- .getVarPtr(box)) {
- boxRef = boxPtr;
+ // If [hl]fir.declare is visible, extract the bounds from the declaration's
+ // shape (if it is provided).
+ if (mlir::isa<hlfir::DeclareOp, fir::DeclareOp>(var.getDefiningOp())) {
+ mlir::Value zero =
+ firBuilder.createIntegerConstant(loc, builder.getIndexType(), 0);
+ mlir::Value one =
+ firBuilder.createIntegerConstant(loc, builder.getIndexType(), 1);
+
+ mlir::Value shape;
+ if (auto declareOp =
+ mlir::dyn_cast_if_present<fir::DeclareOp>(var.getDefiningOp()))
+ shape = declareOp.getShape();
+ else if (auto declareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>(
+ var.getDefiningOp()))
+ shape = declareOp.getShape();
+
+ const bool strideIncludeLowerExtent = true;
+
+ llvm::SmallVector<mlir::Value> accBounds;
+ mlir::Operation *anyShapeOp = shape ? shape.getDefiningOp() : nullptr;
+ if (auto shapeOp = mlir::dyn_cast_if_present<fir::ShapeOp>(anyShapeOp)) {
+ mlir::Value cummulativeExtent = one;
+ for (auto extent : shapeOp.getExtents()) {
+ mlir::Value upperbound =
+ mlir::arith::SubIOp::create(builder, loc, extent, one);
+ mlir::Value stride = one;
+ if (strideIncludeLowerExtent) {
+ stride = cummulativeExtent;
+ cummulativeExtent = mlir::arith::MulIOp::create(
+ builder, loc, cummulativeExtent, extent);
+ }
+ auto accBound = mlir::acc::DataBoundsOp::create(
+ builder, loc, mlir::acc::DataBoundsType::get(builder.getContext()),
+ /*lowerbound=*/zero, /*upperbound=*/upperbound,
+ /*extent=*/extent, /*stride=*/stride, /*strideInBytes=*/false,
+ /*startIdx=*/one);
+ accBounds.push_back(accBound);
}
- // TODO: Handle Fortran optional.
- const mlir::Value isPresent;
- fir::factory::AddrAndBoundsInfo info(box, boxRef, isPresent,
- box.getType());
- return fir::factory::genBoundsOpsFromBox<mlir::acc::DataBoundsOp,
- mlir::acc::DataBoundsType>(
- firBuilder, loc, exv, info);
- }
-
- if (mlir::isa<hlfir::DeclareOp, fir::DeclareOp>(var.getDefiningOp())) {
- mlir::Value zero =
- firBuilder.createIntegerConstant(loc, builder.getIndexType(), 0);
- mlir::Value one =
- firBuilder.createIntegerConstant(loc, builder.getIndexType(), 1);
-
- mlir::Value shape;
- if (auto declareOp =
- mlir::dyn_cast_if_present<fir::DeclareOp>(var.getDefiningOp()))
- shape = declareOp.getShape();
- else if (auto declareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>(
- var.getDefiningOp()))
- shape = declareOp.getShape();
-
- const bool strideIncludeLowerExtent = true;
-
- llvm::SmallVector<mlir::Value> accBounds;
- if (auto shapeOp =
- mlir::dyn_cast_if_present<fir::ShapeOp>(shape.getDefiningOp())) {
- mlir::Value cummulativeExtent = one;
- for (auto extent : shapeOp.getExtents()) {
+ } else if (auto shapeShiftOp =
+ mlir::dyn_cast_if_present<fir::ShapeShiftOp>(anyShapeOp)) {
+ mlir::Value lowerbound;
+ mlir::Value cummulativeExtent = one;
+ for (auto [idx, val] : llvm::enumerate(shapeShiftOp.getPairs())) {
+ if (idx % 2 == 0) {
+ lowerbound = val;
+ } else {
+ mlir::Value extent = val;
mlir::Value upperbound =
mlir::arith::SubIOp::create(builder, loc, extent, one);
mlir::Value stride = one;
@@ -257,40 +287,48 @@ generateSeqTyAccBounds(fir::SequenceType seqType, mlir::Value var,
mlir::acc::DataBoundsType::get(builder.getContext()),
/*lowerbound=*/zero, /*upperbound=*/upperbound,
/*extent=*/extent, /*stride=*/stride, /*strideInBytes=*/false,
- /*startIdx=*/one);
+ /*startIdx=*/lowerbound);
accBounds.push_back(accBound);
}
- } else if (auto shapeShiftOp =
- mlir::dyn_cast_if_present<fir::ShapeShiftOp>(
- shape.getDefiningOp())) {
- mlir::Value lowerbound;
- mlir::Value cummulativeExtent = one;
- for (auto [idx, val] : llvm::enumerate(shapeShiftOp.getPairs())) {
- if (idx % 2 == 0) {
- lowerbound = val;
- } else {
- mlir::Value extent = val;
- mlir::Value upperbound =
- mlir::arith::SubIOp::create(builder, loc, extent, one);
- mlir::Value stride = one;
- if (strideIncludeLowerExtent) {
- stride = cummulativeExtent;
- cummulativeExtent = mlir::arith::MulIOp::create(
- builder, loc, cummulativeExtent, extent);
- }
- auto accBound = mlir::acc::DataBoundsOp::create(
- builder, loc,
- mlir::acc::DataBoundsType::get(builder.getContext()),
- /*lowerbound=*/zero, /*upperbound=*/upperbound,
- /*extent=*/extent, /*stride=*/stride, /*strideInBytes=*/false,
- /*startIdx=*/lowerbound);
- accBounds.push_back(accBound);
- }
- }
}
+ }
+
+ if (!accBounds.empty())
+ return accBounds;
+ }
+
+ if (seqType.hasDynamicExtents() || seqType.hasUnknownShape()) {
+ mlir::Value box;
+ bool mayBeOptional = false;
+ if (auto boxAddr =
+ mlir::dyn_cast_if_present<fir::BoxAddrOp>(var.getDefiningOp())) {
+ box = boxAddr.getVal();
+ // Since fir.box_addr already accesses the box, we do not care
+ // checking if it is optional.
+ } else if (mlir::isa<fir::BaseBoxType>(var.getType())) {
+ box = var;
+ mayBeOptional = fir::mayBeAbsentBox(box);
+ }
+
+ if (box) {
+ auto res =
+ hlfir::translateToExtendedValue(loc, firBuilder, hlfir::Entity(box));
+ fir::ExtendedValue exv = res.first;
+ mlir::Value boxRef = box;
+ if (auto boxPtr =
+ mlir::cast<mlir::acc::MappableType>(box.getType()).getVarPtr(box))
+ boxRef = boxPtr;
+
+ mlir::Value isPresent =
+ !mayBeOptional ? mlir::Value{}
+ : fir::IsPresentOp::create(builder, loc,
+ builder.getI1Type(), box);
- if (!accBounds.empty())
- return accBounds;
+ fir::factory::AddrAndBoundsInfo info(box, boxRef, isPresent,
+ box.getType());
+ return fir::factory::genBoundsOpsFromBox<mlir::acc::DataBoundsOp,
+ mlir::acc::DataBoundsType>(
+ firBuilder, loc, exv, info);
}
assert(false && "array with unknown dimension expected to have descriptor");
@@ -353,7 +391,7 @@ getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
// calculation op.
mlir::Value baseRef =
llvm::TypeSwitch<mlir::Operation *, mlir::Value>(op)
- .Case<fir::DeclareOp>([&](auto op) {
+ .Case([&](fir::DeclareOp op) {
// If this declare binds a view with an underlying storage operand,
// treat that storage as the base reference. Otherwise, fall back
// to the declared memref.
@@ -361,7 +399,7 @@ getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
return storage;
return mlir::Value(varPtr);
})
- .Case<hlfir::DesignateOp>([&](auto op) {
+ .Case([&](hlfir::DesignateOp op) {
// Get the base object.
return op.getMemref();
})
@@ -369,12 +407,12 @@ getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
// Get the base array on which the coordinate is being applied.
return op.getMemref();
})
- .Case<fir::CoordinateOp>([&](auto op) {
+ .Case([&](fir::CoordinateOp op) {
// For coordinate operation which is applied on derived type
// object, get the base object.
return op.getRef();
})
- .Case<fir::ConvertOp>([&](auto op) -> mlir::Value {
+ .Case([&](fir::ConvertOp op) -> mlir::Value {
// Strip the conversion and recursively check the operand
if (auto ptrLikeOperand = mlir::dyn_cast_if_present<
mlir::TypedValue<mlir::acc::PointerLikeType>>(
@@ -543,30 +581,141 @@ OpenACCPointerLikeModel<fir::LLVMPointerType>::getPointeeTypeCategory(
return categorizePointee(pointer, varPtr, varType);
}
-static fir::ShapeOp genShapeOp(mlir::OpBuilder &builder,
- fir::SequenceType seqTy, mlir::Location loc) {
+static hlfir::Entity
+genDesignateWithTriplets(fir::FirOpBuilder &builder, mlir::Location loc,
+ hlfir::Entity &entity,
+ hlfir::DesignateOp::Subscripts &triplets,
+ mlir::Value shape, mlir::ValueRange extents) {
+ llvm::SmallVector<mlir::Value> lenParams;
+ hlfir::genLengthParameters(loc, builder, entity, lenParams);
+
+ // Compute result type of array section.
+ fir::SequenceType::Shape resultTypeShape;
+ bool shapeIsConstant = true;
+ for (mlir::Value extent : extents) {
+ if (std::optional<std::int64_t> cst_extent =
+ fir::getIntIfConstant(extent)) {
+ resultTypeShape.push_back(*cst_extent);
+ } else {
+ resultTypeShape.push_back(fir::SequenceType::getUnknownExtent());
+ shapeIsConstant = false;
+ }
+ }
+ assert(!resultTypeShape.empty() &&
+ "expect private sections to always represented as arrays");
+ mlir::Type eleTy = entity.getFortranElementType();
+ auto seqTy = fir::SequenceType::get(resultTypeShape, eleTy);
+ bool isVolatile = fir::isa_volatile_type(entity.getType());
+ bool resultNeedsBox =
+ llvm::isa<fir::BaseBoxType>(entity.getType()) || !shapeIsConstant;
+ bool isPolymorphic = fir::isPolymorphicType(entity.getType());
+ mlir::Type resultType;
+ if (isPolymorphic) {
+ resultType = fir::ClassType::get(seqTy, isVolatile);
+ } else if (resultNeedsBox) {
+ resultType = fir::BoxType::get(seqTy, isVolatile);
+ } else {
+ resultType = fir::ReferenceType::get(seqTy, isVolatile);
+ }
+
+ // Generate section with hlfir.designate.
+ auto designate = hlfir::DesignateOp::create(
+ builder, loc, resultType, entity, /*component=*/"",
+ /*componentShape=*/mlir::Value{}, triplets,
+ /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, shape,
+ lenParams);
+ return hlfir::Entity{designate.getResult()};
+}
+
+// Designate uses triplets based on object lower bounds while acc.bounds are
+// zero based. This helper shift the bounds to create the designate triplets.
+static hlfir::DesignateOp::Subscripts
+genTripletsFromAccBounds(fir::FirOpBuilder &builder, mlir::Location loc,
+ const llvm::SmallVector<mlir::Value> &accBounds,
+ hlfir::Entity entity) {
+ assert(entity.getRank() * 3 == static_cast<int>(accBounds.size()) &&
+ "must get lb,ub,step for each dimension");
+ hlfir::DesignateOp::Subscripts triplets;
+ for (unsigned i = 0; i < accBounds.size(); i += 3) {
+ mlir::Value lb = hlfir::genLBound(loc, builder, entity, i / 3);
+ lb = builder.createConvert(loc, accBounds[i].getType(), lb);
+ assert(accBounds[i].getType() == accBounds[i + 1].getType() &&
+ "mix of integer types in triplets");
+ mlir::Value sliceLB =
+ builder.createOrFold<mlir::arith::AddIOp>(loc, accBounds[i], lb);
+ mlir::Value sliceUB =
+ builder.createOrFold<mlir::arith::AddIOp>(loc, accBounds[i + 1], lb);
+ triplets.emplace_back(
+ hlfir::DesignateOp::Triplet{sliceLB, sliceUB, accBounds[i + 2]});
+ }
+ return triplets;
+}
+
+static std::pair<mlir::Value, llvm::SmallVector<mlir::Value>>
+computeSectionShapeAndExtents(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::ValueRange bounds) {
llvm::SmallVector<mlir::Value> extents;
+ // Compute the fir.shape of the array section and the triplets to create
+ // hlfir.designate.
mlir::Type idxTy = builder.getIndexType();
- for (auto extent : seqTy.getShape())
- extents.push_back(mlir::arith::ConstantOp::create(
- builder, loc, idxTy, builder.getIntegerAttr(idxTy, extent)));
- return fir::ShapeOp::create(builder, loc, extents);
+ for (unsigned i = 0; i + 2 < bounds.size(); i += 3)
+ extents.push_back(builder.genExtentFromTriplet(
+ loc, bounds[i], bounds[i + 1], bounds[i + 2], idxTy, /*fold=*/true));
+ mlir::Value shape = fir::ShapeOp::create(builder, loc, extents);
+ return {shape, extents};
+}
+
+static std::pair<hlfir::Entity, hlfir::Entity>
+genArraySectionsInRecipe(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::ValueRange bounds, hlfir::Entity lhs,
+ hlfir::Entity rhs) {
+ assert(lhs.getRank() * 3 == static_cast<int>(bounds.size()) &&
+ "must get lb,ub,step for each dimension");
+ lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
+ rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs);
+ // Get the list of lb,ub,step values for the sections that can be used inside
+ // the recipe region.
+ auto [shape, extents] = computeSectionShapeAndExtents(builder, loc, bounds);
+ hlfir::DesignateOp::Subscripts rhsTriplets =
+ genTripletsFromAccBounds(builder, loc, bounds, rhs);
+ hlfir::DesignateOp::Subscripts lhsTriplets;
+ // Share the bounds when both rhs/lhs are known to be 1-based to avoid noise
+ // in the IR for the most common cases.
+ if (!lhs.mayHaveNonDefaultLowerBounds() &&
+ !rhs.mayHaveNonDefaultLowerBounds())
+ lhsTriplets = rhsTriplets;
+ else
+ lhsTriplets = genTripletsFromAccBounds(builder, loc, bounds, lhs);
+ hlfir::Entity leftSection =
+ genDesignateWithTriplets(builder, loc, lhs, lhsTriplets, shape, extents);
+ hlfir::Entity rightSection =
+ genDesignateWithTriplets(builder, loc, rhs, rhsTriplets, shape, extents);
+ return {leftSection, rightSection};
+}
+
+static bool boundsAreAllConstants(mlir::ValueRange bounds) {
+ for (mlir::Value bound : bounds)
+ if (!fir::getIntIfConstant(bound).has_value())
+ return false;
+ return true;
}
template <typename Ty>
mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit(
- mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Type type, mlir::OpBuilder &mlirBuilder, mlir::Location loc,
mlir::TypedValue<mlir::acc::MappableType> var, llvm::StringRef varName,
- mlir::ValueRange extents, mlir::Value initVal, bool &needsDestroy) const {
- needsDestroy = false;
- mlir::Value retVal;
- mlir::Type unwrappedTy = fir::unwrapRefType(type);
- mlir::ModuleOp mod = builder.getInsertionBlock()
+ mlir::ValueRange bounds, mlir::Value initVal, bool &needsDestroy) const {
+ mlir::ModuleOp mod = mlirBuilder.getInsertionBlock()
->getParent()
->getParentOfType<mlir::ModuleOp>();
-
- if (auto recType = llvm::dyn_cast<fir::RecordType>(
- fir::getFortranElementType(unwrappedTy))) {
+ assert(mod && "failed to retrieve ModuleOp");
+ fir::FirOpBuilder builder(mlirBuilder, mod);
+
+ hlfir::Entity inputVar = hlfir::Entity{var};
+ if (inputVar.isPolymorphic())
+ TODO(loc, "OpenACC: polymorphic variable privatization");
+ if (auto recType =
+ llvm::dyn_cast<fir::RecordType>(inputVar.getFortranElementType())) {
// Need to make deep copies of allocatable components.
if (fir::isRecordWithAllocatableMember(recType))
TODO(loc,
@@ -575,117 +724,161 @@ mlir::Value OpenACCMappableModel<Ty>::generatePrivateInit(
if (fir::isRecordWithFinalRoutine(recType, mod).value_or(false))
TODO(loc, "OpenACC: privatizing derived type with user assignment or "
"final routine ");
+ // Pointer components needs to be initialized to NULL() for private-like
+ // recipes.
+ if (fir::isRecordWithDescriptorMember(recType))
+ TODO(loc, "OpenACC: privatizing derived type with pointer components");
+ }
+ bool isPointerOrAllocatable = inputVar.isMutableBox();
+ hlfir::Entity dereferencedVar =
+ hlfir::derefPointersAndAllocatables(loc, builder, inputVar);
+
+ // Step 1: Gather the address, shape, extents, and lengths parameters of the
+ // entity being privatized. Designate the array section if only a section is
+ // privatized, otherwise just use the original variable.
+ hlfir::Entity privatizedVar = dereferencedVar;
+ mlir::Value tempShape;
+ llvm::SmallVector<mlir::Value> tempExtents;
+ // TODO: while it seems best to allocate as little memory as possible and
+ // allocate only the storage for the section, this may actually have drawbacks
+ // when the array has static size and can be privatized with an alloca while
+ // the section size is dynamic and requires an dynamic allocmem. Hence, we
+ // currently allocate the full array storage in such cases. This could be
+ // improved via some kind of threshold if the base array size is large enough
+ // to justify doing a dynamic allocation with the hope that it is much
+ // smaller.
+ bool allocateSection = false;
+ bool isDynamicSectionOfStaticSizeArray =
+ !bounds.empty() &&
+ !fir::hasDynamicSize(dereferencedVar.getElementOrSequenceType()) &&
+ !boundsAreAllConstants(bounds);
+ if (!bounds.empty() && !isDynamicSectionOfStaticSizeArray) {
+ allocateSection = true;
+ hlfir::DesignateOp::Subscripts triplets;
+ std::tie(tempShape, tempExtents) =
+ computeSectionShapeAndExtents(builder, loc, bounds);
+ triplets = genTripletsFromAccBounds(builder, loc, bounds, dereferencedVar);
+ privatizedVar = genDesignateWithTriplets(builder, loc, dereferencedVar,
+ triplets, tempShape, tempExtents);
+ } else if (privatizedVar.getRank() > 0) {
+ mlir::Value shape = hlfir::genShape(loc, builder, privatizedVar);
+ tempExtents = hlfir::getExplicitExtentsFromShape(shape, builder);
+ tempShape = fir::ShapeOp::create(builder, loc, tempExtents);
+ }
+ llvm::SmallVector<mlir::Value> typeParams;
+ hlfir::genLengthParameters(loc, builder, privatizedVar, typeParams);
+ mlir::Type baseType = privatizedVar.getElementOrSequenceType();
+ // Step2: Create a temporary allocation for the privatized part.
+ mlir::Value alloc;
+ if (fir::hasDynamicSize(baseType) ||
+ (isPointerOrAllocatable && bounds.empty())) {
+ // Note: heap allocation is forced for whole pointers/allocatable so that
+ // the private POINTER/ALLOCATABLE can be deallocated/reallocated on the
+ // device inside the compute region. It may not be a requirement, and this
+ // could be revisited. In practice, this only matters for scalars since
+ // array POINTER and ALLOCATABLE always have dynamic size. Constant sections
+ // of POINTER/ALLOCATABLE can use alloca since only part of the data is
+ // privatized (it makes no sense to deallocate them).
+ alloc = builder.createHeapTemporary(loc, baseType, varName, tempExtents,
+ typeParams);
+ needsDestroy = true;
+ } else {
+ alloc = builder.createTemporary(loc, baseType, varName, tempExtents,
+ typeParams);
+ }
+ // Step3: Assign the initial value to the privatized part if any.
+ if (initVal) {
+ mlir::Value tempEntity = alloc;
+ if (fir::hasDynamicSize(baseType))
+ tempEntity =
+ fir::EmboxOp::create(builder, loc, fir::BoxType::get(baseType), alloc,
+ tempShape, /*slice=*/mlir::Value{}, typeParams);
+ hlfir::genNoAliasAssignment(
+ loc, builder, hlfir::Entity{initVal}, hlfir::Entity{tempEntity},
+ /*emitWorkshareLoop=*/false, /*temporaryLHS=*/true);
}
- fir::FirOpBuilder firBuilder(builder, mod);
- auto getDeclareOpForType = [&](mlir::Type ty) -> hlfir::DeclareOp {
- auto alloca = fir::AllocaOp::create(firBuilder, loc, ty);
- return hlfir::DeclareOp::create(firBuilder, loc, alloca, varName);
- };
+ // Making a dynamic allocation of the size of the whole base instead of the
+ // section in case of section would lead to improper deallocation because
+ // generatePrivateDestroy always deallocates the start of the section when
+ // there is a section.
+ assert(!(needsDestroy && !bounds.empty() && !allocateSection) &&
+ "dynamic allocation of the whole base in case of section is not "
+ "expected");
+
+ if (inputVar.getType() == alloc.getType() && !allocateSection)
+ return alloc;
+
+ // Step4: reconstruct the input variable from the privatized part:
+ // - get a mock base address if the privatized part is a section (so that any
+ // addressing of the input variable can be replaced by the same addressing of
+ // the privatized part even though the allocated part for the private does not
+ // cover all the input variable storage. This is relying on OpenACC
+ // constraint that any addressing of such privatized variable inside the
+ // construct region can only address the variable inside the privatized
+ // section).
+ // - reconstruct a descriptor with the same bounds and type parameters as the
+ // input if needed.
+ // - store this new descriptor in a temporary allocation if the input variable
+ // is a POINTER/ALLOCATABLE.
+ llvm::SmallVector<mlir::Value> inputVarLowerBounds, inputVarExtents;
+ if (dereferencedVar.isArray()) {
+ for (int dim = 0; dim < dereferencedVar.getRank(); ++dim) {
+ inputVarLowerBounds.push_back(
+ hlfir::genLBound(loc, builder, dereferencedVar, dim));
+ inputVarExtents.push_back(
+ hlfir::genExtent(loc, builder, dereferencedVar, dim));
+ }
+ }
- if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(unwrappedTy)) {
- if (fir::isa_trivial(seqTy.getEleTy())) {
- mlir::Value shape;
- if (seqTy.hasDynamicExtents()) {
- shape = fir::ShapeOp::create(firBuilder, loc, llvm::to_vector(extents));
- } else {
- shape = genShapeOp(firBuilder, seqTy, loc);
- }
- auto alloca = fir::AllocaOp::create(
- firBuilder, loc, seqTy, /*typeparams=*/mlir::ValueRange{}, extents);
- auto declareOp =
- hlfir::DeclareOp::create(firBuilder, loc, alloca, varName, shape);
-
- if (initVal) {
- mlir::Type idxTy = firBuilder.getIndexType();
- mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
- llvm::SmallVector<fir::DoLoopOp> loops;
- llvm::SmallVector<mlir::Value> ivs;
-
- if (seqTy.hasDynamicExtents()) {
- hlfir::AssignOp::create(firBuilder, loc, initVal,
- declareOp.getBase());
- } else {
- // Generate loop nest from slowest to fastest running dimension
- for (auto ext : llvm::reverse(seqTy.getShape())) {
- auto lb = firBuilder.createIntegerConstant(loc, idxTy, 0);
- auto ub = firBuilder.createIntegerConstant(loc, idxTy, ext - 1);
- auto step = firBuilder.createIntegerConstant(loc, idxTy, 1);
- auto loop = fir::DoLoopOp::create(firBuilder, loc, lb, ub, step,
- /*unordered=*/false);
- firBuilder.setInsertionPointToStart(loop.getBody());
- loops.push_back(loop);
- ivs.push_back(loop.getInductionVar());
- }
- // Reverse IVs to match CoordinateOp's canonical index order.
- std::reverse(ivs.begin(), ivs.end());
- auto coord = fir::CoordinateOp::create(firBuilder, loc, refTy,
- declareOp.getBase(), ivs);
- fir::StoreOp::create(firBuilder, loc, initVal, coord);
- firBuilder.setInsertionPointAfter(loops[0]);
- }
- }
- retVal = declareOp.getBase();
+ mlir::Value privateVarBaseAddr = alloc;
+ if (allocateSection) {
+ // To compute the mock base address without doing pointer arithmetic,
+ // compute: TYPE, TEMP(ZERO_BASED_SECTION_LB:) MOCK_BASE = TEMP(0)
+ // This addresses the section "backwards" (0 <= ZERO_BASED_SECTION_LB). This
+ // is currently OK, but care should be taken to avoid tripping bound checks
+ // if added in the future.
+ mlir::Type inputBaseAddrType =
+ dereferencedVar.getBoxType().getBaseAddressType();
+ mlir::Value tempBaseAddr =
+ builder.createConvert(loc, inputBaseAddrType, alloc);
+ mlir::Value zero =
+ builder.createIntegerConstant(loc, builder.getIndexType(), 0);
+ llvm::SmallVector<mlir::Value> lowerBounds;
+ llvm::SmallVector<mlir::Value> zeros;
+ for (unsigned i = 0; i < bounds.size(); i += 3) {
+ lowerBounds.push_back(bounds[i]);
+ zeros.push_back(zero);
}
- } else if (auto boxTy =
- mlir::dyn_cast_or_null<fir::BaseBoxType>(unwrappedTy)) {
- mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy());
- if (fir::isa_trivial(innerTy)) {
- retVal = getDeclareOpForType(unwrappedTy).getBase();
- mlir::Value allocatedScalar =
- fir::AllocMemOp::create(builder, loc, innerTy);
- mlir::Value firClass =
- fir::EmboxOp::create(builder, loc, boxTy, allocatedScalar);
- fir::StoreOp::create(builder, loc, firClass, retVal);
- needsDestroy = true;
- } else if (mlir::isa<fir::SequenceType>(innerTy)) {
- hlfir::Entity source = hlfir::Entity{var};
- auto [temp, cleanupFlag] =
- hlfir::createTempFromMold(loc, firBuilder, source);
- if (fir::isa_ref_type(type)) {
- // When the temp is created - it is not a reference - thus we can
- // end up with a type inconsistency. Therefore ensure storage is created
- // for it.
- retVal = getDeclareOpForType(unwrappedTy).getBase();
- mlir::Value storeDst = retVal;
- if (fir::unwrapRefType(retVal.getType()) != temp.getType()) {
- // `createTempFromMold` makes the unfortunate choice to lose the
- // `fir.heap` and `fir.ptr` types when wrapping with a box. Namely,
- // when wrapping a `fir.heap<fir.array>`, it will create instead a
- // `fir.box<fir.array>`. Cast here to deal with this inconsistency.
- storeDst = firBuilder.createConvert(
- loc, firBuilder.getRefType(temp.getType()), retVal);
- }
- fir::StoreOp::create(builder, loc, temp, storeDst);
- } else {
- retVal = temp;
- }
- // If heap was allocated, a destroy is required later.
- if (cleanupFlag)
- needsDestroy = true;
+ mlir::Value offsetShapeShift =
+ builder.genShape(loc, lowerBounds, inputVarExtents);
+ mlir::Type eleRefType =
+ builder.getRefType(privatizedVar.getFortranElementType());
+ mlir::Value mockBase = fir::ArrayCoorOp::create(
+ builder, loc, eleRefType, tempBaseAddr, offsetShapeShift,
+ /*slice=*/mlir::Value{}, /*indices=*/zeros,
+ /*typeParams=*/mlir::ValueRange{});
+ privateVarBaseAddr =
+ builder.createConvert(loc, inputBaseAddrType, mockBase);
+ }
+
+ mlir::Value retVal = privateVarBaseAddr;
+ if (inputVar.isBoxAddressOrValue()) {
+ // Recreate descriptor with same bounds as the input variable.
+ mlir::Value shape;
+ if (!inputVarExtents.empty())
+ shape = builder.genShape(loc, inputVarLowerBounds, inputVarExtents);
+ mlir::Value box = fir::EmboxOp::create(builder, loc, inputVar.getBoxType(),
+ privateVarBaseAddr, shape,
+ /*slice=*/mlir::Value{}, typeParams);
+ if (inputVar.isMutableBox()) {
+ mlir::Value boxAlloc =
+ fir::AllocaOp::create(builder, loc, inputVar.getBoxType());
+ fir::StoreOp::create(builder, loc, box, boxAlloc);
+ retVal = boxAlloc;
} else {
- TODO(loc, "Unsupported boxed type for OpenACC private-like recipe");
- }
- if (initVal) {
- hlfir::AssignOp::create(builder, loc, initVal, retVal);
+ retVal = box;
}
- } else if (llvm::isa<fir::BoxCharType, fir::CharacterType>(unwrappedTy)) {
- TODO(loc, "Character type for OpenACC private-like recipe");
- } else {
- assert((fir::isa_trivial(unwrappedTy) ||
- llvm::isa<fir::RecordType>(unwrappedTy)) &&
- "expected numerical, logical, and derived type without length "
- "parameters");
- auto declareOp = getDeclareOpForType(unwrappedTy);
- if (initVal && fir::isa_trivial(unwrappedTy)) {
- auto convert = firBuilder.createConvert(loc, unwrappedTy, initVal);
- fir::StoreOp::create(firBuilder, loc, convert, declareOp.getBase());
- } else if (initVal) {
- // hlfir.assign with temporary LHS flag should just do it. Not implemented
- // because not clear it is needed, so cannot be tested.
- TODO(loc, "initial value for derived type in private-like recipe");
- }
- retVal = declareOp.getBase();
}
return retVal;
}
@@ -714,42 +907,249 @@ OpenACCMappableModel<fir::PointerType>::generatePrivateInit(
mlir::ValueRange extents, mlir::Value initVal, bool &needsDestroy) const;
template <typename Ty>
+bool OpenACCMappableModel<Ty>::generateCopy(
+ mlir::Type type, mlir::OpBuilder &mlirBuilder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::MappableType> src,
+ mlir::TypedValue<mlir::acc::MappableType> dest,
+ mlir::ValueRange bounds) const {
+ mlir::ModuleOp mod =
+ mlirBuilder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
+ assert(mod && "failed to retrieve parent module");
+ fir::FirOpBuilder builder(mlirBuilder, mod);
+ hlfir::Entity source{src};
+ hlfir::Entity destination{dest};
+
+ source = hlfir::derefPointersAndAllocatables(loc, builder, source);
+ destination = hlfir::derefPointersAndAllocatables(loc, builder, destination);
+
+ if (!bounds.empty())
+ std::tie(source, destination) =
+ genArraySectionsInRecipe(builder, loc, bounds, source, destination);
+ // The source and the destination of the firstprivate copy cannot alias,
+ // the destination is already properly allocated, so a simple assignment
+ // can be generated right away to avoid ending-up with runtime calls
+ // for arrays of numerical, logical and, character types.
+ //
+ // The temporary_lhs flag allows indicating that user defined assignments
+ // should not be called while copying components, and that the LHS and RHS
+ // are known to not alias since the LHS is a created object.
+ //
+ // TODO: detect cases where user defined assignment is needed and add a TODO.
+ // using temporary_lhs allows more aggressive optimizations of simple derived
+ // types. Existing compilers supporting OpenACC do not call user defined
+ // assignments, some use case is needed to decide what to do.
+ source = hlfir::loadTrivialScalar(loc, builder, source);
+ hlfir::AssignOp::create(builder, loc, source, destination, /*realloc=*/false,
+ /*keep_lhs_length_if_realloc=*/false,
+ /*temporary_lhs=*/true);
+ return true;
+}
+
+template bool OpenACCMappableModel<fir::BaseBoxType>::generateCopy(
+ mlir::Type, mlir::OpBuilder &, mlir::Location,
+ mlir::TypedValue<mlir::acc::MappableType>,
+ mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange) const;
+template bool OpenACCMappableModel<fir::ReferenceType>::generateCopy(
+ mlir::Type, mlir::OpBuilder &, mlir::Location,
+ mlir::TypedValue<mlir::acc::MappableType>,
+ mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange) const;
+template bool OpenACCMappableModel<fir::PointerType>::generateCopy(
+ mlir::Type, mlir::OpBuilder &, mlir::Location,
+ mlir::TypedValue<mlir::acc::MappableType>,
+ mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange) const;
+template bool OpenACCMappableModel<fir::HeapType>::generateCopy(
+ mlir::Type, mlir::OpBuilder &, mlir::Location,
+ mlir::TypedValue<mlir::acc::MappableType>,
+ mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange) const;
+
+template <typename Op>
+static mlir::Value genLogicalCombiner(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Value value1,
+ mlir::Value value2) {
+ mlir::Type i1 = builder.getI1Type();
+ mlir::Value v1 = fir::ConvertOp::create(builder, loc, i1, value1);
+ mlir::Value v2 = fir::ConvertOp::create(builder, loc, i1, value2);
+ mlir::Value combined = Op::create(builder, loc, v1, v2);
+ return fir::ConvertOp::create(builder, loc, value1.getType(), combined);
+}
+
+static mlir::Value genComparisonCombiner(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ mlir::arith::CmpIPredicate pred,
+ mlir::Value value1,
+ mlir::Value value2) {
+ mlir::Type i1 = builder.getI1Type();
+ mlir::Value v1 = fir::ConvertOp::create(builder, loc, i1, value1);
+ mlir::Value v2 = fir::ConvertOp::create(builder, loc, i1, value2);
+ mlir::Value add = mlir::arith::CmpIOp::create(builder, loc, pred, v1, v2);
+ return fir::ConvertOp::create(builder, loc, value1.getType(), add);
+}
+
+static mlir::Value genScalarCombiner(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ mlir::acc::ReductionOperator op,
+ mlir::Type ty, mlir::Value value1,
+ mlir::Value value2) {
+ value1 = builder.loadIfRef(loc, value1);
+ value2 = builder.loadIfRef(loc, value2);
+ if (op == mlir::acc::ReductionOperator::AccAdd) {
+ if (ty.isIntOrIndex())
+ return mlir::arith::AddIOp::create(builder, loc, value1, value2);
+ if (mlir::isa<mlir::FloatType>(ty))
+ return mlir::arith::AddFOp::create(builder, loc, value1, value2);
+ if (auto cmplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(ty))
+ return fir::AddcOp::create(builder, loc, value1, value2);
+ TODO(loc, "reduction add type");
+ }
+
+ if (op == mlir::acc::ReductionOperator::AccMul) {
+ if (ty.isIntOrIndex())
+ return mlir::arith::MulIOp::create(builder, loc, value1, value2);
+ if (mlir::isa<mlir::FloatType>(ty))
+ return mlir::arith::MulFOp::create(builder, loc, value1, value2);
+ if (mlir::isa<mlir::ComplexType>(ty))
+ return fir::MulcOp::create(builder, loc, value1, value2);
+ TODO(loc, "reduction mul type");
+ }
+
+ if (op == mlir::acc::ReductionOperator::AccMin)
+ return fir::genMin(builder, loc, {value1, value2});
+
+ if (op == mlir::acc::ReductionOperator::AccMax)
+ return fir::genMax(builder, loc, {value1, value2});
+
+ if (op == mlir::acc::ReductionOperator::AccIand)
+ return mlir::arith::AndIOp::create(builder, loc, value1, value2);
+
+ if (op == mlir::acc::ReductionOperator::AccIor)
+ return mlir::arith::OrIOp::create(builder, loc, value1, value2);
+
+ if (op == mlir::acc::ReductionOperator::AccXor)
+ return mlir::arith::XOrIOp::create(builder, loc, value1, value2);
+
+ if (op == mlir::acc::ReductionOperator::AccLand)
+ return genLogicalCombiner<mlir::arith::AndIOp>(builder, loc, value1,
+ value2);
+
+ if (op == mlir::acc::ReductionOperator::AccLor)
+ return genLogicalCombiner<mlir::arith::OrIOp>(builder, loc, value1, value2);
+
+ if (op == mlir::acc::ReductionOperator::AccEqv)
+ return genComparisonCombiner(builder, loc, mlir::arith::CmpIPredicate::eq,
+ value1, value2);
+
+ if (op == mlir::acc::ReductionOperator::AccNeqv)
+ return genComparisonCombiner(builder, loc, mlir::arith::CmpIPredicate::ne,
+ value1, value2);
+
+ TODO(loc, "reduction operator");
+}
+
+template <typename Ty>
+bool OpenACCMappableModel<Ty>::generateCombiner(
+ mlir::Type type, mlir::OpBuilder &mlirBuilder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::MappableType> dest,
+ mlir::TypedValue<mlir::acc::MappableType> source, mlir::ValueRange bounds,
+ mlir::acc::ReductionOperator op, mlir::Attribute fastmathFlags) const {
+ mlir::ModuleOp mod =
+ mlirBuilder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
+ assert(mod && "failed to retrieve parent module");
+ fir::FirOpBuilder builder(mlirBuilder, mod);
+ if (fastmathFlags)
+ if (auto fastMathAttr =
+ mlir::dyn_cast<mlir::arith::FastMathFlagsAttr>(fastmathFlags))
+ builder.setFastMathFlags(fastMathAttr.getValue());
+ // Generate loops that combine and assign the inputs into dest (or array
+ // section of the inputs when there are bounds).
+ hlfir::Entity srcSection{source};
+ hlfir::Entity destSection{dest};
+ if (!bounds.empty()) {
+ std::tie(srcSection, destSection) =
+ genArraySectionsInRecipe(builder, loc, bounds, srcSection, destSection);
+ }
+
+ mlir::Type elementType = fir::getFortranElementType(dest.getType());
+ auto genKernel = [&](mlir::Location l, fir::FirOpBuilder &b,
+ hlfir::Entity srcElementValue,
+ hlfir::Entity destElementValue) -> hlfir::Entity {
+ return hlfir::Entity{genScalarCombiner(builder, loc, op, elementType,
+ srcElementValue, destElementValue)};
+ };
+ hlfir::genNoAliasAssignment(loc, builder, srcSection, destSection,
+ /*emitWorkshareLoop=*/false,
+ /*temporaryLHS=*/false, genKernel);
+ return true;
+}
+
+template bool OpenACCMappableModel<fir::BaseBoxType>::generateCombiner(
+ mlir::Type, mlir::OpBuilder &, mlir::Location,
+ mlir::TypedValue<mlir::acc::MappableType>,
+ mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange,
+ mlir::acc::ReductionOperator op, mlir::Attribute) const;
+template bool OpenACCMappableModel<fir::ReferenceType>::generateCombiner(
+ mlir::Type, mlir::OpBuilder &, mlir::Location,
+ mlir::TypedValue<mlir::acc::MappableType>,
+ mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange,
+ mlir::acc::ReductionOperator op, mlir::Attribute) const;
+template bool OpenACCMappableModel<fir::PointerType>::generateCombiner(
+ mlir::Type, mlir::OpBuilder &, mlir::Location,
+ mlir::TypedValue<mlir::acc::MappableType>,
+ mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange,
+ mlir::acc::ReductionOperator op, mlir::Attribute) const;
+template bool OpenACCMappableModel<fir::HeapType>::generateCombiner(
+ mlir::Type, mlir::OpBuilder &, mlir::Location,
+ mlir::TypedValue<mlir::acc::MappableType>,
+ mlir::TypedValue<mlir::acc::MappableType>, mlir::ValueRange,
+ mlir::acc::ReductionOperator op, mlir::Attribute) const;
+
+template <typename Ty>
bool OpenACCMappableModel<Ty>::generatePrivateDestroy(
- mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
- mlir::Value privatized) const {
- mlir::Type unwrappedTy = fir::unwrapRefType(type);
- // For boxed scalars allocated with AllocMem during init, free the heap.
- if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(unwrappedTy)) {
- mlir::Value boxVal = privatized;
- if (fir::isa_ref_type(boxVal.getType()))
- boxVal = fir::LoadOp::create(builder, loc, boxVal);
- mlir::Value addr = fir::BoxAddrOp::create(builder, loc, boxVal);
- // FreeMem only accepts fir.heap and this may not be represented in the box
- // type if the privatized entity is not an allocatable.
+ mlir::Type type, mlir::OpBuilder &mlirBuilder, mlir::Location loc,
+ mlir::Value privatized, mlir::ValueRange bounds) const {
+ hlfir::Entity inputVar = hlfir::Entity{privatized};
+ mlir::ModuleOp mod =
+ mlirBuilder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
+ assert(mod && "failed to retrieve parent module");
+ fir::FirOpBuilder builder(mlirBuilder, mod);
+ auto genFreeRawAddress = [&](hlfir::Entity entity) {
+ mlir::Value addr = hlfir::genVariableRawAddress(loc, builder, entity);
mlir::Type heapType =
fir::HeapType::get(fir::unwrapRefType(addr.getType()));
if (heapType != addr.getType())
addr = fir::ConvertOp::create(builder, loc, heapType, addr);
fir::FreeMemOp::create(builder, loc, addr);
+ };
+ if (bounds.empty()) {
+ genFreeRawAddress(inputVar);
return true;
}
-
- // Nothing to do for other categories by default, they are stack allocated.
+ // The input variable is an array section, the base address is not the real
+ // allocation. Compute the section base address and deallocate that.
+ hlfir::Entity dereferencedVar =
+ hlfir::derefPointersAndAllocatables(loc, builder, inputVar);
+ hlfir::DesignateOp::Subscripts triplets;
+ auto [tempShape, tempExtents] =
+ computeSectionShapeAndExtents(builder, loc, bounds);
+ (void)tempExtents;
+ triplets = genTripletsFromAccBounds(builder, loc, bounds, dereferencedVar);
+ hlfir::Entity arraySection = genDesignateWithTriplets(
+ builder, loc, dereferencedVar, triplets, tempShape, tempExtents);
+ genFreeRawAddress(arraySection);
return true;
}
template bool OpenACCMappableModel<fir::BaseBoxType>::generatePrivateDestroy(
mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
- mlir::Value privatized) const;
+ mlir::Value privatized, mlir::ValueRange bounds) const;
template bool OpenACCMappableModel<fir::ReferenceType>::generatePrivateDestroy(
mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
- mlir::Value privatized) const;
+ mlir::Value privatized, mlir::ValueRange bounds) const;
template bool OpenACCMappableModel<fir::HeapType>::generatePrivateDestroy(
mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
- mlir::Value privatized) const;
+ mlir::Value privatized, mlir::ValueRange bounds) const;
template bool OpenACCMappableModel<fir::PointerType>::generatePrivateDestroy(
mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
- mlir::Value privatized) const;
+ mlir::Value privatized, mlir::ValueRange bounds) const;
template <typename Ty>
mlir::Value OpenACCPointerLikeModel<Ty>::genAllocate(
@@ -825,41 +1225,6 @@ template mlir::Value OpenACCPointerLikeModel<fir::LLVMPointerType>::genAllocate(
llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar,
bool &needsFree) const;
-static mlir::Value stripCasts(mlir::Value value, bool stripDeclare = true) {
- mlir::Value currentValue = value;
-
- while (currentValue) {
- auto *definingOp = currentValue.getDefiningOp();
- if (!definingOp)
- break;
-
- if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(definingOp)) {
- currentValue = convertOp.getValue();
- continue;
- }
-
- if (auto viewLike = mlir::dyn_cast<mlir::ViewLikeOpInterface>(definingOp)) {
- currentValue = viewLike.getViewSource();
- continue;
- }
-
- if (stripDeclare) {
- if (auto declareOp = mlir::dyn_cast<hlfir::DeclareOp>(definingOp)) {
- currentValue = declareOp.getMemref();
- continue;
- }
-
- if (auto declareOp = mlir::dyn_cast<fir::DeclareOp>(definingOp)) {
- currentValue = declareOp.getMemref();
- continue;
- }
- }
- break;
- }
-
- return currentValue;
-}
-
template <typename Ty>
bool OpenACCPointerLikeModel<Ty>::genFree(
mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
@@ -887,7 +1252,7 @@ bool OpenACCPointerLikeModel<Ty>::genFree(
mlir::Value valueToInspect = allocRes ? allocRes : varToFree;
// Strip casts and declare operations to find the original allocation
- mlir::Value strippedValue = stripCasts(valueToInspect);
+ mlir::Value strippedValue = fir::acc::getOriginalDef(valueToInspect);
mlir::Operation *originalAlloc = strippedValue.getDefiningOp();
// If we found an AllocMemOp (heap allocation), free it
@@ -992,4 +1357,232 @@ template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::genCopy(
mlir::TypedValue<mlir::acc::PointerLikeType> source,
mlir::Type varType) const;
+template <typename Ty>
+mlir::Value OpenACCPointerLikeModel<Ty>::genLoad(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> srcPtr,
+ mlir::Type valueType) const {
+
+ // Unwrap to get the pointee type.
+ mlir::Type pointeeTy = fir::dyn_cast_ptrEleTy(pointer);
+ assert(pointeeTy && "expected pointee type to be extractable");
+
+ // Box types contain both a descriptor and referenced data. The genLoad API
+ // handles simple loads and cannot properly manage both parts.
+ if (fir::isa_box_type(pointeeTy))
+ return {};
+
+ // Unlimited polymorphic (class(*)) cannot be handled because type is unknown.
+ if (fir::isUnlimitedPolymorphicType(pointeeTy))
+ return {};
+
+ // Return empty for dynamic size types because the load logic
+ // cannot be determined simply from the type.
+ if (fir::hasDynamicSize(pointeeTy))
+ return {};
+
+ mlir::Value loadedValue = fir::LoadOp::create(builder, loc, srcPtr);
+
+ // If valueType is provided and differs from the loaded type, insert a convert
+ if (valueType && loadedValue.getType() != valueType)
+ return fir::ConvertOp::create(builder, loc, valueType, loadedValue);
+
+ return loadedValue;
+}
+
+template mlir::Value OpenACCPointerLikeModel<fir::ReferenceType>::genLoad(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> srcPtr,
+ mlir::Type valueType) const;
+
+template mlir::Value OpenACCPointerLikeModel<fir::PointerType>::genLoad(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> srcPtr,
+ mlir::Type valueType) const;
+
+template mlir::Value OpenACCPointerLikeModel<fir::HeapType>::genLoad(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> srcPtr,
+ mlir::Type valueType) const;
+
+template mlir::Value OpenACCPointerLikeModel<fir::LLVMPointerType>::genLoad(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> srcPtr,
+ mlir::Type valueType) const;
+
+template <typename Ty>
+bool OpenACCPointerLikeModel<Ty>::genStore(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value valueToStore,
+ mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const {
+
+ // Unwrap to get the pointee type.
+ mlir::Type pointeeTy = fir::dyn_cast_ptrEleTy(pointer);
+ assert(pointeeTy && "expected pointee type to be extractable");
+
+ // Box types contain both a descriptor and referenced data. The genStore API
+ // handles simple stores and cannot properly manage both parts.
+ if (fir::isa_box_type(pointeeTy))
+ return false;
+
+ // Unlimited polymorphic (class(*)) cannot be handled because type is unknown.
+ if (fir::isUnlimitedPolymorphicType(pointeeTy))
+ return false;
+
+ // Return false for dynamic size types because the store logic
+ // cannot be determined simply from the type.
+ if (fir::hasDynamicSize(pointeeTy))
+ return false;
+
+ // Get the type from the value being stored
+ mlir::Type valueType = valueToStore.getType();
+ mlir::Value convertedValue = valueToStore;
+
+ // If the value type differs from the pointee type, insert a convert
+ if (valueType != pointeeTy)
+ convertedValue =
+ fir::ConvertOp::create(builder, loc, pointeeTy, valueToStore);
+
+ fir::StoreOp::create(builder, loc, convertedValue, destPtr);
+ return true;
+}
+
+template bool OpenACCPointerLikeModel<fir::ReferenceType>::genStore(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value valueToStore,
+ mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const;
+
+template bool OpenACCPointerLikeModel<fir::PointerType>::genStore(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value valueToStore,
+ mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const;
+
+template bool OpenACCPointerLikeModel<fir::HeapType>::genStore(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value valueToStore,
+ mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const;
+
+template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::genStore(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value valueToStore,
+ mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const;
+
+/// Check CUDA attributes on a function argument.
+static bool hasCUDADeviceAttrOnFuncArg(mlir::BlockArgument blockArg) {
+ auto *owner = blockArg.getOwner();
+ if (!owner)
+ return false;
+
+ auto *parentOp = owner->getParentOp();
+ if (!parentOp)
+ return false;
+
+ if (auto funcLike = mlir::dyn_cast<mlir::FunctionOpInterface>(parentOp)) {
+ unsigned argIndex = blockArg.getArgNumber();
+ if (argIndex < funcLike.getNumArguments())
+ if (auto attr = funcLike.getArgAttr(argIndex, cuf::getDataAttrName()))
+ if (auto cudaAttr = mlir::dyn_cast<cuf::DataAttributeAttr>(attr))
+ return cuf::isDeviceDataAttribute(cudaAttr.getValue());
+ }
+ return false;
+}
+
+/// Shared implementation for checking if a value represents device data.
+static bool isDeviceDataImpl(mlir::Value var) {
+ // Strip casts to find the underlying value.
+ mlir::Value currentVal =
+ fir::acc::getOriginalDef(var, /*stripDeclare=*/false);
+
+ if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(currentVal))
+ return hasCUDADeviceAttrOnFuncArg(blockArg);
+
+ mlir::Operation *defOp = currentVal.getDefiningOp();
+ assert(defOp && "expected defining op for non-block-argument value");
+
+ // Check for CUDA attributes on the defining operation.
+ if (cuf::hasDeviceDataAttr(defOp))
+ return true;
+
+ // Handle operations that access a partial entity - check if the base entity
+ // is device data.
+ if (auto partialAccess =
+ mlir::dyn_cast<mlir::acc::PartialEntityAccessOpInterface>(defOp))
+ if (mlir::Value base = partialAccess.getBaseEntity())
+ return isDeviceDataImpl(base);
+
+ // Handle fir.embox, fir.rebox, and similar ops via
+ // FortranObjectViewOpInterface to check if the underlying source is device
+ // data.
+ if (auto viewOp = mlir::dyn_cast<fir::FortranObjectViewOpInterface>(defOp))
+ if (mlir::Value source = viewOp.getViewSource(defOp->getResult(0)))
+ return isDeviceDataImpl(source);
+
+ // Handle address_of - check the referenced global.
+ if (auto addrOfIface =
+ mlir::dyn_cast<mlir::acc::AddressOfGlobalOpInterface>(defOp)) {
+ auto symbol = addrOfIface.getSymbol();
+ if (auto global = mlir::SymbolTable::lookupNearestSymbolFrom<
+ mlir::acc::GlobalVariableOpInterface>(defOp, symbol))
+ return global.isDeviceData();
+ return false;
+ }
+
+ return false;
+}
+
+template <typename Ty>
+bool OpenACCPointerLikeModel<Ty>::isDeviceData(mlir::Type pointer,
+ mlir::Value var) const {
+ return isDeviceDataImpl(var);
+}
+
+template bool OpenACCPointerLikeModel<fir::ReferenceType>::isDeviceData(
+ mlir::Type, mlir::Value) const;
+template bool
+ OpenACCPointerLikeModel<fir::PointerType>::isDeviceData(mlir::Type,
+ mlir::Value) const;
+template bool
+ OpenACCPointerLikeModel<fir::HeapType>::isDeviceData(mlir::Type,
+ mlir::Value) const;
+template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::isDeviceData(
+ mlir::Type, mlir::Value) const;
+
+template <typename Ty>
+bool OpenACCMappableModel<Ty>::isDeviceData(mlir::Type type,
+ mlir::Value var) const {
+ return isDeviceDataImpl(var);
+}
+
+template bool
+ OpenACCMappableModel<fir::BaseBoxType>::isDeviceData(mlir::Type,
+ mlir::Value) const;
+template bool
+ OpenACCMappableModel<fir::ReferenceType>::isDeviceData(mlir::Type,
+ mlir::Value) const;
+template bool
+ OpenACCMappableModel<fir::HeapType>::isDeviceData(mlir::Type,
+ mlir::Value) const;
+template bool
+ OpenACCMappableModel<fir::PointerType>::isDeviceData(mlir::Type,
+ mlir::Value) const;
+
+std::optional<mlir::arith::AtomicRMWKind>
+OpenACCReducibleLogicalModel::getAtomicRMWKind(
+ mlir::Type type, mlir::acc::ReductionOperator redOp) const {
+ switch (redOp) {
+ case mlir::acc::ReductionOperator::AccLand:
+ return mlir::arith::AtomicRMWKind::andi;
+ case mlir::acc::ReductionOperator::AccLor:
+ return mlir::arith::AtomicRMWKind::ori;
+ case mlir::acc::ReductionOperator::AccEqv:
+ case mlir::acc::ReductionOperator::AccNeqv:
+ // Eqv and Neqv are valid for logical types but don't have a direct
+ // AtomicRMWKind mapping yet.
+ return std::nullopt;
+ default:
+ // Other reduction operators are not valid for logical types.
+ return std::nullopt;
+ }
+}
+
} // namespace fir::acc
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCUtils.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCUtils.cpp
new file mode 100644
index 0000000..a53ea92
--- /dev/null
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCUtils.cpp
@@ -0,0 +1,655 @@
+//===- FIROpenACCUtils.cpp - FIR OpenACC Utilities ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utility functions for FIR OpenACC support.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/OpenACC/Support/FIROpenACCUtils.h"
+#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/Complex.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Dialect/Support/FIRContext.h"
+#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Support/InternalNames.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+static constexpr llvm::StringRef accPrivateInitName = "acc.private.init";
+static constexpr llvm::StringRef accReductionInitName = "acc.reduction.init";
+
+std::string fir::acc::getVariableName(Value v, bool preferDemangledName) {
+ std::string srcName;
+ std::string prefix;
+ llvm::SmallVector<std::string, 4> arrayIndices;
+ bool iterate = true;
+ mlir::Operation *defOp;
+
+ // For integer constants, no need to further iterate - print their value
+ // immediately.
+ if (v.getDefiningOp()) {
+ IntegerAttr::ValueType val;
+ if (matchPattern(v.getDefiningOp(), m_ConstantInt(&val))) {
+ llvm::raw_string_ostream os(prefix);
+ val.print(os, /*isSigned=*/true);
+ return prefix;
+ }
+ }
+
+ while (v && (defOp = v.getDefiningOp()) && iterate) {
+ iterate =
+ llvm::TypeSwitch<mlir::Operation *, bool>(defOp)
+ .Case([&v](mlir::ViewLikeOpInterface op) {
+ v = op.getViewSource();
+ return true;
+ })
+ .Case([&v](fir::ReboxOp op) {
+ v = op.getBox();
+ return true;
+ })
+ .Case([&v](fir::EmboxOp op) {
+ v = op.getMemref();
+ return true;
+ })
+ .Case([&v](fir::ConvertOp op) {
+ v = op.getValue();
+ return true;
+ })
+ .Case([&v](fir::LoadOp op) {
+ v = op.getMemref();
+ return true;
+ })
+ .Case([&v](fir::BoxAddrOp op) {
+ // The box holds the name of the variable.
+ v = op.getVal();
+ return true;
+ })
+ .Case([&](fir::AddrOfOp op) {
+ // Only use address_of symbol if mangled name is preferred
+ if (!preferDemangledName) {
+ auto symRef = op.getSymbol();
+ srcName = symRef.getLeafReference().getValue().str();
+ }
+ return false;
+ })
+ .Case([&](fir::ArrayCoorOp op) {
+ v = op.getMemref();
+ for (auto coor : op.getIndices()) {
+ auto idxName = getVariableName(coor, preferDemangledName);
+ arrayIndices.push_back(idxName.empty() ? "?" : idxName);
+ }
+ return true;
+ })
+ .Case([&](fir::CoordinateOp op) {
+ std::optional<llvm::ArrayRef<int32_t>> fieldIndices =
+ op.getFieldIndices();
+ if (fieldIndices && fieldIndices->size() > 0 &&
+ (*fieldIndices)[0] != fir::CoordinateOp::kDynamicIndex) {
+ int fieldId = (*fieldIndices)[0];
+ mlir::Type baseType =
+ fir::getFortranElementType(op.getRef().getType());
+ if (auto recType = llvm::dyn_cast<fir::RecordType>(baseType)) {
+ srcName = recType.getTypeList()[fieldId].first;
+ }
+ }
+ if (!srcName.empty()) {
+ // If the field name is known - attempt to continue building
+ // name by looking at its parents.
+ prefix =
+ getVariableName(op.getRef(), preferDemangledName) + "%";
+ }
+ return false;
+ })
+ .Case([&](hlfir::DesignateOp op) {
+ if (op.getComponent()) {
+ srcName = op.getComponent().value().str();
+ prefix =
+ getVariableName(op.getMemref(), preferDemangledName) + "%";
+ return false;
+ }
+ for (auto coor : op.getIndices()) {
+ auto idxName = getVariableName(coor, preferDemangledName);
+ arrayIndices.push_back(idxName.empty() ? "?" : idxName);
+ }
+ v = op.getMemref();
+ return true;
+ })
+ .Case<fir::DeclareOp, hlfir::DeclareOp>([&](auto op) {
+ srcName = op.getUniqName().str();
+ return false;
+ })
+ .Case([&](fir::AllocaOp op) {
+ if (preferDemangledName) {
+ // Prefer demangled name (bindc_name over uniq_name)
+ srcName = op.getBindcName() ? *op.getBindcName()
+ : op.getUniqName() ? *op.getUniqName()
+ : "";
+ } else {
+ // Prefer mangled name (uniq_name over bindc_name)
+ srcName = op.getUniqName() ? *op.getUniqName()
+ : op.getBindcName() ? *op.getBindcName()
+ : "";
+ }
+ return false;
+ })
+ .Default([](mlir::Operation *) { return false; });
+ }
+
+ // Fallback to the default implementation.
+ if (srcName.empty())
+ return mlir::acc::getVariableName(v);
+
+ // Build array index suffix if present
+ std::string suffix;
+ if (!arrayIndices.empty()) {
+ llvm::raw_string_ostream os(suffix);
+ os << "(";
+ llvm::interleaveComma(arrayIndices, os);
+ os << ")";
+ }
+
+ // Names from FIR operations may be mangled.
+ // When the demangled name is requested - demangle it.
+ if (preferDemangledName) {
+ auto [kind, deconstructed] = fir::NameUniquer::deconstruct(srcName);
+ if (kind != fir::NameUniquer::NameKind::NOT_UNIQUED)
+ return prefix + deconstructed.name + suffix;
+ }
+
+ return prefix + srcName + suffix;
+}
+
+bool fir::acc::areAllBoundsConstant(llvm::ArrayRef<Value> bounds) {
+ for (auto bound : bounds) {
+ auto dataBound =
+ mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
+ if (!dataBound)
+ return false;
+
+ // Check if this bound has constant values
+ bool hasConstant = false;
+ if (dataBound.getLowerbound() && dataBound.getUpperbound())
+ hasConstant =
+ fir::getIntIfConstant(dataBound.getLowerbound()).has_value() &&
+ fir::getIntIfConstant(dataBound.getUpperbound()).has_value();
+ else if (dataBound.getExtent())
+ hasConstant = fir::getIntIfConstant(dataBound.getExtent()).has_value();
+
+ if (!hasConstant)
+ return false;
+ }
+ return true;
+}
+
+static std::string getBoundsString(llvm::ArrayRef<Value> bounds) {
+ if (bounds.empty())
+ return "";
+
+ std::string boundStr;
+ llvm::raw_string_ostream os(boundStr);
+ os << "_section_";
+
+ llvm::interleave(
+ bounds,
+ [&](Value bound) {
+ auto boundsOp =
+ mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
+ if (boundsOp.getLowerbound() &&
+ fir::getIntIfConstant(boundsOp.getLowerbound()) &&
+ boundsOp.getUpperbound() &&
+ fir::getIntIfConstant(boundsOp.getUpperbound())) {
+ os << "lb" << *fir::getIntIfConstant(boundsOp.getLowerbound())
+ << ".ub" << *fir::getIntIfConstant(boundsOp.getUpperbound());
+ } else if (boundsOp.getExtent() &&
+ fir::getIntIfConstant(boundsOp.getExtent())) {
+ os << "ext" << *fir::getIntIfConstant(boundsOp.getExtent());
+ } else {
+ os << "?";
+ }
+ },
+ [&] { os << "x"; });
+
+ return os.str();
+}
+
+static std::string getRecipeName(mlir::acc::RecipeKind kind, Type type,
+ const fir::KindMapping &kindMap,
+ llvm::ArrayRef<Value> bounds,
+ mlir::acc::ReductionOperator reductionOp =
+ mlir::acc::ReductionOperator::AccNone) {
+ assert(fir::isa_fir_type(type) && "getRecipeName expects a FIR type");
+
+ // Build the complete prefix with all components before calling
+ // getTypeAsString
+ std::string prefixStr;
+ llvm::raw_string_ostream prefixOS(prefixStr);
+
+ switch (kind) {
+ case mlir::acc::RecipeKind::private_recipe:
+ prefixOS << "privatization";
+ break;
+ case mlir::acc::RecipeKind::firstprivate_recipe:
+ prefixOS << "firstprivatization";
+ break;
+ case mlir::acc::RecipeKind::reduction_recipe:
+ prefixOS << "reduction";
+ // Embed the reduction operator in the prefix
+ if (reductionOp != mlir::acc::ReductionOperator::AccNone)
+ prefixOS << "_"
+ << mlir::acc::stringifyReductionOperator(reductionOp).str();
+ break;
+ }
+
+ if (!bounds.empty())
+ prefixOS << getBoundsString(bounds);
+
+ return fir::getTypeAsString(type, kindMap, prefixOS.str());
+}
+
+std::string fir::acc::getRecipeName(mlir::acc::RecipeKind kind, Type type,
+ Value var, llvm::ArrayRef<Value> bounds,
+ mlir::acc::ReductionOperator reductionOp) {
+ auto kindMap = var && var.getDefiningOp()
+ ? fir::getKindMapping(var.getDefiningOp())
+ : fir::KindMapping(type.getContext());
+ return ::getRecipeName(kind, type, kindMap, bounds, reductionOp);
+}
+
+/// Get the initial value for reduction operator.
+template <typename R>
+static R getReductionInitValue(mlir::acc::ReductionOperator op, mlir::Type ty) {
+ if (op == mlir::acc::ReductionOperator::AccMin) {
+ // min init value -> largest
+ if constexpr (std::is_same_v<R, llvm::APInt>) {
+ assert(ty.isIntOrIndex() && "expect integer or index type");
+ return llvm::APInt::getSignedMaxValue(ty.getIntOrFloatBitWidth());
+ }
+ if constexpr (std::is_same_v<R, llvm::APFloat>) {
+ auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty);
+ assert(floatTy && "expect float type");
+ return llvm::APFloat::getLargest(floatTy.getFloatSemantics(),
+ /*negative=*/false);
+ }
+ } else if (op == mlir::acc::ReductionOperator::AccMax) {
+ // max init value -> smallest
+ if constexpr (std::is_same_v<R, llvm::APInt>) {
+ assert(ty.isIntOrIndex() && "expect integer or index type");
+ return llvm::APInt::getSignedMinValue(ty.getIntOrFloatBitWidth());
+ }
+ if constexpr (std::is_same_v<R, llvm::APFloat>) {
+ auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty);
+ assert(floatTy && "expect float type");
+ return llvm::APFloat::getSmallest(floatTy.getFloatSemantics(),
+ /*negative=*/true);
+ }
+ } else if (op == mlir::acc::ReductionOperator::AccIand) {
+ if constexpr (std::is_same_v<R, llvm::APInt>) {
+ assert(ty.isIntOrIndex() && "expect integer type");
+ unsigned bits = ty.getIntOrFloatBitWidth();
+ return llvm::APInt::getAllOnes(bits);
+ }
+ } else {
+ assert(op != mlir::acc::ReductionOperator::AccNone);
+ // +, ior, ieor init value -> 0
+ // * init value -> 1
+ int64_t value = (op == mlir::acc::ReductionOperator::AccMul) ? 1 : 0;
+ if constexpr (std::is_same_v<R, llvm::APInt>) {
+ assert(ty.isIntOrIndex() && "expect integer or index type");
+ return llvm::APInt(ty.getIntOrFloatBitWidth(), value, true);
+ }
+
+ if constexpr (std::is_same_v<R, llvm::APFloat>) {
+ assert(mlir::isa<mlir::FloatType>(ty) && "expect float type");
+ auto floatTy = mlir::dyn_cast<mlir::FloatType>(ty);
+ return llvm::APFloat(floatTy.getFloatSemantics(), value);
+ }
+
+ if constexpr (std::is_same_v<R, int64_t>)
+ return value;
+ }
+ llvm_unreachable("OpenACC reduction unsupported type");
+}
+
+/// Return a constant with the initial value for the reduction operator and
+/// type combination.
+static mlir::Value getReductionInitValue(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Type varType,
+ mlir::acc::ReductionOperator op) {
+ mlir::Type ty = fir::getFortranElementType(varType);
+ if (op == mlir::acc::ReductionOperator::AccLand ||
+ op == mlir::acc::ReductionOperator::AccLor ||
+ op == mlir::acc::ReductionOperator::AccEqv ||
+ op == mlir::acc::ReductionOperator::AccNeqv) {
+ assert(mlir::isa<fir::LogicalType>(ty) && "expect fir.logical type");
+ bool value = true; // .true. for .and. and .eqv.
+ if (op == mlir::acc::ReductionOperator::AccLor ||
+ op == mlir::acc::ReductionOperator::AccNeqv)
+ value = false; // .false. for .or. and .neqv.
+ return builder.createBool(loc, value);
+ }
+ if (ty.isIntOrIndex())
+ return mlir::arith::ConstantOp::create(
+ builder, loc, ty,
+ builder.getIntegerAttr(ty, getReductionInitValue<llvm::APInt>(op, ty)));
+ if (op == mlir::acc::ReductionOperator::AccMin ||
+ op == mlir::acc::ReductionOperator::AccMax) {
+ if (mlir::isa<mlir::ComplexType>(ty))
+ llvm::report_fatal_error(
+ "min/max reduction not supported for complex type");
+ if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty))
+ return mlir::arith::ConstantOp::create(
+ builder, loc, ty,
+ builder.getFloatAttr(ty,
+ getReductionInitValue<llvm::APFloat>(op, ty)));
+ } else if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty)) {
+ return mlir::arith::ConstantOp::create(
+ builder, loc, ty,
+ builder.getFloatAttr(ty, getReductionInitValue<int64_t>(op, ty)));
+ } else if (auto cmplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(ty)) {
+ mlir::Type floatTy = cmplxTy.getElementType();
+ mlir::Value realInit = builder.createRealConstant(
+ loc, floatTy, getReductionInitValue<int64_t>(op, cmplxTy));
+ mlir::Value imagInit = builder.createRealConstant(loc, floatTy, 0.0);
+ return fir::factory::Complex{builder, loc}.createComplex(cmplxTy, realInit,
+ imagInit);
+ }
+ llvm::report_fatal_error("Unsupported OpenACC reduction type");
+}
+
+static llvm::SmallVector<mlir::Value>
+getRecipeBounds(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::ValueRange dataBoundOps,
+ mlir::ValueRange blockBoundArgs) {
+ if (dataBoundOps.empty())
+ return {};
+ mlir::Type idxTy = builder.getIndexType();
+ mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
+ llvm::SmallVector<mlir::Value> bounds;
+ if (!blockBoundArgs.empty()) {
+ for (unsigned i = 0; i + 2 < blockBoundArgs.size(); i += 3) {
+ bounds.push_back(blockBoundArgs[i]);
+ bounds.push_back(blockBoundArgs[i + 1]);
+ // acc data bound strides is the inner size in bytes or elements, but
+ // sections are always 1-based, so there is no need to try to compute
+ // that back from the acc bounds.
+ bounds.push_back(one);
+ }
+ return bounds;
+ }
+ for (auto bound : dataBoundOps) {
+ auto dataBound = llvm::dyn_cast_if_present<mlir::acc::DataBoundsOp>(
+ bound.getDefiningOp());
+ assert(dataBound && "expect acc bounds to be produced by DataBoundsOp");
+ assert(
+ dataBound.getLowerbound() && dataBound.getUpperbound() &&
+ "expect acc bounds for Fortran to always have lower and upper bounds");
+ std::optional<std::int64_t> lb =
+ fir::getIntIfConstant(dataBound.getLowerbound());
+ std::optional<std::int64_t> ub =
+ fir::getIntIfConstant(dataBound.getUpperbound());
+ assert(lb.has_value() && ub.has_value() &&
+ "must get constant bounds when there are no bound block arguments");
+ bounds.push_back(builder.createIntegerConstant(loc, idxTy, *lb));
+ bounds.push_back(builder.createIntegerConstant(loc, idxTy, *ub));
+ bounds.push_back(one);
+ }
+ return bounds;
+}
+
+static void addRecipeBoundsArgs(llvm::SmallVector<mlir::Value> &bounds,
+ bool allConstantBound,
+ llvm::SmallVector<mlir::Type> &argsTy,
+ llvm::SmallVector<mlir::Location> &argsLoc) {
+ if (!allConstantBound) {
+ for (mlir::Value bound : llvm::reverse(bounds)) {
+ auto dataBound =
+ mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
+ argsTy.push_back(dataBound.getLowerbound().getType());
+ argsLoc.push_back(dataBound.getLowerbound().getLoc());
+ argsTy.push_back(dataBound.getUpperbound().getType());
+ argsLoc.push_back(dataBound.getUpperbound().getLoc());
+ argsTy.push_back(dataBound.getStartIdx().getType());
+ argsLoc.push_back(dataBound.getStartIdx().getLoc());
+ }
+ }
+}
+
+using MappableValue = mlir::TypedValue<mlir::acc::MappableType>;
+
+// Generate the combiner or copy region block and block arguments and return the
+// source and destination entities.
+static std::pair<MappableValue, MappableValue>
+genRecipeCombinerOrCopyRegion(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Type ty, mlir::Region &region,
+ llvm::SmallVector<mlir::Value> &bounds,
+ bool allConstantBound) {
+ llvm::SmallVector<mlir::Type> argsTy{ty, ty};
+ llvm::SmallVector<mlir::Location> argsLoc{loc, loc};
+ addRecipeBoundsArgs(bounds, allConstantBound, argsTy, argsLoc);
+ mlir::Block *block =
+ builder.createBlock(&region, region.end(), argsTy, argsLoc);
+ builder.setInsertionPointToEnd(&region.back());
+ auto firstArg = mlir::cast<MappableValue>(block->getArgument(0));
+ auto secondArg = mlir::cast<MappableValue>(block->getArgument(1));
+ return {firstArg, secondArg};
+}
+
+template <typename RecipeOp>
+static RecipeOp genRecipeOp(
+ fir::FirOpBuilder &builder, mlir::ModuleOp mod, llvm::StringRef recipeName,
+ mlir::Location loc, mlir::Type ty,
+ llvm::SmallVector<mlir::Value> &dataOperationBounds, bool allConstantBound,
+ mlir::acc::ReductionOperator op = mlir::acc::ReductionOperator::AccNone) {
+ mlir::OpBuilder modBuilder(mod.getBodyRegion());
+ RecipeOp recipe;
+ if constexpr (std::is_same_v<RecipeOp, mlir::acc::ReductionRecipeOp>) {
+ recipe = mlir::acc::ReductionRecipeOp::create(modBuilder, loc, recipeName,
+ ty, op);
+ } else {
+ recipe = RecipeOp::create(modBuilder, loc, recipeName, ty);
+ }
+
+ assert(hlfir::isFortranVariableType(ty) && "expect Fortran variable type");
+
+ llvm::SmallVector<mlir::Type> argsTy{ty};
+ llvm::SmallVector<mlir::Location> argsLoc{loc};
+ if (!dataOperationBounds.empty())
+ addRecipeBoundsArgs(dataOperationBounds, allConstantBound, argsTy, argsLoc);
+
+ auto initBlock = builder.createBlock(
+ &recipe.getInitRegion(), recipe.getInitRegion().end(), argsTy, argsLoc);
+ builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
+ mlir::Value initValue;
+ if constexpr (std::is_same_v<RecipeOp, mlir::acc::ReductionRecipeOp>) {
+ assert(op != mlir::acc::ReductionOperator::AccNone);
+ initValue = getReductionInitValue(builder, loc, ty, op);
+ }
+
+ // Since we reuse the same recipe for all variables of the same type - we
+ // cannot use the actual variable name. Thus use a temporary name.
+ llvm::StringRef initName;
+ if constexpr (std::is_same_v<RecipeOp, mlir::acc::ReductionRecipeOp>)
+ initName = accReductionInitName;
+ else
+ initName = accPrivateInitName;
+
+ auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(ty);
+ assert(mappableTy &&
+ "Expected that all variable types are considered mappable");
+ bool needsDestroy = false;
+ llvm::SmallVector<mlir::Value> initBounds =
+ getRecipeBounds(builder, loc, dataOperationBounds,
+ initBlock->getArguments().drop_front(1));
+ mlir::Value retVal = mappableTy.generatePrivateInit(
+ builder, loc, mlir::cast<MappableValue>(initBlock->getArgument(0)),
+ initName, initBounds, initValue, needsDestroy);
+ mlir::acc::YieldOp::create(builder, loc, retVal);
+ // Create destroy region and generate destruction if requested.
+ if (needsDestroy) {
+ llvm::SmallVector<mlir::Type> destroyArgsTy;
+ llvm::SmallVector<mlir::Location> destroyArgsLoc;
+ // original and privatized/reduction value
+ destroyArgsTy.push_back(ty);
+ destroyArgsTy.push_back(ty);
+ destroyArgsLoc.push_back(loc);
+ destroyArgsLoc.push_back(loc);
+ // Append bounds arguments (if any) in the same order as init region
+ if (argsTy.size() > 1) {
+ destroyArgsTy.append(argsTy.begin() + 1, argsTy.end());
+ destroyArgsLoc.insert(destroyArgsLoc.end(), argsTy.size() - 1, loc);
+ }
+
+ mlir::Block *destroyBlock = builder.createBlock(
+ &recipe.getDestroyRegion(), recipe.getDestroyRegion().end(),
+ destroyArgsTy, destroyArgsLoc);
+ builder.setInsertionPointToEnd(destroyBlock);
+
+ llvm::SmallVector<mlir::Value> destroyBounds =
+ getRecipeBounds(builder, loc, dataOperationBounds,
+ destroyBlock->getArguments().drop_front(2));
+ [[maybe_unused]] bool success = mappableTy.generatePrivateDestroy(
+ builder, loc, destroyBlock->getArgument(1), destroyBounds);
+ assert(success && "failed to generate destroy region");
+ mlir::acc::TerminatorOp::create(builder, loc);
+ }
+ return recipe;
+}
+
+mlir::SymbolRefAttr
+fir::acc::createOrGetPrivateRecipe(mlir::OpBuilder &mlirBuilder,
+ mlir::Location loc, mlir::Type ty,
+ llvm::SmallVector<mlir::Value> &bounds) {
+ mlir::ModuleOp mod =
+ mlirBuilder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(mlirBuilder, mod);
+ std::string recipeName = ::getRecipeName(
+ mlir::acc::RecipeKind::private_recipe, ty, builder.getKindMap(), bounds);
+ if (auto recipe = mod.lookupSymbol<mlir::acc::PrivateRecipeOp>(recipeName))
+ return mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName());
+
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ bool allConstantBound = fir::acc::areAllBoundsConstant(bounds);
+ auto recipe = genRecipeOp<mlir::acc::PrivateRecipeOp>(
+ builder, mod, recipeName, loc, ty, bounds, allConstantBound);
+ return mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName());
+}
+
+mlir::SymbolRefAttr fir::acc::createOrGetFirstprivateRecipe(
+ mlir::OpBuilder &mlirBuilder, mlir::Location loc, mlir::Type ty,
+ llvm::SmallVector<mlir::Value> &dataBoundOps) {
+ mlir::ModuleOp mod =
+ mlirBuilder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(mlirBuilder, mod);
+ std::string recipeName =
+ ::getRecipeName(mlir::acc::RecipeKind::firstprivate_recipe, ty,
+ builder.getKindMap(), dataBoundOps);
+ if (auto recipe =
+ mod.lookupSymbol<mlir::acc::FirstprivateRecipeOp>(recipeName))
+ return mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName());
+
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ bool allConstantBound = fir::acc::areAllBoundsConstant(dataBoundOps);
+ auto recipe = genRecipeOp<mlir::acc::FirstprivateRecipeOp>(
+ builder, mod, recipeName, loc, ty, dataBoundOps, allConstantBound);
+ auto [source, destination] = genRecipeCombinerOrCopyRegion(
+ builder, loc, ty, recipe.getCopyRegion(), dataBoundOps, allConstantBound);
+ llvm::SmallVector<mlir::Value> copyBounds =
+ getRecipeBounds(builder, loc, dataBoundOps,
+ recipe.getCopyRegion().getArguments().drop_front(2));
+
+ auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(ty);
+ assert(mappableTy &&
+ "Expected that all variable types are considered mappable");
+ [[maybe_unused]] bool success =
+ mappableTy.generateCopy(builder, loc, source, destination, copyBounds);
+ assert(success && "failed to generate copy");
+ mlir::acc::TerminatorOp::create(builder, loc);
+ return mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName());
+}
+
+mlir::SymbolRefAttr fir::acc::createOrGetReductionRecipe(
+ mlir::OpBuilder &mlirBuilder, mlir::Location loc, mlir::Type ty,
+ mlir::acc::ReductionOperator op,
+ llvm::SmallVector<mlir::Value> &dataBoundOps,
+ mlir::Attribute fastMathAttr) {
+ mlir::ModuleOp mod =
+ mlirBuilder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(mlirBuilder, mod);
+ std::string recipeName =
+ ::getRecipeName(mlir::acc::RecipeKind::reduction_recipe, ty,
+ builder.getKindMap(), dataBoundOps, op);
+ if (auto recipe = mod.lookupSymbol<mlir::acc::ReductionRecipeOp>(recipeName))
+ return mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName());
+
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ bool allConstantBound = fir::acc::areAllBoundsConstant(dataBoundOps);
+ auto recipe = genRecipeOp<mlir::acc::ReductionRecipeOp>(
+ builder, mod, recipeName, loc, ty, dataBoundOps, allConstantBound, op);
+
+ auto [dest, source] = genRecipeCombinerOrCopyRegion(
+ builder, loc, ty, recipe.getCombinerRegion(), dataBoundOps,
+ allConstantBound);
+ llvm::SmallVector<mlir::Value> combinerBounds =
+ getRecipeBounds(builder, loc, dataBoundOps,
+ recipe.getCombinerRegion().getArguments().drop_front(2));
+
+ auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(ty);
+ assert(mappableTy &&
+ "Expected that all variable types are considered mappable");
+ [[maybe_unused]] bool success = mappableTy.generateCombiner(
+ builder, loc, dest, source, combinerBounds, op, fastMathAttr);
+ assert(success && "failed to generate combiner");
+ mlir::acc::YieldOp::create(builder, loc, dest);
+ return mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName());
+}
+
+mlir::Value fir::acc::getOriginalDef(mlir::Value value, bool stripDeclare) {
+ mlir::Value currentValue = value;
+
+ while (currentValue) {
+ auto *definingOp = currentValue.getDefiningOp();
+ if (!definingOp)
+ break;
+
+ if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(definingOp)) {
+ currentValue = convertOp.getValue();
+ continue;
+ }
+
+ if (auto viewLike = mlir::dyn_cast<mlir::ViewLikeOpInterface>(definingOp)) {
+ currentValue = viewLike.getViewSource();
+ continue;
+ }
+
+ if (stripDeclare) {
+ if (auto declareOp = mlir::dyn_cast<hlfir::DeclareOp>(definingOp)) {
+ currentValue = declareOp.getMemref();
+ continue;
+ }
+
+ if (auto declareOp = mlir::dyn_cast<fir::DeclareOp>(definingOp)) {
+ currentValue = declareOp.getMemref();
+ continue;
+ }
+ }
+ break;
+ }
+
+ return currentValue;
+}
diff --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
index 717bf34..c0be247 100644
--- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
@@ -11,8 +11,15 @@
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.h"
+
+#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h"
#include "flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h"
namespace fir::acc {
@@ -37,7 +44,62 @@ void registerOpenACCExtensions(mlir::DialectRegistry &registry) {
fir::LLVMPointerType::attachInterface<
OpenACCPointerLikeModel<fir::LLVMPointerType>>(*ctx);
+
+ fir::LogicalType::attachInterface<OpenACCReducibleLogicalModel>(*ctx);
+
+ fir::ArrayCoorOp::attachInterface<
+ PartialEntityAccessModel<fir::ArrayCoorOp>>(*ctx);
+ fir::CoordinateOp::attachInterface<
+ PartialEntityAccessModel<fir::CoordinateOp>>(*ctx);
+ fir::DeclareOp::attachInterface<PartialEntityAccessModel<fir::DeclareOp>>(
+ *ctx);
+
+ fir::AddrOfOp::attachInterface<AddressOfGlobalModel>(*ctx);
+ fir::GlobalOp::attachInterface<GlobalVariableModel>(*ctx);
+
+ fir::AllocaOp::attachInterface<IndirectGlobalAccessModel<fir::AllocaOp>>(
+ *ctx);
+ fir::EmboxOp::attachInterface<IndirectGlobalAccessModel<fir::EmboxOp>>(
+ *ctx);
+ fir::ReboxOp::attachInterface<IndirectGlobalAccessModel<fir::ReboxOp>>(
+ *ctx);
+ fir::TypeDescOp::attachInterface<
+ IndirectGlobalAccessModel<fir::TypeDescOp>>(*ctx);
+
+ // Attach OutlineRematerializationOpInterface to FIR operations that
+ // produce synthetic types (shapes, field indices) which cannot be passed
+ // as arguments to outlined regions and must be rematerialized inside.
+ fir::ShapeOp::attachInterface<OutlineRematerializationModel<fir::ShapeOp>>(
+ *ctx);
+ fir::ShapeShiftOp::attachInterface<
+ OutlineRematerializationModel<fir::ShapeShiftOp>>(*ctx);
+ fir::ShiftOp::attachInterface<OutlineRematerializationModel<fir::ShiftOp>>(
+ *ctx);
+ fir::FieldIndexOp::attachInterface<
+ OutlineRematerializationModel<fir::FieldIndexOp>>(*ctx);
+ });
+
+ // Register HLFIR operation interfaces
+ registry.addExtension(
+ +[](mlir::MLIRContext *ctx, hlfir::hlfirDialect *dialect) {
+ hlfir::DesignateOp::attachInterface<
+ PartialEntityAccessModel<hlfir::DesignateOp>>(*ctx);
+ hlfir::DeclareOp::attachInterface<
+ PartialEntityAccessModel<hlfir::DeclareOp>>(*ctx);
+ });
+
+ // Register CUF operation interfaces
+ registry.addExtension(+[](mlir::MLIRContext *ctx, cuf::CUFDialect *dialect) {
+ cuf::KernelOp::attachInterface<OffloadRegionModel<cuf::KernelOp>>(*ctx);
});
+
+ // Attach FIR dialect interfaces to OpenACC operations.
+ registry.addExtension(+[](mlir::MLIRContext *ctx,
+ mlir::acc::OpenACCDialect *dialect) {
+ mlir::acc::LoopOp::attachInterface<OperationMoveModel<mlir::acc::LoopOp>>(
+ *ctx);
+ });
+
registerAttrsExtensions(registry);
}
diff --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCInitializeFIRAnalyses.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCInitializeFIRAnalyses.cpp
new file mode 100644
index 0000000..679b29b
--- /dev/null
+++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCInitializeFIRAnalyses.cpp
@@ -0,0 +1,56 @@
+//===- ACCInitializeFIRAnalyses.cpp - Initialize FIR analyses ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass initializes analyses that can be reused by subsequent OpenACC
+// passes in the pipeline.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Analysis/AliasAnalysis.h"
+#include "flang/Optimizer/OpenACC/Analysis/FIROpenACCSupportAnalysis.h"
+#include "flang/Optimizer/OpenACC/Passes.h"
+#include "mlir/Analysis/AliasAnalysis.h"
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+
+namespace fir {
+namespace acc {
+#define GEN_PASS_DEF_ACCINITIALIZEFIRANALYSES
+#include "flang/Optimizer/OpenACC/Passes.h.inc"
+} // namespace acc
+} // namespace fir
+
+#define DEBUG_TYPE "acc-initialize-fir-analyses"
+
+namespace {
+
+/// This pass initializes analyses for reuse by subsequent OpenACC passes in the
+/// pipeline. It creates and caches analyses like OpenACCSupport so they can be
+/// retrieved by later passes using getAnalysis() or getCachedAnalysis().
+class ACCInitializeFIRAnalysesPass
+ : public fir::acc::impl::ACCInitializeFIRAnalysesBase<
+ ACCInitializeFIRAnalysesPass> {
+public:
+ void runOnOperation() override {
+ // Initialize OpenACCSupport with FIR-specific implementation.
+ auto &openACCSupport = getAnalysis<mlir::acc::OpenACCSupport>();
+ openACCSupport.setImplementation(fir::acc::FIROpenACCSupportAnalysis());
+
+ // Initialize AliasAnalysis with FIR-specific implementation.
+ auto &aliasAnalysis = getAnalysis<mlir::AliasAnalysis>();
+ aliasAnalysis.addAnalysisImplementation(fir::AliasAnalysis());
+
+ // Mark all analyses as preserved since this pass only initializes them
+ markAllAnalysesPreserved();
+ }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::acc::createACCInitializeFIRAnalysesPass() {
+ return std::make_unique<ACCInitializeFIRAnalysesPass>();
+}
diff --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCOptimizeFirstprivateMap.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCOptimizeFirstprivateMap.cpp
new file mode 100644
index 0000000..ec40e12
--- /dev/null
+++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCOptimizeFirstprivateMap.cpp
@@ -0,0 +1,193 @@
+//===- ACCOptimizeFirstprivateMap.cpp -------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass optimizes firstprivate mapping operations (acc.firstprivate_map).
+// The optimization hoists loads from the firstprivate variable to before the
+// compute region, effectively converting the firstprivate copy to a
+// pass-by-value pattern. This eliminates the need for runtime copying into
+// global memory.
+//
+// Example transformation:
+//
+// Before:
+// %decl = fir.declare %alloca : !fir.ref<i32>
+// %fp = acc.firstprivate_map varPtr(%decl) -> !fir.ref<i32>
+// acc.parallel {
+// %val = fir.load %fp : !fir.ref<i32> // load inside region
+// ...
+// }
+//
+// After:
+// %decl = fir.declare %alloca : !fir.ref<i32>
+// %val = fir.load %decl : !fir.ref<i32> // load hoisted before region
+// acc.parallel {
+// ... // uses %val directly
+// }
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Dialect/FortranVariableInterface.h"
+#include "flang/Optimizer/OpenACC/Passes.h"
+#include "flang/Optimizer/OpenACC/Support/FIROpenACCUtils.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace fir::acc {
+#define GEN_PASS_DEF_ACCOPTIMIZEFIRSTPRIVATEMAP
+#include "flang/Optimizer/OpenACC/Passes.h.inc"
+} // namespace fir::acc
+
+using namespace mlir;
+
+namespace {
+
+/// Returns the enclosing offload region interface, or nullptr if not inside
+/// one.
+static acc::OffloadRegionOpInterface getEnclosingOffloadRegion(Operation *op) {
+ return op->getParentOfType<acc::OffloadRegionOpInterface>();
+}
+
+/// Returns true if the value is defined by an OpenACC data clause operation.
+static bool isDefinedByDataClause(Value value) {
+ Operation *defOp = value.getDefiningOp();
+ if (!defOp)
+ return false;
+ return acc::getDataClause(defOp).has_value();
+}
+
+/// Returns true if the value is defined inside the given offload region.
+/// This handles both operation results and block arguments.
+static bool isDefinedInsideRegion(Value value,
+ acc::OffloadRegionOpInterface offloadOp) {
+ Region *valueRegion = value.getParentRegion();
+ if (!valueRegion)
+ return false;
+ return offloadOp.getOffloadRegion().isAncestor(valueRegion);
+}
+
+/// Returns true if the variable may be optional.
+static bool mayBeOptionalVariable(Value var) {
+ // Don't strip declare ops - we need to check the optional attribute on them.
+ Value originalDef = fir::acc::getOriginalDef(var, /*stripDeclare=*/false);
+ if (auto varIface = dyn_cast_or_null<fir::FortranVariableOpInterface>(
+ originalDef.getDefiningOp()))
+ return varIface.isOptional();
+ // If the defining op is an alloca, it's a local variable and not optional.
+ if (isa_and_nonnull<fir::AllocaOp, fir::AllocMemOp>(
+ originalDef.getDefiningOp()))
+ return false;
+ // Conservative: if we can't determine, assume it may be optional.
+ return true;
+}
+
+/// Returns true if the type is a reference to a trivial type.
+/// Note that this does not allow fir.heap, fir.ptr, or fir.llvm_ptr
+/// types - since we would need to check if the load is valid via
+/// a null-check to enable the optimization.
+static bool isRefToTrivialType(Type type) {
+ if (!mlir::isa<fir::ReferenceType>(type))
+ return false;
+ return fir::isa_trivial(fir::unwrapRefType(type));
+}
+
+/// Attempts to hoist loads from accVar to before firstprivateInitOp.
+/// Returns true if all uses of accVar are loads and they were hoisted.
+static bool hoistLoads(acc::FirstprivateMapInitialOp firstprivateInitOp,
+ Value var, Value accVar) {
+ // Check if all uses are loads - only hoist if we can optimize all uses.
+ bool allLoads = llvm::all_of(accVar.getUsers(), [](Operation *user) {
+ return isa<fir::LoadOp>(user);
+ });
+ if (!allLoads)
+ return false;
+
+ // Hoist all loads before the firstprivate_map operation.
+ for (Operation *user : llvm::make_early_inc_range(accVar.getUsers())) {
+ auto loadOp = cast<fir::LoadOp>(user);
+ loadOp.getMemrefMutable().assign(var);
+ loadOp->moveBefore(firstprivateInitOp);
+ }
+ return true;
+}
+
+class ACCOptimizeFirstprivateMap
+ : public fir::acc::impl::ACCOptimizeFirstprivateMapBase<
+ ACCOptimizeFirstprivateMap> {
+public:
+ void runOnOperation() override {
+ func::FuncOp funcOp = getOperation();
+
+ // Collect all firstprivate_map ops first to avoid modifying IR during walk.
+ llvm::SmallVector<acc::FirstprivateMapInitialOp> firstprivateOps;
+ funcOp.walk([&](acc::FirstprivateMapInitialOp op) {
+ firstprivateOps.push_back(op);
+ });
+
+ llvm::SmallVector<acc::FirstprivateMapInitialOp> opsToErase;
+
+ for (acc::FirstprivateMapInitialOp firstprivateInitOp : firstprivateOps) {
+ Value var = firstprivateInitOp.getVar();
+
+ if (auto offloadOp = getEnclosingOffloadRegion(firstprivateInitOp)) {
+ // Inside an offload region.
+ if (isDefinedByDataClause(var) ||
+ isDefinedInsideRegion(var, offloadOp)) {
+ // The variable is already mapped or defined locally - just replace
+ // uses and erase.
+ firstprivateInitOp.getAccVar().replaceAllUsesWith(var);
+ opsToErase.push_back(firstprivateInitOp);
+ } else {
+ // Variable is defined outside - hoist the op out of the region,
+ // then apply optimization.
+ firstprivateInitOp->moveBefore(offloadOp);
+ if (optimizeFirstprivateMapping(firstprivateInitOp))
+ opsToErase.push_back(firstprivateInitOp);
+ }
+ } else {
+ // Outside offload region, apply type-restricted optimization
+ // to pre-load before the compute region.
+ if (optimizeFirstprivateMapping(firstprivateInitOp))
+ opsToErase.push_back(firstprivateInitOp);
+ }
+ }
+
+ for (auto op : opsToErase)
+ op.erase();
+ }
+
+private:
+ /// Returns true if the operation was optimized and can be erased.
+ static bool optimizeFirstprivateMapping(
+ acc::FirstprivateMapInitialOp firstprivateInitOp) {
+ Value var = firstprivateInitOp.getVar();
+ Value accVar = firstprivateInitOp.getAccVar();
+
+ // If there are no uses, we can erase the operation.
+ if (accVar.use_empty())
+ return true;
+
+ // Only optimize references to trivial types.
+ if (!isRefToTrivialType(var.getType()))
+ return false;
+
+ // Avoid hoisting optional variables as they may be
+ // null and thus not safe to access.
+ if (mayBeOptionalVariable(var))
+ return false;
+
+ return hoistLoads(firstprivateInitOp, var, accVar);
+ }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> fir::acc::createACCOptimizeFirstprivateMapPass() {
+ return std::make_unique<ACCOptimizeFirstprivateMap>();
+}
diff --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp
index 4840a99..ad0cfa3 100644
--- a/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp
+++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp
@@ -39,13 +39,13 @@ public:
static mlir::Operation *load(mlir::OpBuilder &builder, mlir::Location loc,
mlir::Value value) {
- return builder.create<fir::LoadOp>(loc, value);
+ return fir::LoadOp::create(builder, loc, value);
}
static mlir::Value placeInMemory(mlir::OpBuilder &builder, mlir::Location loc,
mlir::Value value) {
- auto alloca = builder.create<fir::AllocaOp>(loc, value.getType());
- builder.create<fir::StoreOp>(loc, value, alloca);
+ auto alloca = fir::AllocaOp::create(builder, loc, value.getType());
+ fir::StoreOp::create(builder, loc, value, alloca);
return alloca;
}
};
@@ -87,30 +87,26 @@ static void bufferizeRegionArgsAndYields(mlir::Region &region,
}
}
-static void updateRecipeUse(mlir::ArrayAttr recipes, mlir::ValueRange operands,
+template <typename OpTy>
+static void updateRecipeUse(mlir::ValueRange operands,
llvm::StringRef recipeSymName,
mlir::Operation *computeOp) {
- if (!recipes)
- return;
- for (auto [recipeSym, oldRes] : llvm::zip(recipes, operands)) {
- if (llvm::cast<mlir::SymbolRefAttr>(recipeSym).getLeafReference() !=
- recipeSymName)
+ for (auto operand : operands) {
+ auto op = operand.getDefiningOp<OpTy>();
+ if (!op || !op.getRecipe().has_value() ||
+ op.getRecipeAttr().getLeafReference() != recipeSymName)
continue;
- mlir::Operation *dataOp = oldRes.getDefiningOp();
- assert(dataOp && "dataOp must be paired with computeOp");
- mlir::Location loc = dataOp->getLoc();
- mlir::OpBuilder builder(dataOp);
- llvm::TypeSwitch<mlir::Operation *, void>(dataOp)
- .Case<mlir::acc::PrivateOp, mlir::acc::FirstprivateOp,
- mlir::acc::ReductionOp>([&](auto privateOp) {
- builder.setInsertionPointAfterValue(privateOp.getVar());
- mlir::Value alloca = BufferizeInterface::placeInMemory(
- builder, loc, privateOp.getVar());
- privateOp.getVarMutable().assign(alloca);
- privateOp.getAccVar().setType(alloca.getType());
- });
+ mlir::Location loc = op->getLoc();
+
+ mlir::OpBuilder builder(op);
+ builder.setInsertionPointAfterValue(op.getVar());
+ mlir::Value alloca =
+ BufferizeInterface::placeInMemory(builder, loc, op.getVar());
+ op.getVarMutable().assign(alloca);
+ op.getAccVar().setType(alloca.getType());
+ mlir::Value oldRes = op.getAccVar();
llvm::SmallVector<mlir::Operation *> users(oldRes.getUsers().begin(),
oldRes.getUsers().end());
for (mlir::Operation *useOp : users) {
@@ -166,18 +162,15 @@ public:
.Case<mlir::acc::LoopOp, mlir::acc::ParallelOp, mlir::acc::SerialOp>(
[&](auto computeOp) {
for (llvm::StringRef recipeName : recipeNames) {
- if (computeOp.getPrivatizationRecipes())
- updateRecipeUse(computeOp.getPrivatizationRecipesAttr(),
- computeOp.getPrivateOperands(), recipeName,
- op);
- if (computeOp.getFirstprivatizationRecipes())
- updateRecipeUse(
- computeOp.getFirstprivatizationRecipesAttr(),
+ if (!computeOp.getPrivateOperands().empty())
+ updateRecipeUse<mlir::acc::PrivateOp>(
+ computeOp.getPrivateOperands(), recipeName, op);
+ if (!computeOp.getFirstprivateOperands().empty())
+ updateRecipeUse<mlir::acc::FirstprivateOp>(
computeOp.getFirstprivateOperands(), recipeName, op);
- if (computeOp.getReductionRecipes())
- updateRecipeUse(computeOp.getReductionRecipesAttr(),
- computeOp.getReductionOperands(),
- recipeName, op);
+ if (!computeOp.getReductionOperands().empty())
+ updateRecipeUse<mlir::acc::ReductionOp>(
+ computeOp.getReductionOperands(), recipeName, op);
}
});
});
diff --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCUseDeviceCanonicalizer.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCUseDeviceCanonicalizer.cpp
new file mode 100644
index 0000000..51ab7960
--- /dev/null
+++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCUseDeviceCanonicalizer.cpp
@@ -0,0 +1,400 @@
+//===- ACCUseDeviceCanonicalizer.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass canonicalizes the use_device clause on a host_data construct such
+// that use_device(x) can be lowered to a simple runtime call that takes the
+// actual host pointer as argument.
+//
+// For a use_device operand that is a box type or a reference to a box, the
+// pass:
+// 1. Extracts the host base address for mapping to a device address using
+// acc.use_device.
+// 2. Creates a new boxed descriptor with the device address as the base
+// address for use inside the host_data region.
+//
+// The pass also removes unused use_device clauses, reducing the number of
+// runtime calls.
+//
+// Supported use_device operand types:
+//
+// Scalars:
+// - !fir.ref<i32>, !fir.ref<f64>, etc.
+//
+// Arrays:
+// - Explicit shape (no descriptor): !fir.ref<!fir.array<100xi32>>
+// - Adjustable size: !fir.ref<!fir.array<?xi32>>
+// - Assumed shape (handled by hoistBox): !fir.box<!fir.array<?xi32>>
+// - Assumed size: !fir.ref<!fir.array<?xi32>>
+// - Deferred shape (handled by hoistRefToBox):
+// - Allocatable: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// - Pointer: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
+// - Subarray specification (handled by hoistBox):
+// !fir.box<!fir.array<?xi32>>
+//
+// Not yet supported:
+// - Assumed rank arrays
+// - Composite variables: !fir.ref<!fir.type<...>>
+// - Array elements (device pointer arithmetic in host_data region)
+// - Composite variable members
+// - Fortran common blocks: use_device(/cm_block/)
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/OpenACC/Passes.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+#include <cassert>
+
+namespace fir::acc {
+#define GEN_PASS_DEF_ACCUSEDEVICECANONICALIZER
+#include "flang/Optimizer/OpenACC/Passes.h.inc"
+} // namespace fir::acc
+
+#define DEBUG_TYPE "acc-use-device-canonicalizer"
+
+using namespace mlir;
+
+namespace {
+
+struct UseDeviceHostDataHoisting : public OpRewritePattern<acc::HostDataOp> {
+ using OpRewritePattern<acc::HostDataOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(acc::HostDataOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value> usedOperands;
+ SmallVector<Value> unusedUseDeviceOperands;
+ SmallVector<acc::UseDeviceOp> refToBoxUseDeviceOps;
+ SmallVector<acc::UseDeviceOp> boxUseDeviceOps;
+
+ for (Value operand : op.getDataClauseOperands()) {
+ if (acc::UseDeviceOp useDeviceOp =
+ operand.getDefiningOp<acc::UseDeviceOp>()) {
+ if (fir::isBoxAddress(useDeviceOp.getVar().getType())) {
+ if (!llvm::hasSingleElement(useDeviceOp->getUsers()))
+ refToBoxUseDeviceOps.push_back(useDeviceOp);
+ } else if (isa<fir::BoxType>(useDeviceOp.getVar().getType())) {
+ if (!llvm::hasSingleElement(useDeviceOp->getUsers()))
+ boxUseDeviceOps.push_back(useDeviceOp);
+ }
+
+ // host_data is the only user of this use_device operand - mark for
+ // removal
+ if (llvm::hasSingleElement(useDeviceOp->getUsers()))
+ unusedUseDeviceOperands.push_back(useDeviceOp.getResult());
+ else
+ usedOperands.push_back(useDeviceOp.getResult());
+ } else {
+ // Operand is not an `acc.use_device` result, keep it as is.
+ usedOperands.push_back(operand);
+ }
+ }
+
+ assert(!usedOperands.empty() && "Host_data operation has no used operands");
+
+ if (!unusedUseDeviceOperands.empty()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "ACCUseDeviceCanonicalizer: Removing "
+ << unusedUseDeviceOperands.size()
+ << " unused use_device operands from host_data operation\n");
+
+ // Update the host_data operation to have only used operands
+ rewriter.modifyOpInPlace(op, [&]() {
+ op.getDataClauseOperandsMutable().assign(usedOperands);
+ });
+
+ // Remove unused use_device operations
+ for (Value operand : unusedUseDeviceOperands) {
+ acc::UseDeviceOp useDeviceOp =
+ operand.getDefiningOp<acc::UseDeviceOp>();
+ LLVM_DEBUG(llvm::dbgs() << "ACCUseDeviceCanonicalizer: Erasing: "
+ << *useDeviceOp << "\n");
+ rewriter.eraseOp(useDeviceOp);
+ }
+ return success();
+ }
+
+ // Handle references to box types
+ bool modified = false;
+ for (acc::UseDeviceOp useDeviceOp : refToBoxUseDeviceOps)
+ modified |=
+ hoistRefToBox(rewriter, useDeviceOp.getResult(), useDeviceOp, op);
+
+ // Handle box types
+ for (acc::UseDeviceOp useDeviceOp : boxUseDeviceOps)
+ modified |= hoistBox(rewriter, useDeviceOp.getResult(), useDeviceOp, op);
+
+ return modified ? success() : failure();
+ }
+
+private:
+ /// Collect users of `acc.use_device` operation inside the `acc.host_data`
+ /// region that need to be updated with the final replacement value.
+ void collectUseDeviceUsersToUpdate(
+ acc::UseDeviceOp useDeviceOp, acc::HostDataOp hostDataOp,
+ SmallVectorImpl<Operation *> &usersToUpdate) const {
+ for (mlir::Operation *user : useDeviceOp->getUsers())
+ if (hostDataOp.getRegion().isAncestor(user->getParentRegion()))
+ usersToUpdate.push_back(user);
+ }
+
+ /// Create new `acc.use_device` operation with the given box address as
+ /// operand. Updates the `acc.host_data` operation to use the new
+ /// `acc.use_device` result.
+ acc::UseDeviceOp createNewUseDeviceOp(PatternRewriter &rewriter,
+ acc::UseDeviceOp useDeviceOp,
+ acc::HostDataOp hostDataOp,
+ fir::BoxAddrOp boxAddr) const {
+ // Create use_device on the raw pointer
+ acc::UseDeviceOp newUseDeviceOp = acc::UseDeviceOp::create(
+ rewriter, useDeviceOp.getLoc(), boxAddr.getType(), boxAddr.getResult(),
+ useDeviceOp.getVarTypeAttr(), useDeviceOp.getVarPtrPtr(),
+ useDeviceOp.getBounds(), useDeviceOp.getAsyncOperands(),
+ useDeviceOp.getAsyncOperandsDeviceTypeAttr(),
+ useDeviceOp.getAsyncOnlyAttr(), useDeviceOp.getDataClauseAttr(),
+ useDeviceOp.getStructuredAttr(), useDeviceOp.getImplicitAttr(),
+ useDeviceOp.getModifiersAttr(), useDeviceOp.getNameAttr(),
+ useDeviceOp.getRecipeAttr());
+
+ LLVM_DEBUG(llvm::dbgs() << "Created new hoisted pattern for box access:\n"
+ << " box_addr: " << *boxAddr << "\n"
+ << " new use_device: " << *newUseDeviceOp << "\n");
+
+ // Replace the old `acc.use_device` operand in the `acc.host_data` operation
+ // with the new one
+ rewriter.modifyOpInPlace(hostDataOp, [&]() {
+ hostDataOp->replaceUsesOfWith(useDeviceOp.getResult(),
+ newUseDeviceOp.getResult());
+ });
+
+ return newUseDeviceOp;
+ }
+
+ /// Canonicalize use_device operand that is a reference to a box.
+ /// Transforms:
+ /// %3 = fir.address_of(@_QFEtgt) : !fir.ref<i32>
+ /// %5 = fir.embox %3 : (!fir.ref<i32>) -> !fir.box<!fir.ptr<i32>>
+ /// fir.store %5 to %0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
+ /// %9 = acc.use_device varPtr(%0 : !fir.ref<!fir.box<!fir.ptr<i32>>>)
+ /// -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "ptr"}
+ /// acc.host_data dataOperands(%9 : !fir.ref<!fir.box<!fir.ptr<i32>>>) {
+ /// %loaded = fir.load %9 : !fir.ref<!fir.box<!fir.ptr<i32>>>
+ /// %addr = fir.box_addr %loaded : (!fir.box<!fir.ptr<i32>>) ->
+ /// !fir.ptr<i32> %conv = fir.convert %addr : (!fir.ptr<i32>) -> i64
+ /// fir.call @foo(%conv) : (i64) -> ()
+ /// acc.terminator
+ /// }
+ /// into:
+ /// %loaded = fir.load %0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
+ /// %addr = fir.box_addr %loaded : (!fir.box<!fir.ptr<i32>>) ->
+ /// !fir.ptr<i32>
+ /// %dev_ptr = acc.use_device varPtr(%addr : !fir.ptr<i32>) ->
+ /// !fir.ptr<i32>
+ /// -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "ptr"}
+ /// acc.host_data dataOperands(%dev_ptr : !fir.ref<!fir.box<!fir.ptr<i32>>>)
+ /// {
+ /// %embox = fir.embox %dev_ptr : (!fir.ptr<i32>) ->
+ /// !fir.box<!fir.ptr<i32>> %alloca = fir.alloca !fir.box<!fir.ptr<i32>>
+ /// fir.store %embox to %alloca : !fir.ref<!fir.box<!fir.ptr<i32>>>
+ /// %loaded2 = fir.load %alloca : !fir.ref<!fir.box<!fir.ptr<i32>>>
+ /// %addr2 = fir.box_addr %loaded2 : (!fir.box<!fir.ptr<i32>>) ->
+ /// !fir.ptr<i32> %conv = fir.convert %addr2 : (!fir.ptr<i32>) -> i64
+ /// fir.call @foo(%conv) : (i64) -> ()
+ /// acc.terminator
+ /// }
+ bool hoistRefToBox(PatternRewriter &rewriter, Value operand,
+ acc::UseDeviceOp useDeviceOp,
+ acc::HostDataOp hostDataOp) const {
+
+ // Safety check: if the use_device operation is already using a box_addr
+ // result, it means it has already been processed, so skip to avoid infinite
+ // loop
+ if (useDeviceOp.getVar().getDefiningOp<fir::BoxAddrOp>()) {
+ LLVM_DEBUG(llvm::dbgs() << "ACCUseDeviceCanonicalizer: Skipping "
+ "already processed use_device operation\n");
+ return false;
+ }
+ // Get the ModuleOp before we erase useDeviceOp to avoid invalid reference
+ ModuleOp mod = useDeviceOp->getParentOfType<ModuleOp>();
+
+ // Collect users of the original `acc.use_device` operation that need to be
+ // updated
+ SmallVector<Operation *> usersToUpdate;
+ collectUseDeviceUsersToUpdate(useDeviceOp, hostDataOp, usersToUpdate);
+
+ rewriter.setInsertionPoint(useDeviceOp);
+ // Create a load operation to get the box from the variable
+ fir::LoadOp box = fir::LoadOp::create(rewriter, useDeviceOp.getLoc(),
+ useDeviceOp.getVar());
+ // Create a box_addr operation to get the address from the box
+ fir::BoxAddrOp boxAddr =
+ fir::BoxAddrOp::create(rewriter, useDeviceOp.getLoc(), box);
+
+ acc::UseDeviceOp newUseDeviceOp =
+ createNewUseDeviceOp(rewriter, useDeviceOp, hostDataOp, boxAddr);
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "Created new hoisted pattern for pointer access:\n"
+ << " load box: " << *box << "\n"
+ << " box_addr: " << *boxAddr << "\n"
+ << " new use_device: " << *newUseDeviceOp << "\n");
+
+ // Set insertion point to the first op inside the host_data region
+ rewriter.setInsertionPoint(&hostDataOp.getRegion().front().front());
+
+ // Create a FirOpBuilder from the PatternRewriter using the module we got
+ // earlier
+ fir::FirOpBuilder builder(rewriter, mod);
+ Value newBoxwithDevicePtr = fir::factory::getDescriptorWithNewBaseAddress(
+ builder, useDeviceOp.getLoc(), box.getResult(),
+ newUseDeviceOp.getResult());
+
+ // Create new memory location and store the newBoxwithDevicePtr into new
+ // memory location
+ fir::AllocaOp newMemLoc = fir::AllocaOp::create(
+ rewriter, useDeviceOp.getLoc(), newBoxwithDevicePtr.getType());
+ [[maybe_unused]] fir::StoreOp newStoreOp = fir::StoreOp::create(
+ rewriter, useDeviceOp.getLoc(), newBoxwithDevicePtr, newMemLoc);
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "host_data region updated with new host descriptor "
+ "containing device pointer:\n"
+ << " box with device pointer: "
+ << *newBoxwithDevicePtr.getDefiningOp() << "\n"
+ << " mem loc: " << *newMemLoc << "\n"
+ << " store op: " << *newStoreOp << "\n");
+
+ // Replace all uses of the original `acc.use_device` operation inside the
+ // `acc.host_data` region with the new memory location containing the box
+ // with device pointer
+ for (mlir::Operation *user : usersToUpdate)
+ user->replaceUsesOfWith(useDeviceOp.getResult(), newMemLoc);
+
+ assert(useDeviceOp.getResult().use_empty() &&
+ "expected all uses of use_device to be replaced");
+ rewriter.eraseOp(useDeviceOp);
+ return true;
+ }
+
+ /// Canonicalize use_device operand that is a box type.
+ /// Transforms:
+ /// %box = ... : !fir.box<!fir.array<?xi32>>
+ /// %dev_box = acc.use_device varPtr(%box : !fir.box<!fir.array<?xi32>>)
+ /// -> !fir.box<!fir.array<?xi32>>
+ /// acc.host_data dataOperands(%dev_box : !fir.box<!fir.array<?xi32>>) {
+ /// %addr = fir.box_addr %dev_box : (!fir.box<!fir.array<?xi32>>) ->
+ /// !fir.heap<!fir.array<?xi32>>
+ /// // use %addr
+ /// }
+ /// into:
+ /// %box = ... : !fir.box<!fir.array<?xi32>>
+ /// %addr = fir.box_addr %box : (!fir.box<!fir.array<?xi32>>) ->
+ /// !fir.heap<!fir.array<?xi32>>
+ /// %dev_ptr = acc.use_device varPtr(%addr : !fir.heap<!fir.array<?xi32>>)
+ /// -> !fir.heap<!fir.array<?xi32>>
+ /// acc.host_data dataOperands(%dev_ptr : !fir.heap<!fir.array<?xi32>>) {
+ /// %new_box = fir.embox %dev_ptr ... : !fir.box<!fir.array<?xi32>>
+ /// %new_addr = fir.box_addr %new_box : (!fir.box<!fir.array<?xi32>>) ->
+ /// !fir.heap<!fir.array<?xi32>>
+ /// // use %new_addr instead of %addr
+ /// }
+ bool hoistBox(PatternRewriter &rewriter, Value operand,
+ acc::UseDeviceOp useDeviceOp,
+ acc::HostDataOp hostDataOp) const {
+
+ // Safety check: if the use_device operation is already using a box_addr
+ // result, it means it has already been processed, so skip to avoid infinite
+ // loop
+ if (useDeviceOp.getVar().getDefiningOp<fir::BoxAddrOp>()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "ACCUseDeviceCanonicalizer: Skipping "
+ "already processed box use_device operation\n");
+ return false;
+ }
+
+ // Collect users of the original `acc.use_device` operation that need to be
+ // updated
+ SmallVector<Operation *> usersToUpdate;
+ collectUseDeviceUsersToUpdate(useDeviceOp, hostDataOp, usersToUpdate);
+
+ // Get the ModuleOp before we erase useDeviceOp to avoid invalid reference
+ ModuleOp mod = useDeviceOp->getParentOfType<ModuleOp>();
+
+ rewriter.setInsertionPoint(useDeviceOp);
+ // Extract the raw pointer from the box descriptor
+ fir::BoxAddrOp boxAddr = fir::BoxAddrOp::create(
+ rewriter, useDeviceOp.getLoc(), useDeviceOp.getVar());
+
+ acc::UseDeviceOp newUseDeviceOp =
+ createNewUseDeviceOp(rewriter, useDeviceOp, hostDataOp, boxAddr);
+
+ // Set insertion point to the first op inside the host_data region
+ rewriter.setInsertionPoint(&hostDataOp.getRegion().front().front());
+
+ // Create a FirOpBuilder from the PatternRewriter using the module we got
+ // earlier
+ fir::FirOpBuilder builder(rewriter, mod);
+
+ // Create a new host descriptor at the start of the host_data region
+ // with the device pointer as the base address
+ Value newBoxWithDevicePtr = fir::factory::getDescriptorWithNewBaseAddress(
+ builder, useDeviceOp.getLoc(), useDeviceOp.getVar(),
+ newUseDeviceOp.getResult());
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "host_data region updated with new host descriptor "
+ "containing device pointer:\n"
+ << " box with device pointer: "
+ << *newBoxWithDevicePtr.getDefiningOp() << "\n");
+
+ // Replace all uses of the original `acc.use_device` operation inside the
+ // `acc.host_data` region with the new box containing device pointer
+ for (mlir::Operation *user : usersToUpdate)
+ user->replaceUsesOfWith(useDeviceOp.getResult(), newBoxWithDevicePtr);
+
+ assert(useDeviceOp.getResult().use_empty() &&
+ "expected all uses of use_device to be replaced");
+ rewriter.eraseOp(useDeviceOp);
+ return true;
+ }
+};
+
+class ACCUseDeviceCanonicalizer
+ : public fir::acc::impl::ACCUseDeviceCanonicalizerBase<
+ ACCUseDeviceCanonicalizer> {
+public:
+ void runOnOperation() override {
+ MLIRContext *context = getOperation()->getContext();
+
+ RewritePatternSet patterns(context);
+
+ // Add the custom use_device canonicalization patterns
+ patterns.insert<UseDeviceHostDataHoisting>(context);
+
+ // Apply patterns greedily
+ GreedyRewriteConfig config;
+ // Prevent the pattern driver from merging blocks.
+ config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled);
+ config.setUseTopDownTraversal(true);
+
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
+ }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::acc::createACCUseDeviceCanonicalizerPass() {
+ return std::make_unique<ACCUseDeviceCanonicalizer>();
+}
diff --git a/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt
index ed177ba..27c5ee6 100644
--- a/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt
@@ -1,14 +1,25 @@
add_flang_library(FIROpenACCTransforms
+ ACCInitializeFIRAnalyses.cpp
+ ACCOptimizeFirstprivateMap.cpp
ACCRecipeBufferization.cpp
+ ACCUseDeviceCanonicalizer.cpp
DEPENDS
FIROpenACCPassesIncGen
LINK_LIBS
+ FIRAnalysis
+ FIRBuilder
FIRDialect
+ FIRDialectSupport
+ FIROpenACCAnalysis
+ FIROpenACCSupport
+ HLFIRDialect
MLIR_LIBS
MLIRIR
MLIRPass
MLIROpenACCDialect
+ MLIROpenACCUtils
+ MLIRTransformUtils
)
diff --git a/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp b/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp
index 8b99913..5793d46 100644
--- a/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp
+++ b/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp
@@ -20,8 +20,6 @@
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
-#include "llvm/Frontend/OpenMP/OMPConstants.h"
-
namespace flangomp {
#define GEN_PASS_DEF_AUTOMAPTOTARGETDATAPASS
#include "flang/Optimizer/OpenMP/Passes.h.inc"
@@ -120,12 +118,9 @@ class AutomapToTargetDataPass
builder, memOp.getLoc(), memOp.getMemref().getType(),
memOp.getMemref(),
TypeAttr::get(fir::unwrapRefType(memOp.getMemref().getType())),
- builder.getIntegerAttr(
- builder.getIntegerType(64, false),
- static_cast<unsigned>(
- isa<fir::StoreOp>(memOp)
- ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
- : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)),
+ builder.getAttr<omp::ClauseMapFlagsAttr>(
+ isa<fir::StoreOp>(memOp) ? omp::ClauseMapFlags::to
+ : omp::ClauseMapFlags::del),
builder.getAttr<omp::VariableCaptureKindAttr>(
omp::VariableCaptureKind::ByCopy),
/*var_ptr_ptr=*/mlir::Value{},
@@ -135,8 +130,8 @@ class AutomapToTargetDataPass
builder.getBoolAttr(false));
clauses.mapVars.push_back(mapInfo);
isa<fir::StoreOp>(memOp)
- ? builder.create<omp::TargetEnterDataOp>(memOp.getLoc(), clauses)
- : builder.create<omp::TargetExitDataOp>(memOp.getLoc(), clauses);
+ ? omp::TargetEnterDataOp::create(builder, memOp.getLoc(), clauses)
+ : omp::TargetExitDataOp::create(builder, memOp.getLoc(), clauses);
};
for (fir::GlobalOp globalOp : automapGlobals) {
diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
index 03ff163..ff346e7 100644
--- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
@@ -22,7 +22,6 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SmallPtrSet.h"
-#include "llvm/Frontend/OpenMP/OMPConstants.h"
namespace flangomp {
#define GEN_PASS_DEF_DOCONCURRENTCONVERSIONPASS
@@ -484,6 +483,8 @@ private:
}
loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
+ loopNestClauseOps.collapseNumLoops =
+ rewriter.getI64IntegerAttr(loopNestClauseOps.loopLowerBounds.size());
}
std::pair<mlir::omp::LoopNestOp, mlir::omp::WsloopOp>
@@ -568,16 +569,15 @@ private:
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(liveInType))
eleType = refType.getElementType();
- llvm::omp::OpenMPOffloadMappingFlags mapFlag =
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
+ mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::implicit;
mlir::omp::VariableCaptureKind captureKind =
mlir::omp::VariableCaptureKind::ByRef;
if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) {
captureKind = mlir::omp::VariableCaptureKind::ByCopy;
} else if (!fir::isa_builtin_cptr_type(eleType)) {
- mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
- mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+ mapFlag |= mlir::omp::ClauseMapFlags::to;
+ mapFlag |= mlir::omp::ClauseMapFlags::from;
}
llvm::SmallVector<mlir::Value> boundsOps;
@@ -587,11 +587,8 @@ private:
builder, liveIn.getLoc(), rawAddr,
/*varPtrPtr=*/{}, name.str(), boundsOps,
/*members=*/{},
- /*membersIndex=*/mlir::ArrayAttr{},
- static_cast<
- std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
- mapFlag),
- captureKind, rawAddr.getType());
+ /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind,
+ rawAddr.getType());
}
mlir::omp::TargetOp
@@ -600,7 +597,7 @@ private:
mlir::omp::TargetOperands &clauseOps,
mlir::omp::LoopNestOperands &loopNestClauseOps,
const LiveInShapeInfoMap &liveInShapeInfoMap) const {
- auto targetOp = rewriter.create<mlir::omp::TargetOp>(loc, clauseOps);
+ auto targetOp = mlir::omp::TargetOp::create(rewriter, loc, clauseOps);
auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp);
mlir::Region &region = targetOp.getRegion();
@@ -677,7 +674,7 @@ private:
// temporary.
Fortran::utils::openmp::cloneOrMapRegionOutsiders(builder, targetOp);
rewriter.setInsertionPoint(
- rewriter.create<mlir::omp::TerminatorOp>(targetOp.getLoc()));
+ mlir::omp::TerminatorOp::create(rewriter, targetOp.getLoc()));
return targetOp;
}
@@ -697,9 +694,6 @@ private:
if (!targetShapeCreationInfo.isShapedValue())
return {};
- llvm::SmallVector<mlir::Value> extentOperands;
- llvm::SmallVector<mlir::Value> startIndexOperands;
-
if (targetShapeCreationInfo.isShapeShiftedValue()) {
llvm::SmallVector<mlir::Value> shapeShiftOperands;
@@ -720,8 +714,8 @@ private:
auto shapeShiftType = fir::ShapeShiftType::get(
builder.getContext(), shapeShiftOperands.size() / 2);
- return builder.create<fir::ShapeShiftOp>(
- liveInArg.getLoc(), shapeShiftType, shapeShiftOperands);
+ return fir::ShapeShiftOp::create(builder, liveInArg.getLoc(),
+ shapeShiftType, shapeShiftOperands);
}
llvm::SmallVector<mlir::Value> shapeOperands;
@@ -733,11 +727,11 @@ private:
++shapeIdx;
}
- return builder.create<fir::ShapeOp>(liveInArg.getLoc(), shapeOperands);
+ return fir::ShapeOp::create(builder, liveInArg.getLoc(), shapeOperands);
}();
- return builder.create<hlfir::DeclareOp>(liveInArg.getLoc(), liveInArg,
- liveInName, shape);
+ return hlfir::DeclareOp::create(builder, liveInArg.getLoc(), liveInArg,
+ liveInName, shape);
}
mlir::omp::TeamsOp genTeamsOp(mlir::ConversionPatternRewriter &rewriter,
@@ -747,13 +741,13 @@ private:
genReductions(rewriter, mapper, loop, teamsOps);
mlir::Location loc = loop.getLoc();
- auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(loc, teamsOps);
+ auto teamsOp = mlir::omp::TeamsOp::create(rewriter, loc, teamsOps);
Fortran::common::openmp::EntryBlockArgs teamsArgs;
teamsArgs.reduction.vars = teamsOps.reductionVars;
Fortran::common::openmp::genEntryBlock(rewriter, teamsArgs,
teamsOp.getRegion());
- rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
+ rewriter.setInsertionPoint(mlir::omp::TerminatorOp::create(rewriter, loc));
for (auto [loopVar, teamsArg] : llvm::zip_equal(
loop.getReduceVars(), teamsOp.getRegion().getArguments())) {
@@ -766,8 +760,8 @@ private:
mlir::omp::DistributeOp
genDistributeOp(mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter) const {
- auto distOp = rewriter.create<mlir::omp::DistributeOp>(
- loc, /*clauses=*/mlir::omp::DistributeOperands{});
+ auto distOp = mlir::omp::DistributeOp::create(
+ rewriter, loc, /*clauses=*/mlir::omp::DistributeOperands{});
rewriter.createBlock(&distOp.getRegion());
return distOp;
@@ -856,7 +850,8 @@ private:
if (!ompReducer) {
ompReducer = mlir::omp::DeclareReductionOp::create(
rewriter, firReducer.getLoc(), ompReducerName,
- firReducer.getTypeAttr().getValue());
+ firReducer.getTypeAttr().getValue(),
+ firReducer.getByrefElementTypeAttr());
cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
ompReducer.getAllocRegion());
diff --git a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
index 3031bb5..0acee89 100644
--- a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
+++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
@@ -11,15 +11,19 @@
//
//===----------------------------------------------------------------------===//
+#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/OpenMP/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
namespace flangomp {
#define GEN_PASS_DEF_FUNCTIONFILTERINGPASS
@@ -28,6 +32,77 @@ namespace flangomp {
using namespace mlir;
+/// This function triggers TODO errors and halts compilation if it detects
+/// patterns representing unimplemented features.
+///
+/// It exclusively checks situations that cannot be detected after all of the
+/// MLIR pipeline has ran (i.e. at the MLIR to LLVM IR translation stage, where
+/// the preferred location for these types of checks is), and it only checks for
+/// features that have not been implemented for target offload, but are
+/// supported on host execution.
+static void
+checkDeviceImplementationStatus(omp::OffloadModuleInterface offloadModule) {
+ if (!offloadModule.getIsGPU())
+ return;
+
+ offloadModule->walk<WalkOrder::PreOrder>([&](omp::DeclareReductionOp redOp) {
+ if (redOp.symbolKnownUseEmpty(offloadModule))
+ return WalkResult::advance();
+
+ if (!redOp.getByrefElementType())
+ return WalkResult::advance();
+
+ auto seqTy =
+ mlir::dyn_cast<fir::SequenceType>(*redOp.getByrefElementType());
+
+ bool isByRefReductionSupported =
+ !seqTy || !fir::sequenceWithNonConstantShape(seqTy);
+
+ if (!isByRefReductionSupported) {
+ TODO(redOp.getLoc(),
+ "Reduction of dynamically-shaped arrays are not supported yet "
+ "on the GPU.");
+ }
+
+ return WalkResult::advance();
+ });
+}
+
+/// Add an operation to one of the output sets to be later rewritten.
+template <typename OpTy>
+static void collectRewrite(OpTy op, llvm::SetVector<OpTy> &rewrites) {
+ rewrites.insert(op);
+}
+
+/// Add an \c omp.map.info operation and all its members recursively to the
+/// output set to be later rewritten.
+///
+/// Dependencies across \c omp.map.info are maintained by ensuring dependencies
+/// are added to the output sets before operations based on them.
+template <>
+void collectRewrite(omp::MapInfoOp mapOp,
+ llvm::SetVector<omp::MapInfoOp> &rewrites) {
+ for (Value member : mapOp.getMembers())
+ collectRewrite(cast<omp::MapInfoOp>(member.getDefiningOp()), rewrites);
+
+ rewrites.insert(mapOp);
+}
+
+/// Add the given value to a sorted set if it should be replaced by a
+/// placeholder when used as an operand that must remain for the device.
+///
+/// Values that are block arguments of \c func.func operations are skipped,
+/// since they will still be available after all rewrites are completed.
+static void collectRewrite(Value value, llvm::SetVector<Value> &rewrites) {
+ if ((isa<BlockArgument>(value) &&
+ isa<func::FuncOp>(
+ cast<BlockArgument>(value).getOwner()->getParentOp())) ||
+ rewrites.contains(value))
+ return;
+
+ rewrites.insert(value);
+}
+
namespace {
class FunctionFilteringPass
: public flangomp::impl::FunctionFilteringPassBase<FunctionFilteringPass> {
@@ -90,10 +165,17 @@ public:
// Remove the callOp
callOp->erase();
}
+
if (!hasTargetRegion) {
funcOp.erase();
return WalkResult::skip();
}
+
+ if (failed(rewriteHostFunction(funcOp))) {
+ funcOp.emitOpError() << "could not be rewritten for target device";
+ return WalkResult::interrupt();
+ }
+
if (declareTargetOp)
declareTargetOp.setDeclareTarget(
declareType, omp::DeclareTargetCaptureClause::to,
@@ -101,6 +183,311 @@ public:
}
return WalkResult::advance();
});
+
+ checkDeviceImplementationStatus(op);
+ }
+
+private:
+ /// Rewrite the given host device function containing \c omp.target
+ /// operations, to remove host-only operations that are not used by device
+ /// codegen.
+ ///
+ /// It is based on the expected form of the MLIR module as produced by Flang
+ /// lowering and it performs the following mutations:
+ /// - Replace all values returned by the function with \c fir.undefined.
+ /// - \c omp.target operations are moved to the end of the function. If they
+ /// are nested inside of any other operations, they are hoisted out of
+ /// them.
+ /// - \c depend, \c device and \c if clauses are removed from these target
+ /// functions. Values used to initialize other clauses are replaced by
+ /// placeholders as follows:
+ /// - Values defined by block arguments are replaced by placeholders only
+ /// if they are not attached to the parent \c func.func operation. In
+ /// that case, they are passed unmodified.
+ /// - \c arith.constant and \c fir.address_of ops are maintained.
+ /// - Values of type \c fir.boxchar are replaced with a combination of
+ /// \c fir.alloca for a single bit and a \c fir.emboxchar.
+ /// - Other values are replaced by a combination of an \c fir.alloca for a
+ /// single bit and an \c fir.convert to the original type of the value.
+ /// This can be done because the code eventually generated for these
+ /// operations will be discarded, as they aren't runnable by the target
+ /// device.
+ /// - \c omp.map.info operations associated to these target regions are
+ /// preserved. These are moved above all \c omp.target and sorted to
+ /// satisfy dependencies among them.
+ /// - \c bounds arguments are removed from \c omp.map.info operations.
+ /// - \c var_ptr and \c var_ptr_ptr arguments of \c omp.map.info are
+ /// handled as follows:
+ /// - \c var_ptr_ptr is expected to be defined by a \c fir.box_offset
+ /// operation which is preserved. Otherwise, the pass will fail.
+ /// - \c var_ptr can be defined by an \c hlfir.declare which is also
+ /// preserved. Its \c memref argument is replaced by a placeholder or
+ /// maintained, similarly to non-map clauses of target operations
+ /// described above. If it has \c shape or \c typeparams arguments, they
+ /// are replaced by applicable constants. \c dummy_scope arguments
+ /// are discarded.
+ /// - Every other operation not located inside of an \c omp.target is
+ /// removed.
+ LogicalResult rewriteHostFunction(func::FuncOp funcOp) {
+ Region &region = funcOp.getRegion();
+
+ // Collect target operations inside of the function.
+ llvm::SmallVector<omp::TargetOp> targetOps;
+ region.walk<WalkOrder::PreOrder>([&](Operation *op) {
+ // Skip the inside of omp.target regions, since these contain device code.
+ if (auto targetOp = dyn_cast<omp::TargetOp>(op)) {
+ targetOps.push_back(targetOp);
+ return WalkResult::skip();
+ }
+
+ // Replace omp.target_data entry block argument uses with the value used
+ // to initialize the associated omp.map.info operation. This way,
+ // references are still valid once the omp.target operation has been
+ // extracted out of the omp.target_data region.
+ if (auto targetDataOp = dyn_cast<omp::TargetDataOp>(op)) {
+ llvm::SmallVector<std::pair<Value, BlockArgument>> argPairs;
+ cast<omp::BlockArgOpenMPOpInterface>(*targetDataOp)
+ .getBlockArgsPairs(argPairs);
+ for (auto [operand, blockArg] : argPairs) {
+ auto mapInfo = cast<omp::MapInfoOp>(operand.getDefiningOp());
+ Value varPtr = mapInfo.getVarPtr();
+
+ // If the var_ptr operand of the omp.map.info op defining this entry
+ // block argument is an hlfir.declare, the uses of all users of that
+ // entry block argument that are themselves hlfir.declare are replaced
+ // by values produced by the outer one.
+ //
+ // This prevents this pass from producing chains of hlfir.declare of
+ // the type:
+ // %0 = ...
+ // %1:2 = hlfir.declare %0
+ // %2:2 = hlfir.declare %1#1...
+ // %3 = omp.map.info var_ptr(%2#1 ...
+ if (auto outerDeclare = varPtr.getDefiningOp<hlfir::DeclareOp>())
+ for (Operation *user : blockArg.getUsers())
+ if (isa<hlfir::DeclareOp>(user))
+ user->replaceAllUsesWith(outerDeclare);
+
+ // All remaining uses of the entry block argument are replaced with
+ // the var_ptr initialization value.
+ blockArg.replaceAllUsesWith(varPtr);
+ }
+ }
+ return WalkResult::advance();
+ });
+
+ // Make a temporary clone of the parent operation with an empty region,
+ // and update all references to entry block arguments to those of the new
+ // region. Users will later either be moved to the new region or deleted
+ // when the original region is replaced by the new.
+ OpBuilder builder(&getContext());
+ builder.setInsertionPointAfter(funcOp);
+ Operation *newOp = builder.cloneWithoutRegions(funcOp);
+ Block &block = newOp->getRegion(0).emplaceBlock();
+
+ llvm::SmallVector<Location> locs;
+ locs.reserve(region.getNumArguments());
+ llvm::transform(region.getArguments(), std::back_inserter(locs),
+ [](const BlockArgument &arg) { return arg.getLoc(); });
+ block.addArguments(region.getArgumentTypes(), locs);
+
+ for (auto [oldArg, newArg] :
+ llvm::zip_equal(region.getArguments(), block.getArguments()))
+ oldArg.replaceAllUsesWith(newArg);
+
+ // Collect omp.map.info ops while satisfying interdependencies and remove
+ // operands that aren't used by target device codegen.
+ //
+ // This logic must be updated whenever operands to omp.target change.
+ llvm::SetVector<Value> rewriteValues;
+ llvm::SetVector<omp::MapInfoOp> mapInfos;
+ for (omp::TargetOp targetOp : targetOps) {
+ assert(targetOp.getHostEvalVars().empty() &&
+ "unexpected host_eval in target device module");
+
+ // Variables unused by the device.
+ targetOp.getDependVarsMutable().clear();
+ targetOp.setDependKindsAttr(nullptr);
+ targetOp.getDeviceMutable().clear();
+ targetOp.getIfExprMutable().clear();
+
+ // TODO: Clear some of these operands rather than rewriting them,
+ // depending on whether they are needed by device codegen once support for
+ // them is fully implemented.
+ for (Value allocVar : targetOp.getAllocateVars())
+ collectRewrite(allocVar, rewriteValues);
+ for (Value allocVar : targetOp.getAllocatorVars())
+ collectRewrite(allocVar, rewriteValues);
+ for (Value inReduction : targetOp.getInReductionVars())
+ collectRewrite(inReduction, rewriteValues);
+ for (Value isDevPtr : targetOp.getIsDevicePtrVars())
+ collectRewrite(isDevPtr, rewriteValues);
+ for (Value mapVar : targetOp.getHasDeviceAddrVars())
+ collectRewrite(cast<omp::MapInfoOp>(mapVar.getDefiningOp()), mapInfos);
+ for (Value mapVar : targetOp.getMapVars())
+ collectRewrite(cast<omp::MapInfoOp>(mapVar.getDefiningOp()), mapInfos);
+ for (Value privateVar : targetOp.getPrivateVars())
+ collectRewrite(privateVar, rewriteValues);
+ for (Value threadLimit : targetOp.getThreadLimitVars())
+ collectRewrite(threadLimit, rewriteValues);
+ }
+
+ // Move omp.map.info ops to the new block and collect dependencies.
+ llvm::SetVector<hlfir::DeclareOp> declareOps;
+ llvm::SetVector<fir::BoxOffsetOp> boxOffsets;
+ for (omp::MapInfoOp mapOp : mapInfos) {
+ if (auto declareOp = dyn_cast_if_present<hlfir::DeclareOp>(
+ mapOp.getVarPtr().getDefiningOp()))
+ collectRewrite(declareOp, declareOps);
+ else
+ collectRewrite(mapOp.getVarPtr(), rewriteValues);
+
+ if (Value varPtrPtr = mapOp.getVarPtrPtr()) {
+ if (auto boxOffset = llvm::dyn_cast_if_present<fir::BoxOffsetOp>(
+ varPtrPtr.getDefiningOp()))
+ collectRewrite(boxOffset, boxOffsets);
+ else
+ return mapOp->emitOpError() << "var_ptr_ptr rewrite only supported "
+ "if defined by fir.box_offset";
+ }
+
+ // Bounds are not used during target device codegen.
+ mapOp.getBoundsMutable().clear();
+ mapOp->moveBefore(&block, block.end());
+ }
+
+ // Create a temporary marker to simplify the op moving process below.
+ builder.setInsertionPointToStart(&block);
+ auto marker = fir::UndefOp::create(builder, builder.getUnknownLoc(),
+ builder.getNoneType());
+ builder.setInsertionPoint(marker);
+
+ // Handle dependencies of hlfir.declare ops.
+ for (hlfir::DeclareOp declareOp : declareOps) {
+ collectRewrite(declareOp.getMemref(), rewriteValues);
+
+ if (declareOp.getStorage())
+ collectRewrite(declareOp.getStorage(), rewriteValues);
+
+ // Shape and typeparams aren't needed for target device codegen, but
+ // removing them would break verifiers.
+ Value zero;
+ if (declareOp.getShape() || !declareOp.getTypeparams().empty())
+ zero = arith::ConstantOp::create(builder, declareOp.getLoc(),
+ builder.getI64IntegerAttr(0));
+
+ if (auto shape = declareOp.getShape()) {
+ // The pre-cg rewrite pass requires the shape to be defined by one of
+ // fir.shape, fir.shapeshift or fir.shift, so we need to make sure it's
+ // still defined by one of these after this pass.
+ Operation *shapeOp = shape.getDefiningOp();
+ llvm::SmallVector<Value> extents(shapeOp->getNumOperands(), zero);
+ Value newShape =
+ llvm::TypeSwitch<Operation *, Value>(shapeOp)
+ .Case([&](fir::ShapeOp op) {
+ return fir::ShapeOp::create(builder, op.getLoc(), extents);
+ })
+ .Case([&](fir::ShapeShiftOp op) {
+ auto type = fir::ShapeShiftType::get(op.getContext(),
+ extents.size() / 2);
+ return fir::ShapeShiftOp::create(builder, op.getLoc(), type,
+ extents);
+ })
+ .Case([&](fir::ShiftOp op) {
+ auto type =
+ fir::ShiftType::get(op.getContext(), extents.size());
+ return fir::ShiftOp::create(builder, op.getLoc(), type,
+ extents);
+ })
+ .Default([](Operation *op) {
+ op->emitOpError()
+ << "hlfir.declare shape expected to be one of: "
+ "fir.shape, fir.shapeshift or fir.shift";
+ return nullptr;
+ });
+
+ if (!newShape)
+ return failure();
+
+ declareOp.getShapeMutable().assign(newShape);
+ }
+
+ for (OpOperand &typeParam : declareOp.getTypeparamsMutable())
+ typeParam.assign(zero);
+
+ declareOp.getDummyScopeMutable().clear();
+ }
+
+ // We don't actually need the proper initialization, but rather just
+ // maintain the basic form of these operands. Generally, we create 1-bit
+ // placeholder allocas that we "typecast" to the expected type and replace
+ // all uses. Using fir.undefined here instead is not possible because these
+ // variables cannot be constants, as that would trigger different codegen
+ // for target regions.
+ for (Value value : rewriteValues) {
+ Location loc = value.getLoc();
+ Value rewriteValue;
+ if (isa_and_present<arith::ConstantOp, fir::AddrOfOp>(
+ value.getDefiningOp())) {
+ // If it's defined by fir.address_of, then we need to keep that op as
+ // well because it might be pointing to a 'declare target' global.
+ // Constants can also trigger different codegen paths, so we keep them
+ // as well.
+ rewriteValue = builder.clone(*value.getDefiningOp())->getResult(0);
+ } else if (auto boxCharType =
+ dyn_cast<fir::BoxCharType>(value.getType())) {
+ // !fir.boxchar types cannot be directly obtained by converting a
+ // !fir.ref<i1>, as they aren't reference types. Since they can appear
+ // representing some `target firstprivate` clauses, we need to create
+ // a special case here based on creating a placeholder fir.emboxchar op.
+ MLIRContext *ctx = &getContext();
+ fir::KindTy kind = boxCharType.getKind();
+ auto placeholder = fir::AllocaOp::create(
+ builder, loc, fir::CharacterType::getSingleton(ctx, kind));
+ auto one = arith::ConstantOp::create(builder, loc, builder.getI32Type(),
+ builder.getI32IntegerAttr(1));
+ rewriteValue = fir::EmboxCharOp::create(builder, loc, boxCharType,
+ placeholder, one);
+ } else {
+ Value placeholder =
+ fir::AllocaOp::create(builder, loc, builder.getI1Type());
+ rewriteValue =
+ fir::ConvertOp::create(builder, loc, value.getType(), placeholder);
+ }
+ value.replaceAllUsesWith(rewriteValue);
+ }
+
+ // Move omp.map.info dependencies.
+ for (hlfir::DeclareOp declareOp : declareOps)
+ declareOp->moveBefore(marker);
+
+ // The box_ref argument of fir.box_offset is expected to be the same value
+ // that was passed as var_ptr to the corresponding omp.map.info, so we don't
+ // need to handle its defining op here.
+ for (fir::BoxOffsetOp boxOffset : boxOffsets)
+ boxOffset->moveBefore(marker);
+
+ marker->erase();
+
+ // Move target operations to the end of the new block.
+ for (omp::TargetOp targetOp : targetOps)
+ targetOp->moveBefore(&block, block.end());
+
+ // Add terminator to the new block.
+ builder.setInsertionPointToEnd(&block);
+ llvm::SmallVector<Value> returnValues;
+ returnValues.reserve(funcOp.getNumResults());
+ for (auto type : funcOp.getResultTypes())
+ returnValues.push_back(
+ fir::UndefOp::create(builder, funcOp.getLoc(), type));
+
+ func::ReturnOp::create(builder, funcOp.getLoc(), returnValues);
+
+ // Replace old region (now missing ops) with the new one and remove the
+ // temporary operation clone.
+ region.takeBody(newOp->getRegion(0));
+ newOp->erase();
+ return success();
}
};
} // namespace
diff --git a/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp b/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp
index 5aa1273..be0bdb7 100644
--- a/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp
@@ -41,7 +41,7 @@ class LowerNontemporalPass
operand = op.getMemref();
defOp = operand.getDefiningOp();
})
- .Case<fir::BoxAddrOp>([&](auto op) {
+ .Case([&](fir::BoxAddrOp op) {
operand = op.getVal();
defOp = operand.getDefiningOp();
})
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 9278e17..2c79800 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -282,14 +282,14 @@ fissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
&newTeams.getRegion(), newTeams.getRegion().begin(), {}, {});
for (auto arg : teamsBlock->getArguments())
newTeamsBlock->addArgument(arg.getType(), arg.getLoc());
- auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc);
- rewriter.create<omp::TerminatorOp>(loc);
+ auto newWorkdistribute = omp::WorkdistributeOp::create(rewriter, loc);
+ omp::TerminatorOp::create(rewriter, loc);
rewriter.createBlock(&newWorkdistribute.getRegion(),
newWorkdistribute.getRegion().begin(), {}, {});
auto *cloned = rewriter.clone(*parallelize);
parallelize->replaceAllUsesWith(cloned);
parallelize->erase();
- rewriter.create<omp::TerminatorOp>(loc);
+ omp::TerminatorOp::create(rewriter, loc);
changed = true;
}
}
@@ -298,10 +298,10 @@ fissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
/// Generate omp.parallel operation with an empty region.
static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) {
- auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc);
+ auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc);
parallelOp.setComposite(composite);
rewriter.createBlock(&parallelOp.getRegion());
- rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
+ rewriter.setInsertionPoint(mlir::omp::TerminatorOp::create(rewriter, loc));
return;
}
@@ -309,7 +309,7 @@ static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) {
static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) {
mlir::omp::DistributeOperands distributeClauseOps;
auto distributeOp =
- rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps);
+ mlir::omp::DistributeOp::create(rewriter, loc, distributeClauseOps);
distributeOp.setComposite(composite);
auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion());
rewriter.setInsertionPointToStart(distributeBlock);
@@ -334,12 +334,12 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
const mlir::omp::LoopNestOperands &clauseOps,
bool composite) {
- auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
+ auto wsloopOp = mlir::omp::WsloopOp::create(rewriter, doLoop.getLoc());
wsloopOp.setComposite(composite);
rewriter.createBlock(&wsloopOp.getRegion());
auto loopNestOp =
- rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps);
+ mlir::omp::LoopNestOp::create(rewriter, doLoop.getLoc(), clauseOps);
// Clone the loop's body inside the loop nest construct using the
// mapped values.
@@ -351,7 +351,7 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
// Erase fir.result op of do loop and create yield op.
if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) {
rewriter.setInsertionPoint(terminatorOp);
- rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc());
+ mlir::omp::YieldOp::create(rewriter, doLoop->getLoc());
terminatorOp->erase();
}
}
@@ -494,15 +494,15 @@ static SmallVector<Value> convertFlatToMultiDim(OpBuilder &builder,
// Convert flat index to multi-dimensional indices
SmallVector<Value> indices(rank);
Value temp = flatIdx;
- auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1);
+ auto c1 = arith::ConstantIndexOp::create(builder, loc, 1);
// Work backwards through dimensions (row-major order)
for (int i = rank - 1; i >= 0; --i) {
- Value zeroBasedIdx = builder.create<arith::RemSIOp>(loc, temp, extents[i]);
+ Value zeroBasedIdx = arith::RemSIOp::create(builder, loc, temp, extents[i]);
// Convert to one-based index
- indices[i] = builder.create<arith::AddIOp>(loc, zeroBasedIdx, c1);
+ indices[i] = arith::AddIOp::create(builder, loc, zeroBasedIdx, c1);
if (i > 0) {
- temp = builder.create<arith::DivSIOp>(loc, temp, extents[i]);
+ temp = arith::DivSIOp::create(builder, loc, temp, extents[i]);
}
}
@@ -525,7 +525,7 @@ static Value CalculateTotalElements(OpBuilder &builder, Location loc,
if (i == 0) {
totalElems = extent;
} else {
- totalElems = builder.create<arith::MulIOp>(loc, totalElems, extent);
+ totalElems = arith::MulIOp::create(builder, loc, totalElems, extent);
}
}
return totalElems;
@@ -562,14 +562,14 @@ static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc,
// Load destination array box (if it's a reference)
Value arrayBox = destBox;
if (isa<fir::ReferenceType>(destBox.getType()))
- arrayBox = builder.create<fir::LoadOp>(loc, destBox);
+ arrayBox = fir::LoadOp::create(builder, loc, destBox);
- auto scalarValue = builder.create<fir::BoxAddrOp>(loc, srcBox);
- Value scalar = builder.create<fir::LoadOp>(loc, scalarValue);
+ auto scalarValue = fir::BoxAddrOp::create(builder, loc, srcBox);
+ Value scalar = fir::LoadOp::create(builder, loc, scalarValue);
// Calculate total number of elements (flattened)
- auto c0 = builder.create<arith::ConstantIndexOp>(loc, 0);
- auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1);
+ auto c0 = arith::ConstantIndexOp::create(builder, loc, 0);
+ auto c1 = arith::ConstantIndexOp::create(builder, loc, 1);
Value totalElems = CalculateTotalElements(builder, loc, arrayBox);
auto *workdistributeBlock = &workdistribute.getRegion().front();
@@ -587,7 +587,7 @@ static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc,
builder, loc, fir::ReferenceType::get(scalar.getType()), arrayBox,
nullptr, nullptr, ValueRange{indices}, ValueRange{});
- builder.create<fir::StoreOp>(loc, scalar, elemPtr);
+ fir::StoreOp::create(builder, loc, scalar, elemPtr);
}
/// workdistributeRuntimeCallLower method finds the runtime calls
@@ -719,10 +719,9 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
SmallVector<Value> outerMapInfos;
// Create new mapinfo ops for the inner target region
for (auto mapInfo : mapInfos) {
- auto originalMapType =
- (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType());
+ mlir::omp::ClauseMapFlags originalMapType = mapInfo.getMapType();
auto originalCaptureType = mapInfo.getMapCaptureType();
- llvm::omp::OpenMPOffloadMappingFlags newMapType;
+ mlir::omp::ClauseMapFlags newMapType;
mlir::omp::VariableCaptureKind newCaptureType;
// For bycopy, we keep the same map type and capture type
// For byref, we change the map type to none and keep the capture type
@@ -730,7 +729,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
newMapType = originalMapType;
newCaptureType = originalCaptureType;
} else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) {
- newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
+ newMapType = mlir::omp::ClauseMapFlags::storage;
newCaptureType = originalCaptureType;
outerMapInfos.push_back(mapInfo);
} else {
@@ -738,11 +737,8 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
return failure();
}
auto innerMapInfo = cast<omp::MapInfoOp>(rewriter.clone(*mapInfo));
- innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr(
- rewriter.getIntegerType(64, false),
- static_cast<
- std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
- newMapType)));
+ innerMapInfo.setMapTypeAttr(
+ rewriter.getAttr<omp::ClauseMapFlagsAttr>(newMapType));
innerMapInfo.setMapCaptureType(newCaptureType);
innerMapInfos.push_back(innerMapInfo.getResult());
}
@@ -753,14 +749,15 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
auto deviceAddrVars = targetOp.getHasDeviceAddrVars();
auto devicePtrVars = targetOp.getIsDevicePtrVars();
// Create the target data op
- auto targetDataOp = rewriter.create<omp::TargetDataOp>(
- loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars);
+ auto targetDataOp =
+ omp::TargetDataOp::create(rewriter, loc, device, ifExpr, outerMapInfos,
+ deviceAddrVars, devicePtrVars);
auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion());
- rewriter.create<mlir::omp::TerminatorOp>(loc);
+ mlir::omp::TerminatorOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(taregtDataBlock);
// Create the inner target op
- auto newTargetOp = rewriter.create<omp::TargetOp>(
- targetOp.getLoc(), targetOp.getAllocateVars(),
+ auto newTargetOp = omp::TargetOp::create(
+ rewriter, targetOp.getLoc(), targetOp.getAllocateVars(),
targetOp.getAllocatorVars(), targetOp.getBareAttr(),
targetOp.getDependKindsAttr(), targetOp.getDependVars(),
targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
@@ -769,7 +766,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
- targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+ targetOp.getThreadLimitVars(), targetOp.getPrivateMapsAttr());
rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
newTargetOp.getRegion().begin());
rewriter.replaceOp(targetOp, targetDataOp);
@@ -825,20 +822,20 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
// Get the appropriate type for allocation
if (isPtr(ty)) {
Type intTy = rewriter.getI32Type();
- auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1);
+ auto one = LLVM::ConstantOp::create(rewriter, loc, intTy, 1);
allocType = llvmPtrTy;
- alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one);
+ alloc = LLVM::AllocaOp::create(rewriter, loc, llvmPtrTy, allocType, one);
allocType = intTy;
} else {
allocType = ty;
- alloc = rewriter.create<fir::AllocaOp>(loc, allocType);
+ alloc = fir::AllocaOp::create(rewriter, loc, allocType);
}
// Lambda to create mapinfo ops
- auto getMapInfo = [&](uint64_t mappingFlags, const char *name) {
- return rewriter.create<omp::MapInfoOp>(
- loc, alloc.getType(), alloc, TypeAttr::get(allocType),
- rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false),
- mappingFlags),
+ auto getMapInfo = [&](mlir::omp::ClauseMapFlags mappingFlags,
+ const char *name) {
+ return omp::MapInfoOp::create(
+ rewriter, loc, alloc.getType(), alloc, TypeAttr::get(allocType),
+ rewriter.getAttr<omp::ClauseMapFlagsAttr>(mappingFlags),
rewriter.getAttr<omp::VariableCaptureKindAttr>(
omp::VariableCaptureKind::ByRef),
/*varPtrPtr=*/Value{},
@@ -849,14 +846,10 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
/*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false));
};
// Create mapinfo ops.
- uint64_t mapFrom =
- static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
- uint64_t mapTo =
- static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
- auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from");
- auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to");
+ auto mapInfoFrom = getMapInfo(mlir::omp::ClauseMapFlags::from,
+ "__flang_workdistribute_from");
+ auto mapInfoTo =
+ getMapInfo(mlir::omp::ClauseMapFlags::to, "__flang_workdistribute_to");
return TempOmpVar{mapInfoFrom, mapInfoTo};
}
@@ -987,12 +980,12 @@ static void reloadCacheAndRecompute(
// If the original value is a pointer or reference, load and convert if
// necessary.
if (isPtr(original.getType())) {
- restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg);
+ restored = LLVM::LoadOp::create(rewriter, loc, llvmPtrTy, newArg);
if (!isa<LLVM::LLVMPointerType>(original.getType()))
restored =
- rewriter.create<fir::ConvertOp>(loc, original.getType(), restored);
+ fir::ConvertOp::create(rewriter, loc, original.getType(), restored);
} else {
- restored = rewriter.create<fir::LoadOp>(loc, newArg);
+ restored = fir::LoadOp::create(rewriter, loc, newArg);
}
irMapping.map(original, restored);
}
@@ -1061,7 +1054,7 @@ static mlir::LLVM::ConstantOp
genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
mlir::Type i32Ty = rewriter.getI32Type();
mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value);
- return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
+ return mlir::LLVM::ConstantOp::create(rewriter, loc, i32Ty, attr);
}
/// Given a box descriptor, extract the base address of the data it describes.
@@ -1238,8 +1231,8 @@ static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder,
genOmpGetMappedPtrIfPresent(builder, loc, destBase, device, module);
Value srcPtr =
genOmpGetMappedPtrIfPresent(builder, loc, srcBase, device, module);
- Value zero = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
- builder.getI64IntegerAttr(0));
+ Value zero = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(),
+ builder.getI64IntegerAttr(0));
// Generate the call to omp_target_memcpy to perform the data copy on the
// device.
@@ -1356,23 +1349,24 @@ static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
for (Operation *op : opsToReplace) {
if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
rewriter.setInsertionPoint(allocOp);
- auto ompAllocmemOp = rewriter.create<omp::TargetAllocMemOp>(
- allocOp.getLoc(), rewriter.getI64Type(), device,
+ auto ompAllocmemOp = omp::TargetAllocMemOp::create(
+ rewriter, allocOp.getLoc(), rewriter.getI64Type(), device,
allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(),
allocOp.getBindcNameAttr(), allocOp.getTypeparams(),
allocOp.getShape());
- auto firConvertOp = rewriter.create<fir::ConvertOp>(
- allocOp.getLoc(), allocOp.getResult().getType(),
- ompAllocmemOp.getResult());
+ auto firConvertOp = fir::ConvertOp::create(rewriter, allocOp.getLoc(),
+ allocOp.getResult().getType(),
+ ompAllocmemOp.getResult());
rewriter.replaceOp(allocOp, firConvertOp.getResult());
}
// Replace fir.freemem with omp.target_freemem.
else if (auto freeOp = dyn_cast<fir::FreeMemOp>(op)) {
rewriter.setInsertionPoint(freeOp);
- auto firConvertOp = rewriter.create<fir::ConvertOp>(
- freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref());
- rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device,
- firConvertOp.getResult());
+ auto firConvertOp =
+ fir::ConvertOp::create(rewriter, freeOp.getLoc(),
+ rewriter.getI64Type(), freeOp.getHeapref());
+ omp::TargetFreeMemOp::create(rewriter, freeOp.getLoc(), device,
+ firConvertOp.getResult());
rewriter.eraseOp(freeOp);
}
// fir.declare changes its type when hoisting it out of omp.target to
@@ -1384,8 +1378,9 @@ static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
dyn_cast<fir::ReferenceType>(clonedInType);
Type clonedEleTy = clonedRefType.getElementType();
rewriter.setInsertionPoint(op);
- Value loadedValue = rewriter.create<fir::LoadOp>(
- clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref());
+ Value loadedValue =
+ fir::LoadOp::create(rewriter, clonedDeclareOp.getLoc(), clonedEleTy,
+ clonedDeclareOp.getMemref());
clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue);
}
// Replace runtime calls with omp versions.
@@ -1481,8 +1476,8 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
auto *targetBlock = &targetOp.getRegion().front();
SmallVector<Value> preHostEvalVars{targetOp.getHostEvalVars()};
// update the hostEvalVars of preTargetOp
- omp::TargetOp preTargetOp = rewriter.create<omp::TargetOp>(
- targetOp.getLoc(), targetOp.getAllocateVars(),
+ omp::TargetOp preTargetOp = omp::TargetOp::create(
+ rewriter, targetOp.getLoc(), targetOp.getAllocateVars(),
targetOp.getAllocatorVars(), targetOp.getBareAttr(),
targetOp.getDependKindsAttr(), targetOp.getDependVars(),
targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars,
@@ -1490,7 +1485,7 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
- targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
+ targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVars(),
targetOp.getPrivateMapsAttr());
auto *preTargetBlock = rewriter.createBlock(
&preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
@@ -1521,13 +1516,13 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
// Create the store operation.
if (isPtr(originalResult.getType())) {
if (!isa<LLVM::LLVMPointerType>(toStore.getType()))
- toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore);
- rewriter.create<LLVM::StoreOp>(loc, toStore, newArg);
+ toStore = fir::ConvertOp::create(rewriter, loc, llvmPtrTy, toStore);
+ LLVM::StoreOp::create(rewriter, loc, toStore, newArg);
} else {
- rewriter.create<fir::StoreOp>(loc, toStore, newArg);
+ fir::StoreOp::create(rewriter, loc, toStore, newArg);
}
}
- rewriter.create<omp::TerminatorOp>(loc);
+ omp::TerminatorOp::create(rewriter, loc);
// Update hostEvalVars with the mapped values for the loop bounds if we have
// a loopNestOp and we are not generating code for the target device.
@@ -1571,8 +1566,8 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
hostEvalVars.steps.end());
}
// Create the isolated target op
- omp::TargetOp isolatedTargetOp = rewriter.create<omp::TargetOp>(
- targetOp.getLoc(), targetOp.getAllocateVars(),
+ omp::TargetOp isolatedTargetOp = omp::TargetOp::create(
+ rewriter, targetOp.getLoc(), targetOp.getAllocateVars(),
targetOp.getAllocatorVars(), targetOp.getBareAttr(),
targetOp.getDependKindsAttr(), targetOp.getDependVars(),
targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
@@ -1580,7 +1575,7 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
- targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
+ targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVars(),
targetOp.getPrivateMapsAttr());
auto *isolatedTargetBlock =
rewriter.createBlock(&isolatedTargetOp.getRegion(),
@@ -1598,7 +1593,7 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
// Clone the original operations.
rewriter.clone(*splitBeforeOp, isolatedMapping);
- rewriter.create<omp::TerminatorOp>(loc);
+ omp::TerminatorOp::create(rewriter, loc);
// update the loop bounds in the isolatedTargetOp if we have host_eval vars
// and we are not generating code for the target device.
@@ -1651,8 +1646,8 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
auto *targetBlock = &targetOp.getRegion().front();
SmallVector<Value> postHostEvalVars{targetOp.getHostEvalVars()};
// Create the post target op
- omp::TargetOp postTargetOp = rewriter.create<omp::TargetOp>(
- targetOp.getLoc(), targetOp.getAllocateVars(),
+ omp::TargetOp postTargetOp = omp::TargetOp::create(
+ rewriter, targetOp.getLoc(), targetOp.getAllocateVars(),
targetOp.getAllocatorVars(), targetOp.getBareAttr(),
targetOp.getDependKindsAttr(), targetOp.getDependVars(),
targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars,
@@ -1660,7 +1655,7 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
- targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
+ targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVars(),
targetOp.getPrivateMapsAttr());
// Create the block for postTargetOp
auto *postTargetBlock = rewriter.createBlock(
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 2bbd803..a60960e 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -43,7 +43,6 @@
#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringSet.h"
-#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cstddef>
@@ -348,10 +347,10 @@ class MapInfoFinalizationPass
/// base address (BoxOffsetOp) and a MapInfoOp for it. The most
/// important thing to note is that we normally move the bounds from
/// the descriptor map onto the base address map.
- mlir::omp::MapInfoOp genBaseAddrMap(mlir::Value descriptor,
- mlir::OperandRange bounds,
- int64_t mapType,
- fir::FirOpBuilder &builder) {
+ mlir::omp::MapInfoOp
+ genBaseAddrMap(mlir::Value descriptor, mlir::OperandRange bounds,
+ mlir::omp::ClauseMapFlags mapType, fir::FirOpBuilder &builder,
+ mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr()) {
mlir::Location loc = descriptor.getLoc();
mlir::Value baseAddrAddr = fir::BoxOffsetOp::create(
builder, loc, descriptor, fir::BoxFieldAttr::base_addr);
@@ -368,12 +367,12 @@ class MapInfoFinalizationPass
return mlir::omp::MapInfoOp::create(
builder, loc, baseAddrAddr.getType(), descriptor,
mlir::TypeAttr::get(underlyingVarType),
- builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
+ builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(mapType),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
mlir::omp::VariableCaptureKind::ByRef),
baseAddrAddr, /*members=*/mlir::SmallVector<mlir::Value>{},
/*membersIndex=*/mlir::ArrayAttr{}, bounds,
- /*mapperId*/ mlir::FlatSymbolRefAttr(),
+ /*mapperId=*/mapperId,
/*name=*/builder.getStringAttr(""),
/*partial_map=*/builder.getBoolAttr(false));
}
@@ -428,22 +427,36 @@ class MapInfoFinalizationPass
/// allowing `to` mappings, and `target update` not allowing both `to` and
/// `from` simultaneously. We currently try to maintain the `implicit` flag
/// where necessary, although it does not seem strictly required.
- unsigned long getDescriptorMapType(unsigned long mapTypeFlag,
- mlir::Operation *target) {
- using mapFlags = llvm::omp::OpenMPOffloadMappingFlags;
+ mlir::omp::ClauseMapFlags
+ getDescriptorMapType(mlir::omp::ClauseMapFlags mapTypeFlag,
+ mlir::Operation *target) {
+ using mapFlags = mlir::omp::ClauseMapFlags;
if (llvm::isa_and_nonnull<mlir::omp::TargetExitDataOp,
mlir::omp::TargetUpdateOp>(target))
return mapTypeFlag;
- mapFlags flags = mapFlags::OMP_MAP_TO |
- (mapFlags(mapTypeFlag) &
- (mapFlags::OMP_MAP_IMPLICIT | mapFlags::OMP_MAP_ALWAYS));
+ mapFlags flags =
+ mapFlags::to | (mapTypeFlag & (mapFlags::implicit | mapFlags::always));
+
+ // Descriptors for objects will always be copied. This is because the
+ // descriptor can be rematerialized by the compiler, and so the address
+ // of the descriptor for a given object at one place in the code may
+ // differ from that address in another place. The contents of the
+ // descriptor (the base address in particular) will remain unchanged
+ // though.
+ // TODO/FIXME: We currently cannot have MAP_CLOSE and MAP_ALWAYS on
+ // the descriptor at once, these are mutually exclusive and when
+ // both are applied the runtime will fail to map.
+ flags |= ((mapFlags(mapTypeFlag) & mapFlags::close) == mapFlags::close)
+ ? mapFlags::close
+ : mapFlags::always;
+
// For unified_shared_memory, we additionally add `CLOSE` on the descriptor
// to ensure device-local placement where required by tests relying on USM +
// close semantics.
if (moduleRequiresUSM(target->getParentOfType<mlir::ModuleOp>()))
- flags |= mapFlags::OMP_MAP_CLOSE;
- return llvm::to_underlying(flags);
+ flags |= mapFlags::close;
+ return flags;
}
/// Check if the mapOp is present in the HasDeviceAddr clause on
@@ -478,62 +491,6 @@ class MapInfoFinalizationPass
return false;
}
- mlir::omp::MapInfoOp genBoxcharMemberMap(mlir::omp::MapInfoOp op,
- fir::FirOpBuilder &builder) {
- if (!op.getMembers().empty())
- return op;
- mlir::Location loc = op.getVarPtr().getLoc();
- mlir::Value boxChar = op.getVarPtr();
-
- if (mlir::isa<fir::ReferenceType>(op.getVarPtr().getType()))
- boxChar = fir::LoadOp::create(builder, loc, op.getVarPtr());
-
- fir::BoxCharType boxCharType =
- mlir::dyn_cast<fir::BoxCharType>(boxChar.getType());
- mlir::Value boxAddr = fir::BoxOffsetOp::create(
- builder, loc, op.getVarPtr(), fir::BoxFieldAttr::base_addr);
-
- uint64_t mapTypeToImplicit = static_cast<
- std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
-
- mlir::ArrayAttr newMembersAttr;
- llvm::SmallVector<llvm::SmallVector<int64_t>> memberIdx = {{0}};
- newMembersAttr = builder.create2DI64ArrayAttr(memberIdx);
-
- mlir::Value varPtr = op.getVarPtr();
- mlir::omp::MapInfoOp memberMapInfoOp = mlir::omp::MapInfoOp::create(
- builder, op.getLoc(), varPtr.getType(), varPtr,
- mlir::TypeAttr::get(boxCharType.getEleTy()),
- builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
- mapTypeToImplicit),
- builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
- mlir::omp::VariableCaptureKind::ByRef),
- /*varPtrPtr=*/boxAddr,
- /*members=*/llvm::SmallVector<mlir::Value>{},
- /*member_index=*/mlir::ArrayAttr{},
- /*bounds=*/op.getBounds(),
- /*mapperId=*/mlir::FlatSymbolRefAttr(), /*name=*/op.getNameAttr(),
- builder.getBoolAttr(false));
-
- mlir::omp::MapInfoOp newMapInfoOp = mlir::omp::MapInfoOp::create(
- builder, op.getLoc(), op.getResult().getType(), varPtr,
- mlir::TypeAttr::get(
- llvm::cast<mlir::omp::PointerLikeType>(varPtr.getType())
- .getElementType()),
- op.getMapTypeAttr(), op.getMapCaptureTypeAttr(),
- /*varPtrPtr=*/mlir::Value{},
- /*members=*/llvm::SmallVector<mlir::Value>{memberMapInfoOp},
- /*member_index=*/newMembersAttr,
- /*bounds=*/llvm::SmallVector<mlir::Value>{},
- /*mapperId=*/mlir::FlatSymbolRefAttr(), op.getNameAttr(),
- /*partial_map=*/builder.getBoolAttr(false));
- op.replaceAllUsesWith(newMapInfoOp.getResult());
- op->erase();
- return newMapInfoOp;
- }
-
// Expand mappings of type(C_PTR) to map their `__address` field explicitly
// as a single pointer-sized member (USM-gated at callsite). This helps in
// USM scenarios to ensure the pointer-sized mapping is used.
@@ -568,12 +525,9 @@ class MapInfoFinalizationPass
mlir::ArrayAttr newMembersAttr = builder.create2DI64ArrayAttr(memberIdx);
// Force CLOSE in USM paths so the pointer gets device-local placement
// when required by tests relying on USM + close semantics.
- uint64_t mapTypeVal =
- op.getMapType() |
- llvm::to_underlying(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
- mlir::IntegerAttr mapTypeAttr = builder.getIntegerAttr(
- builder.getIntegerType(64, /*isSigned=*/false), mapTypeVal);
+ mlir::omp::ClauseMapFlagsAttr mapTypeAttr =
+ builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
+ op.getMapType() | mlir::omp::ClauseMapFlags::close);
mlir::omp::MapInfoOp memberMap = mlir::omp::MapInfoOp::create(
builder, loc, coord.getType(), coord,
@@ -638,6 +592,7 @@ class MapInfoFinalizationPass
// from the descriptor to be used verbatim, i.e. without additional
// remapping. To avoid this remapping, simply don't generate any map
// information for the descriptor members.
+ mlir::FlatSymbolRefAttr mapperId = op.getMapperIdAttr();
if (!mapMemberUsers.empty()) {
// Currently, there should only be one user per map when this pass
// is executed. Either a parent map, holding the current map in its
@@ -648,8 +603,8 @@ class MapInfoFinalizationPass
assert(mapMemberUsers.size() == 1 &&
"OMPMapInfoFinalization currently only supports single users of a "
"MapInfoOp");
- auto baseAddr =
- genBaseAddrMap(descriptor, op.getBounds(), op.getMapType(), builder);
+ auto baseAddr = genBaseAddrMap(descriptor, op.getBounds(),
+ op.getMapType(), builder, mapperId);
ParentAndPlacement mapUser = mapMemberUsers[0];
adjustMemberIndices(memberIndices, mapUser.index);
llvm::SmallVector<mlir::Value> newMemberOps;
@@ -662,8 +617,8 @@ class MapInfoFinalizationPass
mapUser.parent.setMembersIndexAttr(
builder.create2DI64ArrayAttr(memberIndices));
} else if (!isHasDeviceAddrFlag) {
- auto baseAddr =
- genBaseAddrMap(descriptor, op.getBounds(), op.getMapType(), builder);
+ auto baseAddr = genBaseAddrMap(descriptor, op.getBounds(),
+ op.getMapType(), builder, mapperId);
newMembers.push_back(baseAddr);
if (!op.getMembers().empty()) {
for (auto &indices : memberIndices)
@@ -683,20 +638,19 @@ class MapInfoFinalizationPass
// one place in the code may differ from that address in another place.
// The contents of the descriptor (the base address in particular) will
// remain unchanged though.
- uint64_t mapType = op.getMapType();
+ mlir::omp::ClauseMapFlags mapType = op.getMapType();
if (isHasDeviceAddrFlag) {
- mapType |= llvm::to_underlying(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
+ mapType |= mlir::omp::ClauseMapFlags::always;
}
mlir::omp::MapInfoOp newDescParentMapOp = mlir::omp::MapInfoOp::create(
builder, op->getLoc(), op.getResult().getType(), descriptor,
mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())),
- builder.getIntegerAttr(builder.getIntegerType(64, false),
- getDescriptorMapType(mapType, target)),
+ builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
+ getDescriptorMapType(mapType, target)),
op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, newMembers,
newMembersAttr, /*bounds=*/mlir::SmallVector<mlir::Value>{},
- /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
+ /*mapperId=*/mlir::FlatSymbolRefAttr(), op.getNameAttr(),
/*partial_map=*/builder.getBoolAttr(false));
op.replaceAllUsesWith(newDescParentMapOp.getResult());
op->erase();
@@ -892,20 +846,16 @@ class MapInfoFinalizationPass
if (explicitMappingPresent(op, targetDataOp))
return;
- mlir::omp::MapInfoOp newDescParentMapOp =
- builder.create<mlir::omp::MapInfoOp>(
- op->getLoc(), op.getResult().getType(), op.getVarPtr(),
- op.getVarTypeAttr(),
- builder.getIntegerAttr(
- builder.getIntegerType(64, false),
- llvm::to_underlying(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)),
- op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{},
- mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
- /*bounds=*/mlir::SmallVector<mlir::Value>{},
- /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
- /*partial_map=*/builder.getBoolAttr(false));
+ mlir::omp::MapInfoOp newDescParentMapOp = mlir::omp::MapInfoOp::create(
+ builder, op->getLoc(), op.getResult().getType(), op.getVarPtr(),
+ op.getVarTypeAttr(),
+ builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(
+ mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::always),
+ op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{},
+ mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
+ /*bounds=*/mlir::SmallVector<mlir::Value>{},
+ /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
+ /*partial_map=*/builder.getBoolAttr(false));
targetDataOp.getMapVarsMutable().append({newDescParentMapOp});
}
@@ -957,19 +907,26 @@ class MapInfoFinalizationPass
// need to see how well this alteration works.
auto loadBaseAddr =
builder.loadIfRef(op->getLoc(), baseAddr.getVarPtrPtr());
- mlir::omp::MapInfoOp newBaseAddrMapOp =
- builder.create<mlir::omp::MapInfoOp>(
- op->getLoc(), loadBaseAddr.getType(), loadBaseAddr,
- baseAddr.getVarTypeAttr(), baseAddr.getMapTypeAttr(),
- baseAddr.getMapCaptureTypeAttr(), mlir::Value{}, members,
- membersAttr, baseAddr.getBounds(),
- /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
- /*partial_map=*/builder.getBoolAttr(false));
+ mlir::omp::MapInfoOp newBaseAddrMapOp = mlir::omp::MapInfoOp::create(
+ builder, op->getLoc(), loadBaseAddr.getType(), loadBaseAddr,
+ baseAddr.getVarTypeAttr(), baseAddr.getMapTypeAttr(),
+ baseAddr.getMapCaptureTypeAttr(), mlir::Value{}, members, membersAttr,
+ baseAddr.getBounds(),
+ /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
+ /*partial_map=*/builder.getBoolAttr(false));
op.replaceAllUsesWith(newBaseAddrMapOp.getResult());
op->erase();
baseAddr.erase();
}
+ static bool hasADescriptor(mlir::Operation *varOp, mlir::Type varType) {
+ if (fir::isTypeWithDescriptor(varType) ||
+ mlir::isa<fir::BoxCharType>(varType) ||
+ mlir::isa_and_present<fir::BoxAddrOp>(varOp))
+ return true;
+ return false;
+ }
+
// This pass executes on omp::MapInfoOp's containing descriptor based types
// (allocatables, pointers, assumed shape etc.) and expanding them into
// multiple omp::MapInfoOp's for each pointer member contained within the
@@ -1001,38 +958,6 @@ class MapInfoFinalizationPass
localBoxAllocas.clear();
deferrableDesc.clear();
- // First, walk `omp.map.info` ops to see if any of them have varPtrs
- // with an underlying type of fir.char<k, ?>, i.e a character
- // with dynamic length. If so, check if they need bounds added.
- func->walk([&](mlir::omp::MapInfoOp op) {
- if (!op.getBounds().empty())
- return;
-
- mlir::Value varPtr = op.getVarPtr();
- mlir::Type underlyingVarType = fir::unwrapRefType(varPtr.getType());
-
- if (!fir::characterWithDynamicLen(underlyingVarType))
- return;
-
- fir::factory::AddrAndBoundsInfo info =
- fir::factory::getDataOperandBaseAddr(
- builder, varPtr, /*isOptional=*/false, varPtr.getLoc());
-
- fir::ExtendedValue extendedValue =
- hlfir::translateToExtendedValue(varPtr.getLoc(), builder,
- hlfir::Entity{info.addr},
- /*continguousHint=*/true)
- .first;
- builder.setInsertionPoint(op);
- llvm::SmallVector<mlir::Value> boundsOps =
- fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
- mlir::omp::MapBoundsType>(
- builder, info, extendedValue,
- /*dataExvIsAssumedSize=*/false, varPtr.getLoc());
-
- op.getBoundsMutable().append(boundsOps);
- });
-
// Next, walk `omp.map.info` ops to see if any record members should be
// implicitly mapped.
func->walk([&](mlir::omp::MapInfoOp op) {
@@ -1218,42 +1143,12 @@ class MapInfoFinalizationPass
newMemberIndices.emplace_back(path);
op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
- op.setPartialMap(true);
+ // Set to partial map only if there is no user-defined mapper.
+ op.setPartialMap(op.getMapperIdAttr() == nullptr);
return mlir::WalkResult::advance();
});
- func->walk([&](mlir::omp::MapInfoOp op) {
- if (!op.getMembers().empty())
- return;
-
- if (!mlir::isa<fir::BoxCharType>(fir::unwrapRefType(op.getVarType())))
- return;
-
- // POSSIBLE_HACK_ALERT: If the boxchar has been implicitly mapped then
- // it is likely that the underlying pointer to the data
- // (!fir.ref<fir.char<k,?>>) has already been mapped. So, skip such
- // boxchars. We are primarily interested in boxchars that were mapped
- // by passes such as MapsForPrivatizedSymbols that map boxchars that
- // are privatized. At present, such boxchar maps are not marked
- // implicit. Should they be? I don't know. If they should be then
- // we need to change this check for early return OR live with
- // over-mapping.
- bool hasImplicitMap =
- (llvm::omp::OpenMPOffloadMappingFlags(op.getMapType()) &
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT) ==
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
- if (hasImplicitMap)
- return;
-
- assert(llvm::hasSingleElement(op->getUsers()) &&
- "OMPMapInfoFinalization currently only supports single users "
- "of a MapInfoOp");
-
- builder.setInsertionPoint(op);
- genBoxcharMemberMap(op, builder);
- });
-
// Expand type(C_PTR) only when unified_shared_memory is required,
// to ensure device-visible pointer size/behavior in USM scenarios
// without changing default expectations elsewhere.
@@ -1281,9 +1176,8 @@ class MapInfoFinalizationPass
"OMPMapInfoFinalization currently only supports single users "
"of a MapInfoOp");
- if (fir::isTypeWithDescriptor(op.getVarType()) ||
- mlir::isa_and_present<fir::BoxAddrOp>(
- op.getVarPtr().getDefiningOp())) {
+ if (hasADescriptor(op.getVarPtr().getDefiningOp(),
+ fir::unwrapRefType(op.getVarType()))) {
builder.setInsertionPoint(op);
mlir::Operation *targetUser = getFirstTargetUser(op);
assert(targetUser && "expected user of map operation was not found");
diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
index 3032857..6404e18 100644
--- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -35,7 +35,6 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
-#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Support/Debug.h"
#include <type_traits>
@@ -70,9 +69,6 @@ class MapsForPrivatizedSymbolsPass
return size <= ptrSize && align <= ptrAlign;
};
- uint64_t mapTypeTo = static_cast<
- std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
Operation *definingOp = var.getDefiningOp();
Value varPtr = var;
@@ -108,22 +104,31 @@ class MapsForPrivatizedSymbolsPass
llvm::SmallVector<mlir::Value> boundsOps;
if (needsBoundsOps(varPtr))
genBoundsOps(builder, varPtr, boundsOps);
+ mlir::Type varType = varPtr.getType();
mlir::omp::VariableCaptureKind captureKind =
mlir::omp::VariableCaptureKind::ByRef;
- if (fir::isa_trivial(fir::unwrapRefType(varPtr.getType())) ||
- fir::isa_char(fir::unwrapRefType(varPtr.getType()))) {
- if (canPassByValue(fir::unwrapRefType(varPtr.getType()))) {
+ if (fir::isa_trivial(fir::unwrapRefType(varType)) ||
+ fir::isa_char(fir::unwrapRefType(varType))) {
+ if (canPassByValue(fir::unwrapRefType(varType))) {
captureKind = mlir::omp::VariableCaptureKind::ByCopy;
}
}
+ // Use tofrom if what we are mapping is not a trivial type. In all
+ // likelihood, it is a descriptor
+ mlir::omp::ClauseMapFlags mapFlag;
+ if (fir::isa_trivial(fir::unwrapRefType(varType)) ||
+ fir::isa_char(fir::unwrapRefType(varType)))
+ mapFlag = mlir::omp::ClauseMapFlags::to;
+ else
+ mapFlag = mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::from;
+
return omp::MapInfoOp::create(
- builder, loc, varPtr.getType(), varPtr,
- TypeAttr::get(llvm::cast<omp::PointerLikeType>(varPtr.getType())
- .getElementType()),
- builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
- mapTypeTo),
+ builder, loc, varType, varPtr,
+ TypeAttr::get(
+ llvm::cast<omp::PointerLikeType>(varType).getElementType()),
+ builder.getAttr<omp::ClauseMapFlagsAttr>(mapFlag),
builder.getAttr<omp::VariableCaptureKindAttr>(captureKind),
/*varPtrPtr=*/Value{},
/*members=*/SmallVector<Value>{},
diff --git a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
index 0b0e6bd..5fa77fb 100644
--- a/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
+++ b/flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
@@ -21,6 +21,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/TypeSwitch.h"
namespace flangomp {
#define GEN_PASS_DEF_MARKDECLARETARGETPASS
@@ -31,9 +32,93 @@ namespace {
class MarkDeclareTargetPass
: public flangomp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> {
- void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy,
- mlir::omp::DeclareTargetCaptureClause parentCapClause,
- bool parentAutomap, mlir::Operation *currOp,
+ struct ParentInfo {
+ mlir::omp::DeclareTargetDeviceType devTy;
+ mlir::omp::DeclareTargetCaptureClause capClause;
+ bool automap;
+ };
+
+ void processSymbolRef(mlir::SymbolRefAttr symRef, ParentInfo parentInfo,
+ llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
+ if (auto currFOp =
+ getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) {
+ auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
+ currFOp.getOperation());
+
+ if (current.isDeclareTarget()) {
+ auto currentDt = current.getDeclareTargetDeviceType();
+
+ // Found the same function twice, with different device_types,
+ // mark as Any as it belongs to both
+ if (currentDt != parentInfo.devTy &&
+ currentDt != mlir::omp::DeclareTargetDeviceType::any) {
+ current.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any,
+ current.getDeclareTargetCaptureClause(),
+ current.getDeclareTargetAutomap());
+ }
+ } else {
+ current.setDeclareTarget(parentInfo.devTy, parentInfo.capClause,
+ parentInfo.automap);
+ }
+
+ markNestedFuncs(parentInfo, currFOp, visited);
+ }
+ }
+
+ void processReductionRefs(std::optional<mlir::ArrayAttr> symRefs,
+ ParentInfo parentInfo,
+ llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
+ if (!symRefs)
+ return;
+
+ for (auto symRef : symRefs->getAsRange<mlir::SymbolRefAttr>()) {
+ if (auto declareReductionOp =
+ getOperation().lookupSymbol<mlir::omp::DeclareReductionOp>(
+ symRef)) {
+ markNestedFuncs(parentInfo, declareReductionOp, visited);
+ }
+ }
+ }
+
+ void
+ processReductionClauses(mlir::Operation *op, ParentInfo parentInfo,
+ llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
+ llvm::TypeSwitch<mlir::Operation &>(*op)
+ .Case([&](mlir::omp::LoopOp op) {
+ processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::ParallelOp op) {
+ processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::SectionsOp op) {
+ processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::SimdOp op) {
+ processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::TargetOp op) {
+ processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::TaskgroupOp op) {
+ processReductionRefs(op.getTaskReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::TaskloopOp op) {
+ processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+ processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::TaskOp op) {
+ processReductionRefs(op.getInReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::TeamsOp op) {
+ processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+ })
+ .Case([&](mlir::omp::WsloopOp op) {
+ processReductionRefs(op.getReductionSyms(), parentInfo, visited);
+ })
+ .Default([](mlir::Operation &) {});
+ }
+
+ void markNestedFuncs(ParentInfo parentInfo, mlir::Operation *currOp,
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
if (visited.contains(currOp))
return;
@@ -43,33 +128,10 @@ class MarkDeclareTargetPass
if (auto callOp = llvm::dyn_cast<mlir::CallOpInterface>(op)) {
if (auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
callOp.getCallableForCallee())) {
- if (auto currFOp =
- getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) {
- auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
- currFOp.getOperation());
-
- if (current.isDeclareTarget()) {
- auto currentDt = current.getDeclareTargetDeviceType();
-
- // Found the same function twice, with different device_types,
- // mark as Any as it belongs to both
- if (currentDt != parentDevTy &&
- currentDt != mlir::omp::DeclareTargetDeviceType::any) {
- current.setDeclareTarget(
- mlir::omp::DeclareTargetDeviceType::any,
- current.getDeclareTargetCaptureClause(),
- current.getDeclareTargetAutomap());
- }
- } else {
- current.setDeclareTarget(parentDevTy, parentCapClause,
- parentAutomap);
- }
-
- markNestedFuncs(parentDevTy, parentCapClause, parentAutomap,
- currFOp, visited);
- }
+ processSymbolRef(symRef, parentInfo, visited);
}
}
+ processReductionClauses(op, parentInfo, visited);
});
}
@@ -82,10 +144,10 @@ class MarkDeclareTargetPass
functionOp.getOperation());
if (declareTargetOp.isDeclareTarget()) {
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
- markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(),
- declareTargetOp.getDeclareTargetCaptureClause(),
- declareTargetOp.getDeclareTargetAutomap(), functionOp,
- visited);
+ ParentInfo parentInfo{declareTargetOp.getDeclareTargetDeviceType(),
+ declareTargetOp.getDeclareTargetCaptureClause(),
+ declareTargetOp.getDeclareTargetAutomap()};
+ markNestedFuncs(parentInfo, functionOp, visited);
}
}
@@ -96,12 +158,13 @@ class MarkDeclareTargetPass
// the contents of the device clause
getOperation()->walk([&](mlir::omp::TargetOp tarOp) {
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
- markNestedFuncs(
- /*parentDevTy=*/mlir::omp::DeclareTargetDeviceType::nohost,
- /*parentCapClause=*/mlir::omp::DeclareTargetCaptureClause::to,
- /*parentAutomap=*/false, tarOp, visited);
+ ParentInfo parentInfo = {
+ /*devTy=*/mlir::omp::DeclareTargetDeviceType::nohost,
+ /*capClause=*/mlir::omp::DeclareTargetCaptureClause::to,
+ /*automap=*/false,
+ };
+ markNestedFuncs(parentInfo, tarOp, visited);
});
}
};
-
} // namespace
diff --git a/flang/lib/Optimizer/OpenMP/Support/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/Support/CMakeLists.txt
index dee35e4..004753d 100644
--- a/flang/lib/Optimizer/OpenMP/Support/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/Support/CMakeLists.txt
@@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
add_flang_library(FIROpenMPSupport
FIROpenMPAttributes.cpp
+ FIROpenMPOpsInterfaces.cpp
RegisterOpenMPExtensions.cpp
DEPENDS
diff --git a/flang/lib/Optimizer/OpenMP/Support/FIROpenMPOpsInterfaces.cpp b/flang/lib/Optimizer/OpenMP/Support/FIROpenMPOpsInterfaces.cpp
new file mode 100644
index 0000000..a396ef0
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/Support/FIROpenMPOpsInterfaces.cpp
@@ -0,0 +1,102 @@
+//===-- FIROpenMPOpsInterfaces.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// This file implements FIR operation interfaces, which may be attached
+/// to OpenMP dialect operations.
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIROperationMoveOpInterface.h"
+#include "flang/Optimizer/OpenMP/Support/RegisterOpenMPExtensions.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+
+namespace {
+/// Helper template that must be specialized for each operation.
+/// The methods are declared just for documentation.
+template <typename OP, typename Enable = void>
+struct OperationMoveModel {
+ // Returns true if it is allowed to move the given 'candidate'
+ // operation from the 'descendant' operation into operation 'op'.
+ // If 'candidate' is nullptr, then the caller is querying whether
+ // any operation from any descendant can be moved into 'op' operation.
+ bool canMoveFromDescendant(mlir::Operation *op, mlir::Operation *descendant,
+ mlir::Operation *candidate) const;
+
+ // Returns true if it is allowed to move the given 'candidate'
+ // operation out of operation 'op'. If 'candidate' is nullptr,
+ // then the caller is querying whether any operation can be moved
+ // out of 'op' operation.
+ bool canMoveOutOf(mlir::Operation *op, mlir::Operation *candidate) const;
+};
+
+// Helpers to check if T is one of Ts.
+template <typename T, typename... Ts>
+struct is_any_type : std::disjunction<std::is_same<T, Ts>...> {};
+
+template <typename T, typename... Ts>
+struct is_any_omp_op
+ : std::integral_constant<
+ bool, is_any_type<typename std::remove_cv<T>::type, Ts...>::value> {};
+
+template <typename T, typename... Ts>
+constexpr bool is_any_omp_op_v = is_any_omp_op<T, Ts...>::value;
+
+/// OperationMoveModel specialization for OMP_LOOP_WRAPPER_OPS.
+template <typename OP>
+struct OperationMoveModel<
+ OP,
+ typename std::enable_if<is_any_omp_op_v<OP, OMP_LOOP_WRAPPER_OPS>>::type>
+ : public fir::OperationMoveOpInterface::ExternalModel<
+ OperationMoveModel<OP>, OP> {
+ bool canMoveFromDescendant(mlir::Operation *op, mlir::Operation *descendant,
+ mlir::Operation *candidate) const {
+ // Operations cannot be moved from descendants of LoopWrapperInterface
+ // operation into the LoopWrapperInterface operation.
+ return false;
+ }
+ bool canMoveOutOf(mlir::Operation *op, mlir::Operation *candidate) const {
+ // The LoopWrapperInterface operations are only supposed to contain
+ // a loop operation, and it is probably okay to move operations
+ // from the descendant loop operation out of the LoopWrapperInterface
+ // operation. For now, return false to be conservative.
+ return false;
+ }
+};
+
+/// OperationMoveModel specialization for OMP_OUTLINEABLE_OPS.
+template <typename OP>
+struct OperationMoveModel<
+ OP, typename std::enable_if<is_any_omp_op_v<OP, OMP_OUTLINEABLE_OPS>>::type>
+ : public fir::OperationMoveOpInterface::ExternalModel<
+ OperationMoveModel<OP>, OP> {
+ bool canMoveFromDescendant(mlir::Operation *op, mlir::Operation *descendant,
+ mlir::Operation *candidate) const {
+ // Operations can be moved from descendants of OutlineableOpenMPOpInterface
+ // operation into the OutlineableOpenMPOpInterface operation.
+ return true;
+ }
+ bool canMoveOutOf(mlir::Operation *op, mlir::Operation *candidate) const {
+ // Operations cannot be moved out of OutlineableOpenMPOpInterface operation.
+ return false;
+ }
+};
+
+// Helper to call attachInterface<OperationMoveModel> for all Ts
+// (types of operations).
+template <typename... Ts>
+void attachInterfaces(mlir::MLIRContext *ctx) {
+ (Ts::template attachInterface<OperationMoveModel<Ts>>(*ctx), ...);
+}
+} // anonymous namespace
+
+void fir::omp::registerOpInterfacesExtensions(mlir::DialectRegistry &registry) {
+ registry.addExtension(
+ +[](mlir::MLIRContext *ctx, mlir::omp::OpenMPDialect *dialect) {
+ attachInterfaces<OMP_LOOP_WRAPPER_OPS>(ctx);
+ attachInterfaces<OMP_OUTLINEABLE_OPS>(ctx);
+ });
+}
diff --git a/flang/lib/Optimizer/OpenMP/Support/RegisterOpenMPExtensions.cpp b/flang/lib/Optimizer/OpenMP/Support/RegisterOpenMPExtensions.cpp
index 2495d54..de4906e 100644
--- a/flang/lib/Optimizer/OpenMP/Support/RegisterOpenMPExtensions.cpp
+++ b/flang/lib/Optimizer/OpenMP/Support/RegisterOpenMPExtensions.cpp
@@ -15,6 +15,7 @@
namespace fir::omp {
void registerOpenMPExtensions(mlir::DialectRegistry &registry) {
registerAttrsExtensions(registry);
+ registerOpInterfacesExtensions(registry);
}
} // namespace fir::omp
diff --git a/flang/lib/Optimizer/Passes/CommandLineOpts.cpp b/flang/lib/Optimizer/Passes/CommandLineOpts.cpp
index 0142375..75e818d 100644
--- a/flang/lib/Optimizer/Passes/CommandLineOpts.cpp
+++ b/flang/lib/Optimizer/Passes/CommandLineOpts.cpp
@@ -61,6 +61,7 @@ cl::opt<bool> useOldAliasTags(
cl::desc("Use a single TBAA tree for all functions and do not use "
"the FIR alias tags pass"),
cl::init(false), cl::Hidden);
+EnableOption(FirLICM, "fir-licm", "FIR loop invariant code motion");
/// CodeGen Passes
DisableOption(CodeGenRewrite, "codegen-rewrite", "rewrite FIR for codegen");
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 6dae39b..18ad22f 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -20,8 +20,9 @@ namespace fir {
template <typename F>
void addNestedPassToAllTopLevelOperations(mlir::PassManager &pm, F ctor) {
- addNestedPassToOps<F, mlir::func::FuncOp, mlir::omp::DeclareReductionOp,
- mlir::omp::PrivateClauseOp, fir::GlobalOp>(pm, ctor);
+ addNestedPassToOps<F, mlir::func::FuncOp, mlir::omp::DeclareMapperOp,
+ mlir::omp::DeclareReductionOp, mlir::omp::PrivateClauseOp,
+ fir::GlobalOp>(pm, ctor);
}
template <typename F>
@@ -107,8 +108,8 @@ void addDebugInfoPass(mlir::PassManager &pm,
[&]() { return fir::createAddDebugInfoPass(options); });
}
-void addFIRToLLVMPass(mlir::PassManager &pm,
- const MLIRToLLVMPassPipelineConfig &config) {
+fir::FIRToLLVMPassOptions
+getFIRToLLVMPassOptions(const MLIRToLLVMPassPipelineConfig &config) {
fir::FIRToLLVMPassOptions options;
options.ignoreMissingTypeDescriptors = ignoreMissingTypeDescriptors;
options.skipExternalRttiDefinition = skipExternalRttiDefinition;
@@ -117,6 +118,12 @@ void addFIRToLLVMPass(mlir::PassManager &pm,
options.typeDescriptorsRenamedForAssembly =
!disableCompilerGeneratedNamesConversion;
options.ComplexRange = config.ComplexRange;
+ return options;
+}
+
+void addFIRToLLVMPass(mlir::PassManager &pm,
+ const MLIRToLLVMPassPipelineConfig &config) {
+ fir::FIRToLLVMPassOptions options = getFIRToLLVMPassOptions(config);
addPassConditionally(pm, disableFirToLlvmIr,
[&]() { return fir::createFIRToLLVMPass(options); });
// The dialect conversion framework may leave dead unrealized_conversion_cast
@@ -206,6 +213,10 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
pm.addPass(fir::createSimplifyRegionLite());
pm.addPass(mlir::createCSEPass());
+ // Run LICM after CSE, which may reduce the number of operations to hoist.
+ if (enableFirLICM && pc.OptLevel.isOptimizingForSpeed())
+ pm.addPass(fir::createLoopInvariantCodeMotion());
+
// Polymorphic types
pm.addPass(fir::createPolymorphicOpConversion());
pm.addPass(fir::createAssumedRankOpConversion());
@@ -279,7 +290,8 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm,
pm, hlfir::createInlineHLFIRCopyIn);
}
}
- pm.addPass(hlfir::createLowerHLFIROrderedAssignments());
+ pm.addPass(hlfir::createLowerHLFIROrderedAssignments(
+ {/*tryFusingAssignments=*/optLevel.isOptimizingForSpeed()}));
pm.addPass(hlfir::createLowerHLFIRIntrinsics());
hlfir::BufferizeHLFIROptions bufferizeOptions;
@@ -372,7 +384,7 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
fir::addCompilerGeneratedNamesConversionPass(pm);
if (config.VScaleMin != 0)
- pm.addPass(fir::createVScaleAttr({{config.VScaleMin, config.VScaleMax}}));
+ pm.addPass(fir::createVScaleAttr({config.VScaleMin, config.VScaleMax}));
// Add function attributes
mlir::LLVM::framePointerKind::FramePointerKind framePointerKind;
@@ -383,6 +395,9 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
framePointerKind = mlir::LLVM::framePointerKind::FramePointerKind::All;
else if (config.FramePointerKind == llvm::FramePointerKind::Reserved)
framePointerKind = mlir::LLVM::framePointerKind::FramePointerKind::Reserved;
+ else if (config.FramePointerKind == llvm::FramePointerKind::NonLeafNoReserve)
+ framePointerKind =
+ mlir::LLVM::framePointerKind::FramePointerKind::NonLeafNoReserve;
else
framePointerKind = mlir::LLVM::framePointerKind::FramePointerKind::None;
@@ -426,6 +441,12 @@ void createMLIRToLLVMPassPipeline(mlir::PassManager &pm,
// Add codegen pass pipeline.
fir::createDefaultFIRCodeGenPassPipeline(pm, config, inputFilename);
+
+ // Run a pass to prepare for translation of delayed privatization in the
+ // context of deferred target tasks.
+ addPassConditionally(pm, disableFirToLlvmIr, [&]() {
+ return mlir::omp::createPrepareForOMPOffloadPrivatizationPass();
+ });
}
} // namespace fir
diff --git a/flang/lib/Optimizer/Support/CMakeLists.txt b/flang/lib/Optimizer/Support/CMakeLists.txt
index 38038e1..6f3652b 100644
--- a/flang/lib/Optimizer/Support/CMakeLists.txt
+++ b/flang/lib/Optimizer/Support/CMakeLists.txt
@@ -7,9 +7,11 @@ add_flang_library(FIRSupport
DEPENDS
FIROpsIncGen
HLFIROpsIncGen
+ MIFOpsIncGen
LINK_LIBS
FIRDialect
+ MIFDialect
LINK_COMPONENTS
TargetParser
diff --git a/flang/lib/Optimizer/Support/Utils.cpp b/flang/lib/Optimizer/Support/Utils.cpp
index 92390e4a..2f33d89 100644
--- a/flang/lib/Optimizer/Support/Utils.cpp
+++ b/flang/lib/Optimizer/Support/Utils.cpp
@@ -66,7 +66,7 @@ fir::genConstantIndex(mlir::Location loc, mlir::Type ity,
mlir::ConversionPatternRewriter &rewriter,
std::int64_t offset) {
auto cattr = rewriter.getI64IntegerAttr(offset);
- return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
+ return mlir::LLVM::ConstantOp::create(rewriter, loc, ity, cattr);
}
mlir::Value
@@ -125,9 +125,9 @@ mlir::Value fir::integerCast(const fir::LLVMTypeConverter &converter,
return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val);
} else {
if (toSize < fromSize)
- return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val);
+ return mlir::LLVM::TruncOp::create(rewriter, loc, ty, val);
if (toSize > fromSize)
- return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val);
+ return mlir::LLVM::SExtOp::create(rewriter, loc, ty, val);
}
return val;
}
diff --git a/flang/lib/Optimizer/Transforms/AddAliasTags.cpp b/flang/lib/Optimizer/Transforms/AddAliasTags.cpp
index 0221c7a..142e4c8 100644
--- a/flang/lib/Optimizer/Transforms/AddAliasTags.cpp
+++ b/flang/lib/Optimizer/Transforms/AddAliasTags.cpp
@@ -60,6 +60,9 @@ static llvm::cl::opt<unsigned> localAllocsThreshold(
llvm::cl::desc("If present, stops generating TBAA tags for accesses of "
"local allocations after N accesses in a module"));
+// Defined in AliasAnalysis.cpp
+extern llvm::cl::opt<bool> supportCrayPointers;
+
namespace {
// Return the size and alignment (in bytes) for the given type.
@@ -210,10 +213,7 @@ public:
void processFunctionScopes(mlir::func::FuncOp func);
// For the given fir.declare returns the dominating fir.dummy_scope
// operation.
- fir::DummyScopeOp getDeclarationScope(fir::DeclareOp declareOp) const;
- // For the given fir.declare returns the outermost fir.dummy_scope
- // in the current function.
- fir::DummyScopeOp getOutermostScope(fir::DeclareOp declareOp) const;
+ fir::DummyScopeOp getDeclarationScope(fir::DeclareOp declareOp);
// Returns true, if the given type of a memref of a FirAliasTagOpInterface
// operation is a descriptor or contains a descriptor
// (e.g. !fir.ref<!fir.type<Derived{f:!fir.box<!fir.heap<f32>>}>>).
@@ -353,8 +353,9 @@ void PassState::processFunctionScopes(mlir::func::FuncOp func) {
}
}
-fir::DummyScopeOp
-PassState::getDeclarationScope(fir::DeclareOp declareOp) const {
+// For the given fir.declare returns the dominating fir.dummy_scope
+// operation.
+fir::DummyScopeOp PassState::getDeclarationScope(fir::DeclareOp declareOp) {
auto func = declareOp->getParentOfType<mlir::func::FuncOp>();
assert(func && "fir.declare does not have parent func.func");
auto &scopeOps = sortedScopeOperations.at(func);
@@ -365,15 +366,6 @@ PassState::getDeclarationScope(fir::DeclareOp declareOp) const {
return nullptr;
}
-fir::DummyScopeOp PassState::getOutermostScope(fir::DeclareOp declareOp) const {
- auto func = declareOp->getParentOfType<mlir::func::FuncOp>();
- assert(func && "fir.declare does not have parent func.func");
- auto &scopeOps = sortedScopeOperations.at(func);
- if (!scopeOps.empty())
- return scopeOps[0];
- return nullptr;
-}
-
bool PassState::typeReferencesDescriptor(mlir::Type type) {
type = fir::unwrapAllRefAndSeqType(type);
if (mlir::isa<fir::BaseBoxType>(type))
@@ -668,6 +660,7 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op,
LLVM_DEBUG(llvm::dbgs() << "Analysing " << op << "\n");
const fir::AliasAnalysis::Source &source = state.getSource(memref);
+ LLVM_DEBUG(llvm::dbgs() << "Got source " << source << "\n");
// Process the scopes, if not processed yet.
state.processFunctionScopes(func);
@@ -686,14 +679,22 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op,
}
mlir::LLVM::TBAATagAttr tag;
- // TBAA for dummy arguments
- if (enableDummyArgs &&
- source.kind == fir::AliasAnalysis::SourceKind::Argument) {
+ // Cray pointer/pointee is a special case. These might alias with any data.
+ if (supportCrayPointers && source.isCrayPointerOrPointee()) {
+ LLVM_DEBUG(llvm::dbgs().indent(2)
+ << "Found reference to Cray pointer/pointee at " << *op << "\n");
+ mlir::LLVM::TBAATypeDescriptorAttr anyDataDesc =
+ state.getFuncTreeWithScope(func, scopeOp).anyDataTypeDesc;
+ tag = mlir::LLVM::TBAATagAttr::get(anyDataDesc, anyDataDesc, /*offset=*/0);
+ // TBAA for dummy arguments
+ } else if (enableDummyArgs &&
+ source.kind == fir::AliasAnalysis::SourceKind::Argument) {
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Found reference to dummy argument at " << *op << "\n");
std::string name = getFuncArgName(llvm::cast<mlir::Value>(source.origin.u));
- // If it is a TARGET or POINTER, then we do not care about the name,
- // because the tag points to the root of the subtree currently.
+ // POINTERS can alias with any POINTER or TARGET. Assume that TARGET dummy
+ // arguments might alias with each other (because of the "TARGET" hole for
+ // dummy arguments). See flang/docs/Aliasing.md.
if (source.isTargetOrPointer()) {
tag = state.getFuncTreeWithScope(func, scopeOp).targetDataTree.getTag();
} else if (!name.empty()) {
@@ -715,13 +716,10 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op,
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "Found reference to global " << globalName.str() << " at "
<< *op << "\n");
- if (source.isPointer()) {
- tag = state.getFuncTreeWithScope(func, scopeOp).targetDataTree.getTag();
- } else {
- // In general, place the tags under the "global data" root.
- fir::TBAATree::SubtreeState *subTree =
- &state.getMutableFuncTreeWithScope(func, scopeOp).globalDataTree;
+ // Add a named tag inside the given subtree, disambiguating members of a
+ // common block
+ auto addTagUsingStorageDesc = [&](fir::TBAATree::SubtreeState *subTree) {
mlir::Operation *instantiationPoint = source.origin.instantiationPoint;
auto storageIface =
mlir::dyn_cast_or_null<fir::FortranVariableStorageOpInterface>(
@@ -766,6 +764,19 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op,
LLVM_DEBUG(llvm::dbgs()
<< "Tagged under '" << globalName << "' root\n");
}
+ };
+
+ if (source.isPointer()) {
+ // Pointers can alias with any pointer or target.
+ tag = state.getFuncTreeWithScope(func, scopeOp).targetDataTree.getTag();
+ } else if (source.isTarget()) {
+ // Targets could alias with any pointer but not with each other.
+ addTagUsingStorageDesc(
+ &state.getMutableFuncTreeWithScope(func, scopeOp).targetDataTree);
+ } else {
+ // In general, place the tags under the "global data" root.
+ addTagUsingStorageDesc(
+ &state.getMutableFuncTreeWithScope(func, scopeOp).globalDataTree);
}
// TBAA for global variables with descriptors
@@ -776,9 +787,17 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op,
const char *name = glbl.getRootReference().data();
LLVM_DEBUG(llvm::dbgs().indent(2) << "Found reference to direct " << name
<< " at " << *op << "\n");
+ // Pointer can alias with any pointer or target so that gets the root.
if (source.isPointer())
tag = state.getFuncTreeWithScope(func, scopeOp).targetDataTree.getTag();
+ // Targets could alias with any pointer but not with each other so they
+ // get their own node inside of the target data tree.
+ else if (source.isTarget())
+ tag = state.getFuncTreeWithScope(func, scopeOp)
+ .targetDataTree.getTag(name);
else
+ // Boxes that are not pointers or targets cannot alias with those that
+ // are. Put them under global data.
tag = state.getFuncTreeWithScope(func, scopeOp)
.directDataTree.getTag(name);
} else {
@@ -800,22 +819,23 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op,
else
unknownAllocOp = true;
- if (auto declOp = source.origin.instantiationPoint) {
- // Use the outermost scope for local allocations,
- // because using the innermost scope may result
- // in incorrect TBAA, when calls are inlined in MLIR.
- auto declareOp = mlir::dyn_cast<fir::DeclareOp>(declOp);
- assert(declareOp && "Instantiation point must be fir.declare");
- scopeOp = state.getOutermostScope(declareOp);
- }
-
if (unknownAllocOp) {
LLVM_DEBUG(llvm::dbgs().indent(2)
<< "WARN: unknown defining op for SourceKind::Allocate " << *op
<< "\n");
} else if (source.isPointer() && state.attachLocalAllocTag()) {
LLVM_DEBUG(llvm::dbgs().indent(2)
- << "Found reference to allocation at " << *op << "\n");
+ << "Found reference to POINTER allocation at " << *op << "\n");
+ tag = state.getFuncTreeWithScope(func, scopeOp).targetDataTree.getTag();
+ } else if (name && source.isTarget() && state.attachLocalAllocTag()) {
+ LLVM_DEBUG(llvm::dbgs().indent(2)
+ << "Found reference to TARGET allocation at " << *op << "\n");
+ tag = state.getFuncTreeWithScope(func, scopeOp)
+ .targetDataTree.getTag(*name);
+ } else if (source.isTarget() && state.attachLocalAllocTag()) {
+ LLVM_DEBUG(llvm::dbgs().indent(2)
+ << "WARN: couldn't find a name for TARGET allocation " << *op
+ << "\n");
tag = state.getFuncTreeWithScope(func, scopeOp).targetDataTree.getTag();
} else if (name && state.attachLocalAllocTag()) {
LLVM_DEBUG(llvm::dbgs().indent(2) << "Found reference to allocation "
diff --git a/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp b/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp
index e006d2e..35d8a2f6 100644
--- a/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp
+++ b/flang/lib/Optimizer/Transforms/AddDebugInfo.cpp
@@ -53,7 +53,7 @@ class AddDebugInfoPass : public fir::impl::AddDebugInfoBase<AddDebugInfoPass> {
mlir::LLVM::DIFileAttr fileAttr,
mlir::LLVM::DIScopeAttr scopeAttr,
fir::DebugTypeGenerator &typeGen,
- mlir::SymbolTable *symbolTable);
+ mlir::SymbolTable *symbolTable, mlir::Value dummyScope);
public:
AddDebugInfoPass(fir::AddDebugInfoOptions options) : Base(options) {}
@@ -84,6 +84,24 @@ private:
mlir::LLVM::DICompileUnitAttr cuAttr,
fir::DebugTypeGenerator &typeGen,
mlir::SymbolTable *symbolTable);
+ void handleOnlyClause(
+ fir::UseStmtOp useOp, mlir::LLVM::DISubprogramAttr spAttr,
+ mlir::LLVM::DIFileAttr fileAttr, mlir::SymbolTable *symbolTable,
+ llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> &importedModules);
+ void handleRenamesWithoutOnly(
+ fir::UseStmtOp useOp, mlir::LLVM::DISubprogramAttr spAttr,
+ mlir::LLVM::DIModuleAttr modAttr, mlir::LLVM::DIFileAttr fileAttr,
+ mlir::SymbolTable *symbolTable,
+ llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> &importedModules);
+ void handleUseStatements(
+ mlir::func::FuncOp funcOp, mlir::LLVM::DISubprogramAttr spAttr,
+ mlir::LLVM::DIFileAttr fileAttr, mlir::LLVM::DICompileUnitAttr cuAttr,
+ mlir::SymbolTable *symbolTable,
+ llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> &importedEntities);
+ std::optional<mlir::LLVM::DIImportedEntityAttr> createImportedDeclForGlobal(
+ llvm::StringRef symbolName, mlir::LLVM::DISubprogramAttr spAttr,
+ mlir::LLVM::DIFileAttr fileAttr, mlir::StringAttr localNameAttr,
+ mlir::SymbolTable *symbolTable);
bool createCommonBlockGlobal(fir::cg::XDeclareOp declOp,
const std::string &name,
mlir::LLVM::DIFileAttr fileAttr,
@@ -138,75 +156,122 @@ mlir::StringAttr getTargetFunctionName(mlir::MLIRContext *context,
} // namespace
+// Check if a global represents a module variable
+static bool isModuleVariable(fir::GlobalOp globalOp) {
+ std::pair result = fir::NameUniquer::deconstruct(globalOp.getSymName());
+ return result.first == fir::NameUniquer::NameKind::VARIABLE &&
+ result.second.procs.empty() && !result.second.modules.empty();
+}
+
+// Look up DIGlobalVariable from a global symbol
+static std::optional<mlir::LLVM::DIGlobalVariableAttr>
+lookupDIGlobalVariable(llvm::StringRef symbolName,
+ mlir::SymbolTable *symbolTable) {
+ if (auto globalOp = symbolTable->lookup<fir::GlobalOp>(symbolName)) {
+ if (auto fusedLoc = mlir::dyn_cast<mlir::FusedLoc>(globalOp.getLoc())) {
+ if (auto metadata = fusedLoc.getMetadata()) {
+ if (auto arrayAttr = mlir::dyn_cast<mlir::ArrayAttr>(metadata)) {
+ for (auto elem : arrayAttr) {
+ if (auto gvExpr =
+ mlir::dyn_cast<mlir::LLVM::DIGlobalVariableExpressionAttr>(
+ elem))
+ return gvExpr.getVar();
+ }
+ }
+ }
+ }
+ }
+ return std::nullopt;
+}
+
bool AddDebugInfoPass::createCommonBlockGlobal(
fir::cg::XDeclareOp declOp, const std::string &name,
mlir::LLVM::DIFileAttr fileAttr, mlir::LLVM::DIScopeAttr scopeAttr,
fir::DebugTypeGenerator &typeGen, mlir::SymbolTable *symbolTable) {
mlir::MLIRContext *context = &getContext();
mlir::OpBuilder builder(context);
- std::optional<std::int64_t> optint;
- mlir::Operation *op = declOp.getMemref().getDefiningOp();
-
- if (auto conOp = mlir::dyn_cast_if_present<fir::ConvertOp>(op))
- op = conOp.getValue().getDefiningOp();
- if (auto cordOp = mlir::dyn_cast_if_present<fir::CoordinateOp>(op)) {
- auto coors = cordOp.getCoor();
- if (coors.size() != 1)
- return false;
- optint = fir::getIntIfConstant(coors[0]);
- if (!optint)
- return false;
- op = cordOp.getRef().getDefiningOp();
- if (auto conOp2 = mlir::dyn_cast_if_present<fir::ConvertOp>(op))
- op = conOp2.getValue().getDefiningOp();
-
- if (auto addrOfOp = mlir::dyn_cast_if_present<fir::AddrOfOp>(op)) {
- mlir::SymbolRefAttr sym = addrOfOp.getSymbol();
- if (auto global =
- symbolTable->lookup<fir::GlobalOp>(sym.getRootReference())) {
-
- unsigned line = getLineFromLoc(global.getLoc());
- llvm::StringRef commonName(sym.getRootReference());
- // FIXME: We are trying to extract the name of the common block from the
- // name of the global. As part of mangling, GetCommonBlockObjectName can
- // add a trailing _ in the name of that global. The demangle function
- // does not seem to handle such cases. So the following hack is used to
- // remove the trailing '_'.
- if (commonName != Fortran::common::blankCommonObjectName &&
- commonName.back() == '_')
- commonName = commonName.drop_back();
- mlir::LLVM::DICommonBlockAttr commonBlock =
- getOrCreateCommonBlockAttr(commonName, fileAttr, scopeAttr, line);
- mlir::LLVM::DITypeAttr diType = typeGen.convertType(
- fir::unwrapRefType(declOp.getType()), fileAttr, scopeAttr, declOp);
- line = getLineFromLoc(declOp.getLoc());
- auto gvAttr = mlir::LLVM::DIGlobalVariableAttr::get(
- context, commonBlock, mlir::StringAttr::get(context, name),
- declOp.getUniqName(), fileAttr, line, diType,
- /*isLocalToUnit*/ false, /*isDefinition*/ true, /* alignInBits*/ 0);
- mlir::LLVM::DIExpressionAttr expr;
- if (*optint != 0) {
- llvm::SmallVector<mlir::LLVM::DIExpressionElemAttr> ops;
- ops.push_back(mlir::LLVM::DIExpressionElemAttr::get(
- context, llvm::dwarf::DW_OP_plus_uconst, *optint));
- expr = mlir::LLVM::DIExpressionAttr::get(context, ops);
- }
- auto dbgExpr = mlir::LLVM::DIGlobalVariableExpressionAttr::get(
- global.getContext(), gvAttr, expr);
- globalToGlobalExprsMap[global].push_back(dbgExpr);
- return true;
- }
- }
+ std::optional<std::int64_t> offset;
+ mlir::Value storage = declOp.getStorage();
+ if (!storage)
+ return false;
+
+ // Extract offset from storage_offset attribute
+ uint64_t storageOffset = declOp.getStorageOffset();
+ if (storageOffset != 0)
+ offset = static_cast<std::int64_t>(storageOffset);
+
+ // Get the GlobalOp from the storage value.
+ // The storage may be wrapped in ConvertOp, so unwrap it first.
+ mlir::Operation *storageOp = storage.getDefiningOp();
+ if (auto convertOp = mlir::dyn_cast_if_present<fir::ConvertOp>(storageOp))
+ storageOp = convertOp.getValue().getDefiningOp();
+
+ auto addrOfOp = mlir::dyn_cast_if_present<fir::AddrOfOp>(storageOp);
+ if (!addrOfOp)
+ return false;
+
+ mlir::SymbolRefAttr sym = addrOfOp.getSymbol();
+ fir::GlobalOp global =
+ symbolTable->lookup<fir::GlobalOp>(sym.getRootReference());
+ if (!global)
+ return false;
+
+ // Check if the global is actually a common block by demangling its name.
+ // Module EQUIVALENCE variables also use storage operands but are mangled
+ // as VARIABLE type, so we reject them to avoid treating them as common
+ // blocks.
+ llvm::StringRef globalSymbol = sym.getRootReference();
+ auto globalResult = fir::NameUniquer::deconstruct(globalSymbol);
+ if (globalResult.first == fir::NameUniquer::NameKind::VARIABLE)
+ return false;
+
+ // FIXME: We are trying to extract the name of the common block from the
+ // name of the global. As part of mangling, GetCommonBlockObjectName can
+ // add a trailing _ in the name of that global. The demangle function
+ // does not seem to handle such cases. So the following hack is used to
+ // remove the trailing '_'.
+ llvm::StringRef commonName = globalSymbol;
+ if (commonName != Fortran::common::blankCommonObjectName &&
+ !commonName.empty() && commonName.back() == '_')
+ commonName = commonName.drop_back();
+
+ // Create the debug attributes.
+ unsigned line = getLineFromLoc(global.getLoc());
+ mlir::LLVM::DICommonBlockAttr commonBlock =
+ getOrCreateCommonBlockAttr(commonName, fileAttr, scopeAttr, line);
+
+ mlir::LLVM::DITypeAttr diType = typeGen.convertType(
+ fir::unwrapRefType(declOp.getType()), fileAttr, scopeAttr, declOp);
+
+ line = getLineFromLoc(declOp.getLoc());
+ auto gvAttr = mlir::LLVM::DIGlobalVariableAttr::get(
+ context, commonBlock, mlir::StringAttr::get(context, name),
+ declOp.getUniqName(), fileAttr, line, diType,
+ /*isLocalToUnit*/ false, /*isDefinition*/ true, /* alignInBits*/ 0);
+
+ // Create DIExpression for offset if needed
+ mlir::LLVM::DIExpressionAttr expr;
+ if (offset && *offset != 0) {
+ llvm::SmallVector<mlir::LLVM::DIExpressionElemAttr> ops;
+ ops.push_back(mlir::LLVM::DIExpressionElemAttr::get(
+ context, llvm::dwarf::DW_OP_plus_uconst, *offset));
+ expr = mlir::LLVM::DIExpressionAttr::get(context, ops);
}
- return false;
+
+ auto dbgExpr = mlir::LLVM::DIGlobalVariableExpressionAttr::get(
+ global.getContext(), gvAttr, expr);
+ globalToGlobalExprsMap[global].push_back(dbgExpr);
+
+ return true;
}
void AddDebugInfoPass::handleDeclareOp(fir::cg::XDeclareOp declOp,
mlir::LLVM::DIFileAttr fileAttr,
mlir::LLVM::DIScopeAttr scopeAttr,
fir::DebugTypeGenerator &typeGen,
- mlir::SymbolTable *symbolTable) {
+ mlir::SymbolTable *symbolTable,
+ mlir::Value dummyScope) {
mlir::MLIRContext *context = &getContext();
mlir::OpBuilder builder(context);
auto result = fir::NameUniquer::deconstruct(declOp.getUniqName());
@@ -228,24 +293,11 @@ void AddDebugInfoPass::handleDeclareOp(fir::cg::XDeclareOp declOp,
}
}
- // FIXME: There may be cases where an argument is processed a bit before
- // DeclareOp is generated. In that case, DeclareOp may point to an
- // intermediate op and not to BlockArgument.
- // Moreover, with MLIR inlining we cannot use the BlockArgument
- // position to identify the original number of the dummy argument.
- // If we want to keep running AddDebugInfoPass late, the dummy argument
- // position in the argument list has to be expressed in FIR (e.g. as a
- // constant attribute of [hl]fir.declare/fircg.ext_declare operation that has
- // a dummy_scope operand).
+ // Get the dummy argument position from the explicit attribute.
unsigned argNo = 0;
- if (declOp.getDummyScope()) {
- if (auto arg = llvm::dyn_cast<mlir::BlockArgument>(declOp.getMemref())) {
- // Check if it is the BlockArgument of the function's entry block.
- if (auto funcLikeOp =
- declOp->getParentOfType<mlir::FunctionOpInterface>())
- if (arg.getOwner() == &funcLikeOp.front())
- argNo = arg.getArgNumber() + 1;
- }
+ if (dummyScope && declOp.getDummyScope() == dummyScope) {
+ if (auto argNoOpt = declOp.getDummyArgNo())
+ argNo = *argNoOpt;
}
auto tyAttr = typeGen.convertType(fir::unwrapRefType(declOp.getType()),
@@ -520,7 +572,7 @@ void AddDebugInfoPass::handleFuncOp(mlir::func::FuncOp funcOp,
CC = llvm::dwarf::getCallingConvention("DW_CC_normal");
mlir::LLVM::DISubroutineTypeAttr spTy =
mlir::LLVM::DISubroutineTypeAttr::get(context, CC, types);
- if (lineTableOnly) {
+ if (lineTableOnly || entities.empty()) {
auto spAttr = mlir::LLVM::DISubprogramAttr::get(
context, id, compilationUnit, Scope, name, name, funcFileAttr, line,
line, flags, spTy, /*retainedNodes=*/{}, /*annotations=*/{});
@@ -540,9 +592,9 @@ void AddDebugInfoPass::handleFuncOp(mlir::func::FuncOp funcOp,
for (mlir::LLVM::DINodeAttr N : entities) {
if (auto entity = mlir::dyn_cast<mlir::LLVM::DIImportedEntityAttr>(N)) {
auto importedEntity = mlir::LLVM::DIImportedEntityAttr::get(
- context, llvm::dwarf::DW_TAG_imported_module, spAttr,
- entity.getEntity(), fileAttr, /*line=*/1, /*name=*/nullptr,
- /*elements*/ {});
+ context, entity.getTag(), spAttr, entity.getEntity(),
+ entity.getFile(), entity.getLine(), entity.getName(),
+ entity.getElements());
opEntities.push_back(importedEntity);
}
}
@@ -567,61 +619,72 @@ void AddDebugInfoPass::handleFuncOp(mlir::func::FuncOp funcOp,
return;
}
- mlir::DistinctAttr recId =
- mlir::DistinctAttr::create(mlir::UnitAttr::get(context));
-
- // The debug attribute in MLIR are readonly once created. But in case of
- // imported entities, we have a circular dependency. The
- // DIImportedEntityAttr requires scope information (DISubprogramAttr in this
- // case) and DISubprogramAttr requires the list of imported entities. The
- // MLIR provides a way where a DISubprogramAttr an be created with a certain
- // recID and be used in places like DIImportedEntityAttr. After that another
- // DISubprogramAttr can be created with same recID but with list of entities
- // now available. The MLIR translation code takes care of updating the
- // references. Note that references will be updated only in the things that
- // are part of DISubprogramAttr (like DIImportedEntityAttr) so we have to
- // create the final DISubprogramAttr before we process local variables.
- // Look at DIRecursiveTypeAttrInterface for more details.
-
- auto spAttr = mlir::LLVM::DISubprogramAttr::get(
- context, recId, /*isRecSelf=*/true, id, compilationUnit, Scope, funcName,
- fullName, funcFileAttr, line, line, subprogramFlags, subTypeAttr,
- /*retainedNodes=*/{}, /*annotations=*/{});
-
- // There is no direct information in the IR for any 'use' statement in the
- // function. We have to extract that information from the DeclareOp. We do
- // a pass on the DeclareOp and generate ModuleAttr and corresponding
- // DIImportedEntityAttr for that module.
- // FIXME: As we are depending on the variables to see which module is being
- // 'used' in the function, there are certain limitations.
- // For things like 'use mod1, only: v1', whole module will be brought into the
- // namespace in the debug info. It is not a problem as such unless there is a
- // clash of names.
- // There is no information about module variable renaming
- llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> importedModules;
- funcOp.walk([&](fir::cg::XDeclareOp declOp) {
- if (&funcOp.front() == declOp->getBlock())
- if (auto global =
- symbolTable->lookup<fir::GlobalOp>(declOp.getUniqName())) {
- std::optional<mlir::LLVM::DIModuleAttr> modOpt =
- getModuleAttrFromGlobalOp(global, fileAttr, cuAttr);
- if (modOpt) {
- auto importedEntity = mlir::LLVM::DIImportedEntityAttr::get(
- context, llvm::dwarf::DW_TAG_imported_module, spAttr, *modOpt,
- fileAttr, /*line=*/1, /*name=*/nullptr, /*elements*/ {});
- importedModules.insert(importedEntity);
- }
- }
+ // Check if there are any USE statements
+ bool hasUseStmts = false;
+ funcOp.walk([&](fir::UseStmtOp useOp) {
+ hasUseStmts = true;
+ return mlir::WalkResult::interrupt();
});
- llvm::SmallVector<mlir::LLVM::DINodeAttr> entities(importedModules.begin(),
- importedModules.end());
- // We have the imported entities now. Generate the final DISubprogramAttr.
- spAttr = mlir::LLVM::DISubprogramAttr::get(
- context, recId, /*isRecSelf=*/false, id2, compilationUnit, Scope,
- funcName, fullName, funcFileAttr, line, line, subprogramFlags,
- subTypeAttr, entities, /*annotations=*/{});
+
+ mlir::LLVM::DISubprogramAttr spAttr;
+ llvm::SmallVector<mlir::LLVM::DINodeAttr> retainedNodes;
+
+ if (hasUseStmts) {
+ mlir::DistinctAttr recId =
+ mlir::DistinctAttr::create(mlir::UnitAttr::get(context));
+ // The debug attribute in MLIR are readonly once created. But in case of
+ // imported entities, we have a circular dependency. The
+ // DIImportedEntityAttr requires scope information (DISubprogramAttr in this
+ // case) and DISubprogramAttr requires the list of imported entities. The
+ // MLIR provides a way where a DISubprogramAttr an be created with a certain
+ // recID and be used in places like DIImportedEntityAttr. After that another
+ // DISubprogramAttr can be created with same recID but with list of entities
+ // now available. The MLIR translation code takes care of updating the
+ // references. Note that references will be updated only in the things that
+ // are part of DISubprogramAttr (like DIImportedEntityAttr) so we have to
+ // create the final DISubprogramAttr before we process local variables.
+ // Look at DIRecursiveTypeAttrInterface for more details.
+ spAttr = mlir::LLVM::DISubprogramAttr::get(
+ context, recId, /*isRecSelf=*/true, id, compilationUnit, Scope,
+ funcName, fullName, funcFileAttr, line, line, subprogramFlags,
+ subTypeAttr, /*retainedNodes=*/{}, /*annotations=*/{});
+
+ // Process USE statements (module globals are already processed)
+ llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> importedEntities;
+ handleUseStatements(funcOp, spAttr, fileAttr, cuAttr, symbolTable,
+ importedEntities);
+
+ retainedNodes.append(importedEntities.begin(), importedEntities.end());
+
+ // Create final DISubprogramAttr with imported entities and same recId
+ spAttr = mlir::LLVM::DISubprogramAttr::get(
+ context, recId, /*isRecSelf=*/false, id2, compilationUnit, Scope,
+ funcName, fullName, funcFileAttr, line, line, subprogramFlags,
+ subTypeAttr, retainedNodes, /*annotations=*/{});
+ } else
+ // No USE statements - create final DISubprogramAttr directly
+ spAttr = mlir::LLVM::DISubprogramAttr::get(
+ context, id, compilationUnit, Scope, funcName, fullName, funcFileAttr,
+ line, line, subprogramFlags, subTypeAttr, /*retainedNodes=*/{},
+ /*annotations=*/{});
+
funcOp->setLoc(builder.getFusedLoc({l}, spAttr));
- addTargetOpDISP(/*lineTableOnly=*/false, entities);
+ addTargetOpDISP(/*lineTableOnly=*/false, retainedNodes);
+
+ // Find the first dummy_scope definition. This is the one of the current
+ // function. The other ones may come from inlined calls. The variables inside
+ // those inlined calls should not be identified as arguments of the current
+ // function.
+ mlir::Value dummyScope;
+ funcOp.walk([&](fir::UndefOp undef) -> mlir::WalkResult {
+ // TODO: delay fir.dummy_scope translation to undefined until
+ // codegeneration. This is nicer and safer to match.
+ if (llvm::isa<fir::DummyScopeType>(undef.getType())) {
+ dummyScope = undef;
+ return mlir::WalkResult::interrupt();
+ }
+ return mlir::WalkResult::advance();
+ });
funcOp.walk([&](fir::cg::XDeclareOp declOp) {
mlir::LLVM::DISubprogramAttr spTy = spAttr;
@@ -632,7 +695,7 @@ void AddDebugInfoPass::handleFuncOp(mlir::func::FuncOp funcOp,
spTy = sp;
}
}
- handleDeclareOp(declOp, fileAttr, spTy, typeGen, symbolTable);
+ handleDeclareOp(declOp, fileAttr, spTy, typeGen, symbolTable, dummyScope);
});
// commonBlockMap ensures that we don't create multiple DICommonBlockAttr of
// the same name in one function. But it is ok (rather required) to create
@@ -641,6 +704,110 @@ void AddDebugInfoPass::handleFuncOp(mlir::func::FuncOp funcOp,
commonBlockMap.clear();
}
+// Helper function to create a DIImportedEntityAttr for an imported declaration.
+// Looks up the DIGlobalVariable for the given symbol and creates an imported
+// declaration with the optional local name (for renames).
+// Returns std::nullopt if the symbol's DIGlobalVariable is not found.
+std::optional<mlir::LLVM::DIImportedEntityAttr>
+AddDebugInfoPass::createImportedDeclForGlobal(
+ llvm::StringRef symbolName, mlir::LLVM::DISubprogramAttr spAttr,
+ mlir::LLVM::DIFileAttr fileAttr, mlir::StringAttr localNameAttr,
+ mlir::SymbolTable *symbolTable) {
+ mlir::MLIRContext *context = &getContext();
+ if (auto gvAttr = lookupDIGlobalVariable(symbolName, symbolTable)) {
+ return mlir::LLVM::DIImportedEntityAttr::get(
+ context, llvm::dwarf::DW_TAG_imported_declaration, spAttr, *gvAttr,
+ fileAttr, /*line=*/1, /*name=*/localNameAttr, /*elements*/ {});
+ }
+ return std::nullopt;
+}
+
+// Process USE with ONLY clause
+void AddDebugInfoPass::handleOnlyClause(
+ fir::UseStmtOp useOp, mlir::LLVM::DISubprogramAttr spAttr,
+ mlir::LLVM::DIFileAttr fileAttr, mlir::SymbolTable *symbolTable,
+ llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> &importedModules) {
+ // Process ONLY symbols (without renames)
+ if (auto onlySymbols = useOp.getOnlySymbols()) {
+ for (mlir::Attribute attr : *onlySymbols) {
+ auto symbolRef = mlir::cast<mlir::FlatSymbolRefAttr>(attr);
+ if (auto importedDecl = createImportedDeclForGlobal(
+ symbolRef.getValue(), spAttr, fileAttr, mlir::StringAttr(),
+ symbolTable))
+ importedModules.insert(*importedDecl);
+ }
+ }
+
+ // Process renames within ONLY clause
+ if (auto renames = useOp.getRenames()) {
+ for (auto attr : *renames) {
+ auto renameAttr = mlir::cast<fir::UseRenameAttr>(attr);
+ if (auto importedDecl = createImportedDeclForGlobal(
+ renameAttr.getSymbol().getValue(), spAttr, fileAttr,
+ renameAttr.getLocalName(), symbolTable))
+ importedModules.insert(*importedDecl);
+ }
+ }
+}
+
+// Process USE with renames but no ONLY clause
+void AddDebugInfoPass::handleRenamesWithoutOnly(
+ fir::UseStmtOp useOp, mlir::LLVM::DISubprogramAttr spAttr,
+ mlir::LLVM::DIModuleAttr modAttr, mlir::LLVM::DIFileAttr fileAttr,
+ mlir::SymbolTable *symbolTable,
+ llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> &importedModules) {
+ mlir::MLIRContext *context = &getContext();
+ llvm::SmallVector<mlir::LLVM::DINodeAttr> childDeclarations;
+
+ if (auto renames = useOp.getRenames()) {
+ for (auto attr : *renames) {
+ auto renameAttr = mlir::cast<fir::UseRenameAttr>(attr);
+ if (auto importedDecl = createImportedDeclForGlobal(
+ renameAttr.getSymbol().getValue(), spAttr, fileAttr,
+ renameAttr.getLocalName(), symbolTable))
+ childDeclarations.push_back(*importedDecl);
+ }
+ }
+
+ // Create module import with renamed declarations as children
+ auto moduleImport = mlir::LLVM::DIImportedEntityAttr::get(
+ context, llvm::dwarf::DW_TAG_imported_module, spAttr, modAttr, fileAttr,
+ /*line=*/1, /*name=*/nullptr, childDeclarations);
+ importedModules.insert(moduleImport);
+}
+
+// Process all USE statements in a function and collect imported entities
+void AddDebugInfoPass::handleUseStatements(
+ mlir::func::FuncOp funcOp, mlir::LLVM::DISubprogramAttr spAttr,
+ mlir::LLVM::DIFileAttr fileAttr, mlir::LLVM::DICompileUnitAttr cuAttr,
+ mlir::SymbolTable *symbolTable,
+ llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> &importedEntities) {
+ mlir::MLIRContext *context = &getContext();
+
+ funcOp.walk([&](fir::UseStmtOp useOp) {
+ mlir::LLVM::DIModuleAttr modAttr = getOrCreateModuleAttr(
+ useOp.getModuleName().str(), fileAttr, cuAttr, /*line=*/1,
+ /*decl=*/true);
+
+ llvm::DenseSet<mlir::LLVM::DIImportedEntityAttr> importedModules;
+
+ if (useOp.hasOnlyClause())
+ handleOnlyClause(useOp, spAttr, fileAttr, symbolTable, importedModules);
+ else if (useOp.hasRenames())
+ handleRenamesWithoutOnly(useOp, spAttr, modAttr, fileAttr, symbolTable,
+ importedModules);
+ else {
+ // Simple module import
+ auto importedEntity = mlir::LLVM::DIImportedEntityAttr::get(
+ context, llvm::dwarf::DW_TAG_imported_module, spAttr, modAttr,
+ fileAttr, /*line=*/1, /*name=*/nullptr, /*elements*/ {});
+ importedModules.insert(importedEntity);
+ }
+
+ importedEntities.insert(importedModules.begin(), importedModules.end());
+ });
+}
+
void AddDebugInfoPass::runOnOperation() {
mlir::ModuleOp module = getOperation();
mlir::MLIRContext *context = &getContext();
@@ -704,6 +871,26 @@ void AddDebugInfoPass::runOnOperation() {
splitDwarfFile.empty() ? mlir::StringAttr()
: mlir::StringAttr::get(context, splitDwarfFile));
+ // Process module globals early.
+ // Walk through all DeclareOps in functions and process globals that are
+ // module variables. This ensures that when we process USE statements,
+ // the DIGlobalVariable lookups will succeed.
+ if (debugLevel == mlir::LLVM::DIEmissionKind::Full) {
+ module.walk([&](fir::cg::XDeclareOp declOp) {
+ mlir::Operation *defOp = declOp.getMemref().getDefiningOp();
+ if (defOp && llvm::isa<fir::AddrOfOp>(defOp)) {
+ if (auto globalOp =
+ symbolTable.lookup<fir::GlobalOp>(declOp.getUniqName())) {
+ // Only process module variables here, not SAVE variables
+ if (isModuleVariable(globalOp)) {
+ handleGlobalOp(globalOp, fileAttr, cuAttr, typeGen, &symbolTable,
+ declOp);
+ }
+ }
+ }
+ });
+ }
+
module.walk([&](mlir::func::FuncOp funcOp) {
handleFuncOp(funcOp, fileAttr, cuAttr, typeGen, &symbolTable);
});
diff --git a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp
index ed9a2ae..5bf783db 100644
--- a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp
+++ b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp
@@ -832,8 +832,8 @@ static mlir::Type getEleTy(mlir::Type ty) {
static bool isAssumedSize(llvm::SmallVectorImpl<mlir::Value> &extents) {
if (extents.empty())
return false;
- auto cstLen = fir::getIntIfConstant(extents.back());
- return cstLen.has_value() && *cstLen == -1;
+ return llvm::isa_and_nonnull<fir::AssumedSizeExtentOp>(
+ extents.back().getDefiningOp());
}
// Extract extents from the ShapeOp/ShapeShiftOp into the result vector.
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 0388439..5a3059eb 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -1,22 +1,35 @@
add_flang_library(FIRTransforms
AbstractResult.cpp
AddAliasTags.cpp
- AffinePromotion.cpp
+ AddDebugInfo.cpp
AffineDemotion.cpp
+ AffinePromotion.cpp
+ AlgebraicSimplification.cpp
AnnotateConstant.cpp
+ ArrayValueCopy.cpp
+ ArrayValueCopy.cpp
AssumedRankOpConversion.cpp
+ CUDA/CUFAddConstructor.cpp
+ CUDA/CUFAllocationConversion.cpp
+ CUDA/CUFAllocationConversion.cpp
+ CUDA/CUFComputeSharedMemoryOffsetsAndSize.cpp
+ CUDA/CUFDeviceFuncTransform.cpp
+ CUDA/CUFDeviceGlobal.cpp
+ CUDA/CUFFunctionRewrite.cpp
+ CUDA/CUFGPUToLLVMConversion.cpp
+ CUDA/CUFLaunchAttachAttr.cpp
+ CUDA/CUFOpConversion.cpp
+ CUDA/CUFOpConversionLate.cpp
+ CUDA/CUFPredefinedVarToGPU.cpp
CharacterConversion.cpp
CompilerGeneratedNames.cpp
ConstantArgumentGlobalisation.cpp
ControlFlowConverter.cpp
- CUFAddConstructor.cpp
- CUFDeviceGlobal.cpp
- CUFOpConversion.cpp
- CUFGPUToLLVMConversion.cpp
- CUFComputeSharedMemoryOffsetsAndSize.cpp
- ArrayValueCopy.cpp
+ ConvertComplexPow.cpp
+ DebugTypeGenerator.cpp
ExternalNameConversion.cpp
FIRToSCF.cpp
+ FIRToMemRef.cpp
MemoryUtils.cpp
MemoryAllocation.cpp
StackArrays.cpp
@@ -26,17 +39,23 @@ add_flang_library(FIRTransforms
SimplifyIntrinsics.cpp
AddDebugInfo.cpp
PolymorphicOpConversion.cpp
- LoopVersioning.cpp
- StackReclaim.cpp
- VScaleAttr.cpp
FunctionAttr.cpp
- DebugTypeGenerator.cpp
- SetRuntimeCallAttributes.cpp
GenRuntimeCallsForTest.cpp
- SimplifyFIROperations.cpp
- OptimizeArrayRepacking.cpp
- ConvertComplexPow.cpp
+ LoopInvariantCodeMotion.cpp
+ LoopVersioning.cpp
MIFOpConversion.cpp
+ MemRefDataFlowOpt.cpp
+ MemoryAllocation.cpp
+ MemoryUtils.cpp
+ OptimizeArrayRepacking.cpp
+ PolymorphicOpConversion.cpp
+ SetRuntimeCallAttributes.cpp
+ SimplifyFIROperations.cpp
+ SimplifyIntrinsics.cpp
+ SimplifyRegionLite.cpp
+ StackArrays.cpp
+ StackReclaim.cpp
+ VScaleAttr.cpp
DEPENDS
CUFAttrs
@@ -62,12 +81,14 @@ add_flang_library(FIRTransforms
MLIR_LIBS
MLIRAffineUtils
+ MLIRAnalysis
MLIRFuncDialect
MLIRGPUDialect
- MLIRLLVMDialect
MLIRLLVMCommonConversion
+ MLIRLLVMDialect
MLIRMathTransforms
MLIROpenACCDialect
MLIROpenACCToLLVMIRTranslation
MLIROpenMPDialect
+ MLIRTransformUtils
)
diff --git a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp
index baa8e59..baa8e59 100644
--- a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp
new file mode 100644
index 0000000..4e2bcb6
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFAllocationConversion.cpp
@@ -0,0 +1,445 @@
+//===-- CUFAllocationConversion.cpp ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Transforms/CUDA/CUFAllocationConversion.h"
+#include "flang/Optimizer/Builder/CUFCommon.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h"
+#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
+#include "flang/Optimizer/CodeGen/TypeConverter.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Support/DataLayout.h"
+#include "flang/Runtime/CUDA/allocatable.h"
+#include "flang/Runtime/CUDA/common.h"
+#include "flang/Runtime/CUDA/descriptor.h"
+#include "flang/Runtime/CUDA/memory.h"
+#include "flang/Runtime/CUDA/pointer.h"
+#include "flang/Runtime/allocatable.h"
+#include "flang/Runtime/allocator-registry-consts.h"
+#include "flang/Support/Fortran.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFALLOCATIONCONVERSION
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace fir;
+using namespace mlir;
+using namespace Fortran::runtime;
+using namespace Fortran::runtime::cuda;
+
+namespace {
+
+template <typename OpTy>
+static bool isPinned(OpTy op) {
+ if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned)
+ return true;
+ return false;
+}
+
+static inline unsigned getMemType(cuf::DataAttribute attr) {
+ if (attr == cuf::DataAttribute::Device)
+ return kMemTypeDevice;
+ if (attr == cuf::DataAttribute::Managed)
+ return kMemTypeManaged;
+ if (attr == cuf::DataAttribute::Pinned)
+ return kMemTypePinned;
+ if (attr == cuf::DataAttribute::Unified)
+ return kMemTypeUnified;
+ llvm_unreachable("unsupported memory type");
+}
+
+static bool inDeviceContext(mlir::Operation *op) {
+ if (op->getParentOfType<cuf::KernelOp>())
+ return true;
+ if (auto funcOp = op->getParentOfType<mlir::gpu::GPUFuncOp>())
+ return true;
+ if (auto funcOp = op->getParentOfType<mlir::gpu::LaunchOp>())
+ return true;
+ if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) {
+ if (auto cudaProcAttr =
+ funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
+ cuf::getProcAttrName())) {
+ return cudaProcAttr.getValue() != cuf::ProcAttribute::Host &&
+ cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice;
+ }
+ }
+ return false;
+}
+
+template <typename OpTy>
+static mlir::LogicalResult convertOpToCall(OpTy op,
+ mlir::PatternRewriter &rewriter,
+ mlir::func::FuncOp func) {
+ auto mod = op->template getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+ auto fTy = func.getFunctionType();
+
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine;
+ if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>)
+ sourceLine = fir::factory::locationToLineNo(
+ builder, loc, op.getSource() ? fTy.getInput(7) : fTy.getInput(6));
+ else
+ sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+
+ mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true)
+ : builder.createBool(loc, false);
+ mlir::Value errmsg;
+ if (op.getErrmsg()) {
+ errmsg = op.getErrmsg();
+ } else {
+ mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType());
+ errmsg = fir::AbsentOp::create(builder, loc, boxNoneTy).getResult();
+ }
+ llvm::SmallVector<mlir::Value> args;
+ if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>) {
+ mlir::Value pinned =
+ op.getPinned()
+ ? op.getPinned()
+ : builder.createNullConstant(
+ loc, fir::ReferenceType::get(
+ mlir::IntegerType::get(op.getContext(), 1)));
+ if (op.getSource()) {
+ mlir::Value isDeviceSource = op.getDeviceSource()
+ ? builder.createBool(loc, true)
+ : builder.createBool(loc, false);
+ mlir::Value stream =
+ op.getStream() ? op.getStream()
+ : builder.createNullConstant(loc, fTy.getInput(2));
+ args = fir::runtime::createArguments(
+ builder, loc, fTy, op.getBox(), op.getSource(), stream, pinned,
+ hasStat, errmsg, sourceFile, sourceLine, isDeviceSource);
+ } else {
+ mlir::Value stream =
+ op.getStream() ? op.getStream()
+ : builder.createNullConstant(loc, fTy.getInput(1));
+ mlir::Value deviceInit =
+ (op.getDataAttrAttr() &&
+ op.getDataAttrAttr().getValue() == cuf::DataAttribute::Device)
+ ? builder.createBool(loc, true)
+ : builder.createBool(loc, false);
+ args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(),
+ stream, pinned, hasStat, errmsg,
+ sourceFile, sourceLine, deviceInit);
+ }
+ } else {
+ args =
+ fir::runtime::createArguments(builder, loc, fTy, op.getBox(), hasStat,
+ errmsg, sourceFile, sourceLine);
+ }
+ auto callOp = fir::CallOp::create(builder, loc, func, args);
+ rewriter.replaceOp(op, callOp);
+ return mlir::success();
+}
+
+struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ CUFAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl,
+ const fir::LLVMTypeConverter *typeConverter)
+ : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::AllocOp op,
+ mlir::PatternRewriter &rewriter) const override {
+
+ mlir::Location loc = op.getLoc();
+
+ if (inDeviceContext(op.getOperation())) {
+ // In device context just replace the cuf.alloc operation with a fir.alloc
+ // the cuf.free will be removed.
+ auto allocaOp =
+ fir::AllocaOp::create(rewriter, loc, op.getInType(),
+ op.getUniqName() ? *op.getUniqName() : "",
+ op.getBindcName() ? *op.getBindcName() : "",
+ op.getTypeparams(), op.getShape());
+ allocaOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
+ rewriter.replaceOp(op, allocaOp);
+ return mlir::success();
+ }
+
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+
+ if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
+ // Convert scalar and known size array allocations.
+ mlir::Value bytes;
+ fir::KindMapping kindMap{fir::getKindMapping(mod)};
+ if (fir::isa_trivial(op.getInType())) {
+ int width = cuf::computeElementByteSize(loc, op.getInType(), kindMap);
+ bytes =
+ builder.createIntegerConstant(loc, builder.getIndexType(), width);
+ } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
+ op.getInType())) {
+ std::size_t size = 0;
+ if (fir::isa_derived(seqTy.getEleTy())) {
+ mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
+ size = dl->getTypeSizeInBits(structTy) / 8;
+ } else {
+ size = cuf::computeElementByteSize(loc, seqTy.getEleTy(), kindMap);
+ }
+ mlir::Value width =
+ builder.createIntegerConstant(loc, builder.getIndexType(), size);
+ mlir::Value nbElem;
+ if (fir::sequenceWithNonConstantShape(seqTy)) {
+ assert(!op.getShape().empty() && "expect shape with dynamic arrays");
+ nbElem = builder.loadIfRef(loc, op.getShape()[0]);
+ for (unsigned i = 1; i < op.getShape().size(); ++i) {
+ nbElem = mlir::arith::MulIOp::create(
+ rewriter, loc, nbElem,
+ builder.loadIfRef(loc, op.getShape()[i]));
+ }
+ } else {
+ nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
+ seqTy.getConstantArraySize());
+ }
+ bytes = mlir::arith::MulIOp::create(rewriter, loc, nbElem, width);
+ } else if (fir::isa_derived(op.getInType())) {
+ mlir::Type structTy = typeConverter->convertType(op.getInType());
+ std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8;
+ bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
+ structSize);
+ } else if (fir::isa_char(op.getInType())) {
+ mlir::Type charTy = typeConverter->convertType(op.getInType());
+ std::size_t charSize = dl->getTypeSizeInBits(charTy) / 8;
+ bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
+ charSize);
+ } else {
+ mlir::emitError(loc, "unsupported type in cuf.alloc\n");
+ }
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+ mlir::Value memTy = builder.createIntegerConstant(
+ loc, builder.getI32Type(), getMemType(op.getDataAttr()));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
+ auto callOp = fir::CallOp::create(builder, loc, func, args);
+ callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
+ auto convOp = builder.createConvert(loc, op.getResult().getType(),
+ callOp.getResult(0));
+ rewriter.replaceOp(op, convOp);
+ return mlir::success();
+ }
+
+ // Convert descriptor allocations to function call.
+ auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDescriptor)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+
+ mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy);
+ std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
+ mlir::Value sizeInBytes =
+ builder.createIntegerConstant(loc, builder.getIndexType(), boxSize);
+
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)};
+ auto callOp = fir::CallOp::create(builder, loc, func, args);
+ callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
+ auto convOp = builder.createConvert(loc, op.getResult().getType(),
+ callOp.getResult(0));
+ rewriter.replaceOp(op, convOp);
+ return mlir::success();
+ }
+
+private:
+ mlir::DataLayout *dl;
+ const fir::LLVMTypeConverter *typeConverter;
+};
+
+struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::FreeOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ if (inDeviceContext(op.getOperation())) {
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+
+ if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
+ return failure();
+
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+
+ auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
+ if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+ mlir::Value memTy = builder.createIntegerConstant(
+ loc, builder.getI32Type(), getMemType(op.getDataAttr()));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
+ fir::CallOp::create(builder, loc, func, args);
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+
+ // Convert cuf.free on descriptors.
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDescriptor)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)};
+ auto callOp = fir::CallOp::create(builder, loc, func, args);
+ callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+};
+
+struct CUFAllocateOpConversion
+ : public mlir::OpRewritePattern<cuf::AllocateOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::AllocateOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+
+ bool isPointer = op.getPointer();
+ if (op.getHasDoubleDescriptor()) {
+ // Allocation for module variable are done with custom runtime entry point
+ // so the descriptors can be synchronized.
+ mlir::func::FuncOp func;
+ if (op.getSource()) {
+ func = isPointer ? fir::runtime::getRuntimeFunc<mkRTKey(
+ CUFPointerAllocateSourceSync)>(loc, builder)
+ : fir::runtime::getRuntimeFunc<mkRTKey(
+ CUFAllocatableAllocateSourceSync)>(loc, builder);
+ } else {
+ func =
+ isPointer
+ ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSync)>(
+ loc, builder)
+ : fir::runtime::getRuntimeFunc<mkRTKey(
+ CUFAllocatableAllocateSync)>(loc, builder);
+ }
+ return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
+ }
+
+ mlir::func::FuncOp func;
+ if (op.getSource()) {
+ func =
+ isPointer
+ ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSource)>(
+ loc, builder)
+ : fir::runtime::getRuntimeFunc<mkRTKey(
+ CUFAllocatableAllocateSource)>(loc, builder);
+ } else {
+ func =
+ isPointer
+ ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocate)>(
+ loc, builder)
+ : fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
+ loc, builder);
+ }
+
+ return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
+ }
+};
+
+struct CUFDeallocateOpConversion
+ : public mlir::OpRewritePattern<cuf::DeallocateOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::DeallocateOp op,
+ mlir::PatternRewriter &rewriter) const override {
+
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+
+ if (op.getHasDoubleDescriptor()) {
+ // Deallocation for module variable are done with custom runtime entry
+ // point so the descriptors can be synchronized.
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>(
+ loc, builder);
+ return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
+ }
+
+ // Deallocation for local descriptor falls back on the standard runtime
+ // AllocatableDeallocate as the dedicated deallocator is set in the
+ // descriptor before the call.
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
+ builder);
+ return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
+ }
+};
+
+class CUFAllocationConversion
+ : public fir::impl::CUFAllocationConversionBase<CUFAllocationConversion> {
+public:
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+ mlir::ConversionTarget target(*ctx);
+
+ mlir::Operation *op = getOperation();
+ mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
+ if (!module)
+ return signalPassFailure();
+ mlir::SymbolTable symtab(module);
+
+ std::optional<mlir::DataLayout> dl = fir::support::getOrSetMLIRDataLayout(
+ module, /*allowDefaultLayout=*/false);
+ fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
+ /*forceUnifiedTBAATree=*/false, *dl);
+ target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+ mlir::gpu::GPUDialect>();
+ target.addLegalOp<cuf::StreamCastOp>();
+ cuf::populateCUFAllocationConversionPatterns(typeConverter, *dl, symtab,
+ patterns);
+ if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
+ std::move(patterns)))) {
+ mlir::emitError(mlir::UnknownLoc::get(ctx),
+ "error in CUF allocation conversion\n");
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+void cuf::populateCUFAllocationConversionPatterns(
+ const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
+ const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
+ patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
+ patterns.insert<CUFFreeOpConversion, CUFAllocateOpConversion,
+ CUFDeallocateOpConversion>(patterns.getContext());
+}
diff --git a/flang/lib/Optimizer/Transforms/CUFComputeSharedMemoryOffsetsAndSize.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFComputeSharedMemoryOffsetsAndSize.cpp
index 09126e0..87dc27e 100644
--- a/flang/lib/Optimizer/Transforms/CUFComputeSharedMemoryOffsetsAndSize.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFComputeSharedMemoryOffsetsAndSize.cpp
@@ -41,12 +41,41 @@ namespace {
static bool isAssumedSize(mlir::ValueRange shape) {
if (shape.size() != 1)
return false;
- std::optional<std::int64_t> val = fir::getIntIfConstant(shape[0]);
- if (val && *val == -1)
+ if (llvm::isa_and_nonnull<fir::AssumedSizeExtentOp>(shape[0].getDefiningOp()))
return true;
return false;
}
+static void createSharedMemoryGlobal(fir::FirOpBuilder &builder,
+ mlir::Location loc, llvm::StringRef prefix,
+ llvm::StringRef suffix,
+ mlir::gpu::GPUModuleOp gpuMod,
+ mlir::Type sharedMemType, unsigned size,
+ unsigned align, bool isDynamic) {
+ std::string sharedMemGlobalName =
+ isDynamic ? (prefix + llvm::Twine(cudaSharedMemSuffix)).str()
+ : (prefix + llvm::Twine(cudaSharedMemSuffix) + suffix).str();
+
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToEnd(gpuMod.getBody());
+
+ mlir::StringAttr linkage = isDynamic ? builder.createExternalLinkage()
+ : builder.createInternalLinkage();
+ llvm::SmallVector<mlir::NamedAttribute> attrs;
+ auto globalOpName = mlir::OperationName(fir::GlobalOp::getOperationName(),
+ gpuMod.getContext());
+ attrs.push_back(mlir::NamedAttribute(
+ fir::GlobalOp::getDataAttrAttrName(globalOpName),
+ cuf::DataAttributeAttr::get(gpuMod.getContext(),
+ cuf::DataAttribute::Shared)));
+
+ mlir::DenseElementsAttr init = {};
+ auto sharedMem =
+ fir::GlobalOp::create(builder, loc, sharedMemGlobalName, false, false,
+ sharedMemType, init, linkage, attrs);
+ sharedMem.setAlignment(align);
+}
+
struct CUFComputeSharedMemoryOffsetsAndSize
: public fir::impl::CUFComputeSharedMemoryOffsetsAndSizeBase<
CUFComputeSharedMemoryOffsetsAndSize> {
@@ -109,18 +138,23 @@ struct CUFComputeSharedMemoryOffsetsAndSize
crtDynOffset, dynSize);
else
crtDynOffset = dynSize;
-
- continue;
+ } else {
+ // Static shared memory.
+ auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash(
+ loc, sharedOp.getInType(), *dl, kindMap);
+ createSharedMemoryGlobal(
+ builder, sharedOp.getLoc(), funcOp.getName(),
+ *sharedOp.getBindcName(), gpuMod,
+ fir::SequenceType::get(size, i8Ty), size,
+ sharedOp.getAlignment() ? *sharedOp.getAlignment() : align,
+ /*isDynamic=*/false);
+ mlir::Value zero = builder.createIntegerConstant(loc, i32Ty, 0);
+ sharedOp.getOffsetMutable().assign(zero);
+ if (!sharedOp.getAlignment())
+ sharedOp.setAlignment(align);
+ sharedOp.setIsStatic(true);
+ ++nbStaticSharedVariables;
}
- auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash(
- sharedOp.getLoc(), sharedOp.getInType(), *dl, kindMap);
- ++nbStaticSharedVariables;
- mlir::Value offset = builder.createIntegerConstant(
- loc, i32Ty, llvm::alignTo(sharedMemSize, align));
- sharedOp.getOffsetMutable().assign(offset);
- sharedMemSize =
- llvm::alignTo(sharedMemSize, align) + llvm::alignTo(size, align);
- alignment = std::max(alignment, align);
}
if (nbDynamicSharedVariables == 0 && nbStaticSharedVariables == 0)
@@ -131,35 +165,13 @@ struct CUFComputeSharedMemoryOffsetsAndSize
funcOp.getLoc(),
"static and dynamic shared variables in a single kernel");
- mlir::DenseElementsAttr init = {};
- if (sharedMemSize > 0) {
- auto vecTy = mlir::VectorType::get(sharedMemSize, i8Ty);
- mlir::Attribute zero = mlir::IntegerAttr::get(i8Ty, 0);
- init = mlir::DenseElementsAttr::get(vecTy, llvm::ArrayRef(zero));
- }
+ if (nbStaticSharedVariables > 0)
+ continue;
- // Create the shared memory global where each shared variable will point
- // to.
auto sharedMemType = fir::SequenceType::get(sharedMemSize, i8Ty);
- std::string sharedMemGlobalName =
- (funcOp.getName() + llvm::Twine(cudaSharedMemSuffix)).str();
- // Dynamic shared memory needs an external linkage while static shared
- // memory needs an internal linkage.
- mlir::StringAttr linkage = nbDynamicSharedVariables > 0
- ? builder.createExternalLinkage()
- : builder.createInternalLinkage();
- builder.setInsertionPointToEnd(gpuMod.getBody());
- llvm::SmallVector<mlir::NamedAttribute> attrs;
- auto globalOpName = mlir::OperationName(fir::GlobalOp::getOperationName(),
- gpuMod.getContext());
- attrs.push_back(mlir::NamedAttribute(
- fir::GlobalOp::getDataAttrAttrName(globalOpName),
- cuf::DataAttributeAttr::get(gpuMod.getContext(),
- cuf::DataAttribute::Shared)));
- auto sharedMem = fir::GlobalOp::create(
- builder, funcOp.getLoc(), sharedMemGlobalName, false, false,
- sharedMemType, init, linkage, attrs);
- sharedMem.setAlignment(alignment);
+ createSharedMemoryGlobal(builder, funcOp.getLoc(), funcOp.getName(), "",
+ gpuMod, sharedMemType, sharedMemSize, alignment,
+ /*isDynamic=*/true);
}
}
};
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFDeviceFuncTransform.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFDeviceFuncTransform.cpp
new file mode 100644
index 0000000..4532af9
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFDeviceFuncTransform.cpp
@@ -0,0 +1,250 @@
+//===-- CUFDeviceFuncTransform.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/CUFCommon.h"
+#include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/Dialect/FIRAttr.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "flang/Optimizer/Support/InternalNames.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/StringSet.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFDEVICEFUNCTRANSFORM
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace mlir;
+
+namespace {
+
+class CUFDeviceFuncTransform
+ : public fir::impl::CUFDeviceFuncTransformBase<CUFDeviceFuncTransform> {
+ using CUFDeviceFuncTransformBase<
+ CUFDeviceFuncTransform>::CUFDeviceFuncTransformBase;
+
+ static gpu::GPUFuncOp createGPUFuncOp(mlir::func::FuncOp funcOp,
+ bool isGlobal, int computeCap) {
+ mlir::OpBuilder builder(funcOp.getContext());
+
+ mlir::Region &funcOpBody = funcOp.getBody();
+ SetVector<Value> operands;
+ for (mlir::Value operand : funcOp.getArguments())
+ operands.insert(operand);
+
+ llvm::SmallVector<mlir::Type> funcOperandTypes;
+ llvm::SmallVector<mlir::Type> funcResultTypes;
+ funcOperandTypes.reserve(funcOp.getArgumentTypes().size());
+ funcResultTypes.reserve(funcOp.getResultTypes().size());
+ for (mlir::Type opTy : funcOp.getArgumentTypes())
+ funcOperandTypes.push_back(opTy);
+ for (mlir::Type resTy : funcOp.getResultTypes())
+ funcResultTypes.push_back(resTy);
+
+ mlir::Location loc = funcOp.getLoc();
+
+ mlir::FunctionType type = mlir::FunctionType::get(
+ funcOp.getContext(), funcOperandTypes, funcResultTypes);
+
+ auto deviceFuncOp =
+ gpu::GPUFuncOp::create(builder, loc, funcOp.getName(), type,
+ mlir::TypeRange{}, mlir::TypeRange{});
+ if (isGlobal)
+ deviceFuncOp->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
+ builder.getUnitAttr());
+
+ mlir::Region &deviceFuncBody = deviceFuncOp.getBody();
+ mlir::Block &entryBlock = deviceFuncBody.front();
+
+ mlir::IRMapping map;
+ for (const auto &operand : enumerate(operands))
+ map.map(operand.value(), entryBlock.getArgument(operand.index()));
+
+ funcOpBody.cloneInto(&deviceFuncBody, map);
+
+ deviceFuncOp.walk([](func::ReturnOp op) {
+ mlir::OpBuilder replacer(op);
+ gpu::ReturnOp gpuReturnOp = gpu::ReturnOp::create(replacer, op.getLoc());
+ gpuReturnOp->setOperands(op.getOperands());
+ op.erase();
+ });
+
+ mlir::Block &funcOpEntry = funcOp.front();
+ mlir::Block *clonedFuncOpEntry = map.lookup(&funcOpEntry);
+
+ entryBlock.getOperations().splice(entryBlock.getOperations().end(),
+ clonedFuncOpEntry->getOperations());
+ clonedFuncOpEntry->erase();
+
+ auto launchBoundsAttr =
+ funcOp.getOperation()->getAttrOfType<cuf::LaunchBoundsAttr>(
+ cuf::getLaunchBoundsAttrName());
+ if (launchBoundsAttr) {
+ auto maxTPB = launchBoundsAttr.getMaxTPB().getInt();
+ auto maxntid =
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(maxTPB), 1, 1});
+ deviceFuncOp->setAttr(NVVM::NVVMDialect::getMaxntidAttrName(), maxntid);
+ deviceFuncOp->setAttr(NVVM::NVVMDialect::getMinctasmAttrName(),
+ launchBoundsAttr.getMinBPM());
+ if (computeCap >= 90 && launchBoundsAttr.getUpperBoundClusterSize())
+ deviceFuncOp->setAttr(NVVM::NVVMDialect::getClusterMaxBlocksAttrName(),
+ launchBoundsAttr.getUpperBoundClusterSize());
+ }
+
+ return deviceFuncOp;
+ }
+
+ static void createHostStub(mlir::func::FuncOp funcOp,
+ mlir::SymbolTable &symTab, mlir::ModuleOp mod) {
+ mlir::Location loc = funcOp.getLoc();
+ mlir::OpBuilder modBuilder(mod.getBodyRegion());
+ modBuilder.setInsertionPointToEnd(mod.getBody());
+ auto emptyStub = func::FuncOp::create(modBuilder, loc, funcOp.getName(),
+ funcOp.getFunctionType());
+ emptyStub.setVisibility(funcOp.getVisibility());
+ emptyStub->setAttrs(funcOp->getAttrs());
+ auto entryBlock = emptyStub.addEntryBlock();
+ modBuilder.setInsertionPointToEnd(entryBlock);
+ func::ReturnOp::create(modBuilder, loc);
+
+ symTab.erase(funcOp);
+ symTab.insert(emptyStub);
+ }
+
+ static bool isDeviceFunc(mlir::func::FuncOp funcOp) {
+ if (auto cudaProcAttr =
+ funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
+ cuf::getProcAttrName()))
+ if (cudaProcAttr.getValue() == cuf::ProcAttribute::Device ||
+ cudaProcAttr.getValue() == cuf::ProcAttribute::Global ||
+ cudaProcAttr.getValue() == cuf::ProcAttribute::GridGlobal ||
+ cudaProcAttr.getValue() == cuf::ProcAttribute::HostDevice)
+ return true;
+ return false;
+ }
+
+ void runOnOperation() override {
+ // Working on Module operation because inserting/removing function from the
+ // module is not thread-safe.
+ ModuleOp mod = getOperation();
+ mlir::SymbolTable symbolTable(getOperation());
+
+ auto *ctx = getOperation().getContext();
+ mlir::OpBuilder builder(ctx);
+
+ gpu::GPUModuleOp gpuMod = cuf::getOrCreateGPUModule(mod, symbolTable);
+ mlir::SymbolTable gpuModSymTab(gpuMod);
+
+ llvm::SetVector<mlir::func::FuncOp> funcsToClone;
+ llvm::SetVector<mlir::func::FuncOp> deviceFuncs;
+ llvm::SetVector<mlir::func::FuncOp> keepInModule;
+ llvm::StringSet<> deviceFuncNames;
+
+ // Look for all function to migrate to the GPU module.
+ mod.walk([&](mlir::func::FuncOp op) {
+ if (isDeviceFunc(op)) {
+ deviceFuncs.insert(op);
+ deviceFuncNames.insert(op.getSymName());
+ }
+ });
+
+ auto processCallOp = [&](fir::CallOp op) {
+ if (op.getCallee()) {
+ auto func = symbolTable.lookup<mlir::func::FuncOp>(
+ op.getCallee()->getLeafReference());
+ if (deviceFuncs.count(func) == 0)
+ funcsToClone.insert(func);
+ }
+ };
+
+ // Gather all function called by device functions.
+ for (auto funcOp : deviceFuncs) {
+ funcOp.walk([&](fir::CallOp op) { processCallOp(op); });
+ funcOp.walk([&](fir::DispatchOp op) {
+ TODO(op.getLoc(), "type-bound procedure call with dynamic dispatch "
+ "in device procedure");
+ });
+ }
+
+ // Functions that are referenced in a derived-type binding table must be
+ // kept in the host module to avoid LLVM dialect verification errors.
+ for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
+ if (globalOp.getName().contains(fir::kBindingTableSeparator)) {
+ globalOp.walk([&](fir::AddrOfOp addrOfOp) {
+ if (deviceFuncNames.contains(addrOfOp.getSymbol().getLeafReference()))
+ keepInModule.insert(
+ *llvm::find_if(deviceFuncs, [&](mlir::func::FuncOp f) {
+ return f.getSymName() ==
+ addrOfOp.getSymbol().getLeafReference();
+ }));
+ });
+ }
+ }
+
+ // Gather all functions called by CUF kernels.
+ mod.walk([&](cuf::KernelOp kernelOp) {
+ kernelOp.walk([&](fir::CallOp op) { processCallOp(op); });
+ kernelOp.walk([&](fir::DispatchOp op) {
+ TODO(op.getLoc(),
+ "type-bound procedure call with dynamic dispatch in cuf kernel");
+ });
+ });
+
+ for (auto funcOp : funcsToClone)
+ gpuModSymTab.insert(funcOp->clone());
+
+ for (auto funcOp : deviceFuncs) {
+ auto cudaProcAttr =
+ funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
+ cuf::getProcAttrName());
+ auto isGlobal = cudaProcAttr.getValue() == cuf::ProcAttribute::Global ||
+ cudaProcAttr.getValue() == cuf::ProcAttribute::GridGlobal;
+ if (funcOp.isDeclaration()) {
+ mlir::Operation *clonedFuncOp = funcOp->clone();
+ if (isGlobal) {
+ clonedFuncOp->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
+ builder.getUnitAttr());
+ clonedFuncOp->removeAttr(cuf::getProcAttrName());
+ if (auto funcOp = mlir::dyn_cast<func::FuncOp>(clonedFuncOp))
+ funcOp.setNested();
+ }
+ gpuModSymTab.insert(clonedFuncOp);
+ } else {
+ gpu::GPUFuncOp deviceFuncOp =
+ createGPUFuncOp(funcOp, isGlobal, computeCap);
+ gpuModSymTab.insert(deviceFuncOp);
+
+ if (cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice) {
+ // If the function is a global, we need to keep the host side
+ // declaration for the kernel registration. Currently we just
+ // erase its body but in the future, the body should be rewritten
+ // to be able to launch CUDA Fortran kernel from C code.
+ if (isGlobal || keepInModule.contains(funcOp))
+ createHostStub(funcOp, symbolTable, mod);
+ else
+ funcOp.erase();
+ }
+ }
+ }
+ }
+};
+
+} // end anonymous namespace
diff --git a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFDeviceGlobal.cpp
index 35badb6..35badb6 100644
--- a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFDeviceGlobal.cpp
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp
new file mode 100644
index 0000000..bcbfb529
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp
@@ -0,0 +1,103 @@
+//===-- CUFFunctionRewrite.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/CodeGen/TypeConverter.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Support/DataLayout.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Debug.h"
+#include <string_view>
+
+#define DEBUG_TYPE "flang-cuf-function-rewrite"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFFUNCTIONREWRITE
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace mlir;
+
+namespace {
+
+using genFunctionType =
+ std::function<mlir::Value(mlir::PatternRewriter &, fir::CallOp op)>;
+
+class CallConversion : public OpRewritePattern<fir::CallOp> {
+public:
+ CallConversion(MLIRContext *context)
+ : OpRewritePattern<fir::CallOp>(context) {}
+
+ LogicalResult
+ matchAndRewrite(fir::CallOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto callee = op.getCallee();
+ if (!callee)
+ return failure();
+ auto name = callee->getRootReference().getValue();
+
+ if (genMappings_.contains(name)) {
+ auto fct = genMappings_.find(name);
+ mlir::Value result = fct->second(rewriter, op);
+ if (result)
+ rewriter.replaceOp(op, result);
+ else
+ rewriter.eraseOp(op);
+ return success();
+ }
+ return failure();
+ }
+
+private:
+ static mlir::Value genOnDevice(mlir::PatternRewriter &rewriter,
+ fir::CallOp op) {
+ assert(op.getArgs().size() == 0 && "expect 0 arguments");
+ mlir::Location loc = op.getLoc();
+ unsigned inGPUMod = op->getParentOfType<gpu::GPUModuleOp>() ? 1 : 0;
+ mlir::Type i1Ty = rewriter.getIntegerType(1);
+ mlir::Value t = mlir::arith::ConstantOp::create(
+ rewriter, loc, i1Ty, rewriter.getIntegerAttr(i1Ty, inGPUMod));
+ return fir::ConvertOp::create(rewriter, loc, op.getResult(0).getType(), t);
+ }
+
+ const llvm::StringMap<genFunctionType> genMappings_ = {
+ {"on_device", &genOnDevice}};
+};
+
+class CUFFunctionRewrite
+ : public fir::impl::CUFFunctionRewriteBase<CUFFunctionRewrite> {
+public:
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+
+ patterns.insert<CallConversion>(patterns.getContext());
+
+ if (mlir::failed(
+ mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ mlir::emitError(mlir::UnknownLoc::get(ctx),
+ "error in CUFFunctionRewrite op conversion\n");
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
diff --git a/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFGPUToLLVMConversion.cpp
index 40f180a..d5a8212 100644
--- a/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFGPUToLLVMConversion.cpp
@@ -249,8 +249,13 @@ struct CUFSharedMemoryOpConversion
"cuf.shared_memory must have an offset for code gen");
auto gpuMod = op->getParentOfType<gpu::GPUModuleOp>();
+
std::string sharedGlobalName =
- (getFuncName(op) + llvm::Twine(cudaSharedMemSuffix)).str();
+ op.getIsStatic()
+ ? (getFuncName(op) + llvm::Twine(cudaSharedMemSuffix) +
+ *op.getBindcName())
+ .str()
+ : (getFuncName(op) + llvm::Twine(cudaSharedMemSuffix)).str();
mlir::Value sharedGlobalAddr =
createAddressOfOp(rewriter, loc, gpuMod, sharedGlobalName);
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFLaunchAttachAttr.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFLaunchAttachAttr.cpp
new file mode 100644
index 0000000..41a0e5c
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFLaunchAttachAttr.cpp
@@ -0,0 +1,70 @@
+//===-- CUFLaunchAttachAttr.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFLAUNCHATTACHATTR
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace mlir;
+
+namespace {
+
+static constexpr llvm::StringRef cudaKernelInfix = "_cufk_";
+
+class CUFGPUAttachAttrPattern
+ : public OpRewritePattern<mlir::gpu::LaunchFuncOp> {
+ using OpRewritePattern<mlir::gpu::LaunchFuncOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(mlir::gpu::LaunchFuncOp op,
+ PatternRewriter &rewriter) const override {
+ op->setAttr(cuf::getProcAttrName(),
+ cuf::ProcAttributeAttr::get(op.getContext(),
+ cuf::ProcAttribute::Global));
+ return mlir::success();
+ }
+};
+
+struct CUFLaunchAttachAttr
+ : public fir::impl::CUFLaunchAttachAttrBase<CUFLaunchAttachAttr> {
+
+ void runOnOperation() override {
+ auto *context = &this->getContext();
+
+ mlir::RewritePatternSet patterns(context);
+ patterns.add<CUFGPUAttachAttrPattern>(context);
+
+ mlir::ConversionTarget target(*context);
+ target.addIllegalOp<mlir::gpu::LaunchFuncOp>();
+ target.addDynamicallyLegalOp<mlir::gpu::LaunchFuncOp>(
+ [&](mlir::gpu::LaunchFuncOp op) -> bool {
+ if (op.getKernelName().getValue().contains(cudaKernelInfix)) {
+ if (op.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
+ cuf::getProcAttrName()))
+ return true;
+ return false;
+ }
+ return true;
+ });
+
+ if (mlir::failed(mlir::applyPartialConversion(this->getOperation(), target,
+ std::move(patterns)))) {
+ mlir::emitError(mlir::UnknownLoc::get(context),
+ "Pattern conversion failed\n");
+ this->signalPassFailure();
+ }
+ }
+};
+
+} // end anonymous namespace
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
index 759e3a65d..ddae324 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
@@ -1,4 +1,4 @@
-//===-- CUFDeviceGlobal.cpp -----------------------------------------------===//
+//===-- CUFOpConversion.cpp -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -16,6 +16,7 @@
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Support/DataLayout.h"
+#include "flang/Optimizer/Transforms/Passes.h"
#include "flang/Runtime/CUDA/allocatable.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/CUDA/descriptor.h"
@@ -27,6 +28,7 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -44,213 +46,12 @@ using namespace Fortran::runtime::cuda;
namespace {
-static inline unsigned getMemType(cuf::DataAttribute attr) {
- if (attr == cuf::DataAttribute::Device)
- return kMemTypeDevice;
- if (attr == cuf::DataAttribute::Managed)
- return kMemTypeManaged;
- if (attr == cuf::DataAttribute::Unified)
- return kMemTypeUnified;
- if (attr == cuf::DataAttribute::Pinned)
- return kMemTypePinned;
- llvm::report_fatal_error("unsupported memory type");
-}
-
-template <typename OpTy>
-static bool isPinned(OpTy op) {
- if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned)
- return true;
- return false;
-}
-
-template <typename OpTy>
-static bool hasDoubleDescriptors(OpTy op) {
- if (auto declareOp =
- mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) {
- if (mlir::isa_and_nonnull<fir::AddrOfOp>(
- declareOp.getMemref().getDefiningOp())) {
- if (isPinned(declareOp))
- return false;
- return true;
- }
- } else if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
- op.getBox().getDefiningOp())) {
- if (mlir::isa_and_nonnull<fir::AddrOfOp>(
- declareOp.getMemref().getDefiningOp())) {
- if (isPinned(declareOp))
- return false;
- return true;
- }
- }
- return false;
-}
-
-static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
- mlir::Location loc, mlir::Type toTy,
- mlir::Value val) {
- if (val.getType() != toTy)
- return fir::ConvertOp::create(rewriter, loc, toTy, val);
- return val;
-}
-
-template <typename OpTy>
-static mlir::LogicalResult convertOpToCall(OpTy op,
- mlir::PatternRewriter &rewriter,
- mlir::func::FuncOp func) {
- auto mod = op->template getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::Location loc = op.getLoc();
- auto fTy = func.getFunctionType();
-
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
- mlir::Value sourceLine;
- if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>)
- sourceLine = fir::factory::locationToLineNo(
- builder, loc, op.getSource() ? fTy.getInput(7) : fTy.getInput(6));
- else
- sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
-
- mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true)
- : builder.createBool(loc, false);
-
- mlir::Value errmsg;
- if (op.getErrmsg()) {
- errmsg = op.getErrmsg();
- } else {
- mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType());
- errmsg = fir::AbsentOp::create(builder, loc, boxNoneTy).getResult();
- }
- llvm::SmallVector<mlir::Value> args;
- if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>) {
- mlir::Value pinned =
- op.getPinned()
- ? op.getPinned()
- : builder.createNullConstant(
- loc, fir::ReferenceType::get(
- mlir::IntegerType::get(op.getContext(), 1)));
- if (op.getSource()) {
- mlir::Value stream =
- op.getStream() ? op.getStream()
- : builder.createNullConstant(loc, fTy.getInput(2));
- args = fir::runtime::createArguments(
- builder, loc, fTy, op.getBox(), op.getSource(), stream, pinned,
- hasStat, errmsg, sourceFile, sourceLine);
- } else {
- mlir::Value stream =
- op.getStream() ? op.getStream()
- : builder.createNullConstant(loc, fTy.getInput(1));
- args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(),
- stream, pinned, hasStat, errmsg,
- sourceFile, sourceLine);
- }
- } else {
- args =
- fir::runtime::createArguments(builder, loc, fTy, op.getBox(), hasStat,
- errmsg, sourceFile, sourceLine);
- }
- auto callOp = fir::CallOp::create(builder, loc, func, args);
- rewriter.replaceOp(op, callOp);
- return mlir::success();
-}
-
-struct CUFAllocateOpConversion
- : public mlir::OpRewritePattern<cuf::AllocateOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult
- matchAndRewrite(cuf::AllocateOp op,
- mlir::PatternRewriter &rewriter) const override {
- auto mod = op->getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::Location loc = op.getLoc();
-
- bool isPointer = false;
-
- if (auto declareOp =
- mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp()))
- if (declareOp.getFortranAttrs() &&
- bitEnumContainsAny(*declareOp.getFortranAttrs(),
- fir::FortranVariableFlagsEnum::pointer))
- isPointer = true;
-
- if (hasDoubleDescriptors(op)) {
- // Allocation for module variable are done with custom runtime entry point
- // so the descriptors can be synchronized.
- mlir::func::FuncOp func;
- if (op.getSource()) {
- func = isPointer ? fir::runtime::getRuntimeFunc<mkRTKey(
- CUFPointerAllocateSourceSync)>(loc, builder)
- : fir::runtime::getRuntimeFunc<mkRTKey(
- CUFAllocatableAllocateSourceSync)>(loc, builder);
- } else {
- func =
- isPointer
- ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSync)>(
- loc, builder)
- : fir::runtime::getRuntimeFunc<mkRTKey(
- CUFAllocatableAllocateSync)>(loc, builder);
- }
- return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
- }
-
- mlir::func::FuncOp func;
- if (op.getSource()) {
- func =
- isPointer
- ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSource)>(
- loc, builder)
- : fir::runtime::getRuntimeFunc<mkRTKey(
- CUFAllocatableAllocateSource)>(loc, builder);
- } else {
- func =
- isPointer
- ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocate)>(
- loc, builder)
- : fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
- loc, builder);
- }
-
- return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
- }
-};
-
-struct CUFDeallocateOpConversion
- : public mlir::OpRewritePattern<cuf::DeallocateOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult
- matchAndRewrite(cuf::DeallocateOp op,
- mlir::PatternRewriter &rewriter) const override {
-
- auto mod = op->getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::Location loc = op.getLoc();
-
- if (hasDoubleDescriptors(op)) {
- // Deallocation for module variable are done with custom runtime entry
- // point so the descriptors can be synchronized.
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>(
- loc, builder);
- return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
- }
-
- // Deallocation for local descriptor falls back on the standard runtime
- // AllocatableDeallocate as the dedicated deallocator is set in the
- // descriptor before the call.
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
- builder);
- return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
- }
-};
-
static bool inDeviceContext(mlir::Operation *op) {
if (op->getParentOfType<cuf::KernelOp>())
return true;
- if (auto funcOp = op->getParentOfType<mlir::gpu::GPUFuncOp>())
+ if (op->getParentOfType<mlir::acc::OffloadRegionOpInterface>())
return true;
- if (auto funcOp = op->getParentOfType<mlir::gpu::LaunchOp>())
+ if (auto funcOp = op->getParentOfType<mlir::gpu::GPUFuncOp>())
return true;
if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) {
if (auto cudaProcAttr =
@@ -263,187 +64,14 @@ static bool inDeviceContext(mlir::Operation *op) {
return false;
}
-static int computeWidth(mlir::Location loc, mlir::Type type,
- fir::KindMapping &kindMap) {
- auto eleTy = fir::unwrapSequenceType(type);
- if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
- return t.getWidth() / 8;
- if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
- return t.getWidth() / 8;
- if (eleTy.isInteger(1))
- return 1;
- if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
- return kindMap.getLogicalBitsize(t.getFKind()) / 8;
- if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
- int elemSize =
- mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
- return 2 * elemSize;
- }
- if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)})
- return kindMap.getCharacterBitsize(t.getFKind()) / 8;
- mlir::emitError(loc, "unsupported type");
- return 0;
+static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
+ mlir::Location loc, mlir::Type toTy,
+ mlir::Value val) {
+ if (val.getType() != toTy)
+ return fir::ConvertOp::create(rewriter, loc, toTy, val);
+ return val;
}
-struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
- using OpRewritePattern::OpRewritePattern;
-
- CUFAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl,
- const fir::LLVMTypeConverter *typeConverter)
- : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
-
- mlir::LogicalResult
- matchAndRewrite(cuf::AllocOp op,
- mlir::PatternRewriter &rewriter) const override {
-
- mlir::Location loc = op.getLoc();
-
- if (inDeviceContext(op.getOperation())) {
- // In device context just replace the cuf.alloc operation with a fir.alloc
- // the cuf.free will be removed.
- auto allocaOp =
- fir::AllocaOp::create(rewriter, loc, op.getInType(),
- op.getUniqName() ? *op.getUniqName() : "",
- op.getBindcName() ? *op.getBindcName() : "",
- op.getTypeparams(), op.getShape());
- allocaOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
- rewriter.replaceOp(op, allocaOp);
- return mlir::success();
- }
-
- auto mod = op->getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
-
- if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
- // Convert scalar and known size array allocations.
- mlir::Value bytes;
- fir::KindMapping kindMap{fir::getKindMapping(mod)};
- if (fir::isa_trivial(op.getInType())) {
- int width = computeWidth(loc, op.getInType(), kindMap);
- bytes =
- builder.createIntegerConstant(loc, builder.getIndexType(), width);
- } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
- op.getInType())) {
- std::size_t size = 0;
- if (fir::isa_derived(seqTy.getEleTy())) {
- mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
- size = dl->getTypeSizeInBits(structTy) / 8;
- } else {
- size = computeWidth(loc, seqTy.getEleTy(), kindMap);
- }
- mlir::Value width =
- builder.createIntegerConstant(loc, builder.getIndexType(), size);
- mlir::Value nbElem;
- if (fir::sequenceWithNonConstantShape(seqTy)) {
- assert(!op.getShape().empty() && "expect shape with dynamic arrays");
- nbElem = builder.loadIfRef(loc, op.getShape()[0]);
- for (unsigned i = 1; i < op.getShape().size(); ++i) {
- nbElem = mlir::arith::MulIOp::create(
- rewriter, loc, nbElem,
- builder.loadIfRef(loc, op.getShape()[i]));
- }
- } else {
- nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
- seqTy.getConstantArraySize());
- }
- bytes = mlir::arith::MulIOp::create(rewriter, loc, nbElem, width);
- } else if (fir::isa_derived(op.getInType())) {
- mlir::Type structTy = typeConverter->convertType(op.getInType());
- std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8;
- bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
- structSize);
- } else {
- mlir::emitError(loc, "unsupported type in cuf.alloc\n");
- }
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
- auto fTy = func.getFunctionType();
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
- mlir::Value memTy = builder.createIntegerConstant(
- loc, builder.getI32Type(), getMemType(op.getDataAttr()));
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
- auto callOp = fir::CallOp::create(builder, loc, func, args);
- callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
- auto convOp = builder.createConvert(loc, op.getResult().getType(),
- callOp.getResult(0));
- rewriter.replaceOp(op, convOp);
- return mlir::success();
- }
-
- // Convert descriptor allocations to function call.
- auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDescriptor)>(loc, builder);
- auto fTy = func.getFunctionType();
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
-
- mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy);
- std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
- mlir::Value sizeInBytes =
- builder.createIntegerConstant(loc, builder.getIndexType(), boxSize);
-
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)};
- auto callOp = fir::CallOp::create(builder, loc, func, args);
- callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
- auto convOp = builder.createConvert(loc, op.getResult().getType(),
- callOp.getResult(0));
- rewriter.replaceOp(op, convOp);
- return mlir::success();
- }
-
-private:
- mlir::DataLayout *dl;
- const fir::LLVMTypeConverter *typeConverter;
-};
-
-struct CUFDeviceAddressOpConversion
- : public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
- using OpRewritePattern::OpRewritePattern;
-
- CUFDeviceAddressOpConversion(mlir::MLIRContext *context,
- const mlir::SymbolTable &symtab)
- : OpRewritePattern(context), symTab{symtab} {}
-
- mlir::LogicalResult
- matchAndRewrite(cuf::DeviceAddressOp op,
- mlir::PatternRewriter &rewriter) const override {
- if (auto global = symTab.lookup<fir::GlobalOp>(
- op.getHostSymbol().getRootReference().getValue())) {
- auto mod = op->getParentOfType<mlir::ModuleOp>();
- mlir::Location loc = op.getLoc();
- auto hostAddr = fir::AddrOfOp::create(
- rewriter, loc, fir::ReferenceType::get(global.getType()),
- op.getHostSymbol());
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::func::FuncOp callee =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc,
- builder);
- auto fTy = callee.getFunctionType();
- mlir::Value conv =
- createConvertOp(rewriter, loc, fTy.getInput(0), hostAddr);
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, conv, sourceFile, sourceLine)};
- auto call = fir::CallOp::create(rewriter, loc, callee, args);
- mlir::Value addr = createConvertOp(rewriter, loc, hostAddr.getType(),
- call->getResult(0));
- rewriter.replaceOp(op, addr.getDefiningOp());
- return success();
- }
- return failure();
- }
-
-private:
- const mlir::SymbolTable &symTab;
-};
-
struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
using OpRewritePattern::OpRewritePattern;
@@ -454,7 +82,12 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
mlir::LogicalResult
matchAndRewrite(fir::DeclareOp op,
mlir::PatternRewriter &rewriter) const override {
+ if (op.getResult().getUsers().empty())
+ return success();
if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
+ if (inDeviceContext(addrOfOp)) {
+ return failure();
+ }
if (auto global = symTab.lookup<fir::GlobalOp>(
addrOfOp.getSymbol().getRootReference().getValue())) {
if (cuf::isRegisteredDeviceGlobal(global)) {
@@ -475,56 +108,6 @@ private:
const mlir::SymbolTable &symTab;
};
-struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult
- matchAndRewrite(cuf::FreeOp op,
- mlir::PatternRewriter &rewriter) const override {
- if (inDeviceContext(op.getOperation())) {
- rewriter.eraseOp(op);
- return mlir::success();
- }
-
- if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
- return failure();
-
- auto mod = op->getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::Location loc = op.getLoc();
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
-
- auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
- if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
- auto fTy = func.getFunctionType();
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
- mlir::Value memTy = builder.createIntegerConstant(
- loc, builder.getI32Type(), getMemType(op.getDataAttr()));
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
- fir::CallOp::create(builder, loc, func, args);
- rewriter.eraseOp(op);
- return mlir::success();
- }
-
- // Convert cuf.free on descriptors.
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDescriptor)>(loc, builder);
- auto fTy = func.getFunctionType();
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)};
- auto callOp = fir::CallOp::create(builder, loc, func, args);
- callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
- rewriter.eraseOp(op);
- return mlir::success();
- }
-};
-
static bool isDstGlobal(cuf::DataTransferOp op) {
if (auto declareOp = op.getDst().getDefiningOp<fir::DeclareOp>())
if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
@@ -671,38 +254,15 @@ struct CUFDataTransferOpConversion
}
mlir::Type i64Ty = builder.getI64Type();
- mlir::Value nbElement;
- if (op.getShape()) {
- llvm::SmallVector<mlir::Value> extents;
- if (auto shapeOp =
- mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp())) {
- extents = shapeOp.getExtents();
- } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
- op.getShape().getDefiningOp())) {
- for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
- if (i.index() & 1)
- extents.push_back(i.value());
- }
-
- nbElement = fir::ConvertOp::create(rewriter, loc, i64Ty, extents[0]);
- for (unsigned i = 1; i < extents.size(); ++i) {
- auto operand =
- fir::ConvertOp::create(rewriter, loc, i64Ty, extents[i]);
- nbElement =
- mlir::arith::MulIOp::create(rewriter, loc, nbElement, operand);
- }
- } else {
- if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(dstTy))
- nbElement = builder.createIntegerConstant(
- loc, i64Ty, seqTy.getConstantArraySize());
- }
+ mlir::Value nbElement =
+ cuf::computeElementCount(rewriter, loc, op.getShape(), dstTy, i64Ty);
unsigned width = 0;
if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) {
mlir::Type structTy =
typeConverter->convertType(fir::unwrapSequenceType(dstTy));
width = dl->getTypeSizeInBits(structTy) / 8;
} else {
- width = computeWidth(loc, dstTy, kindMap);
+ width = cuf::computeElementByteSize(loc, dstTy, kindMap);
}
mlir::Value widthValue = mlir::arith::ConstantOp::create(
rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));
@@ -934,6 +494,8 @@ struct CUFSyncDescriptorOpConversion
};
class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
+ using CUFOpConversionBase::CUFOpConversionBase;
+
public:
void runOnOperation() override {
auto *ctx = &getContext();
@@ -953,6 +515,7 @@ public:
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
mlir::gpu::GPUDialect>();
target.addLegalOp<cuf::StreamCastOp>();
+ target.addLegalOp<cuf::DeviceAddressOp>();
cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
patterns);
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
@@ -963,6 +526,8 @@ public:
}
target.addDynamicallyLegalOp<fir::DeclareOp>([&](fir::DeclareOp op) {
+ if (op.getResult().getUsers().empty())
+ return true;
if (inDeviceContext(op))
return true;
if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
@@ -992,18 +557,13 @@ public:
void cuf::populateCUFToFIRConversionPatterns(
const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
- patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
- patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion,
- CUFFreeOpConversion, CUFSyncDescriptorOpConversion>(
- patterns.getContext());
+ patterns.insert<CUFSyncDescriptorOpConversion>(patterns.getContext());
patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
&dl, &converter);
- patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(
- patterns.getContext(), symtab);
+ patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
}
void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
mlir::RewritePatternSet &patterns) {
- patterns.insert<DeclareOpConversion, CUFDeviceAddressOpConversion>(
- patterns.getContext(), symtab);
+ patterns.insert<DeclareOpConversion>(patterns.getContext(), symtab);
}
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp
new file mode 100644
index 0000000..fe45971
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp
@@ -0,0 +1,120 @@
+//===-- CUFOpConversionLate.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/CUFCommon.h"
+#include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h"
+#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
+#include "flang/Optimizer/CodeGen/TypeConverter.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "flang/Runtime/CUDA/common.h"
+#include "flang/Runtime/CUDA/descriptor.h"
+#include "flang/Runtime/allocatable.h"
+#include "flang/Runtime/allocator-registry-consts.h"
+#include "flang/Support/Fortran.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFOPCONVERSIONLATE
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace fir;
+using namespace mlir;
+using namespace Fortran::runtime;
+using namespace Fortran::runtime::cuda;
+
+namespace {
+
+static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
+ mlir::Location loc, mlir::Type toTy,
+ mlir::Value val) {
+ if (val.getType() != toTy)
+ return fir::ConvertOp::create(rewriter, loc, toTy, val);
+ return val;
+}
+
+struct CUFDeviceAddressOpConversion
+ : public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ CUFDeviceAddressOpConversion(mlir::MLIRContext *context,
+ const mlir::SymbolTable &symtab)
+ : OpRewritePattern(context), symTab{symtab} {}
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::DeviceAddressOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ if (auto global = symTab.lookup<fir::GlobalOp>(
+ op.getHostSymbol().getRootReference().getValue())) {
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ mlir::Location loc = op.getLoc();
+ auto hostAddr = fir::AddrOfOp::create(
+ rewriter, loc, fir::ReferenceType::get(global.getType()),
+ op.getHostSymbol());
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::func::FuncOp callee =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc,
+ builder);
+ auto fTy = callee.getFunctionType();
+ mlir::Value conv =
+ createConvertOp(rewriter, loc, fTy.getInput(0), hostAddr);
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, conv, sourceFile, sourceLine)};
+ auto call = fir::CallOp::create(rewriter, loc, callee, args);
+ mlir::Value addr = createConvertOp(rewriter, loc, hostAddr.getType(),
+ call->getResult(0));
+ rewriter.replaceOp(op, addr.getDefiningOp());
+ return success();
+ }
+ return failure();
+ }
+
+private:
+ const mlir::SymbolTable &symTab;
+};
+
+class CUFOpConversionLate
+ : public fir::impl::CUFOpConversionLateBase<CUFOpConversionLate> {
+ using CUFOpConversionLateBase::CUFOpConversionLateBase;
+
+public:
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+ mlir::ConversionTarget target(*ctx);
+ mlir::Operation *op = getOperation();
+ mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
+ if (!module)
+ return signalPassFailure();
+ mlir::SymbolTable symtab(module);
+ target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+ mlir::gpu::GPUDialect>();
+ patterns.insert<CUFDeviceAddressOpConversion>(patterns.getContext(),
+ symtab);
+ if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
+ std::move(patterns)))) {
+ mlir::emitError(mlir::UnknownLoc::get(ctx),
+ "error in CUF op conversion\n");
+ signalPassFailure();
+ }
+ }
+};
+} // namespace
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp
new file mode 100644
index 0000000..3eb6559
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp
@@ -0,0 +1,153 @@
+//===-- CUFPredefinedVarToGPU.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Pass/Pass.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFPREDEFINEDVARTOGPU
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace mlir;
+
+namespace {
+
+template <typename OpTyX, typename OpTyY, typename OpTyZ>
+static void createForAllDimensions(mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::Value c1,
+ SmallVectorImpl<mlir::Value> &values,
+ bool incrementByOne = false) {
+ if (incrementByOne) {
+ auto baseX = OpTyX::create(builder, loc, builder.getI32Type());
+ values.push_back(mlir::arith::AddIOp::create(builder, loc, baseX, c1));
+ auto baseY = OpTyY::create(builder, loc, builder.getI32Type());
+ values.push_back(mlir::arith::AddIOp::create(builder, loc, baseY, c1));
+ auto baseZ = OpTyZ::create(builder, loc, builder.getI32Type());
+ values.push_back(mlir::arith::AddIOp::create(builder, loc, baseZ, c1));
+ } else {
+ values.push_back(OpTyX::create(builder, loc, builder.getI32Type()));
+ values.push_back(OpTyY::create(builder, loc, builder.getI32Type()));
+ values.push_back(OpTyZ::create(builder, loc, builder.getI32Type()));
+ }
+}
+
+static constexpr llvm::StringRef builtinsModuleName = "__fortran_builtins";
+static constexpr llvm::StringRef builtinVarPrefix = "__builtin_";
+static constexpr llvm::StringRef threadidx = "threadidx";
+static constexpr llvm::StringRef blockidx = "blockidx";
+static constexpr llvm::StringRef blockdim = "blockdim";
+static constexpr llvm::StringRef griddim = "griddim";
+
+static constexpr unsigned field_x = 0;
+static constexpr unsigned field_y = 1;
+static constexpr unsigned field_z = 2;
+
+std::string mangleBuiltin(llvm::StringRef varName) {
+ return "_QM" + builtinsModuleName.str() + "E" + builtinVarPrefix.str() +
+ varName.str();
+}
+
+static void processCoordinateOp(mlir::OpBuilder &builder, mlir::Location loc,
+ fir::CoordinateOp coordOp, unsigned fieldIdx,
+ mlir::Value &gpuValue) {
+ std::optional<llvm::ArrayRef<int32_t>> fieldIndices =
+ coordOp.getFieldIndices();
+ assert(fieldIndices && fieldIndices->size() == 1 &&
+ "expect only one coordinate");
+ if (static_cast<unsigned>((*fieldIndices)[0]) == fieldIdx) {
+ llvm::SmallVector<fir::LoadOp> opToErase;
+ for (mlir::OpOperand &coordUse : coordOp.getResult().getUses()) {
+ assert(mlir::isa<fir::LoadOp>(coordUse.getOwner()) &&
+ "only expect load op");
+ auto loadOp = mlir::dyn_cast<fir::LoadOp>(coordUse.getOwner());
+ loadOp.getResult().replaceAllUsesWith(gpuValue);
+ opToErase.push_back(loadOp);
+ }
+ for (auto op : opToErase)
+ op.erase();
+ }
+}
+
+static void
+processDeclareOp(mlir::OpBuilder &builder, mlir::Location loc,
+ fir::DeclareOp declareOp, llvm::StringRef builtinVar,
+ llvm::SmallVectorImpl<mlir::Value> &gpuValues,
+ llvm::SmallVectorImpl<mlir::Operation *> &opsToDelete) {
+ if (declareOp.getUniqName().str().compare(builtinVar) == 0) {
+ for (mlir::OpOperand &use : declareOp.getResult().getUses()) {
+ fir::CoordinateOp coordOp =
+ mlir::dyn_cast<fir::CoordinateOp>(use.getOwner());
+ processCoordinateOp(builder, loc, coordOp, field_x, gpuValues[0]);
+ processCoordinateOp(builder, loc, coordOp, field_y, gpuValues[1]);
+ processCoordinateOp(builder, loc, coordOp, field_z, gpuValues[2]);
+ opsToDelete.push_back(coordOp);
+ }
+ opsToDelete.push_back(declareOp.getOperation());
+ if (declareOp.getMemref().getDefiningOp())
+ opsToDelete.push_back(declareOp.getMemref().getDefiningOp());
+ }
+}
+
+struct CUFPredefinedVarToGPU
+ : public fir::impl::CUFPredefinedVarToGPUBase<CUFPredefinedVarToGPU> {
+
+ void runOnOperation() override {
+ func::FuncOp funcOp = getOperation();
+ if (funcOp.getBody().empty())
+ return;
+
+ if (auto cudaProcAttr =
+ funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
+ cuf::getProcAttrName())) {
+ if (cudaProcAttr.getValue() == cuf::ProcAttribute::Device ||
+ cudaProcAttr.getValue() == cuf::ProcAttribute::Global ||
+ cudaProcAttr.getValue() == cuf::ProcAttribute::GridGlobal ||
+ cudaProcAttr.getValue() == cuf::ProcAttribute::HostDevice) {
+ mlir::Location loc = funcOp.getLoc();
+ mlir::OpBuilder builder(funcOp.getContext());
+ builder.setInsertionPointToStart(&funcOp.getBody().front());
+ auto c1 = mlir::arith::ConstantOp::create(
+ builder, loc, builder.getI32Type(), builder.getI32IntegerAttr(1));
+ llvm::SmallVector<mlir::Value, 3> threadids, blockids, blockdims,
+ griddims;
+ createForAllDimensions<mlir::NVVM::ThreadIdXOp, mlir::NVVM::ThreadIdYOp,
+ mlir::NVVM::ThreadIdZOp>(
+ builder, loc, c1, threadids, /*incrementByOne=*/true);
+ createForAllDimensions<mlir::NVVM::BlockIdXOp, mlir::NVVM::BlockIdYOp,
+ mlir::NVVM::BlockIdZOp>(
+ builder, loc, c1, blockids, /*incrementByOne=*/true);
+ createForAllDimensions<mlir::NVVM::GridDimXOp, mlir::NVVM::GridDimYOp,
+ mlir::NVVM::GridDimZOp>(builder, loc, c1,
+ griddims);
+ createForAllDimensions<mlir::NVVM::BlockDimXOp, mlir::NVVM::BlockDimYOp,
+ mlir::NVVM::BlockDimZOp>(builder, loc, c1,
+ blockdims);
+
+ llvm::SmallVector<mlir::Operation *> opsToDelete;
+ for (auto declareOp : funcOp.getOps<fir::DeclareOp>()) {
+ processDeclareOp(builder, loc, declareOp, mangleBuiltin(threadidx),
+ threadids, opsToDelete);
+ processDeclareOp(builder, loc, declareOp, mangleBuiltin(blockidx),
+ blockids, opsToDelete);
+ processDeclareOp(builder, loc, declareOp, mangleBuiltin(blockdim),
+ blockdims, opsToDelete);
+ processDeclareOp(builder, loc, declareOp, mangleBuiltin(griddim),
+ griddims, opsToDelete);
+ }
+
+ for (auto op : opsToDelete)
+ op->erase();
+ }
+ }
+ }
+};
+
+} // end anonymous namespace
diff --git a/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp b/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp
index 00fdb5a..7c8ee09 100644
--- a/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp
+++ b/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp
@@ -438,6 +438,7 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType(
context, llvm::dwarf::DW_TAG_member,
mlir::StringAttr::get(context, fieldName), elemTy, byteSize * 8,
byteAlign * 8, offset * 8, /*optional<address space>=*/std::nullopt,
+ /*flags=*/mlir::LLVM::DIFlags::Zero,
/*extra data=*/nullptr);
elements.push_back(tyAttr);
offset += llvm::alignTo(byteSize, byteAlign);
@@ -480,6 +481,7 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertTupleType(
context, llvm::dwarf::DW_TAG_member, mlir::StringAttr::get(context, ""),
elemTy, byteSize * 8, byteAlign * 8, offset * 8,
/*optional<address space>=*/std::nullopt,
+ /*flags=*/mlir::LLVM::DIFlags::Zero,
/*extra data=*/nullptr);
elements.push_back(tyAttr);
offset += llvm::alignTo(byteSize, byteAlign);
@@ -528,13 +530,10 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertSequenceType(
if (dim == seqTy.getUnknownExtent()) {
// This path is taken for both assumed size array or when the size of the
// array is variable. In the case of variable size, we create a variable
- // to use as countAttr. Note that fir has a constant size of -1 for
- // assumed size array. So !optint check makes sure we don't generate
- // variable in that case.
+ // to use as countAttr.
if (declOp && declOp.getShape().size() > index) {
- std::optional<std::int64_t> optint =
- getIntIfConstant(declOp.getShape()[index]);
- if (!optint)
+ if (!llvm::isa_and_nonnull<fir::AssumedSizeExtentOp>(
+ declOp.getShape()[index].getDefiningOp()))
countAttr = generateArtificialVariable(
context, declOp.getShape()[index], fileAttr, scope, declOp);
}
@@ -676,7 +675,8 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertPointerLikeType(
context, llvm::dwarf::DW_TAG_pointer_type,
mlir::StringAttr::get(context, ""), elTyAttr, /*sizeInBits=*/ptrSize * 8,
/*alignInBits=*/0, /*offset=*/0,
- /*optional<address space>=*/std::nullopt, /*extra data=*/nullptr);
+ /*optional<address space>=*/std::nullopt,
+ /*flags=*/mlir::LLVM::DIFlags::Zero, /*extra data=*/nullptr);
}
static mlir::StringAttr getBasicTypeName(mlir::MLIRContext *context,
@@ -721,6 +721,32 @@ DebugTypeGenerator::convertType(mlir::Type Ty, mlir::LLVM::DIFileAttr fileAttr,
return convertRecordType(recTy, fileAttr, scope, declOp);
} else if (auto tupleTy = mlir::dyn_cast_if_present<mlir::TupleType>(Ty)) {
return convertTupleType(tupleTy, fileAttr, scope, declOp);
+ } else if (mlir::isa<mlir::FunctionType>(Ty)) {
+ // Handle function types - these represent procedure pointers after the
+ // BoxedProcedure pass has run and unwrapped the fir.boxproc type, as well
+ // as dummy procedures (which are represented as function types in FIR)
+ llvm::SmallVector<mlir::LLVM::DITypeAttr> types;
+
+ auto funcTy = mlir::cast<mlir::FunctionType>(Ty);
+ // Add return type (or void if no return type)
+ if (funcTy.getNumResults() == 0)
+ types.push_back(mlir::LLVM::DINullTypeAttr::get(context));
+ else
+ types.push_back(
+ convertType(funcTy.getResult(0), fileAttr, scope, declOp));
+
+ for (mlir::Type paramTy : funcTy.getInputs())
+ types.push_back(convertType(paramTy, fileAttr, scope, declOp));
+
+ auto subroutineTy = mlir::LLVM::DISubroutineTypeAttr::get(
+ context, /*callingConvention=*/0, types);
+
+ return mlir::LLVM::DIDerivedTypeAttr::get(
+ context, llvm::dwarf::DW_TAG_pointer_type,
+ mlir::StringAttr::get(context, ""), subroutineTy,
+ /*sizeInBits=*/ptrSize * 8, /*alignInBits=*/0, /*offset=*/0,
+ /*optional<address space>=*/std::nullopt,
+ /*flags=*/mlir::LLVM::DIFlags::Zero, /*extra data=*/nullptr);
} else if (auto refTy = mlir::dyn_cast_if_present<fir::ReferenceType>(Ty)) {
auto elTy = refTy.getEleTy();
return convertPointerLikeType(elTy, fileAttr, scope, declOp,
diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
new file mode 100644
index 0000000..bf125eb
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp
@@ -0,0 +1,1061 @@
+//===-- FIRToMemRef.cpp - Convert FIR loads and stores to MemRef ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass lowers FIR dialect memory operations to the MemRef dialect.
+// In particular it:
+//
+// - Rewrites `fir.alloca` to `memref.alloca`.
+//
+// - Rewrites `fir.load` / `fir.store` to `memref.load` / `memref.store`.
+//
+// - Allows FIR and MemRef to coexist by introducing `fir.convert` at
+// memory-use sites. Memory operations (`memref.load`, `memref.store`,
+// `memref.reinterpret_cast`, etc.) see MemRef-typed values, while the
+// original FIR-typed values remain available for non-memory uses. For
+// example:
+//
+// %fir_ref = ... : !fir.ref<!fir.array<...>>
+// %memref = fir.convert %fir_ref
+// : !fir.ref<!fir.array<...>> -> memref<...>
+// %val = memref.load %memref[...] : memref<...>
+// fir.call @callee(%fir_ref) : (!fir.ref<!fir.array<...>>) -> ()
+//
+// Here the MemRef-typed value is used for `memref.load`, while the
+// original FIR-typed value is preserved for `fir.call`.
+//
+// - Computes shapes, strides, and indices as needed for slices and shifts
+// and emits `memref.reinterpret_cast` when dynamic layout is required
+// (TODO: use memref.cast instead).
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/CUFCommon.h"
+#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Dialect/Support/FIRContext.h"
+#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/Transforms/FIRToMemRefTypeConverter.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "fir-to-memref"
+
+using namespace mlir;
+
+namespace fir {
+
+#define GEN_PASS_DEF_FIRTOMEMREF
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+
+static bool isMarshalLike(Operation *op) {
+ auto convert = dyn_cast_if_present<fir::ConvertOp>(op);
+ if (!convert)
+ return false;
+
+ bool resIsMemRef = isa<MemRefType>(convert.getType());
+ bool argIsMemRef = isa<MemRefType>(convert.getValue().getType());
+
+ assert(!(resIsMemRef && argIsMemRef) &&
+ "unexpected fir.convert memref -> memref in isMarshalLike");
+
+ return resIsMemRef || argIsMemRef;
+}
+
+using MemRefInfo = FailureOr<std::pair<Value, SmallVector<Value>>>;
+
+static llvm::cl::opt<bool> enableFIRConvertOptimizations(
+ "enable-fir-convert-opts",
+ llvm::cl::desc("enable emilinating redundant fir.convert in FIR-to-MemRef"),
+ llvm::cl::init(false), llvm::cl::Hidden);
+
+class FIRToMemRef : public fir::impl::FIRToMemRefBase<FIRToMemRef> {
+public:
+ void runOnOperation() override;
+
+private:
+ llvm::SmallSetVector<Operation *, 32> eraseOps;
+
+ DominanceInfo *domInfo = nullptr;
+
+ void rewriteAlloca(fir::AllocaOp, PatternRewriter &,
+ FIRToMemRefTypeConverter &);
+
+ void rewriteLoadOp(fir::LoadOp, PatternRewriter &,
+ FIRToMemRefTypeConverter &);
+
+ void rewriteStoreOp(fir::StoreOp, PatternRewriter &,
+ FIRToMemRefTypeConverter &);
+
+ MemRefInfo getMemRefInfo(Value, PatternRewriter &, FIRToMemRefTypeConverter &,
+ Operation *);
+
+ MemRefInfo convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp,
+ PatternRewriter &, FIRToMemRefTypeConverter &);
+
+ void replaceFIRMemrefs(Value, Value, PatternRewriter &) const;
+
+ FailureOr<Value> getFIRConvert(Operation *memOp, Operation *memref,
+ PatternRewriter &, FIRToMemRefTypeConverter &);
+
+ FailureOr<SmallVector<Value>> getMemrefIndices(fir::ArrayCoorOp, Operation *,
+ PatternRewriter &, Value,
+ Value) const;
+
+ bool memrefIsOptional(Operation *) const;
+
+ Value canonicalizeIndex(Value, PatternRewriter &) const;
+
+ template <typename OpTy>
+ void getShapeFrom(OpTy op, SmallVector<Value> &shapeVec,
+ SmallVector<Value> &shiftVec,
+ SmallVector<Value> &sliceVec) const;
+
+ void populateShapeAndShift(SmallVectorImpl<Value> &shapeVec,
+ SmallVectorImpl<Value> &shiftVec,
+ fir::ShapeShiftOp shift) const;
+
+ void populateShift(SmallVectorImpl<Value> &vec, fir::ShiftOp shift) const;
+
+ void populateShape(SmallVectorImpl<Value> &vec, fir::ShapeOp shape) const;
+
+ unsigned getRankFromEmbox(fir::EmboxOp embox) const {
+ auto memrefType = embox.getMemref().getType();
+ Type unwrappedType = fir::unwrapRefType(memrefType);
+ if (auto seqType = dyn_cast<fir::SequenceType>(unwrappedType))
+ return seqType.getDimension();
+ return 0;
+ }
+
+ bool isCompilerGeneratedAlloca(Operation *op) const;
+
+ void copyAttribute(Operation *from, Operation *to,
+ llvm::StringRef name) const;
+
+ Type getBaseType(Type type, bool complexBaseTypes = false) const;
+
+ bool memrefIsDeviceData(Operation *memref) const;
+
+ mlir::Attribute findCudaDataAttr(Value val) const;
+};
+
+void FIRToMemRef::populateShapeAndShift(SmallVectorImpl<Value> &shapeVec,
+ SmallVectorImpl<Value> &shiftVec,
+ fir::ShapeShiftOp shift) const {
+ for (mlir::OperandRange::iterator i = shift.getPairs().begin(),
+ endIter = shift.getPairs().end();
+ i != endIter;) {
+ shiftVec.push_back(*i++);
+ shapeVec.push_back(*i++);
+ }
+}
+
+bool FIRToMemRef::isCompilerGeneratedAlloca(Operation *op) const {
+ if (!isa<fir::AllocaOp, memref::AllocaOp>(op))
+ llvm_unreachable("expected alloca op");
+
+ return !op->getAttr("bindc_name") && !op->getAttr("uniq_name");
+}
+
+void FIRToMemRef::copyAttribute(Operation *from, Operation *to,
+ llvm::StringRef name) const {
+ if (Attribute value = from->getAttr(name))
+ to->setAttr(name, value);
+}
+
+Type FIRToMemRef::getBaseType(Type type, bool complexBaseTypes) const {
+ if (fir::isa_fir_type(type)) {
+ type = fir::getFortranElementType(type);
+ } else if (auto memrefTy = dyn_cast<MemRefType>(type)) {
+ type = memrefTy.getElementType();
+ }
+
+ if (!complexBaseTypes)
+ if (auto complexTy = dyn_cast<ComplexType>(type))
+ type = complexTy.getElementType();
+ return type;
+}
+
+bool FIRToMemRef::memrefIsDeviceData(Operation *memref) const {
+ if (isa<ACC_DATA_ENTRY_OPS>(memref))
+ return true;
+
+ return cuf::hasDeviceDataAttr(memref);
+}
+
+mlir::Attribute FIRToMemRef::findCudaDataAttr(Value val) const {
+ Value currentVal = val;
+ llvm::SmallPtrSet<Operation *, 8> visited;
+
+ while (currentVal) {
+ Operation *defOp = currentVal.getDefiningOp();
+ if (!defOp || !visited.insert(defOp).second)
+ break;
+
+ if (cuf::DataAttributeAttr cudaAttr = cuf::getDataAttr(defOp))
+ return cudaAttr;
+
+ // TODO: This is a best-effort backward walk; it is easy to miss attributes
+ // as FIR evolves. Long term, it would be preferable if the necessary
+ // information was carried in the type system (or otherwise made available
+ // without relying on a walk-back through defining ops).
+ if (auto reboxOp = dyn_cast<fir::ReboxOp>(defOp)) {
+ currentVal = reboxOp.getBox();
+ } else if (auto convertOp = dyn_cast<fir::ConvertOp>(defOp)) {
+ currentVal = convertOp->getOperand(0);
+ } else if (auto emboxOp = dyn_cast<fir::EmboxOp>(defOp)) {
+ currentVal = emboxOp.getMemref();
+ } else if (auto boxAddrOp = dyn_cast<fir::BoxAddrOp>(defOp)) {
+ currentVal = boxAddrOp.getVal();
+ } else if (auto declareOp = dyn_cast<fir::DeclareOp>(defOp)) {
+ currentVal = declareOp.getMemref();
+ } else {
+ break;
+ }
+ }
+ return nullptr;
+}
+
+void FIRToMemRef::populateShift(SmallVectorImpl<Value> &vec,
+ fir::ShiftOp shift) const {
+ vec.append(shift.getOrigins().begin(), shift.getOrigins().end());
+}
+
+void FIRToMemRef::populateShape(SmallVectorImpl<Value> &vec,
+ fir::ShapeOp shape) const {
+ vec.append(shape.getExtents().begin(), shape.getExtents().end());
+}
+
+template <typename OpTy>
+void FIRToMemRef::getShapeFrom(OpTy op, SmallVector<Value> &shapeVec,
+ SmallVector<Value> &shiftVec,
+ SmallVector<Value> &sliceVec) const {
+ if constexpr (std::is_same_v<OpTy, fir::ArrayCoorOp> ||
+ std::is_same_v<OpTy, fir::ReboxOp> ||
+ std::is_same_v<OpTy, fir::EmboxOp>) {
+ Value shapeVal = op.getShape();
+
+ if (shapeVal) {
+ Operation *shapeValOp = shapeVal.getDefiningOp();
+
+ if (auto shapeOp = dyn_cast<fir::ShapeOp>(shapeValOp)) {
+ populateShape(shapeVec, shapeOp);
+ } else if (auto shapeShiftOp = dyn_cast<fir::ShapeShiftOp>(shapeValOp)) {
+ populateShapeAndShift(shapeVec, shiftVec, shapeShiftOp);
+ } else if (auto shiftOp = dyn_cast<fir::ShiftOp>(shapeValOp)) {
+ populateShift(shiftVec, shiftOp);
+ }
+ }
+
+ Value sliceVal = op.getSlice();
+ if (sliceVal) {
+ if (auto sliceOp = sliceVal.getDefiningOp<fir::SliceOp>()) {
+ auto triples = sliceOp.getTriples();
+ sliceVec.append(triples.begin(), triples.end());
+ }
+ }
+ }
+}
+
+void FIRToMemRef::rewriteAlloca(fir::AllocaOp firAlloca,
+ PatternRewriter &rewriter,
+ FIRToMemRefTypeConverter &typeConverter) {
+ if (!typeConverter.convertibleType(firAlloca.getInType()))
+ return;
+
+ if (typeConverter.isEmptyArray(firAlloca.getType()))
+ return;
+
+ rewriter.setInsertionPointAfter(firAlloca);
+
+ Type type = firAlloca.getType();
+ MemRefType memrefTy = typeConverter.convertMemrefType(type);
+
+ Location loc = firAlloca.getLoc();
+
+ SmallVector<Value> sizes = firAlloca.getOperands();
+ std::reverse(sizes.begin(), sizes.end());
+
+ auto alloca = memref::AllocaOp::create(rewriter, loc, memrefTy, sizes);
+ copyAttribute(firAlloca, alloca, firAlloca.getBindcNameAttrName());
+ copyAttribute(firAlloca, alloca, firAlloca.getUniqNameAttrName());
+ copyAttribute(firAlloca, alloca, cuf::getDataAttrName());
+
+ auto convert = fir::ConvertOp::create(rewriter, loc, type, alloca);
+
+ rewriter.replaceOp(firAlloca, convert);
+
+ if (isCompilerGeneratedAlloca(alloca)) {
+ for (Operation *userOp : convert->getUsers()) {
+ if (auto declareOp = dyn_cast<fir::DeclareOp>(userOp)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: removing declare for compiler temp:\n";
+ declareOp->dump());
+ declareOp->replaceAllUsesWith(convert);
+ eraseOps.insert(userOp);
+ }
+ }
+ }
+}
+
+bool FIRToMemRef::memrefIsOptional(Operation *op) const {
+ if (auto declare = dyn_cast<fir::DeclareOp>(op)) {
+ if (fir::FortranVariableOpInterface(declare).isOptional())
+ return true;
+
+ Value operand = declare.getMemref();
+ Operation *operandOp = operand.getDefiningOp();
+ if (operandOp && isa<fir::AbsentOp>(operandOp))
+ return true;
+ }
+
+ for (mlir::Value result : op->getResults())
+ for (mlir::Operation *userOp : result.getUsers())
+ if (isa<fir::IsPresentOp>(userOp))
+ return true;
+
+ // TODO: If `op` is not a `fir.declare`, OPTIONAL information may still be
+ // present on a related `fir.declare` reached by tracing the address/box
+ // through common forwarding ops (e.g. `fir.convert`, `fir.rebox`,
+ // `fir.embox`, `fir.box_addr`), then checking `declare.isOptional()`. Add the
+ // search after FIR improves on it.
+ return false;
+}
+
+static Value castTypeToIndexType(Value originalValue,
+ PatternRewriter &rewriter) {
+ if (originalValue.getType().isIndex())
+ return originalValue;
+
+ Type indexType = rewriter.getIndexType();
+ return arith::IndexCastOp::create(rewriter, originalValue.getLoc(), indexType,
+ originalValue);
+}
+
+FailureOr<SmallVector<Value>>
+FIRToMemRef::getMemrefIndices(fir::ArrayCoorOp arrayCoorOp, Operation *memref,
+ PatternRewriter &rewriter, Value converted,
+ Value one) const {
+ IndexType indexTy = rewriter.getIndexType();
+ SmallVector<Value> indices;
+ Location loc = arrayCoorOp->getLoc();
+ SmallVector<Value> shiftVec, shapeVec, sliceVec;
+ int rank = arrayCoorOp.getIndices().size();
+ getShapeFrom<fir::ArrayCoorOp>(arrayCoorOp, shapeVec, shiftVec, sliceVec);
+
+ if (auto embox = dyn_cast_or_null<fir::EmboxOp>(memref)) {
+ getShapeFrom<fir::EmboxOp>(embox, shapeVec, shiftVec, sliceVec);
+ rank = getRankFromEmbox(embox);
+ }
+
+ SmallVector<Value> sliceLbs, sliceStrides;
+ for (size_t i = 0; i < sliceVec.size(); i += 3) {
+ sliceLbs.push_back(castTypeToIndexType(sliceVec[i], rewriter));
+ sliceStrides.push_back(castTypeToIndexType(sliceVec[i + 2], rewriter));
+ }
+
+ const bool isShifted = !shiftVec.empty();
+ const bool isSliced = !sliceVec.empty();
+
+ ValueRange idxs = arrayCoorOp.getIndices();
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+ SmallVector<bool> filledPositions(rank, false);
+ for (int i = 0; i < rank; ++i) {
+ Value step = isSliced ? sliceStrides[i] : one;
+ Operation *stepOp = step.getDefiningOp();
+ if (stepOp && mlir::isa_and_nonnull<fir::UndefOp>(stepOp)) {
+ Value shift = isShifted ? shiftVec[i] : one;
+ Value sliceLb = isSliced ? sliceLbs[i] : shift;
+ Value offset = arith::SubIOp::create(rewriter, loc, sliceLb, shift);
+ indices.push_back(offset);
+ filledPositions[i] = true;
+ } else {
+ indices.push_back(zero);
+ }
+ }
+
+ int arrayCoorIdx = 0;
+ for (int i = 0; i < rank; ++i) {
+ if (filledPositions[i])
+ continue;
+
+ assert((unsigned int)arrayCoorIdx < idxs.size() &&
+ "empty dimension should be eliminated\n");
+ Value index = canonicalizeIndex(idxs[arrayCoorIdx], rewriter);
+ Type cTy = index.getType();
+ if (!llvm::isa<IndexType>(cTy)) {
+ assert(cTy.isSignlessInteger() && "expected signless integer type");
+ index = arith::IndexCastOp::create(rewriter, loc, indexTy, index);
+ }
+
+ Value shift = isShifted ? shiftVec[i] : one;
+ Value stride = isSliced ? sliceStrides[i] : one;
+ Value sliceLb = isSliced ? sliceLbs[i] : shift;
+
+ Value oneIdx = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ Value indexAdjustment = isSliced ? oneIdx : sliceLb;
+ Value delta = arith::SubIOp::create(rewriter, loc, index, indexAdjustment);
+
+ Value scaled = arith::MulIOp::create(rewriter, loc, delta, stride);
+
+ Value offset = arith::SubIOp::create(rewriter, loc, sliceLb, shift);
+
+ Value finalIndex = arith::AddIOp::create(rewriter, loc, scaled, offset);
+
+ indices[i] = finalIndex;
+ arrayCoorIdx++;
+ }
+
+ std::reverse(indices.begin(), indices.end());
+
+ return indices;
+}
+
+MemRefInfo
+FIRToMemRef::convertArrayCoorOp(Operation *memOp, fir::ArrayCoorOp arrayCoorOp,
+ PatternRewriter &rewriter,
+ FIRToMemRefTypeConverter &typeConverter) {
+ IndexType indexTy = rewriter.getIndexType();
+ Value firMemref = arrayCoorOp.getMemref();
+ if (!typeConverter.convertibleMemrefType(firMemref.getType()))
+ return failure();
+
+ if (typeConverter.isEmptyArray(firMemref.getType()))
+ return failure();
+
+ if (auto blockArg = dyn_cast<BlockArgument>(firMemref)) {
+ Value elemRef = arrayCoorOp.getResult();
+ rewriter.setInsertionPointAfter(arrayCoorOp);
+ Location loc = arrayCoorOp->getLoc();
+ Type elemMemrefTy = typeConverter.convertMemrefType(elemRef.getType());
+ Value converted =
+ fir::ConvertOp::create(rewriter, loc, elemMemrefTy, elemRef);
+ SmallVector<Value> indices;
+ return std::pair{converted, indices};
+ }
+
+ Operation *memref = firMemref.getDefiningOp();
+
+ FailureOr<Value> converted;
+ if (enableFIRConvertOptimizations && isMarshalLike(memref) &&
+ !fir::isa_fir_type(firMemref.getType())) {
+ converted = firMemref;
+ rewriter.setInsertionPoint(arrayCoorOp);
+ } else {
+ Operation *arrayCoorOperation = arrayCoorOp.getOperation();
+ rewriter.setInsertionPoint(arrayCoorOp);
+ if (memrefIsOptional(memref)) {
+ auto ifOp = arrayCoorOperation->getParentOfType<scf::IfOp>();
+ if (ifOp) {
+ Operation *condition = ifOp.getCondition().getDefiningOp();
+ if (condition && isa<fir::IsPresentOp>(condition))
+ if (condition->getOperand(0) == firMemref) {
+ if (arrayCoorOperation->getParentRegion() == &ifOp.getThenRegion())
+ rewriter.setInsertionPointToStart(
+ &(ifOp.getThenRegion().front()));
+ else if (arrayCoorOperation->getParentRegion() ==
+ &ifOp.getElseRegion())
+ rewriter.setInsertionPointToStart(
+ &(ifOp.getElseRegion().front()));
+ }
+ }
+ }
+
+ converted = getFIRConvert(memOp, memref, rewriter, typeConverter);
+ if (failed(converted))
+ return failure();
+
+ rewriter.setInsertionPointAfter(arrayCoorOp);
+ }
+
+ Location loc = arrayCoorOp->getLoc();
+ Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ FailureOr<SmallVector<Value>> failureOrIndices =
+ getMemrefIndices(arrayCoorOp, memref, rewriter, *converted, one);
+ if (failed(failureOrIndices))
+ return failure();
+ SmallVector<Value> indices = *failureOrIndices;
+
+ if (converted == firMemref)
+ return std::pair{*converted, indices};
+
+ Value convertedVal = *converted;
+ MemRefType memRefTy = dyn_cast<MemRefType>(convertedVal.getType());
+
+ bool isRebox = firMemref.getDefiningOp<fir::ReboxOp>() != nullptr;
+
+ if (memRefTy.hasStaticShape() && !isRebox)
+ return std::pair{*converted, indices};
+
+ unsigned rank = arrayCoorOp.getIndices().size();
+
+ if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>())
+ rank = getRankFromEmbox(embox);
+
+ SmallVector<Value> sizes;
+ sizes.reserve(rank);
+ SmallVector<Value> strides;
+ strides.reserve(rank);
+
+ SmallVector<Value> shapeVec, shiftVec, sliceVec;
+ getShapeFrom<fir::ArrayCoorOp>(arrayCoorOp, shapeVec, shiftVec, sliceVec);
+
+ Value box = firMemref;
+ if (!isa<BlockArgument>(firMemref)) {
+ if (auto embox = firMemref.getDefiningOp<fir::EmboxOp>())
+ getShapeFrom<fir::EmboxOp>(embox, shapeVec, shiftVec, sliceVec);
+ else if (auto rebox = firMemref.getDefiningOp<fir::ReboxOp>())
+ getShapeFrom<fir::ReboxOp>(rebox, shapeVec, shiftVec, sliceVec);
+ }
+
+ if (shapeVec.empty()) {
+ auto boxElementSize =
+ fir::BoxEleSizeOp::create(rewriter, loc, indexTy, box);
+
+ for (unsigned i = 0; i < rank; ++i) {
+ Value dim = arith::ConstantIndexOp::create(rewriter, loc, rank - i - 1);
+ auto boxDims = fir::BoxDimsOp::create(rewriter, loc, indexTy, indexTy,
+ indexTy, box, dim);
+
+ Value extent = boxDims->getResult(1);
+ sizes.push_back(castTypeToIndexType(extent, rewriter));
+
+ Value byteStride = boxDims->getResult(2);
+ Value div =
+ arith::DivSIOp::create(rewriter, loc, byteStride, boxElementSize);
+ strides.push_back(castTypeToIndexType(div, rewriter));
+ }
+
+ } else {
+ Value oneIdx =
+ arith::ConstantIndexOp::create(rewriter, arrayCoorOp->getLoc(), 1);
+ for (unsigned i = rank - 1; i > 0; --i) {
+ Value size = shapeVec[i];
+ sizes.push_back(castTypeToIndexType(size, rewriter));
+
+ Value stride = shapeVec[0];
+ for (unsigned j = 1; j <= i - 1; ++j)
+ stride = arith::MulIOp::create(rewriter, loc, shapeVec[j], stride);
+ strides.push_back(castTypeToIndexType(stride, rewriter));
+ }
+
+ sizes.push_back(castTypeToIndexType(shapeVec[0], rewriter));
+ strides.push_back(oneIdx);
+ }
+
+ assert(strides.size() == sizes.size() && sizes.size() == rank);
+
+ int64_t dynamicOffset = ShapedType::kDynamic;
+ SmallVector<int64_t> dynamicStrides(rank, ShapedType::kDynamic);
+ auto stridedLayout = StridedLayoutAttr::get(convertedVal.getContext(),
+ dynamicOffset, dynamicStrides);
+
+ SmallVector<int64_t> dynamicShape(rank, ShapedType::kDynamic);
+ memRefTy =
+ MemRefType::get(dynamicShape, memRefTy.getElementType(), stridedLayout);
+
+ Value offset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+ auto reinterpret = memref::ReinterpretCastOp::create(
+ rewriter, loc, memRefTy, *converted, offset, sizes, strides);
+
+ Value result = reinterpret->getResult(0);
+ return std::pair{result, indices};
+}
+
+FailureOr<Value>
+FIRToMemRef::getFIRConvert(Operation *memOp, Operation *op,
+ PatternRewriter &rewriter,
+ FIRToMemRefTypeConverter &typeConverter) {
+ if (enableFIRConvertOptimizations && !op->hasOneUse() &&
+ !memrefIsOptional(op)) {
+ for (Operation *userOp : op->getUsers()) {
+ if (auto convertOp = dyn_cast<fir::ConvertOp>(userOp)) {
+ Value converted = convertOp.getResult();
+ if (!isa<MemRefType>(converted.getType()))
+ continue;
+
+ if (userOp->getParentOp() == memOp->getParentOp() &&
+ domInfo->dominates(userOp, memOp))
+ return converted;
+ }
+ }
+ }
+
+ assert(op->getNumResults() == 1 && "expecting one result");
+
+ Value basePtr = op->getResult(0);
+
+ MemRefType memrefTy = typeConverter.convertMemrefType(basePtr.getType());
+ Type baseTy = memrefTy.getElementType();
+
+ if (fir::isa_std_type(baseTy) && memrefTy.getRank() == 0) {
+ if (auto convertOp = basePtr.getDefiningOp<fir::ConvertOp>()) {
+ Value input = convertOp.getOperand();
+ if (auto alloca = input.getDefiningOp<memref::AllocaOp>()) {
+ assert(alloca.getType() == memrefTy && "expected same types");
+ if (isCompilerGeneratedAlloca(alloca))
+ return alloca.getResult();
+ }
+ }
+ }
+
+ const Location loc = op->getLoc();
+
+ if (isa<fir::BoxType>(basePtr.getType())) {
+ Operation *baseOp = basePtr.getDefiningOp();
+ auto boxAddrOp = fir::BoxAddrOp::create(rewriter, loc, basePtr);
+
+ if (auto cudaAttr = findCudaDataAttr(basePtr))
+ boxAddrOp->setAttr(cuf::getDataAttrName(), cudaAttr);
+
+ basePtr = boxAddrOp;
+ memrefTy = typeConverter.convertMemrefType(basePtr.getType());
+
+ if (baseOp) {
+ auto sameBaseBoxTypes = [&](Type baseType, Type memrefType) -> bool {
+ Type emboxBaseTy = getBaseType(baseType, true);
+ Type emboxMemrefTy = getBaseType(memrefType, true);
+ return emboxBaseTy == emboxMemrefTy;
+ };
+
+ if (auto embox = dyn_cast_or_null<fir::EmboxOp>(baseOp)) {
+ if (!sameBaseBoxTypes(embox.getType(), embox.getMemref().getType())) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: embox base type and memref type are not "
+ "the same, bailing out of conversion\n");
+ return failure();
+ }
+ if (embox.getSlice() &&
+ embox.getSlice().getDefiningOp<fir::SliceOp>()) {
+ Type originalType = embox.getMemref().getType();
+ basePtr = embox.getMemref();
+
+ if (typeConverter.convertibleMemrefType(originalType)) {
+ auto convertedMemrefTy =
+ typeConverter.convertMemrefType(originalType);
+ memrefTy = convertedMemrefTy;
+ } else {
+ return failure();
+ }
+ }
+ }
+
+ if (auto rebox = dyn_cast<fir::ReboxOp>(baseOp)) {
+ if (!sameBaseBoxTypes(rebox.getType(), rebox.getBox().getType())) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: rebox base type and box type are not the "
+ "same, bailing out of conversion\n");
+ return failure();
+ }
+ Type originalType = rebox.getBox().getType();
+ if (auto boxTy = dyn_cast<fir::BoxType>(originalType))
+ originalType = boxTy.getElementType();
+ if (!typeConverter.convertibleMemrefType(originalType)) {
+ return failure();
+ } else {
+ auto convertedMemrefTy =
+ typeConverter.convertMemrefType(originalType);
+ memrefTy = convertedMemrefTy;
+ }
+ }
+ }
+ }
+
+ auto convert = fir::ConvertOp::create(rewriter, loc, memrefTy, basePtr);
+ return convert->getResult(0);
+}
+
+Value FIRToMemRef::canonicalizeIndex(Value index,
+ PatternRewriter &rewriter) const {
+ if (auto blockArg = dyn_cast<BlockArgument>(index))
+ return index;
+
+ Operation *op = index.getDefiningOp();
+
+ if (auto constant = dyn_cast<arith::ConstantIntOp>(op)) {
+ if (!constant.getType().isIndex()) {
+ Value v = arith::ConstantIndexOp::create(rewriter, op->getLoc(),
+ constant.value());
+ return v;
+ }
+ return constant;
+ }
+
+ if (auto extsi = dyn_cast<arith::ExtSIOp>(op)) {
+ Value operand = extsi.getOperand();
+ if (auto indexCast = operand.getDefiningOp<arith::IndexCastOp>()) {
+ Value v = indexCast.getOperand();
+ return v;
+ }
+ return canonicalizeIndex(operand, rewriter);
+ }
+
+ if (auto add = dyn_cast<arith::AddIOp>(op)) {
+ Value lhs = canonicalizeIndex(add.getLhs(), rewriter);
+ Value rhs = canonicalizeIndex(add.getRhs(), rewriter);
+ if (lhs.getType() == rhs.getType())
+ return arith::AddIOp::create(rewriter, op->getLoc(), lhs, rhs);
+ }
+ return index;
+}
+
+MemRefInfo FIRToMemRef::getMemRefInfo(Value firMemref,
+ PatternRewriter &rewriter,
+ FIRToMemRefTypeConverter &typeConverter,
+ Operation *memOp) {
+ Operation *memrefOp = firMemref.getDefiningOp();
+ if (!memrefOp) {
+ if (auto blockArg = dyn_cast<BlockArgument>(firMemref)) {
+ rewriter.setInsertionPoint(memOp);
+ Type memrefTy = typeConverter.convertMemrefType(blockArg.getType());
+ if (auto mt = dyn_cast<MemRefType>(memrefTy))
+ if (auto inner = llvm::dyn_cast<MemRefType>(mt.getElementType()))
+ memrefTy = inner;
+ Value converted = fir::ConvertOp::create(rewriter, blockArg.getLoc(),
+ memrefTy, blockArg);
+ SmallVector<Value> indices;
+ return std::pair{converted, indices};
+ }
+ llvm_unreachable(
+ "FIRToMemRef: expected defining op or block argument for FIR memref");
+ }
+
+ if (auto arrayCoorOp = dyn_cast<fir::ArrayCoorOp>(memrefOp)) {
+ MemRefInfo memrefInfo =
+ convertArrayCoorOp(memOp, arrayCoorOp, rewriter, typeConverter);
+ if (succeeded(memrefInfo)) {
+ for (auto user : memrefOp->getUsers()) {
+ if (!isa<fir::LoadOp, fir::StoreOp>(user)) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "FIRToMemRef: array memref used by unsupported op:\n";
+ firMemref.dump(); user->dump());
+ return memrefInfo;
+ }
+ }
+ eraseOps.insert(memrefOp);
+ }
+ return memrefInfo;
+ }
+
+ rewriter.setInsertionPoint(memOp);
+
+ if (isMarshalLike(memrefOp)) {
+ FailureOr<Value> converted =
+ getFIRConvert(memOp, memrefOp, rewriter, typeConverter);
+ if (failed(converted)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: expected FIR memref in convert, bailing "
+ "out:\n";
+ firMemref.dump());
+ return failure();
+ }
+ SmallVector<Value> indices;
+ return std::pair{*converted, indices};
+ }
+
+ if (auto declareOp = dyn_cast<fir::DeclareOp>(memrefOp)) {
+ FailureOr<Value> converted =
+ getFIRConvert(memOp, declareOp, rewriter, typeConverter);
+ if (failed(converted)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: unable to create convert for scalar "
+ "memref:\n";
+ firMemref.dump());
+ return failure();
+ }
+ SmallVector<Value> indices;
+ return std::pair{*converted, indices};
+ }
+
+ if (auto coordinateOp = dyn_cast<fir::CoordinateOp>(memrefOp)) {
+ FailureOr<Value> converted =
+ getFIRConvert(memOp, coordinateOp, rewriter, typeConverter);
+ if (failed(converted)) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "FIRToMemRef: unable to create convert for derived-type "
+ "memref:\n";
+ firMemref.dump());
+ return failure();
+ }
+ SmallVector<Value> indices;
+ return std::pair{*converted, indices};
+ }
+
+ if (auto convertOp = dyn_cast<fir::ConvertOp>(memrefOp)) {
+ Type fromTy = convertOp->getOperand(0).getType();
+ Type toTy = firMemref.getType();
+ if (isa<fir::ReferenceType>(fromTy) && isa<fir::ReferenceType>(toTy)) {
+ FailureOr<Value> converted =
+ getFIRConvert(memOp, convertOp, rewriter, typeConverter);
+ if (failed(converted)) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "FIRToMemRef: unable to create convert for conversion "
+ "op:\n";
+ firMemref.dump());
+ return failure();
+ }
+ SmallVector<Value> indices;
+ return std::pair{*converted, indices};
+ }
+ }
+
+ if (auto boxAddrOp = dyn_cast<fir::BoxAddrOp>(memrefOp)) {
+ FailureOr<Value> converted =
+ getFIRConvert(memOp, boxAddrOp, rewriter, typeConverter);
+ if (failed(converted)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: unable to create convert for box_addr "
+ "op:\n";
+ firMemref.dump());
+ return failure();
+ }
+ SmallVector<Value> indices;
+ return std::pair{*converted, indices};
+ }
+
+ if (memrefIsDeviceData(memrefOp)) {
+ FailureOr<Value> converted =
+ getFIRConvert(memOp, memrefOp, rewriter, typeConverter);
+ if (failed(converted))
+ return failure();
+ SmallVector<Value> indices;
+ return std::pair{*converted, indices};
+ }
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: unable to create convert for memref value:\n";
+ firMemref.dump());
+
+ return failure();
+}
+
+void FIRToMemRef::replaceFIRMemrefs(Value firMemref, Value converted,
+ PatternRewriter &rewriter) const {
+ Operation *op = firMemref.getDefiningOp();
+ if (op && (isa<fir::ArrayCoorOp>(op) || isMarshalLike(op)))
+ return;
+
+ SmallPtrSet<Operation *, 4> worklist;
+ for (auto user : firMemref.getUsers()) {
+ if (isMarshalLike(user) || isa<fir::LoadOp, fir::StoreOp>(user))
+ continue;
+ if (!domInfo->dominates(converted, user))
+ continue;
+ if (!(isa<omp::AtomicCaptureOp>(user->getParentOp()) ||
+ isa<acc::AtomicCaptureOp>(user->getParentOp())))
+ worklist.insert(user);
+ }
+
+ Type ty = firMemref.getType();
+
+ for (auto op : worklist) {
+ rewriter.setInsertionPoint(op);
+ Location loc = op->getLoc();
+ Value replaceConvert = fir::ConvertOp::create(rewriter, loc, ty, converted);
+ op->replaceUsesOfWith(firMemref, replaceConvert);
+ }
+
+ worklist.clear();
+
+ for (auto user : firMemref.getUsers()) {
+ if (isMarshalLike(user) || isa<fir::LoadOp, fir::StoreOp>(user))
+ continue;
+ if (isa<omp::AtomicCaptureOp>(user->getParentOp()) ||
+ isa<acc::AtomicCaptureOp>(user->getParentOp()))
+ if (domInfo->dominates(converted, user))
+ worklist.insert(user);
+ }
+
+ if (worklist.empty())
+ return;
+
+ while (!worklist.empty()) {
+ Operation *parentOp = (*worklist.begin())->getParentOp();
+
+ Value replaceConvert;
+ SmallVector<Operation *> erase;
+ for (auto op : worklist) {
+ if (op->getParentOp() != parentOp)
+ continue;
+ if (!replaceConvert) {
+ rewriter.setInsertionPoint(parentOp);
+ replaceConvert =
+ fir::ConvertOp::create(rewriter, op->getLoc(), ty, converted);
+ }
+ op->replaceUsesOfWith(firMemref, replaceConvert);
+ erase.push_back(op);
+ }
+
+ for (auto op : erase)
+ worklist.erase(op);
+ }
+}
+
+void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter,
+ FIRToMemRefTypeConverter &typeConverter) {
+ Value firMemref = load.getMemref();
+ if (!typeConverter.convertibleType(firMemref.getType()))
+ return;
+
+ LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: attempting to convert FIR load:\n";
+ load.dump(); firMemref.dump());
+
+ MemRefInfo memrefInfo =
+ getMemRefInfo(firMemref, rewriter, typeConverter, load.getOperation());
+ if (failed(memrefInfo))
+ return;
+
+ Type originalType = load.getResult().getType();
+ Value converted = memrefInfo->first;
+ SmallVector<Value> indices = memrefInfo->second;
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "FIRToMemRef: convert for FIR load created successfully:\n";
+ converted.dump());
+
+ rewriter.setInsertionPointAfter(load);
+
+ Attribute attr = (load.getOperation())->getAttr("tbaa");
+ memref::LoadOp loadOp =
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(load, converted, indices);
+ if (attr)
+ loadOp.getOperation()->setAttr("tbaa", attr);
+
+ LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.load op:\n";
+ loadOp.dump(); assert(succeeded(verify(loadOp))));
+
+ if (isa<fir::LogicalType>(originalType)) {
+ Value logicalVal =
+ fir::ConvertOp::create(rewriter, loadOp.getLoc(), originalType, loadOp);
+ loadOp.getResult().replaceAllUsesExcept(logicalVal,
+ logicalVal.getDefiningOp());
+ }
+
+ if (!isa<fir::LogicalType>(originalType))
+ replaceFIRMemrefs(firMemref, converted, rewriter);
+}
+
+void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter,
+ FIRToMemRefTypeConverter &typeConverter) {
+ Value firMemref = store.getMemref();
+
+ if (!typeConverter.convertibleType(firMemref.getType()))
+ return;
+
+ LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: attempting to convert FIR store:\n";
+ store.dump(); firMemref.dump());
+
+ MemRefInfo memrefInfo =
+ getMemRefInfo(firMemref, rewriter, typeConverter, store.getOperation());
+ if (failed(memrefInfo))
+ return;
+
+ Value converted = memrefInfo->first;
+ SmallVector<Value> indices = memrefInfo->second;
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "FIRToMemRef: convert for FIR store created successfully:\n";
+ converted.dump());
+
+ Value value = store.getValue();
+ rewriter.setInsertionPointAfter(store);
+
+ if (isa<fir::LogicalType>(value.getType())) {
+ Type convertedType = typeConverter.convertType(value.getType());
+ value =
+ fir::ConvertOp::create(rewriter, store.getLoc(), convertedType, value);
+ }
+
+ Attribute attr = (store.getOperation())->getAttr("tbaa");
+ memref::StoreOp storeOp = rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ store, value, converted, indices);
+ if (attr)
+ storeOp.getOperation()->setAttr("tbaa", attr);
+
+ LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.store op:\n";
+ storeOp.dump(); assert(succeeded(verify(storeOp))));
+
+ bool isLogicalRef = false;
+ if (fir::ReferenceType refTy =
+ llvm::dyn_cast<fir::ReferenceType>(firMemref.getType()))
+ isLogicalRef = llvm::isa<fir::LogicalType>(refTy.getEleTy());
+ if (!isLogicalRef)
+ replaceFIRMemrefs(firMemref, converted, rewriter);
+}
+
+void FIRToMemRef::runOnOperation() {
+ LLVM_DEBUG(llvm::dbgs() << "Enter FIRToMemRef()\n");
+
+ func::FuncOp op = getOperation();
+ MLIRContext *context = op.getContext();
+ ModuleOp mod = op->getParentOfType<ModuleOp>();
+ FIRToMemRefTypeConverter typeConverter(mod);
+
+ typeConverter.setConvertComplexTypes(true);
+
+ PatternRewriter rewriter(context);
+ domInfo = new DominanceInfo(op);
+
+ op.walk([&](fir::AllocaOp alloca) {
+ rewriteAlloca(alloca, rewriter, typeConverter);
+ });
+
+ op.walk([&](Operation *op) {
+ if (fir::LoadOp loadOp = dyn_cast<fir::LoadOp>(op))
+ rewriteLoadOp(loadOp, rewriter, typeConverter);
+ else if (fir::StoreOp storeOp = dyn_cast<fir::StoreOp>(op))
+ rewriteStoreOp(storeOp, rewriter, typeConverter);
+ });
+
+ for (auto eraseOp : eraseOps)
+ rewriter.eraseOp(eraseOp);
+ eraseOps.clear();
+
+ if (domInfo)
+ delete domInfo;
+
+ LLVM_DEBUG(llvm::dbgs() << "After FIRToMemRef()\n"; op.dump();
+ llvm::dbgs() << "Exit FIRToMemRef()\n";);
+}
+
+} // namespace fir
diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
index 70d6ebb..04ba053 100644
--- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
@@ -18,6 +18,8 @@ namespace fir {
namespace {
class FIRToSCFPass : public fir::impl::FIRToSCFPassBase<FIRToSCFPass> {
+ using FIRToSCFPassBase::FIRToSCFPassBase;
+
public:
void runOnOperation() override;
};
@@ -25,11 +27,18 @@ public:
struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
using OpRewritePattern<fir::DoLoopOp>::OpRewritePattern;
+ DoLoopConversion(mlir::MLIRContext *context,
+ bool parallelUnorderedLoop = false,
+ mlir::PatternBenefit benefit = 1)
+ : OpRewritePattern<fir::DoLoopOp>(context, benefit),
+ parallelUnorderedLoop(parallelUnorderedLoop) {}
+
mlir::LogicalResult
matchAndRewrite(fir::DoLoopOp doLoopOp,
mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = doLoopOp.getLoc();
bool hasFinalValue = doLoopOp.getFinalValue().has_value();
+ bool isUnordered = doLoopOp.getUnordered().has_value();
// Get loop values from the DoLoopOp
mlir::Value low = doLoopOp.getLowerBound();
@@ -53,39 +62,78 @@ struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
mlir::arith::DivSIOp::create(rewriter, loc, distance, step);
auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0);
auto one = mlir::arith::ConstantIndexOp::create(rewriter, loc, 1);
- auto scfForOp =
- mlir::scf::ForOp::create(rewriter, loc, zero, tripCount, one, iterArgs);
+ // Create the scf.for or scf.parallel operation
+ mlir::Operation *scfLoopOp = nullptr;
+ if (isUnordered && parallelUnorderedLoop) {
+ scfLoopOp = mlir::scf::ParallelOp::create(rewriter, loc, {zero},
+ {tripCount}, {one}, iterArgs);
+ } else {
+ scfLoopOp = mlir::scf::ForOp::create(rewriter, loc, zero, tripCount, one,
+ iterArgs);
+ }
+
+ // Move the body of the fir.do_loop to the scf.for or scf.parallel
auto &loopOps = doLoopOp.getBody()->getOperations();
auto resultOp =
mlir::cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator());
auto results = resultOp.getOperands();
- mlir::Block *loweredBody = scfForOp.getBody();
+ auto scfLoopLikeOp = mlir::cast<mlir::LoopLikeOpInterface>(scfLoopOp);
+ mlir::Block &scfLoopBody = scfLoopLikeOp.getLoopRegions().front()->front();
- loweredBody->getOperations().splice(loweredBody->begin(), loopOps,
- loopOps.begin(),
- std::prev(loopOps.end()));
+ scfLoopBody.getOperations().splice(scfLoopBody.begin(), loopOps,
+ loopOps.begin(),
+ std::prev(loopOps.end()));
- rewriter.setInsertionPointToStart(loweredBody);
+ rewriter.setInsertionPointToStart(&scfLoopBody);
mlir::Value iv = mlir::arith::MulIOp::create(
- rewriter, loc, scfForOp.getInductionVar(), step);
+ rewriter, loc, scfLoopLikeOp.getSingleInductionVar().value(), step);
iv = mlir::arith::AddIOp::create(rewriter, loc, low, iv);
+ mlir::Value firIV = doLoopOp.getInductionVar();
+ firIV.replaceAllUsesWith(iv);
+
+ mlir::Value finalValue;
+ if (hasFinalValue) {
+ // Prefer re-using an existing `arith.addi` in the moved loop body if it
+ // already computes the next `iv + step`.
+ if (!results.empty()) {
+ if (auto addOp = results.front().getDefiningOp<mlir::arith::AddIOp>()) {
+ mlir::Value lhs = addOp.getLhs();
+ mlir::Value rhs = addOp.getRhs();
+ if ((lhs == iv && rhs == step) || (lhs == step && rhs == iv))
+ finalValue = results.front();
+ }
+ }
+ if (!finalValue)
+ finalValue = mlir::arith::AddIOp::create(rewriter, loc, iv, step);
+ }
- if (!results.empty()) {
- rewriter.setInsertionPointToEnd(loweredBody);
- mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(), results);
+ if (hasFinalValue || !results.empty()) {
+ rewriter.setInsertionPointToEnd(&scfLoopBody);
+ llvm::SmallVector<mlir::Value> yieldOperands;
+ if (hasFinalValue) {
+ yieldOperands.push_back(finalValue);
+ llvm::append_range(yieldOperands, results.drop_front());
+ } else {
+ llvm::append_range(yieldOperands, results);
+ }
+ mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(), yieldOperands);
}
- doLoopOp.getInductionVar().replaceAllUsesWith(iv);
- rewriter.replaceAllUsesWith(doLoopOp.getRegionIterArgs(),
- hasFinalValue
- ? scfForOp.getRegionIterArgs().drop_front()
- : scfForOp.getRegionIterArgs());
-
- // Copy all the attributes from the old to new op.
- scfForOp->setAttrs(doLoopOp->getAttrs());
- rewriter.replaceOp(doLoopOp, scfForOp);
+ rewriter.replaceAllUsesWith(
+ doLoopOp.getRegionIterArgs(),
+ hasFinalValue ? scfLoopLikeOp.getRegionIterArgs().drop_front()
+ : scfLoopLikeOp.getRegionIterArgs());
+
+ // Copy loop annotations from the fir.do_loop to scf loop op.
+ if (auto ann = doLoopOp.getLoopAnnotation())
+ scfLoopOp->setAttr("loop_annotation", *ann);
+
+ rewriter.replaceOp(doLoopOp, scfLoopOp);
return mlir::success();
}
+
+private:
+ bool parallelUnorderedLoop;
};
struct IterWhileConversion : public mlir::OpRewritePattern<fir::IterWhileOp> {
@@ -102,6 +150,7 @@ struct IterWhileConversion : public mlir::OpRewritePattern<fir::IterWhileOp> {
mlir::Value okInit = iterWhileOp.getIterateIn();
mlir::ValueRange iterArgs = iterWhileOp.getInitArgs();
+ bool hasFinalValue = iterWhileOp.getFinalValue().has_value();
mlir::SmallVector<mlir::Value> initVals;
initVals.push_back(lowerBound);
@@ -128,10 +177,23 @@ struct IterWhileConversion : public mlir::OpRewritePattern<fir::IterWhileOp> {
rewriter.setInsertionPointToStart(&beforeBlock);
- mlir::Value inductionCmp = mlir::arith::CmpIOp::create(
+ // The comparison depends on the sign of the step value. We fully expect
+ // this expression to be folded by the optimizer or LLVM. This expression
+ // is written this way so that `step == 0` always returns `false`.
+ auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto compl0 = mlir::arith::CmpIOp::create(
+ rewriter, loc, mlir::arith::CmpIPredicate::slt, zero, step);
+ auto compl1 = mlir::arith::CmpIOp::create(
rewriter, loc, mlir::arith::CmpIPredicate::sle, ivInBefore, upperBound);
- mlir::Value cond = mlir::arith::AndIOp::create(rewriter, loc, inductionCmp,
- earlyExitInBefore);
+ auto compl2 = mlir::arith::CmpIOp::create(
+ rewriter, loc, mlir::arith::CmpIPredicate::slt, step, zero);
+ auto compl3 = mlir::arith::CmpIOp::create(
+ rewriter, loc, mlir::arith::CmpIPredicate::sge, ivInBefore, upperBound);
+ auto cmp0 = mlir::arith::AndIOp::create(rewriter, loc, compl0, compl1);
+ auto cmp1 = mlir::arith::AndIOp::create(rewriter, loc, compl2, compl3);
+ auto cmp2 = mlir::arith::OrIOp::create(rewriter, loc, cmp0, cmp1);
+ mlir::Value cond =
+ mlir::arith::AndIOp::create(rewriter, loc, earlyExitInBefore, cmp2);
mlir::scf::ConditionOp::create(rewriter, loc, cond, argsInBefore);
@@ -140,17 +202,22 @@ struct IterWhileConversion : public mlir::OpRewritePattern<fir::IterWhileOp> {
auto *afterBody = scfWhileOp.getAfterBody();
auto resultOp = mlir::cast<fir::ResultOp>(afterBody->getTerminator());
- mlir::SmallVector<mlir::Value> results(resultOp->getOperands());
- mlir::Value ivInAfter = scfWhileOp.getAfterArguments()[0];
+ mlir::SmallVector<mlir::Value> results;
+ mlir::Value iv = scfWhileOp.getAfterArguments()[0];
rewriter.setInsertionPointToStart(afterBody);
- results[0] = mlir::arith::AddIOp::create(rewriter, loc, ivInAfter, step);
+ results.push_back(mlir::arith::AddIOp::create(rewriter, loc, iv, step));
+ llvm::append_range(results, hasFinalValue
+ ? resultOp->getOperands().drop_front()
+ : resultOp->getOperands());
rewriter.setInsertionPointToEnd(afterBody);
rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(resultOp, results);
scfWhileOp->setAttrs(iterWhileOp->getAttrs());
- rewriter.replaceOp(iterWhileOp, scfWhileOp);
+ rewriter.replaceOp(iterWhileOp,
+ hasFinalValue ? scfWhileOp->getResults()
+ : scfWhileOp->getResults().drop_front());
return mlir::success();
}
};
@@ -197,13 +264,14 @@ struct IfConversion : public mlir::OpRewritePattern<fir::IfOp> {
};
} // namespace
+void fir::populateFIRToSCFRewrites(mlir::RewritePatternSet &patterns,
+ bool parallelUnordered) {
+ patterns.add<IterWhileConversion, IfConversion>(patterns.getContext());
+ patterns.add<DoLoopConversion>(patterns.getContext(), parallelUnordered);
+}
+
void FIRToSCFPass::runOnOperation() {
mlir::RewritePatternSet patterns(&getContext());
- patterns.add<DoLoopConversion, IterWhileConversion, IfConversion>(
- patterns.getContext());
+ fir::populateFIRToSCFRewrites(patterns, parallelUnordered);
walkAndApplyPatterns(getOperation(), std::move(patterns));
}
-
-std::unique_ptr<mlir::Pass> fir::createFIRToSCFPass() {
- return std::make_unique<FIRToSCFPass>();
-}
diff --git a/flang/lib/Optimizer/Transforms/FunctionAttr.cpp b/flang/lib/Optimizer/Transforms/FunctionAttr.cpp
index 9dfe26cb..3879a80 100644
--- a/flang/lib/Optimizer/Transforms/FunctionAttr.cpp
+++ b/flang/lib/Optimizer/Transforms/FunctionAttr.cpp
@@ -87,10 +87,6 @@ void FunctionAttrPass::runOnOperation() {
func->setAttr(mlir::LLVM::LLVMFuncOp::getInstrumentFunctionExitAttrName(
llvmFuncOpName),
mlir::StringAttr::get(context, instrumentFunctionExit));
- if (noInfsFPMath)
- func->setAttr(
- mlir::LLVM::LLVMFuncOp::getNoInfsFpMathAttrName(llvmFuncOpName),
- mlir::BoolAttr::get(context, true));
if (noNaNsFPMath)
func->setAttr(
mlir::LLVM::LLVMFuncOp::getNoNansFpMathAttrName(llvmFuncOpName),
@@ -99,10 +95,6 @@ void FunctionAttrPass::runOnOperation() {
func->setAttr(
mlir::LLVM::LLVMFuncOp::getNoSignedZerosFpMathAttrName(llvmFuncOpName),
mlir::BoolAttr::get(context, true));
- if (unsafeFPMath)
- func->setAttr(
- mlir::LLVM::LLVMFuncOp::getUnsafeFpMathAttrName(llvmFuncOpName),
- mlir::BoolAttr::get(context, true));
if (!reciprocals.empty())
func->setAttr(
mlir::LLVM::LLVMFuncOp::getReciprocalEstimatesAttrName(llvmFuncOpName),
diff --git a/flang/lib/Optimizer/Transforms/LoopInvariantCodeMotion.cpp b/flang/lib/Optimizer/Transforms/LoopInvariantCodeMotion.cpp
new file mode 100644
index 0000000..8ebb898
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/LoopInvariantCodeMotion.cpp
@@ -0,0 +1,323 @@
+//===- LoopInvariantCodeMotion.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// FIR-specific Loop Invariant Code Motion pass.
+/// The pass relies on FIR types and interfaces to prove the safety
+/// of hoisting invariant operations out of loop-like operations.
+/// It may be run on both HLFIR and FIR representations.
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Analysis/AliasAnalysis.h"
+#include "flang/Optimizer/Dialect/FIROperationMoveOpInterface.h"
+#include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "flang/Optimizer/Dialect/FortranVariableInterface.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/DebugLog.h"
+
+namespace fir {
+#define GEN_PASS_DEF_LOOPINVARIANTCODEMOTION
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+#define DEBUG_TYPE "flang-licm"
+
+// Temporary engineering option for triaging LICM.
+static llvm::cl::opt<bool> disableFlangLICM(
+ "disable-flang-licm", llvm::cl::init(false), llvm::cl::Hidden,
+ llvm::cl::desc("Disable Flang's loop invariant code motion"));
+
+namespace {
+
+using namespace mlir;
+
+/// The pass tries to hoist loop invariant operations with only
+/// MemoryEffects::Read effects (MemoryEffects::Write support
+/// may be added later).
+/// The safety of hoisting is proven by:
+/// * Proving that the loop runs at least one iteration.
+/// * Proving that is is always safe to load from this location
+/// (see isSafeToHoistLoad() comments below).
+struct LoopInvariantCodeMotion
+ : fir::impl::LoopInvariantCodeMotionBase<LoopInvariantCodeMotion> {
+ void runOnOperation() override;
+};
+
+} // namespace
+
+/// 'location' is a memory reference used by a memory access.
+/// The type of 'location' defines the data type of the access
+/// (e.g. it is considered to be invalid to access 'i64'
+/// data using '!fir.ref<i32>`).
+/// For the given location, this function returns true iff
+/// the Fortran object being accessed is a scalar that
+/// may not be OPTIONAL.
+///
+/// Note that the '!fir.ref<!fir.box<>>' accesses are considered
+/// to be scalar, even if the underlying data is an array.
+///
+/// Note that an access of '!fir.ref<scalar>' may access
+/// an array object. For example:
+/// real :: x(:)
+/// do i=...
+/// = x(10)
+/// 'x(10)' accesses array 'x', and it may be unsafe to hoist
+/// it without proving that '10' is a valid index for the array.
+/// The fact that 'x' is not OPTIONAL does not allow hoisting
+/// on its own.
+static bool isNonOptionalScalar(Value location) {
+ while (true) {
+ LDBG() << "Checking location:\n" << location;
+ Type dataType = fir::unwrapRefType(location.getType());
+ if (!isa<fir::BaseBoxType>(location.getType()) &&
+ (!dataType ||
+ (!isa<fir::BaseBoxType>(dataType) && !fir::isa_trivial(dataType) &&
+ !fir::isa_derived(dataType)))) {
+ LDBG() << "Failure: data access is not scalar";
+ return false;
+ }
+ Operation *defOp = location.getDefiningOp();
+ if (!defOp) {
+ // If this is a function argument
+ auto blockArg = cast<BlockArgument>(location);
+ Block *block = blockArg.getOwner();
+ if (block && block->isEntryBlock())
+ if (auto funcOp =
+ dyn_cast_if_present<FunctionOpInterface>(block->getParentOp()))
+ if (!funcOp.getArgAttrOfType<UnitAttr>(blockArg.getArgNumber(),
+ fir::getOptionalAttrName())) {
+ LDBG() << "Success: is non optional scalar dummy";
+ return true;
+ }
+
+ LDBG() << "Failure: no defining operation";
+ return false;
+ }
+
+ // Scalars "defined" by fir.alloca and fir.address_of
+ // are present.
+ if (isa<fir::AllocaOp, fir::AddrOfOp>(defOp)) {
+ LDBG() << "Success: is non optional scalar";
+ return true;
+ }
+
+ if (auto varIface = dyn_cast<fir::FortranVariableOpInterface>(defOp)) {
+ if (varIface.isOptional()) {
+ // The variable is optional, so do not look further.
+ // Note that it is possible to deduce that the optional
+ // is actually present, but we are not doing it now.
+ LDBG() << "Failure: is optional";
+ return false;
+ }
+
+ // In case of MLIR inlining and ASSOCIATE an [hl]fir.declare
+ // may declare a scalar variable that is actually a "view"
+ // of an array element. Originally, such [hl]fir.declare
+ // would be located inside the loop preventing the hoisting.
+ // But if we decide to hoist such [hl]fir.declare in future,
+ // we cannot rely on their attributes/types.
+ // Use reliable checks based on the variable storage.
+
+ // If the variable has storage specifier (e.g. it is a member
+ // of COMMON, etc.), we can rely that the storage is present,
+ // and we can also rely on its FortranVariableOpInterface
+ // definition type (which is a scalar due to previous checks).
+ if (auto storageIface =
+ dyn_cast<fir::FortranVariableStorageOpInterface>(defOp))
+ if (Value storage = storageIface.getStorage()) {
+ LDBG() << "Success: is scalar with existing storage";
+ return true;
+ }
+
+ // TODO: we can probably use FIR AliasAnalysis' getSource()
+ // method to identify the storage in more cases.
+ Value memref = llvm::TypeSwitch<Operation *, Value>(defOp)
+ .Case<fir::DeclareOp, hlfir::DeclareOp>(
+ [](auto op) { return op.getMemref(); })
+ .Default([](auto) { return nullptr; });
+
+ if (memref)
+ return isNonOptionalScalar(memref);
+
+ LDBG() << "Failure: cannot reason about variable storage";
+ return false;
+ }
+ if (auto viewIface = dyn_cast<fir::FortranObjectViewOpInterface>(defOp)) {
+ location = viewIface.getViewSource(cast<OpResult>(location));
+ } else {
+ LDBG() << "Failure: unknown operation:\n" << *defOp;
+ return false;
+ }
+ }
+}
+
+/// Returns true iff it is safe to hoist the given load-like operation 'op',
+/// which access given memory 'locations', out of the operation 'loopLike'.
+/// The current safety conditions are:
+/// * The loop runs at least one iteration, OR
+/// * all the accessed locations are inside scalar non-OPTIONAL
+/// Fortran objects (Fortran descriptors are considered to be scalars).
+static bool isSafeToHoistLoad(Operation *op, ArrayRef<Value> locations,
+ LoopLikeOpInterface loopLike,
+ AliasAnalysis &aliasAnalysis) {
+ for (Value location : locations)
+ if (aliasAnalysis.getModRef(loopLike.getOperation(), location)
+ .isModAndRef()) {
+ LDBG() << "Failure: reads location:\n"
+ << location << "\nwhich is modified inside the loop";
+ return false;
+ }
+
+ // Check that it is safe to read from all the locations before the loop.
+ std::optional<llvm::APInt> tripCount = loopLike.getStaticTripCount();
+ if (tripCount && !tripCount->isZero()) {
+ // Loop executes at least one iteration, so it is safe to hoist.
+ LDBG() << "Success: loop has non-zero iterations";
+ return true;
+ }
+
+ // Check whether the access must always be valid.
+ return llvm::all_of(
+ locations, [&](Value location) { return isNonOptionalScalar(location); });
+ // TODO: consider hoisting under condition of the loop's trip count
+ // being non-zero.
+}
+
+/// Returns true iff the given 'op' is a load-like operation,
+/// and it can be hoisted out of 'loopLike' operation.
+static bool canHoistLoad(Operation *op, LoopLikeOpInterface loopLike,
+ AliasAnalysis &aliasAnalysis) {
+ LDBG() << "Checking operation:\n" << *op;
+ if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ effectInterface.getEffects(effects);
+ if (effects.empty()) {
+ LDBG() << "Failure: not a load";
+ return false;
+ }
+ llvm::SetVector<Value> locations;
+ for (const MemoryEffects::EffectInstance &effect : effects) {
+ Value location = effect.getValue();
+ if (!isa<MemoryEffects::Read>(effect.getEffect())) {
+ LDBG() << "Failure: has unsupported effects";
+ return false;
+ } else if (!location) {
+ LDBG() << "Failure: reads from unknown location";
+ return false;
+ }
+ locations.insert(location);
+ }
+ return isSafeToHoistLoad(op, locations.getArrayRef(), loopLike,
+ aliasAnalysis);
+ }
+ LDBG() << "Failure: has unknown effects";
+ return false;
+}
+
+void LoopInvariantCodeMotion::runOnOperation() {
+ if (disableFlangLICM) {
+ LDBG() << "Skipping [HL]FIR LoopInvariantCodeMotion()";
+ return;
+ }
+
+ LDBG() << "Enter [HL]FIR LoopInvariantCodeMotion()";
+
+ auto &aliasAnalysis = getAnalysis<AliasAnalysis>();
+ aliasAnalysis.addAnalysisImplementation(fir::AliasAnalysis{});
+
+ std::function<bool(Operation *, LoopLikeOpInterface loopLike)>
+ shouldMoveOutOfLoop = [&](Operation *op, LoopLikeOpInterface loopLike) {
+ if (isPure(op)) {
+ LDBG() << "Pure operation: " << *op;
+ return true;
+ }
+
+ // Handle RecursivelySpeculatable operations that have
+ // RecursiveMemoryEffects by checking if all their
+ // nested operations can be hoisted.
+ auto iface = dyn_cast<ConditionallySpeculatable>(op);
+ if (iface && iface.getSpeculatability() ==
+ Speculation::RecursivelySpeculatable) {
+ if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
+ LDBG() << "Checking recursive operation:\n" << *op;
+ llvm::SmallVector<Operation *> nestedOps;
+ for (Region &region : op->getRegions())
+ for (Block &block : region)
+ for (Operation &nestedOp : block)
+ nestedOps.push_back(&nestedOp);
+
+ bool result = llvm::all_of(nestedOps, [&](Operation *nestedOp) {
+ return shouldMoveOutOfLoop(nestedOp, loopLike);
+ });
+ LDBG() << "Recursive operation can" << (result ? "" : "not")
+ << " be hoisted";
+
+ // If nested operations cannot be hoisted, there is nothing
+ // else to check. Also if the operation itself does not have
+ // any memory effects, we can return the result now.
+ // Otherwise, we have to check the operation itself below.
+ if (!result || !isa<MemoryEffectOpInterface>(op))
+ return result;
+ }
+ }
+ return canHoistLoad(op, loopLike, aliasAnalysis);
+ };
+
+ getOperation()->walk([&](LoopLikeOpInterface loopLike) {
+ if (!fir::canMoveOutOf(loopLike, nullptr)) {
+ LDBG() << "Cannot hoist anything out of loop operation: ";
+ LDBG_OS([&](llvm::raw_ostream &os) {
+ loopLike->print(os, OpPrintingFlags().skipRegions());
+ });
+ return;
+ }
+ // We always hoist operations to the parent operation of the loopLike.
+ // Check that the parent operation allows the hoisting, e.g.
+ // omp::LoopWrapperInterface operations assume tight nesting
+ // of the inner maybe loop-like operations, so hoisting
+ // to such a parent would be invalid. We rely on
+ // fir::canMoveFromDescendant() to identify whether the hoisting
+ // is allowed.
+ Operation *parentOp = loopLike->getParentOp();
+ if (!parentOp) {
+ LDBG() << "Skipping top-level loop-like operation?";
+ return;
+ } else if (!fir::canMoveFromDescendant(parentOp, loopLike, nullptr)) {
+ LDBG() << "Cannot hoist anything into operation: ";
+ LDBG_OS([&](llvm::raw_ostream &os) {
+ parentOp->print(os, OpPrintingFlags().skipRegions());
+ });
+ return;
+ }
+ moveLoopInvariantCode(
+ loopLike.getLoopRegions(),
+ /*isDefinedOutsideRegion=*/
+ [&](Value value, Region *) {
+ return loopLike.isDefinedOutsideOfLoop(value);
+ },
+ /*shouldMoveOutOfRegion=*/
+ [&](Operation *op, Region *) {
+ if (!fir::canMoveOutOf(loopLike, op)) {
+ LDBG() << "Cannot hoist " << *op << " out of the loop";
+ return false;
+ }
+ if (!fir::canMoveFromDescendant(parentOp, loopLike, op)) {
+ LDBG() << "Cannot hoist " << *op << " into the parent of the loop";
+ return false;
+ }
+ return shouldMoveOutOfLoop(op, loopLike);
+ },
+ /*moveOutOfRegion=*/
+ [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
+ });
+
+ LDBG() << "Exit [HL]FIR LoopInvariantCodeMotion()";
+}
diff --git a/flang/lib/Optimizer/Transforms/MIFOpConversion.cpp b/flang/lib/Optimizer/Transforms/MIFOpConversion.cpp
index 206cb9b..fed941c0 100644
--- a/flang/lib/Optimizer/Transforms/MIFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/MIFOpConversion.cpp
@@ -16,6 +16,7 @@
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Support/DataLayout.h"
#include "flang/Optimizer/Support/InternalNames.h"
+#include "flang/Runtime/stop.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -67,6 +68,118 @@ genErrmsgPRIF(fir::FirOpBuilder &builder, mlir::Location loc,
return {errMsg, errMsgAlloc};
}
+static mlir::Value genStatPRIF(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value stat) {
+ if (!stat)
+ return fir::AbsentOp::create(builder, loc, getPRIFStatType(builder));
+ return stat;
+}
+
+static fir::CallOp genPRIFStopErrorStop(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ mlir::Value stopCode,
+ bool isError = false) {
+ mlir::Type stopCharTy = fir::BoxCharType::get(builder.getContext(), 1);
+ mlir::Type i1Ty = builder.getI1Type();
+ mlir::Type i32Ty = builder.getI32Type();
+
+ mlir::FunctionType ftype = mlir::FunctionType::get(
+ builder.getContext(),
+ /*inputs*/
+ {builder.getRefType(i1Ty), builder.getRefType(i32Ty), stopCharTy},
+ /*results*/ {});
+ mlir::func::FuncOp funcOp =
+ isError
+ ? builder.createFunction(loc, getPRIFProcName("error_stop"), ftype)
+ : builder.createFunction(loc, getPRIFProcName("stop"), ftype);
+
+ // QUIET is managed in flang-rt, so its value is set to TRUE here.
+ mlir::Value q = builder.createBool(loc, true);
+ mlir::Value quiet = builder.createTemporary(loc, i1Ty);
+ fir::StoreOp::create(builder, loc, q, quiet);
+
+ mlir::Value stopCodeInt, stopCodeChar;
+ if (!stopCode) {
+ stopCodeChar = fir::AbsentOp::create(builder, loc, stopCharTy);
+ stopCodeInt =
+ fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty));
+ } else if (fir::isa_integer(stopCode.getType())) {
+ stopCodeChar = fir::AbsentOp::create(builder, loc, stopCharTy);
+ stopCodeInt = builder.createTemporary(loc, i32Ty);
+ if (stopCode.getType() != i32Ty)
+ stopCode = fir::ConvertOp::create(builder, loc, i32Ty, stopCode);
+ fir::StoreOp::create(builder, loc, stopCode, stopCodeInt);
+ } else {
+ stopCodeChar = stopCode;
+ if (!mlir::isa<fir::BoxCharType>(stopCodeChar.getType())) {
+ auto len =
+ fir::UndefOp::create(builder, loc, builder.getCharacterLengthType());
+ stopCodeChar =
+ fir::EmboxCharOp::create(builder, loc, stopCharTy, stopCodeChar, len);
+ }
+ stopCodeInt =
+ fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty));
+ }
+
+ llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
+ builder, loc, ftype, quiet, stopCodeInt, stopCodeChar);
+ return fir::CallOp::create(builder, loc, funcOp, args);
+}
+
+enum class TerminationKind { Normal = 0, Error = 1, FailImage = 2 };
+// Generates a wrapper function for the different kind of termination in PRIF.
+// This function will be used to register wrappers on PRIF runtime termination
+// functions into the Fortran runtime.
+mlir::Value genTerminationOperationWrapper(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ mlir::ModuleOp module,
+ TerminationKind termKind) {
+ std::string funcName;
+ mlir::FunctionType funcType =
+ mlir::FunctionType::get(builder.getContext(), {}, {});
+ mlir::Type i32Ty = builder.getI32Type();
+ if (termKind == TerminationKind::Normal) {
+ funcName = getPRIFProcName("stop");
+ funcType = mlir::FunctionType::get(builder.getContext(), {i32Ty}, {});
+ } else if (termKind == TerminationKind::Error) {
+ funcName = getPRIFProcName("error_stop");
+ funcType = mlir::FunctionType::get(builder.getContext(), {i32Ty}, {});
+ } else {
+ funcName = getPRIFProcName("fail_image");
+ }
+ funcName += "_termination_wrapper";
+ mlir::func::FuncOp funcWrapperOp =
+ module.lookupSymbol<mlir::func::FuncOp>(funcName);
+
+ if (!funcWrapperOp) {
+ funcWrapperOp = builder.createFunction(loc, funcName, funcType);
+
+ // generating the body of the function.
+ mlir::OpBuilder::InsertPoint saveInsertPoint = builder.saveInsertionPoint();
+ builder.setInsertionPointToStart(funcWrapperOp.addEntryBlock());
+
+ if (termKind == TerminationKind::Normal) {
+ genPRIFStopErrorStop(builder, loc, funcWrapperOp.getArgument(0),
+ /*isError*/ false);
+ } else if (termKind == TerminationKind::Error) {
+ genPRIFStopErrorStop(builder, loc, funcWrapperOp.getArgument(0),
+ /*isError*/ true);
+ } else {
+ mlir::func::FuncOp fOp = builder.createFunction(
+ loc, getPRIFProcName("fail_image"),
+ mlir::FunctionType::get(builder.getContext(), {}, {}));
+ fir::CallOp::create(builder, loc, fOp);
+ }
+
+ mlir::func::ReturnOp::create(builder, loc);
+ builder.restoreInsertionPoint(saveInsertPoint);
+ }
+
+ mlir::SymbolRefAttr symbolRef = mlir::SymbolRefAttr::get(
+ builder.getContext(), funcWrapperOp.getSymNameAttr());
+ return fir::AddrOfOp::create(builder, loc, funcType, symbolRef);
+}
+
/// Convert mif.init operation to runtime call of 'prif_init'
struct MIFInitOpConversion : public mlir::OpRewritePattern<mif::InitOp> {
using OpRewritePattern::OpRewritePattern;
@@ -80,6 +193,39 @@ struct MIFInitOpConversion : public mlir::OpRewritePattern<mif::InitOp> {
mlir::Type i32Ty = builder.getI32Type();
mlir::Value result = builder.createTemporary(loc, i32Ty);
+
+ // Registering PRIF runtime termination to the Fortran runtime
+ // STOP
+ mlir::Value funcStopOp = genTerminationOperationWrapper(
+ builder, loc, mod, TerminationKind::Normal);
+ mlir::func::FuncOp normalEndFunc =
+ fir::runtime::getRuntimeFunc<mkRTKey(RegisterImagesNormalEndCallback)>(
+ loc, builder);
+ llvm::SmallVector<mlir::Value> args1 = fir::runtime::createArguments(
+ builder, loc, normalEndFunc.getFunctionType(), funcStopOp);
+ fir::CallOp::create(builder, loc, normalEndFunc, args1);
+
+ // ERROR STOP
+ mlir::Value funcErrorStopOp = genTerminationOperationWrapper(
+ builder, loc, mod, TerminationKind::Error);
+ mlir::func::FuncOp errorFunc =
+ fir::runtime::getRuntimeFunc<mkRTKey(RegisterImagesErrorCallback)>(
+ loc, builder);
+ llvm::SmallVector<mlir::Value> args2 = fir::runtime::createArguments(
+ builder, loc, errorFunc.getFunctionType(), funcErrorStopOp);
+ fir::CallOp::create(builder, loc, errorFunc, args2);
+
+ // FAIL IMAGE
+ mlir::Value failImageOp = genTerminationOperationWrapper(
+ builder, loc, mod, TerminationKind::FailImage);
+ mlir::func::FuncOp failImageFunc =
+ fir::runtime::getRuntimeFunc<mkRTKey(RegisterFailImageCallback)>(
+ loc, builder);
+ llvm::SmallVector<mlir::Value> args3 = fir::runtime::createArguments(
+ builder, loc, errorFunc.getFunctionType(), failImageOp);
+ fir::CallOp::create(builder, loc, failImageFunc, args3);
+
+ // Intialize the multi-image parallel environment
mlir::FunctionType ftype = mlir::FunctionType::get(
builder.getContext(),
/*inputs*/ {builder.getRefType(i32Ty)}, /*results*/ {});
@@ -210,9 +356,7 @@ struct MIFSyncAllOpConversion : public mlir::OpRewritePattern<mif::SyncAllOp> {
auto [errmsgArg, errmsgAllocArg] =
genErrmsgPRIF(builder, loc, op.getErrmsg());
- mlir::Value stat = op.getStat();
- if (!stat)
- stat = fir::AbsentOp::create(builder, loc, getPRIFStatType(builder));
+ mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
builder, loc, ftype, stat, errmsgArg, errmsgAllocArg);
rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args);
@@ -261,9 +405,7 @@ struct MIFSyncImagesOpConversion
}
auto [errmsgArg, errmsgAllocArg] =
genErrmsgPRIF(builder, loc, op.getErrmsg());
- mlir::Value stat = op.getStat();
- if (!stat)
- stat = fir::AbsentOp::create(builder, loc, getPRIFStatType(builder));
+ mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
builder, loc, ftype, imageSet, stat, errmsgArg, errmsgAllocArg);
rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args);
@@ -293,9 +435,7 @@ struct MIFSyncMemoryOpConversion
auto [errmsgArg, errmsgAllocArg] =
genErrmsgPRIF(builder, loc, op.getErrmsg());
- mlir::Value stat = op.getStat();
- if (!stat)
- stat = fir::AbsentOp::create(builder, loc, getPRIFStatType(builder));
+ mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
builder, loc, ftype, stat, errmsgArg, errmsgAllocArg);
rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args);
@@ -303,6 +443,37 @@ struct MIFSyncMemoryOpConversion
}
};
+/// Convert mif.sync_team operation to runtime call of 'prif_sync_team'
+struct MIFSyncTeamOpConversion
+ : public mlir::OpRewritePattern<mif::SyncTeamOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mif::SyncTeamOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto mod = op->template getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+
+ mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
+ mlir::Type errmsgTy = getPRIFErrmsgType(builder);
+ mlir::FunctionType ftype = mlir::FunctionType::get(
+ builder.getContext(),
+ /*inputs*/ {boxTy, getPRIFStatType(builder), errmsgTy, errmsgTy},
+ /*results*/ {});
+ mlir::func::FuncOp funcOp =
+ builder.createFunction(loc, getPRIFProcName("sync_team"), ftype);
+
+ auto [errmsgArg, errmsgAllocArg] =
+ genErrmsgPRIF(builder, loc, op.getErrmsg());
+ mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
+ llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
+ builder, loc, ftype, op.getTeam(), stat, errmsgArg, errmsgAllocArg);
+ rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args);
+ return mlir::success();
+ }
+};
+
/// Generate call to collective subroutines except co_reduce
/// A must be lowered as a box
static fir::CallOp genCollectiveSubroutine(fir::FirOpBuilder &builder,
@@ -432,6 +603,208 @@ struct MIFCoSumOpConversion : public mlir::OpRewritePattern<mif::CoSumOp> {
}
};
+/// Convert mif.form_team operation to runtime call of 'prif_form_team'
+struct MIFFormTeamOpConversion
+ : public mlir::OpRewritePattern<mif::FormTeamOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mif::FormTeamOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto mod = op->template getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+ mlir::Type errmsgTy = getPRIFErrmsgType(builder);
+ mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
+ mlir::FunctionType ftype = mlir::FunctionType::get(
+ builder.getContext(),
+ /*inputs*/
+ {builder.getRefType(builder.getI64Type()), boxTy,
+ builder.getRefType(builder.getI32Type()), getPRIFStatType(builder),
+ errmsgTy, errmsgTy},
+ /*results*/ {});
+ mlir::func::FuncOp funcOp =
+ builder.createFunction(loc, getPRIFProcName("form_team"), ftype);
+
+ mlir::Type i64Ty = builder.getI64Type();
+ mlir::Value teamNumber = builder.createTemporary(loc, i64Ty);
+ mlir::Value t =
+ (op.getTeamNumber().getType() == i64Ty)
+ ? op.getTeamNumber()
+ : fir::ConvertOp::create(builder, loc, i64Ty, op.getTeamNumber());
+ fir::StoreOp::create(builder, loc, t, teamNumber);
+
+ mlir::Type i32Ty = builder.getI32Type();
+ mlir::Value newIndex;
+ if (op.getNewIndex()) {
+ newIndex = builder.createTemporary(loc, i32Ty);
+ mlir::Value ni =
+ (op.getNewIndex().getType() == i32Ty)
+ ? op.getNewIndex()
+ : fir::ConvertOp::create(builder, loc, i32Ty, op.getNewIndex());
+ fir::StoreOp::create(builder, loc, ni, newIndex);
+ } else
+ newIndex = fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty));
+
+ mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
+ auto [errmsgArg, errmsgAllocArg] =
+ genErrmsgPRIF(builder, loc, op.getErrmsg());
+ llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
+ builder, loc, ftype, teamNumber, op.getTeamVar(), newIndex, stat,
+ errmsgArg, errmsgAllocArg);
+ fir::CallOp callOp = fir::CallOp::create(builder, loc, funcOp, args);
+ rewriter.replaceOp(op, callOp);
+ return mlir::success();
+ }
+};
+
+/// Convert mif.change_team operation to runtime call of 'prif_change_team'
+struct MIFChangeTeamOpConversion
+ : public mlir::OpRewritePattern<mif::ChangeTeamOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mif::ChangeTeamOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto mod = op->template getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ builder.setInsertionPoint(op);
+
+ mlir::Location loc = op.getLoc();
+ mlir::Type errmsgTy = getPRIFErrmsgType(builder);
+ mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
+ mlir::FunctionType ftype = mlir::FunctionType::get(
+ builder.getContext(),
+ /*inputs*/ {boxTy, getPRIFStatType(builder), errmsgTy, errmsgTy},
+ /*results*/ {});
+ mlir::func::FuncOp funcOp =
+ builder.createFunction(loc, getPRIFProcName("change_team"), ftype);
+
+ mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
+ auto [errmsgArg, errmsgAllocArg] =
+ genErrmsgPRIF(builder, loc, op.getErrmsg());
+ llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
+ builder, loc, ftype, op.getTeam(), stat, errmsgArg, errmsgAllocArg);
+ fir::CallOp::create(builder, loc, funcOp, args);
+
+ mlir::Operation *changeOp = op.getOperation();
+ auto &bodyRegion = op.getRegion();
+ mlir::Block &bodyBlock = bodyRegion.front();
+
+ rewriter.inlineBlockBefore(&bodyBlock, changeOp);
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+};
+
+/// Convert mif.end_team operation to runtime call of 'prif_end_team'
+struct MIFEndTeamOpConversion : public mlir::OpRewritePattern<mif::EndTeamOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mif::EndTeamOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto mod = op->template getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+ mlir::Type errmsgTy = getPRIFErrmsgType(builder);
+ mlir::FunctionType ftype = mlir::FunctionType::get(
+ builder.getContext(),
+ /*inputs*/ {getPRIFStatType(builder), errmsgTy, errmsgTy},
+ /*results*/ {});
+ mlir::func::FuncOp funcOp =
+ builder.createFunction(loc, getPRIFProcName("end_team"), ftype);
+
+ mlir::Value stat = genStatPRIF(builder, loc, op.getStat());
+ auto [errmsgArg, errmsgAllocArg] =
+ genErrmsgPRIF(builder, loc, op.getErrmsg());
+ llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
+ builder, loc, ftype, stat, errmsgArg, errmsgAllocArg);
+ fir::CallOp callOp = fir::CallOp::create(builder, loc, funcOp, args);
+ rewriter.replaceOp(op, callOp);
+ return mlir::success();
+ }
+};
+
+/// Convert mif.get_team operation to runtime call of 'prif_get_team'
+struct MIFGetTeamOpConversion : public mlir::OpRewritePattern<mif::GetTeamOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mif::GetTeamOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto mod = op->template getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+
+ mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
+ mlir::Type lvlTy = builder.getRefType(builder.getI32Type());
+ mlir::FunctionType ftype =
+ mlir::FunctionType::get(builder.getContext(),
+ /*inputs*/ {lvlTy, boxTy},
+ /*results*/ {});
+ mlir::func::FuncOp funcOp =
+ builder.createFunction(loc, getPRIFProcName("get_team"), ftype);
+
+ mlir::Value level = op.getLevel();
+ if (!level)
+ level = fir::AbsentOp::create(builder, loc, lvlTy);
+ else {
+ mlir::Value cst = op.getLevel();
+ mlir::Type i32Ty = builder.getI32Type();
+ level = builder.createTemporary(loc, i32Ty);
+ if (cst.getType() != i32Ty)
+ cst = builder.createConvert(loc, i32Ty, cst);
+ fir::StoreOp::create(builder, loc, cst, level);
+ }
+ mlir::Type resultType = op.getResult().getType();
+ mlir::Type baseTy = fir::unwrapRefType(resultType);
+ mlir::Value team = builder.createTemporary(loc, baseTy);
+ fir::EmboxOp box = fir::EmboxOp::create(builder, loc, resultType, team);
+
+ llvm::SmallVector<mlir::Value> args =
+ fir::runtime::createArguments(builder, loc, ftype, level, box);
+ fir::CallOp::create(builder, loc, funcOp, args);
+
+ rewriter.replaceOp(op, box);
+ return mlir::success();
+ }
+};
+
+/// Convert mif.team_number operation to runtime call of 'prif_team_number'
+struct MIFTeamNumberOpConversion
+ : public mlir::OpRewritePattern<mif::TeamNumberOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mif::TeamNumberOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto mod = op->template getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::Location loc = op.getLoc();
+ mlir::Type i64Ty = builder.getI64Type();
+ mlir::Type boxTy = fir::BoxType::get(builder.getNoneType());
+ mlir::FunctionType ftype =
+ mlir::FunctionType::get(builder.getContext(),
+ /*inputs*/ {boxTy, builder.getRefType(i64Ty)},
+ /*results*/ {});
+ mlir::func::FuncOp funcOp =
+ builder.createFunction(loc, getPRIFProcName("team_number"), ftype);
+
+ mlir::Value team = op.getTeam();
+ if (!team)
+ team = fir::AbsentOp::create(builder, loc, boxTy);
+
+ mlir::Value result = builder.createTemporary(loc, i64Ty);
+ llvm::SmallVector<mlir::Value> args =
+ fir::runtime::createArguments(builder, loc, ftype, team, result);
+ fir::CallOp::create(builder, loc, funcOp, args);
+ fir::LoadOp load = fir::LoadOp::create(builder, loc, result);
+ rewriter.replaceOp(op, load);
+ return mlir::success();
+ }
+};
+
class MIFOpConversion : public fir::impl::MIFOpConversionBase<MIFOpConversion> {
public:
void runOnOperation() override {
@@ -458,7 +831,10 @@ void mif::populateMIFOpConversionPatterns(mlir::RewritePatternSet &patterns) {
patterns.insert<MIFInitOpConversion, MIFThisImageOpConversion,
MIFNumImagesOpConversion, MIFSyncAllOpConversion,
MIFSyncImagesOpConversion, MIFSyncMemoryOpConversion,
- MIFCoBroadcastOpConversion, MIFCoMaxOpConversion,
- MIFCoMinOpConversion, MIFCoSumOpConversion>(
+ MIFSyncTeamOpConversion, MIFCoBroadcastOpConversion,
+ MIFCoMaxOpConversion, MIFCoMinOpConversion,
+ MIFCoSumOpConversion, MIFFormTeamOpConversion,
+ MIFChangeTeamOpConversion, MIFEndTeamOpConversion,
+ MIFGetTeamOpConversion, MIFTeamNumberOpConversion>(
patterns.getContext());
}
diff --git a/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp
index 25a8f7a..c9d52c4 100644
--- a/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp
@@ -246,7 +246,9 @@ struct DispatchOpConv : public OpConversionPattern<fir::DispatchOp> {
args.append(dispatch.getArgs().begin(), dispatch.getArgs().end());
rewriter.replaceOpWithNewOp<fir::CallOp>(
dispatch, resTypes, nullptr, args, dispatch.getArgAttrsAttr(),
- dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr());
+ dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr(),
+ /*inline_attr*/ fir::FortranInlineEnumAttr{},
+ /*accessGroups*/ mlir::ArrayAttr{});
return mlir::success();
}
diff --git a/flang/lib/Optimizer/Transforms/SetRuntimeCallAttributes.cpp b/flang/lib/Optimizer/Transforms/SetRuntimeCallAttributes.cpp
index 378037e..4ba2ea5 100644
--- a/flang/lib/Optimizer/Transforms/SetRuntimeCallAttributes.cpp
+++ b/flang/lib/Optimizer/Transforms/SetRuntimeCallAttributes.cpp
@@ -85,7 +85,10 @@ static mlir::LLVM::MemoryEffectsAttr getGenericMemoryAttr(fir::CallOp callOp) {
callOp->getContext(),
{/*other=*/mlir::LLVM::ModRefInfo::NoModRef,
/*argMem=*/mlir::LLVM::ModRefInfo::ModRef,
- /*inaccessibleMem=*/mlir::LLVM::ModRefInfo::ModRef});
+ /*inaccessibleMem=*/mlir::LLVM::ModRefInfo::ModRef,
+ /*errnoMem=*/mlir::LLVM::ModRefInfo::NoModRef,
+ /*targetMem0=*/mlir::LLVM::ModRefInfo::NoModRef,
+ /*targetMem1=*/mlir::LLVM::ModRefInfo::NoModRef});
}
return {};
diff --git a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
index 03f97eb..3c4da62 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
@@ -254,6 +254,10 @@ public:
// Collect iteration variable(s) allocations so that we can move them
// outside the `fir.do_concurrent` wrapper.
+ // There actually may be more operations that just allocations
+ // at the beginning of the wrapper block, e.g. LICM may move
+ // some operations from the inner fir.do_concurrent.loop into
+ // this block.
llvm::SmallVector<mlir::Operation *> opsToMove;
for (mlir::Operation &op : llvm::drop_end(wrapperBlock))
opsToMove.push_back(&op);
@@ -262,8 +266,13 @@ public:
rewriter, doConcurentOp->getParentOfType<mlir::ModuleOp>());
auto *allocIt = firBuilder.getAllocaBlock();
- for (mlir::Operation *op : llvm::reverse(opsToMove))
- rewriter.moveOpBefore(op, allocIt, allocIt->begin());
+ // Move alloca operations into the alloca-block, and all other
+ // operations - right before fir.do_concurrent.
+ for (mlir::Operation *op : opsToMove)
+ if (mlir::isa<fir::AllocaOp>(op))
+ rewriter.moveOpBefore(op, allocIt, allocIt->begin());
+ else
+ rewriter.moveOpBefore(op, doConcurentOp);
rewriter.setInsertionPointAfter(doConcurentOp);
fir::DoLoopOp innermostUnorderdLoop;
diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index 49a085e..49ae189 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -730,7 +730,6 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
mlir::Value ifCompatElem =
fir::ConvertOp::create(builder, loc, ifCompatType, maskElem);
- llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
fir::IfOp ifOp =
fir::IfOp::create(builder, loc, elementType, ifCompatElem,
/*withElseRegion=*/true);
diff --git a/flang/lib/Optimizer/Transforms/VScaleAttr.cpp b/flang/lib/Optimizer/Transforms/VScaleAttr.cpp
index 54a2456..d0e83ef 100644
--- a/flang/lib/Optimizer/Transforms/VScaleAttr.cpp
+++ b/flang/lib/Optimizer/Transforms/VScaleAttr.cpp
@@ -33,9 +33,11 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
+#include <mlir/IR/Diagnostics.h>
namespace fir {
#define GEN_PASS_DEF_VSCALEATTR
@@ -49,7 +51,8 @@ namespace {
class VScaleAttrPass : public fir::impl::VScaleAttrBase<VScaleAttrPass> {
public:
VScaleAttrPass(const fir::VScaleAttrOptions &options) {
- vscaleRange = options.vscaleRange;
+ vscaleMin = options.vscaleMin;
+ vscaleMax = options.vscaleMax;
}
VScaleAttrPass() {}
void runOnOperation() override;
@@ -63,16 +66,28 @@ void VScaleAttrPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "Func-name:" << func.getSymName() << "\n");
+ if (!llvm::isPowerOf2_32(vscaleMin)) {
+ func->emitError(
+ "VScaleAttr: vscaleMin has to be a power-of-two greater than 0\n");
+ return signalPassFailure();
+ }
+
+ if (vscaleMax != 0 &&
+ (!llvm::isPowerOf2_32(vscaleMax) || (vscaleMin > vscaleMax))) {
+ func->emitError("VScaleAttr: vscaleMax has to be a power-of-two "
+ "greater-than-or-equal to vscaleMin or 0 to signify "
+ "an unbounded maximum\n");
+ return signalPassFailure();
+ }
+
auto context = &getContext();
auto intTy = mlir::IntegerType::get(context, 32);
- assert(vscaleRange.first && "VScaleRange minimum should be non-zero");
-
func->setAttr("vscale_range",
mlir::LLVM::VScaleRangeAttr::get(
- context, mlir::IntegerAttr::get(intTy, vscaleRange.first),
- mlir::IntegerAttr::get(intTy, vscaleRange.second)));
+ context, mlir::IntegerAttr::get(intTy, vscaleMin),
+ mlir::IntegerAttr::get(intTy, vscaleMax)));
LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
}