1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
|
//===-- SPIRVLegalizePointerCast.cpp ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// The LLVM IR has multiple legal patterns we cannot lower to Logical SPIR-V.
// This pass modifies such loads to have an IR we can directly lower to valid
// logical SPIR-V.
// OpenCL can avoid this because they rely on ptrcast, which is not supported
// by logical SPIR-V.
//
// This pass relies on the assign_ptr_type intrinsic to deduce the type of the
// pointed values, must replace all occurences of `ptrcast`. This is why
// unhandled cases are reported as unreachable: we MUST cover all cases.
//
// 1. Loading the first element of an array
//
// %array = [10 x i32]
// %value = load i32, ptr %array
//
// LLVM can skip the GEP instruction, and only request loading the first 4
// bytes. In logical SPIR-V, we need an OpAccessChain to access the first
// element. This pass will add a getelementptr instruction before the load.
//
//
// 2. Implicit downcast from load
//
// %1 = getelementptr <4 x i32>, ptr %vec4, i64 0
// %2 = load <3 x i32>, ptr %1
//
// The pointer in the GEP instruction is only used for offset computations,
// but it doesn't NEED to match the pointed type. OpAccessChain however
// requires this. Also, LLVM loads define the bitwidth of the load, not the
// pointer. In this example, we can guess %vec4 is a vec4 thanks to the GEP
// instruction basetype, but we only want to load the first 3 elements, hence
// do a partial load. In logical SPIR-V, this is not legal. What we must do
// is load the full vector (basetype), extract 3 elements, and recombine them
// to form a 3-element vector.
//
//===----------------------------------------------------------------------===//
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/CodeGen/IntrinsicLowering.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
using namespace llvm;
namespace {
class SPIRVLegalizePointerCast : public FunctionPass {
// Builds the `spv_assign_type` assigning |Ty| to |Value| at the current
// builder position.
void buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg) {
Value *OfType = PoisonValue::get(Ty);
CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type,
{Arg->getType()}, OfType, Arg, {}, B);
GR->addAssignPtrTypeInstr(Arg, AssignCI);
}
// Loads parts of the vector of type |SourceType| from the pointer |Source|
// and create a new vector of type |TargetType|. |TargetType| must be a vector
// type, and element types of |TargetType| and |SourceType| must match.
// Returns the loaded value.
Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
FixedVectorType *TargetType, Value *Source) {
// We expect the codegen to avoid doing implicit bitcast from a load.
assert(TargetType->getElementType() == SourceType->getElementType());
assert(TargetType->getNumElements() < SourceType->getNumElements());
LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
buildAssignType(B, SourceType, NewLoad);
SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
Mask[I] = I;
Value *Output = B.CreateShuffleVector(NewLoad, NewLoad, Mask);
buildAssignType(B, TargetType, Output);
return Output;
}
// Loads the first value in an aggregate pointed by |Source| of containing
// elements of type |ElementType|. Load flags will be copied from |BadLoad|,
// which should be the load being legalized. Returns the loaded value.
Value *loadFirstValueFromAggregate(IRBuilder<> &B, Type *ElementType,
Value *Source, LoadInst *BadLoad) {
SmallVector<Type *, 2> Types = {BadLoad->getPointerOperandType(),
BadLoad->getPointerOperandType()};
SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(false), Source,
B.getInt32(0), B.getInt32(0)};
auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
GR->buildAssignPtr(B, ElementType, GEP);
LoadInst *LI = B.CreateLoad(ElementType, GEP);
LI->setAlignment(BadLoad->getAlign());
buildAssignType(B, ElementType, LI);
return LI;
}
// Replaces the load instruction to get rid of the ptrcast used as source
// operand.
void transformLoad(IRBuilder<> &B, LoadInst *LI, Value *CastedOperand,
Value *OriginalOperand) {
Type *FromTy = GR->findDeducedElementType(OriginalOperand);
Type *ToTy = GR->findDeducedElementType(CastedOperand);
Value *Output = nullptr;
auto *SAT = dyn_cast<ArrayType>(FromTy);
auto *SVT = dyn_cast<FixedVectorType>(FromTy);
auto *SST = dyn_cast<StructType>(FromTy);
auto *DVT = dyn_cast<FixedVectorType>(ToTy);
B.SetInsertPoint(LI);
// Destination is the element type of Source, and source is an array ->
// Loading 1st element.
// - float a = array[0];
if (SAT && SAT->getElementType() == ToTy)
Output = loadFirstValueFromAggregate(B, SAT->getElementType(),
OriginalOperand, LI);
// Destination is the element type of Source, and source is a vector ->
// Vector to scalar.
// - float a = vector.x;
else if (!DVT && SVT && SVT->getElementType() == ToTy) {
Output = loadFirstValueFromAggregate(B, SVT->getElementType(),
OriginalOperand, LI);
}
// Destination is a smaller vector than source.
// - float3 v3 = vector4;
else if (SVT && DVT)
Output = loadVectorFromVector(B, SVT, DVT, OriginalOperand);
// Destination is the scalar type stored at the start of an aggregate.
// - struct S { float m };
// - float v = s.m;
else if (SST && SST->getTypeAtIndex(0u) == ToTy)
Output = loadFirstValueFromAggregate(B, ToTy, OriginalOperand, LI);
else
llvm_unreachable("Unimplemented implicit down-cast from load.");
GR->replaceAllUsesWith(LI, Output, /* DeleteOld= */ true);
DeadInstructions.push_back(LI);
}
// Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,
unsigned Index) {
Type *Int32Ty = Type::getInt32Ty(B.getContext());
SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),
Element->getType(), Int32Ty};
SmallVector<Value *> Args = {Vector, Element, B.getInt32(Index)};
Instruction *NewI =
B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
buildAssignType(B, Vector->getType(), NewI);
return NewI;
}
// Creates an spv_extractelt instruction (equivalent to llvm's
// extractelement).
Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,
unsigned Index) {
Type *Int32Ty = Type::getInt32Ty(B.getContext());
SmallVector<Type *, 3> Types = {ElementType, Vector->getType(), Int32Ty};
SmallVector<Value *> Args = {Vector, B.getInt32(Index)};
Instruction *NewI =
B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});
buildAssignType(B, ElementType, NewI);
return NewI;
}
// Stores the given Src vector operand into the Dst vector, adjusting the size
// if required.
Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,
Align Alignment) {
FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
FixedVectorType *DstType =
cast<FixedVectorType>(GR->findDeducedElementType(Dst));
assert(DstType->getNumElements() >= SrcType->getNumElements());
LoadInst *LI = B.CreateLoad(DstType, Dst);
LI->setAlignment(Alignment);
Value *OldValues = LI;
buildAssignType(B, OldValues->getType(), OldValues);
Value *NewValues = Src;
for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {
Value *Element =
makeExtractElement(B, SrcType->getElementType(), NewValues, I);
OldValues = makeInsertElement(B, OldValues, Element, I);
}
StoreInst *SI = B.CreateStore(OldValues, Dst);
SI->setAlignment(Alignment);
return SI;
}
void buildGEPIndexChain(IRBuilder<> &B, Type *Search, Type *Aggregate,
SmallVectorImpl<Value *> &Indices) {
Indices.push_back(B.getInt32(0));
if (Search == Aggregate)
return;
if (auto *ST = dyn_cast<StructType>(Aggregate))
buildGEPIndexChain(B, Search, ST->getTypeAtIndex(0u), Indices);
else if (auto *AT = dyn_cast<ArrayType>(Aggregate))
buildGEPIndexChain(B, Search, AT->getElementType(), Indices);
else if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
buildGEPIndexChain(B, Search, VT->getElementType(), Indices);
else
llvm_unreachable("Bad access chain?");
}
// Stores the given Src value into the first entry of the Dst aggregate.
Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
Type *DstPointeeType, Align Alignment) {
SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst};
buildGEPIndexChain(B, Src->getType(), DstPointeeType, Args);
auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
GR->buildAssignPtr(B, Src->getType(), GEP);
StoreInst *SI = B.CreateStore(Src, GEP);
SI->setAlignment(Alignment);
return SI;
}
bool isTypeFirstElementAggregate(Type *Search, Type *Aggregate) {
if (Search == Aggregate)
return true;
if (auto *ST = dyn_cast<StructType>(Aggregate))
return isTypeFirstElementAggregate(Search, ST->getTypeAtIndex(0u));
if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
return isTypeFirstElementAggregate(Search, VT->getElementType());
if (auto *AT = dyn_cast<ArrayType>(Aggregate))
return isTypeFirstElementAggregate(Search, AT->getElementType());
return false;
}
// Transforms a store instruction (or SPV intrinsic) using a ptrcast as
// operand into a valid logical SPIR-V store with no ptrcast.
void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
Value *Dst, Align Alignment) {
Type *ToTy = GR->findDeducedElementType(Dst);
Type *FromTy = Src->getType();
auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
auto *D_ST = dyn_cast<StructType>(ToTy);
auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
B.SetInsertPoint(BadStore);
if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))
storeToFirstValueAggregate(B, Src, Dst, D_ST, Alignment);
else if (D_VT && S_VT)
storeVectorFromVector(B, Src, Dst, Alignment);
else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
else
llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
DeadInstructions.push_back(BadStore);
}
void legalizePointerCast(IntrinsicInst *II) {
Value *CastedOperand = II;
Value *OriginalOperand = II->getOperand(0);
IRBuilder<> B(II->getContext());
std::vector<Value *> Users;
for (Use &U : II->uses())
Users.push_back(U.getUser());
for (Value *User : Users) {
if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
transformLoad(B, LI, CastedOperand, OriginalOperand);
continue;
}
if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
transformStore(B, SI, SI->getValueOperand(), OriginalOperand,
SI->getAlign());
continue;
}
if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(User)) {
if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
DeadInstructions.push_back(Intrin);
continue;
}
if (Intrin->getIntrinsicID() == Intrinsic::spv_gep) {
GR->replaceAllUsesWith(CastedOperand, OriginalOperand,
/* DeleteOld= */ false);
continue;
}
if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
Align Alignment;
if (ConstantInt *C = dyn_cast<ConstantInt>(Intrin->getOperand(3)))
Alignment = Align(C->getZExtValue());
transformStore(B, Intrin, Intrin->getArgOperand(0), OriginalOperand,
Alignment);
continue;
}
}
llvm_unreachable("Unsupported ptrcast user. Please fix.");
}
DeadInstructions.push_back(II);
}
public:
SPIRVLegalizePointerCast(SPIRVTargetMachine *TM) : FunctionPass(ID), TM(TM) {}
virtual bool runOnFunction(Function &F) override {
const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
GR = ST.getSPIRVGlobalRegistry();
DeadInstructions.clear();
std::vector<IntrinsicInst *> WorkList;
for (auto &BB : F) {
for (auto &I : BB) {
auto *II = dyn_cast<IntrinsicInst>(&I);
if (II && II->getIntrinsicID() == Intrinsic::spv_ptrcast)
WorkList.push_back(II);
}
}
for (IntrinsicInst *II : WorkList)
legalizePointerCast(II);
for (Instruction *I : DeadInstructions)
I->eraseFromParent();
return DeadInstructions.size() != 0;
}
private:
SPIRVTargetMachine *TM = nullptr;
SPIRVGlobalRegistry *GR = nullptr;
std::vector<Instruction *> DeadInstructions;
public:
static char ID;
};
} // namespace
char SPIRVLegalizePointerCast::ID = 0;
INITIALIZE_PASS(SPIRVLegalizePointerCast, "spirv-legalize-bitcast",
"SPIRV legalize bitcast pass", false, false)
FunctionPass *llvm::createSPIRVLegalizePointerCastPass(SPIRVTargetMachine *TM) {
return new SPIRVLegalizePointerCast(TM);
}
|