aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/NVPTX
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/NVPTX')
-rw-r--r--llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp10
-rw-r--r--llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h1
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp5
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp362
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h5
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp682
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.h10
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrFormats.td10
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp26
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td827
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td624
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td8
12 files changed, 877 insertions, 1693 deletions
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 8eec915..ee1ca45 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -391,16 +391,6 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
}
}
-void NVPTXInstPrinter::printOffseti32imm(const MCInst *MI, int OpNum,
- raw_ostream &O) {
- auto &Op = MI->getOperand(OpNum);
- assert(Op.isImm() && "Invalid operand");
- if (Op.getImm() != 0) {
- O << "+";
- printOperand(MI, OpNum, O);
- }
-}
-
void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
raw_ostream &O) {
int64_t Imm = MI->getOperand(OpNum).getImm();
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index c3ff346..92155b0 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -46,7 +46,6 @@ public:
StringRef Modifier = {});
void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier = {});
- void printOffseti32imm(const MCInst *MI, int OpNum, raw_ostream &O);
void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O);
void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O);
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O);
diff --git a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
index cd40481..a349609 100644
--- a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
@@ -56,15 +56,12 @@ static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI,
case NVPTX::LD_i16:
case NVPTX::LD_i32:
case NVPTX::LD_i64:
- case NVPTX::LD_i8:
case NVPTX::LDV_i16_v2:
case NVPTX::LDV_i16_v4:
case NVPTX::LDV_i32_v2:
case NVPTX::LDV_i32_v4:
case NVPTX::LDV_i64_v2:
- case NVPTX::LDV_i64_v4:
- case NVPTX::LDV_i8_v2:
- case NVPTX::LDV_i8_v4: {
+ case NVPTX::LDV_i64_v4: {
LoadInsts.push_back(&U);
return true;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 65e7c56..6068035 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -56,9 +56,7 @@ INITIALIZE_PASS(NVPTXDAGToDAGISelLegacy, DEBUG_TYPE, PASS_NAME, false, false)
NVPTXDAGToDAGISel::NVPTXDAGToDAGISel(NVPTXTargetMachine &tm,
CodeGenOptLevel OptLevel)
- : SelectionDAGISel(tm, OptLevel), TM(tm) {
- doMulWide = (OptLevel > CodeGenOptLevel::None);
-}
+ : SelectionDAGISel(tm, OptLevel), TM(tm) {}
bool NVPTXDAGToDAGISel::runOnMachineFunction(MachineFunction &MF) {
Subtarget = &MF.getSubtarget<NVPTXSubtarget>();
@@ -145,18 +143,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
if (tryStoreVector(N))
return;
break;
- case NVPTXISD::LoadParam:
- case NVPTXISD::LoadParamV2:
- case NVPTXISD::LoadParamV4:
- if (tryLoadParam(N))
- return;
- break;
- case NVPTXISD::StoreParam:
- case NVPTXISD::StoreParamV2:
- case NVPTXISD::StoreParamV4:
- if (tryStoreParam(N))
- return;
- break;
case ISD::INTRINSIC_W_CHAIN:
if (tryIntrinsicChain(N))
return;
@@ -1017,14 +1003,10 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
// Helper function template to reduce amount of boilerplate code for
// opcode selection.
static std::optional<unsigned>
-pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i8,
- std::optional<unsigned> Opcode_i16,
+pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i16,
std::optional<unsigned> Opcode_i32,
std::optional<unsigned> Opcode_i64) {
switch (VT) {
- case MVT::i1:
- case MVT::i8:
- return Opcode_i8;
case MVT::f16:
case MVT::i16:
case MVT::bf16:
@@ -1092,8 +1074,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
Chain};
const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
- const std::optional<unsigned> Opcode = pickOpcodeForVT(
- TargetVT, NVPTX::LD_i8, NVPTX::LD_i16, NVPTX::LD_i32, NVPTX::LD_i64);
+ const std::optional<unsigned> Opcode =
+ pickOpcodeForVT(TargetVT, NVPTX::LD_i16, NVPTX::LD_i32, NVPTX::LD_i64);
if (!Opcode)
return false;
@@ -1178,17 +1160,15 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
default:
llvm_unreachable("Unexpected opcode");
case NVPTXISD::LoadV2:
- Opcode =
- pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i8_v2, NVPTX::LDV_i16_v2,
- NVPTX::LDV_i32_v2, NVPTX::LDV_i64_v2);
+ Opcode = pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i16_v2,
+ NVPTX::LDV_i32_v2, NVPTX::LDV_i64_v2);
break;
case NVPTXISD::LoadV4:
- Opcode =
- pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i8_v4, NVPTX::LDV_i16_v4,
- NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4);
+ Opcode = pickOpcodeForVT(EltVT.SimpleTy, NVPTX::LDV_i16_v4,
+ NVPTX::LDV_i32_v4, NVPTX::LDV_i64_v4);
break;
case NVPTXISD::LoadV8:
- Opcode = pickOpcodeForVT(EltVT.SimpleTy, {/* no v8i8 */}, {/* no v8i16 */},
+ Opcode = pickOpcodeForVT(EltVT.SimpleTy, {/* no v8i16 */},
NVPTX::LDV_i32_v8, {/* no v8i64 */});
break;
}
@@ -1244,22 +1224,21 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
default:
llvm_unreachable("Unexpected opcode");
case ISD::LOAD:
- Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i8,
- NVPTX::LD_GLOBAL_NC_i16, NVPTX::LD_GLOBAL_NC_i32,
- NVPTX::LD_GLOBAL_NC_i64);
+ Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i16,
+ NVPTX::LD_GLOBAL_NC_i32, NVPTX::LD_GLOBAL_NC_i64);
break;
case NVPTXISD::LoadV2:
- Opcode = pickOpcodeForVT(
- TargetVT, NVPTX::LD_GLOBAL_NC_v2i8, NVPTX::LD_GLOBAL_NC_v2i16,
- NVPTX::LD_GLOBAL_NC_v2i32, NVPTX::LD_GLOBAL_NC_v2i64);
+ Opcode =
+ pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v2i16,
+ NVPTX::LD_GLOBAL_NC_v2i32, NVPTX::LD_GLOBAL_NC_v2i64);
break;
case NVPTXISD::LoadV4:
- Opcode = pickOpcodeForVT(
- TargetVT, NVPTX::LD_GLOBAL_NC_v4i8, NVPTX::LD_GLOBAL_NC_v4i16,
- NVPTX::LD_GLOBAL_NC_v4i32, NVPTX::LD_GLOBAL_NC_v4i64);
+ Opcode =
+ pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v4i16,
+ NVPTX::LD_GLOBAL_NC_v4i32, NVPTX::LD_GLOBAL_NC_v4i64);
break;
case NVPTXISD::LoadV8:
- Opcode = pickOpcodeForVT(TargetVT, {/* no v8i8 */}, {/* no v8i16 */},
+ Opcode = pickOpcodeForVT(TargetVT, {/* no v8i16 */},
NVPTX::LD_GLOBAL_NC_v8i32, {/* no v8i64 */});
break;
}
@@ -1290,8 +1269,9 @@ bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) {
break;
}
- const MVT::SimpleValueType SelectVT =
- MVT::getIntegerVT(LD->getMemoryVT().getSizeInBits() / NumElts).SimpleTy;
+ SDLoc DL(N);
+ const unsigned FromTypeWidth = LD->getMemoryVT().getSizeInBits() / NumElts;
+ const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
// If this is an LDU intrinsic, the address is the third operand. If its an
// LDU SD node (from custom vector handling), then its the second operand
@@ -1300,32 +1280,28 @@ bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) {
SDValue Base, Offset;
SelectADDR(Addr, Base, Offset);
- SDValue Ops[] = {Base, Offset, LD->getChain()};
+ SDValue Ops[] = {getI32Imm(FromTypeWidth, DL), Base, Offset, LD->getChain()};
std::optional<unsigned> Opcode;
switch (N->getOpcode()) {
default:
llvm_unreachable("Unexpected opcode");
case ISD::INTRINSIC_W_CHAIN:
- Opcode =
- pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_i8, NVPTX::LDU_GLOBAL_i16,
- NVPTX::LDU_GLOBAL_i32, NVPTX::LDU_GLOBAL_i64);
+ Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_i16,
+ NVPTX::LDU_GLOBAL_i32, NVPTX::LDU_GLOBAL_i64);
break;
case NVPTXISD::LDUV2:
- Opcode = pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_v2i8,
- NVPTX::LDU_GLOBAL_v2i16, NVPTX::LDU_GLOBAL_v2i32,
- NVPTX::LDU_GLOBAL_v2i64);
+ Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_v2i16,
+ NVPTX::LDU_GLOBAL_v2i32, NVPTX::LDU_GLOBAL_v2i64);
break;
case NVPTXISD::LDUV4:
- Opcode = pickOpcodeForVT(SelectVT, NVPTX::LDU_GLOBAL_v4i8,
- NVPTX::LDU_GLOBAL_v4i16, NVPTX::LDU_GLOBAL_v4i32,
- {/* no v4i64 */});
+ Opcode = pickOpcodeForVT(TargetVT, NVPTX::LDU_GLOBAL_v4i16,
+ NVPTX::LDU_GLOBAL_v4i32, {/* no v4i64 */});
break;
}
if (!Opcode)
return false;
- SDLoc DL(N);
SDNode *NVPTXLDU = CurDAG->getMachineNode(*Opcode, DL, LD->getVTList(), Ops);
ReplaceNode(LD, NVPTXLDU);
@@ -1376,8 +1352,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
Chain};
const std::optional<unsigned> Opcode =
- pickOpcodeForVT(Value.getSimpleValueType().SimpleTy, NVPTX::ST_i8,
- NVPTX::ST_i16, NVPTX::ST_i32, NVPTX::ST_i64);
+ pickOpcodeForVT(Value.getSimpleValueType().SimpleTy, NVPTX::ST_i16,
+ NVPTX::ST_i32, NVPTX::ST_i64);
if (!Opcode)
return false;
@@ -1437,16 +1413,16 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
default:
return false;
case NVPTXISD::StoreV2:
- Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i8_v2, NVPTX::STV_i16_v2,
- NVPTX::STV_i32_v2, NVPTX::STV_i64_v2);
+ Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i16_v2, NVPTX::STV_i32_v2,
+ NVPTX::STV_i64_v2);
break;
case NVPTXISD::StoreV4:
- Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i8_v4, NVPTX::STV_i16_v4,
- NVPTX::STV_i32_v4, NVPTX::STV_i64_v4);
+ Opcode = pickOpcodeForVT(EltVT, NVPTX::STV_i16_v4, NVPTX::STV_i32_v4,
+ NVPTX::STV_i64_v4);
break;
case NVPTXISD::StoreV8:
- Opcode = pickOpcodeForVT(EltVT, {/* no v8i8 */}, {/* no v8i16 */},
- NVPTX::STV_i32_v8, {/* no v8i64 */});
+ Opcode = pickOpcodeForVT(EltVT, {/* no v8i16 */}, NVPTX::STV_i32_v8,
+ {/* no v8i64 */});
break;
}
@@ -1462,267 +1438,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
return true;
}
-bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
- SDValue Chain = Node->getOperand(0);
- SDValue Offset = Node->getOperand(2);
- SDValue Glue = Node->getOperand(3);
- SDLoc DL(Node);
- MemSDNode *Mem = cast<MemSDNode>(Node);
-
- unsigned VecSize;
- switch (Node->getOpcode()) {
- default:
- return false;
- case NVPTXISD::LoadParam:
- VecSize = 1;
- break;
- case NVPTXISD::LoadParamV2:
- VecSize = 2;
- break;
- case NVPTXISD::LoadParamV4:
- VecSize = 4;
- break;
- }
-
- EVT EltVT = Node->getValueType(0);
- EVT MemVT = Mem->getMemoryVT();
-
- std::optional<unsigned> Opcode;
-
- switch (VecSize) {
- default:
- return false;
- case 1:
- Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
- NVPTX::LoadParamMemI8, NVPTX::LoadParamMemI16,
- NVPTX::LoadParamMemI32, NVPTX::LoadParamMemI64);
- break;
- case 2:
- Opcode =
- pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV2I8,
- NVPTX::LoadParamMemV2I16, NVPTX::LoadParamMemV2I32,
- NVPTX::LoadParamMemV2I64);
- break;
- case 4:
- Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
- NVPTX::LoadParamMemV4I8, NVPTX::LoadParamMemV4I16,
- NVPTX::LoadParamMemV4I32, {/* no v4i64 */});
- break;
- }
- if (!Opcode)
- return false;
-
- SDVTList VTs;
- if (VecSize == 1) {
- VTs = CurDAG->getVTList(EltVT, MVT::Other, MVT::Glue);
- } else if (VecSize == 2) {
- VTs = CurDAG->getVTList(EltVT, EltVT, MVT::Other, MVT::Glue);
- } else {
- EVT EVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other, MVT::Glue };
- VTs = CurDAG->getVTList(EVTs);
- }
-
- unsigned OffsetVal = Offset->getAsZExtVal();
-
- SmallVector<SDValue, 2> Ops(
- {CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
-
- ReplaceNode(Node, CurDAG->getMachineNode(*Opcode, DL, VTs, Ops));
- return true;
-}
-
-// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
-#define getOpcV2H(ty, opKind0, opKind1) \
- NVPTX::StoreParamV2##ty##_##opKind0##opKind1
-
-#define getOpcV2H1(ty, opKind0, isImm1) \
- (isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r)
-
-#define getOpcodeForVectorStParamV2(ty, isimm) \
- (isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])
-
-#define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3) \
- NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3
-
-#define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3) \
- (isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i) \
- : getOpcV4H(ty, opKind0, opKind1, opKind2, r)
-
-#define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3) \
- (isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3) \
- : getOpcV4H3(ty, opKind0, opKind1, r, isImm3)
-
-#define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3) \
- (isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3) \
- : getOpcV4H2(ty, opKind0, r, isImm2, isImm3)
-
-#define getOpcodeForVectorStParamV4(ty, isimm) \
- (isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3]) \
- : getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])
-
-#define getOpcodeForVectorStParam(n, ty, isimm) \
- (n == 2) ? getOpcodeForVectorStParamV2(ty, isimm) \
- : getOpcodeForVectorStParamV4(ty, isimm)
-
-static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops,
- unsigned NumElts,
- MVT::SimpleValueType MemTy,
- SelectionDAG *CurDAG, SDLoc DL) {
- // Determine which inputs are registers and immediates make new operators
- // with constant values
- SmallVector<bool, 4> IsImm(NumElts, false);
- for (unsigned i = 0; i < NumElts; i++) {
- IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
- if (IsImm[i]) {
- SDValue Imm = Ops[i];
- if (MemTy == MVT::f32 || MemTy == MVT::f64) {
- const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
- const ConstantFP *CF = ConstImm->getConstantFPValue();
- Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
- } else {
- const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
- const ConstantInt *CI = ConstImm->getConstantIntValue();
- Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
- }
- Ops[i] = Imm;
- }
- }
-
- // Get opcode for MemTy, size, and register/immediate operand ordering
- switch (MemTy) {
- case MVT::i8:
- return getOpcodeForVectorStParam(NumElts, I8, IsImm);
- case MVT::i16:
- return getOpcodeForVectorStParam(NumElts, I16, IsImm);
- case MVT::i32:
- return getOpcodeForVectorStParam(NumElts, I32, IsImm);
- case MVT::i64:
- assert(NumElts == 2 && "MVT too large for NumElts > 2");
- return getOpcodeForVectorStParamV2(I64, IsImm);
- case MVT::f32:
- return getOpcodeForVectorStParam(NumElts, F32, IsImm);
- case MVT::f64:
- assert(NumElts == 2 && "MVT too large for NumElts > 2");
- return getOpcodeForVectorStParamV2(F64, IsImm);
-
- // These cases don't support immediates, just use the all register version
- // and generate moves.
- case MVT::i1:
- return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr
- : NVPTX::StoreParamV4I8_rrrr;
- case MVT::f16:
- case MVT::bf16:
- return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr
- : NVPTX::StoreParamV4I16_rrrr;
- case MVT::v2f16:
- case MVT::v2bf16:
- case MVT::v2i16:
- case MVT::v4i8:
- return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr
- : NVPTX::StoreParamV4I32_rrrr;
- default:
- llvm_unreachable("Cannot select st.param for unknown MemTy");
- }
-}
-
-bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
- SDLoc DL(N);
- SDValue Chain = N->getOperand(0);
- SDValue Param = N->getOperand(1);
- unsigned ParamVal = Param->getAsZExtVal();
- SDValue Offset = N->getOperand(2);
- unsigned OffsetVal = Offset->getAsZExtVal();
- MemSDNode *Mem = cast<MemSDNode>(N);
- SDValue Glue = N->getOperand(N->getNumOperands() - 1);
-
- // How many elements do we have?
- unsigned NumElts;
- switch (N->getOpcode()) {
- default:
- llvm_unreachable("Unexpected opcode");
- case NVPTXISD::StoreParam:
- NumElts = 1;
- break;
- case NVPTXISD::StoreParamV2:
- NumElts = 2;
- break;
- case NVPTXISD::StoreParamV4:
- NumElts = 4;
- break;
- }
-
- // Build vector of operands
- SmallVector<SDValue, 8> Ops;
- for (unsigned i = 0; i < NumElts; ++i)
- Ops.push_back(N->getOperand(i + 3));
- Ops.append({CurDAG->getTargetConstant(ParamVal, DL, MVT::i32),
- CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
-
- // Determine target opcode
- // If we have an i1, use an 8-bit store. The lowering code in
- // NVPTXISelLowering will have already emitted an upcast.
- std::optional<unsigned> Opcode;
- switch (NumElts) {
- default:
- llvm_unreachable("Unexpected NumElts");
- case 1: {
- MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
- SDValue Imm = Ops[0];
- if (MemTy != MVT::f16 && MemTy != MVT::bf16 &&
- (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
- // Convert immediate to target constant
- if (MemTy == MVT::f32 || MemTy == MVT::f64) {
- const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
- const ConstantFP *CF = ConstImm->getConstantFPValue();
- Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
- } else {
- const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
- const ConstantInt *CI = ConstImm->getConstantIntValue();
- Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
- }
- Ops[0] = Imm;
- // Use immediate version of store param
- Opcode =
- pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i, NVPTX::StoreParamI16_i,
- NVPTX::StoreParamI32_i, NVPTX::StoreParamI64_i);
- } else
- Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
- NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
- NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r);
- if (Opcode == NVPTX::StoreParamI8_r) {
- // Fine tune the opcode depending on the size of the operand.
- // This helps to avoid creating redundant COPY instructions in
- // InstrEmitter::AddRegisterOperand().
- switch (Ops[0].getSimpleValueType().SimpleTy) {
- default:
- break;
- case MVT::i32:
- Opcode = NVPTX::StoreParamI8TruncI32_r;
- break;
- case MVT::i64:
- Opcode = NVPTX::StoreParamI8TruncI64_r;
- break;
- }
- }
- break;
- }
- case 2:
- case 4: {
- MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
- Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL);
- break;
- }
- }
-
- SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
- SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, RetVTs, Ops);
- MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
- CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});
-
- ReplaceNode(N, Ret);
- return true;
-}
-
/// SelectBFE - Look for instruction sequences that can be made more efficient
/// by using the 'bfe' (bit-field extract) PTX instruction
bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
@@ -1962,10 +1677,11 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
auto API = APF.bitcastToAPInt();
API = API.concat(API);
auto Const = CurDAG->getTargetConstant(API, DL, MVT::i32);
- return SDValue(CurDAG->getMachineNode(NVPTX::IMOV32i, DL, VT, Const), 0);
+ return SDValue(CurDAG->getMachineNode(NVPTX::MOV_B32_i, DL, VT, Const),
+ 0);
}
auto Const = CurDAG->getTargetConstantFP(APF, DL, VT);
- return SDValue(CurDAG->getMachineNode(NVPTX::BFMOV16i, DL, VT, Const), 0);
+ return SDValue(CurDAG->getMachineNode(NVPTX::MOV_BF16_i, DL, VT, Const), 0);
};
switch (N->getOpcode()) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index b99b4ef..9e0f88e5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -40,9 +40,6 @@ private:
class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
const NVPTXTargetMachine &TM;
- // If true, generate mul.wide from sext and mul
- bool doMulWide;
-
NVPTX::DivPrecisionLevel getDivF32Level(const SDNode *N) const;
bool usePrecSqrtF32(const SDNode *N) const;
bool useF32FTZ() const;
@@ -78,8 +75,6 @@ private:
bool tryLDG(MemSDNode *N);
bool tryStore(SDNode *N);
bool tryStoreVector(SDNode *N);
- bool tryLoadParam(SDNode *N);
- bool tryStoreParam(SDNode *N);
bool tryFence(SDNode *N);
void SelectAddrSpaceCast(SDNode *N);
bool tryBFE(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index ddcecc00..65d1be3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -843,7 +843,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
- ISD::STORE});
+ ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -1075,12 +1075,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::DeclareArrayParam)
MAKE_CASE(NVPTXISD::DeclareScalarParam)
MAKE_CASE(NVPTXISD::CALL)
- MAKE_CASE(NVPTXISD::LoadParam)
- MAKE_CASE(NVPTXISD::LoadParamV2)
- MAKE_CASE(NVPTXISD::LoadParamV4)
- MAKE_CASE(NVPTXISD::StoreParam)
- MAKE_CASE(NVPTXISD::StoreParamV2)
- MAKE_CASE(NVPTXISD::StoreParamV4)
MAKE_CASE(NVPTXISD::MoveParam)
MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
MAKE_CASE(NVPTXISD::BUILD_VECTOR)
@@ -1318,105 +1312,6 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
return DL.getABITypeAlign(Ty);
}
-static bool adjustElementType(EVT &ElementType) {
- switch (ElementType.getSimpleVT().SimpleTy) {
- default:
- return false;
- case MVT::f16:
- case MVT::bf16:
- ElementType = MVT::i16;
- return true;
- case MVT::f32:
- case MVT::v2f16:
- case MVT::v2bf16:
- ElementType = MVT::i32;
- return true;
- case MVT::f64:
- ElementType = MVT::i64;
- return true;
- }
-}
-
-// Use byte-store when the param address of the argument value is unaligned.
-// This may happen when the return value is a field of a packed structure.
-//
-// This is called in LowerCall() when passing the param values.
-static SDValue LowerUnalignedStoreParam(SelectionDAG &DAG, SDValue Chain,
- uint64_t Offset, EVT ElementType,
- SDValue StVal, SDValue &InGlue,
- unsigned ArgID, const SDLoc &dl) {
- // Bit logic only works on integer types
- if (adjustElementType(ElementType))
- StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal);
-
- // Store each byte
- SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
- // Shift the byte to the last byte position
- SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal,
- DAG.getConstant(i * 8, dl, MVT::i32));
- SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32),
- DAG.getConstant(Offset + i, dl, MVT::i32),
- ShiftVal, InGlue};
- // Trunc store only the last byte by using
- // st.param.b8
- // The register type can be larger than b8.
- Chain = DAG.getMemIntrinsicNode(
- NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8,
- MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
- InGlue = Chain.getValue(1);
- }
- return Chain;
-}
-
-// Use byte-load when the param adress of the returned value is unaligned.
-// This may happen when the returned value is a field of a packed structure.
-static SDValue
-LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
- EVT ElementType, SDValue &InGlue,
- SmallVectorImpl<SDValue> &TempProxyRegOps,
- const SDLoc &dl) {
- // Bit logic only works on integer types
- EVT MergedType = ElementType;
- adjustElementType(MergedType);
-
- // Load each byte and construct the whole value. Initial value to 0
- SDValue RetVal = DAG.getConstant(0, dl, MergedType);
- // LoadParamMemI8 loads into i16 register only
- SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue);
- for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
- SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
- DAG.getConstant(Offset + i, dl, MVT::i32),
- InGlue};
- // This will be selected to LoadParamMemI8
- SDValue LdVal =
- DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands,
- MVT::i8, MachinePointerInfo(), Align(1));
- SDValue TmpLdVal = LdVal.getValue(0);
- Chain = LdVal.getValue(1);
- InGlue = LdVal.getValue(2);
-
- TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl,
- TmpLdVal.getSimpleValueType(), TmpLdVal);
- TempProxyRegOps.push_back(TmpLdVal);
-
- SDValue CMask = DAG.getConstant(255, dl, MergedType);
- SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32);
- // Need to extend the i16 register to the whole width.
- TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal);
- // Mask off the high bits. Leave only the lower 8bits.
- // Do this because we are using loadparam.b8.
- TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask);
- // Shift and merge
- TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift);
- RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal);
- }
- if (ElementType != MergedType)
- RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
-
- return RetVal;
-}
-
static bool shouldConvertToIndirectCall(const CallBase *CB,
const GlobalAddressSDNode *Func) {
if (!Func)
@@ -1483,10 +1378,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SelectionDAG &DAG = CLI.DAG;
SDLoc dl = CLI.DL;
- SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
- SDValue Chain = CLI.Chain;
+ const SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
SDValue Callee = CLI.Callee;
- bool &isTailCall = CLI.IsTailCall;
ArgListTy &Args = CLI.getArgs();
Type *RetTy = CLI.RetTy;
const CallBase *CB = CLI.CB;
@@ -1496,6 +1389,36 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
return DAG.getConstant(I, dl, MVT::i32);
};
+ const unsigned UniqueCallSite = GlobalUniqueCallSite++;
+ const SDValue CallChain = CLI.Chain;
+ const SDValue StartChain =
+ DAG.getCALLSEQ_START(CallChain, UniqueCallSite, 0, dl);
+ SDValue DeclareGlue = StartChain.getValue(1);
+
+ SmallVector<SDValue, 16> CallPrereqs{StartChain};
+
+ const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
+ // PTX ABI requires integral types to be at least 32 bits in size. FP16 is
+ // loaded/stored using i16, so it's handled here as well.
+ const unsigned SizeBits = promoteScalarArgumentSize(Size * 8);
+ SDValue Declare =
+ DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
+ {StartChain, Symbol, GetI32(SizeBits), DeclareGlue});
+ CallPrereqs.push_back(Declare);
+ DeclareGlue = Declare.getValue(1);
+ return Declare;
+ };
+
+ const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align,
+ unsigned Size) {
+ SDValue Declare = DAG.getNode(
+ NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
+ {StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue});
+ CallPrereqs.push_back(Declare);
+ DeclareGlue = Declare.getValue(1);
+ return Declare;
+ };
+
// Variadic arguments.
//
// Normally, for each argument, we declare a param scalar or a param
@@ -1511,15 +1434,17 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
//
// After all vararg is processed, 'VAOffset' holds the size of the
// vararg byte array.
+ assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) &&
+ "Non-VarArg function with extra arguments");
- SDValue VADeclareParam; // vararg byte array
const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
- unsigned VAOffset = 0; // current offset in the param array
+ unsigned VAOffset = 0; // current offset in the param array
- const unsigned UniqueCallSite = GlobalUniqueCallSite++;
- SDValue TempChain = Chain;
- Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
- SDValue InGlue = Chain.getValue(1);
+ const SDValue VADeclareParam =
+ CLI.Args.size() > FirstVAArg
+ ? MakeDeclareArrayParam(getCallParamSymbol(DAG, FirstVAArg, MVT::i32),
+ Align(STI.getMaxRequiredAlignment()), 0)
+ : SDValue();
// Args.size() and Outs.size() need not match.
// Outs.size() will be larger
@@ -1580,43 +1505,19 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
"type size mismatch");
- const std::optional<SDValue> ArgDeclare = [&]() -> std::optional<SDValue> {
- if (IsVAArg) {
- if (ArgI == FirstVAArg) {
- VADeclareParam = DAG.getNode(
- NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
- {Chain, ParamSymbol, GetI32(STI.getMaxRequiredAlignment()),
- GetI32(0), InGlue});
- return VADeclareParam;
- }
- return std::nullopt;
- }
- if (IsByVal || shouldPassAsArray(Arg.Ty)) {
- // declare .param .align <align> .b8 .param<n>[<size>];
- return DAG.getNode(NVPTXISD::DeclareArrayParam, dl,
- {MVT::Other, MVT::Glue},
- {Chain, ParamSymbol, GetI32(ArgAlign.value()),
- GetI32(TypeSize), InGlue});
- }
+ const SDValue ArgDeclare = [&]() {
+ if (IsVAArg)
+ return VADeclareParam;
+
+ if (IsByVal || shouldPassAsArray(Arg.Ty))
+ return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TypeSize);
+
assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
- // declare .param .b<size> .param<n>;
-
- // PTX ABI requires integral types to be at least 32 bits in
- // size. FP16 is loaded/stored using i16, so it's handled
- // here as well.
- const unsigned PromotedSize =
- (ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint())
- ? promoteScalarArgumentSize(TypeSize * 8)
- : TypeSize * 8;
-
- return DAG.getNode(NVPTXISD::DeclareScalarParam, dl,
- {MVT::Other, MVT::Glue},
- {Chain, ParamSymbol, GetI32(PromotedSize), InGlue});
+ assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
+ "Only int and float types are supported as non-array arguments");
+
+ return MakeDeclareScalarParam(ParamSymbol, TypeSize);
}();
- if (ArgDeclare) {
- Chain = ArgDeclare->getValue(0);
- InGlue = ArgDeclare->getValue(1);
- }
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter
// than 32-bits are sign extended or zero extended, depending on
@@ -1626,36 +1527,25 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Arg.Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Arg.Ty) < 32;
const auto GetStoredValue = [&](const unsigned I, EVT EltVT,
- const Align PartAlign) {
- SDValue StVal;
+ const MaybeAlign PartAlign) {
if (IsByVal) {
SDValue Ptr = ArgOutVals[0];
auto MPI = refinePtrAS(Ptr, DAG, DL, *this);
SDValue SrcAddr =
DAG.getObjectPtrOffset(dl, Ptr, TypeSize::getFixed(Offsets[I]));
- StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
- } else {
- StVal = ArgOutVals[I];
-
- auto PromotedVT = promoteScalarIntegerPTX(StVal.getValueType());
- if (PromotedVT != StVal.getValueType()) {
- StVal = DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, PromotedVT,
- StVal);
- }
+ return DAG.getLoad(EltVT, dl, CallChain, SrcAddr, MPI, PartAlign);
}
+ SDValue StVal = ArgOutVals[I];
+ assert(promoteScalarIntegerPTX(StVal.getValueType()) ==
+ StVal.getValueType() &&
+ "OutVal type should always be legal");
- if (ExtendIntegerParam) {
- assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
- // zext/sext to i32
- StVal =
- DAG.getNode(getExtOpcode(ArgOuts[I].Flags), dl, MVT::i32, StVal);
- } else if (EltVT.getSizeInBits() < 16) {
- // Use 16-bit registers for small stores as it's the
- // smallest general purpose register size supported by NVPTX.
- StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
- }
- return StVal;
+ const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
+ const EVT StoreVT =
+ ExtendIntegerParam ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
+
+ return correctParamType(StVal, StoreVT, ArgOuts[I].Flags, DAG, dl);
};
const auto VectorInfo =
@@ -1664,23 +1554,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
unsigned J = 0;
for (const unsigned NumElts : VectorInfo) {
const int CurOffset = Offsets[J];
- EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
- const Align PartAlign = commonAlignment(ArgAlign, CurOffset);
-
- // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
- // scalar store. In such cases, fall back to byte stores.
- if (NumElts == 1 && !IsVAArg && PartAlign < DAG.getEVTAlign(EltVT)) {
-
- SDValue StVal = GetStoredValue(J, EltVT, PartAlign);
- Chain = LowerUnalignedStoreParam(DAG, Chain,
- CurOffset + (IsByVal ? VAOffset : 0),
- EltVT, StVal, InGlue, ArgI, dl);
-
- // LowerUnalignedStoreParam took care of inserting the necessary nodes
- // into the SDAG, so just move on to the next element.
- J++;
- continue;
- }
+ const EVT EltVT = promoteScalarIntegerPTX(VTs[J]);
if (IsVAArg && !IsByVal)
// Align each part of the variadic argument to their type.
@@ -1688,44 +1562,45 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
assert((IsVAArg || VAOffset == 0) &&
"VAOffset must be 0 for non-VA args");
- SmallVector<SDValue, 6> StoreOperands{
- Chain, GetI32(IsVAArg ? FirstVAArg : ArgI),
- GetI32(VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset))};
- // Record the values to store.
- for (const unsigned K : llvm::seq(NumElts))
- StoreOperands.push_back(GetStoredValue(J + K, EltVT, PartAlign));
- StoreOperands.push_back(InGlue);
+ const unsigned Offset =
+ (VAOffset + ((IsVAArg && !IsByVal) ? 0 : CurOffset));
+ SDValue Ptr =
+ DAG.getObjectPtrOffset(dl, ParamSymbol, TypeSize::getFixed(Offset));
- NVPTXISD::NodeType Op;
- switch (NumElts) {
- case 1:
- Op = NVPTXISD::StoreParam;
- break;
- case 2:
- Op = NVPTXISD::StoreParamV2;
- break;
- case 4:
- Op = NVPTXISD::StoreParamV4;
- break;
- default:
- llvm_unreachable("Invalid vector info.");
+ const MaybeAlign CurrentAlign = ExtendIntegerParam
+ ? MaybeAlign(std::nullopt)
+ : commonAlignment(ArgAlign, Offset);
+
+ SDValue Val;
+ if (NumElts == 1) {
+ Val = GetStoredValue(J, EltVT, CurrentAlign);
+ } else {
+ SmallVector<SDValue, 8> StoreVals;
+ for (const unsigned K : llvm::seq(NumElts)) {
+ SDValue ValJ = GetStoredValue(J + K, EltVT, CurrentAlign);
+ if (ValJ.getValueType().isVector())
+ DAG.ExtractVectorElements(ValJ, StoreVals);
+ else
+ StoreVals.push_back(ValJ);
+ }
+
+ EVT VT = EVT::getVectorVT(
+ *DAG.getContext(), StoreVals[0].getValueType(), StoreVals.size());
+ Val = DAG.getBuildVector(VT, dl, StoreVals);
}
- // Adjust type of the store op if we've extended the scalar
- // return value.
- EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
- Chain = DAG.getMemIntrinsicNode(
- Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
- TheStoreType, MachinePointerInfo(), PartAlign,
- MachineMemOperand::MOStore);
- InGlue = Chain.getValue(1);
+ SDValue StoreParam =
+ DAG.getStore(ArgDeclare, dl, Val, Ptr,
+ MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
+ CallPrereqs.push_back(StoreParam);
// TODO: We may need to support vector types that can be passed
// as scalars in variadic arguments.
if (IsVAArg && !IsByVal) {
assert(NumElts == 1 &&
"Vectorization is expected to be disabled for variadics.");
+ const EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : EltVT;
VAOffset +=
DL.getTypeAllocSize(TheStoreType.getTypeForEVT(*DAG.getContext()));
}
@@ -1736,33 +1611,21 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
VAOffset += TypeSize;
}
- GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
-
// Handle Result
if (!Ins.empty()) {
- const SDValue RetDeclare = [&]() {
- const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
- const unsigned ResultSize = DL.getTypeAllocSizeInBits(RetTy);
- if (shouldPassAsArray(RetTy)) {
- const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
- return DAG.getNode(NVPTXISD::DeclareArrayParam, dl,
- {MVT::Other, MVT::Glue},
- {Chain, RetSymbol, GetI32(RetAlign.value()),
- GetI32(ResultSize / 8), InGlue});
- }
- const auto PromotedResultSize = promoteScalarArgumentSize(ResultSize);
- return DAG.getNode(
- NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
- {Chain, RetSymbol, GetI32(PromotedResultSize), InGlue});
- }();
- Chain = RetDeclare.getValue(0);
- InGlue = RetDeclare.getValue(1);
+ const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
+ const unsigned ResultSize = DL.getTypeAllocSize(RetTy);
+ if (shouldPassAsArray(RetTy)) {
+ const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
+ MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize);
+ } else {
+ MakeDeclareScalarParam(RetSymbol, ResultSize);
+ }
}
- const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
// Set the size of the vararg param byte array if the callee is a variadic
// function and the variadic part is not empty.
- if (HasVAArgs) {
+ if (VADeclareParam) {
SDValue DeclareParamOps[] = {VADeclareParam.getOperand(0),
VADeclareParam.getOperand(1),
VADeclareParam.getOperand(2), GetI32(VAOffset),
@@ -1771,6 +1634,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
VADeclareParam->getVTList(), DeclareParamOps);
}
+ const auto *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
// If the type of the callsite does not match that of the function, convert
// the callsite to an indirect call.
const bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
@@ -1800,15 +1664,16 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// instruction.
// The prototype is embedded in a string and put as the operand for a
// CallPrototype SDNode which will print out to the value of the string.
+ const bool HasVAArgs = CLI.IsVarArg && (CLI.Args.size() > CLI.NumFixedArgs);
std::string Proto =
getPrototype(DL, RetTy, Args, CLI.Outs,
HasVAArgs ? std::optional(FirstVAArg) : std::nullopt, *CB,
UniqueCallSite);
const char *ProtoStr = nvTM->getStrPool().save(Proto).data();
- Chain = DAG.getNode(
- NVPTXISD::CallPrototype, dl, {MVT::Other, MVT::Glue},
- {Chain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32), InGlue});
- InGlue = Chain.getValue(1);
+ const SDValue PrototypeDeclare = DAG.getNode(
+ NVPTXISD::CallPrototype, dl, MVT::Other,
+ {StartChain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32)});
+ CallPrereqs.push_back(PrototypeDeclare);
}
if (ConvertToIndirectCall) {
@@ -1826,24 +1691,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const unsigned NumArgs =
std::min<unsigned>(CLI.NumFixedArgs + 1, Args.size());
/// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
- /// NumParams, Callee, Proto, InGlue)
- Chain = DAG.getNode(NVPTXISD::CALL, dl, {MVT::Other, MVT::Glue},
- {Chain, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
- GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee,
- GetI32(Proto), InGlue});
- InGlue = Chain.getValue(1);
-
+ /// NumParams, Callee, Proto)
+ const SDValue CallToken = DAG.getTokenFactor(dl, CallPrereqs);
+ const SDValue Call = DAG.getNode(
+ NVPTXISD::CALL, dl, MVT::Other,
+ {CallToken, GetI32(CLI.IsConvergent), GetI32(IsIndirectCall),
+ GetI32(Ins.empty() ? 0 : 1), GetI32(NumArgs), Callee, GetI32(Proto)});
+
+ SmallVector<SDValue, 16> LoadChains{Call};
SmallVector<SDValue, 16> ProxyRegOps;
- // An item of the vector is filled if the element does not need a ProxyReg
- // operation on it and should be added to InVals as is. ProxyRegOps and
- // ProxyRegTruncates contain empty/none items at the same index.
- SmallVector<SDValue, 16> RetElts;
- // A temporary ProxyReg operations inserted in `LowerUnalignedLoadRetParam()`
- // to use the values of `LoadParam`s and to be replaced later then
- // `CALLSEQ_END` is added.
- SmallVector<SDValue, 16> TempProxyRegOps;
-
- // Generate loads from param memory/moves from registers for result
if (!Ins.empty()) {
SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
@@ -1860,104 +1716,65 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
const auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
unsigned I = 0;
- for (const unsigned VectorizedSize : VectorInfo) {
- EVT TheLoadType = promoteScalarIntegerPTX(VTs[I]);
- EVT EltType = Ins[I].VT;
- const Align EltAlign = commonAlignment(RetAlign, Offsets[I]);
-
- if (TheLoadType != VTs[I])
- EltType = TheLoadType;
-
- if (ExtendIntegerRetVal) {
- TheLoadType = MVT::i32;
- EltType = MVT::i32;
- } else if (TheLoadType.getSizeInBits() < 16) {
- EltType = MVT::i16;
- }
+ for (const unsigned NumElts : VectorInfo) {
+ const MaybeAlign CurrentAlign =
+ ExtendIntegerRetVal ? MaybeAlign(std::nullopt)
+ : commonAlignment(RetAlign, Offsets[I]);
- // If we have a PVF_SCALAR entry, it may not be sufficiently aligned for a
- // scalar load. In such cases, fall back to byte loads.
- if (VectorizedSize == 1 && RetTy->isAggregateType() &&
- EltAlign < DAG.getEVTAlign(TheLoadType)) {
- SDValue Ret = LowerUnalignedLoadRetParam(
- DAG, Chain, Offsets[I], TheLoadType, InGlue, TempProxyRegOps, dl);
- ProxyRegOps.push_back(SDValue());
- RetElts.resize(I);
- RetElts.push_back(Ret);
-
- I++;
- continue;
- }
+ const EVT VTI = promoteScalarIntegerPTX(VTs[I]);
+ const EVT LoadVT =
+ ExtendIntegerRetVal ? MVT::i32 : (VTI == MVT::i1 ? MVT::i8 : VTI);
- SmallVector<EVT, 6> LoadVTs(VectorizedSize, EltType);
- LoadVTs.append({MVT::Other, MVT::Glue});
+ const unsigned PackingAmt =
+ LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
- NVPTXISD::NodeType Op;
- switch (VectorizedSize) {
- case 1:
- Op = NVPTXISD::LoadParam;
- break;
- case 2:
- Op = NVPTXISD::LoadParamV2;
- break;
- case 4:
- Op = NVPTXISD::LoadParamV4;
- break;
- default:
- llvm_unreachable("Invalid vector info.");
- }
+ const EVT VecVT = NumElts == 1 ? LoadVT
+ : EVT::getVectorVT(*DAG.getContext(),
+ LoadVT.getScalarType(),
+ NumElts * PackingAmt);
- SDValue LoadOperands[] = {Chain, GetI32(1), GetI32(Offsets[I]), InGlue};
- SDValue RetVal = DAG.getMemIntrinsicNode(
- Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
- MachinePointerInfo(), EltAlign, MachineMemOperand::MOLoad);
+ const SDValue RetSymbol = DAG.getExternalSymbol("retval0", MVT::i32);
+ SDValue Ptr =
+ DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
- for (const unsigned J : llvm::seq(VectorizedSize)) {
- ProxyRegOps.push_back(RetVal.getValue(J));
- }
+ SDValue R =
+ DAG.getLoad(VecVT, dl, Call, Ptr,
+ MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
- Chain = RetVal.getValue(VectorizedSize);
- InGlue = RetVal.getValue(VectorizedSize + 1);
+ LoadChains.push_back(R.getValue(1));
- I += VectorizedSize;
+ if (NumElts == 1)
+ ProxyRegOps.push_back(R);
+ else
+ for (const unsigned J : llvm::seq(NumElts)) {
+ SDValue Elt = DAG.getNode(
+ LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
+ : ISD::EXTRACT_VECTOR_ELT,
+ dl, LoadVT, R, DAG.getVectorIdxConstant(J * PackingAmt, dl));
+ ProxyRegOps.push_back(Elt);
+ }
+ I += NumElts;
}
}
- Chain =
- DAG.getCALLSEQ_END(Chain, UniqueCallSite, UniqueCallSite + 1, InGlue, dl);
- InGlue = Chain.getValue(1);
+ const SDValue EndToken = DAG.getTokenFactor(dl, LoadChains);
+ const SDValue CallEnd = DAG.getCALLSEQ_END(EndToken, UniqueCallSite,
+ UniqueCallSite + 1, SDValue(), dl);
// Append ProxyReg instructions to the chain to make sure that `callseq_end`
// will not get lost. Otherwise, during libcalls expansion, the nodes can become
// dangling.
- for (const unsigned I : llvm::seq(ProxyRegOps.size())) {
- if (I < RetElts.size() && RetElts[I]) {
- InVals.push_back(RetElts[I]);
- continue;
- }
-
- SDValue Ret =
- DAG.getNode(NVPTXISD::ProxyReg, dl, ProxyRegOps[I].getSimpleValueType(),
- {Chain, ProxyRegOps[I]});
-
- const EVT ExpectedVT = Ins[I].VT;
- if (!Ret.getValueType().bitsEq(ExpectedVT)) {
- Ret = DAG.getNode(ISD::TRUNCATE, dl, ExpectedVT, Ret);
- }
+ for (const auto [I, Reg] : llvm::enumerate(ProxyRegOps)) {
+ SDValue Proxy =
+ DAG.getNode(NVPTXISD::ProxyReg, dl, Reg.getValueType(), {CallEnd, Reg});
+ SDValue Ret = correctParamType(Proxy, Ins[I].VT, Ins[I].Flags, DAG, dl);
InVals.push_back(Ret);
}
- for (SDValue &T : TempProxyRegOps) {
- SDValue Repl = DAG.getNode(NVPTXISD::ProxyReg, dl, T.getSimpleValueType(),
- {Chain, T.getOperand(0)});
- DAG.ReplaceAllUsesWith(T, Repl);
- DAG.RemoveDeadNode(T.getNode());
- }
-
- // set isTailCall to false for now, until we figure out how to express
+ // set IsTailCall to false for now, until we figure out how to express
// tail call optimization in PTX
- isTailCall = false;
- return Chain;
+ CLI.IsTailCall = false;
+ return CallEnd;
}
SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
@@ -5100,7 +4917,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return SDValue();
auto *LD = cast<MemSDNode>(N);
- EVT MemVT = LD->getMemoryVT();
SDLoc DL(LD);
// the new opcode after we double the number of operands
@@ -5117,10 +4933,6 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
Operands.push_back(DCI.DAG.getIntPtrConstant(
cast<LoadSDNode>(LD)->getExtensionType(), DL));
break;
- case NVPTXISD::LoadParamV2:
- OldNumOutputs = 2;
- Opcode = NVPTXISD::LoadParamV4;
- break;
case NVPTXISD::LoadV2:
OldNumOutputs = 2;
Opcode = NVPTXISD::LoadV4;
@@ -5145,9 +4957,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
NewVTs.append(LD->value_begin() + OldNumOutputs, LD->value_end());
// Create the new load
- SDValue NewLoad =
- DCI.DAG.getMemIntrinsicNode(Opcode, DL, DCI.DAG.getVTList(NewVTs),
- Operands, MemVT, LD->getMemOperand());
+ SDValue NewLoad = DCI.DAG.getMemIntrinsicNode(
+ Opcode, DL, DCI.DAG.getVTList(NewVTs), Operands, LD->getMemoryVT(),
+ LD->getMemOperand());
// Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
// the outputs the same. These nodes will be optimized away in later
@@ -5189,7 +5001,6 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
return SDValue();
auto *ST = cast<MemSDNode>(N);
- EVT MemVT = ElementVT.getVectorElementType();
// The new opcode after we double the number of operands.
NVPTXISD::NodeType Opcode;
@@ -5198,17 +5009,9 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
// Any packed type is legal, so the legalizer will not have lowered
// ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
// it here.
- MemVT = ST->getMemoryVT();
Opcode = NVPTXISD::StoreV2;
break;
- case NVPTXISD::StoreParam:
- Opcode = NVPTXISD::StoreParamV2;
- break;
- case NVPTXISD::StoreParamV2:
- Opcode = NVPTXISD::StoreParamV4;
- break;
case NVPTXISD::StoreV2:
- MemVT = ST->getMemoryVT();
Opcode = NVPTXISD::StoreV4;
break;
case NVPTXISD::StoreV4:
@@ -5218,7 +5021,6 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
return SDValue();
Opcode = NVPTXISD::StoreV8;
break;
- case NVPTXISD::StoreParamV4:
case NVPTXISD::StoreV8:
// PTX doesn't support the next doubling of operands
return SDValue();
@@ -5260,19 +5062,7 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
// Now we replace the store
return DCI.DAG.getMemIntrinsicNode(Opcode, SDLoc(N), N->getVTList(), Operands,
- MemVT, ST->getMemOperand());
-}
-
-static SDValue PerformStoreCombineHelper(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI,
- unsigned Front, unsigned Back) {
- if (all_of(N->ops().drop_front(Front).drop_back(Back),
- [](const SDUse &U) { return U.get()->isUndef(); }))
- // Operand 0 is the previous value in the chain. Cannot return EntryToken
- // as the previous value will become unused and eliminated later.
- return N->getOperand(0);
-
- return combinePackingMovIntoStore(N, DCI, Front, Back);
+ ST->getMemoryVT(), ST->getMemOperand());
}
static SDValue PerformStoreCombine(SDNode *N,
@@ -5280,13 +5070,6 @@ static SDValue PerformStoreCombine(SDNode *N,
return combinePackingMovIntoStore(N, DCI, 1, 2);
}
-static SDValue PerformStoreParamCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI) {
- // Operands from the 3rd to the 2nd last one are the values to be stored.
- // {Chain, ArgID, Offset, Val, Glue}
- return PerformStoreCombineHelper(N, DCI, 3, 1);
-}
-
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
///
static SDValue PerformADDCombine(SDNode *N,
@@ -5432,6 +5215,42 @@ static SDValue PerformREMCombine(SDNode *N,
return SDValue();
}
+// (sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y)
+static SDValue combineMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ CodeGenOptLevel OptLevel) {
+ if (OptLevel == CodeGenOptLevel::None)
+ return SDValue();
+
+ SDValue Op = N->getOperand(0);
+ if (!Op.hasOneUse())
+ return SDValue();
+ EVT ToVT = N->getValueType(0);
+ EVT FromVT = Op.getValueType();
+ if (!((ToVT == MVT::i32 && FromVT == MVT::i16) ||
+ (ToVT == MVT::i64 && FromVT == MVT::i32)))
+ return SDValue();
+ if (!(Op.getOpcode() == ISD::MUL ||
+ (Op.getOpcode() == ISD::SHL && isa<ConstantSDNode>(Op.getOperand(1)))))
+ return SDValue();
+
+ SDLoc DL(N);
+ unsigned ExtOpcode = N->getOpcode();
+ unsigned Opcode = 0;
+ if (ExtOpcode == ISD::SIGN_EXTEND && Op->getFlags().hasNoSignedWrap())
+ Opcode = NVPTXISD::MUL_WIDE_SIGNED;
+ else if (ExtOpcode == ISD::ZERO_EXTEND && Op->getFlags().hasNoUnsignedWrap())
+ Opcode = NVPTXISD::MUL_WIDE_UNSIGNED;
+ else
+ return SDValue();
+ SDValue RHS = Op.getOperand(1);
+ if (Op.getOpcode() == ISD::SHL) {
+ const auto ShiftAmt = Op.getConstantOperandVal(1);
+ const auto MulVal = APInt(ToVT.getSizeInBits(), 1) << ShiftAmt;
+ RHS = DCI.DAG.getConstant(MulVal, DL, ToVT);
+ }
+ return DCI.DAG.getNode(Opcode, DL, ToVT, Op.getOperand(0), RHS);
+}
+
enum OperandSignedness {
Signed = 0,
Unsigned,
@@ -5942,6 +5761,86 @@ static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
N->getConstantOperandAPInt(2),
N->getConstantOperandVal(3)),
SDLoc(N), N->getValueType(0));
+ return SDValue();
+}
+
+// During call lowering we wrap the return values in a ProxyReg node which
+// depend on the chain value produced by the completed call. This ensures that
+// the full call is emitted in cases where libcalls are used to legalize
+// operations. To improve the functioning of other DAG combines we pull all
+// operations we can through one of these nodes, ensuring that the ProxyReg
+// directly wraps a load. That is:
+//
+// (ProxyReg (zext (load retval0))) => (zext (ProxyReg (load retval0)))
+//
+static SDValue sinkProxyReg(SDValue R, SDValue Chain,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ switch (R.getOpcode()) {
+ case ISD::TRUNCATE:
+ case ISD::ANY_EXTEND:
+ case ISD::SIGN_EXTEND:
+ case ISD::ZERO_EXTEND:
+ case ISD::BITCAST: {
+ if (SDValue V = sinkProxyReg(R.getOperand(0), Chain, DCI))
+ return DCI.DAG.getNode(R.getOpcode(), SDLoc(R), R.getValueType(), V);
+ return SDValue();
+ }
+ case ISD::SHL:
+ case ISD::SRL:
+ case ISD::SRA:
+ case ISD::OR: {
+ if (SDValue A = sinkProxyReg(R.getOperand(0), Chain, DCI))
+ if (SDValue B = sinkProxyReg(R.getOperand(1), Chain, DCI))
+ return DCI.DAG.getNode(R.getOpcode(), SDLoc(R), R.getValueType(), A, B);
+ return SDValue();
+ }
+ case ISD::Constant:
+ return R;
+ case ISD::LOAD:
+ case NVPTXISD::LoadV2:
+ case NVPTXISD::LoadV4: {
+ return DCI.DAG.getNode(NVPTXISD::ProxyReg, SDLoc(R), R.getValueType(),
+ {Chain, R});
+ }
+ case ISD::BUILD_VECTOR: {
+ if (DCI.isBeforeLegalize())
+ return SDValue();
+
+ SmallVector<SDValue, 16> Ops;
+ for (auto &Op : R->ops()) {
+ SDValue V = sinkProxyReg(Op, Chain, DCI);
+ if (!V)
+ return SDValue();
+ Ops.push_back(V);
+ }
+ return DCI.DAG.getNode(ISD::BUILD_VECTOR, SDLoc(R), R.getValueType(), Ops);
+ }
+ case ISD::EXTRACT_VECTOR_ELT: {
+ if (DCI.isBeforeLegalize())
+ return SDValue();
+
+ if (SDValue V = sinkProxyReg(R.getOperand(0), Chain, DCI))
+ return DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(R),
+ R.getValueType(), V, R.getOperand(1));
+ return SDValue();
+ }
+ default:
+ return SDValue();
+ }
+}
+
+static SDValue combineProxyReg(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+
+ SDValue Chain = N->getOperand(0);
+ SDValue Reg = N->getOperand(1);
+
+ // If the ProxyReg is not wrapping a load, try to pull the operations through
+ // the ProxyReg.
+ if (Reg.getOpcode() != ISD::LOAD) {
+ if (SDValue V = sinkProxyReg(Reg, Chain, DCI))
+ return V;
+ }
return SDValue();
}
@@ -5958,6 +5857,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return combineADDRSPACECAST(N, DCI);
case ISD::AND:
return PerformANDCombine(N, DCI);
+ case ISD::SIGN_EXTEND:
+ case ISD::ZERO_EXTEND:
+ return combineMulWide(N, DCI, OptLevel);
case ISD::BUILD_VECTOR:
return PerformBUILD_VECTORCombine(N, DCI);
case ISD::EXTRACT_VECTOR_ELT:
@@ -5965,7 +5867,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::FADD:
return PerformFADDCombine(N, DCI, OptLevel);
case ISD::LOAD:
- case NVPTXISD::LoadParamV2:
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
return combineUnpackingMovIntoLoad(N, DCI);
@@ -5973,6 +5874,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformMULCombine(N, DCI, OptLevel);
case NVPTXISD::PRMT:
return combinePRMT(N, DCI, OptLevel);
+ case NVPTXISD::ProxyReg:
+ return combineProxyReg(N, DCI);
case ISD::SETCC:
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
case ISD::SHL:
@@ -5980,10 +5883,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SREM:
case ISD::UREM:
return PerformREMCombine(N, DCI, OptLevel);
- case NVPTXISD::StoreParam:
- case NVPTXISD::StoreParamV2:
- case NVPTXISD::StoreParamV4:
- return PerformStoreParamCombine(N, DCI);
case ISD::STORE:
case NVPTXISD::StoreV2:
case NVPTXISD::StoreV4:
@@ -6332,6 +6231,22 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
Results.push_back(NewValue.getValue(3));
}
+static void replaceProxyReg(SDNode *N, SelectionDAG &DAG,
+ const TargetLowering &TLI,
+ SmallVectorImpl<SDValue> &Results) {
+ SDValue Chain = N->getOperand(0);
+ SDValue Reg = N->getOperand(1);
+
+ MVT VT = TLI.getRegisterType(*DAG.getContext(), Reg.getValueType());
+
+ SDValue NewReg = DAG.getAnyExtOrTrunc(Reg, SDLoc(N), VT);
+ SDValue NewProxy =
+ DAG.getNode(NVPTXISD::ProxyReg, SDLoc(N), VT, {Chain, NewReg});
+ SDValue Res = DAG.getAnyExtOrTrunc(NewProxy, SDLoc(N), N->getValueType(0));
+
+ Results.push_back(Res);
+}
+
void NVPTXTargetLowering::ReplaceNodeResults(
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
switch (N->getOpcode()) {
@@ -6349,6 +6264,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
case ISD::CopyFromReg:
ReplaceCopyFromReg_128(N, DAG, Results);
return;
+ case NVPTXISD::ProxyReg:
+ replaceProxyReg(N, DAG, *this, Results);
+ return;
}
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 228e2aa..cf72a1e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -38,7 +38,7 @@ enum NodeType : unsigned {
/// This node represents a PTX call instruction. It's operands are as follows:
///
/// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
- /// NumParams, Callee, Proto, InGlue)
+ /// NumParams, Callee, Proto)
CALL,
MoveParam,
@@ -84,13 +84,7 @@ enum NodeType : unsigned {
StoreV2,
StoreV4,
StoreV8,
- LoadParam,
- LoadParamV2,
- LoadParamV4,
- StoreParam,
- StoreParamV2,
- StoreParamV4,
- LAST_MEMORY_OPCODE = StoreParamV4,
+ LAST_MEMORY_OPCODE = StoreV8,
};
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td b/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td
index 86dcb4a..719be03 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td
@@ -11,15 +11,9 @@
//
//===----------------------------------------------------------------------===//
-// Vector instruction type enum
-class VecInstTypeEnum<bits<4> val> {
- bits<4> Value=val;
-}
-def VecNOP : VecInstTypeEnum<0>;
-
// Generic NVPTX Format
-class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern>
+class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern = []>
: Instruction {
field bits<14> Inst;
@@ -30,7 +24,6 @@ class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern>
let Pattern = pattern;
// TSFlagFields
- bits<4> VecInstType = VecNOP.Value;
bit IsLoad = false;
bit IsStore = false;
@@ -45,7 +38,6 @@ class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern>
// 2**(2-1) = 2.
bits<2> IsSuld = 0;
- let TSFlags{3...0} = VecInstType;
let TSFlags{4} = IsLoad;
let TSFlags{5} = IsStore;
let TSFlags{6} = IsTex;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index e218ef1..34fe467 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -35,23 +35,23 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
const TargetRegisterClass *DestRC = MRI.getRegClass(DestReg);
const TargetRegisterClass *SrcRC = MRI.getRegClass(SrcReg);
- if (RegInfo.getRegSizeInBits(*DestRC) != RegInfo.getRegSizeInBits(*SrcRC))
+ if (DestRC != SrcRC)
report_fatal_error("Copy one register into another with a different width");
unsigned Op;
- if (DestRC == &NVPTX::B1RegClass) {
- Op = NVPTX::IMOV1r;
- } else if (DestRC == &NVPTX::B16RegClass) {
- Op = NVPTX::MOV16r;
- } else if (DestRC == &NVPTX::B32RegClass) {
- Op = NVPTX::IMOV32r;
- } else if (DestRC == &NVPTX::B64RegClass) {
- Op = NVPTX::IMOV64r;
- } else if (DestRC == &NVPTX::B128RegClass) {
- Op = NVPTX::IMOV128r;
- } else {
+ if (DestRC == &NVPTX::B1RegClass)
+ Op = NVPTX::MOV_B1_r;
+ else if (DestRC == &NVPTX::B16RegClass)
+ Op = NVPTX::MOV_B16_r;
+ else if (DestRC == &NVPTX::B32RegClass)
+ Op = NVPTX::MOV_B32_r;
+ else if (DestRC == &NVPTX::B64RegClass)
+ Op = NVPTX::MOV_B64_r;
+ else if (DestRC == &NVPTX::B128RegClass)
+ Op = NVPTX::MOV_B128_r;
+ else
llvm_unreachable("Bad register copy");
- }
+
BuildMI(MBB, I, DL, get(Op), DestReg)
.addReg(SrcReg, getKillRegState(KillSrc));
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 442b900..d8047d3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -15,19 +15,8 @@ include "NVPTXInstrFormats.td"
let OperandType = "OPERAND_IMMEDIATE" in {
def f16imm : Operand<f16>;
def bf16imm : Operand<bf16>;
-
}
-// List of vector specific properties
-def isVecLD : VecInstTypeEnum<1>;
-def isVecST : VecInstTypeEnum<2>;
-def isVecBuild : VecInstTypeEnum<3>;
-def isVecShuffle : VecInstTypeEnum<4>;
-def isVecExtract : VecInstTypeEnum<5>;
-def isVecInsert : VecInstTypeEnum<6>;
-def isVecDest : VecInstTypeEnum<7>;
-def isVecOther : VecInstTypeEnum<15>;
-
//===----------------------------------------------------------------------===//
// NVPTX Operand Definitions.
//===----------------------------------------------------------------------===//
@@ -125,8 +114,6 @@ def doF32FTZ : Predicate<"useF32FTZ()">;
def doNoF32FTZ : Predicate<"!useF32FTZ()">;
def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
-def doMulWide : Predicate<"doMulWide">;
-
def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
@@ -486,46 +473,28 @@ let hasSideEffects = false in {
// takes a CvtMode immediate that defines the conversion mode to use. It can
// be CvtNONE to omit a conversion mode.
multiclass CVT_FROM_ALL<string ToType, RegisterClass RC, list<Predicate> Preds = []> {
- def _s8 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B16:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s8">,
- Requires<Preds>;
- def _u8 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B16:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u8">,
- Requires<Preds>;
- def _s16 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B16:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s16">,
- Requires<Preds>;
- def _u16 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B16:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u16">,
- Requires<Preds>;
- def _s32 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B32:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s32">,
- Requires<Preds>;
- def _u32 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B32:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u32">,
- Requires<Preds>;
- def _s64 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B64:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s64">,
- Requires<Preds>;
- def _u64 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B64:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u64">,
- Requires<Preds>;
+ foreach sign = ["s", "u"] in {
+ def _ # sign # "8" :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B16:$src), (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "8">,
+ Requires<Preds>;
+ def _ # sign # "16" :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B16:$src), (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "16">,
+ Requires<Preds>;
+ def _ # sign # "32" :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B32:$src), (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "32">,
+ Requires<Preds>;
+ def _ # sign # "64" :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B64:$src), (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "64">,
+ Requires<Preds>;
+ }
def _f16 :
BasicFlagsNVPTXInst<(outs RC:$dst),
(ins B16:$src), (ins CvtMode:$mode),
@@ -556,14 +525,12 @@ let hasSideEffects = false in {
}
// Generate cvts from all types to all types.
- defm CVT_s8 : CVT_FROM_ALL<"s8", B16>;
- defm CVT_u8 : CVT_FROM_ALL<"u8", B16>;
- defm CVT_s16 : CVT_FROM_ALL<"s16", B16>;
- defm CVT_u16 : CVT_FROM_ALL<"u16", B16>;
- defm CVT_s32 : CVT_FROM_ALL<"s32", B32>;
- defm CVT_u32 : CVT_FROM_ALL<"u32", B32>;
- defm CVT_s64 : CVT_FROM_ALL<"s64", B64>;
- defm CVT_u64 : CVT_FROM_ALL<"u64", B64>;
+ foreach sign = ["s", "u"] in {
+ defm CVT_ # sign # "8" : CVT_FROM_ALL<sign # "8", B16>;
+ defm CVT_ # sign # "16" : CVT_FROM_ALL<sign # "16", B16>;
+ defm CVT_ # sign # "32" : CVT_FROM_ALL<sign # "32", B32>;
+ defm CVT_ # sign # "64" : CVT_FROM_ALL<sign # "64", B64>;
+ }
defm CVT_f16 : CVT_FROM_ALL<"f16", B16>;
defm CVT_bf16 : CVT_FROM_ALL<"bf16", B16, [hasPTX<78>, hasSM<90>]>;
defm CVT_f32 : CVT_FROM_ALL<"f32", B32>;
@@ -571,18 +538,12 @@ let hasSideEffects = false in {
// These cvts are different from those above: The source and dest registers
// are of the same type.
- def CVT_INREG_s16_s8 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src),
- "cvt.s16.s8">;
- def CVT_INREG_s32_s8 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src),
- "cvt.s32.s8">;
- def CVT_INREG_s32_s16 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src),
- "cvt.s32.s16">;
- def CVT_INREG_s64_s8 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
- "cvt.s64.s8">;
- def CVT_INREG_s64_s16 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
- "cvt.s64.s16">;
- def CVT_INREG_s64_s32 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
- "cvt.s64.s32">;
+ def CVT_INREG_s16_s8 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), "cvt.s16.s8">;
+ def CVT_INREG_s32_s8 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "cvt.s32.s8">;
+ def CVT_INREG_s32_s16 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "cvt.s32.s16">;
+ def CVT_INREG_s64_s8 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s8">;
+ def CVT_INREG_s64_s16 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s16">;
+ def CVT_INREG_s64_s32 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s32">;
multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
def _f32 :
@@ -784,7 +745,7 @@ defm SUB : I3<"sub.s", sub, commutative = false>;
def ADD16x2 : I16x2<"add.s", add>;
-// in32 and int64 addition and subtraction with carry-out.
+// int32 and int64 addition and subtraction with carry-out.
defm ADDCC : ADD_SUB_INT_CARRY<"add.cc", addc, commutative = true>;
defm SUBCC : ADD_SUB_INT_CARRY<"sub.cc", subc, commutative = false>;
@@ -805,17 +766,17 @@ defm UDIV : I3<"div.u", udiv, commutative = false>;
defm SREM : I3<"rem.s", srem, commutative = false>;
defm UREM : I3<"rem.u", urem, commutative = false>;
-// Integer absolute value. NumBits should be one minus the bit width of RC.
-// This idiom implements the algorithm at
-// http://graphics.stanford.edu/~seander/bithacks.html#IntegerAbs.
-multiclass ABS<ValueType T, RegisterClass RC, string SizeName> {
- def : BasicNVPTXInst<(outs RC:$dst), (ins RC:$a),
- "abs" # SizeName,
- [(set T:$dst, (abs T:$a))]>;
+foreach t = [I16RT, I32RT, I64RT] in {
+ def ABS_S # t.Size :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a),
+ "abs.s" # t.Size,
+ [(set t.Ty:$dst, (abs t.Ty:$a))]>;
+
+ def NEG_S # t.Size :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
+ "neg.s" # t.Size,
+ [(set t.Ty:$dst, (ineg t.Ty:$src))]>;
}
-defm ABS_16 : ABS<i16, B16, ".s16">;
-defm ABS_32 : ABS<i32, B32, ".s32">;
-defm ABS_64 : ABS<i64, B64, ".s64">;
// Integer min/max.
defm SMAX : I3<"max.s", smax, commutative = true>;
@@ -832,170 +793,63 @@ def UMIN16x2 : I16x2<"min.u", umin>;
//
// Wide multiplication
//
-def MULWIDES64 :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.s32">;
-def MULWIDES64Imm :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.s32">;
-def MULWIDES64Imm64 :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.s32">;
-
-def MULWIDEU64 :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.u32">;
-def MULWIDEU64Imm :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.u32">;
-def MULWIDEU64Imm64 :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.u32">;
-
-def MULWIDES32 :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.s16">;
-def MULWIDES32Imm :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.s16">;
-def MULWIDES32Imm32 :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.s16">;
-
-def MULWIDEU32 :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.u16">;
-def MULWIDEU32Imm :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.u16">;
-def MULWIDEU32Imm32 :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.u16">;
-
-def SDTMulWide : SDTypeProfile<1, 2, [SDTCisSameAs<1, 2>]>;
-def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>;
-def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>;
-
-// Matchers for signed, unsigned mul.wide ISD nodes.
-let Predicates = [doMulWide] in {
- def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>;
- def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>;
- def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>;
- def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)), (MULWIDEU32Imm $a, imm:$b)>;
-
- def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)), (MULWIDES64 $a, $b)>;
- def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)), (MULWIDES64Imm $a, imm:$b)>;
- def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)), (MULWIDEU64 $a, $b)>;
- def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>;
-}
-
-// Predicates used for converting some patterns to mul.wide.
-def SInt32Const : PatLeaf<(imm), [{
- const APInt &v = N->getAPIntValue();
- return v.isSignedIntN(32);
-}]>;
-
-def UInt32Const : PatLeaf<(imm), [{
- const APInt &v = N->getAPIntValue();
- return v.isIntN(32);
-}]>;
-def SInt16Const : PatLeaf<(imm), [{
- const APInt &v = N->getAPIntValue();
- return v.isSignedIntN(16);
-}]>;
-
-def UInt16Const : PatLeaf<(imm), [{
- const APInt &v = N->getAPIntValue();
- return v.isIntN(16);
-}]>;
-
-def IntConst_0_30 : PatLeaf<(imm), [{
- // Check if 0 <= v < 31; only then will the result of (x << v) be an int32.
- const APInt &v = N->getAPIntValue();
- return v.sge(0) && v.slt(31);
-}]>;
-
-def IntConst_0_14 : PatLeaf<(imm), [{
- // Check if 0 <= v < 15; only then will the result of (x << v) be an int16.
- const APInt &v = N->getAPIntValue();
- return v.sge(0) && v.slt(15);
-}]>;
-
-def SHL2MUL32 : SDNodeXForm<imm, [{
- const APInt &v = N->getAPIntValue();
- APInt temp(32, 1);
- return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i32);
-}]>;
+def SDTMulWide : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>;
+def smul_wide : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative]>;
+def umul_wide : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative]>;
-def SHL2MUL16 : SDNodeXForm<imm, [{
- const APInt &v = N->getAPIntValue();
- APInt temp(16, 1);
- return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i16);
-}]>;
-
-// Convert "sign/zero-extend, then shift left by an immediate" to mul.wide.
-let Predicates = [doMulWide] in {
- def : Pat<(shl (sext i32:$a), (i32 IntConst_0_30:$b)),
- (MULWIDES64Imm $a, (SHL2MUL32 $b))>;
- def : Pat<(shl (zext i32:$a), (i32 IntConst_0_30:$b)),
- (MULWIDEU64Imm $a, (SHL2MUL32 $b))>;
-
- def : Pat<(shl (sext i16:$a), (i16 IntConst_0_14:$b)),
- (MULWIDES32Imm $a, (SHL2MUL16 $b))>;
- def : Pat<(shl (zext i16:$a), (i16 IntConst_0_14:$b)),
- (MULWIDEU32Imm $a, (SHL2MUL16 $b))>;
-
- // Convert "sign/zero-extend then multiply" to mul.wide.
- def : Pat<(mul (sext i32:$a), (sext i32:$b)),
- (MULWIDES64 $a, $b)>;
- def : Pat<(mul (sext i32:$a), (i64 SInt32Const:$b)),
- (MULWIDES64Imm64 $a, (i64 SInt32Const:$b))>;
-
- def : Pat<(mul (zext i32:$a), (zext i32:$b)),
- (MULWIDEU64 $a, $b)>;
- def : Pat<(mul (zext i32:$a), (i64 UInt32Const:$b)),
- (MULWIDEU64Imm64 $a, (i64 UInt32Const:$b))>;
- def : Pat<(mul (sext i16:$a), (sext i16:$b)),
- (MULWIDES32 $a, $b)>;
- def : Pat<(mul (sext i16:$a), (i32 SInt16Const:$b)),
- (MULWIDES32Imm32 $a, (i32 SInt16Const:$b))>;
-
- def : Pat<(mul (zext i16:$a), (zext i16:$b)),
- (MULWIDEU32 $a, $b)>;
- def : Pat<(mul (zext i16:$a), (i32 UInt16Const:$b)),
- (MULWIDEU32Imm32 $a, (i32 UInt16Const:$b))>;
+multiclass MULWIDEInst<string suffix, SDPatternOperator op, RegTyInfo big_t, RegTyInfo small_t> {
+ def suffix # _rr :
+ BasicNVPTXInst<(outs big_t.RC:$dst), (ins small_t.RC:$a, small_t.RC:$b),
+ "mul.wide." # suffix,
+ [(set big_t.Ty:$dst, (op small_t.Ty:$a, small_t.Ty:$b))]>;
+ def suffix # _ri :
+ BasicNVPTXInst<(outs big_t.RC:$dst), (ins small_t.RC:$a, small_t.Imm:$b),
+ "mul.wide." # suffix,
+ [(set big_t.Ty:$dst, (op small_t.Ty:$a, imm:$b))]>;
}
+defm MUL_WIDE : MULWIDEInst<"s32", smul_wide, I64RT, I32RT>;
+defm MUL_WIDE : MULWIDEInst<"u32", umul_wide, I64RT, I32RT>;
+defm MUL_WIDE : MULWIDEInst<"s16", smul_wide, I32RT, I16RT>;
+defm MUL_WIDE : MULWIDEInst<"u16", umul_wide, I32RT, I16RT>;
+
//
// Integer multiply-add
//
-def mul_oneuse : OneUse2<mul>;
-
-multiclass MAD<string Ptx, ValueType VT, NVPTXRegClass Reg, Operand Imm> {
+multiclass MADInst<string suffix, SDPatternOperator op, RegTyInfo big_t, RegTyInfo small_t> {
def rrr:
- BasicNVPTXInst<(outs Reg:$dst),
- (ins Reg:$a, Reg:$b, Reg:$c),
- Ptx,
- [(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), VT:$c))]>;
-
- def rir:
- BasicNVPTXInst<(outs Reg:$dst),
- (ins Reg:$a, Imm:$b, Reg:$c),
- Ptx,
- [(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), VT:$c))]>;
+ BasicNVPTXInst<(outs big_t.RC:$dst),
+ (ins small_t.RC:$a, small_t.RC:$b, big_t.RC:$c),
+ "mad." # suffix,
+ [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, small_t.Ty:$b), big_t.Ty:$c))]>;
def rri:
- BasicNVPTXInst<(outs Reg:$dst),
- (ins Reg:$a, Reg:$b, Imm:$c),
- Ptx,
- [(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), imm:$c))]>;
+ BasicNVPTXInst<(outs big_t.RC:$dst),
+ (ins small_t.RC:$a, small_t.RC:$b, big_t.Imm:$c),
+ "mad." # suffix,
+ [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, small_t.Ty:$b), imm:$c))]>;
+ def rir:
+ BasicNVPTXInst<(outs big_t.RC:$dst),
+ (ins small_t.RC:$a, small_t.Imm:$b, big_t.RC:$c),
+ "mad." # suffix,
+ [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, imm:$b), big_t.Ty:$c))]>;
def rii:
- BasicNVPTXInst<(outs Reg:$dst),
- (ins Reg:$a, Imm:$b, Imm:$c),
- Ptx,
- [(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), imm:$c))]>;
+ BasicNVPTXInst<(outs big_t.RC:$dst),
+ (ins small_t.RC:$a, small_t.Imm:$b, big_t.Imm:$c),
+ "mad." # suffix,
+ [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, imm:$b), imm:$c))]>;
}
let Predicates = [hasOptEnabled] in {
-defm MAD16 : MAD<"mad.lo.s16", i16, B16, i16imm>;
-defm MAD32 : MAD<"mad.lo.s32", i32, B32, i32imm>;
-defm MAD64 : MAD<"mad.lo.s64", i64, B64, i64imm>;
-}
+ defm MAD_LO_S16 : MADInst<"lo.s16", mul, I16RT, I16RT>;
+ defm MAD_LO_S32 : MADInst<"lo.s32", mul, I32RT, I32RT>;
+ defm MAD_LO_S64 : MADInst<"lo.s64", mul, I64RT, I64RT>;
-foreach t = [I16RT, I32RT, I64RT] in {
- def NEG_S # t.Size :
- BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
- "neg.s" # t.Size,
- [(set t.Ty:$dst, (ineg t.Ty:$src))]>;
+ defm MAD_WIDE_U16 : MADInst<"wide.u16", umul_wide, I32RT, I16RT>;
+ defm MAD_WIDE_S16 : MADInst<"wide.s16", smul_wide, I32RT, I16RT>;
+ defm MAD_WIDE_U32 : MADInst<"wide.u32", umul_wide, I64RT, I32RT>;
+ defm MAD_WIDE_S32 : MADInst<"wide.s32", smul_wide, I64RT, I32RT>;
}
//-----------------------------------
@@ -1106,8 +960,7 @@ def fdiv_approx : PatFrag<(ops node:$a, node:$b),
def FRCP32_approx_r :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$b), (ins FTZFlag:$ftz),
"rcp.approx$ftz.f32",
[(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>;
@@ -1116,14 +969,12 @@ def FRCP32_approx_r :
//
def FDIV32_approx_rr :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, B32:$b), (ins FTZFlag:$ftz),
"div.approx$ftz.f32",
[(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>;
def FDIV32_approx_ri :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, f32imm:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz),
"div.approx$ftz.f32",
[(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>;
//
@@ -1146,14 +997,12 @@ def : Pat<(fdiv_full f32imm_1, f32:$b),
//
def FDIV32rr :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, B32:$b), (ins FTZFlag:$ftz),
"div.full$ftz.f32",
[(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>;
def FDIV32ri :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, f32imm:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz),
"div.full$ftz.f32",
[(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>;
//
@@ -1167,8 +1016,7 @@ def fdiv_ftz : PatFrag<(ops node:$a, node:$b),
def FRCP32r_prec :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$b), (ins FTZFlag:$ftz),
"rcp.rn$ftz.f32",
[(set f32:$dst, (fdiv_ftz f32imm_1, f32:$b))]>;
//
@@ -1176,14 +1024,12 @@ def FRCP32r_prec :
//
def FDIV32rr_prec :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, B32:$b), (ins FTZFlag:$ftz),
"div.rn$ftz.f32",
[(set f32:$dst, (fdiv_ftz f32:$a, f32:$b))]>;
def FDIV32ri_prec :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, f32imm:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz),
"div.rn$ftz.f32",
[(set f32:$dst, (fdiv_ftz f32:$a, fpimm:$b))]>;
@@ -1262,10 +1108,8 @@ def TANH_APPROX_f32 :
// Template for three-arg bitwise operations. Takes three args, Creates .b16,
// .b32, .b64, and .pred (predicate registers -- i.e., i1) versions of OpcStr.
multiclass BITWISE<string OpcStr, SDNode OpNode> {
- defm b1 : I3Inst<OpcStr # ".pred", OpNode, I1RT, commutative = true>;
- defm b16 : I3Inst<OpcStr # ".b16", OpNode, I16RT, commutative = true>;
- defm b32 : I3Inst<OpcStr # ".b32", OpNode, I32RT, commutative = true>;
- defm b64 : I3Inst<OpcStr # ".b64", OpNode, I64RT, commutative = true>;
+ foreach t = [I1RT, I16RT, I32RT, I64RT] in
+ defm _ # t.PtxType : I3Inst<OpcStr # "." # t.PtxType, OpNode, t, commutative = true>;
}
defm OR : BITWISE<"or", or>;
@@ -1273,48 +1117,40 @@ defm AND : BITWISE<"and", and>;
defm XOR : BITWISE<"xor", xor>;
// PTX does not support mul on predicates, convert to and instructions
-def : Pat<(mul i1:$a, i1:$b), (ANDb1rr $a, $b)>;
-def : Pat<(mul i1:$a, imm:$b), (ANDb1ri $a, imm:$b)>;
+def : Pat<(mul i1:$a, i1:$b), (AND_predrr $a, $b)>;
+def : Pat<(mul i1:$a, imm:$b), (AND_predri $a, imm:$b)>;
foreach op = [add, sub] in {
- def : Pat<(op i1:$a, i1:$b), (XORb1rr $a, $b)>;
- def : Pat<(op i1:$a, imm:$b), (XORb1ri $a, imm:$b)>;
+ def : Pat<(op i1:$a, i1:$b), (XOR_predrr $a, $b)>;
+ def : Pat<(op i1:$a, imm:$b), (XOR_predri $a, imm:$b)>;
}
// These transformations were once reliably performed by instcombine, but thanks
// to poison semantics they are no longer safe for LLVM IR, perform them here
// instead.
-def : Pat<(select i1:$a, i1:$b, 0), (ANDb1rr $a, $b)>;
-def : Pat<(select i1:$a, 1, i1:$b), (ORb1rr $a, $b)>;
+def : Pat<(select i1:$a, i1:$b, 0), (AND_predrr $a, $b)>;
+def : Pat<(select i1:$a, 1, i1:$b), (OR_predrr $a, $b)>;
// Lower logical v2i16/v4i8 ops as bitwise ops on b32.
foreach vt = [v2i16, v4i8] in {
- def : Pat<(or vt:$a, vt:$b), (ORb32rr $a, $b)>;
- def : Pat<(xor vt:$a, vt:$b), (XORb32rr $a, $b)>;
- def : Pat<(and vt:$a, vt:$b), (ANDb32rr $a, $b)>;
+ def : Pat<(or vt:$a, vt:$b), (OR_b32rr $a, $b)>;
+ def : Pat<(xor vt:$a, vt:$b), (XOR_b32rr $a, $b)>;
+ def : Pat<(and vt:$a, vt:$b), (AND_b32rr $a, $b)>;
// The constants get legalized into a bitcast from i32, so that's what we need
// to match here.
def: Pat<(or vt:$a, (vt (bitconvert (i32 imm:$b)))),
- (ORb32ri $a, imm:$b)>;
+ (OR_b32ri $a, imm:$b)>;
def: Pat<(xor vt:$a, (vt (bitconvert (i32 imm:$b)))),
- (XORb32ri $a, imm:$b)>;
+ (XOR_b32ri $a, imm:$b)>;
def: Pat<(and vt:$a, (vt (bitconvert (i32 imm:$b)))),
- (ANDb32ri $a, imm:$b)>;
-}
-
-def NOT1 : BasicNVPTXInst<(outs B1:$dst), (ins B1:$src),
- "not.pred",
- [(set i1:$dst, (not i1:$src))]>;
-def NOT16 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src),
- "not.b16",
- [(set i16:$dst, (not i16:$src))]>;
-def NOT32 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src),
- "not.b32",
- [(set i32:$dst, (not i32:$src))]>;
-def NOT64 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
- "not.b64",
- [(set i64:$dst, (not i64:$src))]>;
+ (AND_b32ri $a, imm:$b)>;
+}
+
+foreach t = [I1RT, I16RT, I32RT, I64RT] in
+ def NOT_ # t.PtxType : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
+ "not." # t.PtxType,
+ [(set t.Ty:$dst, (not t.Ty:$src))]>;
// Template for left/right shifts. Takes three operands,
// [dest (reg), src (reg), shift (reg or imm)].
@@ -1322,34 +1158,22 @@ def NOT64 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
//
// This template also defines a 32-bit shift (imm, imm) instruction.
multiclass SHIFT<string OpcStr, SDNode OpNode> {
- def i64rr :
- BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, B32:$b),
- OpcStr # "64",
- [(set i64:$dst, (OpNode i64:$a, i32:$b))]>;
- def i64ri :
- BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, i32imm:$b),
- OpcStr # "64",
- [(set i64:$dst, (OpNode i64:$a, (i32 imm:$b)))]>;
- def i32rr :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
- OpcStr # "32",
- [(set i32:$dst, (OpNode i32:$a, i32:$b))]>;
- def i32ri :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, i32imm:$b),
- OpcStr # "32",
- [(set i32:$dst, (OpNode i32:$a, (i32 imm:$b)))]>;
- def i32ii :
- BasicNVPTXInst<(outs B32:$dst), (ins i32imm:$a, i32imm:$b),
- OpcStr # "32",
- [(set i32:$dst, (OpNode (i32 imm:$a), (i32 imm:$b)))]>;
- def i16rr :
- BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B32:$b),
- OpcStr # "16",
- [(set i16:$dst, (OpNode i16:$a, i32:$b))]>;
- def i16ri :
- BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, i32imm:$b),
- OpcStr # "16",
- [(set i16:$dst, (OpNode i16:$a, (i32 imm:$b)))]>;
+ let hasSideEffects = false in {
+ foreach t = [I64RT, I32RT, I16RT] in {
+ def t.Size # _rr :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, B32:$b),
+ OpcStr # t.Size,
+ [(set t.Ty:$dst, (OpNode t.Ty:$a, i32:$b))]>;
+ def t.Size # _ri :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, i32imm:$b),
+ OpcStr # t.Size,
+ [(set t.Ty:$dst, (OpNode t.Ty:$a, (i32 imm:$b)))]>;
+ def t.Size # _ii :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, i32imm:$b),
+ OpcStr # t.Size,
+ [(set t.Ty:$dst, (OpNode (t.Ty imm:$a), (i32 imm:$b)))]>;
+ }
+ }
}
defm SHL : SHIFT<"shl.b", shl>;
@@ -1357,14 +1181,11 @@ defm SRA : SHIFT<"shr.s", sra>;
defm SRL : SHIFT<"shr.u", srl>;
// Bit-reverse
-def BREV32 :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$a),
- "brev.b32",
- [(set i32:$dst, (bitreverse i32:$a))]>;
-def BREV64 :
- BasicNVPTXInst<(outs B64:$dst), (ins B64:$a),
- "brev.b64",
- [(set i64:$dst, (bitreverse i64:$a))]>;
+foreach t = [I64RT, I32RT] in
+ def BREV_ # t.PtxType :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a),
+ "brev." # t.PtxType,
+ [(set t.Ty:$dst, (bitreverse t.Ty:$a))]>;
//
@@ -1516,20 +1337,19 @@ def : Pat<(i16 (sext_inreg (trunc (prmt i32:$s, 0, byte_extract_prmt:$sel, PrmtN
// Byte extraction via shift/trunc/sext
-def : Pat<(i16 (sext_inreg (trunc i32:$s), i8)),
- (CVT_s8_s32 $s, CvtNONE)>;
-def : Pat<(i16 (sext_inreg (trunc (srl i32:$s, (i32 imm:$o))), i8)),
+def : Pat<(i16 (sext_inreg (trunc i32:$s), i8)), (CVT_s8_s32 $s, CvtNONE)>;
+def : Pat<(i16 (sext_inreg (trunc i64:$s), i8)), (CVT_s8_s64 $s, CvtNONE)>;
+
+def : Pat<(sext_inreg (srl i32:$s, (i32 imm:$o)), i8), (BFE_S32rii $s, imm:$o, 8)>;
+def : Pat<(sext_inreg (srl i64:$s, (i32 imm:$o)), i8), (BFE_S64rii $s, imm:$o, 8)>;
+
+def : Pat<(i16 (sext_inreg (trunc (srl i32:$s, (i32 imm:$o))), i8)),
(CVT_s8_s32 (BFE_S32rii $s, imm:$o, 8), CvtNONE)>;
-def : Pat<(sext_inreg (srl i32:$s, (i32 imm:$o)), i8),
- (BFE_S32rii $s, imm:$o, 8)>;
+def : Pat<(i16 (sext_inreg (trunc (srl i64:$s, (i32 imm:$o))), i8)),
+ (CVT_s8_s64 (BFE_S64rii $s, imm:$o, 8), CvtNONE)>;
+
def : Pat<(i16 (sra (i16 (trunc i32:$s)), (i32 8))),
(CVT_s8_s32 (BFE_S32rii $s, 8, 8), CvtNONE)>;
-def : Pat<(sext_inreg (srl i64:$s, (i32 imm:$o)), i8),
- (BFE_S64rii $s, imm:$o, 8)>;
-def : Pat<(i16 (sext_inreg (trunc i64:$s), i8)),
- (CVT_s8_s64 $s, CvtNONE)>;
-def : Pat<(i16 (sext_inreg (trunc (srl i64:$s, (i32 imm:$o))), i8)),
- (CVT_s8_s64 (BFE_S64rii $s, imm:$o, 8), CvtNONE)>;
//-----------------------------------
// Comparison instructions (setp, set)
@@ -1619,10 +1439,7 @@ def SETP_bf16x2rr :
def addr : ComplexPattern<pAny, 2, "SelectADDR">;
-def ADDR_base : Operand<pAny> {
- let PrintMethod = "printOperand";
-}
-
+def ADDR_base : Operand<pAny>;
def ADDR : Operand<pAny> {
let PrintMethod = "printMemOperand";
let MIOperandInfo = (ops ADDR_base, i32imm);
@@ -1636,10 +1453,6 @@ def MmaCode : Operand<i32> {
let PrintMethod = "printMmaCode";
}
-def Offseti32imm : Operand<i32> {
- let PrintMethod = "printOffseti32imm";
-}
-
// Get pointer to local stack.
let hasSideEffects = false in {
def MOV_DEPOT_ADDR : NVPTXInst<(outs B32:$d), (ins i32imm:$num),
@@ -1651,33 +1464,31 @@ let hasSideEffects = false in {
// copyPhysreg is hard-coded in NVPTXInstrInfo.cpp
let hasSideEffects = false, isAsCheapAsAMove = true in {
- // Class for register-to-register moves
- class MOVr<RegisterClass RC, string OpStr> :
- BasicNVPTXInst<(outs RC:$dst), (ins RC:$src),
- "mov." # OpStr>;
-
- // Class for immediate-to-register moves
- class MOVi<RegisterClass RC, string OpStr, ValueType VT, Operand IMMType, SDNode ImmNode> :
- BasicNVPTXInst<(outs RC:$dst), (ins IMMType:$src),
- "mov." # OpStr,
- [(set VT:$dst, ImmNode:$src)]>;
-}
+ let isMoveReg = true in
+ class MOVr<RegisterClass RC, string OpStr> :
+ BasicNVPTXInst<(outs RC:$dst), (ins RC:$src), "mov." # OpStr>;
-def IMOV1r : MOVr<B1, "pred">;
-def MOV16r : MOVr<B16, "b16">;
-def IMOV32r : MOVr<B32, "b32">;
-def IMOV64r : MOVr<B64, "b64">;
-def IMOV128r : MOVr<B128, "b128">;
+ let isMoveImm = true in
+ class MOVi<RegTyInfo t, string suffix> :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.Imm:$src),
+ "mov." # suffix,
+ [(set t.Ty:$dst, t.ImmNode:$src)]>;
+}
+def MOV_B1_r : MOVr<B1, "pred">;
+def MOV_B16_r : MOVr<B16, "b16">;
+def MOV_B32_r : MOVr<B32, "b32">;
+def MOV_B64_r : MOVr<B64, "b64">;
+def MOV_B128_r : MOVr<B128, "b128">;
-def IMOV1i : MOVi<B1, "pred", i1, i1imm, imm>;
-def IMOV16i : MOVi<B16, "b16", i16, i16imm, imm>;
-def IMOV32i : MOVi<B32, "b32", i32, i32imm, imm>;
-def IMOV64i : MOVi<B64, "b64", i64, i64imm, imm>;
-def FMOV16i : MOVi<B16, "b16", f16, f16imm, fpimm>;
-def BFMOV16i : MOVi<B16, "b16", bf16, bf16imm, fpimm>;
-def FMOV32i : MOVi<B32, "b32", f32, f32imm, fpimm>;
-def FMOV64i : MOVi<B64, "b64", f64, f64imm, fpimm>;
+def MOV_B1_i : MOVi<I1RT, "pred">;
+def MOV_B16_i : MOVi<I16RT, "b16">;
+def MOV_B32_i : MOVi<I32RT, "b32">;
+def MOV_B64_i : MOVi<I64RT, "b64">;
+def MOV_F16_i : MOVi<F16RT, "b16">;
+def MOV_BF16_i : MOVi<BF16RT, "b16">;
+def MOV_F32_i : MOVi<F32RT, "b32">;
+def MOV_F64_i : MOVi<F64RT, "b64">;
def to_tglobaladdr : SDNodeXForm<globaladdr, [{
@@ -1695,11 +1506,11 @@ def to_tframeindex : SDNodeXForm<frameindex, [{
return CurDAG->getTargetFrameIndex(N->getIndex(), N->getValueType(0));
}]>;
-def : Pat<(i32 globaladdr:$dst), (IMOV32i (to_tglobaladdr $dst))>;
-def : Pat<(i64 globaladdr:$dst), (IMOV64i (to_tglobaladdr $dst))>;
+def : Pat<(i32 globaladdr:$dst), (MOV_B32_i (to_tglobaladdr $dst))>;
+def : Pat<(i64 globaladdr:$dst), (MOV_B64_i (to_tglobaladdr $dst))>;
-def : Pat<(i32 externalsym:$dst), (IMOV32i (to_texternsym $dst))>;
-def : Pat<(i64 externalsym:$dst), (IMOV64i (to_texternsym $dst))>;
+def : Pat<(i32 externalsym:$dst), (MOV_B32_i (to_texternsym $dst))>;
+def : Pat<(i64 externalsym:$dst), (MOV_B64_i (to_texternsym $dst))>;
//---- Copy Frame Index ----
def LEA_ADDRi : NVPTXInst<(outs B32:$dst), (ins ADDR:$addr),
@@ -1713,56 +1524,39 @@ def : Pat<(i64 frameindex:$fi), (LEA_ADDRi64 (to_tframeindex $fi), 0)>;
//-----------------------------------
// Comparison and Selection
//-----------------------------------
+// TODO: These patterns seem very specific and brittle. We should try to find
+// a more general solution.
def cond_signed : PatLeaf<(cond), [{
return isSignedIntSetCC(N->get());
}]>;
-def cond_not_signed : PatLeaf<(cond), [{
- return !isSignedIntSetCC(N->get());
-}]>;
+// A 16-bit signed comparison of sign-extended byte extracts can be converted
+// to 32-bit comparison if we change the PRMT to sign-extend the extracted
+// bytes.
+def : Pat<(setcc (i16 (sext_inreg (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE)), i8)),
+ (i16 (sext_inreg (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE)), i8)),
+ cond_signed:$cc),
+ (SETP_i32rr (PRMT_B32rii i32:$a, 0, (to_sign_extend_selector $sel_a), PrmtNONE),
+ (PRMT_B32rii i32:$b, 0, (to_sign_extend_selector $sel_b), PrmtNONE),
+ (cond2cc $cc))>;
+
+// A 16-bit comparison of truncated byte extracts can be be converted to 32-bit
+// comparison because we know that the truncate is just trancating off zeros
+// and that the most-significant byte is also zeros so the meaning of signed and
+// unsigned comparisons will not be changed.
+def : Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))),
+ (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))),
+ cond:$cc),
+ (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
+ (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
+ (cond2cc $cc))>;
-// comparisons of i8 extracted with PRMT as i32
-// It's faster to do comparison directly on i32 extracted by PRMT,
-// instead of the long conversion and sign extending.
-def: Pat<(setcc (i16 (sext_inreg (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), i8)),
- (i16 (sext_inreg (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), i8)),
- cond_signed:$cc),
- (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
- (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
- (cond2cc $cc))>;
-
-def: Pat<(setcc (i16 (sext_inreg (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE)), i8)),
- (i16 (sext_inreg (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE)), i8)),
- cond_signed:$cc),
- (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
- (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
- (cond2cc $cc))>;
-
-def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))),
- (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))),
- cond_signed:$cc),
- (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
- (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
- (cond2cc $cc))>;
-
-def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))),
- (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))),
- cond_not_signed:$cc),
- (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
- (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
- (cond2cc $cc))>;
def SDTDeclareArrayParam :
SDTypeProfile<0, 3, [SDTCisVT<0, i32>, SDTCisVT<1, i32>, SDTCisVT<2, i32>]>;
def SDTDeclareScalarParam :
SDTypeProfile<0, 2, [SDTCisVT<0, i32>, SDTCisVT<1, i32>]>;
-def SDTLoadParamProfile : SDTypeProfile<1, 2, [SDTCisInt<1>, SDTCisInt<2>]>;
-def SDTLoadParamV2Profile : SDTypeProfile<2, 2, [SDTCisSameAs<0, 1>, SDTCisInt<2>, SDTCisInt<3>]>;
-def SDTLoadParamV4Profile : SDTypeProfile<4, 2, [SDTCisInt<4>, SDTCisInt<5>]>;
-def SDTStoreParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>;
-def SDTStoreParamV2Profile : SDTypeProfile<0, 4, [SDTCisInt<0>, SDTCisInt<1>]>;
-def SDTStoreParamV4Profile : SDTypeProfile<0, 6, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDTMoveParamProfile : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisSameAs<0, 1>]>;
def SDTProxyReg : SDTypeProfile<1, 1, [SDTCisSameAs<0, 1>]>;
@@ -1774,104 +1568,20 @@ def declare_array_param :
def declare_scalar_param :
SDNode<"NVPTXISD::DeclareScalarParam", SDTDeclareScalarParam,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
-
-def LoadParam :
- SDNode<"NVPTXISD::LoadParam", SDTLoadParamProfile,
- [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>;
-def LoadParamV2 :
- SDNode<"NVPTXISD::LoadParamV2", SDTLoadParamV2Profile,
- [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>;
-def LoadParamV4 :
- SDNode<"NVPTXISD::LoadParamV4", SDTLoadParamV4Profile,
- [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>;
-def StoreParam :
- SDNode<"NVPTXISD::StoreParam", SDTStoreParamProfile,
- [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
-def StoreParamV2 :
- SDNode<"NVPTXISD::StoreParamV2", SDTStoreParamV2Profile,
- [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
-def StoreParamV4 :
- SDNode<"NVPTXISD::StoreParamV4", SDTStoreParamV4Profile,
- [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
def MoveParam :
SDNode<"NVPTXISD::MoveParam", SDTMoveParamProfile, []>;
def proxy_reg :
SDNode<"NVPTXISD::ProxyReg", SDTProxyReg, [SDNPHasChain]>;
/// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
- /// NumParams, Callee, Proto, InGlue)
+ /// NumParams, Callee, Proto)
def SDTCallProfile : SDTypeProfile<0, 6,
[SDTCisVT<0, i32>, SDTCisVT<1, i32>, SDTCisVT<2, i32>,
SDTCisVT<3, i32>, SDTCisVT<5, i32>]>;
-def call :
- SDNode<"NVPTXISD::CALL", SDTCallProfile,
- [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
-
-let mayLoad = true in {
- class LoadParamMemInst<NVPTXRegClass regclass, string opstr> :
- NVPTXInst<(outs regclass:$dst), (ins Offseti32imm:$b),
- !strconcat("ld.param", opstr, " \t$dst, [retval0$b];"),
- []>;
-
- class LoadParamV2MemInst<NVPTXRegClass regclass, string opstr> :
- NVPTXInst<(outs regclass:$dst, regclass:$dst2), (ins Offseti32imm:$b),
- !strconcat("ld.param.v2", opstr,
- " \t{{$dst, $dst2}}, [retval0$b];"), []>;
-
- class LoadParamV4MemInst<NVPTXRegClass regclass, string opstr> :
- NVPTXInst<(outs regclass:$dst, regclass:$dst2, regclass:$dst3,
- regclass:$dst4),
- (ins Offseti32imm:$b),
- !strconcat("ld.param.v4", opstr,
- " \t{{$dst, $dst2, $dst3, $dst4}}, [retval0$b];"),
- []>;
-}
-
-let mayStore = true in {
-
- multiclass StoreParamInst<NVPTXRegClass regclass, Operand IMMType, string opstr, bit support_imm = true> {
- foreach op = [IMMType, regclass] in
- if !or(support_imm, !isa<NVPTXRegClass>(op)) then
- def _ # !if(!isa<NVPTXRegClass>(op), "r", "i")
- : NVPTXInst<(outs),
- (ins op:$val, i32imm:$a, Offseti32imm:$b),
- "st.param" # opstr # " \t[param$a$b], $val;",
- []>;
- }
-
- multiclass StoreParamV2Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> {
- foreach op1 = [IMMType, regclass] in
- foreach op2 = [IMMType, regclass] in
- def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i")
- # !if(!isa<NVPTXRegClass>(op2), "r", "i")
- : NVPTXInst<(outs),
- (ins op1:$val1, op2:$val2,
- i32imm:$a, Offseti32imm:$b),
- "st.param.v2" # opstr # " \t[param$a$b], {{$val1, $val2}};",
- []>;
- }
-
- multiclass StoreParamV4Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> {
- foreach op1 = [IMMType, regclass] in
- foreach op2 = [IMMType, regclass] in
- foreach op3 = [IMMType, regclass] in
- foreach op4 = [IMMType, regclass] in
- def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i")
- # !if(!isa<NVPTXRegClass>(op2), "r", "i")
- # !if(!isa<NVPTXRegClass>(op3), "r", "i")
- # !if(!isa<NVPTXRegClass>(op4), "r", "i")
-
- : NVPTXInst<(outs),
- (ins op1:$val1, op2:$val2, op3:$val3, op4:$val4,
- i32imm:$a, Offseti32imm:$b),
- "st.param.v4" # opstr #
- " \t[param$a$b], {{$val1, $val2, $val3, $val4}};",
- []>;
- }
-}
+def call : SDNode<"NVPTXISD::CALL", SDTCallProfile, [SDNPHasChain, SDNPSideEffect]>;
/// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
-/// NumParams, Callee, Proto, InGlue)
+/// NumParams, Callee, Proto)
def CallOperand : Operand<i32> { let PrintMethod = "printCallOperand"; }
@@ -1908,43 +1618,6 @@ foreach is_convergent = [0, 1] in {
(call_uni_inst $addr, imm:$rets, imm:$params)>;
}
-def LoadParamMemI64 : LoadParamMemInst<B64, ".b64">;
-def LoadParamMemI32 : LoadParamMemInst<B32, ".b32">;
-def LoadParamMemI16 : LoadParamMemInst<B16, ".b16">;
-def LoadParamMemI8 : LoadParamMemInst<B16, ".b8">;
-def LoadParamMemV2I64 : LoadParamV2MemInst<B64, ".b64">;
-def LoadParamMemV2I32 : LoadParamV2MemInst<B32, ".b32">;
-def LoadParamMemV2I16 : LoadParamV2MemInst<B16, ".b16">;
-def LoadParamMemV2I8 : LoadParamV2MemInst<B16, ".b8">;
-def LoadParamMemV4I32 : LoadParamV4MemInst<B32, ".b32">;
-def LoadParamMemV4I16 : LoadParamV4MemInst<B16, ".b16">;
-def LoadParamMemV4I8 : LoadParamV4MemInst<B16, ".b8">;
-
-defm StoreParamI64 : StoreParamInst<B64, i64imm, ".b64">;
-defm StoreParamI32 : StoreParamInst<B32, i32imm, ".b32">;
-defm StoreParamI16 : StoreParamInst<B16, i16imm, ".b16">;
-defm StoreParamI8 : StoreParamInst<B16, i8imm, ".b8">;
-
-defm StoreParamI8TruncI32 : StoreParamInst<B32, i8imm, ".b8", /* support_imm */ false>;
-defm StoreParamI8TruncI64 : StoreParamInst<B64, i8imm, ".b8", /* support_imm */ false>;
-
-defm StoreParamV2I64 : StoreParamV2Inst<B64, i64imm, ".b64">;
-defm StoreParamV2I32 : StoreParamV2Inst<B32, i32imm, ".b32">;
-defm StoreParamV2I16 : StoreParamV2Inst<B16, i16imm, ".b16">;
-defm StoreParamV2I8 : StoreParamV2Inst<B16, i8imm, ".b8">;
-
-defm StoreParamV4I32 : StoreParamV4Inst<B32, i32imm, ".b32">;
-defm StoreParamV4I16 : StoreParamV4Inst<B16, i16imm, ".b16">;
-defm StoreParamV4I8 : StoreParamV4Inst<B16, i8imm, ".b8">;
-
-defm StoreParamF32 : StoreParamInst<B32, f32imm, ".b32">;
-defm StoreParamF64 : StoreParamInst<B64, f64imm, ".b64">;
-
-defm StoreParamV2F32 : StoreParamV2Inst<B32, f32imm, ".b32">;
-defm StoreParamV2F64 : StoreParamV2Inst<B64, f64imm, ".b64">;
-
-defm StoreParamV4F32 : StoreParamV4Inst<B32, f32imm, ".b32">;
-
def DECLARE_PARAM_array :
NVPTXInst<(outs), (ins i32imm:$a, i32imm:$align, i32imm:$size),
".param .align $align .b8 \t$a[$size];", []>;
@@ -1957,6 +1630,18 @@ def : Pat<(declare_array_param externalsym:$a, imm:$align, imm:$size),
def : Pat<(declare_scalar_param externalsym:$a, imm:$size),
(DECLARE_PARAM_scalar (to_texternsym $a), imm:$size)>;
+// Call prototype wrapper, this is a dummy instruction that just prints it's
+// operand which is string defining the prototype.
+def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
+def CallPrototype :
+ SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype,
+ [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def ProtoIdent : Operand<i32> { let PrintMethod = "printProtoIdent"; }
+def CALL_PROTOTYPE :
+ NVPTXInst<(outs), (ins ProtoIdent:$ident),
+ "$ident", [(CallPrototype (i32 texternalsym:$ident))]>;
+
+
foreach t = [I32RT, I64RT] in {
defvar inst_name = "MOV" # t.Size # "_PARAM";
def inst_name : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), "mov.b" # t.Size>;
@@ -1976,6 +1661,32 @@ defm ProxyRegB16 : ProxyRegInst<"b16", B16>;
defm ProxyRegB32 : ProxyRegInst<"b32", B32>;
defm ProxyRegB64 : ProxyRegInst<"b64", B64>;
+
+// Callseq start and end
+
+// Note: these nodes are marked as SDNPMayStore and SDNPMayLoad because
+// they define the scope in which the declared params may be used. Therefore
+// we add these flags to ensure ld.param and st.param are not sunk or hoisted
+// out of that scope.
+
+def callseq_start : SDNode<"ISD::CALLSEQ_START",
+ SDCallSeqStart<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>,
+ [SDNPHasChain, SDNPOutGlue,
+ SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>;
+def callseq_end : SDNode<"ISD::CALLSEQ_END",
+ SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>,
+ [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
+ SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>;
+
+def Callseq_Start :
+ NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
+ "\\{ // callseq $amt1, $amt2",
+ [(callseq_start timm:$amt1, timm:$amt2)]>;
+def Callseq_End :
+ NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
+ "\\} // callseq $amt1",
+ [(callseq_end timm:$amt1, timm:$amt2)]>;
+
//
// Load / Store Handling
//
@@ -1988,7 +1699,6 @@ class LD<NVPTXRegClass regclass>
"\t$dst, [$addr];", []>;
let mayLoad=1, hasSideEffects=0 in {
- def LD_i8 : LD<B16>;
def LD_i16 : LD<B16>;
def LD_i32 : LD<B32>;
def LD_i64 : LD<B64>;
@@ -2004,7 +1714,6 @@ class ST<DAGOperand O>
" \t[$addr], $src;", []>;
let mayStore=1, hasSideEffects=0 in {
- def ST_i8 : ST<RI16>;
def ST_i16 : ST<RI16>;
def ST_i32 : ST<RI32>;
def ST_i64 : ST<RI64>;
@@ -2037,7 +1746,6 @@ multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
"[$addr];", []>;
}
let mayLoad=1, hasSideEffects=0 in {
- defm LDV_i8 : LD_VEC<B16>;
defm LDV_i16 : LD_VEC<B16>;
defm LDV_i32 : LD_VEC<B32, support_v8 = true>;
defm LDV_i64 : LD_VEC<B64>;
@@ -2071,7 +1779,6 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
}
let mayStore=1, hasSideEffects=0 in {
- defm STV_i8 : ST_VEC<RI16>;
defm STV_i16 : ST_VEC<RI16>;
defm STV_i32 : ST_VEC<RI32, support_v8 = true>;
defm STV_i64 : ST_VEC<RI64>;
@@ -2241,14 +1948,14 @@ def : Pat<(i64 (anyext i32:$a)), (CVT_u64_u32 $a, CvtNONE)>;
// truncate i64
def : Pat<(i32 (trunc i64:$a)), (CVT_u32_u64 $a, CvtNONE)>;
def : Pat<(i16 (trunc i64:$a)), (CVT_u16_u64 $a, CvtNONE)>;
-def : Pat<(i1 (trunc i64:$a)), (SETP_i64ri (ANDb64ri $a, 1), 0, CmpNE)>;
+def : Pat<(i1 (trunc i64:$a)), (SETP_i64ri (AND_b64ri $a, 1), 0, CmpNE)>;
// truncate i32
def : Pat<(i16 (trunc i32:$a)), (CVT_u16_u32 $a, CvtNONE)>;
-def : Pat<(i1 (trunc i32:$a)), (SETP_i32ri (ANDb32ri $a, 1), 0, CmpNE)>;
+def : Pat<(i1 (trunc i32:$a)), (SETP_i32ri (AND_b32ri $a, 1), 0, CmpNE)>;
// truncate i16
-def : Pat<(i1 (trunc i16:$a)), (SETP_i16ri (ANDb16ri $a, 1), 0, CmpNE)>;
+def : Pat<(i1 (trunc i16:$a)), (SETP_i16ri (AND_b16ri $a, 1), 0, CmpNE)>;
// sext_inreg
def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 $a)>;
@@ -2492,52 +2199,20 @@ defm : CVT_ROUND<frint, CvtRNI, CvtRNI_FTZ>;
//-----------------------------------
let isTerminator=1 in {
- let isReturn=1, isBarrier=1 in
+ let isReturn=1, isBarrier=1 in
def Return : BasicNVPTXInst<(outs), (ins), "ret", [(retglue)]>;
- let isBranch=1 in
- def CBranch : NVPTXInst<(outs), (ins B1:$a, brtarget:$target),
+ let isBranch=1 in {
+ def CBranch : NVPTXInst<(outs), (ins B1:$a, brtarget:$target),
"@$a bra \t$target;",
[(brcond i1:$a, bb:$target)]>;
- let isBranch=1 in
- def CBranchOther : NVPTXInst<(outs), (ins B1:$a, brtarget:$target),
- "@!$a bra \t$target;", []>;
- let isBranch=1, isBarrier=1 in
+ let isBarrier=1 in
def GOTO : BasicNVPTXInst<(outs), (ins brtarget:$target),
- "bra.uni", [(br bb:$target)]>;
+ "bra.uni", [(br bb:$target)]>;
+ }
}
-def : Pat<(brcond i32:$a, bb:$target),
- (CBranch (SETP_i32ri $a, 0, CmpNE), bb:$target)>;
-
-// SelectionDAGBuilder::visitSWitchCase() will invert the condition of a
-// conditional branch if the target block is the next block so that the code
-// can fall through to the target block. The inversion is done by 'xor
-// condition, 1', which will be translated to (setne condition, -1). Since ptx
-// supports '@!pred bra target', we should use it.
-def : Pat<(brcond (i1 (setne i1:$a, -1)), bb:$target),
- (CBranchOther $a, bb:$target)>;
-
-// Call
-def SDT_NVPTXCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>,
- SDTCisVT<1, i32>]>;
-def SDT_NVPTXCallSeqEnd : SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>;
-
-def callseq_start : SDNode<"ISD::CALLSEQ_START", SDT_NVPTXCallSeqStart,
- [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
-def callseq_end : SDNode<"ISD::CALLSEQ_END", SDT_NVPTXCallSeqEnd,
- [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
- SDNPSideEffect]>;
-
-def Callseq_Start :
- NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
- "\\{ // callseq $amt1, $amt2",
- [(callseq_start timm:$amt1, timm:$amt2)]>;
-def Callseq_End :
- NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
- "\\} // callseq $amt1",
- [(callseq_end timm:$amt1, timm:$amt2)]>;
// trap instruction
def trapinst : BasicNVPTXInst<(outs), (ins), "trap", [(trap)]>, Requires<[noPTXASUnreachableBug]>;
@@ -2547,18 +2222,6 @@ def trapexitinst : NVPTXInst<(outs), (ins), "trap; exit;", [(trap)]>, Requires<[
// brkpt instruction
def debugtrapinst : BasicNVPTXInst<(outs), (ins), "brkpt", [(debugtrap)]>;
-// Call prototype wrapper
-def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
-def CallPrototype :
- SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype,
- [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
-def ProtoIdent : Operand<i32> {
- let PrintMethod = "printProtoIdent";
-}
-def CALL_PROTOTYPE :
- NVPTXInst<(outs), (ins ProtoIdent:$ident),
- "$ident", [(CallPrototype (i32 texternalsym:$ident))]>;
-
def SDTDynAllocaOp :
SDTypeProfile<1, 2, [SDTCisSameAs<0, 1>, SDTCisInt<1>, SDTCisVT<2, i32>]>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 0a00220..d337192 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -243,63 +243,82 @@ foreach sync = [false, true] in {
}
// vote.{all,any,uni,ballot}
-multiclass VOTE<NVPTXRegClass regclass, string mode, Intrinsic IntOp> {
- def : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred),
- "vote." # mode,
- [(set regclass:$dest, (IntOp i1:$pred))]>,
- Requires<[hasPTX<60>, hasSM<30>]>;
-}
+let Predicates = [hasPTX<60>, hasSM<30>] in {
+ multiclass VOTE<string mode, RegTyInfo t, Intrinsic op> {
+ def : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred),
+ "vote." # mode # "." # t.PtxType,
+ [(set t.Ty:$dest, (op i1:$pred))]>;
+ }
-defm VOTE_ALL : VOTE<B1, "all.pred", int_nvvm_vote_all>;
-defm VOTE_ANY : VOTE<B1, "any.pred", int_nvvm_vote_any>;
-defm VOTE_UNI : VOTE<B1, "uni.pred", int_nvvm_vote_uni>;
-defm VOTE_BALLOT : VOTE<B32, "ballot.b32", int_nvvm_vote_ballot>;
+ defm VOTE_ALL : VOTE<"all", I1RT, int_nvvm_vote_all>;
+ defm VOTE_ANY : VOTE<"any", I1RT, int_nvvm_vote_any>;
+ defm VOTE_UNI : VOTE<"uni", I1RT, int_nvvm_vote_uni>;
+ defm VOTE_BALLOT : VOTE<"ballot", I32RT, int_nvvm_vote_ballot>;
+
+ // vote.sync.{all,any,uni,ballot}
+ multiclass VOTE_SYNC<string mode, RegTyInfo t, Intrinsic op> {
+ def i : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred, i32imm:$mask),
+ "vote.sync." # mode # "." # t.PtxType,
+ [(set t.Ty:$dest, (op imm:$mask, i1:$pred))]>;
+ def r : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred, B32:$mask),
+ "vote.sync." # mode # "." # t.PtxType,
+ [(set t.Ty:$dest, (op i32:$mask, i1:$pred))]>;
+ }
-// vote.sync.{all,any,uni,ballot}
-multiclass VOTE_SYNC<NVPTXRegClass regclass, string mode, Intrinsic IntOp> {
- def i : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred, i32imm:$mask),
- "vote.sync." # mode,
- [(set regclass:$dest, (IntOp imm:$mask, i1:$pred))]>,
- Requires<[hasPTX<60>, hasSM<30>]>;
- def r : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred, B32:$mask),
- "vote.sync." # mode,
- [(set regclass:$dest, (IntOp i32:$mask, i1:$pred))]>,
- Requires<[hasPTX<60>, hasSM<30>]>;
+ defm VOTE_SYNC_ALL : VOTE_SYNC<"all", I1RT, int_nvvm_vote_all_sync>;
+ defm VOTE_SYNC_ANY : VOTE_SYNC<"any", I1RT, int_nvvm_vote_any_sync>;
+ defm VOTE_SYNC_UNI : VOTE_SYNC<"uni", I1RT, int_nvvm_vote_uni_sync>;
+ defm VOTE_SYNC_BALLOT : VOTE_SYNC<"ballot", I32RT, int_nvvm_vote_ballot_sync>;
}
-
-defm VOTE_SYNC_ALL : VOTE_SYNC<B1, "all.pred", int_nvvm_vote_all_sync>;
-defm VOTE_SYNC_ANY : VOTE_SYNC<B1, "any.pred", int_nvvm_vote_any_sync>;
-defm VOTE_SYNC_UNI : VOTE_SYNC<B1, "uni.pred", int_nvvm_vote_uni_sync>;
-defm VOTE_SYNC_BALLOT : VOTE_SYNC<B32, "ballot.b32", int_nvvm_vote_ballot_sync>;
-
// elect.sync
+let Predicates = [hasPTX<80>, hasSM<90>] in {
def INT_ELECT_SYNC_I : BasicNVPTXInst<(outs B32:$dest, B1:$pred), (ins i32imm:$mask),
"elect.sync",
- [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync imm:$mask))]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync imm:$mask))]>;
def INT_ELECT_SYNC_R : BasicNVPTXInst<(outs B32:$dest, B1:$pred), (ins B32:$mask),
"elect.sync",
- [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync i32:$mask))]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync i32:$mask))]>;
+}
+
+let Predicates = [hasPTX<60>, hasSM<70>] in {
+ multiclass MATCH_ANY_SYNC<Intrinsic op, RegTyInfo t> {
+ def ii : BasicNVPTXInst<(outs B32:$dest), (ins t.Imm:$value, i32imm:$mask),
+ "match.any.sync." # t.PtxType,
+ [(set i32:$dest, (op imm:$mask, imm:$value))]>;
+ def ir : BasicNVPTXInst<(outs B32:$dest), (ins t.Imm:$value, B32:$mask),
+ "match.any.sync." # t.PtxType,
+ [(set i32:$dest, (op i32:$mask, imm:$value))]>;
+ def ri : BasicNVPTXInst<(outs B32:$dest), (ins t.RC:$value, i32imm:$mask),
+ "match.any.sync." # t.PtxType,
+ [(set i32:$dest, (op imm:$mask, t.Ty:$value))]>;
+ def rr : BasicNVPTXInst<(outs B32:$dest), (ins t.RC:$value, B32:$mask),
+ "match.any.sync." # t.PtxType,
+ [(set i32:$dest, (op i32:$mask, t.Ty:$value))]>;
+ }
-multiclass MATCH_ANY_SYNC<NVPTXRegClass regclass, string ptxtype, Intrinsic IntOp,
- Operand ImmOp> {
- def ii : BasicNVPTXInst<(outs B32:$dest), (ins ImmOp:$value, i32imm:$mask),
- "match.any.sync." # ptxtype,
- [(set i32:$dest, (IntOp imm:$mask, imm:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def ir : BasicNVPTXInst<(outs B32:$dest), (ins ImmOp:$value, B32:$mask),
- "match.any.sync." # ptxtype,
- [(set i32:$dest, (IntOp i32:$mask, imm:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def ri : BasicNVPTXInst<(outs B32:$dest), (ins regclass:$value, i32imm:$mask),
- "match.any.sync." # ptxtype,
- [(set i32:$dest, (IntOp imm:$mask, regclass:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def rr : BasicNVPTXInst<(outs B32:$dest), (ins regclass:$value, B32:$mask),
- "match.any.sync." # ptxtype,
- [(set i32:$dest, (IntOp i32:$mask, regclass:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
+ defm MATCH_ANY_SYNC_32 : MATCH_ANY_SYNC<int_nvvm_match_any_sync_i32, I32RT>;
+ defm MATCH_ANY_SYNC_64 : MATCH_ANY_SYNC<int_nvvm_match_any_sync_i64, I64RT>;
+
+ multiclass MATCH_ALLP_SYNC<RegTyInfo t, Intrinsic op> {
+ def ii : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
+ (ins t.Imm:$value, i32imm:$mask),
+ "match.all.sync." # t.PtxType,
+ [(set i32:$dest, i1:$pred, (op imm:$mask, imm:$value))]>;
+ def ir : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
+ (ins t.Imm:$value, B32:$mask),
+ "match.all.sync." # t.PtxType,
+ [(set i32:$dest, i1:$pred, (op i32:$mask, imm:$value))]>;
+ def ri : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
+ (ins t.RC:$value, i32imm:$mask),
+ "match.all.sync." # t.PtxType,
+ [(set i32:$dest, i1:$pred, (op imm:$mask, t.Ty:$value))]>;
+ def rr : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
+ (ins t.RC:$value, B32:$mask),
+ "match.all.sync." # t.PtxType,
+ [(set i32:$dest, i1:$pred, (op i32:$mask, t.Ty:$value))]>;
+ }
+ defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<I32RT, int_nvvm_match_all_sync_i32p>;
+ defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<I64RT, int_nvvm_match_all_sync_i64p>;
}
// activemask.b32
@@ -308,39 +327,6 @@ def ACTIVEMASK : BasicNVPTXInst<(outs B32:$dest), (ins),
[(set i32:$dest, (int_nvvm_activemask))]>,
Requires<[hasPTX<62>, hasSM<30>]>;
-defm MATCH_ANY_SYNC_32 : MATCH_ANY_SYNC<B32, "b32", int_nvvm_match_any_sync_i32,
- i32imm>;
-defm MATCH_ANY_SYNC_64 : MATCH_ANY_SYNC<B64, "b64", int_nvvm_match_any_sync_i64,
- i64imm>;
-
-multiclass MATCH_ALLP_SYNC<NVPTXRegClass regclass, string ptxtype, Intrinsic IntOp,
- Operand ImmOp> {
- def ii : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
- (ins ImmOp:$value, i32imm:$mask),
- "match.all.sync." # ptxtype,
- [(set i32:$dest, i1:$pred, (IntOp imm:$mask, imm:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def ir : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
- (ins ImmOp:$value, B32:$mask),
- "match.all.sync." # ptxtype,
- [(set i32:$dest, i1:$pred, (IntOp i32:$mask, imm:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def ri : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
- (ins regclass:$value, i32imm:$mask),
- "match.all.sync." # ptxtype,
- [(set i32:$dest, i1:$pred, (IntOp imm:$mask, regclass:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def rr : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
- (ins regclass:$value, B32:$mask),
- "match.all.sync." # ptxtype,
- [(set i32:$dest, i1:$pred, (IntOp i32:$mask, regclass:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
-}
-defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<B32, "b32", int_nvvm_match_all_sync_i32p,
- i32imm>;
-defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<B64, "b64", int_nvvm_match_all_sync_i64p,
- i64imm>;
-
multiclass REDUX_SYNC<string BinOp, string PTXType, Intrinsic Intrin> {
def : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src, B32:$mask),
"redux.sync." # BinOp # "." # PTXType,
@@ -381,24 +367,20 @@ defm REDUX_SYNC_FMAX_ABS_NAN: REDUX_SYNC_F<"max", ".abs", ".NaN">;
//-----------------------------------
// Explicit Memory Fence Functions
//-----------------------------------
-class MEMBAR<string StrOp, Intrinsic IntOP> :
- BasicNVPTXInst<(outs), (ins),
- StrOp, [(IntOP)]>;
+class NullaryInst<string StrOp, Intrinsic IntOP> :
+ BasicNVPTXInst<(outs), (ins), StrOp, [(IntOP)]>;
-def INT_MEMBAR_CTA : MEMBAR<"membar.cta", int_nvvm_membar_cta>;
-def INT_MEMBAR_GL : MEMBAR<"membar.gl", int_nvvm_membar_gl>;
-def INT_MEMBAR_SYS : MEMBAR<"membar.sys", int_nvvm_membar_sys>;
+def INT_MEMBAR_CTA : NullaryInst<"membar.cta", int_nvvm_membar_cta>;
+def INT_MEMBAR_GL : NullaryInst<"membar.gl", int_nvvm_membar_gl>;
+def INT_MEMBAR_SYS : NullaryInst<"membar.sys", int_nvvm_membar_sys>;
def INT_FENCE_SC_CLUSTER:
- MEMBAR<"fence.sc.cluster", int_nvvm_fence_sc_cluster>,
+ NullaryInst<"fence.sc.cluster", int_nvvm_fence_sc_cluster>,
Requires<[hasPTX<78>, hasSM<90>]>;
// Proxy fence (uni-directional)
-// fence.proxy.tensormap.release variants
-
class FENCE_PROXY_TENSORMAP_GENERIC_RELEASE<string Scope, Intrinsic Intr> :
- BasicNVPTXInst<(outs), (ins),
- "fence.proxy.tensormap::generic.release." # Scope, [(Intr)]>,
+ NullaryInst<"fence.proxy.tensormap::generic.release." # Scope, Intr>,
Requires<[hasPTX<83>, hasSM<90>]>;
def INT_FENCE_PROXY_TENSORMAP_GENERIC_RELEASE_CTA:
@@ -488,35 +470,31 @@ defm CP_ASYNC_CG_SHARED_GLOBAL_16 :
CP_ASYNC_SHARED_GLOBAL_I<"cg", "16", int_nvvm_cp_async_cg_shared_global_16,
int_nvvm_cp_async_cg_shared_global_16_s>;
-def CP_ASYNC_COMMIT_GROUP :
- BasicNVPTXInst<(outs), (ins), "cp.async.commit_group", [(int_nvvm_cp_async_commit_group)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
+let Predicates = [hasPTX<70>, hasSM<80>] in {
+ def CP_ASYNC_COMMIT_GROUP :
+ NullaryInst<"cp.async.commit_group", int_nvvm_cp_async_commit_group>;
-def CP_ASYNC_WAIT_GROUP :
- BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group",
- [(int_nvvm_cp_async_wait_group timm:$n)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
+ def CP_ASYNC_WAIT_GROUP :
+ BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group",
+ [(int_nvvm_cp_async_wait_group timm:$n)]>;
-def CP_ASYNC_WAIT_ALL :
- BasicNVPTXInst<(outs), (ins), "cp.async.wait_all",
- [(int_nvvm_cp_async_wait_all)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
+ def CP_ASYNC_WAIT_ALL :
+ NullaryInst<"cp.async.wait_all", int_nvvm_cp_async_wait_all>;
+}
-// cp.async.bulk variants of the commit/wait group
-def CP_ASYNC_BULK_COMMIT_GROUP :
- BasicNVPTXInst<(outs), (ins), "cp.async.bulk.commit_group",
- [(int_nvvm_cp_async_bulk_commit_group)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+let Predicates = [hasPTX<80>, hasSM<90>] in {
+ // cp.async.bulk variants of the commit/wait group
+ def CP_ASYNC_BULK_COMMIT_GROUP :
+ NullaryInst<"cp.async.bulk.commit_group", int_nvvm_cp_async_bulk_commit_group>;
-def CP_ASYNC_BULK_WAIT_GROUP :
- BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group",
- [(int_nvvm_cp_async_bulk_wait_group timm:$n)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ def CP_ASYNC_BULK_WAIT_GROUP :
+ BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group",
+ [(int_nvvm_cp_async_bulk_wait_group timm:$n)]>;
-def CP_ASYNC_BULK_WAIT_GROUP_READ :
- BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group.read",
- [(int_nvvm_cp_async_bulk_wait_group_read timm:$n)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ def CP_ASYNC_BULK_WAIT_GROUP_READ :
+ BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group.read",
+ [(int_nvvm_cp_async_bulk_wait_group_read timm:$n)]>;
+}
//------------------------------
// TMA Async Bulk Copy Functions
@@ -974,33 +952,30 @@ defm TMA_TENSOR_PF_TILE_GATHER4_2D : TMA_TENSOR_PREFETCH_INTR<5, "tile_gather4",
//Prefetch and Prefetchu
-class PREFETCH_INTRS<string InstName> :
- BasicNVPTXInst<(outs), (ins ADDR:$addr),
- InstName,
- [(!cast<Intrinsic>(!strconcat("int_nvvm_",
- !subst(".", "_", InstName))) addr:$addr)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
-
+let Predicates = [hasPTX<80>, hasSM<90>] in {
+ class PREFETCH_INTRS<string InstName> :
+ BasicNVPTXInst<(outs), (ins ADDR:$addr),
+ InstName,
+ [(!cast<Intrinsic>(!strconcat("int_nvvm_",
+ !subst(".", "_", InstName))) addr:$addr)]>;
-def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">;
-def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">;
-def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">;
-def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">;
-def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">;
-def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">;
+ def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">;
+ def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">;
+ def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">;
+ def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">;
+ def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">;
+ def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">;
-def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr),
- "prefetch.global.L2::evict_normal",
- [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr),
+ "prefetch.global.L2::evict_normal",
+ [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>;
-def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr),
- "prefetch.global.L2::evict_last",
- [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr),
+ "prefetch.global.L2::evict_last",
+ [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>;
-
-def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">;
+ def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">;
+}
//Applypriority intrinsics
class APPLYPRIORITY_L2_INTRS<string addrspace> :
@@ -1031,99 +1006,82 @@ def DISCARD_GLOBAL_L2 : DISCARD_L2_INTRS<"global">;
// MBarrier Functions
//-----------------------------------
-multiclass MBARRIER_INIT<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$count),
- "mbarrier.init" # AddrSpace # ".b64",
- [(Intrin addr:$addr, i32:$count)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>;
-defm MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared",
- int_nvvm_mbarrier_init_shared>;
-
-multiclass MBARRIER_INVAL<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr),
- "mbarrier.inval" # AddrSpace # ".b64",
- [(Intrin addr:$addr)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>;
-defm MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared",
- int_nvvm_mbarrier_inval_shared>;
-
-multiclass MBARRIER_ARRIVE<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr),
- "mbarrier.arrive" # AddrSpace # ".b64",
- [(set i64:$state, (Intrin addr:$addr))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>;
-defm MBARRIER_ARRIVE_SHARED :
- MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>;
-
-multiclass MBARRIER_ARRIVE_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B64:$state),
- (ins ADDR:$addr, B32:$count),
- "mbarrier.arrive.noComplete" # AddrSpace # ".b64",
- [(set i64:$state, (Intrin addr:$addr, i32:$count))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_ARRIVE_NOCOMPLETE :
- MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>;
-defm MBARRIER_ARRIVE_NOCOMPLETE_SHARED :
- MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>;
-
-multiclass MBARRIER_ARRIVE_DROP<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr),
- "mbarrier.arrive_drop" # AddrSpace # ".b64",
- [(set i64:$state, (Intrin addr:$addr))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_ARRIVE_DROP :
- MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>;
-defm MBARRIER_ARRIVE_DROP_SHARED :
- MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>;
-
-multiclass MBARRIER_ARRIVE_DROP_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B64:$state),
- (ins ADDR:$addr, B32:$count),
- "mbarrier.arrive_drop.noComplete" # AddrSpace # ".b64",
- [(set i64:$state, (Intrin addr:$addr, i32:$count))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_ARRIVE_DROP_NOCOMPLETE :
- MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>;
-defm MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED :
- MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared",
- int_nvvm_mbarrier_arrive_drop_noComplete_shared>;
-
-multiclass MBARRIER_TEST_WAIT<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B1:$res), (ins ADDR:$addr, B64:$state),
- "mbarrier.test_wait" # AddrSpace # ".b64",
- [(set i1:$res, (Intrin addr:$addr, i64:$state))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
+let Predicates = [hasPTX<70>, hasSM<80>] in {
+ class MBARRIER_INIT<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$count),
+ "mbarrier.init" # AddrSpace # ".b64",
+ [(Intrin addr:$addr, i32:$count)]>;
+
+ def MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>;
+ def MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared",
+ int_nvvm_mbarrier_init_shared>;
+
+ class MBARRIER_INVAL<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs), (ins ADDR:$addr),
+ "mbarrier.inval" # AddrSpace # ".b64",
+ [(Intrin addr:$addr)]>;
+
+ def MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>;
+ def MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared",
+ int_nvvm_mbarrier_inval_shared>;
+
+ class MBARRIER_ARRIVE<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr),
+ "mbarrier.arrive" # AddrSpace # ".b64",
+ [(set i64:$state, (Intrin addr:$addr))]>;
+
+ def MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>;
+ def MBARRIER_ARRIVE_SHARED :
+ MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>;
+
+ class MBARRIER_ARRIVE_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B64:$state),
+ (ins ADDR:$addr, B32:$count),
+ "mbarrier.arrive.noComplete" # AddrSpace # ".b64",
+ [(set i64:$state, (Intrin addr:$addr, i32:$count))]>;
+
+ def MBARRIER_ARRIVE_NOCOMPLETE :
+ MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>;
+ def MBARRIER_ARRIVE_NOCOMPLETE_SHARED :
+ MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>;
+
+ class MBARRIER_ARRIVE_DROP<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr),
+ "mbarrier.arrive_drop" # AddrSpace # ".b64",
+ [(set i64:$state, (Intrin addr:$addr))]>;
+
+ def MBARRIER_ARRIVE_DROP :
+ MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>;
+ def MBARRIER_ARRIVE_DROP_SHARED :
+ MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>;
+
+ class MBARRIER_ARRIVE_DROP_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B64:$state),
+ (ins ADDR:$addr, B32:$count),
+ "mbarrier.arrive_drop.noComplete" # AddrSpace # ".b64",
+ [(set i64:$state, (Intrin addr:$addr, i32:$count))]>;
+
+ def MBARRIER_ARRIVE_DROP_NOCOMPLETE :
+ MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>;
+ def MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED :
+ MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared",
+ int_nvvm_mbarrier_arrive_drop_noComplete_shared>;
+
+ class MBARRIER_TEST_WAIT<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B1:$res), (ins ADDR:$addr, B64:$state),
+ "mbarrier.test_wait" # AddrSpace # ".b64",
+ [(set i1:$res, (Intrin addr:$addr, i64:$state))]>;
+
+ def MBARRIER_TEST_WAIT :
+ MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>;
+ def MBARRIER_TEST_WAIT_SHARED :
+ MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>;
+
+ def MBARRIER_PENDING_COUNT :
+ BasicNVPTXInst<(outs B32:$res), (ins B64:$state),
+ "mbarrier.pending_count.b64",
+ [(set i32:$res, (int_nvvm_mbarrier_pending_count i64:$state))]>;
}
-
-defm MBARRIER_TEST_WAIT :
- MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>;
-defm MBARRIER_TEST_WAIT_SHARED :
- MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>;
-
-class MBARRIER_PENDING_COUNT<Intrinsic Intrin> :
- BasicNVPTXInst<(outs B32:$res), (ins B64:$state),
- "mbarrier.pending_count.b64",
- [(set i32:$res, (Intrin i64:$state))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-
-def MBARRIER_PENDING_COUNT :
- MBARRIER_PENDING_COUNT<int_nvvm_mbarrier_pending_count>;
-
//-----------------------------------
// Math Functions
//-----------------------------------
@@ -1449,15 +1407,11 @@ defm ABS_F64 : F_ABS<"f64", F64RT, support_ftz = false>;
def fcopysign_nvptx : SDNode<"NVPTXISD::FCOPYSIGN", SDTFPBinOp>;
-def COPYSIGN_F :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$src0, B32:$src1),
- "copysign.f32",
- [(set f32:$dst, (fcopysign_nvptx f32:$src1, f32:$src0))]>;
-
-def COPYSIGN_D :
- BasicNVPTXInst<(outs B64:$dst), (ins B64:$src0, B64:$src1),
- "copysign.f64",
- [(set f64:$dst, (fcopysign_nvptx f64:$src1, f64:$src0))]>;
+foreach t = [F32RT, F64RT] in
+ def COPYSIGN_ # t :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src0, t.RC:$src1),
+ "copysign." # t.PtxType,
+ [(set t.Ty:$dst, (fcopysign_nvptx t.Ty:$src1, t.Ty:$src0))]>;
//
// Neg bf16, bf16x2
@@ -2255,38 +2209,35 @@ defm INT_PTX_SATOM_XOR : ATOM2_bitwise_impl<"xor">;
// Scalar
-class LDU_G<string TyStr, NVPTXRegClass regclass>
- : NVPTXInst<(outs regclass:$result), (ins ADDR:$src),
- "ldu.global." # TyStr # " \t$result, [$src];", []>;
+class LDU_G<NVPTXRegClass regclass>
+ : NVPTXInst<(outs regclass:$result), (ins i32imm:$fromWidth, ADDR:$src),
+ "ldu.global.b$fromWidth \t$result, [$src];", []>;
-def LDU_GLOBAL_i8 : LDU_G<"b8", B16>;
-def LDU_GLOBAL_i16 : LDU_G<"b16", B16>;
-def LDU_GLOBAL_i32 : LDU_G<"b32", B32>;
-def LDU_GLOBAL_i64 : LDU_G<"b64", B64>;
+def LDU_GLOBAL_i16 : LDU_G<B16>;
+def LDU_GLOBAL_i32 : LDU_G<B32>;
+def LDU_GLOBAL_i64 : LDU_G<B64>;
// vector
// Elementized vector ldu
-class VLDU_G_ELE_V2<string TyStr, NVPTXRegClass regclass>
+class VLDU_G_ELE_V2<NVPTXRegClass regclass>
: NVPTXInst<(outs regclass:$dst1, regclass:$dst2),
- (ins ADDR:$src),
- "ldu.global.v2." # TyStr # " \t{{$dst1, $dst2}}, [$src];", []>;
+ (ins i32imm:$fromWidth, ADDR:$src),
+ "ldu.global.v2.b$fromWidth \t{{$dst1, $dst2}}, [$src];", []>;
-class VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass>
- : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins ADDR:$src),
- "ldu.global.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>;
+class VLDU_G_ELE_V4<NVPTXRegClass regclass>
+ : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4),
+ (ins i32imm:$fromWidth, ADDR:$src),
+ "ldu.global.v4.b$fromWidth \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>;
-def LDU_GLOBAL_v2i8 : VLDU_G_ELE_V2<"b8", B16>;
-def LDU_GLOBAL_v2i16 : VLDU_G_ELE_V2<"b16", B16>;
-def LDU_GLOBAL_v2i32 : VLDU_G_ELE_V2<"b32", B32>;
-def LDU_GLOBAL_v2i64 : VLDU_G_ELE_V2<"b64", B64>;
+def LDU_GLOBAL_v2i16 : VLDU_G_ELE_V2<B16>;
+def LDU_GLOBAL_v2i32 : VLDU_G_ELE_V2<B32>;
+def LDU_GLOBAL_v2i64 : VLDU_G_ELE_V2<B64>;
-def LDU_GLOBAL_v4i8 : VLDU_G_ELE_V4<"b8", B16>;
-def LDU_GLOBAL_v4i16 : VLDU_G_ELE_V4<"b16", B16>;
-def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<"b32", B32>;
+def LDU_GLOBAL_v4i16 : VLDU_G_ELE_V4<B16>;
+def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<B32>;
//-----------------------------------
@@ -2327,12 +2278,10 @@ class VLDG_G_ELE_V8<NVPTXRegClass regclass> :
"ld.global.nc.v8.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];", []>;
// FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.
-def LD_GLOBAL_NC_v2i8 : VLDG_G_ELE_V2<B16>;
def LD_GLOBAL_NC_v2i16 : VLDG_G_ELE_V2<B16>;
def LD_GLOBAL_NC_v2i32 : VLDG_G_ELE_V2<B32>;
def LD_GLOBAL_NC_v2i64 : VLDG_G_ELE_V2<B64>;
-def LD_GLOBAL_NC_v4i8 : VLDG_G_ELE_V4<B16>;
def LD_GLOBAL_NC_v4i16 : VLDG_G_ELE_V4<B16>;
def LD_GLOBAL_NC_v4i32 : VLDG_G_ELE_V4<B32>;
@@ -2342,19 +2291,19 @@ def LD_GLOBAL_NC_v8i32 : VLDG_G_ELE_V8<B32>;
multiclass NG_TO_G<string Str, bit Supports32 = 1, list<Predicate> Preds = []> {
if Supports32 then
def "" : BasicNVPTXInst<(outs B32:$result), (ins B32:$src),
- "cvta." # Str # ".u32", []>, Requires<Preds>;
+ "cvta." # Str # ".u32">, Requires<Preds>;
def _64 : BasicNVPTXInst<(outs B64:$result), (ins B64:$src),
- "cvta." # Str # ".u64", []>, Requires<Preds>;
+ "cvta." # Str # ".u64">, Requires<Preds>;
}
multiclass G_TO_NG<string Str, bit Supports32 = 1, list<Predicate> Preds = []> {
if Supports32 then
def "" : BasicNVPTXInst<(outs B32:$result), (ins B32:$src),
- "cvta.to." # Str # ".u32", []>, Requires<Preds>;
+ "cvta.to." # Str # ".u32">, Requires<Preds>;
def _64 : BasicNVPTXInst<(outs B64:$result), (ins B64:$src),
- "cvta.to." # Str # ".u64", []>, Requires<Preds>;
+ "cvta.to." # Str # ".u64">, Requires<Preds>;
}
foreach space = ["local", "shared", "global", "const", "param"] in {
@@ -4614,9 +4563,9 @@ def INT_PTX_SREG_LANEMASK_GT :
PTX_READ_SREG_R32<"lanemask_gt", int_nvvm_read_ptx_sreg_lanemask_gt>;
let hasSideEffects = 1 in {
-def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>;
-def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>;
-def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>;
+ def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>;
+ def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>;
+ def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>;
}
def: Pat <(i64 (readcyclecounter)), (SREG_CLOCK64)>;
@@ -5096,37 +5045,36 @@ foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
def : MMA_PAT<mma>;
multiclass MAPA<string suffix, Intrinsic Intr> {
- def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, B32:$b),
- "mapa" # suffix # ".u32",
- [(set i32:$d, (Intr i32:$a, i32:$b))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
- def _32i: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b),
- "mapa" # suffix # ".u32",
- [(set i32:$d, (Intr i32:$a, imm:$b))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
- def _64: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, B32:$b),
- "mapa" # suffix # ".u64",
- [(set i64:$d, (Intr i64:$a, i32:$b))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
- def _64i: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, i32imm:$b),
- "mapa" # suffix # ".u64",
- [(set i64:$d, (Intr i64:$a, imm:$b))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
+ let Predicates = [hasSM<90>, hasPTX<78>] in {
+ def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, B32:$b),
+ "mapa" # suffix # ".u32",
+ [(set i32:$d, (Intr i32:$a, i32:$b))]>;
+ def _32i: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b),
+ "mapa" # suffix # ".u32",
+ [(set i32:$d, (Intr i32:$a, imm:$b))]>;
+ def _64: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, B32:$b),
+ "mapa" # suffix # ".u64",
+ [(set i64:$d, (Intr i64:$a, i32:$b))]>;
+ def _64i: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, i32imm:$b),
+ "mapa" # suffix # ".u64",
+ [(set i64:$d, (Intr i64:$a, imm:$b))]>;
+ }
}
+
defm mapa : MAPA<"", int_nvvm_mapa>;
defm mapa_shared_cluster : MAPA<".shared::cluster", int_nvvm_mapa_shared_cluster>;
multiclass GETCTARANK<string suffix, Intrinsic Intr> {
- def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a),
- "getctarank" # suffix # ".u32",
- [(set i32:$d, (Intr i32:$a))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
- def _64: BasicNVPTXInst<(outs B32:$d), (ins B64:$a),
- "getctarank" # suffix # ".u64",
- [(set i32:$d, (Intr i64:$a))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
+ let Predicates = [hasSM<90>, hasPTX<78>] in {
+ def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a),
+ "getctarank" # suffix # ".u32",
+ [(set i32:$d, (Intr i32:$a))]>;
+ def _64: BasicNVPTXInst<(outs B32:$d), (ins B64:$a),
+ "getctarank" # suffix # ".u64",
+ [(set i32:$d, (Intr i64:$a))]>;
+ }
}
defm getctarank : GETCTARANK<"", int_nvvm_getctarank>;
@@ -5165,29 +5113,25 @@ def INT_NVVM_WGMMA_WAIT_GROUP_SYNC_ALIGNED : BasicNVPTXInst<(outs), (ins i64imm:
[(int_nvvm_wgmma_wait_group_sync_aligned timm:$n)]>, Requires<[hasSM90a, hasPTX<80>]>;
} // isConvergent = true
-def GRIDDEPCONTROL_LAUNCH_DEPENDENTS :
- BasicNVPTXInst<(outs), (ins),
- "griddepcontrol.launch_dependents",
- [(int_nvvm_griddepcontrol_launch_dependents)]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
-
-def GRIDDEPCONTROL_WAIT :
- BasicNVPTXInst<(outs), (ins),
- "griddepcontrol.wait",
- [(int_nvvm_griddepcontrol_wait)]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
+let Predicates = [hasSM<90>, hasPTX<78>] in {
+ def GRIDDEPCONTROL_LAUNCH_DEPENDENTS :
+ BasicNVPTXInst<(outs), (ins), "griddepcontrol.launch_dependents",
+ [(int_nvvm_griddepcontrol_launch_dependents)]>;
+ def GRIDDEPCONTROL_WAIT :
+ BasicNVPTXInst<(outs), (ins), "griddepcontrol.wait",
+ [(int_nvvm_griddepcontrol_wait)]>;
+}
def INT_EXIT : BasicNVPTXInst<(outs), (ins), "exit", [(int_nvvm_exit)]>;
// Tcgen05 intrinsics
-let isConvergent = true in {
+let isConvergent = true, Predicates = [hasTcgen05Instructions] in {
multiclass TCGEN05_ALLOC_INTR<string AS, string num, Intrinsic Intr> {
def "" : BasicNVPTXInst<(outs),
(ins ADDR:$dst, B32:$ncols),
"tcgen05.alloc.cta_group::" # num # ".sync.aligned" # AS # ".b32",
- [(Intr addr:$dst, B32:$ncols)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr addr:$dst, B32:$ncols)]>;
}
defm TCGEN05_ALLOC_CG1 : TCGEN05_ALLOC_INTR<"", "1", int_nvvm_tcgen05_alloc_cg1>;
@@ -5200,8 +5144,7 @@ multiclass TCGEN05_DEALLOC_INTR<string num, Intrinsic Intr> {
def "" : BasicNVPTXInst<(outs),
(ins B32:$tmem_addr, B32:$ncols),
"tcgen05.dealloc.cta_group::" # num # ".sync.aligned.b32",
- [(Intr B32:$tmem_addr, B32:$ncols)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr B32:$tmem_addr, B32:$ncols)]>;
}
defm TCGEN05_DEALLOC_CG1: TCGEN05_DEALLOC_INTR<"1", int_nvvm_tcgen05_dealloc_cg1>;
defm TCGEN05_DEALLOC_CG2: TCGEN05_DEALLOC_INTR<"2", int_nvvm_tcgen05_dealloc_cg2>;
@@ -5209,19 +5152,13 @@ defm TCGEN05_DEALLOC_CG2: TCGEN05_DEALLOC_INTR<"2", int_nvvm_tcgen05_dealloc_cg2
multiclass TCGEN05_RELINQ_PERMIT_INTR<string num, Intrinsic Intr> {
def "" : BasicNVPTXInst<(outs), (ins),
"tcgen05.relinquish_alloc_permit.cta_group::" # num # ".sync.aligned",
- [(Intr)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr)]>;
}
defm TCGEN05_RELINQ_CG1: TCGEN05_RELINQ_PERMIT_INTR<"1", int_nvvm_tcgen05_relinq_alloc_permit_cg1>;
defm TCGEN05_RELINQ_CG2: TCGEN05_RELINQ_PERMIT_INTR<"2", int_nvvm_tcgen05_relinq_alloc_permit_cg2>;
-def tcgen05_wait_ld: BasicNVPTXInst<(outs), (ins), "tcgen05.wait::ld.sync.aligned",
- [(int_nvvm_tcgen05_wait_ld)]>,
- Requires<[hasTcgen05Instructions]>;
-
-def tcgen05_wait_st: BasicNVPTXInst<(outs), (ins), "tcgen05.wait::st.sync.aligned",
- [(int_nvvm_tcgen05_wait_st)]>,
- Requires<[hasTcgen05Instructions]>;
+def tcgen05_wait_ld: NullaryInst<"tcgen05.wait::ld.sync.aligned", int_nvvm_tcgen05_wait_ld>;
+def tcgen05_wait_st: NullaryInst<"tcgen05.wait::st.sync.aligned", int_nvvm_tcgen05_wait_st>;
multiclass TCGEN05_COMMIT_INTR<string AS, string num> {
defvar prefix = "tcgen05.commit.cta_group::" # num #".mbarrier::arrive::one.shared::cluster";
@@ -5232,12 +5169,10 @@ multiclass TCGEN05_COMMIT_INTR<string AS, string num> {
def "" : BasicNVPTXInst<(outs), (ins ADDR:$mbar),
prefix # ".b64",
- [(Intr addr:$mbar)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr addr:$mbar)]>;
def _MC : BasicNVPTXInst<(outs), (ins ADDR:$mbar, B16:$mc),
prefix # ".multicast::cluster.b64",
- [(IntrMC addr:$mbar, B16:$mc)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(IntrMC addr:$mbar, B16:$mc)]>;
}
defm TCGEN05_COMMIT_CG1 : TCGEN05_COMMIT_INTR<"", "1">;
@@ -5249,8 +5184,7 @@ multiclass TCGEN05_SHIFT_INTR<string num, Intrinsic Intr> {
def "" : BasicNVPTXInst<(outs),
(ins ADDR:$tmem_addr),
"tcgen05.shift.cta_group::" # num # ".down",
- [(Intr addr:$tmem_addr)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr addr:$tmem_addr)]>;
}
defm TCGEN05_SHIFT_CG1: TCGEN05_SHIFT_INTR<"1", int_nvvm_tcgen05_shift_down_cg1>;
defm TCGEN05_SHIFT_CG2: TCGEN05_SHIFT_INTR<"2", int_nvvm_tcgen05_shift_down_cg2>;
@@ -5270,13 +5204,11 @@ multiclass TCGEN05_CP_INTR<string shape, string src_fmt, string mc = ""> {
def _cg1 : BasicNVPTXInst<(outs),
(ins ADDR:$tmem_addr, B64:$sdesc),
"tcgen05.cp.cta_group::1." # shape_mc_asm # fmt_asm,
- [(IntrCG1 addr:$tmem_addr, B64:$sdesc)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(IntrCG1 addr:$tmem_addr, B64:$sdesc)]>;
def _cg2 : BasicNVPTXInst<(outs),
(ins ADDR:$tmem_addr, B64:$sdesc),
"tcgen05.cp.cta_group::2." # shape_mc_asm # fmt_asm,
- [(IntrCG2 addr:$tmem_addr, B64:$sdesc)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(IntrCG2 addr:$tmem_addr, B64:$sdesc)]>;
}
foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in {
@@ -5289,17 +5221,13 @@ foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in {
}
} // isConvergent
-let hasSideEffects = 1 in {
+let hasSideEffects = 1, Predicates = [hasTcgen05Instructions] in {
-def tcgen05_fence_before_thread_sync: BasicNVPTXInst<(outs), (ins),
- "tcgen05.fence::before_thread_sync",
- [(int_nvvm_tcgen05_fence_before_thread_sync)]>,
- Requires<[hasTcgen05Instructions]>;
+ def tcgen05_fence_before_thread_sync: NullaryInst<
+ "tcgen05.fence::before_thread_sync", int_nvvm_tcgen05_fence_before_thread_sync>;
-def tcgen05_fence_after_thread_sync: BasicNVPTXInst<(outs), (ins),
- "tcgen05.fence::after_thread_sync",
- [(int_nvvm_tcgen05_fence_after_thread_sync)]>,
- Requires<[hasTcgen05Instructions]>;
+ def tcgen05_fence_after_thread_sync: NullaryInst<
+ "tcgen05.fence::after_thread_sync", int_nvvm_tcgen05_fence_after_thread_sync>;
} // hasSideEffects
@@ -5392,17 +5320,17 @@ foreach shape = ["16x64b", "16x128b", "16x256b", "32x32b", "16x32bx2"] in {
// Bulk store instructions
def st_bulk_imm : TImmLeaf<i64, [{ return Imm == 0; }]>;
-def INT_NVVM_ST_BULK_GENERIC :
- BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value),
- "st.bulk",
- [(int_nvvm_st_bulk addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>,
- Requires<[hasSM<100>, hasPTX<86>]>;
+let Predicates = [hasSM<100>, hasPTX<86>] in {
+ def INT_NVVM_ST_BULK_GENERIC :
+ BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value),
+ "st.bulk",
+ [(int_nvvm_st_bulk addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>;
-def INT_NVVM_ST_BULK_SHARED_CTA:
- BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value),
- "st.bulk.shared::cta",
- [(int_nvvm_st_bulk_shared_cta addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>,
- Requires<[hasSM<100>, hasPTX<86>]>;
+ def INT_NVVM_ST_BULK_SHARED_CTA:
+ BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value),
+ "st.bulk.shared::cta",
+ [(int_nvvm_st_bulk_shared_cta addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>;
+}
//
// clusterlaunchcontorl Instructions
diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
index d40886a..2e81ab1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
@@ -38,14 +38,6 @@ foreach i = 0...4 in {
def R#i : NVPTXReg<"%r"#i>; // 32-bit
def RL#i : NVPTXReg<"%rd"#i>; // 64-bit
def RQ#i : NVPTXReg<"%rq"#i>; // 128-bit
- def H#i : NVPTXReg<"%h"#i>; // 16-bit float
- def HH#i : NVPTXReg<"%hh"#i>; // 2x16-bit float
-
- // Arguments
- def ia#i : NVPTXReg<"%ia"#i>;
- def la#i : NVPTXReg<"%la"#i>;
- def fa#i : NVPTXReg<"%fa"#i>;
- def da#i : NVPTXReg<"%da"#i>;
}
foreach i = 0...31 in {