diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp | 17 | ||||
| -rw-r--r-- | llvm/lib/Target/AMDGPU/GCNSubtarget.h | 2 |
2 files changed, 14 insertions, 5 deletions
diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp index a7d8ff0..bcd93e3 100644 --- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp +++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp @@ -1450,20 +1450,27 @@ bool GCNHazardRecognizer::fixLdsDirectVMEMHazard(MachineInstr *MI) { return false; return I.readsRegister(VDSTReg, &TRI) || I.modifiesRegister(VDSTReg, &TRI); }; - auto IsExpiredFn = [](const MachineInstr &I, int) { + bool LdsdirCanWait = ST.hasLdsWaitVMSRC(); + auto IsExpiredFn = [this, LdsdirCanWait](const MachineInstr &I, int) { return SIInstrInfo::isVALU(I) || SIInstrInfo::isEXP(I) || (I.getOpcode() == AMDGPU::S_WAITCNT && !I.getOperand(0).getImm()) || (I.getOpcode() == AMDGPU::S_WAITCNT_DEPCTR && - AMDGPU::DepCtr::decodeFieldVmVsrc(I.getOperand(0).getImm()) == 0); + AMDGPU::DepCtr::decodeFieldVmVsrc(I.getOperand(0).getImm()) == 0) || + (LdsdirCanWait && SIInstrInfo::isLDSDIR(I) && + !TII.getNamedOperand(I, AMDGPU::OpName::waitvsrc)->getImm()); }; if (::getWaitStatesSince(IsHazardFn, MI, IsExpiredFn) == std::numeric_limits<int>::max()) return false; - BuildMI(*MI->getParent(), MI, MI->getDebugLoc(), - TII.get(AMDGPU::S_WAITCNT_DEPCTR)) - .addImm(AMDGPU::DepCtr::encodeFieldVmVsrc(0)); + if (LdsdirCanWait) { + TII.getNamedOperand(*MI, AMDGPU::OpName::waitvsrc)->setImm(0); + } else { + BuildMI(*MI->getParent(), MI, MI->getDebugLoc(), + TII.get(AMDGPU::S_WAITCNT_DEPCTR)) + .addImm(AMDGPU::DepCtr::encodeFieldVmVsrc(0)); + } return true; } diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.h b/llvm/lib/Target/AMDGPU/GCNSubtarget.h index f6f37f5..85d062a 100644 --- a/llvm/lib/Target/AMDGPU/GCNSubtarget.h +++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.h @@ -1128,6 +1128,8 @@ public: bool hasLdsDirect() const { return getGeneration() >= GFX11; } + bool hasLdsWaitVMSRC() const { return getGeneration() >= GFX12; } + bool hasVALUPartialForwardingHazard() const { return getGeneration() >= GFX11; } |
