diff options
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
| -rw-r--r-- | mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 327 |
1 files changed, 269 insertions, 58 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 1659437..c974603 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -27,6 +27,7 @@ #include <cstddef> #include <cstdint> +#include <deque> #include <list> #include <memory> #include <numeric> @@ -830,6 +831,23 @@ namespace { /// This class provides support for reading attribute and type entries from the /// bytecode. Attribute and Type entries are read lazily on demand, so we use /// this reader to manage when to actually parse them from the bytecode. +/// +/// The parsing of attributes & types are generally recursive, this can lead to +/// stack overflows for deeply nested structures, so we track a few extra pieces +/// of information to avoid this: +/// +/// - `depth`: The current depth while parsing nested attributes. We defer on +/// parsing deeply nested attributes to avoid potential stack overflows. The +/// deferred parsing is achieved by reporting a failure when parsing a nested +/// attribute/type and registering the index of the encountered attribute/type +/// in the deferred parsing worklist. Hence, a failure with deffered entry +/// does not constitute a failure, it also requires that folks return on +/// first failure rather than attempting additional parses. +/// - `deferredWorklist`: A list of attribute/type indices that we could not +/// parse due to hitting the depth limit. The worklist is used to capture the +/// indices of attributes/types that need to be parsed/reparsed when we hit +/// the depth limit. This enables moving the tracking of what needs to be +/// parsed to the heap. class AttrTypeReader { /// This class represents a single attribute or type entry. template <typename T> @@ -863,12 +881,34 @@ public: ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData); + LogicalResult readAttribute(uint64_t index, Attribute &result, + uint64_t depth = 0) { + return readEntry(attributes, index, result, "attribute", depth); + } + + LogicalResult readType(uint64_t index, Type &result, uint64_t depth = 0) { + return readEntry(types, index, result, "type", depth); + } + /// Resolve the attribute or type at the given index. Returns nullptr on /// failure. - Attribute resolveAttribute(size_t index) { - return resolveEntry(attributes, index, "Attribute"); + Attribute resolveAttribute(size_t index, uint64_t depth = 0) { + return resolveEntry(attributes, index, "Attribute", depth); + } + Type resolveType(size_t index, uint64_t depth = 0) { + return resolveEntry(types, index, "Type", depth); + } + + Attribute getAttributeOrSentinel(size_t index) { + if (index >= attributes.size()) + return nullptr; + return attributes[index].entry; + } + Type getTypeOrSentinel(size_t index) { + if (index >= types.size()) + return nullptr; + return types[index].entry; } - Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } /// Parse a reference to an attribute or type using the given reader. LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { @@ -909,23 +949,41 @@ public: llvm::getTypeName<T>(), ", but got: ", baseResult); } + /// The kind of entry being parsed. + enum class EntryKind { Attribute, Type }; + + /// Add an index to the deferred worklist for re-parsing. + void addDeferredParsing(uint64_t index, EntryKind kind) { + deferredWorklist.emplace_back(index, kind); + } + + /// Whether currently resolving. + bool isResolving() const { return resolving; } + private: /// Resolve the given entry at `index`. template <typename T> - T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, - StringRef entryType); + T resolveEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index, + StringRef entryType, uint64_t depth = 0); - /// Parse an entry using the given reader that was encoded using the textual - /// assembly format. + /// Read the entry at the given index, returning failure if the entry is not + /// yet resolved. template <typename T> - LogicalResult parseAsmEntry(T &result, EncodingReader &reader, - StringRef entryType); + LogicalResult readEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index, + T &result, StringRef entryType, uint64_t depth); /// Parse an entry using the given reader that was encoded using a custom /// bytecode format. template <typename T> LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, - StringRef entryType); + StringRef entryType, uint64_t index, + uint64_t depth); + + /// Parse an entry using the given reader that was encoded using the textual + /// assembly format. + template <typename T> + LogicalResult parseAsmEntry(T &result, EncodingReader &reader, + StringRef entryType); /// The string section reader used to resolve string references when parsing /// custom encoded attribute/type entries. @@ -951,6 +1009,15 @@ private: /// Reference to the parser configuration. const ParserConfig &parserConfig; + + /// Worklist for deferred attribute/type parsing. This is used to handle + /// deeply nested structures like CallSiteLoc iteratively. + /// - The first element is the index of the attribute/type to parse. + /// - The second element is the kind of entry being parsed. + std::vector<std::pair<uint64_t, EntryKind>> deferredWorklist; + + /// Flag indicating if we are currently resolving an attribute or type. + bool resolving = false; }; class DialectReader : public DialectBytecodeReader { @@ -959,10 +1026,11 @@ public: const StringSectionReader &stringReader, const ResourceSectionReader &resourceReader, const llvm::StringMap<BytecodeDialect *> &dialectsMap, - EncodingReader &reader, uint64_t &bytecodeVersion) + EncodingReader &reader, uint64_t &bytecodeVersion, + uint64_t depth = 0) : attrTypeReader(attrTypeReader), stringReader(stringReader), resourceReader(resourceReader), dialectsMap(dialectsMap), - reader(reader), bytecodeVersion(bytecodeVersion) {} + reader(reader), bytecodeVersion(bytecodeVersion), depth(depth) {} InFlightDiagnostic emitError(const Twine &msg) const override { return reader.emitError(msg); @@ -998,14 +1066,65 @@ public: // IR //===--------------------------------------------------------------------===// + /// The maximum depth to eagerly parse nested attributes/types before + /// deferring. + static constexpr uint64_t maxAttrTypeDepth = 5; + LogicalResult readAttribute(Attribute &result) override { - return attrTypeReader.parseAttribute(reader, result); + uint64_t index; + if (failed(reader.parseVarInt(index))) + return failure(); + + // If we aren't currently resolving an attribute/type, we resolve this + // attribute eagerly. This is the case when we are parsing properties, which + // aren't processed via the worklist. + if (!attrTypeReader.isResolving()) { + if (Attribute attr = attrTypeReader.resolveAttribute(index)) { + result = attr; + return success(); + } + return failure(); + } + + if (depth > maxAttrTypeDepth) { + if (Attribute attr = attrTypeReader.getAttributeOrSentinel(index)) { + result = attr; + return success(); + } + attrTypeReader.addDeferredParsing(index, + AttrTypeReader::EntryKind::Attribute); + return failure(); + } + return attrTypeReader.readAttribute(index, result, depth + 1); } LogicalResult readOptionalAttribute(Attribute &result) override { return attrTypeReader.parseOptionalAttribute(reader, result); } LogicalResult readType(Type &result) override { - return attrTypeReader.parseType(reader, result); + uint64_t index; + if (failed(reader.parseVarInt(index))) + return failure(); + + // If we aren't currently resolving an attribute/type, we resolve this + // type eagerly. This is the case when we are parsing properties, which + // aren't processed via the worklist. + if (!attrTypeReader.isResolving()) { + if (Type type = attrTypeReader.resolveType(index)) { + result = type; + return success(); + } + return failure(); + } + + if (depth > maxAttrTypeDepth) { + if (Type type = attrTypeReader.getTypeOrSentinel(index)) { + result = type; + return success(); + } + attrTypeReader.addDeferredParsing(index, AttrTypeReader::EntryKind::Type); + return failure(); + } + return attrTypeReader.readType(index, result, depth + 1); } FailureOr<AsmDialectResourceHandle> readResourceHandle() override { @@ -1095,6 +1214,7 @@ private: const llvm::StringMap<BytecodeDialect *> &dialectsMap; EncodingReader &reader; uint64_t &bytecodeVersion; + uint64_t depth; }; /// Wraps the properties section and handles reading properties out of it. @@ -1238,69 +1358,133 @@ LogicalResult AttrTypeReader::initialize( } template <typename T> -T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, - StringRef entryType) { +T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, + uint64_t index, StringRef entryType, + uint64_t depth) { + bool oldResolving = resolving; + resolving = true; + llvm::scope_exit restoreResolving([&]() { resolving = oldResolving; }); + if (index >= entries.size()) { emitError(fileLoc) << "invalid " << entryType << " index: " << index; return {}; } - // If the entry has already been resolved, there is nothing left to do. - Entry<T> &entry = entries[index]; - if (entry.entry) - return entry.entry; + // Fast path: Try direct parsing without worklist overhead. This handles the + // common case where there are no deferred dependencies. + assert(deferredWorklist.empty()); + T result; + if (succeeded(readEntry(entries, index, result, entryType, depth))) { + assert(deferredWorklist.empty()); + return result; + } + if (deferredWorklist.empty()) { + // Failed with no deferred entries is error. + return T(); + } - // Parse the entry. - EncodingReader reader(entry.data, fileLoc); + // Slow path: Use worklist to handle deferred dependencies. Use a deque to + // iteratively resolve entries with dependencies. + // - Pop from front to process + // - Push new dependencies to front (depth-first) + // - Move failed entries to back (retry after dependencies) + std::deque<std::pair<uint64_t, EntryKind>> worklist; + llvm::DenseSet<std::pair<uint64_t, EntryKind>> inWorklist; + + EntryKind entryKind = + std::is_same_v<T, Type> ? EntryKind::Type : EntryKind::Attribute; + + static_assert((std::is_same_v<T, Type> || std::is_same_v<T, Attribute>) && + "Only support resolving Attributes and Types"); - // Parse based on how the entry was encoded. - if (entry.hasCustomEncoding) { - if (failed(parseCustomEntry(entry, reader, entryType))) + auto addToWorklistFront = [&](std::pair<uint64_t, EntryKind> entry) { + if (inWorklist.insert(entry).second) + worklist.push_front(entry); + }; + + // Add the original index and any dependencies from the fast path attempt. + worklist.emplace_back(index, entryKind); + inWorklist.insert({index, entryKind}); + for (auto entry : llvm::reverse(deferredWorklist)) + addToWorklistFront(entry); + + while (!worklist.empty()) { + auto [currentIndex, entryKind] = worklist.front(); + worklist.pop_front(); + + // Clear the deferred worklist before parsing to capture any new entries. + deferredWorklist.clear(); + + if (entryKind == EntryKind::Type) { + Type result; + if (succeeded(readType(currentIndex, result, depth))) { + inWorklist.erase({currentIndex, entryKind}); + continue; + } + } else { + assert(entryKind == EntryKind::Attribute && "Unexpected entry kind"); + Attribute result; + if (succeeded(readAttribute(currentIndex, result, depth))) { + inWorklist.erase({currentIndex, entryKind}); + continue; + } + } + + if (deferredWorklist.empty()) { + // Parsing failed with no deferred entries which implies an error. return T(); - } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { - return T(); - } + } - if (!reader.empty()) { - reader.emitError("unexpected trailing bytes after " + entryType + " entry"); - return T(); + // Move this entry to the back to retry after dependencies. + worklist.emplace_back(currentIndex, entryKind); + + // Add dependencies to the front (in reverse so they maintain order). + for (auto entry : llvm::reverse(deferredWorklist)) + addToWorklistFront(entry); + + deferredWorklist.clear(); } - return entry.entry; + return entries[index].entry; } template <typename T> -LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, - StringRef entryType) { - StringRef asmStr; - if (failed(reader.parseNullTerminatedString(asmStr))) - return failure(); +LogicalResult AttrTypeReader::readEntry(SmallVectorImpl<Entry<T>> &entries, + uint64_t index, T &result, + StringRef entryType, uint64_t depth) { + if (index >= entries.size()) + return emitError(fileLoc) << "invalid " << entryType << " index: " << index; - // Invoke the MLIR assembly parser to parse the entry text. - size_t numRead = 0; - MLIRContext *context = fileLoc->getContext(); - if constexpr (std::is_same_v<T, Type>) - result = - ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); - else - result = ::parseAttribute(asmStr, context, Type(), &numRead, - /*isKnownNullTerminated=*/true); - if (!result) + // If the entry has already been resolved, return it. + Entry<T> &entry = entries[index]; + if (entry.entry) { + result = entry.entry; + return success(); + } + + // If the entry hasn't been resolved, try to parse it. + EncodingReader reader(entry.data, fileLoc); + LogicalResult parseResult = + entry.hasCustomEncoding + ? parseCustomEntry(entry, reader, entryType, index, depth) + : parseAsmEntry(entry.entry, reader, entryType); + if (failed(parseResult)) return failure(); - // Ensure there weren't dangling characters after the entry. - if (numRead != asmStr.size()) { - return reader.emitError("trailing characters found after ", entryType, - " assembly format: ", asmStr.drop_front(numRead)); - } + if (!reader.empty()) + return reader.emitError("unexpected trailing bytes after " + entryType + + " entry"); + + result = entry.entry; return success(); } template <typename T> LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, EncodingReader &reader, - StringRef entryType) { + StringRef entryType, + uint64_t index, uint64_t depth) { DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap, - reader, bytecodeVersion); + reader, bytecodeVersion, depth); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); @@ -1350,6 +1534,33 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, return success(!!entry.entry); } +template <typename T> +LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, + StringRef entryType) { + StringRef asmStr; + if (failed(reader.parseNullTerminatedString(asmStr))) + return failure(); + + // Invoke the MLIR assembly parser to parse the entry text. + size_t numRead = 0; + MLIRContext *context = fileLoc->getContext(); + if constexpr (std::is_same_v<T, Type>) + result = + ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); + else + result = ::parseAttribute(asmStr, context, Type(), &numRead, + /*isKnownNullTerminated=*/true); + if (!result) + return failure(); + + // Ensure there weren't dangling characters after the entry. + if (numRead != asmStr.size()) { + return reader.emitError("trailing characters found after ", entryType, + " assembly format: ", asmStr.drop_front(numRead)); + } + return success(); +} + //===----------------------------------------------------------------------===// // Bytecode Reader //===----------------------------------------------------------------------===// @@ -1391,8 +1602,8 @@ public: materialize(Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { this->lazyOpsCallback = lazyOpsCallback; - auto resetlazyOpsCallback = - llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); + llvm::scope_exit resetlazyOpsCallback( + [&] { this->lazyOpsCallback = nullptr; }); auto it = lazyLoadableOpsMap.find(op); assert(it != lazyLoadableOpsMap.end() && "materialize called on non-materializable op"); @@ -1703,8 +1914,8 @@ LogicalResult BytecodeReader::Impl::read( Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { EncodingReader reader(buffer.getBuffer(), fileLoc); this->lazyOpsCallback = lazyOpsCallback; - auto resetlazyOpsCallback = - llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); + llvm::scope_exit resetlazyOpsCallback( + [&] { this->lazyOpsCallback = nullptr; }); // Skip over the bytecode header, this should have already been checked. if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) |
