aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
blob: dc92367fc58cdfde5f71f739096144e72118bb04 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
//===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert SCF dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Context
//===----------------------------------------------------------------------===//

namespace mlir {
struct ScfToSPIRVContextImpl {
  // Map between the spirv region control flow operation (spirv.mlir.loop or
  // spirv.mlir.selection) to the VariableOp created to store the region
  // results. The order of the VariableOp matches the order of the results.
  DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
};
} // namespace mlir

/// We use ScfToSPIRVContext to store information about the lowering of the scf
/// region that need to be used later on. When we lower scf.for/scf.if we create
/// VariableOp to store the results. We need to keep track of the VariableOp
/// created as we need to insert stores into them when lowering Yield. Those
/// StoreOp cannot be created earlier as they may use a different type than
/// yield operands.
ScfToSPIRVContext::ScfToSPIRVContext() {
  impl = std::make_unique<::ScfToSPIRVContextImpl>();
}

ScfToSPIRVContext::~ScfToSPIRVContext() = default;

namespace {

//===----------------------------------------------------------------------===//
// Helper Functions
//===----------------------------------------------------------------------===//

/// Replaces SCF op outputs with SPIR-V variable loads.
/// We create VariableOp to handle the results value of the control flow region.
/// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right
/// after the loop we load the value from the allocation and use it as the SCF
/// op result.
template <typename ScfOp, typename OpTy>
void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
                           ConversionPatternRewriter &rewriter,
                           ScfToSPIRVContextImpl *scfToSPIRVContext,
                           ArrayRef<Type> returnTypes) {

  Location loc = scfOp.getLoc();
  auto &allocas = scfToSPIRVContext->outputVars[newOp];
  // Clearing the allocas is necessary in case a dialect conversion path failed
  // previously, and this is the second attempt of this conversion.
  allocas.clear();
  SmallVector<Value, 8> resultValue;
  for (Type convertedType : returnTypes) {
    auto pointerType =
        spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
    rewriter.setInsertionPoint(newOp);
    auto alloc = spirv::VariableOp::create(rewriter, loc, pointerType,
                                           spirv::StorageClass::Function,
                                           /*initializer=*/nullptr);
    allocas.push_back(alloc);
    rewriter.setInsertionPointAfter(newOp);
    Value loadResult = spirv::LoadOp::create(rewriter, loc, alloc);
    resultValue.push_back(loadResult);
  }
  rewriter.replaceOp(scfOp, resultValue);
}

Region::iterator getBlockIt(Region &region, unsigned index) {
  return std::next(region.begin(), index);
}

//===----------------------------------------------------------------------===//
// Conversion Patterns
//===----------------------------------------------------------------------===//

/// Common class for all vector to GPU patterns.
template <typename OpTy>
class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
public:
  SCFToSPIRVPattern(MLIRContext *context, const SPIRVTypeConverter &converter,
                    ScfToSPIRVContextImpl *scfToSPIRVContext)
      : OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
        scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}

protected:
  ScfToSPIRVContextImpl *scfToSPIRVContext;
  // FIXME: We explicitly keep a reference of the type converter here instead of
  // passing it to OpConversionPattern during construction. This effectively
  // bypasses the conversion framework's automation on type conversion. This is
  // needed right now because the conversion framework will unconditionally
  // legalize all types used by SCF ops upon discovering them, for example, the
  // types of loop carried values. We use SPIR-V variables for those loop
  // carried values. Depending on the available capabilities, the SPIR-V
  // variable can be different, for example, cooperative matrix or normal
  // variable. We'd like to detach the conversion of the loop carried values
  // from the SCF ops (which is mainly a region). So we need to "mark" types
  // used by SCF ops as legal, if to use the conversion framework for type
  // conversion. There isn't a straightforward way to do that yet, as when
  // converting types, ops aren't taken into consideration. Therefore, we just
  // bypass the framework's type conversion for now.
  const SPIRVTypeConverter &typeConverter;
};

