diff options
| author | Modi Mo <mmo@nvidia.com> | 2026-04-22 21:20:32 -0700 |
|---|---|---|
| committer | Modi Mo <mmo@nvidia.com> | 2026-04-22 21:20:55 -0700 |
| commit | f0d721ab83fdf9f0a93bb2b793ed300dc1471105 (patch) | |
| tree | 12e4dc81107524871218b988d1fc23333a1ae9d3 | |
| parent | 23ef732a8c235777b56a8cfb83119acf955117e5 (diff) | |
| download | llvm-users/modiking/nvptx-setp-predicate-inversion.tar.gz llvm-users/modiking/nvptx-setp-predicate-inversion.tar.bz2 llvm-users/modiking/nvptx-setp-predicate-inversion.zip | |
move cmp modes into td and update usersusers/modiking/nvptx-setp-predicate-inversion
| -rw-r--r-- | llvm/lib/Target/NVPTX/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp | 120 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTX.h | 34 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 34 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp | 58 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 64 |
6 files changed, 113 insertions, 198 deletions
diff --git a/llvm/lib/Target/NVPTX/CMakeLists.txt b/llvm/lib/Target/NVPTX/CMakeLists.txt index e0e2b1a40bff..6fd07fe2210f 100644 --- a/llvm/lib/Target/NVPTX/CMakeLists.txt +++ b/llvm/lib/Target/NVPTX/CMakeLists.txt @@ -8,6 +8,7 @@ tablegen(LLVM NVPTXGenInstrInfo.inc -gen-instr-info) tablegen(LLVM NVPTXGenRegisterInfo.inc -gen-register-info) tablegen(LLVM NVPTXGenSDNodeInfo.inc -gen-sd-node-info) tablegen(LLVM NVPTXGenSubtargetInfo.inc -gen-subtarget) +tablegen(LLVM NVPTXGenCmpModes.inc -gen-searchable-tables) add_public_tablegen_target(NVPTXCommonTableGen) diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp index 9764d3a72619..8d59ea910ecd 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp @@ -33,6 +33,11 @@ using namespace llvm; #include "NVPTXGenAsmWriter.inc" +namespace llvm::NVPTX::PTXCmpMode { +#define GET_PTXCmpModeTable_IMPL +#include "NVPTXGenCmpModes.inc" +} // namespace llvm::NVPTX::PTXCmpMode + static bool hasParamSubqualifiers(const MCSubtargetInfo &STI) { return STI.hasFeature(NVPTX::PTX83); } @@ -183,108 +188,19 @@ void NVPTXInstPrinter::printNegatedPredicate(const MCInst *MI, int OpNum, void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, const MCSubtargetInfo &, raw_ostream &O, StringRef Modifier) { - const MCOperand &MO = MI->getOperand(OpNum); - int64_t Imm = MO.getImm(); - - if (Modifier == "FCmp") { - switch (Imm) { - default: - return; - case NVPTX::PTXCmpMode::EQ: - O << "eq"; - return; - case NVPTX::PTXCmpMode::NE: - O << "ne"; - return; - case NVPTX::PTXCmpMode::LT: - O << "lt"; - return; - case NVPTX::PTXCmpMode::LE: - O << "le"; - return; - case NVPTX::PTXCmpMode::GT: - O << "gt"; - return; - case NVPTX::PTXCmpMode::GE: - O << "ge"; - return; - case NVPTX::PTXCmpMode::EQU: - O << "equ"; - return; - case NVPTX::PTXCmpMode::NEU: - O << "neu"; - return; - case NVPTX::PTXCmpMode::LTU: - O << "ltu"; - return; - case NVPTX::PTXCmpMode::LEU: - O << "leu"; - return; - case NVPTX::PTXCmpMode::GTU: - O << "gtu"; - return; - case NVPTX::PTXCmpMode::GEU: - O << "geu"; - return; - case NVPTX::PTXCmpMode::NUM: - O << "num"; - return; - case NVPTX::PTXCmpMode::NotANumber: - O << "nan"; - return; - } - } - if (Modifier == "ICmp") { - switch (Imm) { - default: - llvm_unreachable("Invalid ICmp mode"); - case NVPTX::PTXCmpMode::EQ: - O << "eq"; - return; - case NVPTX::PTXCmpMode::NE: - O << "ne"; - return; - case NVPTX::PTXCmpMode::LT: - case NVPTX::PTXCmpMode::LTU: - O << "lt"; - return; - case NVPTX::PTXCmpMode::LE: - case NVPTX::PTXCmpMode::LEU: - O << "le"; - return; - case NVPTX::PTXCmpMode::GT: - case NVPTX::PTXCmpMode::GTU: - O << "gt"; - return; - case NVPTX::PTXCmpMode::GE: - case NVPTX::PTXCmpMode::GEU: - O << "ge"; - return; - } - } - if (Modifier == "IType") { - switch (Imm) { - default: - llvm_unreachable("Invalid IType"); - case NVPTX::PTXCmpMode::EQ: - case NVPTX::PTXCmpMode::NE: - O << "b"; - return; - case NVPTX::PTXCmpMode::LT: - case NVPTX::PTXCmpMode::LE: - case NVPTX::PTXCmpMode::GT: - case NVPTX::PTXCmpMode::GE: - O << "s"; - return; - case NVPTX::PTXCmpMode::LTU: - case NVPTX::PTXCmpMode::LEU: - case NVPTX::PTXCmpMode::GTU: - case NVPTX::PTXCmpMode::GEU: - O << "u"; - return; - } - } - llvm_unreachable("Empty Modifier"); + const NVPTX::CmpModeInfo *Info = + NVPTX::PTXCmpMode::lookupCmpModeByValue(MI->getOperand(OpNum).getImm()); + StringRef Str; + if (Modifier == "FCmp") + Str = Info->FCmpPrintStr; + else if (Modifier == "ICmp") + Str = Info->ICmpPrintStr; + else if (Modifier == "IType") + Str = Info->ITypePrintStr; + else + llvm_unreachable("Empty Modifier"); + assert(!Str.empty() && "Invalid comparison mode for this modifier"); + O << Str; } void NVPTXInstPrinter::printAtomicCode(const MCInst *MI, int OpNum, diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h index 09a94034894e..50d6508e5f9b 100644 --- a/llvm/lib/Target/NVPTX/NVPTX.h +++ b/llvm/lib/Target/NVPTX/NVPTX.h @@ -240,26 +240,22 @@ enum CvtMode { }; } -/// PTXCmpMode - Comparison mode enumeration -namespace PTXCmpMode { -enum CmpMode { - EQ = 0, - NE, - LT, - LE, - GT, - GE, - EQU, - NEU, - LTU, - LEU, - GTU, - GEU, - NUM, - // NAN is a MACRO - NotANumber, +// Field order must match NVPTXInstrInfo.td +struct CmpModeInfo { + uint8_t Value; + StringRef Name; + uint8_t IntInverseValue; + uint8_t FPInverseValue; + StringRef FCmpPrintStr; + StringRef ICmpPrintStr; + StringRef ITypePrintStr; }; -} + +namespace PTXCmpMode { +#define GET_PTXCmpMode_DECL +#define GET_PTXCmpModeTable_DECL +#include "NVPTXGenCmpModes.inc" +} // namespace PTXCmpMode namespace PTXPrmtMode { enum PrmtMode { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index ede1deb5400b..3904c74c080e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -366,48 +366,48 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) { // Map ISD:CONDCODE value to appropriate CmpMode expected by // NVPTXInstPrinter::printCmpMode() SDValue NVPTXDAGToDAGISel::getPTXCmpMode(const CondCodeSDNode &CondCode) { - using NVPTX::PTXCmpMode::CmpMode; - const unsigned PTXCmpMode = [](ISD::CondCode CC) { + using namespace NVPTX::PTXCmpMode; + const unsigned Mode = [](ISD::CondCode CC) { switch (CC) { default: llvm_unreachable("Unexpected condition code."); case ISD::SETOEQ: case ISD::SETEQ: - return CmpMode::EQ; + return EQ; case ISD::SETOGT: case ISD::SETGT: - return CmpMode::GT; + return GT; case ISD::SETOGE: case ISD::SETGE: - return CmpMode::GE; + return GE; case ISD::SETOLT: case ISD::SETLT: - return CmpMode::LT; + return LT; case ISD::SETOLE: case ISD::SETLE: - return CmpMode::LE; + return LE; case ISD::SETONE: case ISD::SETNE: - return CmpMode::NE; + return NE; case ISD::SETO: - return CmpMode::NUM; + return NUM; case ISD::SETUO: - return CmpMode::NotANumber; + return NotANumber; case ISD::SETUEQ: - return CmpMode::EQU; + return EQU; case ISD::SETUGT: - return CmpMode::GTU; + return GTU; case ISD::SETUGE: - return CmpMode::GEU; + return GEU; case ISD::SETULT: - return CmpMode::LTU; + return LTU; case ISD::SETULE: - return CmpMode::LEU; + return LEU; case ISD::SETUNE: - return CmpMode::NEU; + return NEU; } }(CondCode.get()); - return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(), MVT::i32); + return CurDAG->getTargetConstant(Mode, SDLoc(), MVT::i32); } bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) { diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp index d9900e98fe69..981cef95e1da 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp @@ -255,65 +255,11 @@ static bool isScalarFloatSetp(const MachineInstr &MI) { } static int64_t invertIntegerCmpMode(int64_t Mode) { - switch (Mode) { - case NVPTX::PTXCmpMode::EQ: - return NVPTX::PTXCmpMode::NE; - case NVPTX::PTXCmpMode::NE: - return NVPTX::PTXCmpMode::EQ; - case NVPTX::PTXCmpMode::LT: - return NVPTX::PTXCmpMode::GE; - case NVPTX::PTXCmpMode::LE: - return NVPTX::PTXCmpMode::GT; - case NVPTX::PTXCmpMode::GT: - return NVPTX::PTXCmpMode::LE; - case NVPTX::PTXCmpMode::GE: - return NVPTX::PTXCmpMode::LT; - case NVPTX::PTXCmpMode::LTU: - return NVPTX::PTXCmpMode::GEU; - case NVPTX::PTXCmpMode::LEU: - return NVPTX::PTXCmpMode::GTU; - case NVPTX::PTXCmpMode::GTU: - return NVPTX::PTXCmpMode::LEU; - case NVPTX::PTXCmpMode::GEU: - return NVPTX::PTXCmpMode::LTU; - default: - llvm_unreachable("Invalid integer comparison mode"); - } + return NVPTX::PTXCmpMode::lookupCmpModeByValue(Mode)->IntInverseValue; } static int64_t invertScalarFloatCmpMode(int64_t Mode) { - switch (Mode) { - case NVPTX::PTXCmpMode::EQ: - return NVPTX::PTXCmpMode::NEU; - case NVPTX::PTXCmpMode::NE: - return NVPTX::PTXCmpMode::EQU; - case NVPTX::PTXCmpMode::EQU: - return NVPTX::PTXCmpMode::NE; - case NVPTX::PTXCmpMode::NEU: - return NVPTX::PTXCmpMode::EQ; - case NVPTX::PTXCmpMode::LT: - return NVPTX::PTXCmpMode::GEU; - case NVPTX::PTXCmpMode::LE: - return NVPTX::PTXCmpMode::GTU; - case NVPTX::PTXCmpMode::GT: - return NVPTX::PTXCmpMode::LEU; - case NVPTX::PTXCmpMode::GE: - return NVPTX::PTXCmpMode::LTU; - case NVPTX::PTXCmpMode::LTU: - return NVPTX::PTXCmpMode::GE; - case NVPTX::PTXCmpMode::LEU: - return NVPTX::PTXCmpMode::GT; - case NVPTX::PTXCmpMode::GTU: - return NVPTX::PTXCmpMode::LE; - case NVPTX::PTXCmpMode::GEU: - return NVPTX::PTXCmpMode::LT; - case NVPTX::PTXCmpMode::NUM: - return NVPTX::PTXCmpMode::NotANumber; - case NVPTX::PTXCmpMode::NotANumber: - return NVPTX::PTXCmpMode::NUM; - default: - llvm_unreachable("Invalid scalar float comparison mode"); - } + return NVPTX::PTXCmpMode::lookupCmpModeByValue(Mode)->FPInverseValue; } static void invertScalarCompareInstr(MachineInstr &MI) { diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 5ca7941a77a7..29983d00c9fd 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// include "NVPTXInstrFormats.td" +include "llvm/TableGen/SearchableTable.td" let OperandType = "OPERAND_IMMEDIATE" in { def f16imm : Operand<f16>; @@ -77,10 +78,65 @@ def BranchFlag : OperandWithDefaultOps<i32, (ops (i1 0))> { let PrintMethod = "printNegatedPredicate"; } -// Compare modes -// These must match the enum in NVPTX.h -def CmpEQ : PatLeaf<(i32 0)>; -def CmpNE : PatLeaf<(i32 1)>; +// SETP compare modes + +// Sentinel meaning "no valid inverse for this mode". +defvar CmpModeInvalidInt = 0xFF; + +class CmpModeRec<int val, string name, + int intInv, int fpInv, + string fcmp, string icmp, string ityp> { + // Numeric encoding carried by the setp machine operand. + bits<8> Value = val; + // Enumerator spelling in NVPTX::PTXCmpMode. + string Name = name; + // Inverse predicate under integer-setp semantics + bits<8> IntInverseValue = intInv; + // Inverse predicate under scalar-float-setp semantics + bits<8> FPInverseValue = fpInv; + // Print mnemonic for float compares. + string FCmpPrintStr = fcmp; + // Print mnemonic for int compares. + string ICmpPrintStr = icmp; + // Signedness suffix for the integer setp instruction name (b/s/u); + string ITypePrintStr = ityp; +} + +// val name intInv fpInv fcmp icmp ityp +def CmpModeEQ_Rec : CmpModeRec<0, "EQ", 1, 7, "eq", "eq", "b">; +def CmpModeNE_Rec : CmpModeRec<1, "NE", 0, 6, "ne", "ne", "b">; +def CmpModeLT_Rec : CmpModeRec<2, "LT", 5, 11, "lt", "lt", "s">; +def CmpModeLE_Rec : CmpModeRec<3, "LE", 4, 10, "le", "le", "s">; +def CmpModeGT_Rec : CmpModeRec<4, "GT", 3, 9, "gt", "gt", "s">; +def CmpModeGE_Rec : CmpModeRec<5, "GE", 2, 8, "ge", "ge", "s">; +def CmpModeEQU_Rec : CmpModeRec<6, "EQU", CmpModeInvalidInt, 1, "equ", "", "">; +def CmpModeNEU_Rec : CmpModeRec<7, "NEU", CmpModeInvalidInt, 0, "neu", "", "">; +def CmpModeLTU_Rec : CmpModeRec<8, "LTU", 11, 5, "ltu", "lt", "u">; +def CmpModeLEU_Rec : CmpModeRec<9, "LEU", 10, 4, "leu", "le", "u">; +def CmpModeGTU_Rec : CmpModeRec<10, "GTU", 9, 3, "gtu", "gt", "u">; +def CmpModeGEU_Rec : CmpModeRec<11, "GEU", 8, 2, "geu", "ge", "u">; +def CmpModeNUM_Rec : CmpModeRec<12, "NUM", CmpModeInvalidInt, 13, "num", "", "">; +def CmpModeNotANumber_Rec : CmpModeRec<13, "NotANumber", CmpModeInvalidInt, 12, "nan", "", "">; + +def PTXCmpMode : GenericEnum { + let FilterClass = "CmpModeRec"; + let NameField = "Name"; + let ValueField = "Value"; + let UnderlyingType = "unsigned"; +} + +def PTXCmpModeTable : GenericTable { + let FilterClass = "CmpModeRec"; + let CppTypeName = "CmpModeInfo"; + let Fields = ["Value", "Name", + "IntInverseValue", "FPInverseValue", + "FCmpPrintStr", "ICmpPrintStr", "ITypePrintStr"]; + let PrimaryKey = ["Value"]; + let PrimaryKeyName = "lookupCmpModeByValue"; +} + +def CmpEQ : PatLeaf<(i32 !cast<int>(CmpModeEQ_Rec.Value))>; +def CmpNE : PatLeaf<(i32 !cast<int>(CmpModeNE_Rec.Value))>; def CmpMode : Operand<i32> { let PrintMethod = "printCmpMode"; |
