aboutsummaryrefslogtreecommitdiff
path: root/llvm/unittests/Analysis/IR2VecTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/unittests/Analysis/IR2VecTest.cpp')
-rw-r--r--llvm/unittests/Analysis/IR2VecTest.cpp388
1 files changed, 352 insertions, 36 deletions
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 9f2f6a3..743628f 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -295,7 +295,7 @@ TEST(IR2VecTest, ZeroDimensionEmbedding) {
// Fixture for IR2Vec tests requiring IR setup.
class IR2VecTestFixture : public ::testing::Test {
protected:
- Vocabulary V;
+ Vocabulary *V;
LLVMContext Ctx;
std::unique_ptr<Module> M;
Function *F = nullptr;
@@ -304,7 +304,7 @@ protected:
Instruction *RetInst = nullptr;
void SetUp() override {
- V = Vocabulary(Vocabulary::createDummyVocabForTest(2));
+ V = new Vocabulary(Vocabulary::createDummyVocabForTest(2));
// Setup IR
M = std::make_unique<Module>("TestM", Ctx);
@@ -322,7 +322,7 @@ protected:
};
TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
- auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &InstMap = Emb->getInstVecMap();
@@ -341,7 +341,7 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
}
TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
- auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &InstMap = Emb->getInstVecMap();
@@ -358,7 +358,7 @@ TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
}
TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
- auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBMap = Emb->getBBVecMap();
@@ -373,7 +373,7 @@ TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
}
TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
- auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBMap = Emb->getBBVecMap();
@@ -388,7 +388,7 @@ TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
}
TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
- auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBVec = Emb->getBBVector(*BB);
@@ -398,7 +398,7 @@ TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
}
TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
- auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBVec = Emb->getBBVector(*BB);
@@ -408,7 +408,7 @@ TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
}
TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) {
- auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &FuncVec = Emb->getFunctionVector();
@@ -420,7 +420,7 @@ TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) {
}
TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) {
- auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, *V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &FuncVec = Emb->getFunctionVector();
@@ -435,6 +435,7 @@ static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;
static constexpr unsigned MaxTypeIDs = Vocabulary::MaxTypeIDs;
static constexpr unsigned MaxCanonicalTypeIDs = Vocabulary::MaxCanonicalTypeIDs;
static constexpr unsigned MaxOperands = Vocabulary::MaxOperandKinds;
+static constexpr unsigned MaxPredicateKinds = Vocabulary::MaxPredicateKinds;
// Mapping between LLVM Type::TypeID tokens and Vocabulary::CanonicalTypeID
// names and their canonical string keys.
@@ -460,9 +461,13 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
EXPECT_EQ(Emb.size(), Dim);
// Should have the correct total number of embeddings
- EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands);
+ EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands +
+ MaxPredicateKinds);
- auto ExpectedVocab = VocabVec;
+ // Collect embeddings for later comparison before moving VocabVec
+ std::vector<Embedding> ExpectedVocab;
+ for (const auto &Emb : VocabVec)
+ ExpectedVocab.push_back(Emb);
IR2VecVocabAnalysis VocabAnalysis(std::move(VocabVec));
LLVMContext TestCtx;
@@ -480,17 +485,17 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
}
TEST(IR2VecVocabularyTest, SlotIdxMapping) {
- // Test getSlotIndex for Opcodes
+ // Test getIndex for Opcodes
#define EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS) \
- EXPECT_EQ(Vocabulary::getSlotIndex(NUM), static_cast<unsigned>(NUM - 1));
+ EXPECT_EQ(Vocabulary::getIndex(NUM), static_cast<unsigned>(NUM - 1));
#define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS)
#include "llvm/IR/Instruction.def"
#undef HANDLE_INST
#undef EXPECT_OPCODE_SLOT
- // Test getSlotIndex for Types
+ // Test getIndex for Types
#define EXPECT_TYPE_SLOT(TypeIDTok, CanonEnum, CanonStr) \
- EXPECT_EQ(Vocabulary::getSlotIndex(Type::TypeIDTok), \
+ EXPECT_EQ(Vocabulary::getIndex(Type::TypeIDTok), \
MaxOpcodes + static_cast<unsigned>( \
Vocabulary::CanonicalTypeID::CanonEnum));
@@ -498,7 +503,7 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
#undef EXPECT_TYPE_SLOT
- // Test getSlotIndex for Value operands
+ // Test getIndex for Value operands
LLVMContext Ctx;
Module M("TestM", Ctx);
FunctionType *FTy =
@@ -508,40 +513,59 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
#define EXPECTED_VOCAB_OPERAND_SLOT(X) \
MaxOpcodes + MaxCanonicalTypeIDs + static_cast<unsigned>(X)
// Test Function operand
- EXPECT_EQ(Vocabulary::getSlotIndex(*F),
+ EXPECT_EQ(Vocabulary::getIndex(*F),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::FunctionID));
// Test Constant operand
Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
- EXPECT_EQ(Vocabulary::getSlotIndex(*C),
+ EXPECT_EQ(Vocabulary::getIndex(*C),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::ConstantID));
// Test Pointer operand
BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);
- EXPECT_EQ(Vocabulary::getSlotIndex(*PtrVal),
+ EXPECT_EQ(Vocabulary::getIndex(*PtrVal),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::PointerID));
// Test Variable operand (function argument)
Argument *Arg = F->getArg(0);
- EXPECT_EQ(Vocabulary::getSlotIndex(*Arg),
+ EXPECT_EQ(Vocabulary::getIndex(*Arg),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::VariableID));
#undef EXPECTED_VOCAB_OPERAND_SLOT
+
+ // Test getIndex for predicates
+#define EXPECTED_VOCAB_PREDICATE_SLOT(X) \
+ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + static_cast<unsigned>(X)
+ for (unsigned P = CmpInst::FIRST_FCMP_PREDICATE;
+ P <= CmpInst::LAST_FCMP_PREDICATE; ++P) {
+ CmpInst::Predicate Pred = static_cast<CmpInst::Predicate>(P);
+ unsigned ExpectedIdx =
+ EXPECTED_VOCAB_PREDICATE_SLOT((P - CmpInst::FIRST_FCMP_PREDICATE));
+ EXPECT_EQ(Vocabulary::getIndex(Pred), ExpectedIdx);
+ }
+ auto ICMP_Start = CmpInst::LAST_FCMP_PREDICATE + 1;
+ for (unsigned P = CmpInst::FIRST_ICMP_PREDICATE;
+ P <= CmpInst::LAST_ICMP_PREDICATE; ++P) {
+ CmpInst::Predicate Pred = static_cast<CmpInst::Predicate>(P);
+ unsigned ExpectedIdx = EXPECTED_VOCAB_PREDICATE_SLOT(
+ ICMP_Start + P - CmpInst::FIRST_ICMP_PREDICATE);
+ EXPECT_EQ(Vocabulary::getIndex(Pred), ExpectedIdx);
+ }
+#undef EXPECTED_VOCAB_PREDICATE_SLOT
}
#if GTEST_HAS_DEATH_TEST
#ifndef NDEBUG
TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) {
// Test invalid opcode IDs
- EXPECT_DEATH(Vocabulary::getSlotIndex(0u), "Invalid opcode");
- EXPECT_DEATH(Vocabulary::getSlotIndex(MaxOpcodes + 1), "Invalid opcode");
+ EXPECT_DEATH(Vocabulary::getIndex(0u), "Invalid opcode");
+ EXPECT_DEATH(Vocabulary::getIndex(MaxOpcodes + 1), "Invalid opcode");
// Test invalid type IDs
- EXPECT_DEATH(Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs)),
+ EXPECT_DEATH(Vocabulary::getIndex(static_cast<Type::TypeID>(MaxTypeIDs)),
+ "Invalid type ID");
+ EXPECT_DEATH(Vocabulary::getIndex(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
"Invalid type ID");
- EXPECT_DEATH(
- Vocabulary::getSlotIndex(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
- "Invalid type ID");
}
#endif // NDEBUG
#endif // GTEST_HAS_DEATH_TEST
@@ -551,7 +575,7 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
EXPECT_EQ(Vocabulary::getStringKey(12), "Add");
#define EXPECT_OPCODE(NUM, OPCODE, CLASS) \
- EXPECT_EQ(Vocabulary::getStringKey(Vocabulary::getSlotIndex(NUM)), \
+ EXPECT_EQ(Vocabulary::getStringKey(Vocabulary::getIndex(NUM)), \
Vocabulary::getVocabKeyForOpcode(NUM));
#define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE(NUM, OPCODE, CLASS)
#include "llvm/IR/Instruction.def"
@@ -569,6 +593,7 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
#undef EXPECT_CANONICAL_TYPE_NAME
+ // Verify OperandKind -> string mapping
#define HANDLE_OPERAND_KINDS(X) \
X(FunctionID, "Function") \
X(PointerID, "Pointer") \
@@ -592,6 +617,28 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
Vocabulary::getStringKey(MaxOpcodes + MaxCanonicalTypeIDs + 1);
EXPECT_EQ(FuncArgKey, "Function");
EXPECT_EQ(PtrArgKey, "Pointer");
+
+// Verify PredicateKind -> string mapping
+#define EXPECT_PREDICATE_KIND(PredNum, PredPos, PredKind) \
+ do { \
+ std::string PredStr = \
+ std::string(PredKind) + "_" + \
+ CmpInst::getPredicateName(static_cast<CmpInst::Predicate>(PredNum)) \
+ .str(); \
+ unsigned Pos = MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + PredPos; \
+ EXPECT_EQ(Vocabulary::getStringKey(Pos), PredStr); \
+ } while (0)
+
+ for (unsigned P = CmpInst::FIRST_FCMP_PREDICATE;
+ P <= CmpInst::LAST_FCMP_PREDICATE; ++P)
+ EXPECT_PREDICATE_KIND(P, P - CmpInst::FIRST_FCMP_PREDICATE, "FCMP");
+
+ auto ICMP_Pos = CmpInst::LAST_FCMP_PREDICATE + 1;
+ for (unsigned P = CmpInst::FIRST_ICMP_PREDICATE;
+ P <= CmpInst::LAST_ICMP_PREDICATE; ++P)
+ EXPECT_PREDICATE_KIND(P, ICMP_Pos++, "ICMP");
+
+#undef EXPECT_PREDICATE_KIND
}
TEST(IR2VecVocabularyTest, VocabularyDimensions) {
@@ -627,10 +674,12 @@ TEST(IR2VecVocabularyTest, InvalidAccess) {
#endif // GTEST_HAS_DEATH_TEST
TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) {
+ Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
#define EXPECT_TYPE_TO_CANONICAL(TypeIDTok, CanonEnum, CanonStr) \
- EXPECT_EQ( \
- Vocabulary::getStringKey(Vocabulary::getSlotIndex(Type::TypeIDTok)), \
- CanonStr);
+ do { \
+ unsigned FlatIdx = V.getIndex(Type::TypeIDTok); \
+ EXPECT_EQ(Vocabulary::getStringKey(FlatIdx), CanonStr); \
+ } while (0);
IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_TYPE_TO_CANONICAL)
@@ -638,14 +687,20 @@ TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) {
}
TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) {
- std::vector<Embedding> InvalidVocab;
- InvalidVocab.push_back(Embedding(2, 1.0));
- InvalidVocab.push_back(Embedding(2, 2.0));
-
- Vocabulary V(std::move(InvalidVocab));
+ // Test 1: Create invalid VocabStorage with insufficient sections
+ std::vector<std::vector<Embedding>> InvalidSectionData;
+ // Only add one section with 2 embeddings, but the vocabulary needs 4 sections
+ std::vector<Embedding> Section1;
+ Section1.push_back(Embedding(2, 1.0));
+ Section1.push_back(Embedding(2, 2.0));
+ InvalidSectionData.push_back(std::move(Section1));
+
+ VocabStorage InvalidStorage(std::move(InvalidSectionData));
+ Vocabulary V(std::move(InvalidStorage));
EXPECT_FALSE(V.isValid());
{
+ // Test 2: Default-constructed vocabulary should be invalid
Vocabulary InvalidResult;
EXPECT_FALSE(InvalidResult.isValid());
#if GTEST_HAS_DEATH_TEST
@@ -656,4 +711,265 @@ TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) {
}
}
+TEST(VocabStorageTest, DefaultConstructor) {
+ VocabStorage storage;
+
+ EXPECT_EQ(storage.size(), 0u);
+ EXPECT_EQ(storage.getNumSections(), 0u);
+ EXPECT_EQ(storage.getDimension(), 0u);
+ EXPECT_FALSE(storage.isValid());
+
+ // Test iterators on empty storage
+ EXPECT_EQ(storage.begin(), storage.end());
+}
+
+TEST(VocabStorageTest, BasicConstruction) {
+ // Create test data with 3 sections
+ std::vector<std::vector<Embedding>> sectionData;
+
+ // Section 0: 2 embeddings of dimension 3
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0, 3.0});
+ section0.emplace_back(std::vector<double>{4.0, 5.0, 6.0});
+ sectionData.push_back(std::move(section0));
+
+ // Section 1: 1 embedding of dimension 3
+ std::vector<Embedding> section1;
+ section1.emplace_back(std::vector<double>{7.0, 8.0, 9.0});
+ sectionData.push_back(std::move(section1));
+
+ // Section 2: 3 embeddings of dimension 3
+ std::vector<Embedding> section2;
+ section2.emplace_back(std::vector<double>{10.0, 11.0, 12.0});
+ section2.emplace_back(std::vector<double>{13.0, 14.0, 15.0});
+ section2.emplace_back(std::vector<double>{16.0, 17.0, 18.0});
+ sectionData.push_back(std::move(section2));
+
+ VocabStorage storage(std::move(sectionData));
+
+ EXPECT_EQ(storage.size(), 6u); // Total: 2 + 1 + 3 = 6
+ EXPECT_EQ(storage.getNumSections(), 3u);
+ EXPECT_EQ(storage.getDimension(), 3u);
+ EXPECT_TRUE(storage.isValid());
+}
+
+TEST(VocabStorageTest, SectionAccess) {
+ // Create test data
+ std::vector<std::vector<Embedding>> sectionData;
+
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0});
+ section0.emplace_back(std::vector<double>{3.0, 4.0});
+ sectionData.push_back(std::move(section0));
+
+ std::vector<Embedding> section1;
+ section1.emplace_back(std::vector<double>{5.0, 6.0});
+ sectionData.push_back(std::move(section1));
+
+ VocabStorage storage(std::move(sectionData));
+
+ // Test section access
+ EXPECT_EQ(storage[0].size(), 2u);
+ EXPECT_EQ(storage[1].size(), 1u);
+
+ // Test embedding values
+ EXPECT_THAT(storage[0][0].getData(), ElementsAre(1.0, 2.0));
+ EXPECT_THAT(storage[0][1].getData(), ElementsAre(3.0, 4.0));
+ EXPECT_THAT(storage[1][0].getData(), ElementsAre(5.0, 6.0));
+}
+
+#if GTEST_HAS_DEATH_TEST
+#ifndef NDEBUG
+TEST(VocabStorageTest, InvalidSectionAccess) {
+ std::vector<std::vector<Embedding>> sectionData;
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0});
+ sectionData.push_back(std::move(section0));
+
+ VocabStorage storage(std::move(sectionData));
+
+ EXPECT_DEATH(storage[1], "Invalid section ID");
+ EXPECT_DEATH(storage[10], "Invalid section ID");
+}
+
+TEST(VocabStorageTest, EmptySection) {
+ std::vector<std::vector<Embedding>> sectionData;
+ std::vector<Embedding> emptySection; // Empty section
+ sectionData.push_back(std::move(emptySection));
+
+ std::vector<Embedding> validSection;
+ validSection.emplace_back(std::vector<double>{1.0});
+ sectionData.push_back(std::move(validSection));
+
+ EXPECT_DEATH(VocabStorage(std::move(sectionData)),
+ "Vocabulary section is empty");
+}
+
+TEST(VocabStorageTest, EmptyMiddleSection) {
+ std::vector<std::vector<Embedding>> sectionData;
+
+ // Valid first section
+ std::vector<Embedding> validSection1;
+ validSection1.emplace_back(std::vector<double>{1.0});
+ sectionData.push_back(std::move(validSection1));
+
+ // Empty middle section
+ std::vector<Embedding> emptySection;
+ sectionData.push_back(std::move(emptySection));
+
+ // Valid last section
+ std::vector<Embedding> validSection2;
+ validSection2.emplace_back(std::vector<double>{2.0});
+ sectionData.push_back(std::move(validSection2));
+
+ EXPECT_DEATH(VocabStorage(std::move(sectionData)),
+ "Vocabulary section is empty");
+}
+
+TEST(VocabStorageTest, NoSections) {
+ std::vector<std::vector<Embedding>> sectionData; // No sections
+
+ EXPECT_DEATH(VocabStorage(std::move(sectionData)),
+ "Vocabulary has no sections");
+}
+
+TEST(VocabStorageTest, MismatchedDimensionsAcrossSections) {
+ std::vector<std::vector<Embedding>> sectionData;
+
+ // Section 0: embeddings with dimension 2
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0});
+ section0.emplace_back(std::vector<double>{3.0, 4.0});
+ sectionData.push_back(std::move(section0));
+
+ // Section 1: embedding with dimension 3 (mismatch!)
+ std::vector<Embedding> section1;
+ section1.emplace_back(std::vector<double>{5.0, 6.0, 7.0});
+ sectionData.push_back(std::move(section1));
+
+ EXPECT_DEATH(VocabStorage(std::move(sectionData)),
+ "All embeddings must have the same dimension");
+}
+
+TEST(VocabStorageTest, MismatchedDimensionsWithinSection) {
+ std::vector<std::vector<Embedding>> sectionData;
+
+ // Section 0: first embedding with dimension 2, second with dimension 3
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0});
+ section0.emplace_back(std::vector<double>{3.0, 4.0, 5.0}); // Mismatch!
+ sectionData.push_back(std::move(section0));
+
+ EXPECT_DEATH(VocabStorage(std::move(sectionData)),
+ "All embeddings must have the same dimension");
+}
+#endif // NDEBUG
+#endif // GTEST_HAS_DEATH_TEST
+
+TEST(VocabStorageTest, IteratorBasics) {
+ std::vector<std::vector<Embedding>> sectionData;
+
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0, 2.0});
+ section0.emplace_back(std::vector<double>{3.0, 4.0});
+ sectionData.push_back(std::move(section0));
+
+ std::vector<Embedding> section1;
+ section1.emplace_back(std::vector<double>{5.0, 6.0});
+ sectionData.push_back(std::move(section1));
+
+ VocabStorage storage(std::move(sectionData));
+
+ // Test iterator basics
+ auto it = storage.begin();
+ auto end = storage.end();
+
+ EXPECT_NE(it, end);
+
+ // Check first embedding
+ EXPECT_THAT((*it).getData(), ElementsAre(1.0, 2.0));
+
+ // Advance to second embedding
+ ++it;
+ EXPECT_NE(it, end);
+ EXPECT_THAT((*it).getData(), ElementsAre(3.0, 4.0));
+
+ // Advance to third embedding (in section 1)
+ ++it;
+ EXPECT_NE(it, end);
+ EXPECT_THAT((*it).getData(), ElementsAre(5.0, 6.0));
+
+ // Advance past the end
+ ++it;
+ EXPECT_EQ(it, end);
+}
+
+TEST(VocabStorageTest, IteratorTraversal) {
+ std::vector<std::vector<Embedding>> sectionData;
+
+ // Section 0: 2 embeddings
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{10.0});
+ section0.emplace_back(std::vector<double>{20.0});
+ sectionData.push_back(std::move(section0));
+
+ // Section 1: 1 embedding
+ std::vector<Embedding> section1;
+ section1.emplace_back(std::vector<double>{25.0});
+ sectionData.push_back(std::move(section1));
+
+ // Section 2: 3 embeddings
+ std::vector<Embedding> section2;
+ section2.emplace_back(std::vector<double>{30.0});
+ section2.emplace_back(std::vector<double>{40.0});
+ section2.emplace_back(std::vector<double>{50.0});
+ sectionData.push_back(std::move(section2));
+
+ VocabStorage storage(std::move(sectionData));
+
+ // Collect all values using iterator
+ std::vector<double> values;
+ for (const auto &emb : storage) {
+ EXPECT_EQ(emb.size(), 1u);
+ values.push_back(emb[0]);
+ }
+
+ // Should get all embeddings from all sections
+ EXPECT_THAT(values, ElementsAre(10.0, 20.0, 25.0, 30.0, 40.0, 50.0));
+}
+
+TEST(VocabStorageTest, IteratorComparison) {
+ std::vector<std::vector<Embedding>> sectionData;
+ std::vector<Embedding> section0;
+ section0.emplace_back(std::vector<double>{1.0});
+ section0.emplace_back(std::vector<double>{2.0});
+ sectionData.push_back(std::move(section0));
+
+ VocabStorage storage(std::move(sectionData));
+
+ auto it1 = storage.begin();
+ auto it2 = storage.begin();
+ auto end = storage.end();
+
+ // Test equality
+ EXPECT_EQ(it1, it2);
+ EXPECT_NE(it1, end);
+
+ // Advance one iterator
+ ++it1;
+ EXPECT_NE(it1, it2);
+ EXPECT_NE(it1, end);
+
+ // Advance second iterator to match
+ ++it2;
+ EXPECT_EQ(it1, it2);
+
+ // Advance both to end
+ ++it1;
+ ++it2;
+ EXPECT_EQ(it1, end);
+ EXPECT_EQ(it2, end);
+ EXPECT_EQ(it1, it2);
+}
+
} // end anonymous namespace