diff options
Diffstat (limited to 'llvm/unittests/CodeGen/MIR2VecTest.cpp')
-rw-r--r-- | llvm/unittests/CodeGen/MIR2VecTest.cpp | 299 |
1 files changed, 275 insertions, 24 deletions
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp index 11222b4..8710d6b 100644 --- a/llvm/unittests/CodeGen/MIR2VecTest.cpp +++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp @@ -82,6 +82,9 @@ protected: return; } + // Set the data layout to match the target machine + M->setDataLayout(TM->createDataLayout()); + // Create a dummy function to get subtarget info FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false); Function *F = @@ -96,16 +99,27 @@ protected: } void TearDown() override { TII = nullptr; } -}; -// Function to find an opcode by name -static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) { - for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) { - if (TII->getName(Opcode) == Name) - return Opcode; + // Find an opcode by name + int findOpcodeByName(StringRef Name) { + for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) { + if (TII->getName(Opcode) == Name) + return Opcode; + } + return -1; // Not found } - return -1; // Not found -} + + // Create a vocabulary with specific opcodes and embeddings + Expected<MIRVocabulary> + createTestVocab(std::initializer_list<std::pair<const char *, float>> opcodes, + unsigned dimension = 2) { + assert(TII && "TargetInstrInfo not initialized"); + VocabMap VMap; + for (const auto &[name, value] : opcodes) + VMap[name] = Embedding(dimension, value); + return MIRVocabulary::create(std::move(VMap), *TII); + } +}; TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Test that same base opcodes get same canonical indices @@ -118,10 +132,8 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Create a MIRVocabulary instance to test the mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction - VocabMap VMap; Embedding Val = Embedding(64, 1.0f); - VMap["ADD"] = Val; - auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); + auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, 64); ASSERT_TRUE(static_cast<bool>(TestVocabOrErr)) << "Failed to create vocabulary: " << toString(TestVocabOrErr.takeError()); @@ -156,16 +168,16 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { 6880u); // X86 has >6880 unique base opcodes // Check that the embeddings for opcodes not in the vocab are zero vectors - int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr"); + int Add32rrOpcode = findOpcodeByName("ADD32rr"); ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found"; EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val)); - int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr"); + int Sub32rrOpcode = findOpcodeByName("SUB32rr"); ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found"; EXPECT_TRUE( TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); - int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr"); + int Mov32rrOpcode = findOpcodeByName("MOV32rr"); ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found"; EXPECT_TRUE( TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); @@ -178,9 +190,7 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Create a MIRVocabulary instance to test deterministic mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction - VocabMap VMap; - VMap["ADD"] = Embedding(64, 1.0f); - auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); + auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, 64); ASSERT_TRUE(static_cast<bool>(TestVocabOrErr)) << "Failed to create vocabulary: " << toString(TestVocabOrErr.takeError()); @@ -189,8 +199,6 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName); unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName); - - EXPECT_EQ(Index1, Index2); EXPECT_EQ(Index2, Index3); // Test across multiple runs @@ -202,11 +210,7 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Test MIRVocabulary construction TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { - VocabMap VMap; - VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0 - VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0 - - auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); + auto VocabOrErr = createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, 128); ASSERT_TRUE(static_cast<bool>(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; @@ -243,4 +247,251 @@ TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) { } } +// Fixture for embedding related tests +class MIR2VecEmbeddingTestFixture : public MIR2VecVocabTestFixture { +protected: + std::unique_ptr<MachineModuleInfo> MMI; + MachineFunction *MF = nullptr; + + void SetUp() override { + MIR2VecVocabTestFixture::SetUp(); + // If base class setup was skipped (TII not initialized), skip derived setup + if (!TII) + GTEST_SKIP() << "Failed to get target instruction info in " + "the base class setup; Skipping test"; + + // Create a dummy function for MachineFunction + FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false); + Function *F = + Function::Create(FT, Function::ExternalLinkage, "test", M.get()); + + MMI = std::make_unique<MachineModuleInfo>(TM.get()); + MF = &MMI->getOrCreateMachineFunction(*F); + } + + void TearDown() override { MIR2VecVocabTestFixture::TearDown(); } + + // Create a machine instruction + MachineInstr *createMachineInstr(MachineBasicBlock &MBB, unsigned Opcode) { + const MCInstrDesc &Desc = TII->get(Opcode); + // Create instruction - operands don't affect opcode-based embeddings + MachineInstr *MI = BuildMI(MBB, MBB.end(), DebugLoc(), Desc); + return MI; + } + + MachineInstr *createMachineInstr(MachineBasicBlock &MBB, + const char *OpcodeName) { + int Opcode = findOpcodeByName(OpcodeName); + if (Opcode == -1) + return nullptr; + return createMachineInstr(MBB, Opcode); + } + + void createMachineInstrs(MachineBasicBlock &MBB, + std::initializer_list<const char *> Opcodes) { + for (const char *OpcodeName : Opcodes) { + MachineInstr *MI = createMachineInstr(MBB, OpcodeName); + ASSERT_TRUE(MI != nullptr); + } + } +}; + +// Test factory method for creating embedder +TEST_F(MIR2VecEmbeddingTestFixture, CreateSymbolicEmbedder) { + auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 1); + ASSERT_TRUE(static_cast<bool>(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &V = *VocabOrErr; + auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, *MF, V); + EXPECT_NE(Emb, nullptr); +} + +TEST_F(MIR2VecEmbeddingTestFixture, CreateInvalidMode) { + auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 1); + ASSERT_TRUE(static_cast<bool>(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &V = *VocabOrErr; + auto Result = MIREmbedder::create(static_cast<MIR2VecKind>(-1), *MF, V); + EXPECT_FALSE(static_cast<bool>(Result)); +} + +// Test SymbolicMIREmbedder with simple target opcodes +TEST_F(MIR2VecEmbeddingTestFixture, TestSymbolicEmbedder) { + // Create a test vocabulary with specific values + auto VocabOrErr = createTestVocab( + { + {"NOOP", 1.0f}, // [1.0, 1.0, 1.0, 1.0] + {"RET", 2.0f}, // [2.0, 2.0, 2.0, 2.0] + {"TRAP", 3.0f} // [3.0, 3.0, 3.0, 3.0] + }, + 4); + ASSERT_TRUE(static_cast<bool>(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &Vocab = *VocabOrErr; + // Create a basic block using fixture's MF + MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); + MF->push_back(MBB); + + // Use real X86 opcodes that should exist and not be pseudo + auto NoopInst = createMachineInstr(*MBB, "NOOP"); + ASSERT_TRUE(NoopInst != nullptr); + + auto RetInst = createMachineInstr(*MBB, "RET64"); + ASSERT_TRUE(RetInst != nullptr); + + auto TrapInst = createMachineInstr(*MBB, "TRAP"); + ASSERT_TRUE(TrapInst != nullptr); + + // Verify these are not pseudo instructions + ASSERT_FALSE(NoopInst->isPseudo()) << "NOOP is marked as pseudo instruction"; + ASSERT_FALSE(RetInst->isPseudo()) << "RET is marked as pseudo instruction"; + ASSERT_FALSE(TrapInst->isPseudo()) << "TRAP is marked as pseudo instruction"; + + // Create embedder + auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); + ASSERT_TRUE(Embedder != nullptr); + + // Test instruction embeddings + auto NoopEmb = Embedder->getMInstVector(*NoopInst); + auto RetEmb = Embedder->getMInstVector(*RetInst); + auto TrapEmb = Embedder->getMInstVector(*TrapInst); + + // Verify embeddings match expected values (accounting for weight scaling) + float ExpectedWeight = mir2vec::OpcWeight; // Global weight from command line + EXPECT_TRUE(NoopEmb.approximatelyEquals(Embedding(4, 1.0f * ExpectedWeight))); + EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(4, 2.0f * ExpectedWeight))); + EXPECT_TRUE(TrapEmb.approximatelyEquals(Embedding(4, 3.0f * ExpectedWeight))); + + // Test basic block embedding (should be sum of instruction embeddings) + auto MBBVector = Embedder->getMBBVector(*MBB); + + // Expected BB vector: NOOP + RET + TRAP = [1+2+3, 1+2+3, 1+2+3, 1+2+3] * + // weight = [6, 6, 6, 6] * weight + Embedding ExpectedMBBVector(4, 6.0f * ExpectedWeight); + EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedMBBVector)); + + // Test function embedding (should equal MBB embedding since we have one MBB) + auto MFuncVector = Embedder->getMFunctionVector(); + EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBBVector)); +} + +// Test embedder with multiple basic blocks +TEST_F(MIR2VecEmbeddingTestFixture, MultipleBasicBlocks) { + // Create a test vocabulary + auto VocabOrErr = createTestVocab({{"NOOP", 1.0f}, {"TRAP", 2.0f}}); + ASSERT_TRUE(static_cast<bool>(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &Vocab = *VocabOrErr; + + // Create two basic blocks using fixture's MF + MachineBasicBlock *MBB1 = MF->CreateMachineBasicBlock(); + MachineBasicBlock *MBB2 = MF->CreateMachineBasicBlock(); + MF->push_back(MBB1); + MF->push_back(MBB2); + + createMachineInstrs(*MBB1, {"NOOP", "NOOP"}); + createMachineInstr(*MBB2, "TRAP"); + + // Create embedder + auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); + ASSERT_TRUE(Embedder != nullptr); + + // Test basic block embeddings + auto MBB1Vector = Embedder->getMBBVector(*MBB1); + auto MBB2Vector = Embedder->getMBBVector(*MBB2); + + float ExpectedWeight = mir2vec::OpcWeight; + // BB1: NOOP + NOOP = 2 * ([1, 1] * weight) + Embedding ExpectedMBB1Vector(2, 2.0f * ExpectedWeight); + EXPECT_TRUE(MBB1Vector.approximatelyEquals(ExpectedMBB1Vector)); + + // BB2: TRAP = [2, 2] * weight + Embedding ExpectedMBB2Vector(2, 2.0f * ExpectedWeight); + EXPECT_TRUE(MBB2Vector.approximatelyEquals(ExpectedMBB2Vector)); + + // Function embedding: BB1 + BB2 = [2+2, 2+2] * weight = [4, 4] * weight + // Function embedding should be just the first BB embedding as the second BB + // is unreachable + auto MFuncVector = Embedder->getMFunctionVector(); + EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBB1Vector)); + + // Add a branch from BB1 to BB2 to make both reachable; now function embedding + // should be MBB1 + MBB2 + MBB1->addSuccessor(MBB2); + auto NewMFuncVector = Embedder->getMFunctionVector(); // Recompute embeddings + Embedding ExpectedFuncVector = MBB1Vector + MBB2Vector; + EXPECT_TRUE(NewMFuncVector.approximatelyEquals(ExpectedFuncVector)); +} + +// Test embedder with empty basic block +TEST_F(MIR2VecEmbeddingTestFixture, EmptyBasicBlock) { + + // Create an empty basic block + MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); + MF->push_back(MBB); + + // Create embedder + auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 2); + ASSERT_TRUE(static_cast<bool>(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &V = *VocabOrErr; + auto Embedder = SymbolicMIREmbedder::create(*MF, V); + ASSERT_TRUE(Embedder != nullptr); + + // Test that empty BB has zero embedding + auto MBBVector = Embedder->getMBBVector(*MBB); + Embedding ExpectedBBVector(2, 0.0f); + EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector)); + + // Function embedding should also be zero + auto MFuncVector = Embedder->getMFunctionVector(); + EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedBBVector)); +} + +// Test embedder with opcodes not in vocabulary +TEST_F(MIR2VecEmbeddingTestFixture, UnknownOpcodes) { + // Create a test vocabulary with limited entries + // SUB is intentionally not included + auto VocabOrErr = createTestVocab({{"ADD", 1.0f}}); + ASSERT_TRUE(static_cast<bool>(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &Vocab = *VocabOrErr; + + // Create a basic block + MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); + MF->push_back(MBB); + + // Find opcodes + int AddOpcode = findOpcodeByName("ADD32rr"); + int SubOpcode = findOpcodeByName("SUB32rr"); + + ASSERT_NE(AddOpcode, -1) << "ADD32rr opcode not found"; + ASSERT_NE(SubOpcode, -1) << "SUB32rr opcode not found"; + + // Create instructions + MachineInstr *AddInstr = createMachineInstr(*MBB, AddOpcode); + MachineInstr *SubInstr = createMachineInstr(*MBB, SubOpcode); + + // Create embedder + auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); + ASSERT_TRUE(Embedder != nullptr); + + // Test instruction embeddings + auto AddVector = Embedder->getMInstVector(*AddInstr); + auto SubVector = Embedder->getMInstVector(*SubInstr); + + float ExpectedWeight = mir2vec::OpcWeight; + // ADD should have the embedding from vocabulary + EXPECT_TRUE( + AddVector.approximatelyEquals(Embedding(2, 1.0f * ExpectedWeight))); + + // SUB should have zero embedding (not in vocabulary) + EXPECT_TRUE(SubVector.approximatelyEquals(Embedding(2, 0.0f))); + + // Basic block embedding should be ADD + SUB = [1.0, 1.0] * weight + [0.0, + // 0.0] = [1.0, 1.0] * weight + const auto &MBBVector = Embedder->getMBBVector(*MBB); + Embedding ExpectedBBVector(2, 1.0f * ExpectedWeight); + EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector)); +} } // namespace |