aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/docs/BytecodeFormat.md70
-rw-r--r--mlir/include/mlir/Bytecode/BytecodeImplementation.h21
-rw-r--r--mlir/include/mlir/Bytecode/BytecodeWriter.h48
-rw-r--r--mlir/include/mlir/IR/AsmState.h14
-rw-r--r--mlir/lib/AsmParser/Parser.cpp8
-rw-r--r--mlir/lib/Bytecode/Encoding.h12
-rw-r--r--mlir/lib/Bytecode/Reader/BytecodeReader.cpp381
-rw-r--r--mlir/lib/Bytecode/Writer/BytecodeWriter.cpp234
-rw-r--r--mlir/lib/Bytecode/Writer/IRNumbering.cpp117
-rw-r--r--mlir/lib/Bytecode/Writer/IRNumbering.h48
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp12
-rw-r--r--mlir/lib/IR/BuiltinDialectBytecode.cpp47
-rw-r--r--mlir/test/Bytecode/invalid/invalid-structure.mlir2
-rw-r--r--mlir/test/Bytecode/resources.mlir27
14 files changed, 994 insertions, 47 deletions
diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
index fc33286..d07b461 100644
--- a/mlir/docs/BytecodeFormat.md
+++ b/mlir/docs/BytecodeFormat.md
@@ -89,17 +89,23 @@ Strings are blobs of characters with an associated length.
```
section {
- id: byte
- length: varint
+ idAndIsAligned: byte // id | (hasAlign << 7)
+ length: varint,
+
+ alignment: varint?,
+ padding: byte[], // Padding bytes are always `0xCB`.
+
+ data: byte[]
}
```
-Sections are a mechanism for grouping data within the bytecode. The enable
+Sections are a mechanism for grouping data within the bytecode. They enable
delayed processing, which is useful for out-of-order processing of data,
-lazy-loading, and more. Each section contains a Section ID and a length (which
-allowing for skipping over the section).
-
-TODO: Sections should also carry an optional alignment. Add this when necessary.
+lazy-loading, and more. Each section contains a Section ID, whose high bit
+indicates if the section has alignment requirements, a length (which allows for
+skipping over the section), and an optional alignment. When an alignment is
+present, a variable number of padding bytes (0xCB) may appear before the section
+data. The alignment of a section must be a power of 2.
## MLIR Encoding
@@ -244,6 +250,56 @@ encountered an attribute or type of a given dialect, it doesn't encode any
further information. As such, a common encoding idiom is to use a leading
`varint` code to indicate how the attribute or type was encoded.
+### Resource Section
+
+Resources are encoded using two [sections](#sections), one section
+(`resource_section`) containing the actual encoded representation, and another
+section (`resource_offset_section`) containing the offsets of each encoded
+resource into the previous section.
+
+```
+resource_section {
+ resources: resource[]
+}
+resource {
+ value: resource_bool | resource_string | resource_blob
+}
+resource_bool {
+ value: byte
+}
+resource_string {
+ value: varint
+}
+resource_blob {
+ alignment: varint,
+ size: varint,
+ padding: byte[],
+ blob: byte[]
+}
+
+resource_offset_section {
+ numExternalResourceGroups: varint,
+ resourceGroups: resource_group[]
+}
+resource_group {
+ key: varint,
+ numResources: varint,
+ resources: resource_info[]
+}
+resource_info {
+ key: varint,
+ size: varint
+ kind: byte,
+}
+```
+
+Resources are grouped by the provider, either an external entity or a dialect,
+with each `resource_group` in the offset section containing the corresponding
+provider, number of elements, and info for each element within the group. For
+each element, we record the key, the value kind, and the encoded size. We avoid
+using the direct offset into the `resource_section`, as a smaller relative
+offsets provides more effective compression.
+
### IR Section
The IR section contains the encoded form of operations within the bytecode.
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 845c790..7607f2a 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -18,6 +18,7 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/OpImplementation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/Twine.h"
@@ -105,6 +106,18 @@ public:
<< ", but got: " << baseResult;
}
+ /// Read a handle to a dialect resource.
+ template <typename ResourceT>
+ FailureOr<ResourceT> readResourceHandle() {
+ FailureOr<AsmDialectResourceHandle> handle = readResourceHandle();
+ if (failed(handle))
+ return failure();
+ if (auto *result = dyn_cast<ResourceT>(&*handle))
+ return std::move(*result);
+ return emitError() << "provided resource handle differs from the "
+ "expected resource type";
+ }
+
//===--------------------------------------------------------------------===//
// Primitives
//===--------------------------------------------------------------------===//
@@ -129,6 +142,10 @@ public:
/// Read a string from the bytecode.
virtual LogicalResult readString(StringRef &result) = 0;
+
+private:
+ /// Read a handle to a dialect resource.
+ virtual FailureOr<AsmDialectResourceHandle> readResourceHandle() = 0;
};
//===----------------------------------------------------------------------===//
@@ -171,6 +188,10 @@ public:
writeList(types, [this](T type) { writeType(type); });
}
+ /// Write the given handle to a dialect resource.
+ virtual void
+ writeResourceHandle(const AsmDialectResourceHandle &resource) = 0;
+
//===--------------------------------------------------------------------===//
// Primitives
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
index 6b86e01..cd1d8c7 100644
--- a/mlir/include/mlir/Bytecode/BytecodeWriter.h
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -13,23 +13,59 @@
#ifndef MLIR_BYTECODE_BYTECODEWRITER_H
#define MLIR_BYTECODE_BYTECODEWRITER_H
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/StringRef.h"
+#include "mlir/IR/AsmState.h"
namespace mlir {
class Operation;
+/// This class contains the configuration used for the bytecode writer. It
+/// controls various aspects of bytecode generation, and contains all of the
+/// various bytecode writer hooks.
+class BytecodeWriterConfig {
+public:
+ /// `producer` is an optional string that can be used to identify the producer
+ /// of the bytecode when reading. It has no functional effect on the bytecode
+ /// serialization.
+ BytecodeWriterConfig(StringRef producer = "MLIR" LLVM_VERSION_STRING);
+ ~BytecodeWriterConfig();
+
+ /// An internal implementation class that contains the state of the
+ /// configuration.
+ struct Impl;
+
+ /// Return an instance of the internal implementation.
+ const Impl &getImpl() const { return *impl; }
+
+ //===--------------------------------------------------------------------===//
+ // Resources
+ //===--------------------------------------------------------------------===//
+
+ /// Attach the given resource printer to the writer configuration.
+ void attachResourcePrinter(std::unique_ptr<AsmResourcePrinter> printer);
+
+ /// Attach an resource printer, in the form of a callable, to the
+ /// configuration.
+ template <typename CallableT>
+ std::enable_if_t<std::is_convertible<
+ CallableT, function_ref<void(Operation *, AsmResourceBuilder &)>>::value>
+ attachResourcePrinter(StringRef name, CallableT &&printFn) {
+ attachResourcePrinter(AsmResourcePrinter::fromCallable(
+ name, std::forward<CallableT>(printFn)));
+ }
+
+private:
+ /// A pointer to allocated storage for the impl state.
+ std::unique_ptr<Impl> impl;
+};
+
//===----------------------------------------------------------------------===//
// Entry Points
//===----------------------------------------------------------------------===//
/// Write the bytecode for the given operation to the provided output stream.
/// For streams where it matters, the given stream should be in "binary" mode.
-/// `producer` is an optional string that can be used to identify the producer
-/// of the bytecode when reading. It has no functional effect on the bytecode
-/// serialization.
void writeBytecodeToFile(Operation *op, raw_ostream &os,
- StringRef producer = "MLIR" LLVM_VERSION_STRING);
+ const BytecodeWriterConfig &config = {});
} // namespace mlir
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index d3ef630..3ff4cfdf 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -262,6 +262,17 @@ public:
}
};
+/// This enum represents the different kinds of resource values.
+enum class AsmResourceEntryKind {
+ /// A blob of data with an accompanying alignment.
+ Blob,
+ /// A boolean value.
+ Bool,
+ /// A string value.
+ String,
+};
+StringRef toString(AsmResourceEntryKind kind);
+
/// This class represents a single parsed resource entry.
class AsmParsedResourceEntry {
public:
@@ -273,6 +284,9 @@ public:
/// Emit an error at the location of this entry.
virtual InFlightDiagnostic emitError() const = 0;
+ /// Return the kind of this value.
+ virtual AsmResourceEntryKind getKind() const = 0;
+
/// Parse the resource entry represented by a boolean. Returns failure if the
/// entry does not correspond to a bool.
virtual FailureOr<bool> parseAsBool() const = 0;
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index a48a258..9e4e11e 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -2344,6 +2344,14 @@ public:
InFlightDiagnostic emitError() const final { return p.emitError(keyLoc); }
+ AsmResourceEntryKind getKind() const final {
+ if (value.isAny(Token::kw_true, Token::kw_false))
+ return AsmResourceEntryKind::Bool;
+ return value.getSpelling().startswith("\"0x")
+ ? AsmResourceEntryKind::Blob
+ : AsmResourceEntryKind::String;
+ }
+
FailureOr<bool> parseAsBool() const final {
if (value.is(Token::kw_true))
return true;
diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h
index 3b8de65..ee1789f 100644
--- a/mlir/lib/Bytecode/Encoding.h
+++ b/mlir/lib/Bytecode/Encoding.h
@@ -25,6 +25,9 @@ namespace bytecode {
enum {
/// The current bytecode version.
kVersion = 0,
+
+ /// An arbitrary value used to fill alignment padding.
+ kAlignmentByte = 0xCB,
};
//===----------------------------------------------------------------------===//
@@ -51,8 +54,15 @@ enum ID : uint8_t {
/// and their nested regions/operations.
kIR = 4,
+ /// This section contains the resources of the bytecode.
+ kResource = 5,
+
+ /// This section contains the offsets of resources within the Resource
+ /// section.
+ kResourceOffset = 6,
+
/// The total number of section types.
- kNumSections = 5,
+ kNumSections = 7,
};
} // namespace Section
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 26510a8..5e10dfa 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -20,6 +20,7 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SaveAndRestore.h"
@@ -40,11 +41,32 @@ static std::string toString(bytecode::Section::ID sectionID) {
return "AttrTypeOffset (3)";
case bytecode::Section::kIR:
return "IR (4)";
+ case bytecode::Section::kResource:
+ return "Resource (5)";
+ case bytecode::Section::kResourceOffset:
+ return "ResourceOffset (6)";
default:
return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
}
}
+/// Returns true if the given top-level section ID is optional.
+static bool isSectionOptional(bytecode::Section::ID sectionID) {
+ switch (sectionID) {
+ case bytecode::Section::kString:
+ case bytecode::Section::kDialect:
+ case bytecode::Section::kAttrType:
+ case bytecode::Section::kAttrTypeOffset:
+ case bytecode::Section::kIR:
+ return false;
+ case bytecode::Section::kResource:
+ case bytecode::Section::kResourceOffset:
+ return true;
+ default:
+ llvm_unreachable("unknown section ID");
+ }
+}
+
//===----------------------------------------------------------------------===//
// EncodingReader
//===----------------------------------------------------------------------===//
@@ -65,11 +87,34 @@ public:
/// Returns the remaining size of the bytecode.
size_t size() const { return dataEnd - dataIt; }
+ /// Align the current reader position to the specified alignment.
+ LogicalResult alignTo(unsigned alignment) {
+ if (!llvm::isPowerOf2_32(alignment))
+ return emitError("expected alignment to be a power-of-two");
+
+ // Shift the reader position to the next alignment boundary.
+ while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) {
+ uint8_t padding;
+ if (failed(parseByte(padding)))
+ return failure();
+ if (padding != bytecode::kAlignmentByte) {
+ return emitError("expected alignment byte (0xCB), but got: '0x" +
+ llvm::utohexstr(padding) + "'");
+ }
+ }
+
+ // TODO: Check that the current data pointer is actually at the expected
+ // alignment.
+
+ return success();
+ }
+
/// Emit an error using the given arguments.
template <typename... Args>
InFlightDiagnostic emitError(Args &&...args) const {
return ::emitError(fileLoc).append(std::forward<Args>(args)...);
}
+ InFlightDiagnostic emitError() const { return ::emitError(fileLoc); }
/// Parse a single byte from the stream.
template <typename T>
@@ -101,6 +146,17 @@ public:
return success();
}
+ /// Parse an aligned blob of data, where the alignment was encoded alongside
+ /// the data.
+ LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data,
+ uint64_t &alignment) {
+ uint64_t dataSize;
+ if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) ||
+ failed(alignTo(alignment)))
+ return failure();
+ return parseBytes(dataSize, data);
+ }
+
/// Parse a variable length encoded integer from the byte stream. The first
/// encoded byte contains a prefix in the low bits indicating the encoded
/// length of the value. This length prefix is a bit sequence of '0's followed
@@ -177,13 +233,31 @@ public:
/// contents of the section in `sectionData`.
LogicalResult parseSection(bytecode::Section::ID &sectionID,
ArrayRef<uint8_t> &sectionData) {
+ uint8_t sectionIDAndHasAlignment;
uint64_t length;
- if (failed(parseByte(sectionID)) || failed(parseVarInt(length)))
+ if (failed(parseByte(sectionIDAndHasAlignment)) ||
+ failed(parseVarInt(length)))
return failure();
+
+ // Extract the section ID and whether the section is aligned. The high bit
+ // of the ID is the alignment flag.
+ sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment &
+ 0b01111111);
+ bool hasAlignment = sectionIDAndHasAlignment & 0b10000000;
+
+ // Check that the section is actually valid before trying to process its
+ // data.
if (sectionID >= bytecode::Section::kNumSections)
return emitError("invalid section ID: ", unsigned(sectionID));
- // Parse the actua section data now that we have its length.
+ // Process the section alignment if present.
+ if (hasAlignment) {
+ uint64_t alignment;
+ if (failed(parseVarInt(alignment)) || failed(alignTo(alignment)))
+ return failure();
+ }
+
+ // Parse the actual section data.
return parseBytes(static_cast<size_t>(length), sectionData);
}
@@ -346,6 +420,14 @@ struct BytecodeDialect {
return success();
}
+ /// Return the loaded dialect, or nullptr if the dialect is unknown. This can
+ /// only be called after `load`.
+ Dialect *getLoadedDialect() const {
+ assert(dialect &&
+ "expected `load` to be invoked before `getLoadedDialect`");
+ return *dialect;
+ }
+
/// The loaded dialect entry. This field is None if we haven't attempted to
/// load, nullptr if we failed to load, otherwise the loaded dialect.
Optional<Dialect *> dialect;
@@ -394,6 +476,225 @@ static LogicalResult parseDialectGrouping(
}
//===----------------------------------------------------------------------===//
+// ResourceSectionReader
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class is used to read the resource section from the bytecode.
+class ResourceSectionReader {
+public:
+ /// Initialize the resource section reader with the given section data.
+ LogicalResult initialize(Location fileLoc, const ParserConfig &config,
+ MutableArrayRef<BytecodeDialect> dialects,
+ StringSectionReader &stringReader,
+ ArrayRef<uint8_t> sectionData,
+ ArrayRef<uint8_t> offsetSectionData);
+
+ /// Parse a dialect resource handle from the resource section.
+ LogicalResult parseResourceHandle(EncodingReader &reader,
+ AsmDialectResourceHandle &result) {
+ return parseEntry(reader, dialectResources, result, "resource handle");
+ }
+
+private:
+ /// The table of dialect resources within the bytecode file.
+ SmallVector<AsmDialectResourceHandle> dialectResources;
+};
+
+class ParsedResourceEntry : public AsmParsedResourceEntry {
+public:
+ ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind,
+ EncodingReader &reader, StringSectionReader &stringReader)
+ : key(key), kind(kind), reader(reader), stringReader(stringReader) {}
+ ~ParsedResourceEntry() override = default;
+
+ StringRef getKey() const final { return key; }
+
+ InFlightDiagnostic emitError() const final { return reader.emitError(); }
+
+ AsmResourceEntryKind getKind() const final { return kind; }
+
+ FailureOr<bool> parseAsBool() const final {
+ if (kind != AsmResourceEntryKind::Bool)
+ return emitError() << "expected a bool resource entry, but found a "
+ << toString(kind) << " entry instead";
+
+ bool value;
+ if (failed(reader.parseByte(value)))
+ return failure();
+ return value;
+ }
+ FailureOr<std::string> parseAsString() const final {
+ if (kind != AsmResourceEntryKind::String)
+ return emitError() << "expected a string resource entry, but found a "
+ << toString(kind) << " entry instead";
+
+ StringRef string;
+ if (failed(stringReader.parseString(reader, string)))
+ return failure();
+ return string.str();
+ }
+
+ FailureOr<AsmResourceBlob>
+ parseAsBlob(BlobAllocatorFn allocator) const final {
+ if (kind != AsmResourceEntryKind::Blob)
+ return emitError() << "expected a blob resource entry, but found a "
+ << toString(kind) << " entry instead";
+
+ ArrayRef<uint8_t> data;
+ uint64_t alignment;
+ if (failed(reader.parseBlobAndAlignment(data, alignment)))
+ return failure();
+
+ // Allocate memory for the blob using the provided allocator and copy the
+ // data into it.
+ // FIXME: If the current holder of the bytecode can ensure its lifetime
+ // (e.g. when mmap'd), we should not copy the data. We should use the data
+ // from the bytecode directly.
+ AsmResourceBlob blob = allocator(data.size(), alignment);
+ assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&
+ blob.isMutable() &&
+ "blob allocator did not return a properly aligned address");
+ memcpy(blob.getMutableData().data(), data.data(), data.size());
+ return blob;
+ }
+
+private:
+ StringRef key;
+ AsmResourceEntryKind kind;
+ EncodingReader &reader;
+ StringSectionReader &stringReader;
+};
+} // namespace
+
+template <typename T>
+static LogicalResult
+parseResourceGroup(Location fileLoc, bool allowEmpty,
+ EncodingReader &offsetReader, EncodingReader &resourceReader,
+ StringSectionReader &stringReader, T *handler,
+ function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
+ uint64_t numResources;
+ if (failed(offsetReader.parseVarInt(numResources)))
+ return failure();
+
+ for (uint64_t i = 0; i < numResources; ++i) {
+ StringRef key;
+ AsmResourceEntryKind kind;
+ uint64_t resourceOffset;
+ ArrayRef<uint8_t> data;
+ if (failed(stringReader.parseString(offsetReader, key)) ||
+ failed(offsetReader.parseVarInt(resourceOffset)) ||
+ failed(offsetReader.parseByte(kind)) ||
+ failed(resourceReader.parseBytes(resourceOffset, data)))
+ return failure();
+
+ // Process the resource key.
+ if ((processKeyFn && failed(processKeyFn(key))))
+ return failure();
+
+ // If the resource data is empty and we allow it, don't error out when
+ // parsing below, just skip it.
+ if (allowEmpty && data.empty())
+ continue;
+
+ // Ignore the entry if we don't have a valid handler.
+ if (!handler)
+ continue;
+
+ // Otherwise, parse the resource value.
+ EncodingReader entryReader(data, fileLoc);
+ ParsedResourceEntry entry(key, kind, entryReader, stringReader);
+ if (failed(handler->parseResource(entry)))
+ return failure();
+ if (!entryReader.empty()) {
+ return entryReader.emitError(
+ "unexpected trailing bytes in resource entry '", key, "'");
+ }
+ }
+ return success();
+}
+
+LogicalResult
+ResourceSectionReader::initialize(Location fileLoc, const ParserConfig &config,
+ MutableArrayRef<BytecodeDialect> dialects,
+ StringSectionReader &stringReader,
+ ArrayRef<uint8_t> sectionData,
+ ArrayRef<uint8_t> offsetSectionData) {
+ EncodingReader resourceReader(sectionData, fileLoc);
+ EncodingReader offsetReader(offsetSectionData, fileLoc);
+
+ // Read the number of external resource providers.
+ uint64_t numExternalResourceGroups;
+ if (failed(offsetReader.parseVarInt(numExternalResourceGroups)))
+ return failure();
+
+ // Utility functor that dispatches to `parseResourceGroup`, but implicitly
+ // provides most of the arguments.
+ auto parseGroup = [&](auto *handler, bool allowEmpty = false,
+ function_ref<LogicalResult(StringRef)> keyFn = {}) {
+ return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
+ stringReader, handler, keyFn);
+ };
+
+ // Read the external resources from the bytecode.
+ for (uint64_t i = 0; i < numExternalResourceGroups; ++i) {
+ StringRef key;
+ if (failed(stringReader.parseString(offsetReader, key)))
+ return failure();
+
+ // Get the handler for these resources.
+ // TODO: Should we require handling external resources in some scenarios?
+ AsmResourceParser *handler = config.getResourceParser(key);
+ if (!handler) {
+ emitWarning(fileLoc) << "ignoring unknown external resources for '" << key
+ << "'";
+ }
+
+ if (failed(parseGroup(handler)))
+ return failure();
+ }
+
+ // Read the dialect resources from the bytecode.
+ MLIRContext *ctx = fileLoc->getContext();
+ while (!offsetReader.empty()) {
+ BytecodeDialect *dialect;
+ if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
+ failed(dialect->load(resourceReader, ctx)))
+ return failure();
+ Dialect *loadedDialect = dialect->getLoadedDialect();
+ if (!loadedDialect) {
+ return resourceReader.emitError()
+ << "dialect '" << dialect->name << "' is unknown";
+ }
+ const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
+ if (!handler) {
+ return resourceReader.emitError()
+ << "unexpected resources for dialect '" << dialect->name << "'";
+ }
+
+ // Ensure that each resource is declared before being processed.
+ auto processResourceKeyFn = [&](StringRef key) -> LogicalResult {
+ FailureOr<AsmDialectResourceHandle> handle =
+ handler->declareResource(key);
+ if (failed(handle)) {
+ return resourceReader.emitError()
+ << "unknown 'resource' key '" << key << "' for dialect '"
+ << dialect->name << "'";
+ }
+ dialectResources.push_back(*handle);
+ return success();
+ };
+
+ // Parse the resources for this dialect. We allow empty resources because we
+ // just treat these as declarations.
+ if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn)))
+ return failure();
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// Attribute/Type Reader
//===----------------------------------------------------------------------===//
@@ -419,8 +720,10 @@ class AttrTypeReader {
using TypeEntry = Entry<Type>;
public:
- AttrTypeReader(StringSectionReader &stringReader, Location fileLoc)
- : stringReader(stringReader), fileLoc(fileLoc) {}
+ AttrTypeReader(StringSectionReader &stringReader,
+ ResourceSectionReader &resourceReader, Location fileLoc)
+ : stringReader(stringReader), resourceReader(resourceReader),
+ fileLoc(fileLoc) {}
/// Initialize the attribute and type information within the reader.
LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
@@ -483,6 +786,10 @@ private:
/// custom encoded attribute/type entries.
StringSectionReader &stringReader;
+ /// The resource section reader used to resolve resource references when
+ /// parsing custom encoded attribute/type entries.
+ ResourceSectionReader &resourceReader;
+
/// The set of attribute and type entries.
SmallVector<AttrEntry> attributes;
SmallVector<TypeEntry> types;
@@ -494,9 +801,10 @@ private:
class DialectReader : public DialectBytecodeReader {
public:
DialectReader(AttrTypeReader &attrTypeReader,
- StringSectionReader &stringReader, EncodingReader &reader)
+ StringSectionReader &stringReader,
+ ResourceSectionReader &resourceReader, EncodingReader &reader)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
- reader(reader) {}
+ resourceReader(resourceReader), reader(reader) {}
InFlightDiagnostic emitError(const Twine &msg) override {
return reader.emitError(msg);
@@ -514,6 +822,13 @@ public:
return attrTypeReader.parseType(reader, result);
}
+ FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
+ AsmDialectResourceHandle handle;
+ if (failed(resourceReader.parseResourceHandle(reader, handle)))
+ return failure();
+ return handle;
+ }
+
//===--------------------------------------------------------------------===//
// Primitives
//===--------------------------------------------------------------------===//
@@ -575,6 +890,7 @@ public:
private:
AttrTypeReader &attrTypeReader;
StringSectionReader &stringReader;
+ ResourceSectionReader &resourceReader;
EncodingReader &reader;
};
} // namespace
@@ -707,7 +1023,7 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
}
// Ask the dialect to parse the entry.
- DialectReader dialectReader(*this, stringReader, reader);
+ DialectReader dialectReader(*this, stringReader, resourceReader, reader);
if constexpr (std::is_same_v<T, Type>)
entry.entry = entry.dialect->interface->readType(dialectReader);
else
@@ -724,7 +1040,8 @@ namespace {
class BytecodeReader {
public:
BytecodeReader(Location fileLoc, const ParserConfig &config)
- : config(config), fileLoc(fileLoc), attrTypeReader(stringReader, fileLoc),
+ : config(config), fileLoc(fileLoc),
+ attrTypeReader(stringReader, resourceReader, fileLoc),
// Use the builtin unrealized conversion cast operation to represent
// forward references to values that aren't yet defined.
forwardRefOpState(UnknownLoc::get(config.getContext()),
@@ -762,6 +1079,13 @@ private:
}
//===--------------------------------------------------------------------===//
+ // Resource Section
+
+ LogicalResult
+ parseResourceSection(Optional<ArrayRef<uint8_t>> resourceData,
+ Optional<ArrayRef<uint8_t>> resourceOffsetData);
+
+ //===--------------------------------------------------------------------===//
// IR Section
/// This struct represents the current read state of a range of regions. This
@@ -863,6 +1187,9 @@ private:
SmallVector<BytecodeDialect> dialects;
SmallVector<BytecodeOperationName> opNames;
+ /// The reader used to process resources within the bytecode.
+ ResourceSectionReader resourceReader;
+
/// The table of strings referenced within the bytecode file.
StringSectionReader stringReader;
@@ -914,11 +1241,12 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
}
sectionDatas[sectionID] = sectionData;
}
- // Check that all of the sections were found.
+ // Check that all of the required sections were found.
for (int i = 0; i < bytecode::Section::kNumSections; ++i) {
- if (!sectionDatas[i]) {
+ bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
+ if (!sectionDatas[i] && !isSectionOptional(sectionID)) {
return reader.emitError("missing data for top-level section: ",
- toString(bytecode::Section::ID(i)));
+ toString(sectionID));
}
}
@@ -931,6 +1259,12 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect])))
return failure();
+ // Process the resource section if present.
+ if (failed(parseResourceSection(
+ sectionDatas[bytecode::Section::kResource],
+ sectionDatas[bytecode::Section::kResourceOffset])))
+ return failure();
+
// Process the attribute and type section.
if (failed(attrTypeReader.initialize(
dialects, *sectionDatas[bytecode::Section::kAttrType],
@@ -1009,6 +1343,31 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
}
//===----------------------------------------------------------------------===//
+// Resource Section
+
+LogicalResult BytecodeReader::parseResourceSection(
+ Optional<ArrayRef<uint8_t>> resourceData,
+ Optional<ArrayRef<uint8_t>> resourceOffsetData) {
+ // Ensure both sections are either present or not.
+ if (resourceData.has_value() != resourceOffsetData.has_value()) {
+ if (resourceOffsetData)
+ return emitError(fileLoc, "unexpected resource offset section when "
+ "resource section is not present");
+ return emitError(
+ fileLoc,
+ "expected resource offset section when resource section is present");
+ }
+
+ // If the resource sections are absent, there is nothing to do.
+ if (!resourceData)
+ return success();
+
+ // Initialize the resource reader with the resource sections.
+ return resourceReader.initialize(fileLoc, config, dialects, stringReader,
+ *resourceData, *resourceOffsetData);
+}
+
+//===----------------------------------------------------------------------===//
// IR Section
LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData,
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 5f34d9b..ff53cec 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -24,6 +24,29 @@ using namespace mlir;
using namespace mlir::bytecode::detail;
//===----------------------------------------------------------------------===//
+// BytecodeWriterConfig
+//===----------------------------------------------------------------------===//
+
+struct BytecodeWriterConfig::Impl {
+ Impl(StringRef producer) : producer(producer) {}
+
+ /// The producer of the bytecode.
+ StringRef producer;
+
+ /// A collection of non-dialect resource printers.
+ SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
+};
+
+BytecodeWriterConfig::BytecodeWriterConfig(StringRef producer)
+ : impl(std::make_unique<Impl>(producer)) {}
+BytecodeWriterConfig::~BytecodeWriterConfig() = default;
+
+void BytecodeWriterConfig::attachResourcePrinter(
+ std::unique_ptr<AsmResourcePrinter> printer) {
+ impl->externalResourcePrinters.emplace_back(std::move(printer));
+}
+
+//===----------------------------------------------------------------------===//
// EncodingEmitter
//===----------------------------------------------------------------------===//
@@ -56,6 +79,48 @@ public:
currentResult[offset - prevResultSize] = value;
}
+ /// Emit the provided blob of data, which is owned by the caller and is
+ /// guaranteed to not die before the end of the bytecode process.
+ void emitOwnedBlob(ArrayRef<uint8_t> data) {
+ // Push the current buffer before adding the provided data.
+ appendResult(std::move(currentResult));
+ appendOwnedResult(data);
+ }
+
+ /// Emit the provided blob of data that has the given alignment, which is
+ /// owned by the caller and is guaranteed to not die before the end of the
+ /// bytecode process. The alignment value is also encoded, making it available
+ /// on load.
+ void emitOwnedBlobAndAlignment(ArrayRef<uint8_t> data, uint32_t alignment) {
+ emitVarInt(alignment);
+ emitVarInt(data.size());
+
+ alignTo(alignment);
+ emitOwnedBlob(data);
+ }
+ void emitOwnedBlobAndAlignment(ArrayRef<char> data, uint32_t alignment) {
+ ArrayRef<uint8_t> castedData(reinterpret_cast<const uint8_t *>(data.data()),
+ data.size());
+ emitOwnedBlobAndAlignment(castedData, alignment);
+ }
+
+ /// Align the emitter to the given alignment.
+ void alignTo(unsigned alignment) {
+ if (alignment < 2)
+ return;
+ assert(llvm::isPowerOf2_32(alignment) && "expected valid alignment");
+
+ // Check to see if we need to emit any padding bytes to meet the desired
+ // alignment.
+ size_t curOffset = size();
+ size_t paddingSize = llvm::alignTo(curOffset, alignment) - curOffset;
+ while (paddingSize--)
+ emitByte(bytecode::kAlignmentByte);
+
+ // Keep track of the maximum required alignment.
+ requiredAlignment = std::max(requiredAlignment, alignment);
+ }
+
//===--------------------------------------------------------------------===//
// Integer Emission
@@ -119,15 +184,37 @@ public:
/// Emit a nested section of the given code, whose contents are encoded in the
/// provided emitter.
void emitSection(bytecode::Section::ID code, EncodingEmitter &&emitter) {
- // Emit the section code and length.
+ // Emit the section code and length. The high bit of the code is used to
+ // indicate whether the section alignment is present, so save an offset to
+ // it.
+ uint64_t codeOffset = currentResult.size();
emitByte(code);
emitVarInt(emitter.size());
+ // Integrate the alignment of the section into this emitter if necessary.
+ unsigned emitterAlign = emitter.requiredAlignment;
+ if (emitterAlign > 1) {
+ if (size() & (emitterAlign - 1)) {
+ emitVarInt(emitterAlign);
+ alignTo(emitterAlign);
+
+ // Indicate that we needed to align the section, the high bit of the
+ // code field is used for this.
+ currentResult[codeOffset] |= 0b10000000;
+ } else {
+ // Otherwise, if we happen to be at a compatible offset, we just
+ // remember that we need this alignment.
+ requiredAlignment = std::max(requiredAlignment, emitterAlign);
+ }
+ }
+
// Push our current buffer and then merge the provided section body into
// ours.
appendResult(std::move(currentResult));
for (std::vector<uint8_t> &result : emitter.prevResultStorage)
- appendResult(std::move(result));
+ prevResultStorage.push_back(std::move(result));
+ llvm::append_range(prevResultList, emitter.prevResultList);
+ prevResultSize += emitter.prevResultSize;
appendResult(std::move(emitter.currentResult));
}
@@ -140,9 +227,16 @@ private:
/// Append a new result buffer to the current contents.
void appendResult(std::vector<uint8_t> &&result) {
- prevResultSize += result.size();
+ if (result.empty())
+ return;
prevResultStorage.emplace_back(std::move(result));
- prevResultList.emplace_back(prevResultStorage.back());
+ appendOwnedResult(prevResultStorage.back());
+ }
+ void appendOwnedResult(ArrayRef<uint8_t> result) {
+ if (result.empty())
+ return;
+ prevResultSize += result.size();
+ prevResultList.emplace_back(result);
}
/// The result of the emitter currently being built. We refrain from building
@@ -157,6 +251,9 @@ private:
/// An up-to-date total size of all of the buffers within `prevResultList`.
/// This enables O(1) size checks of the current encoding.
size_t prevResultSize = 0;
+
+ /// The highest required alignment for the start of this section.
+ unsigned requiredAlignment = 1;
};
/// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need
@@ -250,7 +347,8 @@ public:
BytecodeWriter(Operation *op) : numberingState(op) {}
/// Write the bytecode for the given root operation.
- void write(Operation *rootOp, raw_ostream &os, StringRef producer);
+ void write(Operation *rootOp, raw_ostream &os,
+ const BytecodeWriterConfig::Impl &config);
private:
//===--------------------------------------------------------------------===//
@@ -272,6 +370,12 @@ private:
void writeIRSection(EncodingEmitter &emitter, Operation *op);
//===--------------------------------------------------------------------===//
+ // Resources
+
+ void writeResourceSection(Operation *op, EncodingEmitter &emitter,
+ const BytecodeWriterConfig::Impl &config);
+
+ //===--------------------------------------------------------------------===//
// Strings
void writeStringSection(EncodingEmitter &emitter);
@@ -288,7 +392,7 @@ private:
} // namespace
void BytecodeWriter::write(Operation *rootOp, raw_ostream &os,
- StringRef producer) {
+ const BytecodeWriterConfig::Impl &config) {
EncodingEmitter emitter;
// Emit the bytecode file header. This is how we identify the output as a
@@ -299,7 +403,7 @@ void BytecodeWriter::write(Operation *rootOp, raw_ostream &os,
emitter.emitVarInt(bytecode::kVersion);
// Emit the producer.
- emitter.emitNulTerminatedString(producer);
+ emitter.emitNulTerminatedString(config.producer);
// Emit the dialect section.
writeDialectSection(emitter);
@@ -310,6 +414,9 @@ void BytecodeWriter::write(Operation *rootOp, raw_ostream &os,
// Emit the IR section.
writeIRSection(emitter, rootOp);
+ // Emit the resources section.
+ writeResourceSection(rootOp, emitter, config);
+
// Emit the string section.
writeStringSection(emitter);
@@ -386,6 +493,10 @@ public:
emitter.emitVarInt(numberingState.getNumber(type));
}
+ void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
+ emitter.emitVarInt(numberingState.getNumber(resource));
+ }
+
//===--------------------------------------------------------------------===//
// Primitives
//===--------------------------------------------------------------------===//
@@ -614,6 +725,111 @@ void BytecodeWriter::writeIRSection(EncodingEmitter &emitter, Operation *op) {
}
//===----------------------------------------------------------------------===//
+// Resources
+
+namespace {
+/// This class represents a resource builder implementation for the MLIR
+/// bytecode format.
+class ResourceBuilder : public AsmResourceBuilder {
+public:
+ using PostProcessFn = function_ref<void(StringRef, AsmResourceEntryKind)>;
+
+ ResourceBuilder(EncodingEmitter &emitter, StringSectionBuilder &stringSection,
+ PostProcessFn postProcessFn)
+ : emitter(emitter), stringSection(stringSection),
+ postProcessFn(postProcessFn) {}
+ ~ResourceBuilder() override = default;
+
+ void buildBlob(StringRef key, ArrayRef<char> data,
+ uint32_t dataAlignment) final {
+ emitter.emitOwnedBlobAndAlignment(data, dataAlignment);
+ postProcessFn(key, AsmResourceEntryKind::Blob);
+ }
+ void buildBool(StringRef key, bool data) final {
+ emitter.emitByte(data);
+ postProcessFn(key, AsmResourceEntryKind::Bool);
+ }
+ void buildString(StringRef key, StringRef data) final {
+ emitter.emitVarInt(stringSection.insert(data));
+ postProcessFn(key, AsmResourceEntryKind::String);
+ }
+
+private:
+ EncodingEmitter &emitter;
+ StringSectionBuilder &stringSection;
+ PostProcessFn postProcessFn;
+};
+} // namespace
+
+void BytecodeWriter::writeResourceSection(
+ Operation *op, EncodingEmitter &emitter,
+ const BytecodeWriterConfig::Impl &config) {
+ EncodingEmitter resourceEmitter;
+ EncodingEmitter resourceOffsetEmitter;
+ uint64_t prevOffset = 0;
+ SmallVector<std::tuple<StringRef, AsmResourceEntryKind, uint64_t>>
+ curResourceEntries;
+
+ // Functor used to process the offset for a resource of `kind` defined by
+ // 'key'.
+ auto appendResourceOffset = [&](StringRef key, AsmResourceEntryKind kind) {
+ uint64_t curOffset = resourceEmitter.size();
+ curResourceEntries.emplace_back(key, kind, curOffset - prevOffset);
+ prevOffset = curOffset;
+ };
+
+ // Functor used to emit a resource group defined by 'key'.
+ auto emitResourceGroup = [&](uint64_t key) {
+ resourceOffsetEmitter.emitVarInt(key);
+ resourceOffsetEmitter.emitVarInt(curResourceEntries.size());
+ for (auto [key, kind, size] : curResourceEntries) {
+ resourceOffsetEmitter.emitVarInt(stringSection.insert(key));
+ resourceOffsetEmitter.emitVarInt(size);
+ resourceOffsetEmitter.emitByte(kind);
+ }
+ };
+
+ // Builder used to emit resources.
+ ResourceBuilder entryBuilder(resourceEmitter, stringSection,
+ appendResourceOffset);
+
+ // Emit the external resource entries.
+ resourceOffsetEmitter.emitVarInt(config.externalResourcePrinters.size());
+ for (const auto &printer : config.externalResourcePrinters) {
+ curResourceEntries.clear();
+ printer->buildResources(op, entryBuilder);
+ emitResourceGroup(stringSection.insert(printer->getName()));
+ }
+
+ // Emit the dialect resource entries.
+ for (DialectNumbering &dialect : numberingState.getDialects()) {
+ if (!dialect.asmInterface)
+ continue;
+ curResourceEntries.clear();
+ dialect.asmInterface->buildResources(op, dialect.resources, entryBuilder);
+
+ // Emit the declaration resources for this dialect, these didn't get emitted
+ // by the interface. These resources don't have data attached, so just use a
+ // "blob" kind as a placeholder.
+ for (const auto &resource : dialect.resourceMap)
+ if (resource.second->isDeclaration)
+ appendResourceOffset(resource.first, AsmResourceEntryKind::Blob);
+
+ // Emit the resource group for this dialect.
+ if (!curResourceEntries.empty())
+ emitResourceGroup(dialect.number);
+ }
+
+ // If we didn't emit any resource groups, elide the resource sections.
+ if (resourceOffsetEmitter.size() == 0)
+ return;
+
+ emitter.emitSection(bytecode::Section::kResourceOffset,
+ std::move(resourceOffsetEmitter));
+ emitter.emitSection(bytecode::Section::kResource, std::move(resourceEmitter));
+}
+
+//===----------------------------------------------------------------------===//
// Strings
void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) {
@@ -627,7 +843,7 @@ void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) {
//===----------------------------------------------------------------------===//
void mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
- StringRef producer) {
+ const BytecodeWriterConfig &config) {
BytecodeWriter writer(op);
- writer.write(op, os, producer);
+ writer.write(op, os, config.getImpl());
}
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 549a6ac..d5a2ef5 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -9,6 +9,7 @@
#include "IRNumbering.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Bytecode/BytecodeWriter.h"
+#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
@@ -24,6 +25,9 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
void writeAttribute(Attribute attr) override { state.number(attr); }
void writeType(Type type) override { state.number(type); }
+ void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
+ state.number(resource.getDialect(), resource);
+ }
/// Stubbed out methods that are not used for numbering.
void writeVarInt(uint64_t) override {}
@@ -148,6 +152,9 @@ IRNumberingState::IRNumberingState(Operation *op) {
groupByDialectPerByte(llvm::makeMutableArrayRef(orderedAttrs));
groupByDialectPerByte(llvm::makeMutableArrayRef(orderedOpNames));
groupByDialectPerByte(llvm::makeMutableArrayRef(orderedTypes));
+
+ // Finalize the numbering of the dialect resources.
+ finalizeDialectResourceNumberings(op);
}
void IRNumberingState::number(Attribute attr) {
@@ -174,12 +181,23 @@ void IRNumberingState::number(Attribute attr) {
// dummy writing to number any nested components.
if (const auto *interface = numbering->dialect->interface) {
// TODO: We don't allow custom encodings for mutable attributes right now.
- if (attr.hasTrait<AttributeTrait::IsMutable>())
- return;
-
- NumberingDialectWriter writer(*this);
- (void)interface->writeAttribute(attr, writer);
+ if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
+ NumberingDialectWriter writer(*this);
+ if (succeeded(interface->writeAttribute(attr, writer)))
+ return;
+ }
}
+ // If this attribute will be emitted using the fallback, number the nested
+ // dialect resources. We don't number everything (e.g. no nested
+ // attributes/types), because we don't want to encode things we won't decode
+ // (the textual format can't really share much).
+ AsmState tempState(attr.getContext());
+ llvm::raw_null_ostream dummyOS;
+ attr.print(dummyOS, tempState);
+
+ // Number the used dialect resources.
+ for (const auto &it : tempState.getDialectResources())
+ number(it.getFirst(), it.getSecond().getArrayRef());
}
void IRNumberingState::number(Block &block) {
@@ -203,6 +221,7 @@ auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & {
if (!numbering) {
numbering = &numberDialect(dialect->getNamespace());
numbering->interface = dyn_cast<BytecodeDialectInterface>(dialect);
+ numbering->asmInterface = dyn_cast<OpAsmDialectInterface>(dialect);
}
return *numbering;
}
@@ -292,10 +311,92 @@ void IRNumberingState::number(Type type) {
// writing to number any nested components.
if (const auto *interface = numbering->dialect->interface) {
// TODO: We don't allow custom encodings for mutable types right now.
- if (type.hasTrait<TypeTrait::IsMutable>())
+ if (!type.hasTrait<TypeTrait::IsMutable>()) {
+ NumberingDialectWriter writer(*this);
+ if (succeeded(interface->writeType(type, writer)))
+ return;
+ }
+ }
+ // If this type will be emitted using the fallback, number the nested dialect
+ // resources. We don't number everything (e.g. no nested attributes/types),
+ // because we don't want to encode things we won't decode (the textual format
+ // can't really share much).
+ AsmState tempState(type.getContext());
+ llvm::raw_null_ostream dummyOS;
+ type.print(dummyOS, tempState);
+
+ // Number the used dialect resources.
+ for (const auto &it : tempState.getDialectResources())
+ number(it.getFirst(), it.getSecond().getArrayRef());
+}
+
+void IRNumberingState::number(Dialect *dialect,
+ ArrayRef<AsmDialectResourceHandle> resources) {
+ DialectNumbering &dialectNumber = numberDialect(dialect);
+ assert(
+ dialectNumber.asmInterface &&
+ "expected dialect owning a resource to implement OpAsmDialectInterface");
+
+ for (const auto &resource : resources) {
+ // Check if this is a newly seen resource.
+ if (!dialectNumber.resources.insert(resource))
return;
- NumberingDialectWriter writer(*this);
- (void)interface->writeType(type, writer);
+ auto *numbering =
+ new (resourceAllocator.Allocate()) DialectResourceNumbering(
+ dialectNumber.asmInterface->getResourceKey(resource));
+ dialectNumber.resourceMap.insert({numbering->key, numbering});
+ dialectResources.try_emplace(resource, numbering);
+ }
+}
+
+namespace {
+/// A dummy resource builder used to number dialect resources.
+struct NumberingResourceBuilder : public AsmResourceBuilder {
+ NumberingResourceBuilder(DialectNumbering *dialect, unsigned &nextResourceID)
+ : dialect(dialect), nextResourceID(nextResourceID) {}
+ ~NumberingResourceBuilder() override = default;
+
+ void buildBlob(StringRef key, ArrayRef<char>, uint32_t) final {
+ numberEntry(key);
+ }
+ void buildBool(StringRef key, bool) final { numberEntry(key); }
+ void buildString(StringRef key, StringRef) final {
+ // TODO: We could pre-number the value string here as well.
+ numberEntry(key);
+ }
+
+ /// Number the dialect entry for the given key.
+ void numberEntry(StringRef key) {
+ // TODO: We could pre-number resource key strings here as well.
+
+ auto it = dialect->resourceMap.find(key);
+ if (it != dialect->resourceMap.end()) {
+ it->second->number = nextResourceID++;
+ it->second->isDeclaration = false;
+ }
+ }
+
+ DialectNumbering *dialect;
+ unsigned &nextResourceID;
+};
+} // namespace
+
+void IRNumberingState::finalizeDialectResourceNumberings(Operation *rootOp) {
+ unsigned nextResourceID = 0;
+ for (DialectNumbering &dialect : getDialects()) {
+ if (!dialect.asmInterface)
+ continue;
+ NumberingResourceBuilder entryBuilder(&dialect, nextResourceID);
+ dialect.asmInterface->buildResources(rootOp, dialect.resources,
+ entryBuilder);
+
+ // Number any resources that weren't added by the dialect. This can happen
+ // if there was no backing data to the resource, but we still want these
+ // resource references to roundtrip, so we number them and indicate that the
+ // data is missing.
+ for (const auto &it : dialect.resourceMap)
+ if (it.second->isDeclaration)
+ it.second->number = nextResourceID++;
}
}
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
index 9f4cbfe..aeb624e 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.h
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -14,8 +14,10 @@
#ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
#define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
-#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/StringMap.h"
namespace mlir {
class BytecodeDialectInterface;
@@ -77,6 +79,25 @@ struct OpNameNumbering {
};
//===----------------------------------------------------------------------===//
+// Dialect Resource Numbering
+//===----------------------------------------------------------------------===//
+
+/// This class represents a numbering entry for a dialect resource.
+struct DialectResourceNumbering {
+ DialectResourceNumbering(std::string key) : key(std::move(key)) {}
+
+ /// The key used to reference this resource.
+ std::string key;
+
+ /// The number assigned to this resource.
+ unsigned number = 0;
+
+ /// A flag indicating if this resource is only a declaration, not a full
+ /// definition.
+ bool isDeclaration = true;
+};
+
+//===----------------------------------------------------------------------===//
// Dialect Numbering
//===----------------------------------------------------------------------===//
@@ -93,6 +114,15 @@ struct DialectNumbering {
/// The bytecode dialect interface of the dialect if defined.
const BytecodeDialectInterface *interface = nullptr;
+
+ /// The asm dialect interface of the dialect if defined.
+ const OpAsmDialectInterface *asmInterface = nullptr;
+
+ /// The referenced resources of this dialect.
+ SetVector<AsmDialectResourceHandle> resources;
+
+ /// A mapping from resource key to the corresponding resource numbering entry.
+ llvm::MapVector<StringRef, DialectResourceNumbering *> resourceMap;
};
//===----------------------------------------------------------------------===//
@@ -134,6 +164,10 @@ public:
assert(valueIDs.count(value) && "value not numbered");
return valueIDs[value];
}
+ unsigned getNumber(const AsmDialectResourceHandle &resource) {
+ assert(dialectResources.count(resource) && "resource not numbered");
+ return dialectResources[resource]->number;
+ }
/// Return the block and value counts of the given region.
std::pair<unsigned, unsigned> getBlockValueCount(Region *region) {
@@ -162,6 +196,12 @@ private:
void number(Region &region);
void number(Type type);
+ /// Number the given dialect resources.
+ void number(Dialect *dialect, ArrayRef<AsmDialectResourceHandle> resources);
+
+ /// Finalize the numberings of any dialect resources.
+ void finalizeDialectResourceNumberings(Operation *rootOp);
+
/// Mapping from IR to the respective numbering entries.
DenseMap<Attribute, AttributeNumbering *> attrs;
DenseMap<OperationName, OpNameNumbering *> opNames;
@@ -172,10 +212,16 @@ private:
std::vector<OpNameNumbering *> orderedOpNames;
std::vector<TypeNumbering *> orderedTypes;
+ /// A mapping from dialect resource handle to the numbering for the referenced
+ /// resource.
+ llvm::DenseMap<AsmDialectResourceHandle, DialectResourceNumbering *>
+ dialectResources;
+
/// Allocators used for the various numbering entries.
llvm::SpecificBumpPtrAllocator<AttributeNumbering> attrAllocator;
llvm::SpecificBumpPtrAllocator<DialectNumbering> dialectAllocator;
llvm::SpecificBumpPtrAllocator<OpNameNumbering> opNameAllocator;
+ llvm::SpecificBumpPtrAllocator<DialectResourceNumbering> resourceAllocator;
llvm::SpecificBumpPtrAllocator<TypeNumbering> typeAllocator;
/// The value ID for each Block and Value.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 841cc56..395bd03 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1271,6 +1271,18 @@ AsmResourceBuilder::~AsmResourceBuilder() = default;
AsmResourceParser::~AsmResourceParser() = default;
AsmResourcePrinter::~AsmResourcePrinter() = default;
+StringRef mlir::toString(AsmResourceEntryKind kind) {
+ switch (kind) {
+ case AsmResourceEntryKind::Blob:
+ return "blob";
+ case AsmResourceEntryKind::Bool:
+ return "bool";
+ case AsmResourceEntryKind::String:
+ return "string";
+ }
+ llvm_unreachable("unknown AsmResourceEntryKind");
+}
+
//===----------------------------------------------------------------------===//
// AsmState
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp
index ce6120c..2602828 100644
--- a/mlir/lib/IR/BuiltinDialectBytecode.cpp
+++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -116,6 +117,12 @@ enum AttributeCode {
/// UnknownLoc {
/// }
kUnknownLoc = 15,
+
+ /// DenseResourceElementsAttr {
+ /// type: Type,
+ /// handle: ResourceHandle
+ /// }
+ kDenseResourceElementsAttr = 16,
};
/// This enum contains marker codes used to indicate which type is currently
@@ -272,6 +279,8 @@ struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
Attribute readAttribute(DialectBytecodeReader &reader) const override;
ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const;
+ DenseResourceElementsAttr
+ readDenseResourceElementsAttr(DialectBytecodeReader &reader) const;
DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const;
FloatAttr readFloatAttr(DialectBytecodeReader &reader) const;
IntegerAttr readIntegerAttr(DialectBytecodeReader &reader) const;
@@ -289,6 +298,8 @@ struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
LogicalResult writeAttribute(Attribute attr,
DialectBytecodeWriter &writer) const override;
void write(ArrayAttr attr, DialectBytecodeWriter &writer) const;
+ void write(DenseResourceElementsAttr attr,
+ DialectBytecodeWriter &writer) const;
void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const;
void write(IntegerAttr attr, DialectBytecodeWriter &writer) const;
void write(FloatAttr attr, DialectBytecodeWriter &writer) const;
@@ -381,6 +392,8 @@ Attribute BuiltinDialectBytecodeInterface::readAttribute(
return readNameLoc(reader);
case builtin_encoding::kUnknownLoc:
return UnknownLoc::get(getContext());
+ case builtin_encoding::kDenseResourceElementsAttr:
+ return readDenseResourceElementsAttr(reader);
default:
reader.emitError() << "unknown builtin attribute code: " << code;
return Attribute();
@@ -390,9 +403,12 @@ Attribute BuiltinDialectBytecodeInterface::readAttribute(
LogicalResult BuiltinDialectBytecodeInterface::writeAttribute(
Attribute attr, DialectBytecodeWriter &writer) const {
return TypeSwitch<Attribute, LogicalResult>(attr)
- .Case<ArrayAttr, DictionaryAttr, FloatAttr, IntegerAttr, StringAttr,
- SymbolRefAttr, TypeAttr, CallSiteLoc, FileLineColLoc, FusedLoc,
- NameLoc>([&](auto attr) {
+ .Case<ArrayAttr, DenseResourceElementsAttr, DictionaryAttr, FloatAttr,
+ IntegerAttr, StringAttr, SymbolRefAttr, TypeAttr>([&](auto attr) {
+ write(attr, writer);
+ return success();
+ })
+ .Case<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc>([&](auto attr) {
write(attr, writer);
return success();
})
@@ -426,6 +442,31 @@ void BuiltinDialectBytecodeInterface::write(
}
//===----------------------------------------------------------------------===//
+// DenseResourceElementsAttr
+
+DenseResourceElementsAttr
+BuiltinDialectBytecodeInterface::readDenseResourceElementsAttr(
+ DialectBytecodeReader &reader) const {
+ ShapedType type;
+ if (failed(reader.readType(type)))
+ return DenseResourceElementsAttr();
+
+ FailureOr<DenseResourceElementsHandle> handle =
+ reader.readResourceHandle<DenseResourceElementsHandle>();
+ if (failed(handle))
+ return DenseResourceElementsAttr();
+
+ return DenseResourceElementsAttr::get(type, *handle);
+}
+
+void BuiltinDialectBytecodeInterface::write(
+ DenseResourceElementsAttr attr, DialectBytecodeWriter &writer) const {
+ writer.writeVarInt(builtin_encoding::kDenseResourceElementsAttr);
+ writer.writeType(attr.getType());
+ writer.writeResourceHandle(attr.getRawHandle());
+}
+
+//===----------------------------------------------------------------------===//
// DictionaryAttr
DictionaryAttr BuiltinDialectBytecodeInterface::readDictionaryAttr(
diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir
index 5f71ba7..544f6a4 100644
--- a/mlir/test/Bytecode/invalid/invalid-structure.mlir
+++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir
@@ -32,7 +32,7 @@
// ID
// RUN: not mlir-opt %S/invalid-structure-section-id-unknown.mlirbc 2>&1 | FileCheck %s --check-prefix=SECTION_ID_UNKNOWN
-// SECTION_ID_UNKNOWN: invalid section ID: 255
+// SECTION_ID_UNKNOWN: invalid section ID: 127
//===--------------------------------------------------------------------===//
// Length
diff --git a/mlir/test/Bytecode/resources.mlir b/mlir/test/Bytecode/resources.mlir
new file mode 100644
index 0000000..467bfaf
--- /dev/null
+++ b/mlir/test/Bytecode/resources.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s
+
+// Bytecode currently does not support big-endian platforms
+// UNSUPPORTED: s390x-
+
+// CHECK-LABEL: @TestDialectResources
+module @TestDialectResources attributes {
+ // CHECK: bytecode.test = dense_resource<decl_resource> : tensor<2xui32>
+ // CHECK: bytecode.test2 = dense_resource<resource> : tensor<4xf64>
+ // CHECK: bytecode.test3 = dense_resource<resource_2> : tensor<4xf64>
+ bytecode.test = dense_resource<decl_resource> : tensor<2xui32>,
+ bytecode.test2 = dense_resource<resource> : tensor<4xf64>,
+ bytecode.test3 = dense_resource<resource_2> : tensor<4xf64>
+} {}
+
+// CHECK: builtin: {
+// CHECK-NEXT: resource: "0x08000000010000000000000002000000000000000300000000000000"
+// CHECK-NEXT: resource_2: "0x08000000010000000000000002000000000000000300000000000000"
+
+{-#
+ dialect_resources: {
+ builtin: {
+ resource: "0x08000000010000000000000002000000000000000300000000000000",
+ resource_2: "0x08000000010000000000000002000000000000000300000000000000"
+ }
+ }
+#-}