aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorModi Mo <mmo@nvidia.com>2026-04-22 21:20:32 -0700
committerModi Mo <mmo@nvidia.com>2026-04-22 21:20:55 -0700
commitf0d721ab83fdf9f0a93bb2b793ed300dc1471105 (patch)
tree12e4dc81107524871218b988d1fc23333a1ae9d3
parent23ef732a8c235777b56a8cfb83119acf955117e5 (diff)
downloadllvm-users/modiking/nvptx-setp-predicate-inversion.tar.gz
llvm-users/modiking/nvptx-setp-predicate-inversion.tar.bz2
llvm-users/modiking/nvptx-setp-predicate-inversion.zip
move cmp modes into td and update usersusers/modiking/nvptx-setp-predicate-inversion
-rw-r--r--llvm/lib/Target/NVPTX/CMakeLists.txt1
-rw-r--r--llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp120
-rw-r--r--llvm/lib/Target/NVPTX/NVPTX.h34
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp34
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp58
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td64
6 files changed, 113 insertions, 198 deletions
diff --git a/llvm/lib/Target/NVPTX/CMakeLists.txt b/llvm/lib/Target/NVPTX/CMakeLists.txt
index e0e2b1a40bff..6fd07fe2210f 100644
--- a/llvm/lib/Target/NVPTX/CMakeLists.txt
+++ b/llvm/lib/Target/NVPTX/CMakeLists.txt
@@ -8,6 +8,7 @@ tablegen(LLVM NVPTXGenInstrInfo.inc -gen-instr-info)
tablegen(LLVM NVPTXGenRegisterInfo.inc -gen-register-info)
tablegen(LLVM NVPTXGenSDNodeInfo.inc -gen-sd-node-info)
tablegen(LLVM NVPTXGenSubtargetInfo.inc -gen-subtarget)
+tablegen(LLVM NVPTXGenCmpModes.inc -gen-searchable-tables)
add_public_tablegen_target(NVPTXCommonTableGen)
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 9764d3a72619..8d59ea910ecd 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -33,6 +33,11 @@ using namespace llvm;
#include "NVPTXGenAsmWriter.inc"
+namespace llvm::NVPTX::PTXCmpMode {
+#define GET_PTXCmpModeTable_IMPL
+#include "NVPTXGenCmpModes.inc"
+} // namespace llvm::NVPTX::PTXCmpMode
+
static bool hasParamSubqualifiers(const MCSubtargetInfo &STI) {
return STI.hasFeature(NVPTX::PTX83);
}
@@ -183,108 +188,19 @@ void NVPTXInstPrinter::printNegatedPredicate(const MCInst *MI, int OpNum,
void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum,
const MCSubtargetInfo &, raw_ostream &O,
StringRef Modifier) {
- const MCOperand &MO = MI->getOperand(OpNum);
- int64_t Imm = MO.getImm();
-
- if (Modifier == "FCmp") {
- switch (Imm) {
- default:
- return;
- case NVPTX::PTXCmpMode::EQ:
- O << "eq";
- return;
- case NVPTX::PTXCmpMode::NE:
- O << "ne";
- return;
- case NVPTX::PTXCmpMode::LT:
- O << "lt";
- return;
- case NVPTX::PTXCmpMode::LE:
- O << "le";
- return;
- case NVPTX::PTXCmpMode::GT:
- O << "gt";
- return;
- case NVPTX::PTXCmpMode::GE:
- O << "ge";
- return;
- case NVPTX::PTXCmpMode::EQU:
- O << "equ";
- return;
- case NVPTX::PTXCmpMode::NEU:
- O << "neu";
- return;
- case NVPTX::PTXCmpMode::LTU:
- O << "ltu";
- return;
- case NVPTX::PTXCmpMode::LEU:
- O << "leu";
- return;
- case NVPTX::PTXCmpMode::GTU:
- O << "gtu";
- return;
- case NVPTX::PTXCmpMode::GEU:
- O << "geu";
- return;
- case NVPTX::PTXCmpMode::NUM:
- O << "num";
- return;
- case NVPTX::PTXCmpMode::NotANumber:
- O << "nan";
- return;
- }
- }
- if (Modifier == "ICmp") {
- switch (Imm) {
- default:
- llvm_unreachable("Invalid ICmp mode");
- case NVPTX::PTXCmpMode::EQ:
- O << "eq";
- return;
- case NVPTX::PTXCmpMode::NE:
- O << "ne";
- return;
- case NVPTX::PTXCmpMode::LT:
- case NVPTX::PTXCmpMode::LTU:
- O << "lt";
- return;
- case NVPTX::PTXCmpMode::LE:
- case NVPTX::PTXCmpMode::LEU:
- O << "le";
- return;
- case NVPTX::PTXCmpMode::GT:
- case NVPTX::PTXCmpMode::GTU:
- O << "gt";
- return;
- case NVPTX::PTXCmpMode::GE:
- case NVPTX::PTXCmpMode::GEU:
- O << "ge";
- return;
- }
- }
- if (Modifier == "IType") {
- switch (Imm) {
- default:
- llvm_unreachable("Invalid IType");
- case NVPTX::PTXCmpMode::EQ:
- case NVPTX::PTXCmpMode::NE:
- O << "b";
- return;
- case NVPTX::PTXCmpMode::LT:
- case NVPTX::PTXCmpMode::LE:
- case NVPTX::PTXCmpMode::GT:
- case NVPTX::PTXCmpMode::GE:
- O << "s";
- return;
- case NVPTX::PTXCmpMode::LTU:
- case NVPTX::PTXCmpMode::LEU:
- case NVPTX::PTXCmpMode::GTU:
- case NVPTX::PTXCmpMode::GEU:
- O << "u";
- return;
- }
- }
- llvm_unreachable("Empty Modifier");
+ const NVPTX::CmpModeInfo *Info =
+ NVPTX::PTXCmpMode::lookupCmpModeByValue(MI->getOperand(OpNum).getImm());
+ StringRef Str;
+ if (Modifier == "FCmp")
+ Str = Info->FCmpPrintStr;
+ else if (Modifier == "ICmp")
+ Str = Info->ICmpPrintStr;
+ else if (Modifier == "IType")
+ Str = Info->ITypePrintStr;
+ else
+ llvm_unreachable("Empty Modifier");
+ assert(!Str.empty() && "Invalid comparison mode for this modifier");
+ O << Str;
}
void NVPTXInstPrinter::printAtomicCode(const MCInst *MI, int OpNum,
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 09a94034894e..50d6508e5f9b 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -240,26 +240,22 @@ enum CvtMode {
};
}
-/// PTXCmpMode - Comparison mode enumeration
-namespace PTXCmpMode {
-enum CmpMode {
- EQ = 0,
- NE,
- LT,
- LE,
- GT,
- GE,
- EQU,
- NEU,
- LTU,
- LEU,
- GTU,
- GEU,
- NUM,
- // NAN is a MACRO
- NotANumber,
+// Field order must match NVPTXInstrInfo.td
+struct CmpModeInfo {
+ uint8_t Value;
+ StringRef Name;
+ uint8_t IntInverseValue;
+ uint8_t FPInverseValue;
+ StringRef FCmpPrintStr;
+ StringRef ICmpPrintStr;
+ StringRef ITypePrintStr;
};
-}
+
+namespace PTXCmpMode {
+#define GET_PTXCmpMode_DECL
+#define GET_PTXCmpModeTable_DECL
+#include "NVPTXGenCmpModes.inc"
+} // namespace PTXCmpMode
namespace PTXPrmtMode {
enum PrmtMode {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ede1deb5400b..3904c74c080e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -366,48 +366,48 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
// Map ISD:CONDCODE value to appropriate CmpMode expected by
// NVPTXInstPrinter::printCmpMode()
SDValue NVPTXDAGToDAGISel::getPTXCmpMode(const CondCodeSDNode &CondCode) {
- using NVPTX::PTXCmpMode::CmpMode;
- const unsigned PTXCmpMode = [](ISD::CondCode CC) {
+ using namespace NVPTX::PTXCmpMode;
+ const unsigned Mode = [](ISD::CondCode CC) {
switch (CC) {
default:
llvm_unreachable("Unexpected condition code.");
case ISD::SETOEQ:
case ISD::SETEQ:
- return CmpMode::EQ;
+ return EQ;
case ISD::SETOGT:
case ISD::SETGT:
- return CmpMode::GT;
+ return GT;
case ISD::SETOGE:
case ISD::SETGE:
- return CmpMode::GE;
+ return GE;
case ISD::SETOLT:
case ISD::SETLT:
- return CmpMode::LT;
+ return LT;
case ISD::SETOLE:
case ISD::SETLE:
- return CmpMode::LE;
+ return LE;
case ISD::SETONE:
case ISD::SETNE:
- return CmpMode::NE;
+ return NE;
case ISD::SETO:
- return CmpMode::NUM;
+ return NUM;
case ISD::SETUO:
- return CmpMode::NotANumber;
+ return NotANumber;
case ISD::SETUEQ:
- return CmpMode::EQU;
+ return EQU;
case ISD::SETUGT:
- return CmpMode::GTU;
+ return GTU;
case ISD::SETUGE:
- return CmpMode::GEU;
+ return GEU;
case ISD::SETULT:
- return CmpMode::LTU;
+ return LTU;
case ISD::SETULE:
- return CmpMode::LEU;
+ return LEU;
case ISD::SETUNE:
- return CmpMode::NEU;
+ return NEU;
}
}(CondCode.get());
- return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(), MVT::i32);
+ return CurDAG->getTargetConstant(Mode, SDLoc(), MVT::i32);
}
bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index d9900e98fe69..981cef95e1da 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -255,65 +255,11 @@ static bool isScalarFloatSetp(const MachineInstr &MI) {
}
static int64_t invertIntegerCmpMode(int64_t Mode) {
- switch (Mode) {
- case NVPTX::PTXCmpMode::EQ:
- return NVPTX::PTXCmpMode::NE;
- case NVPTX::PTXCmpMode::NE:
- return NVPTX::PTXCmpMode::EQ;
- case NVPTX::PTXCmpMode::LT:
- return NVPTX::PTXCmpMode::GE;
- case NVPTX::PTXCmpMode::LE:
- return NVPTX::PTXCmpMode::GT;
- case NVPTX::PTXCmpMode::GT:
- return NVPTX::PTXCmpMode::LE;
- case NVPTX::PTXCmpMode::GE:
- return NVPTX::PTXCmpMode::LT;
- case NVPTX::PTXCmpMode::LTU:
- return NVPTX::PTXCmpMode::GEU;
- case NVPTX::PTXCmpMode::LEU:
- return NVPTX::PTXCmpMode::GTU;
- case NVPTX::PTXCmpMode::GTU:
- return NVPTX::PTXCmpMode::LEU;
- case NVPTX::PTXCmpMode::GEU:
- return NVPTX::PTXCmpMode::LTU;
- default:
- llvm_unreachable("Invalid integer comparison mode");
- }
+ return NVPTX::PTXCmpMode::lookupCmpModeByValue(Mode)->IntInverseValue;
}
static int64_t invertScalarFloatCmpMode(int64_t Mode) {
- switch (Mode) {
- case NVPTX::PTXCmpMode::EQ:
- return NVPTX::PTXCmpMode::NEU;
- case NVPTX::PTXCmpMode::NE:
- return NVPTX::PTXCmpMode::EQU;
- case NVPTX::PTXCmpMode::EQU:
- return NVPTX::PTXCmpMode::NE;
- case NVPTX::PTXCmpMode::NEU:
- return NVPTX::PTXCmpMode::EQ;
- case NVPTX::PTXCmpMode::LT:
- return NVPTX::PTXCmpMode::GEU;
- case NVPTX::PTXCmpMode::LE:
- return NVPTX::PTXCmpMode::GTU;
- case NVPTX::PTXCmpMode::GT:
- return NVPTX::PTXCmpMode::LEU;
- case NVPTX::PTXCmpMode::GE:
- return NVPTX::PTXCmpMode::LTU;
- case NVPTX::PTXCmpMode::LTU:
- return NVPTX::PTXCmpMode::GE;
- case NVPTX::PTXCmpMode::LEU:
- return NVPTX::PTXCmpMode::GT;
- case NVPTX::PTXCmpMode::GTU:
- return NVPTX::PTXCmpMode::LE;
- case NVPTX::PTXCmpMode::GEU:
- return NVPTX::PTXCmpMode::LT;
- case NVPTX::PTXCmpMode::NUM:
- return NVPTX::PTXCmpMode::NotANumber;
- case NVPTX::PTXCmpMode::NotANumber:
- return NVPTX::PTXCmpMode::NUM;
- default:
- llvm_unreachable("Invalid scalar float comparison mode");
- }
+ return NVPTX::PTXCmpMode::lookupCmpModeByValue(Mode)->FPInverseValue;
}
static void invertScalarCompareInstr(MachineInstr &MI) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 5ca7941a77a7..29983d00c9fd 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
include "NVPTXInstrFormats.td"
+include "llvm/TableGen/SearchableTable.td"
let OperandType = "OPERAND_IMMEDIATE" in {
def f16imm : Operand<f16>;
@@ -77,10 +78,65 @@ def BranchFlag : OperandWithDefaultOps<i32, (ops (i1 0))> {
let PrintMethod = "printNegatedPredicate";
}
-// Compare modes
-// These must match the enum in NVPTX.h
-def CmpEQ : PatLeaf<(i32 0)>;
-def CmpNE : PatLeaf<(i32 1)>;
+// SETP compare modes
+
+// Sentinel meaning "no valid inverse for this mode".
+defvar CmpModeInvalidInt = 0xFF;
+
+class CmpModeRec<int val, string name,
+ int intInv, int fpInv,
+ string fcmp, string icmp, string ityp> {
+ // Numeric encoding carried by the setp machine operand.
+ bits<8> Value = val;
+ // Enumerator spelling in NVPTX::PTXCmpMode.
+ string Name = name;
+ // Inverse predicate under integer-setp semantics
+ bits<8> IntInverseValue = intInv;
+ // Inverse predicate under scalar-float-setp semantics
+ bits<8> FPInverseValue = fpInv;
+ // Print mnemonic for float compares.
+ string FCmpPrintStr = fcmp;
+ // Print mnemonic for int compares.
+ string ICmpPrintStr = icmp;
+ // Signedness suffix for the integer setp instruction name (b/s/u);
+ string ITypePrintStr = ityp;
+}
+
+// val name intInv fpInv fcmp icmp ityp
+def CmpModeEQ_Rec : CmpModeRec<0, "EQ", 1, 7, "eq", "eq", "b">;
+def CmpModeNE_Rec : CmpModeRec<1, "NE", 0, 6, "ne", "ne", "b">;
+def CmpModeLT_Rec : CmpModeRec<2, "LT", 5, 11, "lt", "lt", "s">;
+def CmpModeLE_Rec : CmpModeRec<3, "LE", 4, 10, "le", "le", "s">;
+def CmpModeGT_Rec : CmpModeRec<4, "GT", 3, 9, "gt", "gt", "s">;
+def CmpModeGE_Rec : CmpModeRec<5, "GE", 2, 8, "ge", "ge", "s">;
+def CmpModeEQU_Rec : CmpModeRec<6, "EQU", CmpModeInvalidInt, 1, "equ", "", "">;
+def CmpModeNEU_Rec : CmpModeRec<7, "NEU", CmpModeInvalidInt, 0, "neu", "", "">;
+def CmpModeLTU_Rec : CmpModeRec<8, "LTU", 11, 5, "ltu", "lt", "u">;
+def CmpModeLEU_Rec : CmpModeRec<9, "LEU", 10, 4, "leu", "le", "u">;
+def CmpModeGTU_Rec : CmpModeRec<10, "GTU", 9, 3, "gtu", "gt", "u">;
+def CmpModeGEU_Rec : CmpModeRec<11, "GEU", 8, 2, "geu", "ge", "u">;
+def CmpModeNUM_Rec : CmpModeRec<12, "NUM", CmpModeInvalidInt, 13, "num", "", "">;
+def CmpModeNotANumber_Rec : CmpModeRec<13, "NotANumber", CmpModeInvalidInt, 12, "nan", "", "">;
+
+def PTXCmpMode : GenericEnum {
+ let FilterClass = "CmpModeRec";
+ let NameField = "Name";
+ let ValueField = "Value";
+ let UnderlyingType = "unsigned";
+}
+
+def PTXCmpModeTable : GenericTable {
+ let FilterClass = "CmpModeRec";
+ let CppTypeName = "CmpModeInfo";
+ let Fields = ["Value", "Name",
+ "IntInverseValue", "FPInverseValue",
+ "FCmpPrintStr", "ICmpPrintStr", "ITypePrintStr"];
+ let PrimaryKey = ["Value"];
+ let PrimaryKeyName = "lookupCmpModeByValue";
+}
+
+def CmpEQ : PatLeaf<(i32 !cast<int>(CmpModeEQ_Rec.Value))>;
+def CmpNE : PatLeaf<(i32 !cast<int>(CmpModeNE_Rec.Value))>;
def CmpMode : Operand<i32> {
let PrintMethod = "printCmpMode";