diff options
| author | Matteo Franciolini <m_franciolini@apple.com> | 2023-05-21 16:46:59 -0700 |
|---|---|---|
| committer | Mehdi Amini <joker.eph@gmail.com> | 2023-05-21 16:48:12 -0700 |
| commit | 612781918fb01a2a0985a1c4c9200f5d5d1581cc (patch) | |
| tree | 8a0383af5f3512dd9e1d256eb0fa230d33b4fa9d /mlir/lib/Bytecode/Reader/BytecodeReader.cpp | |
| parent | f81ccb520927247b02708873567428d6988e2a07 (diff) | |
| download | llvm-612781918fb01a2a0985a1c4c9200f5d5d1581cc.zip llvm-612781918fb01a2a0985a1c4c9200f5d5d1581cc.tar.gz llvm-612781918fb01a2a0985a1c4c9200f5d5d1581cc.tar.bz2 | |
Preserve use-list orders in mlir bytecode
This patch implements a mechanism to read/write use-list orders from/to the mlir bytecode format. When producing bytecode, use-list orders are appended to each value of the IR. When reading bytecode, use-lists orders are loaded in memory and used at the end of parsing to sort the existing use-list chains.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D149755
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
| -rw-r--r-- | mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 257 |
1 files changed, 255 insertions, 2 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 58145fa..92584d5 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -7,12 +7,11 @@ //===----------------------------------------------------------------------===// // TODO: Support for big-endian architectures. -// TODO: Properly preserve use lists of values. #include "mlir/Bytecode/BytecodeReader.h" -#include "../Encoding.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/Encoding.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpImplementation.h" @@ -29,6 +28,7 @@ #include "llvm/Support/SourceMgr.h" #include <list> #include <memory> +#include <numeric> #include <optional> #define DEBUG_TYPE "mlir-bytecode-reader" @@ -1282,6 +1282,42 @@ private: Value createForwardRef(); //===--------------------------------------------------------------------===// + // Use-list order helpers + + /// This struct is a simple storage that contains information required to + /// reorder the use-list of a value with respect to the pre-order traversal + /// ordering. + struct UseListOrderStorage { + UseListOrderStorage(bool isIndexPairEncoding, + SmallVector<unsigned, 4> &&indices) + : indices(std::move(indices)), + isIndexPairEncoding(isIndexPairEncoding){}; + /// The vector containing the information required to reorder the + /// use-list of a value. + SmallVector<unsigned, 4> indices; + + /// Whether indices represent a pair of type `(src, dst)` or it is a direct + /// indexing, such as `dst = order[src]`. + bool isIndexPairEncoding; + }; + + /// Parse use-list order from bytecode for a range of values if available. The + /// range is expected to be either a block argument or an op result range. On + /// success, return a map of the position in the range and the use-list order + /// encoding. The function assumes to know the size of the range it is + /// processing. + using UseListMapT = DenseMap<unsigned, UseListOrderStorage>; + FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader, + uint64_t rangeSize); + + /// Shuffle the use-chain according to the order parsed. + LogicalResult sortUseListOrder(Value value); + + /// Recursively visit all the values defined within topLevelOp and sort the + /// use-list orders according to the indices parsed. + LogicalResult processUseLists(Operation *topLevelOp); + + //===--------------------------------------------------------------------===// // Fields /// This class represents a single value scope, in which a value scope is @@ -1341,17 +1377,27 @@ private: /// The reader used to process resources within the bytecode. ResourceSectionReader resourceReader; + /// Worklist of values with custom use-list orders to process before the end + /// of the parsing. + DenseMap<void *, UseListOrderStorage> valueToUseListMap; + /// The table of strings referenced within the bytecode file. StringSectionReader stringReader; /// The current set of available IR value scopes. std::vector<ValueScope> valueScopes; + + /// The global pre-order operation ordering. + DenseMap<Operation *, unsigned> operationIDs; + /// A block containing the set of operations defined to create forward /// references. Block forwardRefOps; + /// A block containing previously created, and no longer used, forward /// reference operations. Block openForwardRefOps; + /// An operation state used when instantiating forward references. OperationState forwardRefOpState; @@ -1598,6 +1644,165 @@ LogicalResult BytecodeReader::Impl::parseResourceSection( } //===----------------------------------------------------------------------===// +// UseListOrder Helpers + +FailureOr<BytecodeReader::Impl::UseListMapT> +BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader, + uint64_t numResults) { + BytecodeReader::Impl::UseListMapT map; + uint64_t numValuesToRead = 1; + if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead))) + return failure(); + + for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) { + uint64_t resultIdx = 0; + if (numResults > 1 && failed(reader.parseVarInt(resultIdx))) + return failure(); + + uint64_t numValues; + bool indexPairEncoding; + if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding))) + return failure(); + + SmallVector<unsigned, 4> useListOrders; + for (size_t idx = 0; idx < numValues; idx++) { + uint64_t index; + if (failed(reader.parseVarInt(index))) + return failure(); + useListOrders.push_back(index); + } + + // Store in a map the result index + map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding, + std::move(useListOrders))); + } + + return map; +} + +/// Sorts each use according to the order specified in the use-list parsed. If +/// the custom use-list is not found, this means that the order needs to be +/// consistent with the reverse pre-order walk of the IR. If multiple uses lie +/// on the same operation, the order will follow the reverse operand number +/// ordering. +LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) { + // Early return for trivial use-lists. + if (value.use_empty() || value.hasOneUse()) + return success(); + + bool hasIncomingOrder = + valueToUseListMap.contains(value.getAsOpaquePointer()); + + // Compute the current order of the use-list with respect to the global + // ordering. Detect if the order is already sorted while doing so. + bool alreadySorted = true; + auto &firstUse = *value.use_begin(); + uint64_t prevID = + bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner())); + llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}}; + for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) { + uint64_t currentID = bytecode::getUseID( + item.value(), operationIDs.at(item.value().getOwner())); + alreadySorted &= prevID > currentID; + currentOrder.push_back({item.index(), currentID}); + prevID = currentID; + } + + // If the order is already sorted, and there wasn't a custom order to apply + // from the bytecode file, we are done. + if (alreadySorted && !hasIncomingOrder) + return success(); + + // If not already sorted, sort the indices of the current order by descending + // useIDs. + if (!alreadySorted) + std::sort( + currentOrder.begin(), currentOrder.end(), + [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); + + if (!hasIncomingOrder) { + // If the bytecode file did not contain any custom use-list order, it means + // that the order was descending useID. Hence, shuffle by the first index + // of the `currentOrder` pair. + SmallVector<unsigned> shuffle = SmallVector<unsigned>( + llvm::map_range(currentOrder, [&](auto item) { return item.first; })); + value.shuffleUseList(shuffle); + return success(); + } + + // Pull the custom order info from the map. + UseListOrderStorage customOrder = + valueToUseListMap.at(value.getAsOpaquePointer()); + SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices); + uint64_t numUses = + std::distance(value.getUses().begin(), value.getUses().end()); + + // If the encoding was a pair of indices `(src, dst)` for every permutation, + // reconstruct the shuffle vector for every use. Initialize the shuffle vector + // as identity, and then apply the mapping encoded in the indices. + if (customOrder.isIndexPairEncoding) { + // Return failure if the number of indices was not representing pairs. + if (shuffle.size() & 1) + return failure(); + + SmallVector<unsigned, 4> newShuffle(numUses); + size_t idx = 0; + std::iota(newShuffle.begin(), newShuffle.end(), idx); + for (idx = 0; idx < shuffle.size(); idx += 2) + newShuffle[shuffle[idx]] = shuffle[idx + 1]; + + shuffle = std::move(newShuffle); + } + + // Make sure that the indices represent a valid mapping. That is, the sum of + // all the values needs to be equal to (numUses - 1) * numUses / 2, and no + // duplicates are allowed in the list. + DenseSet<unsigned> set; + uint64_t accumulator = 0; + for (const auto &elem : shuffle) { + if (set.contains(elem)) + return failure(); + accumulator += elem; + set.insert(elem); + } + if (numUses != shuffle.size() || + accumulator != (((numUses - 1) * numUses) >> 1)) + return failure(); + + // Apply the current ordering map onto the shuffle vector to get the final + // use-list sorting indices before shuffling. + shuffle = SmallVector<unsigned, 4>(llvm::map_range( + currentOrder, [&](auto item) { return shuffle[item.first]; })); + value.shuffleUseList(shuffle); + return success(); +} + +LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) { + // Precompute operation IDs according to the pre-order walk of the IR. We + // can't do this while parsing since parseRegions ordering is not strictly + // equal to the pre-order walk. + unsigned operationID = 0; + topLevelOp->walk<mlir::WalkOrder::PreOrder>( + [&](Operation *op) { operationIDs.try_emplace(op, operationID++); }); + + auto blockWalk = topLevelOp->walk([this](Block *block) { + for (auto arg : block->getArguments()) + if (failed(sortUseListOrder(arg))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + + auto resultWalk = topLevelOp->walk([this](Operation *op) { + for (auto result : op->getResults()) + if (failed(sortUseListOrder(result))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + + return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted()); +} + +//===----------------------------------------------------------------------===// // IR Section LogicalResult @@ -1627,6 +1832,11 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, "not all forward unresolved forward operand references"); } + // Sort use-lists according to what specified in bytecode. + if (failed(processUseLists(*moduleOp))) + return reader.emitError( + "parsed use-list orders were invalid and could not be applied"); + // Resolve dialect version. for (const BytecodeDialect &byteCodeDialect : dialects) { // Parsing is complete, give an opportunity to each dialect to visit the @@ -1812,6 +2022,17 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, } } + /// Parse the use-list orders for the results of the operation. Use-list + /// orders are available since version 3 of the bytecode. + std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt; + if (version > 2 && (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) { + size_t numResults = opState.types.size(); + auto parseResult = parseUseListOrderForRange(reader, numResults); + if (failed(parseResult)) + return failure(); + resultIdxToUseListMap = std::move(*parseResult); + } + /// Parse the regions of the operation. if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { uint64_t numRegions; @@ -1831,6 +2052,16 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) return failure(); + /// Store a map for every value that received a custom use-list order from the + /// bytecode file. + if (resultIdxToUseListMap.has_value()) { + for (size_t idx = 0; idx < op->getNumResults(); idx++) { + if (resultIdxToUseListMap->contains(idx)) { + valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(), + resultIdxToUseListMap->at(idx)); + } + } + } return op; } @@ -1880,6 +2111,28 @@ BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader, if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) return failure(); + // Uselist orders are available since version 3 of the bytecode. + if (version < 3) + return success(); + + uint8_t hasUseListOrders = 0; + if (hasArgs && failed(reader.parseByte(hasUseListOrders))) + return failure(); + + if (!hasUseListOrders) + return success(); + + Block &blk = *readState.curBlock; + auto argIdxToUseListMap = + parseUseListOrderForRange(reader, blk.getNumArguments()); + if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty()) + return failure(); + + for (size_t idx = 0; idx < blk.getNumArguments(); idx++) + if (argIdxToUseListMap->contains(idx)) + valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(), + argIdxToUseListMap->at(idx)); + // We don't parse the operations of the block here, that's done elsewhere. return success(); } |