//===----------------------------------------------------------------------===//
// scf::ForOp
//===----------------------------------------------------------------------===//

/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
  using SCFToSPIRVPattern::SCFToSPIRVPattern;

  LogicalResult
  matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // scf::ForOp can be lowered to the structured control flow represented by
    // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
    // latch and the merge block the exit block. The resulting spirv::LoopOp has
    // a single back edge from the continue to header block, and a single exit
    // from header to merge.
    auto loc = forOp.getLoc();
    auto loopOp =
        spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
    loopOp.addEntryAndMergeBlock(rewriter);

    OpBuilder::InsertionGuard guard(rewriter);
    // Create the block for the header.
    Block *header = rewriter.createBlock(&loopOp.getBody(),
                                         getBlockIt(loopOp.getBody(), 1));
    rewriter.setInsertionPointAfter(loopOp);

    // Create the new induction variable to use.
    Value adapLowerBound = adaptor.getLowerBound();
    BlockArgument newIndVar =
        header->addArgument(adapLowerBound.getType(), adapLowerBound.getLoc());
    for (Value arg : adaptor.getInitArgs())
      header->addArgument(arg.getType(), arg.getLoc());
    Block *body = forOp.getBody();

    // Apply signature conversion to the body of the forOp. It has a single
    // block, with argument which is the induction variable. That has to be
    // replaced with the new induction variable.
    TypeConverter::SignatureConversion signatureConverter(
        body->getNumArguments());
    signatureConverter.remapInput(0, newIndVar);
    for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
      signatureConverter.remapInput(i, header->getArgument(i));
    body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
                                             signatureConverter);

    // Move the blocks from the forOp into the loopOp. This is the body of the
    // loopOp.
    rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
                                getBlockIt(loopOp.getBody(), 2));

    SmallVector<Value, 8> args(1, adaptor.getLowerBound());
    args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
    // Branch into it from the entry.
    rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
    spirv::BranchOp::create(rewriter, loc, header, args);

    // Generate the rest of the loop header.
    rewriter.setInsertionPointToEnd(header);
    auto *mergeBlock = loopOp.getMergeBlock();
    auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
                                            newIndVar, adaptor.getUpperBound());

    spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body,
                                       ArrayRef<Value>(), mergeBlock,
                                       ArrayRef<Value>());

    // Generate instructions to increment the step of the induction variable and
    // branch to the header.
    Block *continueBlock = loopOp.getContinueBlock();
    rewriter.setInsertionPointToEnd(continueBlock);

    // Add the step to the induction variable and branch to the header.
    Value updatedIndVar = spirv::IAddOp::create(
        rewriter, loc, newIndVar.getType(), newIndVar, adaptor.getStep());
    spirv::BranchOp::create(rewriter, loc, header, updatedIndVar);

    // Infer the return types from the init operands. Vector type may get
    // converted to CooperativeMatrix or to Vector type, to avoid having complex
    // extra logic to figure out the right type we just infer it from the Init
    // operands.
    SmallVector<Type, 8> initTypes;
    for (auto arg : adaptor.getInitArgs())
      initTypes.push_back(arg.getType());
    replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext,
                          initTypes);
    return success();
  }
};

//===----------------------------------------------------------------------===//
// scf::IfOp
//===----------------------------------------------------------------------===//

