aboutsummaryrefslogtreecommitdiff
path: root/llvm/unittests/Analysis
diff options
context:
space:
mode:
authorS. VenkataKeerthy <31350914+svkeerthy@users.noreply.github.com>2025-06-13 10:43:22 -0700
committerGitHub <noreply@github.com>2025-06-13 10:43:22 -0700
commit09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc (patch)
tree152e52509f30a78bfd40ee0dea1732198479a948 /llvm/unittests/Analysis
parentecdb549e6de60b3211cfa860eec498270e3980f1 (diff)
downloadllvm-09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc.zip
llvm-09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc.tar.gz
llvm-09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc.tar.bz2
[IR2Vec] Minor vocab changes and exposing weights (#143200)
This PR changes some asserts in Vocab to hard checks that emit error and exposes flags and constructor to help in unit tests. (Tracking issue - #141817)
Diffstat (limited to 'llvm/unittests/Analysis')
-rw-r--r--llvm/unittests/Analysis/IR2VecTest.cpp137
1 files changed, 102 insertions, 35 deletions
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 053b9f7..90d07d0 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -281,25 +281,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
EXPECT_EQ(validResult.getDimension(), 2u);
}
-// Helper to create a minimal function and embedder for getter tests
-struct GetterTestEnv {
- Vocab V = {};
+// Fixture for IR2Vec tests requiring IR setup and weight management.
+class IR2VecTestFixture : public ::testing::Test {
+protected:
+ Vocab V;
LLVMContext Ctx;
- std::unique_ptr<Module> M = nullptr;
+ std::unique_ptr<Module> M;
Function *F = nullptr;
BasicBlock *BB = nullptr;
- Instruction *Add = nullptr;
- Instruction *Ret = nullptr;
- std::unique_ptr<Embedder> Emb = nullptr;
+ Instruction *AddInst = nullptr;
+ Instruction *RetInst = nullptr;
- GetterTestEnv() {
+ float OriginalOpcWeight = ::OpcWeight;
+ float OriginalTypeWeight = ::TypeWeight;
+ float OriginalArgWeight = ::ArgWeight;
+
+ void SetUp() override {
V = {{"add", {1.0, 2.0}},
{"integerTy", {0.5, 0.5}},
{"constant", {0.2, 0.3}},
{"variable", {0.0, 0.0}},
{"unknownTy", {0.0, 0.0}}};
- M = std::make_unique<Module>("M", Ctx);
+ // Setup IR
+ M = std::make_unique<Module>("TestM", Ctx);
FunctionType *FTy = FunctionType::get(
Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
false);
@@ -308,61 +313,82 @@ struct GetterTestEnv {
Argument *Arg = F->getArg(0);
llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
- Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
- Ret = ReturnInst::Create(Ctx, Add, BB);
+ AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
+ RetInst = ReturnInst::Create(Ctx, AddInst, BB);
+ }
+
+ void setWeights(float OpcWeight, float TypeWeight, float ArgWeight) {
+ ::OpcWeight = OpcWeight;
+ ::TypeWeight = TypeWeight;
+ ::ArgWeight = ArgWeight;
+ }
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- EXPECT_TRUE(static_cast<bool>(Result));
- Emb = std::move(*Result);
+ void TearDown() override {
+ // Restore original global weights
+ ::OpcWeight = OriginalOpcWeight;
+ ::TypeWeight = OriginalTypeWeight;
+ ::ArgWeight = OriginalArgWeight;
}
};
-TEST(IR2VecTest, GetInstVecMap) {
- GetterTestEnv Env;
- const auto &InstMap = Env.Emb->getInstVecMap();
+TEST_F(IR2VecTestFixture, GetInstVecMap) {
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &InstMap = Emb->getInstVecMap();
EXPECT_EQ(InstMap.size(), 2u);
- EXPECT_TRUE(InstMap.count(Env.Add));
- EXPECT_TRUE(InstMap.count(Env.Ret));
+ EXPECT_TRUE(InstMap.count(AddInst));
+ EXPECT_TRUE(InstMap.count(RetInst));
- EXPECT_EQ(InstMap.at(Env.Add).size(), 2u);
- EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u);
+ EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
+ EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
// Check values for add: {1.29, 2.31}
- EXPECT_THAT(InstMap.at(Env.Add),
+ EXPECT_THAT(InstMap.at(AddInst),
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
// vocab
- EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0));
+ EXPECT_THAT(InstMap.at(RetInst), ElementsAre(0.0, 0.0));
}
-TEST(IR2VecTest, GetBBVecMap) {
- GetterTestEnv Env;
- const auto &BBMap = Env.Emb->getBBVecMap();
+TEST_F(IR2VecTestFixture, GetBBVecMap) {
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &BBMap = Emb->getBBVecMap();
EXPECT_EQ(BBMap.size(), 1u);
- EXPECT_TRUE(BBMap.count(Env.BB));
- EXPECT_EQ(BBMap.at(Env.BB).size(), 2u);
+ EXPECT_TRUE(BBMap.count(BB));
+ EXPECT_EQ(BBMap.at(BB).size(), 2u);
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
// {1.29, 2.31}
- EXPECT_THAT(BBMap.at(Env.BB),
+ EXPECT_THAT(BBMap.at(BB),
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}
-TEST(IR2VecTest, GetBBVector) {
- GetterTestEnv Env;
- const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
+TEST_F(IR2VecTestFixture, GetBBVector) {
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &BBVec = Emb->getBBVector(*BB);
EXPECT_EQ(BBVec.size(), 2u);
EXPECT_THAT(BBVec,
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}
-TEST(IR2VecTest, GetFunctionVector) {
- GetterTestEnv Env;
- const auto &FuncVec = Env.Emb->getFunctionVector();
+TEST_F(IR2VecTestFixture, GetFunctionVector) {
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &FuncVec = Emb->getFunctionVector();
EXPECT_EQ(FuncVec.size(), 2u);
@@ -371,4 +397,45 @@ TEST(IR2VecTest, GetFunctionVector) {
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}
+TEST_F(IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
+ setWeights(1.0, 1.0, 1.0);
+
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &FuncVec = Emb->getFunctionVector();
+
+ EXPECT_EQ(FuncVec.size(), 2u);
+
+ // Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
+ // 0.3] + [0.0 0.0])
+ EXPECT_THAT(FuncVec,
+ ElementsAre(DoubleNear(1.7, 1e-6), DoubleNear(2.8, 1e-6)));
+}
+
+TEST(IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
+ Vocab InitialVocab = {{"key1", {1.1, 2.2}}, {"key2", {3.3, 4.4}}};
+ Vocab ExpectedVocab = InitialVocab;
+ unsigned ExpectedDim = InitialVocab.begin()->second.size();
+
+ IR2VecVocabAnalysis VocabAnalysis(std::move(InitialVocab));
+
+ LLVMContext TestCtx;
+ Module TestMod("TestModuleForVocabAnalysis", TestCtx);
+ ModuleAnalysisManager MAM;
+ IR2VecVocabResult Result = VocabAnalysis.run(TestMod, MAM);
+
+ EXPECT_TRUE(Result.isValid());
+ ASSERT_FALSE(Result.getVocabulary().empty());
+ EXPECT_EQ(Result.getDimension(), ExpectedDim);
+
+ const auto &ResultVocab = Result.getVocabulary();
+ EXPECT_EQ(ResultVocab.size(), ExpectedVocab.size());
+ for (const auto &pair : ExpectedVocab) {
+ EXPECT_TRUE(ResultVocab.count(pair.first));
+ EXPECT_THAT(ResultVocab.at(pair.first), ElementsAreArray(pair.second));
+ }
+}
+
} // end anonymous namespace