aboutsummaryrefslogtreecommitdiff
path: root/offload/plugins-nextgen/common/src
diff options
context:
space:
mode:
Diffstat (limited to 'offload/plugins-nextgen/common/src')
-rw-r--r--offload/plugins-nextgen/common/src/PluginInterface.cpp93
1 files changed, 86 insertions, 7 deletions
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index ed79af9..e5a313d 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -662,6 +662,10 @@ uint32_t GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice,
return std::min(NumTeamsClause[0], GenericDevice.getBlockLimit());
}
+ // Return the number of teams required to cover the loop iterations.
+ if (isNoLoopMode())
+ return LoopTripCount > 0 ? (((LoopTripCount - 1) / NumThreads) + 1) : 1;
+
uint64_t DefaultNumBlocks = GenericDevice.getDefaultNumBlocks();
uint64_t TripCountNumBlocks = std::numeric_limits<uint64_t>::max();
if (LoopTripCount > 0) {
@@ -815,8 +819,11 @@ Error GenericDeviceTy::init(GenericPluginTy &Plugin) {
// Enable the memory manager if required.
auto [ThresholdMM, EnableMM] = MemoryManagerTy::getSizeThresholdFromEnv();
- if (EnableMM)
+ if (EnableMM) {
+ if (ThresholdMM == 0)
+ ThresholdMM = getMemoryManagerSizeThreshold();
MemoryManager = new MemoryManagerTy(*this, ThresholdMM);
+ }
return Plugin::success();
}
@@ -1332,18 +1339,28 @@ Error PinnedAllocationMapTy::unlockUnmappedHostBuffer(void *HstPtr) {
return eraseEntry(*Entry);
}
-Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo) {
- if (!AsyncInfo || !AsyncInfo->Queue)
+Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo,
+ bool ReleaseQueue) {
+ if (!AsyncInfo)
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
"invalid async info queue");
- if (auto Err = synchronizeImpl(*AsyncInfo))
- return Err;
+ SmallVector<void *> AllocsToDelete{};
+ {
+ std::lock_guard<std::mutex> AllocationGuard{AsyncInfo->Mutex};
+
+ // This can be false when no work has been added to the AsyncInfo. In which
+ // case, the device has nothing to synchronize.
+ if (AsyncInfo->Queue)
+ if (auto Err = synchronizeImpl(*AsyncInfo, ReleaseQueue))
+ return Err;
+
+ std::swap(AllocsToDelete, AsyncInfo->AssociatedAllocations);
+ }
- for (auto *Ptr : AsyncInfo->AssociatedAllocations)
+ for (auto *Ptr : AllocsToDelete)
if (auto Err = dataDelete(Ptr, TargetAllocTy::TARGET_ALLOC_DEVICE))
return Err;
- AsyncInfo->AssociatedAllocations.clear();
return Plugin::success();
}
@@ -1530,6 +1547,16 @@ Error GenericDeviceTy::dataExchange(const void *SrcPtr, GenericDeviceTy &DstDev,
return Err;
}
+Error GenericDeviceTy::dataFill(void *TgtPtr, const void *PatternPtr,
+ int64_t PatternSize, int64_t Size,
+ __tgt_async_info *AsyncInfo) {
+ AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo);
+ auto Err =
+ dataFillImpl(TgtPtr, PatternPtr, PatternSize, Size, AsyncInfoWrapper);
+ AsyncInfoWrapper.finalize(Err);
+ return Err;
+}
+
Error GenericDeviceTy::launchKernel(void *EntryPtr, void **ArgPtrs,
ptrdiff_t *ArgOffsets,
KernelArgsTy &KernelArgs,
@@ -1579,6 +1606,15 @@ Error GenericDeviceTy::initAsyncInfo(__tgt_async_info **AsyncInfoPtr) {
return Err;
}
+Error GenericDeviceTy::enqueueHostCall(void (*Callback)(void *), void *UserData,
+ __tgt_async_info *AsyncInfo) {
+ AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo);
+
+ auto Err = enqueueHostCallImpl(Callback, UserData, AsyncInfoWrapper);
+ AsyncInfoWrapper.finalize(Err);
+ return Err;
+}
+
Error GenericDeviceTy::initDeviceInfo(__tgt_device_info *DeviceInfo) {
assert(DeviceInfo && "Invalid device info");
@@ -1623,6 +1659,37 @@ Error GenericDeviceTy::waitEvent(void *EventPtr, __tgt_async_info *AsyncInfo) {
return Err;
}
+Expected<bool> GenericDeviceTy::hasPendingWork(__tgt_async_info *AsyncInfo) {
+ AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo);
+ auto Res = hasPendingWorkImpl(AsyncInfoWrapper);
+ if (auto Err = Res.takeError()) {
+ AsyncInfoWrapper.finalize(Err);
+ return Err;
+ }
+
+ auto Err = Plugin::success();
+ AsyncInfoWrapper.finalize(Err);
+ if (Err)
+ return Err;
+ return Res;
+}
+
+Expected<bool> GenericDeviceTy::isEventComplete(void *Event,
+ __tgt_async_info *AsyncInfo) {
+ AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo);
+ auto Res = isEventCompleteImpl(Event, AsyncInfoWrapper);
+ if (auto Err = Res.takeError()) {
+ AsyncInfoWrapper.finalize(Err);
+ return Err;
+ }
+
+ auto Err = Plugin::success();
+ AsyncInfoWrapper.finalize(Err);
+ if (Err)
+ return Err;
+ return Res;
+}
+
Error GenericDeviceTy::syncEvent(void *EventPtr) {
return syncEventImpl(EventPtr);
}
@@ -2299,3 +2366,15 @@ int32_t GenericPluginTy::async_barrier(omp_interop_val_t *Interop) {
}
return OFFLOAD_SUCCESS;
}
+
+int32_t GenericPluginTy::data_fence(int32_t DeviceId,
+ __tgt_async_info *AsyncInfo) {
+ auto Err = getDevice(DeviceId).dataFence(AsyncInfo);
+ if (Err) {
+ REPORT("failure to place data fence on device %d: %s\n", DeviceId,
+ toString(std::move(Err)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+}