1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
|
//===- Transforms.cpp - Patterns and transforms for the EmitC dialect -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace emitc {
ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
assert(isa<emitc::CExpressionInterface>(op) && "Expected a C expression");
// Create an expression yielding the value returned by op.
assert(op->getNumResults() == 1 && "Expected exactly one result");
Value result = op->getResult(0);
Type resultType = result.getType();
Location loc = op->getLoc();
builder.setInsertionPointAfter(op);
auto expressionOp =
emitc::ExpressionOp::create(builder, loc, resultType, op->getOperands());
// Replace all op's uses with the new expression's result.
result.replaceAllUsesWith(expressionOp.getResult());
Block &block = expressionOp.createBody();
IRMapping mapper;
for (auto [operand, arg] :
llvm::zip(expressionOp.getOperands(), block.getArguments()))
mapper.map(operand, arg);
builder.setInsertionPointToEnd(&block);
Operation *rootOp = builder.clone(*op, mapper);
op->erase();
// Create an op to yield op's value.
emitc::YieldOp::create(builder, loc, rootOp->getResults()[0]);
return expressionOp;
}
} // namespace emitc
} // namespace mlir
using namespace mlir;
using namespace mlir::emitc;
namespace {
struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
using OpRewritePattern<ExpressionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpressionOp expressionOp,
PatternRewriter &rewriter) const override {
Block *expressionBody = expressionOp.getBody();
ExpressionOp usedExpression;
SetVector<Value> foldedOperands;
auto takesItsOperandsAddress = [](Operation *user) {
auto applyOp = dyn_cast<emitc::ApplyOp>(user);
return applyOp && applyOp.getApplicableOperator() == "&";
};
// Select as expression to fold the first operand expression that
// - doesn't have its result value's address taken,
// - has a single user: assume any re-materialization was done separately,
// - has no side effects,
// and save all other operands to be used later as operands in the folded
// expression.
for (auto [operand, arg] : llvm::zip(expressionOp.getOperands(),
expressionBody->getArguments())) {
ExpressionOp operandExpression = operand.getDefiningOp<ExpressionOp>();
if (usedExpression || !operandExpression ||
llvm::any_of(arg.getUsers(), takesItsOperandsAddress) ||
!operandExpression.getResult().hasOneUse() ||
operandExpression.hasSideEffects())
foldedOperands.insert(operand);
else
usedExpression = operandExpression;
}
// If no operand expression was selected, bail out.
if (!usedExpression)
return failure();
// Collect additional operands from the folded expression.
for (Value operand : usedExpression.getOperands())
foldedOperands.insert(operand);
// Create a new expression to hold the folding result.
rewriter.setInsertionPointAfter(expressionOp);
auto foldedExpression = emitc::ExpressionOp::create(
rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
foldedOperands.getArrayRef(), expressionOp.getDoNotInline());
Block &foldedExpressionBody = foldedExpression.createBody();
// Map each operand of the new expression to its matching block argument.
IRMapping mapper;
for (auto [operand, arg] : llvm::zip(foldedExpression.getOperands(),
foldedExpressionBody.getArguments()))
mapper.map(operand, arg);
// Prepare to fold the used expression and the matched expression into the
// newly created folded expression.
auto foldExpression = [&rewriter, &mapper](ExpressionOp expressionToFold,
bool withTerminator) {
Block *expressionToFoldBody = expressionToFold.getBody();
for (auto [operand, arg] :
llvm::zip(expressionToFold.getOperands(),
expressionToFoldBody->getArguments())) {
mapper.map(arg, mapper.lookup(operand));
}
for (Operation &opToClone : expressionToFoldBody->without_terminator())
rewriter.clone(opToClone, mapper);
if (withTerminator)
rewriter.clone(*expressionToFoldBody->getTerminator(), mapper);
};
rewriter.setInsertionPointToStart(&foldedExpressionBody);
// First, fold the used expression into the new expression and map its
// result to the clone of its root operation within the new expression.
foldExpression(usedExpression, /*withTerminator=*/false);
Operation *expressionRoot = usedExpression.getRootOp();
Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
assert(clonedExpressionRootOp &&
"Expected cloned expression root to be in mapper");
assert(clonedExpressionRootOp->getNumResults() == 1 &&
"Expected cloned root to have a single result");
mapper.map(usedExpression.getResult(),
clonedExpressionRootOp->getResults()[0]);
// Now fold the matched expression into the new expression.
foldExpression(expressionOp, /*withTerminator=*/true);
// Complete the rewrite.
rewriter.replaceOp(expressionOp, foldedExpression);
rewriter.eraseOp(usedExpression);
return success();
}
};
} // namespace
void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) {
patterns.add<FoldExpressionOp>(patterns.getContext());
}
|