diff options
Diffstat (limited to 'flang/lib')
21 files changed, 528 insertions, 223 deletions
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 3b711cc..a516a44 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -1766,7 +1766,7 @@ private: // to a crash due to a block with no terminator. See issue #126452. mlir::FunctionType funcType = builder->getFunction().getFunctionType(); mlir::Type resultType = funcType.getResult(0); - mlir::Value undefResult = builder->create<fir::UndefOp>(loc, resultType); + mlir::Value undefResult = fir::UndefOp::create(*builder, loc, resultType); genExitRoutine(false, undefResult); return; } @@ -4010,8 +4010,8 @@ private: // parameters and dynamic type. The selector cannot be a // POINTER/ALLOCATBLE as per F'2023 C1160. fir::ExtendedValue newExv; - llvm::SmallVector assumeSizeExtents{ - builder->createMinusOneInteger(loc, builder->getIndexType())}; + llvm::SmallVector<mlir::Value> assumeSizeExtents{ + fir::AssumedSizeExtentOp::create(*builder, loc)}; mlir::Value baseAddr = hlfir::genVariableRawAddress(loc, *builder, selector); const bool isVolatile = fir::isa_volatile_type(selector.getType()); @@ -4733,11 +4733,21 @@ private: return fir::factory::createUnallocatedBox(*builder, loc, lhsBoxType, {}); hlfir::Entity rhs = Fortran::lower::convertExprToHLFIR( loc, *this, assign.rhs, localSymbols, rhsContext); + auto rhsBoxType = rhs.getBoxType(); // Create pointer descriptor value from the RHS. if (rhs.isMutableBox()) rhs = hlfir::Entity{fir::LoadOp::create(*builder, loc, rhs)}; - mlir::Value rhsBox = hlfir::genVariableBox( - loc, *builder, rhs, lhsBoxType.getBoxTypeWithNewShape(rhs.getRank())); + + // Use LHS type if LHS is not polymorphic. + fir::BaseBoxType targetBoxType; + if (assign.lhs.GetType()->IsPolymorphic()) + targetBoxType = rhsBoxType.getBoxTypeWithNewAttr( + fir::BaseBoxType::Attribute::Pointer); + else + targetBoxType = lhsBoxType.getBoxTypeWithNewShape(rhs.getRank()); + mlir::Value rhsBox = + hlfir::genVariableBox(loc, *builder, rhs, targetBoxType); + // Apply lower bounds or reshaping if any. if (const auto *lbExprs = std::get_if<Fortran::evaluate::Assignment::BoundsSpec>(&assign.u); diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp index 00ec1b5..2517ab3 100644 --- a/flang/lib/Lower/ConvertVariable.cpp +++ b/flang/lib/Lower/ConvertVariable.cpp @@ -1711,7 +1711,7 @@ static void lowerExplicitLowerBounds( /// CFI_desc_t requirements in 18.5.3 point 5.). static mlir::Value getAssumedSizeExtent(mlir::Location loc, fir::FirOpBuilder &builder) { - return builder.createMinusOneInteger(loc, builder.getIndexType()); + return fir::AssumedSizeExtentOp::create(builder, loc); } /// Lower explicit extents into \p result if this is an explicit-shape or diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index a49961c..7106728 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -2059,37 +2059,38 @@ static void genCanonicalLoopNest( // Start lowering mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0); mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1); - mlir::Value isDownwards = firOpBuilder.create<mlir::arith::CmpIOp>( - loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero); + mlir::Value isDownwards = mlir::arith::CmpIOp::create( + firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero); // Ensure we are counting upwards. If not, negate step and swap lb and ub. mlir::Value negStep = - firOpBuilder.create<mlir::arith::SubIOp>(loc, zero, loopStepVar); - mlir::Value incr = firOpBuilder.create<mlir::arith::SelectOp>( - loc, isDownwards, negStep, loopStepVar); - mlir::Value lb = firOpBuilder.create<mlir::arith::SelectOp>( - loc, isDownwards, loopUBVar, loopLBVar); - mlir::Value ub = firOpBuilder.create<mlir::arith::SelectOp>( - loc, isDownwards, loopLBVar, loopUBVar); + mlir::arith::SubIOp::create(firOpBuilder, loc, zero, loopStepVar); + mlir::Value incr = mlir::arith::SelectOp::create( + firOpBuilder, loc, isDownwards, negStep, loopStepVar); + mlir::Value lb = mlir::arith::SelectOp::create( + firOpBuilder, loc, isDownwards, loopUBVar, loopLBVar); + mlir::Value ub = mlir::arith::SelectOp::create( + firOpBuilder, loc, isDownwards, loopLBVar, loopUBVar); // Compute the trip count assuming lb <= ub. This guarantees that the result // is non-negative and we can use unsigned arithmetic. - mlir::Value span = firOpBuilder.create<mlir::arith::SubIOp>( - loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw); + mlir::Value span = mlir::arith::SubIOp::create( + firOpBuilder, loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw); mlir::Value tcMinusOne = - firOpBuilder.create<mlir::arith::DivUIOp>(loc, span, incr); - mlir::Value tcIfLooping = firOpBuilder.create<mlir::arith::AddIOp>( - loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw); + mlir::arith::DivUIOp::create(firOpBuilder, loc, span, incr); + mlir::Value tcIfLooping = + mlir::arith::AddIOp::create(firOpBuilder, loc, tcMinusOne, one, + ::mlir::arith::IntegerOverflowFlags::nuw); // Fall back to 0 if lb > ub - mlir::Value isZeroTC = firOpBuilder.create<mlir::arith::CmpIOp>( - loc, mlir::arith::CmpIPredicate::slt, ub, lb); - mlir::Value tripcount = firOpBuilder.create<mlir::arith::SelectOp>( - loc, isZeroTC, zero, tcIfLooping); + mlir::Value isZeroTC = mlir::arith::CmpIOp::create( + firOpBuilder, loc, mlir::arith::CmpIPredicate::slt, ub, lb); + mlir::Value tripcount = mlir::arith::SelectOp::create( + firOpBuilder, loc, isZeroTC, zero, tcIfLooping); tripcounts.push_back(tripcount); // Create the CLI handle. - auto newcli = firOpBuilder.create<mlir::omp::NewCliOp>(loc); + auto newcli = mlir::omp::NewCliOp::create(firOpBuilder, loc); mlir::Value cli = newcli.getResult(); clis.push_back(cli); @@ -2122,10 +2123,10 @@ static void genCanonicalLoopNest( "Expecting all block args to have been collected by now"); for (auto j : llvm::seq<size_t>(numLoops)) { mlir::Value natIterNum = fir::getBase(blockArgs[j]); - mlir::Value scaled = firOpBuilder.create<mlir::arith::MulIOp>( - loc, natIterNum, loopStepVars[j]); - mlir::Value userVal = firOpBuilder.create<mlir::arith::AddIOp>( - loc, loopLBVars[j], scaled); + mlir::Value scaled = mlir::arith::MulIOp::create( + firOpBuilder, loc, natIterNum, loopStepVars[j]); + mlir::Value userVal = mlir::arith::AddIOp::create( + firOpBuilder, loc, loopLBVars[j], scaled); mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint(); @@ -2198,9 +2199,9 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter, gridGeneratees.reserve(numLoops); intratileGeneratees.reserve(numLoops); for ([[maybe_unused]] auto i : llvm::seq<int>(0, sizesClause.sizes.size())) { - auto gridCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc); + auto gridCLI = mlir::omp::NewCliOp::create(firOpBuilder, loc); gridGeneratees.push_back(gridCLI.getResult()); - auto intratileCLI = firOpBuilder.create<mlir::omp::NewCliOp>(loc); + auto intratileCLI = mlir::omp::NewCliOp::create(firOpBuilder, loc); intratileGeneratees.push_back(intratileCLI.getResult()); } @@ -2209,8 +2210,8 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter, generatees.append(gridGeneratees); generatees.append(intratileGeneratees); - firOpBuilder.create<mlir::omp::TileOp>(loc, generatees, applyees, - sizesClause.sizes); + mlir::omp::TileOp::create(firOpBuilder, loc, generatees, applyees, + sizesClause.sizes); } static void genUnrollOp(Fortran::lower::AbstractConverter &converter, diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 70bb43a2..478ab15 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -749,6 +749,44 @@ struct VolatileCastOpConversion } }; +/// Lower `fir.assumed_size_extent` to constant -1 of index type. +struct AssumedSizeExtentOpConversion + : public fir::FIROpConversion<fir::AssumedSizeExtentOp> { + using FIROpConversion::FIROpConversion; + + llvm::LogicalResult + matchAndRewrite(fir::AssumedSizeExtentOp op, OpAdaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); + mlir::Type ity = lowerTy().indexType(); + auto cst = fir::genConstantIndex(loc, ity, rewriter, -1); + rewriter.replaceOp(op, cst.getResult()); + return mlir::success(); + } +}; + +/// Lower `fir.is_assumed_size_extent` to integer equality with -1. +struct IsAssumedSizeExtentOpConversion + : public fir::FIROpConversion<fir::IsAssumedSizeExtentOp> { + using FIROpConversion::FIROpConversion; + + llvm::LogicalResult + matchAndRewrite(fir::IsAssumedSizeExtentOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op.getLoc(); + mlir::Value val = adaptor.getVal(); + mlir::Type valTy = val.getType(); + // Create constant -1 of the operand type. + auto negOneAttr = rewriter.getIntegerAttr(valTy, -1); + auto negOne = + mlir::LLVM::ConstantOp::create(rewriter, loc, valTy, negOneAttr); + auto cmp = mlir::LLVM::ICmpOp::create( + rewriter, loc, mlir::LLVM::ICmpPredicate::eq, val, negOne); + rewriter.replaceOp(op, cmp.getResult()); + return mlir::success(); + } +}; + /// convert value of from-type to value of to-type struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> { using FIROpConversion::FIROpConversion; @@ -1113,7 +1151,7 @@ struct AllocMemOpConversion : public fir::FIROpConversion<fir::AllocMemOp> { mlir::Value size = genTypeSizeInBytes(loc, ity, rewriter, llvmObjectTy); if (auto scaleSize = fir::genAllocationScaleSize(loc, heap.getInType(), ity, rewriter)) - size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize); + size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize); for (mlir::Value opnd : adaptor.getOperands()) size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, integerCast(loc, rewriter, ity, opnd)); @@ -4360,6 +4398,7 @@ void fir::populateFIRToLLVMConversionPatterns( AllocaOpConversion, AllocMemOpConversion, BoxAddrOpConversion, BoxCharLenOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion, BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion, + AssumedSizeExtentOpConversion, IsAssumedSizeExtentOpConversion, BoxOffsetOpConversion, BoxProcHostOpConversion, BoxRankOpConversion, BoxTypeCodeOpConversion, BoxTypeDescOpConversion, CallOpConversion, CmpcOpConversion, VolatileCastOpConversion, ConvertOpConversion, diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp index 381b2a2..f74d635 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp @@ -242,10 +242,11 @@ struct TargetAllocMemOpConversion loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout()); if (auto scaleSize = fir::genAllocationScaleSize( loc, allocmemOp.getInType(), ity, rewriter)) - size = rewriter.create<mlir::LLVM::MulOp>(loc, ity, size, scaleSize); + size = mlir::LLVM::MulOp::create(rewriter, loc, ity, size, scaleSize); for (mlir::Value opnd : adaptor.getOperands().drop_front()) - size = rewriter.create<mlir::LLVM::MulOp>( - loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd)); + size = mlir::LLVM::MulOp::create( + rewriter, loc, ity, size, + integerCast(lowerTy(), loc, rewriter, ity, opnd)); auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); auto mallocTy = mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index ac285b5..0776346 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -872,6 +872,14 @@ public: } } + // Count the number of arguments that have to stay in place at the end of + // the argument list. + unsigned trailingArgs = 0; + if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>) { + trailingArgs = + func.getNumWorkgroupAttributions() + func.getNumPrivateAttributions(); + } + // Convert return value(s) for (auto ty : funcTy.getResults()) llvm::TypeSwitch<mlir::Type>(ty) @@ -981,6 +989,16 @@ public: } } + // Add the argument at the end if the number of trailing arguments is 0, + // otherwise insert the argument at the appropriate index. + auto addOrInsertArgument = [&](mlir::Type ty, mlir::Location loc) { + unsigned inputIndex = func.front().getArguments().size() - trailingArgs; + auto newArg = trailingArgs == 0 + ? func.front().addArgument(ty, loc) + : func.front().insertArgument(inputIndex, ty, loc); + return newArg; + }; + if (!func.empty()) { // If the function has a body, then apply the fixups to the arguments and // return ops as required. These fixups are done in place. @@ -1117,8 +1135,7 @@ public: // original arguments. (Boxchar arguments.) auto newBufArg = func.front().insertArgument(fixup.index, fixupType, loc); - auto newLenArg = - func.front().addArgument(trailingTys[fixup.second], loc); + auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc); auto boxTy = oldArgTys[fixup.index - offset]; rewriter->setInsertionPointToStart(&func.front()); auto box = fir::EmboxCharOp::create(*rewriter, loc, boxTy, newBufArg, @@ -1133,8 +1150,7 @@ public: // appended after all the original arguments. auto newProcPointerArg = func.front().insertArgument(fixup.index, fixupType, loc); - auto newLenArg = - func.front().addArgument(trailingTys[fixup.second], loc); + auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc); auto tupleType = oldArgTys[fixup.index - offset]; rewriter->setInsertionPointToStart(&func.front()); fir::FirOpBuilder builder(*rewriter, getModule()); diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 1712af1..d0164f3 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -5143,6 +5143,34 @@ void fir::BoxTotalElementsOp::getCanonicalizationPatterns( } //===----------------------------------------------------------------------===// +// IsAssumedSizeExtentOp and AssumedSizeExtentOp +//===----------------------------------------------------------------------===// + +namespace { +struct FoldIsAssumedSizeExtentOnCtor + : public mlir::OpRewritePattern<fir::IsAssumedSizeExtentOp> { + using mlir::OpRewritePattern<fir::IsAssumedSizeExtentOp>::OpRewritePattern; + mlir::LogicalResult + matchAndRewrite(fir::IsAssumedSizeExtentOp op, + mlir::PatternRewriter &rewriter) const override { + if (llvm::isa_and_nonnull<fir::AssumedSizeExtentOp>( + op.getVal().getDefiningOp())) { + mlir::Type i1 = rewriter.getI1Type(); + rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>( + op, i1, rewriter.getIntegerAttr(i1, 1)); + return mlir::success(); + } + return mlir::failure(); + } +}; +} // namespace + +void fir::IsAssumedSizeExtentOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + patterns.add<FoldIsAssumedSizeExtentOnCtor>(context); +} + +//===----------------------------------------------------------------------===// // LocalitySpecifierOp //===----------------------------------------------------------------------===// diff --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp index 4840a99..0d135a9 100644 --- a/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp +++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCRecipeBufferization.cpp @@ -39,13 +39,13 @@ public: static mlir::Operation *load(mlir::OpBuilder &builder, mlir::Location loc, mlir::Value value) { - return builder.create<fir::LoadOp>(loc, value); + return fir::LoadOp::create(builder, loc, value); } static mlir::Value placeInMemory(mlir::OpBuilder &builder, mlir::Location loc, mlir::Value value) { - auto alloca = builder.create<fir::AllocaOp>(loc, value.getType()); - builder.create<fir::StoreOp>(loc, value, alloca); + auto alloca = fir::AllocaOp::create(builder, loc, value.getType()); + fir::StoreOp::create(builder, loc, value, alloca); return alloca; } }; diff --git a/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp b/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp index 817434f..5793d46 100644 --- a/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp +++ b/flang/lib/Optimizer/OpenMP/AutomapToTargetData.cpp @@ -130,8 +130,8 @@ class AutomapToTargetDataPass builder.getBoolAttr(false)); clauses.mapVars.push_back(mapInfo); isa<fir::StoreOp>(memOp) - ? builder.create<omp::TargetEnterDataOp>(memOp.getLoc(), clauses) - : builder.create<omp::TargetExitDataOp>(memOp.getLoc(), clauses); + ? omp::TargetEnterDataOp::create(builder, memOp.getLoc(), clauses) + : omp::TargetExitDataOp::create(builder, memOp.getLoc(), clauses); }; for (fir::GlobalOp globalOp : automapGlobals) { diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp index 65a23be..1229018 100644 --- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp @@ -595,7 +595,7 @@ private: mlir::omp::TargetOperands &clauseOps, mlir::omp::LoopNestOperands &loopNestClauseOps, const LiveInShapeInfoMap &liveInShapeInfoMap) const { - auto targetOp = rewriter.create<mlir::omp::TargetOp>(loc, clauseOps); + auto targetOp = mlir::omp::TargetOp::create(rewriter, loc, clauseOps); auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp); mlir::Region ®ion = targetOp.getRegion(); @@ -672,7 +672,7 @@ private: // temporary. Fortran::utils::openmp::cloneOrMapRegionOutsiders(builder, targetOp); rewriter.setInsertionPoint( - rewriter.create<mlir::omp::TerminatorOp>(targetOp.getLoc())); + mlir::omp::TerminatorOp::create(rewriter, targetOp.getLoc())); return targetOp; } @@ -715,8 +715,8 @@ private: auto shapeShiftType = fir::ShapeShiftType::get( builder.getContext(), shapeShiftOperands.size() / 2); - return builder.create<fir::ShapeShiftOp>( - liveInArg.getLoc(), shapeShiftType, shapeShiftOperands); + return fir::ShapeShiftOp::create(builder, liveInArg.getLoc(), + shapeShiftType, shapeShiftOperands); } llvm::SmallVector<mlir::Value> shapeOperands; @@ -728,11 +728,11 @@ private: ++shapeIdx; } - return builder.create<fir::ShapeOp>(liveInArg.getLoc(), shapeOperands); + return fir::ShapeOp::create(builder, liveInArg.getLoc(), shapeOperands); }(); - return builder.create<hlfir::DeclareOp>(liveInArg.getLoc(), liveInArg, - liveInName, shape); + return hlfir::DeclareOp::create(builder, liveInArg.getLoc(), liveInArg, + liveInName, shape); } mlir::omp::TeamsOp genTeamsOp(mlir::ConversionPatternRewriter &rewriter, @@ -742,13 +742,13 @@ private: genReductions(rewriter, mapper, loop, teamsOps); mlir::Location loc = loop.getLoc(); - auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(loc, teamsOps); + auto teamsOp = mlir::omp::TeamsOp::create(rewriter, loc, teamsOps); Fortran::common::openmp::EntryBlockArgs teamsArgs; teamsArgs.reduction.vars = teamsOps.reductionVars; Fortran::common::openmp::genEntryBlock(rewriter, teamsArgs, teamsOp.getRegion()); - rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc)); + rewriter.setInsertionPoint(mlir::omp::TerminatorOp::create(rewriter, loc)); for (auto [loopVar, teamsArg] : llvm::zip_equal( loop.getReduceVars(), teamsOp.getRegion().getArguments())) { @@ -761,8 +761,8 @@ private: mlir::omp::DistributeOp genDistributeOp(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter) const { - auto distOp = rewriter.create<mlir::omp::DistributeOp>( - loc, /*clauses=*/mlir::omp::DistributeOperands{}); + auto distOp = mlir::omp::DistributeOp::create( + rewriter, loc, /*clauses=*/mlir::omp::DistributeOperands{}); rewriter.createBlock(&distOp.getRegion()); return distOp; diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 8a9b383..7b61539 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -282,14 +282,14 @@ fissionWorkdistribute(omp::WorkdistributeOp workdistribute) { &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {}); for (auto arg : teamsBlock->getArguments()) newTeamsBlock->addArgument(arg.getType(), arg.getLoc()); - auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc); - rewriter.create<omp::TerminatorOp>(loc); + auto newWorkdistribute = omp::WorkdistributeOp::create(rewriter, loc); + omp::TerminatorOp::create(rewriter, loc); rewriter.createBlock(&newWorkdistribute.getRegion(), newWorkdistribute.getRegion().begin(), {}, {}); auto *cloned = rewriter.clone(*parallelize); parallelize->replaceAllUsesWith(cloned); parallelize->erase(); - rewriter.create<omp::TerminatorOp>(loc); + omp::TerminatorOp::create(rewriter, loc); changed = true; } } @@ -298,10 +298,10 @@ fissionWorkdistribute(omp::WorkdistributeOp workdistribute) { /// Generate omp.parallel operation with an empty region. static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { - auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc); + auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc); parallelOp.setComposite(composite); rewriter.createBlock(¶llelOp.getRegion()); - rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc)); + rewriter.setInsertionPoint(mlir::omp::TerminatorOp::create(rewriter, loc)); return; } @@ -309,7 +309,7 @@ static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) { static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) { mlir::omp::DistributeOperands distributeClauseOps; auto distributeOp = - rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps); + mlir::omp::DistributeOp::create(rewriter, loc, distributeClauseOps); distributeOp.setComposite(composite); auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion()); rewriter.setInsertionPointToStart(distributeBlock); @@ -334,12 +334,12 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, const mlir::omp::LoopNestOperands &clauseOps, bool composite) { - auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc()); + auto wsloopOp = mlir::omp::WsloopOp::create(rewriter, doLoop.getLoc()); wsloopOp.setComposite(composite); rewriter.createBlock(&wsloopOp.getRegion()); auto loopNestOp = - rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps); + mlir::omp::LoopNestOp::create(rewriter, doLoop.getLoc(), clauseOps); // Clone the loop's body inside the loop nest construct using the // mapped values. @@ -351,7 +351,7 @@ static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop, // Erase fir.result op of do loop and create yield op. if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) { rewriter.setInsertionPoint(terminatorOp); - rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc()); + mlir::omp::YieldOp::create(rewriter, doLoop->getLoc()); terminatorOp->erase(); } } @@ -494,15 +494,15 @@ static SmallVector<Value> convertFlatToMultiDim(OpBuilder &builder, // Convert flat index to multi-dimensional indices SmallVector<Value> indices(rank); Value temp = flatIdx; - auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1); + auto c1 = arith::ConstantIndexOp::create(builder, loc, 1); // Work backwards through dimensions (row-major order) for (int i = rank - 1; i >= 0; --i) { - Value zeroBasedIdx = builder.create<arith::RemSIOp>(loc, temp, extents[i]); + Value zeroBasedIdx = arith::RemSIOp::create(builder, loc, temp, extents[i]); // Convert to one-based index - indices[i] = builder.create<arith::AddIOp>(loc, zeroBasedIdx, c1); + indices[i] = arith::AddIOp::create(builder, loc, zeroBasedIdx, c1); if (i > 0) { - temp = builder.create<arith::DivSIOp>(loc, temp, extents[i]); + temp = arith::DivSIOp::create(builder, loc, temp, extents[i]); } } @@ -525,7 +525,7 @@ static Value CalculateTotalElements(OpBuilder &builder, Location loc, if (i == 0) { totalElems = extent; } else { - totalElems = builder.create<arith::MulIOp>(loc, totalElems, extent); + totalElems = arith::MulIOp::create(builder, loc, totalElems, extent); } } return totalElems; @@ -562,14 +562,14 @@ static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc, // Load destination array box (if it's a reference) Value arrayBox = destBox; if (isa<fir::ReferenceType>(destBox.getType())) - arrayBox = builder.create<fir::LoadOp>(loc, destBox); + arrayBox = fir::LoadOp::create(builder, loc, destBox); - auto scalarValue = builder.create<fir::BoxAddrOp>(loc, srcBox); - Value scalar = builder.create<fir::LoadOp>(loc, scalarValue); + auto scalarValue = fir::BoxAddrOp::create(builder, loc, srcBox); + Value scalar = fir::LoadOp::create(builder, loc, scalarValue); // Calculate total number of elements (flattened) - auto c0 = builder.create<arith::ConstantIndexOp>(loc, 0); - auto c1 = builder.create<arith::ConstantIndexOp>(loc, 1); + auto c0 = arith::ConstantIndexOp::create(builder, loc, 0); + auto c1 = arith::ConstantIndexOp::create(builder, loc, 1); Value totalElems = CalculateTotalElements(builder, loc, arrayBox); auto *workdistributeBlock = &workdistribute.getRegion().front(); @@ -587,7 +587,7 @@ static void replaceWithUnorderedDoLoop(OpBuilder &builder, Location loc, builder, loc, fir::ReferenceType::get(scalar.getType()), arrayBox, nullptr, nullptr, ValueRange{indices}, ValueRange{}); - builder.create<fir::StoreOp>(loc, scalar, elemPtr); + fir::StoreOp::create(builder, loc, scalar, elemPtr); } /// workdistributeRuntimeCallLower method finds the runtime calls @@ -749,14 +749,15 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp, auto deviceAddrVars = targetOp.getHasDeviceAddrVars(); auto devicePtrVars = targetOp.getIsDevicePtrVars(); // Create the target data op - auto targetDataOp = rewriter.create<omp::TargetDataOp>( - loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars); + auto targetDataOp = + omp::TargetDataOp::create(rewriter, loc, device, ifExpr, outerMapInfos, + deviceAddrVars, devicePtrVars); auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion()); - rewriter.create<mlir::omp::TerminatorOp>(loc); + mlir::omp::TerminatorOp::create(rewriter, loc); rewriter.setInsertionPointToStart(taregtDataBlock); // Create the inner target op - auto newTargetOp = rewriter.create<omp::TargetOp>( - targetOp.getLoc(), targetOp.getAllocateVars(), + auto newTargetOp = omp::TargetOp::create( + rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), @@ -821,19 +822,19 @@ static TempOmpVar allocateTempOmpVar(Location loc, Type ty, // Get the appropriate type for allocation if (isPtr(ty)) { Type intTy = rewriter.getI32Type(); - auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1); + auto one = LLVM::ConstantOp::create(rewriter, loc, intTy, 1); allocType = llvmPtrTy; - alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one); + alloc = LLVM::AllocaOp::create(rewriter, loc, llvmPtrTy, allocType, one); allocType = intTy; } else { allocType = ty; - alloc = rewriter.create<fir::AllocaOp>(loc, allocType); + alloc = fir::AllocaOp::create(rewriter, loc, allocType); } // Lambda to create mapinfo ops auto getMapInfo = [&](mlir::omp::ClauseMapFlags mappingFlags, const char *name) { - return rewriter.create<omp::MapInfoOp>( - loc, alloc.getType(), alloc, TypeAttr::get(allocType), + return omp::MapInfoOp::create( + rewriter, loc, alloc.getType(), alloc, TypeAttr::get(allocType), rewriter.getAttr<omp::ClauseMapFlagsAttr>(mappingFlags), rewriter.getAttr<omp::VariableCaptureKindAttr>( omp::VariableCaptureKind::ByRef), @@ -979,12 +980,12 @@ static void reloadCacheAndRecompute( // If the original value is a pointer or reference, load and convert if // necessary. if (isPtr(original.getType())) { - restored = rewriter.create<LLVM::LoadOp>(loc, llvmPtrTy, newArg); + restored = LLVM::LoadOp::create(rewriter, loc, llvmPtrTy, newArg); if (!isa<LLVM::LLVMPointerType>(original.getType())) restored = - rewriter.create<fir::ConvertOp>(loc, original.getType(), restored); + fir::ConvertOp::create(rewriter, loc, original.getType(), restored); } else { - restored = rewriter.create<fir::LoadOp>(loc, newArg); + restored = fir::LoadOp::create(rewriter, loc, newArg); } irMapping.map(original, restored); } @@ -1053,7 +1054,7 @@ static mlir::LLVM::ConstantOp genI32Constant(mlir::Location loc, mlir::RewriterBase &rewriter, int value) { mlir::Type i32Ty = rewriter.getI32Type(); mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value); - return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr); + return mlir::LLVM::ConstantOp::create(rewriter, loc, i32Ty, attr); } /// Given a box descriptor, extract the base address of the data it describes. @@ -1230,8 +1231,8 @@ static void genFortranAssignOmpReplacement(fir::FirOpBuilder &builder, genOmpGetMappedPtrIfPresent(builder, loc, destBase, device, module); Value srcPtr = genOmpGetMappedPtrIfPresent(builder, loc, srcBase, device, module); - Value zero = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), - builder.getI64IntegerAttr(0)); + Value zero = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), + builder.getI64IntegerAttr(0)); // Generate the call to omp_target_memcpy to perform the data copy on the // device. @@ -1348,23 +1349,24 @@ static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, for (Operation *op : opsToReplace) { if (auto allocOp = dyn_cast<fir::AllocMemOp>(op)) { rewriter.setInsertionPoint(allocOp); - auto ompAllocmemOp = rewriter.create<omp::TargetAllocMemOp>( - allocOp.getLoc(), rewriter.getI64Type(), device, + auto ompAllocmemOp = omp::TargetAllocMemOp::create( + rewriter, allocOp.getLoc(), rewriter.getI64Type(), device, allocOp.getInTypeAttr(), allocOp.getUniqNameAttr(), allocOp.getBindcNameAttr(), allocOp.getTypeparams(), allocOp.getShape()); - auto firConvertOp = rewriter.create<fir::ConvertOp>( - allocOp.getLoc(), allocOp.getResult().getType(), - ompAllocmemOp.getResult()); + auto firConvertOp = fir::ConvertOp::create(rewriter, allocOp.getLoc(), + allocOp.getResult().getType(), + ompAllocmemOp.getResult()); rewriter.replaceOp(allocOp, firConvertOp.getResult()); } // Replace fir.freemem with omp.target_freemem. else if (auto freeOp = dyn_cast<fir::FreeMemOp>(op)) { rewriter.setInsertionPoint(freeOp); - auto firConvertOp = rewriter.create<fir::ConvertOp>( - freeOp.getLoc(), rewriter.getI64Type(), freeOp.getHeapref()); - rewriter.create<omp::TargetFreeMemOp>(freeOp.getLoc(), device, - firConvertOp.getResult()); + auto firConvertOp = + fir::ConvertOp::create(rewriter, freeOp.getLoc(), + rewriter.getI64Type(), freeOp.getHeapref()); + omp::TargetFreeMemOp::create(rewriter, freeOp.getLoc(), device, + firConvertOp.getResult()); rewriter.eraseOp(freeOp); } // fir.declare changes its type when hoisting it out of omp.target to @@ -1376,8 +1378,9 @@ static LogicalResult moveToHost(omp::TargetOp targetOp, RewriterBase &rewriter, dyn_cast<fir::ReferenceType>(clonedInType); Type clonedEleTy = clonedRefType.getElementType(); rewriter.setInsertionPoint(op); - Value loadedValue = rewriter.create<fir::LoadOp>( - clonedDeclareOp.getLoc(), clonedEleTy, clonedDeclareOp.getMemref()); + Value loadedValue = + fir::LoadOp::create(rewriter, clonedDeclareOp.getLoc(), clonedEleTy, + clonedDeclareOp.getMemref()); clonedDeclareOp.getResult().replaceAllUsesWith(loadedValue); } // Replace runtime calls with omp versions. @@ -1473,8 +1476,8 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands, auto *targetBlock = &targetOp.getRegion().front(); SmallVector<Value> preHostEvalVars{targetOp.getHostEvalVars()}; // update the hostEvalVars of preTargetOp - omp::TargetOp preTargetOp = rewriter.create<omp::TargetOp>( - targetOp.getLoc(), targetOp.getAllocateVars(), + omp::TargetOp preTargetOp = omp::TargetOp::create( + rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars, @@ -1513,13 +1516,13 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands, // Create the store operation. if (isPtr(originalResult.getType())) { if (!isa<LLVM::LLVMPointerType>(toStore.getType())) - toStore = rewriter.create<fir::ConvertOp>(loc, llvmPtrTy, toStore); - rewriter.create<LLVM::StoreOp>(loc, toStore, newArg); + toStore = fir::ConvertOp::create(rewriter, loc, llvmPtrTy, toStore); + LLVM::StoreOp::create(rewriter, loc, toStore, newArg); } else { - rewriter.create<fir::StoreOp>(loc, toStore, newArg); + fir::StoreOp::create(rewriter, loc, toStore, newArg); } } - rewriter.create<omp::TerminatorOp>(loc); + omp::TerminatorOp::create(rewriter, loc); // Update hostEvalVars with the mapped values for the loop bounds if we have // a loopNestOp and we are not generating code for the target device. @@ -1563,8 +1566,8 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands, hostEvalVars.steps.end()); } // Create the isolated target op - omp::TargetOp isolatedTargetOp = rewriter.create<omp::TargetOp>( - targetOp.getLoc(), targetOp.getAllocateVars(), + omp::TargetOp isolatedTargetOp = omp::TargetOp::create( + rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), @@ -1590,7 +1593,7 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands, // Clone the original operations. rewriter.clone(*splitBeforeOp, isolatedMapping); - rewriter.create<omp::TerminatorOp>(loc); + omp::TerminatorOp::create(rewriter, loc); // update the loop bounds in the isolatedTargetOp if we have host_eval vars // and we are not generating code for the target device. @@ -1643,8 +1646,8 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, auto *targetBlock = &targetOp.getRegion().front(); SmallVector<Value> postHostEvalVars{targetOp.getHostEvalVars()}; // Create the post target op - omp::TargetOp postTargetOp = rewriter.create<omp::TargetOp>( - targetOp.getLoc(), targetOp.getAllocateVars(), + omp::TargetOp postTargetOp = omp::TargetOp::create( + rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars, diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp index 566e88b..bd07d7f 100644 --- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp +++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp @@ -883,18 +883,16 @@ class MapInfoFinalizationPass if (explicitMappingPresent(op, targetDataOp)) return; - mlir::omp::MapInfoOp newDescParentMapOp = - builder.create<mlir::omp::MapInfoOp>( - op->getLoc(), op.getResult().getType(), op.getVarPtr(), - op.getVarTypeAttr(), - builder.getAttr<mlir::omp::ClauseMapFlagsAttr>( - mlir::omp::ClauseMapFlags::to | - mlir::omp::ClauseMapFlags::always), - op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, - mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{}, - /*bounds=*/mlir::SmallVector<mlir::Value>{}, - /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), - /*partial_map=*/builder.getBoolAttr(false)); + mlir::omp::MapInfoOp newDescParentMapOp = mlir::omp::MapInfoOp::create( + builder, op->getLoc(), op.getResult().getType(), op.getVarPtr(), + op.getVarTypeAttr(), + builder.getAttr<mlir::omp::ClauseMapFlagsAttr>( + mlir::omp::ClauseMapFlags::to | mlir::omp::ClauseMapFlags::always), + op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, + mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{}, + /*bounds=*/mlir::SmallVector<mlir::Value>{}, + /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), + /*partial_map=*/builder.getBoolAttr(false)); targetDataOp.getMapVarsMutable().append({newDescParentMapOp}); } @@ -946,14 +944,13 @@ class MapInfoFinalizationPass // need to see how well this alteration works. auto loadBaseAddr = builder.loadIfRef(op->getLoc(), baseAddr.getVarPtrPtr()); - mlir::omp::MapInfoOp newBaseAddrMapOp = - builder.create<mlir::omp::MapInfoOp>( - op->getLoc(), loadBaseAddr.getType(), loadBaseAddr, - baseAddr.getVarTypeAttr(), baseAddr.getMapTypeAttr(), - baseAddr.getMapCaptureTypeAttr(), mlir::Value{}, members, - membersAttr, baseAddr.getBounds(), - /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), - /*partial_map=*/builder.getBoolAttr(false)); + mlir::omp::MapInfoOp newBaseAddrMapOp = mlir::omp::MapInfoOp::create( + builder, op->getLoc(), loadBaseAddr.getType(), loadBaseAddr, + baseAddr.getVarTypeAttr(), baseAddr.getMapTypeAttr(), + baseAddr.getMapCaptureTypeAttr(), mlir::Value{}, members, membersAttr, + baseAddr.getBounds(), + /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(), + /*partial_map=*/builder.getBoolAttr(false)); op.replaceAllUsesWith(newBaseAddrMapOp.getResult()); op->erase(); baseAddr.erase(); diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 6dae39b..103e736 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -426,6 +426,12 @@ void createMLIRToLLVMPassPipeline(mlir::PassManager &pm, // Add codegen pass pipeline. fir::createDefaultFIRCodeGenPassPipeline(pm, config, inputFilename); + + // Run a pass to prepare for translation of delayed privatization in the + // context of deferred target tasks. + addPassConditionally(pm, disableFirToLlvmIr, [&]() { + return mlir::omp::createPrepareForOMPOffloadPrivatizationPass(); + }); } } // namespace fir diff --git a/flang/lib/Optimizer/Support/Utils.cpp b/flang/lib/Optimizer/Support/Utils.cpp index 92390e4a..2f33d89 100644 --- a/flang/lib/Optimizer/Support/Utils.cpp +++ b/flang/lib/Optimizer/Support/Utils.cpp @@ -66,7 +66,7 @@ fir::genConstantIndex(mlir::Location loc, mlir::Type ity, mlir::ConversionPatternRewriter &rewriter, std::int64_t offset) { auto cattr = rewriter.getI64IntegerAttr(offset); - return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr); + return mlir::LLVM::ConstantOp::create(rewriter, loc, ity, cattr); } mlir::Value @@ -125,9 +125,9 @@ mlir::Value fir::integerCast(const fir::LLVMTypeConverter &converter, return rewriter.createOrFold<mlir::LLVM::SExtOp>(loc, ty, val); } else { if (toSize < fromSize) - return rewriter.create<mlir::LLVM::TruncOp>(loc, ty, val); + return mlir::LLVM::TruncOp::create(rewriter, loc, ty, val); if (toSize > fromSize) - return rewriter.create<mlir::LLVM::SExtOp>(loc, ty, val); + return mlir::LLVM::SExtOp::create(rewriter, loc, ty, val); } return val; } diff --git a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp index ed9a2ae..5bf783d 100644 --- a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp +++ b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp @@ -832,8 +832,8 @@ static mlir::Type getEleTy(mlir::Type ty) { static bool isAssumedSize(llvm::SmallVectorImpl<mlir::Value> &extents) { if (extents.empty()) return false; - auto cstLen = fir::getIntIfConstant(extents.back()); - return cstLen.has_value() && *cstLen == -1; + return llvm::isa_and_nonnull<fir::AssumedSizeExtentOp>( + extents.back().getDefiningOp()); } // Extract extents from the ShapeOp/ShapeShiftOp into the result vector. diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp index 3d2db6a..caa9bdd 100644 --- a/flang/lib/Semantics/check-cuda.cpp +++ b/flang/lib/Semantics/check-cuda.cpp @@ -131,6 +131,9 @@ struct FindHostArray return (*this)(x.base()); } Result operator()(const Symbol &symbol) const { + if (symbol.IsFuncResult()) { + return nullptr; + } if (const auto *details{ symbol.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) { if (details->IsArray() && diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index be10669..4141630 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -61,6 +61,124 @@ namespace Fortran::semantics { using namespace Fortran::semantics::omp; using namespace Fortran::parser::omp; +OmpStructureChecker::OmpStructureChecker(SemanticsContext &context) + : DirectiveStructureChecker(context, +#define GEN_FLANG_DIRECTIVE_CLAUSE_MAP +#include "llvm/Frontend/OpenMP/OMP.inc" + ) { + scopeStack_.push_back(&context.globalScope()); +} + +bool OmpStructureChecker::Enter(const parser::MainProgram &x) { + using StatementProgramStmt = parser::Statement<parser::ProgramStmt>; + if (auto &stmt{std::get<std::optional<StatementProgramStmt>>(x.t)}) { + scopeStack_.push_back(stmt->statement.v.symbol->scope()); + } else { + for (const Scope &scope : context_.globalScope().children()) { + // There can only be one main program. + if (scope.kind() == Scope::Kind::MainProgram) { + scopeStack_.push_back(&scope); + break; + } + } + } + return true; +} + +void OmpStructureChecker::Leave(const parser::MainProgram &x) { + scopeStack_.pop_back(); +} + +bool OmpStructureChecker::Enter(const parser::BlockData &x) { + // The BLOCK DATA name is optional, so we need to look for the + // corresponding scope in the global scope. + auto &stmt{std::get<parser::Statement<parser::BlockDataStmt>>(x.t)}; + if (auto &name{stmt.statement.v}) { + scopeStack_.push_back(name->symbol->scope()); + } else { + for (const Scope &scope : context_.globalScope().children()) { + if (scope.kind() == Scope::Kind::BlockData) { + if (scope.symbol()->name().empty()) { + scopeStack_.push_back(&scope); + break; + } + } + } + } + return true; +} + +void OmpStructureChecker::Leave(const parser::BlockData &x) { + scopeStack_.pop_back(); +} + +bool OmpStructureChecker::Enter(const parser::Module &x) { + auto &stmt{std::get<parser::Statement<parser::ModuleStmt>>(x.t)}; + const Symbol *sym{stmt.statement.v.symbol}; + scopeStack_.push_back(sym->scope()); + return true; +} + +void OmpStructureChecker::Leave(const parser::Module &x) { + scopeStack_.pop_back(); +} + +bool OmpStructureChecker::Enter(const parser::Submodule &x) { + auto &stmt{std::get<parser::Statement<parser::SubmoduleStmt>>(x.t)}; + const Symbol *sym{std::get<parser::Name>(stmt.statement.t).symbol}; + scopeStack_.push_back(sym->scope()); + return true; +} + +void OmpStructureChecker::Leave(const parser::Submodule &x) { + scopeStack_.pop_back(); +} + +// Function/subroutine subprogram nodes don't appear in INTERFACEs, but +// the subprogram/end statements do. +bool OmpStructureChecker::Enter(const parser::SubroutineStmt &x) { + const Symbol *sym{std::get<parser::Name>(x.t).symbol}; + scopeStack_.push_back(sym->scope()); + return true; +} + +bool OmpStructureChecker::Enter(const parser::EndSubroutineStmt &x) { + scopeStack_.pop_back(); + return true; +} + +bool OmpStructureChecker::Enter(const parser::FunctionStmt &x) { + const Symbol *sym{std::get<parser::Name>(x.t).symbol}; + scopeStack_.push_back(sym->scope()); + return true; +} + +bool OmpStructureChecker::Enter(const parser::EndFunctionStmt &x) { + scopeStack_.pop_back(); + return true; +} + +bool OmpStructureChecker::Enter(const parser::BlockConstruct &x) { + auto &specPart{std::get<parser::BlockSpecificationPart>(x.t)}; + auto &execPart{std::get<parser::Block>(x.t)}; + if (auto &&source{parser::GetSource(specPart)}) { + scopeStack_.push_back(&context_.FindScope(*source)); + } else if (auto &&source{parser::GetSource(execPart)}) { + scopeStack_.push_back(&context_.FindScope(*source)); + } + return true; +} + +void OmpStructureChecker::Leave(const parser::BlockConstruct &x) { + auto &specPart{std::get<parser::BlockSpecificationPart>(x.t)}; + auto &execPart{std::get<parser::Block>(x.t)}; + if (auto &&source{parser::GetSource(specPart)}) { + scopeStack_.push_back(&context_.FindScope(*source)); + } else if (auto &&source{parser::GetSource(execPart)}) { + scopeStack_.push_back(&context_.FindScope(*source)); + } +} + // Use when clause falls under 'struct OmpClause' in 'parse-tree.h'. #define CHECK_SIMPLE_CLAUSE(X, Y) \ void OmpStructureChecker::Enter(const parser::OmpClause::X &) { \ @@ -362,6 +480,36 @@ bool OmpStructureChecker::IsNestedInDirective(llvm::omp::Directive directive) { return false; } +bool OmpStructureChecker::InTargetRegion() { + if (IsNestedInDirective(llvm::omp::Directive::OMPD_target)) { + // Return true even for device_type(host). + return true; + } + for (const Scope *scope : llvm::reverse(scopeStack_)) { + if (const auto *symbol{scope->symbol()}) { + if (symbol->test(Symbol::Flag::OmpDeclareTarget)) { + return true; + } + } + } + return false; +} + +bool OmpStructureChecker::HasRequires(llvm::omp::Clause req) { + const Scope &unit{GetProgramUnit(*scopeStack_.back())}; + return common::visit( + [&](const auto &details) { + if constexpr (std::is_convertible_v<decltype(details), + const WithOmpDeclarative &>) { + if (auto *reqs{details.ompRequires()}) { + return reqs->test(req); + } + } + return false; + }, + DEREF(unit.symbol()).details()); +} + void OmpStructureChecker::CheckVariableListItem( const SymbolSourceMap &symbols) { for (auto &[symbol, source] : symbols) { @@ -1562,40 +1710,95 @@ void OmpStructureChecker::Leave(const parser::OpenMPRequiresConstruct &) { dirContext_.pop_back(); } -void OmpStructureChecker::Enter(const parser::OpenMPDeclarativeAllocate &x) { - isPredefinedAllocator = true; - const auto &dir{std::get<parser::Verbatim>(x.t)}; - const auto &objectList{std::get<parser::OmpObjectList>(x.t)}; - PushContextAndClauseSets(dir.source, llvm::omp::Directive::OMPD_allocate); - SymbolSourceMap currSymbols; - GetSymbolsInObjectList(objectList, currSymbols); - for (auto &[symbol, source] : currSymbols) { - if (IsPointer(*symbol)) { - context_.Say(source, - "List item '%s' in ALLOCATE directive must not have POINTER " - "attribute"_err_en_US, - source.ToString()); +void OmpStructureChecker::CheckAllocateDirective(parser::CharBlock source, + const parser::OmpObjectList &objects, + const parser::OmpClauseList &clauses) { + const Scope &thisScope{context_.FindScope(source)}; + SymbolSourceMap symbols; + GetSymbolsInObjectList(objects, symbols); + + auto maybeHasPredefinedAllocator{[&](const parser::OmpClause *calloc) { + // Return "true" if the ALLOCATOR clause was provided with an argument + // that is either a prefdefined allocator, or a run-time value. + // Otherwise return "false". + if (!calloc) { + return false; } - if (IsDummy(*symbol)) { + auto *allocator{std::get_if<parser::OmpClause::Allocator>(&calloc->u)}; + if (auto val{ToInt64(GetEvaluateExpr(DEREF(allocator).v))}) { + // Predefined allocators (defined in OpenMP 6.0 20.8.1): + // omp_null_allocator = 0, + // omp_default_mem_alloc = 1, + // omp_large_cap_mem_alloc = 2, + // omp_const_mem_alloc = 3, + // omp_high_bw_mem_alloc = 4, + // omp_low_lat_mem_alloc = 5, + // omp_cgroup_mem_alloc = 6, + // omp_pteam_mem_alloc = 7, + // omp_thread_mem_alloc = 8 + return *val >= 0 && *val <= 8; + } + return true; + }}; + + const auto *allocator{FindClause(llvm::omp::Clause::OMPC_allocator)}; + if (InTargetRegion()) { + bool hasDynAllocators{ + HasRequires(llvm::omp::Clause::OMPC_dynamic_allocators)}; + if (!allocator && !hasDynAllocators) { context_.Say(source, - "List item '%s' in ALLOCATE directive must not be a dummy " - "argument"_err_en_US, - source.ToString()); + "An ALLOCATE directive in a TARGET region must specify an ALLOCATOR clause or REQUIRES(DYNAMIC_ALLOCATORS) must be specified"_err_en_US); + } + } + + auto maybePredefined{maybeHasPredefinedAllocator(allocator)}; + + for (auto &[symbol, source] : symbols) { + if (!inExecutableAllocate_) { + if (symbol->owner() != thisScope) { + context_.Say(source, + "A list item on a declarative ALLOCATE must be declared in the same scope in which the directive appears"_err_en_US); + } + if (IsPointer(*symbol) || IsAllocatable(*symbol)) { + context_.Say(source, + "A list item in a declarative ALLOCATE cannot have the ALLOCATABLE or POINTER attribute"_err_en_US); + } } if (symbol->GetUltimate().has<AssocEntityDetails>()) { context_.Say(source, - "List item '%s' in ALLOCATE directive must not be an associate " - "name"_err_en_US, - source.ToString()); + "A list item in a declarative ALLOCATE cannot be an associate name"_err_en_US); + } + if (symbol->attrs().test(Attr::SAVE) || IsCommonBlock(*symbol)) { + if (!allocator) { + context_.Say(source, + "If a list item is a named common block or has SAVE attribute, an ALLOCATOR clause must be present with a predefined allocator"_err_en_US); + } else if (!maybePredefined) { + context_.Say(source, + "If a list item is a named common block or has SAVE attribute, only a predefined allocator may be used on the ALLOCATOR clause"_err_en_US); + } + } + if (FindCommonBlockContaining(*symbol)) { + context_.Say(source, + "A variable that is part of a common block may not be specified as a list item in an ALLOCATE directive, except implicitly via the named common block"_err_en_US); } } - CheckVarIsNotPartOfAnotherVar(dir.source, objectList); + CheckVarIsNotPartOfAnotherVar(source, objects); } -void OmpStructureChecker::Leave(const parser::OpenMPDeclarativeAllocate &x) { +void OmpStructureChecker::Enter(const parser::OpenMPDeclarativeAllocate &x) { const auto &dir{std::get<parser::Verbatim>(x.t)}; - const auto &objectList{std::get<parser::OmpObjectList>(x.t)}; - CheckPredefinedAllocatorRestriction(dir.source, objectList); + PushContextAndClauseSets(dir.source, llvm::omp::Directive::OMPD_allocate); +} + +void OmpStructureChecker::Leave(const parser::OpenMPDeclarativeAllocate &x) { + if (!inExecutableAllocate_) { + const auto &dir{std::get<parser::Verbatim>(x.t)}; + const auto &clauseList{std::get<parser::OmpClauseList>(x.t)}; + const auto &objectList{std::get<parser::OmpObjectList>(x.t)}; + + isPredefinedAllocator = true; + CheckAllocateDirective(dir.source, objectList, clauseList); + } dirContext_.pop_back(); } @@ -1951,6 +2154,7 @@ void OmpStructureChecker::CheckNameInAllocateStmt( } void OmpStructureChecker::Enter(const parser::OpenMPExecutableAllocate &x) { + inExecutableAllocate_ = true; const auto &dir{std::get<parser::Verbatim>(x.t)}; PushContextAndClauseSets(dir.source, llvm::omp::Directive::OMPD_allocate); @@ -1960,24 +2164,6 @@ void OmpStructureChecker::Enter(const parser::OpenMPExecutableAllocate &x) { "The executable form of the OpenMP ALLOCATE directive has been deprecated, please use ALLOCATORS instead"_warn_en_US); } - bool hasAllocator = false; - // TODO: Investigate whether searching the clause list can be done with - // parser::Unwrap instead of the following loop - const auto &clauseList{std::get<parser::OmpClauseList>(x.t)}; - for (const auto &clause : clauseList.v) { - if (std::get_if<parser::OmpClause::Allocator>(&clause.u)) { - hasAllocator = true; - } - } - - if (IsNestedInDirective(llvm::omp::Directive::OMPD_target) && !hasAllocator) { - // TODO: expand this check to exclude the case when a requires - // directive with the dynamic_allocators clause is present - // in the same compilation unit (OMP5.0 2.11.3). - context_.Say(x.source, - "ALLOCATE directives that appear in a TARGET region must specify an allocator clause"_err_en_US); - } - const auto &allocateStmt = std::get<parser::Statement<parser::AllocateStmt>>(x.t).statement; if (const auto &list{std::get<std::optional<parser::OmpObjectList>>(x.t)}) { @@ -1994,18 +2180,34 @@ void OmpStructureChecker::Enter(const parser::OpenMPExecutableAllocate &x) { } isPredefinedAllocator = true; - const auto &objectList{std::get<std::optional<parser::OmpObjectList>>(x.t)}; - if (objectList) { - CheckVarIsNotPartOfAnotherVar(dir.source, *objectList); - } } void OmpStructureChecker::Leave(const parser::OpenMPExecutableAllocate &x) { - const auto &dir{std::get<parser::Verbatim>(x.t)}; - const auto &objectList{std::get<std::optional<parser::OmpObjectList>>(x.t)}; - if (objectList) - CheckPredefinedAllocatorRestriction(dir.source, *objectList); + parser::OmpObjectList empty{std::list<parser::OmpObject>{}}; + auto &objects{[&]() -> const parser::OmpObjectList & { + if (auto &objects{std::get<std::optional<parser::OmpObjectList>>(x.t)}) { + return *objects; + } else { + return empty; + } + }()}; + auto &clauses{std::get<parser::OmpClauseList>(x.t)}; + CheckAllocateDirective( + std::get<parser::Verbatim>(x.t).source, objects, clauses); + + if (const auto &subDirs{ + std::get<std::optional<std::list<parser::OpenMPDeclarativeAllocate>>>( + x.t)}) { + for (const auto &dalloc : *subDirs) { + const auto &dir{std::get<parser::Verbatim>(x.t)}; + const auto &clauses{std::get<parser::OmpClauseList>(dalloc.t)}; + const auto &objects{std::get<parser::OmpObjectList>(dalloc.t)}; + CheckAllocateDirective(dir.source, objects, clauses); + } + } + dirContext_.pop_back(); + inExecutableAllocate_ = false; } void OmpStructureChecker::Enter(const parser::OpenMPAllocatorsConstruct &x) { diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h index b3fd6c8..7426559 100644 --- a/flang/lib/Semantics/check-omp-structure.h +++ b/flang/lib/Semantics/check-omp-structure.h @@ -56,21 +56,32 @@ using SymbolSourceMap = std::multimap<const Symbol *, parser::CharBlock>; using DirectivesClauseTriple = std::multimap<llvm::omp::Directive, std::pair<llvm::omp::Directive, const OmpClauseSet>>; -class OmpStructureChecker - : public DirectiveStructureChecker<llvm::omp::Directive, llvm::omp::Clause, - parser::OmpClause, llvm::omp::Clause_enumSize> { +using OmpStructureCheckerBase = DirectiveStructureChecker<llvm::omp::Directive, + llvm::omp::Clause, parser::OmpClause, llvm::omp::Clause_enumSize>; + +class OmpStructureChecker : public OmpStructureCheckerBase { public: - using Base = DirectiveStructureChecker<llvm::omp::Directive, - llvm::omp::Clause, parser::OmpClause, llvm::omp::Clause_enumSize>; + using Base = OmpStructureCheckerBase; + + OmpStructureChecker(SemanticsContext &context); - OmpStructureChecker(SemanticsContext &context) - : DirectiveStructureChecker(context, -#define GEN_FLANG_DIRECTIVE_CLAUSE_MAP -#include "llvm/Frontend/OpenMP/OMP.inc" - ) { - } using llvmOmpClause = const llvm::omp::Clause; + bool Enter(const parser::MainProgram &); + void Leave(const parser::MainProgram &); + bool Enter(const parser::BlockData &); + void Leave(const parser::BlockData &); + bool Enter(const parser::Module &); + void Leave(const parser::Module &); + bool Enter(const parser::Submodule &); + void Leave(const parser::Submodule &); + bool Enter(const parser::SubroutineStmt &); + bool Enter(const parser::EndSubroutineStmt &); + bool Enter(const parser::FunctionStmt &); + bool Enter(const parser::EndFunctionStmt &); + bool Enter(const parser::BlockConstruct &); + void Leave(const parser::BlockConstruct &); + void Enter(const parser::OpenMPConstruct &); void Leave(const parser::OpenMPConstruct &); void Enter(const parser::OpenMPInteropConstruct &); @@ -177,10 +188,12 @@ private: const parser::CharBlock &, const OmpDirectiveSet &); bool IsCloselyNestedRegion(const OmpDirectiveSet &set); bool IsNestedInDirective(llvm::omp::Directive directive); + bool InTargetRegion(); void HasInvalidTeamsNesting( const llvm::omp::Directive &dir, const parser::CharBlock &source); void HasInvalidDistributeNesting(const parser::OpenMPLoopConstruct &x); void HasInvalidLoopBinding(const parser::OpenMPLoopConstruct &x); + bool HasRequires(llvm::omp::Clause req); // specific clause related void CheckAllowedMapTypes( parser::OmpMapType::Value, llvm::ArrayRef<parser::OmpMapType::Value>); @@ -250,6 +263,9 @@ private: bool CheckTargetBlockOnlyTeams(const parser::Block &); void CheckWorkshareBlockStmts(const parser::Block &, parser::CharBlock); void CheckWorkdistributeBlockStmts(const parser::Block &, parser::CharBlock); + void CheckAllocateDirective(parser::CharBlock source, + const parser::OmpObjectList &objects, + const parser::OmpClauseList &clauses); void CheckIteratorRange(const parser::OmpIteratorSpecifier &x); void CheckIteratorModifier(const parser::OmpIterator &x); @@ -367,12 +383,15 @@ private: }; int directiveNest_[LastType + 1] = {0}; + bool inExecutableAllocate_{false}; parser::CharBlock visitedAtomicSource_; SymbolSourceMap deferredNonVariables_; using LoopConstruct = std::variant<const parser::DoConstruct *, const parser::OpenMPLoopConstruct *>; std::vector<LoopConstruct> loopStack_; + // Scopes for scoping units. + std::vector<const Scope *> scopeStack_; }; /// Find a duplicate entry in the range, and return an iterator to it. diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp index 292e73b..cc55bb4 100644 --- a/flang/lib/Semantics/openmp-utils.cpp +++ b/flang/lib/Semantics/openmp-utils.cpp @@ -218,7 +218,7 @@ bool IsMapExitingType(parser::OmpMapType::Value type) { } } -std::optional<SomeExpr> GetEvaluateExpr(const parser::Expr &parserExpr) { +MaybeExpr GetEvaluateExpr(const parser::Expr &parserExpr) { const parser::TypedExpr &typedExpr{parserExpr.typedExpr}; // ForwardOwningPointer typedExpr // `- GenericExprWrapper ^.get() diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index c410bd4..196755e 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -3094,26 +3094,6 @@ void OmpAttributeVisitor::ResolveOmpDesignator( AddAllocateName(name); } } - if (ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective && - IsAllocatable(*symbol) && - !IsNestedInDirective(llvm::omp::Directive::OMPD_allocate)) { - context_.Say(designator.source, - "List items specified in the ALLOCATE directive must not have the ALLOCATABLE attribute unless the directive is associated with an ALLOCATE statement"_err_en_US); - } - bool checkScope{ompFlag == Symbol::Flag::OmpDeclarativeAllocateDirective}; - // In 5.1 the scope check only applies to declarative allocate. - if (version == 50 && !checkScope) { - checkScope = ompFlag == Symbol::Flag::OmpExecutableAllocateDirective; - } - if (checkScope) { - if (omp::GetScopingUnit(GetContext().scope) != - omp::GetScopingUnit(symbol->GetUltimate().owner())) { - context_.Say(designator.source, // 2.15.3 - "List items must be declared in the same scoping unit in which the %s directive appears"_err_en_US, - parser::ToUpperCaseLetters( - llvm::omp::getOpenMPDirectiveName(directive, version))); - } - } if (ompFlag == Symbol::Flag::OmpReduction) { // Using variables inside of a namelist in OpenMP reductions // is allowed by the standard, but is not allowed for diff --git a/flang/lib/Utils/OpenMP.cpp b/flang/lib/Utils/OpenMP.cpp index 15a42c3..c2036c4 100644 --- a/flang/lib/Utils/OpenMP.cpp +++ b/flang/lib/Utils/OpenMP.cpp @@ -112,7 +112,7 @@ mlir::Value mapTemporaryValue(fir::FirOpBuilder &firOpBuilder, mlir::Block *entryBlock = ®ion.getBlocks().front(); firOpBuilder.setInsertionPointToStart(entryBlock); auto loadOp = - firOpBuilder.create<fir::LoadOp>(clonedValArg.getLoc(), clonedValArg); + fir::LoadOp::create(firOpBuilder, clonedValArg.getLoc(), clonedValArg); return loadOp.getResult(); } |