aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
blob: caaffc5a33155d18b4cc52d45589501a28454ce3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
//===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert standard dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"

#define DEBUG_TYPE "cf-to-spirv-pattern"

using namespace mlir;

/// Legailze target block arguments.
static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
                                            PatternRewriter &rewriter,
                                            const TypeConverter &converter) {
  auto builder = OpBuilder::atBlockBegin(&block);
  for (unsigned i = 0; i < block.getNumArguments(); ++i) {
    BlockArgument arg = block.getArgument(i);
    if (converter.isLegal(arg.getType()))
      continue;
    Type ty = arg.getType();
    Type newTy = converter.convertType(ty);
    if (!newTy) {
      return rewriter.notifyMatchFailure(
          op, llvm::formatv("failed to legalize type for argument {0})", arg));
    }
    unsigned argNum = arg.getArgNumber();
    Location loc = arg.getLoc();
    Value newArg = block.insertArgument(argNum, newTy, loc);
    Value convertedValue = converter.materializeSourceConversion(
        builder, op->getLoc(), ty, newArg);
    if (!convertedValue) {
      return rewriter.notifyMatchFailure(
          op, llvm::formatv("failed to cast new argument {0} to type {1})",
                            newArg, ty));
    }
    arg.replaceAllUsesWith(convertedValue);
    block.eraseArgument(argNum + 1);
  }
  return success();
}

//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//

namespace {
/// Converts cf.br to spirv.Branch.
struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    if (failed(legalizeBlockArguments(*op.getDest(), op, rewriter,
                                      *getTypeConverter())))
      return failure();

    rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
                                                 adaptor.getDestOperands());
    return success();
  }
};

/// Converts cf.cond_br to spirv.BranchConditional.
struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    if (failed(legalizeBlockArguments(*op.getTrueDest(), op, rewriter,
                                      *getTypeConverter())))
      return failure();

    if (failed(legalizeBlockArguments(*op.getFalseDest(), op, rewriter,
                                      *getTypeConverter())))
      return failure();

    rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
        op, adaptor.getCondition(), op.getTrueDest(),
        adaptor.getTrueDestOperands(), op.getFalseDest(),
        adaptor.getFalseDestOperands());
    return success();
  }
};
} // namespace

//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//

void mlir::cf::populateControlFlowToSPIRVPatterns(
    const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
  MLIRContext *context = patterns.getContext();

  patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
}