aboutsummaryrefslogtreecommitdiff
path: root/llvm/unittests/Analysis/IR2VecTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/unittests/Analysis/IR2VecTest.cpp')
-rw-r--r--llvm/unittests/Analysis/IR2VecTest.cpp63
1 files changed, 63 insertions, 0 deletions
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index cb6d633..7c9a546 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -396,6 +396,69 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
}
}
+TEST(IR2VecVocabularyTest, NumericIDMap) {
+ // Test getNumericID for opcodes
+ EXPECT_EQ(Vocabulary::getNumericID(1u), 0u);
+ EXPECT_EQ(Vocabulary::getNumericID(13u), 12u);
+ EXPECT_EQ(Vocabulary::getNumericID(MaxOpcodes), MaxOpcodes - 1);
+
+ // Test getNumericID for Type IDs
+ EXPECT_EQ(Vocabulary::getNumericID(Type::VoidTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::VoidTyID));
+ EXPECT_EQ(Vocabulary::getNumericID(Type::HalfTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::HalfTyID));
+ EXPECT_EQ(Vocabulary::getNumericID(Type::FloatTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::FloatTyID));
+ EXPECT_EQ(Vocabulary::getNumericID(Type::IntegerTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::IntegerTyID));
+ EXPECT_EQ(Vocabulary::getNumericID(Type::PointerTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::PointerTyID));
+
+ // Test getNumericID for Value operands
+ LLVMContext Ctx;
+ Module M("TestM", Ctx);
+ FunctionType *FTy =
+ FunctionType::get(Type::getVoidTy(Ctx), {Type::getInt32Ty(Ctx)}, false);
+ Function *F = Function::Create(FTy, Function::ExternalLinkage, "testFunc", M);
+
+ // Test Function operand
+ EXPECT_EQ(Vocabulary::getNumericID(F),
+ MaxOpcodes + MaxTypeIDs + 0u); // Function = 0
+
+ // Test Constant operand
+ Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
+ EXPECT_EQ(Vocabulary::getNumericID(C),
+ MaxOpcodes + MaxTypeIDs + 2u); // Constant = 2
+
+ // Test Pointer operand
+ BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
+ AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);
+ EXPECT_EQ(Vocabulary::getNumericID(PtrVal),
+ MaxOpcodes + MaxTypeIDs + 1u); // Pointer = 1
+
+ // Test Variable operand (function argument)
+ Argument *Arg = F->getArg(0);
+ EXPECT_EQ(Vocabulary::getNumericID(Arg),
+ MaxOpcodes + MaxTypeIDs + 3u); // Variable = 3
+}
+
+#if GTEST_HAS_DEATH_TEST
+#ifndef NDEBUG
+TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) {
+ // Test invalid opcode IDs
+ EXPECT_DEATH(Vocabulary::getNumericID(0u), "Invalid opcode");
+ EXPECT_DEATH(Vocabulary::getNumericID(MaxOpcodes + 1), "Invalid opcode");
+
+ // Test invalid type IDs
+ EXPECT_DEATH(Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs)),
+ "Invalid type ID");
+ EXPECT_DEATH(
+ Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
+ "Invalid type ID");
+}
+#endif // NDEBUG
+#endif // GTEST_HAS_DEATH_TEST
+
TEST(IR2VecVocabularyTest, StringKeyGeneration) {
EXPECT_EQ(Vocabulary::getStringKey(0), "Ret");
EXPECT_EQ(Vocabulary::getStringKey(12), "Add");