1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
|
//===- TestPrintDefUse.cpp - Passes to illustrate the IR def-use chains ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/Bytecode/Encoding.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OwningOpRef.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include <numeric>
#include <random>
using namespace mlir;
namespace {
/// This pass tests that:
/// 1) we can shuffle use-lists correctly;
/// 2) use-list orders are preserved after a roundtrip to bytecode.
class TestPreserveUseListOrders
: public PassWrapper<TestPreserveUseListOrders, OperationPass<ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPreserveUseListOrders)
TestPreserveUseListOrders() = default;
TestPreserveUseListOrders(const TestPreserveUseListOrders &pass)
: PassWrapper(pass) {}
StringRef getArgument() const final { return "test-verify-uselistorder"; }
StringRef getDescription() const final {
return "Verify that roundtripping the IR to bytecode preserves the order "
"of the uselists";
}
Option<unsigned> rngSeed{*this, "rng-seed",
llvm::cl::desc("Specify an input random seed"),
llvm::cl::init(1)};
LogicalResult initialize(MLIRContext *context) override {
rng.seed(static_cast<unsigned>(rngSeed));
return success();
}
void runOnOperation() override {
// Clone the module so that we can plug in this pass to any other
// independently.
OwningOpRef<ModuleOp> cloneModule = getOperation().clone();
// 1. Compute the op numbering of the module.
computeOpNumbering(*cloneModule);
// 2. Loop over all the values and shuffle the uses. While doing so, check
// that each shuffle is correct.
if (failed(shuffleUses(*cloneModule)))
return signalPassFailure();
// 3. Do a bytecode roundtrip to version 3, which supports use-list order
// preservation.
auto roundtripModuleOr = doRoundtripToBytecode(*cloneModule, 3);
// If the bytecode roundtrip failed, try to roundtrip the original module
// to version 2, which does not support use-list. If this also fails, the
// original module had an issue unrelated to uselists.
if (failed(roundtripModuleOr)) {
auto testModuleOr = doRoundtripToBytecode(getOperation(), 2);
if (failed(testModuleOr))
return;
return signalPassFailure();
}
// 4. Recompute the op numbering on the new module. The numbering should be
// the same as (1), but on the new operation pointers.
computeOpNumbering(roundtripModuleOr->get());
// 5. Loop over all the values and verify that the use-list is consistent
// with the post-shuffle order of step (2).
if (failed(verifyUseListOrders(roundtripModuleOr->get())))
return signalPassFailure();
}
private:
FailureOr<OwningOpRef<Operation *>> doRoundtripToBytecode(Operation *module,
uint32_t version) {
std::string str;
llvm::raw_string_ostream m(str);
BytecodeWriterConfig config;
config.setDesiredBytecodeVersion(version);
if (failed(writeBytecodeToFile(module, m, config)))
return failure();
ParserConfig parseConfig(&getContext(), /*verifyAfterParse=*/true);
auto newModuleOp = parseSourceString(StringRef(str), parseConfig);
if (!newModuleOp.get())
return failure();
return newModuleOp;
}
/// Compute an ordered numbering for all the operations in the IR.
void computeOpNumbering(Operation *topLevelOp) {
uint32_t operationID = 0;
opNumbering.clear();
topLevelOp->walk<mlir::WalkOrder::PreOrder>(
[&](Operation *op) { opNumbering.try_emplace(op, operationID++); });
}
template <typename ValueT>
SmallVector<uint64_t> getUseIDs(ValueT val) {
return SmallVector<uint64_t>(llvm::map_range(val.getUses(), [&](auto &use) {
return bytecode::getUseID(use, opNumbering.at(use.getOwner()));
}));
}
LogicalResult shuffleUses(Operation *topLevelOp) {
uint32_t valueID = 0;
/// Permute randomly the use-list of each value. It is guaranteed that at
/// least one pair of the use list is permuted.
auto doShuffleForRange = [&](ValueRange range) -> LogicalResult {
for (auto val : range) {
if (val.use_empty() || val.hasOneUse())
continue;
/// Get a valid index permutation for the uses of value.
SmallVector<unsigned> permutation = getRandomPermutation(val);
/// Store original order and verify that the shuffle was applied
/// correctly.
auto useIDs = getUseIDs(val);
/// Apply shuffle to the uselist.
val.shuffleUseList(permutation);
/// Get the new order and verify the shuffle happened correctly.
auto permutedIDs = getUseIDs(val);
if (permutedIDs.size() != useIDs.size())
return failure();
for (size_t idx = 0; idx < permutation.size(); idx++)
if (useIDs[idx] != permutedIDs[permutation[idx]])
return failure();
referenceUseListOrder.try_emplace(
valueID++, llvm::map_range(val.getUses(), [&](auto &use) {
return bytecode::getUseID(use, opNumbering.at(use.getOwner()));
}));
}
return success();
};
return walkOverValues(topLevelOp, doShuffleForRange);
}
LogicalResult verifyUseListOrders(Operation *topLevelOp) {
uint32_t valueID = 0;
/// Check that the use-list for the value range matches the one stored in
/// the reference.
auto doValidationForRange = [&](ValueRange range) -> LogicalResult {
for (auto val : range) {
if (val.use_empty() || val.hasOneUse())
continue;
auto referenceOrder = referenceUseListOrder.at(valueID++);
for (auto [use, referenceID] :
llvm::zip(val.getUses(), referenceOrder)) {
uint64_t uniqueID =
bytecode::getUseID(use, opNumbering.at(use.getOwner()));
if (uniqueID != referenceID) {
use.getOwner()->emitError()
<< "found use-list order mismatch for value: " << val;
return failure();
}
}
}
return success();
};
return walkOverValues(topLevelOp, doValidationForRange);
}
/// Walk over blocks and operations and execute a callable over the ranges of
/// operands/results respectively.
template <typename FuncT>
LogicalResult walkOverValues(Operation *topLevelOp, FuncT callable) {
auto blockWalk = topLevelOp->walk([&](Block *block) {
if (failed(callable(block->getArguments())))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (blockWalk.wasInterrupted())
return failure();
auto resultsWalk = topLevelOp->walk([&](Operation *op) {
if (failed(callable(op->getResults())))
return WalkResult::interrupt();
return WalkResult::advance();
});
return failure(resultsWalk.wasInterrupted());
}
/// Creates a random permutation of the uselist order chain of the provided
/// value.
SmallVector<unsigned> getRandomPermutation(Value value) {
size_t numUses = std::distance(value.use_begin(), value.use_end());
SmallVector<unsigned> permutation(numUses);
unsigned zero = 0;
std::iota(permutation.begin(), permutation.end(), zero);
std::shuffle(permutation.begin(), permutation.end(), rng);
return permutation;
}
/// Map each value to its use-list order encoded with unique use IDs.
DenseMap<uint32_t, SmallVector<uint64_t>> referenceUseListOrder;
/// Map each operation to its global ID.
DenseMap<Operation *, uint32_t> opNumbering;
std::default_random_engine rng;
};
} // namespace
namespace mlir {
void registerTestPreserveUseListOrders() {
PassRegistration<TestPreserveUseListOrders>();
}
} // namespace mlir
|