aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
authorBenjamin Chetioui <bchetioui@google.com>2023-01-05 09:26:13 +0100
committerChristian Sigg <csigg@google.com>2023-01-05 09:27:54 +0100
commit2c3f82b7759691f3b67f7e5940e95ac3434b1a9c (patch)
treeae2fa0f4e6ae25c243e26fa458836c338f39534f /llvm/lib
parentccc13241208f3b6975bbef384f0a01b1b6e83e8e (diff)
downloadllvm-2c3f82b7759691f3b67f7e5940e95ac3434b1a9c.zip
llvm-2c3f82b7759691f3b67f7e5940e95ac3434b1a9c.tar.gz
llvm-2c3f82b7759691f3b67f7e5940e95ac3434b1a9c.tar.bz2
[NVPTX] Fix NVPTX lowering of frem when denominator is infinite.
`frem x, {+,-}inf` must return x to match the specification of LLVM's frem. Reviewed By: tra Differential Revision: https://reviews.llvm.org/D140846
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td242
1 files changed, 155 insertions, 87 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index a114d92..b6a1394 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -132,6 +132,7 @@ def doMulWide : Predicate<"doMulWide">;
def allowFMA : Predicate<"allowFMA()">;
def noFMA : Predicate<"!allowFMA()">;
def allowUnsafeFPMath : Predicate<"allowUnsafeFPMath()">;
+def noUnsafeFPMath : Predicate<"!allowUnsafeFPMath()">;
def do_DIVF32_APPROX : Predicate<"getDivF32Level()==0">;
def do_DIVF32_FULL : Predicate<"getDivF32Level()==1">;
@@ -166,7 +167,7 @@ def hasSM80 : Predicate<"Subtarget->getSmVersion() >= 80">;
def hasSM86 : Predicate<"Subtarget->getSmVersion() >= 86">;
// non-sync shfl instructions are not available on sm_70+ in PTX6.4+
-def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
+def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
"&& Subtarget->getPTXVersion() >= 64)">;
def useShortPtr : Predicate<"useShortPointers()">;
@@ -192,7 +193,7 @@ class ValueToRegClass<ValueType T> {
!eq(name, "af32"): Float32ArgRegs,
!eq(name, "if64"): Float64ArgRegs,
);
-}
+}
//===----------------------------------------------------------------------===//
@@ -598,6 +599,99 @@ multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
}
//-----------------------------------
+// Selection instructions (selp)
+//-----------------------------------
+
+// TODO: Missing slct
+
+// selp instructions that don't have any pattern matches; we explicitly use
+// them within this file.
+let hasSideEffects = false in {
+ multiclass SELP<string TypeStr, RegisterClass RC, Operand ImmCls> {
+ def rr : NVPTXInst<(outs RC:$dst),
+ (ins RC:$a, RC:$b, Int1Regs:$p),
+ !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
+ def ri : NVPTXInst<(outs RC:$dst),
+ (ins RC:$a, ImmCls:$b, Int1Regs:$p),
+ !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
+ def ir : NVPTXInst<(outs RC:$dst),
+ (ins ImmCls:$a, RC:$b, Int1Regs:$p),
+ !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
+ def ii : NVPTXInst<(outs RC:$dst),
+ (ins ImmCls:$a, ImmCls:$b, Int1Regs:$p),
+ !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
+ }
+
+ multiclass SELP_PATTERN<string TypeStr, ValueType T, RegisterClass RC,
+ Operand ImmCls, SDNode ImmNode> {
+ def rr :
+ NVPTXInst<(outs RC:$dst),
+ (ins RC:$a, RC:$b, Int1Regs:$p),
+ !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
+ [(set (T RC:$dst), (select Int1Regs:$p, (T RC:$a), (T RC:$b)))]>;
+ def ri :
+ NVPTXInst<(outs RC:$dst),
+ (ins RC:$a, ImmCls:$b, Int1Regs:$p),
+ !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
+ [(set (T RC:$dst), (select Int1Regs:$p, (T RC:$a), (T ImmNode:$b)))]>;
+ def ir :
+ NVPTXInst<(outs RC:$dst),
+ (ins ImmCls:$a, RC:$b, Int1Regs:$p),
+ !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
+ [(set (T RC:$dst), (select Int1Regs:$p, ImmNode:$a, (T RC:$b)))]>;
+ def ii :
+ NVPTXInst<(outs RC:$dst),
+ (ins ImmCls:$a, ImmCls:$b, Int1Regs:$p),
+ !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
+ [(set (T RC:$dst), (select Int1Regs:$p, ImmNode:$a, ImmNode:$b))]>;
+ }
+}
+
+// Don't pattern match on selp.{s,u}{16,32,64} -- selp.b{16,32,64} is just as
+// good.
+defm SELP_b16 : SELP_PATTERN<"b16", i16, Int16Regs, i16imm, imm>;
+defm SELP_s16 : SELP<"s16", Int16Regs, i16imm>;
+defm SELP_u16 : SELP<"u16", Int16Regs, i16imm>;
+defm SELP_b32 : SELP_PATTERN<"b32", i32, Int32Regs, i32imm, imm>;
+defm SELP_s32 : SELP<"s32", Int32Regs, i32imm>;
+defm SELP_u32 : SELP<"u32", Int32Regs, i32imm>;
+defm SELP_b64 : SELP_PATTERN<"b64", i64, Int64Regs, i64imm, imm>;
+defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>;
+defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>;
+defm SELP_f16 : SELP_PATTERN<"b16", f16, Float16Regs, f16imm, fpimm>;
+
+defm SELP_f32 : SELP_PATTERN<"f32", f32, Float32Regs, f32imm, fpimm>;
+defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>;
+
+// This does not work as tablegen fails to infer the type of 'imm'.
+// def v2f16imm : Operand<v2f16>;
+// defm SELP_f16x2 : SELP_PATTERN<"b32", v2f16, Float16x2Regs, v2f16imm, imm>;
+
+def SELP_f16x2rr :
+ NVPTXInst<(outs Float16x2Regs:$dst),
+ (ins Float16x2Regs:$a, Float16x2Regs:$b, Int1Regs:$p),
+ "selp.b32 \t$dst, $a, $b, $p;",
+ [(set Float16x2Regs:$dst,
+ (select Int1Regs:$p, (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>;
+
+//-----------------------------------
+// Test Instructions
+//-----------------------------------
+
+def TESTINF_f32r : NVPTXInst<(outs Int1Regs:$p), (ins Float32Regs:$a),
+ "testp.infinite.f32 \t$p, $a;",
+ []>;
+def TESTINF_f32i : NVPTXInst<(outs Int1Regs:$p), (ins f32imm:$a),
+ "testp.infinite.f32 \t$p, $a;",
+ []>;
+def TESTINF_f64r : NVPTXInst<(outs Int1Regs:$p), (ins Float64Regs:$a),
+ "testp.infinite.f64 \t$p, $a;",
+ []>;
+def TESTINF_f64i : NVPTXInst<(outs Int1Regs:$p), (ins f64imm:$a),
+ "testp.infinite.f64 \t$p, $a;",
+ []>;
+
+//-----------------------------------
// Integer Arithmetic
//-----------------------------------
@@ -1154,39 +1248,89 @@ def COSF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
Requires<[allowUnsafeFPMath]>;
// Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
-// i.e. "poor man's fmod()"
+// i.e. "poor man's fmod()". When y is infinite, x is returned. This matches the
+// semantics of LLVM's frem.
// frem - f32 FTZ
def : Pat<(frem Float32Regs:$x, Float32Regs:$y),
(FSUBf32rr_ftz Float32Regs:$x, (FMULf32rr_ftz (CVT_f32_f32
(FDIV32rr_prec_ftz Float32Regs:$x, Float32Regs:$y), CvtRZI_FTZ),
Float32Regs:$y))>,
- Requires<[doF32FTZ]>;
+ Requires<[doF32FTZ, allowUnsafeFPMath]>;
def : Pat<(frem Float32Regs:$x, fpimm:$y),
(FSUBf32rr_ftz Float32Regs:$x, (FMULf32ri_ftz (CVT_f32_f32
(FDIV32ri_prec_ftz Float32Regs:$x, fpimm:$y), CvtRZI_FTZ),
fpimm:$y))>,
- Requires<[doF32FTZ]>;
+ Requires<[doF32FTZ, allowUnsafeFPMath]>;
+
+def : Pat<(frem Float32Regs:$x, Float32Regs:$y),
+ (SELP_f32rr Float32Regs:$x,
+ (FSUBf32rr_ftz Float32Regs:$x, (FMULf32rr_ftz (CVT_f32_f32
+ (FDIV32rr_prec_ftz Float32Regs:$x, Float32Regs:$y), CvtRZI_FTZ),
+ Float32Regs:$y)),
+ (TESTINF_f32r Float32Regs:$y))>,
+ Requires<[doF32FTZ, noUnsafeFPMath]>;
+def : Pat<(frem Float32Regs:$x, fpimm:$y),
+ (SELP_f32rr Float32Regs:$x,
+ (FSUBf32rr_ftz Float32Regs:$x, (FMULf32ri_ftz (CVT_f32_f32
+ (FDIV32ri_prec_ftz Float32Regs:$x, fpimm:$y), CvtRZI_FTZ),
+ fpimm:$y)),
+ (TESTINF_f32i fpimm:$y))>,
+ Requires<[doF32FTZ, noUnsafeFPMath]>;
// frem - f32
def : Pat<(frem Float32Regs:$x, Float32Regs:$y),
(FSUBf32rr Float32Regs:$x, (FMULf32rr (CVT_f32_f32
(FDIV32rr_prec Float32Regs:$x, Float32Regs:$y), CvtRZI),
- Float32Regs:$y))>;
+ Float32Regs:$y))>,
+ Requires<[allowUnsafeFPMath]>;
def : Pat<(frem Float32Regs:$x, fpimm:$y),
(FSUBf32rr Float32Regs:$x, (FMULf32ri (CVT_f32_f32
(FDIV32ri_prec Float32Regs:$x, fpimm:$y), CvtRZI),
- fpimm:$y))>;
+ fpimm:$y))>,
+ Requires<[allowUnsafeFPMath]>;
+
+def : Pat<(frem Float32Regs:$x, Float32Regs:$y),
+ (SELP_f32rr Float32Regs:$x,
+ (FSUBf32rr Float32Regs:$x, (FMULf32rr (CVT_f32_f32
+ (FDIV32rr_prec Float32Regs:$x, Float32Regs:$y), CvtRZI),
+ Float32Regs:$y)),
+ (TESTINF_f32r Float32Regs:$y))>,
+ Requires<[noUnsafeFPMath]>;
+def : Pat<(frem Float32Regs:$x, fpimm:$y),
+ (SELP_f32rr Float32Regs:$x,
+ (FSUBf32rr Float32Regs:$x, (FMULf32ri (CVT_f32_f32
+ (FDIV32ri_prec Float32Regs:$x, fpimm:$y), CvtRZI),
+ fpimm:$y)),
+ (TESTINF_f32i fpimm:$y))>,
+ Requires<[noUnsafeFPMath]>;
// frem - f64
def : Pat<(frem Float64Regs:$x, Float64Regs:$y),
(FSUBf64rr Float64Regs:$x, (FMULf64rr (CVT_f64_f64
(FDIV64rr Float64Regs:$x, Float64Regs:$y), CvtRZI),
- Float64Regs:$y))>;
+ Float64Regs:$y))>,
+ Requires<[allowUnsafeFPMath]>;
def : Pat<(frem Float64Regs:$x, fpimm:$y),
(FSUBf64rr Float64Regs:$x, (FMULf64ri (CVT_f64_f64
(FDIV64ri Float64Regs:$x, fpimm:$y), CvtRZI),
- fpimm:$y))>;
+ fpimm:$y))>,
+ Requires<[allowUnsafeFPMath]>;
+
+def : Pat<(frem Float64Regs:$x, Float64Regs:$y),
+ (SELP_f64rr Float64Regs:$x,
+ (FSUBf64rr Float64Regs:$x, (FMULf64rr (CVT_f64_f64
+ (FDIV64rr Float64Regs:$x, Float64Regs:$y), CvtRZI),
+ Float64Regs:$y)),
+ (TESTINF_f64r Float64Regs:$y))>,
+ Requires<[noUnsafeFPMath]>;
+def : Pat<(frem Float64Regs:$x, fpimm:$y),
+ (SELP_f64rr Float64Regs:$x,
+ (FSUBf64rr Float64Regs:$x, (FMULf64ri (CVT_f64_f64
+ (FDIV64ri Float64Regs:$x, fpimm:$y), CvtRZI),
+ fpimm:$y)),
+ (TESTINF_f64r Float64Regs:$y))>,
+ Requires<[noUnsafeFPMath]>;
//-----------------------------------
// Bitwise operations
@@ -1569,82 +1713,6 @@ defm SET_f32 : SET<"f32", Float32Regs, f32imm>;
defm SET_f64 : SET<"f64", Float64Regs, f64imm>;
//-----------------------------------
-// Selection instructions (selp)
-//-----------------------------------
-
-// FIXME: Missing slct
-
-// selp instructions that don't have any pattern matches; we explicitly use
-// them within this file.
-let hasSideEffects = false in {
- multiclass SELP<string TypeStr, RegisterClass RC, Operand ImmCls> {
- def rr : NVPTXInst<(outs RC:$dst),
- (ins RC:$a, RC:$b, Int1Regs:$p),
- !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
- def ri : NVPTXInst<(outs RC:$dst),
- (ins RC:$a, ImmCls:$b, Int1Regs:$p),
- !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
- def ir : NVPTXInst<(outs RC:$dst),
- (ins ImmCls:$a, RC:$b, Int1Regs:$p),
- !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
- def ii : NVPTXInst<(outs RC:$dst),
- (ins ImmCls:$a, ImmCls:$b, Int1Regs:$p),
- !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
- }
-
- multiclass SELP_PATTERN<string TypeStr, ValueType T, RegisterClass RC,
- Operand ImmCls, SDNode ImmNode> {
- def rr :
- NVPTXInst<(outs RC:$dst),
- (ins RC:$a, RC:$b, Int1Regs:$p),
- !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
- [(set (T RC:$dst), (select Int1Regs:$p, (T RC:$a), (T RC:$b)))]>;
- def ri :
- NVPTXInst<(outs RC:$dst),
- (ins RC:$a, ImmCls:$b, Int1Regs:$p),
- !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
- [(set (T RC:$dst), (select Int1Regs:$p, (T RC:$a), (T ImmNode:$b)))]>;
- def ir :
- NVPTXInst<(outs RC:$dst),
- (ins ImmCls:$a, RC:$b, Int1Regs:$p),
- !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
- [(set (T RC:$dst), (select Int1Regs:$p, ImmNode:$a, (T RC:$b)))]>;
- def ii :
- NVPTXInst<(outs RC:$dst),
- (ins ImmCls:$a, ImmCls:$b, Int1Regs:$p),
- !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
- [(set (T RC:$dst), (select Int1Regs:$p, ImmNode:$a, ImmNode:$b))]>;
- }
-}
-
-// Don't pattern match on selp.{s,u}{16,32,64} -- selp.b{16,32,64} is just as
-// good.
-defm SELP_b16 : SELP_PATTERN<"b16", i16, Int16Regs, i16imm, imm>;
-defm SELP_s16 : SELP<"s16", Int16Regs, i16imm>;
-defm SELP_u16 : SELP<"u16", Int16Regs, i16imm>;
-defm SELP_b32 : SELP_PATTERN<"b32", i32, Int32Regs, i32imm, imm>;
-defm SELP_s32 : SELP<"s32", Int32Regs, i32imm>;
-defm SELP_u32 : SELP<"u32", Int32Regs, i32imm>;
-defm SELP_b64 : SELP_PATTERN<"b64", i64, Int64Regs, i64imm, imm>;
-defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>;
-defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>;
-defm SELP_f16 : SELP_PATTERN<"b16", f16, Float16Regs, f16imm, fpimm>;
-
-defm SELP_f32 : SELP_PATTERN<"f32", f32, Float32Regs, f32imm, fpimm>;
-defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>;
-
-// This does not work as tablegen fails to infer the type of 'imm'.
-//def v2f16imm : Operand<v2f16>;
-//defm SELP_f16x2 : SELP_PATTERN<"b32", v2f16, Float16x2Regs, v2f16imm, imm>;
-
-def SELP_f16x2rr :
- NVPTXInst<(outs Float16x2Regs:$dst),
- (ins Float16x2Regs:$a, Float16x2Regs:$b, Int1Regs:$p),
- "selp.b32 \t$dst, $a, $b, $p;",
- [(set Float16x2Regs:$dst,
- (select Int1Regs:$p, (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>;
-
-//-----------------------------------
// Data Movement (Load / Store, Move)
//-----------------------------------
@@ -1879,7 +1947,7 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
def : Pat<(i1 (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b))),
(SETP_f16rr Float16Regs:$a, Float16Regs:$b, ModeFTZ)>,
Requires<[useFP16Math,doF32FTZ]>;
- def : Pat<(i1 (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b))),
+ def : Pat<(i1 (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b))),
(SETP_f16rr Float16Regs:$a, Float16Regs:$b, Mode)>,
Requires<[useFP16Math]>;
def : Pat<(i1 (OpNode (f16 Float16Regs:$a), fpimm:$b)),
@@ -2700,7 +2768,7 @@ let mayStore=1, hasSideEffects=0 in {
//---- Conversion ----
-class F_BITCONVERT<string SzStr, ValueType TIn, ValueType TOut,
+class F_BITCONVERT<string SzStr, ValueType TIn, ValueType TOut,
NVPTXRegClass regclassIn = ValueToRegClass<TIn>.ret,
NVPTXRegClass regclassOut = ValueToRegClass<TOut>.ret> :
NVPTXInst<(outs regclassOut:$d), (ins regclassIn:$a),