diff options
| author | Mehdi Amini <joker.eph@gmail.com> | 2023-04-29 02:36:45 -0700 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2023-05-20 15:24:33 -0700 |
| commit | 3128b3105d7a226fc26174be265da479ff619f3e (patch) | |
| tree | 07a0b642d033b5622d1bb36254faf3b12af79ee6 /mlir/lib/Bytecode/Reader/BytecodeReader.cpp | |
| parent | e8cc0d310c4060b605dc293bff5d2bca95ff528b (diff) | |
| download | llvm-3128b3105d7a226fc26174be265da479ff619f3e.zip llvm-3128b3105d7a226fc26174be265da479ff619f3e.tar.gz llvm-3128b3105d7a226fc26174be265da479ff619f3e.tar.bz2 | |
Add support for Lazyloading to the MLIR bytecode
IsolatedRegions are emitted in sections in order for the reader to be
able to skip over them. A new class is exposed to manage the state and
allow the readers to load these IsolatedRegions on-demand.
Differential Revision: https://reviews.llvm.org/D149515
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
| -rw-r--r-- | mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 284 |
1 files changed, 230 insertions, 54 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 9344ec9..58145fa 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -17,6 +17,9 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallString.h" @@ -24,6 +27,8 @@ #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SourceMgr.h" +#include <list> +#include <memory> #include <optional> #define DEBUG_TYPE "mlir-bytecode-reader" @@ -1092,25 +1097,93 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, // Bytecode Reader //===----------------------------------------------------------------------===// -namespace { /// This class is used to read a bytecode buffer and translate it into MLIR. -class BytecodeReader { +class mlir::BytecodeReader::Impl { + struct RegionReadState; + using LazyLoadableOpsInfo = + std::list<std::pair<Operation *, RegionReadState>>; + using LazyLoadableOpsMap = + DenseMap<Operation *, LazyLoadableOpsInfo::iterator>; + public: - BytecodeReader(Location fileLoc, const ParserConfig &config, - const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) - : config(config), fileLoc(fileLoc), + Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, + llvm::MemoryBufferRef buffer, + const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) + : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), attrTypeReader(stringReader, resourceReader, fileLoc), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), "builtin.unrealized_conversion_cast", ValueRange(), NoneType::get(config.getContext())), - bufferOwnerRef(bufferOwnerRef) {} + buffer(buffer), bufferOwnerRef(bufferOwnerRef) {} /// Read the bytecode defined within `buffer` into the given block. - LogicalResult read(llvm::MemoryBufferRef buffer, Block *block); + LogicalResult read(Block *block, + llvm::function_ref<bool(Operation *)> lazyOps); + + /// Return the number of ops that haven't been materialized yet. + int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); } + + bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); } + + /// Materialize the provided operation, invoke the lazyOpsCallback on every + /// newly found lazy operation. + LogicalResult + materialize(Operation *op, + llvm::function_ref<bool(Operation *)> lazyOpsCallback) { + this->lazyOpsCallback = lazyOpsCallback; + auto resetlazyOpsCallback = + llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); + auto it = lazyLoadableOpsMap.find(op); + assert(it != lazyLoadableOpsMap.end() && + "materialize called on non-materializable op"); + return materialize(it); + } + + /// Materialize all operations. + LogicalResult materializeAll() { + while (!lazyLoadableOpsMap.empty()) { + if (failed(materialize(lazyLoadableOpsMap.begin()))) + return failure(); + } + return success(); + } + + /// Finalize the lazy-loading by calling back with every op that hasn't been + /// materialized to let the client decide if the op should be deleted or + /// materialized. The op is materialized if the callback returns true, deleted + /// otherwise. + LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) { + while (!lazyLoadableOps.empty()) { + Operation *op = lazyLoadableOps.begin()->first; + if (shouldMaterialize(op)) { + if (failed(materialize(lazyLoadableOpsMap.find(op)))) + return failure(); + continue; + } + op->dropAllReferences(); + op->erase(); + lazyLoadableOps.pop_front(); + lazyLoadableOpsMap.erase(op); + } + return success(); + } private: + LogicalResult materialize(LazyLoadableOpsMap::iterator it) { + assert(it != lazyLoadableOpsMap.end() && + "materialize called on non-materializable op"); + valueScopes.emplace_back(); + std::vector<RegionReadState> regionStack; + regionStack.push_back(std::move(it->getSecond()->second)); + lazyLoadableOps.erase(it->getSecond()); + lazyLoadableOpsMap.erase(it); + auto result = parseRegions(regionStack, regionStack.back()); + assert(regionStack.empty()); + return result; + } + /// Return the context for this config. MLIRContext *getContext() const { return config.getContext(); } @@ -1151,14 +1224,22 @@ private: /// This struct represents the current read state of a range of regions. This /// struct is used to enable iterative parsing of regions. struct RegionReadState { - RegionReadState(Operation *op, bool isIsolatedFromAbove) - : RegionReadState(op->getRegions(), isIsolatedFromAbove) {} - RegionReadState(MutableArrayRef<Region> regions, bool isIsolatedFromAbove) - : curRegion(regions.begin()), endRegion(regions.end()), + RegionReadState(Operation *op, EncodingReader *reader, + bool isIsolatedFromAbove) + : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {} + RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader, + bool isIsolatedFromAbove) + : curRegion(regions.begin()), endRegion(regions.end()), reader(reader), isIsolatedFromAbove(isIsolatedFromAbove) {} /// The current regions being read. MutableArrayRef<Region>::iterator curRegion, endRegion; + /// This is the reader to use for this region, this pointer is pointing to + /// the parent region reader unless the current region is IsolatedFromAbove, + /// in which case the pointer is pointing to the `owningReader` which is a + /// section dedicated to the current region. + EncodingReader *reader; + std::unique_ptr<EncodingReader> owningReader; /// The number of values defined immediately within this region. unsigned numValues = 0; @@ -1176,15 +1257,15 @@ private: }; LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block); - LogicalResult parseRegions(EncodingReader &reader, - std::vector<RegionReadState> ®ionStack, + LogicalResult parseRegions(std::vector<RegionReadState> ®ionStack, RegionReadState &readState); FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader, RegionReadState &readState, bool &isIsolatedFromAbove); - LogicalResult parseRegion(EncodingReader &reader, RegionReadState &readState); - LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState); + LogicalResult parseRegion(RegionReadState &readState); + LogicalResult parseBlockHeader(EncodingReader &reader, + RegionReadState &readState); LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); //===--------------------------------------------------------------------===// @@ -1234,6 +1315,16 @@ private: /// A location to use when emitting errors. Location fileLoc; + /// Flag that indicates if lazyloading is enabled. + bool lazyLoading; + + /// Keep track of operations that have been lazy loaded (their regions haven't + /// been materialized), along with the `RegionReadState` that allows to + /// lazy-load the regions nested under the operation. + LazyLoadableOpsInfo lazyLoadableOps; + LazyLoadableOpsMap lazyLoadableOpsMap; + llvm::function_ref<bool(Operation *)> lazyOpsCallback; + /// The reader used to process attribute and types within the bytecode. AttrTypeReader attrTypeReader; @@ -1264,14 +1355,20 @@ private: /// An operation state used when instantiating forward references. OperationState forwardRefOpState; + /// Reference to the input buffer. + llvm::MemoryBufferRef buffer; + /// The optional owning source manager, which when present may be used to /// extend the lifetime of the input buffer. const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; }; -} // namespace -LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { +LogicalResult BytecodeReader::Impl::read( + Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { EncodingReader reader(buffer.getBuffer(), fileLoc); + this->lazyOpsCallback = lazyOpsCallback; + auto resetlazyOpsCallback = + llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); // Skip over the bytecode header, this should have already been checked. if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) @@ -1302,7 +1399,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { // Check for duplicate sections, we only expect one instance of each. if (sectionDatas[sectionID]) { return reader.emitError("duplicate top-level section: ", - toString(sectionID)); + ::toString(sectionID)); } sectionDatas[sectionID] = sectionData; } @@ -1311,7 +1408,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); if (!sectionDatas[i] && !isSectionOptional(sectionID)) { return reader.emitError("missing data for top-level section: ", - toString(sectionID)); + ::toString(sectionID)); } } @@ -1340,7 +1437,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); } -LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) { +LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { if (failed(reader.parseVarInt(version))) return failure(); @@ -1357,6 +1454,9 @@ LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) { " is newer than the current version ", currentVersion); } + // Override any request to lazy-load if the bytecode version is too old. + if (version < 2) + lazyLoading = false; return success(); } @@ -1396,7 +1496,7 @@ LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { } LogicalResult -BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) { +BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { EncodingReader sectionReader(sectionData, fileLoc); // Parse the number of dialects in the section. @@ -1449,7 +1549,8 @@ BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) { return success(); } -FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) { +FailureOr<OperationName> +BytecodeReader::Impl::parseOpName(EncodingReader &reader) { BytecodeOperationName *opName = nullptr; if (failed(parseEntry(reader, opNames, opName, "operation name"))) return failure(); @@ -1471,7 +1572,7 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) { //===----------------------------------------------------------------------===// // Resource Section -LogicalResult BytecodeReader::parseResourceSection( +LogicalResult BytecodeReader::Impl::parseResourceSection( EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData, std::optional<ArrayRef<uint8_t>> resourceOffsetData) { // Ensure both sections are either present or not. @@ -1499,8 +1600,9 @@ LogicalResult BytecodeReader::parseResourceSection( //===----------------------------------------------------------------------===// // IR Section -LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData, - Block *block) { +LogicalResult +BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, + Block *block) { EncodingReader reader(sectionData, fileLoc); // A stack of operation regions currently being read from the bytecode. @@ -1508,17 +1610,17 @@ LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData, // Parse the top-level block using a temporary module operation. OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc); - regionStack.emplace_back(*moduleOp, /*isIsolatedFromAbove=*/true); + regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true); regionStack.back().curBlocks.push_back(moduleOp->getBody()); regionStack.back().curBlock = regionStack.back().curRegion->begin(); - if (failed(parseBlock(reader, regionStack.back()))) + if (failed(parseBlockHeader(reader, regionStack.back()))) return failure(); valueScopes.emplace_back(); valueScopes.back().push(regionStack.back()); // Iteratively parse regions until everything has been resolved. while (!regionStack.empty()) - if (failed(parseRegions(reader, regionStack, regionStack.back()))) + if (failed(parseRegions(regionStack, regionStack.back()))) return failure(); if (!forwardRefOps.empty()) { return reader.emitError( @@ -1549,15 +1651,18 @@ LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData, } LogicalResult -BytecodeReader::parseRegions(EncodingReader &reader, - std::vector<RegionReadState> ®ionStack, - RegionReadState &readState) { - // Read the regions of this operation. +BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack, + RegionReadState &readState) { + // Process regions, blocks, and operations until the end or if a nested + // region is encountered. In this case we push a new state in regionStack and + // return, the processing of the current region will resume afterward. for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { // If the current block hasn't been setup yet, parse the header for this - // region. + // region. The current block is already setup when this function was + // interrupted to recurse down in a nested region and we resume the current + // block after processing the nested region. if (readState.curBlock == Region::iterator()) { - if (failed(parseRegion(reader, readState))) + if (failed(parseRegion(readState))) return failure(); // If the region is empty, there is nothing to more to do. @@ -1566,6 +1671,7 @@ BytecodeReader::parseRegions(EncodingReader &reader, } // Parse the blocks within the region. + EncodingReader &reader = *readState.reader; do { while (readState.numOpsRemaining--) { // Read in the next operation. We don't read its regions directly, we @@ -1576,9 +1682,38 @@ BytecodeReader::parseRegions(EncodingReader &reader, if (failed(op)) return failure(); - // If the op has regions, add it to the stack for processing. + // If the op has regions, add it to the stack for processing and return: + // we stop the processing of the current region and resume it after the + // inner one is completed. Unless LazyLoading is activated in which case + // nested region parsing is delayed. if ((*op)->getNumRegions()) { - regionStack.emplace_back(*op, isIsolatedFromAbove); + RegionReadState childState(*op, &reader, isIsolatedFromAbove); + + // Isolated regions are encoded as a section in version 2 and above. + if (version >= 2 && isIsolatedFromAbove) { + bytecode::Section::ID sectionID; + ArrayRef<uint8_t> sectionData; + if (failed(reader.parseSection(sectionID, sectionData))) + return failure(); + if (sectionID != bytecode::Section::kIR) + return emitError(fileLoc, "expected IR section for region"); + childState.owningReader = + std::make_unique<EncodingReader>(sectionData, fileLoc); + childState.reader = childState.owningReader.get(); + } + + if (lazyLoading) { + // If the user has a callback set, they have the opportunity + // to control lazyloading as we go. + if (!lazyOpsCallback || !lazyOpsCallback(*op)) { + lazyLoadableOps.push_back( + std::make_pair(*op, std::move(childState))); + lazyLoadableOpsMap.try_emplace(*op, + std::prev(lazyLoadableOps.end())); + continue; + } + } + regionStack.push_back(std::move(childState)); // If the op is isolated from above, push a new value scope. if (isIsolatedFromAbove) @@ -1590,7 +1725,7 @@ BytecodeReader::parseRegions(EncodingReader &reader, // Move to the next block of the region. if (++readState.curBlock == readState.curRegion->end()) break; - if (failed(parseBlock(reader, readState))) + if (failed(parseBlockHeader(reader, readState))) return failure(); } while (true); @@ -1601,16 +1736,19 @@ BytecodeReader::parseRegions(EncodingReader &reader, // When the regions have been fully parsed, pop them off of the read stack. If // the regions were isolated from above, we also pop the last value scope. - if (readState.isIsolatedFromAbove) + if (readState.isIsolatedFromAbove) { + assert(!valueScopes.empty() && "Expect a valueScope after reading region"); valueScopes.pop_back(); + } + assert(!regionStack.empty() && "Expect a regionStack after reading region"); regionStack.pop_back(); return success(); } FailureOr<Operation *> -BytecodeReader::parseOpWithoutRegions(EncodingReader &reader, - RegionReadState &readState, - bool &isIsolatedFromAbove) { +BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, + RegionReadState &readState, + bool &isIsolatedFromAbove) { // Parse the name of the operation. FailureOr<OperationName> opName = parseOpName(reader); if (failed(opName)) @@ -1696,8 +1834,9 @@ BytecodeReader::parseOpWithoutRegions(EncodingReader &reader, return op; } -LogicalResult BytecodeReader::parseRegion(EncodingReader &reader, - RegionReadState &readState) { +LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) { + EncodingReader &reader = *readState.reader; + // Parse the number of blocks in the region. uint64_t numBlocks; if (failed(reader.parseVarInt(numBlocks))) @@ -1727,11 +1866,12 @@ LogicalResult BytecodeReader::parseRegion(EncodingReader &reader, // Parse the entry block of the region. readState.curBlock = readState.curRegion->begin(); - return parseBlock(reader, readState); + return parseBlockHeader(reader, readState); } -LogicalResult BytecodeReader::parseBlock(EncodingReader &reader, - RegionReadState &readState) { +LogicalResult +BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader, + RegionReadState &readState) { bool hasArgs; if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) return failure(); @@ -1744,8 +1884,8 @@ LogicalResult BytecodeReader::parseBlock(EncodingReader &reader, return success(); } -LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader, - Block *block) { +LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader, + Block *block) { // Parse the value ID for the first argument, and the number of arguments. uint64_t numArgs; if (failed(reader.parseVarInt(numArgs))) @@ -1773,7 +1913,7 @@ LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader, //===----------------------------------------------------------------------===// // Value Processing -Value BytecodeReader::parseOperand(EncodingReader &reader) { +Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) { std::vector<Value> &values = valueScopes.back().values; Value *value = nullptr; if (failed(parseEntry(reader, values, value, "value"))) @@ -1785,8 +1925,8 @@ Value BytecodeReader::parseOperand(EncodingReader &reader) { return *value; } -LogicalResult BytecodeReader::defineValues(EncodingReader &reader, - ValueRange newValues) { +LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader, + ValueRange newValues) { ValueScope &valueScope = valueScopes.back(); std::vector<Value> &values = valueScope.values; @@ -1821,7 +1961,7 @@ LogicalResult BytecodeReader::defineValues(EncodingReader &reader, return success(); } -Value BytecodeReader::createForwardRef() { +Value BytecodeReader::Impl::createForwardRef() { // Check for an avaliable existing operation to use. Otherwise, create a new // fake operation to use for the reference. if (!openForwardRefOps.empty()) { @@ -1837,6 +1977,41 @@ Value BytecodeReader::createForwardRef() { // Entry Points //===----------------------------------------------------------------------===// +BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); } + +BytecodeReader::BytecodeReader( + llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading, + const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { + Location sourceFileLoc = + FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), + /*line=*/0, /*column=*/0); + impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer, + bufferOwnerRef); +} + +LogicalResult BytecodeReader::readTopLevel( + Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { + return impl->read(block, lazyOpsCallback); +} + +int64_t BytecodeReader::getNumOpsToMaterialize() const { + return impl->getNumOpsToMaterialize(); +} + +bool BytecodeReader::isMaterializable(Operation *op) { + return impl->isMaterializable(op); +} + +LogicalResult BytecodeReader::materialize( + Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) { + return impl->materialize(op, lazyOpsCallback); +} + +LogicalResult +BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) { + return impl->finalize(shouldMaterialize); +} + bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { return buffer.getBuffer().startswith("ML\xefR"); } @@ -1856,8 +2031,9 @@ readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, "input buffer is not an MLIR bytecode file"); } - BytecodeReader reader(sourceFileLoc, config, bufferOwnerRef); - return reader.read(buffer, block); + BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false, + buffer, bufferOwnerRef); + return reader.read(block, /*lazyOpsCallback=*/nullptr); } LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, |
