diff options
Diffstat (limited to 'flang/lib/Lower/ConvertVariable.cpp')
-rw-r--r-- | flang/lib/Lower/ConvertVariable.cpp | 82 |
1 files changed, 9 insertions, 73 deletions
diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp index a4a8a69..fd66592 100644 --- a/flang/lib/Lower/ConvertVariable.cpp +++ b/flang/lib/Lower/ConvertVariable.cpp @@ -14,12 +14,12 @@ #include "flang/Lower/AbstractConverter.h" #include "flang/Lower/Allocatable.h" #include "flang/Lower/BoxAnalyzer.h" +#include "flang/Lower/CUDA.h" #include "flang/Lower/CallInterface.h" #include "flang/Lower/ConvertConstant.h" #include "flang/Lower/ConvertExpr.h" #include "flang/Lower/ConvertExprToHLFIR.h" #include "flang/Lower/ConvertProcedureDesignator.h" -#include "flang/Lower/Cuda.h" #include "flang/Lower/Mangler.h" #include "flang/Lower/PFTBuilder.h" #include "flang/Lower/StatementContext.h" @@ -814,81 +814,24 @@ initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter, baseTy = boxTy.getEleTy(); baseTy = fir::unwrapRefType(baseTy); - if (mlir::isa<fir::SequenceType>(baseTy) && - (fir::isAllocatableType(fir::getBase(exv).getType()) || - fir::isPointerType(fir::getBase(exv).getType()))) + if (fir::isAllocatableType(fir::getBase(exv).getType()) || + fir::isPointerType(fir::getBase(exv).getType())) return; // Allocator index need to be set after allocation. auto recTy = mlir::dyn_cast<fir::RecordType>(fir::unwrapSequenceType(baseTy)); assert(recTy && "expected fir::RecordType"); - llvm::SmallVector<mlir::Value> coordinates; Fortran::semantics::UltimateComponentIterator components{*derived}; for (const auto &sym : components) { if (Fortran::semantics::IsDeviceAllocatable(sym)) { - unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString()); - mlir::Type fieldTy; - llvm::SmallVector<mlir::Value> coordinates; - - if (fieldIdx != std::numeric_limits<unsigned>::max()) { - // Field found in the base record type. - auto fieldName = recTy.getTypeList()[fieldIdx].first; - fieldTy = recTy.getTypeList()[fieldIdx].second; - mlir::Value fieldIndex = fir::FieldIndexOp::create( - builder, loc, fir::FieldType::get(fieldTy.getContext()), - fieldName, recTy, - /*typeParams=*/mlir::ValueRange{}); - coordinates.push_back(fieldIndex); - } else { - // Field not found in base record type, search in potential - // record type components. - for (auto component : recTy.getTypeList()) { - if (auto childRecTy = - mlir::dyn_cast<fir::RecordType>(component.second)) { - fieldIdx = childRecTy.getFieldIndex(sym.name().ToString()); - if (fieldIdx != std::numeric_limits<unsigned>::max()) { - mlir::Value parentFieldIndex = fir::FieldIndexOp::create( - builder, loc, - fir::FieldType::get(childRecTy.getContext()), - component.first, recTy, - /*typeParams=*/mlir::ValueRange{}); - coordinates.push_back(parentFieldIndex); - auto fieldName = childRecTy.getTypeList()[fieldIdx].first; - fieldTy = childRecTy.getTypeList()[fieldIdx].second; - mlir::Value childFieldIndex = fir::FieldIndexOp::create( - builder, loc, fir::FieldType::get(fieldTy.getContext()), - fieldName, childRecTy, - /*typeParams=*/mlir::ValueRange{}); - coordinates.push_back(childFieldIndex); - break; - } - } - } - } - - if (coordinates.empty()) - TODO(loc, "device resident component in complex derived-type " - "hierarchy"); - + llvm::SmallVector<mlir::Value> coord; + mlir::Type fieldTy = + Fortran::lower::gatherDeviceComponentCoordinatesAndType( + builder, loc, sym, recTy, coord); mlir::Value base = fir::getBase(exv); - mlir::Value comp; - if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(base.getType()))) { - mlir::Value box = fir::LoadOp::create(builder, loc, base); - mlir::Value addr = fir::BoxAddrOp::create(builder, loc, box); - llvm::SmallVector<mlir::Value> lenParams; - assert(coordinates.size() == 1 && "expect one coordinate"); - auto field = mlir::dyn_cast<fir::FieldIndexOp>( - coordinates[0].getDefiningOp()); - comp = hlfir::DesignateOp::create( - builder, loc, builder.getRefType(fieldTy), addr, - /*component=*/field.getFieldName(), - /*componentShape=*/mlir::Value{}, - hlfir::DesignateOp::Subscripts{}); - } else { - comp = fir::CoordinateOp::create( - builder, loc, builder.getRefType(fieldTy), base, coordinates); - } + mlir::Value comp = fir::CoordinateOp::create( + builder, loc, builder.getRefType(fieldTy), base, coord); cuf::DataAttributeAttr dataAttr = Fortran::lower::translateSymbolCUFDataAttribute( builder.getContext(), sym); @@ -1950,13 +1893,6 @@ fir::FortranVariableFlagsAttr Fortran::lower::translateSymbolAttributes( return fir::FortranVariableFlagsAttr::get(mlirContext, flags); } -cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute( - mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) { - std::optional<Fortran::common::CUDADataAttr> cudaAttr = - Fortran::semantics::GetCUDADataAttr(&sym.GetUltimate()); - return cuf::getDataAttribute(mlirContext, cudaAttr); -} - static bool isCapturedInInternalProcedure(Fortran::lower::AbstractConverter &converter, const Fortran::semantics::Symbol &sym) { |