diff options
Diffstat (limited to 'llvm/unittests/Analysis/IR2VecTest.cpp')
-rw-r--r-- | llvm/unittests/Analysis/IR2VecTest.cpp | 92 |
1 files changed, 23 insertions, 69 deletions
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index 40b4aa2..8ffc5f6 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -30,7 +30,9 @@ namespace { class TestableEmbedder : public Embedder { public: TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {} - void computeEmbeddings(const BasicBlock &BB) const override {} + Embedding computeEmbeddings(const Instruction &I) const override { + return Embedding(); + } }; TEST(EmbeddingTest, ConstructorsAndAccessors) { @@ -321,18 +323,12 @@ protected: } }; -TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) { +TEST_F(IR2VecTestFixture, GetInstVec_Symbolic) { auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V); ASSERT_TRUE(static_cast<bool>(Emb)); - const auto &InstMap = Emb->getInstVecMap(); - - EXPECT_EQ(InstMap.size(), 2u); - EXPECT_TRUE(InstMap.count(AddInst)); - EXPECT_TRUE(InstMap.count(RetInst)); - - const auto &AddEmb = InstMap.at(AddInst); - const auto &RetEmb = InstMap.at(RetInst); + const auto &AddEmb = Emb->getInstVector(*AddInst); + const auto &RetEmb = Emb->getInstVector(*RetInst); EXPECT_EQ(AddEmb.size(), 2u); EXPECT_EQ(RetEmb.size(), 2u); @@ -340,51 +336,17 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) { EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 15.5))); } -TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) { - auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V); - ASSERT_TRUE(static_cast<bool>(Emb)); - - const auto &InstMap = Emb->getInstVecMap(); - - EXPECT_EQ(InstMap.size(), 2u); - EXPECT_TRUE(InstMap.count(AddInst)); - EXPECT_TRUE(InstMap.count(RetInst)); - - EXPECT_EQ(InstMap.at(AddInst).size(), 2u); - EXPECT_EQ(InstMap.at(RetInst).size(), 2u); - - EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 25.5))); - EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 32.6))); -} - -TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) { - auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V); - ASSERT_TRUE(static_cast<bool>(Emb)); - - const auto &BBMap = Emb->getBBVecMap(); - - EXPECT_EQ(BBMap.size(), 1u); - EXPECT_TRUE(BBMap.count(BB)); - EXPECT_EQ(BBMap.at(BB).size(), 2u); - - // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} = - // {41.0, 41.0} - EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 41.0))); -} - -TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) { +TEST_F(IR2VecTestFixture, GetInstVec_FlowAware) { auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V); ASSERT_TRUE(static_cast<bool>(Emb)); - const auto &BBMap = Emb->getBBVecMap(); - - EXPECT_EQ(BBMap.size(), 1u); - EXPECT_TRUE(BBMap.count(BB)); - EXPECT_EQ(BBMap.at(BB).size(), 2u); + const auto &AddEmb = Emb->getInstVector(*AddInst); + const auto &RetEmb = Emb->getInstVector(*RetInst); + EXPECT_EQ(AddEmb.size(), 2u); + EXPECT_EQ(RetEmb.size(), 2u); - // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} = - // {58.1, 58.1} - EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 58.1))); + EXPECT_TRUE(AddEmb.approximatelyEquals(Embedding(2, 25.5))); + EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 32.6))); } TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) { @@ -394,6 +356,8 @@ TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) { const auto &BBVec = Emb->getBBVector(*BB); EXPECT_EQ(BBVec.size(), 2u); + // BB vector should be sum of add and ret: {25.5, 25.5} + {15.5, 15.5} = + // {41.0, 41.0} EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 41.0))); } @@ -404,6 +368,8 @@ TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) { const auto &BBVec = Emb->getBBVector(*BB); EXPECT_EQ(BBVec.size(), 2u); + // BB vector should be sum of add and ret: {25.5, 25.5} + {32.6, 32.6} = + // {58.1, 58.1} EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 58.1))); } @@ -446,15 +412,9 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_Symbolic) { EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3)); EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3)); - // Also check that instruction vectors remain consistent - const auto &InstMap1 = Emb->getInstVecMap(); - const auto &InstMap2 = Emb->getInstVecMap(); - - EXPECT_EQ(InstMap1.size(), InstMap2.size()); - for (const auto &[Inst, Vec1] : InstMap1) { - ASSERT_TRUE(InstMap2.count(Inst)); - EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst))); - } + Emb->invalidateEmbeddings(); + const auto &FuncVec4 = Emb->getFunctionVector(); + EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec4)); } TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) { @@ -473,15 +433,9 @@ TEST_F(IR2VecTestFixture, MultipleComputeEmbeddingsConsistency_FlowAware) { EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec3)); EXPECT_TRUE(FuncVec2.approximatelyEquals(FuncVec3)); - // Also check that instruction vectors remain consistent - const auto &InstMap1 = Emb->getInstVecMap(); - const auto &InstMap2 = Emb->getInstVecMap(); - - EXPECT_EQ(InstMap1.size(), InstMap2.size()); - for (const auto &[Inst, Vec1] : InstMap1) { - ASSERT_TRUE(InstMap2.count(Inst)); - EXPECT_TRUE(Vec1.approximatelyEquals(InstMap2.at(Inst))); - } + Emb->invalidateEmbeddings(); + const auto &FuncVec4 = Emb->getFunctionVector(); + EXPECT_TRUE(FuncVec1.approximatelyEquals(FuncVec4)); } static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes; |