//===- MIR2VecTest.cpp ---------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "llvm/CodeGen/MIR2Vec.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" #include "llvm/TargetParser/Triple.h" #include "gtest/gtest.h" using namespace llvm; using namespace mir2vec; using VocabMap = std::map; namespace { TEST(MIR2VecTest, RegexExtraction) { // Test simple instruction names EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("NOP"), "NOP"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("RET"), "RET"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("ADD16ri"), "ADD"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("ADD32rr"), "ADD"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("ADD64rm"), "ADD"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("MOV8ri"), "MOV"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("MOV32mr"), "MOV"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("PUSH64r"), "PUSH"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("POP64r"), "POP"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("JMP_4"), "JMP"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("CALL64pcrel32"), "CALL"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("SOME_INSTR_123"), "SOME_INSTR"); EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("123ADD"), "ADD"); EXPECT_FALSE(MIRVocabulary::extractBaseOpcodeName("123").empty()); } class MIR2VecVocabTestFixture : public ::testing::Test { protected: std::unique_ptr Ctx; std::unique_ptr M; std::unique_ptr TM; const TargetInstrInfo *TII = nullptr; static void SetUpTestCase() { InitializeAllTargets(); InitializeAllTargetMCs(); } void SetUp() override { Triple TargetTriple("x86_64-unknown-linux-gnu"); std::string Error; const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error); if (!T) { GTEST_SKIP() << "x86_64-unknown-linux-gnu target triple not available; " "Skipping test"; return; } Ctx = std::make_unique(); M = std::make_unique("test", *Ctx); M->setTargetTriple(TargetTriple); TargetOptions Options; TM = std::unique_ptr( T->createTargetMachine(TargetTriple, "", "", Options, std::nullopt)); if (!TM) { GTEST_SKIP() << "Failed to create X86 target machine; Skipping test"; return; } // Create a dummy function to get subtarget info FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false); Function *F = Function::Create(FT, Function::ExternalLinkage, "test", M.get()); // Get the target instruction info TII = TM->getSubtargetImpl(*F)->getInstrInfo(); if (!TII) { GTEST_SKIP() << "Failed to get target instruction info; Skipping test"; return; } } void TearDown() override { TII = nullptr; } }; // Function to find an opcode by name static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) { for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) { if (TII->getName(Opcode) == Name) return Opcode; } return -1; // Not found } TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Test that same base opcodes get same canonical indices std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri"); std::string BaseName2 = MIRVocabulary::extractBaseOpcodeName("ADD32rr"); std::string BaseName3 = MIRVocabulary::extractBaseOpcodeName("ADD64rm"); EXPECT_EQ(BaseName1, BaseName2); EXPECT_EQ(BaseName2, BaseName3); // Create a MIRVocabulary instance to test the mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction VocabMap VMap; Embedding Val = Embedding(64, 1.0f); VMap["ADD"] = Val; auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); ASSERT_TRUE(static_cast(TestVocabOrErr)) << "Failed to create vocabulary: " << toString(TestVocabOrErr.takeError()); auto &TestVocab = *TestVocabOrErr; unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2); unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName3); EXPECT_EQ(Index1, Index2); EXPECT_EQ(Index2, Index3); // Test that different base opcodes get different canonical indices std::string AddBase = MIRVocabulary::extractBaseOpcodeName("ADD32rr"); std::string SubBase = MIRVocabulary::extractBaseOpcodeName("SUB32rr"); std::string MovBase = MIRVocabulary::extractBaseOpcodeName("MOV32rr"); unsigned AddIndex = TestVocab.getCanonicalIndexForBaseName(AddBase); unsigned SubIndex = TestVocab.getCanonicalIndexForBaseName(SubBase); unsigned MovIndex = TestVocab.getCanonicalIndexForBaseName(MovBase); EXPECT_NE(AddIndex, SubIndex); EXPECT_NE(SubIndex, MovIndex); EXPECT_NE(AddIndex, MovIndex); // Even though we only added "ADD" to the vocab, the canonical mapping // should assign unique indices to all the base opcodes of the target // Ideally, we would check against the exact number of unique base opcodes // for X86, but that would make the test brittle. So we just check that // the number is reasonably closer to the expected number (>6880) and not just // opcodes that we added. EXPECT_GT(TestVocab.getCanonicalSize(), 6880u); // X86 has >6880 unique base opcodes // Check that the embeddings for opcodes not in the vocab are zero vectors int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr"); ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found"; EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val)); int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr"); ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found"; EXPECT_TRUE( TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr"); ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found"; EXPECT_TRUE( TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); } // Test deterministic mapping TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Test that the same base name always maps to the same canonical index std::string BaseName = "ADD"; // Create a MIRVocabulary instance to test deterministic mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction VocabMap VMap; VMap["ADD"] = Embedding(64, 1.0f); auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); ASSERT_TRUE(static_cast(TestVocabOrErr)) << "Failed to create vocabulary: " << toString(TestVocabOrErr.takeError()); auto &TestVocab = *TestVocabOrErr; unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName); unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName); EXPECT_EQ(Index1, Index2); EXPECT_EQ(Index2, Index3); // Test across multiple runs for (int Pos = 0; Pos < 100; ++Pos) { unsigned Index = TestVocab.getCanonicalIndexForBaseName(BaseName); EXPECT_EQ(Index, Index1); } } // Test MIRVocabulary construction TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { VocabMap VMap; VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0 VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0 auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); ASSERT_TRUE(static_cast(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; EXPECT_EQ(Vocab.getDimension(), 128u); // Test iterator - iterates over individual embeddings auto IT = Vocab.begin(); EXPECT_NE(IT, Vocab.end()); // Check first embedding exists and has correct dimension EXPECT_EQ((*IT).size(), 128u); size_t Count = 0; for (auto IT = Vocab.begin(); IT != Vocab.end(); ++IT) { EXPECT_EQ((*IT).size(), 128u); ++Count; } EXPECT_GT(Count, 0u); } // Test factory method with empty vocabulary TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) { VocabMap EmptyVMap; auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII); EXPECT_FALSE(static_cast(VocabOrErr)) << "Factory method should fail with empty vocabulary"; // Consume the error if (!VocabOrErr) { auto Err = VocabOrErr.takeError(); std::string ErrorMsg = toString(std::move(Err)); EXPECT_FALSE(ErrorMsg.empty()); } } } // namespace