aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Lower/Bridge.cpp
diff options
context:
space:
mode:
authorValentin Clement (バレンタイン クレメン) <clementval@gmail.com>2024-04-05 09:11:37 -0700
committerGitHub <noreply@github.com>2024-04-05 09:11:37 -0700
commit953aa102a90099ae655eaa4645dd8d15c95ea86a (patch)
treeaa5dfbac45899142387bbb0fda3307eb6b873888 /flang/lib/Lower/Bridge.cpp
parent5f9ed2ff8364ff3e4fac410472f421299dafa793 (diff)
downloadllvm-953aa102a90099ae655eaa4645dd8d15c95ea86a.zip
llvm-953aa102a90099ae655eaa4645dd8d15c95ea86a.tar.gz
llvm-953aa102a90099ae655eaa4645dd8d15c95ea86a.tar.bz2
[flang][cuda] Lower device to host and device to device transfer (#87387)
Add more support for CUDA data transfer in assignment. This patch adds device to device and device to host support. If device symbols are present on the rhs, some implicit data transfer are initiated. A temporary is created and the data are transferred to the host. The expression is evaluated on the host and the assignment is done.
Diffstat (limited to 'flang/lib/Lower/Bridge.cpp')
-rw-r--r--flang/lib/Lower/Bridge.cpp99
1 files changed, 88 insertions, 11 deletions
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 5bba097..478c8f4 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3710,16 +3710,18 @@ private:
return false;
}
- static void genCUDADataTransfer(fir::FirOpBuilder &builder,
- mlir::Location loc, bool lhsIsDevice,
- hlfir::Entity &lhs, bool rhsIsDevice,
- hlfir::Entity &rhs) {
+ void genCUDADataTransfer(fir::FirOpBuilder &builder, mlir::Location loc,
+ const Fortran::evaluate::Assignment &assign,
+ hlfir::Entity &lhs, hlfir::Entity &rhs) {
+ bool lhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.lhs);
+ bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs);
if (rhs.isBoxAddressOrValue() || lhs.isBoxAddressOrValue())
TODO(loc, "CUDA data transfler with descriptors");
+
+ // device = host
if (lhsIsDevice && !rhsIsDevice) {
auto transferKindAttr = fir::CUDADataTransferKindAttr::get(
builder.getContext(), fir::CUDADataTransferKind::HostDevice);
- // device = host
if (!rhs.isVariable()) {
auto associate = hlfir::genAssociateExpr(
loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
@@ -3732,7 +3734,73 @@ private:
}
return;
}
- TODO(loc, "Assignement with CUDA Fortran variables");
+
+ // host = device
+ if (!lhsIsDevice && rhsIsDevice) {
+ auto transferKindAttr = fir::CUDADataTransferKindAttr::get(
+ builder.getContext(), fir::CUDADataTransferKind::DeviceHost);
+ if (!rhs.isVariable()) {
+ // evaluateRhs loads scalar. Look for the memory reference to be used in
+ // the transfer.
+ if (mlir::isa_and_nonnull<fir::LoadOp>(rhs.getDefiningOp())) {
+ auto loadOp = mlir::dyn_cast<fir::LoadOp>(rhs.getDefiningOp());
+ builder.create<fir::CUDADataTransferOp>(loc, loadOp.getMemref(), lhs,
+ transferKindAttr);
+ return;
+ }
+ } else {
+ builder.create<fir::CUDADataTransferOp>(loc, rhs, lhs,
+ transferKindAttr);
+ }
+ return;
+ }
+
+ if (lhsIsDevice && rhsIsDevice) {
+ assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal");
+ auto transferKindAttr = fir::CUDADataTransferKindAttr::get(
+ builder.getContext(), fir::CUDADataTransferKind::DeviceDevice);
+ builder.create<fir::CUDADataTransferOp>(loc, rhs, lhs, transferKindAttr);
+ return;
+ }
+ llvm_unreachable("Unhandled CUDA data transfer");
+ }
+
+ llvm::SmallVector<mlir::Value>
+ genCUDAImplicitDataTransfer(fir::FirOpBuilder &builder, mlir::Location loc,
+ const Fortran::evaluate::Assignment &assign) {
+ llvm::SmallVector<mlir::Value> temps;
+ localSymbols.pushScope();
+ auto transferKindAttr = fir::CUDADataTransferKindAttr::get(
+ builder.getContext(), fir::CUDADataTransferKind::DeviceHost);
+ unsigned nbDeviceResidentObject = 0;
+ for (const Fortran::semantics::Symbol &sym :
+ Fortran::evaluate::CollectSymbols(assign.rhs)) {
+ if (const auto *details =
+ sym.GetUltimate()
+ .detailsIf<Fortran::semantics::ObjectEntityDetails>()) {
+ if (details->cudaDataAttr()) {
+ if (sym.owner().IsDerivedType() && IsAllocatable(sym.GetUltimate()))
+ TODO(loc, "Device resident allocatable derived-type component");
+ // TODO: This should probably being checked in semantic and give a
+ // proper error.
+ assert(
+ nbDeviceResidentObject <= 1 &&
+ "Only one reference to the device resident object is supported");
+ auto addr = getSymbolAddress(sym);
+ hlfir::Entity entity{addr};
+ auto [temp, cleanup] =
+ hlfir::createTempFromMold(loc, builder, entity);
+ auto needCleanup = fir::getIntIfConstant(cleanup);
+ if (needCleanup && *needCleanup)
+ temps.push_back(temp);
+ addSymbol(sym, temp, /*forced=*/true);
+ builder.create<fir::CUDADataTransferOp>(loc, addr, temp,
+ transferKindAttr);
+ ++nbDeviceResidentObject;
+ }
+ }
+ }
+ return temps;
}
void genDataAssignment(
@@ -3741,8 +3809,13 @@ private:
mlir::Location loc = getCurrentLocation();
fir::FirOpBuilder &builder = getFirOpBuilder();
- bool lhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.lhs);
- bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs);
+ bool isCUDATransfer = Fortran::evaluate::HasCUDAAttrs(assign.lhs) ||
+ Fortran::evaluate::HasCUDAAttrs(assign.rhs);
+ bool hasCUDAImplicitTransfer =
+ Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs);
+ llvm::SmallVector<mlir::Value> implicitTemps;
+ if (hasCUDAImplicitTransfer)
+ implicitTemps = genCUDAImplicitDataTransfer(builder, loc, assign);
// Gather some information about the assignment that will impact how it is
// lowered.
@@ -3800,12 +3873,16 @@ private:
Fortran::lower::StatementContext localStmtCtx;
hlfir::Entity rhs = evaluateRhs(localStmtCtx);
hlfir::Entity lhs = evaluateLhs(localStmtCtx);
- if (lhsIsDevice || rhsIsDevice) {
- genCUDADataTransfer(builder, loc, lhsIsDevice, lhs, rhsIsDevice, rhs);
- } else {
+ if (isCUDATransfer && !hasCUDAImplicitTransfer)
+ genCUDADataTransfer(builder, loc, assign, lhs, rhs);
+ else
builder.create<hlfir::AssignOp>(loc, rhs, lhs,
isWholeAllocatableAssignment,
keepLhsLengthInAllocatableAssignment);
+ if (hasCUDAImplicitTransfer) {
+ localSymbols.popScope();
+ for (mlir::Value temp : implicitTemps)
+ builder.create<fir::FreeMemOp>(loc, temp);
}
return;
}