diff options
Diffstat (limited to 'flang/lib/Lower')
| -rw-r--r-- | flang/lib/Lower/Allocatable.cpp | 12 | ||||
| -rw-r--r-- | flang/lib/Lower/Bridge.cpp | 360 | ||||
| -rw-r--r-- | flang/lib/Lower/CMakeLists.txt | 2 | ||||
| -rw-r--r-- | flang/lib/Lower/CUDA.cpp | 14 | ||||
| -rw-r--r-- | flang/lib/Lower/Coarray.cpp | 66 | ||||
| -rw-r--r-- | flang/lib/Lower/ConvertCall.cpp | 15 | ||||
| -rw-r--r-- | flang/lib/Lower/ConvertExpr.cpp | 2 | ||||
| -rw-r--r-- | flang/lib/Lower/ConvertVariable.cpp | 11 | ||||
| -rw-r--r-- | flang/lib/Lower/MultiImageFortran.cpp | 278 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenACC.cpp | 973 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 242 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/ClauseProcessor.h | 8 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/Clauses.cpp | 46 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/DataSharingProcessor.cpp | 5 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/OpenMP.cpp | 453 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/Utils.cpp | 191 | ||||
| -rw-r--r-- | flang/lib/Lower/OpenMP/Utils.h | 10 | ||||
| -rw-r--r-- | flang/lib/Lower/Runtime.cpp | 135 | ||||
| -rw-r--r-- | flang/lib/Lower/Support/ReductionProcessor.cpp | 131 | ||||
| -rw-r--r-- | flang/lib/Lower/Support/Utils.cpp | 5 |
20 files changed, 1880 insertions, 1079 deletions
diff --git a/flang/lib/Lower/Allocatable.cpp b/flang/lib/Lower/Allocatable.cpp index e7a6c4d..c9a9d93 100644 --- a/flang/lib/Lower/Allocatable.cpp +++ b/flang/lib/Lower/Allocatable.cpp @@ -798,10 +798,13 @@ private: // Keep return type the same as a standard AllocatableAllocate call. mlir::Type retTy = fir::runtime::getModel<int>()(builder.getContext()); + bool doubleDescriptors = Fortran::lower::hasDoubleDescriptor(box.getAddr()); return cuf::AllocateOp::create( builder, loc, retTy, box.getAddr(), errmsg, stream, pinned, source, cudaAttr, - errorManager.hasStatSpec() ? builder.getUnitAttr() : nullptr) + errorManager.hasStatSpec() ? builder.getUnitAttr() : nullptr, + doubleDescriptors ? builder.getUnitAttr() : nullptr, + box.isPointer() ? builder.getUnitAttr() : nullptr) .getResult(); } @@ -865,11 +868,14 @@ static mlir::Value genCudaDeallocate(fir::FirOpBuilder &builder, ? nullptr : errorManager.errMsgAddr; - // Keep return type the same as a standard AllocatableAllocate call. + // Keep return type the same as a standard AllocatableDeallocate call. mlir::Type retTy = fir::runtime::getModel<int>()(builder.getContext()); + bool doubleDescriptors = Fortran::lower::hasDoubleDescriptor(box.getAddr()); return cuf::DeallocateOp::create( builder, loc, retTy, box.getAddr(), errmsg, cudaAttr, - errorManager.hasStatSpec() ? builder.getUnitAttr() : nullptr) + errorManager.hasStatSpec() ? builder.getUnitAttr() : nullptr, + doubleDescriptors ? builder.getUnitAttr() : nullptr, + box.isPointer() ? builder.getUnitAttr() : nullptr) .getResult(); } diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 6e72987..d175e2a 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -15,7 +15,6 @@ #include "flang/Lower/Allocatable.h" #include "flang/Lower/CUDA.h" #include "flang/Lower/CallInterface.h" -#include "flang/Lower/Coarray.h" #include "flang/Lower/ConvertCall.h" #include "flang/Lower/ConvertExpr.h" #include "flang/Lower/ConvertExprToHLFIR.h" @@ -26,6 +25,7 @@ #include "flang/Lower/IO.h" #include "flang/Lower/IterationSpace.h" #include "flang/Lower/Mangler.h" +#include "flang/Lower/MultiImageFortran.h" #include "flang/Lower/OpenACC.h" #include "flang/Lower/OpenMP.h" #include "flang/Lower/PFTBuilder.h" @@ -86,10 +86,6 @@ #define DEBUG_TYPE "flang-lower-bridge" -static llvm::cl::opt<bool> dumpBeforeFir( - "fdebug-dump-pre-fir", llvm::cl::init(false), - llvm::cl::desc("dump the Pre-FIR tree prior to FIR generation")); - static llvm::cl::opt<bool> forceLoopToExecuteOnce( "always-execute-loop-body", llvm::cl::init(false), llvm::cl::desc("force the body of a loop to execute at least once")); @@ -311,7 +307,11 @@ private: if (!insertPointIfCreated.isSet()) return; // fir.type_info was already built in a previous call. - // Set init, destroy, and nofinal attributes. + // Set abstract, init, destroy, and nofinal attributes. + const Fortran::semantics::Symbol &dtSymbol = info.typeSpec.typeSymbol(); + if (dtSymbol.attrs().test(Fortran::semantics::Attr::ABSTRACT)) + dt->setAttr(dt.getAbstractAttrName(), builder.getUnitAttr()); + if (!info.typeSpec.HasDefaultInitialization(/*ignoreAllocatable=*/false, /*ignorePointer=*/false)) dt->setAttr(dt.getNoInitAttrName(), builder.getUnitAttr()); @@ -335,10 +335,14 @@ private: if (details.numPrivatesNotOverridden() > 0) tbpName += "."s + std::to_string(details.numPrivatesNotOverridden()); std::string bindingName = converter.mangleName(details.symbol()); - fir::DTEntryOp::create( + auto dtEntry = fir::DTEntryOp::create( builder, info.loc, mlir::StringAttr::get(builder.getContext(), tbpName), mlir::SymbolRefAttr::get(builder.getContext(), bindingName)); + // Propagate DEFERRED attribute on the binding to fir.dt_entry. + if (binding.get().attrs().test(Fortran::semantics::Attr::DEFERRED)) + dtEntry->setAttr(fir::DTEntryOp::getDeferredAttrNameStr(), + builder.getUnitAttr()); } fir::FirEndOp::create(builder, info.loc); } @@ -448,6 +452,13 @@ public: } }); + // Ensure imported OpenMP declare mappers are materialized at module + // scope before lowering any constructs that may reference them. + createBuilderOutsideOfFuncOpAndDo([&]() { + Fortran::lower::materializeOpenMPDeclareMappers( + *this, bridge.getSemanticsContext()); + }); + // Create definitions of intrinsic module constants. createBuilderOutsideOfFuncOpAndDo( [&]() { createIntrinsicModuleDefinitions(pft); }); @@ -1111,6 +1122,34 @@ public: return bridge.fctCtx(); } + /// Initializes values for STAT and ERRMSG + std::pair<mlir::Value, mlir::Value> + genStatAndErrmsg(mlir::Location loc, + const std::list<Fortran::parser::StatOrErrmsg> + &statOrErrList) override final { + Fortran::lower::StatementContext stmtCtx; + + mlir::Value errMsgExpr, statExpr; + for (const Fortran::parser::StatOrErrmsg &statOrErr : statOrErrList) { + std::visit(Fortran::common::visitors{ + [&](const Fortran::parser::StatVariable &statVar) { + const Fortran::semantics::SomeExpr *expr = + Fortran::semantics::GetExpr(statVar); + statExpr = + fir::getBase(genExprAddr(*expr, stmtCtx, &loc)); + }, + [&](const Fortran::parser::MsgVariable &errMsgVar) { + const Fortran::semantics::SomeExpr *expr = + Fortran::semantics::GetExpr(errMsgVar); + errMsgExpr = + fir::getBase(genExprBox(loc, *expr, stmtCtx)); + }}, + statOrErr.u); + } + + return {statExpr, errMsgExpr}; + } + mlir::Value hostAssocTupleValue() override final { return hostAssocTuple; } /// Record a binding for the ssa-value of the tuple for this function. @@ -1129,6 +1168,12 @@ public: return registeredDummySymbols.contains(sym); } + unsigned getDummyArgPosition( + const Fortran::semantics::Symbol &sym) const override final { + auto it = dummyArgPositions.find(&sym); + return (it != dummyArgPositions.end()) ? it->second : 0; + } + const Fortran::lower::pft::FunctionLikeUnit * getCurrentFunctionUnit() const override final { return currentFunctionUnit; @@ -1413,11 +1458,14 @@ private: /// definitive mapping. The specification expression have not been lowered /// yet. The final mapping will be done using this pre-mapping in /// Fortran::lower::mapSymbolAttributes. + /// \param argNo The 1-based source position of this argument (0 if + /// unknown/result) bool mapBlockArgToDummyOrResult(const Fortran::semantics::SymbolRef sym, - mlir::Value val, bool isResult) { + mlir::Value val, bool isResult, + unsigned argNo = 0) { localSymbols.addSymbol(sym, val); if (!isResult) - registerDummySymbol(sym); + registerDummySymbol(sym, argNo); return true; } @@ -2264,6 +2312,35 @@ private: } } + // Add AccessGroups attribute on operations in fir::DoLoopOp if this + // operation has the parallelAccesses attribute. + void attachAccessGroupAttrToDoLoopOperations(fir::DoLoopOp &doLoop) { + if (auto loopAnnotAttr = doLoop.getLoopAnnotationAttr()) { + if (loopAnnotAttr.getParallelAccesses().size()) { + llvm::SmallVector<mlir::Attribute> accessGroupAttrs( + loopAnnotAttr.getParallelAccesses().begin(), + loopAnnotAttr.getParallelAccesses().end()); + mlir::ArrayAttr attrs = + mlir::ArrayAttr::get(builder->getContext(), accessGroupAttrs); + doLoop.walk([&](mlir::Operation *op) { + if (fir::StoreOp storeOp = mlir::dyn_cast<fir::StoreOp>(op)) { + storeOp.setAccessGroupsAttr(attrs); + } else if (fir::LoadOp loadOp = mlir::dyn_cast<fir::LoadOp>(op)) { + loadOp.setAccessGroupsAttr(attrs); + } else if (hlfir::AssignOp assignOp = + mlir::dyn_cast<hlfir::AssignOp>(op)) { + // In some loops, the HLFIR AssignOp operation can be translated + // into FIR operation(s) containing StoreOp. It is therefore + // necessary to forward the AccessGroups attribute. + assignOp.getOperation()->setAttr("access_groups", attrs); + } else if (fir::CallOp callOp = mlir::dyn_cast<fir::CallOp>(op)) { + callOp.setAccessGroupsAttr(attrs); + } + }); + } + } + } + /// Generate FIR for a DO construct. There are six variants: /// - unstructured infinite and while loops /// - structured and unstructured increment loops @@ -2412,6 +2489,11 @@ private: // This call may generate a branch in some contexts. genFIR(endDoEval, unstructuredContext); + // Add AccessGroups attribute on operations in fir::DoLoopOp if necessary + for (IncrementLoopInfo &info : incrementLoopNestInfo) + if (auto loopOp = mlir::dyn_cast_if_present<fir::DoLoopOp>(info.loopOp)) + attachAccessGroupAttrToDoLoopOperations(loopOp); + if (!incrementLoopNestInfo.empty() && incrementLoopNestInfo.back().isConcurrent) localSymbols.popScope(); @@ -2500,22 +2582,61 @@ private: {}, {}, {}, {}); } + // Enabling loop vectorization attribute. + mlir::LLVM::LoopVectorizeAttr + genLoopVectorizeAttr(mlir::BoolAttr disableAttr, + mlir::BoolAttr scalableEnable, + mlir::IntegerAttr vectorWidth) { + mlir::LLVM::LoopVectorizeAttr va; + if (disableAttr) + va = mlir::LLVM::LoopVectorizeAttr::get( + builder->getContext(), + /*disable=*/disableAttr, /*predicate=*/{}, + /*scalableEnable=*/scalableEnable, + /*vectorWidth=*/vectorWidth, {}, {}, {}); + return va; + } + void addLoopAnnotationAttr( IncrementLoopInfo &info, llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) { - mlir::LLVM::LoopVectorizeAttr va; + mlir::BoolAttr disableVecAttr; + mlir::BoolAttr scalableEnable; + mlir::IntegerAttr vectorWidth; mlir::LLVM::LoopUnrollAttr ua; mlir::LLVM::LoopUnrollAndJamAttr uja; + llvm::SmallVector<mlir::LLVM::AccessGroupAttr> aga; bool has_attrs = false; for (const auto *dir : dirs) { Fortran::common::visit( Fortran::common::visitors{ [&](const Fortran::parser::CompilerDirective::VectorAlways &) { - mlir::BoolAttr falseAttr = + disableVecAttr = mlir::BoolAttr::get(builder->getContext(), false); - va = mlir::LLVM::LoopVectorizeAttr::get(builder->getContext(), - /*disable=*/falseAttr, - {}, {}, {}, {}, {}, {}); + has_attrs = true; + }, + [&](const Fortran::parser::CompilerDirective::VectorLength &vl) { + using Kind = + Fortran::parser::CompilerDirective::VectorLength::Kind; + Kind kind = std::get<Kind>(vl.t); + uint64_t length = std::get<uint64_t>(vl.t); + disableVecAttr = + mlir::BoolAttr::get(builder->getContext(), false); + if (length != 0) + vectorWidth = + builder->getIntegerAttr(builder->getI64Type(), length); + switch (kind) { + case Kind::Scalable: + scalableEnable = + mlir::BoolAttr::get(builder->getContext(), true); + break; + case Kind::Fixed: + scalableEnable = + mlir::BoolAttr::get(builder->getContext(), false); + break; + case Kind::Auto: + break; + } has_attrs = true; }, [&](const Fortran::parser::CompilerDirective::Unroll &u) { @@ -2527,11 +2648,8 @@ private: has_attrs = true; }, [&](const Fortran::parser::CompilerDirective::NoVector &u) { - mlir::BoolAttr trueAttr = + disableVecAttr = mlir::BoolAttr::get(builder->getContext(), true); - va = mlir::LLVM::LoopVectorizeAttr::get(builder->getContext(), - /*disable=*/trueAttr, - {}, {}, {}, {}, {}, {}); has_attrs = true; }, [&](const Fortran::parser::CompilerDirective::NoUnroll &u) { @@ -2542,13 +2660,22 @@ private: uja = genLoopUnrollAndJamAttr(/*unrollingFactor=*/0); has_attrs = true; }, - + [&](const Fortran::parser::CompilerDirective::IVDep &iv) { + disableVecAttr = + mlir::BoolAttr::get(builder->getContext(), false); + aga.push_back( + mlir::LLVM::AccessGroupAttr::get(builder->getContext())); + has_attrs = true; + }, [&](const auto &) {}}, dir->u); } + mlir::LLVM::LoopVectorizeAttr va = + genLoopVectorizeAttr(disableVecAttr, scalableEnable, vectorWidth); mlir::LLVM::LoopAnnotationAttr la = mlir::LLVM::LoopAnnotationAttr::get( builder->getContext(), {}, /*vectorize=*/va, {}, /*unroll*/ ua, - /*unroll_and_jam*/ uja, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}); + /*unroll_and_jam*/ uja, {}, {}, {}, {}, {}, {}, {}, {}, {}, + /*parallelAccesses*/ aga); if (has_attrs) { if (auto loopOp = mlir::dyn_cast<fir::DoLoopOp>(info.loopOp)) loopOp.setLoopAnnotationAttr(la); @@ -3251,6 +3378,9 @@ private: [&](const Fortran::parser::CompilerDirective::VectorAlways &) { attachDirectiveToLoop(dir, &eval); }, + [&](const Fortran::parser::CompilerDirective::VectorLength &) { + attachDirectiveToLoop(dir, &eval); + }, [&](const Fortran::parser::CompilerDirective::Unroll &) { attachDirectiveToLoop(dir, &eval); }, @@ -3275,6 +3405,12 @@ private: [&](const Fortran::parser::CompilerDirective::NoInline &) { attachInliningDirectiveToStmt(dir, &eval); }, + [&](const Fortran::parser::CompilerDirective::Prefetch &prefetch) { + TODO(getCurrentLocation(), "!$dir prefetch"); + }, + [&](const Fortran::parser::CompilerDirective::IVDep &) { + attachDirectiveToLoop(dir, &eval); + }, [&](const auto &) {}}, dir.u); } @@ -3832,14 +3968,8 @@ private: if (!isCharSelector) return mlir::arith::CmpIOp::create(*builder, loc, pred, selector, rhs); - fir::factory::CharacterExprHelper charHelper{*builder, loc}; - std::pair<mlir::Value, mlir::Value> lhsVal = - charHelper.createUnboxChar(selector); - std::pair<mlir::Value, mlir::Value> rhsVal = - charHelper.createUnboxChar(rhs); - return fir::runtime::genCharCompare(*builder, loc, pred, lhsVal.first, - lhsVal.second, rhsVal.first, - rhsVal.second); + else + return hlfir::CmpCharOp::create(*builder, loc, pred, selector, rhs); }; mlir::Block *newBlock = insertBlock(*caseBlock); if (mlir::isa<fir::ClosedIntervalAttr>(attr)) { @@ -3950,13 +4080,30 @@ private: } void genFIR(const Fortran::parser::ChangeTeamConstruct &construct) { - TODO(toLocation(), "coarray: ChangeTeamConstruct"); + Fortran::lower::StatementContext stmtCtx; + pushActiveConstruct(getEval(), stmtCtx); + + for (Fortran::lower::pft::Evaluation &e : + getEval().getNestedEvaluations()) { + if (e.getIf<Fortran::parser::ChangeTeamStmt>()) { + maybeStartBlock(e.block); + setCurrentPosition(e.position); + genFIR(e); + } else if (e.getIf<Fortran::parser::EndChangeTeamStmt>()) { + maybeStartBlock(e.block); + setCurrentPosition(e.position); + genFIR(e); + } else { + genFIR(e); + } + } + popActiveConstruct(); } void genFIR(const Fortran::parser::ChangeTeamStmt &stmt) { - TODO(toLocation(), "coarray: ChangeTeamStmt"); + genChangeTeamStmt(*this, getEval(), stmt); } void genFIR(const Fortran::parser::EndChangeTeamStmt &stmt) { - TODO(toLocation(), "coarray: EndChangeTeamStmt"); + genEndChangeTeamStmt(*this, getEval(), stmt); } void genFIR(const Fortran::parser::CriticalConstruct &criticalConstruct) { @@ -4702,32 +4849,14 @@ private: // Generate pointer assignment with possibly empty bounds-spec. R1035: a // bounds-spec is a lower bound value. - void genPointerAssignment( + void genNoHLFIRPointerAssignment( mlir::Location loc, const Fortran::evaluate::Assignment &assign, const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) { Fortran::lower::StatementContext stmtCtx; - if (!lowerToHighLevelFIR() && - Fortran::evaluate::IsProcedureDesignator(assign.rhs)) + assert(!lowerToHighLevelFIR() && "code should not be called with HFLIR"); + if (Fortran::evaluate::IsProcedureDesignator(assign.rhs)) TODO(loc, "procedure pointer assignment"); - if (Fortran::evaluate::IsProcedurePointer(assign.lhs)) { - hlfir::Entity lhs = Fortran::lower::convertExprToHLFIR( - loc, *this, assign.lhs, localSymbols, stmtCtx); - if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>( - assign.rhs)) { - // rhs is null(). rhs being null(pptr) is handled in genNull. - auto boxTy{ - Fortran::lower::getUntypedBoxProcType(builder->getContext())}; - hlfir::Entity rhs( - fir::factory::createNullBoxProc(*builder, loc, boxTy)); - builder->createStoreWithConvert(loc, rhs, lhs); - return; - } - hlfir::Entity rhs(getBase(Fortran::lower::convertExprToAddress( - loc, *this, assign.rhs, localSymbols, stmtCtx))); - builder->createStoreWithConvert(loc, rhs, lhs); - return; - } std::optional<Fortran::evaluate::DynamicType> lhsType = assign.lhs.GetType(); @@ -4735,7 +4864,7 @@ private: // to the runtime. element size, type code, attribute and of // course base_addr might need to be updated. if (lhsType && lhsType->IsPolymorphic()) { - if (!lowerToHighLevelFIR() && explicitIterationSpace()) + if (explicitIterationSpace()) TODO(loc, "polymorphic pointer assignment in FORALL"); llvm::SmallVector<mlir::Value> lbounds; for (const Fortran::evaluate::ExtentExpr &lbExpr : lbExprs) @@ -4762,7 +4891,7 @@ private: llvm::SmallVector<mlir::Value> lbounds; for (const Fortran::evaluate::ExtentExpr &lbExpr : lbExprs) lbounds.push_back(fir::getBase(genExprValue(toEvExpr(lbExpr), stmtCtx))); - if (!lowerToHighLevelFIR() && explicitIterationSpace()) { + if (explicitIterationSpace()) { // Pointer assignment in FORALL context. Copy the rhs box value // into the lhs box variable. genArrayAssignment(assign, stmtCtx, lbounds); @@ -4773,6 +4902,21 @@ private: stmtCtx); } + void genPointerAssignment(mlir::Location loc, + const Fortran::evaluate::Assignment &assign) { + if (isInsideHlfirForallOrWhere()) { + // Generate Pointer assignment as hlfir.region_assign. + genForallPointerAssignment(loc, assign); + return; + } + Fortran::lower::StatementContext stmtCtx; + hlfir::Entity lhs = Fortran::lower::convertExprToHLFIR( + loc, *this, assign.lhs, localSymbols, stmtCtx); + mlir::Value rhs = genPointerAssignmentRhs(loc, lhs, assign, stmtCtx); + builder->createStoreWithConvert(loc, rhs, lhs); + cuf::genPointerSync(lhs, *builder); + } + void genForallPointerAssignment(mlir::Location loc, const Fortran::evaluate::Assignment &assign) { // Lower pointer assignment inside forall with hlfir.region_assign with @@ -4793,8 +4937,7 @@ private: // Lower RHS in its own region. builder->createBlock(®ionAssignOp.getRhsRegion()); Fortran::lower::StatementContext rhsContext; - mlir::Value rhs = - genForallPointerAssignmentRhs(loc, lhs, assign, rhsContext); + mlir::Value rhs = genPointerAssignmentRhs(loc, lhs, assign, rhsContext); auto rhsYieldOp = hlfir::YieldOp::create(*builder, loc, rhs); Fortran::lower::genCleanUpInRegionIfAny( loc, *builder, rhsYieldOp.getCleanup(), rhsContext); @@ -4810,9 +4953,9 @@ private: } mlir::Value - genForallPointerAssignmentRhs(mlir::Location loc, mlir::Value lhs, - const Fortran::evaluate::Assignment &assign, - Fortran::lower::StatementContext &rhsContext) { + genPointerAssignmentRhs(mlir::Location loc, hlfir::Entity lhs, + const Fortran::evaluate::Assignment &assign, + Fortran::lower::StatementContext &rhsContext) { if (Fortran::evaluate::IsProcedureDesignator(assign.lhs)) { if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>( assign.rhs)) @@ -4824,11 +4967,34 @@ private: // Data target. auto lhsBoxType = llvm::cast<fir::BaseBoxType>(fir::unwrapRefType(lhs.getType())); - // For NULL, create disassociated descriptor whose dynamic type is - // the static type of the LHS. + // For NULL, create disassociated descriptor whose dynamic type is the + // static type of the LHS (fulfills 7.3.2.3 requirements that the dynamic + // type of a deallocated polymorphic pointer is its static type). if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>( - assign.rhs)) - return fir::factory::createUnallocatedBox(*builder, loc, lhsBoxType, {}); + assign.rhs)) { + llvm::SmallVector<mlir::Value, 1> nonDeferredLenParams; + if (auto lhsVar = + llvm::dyn_cast_if_present<fir::FortranVariableOpInterface>( + lhs.getDefiningOp())) + nonDeferredLenParams = lhsVar.getExplicitTypeParams(); + if (isInsideHlfirForallOrWhere()) { + // Inside FORALL, the non deferred type parameters may only be + // accessible in the hlfir.region_assign lhs region if they were + // computed there. + for (mlir::Value ¶m : nonDeferredLenParams) + if (!param.getParentRegion()->isAncestor( + builder->getBlock()->getParent())) { + if (llvm::isa_and_nonnull<mlir::arith::ConstantOp>( + param.getDefiningOp())) + param = builder->clone(*param.getDefiningOp())->getResult(0); + else + TODO(loc, "Pointer assignment with non deferred type parameter " + "inside FORALL"); + } + } + return fir::factory::createUnallocatedBox(*builder, loc, lhsBoxType, + nonDeferredLenParams); + } hlfir::Entity rhs = Fortran::lower::convertExprToHLFIR( loc, *this, assign.rhs, localSymbols, rhsContext); auto rhsBoxType = rhs.getBoxType(); @@ -4876,6 +5042,10 @@ private: mlir::Value shape = builder->genShape(loc, lbounds, extents); rhsBox = fir::ReboxOp::create(*builder, loc, lhsBoxType, rhsBox, shape, /*slice=*/mlir::Value{}); + } else if (fir::isClassStarType(lhsBoxType) && + !fir::ConvertOp::canBeConverted(rhsBoxType, lhsBoxType)) { + rhsBox = fir::ReboxOp::create(*builder, loc, lhsBoxType, rhsBox, + mlir::Value{}, mlir::Value{}); } return rhsBox; } @@ -4917,9 +5087,10 @@ private: // Pointer assignment with bounds-remapping. R1036: a bounds-remapping is a // pair, lower bound and upper bound. - void genPointerAssignment( + void genNoHLFIRPointerAssignment( mlir::Location loc, const Fortran::evaluate::Assignment &assign, const Fortran::evaluate::Assignment::BoundsRemapping &boundExprs) { + assert(!lowerToHighLevelFIR() && "code should not be called with HFLIR"); Fortran::lower::StatementContext stmtCtx; llvm::SmallVector<mlir::Value> lbounds; llvm::SmallVector<mlir::Value> ubounds; @@ -4938,7 +5109,7 @@ private: // Polymorphic lhs/rhs need more care. See F2018 10.2.2.3. if ((lhsType && lhsType->IsPolymorphic()) || (rhsType && rhsType->IsPolymorphic())) { - if (!lowerToHighLevelFIR() && explicitIterationSpace()) + if (explicitIterationSpace()) TODO(loc, "polymorphic pointer assignment in FORALL"); fir::MutableBoxValue lhsMutableBox = genExprMutableBox(loc, assign.lhs); @@ -4956,7 +5127,7 @@ private: rhsType->IsPolymorphic()); return; } - if (!lowerToHighLevelFIR() && explicitIterationSpace()) { + if (explicitIterationSpace()) { // Pointer assignment in FORALL context. Copy the rhs box value // into the lhs box variable. genArrayAssignment(assign, stmtCtx, lbounds, ubounds); @@ -4968,13 +5139,6 @@ private: fir::factory::disassociateMutableBox(*builder, loc, lhs); return; } - if (lowerToHighLevelFIR()) { - fir::ExtendedValue rhs = genExprAddr(assign.rhs, stmtCtx); - fir::factory::associateMutableBoxWithRemap(*builder, loc, lhs, rhs, - lbounds, ubounds); - return; - } - // Legacy lowering below. // Do not generate a temp in case rhs is an array section. fir::ExtendedValue rhs = Fortran::lower::isArraySectionWithoutVectorSubscript(assign.rhs) @@ -5364,18 +5528,10 @@ private: dirs); }, [&](const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) { - if (isInsideHlfirForallOrWhere()) - genForallPointerAssignment(loc, assign); - else - genPointerAssignment(loc, assign, lbExprs); + genPointerAssignment(loc, assign); }, [&](const Fortran::evaluate::Assignment::BoundsRemapping - &boundExprs) { - if (isInsideHlfirForallOrWhere()) - genForallPointerAssignment(loc, assign); - else - genPointerAssignment(loc, assign, boundExprs); - }, + &boundExprs) { genPointerAssignment(loc, assign); }, }, assign.u); return; @@ -5577,11 +5733,11 @@ private: }, [&](const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) { - return genPointerAssignment(loc, assign, lbExprs); + return genNoHLFIRPointerAssignment(loc, assign, lbExprs); }, [&](const Fortran::evaluate::Assignment::BoundsRemapping &boundExprs) { - return genPointerAssignment(loc, assign, boundExprs); + return genNoHLFIRPointerAssignment(loc, assign, boundExprs); }, }, assign.u); @@ -5955,7 +6111,16 @@ private: const Fortran::lower::CalleeInterface &callee) { assert(builder && "require a builder object at this point"); using PassBy = Fortran::lower::CalleeInterface::PassEntityBy; + + // Track the source-level argument position (1-based) + unsigned argPosition = 0; + auto mapPassedEntity = [&](const auto arg, bool isResult = false) { + // Count only actual source-level dummy arguments (not results or + // host assoc tuples) + if (!isResult && arg.entity.has_value()) + argPosition++; + if (arg.passBy == PassBy::AddressAndLength) { if (callee.characterize().IsBindC()) return; @@ -5966,11 +6131,12 @@ private: mlir::Value casted = builder->createVolatileCast(loc, false, arg.firArgument); mlir::Value box = charHelp.createEmboxChar(casted, arg.firLength); - mapBlockArgToDummyOrResult(arg.entity->get(), box, isResult); + mapBlockArgToDummyOrResult(arg.entity->get(), box, isResult, + isResult ? 0 : argPosition); } else { if (arg.entity.has_value()) { mapBlockArgToDummyOrResult(arg.entity->get(), arg.firArgument, - isResult); + isResult, isResult ? 0 : argPosition); } else { assert(funit.parentHasTupleHostAssoc() && "expect tuple argument"); } @@ -6828,13 +6994,22 @@ private: } /// Record the given symbol as a dummy argument of this function. - void registerDummySymbol(Fortran::semantics::SymbolRef symRef) { + /// \param symRef The symbol representing the dummy argument + /// \param argNo The 1-based position of this argument in the source (0 = + /// unknown) + void registerDummySymbol(Fortran::semantics::SymbolRef symRef, + unsigned argNo = 0) { auto *sym = &*symRef; registeredDummySymbols.insert(sym); + if (argNo > 0) + dummyArgPositions[sym] = argNo; } /// Reset all registered dummy symbols. - void resetRegisteredDummySymbols() { registeredDummySymbols.clear(); } + void resetRegisteredDummySymbols() { + registeredDummySymbols.clear(); + dummyArgPositions.clear(); + } void setCurrentFunctionUnit(Fortran::lower::pft::FunctionLikeUnit *unit) { currentFunctionUnit = unit; @@ -6876,6 +7051,11 @@ private: llvm::SmallPtrSet<const Fortran::semantics::Symbol *, 16> registeredDummySymbols; + /// Map from dummy symbols to their 1-based argument positions. + /// Used to generate debug info with correct argument numbers. + llvm::DenseMap<const Fortran::semantics::Symbol *, unsigned> + dummyArgPositions; + /// A map of unique names for constant expressions. /// The names are used for representing the constant expressions /// with global constant initialized objects. @@ -6935,8 +7115,6 @@ void Fortran::lower::LoweringBridge::lower( const Fortran::semantics::SemanticsContext &semanticsContext) { std::unique_ptr<Fortran::lower::pft::Program> pft = Fortran::lower::createPFT(prg, semanticsContext); - if (dumpBeforeFir) - Fortran::lower::dumpPFT(llvm::errs(), *pft); FirConverter converter{*this}; converter.run(*pft); } diff --git a/flang/lib/Lower/CMakeLists.txt b/flang/lib/Lower/CMakeLists.txt index 3d0b4e4..230a56a 100644 --- a/flang/lib/Lower/CMakeLists.txt +++ b/flang/lib/Lower/CMakeLists.txt @@ -5,7 +5,6 @@ add_flang_library(FortranLower Allocatable.cpp Bridge.cpp CallInterface.cpp - Coarray.cpp ComponentPath.cpp ConvertArrayConstructor.cpp ConvertCall.cpp @@ -23,6 +22,7 @@ add_flang_library(FortranLower IterationSpace.cpp LoweringOptions.cpp Mangler.cpp + MultiImageFortran.cpp OpenACC.cpp OpenMP/Atomic.cpp OpenMP/ClauseProcessor.cpp diff --git a/flang/lib/Lower/CUDA.cpp b/flang/lib/Lower/CUDA.cpp index 9501b0e..fb05528 100644 --- a/flang/lib/Lower/CUDA.cpp +++ b/flang/lib/Lower/CUDA.cpp @@ -91,3 +91,17 @@ hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) { return elOp; return {}; } + +bool Fortran::lower::hasDoubleDescriptor(mlir::Value addr) { + if (auto declareOp = + mlir::dyn_cast_or_null<hlfir::DeclareOp>(addr.getDefiningOp())) { + if (mlir::isa_and_nonnull<fir::AddrOfOp>( + declareOp.getMemref().getDefiningOp())) { + if (declareOp.getDataAttr() && + *declareOp.getDataAttr() == cuf::DataAttribute::Pinned) + return false; + return true; + } + } + return false; +} diff --git a/flang/lib/Lower/Coarray.cpp b/flang/lib/Lower/Coarray.cpp deleted file mode 100644 index a84f65a..0000000 --- a/flang/lib/Lower/Coarray.cpp +++ /dev/null @@ -1,66 +0,0 @@ -//===-- Coarray.cpp -------------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// Implementation of the lowering of image related constructs and expressions. -/// Fortran images can form teams, communicate via coarrays, etc. -/// -//===----------------------------------------------------------------------===// - -#include "flang/Lower/Coarray.h" -#include "flang/Lower/AbstractConverter.h" -#include "flang/Lower/SymbolMap.h" -#include "flang/Optimizer/Builder/FIRBuilder.h" -#include "flang/Optimizer/Builder/Todo.h" -#include "flang/Parser/parse-tree.h" -#include "flang/Semantics/expression.h" - -//===----------------------------------------------------------------------===// -// TEAM statements and constructs -//===----------------------------------------------------------------------===// - -void Fortran::lower::genChangeTeamConstruct( - Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &, - const Fortran::parser::ChangeTeamConstruct &) { - TODO(converter.getCurrentLocation(), "coarray: CHANGE TEAM construct"); -} - -void Fortran::lower::genChangeTeamStmt( - Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &, - const Fortran::parser::ChangeTeamStmt &) { - TODO(converter.getCurrentLocation(), "coarray: CHANGE TEAM statement"); -} - -void Fortran::lower::genEndChangeTeamStmt( - Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &, - const Fortran::parser::EndChangeTeamStmt &) { - TODO(converter.getCurrentLocation(), "coarray: END CHANGE TEAM statement"); -} - -void Fortran::lower::genFormTeamStatement( - Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &, const Fortran::parser::FormTeamStmt &) { - TODO(converter.getCurrentLocation(), "coarray: FORM TEAM statement"); -} - -//===----------------------------------------------------------------------===// -// COARRAY expressions -//===----------------------------------------------------------------------===// - -fir::ExtendedValue Fortran::lower::CoarrayExprHelper::genAddr( - const Fortran::evaluate::CoarrayRef &expr) { - (void)symMap; - TODO(converter.getCurrentLocation(), "co-array address"); -} - -fir::ExtendedValue Fortran::lower::CoarrayExprHelper::genValue( - const Fortran::evaluate::CoarrayRef &expr) { - TODO(converter.getCurrentLocation(), "co-array value"); -} diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp index 9bf994e..cd5218e 100644 --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -713,7 +713,8 @@ Fortran::lower::genCallOpAndResult( builder.getContext(), fir::FortranInlineEnum::always_inline); auto call = fir::CallOp::create( builder, loc, funcType.getResults(), funcSymbolAttr, operands, - /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs, inlineAttr); + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs, inlineAttr, + /*accessGroups=*/mlir::ArrayAttr{}); callNumResults = call.getNumResults(); if (callNumResults != 0) @@ -1296,10 +1297,14 @@ static PreparedDummyArgument preparePresentUserCallActualArgument( Fortran::evaluate::FoldingContext &foldingContext{ callContext.converter.getFoldingContext()}; - bool suggestCopyIn = Fortran::evaluate::MayNeedCopy( - arg.entity, arg.characteristics, foldingContext, /*forCopyOut=*/false); - bool suggestCopyOut = Fortran::evaluate::MayNeedCopy( - arg.entity, arg.characteristics, foldingContext, /*forCopyOut=*/true); + bool suggestCopyIn = Fortran::evaluate::ActualArgNeedsCopy( + arg.entity, arg.characteristics, foldingContext, + /*forCopyOut=*/false) + .value_or(true); + bool suggestCopyOut = Fortran::evaluate::ActualArgNeedsCopy( + arg.entity, arg.characteristics, foldingContext, + /*forCopyOut=*/true) + .value_or(true); mustDoCopyIn = actual.isArray() && suggestCopyIn; mustDoCopyOut = actual.isArray() && suggestCopyOut; } diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp index a46d219..b2910a0 100644 --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -19,7 +19,6 @@ #include "flang/Lower/Bridge.h" #include "flang/Lower/BuiltinModules.h" #include "flang/Lower/CallInterface.h" -#include "flang/Lower/Coarray.h" #include "flang/Lower/ComponentPath.h" #include "flang/Lower/ConvertCall.h" #include "flang/Lower/ConvertConstant.h" @@ -28,6 +27,7 @@ #include "flang/Lower/ConvertVariable.h" #include "flang/Lower/CustomIntrinsicCall.h" #include "flang/Lower/Mangler.h" +#include "flang/Lower/MultiImageFortran.h" #include "flang/Lower/Runtime.h" #include "flang/Lower/Support/Utils.h" #include "flang/Optimizer/Builder/Character.h" diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp index 2517ab3..53d4d75 100644 --- a/flang/lib/Lower/ConvertVariable.cpp +++ b/flang/lib/Lower/ConvertVariable.cpp @@ -1946,12 +1946,15 @@ static void genDeclareSymbol(Fortran::lower::AbstractConverter &converter, return; } mlir::Value dummyScope; - if (converter.isRegisteredDummySymbol(sym)) + unsigned argNo = 0; + if (converter.isRegisteredDummySymbol(sym)) { dummyScope = converter.dummyArgsScopeValue(); + argNo = converter.getDummyArgPosition(sym); + } auto [storage, storageOffset] = converter.getSymbolStorage(sym); auto newBase = hlfir::DeclareOp::create( builder, loc, base, name, shapeOrShift, lenParams, dummyScope, storage, - storageOffset, attributes, dataAttr); + storageOffset, attributes, dataAttr, argNo); symMap.addVariableDefinition(sym, newBase, force); return; } @@ -2004,15 +2007,17 @@ void Fortran::lower::genDeclareSymbol( sym.GetUltimate()); auto name = converter.mangleName(sym); mlir::Value dummyScope; + unsigned argNo = 0; fir::ExtendedValue base = exv; if (converter.isRegisteredDummySymbol(sym)) { base = genPackArray(converter, sym, exv); dummyScope = converter.dummyArgsScopeValue(); + argNo = converter.getDummyArgPosition(sym); } auto [storage, storageOffset] = converter.getSymbolStorage(sym); hlfir::EntityWithAttributes declare = hlfir::genDeclare(loc, builder, base, name, attributes, dummyScope, - storage, storageOffset, dataAttr); + storage, storageOffset, dataAttr, argNo); symMap.addVariableDefinition(sym, declare.getIfVariableInterface(), force); return; } diff --git a/flang/lib/Lower/MultiImageFortran.cpp b/flang/lib/Lower/MultiImageFortran.cpp new file mode 100644 index 0000000..745ca249 --- /dev/null +++ b/flang/lib/Lower/MultiImageFortran.cpp @@ -0,0 +1,278 @@ +//===-- MultiImageFortran.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// Implementation of the lowering of image related constructs and expressions. +/// Fortran images can form teams, communicate via coarrays, etc. +/// +//===----------------------------------------------------------------------===// + +#include "flang/Lower/MultiImageFortran.h" +#include "flang/Lower/AbstractConverter.h" +#include "flang/Lower/SymbolMap.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/Todo.h" +#include "flang/Optimizer/Dialect/MIF/MIFOps.h" +#include "flang/Parser/parse-tree.h" +#include "flang/Semantics/expression.h" + +//===----------------------------------------------------------------------===// +// Synchronization statements +//===----------------------------------------------------------------------===// + +void Fortran::lower::genSyncAllStatement( + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::SyncAllStmt &stmt) { + mlir::Location loc = converter.getCurrentLocation(); + converter.checkCoarrayEnabled(); + + // Handle STAT and ERRMSG values + const std::list<Fortran::parser::StatOrErrmsg> &statOrErrList = stmt.v; + auto [statAddr, errMsgAddr] = converter.genStatAndErrmsg(loc, statOrErrList); + + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + mif::SyncAllOp::create(builder, loc, statAddr, errMsgAddr); +} + +void Fortran::lower::genSyncImagesStatement( + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::SyncImagesStmt &stmt) { + mlir::Location loc = converter.getCurrentLocation(); + converter.checkCoarrayEnabled(); + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + + // Handle STAT and ERRMSG values + const std::list<Fortran::parser::StatOrErrmsg> &statOrErrList = + std::get<std::list<Fortran::parser::StatOrErrmsg>>(stmt.t); + auto [statAddr, errMsgAddr] = converter.genStatAndErrmsg(loc, statOrErrList); + + // SYNC_IMAGES(*) is passed as count == -1 while SYNC IMAGES([]) has count + // == 0. Note further that SYNC IMAGES(*) is not semantically equivalent to + // SYNC ALL. + Fortran::lower::StatementContext stmtCtx; + mlir::Value imageSet; + const Fortran::parser::SyncImagesStmt::ImageSet &imgSet = + std::get<Fortran::parser::SyncImagesStmt::ImageSet>(stmt.t); + std::visit(Fortran::common::visitors{ + [&](const Fortran::parser::IntExpr &intExpr) { + const SomeExpr *expr = Fortran::semantics::GetExpr(intExpr); + imageSet = + fir::getBase(converter.genExprBox(loc, *expr, stmtCtx)); + }, + [&](const Fortran::parser::Star &) { + // Image set is not set. + imageSet = mlir::Value{}; + }}, + imgSet.u); + + mif::SyncImagesOp::create(builder, loc, imageSet, statAddr, errMsgAddr); +} + +void Fortran::lower::genSyncMemoryStatement( + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::SyncMemoryStmt &stmt) { + mlir::Location loc = converter.getCurrentLocation(); + converter.checkCoarrayEnabled(); + + // Handle STAT and ERRMSG values + const std::list<Fortran::parser::StatOrErrmsg> &statOrErrList = stmt.v; + auto [statAddr, errMsgAddr] = converter.genStatAndErrmsg(loc, statOrErrList); + + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + mif::SyncMemoryOp::create(builder, loc, statAddr, errMsgAddr); +} + +void Fortran::lower::genSyncTeamStatement( + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::SyncTeamStmt &stmt) { + mlir::Location loc = converter.getCurrentLocation(); + converter.checkCoarrayEnabled(); + + // Handle TEAM + Fortran::lower::StatementContext stmtCtx; + const Fortran::parser::TeamValue &teamValue = + std::get<Fortran::parser::TeamValue>(stmt.t); + const SomeExpr *teamExpr = Fortran::semantics::GetExpr(teamValue); + mlir::Value team = + fir::getBase(converter.genExprBox(loc, *teamExpr, stmtCtx)); + + // Handle STAT and ERRMSG values + const std::list<Fortran::parser::StatOrErrmsg> &statOrErrList = + std::get<std::list<Fortran::parser::StatOrErrmsg>>(stmt.t); + auto [statAddr, errMsgAddr] = converter.genStatAndErrmsg(loc, statOrErrList); + + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + mif::SyncTeamOp::create(builder, loc, team, statAddr, errMsgAddr); +} + +//===----------------------------------------------------------------------===// +// TEAM statements and constructs +//===----------------------------------------------------------------------===// + +void Fortran::lower::genChangeTeamConstruct( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &, + const Fortran::parser::ChangeTeamConstruct &) { + TODO(converter.getCurrentLocation(), "coarray: CHANGE TEAM construct"); +} + +void Fortran::lower::genChangeTeamStmt( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &, + const Fortran::parser::ChangeTeamStmt &stmt) { + mlir::Location loc = converter.getCurrentLocation(); + converter.checkCoarrayEnabled(); + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + + mlir::Value errMsgAddr, statAddr, team; + // Handle STAT and ERRMSG values + Fortran::lower::StatementContext stmtCtx; + const std::list<Fortran::parser::StatOrErrmsg> &statOrErrList = + std::get<std::list<Fortran::parser::StatOrErrmsg>>(stmt.t); + for (const Fortran::parser::StatOrErrmsg &statOrErr : statOrErrList) { + std::visit(Fortran::common::visitors{ + [&](const Fortran::parser::StatVariable &statVar) { + const auto *expr = Fortran::semantics::GetExpr(statVar); + statAddr = fir::getBase( + converter.genExprAddr(loc, *expr, stmtCtx)); + }, + [&](const Fortran::parser::MsgVariable &errMsgVar) { + const auto *expr = Fortran::semantics::GetExpr(errMsgVar); + errMsgAddr = fir::getBase( + converter.genExprBox(loc, *expr, stmtCtx)); + }, + }, + statOrErr.u); + } + + // TODO: Manage the list of coarrays associated in + // `std::list<CoarrayAssociation>`. According to the PRIF specification, it is + // necessary to call `prif_alias_{create|destroy}` for each coarray defined in + // this list. Support will be added once lowering to this procedure is + // possible. + const std::list<Fortran::parser::CoarrayAssociation> &coarrayAssocList = + std::get<std::list<Fortran::parser::CoarrayAssociation>>(stmt.t); + if (coarrayAssocList.size()) + TODO(loc, "Coarrays provided in the association list."); + + // Handle TEAM-VALUE + const auto *teamExpr = + Fortran::semantics::GetExpr(std::get<Fortran::parser::TeamValue>(stmt.t)); + team = fir::getBase(converter.genExprBox(loc, *teamExpr, stmtCtx)); + + mif::ChangeTeamOp changeOp = mif::ChangeTeamOp::create( + builder, loc, team, statAddr, errMsgAddr, /*terminator*/ false); + builder.setInsertionPointToStart(changeOp.getBody()); +} + +void Fortran::lower::genEndChangeTeamStmt( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &, + const Fortran::parser::EndChangeTeamStmt &stmt) { + converter.checkCoarrayEnabled(); + mlir::Location loc = converter.getCurrentLocation(); + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + + mlir::Value errMsgAddr, statAddr; + // Handle STAT and ERRMSG values + Fortran::lower::StatementContext stmtCtx; + const std::list<Fortran::parser::StatOrErrmsg> &statOrErrList = + std::get<std::list<Fortran::parser::StatOrErrmsg>>(stmt.t); + for (const Fortran::parser::StatOrErrmsg &statOrErr : statOrErrList) { + std::visit(Fortran::common::visitors{ + [&](const Fortran::parser::StatVariable &statVar) { + const auto *expr = Fortran::semantics::GetExpr(statVar); + statAddr = fir::getBase( + converter.genExprAddr(loc, *expr, stmtCtx)); + }, + [&](const Fortran::parser::MsgVariable &errMsgVar) { + const auto *expr = Fortran::semantics::GetExpr(errMsgVar); + errMsgAddr = fir::getBase( + converter.genExprBox(loc, *expr, stmtCtx)); + }, + }, + statOrErr.u); + } + + mif::EndTeamOp endOp = + mif::EndTeamOp::create(builder, loc, statAddr, errMsgAddr); + builder.setInsertionPointAfter(endOp.getParentOp()); +} + +void Fortran::lower::genFormTeamStatement( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &, + const Fortran::parser::FormTeamStmt &stmt) { + converter.checkCoarrayEnabled(); + mlir::Location loc = converter.getCurrentLocation(); + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + + mlir::Value errMsgAddr, statAddr, newIndex, teamNumber, team; + // Handle NEW_INDEX, STAT and ERRMSG + std::list<Fortran::parser::StatOrErrmsg> statOrErrList{}; + Fortran::lower::StatementContext stmtCtx; + const auto &formSpecList = + std::get<std::list<Fortran::parser::FormTeamStmt::FormTeamSpec>>(stmt.t); + for (const Fortran::parser::FormTeamStmt::FormTeamSpec &formSpec : + formSpecList) { + std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::StatOrErrmsg &statOrErr) { + std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::StatVariable &statVar) { + const auto *expr = Fortran::semantics::GetExpr(statVar); + statAddr = fir::getBase( + converter.genExprAddr(loc, *expr, stmtCtx)); + }, + [&](const Fortran::parser::MsgVariable &errMsgVar) { + const auto *expr = + Fortran::semantics::GetExpr(errMsgVar); + errMsgAddr = fir::getBase( + converter.genExprBox(loc, *expr, stmtCtx)); + }, + }, + statOrErr.u); + }, + [&](const Fortran::parser::ScalarIntExpr &intExpr) { + fir::ExtendedValue newIndexExpr = converter.genExprValue( + loc, Fortran::semantics::GetExpr(intExpr), stmtCtx); + newIndex = fir::getBase(newIndexExpr); + }, + }, + formSpec.u); + } + + // Handle TEAM-NUMBER + const auto *teamNumberExpr = Fortran::semantics::GetExpr( + std::get<Fortran::parser::ScalarIntExpr>(stmt.t)); + teamNumber = + fir::getBase(converter.genExprValue(loc, *teamNumberExpr, stmtCtx)); + + // Handle TEAM-VARIABLE + const auto *teamExpr = Fortran::semantics::GetExpr( + std::get<Fortran::parser::TeamVariable>(stmt.t)); + team = fir::getBase(converter.genExprBox(loc, *teamExpr, stmtCtx)); + + mif::FormTeamOp::create(builder, loc, teamNumber, team, newIndex, statAddr, + errMsgAddr); +} + +//===----------------------------------------------------------------------===// +// COARRAY expressions +//===----------------------------------------------------------------------===// + +fir::ExtendedValue Fortran::lower::CoarrayExprHelper::genAddr( + const Fortran::evaluate::CoarrayRef &expr) { + (void)symMap; + TODO(converter.getCurrentLocation(), "co-array address"); +} + +fir::ExtendedValue Fortran::lower::CoarrayExprHelper::genValue( + const Fortran::evaluate::CoarrayRef &expr) { + TODO(converter.getCurrentLocation(), "co-array value"); +} diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index d7861ac..50b08ce 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -28,6 +28,7 @@ #include "flang/Optimizer/Builder/IntrinsicCall.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/OpenACC/Support/FIROpenACCUtils.h" #include "flang/Parser/parse-tree-visitor.h" #include "flang/Parser/parse-tree.h" #include "flang/Parser/tools.h" @@ -1159,18 +1160,6 @@ bool isConstantBound(mlir::acc::DataBoundsOp &op) { return false; } -/// Return true iff all the bounds are expressed with constant values. -bool areAllBoundConstant(const llvm::SmallVector<mlir::Value> &bounds) { - for (auto bound : bounds) { - auto dataBound = - mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); - assert(dataBound && "Must be DataBoundOp operation"); - if (!isConstantBound(dataBound)) - return false; - } - return true; -} - static llvm::SmallVector<mlir::Value> genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc, mlir::acc::DataBoundsOp &dataBound) { @@ -1196,59 +1185,6 @@ genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc, return {lb, ub, step}; } -static mlir::Value genShapeFromBoundsOrArgs( - mlir::Location loc, fir::FirOpBuilder &builder, fir::SequenceType seqTy, - const llvm::SmallVector<mlir::Value> &bounds, mlir::ValueRange arguments) { - llvm::SmallVector<mlir::Value> args; - if (bounds.empty() && seqTy) { - if (seqTy.hasDynamicExtents()) { - assert(!arguments.empty() && "arguments must hold the entity"); - auto entity = hlfir::Entity{arguments[0]}; - return hlfir::genShape(loc, builder, entity); - } - return genShapeOp(builder, seqTy, loc).getResult(); - } else if (areAllBoundConstant(bounds)) { - for (auto bound : llvm::reverse(bounds)) { - auto dataBound = - mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); - args.append(genConstantBounds(builder, loc, dataBound)); - } - } else { - assert(((arguments.size() - 2) / 3 == seqTy.getDimension()) && - "Expect 3 block arguments per dimension"); - for (auto arg : arguments.drop_front(2)) - args.push_back(arg); - } - - assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3"); - llvm::SmallVector<mlir::Value> extents; - mlir::Type idxTy = builder.getIndexType(); - mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); - mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); - for (unsigned i = 0; i < args.size(); i += 3) { - mlir::Value s1 = - mlir::arith::SubIOp::create(builder, loc, args[i + 1], args[0]); - mlir::Value s2 = mlir::arith::AddIOp::create(builder, loc, s1, one); - mlir::Value s3 = - mlir::arith::DivSIOp::create(builder, loc, s2, args[i + 2]); - mlir::Value cmp = mlir::arith::CmpIOp::create( - builder, loc, mlir::arith::CmpIPredicate::sgt, s3, zero); - mlir::Value ext = - mlir::arith::SelectOp::create(builder, loc, cmp, s3, zero); - extents.push_back(ext); - } - return fir::ShapeOp::create(builder, loc, extents); -} - -static hlfir::DesignateOp::Subscripts -getSubscriptsFromArgs(mlir::ValueRange args) { - hlfir::DesignateOp::Subscripts triplets; - for (unsigned i = 2; i < args.size(); i += 3) - triplets.emplace_back( - hlfir::DesignateOp::Triplet{args[i], args[i + 1], args[i + 2]}); - return triplets; -} - static hlfir::Entity genDesignateWithTriplets( fir::FirOpBuilder &builder, mlir::Location loc, hlfir::Entity &entity, hlfir::DesignateOp::Subscripts &triplets, mlir::Value shape) { @@ -1262,19 +1198,88 @@ static hlfir::Entity genDesignateWithTriplets( return hlfir::Entity{designate.getResult()}; } -mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe( - fir::FirOpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc, - mlir::Type ty, llvm::SmallVector<mlir::Value> &bounds) { - mlir::ModuleOp mod = - builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>(); - if (auto recipe = - mod.lookupSymbol<mlir::acc::FirstprivateRecipeOp>(recipeName)) - return recipe; +// Designate uses triplets based on object lower bounds while acc.bounds are +// zero based. This helper shift the bounds to create the designate triplets. +static hlfir::DesignateOp::Subscripts +genTripletsFromAccBounds(fir::FirOpBuilder &builder, mlir::Location loc, + const llvm::SmallVector<mlir::Value> &accBounds, + hlfir::Entity entity) { + assert(entity.getRank() * 3 == static_cast<int>(accBounds.size()) && + "must get lb,ub,step for each dimension"); + hlfir::DesignateOp::Subscripts triplets; + for (unsigned i = 0; i < accBounds.size(); i += 3) { + mlir::Value lb = hlfir::genLBound(loc, builder, entity, i / 3); + lb = builder.createConvert(loc, accBounds[i].getType(), lb); + assert(accBounds[i].getType() == accBounds[i + 1].getType() && + "mix of integer types in triplets"); + mlir::Value sliceLB = + builder.createOrFold<mlir::arith::AddIOp>(loc, accBounds[i], lb); + mlir::Value sliceUB = + builder.createOrFold<mlir::arith::AddIOp>(loc, accBounds[i + 1], lb); + triplets.emplace_back( + hlfir::DesignateOp::Triplet{sliceLB, sliceUB, accBounds[i + 2]}); + } + return triplets; +} - auto ip = builder.saveInsertionPoint(); - auto recipe = genRecipeOp<mlir::acc::FirstprivateRecipeOp>( - builder, mod, recipeName, loc, ty); - bool allConstantBound = areAllBoundConstant(bounds); +static std::pair<hlfir::Entity, hlfir::Entity> +genArraySectionsInRecipe(fir::FirOpBuilder &builder, mlir::Location loc, + llvm::SmallVector<mlir::Value> &dataOperationBounds, + mlir::ValueRange recipeArguments, + bool allConstantBound, hlfir::Entity lhs, + hlfir::Entity rhs) { + lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs); + rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs); + // Get the list of lb,ub,step values for the sections that can be used inside + // the recipe region. + llvm::SmallVector<mlir::Value> bounds; + if (allConstantBound) { + // For constant bounds, the bounds are not region arguments. Materialize + // constants looking at the IR for the bounds on the data operation. + for (auto bound : dataOperationBounds) { + auto dataBound = + mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); + bounds.append(genConstantBounds(builder, loc, dataBound)); + } + } else { + // If one bound is not constant, all of the bounds are region arguments. + for (auto arg : recipeArguments.drop_front(2)) + bounds.push_back(arg); + } + // Compute the fir.shape of the array section and the triplets to create + // hlfir.designate. + assert(lhs.getRank() * 3 == static_cast<int>(bounds.size()) && + "must get lb,ub,step for each dimension"); + llvm::SmallVector<mlir::Value> extents; + mlir::Type idxTy = builder.getIndexType(); + for (unsigned i = 0; i < bounds.size(); i += 3) + extents.push_back(builder.genExtentFromTriplet( + loc, bounds[i], bounds[i + 1], bounds[i + 2], idxTy)); + mlir::Value shape = fir::ShapeOp::create(builder, loc, extents); + hlfir::DesignateOp::Subscripts rhsTriplets = + genTripletsFromAccBounds(builder, loc, bounds, rhs); + hlfir::DesignateOp::Subscripts lhsTriplets; + // Share the bounds when both rhs/lhs are known to be 1-based to avoid noise + // in the IR for the most common cases. + if (!lhs.mayHaveNonDefaultLowerBounds() && + !rhs.mayHaveNonDefaultLowerBounds()) + lhsTriplets = rhsTriplets; + else + lhsTriplets = genTripletsFromAccBounds(builder, loc, bounds, lhs); + hlfir::Entity leftSection = + genDesignateWithTriplets(builder, loc, lhs, lhsTriplets, shape); + hlfir::Entity rightSection = + genDesignateWithTriplets(builder, loc, rhs, rhsTriplets, shape); + return {leftSection, rightSection}; +} + +// Generate the combiner or copy region block and block arguments and return the +// source and destination entities. +static std::pair<hlfir::Entity, hlfir::Entity> +genRecipeCombinerOrCopyRegion(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type ty, mlir::Region ®ion, + llvm::SmallVector<mlir::Value> &bounds, + bool allConstantBound) { llvm::SmallVector<mlir::Type> argsTy{ty, ty}; llvm::SmallVector<mlir::Location> argsLoc{loc, loc}; if (!allConstantBound) { @@ -1289,100 +1294,57 @@ mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe( argsLoc.push_back(dataBound.getStartIdx().getLoc()); } } - builder.createBlock(&recipe.getCopyRegion(), recipe.getCopyRegion().end(), - argsTy, argsLoc); + mlir::Block *block = + builder.createBlock(®ion, region.end(), argsTy, argsLoc); + builder.setInsertionPointToEnd(®ion.back()); + return {hlfir::Entity{block->getArgument(0)}, + hlfir::Entity{block->getArgument(1)}}; +} - builder.setInsertionPointToEnd(&recipe.getCopyRegion().back()); - ty = fir::unwrapRefType(ty); - if (fir::isa_trivial(ty)) { - mlir::Value initValue = fir::LoadOp::create( - builder, loc, recipe.getCopyRegion().front().getArgument(0)); - fir::StoreOp::create(builder, loc, initValue, - recipe.getCopyRegion().front().getArgument(1)); - } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(ty)) { - fir::FirOpBuilder firBuilder{builder, recipe.getOperation()}; - auto shape = genShapeFromBoundsOrArgs( - loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments()); - - auto leftDeclOp = hlfir::DeclareOp::create( - builder, loc, recipe.getCopyRegion().getArgument(0), llvm::StringRef{}, - shape); - auto rightDeclOp = hlfir::DeclareOp::create( - builder, loc, recipe.getCopyRegion().getArgument(1), llvm::StringRef{}, - shape); - - hlfir::DesignateOp::Subscripts triplets = - getSubscriptsFromArgs(recipe.getCopyRegion().getArguments()); - auto leftEntity = hlfir::Entity{leftDeclOp.getBase()}; - auto left = - genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape); - auto rightEntity = hlfir::Entity{rightDeclOp.getBase()}; - auto right = - genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape); - - hlfir::AssignOp::create(firBuilder, loc, left, right); - - } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) { - fir::FirOpBuilder firBuilder{builder, recipe.getOperation()}; - llvm::SmallVector<mlir::Value> tripletArgs; - mlir::Type innerTy = fir::extractSequenceType(boxTy); - fir::SequenceType seqTy = - mlir::dyn_cast_or_null<fir::SequenceType>(innerTy); - if (!seqTy) - TODO(loc, "Unsupported boxed type in OpenACC firstprivate"); - - auto shape = genShapeFromBoundsOrArgs( - loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments()); - hlfir::DesignateOp::Subscripts triplets = - getSubscriptsFromArgs(recipe.getCopyRegion().getArguments()); - auto leftEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(0)}; - auto left = - genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape); - auto rightEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(1)}; - auto right = - genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape); - hlfir::AssignOp::create(firBuilder, loc, left, right); - } else { - // Copy scalar derived type. - // The temporary_lhs flag allows indicating that user defined assignments - // should not be called while copying components, and that the LHS and RHS - // are known to not alias since the LHS is a created object. - hlfir::AssignOp::create( - builder, loc, recipe.getCopyRegion().getArgument(0), - recipe.getCopyRegion().getArgument(1), /*realloc=*/false, - /*keep_lhs_length_if_realloc=*/false, /*temporary_lhs=*/true); - } +mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe( + fir::FirOpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc, + mlir::Type ty, llvm::SmallVector<mlir::Value> &bounds) { + mlir::ModuleOp mod = + builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>(); + if (auto recipe = + mod.lookupSymbol<mlir::acc::FirstprivateRecipeOp>(recipeName)) + return recipe; - mlir::acc::TerminatorOp::create(builder, loc); - builder.restoreInsertionPoint(ip); - return recipe; -} + mlir::OpBuilder::InsertionGuard guard(builder); + auto recipe = genRecipeOp<mlir::acc::FirstprivateRecipeOp>( + builder, mod, recipeName, loc, ty); + bool allConstantBound = fir::acc::areAllBoundsConstant(bounds); + auto [source, destination] = genRecipeCombinerOrCopyRegion( + builder, loc, ty, recipe.getCopyRegion(), bounds, allConstantBound); + + fir::FirOpBuilder firBuilder{builder, recipe.getOperation()}; + + source = hlfir::derefPointersAndAllocatables(loc, builder, source); + destination = hlfir::derefPointersAndAllocatables(loc, builder, destination); -/// Get a string representation of the bounds. -std::string getBoundsString(llvm::SmallVector<mlir::Value> &bounds) { - std::stringstream boundStr; if (!bounds.empty()) - boundStr << "_section_"; - llvm::interleave( - bounds, - [&](mlir::Value bound) { - auto boundsOp = - mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); - if (boundsOp.getLowerbound() && - fir::getIntIfConstant(boundsOp.getLowerbound()) && - boundsOp.getUpperbound() && - fir::getIntIfConstant(boundsOp.getUpperbound())) { - boundStr << "lb" << *fir::getIntIfConstant(boundsOp.getLowerbound()) - << ".ub" << *fir::getIntIfConstant(boundsOp.getUpperbound()); - } else if (boundsOp.getExtent() && - fir::getIntIfConstant(boundsOp.getExtent())) { - boundStr << "ext" << *fir::getIntIfConstant(boundsOp.getExtent()); - } else { - boundStr << "?"; - } - }, - [&] { boundStr << "x"; }); - return boundStr.str(); + std::tie(source, destination) = genArraySectionsInRecipe( + firBuilder, loc, bounds, recipe.getCopyRegion().getArguments(), + allConstantBound, source, destination); + // The source and the destination of the firstprivate copy cannot alias, + // the destination is already properly allocated, so a simple assignment + // can be generated right away to avoid ending-up with runtime calls + // for arrays of numerical, logical and, character types. + // + // The temporary_lhs flag allows indicating that user defined assignments + // should not be called while copying components, and that the LHS and RHS + // are known to not alias since the LHS is a created object. + // + // TODO: detect cases where user defined assignment is needed and add a TODO. + // using temporary_lhs allows more aggressive optimizations of simple derived + // types. Existing compilers supporting OpenACC do not call user defined + // assignments, some use case is needed to decide what to do. + source = hlfir::loadTrivialScalar(loc, builder, source); + hlfir::AssignOp::create(builder, loc, source, destination, /*realloc=*/false, + /*keep_lhs_length_if_realloc=*/false, + /*temporary_lhs=*/true); + mlir::acc::TerminatorOp::create(builder, loc); + return recipe; } /// Rebuild the array type from the acc.bounds operation with constant @@ -1427,7 +1389,6 @@ static void genPrivatizationRecipes( Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::StatementContext &stmtCtx, llvm::SmallVectorImpl<mlir::Value> &dataOperands, - llvm::SmallVector<mlir::Attribute> &privatizationRecipes, llvm::ArrayRef<mlir::Value> async, llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes, llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes, @@ -1458,15 +1419,16 @@ static void genPrivatizationRecipes( RecipeOp recipe; mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType()); if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) { - std::string recipeName = - fir::getTypeAsString(retTy, converter.getKindMap(), - Fortran::lower::privatizationRecipePrefix); + std::string recipeName = fir::acc::getRecipeName( + mlir::acc::RecipeKind::private_recipe, retTy, info.addr, bounds); recipe = Fortran::lower::createOrGetPrivateRecipe(builder, recipeName, operandLocation, retTy); auto op = createDataEntryOp<mlir::acc::PrivateOp>( builder, operandLocation, info.addr, asFortran, bounds, true, /*implicit=*/false, mlir::acc::DataClause::acc_private, retTy, async, asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true); + op.setRecipeAttr( + mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName())); dataOperands.push_back(op.getAccVar()); // Track the symbol and its corresponding mlir::Value if requested @@ -1474,10 +1436,8 @@ static void genPrivatizationRecipes( symbolPairs->emplace_back(op.getAccVar(), Fortran::semantics::SymbolRef(symbol)); } else { - std::string suffix = - areAllBoundConstant(bounds) ? getBoundsString(bounds) : ""; - std::string recipeName = fir::getTypeAsString( - retTy, converter.getKindMap(), "firstprivatization" + suffix); + std::string recipeName = fir::acc::getRecipeName( + mlir::acc::RecipeKind::firstprivate_recipe, retTy, info.addr, bounds); recipe = Fortran::lower::createOrGetFirstprivateRecipe( builder, recipeName, operandLocation, retTy, bounds); auto op = createDataEntryOp<mlir::acc::FirstprivateOp>( @@ -1485,6 +1445,8 @@ static void genPrivatizationRecipes( /*implicit=*/false, mlir::acc::DataClause::acc_firstprivate, retTy, async, asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true); + op.setRecipeAttr( + mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName())); dataOperands.push_back(op.getAccVar()); // Track the symbol and its corresponding mlir::Value if requested @@ -1492,8 +1454,6 @@ static void genPrivatizationRecipes( symbolPairs->emplace_back(op.getAccVar(), Fortran::semantics::SymbolRef(symbol)); } - privatizationRecipes.push_back(mlir::SymbolRefAttr::get( - builder.getContext(), recipe.getSymName().str())); } } @@ -1611,205 +1571,6 @@ static mlir::Value genScalarCombiner(fir::FirOpBuilder &builder, TODO(loc, "reduction operator"); } -static hlfir::DesignateOp::Subscripts -getTripletsFromArgs(mlir::acc::ReductionRecipeOp recipe) { - hlfir::DesignateOp::Subscripts triplets; - for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size(); - i += 3) - triplets.emplace_back(hlfir::DesignateOp::Triplet{ - recipe.getCombinerRegion().getArgument(i), - recipe.getCombinerRegion().getArgument(i + 1), - recipe.getCombinerRegion().getArgument(i + 2)}); - return triplets; -} - -static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, - mlir::acc::ReductionOperator op, mlir::Type ty, - mlir::Value value1, mlir::Value value2, - mlir::acc::ReductionRecipeOp &recipe, - llvm::SmallVector<mlir::Value> &bounds, - bool allConstantBound) { - ty = fir::unwrapRefType(ty); - - if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) { - mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); - llvm::SmallVector<fir::DoLoopOp> loops; - llvm::SmallVector<mlir::Value> ivs; - if (seqTy.hasDynamicExtents()) { - auto shape = - genShapeFromBoundsOrArgs(loc, builder, seqTy, bounds, - recipe.getCombinerRegion().getArguments()); - auto v1DeclareOp = hlfir::DeclareOp::create(builder, loc, value1, - llvm::StringRef{}, shape); - auto v2DeclareOp = hlfir::DeclareOp::create(builder, loc, value2, - llvm::StringRef{}, shape); - hlfir::DesignateOp::Subscripts triplets = getTripletsFromArgs(recipe); - - llvm::SmallVector<mlir::Value> lenParamsLeft; - auto leftEntity = hlfir::Entity{v1DeclareOp.getBase()}; - hlfir::genLengthParameters(loc, builder, leftEntity, lenParamsLeft); - auto leftDesignate = hlfir::DesignateOp::create( - builder, loc, v1DeclareOp.getBase().getType(), v1DeclareOp.getBase(), - /*component=*/"", - /*componentShape=*/mlir::Value{}, triplets, - /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, - shape, lenParamsLeft); - auto left = hlfir::Entity{leftDesignate.getResult()}; - - llvm::SmallVector<mlir::Value> lenParamsRight; - auto rightEntity = hlfir::Entity{v2DeclareOp.getBase()}; - hlfir::genLengthParameters(loc, builder, rightEntity, lenParamsLeft); - auto rightDesignate = hlfir::DesignateOp::create( - builder, loc, v2DeclareOp.getBase().getType(), v2DeclareOp.getBase(), - /*component=*/"", - /*componentShape=*/mlir::Value{}, triplets, - /*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, - shape, lenParamsRight); - auto right = hlfir::Entity{rightDesignate.getResult()}; - - llvm::SmallVector<mlir::Value, 1> typeParams; - auto genKernel = [&builder, &loc, op, seqTy, &left, &right]( - mlir::Location l, fir::FirOpBuilder &b, - mlir::ValueRange oneBasedIndices) -> hlfir::Entity { - auto leftElement = hlfir::getElementAt(l, b, left, oneBasedIndices); - auto rightElement = hlfir::getElementAt(l, b, right, oneBasedIndices); - auto leftVal = hlfir::loadTrivialScalar(l, b, leftElement); - auto rightVal = hlfir::loadTrivialScalar(l, b, rightElement); - return hlfir::Entity{genScalarCombiner( - builder, loc, op, seqTy.getEleTy(), leftVal, rightVal)}; - }; - mlir::Value elemental = hlfir::genElementalOp( - loc, builder, seqTy.getEleTy(), shape, typeParams, genKernel, - /*isUnordered=*/true); - hlfir::AssignOp::create(builder, loc, elemental, v1DeclareOp.getBase()); - return; - } - if (bounds.empty()) { - llvm::SmallVector<mlir::Value> extents; - mlir::Type idxTy = builder.getIndexType(); - for (auto extent : llvm::reverse(seqTy.getShape())) { - mlir::Value lb = mlir::arith::ConstantOp::create( - builder, loc, idxTy, builder.getIntegerAttr(idxTy, 0)); - mlir::Value ub = mlir::arith::ConstantOp::create( - builder, loc, idxTy, builder.getIntegerAttr(idxTy, extent - 1)); - mlir::Value step = mlir::arith::ConstantOp::create( - builder, loc, idxTy, builder.getIntegerAttr(idxTy, 1)); - auto loop = fir::DoLoopOp::create(builder, loc, lb, ub, step, - /*unordered=*/false); - builder.setInsertionPointToStart(loop.getBody()); - loops.push_back(loop); - ivs.push_back(loop.getInductionVar()); - } - } else if (allConstantBound) { - // Use the constant bound directly in the combiner region so they do not - // need to be passed as block argument. - assert(!bounds.empty() && - "seq type with constant bounds cannot have empty bounds"); - for (auto bound : llvm::reverse(bounds)) { - auto dataBound = - mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); - llvm::SmallVector<mlir::Value> values = - genConstantBounds(builder, loc, dataBound); - auto loop = - fir::DoLoopOp::create(builder, loc, values[0], values[1], values[2], - /*unordered=*/false); - builder.setInsertionPointToStart(loop.getBody()); - loops.push_back(loop); - ivs.push_back(loop.getInductionVar()); - } - } else { - // Lowerbound, upperbound and step are passed as block arguments. - unsigned nbRangeArgs = - recipe.getCombinerRegion().getArguments().size() - 2; - assert((nbRangeArgs / 3 == seqTy.getDimension()) && - "Expect 3 block arguments per dimension"); - for (int i = nbRangeArgs - 1; i >= 2; i -= 3) { - mlir::Value lb = recipe.getCombinerRegion().getArgument(i); - mlir::Value ub = recipe.getCombinerRegion().getArgument(i + 1); - mlir::Value step = recipe.getCombinerRegion().getArgument(i + 2); - auto loop = fir::DoLoopOp::create(builder, loc, lb, ub, step, - /*unordered=*/false); - builder.setInsertionPointToStart(loop.getBody()); - loops.push_back(loop); - ivs.push_back(loop.getInductionVar()); - } - } - llvm::SmallVector<mlir::Value> reversedIvs(ivs.rbegin(), ivs.rend()); - auto addr1 = - fir::CoordinateOp::create(builder, loc, refTy, value1, reversedIvs); - auto addr2 = - fir::CoordinateOp::create(builder, loc, refTy, value2, reversedIvs); - auto load1 = fir::LoadOp::create(builder, loc, addr1); - auto load2 = fir::LoadOp::create(builder, loc, addr2); - mlir::Value res = - genScalarCombiner(builder, loc, op, seqTy.getEleTy(), load1, load2); - fir::StoreOp::create(builder, loc, res, addr1); - builder.setInsertionPointAfter(loops[0]); - } else if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { - mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy()); - if (fir::isa_trivial(innerTy)) { - mlir::Value boxAddr1 = value1, boxAddr2 = value2; - if (fir::isBoxAddress(boxAddr1.getType())) - boxAddr1 = fir::LoadOp::create(builder, loc, boxAddr1); - if (fir::isBoxAddress(boxAddr2.getType())) - boxAddr2 = fir::LoadOp::create(builder, loc, boxAddr2); - boxAddr1 = fir::BoxAddrOp::create(builder, loc, boxAddr1); - boxAddr2 = fir::BoxAddrOp::create(builder, loc, boxAddr2); - auto leftEntity = hlfir::Entity{boxAddr1}; - auto rightEntity = hlfir::Entity{boxAddr2}; - - auto leftVal = hlfir::loadTrivialScalar(loc, builder, leftEntity); - auto rightVal = hlfir::loadTrivialScalar(loc, builder, rightEntity); - mlir::Value res = - genScalarCombiner(builder, loc, op, innerTy, leftVal, rightVal); - hlfir::AssignOp::create(builder, loc, res, boxAddr1); - } else { - mlir::Type innerTy = fir::extractSequenceType(boxTy); - fir::SequenceType seqTy = - mlir::dyn_cast_or_null<fir::SequenceType>(innerTy); - if (!seqTy) - TODO(loc, "Unsupported boxed type in OpenACC reduction combiner"); - - auto shape = - genShapeFromBoundsOrArgs(loc, builder, seqTy, bounds, - recipe.getCombinerRegion().getArguments()); - hlfir::DesignateOp::Subscripts triplets = - getSubscriptsFromArgs(recipe.getCombinerRegion().getArguments()); - auto leftEntity = hlfir::Entity{value1}; - if (fir::isBoxAddress(value1.getType())) - leftEntity = hlfir::Entity{ - fir::LoadOp::create(builder, loc, value1).getResult()}; - auto left = - genDesignateWithTriplets(builder, loc, leftEntity, triplets, shape); - auto rightEntity = hlfir::Entity{value2}; - if (fir::isBoxAddress(value2.getType())) - rightEntity = hlfir::Entity{ - fir::LoadOp::create(builder, loc, value2).getResult()}; - auto right = - genDesignateWithTriplets(builder, loc, rightEntity, triplets, shape); - - llvm::SmallVector<mlir::Value, 1> typeParams; - auto genKernel = [&builder, &loc, op, seqTy, &left, &right]( - mlir::Location l, fir::FirOpBuilder &b, - mlir::ValueRange oneBasedIndices) -> hlfir::Entity { - auto leftElement = hlfir::getElementAt(l, b, left, oneBasedIndices); - auto rightElement = hlfir::getElementAt(l, b, right, oneBasedIndices); - auto leftVal = hlfir::loadTrivialScalar(l, b, leftElement); - auto rightVal = hlfir::loadTrivialScalar(l, b, rightElement); - return hlfir::Entity{genScalarCombiner( - builder, loc, op, seqTy.getEleTy(), leftVal, rightVal)}; - }; - mlir::Value elemental = hlfir::genElementalOp( - loc, builder, seqTy.getEleTy(), shape, typeParams, genKernel, - /*isUnordered=*/true); - hlfir::AssignOp::create(builder, loc, elemental, value1); - } - } else { - mlir::Value res = genScalarCombiner(builder, loc, op, ty, value1, value2); - fir::StoreOp::create(builder, loc, res, value1); - } -} - mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe( fir::FirOpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc, mlir::Type ty, mlir::acc::ReductionOperator op, @@ -1819,37 +1580,33 @@ mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe( if (auto recipe = mod.lookupSymbol<mlir::acc::ReductionRecipeOp>(recipeName)) return recipe; - auto ip = builder.saveInsertionPoint(); - + mlir::OpBuilder::InsertionGuard guard(builder); auto recipe = genRecipeOp<mlir::acc::ReductionRecipeOp>( builder, mod, recipeName, loc, ty, op); - - // The two first block arguments are the two values to be combined. - // The next arguments are the iteration ranges (lb, ub, step) to be used - // for the combiner if needed. - llvm::SmallVector<mlir::Type> argsTy{ty, ty}; - llvm::SmallVector<mlir::Location> argsLoc{loc, loc}; - bool allConstantBound = areAllBoundConstant(bounds); - if (!allConstantBound) { - for (mlir::Value bound : llvm::reverse(bounds)) { - auto dataBound = - mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp()); - argsTy.push_back(dataBound.getLowerbound().getType()); - argsLoc.push_back(dataBound.getLowerbound().getLoc()); - argsTy.push_back(dataBound.getUpperbound().getType()); - argsLoc.push_back(dataBound.getUpperbound().getLoc()); - argsTy.push_back(dataBound.getStartIdx().getType()); - argsLoc.push_back(dataBound.getStartIdx().getLoc()); - } - } - builder.createBlock(&recipe.getCombinerRegion(), - recipe.getCombinerRegion().end(), argsTy, argsLoc); - builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back()); - mlir::Value v1 = recipe.getCombinerRegion().front().getArgument(0); - mlir::Value v2 = recipe.getCombinerRegion().front().getArgument(1); - genCombiner(builder, loc, op, ty, v1, v2, recipe, bounds, allConstantBound); - mlir::acc::YieldOp::create(builder, loc, v1); - builder.restoreInsertionPoint(ip); + bool allConstantBound = fir::acc::areAllBoundsConstant(bounds); + + auto [dest, src] = genRecipeCombinerOrCopyRegion( + builder, loc, ty, recipe.getCombinerRegion(), bounds, allConstantBound); + // Generate loops that combine and assign the inputs into dest (or array + // section of the inputs when there are bounds). + hlfir::Entity srcSection = src; + hlfir::Entity destSection = dest; + if (!bounds.empty()) + std::tie(srcSection, destSection) = genArraySectionsInRecipe( + builder, loc, bounds, recipe.getCombinerRegion().getArguments(), + allConstantBound, srcSection, destSection); + + mlir::Type elementType = fir::getFortranElementType(ty); + auto genKernel = [&](mlir::Location l, fir::FirOpBuilder &b, + hlfir::Entity srcElementValue, + hlfir::Entity destElementValue) -> hlfir::Entity { + return hlfir::Entity{genScalarCombiner(builder, loc, op, elementType, + srcElementValue, destElementValue)}; + }; + hlfir::genNoAliasAssignment(loc, builder, srcSection, destSection, + /*emitWorkshareLoop=*/false, + /*temporaryLHS=*/false, genKernel); + mlir::acc::YieldOp::create(builder, loc, dest); return recipe; } @@ -1866,16 +1623,17 @@ static bool isSupportedReductionType(mlir::Type ty) { return fir::isa_trivial(ty); } -static void -genReductions(const Fortran::parser::AccObjectListWithReduction &objectList, - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semanticsContext, - Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl<mlir::Value> &reductionOperands, - llvm::SmallVector<mlir::Attribute> &reductionRecipes, - llvm::ArrayRef<mlir::Value> async, - llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes, - llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes) { +static void genReductions( + const Fortran::parser::AccObjectListWithReduction &objectList, + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::lower::StatementContext &stmtCtx, + llvm::SmallVectorImpl<mlir::Value> &reductionOperands, + llvm::ArrayRef<mlir::Value> async, + llvm::ArrayRef<mlir::Attribute> asyncDeviceTypes, + llvm::ArrayRef<mlir::Attribute> asyncOnlyDeviceTypes, + llvm::SmallVectorImpl<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> + *symbolPairs = nullptr) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); const auto &objects = std::get<Fortran::parser::AccObjectList>(objectList.t); const auto &op = std::get<Fortran::parser::ReductionOperator>(objectList.t); @@ -1888,6 +1646,8 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList, Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject); Fortran::semantics::MaybeExpr designator = Fortran::common::visit( [&](auto &&s) { return ea.Analyze(s); }, accObject.u); + bool isWholeSymbol = + !designator || Fortran::evaluate::UnwrapWholeSymbolDataRef(*designator); fir::factory::AddrAndBoundsInfo info = Fortran::lower::gatherDataOperandAddrAndBounds< mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>( @@ -1911,22 +1671,24 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList, mlir::acc::DataClause::acc_reduction, info.addr.getType(), async, asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true); mlir::Type ty = op.getAccVar().getType(); - if (!areAllBoundConstant(bounds) || + if (!fir::acc::areAllBoundsConstant(bounds) || fir::isAssumedShape(info.addr.getType()) || fir::isAllocatableOrPointerArray(info.addr.getType())) ty = info.addr.getType(); - std::string suffix = - areAllBoundConstant(bounds) ? getBoundsString(bounds) : ""; - std::string recipeName = fir::getTypeAsString( - ty, converter.getKindMap(), - ("reduction_" + stringifyReductionOperator(mlirOp)).str() + suffix); + std::string recipeName = fir::acc::getRecipeName( + mlir::acc::RecipeKind::reduction_recipe, ty, info.addr, bounds, mlirOp); mlir::acc::ReductionRecipeOp recipe = Fortran::lower::createOrGetReductionRecipe( builder, recipeName, operandLocation, ty, mlirOp, bounds); - reductionRecipes.push_back(mlir::SymbolRefAttr::get( - builder.getContext(), recipe.getSymName().str())); + op.setRecipeAttr( + mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName())); reductionOperands.push_back(op.getAccVar()); + // Track the symbol and its corresponding mlir::Value if requested so that + // accesses inside the compute/loop regions use the acc.reduction variable. + if (symbolPairs && isWholeSymbol) + symbolPairs->emplace_back(op.getAccVar(), + Fortran::semantics::SymbolRef(symbol)); } } @@ -2138,7 +1900,6 @@ static void privatizeIv( llvm::SmallVector<mlir::Value> &privateOperands, llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> &ivPrivate, - llvm::SmallVector<mlir::Attribute> &privatizationRecipes, bool isDoConcurrent = false) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); @@ -2164,9 +1925,8 @@ static void privatizeIv( } if (privateOp == nullptr) { - std::string recipeName = - fir::getTypeAsString(ivValue.getType(), converter.getKindMap(), - Fortran::lower::privatizationRecipePrefix); + std::string recipeName = fir::acc::getRecipeName( + mlir::acc::RecipeKind::private_recipe, ivValue.getType(), ivValue, {}); auto recipe = Fortran::lower::createOrGetPrivateRecipe( builder, recipeName, loc, ivValue.getType()); @@ -2176,11 +1936,11 @@ static void privatizeIv( builder, loc, ivValue, asFortran, {}, true, /*implicit=*/true, mlir::acc::DataClause::acc_private, ivValue.getType(), /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); + op.setRecipeAttr( + mlir::SymbolRefAttr::get(builder.getContext(), recipe.getSymName())); privateOp = op.getOperation(); privateOperands.push_back(op.getAccVar()); - privatizationRecipes.push_back(mlir::SymbolRefAttr::get( - builder.getContext(), recipe.getSymName().str())); } ivPrivate.emplace_back(mlir::acc::getAccVar(privateOp), @@ -2251,6 +2011,49 @@ static void determineDefaultLoopParMode( } } +// Helper to visit Bounds of DO LOOP nest. +static void visitLoopControl( + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::DoConstruct &outerDoConstruct, + uint64_t loopsToProcess, Fortran::lower::pft::Evaluation &eval, + std::function<void(const Fortran::parser::LoopControl::Bounds &, + mlir::Location)> + callback) { + Fortran::lower::pft::Evaluation *crtEval = &eval.getFirstNestedEvaluation(); + for (uint64_t i = 0; i < loopsToProcess; ++i) { + const Fortran::parser::LoopControl *loopControl; + if (i == 0) { + loopControl = &*outerDoConstruct.GetLoopControl(); + mlir::Location loc = converter.genLocation( + Fortran::parser::FindSourceLocation(outerDoConstruct)); + callback(std::get<Fortran::parser::LoopControl::Bounds>(loopControl->u), + loc); + } else { + // Safely locate the next inner DoConstruct within this eval. + const Fortran::parser::DoConstruct *innerDo = nullptr; + if (crtEval && crtEval->hasNestedEvaluations()) { + for (Fortran::lower::pft::Evaluation &child : + crtEval->getNestedEvaluations()) { + if (auto *stmt = child.getIf<Fortran::parser::DoConstruct>()) { + innerDo = stmt; + // Prepare to descend for the next iteration + crtEval = &child; + break; + } + } + } + if (!innerDo) + break; // No deeper loop; stop collecting collapsed bounds. + + loopControl = &*innerDo->GetLoopControl(); + mlir::Location loc = + converter.genLocation(Fortran::parser::FindSourceLocation(*innerDo)); + callback(std::get<Fortran::parser::LoopControl::Bounds>(loopControl->u), + loc); + } + } +} + // Extract loop bounds, steps, induction variables, and privatization info // for both DO CONCURRENT and regular do loops static void processDoLoopBounds( @@ -2265,14 +2068,12 @@ static void processDoLoopBounds( llvm::SmallVector<mlir::Value> &privateOperands, llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> &ivPrivate, - llvm::SmallVector<mlir::Attribute> &privatizationRecipes, llvm::SmallVector<mlir::Type> &ivTypes, llvm::SmallVector<mlir::Location> &ivLocs, llvm::SmallVector<bool> &inclusiveBounds, llvm::SmallVector<mlir::Location> &locs, uint64_t loopsToProcess) { assert(loopsToProcess > 0 && "expect at least one loop"); locs.push_back(currentLocation); // Location of the directive - Fortran::lower::pft::Evaluation *crtEval = &eval.getFirstNestedEvaluation(); bool isDoConcurrent = outerDoConstruct.IsDoConcurrent(); if (isDoConcurrent) { @@ -2307,63 +2108,34 @@ static void processDoLoopBounds( const auto &name = std::get<Fortran::parser::Name>(control.t); privatizeIv(converter, *name.symbol, currentLocation, ivTypes, ivLocs, - privateOperands, ivPrivate, privatizationRecipes, - isDoConcurrent); + privateOperands, ivPrivate, isDoConcurrent); inclusiveBounds.push_back(true); } } else { - for (uint64_t i = 0; i < loopsToProcess; ++i) { - const Fortran::parser::LoopControl *loopControl; - if (i == 0) { - loopControl = &*outerDoConstruct.GetLoopControl(); - locs.push_back(converter.genLocation( - Fortran::parser::FindSourceLocation(outerDoConstruct))); - } else { - // Safely locate the next inner DoConstruct within this eval. - const Fortran::parser::DoConstruct *innerDo = nullptr; - if (crtEval && crtEval->hasNestedEvaluations()) { - for (Fortran::lower::pft::Evaluation &child : - crtEval->getNestedEvaluations()) { - if (auto *stmt = child.getIf<Fortran::parser::DoConstruct>()) { - innerDo = stmt; - // Prepare to descend for the next iteration - crtEval = &child; - break; - } - } - } - if (!innerDo) - break; // No deeper loop; stop collecting collapsed bounds. - - loopControl = &*innerDo->GetLoopControl(); - locs.push_back(converter.genLocation( - Fortran::parser::FindSourceLocation(*innerDo))); - } - - const Fortran::parser::LoopControl::Bounds *bounds = - std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u); - assert(bounds && "Expected bounds on the loop construct"); - lowerbounds.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))); - upperbounds.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))); - if (bounds->step) - steps.push_back(fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(bounds->step), stmtCtx))); - else // If `step` is not present, assume it is `1`. - steps.push_back(builder.createIntegerConstant( - currentLocation, upperbounds[upperbounds.size() - 1].getType(), 1)); - - Fortran::semantics::Symbol &ivSym = - bounds->name.thing.symbol->GetUltimate(); - privatizeIv(converter, ivSym, currentLocation, ivTypes, ivLocs, - privateOperands, ivPrivate, privatizationRecipes); - - inclusiveBounds.push_back(true); - - // crtEval already updated when descending; no blind increment here. - } + visitLoopControl( + converter, outerDoConstruct, loopsToProcess, eval, + [&](const Fortran::parser::LoopControl::Bounds &bounds, + mlir::Location loc) { + locs.push_back(loc); + lowerbounds.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds.lower), stmtCtx))); + upperbounds.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds.upper), stmtCtx))); + if (bounds.step) + steps.push_back(fir::getBase(converter.genExprValue( + *Fortran::semantics::GetExpr(bounds.step), stmtCtx))); + else // If `step` is not present, assume it is `1`. + steps.push_back(builder.createIntegerConstant( + currentLocation, upperbounds[upperbounds.size() - 1].getType(), + 1)); + Fortran::semantics::Symbol &ivSym = + bounds.name.thing.symbol->GetUltimate(); + privatizeIv(converter, ivSym, currentLocation, ivTypes, ivLocs, + privateOperands, ivPrivate); + + inclusiveBounds.push_back(true); + }); } } @@ -2499,6 +2271,32 @@ static void remapDataOperandSymbols( } } +static void privatizeInductionVariables( + Fortran::lower::AbstractConverter &converter, + mlir::Location currentLocation, + const Fortran::parser::DoConstruct &outerDoConstruct, + Fortran::lower::pft::Evaluation &eval, + llvm::SmallVector<mlir::Value> &privateOperands, + llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> + &ivPrivate, + llvm::SmallVector<mlir::Location> &locs, uint64_t loopsToProcess) { + // ivTypes and locs will be ignored since no acc.loop control arguments will + // be created. + llvm::SmallVector<mlir::Type> ivTypes; + llvm::SmallVector<mlir::Location> ivLocs; + assert(!outerDoConstruct.IsDoConcurrent() && + "do concurrent loops are not expected to contained earlty exits"); + visitLoopControl(converter, outerDoConstruct, loopsToProcess, eval, + [&](const Fortran::parser::LoopControl::Bounds &bounds, + mlir::Location loc) { + locs.push_back(loc); + Fortran::semantics::Symbol &ivSym = + bounds.name.thing.symbol->GetUltimate(); + privatizeIv(converter, ivSym, currentLocation, ivTypes, + ivLocs, privateOperands, ivPrivate); + }); +} + static mlir::acc::LoopOp buildACCLoopOp( Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation, @@ -2507,7 +2305,6 @@ static mlir::acc::LoopOp buildACCLoopOp( const Fortran::parser::DoConstruct &outerDoConstruct, Fortran::lower::pft::Evaluation &eval, llvm::SmallVector<mlir::Value> &privateOperands, - llvm::SmallVector<mlir::Attribute> &privatizationRecipes, llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> &dataOperandSymbolPairs, llvm::SmallVector<mlir::Value> &gangOperands, @@ -2528,13 +2325,22 @@ static mlir::acc::LoopOp buildACCLoopOp( llvm::SmallVector<mlir::Location> locs; llvm::SmallVector<mlir::Value> lowerbounds, upperbounds, steps; - // Look at the do/do concurrent loops to extract bounds information. - processDoLoopBounds(converter, currentLocation, stmtCtx, builder, - outerDoConstruct, eval, lowerbounds, upperbounds, steps, - privateOperands, ivPrivate, privatizationRecipes, ivTypes, - ivLocs, inclusiveBounds, locs, loopsToProcess); - - // Prepare the operand segment size attribute and the operands value range. + // Look at the do/do concurrent loops to extract bounds information unless + // this loop is lowered in an unstructured fashion, in which case bounds are + // not represented on acc.loop and explicit control flow is used inside body. + if (!eval.lowerAsUnstructured()) { + processDoLoopBounds(converter, currentLocation, stmtCtx, builder, + outerDoConstruct, eval, lowerbounds, upperbounds, steps, + privateOperands, ivPrivate, ivTypes, ivLocs, + inclusiveBounds, locs, loopsToProcess); + } else { + // When the loop contains early exits, privatize induction variables, but do + // not create acc.loop bounds. The control flow of the loop will be + // generated explicitly in the acc.loop body that is just a container. + privatizeInductionVariables(converter, currentLocation, outerDoConstruct, + eval, privateOperands, ivPrivate, locs, + loopsToProcess); + } llvm::SmallVector<mlir::Value> operands; llvm::SmallVector<int32_t> operandSegments; addOperands(operands, operandSegments, lowerbounds); @@ -2563,20 +2369,36 @@ static mlir::acc::LoopOp buildACCLoopOp( // Remap symbols from data clauses to use data operation results remapDataOperandSymbols(converter, builder, loopOp, dataOperandSymbolPairs); - for (auto [arg, iv] : - llvm::zip(loopOp.getLoopRegions().front()->front().getArguments(), - ivPrivate)) { - // Store block argument to the related iv private variable. - mlir::Value privateValue = - converter.getSymbolAddress(std::get<Fortran::semantics::SymbolRef>(iv)); - fir::StoreOp::create(builder, currentLocation, arg, privateValue); + if (!eval.lowerAsUnstructured()) { + for (auto [arg, iv] : + llvm::zip(loopOp.getLoopRegions().front()->front().getArguments(), + ivPrivate)) { + // Store block argument to the related iv private variable. + mlir::Value privateValue = converter.getSymbolAddress( + std::get<Fortran::semantics::SymbolRef>(iv)); + fir::StoreOp::create(builder, currentLocation, arg, privateValue); + } + loopOp.setInclusiveUpperbound(inclusiveBounds); + } else { + loopOp.setUnstructuredAttr(builder.getUnitAttr()); } - loopOp.setInclusiveUpperbound(inclusiveBounds); - return loopOp; } +static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) { + bool hasReturnStmt = false; + for (auto &e : eval.getNestedEvaluations()) { + e.visit(Fortran::common::visitors{ + [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; }, + [&](const auto &s) {}, + }); + if (e.hasNestedEvaluations()) + hasReturnStmt = hasEarlyReturn(e); + } + return hasReturnStmt; +} + static mlir::acc::LoopOp createLoopOp( Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation, @@ -2586,13 +2408,11 @@ static mlir::acc::LoopOp createLoopOp( Fortran::lower::pft::Evaluation &eval, const Fortran::parser::AccClauseList &accClauseList, std::optional<mlir::acc::CombinedConstructsType> combinedConstructs = - std::nullopt, - bool needEarlyReturnHandling = false) { + std::nullopt) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); llvm::SmallVector<mlir::Value> tileOperands, privateOperands, reductionOperands, cacheOperands, vectorOperands, workerNumOperands, gangOperands; - llvm::SmallVector<mlir::Attribute> privatizationRecipes, reductionRecipes; llvm::SmallVector<int32_t> tileOperandsSegments, gangOperandsSegments; llvm::SmallVector<int64_t> collapseValues; @@ -2719,15 +2539,16 @@ static mlir::acc::LoopOp createLoopOp( &clause.u)) { genPrivatizationRecipes<mlir::acc::PrivateRecipeOp>( privateClause->v, converter, semanticsContext, stmtCtx, - privateOperands, privatizationRecipes, /*async=*/{}, + privateOperands, /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}, &dataOperandSymbolPairs); } else if (const auto *reductionClause = std::get_if<Fortran::parser::AccClause::Reduction>( &clause.u)) { genReductions(reductionClause->v, converter, semanticsContext, stmtCtx, - reductionOperands, reductionRecipes, /*async=*/{}, - /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}); + reductionOperands, /*async=*/{}, + /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}, + &dataOperandSymbolPairs); } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) { for (auto crtDeviceTypeAttr : crtDeviceTypes) seqDeviceTypes.push_back(crtDeviceTypeAttr); @@ -2763,7 +2584,10 @@ static mlir::acc::LoopOp createLoopOp( llvm::SmallVector<mlir::Type> retTy; mlir::Value yieldValue; - if (needEarlyReturnHandling) { + if (eval.lowerAsUnstructured() && hasEarlyReturn(eval)) { + // When there is a return statement inside the loop, add a result to the + // acc.loop that will be used in a conditional branch after the loop to + // return. mlir::Type i1Ty = builder.getI1Type(); yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0); retTy.push_back(i1Ty); @@ -2773,9 +2597,9 @@ static mlir::acc::LoopOp createLoopOp( Fortran::lower::getLoopCountForCollapseAndTile(accClauseList); auto loopOp = buildACCLoopOp( converter, currentLocation, semanticsContext, stmtCtx, outerDoConstruct, - eval, privateOperands, privatizationRecipes, dataOperandSymbolPairs, - gangOperands, workerNumOperands, vectorOperands, tileOperands, - cacheOperands, reductionOperands, retTy, yieldValue, loopsToProcess); + eval, privateOperands, dataOperandSymbolPairs, gangOperands, + workerNumOperands, vectorOperands, tileOperands, cacheOperands, + reductionOperands, retTy, yieldValue, loopsToProcess); if (!gangDeviceTypes.empty()) loopOp.setGangAttr(builder.getArrayAttr(gangDeviceTypes)); @@ -2817,14 +2641,6 @@ static mlir::acc::LoopOp createLoopOp( if (!autoDeviceTypes.empty()) loopOp.setAuto_Attr(builder.getArrayAttr(autoDeviceTypes)); - if (!privatizationRecipes.empty()) - loopOp.setPrivatizationRecipesAttr( - mlir::ArrayAttr::get(builder.getContext(), privatizationRecipes)); - - if (!reductionRecipes.empty()) - loopOp.setReductionRecipesAttr( - mlir::ArrayAttr::get(builder.getContext(), reductionRecipes)); - if (!collapseValues.empty()) loopOp.setCollapseAttr(builder.getI64ArrayAttr(collapseValues)); if (!collapseDeviceTypes.empty()) @@ -2844,19 +2660,6 @@ static mlir::acc::LoopOp createLoopOp( return loopOp; } -static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) { - bool hasReturnStmt = false; - for (auto &e : eval.getNestedEvaluations()) { - e.visit(Fortran::common::visitors{ - [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; }, - [&](const auto &s) {}, - }); - if (e.hasNestedEvaluations()) - hasReturnStmt = hasEarlyReturn(e); - } - return hasReturnStmt; -} - static mlir::Value genACC(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semanticsContext, @@ -2870,17 +2673,6 @@ genACC(Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation = converter.genLocation(beginLoopDirective.source); - bool needEarlyExitHandling = false; - if (eval.lowerAsUnstructured()) { - needEarlyExitHandling = hasEarlyReturn(eval); - // If the loop is lowered in an unstructured fashion, lowering generates - // explicit control flow that duplicates the looping semantics of the - // loops. - if (!needEarlyExitHandling) - TODO(currentLocation, - "loop with early exit inside OpenACC loop construct"); - } - Fortran::lower::StatementContext stmtCtx; assert(loopDirective.v == llvm::acc::ACCD_loop && @@ -2893,8 +2685,8 @@ genACC(Fortran::lower::AbstractConverter &converter, std::get<std::optional<Fortran::parser::DoConstruct>>(loopConstruct.t); auto loopOp = createLoopOp(converter, currentLocation, semanticsContext, stmtCtx, *outerDoConstruct, eval, accClauseList, - /*combinedConstructs=*/{}, needEarlyExitHandling); - if (needEarlyExitHandling) + /*combinedConstructs=*/{}); + if (loopOp.getNumResults() == 1) return loopOp.getResult(0); return mlir::Value{}; @@ -2955,8 +2747,6 @@ static Op createComputeOp( llvm::SmallVector<mlir::Value> reductionOperands, privateOperands, firstprivateOperands; - llvm::SmallVector<mlir::Attribute> privatizationRecipes, - firstPrivatizationRecipes, reductionRecipes; // Vector to track mlir::Value results and their corresponding Fortran symbols llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> @@ -3106,8 +2896,8 @@ static Op createComputeOp( genDataOperandOperationsWithModifier<mlir::acc::CreateOp, Fortran::parser::AccClause::Copyout>( copyoutClause, converter, semanticsContext, stmtCtx, - Fortran::parser::AccDataModifier::Modifier::ReadOnly, - dataClauseOperands, mlir::acc::DataClause::acc_copyout, + Fortran::parser::AccDataModifier::Modifier::Zero, dataClauseOperands, + mlir::acc::DataClause::acc_copyout, mlir::acc::DataClause::acc_copyout_zero, async, asyncDeviceTypes, asyncOnlyDeviceTypes, /*setDeclareAttr=*/false, &dataOperandSymbolPairs); @@ -3178,15 +2968,15 @@ static Op createComputeOp( if (!combinedConstructs) genPrivatizationRecipes<mlir::acc::PrivateRecipeOp>( privateClause->v, converter, semanticsContext, stmtCtx, - privateOperands, privatizationRecipes, async, asyncDeviceTypes, - asyncOnlyDeviceTypes, &dataOperandSymbolPairs); + privateOperands, async, asyncDeviceTypes, asyncOnlyDeviceTypes, + &dataOperandSymbolPairs); } else if (const auto *firstprivateClause = std::get_if<Fortran::parser::AccClause::Firstprivate>( &clause.u)) { genPrivatizationRecipes<mlir::acc::FirstprivateRecipeOp>( firstprivateClause->v, converter, semanticsContext, stmtCtx, - firstprivateOperands, firstPrivatizationRecipes, async, - asyncDeviceTypes, asyncOnlyDeviceTypes, &dataOperandSymbolPairs); + firstprivateOperands, async, asyncDeviceTypes, asyncOnlyDeviceTypes, + &dataOperandSymbolPairs); } else if (const auto *reductionClause = std::get_if<Fortran::parser::AccClause::Reduction>( &clause.u)) { @@ -3197,8 +2987,8 @@ static Op createComputeOp( // instead. if (!combinedConstructs) { genReductions(reductionClause->v, converter, semanticsContext, stmtCtx, - reductionOperands, reductionRecipes, async, - asyncDeviceTypes, asyncOnlyDeviceTypes); + reductionOperands, async, asyncDeviceTypes, + asyncOnlyDeviceTypes, &dataOperandSymbolPairs); } else { auto crtDataStart = dataClauseOperands.size(); genDataOperandOperations<mlir::acc::CopyinOp>( @@ -3234,11 +3024,9 @@ static Op createComputeOp( } addOperand(operands, operandSegments, ifCond); addOperand(operands, operandSegments, selfCond); - if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) { - addOperands(operands, operandSegments, reductionOperands); - addOperands(operands, operandSegments, privateOperands); - addOperands(operands, operandSegments, firstprivateOperands); - } + addOperands(operands, operandSegments, reductionOperands); + addOperands(operands, operandSegments, privateOperands); + addOperands(operands, operandSegments, firstprivateOperands); addOperands(operands, operandSegments, dataClauseOperands); Op computeOp; @@ -3290,18 +3078,6 @@ static Op createComputeOp( if (!waitOnlyDeviceTypes.empty()) computeOp.setWaitOnlyAttr(builder.getArrayAttr(waitOnlyDeviceTypes)); - if constexpr (!std::is_same_v<Op, mlir::acc::KernelsOp>) { - if (!privatizationRecipes.empty()) - computeOp.setPrivatizationRecipesAttr( - mlir::ArrayAttr::get(builder.getContext(), privatizationRecipes)); - if (!reductionRecipes.empty()) - computeOp.setReductionRecipesAttr( - mlir::ArrayAttr::get(builder.getContext(), reductionRecipes)); - if (!firstPrivatizationRecipes.empty()) - computeOp.setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get( - builder.getContext(), firstPrivatizationRecipes)); - } - if (combinedConstructs) computeOp.setCombinedAttr(builder.getUnitAttr()); @@ -3679,10 +3455,6 @@ genACC(Fortran::lower::AbstractConverter &converter, converter.genLocation(beginCombinedDirective.source); Fortran::lower::StatementContext stmtCtx; - if (eval.lowerAsUnstructured()) - TODO(currentLocation, - "loop with early exit inside OpenACC combined construct"); - if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) { createComputeOp<mlir::acc::KernelsOp>( converter, currentLocation, eval, semanticsContext, stmtCtx, @@ -5014,37 +4786,8 @@ static void genACC(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semanticsContext, const Fortran::parser::OpenACCCacheConstruct &cacheConstruct) { - fir::FirOpBuilder &builder = converter.getFirOpBuilder(); - auto loopOp = builder.getRegion().getParentOfType<mlir::acc::LoopOp>(); - auto crtPos = builder.saveInsertionPoint(); - if (loopOp) { - builder.setInsertionPoint(loopOp); - Fortran::lower::StatementContext stmtCtx; - llvm::SmallVector<mlir::Value> cacheOperands; - const Fortran::parser::AccObjectListWithModifier &listWithModifier = - std::get<Fortran::parser::AccObjectListWithModifier>(cacheConstruct.t); - const auto &accObjectList = - std::get<Fortran::parser::AccObjectList>(listWithModifier.t); - const auto &modifier = - std::get<std::optional<Fortran::parser::AccDataModifier>>( - listWithModifier.t); - - mlir::acc::DataClause dataClause = mlir::acc::DataClause::acc_cache; - if (modifier && - (*modifier).v == Fortran::parser::AccDataModifier::Modifier::ReadOnly) - dataClause = mlir::acc::DataClause::acc_cache_readonly; - genDataOperandOperations<mlir::acc::CacheOp>( - accObjectList, converter, semanticsContext, stmtCtx, cacheOperands, - dataClause, - /*structured=*/true, /*implicit=*/false, - /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{}, - /*setDeclareAttr*/ false); - loopOp.getCacheOperandsMutable().append(cacheOperands); - } else { - llvm::report_fatal_error( - "could not find loop to attach OpenACC cache information."); - } - builder.restoreInsertionPoint(crtPos); + mlir::Location loc = converter.genLocation(cacheConstruct.source); + TODO(loc, "OpenACC cache directive"); } mlir::Value Fortran::lower::genOpenACCConstruct( @@ -5315,7 +5058,6 @@ mlir::Operation *Fortran::lower::genOpenACCLoopFromDoConstruct( llvm::SmallVector<mlir::Value> privateOperands, gangOperands, workerNumOperands, vectorOperands, tileOperands, cacheOperands, reductionOperands; - llvm::SmallVector<mlir::Attribute> privatizationRecipes; llvm::SmallVector<mlir::Type> retTy; llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>> dataOperandSymbolPairs; @@ -5327,15 +5069,9 @@ mlir::Operation *Fortran::lower::genOpenACCLoopFromDoConstruct( Fortran::lower::StatementContext stmtCtx; auto loopOp = buildACCLoopOp( converter, converter.getCurrentLocation(), semanticsContext, stmtCtx, - doConstruct, eval, privateOperands, privatizationRecipes, - dataOperandSymbolPairs, gangOperands, workerNumOperands, vectorOperands, - tileOperands, cacheOperands, reductionOperands, retTy, yieldValue, - loopsToProcess); - - fir::FirOpBuilder &builder = converter.getFirOpBuilder(); - if (!privatizationRecipes.empty()) - loopOp.setPrivatizationRecipesAttr(mlir::ArrayAttr::get( - converter.getFirOpBuilder().getContext(), privatizationRecipes)); + doConstruct, eval, privateOperands, dataOperandSymbolPairs, gangOperands, + workerNumOperands, vectorOperands, tileOperands, cacheOperands, + reductionOperands, retTy, yieldValue, loopsToProcess); // Normal do loops which are not annotated with `acc loop` should be // left for analysis by marking with `auto`. This is the case even in the case @@ -5349,8 +5085,9 @@ mlir::Operation *Fortran::lower::genOpenACCLoopFromDoConstruct( // So this means that in all cases we mark with `auto` unless it is a // `do concurrent` in an `acc parallel` construct or it must be `seq` because // it is in an `acc serial` construct. + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); mlir::Operation *accRegionOp = - mlir::acc::getEnclosingComputeOp(converter.getFirOpBuilder().getRegion()); + mlir::acc::getEnclosingComputeOp(builder.getRegion()); mlir::acc::LoopParMode parMode = mlir::isa_and_present<mlir::acc::ParallelOp>(accRegionOp) && doConstruct.IsDoConcurrent() diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 1c163e6..a81ba37 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -13,10 +13,12 @@ #include "ClauseProcessor.h" #include "Utils.h" +#include "flang/Lower/ConvertCall.h" #include "flang/Lower/ConvertExprToHLFIR.h" #include "flang/Lower/OpenMP/Clauses.h" #include "flang/Lower/PFTBuilder.h" #include "flang/Lower/Support/ReductionProcessor.h" +#include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Parser/tools.h" #include "flang/Semantics/tools.h" #include "flang/Utils/OpenMP.h" @@ -42,15 +44,6 @@ mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) { return mlir::omp::ReductionModifier::defaultmod; } -/// Check for unsupported map operand types. -static void checkMapType(mlir::Location location, mlir::Type type) { - if (auto refType = mlir::dyn_cast<fir::ReferenceType>(type)) - type = refType.getElementType(); - if (auto boxType = mlir::dyn_cast_or_null<fir::BoxType>(type)) - if (!mlir::isa<fir::PointerType>(boxType.getElementType())) - TODO(location, "OMPD_target_data MapOperand BoxType"); -} - static mlir::omp::ScheduleModifier translateScheduleModifier(const omp::clause::Schedule::OrderingModifier &m) { switch (m) { @@ -209,18 +202,6 @@ getIfClauseOperand(lower::AbstractConverter &converter, ifVal); } -static void addUseDeviceClause( - lower::AbstractConverter &converter, const omp::ObjectList &objects, - llvm::SmallVectorImpl<mlir::Value> &operands, - llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) { - genObjectList(objects, converter, operands); - for (mlir::Value &operand : operands) - checkMapType(operand.getLoc(), operand.getType()); - - for (const omp::Object &object : objects) - useDeviceSyms.push_back(object.sym()); -} - //===----------------------------------------------------------------------===// // ClauseProcessor unique clauses //===----------------------------------------------------------------------===// @@ -401,11 +382,75 @@ bool ClauseProcessor::processInclusive( return false; } +bool ClauseProcessor::processInitializer( + lower::SymMap &symMap, const parser::OmpClause::Initializer &inp, + ReductionProcessor::GenInitValueCBTy &genInitValueCB) const { + if (auto *clause = findUniqueClause<omp::clause::Initializer>()) { + genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type type, mlir::Value ompOrig) { + lower::SymMapScope scope(symMap); + const parser::OmpInitializerExpression &iexpr = inp.v.v; + const parser::OmpStylizedInstance &styleInstance = iexpr.v.front(); + const std::list<parser::OmpStylizedDeclaration> &declList = + std::get<std::list<parser::OmpStylizedDeclaration>>(styleInstance.t); + mlir::Value ompPrivVar; + for (const parser::OmpStylizedDeclaration &decl : declList) { + auto &name = std::get<parser::ObjectName>(decl.var.t); + assert(name.symbol && "Name does not have a symbol"); + mlir::Value addr = builder.createTemporary(loc, ompOrig.getType()); + fir::StoreOp::create(builder, loc, ompOrig, addr); + fir::FortranVariableFlagsEnum extraFlags = {}; + fir::FortranVariableFlagsAttr attributes = + Fortran::lower::translateSymbolAttributes(builder.getContext(), + *name.symbol, extraFlags); + auto declareOp = hlfir::DeclareOp::create( + builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr, + 0, attributes); + if (name.ToString() == "omp_priv") + ompPrivVar = declareOp.getResult(0); + symMap.addVariableDefinition(*name.symbol, declareOp); + } + // Lower the expression/function call + lower::StatementContext stmtCtx; + mlir::Value result = common::visit( + common::visitors{ + [&](const evaluate::ProcedureRef &procRef) -> mlir::Value { + convertCallToHLFIR(loc, converter, procRef, std::nullopt, + symMap, stmtCtx); + auto privVal = fir::LoadOp::create(builder, loc, ompPrivVar); + return privVal; + }, + [&](const auto &expr) -> mlir::Value { + mlir::Value exprResult = fir::getBase(convertExprToValue( + loc, converter, clause->v, symMap, stmtCtx)); + // Conversion can either give a value or a refrence to a value, + // we need to return the reduction type, so an optional load may + // be generated. + if (auto refType = llvm::dyn_cast<fir::ReferenceType>( + exprResult.getType())) + if (ompPrivVar.getType() == refType) + exprResult = fir::LoadOp::create(builder, loc, exprResult); + return exprResult; + }}, + clause->v.u); + stmtCtx.finalizeAndPop(); + return result; + }; + return true; + } + return false; +} + bool ClauseProcessor::processMergeable( mlir::omp::MergeableClauseOps &result) const { return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable); } +bool ClauseProcessor::processNogroup( + mlir::omp::NogroupClauseOps &result) const { + return markClauseOccurrence<omp::clause::Nogroup>(result.nogroup); +} + bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const { return markClauseOccurrence<omp::clause::Nowait>(result.nowait); } @@ -1159,14 +1204,26 @@ bool ClauseProcessor::processInReduction( } bool ClauseProcessor::processIsDevicePtr( - mlir::omp::IsDevicePtrClauseOps &result, + lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result, llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const { - return findRepeatableClause<omp::clause::IsDevicePtr>( - [&](const omp::clause::IsDevicePtr &devPtrClause, - const parser::CharBlock &) { - addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars, - isDeviceSyms); + std::map<Object, OmpMapParentAndMemberData> parentMemberIndices; + bool clauseFound = findRepeatableClause<omp::clause::IsDevicePtr>( + [&](const omp::clause::IsDevicePtr &clause, + const parser::CharBlock &source) { + mlir::Location location = converter.genLocation(source); + // Force a map so the descriptor is materialized on the device with the + // device address inside. + mlir::omp::ClauseMapFlags mapTypeBits = + mlir::omp::ClauseMapFlags::is_device_ptr | + mlir::omp::ClauseMapFlags::to; + processMapObjects(stmtCtx, location, clause.v, mapTypeBits, + parentMemberIndices, result.isDevicePtrVars, + isDeviceSyms); }); + + insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices, + result.isDevicePtrVars, isDeviceSyms); + return clauseFound; } bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const { @@ -1175,11 +1232,20 @@ bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const { omp::clause::Linear>([&](const omp::clause::Linear &clause, const parser::CharBlock &) { auto &objects = std::get<omp::ObjectList>(clause.t); + static std::vector<mlir::Attribute> typeAttrs; + + if (!result.linearVars.size()) + typeAttrs.clear(); + for (const omp::Object &object : objects) { semantics::Symbol *sym = object.sym(); const mlir::Value variable = converter.getSymbolAddress(*sym); result.linearVars.push_back(variable); + mlir::Type ty = converter.genType(*sym); + typeAttrs.push_back(mlir::TypeAttr::get(ty)); } + result.linearVarTypes = + mlir::ArrayAttr::get(&converter.getMLIRContext(), typeAttrs); if (objects.size()) { if (auto &mod = std::get<std::optional<omp::clause::Linear::StepComplexModifier>>( @@ -1223,26 +1289,67 @@ void ClauseProcessor::processMapObjects( llvm::StringRef mapperIdNameRef) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - auto getDefaultMapperID = [&](const omp::Object &object, - std::string &mapperIdName) { - if (!mlir::isa<mlir::omp::DeclareMapperOp>( - firOpBuilder.getRegion().getParentOp())) { - const semantics::DerivedTypeSpec *typeSpec = nullptr; + auto getSymbolDerivedType = [](const semantics::Symbol &symbol) + -> const semantics::DerivedTypeSpec * { + const semantics::Symbol &ultimate = symbol.GetUltimate(); + if (const semantics::DeclTypeSpec *declType = ultimate.GetType()) + if (const auto *derived = declType->AsDerived()) + return derived; + return nullptr; + }; + + auto addImplicitMapper = [&](const omp::Object &object, + std::string &mapperIdName, + bool allowGenerate) -> mlir::FlatSymbolRefAttr { + if (mapperIdName.empty()) + return mlir::FlatSymbolRefAttr(); - if (object.sym()->owner().IsDerivedType()) - typeSpec = object.sym()->owner().derivedTypeSpec(); - else if (object.sym()->GetType() && - object.sym()->GetType()->category() == - semantics::DeclTypeSpec::TypeDerived) - typeSpec = &object.sym()->GetType()->derivedTypeSpec(); - - if (typeSpec) { - mapperIdName = - typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName; - if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) - mapperIdName = converter.mangleName(mapperIdName, sym->owner()); - } + if (converter.getModuleOp().lookupSymbol(mapperIdName)) + return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), + mapperIdName); + + if (!allowGenerate) + return mlir::FlatSymbolRefAttr(); + + const semantics::DerivedTypeSpec *typeSpec = + getSymbolDerivedType(*object.sym()); + if (!typeSpec && object.sym()->owner().IsDerivedType()) + typeSpec = object.sym()->owner().derivedTypeSpec(); + + if (!typeSpec) + return mlir::FlatSymbolRefAttr(); + + mlir::Type type = converter.genType(*typeSpec); + auto recordType = mlir::dyn_cast<fir::RecordType>(type); + if (!recordType) + return mlir::FlatSymbolRefAttr(); + + return getOrGenImplicitDefaultDeclareMapper(converter, clauseLocation, + recordType, mapperIdName); + }; + + auto getDefaultMapperID = + [&](const semantics::DerivedTypeSpec *typeSpec) -> std::string { + if (mlir::isa<mlir::omp::DeclareMapperOp>( + firOpBuilder.getRegion().getParentOp()) || + !typeSpec) + return {}; + + std::string mapperIdName = + typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName; + if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) { + mapperIdName = + converter.mangleName(mapperIdName, sym->GetUltimate().owner()); + } else { + mapperIdName = converter.mangleName(mapperIdName, *typeSpec->GetScope()); } + + // Make sure we don't return a mapper to self. + if (auto declMapOp = mlir::dyn_cast<mlir::omp::DeclareMapperOp>( + firOpBuilder.getRegion().getParentOp())) + if (mapperIdName == declMapOp.getSymName()) + return {}; + return mapperIdName; }; // Create the mapper symbol from its name, if specified. @@ -1251,8 +1358,13 @@ void ClauseProcessor::processMapObjects( mapperIdNameRef != "__implicit_mapper") { std::string mapperIdName = mapperIdNameRef.str(); const omp::Object &object = objects.front(); - if (mapperIdNameRef == "default") - getDefaultMapperID(object, mapperIdName); + if (mapperIdNameRef == "default") { + const semantics::DerivedTypeSpec *typeSpec = + getSymbolDerivedType(*object.sym()); + if (!typeSpec && object.sym()->owner().IsDerivedType()) + typeSpec = object.sym()->owner().derivedTypeSpec(); + mapperIdName = getDefaultMapperID(typeSpec); + } assert(converter.getModuleOp().lookupSymbol(mapperIdName) && "mapper not found"); mapperId = @@ -1290,13 +1402,25 @@ void ClauseProcessor::processMapObjects( } } + const semantics::DerivedTypeSpec *objectTypeSpec = + getSymbolDerivedType(*object.sym()); + if (mapperIdNameRef == "__implicit_mapper") { - std::string mapperIdName; - getDefaultMapperID(object, mapperIdName); - mapperId = converter.getModuleOp().lookupSymbol(mapperIdName) - ? mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), - mapperIdName) - : mlir::FlatSymbolRefAttr(); + if (parentObj.has_value()) { + mapperId = mlir::FlatSymbolRefAttr(); + } else if (objectTypeSpec) { + std::string mapperIdName = getDefaultMapperID(objectTypeSpec); + bool needsDefaultMapper = + semantics::IsAllocatableOrObjectPointer(object.sym()) || + requiresImplicitDefaultDeclareMapper(*objectTypeSpec); + if (!mapperIdName.empty()) + mapperId = addImplicitMapper(object, mapperIdName, + /*allowGenerate=*/needsDefaultMapper); + else + mapperId = mlir::FlatSymbolRefAttr(); + } else { + mapperId = mlir::FlatSymbolRefAttr(); + } } // Explicit map captures are captured ByRef by default, @@ -1392,10 +1516,14 @@ bool ClauseProcessor::processMap( } if (mappers) { assert(mappers->size() == 1 && "more than one mapper"); - mapperIdName = mappers->front().v.id().symbol->name().ToString(); - if (mapperIdName != "default") - mapperIdName = converter.mangleName( - mapperIdName, mappers->front().v.id().symbol->owner()); + const semantics::Symbol *mapperSym = mappers->front().v.id().symbol; + mapperIdName = mapperSym->name().ToString(); + if (mapperIdName != "default") { + // Mangle with the ultimate owner so that use-associated mapper + // identifiers resolve to the same symbol as their defining scope. + const semantics::Symbol &ultimate = mapperSym->GetUltimate(); + mapperIdName = converter.mangleName(mapperIdName, ultimate.owner()); + } } processMapObjects(stmtCtx, clauseLocation, diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 6452e39..3485a4e 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -18,8 +18,8 @@ #include "flang/Lower/Bridge.h" #include "flang/Lower/DirectivesCommon.h" #include "flang/Lower/OpenMP/Clauses.h" +#include "flang/Lower/Support/ReductionProcessor.h" #include "flang/Optimizer/Builder/Todo.h" -#include "flang/Parser/dump-parse-tree.h" #include "flang/Parser/parse-tree.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" @@ -88,7 +88,11 @@ public: bool processHint(mlir::omp::HintClauseOps &result) const; bool processInclusive(mlir::Location currentLocation, mlir::omp::InclusiveClauseOps &result) const; + bool processInitializer( + lower::SymMap &symMap, const parser::OmpClause::Initializer &inp, + ReductionProcessor::GenInitValueCBTy &genInitValueCB) const; bool processMergeable(mlir::omp::MergeableClauseOps &result) const; + bool processNogroup(mlir::omp::NogroupClauseOps &result) const; bool processNowait(mlir::omp::NowaitClauseOps &result) const; bool processNumTasks(lower::StatementContext &stmtCtx, mlir::omp::NumTasksClauseOps &result) const; @@ -130,7 +134,7 @@ public: mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result, llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const; bool processIsDevicePtr( - mlir::omp::IsDevicePtrClauseOps &result, + lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result, llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const; bool processLinear(mlir::omp::LinearClauseOps &result) const; bool diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index 0f60b47..61430fc 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -10,7 +10,6 @@ #include "flang/Common/idioms.h" #include "flang/Evaluate/expression.h" -#include "flang/Optimizer/Builder/Todo.h" #include "flang/Parser/parse-tree.h" #include "flang/Semantics/expression.h" #include "flang/Semantics/openmp-modifiers.h" @@ -249,8 +248,10 @@ MAKE_EMPTY_CLASS(Groupprivate, Groupprivate); MAKE_INCOMPLETE_CLASS(AdjustArgs, AdjustArgs); MAKE_INCOMPLETE_CLASS(AppendArgs, AppendArgs); +MAKE_INCOMPLETE_CLASS(Collector, Collector); MAKE_INCOMPLETE_CLASS(GraphId, GraphId); MAKE_INCOMPLETE_CLASS(GraphReset, GraphReset); +MAKE_INCOMPLETE_CLASS(Inductor, Inductor); MAKE_INCOMPLETE_CLASS(Replayable, Replayable); MAKE_INCOMPLETE_CLASS(Transparent, Transparent); @@ -394,8 +395,6 @@ makePrescriptiveness(parser::OmpPrescriptiveness::Value v) { switch (v) { case parser::OmpPrescriptiveness::Value::Strict: return clause::Prescriptiveness::Strict; - case parser::OmpPrescriptiveness::Value::Fallback: - return clause::Prescriptiveness::Fallback; } llvm_unreachable("Unexpected prescriptiveness"); } @@ -797,21 +796,31 @@ DynGroupprivate make(const parser::OmpClause::DynGroupprivate &inp, semantics::SemanticsContext &semaCtx) { // imp.v -> OmpDyngroupprivateClause CLAUSET_ENUM_CONVERT( // - convert, parser::OmpAccessGroup::Value, DynGroupprivate::AccessGroup, + makeAccessGroup, parser::OmpAccessGroup::Value, + DynGroupprivate::AccessGroup, // clang-format off MS(Cgroup, Cgroup) // clang-format on ); + CLAUSET_ENUM_CONVERT( // + makeFallback, parser::OmpFallbackModifier::Value, + DynGroupprivate::Fallback, + // clang-format off + MS(Abort, Abort) + MS(Default_Mem, Default_Mem) + MS(Null, Null) + // clang-format on + ); + auto &mods = semantics::OmpGetModifiers(inp.v); auto *m0 = semantics::OmpGetUniqueModifier<parser::OmpAccessGroup>(mods); - auto *m1 = semantics::OmpGetUniqueModifier<parser::OmpPrescriptiveness>(mods); + auto *m1 = semantics::OmpGetUniqueModifier<parser::OmpFallbackModifier>(mods); auto &size = std::get<parser::ScalarIntExpr>(inp.v.t); - return DynGroupprivate{ - {/*AccessGroup=*/maybeApplyToV(convert, m0), - /*Prescriptiveness=*/maybeApplyToV(makePrescriptiveness, m1), - /*Size=*/makeExpr(size, semaCtx)}}; + return DynGroupprivate{{/*AccessGroup=*/maybeApplyToV(makeAccessGroup, m0), + /*Fallback=*/maybeApplyToV(makeFallback, m1), + /*Size=*/makeExpr(size, semaCtx)}}; } Enter make(const parser::OmpClause::Enter &inp, @@ -972,7 +981,22 @@ Init make(const parser::OmpClause::Init &inp, Initializer make(const parser::OmpClause::Initializer &inp, semantics::SemanticsContext &semaCtx) { - llvm_unreachable("Empty: initializer"); + const parser::OmpInitializerExpression &iexpr = inp.v.v; + const parser::OmpStylizedInstance &styleInstance = iexpr.v.front(); + const parser::OmpStylizedInstance::Instance &instance = + std::get<parser::OmpStylizedInstance::Instance>(styleInstance.t); + if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) { + auto &expr = std::get<parser::Expr>(as->t); + return Initializer{makeExpr(expr, semaCtx)}; + } else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) { + if (call->typedCall) { + const auto &procRef = *call->typedCall; + semantics::SomeExpr evalProcRef{procRef}; + return Initializer{evalProcRef}; + } + } + + llvm_unreachable("Unexpected initializer"); } InReduction make(const parser::OmpClause::InReduction &inp, @@ -1052,7 +1076,7 @@ Link make(const parser::OmpClause::Link &inp, return Link{/*List=*/makeObjects(inp.v, semaCtx)}; } -LoopRange make(const parser::OmpClause::Looprange &inp, +Looprange make(const parser::OmpClause::Looprange &inp, semantics::SemanticsContext &semaCtx) { llvm_unreachable("Unimplemented: looprange"); } diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp index 146a252..83c2eda 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp @@ -342,7 +342,8 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) { if (!hasLastPrivate) return; - if (mlir::isa<mlir::omp::WsloopOp>(op) || mlir::isa<mlir::omp::SimdOp>(op)) { + if (mlir::isa<mlir::omp::WsloopOp>(op) || mlir::isa<mlir::omp::SimdOp>(op) || + mlir::isa<mlir::omp::TaskloopOp>(op)) { mlir::omp::LoopRelatedClauseOps result; llvm::SmallVector<const semantics::Symbol *> iv; collectLoopRelatedInfo(converter, converter.getCurrentLocation(), eval, @@ -408,7 +409,7 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) { } else { TODO(converter.getCurrentLocation(), "lastprivate clause in constructs other than " - "simd/worksharing-loop"); + "simd/worksharing-loop/taskloop"); } } diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 7106728..9c25c19 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -18,12 +18,17 @@ #include "Decomposer.h" #include "Utils.h" #include "flang/Common/idioms.h" +#include "flang/Evaluate/type.h" #include "flang/Lower/Bridge.h" +#include "flang/Lower/ConvertCall.h" #include "flang/Lower/ConvertExpr.h" +#include "flang/Lower/ConvertExprToHLFIR.h" #include "flang/Lower/ConvertVariable.h" #include "flang/Lower/DirectivesCommon.h" #include "flang/Lower/OpenMP/Clauses.h" +#include "flang/Lower/PFTBuilder.h" #include "flang/Lower/StatementContext.h" +#include "flang/Lower/Support/ReductionProcessor.h" #include "flang/Lower/SymbolMap.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/FIRBuilder.h" @@ -565,14 +570,9 @@ getCollapsedLoopEval(lower::pft::Evaluation &eval, int collapseValue) { if (collapseValue == 0) return &eval; - lower::pft::Evaluation *curEval = &eval.getFirstNestedEvaluation(); - for (int i = 1; i < collapseValue; i++) { - // The nested evaluations should be DoConstructs (i.e. they should form - // a loop nest). Each DoConstruct is a tuple <NonLabelDoStmt, Block, - // EndDoStmt>. - assert(curEval->isA<parser::DoConstruct>()); - curEval = &*std::next(curEval->getNestedEvaluations().begin()); - } + lower::pft::Evaluation *curEval = &eval; + for (int i = 0; i < collapseValue; i++) + curEval = getNestedDoConstruct(*curEval); return curEval; } @@ -1008,9 +1008,7 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, mlir::omp::VariableCaptureKind::ByRef); break; case DefMap::ImplicitBehavior::Firstprivate: - case DefMap::ImplicitBehavior::None: - TODO(loc, "Firstprivate and None are currently unsupported defaultmap " - "behaviour"); + TODO(loc, "Firstprivate is currently unsupported defaultmap behaviour"); break; case DefMap::ImplicitBehavior::From: return std::make_pair(mapFlag |= mlir::omp::ClauseMapFlags::from, @@ -1032,8 +1030,9 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder, mlir::omp::VariableCaptureKind::ByRef); break; case DefMap::ImplicitBehavior::Default: + case DefMap::ImplicitBehavior::None: llvm_unreachable( - "Implicit None Behaviour Should Have Been Handled Earlier"); + "Implicit None and Default behaviour should have been handled earlier"); break; } @@ -1203,7 +1202,7 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info, // Start with privatization, so that the lowering of the nested // code will use the right symbols. bool isLoop = llvm::omp::getDirectiveAssociation(info.dir) == - llvm::omp::Association::Loop; + llvm::omp::Association::LoopNest; bool privatize = info.clauses && info.privatize; firOpBuilder.setInsertionPoint(marker); @@ -1637,8 +1636,7 @@ static void genSimdClauses( cp.processReduction(loc, clauseOps, reductionSyms); cp.processSafelen(clauseOps); cp.processSimdlen(clauseOps); - - cp.processTODO<clause::Linear>(loc, llvm::omp::Directive::OMPD_simd); + cp.processLinear(clauseOps); } static void genSingleClauses(lower::AbstractConverter &converter, @@ -1673,7 +1671,7 @@ static void genTargetClauses( hostEvalInfo->collectValues(clauseOps.hostEvalVars); } cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); - cp.processIsDevicePtr(clauseOps, isDevicePtrSyms); + cp.processIsDevicePtr(stmtCtx, clauseOps, isDevicePtrSyms); cp.processMap(loc, stmtCtx, clauseOps, llvm::omp::Directive::OMPD_unknown, &mapSyms); cp.processNowait(clauseOps); @@ -1763,21 +1761,25 @@ static void genTaskgroupClauses( cp.processTaskReduction(loc, clauseOps, taskReductionSyms); } -static void genTaskloopClauses(lower::AbstractConverter &converter, - semantics::SemanticsContext &semaCtx, - lower::StatementContext &stmtCtx, - const List<Clause> &clauses, mlir::Location loc, - mlir::omp::TaskloopOperands &clauseOps) { +static void genTaskloopClauses( + lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, const List<Clause> &clauses, + mlir::Location loc, mlir::omp::TaskloopOperands &clauseOps, + llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms, + llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); + cp.processAllocate(clauseOps); + cp.processFinal(stmtCtx, clauseOps); cp.processGrainsize(stmtCtx, clauseOps); + cp.processIf(llvm::omp::Directive::OMPD_taskloop, clauseOps); + cp.processInReduction(loc, clauseOps, inReductionSyms); + cp.processMergeable(clauseOps); + cp.processNogroup(clauseOps); cp.processNumTasks(stmtCtx, clauseOps); - - cp.processTODO<clause::Allocate, clause::Collapse, clause::Default, - clause::Final, clause::If, clause::InReduction, - clause::Lastprivate, clause::Mergeable, clause::Nogroup, - clause::Priority, clause::Reduction, clause::Shared, - clause::Untied>(loc, llvm::omp::Directive::OMPD_taskloop); + cp.processPriority(stmtCtx, clauseOps); + cp.processReduction(loc, clauseOps, reductionSyms); + cp.processUntied(clauseOps); } static void genTaskwaitClauses(lower::AbstractConverter &converter, @@ -1828,9 +1830,9 @@ static void genWsloopClauses( cp.processOrdered(clauseOps); cp.processReduction(loc, clauseOps, reductionSyms); cp.processSchedule(stmtCtx, clauseOps); + cp.processLinear(clauseOps); - cp.processTODO<clause::Allocate, clause::Linear>( - loc, llvm::omp::Directive::OMPD_do); + cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do); } //===----------------------------------------------------------------------===// @@ -2485,13 +2487,15 @@ static bool isDuplicateMappedSymbol( const semantics::Symbol &sym, const llvm::SetVector<const semantics::Symbol *> &privatizedSyms, const llvm::SmallVectorImpl<const semantics::Symbol *> &hasDevSyms, - const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms) { + const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms, + const llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms) { llvm::SmallVector<const semantics::Symbol *> concatSyms; concatSyms.reserve(privatizedSyms.size() + hasDevSyms.size() + - mappedSyms.size()); + mappedSyms.size() + isDevicePtrSyms.size()); concatSyms.append(privatizedSyms.begin(), privatizedSyms.end()); concatSyms.append(hasDevSyms.begin(), hasDevSyms.end()); concatSyms.append(mappedSyms.begin(), mappedSyms.end()); + concatSyms.append(isDevicePtrSyms.begin(), isDevicePtrSyms.end()); auto checkSymbol = [&](const semantics::Symbol &checkSym) { return std::any_of(concatSyms.begin(), concatSyms.end(), @@ -2531,6 +2535,38 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, loc, clauseOps, defaultMaps, hasDeviceAddrSyms, isDevicePtrSyms, mapSyms); + if (!isDevicePtrSyms.empty()) { + // is_device_ptr maps get duplicated so the clause and synthesized + // has_device_addr entry each own a unique MapInfoOp user, keeping + // MapInfoFinalization happy while still wiring the symbol into + // has_device_addr when the user didn’t spell it explicitly. + auto insertionPt = firOpBuilder.saveInsertionPoint(); + auto alreadyPresent = [&](const semantics::Symbol *sym) { + return llvm::any_of(hasDeviceAddrSyms, [&](const semantics::Symbol *s) { + return s && sym && s->GetUltimate() == sym->GetUltimate(); + }); + }; + + for (auto [idx, sym] : llvm::enumerate(isDevicePtrSyms)) { + mlir::Value mapVal = clauseOps.isDevicePtrVars[idx]; + assert(sym && "expected symbol for is_device_ptr"); + assert(mapVal && "expected map value for is_device_ptr"); + auto mapInfo = mapVal.getDefiningOp<mlir::omp::MapInfoOp>(); + assert(mapInfo && "expected map info op"); + + if (!alreadyPresent(sym)) { + clauseOps.hasDeviceAddrVars.push_back(mapVal); + hasDeviceAddrSyms.push_back(sym); + } + + firOpBuilder.setInsertionPointAfter(mapInfo); + mlir::Operation *clonedOp = firOpBuilder.clone(*mapInfo.getOperation()); + auto clonedMapInfo = mlir::cast<mlir::omp::MapInfoOp>(clonedOp); + clauseOps.isDevicePtrVars[idx] = clonedMapInfo.getResult(); + } + firOpBuilder.restoreInsertionPoint(insertionPt); + } + DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/ lower::omp::isLastItemInQueue(item, queue), @@ -2570,7 +2606,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, return; if (!isDuplicateMappedSymbol(sym, dsp.getAllSymbolsToPrivatize(), - hasDeviceAddrSyms, mapSyms)) { + hasDeviceAddrSyms, mapSyms, isDevicePtrSyms)) { if (const auto *details = sym.template detailsIf<semantics::HostAssocDetails>()) converter.copySymbolBinding(details->symbol(), sym); @@ -2578,18 +2614,6 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym); name << sym.name().ToString(); - mlir::FlatSymbolRefAttr mapperId; - if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived) { - auto &typeSpec = sym.GetType()->derivedTypeSpec(); - std::string mapperIdName = - typeSpec.name().ToString() + llvm::omp::OmpDefaultMapperName; - if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) - mapperIdName = converter.mangleName(mapperIdName, sym->owner()); - if (converter.getModuleOp().lookupSymbol(mapperIdName)) - mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), - mapperIdName); - } - fir::factory::AddrAndBoundsInfo info = Fortran::lower::getDataOperandBaseAddr( converter, firOpBuilder, sym.GetUltimate(), @@ -2609,6 +2633,44 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, mapFlagAndKind = getImplicitMapTypeAndKind( firOpBuilder, converter, defaultMaps, eleType, loc, sym); + mlir::FlatSymbolRefAttr mapperId; + if (defaultMaps.empty()) { + // TODO: Honor user-provided defaultmap clauses (aggregates/pointers) + // instead of blanket-disabling implicit mapper generation whenever any + // explicit default map is present. + const semantics::DerivedTypeSpec *typeSpec = + sym.GetType() ? sym.GetType()->AsDerived() : nullptr; + if (typeSpec) { + std::string mapperIdName = + typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName; + if (auto *mapperSym = + converter.getCurrentScope().FindSymbol(mapperIdName)) + mapperIdName = converter.mangleName( + mapperIdName, mapperSym->GetUltimate().owner()); + else + mapperIdName = + converter.mangleName(mapperIdName, *typeSpec->GetScope()); + + if (!mapperIdName.empty()) { + bool allowImplicitMapper = + semantics::IsAllocatableOrObjectPointer(&sym); + bool hasDefaultMapper = + converter.getModuleOp().lookupSymbol(mapperIdName); + if (hasDefaultMapper || allowImplicitMapper) { + if (!hasDefaultMapper) { + if (auto recordType = mlir::dyn_cast_or_null<fir::RecordType>( + converter.genType(*typeSpec))) + mapperId = getOrGenImplicitDefaultDeclareMapper( + converter, loc, recordType, mapperIdName); + } else { + mapperId = mlir::FlatSymbolRefAttr::get( + &converter.getMLIRContext(), mapperIdName); + } + } + } + } + } + mlir::Value mapOp = createMapInfoOp( firOpBuilder, converter.getCurrentLocation(), baseOp, /*varPtrPtr=*/mlir::Value{}, name.str(), bounds, /*members=*/{}, @@ -2818,7 +2880,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, // TODO: Add private syms and vars. args.reduction.syms = reductionSyms; args.reduction.vars = clauseOps.reductionVars; - return genOpWithBody<mlir::omp::TeamsOp>( OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, llvm::omp::Directive::OMPD_teams) @@ -2979,8 +3040,11 @@ static mlir::omp::TaskloopOp genStandaloneTaskloop( lower::pft::Evaluation &eval, mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item) { mlir::omp::TaskloopOperands taskloopClauseOps; + llvm::SmallVector<const semantics::Symbol *> reductionSyms; + llvm::SmallVector<const semantics::Symbol *> inReductionSyms; + genTaskloopClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - taskloopClauseOps); + taskloopClauseOps, reductionSyms, inReductionSyms); DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/true, enableDelayedPrivatization, symTable); @@ -2994,6 +3058,10 @@ static mlir::omp::TaskloopOp genStandaloneTaskloop( EntryBlockArgs taskloopArgs; taskloopArgs.priv.syms = dsp.getDelayedPrivSymbols(); taskloopArgs.priv.vars = taskloopClauseOps.privateVars; + taskloopArgs.reduction.syms = reductionSyms; + taskloopArgs.reduction.vars = taskloopClauseOps.reductionVars; + taskloopArgs.inReduction.syms = inReductionSyms; + taskloopArgs.inReduction.vars = taskloopClauseOps.inReductionVars; auto taskLoopOp = genWrapperOp<mlir::omp::TaskloopOp>( converter, loc, taskloopClauseOps, taskloopArgs); @@ -3246,17 +3314,12 @@ static mlir::omp::WsloopOp genCompositeDoSimd( genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps, simdReductionSyms); - DataSharingProcessor wsloopItemDSP( - converter, semaCtx, doItem->clauses, eval, - /*shouldCollectPreDeterminedSymbols=*/false, - /*useDelayedPrivatization=*/true, symTable); + DataSharingProcessor wsloopItemDSP(converter, semaCtx, doItem->clauses, eval, + /*shouldCollectPreDeterminedSymbols=*/true, + /*useDelayedPrivatization=*/true, + symTable); wsloopItemDSP.processStep1(&wsloopClauseOps); - DataSharingProcessor simdItemDSP(converter, semaCtx, simdItem->clauses, eval, - /*shouldCollectPreDeterminedSymbols=*/true, - /*useDelayedPrivatization=*/true, symTable); - simdItemDSP.processStep1(&simdClauseOps, simdItem->id); - // Pass the innermost leaf construct's clauses because that's where COLLAPSE // is placed by construct decomposition. mlir::omp::LoopNestOperands loopNestClauseOps; @@ -3275,8 +3338,9 @@ static mlir::omp::WsloopOp genCompositeDoSimd( wsloopOp.setComposite(/*val=*/true); EntryBlockArgs simdArgs; - simdArgs.priv.syms = simdItemDSP.getDelayedPrivSymbols(); - simdArgs.priv.vars = simdClauseOps.privateVars; + // For composite 'do simd', privatization is handled by the wsloop. + // The simd does not create separate private storage for variables already + // privatized by the worksharing construct. simdArgs.reduction.syms = simdReductionSyms; simdArgs.reduction.vars = simdClauseOps.reductionVars; auto simdOp = @@ -3286,7 +3350,7 @@ static mlir::omp::WsloopOp genCompositeDoSimd( genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem, loopNestClauseOps, iv, {{wsloopOp, wsloopArgs}, {simdOp, simdArgs}}, - llvm::omp::Directive::OMPD_do_simd, simdItemDSP); + llvm::omp::Directive::OMPD_do_simd, wsloopItemDSP); return wsloopOp; } @@ -3362,7 +3426,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter, }; bool loopLeaf = llvm::omp::getDirectiveAssociation(item->id) == - llvm::omp::Association::Loop; + llvm::omp::Association::LoopNest; if (loopLeaf) { symTable.pushScope(); if (genOMPCompositeDispatch(converter, symTable, stmtCtx, semaCtx, eval, @@ -3471,6 +3535,13 @@ static void genOMPDispatch(lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_tile: genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; + case llvm::omp::Directive::OMPD_fuse: { + unsigned version = semaCtx.langOptions().OpenMPVersion; + if (!semaCtx.langOptions().OpenMPSimd) + TODO(loc, "Unhandled loop directive (" + + llvm::omp::getOpenMPDirectiveName(dir, version) + ")"); + break; + } case llvm::omp::Directive::OMPD_unroll: genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; @@ -3503,12 +3574,12 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, lower::pft::Evaluation &eval, const parser::OpenMPUtilityConstruct &); -static void -genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - const parser::OpenMPDeclarativeAllocate &declarativeAllocate) { +static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, + lower::pft::Evaluation &eval, + const parser::OmpAllocateDirective &allocate) { if (!semaCtx.langOptions().OpenMPSimd) - TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate"); + TODO(converter.getCurrentLocation(), "OmpAllocateDirective"); } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, @@ -3527,12 +3598,186 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective"); } +static ReductionProcessor::GenCombinerCBTy +processReductionCombiner(lower::AbstractConverter &converter, + lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, + const parser::OmpReductionSpecifier &specifier) { + ReductionProcessor::GenCombinerCBTy genCombinerCB; + const auto &combinerExpression = + std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t) + .value(); + const parser::OmpStylizedInstance &combinerInstance = + combinerExpression.v.front(); + const parser::OmpStylizedInstance::Instance &instance = + std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t); + + std::optional<semantics::SomeExpr> evalExprOpt; + if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) { + auto &expr = std::get<parser::Expr>(as->t); + evalExprOpt = makeExpr(expr, semaCtx); + } else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) { + if (call->typedCall) { + const auto &procRef = *call->typedCall; + evalExprOpt = semantics::SomeExpr{procRef}; + } else { + TODO(converter.getCurrentLocation(), + "CallStmt without typedCall is not yet supported"); + } + } else { + TODO(converter.getCurrentLocation(), "Unsupported combiner instance type"); + } + + assert(evalExprOpt.has_value() && "evalExpr must be initialized"); + semantics::SomeExpr evalExpr = *evalExprOpt; + + genCombinerCB = [&, evalExpr](fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type type, mlir::Value lhs, + mlir::Value rhs, bool isByRef) { + lower::SymMapScope scope(symTable); + const std::list<parser::OmpStylizedDeclaration> &declList = + std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t); + mlir::Value ompOutVar; + for (const parser::OmpStylizedDeclaration &decl : declList) { + auto &name = std::get<parser::ObjectName>(decl.var.t); + mlir::Value addr = lhs; + mlir::Type type = lhs.getType(); + bool isRhs = name.ToString() == std::string("omp_in"); + if (isRhs) { + addr = rhs; + type = rhs.getType(); + } + + assert(name.symbol && "Reduction object name does not have a symbol"); + if (!fir::conformsWithPassByRef(type)) { + addr = builder.createTemporary(loc, type); + fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr); + } + fir::FortranVariableFlagsEnum extraFlags = {}; + fir::FortranVariableFlagsAttr attributes = + Fortran::lower::translateSymbolAttributes(builder.getContext(), + *name.symbol, extraFlags); + auto declareOp = + hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr, + {}, nullptr, nullptr, 0, attributes); + if (name.ToString() == "omp_out") + ompOutVar = declareOp.getResult(0); + symTable.addVariableDefinition(*name.symbol, declareOp); + } + + lower::StatementContext stmtCtx; + mlir::Value result = common::visit( + common::visitors{ + [&](const evaluate::ProcedureRef &procRef) -> mlir::Value { + convertCallToHLFIR(loc, converter, procRef, std::nullopt, + symTable, stmtCtx); + auto outVal = fir::LoadOp::create(builder, loc, ompOutVar); + return outVal; + }, + [&](const auto &expr) -> mlir::Value { + mlir::Value exprResult = fir::getBase(convertExprToValue( + loc, converter, evalExpr, symTable, stmtCtx)); + // Optional load may be generated if we get a reference to the + // reduction type. + if (auto refType = + llvm::dyn_cast<fir::ReferenceType>(exprResult.getType())) + if (lhs.getType() == refType.getElementType()) + exprResult = fir::LoadOp::create(builder, loc, exprResult); + return exprResult; + }}, + evalExpr.u); + stmtCtx.finalizeAndPop(); + if (isByRef) { + fir::StoreOp::create(builder, loc, result, lhs); + mlir::omp::YieldOp::create(builder, loc, lhs); + } else { + mlir::omp::YieldOp::create(builder, loc, result); + } + }; + return genCombinerCB; +} + +// Checks that the reduction type is either a trivial type or a derived type of +// trivial types. +static bool isSimpleReductionType(mlir::Type reductionType) { + if (fir::isa_trivial(reductionType)) + return true; + if (auto recordTy = mlir::dyn_cast<fir::RecordType>(reductionType)) { + for (auto [_, fieldType] : recordTy.getTypeList()) { + if (!fir::isa_trivial(fieldType)) + return false; + } + } + return true; +} + +// Getting the type from a symbol compared to a DeclSpec is simpler since we do +// not need to consider derived vs intrinsic types. Semantics is guaranteed to +// generate these symbols. +static mlir::Type +getReductionType(lower::AbstractConverter &converter, + const parser::OmpReductionSpecifier &specifier) { + const auto &combinerExpression = + std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t) + .value(); + const parser::OmpStylizedInstance &combinerInstance = + combinerExpression.v.front(); + const std::list<parser::OmpStylizedDeclaration> &declList = + std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t); + const parser::OmpStylizedDeclaration &decl = declList.front(); + const auto &name = std::get<parser::ObjectName>(decl.var.t); + const auto &symbol = semantics::SymbolRef(*name.symbol); + mlir::Type reductionType = converter.genType(symbol); + + if (!isSimpleReductionType(reductionType)) + TODO(converter.getCurrentLocation(), + "declare reduction currently only supports trival types or derived " + "types containing trivial types"); + return reductionType; +} + static void genOMP( lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) { - if (!semaCtx.langOptions().OpenMPSimd) - TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct"); + if (semaCtx.langOptions().OpenMPSimd) + return; + + const parser::OmpArgumentList &args{declareReductionConstruct.v.Arguments()}; + const parser::OmpArgument &arg{args.v.front()}; + const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u); + + if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1) + TODO(converter.getCurrentLocation(), + "multiple types in declare reduction is not yet supported"); + + mlir::Type reductionType = getReductionType(converter, specifier); + ReductionProcessor::GenCombinerCBTy genCombinerCB = + processReductionCombiner(converter, symTable, semaCtx, specifier); + const parser::OmpClauseList &initializer = + declareReductionConstruct.v.Clauses(); + if (initializer.v.size() > 0) { + List<Clause> clauses = makeClauses(initializer, semaCtx); + ReductionProcessor::GenInitValueCBTy genInitValueCB; + ClauseProcessor cp(converter, semaCtx, clauses); + const parser::OmpClause::Initializer &iclause{ + std::get<parser::OmpClause::Initializer>(initializer.v.front().u)}; + cp.processInitializer(symTable, iclause, genInitValueCB); + const auto &identifier = + std::get<parser::OmpReductionIdentifier>(specifier.t); + const auto &designator = + std::get<parser::ProcedureDesignator>(identifier.u); + const auto &reductionName = std::get<parser::Name>(designator.u); + bool isByRef = ReductionProcessor::doReductionByRef(reductionType); + ReductionProcessor::createDeclareReductionHelper< + mlir::omp::DeclareReductionOp>( + converter, reductionName.ToString(), reductionType, + converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB); + } else { + TODO(converter.getCurrentLocation(), + "declare reduction without an initializer clause is not yet " + "supported"); + } } static void @@ -3543,10 +3788,10 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct"); } -static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, - lower::pft::Evaluation &eval, - const parser::OpenMPDeclareMapperConstruct &construct) { +static void genOpenMPDeclareMapperImpl( + lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, + const parser::OpenMPDeclareMapperConstruct &construct, + const semantics::Symbol *mapperSymOpt = nullptr) { mlir::Location loc = converter.genLocation(construct.source); fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); const parser::OmpArgumentList &args = construct.v.Arguments(); @@ -3562,8 +3807,17 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, "Expected derived type"); std::string mapperNameStr = mapperName; - if (auto *sym = converter.getCurrentScope().FindSymbol(mapperNameStr)) + if (mapperSymOpt && mapperNameStr != "default") { + mapperNameStr = converter.mangleName(mapperNameStr, mapperSymOpt->owner()); + } else if (auto *sym = + converter.getCurrentScope().FindSymbol(mapperNameStr)) { mapperNameStr = converter.mangleName(mapperNameStr, sym->owner()); + } + + // If the mapper op already exists (e.g., created by regular lowering or by + // materialization of imported mappers), do not recreate it. + if (converter.getModuleOp().lookupSymbol(mapperNameStr)) + return; // Save current insertion point before moving to the module scope to create // the DeclareMapperOp @@ -3586,6 +3840,13 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseOps.mapVars); } +static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, + lower::pft::Evaluation &eval, + const parser::OpenMPDeclareMapperConstruct &construct) { + genOpenMPDeclareMapperImpl(converter, semaCtx, construct); +} + static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, @@ -3902,14 +4163,6 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, - const parser::OpenMPExecutableAllocate &execAllocConstruct) { - if (!semaCtx.langOptions().OpenMPSimd) - TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate"); -} - -static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, - lower::pft::Evaluation &eval, const parser::OpenMPLoopConstruct &loopConstruct) { const parser::OmpDirectiveSpecification &beginSpec = loopConstruct.BeginDir(); List<Clause> clauses = makeClauses(beginSpec.Clauses(), semaCtx); @@ -3918,12 +4171,9 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, mlir::Location currentLocation = converter.genLocation(beginSpec.source); - auto &optLoopCons = - std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t); - if (optLoopCons.has_value()) { - if (auto *ompNestedLoopCons{ - std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>( - &*optLoopCons)}) { + for (auto &construct : std::get<parser::Block>(loopConstruct.t)) { + if (const parser::OpenMPLoopConstruct *ompNestedLoopCons = + parser::omp::GetOmpLoop(construct)) { llvm::omp::Directive nestedDirective = parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v; switch (nestedDirective) { @@ -4229,3 +4479,36 @@ void Fortran::lower::genOpenMPRequires(mlir::Operation *mod, offloadMod.setRequires(mlirFlags); } } + +// Walk scopes and materialize omp.declare_mapper ops for mapper declarations +// found in imported modules. If \p scope is null, start from the global scope. +void Fortran::lower::materializeOpenMPDeclareMappers( + Fortran::lower::AbstractConverter &converter, + semantics::SemanticsContext &semaCtx, const semantics::Scope *scope) { + const semantics::Scope &root = scope ? *scope : semaCtx.globalScope(); + + // Recurse into child scopes first (modules, submodules, etc.). + for (const semantics::Scope &child : root.children()) + materializeOpenMPDeclareMappers(converter, semaCtx, &child); + + // Only consider module scopes to avoid duplicating local constructs. + if (!root.IsModule()) + return; + + // Only materialize for modules coming from mod files to avoid duplicates. + if (!root.symbol() || !root.symbol()->test(semantics::Symbol::Flag::ModFile)) + return; + + // Scan symbols in this module scope for MapperDetails. + for (auto &it : root) { + const semantics::Symbol &sym = *it.second; + if (auto *md = sym.detailsIf<semantics::MapperDetails>()) { + for (const auto *decl : md->GetDeclList()) { + if (const auto *mapperDecl = + std::get_if<parser::OpenMPDeclareMapperConstruct>(&decl->u)) { + genOpenMPDeclareMapperImpl(converter, semaCtx, *mapperDecl, &sym); + } + } + } + } +} diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index 6487f59..a818d63 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -14,22 +14,28 @@ #include "ClauseFinder.h" #include "flang/Evaluate/fold.h" +#include "flang/Evaluate/tools.h" #include <flang/Lower/AbstractConverter.h> #include <flang/Lower/ConvertType.h> #include <flang/Lower/DirectivesCommon.h> #include <flang/Lower/OpenMP/Clauses.h> #include <flang/Lower/PFTBuilder.h> #include <flang/Lower/Support/PrivateReductionUtils.h> +#include <flang/Optimizer/Builder/BoxValue.h> #include <flang/Optimizer/Builder/FIRBuilder.h> #include <flang/Optimizer/Builder/Todo.h> +#include <flang/Optimizer/HLFIR/HLFIROps.h> #include <flang/Parser/openmp-utils.h> #include <flang/Parser/parse-tree.h> #include <flang/Parser/tools.h> #include <flang/Semantics/tools.h> #include <flang/Semantics/type.h> #include <flang/Utils/OpenMP.h> +#include <llvm/ADT/SmallPtrSet.h> +#include <llvm/ADT/StringRef.h> #include <llvm/Support/CommandLine.h> +#include <functional> #include <iterator> template <typename T> @@ -61,6 +67,142 @@ namespace Fortran { namespace lower { namespace omp { +mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper( + lower::AbstractConverter &converter, mlir::Location loc, + fir::RecordType recordType, llvm::StringRef mapperNameStr) { + if (mapperNameStr.empty()) + return {}; + + if (converter.getModuleOp().lookupSymbol(mapperNameStr)) + return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), + mapperNameStr); + + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::OpBuilder::InsertionGuard guard(firOpBuilder); + + firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody()); + auto declMapperOp = mlir::omp::DeclareMapperOp::create( + firOpBuilder, loc, mapperNameStr, recordType); + auto ®ion = declMapperOp.getRegion(); + firOpBuilder.createBlock(®ion); + auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc); + + auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg, + /*uniq_name=*/""); + + const auto genBoundsOps = [&](mlir::Value mapVal, + llvm::SmallVectorImpl<mlir::Value> &bounds) { + fir::ExtendedValue extVal = + hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder, + hlfir::Entity{mapVal}, + /*contiguousHint=*/true) + .first; + fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr( + firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc()); + bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp, + mlir::omp::MapBoundsType>( + firOpBuilder, info, extVal, + /*dataExvIsAssumedSize=*/false, mapVal.getLoc()); + }; + + const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName, + mlir::Type fieldTy, mlir::Type recType) { + mlir::Value field = fir::FieldIndexOp::create( + firOpBuilder, loc, fir::FieldType::get(recType.getContext()), fieldName, + recType, fir::getTypeParams(rec)); + return fir::CoordinateOp::create( + firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field); + }; + + llvm::SmallVector<mlir::Value> clauseMapVars; + llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices; + llvm::SmallVector<mlir::Value> memberMapOps; + + mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to | + mlir::omp::ClauseMapFlags::from | + mlir::omp::ClauseMapFlags::implicit; + mlir::omp::VariableCaptureKind captureKind = + mlir::omp::VariableCaptureKind::ByRef; + + for (const auto &entry : llvm::enumerate(recordType.getTypeList())) { + const auto &memberName = entry.value().first; + const auto &memberType = entry.value().second; + mlir::FlatSymbolRefAttr mapperId; + if (auto recType = mlir::dyn_cast<fir::RecordType>( + fir::getFortranElementType(memberType))) { + std::string mapperIdName = + recType.getName().str() + llvm::omp::OmpDefaultMapperName; + if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) + mapperIdName = converter.mangleName(mapperIdName, sym->owner()); + else if (auto *memberSym = + converter.getCurrentScope().FindSymbol(memberName)) + mapperIdName = converter.mangleName(mapperIdName, memberSym->owner()); + + mapperId = getOrGenImplicitDefaultDeclareMapper(converter, loc, recType, + mapperIdName); + } + + auto ref = + getFieldRef(declareOp.getBase(), memberName, memberType, recordType); + llvm::SmallVector<mlir::Value> bounds; + genBoundsOps(ref, bounds); + mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp( + firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"", + bounds, + /*members=*/{}, + /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(), + /*partialMap=*/false, mapperId); + memberMapOps.emplace_back(mapOp); + memberPlacementIndices.emplace_back( + llvm::SmallVector<int64_t>{(int64_t)entry.index()}); + } + + llvm::SmallVector<mlir::Value> bounds; + genBoundsOps(declareOp.getOriginalBase(), bounds); + mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit; + mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp( + firOpBuilder, loc, declareOp.getOriginalBase(), + /*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps, + firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag, + captureKind, declareOp.getType(0), + /*partialMap=*/true); + + clauseMapVars.emplace_back(mapOp); + mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars); + return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), + mapperNameStr); +} + +bool requiresImplicitDefaultDeclareMapper( + const semantics::DerivedTypeSpec &typeSpec) { + // ISO C interoperable types (e.g., c_ptr, c_funptr) must always have implicit + // default mappers available so that OpenMP offloading can correctly map them. + if (semantics::IsIsoCType(&typeSpec)) + return true; + + llvm::SmallPtrSet<const semantics::DerivedTypeSpec *, 8> visited; + + std::function<bool(const semantics::DerivedTypeSpec &)> requiresMapper = + [&](const semantics::DerivedTypeSpec &spec) -> bool { + if (!visited.insert(&spec).second) + return false; + + semantics::DirectComponentIterator directComponents{spec}; + for (const semantics::Symbol &component : directComponents) { + if (component.attrs().test(semantics::Attr::ALLOCATABLE)) + return true; + + if (const semantics::DeclTypeSpec *declType = component.GetType()) + if (const auto *nested = declType->AsDerived()) + if (requiresMapper(*nested)) + return true; + } + return false; + }; + + return requiresMapper(typeSpec); +} + int64_t getCollapseValue(const List<Clause> &clauses) { auto iter = llvm::find_if(clauses, [](const Clause &clause) { return clause.id == llvm::omp::Clause::OMPC_collapse; @@ -537,6 +679,12 @@ void insertChildMapInfoIntoParent( mapOperands[std::distance(mapSyms.begin(), parentIter)] .getDefiningOp()); + // Once explicit members are attached to a parent map, do not also invoke + // a declare mapper on it, otherwise the mapper would remap the same + // components leading to duplicate mappings at runtime. + if (!indices.second.memberMap.empty() && mapOp.getMapperIdAttr()) + mapOp.setMapperIdAttr(nullptr); + // NOTE: To maintain appropriate SSA ordering, we move the parent map // which will now have references to its children after the last // of its members to be generated. This is necessary when a user @@ -631,17 +779,9 @@ static void processTileSizesFromOpenMPConstruct( if (!ompCons) return; if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) { - const auto &nestedOptional = - std::get<std::optional<parser::NestedConstruct>>(ompLoop->t); - assert(nestedOptional.has_value() && - "Expected a DoConstruct or OpenMPLoopConstruct"); - const auto *innerConstruct = - std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>( - &(nestedOptional.value())); - if (innerConstruct) { - const auto &innerLoopDirective = innerConstruct->value(); + if (auto *innerConstruct = ompLoop->GetNestedConstruct()) { const parser::OmpDirectiveSpecification &innerBeginSpec = - innerLoopDirective.BeginDir(); + innerConstruct->BeginDir(); if (innerBeginSpec.DirId() == llvm::omp::Directive::OMPD_tile) { // Get the size values from parse tree and convert to a vector. for (const auto &clause : innerBeginSpec.Clauses().v) { @@ -656,6 +796,28 @@ static void processTileSizesFromOpenMPConstruct( } } +pft::Evaluation *getNestedDoConstruct(pft::Evaluation &eval) { + for (pft::Evaluation &nested : eval.getNestedEvaluations()) { + // In an OpenMPConstruct there can be compiler directives: + // 1 <<OpenMPConstruct>> + // 2 CompilerDirective: !unroll + // <<DoConstruct>> -> 8 + if (nested.getIf<parser::CompilerDirective>()) + continue; + // Within a DoConstruct, there can be compiler directives, plus + // there is a DoStmt before the body: + // <<DoConstruct>> -> 8 + // 3 NonLabelDoStmt -> 7: do i = 1, n + // <<DoConstruct>> -> 7 + if (nested.getIf<parser::NonLabelDoStmt>()) + continue; + assert(nested.getIf<parser::DoConstruct>() && + "Unexpected construct in the nested evaluations"); + return &nested; + } + llvm_unreachable("Expected do loop to be in the nested evaluations"); +} + /// Populates the sizes vector with values if the given OpenMPConstruct /// contains a loop construct with an inner tiling construct. void collectTileSizesFromOpenMPConstruct( @@ -678,7 +840,7 @@ int64_t collectLoopRelatedInfo( int64_t numCollapse = 1; // Collect the loops to collapse. - lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation(); + lower::pft::Evaluation *doConstructEval = getNestedDoConstruct(eval); if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) { TODO(currentLocation, "Do Concurrent in Worksharing loop construct"); } @@ -704,7 +866,7 @@ void collectLoopRelatedInfo( fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); // Collect the loops to collapse. - lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation(); + lower::pft::Evaluation *doConstructEval = getNestedDoConstruct(eval); if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) { TODO(currentLocation, "Do Concurrent in Worksharing loop construct"); } @@ -745,9 +907,8 @@ void collectLoopRelatedInfo( iv.push_back(bounds->name.thing.symbol); loopVarTypeSize = std::max(loopVarTypeSize, bounds->name.thing.symbol->GetUltimate().size()); - collapseValue--; - doConstructEval = - &*std::next(doConstructEval->getNestedEvaluations().begin()); + if (--collapseValue) + doConstructEval = getNestedDoConstruct(*doConstructEval); } while (collapseValue > 0); convertLoopBounds(converter, currentLocation, result, loopVarTypeSize); diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index ef1f37a..8a68ff8 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -20,6 +20,7 @@ extern llvm::cl::opt<bool> treatIndexAsSection; namespace fir { class FirOpBuilder; +class RecordType; } // namespace fir namespace Fortran { @@ -136,6 +137,13 @@ mlir::Value createParentSymAndGenIntermediateMaps( OmpMapParentAndMemberData &parentMemberIndices, llvm::StringRef asFortran, mlir::omp::ClauseMapFlags mapTypeBits); +mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper( + Fortran::lower::AbstractConverter &converter, mlir::Location loc, + fir::RecordType recordType, llvm::StringRef mapperNameStr); + +bool requiresImplicitDefaultDeclareMapper( + const semantics::DerivedTypeSpec &typeSpec); + omp::ObjectList gatherObjectsOf(omp::Object derivedTypeMember, semantics::SemanticsContext &semaCtx); @@ -159,6 +167,8 @@ void genObjectList(const ObjectList &objects, void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp, mlir::Location loc); +pft::Evaluation *getNestedDoConstruct(pft::Evaluation &eval); + int64_t collectLoopRelatedInfo( lower::AbstractConverter &converter, mlir::Location currentLocation, lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses, diff --git a/flang/lib/Lower/Runtime.cpp b/flang/lib/Lower/Runtime.cpp index cb55524..5f8586b 100644 --- a/flang/lib/Lower/Runtime.cpp +++ b/flang/lib/Lower/Runtime.cpp @@ -48,31 +48,6 @@ static void genUnreachable(fir::FirOpBuilder &builder, mlir::Location loc) { builder.setInsertionPointToStart(newBlock); } -/// Initializes values for STAT and ERRMSG -static std::pair<mlir::Value, mlir::Value> getStatAndErrmsg( - Fortran::lower::AbstractConverter &converter, mlir::Location loc, - const std::list<Fortran::parser::StatOrErrmsg> &statOrErrList) { - Fortran::lower::StatementContext stmtCtx; - - mlir::Value errMsgExpr, statExpr; - for (const Fortran::parser::StatOrErrmsg &statOrErr : statOrErrList) { - std::visit(Fortran::common::visitors{ - [&](const Fortran::parser::StatVariable &statVar) { - statExpr = fir::getBase(converter.genExprAddr( - loc, Fortran::semantics::GetExpr(statVar), stmtCtx)); - }, - [&](const Fortran::parser::MsgVariable &errMsgVar) { - const Fortran::semantics::SomeExpr *expr = - Fortran::semantics::GetExpr(errMsgVar); - errMsgExpr = fir::getBase( - converter.genExprBox(loc, *expr, stmtCtx)); - }}, - statOrErr.u); - } - - return {statExpr, errMsgExpr}; -} - //===----------------------------------------------------------------------===// // Misc. Fortran statements that lower to runtime calls //===----------------------------------------------------------------------===// @@ -115,8 +90,7 @@ void Fortran::lower::genStopStatement( operands.push_back(cast); }, [&](auto) { - mlir::emitError(loc, "unhandled expression in STOP"); - std::exit(1); + fir::emitFatalError(loc, "unhandled expression in STOP"); }); } else { callee = fir::runtime::getRuntimeFunc<mkRTKey(StopStatement)>(loc, builder); @@ -193,82 +167,57 @@ void Fortran::lower::genUnlockStatement( TODO(converter.getCurrentLocation(), "coarray: UNLOCK runtime"); } -void Fortran::lower::genSyncAllStatement( +void Fortran::lower::genPauseStatement( Fortran::lower::AbstractConverter &converter, - const Fortran::parser::SyncAllStmt &stmt) { - mlir::Location loc = converter.getCurrentLocation(); - converter.checkCoarrayEnabled(); - - // Handle STAT and ERRMSG values - const std::list<Fortran::parser::StatOrErrmsg> &statOrErrList = stmt.v; - auto [statAddr, errMsgAddr] = getStatAndErrmsg(converter, loc, statOrErrList); + const Fortran::parser::PauseStmt &stmt) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); - mif::SyncAllOp::create(builder, loc, statAddr, errMsgAddr); -} - -void Fortran::lower::genSyncImagesStatement( - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::SyncImagesStmt &stmt) { mlir::Location loc = converter.getCurrentLocation(); - converter.checkCoarrayEnabled(); - fir::FirOpBuilder &builder = converter.getFirOpBuilder(); - - // Handle STAT and ERRMSG values - const std::list<Fortran::parser::StatOrErrmsg> &statOrErrList = - std::get<std::list<Fortran::parser::StatOrErrmsg>>(stmt.t); - auto [statAddr, errMsgAddr] = getStatAndErrmsg(converter, loc, statOrErrList); - - // SYNC_IMAGES(*) is passed as count == -1 while SYNC IMAGES([]) has count - // == 0. Note further that SYNC IMAGES(*) is not semantically equivalent to - // SYNC ALL. Fortran::lower::StatementContext stmtCtx; - mlir::Value imageSet; - const Fortran::parser::SyncImagesStmt::ImageSet &imgSet = - std::get<Fortran::parser::SyncImagesStmt::ImageSet>(stmt.t); - std::visit(Fortran::common::visitors{ - [&](const Fortran::parser::IntExpr &intExpr) { - const SomeExpr *expr = Fortran::semantics::GetExpr(intExpr); - imageSet = - fir::getBase(converter.genExprBox(loc, *expr, stmtCtx)); - }, - [&](const Fortran::parser::Star &) { - // Image set is not set. - imageSet = mlir::Value{}; - }}, - imgSet.u); - - mif::SyncImagesOp::create(builder, loc, imageSet, statAddr, errMsgAddr); -} -void Fortran::lower::genSyncMemoryStatement( - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::SyncMemoryStmt &stmt) { - mlir::Location loc = converter.getCurrentLocation(); - converter.checkCoarrayEnabled(); + llvm::SmallVector<mlir::Value> operands; + mlir::func::FuncOp callee; + mlir::FunctionType calleeType; - // Handle STAT and ERRMSG values - const std::list<Fortran::parser::StatOrErrmsg> &statOrErrList = stmt.v; - auto [statAddr, errMsgAddr] = getStatAndErrmsg(converter, loc, statOrErrList); + if (stmt.v.has_value()) { + const auto &code = stmt.v.value(); + auto expr = + converter.genExprValue(*Fortran::semantics::GetExpr(code), stmtCtx); + expr.match( + // Character-valued expression -> call PauseStatementText (CHAR, LEN) + [&](const fir::CharBoxValue &x) { + callee = fir::runtime::getRuntimeFunc<mkRTKey(PauseStatementText)>( + loc, builder); + calleeType = callee.getFunctionType(); - fir::FirOpBuilder &builder = converter.getFirOpBuilder(); - mif::SyncMemoryOp::create(builder, loc, statAddr, errMsgAddr); -} + operands.push_back( + builder.createConvert(loc, calleeType.getInput(0), x.getAddr())); + operands.push_back( + builder.createConvert(loc, calleeType.getInput(1), x.getLen())); + }, + // Unboxed value -> call PauseStatementInt which accepts an integer. + [&](fir::UnboxedValue x) { + callee = fir::runtime::getRuntimeFunc<mkRTKey(PauseStatementInt)>( + loc, builder); + calleeType = callee.getFunctionType(); + assert(calleeType.getNumInputs() >= 1); + mlir::Value cast = + builder.createConvert(loc, calleeType.getInput(0), x); + operands.push_back(cast); + }, + [&](auto) { + fir::emitFatalError(loc, "unhandled expression in PAUSE"); + }); + } else { + callee = + fir::runtime::getRuntimeFunc<mkRTKey(PauseStatement)>(loc, builder); + calleeType = callee.getFunctionType(); + } -void Fortran::lower::genSyncTeamStatement( - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::SyncTeamStmt &) { - TODO(converter.getCurrentLocation(), "coarray: SYNC TEAM runtime"); -} + fir::CallOp::create(builder, loc, callee, operands); -void Fortran::lower::genPauseStatement( - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::PauseStmt &) { - fir::FirOpBuilder &builder = converter.getFirOpBuilder(); - mlir::Location loc = converter.getCurrentLocation(); - mlir::func::FuncOp callee = - fir::runtime::getRuntimeFunc<mkRTKey(PauseStatement)>(loc, builder); - fir::CallOp::create(builder, loc, callee, mlir::ValueRange{}); + // NOTE: PAUSE does not terminate the current block. The program may resume + // and continue normal execution, so we do not emit control-flow terminators. } void Fortran::lower::genPointerAssociate(fir::FirOpBuilder &builder, diff --git a/flang/lib/Lower/Support/ReductionProcessor.cpp b/flang/lib/Lower/Support/ReductionProcessor.cpp index 605a5b6b..db8ad90 100644 --- a/flang/lib/Lower/Support/ReductionProcessor.cpp +++ b/flang/lib/Lower/Support/ReductionProcessor.cpp @@ -501,7 +501,7 @@ static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) { template <typename OpType> static void createReductionAllocAndInitRegions( AbstractConverter &converter, mlir::Location loc, OpType &reductionDecl, - const ReductionProcessor::ReductionIdentifier redId, mlir::Type type, + ReductionProcessor::GenInitValueCBTy genInitValueCB, mlir::Type type, bool isByRef) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); auto yield = [&](mlir::Value ret) { genYield<OpType>(builder, loc, ret); }; @@ -523,9 +523,8 @@ static void createReductionAllocAndInitRegions( mlir::Type ty = fir::unwrapRefType(type); builder.setInsertionPointToEnd(initBlock); - mlir::Value initValue = ReductionProcessor::getReductionInitValue( - loc, unwrapSeqOrBoxedType(ty), redId, builder); - + mlir::Value initValue = + genInitValueCB(builder, loc, ty, initBlock->getArgument(0)); if (isByRef) { populateByRefInitAndCleanupRegions( converter, loc, type, initValue, initBlock, @@ -536,7 +535,7 @@ static void createReductionAllocAndInitRegions( /*isDoConcurrent*/ std::is_same_v<OpType, fir::DeclareReductionOp>); } - if (fir::isa_trivial(ty)) { + if (fir::isa_trivial(ty) || fir::isa_derived(ty)) { if (isByRef) { // alloc region builder.setInsertionPointToEnd(allocBlock); @@ -556,43 +555,117 @@ static void createReductionAllocAndInitRegions( yield(boxAlloca); } -template <typename OpType> -OpType ReductionProcessor::createDeclareReduction( +template <typename DeclareRedType> +DeclareRedType ReductionProcessor::createDeclareReductionHelper( AbstractConverter &converter, llvm::StringRef reductionOpName, - const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, - bool isByRef) { + mlir::Type type, mlir::Location loc, bool isByRef, + GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB) { fir::FirOpBuilder &builder = converter.getFirOpBuilder(); mlir::OpBuilder::InsertionGuard guard(builder); mlir::ModuleOp module = builder.getModule(); assert(!reductionOpName.empty()); - auto decl = module.lookupSymbol<OpType>(reductionOpName); + auto decl = module.lookupSymbol<DeclareRedType>(reductionOpName); if (decl) return decl; mlir::OpBuilder modBuilder(module.getBodyRegion()); mlir::Type valTy = fir::unwrapRefType(type); - if (!isByRef) + + // For by-ref reductions, we want to keep track of the + // boxed/referenced/allocated type. For example, for a `real, allocatable` + // variable, `real` should be stored. + mlir::TypeAttr boxedTyAttr{}; + mlir::Type boxedTy; + + if (isByRef) { + boxedTy = fir::unwrapPassByRefType(valTy); + boxedTyAttr = mlir::TypeAttr::get(boxedTy); + } else type = valTy; - decl = OpType::create(modBuilder, loc, reductionOpName, type); - createReductionAllocAndInitRegions(converter, loc, decl, redId, type, + decl = DeclareRedType::create(modBuilder, loc, reductionOpName, type, + boxedTyAttr); + createReductionAllocAndInitRegions(converter, loc, decl, genInitValueCB, type, isByRef); - builder.createBlock(&decl.getReductionRegion(), decl.getReductionRegion().end(), {type, type}, {loc, loc}); - builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); - genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef); + genCombinerCB(builder, loc, type, op1, op2, isByRef); + + if (isByRef && fir::isa_box_type(valTy)) { + bool isBoxReductionSupported = [&]() { + auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>( + *builder.getModule()); + + // This check tests the implementation status on the GPU. Box reductions + // are fully supported on the CPU. + if (!offloadMod.getIsGPU()) + return true; + + auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxedTy); + + // Dynamically-shaped arrays are not supported yet on the GPU. + return !seqTy || !fir::sequenceWithNonConstantShape(seqTy); + }(); + + if (!isBoxReductionSupported) { + TODO(loc, "Reduction of dynamically-shaped arrays are not supported yet " + "on the GPU."); + } + + mlir::Region &dataPtrPtrRegion = decl.getDataPtrPtrRegion(); + mlir::Block &dataAddrBlock = *builder.createBlock( + &dataPtrPtrRegion, dataPtrPtrRegion.end(), {type}, {loc}); + builder.setInsertionPointToEnd(&dataAddrBlock); + mlir::Value boxRefOperand = dataAddrBlock.getArgument(0); + mlir::Value baseAddrOffset = fir::BoxOffsetOp::create( + builder, loc, boxRefOperand, fir::BoxFieldAttr::base_addr); + genYield<DeclareRedType>(builder, loc, baseAddrOffset); + } return decl; } -static bool doReductionByRef(mlir::Value reductionVar) { +template <typename OpType> +OpType ReductionProcessor::createDeclareReduction( + AbstractConverter &converter, llvm::StringRef reductionOpName, + const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, + bool isByRef) { + auto genInitValueCB = [&](fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type type, mlir::Value val) { + mlir::Type ty = fir::unwrapRefType(type); + mlir::Value initValue = ReductionProcessor::getReductionInitValue( + loc, unwrapSeqOrBoxedType(ty), redId, builder); + return initValue; + }; + auto genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type type, mlir::Value op1, mlir::Value op2, + bool isByRef) { + genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef); + }; + + return createDeclareReductionHelper<OpType>(converter, reductionOpName, type, + loc, isByRef, genCombinerCB, + genInitValueCB); +} + +bool ReductionProcessor::doReductionByRef(mlir::Type reductionType) { + if (forceByrefReduction) + return true; + + if (!fir::isa_trivial(fir::unwrapRefType(reductionType)) && + !fir::isa_derived(fir::unwrapRefType(reductionType))) + return true; + + return false; +} + +bool ReductionProcessor::doReductionByRef(mlir::Value reductionVar) { if (forceByrefReduction) return true; @@ -600,10 +673,7 @@ static bool doReductionByRef(mlir::Value reductionVar) { mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp())) reductionVar = declare.getMemref(); - if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) - return true; - - return false; + return doReductionByRef(reductionVar.getType()); } template <typename OpType, typename RedOperatorListTy> @@ -614,6 +684,8 @@ bool ReductionProcessor::processReductionArguments( llvm::SmallVectorImpl<bool> &reduceVarByRef, llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) { + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + if constexpr (std::is_same_v<RedOperatorListTy, omp::clause::ReductionOperatorList>) { // For OpenMP reduction clauses, check if the reduction operator is @@ -627,7 +699,13 @@ bool ReductionProcessor::processReductionArguments( std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) { if (!ReductionProcessor::supportedIntrinsicProcReduction( *reductionIntrinsic)) { - return false; + // If not an intrinsic is has to be a custom reduction op, and should + // be available in the module. + semantics::Symbol *sym = reductionIntrinsic->v.sym(); + mlir::ModuleOp module = builder.getModule(); + auto decl = module.lookupSymbol<OpType>(getRealName(sym).ToString()); + if (!decl) + return false; } } else { return false; @@ -637,7 +715,6 @@ bool ReductionProcessor::processReductionArguments( // Reduction variable processing common to both intrinsic operators and // procedure designators - fir::FirOpBuilder &builder = converter.getFirOpBuilder(); mlir::OpBuilder::InsertPoint dcIP; constexpr bool isDoConcurrent = std::is_same_v<OpType, fir::DeclareReductionOp>; @@ -741,7 +818,13 @@ bool ReductionProcessor::processReductionArguments( &redOperator.u)) { if (!ReductionProcessor::supportedIntrinsicProcReduction( *reductionIntrinsic)) { - TODO(currentLocation, "Unsupported intrinsic proc reduction"); + // Custom reductions we can just add to the symbols without + // generating the declare reduction op. + semantics::Symbol *sym = reductionIntrinsic->v.sym(); + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + builder.getContext(), sym->name().ToString())); + ++idx; + continue; } redId = getReductionType(*reductionIntrinsic); reductionName = diff --git a/flang/lib/Lower/Support/Utils.cpp b/flang/lib/Lower/Support/Utils.cpp index 1b4d37e..4b95a3a 100644 --- a/flang/lib/Lower/Support/Utils.cpp +++ b/flang/lib/Lower/Support/Utils.cpp @@ -82,7 +82,7 @@ public: x.cosubscript()) cosubs -= getHashValue(v); return getHashValue(x.base()) * 97u - cosubs + getHashValue(x.stat()) + - 257u + getHashValue(x.team()); + 257u + getHashValue(x.team()) + getHashValue(x.notify()); } static unsigned getHashValue(const Fortran::evaluate::NamedEntity &x) { if (x.IsSymbol()) @@ -341,7 +341,8 @@ public: const Fortran::evaluate::CoarrayRef &y) { return isEqual(x.base(), y.base()) && isEqual(x.cosubscript(), y.cosubscript()) && - isEqual(x.stat(), y.stat()) && isEqual(x.team(), y.team()); + isEqual(x.stat(), y.stat()) && isEqual(x.team(), y.team()) && + isEqual(x.notify(), y.notify()); } static bool isEqual(const Fortran::evaluate::NamedEntity &x, const Fortran::evaluate::NamedEntity &y) { |
