diff options
author | Diana Picus <Diana-Magda.Picus@amd.com> | 2025-09-04 10:34:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-09-04 10:34:43 +0200 |
commit | 018dc1b3977bb249d55a6808bb45802a10f818fa (patch) | |
tree | 0151a68c8f3628dc2ae093d3cfa1537958bce48a /llvm/lib/Target | |
parent | d1408667de830da8817c24cb9788da6caae551c7 (diff) | |
download | llvm-018dc1b3977bb249d55a6808bb45802a10f818fa.zip llvm-018dc1b3977bb249d55a6808bb45802a10f818fa.tar.gz llvm-018dc1b3977bb249d55a6808bb45802a10f818fa.tar.bz2 |
[AMDGPU] Tail call support for whole wave functions (#145860)
Support tail calls to whole wave functions (trivial) and from whole wave
functions (slightly more involved because we need a new pseudo for the
tail call return, that patches up the EXEC mask).
Move the expansion of whole wave function return pseudos (regular and
tail call returns) to prolog epilog insertion, since that's where we
patch up the EXEC mask.
Diffstat (limited to 'llvm/lib/Target')
-rw-r--r-- | llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp | 35 | ||||
-rw-r--r-- | llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp | 1 | ||||
-rw-r--r-- | llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h | 1 | ||||
-rw-r--r-- | llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td | 4 | ||||
-rw-r--r-- | llvm/lib/Target/AMDGPU/SIFrameLowering.cpp | 20 | ||||
-rw-r--r-- | llvm/lib/Target/AMDGPU/SIISelLowering.cpp | 8 | ||||
-rw-r--r-- | llvm/lib/Target/AMDGPU/SIInstrInfo.cpp | 1 | ||||
-rw-r--r-- | llvm/lib/Target/AMDGPU/SIInstructions.td | 27 | ||||
-rw-r--r-- | llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h | 1 |
9 files changed, 86 insertions, 12 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp index d1a5b4e..21255f69 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp @@ -1004,8 +1004,14 @@ static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect, return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64; } - return CC == CallingConv::AMDGPU_Gfx ? AMDGPU::SI_TCRETURN_GFX : - AMDGPU::SI_TCRETURN; + if (CallerF.getFunction().getCallingConv() == + CallingConv::AMDGPU_Gfx_WholeWave) + return AMDGPU::SI_TCRETURN_GFX_WholeWave; + + if (CC == CallingConv::AMDGPU_Gfx || CC == CallingConv::AMDGPU_Gfx_WholeWave) + return AMDGPU::SI_TCRETURN_GFX; + + return AMDGPU::SI_TCRETURN; } // Add operands to call instruction to track the callee. @@ -1284,6 +1290,13 @@ bool AMDGPUCallLowering::lowerTailCall( unsigned Opc = getCallOpcode(MF, Info.Callee.isReg(), /*IsTailCall*/ true, ST.isWave32(), CalleeCC, IsDynamicVGPRChainCall); auto MIB = MIRBuilder.buildInstrNoInsert(Opc); + + if (FuncInfo->isWholeWaveFunction()) + addOriginalExecToReturn(MF, MIB); + + // Keep track of the index of the next operand to be added to the call + unsigned CalleeIdx = MIB->getNumOperands(); + if (!addCallTargetOperands(MIB, MIRBuilder, Info, IsDynamicVGPRChainCall)) return false; @@ -1401,7 +1414,7 @@ bool AMDGPUCallLowering::lowerTailCall( // If we have -tailcallopt, we need to adjust the stack. We'll do the call // sequence start and end here. if (!IsSibCall) { - MIB->getOperand(1).setImm(FPDiff); + MIB->getOperand(CalleeIdx + 1).setImm(FPDiff); CallSeqStart.addImm(NumBytes).addImm(0); // End the call sequence *before* emitting the call. Normally, we would // tidy the frame up after the call. However, here, we've laid out the @@ -1413,16 +1426,24 @@ bool AMDGPUCallLowering::lowerTailCall( // Now we can add the actual call instruction to the correct basic block. MIRBuilder.insertInstr(MIB); + // If this is a whole wave tail call, we need to constrain the register for + // the original EXEC. + if (MIB->getOpcode() == AMDGPU::SI_TCRETURN_GFX_WholeWave) { + MIB->getOperand(0).setReg( + constrainOperandRegClass(MF, *TRI, MRI, *TII, *ST.getRegBankInfo(), + *MIB, MIB->getDesc(), MIB->getOperand(0), 0)); + } + // If Callee is a reg, since it is used by a target specific // instruction, it must have a register class matching the // constraint of that instruction. // FIXME: We should define regbankselectable call instructions to handle // divergent call targets. - if (MIB->getOperand(0).isReg()) { - MIB->getOperand(0).setReg( - constrainOperandRegClass(MF, *TRI, MRI, *TII, *ST.getRegBankInfo(), - *MIB, MIB->getDesc(), MIB->getOperand(0), 0)); + if (MIB->getOperand(CalleeIdx).isReg()) { + MIB->getOperand(CalleeIdx).setReg(constrainOperandRegClass( + MF, *TRI, MRI, *TII, *ST.getRegBankInfo(), *MIB, MIB->getDesc(), + MIB->getOperand(CalleeIdx), CalleeIdx)); } MF.getFrameInfo().setHasTailCall(); diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp index c0ab055..5c9b616 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -5667,6 +5667,7 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(CALL) NODE_NAME_CASE(TC_RETURN) NODE_NAME_CASE(TC_RETURN_GFX) + NODE_NAME_CASE(TC_RETURN_GFX_WholeWave) NODE_NAME_CASE(TC_RETURN_CHAIN) NODE_NAME_CASE(TC_RETURN_CHAIN_DVGPR) NODE_NAME_CASE(TRAP) diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h index 78394ac..bdaf486 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h @@ -418,6 +418,7 @@ enum NodeType : unsigned { CALL, TC_RETURN, TC_RETURN_GFX, + TC_RETURN_GFX_WholeWave, TC_RETURN_CHAIN, TC_RETURN_CHAIN_DVGPR, TRAP, diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td b/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td index e305f08..b8fa6f3 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td @@ -94,6 +94,10 @@ def AMDGPUtc_return_gfx: SDNode<"AMDGPUISD::TC_RETURN_GFX", AMDGPUTCReturnTP, [SDNPHasChain, SDNPOptInGlue, SDNPVariadic] >; +def AMDGPUtc_return_gfx_ww: SDNode<"AMDGPUISD::TC_RETURN_GFX_WholeWave", AMDGPUTCReturnTP, +[SDNPHasChain, SDNPOptInGlue, SDNPVariadic] +>; + def AMDGPUtc_return_chain: SDNode<"AMDGPUISD::TC_RETURN_CHAIN", SDTypeProfile<0, -1, [SDTCisPtrTy<0>]>, [SDNPHasChain, SDNPOptInGlue, SDNPVariadic] diff --git a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp index 357b6f0..ce25bf4 100644 --- a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp @@ -1132,9 +1132,18 @@ void SIFrameLowering::emitCSRSpillRestores( RestoreWWMRegisters(WWMCalleeSavedRegs); // The original EXEC is the first operand of the return instruction. - const MachineInstr &Return = MBB.instr_back(); - assert(Return.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN && - "Unexpected return inst"); + MachineInstr &Return = MBB.instr_back(); + unsigned Opcode = Return.getOpcode(); + switch (Opcode) { + case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN: + Opcode = AMDGPU::SI_RETURN; + break; + case AMDGPU::SI_TCRETURN_GFX_WholeWave: + Opcode = AMDGPU::SI_TCRETURN_GFX; + break; + default: + llvm_unreachable("Unexpected return inst"); + } Register OrigExec = Return.getOperand(0).getReg(); if (!WWMScratchRegs.empty()) { @@ -1148,6 +1157,11 @@ void SIFrameLowering::emitCSRSpillRestores( // Restore original EXEC. unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64; BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addReg(OrigExec); + + // Drop the first operand and update the opcode. + Return.removeOperand(0); + Return.setDesc(TII->get(Opcode)); + return; } diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index 1332ef6..c523789 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -4278,6 +4278,11 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI, break; } + // If the caller is a whole wave function, we need to use a special opcode + // so we can patch up EXEC. + if (Info->isWholeWaveFunction()) + OPC = AMDGPUISD::TC_RETURN_GFX_WholeWave; + return DAG.getNode(OPC, DL, MVT::Other, Ops); } @@ -6041,14 +6046,15 @@ SITargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, MI.eraseFromParent(); return SplitBB; } + case AMDGPU::SI_TCRETURN_GFX_WholeWave: case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN: { assert(MFI->isWholeWaveFunction()); // During ISel, it's difficult to propagate the original EXEC mask to use as // an input to SI_WHOLE_WAVE_FUNC_RETURN. Set it up here instead. MachineInstr *Setup = TII->getWholeWaveFunctionSetup(*BB->getParent()); - Register OriginalExec = Setup->getOperand(0).getReg(); assert(Setup && "Couldn't find SI_SETUP_WHOLE_WAVE_FUNC"); + Register OriginalExec = Setup->getOperand(0).getReg(); MF->getRegInfo().clearKillFlags(OriginalExec); MI.getOperand(0).setReg(OriginalExec); return BB; diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp index 44d8192..85981bc 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp @@ -2493,7 +2493,6 @@ bool SIInstrInfo::expandPostRAPseudo(MachineInstr &MI) const { MI.setDesc(get(ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64)); break; } - case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN: case AMDGPU::SI_RETURN: { const MachineFunction *MF = MBB.getParent(); const GCNSubtarget &ST = MF->getSubtarget<GCNSubtarget>(); diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td index dd9f200..fddde42 100644 --- a/llvm/lib/Target/AMDGPU/SIInstructions.td +++ b/llvm/lib/Target/AMDGPU/SIInstructions.td @@ -692,6 +692,33 @@ def SI_WHOLE_WAVE_FUNC_RETURN : SPseudoInstSI < def : GCNPat< (AMDGPUwhole_wave_return), (SI_WHOLE_WAVE_FUNC_RETURN (i1 (IMPLICIT_DEF)))>; +// Restores the previous EXEC and otherwise behaves entirely like a SI_TCRETURN. +// This is used for tail calls *from* a whole wave function. Tail calls to +// a whole wave function may use the usual opcodes, depending on the calling +// convention of the caller. +def SI_TCRETURN_GFX_WholeWave : SPseudoInstSI < + (outs), + (ins SReg_1:$orig_exec, Gfx_CCR_SGPR_64:$src0, unknown:$callee, i32imm:$fpdiff)> { + let isCall = 1; + let isTerminator = 1; + let isReturn = 1; + let isBarrier = 1; + let UseNamedOperandTable = 1; + let SchedRW = [WriteBranch]; + let isConvergent = 1; + + // We're going to use custom handling to set the $orig_exec to the correct value. + let usesCustomInserter = 1; +} + +// Generate a SI_TCRETURN_GFX_WholeWave pseudo with a placeholder for its +// argument. It will be filled in by the custom inserter. +def : GCNPat< + (AMDGPUtc_return_gfx_ww i64:$src0, tglobaladdr:$callee, i32:$fpdiff), + (SI_TCRETURN_GFX_WholeWave (i1 (IMPLICIT_DEF)), Gfx_CCR_SGPR_64:$src0, + tglobaladdr:$callee, i32:$fpdiff)>; + + // Return for returning shaders to a shader variant epilog. def SI_RETURN_TO_EPILOG : SPseudoInstSI < (outs), (ins variable_ops), [(AMDGPUreturn_to_epilog)]> { diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h index 23ea3ba..4ab17d8 100644 --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h @@ -1517,6 +1517,7 @@ constexpr bool mayTailCallThisCC(CallingConv::ID CC) { switch (CC) { case CallingConv::C: case CallingConv::AMDGPU_Gfx: + case CallingConv::AMDGPU_Gfx_WholeWave: return true; default: return canGuaranteeTCO(CC); |