aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--flang/include/flang/Utils/OpenMP.h9
-rw-r--r--flang/lib/Lower/OpenMP/ClauseProcessor.cpp12
-rw-r--r--flang/lib/Lower/OpenMP/OpenMP.cpp15
-rw-r--r--flang/lib/Lower/OpenMP/Utils.cpp107
-rw-r--r--flang/lib/Lower/OpenMP/Utils.h4
-rw-r--r--flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp37
-rw-r--r--flang/lib/Utils/OpenMP.cpp97
-rw-r--r--flang/test/Transforms/DoConcurrent/implicit_mapper.f9028
8 files changed, 194 insertions, 115 deletions
diff --git a/flang/include/flang/Utils/OpenMP.h b/flang/include/flang/Utils/OpenMP.h
index bad0abb6f578..e8627347fd57 100644
--- a/flang/include/flang/Utils/OpenMP.h
+++ b/flang/include/flang/Utils/OpenMP.h
@@ -13,6 +13,7 @@
namespace fir {
class FirOpBuilder;
+class RecordType;
} // namespace fir
namespace Fortran::utils::openmp {
@@ -59,6 +60,14 @@ mlir::Value mapTemporaryValue(fir::FirOpBuilder &firOpBuilder,
/// maps.
void cloneOrMapRegionOutsiders(
fir::FirOpBuilder &firOpBuilder, mlir::omp::TargetOp targetOp);
+
+using RecordMemberMapperMangler =
+ std::function<void(std::string &mapperId, llvm::StringRef memberName)>;
+
+mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
+ fir::FirOpBuilder &firOpBuilder, mlir::Location loc,
+ fir::RecordType recordType, llvm::StringRef mapperNameStr,
+ RecordMemberMapperMangler mangler = {});
} // namespace Fortran::utils::openmp
#endif // FORTRAN_UTILS_OPENMP_H_
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index b1973a3b8bf0..8edb7cd5d193 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1352,8 +1352,16 @@ void ClauseProcessor::processMapObjects(
if (!recordType)
return mlir::FlatSymbolRefAttr();
- return getOrGenImplicitDefaultDeclareMapper(converter, clauseLocation,
- recordType, mapperIdName);
+ return utils::openmp::getOrGenImplicitDefaultDeclareMapper(
+ converter.getFirOpBuilder(), clauseLocation, recordType, mapperIdName,
+ [&](std::string &mapperIdName, llvm::StringRef memberName) {
+ if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
+ mapperIdName = converter.mangleName(mapperIdName, sym->owner());
+ else if (auto *memberSym =
+ converter.getCurrentScope().FindSymbol(memberName.str()))
+ mapperIdName =
+ converter.mangleName(mapperIdName, memberSym->owner());
+ });
};
auto getDefaultMapperID =
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index df89cbe46a5c..17d466c6340d 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2792,7 +2792,20 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
if (auto recordType = mlir::dyn_cast_or_null<fir::RecordType>(
converter.genType(*typeSpec)))
mapperId = getOrGenImplicitDefaultDeclareMapper(
- converter, loc, recordType, mapperIdName);
+ converter.getFirOpBuilder(), loc, recordType,
+ mapperIdName,
+ [&](std::string &mapperIdName,
+ llvm::StringRef memberName) {
+ if (auto *sym = converter.getCurrentScope().FindSymbol(
+ mapperIdName))
+ mapperIdName =
+ converter.mangleName(mapperIdName, sym->owner());
+ else if (auto *memberSym =
+ converter.getCurrentScope().FindSymbol(
+ memberName.str()))
+ mapperIdName = converter.mangleName(
+ mapperIdName, memberSym->owner());
+ });
} else {
mapperId = mlir::FlatSymbolRefAttr::get(
&converter.getMLIRContext(), mapperIdName);
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index dce858085666..1bf5117168a7 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -67,113 +67,6 @@ llvm::cl::opt<bool> treatIndexAsSection(
namespace Fortran {
namespace lower {
namespace omp {
-
-mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
- lower::AbstractConverter &converter, mlir::Location loc,
- fir::RecordType recordType, llvm::StringRef mapperNameStr) {
- if (mapperNameStr.empty())
- return {};
-
- if (converter.getModuleOp().lookupSymbol(mapperNameStr))
- return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
- mapperNameStr);
-
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
-
- firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody());
- auto declMapperOp = mlir::omp::DeclareMapperOp::create(
- firOpBuilder, loc, mapperNameStr, recordType);
- auto &region = declMapperOp.getRegion();
- firOpBuilder.createBlock(&region);
- auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
-
- auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg,
- /*uniq_name=*/"");
-
- const auto genBoundsOps = [&](mlir::Value mapVal,
- llvm::SmallVectorImpl<mlir::Value> &bounds) {
- fir::ExtendedValue extVal =
- hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder,
- hlfir::Entity{mapVal},
- /*contiguousHint=*/true)
- .first;
- fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
- firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
- bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
- mlir::omp::MapBoundsType>(
- firOpBuilder, info, extVal,
- /*dataExvIsAssumedSize=*/false, mapVal.getLoc());
- };
-
- const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName,
- mlir::Type fieldTy, mlir::Type recType) {
- mlir::Value field = fir::FieldIndexOp::create(
- firOpBuilder, loc, fir::FieldType::get(recType.getContext()), fieldName,
- recType, fir::getTypeParams(rec));
- return fir::CoordinateOp::create(
- firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field);
- };
-
- llvm::SmallVector<mlir::Value> clauseMapVars;
- llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
- llvm::SmallVector<mlir::Value> memberMapOps;
-
- mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to |
- mlir::omp::ClauseMapFlags::from |
- mlir::omp::ClauseMapFlags::implicit;
- mlir::omp::VariableCaptureKind captureKind =
- mlir::omp::VariableCaptureKind::ByRef;
-
- for (const auto &entry : llvm::enumerate(recordType.getTypeList())) {
- const auto &memberName = entry.value().first;
- const auto &memberType = entry.value().second;
- mlir::FlatSymbolRefAttr mapperId;
- if (auto recType = mlir::dyn_cast<fir::RecordType>(
- fir::getFortranElementType(memberType))) {
- std::string mapperIdName =
- recType.getName().str() + llvm::omp::OmpDefaultMapperName;
- if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
- mapperIdName = converter.mangleName(mapperIdName, sym->owner());
- else if (auto *memberSym =
- converter.getCurrentScope().FindSymbol(memberName))
- mapperIdName = converter.mangleName(mapperIdName, memberSym->owner());
-
- mapperId = getOrGenImplicitDefaultDeclareMapper(converter, loc, recType,
- mapperIdName);
- }
-
- auto ref =
- getFieldRef(declareOp.getBase(), memberName, memberType, recordType);
- llvm::SmallVector<mlir::Value> bounds;
- genBoundsOps(ref, bounds);
- mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp(
- firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"",
- bounds,
- /*members=*/{},
- /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(),
- /*partialMap=*/false, mapperId);
- memberMapOps.emplace_back(mapOp);
- memberPlacementIndices.emplace_back(
- llvm::SmallVector<int64_t>{(int64_t)entry.index()});
- }
-
- llvm::SmallVector<mlir::Value> bounds;
- genBoundsOps(declareOp.getOriginalBase(), bounds);
- mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit;
- mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp(
- firOpBuilder, loc, declareOp.getOriginalBase(),
- /*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
- firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag,
- captureKind, declareOp.getType(0),
- /*partialMap=*/true);
-
- clauseMapVars.emplace_back(mapOp);
- mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars);
- return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
- mapperNameStr);
-}
-
bool requiresImplicitDefaultDeclareMapper(
const semantics::DerivedTypeSpec &typeSpec) {
// ISO C interoperable types (e.g., c_ptr, c_funptr) must always have implicit
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 8a68ff8bd3bd..0b651572e087 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -137,10 +137,6 @@ mlir::Value createParentSymAndGenIntermediateMaps(
OmpMapParentAndMemberData &parentMemberIndices, llvm::StringRef asFortran,
mlir::omp::ClauseMapFlags mapTypeBits);
-mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
- Fortran::lower::AbstractConverter &converter, mlir::Location loc,
- fir::RecordType recordType, llvm::StringRef mapperNameStr);
-
bool requiresImplicitDefaultDeclareMapper(
const semantics::DerivedTypeSpec &typeSpec);
diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
index ff346e79276c..ce3659660290 100644
--- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
@@ -22,6 +22,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
namespace flangomp {
#define GEN_PASS_DEF_DOCONCURRENTCONVERSIONPASS
@@ -583,12 +584,46 @@ private:
llvm::SmallVector<mlir::Value> boundsOps;
genBoundsOps(builder, liveIn, rawAddr, boundsOps);
+ auto asRecordType = [&](mlir::Type eleType) {
+ fir::RecordType recordType = mlir::dyn_cast<fir::RecordType>(eleType);
+
+ if (auto seqType = mlir::dyn_cast<fir::SequenceType>(eleType))
+ recordType = mlir::dyn_cast<fir::RecordType>(seqType.getElementType());
+
+ return recordType;
+ };
+
+ fir::RecordType recordType = asRecordType(eleType);
+
+ bool requiresImplcitMapper = [&]() {
+ if (!recordType)
+ return false;
+
+ for (auto [fieldName, fieldType] : recordType.getTypeList()) {
+ if (fir::isAllocatableType(fieldType))
+ return true;
+
+ if (asRecordType(fieldType))
+ TODO(liveIn.getLoc(), "Nested record types are not supported yet.");
+ }
+
+ return false;
+ }();
+
+ mlir::FlatSymbolRefAttr mapperId;
+ if (requiresImplcitMapper) {
+ std::string mapperIdName =
+ recordType.getName().str() + llvm::omp::OmpDefaultMapperName;
+ mapperId = Fortran::utils::openmp::getOrGenImplicitDefaultDeclareMapper(
+ builder, liveIn.getLoc(), recordType, mapperIdName);
+ }
+
return Fortran::utils::openmp::createMapInfoOp(
builder, liveIn.getLoc(), rawAddr,
/*varPtrPtr=*/{}, name.str(), boundsOps,
/*members=*/{},
/*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind,
- rawAddr.getType());
+ rawAddr.getType(), /*partialMap=*/false, mapperId);
}
mlir::omp::TargetOp
diff --git a/flang/lib/Utils/OpenMP.cpp b/flang/lib/Utils/OpenMP.cpp
index c2036c4a383f..a0a67249f7f0 100644
--- a/flang/lib/Utils/OpenMP.cpp
+++ b/flang/lib/Utils/OpenMP.cpp
@@ -155,4 +155,101 @@ void cloneOrMapRegionOutsiders(
mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
}
}
+
+mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
+ fir::FirOpBuilder &firOpBuilder, mlir::Location loc,
+ fir::RecordType recordType, llvm::StringRef mapperNameStr,
+ RecordMemberMapperMangler mangler) {
+ if (mapperNameStr.empty())
+ return {};
+
+ mlir::ModuleOp moduleOp = firOpBuilder.getModule();
+ if (moduleOp.lookupSymbol(mapperNameStr))
+ return mlir::FlatSymbolRefAttr::get(
+ firOpBuilder.getContext(), mapperNameStr);
+
+ mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+
+ firOpBuilder.setInsertionPointToStart(moduleOp.getBody());
+ auto declMapperOp = mlir::omp::DeclareMapperOp::create(
+ firOpBuilder, loc, mapperNameStr, recordType);
+ auto &region = declMapperOp.getRegion();
+ firOpBuilder.createBlock(&region);
+ auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
+
+ auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg,
+ /*uniq_name=*/"");
+
+ const auto genBoundsOps = [&](mlir::Value mapVal,
+ llvm::SmallVectorImpl<mlir::Value> &bounds) {
+ fir::ExtendedValue extVal = hlfir::translateToExtendedValue(mapVal.getLoc(),
+ firOpBuilder, hlfir::Entity{mapVal},
+ /*contiguousHint=*/true)
+ .first;
+ fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
+ firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
+ bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(firOpBuilder, info, extVal,
+ /*dataExvIsAssumedSize=*/false, mapVal.getLoc());
+ };
+
+ const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName,
+ mlir::Type fieldTy, mlir::Type recType) {
+ mlir::Value field = fir::FieldIndexOp::create(firOpBuilder, loc,
+ fir::FieldType::get(recType.getContext()), fieldName, recType,
+ fir::getTypeParams(rec));
+ return fir::CoordinateOp::create(
+ firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field);
+ };
+
+ llvm::SmallVector<mlir::Value> clauseMapVars;
+ llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
+ llvm::SmallVector<mlir::Value> memberMapOps;
+
+ mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to |
+ mlir::omp::ClauseMapFlags::from | mlir::omp::ClauseMapFlags::implicit;
+ mlir::omp::VariableCaptureKind captureKind =
+ mlir::omp::VariableCaptureKind::ByRef;
+
+ for (const auto &entry : llvm::enumerate(recordType.getTypeList())) {
+ const auto &memberName = entry.value().first;
+ const auto &memberType = entry.value().second;
+ mlir::FlatSymbolRefAttr mapperId;
+ if (auto recType = mlir::dyn_cast<fir::RecordType>(
+ fir::getFortranElementType(memberType))) {
+ std::string mapperIdName =
+ recType.getName().str() + llvm::omp::OmpDefaultMapperName;
+ mangler(mapperIdName, memberName);
+ mapperId = getOrGenImplicitDefaultDeclareMapper(
+ firOpBuilder, loc, recType, mapperIdName, mangler);
+ }
+
+ auto ref =
+ getFieldRef(declareOp.getBase(), memberName, memberType, recordType);
+ llvm::SmallVector<mlir::Value> bounds;
+ genBoundsOps(ref, bounds);
+ mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp(firOpBuilder,
+ loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"", bounds,
+ /*members=*/{},
+ /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(),
+ /*partialMap=*/false, mapperId);
+ memberMapOps.emplace_back(mapOp);
+ memberPlacementIndices.emplace_back(
+ llvm::SmallVector<int64_t>{(int64_t)entry.index()});
+ }
+
+ llvm::SmallVector<mlir::Value> bounds;
+ genBoundsOps(declareOp.getOriginalBase(), bounds);
+ mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit;
+ mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp(
+ firOpBuilder, loc, declareOp.getOriginalBase(),
+ /*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
+ firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag,
+ captureKind, declareOp.getType(0),
+ /*partialMap=*/true);
+
+ clauseMapVars.emplace_back(mapOp);
+ mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars);
+ return mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), mapperNameStr);
+}
} // namespace Fortran::utils::openmp
diff --git a/flang/test/Transforms/DoConcurrent/implicit_mapper.f90 b/flang/test/Transforms/DoConcurrent/implicit_mapper.f90
new file mode 100644
index 000000000000..f77c0cd0320c
--- /dev/null
+++ b/flang/test/Transforms/DoConcurrent/implicit_mapper.f90
@@ -0,0 +1,28 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-to-openmp=device %s -o - \
+! RUN: | FileCheck %s
+
+module record_with_alloc_mod
+ implicit none
+ public :: record_with_alloc
+
+ type record_with_alloc
+ real, allocatable :: values_(:)
+ end type
+end module record_with_alloc_mod
+
+subroutine random_inputs()
+ use record_with_alloc_mod, only : record_with_alloc
+ implicit none
+ type(record_with_alloc) :: inputs(2)
+ integer :: i
+
+ do concurrent(i=1:10)
+ inputs(1)%values_ = [1,2,3,4]
+ end do
+end subroutine
+
+! CHECK: omp.declare_mapper @[[MAPPER_NAME:.*record_with_alloc_omp_default_mapper]] : !fir.type<{{.*}}record_with_alloc{{.*}}>
+
+! CHECK: func.func @{{.*}}random_inputs()
+! CHECK: %[[ARR_DECL:.*]]:2 = hlfir.declare {{.*}} {uniq_name = "{{.*}}inputs"}
+! CHECK: omp.map.info var_ptr(%[[ARR_DECL]]#1 : {{.*}}) {{.*}} mapper(@[[MAPPER_NAME]])