aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/IR/SymbolTable.h44
-rw-r--r--mlir/lib/IR/SymbolTable.cpp286
-rw-r--r--mlir/test/lib/IR/TestSymbolUses.cpp10
3 files changed, 205 insertions, 135 deletions
diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index bee27ad..e9d0b32 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -137,47 +137,47 @@ public:
/// Get an iterator range for all of the uses, for any symbol, that are nested
/// within the given operation 'from'. This does not traverse into any nested
- /// symbol tables, and will also only return uses on 'from' if it does not
- /// also define a symbol table. This is because we treat the region as the
- /// boundary of the symbol table, and not the op itself. This function returns
- /// None if there are any unknown operations that may potentially be symbol
- /// tables.
+ /// symbol tables. This function returns None if there are any unknown
+ /// operations that may potentially be symbol tables.
static Optional<UseRange> getSymbolUses(Operation *from);
+ static Optional<UseRange> getSymbolUses(Region *from);
/// Get all of the uses of the given symbol that are nested within the given
- /// operation 'from'. This does not traverse into any nested symbol tables,
- /// and will also only return uses on 'from' if it does not also define a
- /// symbol table. This is because we treat the region as the boundary of the
- /// symbol table, and not the op itself. This function returns None if there
- /// are any unknown operations that may potentially be symbol tables.
+ /// operation 'from'. This does not traverse into any nested symbol tables.
+ /// This function returns None if there are any unknown operations that may
+ /// potentially be symbol tables.
static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from);
static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from);
+ static Optional<UseRange> getSymbolUses(StringRef symbol, Region *from);
+ static Optional<UseRange> getSymbolUses(Operation *symbol, Region *from);
/// Return if the given symbol is known to have no uses that are nested
/// within the given operation 'from'. This does not traverse into any nested
- /// symbol tables, and will also only count uses on 'from' if it does not also
- /// define a symbol table. This is because we treat the region as the boundary
- /// of the symbol table, and not the op itself. This function will also return
- /// false if there are any unknown operations that may potentially be symbol
- /// tables. This doesn't necessarily mean that there are no uses, we just
- /// can't conservatively prove it.
+ /// symbol tables. This function will also return false if there are any
+ /// unknown operations that may potentially be symbol tables. This doesn't
+ /// necessarily mean that there are no uses, we just can't conservatively
+ /// prove it.
static bool symbolKnownUseEmpty(StringRef symbol, Operation *from);
static bool symbolKnownUseEmpty(Operation *symbol, Operation *from);
+ static bool symbolKnownUseEmpty(StringRef symbol, Region *from);
+ static bool symbolKnownUseEmpty(Operation *symbol, Region *from);
/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
/// provided symbol 'newSymbol' that are nested within the given operation
- /// 'from'. This does not traverse into any nested symbol tables, and will
- /// also only replace uses on 'from' if it does not also define a symbol
- /// table. This is because we treat the region as the boundary of the symbol
- /// table, and not the op itself. If there are any unknown operations that may
- /// potentially be symbol tables, no uses are replaced and failure is
- /// returned.
+ /// 'from'. This does not traverse into any nested symbol tables. If there are
+ /// any unknown operations that may potentially be symbol tables, no uses are
+ /// replaced and failure is returned.
LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol,
StringRef newSymbol,
Operation *from);
LLVM_NODISCARD static LogicalResult
replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName,
Operation *from);
+ LLVM_NODISCARD static LogicalResult
+ replaceAllSymbolUses(StringRef oldSymbol, StringRef newSymbol, Region *from);
+ LLVM_NODISCARD static LogicalResult
+ replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName,
+ Region *from);
private:
Operation *symbolTableOp;
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 4265ace..bdeff68 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -401,35 +401,19 @@ static WalkResult walkSymbolRefs(
}
/// Walk all of the uses, for any symbol, that are nested within the given
-/// operation 'from', invoking the provided callback for each. This does not
-/// traverse into any nested symbol tables, and will also only return uses on
-/// 'from' if it does not also define a symbol table.
+/// regions, invoking the provided callback for each. This does not traverse
+/// into any nested symbol tables.
static Optional<WalkResult> walkSymbolUses(
- Operation *from,
+ MutableArrayRef<Region> regions,
function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
- // If from is not a symbol table, check for uses. A symbol table defines a new
- // scope, so we can't walk the attributes from the symbol table op.
- if (!from->hasTrait<OpTrait::SymbolTable>()) {
- if (walkSymbolRefs(from, callback).wasInterrupted())
- return WalkResult::interrupt();
- }
-
- SmallVector<Region *, 1> worklist;
- worklist.reserve(from->getNumRegions());
- for (Region &region : from->getRegions())
- worklist.push_back(&region);
-
+ SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
while (!worklist.empty()) {
- Region *region = worklist.pop_back_val();
- for (Block &block : *region) {
+ for (Block &block : *worklist.pop_back_val()) {
for (Operation &op : block) {
if (walkSymbolRefs(&op, callback).wasInterrupted())
return WalkResult::interrupt();
- // If this operation has regions, and it as well as its dialect aren't
- // registered then conservatively fail. The operation may define a
- // symbol table, so we can't opaquely know if we should traverse to find
- // nested uses.
+ // Check that this isn't a potentially unknown symbol table.
if (isPotentiallyUnknownSymbolTable(&op))
return llvm::None;
@@ -444,16 +428,74 @@ static Optional<WalkResult> walkSymbolUses(
}
return WalkResult::advance();
}
+/// Walk all of the uses, for any symbol, that are nested within the given
+/// operaion 'from', invoking the provided callback for each. This does not
+/// traverse into any nested symbol tables.
+static Optional<WalkResult> walkSymbolUses(
+ Operation *from,
+ function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
+ // If this operation has regions, and it, as well as its dialect, isn't
+ // registered then conservatively fail. The operation may define a
+ // symbol table, so we can't opaquely know if we should traverse to find
+ // nested uses.
+ if (isPotentiallyUnknownSymbolTable(from))
+ return llvm::None;
+
+ // Walk the uses on this operation.
+ if (walkSymbolRefs(from, callback).wasInterrupted())
+ return WalkResult::interrupt();
+
+ // Only recurse if this operation is not a symbol table. A symbol table
+ // defines a new scope, so we can't walk the attributes from within the symbol
+ // table op.
+ if (!from->hasTrait<OpTrait::SymbolTable>())
+ return walkSymbolUses(from->getRegions(), callback);
+ return WalkResult::advance();
+}
+
+namespace {
+/// This class represents a single symbol scope. A symbol scope represents the
+/// set of operations nested within a symbol table that may reference symbols
+/// within that table. A symbol scope does not contain the symbol table
+/// operation itself, just its contained operations. A scope ends at leaf
+/// operations or another symbol table operation.
+struct SymbolScope {
+ /// Walk the symbol uses within this scope, invoking the given callback.
+ /// This variant is used when the callback type matches that expected by
+ /// 'walkSymbolUses'.
+ template <typename CallbackT,
+ typename std::enable_if_t<!std::is_same<
+ typename FunctionTraits<CallbackT>::result_t, void>::value> * =
+ nullptr>
+ Optional<WalkResult> walk(CallbackT cback) {
+ if (Region *region = limit.dyn_cast<Region *>())
+ return walkSymbolUses(*region, cback);
+ return walkSymbolUses(limit.get<Operation *>(), cback);
+ }
+ /// This variant is used when the callback type matches a stripped down type:
+ /// void(SymbolTable::SymbolUse use)
+ template <typename CallbackT,
+ typename std::enable_if_t<std::is_same<
+ typename FunctionTraits<CallbackT>::result_t, void>::value> * =
+ nullptr>
+ Optional<WalkResult> walk(CallbackT cback) {
+ return walk([=](SymbolTable::SymbolUse use, ArrayRef<int>) {
+ return cback(use), WalkResult::advance();
+ });
+ }
-/// Walks all of the symbol scopes from 'symbol' to (inclusive) 'limit' invoking
-/// the provided callback at each one with a properly scoped reference to
-/// 'symbol'. The callback takes as parameters the symbol reference at the
-/// current scope as well as the top-level operation representing the top of
-/// that scope.
-static Optional<WalkResult> walkSymbolScopes(
- Operation *symbol, Operation *limit,
- function_ref<Optional<WalkResult>(SymbolRefAttr, Operation *)> callback) {
- StringRef symbolName = SymbolTable::getSymbolName(symbol);
+ /// The representation of the symbol within this scope.
+ SymbolRefAttr symbol;
+
+ /// The IR unit representing this scope.
+ llvm::PointerUnion<Operation *, Region *> limit;
+};
+} // end anonymous namespace
+
+/// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
+static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
+ Operation *limit) {
+ StringRef symName = SymbolTable::getSymbolName(symbol);
assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
// Compute the ancestors of 'limit'.
@@ -466,10 +508,10 @@ static Optional<WalkResult> walkSymbolScopes(
if (limitAncestor == symbol) {
// Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
// doesn't support parent references.
- if (SymbolTable::getNearestSymbolTable(limit) != symbol->getParentOp())
- return WalkResult::advance();
- return callback(SymbolRefAttr::get(symbolName, symbol->getContext()),
- limit);
+ if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
+ symbol->getParentOp())
+ return {{SymbolRefAttr::get(symName, symbol->getContext()), limit}};
+ return {};
}
limitAncestors.insert(limitAncestor);
@@ -486,36 +528,45 @@ static Optional<WalkResult> walkSymbolScopes(
// Compute the set of valid nested references for 'symbol' as far up to the
// common ancestor as possible.
SmallVector<SymbolRefAttr, 2> references;
- bool collectedAllReferences = succeeded(collectValidReferencesFor(
- symbol, symbolName, commonAncestor, references));
+ bool collectedAllReferences = succeeded(
+ collectValidReferencesFor(symbol, symName, commonAncestor, references));
// Handle the case where the common ancestor is 'limit'.
if (commonAncestor == limit) {
+ SmallVector<SymbolScope, 2> scopes;
+
// Walk each of the ancestors of 'symbol', calling the compute function for
// each one.
Operation *limitIt = symbol->getParentOp();
for (size_t i = 0, e = references.size(); i != e;
++i, limitIt = limitIt->getParentOp()) {
- Optional<WalkResult> callbackResult = callback(references[i], limitIt);
- if (callbackResult != WalkResult::advance())
- return callbackResult;
+ assert(limitIt->hasTrait<OpTrait::SymbolTable>());
+ scopes.push_back({references[i], &limitIt->getRegion(0)});
}
- return WalkResult::advance();
+ return scopes;
}
// Otherwise, we just need the symbol reference for 'symbol' that will be
// used within 'limit'. This is the last reference in the list we computed
// above if we were able to collect all references.
if (!collectedAllReferences)
- return WalkResult::advance();
- return callback(references.back(), limit);
+ return {};
+ return {{references.back(), limit}};
}
+static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
+ Region *limit) {
+ auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
-/// Walk the symbol scopes defined by 'limit' invoking the provided callback.
-static Optional<WalkResult> walkSymbolScopes(
- StringRef symbol, Operation *limit,
- function_ref<Optional<WalkResult>(SymbolRefAttr, Operation *)> callback) {
- return callback(SymbolRefAttr::get(symbol, limit->getContext()), limit);
+ // If we collected some scopes to walk, make sure to constrain the one for
+ // limit to the specific region requested.
+ if (!scopes.empty())
+ scopes.back().limit = limit;
+ return scopes;
+}
+template <typename IRUnit>
+static SmallVector<SymbolScope, 1> collectSymbolScopes(StringRef symbol,
+ IRUnit *limit) {
+ return {{SymbolRefAttr::get(symbol, limit->getContext()), limit}};
}
/// Returns true if the given reference 'SubRef' is a sub reference of the
@@ -539,6 +590,18 @@ static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
//===----------------------------------------------------------------------===//
// SymbolTable::getSymbolUses
+/// The implementation of SymbolTable::getSymbolUses below.
+template <typename FromT>
+static Optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
+ std::vector<SymbolTable::SymbolUse> uses;
+ auto walkFn = [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
+ uses.push_back(symbolUse);
+ return WalkResult::advance();
+ };
+ auto result = walkSymbolUses(from, walkFn);
+ return result ? Optional<SymbolTable::UseRange>(std::move(uses)) : llvm::None;
+}
+
/// Get an iterator range for all of the uses, for any symbol, that are nested
/// within the given operation 'from'. This does not traverse into any nested
/// symbol tables, and will also only return uses on 'from' if it does not
@@ -547,43 +610,34 @@ static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
/// None if there are any unknown operations that may potentially be symbol
/// tables.
auto SymbolTable::getSymbolUses(Operation *from) -> Optional<UseRange> {
- std::vector<SymbolUse> uses;
- auto walkFn = [&](SymbolUse symbolUse, ArrayRef<int>) {
- uses.push_back(symbolUse);
- return WalkResult::advance();
- };
- auto result = walkSymbolUses(from, walkFn);
- return result ? Optional<UseRange>(std::move(uses)) : Optional<UseRange>();
+ return getSymbolUsesImpl(from);
+}
+auto SymbolTable::getSymbolUses(Region *from) -> Optional<UseRange> {
+ return getSymbolUsesImpl(MutableArrayRef<Region>(*from));
}
//===----------------------------------------------------------------------===//
// SymbolTable::getSymbolUses
/// The implementation of SymbolTable::getSymbolUses below.
-template <typename SymbolT>
+template <typename SymbolT, typename IRUnitT>
static Optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
- Operation *limit) {
+ IRUnitT *limit) {
std::vector<SymbolTable::SymbolUse> uses;
- auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) {
- return walkSymbolUses(
- from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
- if (isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef()))
+ for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
+ if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
+ if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
uses.push_back(symbolUse);
- return WalkResult::advance();
- });
- };
- if (walkSymbolScopes(symbol, limit, walkFn))
- return SymbolTable::UseRange(std::move(uses));
- return llvm::None;
+ }))
+ return llvm::None;
+ }
+ return SymbolTable::UseRange(std::move(uses));
}
/// Get all of the uses of the given symbol that are nested within the given
/// operation 'from', invoking the provided callback for each. This does not
-/// traverse into any nested symbol tables, and will also only return uses on
-/// 'from' if it does not also define a symbol table. This is because we treat
-/// the region as the boundary of the symbol table, and not the op itself. This
-/// function returns None if there are any unknown operations that may
-/// potentially be symbol tables.
+/// traverse into any nested symbol tables. This function returns None if there
+/// are any unknown operations that may potentially be symbol tables.
auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from)
-> Optional<UseRange> {
return getSymbolUsesImpl(symbol, from);
@@ -592,37 +646,49 @@ auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from)
-> Optional<UseRange> {
return getSymbolUsesImpl(symbol, from);
}
+auto SymbolTable::getSymbolUses(StringRef symbol, Region *from)
+ -> Optional<UseRange> {
+ return getSymbolUsesImpl(symbol, from);
+}
+auto SymbolTable::getSymbolUses(Operation *symbol, Region *from)
+ -> Optional<UseRange> {
+ return getSymbolUsesImpl(symbol, from);
+}
//===----------------------------------------------------------------------===//
// SymbolTable::symbolKnownUseEmpty
/// The implementation of SymbolTable::symbolKnownUseEmpty below.
-template <typename SymbolT>
-static bool symbolKnownUseEmptyImpl(SymbolT symbol, Operation *limit) {
- // Walk all of the symbol uses looking for a reference to 'symbol'.
- auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) {
- return walkSymbolUses(
- from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
- return isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef())
+template <typename SymbolT, typename IRUnitT>
+static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
+ for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
+ // Walk all of the symbol uses looking for a reference to 'symbol'.
+ if (scope.walk([&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
+ return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
? WalkResult::interrupt()
: WalkResult::advance();
- });
- };
- return walkSymbolScopes(symbol, limit, walkFn) == WalkResult::advance();
+ }) != WalkResult::advance())
+ return false;
+ }
+ return true;
}
/// Return if the given symbol is known to have no uses that are nested within
/// the given operation 'from'. This does not traverse into any nested symbol
-/// tables, and will also only count uses on 'from' if it does not also define
-/// a symbol table. This is because we treat the region as the boundary of the
-/// symbol table, and not the op itself. This function will also return false if
-/// there are any unknown operations that may potentially be symbol tables.
+/// tables. This function will also return false if there are any unknown
+/// operations that may potentially be symbol tables.
bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from) {
return symbolKnownUseEmptyImpl(symbol, from);
}
bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
return symbolKnownUseEmptyImpl(symbol, from);
}
+bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Region *from) {
+ return symbolKnownUseEmptyImpl(symbol, from);
+}
+bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
+ return symbolKnownUseEmptyImpl(symbol, from);
+}
//===----------------------------------------------------------------------===//
// SymbolTable::replaceAllSymbolUses
@@ -685,10 +751,9 @@ static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
}
/// The implementation of SymbolTable::replaceAllSymbolUses below.
-template <typename SymbolT>
-static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol,
- StringRef newSymbol,
- Operation *limit) {
+template <typename SymbolT, typename IRUnitT>
+static LogicalResult
+replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
// A collection of operations along with their new attribute dictionary.
std::vector<std::pair<Operation *, DictionaryAttr>> updatedAttrDicts;
@@ -710,26 +775,26 @@ static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol,
// Generate a new attribute to replace the given attribute.
MLIRContext *ctx = limit->getContext();
FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx);
- auto scopeWalkFn = [&](SymbolRefAttr oldAttr,
- Operation *from) -> Optional<WalkResult> {
- SymbolRefAttr newAttr = generateNewRefAttr(oldAttr, newLeafAttr);
+ for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
+ SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
ArrayRef<int> accessChain) {
SymbolRefAttr useRef = symbolUse.getSymbolRef();
- if (!isReferencePrefixOf(oldAttr, useRef))
+ if (!isReferencePrefixOf(scope.symbol, useRef))
return WalkResult::advance();
// If we have a valid match, check to see if this is a proper
// subreference. If it is, then we will need to generate a different new
// attribute specifically for this use.
SymbolRefAttr replacementRef = newAttr;
- if (useRef != oldAttr) {
- if (oldAttr.isa<FlatSymbolRefAttr>()) {
+ if (useRef != scope.symbol) {
+ if (scope.symbol.isa<FlatSymbolRefAttr>()) {
replacementRef =
SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx);
} else {
auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences());
- nestedRefs[oldAttr.getNestedReferences().size() - 1] = newLeafAttr;
+ nestedRefs[scope.symbol.getNestedReferences().size() - 1] =
+ newLeafAttr;
replacementRef =
SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx);
}
@@ -748,18 +813,15 @@ static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol,
accessChains.push_back({llvm::to_vector<1>(accessChain), replacementRef});
return WalkResult::advance();
};
- if (!walkSymbolUses(from, walkFn))
- return llvm::None;
+ if (!scope.walk(walkFn))
+ return failure();
// Check to see if we have a dangling op that needs to be processed.
if (curOp) {
updatedAttrDicts.push_back({curOp, generateNewAttrDict()});
curOp = nullptr;
}
- return WalkResult::advance();
- };
- if (!walkSymbolScopes(symbol, limit, scopeWalkFn))
- return failure();
+ }
// Update the attribute dictionaries as necessary.
for (auto &it : updatedAttrDicts)
@@ -769,11 +831,9 @@ static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol,
/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
/// provided symbol 'newSymbol' that are nested within the given operation
-/// 'from'. This does not traverse into any nested symbol tables, and will
-/// also only replace uses on 'from' if it does not also define a symbol
-/// table. This is because we treat the region as the boundary of the symbol
-/// table, and not the op itself. If there are any unknown operations that may
-/// potentially be symbol tables, no uses are replaced and failure is returned.
+/// 'from'. This does not traverse into any nested symbol tables. If there are
+/// any unknown operations that may potentially be symbol tables, no uses are
+/// replaced and failure is returned.
LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol,
StringRef newSymbol,
Operation *from) {
@@ -784,3 +844,13 @@ LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
Operation *from) {
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
}
+LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol,
+ StringRef newSymbol,
+ Region *from) {
+ return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
+}
+LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
+ StringRef newSymbol,
+ Region *from) {
+ return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
+}
diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
index 89fd5ef..7757da5 100644
--- a/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -16,7 +16,7 @@ namespace {
/// This is a symbol test pass that tests the symbol uselist functionality
/// provided by the symbol table along with erasing from the symbol table.
struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
- WalkResult operateOnSymbol(Operation *symbol, Operation *module,
+ WalkResult operateOnSymbol(Operation *symbol, ModuleOp module,
SmallVectorImpl<FuncOp> &deadFunctions) {
// Test computing uses on a non symboltable op.
Optional<SymbolTable::UseRange> symbolUses =
@@ -34,7 +34,7 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
<< " nested references";
// Test the functionality of symbolKnownUseEmpty.
- if (SymbolTable::symbolKnownUseEmpty(symbol, module)) {
+ if (SymbolTable::symbolKnownUseEmpty(symbol, &module.getBodyRegion())) {
FuncOp funcSymbol = dyn_cast<FuncOp>(symbol);
if (funcSymbol && funcSymbol.isExternal())
deadFunctions.push_back(funcSymbol);
@@ -44,7 +44,7 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
}
// Test the functionality of getSymbolUses.
- symbolUses = SymbolTable::getSymbolUses(symbol, module);
+ symbolUses = SymbolTable::getSymbolUses(symbol, &module.getBodyRegion());
assert(symbolUses.hasValue() && "expected no unknown operations");
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
// Check that we can resolve back to our symbol.
@@ -70,10 +70,10 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
return WalkResult::advance();
});
+ SymbolTable table(module);
for (Operation *op : deadFunctions) {
// In order to test the SymbolTable::erase method, also erase completely
// useless functions.
- SymbolTable table(module);
auto name = SymbolTable::getSymbolName(op);
assert(table.lookup(name) && "expected no unknown operations");
table.erase(op);
@@ -96,7 +96,7 @@ struct SymbolReplacementPass : public ModulePass<SymbolReplacementPass> {
if (!newName)
return;
if (succeeded(SymbolTable::replaceAllSymbolUses(
- nestedOp, newName.getValue(), module)))
+ nestedOp, newName.getValue(), &module.getBodyRegion())))
SymbolTable::setSymbolName(nestedOp, newName.getValue());
});
}