diff options
Diffstat (limited to 'flang/lib')
29 files changed, 2416 insertions, 179 deletions
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 68adf34..0595ca0 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -4987,11 +4987,8 @@ private: // host = device if (!lhsIsDevice && rhsIsDevice) { - if (Fortran::lower::isTransferWithConversion(rhs)) { + if (auto elementalOp = Fortran::lower::isTransferWithConversion(rhs)) { mlir::OpBuilder::InsertionGuard insertionGuard(builder); - auto elementalOp = - mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp()); - assert(elementalOp && "expect elemental op"); auto designateOp = *elementalOp.getBody()->getOps<hlfir::DesignateOp>().begin(); builder.setInsertionPoint(elementalOp); @@ -6079,7 +6076,7 @@ private: if (resTy != wrappedSymTy) { // check size of the pointed to type so we can't overflow by writing // double precision to a single precision allocation, etc - LLVM_ATTRIBUTE_UNUSED auto getBitWidth = [this](mlir::Type ty) { + [[maybe_unused]] auto getBitWidth = [this](mlir::Type ty) { // 15.6.2.6.3: differering result types should be integer, real, // complex or logical if (auto cmplx = mlir::dyn_cast_or_null<mlir::ComplexType>(ty)) diff --git a/flang/lib/Lower/CUDA.cpp b/flang/lib/Lower/CUDA.cpp index bb4bdee..9501b0e 100644 --- a/flang/lib/Lower/CUDA.cpp +++ b/flang/lib/Lower/CUDA.cpp @@ -68,11 +68,26 @@ cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute( return cuf::getDataAttribute(mlirContext, cudaAttr); } -bool Fortran::lower::isTransferWithConversion(mlir::Value rhs) { +hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) { + auto isConversionElementalOp = [](hlfir::ElementalOp elOp) { + return llvm::hasSingleElement( + elOp.getBody()->getOps<hlfir::DesignateOp>()) && + llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 && + llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) == + 1; + }; + if (auto declOp = mlir::dyn_cast<hlfir::DeclareOp>(rhs.getDefiningOp())) { + if (!declOp.getMemref().getDefiningOp()) + return {}; + if (auto associateOp = mlir::dyn_cast<hlfir::AssociateOp>( + declOp.getMemref().getDefiningOp())) + if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>( + associateOp.getSource().getDefiningOp())) + if (isConversionElementalOp(elOp)) + return elOp; + } if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp())) - if (llvm::hasSingleElement(elOp.getBody()->getOps<hlfir::DesignateOp>()) && - llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 && - llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) == 1) - return true; - return false; + if (isConversionElementalOp(elOp)) + return elOp; + return {}; } diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp index d7f94e1..a46d219 100644 --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -5603,7 +5603,7 @@ private: return newIters; }; if (useTripsForSlice) { - LLVM_ATTRIBUTE_UNUSED auto vectorSubscriptShape = + [[maybe_unused]] auto vectorSubscriptShape = getShape(arrayOperands.back()); auto undef = fir::UndefOp::create(builder, loc, idxTy); trips.push_back(undef); diff --git a/flang/lib/Lower/IO.cpp b/flang/lib/Lower/IO.cpp index 604b137..cd53dc9 100644 --- a/flang/lib/Lower/IO.cpp +++ b/flang/lib/Lower/IO.cpp @@ -950,7 +950,8 @@ static void genIoLoop(Fortran::lower::AbstractConverter &converter, makeNextConditionalOn(builder, loc, checkResult, ok, inLoop); const auto &itemList = std::get<0>(ioImpliedDo.t); const auto &control = std::get<1>(ioImpliedDo.t); - const auto &loopSym = *control.name.thing.thing.symbol; + const auto &loopSym = + *Fortran::parser::UnwrapRef<Fortran::parser::Name>(control.name).symbol; mlir::Value loopVar = fir::getBase(converter.genExprAddr( Fortran::evaluate::AsGenericExpr(loopSym).value(), stmtCtx)); auto genControlValue = [&](const Fortran::parser::ScalarIntExpr &expr) { diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp index ff82a36..3ab8a58 100644 --- a/flang/lib/Lower/OpenMP/Atomic.cpp +++ b/flang/lib/Lower/OpenMP/Atomic.cpp @@ -20,6 +20,7 @@ #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Parser/parse-tree.h" +#include "flang/Semantics/openmp-utils.h" #include "flang/Semantics/semantics.h" #include "flang/Semantics/type.h" #include "flang/Support/Fortran.h" @@ -183,12 +184,8 @@ getMemoryOrderFromRequires(const semantics::Scope &scope) { // scope. // For safety, traverse all enclosing scopes and check if their symbol // contains REQUIRES. - for (const auto *sc{&scope}; sc->kind() != semantics::Scope::Kind::Global; - sc = &sc->parent()) { - const semantics::Symbol *sym = sc->symbol(); - if (!sym) - continue; - + const semantics::Scope &unitScope = semantics::omp::GetProgramUnit(scope); + if (auto *symbol = unitScope.symbol()) { const common::OmpMemoryOrderType *admo = common::visit( [](auto &&s) { using WithOmpDeclarative = semantics::WithOmpDeclarative; @@ -198,7 +195,8 @@ getMemoryOrderFromRequires(const semantics::Scope &scope) { } return static_cast<const common::OmpMemoryOrderType *>(nullptr); }, - sym->details()); + symbol->details()); + if (admo) return getMemoryOrderKind(*admo); } @@ -214,19 +212,83 @@ getDefaultAtomicMemOrder(semantics::SemanticsContext &semaCtx) { return std::nullopt; } -static std::optional<mlir::omp::ClauseMemoryOrderKind> +static std::pair<std::optional<mlir::omp::ClauseMemoryOrderKind>, bool> getAtomicMemoryOrder(semantics::SemanticsContext &semaCtx, const omp::List<omp::Clause> &clauses, const semantics::Scope &scope) { for (const omp::Clause &clause : clauses) { if (auto maybeKind = getMemoryOrderKind(clause.id)) - return *maybeKind; + return std::make_pair(*maybeKind, /*canOverride=*/false); } if (auto maybeKind = getMemoryOrderFromRequires(scope)) - return *maybeKind; + return std::make_pair(*maybeKind, /*canOverride=*/true); - return getDefaultAtomicMemOrder(semaCtx); + return std::make_pair(getDefaultAtomicMemOrder(semaCtx), + /*canOverride=*/false); +} + +static std::optional<mlir::omp::ClauseMemoryOrderKind> +makeValidForAction(std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder, + int action0, int action1, unsigned version) { + // When the atomic default memory order specified on a REQUIRES directive is + // disallowed on a given ATOMIC operation, and it's not ACQ_REL, the order + // reverts to RELAXED. ACQ_REL decays to either ACQUIRE or RELEASE, depending + // on the operation. + + if (!memOrder) { + return memOrder; + } + + using Analysis = parser::OpenMPAtomicConstruct::Analysis; + // Figure out the main action (i.e. disregard a potential capture operation) + int action = action0; + if (action1 != Analysis::None) + action = action0 == Analysis::Read ? action1 : action0; + + // Avaliable orderings: acquire, acq_rel, relaxed, release, seq_cst + + if (action == Analysis::Read) { + // "acq_rel" decays to "acquire" + if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acq_rel) + return mlir::omp::ClauseMemoryOrderKind::Acquire; + } else if (action == Analysis::Write) { + // "acq_rel" decays to "release" + if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acq_rel) + return mlir::omp::ClauseMemoryOrderKind::Release; + } + + if (version > 50) { + if (action == Analysis::Read) { + // "release" prohibited + if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Release) + return mlir::omp::ClauseMemoryOrderKind::Relaxed; + } + if (action == Analysis::Write) { + // "acquire" prohibited + if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acquire) + return mlir::omp::ClauseMemoryOrderKind::Relaxed; + } + } else { + if (action == Analysis::Read) { + // "release" prohibited + if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Release) + return mlir::omp::ClauseMemoryOrderKind::Relaxed; + } else { + if (action & Analysis::Write) { // include "update" + // "acquire" prohibited + if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acquire) + return mlir::omp::ClauseMemoryOrderKind::Relaxed; + if (action == Analysis::Update) { + // "acq_rel" prohibited + if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acq_rel) + return mlir::omp::ClauseMemoryOrderKind::Relaxed; + } + } + } + } + + return memOrder; } static mlir::omp::ClauseMemoryOrderKindAttr @@ -449,16 +511,19 @@ void Fortran::lower::omp::lowerAtomic( mlir::Value atomAddr = fir::getBase(converter.genExprAddr(atom, stmtCtx, &loc)); mlir::IntegerAttr hint = getAtomicHint(converter, clauses); - std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder = - getAtomicMemoryOrder(semaCtx, clauses, - semaCtx.FindScope(construct.source)); + auto [memOrder, canOverride] = getAtomicMemoryOrder( + semaCtx, clauses, semaCtx.FindScope(construct.source)); + + unsigned version = semaCtx.langOptions().OpenMPVersion; + int action0 = analysis.op0.what & analysis.Action; + int action1 = analysis.op1.what & analysis.Action; + if (canOverride) + memOrder = makeValidForAction(memOrder, action0, action1, version); if (auto *cond = get(analysis.cond)) { (void)cond; TODO(loc, "OpenMP ATOMIC COMPARE"); } else { - int action0 = analysis.op0.what & analysis.Action; - int action1 = analysis.op1.what & analysis.Action; mlir::Operation *captureOp = nullptr; fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint(); fir::FirOpBuilder::InsertPoint atomicAt, postAt; diff --git a/flang/lib/Optimizer/Builder/Character.cpp b/flang/lib/Optimizer/Builder/Character.cpp index a096099..155bc0f 100644 --- a/flang/lib/Optimizer/Builder/Character.cpp +++ b/flang/lib/Optimizer/Builder/Character.cpp @@ -92,7 +92,7 @@ getCompileTimeLength(const fir::CharBoxValue &box) { /// Detect the precondition that the value `str` does not reside in memory. Such /// values will have a type `!fir.array<...x!fir.char<N>>` or `!fir.char<N>`. -LLVM_ATTRIBUTE_UNUSED static bool needToMaterialize(mlir::Value str) { +[[maybe_unused]] static bool needToMaterialize(mlir::Value str) { return mlir::isa<fir::SequenceType>(str.getType()) || fir::isa_char(str.getType()); } diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index e07baaf..0195178 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -2169,7 +2169,8 @@ IntrinsicLibrary::genElementalCall<IntrinsicLibrary::ExtendedGenerator>( for (const fir::ExtendedValue &arg : args) { auto *box = arg.getBoxOf<fir::BoxValue>(); if (!arg.getUnboxed() && !arg.getCharBox() && - !(box && fir::isScalarBoxedRecordType(fir::getBase(*box).getType()))) + !(box && (fir::isScalarBoxedRecordType(fir::getBase(*box).getType()) || + fir::isClassStarType(fir::getBase(*box).getType())))) fir::emitFatalError(loc, "nonscalar intrinsic argument"); } if (outline) diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp index 4a9579c..48e1622 100644 --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -336,6 +336,17 @@ bool isBoxedRecordType(mlir::Type ty) { return false; } +// CLASS(*) +bool isClassStarType(mlir::Type ty) { + if (auto clTy = mlir::dyn_cast<fir::ClassType>(fir::unwrapRefType(ty))) { + if (mlir::isa<mlir::NoneType>(clTy.getEleTy())) + return true; + mlir::Type innerType = clTy.unwrapInnerType(); + return innerType && mlir::isa<mlir::NoneType>(innerType); + } + return false; +} + bool isScalarBoxedRecordType(mlir::Type ty) { if (auto refTy = fir::dyn_cast_ptrEleTy(ty)) ty = refTy; @@ -398,12 +409,8 @@ bool isPolymorphicType(mlir::Type ty) { bool isUnlimitedPolymorphicType(mlir::Type ty) { // CLASS(*) - if (auto clTy = mlir::dyn_cast<fir::ClassType>(fir::unwrapRefType(ty))) { - if (mlir::isa<mlir::NoneType>(clTy.getEleTy())) - return true; - mlir::Type innerType = clTy.unwrapInnerType(); - return innerType && mlir::isa<mlir::NoneType>(innerType); - } + if (isClassStarType(ty)) + return true; // TYPE(*) return isAssumedType(ty); } diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp index a48b7ba..63a5803 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ScheduleOrderedAssignments.cpp @@ -21,24 +21,27 @@ //===----------------------------------------------------------------------===// /// Log RAW or WAW conflict. -static void LLVM_ATTRIBUTE_UNUSED logConflict(llvm::raw_ostream &os, - mlir::Value writtenOrReadVarA, - mlir::Value writtenVarB); +[[maybe_unused]] static void logConflict(llvm::raw_ostream &os, + mlir::Value writtenOrReadVarA, + mlir::Value writtenVarB); /// Log when an expression evaluation must be saved. -static void LLVM_ATTRIBUTE_UNUSED logSaveEvaluation(llvm::raw_ostream &os, - unsigned runid, - mlir::Region &yieldRegion, - bool anyWrite); +[[maybe_unused]] static void logSaveEvaluation(llvm::raw_ostream &os, + unsigned runid, + mlir::Region &yieldRegion, + bool anyWrite); /// Log when an assignment is scheduled. -static void LLVM_ATTRIBUTE_UNUSED logAssignmentEvaluation( - llvm::raw_ostream &os, unsigned runid, hlfir::RegionAssignOp assign); +[[maybe_unused]] static void +logAssignmentEvaluation(llvm::raw_ostream &os, unsigned runid, + hlfir::RegionAssignOp assign); /// Log when starting to schedule an order assignment tree. -static void LLVM_ATTRIBUTE_UNUSED logStartScheduling( - llvm::raw_ostream &os, hlfir::OrderedAssignmentTreeOpInterface root); +[[maybe_unused]] static void +logStartScheduling(llvm::raw_ostream &os, + hlfir::OrderedAssignmentTreeOpInterface root); /// Log op if effect value is not known. -static void LLVM_ATTRIBUTE_UNUSED logIfUnkownEffectValue( - llvm::raw_ostream &os, mlir::MemoryEffects::EffectInstance effect, - mlir::Operation &op); +[[maybe_unused]] static void +logIfUnkownEffectValue(llvm::raw_ostream &os, + mlir::MemoryEffects::EffectInstance effect, + mlir::Operation &op); //===----------------------------------------------------------------------===// // Scheduling Implementation @@ -701,23 +704,24 @@ static llvm::raw_ostream &printRegionPath(llvm::raw_ostream &os, return printRegionId(os, yieldRegion); } -static void LLVM_ATTRIBUTE_UNUSED logSaveEvaluation(llvm::raw_ostream &os, - unsigned runid, - mlir::Region &yieldRegion, - bool anyWrite) { +[[maybe_unused]] static void logSaveEvaluation(llvm::raw_ostream &os, + unsigned runid, + mlir::Region &yieldRegion, + bool anyWrite) { os << "run " << runid << " save " << (anyWrite ? "(w)" : " ") << ": "; printRegionPath(os, yieldRegion) << "\n"; } -static void LLVM_ATTRIBUTE_UNUSED logAssignmentEvaluation( - llvm::raw_ostream &os, unsigned runid, hlfir::RegionAssignOp assign) { +[[maybe_unused]] static void +logAssignmentEvaluation(llvm::raw_ostream &os, unsigned runid, + hlfir::RegionAssignOp assign) { os << "run " << runid << " evaluate: "; printNodePath(os, assign.getOperation()) << "\n"; } -static void LLVM_ATTRIBUTE_UNUSED logConflict(llvm::raw_ostream &os, - mlir::Value writtenOrReadVarA, - mlir::Value writtenVarB) { +[[maybe_unused]] static void logConflict(llvm::raw_ostream &os, + mlir::Value writtenOrReadVarA, + mlir::Value writtenVarB) { auto printIfValue = [&](mlir::Value var) -> llvm::raw_ostream & { if (!var) return os << "<unknown>"; @@ -728,8 +732,9 @@ static void LLVM_ATTRIBUTE_UNUSED logConflict(llvm::raw_ostream &os, printIfValue(writtenVarB) << "\n"; } -static void LLVM_ATTRIBUTE_UNUSED logStartScheduling( - llvm::raw_ostream &os, hlfir::OrderedAssignmentTreeOpInterface root) { +[[maybe_unused]] static void +logStartScheduling(llvm::raw_ostream &os, + hlfir::OrderedAssignmentTreeOpInterface root) { os << "------------ scheduling "; printNodePath(os, root.getOperation()); if (auto funcOp = root->getParentOfType<mlir::func::FuncOp>()) @@ -737,9 +742,10 @@ static void LLVM_ATTRIBUTE_UNUSED logStartScheduling( os << "------------\n"; } -static void LLVM_ATTRIBUTE_UNUSED logIfUnkownEffectValue( - llvm::raw_ostream &os, mlir::MemoryEffects::EffectInstance effect, - mlir::Operation &op) { +[[maybe_unused]] static void +logIfUnkownEffectValue(llvm::raw_ostream &os, + mlir::MemoryEffects::EffectInstance effect, + mlir::Operation &op) { if (effect.getValue() != nullptr) return; os << "unknown effected value ("; diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp index 9bf10b5..ed9e41c 100644 --- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp @@ -751,4 +751,245 @@ template bool OpenACCMappableModel<fir::PointerType>::generatePrivateDestroy( mlir::Type type, mlir::OpBuilder &builder, mlir::Location loc, mlir::Value privatized) const; +template <typename Ty> +mlir::Value OpenACCPointerLikeModel<Ty>::genAllocate( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar, + bool &needsFree) const { + + // Unwrap to get the pointee type. + mlir::Type pointeeTy = fir::dyn_cast_ptrEleTy(pointer); + assert(pointeeTy && "expected pointee type to be extractable"); + + // Box types are descriptors that contain both metadata and a pointer to data. + // The `genAllocate` API is designed for simple allocations and cannot + // properly handle the dual nature of boxes. Using `generatePrivateInit` + // instead can allocate both the descriptor and its referenced data. For use + // cases that require an empty descriptor storage, potentially this could be + // implemented here. + if (fir::isa_box_type(pointeeTy)) + return {}; + + // Unlimited polymorphic (class(*)) cannot be handled - size unknown + if (fir::isUnlimitedPolymorphicType(pointeeTy)) + return {}; + + // Return null for dynamic size types because the size of the + // allocation cannot be determined simply from the type. + if (fir::hasDynamicSize(pointeeTy)) + return {}; + + // Use heap allocation for fir.heap, stack allocation for others (fir.ref, + // fir.ptr, fir.llvm_ptr). For fir.ptr, which is supposed to represent a + // Fortran pointer type, it feels a bit odd to "allocate" since it is meant + // to point to an existing entity - but one can imagine where a pointee is + // privatized - thus it makes sense to issue an allocate. + mlir::Value allocation; + if (std::is_same_v<Ty, fir::HeapType>) { + needsFree = true; + allocation = fir::AllocMemOp::create(builder, loc, pointeeTy); + } else { + needsFree = false; + allocation = fir::AllocaOp::create(builder, loc, pointeeTy); + } + + // Convert to the requested pointer type if needed. + // This means converting from a fir.ref to either a fir.llvm_ptr or a fir.ptr. + // fir.heap is already correct type in this case. + if (allocation.getType() != pointer) { + assert(!(std::is_same_v<Ty, fir::HeapType>) && + "fir.heap is already correct type because of allocmem"); + return fir::ConvertOp::create(builder, loc, pointer, allocation); + } + + return allocation; +} + +template mlir::Value OpenACCPointerLikeModel<fir::ReferenceType>::genAllocate( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar, + bool &needsFree) const; + +template mlir::Value OpenACCPointerLikeModel<fir::PointerType>::genAllocate( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar, + bool &needsFree) const; + +template mlir::Value OpenACCPointerLikeModel<fir::HeapType>::genAllocate( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar, + bool &needsFree) const; + +template mlir::Value OpenACCPointerLikeModel<fir::LLVMPointerType>::genAllocate( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + llvm::StringRef varName, mlir::Type varType, mlir::Value originalVar, + bool &needsFree) const; + +static mlir::Value stripCasts(mlir::Value value, bool stripDeclare = true) { + mlir::Value currentValue = value; + + while (currentValue) { + auto *definingOp = currentValue.getDefiningOp(); + if (!definingOp) + break; + + if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(definingOp)) { + currentValue = convertOp.getValue(); + continue; + } + + if (auto viewLike = mlir::dyn_cast<mlir::ViewLikeOpInterface>(definingOp)) { + currentValue = viewLike.getViewSource(); + continue; + } + + if (stripDeclare) { + if (auto declareOp = mlir::dyn_cast<hlfir::DeclareOp>(definingOp)) { + currentValue = declareOp.getMemref(); + continue; + } + + if (auto declareOp = mlir::dyn_cast<fir::DeclareOp>(definingOp)) { + currentValue = declareOp.getMemref(); + continue; + } + } + break; + } + + return currentValue; +} + +template <typename Ty> +bool OpenACCPointerLikeModel<Ty>::genFree( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> varToFree, + mlir::Value allocRes, mlir::Type varType) const { + + // Unwrap to get the pointee type. + mlir::Type pointeeTy = fir::dyn_cast_ptrEleTy(pointer); + assert(pointeeTy && "expected pointee type to be extractable"); + + // Box types contain both a descriptor and data. The `genFree` API + // handles simple deallocations and cannot properly manage both parts. + // Using `generatePrivateDestroy` instead can free both the descriptor and + // its referenced data. + if (fir::isa_box_type(pointeeTy)) + return false; + + // If pointer type is HeapType, assume it's a heap allocation + if (std::is_same_v<Ty, fir::HeapType>) { + fir::FreeMemOp::create(builder, loc, varToFree); + return true; + } + + // Use allocRes if provided to determine the allocation type + mlir::Value valueToInspect = allocRes ? allocRes : varToFree; + + // Strip casts and declare operations to find the original allocation + mlir::Value strippedValue = stripCasts(valueToInspect); + mlir::Operation *originalAlloc = strippedValue.getDefiningOp(); + + // If we found an AllocMemOp (heap allocation), free it + if (mlir::isa_and_nonnull<fir::AllocMemOp>(originalAlloc)) { + mlir::Value toFree = varToFree; + if (!mlir::isa<fir::HeapType>(valueToInspect.getType())) + toFree = fir::ConvertOp::create( + builder, loc, + fir::HeapType::get(varToFree.getType().getElementType()), toFree); + fir::FreeMemOp::create(builder, loc, toFree); + return true; + } + + // If we found an AllocaOp (stack allocation), no deallocation needed + if (mlir::isa_and_nonnull<fir::AllocaOp>(originalAlloc)) + return true; + + // Unable to determine allocation type + return false; +} + +template bool OpenACCPointerLikeModel<fir::ReferenceType>::genFree( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> varToFree, + mlir::Value allocRes, mlir::Type varType) const; + +template bool OpenACCPointerLikeModel<fir::PointerType>::genFree( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> varToFree, + mlir::Value allocRes, mlir::Type varType) const; + +template bool OpenACCPointerLikeModel<fir::HeapType>::genFree( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> varToFree, + mlir::Value allocRes, mlir::Type varType) const; + +template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::genFree( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> varToFree, + mlir::Value allocRes, mlir::Type varType) const; + +template <typename Ty> +bool OpenACCPointerLikeModel<Ty>::genCopy( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> destination, + mlir::TypedValue<mlir::acc::PointerLikeType> source, + mlir::Type varType) const { + + // Check that source and destination types match + if (source.getType() != destination.getType()) + return false; + + // Unwrap to get the pointee type. + mlir::Type pointeeTy = fir::dyn_cast_ptrEleTy(pointer); + assert(pointeeTy && "expected pointee type to be extractable"); + + // Box types contain both a descriptor and referenced data. The genCopy API + // handles simple copies and cannot properly manage both parts. + if (fir::isa_box_type(pointeeTy)) + return false; + + // Unlimited polymorphic (class(*)) cannot be handled because source and + // destination types are not known. + if (fir::isUnlimitedPolymorphicType(pointeeTy)) + return false; + + // Return false for dynamic size types because the copy logic + // cannot be determined simply from the type. + if (fir::hasDynamicSize(pointeeTy)) + return false; + + if (fir::isa_trivial(pointeeTy)) { + auto loadVal = fir::LoadOp::create(builder, loc, source); + fir::StoreOp::create(builder, loc, loadVal, destination); + } else { + hlfir::AssignOp::create(builder, loc, source, destination); + } + return true; +} + +template bool OpenACCPointerLikeModel<fir::ReferenceType>::genCopy( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> destination, + mlir::TypedValue<mlir::acc::PointerLikeType> source, + mlir::Type varType) const; + +template bool OpenACCPointerLikeModel<fir::PointerType>::genCopy( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> destination, + mlir::TypedValue<mlir::acc::PointerLikeType> source, + mlir::Type varType) const; + +template bool OpenACCPointerLikeModel<fir::HeapType>::genCopy( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> destination, + mlir::TypedValue<mlir::acc::PointerLikeType> source, + mlir::Type varType) const; + +template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::genCopy( + mlir::Type pointer, mlir::OpBuilder &builder, mlir::Location loc, + mlir::TypedValue<mlir::acc::PointerLikeType> destination, + mlir::TypedValue<mlir::acc::PointerLikeType> source, + mlir::Type varType) const; + } // namespace fir::acc diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt index b85ee7e..23a7dc8 100644 --- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt @@ -8,6 +8,7 @@ add_flang_library(FlangOpenMPTransforms MapsForPrivatizedSymbols.cpp MapInfoFinalization.cpp MarkDeclareTarget.cpp + LowerWorkdistribute.cpp LowerWorkshare.cpp LowerNontemporal.cpp SimdOnly.cpp diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp new file mode 100644 index 0000000..9278e17 --- /dev/null +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -0,0 +1,1852 @@ +//===- LowerWorkdistribute.cpp +//-------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the lowering and optimisations of omp.workdistribute. +// +// Fortran array statements are lowered to fir as fir.do_loop unordered. +// lower-workdistribute pass works mainly on identifying fir.do_loop unordered +// that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and +// lowers it to target{teams{parallel{distribute{wsloop{loop_nest}}}}}. +// It hoists all the other ops outside target region. +// Relaces heap allocation on target with omp.target_allocmem and +// deallocation with omp.target_freemem from host. Also replaces +// runtime function "Assign" with omp_target_memcpy. +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/HLFIR/Passes.h" +#include "flang/Optimizer/OpenMP/Utils.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/Frontend/OpenMP/OMPConstants.h" +#include <mlir/Dialect/Arith/IR/Arith.h> +#include <mlir/Dialect/LLVMIR/LLVMTypes.h> +#include <mlir/Dialect/Utils/IndexingUtils.h> +#include <mlir/IR/BlockSupport.h> +#include <mlir/IR/BuiltinOps.h> +#include <mlir/IR/Diagnostics.h> +#include <mlir/IR/IRMapping.h> +#include <mlir/IR/PatternMatch.h> +#include <mlir/Interfaces/SideEffectInterfaces.h> +#include <mlir/Support/LLVM.h> +#include <optional> +#include <variant> + +namespace flangomp { +#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE +#include "flang/Optimizer/OpenMP/Passes.h.inc" +} // namespace flangomp + +#define DEBUG_TYPE "lower-workdistribute" + +using namespace mlir; + +namespace { + +/// This string is used to identify the Fortran-specific runtime FortranAAssign. +static constexpr llvm::StringRef FortranAssignStr = "_FortranAAssign"; + +/// The isRuntimeCall function is a utility designed to determine +/// if a given operation is a call to a Fortran-specific runtime function. +static bool isRuntimeCall(Operation *op) { + if (auto callOp = dyn_cast<fir::CallOp>(op)) { + auto callee = callOp.getCallee(); + if (!callee) + return false; + auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee); + if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName())) + return true; + } + return false; +} + +/// This is the single source of truth about whether we should parallelize an +/// operation nested in an omp.workdistribute region. +/// Parallelize here refers to dividing into units of work. +static bool shouldParallelize(Operation *op) { + // True if the op is a runtime call to Assign + if (isRuntimeCall(op)) { + fir::CallOp runtimeCall = cast<fir::CallOp>(op); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); + if (funcName == FortranAssignStr) { + return true; + } + } + // We cannot parallelize ops with side effects. + // Parallelizable operations should not produce + // values that other operations depend on + if (llvm::any_of(op->getResults(), + [](OpResult v) -> bool { return !v.use_empty(); })) + return false; + // We will parallelize unordered loops - these come from array syntax + if (auto loop = dyn_cast<fir::DoLoopOp>(op)) { + auto unordered = loop.getUnordered(); + if (!unordered) + return false; + return *unordered; + } + // We cannot parallelize anything else. + return false; +} + +/// The getPerfectlyNested function is a generic utility for finding +/// a single, "perfectly nested" operation within a parent operation. +template <typename T> +static T getPerfectlyNested(Operation *op) { + if (op->getNumRegions() != 1) + return nullptr; + auto ®ion = op->getRegion(0); + if (region.getBlocks().size() != 1) + return nullptr; + auto *block = ®ion.front(); + auto *firstOp = &block->front(); + if (auto nested = dyn_cast<T>(firstOp)) + if (firstOp->getNextNode() == block->getTerminator()) + return nested; + return nullptr; +} + +/// verifyTargetTeamsWorkdistribute method verifies that +/// omp.target { teams { workdistribute { ... } } } is well formed +/// and fails for function calls that don't have lowering implemented yet. +static LogicalResult +verifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp()); + if (!teams) { + emitError(loc, "workdistribute not nested in teams\n"); + return failure(); + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + emitError(loc, "workdistribute with multiple blocks\n"); + return failure(); + } + if (teams.getRegion().getBlocks().size() != 1) { + emitError(loc, "teams with multiple blocks\n"); + return failure(); + } + + bool foundWorkdistribute = false; + for (auto &op : teams.getOps()) { + if (isa<omp::WorkdistributeOp>(op)) { + if (foundWorkdistribute) { + emitError(loc, "teams has multiple workdistribute ops.\n"); + return failure(); + } + foundWorkdistribute = true; + continue; + } + // Identify any omp dialect ops present before/after workdistribute. + if (op.getDialect() && isa<omp::OpenMPDialect>(op.getDialect()) && + !isa<omp::TerminatorOp>(op)) { + emitError(loc, "teams has omp ops other than workdistribute. Lowering " + "not implemented yet.\n"); + return failure(); + } + } + + omp::TargetOp targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp()); + // return if not omp.target + if (!targetOp) + return success(); + + for (auto &op : workdistribute.getOps()) { + if (auto callOp = dyn_cast<fir::CallOp>(op)) { + if (isRuntimeCall(&op)) { + auto funcName = (*callOp.getCallee()).getRootReference().getValue(); + // _FortranAAssign is handled. Other runtime calls are not supported + // in omp.workdistribute yet. + if (funcName == FortranAssignStr) + continue; + else { + emitError(loc, "Runtime call " + funcName + + " lowering not supported for workdistribute yet."); + return failure(); + } + } + } + } + return success(); +} + +/// fissionWorkdistribute method finds the parallelizable ops +/// within teams {workdistribute} region and moves them to their +/// own teams{workdistribute} region. +/// +/// If B() and D() are parallelizable, +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// C() +/// D() +/// E() +/// } +/// } +/// +/// becomes +/// +/// A() +/// omp.teams { +/// omp.workdistribute { +/// B() +/// } +/// } +/// C() +/// omp.teams { +/// omp.workdistribute { +/// D() +/// } +/// } +/// E() +static FailureOr<bool> +fissionWorkdistribute(omp::WorkdistributeOp workdistribute) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp()); + auto *teamsBlock = &teams.getRegion().front(); + bool changed = false; + // Move the ops inside teams and before workdistribute outside. + IRMapping irMapping; + llvm::SmallVector<Operation *> teamsHoisted; + for (auto &op : teams.getOps()) { + if (&op == workdistribute) { + break; + } + if (shouldParallelize(&op)) { + emitError(loc, "teams has parallelize ops before first workdistribute\n"); + return failure(); + } else { + rewriter.setInsertionPoint(teams); + rewriter.clone(op, irMapping); + teamsHoisted.push_back(&op); + changed = true; + } + } + for (auto *op : llvm::reverse(teamsHoisted)) { + op->replaceAllUsesWith(irMapping.lookup(op)); + op->erase(); + } + + // While we have unhandled operations in the original workdistribute + auto *workdistributeBlock = &workdistribute.getRegion().front(); + auto *terminator = workdistributeBlock->getTerminator(); + while (&workdistributeBlock->front() != terminator) { + rewriter.setInsertionPoint(teams); + IRMapping mapping; + llvm::SmallVector<Operation *> hoisted; + Operation *parallelize = nullptr; + for (auto &op : workdistribute.getOps()) { + if (&op == terminator) { + break; + } + if (shouldParallelize(&op)) { + parallelize = &op; + break; + } else { + rewriter.clone(op, mapping); + hoisted.push_back(&op); + changed = true; + } + } + + for (auto *op : llvm::reverse(hoisted)) { + op->replaceAllUsesWith(mapping.lookup(op)); + op->erase(); + } + + if (parallelize && hoisted.empty() && + parallelize->getNextNode() == terminator) + break; + if (parallelize) { + auto newTeams = rewriter.cloneWithoutRegions(teams); + auto *newTeamsBlock = rewriter.createBlock( + &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {}); + for (auto arg : teamsBlock->getArguments()) + newTeamsBlock->addArgument(arg.getType(), arg.getLoc()); + auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc); + rewriter.create<omp::TerminatorOp>(loc); + rewriter.createBlock(&newWorkdistribute.getRegion(), + newWorkdistribute.getRegion().begin(), {}, {}); + auto *cloned = rewriter.clone(*parallelize); + parallelize->replaceAllUsesWith(cloned); + parallelize->erase(); + rewriter.create<omp::TerminatorOp>(loc); + changed = true; + } + } + return changed; +} + +/// Generate omp.parallel operation with an empty region. +static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { + auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc); + parallelOp.setComposite(composite); + rewriter.createBlock(¶llelOp.getRegion()); + rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc)); + return; +} + +/// Generate omp.distribute operation with an empty region. +static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) { + mlir::omp::DistributeOperands distributeClauseOps; + auto distributeOp = + rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps); + distributeOp.setComposite(composite); + auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion()); + rewriter.setInsertionPointToStart(distributeBlock); + return; +} + +/// Generate loop nest clause operands from fir.do_loop operation. +static void +genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop, + mlir::omp::LoopNestOperands &loopNestClauseOps) { + assert(loopNestClauseOps.loopLowerBounds.empty() && + "Loop nest bounds were already emitted!"); + loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound()); + loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound()); + loopNestClauseOps.loopSteps.push_back(loop.getStep()); + loopNestClauseOps.loopInclusive = rewriter.getUnitAttr(); +} + +/// Generate omp.wsloop operation with an empty region and +/// clone the body of fir.do_loop operation inside the loop nest region. +static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, + const mlir::omp::LoopNestOperands &clauseOps, + bool composite) { + + auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc()); + wsloopOp.setComposite(composite); + rewriter.createBlock(&wsloopOp.getRegion()); + + auto loopNestOp = + rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps); + + // Clone the loop's body inside the loop nest construct using the + // mapped values. + rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(), + loopNestOp.getRegion().begin()); + Block *clonedBlock = &loopNestOp.getRegion().back(); + mlir::Operation *terminatorOp = clonedBlock->getTerminator(); + + // Erase fir.result op of do loop and create yield op. + if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) { + rewriter.setInsertionPoint(terminatorOp); + rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc()); + terminatorOp->erase(); + } +} + +/// workdistributeDoLower method finds the fir.do_loop unoredered +/// nested in teams {workdistribute{fir.do_loop unoredered}} and +/// lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}. +/// +/// If fir.do_loop is present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// fir.do_loop unoredered { +/// ... +/// } +/// } +/// } +/// +/// Then, its lowered to +/// +/// omp.teams { +/// omp.parallel { +/// omp.distribute { +/// omp.wsloop { +/// omp.loop_nest +/// ... +/// } +/// } +/// } +/// } +/// } +static bool +workdistributeDoLower(omp::WorkdistributeOp workdistribute, + SetVector<omp::TargetOp> &targetOpsToProcess) { + OpBuilder rewriter(workdistribute); + auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute); + auto wdLoc = workdistribute->getLoc(); + if (doLoop && shouldParallelize(doLoop)) { + assert(doLoop.getReduceOperands().empty()); + + // Record the target ops to process later + if (auto teamsOp = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp())) { + auto targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp()); + if (targetOp) { + targetOpsToProcess.insert(targetOp); + } + } + // Generate the nested parallel, distribute, wsloop and loop_nest ops. + genParallelOp(wdLoc, rewriter, true); + genDistributeOp(wdLoc, rewriter, true); + mlir::omp::LoopNestOperands loopNestClauseOps; + genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps); + genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true); + workdistribute.erase(); + return true; + } + return false; +} + +/// Check if the enclosed type in fir.ref is fir.box and fir.box encloses array +static bool isEnclosedTypeRefToBoxArray(Type type) { + // Check if it's a reference type + if (auto refType = dyn_cast<fir::ReferenceType>(type)) { + // Get the referenced type (should be fir.box) + auto referencedType = refType.getEleTy(); + // Check if referenced type is a box + if (auto boxType = dyn_cast<fir::BoxType>(referencedType)) { + // Get the boxed type and check if it's an array + auto boxedType = boxType.getEleTy(); + // Check if boxed type is a sequence (array) + return isa<fir::SequenceType>(boxedType); + } + } + return false; +} + +/// Check if the enclosed type in fir.box is scalar (not array) +static bool isEnclosedTypeBoxScalar(Type type) { + // Check if it's a box type + if (auto boxType = dyn_cast<fir::BoxType>(type)) { + // Get the boxed type + auto boxedType = boxType.getEleTy(); + // Check if boxed type is NOT a sequence (array) + return !isa<fir::SequenceType>(boxedType); + } + return false; +} + +/// Check if the FortranAAssign call has src as scalar and dest as array +static bool isFortranAssignSrcScalarAndDestArray(fir::CallOp callOp) { + if (callOp.getNumOperands() < 2) + return false; + auto srcArg = callOp.getOperand(1); + auto destArg = callOp.getOperand(0); + // Both operands should be fir.convert ops + auto srcConvert = srcArg.getDefiningOp<fir::ConvertOp>(); + auto destConvert = destArg.getDefiningOp<fir::ConvertOp>(); + if (!srcConvert || !destConvert) { + emitError(callOp->getLoc(), + "Unimplemented: FortranAssign to OpenMP lowering\n"); + return false; + } + // Get the original types before conversion + auto srcOrigType = srcConvert.getValue().getType(); + auto destOrigType = destConvert.getValue().getType(); + + // Check if src is scalar and dest is array + bool srcIsScalar = isEnclosedTypeBoxScalar(srcOrigType); + bool destIsArray = isEnclosedTypeRefToBoxArray(destOrigType); + return srcIsScalar && destIsArray; +} + +/// Convert a flat index to multi-dimensional indices for an array box +/// Example: 2D array with shape (2,4) +/// Col 1 Col 2 Col 3 Col 4 +/// Row 1: (1,1) (1,2) (1,3) (1,4) +/// Row 2: (2,1) (2,2) (2,3) (2,4) +/// +/// extents: (2,4) +/// +/// flatIdx: 0 1 2 3 4 5 6 7 +/// Indices: (1,1) (1,2) (1,3) (1,4) (2,1) (2,2) (2,3) (2,4) +static SmallVector<Value> convertFlatToMultiDim(OpBuilder &builder, + Location loc, Value flatIdx, + Value arrayBox) { + // Get array type and rank + auto boxType = cast<fir::BoxType>(arrayBox.getType()); + auto seqType = cast<fir::SequenceType>(boxType.getEleTy()); + int rank = seqType.getDimension(); + + // Get all extents + SmallVector<Value> extents; + // Get extents for each dimension + for (int i = 0; i < rank; ++i) { + auto dimIdx = arith::ConstantIndexOp::create(builder, loc, i); + auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx); + extents.push_back(boxDims.getResult(1)); + } + + // Convert flat index to multi-dimensional indices + SmallVector<Value> indices(rank); + Value temp = flatIdx; + auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1); + + // Work backwards through dimensions (row-major order) + for (int i = rank - 1; i >= 0; --i) { + Value zeroBasedIdx = builder.create<arith::RemSIOp>(loc, temp, extents[i]); + // Convert to one-based index + indices[i] = builder.create<arith::AddIOp>(loc, zeroBasedIdx, c1); + if (i > 0) { + temp = builder.create<arith::DivSIOp>(loc, temp, extents[i]); + } + } + + return indices; +} + +/// Calculate the total number of elements in the array box +/// (totalElems = extent(1) * extent(2) * ... * extent(n)) +static Value CalculateTotalElements(OpBuilder &builder, Location loc, + Value arrayBox) { + auto boxType = cast<fir::BoxType>(arrayBox.getType()); + auto seqType = cast<fir::SequenceType>(boxType.getEleTy()); + int rank = seqType.getDimension(); + + Value totalElems = nullptr; + for (int i = 0; i < rank; ++i) { + auto dimIdx = arith::ConstantIndexOp::create(builder, loc, i); + auto boxDims = fir::BoxDimsOp::create(builder, loc, arrayBox, dimIdx); + Value extent = boxDims.getResult(1); + if (i == 0) { + totalElems = extent; + } else { + totalElems = builder.create<arith::MulIOp>(loc, totalElems, extent); + } + } + return totalElems; +} + +/// Replace the FortranAAssign runtime call with an unordered do loop +static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc, + omp::TeamsOp teamsOp, + omp::WorkdistributeOp workdistribute, + fir::CallOp callOp) { + auto destConvert = callOp.getOperand(0).getDefiningOp<fir::ConvertOp>(); + auto srcConvert = callOp.getOperand(1).getDefiningOp<fir::ConvertOp>(); + + Value destBox = destConvert.getValue(); + Value srcBox = srcConvert.getValue(); + + // get defining alloca op of destBox and srcBox + auto destAlloca = destBox.getDefiningOp<fir::AllocaOp>(); + + if (!destAlloca) { + emitError(loc, "Unimplemented: FortranAssign to OpenMP lowering\n"); + return; + } + + // get the store op that stores to the alloca + for (auto user : destAlloca->getUsers()) { + if (auto storeOp = dyn_cast<fir::StoreOp>(user)) { + destBox = storeOp.getValue(); + break; + } + } + + builder.setInsertionPoint(teamsOp); + // Load destination array box (if it's a reference) + Value arrayBox = destBox; + if (isa<fir::ReferenceType>(destBox.getType())) + arrayBox = builder.create<fir::LoadOp>(loc, destBox); + + auto scalarValue = builder.create<fir::BoxAddrOp>(loc, srcBox); + Value scalar = builder.create<fir::LoadOp>(loc, scalarValue); + + // Calculate total number of elements (flattened) + auto c0 = builder.create<arith::ConstantIndexOp>(loc, 0); + auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1); + Value totalElems = CalculateTotalElements(builder, loc, arrayBox); + + auto *workdistributeBlock = &workdistribute.getRegion().front(); + builder.setInsertionPointToStart(workdistributeBlock); + // Create single unordered loop for flattened array + auto doLoop = fir::DoLoopOp::create(builder, loc, c0, totalElems, c1, true); + Block *loopBlock = &doLoop.getRegion().front(); + builder.setInsertionPointToStart(doLoop.getBody()); + + auto flatIdx = loopBlock->getArgument(0); + SmallVector<Value> indices = + convertFlatToMultiDim(builder, loc, flatIdx, arrayBox); + // Use fir.array_coor for linear addressing + auto elemPtr = fir::ArrayCoorOp::create( + builder, loc, fir::ReferenceType::get(scalar.getType()), arrayBox, + nullptr, nullptr, ValueRange{indices}, ValueRange{}); + + builder.create<fir::StoreOp>(loc, scalar, elemPtr); +} + +/// workdistributeRuntimeCallLower method finds the runtime calls +/// nested in teams {workdistribute{}} and +/// lowers FortranAAssign to unordered do loop if src is scalar and dest is +/// array. Other runtime calls are not handled currently. +static FailureOr<bool> +workdistributeRuntimeCallLower(omp::WorkdistributeOp workdistribute, + SetVector<omp::TargetOp> &targetOpsToProcess) { + OpBuilder rewriter(workdistribute); + auto loc = workdistribute->getLoc(); + auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp()); + if (!teams) { + emitError(loc, "workdistribute not nested in teams\n"); + return failure(); + } + if (workdistribute.getRegion().getBlocks().size() != 1) { + emitError(loc, "workdistribute with multiple blocks\n"); + return failure(); + } + if (teams.getRegion().getBlocks().size() != 1) { + emitError(loc, "teams with multiple blocks\n"); + return failure(); + } + bool changed = false; + // Get the target op parent of teams + omp::TargetOp targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp()); + SmallVector<Operation *> opsToErase; + for (auto &op : workdistribute.getOps()) { + if (isRuntimeCall(&op)) { + rewriter.setInsertionPoint(&op); + fir::CallOp runtimeCall = cast<fir::CallOp>(op); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); + if (funcName == FortranAssignStr) { + if (isFortranAssignSrcScalarAndDestArray(runtimeCall) && targetOp) { + // Record the target ops to process later + targetOpsToProcess.insert(targetOp); + replaceWithUnorderedDoLoop(rewriter, loc, teams, workdistribute, + runtimeCall); + opsToErase.push_back(&op); + changed = true; + } + } + } + } + // Erase the runtime calls that have been replaced. + for (auto *op : opsToErase) { + op->erase(); + } + return changed; +} + +/// teamsWorkdistributeToSingleOp method hoists all the ops inside +/// teams {workdistribute{}} before teams op. +/// +/// If A() and B () are present inside teams workdistribute +/// +/// omp.teams { +/// omp.workdistribute { +/// A() +/// B() +/// } +/// } +/// +/// Then, its lowered to +/// +/// A() +/// B() +/// +/// If only the terminator remains in teams after hoisting, we erase teams op. +static bool +teamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp, + SetVector<omp::TargetOp> &targetOpsToProcess) { + auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp); + if (!workdistributeOp) + return false; + // Get the block containing teamsOp (the parent block). + Block *parentBlock = teamsOp->getBlock(); + Block &workdistributeBlock = *workdistributeOp.getRegion().begin(); + // Record the target ops to process later + for (auto &op : workdistributeBlock.getOperations()) { + if (shouldParallelize(&op)) { + auto targetOp = dyn_cast<omp::TargetOp>(teamsOp->getParentOp()); + if (targetOp) { + targetOpsToProcess.insert(targetOp); + } + } + } + auto insertPoint = Block::iterator(teamsOp); + // Get the range of operations to move (excluding the terminator). + auto workdistributeBegin = workdistributeBlock.begin(); + auto workdistributeEnd = workdistributeBlock.getTerminator()->getIterator(); + // Move the operations from workdistribute block to before teamsOp. + parentBlock->getOperations().splice(insertPoint, + workdistributeBlock.getOperations(), + workdistributeBegin, workdistributeEnd); + // Erase the now-empty workdistributeOp. + workdistributeOp.erase(); + Block &teamsBlock = *teamsOp.getRegion().begin(); + // Check if only the terminator remains and erase teams op. + if (teamsBlock.getOperations().size() == 1 && + teamsBlock.getTerminator() != nullptr) { + teamsOp.erase(); + } + return true; +} + +/// If multiple workdistribute are nested in a target regions, we will need to +/// split the target region, but we want to preserve the data semantics of the +/// original data region and avoid unnecessary data movement at each of the +/// subkernels - we split the target region into a target_data{target} +/// nest where only the outer one moves the data +FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, + RewriterBase &rewriter) { + auto loc = targetOp->getLoc(); + if (targetOp.getMapVars().empty()) { + emitError(loc, "Target region has no data maps\n"); + return failure(); + } + // Collect all the mapinfo ops + SmallVector<omp::MapInfoOp> mapInfos; + for (auto opr : targetOp.getMapVars()) { + auto mapInfo = cast<omp::MapInfoOp>(opr.getDefiningOp()); + mapInfos.push_back(mapInfo); + } + + rewriter.setInsertionPoint(targetOp); + SmallVector<Value> innerMapInfos; + SmallVector<Value> outerMapInfos; + // Create new mapinfo ops for the inner target region + for (auto mapInfo : mapInfos) { + auto originalMapType = + (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType()); + auto originalCaptureType = mapInfo.getMapCaptureType(); + llvm::omp::OpenMPOffloadMappingFlags newMapType; + mlir::omp::VariableCaptureKind newCaptureType; + // For bycopy, we keep the same map type and capture type + // For byref, we change the map type to none and keep the capture type + if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) { + newMapType = originalMapType; + newCaptureType = originalCaptureType; + } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) { + newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + newCaptureType = originalCaptureType; + outerMapInfos.push_back(mapInfo); + } else { + emitError(targetOp->getLoc(), "Unhandled case"); + return failure(); + } + auto innerMapInfo = cast<omp::MapInfoOp>(rewriter.clone(*mapInfo)); + innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr( + rewriter.getIntegerType(64, false), + static_cast< + std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( + newMapType))); + innerMapInfo.setMapCaptureType(newCaptureType); + innerMapInfos.push_back(innerMapInfo.getResult()); + } + + rewriter.setInsertionPoint(targetOp); + auto device = targetOp.getDevice(); + auto ifExpr = targetOp.getIfExpr(); + auto deviceAddrVars = targetOp.getHasDeviceAddrVars(); + auto devicePtrVars = targetOp.getIsDevicePtrVars(); + // Create the target data op + auto targetDataOp = rewriter.create<omp::TargetDataOp>( + loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars); + auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion()); + rewriter.create<mlir::omp::TerminatorOp>(loc); + rewriter.setInsertionPointToStart(taregtDataBlock); + // Create the inner target op + auto newTargetOp = rewriter.create<omp::TargetOp>( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + targetOp.getHostEvalVars(), targetOp.getIfExpr(), + targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), + targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(), + innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(), + targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(), + targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr()); + rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(), + newTargetOp.getRegion().begin()); + rewriter.replaceOp(targetOp, targetDataOp); + return newTargetOp; +} + +/// getNestedOpToIsolate function is designed to identify a specific teams +/// parallel op within the body of an omp::TargetOp that should be "isolated." +/// This returns a tuple of op, if its first op in targetBlock, or if the op is +/// last op in the traget block. +static std::optional<std::tuple<Operation *, bool, bool>> +getNestedOpToIsolate(omp::TargetOp targetOp) { + if (targetOp.getRegion().empty()) + return std::nullopt; + auto *targetBlock = &targetOp.getRegion().front(); + for (auto &op : *targetBlock) { + bool first = &op == &*targetBlock->begin(); + bool last = op.getNextNode() == targetBlock->getTerminator(); + if (first && last) + return std::nullopt; + + if (isa<omp::TeamsOp>(&op)) + return {{&op, first, last}}; + } + return std::nullopt; +} + +/// Temporary structure to hold the two mapinfo ops +struct TempOmpVar { + omp::MapInfoOp from, to; +}; + +/// isPtr checks if the type is a pointer or reference type. +static bool isPtr(Type ty) { + return isa<fir::ReferenceType>(ty) || isa<LLVM::LLVMPointerType>(ty); +} + +/// getPtrTypeForOmp returns an LLVM pointer type for the given type. +static Type getPtrTypeForOmp(Type ty) { + if (isPtr(ty)) + return LLVM::LLVMPointerType::get(ty.getContext()); + else + return fir::ReferenceType::get(ty); +} + +/// allocateTempOmpVar allocates a temporary variable for OpenMP mapping +static TempOmpVar allocateTempOmpVar(Location loc, Type ty, + RewriterBase &rewriter) { + MLIRContext &ctx = *ty.getContext(); + Value alloc; + Type allocType; + auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx); + // Get the appropriate type for allocation + if (isPtr(ty)) { + Type intTy = rewriter.getI32Type(); + auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1); + allocType = llvmPtrTy; + alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one); + allocType = intTy; + } else { + allocType = ty; + alloc = rewriter.create<fir::AllocaOp>(loc, allocType); + } + // Lambda to create mapinfo ops + auto getMapInfo = [&](uint64_t mappingFlags, const char *name) { + return rewriter.create<omp::MapInfoOp>( + loc, alloc.getType(), alloc, TypeAttr::get(allocType), + rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false), + mappingFlags), + rewriter.getAttr<omp::VariableCaptureKindAttr>( + omp::VariableCaptureKind::ByRef), + /*varPtrPtr=*/Value{}, + /*members=*/SmallVector<Value>{}, + /*member_index=*/mlir::ArrayAttr{}, + /*bounds=*/ValueRange(), + /*mapperId=*/mlir::FlatSymbolRefAttr(), + /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false)); + }; + // Create mapinfo ops. + uint64_t mapFrom = + static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); + uint64_t mapTo = + static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); + auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from"); + auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to"); + return TempOmpVar{mapInfoFrom, mapInfoTo}; +} + +// usedOutsideSplit checks if a value is used outside the split operation. +static bool usedOutsideSplit(Value v, Operation *split) { + if (!split) + return false; + auto targetOp = cast<omp::TargetOp>(split->getParentOp()); + auto *targetBlock = &targetOp.getRegion().front(); + for (auto *user : v.getUsers()) { + while (user->getBlock() != targetBlock) { + user = user->getParentOp(); + } + if (!user->isBeforeInBlock(split)) + return true; + } + return false; +} + +/// isRecomputableAfterFission checks if an operation can be recomputed +static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) { + // If the op has side effects, it cannot be recomputed. + // We consider fir.declare as having no side effects. + return isa<fir::DeclareOp>(op) || isMemoryEffectFree(op); +} + +/// collectNonRecomputableDeps collects dependencies that cannot be recomputed +static void collectNonRecomputableDeps(Value &v, omp::TargetOp targetOp, + SetVector<Operation *> &nonRecomputable, + SetVector<Operation *> &toCache, + SetVector<Operation *> &toRecompute) { + Operation *op = v.getDefiningOp(); + // If v is a block argument, it must be from the targetOp. + if (!op) { + assert(cast<BlockArgument>(v).getOwner()->getParentOp() == targetOp); + return; + } + // If the op is in the nonRecomputable set, add it to toCache and return. + if (nonRecomputable.contains(op)) { + toCache.insert(op); + return; + } + // Add the op to toRecompute. + toRecompute.insert(op); + for (auto opr : op->getOperands()) + collectNonRecomputableDeps(opr, targetOp, nonRecomputable, toCache, + toRecompute); +} + +/// createBlockArgsAndMap creates block arguments and maps them +static void createBlockArgsAndMap(Location loc, RewriterBase &rewriter, + omp::TargetOp &targetOp, Block *targetBlock, + Block *newTargetBlock, + SmallVector<Value> &hostEvalVars, + SmallVector<Value> &mapOperands, + SmallVector<Value> &allocs, + IRMapping &irMapping) { + // FIRST: Map `host_eval_vars` to block arguments + unsigned originalHostEvalVarsSize = targetOp.getHostEvalVars().size(); + for (unsigned i = 0; i < hostEvalVars.size(); ++i) { + Value originalValue; + BlockArgument newArg; + if (i < originalHostEvalVarsSize) { + originalValue = targetBlock->getArgument(i); // Host_eval args come first + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } else { + originalValue = hostEvalVars[i]; + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } + irMapping.map(originalValue, newArg); + } + + // SECOND: Map `map_operands` to block arguments + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + for (unsigned i = 0; i < mapOperands.size(); ++i) { + Value originalValue; + BlockArgument newArg; + // Map the new arguments from the original block. + if (i < originalMapVarsSize) { + originalValue = targetBlock->getArgument(originalHostEvalVarsSize + + i); // Offset by host_eval count + newArg = newTargetBlock->addArgument(originalValue.getType(), + originalValue.getLoc()); + } + // Map the new arguments from the `allocs`. + else { + originalValue = allocs[i - originalMapVarsSize]; + newArg = newTargetBlock->addArgument( + getPtrTypeForOmp(originalValue.getType()), originalValue.getLoc()); + } + irMapping.map(originalValue, newArg); + } + + // THIRD: Map `private_vars` to block arguments (if any) + unsigned originalPrivateVarsSize = targetOp.getPrivateVars().size(); + for (unsigned i = 0; i < originalPrivateVarsSize; ++i) { + auto originalArg = targetBlock->getArgument(originalHostEvalVarsSize + + originalMapVarsSize + i); + auto newArg = newTargetBlock->addArgument(originalArg.getType(), + originalArg.getLoc()); + irMapping.map(originalArg, newArg); + } + return; +} + +/// reloadCacheAndRecompute reloads cached values and recomputes operations +static void reloadCacheAndRecompute( + Location loc, RewriterBase &rewriter, Operation *splitBefore, + omp::TargetOp &targetOp, Block *targetBlock, Block *newTargetBlock, + SmallVector<Value> &hostEvalVars, SmallVector<Value> &mapOperands, + SmallVector<Value> &allocs, SetVector<Operation *> &toRecompute, + IRMapping &irMapping) { + // Handle the load operations for the allocs. + rewriter.setInsertionPointToStart(newTargetBlock); + auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); + + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + unsigned hostEvalVarsSize = hostEvalVars.size(); + // Create load operations for each allocated variable. + for (unsigned i = 0; i < allocs.size(); ++i) { + Value original = allocs[i]; + // Get the new block argument for this specific allocated value. + Value newArg = + newTargetBlock->getArgument(hostEvalVarsSize + originalMapVarsSize + i); + Value restored; + // If the original value is a pointer or reference, load and convert if + // necessary. + if (isPtr(original.getType())) { + restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg); + if (!isa<LLVM::LLVMPointerType>(original.getType())) + restored = + rewriter.create<fir::ConvertOp>(loc, original.getType(), restored); + } else { + restored = rewriter.create<fir::LoadOp>(loc, newArg); + } + irMapping.map(original, restored); + } + // Clone the operations if they are in the toRecompute set. + for (auto it = targetBlock->begin(); it != splitBefore->getIterator(); it++) { + if (toRecompute.contains(&*it)) + rewriter.clone(*it, irMapping); + } +} + +/// Given a teamsOp, navigate down the nested structure to find the +/// innermost LoopNestOp. The expected nesting is: +/// teams -> parallel -> distribute -> wsloop -> loop_nest +static mlir::omp::LoopNestOp getLoopNestFromTeams(mlir::omp::TeamsOp teamsOp) { + if (teamsOp.getRegion().empty()) + return nullptr; + // Ensure the teams region has a single block. + if (teamsOp.getRegion().getBlocks().size() != 1) + return nullptr; + // Find parallel op inside teams + mlir::omp::ParallelOp parallelOp = nullptr; + // Look for the parallel op in the teams region + for (auto &op : teamsOp.getRegion().front()) { + if (auto parallel = dyn_cast<mlir::omp::ParallelOp>(op)) { + parallelOp = parallel; + break; + } + } + if (!parallelOp) + return nullptr; + + // Find distribute op inside parallel + mlir::omp::DistributeOp distributeOp = nullptr; + for (auto &op : parallelOp.getRegion().front()) { + if (auto distribute = dyn_cast<mlir::omp::DistributeOp>(op)) { + distributeOp = distribute; + break; + } + } + if (!distributeOp) + return nullptr; + + // Find wsloop op inside distribute + mlir::omp::WsloopOp wsloopOp = nullptr; + for (auto &op : distributeOp.getRegion().front()) { + if (auto wsloop = dyn_cast<mlir::omp::WsloopOp>(op)) { + wsloopOp = wsloop; + break; + } + } + if (!wsloopOp) + return nullptr; + + // Find loop_nest op inside wsloop + for (auto &op : wsloopOp.getRegion().front()) { + if (auto loopNest = dyn_cast<mlir::omp::LoopNestOp>(op)) { + return loopNest; + } + } + + return nullptr; +} + +/// Generate LLVM constant operations for i32 and i64 types. +static mlir::LLVM::ConstantOp +genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { + mlir::Type i32Ty = rewriter.getI32Type(); + mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value); + return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr); +} + +/// Given a box descriptor, extract the base address of the data it describes. +/// If the box descriptor is a reference, load it first. +/// The base address is returned as an i8* pointer. +static Value genDescriptorGetBaseAddress(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa<fir::BoxType>(box.getType()) && + "Unknown type passed to genDescriptorGetBaseAddress"); + auto i8Type = builder.getI8Type(); + auto unknownArrayType = + fir::SequenceType::get({fir::SequenceType::getUnknownExtent()}, i8Type); + auto i8BoxType = fir::BoxType::get(unknownArrayType); + auto typedBox = fir::ConvertOp::create(builder, loc, i8BoxType, box); + auto rawAddr = fir::BoxAddrOp::create(builder, loc, typedBox); + return rawAddr; +} + +/// Given a box descriptor, extract the total number of elements in the array it +/// describes. If the box descriptor is a reference, load it first. +/// The total number of elements is returned as an i64 value. +static Value genDescriptorGetTotalElements(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa<fir::BoxType>(box.getType()) && + "Unknown type passed to genDescriptorGetTotalElements"); + auto i64Type = builder.getI64Type(); + return fir::BoxTotalElementsOp::create(builder, loc, i64Type, box); +} + +/// Given a box descriptor, extract the size of each element in the array it +/// describes. If the box descriptor is a reference, load it first. +/// The element size is returned as an i64 value. +static Value genDescriptorGetEleSize(fir::FirOpBuilder &builder, Location loc, + Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa<fir::BoxType>(box.getType()) && + "Unknown type passed to genDescriptorGetElementSize"); + auto i64Type = builder.getI64Type(); + return fir::BoxEleSizeOp::create(builder, loc, i64Type, box); +} + +/// Given a box descriptor, compute the total size in bytes of the data it +/// describes. This is done by multiplying the total number of elements by the +/// size of each element. If the box descriptor is a reference, load it first. +/// The total size in bytes is returned as an i64 value. +static Value genDescriptorGetDataSizeInBytes(fir::FirOpBuilder &builder, + Location loc, Value boxDesc) { + Value box = boxDesc; + if (auto refBox = dyn_cast<fir::ReferenceType>(boxDesc.getType())) { + box = fir::LoadOp::create(builder, loc, boxDesc); + } + assert(isa<fir::BoxType>(box.getType()) && + "Unknown type passed to genDescriptorGetElementSize"); + Value eleSize = genDescriptorGetEleSize(builder, loc, box); + Value totalElements = genDescriptorGetTotalElements(builder, loc, box); + return mlir::arith::MulIOp::create(builder, loc, totalElements, eleSize); +} + +/// Generate a call to the OpenMP runtime function `omp_get_mapped_ptr` to +/// retrieve the device pointer corresponding to a given host pointer and device +/// number. If no mapping exists, the original host pointer is returned. +/// Signature: +/// void *omp_get_mapped_ptr(void *host_ptr, int device_num); +static mlir::Value genOmpGetMappedPtrIfPresent(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value hostPtr, + mlir::Value deviceNum, + mlir::ModuleOp module) { + auto *context = builder.getContext(); + auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type()); + auto i32Type = builder.getI32Type(); + auto funcName = "omp_get_mapped_ptr"; + auto funcOp = module.lookupSymbol<mlir::func::FuncOp>(funcName); + + if (!funcOp) { + auto funcType = + mlir::FunctionType::get(context, {voidPtrType, i32Type}, {voidPtrType}); + + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType); + funcOp.setPrivate(); + } + + llvm::SmallVector<mlir::Value> args; + args.push_back(fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr)); + args.push_back(fir::ConvertOp::create(builder, loc, i32Type, deviceNum)); + auto callOp = fir::CallOp::create(builder, loc, funcOp, args); + auto mappedPtr = callOp.getResult(0); + auto isNull = builder.genIsNullAddr(loc, mappedPtr); + auto convertedHostPtr = + fir::ConvertOp::create(builder, loc, voidPtrType, hostPtr); + auto result = arith::SelectOp::create(builder, loc, isNull, convertedHostPtr, + mappedPtr); + return result; +} + +/// Generate a call to the OpenMP runtime function `omp_target_memcpy` to +/// perform memory copy between host and device or between devices. +/// Signature: +/// int omp_target_memcpy(void *dst, const void *src, size_t length, +/// size_t dst_offset, size_t src_offset, +/// int dst_device, int src_device); +static void genOmpTargetMemcpyCall(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value dst, + mlir::Value src, mlir::Value length, + mlir::Value dstOffset, mlir::Value srcOffset, + mlir::Value device, mlir::ModuleOp module) { + auto *context = builder.getContext(); + auto funcName = "omp_target_memcpy"; + auto voidPtrType = fir::LLVMPointerType::get(context, builder.getI8Type()); + auto sizeTType = builder.getI64Type(); // assuming size_t is 64-bit + auto i32Type = builder.getI32Type(); + auto funcOp = module.lookupSymbol<mlir::func::FuncOp>(funcName); + + if (!funcOp) { + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + llvm::SmallVector<mlir::Type> argTypes = { + voidPtrType, voidPtrType, sizeTType, sizeTType, + sizeTType, i32Type, i32Type}; + auto funcType = mlir::FunctionType::get(context, argTypes, {i32Type}); + funcOp = mlir::func::FuncOp::create(builder, loc, funcName, funcType); + funcOp.setPrivate(); + } + + llvm::SmallVector<mlir::Value> args{dst, src, length, dstOffset, + srcOffset, device, device}; + fir::CallOp::create(builder, loc, funcOp, args); + return; +} + +/// Generate code to replace a Fortran array assignment call with OpenMP +/// runtime calls to perform the equivalent operation on the device. +/// This involves extracting the source and destination pointers from the +/// Fortran array descriptors, retrieving their mapped device pointers (if any), +/// and invoking `omp_target_memcpy` to copy the data on the device. +static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder, + mlir::Location loc, + fir::CallOp callOp, + mlir::Value device, + mlir::ModuleOp module) { + assert(callOp.getNumResults() == 0 && + "Expected _FortranAAssign to have no results"); + assert(callOp.getNumOperands() >= 2 && + "Expected _FortranAAssign to have at least two operands"); + + // Extract the source and destination pointers from the call operands. + mlir::Value dest = callOp.getOperand(0); + mlir::Value src = callOp.getOperand(1); + + // Get the base addresses of the source and destination arrays. + mlir::Value srcBase = genDescriptorGetBaseAddress(builder, loc, src); + mlir::Value destBase = genDescriptorGetBaseAddress(builder, loc, dest); + + // Get the total size in bytes of the data to be copied. + mlir::Value srcDataSize = genDescriptorGetDataSizeInBytes(builder, loc, src); + + // Retrieve the mapped device pointers for source and destination. + // If no mapping exists, the original host pointer is used. + Value destPtr = + genOmpGetMappedPtrIfPresent(builder, loc, destBase, device, module); + Value srcPtr = + genOmpGetMappedPtrIfPresent(builder, loc, srcBase, device, module); + Value zero = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), + builder.getI64IntegerAttr(0)); + + // Generate the call to omp_target_memcpy to perform the data copy on the + // device. + genOmpTargetMemcpyCall(builder, loc, destPtr, srcPtr, srcDataSize, zero, zero, + device, module); +} + +/// Struct to hold the host eval vars corresponding to loop bounds and steps +struct HostEvalVars { + SmallVector<Value> lbs; + SmallVector<Value> ubs; + SmallVector<Value> steps; +}; + +/// moveToHost method clones all the ops from target region outside of it. +/// It hoists runtime function "_FortranAAssign" and replaces it with omp +/// version. Also hoists and replaces fir.allocmem with omp.target_allocmem and +/// fir.freemem with omp.target_freemem +static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, + mlir::ModuleOp module, + struct HostEvalVars &hostEvalVars) { + OpBuilder::InsertionGuard guard(rewriter); + Block *targetBlock = &targetOp.getRegion().front(); + assert(targetBlock == &targetOp.getRegion().back()); + IRMapping mapping; + + // Get the parent target_data op + auto targetDataOp = cast<omp::TargetDataOp>(targetOp->getParentOp()); + if (!targetDataOp) { + emitError(targetOp->getLoc(), + "Expected target op to be inside target_data op"); + return failure(); + } + // create mapping for host_eval_vars + unsigned hostEvalVarCount = targetOp.getHostEvalVars().size(); + for (unsigned i = 0; i < targetOp.getHostEvalVars().size(); ++i) { + Value hostEvalVar = targetOp.getHostEvalVars()[i]; + BlockArgument arg = targetBlock->getArguments()[i]; + mapping.map(arg, hostEvalVar); + } + // create mapping for map_vars + for (unsigned i = 0; i < targetOp.getMapVars().size(); ++i) { + Value mapInfo = targetOp.getMapVars()[i]; + BlockArgument arg = targetBlock->getArguments()[hostEvalVarCount + i]; + Operation *op = mapInfo.getDefiningOp(); + assert(op); + auto mapInfoOp = cast<omp::MapInfoOp>(op); + // map the block argument to the host-side variable pointer + mapping.map(arg, mapInfoOp.getVarPtr()); + } + // create mapping for private_vars + unsigned mapSize = targetOp.getMapVars().size(); + for (unsigned i = 0; i < targetOp.getPrivateVars().size(); ++i) { + Value privateVar = targetOp.getPrivateVars()[i]; + // The mapping should link the device-side variable to the host-side one. + BlockArgument arg = + targetBlock->getArguments()[hostEvalVarCount + mapSize + i]; + // Map the device-side copy (`arg`) to the host-side value (`privateVar`). + mapping.map(arg, privateVar); + } + + rewriter.setInsertionPoint(targetOp); + SmallVector<Operation *> opsToReplace; + Value device = targetOp.getDevice(); + + // If device is not specified, default to device 0. + if (!device) { + device = genI32Constant(targetOp.getLoc(), rewriter, 0); + } + // Clone all operations. + for (auto it = targetBlock->begin(), end = std::prev(targetBlock->end()); + it != end; ++it) { + auto *op = &*it; + Operation *clonedOp = rewriter.clone(*op, mapping); + // Map the results of the original op to the cloned op. + for (unsigned i = 0; i < op->getNumResults(); ++i) { + mapping.map(op->getResult(i), clonedOp->getResult(i)); + } + // fir.declare changes its type when hoisting it out of omp.target to + // omp.target_data Introduce a load, if original declareOp input is not of + // reference type, but cloned delcareOp input is reference type. + if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(clonedOp)) { + auto originalDeclareOp = cast<fir::DeclareOp>(op); + Type originalInType = originalDeclareOp.getMemref().getType(); + Type clonedInType = clonedDeclareOp.getMemref().getType(); + + fir::ReferenceType originalRefType = + dyn_cast<fir::ReferenceType>(originalInType); + fir::ReferenceType clonedRefType = + dyn_cast<fir::ReferenceType>(clonedInType); + if (!originalRefType && clonedRefType) { + Type clonedEleTy = clonedRefType.getElementType(); + if (clonedEleTy == originalDeclareOp.getType()) { + opsToReplace.push_back(clonedOp); + } + } + } + // Collect the ops to be replaced. + if (isa<fir::AllocMemOp>(clonedOp) || isa<fir::FreeMemOp>(clonedOp)) + opsToReplace.push_back(clonedOp); + // Check for runtime calls to be replaced. + if (isRuntimeCall(clonedOp)) { + fir::CallOp runtimeCall = cast<fir::CallOp>(op); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); + if (funcName == FortranAssignStr) { + opsToReplace.push_back(clonedOp); + } else { + emitError(runtimeCall->getLoc(), "Unhandled runtime call hoisting."); + return failure(); + } + } + } + // Replace fir.allocmem with omp.target_allocmem. + for (Operation *op : opsToReplace) { + if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) { + rewriter.setInsertionPoint(allocOp); + auto ompAllocmemOp = rewriter.create<omp::TargetAllocMemOp>( + allocOp.getLoc(), rewriter.getI64Type(), device, + allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(), + allocOp.getBindcNameAttr(), allocOp.getTypeparams(), + allocOp.getShape()); + auto firConvertOp = rewriter.create<fir::ConvertOp>( + allocOp.getLoc(), allocOp.getResult().getType(), + ompAllocmemOp.getResult()); + rewriter.replaceOp(allocOp, firConvertOp.getResult()); + } + // Replace fir.freemem with omp.target_freemem. + else if (auto freeOp = dyn_cast<fir::FreeMemOp>(op)) { + rewriter.setInsertionPoint(freeOp); + auto firConvertOp = rewriter.create<fir::ConvertOp>( + freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref()); + rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device, + firConvertOp.getResult()); + rewriter.eraseOp(freeOp); + } + // fir.declare changes its type when hoisting it out of omp.target to + // omp.target_data Introduce a load, if original declareOp input is not of + // reference type, but cloned delcareOp input is reference type. + else if (fir::DeclareOp clonedDeclareOp = dyn_cast<fir::DeclareOp>(op)) { + Type clonedInType = clonedDeclareOp.getMemref().getType(); + fir::ReferenceType clonedRefType = + dyn_cast<fir::ReferenceType>(clonedInType); + Type clonedEleTy = clonedRefType.getElementType(); + rewriter.setInsertionPoint(op); + Value loadedValue = rewriter.create<fir::LoadOp>( + clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref()); + clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue); + } + // Replace runtime calls with omp versions. + else if (isRuntimeCall(op)) { + fir::CallOp runtimeCall = cast<fir::CallOp>(op); + auto funcName = runtimeCall.getCallee()->getRootReference().getValue(); + if (funcName == FortranAssignStr) { + rewriter.setInsertionPoint(op); + fir::FirOpBuilder builder{rewriter, op}; + + mlir::Location loc = runtimeCall.getLoc(); + genFortranAssignOmpReplacement(builder, loc, runtimeCall, device, + module); + rewriter.eraseOp(op); + } else { + emitError(runtimeCall->getLoc(), "Unhandled runtime call hoisting."); + return failure(); + } + } else { + emitError(op->getLoc(), "Unhandled op hoisting."); + return failure(); + } + } + + // Update the host_eval_vars to use the mapped values. + for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) { + hostEvalVars.lbs[i] = mapping.lookup(hostEvalVars.lbs[i]); + hostEvalVars.ubs[i] = mapping.lookup(hostEvalVars.ubs[i]); + hostEvalVars.steps[i] = mapping.lookup(hostEvalVars.steps[i]); + } + // Finally erase the original targetOp. + rewriter.eraseOp(targetOp); + return success(); +} + +/// Result of isolateOp method +struct SplitResult { + omp::TargetOp preTargetOp; + omp::TargetOp isolatedTargetOp; + omp::TargetOp postTargetOp; +}; + +/// computeAllocsCacheRecomputable method computes the allocs needed to cache +/// the values that are used outside the split point. It also computes the ops +/// that need to be cached and the ops that can be recomputed after the split. +static void computeAllocsCacheRecomputable( + omp::TargetOp targetOp, Operation *splitBeforeOp, RewriterBase &rewriter, + SmallVector<Value> &preMapOperands, SmallVector<Value> &postMapOperands, + SmallVector<Value> &allocs, SmallVector<Value> &requiredVals, + SetVector<Operation *> &nonRecomputable, SetVector<Operation *> &toCache, + SetVector<Operation *> &toRecompute) { + auto *targetBlock = &targetOp.getRegion().front(); + // Find all values that are used outside the split point. + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); + it++) { + // Check if any of the results are used outside the split point. + for (auto res : it->getResults()) { + if (usedOutsideSplit(res, splitBeforeOp)) { + requiredVals.push_back(res); + } + } + // If the op is not recomputable, add it to the nonRecomputable set. + if (!isRecomputableAfterFission(&*it, splitBeforeOp)) { + nonRecomputable.insert(&*it); + } + } + // For each required value, collect its dependencies. + for (auto requiredVal : requiredVals) + collectNonRecomputableDeps(requiredVal, targetOp, nonRecomputable, toCache, + toRecompute); + // For each op in toCache, create an alloc and update the pre and post map + // operands. + for (Operation *op : toCache) { + for (auto res : op->getResults()) { + auto alloc = + allocateTempOmpVar(targetOp.getLoc(), res.getType(), rewriter); + allocs.push_back(res); + preMapOperands.push_back(alloc.from); + postMapOperands.push_back(alloc.to); + } + } +} + +/// genPreTargetOp method generates the preTargetOp that contains all the ops +/// before the split point. It also creates the block arguments and maps the +/// values accordingly. It also creates the store operations for the allocs. +static omp::TargetOp +genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands, + SmallVector<Value> &allocs, Operation *splitBeforeOp, + RewriterBase &rewriter, struct HostEvalVars &hostEvalVars, + bool isTargetDevice) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector<Value> preHostEvalVars{targetOp.getHostEvalVars()}; + // update the hostEvalVars of preTargetOp + omp::TargetOp preTargetOp = rewriter.create<omp::TargetOp>( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars, + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + auto *preTargetBlock = rewriter.createBlock( + &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {}); + IRMapping preMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, preTargetBlock, + preHostEvalVars, preMapOperands, allocs, preMapping); + + // Handle the store operations for the allocs. + rewriter.setInsertionPointToStart(preTargetBlock); + auto llvmPtrTy = LLVM::LLVMPointerType::get(targetOp.getContext()); + + // Clone the original operations. + for (auto it = targetBlock->begin(); it != splitBeforeOp->getIterator(); + it++) { + rewriter.clone(*it, preMapping); + } + + unsigned originalHostEvalVarsSize = preHostEvalVars.size(); + unsigned originalMapVarsSize = targetOp.getMapVars().size(); + // Create Stores for allocs. + for (unsigned i = 0; i < allocs.size(); ++i) { + Value originalResult = allocs[i]; + Value toStore = preMapping.lookup(originalResult); + // Get the new block argument for this specific allocated value. + Value newArg = preTargetBlock->getArgument(originalHostEvalVarsSize + + originalMapVarsSize + i); + // Create the store operation. + if (isPtr(originalResult.getType())) { + if (!isa<LLVM::LLVMPointerType>(toStore.getType())) + toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore); + rewriter.create<LLVM::StoreOp>(loc, toStore, newArg); + } else { + rewriter.create<fir::StoreOp>(loc, toStore, newArg); + } + } + rewriter.create<omp::TerminatorOp>(loc); + + // Update hostEvalVars with the mapped values for the loop bounds if we have + // a loopNestOp and we are not generating code for the target device. + omp::LoopNestOp loopNestOp = + getLoopNestFromTeams(cast<omp::TeamsOp>(splitBeforeOp)); + if (loopNestOp && !isTargetDevice) { + for (size_t i = 0; i < loopNestOp.getLoopLowerBounds().size(); ++i) { + Value lb = loopNestOp.getLoopLowerBounds()[i]; + Value ub = loopNestOp.getLoopUpperBounds()[i]; + Value step = loopNestOp.getLoopSteps()[i]; + + hostEvalVars.lbs.push_back(preMapping.lookup(lb)); + hostEvalVars.ubs.push_back(preMapping.lookup(ub)); + hostEvalVars.steps.push_back(preMapping.lookup(step)); + } + } + + return preTargetOp; +} + +/// genIsolatedTargetOp method generates the isolatedTargetOp that contains the +/// ops between the split point. It also creates the block arguments and maps +/// the values accordingly. It also creates the load operations for the allocs +/// and recomputes the necessary ops. +static omp::TargetOp +genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands, + Operation *splitBeforeOp, RewriterBase &rewriter, + SmallVector<Value> &allocs, + SetVector<Operation *> &toRecompute, + struct HostEvalVars &hostEvalVars, bool isTargetDevice) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector<Value> isolatedHostEvalVars{targetOp.getHostEvalVars()}; + // update the hostEvalVars of isolatedTargetOp + if (!hostEvalVars.lbs.empty() && !isTargetDevice) { + isolatedHostEvalVars.append(hostEvalVars.lbs.begin(), + hostEvalVars.lbs.end()); + isolatedHostEvalVars.append(hostEvalVars.ubs.begin(), + hostEvalVars.ubs.end()); + isolatedHostEvalVars.append(hostEvalVars.steps.begin(), + hostEvalVars.steps.end()); + } + // Create the isolated target op + omp::TargetOp isolatedTargetOp = rewriter.create<omp::TargetOp>( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), + isolatedHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + auto *isolatedTargetBlock = + rewriter.createBlock(&isolatedTargetOp.getRegion(), + isolatedTargetOp.getRegion().begin(), {}, {}); + IRMapping isolatedMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, + isolatedTargetBlock, isolatedHostEvalVars, + postMapOperands, allocs, isolatedMapping); + // Handle the load operations for the allocs and recompute ops. + reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock, + isolatedTargetBlock, isolatedHostEvalVars, + postMapOperands, allocs, toRecompute, + isolatedMapping); + + // Clone the original operations. + rewriter.clone(*splitBeforeOp, isolatedMapping); + rewriter.create<omp::TerminatorOp>(loc); + + // update the loop bounds in the isolatedTargetOp if we have host_eval vars + // and we are not generating code for the target device. + if (!hostEvalVars.lbs.empty() && !isTargetDevice) { + omp::TeamsOp teamsOp; + for (auto &op : *isolatedTargetBlock) { + if (isa<omp::TeamsOp>(&op)) + teamsOp = cast<omp::TeamsOp>(&op); + } + assert(teamsOp && "No teamsOp found in isolated target region"); + // Get the loopNestOp inside the teamsOp + auto loopNestOp = getLoopNestFromTeams(teamsOp); + // Get the BlockArgs related to host_eval vars and update loop_nest bounds + // to them + unsigned originalHostEvalVarsSize = targetOp.getHostEvalVars().size(); + unsigned index = originalHostEvalVarsSize; + // Replace loop bounds with the block arguments passed down via host_eval + SmallVector<Value> lbs, ubs, steps; + + // Collect new lb/ub/step values from target block args + for (size_t i = 0; i < hostEvalVars.lbs.size(); ++i) + lbs.push_back(isolatedTargetBlock->getArgument(index++)); + + for (size_t i = 0; i < hostEvalVars.ubs.size(); ++i) + ubs.push_back(isolatedTargetBlock->getArgument(index++)); + + for (size_t i = 0; i < hostEvalVars.steps.size(); ++i) + steps.push_back(isolatedTargetBlock->getArgument(index++)); + + // Reset the loop bounds + loopNestOp.getLoopLowerBoundsMutable().assign(lbs); + loopNestOp.getLoopUpperBoundsMutable().assign(ubs); + loopNestOp.getLoopStepsMutable().assign(steps); + } + + return isolatedTargetOp; +} + +/// genPostTargetOp method generates the postTargetOp that contains all the ops +/// after the split point. It also creates the block arguments and maps the +/// values accordingly. It also creates the load operations for the allocs +/// and recomputes the necessary ops. +static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, + Operation *splitBeforeOp, + SmallVector<Value> &postMapOperands, + RewriterBase &rewriter, + SmallVector<Value> &allocs, + SetVector<Operation *> &toRecompute) { + auto loc = targetOp.getLoc(); + auto *targetBlock = &targetOp.getRegion().front(); + SmallVector<Value> postHostEvalVars{targetOp.getHostEvalVars()}; + // Create the post target op + omp::TargetOp postTargetOp = rewriter.create<omp::TargetOp>( + targetOp.getLoc(), targetOp.getAllocateVars(), + targetOp.getAllocatorVars(), targetOp.getBareAttr(), + targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars, + targetOp.getIfExpr(), targetOp.getInReductionVars(), + targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), + targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(), + targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(), + targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(), + targetOp.getPrivateMapsAttr()); + // Create the block for postTargetOp + auto *postTargetBlock = rewriter.createBlock( + &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {}); + IRMapping postMapping; + // Create block arguments and map the values. + createBlockArgsAndMap(loc, rewriter, targetOp, targetBlock, postTargetBlock, + postHostEvalVars, postMapOperands, allocs, postMapping); + // Handle the load operations for the allocs and recompute ops. + reloadCacheAndRecompute(loc, rewriter, splitBeforeOp, targetOp, targetBlock, + postTargetBlock, postHostEvalVars, postMapOperands, + allocs, toRecompute, postMapping); + assert(splitBeforeOp->getNumResults() == 0 || + llvm::all_of(splitBeforeOp->getResults(), + [](Value result) { return result.use_empty(); })); + // Clone the original operations after the split point. + for (auto it = std::next(splitBeforeOp->getIterator()); + it != targetBlock->end(); it++) + rewriter.clone(*it, postMapping); + return postTargetOp; +} + +/// isolateOp method rewrites a omp.target_data { omp.target } in to +/// omp.target_data { +/// // preTargetOp region contains ops before splitBeforeOp. +/// omp.target {} +/// // isolatedTargetOp region contains splitBeforeOp, +/// omp.target {} +/// // postTargetOp region contains ops after splitBeforeOp. +/// omp.target {} +/// } +/// It also handles the mapping of variables and the caching/recomputing +/// of values as needed. +static FailureOr<SplitResult> isolateOp(Operation *splitBeforeOp, + bool splitAfter, RewriterBase &rewriter, + mlir::ModuleOp module, + bool isTargetDevice) { + auto targetOp = cast<omp::TargetOp>(splitBeforeOp->getParentOp()); + assert(targetOp); + rewriter.setInsertionPoint(targetOp); + + // Prepare the map operands for preTargetOp and postTargetOp + auto preMapOperands = SmallVector<Value>(targetOp.getMapVars()); + auto postMapOperands = SmallVector<Value>(targetOp.getMapVars()); + + // Vectors to hold analysis results + SmallVector<Value> requiredVals; + SetVector<Operation *> toCache; + SetVector<Operation *> toRecompute; + SetVector<Operation *> nonRecomputable; + SmallVector<Value> allocs; + struct HostEvalVars hostEvalVars; + + // Analyze the ops in target region to determine which ops need to be + // cached and which ops need to be recomputed + computeAllocsCacheRecomputable( + targetOp, splitBeforeOp, rewriter, preMapOperands, postMapOperands, + allocs, requiredVals, nonRecomputable, toCache, toRecompute); + + rewriter.setInsertionPoint(targetOp); + + // Generate the preTargetOp that contains all the ops before splitBeforeOp. + auto preTargetOp = + genPreTargetOp(targetOp, preMapOperands, allocs, splitBeforeOp, rewriter, + hostEvalVars, isTargetDevice); + + // Move the ops of preTarget to host. + auto res = moveToHost(preTargetOp, rewriter, module, hostEvalVars); + if (failed(res)) + return failure(); + rewriter.setInsertionPoint(targetOp); + + // Generate the isolatedTargetOp + omp::TargetOp isolatedTargetOp = + genIsolatedTargetOp(targetOp, postMapOperands, splitBeforeOp, rewriter, + allocs, toRecompute, hostEvalVars, isTargetDevice); + + omp::TargetOp postTargetOp = nullptr; + // Generate the postTargetOp that contains all the ops after splitBeforeOp. + if (splitAfter) { + rewriter.setInsertionPoint(targetOp); + postTargetOp = genPostTargetOp(targetOp, splitBeforeOp, postMapOperands, + rewriter, allocs, toRecompute); + } + // Finally erase the original targetOp. + rewriter.eraseOp(targetOp); + return SplitResult{preTargetOp, isolatedTargetOp, postTargetOp}; +} + +/// Recursively fission target ops until no more nested ops can be isolated. +static LogicalResult fissionTarget(omp::TargetOp targetOp, + RewriterBase &rewriter, + mlir::ModuleOp module, bool isTargetDevice) { + auto tuple = getNestedOpToIsolate(targetOp); + if (!tuple) { + LLVM_DEBUG(llvm::dbgs() << " No op to isolate\n"); + struct HostEvalVars hostEvalVars; + return moveToHost(targetOp, rewriter, module, hostEvalVars); + } + Operation *toIsolate = std::get<0>(*tuple); + bool splitBefore = !std::get<1>(*tuple); + bool splitAfter = !std::get<2>(*tuple); + // Recursively isolate the target op. + if (splitBefore && splitAfter) { + auto res = + isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); + if (failed(res)) + return failure(); + return fissionTarget((*res).postTargetOp, rewriter, module, isTargetDevice); + } + // Isolate only before the op. + if (splitBefore) { + auto res = + isolateOp(toIsolate, splitAfter, rewriter, module, isTargetDevice); + if (failed(res)) + return failure(); + } else { + emitError(toIsolate->getLoc(), "Unhandled case in fissionTarget"); + return failure(); + } + return success(); +} + +/// Pass to lower omp.workdistribute ops. +class LowerWorkdistributePass + : public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> { +public: + void runOnOperation() override { + MLIRContext &context = getContext(); + auto moduleOp = getOperation(); + bool changed = false; + SetVector<omp::TargetOp> targetOpsToProcess; + auto verify = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + if (failed(verifyTargetTeamsWorkdistribute(workdistribute))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (verify.wasInterrupted()) + return signalPassFailure(); + + auto fission = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + auto res = fissionWorkdistribute(workdistribute); + if (failed(res)) + return WalkResult::interrupt(); + changed |= *res; + return WalkResult::advance(); + }); + if (fission.wasInterrupted()) + return signalPassFailure(); + + auto rtCallLower = + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + auto res = workdistributeRuntimeCallLower(workdistribute, + targetOpsToProcess); + if (failed(res)) + return WalkResult::interrupt(); + changed |= *res; + return WalkResult::advance(); + }); + if (rtCallLower.wasInterrupted()) + return signalPassFailure(); + + moduleOp->walk([&](mlir::omp::WorkdistributeOp workdistribute) { + changed |= workdistributeDoLower(workdistribute, targetOpsToProcess); + }); + + moduleOp->walk([&](mlir::omp::TeamsOp teams) { + changed |= teamsWorkdistributeToSingleOp(teams, targetOpsToProcess); + }); + if (changed) { + bool isTargetDevice = + llvm::cast<mlir::omp::OffloadModuleInterface>(*moduleOp) + .getIsTargetDevice(); + IRRewriter rewriter(&context); + for (auto targetOp : targetOpsToProcess) { + auto res = splitTargetData(targetOp, rewriter); + if (failed(res)) + return signalPassFailure(); + if (*res) { + if (failed(fissionTarget(*res, rewriter, moduleOp, isTargetDevice))) + return signalPassFailure(); + } + } + } + } +}; +} // namespace diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index a83b066..1ecb6d3 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -301,8 +301,10 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, addNestedPassToAllTopLevelOperations<PassConstructor>( pm, hlfir::createInlineHLFIRAssign); pm.addPass(hlfir::createConvertHLFIRtoFIR()); - if (enableOpenMP != EnableOpenMP::None) + if (enableOpenMP != EnableOpenMP::None) { pm.addPass(flangomp::createLowerWorkshare()); + pm.addPass(flangomp::createLowerWorkdistribute()); + } if (enableOpenMP == EnableOpenMP::Simd) pm.addPass(flangomp::createSimdOnlyPass()); } diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp index 061a7d2..bdc3418 100644 --- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp +++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp @@ -474,7 +474,7 @@ public: mlir::PatternRewriter &rewriter) const override { LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n"; loop.dump();); - LLVM_ATTRIBUTE_UNUSED auto loopAnalysis = + [[maybe_unused]] auto loopAnalysis = functionAnalysis.getChildLoopAnalysis(loop); if (!loopAnalysis.canPromoteToAffine()) return rewriter.notifyMatchFailure(loop, "cannot promote to affine"); diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp index 80b3f68..8601499 100644 --- a/flang/lib/Optimizer/Transforms/StackArrays.cpp +++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp @@ -561,7 +561,7 @@ static mlir::Value convertAllocationType(mlir::PatternRewriter &rewriter, return stack; fir::HeapType firHeapTy = mlir::cast<fir::HeapType>(heapTy); - LLVM_ATTRIBUTE_UNUSED fir::ReferenceType firRefTy = + [[maybe_unused]] fir::ReferenceType firRefTy = mlir::cast<fir::ReferenceType>(stackTy); assert(firHeapTy.getElementType() == firRefTy.getElementType() && "Allocations must have the same type"); diff --git a/flang/lib/Parser/parse-tree.cpp b/flang/lib/Parser/parse-tree.cpp index cb30939..8cbaa39 100644 --- a/flang/lib/Parser/parse-tree.cpp +++ b/flang/lib/Parser/parse-tree.cpp @@ -185,7 +185,7 @@ StructureConstructor ArrayElement::ConvertToStructureConstructor( std::list<ComponentSpec> components; for (auto &subscript : subscripts) { components.emplace_back(std::optional<Keyword>{}, - ComponentDataSource{std::move(*Unwrap<Expr>(subscript))}); + ComponentDataSource{std::move(UnwrapRef<Expr>(subscript))}); } DerivedTypeSpec spec{std::move(name), std::list<TypeParamSpec>{}}; spec.derivedTypeSpec = &derived; diff --git a/flang/lib/Semantics/assignment.cpp b/flang/lib/Semantics/assignment.cpp index f4aa496..1824a7d 100644 --- a/flang/lib/Semantics/assignment.cpp +++ b/flang/lib/Semantics/assignment.cpp @@ -194,7 +194,8 @@ void AssignmentContext::CheckShape(parser::CharBlock at, const SomeExpr *expr) { template <typename A> void AssignmentContext::PushWhereContext(const A &x) { const auto &expr{std::get<parser::LogicalExpr>(x.t)}; - CheckShape(expr.thing.value().source, GetExpr(context_, expr)); + CheckShape( + parser::UnwrapRef<parser::Expr>(expr).source, GetExpr(context_, expr)); ++whereDepth_; } diff --git a/flang/lib/Semantics/check-allocate.cpp b/flang/lib/Semantics/check-allocate.cpp index 823aa4e..e019bbd 100644 --- a/flang/lib/Semantics/check-allocate.cpp +++ b/flang/lib/Semantics/check-allocate.cpp @@ -151,7 +151,9 @@ static std::optional<AllocateCheckerInfo> CheckAllocateOptions( [&](const parser::MsgVariable &var) { WarnOnDeferredLengthCharacterScalar(context, GetExpr(context, var), - var.v.thing.thing.GetSource(), "ERRMSG="); + parser::UnwrapRef<parser::Variable>(var) + .GetSource(), + "ERRMSG="); if (info.gotMsg) { // C943 context.Say( "ERRMSG may not be duplicated in a ALLOCATE statement"_err_en_US); @@ -439,7 +441,7 @@ static bool HaveCompatibleLengths( evaluate::ToInt64(type1.characterTypeSpec().length().GetExplicit())}; auto v2{ evaluate::ToInt64(type2.characterTypeSpec().length().GetExplicit())}; - return !v1 || !v2 || *v1 == *v2; + return !v1 || !v2 || (*v1 >= 0 ? *v1 : 0) == (*v2 >= 0 ? *v2 : 0); } else { return true; } @@ -452,7 +454,7 @@ static bool HaveCompatibleLengths( auto v1{ evaluate::ToInt64(type1.characterTypeSpec().length().GetExplicit())}; auto v2{type2.knownLength()}; - return !v1 || !v2 || *v1 == *v2; + return !v1 || !v2 || (*v1 >= 0 ? *v1 : 0) == (*v2 >= 0 ? *v2 : 0); } else { return true; } @@ -598,7 +600,7 @@ bool AllocationCheckerHelper::RunChecks(SemanticsContext &context) { std::optional<evaluate::ConstantSubscript> lbound; if (const auto &lb{std::get<0>(shapeSpec.t)}) { lbound.reset(); - const auto &lbExpr{lb->thing.thing.value()}; + const auto &lbExpr{parser::UnwrapRef<parser::Expr>(lb)}; if (const auto *expr{GetExpr(context, lbExpr)}) { auto folded{ evaluate::Fold(context.foldingContext(), SomeExpr(*expr))}; @@ -609,7 +611,8 @@ bool AllocationCheckerHelper::RunChecks(SemanticsContext &context) { lbound = 1; } if (lbound) { - const auto &ubExpr{std::get<1>(shapeSpec.t).thing.thing.value()}; + const auto &ubExpr{ + parser::UnwrapRef<parser::Expr>(std::get<1>(shapeSpec.t))}; if (const auto *expr{GetExpr(context, ubExpr)}) { auto folded{ evaluate::Fold(context.foldingContext(), SomeExpr(*expr))}; diff --git a/flang/lib/Semantics/check-case.cpp b/flang/lib/Semantics/check-case.cpp index 5ce143c..7593154 100644 --- a/flang/lib/Semantics/check-case.cpp +++ b/flang/lib/Semantics/check-case.cpp @@ -72,7 +72,7 @@ private: } std::optional<Value> GetValue(const parser::CaseValue &caseValue) { - const parser::Expr &expr{caseValue.thing.thing.value()}; + const auto &expr{parser::UnwrapRef<parser::Expr>(caseValue)}; auto *x{expr.typedExpr.get()}; if (x && x->v) { // C1147 auto type{x->v->GetType()}; diff --git a/flang/lib/Semantics/check-coarray.cpp b/flang/lib/Semantics/check-coarray.cpp index 0e444f1..9113369 100644 --- a/flang/lib/Semantics/check-coarray.cpp +++ b/flang/lib/Semantics/check-coarray.cpp @@ -112,7 +112,7 @@ static void CheckTeamType( static void CheckTeamStat( SemanticsContext &context, const parser::ImageSelectorSpec::Stat &stat) { - const parser::Variable &var{stat.v.thing.thing.value()}; + const auto &var{parser::UnwrapRef<parser::Variable>(stat)}; if (parser::GetCoindexedNamedObject(var)) { context.Say(parser::FindSourceLocation(var), // C931 "Image selector STAT variable must not be a coindexed " @@ -147,7 +147,8 @@ static void CheckSyncStat(SemanticsContext &context, }, [&](const parser::MsgVariable &var) { WarnOnDeferredLengthCharacterScalar(context, GetExpr(context, var), - var.v.thing.thing.GetSource(), "ERRMSG="); + parser::UnwrapRef<parser::Variable>(var).GetSource(), + "ERRMSG="); if (gotMsg) { context.Say( // C1172 "The errmsg-variable in a sync-stat-list may not be repeated"_err_en_US); @@ -260,7 +261,9 @@ static void CheckEventWaitSpecList(SemanticsContext &context, [&](const parser::MsgVariable &var) { WarnOnDeferredLengthCharacterScalar(context, GetExpr(context, var), - var.v.thing.thing.GetSource(), "ERRMSG="); + parser::UnwrapRef<parser::Variable>(var) + .GetSource(), + "ERRMSG="); if (gotMsg) { context.Say( // C1178 "A errmsg-variable in a event-wait-spec-list may not be repeated"_err_en_US); diff --git a/flang/lib/Semantics/check-data.cpp b/flang/lib/Semantics/check-data.cpp index 5459290..3bcf711 100644 --- a/flang/lib/Semantics/check-data.cpp +++ b/flang/lib/Semantics/check-data.cpp @@ -25,9 +25,10 @@ namespace Fortran::semantics { // Ensures that references to an implied DO loop control variable are // represented as such in the "body" of the implied DO loop. void DataChecker::Enter(const parser::DataImpliedDo &x) { - auto name{std::get<parser::DataImpliedDo::Bounds>(x.t).name.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>( + std::get<parser::DataImpliedDo::Bounds>(x.t).name)}; int kind{evaluate::ResultType<evaluate::ImpliedDoIndex>::kind}; - if (const auto dynamicType{evaluate::DynamicType::From(*name.symbol)}) { + if (const auto dynamicType{evaluate::DynamicType::From(DEREF(name.symbol))}) { if (dynamicType->category() == TypeCategory::Integer) { kind = dynamicType->kind(); } @@ -36,7 +37,8 @@ void DataChecker::Enter(const parser::DataImpliedDo &x) { } void DataChecker::Leave(const parser::DataImpliedDo &x) { - auto name{std::get<parser::DataImpliedDo::Bounds>(x.t).name.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>( + std::get<parser::DataImpliedDo::Bounds>(x.t).name)}; exprAnalyzer_.RemoveImpliedDo(name.source); } @@ -211,7 +213,7 @@ void DataChecker::Leave(const parser::DataIDoObject &object) { std::get_if<parser::Scalar<common::Indirection<parser::Designator>>>( &object.u)}) { if (MaybeExpr expr{exprAnalyzer_.Analyze(*designator)}) { - auto source{designator->thing.value().source}; + auto source{parser::UnwrapRef<parser::Designator>(*designator).source}; DataVarChecker checker{exprAnalyzer_.context(), source}; if (checker(*expr)) { if (checker.HasComponentWithoutSubscripts()) { // C880 diff --git a/flang/lib/Semantics/check-deallocate.cpp b/flang/lib/Semantics/check-deallocate.cpp index c45b585..c1ebc5f 100644 --- a/flang/lib/Semantics/check-deallocate.cpp +++ b/flang/lib/Semantics/check-deallocate.cpp @@ -114,7 +114,8 @@ void DeallocateChecker::Leave(const parser::DeallocateStmt &deallocateStmt) { }, [&](const parser::MsgVariable &var) { WarnOnDeferredLengthCharacterScalar(context_, - GetExpr(context_, var), var.v.thing.thing.GetSource(), + GetExpr(context_, var), + parser::UnwrapRef<parser::Variable>(var).GetSource(), "ERRMSG="); if (gotMsg) { context_.Say( diff --git a/flang/lib/Semantics/check-do-forall.cpp b/flang/lib/Semantics/check-do-forall.cpp index a2f3685..8a47340 100644 --- a/flang/lib/Semantics/check-do-forall.cpp +++ b/flang/lib/Semantics/check-do-forall.cpp @@ -535,7 +535,8 @@ private: if (const SomeExpr * expr{GetExpr(context_, scalarExpression)}) { if (!ExprHasTypeCategory(*expr, TypeCategory::Integer)) { // No warnings or errors for type INTEGER - const parser::CharBlock &loc{scalarExpression.thing.value().source}; + parser::CharBlock loc{ + parser::UnwrapRef<parser::Expr>(scalarExpression).source}; CheckDoControl(loc, ExprHasTypeCategory(*expr, TypeCategory::Real)); } } @@ -552,7 +553,7 @@ private: CheckDoExpression(*bounds.step); if (IsZero(*bounds.step)) { context_.Warn(common::UsageWarning::ZeroDoStep, - bounds.step->thing.value().source, + parser::UnwrapRef<parser::Expr>(bounds.step).source, "DO step expression should not be zero"_warn_en_US); } } @@ -615,7 +616,7 @@ private: // C1121 - procedures in mask must be pure void CheckMaskIsPure(const parser::ScalarLogicalExpr &mask) const { UnorderedSymbolSet references{ - GatherSymbolsFromExpression(mask.thing.thing.value())}; + GatherSymbolsFromExpression(parser::UnwrapRef<parser::Expr>(mask))}; for (const Symbol &ref : OrderBySourcePosition(references)) { if (IsProcedure(ref) && !IsPureProcedure(ref)) { context_.SayWithDecl(ref, parser::Unwrap<parser::Expr>(mask)->source, @@ -639,32 +640,33 @@ private: } void HasNoReferences(const UnorderedSymbolSet &indexNames, - const parser::ScalarIntExpr &expr) const { - CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()), - indexNames, + const parser::ScalarIntExpr &scalarIntExpr) const { + const auto &expr{parser::UnwrapRef<parser::Expr>(scalarIntExpr)}; + CheckNoCollisions(GatherSymbolsFromExpression(expr), indexNames, "%s limit expression may not reference index variable '%s'"_err_en_US, - expr.thing.thing.value().source); + expr.source); } // C1129, names in local locality-specs can't be in mask expressions void CheckMaskDoesNotReferenceLocal(const parser::ScalarLogicalExpr &mask, const UnorderedSymbolSet &localVars) const { - CheckNoCollisions(GatherSymbolsFromExpression(mask.thing.thing.value()), - localVars, + const auto &expr{parser::UnwrapRef<parser::Expr>(mask)}; + CheckNoCollisions(GatherSymbolsFromExpression(expr), localVars, "%s mask expression references variable '%s'" " in LOCAL locality-spec"_err_en_US, - mask.thing.thing.value().source); + expr.source); } // C1129, names in local locality-specs can't be in limit or step // expressions - void CheckExprDoesNotReferenceLocal(const parser::ScalarIntExpr &expr, + void CheckExprDoesNotReferenceLocal( + const parser::ScalarIntExpr &scalarIntExpr, const UnorderedSymbolSet &localVars) const { - CheckNoCollisions(GatherSymbolsFromExpression(expr.thing.thing.value()), - localVars, + const auto &expr{parser::UnwrapRef<parser::Expr>(scalarIntExpr)}; + CheckNoCollisions(GatherSymbolsFromExpression(expr), localVars, "%s expression references variable '%s'" " in LOCAL locality-spec"_err_en_US, - expr.thing.thing.value().source); + expr.source); } // C1130, DEFAULT(NONE) locality requires names to be in locality-specs to @@ -772,7 +774,7 @@ private: HasNoReferences(indexNames, std::get<2>(control.t)); if (const auto &intExpr{ std::get<std::optional<parser::ScalarIntExpr>>(control.t)}) { - const parser::Expr &expr{intExpr->thing.thing.value()}; + const auto &expr{parser::UnwrapRef<parser::Expr>(intExpr)}; CheckNoCollisions(GatherSymbolsFromExpression(expr), indexNames, "%s step expression may not reference index variable '%s'"_err_en_US, expr.source); @@ -840,7 +842,7 @@ private: } void CheckForImpureCall(const parser::ScalarIntExpr &x, std::optional<IndexVarKind> nesting) const { - const auto &parsedExpr{x.thing.thing.value()}; + const auto &parsedExpr{parser::UnwrapRef<parser::Expr>(x)}; auto oldLocation{context_.location()}; context_.set_location(parsedExpr.source); if (const auto &typedExpr{parsedExpr.typedExpr}) { @@ -1124,7 +1126,8 @@ void DoForallChecker::Leave(const parser::ConnectSpec &connectSpec) { const auto *newunit{ std::get_if<parser::ConnectSpec::Newunit>(&connectSpec.u)}; if (newunit) { - context_.CheckIndexVarRedefine(newunit->v.thing.thing); + context_.CheckIndexVarRedefine( + parser::UnwrapRef<parser::Variable>(newunit)); } } @@ -1166,14 +1169,14 @@ void DoForallChecker::Leave(const parser::InquireSpec &inquireSpec) { const auto *intVar{std::get_if<parser::InquireSpec::IntVar>(&inquireSpec.u)}; if (intVar) { const auto &scalar{std::get<parser::ScalarIntVariable>(intVar->t)}; - context_.CheckIndexVarRedefine(scalar.thing.thing); + context_.CheckIndexVarRedefine(parser::UnwrapRef<parser::Variable>(scalar)); } } void DoForallChecker::Leave(const parser::IoControlSpec &ioControlSpec) { const auto *size{std::get_if<parser::IoControlSpec::Size>(&ioControlSpec.u)}; if (size) { - context_.CheckIndexVarRedefine(size->v.thing.thing); + context_.CheckIndexVarRedefine(parser::UnwrapRef<parser::Variable>(size)); } } @@ -1190,16 +1193,19 @@ static void CheckIoImpliedDoIndex( void DoForallChecker::Leave(const parser::OutputImpliedDo &outputImpliedDo) { CheckIoImpliedDoIndex(context_, - std::get<parser::IoImpliedDoControl>(outputImpliedDo.t).name.thing.thing); + parser::UnwrapRef<parser::Name>( + std::get<parser::IoImpliedDoControl>(outputImpliedDo.t).name)); } void DoForallChecker::Leave(const parser::InputImpliedDo &inputImpliedDo) { CheckIoImpliedDoIndex(context_, - std::get<parser::IoImpliedDoControl>(inputImpliedDo.t).name.thing.thing); + parser::UnwrapRef<parser::Name>( + std::get<parser::IoImpliedDoControl>(inputImpliedDo.t).name)); } void DoForallChecker::Leave(const parser::StatVariable &statVariable) { - context_.CheckIndexVarRedefine(statVariable.v.thing.thing); + context_.CheckIndexVarRedefine( + parser::UnwrapRef<parser::Variable>(statVariable)); } } // namespace Fortran::semantics diff --git a/flang/lib/Semantics/check-io.cpp b/flang/lib/Semantics/check-io.cpp index a1ff4b9..19059ad 100644 --- a/flang/lib/Semantics/check-io.cpp +++ b/flang/lib/Semantics/check-io.cpp @@ -424,8 +424,8 @@ void IoChecker::Enter(const parser::InquireSpec::CharVar &spec) { specKind = IoSpecKind::Dispose; break; } - const parser::Variable &var{ - std::get<parser::ScalarDefaultCharVariable>(spec.t).thing.thing}; + const auto &var{parser::UnwrapRef<parser::Variable>( + std::get<parser::ScalarDefaultCharVariable>(spec.t))}; std::string what{parser::ToUpperCaseLetters(common::EnumToString(specKind))}; CheckForDefinableVariable(var, what); WarnOnDeferredLengthCharacterScalar( @@ -627,7 +627,7 @@ void IoChecker::Enter(const parser::IoUnit &spec) { } void IoChecker::Enter(const parser::MsgVariable &msgVar) { - const parser::Variable &var{msgVar.v.thing.thing}; + const auto &var{parser::UnwrapRef<parser::Variable>(msgVar)}; if (stmt_ == IoStmtKind::None) { // allocate, deallocate, image control CheckForDefinableVariable(var, "ERRMSG"); diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index b4c1bf7..ea6fe43 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -2358,7 +2358,7 @@ private: } if (auto &repl{std::get<parser::OmpClause::Replayable>(clause.u).v}) { // Scalar<Logical<Constant<indirection<Expr>>>> - const parser::Expr &parserExpr{repl->v.thing.thing.thing.value()}; + const auto &parserExpr{parser::UnwrapRef<parser::Expr>(repl)}; if (auto &&expr{GetEvaluateExpr(parserExpr)}) { return GetLogicalValue(*expr).value_or(true); } @@ -2372,7 +2372,7 @@ private: bool isTransparent{true}; if (auto &transp{std::get<parser::OmpClause::Transparent>(clause.u).v}) { // Scalar<Integer<indirection<Expr>>> - const parser::Expr &parserExpr{transp->v.thing.thing.value()}; + const auto &parserExpr{parser::UnwrapRef<parser::Expr>(transp)}; if (auto &&expr{GetEvaluateExpr(parserExpr)}) { // If the argument is omp_not_impex (defined as 0), then // the task is not transparent, otherwise it is. @@ -2411,8 +2411,8 @@ private: } } // Scalar<Logical<indirection<Expr>>> - auto &parserExpr{ - std::get<parser::ScalarLogicalExpr>(ifc.v.t).thing.thing.value()}; + const auto &parserExpr{parser::UnwrapRef<parser::Expr>( + std::get<parser::ScalarLogicalExpr>(ifc.v.t))}; if (auto &&expr{GetEvaluateExpr(parserExpr)}) { // If the value is known to be false, an undeferred task will be // generated. diff --git a/flang/lib/Semantics/data-to-inits.cpp b/flang/lib/Semantics/data-to-inits.cpp index 1e46dab..bbf3b28 100644 --- a/flang/lib/Semantics/data-to-inits.cpp +++ b/flang/lib/Semantics/data-to-inits.cpp @@ -179,13 +179,14 @@ bool DataInitializationCompiler<DSV>::Scan( template <typename DSV> bool DataInitializationCompiler<DSV>::Scan(const parser::DataImpliedDo &ido) { const auto &bounds{std::get<parser::DataImpliedDo::Bounds>(ido.t)}; - auto name{bounds.name.thing.thing}; - const auto *lowerExpr{ - GetExpr(exprAnalyzer_.context(), bounds.lower.thing.thing)}; - const auto *upperExpr{ - GetExpr(exprAnalyzer_.context(), bounds.upper.thing.thing)}; + const auto &name{parser::UnwrapRef<parser::Name>(bounds.name)}; + const auto *lowerExpr{GetExpr( + exprAnalyzer_.context(), parser::UnwrapRef<parser::Expr>(bounds.lower))}; + const auto *upperExpr{GetExpr( + exprAnalyzer_.context(), parser::UnwrapRef<parser::Expr>(bounds.upper))}; const auto *stepExpr{bounds.step - ? GetExpr(exprAnalyzer_.context(), bounds.step->thing.thing) + ? GetExpr(exprAnalyzer_.context(), + parser::UnwrapRef<parser::Expr>(bounds.step)) : nullptr}; if (lowerExpr && upperExpr) { // Fold the bounds expressions (again) in case any of them depend @@ -240,7 +241,9 @@ bool DataInitializationCompiler<DSV>::Scan( return common::visit( common::visitors{ [&](const parser::Scalar<common::Indirection<parser::Designator>> - &var) { return Scan(var.thing.value()); }, + &var) { + return Scan(parser::UnwrapRef<parser::Designator>(var)); + }, [&](const common::Indirection<parser::DataImpliedDo> &ido) { return Scan(ido.value()); }, diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp index 2feec98..4aeb9a4 100644 --- a/flang/lib/Semantics/expression.cpp +++ b/flang/lib/Semantics/expression.cpp @@ -176,8 +176,8 @@ public: // Find and return a user-defined operator or report an error. // The provided message is used if there is no such operator. - MaybeExpr TryDefinedOp( - const char *, parser::MessageFixedText, bool isUserOp = false); + MaybeExpr TryDefinedOp(const char *, parser::MessageFixedText, + bool isUserOp = false, bool checkForNullPointer = true); template <typename E> MaybeExpr TryDefinedOp(E opr, parser::MessageFixedText msg) { return TryDefinedOp( @@ -211,7 +211,8 @@ private: void SayNoMatch( const std::string &, bool isAssignment = false, bool isAmbiguous = false); std::string TypeAsFortran(std::size_t); - bool AnyUntypedOrMissingOperand() const; + bool AnyUntypedOperand() const; + bool AnyMissingOperand() const; ExpressionAnalyzer &context_; ActualArguments actuals_; @@ -1954,9 +1955,10 @@ void ArrayConstructorContext::Add(const parser::AcImpliedDo &impliedDo) { const auto &control{std::get<parser::AcImpliedDoControl>(impliedDo.t)}; const auto &bounds{std::get<parser::AcImpliedDoControl::Bounds>(control.t)}; exprAnalyzer_.Analyze(bounds.name); - parser::CharBlock name{bounds.name.thing.thing.source}; + const auto &parsedName{parser::UnwrapRef<parser::Name>(bounds.name)}; + parser::CharBlock name{parsedName.source}; int kind{ImpliedDoIntType::kind}; - if (const Symbol * symbol{bounds.name.thing.thing.symbol}) { + if (const Symbol *symbol{parsedName.symbol}) { if (auto dynamicType{DynamicType::From(symbol)}) { if (dynamicType->category() == TypeCategory::Integer) { kind = dynamicType->kind(); @@ -1981,7 +1983,7 @@ void ArrayConstructorContext::Add(const parser::AcImpliedDo &impliedDo) { auto cUpper{ToInt64(upper)}; auto cStride{ToInt64(stride)}; if (!(messageDisplayedSet_ & 0x10) && cStride && *cStride == 0) { - exprAnalyzer_.SayAt(bounds.step.value().thing.thing.value().source, + exprAnalyzer_.SayAt(parser::UnwrapRef<parser::Expr>(bounds.step).source, "The stride of an implied DO loop must not be zero"_err_en_US); messageDisplayedSet_ |= 0x10; } @@ -2526,7 +2528,7 @@ static const Symbol *GetBindingResolution( auto ExpressionAnalyzer::AnalyzeProcedureComponentRef( const parser::ProcComponentRef &pcr, ActualArguments &&arguments, bool isSubroutine) -> std::optional<CalleeAndArguments> { - const parser::StructureComponent &sc{pcr.v.thing}; + const auto &sc{parser::UnwrapRef<parser::StructureComponent>(pcr)}; if (MaybeExpr base{Analyze(sc.base)}) { if (const Symbol *sym{sc.component.symbol}) { if (context_.HasError(sym)) { @@ -3695,11 +3697,12 @@ std::optional<characteristics::Procedure> ExpressionAnalyzer::CheckCall( MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::Parentheses &x) { if (MaybeExpr operand{Analyze(x.v.value())}) { - if (const semantics::Symbol *symbol{GetLastSymbol(*operand)}) { + if (IsNullPointerOrAllocatable(&*operand)) { + Say("NULL() may not be parenthesized"_err_en_US); + } else if (const semantics::Symbol *symbol{GetLastSymbol(*operand)}) { if (const semantics::Symbol *result{FindFunctionResult(*symbol)}) { if (semantics::IsProcedurePointer(*result)) { - Say("A function reference that returns a procedure " - "pointer may not be parenthesized"_err_en_US); // C1003 + Say("A function reference that returns a procedure pointer may not be parenthesized"_err_en_US); // C1003 } } } @@ -3788,7 +3791,7 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr::DefinedUnary &x) { ArgumentAnalyzer analyzer{*this, name.source}; analyzer.Analyze(std::get<1>(x.t)); return analyzer.TryDefinedOp(name.source.ToString().c_str(), - "No operator %s defined for %s"_err_en_US, true); + "No operator %s defined for %s"_err_en_US, /*isUserOp=*/true); } // Binary (dyadic) operations @@ -3997,7 +4000,9 @@ static bool CheckFuncRefToArrayElement(semantics::SemanticsContext &context, auto &proc{std::get<parser::ProcedureDesignator>(funcRef.v.t)}; const auto *name{std::get_if<parser::Name>(&proc.u)}; if (!name) { - name = &std::get<parser::ProcComponentRef>(proc.u).v.thing.component; + name = &parser::UnwrapRef<parser::StructureComponent>( + std::get<parser::ProcComponentRef>(proc.u)) + .component; } if (!name->symbol) { return false; @@ -4047,14 +4052,16 @@ static void FixMisparsedFunctionReference( } } auto &proc{std::get<parser::ProcedureDesignator>(funcRef.v.t)}; - if (Symbol *origSymbol{ - common::visit(common::visitors{ - [&](parser::Name &name) { return name.symbol; }, - [&](parser::ProcComponentRef &pcr) { - return pcr.v.thing.component.symbol; - }, - }, - proc.u)}) { + if (Symbol * + origSymbol{common::visit( + common::visitors{ + [&](parser::Name &name) { return name.symbol; }, + [&](parser::ProcComponentRef &pcr) { + return parser::UnwrapRef<parser::StructureComponent>(pcr) + .component.symbol; + }, + }, + proc.u)}) { Symbol &symbol{origSymbol->GetUltimate()}; if (symbol.has<semantics::ObjectEntityDetails>() || symbol.has<semantics::AssocEntityDetails>()) { @@ -4176,15 +4183,23 @@ MaybeExpr ExpressionAnalyzer::IterativelyAnalyzeSubexpressions( } while (!queue.empty()); // Analyze the collected subexpressions in bottom-up order. // On an error, bail out and leave partial results in place. - MaybeExpr result; - for (auto riter{finish.rbegin()}; riter != finish.rend(); ++riter) { - const parser::Expr &expr{**riter}; - result = ExprOrVariable(expr, expr.source); - if (!result) { - return result; + if (finish.size() == 1) { + const parser::Expr &expr{DEREF(finish.front())}; + return ExprOrVariable(expr, expr.source); + } else { + // NULL() operand catching is deferred to operation analysis so + // that they can be accepted by defined operators. + auto restorer{AllowNullPointer()}; + MaybeExpr result; + for (auto riter{finish.rbegin()}; riter != finish.rend(); ++riter) { + const parser::Expr &expr{**riter}; + result = ExprOrVariable(expr, expr.source); + if (!result) { + return result; + } } + return result; // last value was from analysis of "top" } - return result; // last value was from analysis of "top" } MaybeExpr ExpressionAnalyzer::Analyze(const parser::Expr &expr) { @@ -4681,7 +4696,7 @@ bool ArgumentAnalyzer::AnyCUDADeviceData() const { // attribute. bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride( const char *opr) const { - if (AnyCUDADeviceData() && !AnyUntypedOrMissingOperand()) { + if (AnyCUDADeviceData() && !AnyUntypedOperand() && !AnyMissingOperand()) { std::string oprNameString{"operator("s + opr + ')'}; parser::CharBlock oprName{oprNameString}; parser::Messages buffer; @@ -4709,9 +4724,9 @@ bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride( return false; } -MaybeExpr ArgumentAnalyzer::TryDefinedOp( - const char *opr, parser::MessageFixedText error, bool isUserOp) { - if (AnyUntypedOrMissingOperand()) { +MaybeExpr ArgumentAnalyzer::TryDefinedOp(const char *opr, + parser::MessageFixedText error, bool isUserOp, bool checkForNullPointer) { + if (AnyMissingOperand()) { context_.Say(error, ToUpperCase(opr), TypeAsFortran(0), TypeAsFortran(1)); return std::nullopt; } @@ -4790,7 +4805,9 @@ MaybeExpr ArgumentAnalyzer::TryDefinedOp( context_.Say( "Operands of %s are not conformable; have rank %d and rank %d"_err_en_US, ToUpperCase(opr), actuals_[0]->Rank(), actuals_[1]->Rank()); - } else if (CheckForNullPointer() && CheckForAssumedRank()) { + } else if (!CheckForAssumedRank()) { + } else if (checkForNullPointer && !CheckForNullPointer()) { + } else { // use the supplied error context_.Say(error, ToUpperCase(opr), TypeAsFortran(0), TypeAsFortran(1)); } return result; @@ -4808,15 +4825,16 @@ MaybeExpr ArgumentAnalyzer::TryDefinedOp( for (std::size_t i{0}; i < oprs.size(); ++i) { parser::Messages buffer; auto restorer{context_.GetContextualMessages().SetMessages(buffer)}; - if (MaybeExpr thisResult{TryDefinedOp(oprs[i], error)}) { + if (MaybeExpr thisResult{TryDefinedOp(oprs[i], error, /*isUserOp=*/false, + /*checkForNullPointer=*/false)}) { result = std::move(thisResult); hit.push_back(oprs[i]); hitBuffer = std::move(buffer); } } } - if (hit.empty()) { // for the error - result = TryDefinedOp(oprs[0], error); + if (hit.empty()) { // run TryDefinedOp() again just to emit errors + CHECK(!TryDefinedOp(oprs[0], error).has_value()); } else if (hit.size() > 1) { context_.Say( "Matching accessible definitions were found with %zd variant spellings of the generic operator ('%s', '%s')"_err_en_US, @@ -5232,10 +5250,19 @@ std::string ArgumentAnalyzer::TypeAsFortran(std::size_t i) { } } -bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() const { +bool ArgumentAnalyzer::AnyUntypedOperand() const { + for (const auto &actual : actuals_) { + if (actual && !actual->GetType() && + !IsBareNullPointer(actual->UnwrapExpr())) { + return true; + } + } + return false; +} + +bool ArgumentAnalyzer::AnyMissingOperand() const { for (const auto &actual : actuals_) { - if (!actual || - (!actual->GetType() && !IsBareNullPointer(actual->UnwrapExpr()))) { + if (!actual) { return true; } } @@ -5268,9 +5295,9 @@ void ExprChecker::Post(const parser::DataStmtObject &obj) { bool ExprChecker::Pre(const parser::DataImpliedDo &ido) { parser::Walk(std::get<parser::DataImpliedDo::Bounds>(ido.t), *this); const auto &bounds{std::get<parser::DataImpliedDo::Bounds>(ido.t)}; - auto name{bounds.name.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>(bounds.name)}; int kind{evaluate::ResultType<evaluate::ImpliedDoIndex>::kind}; - if (const auto dynamicType{evaluate::DynamicType::From(*name.symbol)}) { + if (const auto dynamicType{evaluate::DynamicType::From(DEREF(name.symbol))}) { if (dynamicType->category() == TypeCategory::Integer) { kind = dynamicType->kind(); } diff --git a/flang/lib/Semantics/resolve-names-utils.cpp b/flang/lib/Semantics/resolve-names-utils.cpp index 742bb74..ac67799 100644 --- a/flang/lib/Semantics/resolve-names-utils.cpp +++ b/flang/lib/Semantics/resolve-names-utils.cpp @@ -492,12 +492,14 @@ bool EquivalenceSets::CheckDesignator(const parser::Designator &designator) { const auto &range{std::get<parser::SubstringRange>(x.t)}; bool ok{CheckDataRef(designator.source, dataRef)}; if (const auto &lb{std::get<0>(range.t)}) { - ok &= CheckSubstringBound(lb->thing.thing.value(), true); + ok &= CheckSubstringBound( + parser::UnwrapRef<parser::Expr>(lb), true); } else { currObject_.substringStart = 1; } if (const auto &ub{std::get<1>(range.t)}) { - ok &= CheckSubstringBound(ub->thing.thing.value(), false); + ok &= CheckSubstringBound( + parser::UnwrapRef<parser::Expr>(ub), false); } return ok; }, @@ -528,7 +530,8 @@ bool EquivalenceSets::CheckDataRef( return false; }, [&](const parser::IntExpr &y) { - return CheckArrayBound(y.thing.value()); + return CheckArrayBound( + parser::UnwrapRef<parser::Expr>(y)); }, }, subscript.u); diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index ae0ff9ca..699de41 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -1140,7 +1140,7 @@ protected: std::optional<SourceName> BeginCheckOnIndexUseInOwnBounds( const parser::DoVariable &name) { std::optional<SourceName> result{checkIndexUseInOwnBounds_}; - checkIndexUseInOwnBounds_ = name.thing.thing.source; + checkIndexUseInOwnBounds_ = parser::UnwrapRef<parser::Name>(name).source; return result; } void EndCheckOnIndexUseInOwnBounds(const std::optional<SourceName> &restore) { @@ -2130,7 +2130,7 @@ public: void Post(const parser::SubstringInquiry &); template <typename A, typename B> void Post(const parser::LoopBounds<A, B> &x) { - ResolveName(*parser::Unwrap<parser::Name>(x.name)); + ResolveName(parser::UnwrapRef<parser::Name>(x.name)); } void Post(const parser::ProcComponentRef &); bool Pre(const parser::FunctionReference &); @@ -2560,7 +2560,7 @@ KindExpr DeclTypeSpecVisitor::GetKindParamExpr( CHECK(!state_.originalKindParameter); // Save a pointer to the KIND= expression in the parse tree // in case we need to reanalyze it during PDT instantiation. - state_.originalKindParameter = &expr->thing.thing.thing.value(); + state_.originalKindParameter = parser::Unwrap<parser::Expr>(expr); } } // Inhibit some errors now that will be caught later during instantiations. @@ -5649,6 +5649,7 @@ bool DeclarationVisitor::Pre(const parser::NamedConstantDef &x) { if (details->init() || symbol.test(Symbol::Flag::InDataStmt)) { Say(name, "Named constant '%s' already has a value"_err_en_US); } + parser::CharBlock at{parser::UnwrapRef<parser::Expr>(expr).source}; if (inOldStyleParameterStmt_) { // non-standard extension PARAMETER statement (no parentheses) Walk(expr); @@ -5657,7 +5658,6 @@ bool DeclarationVisitor::Pre(const parser::NamedConstantDef &x) { SayWithDecl(name, symbol, "Alternative style PARAMETER '%s' must not already have an explicit type"_err_en_US); } else if (folded) { - auto at{expr.thing.value().source}; if (evaluate::IsActuallyConstant(*folded)) { if (const auto *type{currScope().GetType(*folded)}) { if (type->IsPolymorphic()) { @@ -5682,8 +5682,7 @@ bool DeclarationVisitor::Pre(const parser::NamedConstantDef &x) { // standard-conforming PARAMETER statement (with parentheses) ApplyImplicitRules(symbol); Walk(expr); - if (auto converted{EvaluateNonPointerInitializer( - symbol, expr, expr.thing.value().source)}) { + if (auto converted{EvaluateNonPointerInitializer(symbol, expr, at)}) { details->set_init(std::move(*converted)); } } @@ -6149,7 +6148,7 @@ bool DeclarationVisitor::Pre(const parser::KindParam &x) { if (const auto *kind{std::get_if< parser::Scalar<parser::Integer<parser::Constant<parser::Name>>>>( &x.u)}) { - const parser::Name &name{kind->thing.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>(kind)}; if (!FindSymbol(name)) { Say(name, "Parameter '%s' not found"_err_en_US); } @@ -7460,7 +7459,7 @@ void DeclarationVisitor::DeclareLocalEntity( Symbol *DeclarationVisitor::DeclareStatementEntity( const parser::DoVariable &doVar, const std::optional<parser::IntegerTypeSpec> &type) { - const parser::Name &name{doVar.thing.thing}; + const auto &name{parser::UnwrapRef<parser::Name>(doVar)}; const DeclTypeSpec *declTypeSpec{nullptr}; if (auto *prev{FindSymbol(name)}) { if (prev->owner() == currScope()) { @@ -7893,13 +7892,14 @@ bool ConstructVisitor::Pre(const parser::DataIDoObject &x) { common::visit( common::visitors{ [&](const parser::Scalar<Indirection<parser::Designator>> &y) { - Walk(y.thing.value()); - const parser::Name &first{parser::GetFirstName(y.thing.value())}; + const auto &designator{parser::UnwrapRef<parser::Designator>(y)}; + Walk(designator); + const parser::Name &first{parser::GetFirstName(designator)}; if (first.symbol) { first.symbol->set(Symbol::Flag::InDataStmt); } }, - [&](const Indirection<parser::DataImpliedDo> &y) { Walk(y.value()); }, + [&](const Indirection<parser::DataImpliedDo> &y) { Walk(y); }, }, x.u); return false; @@ -8582,8 +8582,7 @@ public: void Post(const parser::WriteStmt &) { inAsyncIO_ = false; } void Post(const parser::IoControlSpec::Size &size) { if (const auto *designator{ - std::get_if<common::Indirection<parser::Designator>>( - &size.v.thing.thing.u)}) { + parser::Unwrap<common::Indirection<parser::Designator>>(size)}) { NoteAsyncIODesignator(designator->value()); } } @@ -9175,16 +9174,17 @@ bool DeclarationVisitor::CheckNonPointerInitialization( } void DeclarationVisitor::NonPointerInitialization( - const parser::Name &name, const parser::ConstantExpr &expr) { + const parser::Name &name, const parser::ConstantExpr &constExpr) { if (CheckNonPointerInitialization( name, /*inLegacyDataInitialization=*/false)) { Symbol &ultimate{name.symbol->GetUltimate()}; auto &details{ultimate.get<ObjectEntityDetails>()}; + const auto &expr{parser::UnwrapRef<parser::Expr>(constExpr)}; if (ultimate.owner().IsParameterizedDerivedType()) { // Save the expression for per-instantiation analysis. - details.set_unanalyzedPDTComponentInit(&expr.thing.value()); + details.set_unanalyzedPDTComponentInit(&expr); } else if (MaybeExpr folded{EvaluateNonPointerInitializer( - ultimate, expr, expr.thing.value().source)}) { + ultimate, constExpr, expr.source)}) { details.set_init(std::move(*folded)); ultimate.set(Symbol::Flag::InDataStmt, false); } |