diff options
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
| -rw-r--r-- | mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 160 |
1 files changed, 123 insertions, 37 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 5e71c3a..d6f1e18 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -47,6 +47,8 @@ static std::string toString(bytecode::Section::ID sectionID) { return "Resource (5)"; case bytecode::Section::kResourceOffset: return "ResourceOffset (6)"; + case bytecode::Section::kDialectVersions: + return "DialectVersions (7)"; default: return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); } @@ -63,6 +65,7 @@ static bool isSectionOptional(bytecode::Section::ID sectionID) { return false; case bytecode::Section::kResource: case bytecode::Section::kResourceOffset: + case bytecode::Section::kDialectVersions: return true; default: llvm_unreachable("unknown section ID"); @@ -350,6 +353,13 @@ public: return parseEntry(reader, strings, result, "string"); } + /// Parse a shared string from the string section. The shared string is + /// encoded using an index to a corresponding string in the string section. + LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, + StringRef &result) { + return resolveEntry(reader, strings, index, result, "string"); + } + private: /// The table of strings referenced within the bytecode file. SmallVector<StringRef> strings; @@ -400,31 +410,15 @@ LogicalResult StringSectionReader::initialize(Location fileLoc, //===----------------------------------------------------------------------===// namespace { +class DialectReader; + /// This struct represents a dialect entry within the bytecode. struct BytecodeDialect { /// Load the dialect into the provided context if it hasn't been loaded yet. /// Returns failure if the dialect couldn't be loaded *and* the provided /// context does not allow unregistered dialects. The provided reader is used /// for error emission if necessary. - LogicalResult load(EncodingReader &reader, MLIRContext *ctx) { - if (dialect) - return success(); - Dialect *loadedDialect = ctx->getOrLoadDialect(name); - if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { - return reader.emitError( - "dialect '", name, - "' is unknown. If this is intended, please call " - "allowUnregisteredDialects() on the MLIRContext, or use " - "-allow-unregistered-dialect with the MLIR tool used."); - } - dialect = loadedDialect; - - // If the dialect was actually loaded, check to see if it has a bytecode - // interface. - if (loadedDialect) - interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); - return success(); - } + LogicalResult load(DialectReader &reader, MLIRContext *ctx); /// Return the loaded dialect, or nullptr if the dialect is unknown. This can /// only be called after `load`. @@ -446,6 +440,12 @@ struct BytecodeDialect { /// The name of the dialect. StringRef name; + + /// A buffer containing the encoding of the dialect version parsed. + ArrayRef<uint8_t> versionBuffer; + + /// Lazy loaded dialect version from the handle above. + std::unique_ptr<DialectVersion> loadedVersion; }; /// This struct represents an operation name entry within the bytecode. @@ -496,7 +496,7 @@ public: initialize(Location fileLoc, const ParserConfig &config, MutableArrayRef<BytecodeDialect> dialects, StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, - ArrayRef<uint8_t> offsetSectionData, + ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); /// Parse a dialect resource handle from the resource section. @@ -643,7 +643,7 @@ LogicalResult ResourceSectionReader::initialize( Location fileLoc, const ParserConfig &config, MutableArrayRef<BytecodeDialect> dialects, StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, - ArrayRef<uint8_t> offsetSectionData, + ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { EncodingReader resourceReader(sectionData, fileLoc); EncodingReader offsetReader(offsetSectionData, fileLoc); @@ -684,7 +684,7 @@ LogicalResult ResourceSectionReader::initialize( while (!offsetReader.empty()) { BytecodeDialect *dialect; if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || - failed(dialect->load(resourceReader, ctx))) + failed(dialect->load(dialectReader, ctx))) return failure(); Dialect *loadedDialect = dialect->getLoadedDialect(); if (!loadedDialect) { @@ -1051,7 +1051,8 @@ template <typename T> LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, EncodingReader &reader, StringRef entryType) { - if (failed(entry.dialect->load(reader, fileLoc.getContext()))) + DialectReader dialectReader(*this, stringReader, resourceReader, reader); + if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); // Ensure that the dialect implements the bytecode interface. @@ -1060,12 +1061,22 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, "' does not implement the bytecode interface"); } - // Ask the dialect to parse the entry. - DialectReader dialectReader(*this, stringReader, resourceReader, reader); - if constexpr (std::is_same_v<T, Type>) - entry.entry = entry.dialect->interface->readType(dialectReader); - else - entry.entry = entry.dialect->interface->readAttribute(dialectReader); + // Ask the dialect to parse the entry. If the dialect is versioned, parse + // using the versioned encoding readers. + if (entry.dialect->loadedVersion.get()) { + if constexpr (std::is_same_v<T, Type>) + entry.entry = entry.dialect->interface->readType( + dialectReader, *entry.dialect->loadedVersion); + else + entry.entry = entry.dialect->interface->readAttribute( + dialectReader, *entry.dialect->loadedVersion); + + } else { + if constexpr (std::is_same_v<T, Type>) + entry.entry = entry.dialect->interface->readType(dialectReader); + else + entry.entry = entry.dialect->interface->readAttribute(dialectReader); + } return success(!!entry.entry); } @@ -1122,7 +1133,8 @@ private: // Resource Section LogicalResult - parseResourceSection(std::optional<ArrayRef<uint8_t>> resourceData, + parseResourceSection(EncodingReader &reader, + std::optional<ArrayRef<uint8_t>> resourceData, std::optional<ArrayRef<uint8_t>> resourceOffsetData); //===--------------------------------------------------------------------===// @@ -1306,7 +1318,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { // Process the resource section if present. if (failed(parseResourceSection( - sectionDatas[bytecode::Section::kResource], + reader, sectionDatas[bytecode::Section::kResource], sectionDatas[bytecode::Section::kResourceOffset]))) return failure(); @@ -1326,7 +1338,8 @@ LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) { // Validate the bytecode version. uint64_t currentVersion = bytecode::kVersion; - if (version < currentVersion) { + uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; + if (version < minSupportedVersion) { return reader.emitError("bytecode version ", version, " is older than the current version of ", currentVersion, ", and upgrade is not supported"); @@ -1342,6 +1355,36 @@ LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) { //===----------------------------------------------------------------------===// // Dialect Section +LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { + if (dialect) + return success(); + Dialect *loadedDialect = ctx->getOrLoadDialect(name); + if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { + return reader.emitError("dialect '") + << name + << "' is unknown. If this is intended, please call " + "allowUnregisteredDialects() on the MLIRContext, or use " + "-allow-unregistered-dialect with the MLIR tool used."; + } + dialect = loadedDialect; + + // If the dialect was actually loaded, check to see if it has a bytecode + // interface. + if (loadedDialect) + interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); + if (!versionBuffer.empty()) { + if (!interface) + return reader.emitError("dialect '") + << name + << "' does not implement the bytecode interface, " + "but found a version entry"; + loadedVersion = interface->readVersion(reader); + if (!loadedVersion) + return failure(); + } + return success(); +} + LogicalResult BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) { EncodingReader sectionReader(sectionData, fileLoc); @@ -1353,9 +1396,34 @@ BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) { dialects.resize(numDialects); // Parse each of the dialects. - for (uint64_t i = 0; i < numDialects; ++i) - if (failed(stringReader.parseString(sectionReader, dialects[i].name))) + for (uint64_t i = 0; i < numDialects; ++i) { + /// Before version 1, there wasn't any versioning available for dialects, + /// and the entryIdx represent the string itself. + if (version == 0) { + if (failed(stringReader.parseString(sectionReader, dialects[i].name))) + return failure(); + continue; + } + // Parse ID representing dialect and version. + uint64_t dialectNameIdx; + bool versionAvailable; + if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx, + versionAvailable))) + return failure(); + if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, + dialects[i].name))) return failure(); + if (versionAvailable) { + bytecode::Section::ID sectionID; + if (failed( + sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) + return failure(); + if (sectionID != bytecode::Section::kDialectVersions) { + emitError(fileLoc, "expected dialect version section"); + return failure(); + } + } + } // Parse the operation names, which are grouped by dialect. auto parseOpName = [&](BytecodeDialect *dialect) { @@ -1379,7 +1447,11 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) { // Check to see if this operation name has already been resolved. If we // haven't, load the dialect and build the operation name. if (!opName->opName) { - if (failed(opName->dialect->load(reader, getContext()))) + // Load the dialect and its version. + EncodingReader versionReader(opName->dialect->versionBuffer, fileLoc); + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + versionReader); + if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), getContext()); @@ -1391,7 +1463,7 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) { // Resource Section LogicalResult BytecodeReader::parseResourceSection( - std::optional<ArrayRef<uint8_t>> resourceData, + EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData, std::optional<ArrayRef<uint8_t>> resourceOffsetData) { // Ensure both sections are either present or not. if (resourceData.has_value() != resourceOffsetData.has_value()) { @@ -1408,9 +1480,11 @@ LogicalResult BytecodeReader::parseResourceSection( return success(); // Initialize the resource reader with the resource sections. + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + reader); return resourceReader.initialize(fileLoc, config, dialects, stringReader, *resourceData, *resourceOffsetData, - bufferOwnerRef); + dialectReader, bufferOwnerRef); } //===----------------------------------------------------------------------===// @@ -1442,6 +1516,18 @@ LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData, "not all forward unresolved forward operand references"); } + // Resolve dialect version. + for (const BytecodeDialect &byteCodeDialect : dialects) { + // Parsing is complete, give an opportunity to each dialect to visit the + // IR and perform upgrades. + if (!byteCodeDialect.loadedVersion) + continue; + if (byteCodeDialect.interface && + failed(byteCodeDialect.interface->upgradeFromVersion( + *moduleOp, *byteCodeDialect.loadedVersion))) + return failure(); + } + // Verify that the parsed operations are valid. if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) return failure(); |
