aboutsummaryrefslogtreecommitdiff
path: root/llvm/unittests/CodeGen
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/unittests/CodeGen')
-rw-r--r--llvm/unittests/CodeGen/AsmPrinterDwarfTest.cpp16
-rw-r--r--llvm/unittests/CodeGen/InstrRefLDVTest.cpp2
-rw-r--r--llvm/unittests/CodeGen/MIR2VecTest.cpp369
-rw-r--r--llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp25
4 files changed, 355 insertions, 57 deletions
diff --git a/llvm/unittests/CodeGen/AsmPrinterDwarfTest.cpp b/llvm/unittests/CodeGen/AsmPrinterDwarfTest.cpp
index 6c08173..af2d56d 100644
--- a/llvm/unittests/CodeGen/AsmPrinterDwarfTest.cpp
+++ b/llvm/unittests/CodeGen/AsmPrinterDwarfTest.cpp
@@ -383,14 +383,14 @@ class AsmPrinterHandlerTest : public AsmPrinterFixtureBase {
public:
TestHandler(AsmPrinterHandlerTest &Test) : Test(Test) {}
- virtual ~TestHandler() {}
- virtual void setSymbolSize(const MCSymbol *Sym, uint64_t Size) override {}
- virtual void beginModule(Module *M) override { Test.BeginCount++; }
- virtual void endModule() override { Test.EndCount++; }
- virtual void beginFunction(const MachineFunction *MF) override {}
- virtual void endFunction(const MachineFunction *MF) override {}
- virtual void beginInstruction(const MachineInstr *MI) override {}
- virtual void endInstruction() override {}
+ ~TestHandler() override {}
+ void setSymbolSize(const MCSymbol *Sym, uint64_t Size) override {}
+ void beginModule(Module *M) override { Test.BeginCount++; }
+ void endModule() override { Test.EndCount++; }
+ void beginFunction(const MachineFunction *MF) override {}
+ void endFunction(const MachineFunction *MF) override {}
+ void beginInstruction(const MachineInstr *MI) override {}
+ void endInstruction() override {}
};
protected:
diff --git a/llvm/unittests/CodeGen/InstrRefLDVTest.cpp b/llvm/unittests/CodeGen/InstrRefLDVTest.cpp
index ce2a38b..ff87e7b 100644
--- a/llvm/unittests/CodeGen/InstrRefLDVTest.cpp
+++ b/llvm/unittests/CodeGen/InstrRefLDVTest.cpp
@@ -69,7 +69,7 @@ public:
InstrRefLDVTest() : Ctx(), Mod(std::make_unique<Module>("beehives", Ctx)) {}
- void SetUp() {
+ void SetUp() override {
// Boilerplate that creates a MachineFunction and associated blocks.
Mod->setDataLayout("e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-"
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index 8710d6b..d42749c 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -54,6 +54,9 @@ protected:
std::unique_ptr<Module> M;
std::unique_ptr<TargetMachine> TM;
const TargetInstrInfo *TII = nullptr;
+ const TargetRegisterInfo *TRI = nullptr;
+ std::unique_ptr<MachineModuleInfo> MMI;
+ MachineFunction *MF = nullptr;
static void SetUpTestCase() {
InitializeAllTargets();
@@ -90,15 +93,24 @@ protected:
Function *F =
Function::Create(FT, Function::ExternalLinkage, "test", M.get());
- // Get the target instruction info
+ // Create MMI and MF to get TRI and MRI
+ MMI = std::make_unique<MachineModuleInfo>(TM.get());
+ MF = &MMI->getOrCreateMachineFunction(*F);
+
+ // Get the target instruction info and register info
TII = TM->getSubtargetImpl(*F)->getInstrInfo();
- if (!TII) {
- GTEST_SKIP() << "Failed to get target instruction info; Skipping test";
+ TRI = TM->getSubtargetImpl(*F)->getRegisterInfo();
+ if (!TII || !TRI) {
+ GTEST_SKIP()
+ << "Failed to get target instruction/register info; Skipping test";
return;
}
}
- void TearDown() override { TII = nullptr; }
+ void TearDown() override {
+ TII = nullptr;
+ TRI = nullptr;
+ }
// Find an opcode by name
int findOpcodeByName(StringRef Name) {
@@ -110,17 +122,94 @@ protected:
}
// Create a vocabulary with specific opcodes and embeddings
- Expected<MIRVocabulary>
- createTestVocab(std::initializer_list<std::pair<const char *, float>> opcodes,
- unsigned dimension = 2) {
- assert(TII && "TargetInstrInfo not initialized");
- VocabMap VMap;
- for (const auto &[name, value] : opcodes)
- VMap[name] = Embedding(dimension, value);
- return MIRVocabulary::create(std::move(VMap), *TII);
+ // This might cause errors in future when the validation in
+ // MIRVocabulary::generateStorage() enforces hard checks on the vocabulary
+ // entries.
+ Expected<MIRVocabulary> createTestVocab(
+ std::initializer_list<std::pair<const char *, float>> Opcodes,
+ std::initializer_list<std::pair<const char *, float>> CommonOperands,
+ std::initializer_list<std::pair<const char *, float>> PhyRegs,
+ std::initializer_list<std::pair<const char *, float>> VirtRegs,
+ unsigned Dimension = 2) {
+ assert(TII && TRI && MF && "Target info not initialized");
+ VocabMap OpcodeMap, CommonOperandMap, PhyRegMap, VirtRegMap;
+ for (const auto &[Name, Value] : Opcodes)
+ OpcodeMap[Name] = Embedding(Dimension, Value);
+
+ for (const auto &[Name, Value] : CommonOperands)
+ CommonOperandMap[Name] = Embedding(Dimension, Value);
+
+ for (const auto &[Name, Value] : PhyRegs)
+ PhyRegMap[Name] = Embedding(Dimension, Value);
+
+ for (const auto &[Name, Value] : VirtRegs)
+ VirtRegMap[Name] = Embedding(Dimension, Value);
+
+ // If any section is empty, create minimal maps for other vocabulary
+ // sections to satisfy validation
+ if (Opcodes.size() == 0)
+ OpcodeMap["NOOP"] = Embedding(Dimension, 0.0f);
+ if (CommonOperands.size() == 0)
+ CommonOperandMap["Immediate"] = Embedding(Dimension, 0.0f);
+ if (PhyRegs.size() == 0)
+ PhyRegMap["GR32"] = Embedding(Dimension, 0.0f);
+ if (VirtRegs.size() == 0)
+ VirtRegMap["GR32"] = Embedding(Dimension, 0.0f);
+
+ return MIRVocabulary::create(
+ std::move(OpcodeMap), std::move(CommonOperandMap), std::move(PhyRegMap),
+ std::move(VirtRegMap), *TII, *TRI, MF->getRegInfo());
}
};
+// Parameterized test for empty vocab sections
+class MIR2VecVocabEmptySectionTestFixture
+ : public MIR2VecVocabTestFixture,
+ public ::testing::WithParamInterface<int> {
+protected:
+ void SetUp() override {
+ MIR2VecVocabTestFixture::SetUp();
+ // If base class setup was skipped (TII not initialized), skip derived setup
+ if (!TII)
+ GTEST_SKIP() << "Failed to get target instruction info in "
+ "the base class setup; Skipping test";
+ }
+};
+
+TEST_P(MIR2VecVocabEmptySectionTestFixture, EmptySectionFailsValidation) {
+ int EmptySection = GetParam();
+ VocabMap OpcodeMap, CommonOperandMap, PhyRegMap, VirtRegMap;
+
+ if (EmptySection != 0)
+ OpcodeMap["ADD"] = Embedding(2, 1.0f);
+ if (EmptySection != 1)
+ CommonOperandMap["Immediate"] = Embedding(2, 0.0f);
+ if (EmptySection != 2)
+ PhyRegMap["GR32"] = Embedding(2, 0.0f);
+ if (EmptySection != 3)
+ VirtRegMap["GR32"] = Embedding(2, 0.0f);
+
+ ASSERT_TRUE(TII != nullptr);
+ ASSERT_TRUE(TRI != nullptr);
+ ASSERT_TRUE(MF != nullptr);
+
+ auto VocabOrErr = MIRVocabulary::create(
+ std::move(OpcodeMap), std::move(CommonOperandMap), std::move(PhyRegMap),
+ std::move(VirtRegMap), *TII, *TRI, MF->getRegInfo());
+ EXPECT_FALSE(static_cast<bool>(VocabOrErr))
+ << "Factory method should fail when section " << EmptySection
+ << " is empty";
+
+ if (!VocabOrErr) {
+ auto Err = VocabOrErr.takeError();
+ std::string ErrorMsg = toString(std::move(Err));
+ EXPECT_FALSE(ErrorMsg.empty());
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(EmptySection, MIR2VecVocabEmptySectionTestFixture,
+ ::testing::Values(0, 1, 2, 3));
+
TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Test that same base opcodes get same canonical indices
std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri");
@@ -133,7 +222,7 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Create a MIRVocabulary instance to test the mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
Embedding Val = Embedding(64, 1.0f);
- auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, 64);
+ auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {}, 64);
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
<< "Failed to create vocabulary: "
<< toString(TestVocabOrErr.takeError());
@@ -190,7 +279,7 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Create a MIRVocabulary instance to test deterministic mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
- auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, 64);
+ auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {}, 64);
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
<< "Failed to create vocabulary: "
<< toString(TestVocabOrErr.takeError());
@@ -210,7 +299,8 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Test MIRVocabulary construction
TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
- auto VocabOrErr = createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, 128);
+ auto VocabOrErr =
+ createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, {}, {}, {}, 128);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
@@ -231,42 +321,15 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
EXPECT_GT(Count, 0u);
}
-// Test factory method with empty vocabulary
-TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
- VocabMap EmptyVMap;
-
- auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII);
- EXPECT_FALSE(static_cast<bool>(VocabOrErr))
- << "Factory method should fail with empty vocabulary";
-
- // Consume the error
- if (!VocabOrErr) {
- auto Err = VocabOrErr.takeError();
- std::string ErrorMsg = toString(std::move(Err));
- EXPECT_FALSE(ErrorMsg.empty());
- }
-}
-
// Fixture for embedding related tests
class MIR2VecEmbeddingTestFixture : public MIR2VecVocabTestFixture {
protected:
- std::unique_ptr<MachineModuleInfo> MMI;
- MachineFunction *MF = nullptr;
-
void SetUp() override {
MIR2VecVocabTestFixture::SetUp();
// If base class setup was skipped (TII not initialized), skip derived setup
if (!TII)
GTEST_SKIP() << "Failed to get target instruction info in "
"the base class setup; Skipping test";
-
- // Create a dummy function for MachineFunction
- FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false);
- Function *F =
- Function::Create(FT, Function::ExternalLinkage, "test", M.get());
-
- MMI = std::make_unique<MachineModuleInfo>(TM.get());
- MF = &MMI->getOrCreateMachineFunction(*F);
}
void TearDown() override { MIR2VecVocabTestFixture::TearDown(); }
@@ -298,7 +361,8 @@ protected:
// Test factory method for creating embedder
TEST_F(MIR2VecEmbeddingTestFixture, CreateSymbolicEmbedder) {
- auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 1);
+ auto VocabOrErr =
+ MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 1);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &V = *VocabOrErr;
@@ -307,7 +371,8 @@ TEST_F(MIR2VecEmbeddingTestFixture, CreateSymbolicEmbedder) {
}
TEST_F(MIR2VecEmbeddingTestFixture, CreateInvalidMode) {
- auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 1);
+ auto VocabOrErr =
+ MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 1);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &V = *VocabOrErr;
@@ -324,7 +389,7 @@ TEST_F(MIR2VecEmbeddingTestFixture, TestSymbolicEmbedder) {
{"RET", 2.0f}, // [2.0, 2.0, 2.0, 2.0]
{"TRAP", 3.0f} // [3.0, 3.0, 3.0, 3.0]
},
- 4);
+ {}, {}, {}, 4);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
@@ -378,7 +443,8 @@ TEST_F(MIR2VecEmbeddingTestFixture, TestSymbolicEmbedder) {
// Test embedder with multiple basic blocks
TEST_F(MIR2VecEmbeddingTestFixture, MultipleBasicBlocks) {
// Create a test vocabulary
- auto VocabOrErr = createTestVocab({{"NOOP", 1.0f}, {"TRAP", 2.0f}});
+ auto VocabOrErr =
+ createTestVocab({{"NOOP", 1.0f}, {"TRAP", 2.0f}}, {}, {}, {});
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
@@ -431,7 +497,8 @@ TEST_F(MIR2VecEmbeddingTestFixture, EmptyBasicBlock) {
MF->push_back(MBB);
// Create embedder
- auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(*TII, 2);
+ auto VocabOrErr =
+ MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 2);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &V = *VocabOrErr;
@@ -452,7 +519,7 @@ TEST_F(MIR2VecEmbeddingTestFixture, EmptyBasicBlock) {
TEST_F(MIR2VecEmbeddingTestFixture, UnknownOpcodes) {
// Create a test vocabulary with limited entries
// SUB is intentionally not included
- auto VocabOrErr = createTestVocab({{"ADD", 1.0f}});
+ auto VocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {});
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
@@ -494,4 +561,210 @@ TEST_F(MIR2VecEmbeddingTestFixture, UnknownOpcodes) {
Embedding ExpectedBBVector(2, 1.0f * ExpectedWeight);
EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector));
}
+
+// Test vocabulary string key generation
+TEST_F(MIR2VecEmbeddingTestFixture, VocabularyStringKeys) {
+ auto VocabOrErr =
+ createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, {}, {}, {}, 2);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ // Test that we can get string keys for all positions
+ for (size_t Pos = 0; Pos < Vocab.getCanonicalSize(); ++Pos) {
+ std::string Key = Vocab.getStringKey(Pos);
+ EXPECT_FALSE(Key.empty()) << "Empty key at position " << Pos;
+ }
+
+ // Test specific known positions if we can identify them
+ unsigned AddIndex = Vocab.getCanonicalIndexForBaseName("ADD");
+ std::string AddKey = Vocab.getStringKey(AddIndex);
+ EXPECT_EQ(AddKey, "ADD");
+
+ unsigned SubIndex = Vocab.getCanonicalIndexForBaseName("SUB");
+ std::string SubKey = Vocab.getStringKey(SubIndex);
+ EXPECT_EQ(SubKey, "SUB");
+
+ unsigned ImmIndex = Vocab.getCanonicalIndexForOperandName("Immediate");
+ std::string ImmKey = Vocab.getStringKey(ImmIndex);
+ EXPECT_EQ(ImmKey, "Immediate");
+
+ unsigned PhyRegIndex = Vocab.getCanonicalIndexForRegisterClass("GR32", true);
+ std::string PhyRegKey = Vocab.getStringKey(PhyRegIndex);
+ EXPECT_EQ(PhyRegKey, "PhyReg_GR32");
+
+ unsigned VirtRegIndex =
+ Vocab.getCanonicalIndexForRegisterClass("GR32", false);
+ std::string VirtRegKey = Vocab.getStringKey(VirtRegIndex);
+ EXPECT_EQ(VirtRegKey, "VirtReg_GR32");
+}
+
+// Test vocabulary dimension consistency
+TEST_F(MIR2VecEmbeddingTestFixture, DimensionConsistency) {
+ auto VocabOrErr = createTestVocab({{"TEST", 1.0f}}, {}, {}, {}, 5);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ EXPECT_EQ(Vocab.getDimension(), 5u);
+
+ // All embeddings should have the same dimension
+ for (auto IT = Vocab.begin(); IT != Vocab.end(); ++IT)
+ EXPECT_EQ((*IT).size(), 5u);
+}
+
+// Test invalid register handling through machine instruction creation
+TEST_F(MIR2VecEmbeddingTestFixture, InvalidRegisterHandling) {
+ float MOVValue = 1.5f;
+ float ImmValue = 0.5f;
+ float PhyRegValue = 0.2f;
+ auto VocabOrErr = createTestVocab(
+ {{"MOV", MOVValue}}, {{"Immediate", ImmValue}},
+ {{"GR8_ABCD_H", PhyRegValue}, {"GR8_ABCD_L", PhyRegValue + 0.1f}}, {}, 3);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Create a MOV instruction with actual operands including potential $noreg
+ // This tests the actual scenario where invalid registers are encountered
+ auto MovOpcode = findOpcodeByName("MOV32mr");
+ ASSERT_NE(MovOpcode, -1) << "MOV32mr opcode not found";
+ const MCInstrDesc &Desc = TII->get(MovOpcode);
+
+ // Use available physical registers from the target
+ unsigned BaseReg =
+ TRI->getNumRegs() > 1 ? 1 : 0; // First available physical register
+ unsigned ValueReg = TRI->getNumRegs() > 2 ? 2 : BaseReg;
+
+ // MOV32mr typically has: base, scale, index, displacement, segment, value
+ // Use the MachineInstrBuilder API properly
+ auto MovInst = BuildMI(*MBB, MBB->end(), DebugLoc(), Desc)
+ .addReg(BaseReg) // base
+ .addImm(1) // scale
+ .addReg(0) // index ($noreg)
+ .addImm(-4) // displacement
+ .addReg(0) // segment ($noreg)
+ .addReg(ValueReg); // value
+
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // This should not crash even if the instruction has $noreg operands
+ auto InstEmb = Embedder->getMInstVector(*MovInst);
+ EXPECT_EQ(InstEmb.size(), 3u);
+
+ // Test the expected embedding value
+ Embedding ExpectedOpcodeContribution(3, MOVValue * mir2vec::OpcWeight);
+ auto ExpectedOperandContribution =
+ Embedding(3, PhyRegValue * mir2vec::RegOperandWeight) // Base
+ + Embedding(3, ImmValue * mir2vec::CommonOperandWeight) // Scale
+ + Embedding(3, 0.0f) // noreg
+ + Embedding(3, ImmValue * mir2vec::CommonOperandWeight) // displacement
+ + Embedding(3, 0.0f) // noreg
+ + Embedding(3, (PhyRegValue + 0.1f) * mir2vec::RegOperandWeight); // Value
+ auto ExpectedEmb = ExpectedOpcodeContribution + ExpectedOperandContribution;
+ EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb))
+ << "MOV instruction embedding should match expected embedding";
+}
+
+// Test handling of both physical and virtual registers in an instruction
+TEST_F(MIR2VecEmbeddingTestFixture, PhysicalAndVirtualRegisterHandling) {
+ float MOVValue = 2.0f;
+ float ImmValue = 0.7f;
+ float PhyRegValue = 0.3f;
+ float VirtRegValue = 0.9f;
+
+ // Find GR32 register class
+ const TargetRegisterClass *GR32RC = nullptr;
+ for (unsigned i = 0; i < TRI->getNumRegClasses(); ++i) {
+ const TargetRegisterClass *RC = TRI->getRegClass(i);
+ if (std::string(TRI->getRegClassName(RC)) == "GR32") {
+ GR32RC = RC;
+ break;
+ }
+ }
+ ASSERT_TRUE(GR32RC != nullptr && GR32RC->isAllocatable())
+ << "No allocatable GR32 register class found";
+
+ // Get first available physical register from GR32
+ unsigned PhyReg = *GR32RC->begin();
+ // Create a virtual register of class GR32
+ unsigned VirtReg = MF->getRegInfo().createVirtualRegister(GR32RC);
+
+ // Create vocabulary with register class based keys
+ auto VocabOrErr =
+ createTestVocab({{"MOV", MOVValue}}, {{"Immediate", ImmValue}},
+ {{"GR32_AD", PhyRegValue}}, // GR32_AD is the minimal key
+ {{"GR32", VirtRegValue}}, 4);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Create a MOV32rr instruction: MOV32rr dst, src
+ auto MovOpcode = findOpcodeByName("MOV32rr");
+ ASSERT_NE(MovOpcode, -1) << "MOV32rr opcode not found";
+ const MCInstrDesc &Desc = TII->get(MovOpcode);
+
+ // MOV32rr: dst (physical), src (virtual)
+ auto MovInst = BuildMI(*MBB, MBB->end(), DebugLoc(), Desc)
+ .addReg(PhyReg) // physical register destination
+ .addReg(VirtReg); // virtual register source
+
+ // Create embedder with virtual register support
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // This should not crash and should produce a valid embedding
+ auto InstEmb = Embedder->getMInstVector(*MovInst);
+ EXPECT_EQ(InstEmb.size(), 4u);
+
+ // Test the expected embedding value
+ Embedding ExpectedOpcodeContribution(4, MOVValue * mir2vec::OpcWeight);
+ auto ExpectedOperandContribution =
+ Embedding(4, PhyRegValue * mir2vec::RegOperandWeight) // dst (physical)
+ + Embedding(4, VirtRegValue * mir2vec::RegOperandWeight); // src (virtual)
+ auto ExpectedEmb = ExpectedOpcodeContribution + ExpectedOperandContribution;
+ EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb))
+ << "MOV32rr instruction embedding should match expected embedding";
+}
+
+// Test precise embedding calculation with known operands
+TEST_F(MIR2VecEmbeddingTestFixture, EmbeddingCalculation) {
+ auto VocabOrErr = createTestVocab({{"NOOP", 2.0f}}, {}, {}, {}, 2);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
+
+ MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
+ MF->push_back(MBB);
+
+ // Create a simple NOOP instruction (no operands)
+ auto NoopInst = createMachineInstr(*MBB, "NOOP");
+ ASSERT_TRUE(NoopInst != nullptr);
+
+ auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
+ ASSERT_TRUE(Embedder != nullptr);
+
+ // Get the instruction embedding
+ auto InstEmb = Embedder->getMInstVector(*NoopInst);
+ EXPECT_EQ(InstEmb.size(), 2u);
+
+ // For NOOP with no operands, the embedding should be exactly the opcode
+ // embedding
+ float ExpectedWeight = mir2vec::OpcWeight;
+ Embedding ExpectedEmb(2, 2.0f * ExpectedWeight);
+
+ EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb))
+ << "NOOP instruction embedding should match opcode embedding";
+
+ // Verify individual components
+ EXPECT_FLOAT_EQ(InstEmb[0], 2.0f * ExpectedWeight);
+ EXPECT_FLOAT_EQ(InstEmb[1], 2.0f * ExpectedWeight);
+}
} // namespace
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 16b9979..aa56aaf 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -550,6 +550,31 @@ TEST_F(SelectionDAGPatternMatchTest, matchNode) {
EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_ConstInt(), m_Value())));
}
+TEST_F(SelectionDAGPatternMatchTest, matchSelectLike) {
+ SDLoc DL;
+ auto Int32VT = EVT::getIntegerVT(Context, 32);
+ auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
+
+ SDValue Cond = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 0, Int32VT);
+ SDValue TVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+ SDValue FVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
+
+ SDValue VCond = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 0, VInt32VT);
+ SDValue VTVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
+ SDValue VFVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
+
+ SDValue Select = DAG->getNode(ISD::SELECT, DL, Int32VT, Cond, TVal, FVal);
+ SDValue VSelect =
+ DAG->getNode(ISD::VSELECT, DL, Int32VT, VCond, VTVal, VFVal);
+
+ using namespace SDPatternMatch;
+ EXPECT_TRUE(sd_match(Select, m_SelectLike(m_Specific(Cond), m_Specific(TVal),
+ m_Specific(FVal))));
+ EXPECT_TRUE(
+ sd_match(VSelect, m_SelectLike(m_Specific(VCond), m_Specific(VTVal),
+ m_Specific(VFVal))));
+}
+
namespace {
struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
using SDPatternMatch::BasicMatchContext::BasicMatchContext;