1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
//===- WrapFuncInClass.cpp - Wrap Emitc Funcs in classes -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
using namespace mlir;
using namespace emitc;
namespace mlir {
namespace emitc {
#define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
namespace {
struct WrapFuncInClassPass
: public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
void runOnOperation() override {
Operation *rootOp = getOperation();
RewritePatternSet patterns(&getContext());
populateFuncPatterns(patterns);
walkAndApplyPatterns(rootOp, std::move(patterns));
}
};
} // namespace
} // namespace emitc
} // namespace mlir
class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
public:
WrapFuncInClass(MLIRContext *context)
: OpRewritePattern<emitc::FuncOp>(context) {}
LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
PatternRewriter &rewriter) const override {
auto className = funcOp.getSymNameAttr().str() + "Class";
ClassOp newClassOp = ClassOp::create(rewriter, funcOp.getLoc(), className);
SmallVector<std::pair<StringAttr, TypeAttr>> fields;
rewriter.createBlock(&newClassOp.getBody());
rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
auto argAttrs = funcOp.getArgAttrs();
for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) {
StringAttr fieldName =
rewriter.getStringAttr("fieldName" + std::to_string(idx));
TypeAttr typeAttr = TypeAttr::get(val.getType());
fields.push_back({fieldName, typeAttr});
FieldOp fieldop = emitc::FieldOp::create(rewriter, funcOp->getLoc(),
fieldName, typeAttr, nullptr);
if (argAttrs && idx < argAttrs->size()) {
fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx));
}
}
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
FunctionType funcType = funcOp.getFunctionType();
Location loc = funcOp.getLoc();
FuncOp newFuncOp =
emitc::FuncOp::create(rewriter, loc, ("execute"), funcType);
rewriter.createBlock(&newFuncOp.getBody());
newFuncOp.getBody().takeBody(funcOp.getBody());
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
std::vector<Value> newArguments;
newArguments.reserve(fields.size());
for (auto &[fieldName, attr] : fields) {
GetFieldOp arg =
emitc::GetFieldOp::create(rewriter, loc, attr.getValue(), fieldName);
newArguments.push_back(arg);
}
for (auto [oldArg, newArg] :
llvm::zip(newFuncOp.getArguments(), newArguments)) {
rewriter.replaceAllUsesWith(oldArg, newArg);
}
llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true);
if (failed(newFuncOp.eraseArguments(argsToErase)))
newFuncOp->emitOpError("failed to erase all arguments using BitVector");
rewriter.replaceOp(funcOp, newClassOp);
return success();
}
};
void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns) {
patterns.add<WrapFuncInClass>(patterns.getContext());
}
|