diff options
| author | Nikhil Kalra <nkalra@apple.com> | 2025-09-04 22:31:40 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-04 22:31:40 -0700 |
| commit | 8106c816eb8b279ee4220936c43a0e495d1bb1a0 (patch) | |
| tree | 1bb80f6d91e0f6b434b652ca9390800713731139 /mlir/lib/Bytecode/Reader/BytecodeReader.cpp | |
| parent | ffbd6162103041697ce7387029f321d3a466ca34 (diff) | |
| download | llvm-8106c816eb8b279ee4220936c43a0e495d1bb1a0.zip llvm-8106c816eb8b279ee4220936c43a0e495d1bb1a0.tar.gz llvm-8106c816eb8b279ee4220936c43a0e495d1bb1a0.tar.bz2 | |
[MLIR][Bytecode] Enforce alignment requirements (#157004)
Adds a check that the bytecode buffer is aligned to any section
alignment requirements. Without this check, if the source buffer is not
sufficiently aligned, we may return early when aligning the data
pointer. In that case, we may end up trying to read successive sections
from an incorrect offset, giving the appearance of invalid bytecode.
This requirement is documented in the bytecode unit tests, but is not
otherwise documented in the code or bytecode reference.
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
| -rw-r--r-- | mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 75 |
1 files changed, 70 insertions, 5 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 44458d0..d29053a 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -22,10 +22,13 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Endian.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SourceMgr.h" #include <cstddef> +#include <cstdint> #include <list> #include <memory> #include <numeric> @@ -111,6 +114,9 @@ public: }; // Shift the reader position to the next alignment boundary. + // Note: this assumes the pointer alignment matches the alignment of the + // data from the start of the buffer. In other words, this code is only + // valid if `dataIt` is offsetting into an already aligned buffer. while (isUnaligned(dataIt)) { uint8_t padding; if (failed(parseByte(padding))) @@ -258,9 +264,13 @@ public: return success(); } + /// Validate that the alignment requested in the section is valid. + using ValidateAlignmentFn = function_ref<LogicalResult(unsigned alignment)>; + /// Parse a section header, placing the kind of section in `sectionID` and the /// contents of the section in `sectionData`. LogicalResult parseSection(bytecode::Section::ID §ionID, + ValidateAlignmentFn alignmentValidator, ArrayRef<uint8_t> §ionData) { uint8_t sectionIDAndHasAlignment; uint64_t length; @@ -281,8 +291,22 @@ public: // Process the section alignment if present. if (hasAlignment) { + // Read the requested alignment from the bytecode parser. uint64_t alignment; - if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) + if (failed(parseVarInt(alignment))) + return failure(); + + // Check that the requested alignment is less than or equal to the + // alignment of the root buffer. If it is not, we cannot safely guarantee + // that the specified alignment is globally correct. + // + // E.g. if the buffer is 8k aligned and the section is 16k aligned, + // we could end up at an offset of 24k, which is not globally 16k aligned. + if (failed(alignmentValidator(alignment))) + return emitError("failed to align section ID: ", unsigned(sectionID)); + + // Align the buffer. + if (failed(alignTo(alignment))) return failure(); } @@ -1396,6 +1420,29 @@ private: return success(); } + LogicalResult checkSectionAlignment( + unsigned alignment, + function_ref<InFlightDiagnostic(const Twine &error)> emitError) { + // Check that the bytecode buffer meets the requested section alignment. + // + // If it does not, the virtual address of the item in the section will + // not be aligned to the requested alignment. + // + // The typical case where this is necessary is the resource blob + // optimization in `parseAsBlob` where we reference the weights from the + // provided buffer instead of copying them to a new allocation. + const bool isGloballyAligned = + ((uintptr_t)buffer.getBufferStart() & (alignment - 1)) == 0; + + if (!isGloballyAligned) + return emitError("expected section alignment ") + << alignment << " but bytecode buffer 0x" + << Twine::utohexstr((uint64_t)buffer.getBufferStart()) + << " is not aligned"; + + return success(); + }; + /// Return the context for this config. MLIRContext *getContext() const { return config.getContext(); } @@ -1506,7 +1553,7 @@ private: UseListOrderStorage(bool isIndexPairEncoding, SmallVector<unsigned, 4> &&indices) : indices(std::move(indices)), - isIndexPairEncoding(isIndexPairEncoding){}; + isIndexPairEncoding(isIndexPairEncoding) {}; /// The vector containing the information required to reorder the /// use-list of a value. SmallVector<unsigned, 4> indices; @@ -1651,6 +1698,11 @@ LogicalResult BytecodeReader::Impl::read( return failure(); }); + const auto checkSectionAlignment = [&](unsigned alignment) { + return this->checkSectionAlignment( + alignment, [&](const auto &msg) { return reader.emitError(msg); }); + }; + // Parse the raw data for each of the top-level sections of the bytecode. std::optional<ArrayRef<uint8_t>> sectionDatas[bytecode::Section::kNumSections]; @@ -1658,7 +1710,8 @@ LogicalResult BytecodeReader::Impl::read( // Read the next section from the bytecode. bytecode::Section::ID sectionID; ArrayRef<uint8_t> sectionData; - if (failed(reader.parseSection(sectionID, sectionData))) + if (failed( + reader.parseSection(sectionID, checkSectionAlignment, sectionData))) return failure(); // Check for duplicate sections, we only expect one instance of each. @@ -1778,6 +1831,12 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { return failure(); dialects.resize(numDialects); + const auto checkSectionAlignment = [&](unsigned alignment) { + return this->checkSectionAlignment(alignment, [&](const auto &msg) { + return sectionReader.emitError(msg); + }); + }; + // Parse each of the dialects. for (uint64_t i = 0; i < numDialects; ++i) { dialects[i] = std::make_unique<BytecodeDialect>(); @@ -1800,7 +1859,7 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) { return failure(); if (versionAvailable) { bytecode::Section::ID sectionID; - if (failed(sectionReader.parseSection(sectionID, + if (failed(sectionReader.parseSection(sectionID, checkSectionAlignment, dialects[i]->versionBuffer))) return failure(); if (sectionID != bytecode::Section::kDialectVersions) { @@ -2121,6 +2180,11 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData, LogicalResult BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack, RegionReadState &readState) { + const auto checkSectionAlignment = [&](unsigned alignment) { + return this->checkSectionAlignment( + alignment, [&](const auto &msg) { return emitError(fileLoc, msg); }); + }; + // Process regions, blocks, and operations until the end or if a nested // region is encountered. In this case we push a new state in regionStack and // return, the processing of the current region will resume afterward. @@ -2161,7 +2225,8 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack, if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) { bytecode::Section::ID sectionID; ArrayRef<uint8_t> sectionData; - if (failed(reader.parseSection(sectionID, sectionData))) + if (failed(reader.parseSection(sectionID, checkSectionAlignment, + sectionData))) return failure(); if (sectionID != bytecode::Section::kIR) return emitError(fileLoc, "expected IR section for region"); |
