aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
diff options
context:
space:
mode:
authorMatteo Franciolini <m_franciolini@apple.com>2023-05-21 16:46:59 -0700
committerMehdi Amini <joker.eph@gmail.com>2023-05-21 16:48:12 -0700
commit612781918fb01a2a0985a1c4c9200f5d5d1581cc (patch)
tree8a0383af5f3512dd9e1d256eb0fa230d33b4fa9d /mlir/lib/Bytecode/Reader/BytecodeReader.cpp
parentf81ccb520927247b02708873567428d6988e2a07 (diff)
downloadllvm-612781918fb01a2a0985a1c4c9200f5d5d1581cc.zip
llvm-612781918fb01a2a0985a1c4c9200f5d5d1581cc.tar.gz
llvm-612781918fb01a2a0985a1c4c9200f5d5d1581cc.tar.bz2
Preserve use-list orders in mlir bytecode
This patch implements a mechanism to read/write use-list orders from/to the mlir bytecode format. When producing bytecode, use-list orders are appended to each value of the IR. When reading bytecode, use-lists orders are loaded in memory and used at the end of parsing to sort the existing use-list chains. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D149755
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
-rw-r--r--mlir/lib/Bytecode/Reader/BytecodeReader.cpp257
1 files changed, 255 insertions, 2 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 58145fa..92584d5 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -7,12 +7,11 @@
//===----------------------------------------------------------------------===//
// TODO: Support for big-endian architectures.
-// TODO: Properly preserve use lists of values.
#include "mlir/Bytecode/BytecodeReader.h"
-#include "../Encoding.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Bytecode/Encoding.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpImplementation.h"
@@ -29,6 +28,7 @@
#include "llvm/Support/SourceMgr.h"
#include <list>
#include <memory>
+#include <numeric>
#include <optional>
#define DEBUG_TYPE "mlir-bytecode-reader"
@@ -1282,6 +1282,42 @@ private:
Value createForwardRef();
//===--------------------------------------------------------------------===//
+ // Use-list order helpers
+
+ /// This struct is a simple storage that contains information required to
+ /// reorder the use-list of a value with respect to the pre-order traversal
+ /// ordering.
+ struct UseListOrderStorage {
+ UseListOrderStorage(bool isIndexPairEncoding,
+ SmallVector<unsigned, 4> &&indices)
+ : indices(std::move(indices)),
+ isIndexPairEncoding(isIndexPairEncoding){};
+ /// The vector containing the information required to reorder the
+ /// use-list of a value.
+ SmallVector<unsigned, 4> indices;
+
+ /// Whether indices represent a pair of type `(src, dst)` or it is a direct
+ /// indexing, such as `dst = order[src]`.
+ bool isIndexPairEncoding;
+ };
+
+ /// Parse use-list order from bytecode for a range of values if available. The
+ /// range is expected to be either a block argument or an op result range. On
+ /// success, return a map of the position in the range and the use-list order
+ /// encoding. The function assumes to know the size of the range it is
+ /// processing.
+ using UseListMapT = DenseMap<unsigned, UseListOrderStorage>;
+ FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader,
+ uint64_t rangeSize);
+
+ /// Shuffle the use-chain according to the order parsed.
+ LogicalResult sortUseListOrder(Value value);
+
+ /// Recursively visit all the values defined within topLevelOp and sort the
+ /// use-list orders according to the indices parsed.
+ LogicalResult processUseLists(Operation *topLevelOp);
+
+ //===--------------------------------------------------------------------===//
// Fields
/// This class represents a single value scope, in which a value scope is
@@ -1341,17 +1377,27 @@ private:
/// The reader used to process resources within the bytecode.
ResourceSectionReader resourceReader;
+ /// Worklist of values with custom use-list orders to process before the end
+ /// of the parsing.
+ DenseMap<void *, UseListOrderStorage> valueToUseListMap;
+
/// The table of strings referenced within the bytecode file.
StringSectionReader stringReader;
/// The current set of available IR value scopes.
std::vector<ValueScope> valueScopes;
+
+ /// The global pre-order operation ordering.
+ DenseMap<Operation *, unsigned> operationIDs;
+
/// A block containing the set of operations defined to create forward
/// references.
Block forwardRefOps;
+
/// A block containing previously created, and no longer used, forward
/// reference operations.
Block openForwardRefOps;
+
/// An operation state used when instantiating forward references.
OperationState forwardRefOpState;
@@ -1598,6 +1644,165 @@ LogicalResult BytecodeReader::Impl::parseResourceSection(
}
//===----------------------------------------------------------------------===//
+// UseListOrder Helpers
+
+FailureOr<BytecodeReader::Impl::UseListMapT>
+BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader,
+ uint64_t numResults) {
+ BytecodeReader::Impl::UseListMapT map;
+ uint64_t numValuesToRead = 1;
+ if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead)))
+ return failure();
+
+ for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) {
+ uint64_t resultIdx = 0;
+ if (numResults > 1 && failed(reader.parseVarInt(resultIdx)))
+ return failure();
+
+ uint64_t numValues;
+ bool indexPairEncoding;
+ if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding)))
+ return failure();
+
+ SmallVector<unsigned, 4> useListOrders;
+ for (size_t idx = 0; idx < numValues; idx++) {
+ uint64_t index;
+ if (failed(reader.parseVarInt(index)))
+ return failure();
+ useListOrders.push_back(index);
+ }
+
+ // Store in a map the result index
+ map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding,
+ std::move(useListOrders)));
+ }
+
+ return map;
+}
+
+/// Sorts each use according to the order specified in the use-list parsed. If
+/// the custom use-list is not found, this means that the order needs to be
+/// consistent with the reverse pre-order walk of the IR. If multiple uses lie
+/// on the same operation, the order will follow the reverse operand number
+/// ordering.
+LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
+ // Early return for trivial use-lists.
+ if (value.use_empty() || value.hasOneUse())
+ return success();
+
+ bool hasIncomingOrder =
+ valueToUseListMap.contains(value.getAsOpaquePointer());
+
+ // Compute the current order of the use-list with respect to the global
+ // ordering. Detect if the order is already sorted while doing so.
+ bool alreadySorted = true;
+ auto &firstUse = *value.use_begin();
+ uint64_t prevID =
+ bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner()));
+ llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}};
+ for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) {
+ uint64_t currentID = bytecode::getUseID(
+ item.value(), operationIDs.at(item.value().getOwner()));
+ alreadySorted &= prevID > currentID;
+ currentOrder.push_back({item.index(), currentID});
+ prevID = currentID;
+ }
+
+ // If the order is already sorted, and there wasn't a custom order to apply
+ // from the bytecode file, we are done.
+ if (alreadySorted && !hasIncomingOrder)
+ return success();
+
+ // If not already sorted, sort the indices of the current order by descending
+ // useIDs.
+ if (!alreadySorted)
+ std::sort(
+ currentOrder.begin(), currentOrder.end(),
+ [](auto elem1, auto elem2) { return elem1.second > elem2.second; });
+
+ if (!hasIncomingOrder) {
+ // If the bytecode file did not contain any custom use-list order, it means
+ // that the order was descending useID. Hence, shuffle by the first index
+ // of the `currentOrder` pair.
+ SmallVector<unsigned> shuffle = SmallVector<unsigned>(
+ llvm::map_range(currentOrder, [&](auto item) { return item.first; }));
+ value.shuffleUseList(shuffle);
+ return success();
+ }
+
+ // Pull the custom order info from the map.
+ UseListOrderStorage customOrder =
+ valueToUseListMap.at(value.getAsOpaquePointer());
+ SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
+ uint64_t numUses =
+ std::distance(value.getUses().begin(), value.getUses().end());
+
+ // If the encoding was a pair of indices `(src, dst)` for every permutation,
+ // reconstruct the shuffle vector for every use. Initialize the shuffle vector
+ // as identity, and then apply the mapping encoded in the indices.
+ if (customOrder.isIndexPairEncoding) {
+ // Return failure if the number of indices was not representing pairs.
+ if (shuffle.size() & 1)
+ return failure();
+
+ SmallVector<unsigned, 4> newShuffle(numUses);
+ size_t idx = 0;
+ std::iota(newShuffle.begin(), newShuffle.end(), idx);
+ for (idx = 0; idx < shuffle.size(); idx += 2)
+ newShuffle[shuffle[idx]] = shuffle[idx + 1];
+
+ shuffle = std::move(newShuffle);
+ }
+
+ // Make sure that the indices represent a valid mapping. That is, the sum of
+ // all the values needs to be equal to (numUses - 1) * numUses / 2, and no
+ // duplicates are allowed in the list.
+ DenseSet<unsigned> set;
+ uint64_t accumulator = 0;
+ for (const auto &elem : shuffle) {
+ if (set.contains(elem))
+ return failure();
+ accumulator += elem;
+ set.insert(elem);
+ }
+ if (numUses != shuffle.size() ||
+ accumulator != (((numUses - 1) * numUses) >> 1))
+ return failure();
+
+ // Apply the current ordering map onto the shuffle vector to get the final
+ // use-list sorting indices before shuffling.
+ shuffle = SmallVector<unsigned, 4>(llvm::map_range(
+ currentOrder, [&](auto item) { return shuffle[item.first]; }));
+ value.shuffleUseList(shuffle);
+ return success();
+}
+
+LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) {
+ // Precompute operation IDs according to the pre-order walk of the IR. We
+ // can't do this while parsing since parseRegions ordering is not strictly
+ // equal to the pre-order walk.
+ unsigned operationID = 0;
+ topLevelOp->walk<mlir::WalkOrder::PreOrder>(
+ [&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
+
+ auto blockWalk = topLevelOp->walk([this](Block *block) {
+ for (auto arg : block->getArguments())
+ if (failed(sortUseListOrder(arg)))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+
+ auto resultWalk = topLevelOp->walk([this](Operation *op) {
+ for (auto result : op->getResults())
+ if (failed(sortUseListOrder(result)))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+
+ return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted());
+}
+
+//===----------------------------------------------------------------------===//
// IR Section
LogicalResult
@@ -1627,6 +1832,11 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
"not all forward unresolved forward operand references");
}
+ // Sort use-lists according to what specified in bytecode.
+ if (failed(processUseLists(*moduleOp)))
+ return reader.emitError(
+ "parsed use-list orders were invalid and could not be applied");
+
// Resolve dialect version.
for (const BytecodeDialect &byteCodeDialect : dialects) {
// Parsing is complete, give an opportunity to each dialect to visit the
@@ -1812,6 +2022,17 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
}
}
+ /// Parse the use-list orders for the results of the operation. Use-list
+ /// orders are available since version 3 of the bytecode.
+ std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt;
+ if (version > 2 && (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) {
+ size_t numResults = opState.types.size();
+ auto parseResult = parseUseListOrderForRange(reader, numResults);
+ if (failed(parseResult))
+ return failure();
+ resultIdxToUseListMap = std::move(*parseResult);
+ }
+
/// Parse the regions of the operation.
if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) {
uint64_t numRegions;
@@ -1831,6 +2052,16 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
if (op->getNumResults() && failed(defineValues(reader, op->getResults())))
return failure();
+ /// Store a map for every value that received a custom use-list order from the
+ /// bytecode file.
+ if (resultIdxToUseListMap.has_value()) {
+ for (size_t idx = 0; idx < op->getNumResults(); idx++) {
+ if (resultIdxToUseListMap->contains(idx)) {
+ valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(),
+ resultIdxToUseListMap->at(idx));
+ }
+ }
+ }
return op;
}
@@ -1880,6 +2111,28 @@ BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock)))
return failure();
+ // Uselist orders are available since version 3 of the bytecode.
+ if (version < 3)
+ return success();
+
+ uint8_t hasUseListOrders = 0;
+ if (hasArgs && failed(reader.parseByte(hasUseListOrders)))
+ return failure();
+
+ if (!hasUseListOrders)
+ return success();
+
+ Block &blk = *readState.curBlock;
+ auto argIdxToUseListMap =
+ parseUseListOrderForRange(reader, blk.getNumArguments());
+ if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty())
+ return failure();
+
+ for (size_t idx = 0; idx < blk.getNumArguments(); idx++)
+ if (argIdxToUseListMap->contains(idx))
+ valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(),
+ argIdxToUseListMap->at(idx));
+
// We don't parse the operations of the block here, that's done elsewhere.
return success();
}