diff options
Diffstat (limited to 'llvm/lib/Object/OffloadBundle.cpp')
-rw-r--r-- | llvm/lib/Object/OffloadBundle.cpp | 473 |
1 files changed, 473 insertions, 0 deletions
diff --git a/llvm/lib/Object/OffloadBundle.cpp b/llvm/lib/Object/OffloadBundle.cpp new file mode 100644 index 0000000..5f087a5 --- /dev/null +++ b/llvm/lib/Object/OffloadBundle.cpp @@ -0,0 +1,473 @@ +//===- OffloadBundle.cpp - Utilities for offload bundles---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------===// + +#include "llvm/Object/OffloadBundle.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/BinaryFormat/Magic.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Module.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/MC/StringTableBuilder.h" +#include "llvm/Object/Archive.h" +#include "llvm/Object/Binary.h" +#include "llvm/Object/COFF.h" +#include "llvm/Object/ELFObjectFile.h" +#include "llvm/Object/Error.h" +#include "llvm/Object/IRObjectFile.h" +#include "llvm/Object/ObjectFile.h" +#include "llvm/Support/Alignment.h" +#include "llvm/Support/BinaryStreamReader.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/Timer.h" + +using namespace llvm; +using namespace llvm::object; + +static llvm::TimerGroup + OffloadBundlerTimerGroup("Offload Bundler Timer Group", + "Timer group for offload bundler"); + +// Extract an Offload bundle (usually a Offload Bundle) from a fat_bin +// section +Error extractOffloadBundle(MemoryBufferRef Contents, uint64_t SectionOffset, + StringRef FileName, + SmallVectorImpl<OffloadBundleFatBin> &Bundles) { + + uint64_t Offset = 0; + int64_t NextbundleStart = 0; + + // There could be multiple offloading bundles stored at this section. + while (NextbundleStart >= 0) { + + std::unique_ptr<MemoryBuffer> Buffer = + MemoryBuffer::getMemBuffer(Contents.getBuffer().drop_front(Offset), "", + /*RequiresNullTerminator=*/false); + + // Create the FatBinBindle object. This will also create the Bundle Entry + // list info. + auto FatBundleOrErr = + OffloadBundleFatBin::create(*Buffer, SectionOffset + Offset, FileName); + if (!FatBundleOrErr) + return FatBundleOrErr.takeError(); + + // Add current Bundle to list. + Bundles.emplace_back(std::move(**FatBundleOrErr)); + + // Find the next bundle by searching for the magic string + StringRef Str = Buffer->getBuffer(); + NextbundleStart = + (int64_t)Str.find(StringRef("__CLANG_OFFLOAD_BUNDLE__"), 24); + + if (NextbundleStart >= 0) + Offset += NextbundleStart; + } + + return Error::success(); +} + +Error OffloadBundleFatBin::readEntries(StringRef Buffer, + uint64_t SectionOffset) { + uint64_t NumOfEntries = 0; + + BinaryStreamReader Reader(Buffer, llvm::endianness::little); + + // Read the Magic String first. + StringRef Magic; + if (auto EC = Reader.readFixedString(Magic, 24)) + return errorCodeToError(object_error::parse_failed); + + // Read the number of Code Objects (Entries) in the current Bundle. + if (auto EC = Reader.readInteger(NumOfEntries)) + return errorCodeToError(object_error::parse_failed); + + NumberOfEntries = NumOfEntries; + + // For each Bundle Entry (code object) + for (uint64_t I = 0; I < NumOfEntries; I++) { + uint64_t EntrySize; + uint64_t EntryOffset; + uint64_t EntryIDSize; + StringRef EntryID; + + if (auto EC = Reader.readInteger(EntryOffset)) + return errorCodeToError(object_error::parse_failed); + + if (auto EC = Reader.readInteger(EntrySize)) + return errorCodeToError(object_error::parse_failed); + + if (auto EC = Reader.readInteger(EntryIDSize)) + return errorCodeToError(object_error::parse_failed); + + if (auto EC = Reader.readFixedString(EntryID, EntryIDSize)) + return errorCodeToError(object_error::parse_failed); + + auto Entry = std::make_unique<OffloadBundleEntry>( + EntryOffset + SectionOffset, EntrySize, EntryIDSize, EntryID); + + Entries.push_back(*Entry); + } + + return Error::success(); +} + +Expected<std::unique_ptr<OffloadBundleFatBin>> +OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset, + StringRef FileName) { + if (Buf.getBufferSize() < 24) + return errorCodeToError(object_error::parse_failed); + + // Check for magic bytes. + if (identify_magic(Buf.getBuffer()) != file_magic::offload_bundle) + return errorCodeToError(object_error::parse_failed); + + OffloadBundleFatBin *TheBundle = new OffloadBundleFatBin(Buf, FileName); + + // Read the Bundle Entries + Error Err = TheBundle->readEntries(Buf.getBuffer(), SectionOffset); + if (Err) + return errorCodeToError(object_error::parse_failed); + + return std::unique_ptr<OffloadBundleFatBin>(TheBundle); +} + +Error OffloadBundleFatBin::extractBundle(const ObjectFile &Source) { + // This will extract all entries in the Bundle + for (OffloadBundleEntry &Entry : Entries) { + + if (Entry.Size == 0) + continue; + + // create output file name. Which should be + // <fileName>-offset<Offset>-size<Size>.co" + std::string Str = getFileName().str() + "-offset" + itostr(Entry.Offset) + + "-size" + itostr(Entry.Size) + ".co"; + if (Error Err = object::extractCodeObject(Source, Entry.Offset, Entry.Size, + StringRef(Str))) + return Err; + } + + return Error::success(); +} + +Error object::extractOffloadBundleFatBinary( + const ObjectFile &Obj, SmallVectorImpl<OffloadBundleFatBin> &Bundles) { + assert((Obj.isELF() || Obj.isCOFF()) && "Invalid file type"); + + // Iterate through Sections until we find an offload_bundle section. + for (SectionRef Sec : Obj.sections()) { + Expected<StringRef> Buffer = Sec.getContents(); + if (!Buffer) + return Buffer.takeError(); + + // If it does not start with the reserved suffix, just skip this section. + if ((llvm::identify_magic(*Buffer) == llvm::file_magic::offload_bundle) || + (llvm::identify_magic(*Buffer) == + llvm::file_magic::offload_bundle_compressed)) { + + uint64_t SectionOffset = 0; + if (Obj.isELF()) { + SectionOffset = ELFSectionRef(Sec).getOffset(); + } else if (Obj.isCOFF()) // TODO: add COFF Support + return createStringError(object_error::parse_failed, + "COFF object files not supported.\n"); + + MemoryBufferRef Contents(*Buffer, Obj.getFileName()); + + if (llvm::identify_magic(*Buffer) == + llvm::file_magic::offload_bundle_compressed) { + // Decompress the input if necessary. + Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr = + CompressedOffloadBundle::decompress(Contents, false); + + if (!DecompressedBufferOrErr) + return createStringError( + inconvertibleErrorCode(), + "Failed to decompress input: " + + llvm::toString(DecompressedBufferOrErr.takeError())); + + MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr; + if (Error Err = extractOffloadBundle(DecompressedInput, SectionOffset, + Obj.getFileName(), Bundles)) + return Err; + } else { + if (Error Err = extractOffloadBundle(Contents, SectionOffset, + Obj.getFileName(), Bundles)) + return Err; + } + } + } + return Error::success(); +} + +Error object::extractCodeObject(const ObjectFile &Source, int64_t Offset, + int64_t Size, StringRef OutputFileName) { + Expected<std::unique_ptr<FileOutputBuffer>> BufferOrErr = + FileOutputBuffer::create(OutputFileName, Size); + + if (!BufferOrErr) + return BufferOrErr.takeError(); + + Expected<MemoryBufferRef> InputBuffOrErr = Source.getMemoryBufferRef(); + if (Error Err = InputBuffOrErr.takeError()) + return Err; + + std::unique_ptr<FileOutputBuffer> Buf = std::move(*BufferOrErr); + std::copy(InputBuffOrErr->getBufferStart() + Offset, + InputBuffOrErr->getBufferStart() + Offset + Size, + Buf->getBufferStart()); + if (Error E = Buf->commit()) + return E; + + return Error::success(); +} + +// given a file name, offset, and size, extract data into a code object file, +// into file <SourceFile>-offset<Offset>-size<Size>.co +Error object::extractOffloadBundleByURI(StringRef URIstr) { + // create a URI object + Expected<std::unique_ptr<OffloadBundleURI>> UriOrErr( + OffloadBundleURI::createOffloadBundleURI(URIstr, FILE_URI)); + if (!UriOrErr) + return UriOrErr.takeError(); + + OffloadBundleURI &Uri = **UriOrErr; + std::string OutputFile = Uri.FileName.str(); + OutputFile += + "-offset" + itostr(Uri.Offset) + "-size" + itostr(Uri.Size) + ".co"; + + // Create an ObjectFile object from uri.file_uri + auto ObjOrErr = ObjectFile::createObjectFile(Uri.FileName); + if (!ObjOrErr) + return ObjOrErr.takeError(); + + auto Obj = ObjOrErr->getBinary(); + if (Error Err = + object::extractCodeObject(*Obj, Uri.Offset, Uri.Size, OutputFile)) + return Err; + + return Error::success(); +} + +// Utility function to format numbers with commas +static std::string formatWithCommas(unsigned long long Value) { + std::string Num = std::to_string(Value); + int InsertPosition = Num.length() - 3; + while (InsertPosition > 0) { + Num.insert(InsertPosition, ","); + InsertPosition -= 3; + } + return Num; +} + +llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> +CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input, + bool Verbose) { + StringRef Blob = Input.getBuffer(); + + if (Blob.size() < V1HeaderSize) + return llvm::MemoryBuffer::getMemBufferCopy(Blob); + + if (llvm::identify_magic(Blob) != + llvm::file_magic::offload_bundle_compressed) { + if (Verbose) + llvm::errs() << "Uncompressed bundle.\n"; + return llvm::MemoryBuffer::getMemBufferCopy(Blob); + } + + size_t CurrentOffset = MagicSize; + + uint16_t ThisVersion; + memcpy(&ThisVersion, Blob.data() + CurrentOffset, sizeof(uint16_t)); + CurrentOffset += VersionFieldSize; + + uint16_t CompressionMethod; + memcpy(&CompressionMethod, Blob.data() + CurrentOffset, sizeof(uint16_t)); + CurrentOffset += MethodFieldSize; + + uint32_t TotalFileSize; + if (ThisVersion >= 2) { + if (Blob.size() < V2HeaderSize) + return createStringError(inconvertibleErrorCode(), + "Compressed bundle header size too small"); + memcpy(&TotalFileSize, Blob.data() + CurrentOffset, sizeof(uint32_t)); + CurrentOffset += FileSizeFieldSize; + } + + uint32_t UncompressedSize; + memcpy(&UncompressedSize, Blob.data() + CurrentOffset, sizeof(uint32_t)); + CurrentOffset += UncompressedSizeFieldSize; + + uint64_t StoredHash; + memcpy(&StoredHash, Blob.data() + CurrentOffset, sizeof(uint64_t)); + CurrentOffset += HashFieldSize; + + llvm::compression::Format CompressionFormat; + if (CompressionMethod == + static_cast<uint16_t>(llvm::compression::Format::Zlib)) + CompressionFormat = llvm::compression::Format::Zlib; + else if (CompressionMethod == + static_cast<uint16_t>(llvm::compression::Format::Zstd)) + CompressionFormat = llvm::compression::Format::Zstd; + else + return createStringError(inconvertibleErrorCode(), + "Unknown compressing method"); + + llvm::Timer DecompressTimer("Decompression Timer", "Decompression time", + OffloadBundlerTimerGroup); + if (Verbose) + DecompressTimer.startTimer(); + + SmallVector<uint8_t, 0> DecompressedData; + StringRef CompressedData = Blob.substr(CurrentOffset); + if (llvm::Error DecompressionError = llvm::compression::decompress( + CompressionFormat, llvm::arrayRefFromStringRef(CompressedData), + DecompressedData, UncompressedSize)) + return createStringError(inconvertibleErrorCode(), + "Could not decompress embedded file contents: " + + llvm::toString(std::move(DecompressionError))); + + if (Verbose) { + DecompressTimer.stopTimer(); + + double DecompressionTimeSeconds = + DecompressTimer.getTotalTime().getWallTime(); + + // Recalculate MD5 hash for integrity check. + llvm::Timer HashRecalcTimer("Hash Recalculation Timer", + "Hash recalculation time", + OffloadBundlerTimerGroup); + HashRecalcTimer.startTimer(); + llvm::MD5 Hash; + llvm::MD5::MD5Result Result; + Hash.update(llvm::ArrayRef<uint8_t>(DecompressedData.data(), + DecompressedData.size())); + Hash.final(Result); + uint64_t RecalculatedHash = Result.low(); + HashRecalcTimer.stopTimer(); + bool HashMatch = (StoredHash == RecalculatedHash); + + double CompressionRate = + static_cast<double>(UncompressedSize) / CompressedData.size(); + double DecompressionSpeedMBs = + (UncompressedSize / (1024.0 * 1024.0)) / DecompressionTimeSeconds; + + llvm::errs() << "Compressed bundle format version: " << ThisVersion << "\n"; + if (ThisVersion >= 2) + llvm::errs() << "Total file size (from header): " + << formatWithCommas(TotalFileSize) << " bytes\n"; + llvm::errs() << "Decompression method: " + << (CompressionFormat == llvm::compression::Format::Zlib + ? "zlib" + : "zstd") + << "\n" + << "Size before decompression: " + << formatWithCommas(CompressedData.size()) << " bytes\n" + << "Size after decompression: " + << formatWithCommas(UncompressedSize) << " bytes\n" + << "Compression rate: " + << llvm::format("%.2lf", CompressionRate) << "\n" + << "Compression ratio: " + << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n" + << "Decompression speed: " + << llvm::format("%.2lf MB/s", DecompressionSpeedMBs) << "\n" + << "Stored hash: " << llvm::format_hex(StoredHash, 16) << "\n" + << "Recalculated hash: " + << llvm::format_hex(RecalculatedHash, 16) << "\n" + << "Hashes match: " << (HashMatch ? "Yes" : "No") << "\n"; + } + + return llvm::MemoryBuffer::getMemBufferCopy( + llvm::toStringRef(DecompressedData)); +} + +llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> +CompressedOffloadBundle::compress(llvm::compression::Params P, + const llvm::MemoryBuffer &Input, + bool Verbose) { + if (!llvm::compression::zstd::isAvailable() && + !llvm::compression::zlib::isAvailable()) + return createStringError(llvm::inconvertibleErrorCode(), + "Compression not supported"); + + llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time", + OffloadBundlerTimerGroup); + if (Verbose) + HashTimer.startTimer(); + llvm::MD5 Hash; + llvm::MD5::MD5Result Result; + Hash.update(Input.getBuffer()); + Hash.final(Result); + uint64_t TruncatedHash = Result.low(); + if (Verbose) + HashTimer.stopTimer(); + + SmallVector<uint8_t, 0> CompressedBuffer; + auto BufferUint8 = llvm::ArrayRef<uint8_t>( + reinterpret_cast<const uint8_t *>(Input.getBuffer().data()), + Input.getBuffer().size()); + + llvm::Timer CompressTimer("Compression Timer", "Compression time", + OffloadBundlerTimerGroup); + if (Verbose) + CompressTimer.startTimer(); + llvm::compression::compress(P, BufferUint8, CompressedBuffer); + if (Verbose) + CompressTimer.stopTimer(); + + uint16_t CompressionMethod = static_cast<uint16_t>(P.format); + uint32_t UncompressedSize = Input.getBuffer().size(); + uint32_t TotalFileSize = MagicNumber.size() + sizeof(TotalFileSize) + + sizeof(Version) + sizeof(CompressionMethod) + + sizeof(UncompressedSize) + sizeof(TruncatedHash) + + CompressedBuffer.size(); + + SmallVector<char, 0> FinalBuffer; + llvm::raw_svector_ostream OS(FinalBuffer); + OS << MagicNumber; + OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version)); + OS.write(reinterpret_cast<const char *>(&CompressionMethod), + sizeof(CompressionMethod)); + OS.write(reinterpret_cast<const char *>(&TotalFileSize), + sizeof(TotalFileSize)); + OS.write(reinterpret_cast<const char *>(&UncompressedSize), + sizeof(UncompressedSize)); + OS.write(reinterpret_cast<const char *>(&TruncatedHash), + sizeof(TruncatedHash)); + OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()), + CompressedBuffer.size()); + + if (Verbose) { + auto MethodUsed = + P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib"; + double CompressionRate = + static_cast<double>(UncompressedSize) / CompressedBuffer.size(); + double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime(); + double CompressionSpeedMBs = + (UncompressedSize / (1024.0 * 1024.0)) / CompressionTimeSeconds; + + llvm::errs() << "Compressed bundle format version: " << Version << "\n" + << "Total file size (including headers): " + << formatWithCommas(TotalFileSize) << " bytes\n" + << "Compression method used: " << MethodUsed << "\n" + << "Compression level: " << P.level << "\n" + << "Binary size before compression: " + << formatWithCommas(UncompressedSize) << " bytes\n" + << "Binary size after compression: " + << formatWithCommas(CompressedBuffer.size()) << " bytes\n" + << "Compression rate: " + << llvm::format("%.2lf", CompressionRate) << "\n" + << "Compression ratio: " + << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n" + << "Compression speed: " + << llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n" + << "Truncated MD5 hash: " + << llvm::format_hex(TruncatedHash, 16) << "\n"; + } + return llvm::MemoryBuffer::getMemBufferCopy( + llvm::StringRef(FinalBuffer.data(), FinalBuffer.size())); +} |