diff options
Diffstat (limited to 'llvm/lib')
66 files changed, 2081 insertions, 1004 deletions
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 295b6d3..6885351 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -200,6 +200,8 @@ void Embedder::computeEmbeddings() const { if (F.isDeclaration()) return; + FuncVector = Embedding(Dimension, 0.0); + // Consider only the basic blocks that are reachable from entry for (const BasicBlock *BB : depth_first(&F)) { computeEmbeddings(*BB); diff --git a/llvm/lib/BinaryFormat/DXContainer.cpp b/llvm/lib/BinaryFormat/DXContainer.cpp index c06a3e3..22f5180 100644 --- a/llvm/lib/BinaryFormat/DXContainer.cpp +++ b/llvm/lib/BinaryFormat/DXContainer.cpp @@ -18,6 +18,91 @@ using namespace llvm; using namespace llvm::dxbc; +#define ROOT_PARAMETER(Val, Enum) \ + case Val: \ + return true; +bool llvm::dxbc::isValidParameterType(uint32_t V) { + switch (V) { +#include "llvm/BinaryFormat/DXContainerConstants.def" + } + return false; +} + +bool llvm::dxbc::isValidRangeType(uint32_t V) { + return V <= llvm::to_underlying(dxil::ResourceClass::LastEntry); +} + +#define SHADER_VISIBILITY(Val, Enum) \ + case Val: \ + return true; +bool llvm::dxbc::isValidShaderVisibility(uint32_t V) { + switch (V) { +#include "llvm/BinaryFormat/DXContainerConstants.def" + } + return false; +} + +#define FILTER(Val, Enum) \ + case Val: \ + return true; +bool llvm::dxbc::isValidSamplerFilter(uint32_t V) { + switch (V) { +#include "llvm/BinaryFormat/DXContainerConstants.def" + } + return false; +} + +#define TEXTURE_ADDRESS_MODE(Val, Enum) \ + case Val: \ + return true; +bool llvm::dxbc::isValidAddress(uint32_t V) { + switch (V) { +#include "llvm/BinaryFormat/DXContainerConstants.def" + } + return false; +} + +#define COMPARISON_FUNC(Val, Enum) \ + case Val: \ + return true; +bool llvm::dxbc::isValidComparisonFunc(uint32_t V) { + switch (V) { +#include "llvm/BinaryFormat/DXContainerConstants.def" + } + return false; +} + +#define STATIC_BORDER_COLOR(Val, Enum) \ + case Val: \ + return true; +bool llvm::dxbc::isValidBorderColor(uint32_t V) { + switch (V) { +#include "llvm/BinaryFormat/DXContainerConstants.def" + } + return false; +} + +bool llvm::dxbc::isValidRootDesciptorFlags(uint32_t V) { + using FlagT = dxbc::RootDescriptorFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + return V < NextPowerOf2(LargestValue); +} + +bool llvm::dxbc::isValidDescriptorRangeFlags(uint32_t V) { + using FlagT = dxbc::DescriptorRangeFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + return V < NextPowerOf2(LargestValue); +} + +bool llvm::dxbc::isValidStaticSamplerFlags(uint32_t V) { + using FlagT = dxbc::StaticSamplerFlags; + uint32_t LargestValue = + llvm::to_underlying(FlagT::LLVM_BITMASK_LARGEST_ENUMERATOR); + return V < NextPowerOf2(LargestValue); +} + dxbc::PartType dxbc::parsePartType(StringRef S) { #define CONTAINER_PART(PartName) .Case(#PartName, PartType::PartName) return StringSwitch<dxbc::PartType>(S) diff --git a/llvm/lib/CAS/CMakeLists.txt b/llvm/lib/CAS/CMakeLists.txt index 7ae5f7e..bca39b6 100644 --- a/llvm/lib/CAS/CMakeLists.txt +++ b/llvm/lib/CAS/CMakeLists.txt @@ -7,6 +7,7 @@ add_llvm_component_library(LLVMCAS MappedFileRegionArena.cpp ObjectStore.cpp OnDiskCommon.cpp + OnDiskDataAllocator.cpp OnDiskTrieRawHashMap.cpp ADDITIONAL_HEADER_DIRS diff --git a/llvm/lib/CAS/OnDiskDataAllocator.cpp b/llvm/lib/CAS/OnDiskDataAllocator.cpp new file mode 100644 index 0000000..13bbd66 --- /dev/null +++ b/llvm/lib/CAS/OnDiskDataAllocator.cpp @@ -0,0 +1,234 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file Implements OnDiskDataAllocator. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/CAS/OnDiskDataAllocator.h" +#include "DatabaseFile.h" +#include "llvm/Config/llvm-config.h" + +using namespace llvm; +using namespace llvm::cas; +using namespace llvm::cas::ondisk; + +#if LLVM_ENABLE_ONDISK_CAS + +//===----------------------------------------------------------------------===// +// DataAllocator data structures. +//===----------------------------------------------------------------------===// + +namespace { +/// DataAllocator table layout: +/// - [8-bytes: Generic table header] +/// - 8-bytes: AllocatorOffset (reserved for implementing free lists) +/// - 8-bytes: Size for user data header +/// - <user data buffer> +/// +/// Record layout: +/// - <data> +class DataAllocatorHandle { +public: + static constexpr TableHandle::TableKind Kind = + TableHandle::TableKind::DataAllocator; + + struct Header { + TableHandle::Header GenericHeader; + std::atomic<int64_t> AllocatorOffset; + const uint64_t UserHeaderSize; + }; + + operator TableHandle() const { + if (!H) + return TableHandle(); + return TableHandle(*Region, H->GenericHeader); + } + + Expected<MutableArrayRef<char>> allocate(MappedFileRegionArena &Alloc, + size_t DataSize) { + assert(&Alloc.getRegion() == Region); + auto Ptr = Alloc.allocate(DataSize); + if (LLVM_UNLIKELY(!Ptr)) + return Ptr.takeError(); + return MutableArrayRef(*Ptr, DataSize); + } + + explicit operator bool() const { return H; } + const Header &getHeader() const { return *H; } + MappedFileRegion &getRegion() const { return *Region; } + + MutableArrayRef<uint8_t> getUserHeader() { + return MutableArrayRef(reinterpret_cast<uint8_t *>(H + 1), + H->UserHeaderSize); + } + + static Expected<DataAllocatorHandle> + create(MappedFileRegionArena &Alloc, StringRef Name, uint32_t UserHeaderSize); + + DataAllocatorHandle() = default; + DataAllocatorHandle(MappedFileRegion &Region, Header &H) + : Region(&Region), H(&H) {} + DataAllocatorHandle(MappedFileRegion &Region, intptr_t HeaderOffset) + : DataAllocatorHandle( + Region, *reinterpret_cast<Header *>(Region.data() + HeaderOffset)) { + } + +private: + MappedFileRegion *Region = nullptr; + Header *H = nullptr; +}; + +} // end anonymous namespace + +struct OnDiskDataAllocator::ImplType { + DatabaseFile File; + DataAllocatorHandle Store; +}; + +Expected<DataAllocatorHandle> +DataAllocatorHandle::create(MappedFileRegionArena &Alloc, StringRef Name, + uint32_t UserHeaderSize) { + // Allocate. + auto Offset = + Alloc.allocateOffset(sizeof(Header) + UserHeaderSize + Name.size() + 1); + if (LLVM_UNLIKELY(!Offset)) + return Offset.takeError(); + + // Construct the header and the name. + assert(Name.size() <= UINT16_MAX && "Expected smaller table name"); + auto *H = new (Alloc.getRegion().data() + *Offset) + Header{{TableHandle::TableKind::DataAllocator, + static_cast<uint16_t>(Name.size()), + static_cast<int32_t>(sizeof(Header) + UserHeaderSize)}, + /*AllocatorOffset=*/{0}, + /*UserHeaderSize=*/UserHeaderSize}; + // Memset UserHeader. + char *UserHeader = reinterpret_cast<char *>(H + 1); + memset(UserHeader, 0, UserHeaderSize); + // Write database file name (null-terminated). + char *NameStorage = UserHeader + UserHeaderSize; + llvm::copy(Name, NameStorage); + NameStorage[Name.size()] = 0; + return DataAllocatorHandle(Alloc.getRegion(), *H); +} + +Expected<OnDiskDataAllocator> OnDiskDataAllocator::create( + const Twine &PathTwine, const Twine &TableNameTwine, uint64_t MaxFileSize, + std::optional<uint64_t> NewFileInitialSize, uint32_t UserHeaderSize, + function_ref<void(void *)> UserHeaderInit) { + assert(!UserHeaderSize || UserHeaderInit); + SmallString<128> PathStorage; + StringRef Path = PathTwine.toStringRef(PathStorage); + SmallString<128> TableNameStorage; + StringRef TableName = TableNameTwine.toStringRef(TableNameStorage); + + // Constructor for if the file doesn't exist. + auto NewDBConstructor = [&](DatabaseFile &DB) -> Error { + auto Store = + DataAllocatorHandle::create(DB.getAlloc(), TableName, UserHeaderSize); + if (LLVM_UNLIKELY(!Store)) + return Store.takeError(); + + if (auto E = DB.addTable(*Store)) + return E; + + if (UserHeaderSize) + UserHeaderInit(Store->getUserHeader().data()); + return Error::success(); + }; + + // Get or create the file. + Expected<DatabaseFile> File = + DatabaseFile::create(Path, MaxFileSize, NewDBConstructor); + if (!File) + return File.takeError(); + + // Find the table and validate it. + std::optional<TableHandle> Table = File->findTable(TableName); + if (!Table) + return createTableConfigError(std::errc::argument_out_of_domain, Path, + TableName, "table not found"); + if (Error E = checkTable("table kind", (size_t)DataAllocatorHandle::Kind, + (size_t)Table->getHeader().Kind, Path, TableName)) + return std::move(E); + auto Store = Table->cast<DataAllocatorHandle>(); + assert(Store && "Already checked the kind"); + + // Success. + OnDiskDataAllocator::ImplType Impl{DatabaseFile(std::move(*File)), Store}; + return OnDiskDataAllocator(std::make_unique<ImplType>(std::move(Impl))); +} + +Expected<OnDiskDataAllocator::OnDiskPtr> +OnDiskDataAllocator::allocate(size_t Size) { + auto Data = Impl->Store.allocate(Impl->File.getAlloc(), Size); + if (LLVM_UNLIKELY(!Data)) + return Data.takeError(); + + return OnDiskPtr(FileOffset(Data->data() - Impl->Store.getRegion().data()), + *Data); +} + +Expected<ArrayRef<char>> OnDiskDataAllocator::get(FileOffset Offset, + size_t Size) const { + assert(Offset); + assert(Impl); + if (Offset.get() + Size >= Impl->File.getAlloc().size()) + return createStringError(make_error_code(std::errc::protocol_error), + "requested size too large in allocator"); + return ArrayRef<char>{Impl->File.getRegion().data() + Offset.get(), Size}; +} + +MutableArrayRef<uint8_t> OnDiskDataAllocator::getUserHeader() { + return Impl->Store.getUserHeader(); +} + +size_t OnDiskDataAllocator::size() const { return Impl->File.size(); } +size_t OnDiskDataAllocator::capacity() const { + return Impl->File.getRegion().size(); +} + +OnDiskDataAllocator::OnDiskDataAllocator(std::unique_ptr<ImplType> Impl) + : Impl(std::move(Impl)) {} + +#else // !LLVM_ENABLE_ONDISK_CAS + +struct OnDiskDataAllocator::ImplType {}; + +Expected<OnDiskDataAllocator> OnDiskDataAllocator::create( + const Twine &Path, const Twine &TableName, uint64_t MaxFileSize, + std::optional<uint64_t> NewFileInitialSize, uint32_t UserHeaderSize, + function_ref<void(void *)> UserHeaderInit) { + return createStringError(make_error_code(std::errc::not_supported), + "OnDiskDataAllocator is not supported"); +} + +Expected<OnDiskDataAllocator::OnDiskPtr> +OnDiskDataAllocator::allocate(size_t Size) { + return createStringError(make_error_code(std::errc::not_supported), + "OnDiskDataAllocator is not supported"); +} + +Expected<ArrayRef<char>> OnDiskDataAllocator::get(FileOffset Offset, + size_t Size) const { + return createStringError(make_error_code(std::errc::not_supported), + "OnDiskDataAllocator is not supported"); +} + +MutableArrayRef<uint8_t> OnDiskDataAllocator::getUserHeader() { return {}; } + +size_t OnDiskDataAllocator::size() const { return 0; } +size_t OnDiskDataAllocator::capacity() const { return 0; } + +#endif // LLVM_ENABLE_ONDISK_CAS + +OnDiskDataAllocator::OnDiskDataAllocator(OnDiskDataAllocator &&RHS) = default; +OnDiskDataAllocator & +OnDiskDataAllocator::operator=(OnDiskDataAllocator &&RHS) = default; +OnDiskDataAllocator::~OnDiskDataAllocator() = default; diff --git a/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp b/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp index 9403893..323b21e 100644 --- a/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp +++ b/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp @@ -427,7 +427,7 @@ TrieRawHashMapHandle::createRecord(MappedFileRegionArena &Alloc, return Record; } -Expected<OnDiskTrieRawHashMap::const_pointer> +Expected<OnDiskTrieRawHashMap::ConstOnDiskPtr> OnDiskTrieRawHashMap::recoverFromFileOffset(FileOffset Offset) const { // Check alignment. if (!isAligned(MappedFileRegionArena::getAlign(), Offset.get())) @@ -448,17 +448,17 @@ OnDiskTrieRawHashMap::recoverFromFileOffset(FileOffset Offset) const { // Looks okay... TrieRawHashMapHandle::RecordData D = Impl->Trie.getRecord(SubtrieSlotValue::getDataOffset(Offset)); - return const_pointer(D.Proxy, D.getFileOffset()); + return ConstOnDiskPtr(D.Proxy, D.getFileOffset()); } -OnDiskTrieRawHashMap::const_pointer +OnDiskTrieRawHashMap::ConstOnDiskPtr OnDiskTrieRawHashMap::find(ArrayRef<uint8_t> Hash) const { TrieRawHashMapHandle Trie = Impl->Trie; assert(Hash.size() == Trie.getNumHashBytes() && "Invalid hash"); SubtrieHandle S = Trie.getRoot(); if (!S) - return const_pointer(); + return ConstOnDiskPtr(); TrieHashIndexGenerator IndexGen = Trie.getIndexGen(S, Hash); size_t Index = IndexGen.next(); @@ -466,13 +466,13 @@ OnDiskTrieRawHashMap::find(ArrayRef<uint8_t> Hash) const { // Try to set the content. SubtrieSlotValue V = S.load(Index); if (!V) - return const_pointer(); + return ConstOnDiskPtr(); // Check for an exact match. if (V.isData()) { TrieRawHashMapHandle::RecordData D = Trie.getRecord(V); - return D.Proxy.Hash == Hash ? const_pointer(D.Proxy, D.getFileOffset()) - : const_pointer(); + return D.Proxy.Hash == Hash ? ConstOnDiskPtr(D.Proxy, D.getFileOffset()) + : ConstOnDiskPtr(); } Index = IndexGen.next(); @@ -490,7 +490,7 @@ void SubtrieHandle::reinitialize(uint32_t StartBit, uint32_t NumBits) { H->NumBits = NumBits; } -Expected<OnDiskTrieRawHashMap::pointer> +Expected<OnDiskTrieRawHashMap::OnDiskPtr> OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash, LazyInsertOnConstructCB OnConstruct, LazyInsertOnLeakCB OnLeak) { @@ -523,7 +523,8 @@ OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash, } if (S->compare_exchange_strong(Index, Existing, NewRecord->Offset)) - return pointer(NewRecord->Proxy, NewRecord->Offset.asDataFileOffset()); + return OnDiskPtr(NewRecord->Proxy, + NewRecord->Offset.asDataFileOffset()); // Race means that Existing is no longer empty; fall through... } @@ -540,8 +541,8 @@ OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash, if (NewRecord && OnLeak) OnLeak(NewRecord->Offset.asDataFileOffset(), NewRecord->Proxy, ExistingRecord.Offset.asDataFileOffset(), ExistingRecord.Proxy); - return pointer(ExistingRecord.Proxy, - ExistingRecord.Offset.asDataFileOffset()); + return OnDiskPtr(ExistingRecord.Proxy, + ExistingRecord.Offset.asDataFileOffset()); } // Sink the existing content as long as the indexes match. @@ -1135,7 +1136,7 @@ OnDiskTrieRawHashMap::create(const Twine &PathTwine, const Twine &TrieNameTwine, "OnDiskTrieRawHashMap is not supported"); } -Expected<OnDiskTrieRawHashMap::pointer> +Expected<OnDiskTrieRawHashMap::OnDiskPtr> OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash, LazyInsertOnConstructCB OnConstruct, LazyInsertOnLeakCB OnLeak) { @@ -1143,15 +1144,15 @@ OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash, "OnDiskTrieRawHashMap is not supported"); } -Expected<OnDiskTrieRawHashMap::const_pointer> +Expected<OnDiskTrieRawHashMap::ConstOnDiskPtr> OnDiskTrieRawHashMap::recoverFromFileOffset(FileOffset Offset) const { return createStringError(make_error_code(std::errc::not_supported), "OnDiskTrieRawHashMap is not supported"); } -OnDiskTrieRawHashMap::const_pointer +OnDiskTrieRawHashMap::ConstOnDiskPtr OnDiskTrieRawHashMap::find(ArrayRef<uint8_t> Hash) const { - return const_pointer(); + return ConstOnDiskPtr(); } void OnDiskTrieRawHashMap::print( diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp index 7a0cf40..707f0c3 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp @@ -651,8 +651,11 @@ Error MetadataParser::validateRootSignature( "RegisterSpace", Descriptor.RegisterSpace)); if (RSD.Version > 1) { - if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version, - Descriptor.Flags)) + bool IsValidFlag = + dxbc::isValidRootDesciptorFlags(Descriptor.Flags) && + hlsl::rootsig::verifyRootDescriptorFlag( + RSD.Version, dxbc::RootDescriptorFlags(Descriptor.Flags)); + if (!IsValidFlag) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<uint32_t>>( @@ -676,9 +679,11 @@ Error MetadataParser::validateRootSignature( make_error<RootSignatureValidationError<uint32_t>>( "NumDescriptors", Range.NumDescriptors)); - if (!hlsl::rootsig::verifyDescriptorRangeFlag( - RSD.Version, Range.RangeType, - dxbc::DescriptorRangeFlags(Range.Flags))) + bool IsValidFlag = dxbc::isValidDescriptorRangeFlags(Range.Flags) && + hlsl::rootsig::verifyDescriptorRangeFlag( + RSD.Version, Range.RangeType, + dxbc::DescriptorRangeFlags(Range.Flags)); + if (!IsValidFlag) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<uint32_t>>( @@ -731,8 +736,11 @@ Error MetadataParser::validateRootSignature( joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<uint32_t>>( "RegisterSpace", Sampler.RegisterSpace)); - - if (!hlsl::rootsig::verifyStaticSamplerFlags(RSD.Version, Sampler.Flags)) + bool IsValidFlag = + dxbc::isValidStaticSamplerFlags(Sampler.Flags) && + hlsl::rootsig::verifyStaticSamplerFlags( + RSD.Version, dxbc::StaticSamplerFlags(Sampler.Flags)); + if (!IsValidFlag) DeferredErrs = joinErrors(std::move(DeferredErrs), make_error<RootSignatureValidationError<uint32_t>>( diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp index 8a2b03d..30408df 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp @@ -34,7 +34,8 @@ bool verifyRegisterSpace(uint32_t RegisterSpace) { return !(RegisterSpace >= 0xFFFFFFF0); } -bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { +bool verifyRootDescriptorFlag(uint32_t Version, + dxbc::RootDescriptorFlags FlagsVal) { using FlagT = dxbc::RootDescriptorFlags; FlagT Flags = FlagT(FlagsVal); if (Version == 1) @@ -56,7 +57,6 @@ bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) { bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, dxbc::DescriptorRangeFlags Flags) { using FlagT = dxbc::DescriptorRangeFlags; - const bool IsSampler = (Type == dxil::ResourceClass::Sampler); if (Version == 1) { @@ -113,13 +113,8 @@ bool verifyDescriptorRangeFlag(uint32_t Version, dxil::ResourceClass Type, return (Flags & ~Mask) == FlagT::None; } -bool verifyStaticSamplerFlags(uint32_t Version, uint32_t FlagsNumber) { - uint32_t LargestValue = llvm::to_underlying( - dxbc::StaticSamplerFlags::LLVM_BITMASK_LARGEST_ENUMERATOR); - if (FlagsNumber >= NextPowerOf2(LargestValue)) - return false; - - dxbc::StaticSamplerFlags Flags = dxbc::StaticSamplerFlags(FlagsNumber); +bool verifyStaticSamplerFlags(uint32_t Version, + dxbc::StaticSamplerFlags Flags) { if (Version <= 2) return Flags == dxbc::StaticSamplerFlags::None; diff --git a/llvm/lib/IR/Globals.cpp b/llvm/lib/IR/Globals.cpp index 1a7a5c5..c3a472b 100644 --- a/llvm/lib/IR/Globals.cpp +++ b/llvm/lib/IR/Globals.cpp @@ -419,6 +419,7 @@ findBaseObject(const Constant *C, DenseSet<const GlobalAlias *> &Aliases, case Instruction::PtrToAddr: case Instruction::PtrToInt: case Instruction::BitCast: + case Instruction::AddrSpaceCast: case Instruction::GetElementPtr: return findBaseObject(CE->getOperand(0), Aliases, Op); default: diff --git a/llvm/lib/IR/Mangler.cpp b/llvm/lib/IR/Mangler.cpp index ca6a480..55c825d 100644 --- a/llvm/lib/IR/Mangler.cpp +++ b/llvm/lib/IR/Mangler.cpp @@ -307,6 +307,19 @@ std::optional<std::string> llvm::getArm64ECMangledFunctionName(StringRef Name) { if (Name.contains("$$h")) return std::nullopt; + // Handle MD5 mangled names, which use a slightly different rule from + // other C++ manglings. + // + // A non-Arm64EC function: + // + // ??@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa@ + // + // An Arm64EC function: + // + // ??@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa@$$h@ + if (Name.starts_with("??@") && Name.ends_with("@")) + return (Name + "$$h@").str(); + // Ask the demangler where we should insert "$$h". auto InsertIdx = getArm64ECInsertionPointInMangledName(Name); if (!InsertIdx) @@ -324,6 +337,10 @@ llvm::getArm64ECDemangledFunctionName(StringRef Name) { if (Name[0] != '?') return std::nullopt; + // MD5 mangled name; see comment in getArm64ECMangledFunctionName. + if (Name.starts_with("??@") && Name.ends_with("@$$h@")) + return Name.drop_back(4).str(); + // Drop the ARM64EC "$$h" tag. std::pair<StringRef, StringRef> Pair = Name.split("$$h"); if (Pair.second.empty()) diff --git a/llvm/lib/Object/OffloadBundle.cpp b/llvm/lib/Object/OffloadBundle.cpp index 329dcbf..046cde8 100644 --- a/llvm/lib/Object/OffloadBundle.cpp +++ b/llvm/lib/Object/OffloadBundle.cpp @@ -25,38 +25,71 @@ using namespace llvm; using namespace llvm::object; -static llvm::TimerGroup - OffloadBundlerTimerGroup("Offload Bundler Timer Group", - "Timer group for offload bundler"); +static TimerGroup OffloadBundlerTimerGroup("Offload Bundler Timer Group", + "Timer group for offload bundler"); // Extract an Offload bundle (usually a Offload Bundle) from a fat_bin -// section +// section. Error extractOffloadBundle(MemoryBufferRef Contents, uint64_t SectionOffset, StringRef FileName, SmallVectorImpl<OffloadBundleFatBin> &Bundles) { size_t Offset = 0; size_t NextbundleStart = 0; + StringRef Magic; + std::unique_ptr<MemoryBuffer> Buffer; // There could be multiple offloading bundles stored at this section. - while (NextbundleStart != StringRef::npos) { - std::unique_ptr<MemoryBuffer> Buffer = + while ((NextbundleStart != StringRef::npos) && + (Offset < Contents.getBuffer().size())) { + Buffer = MemoryBuffer::getMemBuffer(Contents.getBuffer().drop_front(Offset), "", /*RequiresNullTerminator=*/false); - // Create the FatBinBindle object. This will also create the Bundle Entry - // list info. - auto FatBundleOrErr = - OffloadBundleFatBin::create(*Buffer, SectionOffset + Offset, FileName); - if (!FatBundleOrErr) - return FatBundleOrErr.takeError(); - - // Add current Bundle to list. - Bundles.emplace_back(std::move(**FatBundleOrErr)); - - // Find the next bundle by searching for the magic string - StringRef Str = Buffer->getBuffer(); - NextbundleStart = Str.find(StringRef("__CLANG_OFFLOAD_BUNDLE__"), 24); + if (identify_magic((*Buffer).getBuffer()) == + file_magic::offload_bundle_compressed) { + Magic = "CCOB"; + // Decompress this bundle first. + NextbundleStart = (*Buffer).getBuffer().find(Magic, Magic.size()); + if (NextbundleStart == StringRef::npos) + NextbundleStart = (*Buffer).getBuffer().size(); + + ErrorOr<std::unique_ptr<MemoryBuffer>> CodeOrErr = + MemoryBuffer::getMemBuffer( + (*Buffer).getBuffer().take_front(NextbundleStart), FileName, + false); + if (std::error_code EC = CodeOrErr.getError()) + return createFileError(FileName, EC); + + Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr = + CompressedOffloadBundle::decompress(**CodeOrErr, nullptr); + if (!DecompressedBufferOrErr) + return createStringError("failed to decompress input: " + + toString(DecompressedBufferOrErr.takeError())); + + auto FatBundleOrErr = OffloadBundleFatBin::create( + **DecompressedBufferOrErr, Offset, FileName, true); + if (!FatBundleOrErr) + return FatBundleOrErr.takeError(); + + // Add current Bundle to list. + Bundles.emplace_back(std::move(**FatBundleOrErr)); + + } else if (identify_magic((*Buffer).getBuffer()) == + file_magic::offload_bundle) { + // Create the OffloadBundleFatBin object. This will also create the Bundle + // Entry list info. + auto FatBundleOrErr = OffloadBundleFatBin::create( + *Buffer, SectionOffset + Offset, FileName); + if (!FatBundleOrErr) + return FatBundleOrErr.takeError(); + + // Add current Bundle to list. + Bundles.emplace_back(std::move(**FatBundleOrErr)); + + Magic = "__CLANG_OFFLOAD_BUNDLE__"; + NextbundleStart = (*Buffer).getBuffer().find(Magic, Magic.size()); + } if (NextbundleStart != StringRef::npos) Offset += NextbundleStart; @@ -82,7 +115,7 @@ Error OffloadBundleFatBin::readEntries(StringRef Buffer, NumberOfEntries = NumOfEntries; - // For each Bundle Entry (code object) + // For each Bundle Entry (code object). for (uint64_t I = 0; I < NumOfEntries; I++) { uint64_t EntrySize; uint64_t EntryOffset; @@ -112,19 +145,22 @@ Error OffloadBundleFatBin::readEntries(StringRef Buffer, Expected<std::unique_ptr<OffloadBundleFatBin>> OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset, - StringRef FileName) { + StringRef FileName, bool Decompress) { if (Buf.getBufferSize() < 24) return errorCodeToError(object_error::parse_failed); // Check for magic bytes. - if (identify_magic(Buf.getBuffer()) != file_magic::offload_bundle) + if ((identify_magic(Buf.getBuffer()) != file_magic::offload_bundle) && + (identify_magic(Buf.getBuffer()) != + file_magic::offload_bundle_compressed)) return errorCodeToError(object_error::parse_failed); std::unique_ptr<OffloadBundleFatBin> TheBundle( new OffloadBundleFatBin(Buf, FileName)); - // Read the Bundle Entries - Error Err = TheBundle->readEntries(Buf.getBuffer(), SectionOffset); + // Read the Bundle Entries. + Error Err = + TheBundle->readEntries(Buf.getBuffer(), Decompress ? 0 : SectionOffset); if (Err) return Err; @@ -132,7 +168,7 @@ OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset, } Error OffloadBundleFatBin::extractBundle(const ObjectFile &Source) { - // This will extract all entries in the Bundle + // This will extract all entries in the Bundle. for (OffloadBundleEntry &Entry : Entries) { if (Entry.Size == 0) @@ -161,40 +197,21 @@ Error object::extractOffloadBundleFatBinary( return Buffer.takeError(); // If it does not start with the reserved suffix, just skip this section. - if ((llvm::identify_magic(*Buffer) == llvm::file_magic::offload_bundle) || + if ((llvm::identify_magic(*Buffer) == file_magic::offload_bundle) || (llvm::identify_magic(*Buffer) == - llvm::file_magic::offload_bundle_compressed)) { + file_magic::offload_bundle_compressed)) { uint64_t SectionOffset = 0; if (Obj.isELF()) { SectionOffset = ELFSectionRef(Sec).getOffset(); - } else if (Obj.isCOFF()) // TODO: add COFF Support + } else if (Obj.isCOFF()) // TODO: add COFF Support. return createStringError(object_error::parse_failed, - "COFF object files not supported.\n"); + "COFF object files not supported"); MemoryBufferRef Contents(*Buffer, Obj.getFileName()); - - if (llvm::identify_magic(*Buffer) == - llvm::file_magic::offload_bundle_compressed) { - // Decompress the input if necessary. - Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr = - CompressedOffloadBundle::decompress(Contents, false); - - if (!DecompressedBufferOrErr) - return createStringError( - inconvertibleErrorCode(), - "Failed to decompress input: " + - llvm::toString(DecompressedBufferOrErr.takeError())); - - MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr; - if (Error Err = extractOffloadBundle(DecompressedInput, SectionOffset, - Obj.getFileName(), Bundles)) - return Err; - } else { - if (Error Err = extractOffloadBundle(Contents, SectionOffset, - Obj.getFileName(), Bundles)) - return Err; - } + if (Error Err = extractOffloadBundle(Contents, SectionOffset, + Obj.getFileName(), Bundles)) + return Err; } } return Error::success(); @@ -222,8 +239,22 @@ Error object::extractCodeObject(const ObjectFile &Source, int64_t Offset, return Error::success(); } +Error object::extractCodeObject(const MemoryBufferRef Buffer, int64_t Offset, + int64_t Size, StringRef OutputFileName) { + Expected<std::unique_ptr<FileOutputBuffer>> BufferOrErr = + FileOutputBuffer::create(OutputFileName, Size); + if (!BufferOrErr) + return BufferOrErr.takeError(); + + std::unique_ptr<FileOutputBuffer> Buf = std::move(*BufferOrErr); + std::copy(Buffer.getBufferStart() + Offset, + Buffer.getBufferStart() + Offset + Size, Buf->getBufferStart()); + + return Buf->commit(); +} + // given a file name, offset, and size, extract data into a code object file, -// into file <SourceFile>-offset<Offset>-size<Size>.co +// into file "<SourceFile>-offset<Offset>-size<Size>.co". Error object::extractOffloadBundleByURI(StringRef URIstr) { // create a URI object Expected<std::unique_ptr<OffloadBundleURI>> UriOrErr( @@ -236,7 +267,7 @@ Error object::extractOffloadBundleByURI(StringRef URIstr) { OutputFile += "-offset" + itostr(Uri.Offset) + "-size" + itostr(Uri.Size) + ".co"; - // Create an ObjectFile object from uri.file_uri + // Create an ObjectFile object from uri.file_uri. auto ObjOrErr = ObjectFile::createObjectFile(Uri.FileName); if (!ObjOrErr) return ObjOrErr.takeError(); @@ -249,7 +280,7 @@ Error object::extractOffloadBundleByURI(StringRef URIstr) { return Error::success(); } -// Utility function to format numbers with commas +// Utility function to format numbers with commas. static std::string formatWithCommas(unsigned long long Value) { std::string Num = std::to_string(Value); int InsertPosition = Num.length() - 3; @@ -260,87 +291,278 @@ static std::string formatWithCommas(unsigned long long Value) { return Num; } -llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> -CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input, - bool Verbose) { - StringRef Blob = Input.getBuffer(); +Expected<std::unique_ptr<MemoryBuffer>> +CompressedOffloadBundle::compress(compression::Params P, + const MemoryBuffer &Input, uint16_t Version, + raw_ostream *VerboseStream) { + if (!compression::zstd::isAvailable() && !compression::zlib::isAvailable()) + return createStringError("compression not supported."); + Timer HashTimer("Hash Calculation Timer", "Hash calculation time", + OffloadBundlerTimerGroup); + if (VerboseStream) + HashTimer.startTimer(); + MD5 Hash; + MD5::MD5Result Result; + Hash.update(Input.getBuffer()); + Hash.final(Result); + uint64_t TruncatedHash = Result.low(); + if (VerboseStream) + HashTimer.stopTimer(); + + SmallVector<uint8_t, 0> CompressedBuffer; + auto BufferUint8 = ArrayRef<uint8_t>( + reinterpret_cast<const uint8_t *>(Input.getBuffer().data()), + Input.getBuffer().size()); + Timer CompressTimer("Compression Timer", "Compression time", + OffloadBundlerTimerGroup); + if (VerboseStream) + CompressTimer.startTimer(); + compression::compress(P, BufferUint8, CompressedBuffer); + if (VerboseStream) + CompressTimer.stopTimer(); + + uint16_t CompressionMethod = static_cast<uint16_t>(P.format); + + // Store sizes in 64-bit variables first. + uint64_t UncompressedSize64 = Input.getBuffer().size(); + uint64_t TotalFileSize64; + + // Calculate total file size based on version. + if (Version == 2) { + // For V2, ensure the sizes don't exceed 32-bit limit. + if (UncompressedSize64 > std::numeric_limits<uint32_t>::max()) + return createStringError("uncompressed size (%llu) exceeds version 2 " + "unsigned 32-bit integer limit", + UncompressedSize64); + TotalFileSize64 = MagicNumber.size() + sizeof(uint32_t) + sizeof(Version) + + sizeof(CompressionMethod) + sizeof(uint32_t) + + sizeof(TruncatedHash) + CompressedBuffer.size(); + if (TotalFileSize64 > std::numeric_limits<uint32_t>::max()) + return createStringError("total file size (%llu) exceeds version 2 " + "unsigned 32-bit integer limit", + TotalFileSize64); + + } else { // Version 3. + TotalFileSize64 = MagicNumber.size() + sizeof(uint64_t) + sizeof(Version) + + sizeof(CompressionMethod) + sizeof(uint64_t) + + sizeof(TruncatedHash) + CompressedBuffer.size(); + } + + SmallVector<char, 0> FinalBuffer; + raw_svector_ostream OS(FinalBuffer); + OS << MagicNumber; + OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version)); + OS.write(reinterpret_cast<const char *>(&CompressionMethod), + sizeof(CompressionMethod)); + + // Write size fields according to version. + if (Version == 2) { + uint32_t TotalFileSize32 = static_cast<uint32_t>(TotalFileSize64); + uint32_t UncompressedSize32 = static_cast<uint32_t>(UncompressedSize64); + OS.write(reinterpret_cast<const char *>(&TotalFileSize32), + sizeof(TotalFileSize32)); + OS.write(reinterpret_cast<const char *>(&UncompressedSize32), + sizeof(UncompressedSize32)); + } else { // Version 3. + OS.write(reinterpret_cast<const char *>(&TotalFileSize64), + sizeof(TotalFileSize64)); + OS.write(reinterpret_cast<const char *>(&UncompressedSize64), + sizeof(UncompressedSize64)); + } + + OS.write(reinterpret_cast<const char *>(&TruncatedHash), + sizeof(TruncatedHash)); + OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()), + CompressedBuffer.size()); + + if (VerboseStream) { + auto MethodUsed = P.format == compression::Format::Zstd ? "zstd" : "zlib"; + double CompressionRate = + static_cast<double>(UncompressedSize64) / CompressedBuffer.size(); + double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime(); + double CompressionSpeedMBs = + (UncompressedSize64 / (1024.0 * 1024.0)) / CompressionTimeSeconds; + *VerboseStream << "Compressed bundle format version: " << Version << "\n" + << "Total file size (including headers): " + << formatWithCommas(TotalFileSize64) << " bytes\n" + << "Compression method used: " << MethodUsed << "\n" + << "Compression level: " << P.level << "\n" + << "Binary size before compression: " + << formatWithCommas(UncompressedSize64) << " bytes\n" + << "Binary size after compression: " + << formatWithCommas(CompressedBuffer.size()) << " bytes\n" + << "Compression rate: " << format("%.2lf", CompressionRate) + << "\n" + << "Compression ratio: " + << format("%.2lf%%", 100.0 / CompressionRate) << "\n" + << "Compression speed: " + << format("%.2lf MB/s", CompressionSpeedMBs) << "\n" + << "Truncated MD5 hash: " << format_hex(TruncatedHash, 16) + << "\n"; + } + + return MemoryBuffer::getMemBufferCopy( + StringRef(FinalBuffer.data(), FinalBuffer.size())); +} + +// Use packed structs to avoid padding, such that the structs map the serialized +// format. +LLVM_PACKED_START +union RawCompressedBundleHeader { + struct CommonFields { + uint32_t Magic; + uint16_t Version; + uint16_t Method; + }; + + struct V1Header { + CommonFields Common; + uint32_t UncompressedFileSize; + uint64_t Hash; + }; + + struct V2Header { + CommonFields Common; + uint32_t FileSize; + uint32_t UncompressedFileSize; + uint64_t Hash; + }; + + struct V3Header { + CommonFields Common; + uint64_t FileSize; + uint64_t UncompressedFileSize; + uint64_t Hash; + }; + + CommonFields Common; + V1Header V1; + V2Header V2; + V3Header V3; +}; +LLVM_PACKED_END + +// Helper method to get header size based on version. +static size_t getHeaderSize(uint16_t Version) { + switch (Version) { + case 1: + return sizeof(RawCompressedBundleHeader::V1Header); + case 2: + return sizeof(RawCompressedBundleHeader::V2Header); + case 3: + return sizeof(RawCompressedBundleHeader::V3Header); + default: + llvm_unreachable("Unsupported version"); + } +} - if (Blob.size() < V1HeaderSize) - return llvm::MemoryBuffer::getMemBufferCopy(Blob); +Expected<CompressedOffloadBundle::CompressedBundleHeader> +CompressedOffloadBundle::CompressedBundleHeader::tryParse(StringRef Blob) { + assert(Blob.size() >= sizeof(RawCompressedBundleHeader::CommonFields)); + assert(identify_magic(Blob) == file_magic::offload_bundle_compressed); + + RawCompressedBundleHeader Header; + std::memcpy(&Header, Blob.data(), std::min(Blob.size(), sizeof(Header))); + + CompressedBundleHeader Normalized; + Normalized.Version = Header.Common.Version; + + size_t RequiredSize = getHeaderSize(Normalized.Version); + + if (Blob.size() < RequiredSize) + return createStringError("compressed bundle header size too small"); + + switch (Normalized.Version) { + case 1: + Normalized.UncompressedFileSize = Header.V1.UncompressedFileSize; + Normalized.Hash = Header.V1.Hash; + break; + case 2: + Normalized.FileSize = Header.V2.FileSize; + Normalized.UncompressedFileSize = Header.V2.UncompressedFileSize; + Normalized.Hash = Header.V2.Hash; + break; + case 3: + Normalized.FileSize = Header.V3.FileSize; + Normalized.UncompressedFileSize = Header.V3.UncompressedFileSize; + Normalized.Hash = Header.V3.Hash; + break; + default: + return createStringError("unknown compressed bundle version"); + } - if (llvm::identify_magic(Blob) != - llvm::file_magic::offload_bundle_compressed) { - if (Verbose) - llvm::errs() << "Uncompressed bundle.\n"; - return llvm::MemoryBuffer::getMemBufferCopy(Blob); + // Determine compression format. + switch (Header.Common.Method) { + case static_cast<uint16_t>(compression::Format::Zlib): + case static_cast<uint16_t>(compression::Format::Zstd): + Normalized.CompressionFormat = + static_cast<compression::Format>(Header.Common.Method); + break; + default: + return createStringError("unknown compressing method"); } - size_t CurrentOffset = MagicSize; + return Normalized; +} - uint16_t ThisVersion; - memcpy(&ThisVersion, Blob.data() + CurrentOffset, sizeof(uint16_t)); - CurrentOffset += VersionFieldSize; +Expected<std::unique_ptr<MemoryBuffer>> +CompressedOffloadBundle::decompress(const MemoryBuffer &Input, + raw_ostream *VerboseStream) { + StringRef Blob = Input.getBuffer(); - uint16_t CompressionMethod; - memcpy(&CompressionMethod, Blob.data() + CurrentOffset, sizeof(uint16_t)); - CurrentOffset += MethodFieldSize; + // Check minimum header size (using V1 as it's the smallest). + if (Blob.size() < sizeof(RawCompressedBundleHeader::CommonFields)) + return MemoryBuffer::getMemBufferCopy(Blob); - uint32_t TotalFileSize; - if (ThisVersion >= 2) { - if (Blob.size() < V2HeaderSize) - return createStringError(inconvertibleErrorCode(), - "Compressed bundle header size too small"); - memcpy(&TotalFileSize, Blob.data() + CurrentOffset, sizeof(uint32_t)); - CurrentOffset += FileSizeFieldSize; + if (identify_magic(Blob) != file_magic::offload_bundle_compressed) { + if (VerboseStream) + *VerboseStream << "Uncompressed bundle\n"; + return MemoryBuffer::getMemBufferCopy(Blob); } - uint32_t UncompressedSize; - memcpy(&UncompressedSize, Blob.data() + CurrentOffset, sizeof(uint32_t)); - CurrentOffset += UncompressedSizeFieldSize; - - uint64_t StoredHash; - memcpy(&StoredHash, Blob.data() + CurrentOffset, sizeof(uint64_t)); - CurrentOffset += HashFieldSize; - - llvm::compression::Format CompressionFormat; - if (CompressionMethod == - static_cast<uint16_t>(llvm::compression::Format::Zlib)) - CompressionFormat = llvm::compression::Format::Zlib; - else if (CompressionMethod == - static_cast<uint16_t>(llvm::compression::Format::Zstd)) - CompressionFormat = llvm::compression::Format::Zstd; - else - return createStringError(inconvertibleErrorCode(), - "Unknown compressing method"); - - llvm::Timer DecompressTimer("Decompression Timer", "Decompression time", - OffloadBundlerTimerGroup); - if (Verbose) + Expected<CompressedBundleHeader> HeaderOrErr = + CompressedBundleHeader::tryParse(Blob); + if (!HeaderOrErr) + return HeaderOrErr.takeError(); + + const CompressedBundleHeader &Normalized = *HeaderOrErr; + unsigned ThisVersion = Normalized.Version; + size_t HeaderSize = getHeaderSize(ThisVersion); + + compression::Format CompressionFormat = Normalized.CompressionFormat; + + size_t TotalFileSize = Normalized.FileSize.value_or(0); + size_t UncompressedSize = Normalized.UncompressedFileSize; + auto StoredHash = Normalized.Hash; + + Timer DecompressTimer("Decompression Timer", "Decompression time", + OffloadBundlerTimerGroup); + if (VerboseStream) DecompressTimer.startTimer(); SmallVector<uint8_t, 0> DecompressedData; - StringRef CompressedData = Blob.substr(CurrentOffset); - if (llvm::Error DecompressionError = llvm::compression::decompress( - CompressionFormat, llvm::arrayRefFromStringRef(CompressedData), + StringRef CompressedData = + Blob.substr(HeaderSize, TotalFileSize - HeaderSize); + + if (Error DecompressionError = compression::decompress( + CompressionFormat, arrayRefFromStringRef(CompressedData), DecompressedData, UncompressedSize)) - return createStringError(inconvertibleErrorCode(), - "Could not decompress embedded file contents: " + - llvm::toString(std::move(DecompressionError))); + return createStringError("could not decompress embedded file contents: " + + toString(std::move(DecompressionError))); - if (Verbose) { + if (VerboseStream) { DecompressTimer.stopTimer(); double DecompressionTimeSeconds = DecompressTimer.getTotalTime().getWallTime(); // Recalculate MD5 hash for integrity check. - llvm::Timer HashRecalcTimer("Hash Recalculation Timer", - "Hash recalculation time", - OffloadBundlerTimerGroup); + Timer HashRecalcTimer("Hash Recalculation Timer", "Hash recalculation time", + OffloadBundlerTimerGroup); HashRecalcTimer.startTimer(); - llvm::MD5 Hash; - llvm::MD5::MD5Result Result; - Hash.update(llvm::ArrayRef<uint8_t>(DecompressedData)); + MD5 Hash; + MD5::MD5Result Result; + Hash.update(ArrayRef<uint8_t>(DecompressedData)); Hash.final(Result); uint64_t RecalculatedHash = Result.low(); HashRecalcTimer.stopTimer(); @@ -351,118 +573,28 @@ CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input, double DecompressionSpeedMBs = (UncompressedSize / (1024.0 * 1024.0)) / DecompressionTimeSeconds; - llvm::errs() << "Compressed bundle format version: " << ThisVersion << "\n"; + *VerboseStream << "Compressed bundle format version: " << ThisVersion + << "\n"; if (ThisVersion >= 2) - llvm::errs() << "Total file size (from header): " - << formatWithCommas(TotalFileSize) << " bytes\n"; - llvm::errs() << "Decompression method: " - << (CompressionFormat == llvm::compression::Format::Zlib - ? "zlib" - : "zstd") - << "\n" - << "Size before decompression: " - << formatWithCommas(CompressedData.size()) << " bytes\n" - << "Size after decompression: " - << formatWithCommas(UncompressedSize) << " bytes\n" - << "Compression rate: " - << llvm::format("%.2lf", CompressionRate) << "\n" - << "Compression ratio: " - << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n" - << "Decompression speed: " - << llvm::format("%.2lf MB/s", DecompressionSpeedMBs) << "\n" - << "Stored hash: " << llvm::format_hex(StoredHash, 16) << "\n" - << "Recalculated hash: " - << llvm::format_hex(RecalculatedHash, 16) << "\n" - << "Hashes match: " << (HashMatch ? "Yes" : "No") << "\n"; + *VerboseStream << "Total file size (from header): " + << formatWithCommas(TotalFileSize) << " bytes\n"; + *VerboseStream + << "Decompression method: " + << (CompressionFormat == compression::Format::Zlib ? "zlib" : "zstd") + << "\n" + << "Size before decompression: " + << formatWithCommas(CompressedData.size()) << " bytes\n" + << "Size after decompression: " << formatWithCommas(UncompressedSize) + << " bytes\n" + << "Compression rate: " << format("%.2lf", CompressionRate) << "\n" + << "Compression ratio: " << format("%.2lf%%", 100.0 / CompressionRate) + << "\n" + << "Decompression speed: " + << format("%.2lf MB/s", DecompressionSpeedMBs) << "\n" + << "Stored hash: " << format_hex(StoredHash, 16) << "\n" + << "Recalculated hash: " << format_hex(RecalculatedHash, 16) << "\n" + << "Hashes match: " << (HashMatch ? "Yes" : "No") << "\n"; } - return llvm::MemoryBuffer::getMemBufferCopy( - llvm::toStringRef(DecompressedData)); -} - -llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> -CompressedOffloadBundle::compress(llvm::compression::Params P, - const llvm::MemoryBuffer &Input, - bool Verbose) { - if (!llvm::compression::zstd::isAvailable() && - !llvm::compression::zlib::isAvailable()) - return createStringError(llvm::inconvertibleErrorCode(), - "Compression not supported"); - - llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time", - OffloadBundlerTimerGroup); - if (Verbose) - HashTimer.startTimer(); - llvm::MD5 Hash; - llvm::MD5::MD5Result Result; - Hash.update(Input.getBuffer()); - Hash.final(Result); - uint64_t TruncatedHash = Result.low(); - if (Verbose) - HashTimer.stopTimer(); - - SmallVector<uint8_t, 0> CompressedBuffer; - auto BufferUint8 = llvm::ArrayRef<uint8_t>( - reinterpret_cast<const uint8_t *>(Input.getBuffer().data()), - Input.getBuffer().size()); - - llvm::Timer CompressTimer("Compression Timer", "Compression time", - OffloadBundlerTimerGroup); - if (Verbose) - CompressTimer.startTimer(); - llvm::compression::compress(P, BufferUint8, CompressedBuffer); - if (Verbose) - CompressTimer.stopTimer(); - - uint16_t CompressionMethod = static_cast<uint16_t>(P.format); - uint32_t UncompressedSize = Input.getBuffer().size(); - uint32_t TotalFileSize = MagicNumber.size() + sizeof(TotalFileSize) + - sizeof(Version) + sizeof(CompressionMethod) + - sizeof(UncompressedSize) + sizeof(TruncatedHash) + - CompressedBuffer.size(); - - SmallVector<char, 0> FinalBuffer; - llvm::raw_svector_ostream OS(FinalBuffer); - OS << MagicNumber; - OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version)); - OS.write(reinterpret_cast<const char *>(&CompressionMethod), - sizeof(CompressionMethod)); - OS.write(reinterpret_cast<const char *>(&TotalFileSize), - sizeof(TotalFileSize)); - OS.write(reinterpret_cast<const char *>(&UncompressedSize), - sizeof(UncompressedSize)); - OS.write(reinterpret_cast<const char *>(&TruncatedHash), - sizeof(TruncatedHash)); - OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()), - CompressedBuffer.size()); - - if (Verbose) { - auto MethodUsed = - P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib"; - double CompressionRate = - static_cast<double>(UncompressedSize) / CompressedBuffer.size(); - double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime(); - double CompressionSpeedMBs = - (UncompressedSize / (1024.0 * 1024.0)) / CompressionTimeSeconds; - - llvm::errs() << "Compressed bundle format version: " << Version << "\n" - << "Total file size (including headers): " - << formatWithCommas(TotalFileSize) << " bytes\n" - << "Compression method used: " << MethodUsed << "\n" - << "Compression level: " << P.level << "\n" - << "Binary size before compression: " - << formatWithCommas(UncompressedSize) << " bytes\n" - << "Binary size after compression: " - << formatWithCommas(CompressedBuffer.size()) << " bytes\n" - << "Compression rate: " - << llvm::format("%.2lf", CompressionRate) << "\n" - << "Compression ratio: " - << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n" - << "Compression speed: " - << llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n" - << "Truncated MD5 hash: " - << llvm::format_hex(TruncatedHash, 16) << "\n"; - } - return llvm::MemoryBuffer::getMemBufferCopy( - llvm::StringRef(FinalBuffer.data(), FinalBuffer.size())); + return MemoryBuffer::getMemBufferCopy(toStringRef(DecompressedData)); } diff --git a/llvm/lib/Option/ArgList.cpp b/llvm/lib/Option/ArgList.cpp index c4188b3b..2f4e212 100644 --- a/llvm/lib/Option/ArgList.cpp +++ b/llvm/lib/Option/ArgList.cpp @@ -14,12 +14,14 @@ #include "llvm/Config/llvm-config.h" #include "llvm/Option/Arg.h" #include "llvm/Option/OptSpecifier.h" +#include "llvm/Option/OptTable.h" #include "llvm/Option/Option.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include <algorithm> #include <cassert> +#include <cstddef> #include <memory> #include <string> #include <utility> @@ -202,6 +204,42 @@ void ArgList::print(raw_ostream &O) const { LLVM_DUMP_METHOD void ArgList::dump() const { print(dbgs()); } #endif +StringRef ArgList::getSubCommand( + ArrayRef<OptTable::SubCommand> AllSubCommands, + std::function<void(ArrayRef<StringRef>)> HandleMultipleSubcommands, + std::function<void(ArrayRef<StringRef>)> HandleOtherPositionals) const { + + SmallVector<StringRef, 4> SubCommands; + SmallVector<StringRef, 4> OtherPositionals; + for (const Arg *A : *this) { + if (A->getOption().getKind() != Option::InputClass) + continue; + + size_t OldSize = SubCommands.size(); + for (const OptTable::SubCommand &CMD : AllSubCommands) { + if (StringRef(CMD.Name) == A->getValue()) + SubCommands.push_back(A->getValue()); + } + + if (SubCommands.size() == OldSize) + OtherPositionals.push_back(A->getValue()); + } + + // Invoke callbacks if necessary. + if (SubCommands.size() > 1) { + HandleMultipleSubcommands(SubCommands); + return {}; + } + if (!OtherPositionals.empty()) { + HandleOtherPositionals(OtherPositionals); + return {}; + } + + if (SubCommands.size() == 1) + return SubCommands.front(); + return {}; // No valid usage of subcommand found. +} + void InputArgList::releaseMemory() { // An InputArgList always owns its arguments. for (Arg *A : *this) diff --git a/llvm/lib/Option/OptTable.cpp b/llvm/lib/Option/OptTable.cpp index 6d10e61..14e3b0d 100644 --- a/llvm/lib/Option/OptTable.cpp +++ b/llvm/lib/Option/OptTable.cpp @@ -79,9 +79,12 @@ OptSpecifier::OptSpecifier(const Option *Opt) : ID(Opt->getID()) {} OptTable::OptTable(const StringTable &StrTable, ArrayRef<StringTable::Offset> PrefixesTable, - ArrayRef<Info> OptionInfos, bool IgnoreCase) + ArrayRef<Info> OptionInfos, bool IgnoreCase, + ArrayRef<SubCommand> SubCommands, + ArrayRef<unsigned> SubCommandIDsTable) : StrTable(&StrTable), PrefixesTable(PrefixesTable), - OptionInfos(OptionInfos), IgnoreCase(IgnoreCase) { + OptionInfos(OptionInfos), IgnoreCase(IgnoreCase), + SubCommands(SubCommands), SubCommandIDsTable(SubCommandIDsTable) { // Explicitly zero initialize the error to work around a bug in array // value-initialization on MinGW with gcc 4.3.5. @@ -715,9 +718,10 @@ static const char *getOptionHelpGroup(const OptTable &Opts, OptSpecifier Id) { void OptTable::printHelp(raw_ostream &OS, const char *Usage, const char *Title, bool ShowHidden, bool ShowAllAliases, - Visibility VisibilityMask) const { + Visibility VisibilityMask, + StringRef SubCommand) const { return internalPrintHelp( - OS, Usage, Title, ShowHidden, ShowAllAliases, + OS, Usage, Title, SubCommand, ShowHidden, ShowAllAliases, [VisibilityMask](const Info &CandidateInfo) -> bool { return (CandidateInfo.Visibility & VisibilityMask) == 0; }, @@ -730,7 +734,7 @@ void OptTable::printHelp(raw_ostream &OS, const char *Usage, const char *Title, bool ShowHidden = !(FlagsToExclude & HelpHidden); FlagsToExclude &= ~HelpHidden; return internalPrintHelp( - OS, Usage, Title, ShowHidden, ShowAllAliases, + OS, Usage, Title, /*SubCommand=*/{}, ShowHidden, ShowAllAliases, [FlagsToInclude, FlagsToExclude](const Info &CandidateInfo) { if (FlagsToInclude && !(CandidateInfo.Flags & FlagsToInclude)) return true; @@ -742,16 +746,62 @@ void OptTable::printHelp(raw_ostream &OS, const char *Usage, const char *Title, } void OptTable::internalPrintHelp( - raw_ostream &OS, const char *Usage, const char *Title, bool ShowHidden, - bool ShowAllAliases, std::function<bool(const Info &)> ExcludeOption, + raw_ostream &OS, const char *Usage, const char *Title, StringRef SubCommand, + bool ShowHidden, bool ShowAllAliases, + std::function<bool(const Info &)> ExcludeOption, Visibility VisibilityMask) const { OS << "OVERVIEW: " << Title << "\n\n"; - OS << "USAGE: " << Usage << "\n\n"; // Render help text into a map of group-name to a list of (option, help) // pairs. std::map<std::string, std::vector<OptionInfo>> GroupedOptionHelp; + auto ActiveSubCommand = + std::find_if(SubCommands.begin(), SubCommands.end(), + [&](const auto &C) { return SubCommand == C.Name; }); + if (!SubCommand.empty()) { + assert(ActiveSubCommand != SubCommands.end() && + "Not a valid registered subcommand."); + OS << ActiveSubCommand->HelpText << "\n\n"; + if (!StringRef(ActiveSubCommand->Usage).empty()) + OS << "USAGE: " << ActiveSubCommand->Usage << "\n\n"; + } else { + OS << "USAGE: " << Usage << "\n\n"; + if (SubCommands.size() > 1) { + OS << "SUBCOMMANDS:\n\n"; + for (const auto &C : SubCommands) + OS << C.Name << " - " << C.HelpText << "\n"; + OS << "\n"; + } + } + + auto DoesOptionBelongToSubcommand = [&](const Info &CandidateInfo) { + // Retrieve the SubCommandIDs registered to the given current CandidateInfo + // Option. + ArrayRef<unsigned> SubCommandIDs = + CandidateInfo.getSubCommandIDs(SubCommandIDsTable); + + // If no registered subcommands, then only global options are to be printed. + // If no valid SubCommand (empty) in commandline then print the current + // global CandidateInfo Option. + if (SubCommandIDs.empty()) + return SubCommand.empty(); + + // Handle CandidateInfo Option which has at least one registered SubCommand. + // If no valid SubCommand (empty) in commandline, this CandidateInfo option + // should not be printed. + if (SubCommand.empty()) + return false; + + // Find the ID of the valid subcommand passed in commandline (its index in + // the SubCommands table which contains all subcommands). + unsigned ActiveSubCommandID = ActiveSubCommand - &SubCommands[0]; + // Print if the ActiveSubCommandID is registered with the CandidateInfo + // Option. + return std::find(SubCommandIDs.begin(), SubCommandIDs.end(), + ActiveSubCommandID) != SubCommandIDs.end(); + }; + for (unsigned Id = 1, e = getNumOptions() + 1; Id != e; ++Id) { // FIXME: Split out option groups. if (getOptionKind(Id) == Option::GroupClass) @@ -764,6 +814,9 @@ void OptTable::internalPrintHelp( if (ExcludeOption(CandidateInfo)) continue; + if (!DoesOptionBelongToSubcommand(CandidateInfo)) + continue; + // If an alias doesn't have a help text, show a help text for the aliased // option instead. const char *HelpText = getOptionHelpText(Id, VisibilityMask); @@ -791,8 +844,11 @@ void OptTable::internalPrintHelp( GenericOptTable::GenericOptTable(const StringTable &StrTable, ArrayRef<StringTable::Offset> PrefixesTable, - ArrayRef<Info> OptionInfos, bool IgnoreCase) - : OptTable(StrTable, PrefixesTable, OptionInfos, IgnoreCase) { + ArrayRef<Info> OptionInfos, bool IgnoreCase, + ArrayRef<SubCommand> SubCommands, + ArrayRef<unsigned> SubCommandIDsTable) + : OptTable(StrTable, PrefixesTable, OptionInfos, IgnoreCase, SubCommands, + SubCommandIDsTable) { std::set<StringRef> TmpPrefixesUnion; for (auto const &Info : OptionInfos.drop_front(FirstSearchableIndex)) diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp index 7069e8d..119caea 100644 --- a/llvm/lib/Passes/PassBuilderPipelines.cpp +++ b/llvm/lib/Passes/PassBuilderPipelines.cpp @@ -1960,6 +1960,7 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level, // is fixed. MPM.addPass(WholeProgramDevirtPass(ExportSummary, nullptr)); + MPM.addPass(NoRecurseLTOInferencePass()); // Stop here at -O1. if (Level == OptimizationLevel::O1) { // The LowerTypeTestsPass needs to run to lower type metadata and the diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index f0e7d36..88550ea 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -119,6 +119,7 @@ MODULE_PASS("metarenamer", MetaRenamerPass()) MODULE_PASS("module-inline", ModuleInlinerPass()) MODULE_PASS("name-anon-globals", NameAnonGlobalPass()) MODULE_PASS("no-op-module", NoOpModulePass()) +MODULE_PASS("norecurse-lto-inference", NoRecurseLTOInferencePass()) MODULE_PASS("nsan", NumericalStabilitySanitizerPass()) MODULE_PASS("openmp-opt", OpenMPOptPass()) MODULE_PASS("openmp-opt-postlink", diff --git a/llvm/lib/Support/GlobPattern.cpp b/llvm/lib/Support/GlobPattern.cpp index 7004adf..0ecf47d 100644 --- a/llvm/lib/Support/GlobPattern.cpp +++ b/llvm/lib/Support/GlobPattern.cpp @@ -143,6 +143,15 @@ GlobPattern::create(StringRef S, std::optional<size_t> MaxSubPatterns) { return Pat; S = S.substr(PrefixSize); + // Just in case we stop on unmatched opening brackets. + size_t SuffixStart = S.find_last_of("?*[]{}\\"); + assert(SuffixStart != std::string::npos); + if (S[SuffixStart] == '\\') + ++SuffixStart; + ++SuffixStart; + Pat.Suffix = S.substr(SuffixStart); + S = S.substr(0, SuffixStart); + SmallVector<std::string, 1> SubPats; if (auto Err = parseBraceExpansions(S, MaxSubPatterns).moveInto(SubPats)) return std::move(Err); @@ -193,6 +202,8 @@ GlobPattern::SubGlobPattern::create(StringRef S) { bool GlobPattern::match(StringRef S) const { if (!S.consume_front(Prefix)) return false; + if (!S.consume_back(Suffix)) + return false; if (SubGlobs.empty() && S.empty()) return true; for (auto &Glob : SubGlobs) diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp index b98f797..c76689f 100644 --- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp @@ -2878,7 +2878,7 @@ static SVEStackSizes determineSVEStackSizes(MachineFunction &MF, StackTop += MFI.getObjectSize(FI); StackTop = alignTo(StackTop, Alignment); - assert(StackTop < std::numeric_limits<int64_t>::max() && + assert(StackTop < (uint64_t)std::numeric_limits<int64_t>::max() && "SVE StackTop far too large?!"); int64_t Offset = -int64_t(StackTop); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 50a8754..479e345 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -5666,18 +5666,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost( VectorType *AccumVectorType = VectorType::get(AccumType, VF.divideCoefficientBy(Ratio)); // We don't yet support all kinds of legalization. - auto TA = TLI->getTypeAction(AccumVectorType->getContext(), - EVT::getEVT(AccumVectorType)); - switch (TA) { + auto TC = TLI->getTypeConversion(AccumVectorType->getContext(), + EVT::getEVT(AccumVectorType)); + switch (TC.first) { default: return Invalid; case TargetLowering::TypeLegal: case TargetLowering::TypePromoteInteger: case TargetLowering::TypeSplitVector: + // The legalised type (e.g. after splitting) must be legal too. + if (TLI->getTypeAction(AccumVectorType->getContext(), TC.second) != + TargetLowering::TypeLegal) + return Invalid; break; } - // Check what kind of type-legalisation happens. std::pair<InstructionCost, MVT> AccumLT = getTypeLegalizationCost(AccumVectorType); std::pair<InstructionCost, MVT> InputLT = diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td index 6b3c151..1a697f7 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPU.td +++ b/llvm/lib/Target/AMDGPU/AMDGPU.td @@ -1448,10 +1448,10 @@ def Feature45BitNumRecordsBufferResource : SubtargetFeature< "45-bit-num-records "The buffer resource (V#) supports 45-bit num_records" >; -def FeatureCluster : SubtargetFeature< "cluster", - "HasCluster", +def FeatureClusters : SubtargetFeature< "clusters", + "HasClusters", "true", - "Has cluster support" + "Has clusters of workgroups support" >; // Dummy feature used to disable assembler instructions. @@ -2120,7 +2120,7 @@ def FeatureISAVersion12_50 : FeatureSet< Feature45BitNumRecordsBufferResource, FeatureSupportsXNACK, FeatureXNACK, - FeatureCluster, + FeatureClusters, ]>; def FeatureISAVersion12_51 : FeatureSet< diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp index 848d9a5..557d87f 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp @@ -5043,6 +5043,9 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case Intrinsic::amdgcn_mfma_i32_16x16x64_i8: case Intrinsic::amdgcn_mfma_i32_32x32x32_i8: case Intrinsic::amdgcn_mfma_f32_16x16x32_bf16: { + unsigned DstSize = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits(); + unsigned MinNumRegsRequired = DstSize / 32; + // Default for MAI intrinsics. // srcC can also be an immediate which can be folded later. // FIXME: Should we eventually add an alternative mapping with AGPR src @@ -5051,29 +5054,32 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { // vdst, srcA, srcB, srcC const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>(); OpdsMapping[0] = - Info->mayNeedAGPRs() + Info->getMinNumAGPRs() >= MinNumRegsRequired ? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI) : getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI); OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI); OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI); OpdsMapping[4] = - Info->mayNeedAGPRs() + Info->getMinNumAGPRs() >= MinNumRegsRequired ? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI) : getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI); break; } case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4: case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: { + unsigned DstSize = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits(); + unsigned MinNumRegsRequired = DstSize / 32; + const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>(); OpdsMapping[0] = - Info->mayNeedAGPRs() + Info->getMinNumAGPRs() >= MinNumRegsRequired ? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI) : getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI); OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI); OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI); OpdsMapping[4] = - Info->mayNeedAGPRs() + Info->getMinNumAGPRs() >= MinNumRegsRequired ? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI) : getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI); diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp index a67a7be..d0c0822 100644 --- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp +++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp @@ -1944,6 +1944,7 @@ public: void cvtVOP3Interp(MCInst &Inst, const OperandVector &Operands); void cvtVINTERP(MCInst &Inst, const OperandVector &Operands); + void cvtOpSelHelper(MCInst &Inst, unsigned OpSel); bool parseDimId(unsigned &Encoding); ParseStatus parseDim(OperandVector &Operands); @@ -9239,6 +9240,33 @@ static bool isRegOrImmWithInputMods(const MCInstrDesc &Desc, unsigned OpNum) { MCOI::OperandConstraint::TIED_TO) == -1; } +void AMDGPUAsmParser::cvtOpSelHelper(MCInst &Inst, unsigned OpSel) { + unsigned Opc = Inst.getOpcode(); + constexpr AMDGPU::OpName Ops[] = {AMDGPU::OpName::src0, AMDGPU::OpName::src1, + AMDGPU::OpName::src2}; + constexpr AMDGPU::OpName ModOps[] = {AMDGPU::OpName::src0_modifiers, + AMDGPU::OpName::src1_modifiers, + AMDGPU::OpName::src2_modifiers}; + for (int J = 0; J < 3; ++J) { + int OpIdx = AMDGPU::getNamedOperandIdx(Opc, Ops[J]); + if (OpIdx == -1) + // Some instructions, e.g. v_interp_p2_f16 in GFX9, have src0, src2, but + // no src1. So continue instead of break. + continue; + + int ModIdx = AMDGPU::getNamedOperandIdx(Opc, ModOps[J]); + uint32_t ModVal = Inst.getOperand(ModIdx).getImm(); + + if ((OpSel & (1 << J)) != 0) + ModVal |= SISrcMods::OP_SEL_0; + // op_sel[3] is encoded in src0_modifiers. + if (ModOps[J] == AMDGPU::OpName::src0_modifiers && (OpSel & (1 << 3)) != 0) + ModVal |= SISrcMods::DST_OP_SEL; + + Inst.getOperand(ModIdx).setImm(ModVal); + } +} + void AMDGPUAsmParser::cvtVOP3Interp(MCInst &Inst, const OperandVector &Operands) { OptionalImmIndexMap OptionalIdx; @@ -9275,6 +9303,16 @@ void AMDGPUAsmParser::cvtVOP3Interp(MCInst &Inst, const OperandVector &Operands) if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::omod)) addOptionalImmOperand(Inst, Operands, OptionalIdx, AMDGPUOperand::ImmTyOModSI); + + // Some v_interp instructions use op_sel[3] for dst. + if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::op_sel)) { + addOptionalImmOperand(Inst, Operands, OptionalIdx, + AMDGPUOperand::ImmTyOpSel); + int OpSelIdx = AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::op_sel); + unsigned OpSel = Inst.getOperand(OpSelIdx).getImm(); + + cvtOpSelHelper(Inst, OpSel); + } } void AMDGPUAsmParser::cvtVINTERP(MCInst &Inst, const OperandVector &Operands) @@ -9310,31 +9348,10 @@ void AMDGPUAsmParser::cvtVINTERP(MCInst &Inst, const OperandVector &Operands) if (OpSelIdx == -1) return; - const AMDGPU::OpName Ops[] = {AMDGPU::OpName::src0, AMDGPU::OpName::src1, - AMDGPU::OpName::src2}; - const AMDGPU::OpName ModOps[] = {AMDGPU::OpName::src0_modifiers, - AMDGPU::OpName::src1_modifiers, - AMDGPU::OpName::src2_modifiers}; - unsigned OpSel = Inst.getOperand(OpSelIdx).getImm(); - - for (int J = 0; J < 3; ++J) { - int OpIdx = AMDGPU::getNamedOperandIdx(Opc, Ops[J]); - if (OpIdx == -1) - break; - - int ModIdx = AMDGPU::getNamedOperandIdx(Opc, ModOps[J]); - uint32_t ModVal = Inst.getOperand(ModIdx).getImm(); - - if ((OpSel & (1 << J)) != 0) - ModVal |= SISrcMods::OP_SEL_0; - if (ModOps[J] == AMDGPU::OpName::src0_modifiers && - (OpSel & (1 << 3)) != 0) - ModVal |= SISrcMods::DST_OP_SEL; - - Inst.getOperand(ModIdx).setImm(ModVal); - } + cvtOpSelHelper(Inst, OpSel); } + void AMDGPUAsmParser::cvtScaledMFMA(MCInst &Inst, const OperandVector &Operands) { OptionalImmIndexMap OptionalIdx; diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp index 7b94ea3..f291e37 100644 --- a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp +++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp @@ -541,7 +541,7 @@ unsigned GCNSubtarget::getMaxNumSGPRs(const Function &F) const { unsigned GCNSubtarget::getBaseMaxNumVGPRs( const Function &F, std::pair<unsigned, unsigned> NumVGPRBounds) const { - const auto &[Min, Max] = NumVGPRBounds; + const auto [Min, Max] = NumVGPRBounds; // Check if maximum number of VGPRs was explicitly requested using // "amdgpu-num-vgpr" attribute. diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.h b/llvm/lib/Target/AMDGPU/GCNSubtarget.h index 879bf5a..c2e6078 100644 --- a/llvm/lib/Target/AMDGPU/GCNSubtarget.h +++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.h @@ -288,7 +288,7 @@ protected: bool Has45BitNumRecordsBufferResource = false; - bool HasCluster = false; + bool HasClusters = false; // Dummy feature to use for assembler in tablegen. bool FeatureDisable = false; @@ -1839,7 +1839,7 @@ public: } /// \returns true if the subtarget supports clusters of workgroups. - bool hasClusters() const { return HasCluster; } + bool hasClusters() const { return HasClusters; } /// \returns true if the subtarget requires a wait for xcnt before atomic /// flat/global stores & rmw. diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp index d3b5718..3563caa 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp @@ -1280,6 +1280,17 @@ void AMDGPUInstPrinter::printPackedModifier(const MCInst *MI, (ModIdx != -1) ? MI->getOperand(ModIdx).getImm() : DefaultValue; } + // Some instructions, e.g. v_interp_p2_f16 in GFX9, have src0, src2, but no + // src1. + if (NumOps == 1 && AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::src2) && + !AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::src1)) { + Ops[NumOps++] = DefaultValue; // Set src1_modifiers to default. + int Mod2Idx = + AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src2_modifiers); + assert(Mod2Idx != -1); + Ops[NumOps++] = MI->getOperand(Mod2Idx).getImm(); + } + const bool HasDst = (AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::vdst) != -1) || (AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::sdst) != -1); diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index e233457..1a686a9 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -17346,74 +17346,24 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI, MachineFunction *MF = MI.getParent()->getParent(); MachineRegisterInfo &MRI = MF->getRegInfo(); - SIMachineFunctionInfo *Info = MF->getInfo<SIMachineFunctionInfo>(); if (TII->isVOP3(MI.getOpcode())) { // Make sure constant bus requirements are respected. TII->legalizeOperandsVOP3(MRI, MI); - // Prefer VGPRs over AGPRs in mAI instructions where possible. - // This saves a chain-copy of registers and better balance register - // use between vgpr and agpr as agpr tuples tend to be big. - if (!MI.getDesc().operands().empty()) { - unsigned Opc = MI.getOpcode(); - bool HasAGPRs = Info->mayNeedAGPRs(); - const SIRegisterInfo *TRI = Subtarget->getRegisterInfo(); - int16_t Src2Idx = AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src2); - for (auto I : - {AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src0), - AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src1), Src2Idx}) { - if (I == -1) - break; - if ((I == Src2Idx) && (HasAGPRs)) - break; - MachineOperand &Op = MI.getOperand(I); - if (!Op.isReg() || !Op.getReg().isVirtual()) - continue; - auto *RC = TRI->getRegClassForReg(MRI, Op.getReg()); - if (!TRI->hasAGPRs(RC)) - continue; - auto *Src = MRI.getUniqueVRegDef(Op.getReg()); - if (!Src || !Src->isCopy() || - !TRI->isSGPRReg(MRI, Src->getOperand(1).getReg())) - continue; - auto *NewRC = TRI->getEquivalentVGPRClass(RC); - // All uses of agpr64 and agpr32 can also accept vgpr except for - // v_accvgpr_read, but we do not produce agpr reads during selection, - // so no use checks are needed. - MRI.setRegClass(Op.getReg(), NewRC); - } - - if (TII->isMAI(MI)) { - // The ordinary src0, src1, src2 were legalized above. - // - // We have to also legalize the appended v_mfma_ld_scale_b32 operands, - // as a separate instruction. - int Src0Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(), - AMDGPU::OpName::scale_src0); - if (Src0Idx != -1) { - int Src1Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(), - AMDGPU::OpName::scale_src1); - if (TII->usesConstantBus(MRI, MI, Src0Idx) && - TII->usesConstantBus(MRI, MI, Src1Idx)) - TII->legalizeOpWithMove(MI, Src1Idx); - } - } - - if (!HasAGPRs) - return; - - // Resolve the rest of AV operands to AGPRs. - if (auto *Src2 = TII->getNamedOperand(MI, AMDGPU::OpName::src2)) { - if (Src2->isReg() && Src2->getReg().isVirtual()) { - auto *RC = TRI->getRegClassForReg(MRI, Src2->getReg()); - if (TRI->isVectorSuperClass(RC)) { - auto *NewRC = TRI->getEquivalentAGPRClass(RC); - MRI.setRegClass(Src2->getReg(), NewRC); - if (Src2->isTied()) - MRI.setRegClass(MI.getOperand(0).getReg(), NewRC); - } - } + if (TII->isMAI(MI)) { + // The ordinary src0, src1, src2 were legalized above. + // + // We have to also legalize the appended v_mfma_ld_scale_b32 operands, + // as a separate instruction. + int Src0Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(), + AMDGPU::OpName::scale_src0); + if (Src0Idx != -1) { + int Src1Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(), + AMDGPU::OpName::scale_src1); + if (TII->usesConstantBus(MRI, MI, Src0Idx) && + TII->usesConstantBus(MRI, MI, Src1Idx)) + TII->legalizeOpWithMove(MI, Src1Idx); } } diff --git a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp index 908d856..b398db4 100644 --- a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp @@ -33,17 +33,20 @@ using namespace llvm; // optimal RC for Opc and Dest of MFMA. In particular, there are high RP cases // where it is better to produce the VGPR form (e.g. if there are VGPR users // of the MFMA result). -static cl::opt<bool> MFMAVGPRForm( - "amdgpu-mfma-vgpr-form", cl::Hidden, +static cl::opt<bool, true> MFMAVGPRFormOpt( + "amdgpu-mfma-vgpr-form", cl::desc("Whether to force use VGPR for Opc and Dest of MFMA. If " "unspecified, default to compiler heuristics"), - cl::init(false)); + cl::location(SIMachineFunctionInfo::MFMAVGPRForm), cl::init(false), + cl::Hidden); const GCNTargetMachine &getTM(const GCNSubtarget *STI) { const SITargetLowering *TLI = STI->getTargetLowering(); return static_cast<const GCNTargetMachine &>(TLI->getTargetMachine()); } +bool SIMachineFunctionInfo::MFMAVGPRForm = false; + SIMachineFunctionInfo::SIMachineFunctionInfo(const Function &F, const GCNSubtarget *STI) : AMDGPUMachineFunction(F, *STI), Mode(F, *STI), GWSResourcePSV(getTM(STI)), @@ -81,14 +84,13 @@ SIMachineFunctionInfo::SIMachineFunctionInfo(const Function &F, PSInputAddr = AMDGPU::getInitialPSInputAddr(F); } - MayNeedAGPRs = ST.hasMAIInsts(); if (ST.hasGFX90AInsts()) { - // FIXME: MayNeedAGPRs is a misnomer for how this is used. MFMA selection - // should be separated from availability of AGPRs - if (MFMAVGPRForm || - (ST.getMaxNumVGPRs(F) <= ST.getAddressableNumArchVGPRs() && - !mayUseAGPRs(F))) - MayNeedAGPRs = false; // We will select all MAI with VGPR operands. + // FIXME: Extract logic out of getMaxNumVectorRegs; we need to apply the + // allocation granule and clamping. + auto [MinNumAGPRAttr, MaxNumAGPRAttr] = + AMDGPU::getIntegerPairAttribute(F, "amdgpu-agpr-alloc", {~0u, ~0u}, + /*OnlyFirstRequired=*/true); + MinNumAGPRs = MinNumAGPRAttr; } if (AMDGPU::isChainCC(CC)) { diff --git a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h index 4560615..b7dbb59 100644 --- a/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h +++ b/llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h @@ -509,7 +509,9 @@ private: // user arguments. This is an offset from the KernargSegmentPtr. bool ImplicitArgPtr : 1; - bool MayNeedAGPRs : 1; + /// Minimum number of AGPRs required to allocate in the function. Only + /// relevant for gfx90a-gfx950. For gfx908, this should be infinite. + unsigned MinNumAGPRs = ~0u; // The hard-wired high half of the address of the global information table // for AMDPAL OS type. 0xffffffff represents no hard-wired high half, since @@ -537,6 +539,8 @@ private: void MRI_NoteCloneVirtualRegister(Register NewReg, Register SrcReg) override; public: + static bool MFMAVGPRForm; + struct VGPRSpillToAGPR { SmallVector<MCPhysReg, 32> Lanes; bool FullyAllocated = false; @@ -1196,9 +1200,7 @@ public: unsigned getMaxMemoryClusterDWords() const { return MaxMemoryClusterDWords; } - bool mayNeedAGPRs() const { - return MayNeedAGPRs; - } + unsigned getMinNumAGPRs() const { return MinNumAGPRs; } // \returns true if a function has a use of AGPRs via inline asm or // has a call which may use it. diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp index 3c2dd42..3115579 100644 --- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp @@ -1118,12 +1118,7 @@ SIRegisterInfo::getPointerRegClass(unsigned Kind) const { const TargetRegisterClass * SIRegisterInfo::getCrossCopyRegClass(const TargetRegisterClass *RC) const { - if (isAGPRClass(RC) && !ST.hasGFX90AInsts()) - return getEquivalentVGPRClass(RC); - if (RC == &AMDGPU::SCC_CLASSRegClass) - return getWaveMaskRegClass(); - - return RC; + return RC == &AMDGPU::SCC_CLASSRegClass ? &AMDGPU::SReg_32RegClass : RC; } static unsigned getNumSubRegsForSpillOp(const MachineInstr &MI, diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td index 4a2b54d..42ec8ba 100644 --- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td @@ -97,6 +97,7 @@ class VOP3Interp<string OpName, VOPProfile P, list<dag> pattern = []> : VOP3_Pseudo<OpName, P, pattern> { let AsmMatchConverter = "cvtVOP3Interp"; let mayRaiseFPException = 0; + let VOP3_OPSEL = P.HasOpSel; } def VOP3_INTERP : VOPProfile<[f32, f32, i32, untyped]> { @@ -119,16 +120,17 @@ def VOP3_INTERP_MOV : VOPProfile<[f32, i32, i32, untyped]> { let HasSrc0Mods = 0; } -class getInterp16Asm <bit HasSrc2, bit HasOMod> { +class getInterp16Asm <bit HasSrc2, bit HasOMod, bit OpSel> { string src2 = !if(HasSrc2, ", $src2_modifiers", ""); string omod = !if(HasOMod, "$omod", ""); + string opsel = !if(OpSel, "$op_sel", ""); string ret = - " $vdst, $src0_modifiers, $attr$attrchan"#src2#"$high$clamp"#omod; + " $vdst, $src0_modifiers, $attr$attrchan"#src2#"$high$clamp"#omod#opsel; } class getInterp16Ins <bit HasSrc2, bit HasOMod, - Operand Src0Mod, Operand Src2Mod> { - dag ret = !if(HasSrc2, + Operand Src0Mod, Operand Src2Mod, bit OpSel> { + dag ret1 = !if(HasSrc2, !if(HasOMod, (ins Src0Mod:$src0_modifiers, VRegSrc_32:$src0, InterpAttr:$attr, InterpAttrChan:$attrchan, @@ -143,19 +145,22 @@ class getInterp16Ins <bit HasSrc2, bit HasOMod, InterpAttr:$attr, InterpAttrChan:$attrchan, highmod:$high, Clamp0:$clamp, omod0:$omod) ); + dag ret2 = !if(OpSel, (ins op_sel0:$op_sel), (ins)); + dag ret = !con(ret1, ret2); } -class VOP3_INTERP16 <list<ValueType> ArgVT> : VOPProfile<ArgVT> { +class VOP3_INTERP16 <list<ValueType> ArgVT, bit OpSel = 0> : VOPProfile<ArgVT> { let IsSingle = 1; let HasOMod = !ne(DstVT.Value, f16.Value); let HasHigh = 1; + let HasOpSel = OpSel; let Src0Mod = FPVRegInputMods; let Src2Mod = FPVRegInputMods; let Outs64 = (outs DstRC.RegClass:$vdst); - let Ins64 = getInterp16Ins<HasSrc2, HasOMod, Src0Mod, Src2Mod>.ret; - let Asm64 = getInterp16Asm<HasSrc2, HasOMod>.ret; + let Ins64 = getInterp16Ins<HasSrc2, HasOMod, Src0Mod, Src2Mod, OpSel>.ret; + let Asm64 = getInterp16Asm<HasSrc2, HasOMod, OpSel>.ret; } //===----------------------------------------------------------------------===// @@ -480,7 +485,7 @@ let SubtargetPredicate = isGFX9Plus in { defm V_MAD_U16_gfx9 : VOP3Inst_t16 <"v_mad_u16_gfx9", VOP_I16_I16_I16_I16>; defm V_MAD_I16_gfx9 : VOP3Inst_t16 <"v_mad_i16_gfx9", VOP_I16_I16_I16_I16>; let OtherPredicates = [isNotGFX90APlus] in -def V_INTERP_P2_F16_gfx9 : VOP3Interp <"v_interp_p2_f16_gfx9", VOP3_INTERP16<[f16, f32, i32, f32]>>; +def V_INTERP_P2_F16_opsel : VOP3Interp <"v_interp_p2_f16_opsel", VOP3_INTERP16<[f16, f32, i32, f32], /*OpSel*/ 1>>; } // End SubtargetPredicate = isGFX9Plus // This predicate should only apply to the selection pattern. The @@ -2676,6 +2681,14 @@ multiclass VOP3Interp_F16_Real_gfx9<bits<10> op, string OpName, string AsmName> } } +multiclass VOP3Interp_F16_OpSel_Real_gfx9<bits<10> op, string OpName, string AsmName> { + def _gfx9 : VOP3_Real<!cast<VOP3_Pseudo>(OpName), SIEncodingFamily.GFX9>, + VOP3Interp_OpSel_gfx9 <op, !cast<VOP3_Pseudo>(OpName).Pfl> { + VOP3_Pseudo ps = !cast<VOP3_Pseudo>(OpName); + let AsmString = AsmName # ps.AsmOperands; + } +} + multiclass VOP3_Real_gfx9<bits<10> op, string AsmName> { def _gfx9 : VOP3_Real<!cast<VOP_Pseudo>(NAME#"_e64"), SIEncodingFamily.GFX9>, VOP3e_vi <op, !cast<VOP_Pseudo>(NAME#"_e64").Pfl> { @@ -2788,7 +2801,7 @@ defm V_MAD_U16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x204, "v_mad_u16">; defm V_MAD_I16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x205, "v_mad_i16">; defm V_FMA_F16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x206, "v_fma_f16">; defm V_DIV_FIXUP_F16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x207, "v_div_fixup_f16">; -defm V_INTERP_P2_F16_gfx9 : VOP3Interp_F16_Real_gfx9 <0x277, "V_INTERP_P2_F16_gfx9", "v_interp_p2_f16">; +defm V_INTERP_P2_F16_opsel : VOP3Interp_F16_OpSel_Real_gfx9 <0x277, "V_INTERP_P2_F16_opsel", "v_interp_p2_f16">; defm V_ADD_I32 : VOP3_Real_vi <0x29c>; defm V_SUB_I32 : VOP3_Real_vi <0x29d>; diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td index 5daf860..3a0cc35 100644 --- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td @@ -67,7 +67,7 @@ class VOP3P_Mix_Profile<VOPProfile P, VOP3Features Features = VOP3_REGULAR, class VOP3P_Mix_Profile_t16<VOPProfile P, VOP3Features Features = VOP3_REGULAR> : VOP3P_Mix_Profile<P, Features, 0> { let IsTrue16 = 1; - let IsRealTrue16 = 1; + let IsRealTrue16 = 1; let DstRC64 = getVALUDstForVT<P.DstVT, 1 /*IsTrue16*/, 1 /*IsVOP3Encoding*/>.ret; } @@ -950,7 +950,7 @@ class MFMA_F8F6F4_WithSizeTable_Helper<VOP3_Pseudo ps, string F8F8Op> : } // Currently assumes scaled instructions never have abid -class MAIFrag<SDPatternOperator Op, code pred, bit HasAbid = true, bit Scaled = false> : PatFrag < +class MAIFrag<SDPatternOperator Op, bit HasAbid = true, bit Scaled = false> : PatFrag < !if(Scaled, (ops node:$src0, node:$src1, node:$src2, node:$cbsz, node:$blgp, node:$src0_modifiers, node:$scale_src0, node:$src1_modifiers, node:$scale_src1), @@ -959,37 +959,30 @@ class MAIFrag<SDPatternOperator Op, code pred, bit HasAbid = true, bit Scaled = (ops node:$blgp))), !if(Scaled, (Op $src0, $src1, $src2, $cbsz, $blgp, $src0_modifiers, $scale_src0, $src1_modifiers, $scale_src1), !if(HasAbid, (Op $src0, $src1, $src2, $cbsz, $abid, $blgp), - (Op $src0, $src1, $src2, $cbsz, $blgp))), - pred ->; - -defvar MayNeedAGPRs = [{ - return MF->getInfo<SIMachineFunctionInfo>()->mayNeedAGPRs(); -}]; - -defvar MayNeedAGPRs_gisel = [{ - return MF.getInfo<SIMachineFunctionInfo>()->mayNeedAGPRs(); -}]; + (Op $src0, $src1, $src2, $cbsz, $blgp)))>; -defvar MayNotNeedAGPRs = [{ - return !MF->getInfo<SIMachineFunctionInfo>()->mayNeedAGPRs(); -}]; +class CanUseAGPR_MAI<ValueType vt> { + code PredicateCode = [{ + return !Subtarget->hasGFX90AInsts() || + (!SIMachineFunctionInfo::MFMAVGPRForm && + MF->getInfo<SIMachineFunctionInfo>()->getMinNumAGPRs() >= + }] # !srl(vt.Size, 5) # ");"; -defvar MayNotNeedAGPRs_gisel = [{ - return !MF.getInfo<SIMachineFunctionInfo>()->mayNeedAGPRs(); -}]; + code GISelPredicateCode = [{ + return !Subtarget->hasGFX90AInsts() || + (!SIMachineFunctionInfo::MFMAVGPRForm && + MF.getInfo<SIMachineFunctionInfo>()->getMinNumAGPRs() >= + }] # !srl(vt.Size, 5) # ");"; +} -class AgprMAIFrag<SDPatternOperator Op, bit HasAbid = true, +class AgprMAIFrag<SDPatternOperator Op, ValueType vt, bit HasAbid = true, bit Scaled = false> : - MAIFrag<Op, MayNeedAGPRs, HasAbid, Scaled> { - let GISelPredicateCode = MayNeedAGPRs_gisel; -} + MAIFrag<Op, HasAbid, Scaled>, + CanUseAGPR_MAI<vt>; class VgprMAIFrag<SDPatternOperator Op, bit HasAbid = true, - bit Scaled = false> : - MAIFrag<Op, MayNotNeedAGPRs, HasAbid, Scaled> { - let GISelPredicateCode = MayNotNeedAGPRs_gisel; -} + bit Scaled = false> : + MAIFrag<Op, HasAbid, Scaled>; let isAsCheapAsAMove = 1, isReMaterializable = 1 in { defm V_ACCVGPR_READ_B32 : VOP3Inst<"v_accvgpr_read_b32", VOPProfileAccRead>; @@ -1037,16 +1030,19 @@ multiclass MAIInst<string OpName, string P, SDPatternOperator node = null_frag, bit HasAbid = true, bit Scaled = false> { defvar NoDstOverlap = !cast<VOPProfileMAI>("VOPProfileMAI_" # P).NoDstOverlap; + defvar ProfileAGPR = !cast<VOPProfileMAI>("VOPProfileMAI_" # P); + defvar ProfileVGPR = !cast<VOPProfileMAI>("VOPProfileMAI_" # P # "_VCD"); + let isConvergent = 1, mayRaiseFPException = 0, ReadsModeReg = 1 in { // FP32 denorm mode is respected, rounding mode is not. Exceptions are not supported. let Constraints = !if(NoDstOverlap, "@earlyclobber $vdst", "") in { - def _e64 : MAIInst<OpName, !cast<VOPProfileMAI>("VOPProfileMAI_" # P), - !if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag, AgprMAIFrag<node, HasAbid, Scaled>), Scaled>, + def _e64 : MAIInst<OpName, ProfileAGPR, + !if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag, AgprMAIFrag<node, ProfileAGPR.DstVT, HasAbid, Scaled>), Scaled>, MFMATable<0, "AGPR", NAME # "_e64">; let OtherPredicates = [isGFX90APlus], Mnemonic = OpName in - def _vgprcd_e64 : MAIInst<OpName # "_vgprcd", !cast<VOPProfileMAI>("VOPProfileMAI_" # P # "_VCD"), + def _vgprcd_e64 : MAIInst<OpName # "_vgprcd", ProfileVGPR, !if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag, VgprMAIFrag<node, HasAbid, Scaled>), Scaled>, MFMATable<0, "VGPR", NAME # "_vgprcd_e64", NAME # "_e64">; } @@ -1055,12 +1051,12 @@ multiclass MAIInst<string OpName, string P, SDPatternOperator node = null_frag, let Constraints = !if(NoDstOverlap, "$vdst = $src2", ""), isConvertibleToThreeAddress = NoDstOverlap, Mnemonic = OpName in { - def "_mac_e64" : MAIInst<OpName # "_mac", !cast<VOPProfileMAI>("VOPProfileMAI_" # P), - !if(!eq(node, null_frag), null_frag, AgprMAIFrag<node, HasAbid, Scaled>), Scaled>, + def "_mac_e64" : MAIInst<OpName # "_mac", ProfileAGPR, + !if(!eq(node, null_frag), null_frag, AgprMAIFrag<node, ProfileAGPR.DstVT, HasAbid, Scaled>), Scaled>, MFMATable<1, "AGPR", NAME # "_e64", NAME # "_mac_e64">; let OtherPredicates = [isGFX90APlus] in - def _mac_vgprcd_e64 : MAIInst<OpName # "_mac_vgprcd", !cast<VOPProfileMAI>("VOPProfileMAI_" # P # "_VCD"), + def _mac_vgprcd_e64 : MAIInst<OpName # "_mac_vgprcd", ProfileVGPR, !if(!eq(node, null_frag), null_frag, VgprMAIFrag<node, HasAbid, Scaled>), Scaled>, MFMATable<1, "VGPR", NAME # "_vgprcd_e64", NAME # "_mac_e64">; } @@ -1074,11 +1070,11 @@ multiclass ScaledMAIInst_mc<string OpName, string UnscaledOpName_, SDPatternOper defvar UnscaledOpName = UnscaledOpName_#VariantSuffix; defvar HasAbid = false; - - defvar NoDstOverlap = !cast<VOPProfileMAI>(!cast<MAIInst>(UnscaledOpName#"_e64").Pfl).NoDstOverlap; + defvar Profile = !cast<VOPProfileMAI>(!cast<MAIInst>(UnscaledOpName#"_e64").Pfl); + defvar NoDstOverlap = Profile.NoDstOverlap; def _e64 : ScaledMAIInst<OpName, - !cast<MAIInst>(UnscaledOpName#"_e64"), !if(NoDstOverlap, null_frag, AgprMAIFrag<node, HasAbid, true>)>, + !cast<MAIInst>(UnscaledOpName#"_e64"), !if(NoDstOverlap, null_frag, AgprMAIFrag<node, Profile.DstVT, HasAbid, true>)>, MFMATable<0, "AGPR", NAME # "_e64">; def _vgprcd_e64 : ScaledMAIInst<OpName # "_vgprcd", @@ -1090,7 +1086,7 @@ multiclass ScaledMAIInst_mc<string OpName, string UnscaledOpName_, SDPatternOper isConvertibleToThreeAddress = NoDstOverlap, Mnemonic = UnscaledOpName_ in { def _mac_e64 : ScaledMAIInst<OpName # "_mac", - !cast<MAIInst>(UnscaledOpName # "_mac_e64"), AgprMAIFrag<node, HasAbid, true>>, + !cast<MAIInst>(UnscaledOpName # "_mac_e64"), AgprMAIFrag<node, Profile.DstVT, HasAbid, true>>, MFMATable<1, "AGPR", NAME # "_e64">; def _mac_vgprcd_e64 : ScaledMAIInst<OpName # " _mac_vgprcd", diff --git a/llvm/lib/Target/AMDGPU/VOPInstructions.td b/llvm/lib/Target/AMDGPU/VOPInstructions.td index 631f0f3..8325c62 100644 --- a/llvm/lib/Target/AMDGPU/VOPInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOPInstructions.td @@ -419,6 +419,13 @@ class VOP3a_ScaleSel_gfx1250<bits<10> op, VOPProfile p> : VOP3e_gfx11_gfx12<op, let Inst{14-11} = scale_sel; } +class VOP3Interp_OpSel_gfx9<bits<10> op, VOPProfile p> : VOP3Interp_vi<op, p> { + let Inst{11} = src0_modifiers{2}; + // There's no src1 + let Inst{13} = src2_modifiers{2}; + let Inst{14} = !if(p.HasDst, src0_modifiers{3}, 0); +} + class VOP3Interp_gfx10<bits<10> op, VOPProfile p> : VOP3e_gfx10<op, p> { bits<6> attr; bits<2> attrchan; diff --git a/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp b/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp index 1fc475d..561a9c5 100644 --- a/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp +++ b/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp @@ -349,32 +349,30 @@ public: bool isImm() const override { return Kind == Immediate || Kind == Expression; } - bool isU1Imm() const { return Kind == Immediate && isUInt<1>(getImm()); } - bool isU2Imm() const { return Kind == Immediate && isUInt<2>(getImm()); } - bool isU3Imm() const { return Kind == Immediate && isUInt<3>(getImm()); } - bool isU4Imm() const { return Kind == Immediate && isUInt<4>(getImm()); } - bool isU5Imm() const { return Kind == Immediate && isUInt<5>(getImm()); } - bool isS5Imm() const { return Kind == Immediate && isInt<5>(getImm()); } - bool isU6Imm() const { return Kind == Immediate && isUInt<6>(getImm()); } - bool isU6ImmX2() const { return Kind == Immediate && - isUInt<6>(getImm()) && - (getImm() & 1) == 0; } - bool isU7Imm() const { return Kind == Immediate && isUInt<7>(getImm()); } - bool isU7ImmX4() const { return Kind == Immediate && - isUInt<7>(getImm()) && - (getImm() & 3) == 0; } - bool isU8Imm() const { return Kind == Immediate && isUInt<8>(getImm()); } - bool isU8ImmX8() const { return Kind == Immediate && - isUInt<8>(getImm()) && - (getImm() & 7) == 0; } - - bool isU10Imm() const { return Kind == Immediate && isUInt<10>(getImm()); } - bool isU12Imm() const { return Kind == Immediate && isUInt<12>(getImm()); } + + template <uint64_t N> bool isUImm() const { + return Kind == Immediate && isUInt<N>(getImm()); + } + template <uint64_t N> bool isSImm() const { + return Kind == Immediate && isInt<N>(getImm()); + } + bool isU6ImmX2() const { return isUImm<6>() && (getImm() & 1) == 0; } + bool isU7ImmX4() const { return isUImm<7>() && (getImm() & 3) == 0; } + bool isU8ImmX8() const { return isUImm<8>() && (getImm() & 7) == 0; } + bool isU16Imm() const { return isExtImm<16>(/*Signed*/ false, 1); } bool isS16Imm() const { return isExtImm<16>(/*Signed*/ true, 1); } bool isS16ImmX4() const { return isExtImm<16>(/*Signed*/ true, 4); } bool isS16ImmX16() const { return isExtImm<16>(/*Signed*/ true, 16); } bool isS17Imm() const { return isExtImm<17>(/*Signed*/ true, 1); } + bool isS34Imm() const { + // Once the PC-Rel ABI is finalized, evaluate whether a 34-bit + // ContextImmediate is needed. + return Kind == Expression || isSImm<34>(); + } + bool isS34ImmX16() const { + return Kind == Expression || (isSImm<34>() && (getImm() & 15) == 0); + } bool isHashImmX8() const { // The Hash Imm form is used for instructions that check or store a hash. @@ -384,16 +382,6 @@ public: (getImm() & 7) == 0); } - bool isS34ImmX16() const { - return Kind == Expression || - (Kind == Immediate && isInt<34>(getImm()) && (getImm() & 15) == 0); - } - bool isS34Imm() const { - // Once the PC-Rel ABI is finalized, evaluate whether a 34-bit - // ContextImmediate is needed. - return Kind == Expression || (Kind == Immediate && isInt<34>(getImm())); - } - bool isTLSReg() const { return Kind == TLSRegister; } bool isDirectBr() const { if (Kind == Expression) @@ -1637,7 +1625,7 @@ bool PPCAsmParser::parseInstruction(ParseInstructionInfo &Info, StringRef Name, if (Operands.size() != 5) return false; PPCOperand &EHOp = (PPCOperand &)*Operands[4]; - if (EHOp.isU1Imm() && EHOp.getImm() == 0) + if (EHOp.isUImm<1>() && EHOp.getImm() == 0) Operands.pop_back(); } @@ -1817,7 +1805,7 @@ unsigned PPCAsmParser::validateTargetOperandClass(MCParsedAsmOperand &AsmOp, } PPCOperand &Op = static_cast<PPCOperand &>(AsmOp); - if (Op.isU3Imm() && Op.getImm() == ImmVal) + if (Op.isUImm<3>() && Op.getImm() == ImmVal) return Match_Success; return Match_InvalidOperand; diff --git a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp index 48c31c9..81d8e94 100644 --- a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp +++ b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp @@ -206,45 +206,24 @@ PPCMCCodeEmitter::getVSRpEvenEncoding(const MCInst &MI, unsigned OpNo, return RegBits; } -unsigned PPCMCCodeEmitter::getImm16Encoding(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI) const { - const MCOperand &MO = MI.getOperand(OpNo); - if (MO.isReg() || MO.isImm()) return getMachineOpValue(MI, MO, Fixups, STI); - - // Add a fixup for the immediate field. - addFixup(Fixups, IsLittleEndian ? 0 : 2, MO.getExpr(), PPC::fixup_ppc_half16); - return 0; -} - -uint64_t PPCMCCodeEmitter::getImm34Encoding(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI, - MCFixupKind Fixup) const { +template <MCFixupKind Fixup> +uint64_t PPCMCCodeEmitter::getImmEncoding(const MCInst &MI, unsigned OpNo, + SmallVectorImpl<MCFixup> &Fixups, + const MCSubtargetInfo &STI) const { const MCOperand &MO = MI.getOperand(OpNo); assert(!MO.isReg() && "Not expecting a register for this operand."); if (MO.isImm()) return getMachineOpValue(MI, MO, Fixups, STI); + uint32_t Offset = 0; + if (Fixup == PPC::fixup_ppc_half16) + Offset = IsLittleEndian ? 0 : 2; + // Add a fixup for the immediate field. - addFixup(Fixups, 0, MO.getExpr(), Fixup); + addFixup(Fixups, Offset, MO.getExpr(), Fixup); return 0; } -uint64_t -PPCMCCodeEmitter::getImm34EncodingNoPCRel(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI) const { - return getImm34Encoding(MI, OpNo, Fixups, STI, PPC::fixup_ppc_imm34); -} - -uint64_t -PPCMCCodeEmitter::getImm34EncodingPCRel(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI) const { - return getImm34Encoding(MI, OpNo, Fixups, STI, PPC::fixup_ppc_pcrel34); -} - unsigned PPCMCCodeEmitter::getDispRIEncoding(const MCInst &MI, unsigned OpNo, SmallVectorImpl<MCFixup> &Fixups, const MCSubtargetInfo &STI) const { diff --git a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h index b574557..3356513 100644 --- a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h +++ b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h @@ -47,19 +47,10 @@ public: unsigned getAbsCondBrEncoding(const MCInst &MI, unsigned OpNo, SmallVectorImpl<MCFixup> &Fixups, const MCSubtargetInfo &STI) const; - unsigned getImm16Encoding(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI) const; - uint64_t getImm34Encoding(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI, - MCFixupKind Fixup) const; - uint64_t getImm34EncodingNoPCRel(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI) const; - uint64_t getImm34EncodingPCRel(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI) const; + template <MCFixupKind Fixup> + uint64_t getImmEncoding(const MCInst &MI, unsigned OpNo, + SmallVectorImpl<MCFixup> &Fixups, + const MCSubtargetInfo &STI) const; unsigned getDispRIEncoding(const MCInst &MI, unsigned OpNo, SmallVectorImpl<MCFixup> &Fixups, const MCSubtargetInfo &STI) const; diff --git a/llvm/lib/Target/PowerPC/PPCInstr64Bit.td b/llvm/lib/Target/PowerPC/PPCInstr64Bit.td index 60efa4c..fdca5ebc 100644 --- a/llvm/lib/Target/PowerPC/PPCInstr64Bit.td +++ b/llvm/lib/Target/PowerPC/PPCInstr64Bit.td @@ -14,30 +14,6 @@ //===----------------------------------------------------------------------===// // 64-bit operands. // -def s16imm64 : Operand<i64> { - let PrintMethod = "printS16ImmOperand"; - let EncoderMethod = "getImm16Encoding"; - let ParserMatchClass = PPCS16ImmAsmOperand; - let DecoderMethod = "decodeSImmOperand<16>"; - let OperandType = "OPERAND_IMMEDIATE"; -} -def u16imm64 : Operand<i64> { - let PrintMethod = "printU16ImmOperand"; - let EncoderMethod = "getImm16Encoding"; - let ParserMatchClass = PPCU16ImmAsmOperand; - let DecoderMethod = "decodeUImmOperand<16>"; - let OperandType = "OPERAND_IMMEDIATE"; -} -def s17imm64 : Operand<i64> { - // This operand type is used for addis/lis to allow the assembler parser - // to accept immediates in the range -65536..65535 for compatibility with - // the GNU assembler. The operand is treated as 16-bit otherwise. - let PrintMethod = "printS16ImmOperand"; - let EncoderMethod = "getImm16Encoding"; - let ParserMatchClass = PPCS17ImmAsmOperand; - let DecoderMethod = "decodeSImmOperand<16>"; - let OperandType = "OPERAND_IMMEDIATE"; -} def tocentry : Operand<iPTR> { let MIOperandInfo = (ops i64imm:$imm); } diff --git a/llvm/lib/Target/PowerPC/PPCInstrAltivec.td b/llvm/lib/Target/PowerPC/PPCInstrAltivec.td index c616db4..23d6d88 100644 --- a/llvm/lib/Target/PowerPC/PPCInstrAltivec.td +++ b/llvm/lib/Target/PowerPC/PPCInstrAltivec.td @@ -30,6 +30,11 @@ // Altivec transformation functions and pattern fragments. // +// fneg is not legal, and desugared as an xor. +def desugared_fneg : PatFrag<(ops node:$x), (v4f32 (bitconvert (xor (bitconvert $x), + (int_ppc_altivec_vslw (bitconvert (v16i8 immAllOnesV)), + (bitconvert (v16i8 immAllOnesV))))))>; + def vpkuhum_shuffle : PatFrag<(ops node:$lhs, node:$rhs), (vector_shuffle node:$lhs, node:$rhs), [{ return PPC::isVPKUHUMShuffleMask(cast<ShuffleVectorSDNode>(N), 0, *CurDAG); @@ -467,11 +472,12 @@ def VMADDFP : VAForm_1<46, (outs vrrc:$RT), (ins vrrc:$RA, vrrc:$RC, vrrc:$RB), [(set v4f32:$RT, (fma v4f32:$RA, v4f32:$RC, v4f32:$RB))]>; -// FIXME: The fma+fneg pattern won't match because fneg is not legal. +// fneg is not legal, hence we have to match on the desugared version. def VNMSUBFP: VAForm_1<47, (outs vrrc:$RT), (ins vrrc:$RA, vrrc:$RC, vrrc:$RB), "vnmsubfp $RT, $RA, $RC, $RB", IIC_VecFP, - [(set v4f32:$RT, (fneg (fma v4f32:$RA, v4f32:$RC, - (fneg v4f32:$RB))))]>; + [(set v4f32:$RT, (desugared_fneg (fma v4f32:$RA, v4f32:$RC, + (desugared_fneg v4f32:$RB))))]>; + let hasSideEffects = 1 in { def VMHADDSHS : VA1a_Int_Ty<32, "vmhaddshs", int_ppc_altivec_vmhaddshs, v8i16>; def VMHRADDSHS : VA1a_Int_Ty<33, "vmhraddshs", int_ppc_altivec_vmhraddshs, @@ -892,6 +898,13 @@ def : Pat<(mul v8i16:$vA, v8i16:$vB), (VMLADDUHM $vA, $vB, (v8i16(V_SET0H)))>; // Add def : Pat<(add (mul v8i16:$vA, v8i16:$vB), v8i16:$vC), (VMLADDUHM $vA, $vB, $vC)>; + +// Fused negated multiply-subtract +def : Pat<(v4f32 (desugared_fneg + (int_ppc_altivec_vmaddfp v4f32:$RA, v4f32:$RC, + (desugared_fneg v4f32:$RB)))), + (VNMSUBFP $RA, $RC, $RB)>; + // Saturating adds/subtracts. def : Pat<(v16i8 (saddsat v16i8:$vA, v16i8:$vB)), (v16i8 (VADDSBS $vA, $vB))>; def : Pat<(v16i8 (uaddsat v16i8:$vA, v16i8:$vB)), (v16i8 (VADDUBS $vA, $vB))>; diff --git a/llvm/lib/Target/PowerPC/PPCRegisterInfo.td b/llvm/lib/Target/PowerPC/PPCRegisterInfo.td index 6d8c122..65d0484 100644 --- a/llvm/lib/Target/PowerPC/PPCRegisterInfo.td +++ b/llvm/lib/Target/PowerPC/PPCRegisterInfo.td @@ -615,7 +615,8 @@ def spe4rc : RegisterOperand<GPRC> { } def PPCU1ImmAsmOperand : AsmOperandClass { - let Name = "U1Imm"; let PredicateMethod = "isU1Imm"; + let Name = "U1Imm"; + let PredicateMethod = "isUImm<1>"; let RenderMethod = "addImmOperands"; } def u1imm : Operand<i32> { @@ -626,7 +627,8 @@ def u1imm : Operand<i32> { } def PPCU2ImmAsmOperand : AsmOperandClass { - let Name = "U2Imm"; let PredicateMethod = "isU2Imm"; + let Name = "U2Imm"; + let PredicateMethod = "isUImm<2>"; let RenderMethod = "addImmOperands"; } def u2imm : Operand<i32> { @@ -647,7 +649,8 @@ def atimm : Operand<i32> { } def PPCU3ImmAsmOperand : AsmOperandClass { - let Name = "U3Imm"; let PredicateMethod = "isU3Imm"; + let Name = "U3Imm"; + let PredicateMethod = "isUImm<3>"; let RenderMethod = "addImmOperands"; } def u3imm : Operand<i32> { @@ -658,7 +661,8 @@ def u3imm : Operand<i32> { } def PPCU4ImmAsmOperand : AsmOperandClass { - let Name = "U4Imm"; let PredicateMethod = "isU4Imm"; + let Name = "U4Imm"; + let PredicateMethod = "isUImm<4>"; let RenderMethod = "addImmOperands"; } def u4imm : Operand<i32> { @@ -668,7 +672,8 @@ def u4imm : Operand<i32> { let OperandType = "OPERAND_IMMEDIATE"; } def PPCS5ImmAsmOperand : AsmOperandClass { - let Name = "S5Imm"; let PredicateMethod = "isS5Imm"; + let Name = "S5Imm"; + let PredicateMethod = "isSImm<5>"; let RenderMethod = "addImmOperands"; } def s5imm : Operand<i32> { @@ -678,7 +683,8 @@ def s5imm : Operand<i32> { let OperandType = "OPERAND_IMMEDIATE"; } def PPCU5ImmAsmOperand : AsmOperandClass { - let Name = "U5Imm"; let PredicateMethod = "isU5Imm"; + let Name = "U5Imm"; + let PredicateMethod = "isUImm<5>"; let RenderMethod = "addImmOperands"; } def u5imm : Operand<i32> { @@ -688,7 +694,8 @@ def u5imm : Operand<i32> { let OperandType = "OPERAND_IMMEDIATE"; } def PPCU6ImmAsmOperand : AsmOperandClass { - let Name = "U6Imm"; let PredicateMethod = "isU6Imm"; + let Name = "U6Imm"; + let PredicateMethod = "isUImm<6>"; let RenderMethod = "addImmOperands"; } def u6imm : Operand<i32> { @@ -698,7 +705,8 @@ def u6imm : Operand<i32> { let OperandType = "OPERAND_IMMEDIATE"; } def PPCU7ImmAsmOperand : AsmOperandClass { - let Name = "U7Imm"; let PredicateMethod = "isU7Imm"; + let Name = "U7Imm"; + let PredicateMethod = "isUImm<7>"; let RenderMethod = "addImmOperands"; } def u7imm : Operand<i32> { @@ -708,7 +716,8 @@ def u7imm : Operand<i32> { let OperandType = "OPERAND_IMMEDIATE"; } def PPCU8ImmAsmOperand : AsmOperandClass { - let Name = "U8Imm"; let PredicateMethod = "isU8Imm"; + let Name = "U8Imm"; + let PredicateMethod = "isUImm<8>"; let RenderMethod = "addImmOperands"; } def u8imm : Operand<i32> { @@ -718,7 +727,8 @@ def u8imm : Operand<i32> { let OperandType = "OPERAND_IMMEDIATE"; } def PPCU10ImmAsmOperand : AsmOperandClass { - let Name = "U10Imm"; let PredicateMethod = "isU10Imm"; + let Name = "U10Imm"; + let PredicateMethod = "isUImm<10>"; let RenderMethod = "addImmOperands"; } def u10imm : Operand<i32> { @@ -728,7 +738,8 @@ def u10imm : Operand<i32> { let OperandType = "OPERAND_IMMEDIATE"; } def PPCU12ImmAsmOperand : AsmOperandClass { - let Name = "U12Imm"; let PredicateMethod = "isU12Imm"; + let Name = "U12Imm"; + let PredicateMethod = "isUImm<12>"; let RenderMethod = "addImmOperands"; } def u12imm : Operand<i32> { @@ -743,7 +754,14 @@ def PPCS16ImmAsmOperand : AsmOperandClass { } def s16imm : Operand<i32> { let PrintMethod = "printS16ImmOperand"; - let EncoderMethod = "getImm16Encoding"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>"; + let ParserMatchClass = PPCS16ImmAsmOperand; + let DecoderMethod = "decodeSImmOperand<16>"; + let OperandType = "OPERAND_IMMEDIATE"; +} +def s16imm64 : Operand<i64> { + let PrintMethod = "printS16ImmOperand"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>"; let ParserMatchClass = PPCS16ImmAsmOperand; let DecoderMethod = "decodeSImmOperand<16>"; let OperandType = "OPERAND_IMMEDIATE"; @@ -754,7 +772,14 @@ def PPCU16ImmAsmOperand : AsmOperandClass { } def u16imm : Operand<i32> { let PrintMethod = "printU16ImmOperand"; - let EncoderMethod = "getImm16Encoding"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>"; + let ParserMatchClass = PPCU16ImmAsmOperand; + let DecoderMethod = "decodeUImmOperand<16>"; + let OperandType = "OPERAND_IMMEDIATE"; +} +def u16imm64 : Operand<i64> { + let PrintMethod = "printU16ImmOperand"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>"; let ParserMatchClass = PPCU16ImmAsmOperand; let DecoderMethod = "decodeUImmOperand<16>"; let OperandType = "OPERAND_IMMEDIATE"; @@ -768,7 +793,17 @@ def s17imm : Operand<i32> { // to accept immediates in the range -65536..65535 for compatibility with // the GNU assembler. The operand is treated as 16-bit otherwise. let PrintMethod = "printS16ImmOperand"; - let EncoderMethod = "getImm16Encoding"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>"; + let ParserMatchClass = PPCS17ImmAsmOperand; + let DecoderMethod = "decodeSImmOperand<16>"; + let OperandType = "OPERAND_IMMEDIATE"; +} +def s17imm64 : Operand<i64> { + // This operand type is used for addis/lis to allow the assembler parser + // to accept immediates in the range -65536..65535 for compatibility with + // the GNU assembler. The operand is treated as 16-bit otherwise. + let PrintMethod = "printS16ImmOperand"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>"; let ParserMatchClass = PPCS17ImmAsmOperand; let DecoderMethod = "decodeSImmOperand<16>"; let OperandType = "OPERAND_IMMEDIATE"; @@ -780,14 +815,14 @@ def PPCS34ImmAsmOperand : AsmOperandClass { } def s34imm : Operand<i64> { let PrintMethod = "printS34ImmOperand"; - let EncoderMethod = "getImm34EncodingNoPCRel"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_imm34>"; let ParserMatchClass = PPCS34ImmAsmOperand; let DecoderMethod = "decodeSImmOperand<34>"; let OperandType = "OPERAND_IMMEDIATE"; } def s34imm_pcrel : Operand<i64> { let PrintMethod = "printS34ImmOperand"; - let EncoderMethod = "getImm34EncodingPCRel"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_pcrel34>"; let ParserMatchClass = PPCS34ImmAsmOperand; let DecoderMethod = "decodeSImmOperand<34>"; let OperandType = "OPERAND_IMMEDIATE"; diff --git a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp index 34026ed..ecfb5fe 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp @@ -439,18 +439,6 @@ bool RISCVCallLowering::canLowerReturn(MachineFunction &MF, CCState CCInfo(CallConv, IsVarArg, MF, ArgLocs, MF.getFunction().getContext()); - const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>(); - - std::optional<unsigned> FirstMaskArgument = std::nullopt; - // Preassign the first mask argument. - if (Subtarget.hasVInstructions()) { - for (const auto &ArgIdx : enumerate(Outs)) { - MVT ArgVT = MVT::getVT(ArgIdx.value().Ty); - if (ArgVT.isVector() && ArgVT.getVectorElementType() == MVT::i1) - FirstMaskArgument = ArgIdx.index(); - } - } - for (unsigned I = 0, E = Outs.size(); I < E; ++I) { MVT VT = MVT::getVT(Outs[I].Ty); if (CC_RISCV(I, VT, VT, CCValAssign::Full, Outs[I].Flags[0], CCInfo, diff --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp index 597dd12..9f9ae2f 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp @@ -324,6 +324,10 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { OpdsMapping[0] = GPRValueMapping; + // Atomics always use GPR destinations. Don't refine any further. + if (cast<GLoad>(MI).isAtomic()) + break; + // Use FPR64 for s64 loads on rv32. if (GPRSize == 32 && Size.getFixedValue() == 64) { assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD()); @@ -358,6 +362,10 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { OpdsMapping[0] = GPRValueMapping; + // Atomics always use GPR sources. Don't refine any further. + if (cast<GStore>(MI).isAtomic()) + break; + // Use FPR64 for s64 stores on rv32. if (GPRSize == 32 && Size.getFixedValue() == 64) { assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD()); diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td index a02de31..27cf057 100644 --- a/llvm/lib/Target/RISCV/RISCVFeatures.td +++ b/llvm/lib/Target/RISCV/RISCVFeatures.td @@ -1421,7 +1421,7 @@ def HasVendorXMIPSCMov : Predicate<"Subtarget->hasVendorXMIPSCMov()">, AssemblerPredicate<(all_of FeatureVendorXMIPSCMov), "'Xmipscmov' ('mips.ccmov' instruction)">; -def UseCCMovInsn : Predicate<"Subtarget->useCCMovInsn()">; +def UseMIPSCCMovInsn : Predicate<"Subtarget->useMIPSCCMovInsn()">; def FeatureVendorXMIPSLSP : RISCVExtension<1, 0, "MIPS optimization for hardware load-store bonding">; diff --git a/llvm/lib/Target/RISCV/RISCVGISel.td b/llvm/lib/Target/RISCV/RISCVGISel.td index 7f5d0af..6d01250 100644 --- a/llvm/lib/Target/RISCV/RISCVGISel.td +++ b/llvm/lib/Target/RISCV/RISCVGISel.td @@ -190,3 +190,29 @@ let Predicates = [HasStdExtZbkb, NoStdExtZbb, IsRV64] in { def : Pat<(i64 (zext (i16 GPR:$rs))), (PACKW GPR:$rs, (XLenVT X0))>; def : Pat<(i32 (zext (i16 GPR:$rs))), (PACKW GPR:$rs, (XLenVT X0))>; } + +//===----------------------------------------------------------------------===// +// Zalasr patterns not used by SelectionDAG +//===----------------------------------------------------------------------===// + +let Predicates = [HasStdExtZalasr] in { + // the sequentially consistent loads use + // .aq instead of .aqrl to match the psABI/A.7 + def : PatLAQ<acquiring_load<atomic_load_aext_8>, LB_AQ, i16>; + def : PatLAQ<seq_cst_load<atomic_load_aext_8>, LB_AQ, i16>; + + def : PatLAQ<acquiring_load<atomic_load_nonext_16>, LH_AQ, i16>; + def : PatLAQ<seq_cst_load<atomic_load_nonext_16>, LH_AQ, i16>; + + def : PatSRL<releasing_store<atomic_store_8>, SB_RL, i16>; + def : PatSRL<seq_cst_store<atomic_store_8>, SB_RL, i16>; + + def : PatSRL<releasing_store<atomic_store_16>, SH_RL, i16>; + def : PatSRL<seq_cst_store<atomic_store_16>, SH_RL, i16>; +} + +let Predicates = [HasStdExtZalasr, IsRV64] in { + // Load pattern is in RISCVInstrInfoZalasr.td and shared with RV32. + def : PatSRL<releasing_store<atomic_store_32>, SW_RL, i32>; + def : PatSRL<seq_cst_store<atomic_store_32>, SW_RL, i32>; +} diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index dcce2d2..a3a4cf2 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -434,7 +434,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::ABS, MVT::i32, Custom); } - if (!Subtarget.useCCMovInsn() && !Subtarget.hasVendorXTHeadCondMov()) + if (!Subtarget.useMIPSCCMovInsn() && !Subtarget.hasVendorXTHeadCondMov()) setOperationAction(ISD::SELECT, XLenVT, Custom); if (Subtarget.hasVendorXqcia() && !Subtarget.is64Bit()) { @@ -16498,43 +16498,60 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, SDValue X = N->getOperand(0); if (Subtarget.hasShlAdd(3)) { - for (uint64_t Divisor : {3, 5, 9}) { - if (MulAmt % Divisor != 0) - continue; - uint64_t MulAmt2 = MulAmt / Divisor; - // 3/5/9 * 2^N -> shl (shXadd X, X), N - if (isPowerOf2_64(MulAmt2)) { - SDLoc DL(N); - SDValue X = N->getOperand(0); - // Put the shift first if we can fold a zext into the - // shift forming a slli.uw. - if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) && - X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) { - SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X, - DAG.getConstant(Log2_64(MulAmt2), DL, VT)); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), - Shl); - } - // Otherwise, put rhe shl second so that it can fold with following - // instructions (e.g. sext or add). - SDValue Mul359 = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); - return DAG.getNode(ISD::SHL, DL, VT, Mul359, - DAG.getConstant(Log2_64(MulAmt2), DL, VT)); - } - - // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X) - if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) { - SDLoc DL(N); - SDValue Mul359 = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT), - Mul359); + int Shift; + if (int ShXAmount = isShifted359(MulAmt, Shift)) { + // 3/5/9 * 2^N -> shl (shXadd X, X), N + SDLoc DL(N); + SDValue X = N->getOperand(0); + // Put the shift first if we can fold a zext into the shift forming + // a slli.uw. + if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) && + X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) { + SDValue Shl = + DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT)); + return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl, + DAG.getConstant(ShXAmount, DL, VT), Shl); } + // Otherwise, put the shl second so that it can fold with following + // instructions (e.g. sext or add). + SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getConstant(ShXAmount, DL, VT), X); + return DAG.getNode(ISD::SHL, DL, VT, Mul359, + DAG.getConstant(Shift, DL, VT)); + } + + // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X) + int ShX; + int ShY; + switch (MulAmt) { + case 3 * 5: + ShY = 1; + ShX = 2; + break; + case 3 * 9: + ShY = 1; + ShX = 3; + break; + case 5 * 5: + ShX = ShY = 2; + break; + case 5 * 9: + ShY = 2; + ShX = 3; + break; + case 9 * 9: + ShX = ShY = 3; + break; + default: + ShX = ShY = 0; + break; + } + if (ShX) { + SDLoc DL(N); + SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getConstant(ShY, DL, VT), X); + return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, + DAG.getConstant(ShX, DL, VT), Mul359); } // If this is a power 2 + 2/4/8, we can use a shift followed by a single @@ -16557,18 +16574,14 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, // variants we could implement. e.g. // (2^(1,2,3) * 3,5,9 + 1) << C2 // 2^(C1>3) * 3,5,9 +/- 1 - for (uint64_t Divisor : {3, 5, 9}) { - uint64_t C = MulAmt - 1; - if (C <= Divisor) - continue; - unsigned TZ = llvm::countr_zero(C); - if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) { + if (int ShXAmount = isShifted359(MulAmt - 1, Shift)) { + assert(Shift != 0 && "MulAmt=4,6,10 handled before"); + if (Shift <= 3) { SDLoc DL(N); - SDValue Mul359 = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); + SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getConstant(ShXAmount, DL, VT), X); return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getConstant(TZ, DL, VT), X); + DAG.getConstant(Shift, DL, VT), X); } } @@ -16576,7 +16589,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) { unsigned ScaleShift = llvm::countr_zero(MulAmt - 1); if (ScaleShift >= 1 && ScaleShift < 4) { - unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2))); + unsigned ShiftAmt = llvm::countr_zero((MulAmt - 1) & (MulAmt - 2)); SDLoc DL(N); SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT)); @@ -16589,7 +16602,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, // 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x)) for (uint64_t Offset : {3, 5, 9}) { if (isPowerOf2_64(MulAmt + Offset)) { - unsigned ShAmt = Log2_64(MulAmt + Offset); + unsigned ShAmt = llvm::countr_zero(MulAmt + Offset); if (ShAmt >= VT.getSizeInBits()) continue; SDLoc DL(N); @@ -16608,21 +16621,16 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, uint64_t MulAmt2 = MulAmt / Divisor; // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples // of 25 which happen to be quite common. - for (uint64_t Divisor2 : {3, 5, 9}) { - if (MulAmt2 % Divisor2 != 0) - continue; - uint64_t MulAmt3 = MulAmt2 / Divisor2; - if (isPowerOf2_64(MulAmt3)) { - SDLoc DL(N); - SDValue Mul359A = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); - SDValue Mul359B = DAG.getNode( - RISCVISD::SHL_ADD, DL, VT, Mul359A, - DAG.getConstant(Log2_64(Divisor2 - 1), DL, VT), Mul359A); - return DAG.getNode(ISD::SHL, DL, VT, Mul359B, - DAG.getConstant(Log2_64(MulAmt3), DL, VT)); - } + if (int ShBAmount = isShifted359(MulAmt2, Shift)) { + SDLoc DL(N); + SDValue Mul359A = + DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); + SDValue Mul359B = + DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359A, + DAG.getConstant(ShBAmount, DL, VT), Mul359A); + return DAG.getNode(ISD::SHL, DL, VT, Mul359B, + DAG.getConstant(Shift, DL, VT)); } } } @@ -25031,8 +25039,17 @@ bool RISCVTargetLowering::fallBackToDAGISel(const Instruction &Inst) const { if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) { // Mark RVV intrinsic as supported. - if (RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(II->getIntrinsicID())) + if (RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(II->getIntrinsicID())) { + // GISel doesn't support tuple types yet. + if (Inst.getType()->isRISCVVectorTupleTy()) + return true; + + for (unsigned i = 0; i < II->arg_size(); ++i) + if (II->getArgOperand(i)->getType()->isRISCVVectorTupleTy()) + return true; + return false; + } } if (Inst.getType()->isScalableTy()) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index 7db4832..96e1078 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -4586,24 +4586,23 @@ void RISCVInstrInfo::mulImm(MachineFunction &MF, MachineBasicBlock &MBB, .addReg(DestReg, RegState::Kill) .addImm(ShiftAmount) .setMIFlag(Flag); - } else if (STI.hasShlAdd(3) && - ((Amount % 3 == 0 && isPowerOf2_64(Amount / 3)) || - (Amount % 5 == 0 && isPowerOf2_64(Amount / 5)) || - (Amount % 9 == 0 && isPowerOf2_64(Amount / 9)))) { + } else if (int ShXAmount, ShiftAmount; + STI.hasShlAdd(3) && + (ShXAmount = isShifted359(Amount, ShiftAmount)) != 0) { // We can use Zba SHXADD+SLLI instructions for multiply in some cases. unsigned Opc; - uint32_t ShiftAmount; - if (Amount % 9 == 0) { - Opc = RISCV::SH3ADD; - ShiftAmount = Log2_64(Amount / 9); - } else if (Amount % 5 == 0) { - Opc = RISCV::SH2ADD; - ShiftAmount = Log2_64(Amount / 5); - } else if (Amount % 3 == 0) { + switch (ShXAmount) { + case 1: Opc = RISCV::SH1ADD; - ShiftAmount = Log2_64(Amount / 3); - } else { - llvm_unreachable("implied by if-clause"); + break; + case 2: + Opc = RISCV::SH2ADD; + break; + case 3: + Opc = RISCV::SH3ADD; + break; + default: + llvm_unreachable("unexpected result of isShifted359"); } if (ShiftAmount) BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h index 42a0c4c..c5eddb9 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -25,6 +25,25 @@ namespace llvm { +// If Value is of the form C1<<C2, where C1 = 3, 5 or 9, +// returns log2(C1 - 1) and assigns Shift = C2. +// Otherwise, returns 0. +template <typename T> int isShifted359(T Value, int &Shift) { + if (Value == 0) + return 0; + Shift = llvm::countr_zero(Value); + switch (Value >> Shift) { + case 3: + return 1; + case 5: + return 2; + case 9: + return 3; + default: + return 0; + } +} + class RISCVSubtarget; static const MachineMemOperand::Flags MONontemporalBit0 = diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td index 115ab38e..0b5bee1 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td @@ -175,7 +175,7 @@ def MIPS_CCMOV : RVInstR4<0b11, 0b011, OPC_CUSTOM_0, (outs GPR:$rd), Sched<[]>; } -let Predicates = [UseCCMovInsn] in { +let Predicates = [UseMIPSCCMovInsn] in { def : Pat<(select (riscv_setne (XLenVT GPR:$rs2)), (XLenVT GPR:$rs1), (XLenVT GPR:$rs3)), (MIPS_CCMOV GPR:$rs1, GPR:$rs2, GPR:$rs3)>; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZalasr.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZalasr.td index 1dd7332..1deecd2 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZalasr.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZalasr.td @@ -93,12 +93,11 @@ let Predicates = [HasStdExtZalasr] in { def : PatSRL<releasing_store<atomic_store_32>, SW_RL>; def : PatSRL<seq_cst_store<atomic_store_32>, SW_RL>; -} // Predicates = [HasStdExtZalasr] -let Predicates = [HasStdExtZalasr, IsRV32] in { - def : PatLAQ<acquiring_load<atomic_load_nonext_32>, LW_AQ>; - def : PatLAQ<seq_cst_load<atomic_load_nonext_32>, LW_AQ>; -} // Predicates = [HasStdExtZalasr, IsRV32] + // Used by GISel for RV32 and RV64. + def : PatLAQ<acquiring_load<atomic_load_nonext_32>, LW_AQ, i32>; + def : PatLAQ<seq_cst_load<atomic_load_nonext_32>, LW_AQ, i32>; +} // Predicates = [HasStdExtZalasr] let Predicates = [HasStdExtZalasr, IsRV64] in { def : PatLAQ<acquiring_load<atomic_load_asext_32>, LW_AQ, i64>; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td index ce21d83..8d9b777 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td @@ -808,9 +808,9 @@ multiclass Sh2Add_UWPat<Instruction sh2add_uw> { } multiclass Sh3Add_UWPat<Instruction sh3add_uw> { - def : Pat<(i64 (add_like_non_imm12 (and GPR:$rs1, 0xFFFFFFF8), + def : Pat<(i64 (add_like_non_imm12 (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), (XLenVT GPR:$rs2))), - (sh3add_uw (XLenVT (SRLIW GPR:$rs1, 3)), GPR:$rs2)>; + (sh3add_uw GPR:$rs1, GPR:$rs2)>; // Use SRLI to clear the LSBs and SHXADD_UW to mask and shift. def : Pat<(i64 (add_like_non_imm12 (and GPR:$rs1, 0x7FFFFFFF8), (XLenVT GPR:$rs2))), diff --git a/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp index c81a20b..115a96e 100644 --- a/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp +++ b/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp @@ -92,7 +92,7 @@ bool RISCVLoadStoreOpt::runOnMachineFunction(MachineFunction &Fn) { if (skipFunction(Fn.getFunction())) return false; const RISCVSubtarget &Subtarget = Fn.getSubtarget<RISCVSubtarget>(); - if (!Subtarget.useLoadStorePairs()) + if (!Subtarget.useMIPSLoadStorePairs()) return false; bool MadeChange = false; diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp index e35ffaf..715ac4c 100644 --- a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp +++ b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp @@ -65,9 +65,9 @@ static cl::opt<bool> UseMIPSLoadStorePairsOpt( cl::desc("Enable the load/store pair optimization pass"), cl::init(false), cl::Hidden); -static cl::opt<bool> UseCCMovInsn("use-riscv-ccmov", - cl::desc("Use 'mips.ccmov' instruction"), - cl::init(true), cl::Hidden); +static cl::opt<bool> UseMIPSCCMovInsn("use-riscv-mips-ccmov", + cl::desc("Use 'mips.ccmov' instruction"), + cl::init(true), cl::Hidden); void RISCVSubtarget::anchor() {} @@ -246,10 +246,10 @@ void RISCVSubtarget::overridePostRASchedPolicy( } } -bool RISCVSubtarget::useLoadStorePairs() const { +bool RISCVSubtarget::useMIPSLoadStorePairs() const { return UseMIPSLoadStorePairsOpt && HasVendorXMIPSLSP; } -bool RISCVSubtarget::useCCMovInsn() const { - return UseCCMovInsn && HasVendorXMIPSCMov; +bool RISCVSubtarget::useMIPSCCMovInsn() const { + return UseMIPSCCMovInsn && HasVendorXMIPSCMov; } diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h index 7dffa63..6acf799 100644 --- a/llvm/lib/Target/RISCV/RISCVSubtarget.h +++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h @@ -227,8 +227,8 @@ public: unsigned getXLen() const { return is64Bit() ? 64 : 32; } - bool useLoadStorePairs() const; - bool useCCMovInsn() const; + bool useMIPSLoadStorePairs() const; + bool useMIPSCCMovInsn() const; unsigned getFLen() const { if (HasStdExtD) return 64; diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 9f2e075..e16c8f0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -2811,9 +2811,7 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { GetElementPtrInst *NewGEP = simplifyZeroLengthArrayGepInst(Ref); if (NewGEP) { Ref->replaceAllUsesWith(NewGEP); - if (isInstructionTriviallyDead(Ref)) - DeadInsts.insert(Ref); - + DeadInsts.insert(Ref); Ref = NewGEP; } if (Type *GepTy = getGEPType(Ref)) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 0afec42..989950f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -307,6 +307,10 @@ private: bool selectHandleFromBinding(Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectCounterHandleFromBinding(Register &ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const; + bool selectReadImageIntrinsic(Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const; bool selectImageWriteIntrinsic(MachineInstr &I) const; @@ -314,6 +318,8 @@ private: MachineInstr &I) const; bool selectModf(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectUpdateCounter(Register &ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; bool selectFrexp(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; // Utilities @@ -3443,6 +3449,10 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, case Intrinsic::spv_resource_handlefrombinding: { return selectHandleFromBinding(ResVReg, ResType, I); } + case Intrinsic::spv_resource_counterhandlefrombinding: + return selectCounterHandleFromBinding(ResVReg, ResType, I); + case Intrinsic::spv_resource_updatecounter: + return selectUpdateCounter(ResVReg, ResType, I); case Intrinsic::spv_resource_store_typedbuffer: { return selectImageWriteIntrinsic(I); } @@ -3478,6 +3488,130 @@ bool SPIRVInstructionSelector::selectHandleFromBinding(Register &ResVReg, *cast<GIntrinsic>(&I), I); } +bool SPIRVInstructionSelector::selectCounterHandleFromBinding( + Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const { + auto &Intr = cast<GIntrinsic>(I); + assert(Intr.getIntrinsicID() == + Intrinsic::spv_resource_counterhandlefrombinding); + + // Extract information from the intrinsic call. + Register MainHandleReg = Intr.getOperand(2).getReg(); + auto *MainHandleDef = cast<GIntrinsic>(getVRegDef(*MRI, MainHandleReg)); + assert(MainHandleDef->getIntrinsicID() == + Intrinsic::spv_resource_handlefrombinding); + + uint32_t Set = getIConstVal(Intr.getOperand(4).getReg(), MRI); + uint32_t Binding = getIConstVal(Intr.getOperand(3).getReg(), MRI); + uint32_t ArraySize = getIConstVal(MainHandleDef->getOperand(4).getReg(), MRI); + Register IndexReg = MainHandleDef->getOperand(5).getReg(); + const bool IsNonUniform = false; + std::string CounterName = + getStringValueFromReg(MainHandleDef->getOperand(6).getReg(), *MRI) + + ".counter"; + + // Create the counter variable. + MachineIRBuilder MIRBuilder(I); + Register CounterVarReg = buildPointerToResource( + GR.getPointeeType(ResType), GR.getPointerStorageClass(ResType), Set, + Binding, ArraySize, IndexReg, IsNonUniform, CounterName, MIRBuilder); + + return BuildCOPY(ResVReg, CounterVarReg, I); +} + +bool SPIRVInstructionSelector::selectUpdateCounter(Register &ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + auto &Intr = cast<GIntrinsic>(I); + assert(Intr.getIntrinsicID() == Intrinsic::spv_resource_updatecounter); + + Register CounterHandleReg = Intr.getOperand(2).getReg(); + Register IncrReg = Intr.getOperand(3).getReg(); + + // The counter handle is a pointer to the counter variable (which is a struct + // containing an i32). We need to get a pointer to that i32 member to do the + // atomic operation. +#ifndef NDEBUG + SPIRVType *CounterVarType = GR.getSPIRVTypeForVReg(CounterHandleReg); + SPIRVType *CounterVarPointeeType = GR.getPointeeType(CounterVarType); + assert(CounterVarPointeeType && + CounterVarPointeeType->getOpcode() == SPIRV::OpTypeStruct && + "Counter variable must be a struct"); + assert(GR.getPointerStorageClass(CounterVarType) == + SPIRV::StorageClass::StorageBuffer && + "Counter variable must be in the storage buffer storage class"); + assert(CounterVarPointeeType->getNumOperands() == 2 && + "Counter variable must have exactly 1 member in the struct"); + const SPIRVType *MemberType = + GR.getSPIRVTypeForVReg(CounterVarPointeeType->getOperand(1).getReg()); + assert(MemberType->getOpcode() == SPIRV::OpTypeInt && + "Counter variable struct must have a single i32 member"); +#endif + + // The struct has a single i32 member. + MachineIRBuilder MIRBuilder(I); + const Type *LLVMIntType = + Type::getInt32Ty(I.getMF()->getFunction().getContext()); + + SPIRVType *IntPtrType = GR.getOrCreateSPIRVPointerType( + LLVMIntType, MIRBuilder, SPIRV::StorageClass::StorageBuffer); + + auto Zero = buildI32Constant(0, I); + if (!Zero.second) + return false; + + Register PtrToCounter = + MRI->createVirtualRegister(GR.getRegClass(IntPtrType)); + if (!BuildMI(*I.getParent(), I, I.getDebugLoc(), + TII.get(SPIRV::OpAccessChain)) + .addDef(PtrToCounter) + .addUse(GR.getSPIRVTypeID(IntPtrType)) + .addUse(CounterHandleReg) + .addUse(Zero.first) + .constrainAllUses(TII, TRI, RBI)) { + return false; + } + + // For UAV/SSBO counters, the scope is Device. The counter variable is not + // used as a flag. So the memory semantics can be None. + auto Scope = buildI32Constant(SPIRV::Scope::Device, I); + if (!Scope.second) + return false; + auto Semantics = buildI32Constant(SPIRV::MemorySemantics::None, I); + if (!Semantics.second) + return false; + + int64_t IncrVal = getIConstValSext(IncrReg, MRI); + auto Incr = buildI32Constant(static_cast<uint32_t>(IncrVal), I); + if (!Incr.second) + return false; + + Register AtomicRes = MRI->createVirtualRegister(GR.getRegClass(ResType)); + if (!BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpAtomicIAdd)) + .addDef(AtomicRes) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(PtrToCounter) + .addUse(Scope.first) + .addUse(Semantics.first) + .addUse(Incr.first) + .constrainAllUses(TII, TRI, RBI)) { + return false; + } + if (IncrVal >= 0) { + return BuildCOPY(ResVReg, AtomicRes, I); + } + + // In HLSL, IncrementCounter returns the value *before* the increment, while + // DecrementCounter returns the value *after* the decrement. Both are lowered + // to the same atomic intrinsic which returns the value *before* the + // operation. So for decrements (negative IncrVal), we must subtract the + // increment value from the result to get the post-decrement value. + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(AtomicRes) + .addUse(Incr.first) + .constrainAllUses(TII, TRI, RBI); +} bool SPIRVInstructionSelector::selectReadImageIntrinsic( Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const { diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp index 205895e..fc14a03 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp @@ -39,6 +39,10 @@ private: void collectBindingInfo(Module &M); uint32_t getAndReserveFirstUnusedBinding(uint32_t DescSet); void replaceImplicitBindingCalls(Module &M); + void replaceResourceHandleCall(Module &M, CallInst *OldCI, + uint32_t NewBinding); + void replaceCounterHandleCall(Module &M, CallInst *OldCI, + uint32_t NewBinding); void verifyUniqueOrderIdPerResource(SmallVectorImpl<CallInst *> &Calls); // A map from descriptor set to a bit vector of used binding numbers. @@ -56,64 +60,93 @@ struct BindingInfoCollector : public InstVisitor<BindingInfoCollector> { : UsedBindings(UsedBindings), ImplicitBindingCalls(ImplicitBindingCalls) { } + void addBinding(uint32_t DescSet, uint32_t Binding) { + if (UsedBindings.size() <= DescSet) { + UsedBindings.resize(DescSet + 1); + UsedBindings[DescSet].resize(64); + } + if (UsedBindings[DescSet].size() <= Binding) { + UsedBindings[DescSet].resize(2 * Binding + 1); + } + UsedBindings[DescSet].set(Binding); + } + void visitCallInst(CallInst &CI) { if (CI.getIntrinsicID() == Intrinsic::spv_resource_handlefrombinding) { const uint32_t DescSet = cast<ConstantInt>(CI.getArgOperand(0))->getZExtValue(); const uint32_t Binding = cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue(); - - if (UsedBindings.size() <= DescSet) { - UsedBindings.resize(DescSet + 1); - UsedBindings[DescSet].resize(64); - } - if (UsedBindings[DescSet].size() <= Binding) { - UsedBindings[DescSet].resize(2 * Binding + 1); - } - UsedBindings[DescSet].set(Binding); + addBinding(DescSet, Binding); } else if (CI.getIntrinsicID() == Intrinsic::spv_resource_handlefromimplicitbinding) { ImplicitBindingCalls.push_back(&CI); + } else if (CI.getIntrinsicID() == + Intrinsic::spv_resource_counterhandlefrombinding) { + const uint32_t DescSet = + cast<ConstantInt>(CI.getArgOperand(2))->getZExtValue(); + const uint32_t Binding = + cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue(); + addBinding(DescSet, Binding); + } else if (CI.getIntrinsicID() == + Intrinsic::spv_resource_counterhandlefromimplicitbinding) { + ImplicitBindingCalls.push_back(&CI); } } }; +static uint32_t getOrderId(const CallInst *CI) { + uint32_t OrderIdArgIdx = 0; + switch (CI->getIntrinsicID()) { + case Intrinsic::spv_resource_handlefromimplicitbinding: + OrderIdArgIdx = 0; + break; + case Intrinsic::spv_resource_counterhandlefromimplicitbinding: + OrderIdArgIdx = 1; + break; + default: + llvm_unreachable("CallInst is not an implicit binding intrinsic"); + } + return cast<ConstantInt>(CI->getArgOperand(OrderIdArgIdx))->getZExtValue(); +} + +static uint32_t getDescSet(const CallInst *CI) { + uint32_t DescSetArgIdx; + switch (CI->getIntrinsicID()) { + case Intrinsic::spv_resource_handlefromimplicitbinding: + case Intrinsic::spv_resource_handlefrombinding: + DescSetArgIdx = 1; + break; + case Intrinsic::spv_resource_counterhandlefromimplicitbinding: + case Intrinsic::spv_resource_counterhandlefrombinding: + DescSetArgIdx = 2; + break; + default: + llvm_unreachable("CallInst is not an implicit binding intrinsic"); + } + return cast<ConstantInt>(CI->getArgOperand(DescSetArgIdx))->getZExtValue(); +} + void SPIRVLegalizeImplicitBinding::collectBindingInfo(Module &M) { BindingInfoCollector InfoCollector(UsedBindings, ImplicitBindingCalls); InfoCollector.visit(M); // Sort the collected calls by their order ID. - std::sort( - ImplicitBindingCalls.begin(), ImplicitBindingCalls.end(), - [](const CallInst *A, const CallInst *B) { - const uint32_t OrderIdArgIdx = 0; - const uint32_t OrderA = - cast<ConstantInt>(A->getArgOperand(OrderIdArgIdx))->getZExtValue(); - const uint32_t OrderB = - cast<ConstantInt>(B->getArgOperand(OrderIdArgIdx))->getZExtValue(); - return OrderA < OrderB; - }); + std::sort(ImplicitBindingCalls.begin(), ImplicitBindingCalls.end(), + [](const CallInst *A, const CallInst *B) { + return getOrderId(A) < getOrderId(B); + }); } void SPIRVLegalizeImplicitBinding::verifyUniqueOrderIdPerResource( SmallVectorImpl<CallInst *> &Calls) { // Check that the order Id is unique per resource. for (uint32_t i = 1; i < Calls.size(); ++i) { - const uint32_t OrderIdArgIdx = 0; - const uint32_t DescSetArgIdx = 1; - const uint32_t OrderA = - cast<ConstantInt>(Calls[i - 1]->getArgOperand(OrderIdArgIdx)) - ->getZExtValue(); - const uint32_t OrderB = - cast<ConstantInt>(Calls[i]->getArgOperand(OrderIdArgIdx)) - ->getZExtValue(); + const uint32_t OrderA = getOrderId(Calls[i - 1]); + const uint32_t OrderB = getOrderId(Calls[i]); if (OrderA == OrderB) { - const uint32_t DescSetA = - cast<ConstantInt>(Calls[i - 1]->getArgOperand(DescSetArgIdx)) - ->getZExtValue(); - const uint32_t DescSetB = - cast<ConstantInt>(Calls[i]->getArgOperand(DescSetArgIdx)) - ->getZExtValue(); + const uint32_t DescSetA = getDescSet(Calls[i - 1]); + const uint32_t DescSetB = getDescSet(Calls[i]); if (DescSetA != DescSetB) { report_fatal_error("Implicit binding calls with the same order ID must " "have the same descriptor set"); @@ -144,36 +177,26 @@ void SPIRVLegalizeImplicitBinding::replaceImplicitBindingCalls(Module &M) { uint32_t lastBindingNumber = -1; for (CallInst *OldCI : ImplicitBindingCalls) { - IRBuilder<> Builder(OldCI); - const uint32_t OrderId = - cast<ConstantInt>(OldCI->getArgOperand(0))->getZExtValue(); - const uint32_t DescSet = - cast<ConstantInt>(OldCI->getArgOperand(1))->getZExtValue(); - - // Reuse an existing binding for this order ID, if one was already assigned. - // Otherwise, assign a new binding. - const uint32_t NewBinding = (lastOrderId == OrderId) - ? lastBindingNumber - : getAndReserveFirstUnusedBinding(DescSet); - lastOrderId = OrderId; - lastBindingNumber = NewBinding; - - SmallVector<Value *, 8> Args; - Args.push_back(Builder.getInt32(DescSet)); - Args.push_back(Builder.getInt32(NewBinding)); - - // Copy the remaining arguments from the old call. - for (uint32_t i = 2; i < OldCI->arg_size(); ++i) { - Args.push_back(OldCI->getArgOperand(i)); + const uint32_t OrderId = getOrderId(OldCI); + uint32_t BindingNumber; + if (OrderId == lastOrderId) { + BindingNumber = lastBindingNumber; + } else { + const uint32_t DescSet = getDescSet(OldCI); + BindingNumber = getAndReserveFirstUnusedBinding(DescSet); } - Function *NewFunc = Intrinsic::getOrInsertDeclaration( - &M, Intrinsic::spv_resource_handlefrombinding, OldCI->getType()); - CallInst *NewCI = Builder.CreateCall(NewFunc, Args); - NewCI->setCallingConv(OldCI->getCallingConv()); - - OldCI->replaceAllUsesWith(NewCI); - OldCI->eraseFromParent(); + if (OldCI->getIntrinsicID() == + Intrinsic::spv_resource_handlefromimplicitbinding) { + replaceResourceHandleCall(M, OldCI, BindingNumber); + } else { + assert(OldCI->getIntrinsicID() == + Intrinsic::spv_resource_counterhandlefromimplicitbinding && + "Unexpected implicit binding intrinsic"); + replaceCounterHandleCall(M, OldCI, BindingNumber); + } + lastOrderId = OrderId; + lastBindingNumber = BindingNumber; } } @@ -196,4 +219,49 @@ INITIALIZE_PASS(SPIRVLegalizeImplicitBinding, "legalize-spirv-implicit-binding", ModulePass *llvm::createSPIRVLegalizeImplicitBindingPass() { return new SPIRVLegalizeImplicitBinding(); -}
\ No newline at end of file +} + +void SPIRVLegalizeImplicitBinding::replaceResourceHandleCall( + Module &M, CallInst *OldCI, uint32_t NewBinding) { + IRBuilder<> Builder(OldCI); + const uint32_t DescSet = + cast<ConstantInt>(OldCI->getArgOperand(1))->getZExtValue(); + + SmallVector<Value *, 8> Args; + Args.push_back(Builder.getInt32(DescSet)); + Args.push_back(Builder.getInt32(NewBinding)); + + // Copy the remaining arguments from the old call. + for (uint32_t i = 2; i < OldCI->arg_size(); ++i) { + Args.push_back(OldCI->getArgOperand(i)); + } + + Function *NewFunc = Intrinsic::getOrInsertDeclaration( + &M, Intrinsic::spv_resource_handlefrombinding, OldCI->getType()); + CallInst *NewCI = Builder.CreateCall(NewFunc, Args); + NewCI->setCallingConv(OldCI->getCallingConv()); + + OldCI->replaceAllUsesWith(NewCI); + OldCI->eraseFromParent(); +} + +void SPIRVLegalizeImplicitBinding::replaceCounterHandleCall( + Module &M, CallInst *OldCI, uint32_t NewBinding) { + IRBuilder<> Builder(OldCI); + const uint32_t DescSet = + cast<ConstantInt>(OldCI->getArgOperand(2))->getZExtValue(); + + SmallVector<Value *, 8> Args; + Args.push_back(OldCI->getArgOperand(0)); + Args.push_back(Builder.getInt32(NewBinding)); + Args.push_back(Builder.getInt32(DescSet)); + + Type *Tys[] = {OldCI->getType(), OldCI->getArgOperand(0)->getType()}; + Function *NewFunc = Intrinsic::getOrInsertDeclaration( + &M, Intrinsic::spv_resource_counterhandlefrombinding, Tys); + CallInst *NewCI = Builder.CreateCall(NewFunc, Args); + NewCI->setCallingConv(OldCI->getCallingConv()); + + OldCI->replaceAllUsesWith(NewCI); + OldCI->eraseFromParent(); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 327c011..1d47c89 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -385,6 +385,12 @@ uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) { return MI->getOperand(1).getCImm()->getValue().getZExtValue(); } +int64_t getIConstValSext(Register ConstReg, const MachineRegisterInfo *MRI) { + const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI); + assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT); + return MI->getOperand(1).getCImm()->getSExtValue(); +} + bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) { if (const auto *GI = dyn_cast<GIntrinsic>(&MI)) return GI->is(IntrinsicID); diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index 409a0fd..5777a24 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -289,6 +289,9 @@ MachineInstr *getDefInstrMaybeConstant(Register &ConstReg, // Get constant integer value of the given ConstReg. uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI); +// Get constant integer value of the given ConstReg, sign-extended. +int64_t getIConstValSext(Register ConstReg, const MachineRegisterInfo *MRI); + // Check if MI is a SPIR-V specific intrinsic call. bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID); // Check if it's a SPIR-V specific intrinsic call. diff --git a/llvm/lib/TargetParser/TargetParser.cpp b/llvm/lib/TargetParser/TargetParser.cpp index b906690..62a3c88 100644 --- a/llvm/lib/TargetParser/TargetParser.cpp +++ b/llvm/lib/TargetParser/TargetParser.cpp @@ -444,7 +444,7 @@ static void fillAMDGCNFeatureMap(StringRef GPU, const Triple &T, Features["atomic-fmin-fmax-global-f32"] = true; Features["atomic-fmin-fmax-global-f64"] = true; Features["wavefrontsize32"] = true; - Features["cluster"] = true; + Features["clusters"] = true; break; case GK_GFX1201: case GK_GFX1200: diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp index 8d9a0e7..50130da 100644 --- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp @@ -2067,6 +2067,36 @@ static void inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes, AI.run(SCCNodes, Changed); } +// Determines if the function 'F' can be marked 'norecurse'. +// It returns true if any call within 'F' could lead to a recursive +// call back to 'F', and false otherwise. +// The 'AnyFunctionsAddressIsTaken' parameter is a module-wide flag +// that is true if any function's address is taken, or if any function +// has external linkage. This is used to determine the safety of +// external/library calls. +static bool mayHaveRecursiveCallee(Function &F, + bool AnyFunctionsAddressIsTaken = true) { + for (const auto &BB : F) { + for (const auto &I : BB.instructionsWithoutDebug()) { + if (const auto *CB = dyn_cast<CallBase>(&I)) { + const Function *Callee = CB->getCalledFunction(); + if (!Callee || Callee == &F) + return true; + + if (Callee->doesNotRecurse()) + continue; + + if (!AnyFunctionsAddressIsTaken || + (Callee->isDeclaration() && + Callee->hasFnAttribute(Attribute::NoCallback))) + continue; + return true; + } + } + } + return false; +} + static void addNoRecurseAttrs(const SCCNodeSet &SCCNodes, SmallPtrSet<Function *, 8> &Changed) { // Try and identify functions that do not recurse. @@ -2078,28 +2108,14 @@ static void addNoRecurseAttrs(const SCCNodeSet &SCCNodes, Function *F = *SCCNodes.begin(); if (!F || !F->hasExactDefinition() || F->doesNotRecurse()) return; - - // If all of the calls in F are identifiable and are to norecurse functions, F - // is norecurse. This check also detects self-recursion as F is not currently - // marked norecurse, so any called from F to F will not be marked norecurse. - for (auto &BB : *F) - for (auto &I : BB.instructionsWithoutDebug()) - if (auto *CB = dyn_cast<CallBase>(&I)) { - Function *Callee = CB->getCalledFunction(); - if (!Callee || Callee == F || - (!Callee->doesNotRecurse() && - !(Callee->isDeclaration() && - Callee->hasFnAttribute(Attribute::NoCallback)))) - // Function calls a potentially recursive function. - return; - } - - // Every call was to a non-recursive function other than this function, and - // we have no indirect recursion as the SCC size is one. This function cannot - // recurse. - F->setDoesNotRecurse(); - ++NumNoRecurse; - Changed.insert(F); + if (!mayHaveRecursiveCallee(*F)) { + // Every call was to a non-recursive function other than this function, and + // we have no indirect recursion as the SCC size is one. This function + // cannot recurse. + F->setDoesNotRecurse(); + ++NumNoRecurse; + Changed.insert(F); + } } // Set the noreturn function attribute if possible. @@ -2429,3 +2445,62 @@ ReversePostOrderFunctionAttrsPass::run(Module &M, ModuleAnalysisManager &AM) { PA.preserve<LazyCallGraphAnalysis>(); return PA; } + +PreservedAnalyses NoRecurseLTOInferencePass::run(Module &M, + ModuleAnalysisManager &MAM) { + + // Check if any function in the whole program has its address taken or has + // potentially external linkage. + // We use this information when inferring norecurse attribute: If there is + // no function whose address is taken and all functions have internal + // linkage, there is no path for a callback to any user function. + bool AnyFunctionsAddressIsTaken = false; + for (Function &F : M) { + if (F.isDeclaration() || F.doesNotRecurse()) + continue; + if (!F.hasLocalLinkage() || F.hasAddressTaken()) { + AnyFunctionsAddressIsTaken = true; + break; + } + } + + // Run norecurse inference on all RefSCCs in the LazyCallGraph for this + // module. + bool Changed = false; + LazyCallGraph &CG = MAM.getResult<LazyCallGraphAnalysis>(M); + CG.buildRefSCCs(); + + for (LazyCallGraph::RefSCC &RC : CG.postorder_ref_sccs()) { + // Skip any RefSCC that is part of a call cycle. A RefSCC containing more + // than one SCC indicates a recursive relationship involving indirect calls. + if (RC.size() > 1) + continue; + + // RefSCC contains a single-SCC. SCC size > 1 indicates mutually recursive + // functions. Ex: foo1 -> foo2 -> foo3 -> foo1. + LazyCallGraph::SCC &S = *RC.begin(); + if (S.size() > 1) + continue; + + // Get the single function from this SCC. + Function &F = S.begin()->getFunction(); + if (!F.hasExactDefinition() || F.doesNotRecurse()) + continue; + + // If the analysis confirms that this function has no recursive calls + // (either direct, indirect, or through external linkages), + // we can safely apply the norecurse attribute. + if (!mayHaveRecursiveCallee(F, AnyFunctionsAddressIsTaken)) { + F.setDoesNotRecurse(); + ++NumNoRecurse; + Changed = true; + } + } + + PreservedAnalyses PA; + if (Changed) + PA.preserve<LazyCallGraphAnalysis>(); + else + PA = PreservedAnalyses::all(); + return PA; +} diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 8f60e50..8c8fc69 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3356,7 +3356,10 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { impliesPoisonOrCond(FalseVal, B, /*Expected=*/false)) { // (A || B) || C --> A || (B | C) return replaceInstUsesWith( - SI, Builder.CreateLogicalOr(A, Builder.CreateOr(B, FalseVal))); + SI, Builder.CreateLogicalOr(A, Builder.CreateOr(B, FalseVal), "", + ProfcheckDisableMetadataFixes + ? nullptr + : cast<SelectInst>(CondVal))); } // (A && B) || (C && B) --> (A || C) && B @@ -3398,7 +3401,10 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { impliesPoisonOrCond(TrueVal, B, /*Expected=*/true)) { // (A && B) && C --> A && (B & C) return replaceInstUsesWith( - SI, Builder.CreateLogicalAnd(A, Builder.CreateAnd(B, TrueVal))); + SI, Builder.CreateLogicalAnd(A, Builder.CreateAnd(B, TrueVal), "", + ProfcheckDisableMetadataFixes + ? nullptr + : cast<SelectInst>(CondVal))); } // (A || B) && (C || B) --> (A && C) || B diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp index e9a3e98..41a6c80 100644 --- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp @@ -120,6 +120,12 @@ static cl::opt<unsigned> cl::desc("Maximum cost accepted for the transformation"), cl::Hidden, cl::init(50)); +static cl::opt<double> MaxClonedRate( + "dfa-max-cloned-rate", + cl::desc( + "Maximum cloned instructions rate accepted for the transformation"), + cl::Hidden, cl::init(7.5)); + namespace { class SelectInstToUnfold { @@ -828,6 +834,7 @@ private: /// also returns false if it is illegal to clone some required block. bool isLegalAndProfitableToTransform() { CodeMetrics Metrics; + uint64_t NumClonedInst = 0; SwitchInst *Switch = SwitchPaths->getSwitchInst(); // Don't thread switch without multiple successors. @@ -837,7 +844,6 @@ private: // Note that DuplicateBlockMap is not being used as intended here. It is // just being used to ensure (BB, State) pairs are only counted once. DuplicateBlockMap DuplicateMap; - for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) { PathType PathBBs = TPath.getPath(); APInt NextState = TPath.getExitValue(); @@ -848,6 +854,7 @@ private: BasicBlock *VisitedBB = getClonedBB(BB, NextState, DuplicateMap); if (!VisitedBB) { Metrics.analyzeBasicBlock(BB, *TTI, EphValues); + NumClonedInst += BB->sizeWithoutDebug(); DuplicateMap[BB].push_back({BB, NextState}); } @@ -865,6 +872,7 @@ private: if (VisitedBB) continue; Metrics.analyzeBasicBlock(BB, *TTI, EphValues); + NumClonedInst += BB->sizeWithoutDebug(); DuplicateMap[BB].push_back({BB, NextState}); } @@ -901,6 +909,22 @@ private: } } + // Too much cloned instructions slow down later optimizations, especially + // SLPVectorizer. + // TODO: Thread the switch partially before reaching the threshold. + uint64_t NumOrigInst = 0; + for (auto *BB : DuplicateMap.keys()) + NumOrigInst += BB->sizeWithoutDebug(); + if (double(NumClonedInst) / double(NumOrigInst) > MaxClonedRate) { + LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, too much " + "instructions wll be cloned\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NotProfitable", Switch) + << "Too much instructions will be cloned."; + }); + return false; + } + InstructionCost DuplicationCost = 0; unsigned JumpTableSize = 0; diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 148bfa8..b8cfe3a 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -4895,9 +4895,8 @@ bool SimplifyCFGOpt::simplifyTerminatorOnSelect(Instruction *OldTerm, // We found both of the successors we were looking for. // Create a conditional branch sharing the condition of the select. BranchInst *NewBI = Builder.CreateCondBr(Cond, TrueBB, FalseBB); - if (TrueWeight != FalseWeight) - setBranchWeights(*NewBI, {TrueWeight, FalseWeight}, - /*IsExpected=*/false, /*ElideAllZero=*/true); + setBranchWeights(*NewBI, {TrueWeight, FalseWeight}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } } else if (KeepEdge1 && (KeepEdge2 || TrueBB == FalseBB)) { // Neither of the selected blocks were successors, so this @@ -4982,9 +4981,15 @@ bool SimplifyCFGOpt::simplifyIndirectBrOnSelect(IndirectBrInst *IBI, BasicBlock *TrueBB = TBA->getBasicBlock(); BasicBlock *FalseBB = FBA->getBasicBlock(); + // The select's profile becomes the profile of the conditional branch that + // replaces the indirect branch. + SmallVector<uint32_t> SelectBranchWeights(2); + if (!ProfcheckDisableMetadataFixes) + extractBranchWeights(*SI, SelectBranchWeights); // Perform the actual simplification. - return simplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB, 0, - 0); + return simplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB, + SelectBranchWeights[0], + SelectBranchWeights[1]); } /// This is called when we find an icmp instruction @@ -7952,19 +7957,27 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { bool SimplifyCFGOpt::simplifyIndirectBr(IndirectBrInst *IBI) { BasicBlock *BB = IBI->getParent(); bool Changed = false; + SmallVector<uint32_t> BranchWeights; + const bool HasBranchWeights = !ProfcheckDisableMetadataFixes && + extractBranchWeights(*IBI, BranchWeights); + + DenseMap<const BasicBlock *, uint64_t> TargetWeight; + if (HasBranchWeights) + for (size_t I = 0, E = IBI->getNumDestinations(); I < E; ++I) + TargetWeight[IBI->getDestination(I)] += BranchWeights[I]; // Eliminate redundant destinations. SmallPtrSet<Value *, 8> Succs; SmallSetVector<BasicBlock *, 8> RemovedSuccs; - for (unsigned i = 0, e = IBI->getNumDestinations(); i != e; ++i) { - BasicBlock *Dest = IBI->getDestination(i); + for (unsigned I = 0, E = IBI->getNumDestinations(); I != E; ++I) { + BasicBlock *Dest = IBI->getDestination(I); if (!Dest->hasAddressTaken() || !Succs.insert(Dest).second) { if (!Dest->hasAddressTaken()) RemovedSuccs.insert(Dest); Dest->removePredecessor(BB); - IBI->removeDestination(i); - --i; - --e; + IBI->removeDestination(I); + --I; + --E; Changed = true; } } @@ -7990,7 +8003,12 @@ bool SimplifyCFGOpt::simplifyIndirectBr(IndirectBrInst *IBI) { eraseTerminatorAndDCECond(IBI); return true; } - + if (HasBranchWeights) { + SmallVector<uint64_t> NewBranchWeights(IBI->getNumDestinations()); + for (size_t I = 0, E = IBI->getNumDestinations(); I < E; ++I) + NewBranchWeights[I] += TargetWeight.find(IBI->getDestination(I))->second; + setFittedBranchWeights(*IBI, NewBranchWeights, /*IsExpected=*/false); + } if (SelectInst *SI = dyn_cast<SelectInst>(IBI->getAddress())) { if (simplifyIndirectBrOnSelect(IBI, SI)) return requestResimplify(); diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 56a3d6d..cee08ef 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -3903,7 +3903,8 @@ void LoopVectorizationPlanner::emitInvalidCostRemarks( if (VF.isScalar()) continue; - VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind); + VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind, + *CM.PSE.getSE()); precomputeCosts(*Plan, VF, CostCtx); auto Iter = vp_depth_first_deep(Plan->getVectorLoopRegion()->getEntry()); for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Iter)) { @@ -4160,7 +4161,8 @@ VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor() { // Add on other costs that are modelled in VPlan, but not in the legacy // cost model. - VPCostContext CostCtx(CM.TTI, *CM.TLI, *P, CM, CM.CostKind); + VPCostContext CostCtx(CM.TTI, *CM.TLI, *P, CM, CM.CostKind, + *CM.PSE.getSE()); VPRegionBlock *VectorRegion = P->getVectorLoopRegion(); assert(VectorRegion && "Expected to have a vector region!"); for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( @@ -6852,7 +6854,7 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF, InstructionCost LoopVectorizationPlanner::cost(VPlan &Plan, ElementCount VF) const { - VPCostContext CostCtx(CM.TTI, *CM.TLI, Plan, CM, CM.CostKind); + VPCostContext CostCtx(CM.TTI, *CM.TLI, Plan, CM, CM.CostKind, *PSE.getSE()); InstructionCost Cost = precomputeCosts(Plan, VF, CostCtx); // Now compute and add the VPlan-based cost. @@ -7085,7 +7087,8 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() { // simplifications not accounted for in the legacy cost model. If that's the // case, don't trigger the assertion, as the extra simplifications may cause a // different VF to be picked by the VPlan-based cost model. - VPCostContext CostCtx(CM.TTI, *CM.TLI, BestPlan, CM, CM.CostKind); + VPCostContext CostCtx(CM.TTI, *CM.TLI, BestPlan, CM, CM.CostKind, + *CM.PSE.getSE()); precomputeCosts(BestPlan, BestFactor.Width, CostCtx); // Verify that the VPlan-based and legacy cost models agree, except for VPlans // with early exits and plans with additional VPlan simplifications. The @@ -8201,211 +8204,6 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, } } -/// Create and return a ResumePhi for \p WideIV, unless it is truncated. If the -/// induction recipe is not canonical, creates a VPDerivedIVRecipe to compute -/// the end value of the induction. -static VPInstruction *addResumePhiRecipeForInduction( - VPWidenInductionRecipe *WideIV, VPBuilder &VectorPHBuilder, - VPBuilder &ScalarPHBuilder, VPTypeAnalysis &TypeInfo, VPValue *VectorTC) { - auto *WideIntOrFp = dyn_cast<VPWidenIntOrFpInductionRecipe>(WideIV); - // Truncated wide inductions resume from the last lane of their vector value - // in the last vector iteration which is handled elsewhere. - if (WideIntOrFp && WideIntOrFp->getTruncInst()) - return nullptr; - - VPValue *Start = WideIV->getStartValue(); - VPValue *Step = WideIV->getStepValue(); - const InductionDescriptor &ID = WideIV->getInductionDescriptor(); - VPValue *EndValue = VectorTC; - if (!WideIntOrFp || !WideIntOrFp->isCanonical()) { - EndValue = VectorPHBuilder.createDerivedIV( - ID.getKind(), dyn_cast_or_null<FPMathOperator>(ID.getInductionBinOp()), - Start, VectorTC, Step); - } - - // EndValue is derived from the vector trip count (which has the same type as - // the widest induction) and thus may be wider than the induction here. - Type *ScalarTypeOfWideIV = TypeInfo.inferScalarType(WideIV); - if (ScalarTypeOfWideIV != TypeInfo.inferScalarType(EndValue)) { - EndValue = VectorPHBuilder.createScalarCast(Instruction::Trunc, EndValue, - ScalarTypeOfWideIV, - WideIV->getDebugLoc()); - } - - auto *ResumePhiRecipe = ScalarPHBuilder.createScalarPhi( - {EndValue, Start}, WideIV->getDebugLoc(), "bc.resume.val"); - return ResumePhiRecipe; -} - -/// Create resume phis in the scalar preheader for first-order recurrences, -/// reductions and inductions, and update the VPIRInstructions wrapping the -/// original phis in the scalar header. End values for inductions are added to -/// \p IVEndValues. -static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan, - DenseMap<VPValue *, VPValue *> &IVEndValues) { - VPTypeAnalysis TypeInfo(Plan); - auto *ScalarPH = Plan.getScalarPreheader(); - auto *MiddleVPBB = cast<VPBasicBlock>(ScalarPH->getPredecessors()[0]); - VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion(); - VPBuilder VectorPHBuilder( - cast<VPBasicBlock>(VectorRegion->getSinglePredecessor())); - VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi()); - VPBuilder ScalarPHBuilder(ScalarPH); - for (VPRecipeBase &ScalarPhiR : Plan.getScalarHeader()->phis()) { - auto *ScalarPhiIRI = cast<VPIRPhi>(&ScalarPhiR); - - // TODO: Extract final value from induction recipe initially, optimize to - // pre-computed end value together in optimizeInductionExitUsers. - auto *VectorPhiR = - cast<VPHeaderPHIRecipe>(Builder.getRecipe(&ScalarPhiIRI->getIRPhi())); - if (auto *WideIVR = dyn_cast<VPWidenInductionRecipe>(VectorPhiR)) { - if (VPInstruction *ResumePhi = addResumePhiRecipeForInduction( - WideIVR, VectorPHBuilder, ScalarPHBuilder, TypeInfo, - &Plan.getVectorTripCount())) { - assert(isa<VPPhi>(ResumePhi) && "Expected a phi"); - IVEndValues[WideIVR] = ResumePhi->getOperand(0); - ScalarPhiIRI->addOperand(ResumePhi); - continue; - } - // TODO: Also handle truncated inductions here. Computing end-values - // separately should be done as VPlan-to-VPlan optimization, after - // legalizing all resume values to use the last lane from the loop. - assert(cast<VPWidenIntOrFpInductionRecipe>(VectorPhiR)->getTruncInst() && - "should only skip truncated wide inductions"); - continue; - } - - // The backedge value provides the value to resume coming out of a loop, - // which for FORs is a vector whose last element needs to be extracted. The - // start value provides the value if the loop is bypassed. - bool IsFOR = isa<VPFirstOrderRecurrencePHIRecipe>(VectorPhiR); - auto *ResumeFromVectorLoop = VectorPhiR->getBackedgeValue(); - assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() && - "Cannot handle loops with uncountable early exits"); - if (IsFOR) - ResumeFromVectorLoop = MiddleBuilder.createNaryOp( - VPInstruction::ExtractLastElement, {ResumeFromVectorLoop}, {}, - "vector.recur.extract"); - StringRef Name = IsFOR ? "scalar.recur.init" : "bc.merge.rdx"; - auto *ResumePhiR = ScalarPHBuilder.createScalarPhi( - {ResumeFromVectorLoop, VectorPhiR->getStartValue()}, {}, Name); - ScalarPhiIRI->addOperand(ResumePhiR); - } -} - -/// Handle users in the exit block for first order reductions in the original -/// exit block. The penultimate value of recurrences is fed to their LCSSA phi -/// users in the original exit block using the VPIRInstruction wrapping to the -/// LCSSA phi. -static void addExitUsersForFirstOrderRecurrences(VPlan &Plan, VFRange &Range) { - VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion(); - auto *ScalarPHVPBB = Plan.getScalarPreheader(); - auto *MiddleVPBB = Plan.getMiddleBlock(); - VPBuilder ScalarPHBuilder(ScalarPHVPBB); - VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi()); - - auto IsScalableOne = [](ElementCount VF) -> bool { - return VF == ElementCount::getScalable(1); - }; - - for (auto &HeaderPhi : VectorRegion->getEntryBasicBlock()->phis()) { - auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&HeaderPhi); - if (!FOR) - continue; - - assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() && - "Cannot handle loops with uncountable early exits"); - - // This is the second phase of vectorizing first-order recurrences, creating - // extract for users outside the loop. An overview of the transformation is - // described below. Suppose we have the following loop with some use after - // the loop of the last a[i-1], - // - // for (int i = 0; i < n; ++i) { - // t = a[i - 1]; - // b[i] = a[i] - t; - // } - // use t; - // - // There is a first-order recurrence on "a". For this loop, the shorthand - // scalar IR looks like: - // - // scalar.ph: - // s.init = a[-1] - // br scalar.body - // - // scalar.body: - // i = phi [0, scalar.ph], [i+1, scalar.body] - // s1 = phi [s.init, scalar.ph], [s2, scalar.body] - // s2 = a[i] - // b[i] = s2 - s1 - // br cond, scalar.body, exit.block - // - // exit.block: - // use = lcssa.phi [s1, scalar.body] - // - // In this example, s1 is a recurrence because it's value depends on the - // previous iteration. In the first phase of vectorization, we created a - // VPFirstOrderRecurrencePHIRecipe v1 for s1. Now we create the extracts - // for users in the scalar preheader and exit block. - // - // vector.ph: - // v_init = vector(..., ..., ..., a[-1]) - // br vector.body - // - // vector.body - // i = phi [0, vector.ph], [i+4, vector.body] - // v1 = phi [v_init, vector.ph], [v2, vector.body] - // v2 = a[i, i+1, i+2, i+3] - // b[i] = v2 - v1 - // // Next, third phase will introduce v1' = splice(v1(3), v2(0, 1, 2)) - // b[i, i+1, i+2, i+3] = v2 - v1 - // br cond, vector.body, middle.block - // - // middle.block: - // vector.recur.extract.for.phi = v2(2) - // vector.recur.extract = v2(3) - // br cond, scalar.ph, exit.block - // - // scalar.ph: - // scalar.recur.init = phi [vector.recur.extract, middle.block], - // [s.init, otherwise] - // br scalar.body - // - // scalar.body: - // i = phi [0, scalar.ph], [i+1, scalar.body] - // s1 = phi [scalar.recur.init, scalar.ph], [s2, scalar.body] - // s2 = a[i] - // b[i] = s2 - s1 - // br cond, scalar.body, exit.block - // - // exit.block: - // lo = lcssa.phi [s1, scalar.body], - // [vector.recur.extract.for.phi, middle.block] - // - // Now update VPIRInstructions modeling LCSSA phis in the exit block. - // Extract the penultimate value of the recurrence and use it as operand for - // the VPIRInstruction modeling the phi. - for (VPUser *U : FOR->users()) { - using namespace llvm::VPlanPatternMatch; - if (!match(U, m_ExtractLastElement(m_Specific(FOR)))) - continue; - // For VF vscale x 1, if vscale = 1, we are unable to extract the - // penultimate value of the recurrence. Instead we rely on the existing - // extract of the last element from the result of - // VPInstruction::FirstOrderRecurrenceSplice. - // TODO: Consider vscale_range info and UF. - if (LoopVectorizationPlanner::getDecisionAndClampRange(IsScalableOne, - Range)) - return; - VPValue *PenultimateElement = MiddleBuilder.createNaryOp( - VPInstruction::ExtractPenultimateElement, {FOR->getBackedgeValue()}, - {}, "vector.recur.extract.for.phi"); - cast<VPInstruction>(U)->replaceAllUsesWith(PenultimateElement); - } - } -} - VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( VPlanPtr Plan, VFRange &Range, LoopVersioning *LVer) { @@ -8598,9 +8396,11 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( R->setOperand(1, WideIV->getStepValue()); } - addExitUsersForFirstOrderRecurrences(*Plan, Range); + // TODO: We can't call runPass on these transforms yet, due to verifier + // failures. + VPlanTransforms::addExitUsersForFirstOrderRecurrences(*Plan, Range); DenseMap<VPValue *, VPValue *> IVEndValues; - addScalarResumePhis(RecipeBuilder, *Plan, IVEndValues); + VPlanTransforms::addScalarResumePhis(*Plan, RecipeBuilder, IVEndValues); // --------------------------------------------------------------------------- // Transform initial VPlan: Apply previously taken decisions, in order, to @@ -8621,7 +8421,8 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( // TODO: Enable following transform when the EVL-version of extended-reduction // and mulacc-reduction are implemented. if (!CM.foldTailWithEVL()) { - VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind); + VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind, + *CM.PSE.getSE()); VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan, CostCtx, Range); } @@ -8711,7 +8512,9 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlan(VFRange &Range) { DenseMap<VPValue *, VPValue *> IVEndValues; // TODO: IVEndValues are not used yet in the native path, to optimize exit // values. - addScalarResumePhis(RecipeBuilder, *Plan, IVEndValues); + // TODO: We can't call runPass on the transform yet, due to verifier + // failures. + VPlanTransforms::addScalarResumePhis(*Plan, RecipeBuilder, IVEndValues); assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid"); return Plan; @@ -10075,7 +9878,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { bool ForceVectorization = Hints.getForce() == LoopVectorizeHints::FK_Enabled; VPCostContext CostCtx(CM.TTI, *CM.TLI, LVP.getPlanFor(VF.Width), CM, - CM.CostKind); + CM.CostKind, *CM.PSE.getSE()); if (!ForceVectorization && !isOutsideLoopWorkProfitable(Checks, VF, L, PSE, CostCtx, LVP.getPlanFor(VF.Width), SEL, diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index fedca65..91c3d42 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -10620,7 +10620,8 @@ class InstructionsCompatibilityAnalysis { /// Checks if the opcode is supported as the main opcode for copyable /// elements. static bool isSupportedOpcode(const unsigned Opcode) { - return Opcode == Instruction::Add || Opcode == Instruction::LShr; + return Opcode == Instruction::Add || Opcode == Instruction::LShr || + Opcode == Instruction::Shl; } /// Identifies the best candidate value, which represents main opcode @@ -10937,6 +10938,7 @@ public: switch (MainOpcode) { case Instruction::Add: case Instruction::LShr: + case Instruction::Shl: VectorCost = TTI.getArithmeticInstrCost(MainOpcode, VecTy, Kind); break; default: @@ -22006,6 +22008,8 @@ bool BoUpSLP::collectValuesToDemote( return all_of(E.Scalars, [&](Value *V) { if (isa<PoisonValue>(V)) return true; + if (E.isCopyableElement(V)) + return true; auto *I = cast<Instruction>(V); KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL); return AmtKnownBits.getMaxValue().ult(BitWidth); diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index 07b191a..2555ebe 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -1772,7 +1772,8 @@ VPCostContext::getOperandInfo(VPValue *V) const { } InstructionCost VPCostContext::getScalarizationOverhead( - Type *ResultTy, ArrayRef<const VPValue *> Operands, ElementCount VF) { + Type *ResultTy, ArrayRef<const VPValue *> Operands, ElementCount VF, + bool AlwaysIncludeReplicatingR) { if (VF.isScalar()) return 0; @@ -1792,7 +1793,11 @@ InstructionCost VPCostContext::getScalarizationOverhead( SmallPtrSet<const VPValue *, 4> UniqueOperands; SmallVector<Type *> Tys; for (auto *Op : Operands) { - if (Op->isLiveIn() || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) || + if (Op->isLiveIn() || + (!AlwaysIncludeReplicatingR && + isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op)) || + (isa<VPReplicateRecipe>(Op) && + cast<VPReplicateRecipe>(Op)->getOpcode() == Instruction::Load) || !UniqueOperands.insert(Op).second) continue; Tys.push_back(toVectorizedTy(Types.inferScalarType(Op), VF)); diff --git a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h index fc1a09e..1580a3b 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h +++ b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h @@ -349,12 +349,14 @@ struct VPCostContext { LoopVectorizationCostModel &CM; SmallPtrSet<Instruction *, 8> SkipCostComputation; TargetTransformInfo::TargetCostKind CostKind; + ScalarEvolution &SE; VPCostContext(const TargetTransformInfo &TTI, const TargetLibraryInfo &TLI, const VPlan &Plan, LoopVectorizationCostModel &CM, - TargetTransformInfo::TargetCostKind CostKind) + TargetTransformInfo::TargetCostKind CostKind, + ScalarEvolution &SE) : TTI(TTI), TLI(TLI), Types(Plan), LLVMCtx(Plan.getContext()), CM(CM), - CostKind(CostKind) {} + CostKind(CostKind), SE(SE) {} /// Return the cost for \p UI with \p VF using the legacy cost model as /// fallback until computing the cost of all recipes migrates to VPlan. @@ -374,10 +376,12 @@ struct VPCostContext { /// Estimate the overhead of scalarizing a recipe with result type \p ResultTy /// and \p Operands with \p VF. This is a convenience wrapper for the - /// type-based getScalarizationOverhead API. - InstructionCost getScalarizationOverhead(Type *ResultTy, - ArrayRef<const VPValue *> Operands, - ElementCount VF); + /// type-based getScalarizationOverhead API. If \p AlwaysIncludeReplicatingR + /// is true, always compute the cost of scalarizing replicating operands. + InstructionCost + getScalarizationOverhead(Type *ResultTy, ArrayRef<const VPValue *> Operands, + ElementCount VF, + bool AlwaysIncludeReplicatingR = false); }; /// This class can be used to assign names to VPValues. For VPValues without diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 67b9244..94e2628 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -40,6 +40,7 @@ #include <cassert> using namespace llvm; +using namespace llvm::VPlanPatternMatch; using VectorParts = SmallVector<Value *, 2>; @@ -303,7 +304,6 @@ VPPartialReductionRecipe::computeCost(ElementCount VF, VPRecipeBase *OpR = Op->getDefiningRecipe(); // If the partial reduction is predicated, a select will be operand 0 - using namespace llvm::VPlanPatternMatch; if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(Op), m_VPValue()))) { OpR = Op->getDefiningRecipe(); } @@ -1963,7 +1963,6 @@ InstructionCost VPWidenSelectRecipe::computeCost(ElementCount VF, Type *VectorTy = toVectorTy(Ctx.Types.inferScalarType(this), VF); VPValue *Op0, *Op1; - using namespace llvm::VPlanPatternMatch; if (!ScalarCond && ScalarTy->getScalarSizeInBits() == 1 && (match(this, m_LogicalAnd(m_VPValue(Op0), m_VPValue(Op1))) || match(this, m_LogicalOr(m_VPValue(Op0), m_VPValue(Op1))))) { @@ -2778,7 +2777,7 @@ VPExpressionRecipe::VPExpressionRecipe( // Recipes in the expression, except the last one, must only be used by // (other) recipes inside the expression. If there are other users, external // to the expression, use a clone of the recipe for external users. - for (VPSingleDefRecipe *R : ExpressionRecipes) { + for (VPSingleDefRecipe *R : reverse(ExpressionRecipes)) { if (R != ExpressionRecipes.back() && any_of(R->users(), [&ExpressionRecipesAsSetOfUsers](VPUser *U) { return !ExpressionRecipesAsSetOfUsers.contains(U); @@ -3111,6 +3110,62 @@ bool VPReplicateRecipe::shouldPack() const { }); } +/// Returns true if \p Ptr is a pointer computation for which the legacy cost +/// model computes a SCEV expression when computing the address cost. +static bool shouldUseAddressAccessSCEV(const VPValue *Ptr) { + auto *PtrR = Ptr->getDefiningRecipe(); + if (!PtrR || !((isa<VPReplicateRecipe>(PtrR) && + cast<VPReplicateRecipe>(PtrR)->getOpcode() == + Instruction::GetElementPtr) || + isa<VPWidenGEPRecipe>(PtrR) || + match(Ptr, m_GetElementPtr(m_VPValue(), m_VPValue())))) + return false; + + // We are looking for a GEP where all indices are either loop invariant or + // inductions. + for (VPValue *Opd : drop_begin(PtrR->operands())) { + if (!Opd->isDefinedOutsideLoopRegions() && + !isa<VPScalarIVStepsRecipe, VPWidenIntOrFpInductionRecipe>(Opd)) + return false; + } + + return true; +} + +/// Returns true if \p V is used as part of the address of another load or +/// store. +static bool isUsedByLoadStoreAddress(const VPUser *V) { + SmallPtrSet<const VPUser *, 4> Seen; + SmallVector<const VPUser *> WorkList = {V}; + + while (!WorkList.empty()) { + auto *Cur = dyn_cast<VPSingleDefRecipe>(WorkList.pop_back_val()); + if (!Cur || !Seen.insert(Cur).second) + continue; + + for (VPUser *U : Cur->users()) { + if (auto *InterleaveR = dyn_cast<VPInterleaveBase>(U)) + if (InterleaveR->getAddr() == Cur) + return true; + if (auto *RepR = dyn_cast<VPReplicateRecipe>(U)) { + if (RepR->getOpcode() == Instruction::Load && + RepR->getOperand(0) == Cur) + return true; + if (RepR->getOpcode() == Instruction::Store && + RepR->getOperand(1) == Cur) + return true; + } + if (auto *MemR = dyn_cast<VPWidenMemoryRecipe>(U)) { + if (MemR->getAddr() == Cur && MemR->isConsecutive()) + return true; + } + } + + append_range(WorkList, cast<VPSingleDefRecipe>(Cur)->users()); + } + return false; +} + InstructionCost VPReplicateRecipe::computeCost(ElementCount VF, VPCostContext &Ctx) const { Instruction *UI = cast<Instruction>(getUnderlyingValue()); @@ -3218,21 +3273,60 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF, } case Instruction::Load: case Instruction::Store: { - if (isSingleScalar()) { - bool IsLoad = UI->getOpcode() == Instruction::Load; - Type *ValTy = Ctx.Types.inferScalarType(IsLoad ? this : getOperand(0)); - Type *ScalarPtrTy = Ctx.Types.inferScalarType(getOperand(IsLoad ? 0 : 1)); - const Align Alignment = getLoadStoreAlignment(UI); - unsigned AS = getLoadStoreAddressSpace(UI); - TTI::OperandValueInfo OpInfo = TTI::getOperandInfo(UI->getOperand(0)); - InstructionCost ScalarMemOpCost = Ctx.TTI.getMemoryOpCost( - UI->getOpcode(), ValTy, Alignment, AS, Ctx.CostKind, OpInfo, UI); - return ScalarMemOpCost + Ctx.TTI.getAddressComputationCost( - ScalarPtrTy, nullptr, nullptr, Ctx.CostKind); - } + if (VF.isScalable() && !isSingleScalar()) + return InstructionCost::getInvalid(); + // TODO: See getMemInstScalarizationCost for how to handle replicating and // predicated cases. - break; + const VPRegionBlock *ParentRegion = getParent()->getParent(); + if (ParentRegion && ParentRegion->isReplicator()) + break; + + bool IsLoad = UI->getOpcode() == Instruction::Load; + const VPValue *PtrOp = getOperand(!IsLoad); + // TODO: Handle cases where we need to pass a SCEV to + // getAddressComputationCost. + if (shouldUseAddressAccessSCEV(PtrOp)) + break; + + Type *ValTy = Ctx.Types.inferScalarType(IsLoad ? this : getOperand(0)); + Type *ScalarPtrTy = Ctx.Types.inferScalarType(PtrOp); + const Align Alignment = getLoadStoreAlignment(UI); + unsigned AS = getLoadStoreAddressSpace(UI); + TTI::OperandValueInfo OpInfo = TTI::getOperandInfo(UI->getOperand(0)); + InstructionCost ScalarMemOpCost = Ctx.TTI.getMemoryOpCost( + UI->getOpcode(), ValTy, Alignment, AS, Ctx.CostKind, OpInfo); + + Type *PtrTy = isSingleScalar() ? ScalarPtrTy : toVectorTy(ScalarPtrTy, VF); + bool PreferVectorizedAddressing = Ctx.TTI.prefersVectorizedAddressing(); + bool UsedByLoadStoreAddress = + !PreferVectorizedAddressing && isUsedByLoadStoreAddress(this); + InstructionCost ScalarCost = + ScalarMemOpCost + Ctx.TTI.getAddressComputationCost( + PtrTy, UsedByLoadStoreAddress ? nullptr : &Ctx.SE, + nullptr, Ctx.CostKind); + if (isSingleScalar()) + return ScalarCost; + + SmallVector<const VPValue *> OpsToScalarize; + Type *ResultTy = Type::getVoidTy(PtrTy->getContext()); + // Set ResultTy and OpsToScalarize, if scalarization is needed. Currently we + // don't assign scalarization overhead in general, if the target prefers + // vectorized addressing or the loaded value is used as part of an address + // of another load or store. + if (!UsedByLoadStoreAddress) { + bool EfficientVectorLoadStore = + Ctx.TTI.supportsEfficientVectorElementLoadStore(); + if (!(IsLoad && !PreferVectorizedAddressing) && + !(!IsLoad && EfficientVectorLoadStore)) + append_range(OpsToScalarize, operands()); + + if (!EfficientVectorLoadStore) + ResultTy = Ctx.Types.inferScalarType(this); + } + + return (ScalarCost * VF.getFixedValue()) + + Ctx.getScalarizationOverhead(ResultTy, OpsToScalarize, VF, true); } } diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index ca63bf3..ebf833e 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -4198,3 +4198,202 @@ void VPlanTransforms::addBranchWeightToMiddleTerminator( MDB.createBranchWeights({1, VectorStep - 1}, /*IsExpected=*/false); MiddleTerm->addMetadata(LLVMContext::MD_prof, BranchWeights); } + +/// Create and return a ResumePhi for \p WideIV, unless it is truncated. If the +/// induction recipe is not canonical, creates a VPDerivedIVRecipe to compute +/// the end value of the induction. +static VPInstruction *addResumePhiRecipeForInduction( + VPWidenInductionRecipe *WideIV, VPBuilder &VectorPHBuilder, + VPBuilder &ScalarPHBuilder, VPTypeAnalysis &TypeInfo, VPValue *VectorTC) { + auto *WideIntOrFp = dyn_cast<VPWidenIntOrFpInductionRecipe>(WideIV); + // Truncated wide inductions resume from the last lane of their vector value + // in the last vector iteration which is handled elsewhere. + if (WideIntOrFp && WideIntOrFp->getTruncInst()) + return nullptr; + + VPValue *Start = WideIV->getStartValue(); + VPValue *Step = WideIV->getStepValue(); + const InductionDescriptor &ID = WideIV->getInductionDescriptor(); + VPValue *EndValue = VectorTC; + if (!WideIntOrFp || !WideIntOrFp->isCanonical()) { + EndValue = VectorPHBuilder.createDerivedIV( + ID.getKind(), dyn_cast_or_null<FPMathOperator>(ID.getInductionBinOp()), + Start, VectorTC, Step); + } + + // EndValue is derived from the vector trip count (which has the same type as + // the widest induction) and thus may be wider than the induction here. + Type *ScalarTypeOfWideIV = TypeInfo.inferScalarType(WideIV); + if (ScalarTypeOfWideIV != TypeInfo.inferScalarType(EndValue)) { + EndValue = VectorPHBuilder.createScalarCast(Instruction::Trunc, EndValue, + ScalarTypeOfWideIV, + WideIV->getDebugLoc()); + } + + auto *ResumePhiRecipe = ScalarPHBuilder.createScalarPhi( + {EndValue, Start}, WideIV->getDebugLoc(), "bc.resume.val"); + return ResumePhiRecipe; +} + +void VPlanTransforms::addScalarResumePhis( + VPlan &Plan, VPRecipeBuilder &Builder, + DenseMap<VPValue *, VPValue *> &IVEndValues) { + VPTypeAnalysis TypeInfo(Plan); + auto *ScalarPH = Plan.getScalarPreheader(); + auto *MiddleVPBB = cast<VPBasicBlock>(ScalarPH->getPredecessors()[0]); + VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion(); + VPBuilder VectorPHBuilder( + cast<VPBasicBlock>(VectorRegion->getSinglePredecessor())); + VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi()); + VPBuilder ScalarPHBuilder(ScalarPH); + for (VPRecipeBase &ScalarPhiR : Plan.getScalarHeader()->phis()) { + auto *ScalarPhiIRI = cast<VPIRPhi>(&ScalarPhiR); + + // TODO: Extract final value from induction recipe initially, optimize to + // pre-computed end value together in optimizeInductionExitUsers. + auto *VectorPhiR = + cast<VPHeaderPHIRecipe>(Builder.getRecipe(&ScalarPhiIRI->getIRPhi())); + if (auto *WideIVR = dyn_cast<VPWidenInductionRecipe>(VectorPhiR)) { + if (VPInstruction *ResumePhi = addResumePhiRecipeForInduction( + WideIVR, VectorPHBuilder, ScalarPHBuilder, TypeInfo, + &Plan.getVectorTripCount())) { + assert(isa<VPPhi>(ResumePhi) && "Expected a phi"); + IVEndValues[WideIVR] = ResumePhi->getOperand(0); + ScalarPhiIRI->addOperand(ResumePhi); + continue; + } + // TODO: Also handle truncated inductions here. Computing end-values + // separately should be done as VPlan-to-VPlan optimization, after + // legalizing all resume values to use the last lane from the loop. + assert(cast<VPWidenIntOrFpInductionRecipe>(VectorPhiR)->getTruncInst() && + "should only skip truncated wide inductions"); + continue; + } + + // The backedge value provides the value to resume coming out of a loop, + // which for FORs is a vector whose last element needs to be extracted. The + // start value provides the value if the loop is bypassed. + bool IsFOR = isa<VPFirstOrderRecurrencePHIRecipe>(VectorPhiR); + auto *ResumeFromVectorLoop = VectorPhiR->getBackedgeValue(); + assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() && + "Cannot handle loops with uncountable early exits"); + if (IsFOR) + ResumeFromVectorLoop = MiddleBuilder.createNaryOp( + VPInstruction::ExtractLastElement, {ResumeFromVectorLoop}, {}, + "vector.recur.extract"); + StringRef Name = IsFOR ? "scalar.recur.init" : "bc.merge.rdx"; + auto *ResumePhiR = ScalarPHBuilder.createScalarPhi( + {ResumeFromVectorLoop, VectorPhiR->getStartValue()}, {}, Name); + ScalarPhiIRI->addOperand(ResumePhiR); + } +} + +void VPlanTransforms::addExitUsersForFirstOrderRecurrences(VPlan &Plan, + VFRange &Range) { + VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion(); + auto *ScalarPHVPBB = Plan.getScalarPreheader(); + auto *MiddleVPBB = Plan.getMiddleBlock(); + VPBuilder ScalarPHBuilder(ScalarPHVPBB); + VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi()); + + auto IsScalableOne = [](ElementCount VF) -> bool { + return VF == ElementCount::getScalable(1); + }; + + for (auto &HeaderPhi : VectorRegion->getEntryBasicBlock()->phis()) { + auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&HeaderPhi); + if (!FOR) + continue; + + assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() && + "Cannot handle loops with uncountable early exits"); + + // This is the second phase of vectorizing first-order recurrences, creating + // extract for users outside the loop. An overview of the transformation is + // described below. Suppose we have the following loop with some use after + // the loop of the last a[i-1], + // + // for (int i = 0; i < n; ++i) { + // t = a[i - 1]; + // b[i] = a[i] - t; + // } + // use t; + // + // There is a first-order recurrence on "a". For this loop, the shorthand + // scalar IR looks like: + // + // scalar.ph: + // s.init = a[-1] + // br scalar.body + // + // scalar.body: + // i = phi [0, scalar.ph], [i+1, scalar.body] + // s1 = phi [s.init, scalar.ph], [s2, scalar.body] + // s2 = a[i] + // b[i] = s2 - s1 + // br cond, scalar.body, exit.block + // + // exit.block: + // use = lcssa.phi [s1, scalar.body] + // + // In this example, s1 is a recurrence because it's value depends on the + // previous iteration. In the first phase of vectorization, we created a + // VPFirstOrderRecurrencePHIRecipe v1 for s1. Now we create the extracts + // for users in the scalar preheader and exit block. + // + // vector.ph: + // v_init = vector(..., ..., ..., a[-1]) + // br vector.body + // + // vector.body + // i = phi [0, vector.ph], [i+4, vector.body] + // v1 = phi [v_init, vector.ph], [v2, vector.body] + // v2 = a[i, i+1, i+2, i+3] + // b[i] = v2 - v1 + // // Next, third phase will introduce v1' = splice(v1(3), v2(0, 1, 2)) + // b[i, i+1, i+2, i+3] = v2 - v1 + // br cond, vector.body, middle.block + // + // middle.block: + // vector.recur.extract.for.phi = v2(2) + // vector.recur.extract = v2(3) + // br cond, scalar.ph, exit.block + // + // scalar.ph: + // scalar.recur.init = phi [vector.recur.extract, middle.block], + // [s.init, otherwise] + // br scalar.body + // + // scalar.body: + // i = phi [0, scalar.ph], [i+1, scalar.body] + // s1 = phi [scalar.recur.init, scalar.ph], [s2, scalar.body] + // s2 = a[i] + // b[i] = s2 - s1 + // br cond, scalar.body, exit.block + // + // exit.block: + // lo = lcssa.phi [s1, scalar.body], + // [vector.recur.extract.for.phi, middle.block] + // + // Now update VPIRInstructions modeling LCSSA phis in the exit block. + // Extract the penultimate value of the recurrence and use it as operand for + // the VPIRInstruction modeling the phi. + for (VPUser *U : FOR->users()) { + using namespace llvm::VPlanPatternMatch; + if (!match(U, m_ExtractLastElement(m_Specific(FOR)))) + continue; + // For VF vscale x 1, if vscale = 1, we are unable to extract the + // penultimate value of the recurrence. Instead we rely on the existing + // extract of the last element from the result of + // VPInstruction::FirstOrderRecurrenceSplice. + // TODO: Consider vscale_range info and UF. + if (LoopVectorizationPlanner::getDecisionAndClampRange(IsScalableOne, + Range)) + return; + VPValue *PenultimateElement = MiddleBuilder.createNaryOp( + VPInstruction::ExtractPenultimateElement, {FOR->getBackedgeValue()}, + {}, "vector.recur.extract.for.phi"); + cast<VPInstruction>(U)->replaceAllUsesWith(PenultimateElement); + } + } +} diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h index 2f00e51..5a8a2bb 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -363,6 +363,19 @@ struct VPlanTransforms { static void addBranchWeightToMiddleTerminator(VPlan &Plan, ElementCount VF, std::optional<unsigned> VScaleForTuning); + + /// Create resume phis in the scalar preheader for first-order recurrences, + /// reductions and inductions, and update the VPIRInstructions wrapping the + /// original phis in the scalar header. End values for inductions are added to + /// \p IVEndValues. + static void addScalarResumePhis(VPlan &Plan, VPRecipeBuilder &Builder, + DenseMap<VPValue *, VPValue *> &IVEndValues); + + /// Handle users in the exit block for first order reductions in the original + /// exit block. The penultimate value of recurrences is fed to their LCSSA phi + /// users in the original exit block using the VPIRInstruction wrapping to the + /// LCSSA phi. + static void addExitUsersForFirstOrderRecurrences(VPlan &Plan, VFRange &Range); }; } // namespace llvm |