//===- MIR2VecTest.cpp ---------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "llvm/CodeGen/MIR2Vec.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" #include "llvm/TargetParser/Triple.h" #include "gtest/gtest.h" using namespace llvm; using namespace mir2vec; using VocabMap = std::map; namespace { TEST(MIR2VecTest, RegexExtraction) { // Test simple instruction names EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("NOP"), "NOP"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("RET"), "RET"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("ADD16ri"), "ADD"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("ADD32rr"), "ADD"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("ADD64rm"), "ADD"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("MOV8ri"), "MOV"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("MOV32mr"), "MOV"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("PUSH64r"), "PUSH"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("POP64r"), "POP"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("JMP_4"), "JMP"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("CALL64pcrel32"), "CALL"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("SOME_INSTR_123"), "SOME_INSTR"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("123ADD"), "ADD"); EXPECT_FALSE(MIRVocabulary::extractBaseOpcodeName("123").empty()); } class MIR2VecVocabTestFixture : public ::testing::Test { protected: std::unique_ptr Ctx; std::unique_ptr M; std::unique_ptr TM; const TargetInstrInfo *TII = nullptr; const TargetRegisterInfo *TRI = nullptr; std::unique_ptr MMI; MachineFunction *MF = nullptr; static void SetUpTestCase() { InitializeAllTargets(); InitializeAllTargetMCs(); } void SetUp() override { Triple TargetTriple("x86_64-unknown-linux-gnu"); std::string Error; const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error); if (!T) { GTEST_SKIP() << "x86_64-unknown-linux-gnu target triple not available; " "Skipping test"; return; } Ctx = std::make_unique(); M = std::make_unique("test", *Ctx); M->setTargetTriple(TargetTriple); TargetOptions Options; TM = std::unique_ptr( T->createTargetMachine(TargetTriple, "", "", Options, std::nullopt)); if (!TM) { GTEST_SKIP() << "Failed to create X86 target machine; Skipping test"; return; } // Set the data layout to match the target machine M->setDataLayout(TM->createDataLayout()); // Create a dummy function to get subtarget info FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false); Function *F = Function::Create(FT, Function::ExternalLinkage, "test", M.get()); // Create MMI and MF to get TRI and MRI MMI = std::make_unique(TM.get()); MF = &MMI->getOrCreateMachineFunction(*F); // Get the target instruction info and register info TII = TM->getSubtargetImpl(*F)->getInstrInfo(); TRI = TM->getSubtargetImpl(*F)->getRegisterInfo(); if (!TII || !TRI) { GTEST_SKIP() << "Failed to get target instruction/register info; Skipping test"; return; } } void TearDown() override { TII = nullptr; TRI = nullptr; } // Find an opcode by name int findOpcodeByName(StringRef Name) { for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) { if (TII->getName(Opcode) == Name) return Opcode; } return -1; // Not found } // Create a vocabulary with specific opcodes and embeddings // This might cause errors in future when the validation in // MIRVocabulary::generateStorage() enforces hard checks on the vocabulary // entries. Expected createTestVocab( std::initializer_list> Opcodes, std::initializer_list> CommonOperands, std::initializer_list> PhyRegs, std::initializer_list> VirtRegs, unsigned Dimension = 2) { assert(TII && TRI && MF && "Target info not initialized"); VocabMap OpcodeMap, CommonOperandMap, PhyRegMap, VirtRegMap; for (const auto &[Name, Value] : Opcodes) OpcodeMap[Name] = Embedding(Dimension, Value); for (const auto &[Name, Value] : CommonOperands) CommonOperandMap[Name] = Embedding(Dimension, Value); for (const auto &[Name, Value] : PhyRegs) PhyRegMap[Name] = Embedding(Dimension, Value); for (const auto &[Name, Value] : VirtRegs) VirtRegMap[Name] = Embedding(Dimension, Value); // If any section is empty, create minimal maps for other vocabulary // sections to satisfy validation if (Opcodes.size() == 0) OpcodeMap["NOOP"] = Embedding(Dimension, 0.0f); if (CommonOperands.size() == 0) CommonOperandMap["Immediate"] = Embedding(Dimension, 0.0f); if (PhyRegs.size() == 0) PhyRegMap["GR32"] = Embedding(Dimension, 0.0f); if (VirtRegs.size() == 0) VirtRegMap["GR32"] = Embedding(Dimension, 0.0f); return MIRVocabulary::create( std::move(OpcodeMap), std::move(CommonOperandMap), std::move(PhyRegMap), std::move(VirtRegMap), *TII, *TRI, MF->getRegInfo()); } }; // Parameterized test for empty vocab sections class MIR2VecVocabEmptySectionTestFixture : public MIR2VecVocabTestFixture, public ::testing::WithParamInterface { protected: void SetUp() override { MIR2VecVocabTestFixture::SetUp(); // If base class setup was skipped (TII not initialized), skip derived setup if (!TII) GTEST_SKIP() << "Failed to get target instruction info in " "the base class setup; Skipping test"; } }; TEST_P(MIR2VecVocabEmptySectionTestFixture, EmptySectionFailsValidation) { int EmptySection = GetParam(); VocabMap OpcodeMap, CommonOperandMap, PhyRegMap, VirtRegMap; if (EmptySection != 0) OpcodeMap["ADD"] = Embedding(2, 1.0f); if (EmptySection != 1) CommonOperandMap["Immediate"] = Embedding(2, 0.0f); if (EmptySection != 2) PhyRegMap["GR32"] = Embedding(2, 0.0f); if (EmptySection != 3) VirtRegMap["GR32"] = Embedding(2, 0.0f); ASSERT_TRUE(TII != nullptr); ASSERT_TRUE(TRI != nullptr); ASSERT_TRUE(MF != nullptr); auto VocabOrErr = MIRVocabulary::create( std::move(OpcodeMap), std::move(CommonOperandMap), std::move(PhyRegMap), std::move(VirtRegMap), *TII, *TRI, MF->getRegInfo()); EXPECT_FALSE(static_cast(VocabOrErr)) << "Factory method should fail when section " << EmptySection << " is empty"; if (!VocabOrErr) { auto Err = VocabOrErr.takeError(); std::string ErrorMsg = toString(std::move(Err)); EXPECT_FALSE(ErrorMsg.empty()); } } INSTANTIATE_TEST_SUITE_P(EmptySection, MIR2VecVocabEmptySectionTestFixture, ::testing::Values(0, 1, 2, 3)); TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Test that same base opcodes get same canonical indices std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri"); std::string BaseName2 = MIRVocabulary::extractBaseOpcodeName("ADD32rr"); std::string BaseName3 = MIRVocabulary::extractBaseOpcodeName("ADD64rm"); EXPECT_EQ(BaseName1, BaseName2); EXPECT_EQ(BaseName2, BaseName3); // Create a MIRVocabulary instance to test the mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction Embedding Val = Embedding(64, 1.0f); auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {}, 64); ASSERT_TRUE(static_cast(TestVocabOrErr)) << "Failed to create vocabulary: " << toString(TestVocabOrErr.takeError()); auto &TestVocab = *TestVocabOrErr; unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2); unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName3); EXPECT_EQ(Index1, Index2); EXPECT_EQ(Index2, Index3); // Test that different base opcodes get different canonical indices std::string AddBase = MIRVocabulary::extractBaseOpcodeName("ADD32rr"); std::string SubBase = MIRVocabulary::extractBaseOpcodeName("SUB32rr"); std::string MovBase = MIRVocabulary::extractBaseOpcodeName("MOV32rr"); unsigned AddIndex = TestVocab.getCanonicalIndexForBaseName(AddBase); unsigned SubIndex = TestVocab.getCanonicalIndexForBaseName(SubBase); unsigned MovIndex = TestVocab.getCanonicalIndexForBaseName(MovBase); EXPECT_NE(AddIndex, SubIndex); EXPECT_NE(SubIndex, MovIndex); EXPECT_NE(AddIndex, MovIndex); // Even though we only added "ADD" to the vocab, the canonical mapping // should assign unique indices to all the base opcodes of the target // Ideally, we would check against the exact number of unique base opcodes // for X86, but that would make the test brittle. So we just check that // the number is reasonably closer to the expected number (>6880) and not just // opcodes that we added. EXPECT_GT(TestVocab.getCanonicalSize(), 6880u); // X86 has >6880 unique base opcodes // Check that the embeddings for opcodes not in the vocab are zero vectors int Add32rrOpcode = findOpcodeByName("ADD32rr"); ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found"; EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val)); int Sub32rrOpcode = findOpcodeByName("SUB32rr"); ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found"; EXPECT_TRUE( TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); int Mov32rrOpcode = findOpcodeByName("MOV32rr"); ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found"; EXPECT_TRUE( TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); } // Test deterministic mapping TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Test that the same base name always maps to the same canonical index std::string BaseName = "ADD"; // Create a MIRVocabulary instance to test deterministic mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {}, 64); ASSERT_TRUE(static_cast(TestVocabOrErr)) << "Failed to create vocabulary: " << toString(TestVocabOrErr.takeError()); auto &TestVocab = *TestVocabOrErr; unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName); unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName); EXPECT_EQ(Index2, Index3); // Test across multiple runs for (int Pos = 0; Pos < 100; ++Pos) { unsigned Index = TestVocab.getCanonicalIndexForBaseName(BaseName); EXPECT_EQ(Index, Index1); } } // Test MIRVocabulary construction TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { auto VocabOrErr = createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, {}, {}, {}, 128); ASSERT_TRUE(static_cast(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; EXPECT_EQ(Vocab.getDimension(), 128u); // Test iterator - iterates over individual embeddings auto IT = Vocab.begin(); EXPECT_NE(IT, Vocab.end()); // Check first embedding exists and has correct dimension EXPECT_EQ((*IT).size(), 128u); size_t Count = 0; for (auto IT = Vocab.begin(); IT != Vocab.end(); ++IT) { EXPECT_EQ((*IT).size(), 128u); ++Count; } EXPECT_GT(Count, 0u); } // Fixture for embedding related tests class MIR2VecEmbeddingTestFixture : public MIR2VecVocabTestFixture { protected: void SetUp() override { MIR2VecVocabTestFixture::SetUp(); // If base class setup was skipped (TII not initialized), skip derived setup if (!TII) GTEST_SKIP() << "Failed to get target instruction info in " "the base class setup; Skipping test"; } void TearDown() override { MIR2VecVocabTestFixture::TearDown(); } // Create a machine instruction MachineInstr *createMachineInstr(MachineBasicBlock &MBB, unsigned Opcode) { const MCInstrDesc &Desc = TII->get(Opcode); // Create instruction - operands don't affect opcode-based embeddings MachineInstr *MI = BuildMI(MBB, MBB.end(), DebugLoc(), Desc); return MI; } MachineInstr *createMachineInstr(MachineBasicBlock &MBB, const char *OpcodeName) { int Opcode = findOpcodeByName(OpcodeName); if (Opcode == -1) return nullptr; return createMachineInstr(MBB, Opcode); } void createMachineInstrs(MachineBasicBlock &MBB, std::initializer_list Opcodes) { for (const char *OpcodeName : Opcodes) { MachineInstr *MI = createMachineInstr(MBB, OpcodeName); ASSERT_TRUE(MI != nullptr); } } }; // Test factory method for creating embedder TEST_F(MIR2VecEmbeddingTestFixture, CreateSymbolicEmbedder) { auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 1); ASSERT_TRUE(static_cast(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &V = *VocabOrErr; auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, *MF, V); EXPECT_NE(Emb, nullptr); } TEST_F(MIR2VecEmbeddingTestFixture, CreateInvalidMode) { auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 1); ASSERT_TRUE(static_cast(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &V = *VocabOrErr; auto Result = MIREmbedder::create(static_cast(-1), *MF, V); EXPECT_FALSE(static_cast(Result)); } // Test SymbolicMIREmbedder with simple target opcodes TEST_F(MIR2VecEmbeddingTestFixture, TestSymbolicEmbedder) { // Create a test vocabulary with specific values auto VocabOrErr = createTestVocab( { {"NOOP", 1.0f}, // [1.0, 1.0, 1.0, 1.0] {"RET", 2.0f}, // [2.0, 2.0, 2.0, 2.0] {"TRAP", 3.0f} // [3.0, 3.0, 3.0, 3.0] }, {}, {}, {}, 4); ASSERT_TRUE(static_cast(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; // Create a basic block using fixture's MF MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); MF->push_back(MBB); // Use real X86 opcodes that should exist and not be pseudo auto NoopInst = createMachineInstr(*MBB, "NOOP"); ASSERT_TRUE(NoopInst != nullptr); auto RetInst = createMachineInstr(*MBB, "RET64"); ASSERT_TRUE(RetInst != nullptr); auto TrapInst = createMachineInstr(*MBB, "TRAP"); ASSERT_TRUE(TrapInst != nullptr); // Verify these are not pseudo instructions ASSERT_FALSE(NoopInst->isPseudo()) << "NOOP is marked as pseudo instruction"; ASSERT_FALSE(RetInst->isPseudo()) << "RET is marked as pseudo instruction"; ASSERT_FALSE(TrapInst->isPseudo()) << "TRAP is marked as pseudo instruction"; // Create embedder auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); ASSERT_TRUE(Embedder != nullptr); // Test instruction embeddings auto NoopEmb = Embedder->getMInstVector(*NoopInst); auto RetEmb = Embedder->getMInstVector(*RetInst); auto TrapEmb = Embedder->getMInstVector(*TrapInst); // Verify embeddings match expected values (accounting for weight scaling) float ExpectedWeight = mir2vec::OpcWeight; // Global weight from command line EXPECT_TRUE(NoopEmb.approximatelyEquals(Embedding(4, 1.0f * ExpectedWeight))); EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(4, 2.0f * ExpectedWeight))); EXPECT_TRUE(TrapEmb.approximatelyEquals(Embedding(4, 3.0f * ExpectedWeight))); // Test basic block embedding (should be sum of instruction embeddings) auto MBBVector = Embedder->getMBBVector(*MBB); // Expected BB vector: NOOP + RET + TRAP = [1+2+3, 1+2+3, 1+2+3, 1+2+3] * // weight = [6, 6, 6, 6] * weight Embedding ExpectedMBBVector(4, 6.0f * ExpectedWeight); EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedMBBVector)); // Test function embedding (should equal MBB embedding since we have one MBB) auto MFuncVector = Embedder->getMFunctionVector(); EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBBVector)); } // Test embedder with multiple basic blocks TEST_F(MIR2VecEmbeddingTestFixture, MultipleBasicBlocks) { // Create a test vocabulary auto VocabOrErr = createTestVocab({{"NOOP", 1.0f}, {"TRAP", 2.0f}}, {}, {}, {}); ASSERT_TRUE(static_cast(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; // Create two basic blocks using fixture's MF MachineBasicBlock *MBB1 = MF->CreateMachineBasicBlock(); MachineBasicBlock *MBB2 = MF->CreateMachineBasicBlock(); MF->push_back(MBB1); MF->push_back(MBB2); createMachineInstrs(*MBB1, {"NOOP", "NOOP"}); createMachineInstr(*MBB2, "TRAP"); // Create embedder auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); ASSERT_TRUE(Embedder != nullptr); // Test basic block embeddings auto MBB1Vector = Embedder->getMBBVector(*MBB1); auto MBB2Vector = Embedder->getMBBVector(*MBB2); float ExpectedWeight = mir2vec::OpcWeight; // BB1: NOOP + NOOP = 2 * ([1, 1] * weight) Embedding ExpectedMBB1Vector(2, 2.0f * ExpectedWeight); EXPECT_TRUE(MBB1Vector.approximatelyEquals(ExpectedMBB1Vector)); // BB2: TRAP = [2, 2] * weight Embedding ExpectedMBB2Vector(2, 2.0f * ExpectedWeight); EXPECT_TRUE(MBB2Vector.approximatelyEquals(ExpectedMBB2Vector)); // Function embedding: BB1 + BB2 = [2+2, 2+2] * weight = [4, 4] * weight // Function embedding should be just the first BB embedding as the second BB // is unreachable auto MFuncVector = Embedder->getMFunctionVector(); EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBB1Vector)); // Add a branch from BB1 to BB2 to make both reachable; now function embedding // should be MBB1 + MBB2 MBB1->addSuccessor(MBB2); auto NewMFuncVector = Embedder->getMFunctionVector(); // Recompute embeddings Embedding ExpectedFuncVector = MBB1Vector + MBB2Vector; EXPECT_TRUE(NewMFuncVector.approximatelyEquals(ExpectedFuncVector)); } // Test embedder with empty basic block TEST_F(MIR2VecEmbeddingTestFixture, EmptyBasicBlock) { // Create an empty basic block MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); MF->push_back(MBB); // Create embedder auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 2); ASSERT_TRUE(static_cast(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &V = *VocabOrErr; auto Embedder = SymbolicMIREmbedder::create(*MF, V); ASSERT_TRUE(Embedder != nullptr); // Test that empty BB has zero embedding auto MBBVector = Embedder->getMBBVector(*MBB); Embedding ExpectedBBVector(2, 0.0f); EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector)); // Function embedding should also be zero auto MFuncVector = Embedder->getMFunctionVector(); EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedBBVector)); } // Test embedder with opcodes not in vocabulary TEST_F(MIR2VecEmbeddingTestFixture, UnknownOpcodes) { // Create a test vocabulary with limited entries // SUB is intentionally not included auto VocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {}); ASSERT_TRUE(static_cast(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; // Create a basic block MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); MF->push_back(MBB); // Find opcodes int AddOpcode = findOpcodeByName("ADD32rr"); int SubOpcode = findOpcodeByName("SUB32rr"); ASSERT_NE(AddOpcode, -1) << "ADD32rr opcode not found"; ASSERT_NE(SubOpcode, -1) << "SUB32rr opcode not found"; // Create instructions MachineInstr *AddInstr = createMachineInstr(*MBB, AddOpcode); MachineInstr *SubInstr = createMachineInstr(*MBB, SubOpcode); // Create embedder auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); ASSERT_TRUE(Embedder != nullptr); // Test instruction embeddings auto AddVector = Embedder->getMInstVector(*AddInstr); auto SubVector = Embedder->getMInstVector(*SubInstr); float ExpectedWeight = mir2vec::OpcWeight; // ADD should have the embedding from vocabulary EXPECT_TRUE( AddVector.approximatelyEquals(Embedding(2, 1.0f * ExpectedWeight))); // SUB should have zero embedding (not in vocabulary) EXPECT_TRUE(SubVector.approximatelyEquals(Embedding(2, 0.0f))); // Basic block embedding should be ADD + SUB = [1.0, 1.0] * weight + [0.0, // 0.0] = [1.0, 1.0] * weight const auto &MBBVector = Embedder->getMBBVector(*MBB); Embedding ExpectedBBVector(2, 1.0f * ExpectedWeight); EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector)); } // Test vocabulary string key generation TEST_F(MIR2VecEmbeddingTestFixture, VocabularyStringKeys) { auto VocabOrErr = createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, {}, {}, {}, 2); ASSERT_TRUE(static_cast(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; // Test that we can get string keys for all positions for (size_t Pos = 0; Pos < Vocab.getCanonicalSize(); ++Pos) { std::string Key = Vocab.getStringKey(Pos); EXPECT_FALSE(Key.empty()) << "Empty key at position " << Pos; } // Test specific known positions if we can identify them unsigned AddIndex = Vocab.getCanonicalIndexForBaseName("ADD"); std::string AddKey = Vocab.getStringKey(AddIndex); EXPECT_EQ(AddKey, "ADD"); unsigned SubIndex = Vocab.getCanonicalIndexForBaseName("SUB"); std::string SubKey = Vocab.getStringKey(SubIndex); EXPECT_EQ(SubKey, "SUB"); unsigned ImmIndex = Vocab.getCanonicalIndexForOperandName("Immediate"); std::string ImmKey = Vocab.getStringKey(ImmIndex); EXPECT_EQ(ImmKey, "Immediate"); unsigned PhyRegIndex = Vocab.getCanonicalIndexForRegisterClass("GR32", true); std::string PhyRegKey = Vocab.getStringKey(PhyRegIndex); EXPECT_EQ(PhyRegKey, "PhyReg_GR32"); unsigned VirtRegIndex = Vocab.getCanonicalIndexForRegisterClass("GR32", false); std::string VirtRegKey = Vocab.getStringKey(VirtRegIndex); EXPECT_EQ(VirtRegKey, "VirtReg_GR32"); } // Test vocabulary dimension consistency TEST_F(MIR2VecEmbeddingTestFixture, DimensionConsistency) { auto VocabOrErr = createTestVocab({{"TEST", 1.0f}}, {}, {}, {}, 5); ASSERT_TRUE(static_cast(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; EXPECT_EQ(Vocab.getDimension(), 5u); // All embeddings should have the same dimension for (auto IT = Vocab.begin(); IT != Vocab.end(); ++IT) EXPECT_EQ((*IT).size(), 5u); } // Test invalid register handling through machine instruction creation TEST_F(MIR2VecEmbeddingTestFixture, InvalidRegisterHandling) { float MOVValue = 1.5f; float ImmValue = 0.5f; float PhyRegValue = 0.2f; auto VocabOrErr = createTestVocab( {{"MOV", MOVValue}}, {{"Immediate", ImmValue}}, {{"GR8_ABCD_H", PhyRegValue}, {"GR8_ABCD_L", PhyRegValue + 0.1f}}, {}, 3); ASSERT_TRUE(static_cast(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); MF->push_back(MBB); // Create a MOV instruction with actual operands including potential $noreg // This tests the actual scenario where invalid registers are encountered auto MovOpcode = findOpcodeByName("MOV32mr"); ASSERT_NE(MovOpcode, -1) << "MOV32mr opcode not found"; const MCInstrDesc &Desc = TII->get(MovOpcode); // Use available physical registers from the target unsigned BaseReg = TRI->getNumRegs() > 1 ? 1 : 0; // First available physical register unsigned ValueReg = TRI->getNumRegs() > 2 ? 2 : BaseReg; // MOV32mr typically has: base, scale, index, displacement, segment, value // Use the MachineInstrBuilder API properly auto MovInst = BuildMI(*MBB, MBB->end(), DebugLoc(), Desc) .addReg(BaseReg) // base .addImm(1) // scale .addReg(0) // index ($noreg) .addImm(-4) // displacement .addReg(0) // segment ($noreg) .addReg(ValueReg); // value auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); ASSERT_TRUE(Embedder != nullptr); // This should not crash even if the instruction has $noreg operands auto InstEmb = Embedder->getMInstVector(*MovInst); EXPECT_EQ(InstEmb.size(), 3u); // Test the expected embedding value Embedding ExpectedOpcodeContribution(3, MOVValue * mir2vec::OpcWeight); auto ExpectedOperandContribution = Embedding(3, PhyRegValue * mir2vec::RegOperandWeight) // Base + Embedding(3, ImmValue * mir2vec::CommonOperandWeight) // Scale + Embedding(3, 0.0f) // noreg + Embedding(3, ImmValue * mir2vec::CommonOperandWeight) // displacement + Embedding(3, 0.0f) // noreg + Embedding(3, (PhyRegValue + 0.1f) * mir2vec::RegOperandWeight); // Value auto ExpectedEmb = ExpectedOpcodeContribution + ExpectedOperandContribution; EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb)) << "MOV instruction embedding should match expected embedding"; } // Test handling of both physical and virtual registers in an instruction TEST_F(MIR2VecEmbeddingTestFixture, PhysicalAndVirtualRegisterHandling) { float MOVValue = 2.0f; float ImmValue = 0.7f; float PhyRegValue = 0.3f; float VirtRegValue = 0.9f; // Find GR32 register class const TargetRegisterClass *GR32RC = nullptr; for (unsigned i = 0; i < TRI->getNumRegClasses(); ++i) { const TargetRegisterClass *RC = TRI->getRegClass(i); if (std::string(TRI->getRegClassName(RC)) == "GR32") { GR32RC = RC; break; } } ASSERT_TRUE(GR32RC != nullptr && GR32RC->isAllocatable()) << "No allocatable GR32 register class found"; // Get first available physical register from GR32 unsigned PhyReg = *GR32RC->begin(); // Create a virtual register of class GR32 unsigned VirtReg = MF->getRegInfo().createVirtualRegister(GR32RC); // Create vocabulary with register class based keys auto VocabOrErr = createTestVocab({{"MOV", MOVValue}}, {{"Immediate", ImmValue}}, {{"GR32_AD", PhyRegValue}}, // GR32_AD is the minimal key {{"GR32", VirtRegValue}}, 4); ASSERT_TRUE(static_cast(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); MF->push_back(MBB); // Create a MOV32rr instruction: MOV32rr dst, src auto MovOpcode = findOpcodeByName("MOV32rr"); ASSERT_NE(MovOpcode, -1) << "MOV32rr opcode not found"; const MCInstrDesc &Desc = TII->get(MovOpcode); // MOV32rr: dst (physical), src (virtual) auto MovInst = BuildMI(*MBB, MBB->end(), DebugLoc(), Desc) .addReg(PhyReg) // physical register destination .addReg(VirtReg); // virtual register source // Create embedder with virtual register support auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); ASSERT_TRUE(Embedder != nullptr); // This should not crash and should produce a valid embedding auto InstEmb = Embedder->getMInstVector(*MovInst); EXPECT_EQ(InstEmb.size(), 4u); // Test the expected embedding value Embedding ExpectedOpcodeContribution(4, MOVValue * mir2vec::OpcWeight); auto ExpectedOperandContribution = Embedding(4, PhyRegValue * mir2vec::RegOperandWeight) // dst (physical) + Embedding(4, VirtRegValue * mir2vec::RegOperandWeight); // src (virtual) auto ExpectedEmb = ExpectedOpcodeContribution + ExpectedOperandContribution; EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb)) << "MOV32rr instruction embedding should match expected embedding"; } // Test precise embedding calculation with known operands TEST_F(MIR2VecEmbeddingTestFixture, EmbeddingCalculation) { auto VocabOrErr = createTestVocab({{"NOOP", 2.0f}}, {}, {}, {}, 2); ASSERT_TRUE(static_cast(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); MF->push_back(MBB); // Create a simple NOOP instruction (no operands) auto NoopInst = createMachineInstr(*MBB, "NOOP"); ASSERT_TRUE(NoopInst != nullptr); auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); ASSERT_TRUE(Embedder != nullptr); // Get the instruction embedding auto InstEmb = Embedder->getMInstVector(*NoopInst); EXPECT_EQ(InstEmb.size(), 2u); // For NOOP with no operands, the embedding should be exactly the opcode // embedding float ExpectedWeight = mir2vec::OpcWeight; Embedding ExpectedEmb(2, 2.0f * ExpectedWeight); EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb)) << "NOOP instruction embedding should match opcode embedding"; // Verify individual components EXPECT_FLOAT_EQ(InstEmb[0], 2.0f * ExpectedWeight); EXPECT_FLOAT_EQ(InstEmb[1], 2.0f * ExpectedWeight); } } // namespace