diff options
| author | Mehdi Amini <joker.eph@gmail.com> | 2023-07-28 16:44:25 -0700 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2023-07-28 16:45:42 -0700 |
| commit | b86a13211fcd84bfae39066b51f9f079f970cea8 (patch) | |
| tree | feab13aad3136d79f530a9f45b64f3d5c83054ea /mlir/lib/Bytecode/Reader/BytecodeReader.cpp | |
| parent | b08d358e8ac83cf31c007a58c814bf2dca03d591 (diff) | |
| download | llvm-b86a13211fcd84bfae39066b51f9f079f970cea8.zip llvm-b86a13211fcd84bfae39066b51f9f079f970cea8.tar.gz llvm-b86a13211fcd84bfae39066b51f9f079f970cea8.tar.bz2 | |
Revert "Expose callbacks for encoding of types/attributes"
This reverts commit b299ec16661f653df66cdaf161cdc5441bc9803c.
The authorship informations were incorrect.
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
| -rw-r--r-- | mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 181 |
1 files changed, 61 insertions, 120 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 91e47c4..0639baf 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -451,7 +451,7 @@ struct BytecodeDialect { /// Returns failure if the dialect couldn't be loaded *and* the provided /// context does not allow unregistered dialects. The provided reader is used /// for error emission if necessary. - LogicalResult load(const DialectReader &reader, MLIRContext *ctx); + LogicalResult load(DialectReader &reader, MLIRContext *ctx); /// Return the loaded dialect, or nullptr if the dialect is unknown. This can /// only be called after `load`. @@ -505,11 +505,10 @@ struct BytecodeOperationName { /// Parse a single dialect group encoded in the byte stream. static LogicalResult parseDialectGrouping( - EncodingReader &reader, - MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, + EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects, function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { // Parse the dialect and the number of entries in the group. - std::unique_ptr<BytecodeDialect> *dialect; + BytecodeDialect *dialect; if (failed(parseEntry(reader, dialects, dialect, "dialect"))) return failure(); uint64_t numEntries; @@ -517,7 +516,7 @@ static LogicalResult parseDialectGrouping( return failure(); for (uint64_t i = 0; i < numEntries; ++i) - if (failed(entryCallback(dialect->get()))) + if (failed(entryCallback(dialect))) return failure(); return success(); } @@ -533,7 +532,7 @@ public: /// Initialize the resource section reader with the given section data. LogicalResult initialize(Location fileLoc, const ParserConfig &config, - MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, + MutableArrayRef<BytecodeDialect> dialects, StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); @@ -683,7 +682,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty, LogicalResult ResourceSectionReader::initialize( Location fileLoc, const ParserConfig &config, - MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, + MutableArrayRef<BytecodeDialect> dialects, StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { @@ -732,19 +731,19 @@ LogicalResult ResourceSectionReader::initialize( // Read the dialect resources from the bytecode. MLIRContext *ctx = fileLoc->getContext(); while (!offsetReader.empty()) { - std::unique_ptr<BytecodeDialect> *dialect; + BytecodeDialect *dialect; if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || - failed((*dialect)->load(dialectReader, ctx))) + failed(dialect->load(dialectReader, ctx))) return failure(); - Dialect *loadedDialect = (*dialect)->getLoadedDialect(); + Dialect *loadedDialect = dialect->getLoadedDialect(); if (!loadedDialect) { return resourceReader.emitError() - << "dialect '" << (*dialect)->name << "' is unknown"; + << "dialect '" << dialect->name << "' is unknown"; } const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); if (!handler) { return resourceReader.emitError() - << "unexpected resources for dialect '" << (*dialect)->name << "'"; + << "unexpected resources for dialect '" << dialect->name << "'"; } // Ensure that each resource is declared before being processed. @@ -754,7 +753,7 @@ LogicalResult ResourceSectionReader::initialize( if (failed(handle)) { return resourceReader.emitError() << "unknown 'resource' key '" << key << "' for dialect '" - << (*dialect)->name << "'"; + << dialect->name << "'"; } dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); dialectResources.push_back(*handle); @@ -797,19 +796,15 @@ class AttrTypeReader { public: AttrTypeReader(StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, - const llvm::StringMap<BytecodeDialect *> &dialectsMap, - uint64_t &bytecodeVersion, Location fileLoc, - const ParserConfig &config) + ResourceSectionReader &resourceReader, Location fileLoc, + uint64_t &bytecodeVersion) : stringReader(stringReader), resourceReader(resourceReader), - dialectsMap(dialectsMap), fileLoc(fileLoc), - bytecodeVersion(bytecodeVersion), parserConfig(config) {} + fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {} /// Initialize the attribute and type information within the reader. - LogicalResult - initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, - ArrayRef<uint8_t> sectionData, - ArrayRef<uint8_t> offsetSectionData); + LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects, + ArrayRef<uint8_t> sectionData, + ArrayRef<uint8_t> offsetSectionData); /// Resolve the attribute or type at the given index. Returns nullptr on /// failure. @@ -883,10 +878,6 @@ private: /// parsing custom encoded attribute/type entries. ResourceSectionReader &resourceReader; - /// The map of the loaded dialects used to retrieve dialect information, such - /// as the dialect version. - const llvm::StringMap<BytecodeDialect *> &dialectsMap; - /// The set of attribute and type entries. SmallVector<AttrEntry> attributes; SmallVector<TypeEntry> types; @@ -896,48 +887,27 @@ private: /// Current bytecode version being used. uint64_t &bytecodeVersion; - - /// Reference to the parser configuration. - const ParserConfig &parserConfig; }; class DialectReader : public DialectBytecodeReader { public: DialectReader(AttrTypeReader &attrTypeReader, StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, - const llvm::StringMap<BytecodeDialect *> &dialectsMap, - EncodingReader &reader, uint64_t &bytecodeVersion) + ResourceSectionReader &resourceReader, EncodingReader &reader, + uint64_t &bytecodeVersion) : attrTypeReader(attrTypeReader), stringReader(stringReader), - resourceReader(resourceReader), dialectsMap(dialectsMap), - reader(reader), bytecodeVersion(bytecodeVersion) {} + resourceReader(resourceReader), reader(reader), + bytecodeVersion(bytecodeVersion) {} - InFlightDiagnostic emitError(const Twine &msg) const override { + InFlightDiagnostic emitError(const Twine &msg) override { return reader.emitError(msg); } - FailureOr<const DialectVersion *> - getDialectVersion(StringRef dialectName) const override { - // First check if the dialect is available in the map. - auto dialectEntry = dialectsMap.find(dialectName); - if (dialectEntry == dialectsMap.end()) - return failure(); - // If the dialect was found, try to load it. This will trigger reading the - // bytecode version from the version buffer if it wasn't already processed. - // Return failure if either of those two actions could not be completed. - if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) || - dialectEntry->getValue()->loadedVersion.get() == nullptr) - return failure(); - return dialectEntry->getValue()->loadedVersion.get(); - } - - MLIRContext *getContext() const override { return getLoc().getContext(); } - uint64_t getBytecodeVersion() const override { return bytecodeVersion; } - DialectReader withEncodingReader(EncodingReader &encReader) const { + DialectReader withEncodingReader(EncodingReader &encReader) { return DialectReader(attrTypeReader, stringReader, resourceReader, - dialectsMap, encReader, bytecodeVersion); + encReader, bytecodeVersion); } Location getLoc() const { return reader.getLoc(); } @@ -1040,7 +1010,6 @@ private: AttrTypeReader &attrTypeReader; StringSectionReader &stringReader; ResourceSectionReader &resourceReader; - const llvm::StringMap<BytecodeDialect *> &dialectsMap; EncodingReader &reader; uint64_t &bytecodeVersion; }; @@ -1127,9 +1096,10 @@ private: }; } // namespace -LogicalResult AttrTypeReader::initialize( - MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, - ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) { +LogicalResult +AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects, + ArrayRef<uint8_t> sectionData, + ArrayRef<uint8_t> offsetSectionData) { EncodingReader offsetReader(offsetSectionData, fileLoc); // Parse the number of attribute and type entries. @@ -1181,7 +1151,6 @@ LogicalResult AttrTypeReader::initialize( return offsetReader.emitError( "unexpected trailing data in the Attribute/Type offset section"); } - return success(); } @@ -1247,54 +1216,32 @@ template <typename T> LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, EncodingReader &reader, StringRef entryType) { - DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap, - reader, bytecodeVersion); + DialectReader dialectReader(*this, stringReader, resourceReader, reader, + bytecodeVersion); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); - - if constexpr (std::is_same_v<T, Type>) { - // Try parsing with callbacks first if available. - for (const auto &callback : - parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) { - if (failed( - callback->read(dialectReader, entry.dialect->name, entry.entry))) - return failure(); - // Early return if parsing was successful. - if (!!entry.entry) - return success(); - - // Reset the reader if we failed to parse, so we can fall through the - // other parsing functions. - reader = EncodingReader(entry.data, reader.getLoc()); - } - } else { - // Try parsing with callbacks first if available. - for (const auto &callback : - parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) { - if (failed( - callback->read(dialectReader, entry.dialect->name, entry.entry))) - return failure(); - // Early return if parsing was successful. - if (!!entry.entry) - return success(); - - // Reset the reader if we failed to parse, so we can fall through the - // other parsing functions. - reader = EncodingReader(entry.data, reader.getLoc()); - } - } - // Ensure that the dialect implements the bytecode interface. if (!entry.dialect->interface) { return reader.emitError("dialect '", entry.dialect->name, "' does not implement the bytecode interface"); } - if constexpr (std::is_same_v<T, Type>) - entry.entry = entry.dialect->interface->readType(dialectReader); - else - entry.entry = entry.dialect->interface->readAttribute(dialectReader); + // Ask the dialect to parse the entry. If the dialect is versioned, parse + // using the versioned encoding readers. + if (entry.dialect->loadedVersion.get()) { + if constexpr (std::is_same_v<T, Type>) + entry.entry = entry.dialect->interface->readType( + dialectReader, *entry.dialect->loadedVersion); + else + entry.entry = entry.dialect->interface->readAttribute( + dialectReader, *entry.dialect->loadedVersion); + } else { + if constexpr (std::is_same_v<T, Type>) + entry.entry = entry.dialect->interface->readType(dialectReader); + else + entry.entry = entry.dialect->interface->readAttribute(dialectReader); + } return success(!!entry.entry); } @@ -1315,8 +1262,7 @@ public: llvm::MemoryBufferRef buffer, const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), - attrTypeReader(stringReader, resourceReader, dialectsMap, version, - fileLoc, config), + attrTypeReader(stringReader, resourceReader, fileLoc, version), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), @@ -1582,8 +1528,7 @@ private: StringRef producer; /// The table of IR units referenced within the bytecode file. - SmallVector<std::unique_ptr<BytecodeDialect>> dialects; - llvm::StringMap<BytecodeDialect *> dialectsMap; + SmallVector<BytecodeDialect> dialects; SmallVector<BytecodeOperationName> opNames; /// The reader used to process resources within the bytecode. @@ -1730,8 +1675,7 @@ LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { //===----------------------------------------------------------------------===// // Dialect Section -LogicalResult BytecodeDialect::load(const DialectReader &reader, - MLIRContext *ctx) { +LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { if (dialect) return success(); Dialect *loadedDialect = ctx->getOrLoadDialect(name); @@ -1775,15 +1719,13 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { // Parse each of the dialects. for (uint64_t i = 0; i < numDialects; ++i) { - dialects[i] = std::make_unique<BytecodeDialect>(); /// Before version kDialectVersioning, there wasn't any versioning available /// for dialects, and the entryIdx represent the string itself. if (version < bytecode::kDialectVersioning) { - if (failed(stringReader.parseString(sectionReader, dialects[i]->name))) + if (failed(stringReader.parseString(sectionReader, dialects[i].name))) return failure(); continue; } - // Parse ID representing dialect and version. uint64_t dialectNameIdx; bool versionAvailable; @@ -1791,19 +1733,18 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { versionAvailable))) return failure(); if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, - dialects[i]->name))) + dialects[i].name))) return failure(); if (versionAvailable) { bytecode::Section::ID sectionID; - if (failed(sectionReader.parseSection(sectionID, - dialects[i]->versionBuffer))) + if (failed( + sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) return failure(); if (sectionID != bytecode::Section::kDialectVersions) { emitError(fileLoc, "expected dialect version section"); return failure(); } } - dialectsMap[dialects[i]->name] = dialects[i].get(); } // Parse the operation names, which are grouped by dialect. @@ -1851,7 +1792,7 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader, if (!opName->opName) { // Load the dialect and its version. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - dialectsMap, reader, version); + reader, version); if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); // If the opName is empty, this is because we use to accept names such as @@ -1894,7 +1835,7 @@ LogicalResult BytecodeReader::Impl::parseResourceSection( // Initialize the resource reader with the resource sections. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - dialectsMap, reader, version); + reader, version); return resourceReader.initialize(fileLoc, config, dialects, stringReader, *resourceData, *resourceOffsetData, dialectReader, bufferOwnerRef); @@ -2095,14 +2036,14 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, "parsed use-list orders were invalid and could not be applied"); // Resolve dialect version. - for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) { + for (const BytecodeDialect &byteCodeDialect : dialects) { // Parsing is complete, give an opportunity to each dialect to visit the // IR and perform upgrades. - if (!byteCodeDialect->loadedVersion) + if (!byteCodeDialect.loadedVersion) continue; - if (byteCodeDialect->interface && - failed(byteCodeDialect->interface->upgradeFromVersion( - *moduleOp, *byteCodeDialect->loadedVersion))) + if (byteCodeDialect.interface && + failed(byteCodeDialect.interface->upgradeFromVersion( + *moduleOp, *byteCodeDialect.loadedVersion))) return failure(); } @@ -2255,7 +2196,7 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, // interface and control the serialization. if (wasRegistered) { DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - dialectsMap, reader, version); + reader, version); if (failed( propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) return failure(); |
