aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
-rw-r--r--mlir/lib/Bytecode/Reader/BytecodeReader.cpp260
1 files changed, 206 insertions, 54 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 1659437..0ac5fc5 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -27,6 +27,7 @@
#include <cstddef>
#include <cstdint>
+#include <deque>
#include <list>
#include <memory>
#include <numeric>
@@ -830,6 +831,23 @@ namespace {
/// This class provides support for reading attribute and type entries from the
/// bytecode. Attribute and Type entries are read lazily on demand, so we use
/// this reader to manage when to actually parse them from the bytecode.
+///
+/// The parsing of attributes & types are generally recursive, this can lead to
+/// stack overflows for deeply nested structures, so we track a few extra pieces
+/// of information to avoid this:
+///
+/// - `depth`: The current depth while parsing nested attributes. We defer on
+/// parsing deeply nested attributes to avoid potential stack overflows. The
+/// deferred parsing is achieved by reporting a failure when parsing a nested
+/// attribute/type and registering the index of the encountered attribute/type
+/// in the deferred parsing worklist. Hence, a failure with deffered entry
+/// does not constitute a failure, it also requires that folks return on
+/// first failure rather than attempting additional parses.
+/// - `deferredWorklist`: A list of attribute/type indices that we could not
+/// parse due to hitting the depth limit. The worklist is used to capture the
+/// indices of attributes/types that need to be parsed/reparsed when we hit
+/// the depth limit. This enables moving the tracking of what needs to be
+/// parsed to the heap.
class AttrTypeReader {
/// This class represents a single attribute or type entry.
template <typename T>
@@ -863,12 +881,34 @@ public:
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData);
+ LogicalResult readAttribute(uint64_t index, Attribute &result,
+ uint64_t depth = 0) {
+ return readEntry(attributes, index, result, "attribute", depth);
+ }
+
+ LogicalResult readType(uint64_t index, Type &result, uint64_t depth = 0) {
+ return readEntry(types, index, result, "type", depth);
+ }
+
/// Resolve the attribute or type at the given index. Returns nullptr on
/// failure.
- Attribute resolveAttribute(size_t index) {
- return resolveEntry(attributes, index, "Attribute");
+ Attribute resolveAttribute(size_t index, uint64_t depth = 0) {
+ return resolveEntry(attributes, index, "Attribute", depth);
+ }
+ Type resolveType(size_t index, uint64_t depth = 0) {
+ return resolveEntry(types, index, "Type", depth);
+ }
+
+ Attribute getAttributeOrSentinel(size_t index) {
+ if (index >= attributes.size())
+ return nullptr;
+ return attributes[index].entry;
+ }
+ Type getTypeOrSentinel(size_t index) {
+ if (index >= types.size())
+ return nullptr;
+ return types[index].entry;
}
- Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); }
/// Parse a reference to an attribute or type using the given reader.
LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
@@ -909,23 +949,33 @@ public:
llvm::getTypeName<T>(), ", but got: ", baseResult);
}
+ /// Add an index to the deferred worklist for re-parsing.
+ void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); }
+
private:
/// Resolve the given entry at `index`.
template <typename T>
- T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
- StringRef entryType);
+ T resolveEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
+ StringRef entryType, uint64_t depth = 0);
- /// Parse an entry using the given reader that was encoded using the textual
- /// assembly format.
+ /// Read the entry at the given index, returning failure if the entry is not
+ /// yet resolved.
template <typename T>
- LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
- StringRef entryType);
+ LogicalResult readEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
+ T &result, StringRef entryType, uint64_t depth);
/// Parse an entry using the given reader that was encoded using a custom
/// bytecode format.
template <typename T>
LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
- StringRef entryType);
+ StringRef entryType, uint64_t index,
+ uint64_t depth);
+
+ /// Parse an entry using the given reader that was encoded using the textual
+ /// assembly format.
+ template <typename T>
+ LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
+ StringRef entryType);
/// The string section reader used to resolve string references when parsing
/// custom encoded attribute/type entries.
@@ -951,6 +1001,10 @@ private:
/// Reference to the parser configuration.
const ParserConfig &parserConfig;
+
+ /// Worklist for deferred attribute/type parsing. This is used to handle
+ /// deeply nested structures like CallSiteLoc iteratively.
+ std::vector<uint64_t> deferredWorklist;
};
class DialectReader : public DialectBytecodeReader {
@@ -959,10 +1013,11 @@ public:
const StringSectionReader &stringReader,
const ResourceSectionReader &resourceReader,
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
- EncodingReader &reader, uint64_t &bytecodeVersion)
+ EncodingReader &reader, uint64_t &bytecodeVersion,
+ uint64_t depth = 0)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
resourceReader(resourceReader), dialectsMap(dialectsMap),
- reader(reader), bytecodeVersion(bytecodeVersion) {}
+ reader(reader), bytecodeVersion(bytecodeVersion), depth(depth) {}
InFlightDiagnostic emitError(const Twine &msg) const override {
return reader.emitError(msg);
@@ -998,14 +1053,40 @@ public:
// IR
//===--------------------------------------------------------------------===//
+ /// The maximum depth to eagerly parse nested attributes/types before
+ /// deferring.
+ static constexpr uint64_t maxAttrTypeDepth = 5;
+
LogicalResult readAttribute(Attribute &result) override {
- return attrTypeReader.parseAttribute(reader, result);
+ uint64_t index;
+ if (failed(reader.parseVarInt(index)))
+ return failure();
+ if (depth > maxAttrTypeDepth) {
+ if (Attribute attr = attrTypeReader.getAttributeOrSentinel(index)) {
+ result = attr;
+ return success();
+ }
+ attrTypeReader.addDeferredParsing(index);
+ return failure();
+ }
+ return attrTypeReader.readAttribute(index, result, depth + 1);
}
LogicalResult readOptionalAttribute(Attribute &result) override {
return attrTypeReader.parseOptionalAttribute(reader, result);
}
LogicalResult readType(Type &result) override {
- return attrTypeReader.parseType(reader, result);
+ uint64_t index;
+ if (failed(reader.parseVarInt(index)))
+ return failure();
+ if (depth > maxAttrTypeDepth) {
+ if (Type type = attrTypeReader.getTypeOrSentinel(index)) {
+ result = type;
+ return success();
+ }
+ attrTypeReader.addDeferredParsing(index);
+ return failure();
+ }
+ return attrTypeReader.readType(index, result, depth + 1);
}
FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
@@ -1095,6 +1176,7 @@ private:
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
EncodingReader &reader;
uint64_t &bytecodeVersion;
+ uint64_t depth;
};
/// Wraps the properties section and handles reading properties out of it.
@@ -1238,69 +1320,112 @@ LogicalResult AttrTypeReader::initialize(
}
template <typename T>
-T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
- StringRef entryType) {
+T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries,
+ uint64_t index, StringRef entryType,
+ uint64_t depth) {
if (index >= entries.size()) {
emitError(fileLoc) << "invalid " << entryType << " index: " << index;
return {};
}
- // If the entry has already been resolved, there is nothing left to do.
- Entry<T> &entry = entries[index];
- if (entry.entry)
- return entry.entry;
+ // Fast path: Try direct parsing without worklist overhead. This handles the
+ // common case where there are no deferred dependencies.
+ assert(deferredWorklist.empty());
+ T result;
+ if (succeeded(readEntry(entries, index, result, entryType, depth))) {
+ assert(deferredWorklist.empty());
+ return result;
+ }
+ if (deferredWorklist.empty()) {
+ // Failed with no deferred entries is error.
+ return T();
+ }
- // Parse the entry.
- EncodingReader reader(entry.data, fileLoc);
+ // Slow path: Use worklist to handle deferred dependencies. Use a deque to
+ // iteratively resolve entries with dependencies.
+ // - Pop from front to process
+ // - Push new dependencies to front (depth-first)
+ // - Move failed entries to back (retry after dependencies)
+ std::deque<size_t> worklist;
+ llvm::DenseSet<size_t> inWorklist;
- // Parse based on how the entry was encoded.
- if (entry.hasCustomEncoding) {
- if (failed(parseCustomEntry(entry, reader, entryType)))
- return T();
- } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
- return T();
+ // Add the original index and any dependencies from the fast path attempt.
+ worklist.push_back(index);
+ inWorklist.insert(index);
+ for (uint64_t idx : llvm::reverse(deferredWorklist)) {
+ if (inWorklist.insert(idx).second)
+ worklist.push_front(idx);
}
- if (!reader.empty()) {
- reader.emitError("unexpected trailing bytes after " + entryType + " entry");
- return T();
+ while (!worklist.empty()) {
+ size_t currentIndex = worklist.front();
+ worklist.pop_front();
+
+ // Clear the deferred worklist before parsing to capture any new entries.
+ deferredWorklist.clear();
+
+ T result;
+ if (succeeded(readEntry(entries, currentIndex, result, entryType, depth))) {
+ inWorklist.erase(currentIndex);
+ continue;
+ }
+
+ if (deferredWorklist.empty()) {
+ // Parsing failed with no deferred entries which implies an error.
+ return T();
+ }
+
+ // Move this entry to the back to retry after dependencies.
+ worklist.push_back(currentIndex);
+
+ // Add dependencies to the front (in reverse so they maintain order).
+ for (uint64_t idx : llvm::reverse(deferredWorklist)) {
+ if (inWorklist.insert(idx).second)
+ worklist.push_front(idx);
+ }
+ deferredWorklist.clear();
}
- return entry.entry;
+ return entries[index].entry;
}
template <typename T>
-LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
- StringRef entryType) {
- StringRef asmStr;
- if (failed(reader.parseNullTerminatedString(asmStr)))
- return failure();
+LogicalResult AttrTypeReader::readEntry(SmallVectorImpl<Entry<T>> &entries,
+ uint64_t index, T &result,
+ StringRef entryType, uint64_t depth) {
+ if (index >= entries.size())
+ return emitError(fileLoc) << "invalid " << entryType << " index: " << index;
- // Invoke the MLIR assembly parser to parse the entry text.
- size_t numRead = 0;
- MLIRContext *context = fileLoc->getContext();
- if constexpr (std::is_same_v<T, Type>)
- result =
- ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
- else
- result = ::parseAttribute(asmStr, context, Type(), &numRead,
- /*isKnownNullTerminated=*/true);
- if (!result)
+ // If the entry has already been resolved, return it.
+ Entry<T> &entry = entries[index];
+ if (entry.entry) {
+ result = entry.entry;
+ return success();
+ }
+
+ // If the entry hasn't been resolved, try to parse it.
+ EncodingReader reader(entry.data, fileLoc);
+ LogicalResult parseResult =
+ entry.hasCustomEncoding
+ ? parseCustomEntry(entry, reader, entryType, index, depth)
+ : parseAsmEntry(entry.entry, reader, entryType);
+ if (failed(parseResult))
return failure();
- // Ensure there weren't dangling characters after the entry.
- if (numRead != asmStr.size()) {
- return reader.emitError("trailing characters found after ", entryType,
- " assembly format: ", asmStr.drop_front(numRead));
- }
+ if (!reader.empty())
+ return reader.emitError("unexpected trailing bytes after " + entryType +
+ " entry");
+
+ result = entry.entry;
return success();
}
template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
- StringRef entryType) {
+ StringRef entryType,
+ uint64_t index, uint64_t depth) {
DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
- reader, bytecodeVersion);
+ reader, bytecodeVersion, depth);
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
@@ -1350,6 +1475,33 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
return success(!!entry.entry);
}
+template <typename T>
+LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
+ StringRef entryType) {
+ StringRef asmStr;
+ if (failed(reader.parseNullTerminatedString(asmStr)))
+ return failure();
+
+ // Invoke the MLIR assembly parser to parse the entry text.
+ size_t numRead = 0;
+ MLIRContext *context = fileLoc->getContext();
+ if constexpr (std::is_same_v<T, Type>)
+ result =
+ ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
+ else
+ result = ::parseAttribute(asmStr, context, Type(), &numRead,
+ /*isKnownNullTerminated=*/true);
+ if (!result)
+ return failure();
+
+ // Ensure there weren't dangling characters after the entry.
+ if (numRead != asmStr.size()) {
+ return reader.emitError("trailing characters found after ", entryType,
+ " assembly format: ", asmStr.drop_front(numRead));
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Bytecode Reader
//===----------------------------------------------------------------------===//