diff options
author | Meredith Julian <35236176+mjulian31@users.noreply.github.com> | 2025-07-24 14:32:59 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-07-24 14:32:59 -0700 |
commit | be58069515e38d25a8e6ae5f1ef3b7b2e5eddbd1 (patch) | |
tree | eda93ea4797e070df2847351dcb17343cd475979 | |
parent | 581ba1cbf70bc5f89a095807c16f668a9b00ded9 (diff) | |
download | llvm-be58069515e38d25a8e6ae5f1ef3b7b2e5eddbd1.zip llvm-be58069515e38d25a8e6ae5f1ef3b7b2e5eddbd1.tar.gz llvm-be58069515e38d25a8e6ae5f1ef3b7b2e5eddbd1.tar.bz2 |
[LLVM][NVPTX] Upstream tanh intrinsic for libdevice (#149596)
Currently __nv_fast_tanhf() in libdevice maps to an nvvm intrinsic that
has not been upstreamed, which is causing issues when using the NVPTX
backend from upstream. Instead of upstreaming the intrinsic, we can
instead use the existing Intrinsic::tanh with the afn flag. This change
adds NVPTX backend support for ISD::TANH, adds auto-upgrade for the old
tanh_approx intrinsic to @llvm.tanh.f32 with afn flag so that libdevice
works properly upstream, and adds a basic codegen test and a case to the
auto-upgrade test.
-rw-r--r-- | llvm/lib/IR/AutoUpgrade.cpp | 7 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 7 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 6 | ||||
-rw-r--r-- | llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll | 9 | ||||
-rw-r--r-- | llvm/test/CodeGen/NVPTX/tanhf.ll | 40 |
5 files changed, 66 insertions, 3 deletions
diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp index 28ed1e5..7159107 100644 --- a/llvm/lib/IR/AutoUpgrade.cpp +++ b/llvm/lib/IR/AutoUpgrade.cpp @@ -1450,6 +1450,7 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn, .Case("popc.ll", true) .Case("h2f", true) .Case("swap.lo.hi.b64", true) + .Case("tanh.approx.f32", true) .Default(false); if (Expand) { @@ -2543,6 +2544,12 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI, MDNode *MD = MDNode::get(Builder.getContext(), {}); LD->setMetadata(LLVMContext::MD_invariant_load, MD); return LD; + } else if (Name == "tanh.approx.f32") { + // nvvm.tanh.approx.f32 -> afn llvm.tanh.f32 + FastMathFlags FMF; + FMF.setApproxFunc(); + Rep = Builder.CreateUnaryIntrinsic(Intrinsic::tanh, CI->getArgOperand(0), + FMF); } else if (Name == "barrier0" || Name == "barrier.n" || Name == "bar.sync") { Value *Arg = Name.ends_with('0') ? Builder.getInt32(0) : CI->getArgOperand(0); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index f2c2f46..ddcecc00 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -952,10 +952,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // promoted to f32. v2f16 is expanded to f16, which is then promoted // to f32. for (const auto &Op : - {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS}) { + {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FTANH}) { setOperationAction(Op, MVT::f16, Promote); setOperationAction(Op, MVT::f32, Legal); - setOperationAction(Op, MVT::f64, Legal); + // only div/rem/sqrt are legal for f64 + if (Op == ISD::FDIV || Op == ISD::FREM || Op == ISD::FSQRT) { + setOperationAction(Op, MVT::f64, Legal); + } setOperationAction(Op, {MVT::v2f16, MVT::v2bf16, MVT::v2f32}, Expand); setOperationAction(Op, MVT::bf16, Promote); AddPromotedToType(Op, MVT::bf16, MVT::f32); diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index b5df4c6..442b900 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -1234,7 +1234,7 @@ defm FMA_F32 : FMA<F32RT, allow_ftz = true>; defm FMA_F32x2 : FMA<F32X2RT, allow_ftz = true, preds = [hasF32x2Instructions]>; defm FMA_F64 : FMA<F64RT, allow_ftz = false>; -// sin/cos +// sin/cos/tanh class UnaryOpAllowsApproxFn<SDPatternOperator operator> : PatFrag<(ops node:$A), @@ -1250,6 +1250,10 @@ def COS_APPROX_f32 : BasicFlagsNVPTXInst<(outs B32:$dst), (ins B32:$src), (ins FTZFlag:$ftz), "cos.approx$ftz.f32", [(set f32:$dst, (UnaryOpAllowsApproxFn<fcos> f32:$src))]>; +def TANH_APPROX_f32 : + BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "tanh.approx.f32", + [(set f32:$dst, (UnaryOpAllowsApproxFn<ftanh> f32:$src))]>, + Requires<[hasPTX<70>, hasSM<75>]>; //----------------------------------- // Bitwise operations diff --git a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll index a17f11a..362586a 100644 --- a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll +++ b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll @@ -17,6 +17,8 @@ declare float @llvm.nvvm.fabs.f(float) declare float @llvm.nvvm.fabs.ftz.f(float) declare double @llvm.nvvm.fabs.d(double) +declare float @llvm.nvvm.tanh.approx.f32(float) + declare i16 @llvm.nvvm.max.s(i16, i16) declare i32 @llvm.nvvm.max.i(i32, i32) declare i64 @llvm.nvvm.max.ll(i64, i64) @@ -138,6 +140,13 @@ define void @fabs(float %a, double %b) { ret void } +; CHECK-LABEL: @tanh +define void @tanh(float %a) { +; CHECK: call afn float @llvm.tanh.f32(float %a) + %r1 = call float @llvm.nvvm.tanh.approx.f32(float %a) + ret void +} + ; CHECK-LABEL: @min_max define void @min_max(i16 %a1, i16 %a2, i32 %b1, i32 %b2, i64 %c1, i64 %c2) { ; CHECK: [[maxs:%[a-zA-Z0-9.]+]] = icmp sge i16 %a1, %a2 diff --git a/llvm/test/CodeGen/NVPTX/tanhf.ll b/llvm/test/CodeGen/NVPTX/tanhf.ll new file mode 100644 index 0000000..6f4eb22 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/tanhf.ll @@ -0,0 +1,40 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -mcpu=sm_75 -mattr=+ptx70 | FileCheck %s +; RUN: %if ptxas-11.0 %{ llc < %s -mcpu=sm_75 -mattr=+ptx70 | %ptxas-verify -arch=sm_75 %} + +target triple = "nvptx64-nvidia-cuda" + +define float @test1(float %in) local_unnamed_addr { +; CHECK-LABEL: test1( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [test1_param_0]; +; CHECK-NEXT: tanh.approx.f32 %r2, %r1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r2; +; CHECK-NEXT: ret; + %call = call afn float @llvm.tanh.f32(float %in) + ret float %call +} + +define half @test2(half %in) local_unnamed_addr { +; CHECK-LABEL: test2( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .b32 %r<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test2_param_0]; +; CHECK-NEXT: cvt.f32.f16 %r1, %rs1; +; CHECK-NEXT: tanh.approx.f32 %r2, %r1; +; CHECK-NEXT: cvt.rn.f16.f32 %rs2, %r2; +; CHECK-NEXT: st.param.b16 [func_retval0], %rs2; +; CHECK-NEXT: ret; + %call = call afn half @llvm.tanh.f16(half %in) + ret half %call +} + +declare float @llvm.tanh.f32(float) +declare half @llvm.tanh.f16(half) + |