aboutsummaryrefslogtreecommitdiff
path: root/flang/lib
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib')
-rw-r--r--flang/lib/Lower/Bridge.cpp7
-rw-r--r--flang/lib/Lower/CUDA.cpp27
-rw-r--r--flang/lib/Lower/ConvertExpr.cpp2
-rw-r--r--flang/lib/Lower/IO.cpp3
-rw-r--r--flang/lib/Lower/OpenMP/Atomic.cpp97
-rw-r--r--flang/lib/Optimizer/Builder/Character.cpp2
-rw-r--r--flang/lib/Optimizer/Builder/IntrinsicCall.cpp3
-rw-r--r--flang/lib/Optimizer/Dialect/FIRType.cpp19
-rw-r--r--flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp62
-rw-r--r--flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp241
-rw-r--r--flang/lib/Optimizer/OpenMP/CMakeLists.txt1
-rw-r--r--flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp1852
-rw-r--r--flang/lib/Optimizer/Passes/Pipelines.cpp4
-rw-r--r--flang/lib/Optimizer/Transforms/AffinePromotion.cpp2
-rw-r--r--flang/lib/Optimizer/Transforms/StackArrays.cpp2
-rw-r--r--flang/lib/Parser/parse-tree.cpp2
-rw-r--r--flang/lib/Semantics/assignment.cpp3
-rw-r--r--flang/lib/Semantics/check-allocate.cpp13
-rw-r--r--flang/lib/Semantics/check-case.cpp2
-rw-r--r--flang/lib/Semantics/check-coarray.cpp9
-rw-r--r--flang/lib/Semantics/check-data.cpp10
-rw-r--r--flang/lib/Semantics/check-deallocate.cpp3
-rw-r--r--flang/lib/Semantics/check-do-forall.cpp50
-rw-r--r--flang/lib/Semantics/check-io.cpp6
-rw-r--r--flang/lib/Semantics/check-omp-structure.cpp8
-rw-r--r--flang/lib/Semantics/data-to-inits.cpp17
-rw-r--r--flang/lib/Semantics/expression.cpp107
-rw-r--r--flang/lib/Semantics/resolve-names-utils.cpp9
-rw-r--r--flang/lib/Semantics/resolve-names.cpp32
29 files changed, 2416 insertions, 179 deletions
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 68adf34..0595ca0 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4987,11 +4987,8 @@ private:
// host = device
if (!lhsIsDevice && rhsIsDevice) {
- if (Fortran::lower::isTransferWithConversion(rhs)) {
+ if (auto elementalOp = Fortran::lower::isTransferWithConversion(rhs)) {
mlir::OpBuilder::InsertionGuard insertionGuard(builder);
- auto elementalOp =
- mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp());
- assert(elementalOp && "expect elemental op");
auto designateOp =
*elementalOp.getBody()->getOps<hlfir::DesignateOp>().begin();
builder.setInsertionPoint(elementalOp);
@@ -6079,7 +6076,7 @@ private:
if (resTy != wrappedSymTy) {
// check size of the pointed to type so we can't overflow by writing
// double precision to a single precision allocation, etc
- LLVM_ATTRIBUTE_UNUSED auto getBitWidth = [this](mlir::Type ty) {
+ [[maybe_unused]] auto getBitWidth = [this](mlir::Type ty) {
// 15.6.2.6.3: differering result types should be integer, real,
// complex or logical
if (auto cmplx = mlir::dyn_cast_or_null<mlir::ComplexType>(ty))
diff --git a/flang/lib/Lower/CUDA.cpp b/flang/lib/Lower/CUDA.cpp
index bb4bdee..9501b0e 100644
--- a/flang/lib/Lower/CUDA.cpp
+++ b/flang/lib/Lower/CUDA.cpp
@@ -68,11 +68,26 @@ cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute(
return cuf::getDataAttribute(mlirContext, cudaAttr);
}
-bool Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
+hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
+ auto isConversionElementalOp = [](hlfir::ElementalOp elOp) {
+ return llvm::hasSingleElement(
+ elOp.getBody()->getOps<hlfir::DesignateOp>()) &&
+ llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 &&
+ llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) ==
+ 1;
+ };
+ if (auto declOp = mlir::dyn_cast<hlfir::DeclareOp>(rhs.getDefiningOp())) {
+ if (!declOp.getMemref().getDefiningOp())
+ return {};
+ if (auto associateOp = mlir::dyn_cast<hlfir::AssociateOp>(
+ declOp.getMemref().getDefiningOp()))
+ if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(
+ associateOp.getSource().getDefiningOp()))
+ if (isConversionElementalOp(elOp))
+ return elOp;
+ }
if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp()))
- if (llvm::hasSingleElement(elOp.getBody()->getOps<hlfir::DesignateOp>()) &&
- llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 &&
- llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) == 1)
- return true;
- return false;
+ if (isConversionElementalOp(elOp))
+ return elOp;
+ return {};
}
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index d7f94e1..a46d219 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -5603,7 +5603,7 @@ private:
return newIters;
};
if (useTripsForSlice) {
- LLVM_ATTRIBUTE_UNUSED auto vectorSubscriptShape =
+ [[maybe_unused]] auto vectorSubscriptShape =
getShape(arrayOperands.back());
auto undef = fir::UndefOp::create(builder, loc, idxTy);
trips.push_back(undef);
diff --git a/flang/lib/Lower/IO.cpp b/flang/lib/Lower/IO.cpp
index 604b137..cd53dc9 100644
--- a/flang/lib/Lower/IO.cpp
+++ b/flang/lib/Lower/IO.cpp
@@ -950,7 +950,8 @@ static void genIoLoop(Fortran::lower::AbstractConverter &converter,
makeNextConditionalOn(builder, loc, checkResult, ok, inLoop);
const auto &itemList = std::get<0>(ioImpliedDo.t);
const auto &control = std::get<1>(ioImpliedDo.t);
- const auto &loopSym = *control.name.thing.thing.symbol;
+ const auto &loopSym =
+ *Fortran::parser::UnwrapRef<Fortran::parser::Name>(control.name).symbol;
mlir::Value loopVar = fir::getBase(converter.genExprAddr(
Fortran::evaluate::AsGenericExpr(loopSym).value(), stmtCtx));
auto genControlValue = [&](const Fortran::parser::ScalarIntExpr &expr) {
diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index ff82a36..3ab8a58 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -20,6 +20,7 @@
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/parse-tree.h"
+#include "flang/Semantics/openmp-utils.h"
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/type.h"
#include "flang/Support/Fortran.h"
@@ -183,12 +184,8 @@ getMemoryOrderFromRequires(const semantics::Scope &scope) {
// scope.
// For safety, traverse all enclosing scopes and check if their symbol
// contains REQUIRES.
- for (const auto *sc{&scope}; sc->kind() != semantics::Scope::Kind::Global;
- sc = &sc->parent()) {
- const semantics::Symbol *sym = sc->symbol();
- if (!sym)
- continue;
-
+ const semantics::Scope &unitScope = semantics::omp::GetProgramUnit(scope);
+ if (auto *symbol = unitScope.symbol()) {
const common::OmpMemoryOrderType *admo = common::visit(
[](auto &&s) {
using WithOmpDeclarative = semantics::WithOmpDeclarative;
@@ -198,7 +195,8 @@ getMemoryOrderFromRequires(const semantics::Scope &scope) {
}
return static_cast<const common::OmpMemoryOrderType *>(nullptr);
},
- sym->details());
+ symbol->details());
+
if (admo)
return getMemoryOrderKind(*admo);
}
@@ -214,19 +212,83 @@ getDefaultAtomicMemOrder(semantics::SemanticsContext &semaCtx) {
return std::nullopt;
}
-static std::optional<mlir::omp::ClauseMemoryOrderKind>
+static std::pair<std::optional<mlir::omp::ClauseMemoryOrderKind>, bool>
getAtomicMemoryOrder(semantics::SemanticsContext &semaCtx,
const omp::List<omp::Clause> &clauses,
const semantics::Scope &scope) {
for (const omp::Clause &clause : clauses) {
if (auto maybeKind = getMemoryOrderKind(clause.id))
- return *maybeKind;
+ return std::make_pair(*maybeKind, /*canOverride=*/false);
}
if (auto maybeKind = getMemoryOrderFromRequires(scope))
- return *maybeKind;
+ return std::make_pair(*maybeKind, /*canOverride=*/true);
- return getDefaultAtomicMemOrder(semaCtx);
+ return std::make_pair(getDefaultAtomicMemOrder(semaCtx),
+ /*canOverride=*/false);
+}
+
+static std::optional<mlir::omp::ClauseMemoryOrderKind>
+makeValidForAction(std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder,
+ int action0, int action1, unsigned version) {
+ // When the atomic default memory order specified on a REQUIRES directive is
+ // disallowed on a given ATOMIC operation, and it's not ACQ_REL, the order
+ // reverts to RELAXED. ACQ_REL decays to either ACQUIRE or RELEASE, depending
+ // on the operation.
+
+ if (!memOrder) {
+ return memOrder;
+ }
+
+ using Analysis = parser::OpenMPAtomicConstruct::Analysis;
+ // Figure out the main action (i.e. disregard a potential capture operation)
+ int action = action0;
+ if (action1 != Analysis::None)
+ action = action0 == Analysis::Read ? action1 : action0;
+
+ // Avaliable orderings: acquire, acq_rel, relaxed, release, seq_cst
+
+ if (action == Analysis::Read) {
+ // "acq_rel" decays to "acquire"
+ if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acq_rel)
+ return mlir::omp::ClauseMemoryOrderKind::Acquire;
+ } else if (action == Analysis::Write) {
+ // "acq_rel" decays to "release"
+ if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acq_rel)
+ return mlir::omp::ClauseMemoryOrderKind::Release;
+ }
+
+ if (version > 50) {
+ if (action == Analysis::Read) {
+ // "release" prohibited
+ if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Release)
+ return mlir::omp::ClauseMemoryOrderKind::Relaxed;
+ }
+ if (action == Analysis::Write) {
+ // "acquire" prohibited
+ if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acquire)
+ return mlir::omp::ClauseMemoryOrderKind::Relaxed;
+ }
+ } else {
+ if (action == Analysis::Read) {
+ // "release" prohibited
+ if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Release)
+ return mlir::omp::ClauseMemoryOrderKind::Relaxed;
+ } else {
+ if (action & Analysis::Write) { // include "update"
+ // "acquire" prohibited
+ if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acquire)
+ return mlir::omp::ClauseMemoryOrderKind::Relaxed;
+ if (action == Analysis::Update) {
+ // "acq_rel" prohibited
+ if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acq_rel)
+ return mlir::omp::ClauseMemoryOrderKind::Relaxed;
+ }
+ }
+ }
+ }
+
+ return memOrder;
}
static mlir::omp::ClauseMemoryOrderKindAttr
@@ -449,16 +511,19 @@ void Fortran::lower::omp::lowerAtomic(
mlir::Value atomAddr =
fir::getBase(converter.genExprAddr(atom, stmtCtx, &loc));
mlir::IntegerAttr hint = getAtomicHint(converter, clauses);
- std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder =
- getAtomicMemoryOrder(semaCtx, clauses,
- semaCtx.FindScope(construct.source));
+ auto [memOrder, canOverride] = getAtomicMemoryOrder(
+ semaCtx, clauses, semaCtx.FindScope(construct.source));
+
+ unsigned version = semaCtx.langOptions().OpenMPVersion;
+ int action0 = analysis.op0.what & analysis.Action;
+ int action1 = analysis.op1.what & analysis.Action;
+ if (canOverride)
+ memOrder = makeValidForAction(memOrder, action0, action1, version);
if (auto *cond = get(analysis.cond)) {
(void)cond;
TODO(loc, "OpenMP ATOMIC COMPARE");
} else {
- int action0 = analysis.op0.what & analysis.Action;
- int action1 = analysis.op1.what & analysis.Action;
mlir::Operation *captureOp = nullptr;
fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
fir::FirOpBuilder::InsertPoint atomicAt, postAt;
diff --git a/flang/lib/Optimizer/Builder/Character.cpp b/flang/lib/Optimizer/Builder/Character.cpp
index a096099..155bc0f 100644
--- a/flang/lib/Optimizer/Builder/Character.cpp
+++ b/flang/lib/Optimizer/Builder/Character.cpp
@@ -92,7 +92,7 @@ getCompileTimeLength(const fir::CharBoxValue &box) {
/// Detect the precondition that the value `str` does not reside in memory. Such
/// values will have a type `!fir.array<...x!fir.char<N>>` or `!fir.char<N>`.
-LLVM_ATTRIBUTE_UNUSED static bool needToMaterialize(mlir::Value str) {
+[[maybe_unused]] static bool needToMaterialize(mlir::Value str) {
return mlir::isa<fir::SequenceType>(str.getType()) ||
fir::isa_char(str.getType());
}
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index e07baaf..0195178 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -2169,7 +2169,8 @@ IntrinsicLibrary::genElementalCall<IntrinsicLibrary::ExtendedGenerator>(
for (const fir::ExtendedValue &arg : args) {
auto *box = arg.getBoxOf<fir::BoxValue>();
if (!arg.getUnboxed() && !arg.getCharBox() &&
- !(box && fir::isScalarBoxedRecordType(fir::getBase(*box).getType())))
+ !(box && (fir::isScalarBoxedRecordType(fir::getBase(*box).getType()) ||
+ fir::isClassStarType(fir::getBase(*box).getType()))))
fir::emitFatalError(loc, "nonscalar intrinsic argument");
}
if (outline)
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index 4a9579c..48e1622 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -336,6 +336,17 @@ bool isBoxedRecordType(mlir::Type ty) {
return false;
}
+// CLASS(*)
+bool isClassStarType(mlir::Type ty) {
+ if (auto clTy = mlir::dyn_cast<fir::ClassType>(fir::unwrapRefType(ty))) {
+ if (mlir::isa<mlir::NoneType>(clTy.getEleTy()))
+ return true;
+ mlir::Type innerType = clTy.unwrapInnerType();
+ return innerType && mlir::isa<mlir::NoneType>(innerType);
+ }
+ return false;
+}
+
bool isScalarBoxedRecordType(mlir::Type ty) {
if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
ty = refTy;
@@ -398,12 +409,8 @@ bool isPolymorphicType(mlir::Type ty) {
bool isUnlimitedPolymorphicType(mlir::Type ty) {
// CLASS(*)
- if (auto clTy = mlir::dyn_cast<fir::ClassType>(fir::unwrapRefType(ty))) {
- if (mlir::isa<mlir::NoneType>(clTy.getEleTy()))
- return true;
- mlir::Type innerType = clTy.unwrapInnerType();
- return innerType && mlir::isa<mlir::NoneType>(innerType);
- }
+ if (isClassStarType(ty))
+ return true;
// TYPE(*)
return isAssumedType(ty);
}
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp
index a48b7ba..63a5803 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp
@@ -21,24 +21,27 @@
//===----------------------------------------------------------------------===//
/// Log RAW or WAW conflict.
-static void LLVM_ATTRIBUTE_UNUSED logConflict(llvm::raw_ostream &os,
- mlir::Value writtenOrReadVarA,
- mlir::Value writtenVarB);
+[[maybe_unused]] static void logConflict(llvm::raw_ostream &os,
+ mlir::Value writtenOrReadVarA,
+ mlir::Value writtenVarB);
/// Log when an expression evaluation must be saved.
-static void LLVM_ATTRIBUTE_UNUSED logSaveEvaluation(llvm::raw_ostream &os,
- unsigned runid,
- mlir::Region &yieldRegion,
- bool anyWrite);
+[[maybe_unused]] static void logSaveEvaluation(llvm::raw_ostream &os,
+ unsigned runid,
+ mlir::Region &yieldRegion,
+ bool anyWrite);
/// Log when an assignment is scheduled.
-static void LLVM_ATTRIBUTE_UNUSED logAssignmentEvaluation(
- llvm::raw_ostream &os, unsigned runid, hlfir::RegionAssignOp assign);
+[[maybe_unused]] static void
+logAssignmentEvaluation(llvm::raw_ostream &os, unsigned runid,
+ hlfir::RegionAssignOp assign);
/// Log when starting to schedule an order assignment tree.
-static void LLVM_ATTRIBUTE_UNUSED logStartScheduling(
- llvm::raw_ostream &os, hlfir::OrderedAssignmentTreeOpInterface root);
+[[maybe_unused]] static void
+logStartScheduling(llvm::raw_ostream &os,
+ hlfir::OrderedAssignmentTreeOpInterface root);
/// Log op if effect value is not known.
-static void LLVM_ATTRIBUTE_UNUSED logIfUnkownEffectValue(
- llvm::raw_ostream &os, mlir::MemoryEffects::EffectInstance effect,
- mlir::Operation &op);
+[[maybe_unused]] static void
+logIfUnkownEffectValue(llvm::raw_ostream &os,
+ mlir::MemoryEffects::EffectInstance effect,
+ mlir::Operation &op);
//===----------------------------------------------------------------------===//
// Scheduling Implementation
@@ -701,23 +704,24 @@ static llvm::raw_ostream &printRegionPath(llvm::raw_ostream &os,
return printRegionId(os, yieldRegion);
}
-static void LLVM_ATTRIBUTE_UNUSED logSaveEvaluation(llvm::raw_ostream &os,
- unsigned runid,
- mlir::Region &yieldRegion,
- bool anyWrite) {
+[[maybe_unused]] static void logSaveEvaluation(llvm::raw_ostream &os,
+ unsigned runid,
+ mlir::Region &yieldRegion,
+ bool anyWrite) {
os << "run " << runid << " save " << (anyWrite ? "(w)" : " ") << ": ";
printRegionPath(os, yieldRegion) << "\n";
}
-static void LLVM_ATTRIBUTE_UNUSED logAssignmentEvaluation(
- llvm::raw_ostream &os, unsigned runid, hlfir::RegionAssignOp assign) {
+[[maybe_unused]] static void
+logAssignmentEvaluation(llvm::raw_ostream &os, unsigned runid,
+ hlfir::RegionAssignOp assign) {
os << "run " << runid << " evaluate: ";
printNodePath(os, assign.getOperation()) << "\n";
}
-static void LLVM_ATTRIBUTE_UNUSED logConflict(llvm::raw_ostream &os,
- mlir::Value writtenOrReadVarA,
- mlir::Value writtenVarB) {
+[[maybe_unused]] static void logConflict(llvm::raw_ostream &os,
+ mlir::Value writtenOrReadVarA,
+ mlir::Value writtenVarB) {
auto printIfValue = [&](mlir::Value var) -> llvm::raw_ostream & {
if (!var)
return os << "<unknown>";
@@ -728,8 +732,9 @@ static void LLVM_ATTRIBUTE_UNUSED logConflict(llvm::raw_ostream &os,
printIfValue(writtenVarB) << "\n";
}
-static void LLVM_ATTRIBUTE_UNUSED logStartScheduling(
- llvm::raw_ostream &os, hlfir::OrderedAssignmentTreeOpInterface root) {
+[[maybe_unused]] static void
+logStartScheduling(llvm::raw_ostream &os,
+ hlfir::OrderedAssignmentTreeOpInterface root) {
os << "------------ scheduling ";
printNodePath(os, root.getOperation());
if (auto funcOp = root->getParentOfType<mlir::func::FuncOp>())
@@ -737,9 +742,10 @@ static void LLVM_ATTRIBUTE_UNUSED logStartScheduling(
os << "------------\n";
}
-static void LLVM_ATTRIBUTE_UNUSED logIfUnkownEffectValue(
- llvm::raw_ostream &os, mlir::MemoryEffects::EffectInstance effect,
- mlir::Operation &op) {
+[[maybe_unused]] static void
+logIfUnkownEffectValue(llvm::raw_ostream &os,
+ mlir::MemoryEffects::EffectInstance effect,
+ mlir::Operation &op) {
if (effect.getValue() != nullptr)
return;
os << "unknown effected value (";
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
index 9bf10b5..ed9e41c 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
@@ -751,4 +751,245 @@ template bool OpenACCMappableModel<fir::PointerType>::generatePrivateDestroy(
mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc,
mlir::Value privatized) const;
+template <typename Ty>
+mlir::Value OpenACCPointerLikeModel<Ty>::genAllocate(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar,
+ bool &needsFree) const {
+
+ // Unwrap to get the pointee type.
+ mlir::Type pointeeTy = fir::dyn_cast_ptrEleTy(pointer);
+ assert(pointeeTy && "expected pointee type to be extractable");
+
+ // Box types are descriptors that contain both metadata and a pointer to data.
+ // The `genAllocate` API is designed for simple allocations and cannot
+ // properly handle the dual nature of boxes. Using `generatePrivateInit`
+ // instead can allocate both the descriptor and its referenced data. For use
+ // cases that require an empty descriptor storage, potentially this could be
+ // implemented here.
+ if (fir::isa_box_type(pointeeTy))
+ return {};
+
+ // Unlimited polymorphic (class(*)) cannot be handled - size unknown
+ if (fir::isUnlimitedPolymorphicType(pointeeTy))
+ return {};
+
+ // Return null for dynamic size types because the size of the
+ // allocation cannot be determined simply from the type.
+ if (fir::hasDynamicSize(pointeeTy))
+ return {};
+
+ // Use heap allocation for fir.heap, stack allocation for others (fir.ref,
+ // fir.ptr, fir.llvm_ptr). For fir.ptr, which is supposed to represent a
+ // Fortran pointer type, it feels a bit odd to "allocate" since it is meant
+ // to point to an existing entity - but one can imagine where a pointee is
+ // privatized - thus it makes sense to issue an allocate.
+ mlir::Value allocation;
+ if (std::is_same_v<Ty, fir::HeapType>) {
+ needsFree = true;
+ allocation = fir::AllocMemOp::create(builder, loc, pointeeTy);
+ } else {
+ needsFree = false;
+ allocation = fir::AllocaOp::create(builder, loc, pointeeTy);
+ }
+
+ // Convert to the requested pointer type if needed.
+ // This means converting from a fir.ref to either a fir.llvm_ptr or a fir.ptr.
+ // fir.heap is already correct type in this case.
+ if (allocation.getType() != pointer) {
+ assert(!(std::is_same_v<Ty, fir::HeapType>) &&
+ "fir.heap is already correct type because of allocmem");
+ return fir::ConvertOp::create(builder, loc, pointer, allocation);
+ }
+
+ return allocation;
+}
+
+template mlir::Value OpenACCPointerLikeModel<fir::ReferenceType>::genAllocate(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar,
+ bool &needsFree) const;
+
+template mlir::Value OpenACCPointerLikeModel<fir::PointerType>::genAllocate(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar,
+ bool &needsFree) const;
+
+template mlir::Value OpenACCPointerLikeModel<fir::HeapType>::genAllocate(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar,
+ bool &needsFree) const;
+
+template mlir::Value OpenACCPointerLikeModel<fir::LLVMPointerType>::genAllocate(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar,
+ bool &needsFree) const;
+
+static mlir::Value stripCasts(mlir::Value value, bool stripDeclare = true) {
+ mlir::Value currentValue = value;
+
+ while (currentValue) {
+ auto *definingOp = currentValue.getDefiningOp();
+ if (!definingOp)
+ break;
+
+ if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(definingOp)) {
+ currentValue = convertOp.getValue();
+ continue;
+ }
+
+ if (auto viewLike = mlir::dyn_cast<mlir::ViewLikeOpInterface>(definingOp)) {
+ currentValue = viewLike.getViewSource();
+ continue;
+ }
+
+ if (stripDeclare) {
+ if (auto declareOp = mlir::dyn_cast<hlfir::DeclareOp>(definingOp)) {
+ currentValue = declareOp.getMemref();
+ continue;
+ }
+
+ if (auto declareOp = mlir::dyn_cast<fir::DeclareOp>(definingOp)) {
+ currentValue = declareOp.getMemref();
+ continue;
+ }
+ }
+ break;
+ }
+
+ return currentValue;
+}
+
+template <typename Ty>
+bool OpenACCPointerLikeModel<Ty>::genFree(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> varToFree,
+ mlir::Value allocRes, mlir::Type varType) const {
+
+ // Unwrap to get the pointee type.
+ mlir::Type pointeeTy = fir::dyn_cast_ptrEleTy(pointer);
+ assert(pointeeTy && "expected pointee type to be extractable");
+
+ // Box types contain both a descriptor and data. The `genFree` API
+ // handles simple deallocations and cannot properly manage both parts.
+ // Using `generatePrivateDestroy` instead can free both the descriptor and
+ // its referenced data.
+ if (fir::isa_box_type(pointeeTy))
+ return false;
+
+ // If pointer type is HeapType, assume it's a heap allocation
+ if (std::is_same_v<Ty, fir::HeapType>) {
+ fir::FreeMemOp::create(builder, loc, varToFree);
+ return true;
+ }
+
+ // Use allocRes if provided to determine the allocation type
+ mlir::Value valueToInspect = allocRes ? allocRes : varToFree;
+
+ // Strip casts and declare operations to find the original allocation
+ mlir::Value strippedValue = stripCasts(valueToInspect);
+ mlir::Operation *originalAlloc = strippedValue.getDefiningOp();
+
+ // If we found an AllocMemOp (heap allocation), free it
+ if (mlir::isa_and_nonnull<fir::AllocMemOp>(originalAlloc)) {
+ mlir::Value toFree = varToFree;
+ if (!mlir::isa<fir::HeapType>(valueToInspect.getType()))
+ toFree = fir::ConvertOp::create(
+ builder, loc,
+ fir::HeapType::get(varToFree.getType().getElementType()), toFree);
+ fir::FreeMemOp::create(builder, loc, toFree);
+ return true;
+ }
+
+ // If we found an AllocaOp (stack allocation), no deallocation needed
+ if (mlir::isa_and_nonnull<fir::AllocaOp>(originalAlloc))
+ return true;
+
+ // Unable to determine allocation type
+ return false;
+}
+
+template bool OpenACCPointerLikeModel<fir::ReferenceType>::genFree(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> varToFree,
+ mlir::Value allocRes, mlir::Type varType) const;
+
+template bool OpenACCPointerLikeModel<fir::PointerType>::genFree(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> varToFree,
+ mlir::Value allocRes, mlir::Type varType) const;
+
+template bool OpenACCPointerLikeModel<fir::HeapType>::genFree(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> varToFree,
+ mlir::Value allocRes, mlir::Type varType) const;
+
+template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::genFree(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> varToFree,
+ mlir::Value allocRes, mlir::Type varType) const;
+
+template <typename Ty>
+bool OpenACCPointerLikeModel<Ty>::genCopy(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> destination,
+ mlir::TypedValue<mlir::acc::PointerLikeType> source,
+ mlir::Type varType) const {
+
+ // Check that source and destination types match
+ if (source.getType() != destination.getType())
+ return false;
+
+ // Unwrap to get the pointee type.
+ mlir::Type pointeeTy = fir::dyn_cast_ptrEleTy(pointer);
+ assert(pointeeTy && "expected pointee type to be extractable");
+
+ // Box types contain both a descriptor and referenced data. The genCopy API
+ // handles simple copies and cannot properly manage both parts.
+ if (fir::isa_box_type(pointeeTy))
+ return false;
+
+ // Unlimited polymorphic (class(*)) cannot be handled because source and
+ // destination types are not known.
+ if (fir::isUnlimitedPolymorphicType(pointeeTy))
+ return false;
+
+ // Return false for dynamic size types because the copy logic
+ // cannot be determined simply from the type.
+ if (fir::hasDynamicSize(pointeeTy))
+ return false;
+
+ if (fir::isa_trivial(pointeeTy)) {
+ auto loadVal = fir::LoadOp::create(builder, loc, source);
+ fir::StoreOp::create(builder, loc, loadVal, destination);
+ } else {
+ hlfir::AssignOp::create(builder, loc, source, destination);
+ }
+ return true;
+}
+
+template bool OpenACCPointerLikeModel<fir::ReferenceType>::genCopy(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> destination,
+ mlir::TypedValue<mlir::acc::PointerLikeType> source,
+ mlir::Type varType) const;
+
+template bool OpenACCPointerLikeModel<fir::PointerType>::genCopy(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> destination,
+ mlir::TypedValue<mlir::acc::PointerLikeType> source,
+ mlir::Type varType) const;
+
+template bool OpenACCPointerLikeModel<fir::HeapType>::genCopy(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> destination,
+ mlir::TypedValue<mlir::acc::PointerLikeType> source,
+ mlir::Type varType) const;
+
+template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::genCopy(
+ mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc,
+ mlir::TypedValue<mlir::acc::PointerLikeType> destination,
+ mlir::TypedValue<mlir::acc::PointerLikeType> source,
+ mlir::Type varType) const;
+
} // namespace fir::acc
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index b85ee7e..23a7dc8 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -8,6 +8,7 @@ add_flang_library(FlangOpenMPTransforms
MapsForPrivatizedSymbols.cpp
MapInfoFinalization.cpp
MarkDeclareTarget.cpp
+ LowerWorkdistribute.cpp
LowerWorkshare.cpp
LowerNontemporal.cpp
SimdOnly.cpp
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
new file mode 100644
index 0000000..9278e17
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -0,0 +1,1852 @@
+//===- LowerWorkdistribute.cpp
+//-------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering and optimisations of omp.workdistribute.
+//
+// Fortran array statements are lowered to fir as fir.do_loop unordered.
+// lower-workdistribute pass works mainly on identifying fir.do_loop unordered
+// that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and
+// lowers it to target{teams{parallel{distribute{wsloop{loop_nest}}}}}.
+// It hoists all the other ops outside target region.
+// Relaces heap allocation on target with omp.target_allocmem and
+// deallocation with omp.target_freemem from host. Also replaces
+// runtime function "Assign" with omp_target_memcpy.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Utils.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/Utils/IndexingUtils.h>
+#include <mlir/IR/BlockSupport.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/Diagnostics.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <optional>
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workdistribute"
+
+using namespace mlir;
+
+namespace {
+
+/// This string is used to identify the Fortran-specific runtime FortranAAssign.
+static constexpr llvm::StringRef FortranAssignStr = "_FortranAAssign";
+
+/// The isRuntimeCall function is a utility designed to determine
+/// if a given operation is a call to a Fortran-specific runtime function.
+static bool isRuntimeCall(Operation *op) {
+ if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+ auto callee = callOp.getCallee();
+ if (!callee)
+ return false;
+ auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
+ if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
+ return true;
+ }
+ return false;
+}
+
+/// This is the single source of truth about whether we should parallelize an
+/// operation nested in an omp.workdistribute region.
+/// Parallelize here refers to dividing into units of work.
+static bool shouldParallelize(Operation *op) {
+ // True if the op is a runtime call to Assign
+ if (isRuntimeCall(op)) {
+ fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+ auto funcName = runtimeCall.getCallee()->getRootReference().getValue();
+ if (funcName == FortranAssignStr) {
+ return true;
+ }
+ }
+ // We cannot parallelize ops with side effects.
+ // Parallelizable operations should not produce
+ // values that other operations depend on
+ if (llvm::any_of(op->getResults(),
+ [](OpResult v) -> bool { return !v.use_empty(); }))
+ return false;
+ // We will parallelize unordered loops - these come from array syntax
+ if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
+ auto unordered = loop.getUnordered();
+ if (!unordered)
+ return false;
+ return *unordered;
+ }
+ // We cannot parallelize anything else.
+ return false;
+}
+
+/// The getPerfectlyNested function is a generic utility for finding
+/// a single, "perfectly nested" operation within a parent operation.
+template <typename T>
+static T getPerfectlyNested(Operation *op) {
+ if (op->getNumRegions() != 1)
+ return nullptr;
+ auto &region = op->getRegion(0);
+ if (region.getBlocks().size() != 1)
+ return nullptr;
+ auto *block = &region.front();
+ auto *firstOp = &block->front();
+ if (auto nested = dyn_cast<T>(firstOp))
+ if (firstOp->getNextNode() == block->getTerminator())
+ return nested;
+ return nullptr;
+}
+
+/// verifyTargetTeamsWorkdistribute method verifies that
+/// omp.target { teams { workdistribute { ... } } } is well formed
+/// and fails for function calls that don't have lowering implemented yet.
+static LogicalResult
+verifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) {
+ OpBuilder rewriter(workdistribute);
+ auto loc = workdistribute->getLoc();
+ auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+ if (!teams) {
+ emitError(loc, "workdistribute not nested in teams\n");
+ return failure();
+ }
+ if (workdistribute.getRegion().getBlocks().size() != 1) {
+ emitError(loc, "workdistribute with multiple blocks\n");
+ return failure();
+ }
+ if (teams.getRegion().getBlocks().size() != 1) {
+ emitError(loc, "teams with multiple blocks\n");
+ return failure();
+ }
+
+ bool foundWorkdistribute = false;
+ for (auto &op : teams.getOps()) {
+ if (isa<omp::WorkdistributeOp>(op)) {
+ if (foundWorkdistribute) {
+ emitError(loc, "teams has multiple workdistribute ops.\n");
+ return failure();
+ }
+ foundWorkdistribute = true;
+ continue;
+ }
+ // Identify any omp dialect ops present before/after workdistribute.
+ if (op.getDialect() && isa<omp::OpenMPDialect>(op.getDialect()) &&
+ !isa<omp::TerminatorOp>(op)) {
+ emitError(loc, "teams has omp ops other than workdistribute. Lowering "
+ "not implemented yet.\n");
+ return failure();
+ }
+ }
+
+ omp::TargetOp targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp());
+ // return if not omp.target
+ if (!targetOp)
+ return success();
+
+ for (auto &op : workdistribute.getOps()) {
+ if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+ if (isRuntimeCall(&op)) {
+ auto funcName = (*callOp.getCallee()).getRootReference().getValue();
+ // _FortranAAssign is handled. Other runtime calls are not supported
+ // in omp.workdistribute yet.
+ if (funcName == FortranAssignStr)
+ continue;
+ else {
+ emitError(loc, "Runtime call " + funcName +
+ " lowering not supported for workdistribute yet.");
+ return failure();
+ }
+ }
+ }
+ }
+ return success();
+}
+
+/// fissionWorkdistribute method finds the parallelizable ops
+/// within teams {workdistribute} region and moves them to their
+/// own teams{workdistribute} region.
+///
+/// If B() and D() are parallelizable,
+///
+/// omp.teams {
+/// omp.workdistribute {
+/// A()
+/// B()
+/// C()
+/// D()
+/// E()
+/// }
+/// }
+///
+/// becomes
+///
+/// A()
+/// omp.teams {
+/// omp.workdistribute {
+/// B()
+/// }
+/// }
+/// C()
+/// omp.teams {
+/// omp.workdistribute {
+/// D()
+/// }
+/// }
+/// E()
+static FailureOr<bool>
+fissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
+ OpBuilder rewriter(workdistribute);
+ auto loc = workdistribute->getLoc();
+ auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+ auto *teamsBlock = &teams.getRegion().front();
+ bool changed = false;
+ // Move the ops inside teams and before workdistribute outside.
+ IRMapping irMapping;
+ llvm::SmallVector<Operation *> teamsHoisted;
+ for (auto &op : teams.getOps()) {
+ if (&op == workdistribute) {
+ break;
+ }
+ if (shouldParallelize(&op)) {
+ emitError(loc, "teams has parallelize ops before first workdistribute\n");
+ return failure();
+ } else {
+ rewriter.setInsertionPoint(teams);
+ rewriter.clone(op, irMapping);
+ teamsHoisted.push_back(&op);
+ changed = true;
+ }
+ }
+ for (auto *op : llvm::reverse(teamsHoisted)) {
+ op->replaceAllUsesWith(irMapping.lookup(op));
+ op->erase();
+ }
+
+ // While we have unhandled operations in the original workdistribute
+ auto *workdistributeBlock = &workdistribute.getRegion().front();
+ auto *terminator = workdistributeBlock->getTerminator();
+ while (&workdistributeBlock->front() != terminator) {
+ rewriter.setInsertionPoint(teams);
+ IRMapping mapping;
+ llvm::SmallVector<Operation *> hoisted;
+ Operation *parallelize = nullptr;
+ for (auto &op : workdistribute.getOps()) {
+ if (&op == terminator) {
+ break;
+ }
+ if (shouldParallelize(&op)) {
+ parallelize = &op;
+ break;
+ } else {
+ rewriter.clone(op, mapping);
+ hoisted.push_back(&op);
+ changed = true;
+ }
+ }
+
+ for (auto *op : llvm::reverse(hoisted)) {
+ op->replaceAllUsesWith(mapping.lookup(op));
+ op->erase();
+ }
+
+ if (parallelize && hoisted.empty() &&
+ parallelize->getNextNode() == terminator)
+ break;
+ if (parallelize) {
+ auto newTeams = rewriter.cloneWithoutRegions(teams);
+ auto *newTeamsBlock = rewriter.createBlock(
+ &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {});
+ for (auto arg : teamsBlock->getArguments())
+ newTeamsBlock->addArgument(arg.getType(), arg.getLoc());
+ auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc);
+ rewriter.create<omp::TerminatorOp>(loc);
+ rewriter.createBlock(&newWorkdistribute.getRegion(),
+ newWorkdistribute.getRegion().begin(), {}, {});
+ auto *cloned = rewriter.clone(*parallelize);
+ parallelize->replaceAllUsesWith(cloned);
+ parallelize->erase();
+ rewriter.create<omp::TerminatorOp>(loc);
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+/// Generate omp.parallel operation with an empty region.
+static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) {
+ auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc);
+ parallelOp.setComposite(composite);
+ rewriter.createBlock(&parallelOp.getRegion());
+ rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
+ return;
+}
+
+/// Generate omp.distribute operation with an empty region.
+static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) {
+ mlir::omp::DistributeOperands distributeClauseOps;
+ auto distributeOp =
+ rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps);
+ distributeOp.setComposite(composite);
+ auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion());
+ rewriter.setInsertionPointToStart(distributeBlock);
+ return;
+}
+
+/// Generate loop nest clause operands from fir.do_loop operation.
+static void
+genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop,
+ mlir::omp::LoopNestOperands &loopNestClauseOps) {
+ assert(loopNestClauseOps.loopLowerBounds.empty() &&
+ "Loop nest bounds were already emitted!");
+ loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound());
+ loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound());
+ loopNestClauseOps.loopSteps.push_back(loop.getStep());
+ loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
+}
+
+/// Generate omp.wsloop operation with an empty region and
+/// clone the body of fir.do_loop operation inside the loop nest region.
+static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
+ const mlir::omp::LoopNestOperands &clauseOps,
+ bool composite) {
+
+ auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
+ wsloopOp.setComposite(composite);
+ rewriter.createBlock(&wsloopOp.getRegion());
+
+ auto loopNestOp =
+ rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps);
+
+ // Clone the loop's body inside the loop nest construct using the
+ // mapped values.
+ rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(),
+ loopNestOp.getRegion().begin());
+ Block *clonedBlock = &loopNestOp.getRegion().back();
+ mlir::Operation *terminatorOp = clonedBlock->getTerminator();
+
+ // Erase fir.result op of do loop and create yield op.
+ if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) {
+ rewriter.setInsertionPoint(terminatorOp);
+ rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc());
+ terminatorOp->erase();
+ }
+}
+
+/// workdistributeDoLower method finds the fir.do_loop unoredered
+/// nested in teams {workdistribute{fir.do_loop unoredered}} and
+/// lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}.
+///
+/// If fir.do_loop is present inside teams workdistribute
+///
+/// omp.teams {
+/// omp.workdistribute {
+/// fir.do_loop unoredered {
+/// ...
+/// }
+/// }
+/// }
+///
+/// Then, its lowered to
+///
+/// omp.teams {
+/// omp.parallel {
+/// omp.distribute {
+/// omp.wsloop {
+/// omp.loop_nest
+/// ...
+/// }
+/// }
+/// }
+/// }
+/// }
+static bool
+workdistributeDoLower(omp::WorkdistributeOp workdistribute,
+ SetVector<omp::TargetOp> &targetOpsToProcess) {
+ OpBuilder rewriter(workdistribute);
+ auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
+ auto wdLoc = workdistribute->getLoc();
+ if (doLoop && shouldParallelize(doLoop)) {
+ assert(doLoop.getReduceOperands().empty());
+
+ // Record the target ops to process later
+ if (auto teamsOp = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp())) {
+ auto targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp());
+ if (targetOp) {
+ targetOpsToProcess.insert(targetOp);
+ }
+ }
+ // Generate the nested parallel, distribute, wsloop and loop_nest ops.
+ genParallelOp(wdLoc, rewriter, true);
+ genDistributeOp(wdLoc, rewriter, true);
+ mlir::omp::LoopNestOperands loopNestClauseOps;
+ genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps);
+ genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true);
+ workdistribute.erase();
+ return true;
+ }
+ return false;
+}
+
+/// Check if the enclosed type in fir.ref is fir.box and fir.box encloses array
+static bool isEnclosedTypeRefToBoxArray(Type type) {
+ // Check if it's a reference type
+ if (auto refType = dyn_cast<fir::ReferenceType>(type)) {
+ // Get the referenced type (should be fir.box)
+ auto referencedType = refType.getEleTy();
+ // Check if referenced type is a box
+ if (auto boxType = dyn_cast<fir::BoxType>(referencedType)) {
+ // Get the boxed type and check if it's an array
+ auto boxedType = boxType.getEleTy();
+ // Check if boxed type is a sequence (array)
+ return isa<fir::SequenceType>(boxedType);
+ }
+ }
+ return false;
+}
+
+/// Check if the enclosed type in fir.box is scalar (not array)
+static bool isEnclosedTypeBoxScalar(Type type) {
+ // Check if it's a box type
+ if (auto boxType = dyn_cast<fir::BoxType>(type)) {
+ // Get the boxed type
+ auto boxedType = boxType.getEleTy();
+ // Check if boxed type is NOT a sequence (array)
+ return !isa<fir::SequenceType>(boxedType);
+ }
+ return false;
+}
+
+/// Check if the FortranAAssign call has src as scalar and dest as array
+static bool isFortranAssignSrcScalarAndDestArray(fir::CallOp callOp) {
+ if (callOp.getNumOperands() < 2)
+ return false;
+ auto srcArg = callOp.getOperand(1);
+ auto destArg = callOp.getOperand(0);
+ // Both operands should be fir.convert ops
+ auto srcConvert = srcArg.getDefiningOp<fir::ConvertOp>();
+ auto destConvert = destArg.getDefiningOp<fir::ConvertOp>();
+ if (!srcConvert || !destConvert) {
+ emitError(callOp->getLoc(),
+ "Unimplemented: FortranAssign to OpenMP lowering\n");
+ return false;
+ }
+ // Get the original types before conversion
+ auto srcOrigType = srcConvert.getValue().getType();
+ auto destOrigType = destConvert.getValue().getType();
+
+ // Check if src is scalar and dest is array
+ bool srcIsScalar = isEnclosedTypeBoxScalar(srcOrigType);
+ bool destIsArray = isEnclosedTypeRefToBoxArray(destOrigType);
+ return srcIsScalar && destIsArray;
+}
+
+/// Convert a flat index to multi-dimensional indices for an array box
+/// Example: 2D array with shape (2,4)
+/// Col 1 Col 2 Col 3 Col 4
+/// Row 1: (1,1) (1,2) (1,3) (1,4)
+/// Row 2: (2,1) (2,2) (2,3) (2,4)
+///
+/// extents: (2,4)
+///
+/// flatIdx: 0 1 2 3 4 5 6 7
+/// Indices: (1,1) (1,2) (1,3) (1,4) (2,1) (2,2) (2,3) (2,4)
+static SmallVector<Value> convertFlatToMultiDim(OpBuilder &builder,
+ Location loc, Value flatIdx,
+ Value arrayBox) {
+ // Get array type and rank
+ auto boxType = cast<fir::BoxType>(arrayBox.getType());
+ auto seqType = cast<fir::SequenceType>(boxType.getEleTy());
+ int rank = seqType.getDimension();
+
+ // Get all extents
+ SmallVector<Value> extents;
+ // Get extents for each dimension
+ for (int i = 0; i < rank; ++i) {
+ auto dimIdx = arith::ConstantIndexOp::create(builder, loc, i);
+ auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx);
+ extents.push_back(boxDims.getResult(1));
+ }
+
+ // Convert flat index to multi-dimensional indices
+ SmallVector<Value> indices(rank);
+ Value temp = flatIdx;
+ auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1);
+
+ // Work backwards through dimensions (row-major order)
+ for (int i = rank - 1; i >= 0; --i) {
+ Value zeroBasedIdx = builder.create<arith::RemSIOp>(loc, temp, extents[i]);
+ // Convert to one-based index
+ indices[i] = builder.create<arith::AddIOp>(loc, zeroBasedIdx, c1);
+ if (i > 0) {
+ temp = builder.create<arith::DivSIOp>(loc, temp, extents[i]);
+ }
+ }
+
+ return indices;
+}
+
+/// Calculate the total number of elements in the array box
+/// (totalElems = extent(1) * extent(2) * ... * extent(n))
+static Value CalculateTotalElements(OpBuilder &builder, Location loc,
+ Value arrayBox) {
+ auto boxType = cast<fir::BoxType>(arrayBox.getType());
+ auto seqType = cast<fir::SequenceType>(boxType.getEleTy());
+ int rank = seqType.getDimension();
+
+ Value totalElems = nullptr;
+ for (int i = 0; i < rank; ++i) {
+ auto dimIdx = arith::ConstantIndexOp::create(builder, loc, i);
+ auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx);
+ Value extent = boxDims.getResult(1);
+ if (i == 0) {
+ totalElems = extent;
+ } else {
+ totalElems = builder.create<arith::MulIOp>(loc, totalElems, extent);
+ }
+ }
+ return totalElems;
+}
+
+/// Replace the FortranAAssign runtime call with an unordered do loop
+static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc,
+ omp::TeamsOp teamsOp,
+ omp::WorkdistributeOp workdistribute,
+ fir::CallOp callOp) {
+ auto destConvert = callOp.getOperand(0).getDefiningOp<fir::ConvertOp>();
+ auto srcConvert = callOp.getOperand(1).getDefiningOp<fir::ConvertOp>();
+
+ Value destBox = destConvert.getValue();
+ Value srcBox = srcConvert.getValue();
+
+ // get defining alloca op of destBox and srcBox
+ auto destAlloca = destBox.getDefiningOp<fir::AllocaOp>();
+
+ if (!destAlloca) {
+ emitError(loc, "Unimplemented: FortranAssign to OpenMP lowering\n");
+ return;
+ }
+
+ // get the store op that stores to the alloca
+ for (auto user : destAlloca->getUsers()) {
+ if (auto storeOp = dyn_cast<fir::StoreOp>(user)) {
+ destBox = storeOp.getValue();
+ break;
+ }
+ }
+
+ builder.setInsertionPoint(teamsOp);
+ // Load destination array box (if it's a reference)
+ Value arrayBox = destBox;
+ if (isa<fir::ReferenceType>(destBox.getType()))
+ arrayBox = builder.create<fir::LoadOp>(loc, destBox);
+
+ auto scalarValue = builder.create<fir::BoxAddrOp>(loc, srcBox);
+ Value scalar = builder.create<fir::LoadOp>(loc, scalarValue);
+
+ // Calculate total number of elements (flattened)
+ auto c0 = builder.create<arith::ConstantIndexOp>(loc, 0);
+ auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1);
+ Value totalElems = CalculateTotalElements(builder, loc, arrayBox);
+
+ auto *workdistributeBlock = &workdistribute.getRegion().front();
+ builder.setInsertionPointToStart(workdistributeBlock);
+ // Create single unordered loop for flattened array
+ auto doLoop = fir::DoLoopOp::create(builder, loc, c0, totalElems, c1, true);
+ Block *loopBlock = &doLoop.getRegion().front();
+ builder.setInsertionPointToStart(doLoop.getBody());
+
+ auto flatIdx = loopBlock->getArgument(0);
+ SmallVector<Value> indices =
+ convertFlatToMultiDim(builder, loc, flatIdx, arrayBox);
+ // Use fir.array_coor for linear addressing
+ auto elemPtr = fir::ArrayCoorOp::create(
+ builder, loc, fir::ReferenceType::get(scalar.getType()), arrayBox,
+ nullptr, nullptr, ValueRange{indices}, ValueRange{});
+
+ builder.create<fir::StoreOp>(loc, scalar, elemPtr);
+}
+
+/// workdistributeRuntimeCallLower method finds the runtime calls
+/// nested in teams {workdistribute{}} and
+/// lowers FortranAAssign to unordered do loop if src is scalar and dest is
+/// array. Other runtime calls are not handled currently.
+static FailureOr<bool>
+workdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute,
+ SetVector<omp::TargetOp> &targetOpsToProcess) {
+ OpBuilder rewriter(workdistribute);
+ auto loc = workdistribute->getLoc();
+ auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+ if (!teams) {
+ emitError(loc, "workdistribute not nested in teams\n");
+ return failure();
+ }
+ if (workdistribute.getRegion().getBlocks().size() != 1) {
+ emitError(loc, "workdistribute with multiple blocks\n");
+ return failure();
+ }
+ if (teams.getRegion().getBlocks().size() != 1) {
+ emitError(loc, "teams with multiple blocks\n");
+ return failure();
+ }
+ bool changed = false;
+ // Get the target op parent of teams
+ omp::TargetOp targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp());
+ SmallVector<Operation *> opsToErase;
+ for (auto &op : workdistribute.getOps()) {
+ if (isRuntimeCall(&op)) {
+ rewriter.setInsertionPoint(&op);
+ fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+ auto funcName = runtimeCall.getCallee()->getRootReference().getValue();
+ if (funcName == FortranAssignStr) {
+ if (isFortranAssignSrcScalarAndDestArray(runtimeCall) && targetOp) {
+ // Record the target ops to process later
+ targetOpsToProcess.insert(targetOp);
+ replaceWithUnorderedDoLoop(rewriter, loc, teams, workdistribute,
+ runtimeCall);
+ opsToErase.push_back(&op);
+ changed = true;
+ }
+ }
+ }
+ }
+ // Erase the runtime calls that have been replaced.
+ for (auto *op : opsToErase) {
+ op->erase();
+ }
+ return changed;
+}
+
+/// teamsWorkdistributeToSingleOp method hoists all the ops inside
+/// teams {workdistribute{}} before teams op.
+///
+/// If A() and B () are present inside teams workdistribute
+///
+/// omp.teams {
+/// omp.workdistribute {
+/// A()
+/// B()
+/// }
+/// }
+///
+/// Then, its lowered to
+///
+/// A()
+/// B()
+///
+/// If only the terminator remains in teams after hoisting, we erase teams op.
+static bool
+teamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp,
+ SetVector<omp::TargetOp> &targetOpsToProcess) {
+ auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+ if (!workdistributeOp)
+ return false;
+ // Get the block containing teamsOp (the parent block).
+ Block *parentBlock = teamsOp->getBlock();
+ Block &workdistributeBlock = *workdistributeOp.getRegion().begin();
+ // Record the target ops to process later
+ for (auto &op : workdistributeBlock.getOperations()) {
+ if (shouldParallelize(&op)) {
+ auto targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp());
+ if (targetOp) {
+ targetOpsToProcess.insert(targetOp);
+ }
+ }
+ }
+ auto insertPoint = Block::iterator(teamsOp);
+ // Get the range of operations to move (excluding the terminator).
+ auto workdistributeBegin = workdistributeBlock.begin();
+ auto workdistributeEnd = workdistributeBlock.getTerminator()->getIterator();
+ // Move the operations from workdistribute block to before teamsOp.
+ parentBlock->getOperations().splice(insertPoint,
+ workdistributeBlock.getOperations(),
+ workdistributeBegin, workdistributeEnd);
+ // Erase the now-empty workdistributeOp.
+ workdistributeOp.erase();
+ Block &teamsBlock = *teamsOp.getRegion().begin();
+ // Check if only the terminator remains and erase teams op.
+ if (teamsBlock.getOperations().size() == 1 &&
+ teamsBlock.getTerminator() != nullptr) {
+ teamsOp.erase();
+ }
+ return true;
+}
+
+/// If multiple workdistribute are nested in a target regions, we will need to
+/// split the target region, but we want to preserve the data semantics of the
+/// original data region and avoid unnecessary data movement at each of the
+/// subkernels - we split the target region into a target_data{target}
+/// nest where only the outer one moves the data
+FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
+ RewriterBase &rewriter) {
+ auto loc = targetOp->getLoc();
+ if (targetOp.getMapVars().empty()) {
+ emitError(loc, "Target region has no data maps\n");
+ return failure();
+ }
+ // Collect all the mapinfo ops
+ SmallVector<omp::MapInfoOp> mapInfos;
+ for (auto opr : targetOp.getMapVars()) {
+ auto mapInfo = cast<omp::MapInfoOp>(opr.getDefiningOp());
+ mapInfos.push_back(mapInfo);
+ }
+
+ rewriter.setInsertionPoint(targetOp);
+ SmallVector<Value> innerMapInfos;
+ SmallVector<Value> outerMapInfos;
+ // Create new mapinfo ops for the inner target region
+ for (auto mapInfo : mapInfos) {
+ auto originalMapType =
+ (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType());
+ auto originalCaptureType = mapInfo.getMapCaptureType();
+ llvm::omp::OpenMPOffloadMappingFlags newMapType;
+ mlir::omp::VariableCaptureKind newCaptureType;
+ // For bycopy, we keep the same map type and capture type
+ // For byref, we change the map type to none and keep the capture type
+ if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) {
+ newMapType = originalMapType;
+ newCaptureType = originalCaptureType;
+ } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) {
+ newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
+ newCaptureType = originalCaptureType;
+ outerMapInfos.push_back(mapInfo);
+ } else {
+ emitError(targetOp->getLoc(), "Unhandled case");
+ return failure();
+ }
+ auto innerMapInfo = cast<omp::MapInfoOp>(rewriter.clone(*mapInfo));
+ innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr(
+ rewriter.getIntegerType(64, false),
+ static_cast<
+ std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ newMapType)));
+ innerMapInfo.setMapCaptureType(newCaptureType);
+ innerMapInfos.push_back(innerMapInfo.getResult());
+ }
+
+ rewriter.setInsertionPoint(targetOp);
+ auto device = targetOp.getDevice();
+ auto ifExpr = targetOp.getIfExpr();
+ auto deviceAddrVars = targetOp.getHasDeviceAddrVars();
+ auto devicePtrVars = targetOp.getIsDevicePtrVars();
+ // Create the target data op
+ auto targetDataOp = rewriter.create<omp::TargetDataOp>(
+ loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars);
+ auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion());
+ rewriter.create<mlir::omp::TerminatorOp>(loc);
+ rewriter.setInsertionPointToStart(taregtDataBlock);
+ // Create the inner target op
+ auto newTargetOp = rewriter.create<omp::TargetOp>(
+ targetOp.getLoc(), targetOp.getAllocateVars(),
+ targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+ targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+ targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+ targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+ targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+ targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+ innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+ targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+ targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+ rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
+ newTargetOp.getRegion().begin());
+ rewriter.replaceOp(targetOp, targetDataOp);
+ return newTargetOp;
+}
+
+/// getNestedOpToIsolate function is designed to identify a specific teams
+/// parallel op within the body of an omp::TargetOp that should be "isolated."
+/// This returns a tuple of op, if its first op in targetBlock, or if the op is
+/// last op in the traget block.
+static std::optional<std::tuple<Operation *, bool, bool>>
+getNestedOpToIsolate(omp::TargetOp targetOp) {
+ if (targetOp.getRegion().empty())
+ return std::nullopt;
+ auto *targetBlock = &targetOp.getRegion().front();
+ for (auto &op : *targetBlock) {
+ bool first = &op == &*targetBlock->begin();
+ bool last = op.getNextNode() == targetBlock->getTerminator();
+ if (first && last)
+ return std::nullopt;
+
+ if (isa<omp::TeamsOp>(&op))
+ return {{&op, first, last}};
+ }
+ return std::nullopt;
+}
+
+/// Temporary structure to hold the two mapinfo ops
+struct TempOmpVar {
+ omp::MapInfoOp from, to;
+};
+
+/// isPtr checks if the type is a pointer or reference type.
+static bool isPtr(Type ty) {
+ return isa<fir::ReferenceType>(ty) || isa<LLVM::LLVMPointerType>(ty);
+}
+
+/// getPtrTypeForOmp returns an LLVM pointer type for the given type.
+static Type getPtrTypeForOmp(Type ty) {
+ if (isPtr(ty))
+ return LLVM::LLVMPointerType::get(ty.getContext());
+ else
+ return fir::ReferenceType::get(ty);
+}
+
+/// allocateTempOmpVar allocates a temporary variable for OpenMP mapping
+static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
+ RewriterBase &rewriter) {
+ MLIRContext &ctx = *ty.getContext();
+ Value alloc;
+ Type allocType;
+ auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
+ // Get the appropriate type for allocation
+ if (isPtr(ty)) {
+ Type intTy = rewriter.getI32Type();
+ auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1);
+ allocType = llvmPtrTy;
+ alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one);
+ allocType = intTy;
+ } else {
+ allocType = ty;
+ alloc = rewriter.create<fir::AllocaOp>(loc, allocType);
+ }
+ // Lambda to create mapinfo ops
+ auto getMapInfo = [&](uint64_t mappingFlags, const char *name) {
+ return rewriter.create<omp::MapInfoOp>(
+ loc, alloc.getType(), alloc, TypeAttr::get(allocType),
+ rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false),
+ mappingFlags),
+ rewriter.getAttr<omp::VariableCaptureKindAttr>(
+ omp::VariableCaptureKind::ByRef),
+ /*varPtrPtr=*/Value{},
+ /*members=*/SmallVector<Value>{},
+ /*member_index=*/mlir::ArrayAttr{},
+ /*bounds=*/ValueRange(),
+ /*mapperId=*/mlir::FlatSymbolRefAttr(),
+ /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false));
+ };
+ // Create mapinfo ops.
+ uint64_t mapFrom =
+ static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
+ uint64_t mapTo =
+ static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
+ auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from");
+ auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to");
+ return TempOmpVar{mapInfoFrom, mapInfoTo};
+}
+
+// usedOutsideSplit checks if a value is used outside the split operation.
+static bool usedOutsideSplit(Value v, Operation *split) {
+ if (!split)
+ return false;
+ auto targetOp = cast<omp::TargetOp>(split->getParentOp());
+ auto *targetBlock = &targetOp.getRegion().front();
+ for (auto *user : v.getUsers()) {
+ while (user->getBlock() != targetBlock) {
+ user = user->getParentOp();
+ }
+ if (!user->isBeforeInBlock(split))
+ return true;
+ }
+ return false;
+}
+
+/// isRecomputableAfterFission checks if an operation can be recomputed
+static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) {
+ // If the op has side effects, it cannot be recomputed.
+ // We consider fir.declare as having no side effects.
+ return isa<fir::DeclareOp>(op) || isMemoryEffectFree(op);
+}
+
+/// collectNonRecomputableDeps collects dependencies that cannot be recomputed
+static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp,
+ SetVector<Operation *> &nonRecomputable,
+ SetVector<Operation *> &toCache,
+ SetVector<Operation *> &toRecompute) {
+ Operation *op = v.getDefiningOp();
+ // If v is a block argument, it must be from the targetOp.
+ if (!op) {
+ assert(cast<BlockArgument>(v).getOwner()->getParentOp() == targetOp);
+ return;
+ }
+ // If the op is in the nonRecomputable set, add it to toCache and return.
+ if (nonRecomputable.contains(op)) {
+ toCache.insert(op);
+ return;
+ }
+ // Add the op to toRecompute.
+ toRecompute.insert(op);
+ for (auto opr : op->getOperands())
+ collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache,
+ toRecompute);
+}
+
+/// createBlockArgsAndMap creates block arguments and maps them
+static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter,
+ omp::TargetOp &targetOp, Block *targetBlock,
+ Block *newTargetBlock,
+ SmallVector<Value> &hostEvalVars,
+ SmallVector<Value> &mapOperands,
+ SmallVector<Value> &allocs,
+ IRMapping &irMapping) {
+ // FIRST: Map `host_eval_vars` to block arguments
+ unsigned originalHostEvalVarsSize = targetOp.getHostEvalVars().size();
+ for (unsigned i = 0; i < hostEvalVars.size(); ++i) {
+ Value originalValue;
+ BlockArgument newArg;
+ if (i < originalHostEvalVarsSize) {
+ originalValue = targetBlock->getArgument(i); // Host_eval args come first
+ newArg = newTargetBlock->addArgument(originalValue.getType(),
+ originalValue.getLoc());
+ } else {
+ originalValue = hostEvalVars[i];
+ newArg = newTargetBlock->addArgument(originalValue.getType(),
+ originalValue.getLoc());
+ }
+ irMapping.map(originalValue, newArg);
+ }
+
+ // SECOND: Map `map_operands` to block arguments
+ unsigned originalMapVarsSize = targetOp.getMapVars().size();
+ for (unsigned i = 0; i < mapOperands.size(); ++i) {
+ Value originalValue;
+ BlockArgument newArg;
+ // Map the new arguments from the original block.
+ if (i < originalMapVarsSize) {
+ originalValue = targetBlock->getArgument(originalHostEvalVarsSize +
+ i); // Offset by host_eval count
+ newArg = newTargetBlock->addArgument(originalValue.getType(),
+ originalValue.getLoc());
+ }
+ // Map the new arguments from the `allocs`.
+ else {
+ originalValue = allocs[i - originalMapVarsSize];
+ newArg = newTargetBlock->addArgument(
+ getPtrTypeForOmp(originalValue.getType()), originalValue.getLoc());
+ }
+ irMapping.map(originalValue, newArg);
+ }
+
+ // THIRD: Map `private_vars` to block arguments (if any)
+ unsigned originalPrivateVarsSize = targetOp.getPrivateVars().size();
+ for (unsigned i = 0; i < originalPrivateVarsSize; ++i) {
+ auto originalArg = targetBlock->getArgument(originalHostEvalVarsSize +
+ originalMapVarsSize + i);
+ auto newArg = newTargetBlock->addArgument(originalArg.getType(),
+ originalArg.getLoc());
+ irMapping.map(originalArg, newArg);
+ }
+ return;
+}
+
+/// reloadCacheAndRecompute reloads cached values and recomputes operations
+static void reloadCacheAndRecompute(
+ Location loc, RewriterBase &rewriter, Operation *splitBefore,
+ omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock,
+ SmallVector<Value> &hostEvalVars, SmallVector<Value> &mapOperands,
+ SmallVector<Value> &allocs, SetVector<Operation *> &toRecompute,
+ IRMapping &irMapping) {
+ // Handle the load operations for the allocs.
+ rewriter.setInsertionPointToStart(newTargetBlock);
+ auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
+
+ unsigned originalMapVarsSize = targetOp.getMapVars().size();
+ unsigned hostEvalVarsSize = hostEvalVars.size();
+ // Create load operations for each allocated variable.
+ for (unsigned i = 0; i < allocs.size(); ++i) {
+ Value original = allocs[i];
+ // Get the new block argument for this specific allocated value.
+ Value newArg =
+ newTargetBlock->getArgument(hostEvalVarsSize + originalMapVarsSize + i);
+ Value restored;
+ // If the original value is a pointer or reference, load and convert if
+ // necessary.
+ if (isPtr(original.getType())) {
+ restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg);
+ if (!isa<LLVM::LLVMPointerType>(original.getType()))
+ restored =
+ rewriter.create<fir::ConvertOp>(loc, original.getType(), restored);
+ } else {
+ restored = rewriter.create<fir::LoadOp>(loc, newArg);
+ }
+ irMapping.map(original, restored);
+ }
+ // Clone the operations if they are in the toRecompute set.
+ for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) {
+ if (toRecompute.contains(&*it))
+ rewriter.clone(*it, irMapping);
+ }
+}
+
+/// Given a teamsOp, navigate down the nested structure to find the
+/// innermost LoopNestOp. The expected nesting is:
+/// teams -> parallel -> distribute -> wsloop -> loop_nest
+static mlir::omp::LoopNestOp getLoopNestFromTeams(mlir::omp::TeamsOp teamsOp) {
+ if (teamsOp.getRegion().empty())
+ return nullptr;
+ // Ensure the teams region has a single block.
+ if (teamsOp.getRegion().getBlocks().size() != 1)
+ return nullptr;
+ // Find parallel op inside teams
+ mlir::omp::ParallelOp parallelOp = nullptr;
+ // Look for the parallel op in the teams region
+ for (auto &op : teamsOp.getRegion().front()) {
+ if (auto parallel = dyn_cast<mlir::omp::ParallelOp>(op)) {
+ parallelOp = parallel;
+ break;
+ }
+ }
+ if (!parallelOp)
+ return nullptr;
+
+ // Find distribute op inside parallel
+ mlir::omp::DistributeOp distributeOp = nullptr;
+ for (auto &op : parallelOp.getRegion().front()) {
+ if (auto distribute = dyn_cast<mlir::omp::DistributeOp>(op)) {
+ distributeOp = distribute;
+ break;
+ }
+ }
+ if (!distributeOp)
+ return nullptr;
+
+ // Find wsloop op inside distribute
+ mlir::omp::WsloopOp wsloopOp = nullptr;
+ for (auto &op : distributeOp.getRegion().front()) {
+ if (auto wsloop = dyn_cast<mlir::omp::WsloopOp>(op)) {
+ wsloopOp = wsloop;
+ break;
+ }
+ }
+ if (!wsloopOp)
+ return nullptr;
+
+ // Find loop_nest op inside wsloop
+ for (auto &op : wsloopOp.getRegion().front()) {
+ if (auto loopNest = dyn_cast<mlir::omp::LoopNestOp>(op)) {
+ return loopNest;
+ }
+ }
+
+ return nullptr;
+}
+
+/// Generate LLVM constant operations for i32 and i64 types.
+static mlir::LLVM::ConstantOp
+genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) {
+ mlir::Type i32Ty = rewriter.getI32Type();
+ mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value);
+ return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
+}
+
+/// Given a box descriptor, extract the base address of the data it describes.
+/// If the box descriptor is a reference, load it first.
+/// The base address is returned as an i8* pointer.
+static Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder,
+ Location loc, Value boxDesc) {
+ Value box = boxDesc;
+ if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) {
+ box = fir::LoadOp::create(builder, loc, boxDesc);
+ }
+ assert(isa<fir::BoxType>(box.getType()) &&
+ "Unknown type passed to genDescriptorGetBaseAddress");
+ auto i8Type = builder.getI8Type();
+ auto unknownArrayType =
+ fir::SequenceType::get({fir::SequenceType::getUnknownExtent()}, i8Type);
+ auto i8BoxType = fir::BoxType::get(unknownArrayType);
+ auto typedBox = fir::ConvertOp::create(builder, loc, i8BoxType, box);
+ auto rawAddr = fir::BoxAddrOp::create(builder, loc, typedBox);
+ return rawAddr;
+}
+
+/// Given a box descriptor, extract the total number of elements in the array it
+/// describes. If the box descriptor is a reference, load it first.
+/// The total number of elements is returned as an i64 value.
+static Value genDescriptorGetTotalElements(fir::FirOpBuilder &builder,
+ Location loc, Value boxDesc) {
+ Value box = boxDesc;
+ if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) {
+ box = fir::LoadOp::create(builder, loc, boxDesc);
+ }
+ assert(isa<fir::BoxType>(box.getType()) &&
+ "Unknown type passed to genDescriptorGetTotalElements");
+ auto i64Type = builder.getI64Type();
+ return fir::BoxTotalElementsOp::create(builder, loc, i64Type, box);
+}
+
+/// Given a box descriptor, extract the size of each element in the array it
+/// describes. If the box descriptor is a reference, load it first.
+/// The element size is returned as an i64 value.
+static Value genDescriptorGetEleSize(fir::FirOpBuilder &builder, Location loc,
+ Value boxDesc) {
+ Value box = boxDesc;
+ if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) {
+ box = fir::LoadOp::create(builder, loc, boxDesc);
+ }
+ assert(isa<fir::BoxType>(box.getType()) &&
+ "Unknown type passed to genDescriptorGetElementSize");
+ auto i64Type = builder.getI64Type();
+ return fir::BoxEleSizeOp::create(builder, loc, i64Type, box);
+}
+
+/// Given a box descriptor, compute the total size in bytes of the data it
+/// describes. This is done by multiplying the total number of elements by the
+/// size of each element. If the box descriptor is a reference, load it first.
+/// The total size in bytes is returned as an i64 value.
+static Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder,
+ Location loc, Value boxDesc) {
+ Value box = boxDesc;
+ if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) {
+ box = fir::LoadOp::create(builder, loc, boxDesc);
+ }
+ assert(isa<fir::BoxType>(box.getType()) &&
+ "Unknown type passed to genDescriptorGetElementSize");
+ Value eleSize = genDescriptorGetEleSize(builder, loc, box);
+ Value totalElements = genDescriptorGetTotalElements(builder, loc, box);
+ return mlir::arith::MulIOp::create(builder, loc, totalElements, eleSize);
+}
+
+/// Generate a call to the OpenMP runtime function `omp_get_mapped_ptr` to
+/// retrieve the device pointer corresponding to a given host pointer and device
+/// number. If no mapping exists, the original host pointer is returned.
+/// Signature:
+/// void *omp_get_mapped_ptr(void *host_ptr, int device_num);
+static mlir::Value genOmpGetMappedPtrIfPresent(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ mlir::Value hostPtr,
+ mlir::Value deviceNum,
+ mlir::ModuleOp module) {
+ auto *context = builder.getContext();
+ auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type());
+ auto i32Type = builder.getI32Type();
+ auto funcName = "omp_get_mapped_ptr";
+ auto funcOp = module.lookupSymbol<mlir::func::FuncOp>(funcName);
+
+ if (!funcOp) {
+ auto funcType =
+ mlir::FunctionType::get(context, {voidPtrType, i32Type}, {voidPtrType});
+
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(module.getBody());
+
+ funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType);
+ funcOp.setPrivate();
+ }
+
+ llvm::SmallVector<mlir::Value> args;
+ args.push_back(fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr));
+ args.push_back(fir::ConvertOp::create(builder, loc, i32Type, deviceNum));
+ auto callOp = fir::CallOp::create(builder, loc, funcOp, args);
+ auto mappedPtr = callOp.getResult(0);
+ auto isNull = builder.genIsNullAddr(loc, mappedPtr);
+ auto convertedHostPtr =
+ fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr);
+ auto result = arith::SelectOp::create(builder, loc, isNull, convertedHostPtr,
+ mappedPtr);
+ return result;
+}
+
+/// Generate a call to the OpenMP runtime function `omp_target_memcpy` to
+/// perform memory copy between host and device or between devices.
+/// Signature:
+/// int omp_target_memcpy(void *dst, const void *src, size_t length,
+/// size_t dst_offset, size_t src_offset,
+/// int dst_device, int src_device);
+static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Value dst,
+ mlir::Value src, mlir::Value length,
+ mlir::Value dstOffset, mlir::Value srcOffset,
+ mlir::Value device, mlir::ModuleOp module) {
+ auto *context = builder.getContext();
+ auto funcName = "omp_target_memcpy";
+ auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type());
+ auto sizeTType = builder.getI64Type(); // assuming size_t is 64-bit
+ auto i32Type = builder.getI32Type();
+ auto funcOp = module.lookupSymbol<mlir::func::FuncOp>(funcName);
+
+ if (!funcOp) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(module.getBody());
+ llvm::SmallVector<mlir::Type> argTypes = {
+ voidPtrType, voidPtrType, sizeTType, sizeTType,
+ sizeTType, i32Type, i32Type};
+ auto funcType = mlir::FunctionType::get(context, argTypes, {i32Type});
+ funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType);
+ funcOp.setPrivate();
+ }
+
+ llvm::SmallVector<mlir::Value> args{dst, src, length, dstOffset,
+ srcOffset, device, device};
+ fir::CallOp::create(builder, loc, funcOp, args);
+ return;
+}
+
+/// Generate code to replace a Fortran array assignment call with OpenMP
+/// runtime calls to perform the equivalent operation on the device.
+/// This involves extracting the source and destination pointers from the
+/// Fortran array descriptors, retrieving their mapped device pointers (if any),
+/// and invoking `omp_target_memcpy` to copy the data on the device.
+static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ fir::CallOp callOp,
+ mlir::Value device,
+ mlir::ModuleOp module) {
+ assert(callOp.getNumResults() == 0 &&
+ "Expected _FortranAAssign to have no results");
+ assert(callOp.getNumOperands() >= 2 &&
+ "Expected _FortranAAssign to have at least two operands");
+
+ // Extract the source and destination pointers from the call operands.
+ mlir::Value dest = callOp.getOperand(0);
+ mlir::Value src = callOp.getOperand(1);
+
+ // Get the base addresses of the source and destination arrays.
+ mlir::Value srcBase = genDescriptorGetBaseAddress(builder, loc, src);
+ mlir::Value destBase = genDescriptorGetBaseAddress(builder, loc, dest);
+
+ // Get the total size in bytes of the data to be copied.
+ mlir::Value srcDataSize = genDescriptorGetDataSizeInBytes(builder, loc, src);
+
+ // Retrieve the mapped device pointers for source and destination.
+ // If no mapping exists, the original host pointer is used.
+ Value destPtr =
+ genOmpGetMappedPtrIfPresent(builder, loc, destBase, device, module);
+ Value srcPtr =
+ genOmpGetMappedPtrIfPresent(builder, loc, srcBase, device, module);
+ Value zero = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
+ builder.getI64IntegerAttr(0));
+
+ // Generate the call to omp_target_memcpy to perform the data copy on the
+ // device.
+ genOmpTargetMemcpyCall(builder, loc, destPtr, srcPtr, srcDataSize, zero, zero,
+ device, module);
+}
+
+/// Struct to hold the host eval vars corresponding to loop bounds and steps
+struct HostEvalVars {
+ SmallVector<Value> lbs;
+ SmallVector<Value> ubs;
+ SmallVector<Value> steps;
+};
+
+/// moveToHost method clones all the ops from target region outside of it.
+/// It hoists runtime function "_FortranAAssign" and replaces it with omp
+/// version. Also hoists and replaces fir.allocmem with omp.target_allocmem and
+/// fir.freemem with omp.target_freemem
+static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter,
+ mlir::ModuleOp module,
+ struct HostEvalVars &hostEvalVars) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ Block *targetBlock = &targetOp.getRegion().front();
+ assert(targetBlock == &targetOp.getRegion().back());
+ IRMapping mapping;
+
+ // Get the parent target_data op
+ auto targetDataOp = cast<omp::TargetDataOp>(targetOp->getParentOp());
+ if (!targetDataOp) {
+ emitError(targetOp->getLoc(),
+ "Expected target op to be inside target_data op");
+ return failure();
+ }
+ // create mapping for host_eval_vars
+ unsigned hostEvalVarCount = targetOp.getHostEvalVars().size();
+ for (unsigned i = 0; i < targetOp.getHostEvalVars().size(); ++i) {
+ Value hostEvalVar = targetOp.getHostEvalVars()[i];
+ BlockArgument arg = targetBlock->getArguments()[i];
+ mapping.map(arg, hostEvalVar);
+ }
+ // create mapping for map_vars
+ for (unsigned i = 0; i < targetOp.getMapVars().size(); ++i) {
+ Value mapInfo = targetOp.getMapVars()[i];
+ BlockArgument arg = targetBlock->getArguments()[hostEvalVarCount + i];
+ Operation *op = mapInfo.getDefiningOp();
+ assert(op);
+ auto mapInfoOp = cast<omp::MapInfoOp>(op);
+ // map the block argument to the host-side variable pointer
+ mapping.map(arg, mapInfoOp.getVarPtr());
+ }
+ // create mapping for private_vars
+ unsigned mapSize = targetOp.getMapVars().size();
+ for (unsigned i = 0; i < targetOp.getPrivateVars().size(); ++i) {
+ Value privateVar = targetOp.getPrivateVars()[i];
+ // The mapping should link the device-side variable to the host-side one.
+ BlockArgument arg =
+ targetBlock->getArguments()[hostEvalVarCount + mapSize + i];
+ // Map the device-side copy (`arg`) to the host-side value (`privateVar`).
+ mapping.map(arg, privateVar);
+ }
+
+ rewriter.setInsertionPoint(targetOp);
+ SmallVector<Operation *> opsToReplace;
+ Value device = targetOp.getDevice();
+
+ // If device is not specified, default to device 0.
+ if (!device) {
+ device = genI32Constant(targetOp.getLoc(), rewriter, 0);
+ }
+ // Clone all operations.
+ for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end());
+ it != end; ++it) {
+ auto *op = &*it;
+ Operation *clonedOp = rewriter.clone(*op, mapping);
+ // Map the results of the original op to the cloned op.
+ for (unsigned i = 0; i < op->getNumResults(); ++i) {
+ mapping.map(op->getResult(i), clonedOp->getResult(i));
+ }
+ // fir.declare changes its type when hoisting it out of omp.target to
+ // omp.target_data Introduce a load, if original declareOp input is not of
+ // reference type, but cloned delcareOp input is reference type.
+ if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) {
+ auto originalDeclareOp = cast<fir::DeclareOp>(op);
+ Type originalInType = originalDeclareOp.getMemref().getType();
+ Type clonedInType = clonedDeclareOp.getMemref().getType();
+
+ fir::ReferenceType originalRefType =
+ dyn_cast<fir::ReferenceType>(originalInType);
+ fir::ReferenceType clonedRefType =
+ dyn_cast<fir::ReferenceType>(clonedInType);
+ if (!originalRefType && clonedRefType) {
+ Type clonedEleTy = clonedRefType.getElementType();
+ if (clonedEleTy == originalDeclareOp.getType()) {
+ opsToReplace.push_back(clonedOp);
+ }
+ }
+ }
+ // Collect the ops to be replaced.
+ if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp))
+ opsToReplace.push_back(clonedOp);
+ // Check for runtime calls to be replaced.
+ if (isRuntimeCall(clonedOp)) {
+ fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+ auto funcName = runtimeCall.getCallee()->getRootReference().getValue();
+ if (funcName == FortranAssignStr) {
+ opsToReplace.push_back(clonedOp);
+ } else {
+ emitError(runtimeCall->getLoc(), "Unhandled runtime call hoisting.");
+ return failure();
+ }
+ }
+ }
+ // Replace fir.allocmem with omp.target_allocmem.
+ for (Operation *op : opsToReplace) {
+ if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) {
+ rewriter.setInsertionPoint(allocOp);
+ auto ompAllocmemOp = rewriter.create<omp::TargetAllocMemOp>(
+ allocOp.getLoc(), rewriter.getI64Type(), device,
+ allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(),
+ allocOp.getBindcNameAttr(), allocOp.getTypeparams(),
+ allocOp.getShape());
+ auto firConvertOp = rewriter.create<fir::ConvertOp>(
+ allocOp.getLoc(), allocOp.getResult().getType(),
+ ompAllocmemOp.getResult());
+ rewriter.replaceOp(allocOp, firConvertOp.getResult());
+ }
+ // Replace fir.freemem with omp.target_freemem.
+ else if (auto freeOp = dyn_cast<fir::FreeMemOp>(op)) {
+ rewriter.setInsertionPoint(freeOp);
+ auto firConvertOp = rewriter.create<fir::ConvertOp>(
+ freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref());
+ rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device,
+ firConvertOp.getResult());
+ rewriter.eraseOp(freeOp);
+ }
+ // fir.declare changes its type when hoisting it out of omp.target to
+ // omp.target_data Introduce a load, if original declareOp input is not of
+ // reference type, but cloned delcareOp input is reference type.
+ else if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(op)) {
+ Type clonedInType = clonedDeclareOp.getMemref().getType();
+ fir::ReferenceType clonedRefType =
+ dyn_cast<fir::ReferenceType>(clonedInType);
+ Type clonedEleTy = clonedRefType.getElementType();
+ rewriter.setInsertionPoint(op);
+ Value loadedValue = rewriter.create<fir::LoadOp>(
+ clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref());
+ clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue);
+ }
+ // Replace runtime calls with omp versions.
+ else if (isRuntimeCall(op)) {
+ fir::CallOp runtimeCall = cast<fir::CallOp>(op);
+ auto funcName = runtimeCall.getCallee()->getRootReference().getValue();
+ if (funcName == FortranAssignStr) {
+ rewriter.setInsertionPoint(op);
+ fir::FirOpBuilder builder{rewriter, op};
+
+ mlir::Location loc = runtimeCall.getLoc();
+ genFortranAssignOmpReplacement(builder, loc, runtimeCall, device,
+ module);
+ rewriter.eraseOp(op);
+ } else {
+ emitError(runtimeCall->getLoc(), "Unhandled runtime call hoisting.");
+ return failure();
+ }
+ } else {
+ emitError(op->getLoc(), "Unhandled op hoisting.");
+ return failure();
+ }
+ }
+
+ // Update the host_eval_vars to use the mapped values.
+ for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) {
+ hostEvalVars.lbs[i] = mapping.lookup(hostEvalVars.lbs[i]);
+ hostEvalVars.ubs[i] = mapping.lookup(hostEvalVars.ubs[i]);
+ hostEvalVars.steps[i] = mapping.lookup(hostEvalVars.steps[i]);
+ }
+ // Finally erase the original targetOp.
+ rewriter.eraseOp(targetOp);
+ return success();
+}
+
+/// Result of isolateOp method
+struct SplitResult {
+ omp::TargetOp preTargetOp;
+ omp::TargetOp isolatedTargetOp;
+ omp::TargetOp postTargetOp;
+};
+
+/// computeAllocsCacheRecomputable method computes the allocs needed to cache
+/// the values that are used outside the split point. It also computes the ops
+/// that need to be cached and the ops that can be recomputed after the split.
+static void computeAllocsCacheRecomputable(
+ omp::TargetOp targetOp, Operation *splitBeforeOp, RewriterBase &rewriter,
+ SmallVector<Value> &preMapOperands, SmallVector<Value> &postMapOperands,
+ SmallVector<Value> &allocs, SmallVector<Value> &requiredVals,
+ SetVector<Operation *> &nonRecomputable, SetVector<Operation *> &toCache,
+ SetVector<Operation *> &toRecompute) {
+ auto *targetBlock = &targetOp.getRegion().front();
+ // Find all values that are used outside the split point.
+ for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator();
+ it++) {
+ // Check if any of the results are used outside the split point.
+ for (auto res : it->getResults()) {
+ if (usedOutsideSplit(res, splitBeforeOp)) {
+ requiredVals.push_back(res);
+ }
+ }
+ // If the op is not recomputable, add it to the nonRecomputable set.
+ if (!isRecomputableAfterFission(&*it, splitBeforeOp)) {
+ nonRecomputable.insert(&*it);
+ }
+ }
+ // For each required value, collect its dependencies.
+ for (auto requiredVal : requiredVals)
+ collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache,
+ toRecompute);
+ // For each op in toCache, create an alloc and update the pre and post map
+ // operands.
+ for (Operation *op : toCache) {
+ for (auto res : op->getResults()) {
+ auto alloc =
+ allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter);
+ allocs.push_back(res);
+ preMapOperands.push_back(alloc.from);
+ postMapOperands.push_back(alloc.to);
+ }
+ }
+}
+
+/// genPreTargetOp method generates the preTargetOp that contains all the ops
+/// before the split point. It also creates the block arguments and maps the
+/// values accordingly. It also creates the store operations for the allocs.
+static omp::TargetOp
+genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
+ SmallVector<Value> &allocs, Operation *splitBeforeOp,
+ RewriterBase &rewriter, struct HostEvalVars &hostEvalVars,
+ bool isTargetDevice) {
+ auto loc = targetOp.getLoc();
+ auto *targetBlock = &targetOp.getRegion().front();
+ SmallVector<Value> preHostEvalVars{targetOp.getHostEvalVars()};
+ // update the hostEvalVars of preTargetOp
+ omp::TargetOp preTargetOp = rewriter.create<omp::TargetOp>(
+ targetOp.getLoc(), targetOp.getAllocateVars(),
+ targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+ targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+ targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars,
+ targetOp.getIfExpr(), targetOp.getInReductionVars(),
+ targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
+ targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(),
+ targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
+ targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
+ targetOp.getPrivateMapsAttr());
+ auto *preTargetBlock = rewriter.createBlock(
+ &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
+ IRMapping preMapping;
+ // Create block arguments and map the values.
+ createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, preTargetBlock,
+ preHostEvalVars, preMapOperands, allocs, preMapping);
+
+ // Handle the store operations for the allocs.
+ rewriter.setInsertionPointToStart(preTargetBlock);
+ auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext());
+
+ // Clone the original operations.
+ for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator();
+ it++) {
+ rewriter.clone(*it, preMapping);
+ }
+
+ unsigned originalHostEvalVarsSize = preHostEvalVars.size();
+ unsigned originalMapVarsSize = targetOp.getMapVars().size();
+ // Create Stores for allocs.
+ for (unsigned i = 0; i < allocs.size(); ++i) {
+ Value originalResult = allocs[i];
+ Value toStore = preMapping.lookup(originalResult);
+ // Get the new block argument for this specific allocated value.
+ Value newArg = preTargetBlock->getArgument(originalHostEvalVarsSize +
+ originalMapVarsSize + i);
+ // Create the store operation.
+ if (isPtr(originalResult.getType())) {
+ if (!isa<LLVM::LLVMPointerType>(toStore.getType()))
+ toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore);
+ rewriter.create<LLVM::StoreOp>(loc, toStore, newArg);
+ } else {
+ rewriter.create<fir::StoreOp>(loc, toStore, newArg);
+ }
+ }
+ rewriter.create<omp::TerminatorOp>(loc);
+
+ // Update hostEvalVars with the mapped values for the loop bounds if we have
+ // a loopNestOp and we are not generating code for the target device.
+ omp::LoopNestOp loopNestOp =
+ getLoopNestFromTeams(cast<omp::TeamsOp>(splitBeforeOp));
+ if (loopNestOp && !isTargetDevice) {
+ for (size_t i = 0; i < loopNestOp.getLoopLowerBounds().size(); ++i) {
+ Value lb = loopNestOp.getLoopLowerBounds()[i];
+ Value ub = loopNestOp.getLoopUpperBounds()[i];
+ Value step = loopNestOp.getLoopSteps()[i];
+
+ hostEvalVars.lbs.push_back(preMapping.lookup(lb));
+ hostEvalVars.ubs.push_back(preMapping.lookup(ub));
+ hostEvalVars.steps.push_back(preMapping.lookup(step));
+ }
+ }
+
+ return preTargetOp;
+}
+
+/// genIsolatedTargetOp method generates the isolatedTargetOp that contains the
+/// ops between the split point. It also creates the block arguments and maps
+/// the values accordingly. It also creates the load operations for the allocs
+/// and recomputes the necessary ops.
+static omp::TargetOp
+genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
+ Operation *splitBeforeOp, RewriterBase &rewriter,
+ SmallVector<Value> &allocs,
+ SetVector<Operation *> &toRecompute,
+ struct HostEvalVars &hostEvalVars, bool isTargetDevice) {
+ auto loc = targetOp.getLoc();
+ auto *targetBlock = &targetOp.getRegion().front();
+ SmallVector<Value> isolatedHostEvalVars{targetOp.getHostEvalVars()};
+ // update the hostEvalVars of isolatedTargetOp
+ if (!hostEvalVars.lbs.empty() && !isTargetDevice) {
+ isolatedHostEvalVars.append(hostEvalVars.lbs.begin(),
+ hostEvalVars.lbs.end());
+ isolatedHostEvalVars.append(hostEvalVars.ubs.begin(),
+ hostEvalVars.ubs.end());
+ isolatedHostEvalVars.append(hostEvalVars.steps.begin(),
+ hostEvalVars.steps.end());
+ }
+ // Create the isolated target op
+ omp::TargetOp isolatedTargetOp = rewriter.create<omp::TargetOp>(
+ targetOp.getLoc(), targetOp.getAllocateVars(),
+ targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+ targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+ targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+ isolatedHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(),
+ targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
+ targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
+ targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
+ targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
+ targetOp.getPrivateMapsAttr());
+ auto *isolatedTargetBlock =
+ rewriter.createBlock(&isolatedTargetOp.getRegion(),
+ isolatedTargetOp.getRegion().begin(), {}, {});
+ IRMapping isolatedMapping;
+ // Create block arguments and map the values.
+ createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock,
+ isolatedTargetBlock, isolatedHostEvalVars,
+ postMapOperands, allocs, isolatedMapping);
+ // Handle the load operations for the allocs and recompute ops.
+ reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock,
+ isolatedTargetBlock, isolatedHostEvalVars,
+ postMapOperands, allocs, toRecompute,
+ isolatedMapping);
+
+ // Clone the original operations.
+ rewriter.clone(*splitBeforeOp, isolatedMapping);
+ rewriter.create<omp::TerminatorOp>(loc);
+
+ // update the loop bounds in the isolatedTargetOp if we have host_eval vars
+ // and we are not generating code for the target device.
+ if (!hostEvalVars.lbs.empty() && !isTargetDevice) {
+ omp::TeamsOp teamsOp;
+ for (auto &op : *isolatedTargetBlock) {
+ if (isa<omp::TeamsOp>(&op))
+ teamsOp = cast<omp::TeamsOp>(&op);
+ }
+ assert(teamsOp && "No teamsOp found in isolated target region");
+ // Get the loopNestOp inside the teamsOp
+ auto loopNestOp = getLoopNestFromTeams(teamsOp);
+ // Get the BlockArgs related to host_eval vars and update loop_nest bounds
+ // to them
+ unsigned originalHostEvalVarsSize = targetOp.getHostEvalVars().size();
+ unsigned index = originalHostEvalVarsSize;
+ // Replace loop bounds with the block arguments passed down via host_eval
+ SmallVector<Value> lbs, ubs, steps;
+
+ // Collect new lb/ub/step values from target block args
+ for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i)
+ lbs.push_back(isolatedTargetBlock->getArgument(index++));
+
+ for (size_t i = 0; i < hostEvalVars.ubs.size(); ++i)
+ ubs.push_back(isolatedTargetBlock->getArgument(index++));
+
+ for (size_t i = 0; i < hostEvalVars.steps.size(); ++i)
+ steps.push_back(isolatedTargetBlock->getArgument(index++));
+
+ // Reset the loop bounds
+ loopNestOp.getLoopLowerBoundsMutable().assign(lbs);
+ loopNestOp.getLoopUpperBoundsMutable().assign(ubs);
+ loopNestOp.getLoopStepsMutable().assign(steps);
+ }
+
+ return isolatedTargetOp;
+}
+
+/// genPostTargetOp method generates the postTargetOp that contains all the ops
+/// after the split point. It also creates the block arguments and maps the
+/// values accordingly. It also creates the load operations for the allocs
+/// and recomputes the necessary ops.
+static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
+ Operation *splitBeforeOp,
+ SmallVector<Value> &postMapOperands,
+ RewriterBase &rewriter,
+ SmallVector<Value> &allocs,
+ SetVector<Operation *> &toRecompute) {
+ auto loc = targetOp.getLoc();
+ auto *targetBlock = &targetOp.getRegion().front();
+ SmallVector<Value> postHostEvalVars{targetOp.getHostEvalVars()};
+ // Create the post target op
+ omp::TargetOp postTargetOp = rewriter.create<omp::TargetOp>(
+ targetOp.getLoc(), targetOp.getAllocateVars(),
+ targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+ targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+ targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars,
+ targetOp.getIfExpr(), targetOp.getInReductionVars(),
+ targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
+ targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
+ targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
+ targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
+ targetOp.getPrivateMapsAttr());
+ // Create the block for postTargetOp
+ auto *postTargetBlock = rewriter.createBlock(
+ &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
+ IRMapping postMapping;
+ // Create block arguments and map the values.
+ createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, postTargetBlock,
+ postHostEvalVars, postMapOperands, allocs, postMapping);
+ // Handle the load operations for the allocs and recompute ops.
+ reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock,
+ postTargetBlock, postHostEvalVars, postMapOperands,
+ allocs, toRecompute, postMapping);
+ assert(splitBeforeOp->getNumResults() == 0 ||
+ llvm::all_of(splitBeforeOp->getResults(),
+ [](Value result) { return result.use_empty(); }));
+ // Clone the original operations after the split point.
+ for (auto it = std::next(splitBeforeOp->getIterator());
+ it != targetBlock->end(); it++)
+ rewriter.clone(*it, postMapping);
+ return postTargetOp;
+}
+
+/// isolateOp method rewrites a omp.target_data { omp.target } in to
+/// omp.target_data {
+/// // preTargetOp region contains ops before splitBeforeOp.
+/// omp.target {}
+/// // isolatedTargetOp region contains splitBeforeOp,
+/// omp.target {}
+/// // postTargetOp region contains ops after splitBeforeOp.
+/// omp.target {}
+/// }
+/// It also handles the mapping of variables and the caching/recomputing
+/// of values as needed.
+static FailureOr<SplitResult> isolateOp(Operation *splitBeforeOp,
+ bool splitAfter, RewriterBase &rewriter,
+ mlir::ModuleOp module,
+ bool isTargetDevice) {
+ auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp());
+ assert(targetOp);
+ rewriter.setInsertionPoint(targetOp);
+
+ // Prepare the map operands for preTargetOp and postTargetOp
+ auto preMapOperands = SmallVector<Value>(targetOp.getMapVars());
+ auto postMapOperands = SmallVector<Value>(targetOp.getMapVars());
+
+ // Vectors to hold analysis results
+ SmallVector<Value> requiredVals;
+ SetVector<Operation *> toCache;
+ SetVector<Operation *> toRecompute;
+ SetVector<Operation *> nonRecomputable;
+ SmallVector<Value> allocs;
+ struct HostEvalVars hostEvalVars;
+
+ // Analyze the ops in target region to determine which ops need to be
+ // cached and which ops need to be recomputed
+ computeAllocsCacheRecomputable(
+ targetOp, splitBeforeOp, rewriter, preMapOperands, postMapOperands,
+ allocs, requiredVals, nonRecomputable, toCache, toRecompute);
+
+ rewriter.setInsertionPoint(targetOp);
+
+ // Generate the preTargetOp that contains all the ops before splitBeforeOp.
+ auto preTargetOp =
+ genPreTargetOp(targetOp, preMapOperands, allocs, splitBeforeOp, rewriter,
+ hostEvalVars, isTargetDevice);
+
+ // Move the ops of preTarget to host.
+ auto res = moveToHost(preTargetOp, rewriter, module, hostEvalVars);
+ if (failed(res))
+ return failure();
+ rewriter.setInsertionPoint(targetOp);
+
+ // Generate the isolatedTargetOp
+ omp::TargetOp isolatedTargetOp =
+ genIsolatedTargetOp(targetOp, postMapOperands, splitBeforeOp, rewriter,
+ allocs, toRecompute, hostEvalVars, isTargetDevice);
+
+ omp::TargetOp postTargetOp = nullptr;
+ // Generate the postTargetOp that contains all the ops after splitBeforeOp.
+ if (splitAfter) {
+ rewriter.setInsertionPoint(targetOp);
+ postTargetOp = genPostTargetOp(targetOp, splitBeforeOp, postMapOperands,
+ rewriter, allocs, toRecompute);
+ }
+ // Finally erase the original targetOp.
+ rewriter.eraseOp(targetOp);
+ return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp};
+}
+
+/// Recursively fission target ops until no more nested ops can be isolated.
+static LogicalResult fissionTarget(omp::TargetOp targetOp,
+ RewriterBase &rewriter,
+ mlir::ModuleOp module, bool isTargetDevice) {
+ auto tuple = getNestedOpToIsolate(targetOp);
+ if (!tuple) {
+ LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n");
+ struct HostEvalVars hostEvalVars;
+ return moveToHost(targetOp, rewriter, module, hostEvalVars);
+ }
+ Operation *toIsolate = std::get<0>(*tuple);
+ bool splitBefore = !std::get<1>(*tuple);
+ bool splitAfter = !std::get<2>(*tuple);
+ // Recursively isolate the target op.
+ if (splitBefore && splitAfter) {
+ auto res =
+ isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice);
+ if (failed(res))
+ return failure();
+ return fissionTarget((*res).postTargetOp, rewriter, module, isTargetDevice);
+ }
+ // Isolate only before the op.
+ if (splitBefore) {
+ auto res =
+ isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice);
+ if (failed(res))
+ return failure();
+ } else {
+ emitError(toIsolate->getLoc(), "Unhandled case in fissionTarget");
+ return failure();
+ }
+ return success();
+}
+
+/// Pass to lower omp.workdistribute ops.
+class LowerWorkdistributePass
+ : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
+public:
+ void runOnOperation() override {
+ MLIRContext &context = getContext();
+ auto moduleOp = getOperation();
+ bool changed = false;
+ SetVector<omp::TargetOp> targetOpsToProcess;
+ auto verify =
+ moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
+ if (failed(verifyTargetTeamsWorkdistribute(workdistribute)))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ if (verify.wasInterrupted())
+ return signalPassFailure();
+
+ auto fission =
+ moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
+ auto res = fissionWorkdistribute(workdistribute);
+ if (failed(res))
+ return WalkResult::interrupt();
+ changed |= *res;
+ return WalkResult::advance();
+ });
+ if (fission.wasInterrupted())
+ return signalPassFailure();
+
+ auto rtCallLower =
+ moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
+ auto res = workdistributeRuntimeCallLower(workdistribute,
+ targetOpsToProcess);
+ if (failed(res))
+ return WalkResult::interrupt();
+ changed |= *res;
+ return WalkResult::advance();
+ });
+ if (rtCallLower.wasInterrupted())
+ return signalPassFailure();
+
+ moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) {
+ changed |= workdistributeDoLower(workdistribute, targetOpsToProcess);
+ });
+
+ moduleOp->walk([&](mlir::omp::TeamsOp teams) {
+ changed |= teamsWorkdistributeToSingleOp(teams, targetOpsToProcess);
+ });
+ if (changed) {
+ bool isTargetDevice =
+ llvm::cast<mlir::omp::OffloadModuleInterface>(*moduleOp)
+ .getIsTargetDevice();
+ IRRewriter rewriter(&context);
+ for (auto targetOp : targetOpsToProcess) {
+ auto res = splitTargetData(targetOp, rewriter);
+ if (failed(res))
+ return signalPassFailure();
+ if (*res) {
+ if (failed(fissionTarget(*res, rewriter, moduleOp, isTargetDevice)))
+ return signalPassFailure();
+ }
+ }
+ }
+ }
+};
+} // namespace
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index a83b066..1ecb6d3 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -301,8 +301,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm,
addNestedPassToAllTopLevelOperations<PassConstructor>(
pm, hlfir::createInlineHLFIRAssign);
pm.addPass(hlfir::createConvertHLFIRtoFIR());
- if (enableOpenMP != EnableOpenMP::None)
+ if (enableOpenMP != EnableOpenMP::None) {
pm.addPass(flangomp::createLowerWorkshare());
+ pm.addPass(flangomp::createLowerWorkdistribute());
+ }
if (enableOpenMP == EnableOpenMP::Simd)
pm.addPass(flangomp::createSimdOnlyPass());
}
diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
index 061a7d2..bdc3418 100644
--- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
+++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
@@ -474,7 +474,7 @@ public:
mlir::PatternRewriter &rewriter) const override {
LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n";
loop.dump(););
- LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
+ [[maybe_unused]] auto loopAnalysis =
functionAnalysis.getChildLoopAnalysis(loop);
if (!loopAnalysis.canPromoteToAffine())
return rewriter.notifyMatchFailure(loop, "cannot promote to affine");
diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp
index 80b3f68..8601499 100644
--- a/flang/lib/Optimizer/Transforms/StackArrays.cpp
+++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp
@@ -561,7 +561,7 @@ static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter,
return stack;
fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy);
- LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy =
+ [[maybe_unused]] fir::ReferenceType firRefTy =
mlir::cast<fir::ReferenceType>(stackTy);
assert(firHeapTy.getElementType() == firRefTy.getElementType() &&
"Allocations must have the same type");
diff --git a/flang/lib/Parser/parse-tree.cpp b/flang/lib/Parser/parse-tree.cpp
index cb30939..8cbaa39 100644
--- a/flang/lib/Parser/parse-tree.cpp
+++ b/flang/lib/Parser/parse-tree.cpp
@@ -185,7 +185,7 @@ StructureConstructor ArrayElement::ConvertToStructureConstructor(
std::list<ComponentSpec> components;
for (auto &subscript : subscripts) {
components.emplace_back(std::optional<Keyword>{},
- ComponentDataSource{std::move(*Unwrap<Expr>(subscript))});
+ ComponentDataSource{std::move(UnwrapRef<Expr>(subscript))});
}
DerivedTypeSpec spec{std::move(name), std::list<TypeParamSpec>{}};
spec.derivedTypeSpec = &derived;
diff --git a/flang/lib/Semantics/assignment.cpp b/flang/lib/Semantics/assignment.cpp
index f4aa496..1824a7d 100644
--- a/flang/lib/Semantics/assignment.cpp
+++ b/flang/lib/Semantics/assignment.cpp
@@ -194,7 +194,8 @@ void AssignmentContext::CheckShape(parser::CharBlock at, const SomeExpr *expr) {
template <typename A> void AssignmentContext::PushWhereContext(const A &x) {
const auto &expr{std::get<parser::LogicalExpr>(x.t)};
- CheckShape(expr.thing.value().source, GetExpr(context_, expr));
+ CheckShape(
+ parser::UnwrapRef<parser::Expr>(expr).source, GetExpr(context_, expr));
++whereDepth_;
}
diff --git a/flang/lib/Semantics/check-allocate.cpp b/flang/lib/Semantics/check-allocate.cpp
index 823aa4e..e019bbd 100644
--- a/flang/lib/Semantics/check-allocate.cpp
+++ b/flang/lib/Semantics/check-allocate.cpp
@@ -151,7 +151,9 @@ static std::optional<AllocateCheckerInfo> CheckAllocateOptions(
[&](const parser::MsgVariable &var) {
WarnOnDeferredLengthCharacterScalar(context,
GetExpr(context, var),
- var.v.thing.thing.GetSource(), "ERRMSG=");
+ parser::UnwrapRef<parser::Variable>(var)
+ .GetSource(),
+ "ERRMSG=");
if (info.gotMsg) { // C943
context.Say(
"ERRMSG may not be duplicated in a ALLOCATE statement"_err_en_US);
@@ -439,7 +441,7 @@ static bool HaveCompatibleLengths(
evaluate::ToInt64(type1.characterTypeSpec().length().GetExplicit())};
auto v2{
evaluate::ToInt64(type2.characterTypeSpec().length().GetExplicit())};
- return !v1 || !v2 || *v1 == *v2;
+ return !v1 || !v2 || (*v1 >= 0 ? *v1 : 0) == (*v2 >= 0 ? *v2 : 0);
} else {
return true;
}
@@ -452,7 +454,7 @@ static bool HaveCompatibleLengths(
auto v1{
evaluate::ToInt64(type1.characterTypeSpec().length().GetExplicit())};
auto v2{type2.knownLength()};
- return !v1 || !v2 || *v1 == *v2;
+ return !v1 || !v2 || (*v1 >= 0 ? *v1 : 0) == (*v2 >= 0 ? *v2 : 0);
} else {
return true;
}
@@ -598,7 +600,7 @@ bool AllocationCheckerHelper::RunChecks(SemanticsContext &context) {
std::optional<evaluate::ConstantSubscript> lbound;
if (const auto &lb{std::get<0>(shapeSpec.t)}) {
lbound.reset();
- const auto &lbExpr{lb->thing.thing.value()};
+ const auto &lbExpr{parser::UnwrapRef<parser::Expr>(lb)};
if (const auto *expr{GetExpr(context, lbExpr)}) {
auto folded{
evaluate::Fold(context.foldingContext(), SomeExpr(*expr))};
@@ -609,7 +611,8 @@ bool AllocationCheckerHelper::RunChecks(SemanticsContext &context) {
lbound = 1;
}
if (lbound) {
- const auto &ubExpr{std::get<1>(shapeSpec.t).thing.thing.value()};
+ const auto &ubExpr{
+ parser::UnwrapRef<parser::Expr>(std::get<1>(shapeSpec.t))};
if (const auto *expr{GetExpr(context, ubExpr)}) {
auto folded{
evaluate::Fold(context.foldingContext(), SomeExpr(*expr))};
diff --git a/flang/lib/Semantics/check-case.cpp b/flang/lib/Semantics/check-case.cpp
index 5ce143c..7593154 100644
--- a/flang/lib/Semantics/check-case.cpp
+++ b/flang/lib/Semantics/check-case.cpp
@@ -72,7 +72,7 @@ private:
}
std::optional<Value> GetValue(const parser::CaseValue &caseValue) {
- const parser::Expr &expr{caseValue.thing.thing.value()};
+ const auto &expr{parser::UnwrapRef<parser::Expr>(caseValue)};
auto *x{expr.typedExpr.get()};
if (x && x->v) { // C1147
auto type{x->v->GetType()};
diff --git a/flang/lib/Semantics/check-coarray.cpp b/flang/lib/Semantics/check-coarray.cpp
index 0e444f1..9113369 100644
--- a/flang/lib/Semantics/check-coarray.cpp
+++ b/flang/lib/Semantics/check-coarray.cpp
@@ -112,7 +112,7 @@ static void CheckTeamType(
static void CheckTeamStat(
SemanticsContext &context, const parser::ImageSelectorSpec::Stat &stat) {
- const parser::Variable &var{stat.v.thing.thing.value()};
+ const auto &var{parser::UnwrapRef<parser::Variable>(stat)};
if (parser::GetCoindexedNamedObject(var)) {
context.Say(parser::FindSourceLocation(var), // C931
"Image selector STAT variable must not be a coindexed "
@@ -147,7 +147,8 @@ static void CheckSyncStat(SemanticsContext &context,
},
[&](const parser::MsgVariable &var) {
WarnOnDeferredLengthCharacterScalar(context, GetExpr(context, var),
- var.v.thing.thing.GetSource(), "ERRMSG=");
+ parser::UnwrapRef<parser::Variable>(var).GetSource(),
+ "ERRMSG=");
if (gotMsg) {
context.Say( // C1172
"The errmsg-variable in a sync-stat-list may not be repeated"_err_en_US);
@@ -260,7 +261,9 @@ static void CheckEventWaitSpecList(SemanticsContext &context,
[&](const parser::MsgVariable &var) {
WarnOnDeferredLengthCharacterScalar(context,
GetExpr(context, var),
- var.v.thing.thing.GetSource(), "ERRMSG=");
+ parser::UnwrapRef<parser::Variable>(var)
+ .GetSource(),
+ "ERRMSG=");
if (gotMsg) {
context.Say( // C1178
"A errmsg-variable in a event-wait-spec-list may not be repeated"_err_en_US);
diff --git a/flang/lib/Semantics/check-data.cpp b/flang/lib/Semantics/check-data.cpp
index 5459290..3bcf711 100644
--- a/flang/lib/Semantics/check-data.cpp
+++ b/flang/lib/Semantics/check-data.cpp
@@ -25,9 +25,10 @@ namespace Fortran::semantics {
// Ensures that references to an implied DO loop control variable are
// represented as such in the "body" of the implied DO loop.
void DataChecker::Enter(const parser::DataImpliedDo &x) {
- auto name{std::get<parser::DataImpliedDo::Bounds>(x.t).name.thing.thing};
+ const auto &name{parser::UnwrapRef<parser::Name>(
+ std::get<parser::DataImpliedDo::Bounds>(x.t).name)};
int kind{evaluate::ResultType<evaluate::ImpliedDoIndex>::kind};
- if (const auto dynamicType{evaluate::DynamicType::From(*name.symbol)}) {
+ if (const auto dynamicType{evaluate::DynamicType::From(DEREF(name.symbol))}) {
if (dynamicType->category() == TypeCategory::Integer) {
kind = dynamicType->kind();
}
@@ -36,7 +37,8 @@ void DataChecker::Enter(const parser::DataImpliedDo &x) {
}
void DataChecker::Leave(const parser::DataImpliedDo &x) {
- auto name{std::get<parser::DataImpliedDo::Bounds>(x.t).name.thing.thing};
+ const auto &name{parser::UnwrapRef<parser::Name>(
+ std::get<parser::DataImpliedDo::Bounds>(x.t).name)};
exprAnalyzer_.RemoveImpliedDo(name.source);
}
@@ -211,7 +213,7 @@ void DataChecker::Leave(const parser::DataIDoObject &object) {
std::get_if<parser::Scalar<common::Indirection<parser::Designator>>>(
&object.u)}) {
if (MaybeExpr expr{exprAnalyzer_.Analyze(*designator)}) {
- auto source{designator->thing.value().source};
+ auto source{parser::UnwrapRef<parser::Designator>(*designator).source};
DataVarChecker checker{exprAnalyzer_.context(), source};
if (checker(*expr)) {
if (checker.HasComponentWithoutSubscripts()) { // C880
diff --git a/flang/lib/Semantics/check-deallocate.cpp b/flang/lib/Semantics/check-deallocate.cpp
index c45b585..c1ebc5f 100644
--- a/flang/lib/Semantics/check-deallocate.cpp
+++ b/flang/lib/Semantics/check-deallocate.cpp
@@ -114,7 +114,8 @@ void DeallocateChecker::Leave(const parser::DeallocateStmt &deallocateStmt) {
},
[&](const parser::MsgVariable &var) {
WarnOnDeferredLengthCharacterScalar(context_,
- GetExpr(context_, var), var.v.thing.thing.GetSource(),
+ GetExpr(context_, var),
+ parser::UnwrapRef<parser::Variable>(var).GetSource(),
"ERRMSG=");
if (gotMsg) {
context_.Say(
diff --git a/flang/lib/Semantics/check-do-forall.cpp b/flang/lib/Semantics/check-do-forall.cpp
index a2f3685..8a47340 100644
--- a/flang/lib/Semantics/check-do-forall.cpp
+++ b/flang/lib/Semantics/check-do-forall.cpp
@@ -535,7 +535,8 @@ private:
if (const SomeExpr * expr{GetExpr(context_, scalarExpression)}) {
if (!ExprHasTypeCategory(*expr, TypeCategory::Integer)) {
// No warnings or errors for type INTEGER
- const parser::CharBlock &loc{scalarExpression.thing.value().source};
+ parser::CharBlock loc{
+ parser::UnwrapRef<parser::Expr>(scalarExpression).source};
CheckDoControl(loc, ExprHasTypeCategory(*expr, TypeCategory::Real));
}
}
@@ -552,7 +553,7 @@ private:
CheckDoExpression(*bounds.step);
if (IsZero(*bounds.step)) {
context_.Warn(common::UsageWarning::ZeroDoStep,
- bounds.step->thing.value().source,
+ parser::UnwrapRef<parser::Expr>(bounds.step).source,
"DO step expression should not be zero"_warn_en_US);
}
}
@@ -615,7 +616,7 @@ private:
// C1121 - procedures in mask must be pure
void CheckMaskIsPure(const parser::ScalarLogicalExpr &mask) const {
UnorderedSymbolSet references{
- GatherSymbolsFromExpression(mask.thing.thing.value())};
+ GatherSymbolsFromExpression(parser::UnwrapRef<parser::Expr>(mask))};
for (const Symbol &ref : OrderBySourcePosition(references)) {
if (IsProcedure(ref) && !IsPureProcedure(ref)) {
context_.SayWithDecl(ref, parser::Unwrap<parser::Expr>(mask)->source,
@@ -639,32 +640,33 @@ private:
}
void HasNoReferences(const UnorderedSymbolSet &indexNames,
- const parser::ScalarIntExpr &expr) const {
- CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()),
- indexNames,
+ const parser::ScalarIntExpr &scalarIntExpr) const {
+ const auto &expr{parser::UnwrapRef<parser::Expr>(scalarIntExpr)};
+ CheckNoCollisions(GatherSymbolsFromExpression(expr), indexNames,
"%s limit expression may not reference index variable '%s'"_err_en_US,
- expr.thing.thing.value().source);
+ expr.source);
}
// C1129, names in local locality-specs can't be in mask expressions
void CheckMaskDoesNotReferenceLocal(const parser::ScalarLogicalExpr &mask,
const UnorderedSymbolSet &localVars) const {
- CheckNoCollisions(GatherSymbolsFromExpression(mask.thing.thing.value()),
- localVars,
+ const auto &expr{parser::UnwrapRef<parser::Expr>(mask)};
+ CheckNoCollisions(GatherSymbolsFromExpression(expr), localVars,
"%s mask expression references variable '%s'"
" in LOCAL locality-spec"_err_en_US,
- mask.thing.thing.value().source);
+ expr.source);
}
// C1129, names in local locality-specs can't be in limit or step
// expressions
- void CheckExprDoesNotReferenceLocal(const parser::ScalarIntExpr &expr,
+ void CheckExprDoesNotReferenceLocal(
+ const parser::ScalarIntExpr &scalarIntExpr,
const UnorderedSymbolSet &localVars) const {
- CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()),
- localVars,
+ const auto &expr{parser::UnwrapRef<parser::Expr>(scalarIntExpr)};
+ CheckNoCollisions(GatherSymbolsFromExpression(expr), localVars,
"%s expression references variable '%s'"
" in LOCAL locality-spec"_err_en_US,
- expr.thing.thing.value().source);
+ expr.source);
}
// C1130, DEFAULT(NONE) locality requires names to be in locality-specs to
@@ -772,7 +774,7 @@ private:
HasNoReferences(indexNames, std::get<2>(control.t));
if (const auto &intExpr{
std::get<std::optional<parser::ScalarIntExpr>>(control.t)}) {
- const parser::Expr &expr{intExpr->thing.thing.value()};
+ const auto &expr{parser::UnwrapRef<parser::Expr>(intExpr)};
CheckNoCollisions(GatherSymbolsFromExpression(expr), indexNames,
"%s step expression may not reference index variable '%s'"_err_en_US,
expr.source);
@@ -840,7 +842,7 @@ private:
}
void CheckForImpureCall(const parser::ScalarIntExpr &x,
std::optional<IndexVarKind> nesting) const {
- const auto &parsedExpr{x.thing.thing.value()};
+ const auto &parsedExpr{parser::UnwrapRef<parser::Expr>(x)};
auto oldLocation{context_.location()};
context_.set_location(parsedExpr.source);
if (const auto &typedExpr{parsedExpr.typedExpr}) {
@@ -1124,7 +1126,8 @@ void DoForallChecker::Leave(const parser::ConnectSpec &connectSpec) {
const auto *newunit{
std::get_if<parser::ConnectSpec::Newunit>(&connectSpec.u)};
if (newunit) {
- context_.CheckIndexVarRedefine(newunit->v.thing.thing);
+ context_.CheckIndexVarRedefine(
+ parser::UnwrapRef<parser::Variable>(newunit));
}
}
@@ -1166,14 +1169,14 @@ void DoForallChecker::Leave(const parser::InquireSpec &inquireSpec) {
const auto *intVar{std::get_if<parser::InquireSpec::IntVar>(&inquireSpec.u)};
if (intVar) {
const auto &scalar{std::get<parser::ScalarIntVariable>(intVar->t)};
- context_.CheckIndexVarRedefine(scalar.thing.thing);
+ context_.CheckIndexVarRedefine(parser::UnwrapRef<parser::Variable>(scalar));
}
}
void DoForallChecker::Leave(const parser::IoControlSpec &ioControlSpec) {
const auto *size{std::get_if<parser::IoControlSpec::Size>(&ioControlSpec.u)};
if (size) {
- context_.CheckIndexVarRedefine(size->v.thing.thing);
+ context_.CheckIndexVarRedefine(parser::UnwrapRef<parser::Variable>(size));
}
}
@@ -1190,16 +1193,19 @@ static void CheckIoImpliedDoIndex(
void DoForallChecker::Leave(const parser::OutputImpliedDo &outputImpliedDo) {
CheckIoImpliedDoIndex(context_,
- std::get<parser::IoImpliedDoControl>(outputImpliedDo.t).name.thing.thing);
+ parser::UnwrapRef<parser::Name>(
+ std::get<parser::IoImpliedDoControl>(outputImpliedDo.t).name));
}
void DoForallChecker::Leave(const parser::InputImpliedDo &inputImpliedDo) {
CheckIoImpliedDoIndex(context_,
- std::get<parser::IoImpliedDoControl>(inputImpliedDo.t).name.thing.thing);
+ parser::UnwrapRef<parser::Name>(
+ std::get<parser::IoImpliedDoControl>(inputImpliedDo.t).name));
}
void DoForallChecker::Leave(const parser::StatVariable &statVariable) {
- context_.CheckIndexVarRedefine(statVariable.v.thing.thing);
+ context_.CheckIndexVarRedefine(
+ parser::UnwrapRef<parser::Variable>(statVariable));
}
} // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/check-io.cpp b/flang/lib/Semantics/check-io.cpp
index a1ff4b9..19059ad 100644
--- a/flang/lib/Semantics/check-io.cpp
+++ b/flang/lib/Semantics/check-io.cpp
@@ -424,8 +424,8 @@ void IoChecker::Enter(const parser::InquireSpec::CharVar &spec) {
specKind = IoSpecKind::Dispose;
break;
}
- const parser::Variable &var{
- std::get<parser::ScalarDefaultCharVariable>(spec.t).thing.thing};
+ const auto &var{parser::UnwrapRef<parser::Variable>(
+ std::get<parser::ScalarDefaultCharVariable>(spec.t))};
std::string what{parser::ToUpperCaseLetters(common::EnumToString(specKind))};
CheckForDefinableVariable(var, what);
WarnOnDeferredLengthCharacterScalar(
@@ -627,7 +627,7 @@ void IoChecker::Enter(const parser::IoUnit &spec) {
}
void IoChecker::Enter(const parser::MsgVariable &msgVar) {
- const parser::Variable &var{msgVar.v.thing.thing};
+ const auto &var{parser::UnwrapRef<parser::Variable>(msgVar)};
if (stmt_ == IoStmtKind::None) {
// allocate, deallocate, image control
CheckForDefinableVariable(var, "ERRMSG");
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index b4c1bf7..ea6fe43 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -2358,7 +2358,7 @@ private:
}
if (auto &repl{std::get<parser::OmpClause::Replayable>(clause.u).v}) {
// Scalar<Logical<Constant<indirection<Expr>>>>
- const parser::Expr &parserExpr{repl->v.thing.thing.thing.value()};
+ const auto &parserExpr{parser::UnwrapRef<parser::Expr>(repl)};
if (auto &&expr{GetEvaluateExpr(parserExpr)}) {
return GetLogicalValue(*expr).value_or(true);
}
@@ -2372,7 +2372,7 @@ private:
bool isTransparent{true};
if (auto &transp{std::get<parser::OmpClause::Transparent>(clause.u).v}) {
// Scalar<Integer<indirection<Expr>>>
- const parser::Expr &parserExpr{transp->v.thing.thing.value()};
+ const auto &parserExpr{parser::UnwrapRef<parser::Expr>(transp)};
if (auto &&expr{GetEvaluateExpr(parserExpr)}) {
// If the argument is omp_not_impex (defined as 0), then
// the task is not transparent, otherwise it is.
@@ -2411,8 +2411,8 @@ private:
}
}
// Scalar<Logical<indirection<Expr>>>
- auto &parserExpr{
- std::get<parser::ScalarLogicalExpr>(ifc.v.t).thing.thing.value()};
+ const auto &parserExpr{parser::UnwrapRef<parser::Expr>(
+ std::get<parser::ScalarLogicalExpr>(ifc.v.t))};
if (auto &&expr{GetEvaluateExpr(parserExpr)}) {
// If the value is known to be false, an undeferred task will be
// generated.
diff --git a/flang/lib/Semantics/data-to-inits.cpp b/flang/lib/Semantics/data-to-inits.cpp
index 1e46dab..bbf3b28 100644
--- a/flang/lib/Semantics/data-to-inits.cpp
+++ b/flang/lib/Semantics/data-to-inits.cpp
@@ -179,13 +179,14 @@ bool DataInitializationCompiler<DSV>::Scan(
template <typename DSV>
bool DataInitializationCompiler<DSV>::Scan(const parser::DataImpliedDo &ido) {
const auto &bounds{std::get<parser::DataImpliedDo::Bounds>(ido.t)};
- auto name{bounds.name.thing.thing};
- const auto *lowerExpr{
- GetExpr(exprAnalyzer_.context(), bounds.lower.thing.thing)};
- const auto *upperExpr{
- GetExpr(exprAnalyzer_.context(), bounds.upper.thing.thing)};
+ const auto &name{parser::UnwrapRef<parser::Name>(bounds.name)};
+ const auto *lowerExpr{GetExpr(
+ exprAnalyzer_.context(), parser::UnwrapRef<parser::Expr>(bounds.lower))};
+ const auto *upperExpr{GetExpr(
+ exprAnalyzer_.context(), parser::UnwrapRef<parser::Expr>(bounds.upper))};
const auto *stepExpr{bounds.step
- ? GetExpr(exprAnalyzer_.context(), bounds.step->thing.thing)
+ ? GetExpr(exprAnalyzer_.context(),
+ parser::UnwrapRef<parser::Expr>(bounds.step))
: nullptr};
if (lowerExpr && upperExpr) {
// Fold the bounds expressions (again) in case any of them depend
@@ -240,7 +241,9 @@ bool DataInitializationCompiler<DSV>::Scan(
return common::visit(
common::visitors{
[&](const parser::Scalar<common::Indirection<parser::Designator>>
- &var) { return Scan(var.thing.value()); },
+ &var) {
+ return Scan(parser::UnwrapRef<parser::Designator>(var));
+ },
[&](const common::Indirection<parser::DataImpliedDo> &ido) {
return Scan(ido.value());
},
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index 2feec98..4aeb9a4 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -176,8 +176,8 @@ public:
// Find and return a user-defined operator or report an error.
// The provided message is used if there is no such operator.
- MaybeExpr TryDefinedOp(
- const char *, parser::MessageFixedText, bool isUserOp = false);
+ MaybeExpr TryDefinedOp(const char *, parser::MessageFixedText,
+ bool isUserOp = false, bool checkForNullPointer = true);
template <typename E>
MaybeExpr TryDefinedOp(E opr, parser::MessageFixedText msg) {
return TryDefinedOp(
@@ -211,7 +211,8 @@ private:
void SayNoMatch(
const std::string &, bool isAssignment = false, bool isAmbiguous = false);
std::string TypeAsFortran(std::size_t);
- bool AnyUntypedOrMissingOperand() const;
+ bool AnyUntypedOperand() const;
+ bool AnyMissingOperand() const;
ExpressionAnalyzer &context_;
ActualArguments actuals_;
@@ -1954,9 +1955,10 @@ void ArrayConstructorContext::Add(const parser::AcImpliedDo &impliedDo) {
const auto &control{std::get<parser::AcImpliedDoControl>(impliedDo.t)};
const auto &bounds{std::get<parser::AcImpliedDoControl::Bounds>(control.t)};
exprAnalyzer_.Analyze(bounds.name);
- parser::CharBlock name{bounds.name.thing.thing.source};
+ const auto &parsedName{parser::UnwrapRef<parser::Name>(bounds.name)};
+ parser::CharBlock name{parsedName.source};
int kind{ImpliedDoIntType::kind};
- if (const Symbol * symbol{bounds.name.thing.thing.symbol}) {
+ if (const Symbol *symbol{parsedName.symbol}) {
if (auto dynamicType{DynamicType::From(symbol)}) {
if (dynamicType->category() == TypeCategory::Integer) {
kind = dynamicType->kind();
@@ -1981,7 +1983,7 @@ void ArrayConstructorContext::Add(const parser::AcImpliedDo &impliedDo) {
auto cUpper{ToInt64(upper)};
auto cStride{ToInt64(stride)};
if (!(messageDisplayedSet_ & 0x10) && cStride && *cStride == 0) {
- exprAnalyzer_.SayAt(bounds.step.value().thing.thing.value().source,
+ exprAnalyzer_.SayAt(parser::UnwrapRef<parser::Expr>(bounds.step).source,
"The stride of an implied DO loop must not be zero"_err_en_US);
messageDisplayedSet_ |= 0x10;
}
@@ -2526,7 +2528,7 @@ static const Symbol *GetBindingResolution(
auto ExpressionAnalyzer::AnalyzeProcedureComponentRef(
const parser::ProcComponentRef &pcr, ActualArguments &&arguments,
bool isSubroutine) -> std::optional<CalleeAndArguments> {
- const parser::StructureComponent &sc{pcr.v.thing};
+ const auto &sc{parser::UnwrapRef<parser::StructureComponent>(pcr)};
if (MaybeExpr base{Analyze(sc.base)}) {
if (const Symbol *sym{sc.component.symbol}) {
if (context_.HasError(sym)) {
@@ -3695,11 +3697,12 @@ std::optional<characteristics::Procedure> ExpressionAnalyzer::CheckCall(
MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::Parentheses &x) {
if (MaybeExpr operand{Analyze(x.v.value())}) {
- if (const semantics::Symbol *symbol{GetLastSymbol(*operand)}) {
+ if (IsNullPointerOrAllocatable(&*operand)) {
+ Say("NULL() may not be parenthesized"_err_en_US);
+ } else if (const semantics::Symbol *symbol{GetLastSymbol(*operand)}) {
if (const semantics::Symbol *result{FindFunctionResult(*symbol)}) {
if (semantics::IsProcedurePointer(*result)) {
- Say("A function reference that returns a procedure "
- "pointer may not be parenthesized"_err_en_US); // C1003
+ Say("A function reference that returns a procedure pointer may not be parenthesized"_err_en_US); // C1003
}
}
}
@@ -3788,7 +3791,7 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::DefinedUnary &x) {
ArgumentAnalyzer analyzer{*this, name.source};
analyzer.Analyze(std::get<1>(x.t));
return analyzer.TryDefinedOp(name.source.ToString().c_str(),
- "No operator %s defined for %s"_err_en_US, true);
+ "No operator %s defined for %s"_err_en_US, /*isUserOp=*/true);
}
// Binary (dyadic) operations
@@ -3997,7 +4000,9 @@ static bool CheckFuncRefToArrayElement(semantics::SemanticsContext &context,
auto &proc{std::get<parser::ProcedureDesignator>(funcRef.v.t)};
const auto *name{std::get_if<parser::Name>(&proc.u)};
if (!name) {
- name = &std::get<parser::ProcComponentRef>(proc.u).v.thing.component;
+ name = &parser::UnwrapRef<parser::StructureComponent>(
+ std::get<parser::ProcComponentRef>(proc.u))
+ .component;
}
if (!name->symbol) {
return false;
@@ -4047,14 +4052,16 @@ static void FixMisparsedFunctionReference(
}
}
auto &proc{std::get<parser::ProcedureDesignator>(funcRef.v.t)};
- if (Symbol *origSymbol{
- common::visit(common::visitors{
- [&](parser::Name &name) { return name.symbol; },
- [&](parser::ProcComponentRef &pcr) {
- return pcr.v.thing.component.symbol;
- },
- },
- proc.u)}) {
+ if (Symbol *
+ origSymbol{common::visit(
+ common::visitors{
+ [&](parser::Name &name) { return name.symbol; },
+ [&](parser::ProcComponentRef &pcr) {
+ return parser::UnwrapRef<parser::StructureComponent>(pcr)
+ .component.symbol;
+ },
+ },
+ proc.u)}) {
Symbol &symbol{origSymbol->GetUltimate()};
if (symbol.has<semantics::ObjectEntityDetails>() ||
symbol.has<semantics::AssocEntityDetails>()) {
@@ -4176,15 +4183,23 @@ MaybeExpr ExpressionAnalyzer::IterativelyAnalyzeSubexpressions(
} while (!queue.empty());
// Analyze the collected subexpressions in bottom-up order.
// On an error, bail out and leave partial results in place.
- MaybeExpr result;
- for (auto riter{finish.rbegin()}; riter != finish.rend(); ++riter) {
- const parser::Expr &expr{**riter};
- result = ExprOrVariable(expr, expr.source);
- if (!result) {
- return result;
+ if (finish.size() == 1) {
+ const parser::Expr &expr{DEREF(finish.front())};
+ return ExprOrVariable(expr, expr.source);
+ } else {
+ // NULL() operand catching is deferred to operation analysis so
+ // that they can be accepted by defined operators.
+ auto restorer{AllowNullPointer()};
+ MaybeExpr result;
+ for (auto riter{finish.rbegin()}; riter != finish.rend(); ++riter) {
+ const parser::Expr &expr{**riter};
+ result = ExprOrVariable(expr, expr.source);
+ if (!result) {
+ return result;
+ }
}
+ return result; // last value was from analysis of "top"
}
- return result; // last value was from analysis of "top"
}
MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr &expr) {
@@ -4681,7 +4696,7 @@ bool ArgumentAnalyzer::AnyCUDADeviceData() const {
// attribute.
bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride(
const char *opr) const {
- if (AnyCUDADeviceData() && !AnyUntypedOrMissingOperand()) {
+ if (AnyCUDADeviceData() && !AnyUntypedOperand() && !AnyMissingOperand()) {
std::string oprNameString{"operator("s + opr + ')'};
parser::CharBlock oprName{oprNameString};
parser::Messages buffer;
@@ -4709,9 +4724,9 @@ bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride(
return false;
}
-MaybeExpr ArgumentAnalyzer::TryDefinedOp(
- const char *opr, parser::MessageFixedText error, bool isUserOp) {
- if (AnyUntypedOrMissingOperand()) {
+MaybeExpr ArgumentAnalyzer::TryDefinedOp(const char *opr,
+ parser::MessageFixedText error, bool isUserOp, bool checkForNullPointer) {
+ if (AnyMissingOperand()) {
context_.Say(error, ToUpperCase(opr), TypeAsFortran(0), TypeAsFortran(1));
return std::nullopt;
}
@@ -4790,7 +4805,9 @@ MaybeExpr ArgumentAnalyzer::TryDefinedOp(
context_.Say(
"Operands of %s are not conformable; have rank %d and rank %d"_err_en_US,
ToUpperCase(opr), actuals_[0]->Rank(), actuals_[1]->Rank());
- } else if (CheckForNullPointer() && CheckForAssumedRank()) {
+ } else if (!CheckForAssumedRank()) {
+ } else if (checkForNullPointer && !CheckForNullPointer()) {
+ } else { // use the supplied error
context_.Say(error, ToUpperCase(opr), TypeAsFortran(0), TypeAsFortran(1));
}
return result;
@@ -4808,15 +4825,16 @@ MaybeExpr ArgumentAnalyzer::TryDefinedOp(
for (std::size_t i{0}; i < oprs.size(); ++i) {
parser::Messages buffer;
auto restorer{context_.GetContextualMessages().SetMessages(buffer)};
- if (MaybeExpr thisResult{TryDefinedOp(oprs[i], error)}) {
+ if (MaybeExpr thisResult{TryDefinedOp(oprs[i], error, /*isUserOp=*/false,
+ /*checkForNullPointer=*/false)}) {
result = std::move(thisResult);
hit.push_back(oprs[i]);
hitBuffer = std::move(buffer);
}
}
}
- if (hit.empty()) { // for the error
- result = TryDefinedOp(oprs[0], error);
+ if (hit.empty()) { // run TryDefinedOp() again just to emit errors
+ CHECK(!TryDefinedOp(oprs[0], error).has_value());
} else if (hit.size() > 1) {
context_.Say(
"Matching accessible definitions were found with %zd variant spellings of the generic operator ('%s', '%s')"_err_en_US,
@@ -5232,10 +5250,19 @@ std::string ArgumentAnalyzer::TypeAsFortran(std::size_t i) {
}
}
-bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() const {
+bool ArgumentAnalyzer::AnyUntypedOperand() const {
+ for (const auto &actual : actuals_) {
+ if (actual && !actual->GetType() &&
+ !IsBareNullPointer(actual->UnwrapExpr())) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool ArgumentAnalyzer::AnyMissingOperand() const {
for (const auto &actual : actuals_) {
- if (!actual ||
- (!actual->GetType() && !IsBareNullPointer(actual->UnwrapExpr()))) {
+ if (!actual) {
return true;
}
}
@@ -5268,9 +5295,9 @@ void ExprChecker::Post(const parser::DataStmtObject &obj) {
bool ExprChecker::Pre(const parser::DataImpliedDo &ido) {
parser::Walk(std::get<parser::DataImpliedDo::Bounds>(ido.t), *this);
const auto &bounds{std::get<parser::DataImpliedDo::Bounds>(ido.t)};
- auto name{bounds.name.thing.thing};
+ const auto &name{parser::UnwrapRef<parser::Name>(bounds.name)};
int kind{evaluate::ResultType<evaluate::ImpliedDoIndex>::kind};
- if (const auto dynamicType{evaluate::DynamicType::From(*name.symbol)}) {
+ if (const auto dynamicType{evaluate::DynamicType::From(DEREF(name.symbol))}) {
if (dynamicType->category() == TypeCategory::Integer) {
kind = dynamicType->kind();
}
diff --git a/flang/lib/Semantics/resolve-names-utils.cpp b/flang/lib/Semantics/resolve-names-utils.cpp
index 742bb74..ac67799 100644
--- a/flang/lib/Semantics/resolve-names-utils.cpp
+++ b/flang/lib/Semantics/resolve-names-utils.cpp
@@ -492,12 +492,14 @@ bool EquivalenceSets::CheckDesignator(const parser::Designator &designator) {
const auto &range{std::get<parser::SubstringRange>(x.t)};
bool ok{CheckDataRef(designator.source, dataRef)};
if (const auto &lb{std::get<0>(range.t)}) {
- ok &= CheckSubstringBound(lb->thing.thing.value(), true);
+ ok &= CheckSubstringBound(
+ parser::UnwrapRef<parser::Expr>(lb), true);
} else {
currObject_.substringStart = 1;
}
if (const auto &ub{std::get<1>(range.t)}) {
- ok &= CheckSubstringBound(ub->thing.thing.value(), false);
+ ok &= CheckSubstringBound(
+ parser::UnwrapRef<parser::Expr>(ub), false);
}
return ok;
},
@@ -528,7 +530,8 @@ bool EquivalenceSets::CheckDataRef(
return false;
},
[&](const parser::IntExpr &y) {
- return CheckArrayBound(y.thing.value());
+ return CheckArrayBound(
+ parser::UnwrapRef<parser::Expr>(y));
},
},
subscript.u);
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index ae0ff9ca..699de41 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1140,7 +1140,7 @@ protected:
std::optional<SourceName> BeginCheckOnIndexUseInOwnBounds(
const parser::DoVariable &name) {
std::optional<SourceName> result{checkIndexUseInOwnBounds_};
- checkIndexUseInOwnBounds_ = name.thing.thing.source;
+ checkIndexUseInOwnBounds_ = parser::UnwrapRef<parser::Name>(name).source;
return result;
}
void EndCheckOnIndexUseInOwnBounds(const std::optional<SourceName> &restore) {
@@ -2130,7 +2130,7 @@ public:
void Post(const parser::SubstringInquiry &);
template <typename A, typename B>
void Post(const parser::LoopBounds<A, B> &x) {
- ResolveName(*parser::Unwrap<parser::Name>(x.name));
+ ResolveName(parser::UnwrapRef<parser::Name>(x.name));
}
void Post(const parser::ProcComponentRef &);
bool Pre(const parser::FunctionReference &);
@@ -2560,7 +2560,7 @@ KindExpr DeclTypeSpecVisitor::GetKindParamExpr(
CHECK(!state_.originalKindParameter);
// Save a pointer to the KIND= expression in the parse tree
// in case we need to reanalyze it during PDT instantiation.
- state_.originalKindParameter = &expr->thing.thing.thing.value();
+ state_.originalKindParameter = parser::Unwrap<parser::Expr>(expr);
}
}
// Inhibit some errors now that will be caught later during instantiations.
@@ -5649,6 +5649,7 @@ bool DeclarationVisitor::Pre(const parser::NamedConstantDef &x) {
if (details->init() || symbol.test(Symbol::Flag::InDataStmt)) {
Say(name, "Named constant '%s' already has a value"_err_en_US);
}
+ parser::CharBlock at{parser::UnwrapRef<parser::Expr>(expr).source};
if (inOldStyleParameterStmt_) {
// non-standard extension PARAMETER statement (no parentheses)
Walk(expr);
@@ -5657,7 +5658,6 @@ bool DeclarationVisitor::Pre(const parser::NamedConstantDef &x) {
SayWithDecl(name, symbol,
"Alternative style PARAMETER '%s' must not already have an explicit type"_err_en_US);
} else if (folded) {
- auto at{expr.thing.value().source};
if (evaluate::IsActuallyConstant(*folded)) {
if (const auto *type{currScope().GetType(*folded)}) {
if (type->IsPolymorphic()) {
@@ -5682,8 +5682,7 @@ bool DeclarationVisitor::Pre(const parser::NamedConstantDef &x) {
// standard-conforming PARAMETER statement (with parentheses)
ApplyImplicitRules(symbol);
Walk(expr);
- if (auto converted{EvaluateNonPointerInitializer(
- symbol, expr, expr.thing.value().source)}) {
+ if (auto converted{EvaluateNonPointerInitializer(symbol, expr, at)}) {
details->set_init(std::move(*converted));
}
}
@@ -6149,7 +6148,7 @@ bool DeclarationVisitor::Pre(const parser::KindParam &x) {
if (const auto *kind{std::get_if<
parser::Scalar<parser::Integer<parser::Constant<parser::Name>>>>(
&x.u)}) {
- const parser::Name &name{kind->thing.thing.thing};
+ const auto &name{parser::UnwrapRef<parser::Name>(kind)};
if (!FindSymbol(name)) {
Say(name, "Parameter '%s' not found"_err_en_US);
}
@@ -7460,7 +7459,7 @@ void DeclarationVisitor::DeclareLocalEntity(
Symbol *DeclarationVisitor::DeclareStatementEntity(
const parser::DoVariable &doVar,
const std::optional<parser::IntegerTypeSpec> &type) {
- const parser::Name &name{doVar.thing.thing};
+ const auto &name{parser::UnwrapRef<parser::Name>(doVar)};
const DeclTypeSpec *declTypeSpec{nullptr};
if (auto *prev{FindSymbol(name)}) {
if (prev->owner() == currScope()) {
@@ -7893,13 +7892,14 @@ bool ConstructVisitor::Pre(const parser::DataIDoObject &x) {
common::visit(
common::visitors{
[&](const parser::Scalar<Indirection<parser::Designator>> &y) {
- Walk(y.thing.value());
- const parser::Name &first{parser::GetFirstName(y.thing.value())};
+ const auto &designator{parser::UnwrapRef<parser::Designator>(y)};
+ Walk(designator);
+ const parser::Name &first{parser::GetFirstName(designator)};
if (first.symbol) {
first.symbol->set(Symbol::Flag::InDataStmt);
}
},
- [&](const Indirection<parser::DataImpliedDo> &y) { Walk(y.value()); },
+ [&](const Indirection<parser::DataImpliedDo> &y) { Walk(y); },
},
x.u);
return false;
@@ -8582,8 +8582,7 @@ public:
void Post(const parser::WriteStmt &) { inAsyncIO_ = false; }
void Post(const parser::IoControlSpec::Size &size) {
if (const auto *designator{
- std::get_if<common::Indirection<parser::Designator>>(
- &size.v.thing.thing.u)}) {
+ parser::Unwrap<common::Indirection<parser::Designator>>(size)}) {
NoteAsyncIODesignator(designator->value());
}
}
@@ -9175,16 +9174,17 @@ bool DeclarationVisitor::CheckNonPointerInitialization(
}
void DeclarationVisitor::NonPointerInitialization(
- const parser::Name &name, const parser::ConstantExpr &expr) {
+ const parser::Name &name, const parser::ConstantExpr &constExpr) {
if (CheckNonPointerInitialization(
name, /*inLegacyDataInitialization=*/false)) {
Symbol &ultimate{name.symbol->GetUltimate()};
auto &details{ultimate.get<ObjectEntityDetails>()};
+ const auto &expr{parser::UnwrapRef<parser::Expr>(constExpr)};
if (ultimate.owner().IsParameterizedDerivedType()) {
// Save the expression for per-instantiation analysis.
- details.set_unanalyzedPDTComponentInit(&expr.thing.value());
+ details.set_unanalyzedPDTComponentInit(&expr);
} else if (MaybeExpr folded{EvaluateNonPointerInitializer(
- ultimate, expr, expr.thing.value().source)}) {
+ ultimate, constExpr, expr.source)}) {
details.set_init(std::move(*folded));
ultimate.set(Symbol::Flag::InDataStmt, false);
}