From 8106c816eb8b279ee4220936c43a0e495d1bb1a0 Mon Sep 17 00:00:00 2001 From: Nikhil Kalra Date: Thu, 4 Sep 2025 22:31:40 -0700 Subject: [MLIR][Bytecode] Enforce alignment requirements (#157004) Adds a check that the bytecode buffer is aligned to any section alignment requirements. Without this check, if the source buffer is not sufficiently aligned, we may return early when aligning the data pointer. In that case, we may end up trying to read successive sections from an incorrect offset, giving the appearance of invalid bytecode. This requirement is documented in the bytecode unit tests, but is not otherwise documented in the code or bytecode reference. --- mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 75 +++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 5 deletions(-) (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp') diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 44458d0..d29053a 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -22,10 +22,13 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Endian.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SourceMgr.h" #include +#include #include #include #include @@ -111,6 +114,9 @@ public: }; // Shift the reader position to the next alignment boundary. + // Note: this assumes the pointer alignment matches the alignment of the + // data from the start of the buffer. In other words, this code is only + // valid if `dataIt` is offsetting into an already aligned buffer. while (isUnaligned(dataIt)) { uint8_t padding; if (failed(parseByte(padding))) @@ -258,9 +264,13 @@ public: return success(); } + /// Validate that the alignment requested in the section is valid. + using ValidateAlignmentFn = function_ref; + /// Parse a section header, placing the kind of section in `sectionID` and the /// contents of the section in `sectionData`. LogicalResult parseSection(bytecode::Section::ID §ionID, + ValidateAlignmentFn alignmentValidator, ArrayRef §ionData) { uint8_t sectionIDAndHasAlignment; uint64_t length; @@ -281,8 +291,22 @@ public: // Process the section alignment if present. if (hasAlignment) { + // Read the requested alignment from the bytecode parser. uint64_t alignment; - if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) + if (failed(parseVarInt(alignment))) + return failure(); + + // Check that the requested alignment is less than or equal to the + // alignment of the root buffer. If it is not, we cannot safely guarantee + // that the specified alignment is globally correct. + // + // E.g. if the buffer is 8k aligned and the section is 16k aligned, + // we could end up at an offset of 24k, which is not globally 16k aligned. + if (failed(alignmentValidator(alignment))) + return emitError("failed to align section ID: ", unsigned(sectionID)); + + // Align the buffer. + if (failed(alignTo(alignment))) return failure(); } @@ -1396,6 +1420,29 @@ private: return success(); } + LogicalResult checkSectionAlignment( + unsigned alignment, + function_ref emitError) { + // Check that the bytecode buffer meets the requested section alignment. + // + // If it does not, the virtual address of the item in the section will + // not be aligned to the requested alignment. + // + // The typical case where this is necessary is the resource blob + // optimization in `parseAsBlob` where we reference the weights from the + // provided buffer instead of copying them to a new allocation. + const bool isGloballyAligned = + ((uintptr_t)buffer.getBufferStart() & (alignment - 1)) == 0; + + if (!isGloballyAligned) + return emitError("expected section alignment ") + << alignment << " but bytecode buffer 0x" + << Twine::utohexstr((uint64_t)buffer.getBufferStart()) + << " is not aligned"; + + return success(); + }; + /// Return the context for this config. MLIRContext *getContext() const { return config.getContext(); } @@ -1506,7 +1553,7 @@ private: UseListOrderStorage(bool isIndexPairEncoding, SmallVector &&indices) : indices(std::move(indices)), - isIndexPairEncoding(isIndexPairEncoding){}; + isIndexPairEncoding(isIndexPairEncoding) {}; /// The vector containing the information required to reorder the /// use-list of a value. SmallVector indices; @@ -1651,6 +1698,11 @@ LogicalResult BytecodeReader::Impl::read( return failure(); }); + const auto checkSectionAlignment = [&](unsigned alignment) { + return this->checkSectionAlignment( + alignment, [&](const auto &msg) { return reader.emitError(msg); }); + }; + // Parse the raw data for each of the top-level sections of the bytecode. std::optional> sectionDatas[bytecode::Section::kNumSections]; @@ -1658,7 +1710,8 @@ LogicalResult BytecodeReader::Impl::read( // Read the next section from the bytecode. bytecode::Section::ID sectionID; ArrayRef sectionData; - if (failed(reader.parseSection(sectionID, sectionData))) + if (failed( + reader.parseSection(sectionID, checkSectionAlignment, sectionData))) return failure(); // Check for duplicate sections, we only expect one instance of each. @@ -1778,6 +1831,12 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef sectionData) { return failure(); dialects.resize(numDialects); + const auto checkSectionAlignment = [&](unsigned alignment) { + return this->checkSectionAlignment(alignment, [&](const auto &msg) { + return sectionReader.emitError(msg); + }); + }; + // Parse each of the dialects. for (uint64_t i = 0; i < numDialects; ++i) { dialects[i] = std::make_unique(); @@ -1800,7 +1859,7 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef sectionData) { return failure(); if (versionAvailable) { bytecode::Section::ID sectionID; - if (failed(sectionReader.parseSection(sectionID, + if (failed(sectionReader.parseSection(sectionID, checkSectionAlignment, dialects[i]->versionBuffer))) return failure(); if (sectionID != bytecode::Section::kDialectVersions) { @@ -2121,6 +2180,11 @@ BytecodeReader::Impl::parseIRSection(ArrayRef sectionData, LogicalResult BytecodeReader::Impl::parseRegions(std::vector ®ionStack, RegionReadState &readState) { + const auto checkSectionAlignment = [&](unsigned alignment) { + return this->checkSectionAlignment( + alignment, [&](const auto &msg) { return emitError(fileLoc, msg); }); + }; + // Process regions, blocks, and operations until the end or if a nested // region is encountered. In this case we push a new state in regionStack and // return, the processing of the current region will resume afterward. @@ -2161,7 +2225,8 @@ BytecodeReader::Impl::parseRegions(std::vector ®ionStack, if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) { bytecode::Section::ID sectionID; ArrayRef sectionData; - if (failed(reader.parseSection(sectionID, sectionData))) + if (failed(reader.parseSection(sectionID, checkSectionAlignment, + sectionData))) return failure(); if (sectionID != bytecode::Section::kIR) return emitError(fileLoc, "expected IR section for region"); -- cgit v1.1