diff options
author | Lang Hames <lhames@gmail.com> | 2025-07-03 22:04:41 +1000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-07-03 22:04:41 +1000 |
commit | 2638fa1be63c33407b779e959027e6dbeec6cb4f (patch) | |
tree | 3ff77a96bf506342b8952840d197fc1ff0f9ae16 /llvm | |
parent | 2532bde0388980ac7e299b02bc554e6fde6c686e (diff) | |
download | llvm-2638fa1be63c33407b779e959027e6dbeec6cb4f.zip llvm-2638fa1be63c33407b779e959027e6dbeec6cb4f.tar.gz llvm-2638fa1be63c33407b779e959027e6dbeec6cb4f.tar.bz2 |
[ORC] Add cloneToContext: Clone Module to a given ThreadSafeContext (#146852)
This is a generalization of the existing cloneToNewContext operation:
rather than cloning the given module into a new ThreadSafeContext it
clones it into any given ThreadSafeContext. The given ThreadSafeContext
is locked to ensure that the cloning operation is safe.
Diffstat (limited to 'llvm')
3 files changed, 101 insertions, 34 deletions
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h b/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h index f135377..0316589 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h @@ -153,6 +153,12 @@ private: using GVPredicate = std::function<bool(const GlobalValue &)>; using GVModifier = std::function<void(GlobalValue &)>; +/// Clones teh given module onto the given context. +LLVM_ABI ThreadSafeModule +cloneToContext(const ThreadSafeModule &TSMW, ThreadSafeContext TSCtx, + GVPredicate ShouldCloneDef = GVPredicate(), + GVModifier UpdateClonedDefSource = GVModifier()); + /// Clones the given module on to a new context. LLVM_ABI ThreadSafeModule cloneToNewContext( const ThreadSafeModule &TSMW, GVPredicate ShouldCloneDef = GVPredicate(), diff --git a/llvm/lib/ExecutionEngine/Orc/ThreadSafeModule.cpp b/llvm/lib/ExecutionEngine/Orc/ThreadSafeModule.cpp index fadd53e..19c000e 100644 --- a/llvm/lib/ExecutionEngine/Orc/ThreadSafeModule.cpp +++ b/llvm/lib/ExecutionEngine/Orc/ThreadSafeModule.cpp @@ -14,51 +14,63 @@ namespace llvm { namespace orc { -ThreadSafeModule cloneToNewContext(const ThreadSafeModule &TSM, - GVPredicate ShouldCloneDef, - GVModifier UpdateClonedDefSource) { +ThreadSafeModule cloneToContext(const ThreadSafeModule &TSM, + ThreadSafeContext TSCtx, + GVPredicate ShouldCloneDef, + GVModifier UpdateClonedDefSource) { assert(TSM && "Can not clone null module"); if (!ShouldCloneDef) ShouldCloneDef = [](const GlobalValue &) { return true; }; - return TSM.withModuleDo([&](Module &M) { - SmallVector<char, 1> ClonedModuleBuffer; + // First copy the source module into a buffer. + std::string ModuleName; + SmallVector<char, 1> ClonedModuleBuffer; + TSM.withModuleDo([&](Module &M) { + ModuleName = M.getModuleIdentifier(); + std::set<GlobalValue *> ClonedDefsInSrc; + ValueToValueMapTy VMap; + auto Tmp = CloneModule(M, VMap, [&](const GlobalValue *GV) { + if (ShouldCloneDef(*GV)) { + ClonedDefsInSrc.insert(const_cast<GlobalValue *>(GV)); + return true; + } + return false; + }); - { - std::set<GlobalValue *> ClonedDefsInSrc; - ValueToValueMapTy VMap; - auto Tmp = CloneModule(M, VMap, [&](const GlobalValue *GV) { - if (ShouldCloneDef(*GV)) { - ClonedDefsInSrc.insert(const_cast<GlobalValue *>(GV)); - return true; - } - return false; - }); + if (UpdateClonedDefSource) + for (auto *GV : ClonedDefsInSrc) + UpdateClonedDefSource(*GV); - if (UpdateClonedDefSource) - for (auto *GV : ClonedDefsInSrc) - UpdateClonedDefSource(*GV); + BitcodeWriter BCWriter(ClonedModuleBuffer); + BCWriter.writeModule(*Tmp); + BCWriter.writeSymtab(); + BCWriter.writeStrtab(); + }); + + MemoryBufferRef ClonedModuleBufferRef( + StringRef(ClonedModuleBuffer.data(), ClonedModuleBuffer.size()), + "cloned module buffer"); - BitcodeWriter BCWriter(ClonedModuleBuffer); + // Then parse the buffer into the new Module. + auto M = TSCtx.withContextDo([&](LLVMContext *Ctx) { + assert(Ctx && "No LLVMContext provided"); + auto TmpM = cantFail(parseBitcodeFile(ClonedModuleBufferRef, *Ctx)); + TmpM->setModuleIdentifier(ModuleName); + return TmpM; + }); - BCWriter.writeModule(*Tmp); - BCWriter.writeSymtab(); - BCWriter.writeStrtab(); - } + return ThreadSafeModule(std::move(M), std::move(TSCtx)); +} - MemoryBufferRef ClonedModuleBufferRef( - StringRef(ClonedModuleBuffer.data(), ClonedModuleBuffer.size()), - "cloned module buffer"); - ThreadSafeContext NewTSCtx(std::make_unique<LLVMContext>()); +ThreadSafeModule cloneToNewContext(const ThreadSafeModule &TSM, + GVPredicate ShouldCloneDef, + GVModifier UpdateClonedDefSource) { + assert(TSM && "Can not clone null module"); - auto ClonedModule = NewTSCtx.withContextDo([&](LLVMContext *Ctx) { - auto TmpM = cantFail(parseBitcodeFile(ClonedModuleBufferRef, *Ctx)); - TmpM->setModuleIdentifier(M.getName()); - return TmpM; - }); - return ThreadSafeModule(std::move(ClonedModule), std::move(NewTSCtx)); - }); + ThreadSafeContext TSCtx(std::make_unique<LLVMContext>()); + return cloneToContext(TSM, std::move(TSCtx), std::move(ShouldCloneDef), + std::move(UpdateClonedDefSource)); } } // end namespace orc diff --git a/llvm/unittests/ExecutionEngine/Orc/ThreadSafeModuleTest.cpp b/llvm/unittests/ExecutionEngine/Orc/ThreadSafeModuleTest.cpp index adaa4d9..bbb9e8d 100644 --- a/llvm/unittests/ExecutionEngine/Orc/ThreadSafeModuleTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/ThreadSafeModuleTest.cpp @@ -7,6 +7,13 @@ //===----------------------------------------------------------------------===// #include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + #include "gtest/gtest.h" #include <atomic> @@ -18,6 +25,24 @@ using namespace llvm::orc; namespace { +const llvm::StringRef FooSrc = R"( + define void @foo() { + ret void + } +)"; + +static ThreadSafeModule parseModule(llvm::StringRef Source, + llvm::StringRef Name) { + auto Ctx = std::make_unique<LLVMContext>(); + SMDiagnostic Err; + auto M = parseIR(MemoryBufferRef(Source, Name), Err, *Ctx); + if (!M) { + Err.print("Testcase source failed to parse: ", errs()); + exit(1); + } + return ThreadSafeModule(std::move(M), std::move(Ctx)); +} + TEST(ThreadSafeModuleTest, ContextWhollyOwnedByOneModule) { // Test that ownership of a context can be transferred to a single // ThreadSafeModule. @@ -103,4 +128,28 @@ TEST(ThreadSafeModuleTest, ConsumingModuleDo) { TSM.consumingModuleDo([](std::unique_ptr<Module> M) {}); } +TEST(ThreadSafeModuleTest, CloneToNewContext) { + auto TSM1 = parseModule(FooSrc, "foo.ll"); + auto TSM2 = cloneToNewContext(TSM1); + TSM2.withModuleDo([&](Module &NewM) { + EXPECT_FALSE(verifyModule(NewM, &errs())); + TSM1.withModuleDo([&](Module &OrigM) { + EXPECT_NE(&NewM.getContext(), &OrigM.getContext()); + }); + }); +} + +TEST(ObjectFormatsTest, CloneToContext) { + auto TSM1 = parseModule(FooSrc, "foo.ll"); + + auto TSCtx = ThreadSafeContext(std::make_unique<LLVMContext>()); + auto TSM2 = cloneToContext(TSM1, TSCtx); + + TSM2.withModuleDo([&](Module &M) { + EXPECT_FALSE(verifyModule(M, &errs())); + TSCtx.withContextDo( + [&](LLVMContext *Ctx) { EXPECT_EQ(&M.getContext(), Ctx); }); + }); +} + } // end anonymous namespace |