diff options
Diffstat (limited to 'offload/plugins-nextgen/common/src')
-rw-r--r-- | offload/plugins-nextgen/common/src/JIT.cpp | 26 | ||||
-rw-r--r-- | offload/plugins-nextgen/common/src/PluginInterface.cpp | 3 |
2 files changed, 20 insertions, 9 deletions
diff --git a/offload/plugins-nextgen/common/src/JIT.cpp b/offload/plugins-nextgen/common/src/JIT.cpp index c82a06e..00720fa 100644 --- a/offload/plugins-nextgen/common/src/JIT.cpp +++ b/offload/plugins-nextgen/common/src/JIT.cpp @@ -285,8 +285,8 @@ JITEngine::compile(const __tgt_device_image &Image, // Check if we JITed this image for the given compute unit kind before. ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind]; - if (__tgt_device_image *JITedImage = CUI.TgtImageMap.lookup(&Image)) - return JITedImage; + if (CUI.TgtImageMap.contains(&Image)) + return CUI.TgtImageMap[&Image].get(); auto ObjMBOrErr = getOrCreateObjFile(Image, CUI.Context, ComputeUnitKind); if (!ObjMBOrErr) @@ -296,17 +296,15 @@ JITEngine::compile(const __tgt_device_image &Image, if (!ImageMBOrErr) return ImageMBOrErr.takeError(); - CUI.JITImages.push_back(std::move(*ImageMBOrErr)); - __tgt_device_image *&JITedImage = CUI.TgtImageMap[&Image]; - JITedImage = new __tgt_device_image(); + CUI.JITImages.insert({&Image, std::move(*ImageMBOrErr)}); + auto &ImageMB = CUI.JITImages[&Image]; + CUI.TgtImageMap.insert({&Image, std::make_unique<__tgt_device_image>()}); + auto &JITedImage = CUI.TgtImageMap[&Image]; *JITedImage = Image; - - auto &ImageMB = CUI.JITImages.back(); - JITedImage->ImageStart = const_cast<char *>(ImageMB->getBufferStart()); JITedImage->ImageEnd = const_cast<char *>(ImageMB->getBufferEnd()); - return JITedImage; + return JITedImage.get(); } Expected<const __tgt_device_image *> @@ -324,3 +322,13 @@ JITEngine::process(const __tgt_device_image &Image, return &Image; } + +void JITEngine::erase(const __tgt_device_image &Image, + target::plugin::GenericDeviceTy &Device) { + std::lock_guard<std::mutex> Lock(ComputeUnitMapMutex); + const std::string &ComputeUnitKind = Device.getComputeUnitKind(); + ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind]; + + CUI.TgtImageMap.erase(&Image); + CUI.JITImages.erase(&Image); +} diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp index 81b9d42..94a050b5 100644 --- a/offload/plugins-nextgen/common/src/PluginInterface.cpp +++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp @@ -854,6 +854,9 @@ Error GenericDeviceTy::unloadBinary(DeviceImageTy *Image) { return Err; } + if (Image->getTgtImageBitcode()) + Plugin.getJIT().erase(*Image->getTgtImageBitcode(), Image->getDevice()); + return unloadBinaryImpl(Image); } |