aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDhruva Chakrabarti <Dhruva.Chakrabarti@amd.com>2026-03-18 14:21:45 -0500
committerDhruva Chakrabarti <Dhruva.Chakrabarti@amd.com>2026-03-18 14:21:45 -0500
commitf441b9d1f7bf4dee0f7958e93926d6fe9a8d6e4c (patch)
treed29f08bf8e625dcf60f7cf73224f76e1062e842b
parentdd0a60895cecea9f8b2f45a8644aa388f3ccccce (diff)
downloadllvm-users/dhruvachak/refactor_max_vgpr_api.tar.gz
llvm-users/dhruvachak/refactor_max_vgpr_api.tar.bz2
llvm-users/dhruvachak/refactor_max_vgpr_api.zip
[AMDGPU][NFC] If outside range, clamp target occupancy to nearest endpoint.users/dhruvachak/refactor_max_vgpr_api
-rw-r--r--llvm/lib/Target/AMDGPU/GCNSubtarget.cpp14
-rw-r--r--llvm/lib/Target/AMDGPU/GCNSubtarget.h5
2 files changed, 13 insertions, 6 deletions
diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp
index d8b67cee946a..72d7c487f460 100644
--- a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp
@@ -565,11 +565,15 @@ GCNSubtarget::getMaxNumVGPRs(const Function &F,
if (DynamicVGPRBlockSize == 0 && isDynamicVGPREnabled())
DynamicVGPRBlockSize = getDynamicVGPRBlockSize();
- std::pair<unsigned, unsigned> Waves;
- if (TargetOccupancy)
- Waves = {*TargetOccupancy, *TargetOccupancy};
- else
- Waves = getWavesPerEU(F);
+ std::pair<unsigned, unsigned> Waves = getWavesPerEU(F);
+ if (TargetOccupancy) {
+ if (*TargetOccupancy >= Waves.first && *TargetOccupancy <= Waves.second)
+ Waves = {*TargetOccupancy, *TargetOccupancy};
+ else if (*TargetOccupancy < Waves.first)
+ Waves = {Waves.first, Waves.first};
+ else
+ Waves = {Waves.second, Waves.second};
+ }
return getBaseMaxNumVGPRs(
F, {getMinNumVGPRs(Waves.second, DynamicVGPRBlockSize),
getMaxNumVGPRs(Waves.first, DynamicVGPRBlockSize)});
diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.h b/llvm/lib/Target/AMDGPU/GCNSubtarget.h
index 54798e36700e..5f1f945cb43b 100644
--- a/llvm/lib/Target/AMDGPU/GCNSubtarget.h
+++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.h
@@ -859,7 +859,10 @@ public:
/// subtarget's specifications, or does not meet number of waves per execution
/// unit requirement.
/// When \p TargetOccupancy is present, use it for both min and max waves
- /// instead of getWavesPerEU(F).
+ /// if it lies within the function's wave range from getWavesPerEU(F)
+ /// (inclusive). Otherwise clamp \p TargetOccupancy to the nearest endpoint
+ /// of that range (below the minimum -> minimum waves; above the maximum ->
+ /// maximum waves).
unsigned
getMaxNumVGPRs(const Function &F,
std::optional<unsigned> TargetOccupancy = std::nullopt) const;