diff options
Diffstat (limited to 'llvm/unittests/CodeGen/MIR2VecTest.cpp')
-rw-r--r-- | llvm/unittests/CodeGen/MIR2VecTest.cpp | 41 |
1 files changed, 35 insertions, 6 deletions
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp index d243d82..11222b4 100644 --- a/llvm/unittests/CodeGen/MIR2VecTest.cpp +++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp @@ -17,6 +17,7 @@ #include "llvm/IR/Module.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" #include "llvm/TargetParser/Triple.h" @@ -52,7 +53,7 @@ protected: std::unique_ptr<LLVMContext> Ctx; std::unique_ptr<Module> M; std::unique_ptr<TargetMachine> TM; - const TargetInstrInfo *TII; + const TargetInstrInfo *TII = nullptr; static void SetUpTestCase() { InitializeAllTargets(); @@ -93,6 +94,8 @@ protected: return; } } + + void TearDown() override { TII = nullptr; } }; // Function to find an opcode by name @@ -118,7 +121,11 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) { VocabMap VMap; Embedding Val = Embedding(64, 1.0f); VMap["ADD"] = Val; - MIRVocabulary TestVocab(std::move(VMap), TII); + auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); + ASSERT_TRUE(static_cast<bool>(TestVocabOrErr)) + << "Failed to create vocabulary: " + << toString(TestVocabOrErr.takeError()); + auto &TestVocab = *TestVocabOrErr; unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2); @@ -173,7 +180,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) { // Use a minimal MIRVocabulary to trigger canonical mapping construction VocabMap VMap; VMap["ADD"] = Embedding(64, 1.0f); - MIRVocabulary TestVocab(std::move(VMap), TII); + auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); + ASSERT_TRUE(static_cast<bool>(TestVocabOrErr)) + << "Failed to create vocabulary: " + << toString(TestVocabOrErr.takeError()); + auto &TestVocab = *TestVocabOrErr; unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName); unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName); @@ -195,8 +206,10 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0 VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0 - MIRVocabulary Vocab(std::move(VMap), TII); - EXPECT_TRUE(Vocab.isValid()); + auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII); + ASSERT_TRUE(static_cast<bool>(VocabOrErr)) + << "Failed to create vocabulary: " << toString(VocabOrErr.takeError()); + auto &Vocab = *VocabOrErr; EXPECT_EQ(Vocab.getDimension(), 128u); // Test iterator - iterates over individual embeddings @@ -214,4 +227,20 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) { EXPECT_GT(Count, 0u); } -} // namespace
\ No newline at end of file +// Test factory method with empty vocabulary +TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) { + VocabMap EmptyVMap; + + auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII); + EXPECT_FALSE(static_cast<bool>(VocabOrErr)) + << "Factory method should fail with empty vocabulary"; + + // Consume the error + if (!VocabOrErr) { + auto Err = VocabOrErr.takeError(); + std::string ErrorMsg = toString(std::move(Err)); + EXPECT_FALSE(ErrorMsg.empty()); + } +} + +} // namespace |