aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
diff options
context:
space:
mode:
authorMehdi Amini <joker.eph@gmail.com>2023-07-28 10:43:51 -0700
committerMehdi Amini <joker.eph@gmail.com>2023-07-28 10:44:02 -0700
commitb299ec16661f653df66cdaf161cdc5441bc9803c (patch)
tree871d39e32caf5b3cacdc546905d4ee7731b8053e /mlir/lib/Bytecode/Reader/BytecodeReader.cpp
parentbb65caf90ae1ade0ab1896c8e781cff34b34a846 (diff)
downloadllvm-b299ec16661f653df66cdaf161cdc5441bc9803c.zip
llvm-b299ec16661f653df66cdaf161cdc5441bc9803c.tar.gz
llvm-b299ec16661f653df66cdaf161cdc5441bc9803c.tar.bz2
Expose callbacks for encoding of types/attributes
[mlir] Expose a mechanism to provide a callback for encoding types and attributes in MLIR bytecode. Two callbacks are exposed, respectively, to the BytecodeWriterConfig and to the ParserConfig. At bytecode parsing/printing, clients have the ability to specify a callback to be used to optionally read/write the encoding. On failure, fallback path will execute the default parsers and printers for the dialect. Testing shows how to leverage this functionality to support back-deployment and backward-compatibility usecases when roundtripping to bytecode a client dialect with type/attributes dependencies on upstream. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D153383
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
-rw-r--r--mlir/lib/Bytecode/Reader/BytecodeReader.cpp181
1 files changed, 120 insertions, 61 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 0639baf..91e47c4 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -451,7 +451,7 @@ struct BytecodeDialect {
/// Returns failure if the dialect couldn't be loaded *and* the provided
/// context does not allow unregistered dialects. The provided reader is used
/// for error emission if necessary.
- LogicalResult load(DialectReader &reader, MLIRContext *ctx);
+ LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
/// Return the loaded dialect, or nullptr if the dialect is unknown. This can
/// only be called after `load`.
@@ -505,10 +505,11 @@ struct BytecodeOperationName {
/// Parse a single dialect group encoded in the byte stream.
static LogicalResult parseDialectGrouping(
- EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects,
+ EncodingReader &reader,
+ MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
// Parse the dialect and the number of entries in the group.
- BytecodeDialect *dialect;
+ std::unique_ptr<BytecodeDialect> *dialect;
if (failed(parseEntry(reader, dialects, dialect, "dialect")))
return failure();
uint64_t numEntries;
@@ -516,7 +517,7 @@ static LogicalResult parseDialectGrouping(
return failure();
for (uint64_t i = 0; i < numEntries; ++i)
- if (failed(entryCallback(dialect)))
+ if (failed(entryCallback(dialect->get())))
return failure();
return success();
}
@@ -532,7 +533,7 @@ public:
/// Initialize the resource section reader with the given section data.
LogicalResult
initialize(Location fileLoc, const ParserConfig &config,
- MutableArrayRef<BytecodeDialect> dialects,
+ MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
@@ -682,7 +683,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
LogicalResult ResourceSectionReader::initialize(
Location fileLoc, const ParserConfig &config,
- MutableArrayRef<BytecodeDialect> dialects,
+ MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
@@ -731,19 +732,19 @@ LogicalResult ResourceSectionReader::initialize(
// Read the dialect resources from the bytecode.
MLIRContext *ctx = fileLoc->getContext();
while (!offsetReader.empty()) {
- BytecodeDialect *dialect;
+ std::unique_ptr<BytecodeDialect> *dialect;
if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
- failed(dialect->load(dialectReader, ctx)))
+ failed((*dialect)->load(dialectReader, ctx)))
return failure();
- Dialect *loadedDialect = dialect->getLoadedDialect();
+ Dialect *loadedDialect = (*dialect)->getLoadedDialect();
if (!loadedDialect) {
return resourceReader.emitError()
- << "dialect '" << dialect->name << "' is unknown";
+ << "dialect '" << (*dialect)->name << "' is unknown";
}
const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
if (!handler) {
return resourceReader.emitError()
- << "unexpected resources for dialect '" << dialect->name << "'";
+ << "unexpected resources for dialect '" << (*dialect)->name << "'";
}
// Ensure that each resource is declared before being processed.
@@ -753,7 +754,7 @@ LogicalResult ResourceSectionReader::initialize(
if (failed(handle)) {
return resourceReader.emitError()
<< "unknown 'resource' key '" << key << "' for dialect '"
- << dialect->name << "'";
+ << (*dialect)->name << "'";
}
dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
dialectResources.push_back(*handle);
@@ -796,15 +797,19 @@ class AttrTypeReader {
public:
AttrTypeReader(StringSectionReader &stringReader,
- ResourceSectionReader &resourceReader, Location fileLoc,
- uint64_t &bytecodeVersion)
+ ResourceSectionReader &resourceReader,
+ const llvm::StringMap<BytecodeDialect *> &dialectsMap,
+ uint64_t &bytecodeVersion, Location fileLoc,
+ const ParserConfig &config)
: stringReader(stringReader), resourceReader(resourceReader),
- fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {}
+ dialectsMap(dialectsMap), fileLoc(fileLoc),
+ bytecodeVersion(bytecodeVersion), parserConfig(config) {}
/// Initialize the attribute and type information within the reader.
- LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
- ArrayRef<uint8_t> sectionData,
- ArrayRef<uint8_t> offsetSectionData);
+ LogicalResult
+ initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
+ ArrayRef<uint8_t> sectionData,
+ ArrayRef<uint8_t> offsetSectionData);
/// Resolve the attribute or type at the given index. Returns nullptr on
/// failure.
@@ -878,6 +883,10 @@ private:
/// parsing custom encoded attribute/type entries.
ResourceSectionReader &resourceReader;
+ /// The map of the loaded dialects used to retrieve dialect information, such
+ /// as the dialect version.
+ const llvm::StringMap<BytecodeDialect *> &dialectsMap;
+
/// The set of attribute and type entries.
SmallVector<AttrEntry> attributes;
SmallVector<TypeEntry> types;
@@ -887,27 +896,48 @@ private:
/// Current bytecode version being used.
uint64_t &bytecodeVersion;
+
+ /// Reference to the parser configuration.
+ const ParserConfig &parserConfig;
};
class DialectReader : public DialectBytecodeReader {
public:
DialectReader(AttrTypeReader &attrTypeReader,
StringSectionReader &stringReader,
- ResourceSectionReader &resourceReader, EncodingReader &reader,
- uint64_t &bytecodeVersion)
+ ResourceSectionReader &resourceReader,
+ const llvm::StringMap<BytecodeDialect *> &dialectsMap,
+ EncodingReader &reader, uint64_t &bytecodeVersion)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
- resourceReader(resourceReader), reader(reader),
- bytecodeVersion(bytecodeVersion) {}
+ resourceReader(resourceReader), dialectsMap(dialectsMap),
+ reader(reader), bytecodeVersion(bytecodeVersion) {}
- InFlightDiagnostic emitError(const Twine &msg) override {
+ InFlightDiagnostic emitError(const Twine &msg) const override {
return reader.emitError(msg);
}
+ FailureOr<const DialectVersion *>
+ getDialectVersion(StringRef dialectName) const override {
+ // First check if the dialect is available in the map.
+ auto dialectEntry = dialectsMap.find(dialectName);
+ if (dialectEntry == dialectsMap.end())
+ return failure();
+ // If the dialect was found, try to load it. This will trigger reading the
+ // bytecode version from the version buffer if it wasn't already processed.
+ // Return failure if either of those two actions could not be completed.
+ if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) ||
+ dialectEntry->getValue()->loadedVersion.get() == nullptr)
+ return failure();
+ return dialectEntry->getValue()->loadedVersion.get();
+ }
+
+ MLIRContext *getContext() const override { return getLoc().getContext(); }
+
uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
- DialectReader withEncodingReader(EncodingReader &encReader) {
+ DialectReader withEncodingReader(EncodingReader &encReader) const {
return DialectReader(attrTypeReader, stringReader, resourceReader,
- encReader, bytecodeVersion);
+ dialectsMap, encReader, bytecodeVersion);
}
Location getLoc() const { return reader.getLoc(); }
@@ -1010,6 +1040,7 @@ private:
AttrTypeReader &attrTypeReader;
StringSectionReader &stringReader;
ResourceSectionReader &resourceReader;
+ const llvm::StringMap<BytecodeDialect *> &dialectsMap;
EncodingReader &reader;
uint64_t &bytecodeVersion;
};
@@ -1096,10 +1127,9 @@ private:
};
} // namespace
-LogicalResult
-AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
- ArrayRef<uint8_t> sectionData,
- ArrayRef<uint8_t> offsetSectionData) {
+LogicalResult AttrTypeReader::initialize(
+ MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
+ ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
EncodingReader offsetReader(offsetSectionData, fileLoc);
// Parse the number of attribute and type entries.
@@ -1151,6 +1181,7 @@ AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
return offsetReader.emitError(
"unexpected trailing data in the Attribute/Type offset section");
}
+
return success();
}
@@ -1216,32 +1247,54 @@ template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
StringRef entryType) {
- DialectReader dialectReader(*this, stringReader, resourceReader, reader,
- bytecodeVersion);
+ DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
+ reader, bytecodeVersion);
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
+
+ if constexpr (std::is_same_v<T, Type>) {
+ // Try parsing with callbacks first if available.
+ for (const auto &callback :
+ parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
+ if (failed(
+ callback->read(dialectReader, entry.dialect->name, entry.entry)))
+ return failure();
+ // Early return if parsing was successful.
+ if (!!entry.entry)
+ return success();
+
+ // Reset the reader if we failed to parse, so we can fall through the
+ // other parsing functions.
+ reader = EncodingReader(entry.data, reader.getLoc());
+ }
+ } else {
+ // Try parsing with callbacks first if available.
+ for (const auto &callback :
+ parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) {
+ if (failed(
+ callback->read(dialectReader, entry.dialect->name, entry.entry)))
+ return failure();
+ // Early return if parsing was successful.
+ if (!!entry.entry)
+ return success();
+
+ // Reset the reader if we failed to parse, so we can fall through the
+ // other parsing functions.
+ reader = EncodingReader(entry.data, reader.getLoc());
+ }
+ }
+
// Ensure that the dialect implements the bytecode interface.
if (!entry.dialect->interface) {
return reader.emitError("dialect '", entry.dialect->name,
"' does not implement the bytecode interface");
}
- // Ask the dialect to parse the entry. If the dialect is versioned, parse
- // using the versioned encoding readers.
- if (entry.dialect->loadedVersion.get()) {
- if constexpr (std::is_same_v<T, Type>)
- entry.entry = entry.dialect->interface->readType(
- dialectReader, *entry.dialect->loadedVersion);
- else
- entry.entry = entry.dialect->interface->readAttribute(
- dialectReader, *entry.dialect->loadedVersion);
+ if constexpr (std::is_same_v<T, Type>)
+ entry.entry = entry.dialect->interface->readType(dialectReader);
+ else
+ entry.entry = entry.dialect->interface->readAttribute(dialectReader);
- } else {
- if constexpr (std::is_same_v<T, Type>)
- entry.entry = entry.dialect->interface->readType(dialectReader);
- else
- entry.entry = entry.dialect->interface->readAttribute(dialectReader);
- }
return success(!!entry.entry);
}
@@ -1262,7 +1315,8 @@ public:
llvm::MemoryBufferRef buffer,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
: config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
- attrTypeReader(stringReader, resourceReader, fileLoc, version),
+ attrTypeReader(stringReader, resourceReader, dialectsMap, version,
+ fileLoc, config),
// Use the builtin unrealized conversion cast operation to represent
// forward references to values that aren't yet defined.
forwardRefOpState(UnknownLoc::get(config.getContext()),
@@ -1528,7 +1582,8 @@ private:
StringRef producer;
/// The table of IR units referenced within the bytecode file.
- SmallVector<BytecodeDialect> dialects;
+ SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
+ llvm::StringMap<BytecodeDialect *> dialectsMap;
SmallVector<BytecodeOperationName> opNames;
/// The reader used to process resources within the bytecode.
@@ -1675,7 +1730,8 @@ LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
//===----------------------------------------------------------------------===//
// Dialect Section
-LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
+LogicalResult BytecodeDialect::load(const DialectReader &reader,
+ MLIRContext *ctx) {
if (dialect)
return success();
Dialect *loadedDialect = ctx->getOrLoadDialect(name);
@@ -1719,13 +1775,15 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
// Parse each of the dialects.
for (uint64_t i = 0; i < numDialects; ++i) {
+ dialects[i] = std::make_unique<BytecodeDialect>();
/// Before version kDialectVersioning, there wasn't any versioning available
/// for dialects, and the entryIdx represent the string itself.
if (version < bytecode::kDialectVersioning) {
- if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
+ if (failed(stringReader.parseString(sectionReader, dialects[i]->name)))
return failure();
continue;
}
+
// Parse ID representing dialect and version.
uint64_t dialectNameIdx;
bool versionAvailable;
@@ -1733,18 +1791,19 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
versionAvailable)))
return failure();
if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
- dialects[i].name)))
+ dialects[i]->name)))
return failure();
if (versionAvailable) {
bytecode::Section::ID sectionID;
- if (failed(
- sectionReader.parseSection(sectionID, dialects[i].versionBuffer)))
+ if (failed(sectionReader.parseSection(sectionID,
+ dialects[i]->versionBuffer)))
return failure();
if (sectionID != bytecode::Section::kDialectVersions) {
emitError(fileLoc, "expected dialect version section");
return failure();
}
}
+ dialectsMap[dialects[i]->name] = dialects[i].get();
}
// Parse the operation names, which are grouped by dialect.
@@ -1792,7 +1851,7 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
if (!opName->opName) {
// Load the dialect and its version.
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
- reader, version);
+ dialectsMap, reader, version);
if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
// If the opName is empty, this is because we use to accept names such as
@@ -1835,7 +1894,7 @@ LogicalResult BytecodeReader::Impl::parseResourceSection(
// Initialize the resource reader with the resource sections.
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
- reader, version);
+ dialectsMap, reader, version);
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
*resourceData, *resourceOffsetData,
dialectReader, bufferOwnerRef);
@@ -2036,14 +2095,14 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
"parsed use-list orders were invalid and could not be applied");
// Resolve dialect version.
- for (const BytecodeDialect &byteCodeDialect : dialects) {
+ for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
// Parsing is complete, give an opportunity to each dialect to visit the
// IR and perform upgrades.
- if (!byteCodeDialect.loadedVersion)
+ if (!byteCodeDialect->loadedVersion)
continue;
- if (byteCodeDialect.interface &&
- failed(byteCodeDialect.interface->upgradeFromVersion(
- *moduleOp, *byteCodeDialect.loadedVersion)))
+ if (byteCodeDialect->interface &&
+ failed(byteCodeDialect->interface->upgradeFromVersion(
+ *moduleOp, *byteCodeDialect->loadedVersion)))
return failure();
}
@@ -2196,7 +2255,7 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
// interface and control the serialization.
if (wasRegistered) {
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
- reader, version);
+ dialectsMap, reader, version);
if (failed(
propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
return failure();