1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
|
//===- NVVMReflect.cpp - NVVM Emulate conditional compilation -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
// with an integer.
//
// We choose the value we use by looking at metadata in the module itself. Note
// that we intentionally only have one way to choose these values, because other
// parts of LLVM (particularly, InstCombineCall) rely on being able to predict
// the values chosen by this pass.
//
// If we see an unknown string, we replace its call with 0.
//
//===----------------------------------------------------------------------===//
#include "NVPTX.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
#define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
// Argument of reflect call to retrive arch number
#define CUDA_ARCH_NAME "__CUDA_ARCH"
// Argument of reflect call to retrive ftz mode
#define CUDA_FTZ_NAME "__CUDA_FTZ"
// Name of module metadata where ftz mode is stored
#define CUDA_FTZ_MODULE_NAME "nvvm-reflect-ftz"
using namespace llvm;
#define DEBUG_TYPE "nvvm-reflect"
namespace {
class NVVMReflect {
// Map from reflect function call arguments to the value to replace the call
// with. Should include __CUDA_FTZ and __CUDA_ARCH values.
StringMap<unsigned> ReflectMap;
bool handleReflectFunction(Module &M, StringRef ReflectName);
void populateReflectMap(Module &M);
void foldReflectCall(CallInst *Call, Constant *NewValue);
public:
// __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
// metadata.
explicit NVVMReflect(unsigned SmVersion)
: ReflectMap({{CUDA_ARCH_NAME, SmVersion * 10}}) {}
bool runOnModule(Module &M);
};
class NVVMReflectLegacyPass : public ModulePass {
NVVMReflect Impl;
public:
static char ID;
NVVMReflectLegacyPass(unsigned SmVersion) : ModulePass(ID), Impl(SmVersion) {}
bool runOnModule(Module &M) override;
};
} // namespace
ModulePass *llvm::createNVVMReflectPass(unsigned SmVersion) {
return new NVVMReflectLegacyPass(SmVersion);
}
static cl::opt<bool>
NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden,
cl::desc("NVVM reflection, enabled by default"));
char NVVMReflectLegacyPass::ID = 0;
INITIALIZE_PASS(NVVMReflectLegacyPass, "nvvm-reflect",
"Replace occurrences of __nvvm_reflect() calls with 0/1", false,
false)
// Allow users to specify additional key/value pairs to reflect. These key/value
// pairs are the last to be added to the ReflectMap, and therefore will take
// precedence over initial values (i.e. __CUDA_FTZ from module medadata and
// __CUDA_ARCH from SmVersion).
static cl::list<std::string> ReflectList(
"nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden,
cl::desc("A key=value pair. Replace __nvvm_reflect(name) with value."),
cl::ValueRequired);
// Set the ReflectMap with, first, the value of __CUDA_FTZ from module metadata,
// and then the key/value pairs from the command line.
void NVVMReflect::populateReflectMap(Module &M) {
if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
M.getModuleFlag(CUDA_FTZ_MODULE_NAME)))
ReflectMap[CUDA_FTZ_NAME] = Flag->getSExtValue();
for (auto &Option : ReflectList) {
LLVM_DEBUG(dbgs() << "ReflectOption : " << Option << "\n");
StringRef OptionRef(Option);
auto [Name, Val] = OptionRef.split('=');
if (Name.empty())
report_fatal_error(Twine("Empty name in nvvm-reflect-add option '") +
Option + "'");
if (Val.empty())
report_fatal_error(Twine("Missing value in nvvm-reflect-add option '") +
Option + "'");
unsigned ValInt;
if (!to_integer(Val.trim(), ValInt, 10))
report_fatal_error(
Twine("integer value expected in nvvm-reflect-add option '") +
Option + "'");
ReflectMap[Name] = ValInt;
}
}
/// Process a reflect function by finding all its calls and replacing them with
/// appropriate constant values. For __CUDA_FTZ, uses the module flag value.
/// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0.
bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
Function *F = M.getFunction(ReflectName);
if (!F)
return false;
assert(F->isDeclaration() && "_reflect function should not have a body");
assert(F->getReturnType()->isIntegerTy() &&
"_reflect's return type should be integer");
const bool Changed = !F->use_empty();
for (User *U : make_early_inc_range(F->users())) {
// Reflect function calls look like:
// @arch = private unnamed_addr addrspace(1) constant [12 x i8]
// c"__CUDA_ARCH\00" call i32 @__nvvm_reflect(ptr addrspacecast (ptr
// addrspace(1) @arch to ptr)) We need to extract the string argument from
// the call (i.e. "__CUDA_ARCH")
auto *Call = dyn_cast<CallInst>(U);
if (!Call)
report_fatal_error(
"__nvvm_reflect can only be used in a call instruction");
if (Call->getNumOperands() != 2)
report_fatal_error("__nvvm_reflect requires exactly one argument");
auto *GlobalStr =
dyn_cast<Constant>(Call->getArgOperand(0)->stripPointerCasts());
if (!GlobalStr)
report_fatal_error("__nvvm_reflect argument must be a constant string");
auto *ConstantStr =
dyn_cast<ConstantDataSequential>(GlobalStr->getOperand(0));
if (!ConstantStr)
report_fatal_error("__nvvm_reflect argument must be a string constant");
if (!ConstantStr->isCString())
report_fatal_error(
"__nvvm_reflect argument must be a null-terminated string");
StringRef ReflectArg = ConstantStr->getAsString().drop_back();
if (ReflectArg.empty())
report_fatal_error("__nvvm_reflect argument cannot be empty");
// Now that we have extracted the string argument, we can look it up in the
// ReflectMap
unsigned ReflectVal = 0; // The default value is 0
if (ReflectMap.contains(ReflectArg))
ReflectVal = ReflectMap[ReflectArg];
LLVM_DEBUG(dbgs() << "Replacing call of reflect function " << F->getName()
<< "(" << ReflectArg << ") with value " << ReflectVal
<< "\n");
auto *NewValue = ConstantInt::get(Call->getType(), ReflectVal);
foldReflectCall(Call, NewValue);
Call->eraseFromParent();
}
// Remove the __nvvm_reflect function from the module
F->eraseFromParent();
return Changed;
}
void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
SmallVector<Instruction *, 8> Worklist;
// Replace an instruction with a constant and add all users of the instruction
// to the worklist
auto ReplaceInstructionWithConst = [&](Instruction *I, Constant *C) {
for (auto *U : I->users())
if (auto *UI = dyn_cast<Instruction>(U))
Worklist.push_back(UI);
I->replaceAllUsesWith(C);
};
ReplaceInstructionWithConst(Call, NewValue);
auto &DL = Call->getModule()->getDataLayout();
while (!Worklist.empty()) {
auto *I = Worklist.pop_back_val();
if (auto *C = ConstantFoldInstruction(I, DL)) {
ReplaceInstructionWithConst(I, C);
if (isInstructionTriviallyDead(I))
I->eraseFromParent();
} else if (I->isTerminator()) {
ConstantFoldTerminator(I->getParent());
}
}
}
bool NVVMReflect::runOnModule(Module &M) {
if (!NVVMReflectEnabled)
return false;
populateReflectMap(M);
bool Changed = true;
Changed |= handleReflectFunction(M, NVVM_REFLECT_FUNCTION);
Changed |= handleReflectFunction(M, NVVM_REFLECT_OCL_FUNCTION);
Changed |=
handleReflectFunction(M, Intrinsic::getName(Intrinsic::nvvm_reflect));
return Changed;
}
bool NVVMReflectLegacyPass::runOnModule(Module &M) {
return Impl.runOnModule(M);
}
PreservedAnalyses NVVMReflectPass::run(Module &M, ModuleAnalysisManager &AM) {
return NVVMReflect(SmVersion).runOnModule(M) ? PreservedAnalyses::none()
: PreservedAnalyses::all();
}
|