diff options
Diffstat (limited to 'llvm/lib/Target')
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 51 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVBuiltins.td | 24 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 5 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 11 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 33 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 14 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 68 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp | 24 |
8 files changed, 219 insertions, 11 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index 86f4459..f704d3a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -1096,6 +1096,41 @@ static bool build2DBlockIOINTELInst(const SPIRV::IncomingCall *Call, return true; } +static bool buildPipeInst(const SPIRV::IncomingCall *Call, unsigned Opcode, + unsigned Scope, MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + switch (Opcode) { + case SPIRV::OpCommitReadPipe: + case SPIRV::OpCommitWritePipe: + return buildOpFromWrapper(MIRBuilder, Opcode, Call, Register(0)); + case SPIRV::OpGroupCommitReadPipe: + case SPIRV::OpGroupCommitWritePipe: + case SPIRV::OpGroupReserveReadPipePackets: + case SPIRV::OpGroupReserveWritePipePackets: { + Register ScopeConstReg = + MIRBuilder.buildConstant(LLT::scalar(32), Scope).getReg(0); + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + MRI->setRegClass(ScopeConstReg, &SPIRV::iIDRegClass); + MachineInstrBuilder MIB; + MIB = MIRBuilder.buildInstr(Opcode); + // Add Return register and type. + if (Opcode == SPIRV::OpGroupReserveReadPipePackets || + Opcode == SPIRV::OpGroupReserveWritePipePackets) + MIB.addDef(Call->ReturnRegister) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)); + + MIB.addUse(ScopeConstReg); + for (unsigned int i = 0; i < Call->Arguments.size(); ++i) + MIB.addUse(Call->Arguments[i]); + + return true; + } + default: + return buildOpFromWrapper(MIRBuilder, Opcode, Call, + GR->getSPIRVTypeID(Call->ReturnType)); + } +} + static unsigned getNumComponentsForDim(SPIRV::Dim::Dim dim) { switch (dim) { case SPIRV::Dim::DIM_1D: @@ -2350,6 +2385,20 @@ static bool generate2DBlockIOINTELInst(const SPIRV::IncomingCall *Call, return build2DBlockIOINTELInst(Call, Opcode, MIRBuilder, GR); } +static bool generatePipeInst(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; + unsigned Opcode = + SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode; + + unsigned Scope = SPIRV::Scope::Workgroup; + if (Builtin->Name.contains("sub_group")) + Scope = SPIRV::Scope::Subgroup; + + return buildPipeInst(Call, Opcode, Scope, MIRBuilder, GR); +} + static bool buildNDRange(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { @@ -2948,6 +2997,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall, return generateTernaryBitwiseFunctionINTELInst(Call.get(), MIRBuilder, GR); case SPIRV::Block2DLoadStore: return generate2DBlockIOINTELInst(Call.get(), MIRBuilder, GR); + case SPIRV::Pipe: + return generatePipeInst(Call.get(), MIRBuilder, GR); } return false; } diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td index d08560b..2a8deb6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td @@ -69,6 +69,7 @@ def ExtendedBitOps : BuiltinGroup; def BindlessINTEL : BuiltinGroup; def TernaryBitwiseINTEL : BuiltinGroup; def Block2DLoadStore : BuiltinGroup; +def Pipe : BuiltinGroup; //===----------------------------------------------------------------------===// // Class defining a demangled builtin record. The information in the record @@ -633,6 +634,29 @@ defm : DemangledNativeBuiltin<"__spirv_AtomicSMax", OpenCL_std, Atomic, 4, 4, Op defm : DemangledNativeBuiltin<"__spirv_AtomicUMin", OpenCL_std, Atomic, 4, 4, OpAtomicUMin>; defm : DemangledNativeBuiltin<"__spirv_AtomicUMax", OpenCL_std, Atomic, 4, 4, OpAtomicUMax>; +// Pipe Instruction +defm : DemangledNativeBuiltin<"__read_pipe_2", OpenCL_std, Pipe,2, 2, OpReadPipe>; +defm : DemangledNativeBuiltin<"__write_pipe_2", OpenCL_std, Pipe, 2, 2, OpWritePipe>; +defm : DemangledNativeBuiltin<"__read_pipe_4", OpenCL_std, Pipe,4, 4, OpReservedReadPipe>; +defm : DemangledNativeBuiltin<"__write_pipe_4", OpenCL_std, Pipe, 4, 4, OpReservedWritePipe>; +defm : DemangledNativeBuiltin<"__reserve_read_pipe", OpenCL_std, Pipe, 2, 2, OpReserveReadPipePackets>; +defm : DemangledNativeBuiltin<"__reserve_write_pipe", OpenCL_std, Pipe, 2, 2, OpReserveWritePipePackets>; +defm : DemangledNativeBuiltin<"__commit_read_pipe", OpenCL_std, Pipe, 2, 2, OpCommitReadPipe>; +defm : DemangledNativeBuiltin<"__commit_write_pipe", OpenCL_std, Pipe, 2, 2, OpCommitWritePipe>; +defm : DemangledNativeBuiltin<"is_valid_reserve_id", OpenCL_std, Pipe, 1, 1, OpIsValidReserveId>; +defm : DemangledNativeBuiltin<"__get_pipe_num_packets_ro", OpenCL_std, Pipe, 1, 1, OpGetNumPipePackets>; +defm : DemangledNativeBuiltin<"__get_pipe_max_packets_ro", OpenCL_std, Pipe, 1, 1, OpGetMaxPipePackets>; +defm : DemangledNativeBuiltin<"__get_pipe_num_packets_wo", OpenCL_std, Pipe, 1, 1, OpGetNumPipePackets>; +defm : DemangledNativeBuiltin<"__get_pipe_max_packets_wo", OpenCL_std, Pipe, 1, 1, OpGetMaxPipePackets>; +defm : DemangledNativeBuiltin<"__work_group_reserve_read_pipe", OpenCL_std, Pipe, 2, 2, OpGroupReserveReadPipePackets>; +defm : DemangledNativeBuiltin<"__work_group_reserve_write_pipe", OpenCL_std, Pipe, 2, 2, OpGroupReserveWritePipePackets>; +defm : DemangledNativeBuiltin<"__work_group_commit_read_pipe", OpenCL_std, Pipe, 2, 2, OpGroupCommitReadPipe>; +defm : DemangledNativeBuiltin<"__work_group_commit_write_pipe", OpenCL_std, Pipe, 2, 2, OpGroupCommitWritePipe>; +defm : DemangledNativeBuiltin<"__sub_group_reserve_read_pipe", OpenCL_std, Pipe, 2, 2, OpGroupReserveReadPipePackets>; +defm : DemangledNativeBuiltin<"__sub_group_reserve_write_pipe", OpenCL_std, Pipe, 2, 2, OpGroupReserveWritePipePackets>; +defm : DemangledNativeBuiltin<"__sub_group_commit_read_pipe", OpenCL_std, Pipe, 2, 2, OpGroupCommitReadPipe>; +defm : DemangledNativeBuiltin<"__sub_group_commit_write_pipe", OpenCL_std, Pipe, 2, 2, OpGroupCommitWritePipe>; + // Barrier builtin records: defm : DemangledNativeBuiltin<"barrier", OpenCL_std, Barrier, 1, 3, OpControlBarrier>; defm : DemangledNativeBuiltin<"work_group_barrier", OpenCL_std, Barrier, 1, 3, OpControlBarrier>; diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 993de9e..85ea9e1 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -148,7 +148,10 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>> SPIRV::Extension::Extension::SPV_KHR_float_controls2}, {"SPV_INTEL_tensor_float32_conversion", SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}, - {"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}}; + {"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}, + {"SPV_EXT_relaxed_printf_string_address_space", + SPIRV::Extension::Extension:: + SPV_EXT_relaxed_printf_string_address_space}}; bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName, StringRef ArgValue, diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index f5a49e2..704edd3 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -1909,11 +1909,12 @@ Instruction *SPIRVEmitIntrinsics::visitInsertValueInst(InsertValueInst &I) { B.SetInsertPoint(&I); SmallVector<Type *, 1> Types = {I.getInsertedValueOperand()->getType()}; SmallVector<Value *> Args; - for (auto &Op : I.operands()) - if (isa<UndefValue>(Op)) - Args.push_back(UndefValue::get(B.getInt32Ty())); - else - Args.push_back(Op); + Value *AggregateOp = I.getAggregateOperand(); + if (isa<UndefValue>(AggregateOp)) + Args.push_back(UndefValue::get(B.getInt32Ty())); + else + Args.push_back(AggregateOp); + Args.push_back(I.getInsertedValueOperand()); for (auto &Op : I.indices()) Args.push_back(B.getInt32(Op)); Instruction *NewI = diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index 496dcba..1723bfb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -763,7 +763,38 @@ def OpGetDefaultQueue: Op<303, (outs ID:$res), (ins TYPE:$type), def OpBuildNDRange: Op<304, (outs ID:$res), (ins TYPE:$type, ID:$GWS, ID:$LWS, ID:$GWO), "$res = OpBuildNDRange $type $GWS $LWS $GWO">; -// TODO: 3.42.23. Pipe Instructions +// 3.42.23. Pipe Instructions + +def OpReadPipe: Op<274, (outs ID:$res), (ins TYPE:$type, ID:$Pipe, ID:$Pointer, ID:$PcktSize, ID:$PcktAlign), + "$res = OpReadPipe $type $Pipe $Pointer $PcktSize $PcktAlign">; +def OpWritePipe: Op<275, (outs ID:$res), (ins TYPE:$type, ID:$Pipe, ID:$Pointer, ID:$PcktSize, ID:$PcktAlign), + "$res = OpWritePipe $type $Pipe $Pointer $PcktSize $PcktAlign">; +def OpReservedReadPipe : Op<276, (outs ID:$res), (ins TYPE:$type, ID:$Pipe, ID:$ReserveId, ID:$Index, ID:$Pointer, ID:$PcktSize, ID:$PcktAlign), + "$res = OpReservedReadPipe $type $Pipe $ReserveId $Index $Pointer $PcktSize $PcktAlign">; +def OpReservedWritePipe : Op<277, (outs ID:$res), (ins TYPE:$type, ID:$Pipe, ID:$ReserveId, ID:$Index, ID:$Pointer, ID:$PcktSize, ID:$PcktAlign), + "$res = OpReservedWritePipe $type $Pipe $ReserveId $Index $Pointer $PcktSize $PcktAlign">; +def OpReserveReadPipePackets : Op<278, (outs ID:$res), (ins TYPE:$type, ID:$Pipe, ID:$NumPckts, ID:$PcktSize, ID:$PcktAlign), + "$res = OpReserveReadPipePackets $type $Pipe $NumPckts $PcktSize $PcktAlign">; +def OpReserveWritePipePackets : Op<279, (outs ID:$res), (ins TYPE:$type, ID:$Pipe, ID:$NumPckts, ID:$PcktSize, ID:$PcktAlign), + "$res = OpReserveWritePipePackets $type $Pipe $NumPckts $PcktSize $PcktAlign">; +def OpCommitReadPipe : Op<280, (outs), (ins ID:$Pipe, ID:$ReserveId, ID:$PcktSize, ID:$PcktAlign), + "OpCommitReadPipe $Pipe $ReserveId $PcktSize $PcktAlign">; +def OpCommitWritePipe : Op<281, (outs), (ins ID:$Pipe, ID:$ReserveId, ID:$PcktSize, ID:$PcktAlign), + "OpCommitWritePipe $Pipe $ReserveId $PcktSize $PcktAlign">; +def OpIsValidReserveId : Op<282, (outs ID:$res), (ins TYPE:$type, ID:$ReserveId), + "$res = OpIsValidReserveId $type $ReserveId">; +def OpGetNumPipePackets : Op<283, (outs ID:$res), (ins TYPE:$type, ID:$Pipe, ID:$PacketSize, ID:$PacketAlign), + "$res = OpGetNumPipePackets $type $Pipe $PacketSize $PacketAlign">; +def OpGetMaxPipePackets : Op<284, (outs ID:$res), (ins TYPE:$type, ID:$Pipe, ID:$PacketSize, ID:$PacketAlign), + "$res = OpGetMaxPipePackets $type $Pipe $PacketSize $PacketAlign">; +def OpGroupReserveReadPipePackets : Op<285, (outs ID:$res), (ins TYPE:$type, ID:$Scope, ID:$Pipe, ID:$NumPckts, ID:$PacketSize, ID:$PacketAlign), + "$res = OpGroupReserveReadPipePackets $type $Scope $Pipe $NumPckts $PacketSize $PacketAlign">; +def OpGroupReserveWritePipePackets : Op<286, (outs ID:$res), (ins TYPE:$type, ID:$Scope, ID:$Pipe, ID:$NumPckts, ID:$PacketSize, ID:$PacketAlign), + "$res = OpGroupReserveWritePipePackets $type $Scope $Pipe $NumPckts $PacketSize $PacketAlign">; +def OpGroupCommitReadPipe : Op<287, (outs), (ins ID:$Scope, ID:$Pipe, ID:$ReserveId, ID:$PacketSize, ID:$PacketAlign), + "OpGroupCommitReadPipe $Scope $Pipe $ReserveId $PacketSize $PacketAlign">; +def OpGroupCommitWritePipe : Op<288, (outs), (ins ID:$Scope, ID:$Pipe, ID:$ReserveId, ID:$PacketSize, ID:$PacketAlign), + "OpGroupCommitWritePipe $Scope $Pipe $ReserveId $PacketSize $PacketAlign">; // 3.42.24. Non-Uniform Instructions diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index a7b2179..5266e20 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -197,6 +197,8 @@ private: bool selectOverflowArith(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, unsigned Opcode) const; + bool selectDebugTrap(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, bool Signed) const; @@ -999,16 +1001,26 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg, // represent code after lowering or intrinsics which are not implemented but // should not crash when found in a customer's LLVM IR input. case TargetOpcode::G_TRAP: - case TargetOpcode::G_DEBUGTRAP: case TargetOpcode::G_UBSANTRAP: case TargetOpcode::DBG_LABEL: return true; + case TargetOpcode::G_DEBUGTRAP: + return selectDebugTrap(ResVReg, ResType, I); default: return false; } } +bool SPIRVInstructionSelector::selectDebugTrap(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + unsigned Opcode = SPIRV::OpNop; + MachineBasicBlock &BB = *I.getParent(); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) + .constrainAllUses(TII, TRI, RBI); +} + bool SPIRVInstructionSelector::selectExtInst(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index a95f393..bc159d5 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1222,6 +1222,31 @@ static void AddDotProductRequirements(const MachineInstr &MI, } } +void addPrintfRequirements(const MachineInstr &MI, + SPIRV::RequirementHandler &Reqs, + const SPIRVSubtarget &ST) { + SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry(); + const SPIRVType *PtrType = GR->getSPIRVTypeForVReg(MI.getOperand(4).getReg()); + if (PtrType) { + MachineOperand ASOp = PtrType->getOperand(1); + if (ASOp.isImm()) { + unsigned AddrSpace = ASOp.getImm(); + if (AddrSpace != SPIRV::StorageClass::UniformConstant) { + if (!ST.canUseExtension( + SPIRV::Extension:: + SPV_EXT_relaxed_printf_string_address_space)) { + report_fatal_error("SPV_EXT_relaxed_printf_string_address_space is " + "required because printf uses a format string not " + "in constant address space.", + false); + } + Reqs.addExtension( + SPIRV::Extension::SPV_EXT_relaxed_printf_string_address_space); + } + } + } +} + static bool isBFloat16Type(const SPIRVType *TypeDef) { return TypeDef && TypeDef->getNumOperands() == 3 && TypeDef->getOpcode() == SPIRV::OpTypeFloat && @@ -1230,8 +1255,9 @@ static bool isBFloat16Type(const SPIRVType *TypeDef) { } void addInstrRequirements(const MachineInstr &MI, - SPIRV::RequirementHandler &Reqs, + SPIRV::ModuleAnalysisInfo &MAI, const SPIRVSubtarget &ST) { + SPIRV::RequirementHandler &Reqs = MAI.Reqs; switch (MI.getOpcode()) { case SPIRV::OpMemoryModel: { int64_t Addr = MI.getOperand(0).getImm(); @@ -1321,6 +1347,12 @@ void addInstrRequirements(const MachineInstr &MI, static_cast<int64_t>( SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) { Reqs.addExtension(SPIRV::Extension::SPV_KHR_non_semantic_info); + break; + } + if (MI.getOperand(3).getImm() == + static_cast<int64_t>(SPIRV::OpenCLExtInst::printf)) { + addPrintfRequirements(MI, Reqs, ST); + break; } break; } @@ -1781,15 +1813,45 @@ void addInstrRequirements(const MachineInstr &MI, break; case SPIRV::OpConvertHandleToImageINTEL: case SPIRV::OpConvertHandleToSamplerINTEL: - case SPIRV::OpConvertHandleToSampledImageINTEL: + case SPIRV::OpConvertHandleToSampledImageINTEL: { if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bindless_images)) report_fatal_error("OpConvertHandleTo[Image/Sampler/SampledImage]INTEL " "instructions require the following SPIR-V extension: " "SPV_INTEL_bindless_images", false); + SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry(); + SPIRV::AddressingModel::AddressingModel AddrModel = MAI.Addr; + SPIRVType *TyDef = GR->getSPIRVTypeForVReg(MI.getOperand(1).getReg()); + if (MI.getOpcode() == SPIRV::OpConvertHandleToImageINTEL && + TyDef->getOpcode() != SPIRV::OpTypeImage) { + report_fatal_error("Incorrect return type for the instruction " + "OpConvertHandleToImageINTEL", + false); + } else if (MI.getOpcode() == SPIRV::OpConvertHandleToSamplerINTEL && + TyDef->getOpcode() != SPIRV::OpTypeSampler) { + report_fatal_error("Incorrect return type for the instruction " + "OpConvertHandleToSamplerINTEL", + false); + } else if (MI.getOpcode() == SPIRV::OpConvertHandleToSampledImageINTEL && + TyDef->getOpcode() != SPIRV::OpTypeSampledImage) { + report_fatal_error("Incorrect return type for the instruction " + "OpConvertHandleToSampledImageINTEL", + false); + } + SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(MI.getOperand(2).getReg()); + unsigned Bitwidth = GR->getScalarOrVectorBitWidth(SpvTy); + if (!(Bitwidth == 32 && AddrModel == SPIRV::AddressingModel::Physical32) && + !(Bitwidth == 64 && AddrModel == SPIRV::AddressingModel::Physical64)) { + report_fatal_error( + "Parameter value must be a 32-bit scalar in case of " + "Physical32 addressing model or a 64-bit scalar in case of " + "Physical64 addressing model", + false); + } Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bindless_images); Reqs.addCapability(SPIRV::Capability::BindlessImagesINTEL); break; + } case SPIRV::OpSubgroup2DBlockLoadINTEL: case SPIRV::OpSubgroup2DBlockLoadTransposeINTEL: case SPIRV::OpSubgroup2DBlockLoadTransformINTEL: @@ -1927,7 +1989,7 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, continue; for (const MachineBasicBlock &MBB : *MF) for (const MachineInstr &MI : MBB) - addInstrRequirements(MI, MAI.Reqs, ST); + addInstrRequirements(MI, MAI, ST); } // Collect requirements for OpExecutionMode instructions. auto Node = M.getNamedMetadata("spirv.ExecutionMode"); diff --git a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp index 2b34f61..4e4e6fb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp @@ -335,6 +335,21 @@ static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) { FSHIntrinsic->setCalledFunction(FSHFunc); } +static void lowerConstrainedFPCmpIntrinsic( + ConstrainedFPCmpIntrinsic *ConstrainedCmpIntrinsic, + SmallVector<Instruction *> &EraseFromParent) { + if (!ConstrainedCmpIntrinsic) + return; + // Extract the floating-point values being compared + Value *LHS = ConstrainedCmpIntrinsic->getArgOperand(0); + Value *RHS = ConstrainedCmpIntrinsic->getArgOperand(1); + FCmpInst::Predicate Pred = ConstrainedCmpIntrinsic->getPredicate(); + IRBuilder<> Builder(ConstrainedCmpIntrinsic); + Value *FCmp = Builder.CreateFCmp(Pred, LHS, RHS); + ConstrainedCmpIntrinsic->replaceAllUsesWith(FCmp); + EraseFromParent.push_back(dyn_cast<Instruction>(ConstrainedCmpIntrinsic)); +} + static void lowerExpectAssume(IntrinsicInst *II) { // If we cannot use the SPV_KHR_expect_assume extension, then we need to // ignore the intrinsic and move on. It should be removed later on by LLVM. @@ -376,6 +391,7 @@ static bool toSpvLifetimeIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID) { bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) { bool Changed = false; const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F); + SmallVector<Instruction *> EraseFromParent; for (BasicBlock &BB : *F) { for (Instruction &I : make_early_inc_range(BB)) { auto Call = dyn_cast<CallInst>(&I); @@ -423,9 +439,17 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) { lowerPtrAnnotation(II); Changed = true; break; + case Intrinsic::experimental_constrained_fcmp: + case Intrinsic::experimental_constrained_fcmps: + lowerConstrainedFPCmpIntrinsic(dyn_cast<ConstrainedFPCmpIntrinsic>(II), + EraseFromParent); + Changed = true; + break; } } } + for (auto *I : EraseFromParent) + I->eraseFromParent(); return Changed; } |