aboutsummaryrefslogtreecommitdiff
path: root/llvm/unittests/CodeGen/MIR2VecTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/unittests/CodeGen/MIR2VecTest.cpp')
-rw-r--r--llvm/unittests/CodeGen/MIR2VecTest.cpp299
1 files changed, 275 insertions, 24 deletions
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index 11222b4..8710d6b 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -82,6 +82,9 @@ protected:
return;
}
+ // Set the data layout to match the target machine
+ M->setDataLayout(TM->createDataLayout());
+
// Create a dummy function to get subtarget info
FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false);
Function *F =
@@ -96,16 +99,27 @@ protected:
}
void TearDown() override { TII = nullptr; }
-};
-// Function to find an opcode by name
-static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) {
- for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
- if (TII->getName(Opcode) == Name)
- return Opcode;
+ // Find an opcode by name
+ int findOpcodeByName(StringRef Name) {
+ for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
+ if (TII->getName(Opcode) == Name)
+ return Opcode;
+ }
+ return -1; // Not found
}
- return -1; // Not found
-}
+
+ // Create a vocabulary with specific opcodes and embeddings
+ Expected<MIRVocabulary>
+ createTestVocab(std::initializer_list<std::pair<const char *, float>> opcodes,
+ unsigned dimension = 2) {
+ assert(TII && "TargetInstrInfo not initialized");
+ VocabMap VMap;
+ for (const auto &[name, value] : opcodes)
+ VMap[name] = Embedding(dimension, value);
+ return MIRVocabulary::create(std::move(VMap), *TII);
+ }
+};
TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Test that same base opcodes get same canonical indices
@@ -118,10 +132,8 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Create a MIRVocabulary instance to test the mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
- VocabMap VMap;
Embedding Val = Embedding(64, 1.0f);
- VMap["ADD"] = Val;
- auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, 64);
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
<< "Failed to create vocabulary: "
<< toString(TestVocabOrErr.takeError());
@@ -156,16 +168,16 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
6880u); // X86 has >6880 unique base opcodes
// Check that the embeddings for opcodes not in the vocab are zero vectors
- int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr");
+ int Add32rrOpcode = findOpcodeByName("ADD32rr");
ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found";
EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val));
- int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr");
+ int Sub32rrOpcode = findOpcodeByName("SUB32rr");
ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found";
EXPECT_TRUE(
TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
- int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr");
+ int Mov32rrOpcode = findOpcodeByName("MOV32rr");
ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found";
EXPECT_TRUE(
TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
@@ -178,9 +190,7 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Create a MIRVocabulary instance to test deterministic mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
- VocabMap VMap;
- VMap["ADD"] = Embedding(64, 1.0f);
- auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, 64);
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
<< "Failed to create vocabulary: "
<< toString(TestVocabOrErr.takeError());
@@ -189,8 +199,6 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName);
-
- EXPECT_EQ(Index1, Index2);
EXPECT_EQ(Index2, Index3);
// Test across multiple runs
@@ -202,11 +210,7 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Test MIRVocabulary construction
TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
- VocabMap VMap;
- VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
- VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
-
- auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ auto VocabOrErr = createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, 128);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
@@ -243,4 +247,251 @@ TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
}
}
+// Fixture for embedding related tests
+class MIR2VecEmbeddingTestFixture : public MIR2VecVocabTestFixture {
+protected:
+ std::unique_ptr<MachineModuleInfo> MMI;
+ MachineFunction *MF = nullptr;
+
+ void SetUp() override {
+ MIR2VecVocabTestFixture::SetUp();
+ // If base class setup was skipped (TII not initialized), skip derived setup
+ if (!TII)
+ GTEST_SKIP() << "Failed to get target instruction info in "
+ "the base class setup; Skipping test";
+
+ // Create a dummy function for MachineFunction
+ FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false);
+ Function *F =
+ Function::Create(FT, Function::ExternalLinkage, "test", M.get());
+
+ MMI = std::make_unique<MachineModuleInfo>(TM.get());
+ MF = &MMI->getOrCreateMachineFunction(*F);
+ }
+
+ void TearDown() override { MIR2VecVocabTestFixture::TearDown(); }
+
+ // Create a machine instruction
+ MachineInstr *createMachineInstr(MachineBasicBlock &MBB, unsigned Opcode) {
+ const MCInstrDesc &Desc = TII->get(Opcode);
+ // Create instruction - operands don't affect opcode-based embeddings
+ MachineInstr *MI = BuildMI(MBB, MBB.end(), DebugLoc(), Desc);
+ return MI;
+ }
+
+ MachineInstr *createMachineInstr(MachineBasicBlock &MBB,
+ const char *OpcodeName) {
+ int Opcode = findOpcodeByName(OpcodeName);
+ if (Opcode == -1)
+ return nullptr;
+ return createMachineInstr(MBB, Opcode);
+ }
+
+ void createMachineInstrs(MachineBasicBlock &MBB,
+ std::initializer_list<const char *> Opcodes) {
+ for (const char *OpcodeName : Opcodes) {
+ MachineInstr *MI = createMachineInstr(MBB, OpcodeName);
+ ASSERT_TRUE(MI != nullptr);
+ }
+ }
+};
+
+// Test factory method for creating embedder
+TEST_F(MIR2VecEmbeddingTestFixture, CreateSymbolicEmbedder) {
+ auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 1);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &V = *VocabOrErr;
+ auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, *MF, V);
+ EXPECT_NE(Emb, nullptr);
+}
+
+TEST_F(MIR2VecEmbeddingTestFixture, CreateInvalidMode) {
+ auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 1);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &V = *VocabOrErr;
+ auto Result = MIREmbedder::create(static_cast<MIR2VecKind>(-1), *MF, V);
+ EXPECT_FALSE(static_cast<bool>(Result));
+}
+
+// Test SymbolicMIREmbedder with simple target opcodes
+TEST_F(MIR2VecEmbeddingTestFixture, TestSymbolicEmbedder) {
+ // Create a test vocabulary with specific values
+ auto VocabOrErr = createTestVocab(
+ {
+ {"NOOP", 1.0f}, // [1.0, 1.0, 1.0, 1.0]
+ {"RET", 2.0f}, // [2.0, 2.0, 2.0, 2.0]
+ {"TRAP", 3.0f} // [3.0, 3.0, 3.0, 3.0]
+ },
+ 4);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+ // Create a basic block using fixture's MF
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Use real X86 opcodes that should exist and not be pseudo
+ auto NoopInst = createMachineInstr(*MBB, "NOOP");
+ ASSERT_TRUE(NoopInst != nullptr);
+
+ auto RetInst = createMachineInstr(*MBB, "RET64");
+ ASSERT_TRUE(RetInst != nullptr);
+
+ auto TrapInst = createMachineInstr(*MBB, "TRAP");
+ ASSERT_TRUE(TrapInst != nullptr);
+
+ // Verify these are not pseudo instructions
+ ASSERT_FALSE(NoopInst->isPseudo()) << "NOOP is marked as pseudo instruction";
+ ASSERT_FALSE(RetInst->isPseudo()) << "RET is marked as pseudo instruction";
+ ASSERT_FALSE(TrapInst->isPseudo()) << "TRAP is marked as pseudo instruction";
+
+ // Create embedder
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // Test instruction embeddings
+ auto NoopEmb = Embedder->getMInstVector(*NoopInst);
+ auto RetEmb = Embedder->getMInstVector(*RetInst);
+ auto TrapEmb = Embedder->getMInstVector(*TrapInst);
+
+ // Verify embeddings match expected values (accounting for weight scaling)
+ float ExpectedWeight = mir2vec::OpcWeight; // Global weight from command line
+ EXPECT_TRUE(NoopEmb.approximatelyEquals(Embedding(4, 1.0f * ExpectedWeight)));
+ EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(4, 2.0f * ExpectedWeight)));
+ EXPECT_TRUE(TrapEmb.approximatelyEquals(Embedding(4, 3.0f * ExpectedWeight)));
+
+ // Test basic block embedding (should be sum of instruction embeddings)
+ auto MBBVector = Embedder->getMBBVector(*MBB);
+
+ // Expected BB vector: NOOP + RET + TRAP = [1+2+3, 1+2+3, 1+2+3, 1+2+3] *
+ // weight = [6, 6, 6, 6] * weight
+ Embedding ExpectedMBBVector(4, 6.0f * ExpectedWeight);
+ EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedMBBVector));
+
+ // Test function embedding (should equal MBB embedding since we have one MBB)
+ auto MFuncVector = Embedder->getMFunctionVector();
+ EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBBVector));
+}
+
+// Test embedder with multiple basic blocks
+TEST_F(MIR2VecEmbeddingTestFixture, MultipleBasicBlocks) {
+ // Create a test vocabulary
+ auto VocabOrErr = createTestVocab({{"NOOP", 1.0f}, {"TRAP", 2.0f}});
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ // Create two basic blocks using fixture's MF
+ MachineBasicBlock *MBB1 = MF->CreateMachineBasicBlock();
+ MachineBasicBlock *MBB2 = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB1);
+ MF->push_back(MBB2);
+
+ createMachineInstrs(*MBB1, {"NOOP", "NOOP"});
+ createMachineInstr(*MBB2, "TRAP");
+
+ // Create embedder
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // Test basic block embeddings
+ auto MBB1Vector = Embedder->getMBBVector(*MBB1);
+ auto MBB2Vector = Embedder->getMBBVector(*MBB2);
+
+ float ExpectedWeight = mir2vec::OpcWeight;
+ // BB1: NOOP + NOOP = 2 * ([1, 1] * weight)
+ Embedding ExpectedMBB1Vector(2, 2.0f * ExpectedWeight);
+ EXPECT_TRUE(MBB1Vector.approximatelyEquals(ExpectedMBB1Vector));
+
+ // BB2: TRAP = [2, 2] * weight
+ Embedding ExpectedMBB2Vector(2, 2.0f * ExpectedWeight);
+ EXPECT_TRUE(MBB2Vector.approximatelyEquals(ExpectedMBB2Vector));
+
+ // Function embedding: BB1 + BB2 = [2+2, 2+2] * weight = [4, 4] * weight
+ // Function embedding should be just the first BB embedding as the second BB
+ // is unreachable
+ auto MFuncVector = Embedder->getMFunctionVector();
+ EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBB1Vector));
+
+ // Add a branch from BB1 to BB2 to make both reachable; now function embedding
+ // should be MBB1 + MBB2
+ MBB1->addSuccessor(MBB2);
+ auto NewMFuncVector = Embedder->getMFunctionVector(); // Recompute embeddings
+ Embedding ExpectedFuncVector = MBB1Vector + MBB2Vector;
+ EXPECT_TRUE(NewMFuncVector.approximatelyEquals(ExpectedFuncVector));
+}
+
+// Test embedder with empty basic block
+TEST_F(MIR2VecEmbeddingTestFixture, EmptyBasicBlock) {
+
+ // Create an empty basic block
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Create embedder
+ auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 2);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &V = *VocabOrErr;
+ auto Embedder = SymbolicMIREmbedder::create(*MF, V);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // Test that empty BB has zero embedding
+ auto MBBVector = Embedder->getMBBVector(*MBB);
+ Embedding ExpectedBBVector(2, 0.0f);
+ EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector));
+
+ // Function embedding should also be zero
+ auto MFuncVector = Embedder->getMFunctionVector();
+ EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedBBVector));
+}
+
+// Test embedder with opcodes not in vocabulary
+TEST_F(MIR2VecEmbeddingTestFixture, UnknownOpcodes) {
+ // Create a test vocabulary with limited entries
+ // SUB is intentionally not included
+ auto VocabOrErr = createTestVocab({{"ADD", 1.0f}});
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ // Create a basic block
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Find opcodes
+ int AddOpcode = findOpcodeByName("ADD32rr");
+ int SubOpcode = findOpcodeByName("SUB32rr");
+
+ ASSERT_NE(AddOpcode, -1) << "ADD32rr opcode not found";
+ ASSERT_NE(SubOpcode, -1) << "SUB32rr opcode not found";
+
+ // Create instructions
+ MachineInstr *AddInstr = createMachineInstr(*MBB, AddOpcode);
+ MachineInstr *SubInstr = createMachineInstr(*MBB, SubOpcode);
+
+ // Create embedder
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // Test instruction embeddings
+ auto AddVector = Embedder->getMInstVector(*AddInstr);
+ auto SubVector = Embedder->getMInstVector(*SubInstr);
+
+ float ExpectedWeight = mir2vec::OpcWeight;
+ // ADD should have the embedding from vocabulary
+ EXPECT_TRUE(
+ AddVector.approximatelyEquals(Embedding(2, 1.0f * ExpectedWeight)));
+
+ // SUB should have zero embedding (not in vocabulary)
+ EXPECT_TRUE(SubVector.approximatelyEquals(Embedding(2, 0.0f)));
+
+ // Basic block embedding should be ADD + SUB = [1.0, 1.0] * weight + [0.0,
+ // 0.0] = [1.0, 1.0] * weight
+ const auto &MBBVector = Embedder->getMBBVector(*MBB);
+ Embedding ExpectedBBVector(2, 1.0f * ExpectedWeight);
+ EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector));
+}
} // namespace