aboutsummaryrefslogtreecommitdiff
path: root/llvm
diff options
context:
space:
mode:
authorHelena Kotas <hekotas@microsoft.com>2024-06-18 10:35:38 -0700
committerGitHub <noreply@github.com>2024-06-18 10:35:38 -0700
commit35a2b60973074ab7b9add53c58acafe166820551 (patch)
treedbcbec88e158be27ff36af29b21bfcb2321bfe69 /llvm
parent30efdce77e523454a6f1778827170f0e70ba8616 (diff)
downloadllvm-35a2b60973074ab7b9add53c58acafe166820551.zip
llvm-35a2b60973074ab7b9add53c58acafe166820551.tar.gz
llvm-35a2b60973074ab7b9add53c58acafe166820551.tar.bz2
[SPIRV][HLSL] Add lowering of `rsqrt` to SPIRV (#95849)
Add lowering of `rsqrt` to SPIRV. Fixes #88949
Diffstat (limited to 'llvm')
-rw-r--r--llvm/include/llvm/IR/IntrinsicsSPIRV.td1
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp22
-rw-r--r--llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll68
3 files changed, 91 insertions, 0 deletions
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 90f1267..683acf4 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -62,4 +62,5 @@ let TargetPrefix = "spv" in {
def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
+ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index db83172..b9e5569 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -173,6 +173,9 @@ private:
bool selectFmix(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectRsqrt(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
int OpIdx) const;
void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
@@ -1315,6 +1318,23 @@ bool SPIRVInstructionSelector::selectFmix(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}
+bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+
+ assert(I.getNumOperands() == 3);
+ assert(I.getOperand(2).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+
+ return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+ .addImm(GL::InverseSqrt)
+ .addUse(I.getOperand(2).getReg())
+ .constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -1992,6 +2012,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectAny(ResVReg, ResType, I);
case Intrinsic::spv_lerp:
return selectFmix(ResVReg, ResType, I);
+ case Intrinsic::spv_rsqrt:
+ return selectRsqrt(ResVReg, ResType, I);
case Intrinsic::spv_lifetime_start:
case Intrinsic::spv_lifetime_end: {
unsigned Op = IID == Intrinsic::spv_lifetime_start ? SPIRV::OpLifetimeStart
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll
new file mode 100644
index 0000000..650b329
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll
@@ -0,0 +1,68 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450"
+
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#float_64:]] = OpTypeFloat 64
+
+; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
+; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
+; CHECK-DAG: %[[#vec4_float_64:]] = OpTypeVector %[[#float_64]] 4
+
+define noundef float @rsqrt_float(float noundef %a) {
+entry:
+; CHECK: %[[#float_32_arg:]] = OpFunctionParameter %[[#float_32]]
+; CHECK: %[[#]] = OpExtInst %[[#float_32]] %[[#op_ext_glsl]] InverseSqrt %[[#float_32_arg]]
+ %elt.rsqrt = call float @llvm.spv.rsqrt.f32(float %a)
+ ret float %elt.rsqrt
+}
+
+define noundef half @rsqrt_half(half noundef %a) {
+entry:
+; CHECK: %[[#float_16_arg:]] = OpFunctionParameter %[[#float_16]]
+; CHECK: %[[#]] = OpExtInst %[[#float_16]] %[[#op_ext_glsl]] InverseSqrt %[[#float_16_arg]]
+ %elt.rsqrt = call half @llvm.spv.rsqrt.f16(half %a)
+ ret half %elt.rsqrt
+}
+
+define noundef double @rsqrt_double(double noundef %a) {
+entry:
+; CHECK: %[[#float_64_arg:]] = OpFunctionParameter %[[#float_64]]
+; CHECK: %[[#]] = OpExtInst %[[#float_64]] %[[#op_ext_glsl]] InverseSqrt %[[#float_64_arg]]
+ %elt.rsqrt = call double @llvm.spv.rsqrt.f64(double %a)
+ ret double %elt.rsqrt
+}
+
+define noundef <4 x float> @rsqrt_float_vector(<4 x float> noundef %a) {
+entry:
+; CHECK: %[[#vec4_float_32_arg:]] = OpFunctionParameter %[[#vec4_float_32]]
+; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] InverseSqrt %[[#vec4_float_32_arg]]
+ %elt.rsqrt = call <4 x float> @llvm.spv.rsqrt.v4f32(<4 x float> %a)
+ ret <4 x float> %elt.rsqrt
+}
+
+define noundef <4 x half> @rsqrt_half_vector(<4 x half> noundef %a) {
+entry:
+; CHECK: %[[#vec4_float_16_arg:]] = OpFunctionParameter %[[#vec4_float_16]]
+; CHECK: %[[#]] = OpExtInst %[[#vec4_float_16]] %[[#op_ext_glsl]] InverseSqrt %[[#vec4_float_16_arg]]
+ %elt.rsqrt = call <4 x half> @llvm.spv.rsqrt.v4f16(<4 x half> %a)
+ ret <4 x half> %elt.rsqrt
+}
+
+define noundef <4 x double> @rsqrt_double_vector(<4 x double> noundef %a) {
+entry:
+; CHECK: %[[#vec4_float_64_arg:]] = OpFunctionParameter %[[#vec4_float_64]]
+; CHECK: %[[#]] = OpExtInst %[[#vec4_float_64]] %[[#op_ext_glsl]] InverseSqrt %[[#vec4_float_64_arg]]
+ %elt.rsqrt = call <4 x double> @llvm.spv.rsqrt.v4f64(<4 x double> %a)
+ ret <4 x double> %elt.rsqrt
+}
+
+declare half @llvm.spv.rsqrt.f16(half)
+declare float @llvm.spv.rsqrt.f32(float)
+declare double @llvm.spv.rsqrt.f64(double)
+
+declare <4 x float> @llvm.spv.rsqrt.v4f32(<4 x float>)
+declare <4 x half> @llvm.spv.rsqrt.v4f16(<4 x half>)
+declare <4 x double> @llvm.spv.rsqrt.v4f64(<4 x double>)