diff options
author | Valentin Clement (バレンタイン クレメン) <clementval@gmail.com> | 2024-04-05 09:11:37 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-05 09:11:37 -0700 |
commit | 953aa102a90099ae655eaa4645dd8d15c95ea86a (patch) | |
tree | aa5dfbac45899142387bbb0fda3307eb6b873888 /flang/lib/Lower/Bridge.cpp | |
parent | 5f9ed2ff8364ff3e4fac410472f421299dafa793 (diff) | |
download | llvm-953aa102a90099ae655eaa4645dd8d15c95ea86a.zip llvm-953aa102a90099ae655eaa4645dd8d15c95ea86a.tar.gz llvm-953aa102a90099ae655eaa4645dd8d15c95ea86a.tar.bz2 |
[flang][cuda] Lower device to host and device to device transfer (#87387)
Add more support for CUDA data transfer in assignment. This patch adds
device to device and device to host support. If device symbols are
present on the rhs, some implicit data transfer are initiated. A
temporary is created and the data are transferred to the host. The
expression is evaluated on the host and the assignment is done.
Diffstat (limited to 'flang/lib/Lower/Bridge.cpp')
-rw-r--r-- | flang/lib/Lower/Bridge.cpp | 99 |
1 files changed, 88 insertions, 11 deletions
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 5bba097..478c8f4 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -3710,16 +3710,18 @@ private: return false; } - static void genCUDADataTransfer(fir::FirOpBuilder &builder, - mlir::Location loc, bool lhsIsDevice, - hlfir::Entity &lhs, bool rhsIsDevice, - hlfir::Entity &rhs) { + void genCUDADataTransfer(fir::FirOpBuilder &builder, mlir::Location loc, + const Fortran::evaluate::Assignment &assign, + hlfir::Entity &lhs, hlfir::Entity &rhs) { + bool lhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.lhs); + bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs); if (rhs.isBoxAddressOrValue() || lhs.isBoxAddressOrValue()) TODO(loc, "CUDA data transfler with descriptors"); + + // device = host if (lhsIsDevice && !rhsIsDevice) { auto transferKindAttr = fir::CUDADataTransferKindAttr::get( builder.getContext(), fir::CUDADataTransferKind::HostDevice); - // device = host if (!rhs.isVariable()) { auto associate = hlfir::genAssociateExpr( loc, builder, rhs, rhs.getType(), ".cuf_host_tmp"); @@ -3732,7 +3734,73 @@ private: } return; } - TODO(loc, "Assignement with CUDA Fortran variables"); + + // host = device + if (!lhsIsDevice && rhsIsDevice) { + auto transferKindAttr = fir::CUDADataTransferKindAttr::get( + builder.getContext(), fir::CUDADataTransferKind::DeviceHost); + if (!rhs.isVariable()) { + // evaluateRhs loads scalar. Look for the memory reference to be used in + // the transfer. + if (mlir::isa_and_nonnull<fir::LoadOp>(rhs.getDefiningOp())) { + auto loadOp = mlir::dyn_cast<fir::LoadOp>(rhs.getDefiningOp()); + builder.create<fir::CUDADataTransferOp>(loc, loadOp.getMemref(), lhs, + transferKindAttr); + return; + } + } else { + builder.create<fir::CUDADataTransferOp>(loc, rhs, lhs, + transferKindAttr); + } + return; + } + + if (lhsIsDevice && rhsIsDevice) { + assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal"); + auto transferKindAttr = fir::CUDADataTransferKindAttr::get( + builder.getContext(), fir::CUDADataTransferKind::DeviceDevice); + builder.create<fir::CUDADataTransferOp>(loc, rhs, lhs, transferKindAttr); + return; + } + llvm_unreachable("Unhandled CUDA data transfer"); + } + + llvm::SmallVector<mlir::Value> + genCUDAImplicitDataTransfer(fir::FirOpBuilder &builder, mlir::Location loc, + const Fortran::evaluate::Assignment &assign) { + llvm::SmallVector<mlir::Value> temps; + localSymbols.pushScope(); + auto transferKindAttr = fir::CUDADataTransferKindAttr::get( + builder.getContext(), fir::CUDADataTransferKind::DeviceHost); + unsigned nbDeviceResidentObject = 0; + for (const Fortran::semantics::Symbol &sym : + Fortran::evaluate::CollectSymbols(assign.rhs)) { + if (const auto *details = + sym.GetUltimate() + .detailsIf<Fortran::semantics::ObjectEntityDetails>()) { + if (details->cudaDataAttr()) { + if (sym.owner().IsDerivedType() && IsAllocatable(sym.GetUltimate())) + TODO(loc, "Device resident allocatable derived-type component"); + // TODO: This should probably being checked in semantic and give a + // proper error. + assert( + nbDeviceResidentObject <= 1 && + "Only one reference to the device resident object is supported"); + auto addr = getSymbolAddress(sym); + hlfir::Entity entity{addr}; + auto [temp, cleanup] = + hlfir::createTempFromMold(loc, builder, entity); + auto needCleanup = fir::getIntIfConstant(cleanup); + if (needCleanup && *needCleanup) + temps.push_back(temp); + addSymbol(sym, temp, /*forced=*/true); + builder.create<fir::CUDADataTransferOp>(loc, addr, temp, + transferKindAttr); + ++nbDeviceResidentObject; + } + } + } + return temps; } void genDataAssignment( @@ -3741,8 +3809,13 @@ private: mlir::Location loc = getCurrentLocation(); fir::FirOpBuilder &builder = getFirOpBuilder(); - bool lhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.lhs); - bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs); + bool isCUDATransfer = Fortran::evaluate::HasCUDAAttrs(assign.lhs) || + Fortran::evaluate::HasCUDAAttrs(assign.rhs); + bool hasCUDAImplicitTransfer = + Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs); + llvm::SmallVector<mlir::Value> implicitTemps; + if (hasCUDAImplicitTransfer) + implicitTemps = genCUDAImplicitDataTransfer(builder, loc, assign); // Gather some information about the assignment that will impact how it is // lowered. @@ -3800,12 +3873,16 @@ private: Fortran::lower::StatementContext localStmtCtx; hlfir::Entity rhs = evaluateRhs(localStmtCtx); hlfir::Entity lhs = evaluateLhs(localStmtCtx); - if (lhsIsDevice || rhsIsDevice) { - genCUDADataTransfer(builder, loc, lhsIsDevice, lhs, rhsIsDevice, rhs); - } else { + if (isCUDATransfer && !hasCUDAImplicitTransfer) + genCUDADataTransfer(builder, loc, assign, lhs, rhs); + else builder.create<hlfir::AssignOp>(loc, rhs, lhs, isWholeAllocatableAssignment, keepLhsLengthInAllocatableAssignment); + if (hasCUDAImplicitTransfer) { + localSymbols.popScope(); + for (mlir::Value temp : implicitTemps) + builder.create<fir::FreeMemOp>(loc, temp); } return; } |