diff options
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXInstrInfo.td')
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 833 |
1 files changed, 250 insertions, 583 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index b5df4c6..d8047d3 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -15,19 +15,8 @@ include "NVPTXInstrFormats.td" let OperandType = "OPERAND_IMMEDIATE" in { def f16imm : Operand<f16>; def bf16imm : Operand<bf16>; - } -// List of vector specific properties -def isVecLD : VecInstTypeEnum<1>; -def isVecST : VecInstTypeEnum<2>; -def isVecBuild : VecInstTypeEnum<3>; -def isVecShuffle : VecInstTypeEnum<4>; -def isVecExtract : VecInstTypeEnum<5>; -def isVecInsert : VecInstTypeEnum<6>; -def isVecDest : VecInstTypeEnum<7>; -def isVecOther : VecInstTypeEnum<15>; - //===----------------------------------------------------------------------===// // NVPTX Operand Definitions. //===----------------------------------------------------------------------===// @@ -125,8 +114,6 @@ def doF32FTZ : Predicate<"useF32FTZ()">; def doNoF32FTZ : Predicate<"!useF32FTZ()">; def doRsqrtOpt : Predicate<"doRsqrtOpt()">; -def doMulWide : Predicate<"doMulWide">; - def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">; def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">; def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">; @@ -486,46 +473,28 @@ let hasSideEffects = false in { // takes a CvtMode immediate that defines the conversion mode to use. It can // be CvtNONE to omit a conversion mode. multiclass CVT_FROM_ALL<string ToType, RegisterClass RC, list<Predicate> Preds = []> { - def _s8 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s8">, - Requires<Preds>; - def _u8 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u8">, - Requires<Preds>; - def _s16 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s16">, - Requires<Preds>; - def _u16 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B16:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u16">, - Requires<Preds>; - def _s32 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B32:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s32">, - Requires<Preds>; - def _u32 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B32:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u32">, - Requires<Preds>; - def _s64 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B64:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".s64">, - Requires<Preds>; - def _u64 : - BasicFlagsNVPTXInst<(outs RC:$dst), - (ins B64:$src), (ins CvtMode:$mode), - "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # ".u64">, - Requires<Preds>; + foreach sign = ["s", "u"] in { + def _ # sign # "8" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B16:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "8">, + Requires<Preds>; + def _ # sign # "16" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B16:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "16">, + Requires<Preds>; + def _ # sign # "32" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B32:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "32">, + Requires<Preds>; + def _ # sign # "64" : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B64:$src), (ins CvtMode:$mode), + "cvt${mode:base}${mode:ftz}${mode:sat}." # ToType # "." # sign # "64">, + Requires<Preds>; + } def _f16 : BasicFlagsNVPTXInst<(outs RC:$dst), (ins B16:$src), (ins CvtMode:$mode), @@ -556,14 +525,12 @@ let hasSideEffects = false in { } // Generate cvts from all types to all types. - defm CVT_s8 : CVT_FROM_ALL<"s8", B16>; - defm CVT_u8 : CVT_FROM_ALL<"u8", B16>; - defm CVT_s16 : CVT_FROM_ALL<"s16", B16>; - defm CVT_u16 : CVT_FROM_ALL<"u16", B16>; - defm CVT_s32 : CVT_FROM_ALL<"s32", B32>; - defm CVT_u32 : CVT_FROM_ALL<"u32", B32>; - defm CVT_s64 : CVT_FROM_ALL<"s64", B64>; - defm CVT_u64 : CVT_FROM_ALL<"u64", B64>; + foreach sign = ["s", "u"] in { + defm CVT_ # sign # "8" : CVT_FROM_ALL<sign # "8", B16>; + defm CVT_ # sign # "16" : CVT_FROM_ALL<sign # "16", B16>; + defm CVT_ # sign # "32" : CVT_FROM_ALL<sign # "32", B32>; + defm CVT_ # sign # "64" : CVT_FROM_ALL<sign # "64", B64>; + } defm CVT_f16 : CVT_FROM_ALL<"f16", B16>; defm CVT_bf16 : CVT_FROM_ALL<"bf16", B16, [hasPTX<78>, hasSM<90>]>; defm CVT_f32 : CVT_FROM_ALL<"f32", B32>; @@ -571,18 +538,12 @@ let hasSideEffects = false in { // These cvts are different from those above: The source and dest registers // are of the same type. - def CVT_INREG_s16_s8 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), - "cvt.s16.s8">; - def CVT_INREG_s32_s8 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), - "cvt.s32.s8">; - def CVT_INREG_s32_s16 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), - "cvt.s32.s16">; - def CVT_INREG_s64_s8 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "cvt.s64.s8">; - def CVT_INREG_s64_s16 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "cvt.s64.s16">; - def CVT_INREG_s64_s32 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "cvt.s64.s32">; + def CVT_INREG_s16_s8 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), "cvt.s16.s8">; + def CVT_INREG_s32_s8 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "cvt.s32.s8">; + def CVT_INREG_s32_s16 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "cvt.s32.s16">; + def CVT_INREG_s64_s8 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s8">; + def CVT_INREG_s64_s16 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s16">; + def CVT_INREG_s64_s32 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), "cvt.s64.s32">; multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> { def _f32 : @@ -784,7 +745,7 @@ defm SUB : I3<"sub.s", sub, commutative = false>; def ADD16x2 : I16x2<"add.s", add>; -// in32 and int64 addition and subtraction with carry-out. +// int32 and int64 addition and subtraction with carry-out. defm ADDCC : ADD_SUB_INT_CARRY<"add.cc", addc, commutative = true>; defm SUBCC : ADD_SUB_INT_CARRY<"sub.cc", subc, commutative = false>; @@ -805,17 +766,17 @@ defm UDIV : I3<"div.u", udiv, commutative = false>; defm SREM : I3<"rem.s", srem, commutative = false>; defm UREM : I3<"rem.u", urem, commutative = false>; -// Integer absolute value. NumBits should be one minus the bit width of RC. -// This idiom implements the algorithm at -// http://graphics.stanford.edu/~seander/bithacks.html#IntegerAbs. -multiclass ABS<ValueType T, RegisterClass RC, string SizeName> { - def : BasicNVPTXInst<(outs RC:$dst), (ins RC:$a), - "abs" # SizeName, - [(set T:$dst, (abs T:$a))]>; +foreach t = [I16RT, I32RT, I64RT] in { + def ABS_S # t.Size : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a), + "abs.s" # t.Size, + [(set t.Ty:$dst, (abs t.Ty:$a))]>; + + def NEG_S # t.Size : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), + "neg.s" # t.Size, + [(set t.Ty:$dst, (ineg t.Ty:$src))]>; } -defm ABS_16 : ABS<i16, B16, ".s16">; -defm ABS_32 : ABS<i32, B32, ".s32">; -defm ABS_64 : ABS<i64, B64, ".s64">; // Integer min/max. defm SMAX : I3<"max.s", smax, commutative = true>; @@ -832,170 +793,63 @@ def UMIN16x2 : I16x2<"min.u", umin>; // // Wide multiplication // -def MULWIDES64 : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.s32">; -def MULWIDES64Imm : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.s32">; -def MULWIDES64Imm64 : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.s32">; - -def MULWIDEU64 : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, B32:$b), "mul.wide.u32">; -def MULWIDEU64Imm : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i32imm:$b), "mul.wide.u32">; -def MULWIDEU64Imm64 : - BasicNVPTXInst<(outs B64:$dst), (ins B32:$a, i64imm:$b), "mul.wide.u32">; - -def MULWIDES32 : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.s16">; -def MULWIDES32Imm : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.s16">; -def MULWIDES32Imm32 : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.s16">; - -def MULWIDEU32 : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b), "mul.wide.u16">; -def MULWIDEU32Imm : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i16imm:$b), "mul.wide.u16">; -def MULWIDEU32Imm32 : - BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, i32imm:$b), "mul.wide.u16">; - -def SDTMulWide : SDTypeProfile<1, 2, [SDTCisSameAs<1, 2>]>; -def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>; -def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>; - -// Matchers for signed, unsigned mul.wide ISD nodes. -let Predicates = [doMulWide] in { - def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>; - def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>; - def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>; - def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)), (MULWIDEU32Imm $a, imm:$b)>; - - def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)), (MULWIDES64 $a, $b)>; - def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)), (MULWIDES64Imm $a, imm:$b)>; - def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)), (MULWIDEU64 $a, $b)>; - def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>; -} - -// Predicates used for converting some patterns to mul.wide. -def SInt32Const : PatLeaf<(imm), [{ - const APInt &v = N->getAPIntValue(); - return v.isSignedIntN(32); -}]>; - -def UInt32Const : PatLeaf<(imm), [{ - const APInt &v = N->getAPIntValue(); - return v.isIntN(32); -}]>; -def SInt16Const : PatLeaf<(imm), [{ - const APInt &v = N->getAPIntValue(); - return v.isSignedIntN(16); -}]>; - -def UInt16Const : PatLeaf<(imm), [{ - const APInt &v = N->getAPIntValue(); - return v.isIntN(16); -}]>; - -def IntConst_0_30 : PatLeaf<(imm), [{ - // Check if 0 <= v < 31; only then will the result of (x << v) be an int32. - const APInt &v = N->getAPIntValue(); - return v.sge(0) && v.slt(31); -}]>; - -def IntConst_0_14 : PatLeaf<(imm), [{ - // Check if 0 <= v < 15; only then will the result of (x << v) be an int16. - const APInt &v = N->getAPIntValue(); - return v.sge(0) && v.slt(15); -}]>; - -def SHL2MUL32 : SDNodeXForm<imm, [{ - const APInt &v = N->getAPIntValue(); - APInt temp(32, 1); - return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i32); -}]>; +def SDTMulWide : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>; +def smul_wide : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide, [SDNPCommutative]>; +def umul_wide : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide, [SDNPCommutative]>; -def SHL2MUL16 : SDNodeXForm<imm, [{ - const APInt &v = N->getAPIntValue(); - APInt temp(16, 1); - return CurDAG->getTargetConstant(temp.shl(v), SDLoc(N), MVT::i16); -}]>; - -// Convert "sign/zero-extend, then shift left by an immediate" to mul.wide. -let Predicates = [doMulWide] in { - def : Pat<(shl (sext i32:$a), (i32 IntConst_0_30:$b)), - (MULWIDES64Imm $a, (SHL2MUL32 $b))>; - def : Pat<(shl (zext i32:$a), (i32 IntConst_0_30:$b)), - (MULWIDEU64Imm $a, (SHL2MUL32 $b))>; - - def : Pat<(shl (sext i16:$a), (i16 IntConst_0_14:$b)), - (MULWIDES32Imm $a, (SHL2MUL16 $b))>; - def : Pat<(shl (zext i16:$a), (i16 IntConst_0_14:$b)), - (MULWIDEU32Imm $a, (SHL2MUL16 $b))>; - - // Convert "sign/zero-extend then multiply" to mul.wide. - def : Pat<(mul (sext i32:$a), (sext i32:$b)), - (MULWIDES64 $a, $b)>; - def : Pat<(mul (sext i32:$a), (i64 SInt32Const:$b)), - (MULWIDES64Imm64 $a, (i64 SInt32Const:$b))>; - - def : Pat<(mul (zext i32:$a), (zext i32:$b)), - (MULWIDEU64 $a, $b)>; - def : Pat<(mul (zext i32:$a), (i64 UInt32Const:$b)), - (MULWIDEU64Imm64 $a, (i64 UInt32Const:$b))>; - def : Pat<(mul (sext i16:$a), (sext i16:$b)), - (MULWIDES32 $a, $b)>; - def : Pat<(mul (sext i16:$a), (i32 SInt16Const:$b)), - (MULWIDES32Imm32 $a, (i32 SInt16Const:$b))>; - - def : Pat<(mul (zext i16:$a), (zext i16:$b)), - (MULWIDEU32 $a, $b)>; - def : Pat<(mul (zext i16:$a), (i32 UInt16Const:$b)), - (MULWIDEU32Imm32 $a, (i32 UInt16Const:$b))>; +multiclass MULWIDEInst<string suffix, SDPatternOperator op, RegTyInfo big_t, RegTyInfo small_t> { + def suffix # _rr : + BasicNVPTXInst<(outs big_t.RC:$dst), (ins small_t.RC:$a, small_t.RC:$b), + "mul.wide." # suffix, + [(set big_t.Ty:$dst, (op small_t.Ty:$a, small_t.Ty:$b))]>; + def suffix # _ri : + BasicNVPTXInst<(outs big_t.RC:$dst), (ins small_t.RC:$a, small_t.Imm:$b), + "mul.wide." # suffix, + [(set big_t.Ty:$dst, (op small_t.Ty:$a, imm:$b))]>; } +defm MUL_WIDE : MULWIDEInst<"s32", smul_wide, I64RT, I32RT>; +defm MUL_WIDE : MULWIDEInst<"u32", umul_wide, I64RT, I32RT>; +defm MUL_WIDE : MULWIDEInst<"s16", smul_wide, I32RT, I16RT>; +defm MUL_WIDE : MULWIDEInst<"u16", umul_wide, I32RT, I16RT>; + // // Integer multiply-add // -def mul_oneuse : OneUse2<mul>; - -multiclass MAD<string Ptx, ValueType VT, NVPTXRegClass Reg, Operand Imm> { +multiclass MADInst<string suffix, SDPatternOperator op, RegTyInfo big_t, RegTyInfo small_t> { def rrr: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Reg:$b, Reg:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), VT:$c))]>; - - def rir: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Imm:$b, Reg:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), VT:$c))]>; + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.RC:$b, big_t.RC:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, small_t.Ty:$b), big_t.Ty:$c))]>; def rri: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Reg:$b, Imm:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), imm:$c))]>; + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.RC:$b, big_t.Imm:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, small_t.Ty:$b), imm:$c))]>; + def rir: + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.Imm:$b, big_t.RC:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, imm:$b), big_t.Ty:$c))]>; def rii: - BasicNVPTXInst<(outs Reg:$dst), - (ins Reg:$a, Imm:$b, Imm:$c), - Ptx, - [(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), imm:$c))]>; + BasicNVPTXInst<(outs big_t.RC:$dst), + (ins small_t.RC:$a, small_t.Imm:$b, big_t.Imm:$c), + "mad." # suffix, + [(set big_t.Ty:$dst, (add (OneUse2<op> small_t.Ty:$a, imm:$b), imm:$c))]>; } let Predicates = [hasOptEnabled] in { -defm MAD16 : MAD<"mad.lo.s16", i16, B16, i16imm>; -defm MAD32 : MAD<"mad.lo.s32", i32, B32, i32imm>; -defm MAD64 : MAD<"mad.lo.s64", i64, B64, i64imm>; -} + defm MAD_LO_S16 : MADInst<"lo.s16", mul, I16RT, I16RT>; + defm MAD_LO_S32 : MADInst<"lo.s32", mul, I32RT, I32RT>; + defm MAD_LO_S64 : MADInst<"lo.s64", mul, I64RT, I64RT>; -foreach t = [I16RT, I32RT, I64RT] in { - def NEG_S # t.Size : - BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), - "neg.s" # t.Size, - [(set t.Ty:$dst, (ineg t.Ty:$src))]>; + defm MAD_WIDE_U16 : MADInst<"wide.u16", umul_wide, I32RT, I16RT>; + defm MAD_WIDE_S16 : MADInst<"wide.s16", smul_wide, I32RT, I16RT>; + defm MAD_WIDE_U32 : MADInst<"wide.u32", umul_wide, I64RT, I32RT>; + defm MAD_WIDE_S32 : MADInst<"wide.s32", smul_wide, I64RT, I32RT>; } //----------------------------------- @@ -1106,8 +960,7 @@ def fdiv_approx : PatFrag<(ops node:$a, node:$b), def FRCP32_approx_r : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$b), (ins FTZFlag:$ftz), "rcp.approx$ftz.f32", [(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>; @@ -1116,14 +969,12 @@ def FRCP32_approx_r : // def FDIV32_approx_rr : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, B32:$b), (ins FTZFlag:$ftz), "div.approx$ftz.f32", [(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>; def FDIV32_approx_ri : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, f32imm:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz), "div.approx$ftz.f32", [(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>; // @@ -1146,14 +997,12 @@ def : Pat<(fdiv_full f32imm_1, f32:$b), // def FDIV32rr : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, B32:$b), (ins FTZFlag:$ftz), "div.full$ftz.f32", [(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>; def FDIV32ri : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, f32imm:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz), "div.full$ftz.f32", [(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>; // @@ -1167,8 +1016,7 @@ def fdiv_ftz : PatFrag<(ops node:$a, node:$b), def FRCP32r_prec : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$b), (ins FTZFlag:$ftz), "rcp.rn$ftz.f32", [(set f32:$dst, (fdiv_ftz f32imm_1, f32:$b))]>; // @@ -1176,14 +1024,12 @@ def FRCP32r_prec : // def FDIV32rr_prec : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, B32:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, B32:$b), (ins FTZFlag:$ftz), "div.rn$ftz.f32", [(set f32:$dst, (fdiv_ftz f32:$a, f32:$b))]>; def FDIV32ri_prec : BasicFlagsNVPTXInst<(outs B32:$dst), - (ins B32:$a, f32imm:$b), - (ins FTZFlag:$ftz), + (ins B32:$a, f32imm:$b), (ins FTZFlag:$ftz), "div.rn$ftz.f32", [(set f32:$dst, (fdiv_ftz f32:$a, fpimm:$b))]>; @@ -1234,7 +1080,7 @@ defm FMA_F32 : FMA<F32RT, allow_ftz = true>; defm FMA_F32x2 : FMA<F32X2RT, allow_ftz = true, preds = [hasF32x2Instructions]>; defm FMA_F64 : FMA<F64RT, allow_ftz = false>; -// sin/cos +// sin/cos/tanh class UnaryOpAllowsApproxFn<SDPatternOperator operator> : PatFrag<(ops node:$A), @@ -1250,6 +1096,10 @@ def COS_APPROX_f32 : BasicFlagsNVPTXInst<(outs B32:$dst), (ins B32:$src), (ins FTZFlag:$ftz), "cos.approx$ftz.f32", [(set f32:$dst, (UnaryOpAllowsApproxFn<fcos> f32:$src))]>; +def TANH_APPROX_f32 : + BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "tanh.approx.f32", + [(set f32:$dst, (UnaryOpAllowsApproxFn<ftanh> f32:$src))]>, + Requires<[hasPTX<70>, hasSM<75>]>; //----------------------------------- // Bitwise operations @@ -1258,10 +1108,8 @@ def COS_APPROX_f32 : // Template for three-arg bitwise operations. Takes three args, Creates .b16, // .b32, .b64, and .pred (predicate registers -- i.e., i1) versions of OpcStr. multiclass BITWISE<string OpcStr, SDNode OpNode> { - defm b1 : I3Inst<OpcStr # ".pred", OpNode, I1RT, commutative = true>; - defm b16 : I3Inst<OpcStr # ".b16", OpNode, I16RT, commutative = true>; - defm b32 : I3Inst<OpcStr # ".b32", OpNode, I32RT, commutative = true>; - defm b64 : I3Inst<OpcStr # ".b64", OpNode, I64RT, commutative = true>; + foreach t = [I1RT, I16RT, I32RT, I64RT] in + defm _ # t.PtxType : I3Inst<OpcStr # "." # t.PtxType, OpNode, t, commutative = true>; } defm OR : BITWISE<"or", or>; @@ -1269,48 +1117,40 @@ defm AND : BITWISE<"and", and>; defm XOR : BITWISE<"xor", xor>; // PTX does not support mul on predicates, convert to and instructions -def : Pat<(mul i1:$a, i1:$b), (ANDb1rr $a, $b)>; -def : Pat<(mul i1:$a, imm:$b), (ANDb1ri $a, imm:$b)>; +def : Pat<(mul i1:$a, i1:$b), (AND_predrr $a, $b)>; +def : Pat<(mul i1:$a, imm:$b), (AND_predri $a, imm:$b)>; foreach op = [add, sub] in { - def : Pat<(op i1:$a, i1:$b), (XORb1rr $a, $b)>; - def : Pat<(op i1:$a, imm:$b), (XORb1ri $a, imm:$b)>; + def : Pat<(op i1:$a, i1:$b), (XOR_predrr $a, $b)>; + def : Pat<(op i1:$a, imm:$b), (XOR_predri $a, imm:$b)>; } // These transformations were once reliably performed by instcombine, but thanks // to poison semantics they are no longer safe for LLVM IR, perform them here // instead. -def : Pat<(select i1:$a, i1:$b, 0), (ANDb1rr $a, $b)>; -def : Pat<(select i1:$a, 1, i1:$b), (ORb1rr $a, $b)>; +def : Pat<(select i1:$a, i1:$b, 0), (AND_predrr $a, $b)>; +def : Pat<(select i1:$a, 1, i1:$b), (OR_predrr $a, $b)>; // Lower logical v2i16/v4i8 ops as bitwise ops on b32. foreach vt = [v2i16, v4i8] in { - def : Pat<(or vt:$a, vt:$b), (ORb32rr $a, $b)>; - def : Pat<(xor vt:$a, vt:$b), (XORb32rr $a, $b)>; - def : Pat<(and vt:$a, vt:$b), (ANDb32rr $a, $b)>; + def : Pat<(or vt:$a, vt:$b), (OR_b32rr $a, $b)>; + def : Pat<(xor vt:$a, vt:$b), (XOR_b32rr $a, $b)>; + def : Pat<(and vt:$a, vt:$b), (AND_b32rr $a, $b)>; // The constants get legalized into a bitcast from i32, so that's what we need // to match here. def: Pat<(or vt:$a, (vt (bitconvert (i32 imm:$b)))), - (ORb32ri $a, imm:$b)>; + (OR_b32ri $a, imm:$b)>; def: Pat<(xor vt:$a, (vt (bitconvert (i32 imm:$b)))), - (XORb32ri $a, imm:$b)>; + (XOR_b32ri $a, imm:$b)>; def: Pat<(and vt:$a, (vt (bitconvert (i32 imm:$b)))), - (ANDb32ri $a, imm:$b)>; -} - -def NOT1 : BasicNVPTXInst<(outs B1:$dst), (ins B1:$src), - "not.pred", - [(set i1:$dst, (not i1:$src))]>; -def NOT16 : BasicNVPTXInst<(outs B16:$dst), (ins B16:$src), - "not.b16", - [(set i16:$dst, (not i16:$src))]>; -def NOT32 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), - "not.b32", - [(set i32:$dst, (not i32:$src))]>; -def NOT64 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), - "not.b64", - [(set i64:$dst, (not i64:$src))]>; + (AND_b32ri $a, imm:$b)>; +} + +foreach t = [I1RT, I16RT, I32RT, I64RT] in + def NOT_ # t.PtxType : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), + "not." # t.PtxType, + [(set t.Ty:$dst, (not t.Ty:$src))]>; // Template for left/right shifts. Takes three operands, // [dest (reg), src (reg), shift (reg or imm)]. @@ -1318,34 +1158,22 @@ def NOT64 : BasicNVPTXInst<(outs B64:$dst), (ins B64:$src), // // This template also defines a 32-bit shift (imm, imm) instruction. multiclass SHIFT<string OpcStr, SDNode OpNode> { - def i64rr : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, B32:$b), - OpcStr # "64", - [(set i64:$dst, (OpNode i64:$a, i32:$b))]>; - def i64ri : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$a, i32imm:$b), - OpcStr # "64", - [(set i64:$dst, (OpNode i64:$a, (i32 imm:$b)))]>; - def i32rr : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b), - OpcStr # "32", - [(set i32:$dst, (OpNode i32:$a, i32:$b))]>; - def i32ri : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, i32imm:$b), - OpcStr # "32", - [(set i32:$dst, (OpNode i32:$a, (i32 imm:$b)))]>; - def i32ii : - BasicNVPTXInst<(outs B32:$dst), (ins i32imm:$a, i32imm:$b), - OpcStr # "32", - [(set i32:$dst, (OpNode (i32 imm:$a), (i32 imm:$b)))]>; - def i16rr : - BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, B32:$b), - OpcStr # "16", - [(set i16:$dst, (OpNode i16:$a, i32:$b))]>; - def i16ri : - BasicNVPTXInst<(outs B16:$dst), (ins B16:$a, i32imm:$b), - OpcStr # "16", - [(set i16:$dst, (OpNode i16:$a, (i32 imm:$b)))]>; + let hasSideEffects = false in { + foreach t = [I64RT, I32RT, I16RT] in { + def t.Size # _rr : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, B32:$b), + OpcStr # t.Size, + [(set t.Ty:$dst, (OpNode t.Ty:$a, i32:$b))]>; + def t.Size # _ri : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, i32imm:$b), + OpcStr # t.Size, + [(set t.Ty:$dst, (OpNode t.Ty:$a, (i32 imm:$b)))]>; + def t.Size # _ii : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, i32imm:$b), + OpcStr # t.Size, + [(set t.Ty:$dst, (OpNode (t.Ty imm:$a), (i32 imm:$b)))]>; + } + } } defm SHL : SHIFT<"shl.b", shl>; @@ -1353,14 +1181,11 @@ defm SRA : SHIFT<"shr.s", sra>; defm SRL : SHIFT<"shr.u", srl>; // Bit-reverse -def BREV32 : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$a), - "brev.b32", - [(set i32:$dst, (bitreverse i32:$a))]>; -def BREV64 : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$a), - "brev.b64", - [(set i64:$dst, (bitreverse i64:$a))]>; +foreach t = [I64RT, I32RT] in + def BREV_ # t.PtxType : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$a), + "brev." # t.PtxType, + [(set t.Ty:$dst, (bitreverse t.Ty:$a))]>; // @@ -1512,20 +1337,19 @@ def : Pat<(i16 (sext_inreg (trunc (prmt i32:$s, 0, byte_extract_prmt:$sel, PrmtN // Byte extraction via shift/trunc/sext -def : Pat<(i16 (sext_inreg (trunc i32:$s), i8)), - (CVT_s8_s32 $s, CvtNONE)>; -def : Pat<(i16 (sext_inreg (trunc (srl i32:$s, (i32 imm:$o))), i8)), +def : Pat<(i16 (sext_inreg (trunc i32:$s), i8)), (CVT_s8_s32 $s, CvtNONE)>; +def : Pat<(i16 (sext_inreg (trunc i64:$s), i8)), (CVT_s8_s64 $s, CvtNONE)>; + +def : Pat<(sext_inreg (srl i32:$s, (i32 imm:$o)), i8), (BFE_S32rii $s, imm:$o, 8)>; +def : Pat<(sext_inreg (srl i64:$s, (i32 imm:$o)), i8), (BFE_S64rii $s, imm:$o, 8)>; + +def : Pat<(i16 (sext_inreg (trunc (srl i32:$s, (i32 imm:$o))), i8)), (CVT_s8_s32 (BFE_S32rii $s, imm:$o, 8), CvtNONE)>; -def : Pat<(sext_inreg (srl i32:$s, (i32 imm:$o)), i8), - (BFE_S32rii $s, imm:$o, 8)>; +def : Pat<(i16 (sext_inreg (trunc (srl i64:$s, (i32 imm:$o))), i8)), + (CVT_s8_s64 (BFE_S64rii $s, imm:$o, 8), CvtNONE)>; + def : Pat<(i16 (sra (i16 (trunc i32:$s)), (i32 8))), (CVT_s8_s32 (BFE_S32rii $s, 8, 8), CvtNONE)>; -def : Pat<(sext_inreg (srl i64:$s, (i32 imm:$o)), i8), - (BFE_S64rii $s, imm:$o, 8)>; -def : Pat<(i16 (sext_inreg (trunc i64:$s), i8)), - (CVT_s8_s64 $s, CvtNONE)>; -def : Pat<(i16 (sext_inreg (trunc (srl i64:$s, (i32 imm:$o))), i8)), - (CVT_s8_s64 (BFE_S64rii $s, imm:$o, 8), CvtNONE)>; //----------------------------------- // Comparison instructions (setp, set) @@ -1615,10 +1439,7 @@ def SETP_bf16x2rr : def addr : ComplexPattern<pAny, 2, "SelectADDR">; -def ADDR_base : Operand<pAny> { - let PrintMethod = "printOperand"; -} - +def ADDR_base : Operand<pAny>; def ADDR : Operand<pAny> { let PrintMethod = "printMemOperand"; let MIOperandInfo = (ops ADDR_base, i32imm); @@ -1632,10 +1453,6 @@ def MmaCode : Operand<i32> { let PrintMethod = "printMmaCode"; } -def Offseti32imm : Operand<i32> { - let PrintMethod = "printOffseti32imm"; -} - // Get pointer to local stack. let hasSideEffects = false in { def MOV_DEPOT_ADDR : NVPTXInst<(outs B32:$d), (ins i32imm:$num), @@ -1647,33 +1464,31 @@ let hasSideEffects = false in { // copyPhysreg is hard-coded in NVPTXInstrInfo.cpp let hasSideEffects = false, isAsCheapAsAMove = true in { - // Class for register-to-register moves - class MOVr<RegisterClass RC, string OpStr> : - BasicNVPTXInst<(outs RC:$dst), (ins RC:$src), - "mov." # OpStr>; - - // Class for immediate-to-register moves - class MOVi<RegisterClass RC, string OpStr, ValueType VT, Operand IMMType, SDNode ImmNode> : - BasicNVPTXInst<(outs RC:$dst), (ins IMMType:$src), - "mov." # OpStr, - [(set VT:$dst, ImmNode:$src)]>; -} + let isMoveReg = true in + class MOVr<RegisterClass RC, string OpStr> : + BasicNVPTXInst<(outs RC:$dst), (ins RC:$src), "mov." # OpStr>; -def IMOV1r : MOVr<B1, "pred">; -def MOV16r : MOVr<B16, "b16">; -def IMOV32r : MOVr<B32, "b32">; -def IMOV64r : MOVr<B64, "b64">; -def IMOV128r : MOVr<B128, "b128">; + let isMoveImm = true in + class MOVi<RegTyInfo t, string suffix> : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.Imm:$src), + "mov." # suffix, + [(set t.Ty:$dst, t.ImmNode:$src)]>; +} +def MOV_B1_r : MOVr<B1, "pred">; +def MOV_B16_r : MOVr<B16, "b16">; +def MOV_B32_r : MOVr<B32, "b32">; +def MOV_B64_r : MOVr<B64, "b64">; +def MOV_B128_r : MOVr<B128, "b128">; -def IMOV1i : MOVi<B1, "pred", i1, i1imm, imm>; -def IMOV16i : MOVi<B16, "b16", i16, i16imm, imm>; -def IMOV32i : MOVi<B32, "b32", i32, i32imm, imm>; -def IMOV64i : MOVi<B64, "b64", i64, i64imm, imm>; -def FMOV16i : MOVi<B16, "b16", f16, f16imm, fpimm>; -def BFMOV16i : MOVi<B16, "b16", bf16, bf16imm, fpimm>; -def FMOV32i : MOVi<B32, "b32", f32, f32imm, fpimm>; -def FMOV64i : MOVi<B64, "b64", f64, f64imm, fpimm>; +def MOV_B1_i : MOVi<I1RT, "pred">; +def MOV_B16_i : MOVi<I16RT, "b16">; +def MOV_B32_i : MOVi<I32RT, "b32">; +def MOV_B64_i : MOVi<I64RT, "b64">; +def MOV_F16_i : MOVi<F16RT, "b16">; +def MOV_BF16_i : MOVi<BF16RT, "b16">; +def MOV_F32_i : MOVi<F32RT, "b32">; +def MOV_F64_i : MOVi<F64RT, "b64">; def to_tglobaladdr : SDNodeXForm<globaladdr, [{ @@ -1691,11 +1506,11 @@ def to_tframeindex : SDNodeXForm<frameindex, [{ return CurDAG->getTargetFrameIndex(N->getIndex(), N->getValueType(0)); }]>; -def : Pat<(i32 globaladdr:$dst), (IMOV32i (to_tglobaladdr $dst))>; -def : Pat<(i64 globaladdr:$dst), (IMOV64i (to_tglobaladdr $dst))>; +def : Pat<(i32 globaladdr:$dst), (MOV_B32_i (to_tglobaladdr $dst))>; +def : Pat<(i64 globaladdr:$dst), (MOV_B64_i (to_tglobaladdr $dst))>; -def : Pat<(i32 externalsym:$dst), (IMOV32i (to_texternsym $dst))>; -def : Pat<(i64 externalsym:$dst), (IMOV64i (to_texternsym $dst))>; +def : Pat<(i32 externalsym:$dst), (MOV_B32_i (to_texternsym $dst))>; +def : Pat<(i64 externalsym:$dst), (MOV_B64_i (to_texternsym $dst))>; //---- Copy Frame Index ---- def LEA_ADDRi : NVPTXInst<(outs B32:$dst), (ins ADDR:$addr), @@ -1709,56 +1524,39 @@ def : Pat<(i64 frameindex:$fi), (LEA_ADDRi64 (to_tframeindex $fi), 0)>; //----------------------------------- // Comparison and Selection //----------------------------------- +// TODO: These patterns seem very specific and brittle. We should try to find +// a more general solution. def cond_signed : PatLeaf<(cond), [{ return isSignedIntSetCC(N->get()); }]>; -def cond_not_signed : PatLeaf<(cond), [{ - return !isSignedIntSetCC(N->get()); -}]>; +// A 16-bit signed comparison of sign-extended byte extracts can be converted +// to 32-bit comparison if we change the PRMT to sign-extend the extracted +// bytes. +def : Pat<(setcc (i16 (sext_inreg (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE)), i8)), + (i16 (sext_inreg (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE)), i8)), + cond_signed:$cc), + (SETP_i32rr (PRMT_B32rii i32:$a, 0, (to_sign_extend_selector $sel_a), PrmtNONE), + (PRMT_B32rii i32:$b, 0, (to_sign_extend_selector $sel_b), PrmtNONE), + (cond2cc $cc))>; + +// A 16-bit comparison of truncated byte extracts can be be converted to 32-bit +// comparison because we know that the truncate is just trancating off zeros +// and that the most-significant byte is also zeros so the meaning of signed and +// unsigned comparisons will not be changed. +def : Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), + (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), + cond:$cc), + (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), + (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), + (cond2cc $cc))>; -// comparisons of i8 extracted with PRMT as i32 -// It's faster to do comparison directly on i32 extracted by PRMT, -// instead of the long conversion and sign extending. -def: Pat<(setcc (i16 (sext_inreg (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), i8)), - (i16 (sext_inreg (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), i8)), - cond_signed:$cc), - (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), - (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), - (cond2cc $cc))>; - -def: Pat<(setcc (i16 (sext_inreg (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE)), i8)), - (i16 (sext_inreg (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE)), i8)), - cond_signed:$cc), - (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), - (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), - (cond2cc $cc))>; - -def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), - (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), - cond_signed:$cc), - (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), - (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), - (cond2cc $cc))>; - -def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), - (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), - cond_not_signed:$cc), - (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE), - (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), - (cond2cc $cc))>; def SDTDeclareArrayParam : SDTypeProfile<0, 3, [SDTCisVT<0, i32>, SDTCisVT<1, i32>, SDTCisVT<2, i32>]>; def SDTDeclareScalarParam : SDTypeProfile<0, 2, [SDTCisVT<0, i32>, SDTCisVT<1, i32>]>; -def SDTLoadParamProfile : SDTypeProfile<1, 2, [SDTCisInt<1>, SDTCisInt<2>]>; -def SDTLoadParamV2Profile : SDTypeProfile<2, 2, [SDTCisSameAs<0, 1>, SDTCisInt<2>, SDTCisInt<3>]>; -def SDTLoadParamV4Profile : SDTypeProfile<4, 2, [SDTCisInt<4>, SDTCisInt<5>]>; -def SDTStoreParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>; -def SDTStoreParamV2Profile : SDTypeProfile<0, 4, [SDTCisInt<0>, SDTCisInt<1>]>; -def SDTStoreParamV4Profile : SDTypeProfile<0, 6, [SDTCisInt<0>, SDTCisInt<1>]>; def SDTMoveParamProfile : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisSameAs<0, 1>]>; def SDTProxyReg : SDTypeProfile<1, 1, [SDTCisSameAs<0, 1>]>; @@ -1770,104 +1568,20 @@ def declare_array_param : def declare_scalar_param : SDNode<"NVPTXISD::DeclareScalarParam", SDTDeclareScalarParam, [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; - -def LoadParam : - SDNode<"NVPTXISD::LoadParam", SDTLoadParamProfile, - [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>; -def LoadParamV2 : - SDNode<"NVPTXISD::LoadParamV2", SDTLoadParamV2Profile, - [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>; -def LoadParamV4 : - SDNode<"NVPTXISD::LoadParamV4", SDTLoadParamV4Profile, - [SDNPHasChain, SDNPMayLoad, SDNPOutGlue, SDNPInGlue]>; -def StoreParam : - SDNode<"NVPTXISD::StoreParam", SDTStoreParamProfile, - [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; -def StoreParamV2 : - SDNode<"NVPTXISD::StoreParamV2", SDTStoreParamV2Profile, - [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; -def StoreParamV4 : - SDNode<"NVPTXISD::StoreParamV4", SDTStoreParamV4Profile, - [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; def MoveParam : SDNode<"NVPTXISD::MoveParam", SDTMoveParamProfile, []>; def proxy_reg : SDNode<"NVPTXISD::ProxyReg", SDTProxyReg, [SDNPHasChain]>; /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns, - /// NumParams, Callee, Proto, InGlue) + /// NumParams, Callee, Proto) def SDTCallProfile : SDTypeProfile<0, 6, [SDTCisVT<0, i32>, SDTCisVT<1, i32>, SDTCisVT<2, i32>, SDTCisVT<3, i32>, SDTCisVT<5, i32>]>; -def call : - SDNode<"NVPTXISD::CALL", SDTCallProfile, - [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; - -let mayLoad = true in { - class LoadParamMemInst<NVPTXRegClass regclass, string opstr> : - NVPTXInst<(outs regclass:$dst), (ins Offseti32imm:$b), - !strconcat("ld.param", opstr, " \t$dst, [retval0$b];"), - []>; - - class LoadParamV2MemInst<NVPTXRegClass regclass, string opstr> : - NVPTXInst<(outs regclass:$dst, regclass:$dst2), (ins Offseti32imm:$b), - !strconcat("ld.param.v2", opstr, - " \t{{$dst, $dst2}}, [retval0$b];"), []>; - - class LoadParamV4MemInst<NVPTXRegClass regclass, string opstr> : - NVPTXInst<(outs regclass:$dst, regclass:$dst2, regclass:$dst3, - regclass:$dst4), - (ins Offseti32imm:$b), - !strconcat("ld.param.v4", opstr, - " \t{{$dst, $dst2, $dst3, $dst4}}, [retval0$b];"), - []>; -} - -let mayStore = true in { - - multiclass StoreParamInst<NVPTXRegClass regclass, Operand IMMType, string opstr, bit support_imm = true> { - foreach op = [IMMType, regclass] in - if !or(support_imm, !isa<NVPTXRegClass>(op)) then - def _ # !if(!isa<NVPTXRegClass>(op), "r", "i") - : NVPTXInst<(outs), - (ins op:$val, i32imm:$a, Offseti32imm:$b), - "st.param" # opstr # " \t[param$a$b], $val;", - []>; - } - - multiclass StoreParamV2Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> { - foreach op1 = [IMMType, regclass] in - foreach op2 = [IMMType, regclass] in - def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i") - # !if(!isa<NVPTXRegClass>(op2), "r", "i") - : NVPTXInst<(outs), - (ins op1:$val1, op2:$val2, - i32imm:$a, Offseti32imm:$b), - "st.param.v2" # opstr # " \t[param$a$b], {{$val1, $val2}};", - []>; - } - - multiclass StoreParamV4Inst<NVPTXRegClass regclass, Operand IMMType, string opstr> { - foreach op1 = [IMMType, regclass] in - foreach op2 = [IMMType, regclass] in - foreach op3 = [IMMType, regclass] in - foreach op4 = [IMMType, regclass] in - def _ # !if(!isa<NVPTXRegClass>(op1), "r", "i") - # !if(!isa<NVPTXRegClass>(op2), "r", "i") - # !if(!isa<NVPTXRegClass>(op3), "r", "i") - # !if(!isa<NVPTXRegClass>(op4), "r", "i") - - : NVPTXInst<(outs), - (ins op1:$val1, op2:$val2, op3:$val3, op4:$val4, - i32imm:$a, Offseti32imm:$b), - "st.param.v4" # opstr # - " \t[param$a$b], {{$val1, $val2, $val3, $val4}};", - []>; - } -} +def call : SDNode<"NVPTXISD::CALL", SDTCallProfile, [SDNPHasChain, SDNPSideEffect]>; /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns, -/// NumParams, Callee, Proto, InGlue) +/// NumParams, Callee, Proto) def CallOperand : Operand<i32> { let PrintMethod = "printCallOperand"; } @@ -1904,43 +1618,6 @@ foreach is_convergent = [0, 1] in { (call_uni_inst $addr, imm:$rets, imm:$params)>; } -def LoadParamMemI64 : LoadParamMemInst<B64, ".b64">; -def LoadParamMemI32 : LoadParamMemInst<B32, ".b32">; -def LoadParamMemI16 : LoadParamMemInst<B16, ".b16">; -def LoadParamMemI8 : LoadParamMemInst<B16, ".b8">; -def LoadParamMemV2I64 : LoadParamV2MemInst<B64, ".b64">; -def LoadParamMemV2I32 : LoadParamV2MemInst<B32, ".b32">; -def LoadParamMemV2I16 : LoadParamV2MemInst<B16, ".b16">; -def LoadParamMemV2I8 : LoadParamV2MemInst<B16, ".b8">; -def LoadParamMemV4I32 : LoadParamV4MemInst<B32, ".b32">; -def LoadParamMemV4I16 : LoadParamV4MemInst<B16, ".b16">; -def LoadParamMemV4I8 : LoadParamV4MemInst<B16, ".b8">; - -defm StoreParamI64 : StoreParamInst<B64, i64imm, ".b64">; -defm StoreParamI32 : StoreParamInst<B32, i32imm, ".b32">; -defm StoreParamI16 : StoreParamInst<B16, i16imm, ".b16">; -defm StoreParamI8 : StoreParamInst<B16, i8imm, ".b8">; - -defm StoreParamI8TruncI32 : StoreParamInst<B32, i8imm, ".b8", /* support_imm */ false>; -defm StoreParamI8TruncI64 : StoreParamInst<B64, i8imm, ".b8", /* support_imm */ false>; - -defm StoreParamV2I64 : StoreParamV2Inst<B64, i64imm, ".b64">; -defm StoreParamV2I32 : StoreParamV2Inst<B32, i32imm, ".b32">; -defm StoreParamV2I16 : StoreParamV2Inst<B16, i16imm, ".b16">; -defm StoreParamV2I8 : StoreParamV2Inst<B16, i8imm, ".b8">; - -defm StoreParamV4I32 : StoreParamV4Inst<B32, i32imm, ".b32">; -defm StoreParamV4I16 : StoreParamV4Inst<B16, i16imm, ".b16">; -defm StoreParamV4I8 : StoreParamV4Inst<B16, i8imm, ".b8">; - -defm StoreParamF32 : StoreParamInst<B32, f32imm, ".b32">; -defm StoreParamF64 : StoreParamInst<B64, f64imm, ".b64">; - -defm StoreParamV2F32 : StoreParamV2Inst<B32, f32imm, ".b32">; -defm StoreParamV2F64 : StoreParamV2Inst<B64, f64imm, ".b64">; - -defm StoreParamV4F32 : StoreParamV4Inst<B32, f32imm, ".b32">; - def DECLARE_PARAM_array : NVPTXInst<(outs), (ins i32imm:$a, i32imm:$align, i32imm:$size), ".param .align $align .b8 \t$a[$size];", []>; @@ -1953,6 +1630,18 @@ def : Pat<(declare_array_param externalsym:$a, imm:$align, imm:$size), def : Pat<(declare_scalar_param externalsym:$a, imm:$size), (DECLARE_PARAM_scalar (to_texternsym $a), imm:$size)>; +// Call prototype wrapper, this is a dummy instruction that just prints it's +// operand which is string defining the prototype. +def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>; +def CallPrototype : + SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype, + [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; +def ProtoIdent : Operand<i32> { let PrintMethod = "printProtoIdent"; } +def CALL_PROTOTYPE : + NVPTXInst<(outs), (ins ProtoIdent:$ident), + "$ident", [(CallPrototype (i32 texternalsym:$ident))]>; + + foreach t = [I32RT, I64RT] in { defvar inst_name = "MOV" # t.Size # "_PARAM"; def inst_name : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), "mov.b" # t.Size>; @@ -1972,6 +1661,32 @@ defm ProxyRegB16 : ProxyRegInst<"b16", B16>; defm ProxyRegB32 : ProxyRegInst<"b32", B32>; defm ProxyRegB64 : ProxyRegInst<"b64", B64>; + +// Callseq start and end + +// Note: these nodes are marked as SDNPMayStore and SDNPMayLoad because +// they define the scope in which the declared params may be used. Therefore +// we add these flags to ensure ld.param and st.param are not sunk or hoisted +// out of that scope. + +def callseq_start : SDNode<"ISD::CALLSEQ_START", + SDCallSeqStart<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>, + [SDNPHasChain, SDNPOutGlue, + SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>; +def callseq_end : SDNode<"ISD::CALLSEQ_END", + SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>, + [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue, + SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>; + +def Callseq_Start : + NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2), + "\\{ // callseq $amt1, $amt2", + [(callseq_start timm:$amt1, timm:$amt2)]>; +def Callseq_End : + NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2), + "\\} // callseq $amt1", + [(callseq_end timm:$amt1, timm:$amt2)]>; + // // Load / Store Handling // @@ -1984,7 +1699,6 @@ class LD<NVPTXRegClass regclass> "\t$dst, [$addr];", []>; let mayLoad=1, hasSideEffects=0 in { - def LD_i8 : LD<B16>; def LD_i16 : LD<B16>; def LD_i32 : LD<B32>; def LD_i64 : LD<B64>; @@ -2000,7 +1714,6 @@ class ST<DAGOperand O> " \t[$addr], $src;", []>; let mayStore=1, hasSideEffects=0 in { - def ST_i8 : ST<RI16>; def ST_i16 : ST<RI16>; def ST_i32 : ST<RI32>; def ST_i64 : ST<RI64>; @@ -2033,7 +1746,6 @@ multiclass LD_VEC<NVPTXRegClass regclass, bit support_v8 = false> { "[$addr];", []>; } let mayLoad=1, hasSideEffects=0 in { - defm LDV_i8 : LD_VEC<B16>; defm LDV_i16 : LD_VEC<B16>; defm LDV_i32 : LD_VEC<B32, support_v8 = true>; defm LDV_i64 : LD_VEC<B64>; @@ -2067,7 +1779,6 @@ multiclass ST_VEC<DAGOperand O, bit support_v8 = false> { } let mayStore=1, hasSideEffects=0 in { - defm STV_i8 : ST_VEC<RI16>; defm STV_i16 : ST_VEC<RI16>; defm STV_i32 : ST_VEC<RI32, support_v8 = true>; defm STV_i64 : ST_VEC<RI64>; @@ -2237,14 +1948,14 @@ def : Pat<(i64 (anyext i32:$a)), (CVT_u64_u32 $a, CvtNONE)>; // truncate i64 def : Pat<(i32 (trunc i64:$a)), (CVT_u32_u64 $a, CvtNONE)>; def : Pat<(i16 (trunc i64:$a)), (CVT_u16_u64 $a, CvtNONE)>; -def : Pat<(i1 (trunc i64:$a)), (SETP_i64ri (ANDb64ri $a, 1), 0, CmpNE)>; +def : Pat<(i1 (trunc i64:$a)), (SETP_i64ri (AND_b64ri $a, 1), 0, CmpNE)>; // truncate i32 def : Pat<(i16 (trunc i32:$a)), (CVT_u16_u32 $a, CvtNONE)>; -def : Pat<(i1 (trunc i32:$a)), (SETP_i32ri (ANDb32ri $a, 1), 0, CmpNE)>; +def : Pat<(i1 (trunc i32:$a)), (SETP_i32ri (AND_b32ri $a, 1), 0, CmpNE)>; // truncate i16 -def : Pat<(i1 (trunc i16:$a)), (SETP_i16ri (ANDb16ri $a, 1), 0, CmpNE)>; +def : Pat<(i1 (trunc i16:$a)), (SETP_i16ri (AND_b16ri $a, 1), 0, CmpNE)>; // sext_inreg def : Pat<(sext_inreg i16:$a, i8), (CVT_INREG_s16_s8 $a)>; @@ -2488,52 +2199,20 @@ defm : CVT_ROUND<frint, CvtRNI, CvtRNI_FTZ>; //----------------------------------- let isTerminator=1 in { - let isReturn=1, isBarrier=1 in + let isReturn=1, isBarrier=1 in def Return : BasicNVPTXInst<(outs), (ins), "ret", [(retglue)]>; - let isBranch=1 in - def CBranch : NVPTXInst<(outs), (ins B1:$a, brtarget:$target), + let isBranch=1 in { + def CBranch : NVPTXInst<(outs), (ins B1:$a, brtarget:$target), "@$a bra \t$target;", [(brcond i1:$a, bb:$target)]>; - let isBranch=1 in - def CBranchOther : NVPTXInst<(outs), (ins B1:$a, brtarget:$target), - "@!$a bra \t$target;", []>; - let isBranch=1, isBarrier=1 in + let isBarrier=1 in def GOTO : BasicNVPTXInst<(outs), (ins brtarget:$target), - "bra.uni", [(br bb:$target)]>; + "bra.uni", [(br bb:$target)]>; + } } -def : Pat<(brcond i32:$a, bb:$target), - (CBranch (SETP_i32ri $a, 0, CmpNE), bb:$target)>; - -// SelectionDAGBuilder::visitSWitchCase() will invert the condition of a -// conditional branch if the target block is the next block so that the code -// can fall through to the target block. The inversion is done by 'xor -// condition, 1', which will be translated to (setne condition, -1). Since ptx -// supports '@!pred bra target', we should use it. -def : Pat<(brcond (i1 (setne i1:$a, -1)), bb:$target), - (CBranchOther $a, bb:$target)>; - -// Call -def SDT_NVPTXCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>, - SDTCisVT<1, i32>]>; -def SDT_NVPTXCallSeqEnd : SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>; - -def callseq_start : SDNode<"ISD::CALLSEQ_START", SDT_NVPTXCallSeqStart, - [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>; -def callseq_end : SDNode<"ISD::CALLSEQ_END", SDT_NVPTXCallSeqEnd, - [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue, - SDNPSideEffect]>; - -def Callseq_Start : - NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2), - "\\{ // callseq $amt1, $amt2", - [(callseq_start timm:$amt1, timm:$amt2)]>; -def Callseq_End : - NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2), - "\\} // callseq $amt1", - [(callseq_end timm:$amt1, timm:$amt2)]>; // trap instruction def trapinst : BasicNVPTXInst<(outs), (ins), "trap", [(trap)]>, Requires<[noPTXASUnreachableBug]>; @@ -2543,18 +2222,6 @@ def trapexitinst : NVPTXInst<(outs), (ins), "trap; exit;", [(trap)]>, Requires<[ // brkpt instruction def debugtrapinst : BasicNVPTXInst<(outs), (ins), "brkpt", [(debugtrap)]>; -// Call prototype wrapper -def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>; -def CallPrototype : - SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype, - [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>; -def ProtoIdent : Operand<i32> { - let PrintMethod = "printProtoIdent"; -} -def CALL_PROTOTYPE : - NVPTXInst<(outs), (ins ProtoIdent:$ident), - "$ident", [(CallPrototype (i32 texternalsym:$ident))]>; - def SDTDynAllocaOp : SDTypeProfile<1, 2, [SDTCisSameAs<0, 1>, SDTCisInt<1>, SDTCisVT<2, i32>]>; |