diff options
author | Akshay Deodhar <adeodhar@nvidia.com> | 2025-03-21 10:56:38 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-21 10:56:38 -0700 |
commit | cb2ee1e64db663ec8b39554a3cf93cc924d89818 (patch) | |
tree | 480a68ca9be81f7b07a8fefc70e80daac56a54b3 /llvm/lib | |
parent | 3b0ec611565ca603389db6d71e1c917f22439456 (diff) | |
download | llvm-cb2ee1e64db663ec8b39554a3cf93cc924d89818.zip llvm-cb2ee1e64db663ec8b39554a3cf93cc924d89818.tar.gz llvm-cb2ee1e64db663ec8b39554a3cf93cc924d89818.tar.bz2 |
[NVPTX][NVPTXLowerArgs] Add NewPM interface for NVPTXLowerArgs (#128960)
Add a NewPM interface for NVPTXLowerArgs
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTX.h | 10 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp | 55 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXPassRegistry.def | 1 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp | 4 |
4 files changed, 39 insertions, 31 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h index 62f5186..20a5bf4 100644 --- a/llvm/lib/Target/NVPTX/NVPTX.h +++ b/llvm/lib/Target/NVPTX/NVPTX.h @@ -18,6 +18,7 @@ #include "llvm/Pass.h" #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/CodeGen.h" +#include "llvm/Target/TargetMachine.h" namespace llvm { class FunctionPass; @@ -75,6 +76,15 @@ struct NVPTXCopyByValArgsPass : PassInfoMixin<NVPTXCopyByValArgsPass> { PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); }; +struct NVPTXLowerArgsPass : PassInfoMixin<NVPTXLowerArgsPass> { +private: + TargetMachine &TM; + +public: + NVPTXLowerArgsPass(TargetMachine &TM) : TM(TM) {}; + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; + namespace NVPTX { enum DrvInterface { NVCL, diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp index 6dc9277..2637b9f 100644 --- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp @@ -162,27 +162,16 @@ using namespace llvm; namespace llvm { -void initializeNVPTXLowerArgsPass(PassRegistry &); +void initializeNVPTXLowerArgsLegacyPassPass(PassRegistry &); } namespace { -class NVPTXLowerArgs : public FunctionPass { +class NVPTXLowerArgsLegacyPass : public FunctionPass { bool runOnFunction(Function &F) override; - bool runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F); - bool runOnDeviceFunction(const NVPTXTargetMachine &TM, Function &F); - - // handle byval parameters - void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg); - // Knowing Ptr must point to the global address space, this function - // addrspacecasts Ptr to global and then back to generic. This allows - // NVPTXInferAddressSpaces to fold the global-to-generic cast into - // loads/stores that appear later. - void markPointerAsGlobal(Value *Ptr); - public: static char ID; // Pass identification, replacement for typeid - NVPTXLowerArgs() : FunctionPass(ID) {} + NVPTXLowerArgsLegacyPass() : FunctionPass(ID) {} StringRef getPassName() const override { return "Lower pointer arguments of CUDA kernels"; } @@ -192,12 +181,12 @@ public: }; } // namespace -char NVPTXLowerArgs::ID = 1; +char NVPTXLowerArgsLegacyPass::ID = 1; -INITIALIZE_PASS_BEGIN(NVPTXLowerArgs, "nvptx-lower-args", +INITIALIZE_PASS_BEGIN(NVPTXLowerArgsLegacyPass, "nvptx-lower-args", "Lower arguments (NVPTX)", false, false) INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) -INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args", +INITIALIZE_PASS_END(NVPTXLowerArgsLegacyPass, "nvptx-lower-args", "Lower arguments (NVPTX)", false, false) // ============================================================================= @@ -552,8 +541,7 @@ void copyByValParam(Function &F, Argument &Arg) { } } // namespace -void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM, - Argument *Arg) { +static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) { Function *Func = Arg->getParent(); bool HasCvtaParam = TM.getSubtargetImpl(*Func)->hasCvtaParam() && isKernelFunction(*Func); @@ -647,20 +635,19 @@ static void markPointerAsAS(Value *Ptr, const unsigned AS) { PtrInGlobal->setOperand(0, Ptr); } -void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) { +static void markPointerAsGlobal(Value *Ptr) { markPointerAsAS(Ptr, ADDRESS_SPACE_GLOBAL); } // ============================================================================= // Main function for this pass. // ============================================================================= -bool NVPTXLowerArgs::runOnKernelFunction(const NVPTXTargetMachine &TM, - Function &F) { +static bool runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F) { // Copying of byval aggregates + SROA may result in pointers being loaded as // integers, followed by intotoptr. We may want to mark those as global, too, // but only if the loaded integer is used exclusively for conversion to a // pointer with inttoptr. - auto HandleIntToPtr = [this](Value &V) { + auto HandleIntToPtr = [](Value &V) { if (llvm::all_of(V.users(), [](User *U) { return isa<IntToPtrInst>(U); })) { SmallVector<User *, 16> UsersToUpdate(V.users()); for (User *U : UsersToUpdate) @@ -705,8 +692,7 @@ bool NVPTXLowerArgs::runOnKernelFunction(const NVPTXTargetMachine &TM, } // Device functions only need to copy byval args into local memory. -bool NVPTXLowerArgs::runOnDeviceFunction(const NVPTXTargetMachine &TM, - Function &F) { +static bool runOnDeviceFunction(const NVPTXTargetMachine &TM, Function &F) { LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n"); const auto *TLI = @@ -720,14 +706,18 @@ bool NVPTXLowerArgs::runOnDeviceFunction(const NVPTXTargetMachine &TM, return true; } -bool NVPTXLowerArgs::runOnFunction(Function &F) { - auto &TM = getAnalysis<TargetPassConfig>().getTM<NVPTXTargetMachine>(); - +static bool processFunction(Function &F, NVPTXTargetMachine &TM) { return isKernelFunction(F) ? runOnKernelFunction(TM, F) : runOnDeviceFunction(TM, F); } -FunctionPass *llvm::createNVPTXLowerArgsPass() { return new NVPTXLowerArgs(); } +bool NVPTXLowerArgsLegacyPass::runOnFunction(Function &F) { + auto &TM = getAnalysis<TargetPassConfig>().getTM<NVPTXTargetMachine>(); + return processFunction(F, TM); +} +FunctionPass *llvm::createNVPTXLowerArgsPass() { + return new NVPTXLowerArgsLegacyPass(); +} static bool copyFunctionByValArgs(Function &F) { LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName() @@ -747,3 +737,10 @@ PreservedAnalyses NVPTXCopyByValArgsPass::run(Function &F, return copyFunctionByValArgs(F) ? PreservedAnalyses::none() : PreservedAnalyses::all(); } + +PreservedAnalyses NVPTXLowerArgsPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &NTM = static_cast<NVPTXTargetMachine &>(TM); + bool Changed = processFunction(F, NTM); + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def index 28ea9dd..34c79b8f 100644 --- a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def +++ b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def @@ -38,4 +38,5 @@ FUNCTION_ALIAS_ANALYSIS("nvptx-aa", NVPTXAA()) FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass()) FUNCTION_PASS("nvvm-reflect", NVVMReflectPass()) FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass()) +FUNCTION_PASS("nvptx-lower-args", NVPTXLowerArgsPass(*this)); #undef FUNCTION_PASS diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp index b363472..5bb1687 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp @@ -98,7 +98,7 @@ void initializeNVPTXLowerAggrCopiesPass(PassRegistry &); void initializeNVPTXLowerAllocaPass(PassRegistry &); void initializeNVPTXLowerUnreachablePass(PassRegistry &); void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &); -void initializeNVPTXLowerArgsPass(PassRegistry &); +void initializeNVPTXLowerArgsLegacyPassPass(PassRegistry &); void initializeNVPTXProxyRegErasurePass(PassRegistry &); void initializeNVPTXForwardParamsPassPass(PassRegistry &); void initializeNVVMIntrRangePass(PassRegistry &); @@ -122,7 +122,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() { initializeNVPTXAllocaHoistingPass(PR); initializeNVPTXAssignValidGlobalNamesPass(PR); initializeNVPTXAtomicLowerPass(PR); - initializeNVPTXLowerArgsPass(PR); + initializeNVPTXLowerArgsLegacyPassPass(PR); initializeNVPTXLowerAllocaPass(PR); initializeNVPTXLowerUnreachablePass(PR); initializeNVPTXCtorDtorLoweringLegacyPass(PR); |