diff options
Diffstat (limited to 'llvm/lib')
87 files changed, 3676 insertions, 926 deletions
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 87fae92..47dccde 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -234,9 +234,14 @@ static bool evaluatePtrAddRecAtMaxBTCWillNotWrap( // Check if we have a suitable dereferencable assumption we can use. if (!StartPtrV->canBeFreed()) { + Instruction *CtxI = &*L->getHeader()->getFirstNonPHIIt(); + if (BasicBlock *LoopPred = L->getLoopPredecessor()) { + if (isa<BranchInst>(LoopPred->getTerminator())) + CtxI = LoopPred->getTerminator(); + } + RetainedKnowledge DerefRK = getKnowledgeValidInContext( - StartPtrV, {Attribute::Dereferenceable}, *AC, - L->getLoopPredecessor()->getTerminator(), DT); + StartPtrV, {Attribute::Dereferenceable}, *AC, CtxI, DT); if (DerefRK) { DerefBytesSCEV = SE.getUMaxExpr(DerefBytesSCEV, SE.getSCEV(DerefRK.IRArgValue)); @@ -2856,8 +2861,9 @@ void LoopAccessInfo::emitUnsafeDependenceRemark() { } } -bool LoopAccessInfo::blockNeedsPredication(BasicBlock *BB, Loop *TheLoop, - DominatorTree *DT) { +bool LoopAccessInfo::blockNeedsPredication(const BasicBlock *BB, + const Loop *TheLoop, + const DominatorTree *DT) { assert(TheLoop->contains(BB) && "Unknown block used"); // Blocks that do not dominate the latch need predication. diff --git a/llvm/lib/CAS/CMakeLists.txt b/llvm/lib/CAS/CMakeLists.txt index 6ed724b..7ae5f7e 100644 --- a/llvm/lib/CAS/CMakeLists.txt +++ b/llvm/lib/CAS/CMakeLists.txt @@ -2,14 +2,19 @@ add_llvm_component_library(LLVMCAS ActionCache.cpp ActionCaches.cpp BuiltinCAS.cpp + DatabaseFile.cpp InMemoryCAS.cpp MappedFileRegionArena.cpp ObjectStore.cpp OnDiskCommon.cpp + OnDiskTrieRawHashMap.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/CAS + LINK_LIBS + ${LLVM_PTHREAD_LIB} + LINK_COMPONENTS Support ) diff --git a/llvm/lib/CAS/DatabaseFile.cpp b/llvm/lib/CAS/DatabaseFile.cpp new file mode 100644 index 0000000..db8ce1d --- /dev/null +++ b/llvm/lib/CAS/DatabaseFile.cpp @@ -0,0 +1,123 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file implements the common abstractions for CAS database file. +/// +//===----------------------------------------------------------------------===// + +#include "DatabaseFile.h" + +using namespace llvm; +using namespace llvm::cas; +using namespace llvm::cas::ondisk; + +Error ondisk::createTableConfigError(std::errc ErrC, StringRef Path, + StringRef TableName, const Twine &Msg) { + return createStringError(make_error_code(ErrC), + Path + "[" + TableName + "]: " + Msg); +} + +Error ondisk::checkTable(StringRef Label, size_t Expected, size_t Observed, + StringRef Path, StringRef TrieName) { + if (Expected == Observed) + return Error::success(); + return createTableConfigError(std::errc::invalid_argument, Path, TrieName, + "mismatched " + Label + + " (expected: " + Twine(Expected) + + ", observed: " + Twine(Observed) + ")"); +} + +Expected<DatabaseFile> +DatabaseFile::create(const Twine &Path, uint64_t Capacity, + function_ref<Error(DatabaseFile &)> NewDBConstructor) { + // Constructor for if the file doesn't exist. + auto NewFileConstructor = [&](MappedFileRegionArena &Alloc) -> Error { + if (Alloc.capacity() < + sizeof(Header) + sizeof(MappedFileRegionArena::Header)) + return createTableConfigError(std::errc::argument_out_of_domain, + Path.str(), "datafile", + "Allocator too small for header"); + (void)new (Alloc.data()) Header{getMagic(), getVersion(), {0}}; + DatabaseFile DB(Alloc); + return NewDBConstructor(DB); + }; + + // Get or create the file. + MappedFileRegionArena Alloc; + if (Error E = MappedFileRegionArena::create(Path, Capacity, sizeof(Header), + NewFileConstructor) + .moveInto(Alloc)) + return std::move(E); + + return DatabaseFile::get( + std::make_unique<MappedFileRegionArena>(std::move(Alloc))); +} + +Error DatabaseFile::addTable(TableHandle Table) { + assert(Table); + assert(&Table.getRegion() == &getRegion()); + int64_t ExistingRootOffset = 0; + const int64_t NewOffset = + reinterpret_cast<const char *>(&Table.getHeader()) - getRegion().data(); + if (H->RootTableOffset.compare_exchange_strong(ExistingRootOffset, NewOffset)) + return Error::success(); + + // Silently ignore attempts to set the root to itself. + if (ExistingRootOffset == NewOffset) + return Error::success(); + + // Return an proper error message. + TableHandle Root(getRegion(), ExistingRootOffset); + if (Root.getName() == Table.getName()) + return createStringError( + make_error_code(std::errc::not_supported), + "collision with existing table of the same name '" + Table.getName() + + "'"); + + return createStringError(make_error_code(std::errc::not_supported), + "cannot add new table '" + Table.getName() + + "'" + " to existing root '" + + Root.getName() + "'"); +} + +std::optional<TableHandle> DatabaseFile::findTable(StringRef Name) { + int64_t RootTableOffset = H->RootTableOffset.load(); + if (!RootTableOffset) + return std::nullopt; + + TableHandle Root(getRegion(), RootTableOffset); + if (Root.getName() == Name) + return Root; + + return std::nullopt; +} + +Error DatabaseFile::validate(MappedFileRegion &Region) { + if (Region.size() < sizeof(Header)) + return createStringError(std::errc::invalid_argument, + "database: missing header"); + + // Check the magic and version. + auto *H = reinterpret_cast<Header *>(Region.data()); + if (H->Magic != getMagic()) + return createStringError(std::errc::invalid_argument, + "database: bad magic"); + if (H->Version != getVersion()) + return createStringError(std::errc::invalid_argument, + "database: wrong version"); + + auto *MFH = reinterpret_cast<MappedFileRegionArena::Header *>(Region.data() + + sizeof(Header)); + // Check the bump-ptr, which should point past the header. + if (MFH->BumpPtr.load() < (int64_t)sizeof(Header)) + return createStringError(std::errc::invalid_argument, + "database: corrupt bump-ptr"); + + return Error::success(); +} diff --git a/llvm/lib/CAS/DatabaseFile.h b/llvm/lib/CAS/DatabaseFile.h new file mode 100644 index 0000000..609e5f13 --- /dev/null +++ b/llvm/lib/CAS/DatabaseFile.h @@ -0,0 +1,153 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// This file declares the common interface for a DatabaseFile that is used to +/// implement OnDiskCAS. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_CAS_DATABASEFILE_H +#define LLVM_LIB_CAS_DATABASEFILE_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/CAS/MappedFileRegionArena.h" +#include "llvm/Support/Error.h" + +namespace llvm::cas::ondisk { + +using MappedFileRegion = MappedFileRegionArena::RegionT; + +/// Generic handle for a table. +/// +/// Generic table header layout: +/// - 2-bytes: TableKind +/// - 2-bytes: TableNameSize +/// - 4-bytes: TableNameRelOffset (relative to header) +class TableHandle { +public: + enum class TableKind : uint16_t { + TrieRawHashMap = 1, + DataAllocator = 2, + }; + struct Header { + TableKind Kind; + uint16_t NameSize; + int32_t NameRelOffset; ///< Relative to Header. + }; + + explicit operator bool() const { return H; } + const Header &getHeader() const { return *H; } + MappedFileRegion &getRegion() const { return *Region; } + + template <class T> static void check() { + static_assert( + std::is_same<decltype(T::Header::GenericHeader), Header>::value, + "T::GenericHeader should be of type TableHandle::Header"); + static_assert(offsetof(typename T::Header, GenericHeader) == 0, + "T::GenericHeader must be the head of T::Header"); + } + template <class T> bool is() const { return T::Kind == H->Kind; } + template <class T> T dyn_cast() const { + check<T>(); + if (is<T>()) + return T(*Region, *reinterpret_cast<typename T::Header *>(H)); + return T(); + } + template <class T> T cast() const { + assert(is<T>()); + return dyn_cast<T>(); + } + + StringRef getName() const { + auto *Begin = reinterpret_cast<const char *>(H) + H->NameRelOffset; + return StringRef(Begin, H->NameSize); + } + + TableHandle() = default; + TableHandle(MappedFileRegion &Region, Header &H) : Region(&Region), H(&H) {} + TableHandle(MappedFileRegion &Region, intptr_t HeaderOffset) + : TableHandle(Region, + *reinterpret_cast<Header *>(Region.data() + HeaderOffset)) { + } + +private: + MappedFileRegion *Region = nullptr; + Header *H = nullptr; +}; + +/// Encapsulate a database file, which: +/// - Sets/checks magic. +/// - Sets/checks version. +/// - Points at an arbitrary root table. +/// - Sets up a MappedFileRegionArena for allocation. +/// +/// Top-level layout: +/// - 4-bytes: Magic +/// - 4-bytes: Version +/// - 8-bytes: RootTableOffset (16-bits: Kind; 48-bits: Offset) +/// - 8-bytes: BumpPtr from MappedFileRegionArena +class DatabaseFile { +public: + static constexpr uint32_t getMagic() { return 0xDA7ABA53UL; } + static constexpr uint32_t getVersion() { return 1UL; } + struct Header { + uint32_t Magic; + uint32_t Version; + std::atomic<int64_t> RootTableOffset; + }; + + const Header &getHeader() { return *H; } + MappedFileRegionArena &getAlloc() { return Alloc; } + MappedFileRegion &getRegion() { return Alloc.getRegion(); } + + /// Add a table. This is currently not thread safe and should be called inside + /// NewDBConstructor. + Error addTable(TableHandle Table); + + /// Find a table. May return null. + std::optional<TableHandle> findTable(StringRef Name); + + /// Create the DatabaseFile at Path with Capacity. + static Expected<DatabaseFile> + create(const Twine &Path, uint64_t Capacity, + function_ref<Error(DatabaseFile &)> NewDBConstructor); + + size_t size() const { return Alloc.size(); } + +private: + static Expected<DatabaseFile> + get(std::unique_ptr<MappedFileRegionArena> Alloc) { + if (Error E = validate(Alloc->getRegion())) + return std::move(E); + return DatabaseFile(std::move(Alloc)); + } + + static Error validate(MappedFileRegion &Region); + + DatabaseFile(MappedFileRegionArena &Alloc) + : H(reinterpret_cast<Header *>(Alloc.data())), Alloc(Alloc) {} + DatabaseFile(std::unique_ptr<MappedFileRegionArena> Alloc) + : DatabaseFile(*Alloc) { + OwnedAlloc = std::move(Alloc); + } + + Header *H = nullptr; + MappedFileRegionArena &Alloc; + std::unique_ptr<MappedFileRegionArena> OwnedAlloc; +}; + +Error createTableConfigError(std::errc ErrC, StringRef Path, + StringRef TableName, const Twine &Msg); + +Error checkTable(StringRef Label, size_t Expected, size_t Observed, + StringRef Path, StringRef TrieName); + +} // namespace llvm::cas::ondisk + +#endif diff --git a/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp b/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp new file mode 100644 index 0000000..9b382dd7 --- /dev/null +++ b/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp @@ -0,0 +1,1178 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file Implements OnDiskTrieRawHashMap. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/CAS/OnDiskTrieRawHashMap.h" +#include "DatabaseFile.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TrieHashIndexGenerator.h" +#include "llvm/CAS/MappedFileRegionArena.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/Support/ThreadPool.h" +#include "llvm/Support/Threading.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace llvm::cas; +using namespace llvm::cas::ondisk; + +#if LLVM_ENABLE_ONDISK_CAS + +//===----------------------------------------------------------------------===// +// TrieRawHashMap data structures. +//===----------------------------------------------------------------------===// + +namespace { + +class SubtrieHandle; +class TrieRawHashMapHandle; +class TrieVisitor; + +/// A value stored in the slots inside a SubTrie. A stored value can either be a +/// subtrie (encoded after negation) which is the file offset to another +/// subtrie, or it can be a fileset to a DataRecord. +class SubtrieSlotValue { +public: + explicit operator bool() const { return !isEmpty(); } + bool isEmpty() const { return !Offset; } + bool isData() const { return Offset > 0; } + bool isSubtrie() const { return Offset < 0; } + uint64_t asData() const { + assert(isData()); + return Offset; + } + uint64_t asSubtrie() const { + assert(isSubtrie()); + return -Offset; + } + + FileOffset asSubtrieFileOffset() const { return FileOffset(asSubtrie()); } + + FileOffset asDataFileOffset() const { return FileOffset(asData()); } + + int64_t getRawOffset() const { return Offset; } + + static SubtrieSlotValue getDataOffset(int64_t Offset) { + return SubtrieSlotValue(Offset); + } + + static SubtrieSlotValue getSubtrieOffset(int64_t Offset) { + return SubtrieSlotValue(-Offset); + } + + static SubtrieSlotValue getDataOffset(FileOffset Offset) { + return getDataOffset(Offset.get()); + } + + static SubtrieSlotValue getSubtrieOffset(FileOffset Offset) { + return getDataOffset(Offset.get()); + } + + static SubtrieSlotValue getFromSlot(std::atomic<int64_t> &Slot) { + return SubtrieSlotValue(Slot.load()); + } + + SubtrieSlotValue() = default; + +private: + friend class SubtrieHandle; + explicit SubtrieSlotValue(int64_t Offset) : Offset(Offset) {} + int64_t Offset = 0; +}; + +/// Subtrie layout: +/// - 2-bytes: StartBit +/// - 1-bytes: NumBits=lg(num-slots) +/// - 5-bytes: 0-pad +/// - <slots> +class SubtrieHandle { +public: + struct Header { + /// The bit this subtrie starts on. + uint16_t StartBit; + + /// The number of bits this subtrie handles. It has 2^NumBits slots. + uint8_t NumBits; + + /// 0-pad to 8B. + uint8_t ZeroPad1B; + uint32_t ZeroPad4B; + }; + + /// Slot storage: + /// - zero: Empty + /// - positive: RecordOffset + /// - negative: SubtrieOffset + using SlotT = std::atomic<int64_t>; + + static int64_t getSlotsSize(uint32_t NumBits) { + return sizeof(int64_t) * (1u << NumBits); + } + + static int64_t getSize(uint32_t NumBits) { + return sizeof(SubtrieHandle::Header) + getSlotsSize(NumBits); + } + + int64_t getSize() const { return getSize(H->NumBits); } + size_t getNumSlots() const { return Slots.size(); } + + SubtrieSlotValue load(size_t I) const { + return SubtrieSlotValue(Slots[I].load()); + } + void store(size_t I, SubtrieSlotValue V) { + return Slots[I].store(V.getRawOffset()); + } + + void printHash(raw_ostream &OS, ArrayRef<uint8_t> Bytes) const; + + /// Return None on success, or the existing offset on failure. + bool compare_exchange_strong(size_t I, SubtrieSlotValue &Expected, + SubtrieSlotValue New) { + return Slots[I].compare_exchange_strong(Expected.Offset, New.Offset); + } + + /// Sink \p V from \p I in this subtrie down to \p NewI in a new subtrie with + /// \p NumSubtrieBits. + /// + /// \p UnusedSubtrie maintains a 1-item "free" list of unused subtries. If a + /// new subtrie is created that isn't used because of a lost race, then it If + /// it's already valid, it should be used instead of allocating a new one. + /// should be returned as an out parameter to be passed back in the future. + /// If it's already valid, it should be used instead of allocating a new one. + /// + /// Returns the subtrie that now lives at \p I. + Expected<SubtrieHandle> sink(size_t I, SubtrieSlotValue V, + MappedFileRegionArena &Alloc, + size_t NumSubtrieBits, + SubtrieHandle &UnusedSubtrie, size_t NewI); + + /// Only safe if the subtrie is empty. + void reinitialize(uint32_t StartBit, uint32_t NumBits); + + SubtrieSlotValue getOffset() const { + return SubtrieSlotValue::getSubtrieOffset( + reinterpret_cast<const char *>(H) - Region->data()); + } + + FileOffset getFileOffset() const { return getOffset().asSubtrieFileOffset(); } + + explicit operator bool() const { return H; } + + Header &getHeader() const { return *H; } + uint32_t getStartBit() const { return H->StartBit; } + uint32_t getNumBits() const { return H->NumBits; } + + static Expected<SubtrieHandle> create(MappedFileRegionArena &Alloc, + uint32_t StartBit, uint32_t NumBits); + + static SubtrieHandle getFromFileOffset(MappedFileRegion &Region, + FileOffset Offset) { + return SubtrieHandle(Region, SubtrieSlotValue::getSubtrieOffset(Offset)); + } + + SubtrieHandle() = default; + SubtrieHandle(MappedFileRegion &Region, Header &H) + : Region(&Region), H(&H), Slots(getSlots(H)) {} + SubtrieHandle(MappedFileRegion &Region, SubtrieSlotValue Offset) + : SubtrieHandle(Region, *reinterpret_cast<Header *>( + Region.data() + Offset.asSubtrie())) {} + +private: + MappedFileRegion *Region = nullptr; + Header *H = nullptr; + MutableArrayRef<SlotT> Slots; + + static MutableArrayRef<SlotT> getSlots(Header &H) { + return MutableArrayRef(reinterpret_cast<SlotT *>(&H + 1), 1u << H.NumBits); + } +}; + +/// Handle for a TrieRawHashMap table. +/// +/// TrieRawHashMap table layout: +/// - [8-bytes: Generic table header] +/// - 1-byte: NumSubtrieBits +/// - 1-byte: Flags (not used yet) +/// - 2-bytes: NumHashBits +/// - 4-bytes: RecordDataSize (in bytes) +/// - 8-bytes: RootTrieOffset +/// - 8-bytes: AllocatorOffset (reserved for implementing free lists) +/// - <name> '\0' +/// +/// Record layout: +/// - <hash> +/// - <data> +class TrieRawHashMapHandle { +public: + static constexpr TableHandle::TableKind Kind = + TableHandle::TableKind::TrieRawHashMap; + + struct Header { + TableHandle::Header GenericHeader; + uint8_t NumSubtrieBits; + uint8_t Flags; ///< None used yet. + uint16_t NumHashBits; + uint32_t RecordDataSize; + std::atomic<int64_t> RootTrieOffset; + std::atomic<int64_t> AllocatorOffset; + }; + + operator TableHandle() const { + if (!H) + return TableHandle(); + return TableHandle(*Region, H->GenericHeader); + } + + struct RecordData { + OnDiskTrieRawHashMap::ValueProxy Proxy; + SubtrieSlotValue Offset; + FileOffset getFileOffset() const { return Offset.asDataFileOffset(); } + }; + + enum Limits : size_t { + /// Seems like 65528 hash bits ought to be enough. + MaxNumHashBytes = UINT16_MAX >> 3, + MaxNumHashBits = MaxNumHashBytes << 3, + + /// 2^16 bits in a trie is 65536 slots. This restricts us to a 16-bit + /// index. This many slots is suspicously large anyway. + MaxNumRootBits = 16, + + /// 2^10 bits in a trie is 1024 slots. This many slots seems suspiciously + /// large for subtries. + MaxNumSubtrieBits = 10, + }; + + static constexpr size_t getNumHashBytes(size_t NumHashBits) { + assert(NumHashBits % 8 == 0); + return NumHashBits / 8; + } + static constexpr size_t getRecordSize(size_t RecordDataSize, + size_t NumHashBits) { + return RecordDataSize + getNumHashBytes(NumHashBits); + } + + RecordData getRecord(SubtrieSlotValue Offset); + Expected<RecordData> createRecord(MappedFileRegionArena &Alloc, + ArrayRef<uint8_t> Hash); + + explicit operator bool() const { return H; } + const Header &getHeader() const { return *H; } + SubtrieHandle getRoot() const; + Expected<SubtrieHandle> getOrCreateRoot(MappedFileRegionArena &Alloc); + MappedFileRegion &getRegion() const { return *Region; } + + size_t getFlags() const { return H->Flags; } + size_t getNumSubtrieBits() const { return H->NumSubtrieBits; } + size_t getNumHashBits() const { return H->NumHashBits; } + size_t getNumHashBytes() const { return getNumHashBytes(H->NumHashBits); } + size_t getRecordDataSize() const { return H->RecordDataSize; } + size_t getRecordSize() const { + return getRecordSize(H->RecordDataSize, H->NumHashBits); + } + + TrieHashIndexGenerator getIndexGen(SubtrieHandle Root, + ArrayRef<uint8_t> Hash) { + assert(Root.getStartBit() == 0); + assert(getNumHashBytes() == Hash.size()); + assert(getNumHashBits() == Hash.size() * 8); + return TrieHashIndexGenerator{Root.getNumBits(), getNumSubtrieBits(), Hash}; + } + + static Expected<TrieRawHashMapHandle> + create(MappedFileRegionArena &Alloc, StringRef Name, + std::optional<uint64_t> NumRootBits, uint64_t NumSubtrieBits, + uint64_t NumHashBits, uint64_t RecordDataSize); + + void + print(raw_ostream &OS, + function_ref<void(ArrayRef<char>)> PrintRecordData = nullptr) const; + + Error validate( + function_ref<Error(FileOffset, OnDiskTrieRawHashMap::ConstValueProxy)> + RecordVerifier) const; + TrieRawHashMapHandle() = default; + TrieRawHashMapHandle(MappedFileRegion &Region, Header &H) + : Region(&Region), H(&H) {} + TrieRawHashMapHandle(MappedFileRegion &Region, intptr_t HeaderOffset) + : TrieRawHashMapHandle( + Region, *reinterpret_cast<Header *>(Region.data() + HeaderOffset)) { + } + +private: + MappedFileRegion *Region = nullptr; + Header *H = nullptr; +}; + +} // end anonymous namespace + +struct OnDiskTrieRawHashMap::ImplType { + DatabaseFile File; + TrieRawHashMapHandle Trie; +}; + +Expected<SubtrieHandle> SubtrieHandle::create(MappedFileRegionArena &Alloc, + uint32_t StartBit, + uint32_t NumBits) { + assert(StartBit <= TrieRawHashMapHandle::MaxNumHashBits); + assert(NumBits <= UINT8_MAX); + assert(NumBits <= TrieRawHashMapHandle::MaxNumRootBits); + + auto Mem = Alloc.allocate(getSize(NumBits)); + if (LLVM_UNLIKELY(!Mem)) + return Mem.takeError(); + auto *H = + new (*Mem) SubtrieHandle::Header{(uint16_t)StartBit, (uint8_t)NumBits, + /*ZeroPad1B=*/0, /*ZeroPad4B=*/0}; + SubtrieHandle S(Alloc.getRegion(), *H); + for (auto I = S.Slots.begin(), E = S.Slots.end(); I != E; ++I) + new (I) SlotT(0); + return S; +} + +SubtrieHandle TrieRawHashMapHandle::getRoot() const { + if (int64_t Root = H->RootTrieOffset) + return SubtrieHandle(getRegion(), SubtrieSlotValue::getSubtrieOffset(Root)); + return SubtrieHandle(); +} + +Expected<SubtrieHandle> +TrieRawHashMapHandle::getOrCreateRoot(MappedFileRegionArena &Alloc) { + assert(&Alloc.getRegion() == &getRegion()); + if (SubtrieHandle Root = getRoot()) + return Root; + + int64_t Race = 0; + auto LazyRoot = SubtrieHandle::create(Alloc, 0, H->NumSubtrieBits); + if (LLVM_UNLIKELY(!LazyRoot)) + return LazyRoot.takeError(); + if (H->RootTrieOffset.compare_exchange_strong( + Race, LazyRoot->getOffset().asSubtrie())) + return *LazyRoot; + + // There was a race. Return the other root. + // + // TODO: Avoid leaking the lazy root by storing it in an allocator. + return SubtrieHandle(getRegion(), SubtrieSlotValue::getSubtrieOffset(Race)); +} + +Expected<TrieRawHashMapHandle> +TrieRawHashMapHandle::create(MappedFileRegionArena &Alloc, StringRef Name, + std::optional<uint64_t> NumRootBits, + uint64_t NumSubtrieBits, uint64_t NumHashBits, + uint64_t RecordDataSize) { + // Allocate. + auto Offset = Alloc.allocateOffset(sizeof(Header) + Name.size() + 1); + if (LLVM_UNLIKELY(!Offset)) + return Offset.takeError(); + + // Construct the header and the name. + assert(Name.size() <= UINT16_MAX && "Expected smaller table name"); + assert(NumSubtrieBits <= UINT8_MAX && "Expected valid subtrie bits"); + assert(NumHashBits <= UINT16_MAX && "Expected valid hash size"); + assert(RecordDataSize <= UINT32_MAX && "Expected smaller table name"); + auto *H = new (Alloc.getRegion().data() + *Offset) + Header{{TableHandle::TableKind::TrieRawHashMap, (uint16_t)Name.size(), + (uint32_t)sizeof(Header)}, + (uint8_t)NumSubtrieBits, + /*Flags=*/0, + (uint16_t)NumHashBits, + (uint32_t)RecordDataSize, + /*RootTrieOffset=*/{0}, + /*AllocatorOffset=*/{0}}; + char *NameStorage = reinterpret_cast<char *>(H + 1); + llvm::copy(Name, NameStorage); + NameStorage[Name.size()] = 0; + + // Construct a root trie, if requested. + TrieRawHashMapHandle Trie(Alloc.getRegion(), *H); + auto Sub = SubtrieHandle::create(Alloc, 0, *NumRootBits); + if (LLVM_UNLIKELY(!Sub)) + return Sub.takeError(); + if (NumRootBits) + H->RootTrieOffset = Sub->getOffset().asSubtrie(); + return Trie; +} + +TrieRawHashMapHandle::RecordData +TrieRawHashMapHandle::getRecord(SubtrieSlotValue Offset) { + char *Begin = Region->data() + Offset.asData(); + OnDiskTrieRawHashMap::ValueProxy Proxy; + Proxy.Data = MutableArrayRef(Begin, getRecordDataSize()); + Proxy.Hash = ArrayRef(reinterpret_cast<const uint8_t *>(Proxy.Data.end()), + getNumHashBytes()); + return RecordData{Proxy, Offset}; +} + +Expected<TrieRawHashMapHandle::RecordData> +TrieRawHashMapHandle::createRecord(MappedFileRegionArena &Alloc, + ArrayRef<uint8_t> Hash) { + assert(&Alloc.getRegion() == Region); + assert(Hash.size() == getNumHashBytes()); + auto Offset = Alloc.allocateOffset(getRecordSize()); + if (LLVM_UNLIKELY(!Offset)) + return Offset.takeError(); + + RecordData Record = getRecord(SubtrieSlotValue::getDataOffset(*Offset)); + llvm::copy(Hash, const_cast<uint8_t *>(Record.Proxy.Hash.begin())); + return Record; +} + +Expected<OnDiskTrieRawHashMap::const_pointer> +OnDiskTrieRawHashMap::recoverFromFileOffset(FileOffset Offset) const { + // Check alignment. + if (!isAligned(MappedFileRegionArena::getAlign(), Offset.get())) + return createStringError(make_error_code(std::errc::protocol_error), + "unaligned file offset at 0x" + + utohexstr(Offset.get(), /*LowerCase=*/true)); + + // Check bounds. + // + // Note: There's no potential overflow when using \c uint64_t because Offset + // is in valid offset range and the record size is in \c [0,UINT32_MAX]. + if (!validOffset(Offset) || + Offset.get() + Impl->Trie.getRecordSize() > Impl->File.getAlloc().size()) + return createStringError(make_error_code(std::errc::protocol_error), + "file offset too large: 0x" + + utohexstr(Offset.get(), /*LowerCase=*/true)); + + // Looks okay... + TrieRawHashMapHandle::RecordData D = + Impl->Trie.getRecord(SubtrieSlotValue::getDataOffset(Offset)); + return const_pointer(D.Proxy, D.getFileOffset()); +} + +OnDiskTrieRawHashMap::const_pointer +OnDiskTrieRawHashMap::find(ArrayRef<uint8_t> Hash) const { + TrieRawHashMapHandle Trie = Impl->Trie; + assert(Hash.size() == Trie.getNumHashBytes() && "Invalid hash"); + + SubtrieHandle S = Trie.getRoot(); + if (!S) + return const_pointer(); + + TrieHashIndexGenerator IndexGen = Trie.getIndexGen(S, Hash); + size_t Index = IndexGen.next(); + for (;;) { + // Try to set the content. + SubtrieSlotValue V = S.load(Index); + if (!V) + return const_pointer(); + + // Check for an exact match. + if (V.isData()) { + TrieRawHashMapHandle::RecordData D = Trie.getRecord(V); + return D.Proxy.Hash == Hash ? const_pointer(D.Proxy, D.getFileOffset()) + : const_pointer(); + } + + Index = IndexGen.next(); + S = SubtrieHandle(Trie.getRegion(), V); + } +} + +/// Only safe if the subtrie is empty. +void SubtrieHandle::reinitialize(uint32_t StartBit, uint32_t NumBits) { + assert(StartBit > H->StartBit); + assert(NumBits <= H->NumBits); + // Ideally would also assert that all slots are empty, but that's expensive. + + H->StartBit = StartBit; + H->NumBits = NumBits; +} + +Expected<OnDiskTrieRawHashMap::pointer> +OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash, + LazyInsertOnConstructCB OnConstruct, + LazyInsertOnLeakCB OnLeak) { + TrieRawHashMapHandle Trie = Impl->Trie; + assert(Hash.size() == Trie.getNumHashBytes() && "Invalid hash"); + + MappedFileRegionArena &Alloc = Impl->File.getAlloc(); + std::optional<SubtrieHandle> S; + auto Err = Trie.getOrCreateRoot(Alloc).moveInto(S); + if (LLVM_UNLIKELY(Err)) + return std::move(Err); + + TrieHashIndexGenerator IndexGen = Trie.getIndexGen(*S, Hash); + size_t Index = IndexGen.next(); + + // Walk through the hash bytes and insert into correct trie position. + std::optional<TrieRawHashMapHandle::RecordData> NewRecord; + SubtrieHandle UnusedSubtrie; + for (;;) { + SubtrieSlotValue Existing = S->load(Index); + + // Try to set it, if it's empty. + if (!Existing) { + if (!NewRecord) { + auto Err = Trie.createRecord(Alloc, Hash).moveInto(NewRecord); + if (LLVM_UNLIKELY(Err)) + return std::move(Err); + if (OnConstruct) + OnConstruct(NewRecord->Offset.asDataFileOffset(), NewRecord->Proxy); + } + + if (S->compare_exchange_strong(Index, Existing, NewRecord->Offset)) + return pointer(NewRecord->Proxy, NewRecord->Offset.asDataFileOffset()); + + // Race means that Existing is no longer empty; fall through... + } + + if (Existing.isSubtrie()) { + S = SubtrieHandle(Trie.getRegion(), Existing); + Index = IndexGen.next(); + continue; + } + + // Check for an exact match. + TrieRawHashMapHandle::RecordData ExistingRecord = Trie.getRecord(Existing); + if (ExistingRecord.Proxy.Hash == Hash) { + if (NewRecord && OnLeak) + OnLeak(NewRecord->Offset.asDataFileOffset(), NewRecord->Proxy, + ExistingRecord.Offset.asDataFileOffset(), ExistingRecord.Proxy); + return pointer(ExistingRecord.Proxy, + ExistingRecord.Offset.asDataFileOffset()); + } + + // Sink the existing content as long as the indexes match. + for (;;) { + size_t NextIndex = IndexGen.next(); + size_t NewIndexForExistingContent = + IndexGen.getCollidingBits(ExistingRecord.Proxy.Hash); + + auto Err = S->sink(Index, Existing, Alloc, IndexGen.getNumBits(), + UnusedSubtrie, NewIndexForExistingContent) + .moveInto(S); + if (LLVM_UNLIKELY(Err)) + return std::move(Err); + Index = NextIndex; + + // Found the difference. + if (NextIndex != NewIndexForExistingContent) + break; + } + } +} + +Expected<SubtrieHandle> SubtrieHandle::sink(size_t I, SubtrieSlotValue V, + MappedFileRegionArena &Alloc, + size_t NumSubtrieBits, + SubtrieHandle &UnusedSubtrie, + size_t NewI) { + std::optional<SubtrieHandle> NewS; + if (UnusedSubtrie) { + // Steal UnusedSubtrie and initialize it. + NewS.emplace(); + std::swap(*NewS, UnusedSubtrie); + NewS->reinitialize(getStartBit() + getNumBits(), NumSubtrieBits); + } else { + // Allocate a new, empty subtrie. + auto Err = SubtrieHandle::create(Alloc, getStartBit() + getNumBits(), + NumSubtrieBits) + .moveInto(NewS); + if (LLVM_UNLIKELY(Err)) + return std::move(Err); + } + + NewS->store(NewI, V); + if (compare_exchange_strong(I, V, NewS->getOffset())) + return *NewS; // Success! + + // Raced. + assert(V.isSubtrie() && "Expected racing sink() to add a subtrie"); + + // Wipe out the new slot so NewS can be reused and set the out parameter. + NewS->store(NewI, SubtrieSlotValue()); + UnusedSubtrie = *NewS; + + // Return the subtrie added by the concurrent sink() call. + return SubtrieHandle(Alloc.getRegion(), V); +} + +void OnDiskTrieRawHashMap::print( + raw_ostream &OS, function_ref<void(ArrayRef<char>)> PrintRecordData) const { + Impl->Trie.print(OS, PrintRecordData); +} + +Error OnDiskTrieRawHashMap::validate( + function_ref<Error(FileOffset, ConstValueProxy)> RecordVerifier) const { + return Impl->Trie.validate(RecordVerifier); +} + +// Helper function that prints hexdigit and have a sub-byte starting position. +static void printHexDigits(raw_ostream &OS, ArrayRef<uint8_t> Bytes, + size_t StartBit, size_t NumBits) { + assert(StartBit % 4 == 0); + assert(NumBits % 4 == 0); + for (size_t I = StartBit, E = StartBit + NumBits; I != E; I += 4) { + uint8_t HexPair = Bytes[I / 8]; + uint8_t HexDigit = I % 8 == 0 ? HexPair >> 4 : HexPair & 0xf; + OS << hexdigit(HexDigit, /*LowerCase=*/true); + } +} + +static void printBits(raw_ostream &OS, ArrayRef<uint8_t> Bytes, size_t StartBit, + size_t NumBits) { + assert(StartBit + NumBits <= Bytes.size() * 8u); + for (size_t I = StartBit, E = StartBit + NumBits; I != E; ++I) { + uint8_t Byte = Bytes[I / 8]; + size_t ByteOffset = I % 8; + if (size_t ByteShift = 8 - ByteOffset - 1) + Byte >>= ByteShift; + OS << (Byte & 0x1 ? '1' : '0'); + } +} + +void SubtrieHandle::printHash(raw_ostream &OS, ArrayRef<uint8_t> Bytes) const { + // afb[1c:00*01110*0]def + size_t EndBit = getStartBit() + getNumBits(); + size_t HashEndBit = Bytes.size() * 8u; + + size_t FirstBinaryBit = getStartBit() & ~0x3u; + printHexDigits(OS, Bytes, 0, FirstBinaryBit); + + size_t LastBinaryBit = (EndBit + 3u) & ~0x3u; + OS << "["; + printBits(OS, Bytes, FirstBinaryBit, LastBinaryBit - FirstBinaryBit); + OS << "]"; + + printHexDigits(OS, Bytes, LastBinaryBit, HashEndBit - LastBinaryBit); +} + +static void appendIndexBits(std::string &Prefix, size_t Index, + size_t NumSlots) { + std::string Bits; + for (size_t NumBits = 1u; NumBits < NumSlots; NumBits <<= 1) { + Bits.push_back('0' + (Index & 0x1)); + Index >>= 1; + } + for (char Ch : llvm::reverse(Bits)) + Prefix += Ch; +} + +static void printPrefix(raw_ostream &OS, StringRef Prefix) { + while (Prefix.size() >= 4) { + uint8_t Digit; + bool ErrorParsingBinary = Prefix.take_front(4).getAsInteger(2, Digit); + assert(!ErrorParsingBinary); + (void)ErrorParsingBinary; + OS << hexdigit(Digit, /*LowerCase=*/true); + Prefix = Prefix.drop_front(4); + } + if (!Prefix.empty()) + OS << "[" << Prefix << "]"; +} + +LLVM_DUMP_METHOD void OnDiskTrieRawHashMap::dump() const { print(dbgs()); } + +static Expected<size_t> checkParameter(StringRef Label, size_t Max, + std::optional<size_t> Value, + std::optional<size_t> Default, + StringRef Path, StringRef TableName) { + assert(Value || Default); + assert(!Default || *Default <= Max); + if (!Value) + return *Default; + + if (*Value <= Max) + return *Value; + return createTableConfigError( + std::errc::argument_out_of_domain, Path, TableName, + "invalid " + Label + ": " + Twine(*Value) + " (max: " + Twine(Max) + ")"); +} + +size_t OnDiskTrieRawHashMap::size() const { return Impl->File.size(); } +size_t OnDiskTrieRawHashMap::capacity() const { + return Impl->File.getRegion().size(); +} + +Expected<OnDiskTrieRawHashMap> +OnDiskTrieRawHashMap::create(const Twine &PathTwine, const Twine &TrieNameTwine, + size_t NumHashBits, uint64_t DataSize, + uint64_t MaxFileSize, + std::optional<uint64_t> NewFileInitialSize, + std::optional<size_t> NewTableNumRootBits, + std::optional<size_t> NewTableNumSubtrieBits) { + SmallString<128> PathStorage; + StringRef Path = PathTwine.toStringRef(PathStorage); + SmallString<128> TrieNameStorage; + StringRef TrieName = TrieNameTwine.toStringRef(TrieNameStorage); + + constexpr size_t DefaultNumRootBits = 10; + constexpr size_t DefaultNumSubtrieBits = 6; + + size_t NumRootBits; + if (Error E = checkParameter( + "root bits", TrieRawHashMapHandle::MaxNumRootBits, + NewTableNumRootBits, DefaultNumRootBits, Path, TrieName) + .moveInto(NumRootBits)) + return std::move(E); + + size_t NumSubtrieBits; + if (Error E = checkParameter("subtrie bits", + TrieRawHashMapHandle::MaxNumSubtrieBits, + NewTableNumSubtrieBits, DefaultNumSubtrieBits, + Path, TrieName) + .moveInto(NumSubtrieBits)) + return std::move(E); + + size_t NumHashBytes = NumHashBits >> 3; + if (Error E = + checkParameter("hash size", TrieRawHashMapHandle::MaxNumHashBits, + NumHashBits, std::nullopt, Path, TrieName) + .takeError()) + return std::move(E); + assert(NumHashBits == NumHashBytes << 3 && + "Expected hash size to be byte-aligned"); + if (NumHashBits != NumHashBytes << 3) + return createTableConfigError( + std::errc::argument_out_of_domain, Path, TrieName, + "invalid hash size: " + Twine(NumHashBits) + " (not byte-aligned)"); + + // Constructor for if the file doesn't exist. + auto NewDBConstructor = [&](DatabaseFile &DB) -> Error { + auto Trie = + TrieRawHashMapHandle::create(DB.getAlloc(), TrieName, NumRootBits, + NumSubtrieBits, NumHashBits, DataSize); + if (LLVM_UNLIKELY(!Trie)) + return Trie.takeError(); + + return DB.addTable(*Trie); + }; + + // Get or create the file. + Expected<DatabaseFile> File = + DatabaseFile::create(Path, MaxFileSize, NewDBConstructor); + if (!File) + return File.takeError(); + + // Find the trie and validate it. + std::optional<TableHandle> Table = File->findTable(TrieName); + if (!Table) + return createTableConfigError(std::errc::argument_out_of_domain, Path, + TrieName, "table not found"); + if (Error E = checkTable("table kind", (size_t)TrieRawHashMapHandle::Kind, + (size_t)Table->getHeader().Kind, Path, TrieName)) + return std::move(E); + auto Trie = Table->cast<TrieRawHashMapHandle>(); + assert(Trie && "Already checked the kind"); + + // Check the hash and data size. + if (Error E = checkTable("hash size", NumHashBits, Trie.getNumHashBits(), + Path, TrieName)) + return std::move(E); + if (Error E = checkTable("data size", DataSize, Trie.getRecordDataSize(), + Path, TrieName)) + return std::move(E); + + // No flags supported right now. Either corrupt, or coming from a future + // writer. + if (size_t Flags = Trie.getFlags()) + return createTableConfigError(std::errc::invalid_argument, Path, TrieName, + "unsupported flags: " + Twine(Flags)); + + // Success. + OnDiskTrieRawHashMap::ImplType Impl{DatabaseFile(std::move(*File)), Trie}; + return OnDiskTrieRawHashMap(std::make_unique<ImplType>(std::move(Impl))); +} + +static Error createInvalidTrieError(uint64_t Offset, const Twine &Msg) { + return createStringError(make_error_code(std::errc::protocol_error), + "invalid trie at 0x" + + utohexstr(Offset, /*LowerCase=*/true) + ": " + + Msg); +} + +//===----------------------------------------------------------------------===// +// TrieVisitor data structures. +//===----------------------------------------------------------------------===// + +namespace { +/// A multi-threaded vistior to traverse the Trie. +/// +/// TODO: add more sanity checks that isn't just plain data corruption. For +/// example, some ill-formed data can be constructed to form a cycle using +/// Sub-Tries and it can lead to inifinite loop when visiting (or inserting +/// data). +class TrieVisitor { +public: + TrieVisitor(TrieRawHashMapHandle Trie, unsigned ThreadCount = 0, + unsigned ErrorLimit = 50) + : Trie(Trie), ErrorLimit(ErrorLimit), + Threads(hardware_concurrency(ThreadCount)) {} + virtual ~TrieVisitor() = default; + Error visit(); + +private: + // Virtual method to implement the action when visiting a sub-trie. + virtual Error visitSubTrie(StringRef Prefix, SubtrieHandle SubTrie) { + return Error::success(); + } + + // Virtual method to implement the action when visiting a slot in a trie node. + virtual Error visitSlot(unsigned I, SubtrieHandle Subtrie, StringRef Prefix, + SubtrieSlotValue Slot) { + return Error::success(); + } + +protected: + TrieRawHashMapHandle Trie; + +private: + Error traverseTrieNode(SubtrieHandle Node, StringRef Prefix); + + Error validateSubTrie(SubtrieHandle Node, bool IsRoot); + + // Helper function to capture errors when visiting the trie nodes. + void addError(Error NewError) { + assert(NewError && "not an error"); + std::lock_guard<std::mutex> ErrorLock(Lock); + if (NumError >= ErrorLimit) { + // Too many errors. + consumeError(std::move(NewError)); + return; + } + + if (Err) + Err = joinErrors(std::move(*Err), std::move(NewError)); + else + Err = std::move(NewError); + NumError++; + } + + bool tooManyErrors() { + std::lock_guard<std::mutex> ErrorLock(Lock); + return (bool)Err && NumError >= ErrorLimit; + } + + const unsigned ErrorLimit; + std::optional<Error> Err; + unsigned NumError = 0; + std::mutex Lock; + DefaultThreadPool Threads; +}; + +/// A visitor that traverse and print the Trie. +class TriePrinter : public TrieVisitor { +public: + TriePrinter(TrieRawHashMapHandle Trie, raw_ostream &OS, + function_ref<void(ArrayRef<char>)> PrintRecordData) + : TrieVisitor(Trie, /*ThreadCount=*/1), OS(OS), + PrintRecordData(PrintRecordData) {} + + Error printRecords() { + if (Records.empty()) + return Error::success(); + + OS << "records\n"; + llvm::sort(Records); + for (int64_t Offset : Records) { + TrieRawHashMapHandle::RecordData Record = + Trie.getRecord(SubtrieSlotValue::getDataOffset(Offset)); + if (auto Err = printRecord(Record)) + return Err; + } + return Error::success(); + } + + Error printRecord(TrieRawHashMapHandle::RecordData &Record) { + OS << "- addr=" << (void *)Record.getFileOffset().get() << " "; + if (PrintRecordData) { + PrintRecordData(Record.Proxy.Data); + } else { + OS << "bytes="; + ArrayRef<uint8_t> Data( + reinterpret_cast<const uint8_t *>(Record.Proxy.Data.data()), + Record.Proxy.Data.size()); + printHexDigits(OS, Data, 0, Data.size() * 8); + } + OS << "\n"; + return Error::success(); + } + + Error visitSubTrie(StringRef Prefix, SubtrieHandle SubTrie) override { + if (Prefix.empty()) { + OS << "root"; + } else { + OS << "subtrie="; + printPrefix(OS, Prefix); + } + + OS << " addr=" + << (void *)(reinterpret_cast<const char *>(&SubTrie.getHeader()) - + Trie.getRegion().data()); + OS << " num-slots=" << SubTrie.getNumSlots() << "\n"; + return Error::success(); + } + + Error visitSlot(unsigned I, SubtrieHandle Subtrie, StringRef Prefix, + SubtrieSlotValue Slot) override { + OS << "- index="; + for (size_t Pad : {10, 100, 1000}) + if (I < Pad && Subtrie.getNumSlots() >= Pad) + OS << "0"; + OS << I << " "; + if (Slot.isSubtrie()) { + OS << "addr=" << (void *)Slot.asSubtrie(); + OS << " subtrie="; + printPrefix(OS, Prefix); + OS << "\n"; + return Error::success(); + } + TrieRawHashMapHandle::RecordData Record = Trie.getRecord(Slot); + OS << "addr=" << (void *)Record.getFileOffset().get(); + OS << " content="; + Subtrie.printHash(OS, Record.Proxy.Hash); + OS << "\n"; + Records.push_back(Slot.asData()); + return Error::success(); + } + +private: + raw_ostream &OS; + function_ref<void(ArrayRef<char>)> PrintRecordData; + SmallVector<int64_t> Records; +}; + +/// TrieVerifier that adds additional verification on top of the basic visitor. +class TrieVerifier : public TrieVisitor { +public: + TrieVerifier( + TrieRawHashMapHandle Trie, + function_ref<Error(FileOffset, OnDiskTrieRawHashMap::ConstValueProxy)> + RecordVerifier) + : TrieVisitor(Trie), RecordVerifier(RecordVerifier) {} + +private: + Error visitSubTrie(StringRef Prefix, SubtrieHandle SubTrie) final { + return Error::success(); + } + + Error visitSlot(unsigned I, SubtrieHandle Subtrie, StringRef Prefix, + SubtrieSlotValue Slot) final { + if (RecordVerifier && Slot.isData()) { + if (!isAligned(MappedFileRegionArena::getAlign(), Slot.asData())) + return createInvalidTrieError(Slot.asData(), "mis-aligned data entry"); + + TrieRawHashMapHandle::RecordData Record = + Trie.getRecord(SubtrieSlotValue::getDataOffset(Slot.asData())); + return RecordVerifier(Slot.asDataFileOffset(), + OnDiskTrieRawHashMap::ConstValueProxy{ + Record.Proxy.Hash, Record.Proxy.Data}); + } + return Error::success(); + } + + function_ref<Error(FileOffset, OnDiskTrieRawHashMap::ConstValueProxy)> + RecordVerifier; +}; +} // namespace + +Error TrieVisitor::visit() { + auto Root = Trie.getRoot(); + if (!Root) + return Error::success(); + + if (auto Err = validateSubTrie(Root, /*IsRoot=*/true)) + return Err; + + if (auto Err = visitSubTrie("", Root)) + return Err; + + SmallVector<SubtrieHandle> Subs; + SmallVector<std::string> Prefixes; + const size_t NumSlots = Root.getNumSlots(); + for (size_t I = 0, E = NumSlots; I != E; ++I) { + SubtrieSlotValue Slot = Root.load(I); + if (!Slot) + continue; + uint64_t Offset = Slot.isSubtrie() ? Slot.asSubtrie() : Slot.asData(); + if (Offset >= (uint64_t)Trie.getRegion().size()) + return createInvalidTrieError(Offset, "slot points out of bound"); + std::string SubtriePrefix; + appendIndexBits(SubtriePrefix, I, NumSlots); + if (Slot.isSubtrie()) { + SubtrieHandle S(Trie.getRegion(), Slot); + Subs.push_back(S); + Prefixes.push_back(SubtriePrefix); + } + if (auto Err = visitSlot(I, Root, SubtriePrefix, Slot)) + return Err; + } + + for (size_t I = 0, E = Subs.size(); I != E; ++I) { + Threads.async( + [&](unsigned Idx) { + // Don't run if there is an error already. + if (tooManyErrors()) + return; + if (auto Err = traverseTrieNode(Subs[Idx], Prefixes[Idx])) + addError(std::move(Err)); + }, + I); + } + + Threads.wait(); + if (Err) + return std::move(*Err); + return Error::success(); +} + +Error TrieVisitor::validateSubTrie(SubtrieHandle Node, bool IsRoot) { + char *Addr = reinterpret_cast<char *>(&Node.getHeader()); + const int64_t Offset = Node.getFileOffset().get(); + if (Addr + Node.getSize() >= + Trie.getRegion().data() + Trie.getRegion().size()) + return createInvalidTrieError(Offset, "subtrie node spans out of bound"); + + if (!IsRoot && + Node.getStartBit() + Node.getNumBits() > Trie.getNumHashBits()) { + return createInvalidTrieError(Offset, + "subtrie represents too many hash bits"); + } + + if (IsRoot) { + if (Node.getStartBit() != 0) + return createInvalidTrieError(Offset, + "root node doesn't start at 0 index"); + + return Error::success(); + } + + if (Node.getNumBits() > Trie.getNumSubtrieBits()) + return createInvalidTrieError(Offset, "subtrie has wrong number of slots"); + + return Error::success(); +} + +Error TrieVisitor::traverseTrieNode(SubtrieHandle Node, StringRef Prefix) { + if (auto Err = validateSubTrie(Node, /*IsRoot=*/false)) + return Err; + + if (auto Err = visitSubTrie(Prefix, Node)) + return Err; + + SmallVector<SubtrieHandle> Subs; + SmallVector<std::string> Prefixes; + const size_t NumSlots = Node.getNumSlots(); + for (size_t I = 0, E = NumSlots; I != E; ++I) { + SubtrieSlotValue Slot = Node.load(I); + if (!Slot) + continue; + uint64_t Offset = Slot.isSubtrie() ? Slot.asSubtrie() : Slot.asData(); + if (Offset >= (uint64_t)Trie.getRegion().size()) + return createInvalidTrieError(Offset, "slot points out of bound"); + std::string SubtriePrefix = Prefix.str(); + appendIndexBits(SubtriePrefix, I, NumSlots); + if (Slot.isSubtrie()) { + SubtrieHandle S(Trie.getRegion(), Slot); + Subs.push_back(S); + Prefixes.push_back(SubtriePrefix); + } + if (auto Err = visitSlot(I, Node, SubtriePrefix, Slot)) + return Err; + } + for (size_t I = 0, E = Subs.size(); I != E; ++I) + if (auto Err = traverseTrieNode(Subs[I], Prefixes[I])) + return Err; + + return Error::success(); +} + +void TrieRawHashMapHandle::print( + raw_ostream &OS, function_ref<void(ArrayRef<char>)> PrintRecordData) const { + OS << "hash-num-bits=" << getNumHashBits() + << " hash-size=" << getNumHashBytes() + << " record-data-size=" << getRecordDataSize() << "\n"; + + TriePrinter Printer(*this, OS, PrintRecordData); + if (auto Err = Printer.visit()) + OS << "error: " << toString(std::move(Err)) << "\n"; + + if (auto Err = Printer.printRecords()) + OS << "error: " << toString(std::move(Err)) << "\n"; + + return; +} + +Error TrieRawHashMapHandle::validate( + function_ref<Error(FileOffset, OnDiskTrieRawHashMap::ConstValueProxy)> + RecordVerifier) const { + // Use the base TrieVisitor to identify the errors inside trie first. + TrieVisitor BasicVerifier(*this); + if (auto Err = BasicVerifier.visit()) + return Err; + + // If the trie data structure is sound, do a second pass to verify data and + // verifier function can assume the index is correct. However, there can be + // newly added bad entries that can still produce error. + TrieVerifier Verifier(*this, RecordVerifier); + return Verifier.visit(); +} + +#else // !LLVM_ENABLE_ONDISK_CAS + +struct OnDiskTrieRawHashMap::ImplType {}; + +Expected<OnDiskTrieRawHashMap> +OnDiskTrieRawHashMap::create(const Twine &PathTwine, const Twine &TrieNameTwine, + size_t NumHashBits, uint64_t DataSize, + uint64_t MaxFileSize, + std::optional<uint64_t> NewFileInitialSize, + std::optional<size_t> NewTableNumRootBits, + std::optional<size_t> NewTableNumSubtrieBits) { + return createStringError(make_error_code(std::errc::not_supported), + "OnDiskTrieRawHashMap is not supported"); +} + +Expected<OnDiskTrieRawHashMap::pointer> +OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash, + LazyInsertOnConstructCB OnConstruct, + LazyInsertOnLeakCB OnLeak) { + return createStringError(make_error_code(std::errc::not_supported), + "OnDiskTrieRawHashMap is not supported"); +} + +Expected<OnDiskTrieRawHashMap::const_pointer> +OnDiskTrieRawHashMap::recoverFromFileOffset(FileOffset Offset) const { + return createStringError(make_error_code(std::errc::not_supported), + "OnDiskTrieRawHashMap is not supported"); +} + +OnDiskTrieRawHashMap::const_pointer +OnDiskTrieRawHashMap::find(ArrayRef<uint8_t> Hash) const { + return const_pointer(); +} + +void OnDiskTrieRawHashMap::print( + raw_ostream &OS, function_ref<void(ArrayRef<char>)> PrintRecordData) const { +} + +Error OnDiskTrieRawHashMap::validate( + function_ref<Error(FileOffset, OnDiskTrieRawHashMap::ConstValueProxy)> + RecordVerifier) const { + return createStringError(make_error_code(std::errc::not_supported), + "OnDiskTrieRawHashMap is not supported"); +} + +size_t OnDiskTrieRawHashMap::size() const { return 0; } +size_t OnDiskTrieRawHashMap::capacity() const { return 0; } + +#endif // LLVM_ENABLE_ONDISK_CAS + +OnDiskTrieRawHashMap::OnDiskTrieRawHashMap(std::unique_ptr<ImplType> Impl) + : Impl(std::move(Impl)) {} +OnDiskTrieRawHashMap::OnDiskTrieRawHashMap(OnDiskTrieRawHashMap &&RHS) = + default; +OnDiskTrieRawHashMap & +OnDiskTrieRawHashMap::operator=(OnDiskTrieRawHashMap &&RHS) = default; +OnDiskTrieRawHashMap::~OnDiskTrieRawHashMap() = default; diff --git a/llvm/lib/CodeGen/PeepholeOptimizer.cpp b/llvm/lib/CodeGen/PeepholeOptimizer.cpp index fb3e648..729a57e 100644 --- a/llvm/lib/CodeGen/PeepholeOptimizer.cpp +++ b/llvm/lib/CodeGen/PeepholeOptimizer.cpp @@ -1203,6 +1203,18 @@ bool PeepholeOptimizer::optimizeCoalescableCopyImpl(Rewriter &&CpyRewriter) { if (!NewSrc.Reg) continue; + if (NewSrc.SubReg) { + // Verify the register class supports the subregister index. ARM's + // copy-like queries return register:subreg pairs where the register's + // current class does not directly support the subregister index. + const TargetRegisterClass *RC = MRI->getRegClass(NewSrc.Reg); + const TargetRegisterClass *WithSubRC = + TRI->getSubClassWithSubReg(RC, NewSrc.SubReg); + if (!MRI->constrainRegClass(NewSrc.Reg, WithSubRC)) + continue; + Changed = true; + } + // Rewrite source. if (CpyRewriter.RewriteCurrentSource(NewSrc.Reg, NewSrc.SubReg)) { // We may have extended the live-range of NewSrc, account for that. @@ -1275,6 +1287,18 @@ MachineInstr &PeepholeOptimizer::rewriteSource(MachineInstr &CopyLike, const TargetRegisterClass *DefRC = MRI->getRegClass(Def.Reg); Register NewVReg = MRI->createVirtualRegister(DefRC); + if (NewSrc.SubReg) { + const TargetRegisterClass *NewSrcRC = MRI->getRegClass(NewSrc.Reg); + const TargetRegisterClass *WithSubRC = + TRI->getSubClassWithSubReg(NewSrcRC, NewSrc.SubReg); + + // The new source may not directly support the subregister, but we should be + // able to assume it is constrainable to support the subregister (otherwise + // ValueTracker was lying and reported a useless value). + if (!MRI->constrainRegClass(NewSrc.Reg, WithSubRC)) + llvm_unreachable("replacement register cannot support subregister"); + } + MachineInstr *NewCopy = BuildMI(*CopyLike.getParent(), &CopyLike, CopyLike.getDebugLoc(), TII->get(TargetOpcode::COPY), NewVReg) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 77df4b4..204e1f0 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -11849,9 +11849,7 @@ static bool isLegalToCombineMinNumMaxNum(SelectionDAG &DAG, SDValue LHS, if (!VT.isFloatingPoint()) return false; - const TargetOptions &Options = DAG.getTarget().Options; - - return (Flags.hasNoSignedZeros() || Options.NoSignedZerosFPMath) && + return Flags.hasNoSignedZeros() && TLI.isProfitableToCombineMinNumMaxNum(VT) && (Flags.hasNoNaNs() || (DAG.isKnownNeverNaN(RHS) && DAG.isKnownNeverNaN(LHS))); @@ -17351,7 +17349,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { // Always prefer FMAD to FMA for precision. unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; bool Aggressive = TLI.enableAggressiveFMAFusion(VT); - bool NoSignedZero = Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros(); + bool NoSignedZero = Flags.hasNoSignedZeros(); // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. @@ -18327,11 +18325,9 @@ template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) { return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2); } - // FIXME: use fast math flags instead of Options.UnsafeFPMath - // TODO: Finally migrate away from global TargetOptions. if ((Options.NoNaNsFPMath && Options.NoInfsFPMath) || (N->getFlags().hasNoNaNs() && N->getFlags().hasNoInfs())) { - if (Options.NoSignedZerosFPMath || N->getFlags().hasNoSignedZeros() || + if (N->getFlags().hasNoSignedZeros() || (N2CFP && !N2CFP->isExactlyValue(-0.0))) { if (N0CFP && N0CFP->isZero()) return N2; @@ -18636,8 +18632,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { } // Fold X/Sqrt(X) -> Sqrt(X) - if ((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) && - Flags.hasAllowReassociation()) + if (Flags.hasNoSignedZeros() && Flags.hasAllowReassociation()) if (N1.getOpcode() == ISD::FSQRT && N0 == N1.getOperand(0)) return N1; diff --git a/llvm/lib/ExecutionEngine/JITLink/ELF_loongarch.cpp b/llvm/lib/ExecutionEngine/JITLink/ELF_loongarch.cpp index f23fb34..5f956b1 100644 --- a/llvm/lib/ExecutionEngine/JITLink/ELF_loongarch.cpp +++ b/llvm/lib/ExecutionEngine/JITLink/ELF_loongarch.cpp @@ -365,6 +365,10 @@ private: uint32_t Type = Rel.getType(false); int64_t Addend = Rel.r_addend; + // ignore + if (Type == ELF::R_LARCH_MARK_LA) + return Error::success(); + if (Type == ELF::R_LARCH_RELAX) { if (BlockToFix.edges_empty()) return make_error<StringError>( diff --git a/llvm/lib/ExecutionEngine/JITLink/MachO_arm64.cpp b/llvm/lib/ExecutionEngine/JITLink/MachO_arm64.cpp index 09ac0f1..f794780 100644 --- a/llvm/lib/ExecutionEngine/JITLink/MachO_arm64.cpp +++ b/llvm/lib/ExecutionEngine/JITLink/MachO_arm64.cpp @@ -599,8 +599,7 @@ Expected<std::unique_ptr<LinkGraph>> createLinkGraphFromMachOObject_arm64( } static Error applyPACSigningToModInitPointers(LinkGraph &G) { - assert(G.getTargetTriple().getSubArch() == Triple::AArch64SubArch_arm64e && - "PAC signing only valid for arm64e"); + assert(G.getTargetTriple().isArm64e() && "PAC signing only valid for arm64e"); if (auto *ModInitSec = G.findSectionByName("__DATA,__mod_init_func")) { for (auto *B : ModInitSec->blocks()) { diff --git a/llvm/lib/ExecutionEngine/RuntimeDyld/RuntimeDyldELF.cpp b/llvm/lib/ExecutionEngine/RuntimeDyld/RuntimeDyldELF.cpp index d626803..dd1b1d3 100644 --- a/llvm/lib/ExecutionEngine/RuntimeDyld/RuntimeDyldELF.cpp +++ b/llvm/lib/ExecutionEngine/RuntimeDyld/RuntimeDyldELF.cpp @@ -781,6 +781,9 @@ void RuntimeDyldELF::resolveLoongArch64Relocation(const SectionEntry &Section, default: report_fatal_error("Relocation type not implemented yet!"); break; + case ELF::R_LARCH_MARK_LA: + // ignore + break; case ELF::R_LARCH_32: support::ulittle32_t::ref{TargetPtr} = static_cast<uint32_t>(Value + Addend); diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp index 1a51830..54b92c9 100644 --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -516,19 +516,15 @@ static void PrintShuffleMask(raw_ostream &Out, Type *Ty, ArrayRef<int> Mask) { if (isa<ScalableVectorType>(Ty)) Out << "vscale x "; Out << Mask.size() << " x i32> "; - bool FirstElt = true; if (all_of(Mask, [](int Elt) { return Elt == 0; })) { Out << "zeroinitializer"; } else if (all_of(Mask, [](int Elt) { return Elt == PoisonMaskElem; })) { Out << "poison"; } else { Out << "<"; + ListSeparator LS; for (int Elt : Mask) { - if (FirstElt) - FirstElt = false; - else - Out << ", "; - Out << "i32 "; + Out << LS << "i32 "; if (Elt == PoisonMaskElem) Out << "poison"; else @@ -1700,14 +1696,12 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, if (const ConstantArray *CA = dyn_cast<ConstantArray>(CV)) { Type *ETy = CA->getType()->getElementType(); Out << '['; - WriterCtx.TypePrinter->print(ETy, Out); - Out << ' '; - WriteAsOperandInternal(Out, CA->getOperand(0), WriterCtx); - for (unsigned i = 1, e = CA->getNumOperands(); i != e; ++i) { - Out << ", "; + ListSeparator LS; + for (const Value *Op : CA->operands()) { + Out << LS; WriterCtx.TypePrinter->print(ETy, Out); Out << ' '; - WriteAsOperandInternal(Out, CA->getOperand(i), WriterCtx); + WriteAsOperandInternal(Out, Op, WriterCtx); } Out << ']'; return; @@ -1725,11 +1719,9 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, Type *ETy = CA->getType()->getElementType(); Out << '['; - WriterCtx.TypePrinter->print(ETy, Out); - Out << ' '; - WriteAsOperandInternal(Out, CA->getElementAsConstant(0), WriterCtx); - for (uint64_t i = 1, e = CA->getNumElements(); i != e; ++i) { - Out << ", "; + ListSeparator LS; + for (uint64_t i = 0, e = CA->getNumElements(); i != e; ++i) { + Out << LS; WriterCtx.TypePrinter->print(ETy, Out); Out << ' '; WriteAsOperandInternal(Out, CA->getElementAsConstant(i), WriterCtx); @@ -1742,24 +1734,17 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, if (CS->getType()->isPacked()) Out << '<'; Out << '{'; - unsigned N = CS->getNumOperands(); - if (N) { - Out << ' '; - WriterCtx.TypePrinter->print(CS->getOperand(0)->getType(), Out); + if (CS->getNumOperands() != 0) { Out << ' '; - - WriteAsOperandInternal(Out, CS->getOperand(0), WriterCtx); - - for (unsigned i = 1; i < N; i++) { - Out << ", "; - WriterCtx.TypePrinter->print(CS->getOperand(i)->getType(), Out); + ListSeparator LS; + for (const Value *Op : CS->operands()) { + Out << LS; + WriterCtx.TypePrinter->print(Op->getType(), Out); Out << ' '; - - WriteAsOperandInternal(Out, CS->getOperand(i), WriterCtx); + WriteAsOperandInternal(Out, Op, WriterCtx); } Out << ' '; } - Out << '}'; if (CS->getType()->isPacked()) Out << '>'; @@ -1787,11 +1772,9 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, } Out << '<'; - WriterCtx.TypePrinter->print(ETy, Out); - Out << ' '; - WriteAsOperandInternal(Out, CV->getAggregateElement(0U), WriterCtx); - for (unsigned i = 1, e = CVVTy->getNumElements(); i != e; ++i) { - Out << ", "; + ListSeparator LS; + for (unsigned i = 0, e = CVVTy->getNumElements(); i != e; ++i) { + Out << LS; WriterCtx.TypePrinter->print(ETy, Out); Out << ' '; WriteAsOperandInternal(Out, CV->getAggregateElement(i), WriterCtx); @@ -1848,13 +1831,12 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, Out << ", "; } - for (User::const_op_iterator OI = CE->op_begin(); OI != CE->op_end(); - ++OI) { - WriterCtx.TypePrinter->print((*OI)->getType(), Out); + ListSeparator LS; + for (const Value *Op : CE->operands()) { + Out << LS; + WriterCtx.TypePrinter->print(Op->getType(), Out); Out << ' '; - WriteAsOperandInternal(Out, *OI, WriterCtx); - if (OI+1 != CE->op_end()) - Out << ", "; + WriteAsOperandInternal(Out, Op, WriterCtx); } if (CE->isCast()) { @@ -1875,11 +1857,12 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, static void writeMDTuple(raw_ostream &Out, const MDTuple *Node, AsmWriterContext &WriterCtx) { Out << "!{"; - for (unsigned mi = 0, me = Node->getNumOperands(); mi != me; ++mi) { - const Metadata *MD = Node->getOperand(mi); - if (!MD) + ListSeparator LS; + for (const Metadata *MD : Node->operands()) { + Out << LS; + if (!MD) { Out << "null"; - else if (auto *MDV = dyn_cast<ValueAsMetadata>(MD)) { + } else if (auto *MDV = dyn_cast<ValueAsMetadata>(MD)) { Value *V = MDV->getValue(); WriterCtx.TypePrinter->print(V->getType(), Out); Out << ' '; @@ -1888,8 +1871,6 @@ static void writeMDTuple(raw_ostream &Out, const MDTuple *Node, WriteAsOperandInternal(Out, MD, WriterCtx); WriterCtx.onWriteMetadataAsOperand(MD); } - if (mi + 1 != me) - Out << ", "; } Out << "}"; @@ -1897,24 +1878,9 @@ static void writeMDTuple(raw_ostream &Out, const MDTuple *Node, namespace { -struct FieldSeparator { - bool Skip = true; - const char *Sep; - - FieldSeparator(const char *Sep = ", ") : Sep(Sep) {} -}; - -raw_ostream &operator<<(raw_ostream &OS, FieldSeparator &FS) { - if (FS.Skip) { - FS.Skip = false; - return OS; - } - return OS << FS.Sep; -} - struct MDFieldPrinter { raw_ostream &Out; - FieldSeparator FS; + ListSeparator FS; AsmWriterContext &WriterCtx; explicit MDFieldPrinter(raw_ostream &Out) @@ -2051,7 +2017,7 @@ void MDFieldPrinter::printDIFlags(StringRef Name, DINode::DIFlags Flags) { SmallVector<DINode::DIFlags, 8> SplitFlags; auto Extra = DINode::splitFlags(Flags, SplitFlags); - FieldSeparator FlagsFS(" | "); + ListSeparator FlagsFS(" | "); for (auto F : SplitFlags) { auto StringF = DINode::getFlagString(F); assert(!StringF.empty() && "Expected valid flag"); @@ -2075,7 +2041,7 @@ void MDFieldPrinter::printDISPFlags(StringRef Name, SmallVector<DISubprogram::DISPFlags, 8> SplitFlags; auto Extra = DISubprogram::splitFlags(Flags, SplitFlags); - FieldSeparator FlagsFS(" | "); + ListSeparator FlagsFS(" | "); for (auto F : SplitFlags) { auto StringF = DISubprogram::getFlagString(F); assert(!StringF.empty() && "Expected valid flag"); @@ -2124,7 +2090,7 @@ static void writeGenericDINode(raw_ostream &Out, const GenericDINode *N, Printer.printString("header", N->getHeader()); if (N->getNumDwarfOperands()) { Out << Printer.FS << "operands: {"; - FieldSeparator IFS; + ListSeparator IFS; for (auto &I : N->dwarf_operands()) { Out << IFS; writeMetadataAsOperand(Out, I, WriterCtx); @@ -2638,7 +2604,7 @@ static void writeDILabel(raw_ostream &Out, const DILabel *N, static void writeDIExpression(raw_ostream &Out, const DIExpression *N, AsmWriterContext &WriterCtx) { Out << "!DIExpression("; - FieldSeparator FS; + ListSeparator FS; if (N->isValid()) { for (const DIExpression::ExprOperand &Op : N->expr_ops()) { auto OpStr = dwarf::OperationEncodingString(Op.getOp()); @@ -2666,7 +2632,7 @@ static void writeDIArgList(raw_ostream &Out, const DIArgList *N, assert(FromValue && "Unexpected DIArgList metadata outside of value argument"); Out << "!DIArgList("; - FieldSeparator FS; + ListSeparator FS; MDFieldPrinter Printer(Out, WriterCtx); for (Metadata *Arg : N->getArgs()) { Out << FS; @@ -3073,15 +3039,11 @@ void AssemblyWriter::writeOperandBundles(const CallBase *Call) { Out << " [ "; - bool FirstBundle = true; + ListSeparator LS; for (unsigned i = 0, e = Call->getNumOperandBundles(); i != e; ++i) { OperandBundleUse BU = Call->getOperandBundleAt(i); - if (!FirstBundle) - Out << ", "; - FirstBundle = false; - - Out << '"'; + Out << LS << '"'; printEscapedString(BU.getTagName(), Out); Out << '"'; @@ -3229,7 +3191,7 @@ void AssemblyWriter::printModuleSummaryIndex() { Out << "path: \""; printEscapedString(ModPair.first, Out); Out << "\", hash: ("; - FieldSeparator FS; + ListSeparator FS; for (auto Hash : ModPair.second) Out << FS << Hash; Out << "))\n"; @@ -3347,7 +3309,7 @@ void AssemblyWriter::printTypeIdSummary(const TypeIdSummary &TIS) { printTypeTestResolution(TIS.TTRes); if (!TIS.WPDRes.empty()) { Out << ", wpdResolutions: ("; - FieldSeparator FS; + ListSeparator FS; for (auto &WPDRes : TIS.WPDRes) { Out << FS; Out << "(offset: " << WPDRes.first << ", "; @@ -3362,7 +3324,7 @@ void AssemblyWriter::printTypeIdSummary(const TypeIdSummary &TIS) { void AssemblyWriter::printTypeIdCompatibleVtableSummary( const TypeIdCompatibleVtableInfo &TI) { Out << ", summary: ("; - FieldSeparator FS; + ListSeparator FS; for (auto &P : TI) { Out << FS; Out << "(offset: " << P.AddressPointOffset << ", "; @@ -3374,7 +3336,7 @@ void AssemblyWriter::printTypeIdCompatibleVtableSummary( void AssemblyWriter::printArgs(const std::vector<uint64_t> &Args) { Out << "args: ("; - FieldSeparator FS; + ListSeparator FS; for (auto arg : Args) { Out << FS; Out << arg; @@ -3391,7 +3353,7 @@ void AssemblyWriter::printWPDRes(const WholeProgramDevirtResolution &WPDRes) { if (!WPDRes.ResByArg.empty()) { Out << ", resByArg: ("; - FieldSeparator FS; + ListSeparator FS; for (auto &ResByArg : WPDRes.ResByArg) { Out << FS; printArgs(ResByArg.first); @@ -3451,7 +3413,7 @@ void AssemblyWriter::printGlobalVarSummary(const GlobalVarSummary *GS) { if (!VTableFuncs.empty()) { Out << ", vTableFuncs: ("; - FieldSeparator FS; + ListSeparator FS; for (auto &P : VTableFuncs) { Out << FS; Out << "(virtFunc: ^" << Machine.getGUIDSlot(P.FuncVI.getGUID()) @@ -3528,7 +3490,7 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { if (!FS->calls().empty()) { Out << ", calls: ("; - FieldSeparator IFS; + ListSeparator IFS; for (auto &Call : FS->calls()) { Out << IFS; Out << "(callee: ^" << Machine.getGUIDSlot(Call.first.getGUID()); @@ -3566,22 +3528,22 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { if (!FS->allocs().empty()) { Out << ", allocs: ("; - FieldSeparator AFS; + ListSeparator AFS; for (auto &AI : FS->allocs()) { Out << AFS; Out << "(versions: ("; - FieldSeparator VFS; + ListSeparator VFS; for (auto V : AI.Versions) { Out << VFS; Out << AllocTypeName(V); } Out << "), memProf: ("; - FieldSeparator MIBFS; + ListSeparator MIBFS; for (auto &MIB : AI.MIBs) { Out << MIBFS; Out << "(type: " << AllocTypeName((uint8_t)MIB.AllocType); Out << ", stackIds: ("; - FieldSeparator SIDFS; + ListSeparator SIDFS; for (auto Id : MIB.StackIdIndices) { Out << SIDFS; Out << TheIndex->getStackIdAtIndex(Id); @@ -3595,7 +3557,7 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { if (!FS->callsites().empty()) { Out << ", callsites: ("; - FieldSeparator SNFS; + ListSeparator SNFS; for (auto &CI : FS->callsites()) { Out << SNFS; if (CI.Callee) @@ -3603,13 +3565,13 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { else Out << "(callee: null"; Out << ", clones: ("; - FieldSeparator VFS; + ListSeparator VFS; for (auto V : CI.Clones) { Out << VFS; Out << V; } Out << "), stackIds: ("; - FieldSeparator SIDFS; + ListSeparator SIDFS; for (auto Id : CI.StackIdIndices) { Out << SIDFS; Out << TheIndex->getStackIdAtIndex(Id); @@ -3625,7 +3587,7 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { if (!FS->paramAccesses().empty()) { Out << ", params: ("; - FieldSeparator IFS; + ListSeparator IFS; for (auto &PS : FS->paramAccesses()) { Out << IFS; Out << "(param: " << PS.ParamNo; @@ -3633,7 +3595,7 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { PrintRange(PS.Use); if (!PS.Calls.empty()) { Out << ", calls: ("; - FieldSeparator IFS; + ListSeparator IFS; for (auto &Call : PS.Calls) { Out << IFS; Out << "(callee: ^" << Machine.getGUIDSlot(Call.Callee.getGUID()); @@ -3653,11 +3615,11 @@ void AssemblyWriter::printFunctionSummary(const FunctionSummary *FS) { void AssemblyWriter::printTypeIdInfo( const FunctionSummary::TypeIdInfo &TIDInfo) { Out << ", typeIdInfo: ("; - FieldSeparator TIDFS; + ListSeparator TIDFS; if (!TIDInfo.TypeTests.empty()) { Out << TIDFS; Out << "typeTests: ("; - FieldSeparator FS; + ListSeparator FS; for (auto &GUID : TIDInfo.TypeTests) { auto TidIter = TheIndex->typeIds().equal_range(GUID); if (TidIter.first == TidIter.second) { @@ -3706,7 +3668,7 @@ void AssemblyWriter::printVFuncId(const FunctionSummary::VFuncId VFId) { return; } // Print all type id that correspond to this GUID. - FieldSeparator FS; + ListSeparator FS; for (const auto &[GUID, TypeIdPair] : make_range(TidIter)) { Out << FS; Out << "vFuncId: ("; @@ -3721,7 +3683,7 @@ void AssemblyWriter::printVFuncId(const FunctionSummary::VFuncId VFId) { void AssemblyWriter::printNonConstVCalls( const std::vector<FunctionSummary::VFuncId> &VCallList, const char *Tag) { Out << Tag << ": ("; - FieldSeparator FS; + ListSeparator FS; for (auto &VFuncId : VCallList) { Out << FS; printVFuncId(VFuncId); @@ -3733,7 +3695,7 @@ void AssemblyWriter::printConstVCalls( const std::vector<FunctionSummary::ConstVCall> &VCallList, const char *Tag) { Out << Tag << ": ("; - FieldSeparator FS; + ListSeparator FS; for (auto &ConstVCall : VCallList) { Out << FS; Out << "("; @@ -3774,7 +3736,7 @@ void AssemblyWriter::printSummary(const GlobalValueSummary &Summary) { auto RefList = Summary.refs(); if (!RefList.empty()) { Out << ", refs: ("; - FieldSeparator FS; + ListSeparator FS; for (auto &Ref : RefList) { Out << FS; if (Ref.isReadOnly()) @@ -3797,7 +3759,7 @@ void AssemblyWriter::printSummaryInfo(unsigned Slot, const ValueInfo &VI) { Out << "guid: " << VI.getGUID(); if (!VI.getSummaryList().empty()) { Out << ", summaries: ("; - FieldSeparator FS; + ListSeparator FS; for (auto &Summary : VI.getSummaryList()) { Out << FS; printSummary(*Summary); @@ -3835,13 +3797,11 @@ void AssemblyWriter::printNamedMDNode(const NamedMDNode *NMD) { Out << '!'; printMetadataIdentifier(NMD->getName(), Out); Out << " = !{"; - for (unsigned i = 0, e = NMD->getNumOperands(); i != e; ++i) { - if (i) - Out << ", "; - + ListSeparator LS; + for (const MDNode *Op : NMD->operands()) { + Out << LS; // Write DIExpressions inline. // FIXME: Ban DIExpressions in NamedMDNodes, they will serve no purpose. - MDNode *Op = NMD->getOperand(i); if (auto *Expr = dyn_cast<DIExpression>(Op)) { writeDIExpression(Out, Expr, AsmWriterContext::getEmpty()); continue; @@ -4192,11 +4152,10 @@ void AssemblyWriter::printFunction(const Function *F) { // Loop over the arguments, printing them... if (F->isDeclaration() && !IsForDebug) { // We're only interested in the type here - don't print argument names. + ListSeparator LS; for (unsigned I = 0, E = FT->getNumParams(); I != E; ++I) { - // Insert commas as we go... the first arg doesn't get a comma - if (I) - Out << ", "; - // Output type... + Out << LS; + // Output type. TypePrinter.print(FT->getParamType(I), Out); AttributeSet ArgAttrs = Attrs.getParamAttrs(I); @@ -4207,10 +4166,9 @@ void AssemblyWriter::printFunction(const Function *F) { } } else { // The arguments are meaningful here, print them in detail. + ListSeparator LS; for (const Argument &Arg : F->args()) { - // Insert commas as we go... the first arg doesn't get a comma - if (Arg.getArgNo() != 0) - Out << ", "; + Out << LS; printArgument(&Arg, Attrs.getParamAttrs(Arg.getArgNo())); } } @@ -4332,16 +4290,14 @@ void AssemblyWriter::printBasicBlock(const BasicBlock *BB) { // Output predecessors for the block. Out.PadToColumn(50); Out << ";"; - const_pred_iterator PI = pred_begin(BB), PE = pred_end(BB); - - if (PI == PE) { + if (pred_empty(BB)) { Out << " No predecessors!"; } else { Out << " preds = "; - writeOperand(*PI, false); - for (++PI; PI != PE; ++PI) { - Out << ", "; - writeOperand(*PI, false); + ListSeparator LS; + for (const BasicBlock *Pred : predecessors(BB)) { + Out << LS; + writeOperand(Pred, false); } } } @@ -4520,9 +4476,9 @@ void AssemblyWriter::printInstruction(const Instruction &I) { writeOperand(Operand, true); Out << ", ["; + ListSeparator LS; for (unsigned i = 1, e = I.getNumOperands(); i != e; ++i) { - if (i != 1) - Out << ", "; + Out << LS; writeOperand(I.getOperand(i), true); } Out << ']'; @@ -4531,9 +4487,9 @@ void AssemblyWriter::printInstruction(const Instruction &I) { TypePrinter.print(I.getType(), Out); Out << ' '; + ListSeparator LS; for (unsigned op = 0, Eop = PN->getNumIncomingValues(); op < Eop; ++op) { - if (op) Out << ", "; - Out << "[ "; + Out << LS << "[ "; writeOperand(PN->getIncomingValue(op), false); Out << ", "; writeOperand(PN->getIncomingBlock(op), false); Out << " ]"; } @@ -4570,12 +4526,10 @@ void AssemblyWriter::printInstruction(const Instruction &I) { Out << " within "; writeOperand(CatchSwitch->getParentPad(), /*PrintType=*/false); Out << " ["; - unsigned Op = 0; + ListSeparator LS; for (const BasicBlock *PadBB : CatchSwitch->handlers()) { - if (Op > 0) - Out << ", "; + Out << LS; writeOperand(PadBB, /*PrintType=*/true); - ++Op; } Out << "] unwind "; if (const BasicBlock *UnwindDest = CatchSwitch->getUnwindDest()) @@ -4586,10 +4540,10 @@ void AssemblyWriter::printInstruction(const Instruction &I) { Out << " within "; writeOperand(FPI->getParentPad(), /*PrintType=*/false); Out << " ["; - for (unsigned Op = 0, NumOps = FPI->arg_size(); Op < NumOps; ++Op) { - if (Op > 0) - Out << ", "; - writeOperand(FPI->getArgOperand(Op), /*PrintType=*/true); + ListSeparator LS; + for (const Value *Op : FPI->arg_operands()) { + Out << LS; + writeOperand(Op, /*PrintType=*/true); } Out << ']'; } else if (isa<ReturnInst>(I) && !Operand) { @@ -4635,9 +4589,9 @@ void AssemblyWriter::printInstruction(const Instruction &I) { Out << ' '; writeOperand(Operand, false); Out << '('; + ListSeparator LS; for (unsigned op = 0, Eop = CI->arg_size(); op < Eop; ++op) { - if (op > 0) - Out << ", "; + Out << LS; writeParamOperand(CI->getArgOperand(op), PAL.getParamAttrs(op)); } @@ -4683,9 +4637,9 @@ void AssemblyWriter::printInstruction(const Instruction &I) { Out << ' '; writeOperand(Operand, false); Out << '('; + ListSeparator LS; for (unsigned op = 0, Eop = II->arg_size(); op < Eop; ++op) { - if (op) - Out << ", "; + Out << LS; writeParamOperand(II->getArgOperand(op), PAL.getParamAttrs(op)); } @@ -4723,9 +4677,9 @@ void AssemblyWriter::printInstruction(const Instruction &I) { Out << ' '; writeOperand(Operand, false); Out << '('; + ListSeparator ArgLS; for (unsigned op = 0, Eop = CBI->arg_size(); op < Eop; ++op) { - if (op) - Out << ", "; + Out << ArgLS; writeParamOperand(CBI->getArgOperand(op), PAL.getParamAttrs(op)); } @@ -4738,10 +4692,10 @@ void AssemblyWriter::printInstruction(const Instruction &I) { Out << "\n to "; writeOperand(CBI->getDefaultDest(), true); Out << " ["; - for (unsigned i = 0, e = CBI->getNumIndirectDests(); i != e; ++i) { - if (i != 0) - Out << ", "; - writeOperand(CBI->getIndirectDest(i), true); + ListSeparator DestLS; + for (const BasicBlock *Dest : CBI->getIndirectDests()) { + Out << DestLS; + writeOperand(Dest, true); } Out << ']'; } else if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) { @@ -4824,9 +4778,10 @@ void AssemblyWriter::printInstruction(const Instruction &I) { } Out << ' '; - for (unsigned i = 0, E = I.getNumOperands(); i != E; ++i) { - if (i) Out << ", "; - writeOperand(I.getOperand(i), PrintAllTypes); + ListSeparator LS; + for (const Value *Op : I.operands()) { + Out << LS; + writeOperand(Op, PrintAllTypes); } } diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp index 5385b1f..f28b989 100644 --- a/llvm/lib/IR/AutoUpgrade.cpp +++ b/llvm/lib/IR/AutoUpgrade.cpp @@ -594,6 +594,42 @@ static bool upgradeX86IntrinsicFunction(Function *F, StringRef Name, return false; // No other 'x86.avx512.*'. } + if (Name.consume_front("avx2.vpdpb")) { + // Added in 21.1 + ID = StringSwitch<Intrinsic::ID>(Name) + .Case("ssd.128", Intrinsic::x86_avx2_vpdpbssd_128) + .Case("ssd.256", Intrinsic::x86_avx2_vpdpbssd_256) + .Case("ssds.128", Intrinsic::x86_avx2_vpdpbssds_128) + .Case("ssds.256", Intrinsic::x86_avx2_vpdpbssds_256) + .Case("sud.128", Intrinsic::x86_avx2_vpdpbsud_128) + .Case("sud.256", Intrinsic::x86_avx2_vpdpbsud_256) + .Case("suds.128", Intrinsic::x86_avx2_vpdpbsuds_128) + .Case("suds.256", Intrinsic::x86_avx2_vpdpbsuds_256) + .Case("uud.128", Intrinsic::x86_avx2_vpdpbuud_128) + .Case("uud.256", Intrinsic::x86_avx2_vpdpbuud_256) + .Case("uuds.128", Intrinsic::x86_avx2_vpdpbuuds_128) + .Case("uuds.256", Intrinsic::x86_avx2_vpdpbuuds_256) + .Default(Intrinsic::not_intrinsic); + if (ID != Intrinsic::not_intrinsic) + return upgradeX86MultiplyAddBytes(F, ID, NewFn); + return false; // No other 'x86.avx2.*' + } + + if (Name.consume_front("avx10.vpdpb")) { + // Added in 21.1 + ID = StringSwitch<Intrinsic::ID>(Name) + .Case("ssd.512", Intrinsic::x86_avx10_vpdpbssd_512) + .Case("ssds.512", Intrinsic::x86_avx10_vpdpbssds_512) + .Case("sud.512", Intrinsic::x86_avx10_vpdpbsud_512) + .Case("suds.512", Intrinsic::x86_avx10_vpdpbsuds_512) + .Case("uud.512", Intrinsic::x86_avx10_vpdpbuud_512) + .Case("uuds.512", Intrinsic::x86_avx10_vpdpbuuds_512) + .Default(Intrinsic::not_intrinsic); + if (ID != Intrinsic::not_intrinsic) + return upgradeX86MultiplyAddBytes(F, ID, NewFn); + return false; // No other 'x86.avx10.*' + } + if (Name.consume_front("avx512bf16.")) { // Added in 9.0 ID = StringSwitch<Intrinsic::ID>(Name) @@ -5224,7 +5260,25 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) { case Intrinsic::x86_avx512_vpdpbusd_512: case Intrinsic::x86_avx512_vpdpbusds_128: case Intrinsic::x86_avx512_vpdpbusds_256: - case Intrinsic::x86_avx512_vpdpbusds_512: { + case Intrinsic::x86_avx512_vpdpbusds_512: + case Intrinsic::x86_avx2_vpdpbssd_128: + case Intrinsic::x86_avx2_vpdpbssd_256: + case Intrinsic::x86_avx10_vpdpbssd_512: + case Intrinsic::x86_avx2_vpdpbssds_128: + case Intrinsic::x86_avx2_vpdpbssds_256: + case Intrinsic::x86_avx10_vpdpbssds_512: + case Intrinsic::x86_avx2_vpdpbsud_128: + case Intrinsic::x86_avx2_vpdpbsud_256: + case Intrinsic::x86_avx10_vpdpbsud_512: + case Intrinsic::x86_avx2_vpdpbsuds_128: + case Intrinsic::x86_avx2_vpdpbsuds_256: + case Intrinsic::x86_avx10_vpdpbsuds_512: + case Intrinsic::x86_avx2_vpdpbuud_128: + case Intrinsic::x86_avx2_vpdpbuud_256: + case Intrinsic::x86_avx10_vpdpbuud_512: + case Intrinsic::x86_avx2_vpdpbuuds_128: + case Intrinsic::x86_avx2_vpdpbuuds_256: + case Intrinsic::x86_avx10_vpdpbuuds_512: { unsigned NumElts = CI->getType()->getPrimitiveSizeInBits() / 8; Value *Args[] = {CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2)}; diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index 5827292..99029c1 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -252,6 +252,13 @@ void setExplicitlyUnknownBranchWeights(Instruction &I, StringRef PassName) { MDB.createString(PassName)})); } +void setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &I, Function &F, + StringRef PassName) { + if (std::optional<Function::ProfileCount> EC = F.getEntryCount(); + EC && EC->getCount() > 0) + setExplicitlyUnknownBranchWeights(I, PassName); +} + void setExplicitlyUnknownFunctionEntryCount(Function &F, StringRef PassName) { MDBuilder MDB(F.getContext()); F.setMetadata( diff --git a/llvm/lib/IR/Value.cpp b/llvm/lib/IR/Value.cpp index 4e8f359..e5e062d 100644 --- a/llvm/lib/IR/Value.cpp +++ b/llvm/lib/IR/Value.cpp @@ -1000,14 +1000,12 @@ Align Value::getPointerAlignment(const DataLayout &DL) const { ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(0)); return Align(CI->getLimitedValue()); } - } else if (auto *CstPtr = dyn_cast<Constant>(this)) { - // Strip pointer casts to avoid creating unnecessary ptrtoint expression - // if the only "reduction" is combining a bitcast + ptrtoint. - CstPtr = CstPtr->stripPointerCasts(); - if (auto *CstInt = dyn_cast_or_null<ConstantInt>(ConstantExpr::getPtrToInt( - const_cast<Constant *>(CstPtr), DL.getIntPtrType(getType()), - /*OnlyIfReduced=*/true))) { - size_t TrailingZeros = CstInt->getValue().countr_zero(); + } else if (auto *CE = dyn_cast<ConstantExpr>(this)) { + // Determine the alignment of inttoptr(C). + if (CE->getOpcode() == Instruction::IntToPtr && + isa<ConstantInt>(CE->getOperand(0))) { + ConstantInt *IntPtr = cast<ConstantInt>(CE->getOperand(0)); + size_t TrailingZeros = IntPtr->getValue().countr_zero(); // While the actual alignment may be large, elsewhere we have // an arbitrary upper alignmet limit, so let's clamp to it. return Align(TrailingZeros < Value::MaxAlignmentExponent diff --git a/llvm/lib/Object/OffloadBundle.cpp b/llvm/lib/Object/OffloadBundle.cpp index 1e1042c..0dd378e 100644 --- a/llvm/lib/Object/OffloadBundle.cpp +++ b/llvm/lib/Object/OffloadBundle.cpp @@ -89,17 +89,17 @@ Error OffloadBundleFatBin::readEntries(StringRef Buffer, uint64_t EntryIDSize; StringRef EntryID; - if (auto EC = Reader.readInteger(EntryOffset)) - return errorCodeToError(object_error::parse_failed); + if (Error Err = Reader.readInteger(EntryOffset)) + return Err; - if (auto EC = Reader.readInteger(EntrySize)) - return errorCodeToError(object_error::parse_failed); + if (Error Err = Reader.readInteger(EntrySize)) + return Err; - if (auto EC = Reader.readInteger(EntryIDSize)) - return errorCodeToError(object_error::parse_failed); + if (Error Err = Reader.readInteger(EntryIDSize)) + return Err; - if (auto EC = Reader.readFixedString(EntryID, EntryIDSize)) - return errorCodeToError(object_error::parse_failed); + if (Error Err = Reader.readFixedString(EntryID, EntryIDSize)) + return Err; auto Entry = std::make_unique<OffloadBundleEntry>( EntryOffset + SectionOffset, EntrySize, EntryIDSize, EntryID); @@ -125,7 +125,7 @@ OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset, // Read the Bundle Entries Error Err = TheBundle->readEntries(Buf.getBuffer(), SectionOffset); if (Err) - return errorCodeToError(object_error::parse_failed); + return Err; return std::unique_ptr<OffloadBundleFatBin>(TheBundle); } diff --git a/llvm/lib/Support/Mustache.cpp b/llvm/lib/Support/Mustache.cpp index 686688a..47860c0 100644 --- a/llvm/lib/Support/Mustache.cpp +++ b/llvm/lib/Support/Mustache.cpp @@ -7,9 +7,14 @@ //===----------------------------------------------------------------------===// #include "llvm/Support/Mustache.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include <cctype> +#include <optional> #include <sstream> +#define DEBUG_TYPE "mustache" + using namespace llvm; using namespace llvm::mustache; @@ -51,6 +56,33 @@ static Accessor splitMustacheString(StringRef Str) { namespace llvm::mustache { +class MustacheOutputStream : public raw_ostream { +public: + MustacheOutputStream() = default; + ~MustacheOutputStream() override = default; + + virtual void suspendIndentation() {} + virtual void resumeIndentation() {} + +private: + void anchor() override; +}; + +void MustacheOutputStream::anchor() {} + +class RawMustacheOutputStream : public MustacheOutputStream { +public: + RawMustacheOutputStream(raw_ostream &OS) : OS(OS) { SetUnbuffered(); } + +private: + raw_ostream &OS; + + void write_impl(const char *Ptr, size_t Size) override { + OS.write(Ptr, Size); + } + uint64_t current_pos() const override { return OS.tell(); } +}; + class Token { public: enum class Type { @@ -62,6 +94,7 @@ public: InvertSectionOpen, UnescapeVariable, Comment, + SetDelimiter, }; Token(std::string Str) @@ -102,6 +135,8 @@ public: return Type::Partial; case '&': return Type::UnescapeVariable; + case '=': + return Type::SetDelimiter; default: return Type::Variable; } @@ -130,26 +165,17 @@ public: InvertSection, }; - ASTNode(llvm::StringMap<AstPtr> &Partials, llvm::StringMap<Lambda> &Lambdas, - llvm::StringMap<SectionLambda> &SectionLambdas, EscapeMap &Escapes) - : Partials(Partials), Lambdas(Lambdas), SectionLambdas(SectionLambdas), - Escapes(Escapes), Ty(Type::Root), Parent(nullptr), - ParentContext(nullptr) {} + ASTNode(MustacheContext &Ctx) + : Ctx(Ctx), Ty(Type::Root), Parent(nullptr), ParentContext(nullptr) {} - ASTNode(std::string Body, ASTNode *Parent, llvm::StringMap<AstPtr> &Partials, - llvm::StringMap<Lambda> &Lambdas, - llvm::StringMap<SectionLambda> &SectionLambdas, EscapeMap &Escapes) - : Partials(Partials), Lambdas(Lambdas), SectionLambdas(SectionLambdas), - Escapes(Escapes), Ty(Type::Text), Body(std::move(Body)), Parent(Parent), + ASTNode(MustacheContext &Ctx, std::string Body, ASTNode *Parent) + : Ctx(Ctx), Ty(Type::Text), Body(std::move(Body)), Parent(Parent), ParentContext(nullptr) {} // Constructor for Section/InvertSection/Variable/UnescapeVariable Nodes - ASTNode(Type Ty, Accessor Accessor, ASTNode *Parent, - llvm::StringMap<AstPtr> &Partials, llvm::StringMap<Lambda> &Lambdas, - llvm::StringMap<SectionLambda> &SectionLambdas, EscapeMap &Escapes) - : Partials(Partials), Lambdas(Lambdas), SectionLambdas(SectionLambdas), - Escapes(Escapes), Ty(Ty), Parent(Parent), - AccessorValue(std::move(Accessor)), ParentContext(nullptr) {} + ASTNode(MustacheContext &Ctx, Type Ty, Accessor Accessor, ASTNode *Parent) + : Ctx(Ctx), Ty(Ty), Parent(Parent), AccessorValue(std::move(Accessor)), + ParentContext(nullptr) {} void addChild(AstPtr Child) { Children.emplace_back(std::move(Child)); }; @@ -157,26 +183,33 @@ public: void setIndentation(size_t NewIndentation) { Indentation = NewIndentation; }; - void render(const llvm::json::Value &Data, llvm::raw_ostream &OS); + void render(const llvm::json::Value &Data, MustacheOutputStream &OS); private: - void renderLambdas(const llvm::json::Value &Contexts, llvm::raw_ostream &OS, - Lambda &L); + void renderLambdas(const llvm::json::Value &Contexts, + MustacheOutputStream &OS, Lambda &L); void renderSectionLambdas(const llvm::json::Value &Contexts, - llvm::raw_ostream &OS, SectionLambda &L); + MustacheOutputStream &OS, SectionLambda &L); - void renderPartial(const llvm::json::Value &Contexts, llvm::raw_ostream &OS, - ASTNode *Partial); + void renderPartial(const llvm::json::Value &Contexts, + MustacheOutputStream &OS, ASTNode *Partial); - void renderChild(const llvm::json::Value &Context, llvm::raw_ostream &OS); + void renderChild(const llvm::json::Value &Context, MustacheOutputStream &OS); const llvm::json::Value *findContext(); - StringMap<AstPtr> &Partials; - StringMap<Lambda> &Lambdas; - StringMap<SectionLambda> &SectionLambdas; - EscapeMap &Escapes; + void renderRoot(const json::Value &CurrentCtx, MustacheOutputStream &OS); + void renderText(MustacheOutputStream &OS); + void renderPartial(const json::Value &CurrentCtx, MustacheOutputStream &OS); + void renderVariable(const json::Value &CurrentCtx, MustacheOutputStream &OS); + void renderUnescapeVariable(const json::Value &CurrentCtx, + MustacheOutputStream &OS); + void renderSection(const json::Value &CurrentCtx, MustacheOutputStream &OS); + void renderInvertSection(const json::Value &CurrentCtx, + MustacheOutputStream &OS); + + MustacheContext &Ctx; Type Ty; size_t Indentation = 0; std::string RawBody; @@ -189,29 +222,18 @@ private: }; // A wrapper for arena allocator for ASTNodes -AstPtr createRootNode(llvm::StringMap<AstPtr> &Partials, - llvm::StringMap<Lambda> &Lambdas, - llvm::StringMap<SectionLambda> &SectionLambdas, - EscapeMap &Escapes) { - return std::make_unique<ASTNode>(Partials, Lambdas, SectionLambdas, Escapes); +static AstPtr createRootNode(MustacheContext &Ctx) { + return std::make_unique<ASTNode>(Ctx); } -AstPtr createNode(ASTNode::Type T, Accessor A, ASTNode *Parent, - llvm::StringMap<AstPtr> &Partials, - llvm::StringMap<Lambda> &Lambdas, - llvm::StringMap<SectionLambda> &SectionLambdas, - EscapeMap &Escapes) { - return std::make_unique<ASTNode>(T, std::move(A), Parent, Partials, Lambdas, - SectionLambdas, Escapes); +static AstPtr createNode(MustacheContext &Ctx, ASTNode::Type T, Accessor A, + ASTNode *Parent) { + return std::make_unique<ASTNode>(Ctx, T, std::move(A), Parent); } -AstPtr createTextNode(std::string Body, ASTNode *Parent, - llvm::StringMap<AstPtr> &Partials, - llvm::StringMap<Lambda> &Lambdas, - llvm::StringMap<SectionLambda> &SectionLambdas, - EscapeMap &Escapes) { - return std::make_unique<ASTNode>(std::move(Body), Parent, Partials, Lambdas, - SectionLambdas, Escapes); +static AstPtr createTextNode(MustacheContext &Ctx, std::string Body, + ASTNode *Parent) { + return std::make_unique<ASTNode>(Ctx, std::move(Body), Parent); } // Function to check if there is meaningful text behind. @@ -226,7 +248,7 @@ AstPtr createTextNode(std::string Body, ASTNode *Parent, // and the current token is the second token. // For example: // "{{#Section}}" -bool hasTextBehind(size_t Idx, const ArrayRef<Token> &Tokens) { +static bool hasTextBehind(size_t Idx, const ArrayRef<Token> &Tokens) { if (Idx == 0) return true; @@ -242,7 +264,7 @@ bool hasTextBehind(size_t Idx, const ArrayRef<Token> &Tokens) { // Function to check if there's no meaningful text ahead. // We determine if a token has text ahead if the left of previous // token does not start with a newline. -bool hasTextAhead(size_t Idx, const ArrayRef<Token> &Tokens) { +static bool hasTextAhead(size_t Idx, const ArrayRef<Token> &Tokens) { if (Idx >= Tokens.size() - 1) return true; @@ -255,11 +277,11 @@ bool hasTextAhead(size_t Idx, const ArrayRef<Token> &Tokens) { return !TokenBody.starts_with("\r\n") && !TokenBody.starts_with("\n"); } -bool requiresCleanUp(Token::Type T) { +static bool requiresCleanUp(Token::Type T) { // We must clean up all the tokens that could contain child nodes. return T == Token::Type::SectionOpen || T == Token::Type::InvertSectionOpen || T == Token::Type::SectionClose || T == Token::Type::Comment || - T == Token::Type::Partial; + T == Token::Type::Partial || T == Token::Type::SetDelimiter; } // Adjust next token body if there is no text ahead. @@ -268,7 +290,7 @@ bool requiresCleanUp(Token::Type T) { // "{{! Comment }} \nLine 2" // would be considered as no text ahead and should be rendered as // " Line 2" -void stripTokenAhead(SmallVectorImpl<Token> &Tokens, size_t Idx) { +static void stripTokenAhead(SmallVectorImpl<Token> &Tokens, size_t Idx) { Token &NextToken = Tokens[Idx + 1]; StringRef NextTokenBody = NextToken.TokenBody; // Cut off the leading newline which could be \n or \r\n. @@ -282,72 +304,167 @@ void stripTokenAhead(SmallVectorImpl<Token> &Tokens, size_t Idx) { // For example: // The template string // " \t{{#section}}A{{/section}}" -// would be considered as having no text ahead and would be render as +// would be considered as having no text ahead and would be render as: // "A" -// The exception for this is partial tag which requires us to -// keep track of the indentation once it's rendered. void stripTokenBefore(SmallVectorImpl<Token> &Tokens, size_t Idx, Token &CurrentToken, Token::Type CurrentType) { Token &PrevToken = Tokens[Idx - 1]; StringRef PrevTokenBody = PrevToken.TokenBody; StringRef Unindented = PrevTokenBody.rtrim(" \r\t\v"); size_t Indentation = PrevTokenBody.size() - Unindented.size(); - if (CurrentType != Token::Type::Partial) - PrevToken.TokenBody = Unindented.str(); + PrevToken.TokenBody = Unindented.str(); CurrentToken.setIndentation(Indentation); } +struct Tag { + enum class Kind { + None, + Normal, // {{...}} + Triple, // {{{...}}} + }; + + Kind TagKind = Kind::None; + StringRef Content; // The content between the delimiters. + StringRef FullMatch; // The entire tag, including delimiters. + size_t StartPosition = StringRef::npos; +}; + +[[maybe_unused]] static const char *tagKindToString(Tag::Kind K) { + switch (K) { + case Tag::Kind::None: + return "None"; + case Tag::Kind::Normal: + return "Normal"; + case Tag::Kind::Triple: + return "Triple"; + } + llvm_unreachable("Unknown Tag::Kind"); +} + +[[maybe_unused]] static const char *jsonKindToString(json::Value::Kind K) { + switch (K) { + case json::Value::Kind::Null: + return "JSON_KIND_NULL"; + case json::Value::Kind::Boolean: + return "JSON_KIND_BOOLEAN"; + case json::Value::Kind::Number: + return "JSON_KIND_NUMBER"; + case json::Value::Kind::String: + return "JSON_KIND_STRING"; + case json::Value::Kind::Array: + return "JSON_KIND_ARRAY"; + case json::Value::Kind::Object: + return "JSON_KIND_OBJECT"; + } + llvm_unreachable("Unknown json::Value::Kind"); +} + +static Tag findNextTag(StringRef Template, size_t StartPos, StringRef Open, + StringRef Close) { + const StringLiteral TripleOpen("{{{"); + const StringLiteral TripleClose("}}}"); + + size_t NormalOpenPos = Template.find(Open, StartPos); + size_t TripleOpenPos = Template.find(TripleOpen, StartPos); + + Tag Result; + + // Determine which tag comes first. + if (TripleOpenPos != StringRef::npos && + (NormalOpenPos == StringRef::npos || TripleOpenPos <= NormalOpenPos)) { + // Found a triple mustache tag. + size_t EndPos = + Template.find(TripleClose, TripleOpenPos + TripleOpen.size()); + if (EndPos == StringRef::npos) + return Result; // No closing tag found. + + Result.TagKind = Tag::Kind::Triple; + Result.StartPosition = TripleOpenPos; + size_t ContentStart = TripleOpenPos + TripleOpen.size(); + Result.Content = Template.substr(ContentStart, EndPos - ContentStart); + Result.FullMatch = Template.substr( + TripleOpenPos, (EndPos + TripleClose.size()) - TripleOpenPos); + } else if (NormalOpenPos != StringRef::npos) { + // Found a normal mustache tag. + size_t EndPos = Template.find(Close, NormalOpenPos + Open.size()); + if (EndPos == StringRef::npos) + return Result; // No closing tag found. + + Result.TagKind = Tag::Kind::Normal; + Result.StartPosition = NormalOpenPos; + size_t ContentStart = NormalOpenPos + Open.size(); + Result.Content = Template.substr(ContentStart, EndPos - ContentStart); + Result.FullMatch = + Template.substr(NormalOpenPos, (EndPos + Close.size()) - NormalOpenPos); + } + + return Result; +} + +static std::optional<std::pair<StringRef, StringRef>> +processTag(const Tag &T, SmallVectorImpl<Token> &Tokens) { + LLVM_DEBUG(dbgs() << "[Tag] " << T.FullMatch << ", Content: " << T.Content + << ", Kind: " << tagKindToString(T.TagKind) << "\n"); + if (T.TagKind == Tag::Kind::Triple) { + Tokens.emplace_back(T.FullMatch.str(), "&" + T.Content.str(), '&'); + return std::nullopt; + } + StringRef Interpolated = T.Content; + std::string RawBody = T.FullMatch.str(); + if (!Interpolated.trim().starts_with("=")) { + char Front = Interpolated.empty() ? ' ' : Interpolated.trim().front(); + Tokens.emplace_back(RawBody, Interpolated.str(), Front); + return std::nullopt; + } + Tokens.emplace_back(RawBody, Interpolated.str(), '='); + StringRef DelimSpec = Interpolated.trim(); + DelimSpec = DelimSpec.drop_front(1); + DelimSpec = DelimSpec.take_until([](char C) { return C == '='; }); + DelimSpec = DelimSpec.trim(); + + std::pair<StringRef, StringRef> Ret = DelimSpec.split(' '); + LLVM_DEBUG(dbgs() << "[Set Delimiter] NewOpen: " << Ret.first + << ", NewClose: " << Ret.second << "\n"); + return Ret; +} + // Simple tokenizer that splits the template into tokens. // The mustache spec allows {{{ }}} to unescape variables, // but we don't support that here. An unescape variable // is represented only by {{& variable}}. -SmallVector<Token> tokenize(StringRef Template) { +static SmallVector<Token> tokenize(StringRef Template) { + LLVM_DEBUG(dbgs() << "[Tokenize Template] \"" << Template << "\"\n"); SmallVector<Token> Tokens; - StringLiteral Open("{{"); - StringLiteral Close("}}"); - StringLiteral TripleOpen("{{{"); - StringLiteral TripleClose("}}}"); + SmallString<8> Open("{{"); + SmallString<8> Close("}}"); size_t Start = 0; - size_t DelimiterStart = Template.find(Open); - if (DelimiterStart == StringRef::npos) { - Tokens.emplace_back(Template.str()); - return Tokens; - } - while (DelimiterStart != StringRef::npos) { - if (DelimiterStart != Start) - Tokens.emplace_back(Template.substr(Start, DelimiterStart - Start).str()); - - if (Template.substr(DelimiterStart).starts_with(TripleOpen)) { - size_t DelimiterEnd = Template.find(TripleClose, DelimiterStart); - if (DelimiterEnd == StringRef::npos) - break; - size_t BodyStart = DelimiterStart + TripleOpen.size(); - std::string Body = - Template.substr(BodyStart, DelimiterEnd - BodyStart).str(); - std::string RawBody = - Template.substr(DelimiterStart, DelimiterEnd - DelimiterStart + 3) - .str(); - Tokens.emplace_back(RawBody, "&" + Body, '&'); - Start = DelimiterEnd + TripleClose.size(); - } else { - size_t DelimiterEnd = Template.find(Close, DelimiterStart); - if (DelimiterEnd == StringRef::npos) - break; - - // Extract the Interpolated variable without delimiters. - size_t InterpolatedStart = DelimiterStart + Open.size(); - size_t InterpolatedEnd = DelimiterEnd - DelimiterStart - Close.size(); - std::string Interpolated = - Template.substr(InterpolatedStart, InterpolatedEnd).str(); - std::string RawBody = Open.str() + Interpolated + Close.str(); - Tokens.emplace_back(RawBody, Interpolated, Interpolated[0]); - Start = DelimiterEnd + Close.size(); + + while (Start < Template.size()) { + LLVM_DEBUG(dbgs() << "[Tokenize Loop] Start:" << Start << ", Open:'" << Open + << "', Close:'" << Close << "'\n"); + Tag T = findNextTag(Template, Start, Open, Close); + + if (T.TagKind == Tag::Kind::None) { + // No more tags, the rest is text. + Tokens.emplace_back(Template.substr(Start).str()); + LLVM_DEBUG(dbgs() << " No more tags. Created final Text token: \"" + << Template.substr(Start) << "\"\n"); + break; } - DelimiterStart = Template.find(Open, Start); - } - if (Start < Template.size()) - Tokens.emplace_back(Template.substr(Start).str()); + // Add the text before the tag. + if (T.StartPosition > Start) { + StringRef Text = Template.substr(Start, T.StartPosition - Start); + Tokens.emplace_back(Text.str()); + } + + if (auto NewDelims = processTag(T, Tokens)) { + std::tie(Open, Close) = *NewDelims; + } + + // Move past the tag. + Start = T.StartPosition + T.FullMatch.size(); + } // Fix up white spaces for: // - open sections @@ -393,7 +510,7 @@ SmallVector<Token> tokenize(StringRef Template) { } // Custom stream to escape strings. -class EscapeStringStream : public raw_ostream { +class EscapeStringStream : public MustacheOutputStream { public: explicit EscapeStringStream(llvm::raw_ostream &WrappedStream, EscapeMap &Escape) @@ -435,23 +552,35 @@ private: }; // Custom stream to add indentation used to for rendering partials. -class AddIndentationStringStream : public raw_ostream { +class AddIndentationStringStream : public MustacheOutputStream { public: - explicit AddIndentationStringStream(llvm::raw_ostream &WrappedStream, + explicit AddIndentationStringStream(raw_ostream &WrappedStream, size_t Indentation) - : Indentation(Indentation), WrappedStream(WrappedStream) { + : Indentation(Indentation), WrappedStream(WrappedStream), + NeedsIndent(true), IsSuspended(false) { SetUnbuffered(); } + void suspendIndentation() override { IsSuspended = true; } + void resumeIndentation() override { IsSuspended = false; } + protected: void write_impl(const char *Ptr, size_t Size) override { llvm::StringRef Data(Ptr, Size); SmallString<0> Indent; Indent.resize(Indentation, ' '); + for (char C : Data) { - WrappedStream << C; - if (C == '\n') + LLVM_DEBUG(dbgs() << "[Indentation Stream] NeedsIndent:" << NeedsIndent + << ", C:'" << C << "', Indentation:" << Indentation + << "\n"); + if (NeedsIndent && C != '\n') { WrappedStream << Indent; + NeedsIndent = false; + } + WrappedStream << C; + if (C == '\n' && !IsSuspended) + NeedsIndent = true; } } @@ -459,44 +588,50 @@ protected: private: size_t Indentation; - llvm::raw_ostream &WrappedStream; + raw_ostream &WrappedStream; + bool NeedsIndent; + bool IsSuspended; }; class Parser { public: - Parser(StringRef TemplateStr) : TemplateStr(TemplateStr) {} + Parser(StringRef TemplateStr, MustacheContext &Ctx) + : Ctx(Ctx), TemplateStr(TemplateStr) {} - AstPtr parse(llvm::StringMap<AstPtr> &Partials, - llvm::StringMap<Lambda> &Lambdas, - llvm::StringMap<SectionLambda> &SectionLambdas, - EscapeMap &Escapes); + AstPtr parse(); private: - void parseMustache(ASTNode *Parent, llvm::StringMap<AstPtr> &Partials, - llvm::StringMap<Lambda> &Lambdas, - llvm::StringMap<SectionLambda> &SectionLambdas, - EscapeMap &Escapes); + void parseMustache(ASTNode *Parent); + void parseSection(ASTNode *Parent, ASTNode::Type Ty, const Accessor &A); + MustacheContext &Ctx; SmallVector<Token> Tokens; size_t CurrentPtr; StringRef TemplateStr; }; -AstPtr Parser::parse(llvm::StringMap<AstPtr> &Partials, - llvm::StringMap<Lambda> &Lambdas, - llvm::StringMap<SectionLambda> &SectionLambdas, - EscapeMap &Escapes) { +void Parser::parseSection(ASTNode *Parent, ASTNode::Type Ty, + const Accessor &A) { + AstPtr CurrentNode = createNode(Ctx, Ty, A, Parent); + size_t Start = CurrentPtr; + parseMustache(CurrentNode.get()); + const size_t End = CurrentPtr - 1; + std::string RawBody; + for (std::size_t I = Start; I < End; I++) + RawBody += Tokens[I].RawBody; + CurrentNode->setRawBody(std::move(RawBody)); + Parent->addChild(std::move(CurrentNode)); +} + +AstPtr Parser::parse() { Tokens = tokenize(TemplateStr); CurrentPtr = 0; - AstPtr RootNode = createRootNode(Partials, Lambdas, SectionLambdas, Escapes); - parseMustache(RootNode.get(), Partials, Lambdas, SectionLambdas, Escapes); + AstPtr RootNode = createRootNode(Ctx); + parseMustache(RootNode.get()); return RootNode; } -void Parser::parseMustache(ASTNode *Parent, llvm::StringMap<AstPtr> &Partials, - llvm::StringMap<Lambda> &Lambdas, - llvm::StringMap<SectionLambda> &SectionLambdas, - EscapeMap &Escapes) { +void Parser::parseMustache(ASTNode *Parent) { while (CurrentPtr < Tokens.size()) { Token CurrentToken = Tokens[CurrentPtr]; @@ -506,66 +641,48 @@ void Parser::parseMustache(ASTNode *Parent, llvm::StringMap<AstPtr> &Partials, switch (CurrentToken.getType()) { case Token::Type::Text: { - CurrentNode = createTextNode(std::move(CurrentToken.TokenBody), Parent, - Partials, Lambdas, SectionLambdas, Escapes); + CurrentNode = + createTextNode(Ctx, std::move(CurrentToken.TokenBody), Parent); Parent->addChild(std::move(CurrentNode)); break; } case Token::Type::Variable: { - CurrentNode = createNode(ASTNode::Variable, std::move(A), Parent, - Partials, Lambdas, SectionLambdas, Escapes); + CurrentNode = createNode(Ctx, ASTNode::Variable, std::move(A), Parent); Parent->addChild(std::move(CurrentNode)); break; } case Token::Type::UnescapeVariable: { - CurrentNode = createNode(ASTNode::UnescapeVariable, std::move(A), Parent, - Partials, Lambdas, SectionLambdas, Escapes); + CurrentNode = + createNode(Ctx, ASTNode::UnescapeVariable, std::move(A), Parent); Parent->addChild(std::move(CurrentNode)); break; } case Token::Type::Partial: { - CurrentNode = createNode(ASTNode::Partial, std::move(A), Parent, Partials, - Lambdas, SectionLambdas, Escapes); + CurrentNode = createNode(Ctx, ASTNode::Partial, std::move(A), Parent); CurrentNode->setIndentation(CurrentToken.getIndentation()); Parent->addChild(std::move(CurrentNode)); break; } case Token::Type::SectionOpen: { - CurrentNode = createNode(ASTNode::Section, A, Parent, Partials, Lambdas, - SectionLambdas, Escapes); - size_t Start = CurrentPtr; - parseMustache(CurrentNode.get(), Partials, Lambdas, SectionLambdas, - Escapes); - const size_t End = CurrentPtr - 1; - std::string RawBody; - for (std::size_t I = Start; I < End; I++) - RawBody += Tokens[I].RawBody; - CurrentNode->setRawBody(std::move(RawBody)); - Parent->addChild(std::move(CurrentNode)); + parseSection(Parent, ASTNode::Section, A); break; } case Token::Type::InvertSectionOpen: { - CurrentNode = createNode(ASTNode::InvertSection, A, Parent, Partials, - Lambdas, SectionLambdas, Escapes); - size_t Start = CurrentPtr; - parseMustache(CurrentNode.get(), Partials, Lambdas, SectionLambdas, - Escapes); - const size_t End = CurrentPtr - 1; - std::string RawBody; - for (size_t Idx = Start; Idx < End; Idx++) - RawBody += Tokens[Idx].RawBody; - CurrentNode->setRawBody(std::move(RawBody)); - Parent->addChild(std::move(CurrentNode)); + parseSection(Parent, ASTNode::InvertSection, A); break; } case Token::Type::Comment: + case Token::Type::SetDelimiter: break; case Token::Type::SectionClose: return; } } } -void toMustacheString(const json::Value &Data, raw_ostream &OS) { +static void toMustacheString(const json::Value &Data, raw_ostream &OS) { + LLVM_DEBUG(dbgs() << "[To Mustache String] Kind: " + << jsonKindToString(Data.kind()) << ", Data: " << Data + << "\n"); switch (Data.kind()) { case json::Value::Null: return; @@ -597,74 +714,106 @@ void toMustacheString(const json::Value &Data, raw_ostream &OS) { } } -void ASTNode::render(const json::Value &CurrentCtx, raw_ostream &OS) { +void ASTNode::renderRoot(const json::Value &CurrentCtx, + MustacheOutputStream &OS) { + renderChild(CurrentCtx, OS); +} + +void ASTNode::renderText(MustacheOutputStream &OS) { OS << Body; } + +void ASTNode::renderPartial(const json::Value &CurrentCtx, + MustacheOutputStream &OS) { + LLVM_DEBUG(dbgs() << "[Render Partial] Accessor:" << AccessorValue[0] + << ", Indentation:" << Indentation << "\n"); + auto Partial = Ctx.Partials.find(AccessorValue[0]); + if (Partial != Ctx.Partials.end()) + renderPartial(CurrentCtx, OS, Partial->getValue().get()); +} + +void ASTNode::renderVariable(const json::Value &CurrentCtx, + MustacheOutputStream &OS) { + auto Lambda = Ctx.Lambdas.find(AccessorValue[0]); + if (Lambda != Ctx.Lambdas.end()) { + renderLambdas(CurrentCtx, OS, Lambda->getValue()); + } else if (const json::Value *ContextPtr = findContext()) { + EscapeStringStream ES(OS, Ctx.Escapes); + toMustacheString(*ContextPtr, ES); + } +} + +void ASTNode::renderUnescapeVariable(const json::Value &CurrentCtx, + MustacheOutputStream &OS) { + LLVM_DEBUG(dbgs() << "[Render UnescapeVariable] Accessor:" << AccessorValue[0] + << "\n"); + auto Lambda = Ctx.Lambdas.find(AccessorValue[0]); + if (Lambda != Ctx.Lambdas.end()) { + renderLambdas(CurrentCtx, OS, Lambda->getValue()); + } else if (const json::Value *ContextPtr = findContext()) { + OS.suspendIndentation(); + toMustacheString(*ContextPtr, OS); + OS.resumeIndentation(); + } +} + +void ASTNode::renderSection(const json::Value &CurrentCtx, + MustacheOutputStream &OS) { + auto SectionLambda = Ctx.SectionLambdas.find(AccessorValue[0]); + if (SectionLambda != Ctx.SectionLambdas.end()) { + renderSectionLambdas(CurrentCtx, OS, SectionLambda->getValue()); + return; + } + + const json::Value *ContextPtr = findContext(); + if (isContextFalsey(ContextPtr)) + return; + + if (const json::Array *Arr = ContextPtr->getAsArray()) { + for (const json::Value &V : *Arr) + renderChild(V, OS); + return; + } + renderChild(*ContextPtr, OS); +} + +void ASTNode::renderInvertSection(const json::Value &CurrentCtx, + MustacheOutputStream &OS) { + bool IsLambda = Ctx.SectionLambdas.contains(AccessorValue[0]); + const json::Value *ContextPtr = findContext(); + if (isContextFalsey(ContextPtr) && !IsLambda) { + renderChild(CurrentCtx, OS); + } +} + +void ASTNode::render(const llvm::json::Value &Data, MustacheOutputStream &OS) { + if (Ty != Root && Ty != Text && AccessorValue.empty()) + return; // Set the parent context to the incoming context so that we // can walk up the context tree correctly in findContext(). - ParentContext = &CurrentCtx; - const json::Value *ContextPtr = Ty == Root ? ParentContext : findContext(); + ParentContext = &Data; switch (Ty) { case Root: - renderChild(CurrentCtx, OS); + renderRoot(Data, OS); return; case Text: - OS << Body; + renderText(OS); return; - case Partial: { - auto Partial = Partials.find(AccessorValue[0]); - if (Partial != Partials.end()) - renderPartial(CurrentCtx, OS, Partial->getValue().get()); + case Partial: + renderPartial(Data, OS); return; - } - case Variable: { - auto Lambda = Lambdas.find(AccessorValue[0]); - if (Lambda != Lambdas.end()) { - renderLambdas(CurrentCtx, OS, Lambda->getValue()); - } else if (ContextPtr) { - EscapeStringStream ES(OS, Escapes); - toMustacheString(*ContextPtr, ES); - } + case Variable: + renderVariable(Data, OS); return; - } - case UnescapeVariable: { - auto Lambda = Lambdas.find(AccessorValue[0]); - if (Lambda != Lambdas.end()) { - renderLambdas(CurrentCtx, OS, Lambda->getValue()); - } else if (ContextPtr) { - toMustacheString(*ContextPtr, OS); - } + case UnescapeVariable: + renderUnescapeVariable(Data, OS); return; - } - case Section: { - auto SectionLambda = SectionLambdas.find(AccessorValue[0]); - bool IsLambda = SectionLambda != SectionLambdas.end(); - - if (IsLambda) { - renderSectionLambdas(CurrentCtx, OS, SectionLambda->getValue()); - return; - } - - if (isContextFalsey(ContextPtr)) - return; - - if (const json::Array *Arr = ContextPtr->getAsArray()) { - for (const json::Value &V : *Arr) - renderChild(V, OS); - return; - } - renderChild(*ContextPtr, OS); + case Section: + renderSection(Data, OS); return; - } - case InvertSection: { - bool IsLambda = SectionLambdas.contains(AccessorValue[0]); - if (isContextFalsey(ContextPtr) && !IsLambda) { - // The context for the children remains unchanged from the parent's, so - // we pass this node's original incoming context. - renderChild(CurrentCtx, OS); - } + case InvertSection: + renderInvertSection(Data, OS); return; } - } llvm_unreachable("Invalid ASTNode type"); } @@ -707,27 +856,29 @@ const json::Value *ASTNode::findContext() { return Context; } -void ASTNode::renderChild(const json::Value &Contexts, llvm::raw_ostream &OS) { +void ASTNode::renderChild(const json::Value &Contexts, + MustacheOutputStream &OS) { for (AstPtr &Child : Children) Child->render(Contexts, OS); } -void ASTNode::renderPartial(const json::Value &Contexts, llvm::raw_ostream &OS, - ASTNode *Partial) { +void ASTNode::renderPartial(const json::Value &Contexts, + MustacheOutputStream &OS, ASTNode *Partial) { + LLVM_DEBUG(dbgs() << "[Render Partial Indentation] Indentation: " << Indentation << "\n"); AddIndentationStringStream IS(OS, Indentation); Partial->render(Contexts, IS); } -void ASTNode::renderLambdas(const json::Value &Contexts, llvm::raw_ostream &OS, - Lambda &L) { +void ASTNode::renderLambdas(const json::Value &Contexts, + MustacheOutputStream &OS, Lambda &L) { json::Value LambdaResult = L(); std::string LambdaStr; raw_string_ostream Output(LambdaStr); toMustacheString(LambdaResult, Output); - Parser P = Parser(LambdaStr); - AstPtr LambdaNode = P.parse(Partials, Lambdas, SectionLambdas, Escapes); + Parser P(LambdaStr, Ctx); + AstPtr LambdaNode = P.parse(); - EscapeStringStream ES(OS, Escapes); + EscapeStringStream ES(OS, Ctx.Escapes); if (Ty == Variable) { LambdaNode->render(Contexts, ES); return; @@ -736,39 +887,44 @@ void ASTNode::renderLambdas(const json::Value &Contexts, llvm::raw_ostream &OS, } void ASTNode::renderSectionLambdas(const json::Value &Contexts, - llvm::raw_ostream &OS, SectionLambda &L) { + MustacheOutputStream &OS, SectionLambda &L) { json::Value Return = L(RawBody); if (isFalsey(Return)) return; std::string LambdaStr; raw_string_ostream Output(LambdaStr); toMustacheString(Return, Output); - Parser P = Parser(LambdaStr); - AstPtr LambdaNode = P.parse(Partials, Lambdas, SectionLambdas, Escapes); + Parser P(LambdaStr, Ctx); + AstPtr LambdaNode = P.parse(); LambdaNode->render(Contexts, OS); } void Template::render(const json::Value &Data, llvm::raw_ostream &OS) { - Tree->render(Data, OS); + RawMustacheOutputStream MOS(OS); + Tree->render(Data, MOS); } void Template::registerPartial(std::string Name, std::string Partial) { - Parser P = Parser(Partial); - AstPtr PartialTree = P.parse(Partials, Lambdas, SectionLambdas, Escapes); - Partials.insert(std::make_pair(Name, std::move(PartialTree))); + Parser P(Partial, Ctx); + AstPtr PartialTree = P.parse(); + Ctx.Partials.insert(std::make_pair(Name, std::move(PartialTree))); } -void Template::registerLambda(std::string Name, Lambda L) { Lambdas[Name] = L; } +void Template::registerLambda(std::string Name, Lambda L) { + Ctx.Lambdas[Name] = L; +} void Template::registerLambda(std::string Name, SectionLambda L) { - SectionLambdas[Name] = L; + Ctx.SectionLambdas[Name] = L; } -void Template::overrideEscapeCharacters(EscapeMap E) { Escapes = std::move(E); } +void Template::overrideEscapeCharacters(EscapeMap E) { + Ctx.Escapes = std::move(E); +} Template::Template(StringRef TemplateStr) { - Parser P = Parser(TemplateStr); - Tree = P.parse(Partials, Lambdas, SectionLambdas, Escapes); + Parser P(TemplateStr, Ctx); + Tree = P.parse(); // The default behavior is to escape html entities. const EscapeMap HtmlEntities = {{'&', "&"}, {'<', "<"}, @@ -779,21 +935,18 @@ Template::Template(StringRef TemplateStr) { } Template::Template(Template &&Other) noexcept - : Partials(std::move(Other.Partials)), Lambdas(std::move(Other.Lambdas)), - SectionLambdas(std::move(Other.SectionLambdas)), - Escapes(std::move(Other.Escapes)), Tree(std::move(Other.Tree)) {} + : Ctx(std::move(Other.Ctx)), Tree(std::move(Other.Tree)) {} Template::~Template() = default; Template &Template::operator=(Template &&Other) noexcept { if (this != &Other) { - Partials = std::move(Other.Partials); - Lambdas = std::move(Other.Lambdas); - SectionLambdas = std::move(Other.SectionLambdas); - Escapes = std::move(Other.Escapes); + Ctx = std::move(Other.Ctx); Tree = std::move(Other.Tree); Other.Tree = nullptr; } return *this; } } // namespace llvm::mustache + +#undef DEBUG_TYPE diff --git a/llvm/lib/Support/VirtualFileSystem.cpp b/llvm/lib/Support/VirtualFileSystem.cpp index 7ff62d4..44d2ee7 100644 --- a/llvm/lib/Support/VirtualFileSystem.cpp +++ b/llvm/lib/Support/VirtualFileSystem.cpp @@ -1908,7 +1908,12 @@ private: FullPath = FS->getOverlayFileDir(); assert(!FullPath.empty() && "External contents prefix directory must exist"); - llvm::sys::path::append(FullPath, Value); + SmallString<256> AbsFullPath = Value; + if (FS->makeAbsolute(FullPath, AbsFullPath)) { + error(N, "failed to make 'external-contents' absolute"); + return nullptr; + } + FullPath = AbsFullPath; } else { FullPath = Value; } @@ -2204,7 +2209,7 @@ RedirectingFileSystem::create(std::unique_ptr<MemoryBuffer> Buffer, // FS->OverlayFileDir => /<absolute_path_to>/dummy.cache/vfs // SmallString<256> OverlayAbsDir = sys::path::parent_path(YAMLFilePath); - std::error_code EC = llvm::sys::fs::make_absolute(OverlayAbsDir); + std::error_code EC = FS->makeAbsolute(OverlayAbsDir); assert(!EC && "Overlay dir final path must be absolute"); (void)EC; FS->setOverlayFileDir(OverlayAbsDir); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 899baa9..45f5235 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -18867,21 +18867,25 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming()))) return SDValue(); - unsigned NumUses = N->use_size(); + // Count the number of users which are extract_vectors. + unsigned NumExts = count_if(N->users(), [](SDNode *Use) { + return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR; + }); + auto MaskEC = N->getValueType(0).getVectorElementCount(); - if (!MaskEC.isKnownMultipleOf(NumUses)) + if (!MaskEC.isKnownMultipleOf(NumExts)) return SDValue(); - ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses); + ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts); if (ExtMinEC.getKnownMinValue() < 2) return SDValue(); - SmallVector<SDNode *> Extracts(NumUses, nullptr); + SmallVector<SDNode *> Extracts(NumExts, nullptr); for (SDNode *Use : N->users()) { if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR) - return SDValue(); + continue; - // Ensure the extract type is correct (e.g. if NumUses is 4 and + // Ensure the extract type is correct (e.g. if NumExts is 4 and // the mask return type is nxv8i1, each extract should be nxv2i1. if (Use->getValueType(0).getVectorElementCount() != ExtMinEC) return SDValue(); @@ -18902,32 +18906,39 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SDValue Idx = N->getOperand(0); SDValue TC = N->getOperand(1); - EVT OpVT = Idx.getValueType(); - if (OpVT != MVT::i64) { + if (Idx.getValueType() != MVT::i64) { Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx); TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC); } // Create the whilelo_x2 intrinsics from each pair of extracts EVT ExtVT = Extracts[0]->getValueType(0); + EVT DoubleExtVT = ExtVT.getDoubleNumVectorElementsVT(*DAG.getContext()); auto R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC}); DCI.CombineTo(Extracts[0], R.getValue(0)); DCI.CombineTo(Extracts[1], R.getValue(1)); + SmallVector<SDValue> Concats = {DAG.getNode( + ISD::CONCAT_VECTORS, DL, DoubleExtVT, R.getValue(0), R.getValue(1))}; - if (NumUses == 2) - return SDValue(N, 0); + if (NumExts == 2) { + assert(N->getValueType(0) == DoubleExtVT); + return Concats[0]; + } - auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2); - for (unsigned I = 2; I < NumUses; I += 2) { + auto Elts = + DAG.getElementCount(DL, MVT::i64, ExtVT.getVectorElementCount() * 2); + for (unsigned I = 2; I < NumExts; I += 2) { // After the first whilelo_x2, we need to increment the starting value. - Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts); + Idx = DAG.getNode(ISD::UADDSAT, DL, MVT::i64, Idx, Elts); R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC}); DCI.CombineTo(Extracts[I], R.getValue(0)); DCI.CombineTo(Extracts[I + 1], R.getValue(1)); + Concats.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, DoubleExtVT, + R.getValue(0), R.getValue(1))); } - return SDValue(N, 0); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Concats); } // Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce @@ -25512,6 +25523,32 @@ SDValue performCONDCombine(SDNode *N, CmpIndex, CC)) return Val; + // X & M ?= C --> (C << clz(M)) ?= (X << clz(M)) where M is a non-empty + // sequence of ones starting at the least significant bit with the remainder + // zero and C is a constant s.t. (C & ~M) == 0 that cannot be materialised + // into a SUBS (immediate). The transformed form can be matched into a SUBS + // (shifted register). + if ((CC == AArch64CC::EQ || CC == AArch64CC::NE) && AndNode->hasOneUse() && + isa<ConstantSDNode>(AndNode->getOperand(1)) && + isa<ConstantSDNode>(SubsNode->getOperand(1))) { + SDValue X = AndNode->getOperand(0); + APInt M = AndNode->getConstantOperandAPInt(1); + APInt C = SubsNode->getConstantOperandAPInt(1); + + if (M.isMask() && C.isSubsetOf(M) && !isLegalArithImmed(C.getZExtValue())) { + SDLoc DL(SubsNode); + EVT VT = SubsNode->getValueType(0); + unsigned ShiftAmt = M.countl_zero(); + SDValue ShiftedX = DAG.getNode( + ISD::SHL, DL, VT, X, DAG.getShiftAmountConstant(ShiftAmt, VT, DL)); + SDValue ShiftedC = DAG.getConstant(C << ShiftAmt, DL, VT); + SDValue NewSubs = DAG.getNode(AArch64ISD::SUBS, DL, SubsNode->getVTList(), + ShiftedC, ShiftedX); + DCI.CombineTo(SubsNode, NewSubs, NewSubs.getValue(1)); + return SDValue(N, 0); + } + } + if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(AndNode->getOperand(1))) { uint32_t CNV = CN->getZExtValue(); if (CNV == 255) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index d8072d1..e472e7d 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -303,6 +303,16 @@ public: bool shouldFoldConstantShiftPairToMask(const SDNode *N, CombineLevel Level) const override; + /// Return true if it is profitable to fold a pair of shifts into a mask. + bool shouldFoldMaskToVariableShiftPair(SDValue Y) const override { + EVT VT = Y.getValueType(); + + if (VT.isVector()) + return false; + + return VT.getScalarSizeInBits() <= 64; + } + bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X, SDValue Y) const override; diff --git a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp index 7947469..09b3643 100644 --- a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp +++ b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp @@ -541,6 +541,13 @@ void AArch64PrologueEmitter::emitPrologue() { // to determine the end of the prologue. DebugLoc DL; + // In some cases, particularly with CallingConv::SwiftTail, it is possible to + // have a tail-call where the caller only needs to adjust the stack pointer in + // the epilogue. In this case, we still need to emit a SEH prologue sequence. + // See `seh-minimal-prologue-epilogue.ll` test cases. + if (AFI->getArgumentStackToRestore()) + HasWinCFI = true; + if (AFI->shouldSignReturnAddress(MF)) { // If pac-ret+leaf is in effect, PAUTH_PROLOGUE pseudo instructions // are inserted by emitPacRetPlusLeafHardening(). diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp index cced0fa..4749748 100644 --- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp +++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp @@ -22,7 +22,7 @@ // To handle ZA state across control flow, we make use of edge bundling. This // assigns each block an "incoming" and "outgoing" edge bundle (representing // incoming and outgoing edges). Initially, these are unique to each block; -// then, in the process of forming bundles, the outgoing block of a block is +// then, in the process of forming bundles, the outgoing bundle of a block is // joined with the incoming bundle of all successors. The result is that each // bundle can be assigned a single ZA state, which ensures the state required by // all a blocks' successors is the same, and that each basic block will always diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td index eaa1870..7003a40 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPU.td +++ b/llvm/lib/Target/AMDGPU/AMDGPU.td @@ -2589,6 +2589,8 @@ def NotHasTrue16BitInsts : True16PredicateClass<"!Subtarget->hasTrue16BitInsts() // only allow 32-bit registers in operands and use low halves thereof. def UseRealTrue16Insts : True16PredicateClass<"Subtarget->useRealTrue16Insts()">, AssemblerPredicate<(all_of FeatureTrue16BitInsts, FeatureRealTrue16Insts)>; +def NotUseRealTrue16Insts : True16PredicateClass<"!Subtarget->useRealTrue16Insts()">, + AssemblerPredicate<(not (all_of FeatureTrue16BitInsts, FeatureRealTrue16Insts))>; def UseFakeTrue16Insts : True16PredicateClass<"Subtarget->hasTrue16BitInsts() && " "!Subtarget->useRealTrue16Insts()">, AssemblerPredicate<(all_of FeatureTrue16BitInsts)>; diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp index 0776d14..f413bbc 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp @@ -840,7 +840,9 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST, .Any({{B128, Ptr32}, {{}, {VgprB128, VgprPtr32}}}); // clang-format on - addRulesForGOpcs({G_AMDGPU_BUFFER_LOAD}, StandardB) + addRulesForGOpcs({G_AMDGPU_BUFFER_LOAD, G_AMDGPU_BUFFER_LOAD_FORMAT, + G_AMDGPU_TBUFFER_LOAD_FORMAT}, + StandardB) .Div(B32, {{VgprB32}, {SgprV4S32_WF, Vgpr32, Vgpr32, Sgpr32_WF}}) .Uni(B32, {{UniInVgprB32}, {SgprV4S32_WF, Vgpr32, Vgpr32, Sgpr32_WF}}) .Div(B64, {{VgprB64}, {SgprV4S32_WF, Vgpr32, Vgpr32, Sgpr32_WF}}) diff --git a/llvm/lib/Target/AMDGPU/DSInstructions.td b/llvm/lib/Target/AMDGPU/DSInstructions.td index f2e432f..b2ff5a1 100644 --- a/llvm/lib/Target/AMDGPU/DSInstructions.td +++ b/llvm/lib/Target/AMDGPU/DSInstructions.td @@ -969,10 +969,9 @@ multiclass DSReadPat_t16<DS_Pseudo inst, ValueType vt, string frag> { } let OtherPredicates = [NotLDSRequiresM0Init] in { - foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in - let True16Predicate = p in { - def : DSReadPat<!cast<DS_Pseudo>(!cast<string>(inst)#"_gfx9"), vt, !cast<PatFrag>(frag)>; - } + let True16Predicate = NotUseRealTrue16Insts in { + def : DSReadPat<!cast<DS_Pseudo>(!cast<string>(inst)#"_gfx9"), vt, !cast<PatFrag>(frag)>; + } let True16Predicate = UseRealTrue16Insts in { def : DSReadPat<!cast<DS_Pseudo>(!cast<string>(inst)#"_t16"), vt, !cast<PatFrag>(frag)>; } @@ -1050,10 +1049,9 @@ multiclass DSWritePat_t16 <DS_Pseudo inst, ValueType vt, string frag> { } let OtherPredicates = [NotLDSRequiresM0Init] in { - foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in - let True16Predicate = p in { - def : DSWritePat<!cast<DS_Pseudo>(!cast<string>(inst)#"_gfx9"), vt, !cast<PatFrag>(frag)>; - } + let True16Predicate = NotUseRealTrue16Insts in { + def : DSWritePat<!cast<DS_Pseudo>(!cast<string>(inst)#"_gfx9"), vt, !cast<PatFrag>(frag)>; + } let True16Predicate = UseRealTrue16Insts in { def : DSWritePat<!cast<DS_Pseudo>(!cast<string>(inst)#"_t16"), vt, !cast<PatFrag>(frag)>; } diff --git a/llvm/lib/Target/AMDGPU/FLATInstructions.td b/llvm/lib/Target/AMDGPU/FLATInstructions.td index 9f33bac..5a22b23 100644 --- a/llvm/lib/Target/AMDGPU/FLATInstructions.td +++ b/llvm/lib/Target/AMDGPU/FLATInstructions.td @@ -1982,8 +1982,7 @@ defm : FlatLoadPats <FLAT_LOAD_SSHORT, sextloadi16_flat, i32>; defm : FlatLoadPats <FLAT_LOAD_SSHORT, atomic_load_sext_16_flat, i32>; defm : FlatLoadPats <FLAT_LOAD_DWORDX3, load_flat, v3i32>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { defm : FlatLoadPats <FLAT_LOAD_UBYTE, extloadi8_flat, i16>; defm : FlatLoadPats <FLAT_LOAD_UBYTE, zextloadi8_flat, i16>; defm : FlatLoadPats <FLAT_LOAD_SBYTE, sextloadi8_flat, i16>; @@ -2127,8 +2126,7 @@ defm : GlobalFLATLoadPats <GLOBAL_LOAD_USHORT, extloadi16_global, i32>; defm : GlobalFLATLoadPats <GLOBAL_LOAD_USHORT, zextloadi16_global, i32>; defm : GlobalFLATLoadPats <GLOBAL_LOAD_SSHORT, sextloadi16_global, i32>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { defm : GlobalFLATLoadPats <GLOBAL_LOAD_UBYTE, extloadi8_global, i16>; defm : GlobalFLATLoadPats <GLOBAL_LOAD_UBYTE, zextloadi8_global, i16>; defm : GlobalFLATLoadPats <GLOBAL_LOAD_SBYTE, sextloadi8_global, i16>; @@ -2187,8 +2185,7 @@ defm : GlobalFLATStorePats <GLOBAL_STORE_BYTE, truncstorei8_global, i32>; defm : GlobalFLATStorePats <GLOBAL_STORE_SHORT, truncstorei16_global, i32>; defm : GlobalFLATStorePats <GLOBAL_STORE_DWORDX3, store_global, v3i32>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let OtherPredicates = [HasFlatGlobalInsts], True16Predicate = p in { +let OtherPredicates = [HasFlatGlobalInsts], True16Predicate = NotUseRealTrue16Insts in { defm : GlobalFLATStorePats <GLOBAL_STORE_BYTE, truncstorei8_global, i16>; defm : GlobalFLATStorePats <GLOBAL_STORE_SHORT, store_global, i16>; defm : GlobalFLATStorePats <GLOBAL_STORE_BYTE, atomic_store_8_global, i16>; @@ -2356,8 +2353,7 @@ defm : ScratchFLATLoadPats <SCRATCH_LOAD_USHORT, extloadi16_private, i32>; defm : ScratchFLATLoadPats <SCRATCH_LOAD_USHORT, zextloadi16_private, i32>; defm : ScratchFLATLoadPats <SCRATCH_LOAD_SSHORT, sextloadi16_private, i32>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { defm : ScratchFLATLoadPats <SCRATCH_LOAD_UBYTE, extloadi8_private, i16>; defm : ScratchFLATLoadPats <SCRATCH_LOAD_UBYTE, zextloadi8_private, i16>; defm : ScratchFLATLoadPats <SCRATCH_LOAD_SBYTE, sextloadi8_private, i16>; diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.h b/llvm/lib/Target/AMDGPU/SIInstrInfo.h index 31a2d55..c2252af 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.h +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.h @@ -1006,9 +1006,8 @@ public: Opcode == AMDGPU::S_BARRIER_INIT_M0 || Opcode == AMDGPU::S_BARRIER_INIT_IMM || Opcode == AMDGPU::S_BARRIER_JOIN_IMM || - Opcode == AMDGPU::S_BARRIER_LEAVE || - Opcode == AMDGPU::S_BARRIER_LEAVE_IMM || - Opcode == AMDGPU::DS_GWS_INIT || Opcode == AMDGPU::DS_GWS_BARRIER; + Opcode == AMDGPU::S_BARRIER_LEAVE || Opcode == AMDGPU::DS_GWS_INIT || + Opcode == AMDGPU::DS_GWS_BARRIER; } static bool isF16PseudoScalarTrans(unsigned Opcode) { diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td index 59fd2f1..be084a9 100644 --- a/llvm/lib/Target/AMDGPU/SIInstructions.td +++ b/llvm/lib/Target/AMDGPU/SIInstructions.td @@ -1466,8 +1466,7 @@ class VOPSelectPat_t16 <ValueType vt> : GCNPat < def : VOPSelectModsPat <i32>; def : VOPSelectModsPat <f32>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : VOPSelectPat <f16>; def : VOPSelectPat <i16>; } // End True16Predicate = p @@ -2137,8 +2136,7 @@ def : GCNPat < >; foreach fp16vt = [f16, bf16] in { -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let SubtargetPredicate = p in { +let SubtargetPredicate = NotUseRealTrue16Insts in { def : GCNPat < (fabs (fp16vt VGPR_32:$src)), (V_AND_B32_e64 (S_MOV_B32 (i32 0x00007fff)), VGPR_32:$src) @@ -2230,8 +2228,7 @@ def : GCNPat < } foreach fp16vt = [f16, bf16] in { -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat < (fcopysign fp16vt:$src0, fp16vt:$src1), (V_BFI_B32_e64 (S_MOV_B32 (i32 0x00007fff)), $src0, $src1) @@ -2354,23 +2351,21 @@ def : GCNPat < (S_MOV_B32 $ga) >; -foreach pred = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in { - let True16Predicate = pred in { - def : GCNPat < - (VGPRImm<(i16 imm)>:$imm), - (V_MOV_B32_e32 imm:$imm) - >; +let True16Predicate = NotUseRealTrue16Insts in { + def : GCNPat < + (VGPRImm<(i16 imm)>:$imm), + (V_MOV_B32_e32 imm:$imm) + >; - // FIXME: Workaround for ordering issue with peephole optimizer where - // a register class copy interferes with immediate folding. Should - // use s_mov_b32, which can be shrunk to s_movk_i32 + // FIXME: Workaround for ordering issue with peephole optimizer where + // a register class copy interferes with immediate folding. Should + // use s_mov_b32, which can be shrunk to s_movk_i32 - foreach vt = [f16, bf16] in { - def : GCNPat < - (VGPRImm<(vt fpimm)>:$imm), - (V_MOV_B32_e32 (vt (bitcast_fpimm_to_i32 $imm))) - >; - } + foreach vt = [f16, bf16] in { + def : GCNPat < + (VGPRImm<(vt fpimm)>:$imm), + (V_MOV_B32_e32 (vt (bitcast_fpimm_to_i32 $imm))) + >; } } @@ -2859,8 +2854,7 @@ def : GCNPat< (i32 (DivergentSextInreg<i1> i32:$src)), (V_BFE_I32_e64 i32:$src, (i32 0), (i32 1))>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat < (i16 (DivergentSextInreg<i1> i16:$src)), (V_BFE_I32_e64 $src, (i32 0), (i32 1)) @@ -3205,8 +3199,7 @@ def : GCNPat< } } // AddedComplexity = 1 -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat< (i32 (DivergentUnaryFrag<zext> i16:$src)), (V_AND_B32_e64 (S_MOV_B32 (i32 0xffff)), $src) @@ -3416,8 +3409,7 @@ def : GCNPat < // Magic number: 1 | (0 << 8) | (12 << 16) | (12 << 24) // The 12s emit 0s. -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat < (i16 (bswap i16:$a)), (V_PERM_B32_e64 (i32 0), VSrc_b32:$a, (S_MOV_B32 (i32 0x0c0c0001))) @@ -3670,8 +3662,7 @@ def : GCNPat < (S_LSHL_B32 SReg_32:$src1, (i16 16)) >; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat < (v2i16 (DivergentBinFrag<build_vector> (i16 0), (i16 VGPR_32:$src1))), (v2i16 (V_LSHLREV_B32_e64 (i16 16), VGPR_32:$src1)) @@ -3707,8 +3698,7 @@ def : GCNPat < (COPY_TO_REGCLASS SReg_32:$src0, SReg_32) >; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat < (vecTy (DivergentBinFrag<build_vector> (Ty VGPR_32:$src0), (Ty undef))), (COPY_TO_REGCLASS VGPR_32:$src0, VGPR_32) @@ -3735,8 +3725,7 @@ def : GCNPat < >; let SubtargetPredicate = HasVOP3PInsts in { -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in +let True16Predicate = NotUseRealTrue16Insts in def : GCNPat < (v2i16 (DivergentBinFrag<build_vector> (i16 VGPR_32:$src0), (i16 VGPR_32:$src1))), (v2i16 (V_LSHL_OR_B32_e64 $src1, (i32 16), (i32 (V_AND_B32_e64 (i32 (V_MOV_B32_e32 (i32 0xffff))), $src0)))) @@ -3766,8 +3755,7 @@ def : GCNPat < (S_PACK_LL_B32_B16 SReg_32:$src0, SReg_32:$src1) >; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { // Take the lower 16 bits from each VGPR_32 and concat them def : GCNPat < (vecTy (DivergentBinFrag<build_vector> (Ty VGPR_32:$a), (Ty VGPR_32:$b))), @@ -3838,8 +3826,7 @@ def : GCNPat < >; // Take the upper 16 bits from each VGPR_32 and concat them -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in +let True16Predicate = NotUseRealTrue16Insts in def : GCNPat < (vecTy (DivergentBinFrag<build_vector> (Ty !if(!eq(Ty, i16), @@ -3881,8 +3868,7 @@ def : GCNPat < (v2i16 (S_PACK_HL_B32_B16 SReg_32:$src0, SReg_32:$src1)) >; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat < (v2f16 (scalar_to_vector f16:$src0)), (COPY $src0) diff --git a/llvm/lib/Target/AMDGPU/SOPInstructions.td b/llvm/lib/Target/AMDGPU/SOPInstructions.td index 296ce5a..b3fd8c7 100644 --- a/llvm/lib/Target/AMDGPU/SOPInstructions.td +++ b/llvm/lib/Target/AMDGPU/SOPInstructions.td @@ -1616,7 +1616,8 @@ def S_BARRIER_WAIT : SOPP_Pseudo <"s_barrier_wait", (ins i16imm:$simm16), "$simm let isConvergent = 1; } -def S_BARRIER_LEAVE : SOPP_Pseudo <"s_barrier_leave", (ins)> { + def S_BARRIER_LEAVE : SOPP_Pseudo <"s_barrier_leave", + (ins), "", [(int_amdgcn_s_barrier_leave (i16 srcvalue))] > { let SchedRW = [WriteBarrier]; let simm16 = 0; let fixed_imm = 1; @@ -1624,9 +1625,6 @@ def S_BARRIER_LEAVE : SOPP_Pseudo <"s_barrier_leave", (ins)> { let Defs = [SCC]; } -def S_BARRIER_LEAVE_IMM : SOPP_Pseudo <"s_barrier_leave", - (ins i16imm:$simm16), "$simm16", [(int_amdgcn_s_barrier_leave timm:$simm16)]>; - def S_WAKEUP : SOPP_Pseudo <"s_wakeup", (ins) > { let SubtargetPredicate = isGFX8Plus; let simm16 = 0; diff --git a/llvm/lib/Target/AMDGPU/VOP1Instructions.td b/llvm/lib/Target/AMDGPU/VOP1Instructions.td index 6230c17..77df721 100644 --- a/llvm/lib/Target/AMDGPU/VOP1Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP1Instructions.td @@ -1561,8 +1561,7 @@ def : GCNPat < } // End OtherPredicates = [isGFX8Plus] -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let OtherPredicates = [isGFX8Plus, p] in { +let OtherPredicates = [isGFX8Plus, NotUseRealTrue16Insts] in { def : GCNPat< (i32 (anyext i16:$src)), (COPY $src) diff --git a/llvm/lib/Target/AMDGPU/VOP2Instructions.td b/llvm/lib/Target/AMDGPU/VOP2Instructions.td index 37d92bc..30dab55 100644 --- a/llvm/lib/Target/AMDGPU/VOP2Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP2Instructions.td @@ -1378,8 +1378,7 @@ class ZExt_i16_i1_Pat <SDNode ext> : GCNPat < $src) >; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in { +let True16Predicate = NotUseRealTrue16Insts in { def : GCNPat < (and i16:$src0, i16:$src1), (V_AND_B32_e64 VSrc_b32:$src0, VSrc_b32:$src1) diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td index e6a7c35..4a2b54d 100644 --- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td @@ -387,8 +387,7 @@ let SchedRW = [Write64Bit] in { } // End SchedRW = [Write64Bit] } // End isReMaterializable = 1 -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in +let True16Predicate = NotUseRealTrue16Insts in def : GCNPat< (i32 (DivergentUnaryFrag<sext> i16:$src)), (i32 (V_BFE_I32_e64 i16:$src, (i32 0), (i32 0x10))) @@ -501,8 +500,7 @@ def V_INTERP_P1LV_F16 : VOP3Interp <"v_interp_p1lv_f16", VOP3_INTERP16<[f32, f32 } // End SubtargetPredicate = Has16BitInsts, isCommutable = 1 -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in +let True16Predicate = NotUseRealTrue16Insts in def : GCNPat< (i64 (DivergentUnaryFrag<sext> i16:$src)), (REG_SEQUENCE VReg_64, diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td index 52ee1e8..5daf860 100644 --- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td @@ -402,8 +402,7 @@ defm V_FMA_MIX_F16_t16 : VOP3_VOP3PInst_t16<"v_fma_mix_f16_t16", VOP3P_Mix_Profi defm : MadFmaMixFP32Pats<fma, V_FMA_MIX_F32>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in +let True16Predicate = NotUseRealTrue16Insts in defm : MadFmaMixFP16Pats<fma, V_FMA_MIXLO_F16, V_FMA_MIXHI_F16>; let True16Predicate = UseRealTrue16Insts in defm : MadFmaMixFP16Pats_t16<fma, V_FMA_MIX_F16_t16>; @@ -428,8 +427,7 @@ defm V_FMA_MIX_BF16_t16 : VOP3_VOP3PInst_t16<"v_fma_mix_bf16_t16", VOP3P_Mix_Pro } // End isCommutable = 1 defm : MadFmaMixFP32Pats<fma, V_FMA_MIX_F32_BF16, bf16>; -foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in -let True16Predicate = p in +let True16Predicate = NotUseRealTrue16Insts in defm : MadFmaMixFP16Pats<fma, V_FMA_MIXLO_BF16, V_FMA_MIXHI_BF16, bf16, v2bf16>; let True16Predicate = UseRealTrue16Insts in defm : MadFmaMixFP16Pats_t16<fma, V_FMA_MIX_BF16_t16>; diff --git a/llvm/lib/Target/ARM/ARMBaseRegisterInfo.cpp b/llvm/lib/Target/ARM/ARMBaseRegisterInfo.cpp index e94220a..2e8a676 100644 --- a/llvm/lib/Target/ARM/ARMBaseRegisterInfo.cpp +++ b/llvm/lib/Target/ARM/ARMBaseRegisterInfo.cpp @@ -960,17 +960,3 @@ bool ARMBaseRegisterInfo::shouldCoalesce(MachineInstr *MI, } return false; } - -bool ARMBaseRegisterInfo::shouldRewriteCopySrc(const TargetRegisterClass *DefRC, - unsigned DefSubReg, - const TargetRegisterClass *SrcRC, - unsigned SrcSubReg) const { - // We can't extract an SPR from an arbitary DPR (as opposed to a DPR_VFP2). - if (DefRC == &ARM::SPRRegClass && DefSubReg == 0 && - SrcRC == &ARM::DPRRegClass && - (SrcSubReg == ARM::ssub_0 || SrcSubReg == ARM::ssub_1)) - return false; - - return TargetRegisterInfo::shouldRewriteCopySrc(DefRC, DefSubReg, - SrcRC, SrcSubReg); -} diff --git a/llvm/lib/Target/ARM/ARMBaseRegisterInfo.h b/llvm/lib/Target/ARM/ARMBaseRegisterInfo.h index 5b67b34..03b0fa0 100644 --- a/llvm/lib/Target/ARM/ARMBaseRegisterInfo.h +++ b/llvm/lib/Target/ARM/ARMBaseRegisterInfo.h @@ -158,11 +158,6 @@ public: const TargetRegisterClass *NewRC, LiveIntervals &LIS) const override; - bool shouldRewriteCopySrc(const TargetRegisterClass *DefRC, - unsigned DefSubReg, - const TargetRegisterClass *SrcRC, - unsigned SrcSubReg) const override; - int getSEHRegNum(unsigned i) const { return getEncodingValue(i); } }; diff --git a/llvm/lib/Target/ARM/ARMSubtarget.cpp b/llvm/lib/Target/ARM/ARMSubtarget.cpp index 3329bea..58bc338 100644 --- a/llvm/lib/Target/ARM/ARMSubtarget.cpp +++ b/llvm/lib/Target/ARM/ARMSubtarget.cpp @@ -225,7 +225,11 @@ void ARMSubtarget::initSubtargetFeatures(StringRef CPU, StringRef FS) { (isTargetDarwin() || DM == DenormalMode::getPreserveSign())) HasNEONForFP = true; - if (isRWPI()) + const ARM::ArchKind Arch = ARM::parseArch(TargetTriple.getArchName()); + if (isRWPI() || + (isTargetIOS() && + (Arch == ARM::ArchKind::ARMV6K || Arch == ARM::ArchKind::ARMV6) && + TargetTriple.isOSVersionLT(3, 0))) ReserveR9 = true; // If MVEVectorCostFactor is still 0 (has not been set to anything else), default it to 2 diff --git a/llvm/lib/Target/LoongArch/AsmParser/LoongArchAsmParser.cpp b/llvm/lib/Target/LoongArch/AsmParser/LoongArchAsmParser.cpp index 5be4713..9b11201 100644 --- a/llvm/lib/Target/LoongArch/AsmParser/LoongArchAsmParser.cpp +++ b/llvm/lib/Target/LoongArch/AsmParser/LoongArchAsmParser.cpp @@ -957,8 +957,10 @@ void LoongArchAsmParser::emitLoadAddressAbs(MCInst &Inst, SMLoc IDLoc, : Inst.getOperand(2).getExpr(); InstSeq Insts; + // To distinguish between la.abs and %abs_hi20, la.abs will generate + // R_LARCH_MARK_LA and R_LARCH_ABS_HI20 relocations. Insts.push_back( - LoongArchAsmParser::Inst(LoongArch::LU12I_W, ELF::R_LARCH_ABS_HI20)); + LoongArchAsmParser::Inst(LoongArch::LU12I_W, ELF::R_LARCH_MARK_LA)); Insts.push_back( LoongArchAsmParser::Inst(LoongArch::ORI, ELF::R_LARCH_ABS_LO12)); diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp index 098bcfa..4cfbfca 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp @@ -2319,6 +2319,53 @@ static SDValue lowerVECTOR_SHUFFLE_XVPICKOD(const SDLoc &DL, ArrayRef<int> Mask, return DAG.getNode(LoongArchISD::VPICKOD, DL, VT, V2, V1); } +/// Lower VECTOR_SHUFFLE into XVINSVE0 (if possible). +static SDValue +lowerVECTOR_SHUFFLE_XVINSVE0(const SDLoc &DL, ArrayRef<int> Mask, MVT VT, + SDValue V1, SDValue V2, SelectionDAG &DAG, + const LoongArchSubtarget &Subtarget) { + // LoongArch LASX only supports xvinsve0.{w/d}. + if (VT != MVT::v8i32 && VT != MVT::v8f32 && VT != MVT::v4i64 && + VT != MVT::v4f64) + return SDValue(); + + MVT GRLenVT = Subtarget.getGRLenVT(); + int MaskSize = Mask.size(); + assert(MaskSize == (int)VT.getVectorNumElements() && "Unexpected mask size"); + + // Check if exactly one element of the Mask is replaced by 'Replaced', while + // all other elements are either 'Base + i' or undef (-1). On success, return + // the index of the replaced element. Otherwise, just return -1. + auto checkReplaceOne = [&](int Base, int Replaced) -> int { + int Idx = -1; + for (int i = 0; i < MaskSize; ++i) { + if (Mask[i] == Base + i || Mask[i] == -1) + continue; + if (Mask[i] != Replaced) + return -1; + if (Idx == -1) + Idx = i; + else + return -1; + } + return Idx; + }; + + // Case 1: the lowest element of V2 replaces one element in V1. + int Idx = checkReplaceOne(0, MaskSize); + if (Idx != -1) + return DAG.getNode(LoongArchISD::XVINSVE0, DL, VT, V1, V2, + DAG.getConstant(Idx, DL, GRLenVT)); + + // Case 2: the lowest element of V1 replaces one element in V2. + Idx = checkReplaceOne(MaskSize, 0); + if (Idx != -1) + return DAG.getNode(LoongArchISD::XVINSVE0, DL, VT, V2, V1, + DAG.getConstant(Idx, DL, GRLenVT)); + + return SDValue(); +} + /// Lower VECTOR_SHUFFLE into XVSHUF (if possible). static SDValue lowerVECTOR_SHUFFLE_XVSHUF(const SDLoc &DL, ArrayRef<int> Mask, MVT VT, SDValue V1, SDValue V2, @@ -2595,6 +2642,9 @@ static SDValue lower256BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT, if ((Result = lowerVECTOR_SHUFFLEAsShift(DL, Mask, VT, V1, V2, DAG, Subtarget, Zeroable))) return Result; + if ((Result = + lowerVECTOR_SHUFFLE_XVINSVE0(DL, Mask, VT, V1, V2, DAG, Subtarget))) + return Result; if ((Result = lowerVECTOR_SHUFFLEAsByteRotate(DL, Mask, VT, V1, V2, DAG, Subtarget))) return Result; @@ -7453,6 +7503,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(XVPERM) NODE_NAME_CASE(XVREPLVE0) NODE_NAME_CASE(XVREPLVE0Q) + NODE_NAME_CASE(XVINSVE0) NODE_NAME_CASE(VPICK_SEXT_ELT) NODE_NAME_CASE(VPICK_ZEXT_ELT) NODE_NAME_CASE(VREPLVE) diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h index 9b60a9f..8a4d774 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h @@ -151,6 +151,7 @@ enum NodeType : unsigned { XVPERM, XVREPLVE0, XVREPLVE0Q, + XVINSVE0, // Extended vector element extraction VPICK_SEXT_ELT, diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td index bbc0489..5143d53 100644 --- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td @@ -20,6 +20,7 @@ def loongarch_xvpermi: SDNode<"LoongArchISD::XVPERMI", SDT_LoongArchV1RUimm>; def loongarch_xvperm: SDNode<"LoongArchISD::XVPERM", SDT_LoongArchXVPERM>; def loongarch_xvreplve0: SDNode<"LoongArchISD::XVREPLVE0", SDT_LoongArchXVREPLVE0>; def loongarch_xvreplve0q: SDNode<"LoongArchISD::XVREPLVE0Q", SDT_LoongArchXVREPLVE0>; +def loongarch_xvinsve0 : SDNode<"LoongArchISD::XVINSVE0", SDT_LoongArchV2RUimm>; def loongarch_xvmskltz: SDNode<"LoongArchISD::XVMSKLTZ", SDT_LoongArchVMSKCOND>; def loongarch_xvmskgez: SDNode<"LoongArchISD::XVMSKGEZ", SDT_LoongArchVMSKCOND>; def loongarch_xvmskeqz: SDNode<"LoongArchISD::XVMSKEQZ", SDT_LoongArchVMSKCOND>; @@ -1708,6 +1709,14 @@ def : Pat<(vector_insert v4f64:$xd, (f64(bitconvert i64:$rj)), uimm2:$imm), (XVINSGR2VR_D v4f64:$xd, GPR:$rj, uimm2:$imm)>; // XVINSVE0_{W/D} +def : Pat<(loongarch_xvinsve0 v8i32:$xd, v8i32:$xj, uimm3:$imm), + (XVINSVE0_W v8i32:$xd, v8i32:$xj, uimm3:$imm)>; +def : Pat<(loongarch_xvinsve0 v4i64:$xd, v4i64:$xj, uimm2:$imm), + (XVINSVE0_D v4i64:$xd, v4i64:$xj, uimm2:$imm)>; +def : Pat<(loongarch_xvinsve0 v8f32:$xd, v8f32:$xj, uimm3:$imm), + (XVINSVE0_W v8f32:$xd, v8f32:$xj, uimm3:$imm)>; +def : Pat<(loongarch_xvinsve0 v4f64:$xd, v4f64:$xj, uimm2:$imm), + (XVINSVE0_D v4f64:$xd, v4f64:$xj, uimm2:$imm)>; def : Pat<(vector_insert v8f32:$xd, FPR32:$fj, uimm3:$imm), (XVINSVE0_W v8f32:$xd, (SUBREG_TO_REG(i64 0), FPR32:$fj, sub_32), uimm3:$imm)>; diff --git a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchMCAsmInfo.cpp b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchMCAsmInfo.cpp index 0d77617..8ecb62d 100644 --- a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchMCAsmInfo.cpp +++ b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchMCAsmInfo.cpp @@ -32,6 +32,7 @@ static StringRef getLoongArchSpecifierName(uint16_t S) { return "b16"; case ELF::R_LARCH_B21: return "b21"; + case ELF::R_LARCH_MARK_LA: case ELF::R_LARCH_ABS_HI20: return "abs_hi20"; case ELF::R_LARCH_ABS_LO12: diff --git a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchMCCodeEmitter.cpp b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchMCCodeEmitter.cpp index b7ead5e..f0e2bc4 100644 --- a/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchMCCodeEmitter.cpp +++ b/llvm/lib/Target/LoongArch/MCTargetDesc/LoongArchMCCodeEmitter.cpp @@ -161,6 +161,13 @@ LoongArchMCCodeEmitter::getExprOpValue(const MCInst &MI, const MCOperand &MO, case ELF::R_LARCH_B26: FixupKind = LoongArch::fixup_loongarch_b26; break; + case ELF::R_LARCH_MARK_LA: + // Match gas behavior: generate `R_LARCH_MARK_LA` relocation when using + // `la.abs`. + Fixups.push_back( + MCFixup::create(0, MCConstantExpr::create(0, Ctx), + FirstLiteralRelocationKind + ELF::R_LARCH_MARK_LA)); + [[fallthrough]]; case ELF::R_LARCH_ABS_HI20: FixupKind = LoongArch::fixup_loongarch_abs_hi20; break; diff --git a/llvm/lib/Target/PowerPC/PPCInstrFuture.td b/llvm/lib/Target/PowerPC/PPCInstrFuture.td index c3ab965..1aefea1 100644 --- a/llvm/lib/Target/PowerPC/PPCInstrFuture.td +++ b/llvm/lib/Target/PowerPC/PPCInstrFuture.td @@ -182,10 +182,113 @@ class XX3Form_XTAB6<bits<6> opcode, bits<8> xo, dag OOL, dag IOL, string asmstr, let Inst{31} = XT{5}; } +class XX3Form_XTAB6_S<bits<5> xo, dag OOL, dag IOL, string asmstr, + list<dag> pattern> + : I<59, OOL, IOL, asmstr, NoItinerary> { + bits<6> XT; + bits<6> XA; + bits<6> XB; + + let Pattern = pattern; + + let Inst{6...10} = XT{4...0}; + let Inst{11...15} = XA{4...0}; + let Inst{16...20} = XB{4...0}; + let Inst{24...28} = xo; + let Inst{29} = XA{5}; + let Inst{30} = XB{5}; + let Inst{31} = XT{5}; +} + +class XX3Form_XTAB6_S3<bits<5> xo, dag OOL, dag IOL, string asmstr, + list<dag> pattern> + : XX3Form_XTAB6_S<xo, OOL, IOL, asmstr, pattern> { + + bits<3> S; + let Inst{21...23} = S; +} + +class XX3Form_XTAB6_3S1<bits<5> xo, dag OOL, dag IOL, string asmstr, + list<dag> pattern> + : XX3Form_XTAB6_S<xo, OOL, IOL, asmstr, pattern> { + + bits<1> S0; + bits<1> S1; + bits<1> S2; + + let Inst{21} = S0; + let Inst{22} = S1; + let Inst{23} = S2; +} + +class XX3Form_XTAB6_2S1<bits<5> xo, dag OOL, dag IOL, string asmstr, + list<dag> pattern> + : XX3Form_XTAB6_S<xo, OOL, IOL, asmstr, pattern> { + + bits<1> S1; + bits<1> S2; + + let Inst{21} = 0; + let Inst{22} = S1; + let Inst{23} = S2; +} + +class XX3Form_XTAB6_P<bits<7> xo, dag OOL, dag IOL, string asmstr, + list<dag> pattern> + : I<59, OOL, IOL, asmstr, NoItinerary> { + + bits<6> XT; + bits<6> XA; + bits<6> XB; + bits<1> P; + + let Pattern = pattern; + + let Inst{6...10} = XT{4...0}; + let Inst{11...15} = XA{4...0}; + let Inst{16...20} = XB{4...0}; + let Inst{21} = P; + let Inst{22...28} = xo; + let Inst{29} = XA{5}; + let Inst{30} = XB{5}; + let Inst{31} = XT{5}; +} + +// Prefix instruction classes. + +class 8RR_XX4Form_XTABC6_P<bits<6> opcode, dag OOL, dag IOL, string asmstr, + InstrItinClass itin, list<dag> pattern> + : PI<1, opcode, OOL, IOL, asmstr, itin> { + bits<6> XT; + bits<6> XA; + bits<6> XB; + bits<6> XC; + bits<1> P; + + let Pattern = pattern; + + // The prefix. + let Inst{6...7} = 1; + let Inst{8...11} = 0; + + // The instruction. + let Inst{38...42} = XT{4...0}; + let Inst{43...47} = XA{4...0}; + let Inst{48...52} = XB{4...0}; + let Inst{53...57} = XC{4...0}; + let Inst{58} = 1; + let Inst{59} = P; + let Inst{60} = XC{5}; + let Inst{61} = XA{5}; + let Inst{62} = XB{5}; + let Inst{63} = XT{5}; +} + //-------------------------- Instruction definitions -------------------------// // Predicate combinations available: // [IsISAFuture] // [HasVSX, IsISAFuture] +// [HasVSX, PrefixInstrs, IsISAFuture] let Predicates = [IsISAFuture] in { defm SUBFUS : XOForm_RTAB5_L1r<31, 72, (outs g8rc:$RT), @@ -294,6 +397,78 @@ let Predicates = [HasVSX, IsISAFuture] in { "xvmulhuw $XT, $XA, $XB", []>; def XVMULHUH: XX3Form_XTAB6<60, 122, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), "xvmulhuh $XT, $XA, $XB", []>; + + // Elliptic Curve Cryptography Acceleration Instructions. + def XXMULMUL + : XX3Form_XTAB6_S3<1, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB, u3imm:$S), + "xxmulmul $XT, $XA, $XB, $S", []>; + def XXMULMULHIADD + : XX3Form_XTAB6_3S1<9, (outs vsrc:$XT), + (ins vsrc:$XA, vsrc:$XB, u1imm:$S0, u1imm:$S1, + u1imm:$S2), + "xxmulmulhiadd $XT, $XA, $XB, $S0, $S1, $S2", []>; + def XXMULMULLOADD + : XX3Form_XTAB6_2S1<17, (outs vsrc:$XT), + (ins vsrc:$XA, vsrc:$XB, u1imm:$S1, u1imm:$S2), + "xxmulmulloadd $XT, $XA, $XB, $S1, $S2", []>; + def XXSSUMUDM + : XX3Form_XTAB6_P<25, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB, u1imm:$P), + "xxssumudm $XT, $XA, $XB, $P", []>; + def XXSSUMUDMC + : XX3Form_XTAB6_P<57, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB, u1imm:$P), + "xxssumudmc $XT, $XA, $XB, $P", []>; + def XSADDADDUQM + : XX3Form_XTAB6<59, 96, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsaddadduqm $XT, $XA, $XB", []>; + def XSADDADDSUQM + : XX3Form_XTAB6<59, 104, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsaddaddsuqm $XT, $XA, $XB", []>; + def XSADDSUBUQM + : XX3Form_XTAB6<59, 112, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsaddsubuqm $XT, $XA, $XB", []>; + def XSADDSUBSUQM + : XX3Form_XTAB6<59, 224, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsaddsubsuqm $XT, $XA, $XB", []>; + def XSMERGE2T1UQM + : XX3Form_XTAB6<59, 232, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsmerge2t1uqm $XT, $XA, $XB", []>; + def XSMERGE2T2UQM + : XX3Form_XTAB6<59, 240, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsmerge2t2uqm $XT, $XA, $XB", []>; + def XSMERGE2T3UQM + : XX3Form_XTAB6<59, 89, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsmerge2t3uqm $XT, $XA, $XB", []>; + def XSMERGE3T1UQM + : XX3Form_XTAB6<59, 121, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsmerge3t1uqm $XT, $XA, $XB", []>; + def XSREBASE2T1UQM + : XX3Form_XTAB6<59, 145, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsrebase2t1uqm $XT, $XA, $XB", []>; + def XSREBASE2T2UQM + : XX3Form_XTAB6<59, 177, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsrebase2t2uqm $XT, $XA, $XB", []>; + def XSREBASE2T3UQM + : XX3Form_XTAB6<59, 209, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsrebase2t3uqm $XT, $XA, $XB", []>; + def XSREBASE2T4UQM + : XX3Form_XTAB6<59, 217, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsrebase2t4uqm $XT, $XA, $XB", []>; + def XSREBASE3T1UQM + : XX3Form_XTAB6<59, 241, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsrebase3t1uqm $XT, $XA, $XB", []>; + def XSREBASE3T2UQM + : XX3Form_XTAB6<59, 249, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsrebase3t2uqm $XT, $XA, $XB", []>; + def XSREBASE3T3UQM + : XX3Form_XTAB6<59, 195, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB), + "xsrebase3t3uqm $XT, $XA, $XB", []>; +} + +let Predicates = [HasVSX, PrefixInstrs, IsISAFuture] in { + def XXSSUMUDMCEXT + : 8RR_XX4Form_XTABC6_P< + 34, (outs vsrc:$XT), (ins vsrc:$XA, vsrc:$XB, vsrc:$XC, u1imm:$P), + "xxssumudmcext $XT, $XA, $XB, $XC, $P", IIC_VecGeneral, []>; } //---------------------------- Anonymous Patterns ----------------------------// diff --git a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp index cb57c43..d4d9e54 100644 --- a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp +++ b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp @@ -193,7 +193,7 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB, // we need to invert the branch condition to jump over TrueBB when the // condition is false. auto CC = static_cast<RISCVCC::CondCode>(MI.getOperand(3).getImm()); - CC = RISCVCC::getOppositeBranchCondition(CC); + CC = RISCVCC::getInverseBranchCondition(CC); // Insert branch instruction. BuildMI(MBB, MBBI, DL, TII->get(RISCVCC::getBrCond(CC))) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index 56db09a..70b6c7e 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -1023,6 +1023,37 @@ static void parseCondBranch(MachineInstr &LastInst, MachineBasicBlock *&Target, Cond.push_back(LastInst.getOperand(1)); } +static unsigned getInverseXqcicmOpcode(unsigned Opcode) { + switch (Opcode) { + default: + llvm_unreachable("Unexpected Opcode"); + case RISCV::QC_MVEQ: + return RISCV::QC_MVNE; + case RISCV::QC_MVNE: + return RISCV::QC_MVEQ; + case RISCV::QC_MVLT: + return RISCV::QC_MVGE; + case RISCV::QC_MVGE: + return RISCV::QC_MVLT; + case RISCV::QC_MVLTU: + return RISCV::QC_MVGEU; + case RISCV::QC_MVGEU: + return RISCV::QC_MVLTU; + case RISCV::QC_MVEQI: + return RISCV::QC_MVNEI; + case RISCV::QC_MVNEI: + return RISCV::QC_MVEQI; + case RISCV::QC_MVLTI: + return RISCV::QC_MVGEI; + case RISCV::QC_MVGEI: + return RISCV::QC_MVLTI; + case RISCV::QC_MVLTUI: + return RISCV::QC_MVGEUI; + case RISCV::QC_MVGEUI: + return RISCV::QC_MVLTUI; + } +} + unsigned RISCVCC::getBrCond(RISCVCC::CondCode CC, unsigned SelectOpc) { switch (SelectOpc) { default: @@ -1134,7 +1165,7 @@ unsigned RISCVCC::getBrCond(RISCVCC::CondCode CC, unsigned SelectOpc) { } } -RISCVCC::CondCode RISCVCC::getOppositeBranchCondition(RISCVCC::CondCode CC) { +RISCVCC::CondCode RISCVCC::getInverseBranchCondition(RISCVCC::CondCode CC) { switch (CC) { default: llvm_unreachable("Unrecognized conditional branch"); @@ -1554,7 +1585,7 @@ bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const { return Register(); }; - unsigned NewOpc = RISCVCC::getBrCond(getOppositeBranchCondition(CC)); + unsigned NewOpc = RISCVCC::getBrCond(getInverseBranchCondition(CC)); // Might be case 1. // Don't change 0 to 1 since we can use x0. @@ -1801,7 +1832,7 @@ RISCVInstrInfo::optimizeSelect(MachineInstr &MI, // Add condition code, inverting if necessary. auto CC = static_cast<RISCVCC::CondCode>(MI.getOperand(3).getImm()); if (Invert) - CC = RISCVCC::getOppositeBranchCondition(CC); + CC = RISCVCC::getInverseBranchCondition(CC); NewMI.addImm(CC); // Copy the false register. @@ -3762,6 +3793,19 @@ bool RISCVInstrInfo::findCommutedOpIndices(const MachineInstr &MI, return false; // Operands 1 and 2 are commutable, if we switch the opcode. return fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, 1, 2); + case RISCV::QC_MVEQ: + case RISCV::QC_MVNE: + case RISCV::QC_MVLT: + case RISCV::QC_MVGE: + case RISCV::QC_MVLTU: + case RISCV::QC_MVGEU: + case RISCV::QC_MVEQI: + case RISCV::QC_MVNEI: + case RISCV::QC_MVLTI: + case RISCV::QC_MVGEI: + case RISCV::QC_MVLTUI: + case RISCV::QC_MVGEUI: + return fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, 1, 4); case RISCV::TH_MULA: case RISCV::TH_MULAW: case RISCV::TH_MULAH: @@ -3974,11 +4018,28 @@ MachineInstr *RISCVInstrInfo::commuteInstructionImpl(MachineInstr &MI, return TargetInstrInfo::commuteInstructionImpl(WorkingMI, false, OpIdx1, OpIdx2); } + case RISCV::QC_MVEQ: + case RISCV::QC_MVNE: + case RISCV::QC_MVLT: + case RISCV::QC_MVGE: + case RISCV::QC_MVLTU: + case RISCV::QC_MVGEU: + case RISCV::QC_MVEQI: + case RISCV::QC_MVNEI: + case RISCV::QC_MVLTI: + case RISCV::QC_MVGEI: + case RISCV::QC_MVLTUI: + case RISCV::QC_MVGEUI: { + auto &WorkingMI = cloneIfNew(MI); + WorkingMI.setDesc(get(getInverseXqcicmOpcode(MI.getOpcode()))); + return TargetInstrInfo::commuteInstructionImpl(WorkingMI, false, OpIdx1, + OpIdx2); + } case RISCV::PseudoCCMOVGPRNoX0: case RISCV::PseudoCCMOVGPR: { // CCMOV can be commuted by inverting the condition. auto CC = static_cast<RISCVCC::CondCode>(MI.getOperand(3).getImm()); - CC = RISCVCC::getOppositeBranchCondition(CC); + CC = RISCVCC::getInverseBranchCondition(CC); auto &WorkingMI = cloneIfNew(MI); WorkingMI.getOperand(3).setImm(CC); return TargetInstrInfo::commuteInstructionImpl(WorkingMI, /*NewMI*/ false, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h index 2bc499b..42a0c4c 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -44,7 +44,7 @@ enum CondCode { COND_INVALID }; -CondCode getOppositeBranchCondition(CondCode); +CondCode getInverseBranchCondition(CondCode); unsigned getBrCond(CondCode CC, unsigned SelectOpc = 0); } // end of namespace RISCVCC diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td index 4eb9a3be..d998316 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td @@ -345,7 +345,7 @@ defset list<VTypeInfo> AllVectors = { } } - defset list<VTypeInfo> AllFloatAndBFloatVectors = { + defset list<VTypeInfo> AllFloatAndBF16Vectors = { defset list<VTypeInfo> AllFloatVectors = { defset list<VTypeInfo> NoGroupFloatVectors = { defset list<VTypeInfo> FractionalGroupFloatVectors = { @@ -382,16 +382,16 @@ defset list<VTypeInfo> AllVectors = { } } - defset list<VTypeInfo> AllBFloatVectors = { - defset list<VTypeInfo> NoGroupBFloatVectors = { - defset list<VTypeInfo> FractionalGroupBFloatVectors = { + defset list<VTypeInfo> AllBF16Vectors = { + defset list<VTypeInfo> NoGroupBF16Vectors = { + defset list<VTypeInfo> FractionalGroupBF16Vectors = { def VBF16MF4: VTypeInfo<vbfloat16mf4_t, vbool64_t, 16, V_MF4, bf16, FPR16>; def VBF16MF2: VTypeInfo<vbfloat16mf2_t, vbool32_t, 16, V_MF2, bf16, FPR16>; } def VBF16M1: VTypeInfo<vbfloat16m1_t, vbool16_t, 16, V_M1, bf16, FPR16>; } - defset list<GroupVTypeInfo> GroupBFloatVectors = { + defset list<GroupVTypeInfo> GroupBF16Vectors = { def VBF16M2: GroupVTypeInfo<vbfloat16m2_t, vbfloat16m1_t, vbool8_t, 16, V_M2, bf16, FPR16>; def VBF16M4: GroupVTypeInfo<vbfloat16m4_t, vbfloat16m1_t, vbool4_t, 16, @@ -542,7 +542,7 @@ defset list<VTypeInfoToWide> AllWidenableIntToFloatVectors = { def : VTypeInfoToWide<VI32M4, VF64M8>; } -defset list<VTypeInfoToWide> AllWidenableBFloatToFloatVectors = { +defset list<VTypeInfoToWide> AllWidenableBF16ToFloatVectors = { def : VTypeInfoToWide<VBF16MF4, VF32MF2>; def : VTypeInfoToWide<VBF16MF2, VF32M1>; def : VTypeInfoToWide<VBF16M1, VF32M2>; @@ -5870,7 +5870,7 @@ multiclass VPatConversionWF_VF<string intrinsic, string instruction, multiclass VPatConversionWF_VF_BF<string intrinsic, string instruction, bit isSEWAware = 0> { - foreach fvtiToFWti = AllWidenableBFloatToFloatVectors in + foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { defvar fvti = fvtiToFWti.Vti; defvar fwti = fvtiToFWti.Wti; @@ -5977,7 +5977,7 @@ multiclass VPatConversionVF_WF_RTZ<string intrinsic, string instruction, multiclass VPatConversionVF_WF_BF_RM<string intrinsic, string instruction, bit isSEWAware = 0> { - foreach fvtiToFWti = AllWidenableBFloatToFloatVectors in { + foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { defvar fvti = fvtiToFWti.Vti; defvar fwti = fvtiToFWti.Wti; let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates, @@ -7154,7 +7154,7 @@ defm : VPatConversionVI_VF<"int_riscv_vfclass", "PseudoVFCLASS">; // We can use vmerge.vvm to support vector-vector vfmerge. // NOTE: Clang previously used int_riscv_vfmerge for vector-vector, but now uses // int_riscv_vmerge. Support both for compatibility. -foreach vti = AllFloatAndBFloatVectors in { +foreach vti = AllFloatAndBF16Vectors in { let Predicates = GetVTypeMinimalPredicates<vti>.Predicates in defm : VPatBinaryCarryInTAIL<"int_riscv_vmerge", "PseudoVMERGE", "VVM", vti.Vector, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td index dc61361..139ff92 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -1388,7 +1388,7 @@ defm : VPatFPSetCCSDNode_VV_VF_FV<SETOLE, "PseudoVMFLE", "PseudoVMFGE">; // Floating-point vselects: // 11.15. Vector Integer Merge Instructions // 13.15. Vector Floating-Point Merge Instruction -foreach fvti = AllFloatAndBFloatVectors in { +foreach fvti = AllFloatAndBF16Vectors in { defvar ivti = GetIntVTypeInfo<fvti>.Vti; let Predicates = GetVTypePredicates<ivti>.Predicates in { def : Pat<(fvti.Vector (vselect (fvti.Mask VMV0:$vm), fvti.RegClass:$rs1, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index 1511f1b..cf904ea 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -2426,7 +2426,7 @@ foreach vti = AllFloatVectors in { // Floating-point vselects: // 11.15. Vector Integer Merge Instructions // 13.15. Vector Floating-Point Merge Instruction -foreach fvti = AllFloatAndBFloatVectors in { +foreach fvti = AllFloatAndBF16Vectors in { defvar ivti = GetIntVTypeInfo<fvti>.Vti; let Predicates = GetVTypePredicates<ivti>.Predicates in { def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask VMV0:$vm), @@ -2770,7 +2770,7 @@ foreach vti = NoGroupFloatVectors in { } } -foreach vti = AllFloatAndBFloatVectors in { +foreach vti = AllFloatAndBF16Vectors in { defvar ivti = GetIntVTypeInfo<vti>.Vti; let Predicates = GetVTypePredicates<ivti>.Predicates in { def : Pat<(vti.Vector diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td index 9835c03..b683e89 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXAndes.td @@ -560,7 +560,7 @@ multiclass VPseudoVNCVT_BF16_S { } multiclass VPatConversionS_BF16<string intrinsic, string instruction> { - foreach fvtiToFWti = AllWidenableBFloatToFloatVectors in { + foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { defvar fvti = fvtiToFWti.Vti; defvar fwti = fvtiToFWti.Wti; let Predicates = [HasVendorXAndesVBFHCvt] in @@ -572,7 +572,7 @@ multiclass VPatConversionS_BF16<string intrinsic, string instruction> { } multiclass VPatConversionBF16_S<string intrinsic, string instruction> { - foreach fvtiToFWti = AllWidenableBFloatToFloatVectors in { + foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { defvar fvti = fvtiToFWti.Vti; defvar fwti = fvtiToFWti.Wti; let Predicates = [HasVendorXAndesVBFHCvt] in diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td index b546339..557d873 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td @@ -770,7 +770,7 @@ multiclass VPatVQMACCQOQ<string intrinsic, string instruction, string kind> : VPatVMACC<intrinsic, instruction, kind, VQMACCQOQInfoPairs, vint8m1_t>; multiclass VPatVFWMACC<string intrinsic, string instruction, string kind> - : VPatVMACC<intrinsic, instruction, kind, AllWidenableBFloatToFloatVectors, + : VPatVMACC<intrinsic, instruction, kind, AllWidenableBF16ToFloatVectors, vbfloat16m1_t>; defset list<VTypeInfoToWide> VFNRCLIPInfoPairs = { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td index 13b02d1..ff4a040 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXqci.td @@ -604,7 +604,7 @@ class QCILICC<bits<3> funct3, bits<2> funct2, DAGOperand InTyRs2, string opcodes let Inst{31-25} = {simm, funct2}; } -let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +let hasSideEffects = 0, mayLoad = 0, mayStore = 0, isCommutable = 1 in class QCIMVCC<bits<3> funct3, string opcodestr> : RVInstR4<0b00, funct3, OPC_CUSTOM_2, (outs GPRNoX0:$rd_wb), (ins GPRNoX0:$rd, GPRNoX0:$rs1, GPRNoX0:$rs2, GPRNoX0:$rs3), @@ -612,7 +612,7 @@ class QCIMVCC<bits<3> funct3, string opcodestr> let Constraints = "$rd = $rd_wb"; } -let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +let hasSideEffects = 0, mayLoad = 0, mayStore = 0, isCommutable = 1 in class QCIMVCCI<bits<3> funct3, string opcodestr, DAGOperand immType> : RVInstR4<0b10, funct3, OPC_CUSTOM_2, (outs GPRNoX0:$rd_wb), (ins GPRNoX0:$rd, GPRNoX0:$rs1, immType:$imm, GPRNoX0:$rs3), diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td index 6d8672b..0be9eab 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td @@ -53,7 +53,7 @@ let Predicates = [HasStdExtZvfbfminOrZvfofp8min] in { defm : VPatConversionVF_WF_BF_RM<"int_riscv_vfncvtbf16_f_f_w", "PseudoVFNCVTBF16_F_F", isSEWAware=1>; - foreach fvtiToFWti = AllWidenableBFloatToFloatVectors in { + foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { defvar fvti = fvtiToFWti.Vti; defvar fwti = fvtiToFWti.Wti; let Predicates = [HasVInstructionsBF16Minimal] in @@ -91,9 +91,9 @@ let Predicates = [HasStdExtZvfbfminOrZvfofp8min] in { let Predicates = [HasStdExtZvfbfwma] in { defm : VPatTernaryW_VV_VX_RM<"int_riscv_vfwmaccbf16", "PseudoVFWMACCBF16", - AllWidenableBFloatToFloatVectors, isSEWAware=1>; + AllWidenableBF16ToFloatVectors, isSEWAware=1>; defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmadd_vl, "PseudoVFWMACCBF16", - AllWidenableBFloatToFloatVectors>; + AllWidenableBF16ToFloatVectors>; defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACCBF16", - AllWidenableBFloatToFloatVectors>; + AllWidenableBF16ToFloatVectors>; } diff --git a/llvm/lib/Target/RISCV/RISCVProcessors.td b/llvm/lib/Target/RISCV/RISCVProcessors.td index 95f8a87..17a7948 100644 --- a/llvm/lib/Target/RISCV/RISCVProcessors.td +++ b/llvm/lib/Target/RISCV/RISCVProcessors.td @@ -347,16 +347,58 @@ defvar SiFiveP400TuneFeatures = [TuneNoDefaultUnroll, TunePostRAScheduler]; def SIFIVE_P450 : RISCVProcessorModel<"sifive-p450", SiFiveP400Model, - !listconcat(RVA22U64Features, - [FeatureStdExtZifencei, + [Feature64Bit, + FeatureStdExtI, + FeatureStdExtM, + FeatureStdExtA, + FeatureStdExtF, + FeatureStdExtD, + FeatureStdExtC, + FeatureStdExtZicsr, + FeatureStdExtZiccif, + FeatureStdExtZiccrse, + FeatureStdExtZiccamoa, + FeatureStdExtZicclsm, + FeatureStdExtZa64rs, + FeatureStdExtZihpm, + FeatureStdExtZihintpause, + FeatureStdExtB, + FeatureStdExtZic64b, + FeatureStdExtZicbom, + FeatureStdExtZicbop, + FeatureStdExtZicboz, + FeatureStdExtZfhmin, + FeatureStdExtZkt, + FeatureStdExtZifencei, FeatureStdExtZihintntl, FeatureUnalignedScalarMem, - FeatureUnalignedVectorMem]), + FeatureUnalignedVectorMem], SiFiveP400TuneFeatures>; def SIFIVE_P470 : RISCVProcessorModel<"sifive-p470", SiFiveP400Model, - !listconcat(RVA22U64Features, - [FeatureStdExtV, + [Feature64Bit, + FeatureStdExtI, + FeatureStdExtM, + FeatureStdExtA, + FeatureStdExtF, + FeatureStdExtD, + FeatureStdExtC, + FeatureStdExtZicsr, + FeatureStdExtZiccif, + FeatureStdExtZiccrse, + FeatureStdExtZiccamoa, + FeatureStdExtZicclsm, + FeatureStdExtZa64rs, + FeatureStdExtZihpm, + FeatureStdExtZihintpause, + FeatureStdExtB, + FeatureStdExtZic64b, + FeatureStdExtZicbom, + FeatureStdExtZicbop, + FeatureStdExtZicboz, + FeatureStdExtZfhmin, + FeatureStdExtZkt, + FeatureStdExtV, FeatureStdExtZifencei, FeatureStdExtZihintntl, FeatureStdExtZvl128b, @@ -368,7 +410,7 @@ def SIFIVE_P470 : RISCVProcessorModel<"sifive-p470", SiFiveP400Model, FeatureVendorXSiFivecdiscarddlone, FeatureVendorXSiFivecflushdlone, FeatureUnalignedScalarMem, - FeatureUnalignedVectorMem]), + FeatureUnalignedVectorMem], !listconcat(SiFiveP400TuneFeatures, [TuneNoSinkSplatOperands, TuneVXRMPipelineFlush])>; @@ -397,8 +439,29 @@ def SIFIVE_P550 : RISCVProcessorModel<"sifive-p550", SiFiveP500Model, } def SIFIVE_P670 : RISCVProcessorModel<"sifive-p670", SiFiveP600Model, - !listconcat(RVA22U64Features, - [FeatureStdExtV, + [Feature64Bit, + FeatureStdExtI, + FeatureStdExtM, + FeatureStdExtA, + FeatureStdExtF, + FeatureStdExtD, + FeatureStdExtC, + FeatureStdExtZicsr, + FeatureStdExtZiccif, + FeatureStdExtZiccrse, + FeatureStdExtZiccamoa, + FeatureStdExtZicclsm, + FeatureStdExtZa64rs, + FeatureStdExtZihpm, + FeatureStdExtZihintpause, + FeatureStdExtB, + FeatureStdExtZic64b, + FeatureStdExtZicbom, + FeatureStdExtZicbop, + FeatureStdExtZicboz, + FeatureStdExtZfhmin, + FeatureStdExtZkt, + FeatureStdExtV, FeatureStdExtZifencei, FeatureStdExtZihintntl, FeatureStdExtZvl128b, @@ -408,7 +471,7 @@ def SIFIVE_P670 : RISCVProcessorModel<"sifive-p670", SiFiveP600Model, FeatureStdExtZvksc, FeatureStdExtZvksg, FeatureUnalignedScalarMem, - FeatureUnalignedVectorMem]), + FeatureUnalignedVectorMem], [TuneNoDefaultUnroll, TuneConditionalCompressedMoveFusion, TuneLUIADDIFusion, diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index c2a6e51..b765fec 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -81,6 +81,7 @@ public: void outputExecutionMode(const Module &M); void outputAnnotations(const Module &M); void outputModuleSections(); + void outputFPFastMathDefaultInfo(); bool isHidden() { return MF->getFunction() .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME) @@ -498,11 +499,27 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode"); if (Node) { for (unsigned i = 0; i < Node->getNumOperands(); i++) { + // If SPV_KHR_float_controls2 is enabled and we find any of + // FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution + // modes, skip it, it'll be done somewhere else. + if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { + const auto EM = + cast<ConstantInt>( + cast<ConstantAsMetadata>((Node->getOperand(i))->getOperand(1)) + ->getValue()) + ->getZExtValue(); + if (EM == SPIRV::ExecutionMode::FPFastMathDefault || + EM == SPIRV::ExecutionMode::ContractionOff || + EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) + continue; + } + MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionMode); addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI); outputMCInst(Inst); } + outputFPFastMathDefaultInfo(); } for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { const Function &F = *FI; @@ -552,12 +569,84 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { } if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") && !M.getNamedMetadata("opencl.enable.FP_CONTRACT")) { - MCInst Inst; - Inst.setOpcode(SPIRV::OpExecutionMode); - Inst.addOperand(MCOperand::createReg(FReg)); - unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff); - Inst.addOperand(MCOperand::createImm(EM)); - outputMCInst(Inst); + if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { + // When SPV_KHR_float_controls2 is enabled, ContractionOff is + // deprecated. We need to use FPFastMathDefault with the appropriate + // flags instead. Since FPFastMathDefault takes a target type, we need + // to emit it for each floating-point type that exists in the module + // to match the effect of ContractionOff. As of now, there are 3 FP + // types: fp16, fp32 and fp64. + + // We only end up here because there is no "spirv.ExecutionMode" + // metadata, so that means no FPFastMathDefault. Therefore, we only + // need to make sure AllowContract is set to 0, as the rest of flags. + // We still need to emit the OpExecutionMode instruction, otherwise + // it's up to the client API to define the flags. Therefore, we need + // to find the constant with 0 value. + + // Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of + // type int32 with 0 value to represent the FP Fast Math Mode. + std::vector<const MachineInstr *> SPIRVFloatTypes; + const MachineInstr *ConstZero = nullptr; + for (const MachineInstr *MI : + MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { + // Skip if the instruction is not OpTypeFloat or OpConstant. + unsigned OpCode = MI->getOpcode(); + if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantNull) + continue; + + // Collect the SPIRV type if it's a float. + if (OpCode == SPIRV::OpTypeFloat) { + // Skip if the target type is not fp16, fp32, fp64. + const unsigned OpTypeFloatSize = MI->getOperand(1).getImm(); + if (OpTypeFloatSize != 16 && OpTypeFloatSize != 32 && + OpTypeFloatSize != 64) { + continue; + } + SPIRVFloatTypes.push_back(MI); + } else { + // Check if the constant is int32, if not skip it. + const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo(); + MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg()); + if (!TypeMI || TypeMI->getOperand(1).getImm() != 32) + continue; + + ConstZero = MI; + } + } + + // When SPV_KHR_float_controls2 is enabled, ContractionOff is + // deprecated. We need to use FPFastMathDefault with the appropriate + // flags instead. Since FPFastMathDefault takes a target type, we need + // to emit it for each floating-point type that exists in the module + // to match the effect of ContractionOff. As of now, there are 3 FP + // types: fp16, fp32 and fp64. + for (const MachineInstr *MI : SPIRVFloatTypes) { + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionModeId); + Inst.addOperand(MCOperand::createReg(FReg)); + unsigned EM = + static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault); + Inst.addOperand(MCOperand::createImm(EM)); + const MachineFunction *MF = MI->getMF(); + MCRegister TypeReg = + MAI->getRegisterAlias(MF, MI->getOperand(0).getReg()); + Inst.addOperand(MCOperand::createReg(TypeReg)); + assert(ConstZero && "There should be a constant zero."); + MCRegister ConstReg = MAI->getRegisterAlias( + ConstZero->getMF(), ConstZero->getOperand(0).getReg()); + Inst.addOperand(MCOperand::createReg(ConstReg)); + outputMCInst(Inst); + } + } else { + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionMode); + Inst.addOperand(MCOperand::createReg(FReg)); + unsigned EM = + static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff); + Inst.addOperand(MCOperand::createImm(EM)); + outputMCInst(Inst); + } } } } @@ -606,6 +695,101 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) { } } +void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() { + // Collect the SPIRVTypes that are OpTypeFloat and the constants of type + // int32, that might be used as FP Fast Math Mode. + std::vector<const MachineInstr *> SPIRVFloatTypes; + // Hashtable to associate immediate values with the constant holding them. + std::unordered_map<int, const MachineInstr *> ConstMap; + for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { + // Skip if the instruction is not OpTypeFloat or OpConstant. + unsigned OpCode = MI->getOpcode(); + if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantI && + OpCode != SPIRV::OpConstantNull) + continue; + + // Collect the SPIRV type if it's a float. + if (OpCode == SPIRV::OpTypeFloat) { + SPIRVFloatTypes.push_back(MI); + } else { + // Check if the constant is int32, if not skip it. + const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo(); + MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg()); + if (!TypeMI || TypeMI->getOpcode() != SPIRV::OpTypeInt || + TypeMI->getOperand(1).getImm() != 32) + continue; + + if (OpCode == SPIRV::OpConstantI) + ConstMap[MI->getOperand(2).getImm()] = MI; + else + ConstMap[0] = MI; + } + } + + for (const auto &[Func, FPFastMathDefaultInfoVec] : + MAI->FPFastMathDefaultInfoMap) { + if (FPFastMathDefaultInfoVec.empty()) + continue; + + for (const MachineInstr *MI : SPIRVFloatTypes) { + unsigned OpTypeFloatSize = MI->getOperand(1).getImm(); + unsigned Index = SPIRV::FPFastMathDefaultInfoVector:: + computeFPFastMathDefaultInfoVecIndex(OpTypeFloatSize); + assert(Index < FPFastMathDefaultInfoVec.size() && + "Index out of bounds for FPFastMathDefaultInfoVec"); + const auto &FPFastMathDefaultInfo = FPFastMathDefaultInfoVec[Index]; + assert(FPFastMathDefaultInfo.Ty && + "Expected target type for FPFastMathDefaultInfo"); + assert(FPFastMathDefaultInfo.Ty->getScalarSizeInBits() == + OpTypeFloatSize && + "Mismatched float type size"); + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionModeId); + MCRegister FuncReg = MAI->getFuncReg(Func); + assert(FuncReg.isValid()); + Inst.addOperand(MCOperand::createReg(FuncReg)); + Inst.addOperand( + MCOperand::createImm(SPIRV::ExecutionMode::FPFastMathDefault)); + MCRegister TypeReg = + MAI->getRegisterAlias(MI->getMF(), MI->getOperand(0).getReg()); + Inst.addOperand(MCOperand::createReg(TypeReg)); + unsigned Flags = FPFastMathDefaultInfo.FastMathFlags; + if (FPFastMathDefaultInfo.ContractionOff && + (Flags & SPIRV::FPFastMathMode::AllowContract)) + report_fatal_error( + "Conflicting FPFastMathFlags: ContractionOff and AllowContract"); + + if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve && + !(Flags & + (SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf | + SPIRV::FPFastMathMode::NSZ))) { + if (FPFastMathDefaultInfo.FPFastMathDefault) + report_fatal_error("Conflicting FPFastMathFlags: " + "SignedZeroInfNanPreserve but at least one of " + "NotNaN/NotInf/NSZ is enabled."); + } + + // Don't emit if none of the execution modes was used. + if (Flags == SPIRV::FPFastMathMode::None && + !FPFastMathDefaultInfo.ContractionOff && + !FPFastMathDefaultInfo.SignedZeroInfNanPreserve && + !FPFastMathDefaultInfo.FPFastMathDefault) + continue; + + // Retrieve the constant instruction for the immediate value. + auto It = ConstMap.find(Flags); + if (It == ConstMap.end()) + report_fatal_error("Expected constant instruction for FP Fast Math " + "Mode operand of FPFastMathDefault execution mode."); + const MachineInstr *ConstMI = It->second; + MCRegister ConstReg = MAI->getRegisterAlias( + ConstMI->getMF(), ConstMI->getOperand(0).getReg()); + Inst.addOperand(MCOperand::createReg(ConstReg)); + outputMCInst(Inst); + } + } +} + void SPIRVAsmPrinter::outputModuleSections() { const Module *M = MMI->getModule(); // Get the global subtarget to output module-level info. @@ -614,7 +798,8 @@ void SPIRVAsmPrinter::outputModuleSections() { MAI = &SPIRVModuleAnalysis::MAI; assert(ST && TII && MAI && M && "Module analysis is required"); // Output instructions according to the Logical Layout of a Module: - // 1,2. All OpCapability instructions, then optional OpExtension instructions. + // 1,2. All OpCapability instructions, then optional OpExtension + // instructions. outputGlobalRequirements(); // 3. Optional OpExtInstImport instructions. outputOpExtInstImports(*M); @@ -622,7 +807,8 @@ void SPIRVAsmPrinter::outputModuleSections() { outputOpMemoryModel(); // 5. All entry point declarations, using OpEntryPoint. outputEntryPoints(); - // 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId. + // 6. Execution-mode declarations, using OpExecutionMode or + // OpExecutionModeId. outputExecutionMode(*M); // 7a. Debug: all OpString, OpSourceExtension, OpSource, and // OpSourceContinued, without forward references. diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index f704d3a..0e0c454 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -1162,11 +1162,24 @@ static unsigned getNumSizeComponents(SPIRVType *imgType) { static bool generateExtInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, - SPIRVGlobalRegistry *GR) { + SPIRVGlobalRegistry *GR, const CallBase &CB) { // Lookup the extended instruction number in the TableGen records. const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; uint32_t Number = SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set)->Number; + // fmin_common and fmax_common are now deprecated, and we should use fmin and + // fmax with NotInf and NotNaN flags instead. Keep original number to add + // later the NoNans and NoInfs flags. + uint32_t OrigNumber = Number; + const SPIRVSubtarget &ST = + cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget()); + if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2) && + (Number == SPIRV::OpenCLExtInst::fmin_common || + Number == SPIRV::OpenCLExtInst::fmax_common)) { + Number = (Number == SPIRV::OpenCLExtInst::fmin_common) + ? SPIRV::OpenCLExtInst::fmin + : SPIRV::OpenCLExtInst::fmax; + } // Build extended instruction. auto MIB = @@ -1178,6 +1191,13 @@ static bool generateExtInst(const SPIRV::IncomingCall *Call, for (auto Argument : Call->Arguments) MIB.addUse(Argument); + MIB.getInstr()->copyIRFlags(CB); + if (OrigNumber == SPIRV::OpenCLExtInst::fmin_common || + OrigNumber == SPIRV::OpenCLExtInst::fmax_common) { + // Add NoNans and NoInfs flags to fmin/fmax instruction. + MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoNans); + MIB.getInstr()->setFlag(MachineInstr::MIFlag::FmNoInfs); + } return true; } @@ -2908,7 +2928,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall, MachineIRBuilder &MIRBuilder, const Register OrigRet, const Type *OrigRetTy, const SmallVectorImpl<Register> &Args, - SPIRVGlobalRegistry *GR) { + SPIRVGlobalRegistry *GR, const CallBase &CB) { LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n"); // Lookup the builtin in the TableGen records. @@ -2931,7 +2951,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall, // Match the builtin with implementation based on the grouping. switch (Call->Builtin->Group) { case SPIRV::Extended: - return generateExtInst(Call.get(), MIRBuilder, GR); + return generateExtInst(Call.get(), MIRBuilder, GR, CB); case SPIRV::Relational: return generateRelationalInst(Call.get(), MIRBuilder, GR); case SPIRV::Group: diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h index 1a8641a..f6a5234 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h @@ -39,7 +39,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall, MachineIRBuilder &MIRBuilder, const Register OrigRet, const Type *OrigRetTy, const SmallVectorImpl<Register> &Args, - SPIRVGlobalRegistry *GR); + SPIRVGlobalRegistry *GR, const CallBase &CB); /// Helper function for finding a builtin function attributes /// by a demangled function name. Defined in SPIRVBuiltins.cpp. diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index a412887..1a7c02c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -641,9 +641,9 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, GR->getPointerSize())); } } - if (auto Res = - SPIRV::lowerBuiltin(DemangledName, ST->getPreferredInstructionSet(), - MIRBuilder, ResVReg, OrigRetTy, ArgVRegs, GR)) + if (auto Res = SPIRV::lowerBuiltin( + DemangledName, ST->getPreferredInstructionSet(), MIRBuilder, + ResVReg, OrigRetTy, ArgVRegs, GR, *Info.CB)) return *Res; } diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 704edd3..9f2e075 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -25,6 +25,7 @@ #include "llvm/IR/TypedPointerType.h" #include "llvm/Transforms/Utils/Local.h" +#include <cassert> #include <queue> #include <unordered_set> @@ -152,6 +153,7 @@ class SPIRVEmitIntrinsics void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B); bool shouldTryToAddMemAliasingDecoration(Instruction *Inst); void insertSpirvDecorations(Instruction *I, IRBuilder<> &B); + void insertConstantsForFPFastMathDefault(Module &M); void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B); void processParamTypes(Function *F, IRBuilder<> &B); void processParamTypesByFunHeader(Function *F, IRBuilder<> &B); @@ -2249,6 +2251,198 @@ void SPIRVEmitIntrinsics::insertSpirvDecorations(Instruction *I, } } +static SPIRV::FPFastMathDefaultInfoVector &getOrCreateFPFastMathDefaultInfoVec( + const Module &M, + DenseMap<Function *, SPIRV::FPFastMathDefaultInfoVector> + &FPFastMathDefaultInfoMap, + Function *F) { + auto it = FPFastMathDefaultInfoMap.find(F); + if (it != FPFastMathDefaultInfoMap.end()) + return it->second; + + // If the map does not contain the entry, create a new one. Initialize it to + // contain all 3 elements sorted by bit width of target type: {half, float, + // double}. + SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec; + FPFastMathDefaultInfoVec.emplace_back(Type::getHalfTy(M.getContext()), + SPIRV::FPFastMathMode::None); + FPFastMathDefaultInfoVec.emplace_back(Type::getFloatTy(M.getContext()), + SPIRV::FPFastMathMode::None); + FPFastMathDefaultInfoVec.emplace_back(Type::getDoubleTy(M.getContext()), + SPIRV::FPFastMathMode::None); + return FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec); +} + +static SPIRV::FPFastMathDefaultInfo &getFPFastMathDefaultInfo( + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec, + const Type *Ty) { + size_t BitWidth = Ty->getScalarSizeInBits(); + int Index = + SPIRV::FPFastMathDefaultInfoVector::computeFPFastMathDefaultInfoVecIndex( + BitWidth); + assert(Index >= 0 && Index < 3 && + "Expected FPFastMathDefaultInfo for half, float, or double"); + assert(FPFastMathDefaultInfoVec.size() == 3 && + "Expected FPFastMathDefaultInfoVec to have exactly 3 elements"); + return FPFastMathDefaultInfoVec[Index]; +} + +void SPIRVEmitIntrinsics::insertConstantsForFPFastMathDefault(Module &M) { + const SPIRVSubtarget *ST = TM->getSubtargetImpl(); + if (!ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) + return; + + // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap. + // We need the entry point (function) as the key, and the target + // type and flags as the value. + // We also need to check ContractionOff and SignedZeroInfNanPreserve + // execution modes, as they are now deprecated and must be replaced + // with FPFastMathDefaultInfo. + auto Node = M.getNamedMetadata("spirv.ExecutionMode"); + if (!Node) { + if (!M.getNamedMetadata("opencl.enable.FP_CONTRACT")) { + // This requires emitting ContractionOff. However, because + // ContractionOff is now deprecated, we need to replace it with + // FPFastMathDefaultInfo with FP Fast Math Mode bitmask set to all 0. + // We need to create the constant for that. + + // Create constant instruction with the bitmask flags. + Constant *InitValue = + ConstantInt::get(Type::getInt32Ty(M.getContext()), 0); + // TODO: Reuse constant if there is one already with the required + // value. + [[maybe_unused]] GlobalVariable *GV = + new GlobalVariable(M, // Module + Type::getInt32Ty(M.getContext()), // Type + true, // isConstant + GlobalValue::InternalLinkage, // Linkage + InitValue // Initializer + ); + } + return; + } + + // The table maps function pointers to their default FP fast math info. It + // can be assumed that the SmallVector is sorted by the bit width of the + // type. The first element is the smallest bit width, and the last element + // is the largest bit width, therefore, we will have {half, float, double} + // in the order of their bit widths. + DenseMap<Function *, SPIRV::FPFastMathDefaultInfoVector> + FPFastMathDefaultInfoMap; + + for (unsigned i = 0; i < Node->getNumOperands(); i++) { + MDNode *MDN = cast<MDNode>(Node->getOperand(i)); + assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands"); + Function *F = cast<Function>( + cast<ConstantAsMetadata>(MDN->getOperand(0))->getValue()); + const auto EM = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(1))->getValue()) + ->getZExtValue(); + if (EM == SPIRV::ExecutionMode::FPFastMathDefault) { + assert(MDN->getNumOperands() == 4 && + "Expected 4 operands for FPFastMathDefault"); + const Type *T = cast<ValueAsMetadata>(MDN->getOperand(2))->getType(); + unsigned Flags = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(3))->getValue()) + ->getZExtValue(); + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, FPFastMathDefaultInfoMap, F); + SPIRV::FPFastMathDefaultInfo &Info = + getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, T); + Info.FastMathFlags = Flags; + Info.FPFastMathDefault = true; + } else if (EM == SPIRV::ExecutionMode::ContractionOff) { + assert(MDN->getNumOperands() == 2 && + "Expected no operands for ContractionOff"); + + // We need to save this info for every possible FP type, i.e. {half, + // float, double, fp128}. + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, FPFastMathDefaultInfoMap, F); + for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) { + Info.ContractionOff = true; + } + } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) { + assert(MDN->getNumOperands() == 3 && + "Expected 1 operand for SignedZeroInfNanPreserve"); + unsigned TargetWidth = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(2))->getValue()) + ->getZExtValue(); + // We need to save this info only for the FP type with TargetWidth. + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, FPFastMathDefaultInfoMap, F); + int Index = SPIRV::FPFastMathDefaultInfoVector:: + computeFPFastMathDefaultInfoVecIndex(TargetWidth); + assert(Index >= 0 && Index < 3 && + "Expected FPFastMathDefaultInfo for half, float, or double"); + assert(FPFastMathDefaultInfoVec.size() == 3 && + "Expected FPFastMathDefaultInfoVec to have exactly 3 elements"); + FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true; + } + } + + std::unordered_map<unsigned, GlobalVariable *> GlobalVars; + for (auto &[Func, FPFastMathDefaultInfoVec] : FPFastMathDefaultInfoMap) { + if (FPFastMathDefaultInfoVec.empty()) + continue; + + for (const SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) { + assert(Info.Ty && "Expected target type for FPFastMathDefaultInfo"); + // Skip if none of the execution modes was used. + unsigned Flags = Info.FastMathFlags; + if (Flags == SPIRV::FPFastMathMode::None && !Info.ContractionOff && + !Info.SignedZeroInfNanPreserve && !Info.FPFastMathDefault) + continue; + + // Check if flags are compatible. + if (Info.ContractionOff && (Flags & SPIRV::FPFastMathMode::AllowContract)) + report_fatal_error("Conflicting FPFastMathFlags: ContractionOff " + "and AllowContract"); + + if (Info.SignedZeroInfNanPreserve && + !(Flags & + (SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf | + SPIRV::FPFastMathMode::NSZ))) { + if (Info.FPFastMathDefault) + report_fatal_error("Conflicting FPFastMathFlags: " + "SignedZeroInfNanPreserve but at least one of " + "NotNaN/NotInf/NSZ is enabled."); + } + + if ((Flags & SPIRV::FPFastMathMode::AllowTransform) && + !((Flags & SPIRV::FPFastMathMode::AllowReassoc) && + (Flags & SPIRV::FPFastMathMode::AllowContract))) { + report_fatal_error("Conflicting FPFastMathFlags: " + "AllowTransform requires AllowReassoc and " + "AllowContract to be set."); + } + + auto it = GlobalVars.find(Flags); + GlobalVariable *GV = nullptr; + if (it != GlobalVars.end()) { + // Reuse existing global variable. + GV = it->second; + } else { + // Create constant instruction with the bitmask flags. + Constant *InitValue = + ConstantInt::get(Type::getInt32Ty(M.getContext()), Flags); + // TODO: Reuse constant if there is one already with the required + // value. + GV = new GlobalVariable(M, // Module + Type::getInt32Ty(M.getContext()), // Type + true, // isConstant + GlobalValue::InternalLinkage, // Linkage + InitValue // Initializer + ); + GlobalVars[Flags] = GV; + } + } + } +} + void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I, IRBuilder<> &B) { auto *II = dyn_cast<IntrinsicInst>(I); @@ -2569,9 +2763,9 @@ GetElementPtrInst * SPIRVEmitIntrinsics::simplifyZeroLengthArrayGepInst(GetElementPtrInst *GEP) { // getelementptr [0 x T], P, 0 (zero), I -> getelementptr T, P, I. // If type is 0-length array and first index is 0 (zero), drop both the - // 0-length array type and the first index. This is a common pattern in the - // IR, e.g. when using a zero-length array as a placeholder for a flexible - // array such as unbound arrays. + // 0-length array type and the first index. This is a common pattern in + // the IR, e.g. when using a zero-length array as a placeholder for a + // flexible array such as unbound arrays. assert(GEP && "GEP is null"); Type *SrcTy = GEP->getSourceElementType(); SmallVector<Value *, 8> Indices(GEP->indices()); @@ -2633,8 +2827,9 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { processParamTypesByFunHeader(CurrF, B); - // StoreInst's operand type can be changed during the next transformations, - // so we need to store it in the set. Also store already transformed types. + // StoreInst's operand type can be changed during the next + // transformations, so we need to store it in the set. Also store already + // transformed types. for (auto &I : instructions(Func)) { StoreInst *SI = dyn_cast<StoreInst>(&I); if (!SI) @@ -2681,8 +2876,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { for (auto &I : llvm::reverse(instructions(Func))) deduceOperandElementType(&I, &IncompleteRets); - // Pass forward for PHIs only, their operands are not preceed the instruction - // in meaning of `instructions(Func)`. + // Pass forward for PHIs only, their operands are not preceed the + // instruction in meaning of `instructions(Func)`. for (BasicBlock &BB : Func) for (PHINode &Phi : BB.phis()) if (isPointerTy(Phi.getType())) @@ -2692,8 +2887,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { TrackConstants = true; if (!I->getType()->isVoidTy() || isa<StoreInst>(I)) setInsertPointAfterDef(B, I); - // Visitors return either the original/newly created instruction for further - // processing, nullptr otherwise. + // Visitors return either the original/newly created instruction for + // further processing, nullptr otherwise. I = visit(*I); if (!I) continue; @@ -2816,6 +3011,7 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) { bool Changed = false; parseFunDeclarations(M); + insertConstantsForFPFastMathDefault(M); TodoType.clear(); for (auto &F : M) diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 115766c..6fd1c7e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -806,7 +806,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable( // arguments. MDNode *GVarMD = nullptr; if (GVar && (GVarMD = GVar->getMetadata("spirv.Decorations")) != nullptr) - buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD); + buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD, ST); return Reg; } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp index 45e88fc..ba95ad8 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp @@ -132,7 +132,8 @@ bool SPIRVInstrInfo::isHeaderInstr(const MachineInstr &MI) const { } } -bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI) const { +bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI, + bool KHRFloatControls2) const { switch (MI.getOpcode()) { case SPIRV::OpFAddS: case SPIRV::OpFSubS: @@ -146,6 +147,24 @@ bool SPIRVInstrInfo::canUseFastMathFlags(const MachineInstr &MI) const { case SPIRV::OpFRemV: case SPIRV::OpFMod: return true; + case SPIRV::OpFNegateV: + case SPIRV::OpFNegate: + case SPIRV::OpOrdered: + case SPIRV::OpUnordered: + case SPIRV::OpFOrdEqual: + case SPIRV::OpFOrdNotEqual: + case SPIRV::OpFOrdLessThan: + case SPIRV::OpFOrdLessThanEqual: + case SPIRV::OpFOrdGreaterThan: + case SPIRV::OpFOrdGreaterThanEqual: + case SPIRV::OpFUnordEqual: + case SPIRV::OpFUnordNotEqual: + case SPIRV::OpFUnordLessThan: + case SPIRV::OpFUnordLessThanEqual: + case SPIRV::OpFUnordGreaterThan: + case SPIRV::OpFUnordGreaterThanEqual: + case SPIRV::OpExtInst: + return KHRFloatControls2 ? true : false; default: return false; } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h index 72d2243..4de9d6a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.h @@ -36,7 +36,8 @@ public: bool isTypeDeclInstr(const MachineInstr &MI) const; bool isDecorationInstr(const MachineInstr &MI) const; bool isAliasingInstr(const MachineInstr &MI) const; - bool canUseFastMathFlags(const MachineInstr &MI) const; + bool canUseFastMathFlags(const MachineInstr &MI, + bool KHRFloatControls2) const; bool canUseNSW(const MachineInstr &MI) const; bool canUseNUW(const MachineInstr &MI) const; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 1aadd9d..273edf3 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -1073,7 +1073,8 @@ bool SPIRVInstructionSelector::selectExtInst(Register ResVReg, .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)) .addImm(static_cast<uint32_t>(Set)) - .addImm(Opcode); + .addImm(Opcode) + .setMIFlags(I.getFlags()); const unsigned NumOps = I.getNumOperands(); unsigned Index = 1; if (Index < NumOps && @@ -2629,6 +2630,7 @@ bool SPIRVInstructionSelector::selectCmp(Register ResVReg, .addUse(GR.getSPIRVTypeID(ResType)) .addUse(Cmp0) .addUse(Cmp1) + .setMIFlags(I.getFlags()) .constrainAllUses(TII, TRI, RBI); } diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index bc159d5..dc717a6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -248,6 +248,22 @@ static InstrSignature instrToSignature(const MachineInstr &MI, Register DefReg; InstrSignature Signature{MI.getOpcode()}; for (unsigned i = 0; i < MI.getNumOperands(); ++i) { + // The only decorations that can be applied more than once to a given <id> + // or structure member are UserSemantic(5635), CacheControlLoadINTEL (6442), + // and CacheControlStoreINTEL (6443). For all the rest of decorations, we + // will only add to the signature the Opcode, the id to which it applies, + // and the decoration id, disregarding any decoration flags. This will + // ensure that any subsequent decoration with the same id will be deemed as + // a duplicate. Then, at the call site, we will be able to handle duplicates + // in the best way. + unsigned Opcode = MI.getOpcode(); + if ((Opcode == SPIRV::OpDecorate) && i >= 2) { + unsigned DecorationID = MI.getOperand(1).getImm(); + if (DecorationID != SPIRV::Decoration::UserSemantic && + DecorationID != SPIRV::Decoration::CacheControlLoadINTEL && + DecorationID != SPIRV::Decoration::CacheControlStoreINTEL) + continue; + } const MachineOperand &MO = MI.getOperand(i); size_t h; if (MO.isReg()) { @@ -559,8 +575,54 @@ static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, MAI.setSkipEmission(&MI); InstrSignature MISign = instrToSignature(MI, MAI, true); auto FoundMI = IS.insert(std::move(MISign)); - if (!FoundMI.second) + if (!FoundMI.second) { + if (MI.getOpcode() == SPIRV::OpDecorate) { + assert(MI.getNumOperands() >= 2 && + "Decoration instructions must have at least 2 operands"); + assert(MSType == SPIRV::MB_Annotations && + "Only OpDecorate instructions can be duplicates"); + // For FPFastMathMode decoration, we need to merge the flags of the + // duplicate decoration with the original one, so we need to find the + // original instruction that has the same signature. For the rest of + // instructions, we will simply skip the duplicate. + if (MI.getOperand(1).getImm() != SPIRV::Decoration::FPFastMathMode) + return; // Skip duplicates of other decorations. + + const SPIRV::InstrList &Decorations = MAI.MS[MSType]; + for (const MachineInstr *OrigMI : Decorations) { + if (instrToSignature(*OrigMI, MAI, true) == MISign) { + assert(OrigMI->getNumOperands() == MI.getNumOperands() && + "Original instruction must have the same number of operands"); + assert( + OrigMI->getNumOperands() == 3 && + "FPFastMathMode decoration must have 3 operands for OpDecorate"); + unsigned OrigFlags = OrigMI->getOperand(2).getImm(); + unsigned NewFlags = MI.getOperand(2).getImm(); + if (OrigFlags == NewFlags) + return; // No need to merge, the flags are the same. + + // Emit warning about possible conflict between flags. + unsigned FinalFlags = OrigFlags | NewFlags; + llvm::errs() + << "Warning: Conflicting FPFastMathMode decoration flags " + "in instruction: " + << *OrigMI << "Original flags: " << OrigFlags + << ", new flags: " << NewFlags + << ". They will be merged on a best effort basis, but not " + "validated. Final flags: " + << FinalFlags << "\n"; + MachineInstr *OrigMINonConst = const_cast<MachineInstr *>(OrigMI); + MachineOperand &OrigFlagsOp = OrigMINonConst->getOperand(2); + OrigFlagsOp = + MachineOperand::CreateImm(static_cast<unsigned>(FinalFlags)); + return; // Merge done, so we found a duplicate; don't add it to MAI.MS + } + } + assert(false && "No original instruction found for the duplicate " + "OpDecorate, but we found one in IS."); + } return; // insert failed, so we found a duplicate; don't add it to MAI.MS + } // No duplicates, so add it. if (Append) MAI.MS[MSType].push_back(&MI); @@ -934,6 +996,11 @@ static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex, } else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) { Reqs.addRequirements(SPIRV::Capability::FPMaxErrorINTEL); Reqs.addExtension(SPIRV::Extension::SPV_INTEL_fp_max_error); + } else if (Dec == SPIRV::Decoration::FPFastMathMode) { + if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { + Reqs.addRequirements(SPIRV::Capability::FloatControls2); + Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2); + } } } @@ -1994,10 +2061,13 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, // Collect requirements for OpExecutionMode instructions. auto Node = M.getNamedMetadata("spirv.ExecutionMode"); if (Node) { - bool RequireFloatControls = false, RequireFloatControls2 = false, + bool RequireFloatControls = false, RequireIntelFloatControls2 = false, + RequireKHRFloatControls2 = false, VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4)); - bool HasFloatControls2 = + bool HasIntelFloatControls2 = ST.canUseExtension(SPIRV::Extension::SPV_INTEL_float_controls2); + bool HasKHRFloatControls2 = + ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2); for (unsigned i = 0; i < Node->getNumOperands(); i++) { MDNode *MDN = cast<MDNode>(Node->getOperand(i)); const MDOperand &MDOp = MDN->getOperand(1); @@ -2010,7 +2080,6 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, switch (EM) { case SPIRV::ExecutionMode::DenormPreserve: case SPIRV::ExecutionMode::DenormFlushToZero: - case SPIRV::ExecutionMode::SignedZeroInfNanPreserve: case SPIRV::ExecutionMode::RoundingModeRTE: case SPIRV::ExecutionMode::RoundingModeRTZ: RequireFloatControls = VerLower14; @@ -2021,8 +2090,28 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, case SPIRV::ExecutionMode::RoundingModeRTNINTEL: case SPIRV::ExecutionMode::FloatingPointModeALTINTEL: case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL: - if (HasFloatControls2) { - RequireFloatControls2 = true; + if (HasIntelFloatControls2) { + RequireIntelFloatControls2 = true; + MAI.Reqs.getAndAddRequirements( + SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); + } + break; + case SPIRV::ExecutionMode::FPFastMathDefault: { + if (HasKHRFloatControls2) { + RequireKHRFloatControls2 = true; + MAI.Reqs.getAndAddRequirements( + SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); + } + break; + } + case SPIRV::ExecutionMode::ContractionOff: + case SPIRV::ExecutionMode::SignedZeroInfNanPreserve: + if (HasKHRFloatControls2) { + RequireKHRFloatControls2 = true; + MAI.Reqs.getAndAddRequirements( + SPIRV::OperandCategory::ExecutionModeOperand, + SPIRV::ExecutionMode::FPFastMathDefault, ST); + } else { MAI.Reqs.getAndAddRequirements( SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); } @@ -2037,8 +2126,10 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, if (RequireFloatControls && ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls)) MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls); - if (RequireFloatControls2) + if (RequireIntelFloatControls2) MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_float_controls2); + if (RequireKHRFloatControls2) + MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2); } for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { const Function &F = *FI; @@ -2078,8 +2169,11 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, } } -static unsigned getFastMathFlags(const MachineInstr &I) { +static unsigned getFastMathFlags(const MachineInstr &I, + const SPIRVSubtarget &ST) { unsigned Flags = SPIRV::FPFastMathMode::None; + bool CanUseKHRFloatControls2 = + ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2); if (I.getFlag(MachineInstr::MIFlag::FmNoNans)) Flags |= SPIRV::FPFastMathMode::NotNaN; if (I.getFlag(MachineInstr::MIFlag::FmNoInfs)) @@ -2088,12 +2182,45 @@ static unsigned getFastMathFlags(const MachineInstr &I) { Flags |= SPIRV::FPFastMathMode::NSZ; if (I.getFlag(MachineInstr::MIFlag::FmArcp)) Flags |= SPIRV::FPFastMathMode::AllowRecip; - if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) - Flags |= SPIRV::FPFastMathMode::Fast; + if (I.getFlag(MachineInstr::MIFlag::FmContract) && CanUseKHRFloatControls2) + Flags |= SPIRV::FPFastMathMode::AllowContract; + if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) { + if (CanUseKHRFloatControls2) + // LLVM reassoc maps to SPIRV transform, see + // https://github.com/KhronosGroup/SPIRV-Registry/issues/326 for details. + // Because we are enabling AllowTransform, we must enable AllowReassoc and + // AllowContract too, as required by SPIRV spec. Also, we used to map + // MIFlag::FmReassoc to FPFastMathMode::Fast, which now should instead by + // replaced by turning all the other bits instead. Therefore, we're + // enabling every bit here except None and Fast. + Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf | + SPIRV::FPFastMathMode::NSZ | SPIRV::FPFastMathMode::AllowRecip | + SPIRV::FPFastMathMode::AllowTransform | + SPIRV::FPFastMathMode::AllowReassoc | + SPIRV::FPFastMathMode::AllowContract; + else + Flags |= SPIRV::FPFastMathMode::Fast; + } + + if (CanUseKHRFloatControls2) { + // Error out if SPIRV::FPFastMathMode::Fast is enabled. + assert(!(Flags & SPIRV::FPFastMathMode::Fast) && + "SPIRV::FPFastMathMode::Fast is deprecated and should not be used " + "anymore."); + + // Error out if AllowTransform is enabled without AllowReassoc and + // AllowContract. + assert((!(Flags & SPIRV::FPFastMathMode::AllowTransform) || + ((Flags & SPIRV::FPFastMathMode::AllowReassoc && + Flags & SPIRV::FPFastMathMode::AllowContract))) && + "SPIRV::FPFastMathMode::AllowTransform requires AllowReassoc and " + "AllowContract flags to be enabled as well."); + } + return Flags; } -static bool isFastMathMathModeAvailable(const SPIRVSubtarget &ST) { +static bool isFastMathModeAvailable(const SPIRVSubtarget &ST) { if (ST.isKernel()) return true; if (ST.getSPIRVVersion() < VersionTuple(1, 2)) @@ -2101,9 +2228,10 @@ static bool isFastMathMathModeAvailable(const SPIRVSubtarget &ST) { return ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2); } -static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, - const SPIRVInstrInfo &TII, - SPIRV::RequirementHandler &Reqs) { +static void handleMIFlagDecoration( + MachineInstr &I, const SPIRVSubtarget &ST, const SPIRVInstrInfo &TII, + SPIRV::RequirementHandler &Reqs, const SPIRVGlobalRegistry *GR, + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec) { if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) && getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, SPIRV::Decoration::NoSignedWrap, ST, Reqs) @@ -2119,13 +2247,53 @@ static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, buildOpDecorate(I.getOperand(0).getReg(), I, TII, SPIRV::Decoration::NoUnsignedWrap, {}); } - if (!TII.canUseFastMathFlags(I)) - return; - unsigned FMFlags = getFastMathFlags(I); - if (FMFlags == SPIRV::FPFastMathMode::None) + if (!TII.canUseFastMathFlags( + I, ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2))) return; - if (isFastMathMathModeAvailable(ST)) { + unsigned FMFlags = getFastMathFlags(I, ST); + if (FMFlags == SPIRV::FPFastMathMode::None) { + // We also need to check if any FPFastMathDefault info was set for the + // types used in this instruction. + if (FPFastMathDefaultInfoVec.empty()) + return; + + // There are three types of instructions that can use fast math flags: + // 1. Arithmetic instructions (FAdd, FMul, FSub, FDiv, FRem, etc.) + // 2. Relational instructions (FCmp, FOrd, FUnord, etc.) + // 3. Extended instructions (ExtInst) + // For arithmetic instructions, the floating point type can be in the + // result type or in the operands, but they all must be the same. + // For the relational and logical instructions, the floating point type + // can only be in the operands 1 and 2, not the result type. Also, the + // operands must have the same type. For the extended instructions, the + // floating point type can be in the result type or in the operands. It's + // unclear if the operands and the result type must be the same. Let's + // assume they must be. Therefore, for 1. and 2., we can check the first + // operand type, and for 3. we can check the result type. + assert(I.getNumOperands() >= 3 && "Expected at least 3 operands"); + Register ResReg = I.getOpcode() == SPIRV::OpExtInst + ? I.getOperand(1).getReg() + : I.getOperand(2).getReg(); + SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResReg, I.getMF()); + const Type *Ty = GR->getTypeForSPIRVType(ResType); + Ty = Ty->isVectorTy() ? cast<VectorType>(Ty)->getElementType() : Ty; + + // Match instruction type with the FPFastMathDefaultInfoVec. + bool Emit = false; + for (SPIRV::FPFastMathDefaultInfo &Elem : FPFastMathDefaultInfoVec) { + if (Ty == Elem.Ty) { + FMFlags = Elem.FastMathFlags; + Emit = Elem.ContractionOff || Elem.SignedZeroInfNanPreserve || + Elem.FPFastMathDefault; + break; + } + } + + if (FMFlags == SPIRV::FPFastMathMode::None && !Emit) + return; + } + if (isFastMathModeAvailable(ST)) { Register DstReg = I.getOperand(0).getReg(); buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags}); @@ -2135,14 +2303,17 @@ static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, // Walk all functions and add decorations related to MI flags. static void addDecorations(const Module &M, const SPIRVInstrInfo &TII, MachineModuleInfo *MMI, const SPIRVSubtarget &ST, - SPIRV::ModuleAnalysisInfo &MAI) { + SPIRV::ModuleAnalysisInfo &MAI, + const SPIRVGlobalRegistry *GR) { for (auto F = M.begin(), E = M.end(); F != E; ++F) { MachineFunction *MF = MMI->getMachineFunction(*F); if (!MF) continue; + for (auto &MBB : *MF) for (auto &MI : MBB) - handleMIFlagDecoration(MI, ST, TII, MAI.Reqs); + handleMIFlagDecoration(MI, ST, TII, MAI.Reqs, GR, + MAI.FPFastMathDefaultInfoMap[&(*F)]); } } @@ -2188,6 +2359,111 @@ static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR, } } +static SPIRV::FPFastMathDefaultInfoVector &getOrCreateFPFastMathDefaultInfoVec( + const Module &M, SPIRV::ModuleAnalysisInfo &MAI, const Function *F) { + auto it = MAI.FPFastMathDefaultInfoMap.find(F); + if (it != MAI.FPFastMathDefaultInfoMap.end()) + return it->second; + + // If the map does not contain the entry, create a new one. Initialize it to + // contain all 3 elements sorted by bit width of target type: {half, float, + // double}. + SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec; + FPFastMathDefaultInfoVec.emplace_back(Type::getHalfTy(M.getContext()), + SPIRV::FPFastMathMode::None); + FPFastMathDefaultInfoVec.emplace_back(Type::getFloatTy(M.getContext()), + SPIRV::FPFastMathMode::None); + FPFastMathDefaultInfoVec.emplace_back(Type::getDoubleTy(M.getContext()), + SPIRV::FPFastMathMode::None); + return MAI.FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec); +} + +static SPIRV::FPFastMathDefaultInfo &getFPFastMathDefaultInfo( + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec, + const Type *Ty) { + size_t BitWidth = Ty->getScalarSizeInBits(); + int Index = + SPIRV::FPFastMathDefaultInfoVector::computeFPFastMathDefaultInfoVecIndex( + BitWidth); + assert(Index >= 0 && Index < 3 && + "Expected FPFastMathDefaultInfo for half, float, or double"); + assert(FPFastMathDefaultInfoVec.size() == 3 && + "Expected FPFastMathDefaultInfoVec to have exactly 3 elements"); + return FPFastMathDefaultInfoVec[Index]; +} + +static void collectFPFastMathDefaults(const Module &M, + SPIRV::ModuleAnalysisInfo &MAI, + const SPIRVSubtarget &ST) { + if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) + return; + + // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap. + // We need the entry point (function) as the key, and the target + // type and flags as the value. + // We also need to check ContractionOff and SignedZeroInfNanPreserve + // execution modes, as they are now deprecated and must be replaced + // with FPFastMathDefaultInfo. + auto Node = M.getNamedMetadata("spirv.ExecutionMode"); + if (!Node) + return; + + for (unsigned i = 0; i < Node->getNumOperands(); i++) { + MDNode *MDN = cast<MDNode>(Node->getOperand(i)); + assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands"); + const Function *F = cast<Function>( + cast<ConstantAsMetadata>(MDN->getOperand(0))->getValue()); + const auto EM = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(1))->getValue()) + ->getZExtValue(); + if (EM == SPIRV::ExecutionMode::FPFastMathDefault) { + assert(MDN->getNumOperands() == 4 && + "Expected 4 operands for FPFastMathDefault"); + + const Type *T = cast<ValueAsMetadata>(MDN->getOperand(2))->getType(); + unsigned Flags = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(3))->getValue()) + ->getZExtValue(); + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, MAI, F); + SPIRV::FPFastMathDefaultInfo &Info = + getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, T); + Info.FastMathFlags = Flags; + Info.FPFastMathDefault = true; + } else if (EM == SPIRV::ExecutionMode::ContractionOff) { + assert(MDN->getNumOperands() == 2 && + "Expected no operands for ContractionOff"); + + // We need to save this info for every possible FP type, i.e. {half, + // float, double, fp128}. + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, MAI, F); + for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) { + Info.ContractionOff = true; + } + } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) { + assert(MDN->getNumOperands() == 3 && + "Expected 1 operand for SignedZeroInfNanPreserve"); + unsigned TargetWidth = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(2))->getValue()) + ->getZExtValue(); + // We need to save this info only for the FP type with TargetWidth. + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, MAI, F); + int Index = SPIRV::FPFastMathDefaultInfoVector:: + computeFPFastMathDefaultInfoVecIndex(TargetWidth); + assert(Index >= 0 && Index < 3 && + "Expected FPFastMathDefaultInfo for half, float, or double"); + assert(FPFastMathDefaultInfoVec.size() == 3 && + "Expected FPFastMathDefaultInfoVec to have exactly 3 elements"); + FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true; + } + } +} + struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI; void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { @@ -2209,7 +2485,8 @@ bool SPIRVModuleAnalysis::runOnModule(Module &M) { patchPhis(M, GR, *TII, MMI); addMBBNames(M, *TII, MMI, *ST, MAI); - addDecorations(M, *TII, MMI, *ST, MAI); + collectFPFastMathDefaults(M, MAI, *ST); + addDecorations(M, *TII, MMI, *ST, MAI, GR); collectReqs(M, MAI, MMI, *ST); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h index 41c792a..d8376cd 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h @@ -159,6 +159,13 @@ struct ModuleAnalysisInfo { InstrList MS[NUM_MODULE_SECTIONS]; // The table maps MBB number to SPIR-V unique ID register. DenseMap<std::pair<const MachineFunction *, int>, MCRegister> BBNumToRegMap; + // The table maps function pointers to their default FP fast math info. It can + // be assumed that the SmallVector is sorted by the bit width of the type. The + // first element is the smallest bit width, and the last element is the + // largest bit width, therefore, we will have {half, float, double} in + // the order of their bit widths. + DenseMap<const Function *, SPIRV::FPFastMathDefaultInfoVector> + FPFastMathDefaultInfoMap; MCRegister getFuncReg(const Function *F) { assert(F && "Function is null"); diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index 1a08c6a..db6f2d6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -839,6 +839,7 @@ static uint32_t convertFloatToSPIRVWord(float F) { static void insertSpirvDecorations(MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder MIB) { + const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(MIB.getMF().getSubtarget()); SmallVector<MachineInstr *, 10> ToErase; for (MachineBasicBlock &MBB : MF) { for (MachineInstr &MI : MBB) { @@ -849,7 +850,7 @@ static void insertSpirvDecorations(MachineFunction &MF, SPIRVGlobalRegistry *GR, MIB.setInsertPt(*MI.getParent(), MI.getNextNode()); if (isSpvIntrinsic(MI, Intrinsic::spv_assign_decoration)) { buildOpSpirvDecorations(MI.getOperand(1).getReg(), MIB, - MI.getOperand(2).getMetadata()); + MI.getOperand(2).getMetadata(), ST); } else if (isSpvIntrinsic(MI, Intrinsic::spv_assign_fpmaxerror_decoration)) { ConstantFP *OpV = mdconst::dyn_extract<ConstantFP>( diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 66ce5a2..6a32dba 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -802,6 +802,7 @@ defm RoundingModeRTPINTEL : ExecutionModeOperand<5620, [RoundToInfinityINTEL]>; defm RoundingModeRTNINTEL : ExecutionModeOperand<5621, [RoundToInfinityINTEL]>; defm FloatingPointModeALTINTEL : ExecutionModeOperand<5622, [FloatingPointModeINTEL]>; defm FloatingPointModeIEEEINTEL : ExecutionModeOperand<5623, [FloatingPointModeINTEL]>; +defm FPFastMathDefault : ExecutionModeOperand<6028, [FloatControls2]>; //===----------------------------------------------------------------------===// // Multiclass used to define StorageClass enum values and at the same time @@ -1153,6 +1154,9 @@ defm NotInf : FPFastMathModeOperand<0x2, [Kernel]>; defm NSZ : FPFastMathModeOperand<0x4, [Kernel]>; defm AllowRecip : FPFastMathModeOperand<0x8, [Kernel]>; defm Fast : FPFastMathModeOperand<0x10, [Kernel]>; +defm AllowContract : FPFastMathModeOperand<0x10000, [FloatControls2]>; +defm AllowReassoc : FPFastMathModeOperand<0x20000, [FloatControls2]>; +defm AllowTransform : FPFastMathModeOperand<0x40000, [FloatControls2]>; //===----------------------------------------------------------------------===// // Multiclass used to define FPRoundingMode enum values and at the same time diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 820e56b..327c011 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -181,7 +181,7 @@ void buildOpMemberDecorate(Register Reg, MachineInstr &I, } void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder, - const MDNode *GVarMD) { + const MDNode *GVarMD, const SPIRVSubtarget &ST) { for (unsigned I = 0, E = GVarMD->getNumOperands(); I != E; ++I) { auto *OpMD = dyn_cast<MDNode>(GVarMD->getOperand(I)); if (!OpMD) @@ -193,6 +193,20 @@ void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder, if (!DecorationId) report_fatal_error("Expect SPIR-V <Decoration> operand to be the first " "element of the decoration"); + + // The goal of `spirv.Decorations` metadata is to provide a way to + // represent SPIR-V entities that do not map to LLVM in an obvious way. + // FP flags do have obvious matches between LLVM IR and SPIR-V. + // Additionally, we have no guarantee at this point that the flags passed + // through the decoration are not violated already in the optimizer passes. + // Therefore, we simply ignore FP flags, including NoContraction, and + // FPFastMathMode. + if (DecorationId->getZExtValue() == + static_cast<uint32_t>(SPIRV::Decoration::NoContraction) || + DecorationId->getZExtValue() == + static_cast<uint32_t>(SPIRV::Decoration::FPFastMathMode)) { + continue; // Ignored. + } auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate) .addUse(Reg) .addImm(static_cast<uint32_t>(DecorationId->getZExtValue())); diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index 45c520a..409a0fd 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -113,6 +113,54 @@ public: std::function<bool(BasicBlock *)> Op); }; +namespace SPIRV { +struct FPFastMathDefaultInfo { + const Type *Ty = nullptr; + unsigned FastMathFlags = 0; + // When SPV_KHR_float_controls2 ContractionOff and SignzeroInfNanPreserve are + // deprecated, and we replace them with FPFastMathDefault appropriate flags + // instead. However, we have no guarantee about the order in which we will + // process execution modes. Therefore it could happen that we first process + // ContractionOff, setting AllowContraction bit to 0, and then we process + // FPFastMathDefault enabling AllowContraction bit, effectively invalidating + // ContractionOff. Because of that, it's best to keep separate bits for the + // different execution modes, and we will try and combine them later when we + // emit OpExecutionMode instructions. + bool ContractionOff = false; + bool SignedZeroInfNanPreserve = false; + bool FPFastMathDefault = false; + + FPFastMathDefaultInfo() = default; + FPFastMathDefaultInfo(const Type *Ty, unsigned FastMathFlags) + : Ty(Ty), FastMathFlags(FastMathFlags) {} + bool operator==(const FPFastMathDefaultInfo &Other) const { + return Ty == Other.Ty && FastMathFlags == Other.FastMathFlags && + ContractionOff == Other.ContractionOff && + SignedZeroInfNanPreserve == Other.SignedZeroInfNanPreserve && + FPFastMathDefault == Other.FPFastMathDefault; + } +}; + +struct FPFastMathDefaultInfoVector + : public SmallVector<SPIRV::FPFastMathDefaultInfo, 3> { + static size_t computeFPFastMathDefaultInfoVecIndex(size_t BitWidth) { + switch (BitWidth) { + case 16: // half + return 0; + case 32: // float + return 1; + case 64: // double + return 2; + default: + report_fatal_error("Expected BitWidth to be 16, 32, 64", false); + } + llvm_unreachable( + "Unreachable code in computeFPFastMathDefaultInfoVecIndex"); + } +}; + +} // namespace SPIRV + // Add the given string as a series of integer operand, inserting null // terminators and padding to make sure the operands all have 32-bit // little-endian words. @@ -161,7 +209,7 @@ void buildOpMemberDecorate(Register Reg, MachineInstr &I, // Add an OpDecorate instruction by "spirv.Decorations" metadata node. void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder, - const MDNode *GVarMD); + const MDNode *GVarMD, const SPIRVSubtarget &ST); // Return a valid position for the OpVariable instruction inside a function, // i.e., at the beginning of the first block of the function. @@ -508,6 +556,5 @@ unsigned getArrayComponentCount(const MachineRegisterInfo *MRI, const MachineInstr *ResType); MachineBasicBlock::iterator getFirstValidInstructionInsertPoint(MachineBasicBlock &BB); - } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index 64b9dc3..163bf9b 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -186,7 +186,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( // SIMD-specific configuration if (Subtarget->hasSIMD128()) { - // Combine partial.reduce.add before legalization gets confused. setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN); // Combine wide-vector muls, with extend inputs, to extmul_half. @@ -317,6 +316,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom); setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom); } + + // Partial MLA reductions. + for (auto Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) { + setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v16i8, Legal); + setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v8i16, Legal); + } } // As a special case, these operators use the type to mean the type to @@ -416,41 +421,6 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL, return TargetLowering::getPointerMemTy(DL, AS); } -bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic( - const IntrinsicInst *I) const { - if (I->getIntrinsicID() != Intrinsic::vector_partial_reduce_add) - return true; - - EVT VT = EVT::getEVT(I->getType()); - if (VT.getSizeInBits() > 128) - return true; - - auto Op1 = I->getOperand(1); - - if (auto *InputInst = dyn_cast<Instruction>(Op1)) { - unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode()); - if (Opcode == ISD::MUL) { - if (isa<Instruction>(InputInst->getOperand(0)) && - isa<Instruction>(InputInst->getOperand(1))) { - // dot only supports signed inputs but also support lowering unsigned. - if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() != - cast<Instruction>(InputInst->getOperand(1))->getOpcode()) - return true; - - EVT Op1VT = EVT::getEVT(Op1->getType()); - if (Op1VT.getVectorElementType() == VT.getVectorElementType() && - ((VT.getVectorElementCount() * 2 == - Op1VT.getVectorElementCount()) || - (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount()))) - return false; - } - } else if (ISD::isExtOpcode(Opcode)) { - return false; - } - } - return true; -} - TargetLowering::AtomicExpansionKind WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { // We have wasm instructions for these @@ -2113,106 +2083,6 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op, MachinePointerInfo(SV)); } -// Try to lower partial.reduce.add to a dot or fallback to a sequence with -// extmul and adds. -SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) { - assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN); - if (N->getConstantOperandVal(0) != Intrinsic::vector_partial_reduce_add) - return SDValue(); - - assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32"); - SDLoc DL(N); - - SDValue Input = N->getOperand(2); - if (Input->getOpcode() == ISD::MUL) { - SDValue ExtendLHS = Input->getOperand(0); - SDValue ExtendRHS = Input->getOperand(1); - assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) && - ISD::isExtOpcode(ExtendRHS.getOpcode())) && - "expected widening mul or add"); - assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() && - "expected binop to use the same extend for both operands"); - - SDValue ExtendInLHS = ExtendLHS->getOperand(0); - SDValue ExtendInRHS = ExtendRHS->getOperand(0); - bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND; - unsigned LowOpc = - IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U; - unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S - : WebAssemblyISD::EXTEND_HIGH_U; - SDValue LowLHS; - SDValue LowRHS; - SDValue HighLHS; - SDValue HighRHS; - - auto AssignInputs = [&](MVT VT) { - LowLHS = DAG.getNode(LowOpc, DL, VT, ExtendInLHS); - LowRHS = DAG.getNode(LowOpc, DL, VT, ExtendInRHS); - HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS); - HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS); - }; - - if (ExtendInLHS->getValueType(0) == MVT::v8i16) { - if (IsSigned) { - // i32x4.dot_i16x8_s - SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, - ExtendInLHS, ExtendInRHS); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot); - } - - // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs))) - MVT VT = MVT::v4i32; - AssignInputs(VT); - SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS); - SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS); - SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh); - return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add); - } else { - assert(ExtendInLHS->getValueType(0) == MVT::v16i8 && - "expected v16i8 input types"); - AssignInputs(MVT::v8i16); - // Lower to a wider tree, using twice the operations compared to above. - if (IsSigned) { - // Use two dots - SDValue DotLHS = - DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS); - SDValue DotRHS = - DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS); - SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); - } - - SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS); - SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS); - - SDValue AddLow = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, - MVT::v4i32, MulLow); - SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, - MVT::v4i32, MulHigh); - SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); - } - } else { - // Accumulate the input using extadd_pairwise. - assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend"); - bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND; - unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S - : WebAssemblyISD::EXT_ADD_PAIRWISE_U; - SDValue ExtendIn = Input->getOperand(0); - if (ExtendIn->getValueType(0) == MVT::v8i16) { - SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); - } - - assert(ExtendIn->getValueType(0) == MVT::v16i8 && - "expected v16i8 input types"); - SDValue Add = - DAG.getNode(PairwiseOpc, DL, MVT::v4i32, - DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn)); - return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add); - } -} - SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op, SelectionDAG &DAG) const { MachineFunction &MF = DAG.getMachineFunction(); @@ -3683,11 +3553,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, return performVectorTruncZeroCombine(N, DCI); case ISD::TRUNCATE: return performTruncateCombine(N, DCI); - case ISD::INTRINSIC_WO_CHAIN: { - if (auto AnyAllCombine = performAnyAllCombine(N, DCI.DAG)) - return AnyAllCombine; - return performLowerPartialReduction(N, DCI.DAG); - } + case ISD::INTRINSIC_WO_CHAIN: + return performAnyAllCombine(N, DCI.DAG); case ISD::MUL: return performMulCombine(N, DCI); } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h index 72401a7..b33a853 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h @@ -45,8 +45,6 @@ private: /// right decision when generating code for different targets. const WebAssemblySubtarget *Subtarget; - bool - shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override; AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *) const override; bool shouldScalarizeBinop(SDValue VecOp) const override; FastISel *createFastISel(FunctionLoweringInfo &FuncInfo, @@ -89,8 +87,7 @@ private: bool CanLowerReturn(CallingConv::ID CallConv, MachineFunction &MF, bool isVarArg, const SmallVectorImpl<ISD::OutputArg> &Outs, - LLVMContext &Context, - const Type *RetTy) const override; + LLVMContext &Context, const Type *RetTy) const override; SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg, const SmallVectorImpl<ISD::OutputArg> &Outs, const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl, diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td index d8948ad..1306026 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -1505,6 +1505,51 @@ defm Q15MULR_SAT_S : SIMDBinary<I16x8, int_wasm_q15mulr_sat_signed, "q15mulr_sat_s", 0x82>; //===----------------------------------------------------------------------===// +// Partial reductions, using: dot, extmul and extadd_pairwise +//===----------------------------------------------------------------------===// +// MLA: v8i16 -> v4i32 +def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$lhs), + (v8i16 V128:$rhs))), + (ADD_I32x4 (DOT $lhs, $rhs), $acc)>; +def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$lhs), + (v8i16 V128:$rhs))), + (ADD_I32x4 (ADD_I32x4 (EXTMUL_LOW_U_I32x4 $lhs, $rhs), + (EXTMUL_HIGH_U_I32x4 $lhs, $rhs)), + $acc)>; +// MLA: v16i8 -> v4i32 +def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v16i8 V128:$lhs), + (v16i8 V128:$rhs))), + (ADD_I32x4 (ADD_I32x4 (DOT (extend_low_s_I16x8 $lhs), + (extend_low_s_I16x8 $rhs)), + (DOT (extend_high_s_I16x8 $lhs), + (extend_high_s_I16x8 $rhs))), + $acc)>; +def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v16i8 V128:$lhs), + (v16i8 V128:$rhs))), + (ADD_I32x4 (ADD_I32x4 (extadd_pairwise_u_I32x4 (EXTMUL_LOW_U_I16x8 $lhs, $rhs)), + (extadd_pairwise_u_I32x4 (EXTMUL_HIGH_U_I16x8 $lhs, $rhs))), + $acc)>; + +// Accumulate: v8i16 -> v4i32 +def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$in), + (I16x8.splat (i32 1)))), + (ADD_I32x4 (extadd_pairwise_s_I32x4 $in), $acc)>; + +def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$in), + (I16x8.splat (i32 1)))), + (ADD_I32x4 (extadd_pairwise_u_I32x4 $in), $acc)>; + +// Accumulate: v16i8 -> v4i32 +def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v16i8 V128:$in), + (I8x16.splat (i32 1)))), + (ADD_I32x4 (extadd_pairwise_s_I32x4 (extadd_pairwise_s_I16x8 $in)), + $acc)>; +def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v16i8 V128:$in), + (I8x16.splat (i32 1)))), + (ADD_I32x4 (extadd_pairwise_u_I32x4 (extadd_pairwise_u_I16x8 $in)), + $acc)>; + +//===----------------------------------------------------------------------===// // Relaxed swizzle //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 292eab7..cd04ff5 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -45169,6 +45169,7 @@ bool X86TargetLowering::isGuaranteedNotToBeUndefOrPoisonForTargetNode( case X86ISD::Wrapper: case X86ISD::WrapperRIP: return true; + case X86ISD::INSERTPS: case X86ISD::BLENDI: case X86ISD::PSHUFB: case X86ISD::PSHUFD: @@ -45239,6 +45240,7 @@ bool X86TargetLowering::canCreateUndefOrPoisonForTargetNode( case X86ISD::BLENDV: return false; // SSE target shuffles. + case X86ISD::INSERTPS: case X86ISD::PSHUFB: case X86ISD::PSHUFD: case X86ISD::UNPCKL: diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index ee1fec0..805bdb4 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -1350,6 +1350,10 @@ static bool foldMemChr(CallInst *Call, DomTreeUpdater *DTU, BB->getTerminator()->eraseFromParent(); SwitchInst *SI = IRB.CreateSwitch( IRB.CreateTrunc(Call->getArgOperand(1), ByteTy), BBNext, N); + // We can't know the precise weights here, as they would depend on the value + // distribution of Call->getArgOperand(1). So we just mark it as "unknown". + setExplicitlyUnknownBranchWeightsIfProfiled(*SI, *Call->getFunction(), + DEBUG_TYPE); Type *IndexTy = DL.getIndexType(Call->getType()); SmallVector<DominatorTree::UpdateType, 8> Updates; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index d1ca0a6..59e103cd 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -880,11 +880,11 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { // zext(bool) + C -> bool ? C + 1 : C if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->getScalarSizeInBits() == 1) - return SelectInst::Create(X, InstCombiner::AddOne(Op1C), Op1); + return createSelectInst(X, InstCombiner::AddOne(Op1C), Op1); // sext(bool) + C -> bool ? C - 1 : C if (match(Op0, m_SExt(m_Value(X))) && X->getType()->getScalarSizeInBits() == 1) - return SelectInst::Create(X, InstCombiner::SubOne(Op1C), Op1); + return createSelectInst(X, InstCombiner::SubOne(Op1C), Op1); // ~X + C --> (C-1) - X if (match(Op0, m_Not(m_Value(X)))) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 7a979c1..4f94aa2 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -23,6 +23,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Value.h" #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" @@ -62,14 +63,14 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final public InstVisitor<InstCombinerImpl, Instruction *> { public: InstCombinerImpl(InstructionWorklist &Worklist, BuilderTy &Builder, - bool MinimizeSize, AAResults *AA, AssumptionCache &AC, + Function &F, AAResults *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, TargetTransformInfo &TTI, DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI, ProfileSummaryInfo *PSI, const DataLayout &DL, ReversePostOrderTraversal<BasicBlock *> &RPOT) - : InstCombiner(Worklist, Builder, MinimizeSize, AA, AC, TLI, TTI, DT, ORE, - BFI, BPI, PSI, DL, RPOT) {} + : InstCombiner(Worklist, Builder, F, AA, AC, TLI, TTI, DT, ORE, BFI, BPI, + PSI, DL, RPOT) {} virtual ~InstCombinerImpl() = default; @@ -469,6 +470,17 @@ private: Value *simplifyNonNullOperand(Value *V, bool HasDereferenceable, unsigned Depth = 0); + SelectInst *createSelectInst(Value *C, Value *S1, Value *S2, + const Twine &NameStr = "", + InsertPosition InsertBefore = nullptr, + Instruction *MDFrom = nullptr) { + SelectInst *SI = + SelectInst::Create(C, S1, S2, NameStr, InsertBefore, MDFrom); + if (!MDFrom) + setExplicitlyUnknownBranchWeightsIfProfiled(*SI, F, DEBUG_TYPE); + return SI; + } + public: /// Create and insert the idiom we use to indicate a block is unreachable /// without having to rewrite the CFG from within InstCombine. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 550f095..d457e0c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -1253,7 +1253,7 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { // shl (zext i1 X), C1 --> select (X, 1 << C1, 0) if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { auto *NewC = Builder.CreateShl(ConstantInt::get(Ty, 1), C1); - return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty)); + return createSelectInst(X, NewC, ConstantInt::getNullValue(Ty)); } } diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index f0ddd5c..8fbaf68 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1735,7 +1735,7 @@ Instruction *InstCombinerImpl::foldBinopOfSextBoolToSelect(BinaryOperator &BO) { Constant *Zero = ConstantInt::getNullValue(BO.getType()); Value *TVal = Builder.CreateBinOp(BO.getOpcode(), Ones, C); Value *FVal = Builder.CreateBinOp(BO.getOpcode(), Zero, C); - return SelectInst::Create(X, TVal, FVal); + return createSelectInst(X, TVal, FVal); } static Value *simplifyOperationIntoSelectOperand(Instruction &I, SelectInst *SI, @@ -5934,8 +5934,8 @@ static bool combineInstructionsOverFunction( LLVM_DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " << F.getName() << "\n"); - InstCombinerImpl IC(Worklist, Builder, F.hasMinSize(), AA, AC, TLI, TTI, DT, - ORE, BFI, BPI, PSI, DL, RPOT); + InstCombinerImpl IC(Worklist, Builder, F, AA, AC, TLI, TTI, DT, ORE, BFI, + BPI, PSI, DL, RPOT); IC.MaxArraySizeForCombine = MaxArraySize; bool MadeChangeInThisIteration = IC.prepareWorklist(F); MadeChangeInThisIteration |= IC.run(); diff --git a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp index e5bf2d1..d842275 100644 --- a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -35,6 +35,7 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/Path.h" #include "llvm/Support/Regex.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation/CFGMST.h" #include "llvm/Transforms/Instrumentation/GCOVProfiler.h" @@ -92,8 +93,10 @@ class GCOVFunction; class GCOVProfiler { public: - GCOVProfiler() : GCOVProfiler(GCOVOptions::getDefault()) {} - GCOVProfiler(const GCOVOptions &Opts) : Options(Opts) {} + GCOVProfiler() + : GCOVProfiler(GCOVOptions::getDefault(), *vfs::getRealFileSystem()) {} + GCOVProfiler(const GCOVOptions &Opts, vfs::FileSystem &VFS) + : Options(Opts), VFS(VFS) {} bool runOnModule(Module &M, function_ref<BlockFrequencyInfo *(Function &F)> GetBFI, function_ref<BranchProbabilityInfo *(Function &F)> GetBPI, @@ -110,6 +113,7 @@ public: os->write_zeros(4 - s.size() % 4); } void writeBytes(const char *Bytes, int Size) { os->write(Bytes, Size); } + vfs::FileSystem &getVirtualFileSystem() const { return VFS; } private: // Create the .gcno files for the Module based on DebugInfo. @@ -166,6 +170,7 @@ private: std::vector<Regex> ExcludeRe; DenseSet<const BasicBlock *> ExecBlocks; StringMap<bool> InstrumentedFiles; + vfs::FileSystem &VFS; }; struct BBInfo { @@ -214,10 +219,10 @@ static StringRef getFunctionName(const DISubprogram *SP) { /// Prefer relative paths in the coverage notes. Clang also may split /// up absolute paths into a directory and filename component. When /// the relative path doesn't exist, reconstruct the absolute path. -static SmallString<128> getFilename(const DIScope *SP) { +static SmallString<128> getFilename(const DIScope *SP, vfs::FileSystem &VFS) { SmallString<128> Path; StringRef RelPath = SP->getFilename(); - if (sys::fs::exists(RelPath)) + if (VFS.exists(RelPath)) Path = RelPath; else sys::path::append(Path, SP->getDirectory(), SP->getFilename()); @@ -357,7 +362,7 @@ namespace { void writeOut(uint32_t CfgChecksum) { write(GCOV_TAG_FUNCTION); - SmallString<128> Filename = getFilename(SP); + SmallString<128> Filename = getFilename(SP, P->getVirtualFileSystem()); uint32_t BlockLen = 3 + wordsOfString(getFunctionName(SP)); BlockLen += 1 + wordsOfString(Filename) + 4; @@ -455,7 +460,7 @@ bool GCOVProfiler::isFunctionInstrumented(const Function &F) { if (FilterRe.empty() && ExcludeRe.empty()) { return true; } - SmallString<128> Filename = getFilename(F.getSubprogram()); + SmallString<128> Filename = getFilename(F.getSubprogram(), VFS); auto It = InstrumentedFiles.find(Filename); if (It != InstrumentedFiles.end()) { return It->second; @@ -467,7 +472,7 @@ bool GCOVProfiler::isFunctionInstrumented(const Function &F) { // Path can be // /usr/lib/gcc/x86_64-linux-gnu/8/../../../../include/c++/8/bits/*.h so for // such a case we must get the real_path. - if (sys::fs::real_path(Filename, RealPath)) { + if (VFS.getRealPath(Filename, RealPath)) { // real_path can fail with path like "foo.c". RealFilename = Filename; } else { @@ -524,9 +529,10 @@ std::string GCOVProfiler::mangleName(const DICompileUnit *CU, SmallString<128> Filename = CU->getFilename(); sys::path::replace_extension(Filename, Notes ? "gcno" : "gcda"); StringRef FName = sys::path::filename(Filename); - SmallString<128> CurPath; - if (sys::fs::current_path(CurPath)) + ErrorOr<std::string> CWD = VFS.getCurrentWorkingDirectory(); + if (!CWD) return std::string(FName); + SmallString<128> CurPath{*CWD}; sys::path::append(CurPath, FName); return std::string(CurPath); } @@ -554,7 +560,7 @@ bool GCOVProfiler::runOnModule( PreservedAnalyses GCOVProfilerPass::run(Module &M, ModuleAnalysisManager &AM) { - GCOVProfiler Profiler(GCOVOpts); + GCOVProfiler Profiler(GCOVOpts, *VFS); FunctionAnalysisManager &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); @@ -789,7 +795,7 @@ bool GCOVProfiler::emitProfileNotes( // Add the function line number to the lines of the entry block // to have a counter for the function definition. uint32_t Line = SP->getLine(); - auto Filename = getFilename(SP); + auto Filename = getFilename(SP, VFS); BranchProbabilityInfo *BPI = GetBPI(F); BlockFrequencyInfo *BFI = GetBFI(F); @@ -881,7 +887,7 @@ bool GCOVProfiler::emitProfileNotes( if (SP != getDISubprogram(Scope)) continue; - GCOVLines &Lines = Block.getFile(getFilename(Loc->getScope())); + GCOVLines &Lines = Block.getFile(getFilename(Loc->getScope(), VFS)); Lines.addLine(Loc.getLine()); } Line = 0; diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index b988957..cf076b9a 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -5810,10 +5810,22 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { case Intrinsic::x86_avx512_vpdpbusds_512: case Intrinsic::x86_avx2_vpdpbssd_128: case Intrinsic::x86_avx2_vpdpbssd_256: + case Intrinsic::x86_avx10_vpdpbssd_512: case Intrinsic::x86_avx2_vpdpbssds_128: case Intrinsic::x86_avx2_vpdpbssds_256: - case Intrinsic::x86_avx10_vpdpbssd_512: case Intrinsic::x86_avx10_vpdpbssds_512: + case Intrinsic::x86_avx2_vpdpbsud_128: + case Intrinsic::x86_avx2_vpdpbsud_256: + case Intrinsic::x86_avx10_vpdpbsud_512: + case Intrinsic::x86_avx2_vpdpbsuds_128: + case Intrinsic::x86_avx2_vpdpbsuds_256: + case Intrinsic::x86_avx10_vpdpbsuds_512: + case Intrinsic::x86_avx2_vpdpbuud_128: + case Intrinsic::x86_avx2_vpdpbuud_256: + case Intrinsic::x86_avx10_vpdpbuud_512: + case Intrinsic::x86_avx2_vpdpbuuds_128: + case Intrinsic::x86_avx2_vpdpbuuds_256: + case Intrinsic::x86_avx10_vpdpbuuds_512: handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/4, /*EltSize=*/8); break; diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index ab5c9c9..12fb46d 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1762,9 +1762,10 @@ public: GeneratedRTChecks(PredicatedScalarEvolution &PSE, DominatorTree *DT, LoopInfo *LI, TargetTransformInfo *TTI, const DataLayout &DL, TTI::TargetCostKind CostKind) - : DT(DT), LI(LI), TTI(TTI), SCEVExp(*PSE.getSE(), DL, "scev.check"), - MemCheckExp(*PSE.getSE(), DL, "scev.check"), PSE(PSE), - CostKind(CostKind) {} + : DT(DT), LI(LI), TTI(TTI), + SCEVExp(*PSE.getSE(), DL, "scev.check", /*PreserveLCSSA=*/false), + MemCheckExp(*PSE.getSE(), DL, "scev.check", /*PreserveLCSSA=*/false), + PSE(PSE), CostKind(CostKind) {} /// Generate runtime checks in SCEVCheckBlock and MemCheckBlock, so we can /// accurately estimate the cost of the runtime checks. The blocks are @@ -3902,8 +3903,7 @@ void LoopVectorizationPlanner::emitInvalidCostRemarks( if (VF.isScalar()) continue; - VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind, - *CM.PSE.getSE()); + VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind); precomputeCosts(*Plan, VF, CostCtx); auto Iter = vp_depth_first_deep(Plan->getVectorLoopRegion()->getEntry()); for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Iter)) { @@ -4160,8 +4160,7 @@ VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor() { // Add on other costs that are modelled in VPlan, but not in the legacy // cost model. - VPCostContext CostCtx(CM.TTI, *CM.TLI, *P, CM, CM.CostKind, - *CM.PSE.getSE()); + VPCostContext CostCtx(CM.TTI, *CM.TLI, *P, CM, CM.CostKind); VPRegionBlock *VectorRegion = P->getVectorLoopRegion(); assert(VectorRegion && "Expected to have a vector region!"); for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( @@ -6836,7 +6835,7 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF, InstructionCost LoopVectorizationPlanner::cost(VPlan &Plan, ElementCount VF) const { - VPCostContext CostCtx(CM.TTI, *CM.TLI, Plan, CM, CM.CostKind, *PSE.getSE()); + VPCostContext CostCtx(CM.TTI, *CM.TLI, Plan, CM, CM.CostKind); InstructionCost Cost = precomputeCosts(Plan, VF, CostCtx); // Now compute and add the VPlan-based cost. @@ -7069,8 +7068,7 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() { // simplifications not accounted for in the legacy cost model. If that's the // case, don't trigger the assertion, as the extra simplifications may cause a // different VF to be picked by the VPlan-based cost model. - VPCostContext CostCtx(CM.TTI, *CM.TLI, BestPlan, CM, CM.CostKind, - *CM.PSE.getSE()); + VPCostContext CostCtx(CM.TTI, *CM.TLI, BestPlan, CM, CM.CostKind); precomputeCosts(BestPlan, BestFactor.Width, CostCtx); // Verify that the VPlan-based and legacy cost models agree, except for VPlans // with early exits and plans with additional VPlan simplifications. The @@ -7486,12 +7484,13 @@ VPRecipeBuilder::tryToWidenMemory(Instruction *I, ArrayRef<VPValue *> Operands, VPSingleDefRecipe *VectorPtr; if (Reverse) { // When folding the tail, we may compute an address that we don't in the - // original scalar loop and it may not be inbounds. Drop Inbounds in that - // case. + // original scalar loop: drop the GEP no-wrap flags in this case. + // Otherwise preserve existing flags without no-unsigned-wrap, as we will + // emit negative indices. GEPNoWrapFlags Flags = - (CM.foldTailByMasking() || !GEP || !GEP->isInBounds()) + CM.foldTailByMasking() || !GEP ? GEPNoWrapFlags::none() - : GEPNoWrapFlags::inBounds(); + : GEP->getNoWrapFlags().withoutNoUnsignedWrap(); VectorPtr = new VPVectorEndPointerRecipe(Ptr, &Plan.getVF(), getLoadStoreType(I), /*Stride*/ -1, Flags, I->getDebugLoc()); @@ -8163,14 +8162,12 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, VFRange SubRange = {VF, MaxVFTimes2}; if (auto Plan = tryToBuildVPlanWithVPRecipes( std::unique_ptr<VPlan>(VPlan0->duplicate()), SubRange, &LVer)) { - bool HasScalarVF = Plan->hasScalarVFOnly(); // Now optimize the initial VPlan. - if (!HasScalarVF) - VPlanTransforms::runPass(VPlanTransforms::truncateToMinimalBitwidths, - *Plan, CM.getMinimalBitwidths()); + VPlanTransforms::runPass(VPlanTransforms::truncateToMinimalBitwidths, + *Plan, CM.getMinimalBitwidths()); VPlanTransforms::runPass(VPlanTransforms::optimize, *Plan); // TODO: try to put it close to addActiveLaneMask(). - if (CM.foldTailWithEVL() && !HasScalarVF) + if (CM.foldTailWithEVL()) VPlanTransforms::runPass(VPlanTransforms::addExplicitVectorLength, *Plan, CM.getMaxSafeElements()); assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid"); @@ -8600,8 +8597,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( // TODO: Enable following transform when the EVL-version of extended-reduction // and mulacc-reduction are implemented. if (!CM.foldTailWithEVL()) { - VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind, - *CM.PSE.getSE()); + VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind); VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan, CostCtx, Range); } @@ -10058,7 +10054,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { bool ForceVectorization = Hints.getForce() == LoopVectorizeHints::FK_Enabled; VPCostContext CostCtx(CM.TTI, *CM.TLI, LVP.getPlanFor(VF.Width), CM, - CM.CostKind, *CM.PSE.getSE()); + CM.CostKind); if (!ForceVectorization && !isOutsideLoopWorkProfitable(Checks, VF, L, PSE, CostCtx, LVP.getPlanFor(VF.Width), SEL, diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 065622e..f77d587 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1100,7 +1100,9 @@ class BinOpSameOpcodeHelper { // constant + x cannot be -constant - x // instead, it should be x - -constant if (Pos == 1 || - (FromOpcode == Instruction::Add && ToOpcode == Instruction::Sub)) + ((FromOpcode == Instruction::Add || FromOpcode == Instruction::Or || + FromOpcode == Instruction::Xor) && + ToOpcode == Instruction::Sub)) return SmallVector<Value *>({LHS, RHS}); return SmallVector<Value *>({RHS, LHS}); } @@ -1188,6 +1190,10 @@ public: if (CIValue.isAllOnes()) InterchangeableMask = CanBeAll; break; + case Instruction::Xor: + if (CIValue.isZero()) + InterchangeableMask = XorBIT | OrBIT | AndBIT | SubBIT | AddBIT; + break; default: if (CIValue.isZero()) InterchangeableMask = CanBeAll; @@ -2099,6 +2105,7 @@ public: UserIgnoreList = nullptr; PostponedGathers.clear(); ValueToGatherNodes.clear(); + TreeEntryToStridedPtrInfoMap.clear(); } unsigned getTreeSize() const { return VectorizableTree.size(); } @@ -8942,6 +8949,8 @@ BoUpSLP::findExternalStoreUsersReorderIndices(TreeEntry *TE) const { void BoUpSLP::buildTree(ArrayRef<Value *> Roots, const SmallDenseSet<Value *> &UserIgnoreLst) { deleteTree(); + assert(TreeEntryToStridedPtrInfoMap.empty() && + "TreeEntryToStridedPtrInfoMap is not cleared"); UserIgnoreList = &UserIgnoreLst; if (!allSameType(Roots)) return; @@ -8950,6 +8959,8 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, void BoUpSLP::buildTree(ArrayRef<Value *> Roots) { deleteTree(); + assert(TreeEntryToStridedPtrInfoMap.empty() && + "TreeEntryToStridedPtrInfoMap is not cleared"); if (!allSameType(Roots)) return; buildTreeRec(Roots, 0, EdgeInfo()); diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index 728d291..81f1956 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -1750,8 +1750,7 @@ VPCostContext::getOperandInfo(VPValue *V) const { } InstructionCost VPCostContext::getScalarizationOverhead( - Type *ResultTy, ArrayRef<const VPValue *> Operands, ElementCount VF, - bool AlwaysIncludeReplicatingR) { + Type *ResultTy, ArrayRef<const VPValue *> Operands, ElementCount VF) { if (VF.isScalar()) return 0; @@ -1771,9 +1770,7 @@ InstructionCost VPCostContext::getScalarizationOverhead( SmallPtrSet<const VPValue *, 4> UniqueOperands; SmallVector<Type *> Tys; for (auto *Op : Operands) { - if (Op->isLiveIn() || - (!AlwaysIncludeReplicatingR && - isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op)) || + if (Op->isLiveIn() || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) || !UniqueOperands.insert(Op).second) continue; Tys.push_back(toVectorizedTy(Types.inferScalarType(Op), VF)); diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 0822511..10d704d 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2997,6 +2997,10 @@ class VPExpressionRecipe : public VPSingleDefRecipe { /// vector operands, performing a reduction.add on the result, and adding /// the scalar result to a chain. MulAccReduction, + /// Represent an inloop multiply-accumulate reduction, multiplying the + /// extended vector operands, negating the multiplication, performing a + /// reduction.add on the result, and adding the scalar result to a chain. + ExtNegatedMulAccReduction, }; /// Type of the expression. @@ -3020,6 +3024,19 @@ public: VPWidenRecipe *Mul, VPReductionRecipe *Red) : VPExpressionRecipe(ExpressionTypes::ExtMulAccReduction, {Ext0, Ext1, Mul, Red}) {} + VPExpressionRecipe(VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1, + VPWidenRecipe *Mul, VPWidenRecipe *Sub, + VPReductionRecipe *Red) + : VPExpressionRecipe(ExpressionTypes::ExtNegatedMulAccReduction, + {Ext0, Ext1, Mul, Sub, Red}) { + assert(Mul->getOpcode() == Instruction::Mul && "Expected a mul"); + assert(Red->getRecurrenceKind() == RecurKind::Add && + "Expected an add reduction"); + assert(getNumOperands() >= 3 && "Expected at least three operands"); + [[maybe_unused]] auto *SubConst = dyn_cast<ConstantInt>(getOperand(2)->getLiveInIRValue()); + assert(SubConst && SubConst->getValue() == 0 && + Sub->getOpcode() == Instruction::Sub && "Expected a negating sub"); + } ~VPExpressionRecipe() override { for (auto *R : reverse(ExpressionRecipes)) diff --git a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h index 2a8baec..fe59774 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h +++ b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h @@ -349,14 +349,12 @@ struct VPCostContext { LoopVectorizationCostModel &CM; SmallPtrSet<Instruction *, 8> SkipCostComputation; TargetTransformInfo::TargetCostKind CostKind; - ScalarEvolution &SE; VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI, const VPlan &Plan, LoopVectorizationCostModel &CM, - TargetTransformInfo::TargetCostKind CostKind, - ScalarEvolution &SE) + TargetTransformInfo::TargetCostKind CostKind) : TTI(TTI), TLI(TLI), Types(Plan), LLVMCtx(Plan.getContext()), CM(CM), - CostKind(CostKind), SE(SE) {} + CostKind(CostKind) {} /// Return the cost for \p UI with \p VF using the legacy cost model as /// fallback until computing the cost of all recipes migrates to VPlan. @@ -376,12 +374,10 @@ struct VPCostContext { /// Estimate the overhead of scalarizing a recipe with result type \p ResultTy /// and \p Operands with \p VF. This is a convenience wrapper for the - /// type-based getScalarizationOverhead API. If \p AlwaysIncludeReplicatingR - /// is true, always compute the cost of scalarizing replicating operands. - InstructionCost - getScalarizationOverhead(Type *ResultTy, ArrayRef<const VPValue *> Operands, - ElementCount VF, - bool AlwaysIncludeReplicatingR = false); + /// type-based getScalarizationOverhead API. + InstructionCost getScalarizationOverhead(Type *ResultTy, + ArrayRef<const VPValue *> Operands, + ElementCount VF); }; /// This class can be used to assign names to VPValues. For VPValues without diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index b5e30cb..3a55710 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -2839,12 +2839,17 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy, Ctx.CostKind); - case ExpressionTypes::ExtMulAccReduction: + case ExpressionTypes::ExtNegatedMulAccReduction: + assert(Opcode == Instruction::Add && "Unexpected opcode"); + Opcode = Instruction::Sub; + LLVM_FALLTHROUGH; + case ExpressionTypes::ExtMulAccReduction: { return Ctx.TTI.getMulAccReductionCost( cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() == Instruction::ZExt, Opcode, RedTy, SrcVecTy, Ctx.CostKind); } + } llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum"); } @@ -2890,6 +2895,30 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, O << ")"; break; } + case ExpressionTypes::ExtNegatedMulAccReduction: { + getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); + O << " + reduce." + << Instruction::getOpcodeName( + RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind())) + << " (sub (0, mul"; + auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]); + Mul->printFlags(O); + O << "("; + getOperand(0)->printAsOperand(O, SlotTracker); + auto *Ext0 = cast<VPWidenCastRecipe>(ExpressionRecipes[0]); + O << " " << Instruction::getOpcodeName(Ext0->getOpcode()) << " to " + << *Ext0->getResultType() << "), ("; + getOperand(1)->printAsOperand(O, SlotTracker); + auto *Ext1 = cast<VPWidenCastRecipe>(ExpressionRecipes[1]); + O << " " << Instruction::getOpcodeName(Ext1->getOpcode()) << " to " + << *Ext1->getResultType() << ")"; + if (Red->isConditional()) { + O << ", "; + Red->getCondOp()->printAsOperand(O, SlotTracker); + } + O << "))"; + break; + } case ExpressionTypes::MulAccReduction: case ExpressionTypes::ExtMulAccReduction: { getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); @@ -3069,61 +3098,6 @@ bool VPReplicateRecipe::shouldPack() const { }); } -/// Returns true if \p Ptr is a pointer computation for which the legacy cost -/// model computes a SCEV expression when computing the address cost. -static bool shouldUseAddressAccessSCEV(const VPValue *Ptr) { - auto *PtrR = Ptr->getDefiningRecipe(); - if (!PtrR || !((isa<VPReplicateRecipe>(PtrR) && - cast<VPReplicateRecipe>(PtrR)->getOpcode() == - Instruction::GetElementPtr) || - isa<VPWidenGEPRecipe>(PtrR))) - return false; - - // We are looking for a GEP where all indices are either loop invariant or - // inductions. - for (VPValue *Opd : drop_begin(PtrR->operands())) { - if (!Opd->isDefinedOutsideLoopRegions() && - !isa<VPScalarIVStepsRecipe, VPWidenIntOrFpInductionRecipe>(Opd)) - return false; - } - - return true; -} - -/// Returns true if \p V is used as part of the address of another load or -/// store. -static bool isUsedByLoadStoreAddress(const VPUser *V) { - SmallPtrSet<const VPUser *, 4> Seen; - SmallVector<const VPUser *> WorkList = {V}; - - while (!WorkList.empty()) { - auto *Cur = dyn_cast<VPSingleDefRecipe>(WorkList.pop_back_val()); - if (!Cur || !Seen.insert(Cur).second) - continue; - - for (VPUser *U : Cur->users()) { - if (auto *InterleaveR = dyn_cast<VPInterleaveBase>(U)) - if (InterleaveR->getAddr() == Cur) - return true; - if (auto *RepR = dyn_cast<VPReplicateRecipe>(U)) { - if (RepR->getOpcode() == Instruction::Load && - RepR->getOperand(0) == Cur) - return true; - if (RepR->getOpcode() == Instruction::Store && - RepR->getOperand(1) == Cur) - return true; - } - if (auto *MemR = dyn_cast<VPWidenMemoryRecipe>(U)) { - if (MemR->getAddr() == Cur && MemR->isConsecutive()) - return true; - } - } - - append_range(WorkList, cast<VPSingleDefRecipe>(Cur)->users()); - } - return false; -} - InstructionCost VPReplicateRecipe::computeCost(ElementCount VF, VPCostContext &Ctx) const { Instruction *UI = cast<Instruction>(getUnderlyingValue()); @@ -3231,58 +3205,21 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF, } case Instruction::Load: case Instruction::Store: { - if (VF.isScalable() && !isSingleScalar()) - return InstructionCost::getInvalid(); - + if (isSingleScalar()) { + bool IsLoad = UI->getOpcode() == Instruction::Load; + Type *ValTy = Ctx.Types.inferScalarType(IsLoad ? this : getOperand(0)); + Type *ScalarPtrTy = Ctx.Types.inferScalarType(getOperand(IsLoad ? 0 : 1)); + const Align Alignment = getLoadStoreAlignment(UI); + unsigned AS = getLoadStoreAddressSpace(UI); + TTI::OperandValueInfo OpInfo = TTI::getOperandInfo(UI->getOperand(0)); + InstructionCost ScalarMemOpCost = Ctx.TTI.getMemoryOpCost( + UI->getOpcode(), ValTy, Alignment, AS, Ctx.CostKind, OpInfo, UI); + return ScalarMemOpCost + Ctx.TTI.getAddressComputationCost( + ScalarPtrTy, nullptr, nullptr, Ctx.CostKind); + } // TODO: See getMemInstScalarizationCost for how to handle replicating and // predicated cases. - const VPRegionBlock *ParentRegion = getParent()->getParent(); - if (ParentRegion && ParentRegion->isReplicator()) - break; - - bool IsLoad = UI->getOpcode() == Instruction::Load; - const VPValue *PtrOp = getOperand(!IsLoad); - // TODO: Handle cases where we need to pass a SCEV to - // getAddressComputationCost. - if (shouldUseAddressAccessSCEV(PtrOp)) - break; - - Type *ValTy = Ctx.Types.inferScalarType(IsLoad ? this : getOperand(0)); - Type *ScalarPtrTy = Ctx.Types.inferScalarType(PtrOp); - const Align Alignment = getLoadStoreAlignment(UI); - unsigned AS = getLoadStoreAddressSpace(UI); - TTI::OperandValueInfo OpInfo = TTI::getOperandInfo(UI->getOperand(0)); - InstructionCost ScalarMemOpCost = Ctx.TTI.getMemoryOpCost( - UI->getOpcode(), ValTy, Alignment, AS, Ctx.CostKind, OpInfo); - - Type *PtrTy = isSingleScalar() ? ScalarPtrTy : toVectorTy(ScalarPtrTy, VF); - - InstructionCost ScalarCost = - ScalarMemOpCost + Ctx.TTI.getAddressComputationCost( - PtrTy, &Ctx.SE, nullptr, Ctx.CostKind); - if (isSingleScalar()) - return ScalarCost; - - SmallVector<const VPValue *> OpsToScalarize; - Type *ResultTy = Type::getVoidTy(PtrTy->getContext()); - // Set ResultTy and OpsToScalarize, if scalarization is needed. Currently we - // don't assign scalarization overhead in general, if the target prefers - // vectorized addressing or the loaded value is used as part of an address - // of another load or store. - bool PreferVectorizedAddressing = Ctx.TTI.prefersVectorizedAddressing(); - if (PreferVectorizedAddressing || !isUsedByLoadStoreAddress(this)) { - bool EfficientVectorLoadStore = - Ctx.TTI.supportsEfficientVectorElementLoadStore(); - if (!(IsLoad && !PreferVectorizedAddressing) && - !(!IsLoad && EfficientVectorLoadStore)) - append_range(OpsToScalarize, operands()); - - if (!EfficientVectorLoadStore) - ResultTy = Ctx.Types.inferScalarType(this); - } - - return (ScalarCost * VF.getFixedValue()) + - Ctx.getScalarizationOverhead(ResultTy, OpsToScalarize, VF, true); + break; } } diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 5252e1f..a73b083 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -2124,6 +2124,8 @@ static void licm(VPlan &Plan) { void VPlanTransforms::truncateToMinimalBitwidths( VPlan &Plan, const MapVector<Instruction *, uint64_t> &MinBWs) { + if (Plan.hasScalarVFOnly()) + return; // Keep track of created truncates, so they can be re-used. Note that we // cannot use RAUW after creating a new truncate, as this would could make // other uses have different types for their operands, making them invalidly @@ -2704,6 +2706,8 @@ static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) { /// void VPlanTransforms::addExplicitVectorLength( VPlan &Plan, const std::optional<unsigned> &MaxSafeElements) { + if (Plan.hasScalarVFOnly()) + return; VPBasicBlock *Header = Plan.getVectorLoopRegion()->getEntryBasicBlock(); auto *CanonicalIVPHI = Plan.getCanonicalIV(); @@ -3543,7 +3547,15 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, }; VPValue *VecOp = Red->getVecOp(); + VPRecipeBase *Sub = nullptr; VPValue *A, *B; + VPValue *Tmp = nullptr; + // Sub reductions could have a sub between the add reduction and vec op. + if (match(VecOp, + m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Tmp)))) { + Sub = VecOp->getDefiningRecipe(); + VecOp = Tmp; + } // Try to match reduce.add(mul(...)). if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) { auto *RecipeA = @@ -3560,12 +3572,21 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, IsMulAccValidAndClampRange(RecipeA->getOpcode() == Instruction::CastOps::ZExt, Mul, RecipeA, RecipeB, nullptr)) { + if (Sub) + return new VPExpressionRecipe(RecipeA, RecipeB, Mul, + cast<VPWidenRecipe>(Sub), Red); return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red); } // Match reduce.add(mul). - if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr)) + // TODO: Add an expression type for this variant with a negated mul + if (!Sub && + IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr)) return new VPExpressionRecipe(Mul, Red); } + // TODO: Add an expression type for negated versions of other expression + // variants. + if (Sub) + return nullptr; // Match reduce.add(ext(mul(ext(A), ext(B)))). // All extend recipes must have same opcode or A == B // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))). diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 32704bd..d6eb00d 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -1031,6 +1031,16 @@ bool VectorCombine::foldBitOpOfCastConstant(Instruction &I) { // Create the cast operation directly to ensure we get a new instruction Instruction *NewCast = CastInst::Create(CastOpcode, NewOp, I.getType()); + // Preserve cast instruction flags + if (RHSFlags.NNeg) + NewCast->setNonNeg(); + if (RHSFlags.NUW) + NewCast->setHasNoUnsignedWrap(); + if (RHSFlags.NSW) + NewCast->setHasNoSignedWrap(); + + NewCast->andIRFlags(LHSCast); + // Insert the new instruction Value *Result = Builder.Insert(NewCast); |