aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXInstrInfo.td')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td833
1 files changed, 250 insertions, 583 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index b5df4c6..d8047d3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -15,19 +15,8 @@ include "NVPTXInstrFormats.td"
let OperandType = "OPERAND_IMMEDIATE" in {
def f16imm : Operand<f16>;
def bf16imm : Operand<bf16>;
-
}
-// List of vector specific properties
-def isVecLD : VecInstTypeEnum<1>;
-def isVecST : VecInstTypeEnum<2>;
-def isVecBuild : VecInstTypeEnum<3>;
-def isVecShuffle : VecInstTypeEnum<4>;
-def isVecExtract : VecInstTypeEnum<5>;
-def isVecInsert : VecInstTypeEnum<6>;
-def isVecDest : VecInstTypeEnum<7>;
-def isVecOther : VecInstTypeEnum<15>;
-
//===----------------------------------------------------------------------===//
// NVPTX Operand Definitions.
//===----------------------------------------------------------------------===//
@@ -125,8 +114,6 @@ def doF32FTZ : Predicate<"useF32FTZ()">;
def doNoF32FTZ : Predicate<"!useF32FTZ()">;
def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
-def doMulWide : Predicate<"doMulWide">;
-
def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
@@ -486,46 +473,28 @@ let hasSideEffects = false in {
// takes a CvtMode immediate that defines the conversion mode to use. It can
// be CvtNONE to omit a conversion mode.
multiclass CVT_FROM_ALL<string ToType, RegisterClass RC, list<Predicate> Preds = []> {
- def _s8 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B16:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s8">,
- Requires<Preds>;
- def _u8 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B16:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u8">,
- Requires<Preds>;
- def _s16 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B16:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s16">,
- Requires<Preds>;
- def _u16 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B16:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u16">,
- Requires<Preds>;
- def _s32 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B32:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s32">,
- Requires<Preds>;
- def _u32 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B32:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u32">,
- Requires<Preds>;
- def _s64 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B64:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s64">,
- Requires<Preds>;
- def _u64 :
- BasicFlagsNVPTXInst<(outs RC:$dst),
- (ins B64:$src), (ins CvtMode:$mode),
- "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u64">,
- Requires<Preds>;
+ foreach sign = ["s", "u"] in {
+ def _ # sign # "8" :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B16:$src), (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "8">,
+ Requires<Preds>;
+ def _ # sign # "16" :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B16:$src), (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "16">,
+ Requires<Preds>;
+ def _ # sign # "32" :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B32:$src), (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "32">,
+ Requires<Preds>;
+ def _ # sign # "64" :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B64:$src), (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "64">,
+ Requires<Preds>;
+ }
def _f16 :
BasicFlagsNVPTXInst<(outs RC:$dst),
(ins B16:$src), (ins CvtMode:$mode),
@@ -556,14 +525,12 @@ let hasSideEffects = false in {
}
// Generate cvts from all types to all types.
- defm CVT_s8 : CVT_FROM_ALL<"s8", B16>;
- defm CVT_u8 : CVT_FROM_ALL<"u8", B16>;
- defm CVT_s16 : CVT_FROM_ALL<"s16", B16>;
- defm CVT_u16 : CVT_FROM_ALL<"u16", B16>;
- defm CVT_s32 : CVT_FROM_ALL<"s32", B32>;
- defm CVT_u32 : CVT_FROM_ALL<"u32", B32>;
- defm CVT_s64 : CVT_FROM_ALL<"s64", B64>;
- defm CVT_u64 : CVT_FROM_ALL<"u64", B64>;
+ foreach sign = ["s", "u"] in {
+ defm CVT_ # sign # "8" : CVT_FROM_ALL<sign # "8", B16>;
+ defm CVT_ # sign # "16" : CVT_FROM_ALL<sign # "16", B16>;
+ defm CVT_ # sign # "32" : CVT_FROM_ALL<sign # "32", B32>;
+ defm CVT_ # sign # "64" : CVT_FROM_ALL<sign # "64", B64>;
+ }
defm CVT_f16 : CVT_FROM_ALL<"f16", B16>;
defm CVT_bf16 : CVT_FROM_ALL<"bf16", B16, [hasPTX<78>, hasSM<90>]>;
defm CVT_f32 : CVT_FROM_ALL<"f32", B32>;
@@ -571,18 +538,12 @@ let hasSideEffects = false in {
// These cvts are different from those above: The source and dest registers
// are of the same type.
- def CVT_INREG_s16_s8 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src),
- "cvt.s16.s8">;
- def CVT_INREG_s32_s8 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src),
- "cvt.s32.s8">;
- def CVT_INREG_s32_s16 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src),
- "cvt.s32.s16">;
- def CVT_INREG_s64_s8 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
- "cvt.s64.s8">;
- def CVT_INREG_s64_s16 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
- "cvt.s64.s16">;
- def CVT_INREG_s64_s32 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
- "cvt.s64.s32">;
+ def CVT_INREG_s16_s8 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), "cvt.s16.s8">;
+ def CVT_INREG_s32_s8 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "cvt.s32.s8">;
+ def CVT_INREG_s32_s16 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "cvt.s32.s16">;
+ def CVT_INREG_s64_s8 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s8">;
+ def CVT_INREG_s64_s16 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s16">;
+ def CVT_INREG_s64_s32 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s32">;
multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
def _f32 :
@@ -784,7 +745,7 @@ defm SUB : I3<"sub.s", sub, commutative = false>;
def ADD16x2 : I16x2<"add.s", add>;
-// in32 and int64 addition and subtraction with carry-out.
+// int32 and int64 addition and subtraction with carry-out.
defm ADDCC : ADD_SUB_INT_CARRY<"add.cc", addc, commutative = true>;
defm SUBCC : ADD_SUB_INT_CARRY<"sub.cc", subc, commutative = false>;
@@ -805,17 +766,17 @@ defm UDIV : I3<"div.u", udiv, commutative = false>;
defm SREM : I3<"rem.s", srem, commutative = false>;
defm UREM : I3<"rem.u", urem, commutative = false>;
-// Integer absolute value. NumBits should be one minus the bit width of RC.
-// This idiom implements the algorithm at
-// http://graphics.stanford.edu/~seander/bithacks.html#IntegerAbs.
-multiclass ABS<ValueType T, RegisterClass RC, string SizeName> {
- def : BasicNVPTXInst<(outs RC:$dst), (ins RC:$a),
- "abs" # SizeName,
- [(set T:$dst, (abs T:$a))]>;
+foreach t = [I16RT, I32RT, I64RT] in {
+ def ABS_S # t.Size :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a),
+ "abs.s" # t.Size,
+ [(set t.Ty:$dst, (abs t.Ty:$a))]>;
+
+ def NEG_S # t.Size :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
+ "neg.s" # t.Size,
+ [(set t.Ty:$dst, (ineg t.Ty:$src))]>;
}
-defm ABS_16 : ABS<i16, B16, ".s16">;
-defm ABS_32 : ABS<i32, B32, ".s32">;
-defm ABS_64 : ABS<i64, B64, ".s64">;
// Integer min/max.
defm SMAX : I3<"max.s", smax, commutative = true>;
@@ -832,170 +793,63 @@ def UMIN16x2 : I16x2<"min.u", umin>;
//
// Wide multiplication
//
-def MULWIDES64 :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.s32">;
-def MULWIDES64Imm :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.s32">;
-def MULWIDES64Imm64 :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.s32">;
-
-def MULWIDEU64 :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.u32">;
-def MULWIDEU64Imm :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.u32">;
-def MULWIDEU64Imm64 :
- BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.u32">;
-
-def MULWIDES32 :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.s16">;
-def MULWIDES32Imm :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.s16">;
-def MULWIDES32Imm32 :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.s16">;
-
-def MULWIDEU32 :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.u16">;
-def MULWIDEU32Imm :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.u16">;
-def MULWIDEU32Imm32 :
- BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.u16">;
-
-def SDTMulWide : SDTypeProfile<1, 2, [SDTCisSameAs<1, 2>]>;
-def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>;
-def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>;
-
-// Matchers for signed, unsigned mul.wide ISD nodes.
-let Predicates = [doMulWide] in {
- def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>;
- def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>;
- def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>;
- def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)), (MULWIDEU32Imm $a, imm:$b)>;
-
- def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)), (MULWIDES64 $a, $b)>;
- def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)), (MULWIDES64Imm $a, imm:$b)>;
- def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)), (MULWIDEU64 $a, $b)>;
- def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>;
-}
-
-// Predicates used for converting some patterns to mul.wide.
-def SInt32Const : PatLeaf<(imm), [{
- const APInt &v = N->getAPIntValue();
- return v.isSignedIntN(32);
-}]>;
-
-def UInt32Const : PatLeaf<(imm), [{
- const APInt &v = N->getAPIntValue();
- return v.isIntN(32);
-}]>;
-def SInt16Const : PatLeaf<(imm), [{
- const APInt &v = N->getAPIntValue();
- return v.isSignedIntN(16);
-}]>;
-
-def UInt16Const : PatLeaf<(imm), [{
- const APInt &v = N->getAPIntValue();
- return v.isIntN(16);
-}]>;
-
-def IntConst_0_30 : PatLeaf<(imm), [{
- // Check if 0 <= v < 31; only then will the result of (x << v) be an int32.
- const APInt &v = N->getAPIntValue();
- return v.sge(0) && v.slt(31);
-}]>;
-
-def IntConst_0_14 : PatLeaf<(imm), [{
- // Check if 0 <= v < 15; only then will the result of (x << v) be an int16.
- const APInt &v = N->getAPIntValue();
- return v.sge(0) && v.slt(15);
-}]>;
-
-def SHL2MUL32 : SDNodeXForm<imm, [{
- const APInt &v = N->getAPIntValue();
- APInt temp(32, 1);
- return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i32);
-}]>;
+def SDTMulWide : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>;
+def smul_wide : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative]>;
+def umul_wide : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative]>;
-def SHL2MUL16 : SDNodeXForm<imm, [{
- const APInt &v = N->getAPIntValue();
- APInt temp(16, 1);
- return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i16);
-}]>;
-
-// Convert "sign/zero-extend, then shift left by an immediate" to mul.wide.
-let Predicates = [doMulWide] in {
- def : Pat<(shl (sext i32:$a), (i32 IntConst_0_30:$b)),
- (MULWIDES64Imm $a, (SHL2MUL32 $b))>;
- def : Pat<(shl (zext i32:$a), (i32 IntConst_0_30:$b)),
- (MULWIDEU64Imm $a, (SHL2MUL32 $b))>;
-
- def : Pat<(shl (sext i16:$a), (i16 IntConst_0_14:$b)),
- (MULWIDES32Imm $a, (SHL2MUL16 $b))>;
- def : Pat<(shl (zext i16:$a), (i16 IntConst_0_14:$b)),
- (MULWIDEU32Imm $a, (SHL2MUL16 $b))>;
-
- // Convert "sign/zero-extend then multiply" to mul.wide.
- def : Pat<(mul (sext i32:$a), (sext i32:$b)),
- (MULWIDES64 $a, $b)>;
- def : Pat<(mul (sext i32:$a), (i64 SInt32Const:$b)),
- (MULWIDES64Imm64 $a, (i64 SInt32Const:$b))>;
-
- def : Pat<(mul (zext i32:$a), (zext i32:$b)),
- (MULWIDEU64 $a, $b)>;
- def : Pat<(mul (zext i32:$a), (i64 UInt32Const:$b)),
- (MULWIDEU64Imm64 $a, (i64 UInt32Const:$b))>;
- def : Pat<(mul (sext i16:$a), (sext i16:$b)),
- (MULWIDES32 $a, $b)>;
- def : Pat<(mul (sext i16:$a), (i32 SInt16Const:$b)),
- (MULWIDES32Imm32 $a, (i32 SInt16Const:$b))>;
-
- def : Pat<(mul (zext i16:$a), (zext i16:$b)),
- (MULWIDEU32 $a, $b)>;
- def : Pat<(mul (zext i16:$a), (i32 UInt16Const:$b)),
- (MULWIDEU32Imm32 $a, (i32 UInt16Const:$b))>;
+multiclass MULWIDEInst<string suffix, SDPatternOperator op, RegTyInfo big_t, RegTyInfo small_t> {
+ def suffix # _rr :
+ BasicNVPTXInst<(outs big_t.RC:$dst), (ins small_t.RC:$a, small_t.RC:$b),
+ "mul.wide." # suffix,
+ [(set big_t.Ty:$dst, (op small_t.Ty:$a, small_t.Ty:$b))]>;
+ def suffix # _ri :
+ BasicNVPTXInst<(outs big_t.RC:$dst), (ins small_t.RC:$a, small_t.Imm:$b),
+ "mul.wide." # suffix,
+ [(set big_t.Ty:$dst, (op small_t.Ty:$a, imm:$b))]>;
}
+defm MUL_WIDE : MULWIDEInst<"s32", smul_wide, I64RT, I32RT>;
+defm MUL_WIDE : MULWIDEInst<"u32", umul_wide, I64RT, I32RT>;
+defm MUL_WIDE : MULWIDEInst<"s16", smul_wide, I32RT, I16RT>;
+defm MUL_WIDE : MULWIDEInst<"u16", umul_wide, I32RT, I16RT>;
+
//
// Integer multiply-add
//
-def mul_oneuse : OneUse2<mul>;
-
-multiclass MAD<string Ptx, ValueType VT, NVPTXRegClass Reg, Operand Imm> {
+multiclass MADInst<string suffix, SDPatternOperator op, RegTyInfo big_t, RegTyInfo small_t> {
def rrr:
- BasicNVPTXInst<(outs Reg:$dst),
- (ins Reg:$a, Reg:$b, Reg:$c),
- Ptx,
- [(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), VT:$c))]>;
-
- def rir:
- BasicNVPTXInst<(outs Reg:$dst),
- (ins Reg:$a, Imm:$b, Reg:$c),
- Ptx,
- [(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), VT:$c))]>;
+ BasicNVPTXInst<(outs big_t.RC:$dst),
+ (ins small_t.RC:$a, small_t.RC:$b, big_t.RC:$c),
+ "mad." # suffix,
+ [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, small_t.Ty:$b), big_t.Ty:$c))]>;
def rri:
- BasicNVPTXInst<(outs Reg:$dst),
- (ins Reg:$a, Reg:$b, Imm:$c),
- Ptx,
- [(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), imm:$c))]>;
+ BasicNVPTXInst<(outs big_t.RC:$dst),
+ (ins small_t.RC:$a, small_t.RC:$b, big_t.Imm:$c),
+ "mad." # suffix,
+ [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, small_t.Ty:$b), imm:$c))]>;
+ def rir:
+ BasicNVPTXInst<(outs big_t.RC:$dst),
+ (ins small_t.RC:$a, small_t.Imm:$b, big_t.RC:$c),
+ "mad." # suffix,
+ [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, imm:$b), big_t.Ty:$c))]>;
def rii:
- BasicNVPTXInst<(outs Reg:$dst),
- (ins Reg:$a, Imm:$b, Imm:$c),
- Ptx,
- [(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), imm:$c))]>;
+ BasicNVPTXInst<(outs big_t.RC:$dst),
+ (ins small_t.RC:$a, small_t.Imm:$b, big_t.Imm:$c),
+ "mad." # suffix,
+ [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, imm:$b), imm:$c))]>;
}
let Predicates = [hasOptEnabled] in {
-defm MAD16 : MAD<"mad.lo.s16", i16, B16, i16imm>;
-defm MAD32 : MAD<"mad.lo.s32", i32, B32, i32imm>;
-defm MAD64 : MAD<"mad.lo.s64", i64, B64, i64imm>;
-}
+ defm MAD_LO_S16 : MADInst<"lo.s16", mul, I16RT, I16RT>;
+ defm MAD_LO_S32 : MADInst<"lo.s32", mul, I32RT, I32RT>;
+ defm MAD_LO_S64 : MADInst<"lo.s64", mul, I64RT, I64RT>;
-foreach t = [I16RT, I32RT, I64RT] in {
- def NEG_S # t.Size :
- BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
- "neg.s" # t.Size,
- [(set t.Ty:$dst, (ineg t.Ty:$src))]>;
+ defm MAD_WIDE_U16 : MADInst<"wide.u16", umul_wide, I32RT, I16RT>;
+ defm MAD_WIDE_S16 : MADInst<"wide.s16", smul_wide, I32RT, I16RT>;
+ defm MAD_WIDE_U32 : MADInst<"wide.u32", umul_wide, I64RT, I32RT>;
+ defm MAD_WIDE_S32 : MADInst<"wide.s32", smul_wide, I64RT, I32RT>;
}
//-----------------------------------
@@ -1106,8 +960,7 @@ def fdiv_approx : PatFrag<(ops node:$a, node:$b),
def FRCP32_approx_r :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$b), (ins FTZFlag:$ftz),
"rcp.approx$ftz.f32",
[(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>;
@@ -1116,14 +969,12 @@ def FRCP32_approx_r :
//
def FDIV32_approx_rr :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, B32:$b), (ins FTZFlag:$ftz),
"div.approx$ftz.f32",
[(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>;
def FDIV32_approx_ri :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, f32imm:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz),
"div.approx$ftz.f32",
[(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>;
//
@@ -1146,14 +997,12 @@ def : Pat<(fdiv_full f32imm_1, f32:$b),
//
def FDIV32rr :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, B32:$b), (ins FTZFlag:$ftz),
"div.full$ftz.f32",
[(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>;
def FDIV32ri :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, f32imm:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz),
"div.full$ftz.f32",
[(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>;
//
@@ -1167,8 +1016,7 @@ def fdiv_ftz : PatFrag<(ops node:$a, node:$b),
def FRCP32r_prec :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$b), (ins FTZFlag:$ftz),
"rcp.rn$ftz.f32",
[(set f32:$dst, (fdiv_ftz f32imm_1, f32:$b))]>;
//
@@ -1176,14 +1024,12 @@ def FRCP32r_prec :
//
def FDIV32rr_prec :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, B32:$b), (ins FTZFlag:$ftz),
"div.rn$ftz.f32",
[(set f32:$dst, (fdiv_ftz f32:$a, f32:$b))]>;
def FDIV32ri_prec :
BasicFlagsNVPTXInst<(outs B32:$dst),
- (ins B32:$a, f32imm:$b),
- (ins FTZFlag:$ftz),
+ (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz),
"div.rn$ftz.f32",
[(set f32:$dst, (fdiv_ftz f32:$a, fpimm:$b))]>;
@@ -1234,7 +1080,7 @@ defm FMA_F32 : FMA<F32RT, allow_ftz = true>;
defm FMA_F32x2 : FMA<F32X2RT, allow_ftz = true, preds = [hasF32x2Instructions]>;
defm FMA_F64 : FMA<F64RT, allow_ftz = false>;
-// sin/cos
+// sin/cos/tanh
class UnaryOpAllowsApproxFn<SDPatternOperator operator>
: PatFrag<(ops node:$A),
@@ -1250,6 +1096,10 @@ def COS_APPROX_f32 :
BasicFlagsNVPTXInst<(outs B32:$dst), (ins B32:$src), (ins FTZFlag:$ftz),
"cos.approx$ftz.f32",
[(set f32:$dst, (UnaryOpAllowsApproxFn<fcos> f32:$src))]>;
+def TANH_APPROX_f32 :
+ BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "tanh.approx.f32",
+ [(set f32:$dst, (UnaryOpAllowsApproxFn<ftanh> f32:$src))]>,
+ Requires<[hasPTX<70>, hasSM<75>]>;
//-----------------------------------
// Bitwise operations
@@ -1258,10 +1108,8 @@ def COS_APPROX_f32 :
// Template for three-arg bitwise operations. Takes three args, Creates .b16,
// .b32, .b64, and .pred (predicate registers -- i.e., i1) versions of OpcStr.
multiclass BITWISE<string OpcStr, SDNode OpNode> {
- defm b1 : I3Inst<OpcStr # ".pred", OpNode, I1RT, commutative = true>;
- defm b16 : I3Inst<OpcStr # ".b16", OpNode, I16RT, commutative = true>;
- defm b32 : I3Inst<OpcStr # ".b32", OpNode, I32RT, commutative = true>;
- defm b64 : I3Inst<OpcStr # ".b64", OpNode, I64RT, commutative = true>;
+ foreach t = [I1RT, I16RT, I32RT, I64RT] in
+ defm _ # t.PtxType : I3Inst<OpcStr # "." # t.PtxType, OpNode, t, commutative = true>;
}
defm OR : BITWISE<"or", or>;
@@ -1269,48 +1117,40 @@ defm AND : BITWISE<"and", and>;
defm XOR : BITWISE<"xor", xor>;
// PTX does not support mul on predicates, convert to and instructions
-def : Pat<(mul i1:$a, i1:$b), (ANDb1rr $a, $b)>;
-def : Pat<(mul i1:$a, imm:$b), (ANDb1ri $a, imm:$b)>;
+def : Pat<(mul i1:$a, i1:$b), (AND_predrr $a, $b)>;
+def : Pat<(mul i1:$a, imm:$b), (AND_predri $a, imm:$b)>;
foreach op = [add, sub] in {
- def : Pat<(op i1:$a, i1:$b), (XORb1rr $a, $b)>;
- def : Pat<(op i1:$a, imm:$b), (XORb1ri $a, imm:$b)>;
+ def : Pat<(op i1:$a, i1:$b), (XOR_predrr $a, $b)>;
+ def : Pat<(op i1:$a, imm:$b), (XOR_predri $a, imm:$b)>;
}
// These transformations were once reliably performed by instcombine, but thanks
// to poison semantics they are no longer safe for LLVM IR, perform them here
// instead.
-def : Pat<(select i1:$a, i1:$b, 0), (ANDb1rr $a, $b)>;
-def : Pat<(select i1:$a, 1, i1:$b), (ORb1rr $a, $b)>;
+def : Pat<(select i1:$a, i1:$b, 0), (AND_predrr $a, $b)>;
+def : Pat<(select i1:$a, 1, i1:$b), (OR_predrr $a, $b)>;
// Lower logical v2i16/v4i8 ops as bitwise ops on b32.
foreach vt = [v2i16, v4i8] in {
- def : Pat<(or vt:$a, vt:$b), (ORb32rr $a, $b)>;
- def : Pat<(xor vt:$a, vt:$b), (XORb32rr $a, $b)>;
- def : Pat<(and vt:$a, vt:$b), (ANDb32rr $a, $b)>;
+ def : Pat<(or vt:$a, vt:$b), (OR_b32rr $a, $b)>;
+ def : Pat<(xor vt:$a, vt:$b), (XOR_b32rr $a, $b)>;
+ def : Pat<(and vt:$a, vt:$b), (AND_b32rr $a, $b)>;
// The constants get legalized into a bitcast from i32, so that's what we need
// to match here.
def: Pat<(or vt:$a, (vt (bitconvert (i32 imm:$b)))),
- (ORb32ri $a, imm:$b)>;
+ (OR_b32ri $a, imm:$b)>;
def: Pat<(xor vt:$a, (vt (bitconvert (i32 imm:$b)))),
- (XORb32ri $a, imm:$b)>;
+ (XOR_b32ri $a, imm:$b)>;
def: Pat<(and vt:$a, (vt (bitconvert (i32 imm:$b)))),
- (ANDb32ri $a, imm:$b)>;
-}
-
-def NOT1 : BasicNVPTXInst<(outs B1:$dst), (ins B1:$src),
- "not.pred",
- [(set i1:$dst, (not i1:$src))]>;
-def NOT16 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src),
- "not.b16",
- [(set i16:$dst, (not i16:$src))]>;
-def NOT32 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src),
- "not.b32",
- [(set i32:$dst, (not i32:$src))]>;
-def NOT64 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
- "not.b64",
- [(set i64:$dst, (not i64:$src))]>;
+ (AND_b32ri $a, imm:$b)>;
+}
+
+foreach t = [I1RT, I16RT, I32RT, I64RT] in
+ def NOT_ # t.PtxType : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
+ "not." # t.PtxType,
+ [(set t.Ty:$dst, (not t.Ty:$src))]>;
// Template for left/right shifts. Takes three operands,
// [dest (reg), src (reg), shift (reg or imm)].
@@ -1318,34 +1158,22 @@ def NOT64 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src),
//
// This template also defines a 32-bit shift (imm, imm) instruction.
multiclass SHIFT<string OpcStr, SDNode OpNode> {
- def i64rr :
- BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, B32:$b),
- OpcStr # "64",
- [(set i64:$dst, (OpNode i64:$a, i32:$b))]>;
- def i64ri :
- BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, i32imm:$b),
- OpcStr # "64",
- [(set i64:$dst, (OpNode i64:$a, (i32 imm:$b)))]>;
- def i32rr :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
- OpcStr # "32",
- [(set i32:$dst, (OpNode i32:$a, i32:$b))]>;
- def i32ri :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, i32imm:$b),
- OpcStr # "32",
- [(set i32:$dst, (OpNode i32:$a, (i32 imm:$b)))]>;
- def i32ii :
- BasicNVPTXInst<(outs B32:$dst), (ins i32imm:$a, i32imm:$b),
- OpcStr # "32",
- [(set i32:$dst, (OpNode (i32 imm:$a), (i32 imm:$b)))]>;
- def i16rr :
- BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B32:$b),
- OpcStr # "16",
- [(set i16:$dst, (OpNode i16:$a, i32:$b))]>;
- def i16ri :
- BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, i32imm:$b),
- OpcStr # "16",
- [(set i16:$dst, (OpNode i16:$a, (i32 imm:$b)))]>;
+ let hasSideEffects = false in {
+ foreach t = [I64RT, I32RT, I16RT] in {
+ def t.Size # _rr :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, B32:$b),
+ OpcStr # t.Size,
+ [(set t.Ty:$dst, (OpNode t.Ty:$a, i32:$b))]>;
+ def t.Size # _ri :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, i32imm:$b),
+ OpcStr # t.Size,
+ [(set t.Ty:$dst, (OpNode t.Ty:$a, (i32 imm:$b)))]>;
+ def t.Size # _ii :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, i32imm:$b),
+ OpcStr # t.Size,
+ [(set t.Ty:$dst, (OpNode (t.Ty imm:$a), (i32 imm:$b)))]>;
+ }
+ }
}
defm SHL : SHIFT<"shl.b", shl>;
@@ -1353,14 +1181,11 @@ defm SRA : SHIFT<"shr.s", sra>;
defm SRL : SHIFT<"shr.u", srl>;
// Bit-reverse
-def BREV32 :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$a),
- "brev.b32",
- [(set i32:$dst, (bitreverse i32:$a))]>;
-def BREV64 :
- BasicNVPTXInst<(outs B64:$dst), (ins B64:$a),
- "brev.b64",
- [(set i64:$dst, (bitreverse i64:$a))]>;
+foreach t = [I64RT, I32RT] in
+ def BREV_ # t.PtxType :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a),
+ "brev." # t.PtxType,
+ [(set t.Ty:$dst, (bitreverse t.Ty:$a))]>;
//
@@ -1512,20 +1337,19 @@ def : Pat<(i16 (sext_inreg (trunc (prmt i32:$s, 0, byte_extract_prmt:$sel, PrmtN
// Byte extraction via shift/trunc/sext
-def : Pat<(i16 (sext_inreg (trunc i32:$s), i8)),
- (CVT_s8_s32 $s, CvtNONE)>;
-def : Pat<(i16 (sext_inreg (trunc (srl i32:$s, (i32 imm:$o))), i8)),
+def : Pat<(i16 (sext_inreg (trunc i32:$s), i8)), (CVT_s8_s32 $s, CvtNONE)>;
+def : Pat<(i16 (sext_inreg (trunc i64:$s), i8)), (CVT_s8_s64 $s, CvtNONE)>;
+
+def : Pat<(sext_inreg (srl i32:$s, (i32 imm:$o)), i8), (BFE_S32rii $s, imm:$o, 8)>;
+def : Pat<(sext_inreg (srl i64:$s, (i32 imm:$o)), i8), (BFE_S64rii $s, imm:$o, 8)>;
+
+def : Pat<(i16 (sext_inreg (trunc (srl i32:$s, (i32 imm:$o))), i8)),
(CVT_s8_s32 (BFE_S32rii $s, imm:$o, 8), CvtNONE)>;
-def : Pat<(sext_inreg (srl i32:$s, (i32 imm:$o)), i8),
- (BFE_S32rii $s, imm:$o, 8)>;
+def : Pat<(i16 (sext_inreg (trunc (srl i64:$s, (i32 imm:$o))), i8)),
+ (CVT_s8_s64 (BFE_S64rii $s, imm:$o, 8), CvtNONE)>;
+
def : Pat<(i16 (sra (i16 (trunc i32:$s)), (i32 8))),
(CVT_s8_s32 (BFE_S32rii $s, 8, 8), CvtNONE)>;
-def : Pat<(sext_inreg (srl i64:$s, (i32 imm:$o)), i8),
- (BFE_S64rii $s, imm:$o, 8)>;
-def : Pat<(i16 (sext_inreg (trunc i64:$s), i8)),
- (CVT_s8_s64 $s, CvtNONE)>;
-def : Pat<(i16 (sext_inreg (trunc (srl i64:$s, (i32 imm:$o))), i8)),
- (CVT_s8_s64 (BFE_S64rii $s, imm:$o, 8), CvtNONE)>;
//-----------------------------------
// Comparison instructions (setp, set)
@@ -1615,10 +1439,7 @@ def SETP_bf16x2rr :
def addr : ComplexPattern<pAny, 2, "SelectADDR">;
-def ADDR_base : Operand<pAny> {
- let PrintMethod = "printOperand";
-}
-
+def ADDR_base : Operand<pAny>;
def ADDR : Operand<pAny> {
let PrintMethod = "printMemOperand";
let MIOperandInfo = (ops ADDR_base, i32imm);
@@ -1632,10 +1453,6 @@ def MmaCode : Operand<i32> {
let PrintMethod = "printMmaCode";
}
-def Offseti32imm : Operand<i32> {
- let PrintMethod = "printOffseti32imm";
-}
-
// Get pointer to local stack.
let hasSideEffects = false in {
def MOV_DEPOT_ADDR : NVPTXInst<(outs B32:$d), (ins i32imm:$num),
@@ -1647,33 +1464,31 @@ let hasSideEffects = false in {
// copyPhysreg is hard-coded in NVPTXInstrInfo.cpp
let hasSideEffects = false, isAsCheapAsAMove = true in {
- // Class for register-to-register moves
- class MOVr<RegisterClass RC, string OpStr> :
- BasicNVPTXInst<(outs RC:$dst), (ins RC:$src),
- "mov." # OpStr>;
-
- // Class for immediate-to-register moves
- class MOVi<RegisterClass RC, string OpStr, ValueType VT, Operand IMMType, SDNode ImmNode> :
- BasicNVPTXInst<(outs RC:$dst), (ins IMMType:$src),
- "mov." # OpStr,
- [(set VT:$dst, ImmNode:$src)]>;
-}
+ let isMoveReg = true in
+ class MOVr<RegisterClass RC, string OpStr> :
+ BasicNVPTXInst<(outs RC:$dst), (ins RC:$src), "mov." # OpStr>;
-def IMOV1r : MOVr<B1, "pred">;
-def MOV16r : MOVr<B16, "b16">;
-def IMOV32r : MOVr<B32, "b32">;
-def IMOV64r : MOVr<B64, "b64">;
-def IMOV128r : MOVr<B128, "b128">;
+ let isMoveImm = true in
+ class MOVi<RegTyInfo t, string suffix> :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.Imm:$src),
+ "mov." # suffix,
+ [(set t.Ty:$dst, t.ImmNode:$src)]>;
+}
+def MOV_B1_r : MOVr<B1, "pred">;
+def MOV_B16_r : MOVr<B16, "b16">;
+def MOV_B32_r : MOVr<B32, "b32">;
+def MOV_B64_r : MOVr<B64, "b64">;
+def MOV_B128_r : MOVr<B128, "b128">;
-def IMOV1i : MOVi<B1, "pred", i1, i1imm, imm>;
-def IMOV16i : MOVi<B16, "b16", i16, i16imm, imm>;
-def IMOV32i : MOVi<B32, "b32", i32, i32imm, imm>;
-def IMOV64i : MOVi<B64, "b64", i64, i64imm, imm>;
-def FMOV16i : MOVi<B16, "b16", f16, f16imm, fpimm>;
-def BFMOV16i : MOVi<B16, "b16", bf16, bf16imm, fpimm>;
-def FMOV32i : MOVi<B32, "b32", f32, f32imm, fpimm>;
-def FMOV64i : MOVi<B64, "b64", f64, f64imm, fpimm>;
+def MOV_B1_i : MOVi<I1RT, "pred">;
+def MOV_B16_i : MOVi<I16RT, "b16">;
+def MOV_B32_i : MOVi<I32RT, "b32">;
+def MOV_B64_i : MOVi<I64RT, "b64">;
+def MOV_F16_i : MOVi<F16RT, "b16">;
+def MOV_BF16_i : MOVi<BF16RT, "b16">;
+def MOV_F32_i : MOVi<F32RT, "b32">;
+def MOV_F64_i : MOVi<F64RT, "b64">;
def to_tglobaladdr : SDNodeXForm<globaladdr, [{
@@ -1691,11 +1506,11 @@ def to_tframeindex : SDNodeXForm<frameindex, [{
return CurDAG->getTargetFrameIndex(N->getIndex(), N->getValueType(0));
}]>;
-def : Pat<(i32 globaladdr:$dst), (IMOV32i (to_tglobaladdr $dst))>;
-def : Pat<(i64 globaladdr:$dst), (IMOV64i (to_tglobaladdr $dst))>;
+def : Pat<(i32 globaladdr:$dst), (MOV_B32_i (to_tglobaladdr $dst))>;
+def : Pat<(i64 globaladdr:$dst), (MOV_B64_i (to_tglobaladdr $dst))>;
-def : Pat<(i32 externalsym:$dst), (IMOV32i (to_texternsym $dst))>;
-def : Pat<(i64 externalsym:$dst), (IMOV64i (to_texternsym $dst))>;
+def : Pat<(i32 externalsym:$dst), (MOV_B32_i (to_texternsym $dst))>;
+def : Pat<(i64 externalsym:$dst), (MOV_B64_i (to_texternsym $dst))>;
//---- Copy Frame Index ----
def LEA_ADDRi : NVPTXInst<(outs B32:$dst), (ins ADDR:$addr),
@@ -1709,56 +1524,39 @@ def : Pat<(i64 frameindex:$fi), (LEA_ADDRi64 (to_tframeindex $fi), 0)>;
//-----------------------------------
// Comparison and Selection
//-----------------------------------
+// TODO: These patterns seem very specific and brittle. We should try to find
+// a more general solution.
def cond_signed : PatLeaf<(cond), [{
return isSignedIntSetCC(N->get());
}]>;
-def cond_not_signed : PatLeaf<(cond), [{
- return !isSignedIntSetCC(N->get());
-}]>;
+// A 16-bit signed comparison of sign-extended byte extracts can be converted
+// to 32-bit comparison if we change the PRMT to sign-extend the extracted
+// bytes.
+def : Pat<(setcc (i16 (sext_inreg (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE)), i8)),
+ (i16 (sext_inreg (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE)), i8)),
+ cond_signed:$cc),
+ (SETP_i32rr (PRMT_B32rii i32:$a, 0, (to_sign_extend_selector $sel_a), PrmtNONE),
+ (PRMT_B32rii i32:$b, 0, (to_sign_extend_selector $sel_b), PrmtNONE),
+ (cond2cc $cc))>;
+
+// A 16-bit comparison of truncated byte extracts can be be converted to 32-bit
+// comparison because we know that the truncate is just trancating off zeros
+// and that the most-significant byte is also zeros so the meaning of signed and
+// unsigned comparisons will not be changed.
+def : Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))),
+ (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))),
+ cond:$cc),
+ (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
+ (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
+ (cond2cc $cc))>;
-// comparisons of i8 extracted with PRMT as i32
-// It's faster to do comparison directly on i32 extracted by PRMT,
-// instead of the long conversion and sign extending.
-def: Pat<(setcc (i16 (sext_inreg (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), i8)),
- (i16 (sext_inreg (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), i8)),
- cond_signed:$cc),
- (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
- (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
- (cond2cc $cc))>;
-
-def: Pat<(setcc (i16 (sext_inreg (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE)), i8)),
- (i16 (sext_inreg (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE)), i8)),
- cond_signed:$cc),
- (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
- (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
- (cond2cc $cc))>;
-
-def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))),
- (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))),
- cond_signed:$cc),
- (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
- (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
- (cond2cc $cc))>;
-
-def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))),
- (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))),
- cond_not_signed:$cc),
- (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
- (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
- (cond2cc $cc))>;
def SDTDeclareArrayParam :
SDTypeProfile<0, 3, [SDTCisVT<0, i32>, SDTCisVT<1, i32>, SDTCisVT<2, i32>]>;
def SDTDeclareScalarParam :
SDTypeProfile<0, 2, [SDTCisVT<0, i32>, SDTCisVT<1, i32>]>;
-def SDTLoadParamProfile : SDTypeProfile<1, 2, [SDTCisInt<1>, SDTCisInt<2>]>;
-def SDTLoadParamV2Profile : SDTypeProfile<2, 2, [SDTCisSameAs<0, 1>, SDTCisInt<2>, SDTCisInt<3>]>;
-def SDTLoadParamV4Profile : SDTypeProfile<4, 2, [SDTCisInt<4>, SDTCisInt<5>]>;
-def SDTStoreParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>;
-def SDTStoreParamV2Profile : SDTypeProfile<0, 4, [SDTCisInt<0>, SDTCisInt<1>]>;
-def SDTStoreParamV4Profile : SDTypeProfile<0, 6, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDTMoveParamProfile : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisSameAs<0, 1>]>;
def SDTProxyReg : SDTypeProfile<1, 1, [SDTCisSameAs<0, 1>]>;
@@ -1770,104 +1568,20 @@ def declare_array_param :
def declare_scalar_param :
SDNode<"NVPTXISD::DeclareScalarParam", SDTDeclareScalarParam,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
-
-def LoadParam :
- SDNode<"NVPTXISD::LoadParam", SDTLoadParamProfile,
- [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>;
-def LoadParamV2 :
- SDNode<"NVPTXISD::LoadParamV2", SDTLoadParamV2Profile,
- [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>;
-def LoadParamV4 :
- SDNode<"NVPTXISD::LoadParamV4", SDTLoadParamV4Profile,
- [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>;
-def StoreParam :
- SDNode<"NVPTXISD::StoreParam", SDTStoreParamProfile,
- [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
-def StoreParamV2 :
- SDNode<"NVPTXISD::StoreParamV2", SDTStoreParamV2Profile,
- [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
-def StoreParamV4 :
- SDNode<"NVPTXISD::StoreParamV4", SDTStoreParamV4Profile,
- [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
def MoveParam :
SDNode<"NVPTXISD::MoveParam", SDTMoveParamProfile, []>;
def proxy_reg :
SDNode<"NVPTXISD::ProxyReg", SDTProxyReg, [SDNPHasChain]>;
/// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
- /// NumParams, Callee, Proto, InGlue)
+ /// NumParams, Callee, Proto)
def SDTCallProfile : SDTypeProfile<0, 6,
[SDTCisVT<0, i32>, SDTCisVT<1, i32>, SDTCisVT<2, i32>,
SDTCisVT<3, i32>, SDTCisVT<5, i32>]>;
-def call :
- SDNode<"NVPTXISD::CALL", SDTCallProfile,
- [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
-
-let mayLoad = true in {
- class LoadParamMemInst<NVPTXRegClass regclass, string opstr> :
- NVPTXInst<(outs regclass:$dst), (ins Offseti32imm:$b),
- !strconcat("ld.param", opstr, " \t$dst, [retval0$b];"),
- []>;
-
- class LoadParamV2MemInst<NVPTXRegClass regclass, string opstr> :
- NVPTXInst<(outs regclass:$dst, regclass:$dst2), (ins Offseti32imm:$b),
- !strconcat("ld.param.v2", opstr,
- " \t{{$dst, $dst2}}, [retval0$b];"), []>;
-
- class LoadParamV4MemInst<NVPTXRegClass regclass, string opstr> :
- NVPTXInst<(outs regclass:$dst, regclass:$dst2, regclass:$dst3,
- regclass:$dst4),
- (ins Offseti32imm:$b),
- !strconcat("ld.param.v4", opstr,
- " \t{{$dst, $dst2, $dst3, $dst4}}, [retval0$b];"),
- []>;
-}
-
-let mayStore = true in {
-
- multiclass StoreParamInst<NVPTXRegClass regclass, Operand IMMType, string opstr, bit support_imm = true> {
- foreach op = [IMMType, regclass] in
- if !or(support_imm, !isa<NVPTXRegClass>(op)) then
- def _ # !if(!isa<NVPTXRegClass>(op), "r", "i")
- : NVPTXInst<(outs),
- (ins op:$val, i32imm:$a, Offseti32imm:$b),
- "st.param" # opstr # " \t[param$a$b], $val;",
- []>;
- }
-
- multiclass StoreParamV2Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> {
- foreach op1 = [IMMType, regclass] in
- foreach op2 = [IMMType, regclass] in
- def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i")
- # !if(!isa<NVPTXRegClass>(op2), "r", "i")
- : NVPTXInst<(outs),
- (ins op1:$val1, op2:$val2,
- i32imm:$a, Offseti32imm:$b),
- "st.param.v2" # opstr # " \t[param$a$b], {{$val1, $val2}};",
- []>;
- }
-
- multiclass StoreParamV4Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> {
- foreach op1 = [IMMType, regclass] in
- foreach op2 = [IMMType, regclass] in
- foreach op3 = [IMMType, regclass] in
- foreach op4 = [IMMType, regclass] in
- def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i")
- # !if(!isa<NVPTXRegClass>(op2), "r", "i")
- # !if(!isa<NVPTXRegClass>(op3), "r", "i")
- # !if(!isa<NVPTXRegClass>(op4), "r", "i")
-
- : NVPTXInst<(outs),
- (ins op1:$val1, op2:$val2, op3:$val3, op4:$val4,
- i32imm:$a, Offseti32imm:$b),
- "st.param.v4" # opstr #
- " \t[param$a$b], {{$val1, $val2, $val3, $val4}};",
- []>;
- }
-}
+def call : SDNode<"NVPTXISD::CALL", SDTCallProfile, [SDNPHasChain, SDNPSideEffect]>;
/// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
-/// NumParams, Callee, Proto, InGlue)
+/// NumParams, Callee, Proto)
def CallOperand : Operand<i32> { let PrintMethod = "printCallOperand"; }
@@ -1904,43 +1618,6 @@ foreach is_convergent = [0, 1] in {
(call_uni_inst $addr, imm:$rets, imm:$params)>;
}
-def LoadParamMemI64 : LoadParamMemInst<B64, ".b64">;
-def LoadParamMemI32 : LoadParamMemInst<B32, ".b32">;
-def LoadParamMemI16 : LoadParamMemInst<B16, ".b16">;
-def LoadParamMemI8 : LoadParamMemInst<B16, ".b8">;
-def LoadParamMemV2I64 : LoadParamV2MemInst<B64, ".b64">;
-def LoadParamMemV2I32 : LoadParamV2MemInst<B32, ".b32">;
-def LoadParamMemV2I16 : LoadParamV2MemInst<B16, ".b16">;
-def LoadParamMemV2I8 : LoadParamV2MemInst<B16, ".b8">;
-def LoadParamMemV4I32 : LoadParamV4MemInst<B32, ".b32">;
-def LoadParamMemV4I16 : LoadParamV4MemInst<B16, ".b16">;
-def LoadParamMemV4I8 : LoadParamV4MemInst<B16, ".b8">;
-
-defm StoreParamI64 : StoreParamInst<B64, i64imm, ".b64">;
-defm StoreParamI32 : StoreParamInst<B32, i32imm, ".b32">;
-defm StoreParamI16 : StoreParamInst<B16, i16imm, ".b16">;
-defm StoreParamI8 : StoreParamInst<B16, i8imm, ".b8">;
-
-defm StoreParamI8TruncI32 : StoreParamInst<B32, i8imm, ".b8", /* support_imm */ false>;
-defm StoreParamI8TruncI64 : StoreParamInst<B64, i8imm, ".b8", /* support_imm */ false>;
-
-defm StoreParamV2I64 : StoreParamV2Inst<B64, i64imm, ".b64">;
-defm StoreParamV2I32 : StoreParamV2Inst<B32, i32imm, ".b32">;
-defm StoreParamV2I16 : StoreParamV2Inst<B16, i16imm, ".b16">;
-defm StoreParamV2I8 : StoreParamV2Inst<B16, i8imm, ".b8">;
-
-defm StoreParamV4I32 : StoreParamV4Inst<B32, i32imm, ".b32">;
-defm StoreParamV4I16 : StoreParamV4Inst<B16, i16imm, ".b16">;
-defm StoreParamV4I8 : StoreParamV4Inst<B16, i8imm, ".b8">;
-
-defm StoreParamF32 : StoreParamInst<B32, f32imm, ".b32">;
-defm StoreParamF64 : StoreParamInst<B64, f64imm, ".b64">;
-
-defm StoreParamV2F32 : StoreParamV2Inst<B32, f32imm, ".b32">;
-defm StoreParamV2F64 : StoreParamV2Inst<B64, f64imm, ".b64">;
-
-defm StoreParamV4F32 : StoreParamV4Inst<B32, f32imm, ".b32">;
-
def DECLARE_PARAM_array :
NVPTXInst<(outs), (ins i32imm:$a, i32imm:$align, i32imm:$size),
".param .align $align .b8 \t$a[$size];", []>;
@@ -1953,6 +1630,18 @@ def : Pat<(declare_array_param externalsym:$a, imm:$align, imm:$size),
def : Pat<(declare_scalar_param externalsym:$a, imm:$size),
(DECLARE_PARAM_scalar (to_texternsym $a), imm:$size)>;
+// Call prototype wrapper, this is a dummy instruction that just prints it's
+// operand which is string defining the prototype.
+def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
+def CallPrototype :
+ SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype,
+ [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def ProtoIdent : Operand<i32> { let PrintMethod = "printProtoIdent"; }
+def CALL_PROTOTYPE :
+ NVPTXInst<(outs), (ins ProtoIdent:$ident),
+ "$ident", [(CallPrototype (i32 texternalsym:$ident))]>;
+
+
foreach t = [I32RT, I64RT] in {
defvar inst_name = "MOV" # t.Size # "_PARAM";
def inst_name : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), "mov.b" # t.Size>;
@@ -1972,6 +1661,32 @@ defm ProxyRegB16 : ProxyRegInst<"b16", B16>;
defm ProxyRegB32 : ProxyRegInst<"b32", B32>;
defm ProxyRegB64 : ProxyRegInst<"b64", B64>;
+
+// Callseq start and end
+
+// Note: these nodes are marked as SDNPMayStore and SDNPMayLoad because
+// they define the scope in which the declared params may be used. Therefore
+// we add these flags to ensure ld.param and st.param are not sunk or hoisted
+// out of that scope.
+
+def callseq_start : SDNode<"ISD::CALLSEQ_START",
+ SDCallSeqStart<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>,
+ [SDNPHasChain, SDNPOutGlue,
+ SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>;
+def callseq_end : SDNode<"ISD::CALLSEQ_END",
+ SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>,
+ [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
+ SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>;
+
+def Callseq_Start :
+ NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
+ "\\{ // callseq $amt1, $amt2",
+ [(callseq_start timm:$amt1, timm:$amt2)]>;
+def Callseq_End :
+ NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
+ "\\} // callseq $amt1",
+ [(callseq_end timm:$amt1, timm:$amt2)]>;
+
//
// Load / Store Handling
//
@@ -1984,7 +1699,6 @@ class LD<NVPTXRegClass regclass>
"\t$dst, [$addr];", []>;
let mayLoad=1, hasSideEffects=0 in {
- def LD_i8 : LD<B16>;
def LD_i16 : LD<B16>;
def LD_i32 : LD<B32>;
def LD_i64 : LD<B64>;
@@ -2000,7 +1714,6 @@ class ST<DAGOperand O>
" \t[$addr], $src;", []>;
let mayStore=1, hasSideEffects=0 in {
- def ST_i8 : ST<RI16>;
def ST_i16 : ST<RI16>;
def ST_i32 : ST<RI32>;
def ST_i64 : ST<RI64>;
@@ -2033,7 +1746,6 @@ multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
"[$addr];", []>;
}
let mayLoad=1, hasSideEffects=0 in {
- defm LDV_i8 : LD_VEC<B16>;
defm LDV_i16 : LD_VEC<B16>;
defm LDV_i32 : LD_VEC<B32, support_v8 = true>;
defm LDV_i64 : LD_VEC<B64>;
@@ -2067,7 +1779,6 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
}
let mayStore=1, hasSideEffects=0 in {
- defm STV_i8 : ST_VEC<RI16>;
defm STV_i16 : ST_VEC<RI16>;
defm STV_i32 : ST_VEC<RI32, support_v8 = true>;
defm STV_i64 : ST_VEC<RI64>;
@@ -2237,14 +1948,14 @@ def : Pat<(i64 (anyext i32:$a)), (CVT_u64_u32 $a, CvtNONE)>;
// truncate i64
def : Pat<(i32 (trunc i64:$a)), (CVT_u32_u64 $a, CvtNONE)>;
def : Pat<(i16 (trunc i64:$a)), (CVT_u16_u64 $a, CvtNONE)>;
-def : Pat<(i1 (trunc i64:$a)), (SETP_i64ri (ANDb64ri $a, 1), 0, CmpNE)>;
+def : Pat<(i1 (trunc i64:$a)), (SETP_i64ri (AND_b64ri $a, 1), 0, CmpNE)>;
// truncate i32
def : Pat<(i16 (trunc i32:$a)), (CVT_u16_u32 $a, CvtNONE)>;
-def : Pat<(i1 (trunc i32:$a)), (SETP_i32ri (ANDb32ri $a, 1), 0, CmpNE)>;
+def : Pat<(i1 (trunc i32:$a)), (SETP_i32ri (AND_b32ri $a, 1), 0, CmpNE)>;
// truncate i16
-def : Pat<(i1 (trunc i16:$a)), (SETP_i16ri (ANDb16ri $a, 1), 0, CmpNE)>;
+def : Pat<(i1 (trunc i16:$a)), (SETP_i16ri (AND_b16ri $a, 1), 0, CmpNE)>;
// sext_inreg
def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 $a)>;
@@ -2488,52 +2199,20 @@ defm : CVT_ROUND<frint, CvtRNI, CvtRNI_FTZ>;
//-----------------------------------
let isTerminator=1 in {
- let isReturn=1, isBarrier=1 in
+ let isReturn=1, isBarrier=1 in
def Return : BasicNVPTXInst<(outs), (ins), "ret", [(retglue)]>;
- let isBranch=1 in
- def CBranch : NVPTXInst<(outs), (ins B1:$a, brtarget:$target),
+ let isBranch=1 in {
+ def CBranch : NVPTXInst<(outs), (ins B1:$a, brtarget:$target),
"@$a bra \t$target;",
[(brcond i1:$a, bb:$target)]>;
- let isBranch=1 in
- def CBranchOther : NVPTXInst<(outs), (ins B1:$a, brtarget:$target),
- "@!$a bra \t$target;", []>;
- let isBranch=1, isBarrier=1 in
+ let isBarrier=1 in
def GOTO : BasicNVPTXInst<(outs), (ins brtarget:$target),
- "bra.uni", [(br bb:$target)]>;
+ "bra.uni", [(br bb:$target)]>;
+ }
}
-def : Pat<(brcond i32:$a, bb:$target),
- (CBranch (SETP_i32ri $a, 0, CmpNE), bb:$target)>;
-
-// SelectionDAGBuilder::visitSWitchCase() will invert the condition of a
-// conditional branch if the target block is the next block so that the code
-// can fall through to the target block. The inversion is done by 'xor
-// condition, 1', which will be translated to (setne condition, -1). Since ptx
-// supports '@!pred bra target', we should use it.
-def : Pat<(brcond (i1 (setne i1:$a, -1)), bb:$target),
- (CBranchOther $a, bb:$target)>;
-
-// Call
-def SDT_NVPTXCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>,
- SDTCisVT<1, i32>]>;
-def SDT_NVPTXCallSeqEnd : SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>;
-
-def callseq_start : SDNode<"ISD::CALLSEQ_START", SDT_NVPTXCallSeqStart,
- [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
-def callseq_end : SDNode<"ISD::CALLSEQ_END", SDT_NVPTXCallSeqEnd,
- [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
- SDNPSideEffect]>;
-
-def Callseq_Start :
- NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
- "\\{ // callseq $amt1, $amt2",
- [(callseq_start timm:$amt1, timm:$amt2)]>;
-def Callseq_End :
- NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
- "\\} // callseq $amt1",
- [(callseq_end timm:$amt1, timm:$amt2)]>;
// trap instruction
def trapinst : BasicNVPTXInst<(outs), (ins), "trap", [(trap)]>, Requires<[noPTXASUnreachableBug]>;
@@ -2543,18 +2222,6 @@ def trapexitinst : NVPTXInst<(outs), (ins), "trap; exit;", [(trap)]>, Requires<[
// brkpt instruction
def debugtrapinst : BasicNVPTXInst<(outs), (ins), "brkpt", [(debugtrap)]>;
-// Call prototype wrapper
-def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
-def CallPrototype :
- SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype,
- [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
-def ProtoIdent : Operand<i32> {
- let PrintMethod = "printProtoIdent";
-}
-def CALL_PROTOTYPE :
- NVPTXInst<(outs), (ins ProtoIdent:$ident),
- "$ident", [(CallPrototype (i32 texternalsym:$ident))]>;
-
def SDTDynAllocaOp :
SDTypeProfile<1, 2, [SDTCisSameAs<0, 1>, SDTCisInt<1>, SDTCisVT<2, i32>]>;