1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
|
//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
namespace mlir {
namespace bufferization {
#define GEN_PASS_DEF_EMPTYTENSORELIMINATIONPASS
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
} // namespace bufferization
} // namespace mlir
using namespace mlir;
using namespace mlir::bufferization;
/// Return true if all `neededValues` are in scope at the given
/// `insertionPoint`.
static bool
neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
Operation *insertionPoint,
const SmallVector<Value> &neededValues) {
for (Value val : neededValues) {
if (auto bbArg = dyn_cast<BlockArgument>(val)) {
Block *owner = bbArg.getOwner();
if (!owner->findAncestorOpInBlock(*insertionPoint))
return false;
} else {
auto opResult = cast<OpResult>(val);
if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
return false;
}
}
return true;
}
/// Find a valid insertion point for a replacement of `emptyTensorOp`'s
/// use of `user` operation, assuming that the replacement may use any
/// value from `neededValues`.
static Operation *
findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
const SmallVector<Value> &neededValues) {
DominanceInfo domInfo;
Operation *candidateInsertionPoint = emptyTensorOp;
// Gather all possible insertion points: the location of
// `candidateInsertionPoint` and right after the definition of each value in
// `neededValues`.
SmallVector<Operation *> insertionPointCandidates;
insertionPointCandidates.push_back(candidateInsertionPoint);
for (Value val : neededValues) {
// Note: The anchor op is using all of `neededValues`, so:
// * in case of a block argument: There must be at least one op in the block
// (the anchor op or one of its parents).
// * in case of an OpResult: There must be at least one op right after the
// defining op (the anchor op or one of its
// parents).
if (auto bbArg = dyn_cast<BlockArgument>(val)) {
insertionPointCandidates.push_back(
&bbArg.getOwner()->getOperations().front());
} else {
insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
}
}
// Select first matching insertion point.
for (Operation *insertionPoint : insertionPointCandidates) {
// Check if all needed values are in scope.
if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
neededValues))
continue;
// Check if the insertion point is before the use to be replaced.
if (!domInfo.dominates(insertionPoint, user))
continue;
return insertionPoint;
}
// No suitable insertion point was found.
return nullptr;
}
Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
SubsetInsertionOpInterface op,
tensor::EmptyOp emptyTensorOp,
Operation *user) {
mlir::OpBuilder::InsertionGuard guard(rewriter);
// All values that are needed to create the replacement op.
SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction();
// Find a suitable insertion point. If no suitable insertion point
// for the replacement can be found, return an empty value to skip
// this replacement.
Operation *insertionPoint =
findValidInsertionPoint(emptyTensorOp, user, neededValues);
if (!insertionPoint)
return {};
rewriter.setInsertionPoint(insertionPoint);
Value replacement =
op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
return replacement;
}
LogicalResult mlir::bufferization::eliminateEmptyTensors(
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
ControlBuildSubsetExtractionFn subsetsExtractionFn) {
OpBuilder::InsertionGuard g(rewriter);
llvm::DenseSet<OpOperand *> visitedOpOperands;
op->walk([&](SubsetInsertionOpInterface op) {
visitedOpOperands.clear();
OpOperand &source = op.getSourceOperand();
// Skip operands that do not bufferize inplace. "tensor.empty" could still
// be replaced, but the transformation may not be beneficial.
if (!state.isInPlace(source))
return WalkResult::skip();
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
// equivalent tensors. I.e., stop when there are ops such as extract_slice
// on the path.
TraversalConfig config;
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
// Replace only if the types match or are static <-> dynamic casts. We do
// not support slices or reshapes.
// TODO: This could be extended to support IR such as:
// %0 = tensor.empty() : tensor<128xf32>
// %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
// %2 = tensor.expand_shape %1 ...
// %3 = tensor.insert_slice %2 into ...
config.followSameTypeOrCastsOnly = true;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
&source, /*condition=*/
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
&visitedOpOperands);
for (Value v : emptyTensors) {
auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
assert(emptyTensorOp && "expected tensor.empty op");
// Find the use to be replaced from the use-def chain.
auto iter = llvm::find_if(
visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
return llvm::count(emptyTensorOp->getUses(), *opOperand);
});
assert(iter != visitedOpOperands.end() && "could not find use");
OpOperand *useToBeReplaced = *iter;
Operation *user = useToBeReplaced->getOwner();
auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
if (!replacement)
continue;
if (emptyTensorOp == replacement.getDefiningOp())
continue;
if (replacement.getType() != v.getType()) {
if (cast<ShapedType>(replacement.getType()).getElementType() !=
cast<ShapedType>(v.getType()).getElementType())
continue;
rewriter.setInsertionPointAfterValue(replacement);
replacement = tensor::CastOp::create(rewriter, v.getLoc(), v.getType(),
replacement);
}
// Replace the specific use of the tensor::EmptyOp.
rewriter.modifyOpInPlace(user, [&]() {
user->setOperand(useToBeReplaced->getOperandNumber(), replacement);
});
state.resetCache();
}
return WalkResult::advance();
});
return success();
}
namespace {
struct EmptyTensorElimination
: public bufferization::impl::EmptyTensorEliminationPassBase<
EmptyTensorElimination> {
using Base::Base;
void runOnOperation() override;
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
}
};
} // namespace
LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
Operation *op) {
auto moduleOp = dyn_cast<ModuleOp>(op);
OneShotBufferizationOptions options;
options.allowReturnAllocsFromLoops = true;
if (moduleOp)
options.bufferizeFunctionBoundaries = true;
OneShotAnalysisState state(op, options);
if (moduleOp) {
// Module analysis takes into account function boundaries.
if (failed(analyzeModuleOp(moduleOp, state)))
return failure();
} else {
// Regular One-Shot Bufferize ignores func.func block arguments, func.call,
// func.return.
if (failed(analyzeOp(op, state)))
return failure();
}
return bufferization::eliminateEmptyTensors(rewriter, op, state);
}
void EmptyTensorElimination::runOnOperation() {
IRRewriter rewriter(getOperation()->getContext());
if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation())))
signalPassFailure();
}
|