//===- LevelZeroRuntimeWrappers.cpp - MLIR Level Zero (L0) wrapper library-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Implements wrappers around the Level Zero (L0) runtime library with C linkage // //===----------------------------------------------------------------------===// #include "llvm/ADT/Twine.h" #include "level_zero/ze_api.h" #include #include #include #include #include #include #include #include namespace { template auto catchAll(F &&func) { try { return func(); } catch (const std::exception &e) { std::cerr << "An exception was thrown: " << e.what() << std::endl; std::abort(); } catch (...) { std::cerr << "An unknown exception was thrown." << std::endl; std::abort(); } } #define L0_SAFE_CALL(call) \ { \ ze_result_t status = (call); \ if (status != ZE_RESULT_SUCCESS) { \ const char *errorString; \ zeDriverGetLastErrorDescription(NULL, &errorString); \ std::cerr << "L0 error " << status << ": " << errorString << std::endl; \ std::abort(); \ } \ } } // namespace //===----------------------------------------------------------------------===// // L0 RT context & device setters //===----------------------------------------------------------------------===// // Returns the L0 driver handle for the given index. Default index is 0 // (i.e., returns the first driver handle of the available drivers). static ze_driver_handle_t getDriver(uint32_t idx = 0) { ze_init_driver_type_desc_t driver_type = {}; driver_type.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC; driver_type.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU; driver_type.pNext = nullptr; uint32_t driverCount{0}; thread_local static std::vector drivers; thread_local static bool isDriverInitialised{false}; if (isDriverInitialised && idx < drivers.size()) return drivers[idx]; L0_SAFE_CALL(zeInitDrivers(&driverCount, nullptr, &driver_type)); if (!driverCount) throw std::runtime_error("No L0 drivers found."); drivers.resize(driverCount); L0_SAFE_CALL(zeInitDrivers(&driverCount, drivers.data(), &driver_type)); if (idx >= driverCount) throw std::runtime_error((llvm::Twine("Requested driver idx out-of-bound, " "number of availabe drivers: ") + std::to_string(driverCount)) .str()); isDriverInitialised = true; return drivers[idx]; } static ze_device_handle_t getDevice(const uint32_t driverIdx = 0, const int32_t devIdx = 0) { thread_local static ze_device_handle_t l0Device; thread_local int32_t currDevIdx{-1}; thread_local uint32_t currDriverIdx{0}; if (currDriverIdx == driverIdx && currDevIdx == devIdx) return l0Device; auto driver = getDriver(driverIdx); uint32_t deviceCount{0}; L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount, nullptr)); if (!deviceCount) throw std::runtime_error("getDevice failed: did not find L0 device."); if (static_cast(deviceCount) < devIdx + 1) throw std::runtime_error("getDevice failed: devIdx out-of-bounds."); std::vector devices(deviceCount); L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount, devices.data())); l0Device = devices[devIdx]; currDriverIdx = driverIdx; currDevIdx = devIdx; return l0Device; } // Returns the default L0 context of the defult driver. static ze_context_handle_t getContext(ze_driver_handle_t driver) { thread_local static ze_context_handle_t context; thread_local static bool isContextInitialised{false}; if (isContextInitialised) return context; ze_context_desc_t ctxtDesc = {ZE_STRUCTURE_TYPE_CONTEXT_DESC, nullptr, 0}; L0_SAFE_CALL(zeContextCreate(driver, &ctxtDesc, &context)); isContextInitialised = true; return context; } //===----------------------------------------------------------------------===// // L0 RT helper structs //===----------------------------------------------------------------------===// struct ZeContextDeleter { void operator()(ze_context_handle_t ctx) const { if (ctx) L0_SAFE_CALL(zeContextDestroy(ctx)); } }; struct ZeCommandListDeleter { void operator()(ze_command_list_handle_t cmdList) const { if (cmdList) L0_SAFE_CALL(zeCommandListDestroy(cmdList)); } }; using UniqueZeContext = std::unique_ptr::type, ZeContextDeleter>; using UniqueZeCommandList = std::unique_ptr::type, ZeCommandListDeleter>; struct L0RTContextWrapper { ze_driver_handle_t driver{nullptr}; ze_device_handle_t device{nullptr}; UniqueZeContext context; // Usually, one immediate command list with ordinal 0 suffices for // both copy and compute ops, but leaves HW underutilized. UniqueZeCommandList immCmdListCompute; // Copy engines can be used for both memcpy and memset, but // they have limitations for memset pattern size (e.g., 1 byte). UniqueZeCommandList immCmdListCopy; uint32_t copyEngineMaxMemoryFillPatternSize{-1u}; L0RTContextWrapper() = default; L0RTContextWrapper(const uint32_t driverIdx = 0, const int32_t devIdx = 0) : driver(getDriver(driverIdx)), device(getDevice(devIdx)) { // Create context ze_context_handle_t ctx = getContext(driver); context.reset(ctx); // Determine ordinals uint32_t computeEngineOrdinal = -1u, copyEngineOrdinal = -1u; ze_device_properties_t deviceProperties{}; L0_SAFE_CALL(zeDeviceGetProperties(device, &deviceProperties)); uint32_t queueGroupCount = 0; L0_SAFE_CALL(zeDeviceGetCommandQueueGroupProperties( device, &queueGroupCount, nullptr)); std::vector queueGroupProperties( queueGroupCount); L0_SAFE_CALL(zeDeviceGetCommandQueueGroupProperties( device, &queueGroupCount, queueGroupProperties.data())); for (uint32_t queueGroupIdx = 0; queueGroupIdx < queueGroupCount; ++queueGroupIdx) { const auto &group = queueGroupProperties[queueGroupIdx]; if (group.flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE) computeEngineOrdinal = queueGroupIdx; else if (group.flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COPY) { copyEngineOrdinal = queueGroupIdx; copyEngineMaxMemoryFillPatternSize = group.maxMemoryFillPatternSize; } if (copyEngineOrdinal != -1u && computeEngineOrdinal != -1u) break; } // Fallback to the default queue if no dedicated copy queue is available. if (copyEngineOrdinal == -1u) copyEngineOrdinal = computeEngineOrdinal; assert(copyEngineOrdinal != -1u && computeEngineOrdinal != -1u && "Expected two engines to be available."); // Create copy command list ze_command_queue_desc_t cmdQueueDesc{ ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC, nullptr, copyEngineOrdinal, // ordinal 0, // index (assume one physical engine in the group) 0, // flags ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL}; ze_command_list_handle_t rawCmdListCopy = nullptr; L0_SAFE_CALL(zeCommandListCreateImmediate(context.get(), device, &cmdQueueDesc, &rawCmdListCopy)); immCmdListCopy.reset(rawCmdListCopy); // Create compute command list cmdQueueDesc.ordinal = computeEngineOrdinal; ze_command_list_handle_t rawCmdListCompute = nullptr; L0_SAFE_CALL(zeCommandListCreateImmediate( context.get(), device, &cmdQueueDesc, &rawCmdListCompute)); immCmdListCompute.reset(rawCmdListCompute); } L0RTContextWrapper(const L0RTContextWrapper &) = delete; L0RTContextWrapper &operator=(const L0RTContextWrapper &) = delete; // Allow move L0RTContextWrapper(L0RTContextWrapper &&) noexcept = default; L0RTContextWrapper &operator=(L0RTContextWrapper &&) noexcept = default; ~L0RTContextWrapper() = default; }; struct ZeEventDeleter { void operator()(ze_event_handle_t event) const { if (event) L0_SAFE_CALL(zeEventDestroy(event)); } }; struct ZeEventPoolDeleter { void operator()(ze_event_pool_handle_t pool) const { if (pool) L0_SAFE_CALL(zeEventPoolDestroy(pool)); } }; using UniqueZeEvent = std::unique_ptr::type, ZeEventDeleter>; using UniqueZeEventPool = std::unique_ptr::type, ZeEventPoolDeleter>; // L0 only supports pre-determined sizes of event pools, // implement a runtime data structure to avoid running out of events. struct DynamicEventPool { constexpr static size_t numEventsPerPool{128}; std::vector eventPools; std::vector availableEvents; std::unordered_map takenEvents; // Limit the number of events to avoid running out of memory. // The limit is set to 32K events, which should be sufficient for most use // cases. size_t maxEventsCount{32768}; // 32K events size_t currentEventsLimit{0}; size_t currentEventsCnt{0}; L0RTContextWrapper *rtCtx; DynamicEventPool(L0RTContextWrapper *rtCtx) : rtCtx(rtCtx) { createNewPool(numEventsPerPool); } DynamicEventPool(const DynamicEventPool &) = delete; DynamicEventPool &operator=(const DynamicEventPool &) = delete; // Allow move DynamicEventPool(DynamicEventPool &&) noexcept = default; DynamicEventPool &operator=(DynamicEventPool &&) noexcept = default; ~DynamicEventPool() { assert(takenEvents.empty() && "Some events were not released"); } void createNewPool(size_t numEvents) { ze_event_pool_desc_t eventPoolDesc = {}; eventPoolDesc.flags = ZE_EVENT_POOL_FLAG_HOST_VISIBLE; eventPoolDesc.count = numEvents; ze_event_pool_handle_t rawPool = nullptr; L0_SAFE_CALL(zeEventPoolCreate(rtCtx->context.get(), &eventPoolDesc, 1, &rtCtx->device, &rawPool)); eventPools.emplace_back(UniqueZeEventPool(rawPool)); currentEventsLimit += numEvents; } ze_event_handle_t takeEvent() { ze_event_handle_t rawEvent = nullptr; if (!availableEvents.empty()) { // Reuse one auto uniqueEvent = std::move(availableEvents.back()); availableEvents.pop_back(); rawEvent = uniqueEvent.get(); takenEvents[rawEvent] = std::move(uniqueEvent); } else { if (currentEventsCnt >= maxEventsCount) { throw std::runtime_error("DynamicEventPool: reached max events limit"); } if (currentEventsCnt == currentEventsLimit) createNewPool(numEventsPerPool); ze_event_desc_t eventDesc = { ZE_STRUCTURE_TYPE_EVENT_DESC, nullptr, static_cast(currentEventsCnt % numEventsPerPool), ZE_EVENT_SCOPE_FLAG_DEVICE, ZE_EVENT_SCOPE_FLAG_HOST}; ze_event_handle_t newEvent = nullptr; L0_SAFE_CALL( zeEventCreate(eventPools.back().get(), &eventDesc, &newEvent)); takenEvents[newEvent] = UniqueZeEvent(newEvent); rawEvent = newEvent; currentEventsCnt++; } return rawEvent; } void releaseEvent(ze_event_handle_t event) { auto it = takenEvents.find(event); assert(it != takenEvents.end() && "Attempting to release unknown or already released event"); L0_SAFE_CALL(zeEventHostReset(event)); availableEvents.emplace_back(std::move(it->second)); takenEvents.erase(it); } }; L0RTContextWrapper &getRtContext() { thread_local static L0RTContextWrapper rtContext(0); return rtContext; } DynamicEventPool &getDynamicEventPool() { thread_local static DynamicEventPool dynEventPool{&getRtContext()}; return dynEventPool; } struct StreamWrapper { // avoid event pointer invalidations std::deque implicitEventStack; DynamicEventPool &dynEventPool; StreamWrapper(DynamicEventPool &dynEventPool) : dynEventPool(dynEventPool) {} ~StreamWrapper() { sync(); } ze_event_handle_t *getLastImplicitEventPtr() { // Assume current implicit events will not be used after `sync`. return implicitEventStack.size() ? &implicitEventStack.back() : nullptr; } void sync(ze_event_handle_t explicitEvent = nullptr) { ze_event_handle_t syncEvent{nullptr}; if (!explicitEvent) { ze_event_handle_t *lastImplicitEventPtr = getLastImplicitEventPtr(); syncEvent = lastImplicitEventPtr ? *lastImplicitEventPtr : nullptr; } else { syncEvent = explicitEvent; } if (syncEvent) L0_SAFE_CALL(zeEventHostSynchronize( syncEvent, std::numeric_limits::max())); // All of the "implicit" events were signaled and are of no use, release // them. "explicit" event must be "released" via mgpuEventDestroy for (auto event : implicitEventStack) dynEventPool.releaseEvent(event); implicitEventStack.clear(); } template void enqueueOp(Func &&op) { ze_event_handle_t newImplicitEvent = dynEventPool.takeEvent(); ze_event_handle_t *lastImplicitEventPtr = getLastImplicitEventPtr(); const uint32_t numWaitEvents = lastImplicitEventPtr ? 1 : 0; std::forward(op)(newImplicitEvent, numWaitEvents, lastImplicitEventPtr); implicitEventStack.push_back(newImplicitEvent); } }; static ze_module_handle_t loadModule(const void *data, size_t dataSize) { assert(data); ze_module_handle_t zeModule; ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC, nullptr, ZE_MODULE_FORMAT_IL_SPIRV, dataSize, (const uint8_t *)data, nullptr, nullptr}; ze_module_build_log_handle_t buildLogHandle; ze_result_t result = zeModuleCreate(getRtContext().context.get(), getRtContext().device, &desc, &zeModule, &buildLogHandle); if (result != ZE_RESULT_SUCCESS) { std::cerr << "Error creating module, error code: " << result << std::endl; size_t logSize = 0; L0_SAFE_CALL(zeModuleBuildLogGetString(buildLogHandle, &logSize, nullptr)); std::string buildLog(" ", logSize); L0_SAFE_CALL( zeModuleBuildLogGetString(buildLogHandle, &logSize, buildLog.data())); std::cerr << "Build log:\n" << buildLog << std::endl; std::abort(); } return zeModule; } //===----------------------------------------------------------------------===// // L0 Wrappers definition //===----------------------------------------------------------------------===// extern "C" StreamWrapper *mgpuStreamCreate() { return new StreamWrapper(getDynamicEventPool()); } extern "C" void mgpuStreamSynchronize(StreamWrapper *stream) { if (stream) stream->sync(); } extern "C" void mgpuStreamDestroy(StreamWrapper *stream) { delete stream; } extern "C" void mgpuStreamWaitEvent(StreamWrapper *stream, ze_event_handle_t event) { assert(stream && "Invalid stream"); assert(event && "Invalid event"); stream->sync(event); } extern "C" ze_event_handle_t mgpuEventCreate() { return getDynamicEventPool().takeEvent(); } extern "C" void mgpuEventDestroy(ze_event_handle_t event) { return getDynamicEventPool().releaseEvent(event); } extern "C" void mgpuEventSynchronize(ze_event_handle_t event) { L0_SAFE_CALL( zeEventHostSynchronize(event, std::numeric_limits::max())); L0_SAFE_CALL(zeEventHostReset(event)); } extern "C" void mgpuEventRecord(ze_event_handle_t event, StreamWrapper *stream) { L0_SAFE_CALL(zeCommandListAppendSignalEvent( getRtContext().immCmdListCopy.get(), event)); L0_SAFE_CALL(zeCommandListAppendSignalEvent( getRtContext().immCmdListCompute.get(), event)); } extern "C" void *mgpuMemAlloc(uint64_t size, StreamWrapper *stream, bool isShared) { return catchAll([&]() { void *memPtr = nullptr; constexpr size_t alignment{64}; ze_device_mem_alloc_desc_t deviceDesc = {}; deviceDesc.stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC; if (isShared) { ze_host_mem_alloc_desc_t hostDesc = {}; hostDesc.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC; L0_SAFE_CALL(zeMemAllocShared(getRtContext().context.get(), &deviceDesc, &hostDesc, size, alignment, getRtContext().device, &memPtr)); } else { L0_SAFE_CALL(zeMemAllocDevice(getRtContext().context.get(), &deviceDesc, size, alignment, getRtContext().device, &memPtr)); } if (!memPtr) throw std::runtime_error("mem allocation failed!"); return memPtr; }); } extern "C" void mgpuMemFree(void *ptr, StreamWrapper *stream) { stream->sync(); if (ptr) L0_SAFE_CALL(zeMemFree(getRtContext().context.get(), ptr)); } extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes, StreamWrapper *stream) { stream->enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents, ze_event_handle_t *waitEvents) { L0_SAFE_CALL(zeCommandListAppendMemoryCopy( getRtContext().immCmdListCopy.get(), dst, src, sizeBytes, newEvent, numWaitEvents, waitEvents)); }); } template void mgpuMemset(void *dst, PATTERN_TYPE value, size_t count, StreamWrapper *stream) { L0RTContextWrapper &rtContext = getRtContext(); auto listType = rtContext.copyEngineMaxMemoryFillPatternSize >= sizeof(PATTERN_TYPE) ? rtContext.immCmdListCopy.get() : rtContext.immCmdListCompute.get(); stream->enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents, ze_event_handle_t *waitEvents) { L0_SAFE_CALL(zeCommandListAppendMemoryFill( listType, dst, &value, sizeof(PATTERN_TYPE), count * sizeof(PATTERN_TYPE), newEvent, numWaitEvents, waitEvents)); }); } extern "C" void mgpuMemset32(void *dst, unsigned int value, size_t count, StreamWrapper *stream) { mgpuMemset(dst, value, count, stream); } extern "C" void mgpuMemset16(void *dst, unsigned short value, size_t count, StreamWrapper *stream) { mgpuMemset(dst, value, count, stream); } extern "C" ze_module_handle_t mgpuModuleLoad(const void *data, size_t gpuBlobSize) { return catchAll([&]() { return loadModule(data, gpuBlobSize); }); } extern "C" ze_kernel_handle_t mgpuModuleGetFunction(ze_module_handle_t module, const char *name) { assert(module && name); ze_kernel_handle_t zeKernel; ze_kernel_desc_t desc = {}; desc.pKernelName = name; L0_SAFE_CALL(zeKernelCreate(module, &desc, &zeKernel)); return zeKernel; } extern "C" void mgpuLaunchKernel(ze_kernel_handle_t kernel, size_t gridX, size_t gridY, size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, size_t sharedMemBytes, StreamWrapper *stream, void **params, void ** /*extra*/, size_t paramsCount) { if (sharedMemBytes > 0) { paramsCount = paramsCount - 1; // Last param is shared memory size L0_SAFE_CALL( zeKernelSetArgumentValue(kernel, paramsCount, sharedMemBytes, nullptr)); } for (size_t i = 0; i < paramsCount; ++i) L0_SAFE_CALL(zeKernelSetArgumentValue(kernel, static_cast(i), sizeof(void *), params[i])); L0_SAFE_CALL(zeKernelSetGroupSize(kernel, blockX, blockY, blockZ)); ze_group_count_t dispatch; dispatch.groupCountX = static_cast(gridX); dispatch.groupCountY = static_cast(gridY); dispatch.groupCountZ = static_cast(gridZ); stream->enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents, ze_event_handle_t *waitEvents) { L0_SAFE_CALL(zeCommandListAppendLaunchKernel( getRtContext().immCmdListCompute.get(), kernel, &dispatch, newEvent, numWaitEvents, waitEvents)); }); } extern "C" void mgpuModuleUnload(ze_module_handle_t module) { L0_SAFE_CALL(zeModuleDestroy(module)); } extern "C" void mgpuSetDefaultDevice(int32_t devIdx) { catchAll([&]() { // For now, a user must ensure that streams and events complete // and are destroyed before switching a device. getRtContext() = L0RTContextWrapper(devIdx); getDynamicEventPool() = DynamicEventPool(&getRtContext()); }); }