diff options
| author | Matteo Franciolini <m_franciolini@apple.com> | 2023-02-14 08:45:08 -0800 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2023-03-10 23:28:56 +0100 |
| commit | 0e0b6070fd2a2a8f188ddb32aa526beda38190b7 (patch) | |
| tree | e036874c925ed056d0e3707d2d7ba4949c56ef6a /mlir/lib/Bytecode/Reader/BytecodeReader.cpp | |
| parent | 6a02cd45a51945cbe6b9f85629805c38e6009271 (diff) | |
| download | llvm-0e0b6070fd2a2a8f188ddb32aa526beda38190b7.zip llvm-0e0b6070fd2a2a8f188ddb32aa526beda38190b7.tar.gz llvm-0e0b6070fd2a2a8f188ddb32aa526beda38190b7.tar.bz2 | |
Implements MLIR Bytecode versioning capability
A dialect can opt-in to handle versioning through the
`BytecodeDialectInterface`. Few hooks are exposed to the dialect to allow
managing a version encoded into the bytecode file. The version is loaded
lazily and allows to retrieve the version information while parsing the input
IR, and gives an opportunity to each dialect for which a version is present
to perform IR upgrades post-parsing through the `upgradeFromVersion` method.
Custom Attribute and Type encodings can also be upgraded according to the
dialect version using readAttribute and readType methods.
There is no restriction on what kind of information a dialect is allowed to
encode to model its versioning. Currently, versioning is supported only for
bytecode formats.
Reviewed By: rriddle, mehdi_amini
Differential Revision: https://reviews.llvm.org/D143647
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
| -rw-r--r-- | mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 160 |
1 files changed, 123 insertions, 37 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 5e71c3a..d6f1e18 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -47,6 +47,8 @@ static std::string toString(bytecode::Section::ID sectionID) { return "Resource (5)"; case bytecode::Section::kResourceOffset: return "ResourceOffset (6)"; + case bytecode::Section::kDialectVersions: + return "DialectVersions (7)"; default: return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); } @@ -63,6 +65,7 @@ static bool isSectionOptional(bytecode::Section::ID sectionID) { return false; case bytecode::Section::kResource: case bytecode::Section::kResourceOffset: + case bytecode::Section::kDialectVersions: return true; default: llvm_unreachable("unknown section ID"); @@ -350,6 +353,13 @@ public: return parseEntry(reader, strings, result, "string"); } + /// Parse a shared string from the string section. The shared string is + /// encoded using an index to a corresponding string in the string section. + LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, + StringRef &result) { + return resolveEntry(reader, strings, index, result, "string"); + } + private: /// The table of strings referenced within the bytecode file. SmallVector<StringRef> strings; @@ -400,31 +410,15 @@ LogicalResult StringSectionReader::initialize(Location fileLoc, //===----------------------------------------------------------------------===// namespace { +class DialectReader; + /// This struct represents a dialect entry within the bytecode. struct BytecodeDialect { /// Load the dialect into the provided context if it hasn't been loaded yet. /// Returns failure if the dialect couldn't be loaded *and* the provided /// context does not allow unregistered dialects. The provided reader is used /// for error emission if necessary. - LogicalResult load(EncodingReader &reader, MLIRContext *ctx) { - if (dialect) - return success(); - Dialect *loadedDialect = ctx->getOrLoadDialect(name); - if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { - return reader.emitError( - "dialect '", name, - "' is unknown. If this is intended, please call " - "allowUnregisteredDialects() on the MLIRContext, or use " - "-allow-unregistered-dialect with the MLIR tool used."); - } - dialect = loadedDialect; - - // If the dialect was actually loaded, check to see if it has a bytecode - // interface. - if (loadedDialect) - interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); - return success(); - } + LogicalResult load(DialectReader &reader, MLIRContext *ctx); /// Return the loaded dialect, or nullptr if the dialect is unknown. This can /// only be called after `load`. @@ -446,6 +440,12 @@ struct BytecodeDialect { /// The name of the dialect. StringRef name; + + /// A buffer containing the encoding of the dialect version parsed. + ArrayRef<uint8_t> versionBuffer; + + /// Lazy loaded dialect version from the handle above. + std::unique_ptr<DialectVersion> loadedVersion; }; /// This struct represents an operation name entry within the bytecode. @@ -496,7 +496,7 @@ public: initialize(Location fileLoc, const ParserConfig &config, MutableArrayRef<BytecodeDialect> dialects, StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, - ArrayRef<uint8_t> offsetSectionData, + ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); /// Parse a dialect resource handle from the resource section. @@ -643,7 +643,7 @@ LogicalResult ResourceSectionReader::initialize( Location fileLoc, const ParserConfig &config, MutableArrayRef<BytecodeDialect> dialects, StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, - ArrayRef<uint8_t> offsetSectionData, + ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { EncodingReader resourceReader(sectionData, fileLoc); EncodingReader offsetReader(offsetSectionData, fileLoc); @@ -684,7 +684,7 @@ LogicalResult ResourceSectionReader::initialize( while (!offsetReader.empty()) { BytecodeDialect *dialect; if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || - failed(dialect->load(resourceReader, ctx))) + failed(dialect->load(dialectReader, ctx))) return failure(); Dialect *loadedDialect = dialect->getLoadedDialect(); if (!loadedDialect) { @@ -1051,7 +1051,8 @@ template <typename T> LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, EncodingReader &reader, StringRef entryType) { - if (failed(entry.dialect->load(reader, fileLoc.getContext()))) + DialectReader dialectReader(*this, stringReader, resourceReader, reader); + if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); // Ensure that the dialect implements the bytecode interface. @@ -1060,12 +1061,22 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, "' does not implement the bytecode interface"); } - // Ask the dialect to parse the entry. - DialectReader dialectReader(*this, stringReader, resourceReader, reader); - if constexpr (std::is_same_v<T, Type>) - entry.entry = entry.dialect->interface->readType(dialectReader); - else - entry.entry = entry.dialect->interface->readAttribute(dialectReader); + // Ask the dialect to parse the entry. If the dialect is versioned, parse + // using the versioned encoding readers. + if (entry.dialect->loadedVersion.get()) { + if constexpr (std::is_same_v<T, Type>) + entry.entry = entry.dialect->interface->readType( + dialectReader, *entry.dialect->loadedVersion); + else + entry.entry = entry.dialect->interface->readAttribute( + dialectReader, *entry.dialect->loadedVersion); + + } else { + if constexpr (std::is_same_v<T, Type>) + entry.entry = entry.dialect->interface->readType(dialectReader); + else + entry.entry = entry.dialect->interface->readAttribute(dialectReader); + } return success(!!entry.entry); } @@ -1122,7 +1133,8 @@ private: // Resource Section LogicalResult - parseResourceSection(std::optional<ArrayRef<uint8_t>> resourceData, + parseResourceSection(EncodingReader &reader, + std::optional<ArrayRef<uint8_t>> resourceData, std::optional<ArrayRef<uint8_t>> resourceOffsetData); //===--------------------------------------------------------------------===// @@ -1306,7 +1318,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { // Process the resource section if present. if (failed(parseResourceSection( - sectionDatas[bytecode::Section::kResource], + reader, sectionDatas[bytecode::Section::kResource], sectionDatas[bytecode::Section::kResourceOffset]))) return failure(); @@ -1326,7 +1338,8 @@ LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) { // Validate the bytecode version. uint64_t currentVersion = bytecode::kVersion; - if (version < currentVersion) { + uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; + if (version < minSupportedVersion) { return reader.emitError("bytecode version ", version, " is older than the current version of ", currentVersion, ", and upgrade is not supported"); @@ -1342,6 +1355,36 @@ LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) { //===----------------------------------------------------------------------===// // Dialect Section +LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { + if (dialect) + return success(); + Dialect *loadedDialect = ctx->getOrLoadDialect(name); + if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { + return reader.emitError("dialect '") + << name + << "' is unknown. If this is intended, please call " + "allowUnregisteredDialects() on the MLIRContext, or use " + "-allow-unregistered-dialect with the MLIR tool used."; + } + dialect = loadedDialect; + + // If the dialect was actually loaded, check to see if it has a bytecode + // interface. + if (loadedDialect) + interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); + if (!versionBuffer.empty()) { + if (!interface) + return reader.emitError("dialect '") + << name + << "' does not implement the bytecode interface, " + "but found a version entry"; + loadedVersion = interface->readVersion(reader); + if (!loadedVersion) + return failure(); + } + return success(); +} + LogicalResult BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) { EncodingReader sectionReader(sectionData, fileLoc); @@ -1353,9 +1396,34 @@ BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) { dialects.resize(numDialects); // Parse each of the dialects. - for (uint64_t i = 0; i < numDialects; ++i) - if (failed(stringReader.parseString(sectionReader, dialects[i].name))) + for (uint64_t i = 0; i < numDialects; ++i) { + /// Before version 1, there wasn't any versioning available for dialects, + /// and the entryIdx represent the string itself. + if (version == 0) { + if (failed(stringReader.parseString(sectionReader, dialects[i].name))) + return failure(); + continue; + } + // Parse ID representing dialect and version. + uint64_t dialectNameIdx; + bool versionAvailable; + if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx, + versionAvailable))) + return failure(); + if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, + dialects[i].name))) return failure(); + if (versionAvailable) { + bytecode::Section::ID sectionID; + if (failed( + sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) + return failure(); + if (sectionID != bytecode::Section::kDialectVersions) { + emitError(fileLoc, "expected dialect version section"); + return failure(); + } + } + } // Parse the operation names, which are grouped by dialect. auto parseOpName = [&](BytecodeDialect *dialect) { @@ -1379,7 +1447,11 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) { // Check to see if this operation name has already been resolved. If we // haven't, load the dialect and build the operation name. if (!opName->opName) { - if (failed(opName->dialect->load(reader, getContext()))) + // Load the dialect and its version. + EncodingReader versionReader(opName->dialect->versionBuffer, fileLoc); + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + versionReader); + if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), getContext()); @@ -1391,7 +1463,7 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) { // Resource Section LogicalResult BytecodeReader::parseResourceSection( - std::optional<ArrayRef<uint8_t>> resourceData, + EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData, std::optional<ArrayRef<uint8_t>> resourceOffsetData) { // Ensure both sections are either present or not. if (resourceData.has_value() != resourceOffsetData.has_value()) { @@ -1408,9 +1480,11 @@ LogicalResult BytecodeReader::parseResourceSection( return success(); // Initialize the resource reader with the resource sections. + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + reader); return resourceReader.initialize(fileLoc, config, dialects, stringReader, *resourceData, *resourceOffsetData, - bufferOwnerRef); + dialectReader, bufferOwnerRef); } //===----------------------------------------------------------------------===// @@ -1442,6 +1516,18 @@ LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData, "not all forward unresolved forward operand references"); } + // Resolve dialect version. + for (const BytecodeDialect &byteCodeDialect : dialects) { + // Parsing is complete, give an opportunity to each dialect to visit the + // IR and perform upgrades. + if (!byteCodeDialect.loadedVersion) + continue; + if (byteCodeDialect.interface && + failed(byteCodeDialect.interface->upgradeFromVersion( + *moduleOp, *byteCodeDialect.loadedVersion))) + return failure(); + } + // Verify that the parsed operations are valid. if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) return failure(); |
