diff options
Diffstat (limited to 'offload/plugins-nextgen/common/src')
-rw-r--r-- | offload/plugins-nextgen/common/src/PluginInterface.cpp | 93 |
1 files changed, 86 insertions, 7 deletions
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp index ed79af9..e5a313d 100644 --- a/offload/plugins-nextgen/common/src/PluginInterface.cpp +++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp @@ -662,6 +662,10 @@ uint32_t GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice, return std::min(NumTeamsClause[0], GenericDevice.getBlockLimit()); } + // Return the number of teams required to cover the loop iterations. + if (isNoLoopMode()) + return LoopTripCount > 0 ? (((LoopTripCount - 1) / NumThreads) + 1) : 1; + uint64_t DefaultNumBlocks = GenericDevice.getDefaultNumBlocks(); uint64_t TripCountNumBlocks = std::numeric_limits<uint64_t>::max(); if (LoopTripCount > 0) { @@ -815,8 +819,11 @@ Error GenericDeviceTy::init(GenericPluginTy &Plugin) { // Enable the memory manager if required. auto [ThresholdMM, EnableMM] = MemoryManagerTy::getSizeThresholdFromEnv(); - if (EnableMM) + if (EnableMM) { + if (ThresholdMM == 0) + ThresholdMM = getMemoryManagerSizeThreshold(); MemoryManager = new MemoryManagerTy(*this, ThresholdMM); + } return Plugin::success(); } @@ -1332,18 +1339,28 @@ Error PinnedAllocationMapTy::unlockUnmappedHostBuffer(void *HstPtr) { return eraseEntry(*Entry); } -Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo) { - if (!AsyncInfo || !AsyncInfo->Queue) +Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo, + bool ReleaseQueue) { + if (!AsyncInfo) return Plugin::error(ErrorCode::INVALID_ARGUMENT, "invalid async info queue"); - if (auto Err = synchronizeImpl(*AsyncInfo)) - return Err; + SmallVector<void *> AllocsToDelete{}; + { + std::lock_guard<std::mutex> AllocationGuard{AsyncInfo->Mutex}; + + // This can be false when no work has been added to the AsyncInfo. In which + // case, the device has nothing to synchronize. + if (AsyncInfo->Queue) + if (auto Err = synchronizeImpl(*AsyncInfo, ReleaseQueue)) + return Err; + + std::swap(AllocsToDelete, AsyncInfo->AssociatedAllocations); + } - for (auto *Ptr : AsyncInfo->AssociatedAllocations) + for (auto *Ptr : AllocsToDelete) if (auto Err = dataDelete(Ptr, TargetAllocTy::TARGET_ALLOC_DEVICE)) return Err; - AsyncInfo->AssociatedAllocations.clear(); return Plugin::success(); } @@ -1530,6 +1547,16 @@ Error GenericDeviceTy::dataExchange(const void *SrcPtr, GenericDeviceTy &DstDev, return Err; } +Error GenericDeviceTy::dataFill(void *TgtPtr, const void *PatternPtr, + int64_t PatternSize, int64_t Size, + __tgt_async_info *AsyncInfo) { + AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo); + auto Err = + dataFillImpl(TgtPtr, PatternPtr, PatternSize, Size, AsyncInfoWrapper); + AsyncInfoWrapper.finalize(Err); + return Err; +} + Error GenericDeviceTy::launchKernel(void *EntryPtr, void **ArgPtrs, ptrdiff_t *ArgOffsets, KernelArgsTy &KernelArgs, @@ -1579,6 +1606,15 @@ Error GenericDeviceTy::initAsyncInfo(__tgt_async_info **AsyncInfoPtr) { return Err; } +Error GenericDeviceTy::enqueueHostCall(void (*Callback)(void *), void *UserData, + __tgt_async_info *AsyncInfo) { + AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo); + + auto Err = enqueueHostCallImpl(Callback, UserData, AsyncInfoWrapper); + AsyncInfoWrapper.finalize(Err); + return Err; +} + Error GenericDeviceTy::initDeviceInfo(__tgt_device_info *DeviceInfo) { assert(DeviceInfo && "Invalid device info"); @@ -1623,6 +1659,37 @@ Error GenericDeviceTy::waitEvent(void *EventPtr, __tgt_async_info *AsyncInfo) { return Err; } +Expected<bool> GenericDeviceTy::hasPendingWork(__tgt_async_info *AsyncInfo) { + AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo); + auto Res = hasPendingWorkImpl(AsyncInfoWrapper); + if (auto Err = Res.takeError()) { + AsyncInfoWrapper.finalize(Err); + return Err; + } + + auto Err = Plugin::success(); + AsyncInfoWrapper.finalize(Err); + if (Err) + return Err; + return Res; +} + +Expected<bool> GenericDeviceTy::isEventComplete(void *Event, + __tgt_async_info *AsyncInfo) { + AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo); + auto Res = isEventCompleteImpl(Event, AsyncInfoWrapper); + if (auto Err = Res.takeError()) { + AsyncInfoWrapper.finalize(Err); + return Err; + } + + auto Err = Plugin::success(); + AsyncInfoWrapper.finalize(Err); + if (Err) + return Err; + return Res; +} + Error GenericDeviceTy::syncEvent(void *EventPtr) { return syncEventImpl(EventPtr); } @@ -2299,3 +2366,15 @@ int32_t GenericPluginTy::async_barrier(omp_interop_val_t *Interop) { } return OFFLOAD_SUCCESS; } + +int32_t GenericPluginTy::data_fence(int32_t DeviceId, + __tgt_async_info *AsyncInfo) { + auto Err = getDevice(DeviceId).dataFence(AsyncInfo); + if (Err) { + REPORT("failure to place data fence on device %d: %s\n", DeviceId, + toString(std::move(Err)).data()); + return OFFLOAD_FAIL; + } + + return OFFLOAD_SUCCESS; +} |