diff options
| author | Mehdi Amini <joker.eph@gmail.com> | 2023-07-28 10:43:51 -0700 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2023-07-28 10:44:02 -0700 |
| commit | b299ec16661f653df66cdaf161cdc5441bc9803c (patch) | |
| tree | 871d39e32caf5b3cacdc546905d4ee7731b8053e /mlir/lib/Bytecode/Reader/BytecodeReader.cpp | |
| parent | bb65caf90ae1ade0ab1896c8e781cff34b34a846 (diff) | |
| download | llvm-b299ec16661f653df66cdaf161cdc5441bc9803c.zip llvm-b299ec16661f653df66cdaf161cdc5441bc9803c.tar.gz llvm-b299ec16661f653df66cdaf161cdc5441bc9803c.tar.bz2 | |
Expose callbacks for encoding of types/attributes
[mlir] Expose a mechanism to provide a callback for encoding types and attributes in MLIR bytecode.
Two callbacks are exposed, respectively, to the BytecodeWriterConfig and to the ParserConfig. At bytecode parsing/printing, clients have the ability to specify a callback to be used to optionally read/write the encoding. On failure, fallback path will execute the default parsers and printers for the dialect.
Testing shows how to leverage this functionality to support back-deployment and backward-compatibility usecases when roundtripping to bytecode a client dialect with type/attributes dependencies on upstream.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D153383
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
| -rw-r--r-- | mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 181 |
1 files changed, 120 insertions, 61 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 0639baf..91e47c4 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -451,7 +451,7 @@ struct BytecodeDialect { /// Returns failure if the dialect couldn't be loaded *and* the provided /// context does not allow unregistered dialects. The provided reader is used /// for error emission if necessary. - LogicalResult load(DialectReader &reader, MLIRContext *ctx); + LogicalResult load(const DialectReader &reader, MLIRContext *ctx); /// Return the loaded dialect, or nullptr if the dialect is unknown. This can /// only be called after `load`. @@ -505,10 +505,11 @@ struct BytecodeOperationName { /// Parse a single dialect group encoded in the byte stream. static LogicalResult parseDialectGrouping( - EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects, + EncodingReader &reader, + MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { // Parse the dialect and the number of entries in the group. - BytecodeDialect *dialect; + std::unique_ptr<BytecodeDialect> *dialect; if (failed(parseEntry(reader, dialects, dialect, "dialect"))) return failure(); uint64_t numEntries; @@ -516,7 +517,7 @@ static LogicalResult parseDialectGrouping( return failure(); for (uint64_t i = 0; i < numEntries; ++i) - if (failed(entryCallback(dialect))) + if (failed(entryCallback(dialect->get()))) return failure(); return success(); } @@ -532,7 +533,7 @@ public: /// Initialize the resource section reader with the given section data. LogicalResult initialize(Location fileLoc, const ParserConfig &config, - MutableArrayRef<BytecodeDialect> dialects, + MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); @@ -682,7 +683,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty, LogicalResult ResourceSectionReader::initialize( Location fileLoc, const ParserConfig &config, - MutableArrayRef<BytecodeDialect> dialects, + MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { @@ -731,19 +732,19 @@ LogicalResult ResourceSectionReader::initialize( // Read the dialect resources from the bytecode. MLIRContext *ctx = fileLoc->getContext(); while (!offsetReader.empty()) { - BytecodeDialect *dialect; + std::unique_ptr<BytecodeDialect> *dialect; if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || - failed(dialect->load(dialectReader, ctx))) + failed((*dialect)->load(dialectReader, ctx))) return failure(); - Dialect *loadedDialect = dialect->getLoadedDialect(); + Dialect *loadedDialect = (*dialect)->getLoadedDialect(); if (!loadedDialect) { return resourceReader.emitError() - << "dialect '" << dialect->name << "' is unknown"; + << "dialect '" << (*dialect)->name << "' is unknown"; } const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); if (!handler) { return resourceReader.emitError() - << "unexpected resources for dialect '" << dialect->name << "'"; + << "unexpected resources for dialect '" << (*dialect)->name << "'"; } // Ensure that each resource is declared before being processed. @@ -753,7 +754,7 @@ LogicalResult ResourceSectionReader::initialize( if (failed(handle)) { return resourceReader.emitError() << "unknown 'resource' key '" << key << "' for dialect '" - << dialect->name << "'"; + << (*dialect)->name << "'"; } dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); dialectResources.push_back(*handle); @@ -796,15 +797,19 @@ class AttrTypeReader { public: AttrTypeReader(StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, Location fileLoc, - uint64_t &bytecodeVersion) + ResourceSectionReader &resourceReader, + const llvm::StringMap<BytecodeDialect *> &dialectsMap, + uint64_t &bytecodeVersion, Location fileLoc, + const ParserConfig &config) : stringReader(stringReader), resourceReader(resourceReader), - fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {} + dialectsMap(dialectsMap), fileLoc(fileLoc), + bytecodeVersion(bytecodeVersion), parserConfig(config) {} /// Initialize the attribute and type information within the reader. - LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects, - ArrayRef<uint8_t> sectionData, - ArrayRef<uint8_t> offsetSectionData); + LogicalResult + initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, + ArrayRef<uint8_t> sectionData, + ArrayRef<uint8_t> offsetSectionData); /// Resolve the attribute or type at the given index. Returns nullptr on /// failure. @@ -878,6 +883,10 @@ private: /// parsing custom encoded attribute/type entries. ResourceSectionReader &resourceReader; + /// The map of the loaded dialects used to retrieve dialect information, such + /// as the dialect version. + const llvm::StringMap<BytecodeDialect *> &dialectsMap; + /// The set of attribute and type entries. SmallVector<AttrEntry> attributes; SmallVector<TypeEntry> types; @@ -887,27 +896,48 @@ private: /// Current bytecode version being used. uint64_t &bytecodeVersion; + + /// Reference to the parser configuration. + const ParserConfig &parserConfig; }; class DialectReader : public DialectBytecodeReader { public: DialectReader(AttrTypeReader &attrTypeReader, StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, EncodingReader &reader, - uint64_t &bytecodeVersion) + ResourceSectionReader &resourceReader, + const llvm::StringMap<BytecodeDialect *> &dialectsMap, + EncodingReader &reader, uint64_t &bytecodeVersion) : attrTypeReader(attrTypeReader), stringReader(stringReader), - resourceReader(resourceReader), reader(reader), - bytecodeVersion(bytecodeVersion) {} + resourceReader(resourceReader), dialectsMap(dialectsMap), + reader(reader), bytecodeVersion(bytecodeVersion) {} - InFlightDiagnostic emitError(const Twine &msg) override { + InFlightDiagnostic emitError(const Twine &msg) const override { return reader.emitError(msg); } + FailureOr<const DialectVersion *> + getDialectVersion(StringRef dialectName) const override { + // First check if the dialect is available in the map. + auto dialectEntry = dialectsMap.find(dialectName); + if (dialectEntry == dialectsMap.end()) + return failure(); + // If the dialect was found, try to load it. This will trigger reading the + // bytecode version from the version buffer if it wasn't already processed. + // Return failure if either of those two actions could not be completed. + if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) || + dialectEntry->getValue()->loadedVersion.get() == nullptr) + return failure(); + return dialectEntry->getValue()->loadedVersion.get(); + } + + MLIRContext *getContext() const override { return getLoc().getContext(); } + uint64_t getBytecodeVersion() const override { return bytecodeVersion; } - DialectReader withEncodingReader(EncodingReader &encReader) { + DialectReader withEncodingReader(EncodingReader &encReader) const { return DialectReader(attrTypeReader, stringReader, resourceReader, - encReader, bytecodeVersion); + dialectsMap, encReader, bytecodeVersion); } Location getLoc() const { return reader.getLoc(); } @@ -1010,6 +1040,7 @@ private: AttrTypeReader &attrTypeReader; StringSectionReader &stringReader; ResourceSectionReader &resourceReader; + const llvm::StringMap<BytecodeDialect *> &dialectsMap; EncodingReader &reader; uint64_t &bytecodeVersion; }; @@ -1096,10 +1127,9 @@ private: }; } // namespace -LogicalResult -AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects, - ArrayRef<uint8_t> sectionData, - ArrayRef<uint8_t> offsetSectionData) { +LogicalResult AttrTypeReader::initialize( + MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects, + ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) { EncodingReader offsetReader(offsetSectionData, fileLoc); // Parse the number of attribute and type entries. @@ -1151,6 +1181,7 @@ AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects, return offsetReader.emitError( "unexpected trailing data in the Attribute/Type offset section"); } + return success(); } @@ -1216,32 +1247,54 @@ template <typename T> LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, EncodingReader &reader, StringRef entryType) { - DialectReader dialectReader(*this, stringReader, resourceReader, reader, - bytecodeVersion); + DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap, + reader, bytecodeVersion); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); + + if constexpr (std::is_same_v<T, Type>) { + // Try parsing with callbacks first if available. + for (const auto &callback : + parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) { + if (failed( + callback->read(dialectReader, entry.dialect->name, entry.entry))) + return failure(); + // Early return if parsing was successful. + if (!!entry.entry) + return success(); + + // Reset the reader if we failed to parse, so we can fall through the + // other parsing functions. + reader = EncodingReader(entry.data, reader.getLoc()); + } + } else { + // Try parsing with callbacks first if available. + for (const auto &callback : + parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) { + if (failed( + callback->read(dialectReader, entry.dialect->name, entry.entry))) + return failure(); + // Early return if parsing was successful. + if (!!entry.entry) + return success(); + + // Reset the reader if we failed to parse, so we can fall through the + // other parsing functions. + reader = EncodingReader(entry.data, reader.getLoc()); + } + } + // Ensure that the dialect implements the bytecode interface. if (!entry.dialect->interface) { return reader.emitError("dialect '", entry.dialect->name, "' does not implement the bytecode interface"); } - // Ask the dialect to parse the entry. If the dialect is versioned, parse - // using the versioned encoding readers. - if (entry.dialect->loadedVersion.get()) { - if constexpr (std::is_same_v<T, Type>) - entry.entry = entry.dialect->interface->readType( - dialectReader, *entry.dialect->loadedVersion); - else - entry.entry = entry.dialect->interface->readAttribute( - dialectReader, *entry.dialect->loadedVersion); + if constexpr (std::is_same_v<T, Type>) + entry.entry = entry.dialect->interface->readType(dialectReader); + else + entry.entry = entry.dialect->interface->readAttribute(dialectReader); - } else { - if constexpr (std::is_same_v<T, Type>) - entry.entry = entry.dialect->interface->readType(dialectReader); - else - entry.entry = entry.dialect->interface->readAttribute(dialectReader); - } return success(!!entry.entry); } @@ -1262,7 +1315,8 @@ public: llvm::MemoryBufferRef buffer, const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), - attrTypeReader(stringReader, resourceReader, fileLoc, version), + attrTypeReader(stringReader, resourceReader, dialectsMap, version, + fileLoc, config), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), @@ -1528,7 +1582,8 @@ private: StringRef producer; /// The table of IR units referenced within the bytecode file. - SmallVector<BytecodeDialect> dialects; + SmallVector<std::unique_ptr<BytecodeDialect>> dialects; + llvm::StringMap<BytecodeDialect *> dialectsMap; SmallVector<BytecodeOperationName> opNames; /// The reader used to process resources within the bytecode. @@ -1675,7 +1730,8 @@ LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { //===----------------------------------------------------------------------===// // Dialect Section -LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { +LogicalResult BytecodeDialect::load(const DialectReader &reader, + MLIRContext *ctx) { if (dialect) return success(); Dialect *loadedDialect = ctx->getOrLoadDialect(name); @@ -1719,13 +1775,15 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { // Parse each of the dialects. for (uint64_t i = 0; i < numDialects; ++i) { + dialects[i] = std::make_unique<BytecodeDialect>(); /// Before version kDialectVersioning, there wasn't any versioning available /// for dialects, and the entryIdx represent the string itself. if (version < bytecode::kDialectVersioning) { - if (failed(stringReader.parseString(sectionReader, dialects[i].name))) + if (failed(stringReader.parseString(sectionReader, dialects[i]->name))) return failure(); continue; } + // Parse ID representing dialect and version. uint64_t dialectNameIdx; bool versionAvailable; @@ -1733,18 +1791,19 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { versionAvailable))) return failure(); if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, - dialects[i].name))) + dialects[i]->name))) return failure(); if (versionAvailable) { bytecode::Section::ID sectionID; - if (failed( - sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) + if (failed(sectionReader.parseSection(sectionID, + dialects[i]->versionBuffer))) return failure(); if (sectionID != bytecode::Section::kDialectVersions) { emitError(fileLoc, "expected dialect version section"); return failure(); } } + dialectsMap[dialects[i]->name] = dialects[i].get(); } // Parse the operation names, which are grouped by dialect. @@ -1792,7 +1851,7 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader, if (!opName->opName) { // Load the dialect and its version. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader, version); + dialectsMap, reader, version); if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); // If the opName is empty, this is because we use to accept names such as @@ -1835,7 +1894,7 @@ LogicalResult BytecodeReader::Impl::parseResourceSection( // Initialize the resource reader with the resource sections. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader, version); + dialectsMap, reader, version); return resourceReader.initialize(fileLoc, config, dialects, stringReader, *resourceData, *resourceOffsetData, dialectReader, bufferOwnerRef); @@ -2036,14 +2095,14 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, "parsed use-list orders were invalid and could not be applied"); // Resolve dialect version. - for (const BytecodeDialect &byteCodeDialect : dialects) { + for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) { // Parsing is complete, give an opportunity to each dialect to visit the // IR and perform upgrades. - if (!byteCodeDialect.loadedVersion) + if (!byteCodeDialect->loadedVersion) continue; - if (byteCodeDialect.interface && - failed(byteCodeDialect.interface->upgradeFromVersion( - *moduleOp, *byteCodeDialect.loadedVersion))) + if (byteCodeDialect->interface && + failed(byteCodeDialect->interface->upgradeFromVersion( + *moduleOp, *byteCodeDialect->loadedVersion))) return failure(); } @@ -2196,7 +2255,7 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, // interface and control the serialization. if (wasRegistered) { DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader, version); + dialectsMap, reader, version); if (failed( propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) return failure(); |
