aboutsummaryrefslogtreecommitdiff
path: root/llvm/unittests/CodeGen/MIR2VecTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/unittests/CodeGen/MIR2VecTest.cpp')
-rw-r--r--llvm/unittests/CodeGen/MIR2VecTest.cpp41
1 files changed, 35 insertions, 6 deletions
diff --git a/llvm/unittests/CodeGen/MIR2VecTest.cpp b/llvm/unittests/CodeGen/MIR2VecTest.cpp
index d243d82..11222b4 100644
--- a/llvm/unittests/CodeGen/MIR2VecTest.cpp
+++ b/llvm/unittests/CodeGen/MIR2VecTest.cpp
@@ -17,6 +17,7 @@
#include "llvm/IR/Module.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/TargetParser/Triple.h"
@@ -52,7 +53,7 @@ protected:
std::unique_ptr<LLVMContext> Ctx;
std::unique_ptr<Module> M;
std::unique_ptr<TargetMachine> TM;
- const TargetInstrInfo *TII;
+ const TargetInstrInfo *TII = nullptr;
static void SetUpTestCase() {
InitializeAllTargets();
@@ -93,6 +94,8 @@ protected:
return;
}
}
+
+ void TearDown() override { TII = nullptr; }
};
// Function to find an opcode by name
@@ -118,7 +121,11 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
VocabMap VMap;
Embedding Val = Embedding(64, 1.0f);
VMap["ADD"] = Val;
- MIRVocabulary TestVocab(std::move(VMap), TII);
+ auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
+ << "Failed to create vocabulary: "
+ << toString(TestVocabOrErr.takeError());
+ auto &TestVocab = *TestVocabOrErr;
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
@@ -173,7 +180,11 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Use a minimal MIRVocabulary to trigger canonical mapping construction
VocabMap VMap;
VMap["ADD"] = Embedding(64, 1.0f);
- MIRVocabulary TestVocab(std::move(VMap), TII);
+ auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
+ << "Failed to create vocabulary: "
+ << toString(TestVocabOrErr.takeError());
+ auto &TestVocab = *TestVocabOrErr;
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
@@ -195,8 +206,10 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
- MIRVocabulary Vocab(std::move(VMap), TII);
- EXPECT_TRUE(Vocab.isValid());
+ auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
+ ASSERT_TRUE(static_cast<bool>(VocabOrErr))
+ << "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
+ auto &Vocab = *VocabOrErr;
EXPECT_EQ(Vocab.getDimension(), 128u);
// Test iterator - iterates over individual embeddings
@@ -214,4 +227,20 @@ TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
EXPECT_GT(Count, 0u);
}
-} // namespace \ No newline at end of file
+// Test factory method with empty vocabulary
+TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
+ VocabMap EmptyVMap;
+
+ auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII);
+ EXPECT_FALSE(static_cast<bool>(VocabOrErr))
+ << "Factory method should fail with empty vocabulary";
+
+ // Consume the error
+ if (!VocabOrErr) {
+ auto Err = VocabOrErr.takeError();
+ std::string ErrorMsg = toString(std::move(Err));
+ EXPECT_FALSE(ErrorMsg.empty());
+ }
+}
+
+} // namespace