//===- ol_impl.cpp - Implementation of the new LLVM/Offload API ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This contains the definitions of the new LLVM/Offload API entry points. See // new-api/API/README.md for more information. // //===----------------------------------------------------------------------===// #include "OffloadImpl.hpp" #include "Helpers.hpp" #include "OffloadPrint.hpp" #include "PluginManager.h" #include "llvm/Support/FormatVariadic.h" #include #include // TODO: Some plugins expect to be linked into libomptarget which defines these // symbols to implement ompt callbacks. The least invasive workaround here is to // define them in libLLVMOffload as false/null so they are never used. In future // it would be better to allow the plugins to implement callbacks without // pulling in details from libomptarget. #ifdef OMPT_SUPPORT namespace llvm::omp::target { namespace ompt { bool Initialized = false; ompt_get_callback_t lookupCallbackByCode = nullptr; ompt_function_lookup_t lookupCallbackByName = nullptr; } // namespace ompt } // namespace llvm::omp::target #endif using namespace llvm::omp::target; using namespace llvm::omp::target::plugin; using namespace error; struct ol_platform_impl_t { ol_platform_impl_t(std::unique_ptr Plugin, ol_platform_backend_t BackendType) : Plugin(std::move(Plugin)), BackendType(BackendType) {} std::unique_ptr Plugin; llvm::SmallVector> Devices; ol_platform_backend_t BackendType; /// Complete all pending work for this platform and perform any needed /// cleanup. /// /// After calling this function, no liboffload functions should be called with /// this platform handle. llvm::Error destroy(); }; // Handle type definitions. Ideally these would be 1:1 with the plugins, but // we add some additional data here for now to avoid churn in the plugin // interface. struct ol_device_impl_t { ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device, ol_platform_impl_t &Platform, InfoTreeNode &&DevInfo) : DeviceNum(DeviceNum), Device(Device), Platform(Platform), Info(std::forward(DevInfo)) {} ~ol_device_impl_t() { assert(!OutstandingQueues.size() && "Device object dropped with outstanding queues"); } int DeviceNum; GenericDeviceTy *Device; ol_platform_impl_t &Platform; InfoTreeNode Info; llvm::SmallVector<__tgt_async_info *> OutstandingQueues; std::mutex OutstandingQueuesMutex; /// If the device has any outstanding queues that are now complete, remove it /// from the list and return it. /// /// Queues may be added to the outstanding queue list by olDestroyQueue if /// they are destroyed but not completed. __tgt_async_info *getOutstandingQueue() { // Not locking the `size()` access is fine here - In the worst case we // either miss a queue that exists or loop through an empty array after // taking the lock. Both are sub-optimal but not that bad. if (OutstandingQueues.size()) { std::lock_guard Lock(OutstandingQueuesMutex); // As queues are pulled and popped from this list, longer running queues // naturally bubble to the start of the array. Hence looping backwards. for (auto Q = OutstandingQueues.rbegin(); Q != OutstandingQueues.rend(); Q++) { if (!Device->hasPendingWork(*Q)) { auto OutstandingQueue = *Q; *Q = OutstandingQueues.back(); OutstandingQueues.pop_back(); return OutstandingQueue; } } } return nullptr; } /// Complete all pending work for this device and perform any needed cleanup. /// /// After calling this function, no liboffload functions should be called with /// this device handle. llvm::Error destroy() { llvm::Error Result = Plugin::success(); for (auto Q : OutstandingQueues) if (auto Err = Device->synchronize(Q, /*Release=*/true)) Result = llvm::joinErrors(std::move(Result), std::move(Err)); OutstandingQueues.clear(); return Result; } }; llvm::Error ol_platform_impl_t::destroy() { llvm::Error Result = Plugin::success(); for (auto &D : Devices) if (auto Err = D->destroy()) Result = llvm::joinErrors(std::move(Result), std::move(Err)); if (auto Res = Plugin->deinit()) Result = llvm::joinErrors(std::move(Result), std::move(Res)); return Result; } struct ol_queue_impl_t { ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device) : AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {} __tgt_async_info *AsyncInfo; ol_device_handle_t Device; // A unique identifier for the queue size_t Id; static std::atomic IdCounter; }; std::atomic ol_queue_impl_t::IdCounter(0); struct ol_event_impl_t { ol_event_impl_t(void *EventInfo, ol_device_handle_t Device, ol_queue_handle_t Queue) : EventInfo(EventInfo), Device(Device), QueueId(Queue->Id), Queue(Queue) { } // EventInfo may be null, in which case the event should be considered always // complete void *EventInfo; ol_device_handle_t Device; size_t QueueId; // Events may outlive the queue - don't assume this is always valid. // It is provided only to implement OL_EVENT_INFO_QUEUE. Use QueueId to check // for queue equality instead. ol_queue_handle_t Queue; }; struct ol_program_impl_t { ol_program_impl_t(plugin::DeviceImageTy *Image, llvm::MemoryBufferRef DeviceImage) : Image(Image), DeviceImage(DeviceImage) {} plugin::DeviceImageTy *Image; std::mutex SymbolListMutex; llvm::MemoryBufferRef DeviceImage; llvm::StringMap> KernelSymbols; llvm::StringMap> GlobalSymbols; }; struct ol_symbol_impl_t { ol_symbol_impl_t(const char *Name, GenericKernelTy *Kernel) : PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL), Name(Name) {} ol_symbol_impl_t(const char *Name, GlobalTy &&Global) : PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE), Name(Name) {} std::variant PluginImpl; ol_symbol_kind_t Kind; llvm::StringRef Name; }; namespace llvm { namespace offload { struct AllocInfo { ol_device_handle_t Device; ol_alloc_type_t Type; void *Start; // One byte past the end void *End; }; // Global shared state for liboffload struct OffloadContext; // This pointer is non-null if and only if the context is valid and fully // initialized static std::atomic OffloadContextVal; std::mutex OffloadContextValMutex; struct OffloadContext { OffloadContext(OffloadContext &) = delete; OffloadContext(OffloadContext &&) = delete; OffloadContext &operator=(OffloadContext &) = delete; OffloadContext &operator=(OffloadContext &&) = delete; bool TracingEnabled = false; bool ValidationEnabled = true; DenseMap AllocInfoMap{}; std::mutex AllocInfoMapMutex{}; // Partitioned list of memory base addresses. Each element in this list is a // key in AllocInfoMap llvm::SmallVector AllocBases{}; SmallVector, 4> Platforms{}; size_t RefCount; ol_device_handle_t HostDevice() { // The host platform is always inserted last return Platforms.back()->Devices[0].get(); } static OffloadContext &get() { assert(OffloadContextVal); return *OffloadContextVal; } }; // If the context is uninited, then we assume tracing is disabled bool isTracingEnabled() { return isOffloadInitialized() && OffloadContext::get().TracingEnabled; } bool isValidationEnabled() { return OffloadContext::get().ValidationEnabled; } bool isOffloadInitialized() { return OffloadContextVal != nullptr; } template Error olDestroy(HandleT Handle) { delete Handle; return Error::success(); } constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) { if (Name == "amdgpu") { return OL_PLATFORM_BACKEND_AMDGPU; } else if (Name == "cuda") { return OL_PLATFORM_BACKEND_CUDA; } else { return OL_PLATFORM_BACKEND_UNKNOWN; } } // Every plugin exports this method to create an instance of the plugin type. #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name(); #include "Shared/Targets.def" Error initPlugins(OffloadContext &Context) { // Attempt to create an instance of each supported plugin. #define PLUGIN_TARGET(Name) \ do { \ if (StringRef(#Name) != "host") \ Context.Platforms.emplace_back(std::make_unique( \ std::unique_ptr(createPlugin_##Name()), \ pluginNameToBackend(#Name))); \ } while (false); #include "Shared/Targets.def" // Preemptively initialize all devices in the plugin for (auto &Platform : Context.Platforms) { auto Err = Platform->Plugin->init(); [[maybe_unused]] std::string InfoMsg = toString(std::move(Err)); for (auto DevNum = 0; DevNum < Platform->Plugin->number_of_devices(); DevNum++) { if (Platform->Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) { auto Device = &Platform->Plugin->getDevice(DevNum); auto Info = Device->obtainInfoImpl(); if (auto Err = Info.takeError()) return Err; Platform->Devices.emplace_back(std::make_unique( DevNum, Device, *Platform, std::move(*Info))); } } } // Add the special host device auto &HostPlatform = Context.Platforms.emplace_back( std::make_unique(nullptr, OL_PLATFORM_BACKEND_HOST)); HostPlatform->Devices.emplace_back(std::make_unique( -1, nullptr, *HostPlatform, InfoTreeNode{})); Context.TracingEnabled = std::getenv("OFFLOAD_TRACE"); Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION"); return Plugin::success(); } Error olInit_impl() { std::lock_guard Lock(OffloadContextValMutex); if (isOffloadInitialized()) { OffloadContext::get().RefCount++; return Plugin::success(); } // Use a temporary to ensure that entry points querying OffloadContextVal do // not get a partially initialized context auto *NewContext = new OffloadContext{}; Error InitResult = initPlugins(*NewContext); OffloadContextVal.store(NewContext); OffloadContext::get().RefCount++; return InitResult; } Error olShutDown_impl() { std::lock_guard Lock(OffloadContextValMutex); if (--OffloadContext::get().RefCount != 0) return Error::success(); llvm::Error Result = Error::success(); auto *OldContext = OffloadContextVal.exchange(nullptr); for (auto &P : OldContext->Platforms) { // Host plugin is nullptr and has no deinit if (!P->Plugin || !P->Plugin->is_initialized()) continue; if (auto Res = P->destroy()) Result = llvm::joinErrors(std::move(Result), std::move(Res)); } delete OldContext; return Result; } Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { InfoWriter Info(PropSize, PropValue, PropSizeRet); bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST; switch (PropName) { case OL_PLATFORM_INFO_NAME: return Info.writeString(IsHost ? "Host" : Platform->Plugin->getName()); case OL_PLATFORM_INFO_VENDOR_NAME: // TODO: Implement this return Info.writeString("Unknown platform vendor"); case OL_PLATFORM_INFO_VERSION: { return Info.writeString(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR, OL_VERSION_MINOR, OL_VERSION_PATCH) .str()); } case OL_PLATFORM_INFO_BACKEND: { return Info.write(Platform->BackendType); } default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "getPlatformInfo enum '%i' is invalid", PropName); } return Error::success(); } Error olGetPlatformInfo_impl(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t PropSize, void *PropValue) { return olGetPlatformInfoImplDetail(Platform, PropName, PropSize, PropValue, nullptr); } Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform, ol_platform_info_t PropName, size_t *PropSizeRet) { return olGetPlatformInfoImplDetail(Platform, PropName, 0, nullptr, PropSizeRet); } Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { assert(Device != OffloadContext::get().HostDevice()); InfoWriter Info(PropSize, PropValue, PropSizeRet); auto makeError = [&](ErrorCode Code, StringRef Err) { std::string ErrBuffer; llvm::raw_string_ostream(ErrBuffer) << PropName << ": " << Err; return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str()); }; // These are not implemented by the plugin interface switch (PropName) { case OL_DEVICE_INFO_PLATFORM: return Info.write(&Device->Platform); case OL_DEVICE_INFO_TYPE: return Info.write(OL_DEVICE_TYPE_GPU); case OL_DEVICE_INFO_SINGLE_FP_CONFIG: case OL_DEVICE_INFO_DOUBLE_FP_CONFIG: { ol_device_fp_capability_flags_t flags{0}; flags |= OL_DEVICE_FP_CAPABILITY_FLAG_CORRECTLY_ROUNDED_DIVIDE_SQRT | OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_NEAREST | OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_ZERO | OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_INF | OL_DEVICE_FP_CAPABILITY_FLAG_INF_NAN | OL_DEVICE_FP_CAPABILITY_FLAG_DENORM | OL_DEVICE_FP_CAPABILITY_FLAG_FMA; return Info.write(flags); } case OL_DEVICE_INFO_HALF_FP_CONFIG: return Info.write(0); case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_CHAR: case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_SHORT: case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_INT: case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_LONG: case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_FLOAT: case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_DOUBLE: return Info.write(1); case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_HALF: return Info.write(0); // None of the existing plugins specify a limit on a single allocation, // so return the global memory size instead case OL_DEVICE_INFO_MAX_MEM_ALLOC_SIZE: [[fallthrough]]; // AMD doesn't provide the global memory size (trivially) with the device info // struct, so use the plugin interface case OL_DEVICE_INFO_GLOBAL_MEM_SIZE: { uint64_t Mem; if (auto Err = Device->Device->getDeviceMemorySize(Mem)) return Err; return Info.write(Mem); } break; default: break; } if (PropName >= OL_DEVICE_INFO_LAST) return createOffloadError(ErrorCode::INVALID_ENUMERATION, "getDeviceInfo enum '%i' is invalid", PropName); auto EntryOpt = Device->Info.get(static_cast(PropName)); if (!EntryOpt) return makeError(ErrorCode::UNIMPLEMENTED, "plugin did not provide a response for this information"); auto Entry = *EntryOpt; // Retrieve properties from the plugin interface switch (PropName) { case OL_DEVICE_INFO_NAME: case OL_DEVICE_INFO_PRODUCT_NAME: case OL_DEVICE_INFO_VENDOR: case OL_DEVICE_INFO_DRIVER_VERSION: { // String values if (!std::holds_alternative(Entry->Value)) return makeError(ErrorCode::BACKEND_FAILURE, "plugin returned incorrect type"); return Info.writeString(std::get(Entry->Value).c_str()); } case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: case OL_DEVICE_INFO_MAX_WORK_SIZE: case OL_DEVICE_INFO_VENDOR_ID: case OL_DEVICE_INFO_NUM_COMPUTE_UNITS: case OL_DEVICE_INFO_ADDRESS_BITS: case OL_DEVICE_INFO_MAX_CLOCK_FREQUENCY: case OL_DEVICE_INFO_MEMORY_CLOCK_RATE: { // Uint32 values if (!std::holds_alternative(Entry->Value)) return makeError(ErrorCode::BACKEND_FAILURE, "plugin returned incorrect type"); auto Value = std::get(Entry->Value); if (Value > std::numeric_limits::max()) return makeError(ErrorCode::BACKEND_FAILURE, "plugin returned out of range device info"); return Info.write(static_cast(Value)); } case OL_DEVICE_INFO_MAX_WORK_SIZE_PER_DIMENSION: case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION: { // {x, y, z} triples ol_dimensions_t Out{0, 0, 0}; auto getField = [&](StringRef Name, uint32_t &Dest) { if (auto F = Entry->get(Name)) { if (!std::holds_alternative((*F)->Value)) return makeError( ErrorCode::BACKEND_FAILURE, "plugin returned incorrect type for dimensions element"); Dest = std::get((*F)->Value); } else return makeError(ErrorCode::BACKEND_FAILURE, "plugin didn't provide all values for dimensions"); return Plugin::success(); }; if (auto Res = getField("x", Out.x)) return Res; if (auto Res = getField("y", Out.y)) return Res; if (auto Res = getField("z", Out.z)) return Res; return Info.write(Out); } default: llvm_unreachable("Unimplemented device info"); } } Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { assert(Device == OffloadContext::get().HostDevice()); InfoWriter Info(PropSize, PropValue, PropSizeRet); constexpr auto uint32_max = std::numeric_limits::max(); switch (PropName) { case OL_DEVICE_INFO_PLATFORM: return Info.write(&Device->Platform); case OL_DEVICE_INFO_TYPE: return Info.write(OL_DEVICE_TYPE_HOST); case OL_DEVICE_INFO_NAME: return Info.writeString("Virtual Host Device"); case OL_DEVICE_INFO_PRODUCT_NAME: return Info.writeString("Virtual Host Device"); case OL_DEVICE_INFO_VENDOR: return Info.writeString("Liboffload"); case OL_DEVICE_INFO_DRIVER_VERSION: return Info.writeString(LLVM_VERSION_STRING); case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: return Info.write(1); case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION: return Info.write(ol_dimensions_t{1, 1, 1}); case OL_DEVICE_INFO_MAX_WORK_SIZE: return Info.write(uint32_max); case OL_DEVICE_INFO_MAX_WORK_SIZE_PER_DIMENSION: return Info.write( ol_dimensions_t{uint32_max, uint32_max, uint32_max}); case OL_DEVICE_INFO_VENDOR_ID: return Info.write(0); case OL_DEVICE_INFO_NUM_COMPUTE_UNITS: return Info.write(1); case OL_DEVICE_INFO_SINGLE_FP_CONFIG: case OL_DEVICE_INFO_DOUBLE_FP_CONFIG: return Info.write( OL_DEVICE_FP_CAPABILITY_FLAG_CORRECTLY_ROUNDED_DIVIDE_SQRT | OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_NEAREST | OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_ZERO | OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_INF | OL_DEVICE_FP_CAPABILITY_FLAG_INF_NAN | OL_DEVICE_FP_CAPABILITY_FLAG_DENORM | OL_DEVICE_FP_CAPABILITY_FLAG_FMA); case OL_DEVICE_INFO_HALF_FP_CONFIG: return Info.write(0); case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_CHAR: case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_SHORT: case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_INT: case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_LONG: case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_FLOAT: case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_DOUBLE: return Info.write(1); case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_HALF: return Info.write(0); case OL_DEVICE_INFO_MAX_CLOCK_FREQUENCY: case OL_DEVICE_INFO_MEMORY_CLOCK_RATE: case OL_DEVICE_INFO_ADDRESS_BITS: return Info.write(std::numeric_limits::digits); case OL_DEVICE_INFO_MAX_MEM_ALLOC_SIZE: case OL_DEVICE_INFO_GLOBAL_MEM_SIZE: return Info.write(0); default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "getDeviceInfo enum '%i' is invalid", PropName); } return Error::success(); } Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue) { if (Device == OffloadContext::get().HostDevice()) return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue, nullptr); return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue, nullptr); } Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, ol_device_info_t PropName, size_t *PropSizeRet) { if (Device == OffloadContext::get().HostDevice()) return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr, PropSizeRet); return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet); } Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) { for (auto &Platform : OffloadContext::get().Platforms) { for (auto &Device : Platform->Devices) { if (!Callback(Device.get(), UserData)) { break; } } } return Error::success(); } TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) { switch (Type) { case OL_ALLOC_TYPE_DEVICE: return TARGET_ALLOC_DEVICE; case OL_ALLOC_TYPE_HOST: return TARGET_ALLOC_HOST; case OL_ALLOC_TYPE_MANAGED: default: return TARGET_ALLOC_SHARED; } } constexpr size_t MAX_ALLOC_TRIES = 50; Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type, size_t Size, void **AllocationOut) { SmallVector Rejects; // Repeat the allocation up to a certain amount of times. If it happens to // already be allocated (e.g. by a device from another vendor) throw it away // and try again. for (size_t Count = 0; Count < MAX_ALLOC_TRIES; Count++) { auto NewAlloc = Device->Device->dataAlloc(Size, nullptr, convertOlToPluginAllocTy(Type)); if (!NewAlloc) return NewAlloc.takeError(); void *NewEnd = &static_cast(*NewAlloc)[Size]; auto &AllocBases = OffloadContext::get().AllocBases; auto &AllocInfoMap = OffloadContext::get().AllocInfoMap; { std::lock_guard Lock(OffloadContext::get().AllocInfoMapMutex); // Check that this memory region doesn't overlap another one // That is, the start of this allocation needs to be after another // allocation's end point, and the end of this allocation needs to be // before the next one's start. // `Gap` is the first alloc who ends after the new alloc's start point. auto Gap = std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc, [&](const void *Iter, const void *Val) { return AllocInfoMap.at(Iter).End <= Val; }); if (Gap == AllocBases.end() || NewEnd <= AllocInfoMap.at(*Gap).Start) { // Success, no conflict AllocInfoMap.insert_or_assign( *NewAlloc, AllocInfo{Device, Type, *NewAlloc, NewEnd}); AllocBases.insert( std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc), *NewAlloc); *AllocationOut = *NewAlloc; for (void *R : Rejects) if (auto Err = Device->Device->dataDelete(R, convertOlToPluginAllocTy(Type))) return Err; return Error::success(); } // To avoid the next attempt allocating the same memory we just freed, we // hold onto it until we complete the allocation Rejects.push_back(*NewAlloc); } } // We've tried multiple times, and can't allocate a non-overlapping region. return createOffloadError(ErrorCode::BACKEND_FAILURE, "failed to allocate non-overlapping memory"); } Error olMemFree_impl(void *Address) { ol_device_handle_t Device; ol_alloc_type_t Type; { std::lock_guard Lock(OffloadContext::get().AllocInfoMapMutex); if (!OffloadContext::get().AllocInfoMap.contains(Address)) return createOffloadError(ErrorCode::INVALID_ARGUMENT, "address is not a known allocation"); auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address); Device = AllocInfo.Device; Type = AllocInfo.Type; OffloadContext::get().AllocInfoMap.erase(Address); auto &Bases = OffloadContext::get().AllocBases; Bases.erase(std::lower_bound(Bases.begin(), Bases.end(), Address)); } if (auto Res = Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type))) return Res; return Error::success(); } Error olGetMemInfoImplDetail(const void *Ptr, ol_mem_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { InfoWriter Info(PropSize, PropValue, PropSizeRet); std::lock_guard Lock(OffloadContext::get().AllocInfoMapMutex); auto &AllocBases = OffloadContext::get().AllocBases; auto &AllocInfoMap = OffloadContext::get().AllocInfoMap; const AllocInfo *Alloc = nullptr; if (AllocInfoMap.contains(Ptr)) { // Fast case, we have been given the base pointer directly Alloc = &AllocInfoMap.at(Ptr); } else { // Slower case, we need to look up the base pointer first // Find the first memory allocation whose end is after the target pointer, // and then check to see if it is in range auto Loc = std::lower_bound(AllocBases.begin(), AllocBases.end(), Ptr, [&](const void *Iter, const void *Val) { return AllocInfoMap.at(Iter).End <= Val; }); if (Loc == AllocBases.end() || Ptr < AllocInfoMap.at(*Loc).Start) return Plugin::error(ErrorCode::NOT_FOUND, "allocated memory information not found"); Alloc = &AllocInfoMap.at(*Loc); } switch (PropName) { case OL_MEM_INFO_DEVICE: return Info.write(Alloc->Device); case OL_MEM_INFO_BASE: return Info.write(Alloc->Start); case OL_MEM_INFO_SIZE: return Info.write(static_cast(Alloc->End) - static_cast(Alloc->Start)); case OL_MEM_INFO_TYPE: return Info.write(Alloc->Type); default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "olGetMemInfo enum '%i' is invalid", PropName); } return Error::success(); } Error olGetMemInfo_impl(const void *Ptr, ol_mem_info_t PropName, size_t PropSize, void *PropValue) { return olGetMemInfoImplDetail(Ptr, PropName, PropSize, PropValue, nullptr); } Error olGetMemInfoSize_impl(const void *Ptr, ol_mem_info_t PropName, size_t *PropSizeRet) { return olGetMemInfoImplDetail(Ptr, PropName, 0, nullptr, PropSizeRet); } Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) { auto CreatedQueue = std::make_unique(nullptr, Device); auto OutstandingQueue = Device->getOutstandingQueue(); if (OutstandingQueue) { // The queue is empty, but we still need to sync it to release any temporary // memory allocations or do other cleanup. if (auto Err = Device->Device->synchronize(OutstandingQueue, /*Release=*/false)) return Err; CreatedQueue->AsyncInfo = OutstandingQueue; } else if (auto Err = Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo))) { return Err; } *Queue = CreatedQueue.release(); return Error::success(); } Error olDestroyQueue_impl(ol_queue_handle_t Queue) { auto *Device = Queue->Device; // This is safe; as soon as olDestroyQueue is called it is not possible to add // any more work to the queue, so if it's finished now it will remain finished // forever. auto Res = Device->Device->hasPendingWork(Queue->AsyncInfo); if (!Res) return Res.takeError(); if (!*Res) { // The queue is complete, so sync it and throw it back into the pool. if (auto Err = Device->Device->synchronize(Queue->AsyncInfo, /*Release=*/true)) return Err; } else { // The queue still has outstanding work. Store it so we can check it later. std::lock_guard Lock(Device->OutstandingQueuesMutex); Device->OutstandingQueues.push_back(Queue->AsyncInfo); } return olDestroy(Queue); } Error olSyncQueue_impl(ol_queue_handle_t Queue) { // Host plugin doesn't have a queue set so it's not safe to call synchronize // on it, but we have nothing to synchronize in that situation anyway. if (Queue->AsyncInfo->Queue) { // We don't need to release the queue and we would like the ability for // other offload threads to submit work concurrently, so pass "false" here // so we don't release the underlying queue object. if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo, false)) return Err; } return Error::success(); } Error olWaitEvents_impl(ol_queue_handle_t Queue, ol_event_handle_t *Events, size_t NumEvents) { auto *Device = Queue->Device->Device; for (size_t I = 0; I < NumEvents; I++) { auto *Event = Events[I]; if (!Event) return Plugin::error(ErrorCode::INVALID_NULL_HANDLE, "olWaitEvents asked to wait on a NULL event"); // Do nothing if the event is for this queue or the event is always complete if (Event->QueueId == Queue->Id || !Event->EventInfo) continue; if (auto Err = Device->waitEvent(Event->EventInfo, Queue->AsyncInfo)) return Err; } return Error::success(); } Error olGetQueueInfoImplDetail(ol_queue_handle_t Queue, ol_queue_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { InfoWriter Info(PropSize, PropValue, PropSizeRet); switch (PropName) { case OL_QUEUE_INFO_DEVICE: return Info.write(Queue->Device); case OL_QUEUE_INFO_EMPTY: { auto Pending = Queue->Device->Device->hasPendingWork(Queue->AsyncInfo); if (auto Err = Pending.takeError()) return Err; return Info.write(!*Pending); } default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "olGetQueueInfo enum '%i' is invalid", PropName); } return Error::success(); } Error olGetQueueInfo_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName, size_t PropSize, void *PropValue) { return olGetQueueInfoImplDetail(Queue, PropName, PropSize, PropValue, nullptr); } Error olGetQueueInfoSize_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName, size_t *PropSizeRet) { return olGetQueueInfoImplDetail(Queue, PropName, 0, nullptr, PropSizeRet); } Error olSyncEvent_impl(ol_event_handle_t Event) { // No event info means that this event was complete on creation if (!Event->EventInfo) return Plugin::success(); if (auto Res = Event->Device->Device->syncEvent(Event->EventInfo)) return Res; return Error::success(); } Error olDestroyEvent_impl(ol_event_handle_t Event) { if (Event->EventInfo) if (auto Res = Event->Device->Device->destroyEvent(Event->EventInfo)) return Res; return olDestroy(Event); } Error olGetEventInfoImplDetail(ol_event_handle_t Event, ol_event_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { InfoWriter Info(PropSize, PropValue, PropSizeRet); auto Queue = Event->Queue; switch (PropName) { case OL_EVENT_INFO_QUEUE: return Info.write(Queue); case OL_EVENT_INFO_IS_COMPLETE: { // No event info means that this event was complete on creation if (!Event->EventInfo) return Info.write(true); auto Res = Queue->Device->Device->isEventComplete(Event->EventInfo, Queue->AsyncInfo); if (auto Err = Res.takeError()) return Err; return Info.write(*Res); } default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "olGetEventInfo enum '%i' is invalid", PropName); } return Error::success(); } Error olGetEventInfo_impl(ol_event_handle_t Event, ol_event_info_t PropName, size_t PropSize, void *PropValue) { return olGetEventInfoImplDetail(Event, PropName, PropSize, PropValue, nullptr); } Error olGetEventInfoSize_impl(ol_event_handle_t Event, ol_event_info_t PropName, size_t *PropSizeRet) { return olGetEventInfoImplDetail(Event, PropName, 0, nullptr, PropSizeRet); } Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) { auto Pending = Queue->Device->Device->hasPendingWork(Queue->AsyncInfo); if (auto Err = Pending.takeError()) return Err; *EventOut = new ol_event_impl_t(nullptr, Queue->Device, Queue); if (!*Pending) // Queue is empty, don't record an event and consider the event always // complete return Plugin::success(); if (auto Res = Queue->Device->Device->createEvent(&(*EventOut)->EventInfo)) return Res; if (auto Res = Queue->Device->Device->recordEvent((*EventOut)->EventInfo, Queue->AsyncInfo)) return Res; return Plugin::success(); } Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr, ol_device_handle_t DstDevice, const void *SrcPtr, ol_device_handle_t SrcDevice, size_t Size) { auto Host = OffloadContext::get().HostDevice(); if (DstDevice == Host && SrcDevice == Host) { if (!Queue) { std::memcpy(DstPtr, SrcPtr, Size); return Error::success(); } else { return createOffloadError( ErrorCode::INVALID_ARGUMENT, "ane of DstDevice and SrcDevice must be a non-host device if " "queue is specified"); } } // If no queue is given the memcpy will be synchronous auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr; if (DstDevice == Host) { if (auto Res = SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl)) return Res; } else if (SrcDevice == Host) { if (auto Res = DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl)) return Res; } else { if (auto Res = SrcDevice->Device->dataExchange(SrcPtr, *DstDevice->Device, DstPtr, Size, QueueImpl)) return Res; } return Error::success(); } Error olMemFill_impl(ol_queue_handle_t Queue, void *Ptr, size_t PatternSize, const void *PatternPtr, size_t FillSize) { return Queue->Device->Device->dataFill(Ptr, PatternPtr, PatternSize, FillSize, Queue->AsyncInfo); } Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData, size_t ProgDataSize, ol_program_handle_t *Program) { StringRef Buffer(reinterpret_cast(ProgData), ProgDataSize); Expected Res = Device->Device->loadBinary(Device->Device->Plugin, Buffer); if (!Res) return Res.takeError(); assert(*Res && "loadBinary returned nullptr"); *Program = new ol_program_impl_t(*Res, (*Res)->getMemoryBuffer()); return Error::success(); } Error olIsValidBinary_impl(ol_device_handle_t Device, const void *ProgData, size_t ProgDataSize, bool *IsValid) { StringRef Buffer(reinterpret_cast(ProgData), ProgDataSize); *IsValid = Device->Device->Plugin.isDeviceCompatible( Device->Device->getDeviceId(), Buffer); return Error::success(); } Error olDestroyProgram_impl(ol_program_handle_t Program) { auto &Device = Program->Image->getDevice(); if (auto Err = Device.unloadBinary(Program->Image)) return Err; auto &LoadedImages = Device.LoadedImages; LoadedImages.erase( std::find(LoadedImages.begin(), LoadedImages.end(), Program->Image)); return olDestroy(Program); } Error olCalculateOptimalOccupancy_impl(ol_device_handle_t Device, ol_symbol_handle_t Kernel, size_t DynamicMemSize, size_t *GroupSize) { if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL) return createOffloadError(ErrorCode::SYMBOL_KIND, "provided symbol is not a kernel"); auto *KernelImpl = std::get(Kernel->PluginImpl); auto Res = KernelImpl->maxGroupSize(*Device->Device, DynamicMemSize); if (auto Err = Res.takeError()) return Err; *GroupSize = *Res; return Error::success(); } Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device, ol_symbol_handle_t Kernel, const void *ArgumentsData, size_t ArgumentsSize, const ol_kernel_launch_size_args_t *LaunchSizeArgs) { auto *DeviceImpl = Device->Device; if (Queue && Device != Queue->Device) { return createOffloadError( ErrorCode::INVALID_DEVICE, "device specified does not match the device of the given queue"); } if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL) return createOffloadError(ErrorCode::SYMBOL_KIND, "provided symbol is not a kernel"); auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr; AsyncInfoWrapperTy AsyncInfoWrapper(*DeviceImpl, QueueImpl); KernelArgsTy LaunchArgs{}; LaunchArgs.NumTeams[0] = LaunchSizeArgs->NumGroups.x; LaunchArgs.NumTeams[1] = LaunchSizeArgs->NumGroups.y; LaunchArgs.NumTeams[2] = LaunchSizeArgs->NumGroups.z; LaunchArgs.ThreadLimit[0] = LaunchSizeArgs->GroupSize.x; LaunchArgs.ThreadLimit[1] = LaunchSizeArgs->GroupSize.y; LaunchArgs.ThreadLimit[2] = LaunchSizeArgs->GroupSize.z; LaunchArgs.DynCGroupMem = LaunchSizeArgs->DynSharedMemory; KernelLaunchParamsTy Params; Params.Data = const_cast(ArgumentsData); Params.Size = ArgumentsSize; LaunchArgs.ArgPtrs = reinterpret_cast(&Params); // Don't do anything with pointer indirection; use arg data as-is LaunchArgs.Flags.IsCUDA = true; auto *KernelImpl = std::get(Kernel->PluginImpl); auto Err = KernelImpl->launch(*DeviceImpl, LaunchArgs.ArgPtrs, nullptr, LaunchArgs, AsyncInfoWrapper); AsyncInfoWrapper.finalize(Err); if (Err) return Err; return Error::success(); } Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name, ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) { auto &Device = Program->Image->getDevice(); std::lock_guard Lock(Program->SymbolListMutex); switch (Kind) { case OL_SYMBOL_KIND_KERNEL: { auto &Kernel = Program->KernelSymbols[Name]; if (!Kernel) { auto KernelImpl = Device.constructKernel(Name); if (!KernelImpl) return KernelImpl.takeError(); if (auto Err = KernelImpl->init(Device, *Program->Image)) return Err; Kernel = std::make_unique(KernelImpl->getName(), &*KernelImpl); } *Symbol = Kernel.get(); return Error::success(); } case OL_SYMBOL_KIND_GLOBAL_VARIABLE: { auto &Global = Program->GlobalSymbols[Name]; if (!Global) { GlobalTy GlobalObj{Name}; if (auto Res = Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice( Device, *Program->Image, GlobalObj)) return Res; Global = std::make_unique(GlobalObj.getName().c_str(), std::move(GlobalObj)); } *Symbol = Global.get(); return Error::success(); } default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "getSymbol kind enum '%i' is invalid", Kind); } } Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol, ol_symbol_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { InfoWriter Info(PropSize, PropValue, PropSizeRet); auto CheckKind = [&](ol_symbol_kind_t Required) { if (Symbol->Kind != Required) { std::string ErrBuffer; llvm::raw_string_ostream(ErrBuffer) << PropName << ": Expected a symbol of Kind " << Required << " but given a symbol of Kind " << Symbol->Kind; return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str()); } return Plugin::success(); }; switch (PropName) { case OL_SYMBOL_INFO_KIND: return Info.write(Symbol->Kind); case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS: if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE)) return Err; return Info.write(std::get(Symbol->PluginImpl).getPtr()); case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE: if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE)) return Err; return Info.write(std::get(Symbol->PluginImpl).getSize()); default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "olGetSymbolInfo enum '%i' is invalid", PropName); } return Error::success(); } Error olGetSymbolInfo_impl(ol_symbol_handle_t Symbol, ol_symbol_info_t PropName, size_t PropSize, void *PropValue) { return olGetSymbolInfoImplDetail(Symbol, PropName, PropSize, PropValue, nullptr); } Error olGetSymbolInfoSize_impl(ol_symbol_handle_t Symbol, ol_symbol_info_t PropName, size_t *PropSizeRet) { return olGetSymbolInfoImplDetail(Symbol, PropName, 0, nullptr, PropSizeRet); } Error olLaunchHostFunction_impl(ol_queue_handle_t Queue, ol_host_function_cb_t Callback, void *UserData) { return Queue->Device->Device->enqueueHostCall(Callback, UserData, Queue->AsyncInfo); } } // namespace offload } // namespace llvm