diff options
Diffstat (limited to 'mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp')
-rw-r--r-- | mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp | 27 |
1 files changed, 16 insertions, 11 deletions
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 349841f..1eb27e4 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -480,18 +480,21 @@ bool AnalysisState::isValueRead(Value value) const { return false; } -// Starting from `value`, follow the use-def chain in reverse, always selecting -// the aliasing OpOperands. Find and return Values for which `condition` -// evaluates to true. OpOperands of such matching Values are not traversed any -// further, the visited aliasing opOperands will be preserved through -// `visitedOpOperands`. +// Starting from `opOperand`, follow the use-def chain in reverse, always +// selecting the aliasing OpOperands. Find and return Values for which +// `condition` evaluates to true. Uses of such matching Values are not +// traversed any further, the visited aliasing opOperands will be preserved +// through `visitedOpOperands`. llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( - Value value, llvm::function_ref<bool(Value)> condition, + OpOperand *opOperand, llvm::function_ref<bool(Value)> condition, TraversalConfig config, llvm::DenseSet<OpOperand *> *visitedOpOperands) const { llvm::DenseSet<Value> visited; llvm::SetVector<Value> result, workingSet; - workingSet.insert(value); + workingSet.insert(opOperand->get()); + + if (visitedOpOperands) + visitedOpOperands->insert(opOperand); while (!workingSet.empty()) { Value value = workingSet.pop_back_val(); @@ -563,12 +566,14 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain( return result; } -// Find the values that define the contents of the given value. -llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const { +// Find the values that define the contents of the given operand's value. +llvm::SetVector<Value> +AnalysisState::findDefinitions(OpOperand *opOperand) const { TraversalConfig config; config.alwaysIncludeLeaves = false; return findValueInReverseUseDefChain( - value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config); + opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, + config); } AnalysisState::AnalysisState(const BufferizationOptions &options) @@ -892,7 +897,7 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite( config.alwaysIncludeLeaves = false; for (AliasingOpOperand alias : opOperands) { if (!state - .findValueInReverseUseDefChain(alias.opOperand->get(), + .findValueInReverseUseDefChain(alias.opOperand, isMemoryWriteInsideOp, config) .empty()) return true; |