aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAkash Banerjee <Akash.Banerjee@amd.com>2025-07-29 00:23:19 +0100
committerAkash Banerjee <Akash.Banerjee@amd.com>2025-07-29 00:23:19 +0100
commit1c9de4524391b28ea14279ce09fdb90651d83425 (patch)
treebbf7c93bb4be29d4273256af57b9560559922fe3
parentc93083ac24fb9b7f65951c85c6e174a35da0914c (diff)
downloadllvm-users/Akash/implicit_default_mapper.zip
llvm-users/Akash/implicit_default_mapper.tar.gz
llvm-users/Akash/implicit_default_mapper.tar.bz2
User inner FortranType when checking for RecordType.users/Akash/implicit_default_mapper
-rw-r--r--flang/lib/Lower/OpenMP/OpenMP.cpp17
1 files changed, 8 insertions, 9 deletions
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 816e04d..5a90e24 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2510,15 +2510,13 @@ static mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
// Return a reference to the contents of a derived type with one field.
// Also return the field type.
- const auto getFieldRef =
- [&](mlir::Value rec, llvm::StringRef fieldName, mlir::Type fieldTy,
- mlir::Type recType) -> std::tuple<mlir::Value, mlir::Type> {
+ const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName,
+ mlir::Type fieldTy, mlir::Type recType) {
mlir::Value field = firOpBuilder.create<fir::FieldIndexOp>(
loc, fir::FieldType::get(recType.getContext()), fieldName, recType,
fir::getTypeParams(rec));
- return {firOpBuilder.create<fir::CoordinateOp>(
- loc, firOpBuilder.getRefType(fieldTy), rec, field),
- fieldTy};
+ return firOpBuilder.create<fir::CoordinateOp>(
+ loc, firOpBuilder.getRefType(fieldTy), rec, field);
};
mlir::omp::DeclareMapperInfoOperands clauseOps;
@@ -2536,10 +2534,9 @@ static mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
for (const auto &entry : llvm::enumerate(recordType.getTypeList())) {
const auto &memberName = entry.value().first;
const auto &memberType = entry.value().second;
- auto [ref, type] =
- getFieldRef(declareOp.getBase(), memberName, memberType, recordType);
mlir::FlatSymbolRefAttr mapperId;
- if (auto recType = mlir::dyn_cast<fir::RecordType>(memberType)) {
+ if (auto recType = mlir::dyn_cast<fir::RecordType>(
+ fir::getFortranElementType(memberType))) {
std::string mapperIdName =
recType.getName().str() + llvm::omp::OmpDefaultMapperName;
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
@@ -2551,6 +2548,8 @@ static mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
mapperIdName);
}
+ auto ref =
+ getFieldRef(declareOp.getBase(), memberName, memberType, recordType);
llvm::SmallVector<mlir::Value> bounds;
genBoundsOps(ref, bounds);
mlir::Value mapOp = createMapInfoOp(