diff options
author | vporpo <vporpodas@google.com> | 2025-02-14 11:26:21 -0800 |
---|---|---|
committer | joaosaffran <joao.saffran@microsoft.com> | 2025-02-14 20:26:34 +0000 |
commit | d6b82e8bf4de66278da7660cc194de5823ff05bb (patch) | |
tree | daa6462f9ae16c63d79078abee20f5b65228d4f9 | |
parent | 2c50fc34dea7495bcc0b125a721151cdb7a7b9d6 (diff) | |
download | llvm-users/joaosaffran/123147.zip llvm-users/joaosaffran/123147.tar.gz llvm-users/joaosaffran/123147.tar.bz2 |
[SandboxIR] SetUse callback (#126985)users/joaosaffran/123147
This patch implements a callback mechanism similar to the existing ones,
but for getting notified whenever a Use edge gets updated. This is going
to be used in a follow up patch by the Dependency Graph.
-rw-r--r-- | llvm/include/llvm/SandboxIR/Context.h | 15 | ||||
-rw-r--r-- | llvm/lib/SandboxIR/Context.cpp | 18 | ||||
-rw-r--r-- | llvm/lib/SandboxIR/User.cpp | 13 | ||||
-rw-r--r-- | llvm/lib/SandboxIR/Value.cpp | 8 | ||||
-rw-r--r-- | llvm/unittests/SandboxIR/SandboxIRTest.cpp | 66 |
5 files changed, 111 insertions, 9 deletions
diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h index a88b000..714d1ec 100644 --- a/llvm/include/llvm/SandboxIR/Context.h +++ b/llvm/include/llvm/SandboxIR/Context.h @@ -26,6 +26,7 @@ class BBIterator; class Constant; class Module; class Value; +class Use; class Context { public: @@ -37,6 +38,8 @@ public: // destination BB and an iterator pointing to the insertion position. using MoveInstrCallback = std::function<void(Instruction *, const BBIterator &)>; + // A SetUseCallback receives the Use that is about to get its source set. + using SetUseCallback = std::function<void(const Use &, Value *)>; /// An ID for a registered callback. Used for deregistration. A dedicated type /// is employed so as to keep IDs opaque to the end user; only Context should @@ -98,6 +101,9 @@ protected: /// Callbacks called when an IR instruction is about to get moved. Keys are /// used as IDs for deregistration. MapVector<CallbackID, MoveInstrCallback> MoveInstrCallbacks; + /// Callbacks called when a Use gets its source set. Keys are used as IDs for + /// deregistration. + MapVector<CallbackID, SetUseCallback> SetUseCallbacks; /// A counter used for assigning callback IDs during registration. The same /// counter is used for all kinds of callbacks so we can detect mismatched @@ -129,6 +135,10 @@ protected: void runEraseInstrCallbacks(Instruction *I); void runCreateInstrCallbacks(Instruction *I); void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where); + void runSetUseCallbacks(const Use &U, Value *NewSrc); + + friend class User; // For runSetUseCallbacks(). + friend class Value; // For runSetUseCallbacks(). // Friends for getOrCreateConstant(). #define DEF_CONST(ID, CLASS) friend class CLASS; @@ -281,7 +291,10 @@ public: CallbackID registerMoveInstrCallback(MoveInstrCallback CB); void unregisterMoveInstrCallback(CallbackID ID); - // TODO: Add callbacks for instructions inserted/removed if needed. + /// Register a callback that gets called when a Use gets set. + /// \Returns a callback ID for later deregistration. + CallbackID registerSetUseCallback(SetUseCallback CB); + void unregisterSetUseCallback(CallbackID ID); }; } // namespace sandboxir diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp index 830f283..6a397b0 100644 --- a/llvm/lib/SandboxIR/Context.cpp +++ b/llvm/lib/SandboxIR/Context.cpp @@ -687,6 +687,11 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) { CBEntry.second(I, WhereIt); } +void Context::runSetUseCallbacks(const Use &U, Value *NewSrc) { + for (auto &CBEntry : SetUseCallbacks) + CBEntry.second(U, NewSrc); +} + // An arbitrary limit, to check for accidental misuse. We expect a small number // of callbacks to be registered at a time, but we can increase this number if // we discover we needed more. @@ -732,4 +737,17 @@ void Context::unregisterMoveInstrCallback(CallbackID ID) { "Callback ID not found in MoveInstrCallbacks during deregistration"); } +Context::CallbackID Context::registerSetUseCallback(SetUseCallback CB) { + assert(SetUseCallbacks.size() <= MaxRegisteredCallbacks && + "SetUseCallbacks size limit exceeded"); + CallbackID ID{NextCallbackID++}; + SetUseCallbacks[ID] = CB; + return ID; +} +void Context::unregisterSetUseCallback(CallbackID ID) { + [[maybe_unused]] bool Erased = SetUseCallbacks.erase(ID); + assert(Erased && + "Callback ID not found in SetUseCallbacks during deregistration"); +} + } // namespace llvm::sandboxir diff --git a/llvm/lib/SandboxIR/User.cpp b/llvm/lib/SandboxIR/User.cpp index d7e4656..43fd565 100644 --- a/llvm/lib/SandboxIR/User.cpp +++ b/llvm/lib/SandboxIR/User.cpp @@ -90,17 +90,20 @@ bool User::classof(const Value *From) { void User::setOperand(unsigned OperandIdx, Value *Operand) { assert(isa<llvm::User>(Val) && "No operands!"); - Ctx.getTracker().emplaceIfTracking<UseSet>(getOperandUse(OperandIdx)); + const auto &U = getOperandUse(OperandIdx); + Ctx.getTracker().emplaceIfTracking<UseSet>(U); + Ctx.runSetUseCallbacks(U, Operand); // We are delegating to llvm::User::setOperand(). cast<llvm::User>(Val)->setOperand(OperandIdx, Operand->Val); } bool User::replaceUsesOfWith(Value *FromV, Value *ToV) { auto &Tracker = Ctx.getTracker(); - if (Tracker.isTracking()) { - for (auto OpIdx : seq<unsigned>(0, getNumOperands())) { - auto Use = getOperandUse(OpIdx); - if (Use.get() == FromV) + for (auto OpIdx : seq<unsigned>(0, getNumOperands())) { + auto Use = getOperandUse(OpIdx); + if (Use.get() == FromV) { + Ctx.runSetUseCallbacks(Use, ToV); + if (Tracker.isTracking()) Tracker.emplaceIfTracking<UseSet>(Use); } } diff --git a/llvm/lib/SandboxIR/Value.cpp b/llvm/lib/SandboxIR/Value.cpp index b9d91c7..e39bbc4 100644 --- a/llvm/lib/SandboxIR/Value.cpp +++ b/llvm/lib/SandboxIR/Value.cpp @@ -51,7 +51,7 @@ void Value::replaceUsesWithIf( llvm::Value *OtherVal = OtherV->Val; // We are delegating RUWIf to LLVM IR's RUWIf. Val->replaceUsesWithIf( - OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool { + OtherVal, [&ShouldReplace, this, OtherV](llvm::Use &LLVMUse) -> bool { User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser())); if (DstU == nullptr) return false; @@ -59,6 +59,7 @@ void Value::replaceUsesWithIf( if (!ShouldReplace(UseToReplace)) return false; Ctx.getTracker().emplaceIfTracking<UseSet>(UseToReplace); + Ctx.runSetUseCallbacks(UseToReplace, OtherV); return true; }); } @@ -67,8 +68,9 @@ void Value::replaceAllUsesWith(Value *Other) { assert(getType() == Other->getType() && "Replacing with Value of different type!"); auto &Tracker = Ctx.getTracker(); - if (Tracker.isTracking()) { - for (auto Use : uses()) + for (auto Use : uses()) { + Ctx.runSetUseCallbacks(Use, Other); + if (Tracker.isTracking()) Tracker.track(std::make_unique<UseSet>(Use)); } // We are delegating RAUW to LLVM IR's RAUW. diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 9eeac9b..2ad3365 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -6081,6 +6081,72 @@ TEST_F(SandboxIRTest, InstructionCallbacks) { EXPECT_THAT(Moved, testing::IsEmpty()); } +// Check callbacks when we set a Use. +TEST_F(SandboxIRTest, SetUseCallbacks) { + parseIR(C, R"IR( +define void @foo(i8 %v0, i8 %v1) { + %add0 = add i8 %v0, %v1 + %add1 = add i8 %add0, %v1 + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *Arg0 = F->getArg(0); + auto *BB = &*F->begin(); + auto It = BB->begin(); + auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++); + auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++); + + SmallVector<std::pair<sandboxir::Use, sandboxir::Value *>> UsesSet; + auto Id = Ctx.registerSetUseCallback( + [&UsesSet](sandboxir::Use U, sandboxir::Value *NewSrc) { + UsesSet.push_back({U, NewSrc}); + }); + + // Now change %add1 operand to not use %add0. + Add1->setOperand(0, Arg0); + EXPECT_EQ(UsesSet.size(), 1u); + EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get()); + EXPECT_EQ(UsesSet[0].second, Arg0); + // Restore to previous state. + Add1->setOperand(0, Add0); + UsesSet.clear(); + + // RAUW + Add0->replaceAllUsesWith(Arg0); + EXPECT_EQ(UsesSet.size(), 1u); + EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get()); + EXPECT_EQ(UsesSet[0].second, Arg0); + // Restore to previous state. + Add1->setOperand(0, Add0); + UsesSet.clear(); + + // RUWIf + Add0->replaceUsesWithIf(Arg0, [](const auto &U) { return true; }); + EXPECT_EQ(UsesSet.size(), 1u); + EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get()); + EXPECT_EQ(UsesSet[0].second, Arg0); + // Restore to previous state. + Add1->setOperand(0, Add0); + UsesSet.clear(); + + // RUOW + Add1->replaceUsesOfWith(Add0, Arg0); + EXPECT_EQ(UsesSet.size(), 1u); + EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get()); + EXPECT_EQ(UsesSet[0].second, Arg0); + // Restore to previous state. + Add1->setOperand(0, Add0); + UsesSet.clear(); + + // Check unregister. + Ctx.unregisterSetUseCallback(Id); + Add0->replaceAllUsesWith(Arg0); + EXPECT_TRUE(UsesSet.empty()); +} + TEST_F(SandboxIRTest, FunctionObjectAlreadyExists) { parseIR(C, R"IR( define void @foo() { |