diff options
Diffstat (limited to 'llvm/unittests/CodeGen')
-rw-r--r-- | llvm/unittests/CodeGen/MIR2VecTest.cpp | 88 |
1 files changed, 55 insertions, 33 deletions
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp index 01f2ead..d243d82 100644 --- a/llvm/unittests/CodeGen/MIR2VecTest.cpp +++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp @@ -54,27 +54,32 @@ protected: std::unique_ptr<TargetMachine> TM; const TargetInstrInfo *TII; + static void SetUpTestCase() { + InitializeAllTargets(); + InitializeAllTargetMCs(); + } + void SetUp() override { - LLVMInitializeX86TargetInfo(); - LLVMInitializeX86Target(); - LLVMInitializeX86TargetMC(); + Triple TargetTriple("x86_64-unknown-linux-gnu"); + std::string Error; + const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error); + if (!T) { + GTEST_SKIP() << "x86_64-unknown-linux-gnu target triple not available; " + "Skipping test"; + return; + } Ctx = std::make_unique<LLVMContext>(); M = std::make_unique<Module>("test", *Ctx); - - // Set up X86 target - Triple TargetTriple("x86_64-unknown-linux-gnu"); M->setTargetTriple(TargetTriple); - std::string Error; - const Target *TheTarget = - TargetRegistry::lookupTarget(M->getTargetTriple(), Error); - ASSERT_TRUE(TheTarget) << "Failed to lookup target: " << Error; - TargetOptions Options; - TM = std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine( - M->getTargetTriple(), "", "", Options, Reloc::Model::Static)); - ASSERT_TRUE(TM); + TM = std::unique_ptr<TargetMachine>( + T->createTargetMachine(TargetTriple, "", "", Options, std::nullopt)); + if (!TM) { + GTEST_SKIP() << "Failed to create X86 target machine; Skipping test"; + return; + } // Create a dummy function to get subtarget info FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false); @@ -83,10 +88,22 @@ protected: // Get the target instruction info TII = TM->getSubtargetImpl(*F)->getInstrInfo(); - ASSERT_TRUE(TII); + if (!TII) { + GTEST_SKIP() << "Failed to get target instruction info; Skipping test"; + return; + } } }; +// Function to find an opcode by name +static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) { + for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) { + if (TII->getName(Opcode) == Name) + return Opcode; + } + return -1; // Not found +} + TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Test that same base opcodes get same canonical indices std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri"); @@ -98,10 +115,10 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Create a MIRVocabulary instance to test the mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction - VocabMap VM; + VocabMap VMap; Embedding Val = Embedding(64, 1.0f); - VM["ADD"] = Val; - MIRVocabulary TestVocab(std::move(VM), TII); + VMap["ADD"] = Val; + MIRVocabulary TestVocab(std::move(VMap), TII); unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2); @@ -132,9 +149,19 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { 6880u); // X86 has >6880 unique base opcodes // Check that the embeddings for opcodes not in the vocab are zero vectors - EXPECT_TRUE(TestVocab[AddIndex].approximatelyEquals(Val)); - EXPECT_TRUE(TestVocab[SubIndex].approximatelyEquals(Embedding(64, 0.0f))); - EXPECT_TRUE(TestVocab[MovIndex].approximatelyEquals(Embedding(64, 0.0f))); + int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr"); + ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found"; + EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val)); + + int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr"); + ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found"; + EXPECT_TRUE( + TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); + + int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr"); + ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found"; + EXPECT_TRUE( + TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f))); } // Test deterministic mapping @@ -144,9 +171,9 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Create a MIRVocabulary instance to test deterministic mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction - VocabMap VM; - VM["ADD"] = Embedding(64, 1.0f); - MIRVocabulary TestVocab(std::move(VM), TII); + VocabMap VMap; + VMap["ADD"] = Embedding(64, 1.0f); + MIRVocabulary TestVocab(std::move(VMap), TII); unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName); @@ -164,16 +191,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Test MIRVocabulary construction TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { - // Test empty MIRVocabulary - MIRVocabulary EmptyVocab; - EXPECT_FALSE(EmptyVocab.isValid()); - - // Test MIRVocabulary with embeddings via VocabMap - VocabMap VM; - VM["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0 - VM["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0 + VocabMap VMap; + VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0 + VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0 - MIRVocabulary Vocab(std::move(VM), TII); + MIRVocabulary Vocab(std::move(VMap), TII); EXPECT_TRUE(Vocab.isValid()); EXPECT_EQ(Vocab.getDimension(), 128u); |