diff options
Diffstat (limited to 'llvm/unittests/CodeGen/MIR2VecTest.cpp')
-rw-r--r-- | llvm/unittests/CodeGen/MIR2VecTest.cpp | 369 |
1 files changed, 321 insertions, 48 deletions
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp index 8710d6b..d42749c 100644 --- a/llvm/unittests/CodeGen/MIR2VecTest.cpp +++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp @@ -54,6 +54,9 @@ protected: std::unique_ptr<Module> M; std::unique_ptr<TargetMachine> TM; const TargetInstrInfo *TII = nullptr; + const TargetRegisterInfo *TRI = nullptr; + std::unique_ptr<MachineModuleInfo> MMI; + MachineFunction *MF = nullptr; static void SetUpTestCase() { InitializeAllTargets(); @@ -90,15 +93,24 @@ protected: Function *F = Function::Create(FT, Function::ExternalLinkage, "test", M.get()); - // Get the target instruction info + // Create MMI and MF to get TRI and MRI + MMI = std::make_unique<MachineModuleInfo>(TM.get()); + MF = &MMI->getOrCreateMachineFunction(*F); + + // Get the target instruction info and register info TII = TM->getSubtargetImpl(*F)->getInstrInfo(); - if (!TII) { - GTEST_SKIP() << "Failed to get target instruction info; Skipping test"; + TRI = TM->getSubtargetImpl(*F)->getRegisterInfo(); + if (!TII || !TRI) { + GTEST_SKIP() + << "Failed to get target instruction/register info; Skipping test"; return; } } - void TearDown() override { TII = nullptr; } + void TearDown() override { + TII = nullptr; + TRI = nullptr; + } // Find an opcode by name int findOpcodeByName(StringRef Name) { @@ -110,17 +122,94 @@ protected: } // Create a vocabulary with specific opcodes and embeddings - Expected<MIRVocabulary> - createTestVocab(std::initializer_list<std::pair<const char *, float>> opcodes, - unsigned dimension = 2) { - assert(TII && "TargetInstrInfo not initialized"); - VocabMap VMap; - for (const auto &[name, value] : opcodes) - VMap[name] = Embedding(dimension, value); - return MIRVocabulary::create(std::move(VMap), *TII); + // This might cause errors in future when the validation in + // MIRVocabulary::generateStorage() enforces hard checks on the vocabulary + // entries. + Expected<MIRVocabulary> createTestVocab( + std::initializer_list<std::pair<const char *, float>> Opcodes, + std::initializer_list<std::pair<const char *, float>> CommonOperands, + std::initializer_list<std::pair<const char *, float>> PhyRegs, + std::initializer_list<std::pair<const char *, float>> VirtRegs, + unsigned Dimension = 2) { + assert(TII && TRI && MF && "Target info not initialized"); + VocabMap OpcodeMap, CommonOperandMap, PhyRegMap, VirtRegMap; + for (const auto &[Name, Value] : Opcodes) + OpcodeMap[Name] = Embedding(Dimension, Value); + + for (const auto &[Name, Value] : CommonOperands) + CommonOperandMap[Name] = Embedding(Dimension, Value); + + for (const auto &[Name, Value] : PhyRegs) + PhyRegMap[Name] = Embedding(Dimension, Value); + + for (const auto &[Name, Value] : VirtRegs) + VirtRegMap[Name] = Embedding(Dimension, Value); + + // If any section is empty, create minimal maps for other vocabulary + // sections to satisfy validation + if (Opcodes.size() == 0) + OpcodeMap["NOOP"] = Embedding(Dimension, 0.0f); + if (CommonOperands.size() == 0) + CommonOperandMap["Immediate"] = Embedding(Dimension, 0.0f); + if (PhyRegs.size() == 0) + PhyRegMap["GR32"] = Embedding(Dimension, 0.0f); + if (VirtRegs.size() == 0) + VirtRegMap["GR32"] = Embedding(Dimension, 0.0f); + + return MIRVocabulary::create( + std::move(OpcodeMap), std::move(CommonOperandMap), std::move(PhyRegMap), + std::move(VirtRegMap), *TII, *TRI, MF->getRegInfo()); } }; +// Parameterized test for empty vocab sections +class MIR2VecVocabEmptySectionTestFixture + : public MIR2VecVocabTestFixture, + public ::testing::WithParamInterface<int> { +protected: + void SetUp() override { + MIR2VecVocabTestFixture::SetUp(); + // If base class setup was skipped (TII not initialized), skip derived setup + if (!TII) + GTEST_SKIP() << "Failed to get target instruction info in " + "the base class setup; Skipping test"; + } +}; + +TEST_P(MIR2VecVocabEmptySectionTestFixture, EmptySectionFailsValidation) { + int EmptySection = GetParam(); + VocabMap OpcodeMap, CommonOperandMap, PhyRegMap, VirtRegMap; + + if (EmptySection != 0) + OpcodeMap["ADD"] = Embedding(2, 1.0f); + if (EmptySection != 1) + CommonOperandMap["Immediate"] = Embedding(2, 0.0f); + if (EmptySection != 2) + PhyRegMap["GR32"] = Embedding(2, 0.0f); + if (EmptySection != 3) + VirtRegMap["GR32"] = Embedding(2, 0.0f); + + ASSERT_TRUE(TII != nullptr); + ASSERT_TRUE(TRI != nullptr); + ASSERT_TRUE(MF != nullptr); + + auto VocabOrErr = MIRVocabulary::create( + std::move(OpcodeMap), std::move(CommonOperandMap), std::move(PhyRegMap), + std::move(VirtRegMap), *TII, *TRI, MF->getRegInfo()); + EXPECT_FALSE(static_cast<bool>(VocabOrErr)) + << "Factory method should fail when section " << EmptySection + << " is empty"; + + if (!VocabOrErr) { + auto Err = VocabOrErr.takeError(); + std::string ErrorMsg = toString(std::move(Err)); + EXPECT_FALSE(ErrorMsg.empty()); + } +} + +INSTANTIATE_TEST_SUITE_P(EmptySection, MIR2VecVocabEmptySectionTestFixture, + ::testing::Values(0, 1, 2, 3)); + TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Test that same base opcodes get same canonical indices std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri"); @@ -133,7 +222,7 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { // Create a MIRVocabulary instance to test the mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction Embedding Val = Embedding(64, 1.0f); - auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, 64); + auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {}, 64); ASSERT_TRUE(static_cast<bool>(TestVocabOrErr)) << "Failed to create vocabulary: " << toString(TestVocabOrErr.takeError()); @@ -190,7 +279,7 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Create a MIRVocabulary instance to test deterministic mapping // Use a minimal MIRVocabulary to trigger canonical mapping construction - auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, 64); + auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {}, 64); ASSERT_TRUE(static_cast<bool>(TestVocabOrErr)) << "Failed to create vocabulary: " << toString(TestVocabOrErr.takeError()); @@ -210,7 +299,8 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Test MIRVocabulary construction TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { - auto VocabOrErr = createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, 128); + auto VocabOrErr = + createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, {}, {}, {}, 128); ASSERT_TRUE(static_cast<bool>(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; @@ -231,42 +321,15 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { EXPECT_GT(Count, 0u); } -// Test factory method with empty vocabulary -TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) { - VocabMap EmptyVMap; - - auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII); - EXPECT_FALSE(static_cast<bool>(VocabOrErr)) - << "Factory method should fail with empty vocabulary"; - - // Consume the error - if (!VocabOrErr) { - auto Err = VocabOrErr.takeError(); - std::string ErrorMsg = toString(std::move(Err)); - EXPECT_FALSE(ErrorMsg.empty()); - } -} - // Fixture for embedding related tests class MIR2VecEmbeddingTestFixture : public MIR2VecVocabTestFixture { protected: - std::unique_ptr<MachineModuleInfo> MMI; - MachineFunction *MF = nullptr; - void SetUp() override { MIR2VecVocabTestFixture::SetUp(); // If base class setup was skipped (TII not initialized), skip derived setup if (!TII) GTEST_SKIP() << "Failed to get target instruction info in " "the base class setup; Skipping test"; - - // Create a dummy function for MachineFunction - FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false); - Function *F = - Function::Create(FT, Function::ExternalLinkage, "test", M.get()); - - MMI = std::make_unique<MachineModuleInfo>(TM.get()); - MF = &MMI->getOrCreateMachineFunction(*F); } void TearDown() override { MIR2VecVocabTestFixture::TearDown(); } @@ -298,7 +361,8 @@ protected: // Test factory method for creating embedder TEST_F(MIR2VecEmbeddingTestFixture, CreateSymbolicEmbedder) { - auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 1); + auto VocabOrErr = + MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 1); ASSERT_TRUE(static_cast<bool>(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &V = *VocabOrErr; @@ -307,7 +371,8 @@ TEST_F(MIR2VecEmbeddingTestFixture, CreateSymbolicEmbedder) { } TEST_F(MIR2VecEmbeddingTestFixture, CreateInvalidMode) { - auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 1); + auto VocabOrErr = + MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 1); ASSERT_TRUE(static_cast<bool>(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &V = *VocabOrErr; @@ -324,7 +389,7 @@ TEST_F(MIR2VecEmbeddingTestFixture, TestSymbolicEmbedder) { {"RET", 2.0f}, // [2.0, 2.0, 2.0, 2.0] {"TRAP", 3.0f} // [3.0, 3.0, 3.0, 3.0] }, - 4); + {}, {}, {}, 4); ASSERT_TRUE(static_cast<bool>(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; @@ -378,7 +443,8 @@ TEST_F(MIR2VecEmbeddingTestFixture, TestSymbolicEmbedder) { // Test embedder with multiple basic blocks TEST_F(MIR2VecEmbeddingTestFixture, MultipleBasicBlocks) { // Create a test vocabulary - auto VocabOrErr = createTestVocab({{"NOOP", 1.0f}, {"TRAP", 2.0f}}); + auto VocabOrErr = + createTestVocab({{"NOOP", 1.0f}, {"TRAP", 2.0f}}, {}, {}, {}); ASSERT_TRUE(static_cast<bool>(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; @@ -431,7 +497,8 @@ TEST_F(MIR2VecEmbeddingTestFixture, EmptyBasicBlock) { MF->push_back(MBB); // Create embedder - auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 2); + auto VocabOrErr = + MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 2); ASSERT_TRUE(static_cast<bool>(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &V = *VocabOrErr; @@ -452,7 +519,7 @@ TEST_F(MIR2VecEmbeddingTestFixture, EmptyBasicBlock) { TEST_F(MIR2VecEmbeddingTestFixture, UnknownOpcodes) { // Create a test vocabulary with limited entries // SUB is intentionally not included - auto VocabOrErr = createTestVocab({{"ADD", 1.0f}}); + auto VocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {}); ASSERT_TRUE(static_cast<bool>(VocabOrErr)) << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); auto &Vocab = *VocabOrErr; @@ -494,4 +561,210 @@ TEST_F(MIR2VecEmbeddingTestFixture, UnknownOpcodes) { Embedding ExpectedBBVector(2, 1.0f * ExpectedWeight); EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector)); } + +// Test vocabulary string key generation +TEST_F(MIR2VecEmbeddingTestFixture, VocabularyStringKeys) { + auto VocabOrErr = + createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, {}, {}, {}, 2); + ASSERT_TRUE(static_cast<bool>(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &Vocab = *VocabOrErr; + + // Test that we can get string keys for all positions + for (size_t Pos = 0; Pos < Vocab.getCanonicalSize(); ++Pos) { + std::string Key = Vocab.getStringKey(Pos); + EXPECT_FALSE(Key.empty()) << "Empty key at position " << Pos; + } + + // Test specific known positions if we can identify them + unsigned AddIndex = Vocab.getCanonicalIndexForBaseName("ADD"); + std::string AddKey = Vocab.getStringKey(AddIndex); + EXPECT_EQ(AddKey, "ADD"); + + unsigned SubIndex = Vocab.getCanonicalIndexForBaseName("SUB"); + std::string SubKey = Vocab.getStringKey(SubIndex); + EXPECT_EQ(SubKey, "SUB"); + + unsigned ImmIndex = Vocab.getCanonicalIndexForOperandName("Immediate"); + std::string ImmKey = Vocab.getStringKey(ImmIndex); + EXPECT_EQ(ImmKey, "Immediate"); + + unsigned PhyRegIndex = Vocab.getCanonicalIndexForRegisterClass("GR32", true); + std::string PhyRegKey = Vocab.getStringKey(PhyRegIndex); + EXPECT_EQ(PhyRegKey, "PhyReg_GR32"); + + unsigned VirtRegIndex = + Vocab.getCanonicalIndexForRegisterClass("GR32", false); + std::string VirtRegKey = Vocab.getStringKey(VirtRegIndex); + EXPECT_EQ(VirtRegKey, "VirtReg_GR32"); +} + +// Test vocabulary dimension consistency +TEST_F(MIR2VecEmbeddingTestFixture, DimensionConsistency) { + auto VocabOrErr = createTestVocab({{"TEST", 1.0f}}, {}, {}, {}, 5); + ASSERT_TRUE(static_cast<bool>(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &Vocab = *VocabOrErr; + + EXPECT_EQ(Vocab.getDimension(), 5u); + + // All embeddings should have the same dimension + for (auto IT = Vocab.begin(); IT != Vocab.end(); ++IT) + EXPECT_EQ((*IT).size(), 5u); +} + +// Test invalid register handling through machine instruction creation +TEST_F(MIR2VecEmbeddingTestFixture, InvalidRegisterHandling) { + float MOVValue = 1.5f; + float ImmValue = 0.5f; + float PhyRegValue = 0.2f; + auto VocabOrErr = createTestVocab( + {{"MOV", MOVValue}}, {{"Immediate", ImmValue}}, + {{"GR8_ABCD_H", PhyRegValue}, {"GR8_ABCD_L", PhyRegValue + 0.1f}}, {}, 3); + ASSERT_TRUE(static_cast<bool>(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &Vocab = *VocabOrErr; + + MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); + MF->push_back(MBB); + + // Create a MOV instruction with actual operands including potential $noreg + // This tests the actual scenario where invalid registers are encountered + auto MovOpcode = findOpcodeByName("MOV32mr"); + ASSERT_NE(MovOpcode, -1) << "MOV32mr opcode not found"; + const MCInstrDesc &Desc = TII->get(MovOpcode); + + // Use available physical registers from the target + unsigned BaseReg = + TRI->getNumRegs() > 1 ? 1 : 0; // First available physical register + unsigned ValueReg = TRI->getNumRegs() > 2 ? 2 : BaseReg; + + // MOV32mr typically has: base, scale, index, displacement, segment, value + // Use the MachineInstrBuilder API properly + auto MovInst = BuildMI(*MBB, MBB->end(), DebugLoc(), Desc) + .addReg(BaseReg) // base + .addImm(1) // scale + .addReg(0) // index ($noreg) + .addImm(-4) // displacement + .addReg(0) // segment ($noreg) + .addReg(ValueReg); // value + + auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); + ASSERT_TRUE(Embedder != nullptr); + + // This should not crash even if the instruction has $noreg operands + auto InstEmb = Embedder->getMInstVector(*MovInst); + EXPECT_EQ(InstEmb.size(), 3u); + + // Test the expected embedding value + Embedding ExpectedOpcodeContribution(3, MOVValue * mir2vec::OpcWeight); + auto ExpectedOperandContribution = + Embedding(3, PhyRegValue * mir2vec::RegOperandWeight) // Base + + Embedding(3, ImmValue * mir2vec::CommonOperandWeight) // Scale + + Embedding(3, 0.0f) // noreg + + Embedding(3, ImmValue * mir2vec::CommonOperandWeight) // displacement + + Embedding(3, 0.0f) // noreg + + Embedding(3, (PhyRegValue + 0.1f) * mir2vec::RegOperandWeight); // Value + auto ExpectedEmb = ExpectedOpcodeContribution + ExpectedOperandContribution; + EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb)) + << "MOV instruction embedding should match expected embedding"; +} + +// Test handling of both physical and virtual registers in an instruction +TEST_F(MIR2VecEmbeddingTestFixture, PhysicalAndVirtualRegisterHandling) { + float MOVValue = 2.0f; + float ImmValue = 0.7f; + float PhyRegValue = 0.3f; + float VirtRegValue = 0.9f; + + // Find GR32 register class + const TargetRegisterClass *GR32RC = nullptr; + for (unsigned i = 0; i < TRI->getNumRegClasses(); ++i) { + const TargetRegisterClass *RC = TRI->getRegClass(i); + if (std::string(TRI->getRegClassName(RC)) == "GR32") { + GR32RC = RC; + break; + } + } + ASSERT_TRUE(GR32RC != nullptr && GR32RC->isAllocatable()) + << "No allocatable GR32 register class found"; + + // Get first available physical register from GR32 + unsigned PhyReg = *GR32RC->begin(); + // Create a virtual register of class GR32 + unsigned VirtReg = MF->getRegInfo().createVirtualRegister(GR32RC); + + // Create vocabulary with register class based keys + auto VocabOrErr = + createTestVocab({{"MOV", MOVValue}}, {{"Immediate", ImmValue}}, + {{"GR32_AD", PhyRegValue}}, // GR32_AD is the minimal key + {{"GR32", VirtRegValue}}, 4); + ASSERT_TRUE(static_cast<bool>(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &Vocab = *VocabOrErr; + + MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); + MF->push_back(MBB); + + // Create a MOV32rr instruction: MOV32rr dst, src + auto MovOpcode = findOpcodeByName("MOV32rr"); + ASSERT_NE(MovOpcode, -1) << "MOV32rr opcode not found"; + const MCInstrDesc &Desc = TII->get(MovOpcode); + + // MOV32rr: dst (physical), src (virtual) + auto MovInst = BuildMI(*MBB, MBB->end(), DebugLoc(), Desc) + .addReg(PhyReg) // physical register destination + .addReg(VirtReg); // virtual register source + + // Create embedder with virtual register support + auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); + ASSERT_TRUE(Embedder != nullptr); + + // This should not crash and should produce a valid embedding + auto InstEmb = Embedder->getMInstVector(*MovInst); + EXPECT_EQ(InstEmb.size(), 4u); + + // Test the expected embedding value + Embedding ExpectedOpcodeContribution(4, MOVValue * mir2vec::OpcWeight); + auto ExpectedOperandContribution = + Embedding(4, PhyRegValue * mir2vec::RegOperandWeight) // dst (physical) + + Embedding(4, VirtRegValue * mir2vec::RegOperandWeight); // src (virtual) + auto ExpectedEmb = ExpectedOpcodeContribution + ExpectedOperandContribution; + EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb)) + << "MOV32rr instruction embedding should match expected embedding"; +} + +// Test precise embedding calculation with known operands +TEST_F(MIR2VecEmbeddingTestFixture, EmbeddingCalculation) { + auto VocabOrErr = createTestVocab({{"NOOP", 2.0f}}, {}, {}, {}, 2); + ASSERT_TRUE(static_cast<bool>(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &Vocab = *VocabOrErr; + + MachineBasicBlock *MBB = MF->CreateMachineBasicBlock(); + MF->push_back(MBB); + + // Create a simple NOOP instruction (no operands) + auto NoopInst = createMachineInstr(*MBB, "NOOP"); + ASSERT_TRUE(NoopInst != nullptr); + + auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab); + ASSERT_TRUE(Embedder != nullptr); + + // Get the instruction embedding + auto InstEmb = Embedder->getMInstVector(*NoopInst); + EXPECT_EQ(InstEmb.size(), 2u); + + // For NOOP with no operands, the embedding should be exactly the opcode + // embedding + float ExpectedWeight = mir2vec::OpcWeight; + Embedding ExpectedEmb(2, 2.0f * ExpectedWeight); + + EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb)) + << "NOOP instruction embedding should match opcode embedding"; + + // Verify individual components + EXPECT_FLOAT_EQ(InstEmb[0], 2.0f * ExpectedWeight); + EXPECT_FLOAT_EQ(InstEmb[1], 2.0f * ExpectedWeight); +} } // namespace |