diff options
| author | Mehdi Amini <joker.eph@gmail.com> | 2023-05-25 21:04:35 -0700 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2023-05-26 17:45:01 -0700 |
| commit | 660f714e26999d266232a1fbb02712bb879bd34e (patch) | |
| tree | a3a473f8ac64651140d855c2d6521cada262fd65 /mlir/lib/Bytecode/Reader/BytecodeReader.cpp | |
| parent | f354e971b09c244147ff59eb65b34487755598c0 (diff) | |
| download | llvm-660f714e26999d266232a1fbb02712bb879bd34e.zip llvm-660f714e26999d266232a1fbb02712bb879bd34e.tar.gz llvm-660f714e26999d266232a1fbb02712bb879bd34e.tar.bz2 | |
[MLIR] Add native Bytecode support for properties
This is adding a new interface (`BytecodeOpInterface`) to allow operations to
opt-in skipping conversion to attribute and serializing properties to native
bytecode.
The scheme relies on a new section where properties are stored in sequence
{ size, serialize_properties }, ...
The operations are storing the index of a properties, a table of offset is
built when loading the properties section the first time.
This is a re-commit of 837d1ce0dc which conflicted with another patch upgrading
the bytecode and the collision wasn't properly resolved before.
Differential Revision: https://reviews.llvm.org/D151065
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
| -rw-r--r-- | mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 198 |
1 files changed, 184 insertions, 14 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index ca05eac..b4fe53e 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -11,6 +11,7 @@ #include "mlir/Bytecode/BytecodeReader.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Bytecode/Encoding.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" @@ -20,6 +21,7 @@ #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallString.h" @@ -28,6 +30,7 @@ #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SourceMgr.h" +#include <cstddef> #include <list> #include <memory> #include <numeric> @@ -56,13 +59,15 @@ static std::string toString(bytecode::Section::ID sectionID) { return "ResourceOffset (6)"; case bytecode::Section::kDialectVersions: return "DialectVersions (7)"; + case bytecode::Section::kProperties: + return "Properties (8)"; default: return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); } } /// Returns true if the given top-level section ID is optional. -static bool isSectionOptional(bytecode::Section::ID sectionID) { +static bool isSectionOptional(bytecode::Section::ID sectionID, int version) { switch (sectionID) { case bytecode::Section::kString: case bytecode::Section::kDialect: @@ -74,6 +79,8 @@ static bool isSectionOptional(bytecode::Section::ID sectionID) { case bytecode::Section::kResourceOffset: case bytecode::Section::kDialectVersions: return true; + case bytecode::Section::kProperties: + return version < 5; default: llvm_unreachable("unknown section ID"); } @@ -364,6 +371,17 @@ public: /// Parse a shared string from the string section. The shared string is /// encoded using an index to a corresponding string in the string section. + /// This variant parses a flag compressed with the index. + LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result, + bool &flag) { + uint64_t entryIdx; + if (failed(reader.parseVarIntWithFlag(entryIdx, flag))) + return failure(); + return parseStringAtIndex(reader, entryIdx, result); + } + + /// Parse a shared string from the string section. The shared string is + /// encoded using an index to a corresponding string in the string section. LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, StringRef &result) { return resolveEntry(reader, strings, index, result, "string"); @@ -459,8 +477,9 @@ struct BytecodeDialect { /// This struct represents an operation name entry within the bytecode. struct BytecodeOperationName { - BytecodeOperationName(BytecodeDialect *dialect, StringRef name) - : dialect(dialect), name(name) {} + BytecodeOperationName(BytecodeDialect *dialect, StringRef name, + std::optional<bool> wasRegistered) + : dialect(dialect), name(name), wasRegistered(wasRegistered) {} /// The loaded operation name, or std::nullopt if it hasn't been processed /// yet. @@ -471,6 +490,10 @@ struct BytecodeOperationName { /// The name of the operation, without the dialect prefix. StringRef name; + + /// Whether this operation was registered when the bytecode was produced. + /// This flag is populated when bytecode version >=5. + std::optional<bool> wasRegistered; }; } // namespace @@ -791,6 +814,18 @@ public: result = resolveAttribute(attrIdx); return success(!!result); } + LogicalResult parseOptionalAttribute(EncodingReader &reader, + Attribute &result) { + uint64_t attrIdx; + bool flag; + if (failed(reader.parseVarIntWithFlag(attrIdx, flag))) + return failure(); + if (!flag) + return success(); + result = resolveAttribute(attrIdx); + return success(!!result); + } + LogicalResult parseType(EncodingReader &reader, Type &result) { uint64_t typeIdx; if (failed(reader.parseVarInt(typeIdx))) @@ -870,7 +905,9 @@ public: LogicalResult readAttribute(Attribute &result) override { return attrTypeReader.parseAttribute(reader, result); } - + LogicalResult readOptionalAttribute(Attribute &result) override { + return attrTypeReader.parseOptionalAttribute(reader, result); + } LogicalResult readType(Type &result) override { return attrTypeReader.parseType(reader, result); } @@ -957,6 +994,87 @@ private: ResourceSectionReader &resourceReader; EncodingReader &reader; }; + +/// Wraps the properties section and handles reading properties out of it. +class PropertiesSectionReader { +public: + /// Initialize the properties section reader with the given section data. + LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) { + if (sectionData.empty()) + return success(); + EncodingReader propReader(sectionData, fileLoc); + size_t count; + if (failed(propReader.parseVarInt(count))) + return failure(); + // Parse the raw properties buffer. + if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers))) + return failure(); + + EncodingReader offsetsReader(propertiesBuffers, fileLoc); + offsetTable.reserve(count); + for (auto idx : llvm::seq<int64_t>(0, count)) { + (void)idx; + offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size()); + ArrayRef<uint8_t> rawProperties; + size_t dataSize; + if (failed(offsetsReader.parseVarInt(dataSize)) || + failed(offsetsReader.parseBytes(dataSize, rawProperties))) + return failure(); + } + if (!offsetsReader.empty()) + return offsetsReader.emitError() + << "Broken properties section: didn't exhaust the offsets table"; + return success(); + } + + LogicalResult read(Location fileLoc, DialectReader &dialectReader, + OperationName *opName, OperationState &opState) { + uint64_t propertiesIdx; + if (failed(dialectReader.readVarInt(propertiesIdx))) + return failure(); + if (propertiesIdx >= offsetTable.size()) + return dialectReader.emitError("Properties idx out-of-bound for ") + << opName->getStringRef(); + size_t propertiesOffset = offsetTable[propertiesIdx]; + if (propertiesIdx >= propertiesBuffers.size()) + return dialectReader.emitError("Properties offset out-of-bound for ") + << opName->getStringRef(); + + // Acquire the sub-buffer that represent the requested properties. + ArrayRef<char> rawProperties; + { + // "Seek" to the requested offset by getting a new reader with the right + // sub-buffer. + EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset), + fileLoc); + // Properties are stored as a sequence of {size + raw_data}. + if (failed( + dialectReader.withEncodingReader(reader).readBlob(rawProperties))) + return failure(); + } + // Setup a new reader to read from the `rawProperties` sub-buffer. + EncodingReader reader( + StringRef(rawProperties.begin(), rawProperties.size()), fileLoc); + DialectReader propReader = dialectReader.withEncodingReader(reader); + + auto *iface = opName->getInterface<BytecodeOpInterface>(); + if (iface) + return iface->readProperties(propReader, opState); + if (opName->isRegistered()) + return propReader.emitError( + "has properties but missing BytecodeOpInterface for ") + << opName->getStringRef(); + // Unregistered op are storing properties as an attribute. + return propReader.readAttribute(opState.propertiesAttr); + } + +private: + /// The properties buffer referenced within the bytecode file. + ArrayRef<uint8_t> propertiesBuffers; + + /// Table of offset in the buffer above. + SmallVector<int64_t> offsetTable; +}; } // namespace LogicalResult @@ -1194,7 +1312,9 @@ private: lazyLoadableOps.erase(it->getSecond()); lazyLoadableOpsMap.erase(it); auto result = parseRegions(regionStack, regionStack.back()); - assert(regionStack.empty()); + assert((regionStack.empty() || failed(result)) && + "broken invariant: regionStack should be empty when parseRegions " + "succeeds"); return result; } @@ -1209,8 +1329,11 @@ private: LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData); - /// Parse an operation name reference using the given reader. - FailureOr<OperationName> parseOpName(EncodingReader &reader); + /// Parse an operation name reference using the given reader, and set the + /// `wasRegistered` flag that indicates if the bytecode was produced by a + /// context where opName was registered. + FailureOr<OperationName> parseOpName(EncodingReader &reader, + std::optional<bool> &wasRegistered); //===--------------------------------------------------------------------===// // Attribute/Type Section @@ -1398,6 +1521,9 @@ private: /// The table of strings referenced within the bytecode file. StringSectionReader stringReader; + /// The table of properties referenced by the operation in the bytecode file. + PropertiesSectionReader propertiesReader; + /// The current set of available IR value scopes. std::vector<ValueScope> valueScopes; @@ -1466,7 +1592,7 @@ LogicalResult BytecodeReader::Impl::read( // Check that all of the required sections were found. for (int i = 0; i < bytecode::Section::kNumSections; ++i) { bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); - if (!sectionDatas[i] && !isSectionOptional(sectionID)) { + if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) { return reader.emitError("missing data for top-level section: ", ::toString(sectionID)); } @@ -1477,6 +1603,12 @@ LogicalResult BytecodeReader::Impl::read( fileLoc, *sectionDatas[bytecode::Section::kString]))) return failure(); + // Process the properties section. + if (sectionDatas[bytecode::Section::kProperties] && + failed(propertiesReader.initialize( + fileLoc, *sectionDatas[bytecode::Section::kProperties]))) + return failure(); + // Process the dialect section. if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) return failure(); @@ -1598,9 +1730,20 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { // Parse the operation names, which are grouped by dialect. auto parseOpName = [&](BytecodeDialect *dialect) { StringRef opName; - if (failed(stringReader.parseString(sectionReader, opName))) - return failure(); - opNames.emplace_back(dialect, opName); + std::optional<bool> wasRegistered; + // Prior to version 5, the information about wheter an op was registered or + // not wasn't encoded. + if (version < 5) { + if (failed(stringReader.parseString(sectionReader, opName))) + return failure(); + } else { + bool wasRegisteredFlag; + if (failed(stringReader.parseStringWithFlag(sectionReader, opName, + wasRegisteredFlag))) + return failure(); + wasRegistered = wasRegisteredFlag; + } + opNames.emplace_back(dialect, opName, wasRegistered); return success(); }; // Avoid re-allocation in bytecode version > 3 where the number of ops are @@ -1618,11 +1761,12 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { } FailureOr<OperationName> -BytecodeReader::Impl::parseOpName(EncodingReader &reader) { +BytecodeReader::Impl::parseOpName(EncodingReader &reader, + std::optional<bool> &wasRegistered) { BytecodeOperationName *opName = nullptr; if (failed(parseEntry(reader, opNames, opName, "operation name"))) return failure(); - + wasRegistered = opName->wasRegistered; // Check to see if this operation name has already been resolved. If we // haven't, load the dialect and build the operation name. if (!opName->opName) { @@ -1994,7 +2138,8 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, RegionReadState &readState, bool &isIsolatedFromAbove) { // Parse the name of the operation. - FailureOr<OperationName> opName = parseOpName(reader); + std::optional<bool> wasRegistered; + FailureOr<OperationName> opName = parseOpName(reader, wasRegistered); if (failed(opName)) return failure(); @@ -2021,6 +2166,31 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, opState.attributes = dictAttr; } + if (opMask & bytecode::OpEncodingMask::kHasProperties) { + // kHasProperties wasn't emitted in older bytecode, we should never get + // there without also having the `wasRegistered` flag available. + if (!wasRegistered) + return emitError(fileLoc, + "Unexpected missing `wasRegistered` opname flag at " + "bytecode version ") + << version << " with properties."; + // When an operation is emitted without being registered, the properties are + // stored as an attribute. Otherwise the op must implement the bytecode + // interface and control the serialization. + if (wasRegistered) { + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + reader); + if (failed( + propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) + return failure(); + } else { + // If the operation wasn't registered when it was emitted, the properties + // was serialized as an attribute. + if (failed(parseAttribute(reader, opState.propertiesAttr))) + return failure(); + } + } + /// Parse the results of the operation. if (opMask & bytecode::OpEncodingMask::kHasResults) { uint64_t numResults; |
