aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHelena Kotas <hekotas@microsoft.com>2024-06-18 10:35:38 -0700
committerGitHub <noreply@github.com>2024-06-18 10:35:38 -0700
commit35a2b60973074ab7b9add53c58acafe166820551 (patch)
treedbcbec88e158be27ff36af29b21bfcb2321bfe69
parent30efdce77e523454a6f1778827170f0e70ba8616 (diff)
downloadllvm-35a2b60973074ab7b9add53c58acafe166820551.zip
llvm-35a2b60973074ab7b9add53c58acafe166820551.tar.gz
llvm-35a2b60973074ab7b9add53c58acafe166820551.tar.bz2
[SPIRV][HLSL] Add lowering of `rsqrt` to SPIRV (#95849)
Add lowering of `rsqrt` to SPIRV. Fixes #88949
-rw-r--r--clang/lib/CodeGen/CGBuiltin.cpp4
-rw-r--r--clang/lib/CodeGen/CGHLSLRuntime.h1
-rw-r--r--clang/test/CodeGenHLSL/builtins/rsqrt.hlsl109
-rw-r--r--llvm/include/llvm/IR/IntrinsicsSPIRV.td1
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp22
-rw-r--r--llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll68
6 files changed, 164 insertions, 41 deletions
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 511e1fd..eb56bba 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18331,8 +18331,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
llvm_unreachable("rsqrt operand must have a float representation");
return Builder.CreateIntrinsic(
- /*ReturnType=*/Op0->getType(), Intrinsic::dx_rsqrt,
- ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
+ /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getRsqrtIntrinsic(),
+ ArrayRef<Value *>{Op0}, nullptr, "hlsl.rsqrt");
}
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 0abe39d..4036ce7 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -75,6 +75,7 @@ public:
GENERATE_HLSL_INTRINSIC_FUNCTION(All, all)
GENERATE_HLSL_INTRINSIC_FUNCTION(Any, any)
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
//===----------------------------------------------------------------------===//
diff --git a/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl b/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl
index c87a8c4..bb96ad8 100644
--- a/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/rsqrt.hlsl
@@ -1,53 +1,84 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
-// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
-// RUN: --check-prefixes=CHECK,NATIVE_HALF
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,DXIL_CHECK,DXIL_NATIVE_HALF,NATIVE_HALF
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
-// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,DXIL_CHECK,NO_HALF,DXIL_NO_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,SPIR_CHECK,NATIVE_HALF,SPIR_NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,SPIR_CHECK,NO_HALF,SPIR_NO_HALF
-// NATIVE_HALF: define noundef half @
-// NATIVE_HALF: %dx.rsqrt = call half @llvm.dx.rsqrt.f16(
-// NATIVE_HALF: ret half %dx.rsqrt
-// NO_HALF: define noundef float @"?test_rsqrt_half@@YA$halff@$halff@@Z"(
-// NO_HALF: %dx.rsqrt = call float @llvm.dx.rsqrt.f32(
-// NO_HALF: ret float %dx.rsqrt
+// DXIL_NATIVE_HALF: define noundef half @
+// SPIR_NATIVE_HALF: define spir_func noundef half @
+// DXIL_NATIVE_HALF: %hlsl.rsqrt = call half @llvm.dx.rsqrt.f16(
+// SPIR_NATIVE_HALF: %hlsl.rsqrt = call half @llvm.spv.rsqrt.f16(
+// NATIVE_HALF: ret half %hlsl.rsqrt
+// DXIL_NO_HALF: define noundef float @
+// SPIR_NO_HALF: define spir_func noundef float @
+// DXIL_NO_HALF: %hlsl.rsqrt = call float @llvm.dx.rsqrt.f32(
+// SPIR_NO_HALF: %hlsl.rsqrt = call float @llvm.spv.rsqrt.f32(
+// NO_HALF: ret float %hlsl.rsqrt
half test_rsqrt_half(half p0) { return rsqrt(p0); }
-// NATIVE_HALF: define noundef <2 x half> @
-// NATIVE_HALF: %dx.rsqrt = call <2 x half> @llvm.dx.rsqrt.v2f16
-// NATIVE_HALF: ret <2 x half> %dx.rsqrt
-// NO_HALF: define noundef <2 x float> @
-// NO_HALF: %dx.rsqrt = call <2 x float> @llvm.dx.rsqrt.v2f32(
-// NO_HALF: ret <2 x float> %dx.rsqrt
+// DXIL_NATIVE_HALF: define noundef <2 x half> @
+// SPIR_NATIVE_HALF: define spir_func noundef <2 x half> @
+// DXIL_NATIVE_HALF: %hlsl.rsqrt = call <2 x half> @llvm.dx.rsqrt.v2f16
+// SPIR_NATIVE_HALF: %hlsl.rsqrt = call <2 x half> @llvm.spv.rsqrt.v2f16
+// NATIVE_HALF: ret <2 x half> %hlsl.rsqrt
+// DXIL_NO_HALF: define noundef <2 x float> @
+// SPIR_NO_HALF: define spir_func noundef <2 x float> @
+// DXIL_NO_HALF: %hlsl.rsqrt = call <2 x float> @llvm.dx.rsqrt.v2f32(
+// SPIR_NO_HALF: %hlsl.rsqrt = call <2 x float> @llvm.spv.rsqrt.v2f32(
+// NO_HALF: ret <2 x float> %hlsl.rsqrt
half2 test_rsqrt_half2(half2 p0) { return rsqrt(p0); }
-// NATIVE_HALF: define noundef <3 x half> @
-// NATIVE_HALF: %dx.rsqrt = call <3 x half> @llvm.dx.rsqrt.v3f16
-// NATIVE_HALF: ret <3 x half> %dx.rsqrt
-// NO_HALF: define noundef <3 x float> @
-// NO_HALF: %dx.rsqrt = call <3 x float> @llvm.dx.rsqrt.v3f32(
-// NO_HALF: ret <3 x float> %dx.rsqrt
+// DXIL_NATIVE_HALF: define noundef <3 x half> @
+// SPIR_NATIVE_HALF: define spir_func noundef <3 x half> @
+// DXIL_NATIVE_HALF: %hlsl.rsqrt = call <3 x half> @llvm.dx.rsqrt.v3f16
+// SPIR_NATIVE_HALF: %hlsl.rsqrt = call <3 x half> @llvm.spv.rsqrt.v3f16
+// NATIVE_HALF: ret <3 x half> %hlsl.rsqrt
+// DXIL_NO_HALF: define noundef <3 x float> @
+// SPIR_NO_HALF: define spir_func noundef <3 x float> @
+// DXIL_NO_HALF: %hlsl.rsqrt = call <3 x float> @llvm.dx.rsqrt.v3f32(
+// SPIR_NO_HALF: %hlsl.rsqrt = call <3 x float> @llvm.spv.rsqrt.v3f32(
+// NO_HALF: ret <3 x float> %hlsl.rsqrt
half3 test_rsqrt_half3(half3 p0) { return rsqrt(p0); }
-// NATIVE_HALF: define noundef <4 x half> @
-// NATIVE_HALF: %dx.rsqrt = call <4 x half> @llvm.dx.rsqrt.v4f16
-// NATIVE_HALF: ret <4 x half> %dx.rsqrt
-// NO_HALF: define noundef <4 x float> @
-// NO_HALF: %dx.rsqrt = call <4 x float> @llvm.dx.rsqrt.v4f32(
-// NO_HALF: ret <4 x float> %dx.rsqrt
+// DXIL_NATIVE_HALF: define noundef <4 x half> @
+// SPIR_NATIVE_HALF: define spir_func noundef <4 x half> @
+// DXIL_NATIVE_HALF: %hlsl.rsqrt = call <4 x half> @llvm.dx.rsqrt.v4f16
+// SPIR_NATIVE_HALF: %hlsl.rsqrt = call <4 x half> @llvm.spv.rsqrt.v4f16
+// NATIVE_HALF: ret <4 x half> %hlsl.rsqrt
+// DXIL_NO_HALF: define noundef <4 x float> @
+// SPIR_NO_HALF: define spir_func noundef <4 x float> @
+// DXIL_NO_HALF: %hlsl.rsqrt = call <4 x float> @llvm.dx.rsqrt.v4f32(
+// SPIR_NO_HALF: %hlsl.rsqrt = call <4 x float> @llvm.spv.rsqrt.v4f32(
+// NO_HALF: ret <4 x float> %hlsl.rsqrt
half4 test_rsqrt_half4(half4 p0) { return rsqrt(p0); }
-// CHECK: define noundef float @
-// CHECK: %dx.rsqrt = call float @llvm.dx.rsqrt.f32(
-// CHECK: ret float %dx.rsqrt
+// DXIL_CHECK: define noundef float @
+// SPIR_CHECK: define spir_func noundef float @
+// DXIL_CHECK: %hlsl.rsqrt = call float @llvm.dx.rsqrt.f32(
+// SPIR_CHECK: %hlsl.rsqrt = call float @llvm.spv.rsqrt.f32(
+// CHECK: ret float %hlsl.rsqrt
float test_rsqrt_float(float p0) { return rsqrt(p0); }
-// CHECK: define noundef <2 x float> @
-// CHECK: %dx.rsqrt = call <2 x float> @llvm.dx.rsqrt.v2f32
-// CHECK: ret <2 x float> %dx.rsqrt
+// DXIL_CHECK: define noundef <2 x float> @
+// SPIR_CHECK: define spir_func noundef <2 x float> @
+// DXIL_CHECK: %hlsl.rsqrt = call <2 x float> @llvm.dx.rsqrt.v2f32
+// SPIR_CHECK: %hlsl.rsqrt = call <2 x float> @llvm.spv.rsqrt.v2f32
+// CHECK: ret <2 x float> %hlsl.rsqrt
float2 test_rsqrt_float2(float2 p0) { return rsqrt(p0); }
-// CHECK: define noundef <3 x float> @
-// CHECK: %dx.rsqrt = call <3 x float> @llvm.dx.rsqrt.v3f32
-// CHECK: ret <3 x float> %dx.rsqrt
+// DXIL_CHECK: define noundef <3 x float> @
+// SPIR_CHECK: define spir_func noundef <3 x float> @
+// DXIL_CHECK: %hlsl.rsqrt = call <3 x float> @llvm.dx.rsqrt.v3f32
+// SPIR_CHECK: %hlsl.rsqrt = call <3 x float> @llvm.spv.rsqrt.v3f32
+// CHECK: ret <3 x float> %hlsl.rsqrt
float3 test_rsqrt_float3(float3 p0) { return rsqrt(p0); }
-// CHECK: define noundef <4 x float> @
-// CHECK: %dx.rsqrt = call <4 x float> @llvm.dx.rsqrt.v4f32
-// CHECK: ret <4 x float> %dx.rsqrt
+// DXIL_CHECK: define noundef <4 x float> @
+// SPIR_CHECK: define spir_func noundef <4 x float> @
+// DXIL_CHECK: %hlsl.rsqrt = call <4 x float> @llvm.dx.rsqrt.v4f32
+// SPIR_CHECK: %hlsl.rsqrt = call <4 x float> @llvm.spv.rsqrt.v4f32
+// CHECK: ret <4 x float> %hlsl.rsqrt
float4 test_rsqrt_float4(float4 p0) { return rsqrt(p0); }
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 90f1267..683acf4 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -62,4 +62,5 @@ let TargetPrefix = "spv" in {
def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
+ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index db83172..b9e5569 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -173,6 +173,9 @@ private:
bool selectFmix(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectRsqrt(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
int OpIdx) const;
void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
@@ -1315,6 +1318,23 @@ bool SPIRVInstructionSelector::selectFmix(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}
+bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+
+ assert(I.getNumOperands() == 3);
+ assert(I.getOperand(2).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+
+ return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+ .addImm(GL::InverseSqrt)
+ .addUse(I.getOperand(2).getReg())
+ .constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -1992,6 +2012,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectAny(ResVReg, ResType, I);
case Intrinsic::spv_lerp:
return selectFmix(ResVReg, ResType, I);
+ case Intrinsic::spv_rsqrt:
+ return selectRsqrt(ResVReg, ResType, I);
case Intrinsic::spv_lifetime_start:
case Intrinsic::spv_lifetime_end: {
unsigned Op = IID == Intrinsic::spv_lifetime_start ? SPIRV::OpLifetimeStart
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll
new file mode 100644
index 0000000..650b329
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll
@@ -0,0 +1,68 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450"
+
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#float_64:]] = OpTypeFloat 64
+
+; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
+; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
+; CHECK-DAG: %[[#vec4_float_64:]] = OpTypeVector %[[#float_64]] 4
+
+define noundef float @rsqrt_float(float noundef %a) {
+entry:
+; CHECK: %[[#float_32_arg:]] = OpFunctionParameter %[[#float_32]]
+; CHECK: %[[#]] = OpExtInst %[[#float_32]] %[[#op_ext_glsl]] InverseSqrt %[[#float_32_arg]]
+ %elt.rsqrt = call float @llvm.spv.rsqrt.f32(float %a)
+ ret float %elt.rsqrt
+}
+
+define noundef half @rsqrt_half(half noundef %a) {
+entry:
+; CHECK: %[[#float_16_arg:]] = OpFunctionParameter %[[#float_16]]
+; CHECK: %[[#]] = OpExtInst %[[#float_16]] %[[#op_ext_glsl]] InverseSqrt %[[#float_16_arg]]
+ %elt.rsqrt = call half @llvm.spv.rsqrt.f16(half %a)
+ ret half %elt.rsqrt
+}
+
+define noundef double @rsqrt_double(double noundef %a) {
+entry:
+; CHECK: %[[#float_64_arg:]] = OpFunctionParameter %[[#float_64]]
+; CHECK: %[[#]] = OpExtInst %[[#float_64]] %[[#op_ext_glsl]] InverseSqrt %[[#float_64_arg]]
+ %elt.rsqrt = call double @llvm.spv.rsqrt.f64(double %a)
+ ret double %elt.rsqrt
+}
+
+define noundef <4 x float> @rsqrt_float_vector(<4 x float> noundef %a) {
+entry:
+; CHECK: %[[#vec4_float_32_arg:]] = OpFunctionParameter %[[#vec4_float_32]]
+; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] InverseSqrt %[[#vec4_float_32_arg]]
+ %elt.rsqrt = call <4 x float> @llvm.spv.rsqrt.v4f32(<4 x float> %a)
+ ret <4 x float> %elt.rsqrt
+}
+
+define noundef <4 x half> @rsqrt_half_vector(<4 x half> noundef %a) {
+entry:
+; CHECK: %[[#vec4_float_16_arg:]] = OpFunctionParameter %[[#vec4_float_16]]
+; CHECK: %[[#]] = OpExtInst %[[#vec4_float_16]] %[[#op_ext_glsl]] InverseSqrt %[[#vec4_float_16_arg]]
+ %elt.rsqrt = call <4 x half> @llvm.spv.rsqrt.v4f16(<4 x half> %a)
+ ret <4 x half> %elt.rsqrt
+}
+
+define noundef <4 x double> @rsqrt_double_vector(<4 x double> noundef %a) {
+entry:
+; CHECK: %[[#vec4_float_64_arg:]] = OpFunctionParameter %[[#vec4_float_64]]
+; CHECK: %[[#]] = OpExtInst %[[#vec4_float_64]] %[[#op_ext_glsl]] InverseSqrt %[[#vec4_float_64_arg]]
+ %elt.rsqrt = call <4 x double> @llvm.spv.rsqrt.v4f64(<4 x double> %a)
+ ret <4 x double> %elt.rsqrt
+}
+
+declare half @llvm.spv.rsqrt.f16(half)
+declare float @llvm.spv.rsqrt.f32(float)
+declare double @llvm.spv.rsqrt.f64(double)
+
+declare <4 x float> @llvm.spv.rsqrt.v4f32(<4 x float>)
+declare <4 x half> @llvm.spv.rsqrt.v4f16(<4 x half>)
+declare <4 x double> @llvm.spv.rsqrt.v4f64(<4 x double>)