aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
blob: 0b2e080e52b751d491ab6c4801a3d436659ae987 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
//======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"

#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/SetOperations.h"

using namespace mlir;
using namespace mlir::bufferization;

//===----------------------------------------------------------------------===//
// BufferViewFlowAnalysis
//===----------------------------------------------------------------------===//

/// Constructs a new alias analysis using the op provided.
BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }

static BufferViewFlowAnalysis::ValueSetT
resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value) {
  BufferViewFlowAnalysis::ValueSetT result;
  SmallVector<Value, 8> queue;
  queue.push_back(value);
  while (!queue.empty()) {
    Value currentValue = queue.pop_back_val();
    if (result.insert(currentValue).second) {
      auto it = map.find(currentValue);
      if (it != map.end()) {
        for (Value aliasValue : it->second)
          queue.push_back(aliasValue);
      }
    }
  }
  return result;
}

/// Find all immediate and indirect dependent buffers this value could
/// potentially have. Note that the resulting set will also contain the value
/// provided as it is a dependent alias of itself.
BufferViewFlowAnalysis::ValueSetT
BufferViewFlowAnalysis::resolve(Value rootValue) const {
  return resolveValues(dependencies, rootValue);
}

BufferViewFlowAnalysis::ValueSetT
BufferViewFlowAnalysis::resolveReverse(Value rootValue) const {
  return resolveValues(reverseDependencies, rootValue);
}

/// Removes the given values from all alias sets.
void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
  for (auto &entry : dependencies)
    llvm::set_subtract(entry.second, aliasValues);
}

void BufferViewFlowAnalysis::rename(Value from, Value to) {
  dependencies[to] = dependencies[from];
  dependencies.erase(from);

  for (auto &[_, value] : dependencies) {
    if (value.contains(from)) {
      value.insert(to);
      value.erase(from);
    }
  }
}

/// This function constructs a mapping from values to its immediate
/// dependencies. It iterates over all blocks, gets their predecessors,
/// determines the values that will be passed to the corresponding block
/// arguments and inserts them into the underlying map. Furthermore, it wires
/// successor regions and branch-like return operations from nested regions.
void BufferViewFlowAnalysis::build(Operation *op) {
  // Registers all dependencies of the given values.
  auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
    for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {
      this->dependencies[value].insert(dep);
      this->reverseDependencies[dep].insert(value);
    }
  };

  // Mark all buffer results and buffer region entry block arguments of the
  // given op as terminals.
  auto populateTerminalValues = [&](Operation *op) {
    for (Value v : op->getResults())
      if (isa<BaseMemRefType>(v.getType()))
        this->terminals.insert(v);
    for (Region &r : op->getRegions())
      for (BlockArgument v : r.getArguments())
        if (isa<BaseMemRefType>(v.getType()))
          this->terminals.insert(v);
  };

  op->walk([&](Operation *op) {
    // Query BufferViewFlowOpInterface. If the op does not implement that
    // interface, try to infer the dependencies from other interfaces that the
    // op may implement.
    if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
      bufferViewFlowOp.populateDependencies(registerDependencies);
      for (Value v : op->getResults())
        if (isa<BaseMemRefType>(v.getType()) &&
            bufferViewFlowOp.mayBeTerminalBuffer(v))
          this->terminals.insert(v);
      for (Region &r : op->getRegions())
        for (BlockArgument v : r.getArguments())
          if (isa<BaseMemRefType>(v.getType()) &&
              bufferViewFlowOp.mayBeTerminalBuffer(v))
            this->terminals.insert(v);
      return WalkResult::advance();
    }

    // Add additional dependencies created by view changes to the alias list.
    if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
      registerDependencies(viewInterface.getViewSource(),
                           viewInterface.getViewDest());
      return WalkResult::advance();
    }

    if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
      // Query all branch interfaces to link block argument dependencies.
      Block *parentBlock = branchInterface->getBlock();
      for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
           it != e; ++it) {
        // Query the branch op interface to get the successor operands.
        auto successorOperands =
            branchInterface.getSuccessorOperands(it.getIndex());
        // Build the actual mapping of values to their immediate dependencies.
        registerDependencies(successorOperands.getForwardedOperands(),
                             (*it)->getArguments().drop_front(
                                 successorOperands.getProducedOperandCount()));
      }
      return WalkResult::advance();
    }

    if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
      // Query the RegionBranchOpInterface to find potential successor regions.
      // Extract all entry regions and wire all initial entry successor inputs.
      SmallVector<RegionSuccessor, 2> entrySuccessors;
      regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
                                          entrySuccessors);
      for (RegionSuccessor &entrySuccessor : entrySuccessors) {
        // Wire the entry region's successor arguments with the initial
        // successor inputs.
        registerDependencies(
            regionInterface.getEntrySuccessorOperands(entrySuccessor),
            entrySuccessor.getSuccessorInputs());
      }

      // Wire flow between regions and from region exits.
      for (Region &region : regionInterface->getRegions()) {
        // Iterate over all successor region entries that are reachable from the
        // current region.
        SmallVector<RegionSuccessor, 2> successorRegions;
        regionInterface.getSuccessorRegions(region, successorRegions);
        for (RegionSuccessor &successorRegion : successorRegions) {
          // Iterate over all immediate terminator operations and wire the
          // successor inputs with the successor operands of each terminator.
          for (Block &block : region)
            if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
                    block.getTerminator()))
              registerDependencies(
                  terminator.getSuccessorOperands(successorRegion),
                  successorRegion.getSuccessorInputs());
        }
      }

      return WalkResult::advance();
    }

    // Region terminators are handled together with RegionBranchOpInterface.
    if (isa<RegionBranchTerminatorOpInterface>(op))
      return WalkResult::advance();

    if (isa<CallOpInterface>(op)) {
      // This is an intra-function analysis. We have no information about other
      // functions. Conservatively assume that each operand may alias with each
      // result. Also mark the results are terminals because the function could
      // return newly allocated buffers.
      populateTerminalValues(op);
      for (Value operand : op->getOperands())
        for (Value result : op->getResults())
          registerDependencies({operand}, {result});
      return WalkResult::advance();
    }

    // We have no information about unknown ops.
    populateTerminalValues(op);

    return WalkResult::advance();
  });
}

bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
  assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
  return terminals.contains(value);
}

//===----------------------------------------------------------------------===//
// BufferOriginAnalysis
//===----------------------------------------------------------------------===//

/// Return "true" if the given value is the result of a memory allocation.
static bool hasAllocateSideEffect(Value v) {
  Operation *op = v.getDefiningOp();
  if (!op)
    return false;
  return hasEffect<MemoryEffects::Allocate>(op, v);
}

/// Return "true" if the given value is a function block argument.
static bool isFunctionArgument(Value v) {
  auto bbArg = dyn_cast<BlockArgument>(v);
  if (!bbArg)
    return false;
  Block *b = bbArg.getOwner();
  auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp());
  if (!funcOp)
    return false;
  return bbArg.getOwner() == &funcOp.getFunctionBody().front();
}

/// Given a memref value, return the "base" value by skipping over all
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
static Value getViewBase(Value value) {
  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) {
    if (value != viewLikeOp.getViewDest()) {
      break;
    }
    value = viewLikeOp.getViewSource();
  }
  return value;
}

BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {}

std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
  assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
  assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");

  // Skip over all view-like ops.
  v1 = getViewBase(v1);
  v2 = getViewBase(v2);

  // Fast path: If both buffers are the same SSA value, we can be sure that
  // they originate from the same allocation.
  if (v1 == v2)
    return true;

  // Compute the SSA values from which the buffers `v1` and `v2` originate.
  SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(v1);
  SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(v2);

  // Originating buffers are "terminal" if they could not be traced back any
  // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers:
  // - function block arguments
  // - values defined by allocation ops such as "memref.alloc"
  // - values defined by ops that are unknown to the buffer view flow analysis
  // - values that are marked as "terminal" in the `BufferViewFlowOpInterface`
  SmallPtrSet<Value, 16> terminal1, terminal2;

  // While gathering terminal buffers, keep track of whether all terminal
  // buffers are newly allocated buffer or function entry arguments.
  bool allAllocs1 = true, allAllocs2 = true;
  bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;

  // Helper function that gathers terminal buffers among `origin`.
  auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin,
                                      SmallPtrSet<Value, 16> &terminal,
                                      bool &allAllocs,
                                      bool &allAllocsOrFuncEntryArgs) {
    for (Value v : origin) {
      if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
        terminal.insert(v);
        allAllocs &= hasAllocateSideEffect(v);
        allAllocsOrFuncEntryArgs &=
            isFunctionArgument(v) || hasAllocateSideEffect(v);
      }
    }
    assert(!terminal.empty() && "expected non-empty terminal set");
  };

  // Gather terminal buffers for `v1` and `v2`.
  gatherTerminalBuffers(origin1, terminal1, allAllocs1,
                        allAllocsOrFuncEntryArgs1);
  gatherTerminalBuffers(origin2, terminal2, allAllocs2,
                        allAllocsOrFuncEntryArgs2);

  // If both `v1` and `v2` have a single matching terminal buffer, they are
  // guaranteed to originate from the same buffer allocation.
  if (llvm::hasSingleElement(terminal1) && llvm::hasSingleElement(terminal2) &&
      *terminal1.begin() == *terminal2.begin())
    return true;

  // At least one of the two values has multiple terminals.

  // Check if there is overlap between the terminal buffers of `v1` and `v2`.
  bool distinctTerminalSets = true;
  for (Value v : terminal1)
    distinctTerminalSets &= !terminal2.contains(v);
  // If there is overlap between the terminal buffers of `v1` and `v2`, we
  // cannot make an accurate decision without further analysis.
  if (!distinctTerminalSets)
    return std::nullopt;

  // If `v1` originates from only allocs, and `v2` is guaranteed to originate
  // from different allocations (that is guaranteed if `v2` originates from
  // only distinct allocs or function entry arguments), we can be sure that
  // `v1` and `v2` originate from different allocations. The same argument can
  // be made when swapping `v1` and `v2`.
  bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
  bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
  if (isolatedAlloc1 || isolatedAlloc2)
    return false;

  // Otherwise: We do not know whether `v1` and `v2` originate from the same
  // allocation or not.
  // TODO: Function arguments are currently handled conservatively. We assume
  // that they could be the same allocation.
  // TODO: Terminals other than allocations and function arguments are
  // currently handled conservatively. We assume that they could be the same
  // allocation. E.g., we currently return "nullopt" for values that originate
  // from different "memref.get_global" ops (with different symbols).
  return std::nullopt;
}