aboutsummaryrefslogtreecommitdiff
path: root/llvm/unittests/Analysis/IR2VecTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/unittests/Analysis/IR2VecTest.cpp')
-rw-r--r--llvm/unittests/Analysis/IR2VecTest.cpp92
1 files changed, 23 insertions, 69 deletions
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 40b4aa2..8ffc5f6 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -30,7 +30,9 @@ namespace {
class TestableEmbedder : public Embedder {
public:
TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {}
- void computeEmbeddings(const BasicBlock &BB) const override {}
+ Embedding computeEmbeddings(const Instruction &I) const override {
+ return Embedding();
+ }
};
TEST(EmbeddingTest, ConstructorsAndAccessors) {
@@ -321,18 +323,12 @@ protected:
}
};
-TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
+TEST_F(IR2VecTestFixture, GetInstVec_Symbolic) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
- const auto &InstMap = Emb->getInstVecMap();
-
- EXPECT_EQ(InstMap.size(), 2u);
- EXPECT_TRUE(InstMap.count(AddInst));
- EXPECT_TRUE(InstMap.count(RetInst));
-
- const auto &AddEmb = InstMap.at(AddInst);
- const auto &RetEmb = InstMap.at(RetInst);
+ const auto &AddEmb = Emb->getInstVector(*AddInst);
+ const auto &RetEmb = Emb->getInstVector(*RetInst);
EXPECT_EQ(AddEmb.size(), 2u);
EXPECT_EQ(RetEmb.size(), 2u);
@@ -340,51 +336,17 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 15.5)));
}
-TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
- auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
- ASSERT_TRUE(static_cast<bool>(Emb));
-
- const auto &InstMap = Emb->getInstVecMap();
-
- EXPECT_EQ(InstMap.size(), 2u);
- EXPECT_TRUE(InstMap.count(AddInst));
- EXPECT_TRUE(InstMap.count(RetInst));
-
- EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
- EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
-
- EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 25.5)));
- EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 32.6)));
-}
-
-TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
- auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
- ASSERT_TRUE(static_cast<bool>(Emb));
-
- const auto &BBMap = Emb->getBBVecMap();
-
- EXPECT_EQ(BBMap.size(), 1u);
- EXPECT_TRUE(BBMap.count(BB));
- EXPECT_EQ(BBMap.at(BB).size(), 2u);
-
- // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} =
- // {41.0, 41.0}
- EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 41.0)));
-}
-
-TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
+TEST_F(IR2VecTestFixture, GetInstVec_FlowAware) {
auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
- const auto &BBMap = Emb->getBBVecMap();
-
- EXPECT_EQ(BBMap.size(), 1u);
- EXPECT_TRUE(BBMap.count(BB));
- EXPECT_EQ(BBMap.at(BB).size(), 2u);
+ const auto &AddEmb = Emb->getInstVector(*AddInst);
+ const auto &RetEmb = Emb->getInstVector(*RetInst);
+ EXPECT_EQ(AddEmb.size(), 2u);
+ EXPECT_EQ(RetEmb.size(), 2u);
- // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} =
- // {58.1, 58.1}
- EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 58.1)));
+ EXPECT_TRUE(AddEmb.approximatelyEquals(Embedding(2, 25.5)));
+ EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 32.6)));
}
TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
@@ -394,6 +356,8 @@ TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
const auto &BBVec = Emb->getBBVector(*BB);
EXPECT_EQ(BBVec.size(), 2u);
+ // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} =
+ // {41.0, 41.0}
EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 41.0)));
}
@@ -404,6 +368,8 @@ TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
const auto &BBVec = Emb->getBBVector(*BB);
EXPECT_EQ(BBVec.size(), 2u);
+ // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} =
+ // {58.1, 58.1}
EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 58.1)));
}
@@ -446,15 +412,9 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_Symbolic) {
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3));
EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3));
- // Also check that instruction vectors remain consistent
- const auto &InstMap1 = Emb->getInstVecMap();
- const auto &InstMap2 = Emb->getInstVecMap();
-
- EXPECT_EQ(InstMap1.size(), InstMap2.size());
- for (const auto &[Inst, Vec1] : InstMap1) {
- ASSERT_TRUE(InstMap2.count(Inst));
- EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst)));
- }
+ Emb->invalidateEmbeddings();
+ const auto &FuncVec4 = Emb->getFunctionVector();
+ EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec4));
}
TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) {
@@ -473,15 +433,9 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) {
EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3));
EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3));
- // Also check that instruction vectors remain consistent
- const auto &InstMap1 = Emb->getInstVecMap();
- const auto &InstMap2 = Emb->getInstVecMap();
-
- EXPECT_EQ(InstMap1.size(), InstMap2.size());
- for (const auto &[Inst, Vec1] : InstMap1) {
- ASSERT_TRUE(InstMap2.count(Inst));
- EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst)));
- }
+ Emb->invalidateEmbeddings();
+ const auto &FuncVec4 = Emb->getFunctionVector();
+ EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec4));
}
static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;