aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
authorViktoria Maximova <viktoria.maksimova@intel.com>2025-02-05 18:35:08 +0100
committerGitHub <noreply@github.com>2025-02-05 09:35:08 -0800
commit50a27ce88cb070b68da739c6ec6e7eb255601495 (patch)
treec17982a59bd5d251f70490095219d8c4ef263dcc /llvm/lib
parenta907008bcb8dcc093f8aa5c0450d92cd63473b81 (diff)
downloadllvm-50a27ce88cb070b68da739c6ec6e7eb255601495.zip
llvm-50a27ce88cb070b68da739c6ec6e7eb255601495.tar.gz
llvm-50a27ce88cb070b68da739c6ec6e7eb255601495.tar.bz2
[SPIR-V] Support all the instructions of SPV_KHR_integer_dot_product (#123792)
This continues the work on dot product instructions already started in 3cdac06. This change adds support for all OpenCL integer dot product builtins under `cl_khr_integer_dot_product` extension, namely: ``` * dot * dot_acc_sat * dot_4x8packed_(uu/ss/su/us)_(u)int * dot_acc_sat_4x8packed_(uu/ss/su/us)_(u)int ```
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp97
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVBuiltins.td47
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVInstrInfo.td14
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp4
4 files changed, 151 insertions, 11 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 95fa7bc..08ee94a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -132,6 +132,15 @@ struct ImageQueryBuiltin {
#define GET_ImageQueryBuiltins_DECL
#define GET_ImageQueryBuiltins_IMPL
+struct IntegerDotProductBuiltin {
+ StringRef Name;
+ uint32_t Opcode;
+ bool IsSwapReq;
+};
+
+#define GET_IntegerDotProductBuiltins_DECL
+#define GET_IntegerDotProductBuiltins_IMPL
+
struct ConvertBuiltin {
StringRef Name;
InstructionSet::InstructionSet Set;
@@ -1579,20 +1588,84 @@ static bool generateCastToPtrInst(const SPIRV::IncomingCall *Call,
return true;
}
-static bool generateDotOrFMulInst(const SPIRV::IncomingCall *Call,
+static bool generateDotOrFMulInst(const StringRef DemangledCall,
+ const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
if (Call->isSpirvOp())
return buildOpFromWrapper(MIRBuilder, SPIRV::OpDot, Call,
GR->getSPIRVTypeID(Call->ReturnType));
- unsigned Opcode = GR->getSPIRVTypeForVReg(Call->Arguments[0])->getOpcode();
- bool IsVec = Opcode == SPIRV::OpTypeVector;
+
+ bool IsVec = GR->getSPIRVTypeForVReg(Call->Arguments[0])->getOpcode() ==
+ SPIRV::OpTypeVector;
// Use OpDot only in case of vector args and OpFMul in case of scalar args.
- MIRBuilder.buildInstr(IsVec ? SPIRV::OpDot : SPIRV::OpFMulS)
- .addDef(Call->ReturnRegister)
- .addUse(GR->getSPIRVTypeID(Call->ReturnType))
- .addUse(Call->Arguments[0])
- .addUse(Call->Arguments[1]);
+ uint32_t OC = IsVec ? SPIRV::OpDot : SPIRV::OpFMulS;
+ bool IsSwapReq = false;
+
+ const auto *ST =
+ static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
+ if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt) &&
+ (ST->canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product) ||
+ ST->isAtLeastSPIRVVer(VersionTuple(1, 6)))) {
+ const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
+ const SPIRV::IntegerDotProductBuiltin *IntDot =
+ SPIRV::lookupIntegerDotProductBuiltin(Builtin->Name);
+ if (IntDot) {
+ OC = IntDot->Opcode;
+ IsSwapReq = IntDot->IsSwapReq;
+ } else if (IsVec) {
+ // Handling "dot" and "dot_acc_sat" builtins which use vectors of
+ // integers.
+ LLVMContext &Ctx = MIRBuilder.getContext();
+ SmallVector<StringRef, 10> TypeStrs;
+ SPIRV::parseBuiltinTypeStr(TypeStrs, DemangledCall, Ctx);
+ bool IsFirstSigned = TypeStrs[0].trim()[0] != 'u';
+ bool IsSecondSigned = TypeStrs[1].trim()[0] != 'u';
+
+ if (Call->BuiltinName == "dot") {
+ if (IsFirstSigned && IsSecondSigned)
+ OC = SPIRV::OpSDot;
+ else if (!IsFirstSigned && !IsSecondSigned)
+ OC = SPIRV::OpUDot;
+ else {
+ OC = SPIRV::OpSUDot;
+ if (!IsFirstSigned)
+ IsSwapReq = true;
+ }
+ } else if (Call->BuiltinName == "dot_acc_sat") {
+ if (IsFirstSigned && IsSecondSigned)
+ OC = SPIRV::OpSDotAccSat;
+ else if (!IsFirstSigned && !IsSecondSigned)
+ OC = SPIRV::OpUDotAccSat;
+ else {
+ OC = SPIRV::OpSUDotAccSat;
+ if (!IsFirstSigned)
+ IsSwapReq = true;
+ }
+ }
+ }
+ }
+
+ MachineInstrBuilder MIB = MIRBuilder.buildInstr(OC)
+ .addDef(Call->ReturnRegister)
+ .addUse(GR->getSPIRVTypeID(Call->ReturnType));
+
+ if (IsSwapReq) {
+ MIB.addUse(Call->Arguments[1]);
+ MIB.addUse(Call->Arguments[0]);
+ // needed for dot_acc_sat* builtins
+ for (size_t i = 2; i < Call->Arguments.size(); ++i)
+ MIB.addUse(Call->Arguments[i]);
+ } else {
+ for (size_t i = 0; i < Call->Arguments.size(); ++i)
+ MIB.addUse(Call->Arguments[i]);
+ }
+
+ // Add Packed Vector Format for Integer dot product builtins if arguments are
+ // scalar
+ if (!IsVec && OC != SPIRV::OpFMulS)
+ MIB.addImm(0);
+
return true;
}
@@ -2576,6 +2649,11 @@ mapBuiltinToOpcode(const StringRef DemangledCall,
if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Call->Builtin->Name))
return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
break;
+ case SPIRV::IntegerDot:
+ if (const auto *R =
+ SPIRV::lookupIntegerDotProductBuiltin(Call->Builtin->Name))
+ return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
+ break;
case SPIRV::WriteImage:
return std::make_tuple(Call->Builtin->Group, SPIRV::OpImageWrite, 0);
case SPIRV::Select:
@@ -2635,7 +2713,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
case SPIRV::CastToPtr:
return generateCastToPtrInst(Call.get(), MIRBuilder);
case SPIRV::Dot:
- return generateDotOrFMulInst(Call.get(), MIRBuilder, GR);
+ case SPIRV::IntegerDot:
+ return generateDotOrFMulInst(DemangledCall, Call.get(), MIRBuilder, GR);
case SPIRV::Wave:
return generateWaveInst(Call.get(), MIRBuilder, GR);
case SPIRV::ICarryBorrow:
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index af3901c..8125e12 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -42,6 +42,7 @@ def Variable : BuiltinGroup;
def Atomic : BuiltinGroup;
def Barrier : BuiltinGroup;
def Dot : BuiltinGroup;
+def IntegerDot : BuiltinGroup;
def Wave : BuiltinGroup;
def GetQuery : BuiltinGroup;
def ImageSizeQuery : BuiltinGroup;
@@ -101,6 +102,8 @@ def lookupBuiltin : SearchIndex {
// Dot builtin record:
def : DemangledBuiltin<"dot", OpenCL_std, Dot, 2, 2>;
def : DemangledBuiltin<"__spirv_Dot", OpenCL_std, Dot, 2, 2>;
+def : DemangledBuiltin<"dot_acc_sat", OpenCL_std, IntegerDot, 3, 3>;
+def : DemangledBuiltin<"__spirv_DotAccSat", OpenCL_std, IntegerDot, 3, 3>;
// Image builtin records:
def : DemangledBuiltin<"read_imagei", OpenCL_std, ReadImage, 2, 4>;
@@ -1715,3 +1718,47 @@ class CLMemoryFenceFlags<bits<32> value> {
def CLK_LOCAL_MEM_FENCE : CLMemoryFenceFlags<0x1>;
def CLK_GLOBAL_MEM_FENCE : CLMemoryFenceFlags<0x2>;
def CLK_IMAGE_MEM_FENCE : CLMemoryFenceFlags<0x4>;
+
+//===----------------------------------------------------------------------===//
+// Class defining dot builtins that should be translated into a
+// SPIR-V instruction using SPIR-V 1.6 or SPV_KHR_integer_dot_product extension.
+//
+// name is the demangled name of the given builtin.
+// opcode specifies the SPIR-V operation code of the generated instruction.
+// isSwapRequired specifies if the operands need to be swapped (the SPIR-V extension
+// has only one instruction for arguments of different signedness).
+//===----------------------------------------------------------------------===//
+class IntegerDotProductBuiltin<string name, Op operation> {
+ string Name = name;
+ Op Opcode = operation;
+ bit IsSwapReq = !not(!eq(!find(name, "_us"), -1));
+}
+
+// Table gathering all the integer dot product builtins.
+def IntegerDotProductBuiltins : GenericTable {
+ let FilterClass = "IntegerDotProductBuiltin";
+ let Fields = ["Name", "Opcode", "IsSwapReq"];
+}
+
+// Function to lookup group builtins by their name and set.
+def lookupIntegerDotProductBuiltin : SearchIndex {
+ let Table = IntegerDotProductBuiltins;
+ let Key = ["Name"];
+}
+
+// Multiclass used to define incoming builtin records for the SPV_KHR_integer_dot_product extension.
+multiclass DemangledIntegerDotProductBuiltin<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> {
+ def : DemangledBuiltin<!strconcat("dot", name), OpenCL_std, IntegerDot, minNumArgs, maxNumArgs>;
+ def : IntegerDotProductBuiltin<!strconcat("dot", name), operation>;
+}
+
+// cl_khr_integer_dot_product
+defm : DemangledIntegerDotProductBuiltin<"_4x8packed_uu_uint", 2, 3, OpUDot>;
+defm : DemangledIntegerDotProductBuiltin<"_4x8packed_ss_int", 2, 3, OpSDot>;
+defm : DemangledIntegerDotProductBuiltin<"_4x8packed_us_int", 2, 3, OpSUDot>;
+defm : DemangledIntegerDotProductBuiltin<"_4x8packed_su_int", 2, 3, OpSUDot>;
+
+defm : DemangledIntegerDotProductBuiltin<"_acc_sat_4x8packed_uu_uint", 3, 4, OpUDotAccSat>;
+defm : DemangledIntegerDotProductBuiltin<"_acc_sat_4x8packed_ss_int", 3, 4, OpSDotAccSat>;
+defm : DemangledIntegerDotProductBuiltin<"_acc_sat_4x8packed_us_int", 3, 4, OpSUDotAccSat>;
+defm : DemangledIntegerDotProductBuiltin<"_acc_sat_4x8packed_su_int", 3, 4, OpSUDotAccSat>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 1bc35c6..981e224 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -530,8 +530,18 @@ defm OpISubBorrow: BinOpTypedGen<"OpISubBorrow", 150, subc, 0, 1>;
def OpUMulExtended: BinOp<"OpUMulExtended", 151>;
def OpSMulExtended: BinOp<"OpSMulExtended", 152>;
-def OpSDot: BinOp<"OpSDot", 4450>;
-def OpUDot: BinOp<"OpUDot", 4451>;
+def OpSDot: Op<4450, (outs ID:$res), (ins TYPE:$type, ID:$vec1, ID:$vec2, variable_ops),
+ "$res = OpSDot $type $vec1 $vec2">;
+def OpUDot: Op<4451, (outs ID:$res), (ins TYPE:$type, ID:$vec1, ID:$vec2, variable_ops),
+ "$res = OpUDot $type $vec1 $vec2">;
+def OpSUDot: Op<4452, (outs ID:$res), (ins TYPE:$type, ID:$vec1, ID:$vec2, variable_ops),
+ "$res = OpSUDot $type $vec1 $vec2">;
+def OpSDotAccSat: Op<4453, (outs ID:$res), (ins TYPE:$type, ID:$vec1, ID:$vec2, ID:$acc, variable_ops),
+ "$res = OpSDotAccSat $type $vec1 $vec2 $acc">;
+def OpUDotAccSat: Op<4454, (outs ID:$res), (ins TYPE:$type, ID:$vec1, ID:$vec2, ID:$acc, variable_ops),
+ "$res = OpUDotAccSat $type $vec1 $vec2 $acc">;
+def OpSUDotAccSat: Op<4455, (outs ID:$res), (ins TYPE:$type, ID:$vec1, ID:$vec2, ID:$acc, variable_ops),
+ "$res = OpSUDotAccSat $type $vec1 $vec2 $acc">;
// 3.42.14 Bit Instructions
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index d3afaf42..a7a5ece 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1692,6 +1692,10 @@ void addInstrRequirements(const MachineInstr &MI,
break;
case SPIRV::OpSDot:
case SPIRV::OpUDot:
+ case SPIRV::OpSUDot:
+ case SPIRV::OpSDotAccSat:
+ case SPIRV::OpUDotAccSat:
+ case SPIRV::OpSUDotAccSat:
AddDotProductRequirements(MI, Reqs, ST);
break;
case SPIRV::OpImageRead: {