aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
diff options
context:
space:
mode:
authorMehdi Amini <joker.eph@gmail.com>2023-07-28 16:44:25 -0700
committerMehdi Amini <joker.eph@gmail.com>2023-07-28 16:45:42 -0700
commitb86a13211fcd84bfae39066b51f9f079f970cea8 (patch)
treefeab13aad3136d79f530a9f45b64f3d5c83054ea /mlir/lib/Bytecode/Reader/BytecodeReader.cpp
parentb08d358e8ac83cf31c007a58c814bf2dca03d591 (diff)
downloadllvm-b86a13211fcd84bfae39066b51f9f079f970cea8.zip
llvm-b86a13211fcd84bfae39066b51f9f079f970cea8.tar.gz
llvm-b86a13211fcd84bfae39066b51f9f079f970cea8.tar.bz2
Revert "Expose callbacks for encoding of types/attributes"
This reverts commit b299ec16661f653df66cdaf161cdc5441bc9803c. The authorship informations were incorrect.
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
-rw-r--r--mlir/lib/Bytecode/Reader/BytecodeReader.cpp181
1 files changed, 61 insertions, 120 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 91e47c4..0639baf 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -451,7 +451,7 @@ struct BytecodeDialect {
/// Returns failure if the dialect couldn't be loaded *and* the provided
/// context does not allow unregistered dialects. The provided reader is used
/// for error emission if necessary.
- LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
+ LogicalResult load(DialectReader &reader, MLIRContext *ctx);
/// Return the loaded dialect, or nullptr if the dialect is unknown. This can
/// only be called after `load`.
@@ -505,11 +505,10 @@ struct BytecodeOperationName {
/// Parse a single dialect group encoded in the byte stream.
static LogicalResult parseDialectGrouping(
- EncodingReader &reader,
- MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
+ EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects,
function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
// Parse the dialect and the number of entries in the group.
- std::unique_ptr<BytecodeDialect> *dialect;
+ BytecodeDialect *dialect;
if (failed(parseEntry(reader, dialects, dialect, "dialect")))
return failure();
uint64_t numEntries;
@@ -517,7 +516,7 @@ static LogicalResult parseDialectGrouping(
return failure();
for (uint64_t i = 0; i < numEntries; ++i)
- if (failed(entryCallback(dialect->get())))
+ if (failed(entryCallback(dialect)))
return failure();
return success();
}
@@ -533,7 +532,7 @@ public:
/// Initialize the resource section reader with the given section data.
LogicalResult
initialize(Location fileLoc, const ParserConfig &config,
- MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
+ MutableArrayRef<BytecodeDialect> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
@@ -683,7 +682,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
LogicalResult ResourceSectionReader::initialize(
Location fileLoc, const ParserConfig &config,
- MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
+ MutableArrayRef<BytecodeDialect> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
@@ -732,19 +731,19 @@ LogicalResult ResourceSectionReader::initialize(
// Read the dialect resources from the bytecode.
MLIRContext *ctx = fileLoc->getContext();
while (!offsetReader.empty()) {
- std::unique_ptr<BytecodeDialect> *dialect;
+ BytecodeDialect *dialect;
if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
- failed((*dialect)->load(dialectReader, ctx)))
+ failed(dialect->load(dialectReader, ctx)))
return failure();
- Dialect *loadedDialect = (*dialect)->getLoadedDialect();
+ Dialect *loadedDialect = dialect->getLoadedDialect();
if (!loadedDialect) {
return resourceReader.emitError()
- << "dialect '" << (*dialect)->name << "' is unknown";
+ << "dialect '" << dialect->name << "' is unknown";
}
const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
if (!handler) {
return resourceReader.emitError()
- << "unexpected resources for dialect '" << (*dialect)->name << "'";
+ << "unexpected resources for dialect '" << dialect->name << "'";
}
// Ensure that each resource is declared before being processed.
@@ -754,7 +753,7 @@ LogicalResult ResourceSectionReader::initialize(
if (failed(handle)) {
return resourceReader.emitError()
<< "unknown 'resource' key '" << key << "' for dialect '"
- << (*dialect)->name << "'";
+ << dialect->name << "'";
}
dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
dialectResources.push_back(*handle);
@@ -797,19 +796,15 @@ class AttrTypeReader {
public:
AttrTypeReader(StringSectionReader &stringReader,
- ResourceSectionReader &resourceReader,
- const llvm::StringMap<BytecodeDialect *> &dialectsMap,
- uint64_t &bytecodeVersion, Location fileLoc,
- const ParserConfig &config)
+ ResourceSectionReader &resourceReader, Location fileLoc,
+ uint64_t &bytecodeVersion)
: stringReader(stringReader), resourceReader(resourceReader),
- dialectsMap(dialectsMap), fileLoc(fileLoc),
- bytecodeVersion(bytecodeVersion), parserConfig(config) {}
+ fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {}
/// Initialize the attribute and type information within the reader.
- LogicalResult
- initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
- ArrayRef<uint8_t> sectionData,
- ArrayRef<uint8_t> offsetSectionData);
+ LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
+ ArrayRef<uint8_t> sectionData,
+ ArrayRef<uint8_t> offsetSectionData);
/// Resolve the attribute or type at the given index. Returns nullptr on
/// failure.
@@ -883,10 +878,6 @@ private:
/// parsing custom encoded attribute/type entries.
ResourceSectionReader &resourceReader;
- /// The map of the loaded dialects used to retrieve dialect information, such
- /// as the dialect version.
- const llvm::StringMap<BytecodeDialect *> &dialectsMap;
-
/// The set of attribute and type entries.
SmallVector<AttrEntry> attributes;
SmallVector<TypeEntry> types;
@@ -896,48 +887,27 @@ private:
/// Current bytecode version being used.
uint64_t &bytecodeVersion;
-
- /// Reference to the parser configuration.
- const ParserConfig &parserConfig;
};
class DialectReader : public DialectBytecodeReader {
public:
DialectReader(AttrTypeReader &attrTypeReader,
StringSectionReader &stringReader,
- ResourceSectionReader &resourceReader,
- const llvm::StringMap<BytecodeDialect *> &dialectsMap,
- EncodingReader &reader, uint64_t &bytecodeVersion)
+ ResourceSectionReader &resourceReader, EncodingReader &reader,
+ uint64_t &bytecodeVersion)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
- resourceReader(resourceReader), dialectsMap(dialectsMap),
- reader(reader), bytecodeVersion(bytecodeVersion) {}
+ resourceReader(resourceReader), reader(reader),
+ bytecodeVersion(bytecodeVersion) {}
- InFlightDiagnostic emitError(const Twine &msg) const override {
+ InFlightDiagnostic emitError(const Twine &msg) override {
return reader.emitError(msg);
}
- FailureOr<const DialectVersion *>
- getDialectVersion(StringRef dialectName) const override {
- // First check if the dialect is available in the map.
- auto dialectEntry = dialectsMap.find(dialectName);
- if (dialectEntry == dialectsMap.end())
- return failure();
- // If the dialect was found, try to load it. This will trigger reading the
- // bytecode version from the version buffer if it wasn't already processed.
- // Return failure if either of those two actions could not be completed.
- if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) ||
- dialectEntry->getValue()->loadedVersion.get() == nullptr)
- return failure();
- return dialectEntry->getValue()->loadedVersion.get();
- }
-
- MLIRContext *getContext() const override { return getLoc().getContext(); }
-
uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
- DialectReader withEncodingReader(EncodingReader &encReader) const {
+ DialectReader withEncodingReader(EncodingReader &encReader) {
return DialectReader(attrTypeReader, stringReader, resourceReader,
- dialectsMap, encReader, bytecodeVersion);
+ encReader, bytecodeVersion);
}
Location getLoc() const { return reader.getLoc(); }
@@ -1040,7 +1010,6 @@ private:
AttrTypeReader &attrTypeReader;
StringSectionReader &stringReader;
ResourceSectionReader &resourceReader;
- const llvm::StringMap<BytecodeDialect *> &dialectsMap;
EncodingReader &reader;
uint64_t &bytecodeVersion;
};
@@ -1127,9 +1096,10 @@ private:
};
} // namespace
-LogicalResult AttrTypeReader::initialize(
- MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
- ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
+LogicalResult
+AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
+ ArrayRef<uint8_t> sectionData,
+ ArrayRef<uint8_t> offsetSectionData) {
EncodingReader offsetReader(offsetSectionData, fileLoc);
// Parse the number of attribute and type entries.
@@ -1181,7 +1151,6 @@ LogicalResult AttrTypeReader::initialize(
return offsetReader.emitError(
"unexpected trailing data in the Attribute/Type offset section");
}
-
return success();
}
@@ -1247,54 +1216,32 @@ template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
StringRef entryType) {
- DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
- reader, bytecodeVersion);
+ DialectReader dialectReader(*this, stringReader, resourceReader, reader,
+ bytecodeVersion);
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
-
- if constexpr (std::is_same_v<T, Type>) {
- // Try parsing with callbacks first if available.
- for (const auto &callback :
- parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
- if (failed(
- callback->read(dialectReader, entry.dialect->name, entry.entry)))
- return failure();
- // Early return if parsing was successful.
- if (!!entry.entry)
- return success();
-
- // Reset the reader if we failed to parse, so we can fall through the
- // other parsing functions.
- reader = EncodingReader(entry.data, reader.getLoc());
- }
- } else {
- // Try parsing with callbacks first if available.
- for (const auto &callback :
- parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) {
- if (failed(
- callback->read(dialectReader, entry.dialect->name, entry.entry)))
- return failure();
- // Early return if parsing was successful.
- if (!!entry.entry)
- return success();
-
- // Reset the reader if we failed to parse, so we can fall through the
- // other parsing functions.
- reader = EncodingReader(entry.data, reader.getLoc());
- }
- }
-
// Ensure that the dialect implements the bytecode interface.
if (!entry.dialect->interface) {
return reader.emitError("dialect '", entry.dialect->name,
"' does not implement the bytecode interface");
}
- if constexpr (std::is_same_v<T, Type>)
- entry.entry = entry.dialect->interface->readType(dialectReader);
- else
- entry.entry = entry.dialect->interface->readAttribute(dialectReader);
+ // Ask the dialect to parse the entry. If the dialect is versioned, parse
+ // using the versioned encoding readers.
+ if (entry.dialect->loadedVersion.get()) {
+ if constexpr (std::is_same_v<T, Type>)
+ entry.entry = entry.dialect->interface->readType(
+ dialectReader, *entry.dialect->loadedVersion);
+ else
+ entry.entry = entry.dialect->interface->readAttribute(
+ dialectReader, *entry.dialect->loadedVersion);
+ } else {
+ if constexpr (std::is_same_v<T, Type>)
+ entry.entry = entry.dialect->interface->readType(dialectReader);
+ else
+ entry.entry = entry.dialect->interface->readAttribute(dialectReader);
+ }
return success(!!entry.entry);
}
@@ -1315,8 +1262,7 @@ public:
llvm::MemoryBufferRef buffer,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
: config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
- attrTypeReader(stringReader, resourceReader, dialectsMap, version,
- fileLoc, config),
+ attrTypeReader(stringReader, resourceReader, fileLoc, version),
// Use the builtin unrealized conversion cast operation to represent
// forward references to values that aren't yet defined.
forwardRefOpState(UnknownLoc::get(config.getContext()),
@@ -1582,8 +1528,7 @@ private:
StringRef producer;
/// The table of IR units referenced within the bytecode file.
- SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
- llvm::StringMap<BytecodeDialect *> dialectsMap;
+ SmallVector<BytecodeDialect> dialects;
SmallVector<BytecodeOperationName> opNames;
/// The reader used to process resources within the bytecode.
@@ -1730,8 +1675,7 @@ LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
//===----------------------------------------------------------------------===//
// Dialect Section
-LogicalResult BytecodeDialect::load(const DialectReader &reader,
- MLIRContext *ctx) {
+LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
if (dialect)
return success();
Dialect *loadedDialect = ctx->getOrLoadDialect(name);
@@ -1775,15 +1719,13 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
// Parse each of the dialects.
for (uint64_t i = 0; i < numDialects; ++i) {
- dialects[i] = std::make_unique<BytecodeDialect>();
/// Before version kDialectVersioning, there wasn't any versioning available
/// for dialects, and the entryIdx represent the string itself.
if (version < bytecode::kDialectVersioning) {
- if (failed(stringReader.parseString(sectionReader, dialects[i]->name)))
+ if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
return failure();
continue;
}
-
// Parse ID representing dialect and version.
uint64_t dialectNameIdx;
bool versionAvailable;
@@ -1791,19 +1733,18 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
versionAvailable)))
return failure();
if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
- dialects[i]->name)))
+ dialects[i].name)))
return failure();
if (versionAvailable) {
bytecode::Section::ID sectionID;
- if (failed(sectionReader.parseSection(sectionID,
- dialects[i]->versionBuffer)))
+ if (failed(
+ sectionReader.parseSection(sectionID, dialects[i].versionBuffer)))
return failure();
if (sectionID != bytecode::Section::kDialectVersions) {
emitError(fileLoc, "expected dialect version section");
return failure();
}
}
- dialectsMap[dialects[i]->name] = dialects[i].get();
}
// Parse the operation names, which are grouped by dialect.
@@ -1851,7 +1792,7 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
if (!opName->opName) {
// Load the dialect and its version.
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
- dialectsMap, reader, version);
+ reader, version);
if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
// If the opName is empty, this is because we use to accept names such as
@@ -1894,7 +1835,7 @@ LogicalResult BytecodeReader::Impl::parseResourceSection(
// Initialize the resource reader with the resource sections.
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
- dialectsMap, reader, version);
+ reader, version);
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
*resourceData, *resourceOffsetData,
dialectReader, bufferOwnerRef);
@@ -2095,14 +2036,14 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
"parsed use-list orders were invalid and could not be applied");
// Resolve dialect version.
- for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
+ for (const BytecodeDialect &byteCodeDialect : dialects) {
// Parsing is complete, give an opportunity to each dialect to visit the
// IR and perform upgrades.
- if (!byteCodeDialect->loadedVersion)
+ if (!byteCodeDialect.loadedVersion)
continue;
- if (byteCodeDialect->interface &&
- failed(byteCodeDialect->interface->upgradeFromVersion(
- *moduleOp, *byteCodeDialect->loadedVersion)))
+ if (byteCodeDialect.interface &&
+ failed(byteCodeDialect.interface->upgradeFromVersion(
+ *moduleOp, *byteCodeDialect.loadedVersion)))
return failure();
}
@@ -2255,7 +2196,7 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
// interface and control the serialization.
if (wasRegistered) {
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
- dialectsMap, reader, version);
+ reader, version);
if (failed(
propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
return failure();