diff options
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
| -rw-r--r-- | mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 75 |
1 files changed, 70 insertions, 5 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 44458d0..d29053a 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -22,10 +22,13 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Endian.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SourceMgr.h" #include <cstddef> +#include <cstdint> #include <list> #include <memory> #include <numeric> @@ -111,6 +114,9 @@ public: }; // Shift the reader position to the next alignment boundary. + // Note: this assumes the pointer alignment matches the alignment of the + // data from the start of the buffer. In other words, this code is only + // valid if `dataIt` is offsetting into an already aligned buffer. while (isUnaligned(dataIt)) { uint8_t padding; if (failed(parseByte(padding))) @@ -258,9 +264,13 @@ public: return success(); } + /// Validate that the alignment requested in the section is valid. + using ValidateAlignmentFn = function_ref<LogicalResult(unsigned alignment)>; + /// Parse a section header, placing the kind of section in `sectionID` and the /// contents of the section in `sectionData`. LogicalResult parseSection(bytecode::Section::ID §ionID, + ValidateAlignmentFn alignmentValidator, ArrayRef<uint8_t> §ionData) { uint8_t sectionIDAndHasAlignment; uint64_t length; @@ -281,8 +291,22 @@ public: // Process the section alignment if present. if (hasAlignment) { + // Read the requested alignment from the bytecode parser. uint64_t alignment; - if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) + if (failed(parseVarInt(alignment))) + return failure(); + + // Check that the requested alignment is less than or equal to the + // alignment of the root buffer. If it is not, we cannot safely guarantee + // that the specified alignment is globally correct. + // + // E.g. if the buffer is 8k aligned and the section is 16k aligned, + // we could end up at an offset of 24k, which is not globally 16k aligned. + if (failed(alignmentValidator(alignment))) + return emitError("failed to align section ID: ", unsigned(sectionID)); + + // Align the buffer. + if (failed(alignTo(alignment))) return failure(); } @@ -1396,6 +1420,29 @@ private: return success(); } + LogicalResult checkSectionAlignment( + unsigned alignment, + function_ref<InFlightDiagnostic(const Twine &error)> emitError) { + // Check that the bytecode buffer meets the requested section alignment. + // + // If it does not, the virtual address of the item in the section will + // not be aligned to the requested alignment. + // + // The typical case where this is necessary is the resource blob + // optimization in `parseAsBlob` where we reference the weights from the + // provided buffer instead of copying them to a new allocation. + const bool isGloballyAligned = + ((uintptr_t)buffer.getBufferStart() & (alignment - 1)) == 0; + + if (!isGloballyAligned) + return emitError("expected section alignment ") + << alignment << " but bytecode buffer 0x" + << Twine::utohexstr((uint64_t)buffer.getBufferStart()) + << " is not aligned"; + + return success(); + }; + /// Return the context for this config. MLIRContext *getContext() const { return config.getContext(); } @@ -1506,7 +1553,7 @@ private: UseListOrderStorage(bool isIndexPairEncoding, SmallVector<unsigned, 4> &&indices) : indices(std::move(indices)), - isIndexPairEncoding(isIndexPairEncoding){}; + isIndexPairEncoding(isIndexPairEncoding) {}; /// The vector containing the information required to reorder the /// use-list of a value. SmallVector<unsigned, 4> indices; @@ -1651,6 +1698,11 @@ LogicalResult BytecodeReader::Impl::read( return failure(); }); + const auto checkSectionAlignment = [&](unsigned alignment) { + return this->checkSectionAlignment( + alignment, [&](const auto &msg) { return reader.emitError(msg); }); + }; + // Parse the raw data for each of the top-level sections of the bytecode. std::optional<ArrayRef<uint8_t>> sectionDatas[bytecode::Section::kNumSections]; @@ -1658,7 +1710,8 @@ LogicalResult BytecodeReader::Impl::read( // Read the next section from the bytecode. bytecode::Section::ID sectionID; ArrayRef<uint8_t> sectionData; - if (failed(reader.parseSection(sectionID, sectionData))) + if (failed( + reader.parseSection(sectionID, checkSectionAlignment, sectionData))) return failure(); // Check for duplicate sections, we only expect one instance of each. @@ -1778,6 +1831,12 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { return failure(); dialects.resize(numDialects); + const auto checkSectionAlignment = [&](unsigned alignment) { + return this->checkSectionAlignment(alignment, [&](const auto &msg) { + return sectionReader.emitError(msg); + }); + }; + // Parse each of the dialects. for (uint64_t i = 0; i < numDialects; ++i) { dialects[i] = std::make_unique<BytecodeDialect>(); @@ -1800,7 +1859,7 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { return failure(); if (versionAvailable) { bytecode::Section::ID sectionID; - if (failed(sectionReader.parseSection(sectionID, + if (failed(sectionReader.parseSection(sectionID, checkSectionAlignment, dialects[i]->versionBuffer))) return failure(); if (sectionID != bytecode::Section::kDialectVersions) { @@ -2121,6 +2180,11 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, LogicalResult BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack, RegionReadState &readState) { + const auto checkSectionAlignment = [&](unsigned alignment) { + return this->checkSectionAlignment( + alignment, [&](const auto &msg) { return emitError(fileLoc, msg); }); + }; + // Process regions, blocks, and operations until the end or if a nested // region is encountered. In this case we push a new state in regionStack and // return, the processing of the current region will resume afterward. @@ -2161,7 +2225,8 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack, if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) { bytecode::Section::ID sectionID; ArrayRef<uint8_t> sectionData; - if (failed(reader.parseSection(sectionID, sectionData))) + if (failed(reader.parseSection(sectionID, checkSectionAlignment, + sectionData))) return failure(); if (sectionID != bytecode::Section::kIR) return emitError(fileLoc, "expected IR section for region"); |
