aboutsummaryrefslogtreecommitdiff
path: root/llvm/unittests/CodeGen
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/unittests/CodeGen')
-rw-r--r--llvm/unittests/CodeGen/AsmPrinterDwarfTest.cpp16
-rw-r--r--llvm/unittests/CodeGen/InstrRefLDVTest.cpp2
-rw-r--r--llvm/unittests/CodeGen/MIR2VecTest.cpp598
-rw-r--r--llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp25
4 files changed, 595 insertions, 46 deletions
diff --git a/llvm/unittests/CodeGen/AsmPrinterDwarfTest.cpp b/llvm/unittests/CodeGen/AsmPrinterDwarfTest.cpp
index 6c08173..af2d56d 100644
--- a/llvm/unittests/CodeGen/AsmPrinterDwarfTest.cpp
+++ b/llvm/unittests/CodeGen/AsmPrinterDwarfTest.cpp
@@ -383,14 +383,14 @@ class AsmPrinterHandlerTest : public AsmPrinterFixtureBase {
public:
TestHandler(AsmPrinterHandlerTest &Test) : Test(Test) {}
- virtual ~TestHandler() {}
- virtual void setSymbolSize(const MCSymbol *Sym, uint64_t Size) override {}
- virtual void beginModule(Module *M) override { Test.BeginCount++; }
- virtual void endModule() override { Test.EndCount++; }
- virtual void beginFunction(const MachineFunction *MF) override {}
- virtual void endFunction(const MachineFunction *MF) override {}
- virtual void beginInstruction(const MachineInstr *MI) override {}
- virtual void endInstruction() override {}
+ ~TestHandler() override {}
+ void setSymbolSize(const MCSymbol *Sym, uint64_t Size) override {}
+ void beginModule(Module *M) override { Test.BeginCount++; }
+ void endModule() override { Test.EndCount++; }
+ void beginFunction(const MachineFunction *MF) override {}
+ void endFunction(const MachineFunction *MF) override {}
+ void beginInstruction(const MachineInstr *MI) override {}
+ void endInstruction() override {}
};
protected:
diff --git a/llvm/unittests/CodeGen/InstrRefLDVTest.cpp b/llvm/unittests/CodeGen/InstrRefLDVTest.cpp
index ce2a38b..ff87e7b 100644
--- a/llvm/unittests/CodeGen/InstrRefLDVTest.cpp
+++ b/llvm/unittests/CodeGen/InstrRefLDVTest.cpp
@@ -69,7 +69,7 @@ public:
InstrRefLDVTest() : Ctx(), Mod(std::make_unique<Module>("beehives", Ctx)) {}
- void SetUp() {
+ void SetUp() override {
// Boilerplate that creates a MachineFunction and associated blocks.
Mod->setDataLayout("e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-"
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index 11222b4..d42749c 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -54,6 +54,9 @@ protected:
std::unique_ptr<Module> M;
std::unique_ptr<TargetMachine> TM;
const TargetInstrInfo *TII = nullptr;
+ const TargetRegisterInfo *TRI = nullptr;
+ std::unique_ptr<MachineModuleInfo> MMI;
+ MachineFunction *MF = nullptr;
static void SetUpTestCase() {
InitializeAllTargets();
@@ -82,31 +85,131 @@ protected:
return;
}
+ // Set the data layout to match the target machine
+ M->setDataLayout(TM->createDataLayout());
+
// Create a dummy function to get subtarget info
FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false);
Function *F =
Function::Create(FT, Function::ExternalLinkage, "test", M.get());
- // Get the target instruction info
+ // Create MMI and MF to get TRI and MRI
+ MMI = std::make_unique<MachineModuleInfo>(TM.get());
+ MF = &MMI->getOrCreateMachineFunction(*F);
+
+ // Get the target instruction info and register info
TII = TM->getSubtargetImpl(*F)->getInstrInfo();
- if (!TII) {
- GTEST_SKIP() << "Failed to get target instruction info; Skipping test";
+ TRI = TM->getSubtargetImpl(*F)->getRegisterInfo();
+ if (!TII || !TRI) {
+ GTEST_SKIP()
+ << "Failed to get target instruction/register info; Skipping test";
return;
}
}
- void TearDown() override { TII = nullptr; }
+ void TearDown() override {
+ TII = nullptr;
+ TRI = nullptr;
+ }
+
+ // Find an opcode by name
+ int findOpcodeByName(StringRef Name) {
+ for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
+ if (TII->getName(Opcode) == Name)
+ return Opcode;
+ }
+ return -1; // Not found
+ }
+
+ // Create a vocabulary with specific opcodes and embeddings
+ // This might cause errors in future when the validation in
+ // MIRVocabulary::generateStorage() enforces hard checks on the vocabulary
+ // entries.
+ Expected<MIRVocabulary> createTestVocab(
+ std::initializer_list<std::pair<const char *, float>> Opcodes,
+ std::initializer_list<std::pair<const char *, float>> CommonOperands,
+ std::initializer_list<std::pair<const char *, float>> PhyRegs,
+ std::initializer_list<std::pair<const char *, float>> VirtRegs,
+ unsigned Dimension = 2) {
+ assert(TII && TRI && MF && "Target info not initialized");
+ VocabMap OpcodeMap, CommonOperandMap, PhyRegMap, VirtRegMap;
+ for (const auto &[Name, Value] : Opcodes)
+ OpcodeMap[Name] = Embedding(Dimension, Value);
+
+ for (const auto &[Name, Value] : CommonOperands)
+ CommonOperandMap[Name] = Embedding(Dimension, Value);
+
+ for (const auto &[Name, Value] : PhyRegs)
+ PhyRegMap[Name] = Embedding(Dimension, Value);
+
+ for (const auto &[Name, Value] : VirtRegs)
+ VirtRegMap[Name] = Embedding(Dimension, Value);
+
+ // If any section is empty, create minimal maps for other vocabulary
+ // sections to satisfy validation
+ if (Opcodes.size() == 0)
+ OpcodeMap["NOOP"] = Embedding(Dimension, 0.0f);
+ if (CommonOperands.size() == 0)
+ CommonOperandMap["Immediate"] = Embedding(Dimension, 0.0f);
+ if (PhyRegs.size() == 0)
+ PhyRegMap["GR32"] = Embedding(Dimension, 0.0f);
+ if (VirtRegs.size() == 0)
+ VirtRegMap["GR32"] = Embedding(Dimension, 0.0f);
+
+ return MIRVocabulary::create(
+ std::move(OpcodeMap), std::move(CommonOperandMap), std::move(PhyRegMap),
+ std::move(VirtRegMap), *TII, *TRI, MF->getRegInfo());
+ }
};
-// Function to find an opcode by name
-static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) {
- for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
- if (TII->getName(Opcode) == Name)
- return Opcode;
+// Parameterized test for empty vocab sections
+class MIR2VecVocabEmptySectionTestFixture
+ : public MIR2VecVocabTestFixture,
+ public ::testing::WithParamInterface<int> {
+protected:
+ void SetUp() override {
+ MIR2VecVocabTestFixture::SetUp();
+ // If base class setup was skipped (TII not initialized), skip derived setup
+ if (!TII)
+ GTEST_SKIP() << "Failed to get target instruction info in "
+ "the base class setup; Skipping test";
+ }
+};
+
+TEST_P(MIR2VecVocabEmptySectionTestFixture, EmptySectionFailsValidation) {
+ int EmptySection = GetParam();
+ VocabMap OpcodeMap, CommonOperandMap, PhyRegMap, VirtRegMap;
+
+ if (EmptySection != 0)
+ OpcodeMap["ADD"] = Embedding(2, 1.0f);
+ if (EmptySection != 1)
+ CommonOperandMap["Immediate"] = Embedding(2, 0.0f);
+ if (EmptySection != 2)
+ PhyRegMap["GR32"] = Embedding(2, 0.0f);
+ if (EmptySection != 3)
+ VirtRegMap["GR32"] = Embedding(2, 0.0f);
+
+ ASSERT_TRUE(TII != nullptr);
+ ASSERT_TRUE(TRI != nullptr);
+ ASSERT_TRUE(MF != nullptr);
+
+ auto VocabOrErr = MIRVocabulary::create(
+ std::move(OpcodeMap), std::move(CommonOperandMap), std::move(PhyRegMap),
+ std::move(VirtRegMap), *TII, *TRI, MF->getRegInfo());
+ EXPECT_FALSE(static_cast<bool>(VocabOrErr))
+ << "Factory method should fail when section " << EmptySection
+ << " is empty";
+
+ if (!VocabOrErr) {
+ auto Err = VocabOrErr.takeError();
+ std::string ErrorMsg = toString(std::move(Err));
+ EXPECT_FALSE(ErrorMsg.empty());
}
- return -1; // Not found
}
+INSTANTIATE_TEST_SUITE_P(EmptySection, MIR2VecVocabEmptySectionTestFixture,
+ ::testing::Values(0, 1, 2, 3));
+
TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Test that same base opcodes get same canonical indices
std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri");
@@ -118,10 +221,8 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Create a MIRVocabulary instance to test the mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
- VocabMap VMap;
Embedding Val = Embedding(64, 1.0f);
- VMap["ADD"] = Val;
- auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {}, 64);
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
<< "Failed to create vocabulary: "
<< toString(TestVocabOrErr.takeError());
@@ -156,16 +257,16 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
6880u); // X86 has >6880 unique base opcodes
// Check that the embeddings for opcodes not in the vocab are zero vectors
- int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr");
+ int Add32rrOpcode = findOpcodeByName("ADD32rr");
ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found";
EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val));
- int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr");
+ int Sub32rrOpcode = findOpcodeByName("SUB32rr");
ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found";
EXPECT_TRUE(
TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
- int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr");
+ int Mov32rrOpcode = findOpcodeByName("MOV32rr");
ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found";
EXPECT_TRUE(
TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
@@ -178,9 +279,7 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Create a MIRVocabulary instance to test deterministic mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
- VocabMap VMap;
- VMap["ADD"] = Embedding(64, 1.0f);
- auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {}, 64);
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
<< "Failed to create vocabulary: "
<< toString(TestVocabOrErr.takeError());
@@ -189,8 +288,6 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName);
-
- EXPECT_EQ(Index1, Index2);
EXPECT_EQ(Index2, Index3);
// Test across multiple runs
@@ -202,11 +299,8 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Test MIRVocabulary construction
TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
- VocabMap VMap;
- VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
- VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
-
- auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ auto VocabOrErr =
+ createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, {}, {}, {}, 128);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
@@ -227,20 +321,450 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
EXPECT_GT(Count, 0u);
}
-// Test factory method with empty vocabulary
-TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
- VocabMap EmptyVMap;
+// Fixture for embedding related tests
+class MIR2VecEmbeddingTestFixture : public MIR2VecVocabTestFixture {
+protected:
+ void SetUp() override {
+ MIR2VecVocabTestFixture::SetUp();
+ // If base class setup was skipped (TII not initialized), skip derived setup
+ if (!TII)
+ GTEST_SKIP() << "Failed to get target instruction info in "
+ "the base class setup; Skipping test";
+ }
- auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII);
- EXPECT_FALSE(static_cast<bool>(VocabOrErr))
- << "Factory method should fail with empty vocabulary";
+ void TearDown() override { MIR2VecVocabTestFixture::TearDown(); }
- // Consume the error
- if (!VocabOrErr) {
- auto Err = VocabOrErr.takeError();
- std::string ErrorMsg = toString(std::move(Err));
- EXPECT_FALSE(ErrorMsg.empty());
+ // Create a machine instruction
+ MachineInstr *createMachineInstr(MachineBasicBlock &MBB, unsigned Opcode) {
+ const MCInstrDesc &Desc = TII->get(Opcode);
+ // Create instruction - operands don't affect opcode-based embeddings
+ MachineInstr *MI = BuildMI(MBB, MBB.end(), DebugLoc(), Desc);
+ return MI;
+ }
+
+ MachineInstr *createMachineInstr(MachineBasicBlock &MBB,
+ const char *OpcodeName) {
+ int Opcode = findOpcodeByName(OpcodeName);
+ if (Opcode == -1)
+ return nullptr;
+ return createMachineInstr(MBB, Opcode);
+ }
+
+ void createMachineInstrs(MachineBasicBlock &MBB,
+ std::initializer_list<const char *> Opcodes) {
+ for (const char *OpcodeName : Opcodes) {
+ MachineInstr *MI = createMachineInstr(MBB, OpcodeName);
+ ASSERT_TRUE(MI != nullptr);
+ }
+ }
+};
+
+// Test factory method for creating embedder
+TEST_F(MIR2VecEmbeddingTestFixture, CreateSymbolicEmbedder) {
+ auto VocabOrErr =
+ MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 1);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &V = *VocabOrErr;
+ auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, *MF, V);
+ EXPECT_NE(Emb, nullptr);
+}
+
+TEST_F(MIR2VecEmbeddingTestFixture, CreateInvalidMode) {
+ auto VocabOrErr =
+ MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 1);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &V = *VocabOrErr;
+ auto Result = MIREmbedder::create(static_cast<MIR2VecKind>(-1), *MF, V);
+ EXPECT_FALSE(static_cast<bool>(Result));
+}
+
+// Test SymbolicMIREmbedder with simple target opcodes
+TEST_F(MIR2VecEmbeddingTestFixture, TestSymbolicEmbedder) {
+ // Create a test vocabulary with specific values
+ auto VocabOrErr = createTestVocab(
+ {
+ {"NOOP", 1.0f}, // [1.0, 1.0, 1.0, 1.0]
+ {"RET", 2.0f}, // [2.0, 2.0, 2.0, 2.0]
+ {"TRAP", 3.0f} // [3.0, 3.0, 3.0, 3.0]
+ },
+ {}, {}, {}, 4);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+ // Create a basic block using fixture's MF
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Use real X86 opcodes that should exist and not be pseudo
+ auto NoopInst = createMachineInstr(*MBB, "NOOP");
+ ASSERT_TRUE(NoopInst != nullptr);
+
+ auto RetInst = createMachineInstr(*MBB, "RET64");
+ ASSERT_TRUE(RetInst != nullptr);
+
+ auto TrapInst = createMachineInstr(*MBB, "TRAP");
+ ASSERT_TRUE(TrapInst != nullptr);
+
+ // Verify these are not pseudo instructions
+ ASSERT_FALSE(NoopInst->isPseudo()) << "NOOP is marked as pseudo instruction";
+ ASSERT_FALSE(RetInst->isPseudo()) << "RET is marked as pseudo instruction";
+ ASSERT_FALSE(TrapInst->isPseudo()) << "TRAP is marked as pseudo instruction";
+
+ // Create embedder
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // Test instruction embeddings
+ auto NoopEmb = Embedder->getMInstVector(*NoopInst);
+ auto RetEmb = Embedder->getMInstVector(*RetInst);
+ auto TrapEmb = Embedder->getMInstVector(*TrapInst);
+
+ // Verify embeddings match expected values (accounting for weight scaling)
+ float ExpectedWeight = mir2vec::OpcWeight; // Global weight from command line
+ EXPECT_TRUE(NoopEmb.approximatelyEquals(Embedding(4, 1.0f * ExpectedWeight)));
+ EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(4, 2.0f * ExpectedWeight)));
+ EXPECT_TRUE(TrapEmb.approximatelyEquals(Embedding(4, 3.0f * ExpectedWeight)));
+
+ // Test basic block embedding (should be sum of instruction embeddings)
+ auto MBBVector = Embedder->getMBBVector(*MBB);
+
+ // Expected BB vector: NOOP + RET + TRAP = [1+2+3, 1+2+3, 1+2+3, 1+2+3] *
+ // weight = [6, 6, 6, 6] * weight
+ Embedding ExpectedMBBVector(4, 6.0f * ExpectedWeight);
+ EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedMBBVector));
+
+ // Test function embedding (should equal MBB embedding since we have one MBB)
+ auto MFuncVector = Embedder->getMFunctionVector();
+ EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBBVector));
+}
+
+// Test embedder with multiple basic blocks
+TEST_F(MIR2VecEmbeddingTestFixture, MultipleBasicBlocks) {
+ // Create a test vocabulary
+ auto VocabOrErr =
+ createTestVocab({{"NOOP", 1.0f}, {"TRAP", 2.0f}}, {}, {}, {});
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ // Create two basic blocks using fixture's MF
+ MachineBasicBlock *MBB1 = MF->CreateMachineBasicBlock();
+ MachineBasicBlock *MBB2 = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB1);
+ MF->push_back(MBB2);
+
+ createMachineInstrs(*MBB1, {"NOOP", "NOOP"});
+ createMachineInstr(*MBB2, "TRAP");
+
+ // Create embedder
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // Test basic block embeddings
+ auto MBB1Vector = Embedder->getMBBVector(*MBB1);
+ auto MBB2Vector = Embedder->getMBBVector(*MBB2);
+
+ float ExpectedWeight = mir2vec::OpcWeight;
+ // BB1: NOOP + NOOP = 2 * ([1, 1] * weight)
+ Embedding ExpectedMBB1Vector(2, 2.0f * ExpectedWeight);
+ EXPECT_TRUE(MBB1Vector.approximatelyEquals(ExpectedMBB1Vector));
+
+ // BB2: TRAP = [2, 2] * weight
+ Embedding ExpectedMBB2Vector(2, 2.0f * ExpectedWeight);
+ EXPECT_TRUE(MBB2Vector.approximatelyEquals(ExpectedMBB2Vector));
+
+ // Function embedding: BB1 + BB2 = [2+2, 2+2] * weight = [4, 4] * weight
+ // Function embedding should be just the first BB embedding as the second BB
+ // is unreachable
+ auto MFuncVector = Embedder->getMFunctionVector();
+ EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBB1Vector));
+
+ // Add a branch from BB1 to BB2 to make both reachable; now function embedding
+ // should be MBB1 + MBB2
+ MBB1->addSuccessor(MBB2);
+ auto NewMFuncVector = Embedder->getMFunctionVector(); // Recompute embeddings
+ Embedding ExpectedFuncVector = MBB1Vector + MBB2Vector;
+ EXPECT_TRUE(NewMFuncVector.approximatelyEquals(ExpectedFuncVector));
+}
+
+// Test embedder with empty basic block
+TEST_F(MIR2VecEmbeddingTestFixture, EmptyBasicBlock) {
+
+ // Create an empty basic block
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Create embedder
+ auto VocabOrErr =
+ MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 2);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &V = *VocabOrErr;
+ auto Embedder = SymbolicMIREmbedder::create(*MF, V);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // Test that empty BB has zero embedding
+ auto MBBVector = Embedder->getMBBVector(*MBB);
+ Embedding ExpectedBBVector(2, 0.0f);
+ EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector));
+
+ // Function embedding should also be zero
+ auto MFuncVector = Embedder->getMFunctionVector();
+ EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedBBVector));
+}
+
+// Test embedder with opcodes not in vocabulary
+TEST_F(MIR2VecEmbeddingTestFixture, UnknownOpcodes) {
+ // Create a test vocabulary with limited entries
+ // SUB is intentionally not included
+ auto VocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {});
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ // Create a basic block
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Find opcodes
+ int AddOpcode = findOpcodeByName("ADD32rr");
+ int SubOpcode = findOpcodeByName("SUB32rr");
+
+ ASSERT_NE(AddOpcode, -1) << "ADD32rr opcode not found";
+ ASSERT_NE(SubOpcode, -1) << "SUB32rr opcode not found";
+
+ // Create instructions
+ MachineInstr *AddInstr = createMachineInstr(*MBB, AddOpcode);
+ MachineInstr *SubInstr = createMachineInstr(*MBB, SubOpcode);
+
+ // Create embedder
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // Test instruction embeddings
+ auto AddVector = Embedder->getMInstVector(*AddInstr);
+ auto SubVector = Embedder->getMInstVector(*SubInstr);
+
+ float ExpectedWeight = mir2vec::OpcWeight;
+ // ADD should have the embedding from vocabulary
+ EXPECT_TRUE(
+ AddVector.approximatelyEquals(Embedding(2, 1.0f * ExpectedWeight)));
+
+ // SUB should have zero embedding (not in vocabulary)
+ EXPECT_TRUE(SubVector.approximatelyEquals(Embedding(2, 0.0f)));
+
+ // Basic block embedding should be ADD + SUB = [1.0, 1.0] * weight + [0.0,
+ // 0.0] = [1.0, 1.0] * weight
+ const auto &MBBVector = Embedder->getMBBVector(*MBB);
+ Embedding ExpectedBBVector(2, 1.0f * ExpectedWeight);
+ EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector));
+}
+
+// Test vocabulary string key generation
+TEST_F(MIR2VecEmbeddingTestFixture, VocabularyStringKeys) {
+ auto VocabOrErr =
+ createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, {}, {}, {}, 2);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ // Test that we can get string keys for all positions
+ for (size_t Pos = 0; Pos < Vocab.getCanonicalSize(); ++Pos) {
+ std::string Key = Vocab.getStringKey(Pos);
+ EXPECT_FALSE(Key.empty()) << "Empty key at position " << Pos;
+ }
+
+ // Test specific known positions if we can identify them
+ unsigned AddIndex = Vocab.getCanonicalIndexForBaseName("ADD");
+ std::string AddKey = Vocab.getStringKey(AddIndex);
+ EXPECT_EQ(AddKey, "ADD");
+
+ unsigned SubIndex = Vocab.getCanonicalIndexForBaseName("SUB");
+ std::string SubKey = Vocab.getStringKey(SubIndex);
+ EXPECT_EQ(SubKey, "SUB");
+
+ unsigned ImmIndex = Vocab.getCanonicalIndexForOperandName("Immediate");
+ std::string ImmKey = Vocab.getStringKey(ImmIndex);
+ EXPECT_EQ(ImmKey, "Immediate");
+
+ unsigned PhyRegIndex = Vocab.getCanonicalIndexForRegisterClass("GR32", true);
+ std::string PhyRegKey = Vocab.getStringKey(PhyRegIndex);
+ EXPECT_EQ(PhyRegKey, "PhyReg_GR32");
+
+ unsigned VirtRegIndex =
+ Vocab.getCanonicalIndexForRegisterClass("GR32", false);
+ std::string VirtRegKey = Vocab.getStringKey(VirtRegIndex);
+ EXPECT_EQ(VirtRegKey, "VirtReg_GR32");
+}
+
+// Test vocabulary dimension consistency
+TEST_F(MIR2VecEmbeddingTestFixture, DimensionConsistency) {
+ auto VocabOrErr = createTestVocab({{"TEST", 1.0f}}, {}, {}, {}, 5);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ EXPECT_EQ(Vocab.getDimension(), 5u);
+
+ // All embeddings should have the same dimension
+ for (auto IT = Vocab.begin(); IT != Vocab.end(); ++IT)
+ EXPECT_EQ((*IT).size(), 5u);
+}
+
+// Test invalid register handling through machine instruction creation
+TEST_F(MIR2VecEmbeddingTestFixture, InvalidRegisterHandling) {
+ float MOVValue = 1.5f;
+ float ImmValue = 0.5f;
+ float PhyRegValue = 0.2f;
+ auto VocabOrErr = createTestVocab(
+ {{"MOV", MOVValue}}, {{"Immediate", ImmValue}},
+ {{"GR8_ABCD_H", PhyRegValue}, {"GR8_ABCD_L", PhyRegValue + 0.1f}}, {}, 3);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Create a MOV instruction with actual operands including potential $noreg
+ // This tests the actual scenario where invalid registers are encountered
+ auto MovOpcode = findOpcodeByName("MOV32mr");
+ ASSERT_NE(MovOpcode, -1) << "MOV32mr opcode not found";
+ const MCInstrDesc &Desc = TII->get(MovOpcode);
+
+ // Use available physical registers from the target
+ unsigned BaseReg =
+ TRI->getNumRegs() > 1 ? 1 : 0; // First available physical register
+ unsigned ValueReg = TRI->getNumRegs() > 2 ? 2 : BaseReg;
+
+ // MOV32mr typically has: base, scale, index, displacement, segment, value
+ // Use the MachineInstrBuilder API properly
+ auto MovInst = BuildMI(*MBB, MBB->end(), DebugLoc(), Desc)
+ .addReg(BaseReg) // base
+ .addImm(1) // scale
+ .addReg(0) // index ($noreg)
+ .addImm(-4) // displacement
+ .addReg(0) // segment ($noreg)
+ .addReg(ValueReg); // value
+
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // This should not crash even if the instruction has $noreg operands
+ auto InstEmb = Embedder->getMInstVector(*MovInst);
+ EXPECT_EQ(InstEmb.size(), 3u);
+
+ // Test the expected embedding value
+ Embedding ExpectedOpcodeContribution(3, MOVValue * mir2vec::OpcWeight);
+ auto ExpectedOperandContribution =
+ Embedding(3, PhyRegValue * mir2vec::RegOperandWeight) // Base
+ + Embedding(3, ImmValue * mir2vec::CommonOperandWeight) // Scale
+ + Embedding(3, 0.0f) // noreg
+ + Embedding(3, ImmValue * mir2vec::CommonOperandWeight) // displacement
+ + Embedding(3, 0.0f) // noreg
+ + Embedding(3, (PhyRegValue + 0.1f) * mir2vec::RegOperandWeight); // Value
+ auto ExpectedEmb = ExpectedOpcodeContribution + ExpectedOperandContribution;
+ EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb))
+ << "MOV instruction embedding should match expected embedding";
+}
+
+// Test handling of both physical and virtual registers in an instruction
+TEST_F(MIR2VecEmbeddingTestFixture, PhysicalAndVirtualRegisterHandling) {
+ float MOVValue = 2.0f;
+ float ImmValue = 0.7f;
+ float PhyRegValue = 0.3f;
+ float VirtRegValue = 0.9f;
+
+ // Find GR32 register class
+ const TargetRegisterClass *GR32RC = nullptr;
+ for (unsigned i = 0; i < TRI->getNumRegClasses(); ++i) {
+ const TargetRegisterClass *RC = TRI->getRegClass(i);
+ if (std::string(TRI->getRegClassName(RC)) == "GR32") {
+ GR32RC = RC;
+ break;
+ }
}
+ ASSERT_TRUE(GR32RC != nullptr && GR32RC->isAllocatable())
+ << "No allocatable GR32 register class found";
+
+ // Get first available physical register from GR32
+ unsigned PhyReg = *GR32RC->begin();
+ // Create a virtual register of class GR32
+ unsigned VirtReg = MF->getRegInfo().createVirtualRegister(GR32RC);
+
+ // Create vocabulary with register class based keys
+ auto VocabOrErr =
+ createTestVocab({{"MOV", MOVValue}}, {{"Immediate", ImmValue}},
+ {{"GR32_AD", PhyRegValue}}, // GR32_AD is the minimal key
+ {{"GR32", VirtRegValue}}, 4);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Create a MOV32rr instruction: MOV32rr dst, src
+ auto MovOpcode = findOpcodeByName("MOV32rr");
+ ASSERT_NE(MovOpcode, -1) << "MOV32rr opcode not found";
+ const MCInstrDesc &Desc = TII->get(MovOpcode);
+
+ // MOV32rr: dst (physical), src (virtual)
+ auto MovInst = BuildMI(*MBB, MBB->end(), DebugLoc(), Desc)
+ .addReg(PhyReg) // physical register destination
+ .addReg(VirtReg); // virtual register source
+
+ // Create embedder with virtual register support
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // This should not crash and should produce a valid embedding
+ auto InstEmb = Embedder->getMInstVector(*MovInst);
+ EXPECT_EQ(InstEmb.size(), 4u);
+
+ // Test the expected embedding value
+ Embedding ExpectedOpcodeContribution(4, MOVValue * mir2vec::OpcWeight);
+ auto ExpectedOperandContribution =
+ Embedding(4, PhyRegValue * mir2vec::RegOperandWeight) // dst (physical)
+ + Embedding(4, VirtRegValue * mir2vec::RegOperandWeight); // src (virtual)
+ auto ExpectedEmb = ExpectedOpcodeContribution + ExpectedOperandContribution;
+ EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb))
+ << "MOV32rr instruction embedding should match expected embedding";
}
+// Test precise embedding calculation with known operands
+TEST_F(MIR2VecEmbeddingTestFixture, EmbeddingCalculation) {
+ auto VocabOrErr = createTestVocab({{"NOOP", 2.0f}}, {}, {}, {}, 2);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Create a simple NOOP instruction (no operands)
+ auto NoopInst = createMachineInstr(*MBB, "NOOP");
+ ASSERT_TRUE(NoopInst != nullptr);
+
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // Get the instruction embedding
+ auto InstEmb = Embedder->getMInstVector(*NoopInst);
+ EXPECT_EQ(InstEmb.size(), 2u);
+
+ // For NOOP with no operands, the embedding should be exactly the opcode
+ // embedding
+ float ExpectedWeight = mir2vec::OpcWeight;
+ Embedding ExpectedEmb(2, 2.0f * ExpectedWeight);
+
+ EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb))
+ << "NOOP instruction embedding should match opcode embedding";
+
+ // Verify individual components
+ EXPECT_FLOAT_EQ(InstEmb[0], 2.0f * ExpectedWeight);
+ EXPECT_FLOAT_EQ(InstEmb[1], 2.0f * ExpectedWeight);
+}
} // namespace
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 16b9979..aa56aaf 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -550,6 +550,31 @@ TEST_F(SelectionDAGPatternMatchTest, matchNode) {
EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_ConstInt(), m_Value())));
}
+TEST_F(SelectionDAGPatternMatchTest, matchSelectLike) {
+ SDLoc DL;
+ auto Int32VT = EVT::getIntegerVT(Context, 32);
+ auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
+
+ SDValue Cond = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 0, Int32VT);
+ SDValue TVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+ SDValue FVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
+
+ SDValue VCond = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 0, VInt32VT);
+ SDValue VTVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
+ SDValue VFVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
+
+ SDValue Select = DAG->getNode(ISD::SELECT, DL, Int32VT, Cond, TVal, FVal);
+ SDValue VSelect =
+ DAG->getNode(ISD::VSELECT, DL, Int32VT, VCond, VTVal, VFVal);
+
+ using namespace SDPatternMatch;
+ EXPECT_TRUE(sd_match(Select, m_SelectLike(m_Specific(Cond), m_Specific(TVal),
+ m_Specific(FVal))));
+ EXPECT_TRUE(
+ sd_match(VSelect, m_SelectLike(m_Specific(VCond), m_Specific(VTVal),
+ m_Specific(VFVal))));
+}
+
namespace {
struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
using SDPatternMatch::BasicMatchContext::BasicMatchContext;