1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
|
//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir
using namespace mlir;
static inline bool isScalarLike(Type t) {
return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
}
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
if (!OpTrait::hasElementwiseMappableTraits(op))
return false;
auto types = op->getOperandTypes();
// We want at least one ranked tensor.
bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);
// No invalid operands (i.e., every operand is a ranked tensor or
// scalar-like).
bool noneInvalid = llvm::none_of(types, [](Type t) {
return !(isa<RankedTensorType>(t) || isScalarLike(t));
});
return anyRankedTensor && noneInvalid;
}
/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
/// the result types and return a list of values such that, for each result type
/// `t` and value `v` at the same index `idx`:
/// 1. `v.getType() == t`
/// 2. If an operand of `op` has type `t`, let `operand_first` be the first
/// such operand. Then`v == operand_first`.
/// 3. Otherwise, v is a newly created `tensor::EmptyOp` with:
/// a. Static and dynamic dims extracted from the first operand of `op`.
/// b. Elemental type equal to the elemental type of `t`.
///
/// This is sufficient because ElementwiseMappable guarantees that "The static
/// types of all vector (resp. tensor) operands and results must have the same
/// shape".
static SmallVector<Value, 4>
getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
assert(isElementwiseMappableOpOnRankedTensors(op));
Location loc = op->getLoc();
ValueRange operands = op->getOperands();
TypeRange rankedTensorTypes = op->getResultTypes();
SmallVector<Value, 4> res;
res.reserve(rankedTensorTypes.size());
for (Type t : rankedTensorTypes) {
// Try to find an operand with type matching the result tensor.
bool found = false;
for (Value v : operands) {
if (v.getType() == t) {
found = true;
res.push_back(v);
break;
}
}
if (found)
continue;
// Extract static / dynamic shape mix from the first operand.
res.push_back(tensor::EmptyOp::create(
b, loc, tensor::getMixedSizes(b, loc, operands.front()),
cast<RankedTensorType>(t).getElementType()));
}
return res;
}
namespace {
struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
if (!isElementwiseMappableOpOnRankedTensors(op))
return rewriter.notifyMatchFailure(
op, "requires elementwise op on ranked tensors");
auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
auto rank = resTy.getRank();
// Maps: identity for tensors (rank > 0), scalar map for scalars.
AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
/*results=*/{}, rewriter.getContext());
AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
// Match phase.
SmallVector<bool> isScalarOperand;
isScalarOperand.reserve(op->getNumOperands());
for (Type ty : op->getOperandTypes()) {
if (isScalarLike(ty))
isScalarOperand.push_back(true);
else if (auto rt = dyn_cast<RankedTensorType>(ty))
isScalarOperand.push_back(false);
else
return rewriter.notifyMatchFailure(
op,
"unsupported operand type (expected scalar-like or ranked tensor)");
}
// Create indexing maps.
SmallVector<AffineMap> indexingMaps;
indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
for (bool isScalar : isScalarOperand)
indexingMaps.push_back(isScalar ? scalarMap : idMap);
indexingMaps.append(op->getNumResults(), idMap);
SmallVector<utils::IteratorType> iteratorTypes(
rank, utils::IteratorType::parallel);
SmallVector<Value> outputs =
getOrCreateOperandsMatchingResultTypes(rewriter, op);
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, /*resultTensorTypes=*/op->getResultTypes(),
/*inputs=*/op->getOperands(),
/*outputs=*/outputs,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
/*bodyBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
SmallVector<Type> resultEltTys = llvm::to_vector<6>(
llvm::map_range(op->getResultTypes(), [](Type type) {
return cast<TensorType>(type).getElementType();
}));
Operation *scalarOp =
builder.create(loc, op->getName().getIdentifier(),
regionArgs.take_front(op->getNumOperands()),
resultEltTys, op->getAttrs());
linalg::YieldOp::create(builder, loc, scalarOp->getResults());
});
return success();
}
};
} // namespace
void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
patterns.getContext());
}
namespace {
class ConvertElementwiseToLinalgPass
: public impl::ConvertElementwiseToLinalgPassBase<
ConvertElementwiseToLinalgPass> {
using impl::ConvertElementwiseToLinalgPassBase<
ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;
void runOnOperation() final {
auto *func = getOperation();
auto *context = &getContext();
ConversionTarget target(*context);
RewritePatternSet patterns(context);
mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns);
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return !isElementwiseMappableOpOnRankedTensors(op);
});
if (failed(applyPartialConversion(func, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
|