diff options
author | S. VenkataKeerthy <31350914+svkeerthy@users.noreply.github.com> | 2025-05-29 13:35:29 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-05-29 13:35:29 -0700 |
commit | e74b45e0789ab1da1afab48dc2fe39e0ed7a326e (patch) | |
tree | ccd68a2475a27f4760887e42fc65fde605b7190a /llvm/unittests/Analysis | |
parent | a8c6a5017de7076f3011b0ddba6f224f7e1f93f3 (diff) | |
download | llvm-e74b45e0789ab1da1afab48dc2fe39e0ed7a326e.zip llvm-e74b45e0789ab1da1afab48dc2fe39e0ed7a326e.tar.gz llvm-e74b45e0789ab1da1afab48dc2fe39e0ed7a326e.tar.bz2 |
[IR2Vec] Adding unit tests (#141873)
This PR adds unit tests for IR2Vec
(Tracking issue - #141817)
Diffstat (limited to 'llvm/unittests/Analysis')
-rw-r--r-- | llvm/unittests/Analysis/CMakeLists.txt | 1 | ||||
-rw-r--r-- | llvm/unittests/Analysis/IR2VecTest.cpp | 243 |
2 files changed, 244 insertions, 0 deletions
diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt index 67f0b04..cd04a77 100644 --- a/llvm/unittests/Analysis/CMakeLists.txt +++ b/llvm/unittests/Analysis/CMakeLists.txt @@ -32,6 +32,7 @@ set(ANALYSIS_TEST_SOURCES GlobalsModRefTest.cpp FunctionPropertiesAnalysisTest.cpp InlineCostTest.cpp + IR2VecTest.cpp IRSimilarityIdentifierTest.cpp IVDescriptorsTest.cpp LastRunTrackingAnalysisTest.cpp diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp new file mode 100644 index 0000000..5fb4da9 --- /dev/null +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -0,0 +1,243 @@ +//===- IR2VecTest.cpp - Unit tests for IR2Vec -----------------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/IR2Vec.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Error.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include <map> +#include <vector> + +using namespace llvm; +using namespace ir2vec; +using namespace ::testing; + +namespace { + +class TestableEmbedder : public Embedder { +public: + TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim) + : Embedder(F, V, Dim) {} + void computeEmbeddings() const override {} + using Embedder::lookupVocab; + static void addVectors(Embedding &Dst, const Embedding &Src) { + Embedder::addVectors(Dst, Src); + } + static void addScaledVector(Embedding &Dst, const Embedding &Src, + float Factor) { + Embedder::addScaledVector(Dst, Src, Factor); + } +}; + +TEST(IR2VecTest, CreateSymbolicEmbedder) { + Vocab V = {{"foo", {1.0, 2.0}}}; + + LLVMContext Ctx; + Module M("M", Ctx); + FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false); + Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M); + + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2); + EXPECT_TRUE(static_cast<bool>(Result)); + + auto *Emb = Result->get(); + EXPECT_NE(Emb, nullptr); +} + +TEST(IR2VecTest, CreateInvalidMode) { + Vocab V = {{"foo", {1.0, 2.0}}}; + + LLVMContext Ctx; + Module M("M", Ctx); + FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false); + Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M); + + // static_cast an invalid int to IR2VecKind + auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V, 2); + EXPECT_FALSE(static_cast<bool>(Result)); + + std::string ErrMsg; + llvm::handleAllErrors( + Result.takeError(), + [&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); }); + EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos); +} + +TEST(IR2VecTest, AddVectors) { + Embedding E1 = {1.0, 2.0, 3.0}; + Embedding E2 = {0.5, 1.5, -1.0}; + + TestableEmbedder::addVectors(E1, E2); + EXPECT_THAT(E1, ElementsAre(1.5, 3.5, 2.0)); + + // Check that E2 is unchanged + EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0)); +} + +TEST(IR2VecTest, AddScaledVector) { + Embedding E1 = {1.0, 2.0, 3.0}; + Embedding E2 = {2.0, 0.5, -1.0}; + + TestableEmbedder::addScaledVector(E1, E2, 0.5f); + EXPECT_THAT(E1, ElementsAre(2.0, 2.25, 2.5)); + + // Check that E2 is unchanged + EXPECT_THAT(E2, ElementsAre(2.0, 0.5, -1.0)); +} + +#if GTEST_HAS_DEATH_TEST +#ifndef NDEBUG +TEST(IR2VecTest, MismatchedDimensionsAddVectors) { + Embedding E1 = {1.0, 2.0}; + Embedding E2 = {1.0}; + EXPECT_DEATH(TestableEmbedder::addVectors(E1, E2), + "Vectors must have the same dimension"); +} + +TEST(IR2VecTest, MismatchedDimensionsAddScaledVector) { + Embedding E1 = {1.0, 2.0}; + Embedding E2 = {1.0}; + EXPECT_DEATH(TestableEmbedder::addScaledVector(E1, E2, 1.0f), + "Vectors must have the same dimension"); +} +#endif // NDEBUG +#endif // GTEST_HAS_DEATH_TEST + +TEST(IR2VecTest, LookupVocab) { + Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}}; + LLVMContext Ctx; + Module M("M", Ctx); + FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false); + Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M); + + TestableEmbedder E(*F, V, 2); + auto V_foo = E.lookupVocab("foo"); + EXPECT_EQ(V_foo.size(), 2u); + EXPECT_THAT(V_foo, ElementsAre(1.0, 2.0)); + + auto V_missing = E.lookupVocab("missing"); + EXPECT_EQ(V_missing.size(), 2u); + EXPECT_THAT(V_missing, ElementsAre(0.0, 0.0)); +} + +TEST(IR2VecTest, ZeroDimensionEmbedding) { + Embedding E1; + Embedding E2; + // Should be no-op, but not crash + TestableEmbedder::addVectors(E1, E2); + TestableEmbedder::addScaledVector(E1, E2, 1.0f); + EXPECT_TRUE(E1.empty()); +} + +TEST(IR2VecTest, IR2VecVocabResultValidity) { + // Default constructed is invalid + IR2VecVocabResult invalidResult; + EXPECT_FALSE(invalidResult.isValid()); +#if GTEST_HAS_DEATH_TEST +#ifndef NDEBUG + EXPECT_DEATH(invalidResult.getVocabulary(), "IR2Vec Vocabulary is invalid"); + EXPECT_DEATH(invalidResult.getDimension(), "IR2Vec Vocabulary is invalid"); +#endif // NDEBUG +#endif // GTEST_HAS_DEATH_TEST + + // Valid vocab + Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}}; + IR2VecVocabResult validResult(std::move(V)); + EXPECT_TRUE(validResult.isValid()); + EXPECT_EQ(validResult.getDimension(), 2u); +} + +// Helper to create a minimal function and embedder for getter tests +struct GetterTestEnv { + Vocab V = {}; + LLVMContext Ctx; + std::unique_ptr<Module> M = nullptr; + Function *F = nullptr; + BasicBlock *BB = nullptr; + Instruction *Add = nullptr; + Instruction *Ret = nullptr; + std::unique_ptr<Embedder> Emb = nullptr; + + GetterTestEnv() { + V = {{"add", {1.0, 2.0}}, + {"integerTy", {0.5, 0.5}}, + {"constant", {0.2, 0.3}}, + {"variable", {0.0, 0.0}}, + {"unknownTy", {0.0, 0.0}}}; + + M = std::make_unique<Module>("M", Ctx); + FunctionType *FTy = FunctionType::get( + Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)}, + false); + F = Function::Create(FTy, Function::ExternalLinkage, "f", M.get()); + BB = BasicBlock::Create(Ctx, "entry", F); + Argument *Arg = F->getArg(0); + llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42); + + Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB); + Ret = ReturnInst::Create(Ctx, Add, BB); + + auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2); + EXPECT_TRUE(static_cast<bool>(Result)); + Emb = std::move(*Result); + } +}; + +TEST(IR2VecTest, GetInstVecMap) { + GetterTestEnv Env; + const auto &InstMap = Env.Emb->getInstVecMap(); + + EXPECT_EQ(InstMap.size(), 2u); + EXPECT_TRUE(InstMap.count(Env.Add)); + EXPECT_TRUE(InstMap.count(Env.Ret)); + + EXPECT_EQ(InstMap.at(Env.Add).size(), 2u); + EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u); + + // Check values for add: {1.29, 2.31} + EXPECT_THAT(InstMap.at(Env.Add), + ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6))); + + // Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in + // vocab + EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0)); +} + +TEST(IR2VecTest, GetBBVecMap) { + GetterTestEnv Env; + const auto &BBMap = Env.Emb->getBBVecMap(); + + EXPECT_EQ(BBMap.size(), 1u); + EXPECT_TRUE(BBMap.count(Env.BB)); + EXPECT_EQ(BBMap.at(Env.BB).size(), 2u); + + // BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} = + // {1.29, 2.31} + EXPECT_THAT(BBMap.at(Env.BB), + ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6))); +} + +TEST(IR2VecTest, GetFunctionVector) { + GetterTestEnv Env; + const auto &FuncVec = Env.Emb->getFunctionVector(); + + EXPECT_EQ(FuncVec.size(), 2u); + + // Function vector should match BB vector (only one BB): {1.29, 2.31} + EXPECT_THAT(FuncVec, + ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6))); +} + +} // end anonymous namespace |