diff options
Diffstat (limited to 'flang/lib/Optimizer')
-rw-r--r-- | flang/lib/Optimizer/Builder/CMakeLists.txt | 3 | ||||
-rw-r--r-- | flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 111 | ||||
-rw-r--r-- | flang/lib/Optimizer/Builder/Runtime/Coarray.cpp | 228 | ||||
-rw-r--r-- | flang/lib/Optimizer/Builder/Runtime/Main.cpp | 4 | ||||
-rw-r--r-- | flang/lib/Optimizer/Dialect/CMakeLists.txt | 1 | ||||
-rw-r--r-- | flang/lib/Optimizer/Dialect/FIRType.cpp | 7 | ||||
-rw-r--r-- | flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt | 24 | ||||
-rw-r--r-- | flang/lib/Optimizer/Dialect/MIF/MIFDialect.cpp | 24 | ||||
-rw-r--r-- | flang/lib/Optimizer/Dialect/MIF/MIFOps.cpp | 153 | ||||
-rw-r--r-- | flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt | 4 | ||||
-rw-r--r-- | flang/lib/Optimizer/Passes/Pipelines.cpp | 1 | ||||
-rw-r--r-- | flang/lib/Optimizer/Transforms/CMakeLists.txt | 3 | ||||
-rw-r--r-- | flang/lib/Optimizer/Transforms/MIFOpConversion.cpp | 464 |
13 files changed, 723 insertions, 304 deletions
diff --git a/flang/lib/Optimizer/Builder/CMakeLists.txt b/flang/lib/Optimizer/Builder/CMakeLists.txt index 404afd1..1f95259 100644 --- a/flang/lib/Optimizer/Builder/CMakeLists.txt +++ b/flang/lib/Optimizer/Builder/CMakeLists.txt @@ -16,7 +16,6 @@ add_flang_library(FIRBuilder Runtime/Allocatable.cpp Runtime/ArrayConstructor.cpp Runtime/Assign.cpp - Runtime/Coarray.cpp Runtime/Character.cpp Runtime/Command.cpp Runtime/CUDA/Descriptor.cpp @@ -42,6 +41,7 @@ add_flang_library(FIRBuilder CUFDialect FIRDialect HLFIRDialect + MIFDialect LINK_LIBS CUFAttrs @@ -52,6 +52,7 @@ add_flang_library(FIRBuilder FortranEvaluate FortranSupport HLFIRDialect + MIFDialect MLIR_DEPS ${dialect_libs} diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 0195178..29eedfb 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -25,7 +25,6 @@ #include "flang/Optimizer/Builder/Runtime/Allocatable.h" #include "flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h" #include "flang/Optimizer/Builder/Runtime/Character.h" -#include "flang/Optimizer/Builder/Runtime/Coarray.h" #include "flang/Optimizer/Builder/Runtime/Command.h" #include "flang/Optimizer/Builder/Runtime/Derived.h" #include "flang/Optimizer/Builder/Runtime/Exceptions.h" @@ -40,6 +39,7 @@ #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "flang/Optimizer/Dialect/MIF/MIFOps.h" #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/Support/FatalError.h" @@ -412,28 +412,28 @@ static constexpr IntrinsicHandler handlers[]{ {"co_broadcast", &I::genCoBroadcast, {{{"a", asBox}, - {"source_image", asAddr}, + {"source_image", asValue}, {"stat", asAddr, handleDynamicOptional}, {"errmsg", asBox, handleDynamicOptional}}}, /*isElemental*/ false}, {"co_max", &I::genCoMax, {{{"a", asBox}, - {"result_image", asAddr, handleDynamicOptional}, + {"result_image", asValue, handleDynamicOptional}, {"stat", asAddr, handleDynamicOptional}, {"errmsg", asBox, handleDynamicOptional}}}, /*isElemental*/ false}, {"co_min", &I::genCoMin, {{{"a", asBox}, - {"result_image", asAddr, handleDynamicOptional}, + {"result_image", asValue, handleDynamicOptional}, {"stat", asAddr, handleDynamicOptional}, {"errmsg", asBox, handleDynamicOptional}}}, /*isElemental*/ false}, {"co_sum", &I::genCoSum, {{{"a", asBox}, - {"result_image", asAddr, handleDynamicOptional}, + {"result_image", asValue, handleDynamicOptional}, {"stat", asAddr, handleDynamicOptional}, {"errmsg", asBox, handleDynamicOptional}}}, /*isElemental*/ false}, @@ -829,7 +829,7 @@ static constexpr IntrinsicHandler handlers[]{ {"null", &I::genNull, {{{"mold", asInquired}}}, /*isElemental=*/false}, {"num_images", &I::genNumImages, - {{{"team", asAddr}, {"team_number", asAddr}}}, + {{{"team_number", asValue}, {"team", asBox}}}, /*isElemental*/ false}, {"pack", &I::genPack, @@ -3516,11 +3516,23 @@ static mlir::Value getAddrFromBox(fir::FirOpBuilder &builder, return addr; } +static void clocDeviceArgRewrite(fir::ExtendedValue arg) { + // Special case for device address in c_loc. + if (auto emboxOp = mlir::dyn_cast_or_null<fir::EmboxOp>( + fir::getBase(arg).getDefiningOp())) + if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>( + emboxOp.getMemref().getDefiningOp())) + if (declareOp.getDataAttr() && + declareOp.getDataAttr() == cuf::DataAttribute::Device) + emboxOp.getMemrefMutable().assign(declareOp.getMemref()); +} + static fir::ExtendedValue genCLocOrCFunLoc(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args, bool isFunc = false, bool isDevLoc = false) { assert(args.size() == 1); + clocDeviceArgRewrite(args[0]); mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); mlir::Value resAddr; if (isDevLoc) @@ -3795,79 +3807,40 @@ mlir::Value IntrinsicLibrary::genCmplx(mlir::Type resultType, void IntrinsicLibrary::genCoBroadcast(llvm::ArrayRef<fir::ExtendedValue> args) { converter->checkCoarrayEnabled(); assert(args.size() == 4); - mlir::Value sourceImage = fir::getBase(args[1]); - mlir::Value status = - isStaticallyAbsent(args[2]) - ? fir::AbsentOp::create(builder, loc, - builder.getRefType(builder.getI32Type())) - .getResult() - : fir::getBase(args[2]); - mlir::Value errmsg = - isStaticallyAbsent(args[3]) - ? fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE).getResult() - : fir::getBase(args[3]); - fir::runtime::genCoBroadcast(builder, loc, fir::getBase(args[0]), sourceImage, - status, errmsg); + mif::CoBroadcastOp::create(builder, loc, fir::getBase(args[0]), + /*sourceImage*/ fir::getBase(args[1]), + /*status*/ fir::getBase(args[2]), + /*errmsg*/ fir::getBase(args[3])); } // CO_MAX void IntrinsicLibrary::genCoMax(llvm::ArrayRef<fir::ExtendedValue> args) { converter->checkCoarrayEnabled(); assert(args.size() == 4); - mlir::Value refNone = - fir::AbsentOp::create(builder, loc, - builder.getRefType(builder.getI32Type())) - .getResult(); - mlir::Value resultImage = - isStaticallyAbsent(args[1]) ? refNone : fir::getBase(args[1]); - mlir::Value status = - isStaticallyAbsent(args[2]) ? refNone : fir::getBase(args[2]); - mlir::Value errmsg = - isStaticallyAbsent(args[3]) - ? fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE).getResult() - : fir::getBase(args[3]); - fir::runtime::genCoMax(builder, loc, fir::getBase(args[0]), resultImage, - status, errmsg); + mif::CoMaxOp::create(builder, loc, fir::getBase(args[0]), + /*resultImage*/ fir::getBase(args[1]), + /*status*/ fir::getBase(args[2]), + /*errmsg*/ fir::getBase(args[3])); } // CO_MIN void IntrinsicLibrary::genCoMin(llvm::ArrayRef<fir::ExtendedValue> args) { converter->checkCoarrayEnabled(); assert(args.size() == 4); - mlir::Value refNone = - fir::AbsentOp::create(builder, loc, - builder.getRefType(builder.getI32Type())) - .getResult(); - mlir::Value resultImage = - isStaticallyAbsent(args[1]) ? refNone : fir::getBase(args[1]); - mlir::Value status = - isStaticallyAbsent(args[2]) ? refNone : fir::getBase(args[2]); - mlir::Value errmsg = - isStaticallyAbsent(args[3]) - ? fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE).getResult() - : fir::getBase(args[3]); - fir::runtime::genCoMin(builder, loc, fir::getBase(args[0]), resultImage, - status, errmsg); + mif::CoMinOp::create(builder, loc, fir::getBase(args[0]), + /*resultImage*/ fir::getBase(args[1]), + /*status*/ fir::getBase(args[2]), + /*errmsg*/ fir::getBase(args[3])); } // CO_SUM void IntrinsicLibrary::genCoSum(llvm::ArrayRef<fir::ExtendedValue> args) { converter->checkCoarrayEnabled(); assert(args.size() == 4); - mlir::Value absentInt = - fir::AbsentOp::create(builder, loc, - builder.getRefType(builder.getI32Type())) - .getResult(); - mlir::Value resultImage = - isStaticallyAbsent(args[1]) ? absentInt : fir::getBase(args[1]); - mlir::Value status = - isStaticallyAbsent(args[2]) ? absentInt : fir::getBase(args[2]); - mlir::Value errmsg = - isStaticallyAbsent(args[3]) - ? fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE).getResult() - : fir::getBase(args[3]); - fir::runtime::genCoSum(builder, loc, fir::getBase(args[0]), resultImage, - status, errmsg); + mif::CoSumOp::create(builder, loc, fir::getBase(args[0]), + /*resultImage*/ fir::getBase(args[1]), + /*status*/ fir::getBase(args[2]), + /*errmsg*/ fir::getBase(args[3])); } // COMMAND_ARGUMENT_COUNT @@ -7579,9 +7552,9 @@ IntrinsicLibrary::genNumImages(mlir::Type resultType, assert(args.size() == 0 || args.size() == 1); if (args.size()) - return fir::runtime::getNumImagesWithTeam(builder, loc, - fir::getBase(args[0])); - return fir::runtime::getNumImages(builder, loc); + return mif::NumImagesOp::create(builder, loc, fir::getBase(args[0])) + .getResult(); + return mif::NumImagesOp::create(builder, loc).getResult(); } // CLOCK, CLOCK64, GLOBALTIMER @@ -8659,17 +8632,11 @@ IntrinsicLibrary::genThisImage(mlir::Type resultType, converter->checkCoarrayEnabled(); assert(args.size() >= 1 && args.size() <= 3); const bool coarrayIsAbsent = args.size() == 1; - mlir::Value team = - !isStaticallyAbsent(args, args.size() - 1) - ? fir::getBase(args[args.size() - 1]) - : builder - .create<fir::AbsentOp>(loc, - fir::BoxType::get(builder.getNoneType())) - .getResult(); + mlir::Value team = fir::getBase(args[args.size() - 1]); if (!coarrayIsAbsent) TODO(loc, "this_image with coarray argument."); - mlir::Value res = fir::runtime::getThisImage(builder, loc, team); + mlir::Value res = mif::ThisImageOp::create(builder, loc, team); return builder.createConvert(loc, resultType, res); } diff --git a/flang/lib/Optimizer/Builder/Runtime/Coarray.cpp b/flang/lib/Optimizer/Builder/Runtime/Coarray.cpp deleted file mode 100644 index 364e7b7..0000000 --- a/flang/lib/Optimizer/Builder/Runtime/Coarray.cpp +++ /dev/null @@ -1,228 +0,0 @@ -//===-- Coarray.cpp -- runtime API for coarray intrinsics -----------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "flang/Optimizer/Builder/Runtime/Coarray.h" -#include "flang/Optimizer/Builder/FIRBuilder.h" -#include "flang/Optimizer/Builder/Runtime/RTBuilder.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" - -using namespace Fortran::runtime; -using namespace Fortran::semantics; - -// Most PRIF functions take `errmsg` and `errmsg_alloc` as two optional -// arguments of intent (out). One is allocatable, the other is not. -// It is the responsibility of the compiler to ensure that the appropriate -// optional argument is passed, and at most one must be provided in a given -// call. -// Depending on the type of `errmsg`, this function will return the pair -// corresponding to (`errmsg`, `errmsg_alloc`). -static std::pair<mlir::Value, mlir::Value> -genErrmsgPRIF(fir::FirOpBuilder &builder, mlir::Location loc, - mlir::Value errmsg) { - bool isAllocatableErrmsg = fir::isAllocatableType(errmsg.getType()); - - mlir::Value absent = fir::AbsentOp::create(builder, loc, PRIF_ERRMSG_TYPE); - mlir::Value errMsg = isAllocatableErrmsg ? absent : errmsg; - mlir::Value errMsgAlloc = isAllocatableErrmsg ? errmsg : absent; - return {errMsg, errMsgAlloc}; -} - -/// Generate Call to runtime prif_init -mlir::Value fir::runtime::genInitCoarray(fir::FirOpBuilder &builder, - mlir::Location loc) { - mlir::Type i32Ty = builder.getI32Type(); - mlir::Value result = builder.createTemporary(loc, i32Ty); - mlir::FunctionType ftype = PRIF_FUNCTYPE(builder.getRefType(i32Ty)); - mlir::func::FuncOp funcOp = - builder.createFunction(loc, PRIFNAME_SUB("init"), ftype); - llvm::SmallVector<mlir::Value> args = - fir::runtime::createArguments(builder, loc, ftype, result); - fir::CallOp::create(builder, loc, funcOp, args); - return fir::LoadOp::create(builder, loc, result); -} - -/// Generate Call to runtime prif_num_images -mlir::Value fir::runtime::getNumImages(fir::FirOpBuilder &builder, - mlir::Location loc) { - mlir::Value result = builder.createTemporary(loc, builder.getI32Type()); - mlir::FunctionType ftype = - PRIF_FUNCTYPE(builder.getRefType(builder.getI32Type())); - mlir::func::FuncOp funcOp = - builder.createFunction(loc, PRIFNAME_SUB("num_images"), ftype); - llvm::SmallVector<mlir::Value> args = - fir::runtime::createArguments(builder, loc, ftype, result); - fir::CallOp::create(builder, loc, funcOp, args); - return fir::LoadOp::create(builder, loc, result); -} - -/// Generate Call to runtime prif_num_images_with_{team|team_number} -mlir::Value fir::runtime::getNumImagesWithTeam(fir::FirOpBuilder &builder, - mlir::Location loc, - mlir::Value team) { - bool isTeamNumber = fir::unwrapPassByRefType(team.getType()).isInteger(); - std::string numImagesName = isTeamNumber - ? PRIFNAME_SUB("num_images_with_team_number") - : PRIFNAME_SUB("num_images_with_team"); - - mlir::Value result = builder.createTemporary(loc, builder.getI32Type()); - mlir::Type refTy = builder.getRefType(builder.getI32Type()); - mlir::FunctionType ftype = - isTeamNumber - ? PRIF_FUNCTYPE(builder.getRefType(builder.getI64Type()), refTy) - : PRIF_FUNCTYPE(fir::BoxType::get(builder.getNoneType()), refTy); - mlir::func::FuncOp funcOp = builder.createFunction(loc, numImagesName, ftype); - - if (!isTeamNumber) - team = builder.createBox(loc, team); - llvm::SmallVector<mlir::Value> args = - fir::runtime::createArguments(builder, loc, ftype, team, result); - fir::CallOp::create(builder, loc, funcOp, args); - return fir::LoadOp::create(builder, loc, result); -} - -/// Generate Call to runtime prif_this_image_no_coarray -mlir::Value fir::runtime::getThisImage(fir::FirOpBuilder &builder, - mlir::Location loc, mlir::Value team) { - mlir::Type refTy = builder.getRefType(builder.getI32Type()); - mlir::Type boxTy = fir::BoxType::get(builder.getNoneType()); - mlir::FunctionType ftype = PRIF_FUNCTYPE(boxTy, refTy); - mlir::func::FuncOp funcOp = - builder.createFunction(loc, PRIFNAME_SUB("this_image_no_coarray"), ftype); - - mlir::Value result = builder.createTemporary(loc, builder.getI32Type()); - mlir::Value teamArg = - !team ? fir::AbsentOp::create(builder, loc, boxTy) : team; - llvm::SmallVector<mlir::Value> args = - fir::runtime::createArguments(builder, loc, ftype, teamArg, result); - fir::CallOp::create(builder, loc, funcOp, args); - return fir::LoadOp::create(builder, loc, result); -} - -/// Generate call to collective subroutines except co_reduce -/// A must be lowered as a box -void genCollectiveSubroutine(fir::FirOpBuilder &builder, mlir::Location loc, - mlir::Value A, mlir::Value rootImage, - mlir::Value stat, mlir::Value errmsg, - std::string coName) { - mlir::Type boxTy = fir::BoxType::get(builder.getNoneType()); - mlir::FunctionType ftype = - PRIF_FUNCTYPE(boxTy, builder.getRefType(builder.getI32Type()), - PRIF_STAT_TYPE, PRIF_ERRMSG_TYPE, PRIF_ERRMSG_TYPE); - mlir::func::FuncOp funcOp = builder.createFunction(loc, coName, ftype); - - auto [errmsgArg, errmsgAllocArg] = genErrmsgPRIF(builder, loc, errmsg); - llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( - builder, loc, ftype, A, rootImage, stat, errmsgArg, errmsgAllocArg); - fir::CallOp::create(builder, loc, funcOp, args); -} - -/// Generate call to runtime subroutine prif_co_broadcast -void fir::runtime::genCoBroadcast(fir::FirOpBuilder &builder, - mlir::Location loc, mlir::Value A, - mlir::Value sourceImage, mlir::Value stat, - mlir::Value errmsg) { - genCollectiveSubroutine(builder, loc, A, sourceImage, stat, errmsg, - PRIFNAME_SUB("co_broadcast")); -} - -/// Generate call to runtime subroutine prif_co_max or prif_co_max_character -void fir::runtime::genCoMax(fir::FirOpBuilder &builder, mlir::Location loc, - mlir::Value A, mlir::Value resultImage, - mlir::Value stat, mlir::Value errmsg) { - mlir::Type argTy = - fir::unwrapSequenceType(fir::unwrapPassByRefType(A.getType())); - if (mlir::isa<fir::CharacterType>(argTy)) - genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg, - PRIFNAME_SUB("co_max_character")); - else - genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg, - PRIFNAME_SUB("co_max")); -} - -/// Generate call to runtime subroutine prif_co_min or prif_co_min_character -void fir::runtime::genCoMin(fir::FirOpBuilder &builder, mlir::Location loc, - mlir::Value A, mlir::Value resultImage, - mlir::Value stat, mlir::Value errmsg) { - mlir::Type argTy = - fir::unwrapSequenceType(fir::unwrapPassByRefType(A.getType())); - if (mlir::isa<fir::CharacterType>(argTy)) - genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg, - PRIFNAME_SUB("co_min_character")); - else - genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg, - PRIFNAME_SUB("co_min")); -} - -/// Generate call to runtime subroutine prif_co_sum -void fir::runtime::genCoSum(fir::FirOpBuilder &builder, mlir::Location loc, - mlir::Value A, mlir::Value resultImage, - mlir::Value stat, mlir::Value errmsg) { - genCollectiveSubroutine(builder, loc, A, resultImage, stat, errmsg, - PRIFNAME_SUB("co_sum")); -} - -/// Generate call to runtime subroutine prif_sync_all -void fir::runtime::genSyncAllStatement(fir::FirOpBuilder &builder, - mlir::Location loc, mlir::Value stat, - mlir::Value errmsg) { - mlir::FunctionType ftype = - PRIF_FUNCTYPE(PRIF_STAT_TYPE, PRIF_ERRMSG_TYPE, PRIF_ERRMSG_TYPE); - mlir::func::FuncOp funcOp = - builder.createFunction(loc, PRIFNAME_SUB("sync_all"), ftype); - - auto [errmsgArg, errmsgAllocArg] = genErrmsgPRIF(builder, loc, errmsg); - llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( - builder, loc, ftype, stat, errmsgArg, errmsgAllocArg); - fir::CallOp::create(builder, loc, funcOp, args); -} - -/// Generate call to runtime subroutine prif_sync_memory -void fir::runtime::genSyncMemoryStatement(fir::FirOpBuilder &builder, - mlir::Location loc, mlir::Value stat, - mlir::Value errmsg) { - mlir::FunctionType ftype = - PRIF_FUNCTYPE(PRIF_STAT_TYPE, PRIF_ERRMSG_TYPE, PRIF_ERRMSG_TYPE); - mlir::func::FuncOp funcOp = - builder.createFunction(loc, PRIFNAME_SUB("sync_memory"), ftype); - - auto [errmsgArg, errmsgAllocArg] = genErrmsgPRIF(builder, loc, errmsg); - llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( - builder, loc, ftype, stat, errmsgArg, errmsgAllocArg); - fir::CallOp::create(builder, loc, funcOp, args); -} - -/// Generate call to runtime subroutine prif_sync_images -void fir::runtime::genSyncImagesStatement(fir::FirOpBuilder &builder, - mlir::Location loc, - mlir::Value imageSet, - mlir::Value stat, - mlir::Value errmsg) { - mlir::Type imgSetTy = fir::BoxType::get(fir::SequenceType::get( - {fir::SequenceType::getUnknownExtent()}, builder.getI32Type())); - mlir::FunctionType ftype = PRIF_FUNCTYPE(imgSetTy, PRIF_STAT_TYPE, - PRIF_ERRMSG_TYPE, PRIF_ERRMSG_TYPE); - mlir::func::FuncOp funcOp = - builder.createFunction(loc, PRIFNAME_SUB("sync_images"), ftype); - - // If imageSet is scalar, PRIF require to pass an array of size 1. - if (auto boxTy = mlir::dyn_cast<fir::BoxType>(imageSet.getType())) { - if (!mlir::isa<fir::SequenceType>(boxTy.getEleTy())) { - mlir::Value one = - builder.createIntegerConstant(loc, builder.getI32Type(), 1); - mlir::Value shape = fir::ShapeOp::create(builder, loc, one); - imageSet = fir::ReboxOp::create( - builder, loc, - fir::BoxType::get(fir::SequenceType::get({1}, builder.getI32Type())), - imageSet, shape, mlir::Value{}); - } - } - auto [errmsgArg, errmsgAllocArg] = genErrmsgPRIF(builder, loc, errmsg); - llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( - builder, loc, ftype, imageSet, stat, errmsgArg, errmsgAllocArg); - fir::CallOp::create(builder, loc, funcOp, args); -} diff --git a/flang/lib/Optimizer/Builder/Runtime/Main.cpp b/flang/lib/Optimizer/Builder/Runtime/Main.cpp index d303e0a..9ce5e17 100644 --- a/flang/lib/Optimizer/Builder/Runtime/Main.cpp +++ b/flang/lib/Optimizer/Builder/Runtime/Main.cpp @@ -10,11 +10,11 @@ #include "flang/Lower/EnvironmentDefault.h" #include "flang/Optimizer/Builder/BoxValue.h" #include "flang/Optimizer/Builder/FIRBuilder.h" -#include "flang/Optimizer/Builder/Runtime/Coarray.h" #include "flang/Optimizer/Builder/Runtime/EnvironmentDefaults.h" #include "flang/Optimizer/Builder/Runtime/RTBuilder.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Dialect/MIF/MIFOps.h" #include "flang/Runtime/CUDA/init.h" #include "flang/Runtime/main.h" #include "flang/Runtime/stop.h" @@ -71,7 +71,7 @@ void fir::runtime::genMain( fir::CallOp::create(builder, loc, initFn); } if (initCoarrayEnv) - fir::runtime::genInitCoarray(builder, loc); + mif::InitOp::create(builder, loc); fir::CallOp::create(builder, loc, qqMainFn); fir::CallOp::create(builder, loc, stopFn); diff --git a/flang/lib/Optimizer/Dialect/CMakeLists.txt b/flang/lib/Optimizer/Dialect/CMakeLists.txt index 4fd4d28..65d1f2c 100644 --- a/flang/lib/Optimizer/Dialect/CMakeLists.txt +++ b/flang/lib/Optimizer/Dialect/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(Support) add_subdirectory(CUF) add_subdirectory(FIRCG) +add_subdirectory(MIF) add_flang_library(FIRDialect FIRAttr.cpp diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp index 48e1622..fe35b08 100644 --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -1427,6 +1427,13 @@ mlir::Type BaseBoxType::unwrapInnerType() const { return fir::unwrapInnerType(getEleTy()); } +mlir::Type BaseBoxType::getElementOrSequenceType() const { + mlir::Type eleTy = getEleTy(); + if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) + return seqTy; + return fir::unwrapRefType(eleTy); +} + static mlir::Type changeTypeShape(mlir::Type type, std::optional<fir::SequenceType::ShapeRef> newShape) { diff --git a/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt b/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt new file mode 100644 index 0000000..ed8463e --- /dev/null +++ b/flang/lib/Optimizer/Dialect/MIF/CMakeLists.txt @@ -0,0 +1,24 @@ +add_flang_library(MIFDialect + MIFDialect.cpp + MIFOps.cpp + + DEPENDS + MIFOpsIncGen + + LINK_LIBS + FIRDialect + FIRDialectSupport + FIRSupport + + LINK_COMPONENTS + AsmParser + AsmPrinter + Remarks + + MLIR_DEPS + MLIRIR + + MLIR_LIBS + MLIRIR + MLIRTargetLLVMIRExport +) diff --git a/flang/lib/Optimizer/Dialect/MIF/MIFDialect.cpp b/flang/lib/Optimizer/Dialect/MIF/MIFDialect.cpp new file mode 100644 index 0000000..edc723d --- /dev/null +++ b/flang/lib/Optimizer/Dialect/MIF/MIFDialect.cpp @@ -0,0 +1,24 @@ +//===- MIFDialect.cpp - MIF dialect implementation ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// C +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/MIF/MIFDialect.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/MIF/MIFOps.h" + +//===----------------------------------------------------------------------===// +/// Tablegen Definitions +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/MIF/MIFDialect.cpp.inc" + +void mif::MIFDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "flang/Optimizer/Dialect/MIF/MIFOps.cpp.inc" + >(); +} diff --git a/flang/lib/Optimizer/Dialect/MIF/MIFOps.cpp b/flang/lib/Optimizer/Dialect/MIF/MIFOps.cpp new file mode 100644 index 0000000..c6cc2e8 --- /dev/null +++ b/flang/lib/Optimizer/Dialect/MIF/MIFOps.cpp @@ -0,0 +1,153 @@ +//===-- MIFOps.cpp - MIF dialect ops implementation -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/MIF/MIFOps.h" +#include "flang/Optimizer/Builder/Todo.h" +#include "flang/Optimizer/Dialect/FIRAttr.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Dialect/MIF/MIFDialect.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/SmallVector.h" + +#define GET_OP_CLASSES +#include "flang/Optimizer/Dialect/MIF/MIFOps.cpp.inc" + +//===----------------------------------------------------------------------===// +// NumImagesOp +//===----------------------------------------------------------------------===// + +void mif::NumImagesOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, + mlir::Value teamArg) { + bool isTeamNumber = + teamArg && fir::unwrapPassByRefType(teamArg.getType()).isInteger(); + if (isTeamNumber) + build(builder, result, teamArg, /*team*/ mlir::Value{}); + else + build(builder, result, /*team_number*/ mlir::Value{}, teamArg); +} + +llvm::LogicalResult mif::NumImagesOp::verify() { + if (getTeam() && getTeamNumber()) + return emitOpError( + "team and team_number must not be provided at the same time"); + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// ThisImageOp +//===----------------------------------------------------------------------===// + +void mif::ThisImageOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, mlir::Value coarray, + mlir::Value team) { + build(builder, result, coarray, /*dim*/ mlir::Value{}, team); +} + +void mif::ThisImageOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, mlir::Value team) { + build(builder, result, /*coarray*/ mlir::Value{}, /*dim*/ mlir::Value{}, + team); +} + +llvm::LogicalResult mif::ThisImageOp::verify() { + if (getDim() && !getCoarray()) + return emitOpError( + "`dim` must be provied at the same time as the `coarray` argument."); + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// SyncImagesOp +//===----------------------------------------------------------------------===// + +llvm::LogicalResult mif::SyncImagesOp::verify() { + if (getImageSet()) { + mlir::Type t = getImageSet().getType(); + fir::BoxType boxTy = mlir::dyn_cast<fir::BoxType>(t); + if (auto seqTy = mlir::dyn_cast<fir::SequenceType>( + boxTy.getElementOrSequenceType())) { + if (seqTy.getDimension() != 0 && seqTy.getDimension() != 1) + return emitOpError( + "`image_set` must be a boxed integer expression of rank 1."); + if (!fir::isa_integer(seqTy.getElementType())) + return emitOpError("`image_set` must be a boxed array of integer."); + } else if (!fir::isa_integer(boxTy.getElementType())) + return emitOpError( + "`image_set` must be a boxed scalar integer expression."); + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// CoBroadcastOp +//===----------------------------------------------------------------------===// + +llvm::LogicalResult mif::CoBroadcastOp::verify() { + fir::BoxType boxTy = mlir::dyn_cast<fir::BoxType>(getA().getType()); + + if (fir::isPolymorphicType(boxTy)) + return emitOpError("`A` cannot be polymorphic."); + else if (auto recTy = + mlir::dyn_cast<fir::RecordType>(boxTy.getElementType())) { + for (auto component : recTy.getTypeList()) { + if (fir::isPolymorphicType(component.second)) + TODO(getLoc(), "`A` with polymorphic subobject component."); + } + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// CoMaxOp +//===----------------------------------------------------------------------===// + +llvm::LogicalResult mif::CoMaxOp::verify() { + fir::BoxType boxTy = mlir::dyn_cast<fir::BoxType>(getA().getType()); + mlir::Type elemTy = boxTy.getElementOrSequenceType(); + if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(elemTy)) + elemTy = seqTy.getElementType(); + + if (!fir::isa_real(elemTy) && !fir::isa_integer(elemTy) && + !fir::isa_char(elemTy)) + return emitOpError("`A` shall be of type integer, real or character."); + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// CoMinOp +//===----------------------------------------------------------------------===// + +llvm::LogicalResult mif::CoMinOp::verify() { + fir::BoxType boxTy = mlir::dyn_cast<fir::BoxType>(getA().getType()); + mlir::Type elemTy = boxTy.getElementOrSequenceType(); + if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(elemTy)) + elemTy = seqTy.getElementType(); + + if (!fir::isa_real(elemTy) && !fir::isa_integer(elemTy) && + !fir::isa_char(elemTy)) + return emitOpError("`A` shall be of type integer, real or character."); + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// CoSumOp +//===----------------------------------------------------------------------===// + +llvm::LogicalResult mif::CoSumOp::verify() { + fir::BoxType boxTy = mlir::dyn_cast<fir::BoxType>(getA().getType()); + mlir::Type elemTy = boxTy.getElementOrSequenceType(); + if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(elemTy)) + elemTy = seqTy.getElementType(); + + if (!fir::isa_real(elemTy) && !fir::isa_integer(elemTy) && + !fir::isa_complex(elemTy)) + return emitOpError("`A` shall be of numeric type."); + return mlir::success(); +} diff --git a/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt index 2427da0..ed177ba 100644 --- a/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt @@ -5,8 +5,10 @@ add_flang_library(FIROpenACCTransforms FIROpenACCPassesIncGen LINK_LIBS + FIRDialect + + MLIR_LIBS MLIRIR MLIRPass - FIRDialect MLIROpenACCDialect ) diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 1ecb6d3..6dae39b 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -354,6 +354,7 @@ void createDebugPasses(mlir::PassManager &pm, void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm, MLIRToLLVMPassPipelineConfig config, llvm::StringRef inputFilename) { + pm.addPass(fir::createMIFOpConversion()); fir::addBoxedProcedurePass(pm); if (config.OptLevel.isOptimizingForSpeed() && config.AliasAnalysis && !disableFirAliasTags && !useOldAliasTags) diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt index 4ec1627..0388439 100644 --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -36,6 +36,7 @@ add_flang_library(FIRTransforms SimplifyFIROperations.cpp OptimizeArrayRepacking.cpp ConvertComplexPow.cpp + MIFOpConversion.cpp DEPENDS CUFAttrs @@ -43,6 +44,7 @@ add_flang_library(FIRTransforms FIRDialect FIROptTransformsPassIncGen HLFIROpsIncGen + MIFDialect LINK_LIBS CUFAttrs @@ -56,6 +58,7 @@ add_flang_library(FIRTransforms FIRSupport FortranSupport HLFIRDialect + MIFDialect MLIR_LIBS MLIRAffineUtils diff --git a/flang/lib/Optimizer/Transforms/MIFOpConversion.cpp b/flang/lib/Optimizer/Transforms/MIFOpConversion.cpp new file mode 100644 index 0000000..206cb9b --- /dev/null +++ b/flang/lib/Optimizer/Transforms/MIFOpConversion.cpp @@ -0,0 +1,464 @@ +//===-- MIFOpConversion.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Transforms/MIFOpConversion.h" +#include "flang/Optimizer/Builder/Runtime/RTBuilder.h" +#include "flang/Optimizer/Builder/Todo.h" +#include "flang/Optimizer/CodeGen/TypeConverter.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/MIF/MIFOps.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Optimizer/Support/DataLayout.h" +#include "flang/Optimizer/Support/InternalNames.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace fir { +#define GEN_PASS_DEF_MIFOPCONVERSION +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace mlir; +using namespace Fortran::runtime; + +namespace { + +// Default prefix for subroutines of PRIF compiled with LLVM +static std::string getPRIFProcName(std::string fmt) { + std::ostringstream oss; + oss << "prif_" << fmt; + return fir::NameUniquer::doProcedure({"prif"}, {}, oss.str()); +} + +static mlir::Type getPRIFStatType(fir::FirOpBuilder &builder) { + return builder.getRefType(builder.getI32Type()); +} + +static mlir::Type getPRIFErrmsgType(fir::FirOpBuilder &builder) { + return fir::BoxType::get(fir::CharacterType::get( + builder.getContext(), 1, fir::CharacterType::unknownLen())); +} + +// Most PRIF functions take `errmsg` and `errmsg_alloc` as two optional +// arguments of intent (out). One is allocatable, the other is not. +// It is the responsibility of the compiler to ensure that the appropriate +// optional argument is passed, and at most one must be provided in a given +// call. +// Depending on the type of `errmsg`, this function will return the pair +// corresponding to (`errmsg`, `errmsg_alloc`). +static std::pair<mlir::Value, mlir::Value> +genErrmsgPRIF(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value errmsg) { + mlir::Value absent = + fir::AbsentOp::create(builder, loc, getPRIFErrmsgType(builder)); + if (!errmsg) + return {absent, absent}; + + bool isAllocatableErrmsg = fir::isAllocatableType(errmsg.getType()); + mlir::Value errMsg = isAllocatableErrmsg ? absent : errmsg; + mlir::Value errMsgAlloc = isAllocatableErrmsg ? errmsg : absent; + return {errMsg, errMsgAlloc}; +} + +/// Convert mif.init operation to runtime call of 'prif_init' +struct MIFInitOpConversion : public mlir::OpRewritePattern<mif::InitOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::InitOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + mlir::Type i32Ty = builder.getI32Type(); + mlir::Value result = builder.createTemporary(loc, i32Ty); + mlir::FunctionType ftype = mlir::FunctionType::get( + builder.getContext(), + /*inputs*/ {builder.getRefType(i32Ty)}, /*results*/ {}); + mlir::func::FuncOp funcOp = + builder.createFunction(loc, getPRIFProcName("init"), ftype); + llvm::SmallVector<mlir::Value> args = + fir::runtime::createArguments(builder, loc, ftype, result); + fir::CallOp::create(builder, loc, funcOp, args); + rewriter.replaceOpWithNewOp<fir::LoadOp>(op, result); + return mlir::success(); + } +}; + +/// Convert mif.this_image operation to PRIF runtime call +struct MIFThisImageOpConversion + : public mlir::OpRewritePattern<mif::ThisImageOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::ThisImageOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + if (op.getCoarray()) + TODO(loc, "mif.this_image op with coarray argument."); + else { + mlir::Type i32Ty = builder.getI32Type(); + mlir::Type boxTy = fir::BoxType::get(rewriter.getNoneType()); + mlir::Value result = builder.createTemporary(loc, i32Ty); + mlir::FunctionType ftype = mlir::FunctionType::get( + builder.getContext(), + /*inputs*/ {boxTy, builder.getRefType(i32Ty)}, /*results*/ {}); + mlir::Value teamArg = op.getTeam(); + if (!op.getTeam()) + teamArg = fir::AbsentOp::create(builder, loc, boxTy); + + mlir::func::FuncOp funcOp = builder.createFunction( + loc, getPRIFProcName("this_image_no_coarray"), ftype); + llvm::SmallVector<mlir::Value> args = + fir::runtime::createArguments(builder, loc, ftype, teamArg, result); + fir::CallOp::create(builder, loc, funcOp, args); + rewriter.replaceOpWithNewOp<fir::LoadOp>(op, result); + return mlir::success(); + } + } +}; + +/// Convert mif.num_images operation to runtime call of +/// prif_num_images_with_{team|team_number} +struct MIFNumImagesOpConversion + : public mlir::OpRewritePattern<mif::NumImagesOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::NumImagesOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + mlir::Type i32Ty = builder.getI32Type(); + mlir::Type i64Ty = builder.getI64Type(); + mlir::Type boxTy = fir::BoxType::get(rewriter.getNoneType()); + mlir::Value result = builder.createTemporary(loc, i32Ty); + + mlir::func::FuncOp funcOp; + llvm::SmallVector<mlir::Value> args; + if (!op.getTeam() && !op.getTeamNumber()) { + mlir::FunctionType ftype = mlir::FunctionType::get( + builder.getContext(), + /*inputs*/ {builder.getRefType(i32Ty)}, /*results*/ {}); + funcOp = + builder.createFunction(loc, getPRIFProcName("num_images"), ftype); + args = fir::runtime::createArguments(builder, loc, ftype, result); + } else { + if (op.getTeam()) { + mlir::FunctionType ftype = + mlir::FunctionType::get(builder.getContext(), + /*inputs*/ + {boxTy, builder.getRefType(i32Ty)}, + /*results*/ {}); + funcOp = builder.createFunction( + loc, getPRIFProcName("num_images_with_team"), ftype); + args = fir::runtime::createArguments(builder, loc, ftype, op.getTeam(), + result); + } else { + mlir::Value teamNumber = builder.createTemporary(loc, i64Ty); + mlir::Value cst = op.getTeamNumber(); + if (op.getTeamNumber().getType() != i64Ty) + cst = fir::ConvertOp::create(builder, loc, i64Ty, op.getTeamNumber()); + fir::StoreOp::create(builder, loc, cst, teamNumber); + mlir::FunctionType ftype = mlir::FunctionType::get( + builder.getContext(), + /*inputs*/ {builder.getRefType(i64Ty), builder.getRefType(i32Ty)}, + /*results*/ {}); + funcOp = builder.createFunction( + loc, getPRIFProcName("num_images_with_team_number"), ftype); + args = fir::runtime::createArguments(builder, loc, ftype, teamNumber, + result); + } + } + fir::CallOp::create(builder, loc, funcOp, args); + rewriter.replaceOpWithNewOp<fir::LoadOp>(op, result); + return mlir::success(); + } +}; + +/// Convert mif.sync_all operation to runtime call of 'prif_sync_all' +struct MIFSyncAllOpConversion : public mlir::OpRewritePattern<mif::SyncAllOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::SyncAllOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + mlir::Type errmsgTy = getPRIFErrmsgType(builder); + mlir::FunctionType ftype = mlir::FunctionType::get( + builder.getContext(), + /*inputs*/ {getPRIFStatType(builder), errmsgTy, errmsgTy}, + /*results*/ {}); + mlir::func::FuncOp funcOp = + builder.createFunction(loc, getPRIFProcName("sync_all"), ftype); + + auto [errmsgArg, errmsgAllocArg] = + genErrmsgPRIF(builder, loc, op.getErrmsg()); + mlir::Value stat = op.getStat(); + if (!stat) + stat = fir::AbsentOp::create(builder, loc, getPRIFStatType(builder)); + llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( + builder, loc, ftype, stat, errmsgArg, errmsgAllocArg); + rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args); + return mlir::success(); + } +}; + +/// Convert mif.sync_images operation to runtime call of 'prif_sync_images' +struct MIFSyncImagesOpConversion + : public mlir::OpRewritePattern<mif::SyncImagesOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::SyncImagesOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + mlir::Type errmsgTy = getPRIFErrmsgType(builder); + mlir::Type imgSetTy = fir::BoxType::get(fir::SequenceType::get( + {fir::SequenceType::getUnknownExtent()}, builder.getI32Type())); + mlir::FunctionType ftype = mlir::FunctionType::get( + builder.getContext(), + /*inputs*/ + {imgSetTy, getPRIFStatType(builder), errmsgTy, errmsgTy}, + /*results*/ {}); + mlir::func::FuncOp funcOp = + builder.createFunction(loc, getPRIFProcName("sync_images"), ftype); + + // If imageSet is scalar, PRIF require to pass an array of size 1. + mlir::Value imageSet = op.getImageSet(); + if (!imageSet) + imageSet = fir::AbsentOp::create(builder, loc, imgSetTy); + else if (auto boxTy = mlir::dyn_cast<fir::BoxType>(imageSet.getType())) { + if (!mlir::isa<fir::SequenceType>(boxTy.getEleTy())) { + mlir::Value one = + builder.createIntegerConstant(loc, builder.getI32Type(), 1); + mlir::Value shape = fir::ShapeOp::create(builder, loc, one); + imageSet = + fir::ReboxOp::create(builder, loc, + fir::BoxType::get(fir::SequenceType::get( + {1}, builder.getI32Type())), + imageSet, shape, mlir::Value{}); + } + } + auto [errmsgArg, errmsgAllocArg] = + genErrmsgPRIF(builder, loc, op.getErrmsg()); + mlir::Value stat = op.getStat(); + if (!stat) + stat = fir::AbsentOp::create(builder, loc, getPRIFStatType(builder)); + llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( + builder, loc, ftype, imageSet, stat, errmsgArg, errmsgAllocArg); + rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args); + return mlir::success(); + } +}; + +/// Convert mif.sync_memory operation to runtime call of 'prif_sync_memory' +struct MIFSyncMemoryOpConversion + : public mlir::OpRewritePattern<mif::SyncMemoryOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::SyncMemoryOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + mlir::Type errmsgTy = getPRIFErrmsgType(builder); + mlir::FunctionType ftype = mlir::FunctionType::get( + builder.getContext(), + /*inputs*/ {getPRIFStatType(builder), errmsgTy, errmsgTy}, + /*results*/ {}); + mlir::func::FuncOp funcOp = + builder.createFunction(loc, getPRIFProcName("sync_memory"), ftype); + + auto [errmsgArg, errmsgAllocArg] = + genErrmsgPRIF(builder, loc, op.getErrmsg()); + mlir::Value stat = op.getStat(); + if (!stat) + stat = fir::AbsentOp::create(builder, loc, getPRIFStatType(builder)); + llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( + builder, loc, ftype, stat, errmsgArg, errmsgAllocArg); + rewriter.replaceOpWithNewOp<fir::CallOp>(op, funcOp, args); + return mlir::success(); + } +}; + +/// Generate call to collective subroutines except co_reduce +/// A must be lowered as a box +static fir::CallOp genCollectiveSubroutine(fir::FirOpBuilder &builder, + mlir::Location loc, mlir::Value A, + mlir::Value image, mlir::Value stat, + mlir::Value errmsg, + std::string coName) { + mlir::Value rootImage; + mlir::Type i32Ty = builder.getI32Type(); + if (!image) + rootImage = fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty)); + else { + rootImage = builder.createTemporary(loc, i32Ty); + if (image.getType() != i32Ty) + image = fir::ConvertOp::create(builder, loc, i32Ty, image); + fir::StoreOp::create(builder, loc, image, rootImage); + } + + mlir::Type errmsgTy = getPRIFErrmsgType(builder); + mlir::Type boxTy = fir::BoxType::get(builder.getNoneType()); + mlir::FunctionType ftype = + mlir::FunctionType::get(builder.getContext(), + /*inputs*/ + {boxTy, builder.getRefType(builder.getI32Type()), + getPRIFStatType(builder), errmsgTy, errmsgTy}, + /*results*/ {}); + mlir::func::FuncOp funcOp = builder.createFunction(loc, coName, ftype); + + auto [errmsgArg, errmsgAllocArg] = genErrmsgPRIF(builder, loc, errmsg); + if (!stat) + stat = fir::AbsentOp::create(builder, loc, getPRIFStatType(builder)); + llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments( + builder, loc, ftype, A, rootImage, stat, errmsgArg, errmsgAllocArg); + return fir::CallOp::create(builder, loc, funcOp, args); +} + +/// Convert mif.co_broadcast operation to runtime call of 'prif_co_broadcast' +struct MIFCoBroadcastOpConversion + : public mlir::OpRewritePattern<mif::CoBroadcastOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::CoBroadcastOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + fir::CallOp callOp = genCollectiveSubroutine( + builder, loc, op.getA(), op.getSourceImage(), op.getStat(), + op.getErrmsg(), getPRIFProcName("co_broadcast")); + rewriter.replaceOp(op, callOp); + return mlir::success(); + } +}; + +/// Convert mif.co_max operation to runtime call of 'prif_co_max' +struct MIFCoMaxOpConversion : public mlir::OpRewritePattern<mif::CoMaxOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::CoMaxOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + fir::CallOp callOp; + mlir::Type argTy = + fir::unwrapSequenceType(fir::unwrapPassByRefType(op.getA().getType())); + if (mlir::isa<fir::CharacterType>(argTy)) + callOp = genCollectiveSubroutine( + builder, loc, op.getA(), op.getResultImage(), op.getStat(), + op.getErrmsg(), getPRIFProcName("co_max_character")); + else + callOp = genCollectiveSubroutine( + builder, loc, op.getA(), op.getResultImage(), op.getStat(), + op.getErrmsg(), getPRIFProcName("co_max")); + rewriter.replaceOp(op, callOp); + return mlir::success(); + } +}; + +/// Convert mif.co_min operation to runtime call of 'prif_co_min' +struct MIFCoMinOpConversion : public mlir::OpRewritePattern<mif::CoMinOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::CoMinOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + fir::CallOp callOp; + mlir::Type argTy = + fir::unwrapSequenceType(fir::unwrapPassByRefType(op.getA().getType())); + if (mlir::isa<fir::CharacterType>(argTy)) + callOp = genCollectiveSubroutine( + builder, loc, op.getA(), op.getResultImage(), op.getStat(), + op.getErrmsg(), getPRIFProcName("co_min_character")); + else + callOp = genCollectiveSubroutine( + builder, loc, op.getA(), op.getResultImage(), op.getStat(), + op.getErrmsg(), getPRIFProcName("co_min")); + rewriter.replaceOp(op, callOp); + return mlir::success(); + } +}; + +/// Convert mif.co_sum operation to runtime call of 'prif_co_sum' +struct MIFCoSumOpConversion : public mlir::OpRewritePattern<mif::CoSumOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(mif::CoSumOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->template getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + fir::CallOp callOp = genCollectiveSubroutine( + builder, loc, op.getA(), op.getResultImage(), op.getStat(), + op.getErrmsg(), getPRIFProcName("co_sum")); + rewriter.replaceOp(op, callOp); + return mlir::success(); + } +}; + +class MIFOpConversion : public fir::impl::MIFOpConversionBase<MIFOpConversion> { +public: + void runOnOperation() override { + auto *ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); + mlir::ConversionTarget target(*ctx); + + mif::populateMIFOpConversionPatterns(patterns); + + target.addLegalDialect<fir::FIROpsDialect>(); + target.addLegalOp<mlir::ModuleOp>(); + + if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + mlir::emitError(mlir::UnknownLoc::get(ctx), + "error in MIF op conversion\n"); + return signalPassFailure(); + } + } +}; +} // namespace + +void mif::populateMIFOpConversionPatterns(mlir::RewritePatternSet &patterns) { + patterns.insert<MIFInitOpConversion, MIFThisImageOpConversion, + MIFNumImagesOpConversion, MIFSyncAllOpConversion, + MIFSyncImagesOpConversion, MIFSyncMemoryOpConversion, + MIFCoBroadcastOpConversion, MIFCoMaxOpConversion, + MIFCoMinOpConversion, MIFCoSumOpConversion>( + patterns.getContext()); +} |