diff options
Diffstat (limited to 'flang/lib/Lower/OpenMP.cpp')
-rw-r--r-- | flang/lib/Lower/OpenMP.cpp | 698 |
1 files changed, 378 insertions, 320 deletions
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index ad4cffc..06850be 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -552,8 +552,9 @@ class ClauseProcessor { public: ClauseProcessor(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, const Fortran::parser::OmpClauseList &clauses) - : converter(converter), clauses(clauses) {} + : converter(converter), semaCtx(semaCtx), clauses(clauses) {} // 'Unique' clauses: They can appear at most once in the clause list. bool @@ -614,17 +615,18 @@ public: // target directives that require it. bool processMap(mlir::Location currentLocation, const llvm::omp::Directive &directive, - Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::StatementContext &stmtCtx, llvm::SmallVectorImpl<mlir::Value> &mapOperands, llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr, llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols = nullptr) const; - bool processReduction( - mlir::Location currentLocation, - llvm::SmallVectorImpl<mlir::Value> &reductionVars, - llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const; + bool + processReduction(mlir::Location currentLocation, + llvm::SmallVectorImpl<mlir::Value> &reductionVars, + llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> + *reductionSymbols = nullptr) const; bool processSectionsReduction(mlir::Location currentLocation) const; bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; bool @@ -641,10 +643,8 @@ public: &useDeviceSymbols) const; template <typename T> - bool - processMotionClauses(Fortran::semantics::SemanticsContext &semanticsContext, - Fortran::lower::StatementContext &stmtCtx, - llvm::SmallVectorImpl<mlir::Value> &mapOperands); + bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx, + llvm::SmallVectorImpl<mlir::Value> &mapOperands); // Call this method for these clauses that should be supported but are not // implemented yet. It triggers a compilation error if any of the given @@ -713,6 +713,7 @@ private: } Fortran::lower::AbstractConverter &converter; + Fortran::semantics::SemanticsContext &semaCtx; const Fortran::parser::OmpClauseList &clauses; }; @@ -731,21 +732,59 @@ static void checkMapType(mlir::Location location, mlir::Type type) { class ReductionProcessor { public: - enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR }; - static IntrinsicProc + // TODO: Move this enumeration to the OpenMP dialect + enum ReductionIdentifier { + ID, + USER_DEF_OP, + ADD, + SUBTRACT, + MULTIPLY, + AND, + OR, + EQV, + NEQV, + MAX, + MIN, + IAND, + IOR, + IEOR + }; + static ReductionIdentifier getReductionType(const Fortran::parser::ProcedureDesignator &pd) { - auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>( + auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( getRealName(pd).ToString()) - .Case("max", IntrinsicProc::MAX) - .Case("min", IntrinsicProc::MIN) - .Case("iand", IntrinsicProc::IAND) - .Case("ior", IntrinsicProc::IOR) - .Case("ieor", IntrinsicProc::IEOR) + .Case("max", ReductionIdentifier::MAX) + .Case("min", ReductionIdentifier::MIN) + .Case("iand", ReductionIdentifier::IAND) + .Case("ior", ReductionIdentifier::IOR) + .Case("ieor", ReductionIdentifier::IEOR) .Default(std::nullopt); assert(redType && "Invalid Reduction"); return *redType; } + static ReductionIdentifier getReductionType( + Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) { + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + return ReductionIdentifier::ADD; + case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract: + return ReductionIdentifier::SUBTRACT; + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + return ReductionIdentifier::MULTIPLY; + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + return ReductionIdentifier::AND; + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + return ReductionIdentifier::EQV; + case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + return ReductionIdentifier::OR; + case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + return ReductionIdentifier::NEQV; + default: + llvm_unreachable("unexpected intrinsic operator in reduction"); + } + } + static bool supportedIntrinsicProcReduction( const Fortran::parser::ProcedureDesignator &pd) { const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)}; @@ -753,17 +792,14 @@ public: if (!name->symbol->GetUltimate().attrs().test( Fortran::semantics::Attr::INTRINSIC)) return false; - auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>( - getRealName(name).ToString()) - .Case("max", IntrinsicProc::MAX) - .Case("min", IntrinsicProc::MIN) - .Case("iand", IntrinsicProc::IAND) - .Case("ior", IntrinsicProc::IOR) - .Case("ieor", IntrinsicProc::IEOR) - .Default(std::nullopt); - if (redType) - return true; - return false; + auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString()) + .Case("max", true) + .Case("min", true) + .Case("iand", true) + .Case("ior", true) + .Case("ieor", true) + .Default(false); + return redType; } static const Fortran::semantics::SourceName @@ -817,32 +853,30 @@ public: /// reductionOpName. For example: /// 0 + x = x, /// 1 * x = x - static int getOperationIdentity( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Location loc) { - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + static int getOperationIdentity(ReductionIdentifier redId, + mlir::Location loc) { + switch (redId) { + case ReductionIdentifier::ADD: + case ReductionIdentifier::OR: + case ReductionIdentifier::NEQV: return 0; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + case ReductionIdentifier::MULTIPLY: + case ReductionIdentifier::AND: + case ReductionIdentifier::EQV: return 1; default: TODO(loc, "Reduction of some intrinsic operators is not supported"); } } - static mlir::Value getIntrinsicProcInitValue( - mlir::Location loc, mlir::Type type, - const Fortran::parser::ProcedureDesignator &procDesignator, - fir::FirOpBuilder &builder) { + static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type, + ReductionIdentifier redId, + fir::FirOpBuilder &builder) { assert((fir::isa_integer(type) || fir::isa_real(type) || type.isa<fir::LogicalType>()) && "only integer, logical and real types are currently supported"); - switch (getReductionType(procDesignator)) { - case IntrinsicProc::MAX: { + switch (redId) { + case ReductionIdentifier::MAX: { if (auto ty = type.dyn_cast<mlir::FloatType>()) { const llvm::fltSemantics &sem = ty.getFloatSemantics(); return builder.createRealConstant( @@ -852,7 +886,7 @@ public: int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); return builder.createIntegerConstant(loc, type, minInt); } - case IntrinsicProc::MIN: { + case ReductionIdentifier::MIN: { if (auto ty = type.dyn_cast<mlir::FloatType>()) { const llvm::fltSemantics &sem = ty.getFloatSemantics(); return builder.createRealConstant( @@ -862,46 +896,50 @@ public: int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); return builder.createIntegerConstant(loc, type, maxInt); } - case IntrinsicProc::IOR: { + case ReductionIdentifier::IOR: { unsigned bits = type.getIntOrFloatBitWidth(); int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); return builder.createIntegerConstant(loc, type, zeroInt); } - case IntrinsicProc::IEOR: { + case ReductionIdentifier::IEOR: { unsigned bits = type.getIntOrFloatBitWidth(); int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); return builder.createIntegerConstant(loc, type, zeroInt); } - case IntrinsicProc::IAND: { + case ReductionIdentifier::IAND: { unsigned bits = type.getIntOrFloatBitWidth(); int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); return builder.createIntegerConstant(loc, type, allOnInt); } - } - llvm_unreachable("Unknown Reduction Intrinsic"); - } + case ReductionIdentifier::ADD: + case ReductionIdentifier::MULTIPLY: + case ReductionIdentifier::AND: + case ReductionIdentifier::OR: + case ReductionIdentifier::EQV: + case ReductionIdentifier::NEQV: + if (type.isa<mlir::FloatType>()) + return builder.create<mlir::arith::ConstantOp>( + loc, type, + builder.getFloatAttr(type, + (double)getOperationIdentity(redId, loc))); + + if (type.isa<fir::LogicalType>()) { + mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( + loc, builder.getI1Type(), + builder.getIntegerAttr(builder.getI1Type(), + getOperationIdentity(redId, loc))); + return builder.createConvert(loc, type, intConst); + } - static mlir::Value getIntrinsicOpInitValue( - mlir::Location loc, mlir::Type type, - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - fir::FirOpBuilder &builder) { - if (type.isa<mlir::FloatType>()) return builder.create<mlir::arith::ConstantOp>( loc, type, - builder.getFloatAttr(type, - (double)getOperationIdentity(intrinsicOp, loc))); - - if (type.isa<fir::LogicalType>()) { - mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( - loc, builder.getI1Type(), - builder.getIntegerAttr(builder.getI1Type(), - getOperationIdentity(intrinsicOp, loc))); - return builder.createConvert(loc, type, intConst); + builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); + case ReductionIdentifier::ID: + case ReductionIdentifier::USER_DEF_OP: + case ReductionIdentifier::SUBTRACT: + TODO(loc, "Reduction of some identifier types is not supported"); } - - return builder.create<mlir::arith::ConstantOp>( - loc, type, - builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc))); + llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); } template <typename FloatOp, typename IntegerOp> @@ -915,118 +953,46 @@ public: return builder.create<FloatOp>(loc, op1, op2); } - /// Creates an OpenMP reduction declaration and inserts it into the provided - /// symbol table. The declaration has a constant initializer with the neutral - /// value `initValue`, and the reduction combiner carried over from `reduce`. - /// TODO: Generalize this for non-integer types, add atomic region. - static mlir::omp::ReductionDeclareOp createReductionDecl( - fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, - const Fortran::parser::ProcedureDesignator &procDesignator, - mlir::Type type, mlir::Location loc) { - mlir::OpBuilder::InsertionGuard guard(builder); - mlir::ModuleOp module = builder.getModule(); - - auto decl = - module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName); - if (decl) - return decl; - - mlir::OpBuilder modBuilder(module.getBodyRegion()); - - decl = modBuilder.create<mlir::omp::ReductionDeclareOp>( - loc, reductionOpName, type); - builder.createBlock(&decl.getInitializerRegion(), - decl.getInitializerRegion().end(), {type}, {loc}); - builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); - mlir::Value init = - getIntrinsicProcInitValue(loc, type, procDesignator, builder); - builder.create<mlir::omp::YieldOp>(loc, init); - - builder.createBlock(&decl.getReductionRegion(), - decl.getReductionRegion().end(), {type, type}, - {loc, loc}); - - builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); - mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); - mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); - + static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder, + mlir::Location loc, + ReductionIdentifier redId, + mlir::Type type, mlir::Value op1, + mlir::Value op2) { mlir::Value reductionOp; - switch (getReductionType(procDesignator)) { - case IntrinsicProc::MAX: + switch (redId) { + case ReductionIdentifier::MAX: reductionOp = getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>( builder, type, loc, op1, op2); break; - case IntrinsicProc::MIN: + case ReductionIdentifier::MIN: reductionOp = getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>( builder, type, loc, op1, op2); break; - case IntrinsicProc::IOR: + case ReductionIdentifier::IOR: assert((type.isIntOrIndex()) && "only integer is expected"); reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); break; - case IntrinsicProc::IEOR: + case ReductionIdentifier::IEOR: assert((type.isIntOrIndex()) && "only integer is expected"); reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); break; - case IntrinsicProc::IAND: + case ReductionIdentifier::IAND: assert((type.isIntOrIndex()) && "only integer is expected"); reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); break; - } - - builder.create<mlir::omp::YieldOp>(loc, reductionOp); - return decl; - } - - /// Creates an OpenMP reduction declaration and inserts it into the provided - /// symbol table. The declaration has a constant initializer with the neutral - /// value `initValue`, and the reduction combiner carried over from `reduce`. - /// TODO: Generalize this for non-integer types, add atomic region. - static mlir::omp::ReductionDeclareOp createReductionDecl( - fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Type type, mlir::Location loc) { - mlir::OpBuilder::InsertionGuard guard(builder); - mlir::ModuleOp module = builder.getModule(); - - auto decl = - module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName); - if (decl) - return decl; - - mlir::OpBuilder modBuilder(module.getBodyRegion()); - - decl = modBuilder.create<mlir::omp::ReductionDeclareOp>( - loc, reductionOpName, type); - builder.createBlock(&decl.getInitializerRegion(), - decl.getInitializerRegion().end(), {type}, {loc}); - builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); - mlir::Value init = getIntrinsicOpInitValue(loc, type, intrinsicOp, builder); - builder.create<mlir::omp::YieldOp>(loc, init); - - builder.createBlock(&decl.getReductionRegion(), - decl.getReductionRegion().end(), {type, type}, - {loc, loc}); - - builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); - mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); - mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); - - mlir::Value reductionOp; - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case ReductionIdentifier::ADD: reductionOp = getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>( builder, type, loc, op1, op2); break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case ReductionIdentifier::MULTIPLY: reductionOp = getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>( builder, type, loc, op1, op2); break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: { + case ReductionIdentifier::AND: { mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); @@ -1036,7 +1002,7 @@ public: reductionOp = builder.createConvert(loc, type, andiOp); break; } - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: { + case ReductionIdentifier::OR: { mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); @@ -1045,7 +1011,7 @@ public: reductionOp = builder.createConvert(loc, type, oriOp); break; } - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: { + case ReductionIdentifier::EQV: { mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); @@ -1055,7 +1021,7 @@ public: reductionOp = builder.createConvert(loc, type, cmpiOp); break; } - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: { + case ReductionIdentifier::NEQV: { mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); @@ -1069,18 +1035,59 @@ public: TODO(loc, "Reduction of some intrinsic operators is not supported"); } + return reductionOp; + } + + /// Creates an OpenMP reduction declaration and inserts it into the provided + /// symbol table. The declaration has a constant initializer with the neutral + /// value `initValue`, and the reduction combiner carried over from `reduce`. + /// TODO: Generalize this for non-integer types, add atomic region. + static mlir::omp::ReductionDeclareOp createReductionDecl( + fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, + const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) { + mlir::OpBuilder::InsertionGuard guard(builder); + mlir::ModuleOp module = builder.getModule(); + + auto decl = + module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName); + if (decl) + return decl; + + mlir::OpBuilder modBuilder(module.getBodyRegion()); + + decl = modBuilder.create<mlir::omp::ReductionDeclareOp>( + loc, reductionOpName, type); + builder.createBlock(&decl.getInitializerRegion(), + decl.getInitializerRegion().end(), {type}, {loc}); + builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); + mlir::Value init = getReductionInitValue(loc, type, redId, builder); + builder.create<mlir::omp::YieldOp>(loc, init); + + builder.createBlock(&decl.getReductionRegion(), + decl.getReductionRegion().end(), {type, type}, + {loc, loc}); + + builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); + mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); + mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); + + mlir::Value reductionOp = + createScalarCombiner(builder, loc, redId, type, op1, op2); builder.create<mlir::omp::YieldOp>(loc, reductionOp); + return decl; } /// Creates a reduction declaration and associates it with an OpenMP block /// directive. - static void addReductionDecl( - mlir::Location currentLocation, - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpReductionClause &reduction, - llvm::SmallVectorImpl<mlir::Value> &reductionVars, - llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) { + static void + addReductionDecl(mlir::Location currentLocation, + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpReductionClause &reduction, + llvm::SmallVectorImpl<mlir::Value> &reductionVars, + llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> + *reductionSymbols = nullptr) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::omp::ReductionDeclareOp decl; const auto &redOperator{ @@ -1092,15 +1099,15 @@ public: const auto &intrinsicOp{ std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( redDefinedOp->u)}; - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + ReductionIdentifier redId = getReductionType(intrinsicOp); + switch (redId) { + case ReductionIdentifier::ADD: + case ReductionIdentifier::MULTIPLY: + case ReductionIdentifier::AND: + case ReductionIdentifier::EQV: + case ReductionIdentifier::OR: + case ReductionIdentifier::NEQV: break; - default: TODO(currentLocation, "Reduction of some intrinsic operators is not supported"); @@ -1110,6 +1117,8 @@ public: if (const auto *name{ Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + if (reductionSymbols) + reductionSymbols->push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) symVal = declOp.getBase(); @@ -1120,11 +1129,11 @@ public: decl = createReductionDecl( firOpBuilder, getReductionName(intrinsicOp, firOpBuilder.getI1Type()), - intrinsicOp, redType, currentLocation); + redId, redType, currentLocation); else if (redType.isIntOrIndexOrFloat()) { decl = createReductionDecl(firOpBuilder, getReductionName(intrinsicOp, redType), - intrinsicOp, redType, currentLocation); + redId, redType, currentLocation); } else { TODO(currentLocation, "Reduction of some types is not supported"); } @@ -1138,10 +1147,14 @@ public: &redOperator.u)) { if (ReductionProcessor::supportedIntrinsicProcReduction( *reductionIntrinsic)) { + ReductionProcessor::ReductionIdentifier redId = + ReductionProcessor::getReductionType(*reductionIntrinsic); for (const Fortran::parser::OmpObject &ompObject : objectList.v) { if (const auto *name{ Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + if (reductionSymbols) + reductionSymbols->push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) symVal = declOp.getBase(); @@ -1154,7 +1167,7 @@ public: firOpBuilder, getReductionName(getRealName(*reductionIntrinsic).ToString(), redType), - *reductionIntrinsic, redType, currentLocation); + redId, redType, currentLocation); reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( firOpBuilder.getContext(), decl.getSymName())); } @@ -1845,7 +1858,6 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, bool ClauseProcessor::processMap( mlir::Location currentLocation, const llvm::omp::Directive &directive, - Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::StatementContext &stmtCtx, llvm::SmallVectorImpl<mlir::Value> &mapOperands, llvm::SmallVectorImpl<mlir::Type> *mapSymTypes, @@ -1907,7 +1919,7 @@ bool ClauseProcessor::processMap( Fortran::lower::gatherDataOperandAddrAndBounds< Fortran::parser::OmpObject, mlir::omp::DataBoundsOp, mlir::omp::DataBoundsType>( - converter, firOpBuilder, semanticsContext, stmtCtx, ompObject, + converter, firOpBuilder, semaCtx, stmtCtx, ompObject, clauseLocation, asFortran, bounds, treatIndexAsSection); auto origSymbol = @@ -1942,13 +1954,16 @@ bool ClauseProcessor::processMap( bool ClauseProcessor::processReduction( mlir::Location currentLocation, llvm::SmallVectorImpl<mlir::Value> &reductionVars, - llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const { + llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols) + const { return findRepeatableClause<ClauseTy::Reduction>( [&](const ClauseTy::Reduction *reductionClause, const Fortran::parser::CharBlock &) { ReductionProcessor rp; rp.addReductionDecl(currentLocation, converter, reductionClause->v, - reductionVars, reductionDeclSymbols); + reductionVars, reductionDeclSymbols, + reductionSymbols); }); } @@ -2012,7 +2027,6 @@ bool ClauseProcessor::processUseDevicePtr( template <typename T> bool ClauseProcessor::processMotionClauses( - Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::StatementContext &stmtCtx, llvm::SmallVectorImpl<mlir::Value> &mapOperands) { return findRepeatableClause<T>( @@ -2036,7 +2050,7 @@ bool ClauseProcessor::processMotionClauses( Fortran::lower::gatherDataOperandAddrAndBounds< Fortran::parser::OmpObject, mlir::omp::DataBoundsOp, mlir::omp::DataBoundsType>( - converter, firOpBuilder, semanticsContext, stmtCtx, ompObject, + converter, firOpBuilder, semaCtx, stmtCtx, ompObject, clauseLocation, asFortran, bounds, treatIndexAsSection); auto origSymbol = @@ -2275,8 +2289,9 @@ struct OpWithBodyGenInfo { mlir::Operation *)>; OpWithBodyGenInfo(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, mlir::Location loc, Fortran::lower::pft::Evaluation &eval) - : converter(converter), loc(loc), eval(eval) {} + : converter(converter), semaCtx(semaCtx), loc(loc), eval(eval) {} OpWithBodyGenInfo &setGenNested(bool value) { genNested = value; @@ -2298,6 +2313,14 @@ struct OpWithBodyGenInfo { return *this; } + OpWithBodyGenInfo & + setReductions(llvm::SmallVector<const Fortran::semantics::Symbol *> *value1, + llvm::SmallVector<mlir::Type> *value2) { + reductionSymbols = value1; + reductionTypes = value2; + return *this; + } + OpWithBodyGenInfo &setGenRegionEntryCb(GenOMPRegionEntryCBFn value) { genRegionEntryCB = value; return *this; @@ -2305,6 +2328,8 @@ struct OpWithBodyGenInfo { /// [inout] converter to use for the clauses. Fortran::lower::AbstractConverter &converter; + /// [in] Semantics context + Fortran::semantics::SemanticsContext &semaCtx; /// [in] location in source code. mlir::Location loc; /// [in] current PFT node/evaluation. @@ -2317,6 +2342,11 @@ struct OpWithBodyGenInfo { const Fortran::parser::OmpClauseList *clauses = nullptr; /// [in] if provided, processes the construct's data-sharing attributes. DataSharingProcessor *dsp = nullptr; + /// [in] if provided, list of reduction symbols + llvm::SmallVector<const Fortran::semantics::Symbol *> *reductionSymbols = + nullptr; + /// [in] if provided, list of reduction types + llvm::SmallVector<mlir::Type> *reductionTypes = nullptr; /// [in] if provided, emits the op's region entry. Otherwise, an emtpy block /// is created in the region. GenOMPRegionEntryCBFn genRegionEntryCB = nullptr; @@ -2378,7 +2408,8 @@ static void createBodyOfOp(Op &op, OpWithBodyGenInfo &info) { threadPrivatizeVars(info.converter, info.eval); if (info.clauses) { firOpBuilder.setInsertionPoint(marker); - ClauseProcessor(info.converter, *info.clauses).processCopyin(); + ClauseProcessor(info.converter, info.semaCtx, *info.clauses) + .processCopyin(); } } @@ -2458,6 +2489,7 @@ static void createBodyOfOp(Op &op, OpWithBodyGenInfo &info) { static void genBodyOfTargetDataOp( Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::omp::DataOp &dataOp, const llvm::SmallVector<mlir::Type> &useDeviceTypes, @@ -2531,26 +2563,29 @@ static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) { static mlir::omp::MasterOp genMasterOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::Location currentLocation) { return genOpWithBody<mlir::omp::MasterOp>( - OpWithBodyGenInfo(converter, currentLocation, eval) + OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested), /*resultTypes=*/mlir::TypeRange()); } static mlir::omp::OrderedRegionOp genOrderedRegionOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::Location currentLocation) { return genOpWithBody<mlir::omp::OrderedRegionOp>( - OpWithBodyGenInfo(converter, currentLocation, eval) + OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested), /*simd=*/false); } static mlir::omp::ParallelOp genParallelOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList, @@ -2561,8 +2596,9 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands, reductionVars; llvm::SmallVector<mlir::Attribute> reductionDeclSymbols; + llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols; - ClauseProcessor cp(converter, clauseList); + ClauseProcessor cp(converter, semaCtx, clauseList); cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel, ifClauseOperand); cp.processNumThreads(stmtCtx, numThreadsClauseOperand); @@ -2570,13 +2606,33 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, cp.processDefault(); cp.processAllocate(allocatorOperands, allocateOperands); if (!outerCombined) - cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols); + cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols, + &reductionSymbols); + + llvm::SmallVector<mlir::Type> reductionTypes; + reductionTypes.reserve(reductionVars.size()); + llvm::transform(reductionVars, std::back_inserter(reductionTypes), + [](mlir::Value v) { return v.getType(); }); + + auto reductionCallback = [&](mlir::Operation *op) { + llvm::SmallVector<mlir::Location> locs(reductionVars.size(), + currentLocation); + auto block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {}, + reductionTypes, locs); + for (auto [arg, prv] : + llvm::zip_equal(reductionSymbols, block->getArguments())) { + converter.bindSymbol(*arg, prv); + } + return reductionSymbols; + }; return genOpWithBody<mlir::omp::ParallelOp>( - OpWithBodyGenInfo(converter, currentLocation, eval) + OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested) .setOuterCombined(outerCombined) - .setClauses(&clauseList), + .setClauses(&clauseList) + .setReductions(&reductionSymbols, &reductionTypes) + .setGenRegionEntryCb(reductionCallback), /*resultTypes=*/mlir::TypeRange(), ifClauseOperand, numThreadsClauseOperand, allocateOperands, allocatorOperands, reductionVars, @@ -2589,19 +2645,21 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, static mlir::omp::SectionOp genSectionOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::Location currentLocation, const Fortran::parser::OmpClauseList §ionsClauseList) { // Currently only private/firstprivate clause is handled, and // all privatization is done within `omp.section` operations. return genOpWithBody<mlir::omp::SectionOp>( - OpWithBodyGenInfo(converter, currentLocation, eval) + OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested) .setClauses(§ionsClauseList)); } static mlir::omp::SingleOp genSingleOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &beginClauseList, @@ -2609,15 +2667,15 @@ genSingleOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands; mlir::UnitAttr nowaitAttr; - ClauseProcessor cp(converter, beginClauseList); + ClauseProcessor cp(converter, semaCtx, beginClauseList); cp.processAllocate(allocatorOperands, allocateOperands); cp.processTODO<Fortran::parser::OmpClause::Copyprivate>( currentLocation, llvm::omp::Directive::OMPD_single); - ClauseProcessor(converter, endClauseList).processNowait(nowaitAttr); + ClauseProcessor(converter, semaCtx, endClauseList).processNowait(nowaitAttr); return genOpWithBody<mlir::omp::SingleOp>( - OpWithBodyGenInfo(converter, currentLocation, eval) + OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested) .setClauses(&beginClauseList), allocateOperands, allocatorOperands, nowaitAttr); @@ -2625,6 +2683,7 @@ genSingleOp(Fortran::lower::AbstractConverter &converter, static mlir::omp::TaskOp genTaskOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList) { @@ -2635,7 +2694,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands, dependOperands; - ClauseProcessor cp(converter, clauseList); + ClauseProcessor cp(converter, semaCtx, clauseList); cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Task, ifClauseOperand); cp.processAllocate(allocatorOperands, allocateOperands); @@ -2651,7 +2710,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter, currentLocation, llvm::omp::Directive::OMPD_task); return genOpWithBody<mlir::omp::TaskOp>( - OpWithBodyGenInfo(converter, currentLocation, eval) + OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested) .setClauses(&clauseList), ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr, @@ -2666,16 +2725,17 @@ genTaskOp(Fortran::lower::AbstractConverter &converter, static mlir::omp::TaskGroupOp genTaskGroupOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList) { llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands; - ClauseProcessor cp(converter, clauseList); + ClauseProcessor cp(converter, semaCtx, clauseList); cp.processAllocate(allocatorOperands, allocateOperands); cp.processTODO<Fortran::parser::OmpClause::TaskReduction>( currentLocation, llvm::omp::Directive::OMPD_taskgroup); return genOpWithBody<mlir::omp::TaskGroupOp>( - OpWithBodyGenInfo(converter, currentLocation, eval) + OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested) .setClauses(&clauseList), /*task_reduction_vars=*/mlir::ValueRange(), @@ -2684,9 +2744,9 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter, static mlir::omp::DataOp genDataOp(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - Fortran::semantics::SemanticsContext &semanticsContext, - bool genNested, mlir::Location currentLocation, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList) { Fortran::lower::StatementContext stmtCtx; mlir::Value ifClauseOperand, deviceOperand; @@ -2696,7 +2756,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<mlir::Location> useDeviceLocs; llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols; - ClauseProcessor cp(converter, clauseList); + ClauseProcessor cp(converter, semaCtx, clauseList); cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData, ifClauseOperand); cp.processDevice(stmtCtx, deviceOperand); @@ -2705,20 +2765,21 @@ genDataOp(Fortran::lower::AbstractConverter &converter, cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs, useDeviceSymbols); cp.processMap(currentLocation, llvm::omp::Directive::OMPD_target_data, - semanticsContext, stmtCtx, mapOperands); + stmtCtx, mapOperands); auto dataOp = converter.getFirOpBuilder().create<mlir::omp::DataOp>( currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands, deviceAddrOperands, mapOperands); - genBodyOfTargetDataOp(converter, eval, genNested, dataOp, useDeviceTypes, - useDeviceLocs, useDeviceSymbols, currentLocation); + genBodyOfTargetDataOp(converter, semaCtx, eval, genNested, dataOp, + useDeviceTypes, useDeviceLocs, useDeviceSymbols, + currentLocation); return dataOp; } template <typename OpTy> static OpTy genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::semantics::SemanticsContext &semaCtx, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -2745,33 +2806,34 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter, return nullptr; } - ClauseProcessor cp(converter, clauseList); + ClauseProcessor cp(converter, semaCtx, clauseList); cp.processIf(directiveName, ifClauseOperand); cp.processDevice(stmtCtx, deviceOperand); cp.processNowait(nowaitAttr); if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) { - cp.processMotionClauses<Fortran::parser::OmpClause::To>( - semanticsContext, stmtCtx, mapOperands); - cp.processMotionClauses<Fortran::parser::OmpClause::From>( - semanticsContext, stmtCtx, mapOperands); + cp.processMotionClauses<Fortran::parser::OmpClause::To>(stmtCtx, + mapOperands); + cp.processMotionClauses<Fortran::parser::OmpClause::From>(stmtCtx, + mapOperands); } else { - cp.processMap(currentLocation, directive, semanticsContext, stmtCtx, - mapOperands); + cp.processMap(currentLocation, directive, stmtCtx, mapOperands); } cp.processTODO<Fortran::parser::OmpClause::Depend>(currentLocation, directive); return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand, - deviceOperand, nowaitAttr, mapOperands); + deviceOperand, nullptr, mlir::ValueRange(), + nowaitAttr, mapOperands); } // This functions creates a block for the body of the targetOp's region. It adds // all the symbols present in mapSymbols as block arguments to this block. static void genBodyOfTargetOp( Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::omp::TargetOp &targetOp, const llvm::SmallVector<mlir::Type> &mapSymTypes, @@ -2923,9 +2985,9 @@ static void genBodyOfTargetOp( static mlir::omp::TargetOp genTargetOp(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - Fortran::semantics::SemanticsContext &semanticsContext, - bool genNested, mlir::Location currentLocation, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList, llvm::omp::Directive directive, bool outerCombined = false) { Fortran::lower::StatementContext stmtCtx; @@ -2936,14 +2998,14 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<mlir::Location> mapSymLocs; llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols; - ClauseProcessor cp(converter, clauseList); + ClauseProcessor cp(converter, semaCtx, clauseList); cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target, ifClauseOperand); cp.processDevice(stmtCtx, deviceOperand); cp.processThreadLimit(stmtCtx, threadLimitOperand); cp.processNowait(nowaitAttr); - cp.processMap(currentLocation, directive, semanticsContext, stmtCtx, - mapOperands, &mapSymTypes, &mapSymLocs, &mapSymbols); + cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes, + &mapSymLocs, &mapSymbols); cp.processTODO<Fortran::parser::OmpClause::Private, Fortran::parser::OmpClause::Depend, Fortran::parser::OmpClause::Firstprivate, @@ -3029,9 +3091,9 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>( currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand, - nowaitAttr, mapOperands); + nullptr, mlir::ValueRange(), nowaitAttr, mapOperands); - genBodyOfTargetOp(converter, eval, genNested, targetOp, mapSymTypes, + genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes, mapSymLocs, mapSymbols, currentLocation); return targetOp; @@ -3039,6 +3101,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, static mlir::omp::TeamsOp genTeamsOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, mlir::Location currentLocation, const Fortran::parser::OmpClauseList &clauseList, @@ -3049,7 +3112,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter, reductionVars; llvm::SmallVector<mlir::Attribute> reductionDeclSymbols; - ClauseProcessor cp(converter, clauseList); + ClauseProcessor cp(converter, semaCtx, clauseList); cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams, ifClauseOperand); cp.processAllocate(allocatorOperands, allocateOperands); @@ -3060,7 +3123,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter, currentLocation, llvm::omp::Directive::OMPD_teams); return genOpWithBody<mlir::omp::TeamsOp>( - OpWithBodyGenInfo(converter, currentLocation, eval) + OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) .setGenNested(genNested) .setOuterCombined(outerCombined) .setClauses(&clauseList), @@ -3077,6 +3140,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter, /// 'declare target' directive and return the intended device type for them. static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { @@ -3102,7 +3166,7 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( eval.getOwningProcedure()->getSubprogramSymbol()); } - ClauseProcessor cp(converter, *clauseList); + ClauseProcessor cp(converter, semaCtx, *clauseList); cp.processTo(symbolAndClause); cp.processEnter(symbolAndClause); cp.processLink(symbolAndClause); @@ -3118,12 +3182,13 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( static std::optional<mlir::omp::DeclareTargetDeviceType> getDeclareTargetFunctionDevice( Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) { llvm::SmallVector<DeclareTargetCapturePair, 0> symbolAndClause; mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo( - converter, eval, declareTargetConstruct, symbolAndClause); + converter, semaCtx, eval, declareTargetConstruct, symbolAndClause); // Return the device type only if at least one of the targets for the // directive is a function or subroutine @@ -3145,9 +3210,8 @@ getDeclareTargetFunctionDevice( static void genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - Fortran::semantics::SemanticsContext &semanticsContext, - bool genNested, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, const Fortran::parser::OpenMPSimpleStandaloneConstruct &simpleStandaloneConstruct) { const auto &directive = @@ -3165,7 +3229,7 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter, firOpBuilder.create<mlir::omp::BarrierOp>(currentLocation); break; case llvm::omp::Directive::OMPD_taskwait: - ClauseProcessor(converter, opClauseList) + ClauseProcessor(converter, semaCtx, opClauseList) .processTODO<Fortran::parser::OmpClause::Depend, Fortran::parser::OmpClause::Nowait>( currentLocation, llvm::omp::Directive::OMPD_taskwait); @@ -3175,20 +3239,20 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter, firOpBuilder.create<mlir::omp::TaskyieldOp>(currentLocation); break; case llvm::omp::Directive::OMPD_target_data: - genDataOp(converter, eval, semanticsContext, genNested, currentLocation, + genDataOp(converter, semaCtx, eval, genNested, currentLocation, opClauseList); break; case llvm::omp::Directive::OMPD_target_enter_data: genEnterExitUpdateDataOp<mlir::omp::EnterDataOp>( - converter, semanticsContext, currentLocation, opClauseList); + converter, semaCtx, currentLocation, opClauseList); break; case llvm::omp::Directive::OMPD_target_exit_data: genEnterExitUpdateDataOp<mlir::omp::ExitDataOp>( - converter, semanticsContext, currentLocation, opClauseList); + converter, semaCtx, currentLocation, opClauseList); break; case llvm::omp::Directive::OMPD_target_update: genEnterExitUpdateDataOp<mlir::omp::UpdateDataOp>( - converter, semanticsContext, currentLocation, opClauseList); + converter, semaCtx, currentLocation, opClauseList); break; case llvm::omp::Directive::OMPD_ordered: TODO(currentLocation, "OMPD_ordered"); @@ -3197,6 +3261,7 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter, static void genOmpFlush(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPFlushConstruct &flushConstruct) { llvm::SmallVector<mlir::Value, 4> operandRange; @@ -3216,19 +3281,19 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter, static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) { std::visit( Fortran::common::visitors{ [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct &simpleStandaloneConstruct) { - genOmpSimpleStandalone(converter, eval, semanticsContext, + genOmpSimpleStandalone(converter, semaCtx, eval, /*genNested=*/true, simpleStandaloneConstruct); }, [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) { - genOmpFlush(converter, eval, flushConstruct); + genOmpFlush(converter, semaCtx, eval, flushConstruct); }, [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) { TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); @@ -3289,6 +3354,7 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter, static void createSimdLoop(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive ompDirective, const Fortran::parser::OmpClauseList &loopOpClauseList, @@ -3307,7 +3373,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter, mlir::IntegerAttr simdlenClauseOperand, safelenClauseOperand; std::size_t loopVarTypeSize; - ClauseProcessor cp(converter, loopOpClauseList); + ClauseProcessor cp(converter, semaCtx, loopOpClauseList); cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv, loopVarTypeSize); cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand); @@ -3340,13 +3406,14 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter, }; createBodyOfOp<mlir::omp::SimdLoopOp>( - simdLoopOp, OpWithBodyGenInfo(converter, loc, *nestedEval) + simdLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) .setClauses(&loopOpClauseList) .setDataSharingProcessor(&dsp) .setGenRegionEntryCb(ivCallback)); } static void createWsLoop(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive ompDirective, const Fortran::parser::OmpClauseList &beginClauseList, @@ -3369,7 +3436,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter, mlir::omp::ScheduleModifierAttr scheduleModClauseOperand; std::size_t loopVarTypeSize; - ClauseProcessor cp(converter, beginClauseList); + ClauseProcessor cp(converter, semaCtx, beginClauseList); cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv, loopVarTypeSize); cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand); @@ -3409,7 +3476,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter, // <...> // !$omp end do nowait if (endClauseList) { - if (ClauseProcessor(converter, *endClauseList) + if (ClauseProcessor(converter, semaCtx, *endClauseList) .processNowait(nowaitClauseOperand)) wsLoopOp.setNowaitAttr(nowaitClauseOperand); } @@ -3422,7 +3489,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter, }; createBodyOfOp<mlir::omp::WsLoopOp>( - wsLoopOp, OpWithBodyGenInfo(converter, loc, *nestedEval) + wsLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) .setClauses(&beginClauseList) .setDataSharingProcessor(&dsp) .setGenRegionEntryCb(ivCallback)); @@ -3430,10 +3497,11 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter, static void createSimdWsLoop( Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive ompDirective, const Fortran::parser::OmpClauseList &beginClauseList, const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { - ClauseProcessor cp(converter, beginClauseList); + ClauseProcessor cp(converter, semaCtx, beginClauseList); cp.processTODO< Fortran::parser::OmpClause::Aligned, Fortran::parser::OmpClause::Allocate, Fortran::parser::OmpClause::Linear, Fortran::parser::OmpClause::Safelen, @@ -3447,13 +3515,13 @@ static void createSimdWsLoop( // When support for vectorization is enabled, then we need to add handling of // if clause. Currently if clause can be skipped because we always assume // SIMD length = 1. - createWsLoop(converter, eval, ompDirective, beginClauseList, endClauseList, - loc); + createWsLoop(converter, semaCtx, eval, ompDirective, beginClauseList, + endClauseList, loc); } static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { const auto &beginLoopDirective = @@ -3485,14 +3553,14 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, if ((llvm::omp::allTargetSet & llvm::omp::loopConstructSet) .test(ompDirective)) { validDirective = true; - genTargetOp(converter, eval, semanticsContext, /*genNested=*/false, + genTargetOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation, loopOpClauseList, ompDirective, /*outerCombined=*/true); } if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet) .test(ompDirective)) { validDirective = true; - genTeamsOp(converter, eval, /*genNested=*/false, currentLocation, + genTeamsOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation, loopOpClauseList, /*outerCombined=*/true); } @@ -3503,8 +3571,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet) .test(ompDirective)) { validDirective = true; - genParallelOp(converter, eval, /*genNested=*/false, currentLocation, - loopOpClauseList, + genParallelOp(converter, semaCtx, eval, /*genNested=*/false, + currentLocation, loopOpClauseList, /*outerCombined=*/true); } } @@ -3519,25 +3587,25 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, if (llvm::omp::allDoSimdSet.test(ompDirective)) { // 2.9.3.2 Workshare SIMD construct - createSimdWsLoop(converter, eval, ompDirective, loopOpClauseList, + createSimdWsLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList, endClauseList, currentLocation); } else if (llvm::omp::allSimdSet.test(ompDirective)) { // 2.9.3.1 SIMD construct - createSimdLoop(converter, eval, ompDirective, loopOpClauseList, + createSimdLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList, currentLocation); } else { - createWsLoop(converter, eval, ompDirective, loopOpClauseList, endClauseList, - currentLocation); + createWsLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList, + endClauseList, currentLocation); } - genOpenMPReduction(converter, loopOpClauseList); + genOpenMPReduction(converter, semaCtx, loopOpClauseList); } static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPBlockConstruct &blockConstruct) { const auto &beginBlockDirective = @@ -3586,37 +3654,38 @@ genOMP(Fortran::lower::AbstractConverter &converter, mlir::Location currentLocation = converter.genLocation(directive.source); switch (directive.v) { case llvm::omp::Directive::OMPD_master: - genMasterOp(converter, eval, /*genNested=*/true, currentLocation); + genMasterOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation); break; case llvm::omp::Directive::OMPD_ordered: - genOrderedRegionOp(converter, eval, /*genNested=*/true, currentLocation); + genOrderedRegionOp(converter, semaCtx, eval, /*genNested=*/true, + currentLocation); break; case llvm::omp::Directive::OMPD_parallel: - genParallelOp(converter, eval, /*genNested=*/true, currentLocation, + genParallelOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, beginClauseList); break; case llvm::omp::Directive::OMPD_single: - genSingleOp(converter, eval, /*genNested=*/true, currentLocation, + genSingleOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, beginClauseList, endClauseList); break; case llvm::omp::Directive::OMPD_target: - genTargetOp(converter, eval, semanticsContext, /*genNested=*/true, - currentLocation, beginClauseList, directive.v); + genTargetOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, + beginClauseList, directive.v); break; case llvm::omp::Directive::OMPD_target_data: - genDataOp(converter, eval, semanticsContext, /*genNested=*/true, - currentLocation, beginClauseList); + genDataOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, + beginClauseList); break; case llvm::omp::Directive::OMPD_task: - genTaskOp(converter, eval, /*genNested=*/true, currentLocation, + genTaskOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, beginClauseList); break; case llvm::omp::Directive::OMPD_taskgroup: - genTaskGroupOp(converter, eval, /*genNested=*/true, currentLocation, - beginClauseList); + genTaskGroupOp(converter, semaCtx, eval, /*genNested=*/true, + currentLocation, beginClauseList); break; case llvm::omp::Directive::OMPD_teams: - genTeamsOp(converter, eval, /*genNested=*/true, currentLocation, + genTeamsOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, beginClauseList, /*outerCombined=*/false); break; @@ -3628,23 +3697,21 @@ genOMP(Fortran::lower::AbstractConverter &converter, break; } - if (singleDirective) { - genOpenMPReduction(converter, beginClauseList); + if (singleDirective) return; - } // Codegen for combined directives bool combinedDirective = false; if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet) .test(directive.v)) { - genTargetOp(converter, eval, semanticsContext, /*genNested=*/false, - currentLocation, beginClauseList, directive.v, + genTargetOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation, + beginClauseList, directive.v, /*outerCombined=*/true); combinedDirective = true; } if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet) .test(directive.v)) { - genTeamsOp(converter, eval, /*genNested=*/false, currentLocation, + genTeamsOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation, beginClauseList); combinedDirective = true; } @@ -3652,8 +3719,8 @@ genOMP(Fortran::lower::AbstractConverter &converter, .test(directive.v)) { bool outerCombined = directive.v != llvm::omp::Directive::OMPD_target_parallel; - genParallelOp(converter, eval, /*genNested=*/false, currentLocation, - beginClauseList, outerCombined); + genParallelOp(converter, semaCtx, eval, /*genNested=*/false, + currentLocation, beginClauseList, outerCombined); combinedDirective = true; } if ((llvm::omp::workShareSet & llvm::omp::blockConstructSet) @@ -3667,13 +3734,12 @@ genOMP(Fortran::lower::AbstractConverter &converter, ")"); genNestedEvaluations(converter, eval); - genOpenMPReduction(converter, beginClauseList); } static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -3688,7 +3754,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, } const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t); - ClauseProcessor(converter, clauseList).processHint(hintClauseOp); + ClauseProcessor(converter, semaCtx, clauseList).processHint(hintClauseOp); mlir::omp::CriticalOp criticalOp = [&]() { if (name.empty()) { @@ -3706,14 +3772,14 @@ genOMP(Fortran::lower::AbstractConverter &converter, currentLocation, mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), global.getSymName())); }(); - auto genInfo = OpWithBodyGenInfo(converter, currentLocation, eval); + auto genInfo = OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval); createBodyOfOp<mlir::omp::CriticalOp>(criticalOp, genInfo); } static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { mlir::Location currentLocation = converter.getCurrentLocation(); @@ -3726,7 +3792,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, // Process clauses before optional omp.parallel, so that new variables are // allocated outside of the parallel region - ClauseProcessor cp(converter, sectionsClauseList); + ClauseProcessor cp(converter, semaCtx, sectionsClauseList); cp.processSectionsReduction(currentLocation); cp.processAllocate(allocatorOperands, allocateOperands); @@ -3736,7 +3802,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, // Parallel wrapper of PARALLEL SECTIONS construct if (dir == llvm::omp::Directive::OMPD_parallel_sections) { - genParallelOp(converter, eval, + genParallelOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation, sectionsClauseList, /*outerCombined=*/true); } else { @@ -3744,13 +3810,14 @@ genOMP(Fortran::lower::AbstractConverter &converter, std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t); const auto &endSectionsClauseList = std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t); - ClauseProcessor(converter, endSectionsClauseList) + ClauseProcessor(converter, semaCtx, endSectionsClauseList) .processNowait(nowaitClauseOperand); } // SECTIONS construct genOpWithBody<mlir::omp::SectionsOp>( - OpWithBodyGenInfo(converter, currentLocation, eval).setGenNested(false), + OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) + .setGenNested(false), /*reduction_vars=*/mlir::ValueRange(), /*reductions=*/nullptr, allocateOperands, allocatorOperands, nowaitClauseOperand); @@ -3762,7 +3829,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, for (const auto &[nblock, neval] : llvm::zip(sectionBlocks.v, eval.getNestedEvaluations())) { symTable.pushScope(); - genSectionOp(converter, neval, /*genNested=*/true, currentLocation, + genSectionOp(converter, semaCtx, neval, /*genNested=*/true, currentLocation, sectionsClauseList); symTable.popScope(); firOpBuilder.restoreInsertionPoint(ip); @@ -3772,7 +3839,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) { std::visit( @@ -3817,14 +3884,14 @@ genOMP(Fortran::lower::AbstractConverter &converter, static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) { llvm::SmallVector<DeclareTargetCapturePair, 0> symbolAndClause; mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo( - converter, eval, declareTargetConstruct, symbolAndClause); + converter, semaCtx, eval, declareTargetConstruct, symbolAndClause); for (const DeclareTargetCapturePair &symClause : symbolAndClause) { mlir::Operation *op = mod.lookupSymbol( @@ -3870,27 +3937,25 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPConstruct &ompConstruct) { std::visit( Fortran::common::visitors{ [&](const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) { - genOMP(converter, symTable, semanticsContext, eval, - standaloneConstruct); + genOMP(converter, symTable, semaCtx, eval, standaloneConstruct); }, [&](const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { - genOMP(converter, symTable, semanticsContext, eval, - sectionsConstruct); + genOMP(converter, symTable, semaCtx, eval, sectionsConstruct); }, [&](const Fortran::parser::OpenMPSectionConstruct §ionConstruct) { // SECTION constructs are handled as a part of SECTIONS. llvm_unreachable("Unexpected standalone OMP SECTION"); }, [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { - genOMP(converter, symTable, semanticsContext, eval, loopConstruct); + genOMP(converter, symTable, semaCtx, eval, loopConstruct); }, [&](const Fortran::parser::OpenMPDeclarativeAllocate &execAllocConstruct) { @@ -3905,16 +3970,14 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, TODO(converter.getCurrentLocation(), "OpenMPAllocatorsConstruct"); }, [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) { - genOMP(converter, symTable, semanticsContext, eval, blockConstruct); + genOMP(converter, symTable, semaCtx, eval, blockConstruct); }, [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) { - genOMP(converter, symTable, semanticsContext, eval, - atomicConstruct); + genOMP(converter, symTable, semaCtx, eval, atomicConstruct); }, [&](const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { - genOMP(converter, symTable, semanticsContext, eval, - criticalConstruct); + genOMP(converter, symTable, semaCtx, eval, criticalConstruct); }, }, ompConstruct.u); @@ -3923,7 +3986,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPDeclarativeConstruct &ompDeclConstruct) { std::visit( @@ -3943,8 +4006,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, }, [&](const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) { - genOMP(converter, symTable, semanticsContext, eval, - declareTargetConstruct); + genOMP(converter, symTable, semaCtx, eval, declareTargetConstruct); }, [&](const Fortran::parser::OpenMPRequiresConstruct &requiresConstruct) { @@ -3978,21 +4040,21 @@ mlir::Operation *Fortran::lower::genOpenMPTerminator(fir::FirOpBuilder &builder, void Fortran::lower::genOpenMPConstruct( Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPConstruct &omp) { symTable.pushScope(); - genOMP(converter, symTable, semanticsContext, eval, omp); + genOMP(converter, symTable, semaCtx, eval, omp); symTable.popScope(); } void Fortran::lower::genOpenMPDeclarativeConstruct( Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semanticsContext, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPDeclarativeConstruct &omp) { - genOMP(converter, symTable, semanticsContext, eval, omp); + genOMP(converter, symTable, semaCtx, eval, omp); genNestedEvaluations(converter, eval); } @@ -4107,6 +4169,7 @@ void Fortran::lower::genDeclareTargetIntGlobal( // ops in the builder (instead of a rewriter) is probably not the best approach. void Fortran::lower::genOpenMPReduction( Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, const Fortran::parser::OmpClauseList &clauseList) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -4174,7 +4237,7 @@ void Fortran::lower::genOpenMPReduction( if (!ReductionProcessor::supportedIntrinsicProcReduction( *reductionIntrinsic)) continue; - ReductionProcessor::IntrinsicProc redIntrinsicProc = + ReductionProcessor::ReductionIdentifier redId = ReductionProcessor::getReductionType(*reductionIntrinsic); for (const Fortran::parser::OmpObject &ompObject : objectList.v) { if (const auto *name{ @@ -4195,10 +4258,8 @@ void Fortran::lower::genOpenMPReduction( if (reductionOp == nullptr) continue; - if (redIntrinsicProc == - ReductionProcessor::IntrinsicProc::MAX || - redIntrinsicProc == - ReductionProcessor::IntrinsicProc::MIN) { + if (redId == ReductionProcessor::ReductionIdentifier::MAX || + redId == ReductionProcessor::ReductionIdentifier::MIN) { assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) && "Selection Op not found in reduction intrinsic"); mlir::Operation *compareOp = @@ -4206,13 +4267,9 @@ void Fortran::lower::genOpenMPReduction( updateReduction(compareOp, firOpBuilder, loadVal, reductionVal); } - if (redIntrinsicProc == - ReductionProcessor::IntrinsicProc::IOR || - redIntrinsicProc == - ReductionProcessor::IntrinsicProc::IEOR || - redIntrinsicProc == - ReductionProcessor::IntrinsicProc::IAND) { - + if (redId == ReductionProcessor::ReductionIdentifier::IOR || + redId == ReductionProcessor::ReductionIdentifier::IEOR || + redId == ReductionProcessor::ReductionIdentifier::IAND) { updateReduction(reductionOp, firOpBuilder, loadVal, reductionVal); } @@ -4335,13 +4392,14 @@ bool Fortran::lower::isOpenMPTargetConstruct( bool Fortran::lower::isOpenMPDeviceDeclareTarget( Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { return std::visit( Fortran::common::visitors{ [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) { mlir::omp::DeclareTargetDeviceType targetType = - getDeclareTargetFunctionDevice(converter, eval, ompReq) + getDeclareTargetFunctionDevice(converter, semaCtx, eval, ompReq) .value_or(mlir::omp::DeclareTargetDeviceType::host); return targetType != mlir::omp::DeclareTargetDeviceType::host; }, |