aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
-rw-r--r--mlir/lib/Bytecode/Reader/BytecodeReader.cpp160
1 files changed, 123 insertions, 37 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 5e71c3a..d6f1e18 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -47,6 +47,8 @@ static std::string toString(bytecode::Section::ID sectionID) {
return "Resource (5)";
case bytecode::Section::kResourceOffset:
return "ResourceOffset (6)";
+ case bytecode::Section::kDialectVersions:
+ return "DialectVersions (7)";
default:
return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
}
@@ -63,6 +65,7 @@ static bool isSectionOptional(bytecode::Section::ID sectionID) {
return false;
case bytecode::Section::kResource:
case bytecode::Section::kResourceOffset:
+ case bytecode::Section::kDialectVersions:
return true;
default:
llvm_unreachable("unknown section ID");
@@ -350,6 +353,13 @@ public:
return parseEntry(reader, strings, result, "string");
}
+ /// Parse a shared string from the string section. The shared string is
+ /// encoded using an index to a corresponding string in the string section.
+ LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
+ StringRef &result) {
+ return resolveEntry(reader, strings, index, result, "string");
+ }
+
private:
/// The table of strings referenced within the bytecode file.
SmallVector<StringRef> strings;
@@ -400,31 +410,15 @@ LogicalResult StringSectionReader::initialize(Location fileLoc,
//===----------------------------------------------------------------------===//
namespace {
+class DialectReader;
+
/// This struct represents a dialect entry within the bytecode.
struct BytecodeDialect {
/// Load the dialect into the provided context if it hasn't been loaded yet.
/// Returns failure if the dialect couldn't be loaded *and* the provided
/// context does not allow unregistered dialects. The provided reader is used
/// for error emission if necessary.
- LogicalResult load(EncodingReader &reader, MLIRContext *ctx) {
- if (dialect)
- return success();
- Dialect *loadedDialect = ctx->getOrLoadDialect(name);
- if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
- return reader.emitError(
- "dialect '", name,
- "' is unknown. If this is intended, please call "
- "allowUnregisteredDialects() on the MLIRContext, or use "
- "-allow-unregistered-dialect with the MLIR tool used.");
- }
- dialect = loadedDialect;
-
- // If the dialect was actually loaded, check to see if it has a bytecode
- // interface.
- if (loadedDialect)
- interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
- return success();
- }
+ LogicalResult load(DialectReader &reader, MLIRContext *ctx);
/// Return the loaded dialect, or nullptr if the dialect is unknown. This can
/// only be called after `load`.
@@ -446,6 +440,12 @@ struct BytecodeDialect {
/// The name of the dialect.
StringRef name;
+
+ /// A buffer containing the encoding of the dialect version parsed.
+ ArrayRef<uint8_t> versionBuffer;
+
+ /// Lazy loaded dialect version from the handle above.
+ std::unique_ptr<DialectVersion> loadedVersion;
};
/// This struct represents an operation name entry within the bytecode.
@@ -496,7 +496,7 @@ public:
initialize(Location fileLoc, const ParserConfig &config,
MutableArrayRef<BytecodeDialect> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
- ArrayRef<uint8_t> offsetSectionData,
+ ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
/// Parse a dialect resource handle from the resource section.
@@ -643,7 +643,7 @@ LogicalResult ResourceSectionReader::initialize(
Location fileLoc, const ParserConfig &config,
MutableArrayRef<BytecodeDialect> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
- ArrayRef<uint8_t> offsetSectionData,
+ ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
EncodingReader resourceReader(sectionData, fileLoc);
EncodingReader offsetReader(offsetSectionData, fileLoc);
@@ -684,7 +684,7 @@ LogicalResult ResourceSectionReader::initialize(
while (!offsetReader.empty()) {
BytecodeDialect *dialect;
if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
- failed(dialect->load(resourceReader, ctx)))
+ failed(dialect->load(dialectReader, ctx)))
return failure();
Dialect *loadedDialect = dialect->getLoadedDialect();
if (!loadedDialect) {
@@ -1051,7 +1051,8 @@ template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
StringRef entryType) {
- if (failed(entry.dialect->load(reader, fileLoc.getContext())))
+ DialectReader dialectReader(*this, stringReader, resourceReader, reader);
+ if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
// Ensure that the dialect implements the bytecode interface.
@@ -1060,12 +1061,22 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
"' does not implement the bytecode interface");
}
- // Ask the dialect to parse the entry.
- DialectReader dialectReader(*this, stringReader, resourceReader, reader);
- if constexpr (std::is_same_v<T, Type>)
- entry.entry = entry.dialect->interface->readType(dialectReader);
- else
- entry.entry = entry.dialect->interface->readAttribute(dialectReader);
+ // Ask the dialect to parse the entry. If the dialect is versioned, parse
+ // using the versioned encoding readers.
+ if (entry.dialect->loadedVersion.get()) {
+ if constexpr (std::is_same_v<T, Type>)
+ entry.entry = entry.dialect->interface->readType(
+ dialectReader, *entry.dialect->loadedVersion);
+ else
+ entry.entry = entry.dialect->interface->readAttribute(
+ dialectReader, *entry.dialect->loadedVersion);
+
+ } else {
+ if constexpr (std::is_same_v<T, Type>)
+ entry.entry = entry.dialect->interface->readType(dialectReader);
+ else
+ entry.entry = entry.dialect->interface->readAttribute(dialectReader);
+ }
return success(!!entry.entry);
}
@@ -1122,7 +1133,8 @@ private:
// Resource Section
LogicalResult
- parseResourceSection(std::optional<ArrayRef<uint8_t>> resourceData,
+ parseResourceSection(EncodingReader &reader,
+ std::optional<ArrayRef<uint8_t>> resourceData,
std::optional<ArrayRef<uint8_t>> resourceOffsetData);
//===--------------------------------------------------------------------===//
@@ -1306,7 +1318,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
// Process the resource section if present.
if (failed(parseResourceSection(
- sectionDatas[bytecode::Section::kResource],
+ reader, sectionDatas[bytecode::Section::kResource],
sectionDatas[bytecode::Section::kResourceOffset])))
return failure();
@@ -1326,7 +1338,8 @@ LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) {
// Validate the bytecode version.
uint64_t currentVersion = bytecode::kVersion;
- if (version < currentVersion) {
+ uint64_t minSupportedVersion = bytecode::kMinSupportedVersion;
+ if (version < minSupportedVersion) {
return reader.emitError("bytecode version ", version,
" is older than the current version of ",
currentVersion, ", and upgrade is not supported");
@@ -1342,6 +1355,36 @@ LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) {
//===----------------------------------------------------------------------===//
// Dialect Section
+LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
+ if (dialect)
+ return success();
+ Dialect *loadedDialect = ctx->getOrLoadDialect(name);
+ if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
+ return reader.emitError("dialect '")
+ << name
+ << "' is unknown. If this is intended, please call "
+ "allowUnregisteredDialects() on the MLIRContext, or use "
+ "-allow-unregistered-dialect with the MLIR tool used.";
+ }
+ dialect = loadedDialect;
+
+ // If the dialect was actually loaded, check to see if it has a bytecode
+ // interface.
+ if (loadedDialect)
+ interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
+ if (!versionBuffer.empty()) {
+ if (!interface)
+ return reader.emitError("dialect '")
+ << name
+ << "' does not implement the bytecode interface, "
+ "but found a version entry";
+ loadedVersion = interface->readVersion(reader);
+ if (!loadedVersion)
+ return failure();
+ }
+ return success();
+}
+
LogicalResult
BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) {
EncodingReader sectionReader(sectionData, fileLoc);
@@ -1353,9 +1396,34 @@ BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) {
dialects.resize(numDialects);
// Parse each of the dialects.
- for (uint64_t i = 0; i < numDialects; ++i)
- if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
+ for (uint64_t i = 0; i < numDialects; ++i) {
+ /// Before version 1, there wasn't any versioning available for dialects,
+ /// and the entryIdx represent the string itself.
+ if (version == 0) {
+ if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
+ return failure();
+ continue;
+ }
+ // Parse ID representing dialect and version.
+ uint64_t dialectNameIdx;
+ bool versionAvailable;
+ if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx,
+ versionAvailable)))
+ return failure();
+ if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
+ dialects[i].name)))
return failure();
+ if (versionAvailable) {
+ bytecode::Section::ID sectionID;
+ if (failed(
+ sectionReader.parseSection(sectionID, dialects[i].versionBuffer)))
+ return failure();
+ if (sectionID != bytecode::Section::kDialectVersions) {
+ emitError(fileLoc, "expected dialect version section");
+ return failure();
+ }
+ }
+ }
// Parse the operation names, which are grouped by dialect.
auto parseOpName = [&](BytecodeDialect *dialect) {
@@ -1379,7 +1447,11 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
// Check to see if this operation name has already been resolved. If we
// haven't, load the dialect and build the operation name.
if (!opName->opName) {
- if (failed(opName->dialect->load(reader, getContext())))
+ // Load the dialect and its version.
+ EncodingReader versionReader(opName->dialect->versionBuffer, fileLoc);
+ DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
+ versionReader);
+ if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
getContext());
@@ -1391,7 +1463,7 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
// Resource Section
LogicalResult BytecodeReader::parseResourceSection(
- std::optional<ArrayRef<uint8_t>> resourceData,
+ EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
// Ensure both sections are either present or not.
if (resourceData.has_value() != resourceOffsetData.has_value()) {
@@ -1408,9 +1480,11 @@ LogicalResult BytecodeReader::parseResourceSection(
return success();
// Initialize the resource reader with the resource sections.
+ DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
+ reader);
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
*resourceData, *resourceOffsetData,
- bufferOwnerRef);
+ dialectReader, bufferOwnerRef);
}
//===----------------------------------------------------------------------===//
@@ -1442,6 +1516,18 @@ LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData,
"not all forward unresolved forward operand references");
}
+ // Resolve dialect version.
+ for (const BytecodeDialect &byteCodeDialect : dialects) {
+ // Parsing is complete, give an opportunity to each dialect to visit the
+ // IR and perform upgrades.
+ if (!byteCodeDialect.loadedVersion)
+ continue;
+ if (byteCodeDialect.interface &&
+ failed(byteCodeDialect.interface->upgradeFromVersion(
+ *moduleOp, *byteCodeDialect.loadedVersion)))
+ return failure();
+ }
+
// Verify that the parsed operations are valid.
if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp)))
return failure();