aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsvkeerthy <venkatakeerthy@google.com>2025-09-03 22:56:08 +0000
committersvkeerthy <venkatakeerthy@google.com>2025-09-08 23:22:27 +0000
commit52875ac0a149d77682fb8ba62c07fb1dcc132737 (patch)
treeae24115b05720cdeb6b23a6cfcf212d331248f71
parent69f1aebf2017f4362347ec9890ee4b57c3b989f1 (diff)
downloadllvm-users/svkeerthy/09-03-support_predicates.zip
llvm-users/svkeerthy/09-03-support_predicates.tar.gz
llvm-users/svkeerthy/09-03-support_predicates.tar.bz2
-rw-r--r--llvm/include/llvm/Analysis/IR2Vec.h48
-rw-r--r--llvm/lib/Analysis/IR2Vec.cpp76
-rw-r--r--llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json28
-rw-r--r--llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json28
-rw-r--r--llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json29
-rw-r--r--llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt26
-rw-r--r--llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt26
-rw-r--r--llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt26
-rw-r--r--llvm/test/Analysis/IR2Vec/if-else.ll2
-rw-r--r--llvm/test/Analysis/IR2Vec/unreachable.ll2
-rw-r--r--llvm/test/tools/llvm-ir2vec/entities.ll28
-rw-r--r--llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp2
-rw-r--r--llvm/unittests/Analysis/IR2VecTest.cpp47
13 files changed, 344 insertions, 24 deletions
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 3671c1c..4a6db5d 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -36,6 +36,7 @@
#define LLVM_ANALYSIS_IR2VEC_H
#include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/CommandLine.h"
@@ -162,15 +163,29 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
/// embeddings.
class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;
+
+ // Vocabulary Slot Layout:
+ // +----------------+------------------------------------------------------+
+ // | Entity Type | Index Range |
+ // +----------------+------------------------------------------------------+
+ // | Opcodes | [0 .. (MaxOpcodes-1)] |
+ // | Canonical Types| [MaxOpcodes .. (MaxOpcodes+MaxCanonicalTypeIDs-1)] |
+ // | Operands | [(MaxOpcodes+MaxCanonicalTypeIDs) .. NumCanEntries] |
+ // +----------------+------------------------------------------------------+
+ // Note: "Similar" LLVM Types are grouped/canonicalized together.
+ // Operands include Comparison predicates (ICmp/FCmp).
+ // This can be extended to include other specializations in future.
using VocabVector = std::vector<ir2vec::Embedding>;
VocabVector Vocab;
-public:
- // Slot layout:
- // [0 .. MaxOpcodes-1] => Instruction opcodes
- // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
- // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
+ static constexpr unsigned NumICmpPredicates =
+ static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
+ static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1;
+ static constexpr unsigned NumFCmpPredicates =
+ static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) -
+ static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1;
+public:
/// Canonical type IDs supported by IR2Vec Vocabulary
enum class CanonicalTypeID : unsigned {
FloatTy,
@@ -207,13 +222,18 @@ public:
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
static constexpr unsigned MaxOperandKinds =
static_cast<unsigned>(OperandKind::MaxOperandKind);
+ // CmpInst::Predicate has gaps. We want the vocabulary to be dense without
+ // empty slots.
+ static constexpr unsigned MaxPredicateKinds =
+ NumICmpPredicates + NumFCmpPredicates;
Vocabulary() = default;
LLVM_ABI Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
LLVM_ABI bool isValid() const { return Vocab.size() == NumCanonicalEntries; };
LLVM_ABI unsigned getDimension() const;
- /// Total number of entries (opcodes + canonicalized types + operand kinds)
+ /// Total number of entries (opcodes + canonicalized types + operand kinds +
+ /// predicates)
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
/// Function to get vocabulary key for a given Opcode
@@ -228,16 +248,21 @@ public:
/// Function to classify an operand into OperandKind
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
+ /// Function to get vocabulary key for a given predicate
+ LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
+
/// Functions to return the slot index or position of a given Opcode, TypeID,
/// or OperandKind in the vocabulary.
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
LLVM_ABI static unsigned getSlotIndex(const Value &Op);
+ LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
/// Accessors to get the embedding for a given entity.
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
+ LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
/// Const Iterator type aliases
using const_iterator = VocabVector::const_iterator;
@@ -274,7 +299,13 @@ public:
private:
constexpr static unsigned NumCanonicalEntries =
- MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
+ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
+
+ // Base offsets for slot layout to simplify index computation
+ constexpr static unsigned OperandBaseOffset =
+ MaxOpcodes + MaxCanonicalTypeIDs;
+ constexpr static unsigned PredicateBaseOffset =
+ OperandBaseOffset + MaxOperandKinds;
/// String mappings for CanonicalTypeID values
static constexpr StringLiteral CanonicalTypeNames[] = {
@@ -326,6 +357,9 @@ private:
/// Function to convert TypeID to CanonicalTypeID
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
+
+ /// Function to get the predicate enum value for a given index
+ LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
};
/// Embedder provides the interface to generate embeddings (vector
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 99afc06..f51f089 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
ArgEmb += Vocab[*Op];
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
+ InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
@@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
// embeddings
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ // Add compare predicate embedding as an additional operand if applicable
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
+ InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
@@ -278,7 +283,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
unsigned Vocabulary::getSlotIndex(const Value &Op) {
unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
assert(Index < MaxOperandKinds && "Invalid OperandKind");
- return MaxOpcodes + MaxCanonicalTypeIDs + Index;
+ return OperandBaseOffset + Index;
+}
+
+unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
+ unsigned PU = static_cast<unsigned>(P);
+ unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
+ unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
+
+ unsigned PredIdx =
+ (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
+ return PredicateBaseOffset + PredIdx;
}
const Embedding &Vocabulary::operator[](unsigned Opcode) const {
@@ -293,6 +308,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
return Vocab[getSlotIndex(Arg)];
}
+const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
+ return Vocab[getSlotIndex(P)];
+}
+
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
#define HANDLE_INST(NUM, OPCODE, CLASS) \
@@ -338,18 +357,41 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
return OperandKind::VariableID;
}
+CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
+ assert(Index < MaxPredicateKinds && "Invalid predicate index");
+ unsigned PredEnumVal =
+ (Index < NumFCmpPredicates)
+ ? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
+ : (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
+ (Index - NumFCmpPredicates));
+ return static_cast<CmpInst::Predicate>(PredEnumVal);
+}
+
+StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
+ static SmallString<16> PredNameBuffer;
+ if (Pred < CmpInst::FIRST_ICMP_PREDICATE)
+ PredNameBuffer = "FCMP_";
+ else
+ PredNameBuffer = "ICMP_";
+ PredNameBuffer += CmpInst::getPredicateName(Pred);
+ return PredNameBuffer;
+}
+
StringRef Vocabulary::getStringKey(unsigned Pos) {
assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
// Opcode
if (Pos < MaxOpcodes)
return getVocabKeyForOpcode(Pos + 1);
// Type
- if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+ if (Pos < OperandBaseOffset)
return getVocabKeyForCanonicalTypeID(
static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
// Operand
- return getVocabKeyForOperandKind(
- static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
+ if (Pos < PredicateBaseOffset)
+ return getVocabKeyForOperandKind(
+ static_cast<OperandKind>(Pos - OperandBaseOffset));
+ // Predicates
+ return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
}
// For now, assume vocabulary is stable unless explicitly invalidated.
@@ -363,11 +405,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
VocabVector DummyVocab;
DummyVocab.reserve(NumCanonicalEntries);
float DummyVal = 0.1f;
- // Create a dummy vocabulary with entries for all opcodes, types, and
- // operands
- for ([[maybe_unused]] unsigned _ :
- seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
- Vocabulary::MaxOperandKinds)) {
+ // Create a dummy vocabulary with entries for all opcodes, types, operands
+ // and predicates
+ for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
DummyVocab.push_back(Embedding(Dim, DummyVal));
DummyVal += 0.1f;
}
@@ -510,6 +550,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
}
Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
NumericArgEmbeddings.end());
+
+ // Handle Predicates: part of Operands section. We look up predicate keys
+ // in ArgVocab.
+ std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
+ Embedding(Dim, 0));
+ NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
+ for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
+ StringRef VocabKey =
+ Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
+ auto It = ArgVocab.find(VocabKey.str());
+ if (It != ArgVocab.end()) {
+ NumericPredEmbeddings[PK] = It->second;
+ continue;
+ }
+ handleMissingEntity(VocabKey.str());
+ }
+ Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
+ NumericPredEmbeddings.end());
}
IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
index 07fde84..ae36ff5 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
@@ -87,6 +87,32 @@
"Function": [1, 2],
"Pointer": [3, 4],
"Constant": [5, 6],
- "Variable": [7, 8]
+ "Variable": [7, 8],
+ "FCMP_false": [9, 10],
+ "FCMP_oeq": [11, 12],
+ "FCMP_ogt": [13, 14],
+ "FCMP_oge": [15, 16],
+ "FCMP_olt": [17, 18],
+ "FCMP_ole": [19, 20],
+ "FCMP_one": [21, 22],
+ "FCMP_ord": [23, 24],
+ "FCMP_uno": [25, 26],
+ "FCMP_ueq": [27, 28],
+ "FCMP_ugt": [29, 30],
+ "FCMP_uge": [31, 32],
+ "FCMP_ult": [33, 34],
+ "FCMP_ule": [35, 36],
+ "FCMP_une": [37, 38],
+ "FCMP_true": [39, 40],
+ "ICMP_eq": [41, 42],
+ "ICMP_ne": [43, 44],
+ "ICMP_ugt": [45, 46],
+ "ICMP_uge": [47, 48],
+ "ICMP_ult": [49, 50],
+ "ICMP_ule": [51, 52],
+ "ICMP_sgt": [53, 54],
+ "ICMP_sge": [55, 56],
+ "ICMP_slt": [57, 58],
+ "ICMP_sle": [59, 60]
}
}
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
index 932b3a2..9003dc7 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
@@ -86,6 +86,32 @@
"Function": [1, 2, 3],
"Pointer": [4, 5, 6],
"Constant": [7, 8, 9],
- "Variable": [10, 11, 12]
+ "Variable": [10, 11, 12],
+ "FCMP_false": [13, 14, 15],
+ "FCMP_oeq": [16, 17, 18],
+ "FCMP_ogt": [19, 20, 21],
+ "FCMP_oge": [22, 23, 24],
+ "FCMP_olt": [25, 26, 27],
+ "FCMP_ole": [28, 29, 30],
+ "FCMP_one": [31, 32, 33],
+ "FCMP_ord": [34, 35, 36],
+ "FCMP_uno": [37, 38, 39],
+ "FCMP_ueq": [40, 41, 42],
+ "FCMP_ugt": [43, 44, 45],
+ "FCMP_uge": [46, 47, 48],
+ "FCMP_ult": [49, 50, 51],
+ "FCMP_ule": [52, 53, 54],
+ "FCMP_une": [55, 56, 57],
+ "FCMP_true": [58, 59, 60],
+ "ICMP_eq": [61, 62, 63],
+ "ICMP_ne": [64, 65, 66],
+ "ICMP_ugt": [67, 68, 69],
+ "ICMP_uge": [70, 71, 72],
+ "ICMP_ult": [73, 74, 75],
+ "ICMP_ule": [76, 77, 78],
+ "ICMP_sgt": [79, 80, 81],
+ "ICMP_sge": [82, 83, 84],
+ "ICMP_slt": [85, 86, 87],
+ "ICMP_sle": [88, 89, 90]
}
}
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
index 19f3efe..7ef8549 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
@@ -47,6 +47,7 @@
"FPTrunc": [133, 134, 135],
"FPExt": [136, 137, 138],
"PtrToInt": [139, 140, 141],
+ "PtrToAddr": [202, 203, 204],
"IntToPtr": [142, 143, 144],
"BitCast": [145, 146, 147],
"AddrSpaceCast": [148, 149, 150],
@@ -86,6 +87,32 @@
"Function": [0, 0, 0],
"Pointer": [0, 0, 0],
"Constant": [0, 0, 0],
- "Variable": [0, 0, 0]
+ "Variable": [0, 0, 0],
+ "FCMP_false": [0, 0, 0],
+ "FCMP_oeq": [0, 0, 0],
+ "FCMP_ogt": [0, 0, 0],
+ "FCMP_oge": [0, 0, 0],
+ "FCMP_olt": [0, 0, 0],
+ "FCMP_ole": [0, 0, 0],
+ "FCMP_one": [0, 0, 0],
+ "FCMP_ord": [0, 0, 0],
+ "FCMP_uno": [0, 0, 0],
+ "FCMP_ueq": [0, 0, 0],
+ "FCMP_ugt": [0, 0, 0],
+ "FCMP_uge": [0, 0, 0],
+ "FCMP_ult": [0, 0, 0],
+ "FCMP_ule": [0, 0, 0],
+ "FCMP_une": [0, 0, 0],
+ "FCMP_true": [0, 0, 0],
+ "ICMP_eq": [0, 0, 0],
+ "ICMP_ne": [0, 0, 0],
+ "ICMP_ugt": [0, 0, 0],
+ "ICMP_uge": [0, 0, 0],
+ "ICMP_ult": [0, 0, 0],
+ "ICMP_ule": [0, 0, 0],
+ "ICMP_sgt": [1, 1, 1],
+ "ICMP_sge": [0, 0, 0],
+ "ICMP_slt": [0, 0, 0],
+ "ICMP_sle": [0, 0, 0]
}
}
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
index df7769c..d62b0dd 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function: [ 0.20 0.40 ]
Key: Pointer: [ 0.60 0.80 ]
Key: Constant: [ 1.00 1.20 ]
Key: Variable: [ 1.40 1.60 ]
+Key: FCMP_false: [ 1.80 2.00 ]
+Key: FCMP_oeq: [ 2.20 2.40 ]
+Key: FCMP_ogt: [ 2.60 2.80 ]
+Key: FCMP_oge: [ 3.00 3.20 ]
+Key: FCMP_olt: [ 3.40 3.60 ]
+Key: FCMP_ole: [ 3.80 4.00 ]
+Key: FCMP_one: [ 4.20 4.40 ]
+Key: FCMP_ord: [ 4.60 4.80 ]
+Key: FCMP_uno: [ 5.00 5.20 ]
+Key: FCMP_ueq: [ 5.40 5.60 ]
+Key: FCMP_ugt: [ 5.80 6.00 ]
+Key: FCMP_uge: [ 6.20 6.40 ]
+Key: FCMP_ult: [ 6.60 6.80 ]
+Key: FCMP_ule: [ 7.00 7.20 ]
+Key: FCMP_une: [ 7.40 7.60 ]
+Key: FCMP_true: [ 7.80 8.00 ]
+Key: ICMP_eq: [ 8.20 8.40 ]
+Key: ICMP_ne: [ 8.60 8.80 ]
+Key: ICMP_ugt: [ 9.00 9.20 ]
+Key: ICMP_uge: [ 9.40 9.60 ]
+Key: ICMP_ult: [ 9.80 10.00 ]
+Key: ICMP_ule: [ 10.20 10.40 ]
+Key: ICMP_sgt: [ 10.60 10.80 ]
+Key: ICMP_sge: [ 11.00 11.20 ]
+Key: ICMP_slt: [ 11.40 11.60 ]
+Key: ICMP_sle: [ 11.80 12.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
index f3ce809..e443adb 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function: [ 0.50 1.00 ]
Key: Pointer: [ 1.50 2.00 ]
Key: Constant: [ 2.50 3.00 ]
Key: Variable: [ 3.50 4.00 ]
+Key: FCMP_false: [ 4.50 5.00 ]
+Key: FCMP_oeq: [ 5.50 6.00 ]
+Key: FCMP_ogt: [ 6.50 7.00 ]
+Key: FCMP_oge: [ 7.50 8.00 ]
+Key: FCMP_olt: [ 8.50 9.00 ]
+Key: FCMP_ole: [ 9.50 10.00 ]
+Key: FCMP_one: [ 10.50 11.00 ]
+Key: FCMP_ord: [ 11.50 12.00 ]
+Key: FCMP_uno: [ 12.50 13.00 ]
+Key: FCMP_ueq: [ 13.50 14.00 ]
+Key: FCMP_ugt: [ 14.50 15.00 ]
+Key: FCMP_uge: [ 15.50 16.00 ]
+Key: FCMP_ult: [ 16.50 17.00 ]
+Key: FCMP_ule: [ 17.50 18.00 ]
+Key: FCMP_une: [ 18.50 19.00 ]
+Key: FCMP_true: [ 19.50 20.00 ]
+Key: ICMP_eq: [ 20.50 21.00 ]
+Key: ICMP_ne: [ 21.50 22.00 ]
+Key: ICMP_ugt: [ 22.50 23.00 ]
+Key: ICMP_uge: [ 23.50 24.00 ]
+Key: ICMP_ult: [ 24.50 25.00 ]
+Key: ICMP_ule: [ 25.50 26.00 ]
+Key: ICMP_sgt: [ 26.50 27.00 ]
+Key: ICMP_sge: [ 27.50 28.00 ]
+Key: ICMP_slt: [ 28.50 29.00 ]
+Key: ICMP_sle: [ 29.50 30.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
index 72b25b9..7fb6043 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function: [ 0.00 0.00 ]
Key: Pointer: [ 0.00 0.00 ]
Key: Constant: [ 0.00 0.00 ]
Key: Variable: [ 0.00 0.00 ]
+Key: FCMP_false: [ 0.00 0.00 ]
+Key: FCMP_oeq: [ 0.00 0.00 ]
+Key: FCMP_ogt: [ 0.00 0.00 ]
+Key: FCMP_oge: [ 0.00 0.00 ]
+Key: FCMP_olt: [ 0.00 0.00 ]
+Key: FCMP_ole: [ 0.00 0.00 ]
+Key: FCMP_one: [ 0.00 0.00 ]
+Key: FCMP_ord: [ 0.00 0.00 ]
+Key: FCMP_uno: [ 0.00 0.00 ]
+Key: FCMP_ueq: [ 0.00 0.00 ]
+Key: FCMP_ugt: [ 0.00 0.00 ]
+Key: FCMP_uge: [ 0.00 0.00 ]
+Key: FCMP_ult: [ 0.00 0.00 ]
+Key: FCMP_ule: [ 0.00 0.00 ]
+Key: FCMP_une: [ 0.00 0.00 ]
+Key: FCMP_true: [ 0.00 0.00 ]
+Key: ICMP_eq: [ 0.00 0.00 ]
+Key: ICMP_ne: [ 0.00 0.00 ]
+Key: ICMP_ugt: [ 0.00 0.00 ]
+Key: ICMP_uge: [ 0.00 0.00 ]
+Key: ICMP_ult: [ 0.00 0.00 ]
+Key: ICMP_ule: [ 0.00 0.00 ]
+Key: ICMP_sgt: [ 0.00 0.00 ]
+Key: ICMP_sge: [ 0.00 0.00 ]
+Key: ICMP_slt: [ 0.00 0.00 ]
+Key: ICMP_sle: [ 0.00 0.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/if-else.ll b/llvm/test/Analysis/IR2Vec/if-else.ll
index fe53247..804c1ca 100644
--- a/llvm/test/Analysis/IR2Vec/if-else.ll
+++ b/llvm/test/Analysis/IR2Vec/if-else.ll
@@ -29,7 +29,7 @@ return: ; preds = %if.else, %if.then
; CHECK: Basic block vectors:
; CHECK-NEXT: Basic block: entry:
-; CHECK-NEXT: [ 816.00 825.00 834.00 ]
+; CHECK-NEXT: [ 816.20 825.20 834.20 ]
; CHECK-NEXT: Basic block: if.then:
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
; CHECK-NEXT: Basic block: if.else:
diff --git a/llvm/test/Analysis/IR2Vec/unreachable.ll b/llvm/test/Analysis/IR2Vec/unreachable.ll
index b0e3e49..9be0ee1 100644
--- a/llvm/test/Analysis/IR2Vec/unreachable.ll
+++ b/llvm/test/Analysis/IR2Vec/unreachable.ll
@@ -33,7 +33,7 @@ return: ; preds = %if.else, %if.then
; CHECK: Basic block vectors:
; CHECK-NEXT: Basic block: entry:
-; CHECK-NEXT: [ 816.00 825.00 834.00 ]
+; CHECK-NEXT: [ 816.20 825.20 834.20 ]
; CHECK-NEXT: Basic block: if.then:
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
; CHECK-NEXT: Basic block: if.else:
diff --git a/llvm/test/tools/llvm-ir2vec/entities.ll b/llvm/test/tools/llvm-ir2vec/entities.ll
index 4b51adf..8dbce57 100644
--- a/llvm/test/tools/llvm-ir2vec/entities.ll
+++ b/llvm/test/tools/llvm-ir2vec/entities.ll
@@ -1,6 +1,6 @@
; RUN: llvm-ir2vec entities | FileCheck %s
-CHECK: 84
+CHECK: 110
CHECK-NEXT: Ret 0
CHECK-NEXT: Br 1
CHECK-NEXT: Switch 2
@@ -85,3 +85,29 @@ CHECK-NEXT: Function 80
CHECK-NEXT: Pointer 81
CHECK-NEXT: Constant 82
CHECK-NEXT: Variable 83
+CHECK-NEXT: FCMP_false 84
+CHECK-NEXT: FCMP_oeq 85
+CHECK-NEXT: FCMP_ogt 86
+CHECK-NEXT: FCMP_oge 87
+CHECK-NEXT: FCMP_olt 88
+CHECK-NEXT: FCMP_ole 89
+CHECK-NEXT: FCMP_one 90
+CHECK-NEXT: FCMP_ord 91
+CHECK-NEXT: FCMP_uno 92
+CHECK-NEXT: FCMP_ueq 93
+CHECK-NEXT: FCMP_ugt 94
+CHECK-NEXT: FCMP_uge 95
+CHECK-NEXT: FCMP_ult 96
+CHECK-NEXT: FCMP_ule 97
+CHECK-NEXT: FCMP_une 98
+CHECK-NEXT: FCMP_true 99
+CHECK-NEXT: ICMP_eq 100
+CHECK-NEXT: ICMP_ne 101
+CHECK-NEXT: ICMP_ugt 102
+CHECK-NEXT: ICMP_uge 103
+CHECK-NEXT: ICMP_ult 104
+CHECK-NEXT: ICMP_ule 105
+CHECK-NEXT: ICMP_sgt 106
+CHECK-NEXT: ICMP_sge 107
+CHECK-NEXT: ICMP_slt 108
+CHECK-NEXT: ICMP_sle 109
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index aabebf0..1c656b8 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -184,7 +184,7 @@ public:
// Add "Arg" relationships
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
- unsigned OperandID = Vocabulary::getSlotIndex(*U);
+ unsigned OperandID = Vocabulary::getSlotIndex(*U.get());
unsigned RelationID = ArgRelation + ArgIndex;
OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 9f2f6a3..9bc48e4 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -435,6 +435,7 @@ static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;
static constexpr unsigned MaxTypeIDs = Vocabulary::MaxTypeIDs;
static constexpr unsigned MaxCanonicalTypeIDs = Vocabulary::MaxCanonicalTypeIDs;
static constexpr unsigned MaxOperands = Vocabulary::MaxOperandKinds;
+static constexpr unsigned MaxPredicateKinds = Vocabulary::MaxPredicateKinds;
// Mapping between LLVM Type::TypeID tokens and Vocabulary::CanonicalTypeID
// names and their canonical string keys.
@@ -460,7 +461,8 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
EXPECT_EQ(Emb.size(), Dim);
// Should have the correct total number of embeddings
- EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands);
+ EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands +
+ MaxPredicateKinds);
auto ExpectedVocab = VocabVec;
@@ -527,6 +529,26 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
EXPECT_EQ(Vocabulary::getSlotIndex(*Arg),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::VariableID));
#undef EXPECTED_VOCAB_OPERAND_SLOT
+
+ // Test getSlotIndex for predicates
+#define EXPECTED_VOCAB_PREDICATE_SLOT(X) \
+ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + static_cast<unsigned>(X)
+ for (unsigned P = CmpInst::FIRST_FCMP_PREDICATE;
+ P <= CmpInst::LAST_FCMP_PREDICATE; ++P) {
+ CmpInst::Predicate Pred = static_cast<CmpInst::Predicate>(P);
+ unsigned ExpectedIdx =
+ EXPECTED_VOCAB_PREDICATE_SLOT((P - CmpInst::FIRST_FCMP_PREDICATE));
+ EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx);
+ }
+ auto ICMP_Start = CmpInst::LAST_FCMP_PREDICATE + 1;
+ for (unsigned P = CmpInst::FIRST_ICMP_PREDICATE;
+ P <= CmpInst::LAST_ICMP_PREDICATE; ++P) {
+ CmpInst::Predicate Pred = static_cast<CmpInst::Predicate>(P);
+ unsigned ExpectedIdx = EXPECTED_VOCAB_PREDICATE_SLOT(
+ ICMP_Start + P - CmpInst::FIRST_ICMP_PREDICATE);
+ EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx);
+ }
+#undef EXPECTED_VOCAB_PREDICATE_SLOT
}
#if GTEST_HAS_DEATH_TEST
@@ -569,6 +591,7 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
#undef EXPECT_CANONICAL_TYPE_NAME
+ // Verify OperandKind -> string mapping
#define HANDLE_OPERAND_KINDS(X) \
X(FunctionID, "Function") \
X(PointerID, "Pointer") \
@@ -592,6 +615,28 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
Vocabulary::getStringKey(MaxOpcodes + MaxCanonicalTypeIDs + 1);
EXPECT_EQ(FuncArgKey, "Function");
EXPECT_EQ(PtrArgKey, "Pointer");
+
+// Verify PredicateKind -> string mapping
+#define EXPECT_PREDICATE_KIND(PredNum, PredPos, PredKind) \
+ do { \
+ std::string PredStr = \
+ std::string(PredKind) + "_" + \
+ CmpInst::getPredicateName(static_cast<CmpInst::Predicate>(PredNum)) \
+ .str(); \
+ unsigned Pos = MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + PredPos; \
+ EXPECT_EQ(Vocabulary::getStringKey(Pos), PredStr); \
+ } while (0)
+
+ for (unsigned P = CmpInst::FIRST_FCMP_PREDICATE;
+ P <= CmpInst::LAST_FCMP_PREDICATE; ++P)
+ EXPECT_PREDICATE_KIND(P, P - CmpInst::FIRST_FCMP_PREDICATE, "FCMP");
+
+ auto ICMP_Pos = CmpInst::LAST_FCMP_PREDICATE + 1;
+ for (unsigned P = CmpInst::FIRST_ICMP_PREDICATE;
+ P <= CmpInst::LAST_ICMP_PREDICATE; ++P)
+ EXPECT_PREDICATE_KIND(P, ICMP_Pos++, "ICMP");
+
+#undef EXPECT_PREDICATE_KIND
}
TEST(IR2VecVocabularyTest, VocabularyDimensions) {