diff options
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 77 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVBuiltins.td | 13 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 32 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 8 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 16 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 9 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td | 2 | ||||
-rw-r--r-- | llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll | 54 | ||||
-rw-r--r-- | llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll | 30 |
11 files changed, 232 insertions, 13 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index c14e509..f5f3607 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -558,16 +558,21 @@ static Register buildMemSemanticsReg(Register SemanticsRegister, static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode, const SPIRV::IncomingCall *Call, - Register TypeReg = Register(0)) { + Register TypeReg, + ArrayRef<uint32_t> ImmArgs = {}) { MachineRegisterInfo *MRI = MIRBuilder.getMRI(); auto MIB = MIRBuilder.buildInstr(Opcode); if (TypeReg.isValid()) MIB.addDef(Call->ReturnRegister).addUse(TypeReg); - for (Register ArgReg : Call->Arguments) { + unsigned Sz = Call->Arguments.size() - ImmArgs.size(); + for (unsigned i = 0; i < Sz; ++i) { + Register ArgReg = Call->Arguments[i]; if (!MRI->getRegClassOrNull(ArgReg)) MRI->setRegClass(ArgReg, &SPIRV::IDRegClass); MIB.addUse(ArgReg); } + for (uint32_t ImmArg : ImmArgs) + MIB.addImm(ImmArg); return true; } @@ -575,7 +580,7 @@ static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode, static bool buildAtomicInitInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder) { if (Call->isSpirvOp()) - return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call); + return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call, Register(0)); assert(Call->Arguments.size() == 2 && "Need 2 arguments for atomic init translation"); @@ -633,7 +638,7 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { if (Call->isSpirvOp()) - return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call); + return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0)); Register ScopeRegister = buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR); @@ -870,7 +875,7 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { if (Call->isSpirvOp()) - return buildOpFromWrapper(MIRBuilder, Opcode, Call); + return buildOpFromWrapper(MIRBuilder, Opcode, Call, Register(0)); MachineRegisterInfo *MRI = MIRBuilder.getMRI(); unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI); @@ -1824,6 +1829,45 @@ static bool generateSelectInst(const SPIRV::IncomingCall *Call, return true; } +static bool generateConstructInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + return buildOpFromWrapper(MIRBuilder, SPIRV::OpCompositeConstruct, Call, + GR->getSPIRVTypeID(Call->ReturnType)); +} + +static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; + unsigned Opcode = + SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode; + bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR; + unsigned ArgSz = Call->Arguments.size(); + unsigned LiteralIdx = 0; + if (Opcode == SPIRV::OpCooperativeMatrixLoadKHR && ArgSz > 3) + LiteralIdx = 3; + else if (Opcode == SPIRV::OpCooperativeMatrixStoreKHR && ArgSz > 4) + LiteralIdx = 4; + SmallVector<uint32_t, 1> ImmArgs; + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + if (LiteralIdx > 0) + ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx], MRI)); + Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType); + if (Opcode == SPIRV::OpCooperativeMatrixLengthKHR) { + SPIRVType *CoopMatrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]); + if (!CoopMatrType) + report_fatal_error("Can't find a register's type definition"); + MIRBuilder.buildInstr(Opcode) + .addDef(Call->ReturnRegister) + .addUse(TypeReg) + .addUse(CoopMatrType->getOperand(0).getReg()); + return true; + } + return buildOpFromWrapper(MIRBuilder, Opcode, Call, + IsSet ? TypeReg : Register(0), ImmArgs); +} + static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { @@ -2382,6 +2426,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall, return generateSampleImageInst(DemangledCall, Call.get(), MIRBuilder, GR); case SPIRV::Select: return generateSelectInst(Call.get(), MIRBuilder); + case SPIRV::Construct: + return generateConstructInst(Call.get(), MIRBuilder, GR); case SPIRV::SpecConstant: return generateSpecConstantInst(Call.get(), MIRBuilder, GR); case SPIRV::Enqueue: @@ -2400,6 +2446,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall, return generateGroupUniformInst(Call.get(), MIRBuilder, GR); case SPIRV::KernelClock: return generateKernelClockInst(Call.get(), MIRBuilder, GR); + case SPIRV::CoopMatr: + return generateCoopMatrInst(Call.get(), MIRBuilder, GR); } return false; } @@ -2524,6 +2572,22 @@ static SPIRVType *getPipeType(const TargetExtType *ExtensionType, ExtensionType->getIntParameter(0))); } +static SPIRVType *getCoopMatrType(const TargetExtType *ExtensionType, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + assert(ExtensionType->getNumIntParameters() == 4 && + "Invalid number of parameters for SPIR-V coop matrices builtin!"); + assert(ExtensionType->getNumTypeParameters() == 1 && + "SPIR-V coop matrices builtin type must have a type parameter!"); + const SPIRVType *ElemType = + GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder); + // Create or get an existing type from GlobalRegistry. + return GR->getOrCreateOpTypeCoopMatr( + MIRBuilder, ExtensionType, ElemType, ExtensionType->getIntParameter(0), + ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2), + ExtensionType->getIntParameter(3)); +} + static SPIRVType * getImageType(const TargetExtType *ExtensionType, const SPIRV::AccessQualifier::AccessQualifier Qualifier, @@ -2654,6 +2718,9 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType, case SPIRV::OpTypeSampledImage: TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR); break; + case SPIRV::OpTypeCooperativeMatrixKHR: + TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR); + break; default: TargetType = getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR); diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td index 2b8e6d8..4bd1104 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td @@ -60,6 +60,8 @@ def AtomicFloating : BuiltinGroup; def GroupUniform : BuiltinGroup; def KernelClock : BuiltinGroup; def CastToPtr : BuiltinGroup; +def Construct : BuiltinGroup; +def CoopMatr : BuiltinGroup; //===----------------------------------------------------------------------===// // Class defining a demangled builtin record. The information in the record @@ -114,6 +116,9 @@ def : DemangledBuiltin<"__spirv_ImageSampleExplicitLod", OpenCL_std, SampleImage // Select builtin record: def : DemangledBuiltin<"__spirv_Select", OpenCL_std, Select, 3, 3>; +// Composite Construct builtin record: +def : DemangledBuiltin<"__spirv_CompositeConstruct", OpenCL_std, Construct, 1, 0>; + //===----------------------------------------------------------------------===// // Class defining an extended builtin record used for lowering into an // OpExtInst instruction. @@ -608,6 +613,12 @@ defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToGlobal", Ope defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToLocal", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>; defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToPrivate", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>; +// Cooperative Matrix builtin records: +defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLoadKHR", OpenCL_std, CoopMatr, 2, 0, OpCooperativeMatrixLoadKHR>; +defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreKHR", OpenCL_std, CoopMatr, 3, 0, OpCooperativeMatrixStoreKHR>; +defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixMulAddKHR", OpenCL_std, CoopMatr, 3, 0, OpCooperativeMatrixMulAddKHR>; +defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLengthKHR", OpenCL_std, CoopMatr, 1, 1, OpCooperativeMatrixLengthKHR>; + //===----------------------------------------------------------------------===// // Class defining a work/sub group builtin that should be translated into a // SPIR-V instruction using the defined properties. @@ -1436,7 +1447,7 @@ def : BuiltinType<"spirv.DeviceEvent", OpTypeDeviceEvent>; def : BuiltinType<"spirv.Image", OpTypeImage>; def : BuiltinType<"spirv.SampledImage", OpTypeSampledImage>; def : BuiltinType<"spirv.Pipe", OpTypePipe>; - +def : BuiltinType<"spirv.CooperativeMatrixKHR", OpTypeCooperativeMatrixKHR>; //===----------------------------------------------------------------------===// // Class matching an OpenCL builtin type name to an equivalent SPIR-V diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 75aa182..c7c244c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -66,6 +66,8 @@ static const std::map<std::string, SPIRV::Extension::Extension> SPIRV::Extension::Extension::SPV_INTEL_function_pointers}, {"SPV_KHR_shader_clock", SPIRV::Extension::Extension::SPV_KHR_shader_clock}, + {"SPV_KHR_cooperative_matrix", + SPIRV::Extension::Extension::SPV_KHR_cooperative_matrix}, }; bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName, diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index b22d2a04..b8710d2 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -1080,12 +1080,14 @@ bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { return IntType && IntType->getOperand(2).getImm() != 0; } +SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) { + return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer + ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg()) + : nullptr; +} + unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) { - SPIRVType *PtrType = getSPIRVTypeForVReg(PtrReg); - SPIRVType *ElemType = - PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer - ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg()) - : nullptr; + SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg)); return ElemType ? ElemType->getOpcode() : 0; } @@ -1189,6 +1191,26 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage( .addUse(getSPIRVTypeID(ImageType)); } +SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr( + MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType, + const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns, + uint32_t Use) { + Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF()); + if (ResVReg.isValid()) + return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg); + ResVReg = createTypeVReg(MIRBuilder); + SPIRVType *SpirvTy = + MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR) + .addDef(ResVReg) + .addUse(getSPIRVTypeID(ElemType)) + .addUse(buildConstantInt(Scope, MIRBuilder, nullptr, true)) + .addUse(buildConstantInt(Rows, MIRBuilder, nullptr, true)) + .addUse(buildConstantInt(Columns, MIRBuilder, nullptr, true)) + .addUse(buildConstantInt(Use, MIRBuilder, nullptr, true)); + DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg); + return SpirvTy; +} + SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode( const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) { Register ResVReg = DT.find(Ty, &MIRBuilder.getMF()); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index db01f68..cc4e20b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -292,6 +292,8 @@ public: return Res->second; } + // Return a pointee's type, or nullptr otherwise. + SPIRVType *getPointeeType(SPIRVType *PtrType); // Return a pointee's type op code, or 0 otherwise. unsigned getPointeeTypeOp(Register PtrReg); @@ -514,7 +516,11 @@ public: SPIRVType *getOrCreateOpTypeSampledImage(SPIRVType *ImageType, MachineIRBuilder &MIRBuilder); - + SPIRVType *getOrCreateOpTypeCoopMatr(MachineIRBuilder &MIRBuilder, + const TargetExtType *ExtensionType, + const SPIRVType *ElemType, + uint32_t Scope, uint32_t Rows, + uint32_t Columns, uint32_t Use); SPIRVType * getOrCreateOpTypePipe(MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AccQual); diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index dedfd5e..63549b0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -211,6 +211,9 @@ def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins), def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res), (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols), "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">; +def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res), + (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use), + "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">; // 3.42.7 Constant-Creation Instructions @@ -864,3 +867,16 @@ def OpAsmINTEL: Op<5610, (outs ID:$res), (ins TYPE:$type, TYPE:$asm_type, ID:$ta "$res = OpAsmINTEL $type $asm_type $target $asm">; def OpAsmCallINTEL: Op<5611, (outs ID:$res), (ins TYPE:$type, ID:$asm, variable_ops), "$res = OpAsmCallINTEL $type $asm">; + +// SPV_KHR_cooperative_matrix +def OpCooperativeMatrixLoadKHR: Op<4457, (outs ID:$res), + (ins TYPE:$resType, ID:$pointer, ID:$memory_layout, variable_ops), + "$res = OpCooperativeMatrixLoadKHR $resType $pointer $memory_layout">; +def OpCooperativeMatrixStoreKHR: Op<4458, (outs), + (ins ID:$pointer, ID:$objectToStore, ID:$memory_layout, variable_ops), + "OpCooperativeMatrixStoreKHR $pointer $objectToStore $memory_layout">; +def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res), + (ins TYPE:$type, ID:$A, ID:$B, ID:$C, variable_ops), + "$res = OpCooperativeMatrixMulAddKHR $type $A $B $C">; +def OpCooperativeMatrixLengthKHR: Op<4460, (outs ID:$res), (ins TYPE:$type, ID:$coop_matr_type), + "$res = OpCooperativeMatrixLengthKHR $type $coop_matr_type">; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index d7b96b2..a7c6c71 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -1117,7 +1117,7 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg, if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) { Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass); SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType( - SrcPtrTy, I, TII, SPIRV::StorageClass::Generic); + GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic); MachineBasicBlock &BB = *I.getParent(); const DebugLoc &DL = I.getDebugLoc(); bool Success = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric)) diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 30a6c47..ac0aa68 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1168,6 +1168,15 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::AsmINTEL); } break; + case SPIRV::OpTypeCooperativeMatrixKHR: + if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix)) + report_fatal_error( + "OpTypeCooperativeMatrixKHR type requires the " + "following SPIR-V extension: SPV_KHR_cooperative_matrix", + false); + Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix); + Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR); + break; default: break; } diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 318c5ce..96601dd 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -302,6 +302,7 @@ defm SPV_INTEL_inline_assembly : ExtensionOperand<107>; defm SPV_INTEL_cache_controls : ExtensionOperand<108>; defm SPV_INTEL_global_variable_host_access : ExtensionOperand<109>; defm SPV_INTEL_global_variable_fpga_decorations : ExtensionOperand<110>; +defm SPV_KHR_cooperative_matrix : ExtensionOperand<111>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -478,6 +479,7 @@ defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_gl defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>; defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>; defm CacheControlsINTEL : CapabilityOperand<6441, 0, 0, [SPV_INTEL_cache_controls], []>; +defm CooperativeMatrixKHR : CapabilityOperand<6022, 0, 0, [SPV_KHR_cooperative_matrix], []>; //===----------------------------------------------------------------------===// // Multiclass used to define SourceLanguage enum values and at the same time diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll new file mode 100644 index 0000000..1c41c73 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll @@ -0,0 +1,54 @@ +; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %} + +; CHECK-ERROR: LLVM ERROR: OpTypeCooperativeMatrixKHR type requires the following SPIR-V extension: SPV_KHR_cooperative_matrix + +; CHECK: OpCapability CooperativeMatrixKHR +; CHECK: OpExtension "SPV_KHR_cooperative_matrix" + +; CHECK-DAG: %[[#Int32Ty:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#Const12:]] = OpConstant %[[#Int32Ty]] 12 +; CHECK-DAG: %[[#Const48:]] = OpConstant %[[#Int32Ty]] 48 +; CHECK-DAG: %[[#Const3:]] = OpConstant %[[#Int32Ty]] 3 +; CHECK-DAG: %[[#Const2:]] = OpConstant %[[#Int32Ty]] 2 +; CHECK-DAG: %[[#Const1:]] = OpConstant %[[#Int32Ty]] 1 +; CHECK-DAG: %[[#Const0:]] = OpConstant %[[#Int32Ty]] 0 +; CHECK-DAG: %[[#MatTy1:]] = OpTypeCooperativeMatrixKHR %[[#Int32Ty]] %[[#Const3]] %[[#Const12]] %[[#Const12]] %[[#Const2]] +; CHECK-DAG: %[[#MatTy2:]] = OpTypeCooperativeMatrixKHR %[[#Int32Ty]] %[[#Const3]] %[[#Const12]] %[[#Const48]] %[[#Const0]] +; CHECK-DAG: %[[#MatTy3:]] = OpTypeCooperativeMatrixKHR %[[#Int32Ty]] %[[#Const3]] %[[#Const48]] %[[#Const12]] %[[#Const1]] +; CHECK: OpCompositeConstruct %[[#MatTy1]] +; CHECK: %[[#Load1:]] = OpCooperativeMatrixLoadKHR %[[#MatTy2]] +; CHECK: OpCooperativeMatrixLengthKHR %[[#Int32Ty]] %[[#MatTy2:]] +; CHECK: OpCooperativeMatrixLoadKHR %[[#MatTy3]] +; CHECK: OpCooperativeMatrixMulAddKHR %[[#MatTy1]] +; CHECK: OpCooperativeMatrixStoreKHR + +define spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) { +entry: + %addr1 = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), align 8 + %res = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), align 8 + %m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32 0) + store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %m1, ptr %addr1, align 8 + %accA3 = addrspacecast ptr addrspace(1) %_arg_accA to ptr addrspace(3) + %m2 = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(3) %accA3, i32 0, i64 %_arg_K, i32 1) + %len = tail call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0) %m2) + %accB3 = addrspacecast ptr addrspace(1) %_arg_accB to ptr addrspace(3) + %m3 = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(3) %accB3, i32 0, i64 0) + %m4 = load target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), ptr %addr1, align 8 + %m5 = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0) %m2, target("spirv.CooperativeMatrixKHR", i32, 3, 48, 12, 1) %m3, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %m4, i32 12) + store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %m5, ptr %res, align 8 + %r = load i64, ptr %res, align 8 + store i64 %r, ptr %addr1, align 8 + %accC3 = addrspacecast ptr addrspace(1) %_arg_accC to ptr addrspace(3) + %m6 = load target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), ptr %addr1, align 8 + tail call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(3) %accC3, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %m6, i32 0, i64 %_arg_N, i32 1) + ret void +} + +declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32) +declare dso_local spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0)) +declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(3), i32, i64, i32) +declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(3), i32, i64) +declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0), target("spirv.CooperativeMatrixKHR", i32, 3, 48, 12, 1), target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32) +declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(3), target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32, i64, i32) diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll new file mode 100644 index 0000000..818243a --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll @@ -0,0 +1,30 @@ +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-SPIRV-DAG: %[[#Char:]] = OpTypeInt 8 0 +; CHECK-SPIRV-DAG: %[[#GlobalCharPtr:]] = OpTypePointer CrossWorkgroup %[[#Char]] +; CHECK-SPIRV-DAG: %[[#LocalCharPtr:]] = OpTypePointer Workgroup %[[#Char]] +; CHECK-SPIRV-DAG: %[[#GenericCharPtr:]] = OpTypePointer Generic %[[#Char]] +; CHECK-SPIRV: OpFunction +; CHECK-SPIRV: %[[#Arg1:]] = OpFunctionParameter %[[#GlobalCharPtr]] +; CHECK-SPIRV: %[[#Ptr1:]] = OpPtrCastToGeneric %[[#GenericCharPtr]] %[[#Arg1]] +; CHECK-SPIRV: OpGenericCastToPtr %[[#LocalCharPtr]] %[[#Ptr1]] +; CHECK-SPIRV: OpFunctionEnd +; CHECK-SPIRV: OpFunction +; CHECK-SPIRV: %[[#Arg2:]] = OpFunctionParameter %[[#GlobalCharPtr]] +; CHECK-SPIRV: %[[#Ptr2:]] = OpPtrCastToGeneric %[[#GenericCharPtr]] %[[#Arg2]] +; CHECK-SPIRV: OpGenericCastToPtr %[[#LocalCharPtr]] %[[#Ptr2]] +; CHECK-SPIRV: OpFunctionEnd + +define spir_kernel void @foo(ptr addrspace(1) %arg) { +entry: + %p = addrspacecast ptr addrspace(1) %arg to ptr addrspace(3) + ret void +} + +define spir_kernel void @bar(ptr addrspace(1) %arg) { +entry: + %p1 = addrspacecast ptr addrspace(1) %arg to ptr addrspace(4) + %p2 = addrspacecast ptr addrspace(4) %p1 to ptr addrspace(3) + ret void +} |