aboutsummaryrefslogtreecommitdiff
path: root/flang/lib/Lower/Allocatable.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Lower/Allocatable.cpp')
-rw-r--r--flang/lib/Lower/Allocatable.cpp14
1 files changed, 9 insertions, 5 deletions
diff --git a/flang/lib/Lower/Allocatable.cpp b/flang/lib/Lower/Allocatable.cpp
index 219f920..ce9d894 100644
--- a/flang/lib/Lower/Allocatable.cpp
+++ b/flang/lib/Lower/Allocatable.cpp
@@ -13,9 +13,9 @@
#include "flang/Lower/Allocatable.h"
#include "flang/Evaluate/tools.h"
#include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/CUDA.h"
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/ConvertVariable.h"
-#include "flang/Lower/Cuda.h"
#include "flang/Lower/IterationSpace.h"
#include "flang/Lower/Mangler.h"
#include "flang/Lower/OpenACC.h"
@@ -445,10 +445,14 @@ private:
/*mustBeHeap=*/true);
}
- void postAllocationAction(const Allocation &alloc) {
+ void postAllocationAction(const Allocation &alloc,
+ const fir::MutableBoxValue &box) {
if (alloc.getSymbol().test(Fortran::semantics::Symbol::Flag::AccDeclare))
Fortran::lower::attachDeclarePostAllocAction(converter, builder,
alloc.getSymbol());
+ if (Fortran::semantics::HasCUDAComponent(alloc.getSymbol()))
+ Fortran::lower::initializeDeviceComponentAllocator(
+ converter, alloc.getSymbol(), box);
}
void setPinnedToFalse() {
@@ -481,7 +485,7 @@ private:
// Pointers must use PointerAllocate so that their deallocations
// can be validated.
genInlinedAllocation(alloc, box);
- postAllocationAction(alloc);
+ postAllocationAction(alloc, box);
setPinnedToFalse();
return;
}
@@ -504,7 +508,7 @@ private:
genCudaAllocate(builder, loc, box, errorManager, alloc.getSymbol());
}
fir::factory::syncMutableBoxFromIRBox(builder, loc, box);
- postAllocationAction(alloc);
+ postAllocationAction(alloc, box);
errorManager.assignStat(builder, loc, stat);
}
@@ -647,7 +651,7 @@ private:
setPinnedToFalse();
}
fir::factory::syncMutableBoxFromIRBox(builder, loc, box);
- postAllocationAction(alloc);
+ postAllocationAction(alloc, box);
errorManager.assignStat(builder, loc, stat);
}