1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
|
//===- TestDenseBackwardDataFlowAnalysis.cpp - Test pass ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Test pass for backward dense dataflow analysis.
//
//===----------------------------------------------------------------------===//
#include "TestDenseDataFlowAnalysis.h"
#include "TestOps.h"
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
#include "mlir/Analysis/DataFlow/Utils.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/TypeID.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::dataflow;
using namespace mlir::dataflow::test;
#define DEBUG_TYPE "test-next-access"
namespace {
class NextAccess : public AbstractDenseLattice, public AccessLatticeBase {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess)
using dataflow::AbstractDenseLattice::AbstractDenseLattice;
ChangeResult meet(const AbstractDenseLattice &lattice) override {
return AccessLatticeBase::merge(static_cast<AccessLatticeBase>(
static_cast<const NextAccess &>(lattice)));
}
void print(raw_ostream &os) const override {
return AccessLatticeBase::print(os);
}
};
class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
public:
NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
bool assumeFuncReads = false)
: DenseBackwardDataFlowAnalysis(solver, symbolTable),
assumeFuncReads(assumeFuncReads) {}
LogicalResult visitOperation(Operation *op, const NextAccess &after,
NextAccess *before) override;
void visitCallControlFlowTransfer(CallOpInterface call,
CallControlFlowAction action,
const NextAccess &after,
NextAccess *before) override;
void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
RegionBranchPoint regionFrom,
RegionBranchPoint regionTo,
const NextAccess &after,
NextAccess *before) override;
// TODO: this isn't ideal for the analysis. When there is no next access, it
// means "we don't know what the next access is" rather than "there is no next
// access". But it's unclear how to differentiate the two cases...
void setToExitState(NextAccess *lattice) override {
LDBG() << "setToExitState: setting lattice to unknown state";
propagateIfChanged(lattice, lattice->setKnownToUnknown());
}
/// Visit an operation. If this analysis can confirm that lattice content
/// of lattice anchors around operation are necessarily identical, join
/// them into the same equivalent class.
void buildOperationEquivalentLatticeAnchor(Operation *op) override;
const bool assumeFuncReads;
};
} // namespace
LogicalResult NextAccessAnalysis::visitOperation(Operation *op,
const NextAccess &after,
NextAccess *before) {
LDBG() << "visitOperation: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
LDBG() << " after state: " << after;
LDBG() << " before state: " << *before;
auto memory = dyn_cast<MemoryEffectOpInterface>(op);
// If we can't reason about the memory effects, conservatively assume we can't
// say anything about the next access.
if (!memory) {
LDBG() << " No memory effect interface, setting to exit state";
setToExitState(before);
return success();
}
SmallVector<MemoryEffects::EffectInstance> effects;
memory.getEffects(effects);
LDBG() << " Found " << effects.size() << " memory effects";
// First, check if all underlying values are already known. Otherwise, avoid
// propagating and stay in the "undefined" state to avoid incorrectly
// propagating values that may be overwritten later on as that could be
// problematic for convergence based on monotonicity of lattice updates.
SmallVector<Value> underlyingValues;
underlyingValues.reserve(effects.size());
for (const MemoryEffects::EffectInstance &effect : effects) {
Value value = effect.getValue();
// Effects with unspecified value are treated conservatively and we cannot
// assume anything about the next access.
if (!value) {
LDBG() << " Effect has unspecified value, setting to exit state";
setToExitState(before);
return success();
}
// If cannot find the most underlying value, we cannot assume anything about
// the next accesses.
std::optional<Value> underlyingValue =
UnderlyingValueAnalysis::getMostUnderlyingValue(
value, [&](Value value) {
return getOrCreateFor<UnderlyingValueLattice>(
getProgramPointBefore(op), value);
});
// If the underlying value is not known yet, don't propagate.
if (!underlyingValue) {
LDBG() << " Underlying value not known for " << value
<< ", skipping propagation";
return success();
}
LDBG() << " Found underlying value " << *underlyingValue << " for "
<< value;
underlyingValues.push_back(*underlyingValue);
}
// Update the state if all underlying values are known.
LDBG() << " All underlying values known, updating state";
ChangeResult result = before->meet(after);
for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) {
// If the underlying value is known to be unknown, set to fixpoint.
if (!value) {
LDBG() << " Underlying value is unknown, setting to exit state";
setToExitState(before);
return success();
}
LDBG() << " Setting next access for value " << value << " to operation "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
result |= before->set(value, op);
}
LDBG() << " Final result: "
<< (result == ChangeResult::Change ? "changed" : "no change");
propagateIfChanged(before, result);
return success();
}
void NextAccessAnalysis::buildOperationEquivalentLatticeAnchor(Operation *op) {
LDBG() << "buildOperationEquivalentLatticeAnchor: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
if (isMemoryEffectFree(op)) {
LDBG() << " Operation is memory effect free, unioning lattice anchors";
unionLatticeAnchors<NextAccess>(getProgramPointBefore(op),
getProgramPointAfter(op));
} else {
LDBG() << " Operation has memory effects, not unioning lattice anchors";
}
}
void NextAccessAnalysis::visitCallControlFlowTransfer(
CallOpInterface call, CallControlFlowAction action, const NextAccess &after,
NextAccess *before) {
LDBG() << "visitCallControlFlowTransfer: "
<< OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
LDBG() << " action: "
<< (action == CallControlFlowAction::ExternalCallee ? "ExternalCallee"
: action == CallControlFlowAction::EnterCallee ? "EnterCallee"
: "ExitCallee");
LDBG() << " assumeFuncReads: " << assumeFuncReads;
if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) {
LDBG() << " Handling external callee with assumed function reads";
SmallVector<Value> underlyingValues;
underlyingValues.reserve(call->getNumOperands());
for (Value operand : call.getArgOperands()) {
std::optional<Value> underlyingValue =
UnderlyingValueAnalysis::getMostUnderlyingValue(
operand, [&](Value value) {
return getOrCreateFor<UnderlyingValueLattice>(
getProgramPointBefore(call.getOperation()), value);
});
if (!underlyingValue) {
LDBG() << " Underlying value not known for operand " << operand
<< ", returning";
return;
}
LDBG() << " Found underlying value " << *underlyingValue
<< " for operand " << operand;
underlyingValues.push_back(*underlyingValue);
}
LDBG() << " Setting next access for " << underlyingValues.size()
<< " operands";
ChangeResult result = before->meet(after);
for (Value operand : underlyingValues) {
LDBG() << " Setting next access for operand " << operand << " to call "
<< call;
result |= before->set(operand, call);
}
LDBG() << " Call control flow result: "
<< (result == ChangeResult::Change ? "changed" : "no change");
return propagateIfChanged(before, result);
}
auto testCallAndStore =
dyn_cast<::test::TestCallAndStoreOp>(call.getOperation());
if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
testCallAndStore.getStoreBeforeCall()) ||
(action == CallControlFlowAction::ExitCallee &&
!testCallAndStore.getStoreBeforeCall()))) {
LDBG() << " Handling TestCallAndStoreOp with special logic";
(void)visitOperation(call, after, before);
} else {
LDBG() << " Using default call control flow transfer logic";
AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer(
call, action, after, before);
}
}
void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
LDBG() << "visitRegionBranchControlFlowTransfer: "
<< OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
LDBG() << " regionFrom: " << (regionFrom.isParent() ? "parent" : "region");
LDBG() << " regionTo: " << (regionTo.isParent() ? "parent" : "region");
auto testStoreWithARegion =
dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
if (testStoreWithARegion &&
((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) ||
(regionFrom.isParent() &&
testStoreWithARegion.getStoreBeforeRegion()))) {
LDBG() << " Handling TestStoreWithARegion with special logic";
(void)visitOperation(branch, static_cast<const NextAccess &>(after),
static_cast<NextAccess *>(before));
} else {
LDBG() << " Using default region branch control flow transfer logic";
propagateIfChanged(before, before->meet(after));
}
}
namespace {
struct TestNextAccessPass
: public PassWrapper<TestNextAccessPass, OperationPass<>> {
TestNextAccessPass() = default;
TestNextAccessPass(const TestNextAccessPass &other) : PassWrapper(other) {
interprocedural = other.interprocedural;
assumeFuncReads = other.assumeFuncReads;
}
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass)
StringRef getArgument() const override { return "test-next-access"; }
Option<bool> interprocedural{
*this, "interprocedural", llvm::cl::init(true),
llvm::cl::desc("perform interprocedural analysis")};
Option<bool> assumeFuncReads{
*this, "assume-func-reads", llvm::cl::init(false),
llvm::cl::desc(
"assume external functions have read effect on all arguments")};
static constexpr llvm::StringLiteral kTagAttrName = "name";
static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access";
static constexpr llvm::StringLiteral kAtEntryPointAttrName =
"next_at_entry_point";
static Attribute makeNextAccessAttribute(Operation *op,
const DataFlowSolver &solver,
const NextAccess *nextAccess) {
if (!nextAccess)
return StringAttr::get(op->getContext(), "not computed");
// Note that if the underlying value could not be computed or is unknown, we
// conservatively treat the result also unknown.
SmallVector<Attribute> attrs;
for (Value operand : op->getOperands()) {
std::optional<Value> underlyingValue =
UnderlyingValueAnalysis::getMostUnderlyingValue(
operand, [&](Value value) {
return solver.lookupState<UnderlyingValueLattice>(value);
});
if (!underlyingValue) {
attrs.push_back(StringAttr::get(op->getContext(), "unknown"));
continue;
}
Value value = *underlyingValue;
const AdjacentAccess *nextAcc = nextAccess->getAdjacentAccess(value);
if (!nextAcc || !nextAcc->isKnown()) {
attrs.push_back(StringAttr::get(op->getContext(), "unknown"));
continue;
}
SmallVector<Attribute> innerAttrs;
innerAttrs.reserve(nextAcc->get().size());
for (Operation *nextAccOp : nextAcc->get()) {
if (auto nextAccTag =
nextAccOp->getAttrOfType<StringAttr>(kTagAttrName)) {
innerAttrs.push_back(nextAccTag);
continue;
}
std::string repr;
llvm::raw_string_ostream os(repr);
nextAccOp->print(os);
innerAttrs.push_back(StringAttr::get(op->getContext(), os.str()));
}
attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs));
}
return ArrayAttr::get(op->getContext(), attrs);
}
void runOnOperation() override {
Operation *op = getOperation();
LDBG() << "runOnOperation: Starting test-next-access pass on "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
LDBG() << " interprocedural: " << interprocedural;
LDBG() << " assumeFuncReads: " << assumeFuncReads;
SymbolTableCollection symbolTable;
auto config = DataFlowConfig().setInterprocedural(interprocedural);
DataFlowSolver solver(config);
loadBaselineAnalyses(solver);
solver.load<NextAccessAnalysis>(symbolTable, assumeFuncReads);
solver.load<UnderlyingValueAnalysis>();
LDBG() << " Initializing and running dataflow solver";
if (failed(solver.initializeAndRun(op))) {
emitError(op->getLoc(), "dataflow solver failed");
return signalPassFailure();
}
LDBG() << " Dataflow solver completed successfully";
LDBG() << " Walking operations to set next access attributes";
op->walk([&](Operation *op) {
auto tag = op->getAttrOfType<StringAttr>(kTagAttrName);
if (!tag)
return;
LDBG() << " Processing tagged operation: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
const NextAccess *nextAccess =
solver.lookupState<NextAccess>(solver.getProgramPointAfter(op));
op->setAttr(kNextAccessAttrName,
makeNextAccessAttribute(op, solver, nextAccess));
auto iface = dyn_cast<RegionBranchOpInterface>(op);
if (!iface)
return;
SmallVector<Attribute> entryPointNextAccess;
SmallVector<RegionSuccessor> regionSuccessors;
iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors);
for (const RegionSuccessor &successor : regionSuccessors) {
if (!successor.getSuccessor() || successor.getSuccessor()->empty())
continue;
Block &successorBlock = successor.getSuccessor()->front();
ProgramPoint *successorPoint =
solver.getProgramPointBefore(&successorBlock);
entryPointNextAccess.push_back(makeNextAccessAttribute(
op, solver, solver.lookupState<NextAccess>(successorPoint)));
}
op->setAttr(kAtEntryPointAttrName,
ArrayAttr::get(op->getContext(), entryPointNextAccess));
});
}
};
} // namespace
namespace mlir::test {
void registerTestNextAccessPass() { PassRegistration<TestNextAccessPass>(); }
} // namespace mlir::test
|