/// Pattern to convert a scf::IfOp within kernel functions into
/// spirv::SelectionOp.
struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
  using SCFToSPIRVPattern::SCFToSPIRVPattern;

  LogicalResult
  matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // When lowering `scf::IfOp` we explicitly create a selection header block
    // before the control flow diverges and a merge block where control flow
    // subsequently converges.
    auto loc = ifOp.getLoc();

    // Compute return types.
    SmallVector<Type, 8> returnTypes;
    for (auto result : ifOp.getResults()) {
      auto convertedType = typeConverter.convertType(result.getType());
      if (!convertedType)
        return rewriter.notifyMatchFailure(
            loc,
            llvm::formatv("failed to convert type '{0}'", result.getType()));

      returnTypes.push_back(convertedType);
    }

    // Create `spirv.selection` operation, selection header block and merge
    // block.
    auto selectionOp = spirv::SelectionOp::create(
        rewriter, loc, spirv::SelectionControl::None);
    auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
                                            selectionOp.getBody().end());
    spirv::MergeOp::create(rewriter, loc);

    OpBuilder::InsertionGuard guard(rewriter);
    auto *selectionHeaderBlock =
        rewriter.createBlock(&selectionOp.getBody().front());

    // Inline `then` region before the merge block and branch to it.
    auto &thenRegion = ifOp.getThenRegion();
    auto *thenBlock = &thenRegion.front();
    rewriter.setInsertionPointToEnd(&thenRegion.back());
    spirv::BranchOp::create(rewriter, loc, mergeBlock);
    rewriter.inlineRegionBefore(thenRegion, mergeBlock);

    auto *elseBlock = mergeBlock;
    // If `else` region is not empty, inline that region before the merge block
    // and branch to it.
    if (!ifOp.getElseRegion().empty()) {
      auto &elseRegion = ifOp.getElseRegion();
      elseBlock = &elseRegion.front();
      rewriter.setInsertionPointToEnd(&elseRegion.back());
      spirv::BranchOp::create(rewriter, loc, mergeBlock);
      rewriter.inlineRegionBefore(elseRegion, mergeBlock);
    }

    // Create a `spirv.BranchConditional` operation for selection header block.
    rewriter.setInsertionPointToEnd(selectionHeaderBlock);
    spirv::BranchConditionalOp::create(rewriter, loc, adaptor.getCondition(),
                                       thenBlock, ArrayRef<Value>(), elseBlock,
                                       ArrayRef<Value>());

    replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
                          returnTypes);
    return success();
  }
};

//===----------------------------------------------------------------------===//
// scf::YieldOp
//===----------------------------------------------------------------------===//

struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
public:
  using SCFToSPIRVPattern::SCFToSPIRVPattern;

  LogicalResult
  matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    ValueRange operands = adaptor.getOperands();

    Operation *parent = terminatorOp->getParentOp();

    // TODO: Implement conversion for the remaining `scf` ops.
    if (parent->getDialect()->getNamespace() ==
            scf::SCFDialect::getDialectNamespace() &&
        !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
      return rewriter.notifyMatchFailure(
          terminatorOp,
          llvm::formatv("conversion not supported for parent op: '{0}'",
                        parent->getName()));

    // If the region return values, store each value into the associated
    // VariableOp created during lowering of the parent region.
    if (!operands.empty()) {
      auto &allocas = scfToSPIRVContext->outputVars[parent];
      if (allocas.size() != operands.size())
        return failure();

      auto loc = terminatorOp.getLoc();
      for (unsigned i = 0, e = operands.size(); i < e; i++)
        spirv::StoreOp::create(rewriter, loc, allocas[i], operands[i]);
      if (isa<spirv::LoopOp>(parent)) {
        // For loops we also need to update the branch jumping back to the
        // header.
        auto br = cast<spirv::BranchOp>(
            rewriter.getInsertionBlock()->getTerminator());
        SmallVector<Value, 8> args(br.getBlockArguments());
        args.append(operands.begin(), operands.end());
        rewriter.setInsertionPoint(br);
        spirv::BranchOp::create(rewriter, terminatorOp.getLoc(), br.getTarget(),
                                args);
        rewriter.eraseOp(br);
      }
    }
    rewriter.eraseOp(terminatorOp);
    return success();
  }
};

//===----------------------------------------------------------------------===//
// scf::WhileOp
//===----------------------------------------------------------------------===//

struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
  using SCFToSPIRVPattern::SCFToSPIRVPattern;

  LogicalResult
  matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto loc = whileOp.getLoc();
    auto loopOp =
        spirv::LoopOp::create(rewriter, loc, spirv::LoopControl::None);
    loopOp.addEntryAndMergeBlock(rewriter);

    Region &beforeRegion = whileOp.getBefore();
    Region &afterRegion = whileOp.getAfter();

    if (failed(rewriter.convertRegionTypes(&beforeRegion, typeConverter)) ||
        failed(rewriter.convertRegionTypes(&afterRegion, typeConverter)))
      return rewriter.notifyMatchFailure(whileOp,
                                         "Failed to convert region types");

    OpBuilder::InsertionGuard guard(rewriter);

    Block &entryBlock = *loopOp.getEntryBlock();
    Block &beforeBlock = beforeRegion.front();
    Block &afterBlock = afterRegion.front();
    Block &mergeBlock = *loopOp.getMergeBlock();

    auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator());
    SmallVector<Value> condArgs;
    if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs)))
      return failure();

    Value conditionVal = rewriter.getRemappedValue(cond.getCondition());
    if (!conditionVal)
      return failure();

    auto yield = cast<scf::YieldOp>(afterBlock.getTerminator());
    SmallVector<Value> yieldArgs;
    if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs)))
      return failure();

    // Move the while before block as the initial loop header block.
    rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
                                getBlockIt(loopOp.getBody(), 1));

    // Move the while after block as the initial loop body block.
    rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
                                getBlockIt(loopOp.getBody(), 2));

    // Jump from the loop entry block to the loop header block.
    rewriter.setInsertionPointToEnd(&entryBlock);
    spirv::BranchOp::create(rewriter, loc, &beforeBlock, adaptor.getInits());

    auto condLoc = cond.getLoc();

    SmallVector<Value> resultValues(condArgs.size());

    // For other SCF ops, the scf.yield op yields the value for the whole SCF
    // op. So we use the scf.yield op as the anchor to create/load/store SPIR-V
    // local variables. But for the scf.while op, the scf.yield op yields a
    // value for the before region, which may not matching the whole op's
    // result. Instead, the scf.condition op returns values matching the whole
    // op's results. So we need to create/load/store variables according to
    // that.
    for (const auto &it : llvm::enumerate(condArgs)) {
      auto res = it.value();
      auto i = it.index();
      auto pointerType =
          spirv::PointerType::get(res.getType(), spirv::StorageClass::Function);

      // Create local variables before the scf.while op.
      rewriter.setInsertionPoint(loopOp);
      auto alloc = spirv::VariableOp::create(rewriter, condLoc, pointerType,
                                             spirv::StorageClass::Function,
                                             /*initializer=*/nullptr);

      // Load the final result values after the scf.while op.
      rewriter.setInsertionPointAfter(loopOp);
      auto loadResult = spirv::LoadOp::create(rewriter, condLoc, alloc);
      resultValues[i] = loadResult;

      // Store the current iteration's result value.
      rewriter.setInsertionPointToEnd(&beforeBlock);
      spirv::StoreOp::create(rewriter, condLoc, alloc, res);
    }

    rewriter.setInsertionPointToEnd(&beforeBlock);
    rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
        cond, conditionVal, &afterBlock, condArgs, &mergeBlock, ValueRange());

    // Convert the scf.yield op to a branch back to the header block.
    rewriter.setInsertionPointToEnd(&afterBlock);
    rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock,
                                                 yieldArgs);

    rewriter.replaceOp(whileOp, resultValues);
    return success();
  }
};
} // namespace

//===----------------------------------------------------------------------===//
// Public API
//===----------------------------------------------------------------------===//

void mlir::populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
                                      ScfToSPIRVContext &scfToSPIRVContext,
                                      RewritePatternSet &patterns) {
  patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
               WhileOpConversion>(patterns.getContext(), typeConverter,
                                  scfToSPIRVContext.getImpl());
}