aboutsummaryrefslogtreecommitdiff
path: root/llvm/unittests/CodeGen
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/unittests/CodeGen')
-rw-r--r--llvm/unittests/CodeGen/MIR2VecTest.cpp88
1 files changed, 55 insertions, 33 deletions
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index 01f2ead..d243d82 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -54,27 +54,32 @@ protected:
std::unique_ptr<TargetMachine> TM;
const TargetInstrInfo *TII;
+ static void SetUpTestCase() {
+ InitializeAllTargets();
+ InitializeAllTargetMCs();
+ }
+
void SetUp() override {
- LLVMInitializeX86TargetInfo();
- LLVMInitializeX86Target();
- LLVMInitializeX86TargetMC();
+ Triple TargetTriple("x86_64-unknown-linux-gnu");
+ std::string Error;
+ const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);
+ if (!T) {
+ GTEST_SKIP() << "x86_64-unknown-linux-gnu target triple not available; "
+ "Skipping test";
+ return;
+ }
Ctx = std::make_unique<LLVMContext>();
M = std::make_unique<Module>("test", *Ctx);
-
- // Set up X86 target
- Triple TargetTriple("x86_64-unknown-linux-gnu");
M->setTargetTriple(TargetTriple);
- std::string Error;
- const Target *TheTarget =
- TargetRegistry::lookupTarget(M->getTargetTriple(), Error);
- ASSERT_TRUE(TheTarget) << "Failed to lookup target: " << Error;
-
TargetOptions Options;
- TM = std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
- M->getTargetTriple(), "", "", Options, Reloc::Model::Static));
- ASSERT_TRUE(TM);
+ TM = std::unique_ptr<TargetMachine>(
+ T->createTargetMachine(TargetTriple, "", "", Options, std::nullopt));
+ if (!TM) {
+ GTEST_SKIP() << "Failed to create X86 target machine; Skipping test";
+ return;
+ }
// Create a dummy function to get subtarget info
FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false);
@@ -83,10 +88,22 @@ protected:
// Get the target instruction info
TII = TM->getSubtargetImpl(*F)->getInstrInfo();
- ASSERT_TRUE(TII);
+ if (!TII) {
+ GTEST_SKIP() << "Failed to get target instruction info; Skipping test";
+ return;
+ }
}
};
+// Function to find an opcode by name
+static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) {
+ for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
+ if (TII->getName(Opcode) == Name)
+ return Opcode;
+ }
+ return -1; // Not found
+}
+
TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Test that same base opcodes get same canonical indices
std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri");
@@ -98,10 +115,10 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Create a MIRVocabulary instance to test the mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
- VocabMap VM;
+ VocabMap VMap;
Embedding Val = Embedding(64, 1.0f);
- VM["ADD"] = Val;
- MIRVocabulary TestVocab(std::move(VM), TII);
+ VMap["ADD"] = Val;
+ MIRVocabulary TestVocab(std::move(VMap), TII);
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
@@ -132,9 +149,19 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
6880u); // X86 has >6880 unique base opcodes
// Check that the embeddings for opcodes not in the vocab are zero vectors
- EXPECT_TRUE(TestVocab[AddIndex].approximatelyEquals(Val));
- EXPECT_TRUE(TestVocab[SubIndex].approximatelyEquals(Embedding(64, 0.0f)));
- EXPECT_TRUE(TestVocab[MovIndex].approximatelyEquals(Embedding(64, 0.0f)));
+ int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr");
+ ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found";
+ EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val));
+
+ int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr");
+ ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found";
+ EXPECT_TRUE(
+ TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
+
+ int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr");
+ ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found";
+ EXPECT_TRUE(
+ TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
}
// Test deterministic mapping
@@ -144,9 +171,9 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Create a MIRVocabulary instance to test deterministic mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
- VocabMap VM;
- VM["ADD"] = Embedding(64, 1.0f);
- MIRVocabulary TestVocab(std::move(VM), TII);
+ VocabMap VMap;
+ VMap["ADD"] = Embedding(64, 1.0f);
+ MIRVocabulary TestVocab(std::move(VMap), TII);
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
@@ -164,16 +191,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Test MIRVocabulary construction
TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
- // Test empty MIRVocabulary
- MIRVocabulary EmptyVocab;
- EXPECT_FALSE(EmptyVocab.isValid());
-
- // Test MIRVocabulary with embeddings via VocabMap
- VocabMap VM;
- VM["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
- VM["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
+ VocabMap VMap;
+ VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
+ VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
- MIRVocabulary Vocab(std::move(VM), TII);
+ MIRVocabulary Vocab(std::move(VMap), TII);
EXPECT_TRUE(Vocab.isValid());
EXPECT_EQ(Vocab.getDimension(), 128u);