diff options
| author | Justin Lebar <jlebar@google.com> | 2016-09-09 21:07:26 +0000 |
|---|---|---|
| committer | Justin Lebar <jlebar@google.com> | 2016-09-09 21:07:26 +0000 |
| commit | b5e884976b71a1bd12653d6c4d67bf121da95e5f (patch) | |
| tree | 5eb0b0d93c449f405ba340403edab83f0e7de3c2 | |
| parent | b9e51397bfe1fdbf0f1ae444744a8958b985bcb4 (diff) | |
| download | llvm-b5e884976b71a1bd12653d6c4d67bf121da95e5f.zip llvm-b5e884976b71a1bd12653d6c4d67bf121da95e5f.tar.gz llvm-b5e884976b71a1bd12653d6c4d67bf121da95e5f.tar.bz2 | |
[NVPTX] Implement llvm.fabs.f32, llvm.max.f32, etc.
Summary:
Previously these only worked via NVPTX-specific intrinsics.
This change will allow us to convert these target-specific intrinsics
into the general LLVM versions, allowing existing LLVM passes to reason
about their behavior.
It also gets us some minor codegen improvements as-is, from situations
where we canonicalize code into one of these llvm intrinsics.
Reviewers: majnemer
Subscribers: llvm-commits, jholewinski, tra
Differential Revision: https://reviews.llvm.org/D24300
llvm-svn: 281092
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 22 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 126 | ||||
| -rw-r--r-- | llvm/test/CodeGen/NVPTX/bug22322.ll | 2 | ||||
| -rw-r--r-- | llvm/test/CodeGen/NVPTX/math-intrins.ll | 261 |
4 files changed, 394 insertions, 17 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 0ad59b0..52eaf6e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -279,6 +279,28 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setTargetDAGCombine(ISD::SHL); setTargetDAGCombine(ISD::SELECT); + // Library functions. These default to Expand, but we have instructions + // for them. + setOperationAction(ISD::FCEIL, MVT::f32, Legal); + setOperationAction(ISD::FCEIL, MVT::f64, Legal); + setOperationAction(ISD::FFLOOR, MVT::f32, Legal); + setOperationAction(ISD::FFLOOR, MVT::f64, Legal); + setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal); + setOperationAction(ISD::FNEARBYINT, MVT::f64, Legal); + setOperationAction(ISD::FRINT, MVT::f32, Legal); + setOperationAction(ISD::FRINT, MVT::f64, Legal); + setOperationAction(ISD::FROUND, MVT::f32, Legal); + setOperationAction(ISD::FROUND, MVT::f64, Legal); + setOperationAction(ISD::FTRUNC, MVT::f32, Legal); + setOperationAction(ISD::FTRUNC, MVT::f64, Legal); + setOperationAction(ISD::FMINNUM, MVT::f32, Legal); + setOperationAction(ISD::FMINNUM, MVT::f64, Legal); + setOperationAction(ISD::FMAXNUM, MVT::f32, Legal); + setOperationAction(ISD::FMAXNUM, MVT::f64, Legal); + + // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate. + // No FPOW or FREM in PTX. + // Now deduce the information based on the above mentioned // actions computeRegisterProperties(STI.getRegisterInfo()); diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index a59235ba..8418e0e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -207,15 +207,63 @@ multiclass ADD_SUB_INT_32<string OpcStr, SDNode OpNode> { } // Template for instructions which take three fp64 or fp32 args. The -// instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64"). +// instructions are named "<OpcStr>.f<Width>" (e.g. "min.f64"). // // Also defines ftz (flush subnormal inputs and results to sign-preserving // zero) variants for fp32 functions. +// +// This multiclass should be used for nodes that cannot be folded into FMAs. +// For nodes that can be folded into FMAs (i.e. adds and muls), use +// F3_fma_component. multiclass F3<string OpcStr, SDNode OpNode> { def f64rr : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$a, Float64Regs:$b), !strconcat(OpcStr, ".f64 \t$dst, $a, $b;"), + [(set Float64Regs:$dst, (OpNode Float64Regs:$a, Float64Regs:$b))]>; + def f64ri : + NVPTXInst<(outs Float64Regs:$dst), + (ins Float64Regs:$a, f64imm:$b), + !strconcat(OpcStr, ".f64 \t$dst, $a, $b;"), + [(set Float64Regs:$dst, (OpNode Float64Regs:$a, fpimm:$b))]>; + def f32rr_ftz : + NVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, Float32Regs:$b), + !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b;"), + [(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>, + Requires<[doF32FTZ]>; + def f32ri_ftz : + NVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, f32imm:$b), + !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b;"), + [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>, + Requires<[doF32FTZ]>; + def f32rr : + NVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, Float32Regs:$b), + !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"), + [(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>; + def f32ri : + NVPTXInst<(outs Float32Regs:$dst), + (ins Float32Regs:$a, f32imm:$b), + !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"), + [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>; +} + +// Template for instructions which take three fp64 or fp32 args. The +// instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64"). +// +// Also defines ftz (flush subnormal inputs and results to sign-preserving +// zero) variants for fp32 functions. +// +// This multiclass should be used for nodes that can be folded to make fma ops. +// In this case, we use the ".rn" variant when FMA is disabled, as this behaves +// just like the non ".rn" op, but prevents ptxas from creating FMAs. +multiclass F3_fma_component<string OpcStr, SDNode OpNode> { + def f64rr : + NVPTXInst<(outs Float64Regs:$dst), + (ins Float64Regs:$a, Float64Regs:$b), + !strconcat(OpcStr, ".f64 \t$dst, $a, $b;"), [(set Float64Regs:$dst, (OpNode Float64Regs:$a, Float64Regs:$b))]>, Requires<[allowFMA]>; def f64ri : @@ -248,41 +296,39 @@ multiclass F3<string OpcStr, SDNode OpNode> { !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"), [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>, Requires<[allowFMA]>; -} -// Same as F3, but defines ".rn" variants (round to nearest even). -multiclass F3_rn<string OpcStr, SDNode OpNode> { - def f64rr : + // These have strange names so we don't perturb existing mir tests. + def _rnf64rr : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$a, Float64Regs:$b), !strconcat(OpcStr, ".rn.f64 \t$dst, $a, $b;"), [(set Float64Regs:$dst, (OpNode Float64Regs:$a, Float64Regs:$b))]>, Requires<[noFMA]>; - def f64ri : + def _rnf64ri : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$a, f64imm:$b), !strconcat(OpcStr, ".rn.f64 \t$dst, $a, $b;"), [(set Float64Regs:$dst, (OpNode Float64Regs:$a, fpimm:$b))]>, Requires<[noFMA]>; - def f32rr_ftz : + def _rnf32rr_ftz : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a, Float32Regs:$b), !strconcat(OpcStr, ".rn.ftz.f32 \t$dst, $a, $b;"), [(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>, Requires<[noFMA, doF32FTZ]>; - def f32ri_ftz : + def _rnf32ri_ftz : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a, f32imm:$b), !strconcat(OpcStr, ".rn.ftz.f32 \t$dst, $a, $b;"), [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>, Requires<[noFMA, doF32FTZ]>; - def f32rr : + def _rnf32rr : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a, Float32Regs:$b), !strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"), [(set Float32Regs:$dst, (OpNode Float32Regs:$a, Float32Regs:$b))]>, Requires<[noFMA]>; - def f32ri : + def _rnf32ri : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a, f32imm:$b), !strconcat(OpcStr, ".rn.f32 \t$dst, $a, $b;"), @@ -713,13 +759,12 @@ def DoubleConst1 : PatLeaf<(fpimm), [{ N->getValueAPF().convertToDouble() == 1.0; }]>; -defm FADD : F3<"add", fadd>; -defm FSUB : F3<"sub", fsub>; -defm FMUL : F3<"mul", fmul>; +defm FADD : F3_fma_component<"add", fadd>; +defm FSUB : F3_fma_component<"sub", fsub>; +defm FMUL : F3_fma_component<"mul", fmul>; -defm FADD_rn : F3_rn<"add", fadd>; -defm FSUB_rn : F3_rn<"sub", fsub>; -defm FMUL_rn : F3_rn<"mul", fmul>; +defm FMIN : F3<"min", fminnum>; +defm FMAX : F3<"max", fmaxnum>; defm FABS : F2<"abs", fabs>; defm FNEG : F2<"neg", fneg>; @@ -2628,6 +2673,55 @@ def : Pat<(f64 (fpextend Float32Regs:$a)), def retflag : SDNode<"NVPTXISD::RET_FLAG", SDTNone, [SDNPHasChain, SDNPOptInGlue]>; +// fceil, ffloor, fround, ftrunc. + +def : Pat<(fceil Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRPI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(fceil Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRPI)>, Requires<[doNoF32FTZ]>; +def : Pat<(fceil Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtRPI)>; + +def : Pat<(ffloor Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRMI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(ffloor Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRMI)>, Requires<[doNoF32FTZ]>; +def : Pat<(ffloor Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtRMI)>; + +def : Pat<(fround Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f32 (fround Float32Regs:$a)), + (CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>; +def : Pat<(f64 (fround Float64Regs:$a)), + (CVT_f64_f64 Float64Regs:$a, CvtRNI)>; + +def : Pat<(ftrunc Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRZI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(ftrunc Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRZI)>, Requires<[doNoF32FTZ]>; +def : Pat<(ftrunc Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtRZI)>; + +// nearbyint and rint are implemented as rounding to nearest even. This isn't +// strictly correct, because it causes us to ignore the rounding mode. But it +// matches what CUDA's "libm" does. + +def : Pat<(fnearbyint Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(fnearbyint Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>; +def : Pat<(fnearbyint Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtRNI)>; + +def : Pat<(frint Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(frint Float32Regs:$a), + (CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>; +def : Pat<(frint Float64Regs:$a), + (CVT_f64_f64 Float64Regs:$a, CvtRNI)>; + + //----------------------------------- // Control-flow //----------------------------------- diff --git a/llvm/test/CodeGen/NVPTX/bug22322.ll b/llvm/test/CodeGen/NVPTX/bug22322.ll index e073e36..0c4c30c 100644 --- a/llvm/test/CodeGen/NVPTX/bug22322.ll +++ b/llvm/test/CodeGen/NVPTX/bug22322.ll @@ -22,7 +22,7 @@ _ZL11compute_vecRK6float3jb.exit: %8 = icmp eq i32 %7, 0 %9 = select i1 %8, float 0.000000e+00, float -1.000000e+00 store float %9, float* %ret_vec.sroa.8.i, align 4 -; CHECK: setp.lt.f32 %p{{[0-9]+}}, %f{{[0-9]+}}, 0f00000000 +; CHECK: max.f32 %f{{[0-9]+}}, %f{{[0-9]+}}, 0f00000000 %10 = fcmp olt float %9, 0.000000e+00 %ret_vec.sroa.8.i.val = load float, float* %ret_vec.sroa.8.i, align 4 %11 = select i1 %10, float 0.000000e+00, float %ret_vec.sroa.8.i.val diff --git a/llvm/test/CodeGen/NVPTX/math-intrins.ll b/llvm/test/CodeGen/NVPTX/math-intrins.ll new file mode 100644 index 0000000..de911d0 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/math-intrins.ll @@ -0,0 +1,261 @@ +; RUN: llc < %s | FileCheck %s +target triple = "nvptx64-nvidia-cuda" + +; Checks that llvm intrinsics for math functions are correctly lowered to PTX. + +declare float @llvm.ceil.f32(float) #0 +declare double @llvm.ceil.f64(double) #0 +declare float @llvm.floor.f32(float) #0 +declare double @llvm.floor.f64(double) #0 +declare float @llvm.round.f32(float) #0 +declare double @llvm.round.f64(double) #0 +declare float @llvm.nearbyint.f32(float) #0 +declare double @llvm.nearbyint.f64(double) #0 +declare float @llvm.rint.f32(float) #0 +declare double @llvm.rint.f64(double) #0 +declare float @llvm.trunc.f32(float) #0 +declare double @llvm.trunc.f64(double) #0 +declare float @llvm.fabs.f32(float) #0 +declare double @llvm.fabs.f64(double) #0 +declare float @llvm.minnum.f32(float, float) #0 +declare double @llvm.minnum.f64(double, double) #0 +declare float @llvm.maxnum.f32(float, float) #0 +declare double @llvm.maxnum.f64(double, double) #0 + +; ---- ceil ---- + +; CHECK-LABEL: ceil_float +define float @ceil_float(float %a) { + ; CHECK: cvt.rpi.f32.f32 + %b = call float @llvm.ceil.f32(float %a) + ret float %b +} + +; CHECK-LABEL: ceil_float_ftz +define float @ceil_float_ftz(float %a) #1 { + ; CHECK: cvt.rpi.ftz.f32.f32 + %b = call float @llvm.ceil.f32(float %a) + ret float %b +} + +; CHECK-LABEL: ceil_double +define double @ceil_double(double %a) { + ; CHECK: cvt.rpi.f64.f64 + %b = call double @llvm.ceil.f64(double %a) + ret double %b +} + +; ---- floor ---- + +; CHECK-LABEL: floor_float +define float @floor_float(float %a) { + ; CHECK: cvt.rmi.f32.f32 + %b = call float @llvm.floor.f32(float %a) + ret float %b +} + +; CHECK-LABEL: floor_float_ftz +define float @floor_float_ftz(float %a) #1 { + ; CHECK: cvt.rmi.ftz.f32.f32 + %b = call float @llvm.floor.f32(float %a) + ret float %b +} + +; CHECK-LABEL: floor_double +define double @floor_double(double %a) { + ; CHECK: cvt.rmi.f64.f64 + %b = call double @llvm.floor.f64(double %a) + ret double %b +} + +; ---- round ---- + +; CHECK-LABEL: round_float +define float @round_float(float %a) { + ; CHECK: cvt.rni.f32.f32 + %b = call float @llvm.round.f32(float %a) + ret float %b +} + +; CHECK-LABEL: round_float_ftz +define float @round_float_ftz(float %a) #1 { + ; CHECK: cvt.rni.ftz.f32.f32 + %b = call float @llvm.round.f32(float %a) + ret float %b +} + +; CHECK-LABEL: round_double +define double @round_double(double %a) { + ; CHECK: cvt.rni.f64.f64 + %b = call double @llvm.round.f64(double %a) + ret double %b +} + +; ---- nearbyint ---- + +; CHECK-LABEL: nearbyint_float +define float @nearbyint_float(float %a) { + ; CHECK: cvt.rni.f32.f32 + %b = call float @llvm.nearbyint.f32(float %a) + ret float %b +} + +; CHECK-LABEL: nearbyint_float_ftz +define float @nearbyint_float_ftz(float %a) #1 { + ; CHECK: cvt.rni.ftz.f32.f32 + %b = call float @llvm.nearbyint.f32(float %a) + ret float %b +} + +; CHECK-LABEL: nearbyint_double +define double @nearbyint_double(double %a) { + ; CHECK: cvt.rni.f64.f64 + %b = call double @llvm.nearbyint.f64(double %a) + ret double %b +} + +; ---- rint ---- + +; CHECK-LABEL: rint_float +define float @rint_float(float %a) { + ; CHECK: cvt.rni.f32.f32 + %b = call float @llvm.rint.f32(float %a) + ret float %b +} + +; CHECK-LABEL: rint_float_ftz +define float @rint_float_ftz(float %a) #1 { + ; CHECK: cvt.rni.ftz.f32.f32 + %b = call float @llvm.rint.f32(float %a) + ret float %b +} + +; CHECK-LABEL: rint_double +define double @rint_double(double %a) { + ; CHECK: cvt.rni.f64.f64 + %b = call double @llvm.rint.f64(double %a) + ret double %b +} + +; ---- trunc ---- + +; CHECK-LABEL: trunc_float +define float @trunc_float(float %a) { + ; CHECK: cvt.rzi.f32.f32 + %b = call float @llvm.trunc.f32(float %a) + ret float %b +} + +; CHECK-LABEL: trunc_float_ftz +define float @trunc_float_ftz(float %a) #1 { + ; CHECK: cvt.rzi.ftz.f32.f32 + %b = call float @llvm.trunc.f32(float %a) + ret float %b +} + +; CHECK-LABEL: trunc_double +define double @trunc_double(double %a) { + ; CHECK: cvt.rzi.f64.f64 + %b = call double @llvm.trunc.f64(double %a) + ret double %b +} + +; ---- abs ---- + +; CHECK-LABEL: abs_float +define float @abs_float(float %a) { + ; CHECK: abs.f32 + %b = call float @llvm.fabs.f32(float %a) + ret float %b +} + +; CHECK-LABEL: abs_float_ftz +define float @abs_float_ftz(float %a) #1 { + ; CHECK: abs.ftz.f32 + %b = call float @llvm.fabs.f32(float %a) + ret float %b +} + +; CHECK-LABEL: abs_double +define double @abs_double(double %a) { + ; CHECK: abs.f64 + %b = call double @llvm.fabs.f64(double %a) + ret double %b +} + +; ---- min ---- + +; CHECK-LABEL: min_float +define float @min_float(float %a, float %b) { + ; CHECK: min.f32 + %x = call float @llvm.minnum.f32(float %a, float %b) + ret float %x +} + +; CHECK-LABEL: min_imm1 +define float @min_imm1(float %a) { + ; CHECK: min.f32 + %x = call float @llvm.minnum.f32(float %a, float 0.0) + ret float %x +} + +; CHECK-LABEL: min_imm2 +define float @min_imm2(float %a) { + ; CHECK: min.f32 + %x = call float @llvm.minnum.f32(float 0.0, float %a) + ret float %x +} + +; CHECK-LABEL: min_float_ftz +define float @min_float_ftz(float %a, float %b) #1 { + ; CHECK: min.ftz.f32 + %x = call float @llvm.minnum.f32(float %a, float %b) + ret float %x +} + +; CHECK-LABEL: min_double +define double @min_double(double %a, double %b) { + ; CHECK: min.f64 + %x = call double @llvm.minnum.f64(double %a, double %b) + ret double %x +} + +; ---- max ---- + +; CHECK-LABEL: max_imm1 +define float @max_imm1(float %a) { + ; CHECK: max.f32 + %x = call float @llvm.maxnum.f32(float %a, float 0.0) + ret float %x +} + +; CHECK-LABEL: max_imm2 +define float @max_imm2(float %a) { + ; CHECK: max.f32 + %x = call float @llvm.maxnum.f32(float 0.0, float %a) + ret float %x +} + +; CHECK-LABEL: max_float +define float @max_float(float %a, float %b) { + ; CHECK: max.f32 + %x = call float @llvm.maxnum.f32(float %a, float %b) + ret float %x +} + +; CHECK-LABEL: max_float_ftz +define float @max_float_ftz(float %a, float %b) #1 { + ; CHECK: max.ftz.f32 + %x = call float @llvm.maxnum.f32(float %a, float %b) + ret float %x +} + +; CHECK-LABEL: max_double +define double @max_double(double %a, double %b) { + ; CHECK: max.f64 + %x = call double @llvm.maxnum.f64(double %a, double %b) + ret double %x +} + +attributes #0 = { nounwind readnone } +attributes #1 = { "nvptx-f32ftz" = "true" } |
