diff options
Diffstat (limited to 'flang/lib')
23 files changed, 887 insertions, 122 deletions
| diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index a516a44..6e72987 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -1884,6 +1884,26 @@ private:      setCurrentPosition(stmt.source);      assert(stmt.typedCall && "Call was not analyzed");      mlir::Value res{}; + +    // Set 'no_inline', 'inline_hint' or 'always_inline' to true on the +    // ProcedureRef. The NoInline and AlwaysInline attribute will be set in +    // genProcedureRef later. +    for (const auto *dir : eval.dirs) { +      Fortran::common::visit( +          Fortran::common::visitors{ +              [&](const Fortran::parser::CompilerDirective::ForceInline &) { +                stmt.typedCall->setAlwaysInline(true); +              }, +              [&](const Fortran::parser::CompilerDirective::Inline &) { +                stmt.typedCall->setInlineHint(true); +              }, +              [&](const Fortran::parser::CompilerDirective::NoInline &) { +                stmt.typedCall->setNoInline(true); +              }, +              [&](const auto &) {}}, +          dir->u); +    } +      if (lowerToHighLevelFIR()) {        std::optional<mlir::Type> resultType;        if (stmt.typedCall->hasAlternateReturns()) @@ -2200,6 +2220,50 @@ private:      // so no clean-up needs to be generated for these entities.    } +  void attachInlineAttributes( +      mlir::Operation &op, +      const llvm::ArrayRef<const Fortran::parser::CompilerDirective *> &dirs) { +    if (dirs.empty()) +      return; + +    for (mlir::Value operand : op.getOperands()) { +      if (operand.getDefiningOp()) +        attachInlineAttributes(*operand.getDefiningOp(), dirs); +    } + +    if (fir::CallOp callOp = mlir::dyn_cast<fir::CallOp>(op)) { +      for (const auto *dir : dirs) { +        Fortran::common::visit( +            Fortran::common::visitors{ +                [&](const Fortran::parser::CompilerDirective::NoInline &) { +                  callOp.setInlineAttr(fir::FortranInlineEnum::no_inline); +                }, +                [&](const Fortran::parser::CompilerDirective::Inline &) { +                  callOp.setInlineAttr(fir::FortranInlineEnum::inline_hint); +                }, +                [&](const Fortran::parser::CompilerDirective::ForceInline &) { +                  callOp.setInlineAttr(fir::FortranInlineEnum::always_inline); +                }, +                [&](const auto &) {}}, +            dir->u); +      } +    } +  } + +  void attachAttributesToDoLoopOperations( +      fir::DoLoopOp &doLoop, +      llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) { +    if (!doLoop.getOperation() || dirs.empty()) +      return; + +    for (mlir::Block &block : doLoop.getRegion()) { +      for (mlir::Operation &op : block.getOperations()) { +        if (!dirs.empty()) +          attachInlineAttributes(op, dirs); +      } +    } +  } +    /// Generate FIR for a DO construct. There are six variants:    ///  - unstructured infinite and while loops    ///  - structured and unstructured increment loops @@ -2351,6 +2415,11 @@ private:      if (!incrementLoopNestInfo.empty() &&          incrementLoopNestInfo.back().isConcurrent)        localSymbols.popScope(); + +    // Add attribute(s) on operations in fir::DoLoopOp if necessary +    for (IncrementLoopInfo &info : incrementLoopNestInfo) +      if (auto loopOp = mlir::dyn_cast_if_present<fir::DoLoopOp>(info.loopOp)) +        attachAttributesToDoLoopOperations(loopOp, doStmtEval.dirs);    }    /// Generate FIR to evaluate loop control values (lower, upper and step). @@ -3154,6 +3223,26 @@ private:        e->dirs.push_back(&dir);    } +  void +  attachInliningDirectiveToStmt(const Fortran::parser::CompilerDirective &dir, +                                Fortran::lower::pft::Evaluation *e) { +    while (e->isDirective()) +      e = e->lexicalSuccessor; + +    // If the successor is a statement or a do loop, the compiler +    // will perform inlining. +    if (e->isA<Fortran::parser::CallStmt>() || +        e->isA<Fortran::parser::NonLabelDoStmt>() || +        e->isA<Fortran::parser::AssignmentStmt>()) { +      e->dirs.push_back(&dir); +    } else { +      mlir::Location loc = toLocation(); +      mlir::emitWarning(loc, +                        "Inlining directive not in front of loops, function" +                        "call or assignment.\n"); +    } +  } +    void genFIR(const Fortran::parser::CompilerDirective &dir) {      Fortran::lower::pft::Evaluation &eval = getEval(); @@ -3177,6 +3266,15 @@ private:              [&](const Fortran::parser::CompilerDirective::NoUnrollAndJam &) {                attachDirectiveToLoop(dir, &eval);              }, +            [&](const Fortran::parser::CompilerDirective::ForceInline &) { +              attachInliningDirectiveToStmt(dir, &eval); +            }, +            [&](const Fortran::parser::CompilerDirective::Inline &) { +              attachInliningDirectiveToStmt(dir, &eval); +            }, +            [&](const Fortran::parser::CompilerDirective::NoInline &) { +              attachInliningDirectiveToStmt(dir, &eval); +            },              [&](const auto &) {}},          dir.u);    } @@ -5086,7 +5184,9 @@ private:    void genDataAssignment(        const Fortran::evaluate::Assignment &assign, -      const Fortran::evaluate::ProcedureRef *userDefinedAssignment) { +      const Fortran::evaluate::ProcedureRef *userDefinedAssignment, +      const llvm::ArrayRef<const Fortran::parser::CompilerDirective *> &dirs = +          {}) {      mlir::Location loc = getCurrentLocation();      fir::FirOpBuilder &builder = getFirOpBuilder(); @@ -5166,10 +5266,20 @@ private:          genCUDADataTransfer(builder, loc, assign, lhs, rhs,                              isWholeAllocatableAssignment,                              keepLhsLengthInAllocatableAssignment); -      else +      else { +        // If RHS or LHS have a CallOp in their expression, this operation will +        // have the 'no_inline' or 'always_inline' attribute if there is a +        // directive just before the assignement. +        if (!dirs.empty()) { +          if (rhs.getDefiningOp()) +            attachInlineAttributes(*rhs.getDefiningOp(), dirs); +          if (lhs.getDefiningOp()) +            attachInlineAttributes(*lhs.getDefiningOp(), dirs); +        }          hlfir::AssignOp::create(builder, loc, rhs, lhs,                                  isWholeAllocatableAssignment,                                  keepLhsLengthInAllocatableAssignment); +      }        if (hasCUDAImplicitTransfer && !isInDeviceContext) {          localSymbols.popScope();          for (mlir::Value temp : implicitTemps) @@ -5237,16 +5347,21 @@ private:    }    /// Shared for both assignments and pointer assignments. -  void genAssignment(const Fortran::evaluate::Assignment &assign) { +  void +  genAssignment(const Fortran::evaluate::Assignment &assign, +                const llvm::ArrayRef<const Fortran::parser::CompilerDirective *> +                    &dirs = {}) {      mlir::Location loc = toLocation();      if (lowerToHighLevelFIR()) {        Fortran::common::visit(            Fortran::common::visitors{                [&](const Fortran::evaluate::Assignment::Intrinsic &) { -                genDataAssignment(assign, /*userDefinedAssignment=*/nullptr); +                genDataAssignment(assign, /*userDefinedAssignment=*/nullptr, +                                  dirs);                },                [&](const Fortran::evaluate::ProcedureRef &procRef) { -                genDataAssignment(assign, /*userDefinedAssignment=*/&procRef); +                genDataAssignment(assign, /*userDefinedAssignment=*/&procRef, +                                  dirs);                },                [&](const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) {                  if (isInsideHlfirForallOrWhere()) @@ -5651,7 +5766,8 @@ private:    }    void genFIR(const Fortran::parser::AssignmentStmt &stmt) { -    genAssignment(*stmt.typedAssignment->v); +    Fortran::lower::pft::Evaluation &eval = getEval(); +    genAssignment(*stmt.typedAssignment->v, eval.dirs);    }    void genFIR(const Fortran::parser::SyncAllStmt &stmt) { diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp index fb72040..9bf994e 100644 --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -700,9 +700,20 @@ Fortran::lower::genCallOpAndResult(        callResult = dispatch.getResult(0);    } else {      // Standard procedure call with fir.call. +    fir::FortranInlineEnumAttr inlineAttr; + +    if (caller.getCallDescription().hasNoInline()) +      inlineAttr = fir::FortranInlineEnumAttr::get( +          builder.getContext(), fir::FortranInlineEnum::no_inline); +    else if (caller.getCallDescription().hasInlineHint()) +      inlineAttr = fir::FortranInlineEnumAttr::get( +          builder.getContext(), fir::FortranInlineEnum::inline_hint); +    else if (caller.getCallDescription().hasAlwaysInline()) +      inlineAttr = fir::FortranInlineEnumAttr::get( +          builder.getContext(), fir::FortranInlineEnum::always_inline);      auto call = fir::CallOp::create(          builder, loc, funcType.getResults(), funcSymbolAttr, operands, -        /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs); +        /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs, inlineAttr);      callNumResults = call.getNumResults();      if (callNumResults != 0) diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index d39f9dd..0f60b47 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -1482,6 +1482,21 @@ ThreadLimit make(const parser::OmpClause::ThreadLimit &inp,    return ThreadLimit{/*Threadlim=*/makeExpr(inp.v, semaCtx)};  } +Threadset make(const parser::OmpClause::Threadset &inp, +               semantics::SemanticsContext &semaCtx) { +  // inp.v -> parser::OmpThreadsetClause +  using wrapped = parser::OmpThreadsetClause; + +  CLAUSET_ENUM_CONVERT( // +      convert, wrapped::ThreadsetPolicy, Threadset::ThreadsetPolicy, +      // clang-format off +      MS(Omp_Pool, Omp_Pool) +      MS(Omp_Team, Omp_Team) +      // clang-format on +  ); +  return Threadset{/*ThreadsetPolicy=*/convert(inp.v.v)}; +} +  // Threadprivate: empty  // Threads: empty diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 39bac81..ca3e1cd 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -50,6 +50,7 @@  #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  #include "mlir/Dialect/LLVMIR/LLVMTypes.h"  #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h"  #include "mlir/Dialect/Vector/IR/VectorOps.h"  #include "llvm/Support/CommandLine.h"  #include "llvm/Support/Debug.h" @@ -358,6 +359,14 @@ static constexpr IntrinsicHandler handlers[]{       &I::genBarrierInit,       {{{"barrier", asAddr}, {"count", asValue}}},       /*isElemental=*/false}, +    {"barrier_try_wait", +     &I::genBarrierTryWait, +     {{{"barrier", asAddr}, {"token", asValue}}}, +     /*isElemental=*/false}, +    {"barrier_try_wait_sleep", +     &I::genBarrierTryWaitSleep, +     {{{"barrier", asAddr}, {"token", asValue}, {"ns", asValue}}}, +     /*isElemental=*/false},      {"bessel_jn",       &I::genBesselJn,       {{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}}, @@ -1036,10 +1045,87 @@ static constexpr IntrinsicHandler handlers[]{         {"dst", asAddr},         {"nbytes", asValue}}},       /*isElemental=*/false}, +    {"tma_bulk_ldc4", +     &I::genTMABulkLoadC4, +     {{{"barrier", asAddr}, +       {"src", asAddr}, +       {"dst", asAddr}, +       {"nelems", asValue}}}, +     /*isElemental=*/false}, +    {"tma_bulk_ldc8", +     &I::genTMABulkLoadC8, +     {{{"barrier", asAddr}, +       {"src", asAddr}, +       {"dst", asAddr}, +       {"nelems", asValue}}}, +     /*isElemental=*/false}, +    {"tma_bulk_ldi4", +     &I::genTMABulkLoadI4, +     {{{"barrier", asAddr}, +       {"src", asAddr}, +       {"dst", asAddr}, +       {"nelems", asValue}}}, +     /*isElemental=*/false}, +    {"tma_bulk_ldi8", +     &I::genTMABulkLoadI8, +     {{{"barrier", asAddr}, +       {"src", asAddr}, +       {"dst", asAddr}, +       {"nelems", asValue}}}, +     /*isElemental=*/false}, +    {"tma_bulk_ldr2", +     &I::genTMABulkLoadR2, +     {{{"barrier", asAddr}, +       {"src", asAddr}, +       {"dst", asAddr}, +       {"nelems", asValue}}}, +     /*isElemental=*/false}, +    {"tma_bulk_ldr4", +     &I::genTMABulkLoadR4, +     {{{"barrier", asAddr}, +       {"src", asAddr}, +       {"dst", asAddr}, +       {"nelems", asValue}}}, +     /*isElemental=*/false}, +    {"tma_bulk_ldr8", +     &I::genTMABulkLoadR8, +     {{{"barrier", asAddr}, +       {"src", asAddr}, +       {"dst", asAddr}, +       {"nelems", asValue}}}, +     /*isElemental=*/false},      {"tma_bulk_s2g",       &I::genTMABulkS2G,       {{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}},       /*isElemental=*/false}, +    {"tma_bulk_store_c4", +     &I::genTMABulkStoreC4, +     {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, +     /*isElemental=*/false}, +    {"tma_bulk_store_c8", +     &I::genTMABulkStoreC8, +     {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, +     /*isElemental=*/false}, +    {"tma_bulk_store_i4", +     &I::genTMABulkStoreI4, +     {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, +     /*isElemental=*/false}, +    {"tma_bulk_store_i8", +     &I::genTMABulkStoreI8, +     {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, +     /*isElemental=*/false}, +    {"tma_bulk_store_r2", +     &I::genTMABulkStoreR2, +     {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, +     /*isElemental=*/false}, +    {"tma_bulk_store_r4", +     &I::genTMABulkStoreR4, +     {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, +     /*isElemental=*/false}, +    {"tma_bulk_store_r8", +     &I::genTMABulkStoreR8, +     {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}}, +     /*isElemental=*/false},      {"tma_bulk_wait_group",       &I::genTMABulkWaitGroup,       {{}}, @@ -3282,6 +3368,57 @@ void IntrinsicLibrary::genBarrierInit(llvm::ArrayRef<fir::ExtendedValue> args) {    mlir::NVVM::FenceProxyOp::create(builder, loc, kind, space);  } +// BARRIER_TRY_WAIT (CUDA) +mlir::Value +IntrinsicLibrary::genBarrierTryWait(mlir::Type resultType, +                                    llvm::ArrayRef<mlir::Value> args) { +  assert(args.size() == 2); +  mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); +  mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0); +  fir::StoreOp::create(builder, loc, zero, res); +  mlir::Value ns = +      builder.createIntegerConstant(loc, builder.getI32Type(), 1000000); +  mlir::Value load = fir::LoadOp::create(builder, loc, res); +  auto whileOp = mlir::scf::WhileOp::create( +      builder, loc, mlir::TypeRange{resultType}, mlir::ValueRange{load}); +  mlir::Block *beforeBlock = builder.createBlock(&whileOp.getBefore()); +  mlir::Value beforeArg = beforeBlock->addArgument(resultType, loc); +  builder.setInsertionPointToStart(beforeBlock); +  mlir::Value condition = mlir::arith::CmpIOp::create( +      builder, loc, mlir::arith::CmpIPredicate::ne, beforeArg, zero); +  mlir::scf::ConditionOp::create(builder, loc, condition, beforeArg); +  mlir::Block *afterBlock = builder.createBlock(&whileOp.getAfter()); +  afterBlock->addArgument(resultType, loc); +  builder.setInsertionPointToStart(afterBlock); +  auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext()); +  auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]); +  mlir::Value ret = +      mlir::NVVM::InlinePtxOp::create( +          builder, loc, {resultType}, {barrier, args[1], ns}, {}, +          ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%1], %2, %3; " +          "selp.b32 %0, 1, 0, p;", +          {}) +          .getResult(0); +  mlir::scf::YieldOp::create(builder, loc, ret); +  builder.setInsertionPointAfter(whileOp); +  return whileOp.getResult(0); +} + +// BARRIER_TRY_WAIT_SLEEP (CUDA) +mlir::Value +IntrinsicLibrary::genBarrierTryWaitSleep(mlir::Type resultType, +                                         llvm::ArrayRef<mlir::Value> args) { +  assert(args.size() == 3); +  auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext()); +  auto barrier = builder.createConvert(loc, llvmPtrTy, args[0]); +  return mlir::NVVM::InlinePtxOp::create( +             builder, loc, {resultType}, {barrier, args[1], args[2]}, {}, +             ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%1], %2, %3; " +             "selp.b32 %0, 1, 0, p;", +             {}) +      .getResult(0); +} +  // BESSEL_JN  fir::ExtendedValue  IntrinsicLibrary::genBesselJn(mlir::Type resultType, @@ -9218,6 +9355,95 @@ void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue> args) {        builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {});  } +static void genTMABulkLoad(fir::FirOpBuilder &builder, mlir::Location loc, +                           mlir::Value barrier, mlir::Value src, +                           mlir::Value dst, mlir::Value nelem, +                           mlir::Value eleSize) { +  mlir::Value size = mlir::arith::MulIOp::create(builder, loc, nelem, eleSize); +  auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext()); +  barrier = builder.createConvert(loc, llvmPtrTy, barrier); +  dst = builder.createConvert(loc, llvmPtrTy, dst); +  src = builder.createConvert(loc, llvmPtrTy, src); +  mlir::NVVM::InlinePtxOp::create( +      builder, loc, mlir::TypeRange{}, {dst, src, size, barrier}, {}, +      "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], " +      "[%1], %2, [%3];", +      {}); +  mlir::NVVM::InlinePtxOp::create( +      builder, loc, mlir::TypeRange{}, {barrier, size}, {}, +      "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;", {}); +} + +// TMA_BULK_LOADC4 +void IntrinsicLibrary::genTMABulkLoadC4( +    llvm::ArrayRef<fir::ExtendedValue> args) { +  assert(args.size() == 4); +  mlir::Value eleSize = +      builder.createIntegerConstant(loc, builder.getI32Type(), 8); +  genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), +                 fir::getBase(args[2]), fir::getBase(args[3]), eleSize); +} + +// TMA_BULK_LOADC8 +void IntrinsicLibrary::genTMABulkLoadC8( +    llvm::ArrayRef<fir::ExtendedValue> args) { +  assert(args.size() == 4); +  mlir::Value eleSize = +      builder.createIntegerConstant(loc, builder.getI32Type(), 16); +  genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), +                 fir::getBase(args[2]), fir::getBase(args[3]), eleSize); +} + +// TMA_BULK_LOADI4 +void IntrinsicLibrary::genTMABulkLoadI4( +    llvm::ArrayRef<fir::ExtendedValue> args) { +  assert(args.size() == 4); +  mlir::Value eleSize = +      builder.createIntegerConstant(loc, builder.getI32Type(), 4); +  genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), +                 fir::getBase(args[2]), fir::getBase(args[3]), eleSize); +} + +// TMA_BULK_LOADI8 +void IntrinsicLibrary::genTMABulkLoadI8( +    llvm::ArrayRef<fir::ExtendedValue> args) { +  assert(args.size() == 4); +  mlir::Value eleSize = +      builder.createIntegerConstant(loc, builder.getI32Type(), 8); +  genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), +                 fir::getBase(args[2]), fir::getBase(args[3]), eleSize); +} + +// TMA_BULK_LOADR2 +void IntrinsicLibrary::genTMABulkLoadR2( +    llvm::ArrayRef<fir::ExtendedValue> args) { +  assert(args.size() == 4); +  mlir::Value eleSize = +      builder.createIntegerConstant(loc, builder.getI32Type(), 2); +  genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), +                 fir::getBase(args[2]), fir::getBase(args[3]), eleSize); +} + +// TMA_BULK_LOADR4 +void IntrinsicLibrary::genTMABulkLoadR4( +    llvm::ArrayRef<fir::ExtendedValue> args) { +  assert(args.size() == 4); +  mlir::Value eleSize = +      builder.createIntegerConstant(loc, builder.getI32Type(), 4); +  genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), +                 fir::getBase(args[2]), fir::getBase(args[3]), eleSize); +} + +// TMA_BULK_LOADR8 +void IntrinsicLibrary::genTMABulkLoadR8( +    llvm::ArrayRef<fir::ExtendedValue> args) { +  assert(args.size() == 4); +  mlir::Value eleSize = +      builder.createIntegerConstant(loc, builder.getI32Type(), 8); +  genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), +                 fir::getBase(args[2]), fir::getBase(args[3]), eleSize); +} +  // TMA_BULK_S2G (CUDA)  void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {    assert(args.size() == 3); @@ -9227,6 +9453,97 @@ void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {                                            mlir::NVVM::NVVMMemorySpace::Global);    mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create(        builder, loc, dst, src, fir::getBase(args[2]), {}, {}); + +  mlir::NVVM::InlinePtxOp::create(builder, loc, mlir::TypeRange{}, {}, {}, +                                  "cp.async.bulk.commit_group", {}); +  mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc, +                                             builder.getI32IntegerAttr(0), {}); +} + +static void genTMABulkStore(fir::FirOpBuilder &builder, mlir::Location loc, +                            mlir::Value src, mlir::Value dst, mlir::Value count, +                            mlir::Value eleSize) { +  mlir::Value size = mlir::arith::MulIOp::create(builder, loc, eleSize, count); +  src = convertPtrToNVVMSpace(builder, loc, src, +                              mlir::NVVM::NVVMMemorySpace::Shared); +  dst = convertPtrToNVVMSpace(builder, loc, dst, +                              mlir::NVVM::NVVMMemorySpace::Global); +  mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create(builder, loc, dst, src, +                                                     size, {}, {}); +  mlir::NVVM::InlinePtxOp::create(builder, loc, mlir::TypeRange{}, {}, {}, +                                  "cp.async.bulk.commit_group", {}); +  mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc, +                                             builder.getI32IntegerAttr(0), {}); +} + +// TMA_BULK_STORE_C4 (CUDA) +void IntrinsicLibrary::genTMABulkStoreC4( +    llvm::ArrayRef<fir::ExtendedValue> args) { +  assert(args.size() == 3); +  mlir::Value eleSize = +      builder.createIntegerConstant(loc, builder.getI32Type(), 8); +  genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), +                  fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_C8 (CUDA) +void IntrinsicLibrary::genTMABulkStoreC8( +    llvm::ArrayRef<fir::ExtendedValue> args) { +  assert(args.size() == 3); +  mlir::Value eleSize = +      builder.createIntegerConstant(loc, builder.getI32Type(), 16); +  genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), +                  fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_I4 (CUDA) +void IntrinsicLibrary::genTMABulkStoreI4( +    llvm::ArrayRef<fir::ExtendedValue> args) { +  assert(args.size() == 3); +  mlir::Value eleSize = +      builder.createIntegerConstant(loc, builder.getI32Type(), 4); +  genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), +                  fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_I8 (CUDA) +void IntrinsicLibrary::genTMABulkStoreI8( +    llvm::ArrayRef<fir::ExtendedValue> args) { +  assert(args.size() == 3); +  mlir::Value eleSize = +      builder.createIntegerConstant(loc, builder.getI32Type(), 8); +  genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), +                  fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_R2 (CUDA) +void IntrinsicLibrary::genTMABulkStoreR2( +    llvm::ArrayRef<fir::ExtendedValue> args) { +  assert(args.size() == 3); +  mlir::Value eleSize = +      builder.createIntegerConstant(loc, builder.getI32Type(), 2); +  genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), +                  fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_R4 (CUDA) +void IntrinsicLibrary::genTMABulkStoreR4( +    llvm::ArrayRef<fir::ExtendedValue> args) { +  assert(args.size() == 3); +  mlir::Value eleSize = +      builder.createIntegerConstant(loc, builder.getI32Type(), 4); +  genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), +                  fir::getBase(args[2]), eleSize); +} + +// TMA_BULK_STORE_R8 (CUDA) +void IntrinsicLibrary::genTMABulkStoreR8( +    llvm::ArrayRef<fir::ExtendedValue> args) { +  assert(args.size() == 3); +  mlir::Value eleSize = +      builder.createIntegerConstant(loc, builder.getI32Type(), 8); +  genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]), +                  fir::getBase(args[2]), eleSize);  }  // TMA_BULK_WAIT_GROUP (CUDA) diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 478ab15..ca4aefb 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -680,6 +680,18 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {      if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())        llvmCall.setResAttrsAttr(resAttrs); +    if (auto inlineAttr = call.getInlineAttrAttr()) { +      llvmCall->removeAttr("inline_attr"); +      if (inlineAttr.getValue() == fir::FortranInlineEnum::no_inline) { +        llvmCall.setNoInlineAttr(rewriter.getUnitAttr()); +      } else if (inlineAttr.getValue() == fir::FortranInlineEnum::inline_hint) { +        llvmCall.setInlineHintAttr(rewriter.getUnitAttr()); +      } else if (inlineAttr.getValue() == +                 fir::FortranInlineEnum::always_inline) { +        llvmCall.setAlwaysInlineAttr(rewriter.getUnitAttr()); +      } +    } +      if (memAttr)        llvmCall.setMemoryEffectsAttr(            mlir::cast<mlir::LLVM::MemoryEffectsAttr>(memAttr)); diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index 0776346..8ca2869 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -143,7 +143,8 @@ public:          llvm::SmallVector<mlir::Type> operandsTypes;          for (auto arg : gpuLaunchFunc.getKernelOperands())            operandsTypes.push_back(arg.getType()); -        auto fctTy = mlir::FunctionType::get(&context, operandsTypes, {}); +        auto fctTy = mlir::FunctionType::get(&context, operandsTypes, +                                             gpuLaunchFunc.getResultTypes());          if (!hasPortableSignature(fctTy, op))            convertCallOp(gpuLaunchFunc, fctTy);        } else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) { @@ -520,10 +521,14 @@ public:      llvm::SmallVector<mlir::Value, 1> newCallResults;      // TODO propagate/update call argument and result attributes.      if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) { +      mlir::Value asyncToken = callOp.getAsyncToken();        auto newCall = A::create(*rewriter, loc, callOp.getKernel(),                                 callOp.getGridSizeOperandValues(),                                 callOp.getBlockSizeOperandValues(), -                               callOp.getDynamicSharedMemorySize(), newOpers); +                               callOp.getDynamicSharedMemorySize(), newOpers, +                               asyncToken ? asyncToken.getType() : nullptr, +                               callOp.getAsyncDependencies(), +                               /*clusterSize=*/std::nullopt);        if (callOp.getClusterSizeX())          newCall.getClusterSizeXMutable().assign(callOp.getClusterSizeX());        if (callOp.getClusterSizeY()) diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index d0164f3..4f97aca 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -4484,7 +4484,7 @@ void fir::IfOp::getSuccessorRegions(      llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) {    // The `then` and the `else` region branch back to the parent operation.    if (!point.isParent()) { -    regions.push_back(mlir::RegionSuccessor(getResults())); +    regions.push_back(mlir::RegionSuccessor(getOperation(), getResults()));      return;    } @@ -4494,7 +4494,8 @@ void fir::IfOp::getSuccessorRegions(    // Don't consider the else region if it is empty.    mlir::Region *elseRegion = &this->getElseRegion();    if (elseRegion->empty()) -    regions.push_back(mlir::RegionSuccessor()); +    regions.push_back( +        mlir::RegionSuccessor(getOperation(), getOperation()->getResults()));    else      regions.push_back(mlir::RegionSuccessor(elseRegion));  } @@ -4513,7 +4514,7 @@ void fir::IfOp::getEntrySuccessorRegions(      if (!getElseRegion().empty())        regions.emplace_back(&getElseRegion());      else -      regions.emplace_back(getResults()); +      regions.emplace_back(getOperation(), getOperation()->getResults());    }  } diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp index ed9e41c..ae0f5fb8 100644 --- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp @@ -193,6 +193,28 @@ OpenACCMappableModel<fir::PointerType>::getOffsetInBytes(      mlir::Type type, mlir::Value var, mlir::ValueRange accBounds,      const mlir::DataLayout &dataLayout) const; +template <typename Ty> +bool OpenACCMappableModel<Ty>::hasUnknownDimensions(mlir::Type type) const { +  assert(fir::isa_ref_type(type) && "expected FIR reference type"); +  return fir::hasDynamicSize(fir::unwrapRefType(type)); +} + +template bool OpenACCMappableModel<fir::ReferenceType>::hasUnknownDimensions( +    mlir::Type type) const; + +template bool OpenACCMappableModel<fir::HeapType>::hasUnknownDimensions( +    mlir::Type type) const; + +template bool OpenACCMappableModel<fir::PointerType>::hasUnknownDimensions( +    mlir::Type type) const; + +template <> +bool OpenACCMappableModel<fir::BaseBoxType>::hasUnknownDimensions( +    mlir::Type type) const { +  // Descriptor-based entities have dimensions encoded. +  return false; +} +  static llvm::SmallVector<mlir::Value>  generateSeqTyAccBounds(fir::SequenceType seqType, mlir::Value var,                         mlir::OpBuilder &builder) { diff --git a/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp index 25a8f7a..8c0acc5 100644 --- a/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp @@ -246,7 +246,8 @@ struct DispatchOpConv : public OpConversionPattern<fir::DispatchOp> {      args.append(dispatch.getArgs().begin(), dispatch.getArgs().end());      rewriter.replaceOpWithNewOp<fir::CallOp>(          dispatch, resTypes, nullptr, args, dispatch.getArgAttrsAttr(), -        dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr()); +        dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr(), +        /*inline_attr*/ fir::FortranInlineEnumAttr{});      return mlir::success();    } diff --git a/flang/lib/Parser/Fortran-parsers.cpp b/flang/lib/Parser/Fortran-parsers.cpp index d33a18f..59fe7d8 100644 --- a/flang/lib/Parser/Fortran-parsers.cpp +++ b/flang/lib/Parser/Fortran-parsers.cpp @@ -1314,6 +1314,11 @@ constexpr auto novector{"NOVECTOR" >> construct<CompilerDirective::NoVector>()};  constexpr auto nounroll{"NOUNROLL" >> construct<CompilerDirective::NoUnroll>()};  constexpr auto nounrollAndJam{      "NOUNROLL_AND_JAM" >> construct<CompilerDirective::NoUnrollAndJam>()}; +constexpr auto forceinlineDir{ +    "FORCEINLINE" >> construct<CompilerDirective::ForceInline>()}; +constexpr auto noinlineDir{ +    "NOINLINE" >> construct<CompilerDirective::NoInline>()}; +constexpr auto inlineDir{"INLINE" >> construct<CompilerDirective::Inline>()};  TYPE_PARSER(beginDirective >> "DIR$ "_tok >>      sourced((construct<CompilerDirective>(ignore_tkr) ||                  construct<CompilerDirective>(loopCount) || @@ -1324,6 +1329,9 @@ TYPE_PARSER(beginDirective >> "DIR$ "_tok >>                  construct<CompilerDirective>(novector) ||                  construct<CompilerDirective>(nounrollAndJam) ||                  construct<CompilerDirective>(nounroll) || +                construct<CompilerDirective>(noinlineDir) || +                construct<CompilerDirective>(forceinlineDir) || +                construct<CompilerDirective>(inlineDir) ||                  construct<CompilerDirective>(                      many(construct<CompilerDirective::NameValue>(                          name, maybe(("="_tok || ":"_tok) >> digitString64))))) / diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index d1e081c..4159d2e 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -275,6 +275,13 @@ struct SpecificModifierParser {  // --- Iterator helpers ----------------------------------------------- +static EntityDecl MakeEntityDecl(ObjectName &&name) { +  return EntityDecl( +      /*ObjectName=*/std::move(name), std::optional<ArraySpec>{}, +      std::optional<CoarraySpec>{}, std::optional<CharLength>{}, +      std::optional<Initialization>{}); +} +  // [5.0:47:17-18] In an iterator-specifier, if the iterator-type is not  // specified then the type of that iterator is default integer.  // [5.0:49:14] The iterator-type must be an integer type. @@ -282,11 +289,7 @@ static std::list<EntityDecl> makeEntityList(std::list<ObjectName> &&names) {    std::list<EntityDecl> entities;    for (auto iter = names.begin(), end = names.end(); iter != end; ++iter) { -    EntityDecl entityDecl( -        /*ObjectName=*/std::move(*iter), std::optional<ArraySpec>{}, -        std::optional<CoarraySpec>{}, std::optional<CharLength>{}, -        std::optional<Initialization>{}); -    entities.push_back(std::move(entityDecl)); +    entities.push_back(MakeEntityDecl(std::move(*iter)));    }    return entities;  } @@ -306,6 +309,217 @@ static TypeDeclarationStmt makeIterSpecDecl(std::list<ObjectName> &&names) {        makeEntityList(std::move(names)));  } +// --- Stylized expression handling ----------------------------------- + +// OpenMP has a concept of am "OpenMP stylized expression". Syntactially +// it looks like a typical Fortran expression (or statement), except: +// - the only variables allowed in it are OpenMP special variables, the +//   exact set of these variables depends on the specific case of the +//   stylized expression +// - the special OpenMP variables present may assume one or more types, +//   and the expression should be semantically valid for each type. +// +// The stylized expression can be thought of as a template, which will be +// instantiated for each type provided somewhere in the context in which +// the stylized expression appears. +// +// AST nodes: +// - OmpStylizedExpression: contains the source string for the expression, +//   plus the list of instances (OmpStylizedInstance). +// - OmpStylizedInstance: corresponds to the instantiation of the stylized +//   expression for a specific type. The way that the type is specified is +//   by creating declarations (OmpStylizedDeclaration) for the special +//   variables. Together with the AST tree corresponding to the stylized +//   expression the instantiation has enough information for semantic +//   analysis. Each instance has its own scope, and the special variables +//   have their own Symbol's (local to the scope). +// - OmpStylizedDeclaration: encapsulates the information that the visitors +//   in resolve-names can use to "emulate" a declaration for a special +//   variable and allow name resolution in the instantiation AST to work. +// +// Implementation specifics: +// The semantic analysis stores "evaluate::Expr" in each AST node rooted +// in parser::Expr (in the typedExpr member). The evaluate::Expr is specific +// to a given type, and so to allow different types for a given expression, +// for each type a separate copy of the parser::Expr subtree is created. +// Normally, AST nodes are non-copyable (copy-ctor is deleted), so to create +// several copies of a subtree, the same source string is parsed several +// times. The ParseState member in OmpStylizedExpression is the parser state +// immediately before the stylized expression. +// +// Initially, when OmpStylizedExpression is first created, the expression is +// parsed as if it was an actual code, but this parsing is only done to +// establish where the stylized expression ends (in the source). The source +// and the initial parser state are stored in the object, and the instance +// list is empty. +// Once the parsing of the containing OmpDirectiveSpecification completes, +// a post-processing "parser" (OmpStylizedInstanceCreator) executes. This +// post-processor examines the directive specification to see if it expects +// any stylized expressions to be contained in it, and then instantiates +// them for each such directive. + +template <typename A> struct NeverParser { +  using resultType = A; +  std::optional<resultType> Parse(ParseState &state) const { +    // Always fail, but without any messages. +    return std::nullopt; +  } +}; + +template <typename A> constexpr auto never() { return NeverParser<A>{}; } + +// Parser for optional<T> which always succeeds and returns std::nullptr. +// It's only needed to produce "std::optional<CallStmt::Chevrons>" in +// CallStmt. +template <typename A, typename B = void> struct NullParser; +template <typename B> struct NullParser<std::optional<B>> { +  using resultType = std::optional<B>; +  std::optional<resultType> Parse(ParseState &) const { +    return resultType{std::nullopt}; +  } +}; + +template <typename A> constexpr auto null() { return NullParser<A>{}; } + +// OmpStylizedDeclaration and OmpStylizedInstance are helper classes, and +// don't correspond to anything in the source. Their parsers should still +// exist, but they should never be executed. +TYPE_PARSER(construct<OmpStylizedDeclaration>(never<OmpStylizedDeclaration>())) +TYPE_PARSER(construct<OmpStylizedInstance>(never<OmpStylizedInstance>())) + +TYPE_PARSER( // +    construct<OmpStylizedInstance::Instance>(Parser<AssignmentStmt>{}) || +    construct<OmpStylizedInstance::Instance>( +        sourced(construct<CallStmt>(Parser<ProcedureDesignator>{}, +            null<std::optional<CallStmt::Chevrons>>(), +            parenthesized(optionalList(actualArgSpec))))) || +    construct<OmpStylizedInstance::Instance>(indirect(expr))) + +struct OmpStylizedExpressionParser { +  using resultType = OmpStylizedExpression; + +  std::optional<resultType> Parse(ParseState &state) const { +    auto *saved{new ParseState(state)}; +    auto getSource{verbatim(Parser<OmpStylizedInstance::Instance>{} >> ok)}; +    if (auto &&ok{getSource.Parse(state)}) { +      OmpStylizedExpression result{std::list<OmpStylizedInstance>{}}; +      result.source = ok->source; +      result.state = saved; +      // result.v remains empty +      return std::move(result); +    } +    delete saved; +    return std::nullopt; +  } +}; + +static void Instantiate(OmpStylizedExpression &ose, +    llvm::ArrayRef<const OmpTypeName *> types, llvm::ArrayRef<CharBlock> vars) { +  // 1. For each var in the vars list, declare it with the corresponding +  //    type from types. +  // 2. Run the parser to get the AST for the stylized expression. +  // 3. Create OmpStylizedInstance and append it to the list in ose. +  assert(types.size() == vars.size() && "List size mismatch"); +  // A ParseState object is irreversibly modified during parsing (in +  // particular, it cannot be rewound to an earlier position in the source). +  // Because of that we need to create a local copy for each instantiation. +  // If rewinding was possible, we could just use the current one, and we +  // wouldn't need to save it in the AST node. +  ParseState state{DEREF(ose.state)}; + +  std::list<OmpStylizedDeclaration> decls; +  for (auto [type, var] : llvm::zip_equal(types, vars)) { +    decls.emplace_back(OmpStylizedDeclaration{ +        common::Reference(*type), MakeEntityDecl(Name{var})}); +  } + +  if (auto &&instance{Parser<OmpStylizedInstance::Instance>{}.Parse(state)}) { +    ose.v.emplace_back( +        OmpStylizedInstance{std::move(decls), std::move(*instance)}); +  } +} + +static void InstantiateForTypes(OmpStylizedExpression &ose, +    const OmpTypeNameList &typeNames, llvm::ArrayRef<CharBlock> vars) { +  // For each type in the type list, declare all variables in vars with +  // that type, and complete the instantiation. +  for (const OmpTypeName &t : typeNames.v) { +    std::vector<const OmpTypeName *> types(vars.size(), &t); +    Instantiate(ose, types, vars); +  } +} + +static void InstantiateDeclareReduction(OmpDirectiveSpecification &spec) { +  // There can be arguments/clauses that don't make sense, that analysis +  // is left until semantic checks. Tolerate any unexpected stuff. +  auto *rspec{GetFirstArgument<OmpReductionSpecifier>(spec)}; +  if (!rspec) { +    return; +  } + +  const OmpTypeNameList *typeNames{nullptr}; + +  if (auto *cexpr{ +          const_cast<OmpCombinerExpression *>(GetCombinerExpr(*rspec))}) { +    typeNames = &std::get<OmpTypeNameList>(rspec->t); + +    InstantiateForTypes(*cexpr, *typeNames, OmpCombinerExpression::Variables()); +    delete cexpr->state; +    cexpr->state = nullptr; +  } else { +    // If there are no types, there is nothing else to do. +    return; +  } + +  for (const OmpClause &clause : spec.Clauses().v) { +    llvm::omp::Clause id{clause.Id()}; +    if (id == llvm::omp::Clause::OMPC_initializer) { +      if (auto *iexpr{const_cast<OmpInitializerExpression *>( +              GetInitializerExpr(clause))}) { +        InstantiateForTypes( +            *iexpr, *typeNames, OmpInitializerExpression::Variables()); +        delete iexpr->state; +        iexpr->state = nullptr; +      } +    } +  } +} + +static void InstantiateStylizedDirective(OmpDirectiveSpecification &spec) { +  const OmpDirectiveName &dirName{spec.DirName()}; +  if (dirName.v == llvm::omp::Directive::OMPD_declare_reduction) { +    InstantiateDeclareReduction(spec); +  } +} + +template <typename P, +    typename = std::enable_if_t< +        std::is_same_v<typename P::resultType, OmpDirectiveSpecification>>> +struct OmpStylizedInstanceCreator { +  using resultType = OmpDirectiveSpecification; +  constexpr OmpStylizedInstanceCreator(P p) : parser_(p) {} + +  std::optional<resultType> Parse(ParseState &state) const { +    if (auto &&spec{parser_.Parse(state)}) { +      InstantiateStylizedDirective(*spec); +      return std::move(spec); +    } +    return std::nullopt; +  } + +private: +  const P parser_; +}; + +template <typename P> +OmpStylizedInstanceCreator(P) -> OmpStylizedInstanceCreator<P>; + +// --- Parsers for types ---------------------------------------------- + +TYPE_PARSER( // +    sourced(construct<OmpTypeName>(Parser<DeclarationTypeSpec>{})) || +    sourced(construct<OmpTypeName>(Parser<TypeSpec>{}))) +  // --- Parsers for arguments ------------------------------------------  // At the moment these are only directive arguments. This is needed for @@ -366,10 +580,6 @@ struct OmpArgumentListParser {    }  }; -TYPE_PARSER( // -    construct<OmpTypeName>(Parser<DeclarationTypeSpec>{}) || -    construct<OmpTypeName>(Parser<TypeSpec>{})) -  // 2.15.3.6 REDUCTION (reduction-identifier: variable-name-list)  TYPE_PARSER(construct<OmpReductionIdentifier>(Parser<DefinedOperator>{}) ||      construct<OmpReductionIdentifier>(Parser<ProcedureDesignator>{})) @@ -1065,7 +1275,8 @@ TYPE_PARSER(construct<OmpOtherwiseClause>(  TYPE_PARSER(construct<OmpWhenClause>(      maybe(nonemptyList(Parser<OmpWhenClause::Modifier>{}) / ":"), -    maybe(indirect(Parser<OmpDirectiveSpecification>{})))) +    maybe(indirect( +        OmpStylizedInstanceCreator(Parser<OmpDirectiveSpecification>{})))))  // OMP 5.2 12.6.1 grainsize([ prescriptiveness :] scalar-integer-expression)  TYPE_PARSER(construct<OmpGrainsizeClause>( @@ -1777,12 +1988,7 @@ TYPE_PARSER(              Parser<OpenMPInteropConstruct>{})) /      endOfLine) -TYPE_PARSER(construct<OmpInitializerProc>(Parser<ProcedureDesignator>{}, -    parenthesized(many(maybe(","_tok) >> Parser<ActualArgSpec>{})))) - -TYPE_PARSER(construct<OmpInitializerClause>( -    construct<OmpInitializerClause>(assignmentStmt) || -    construct<OmpInitializerClause>(Parser<OmpInitializerProc>{}))) +TYPE_PARSER(construct<OmpInitializerClause>(Parser<OmpInitializerExpression>{}))  // OpenMP 5.2: 7.5.4 Declare Variant directive  TYPE_PARSER(sourced(construct<OmpDeclareVariantDirective>( @@ -1794,7 +2000,7 @@ TYPE_PARSER(sourced(construct<OmpDeclareVariantDirective>(  TYPE_PARSER(sourced(construct<OpenMPDeclareReductionConstruct>(      predicated(Parser<OmpDirectiveName>{},          IsDirective(llvm::omp::Directive::OMPD_declare_reduction)) >= -    Parser<OmpDirectiveSpecification>{}))) +    OmpStylizedInstanceCreator(Parser<OmpDirectiveSpecification>{}))))  // 2.10.6 Declare Target Construct  TYPE_PARSER(sourced(construct<OpenMPDeclareTargetConstruct>( @@ -1832,8 +2038,8 @@ TYPE_PARSER(sourced(construct<OpenMPDeclareMapperConstruct>(          IsDirective(llvm::omp::Directive::OMPD_declare_mapper)) >=      Parser<OmpDirectiveSpecification>{}))) -TYPE_PARSER(construct<OmpCombinerExpression>(Parser<AssignmentStmt>{}) || -    construct<OmpCombinerExpression>(Parser<FunctionReference>{})) +TYPE_PARSER(construct<OmpCombinerExpression>(OmpStylizedExpressionParser{})) +TYPE_PARSER(construct<OmpInitializerExpression>(OmpStylizedExpressionParser{}))  TYPE_PARSER(sourced(construct<OpenMPCriticalConstruct>(      OmpBlockConstructParser{llvm::omp::Directive::OMPD_critical}))) diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp index 937a17f..95ad3f6 100644 --- a/flang/lib/Parser/openmp-utils.cpp +++ b/flang/lib/Parser/openmp-utils.cpp @@ -74,4 +74,16 @@ const BlockConstruct *GetFortranBlockConstruct(    return nullptr;  } +const OmpCombinerExpression *GetCombinerExpr( +    const OmpReductionSpecifier &rspec) { +  return addr_if(std::get<std::optional<OmpCombinerExpression>>(rspec.t)); +} + +const OmpInitializerExpression *GetInitializerExpr(const OmpClause &init) { +  if (auto *wrapped{std::get_if<OmpClause::Initializer>(&init.u)}) { +    return &wrapped->v.v; +  } +  return nullptr; +} +  } // namespace Fortran::parser::omp diff --git a/flang/lib/Parser/parse-tree.cpp b/flang/lib/Parser/parse-tree.cpp index 8cbaa39..ad0016e 100644 --- a/flang/lib/Parser/parse-tree.cpp +++ b/flang/lib/Parser/parse-tree.cpp @@ -11,6 +11,7 @@  #include "flang/Common/indirection.h"  #include "flang/Parser/tools.h"  #include "flang/Parser/user-state.h" +#include "llvm/ADT/ArrayRef.h"  #include "llvm/Frontend/OpenMP/OMP.h"  #include "llvm/Support/raw_ostream.h"  #include <algorithm> @@ -430,4 +431,30 @@ const OmpClauseList &OmpDirectiveSpecification::Clauses() const {    }    return empty;  } + +static bool InitCharBlocksFromStrings(llvm::MutableArrayRef<CharBlock> blocks, +    llvm::ArrayRef<std::string> strings) { +  for (auto [i, n] : llvm::enumerate(strings)) { +    blocks[i] = CharBlock(n); +  } +  return true; +} + +// The names should have static storage duration. Keep these names +// in a sigle place. +llvm::ArrayRef<CharBlock> OmpCombinerExpression::Variables() { +  static std::string names[]{"omp_in", "omp_out"}; +  static CharBlock vars[std::size(names)]; + +  [[maybe_unused]] static bool init = InitCharBlocksFromStrings(vars, names); +  return vars; +} + +llvm::ArrayRef<CharBlock> OmpInitializerExpression::Variables() { +  static std::string names[]{"omp_orig", "omp_priv"}; +  static CharBlock vars[std::size(names)]; + +  [[maybe_unused]] static bool init = InitCharBlocksFromStrings(vars, names); +  return vars; +}  } // namespace Fortran::parser diff --git a/flang/lib/Parser/prescan.cpp b/flang/lib/Parser/prescan.cpp index 4739da0..fd69404 100644 --- a/flang/lib/Parser/prescan.cpp +++ b/flang/lib/Parser/prescan.cpp @@ -557,7 +557,7 @@ bool Prescanner::MustSkipToEndOfLine() const {      return true; // skip over ignored columns in right margin (73:80)    } else if (*at_ == '!' && !inCharLiteral_ &&        (!inFixedForm_ || tabInCurrentLine_ || column_ != 6)) { -    return !IsCompilerDirectiveSentinel(at_); +    return !IsCompilerDirectiveSentinel(at_ + 1);    } else {      return false;    } diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp index 2f86c76..9b38cfc 100644 --- a/flang/lib/Parser/unparse.cpp +++ b/flang/lib/Parser/unparse.cpp @@ -1867,6 +1867,13 @@ public:              [&](const CompilerDirective::NoUnrollAndJam &) {                Word("!DIR$ NOUNROLL_AND_JAM");              }, +            [&](const CompilerDirective::ForceInline &) { +              Word("!DIR$ FORCEINLINE"); +            }, +            [&](const CompilerDirective::Inline &) { Word("!DIR$ INLINE"); }, +            [&](const CompilerDirective::NoInline &) { +              Word("!DIR$ NOINLINE"); +            },              [&](const CompilerDirective::Unrecognized &) {                Word("!DIR$ ");                Word(x.source.ToString()); @@ -2088,15 +2095,13 @@ public:    // OpenMP Clauses & Directives    void Unparse(const OmpArgumentList &x) { Walk(x.v, ", "); } +  void Unparse(const OmpTypeNameList &x) { Walk(x.v, ", "); }    void Unparse(const OmpBaseVariantNames &x) {      Walk(std::get<0>(x.t)); // OmpObject      Put(":");      Walk(std::get<1>(x.t)); // OmpObject    } -  void Unparse(const OmpTypeNameList &x) { // -    Walk(x.v, ","); -  }    void Unparse(const OmpMapperSpecifier &x) {      const auto &mapperName{std::get<std::string>(x.t)};      if (mapperName.find(llvm::omp::OmpDefaultMapperName) == std::string::npos) { @@ -2195,6 +2200,15 @@ public:      unsigned ompVersion{langOpts_.OpenMPVersion};      Word(llvm::omp::getOpenMPDirectiveName(x.v, ompVersion));    } +  void Unparse(const OmpStylizedDeclaration &x) { +    // empty +  } +  void Unparse(const OmpStylizedExpression &x) { // +    Put(x.source.ToString()); +  } +  void Unparse(const OmpStylizedInstance &x) { +    // empty +  }    void Unparse(const OmpIteratorSpecifier &x) {      Walk(std::get<TypeDeclarationStmt>(x.t));      Put(" = "); @@ -2504,29 +2518,11 @@ public:    void Unparse(const OpenMPCriticalConstruct &x) {      Unparse(static_cast<const OmpBlockConstruct &>(x));    } -  void Unparse(const OmpInitializerProc &x) { -    Walk(std::get<ProcedureDesignator>(x.t)); -    Put("("); -    Walk(std::get<std::list<ActualArgSpec>>(x.t)); -    Put(")"); -  } -  void Unparse(const OmpInitializerClause &x) { -    // Don't let the visitor go to the normal AssignmentStmt Unparse function, -    // it adds an extra newline that we don't want. -    if (const auto *assignment{std::get_if<AssignmentStmt>(&x.u)}) { -      Walk(assignment->t, " = "); -    } else { -      Walk(x.u); -    } +  void Unparse(const OmpInitializerExpression &x) { +    Unparse(static_cast<const OmpStylizedExpression &>(x));    }    void Unparse(const OmpCombinerExpression &x) { -    // Don't let the visitor go to the normal AssignmentStmt Unparse function, -    // it adds an extra newline that we don't want. -    if (const auto *assignment{std::get_if<AssignmentStmt>(&x.u)}) { -      Walk(assignment->t, " = "); -    } else { -      Walk(x.u); -    } +    Unparse(static_cast<const OmpStylizedExpression &>(x));    }    void Unparse(const OpenMPDeclareReductionConstruct &x) {      BeginOpenMP(); diff --git a/flang/lib/Semantics/canonicalize-directives.cpp b/flang/lib/Semantics/canonicalize-directives.cpp index 104df25..a651a87 100644 --- a/flang/lib/Semantics/canonicalize-directives.cpp +++ b/flang/lib/Semantics/canonicalize-directives.cpp @@ -60,7 +60,11 @@ static bool IsExecutionDirective(const parser::CompilerDirective &dir) {        std::holds_alternative<parser::CompilerDirective::UnrollAndJam>(dir.u) ||        std::holds_alternative<parser::CompilerDirective::NoVector>(dir.u) ||        std::holds_alternative<parser::CompilerDirective::NoUnroll>(dir.u) || -      std::holds_alternative<parser::CompilerDirective::NoUnrollAndJam>(dir.u); +      std::holds_alternative<parser::CompilerDirective::NoUnrollAndJam>( +          dir.u) || +      std::holds_alternative<parser::CompilerDirective::ForceInline>(dir.u) || +      std::holds_alternative<parser::CompilerDirective::Inline>(dir.u) || +      std::holds_alternative<parser::CompilerDirective::NoInline>(dir.u);  }  void CanonicalizationOfDirectives::Post(parser::SpecificationPart &spec) { diff --git a/flang/lib/Semantics/check-call.cpp b/flang/lib/Semantics/check-call.cpp index c51d40b..995deaa 100644 --- a/flang/lib/Semantics/check-call.cpp +++ b/flang/lib/Semantics/check-call.cpp @@ -914,7 +914,8 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,              dummyName);        }        // INTENT(OUT) and INTENT(IN OUT) cases are caught elsewhere -    } else { +    } else if (!actualIsAllocatable && +        !dummy.ignoreTKR.test(common::IgnoreTKR::Pointer)) {        messages.Say(            "ALLOCATABLE %s must be associated with an ALLOCATABLE actual argument"_err_en_US,            dummyName); @@ -929,7 +930,8 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,              dummy, actual, *scope,              /*isAssumedRank=*/dummyIsAssumedRank, actualIsPointer);        } -    } else if (!actualIsPointer) { +    } else if (!actualIsPointer && +        !dummy.ignoreTKR.test(common::IgnoreTKR::Pointer)) {        messages.Say(            "Actual argument associated with POINTER %s must also be POINTER unless INTENT(IN)"_err_en_US,            dummyName); diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp index 549ee83..de407d3 100644 --- a/flang/lib/Semantics/check-declarations.cpp +++ b/flang/lib/Semantics/check-declarations.cpp @@ -949,7 +949,8 @@ void CheckHelper::CheckObjectEntity(              "!DIR$ IGNORE_TKR(R) may not apply in an ELEMENTAL procedure"_err_en_US);        }        if (IsPassedViaDescriptor(symbol)) { -        if (IsAllocatableOrObjectPointer(&symbol)) { +        if (IsAllocatableOrObjectPointer(&symbol) && +            !ignoreTKR.test(common::IgnoreTKR::Pointer)) {            if (inExplicitExternalInterface) {              Warn(common::UsageWarning::IgnoreTKRUsage,                  "!DIR$ IGNORE_TKR should not apply to an allocatable or pointer"_warn_en_US); diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index e094458f..aaaf1ec 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -3390,6 +3390,7 @@ CHECK_SIMPLE_CLAUSE(Read, OMPC_read)  CHECK_SIMPLE_CLAUSE(Threadprivate, OMPC_threadprivate)  CHECK_SIMPLE_CLAUSE(Groupprivate, OMPC_groupprivate)  CHECK_SIMPLE_CLAUSE(Threads, OMPC_threads) +CHECK_SIMPLE_CLAUSE(Threadset, OMPC_threadset)  CHECK_SIMPLE_CLAUSE(Inbranch, OMPC_inbranch)  CHECK_SIMPLE_CLAUSE(Link, OMPC_link)  CHECK_SIMPLE_CLAUSE(Indirect, OMPC_indirect) diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp index 556259d..b419864 100644 --- a/flang/lib/Semantics/mod-file.cpp +++ b/flang/lib/Semantics/mod-file.cpp @@ -1021,6 +1021,9 @@ void ModFileWriter::PutObjectEntity(        case common::IgnoreTKR::Contiguous:          os << 'c';          break; +      case common::IgnoreTKR::Pointer: +        os << 'p'; +        break;        }      });      os << ") " << symbol.name() << '\n'; diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 196755e..628068f 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -26,6 +26,8 @@  #include "flang/Semantics/symbol.h"  #include "flang/Semantics/tools.h"  #include "flang/Support/Flags.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h"  #include "llvm/Frontend/OpenMP/OMP.h.inc"  #include "llvm/Support/Debug.h"  #include <list> @@ -453,6 +455,21 @@ public:      return true;    } +  bool Pre(const parser::OmpStylizedDeclaration &x) { +    static llvm::StringMap<Symbol::Flag> map{ +        {"omp_in", Symbol::Flag::OmpInVar}, +        {"omp_orig", Symbol::Flag::OmpOrigVar}, +        {"omp_out", Symbol::Flag::OmpOutVar}, +        {"omp_priv", Symbol::Flag::OmpPrivVar}, +    }; +    if (auto &name{std::get<parser::ObjectName>(x.var.t)}; name.symbol) { +      if (auto found{map.find(name.ToString())}; found != map.end()) { +        ResolveOmp(name, found->second, +            const_cast<Scope &>(DEREF(name.symbol).owner())); +      } +    } +    return false; +  }    bool Pre(const parser::OmpMetadirectiveDirective &x) {      PushContext(x.v.source, llvm::omp::Directive::OMPD_metadirective);      return true; diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index 561ebd2..f88af5f 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -1605,6 +1605,12 @@ public:      Post(static_cast<const parser::OmpDirectiveSpecification &>(x));    } +  void Post(const parser::OmpTypeName &); +  bool Pre(const parser::OmpStylizedDeclaration &); +  void Post(const parser::OmpStylizedDeclaration &); +  bool Pre(const parser::OmpStylizedInstance &); +  void Post(const parser::OmpStylizedInstance &); +    bool Pre(const parser::OpenMPDeclareMapperConstruct &x) {      AddOmpSourceRange(x.source);      return true; @@ -1615,18 +1621,6 @@ public:      return true;    } -  bool Pre(const parser::OmpInitializerProc &x) { -    auto &procDes = std::get<parser::ProcedureDesignator>(x.t); -    auto &name = std::get<parser::Name>(procDes.u); -    auto *symbol{FindSymbol(NonDerivedTypeScope(), name)}; -    if (!symbol) { -      context().Say(name.source, -          "Implicit subroutine declaration '%s' in DECLARE REDUCTION"_err_en_US, -          name.source); -    } -    return true; -  } -    bool Pre(const parser::OmpDeclareVariantDirective &x) {      AddOmpSourceRange(x.source);      return true; @@ -1772,14 +1766,6 @@ public:      messageHandler().set_currStmtSource(std::nullopt);    } -  bool Pre(const parser::OmpTypeName &x) { -    BeginDeclTypeSpec(); -    return true; -  } -  void Post(const parser::OmpTypeName &x) { // -    EndDeclTypeSpec(); -  } -    bool Pre(const parser::OpenMPConstruct &x) {      // Indicate that the current directive is not a declarative one.      declaratives_.push_back(nullptr); @@ -1835,6 +1821,30 @@ void OmpVisitor::Post(const parser::OmpBlockConstruct &x) {    }  } +void OmpVisitor::Post(const parser::OmpTypeName &x) { +  x.declTypeSpec = GetDeclTypeSpec(); +} + +bool OmpVisitor::Pre(const parser::OmpStylizedDeclaration &x) { +  BeginDecl(); +  Walk(x.type.get()); +  Walk(x.var); +  return true; +} + +void OmpVisitor::Post(const parser::OmpStylizedDeclaration &x) { // +  EndDecl(); +} + +bool OmpVisitor::Pre(const parser::OmpStylizedInstance &x) { +  PushScope(Scope::Kind::OtherConstruct, nullptr); +  return true; +} + +void OmpVisitor::Post(const parser::OmpStylizedInstance &x) { // +  PopScope(); +} +  bool OmpVisitor::Pre(const parser::OmpMapClause &x) {    auto &mods{OmpGetModifiers(x)};    if (auto *mapper{OmpGetUniqueModifier<parser::OmpMapper>(mods)}) { @@ -1969,51 +1979,20 @@ void OmpVisitor::ProcessReductionSpecifier(      }    } -  auto &typeList{std::get<parser::OmpTypeNameList>(spec.t)}; - -  // Create a temporary variable declaration for the four variables -  // used in the reduction specifier and initializer (omp_out, omp_in, -  // omp_priv and omp_orig), with the type in the  typeList. -  // -  // In theory it would be possible to create only variables that are -  // actually used, but that requires walking the entire parse-tree of the -  // expressions, and finding the relevant variables [there may well be other -  // variables involved too]. -  // -  // This allows doing semantic analysis where the type is a derived type -  // e.g omp_out%x = omp_out%x + omp_in%x. -  // -  // These need to be temporary (in their own scope). If they are created -  // as variables in the outer scope, if there's more than one type in the -  // typelist, duplicate symbols will be reported. -  const parser::CharBlock ompVarNames[]{ -      {"omp_in", 6}, {"omp_out", 7}, {"omp_priv", 8}, {"omp_orig", 8}}; - -  for (auto &t : typeList.v) { -    PushScope(Scope::Kind::OtherConstruct, nullptr); -    BeginDeclTypeSpec(); -    // We need to walk t.u because Walk(t) does it's own BeginDeclTypeSpec. -    Walk(t.u); +  reductionDetails->AddDecl(declaratives_.back()); -    // Only process types we can find. There will be an error later on when -    // a type isn't found. -    if (const DeclTypeSpec *typeSpec{GetDeclTypeSpec()}) { -      reductionDetails->AddType(*typeSpec); +  // Do not walk OmpTypeNameList. The types on the list will be visited +  // during procesing of OmpCombinerExpression. +  Walk(std::get<std::optional<parser::OmpCombinerExpression>>(spec.t)); +  Walk(clauses); -      for (auto &nm : ompVarNames) { -        ObjectEntityDetails details{}; -        details.set_type(*typeSpec); -        MakeSymbol(nm, Attrs{}, std::move(details)); -      } +  for (auto &type : std::get<parser::OmpTypeNameList>(spec.t).v) { +    // The declTypeSpec can be null if there is some semantic error. +    if (type.declTypeSpec) { +      reductionDetails->AddType(*type.declTypeSpec);      } -    EndDeclTypeSpec(); -    Walk(std::get<std::optional<parser::OmpCombinerExpression>>(spec.t)); -    Walk(clauses); -    PopScope();    } -  reductionDetails->AddDecl(declaratives_.back()); -    if (!symbol) {      symbol = &MakeSymbol(mangledName, Attrs{}, std::move(*reductionDetails));    } @@ -10078,7 +10057,10 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {        std::holds_alternative<parser::CompilerDirective::UnrollAndJam>(x.u) ||        std::holds_alternative<parser::CompilerDirective::NoVector>(x.u) ||        std::holds_alternative<parser::CompilerDirective::NoUnroll>(x.u) || -      std::holds_alternative<parser::CompilerDirective::NoUnrollAndJam>(x.u)) { +      std::holds_alternative<parser::CompilerDirective::NoUnrollAndJam>(x.u) || +      std::holds_alternative<parser::CompilerDirective::ForceInline>(x.u) || +      std::holds_alternative<parser::CompilerDirective::Inline>(x.u) || +      std::holds_alternative<parser::CompilerDirective::NoInline>(x.u)) {      return;    }    if (const auto *tkr{ @@ -10127,6 +10109,9 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {                case 'c':                  set.set(common::IgnoreTKR::Contiguous);                  break; +              case 'p': +                set.set(common::IgnoreTKR::Pointer); +                break;                case 'a':                  set = common::ignoreTKRAll;                  break; diff --git a/flang/lib/Support/Fortran.cpp b/flang/lib/Support/Fortran.cpp index 3a8ebbb..05d6e0e 100644 --- a/flang/lib/Support/Fortran.cpp +++ b/flang/lib/Support/Fortran.cpp @@ -95,6 +95,9 @@ std::string AsFortran(IgnoreTKRSet tkr) {    if (tkr.test(IgnoreTKR::Contiguous)) {      result += 'C';    } +  if (tkr.test(IgnoreTKR::Pointer)) { +    result += 'P'; +  }    return result;  } | 
