aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
blob: 06d7e07005f8ae5ffe6a6e246c16926234f93980 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
//===- WrapFuncInClass.cpp - Wrap Emitc Funcs in classes -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"

using namespace mlir;
using namespace emitc;

namespace mlir {
namespace emitc {
#define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"

namespace {
struct WrapFuncInClassPass
    : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
  using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
  void runOnOperation() override {
    Operation *rootOp = getOperation();

    RewritePatternSet patterns(&getContext());
    populateFuncPatterns(patterns);

    walkAndApplyPatterns(rootOp, std::move(patterns));
  }
};

} // namespace
} // namespace emitc
} // namespace mlir

class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
public:
  WrapFuncInClass(MLIRContext *context)
      : OpRewritePattern<emitc::FuncOp>(context) {}

  LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
                                PatternRewriter &rewriter) const override {

    auto className = funcOp.getSymNameAttr().str() + "Class";
    ClassOp newClassOp = ClassOp::create(rewriter, funcOp.getLoc(), className);

    SmallVector<std::pair<StringAttr, TypeAttr>> fields;
    rewriter.createBlock(&newClassOp.getBody());
    rewriter.setInsertionPointToStart(&newClassOp.getBody().front());

    auto argAttrs = funcOp.getArgAttrs();
    for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) {
      StringAttr fieldName =
          rewriter.getStringAttr("fieldName" + std::to_string(idx));

      TypeAttr typeAttr = TypeAttr::get(val.getType());
      fields.push_back({fieldName, typeAttr});

      FieldOp fieldop = emitc::FieldOp::create(rewriter, funcOp->getLoc(),
                                               fieldName, typeAttr, nullptr);

      if (argAttrs && idx < argAttrs->size()) {
        fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx));
      }
    }

    rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
    FunctionType funcType = funcOp.getFunctionType();
    Location loc = funcOp.getLoc();
    FuncOp newFuncOp =
        emitc::FuncOp::create(rewriter, loc, ("execute"), funcType);

    rewriter.createBlock(&newFuncOp.getBody());
    newFuncOp.getBody().takeBody(funcOp.getBody());

    rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
    std::vector<Value> newArguments;
    newArguments.reserve(fields.size());
    for (auto &[fieldName, attr] : fields) {
      GetFieldOp arg =
          emitc::GetFieldOp::create(rewriter, loc, attr.getValue(), fieldName);
      newArguments.push_back(arg);
    }

    for (auto [oldArg, newArg] :
         llvm::zip(newFuncOp.getArguments(), newArguments)) {
      rewriter.replaceAllUsesWith(oldArg, newArg);
    }

    llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true);
    if (failed(newFuncOp.eraseArguments(argsToErase)))
      newFuncOp->emitOpError("failed to erase all arguments using BitVector");

    rewriter.replaceOp(funcOp, newClassOp);
    return success();
  }
};

void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns) {
  patterns.add<WrapFuncInClass>(patterns.getContext());
}