1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
|
//===- DecomposeGenericByUnfoldingPermutation.cpp -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include <map>
#include <utility>
using namespace mlir;
using namespace mlir::linalg;
namespace {
/// This pattern decomposes the input operand(s) of a linalg.generic that has
/// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose
/// and broadcast. Having them folded into the linalg.generic is a good
/// optimization but sometimes we may want to unwrap, i.e., `unfold` them as
/// explicit transpose and broadcast. This rewrite pattern helps do it for
/// each input operand. This is useful for instance when trying to recognize
/// named ops.
///
/// The transpose, broadcast, or mixture of both, are expressed in the affine
/// map of the operand. Technically it is essentially `projected permutation`.
///
/// Example
///
/// ```mlir
///
/// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
/// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
/// ...
/// %res = linalg.generic
/// { indexing_maps = [#projection, #identity, #identity],
/// iterator_types = ["parallel", "parallel", "parallel",
/// "parallel", "parallel"]}
/// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>)
/// outs(%z : tensor<5x9x7x8x10xf32>) {
/// ^bb0(%in: f32, %in_1: f32, %out: f32):
/// %div = arith.divf %in, %in_1 : f32
/// linalg.yield %div : f32
/// } -> tensor<5x9x7x8x10xf32>
/// ```
///
/// In the above IR operand `%x` map is a projected-permutation. This can be
/// unfolded as:
///
/// ```mlir
/// ...
/// %x_trans = linalg.transpose
/// ins(%x : tensor<7x8x9xf32>)
/// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
/// ...
/// %x_trans_bc = linalg.broadcast
/// ins(%x_trans : tensor<9x7x8xf32>)
/// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
/// %2 = linalg.div
/// ins(%x_trans_bc, %y :
/// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>)
/// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
///
/// Note that linalg.generic has been 'specialized' to linalg.div.
///
/// To unfold it, it is more optimal to transpose first and then do the
/// broadcast. However, if transpose is done first, the permutation map needs
/// to be expressed in terms of reduced dimension as broadcast hasn't happened
/// yet. Also, the broadcast dimensions in a linalg.generic come from other
/// operands (those not broadcasted along that particular dimension). We work
/// this out by computing the convex-polyhedron shape of the linalg.generic
/// iteration space from shapes of all the operands, both inputs and outputs.
///
struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override;
};
/// For the given `map`, determine what dimensions are transposed and what
/// dimensions are broadcasted.
/// Returns :
/// transpose-permutation, broadcast-dimensions` (empty if not needed)
///
std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
computeTransposeBroadcast(AffineMap &map) {
assert(map.isProjectedPermutation(false) && "not a projection");
// As the map is a projection it likely operates on a smaller set of
// dimensions as far as the transpose is concerned (rest are broadcast).
int64_t minorSize = map.getNumResults();
SmallVector<int64_t> minorResult;
for (int64_t i = 0; i < minorSize; ++i) {
auto expr = cast<AffineDimExpr>(map.getResults()[i]);
minorResult.push_back(expr.getPosition());
}
// If dims are not monotonically increasing then transpose is present.
SmallVector<int64_t> sortedResMap(minorResult);
llvm::sort(sortedResMap);
bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(),
sortedResMap.begin(), sortedResMap.end());
// Walk the sorted map result to determine which dimensions are broadcasted.
SmallVector<int64_t> broadcast;
for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) {
if (j < minorSize && sortedResMap[j] == i) {
j++;
continue;
}
broadcast.push_back(i);
}
SmallVector<int64_t> permutation;
if (hasTranspose) {
// Consider an operand `x : tensor<7x8x9>` of a genericOp that has
// affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
// `x`s access is both transposed and broadcast. But when specifying
// the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
// specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
// refering to d3, d4. Therefore, re-base the transpose dimensions so
// that they start from d0.
permutation.resize(minorSize);
std::map<int64_t, int64_t> minorMap;
for (int64_t i = 0; i < minorSize; ++i)
minorMap.insert({sortedResMap[i], i});
// Re-map the dimensions.
SmallVector<int64_t> remappedResult(minorSize);
for (int64_t i = 0; i < minorSize; ++i)
remappedResult[i] = minorMap[minorResult[i]];
/// Calculate the permutation for the transpose.
for (unsigned i = 0; i < minorSize; ++i) {
permutation[remappedResult[i]] = i;
}
}
return {permutation, broadcast};
}
LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
GenericOp op, PatternRewriter &rewriter) const {
if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() ||
op.isSingleYieldOp() || !op.isAllParallelLoops())
return failure();
// If the map of an operand is not a `projected permutation` then
// it cannot be decomposed to mere transpose and broadcast.
// The requirement that all maps be `projected permutation` may be
// over-restrictive but since we need to determine shape of the
// iteration space as well, reject if any map violates assumption.
for (auto &opOperand : op->getOpOperands()) {
auto map = op.getMatchingIndexingMap(&opOperand);
if (!map.isProjectedPermutation(false))
return failure();
}
// Decomposing linalg.generic involves creating `tensor.empty`
// which can have dynamic shapes but then we would have to work
// out which operand can supply that runtime-value (tensor.dim).
// Leaving it as a future TODO.
if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) {
auto opType = cast<RankedTensorType>(oper.get().getType());
return ShapedType::isDynamicShape(opType.getShape());
}))
return failure();
auto outputShape = op.getStaticLoopRanges();
auto loc = op.getLoc();
bool isChanged = false;
SmallVector<Value> newInitValues = op.getDpsInputs();
SmallVector<AffineMap> newMap = op.getIndexingMapsArray();
// Walk over each input operand and unfold if it is transposed, broadcast
// or mix of two via operand's affine-map.
for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
auto &map = newMap[i];
auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType());
auto elType = inputRTType.getElementType();
/// Nothing to do if map is already an identity.
if (map.isIdentity())
continue;
auto [permutation, broadcastedDims] = computeTransposeBroadcast(map);
// Does it need transpose?
if (!permutation.empty()) {
/// linalg.transpose permutes the dimensions of input using
/// rule: dim(result, i) = dim(input, permutation[i])
SmallVector<int64_t> transposedShape(map.getNumResults());
for (int64_t i = 0; i < map.getNumResults(); ++i)
transposedShape[i] = inputRTType.getShape()[permutation[i]];
Value emptyTensor =
tensor::EmptyOp::create(rewriter, loc, transposedShape, elType);
auto transposeOp = TransposeOp::create(rewriter, loc, newInitValues[i],
emptyTensor, permutation);
newInitValues[i] = transposeOp->getResult(0);
isChanged = true;
}
// Does it require broadcast?
if (!broadcastedDims.empty()) {
assert(!broadcastedDims.empty() && "should have non size broadcast");
Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, outputShape,
inputRTType.getElementType());
auto broadcastOp = linalg::BroadcastOp::create(
rewriter, loc, newInitValues[i], emptyTensor, broadcastedDims);
newInitValues[i] = broadcastOp->getResult(0);
isChanged = true;
}
newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
}
if (!isChanged)
return failure();
SmallVector<Value> operands = op->getOperands();
ValueRange operandsRef(operands);
auto newOp = linalg::GenericOp::create(
rewriter,
/*location=*/op.getLoc(),
/*resultTensorTypes=*/op->getResultTypes(),
/*inputs=*/newInitValues,
/*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
/*indexingMaps=*/newMap,
/*iteratorTypes=*/op.getIteratorTypesArray());
newOp.getRegion().takeBody(op->getRegion(0));
rewriter.replaceOp(op, newOp->getResults());
return success();
}
} // namespace
void mlir::linalg::populateDecomposeProjectedPermutationPatterns(
RewritePatternSet &patterns) {
patterns.insert<DecomposeProjectedPermutation>(patterns.getContext());
}
|