aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Bytecode/Reader/BytecodeReader.cpp')
-rw-r--r--mlir/lib/Bytecode/Reader/BytecodeReader.cpp75
1 files changed, 70 insertions, 5 deletions
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 44458d0..d29053a 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -22,10 +22,13 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Endian.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SourceMgr.h"
#include <cstddef>
+#include <cstdint>
#include <list>
#include <memory>
#include <numeric>
@@ -111,6 +114,9 @@ public:
};
// Shift the reader position to the next alignment boundary.
+ // Note: this assumes the pointer alignment matches the alignment of the
+ // data from the start of the buffer. In other words, this code is only
+ // valid if `dataIt` is offsetting into an already aligned buffer.
while (isUnaligned(dataIt)) {
uint8_t padding;
if (failed(parseByte(padding)))
@@ -258,9 +264,13 @@ public:
return success();
}
+ /// Validate that the alignment requested in the section is valid.
+ using ValidateAlignmentFn = function_ref<LogicalResult(unsigned alignment)>;
+
/// Parse a section header, placing the kind of section in `sectionID` and the
/// contents of the section in `sectionData`.
LogicalResult parseSection(bytecode::Section::ID &sectionID,
+ ValidateAlignmentFn alignmentValidator,
ArrayRef<uint8_t> &sectionData) {
uint8_t sectionIDAndHasAlignment;
uint64_t length;
@@ -281,8 +291,22 @@ public:
// Process the section alignment if present.
if (hasAlignment) {
+ // Read the requested alignment from the bytecode parser.
uint64_t alignment;
- if (failed(parseVarInt(alignment)) || failed(alignTo(alignment)))
+ if (failed(parseVarInt(alignment)))
+ return failure();
+
+ // Check that the requested alignment is less than or equal to the
+ // alignment of the root buffer. If it is not, we cannot safely guarantee
+ // that the specified alignment is globally correct.
+ //
+ // E.g. if the buffer is 8k aligned and the section is 16k aligned,
+ // we could end up at an offset of 24k, which is not globally 16k aligned.
+ if (failed(alignmentValidator(alignment)))
+ return emitError("failed to align section ID: ", unsigned(sectionID));
+
+ // Align the buffer.
+ if (failed(alignTo(alignment)))
return failure();
}
@@ -1396,6 +1420,29 @@ private:
return success();
}
+ LogicalResult checkSectionAlignment(
+ unsigned alignment,
+ function_ref<InFlightDiagnostic(const Twine &error)> emitError) {
+ // Check that the bytecode buffer meets the requested section alignment.
+ //
+ // If it does not, the virtual address of the item in the section will
+ // not be aligned to the requested alignment.
+ //
+ // The typical case where this is necessary is the resource blob
+ // optimization in `parseAsBlob` where we reference the weights from the
+ // provided buffer instead of copying them to a new allocation.
+ const bool isGloballyAligned =
+ ((uintptr_t)buffer.getBufferStart() & (alignment - 1)) == 0;
+
+ if (!isGloballyAligned)
+ return emitError("expected section alignment ")
+ << alignment << " but bytecode buffer 0x"
+ << Twine::utohexstr((uint64_t)buffer.getBufferStart())
+ << " is not aligned";
+
+ return success();
+ };
+
/// Return the context for this config.
MLIRContext *getContext() const { return config.getContext(); }
@@ -1506,7 +1553,7 @@ private:
UseListOrderStorage(bool isIndexPairEncoding,
SmallVector<unsigned, 4> &&indices)
: indices(std::move(indices)),
- isIndexPairEncoding(isIndexPairEncoding){};
+ isIndexPairEncoding(isIndexPairEncoding) {};
/// The vector containing the information required to reorder the
/// use-list of a value.
SmallVector<unsigned, 4> indices;
@@ -1651,6 +1698,11 @@ LogicalResult BytecodeReader::Impl::read(
return failure();
});
+ const auto checkSectionAlignment = [&](unsigned alignment) {
+ return this->checkSectionAlignment(
+ alignment, [&](const auto &msg) { return reader.emitError(msg); });
+ };
+
// Parse the raw data for each of the top-level sections of the bytecode.
std::optional<ArrayRef<uint8_t>>
sectionDatas[bytecode::Section::kNumSections];
@@ -1658,7 +1710,8 @@ LogicalResult BytecodeReader::Impl::read(
// Read the next section from the bytecode.
bytecode::Section::ID sectionID;
ArrayRef<uint8_t> sectionData;
- if (failed(reader.parseSection(sectionID, sectionData)))
+ if (failed(
+ reader.parseSection(sectionID, checkSectionAlignment, sectionData)))
return failure();
// Check for duplicate sections, we only expect one instance of each.
@@ -1778,6 +1831,12 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
return failure();
dialects.resize(numDialects);
+ const auto checkSectionAlignment = [&](unsigned alignment) {
+ return this->checkSectionAlignment(alignment, [&](const auto &msg) {
+ return sectionReader.emitError(msg);
+ });
+ };
+
// Parse each of the dialects.
for (uint64_t i = 0; i < numDialects; ++i) {
dialects[i] = std::make_unique<BytecodeDialect>();
@@ -1800,7 +1859,7 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
return failure();
if (versionAvailable) {
bytecode::Section::ID sectionID;
- if (failed(sectionReader.parseSection(sectionID,
+ if (failed(sectionReader.parseSection(sectionID, checkSectionAlignment,
dialects[i]->versionBuffer)))
return failure();
if (sectionID != bytecode::Section::kDialectVersions) {
@@ -2121,6 +2180,11 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
LogicalResult
BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
RegionReadState &readState) {
+ const auto checkSectionAlignment = [&](unsigned alignment) {
+ return this->checkSectionAlignment(
+ alignment, [&](const auto &msg) { return emitError(fileLoc, msg); });
+ };
+
// Process regions, blocks, and operations until the end or if a nested
// region is encountered. In this case we push a new state in regionStack and
// return, the processing of the current region will resume afterward.
@@ -2161,7 +2225,8 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) {
bytecode::Section::ID sectionID;
ArrayRef<uint8_t> sectionData;
- if (failed(reader.parseSection(sectionID, sectionData)))
+ if (failed(reader.parseSection(sectionID, checkSectionAlignment,
+ sectionData)))
return failure();
if (sectionID != bytecode::Section::kIR)
return emitError(fileLoc, "expected IR section for region");