diff options
author | S. VenkataKeerthy <31350914+svkeerthy@users.noreply.github.com> | 2025-06-13 10:43:22 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-13 10:43:22 -0700 |
commit | 09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc (patch) | |
tree | 152e52509f30a78bfd40ee0dea1732198479a948 /llvm/unittests/Analysis | |
parent | ecdb549e6de60b3211cfa860eec498270e3980f1 (diff) | |
download | llvm-09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc.zip llvm-09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc.tar.gz llvm-09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc.tar.bz2 |
[IR2Vec] Minor vocab changes and exposing weights (#143200)
This PR changes some asserts in Vocab to hard checks that emit error and exposes flags and constructor to help in unit tests.
(Tracking issue - #141817)
Diffstat (limited to 'llvm/unittests/Analysis')
-rw-r--r-- | llvm/unittests/Analysis/IR2VecTest.cpp | 137 |
1 files changed, 102 insertions, 35 deletions
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index 053b9f7..90d07d0 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -281,25 +281,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) { EXPECT_EQ(validResult.getDimension(), 2u); } -// Helper to create a minimal function and embedder for getter tests -struct GetterTestEnv { - Vocab V = {}; +// Fixture for IR2Vec tests requiring IR setup and weight management. +class IR2VecTestFixture : public ::testing::Test { +protected: + Vocab V; LLVMContext Ctx; - std::unique_ptr<Module> M = nullptr; + std::unique_ptr<Module> M; Function *F = nullptr; BasicBlock *BB = nullptr; - Instruction *Add = nullptr; - Instruction *Ret = nullptr; - std::unique_ptr<Embedder> Emb = nullptr; + Instruction *AddInst = nullptr; + Instruction *RetInst = nullptr; - GetterTestEnv() { + float OriginalOpcWeight = ::OpcWeight; + float OriginalTypeWeight = ::TypeWeight; + float OriginalArgWeight = ::ArgWeight; + + void SetUp() override { V = {{"add", {1.0, 2.0}}, {"integerTy", {0.5, 0.5}}, {"constant", {0.2, 0.3}}, {"variable", {0.0, 0.0}}, {"unknownTy", {0.0, 0.0}}}; - M = std::make_unique<Module>("M", Ctx); + // Setup IR + M = std::make_unique<Module>("TestM", Ctx); FunctionType *FTy = FunctionType::get( Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)}, false); @@ -308,61 +313,82 @@ struct GetterTestEnv { Argument *Arg = F->getArg(0); llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42); - Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB); - Ret = ReturnInst::Create(Ctx, Add, BB); + AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB); + RetInst = ReturnInst::Create(Ctx, AddInst, BB); + } + + void setWeights(float OpcWeight, float TypeWeight, float ArgWeight) { + ::OpcWeight = OpcWeight; + ::TypeWeight = TypeWeight; + ::ArgWeight = ArgWeight; + } - auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); - EXPECT_TRUE(static_cast<bool>(Result)); - Emb = std::move(*Result); + void TearDown() override { + // Restore original global weights + ::OpcWeight = OriginalOpcWeight; + ::TypeWeight = OriginalTypeWeight; + ::ArgWeight = OriginalArgWeight; } }; -TEST(IR2VecTest, GetInstVecMap) { - GetterTestEnv Env; - const auto &InstMap = Env.Emb->getInstVecMap(); +TEST_F(IR2VecTestFixture, GetInstVecMap) { + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); + ASSERT_TRUE(static_cast<bool>(Result)); + auto Emb = std::move(*Result); + + const auto &InstMap = Emb->getInstVecMap(); EXPECT_EQ(InstMap.size(), 2u); - EXPECT_TRUE(InstMap.count(Env.Add)); - EXPECT_TRUE(InstMap.count(Env.Ret)); + EXPECT_TRUE(InstMap.count(AddInst)); + EXPECT_TRUE(InstMap.count(RetInst)); - EXPECT_EQ(InstMap.at(Env.Add).size(), 2u); - EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u); + EXPECT_EQ(InstMap.at(AddInst).size(), 2u); + EXPECT_EQ(InstMap.at(RetInst).size(), 2u); // Check values for add: {1.29, 2.31} - EXPECT_THAT(InstMap.at(Env.Add), + EXPECT_THAT(InstMap.at(AddInst), ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6))); // Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in // vocab - EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0)); + EXPECT_THAT(InstMap.at(RetInst), ElementsAre(0.0, 0.0)); } -TEST(IR2VecTest, GetBBVecMap) { - GetterTestEnv Env; - const auto &BBMap = Env.Emb->getBBVecMap(); +TEST_F(IR2VecTestFixture, GetBBVecMap) { + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); + ASSERT_TRUE(static_cast<bool>(Result)); + auto Emb = std::move(*Result); + + const auto &BBMap = Emb->getBBVecMap(); EXPECT_EQ(BBMap.size(), 1u); - EXPECT_TRUE(BBMap.count(Env.BB)); - EXPECT_EQ(BBMap.at(Env.BB).size(), 2u); + EXPECT_TRUE(BBMap.count(BB)); + EXPECT_EQ(BBMap.at(BB).size(), 2u); // BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} = // {1.29, 2.31} - EXPECT_THAT(BBMap.at(Env.BB), + EXPECT_THAT(BBMap.at(BB), ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6))); } -TEST(IR2VecTest, GetBBVector) { - GetterTestEnv Env; - const auto &BBVec = Env.Emb->getBBVector(*Env.BB); +TEST_F(IR2VecTestFixture, GetBBVector) { + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); + ASSERT_TRUE(static_cast<bool>(Result)); + auto Emb = std::move(*Result); + + const auto &BBVec = Emb->getBBVector(*BB); EXPECT_EQ(BBVec.size(), 2u); EXPECT_THAT(BBVec, ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6))); } -TEST(IR2VecTest, GetFunctionVector) { - GetterTestEnv Env; - const auto &FuncVec = Env.Emb->getFunctionVector(); +TEST_F(IR2VecTestFixture, GetFunctionVector) { + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); + ASSERT_TRUE(static_cast<bool>(Result)); + auto Emb = std::move(*Result); + + const auto &FuncVec = Emb->getFunctionVector(); EXPECT_EQ(FuncVec.size(), 2u); @@ -371,4 +397,45 @@ TEST(IR2VecTest, GetFunctionVector) { ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6))); } +TEST_F(IR2VecTestFixture, GetFunctionVectorWithCustomWeights) { + setWeights(1.0, 1.0, 1.0); + + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V); + ASSERT_TRUE(static_cast<bool>(Result)); + auto Emb = std::move(*Result); + + const auto &FuncVec = Emb->getFunctionVector(); + + EXPECT_EQ(FuncVec.size(), 2u); + + // Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2 + // 0.3] + [0.0 0.0]) + EXPECT_THAT(FuncVec, + ElementsAre(DoubleNear(1.7, 1e-6), DoubleNear(2.8, 1e-6))); +} + +TEST(IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) { + Vocab InitialVocab = {{"key1", {1.1, 2.2}}, {"key2", {3.3, 4.4}}}; + Vocab ExpectedVocab = InitialVocab; + unsigned ExpectedDim = InitialVocab.begin()->second.size(); + + IR2VecVocabAnalysis VocabAnalysis(std::move(InitialVocab)); + + LLVMContext TestCtx; + Module TestMod("TestModuleForVocabAnalysis", TestCtx); + ModuleAnalysisManager MAM; + IR2VecVocabResult Result = VocabAnalysis.run(TestMod, MAM); + + EXPECT_TRUE(Result.isValid()); + ASSERT_FALSE(Result.getVocabulary().empty()); + EXPECT_EQ(Result.getDimension(), ExpectedDim); + + const auto &ResultVocab = Result.getVocabulary(); + EXPECT_EQ(ResultVocab.size(), ExpectedVocab.size()); + for (const auto &pair : ExpectedVocab) { + EXPECT_TRUE(ResultVocab.count(pair.first)); + EXPECT_THAT(ResultVocab.at(pair.first), ElementsAreArray(pair.second)); + } +} + } // end anonymous namespace |