diff options
Diffstat (limited to 'llvm/lib')
64 files changed, 1806 insertions, 1073 deletions
diff --git a/llvm/lib/Analysis/MemoryProfileInfo.cpp b/llvm/lib/Analysis/MemoryProfileInfo.cpp index 0c1f8db..92a5b6f 100644 --- a/llvm/lib/Analysis/MemoryProfileInfo.cpp +++ b/llvm/lib/Analysis/MemoryProfileInfo.cpp @@ -54,6 +54,10 @@ cl::opt<unsigned> MinPercentMaxColdSize( "memprof-min-percent-max-cold-size", cl::init(100), cl::Hidden, cl::desc("Min percent of max cold bytes for critical cold context")); +LLVM_ABI cl::opt<bool> MemProfUseAmbiguousAttributes( + "memprof-ambiguous-attributes", cl::init(true), cl::Hidden, + cl::desc("Apply ambiguous memprof attribute to ambiguous allocations")); + } // end namespace llvm bool llvm::memprof::metadataIncludesAllContextSizeInfo() { @@ -125,6 +129,26 @@ bool llvm::memprof::hasSingleAllocType(uint8_t AllocTypes) { return NumAllocTypes == 1; } +void llvm::memprof::removeAnyExistingAmbiguousAttribute(CallBase *CB) { + if (!CB->hasFnAttr("memprof")) + return; + assert(CB->getFnAttr("memprof").getValueAsString() == "ambiguous"); + CB->removeFnAttr("memprof"); +} + +void llvm::memprof::addAmbiguousAttribute(CallBase *CB) { + if (!MemProfUseAmbiguousAttributes) + return; + // We may have an existing ambiguous attribute if we are reanalyzing + // after inlining. + if (CB->hasFnAttr("memprof")) { + assert(CB->getFnAttr("memprof").getValueAsString() == "ambiguous"); + } else { + auto A = llvm::Attribute::get(CB->getContext(), "memprof", "ambiguous"); + CB->addFnAttr(A); + } +} + void CallStackTrie::addCallStack( AllocationType AllocType, ArrayRef<uint64_t> StackIds, std::vector<ContextTotalSize> ContextSizeInfo) { @@ -470,6 +494,9 @@ void CallStackTrie::addSingleAllocTypeAttribute(CallBase *CI, AllocationType AT, StringRef Descriptor) { auto AllocTypeString = getAllocTypeAttributeString(AT); auto A = llvm::Attribute::get(CI->getContext(), "memprof", AllocTypeString); + // After inlining we may be able to convert an existing ambiguous allocation + // to an unambiguous one. + removeAnyExistingAmbiguousAttribute(CI); CI->addFnAttr(A); if (MemProfReportHintedSizes) { std::vector<ContextTotalSize> ContextSizeInfo; @@ -529,6 +556,7 @@ bool CallStackTrie::buildAndAttachMIBMetadata(CallBase *CI) { assert(MIBCallStack.size() == 1 && "Should only be left with Alloc's location in stack"); CI->setMetadata(LLVMContext::MD_memprof, MDNode::get(Ctx, MIBNodes)); + addAmbiguousAttribute(CI); return true; } // If there exists corner case that CallStackTrie has one chain to leaf diff --git a/llvm/lib/BinaryFormat/DXContainer.cpp b/llvm/lib/BinaryFormat/DXContainer.cpp index c06a3e3..b334f86 100644 --- a/llvm/lib/BinaryFormat/DXContainer.cpp +++ b/llvm/lib/BinaryFormat/DXContainer.cpp @@ -18,6 +18,70 @@ using namespace llvm; using namespace llvm::dxbc; +#define ROOT_PARAMETER(Val, Enum) \ + case Val: \ + return true; +bool llvm::dxbc::isValidParameterType(uint32_t V) { + switch (V) { +#include "llvm/BinaryFormat/DXContainerConstants.def" + } + return false; +} + +bool llvm::dxbc::isValidRangeType(uint32_t V) { + return V <= llvm::to_underlying(dxil::ResourceClass::LastEntry); +} + +#define SHADER_VISIBILITY(Val, Enum) \ + case Val: \ + return true; +bool llvm::dxbc::isValidShaderVisibility(uint32_t V) { + switch (V) { +#include "llvm/BinaryFormat/DXContainerConstants.def" + } + return false; +} + +#define FILTER(Val, Enum) \ + case Val: \ + return true; +bool llvm::dxbc::isValidSamplerFilter(uint32_t V) { + switch (V) { +#include "llvm/BinaryFormat/DXContainerConstants.def" + } + return false; +} + +#define TEXTURE_ADDRESS_MODE(Val, Enum) \ + case Val: \ + return true; +bool llvm::dxbc::isValidAddress(uint32_t V) { + switch (V) { +#include "llvm/BinaryFormat/DXContainerConstants.def" + } + return false; +} + +#define COMPARISON_FUNC(Val, Enum) \ + case Val: \ + return true; +bool llvm::dxbc::isValidComparisonFunc(uint32_t V) { + switch (V) { +#include "llvm/BinaryFormat/DXContainerConstants.def" + } + return false; +} + +#define STATIC_BORDER_COLOR(Val, Enum) \ + case Val: \ + return true; +bool llvm::dxbc::isValidBorderColor(uint32_t V) { + switch (V) { +#include "llvm/BinaryFormat/DXContainerConstants.def" + } + return false; +} + dxbc::PartType dxbc::parsePartType(StringRef S) { #define CONTAINER_PART(PartName) .Case(#PartName, PartType::PartName) return StringSwitch<dxbc::PartType>(S) diff --git a/llvm/lib/CAS/CMakeLists.txt b/llvm/lib/CAS/CMakeLists.txt index 7ae5f7e..bca39b6 100644 --- a/llvm/lib/CAS/CMakeLists.txt +++ b/llvm/lib/CAS/CMakeLists.txt @@ -7,6 +7,7 @@ add_llvm_component_library(LLVMCAS MappedFileRegionArena.cpp ObjectStore.cpp OnDiskCommon.cpp + OnDiskDataAllocator.cpp OnDiskTrieRawHashMap.cpp ADDITIONAL_HEADER_DIRS diff --git a/llvm/lib/CAS/OnDiskDataAllocator.cpp b/llvm/lib/CAS/OnDiskDataAllocator.cpp new file mode 100644 index 0000000..13bbd66 --- /dev/null +++ b/llvm/lib/CAS/OnDiskDataAllocator.cpp @@ -0,0 +1,234 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file Implements OnDiskDataAllocator. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/CAS/OnDiskDataAllocator.h" +#include "DatabaseFile.h" +#include "llvm/Config/llvm-config.h" + +using namespace llvm; +using namespace llvm::cas; +using namespace llvm::cas::ondisk; + +#if LLVM_ENABLE_ONDISK_CAS + +//===----------------------------------------------------------------------===// +// DataAllocator data structures. +//===----------------------------------------------------------------------===// + +namespace { +/// DataAllocator table layout: +/// - [8-bytes: Generic table header] +/// - 8-bytes: AllocatorOffset (reserved for implementing free lists) +/// - 8-bytes: Size for user data header +/// - <user data buffer> +/// +/// Record layout: +/// - <data> +class DataAllocatorHandle { +public: + static constexpr TableHandle::TableKind Kind = + TableHandle::TableKind::DataAllocator; + + struct Header { + TableHandle::Header GenericHeader; + std::atomic<int64_t> AllocatorOffset; + const uint64_t UserHeaderSize; + }; + + operator TableHandle() const { + if (!H) + return TableHandle(); + return TableHandle(*Region, H->GenericHeader); + } + + Expected<MutableArrayRef<char>> allocate(MappedFileRegionArena &Alloc, + size_t DataSize) { + assert(&Alloc.getRegion() == Region); + auto Ptr = Alloc.allocate(DataSize); + if (LLVM_UNLIKELY(!Ptr)) + return Ptr.takeError(); + return MutableArrayRef(*Ptr, DataSize); + } + + explicit operator bool() const { return H; } + const Header &getHeader() const { return *H; } + MappedFileRegion &getRegion() const { return *Region; } + + MutableArrayRef<uint8_t> getUserHeader() { + return MutableArrayRef(reinterpret_cast<uint8_t *>(H + 1), + H->UserHeaderSize); + } + + static Expected<DataAllocatorHandle> + create(MappedFileRegionArena &Alloc, StringRef Name, uint32_t UserHeaderSize); + + DataAllocatorHandle() = default; + DataAllocatorHandle(MappedFileRegion &Region, Header &H) + : Region(&Region), H(&H) {} + DataAllocatorHandle(MappedFileRegion &Region, intptr_t HeaderOffset) + : DataAllocatorHandle( + Region, *reinterpret_cast<Header *>(Region.data() + HeaderOffset)) { + } + +private: + MappedFileRegion *Region = nullptr; + Header *H = nullptr; +}; + +} // end anonymous namespace + +struct OnDiskDataAllocator::ImplType { + DatabaseFile File; + DataAllocatorHandle Store; +}; + +Expected<DataAllocatorHandle> +DataAllocatorHandle::create(MappedFileRegionArena &Alloc, StringRef Name, + uint32_t UserHeaderSize) { + // Allocate. + auto Offset = + Alloc.allocateOffset(sizeof(Header) + UserHeaderSize + Name.size() + 1); + if (LLVM_UNLIKELY(!Offset)) + return Offset.takeError(); + + // Construct the header and the name. + assert(Name.size() <= UINT16_MAX && "Expected smaller table name"); + auto *H = new (Alloc.getRegion().data() + *Offset) + Header{{TableHandle::TableKind::DataAllocator, + static_cast<uint16_t>(Name.size()), + static_cast<int32_t>(sizeof(Header) + UserHeaderSize)}, + /*AllocatorOffset=*/{0}, + /*UserHeaderSize=*/UserHeaderSize}; + // Memset UserHeader. + char *UserHeader = reinterpret_cast<char *>(H + 1); + memset(UserHeader, 0, UserHeaderSize); + // Write database file name (null-terminated). + char *NameStorage = UserHeader + UserHeaderSize; + llvm::copy(Name, NameStorage); + NameStorage[Name.size()] = 0; + return DataAllocatorHandle(Alloc.getRegion(), *H); +} + +Expected<OnDiskDataAllocator> OnDiskDataAllocator::create( + const Twine &PathTwine, const Twine &TableNameTwine, uint64_t MaxFileSize, + std::optional<uint64_t> NewFileInitialSize, uint32_t UserHeaderSize, + function_ref<void(void *)> UserHeaderInit) { + assert(!UserHeaderSize || UserHeaderInit); + SmallString<128> PathStorage; + StringRef Path = PathTwine.toStringRef(PathStorage); + SmallString<128> TableNameStorage; + StringRef TableName = TableNameTwine.toStringRef(TableNameStorage); + + // Constructor for if the file doesn't exist. + auto NewDBConstructor = [&](DatabaseFile &DB) -> Error { + auto Store = + DataAllocatorHandle::create(DB.getAlloc(), TableName, UserHeaderSize); + if (LLVM_UNLIKELY(!Store)) + return Store.takeError(); + + if (auto E = DB.addTable(*Store)) + return E; + + if (UserHeaderSize) + UserHeaderInit(Store->getUserHeader().data()); + return Error::success(); + }; + + // Get or create the file. + Expected<DatabaseFile> File = + DatabaseFile::create(Path, MaxFileSize, NewDBConstructor); + if (!File) + return File.takeError(); + + // Find the table and validate it. + std::optional<TableHandle> Table = File->findTable(TableName); + if (!Table) + return createTableConfigError(std::errc::argument_out_of_domain, Path, + TableName, "table not found"); + if (Error E = checkTable("table kind", (size_t)DataAllocatorHandle::Kind, + (size_t)Table->getHeader().Kind, Path, TableName)) + return std::move(E); + auto Store = Table->cast<DataAllocatorHandle>(); + assert(Store && "Already checked the kind"); + + // Success. + OnDiskDataAllocator::ImplType Impl{DatabaseFile(std::move(*File)), Store}; + return OnDiskDataAllocator(std::make_unique<ImplType>(std::move(Impl))); +} + +Expected<OnDiskDataAllocator::OnDiskPtr> +OnDiskDataAllocator::allocate(size_t Size) { + auto Data = Impl->Store.allocate(Impl->File.getAlloc(), Size); + if (LLVM_UNLIKELY(!Data)) + return Data.takeError(); + + return OnDiskPtr(FileOffset(Data->data() - Impl->Store.getRegion().data()), + *Data); +} + +Expected<ArrayRef<char>> OnDiskDataAllocator::get(FileOffset Offset, + size_t Size) const { + assert(Offset); + assert(Impl); + if (Offset.get() + Size >= Impl->File.getAlloc().size()) + return createStringError(make_error_code(std::errc::protocol_error), + "requested size too large in allocator"); + return ArrayRef<char>{Impl->File.getRegion().data() + Offset.get(), Size}; +} + +MutableArrayRef<uint8_t> OnDiskDataAllocator::getUserHeader() { + return Impl->Store.getUserHeader(); +} + +size_t OnDiskDataAllocator::size() const { return Impl->File.size(); } +size_t OnDiskDataAllocator::capacity() const { + return Impl->File.getRegion().size(); +} + +OnDiskDataAllocator::OnDiskDataAllocator(std::unique_ptr<ImplType> Impl) + : Impl(std::move(Impl)) {} + +#else // !LLVM_ENABLE_ONDISK_CAS + +struct OnDiskDataAllocator::ImplType {}; + +Expected<OnDiskDataAllocator> OnDiskDataAllocator::create( + const Twine &Path, const Twine &TableName, uint64_t MaxFileSize, + std::optional<uint64_t> NewFileInitialSize, uint32_t UserHeaderSize, + function_ref<void(void *)> UserHeaderInit) { + return createStringError(make_error_code(std::errc::not_supported), + "OnDiskDataAllocator is not supported"); +} + +Expected<OnDiskDataAllocator::OnDiskPtr> +OnDiskDataAllocator::allocate(size_t Size) { + return createStringError(make_error_code(std::errc::not_supported), + "OnDiskDataAllocator is not supported"); +} + +Expected<ArrayRef<char>> OnDiskDataAllocator::get(FileOffset Offset, + size_t Size) const { + return createStringError(make_error_code(std::errc::not_supported), + "OnDiskDataAllocator is not supported"); +} + +MutableArrayRef<uint8_t> OnDiskDataAllocator::getUserHeader() { return {}; } + +size_t OnDiskDataAllocator::size() const { return 0; } +size_t OnDiskDataAllocator::capacity() const { return 0; } + +#endif // LLVM_ENABLE_ONDISK_CAS + +OnDiskDataAllocator::OnDiskDataAllocator(OnDiskDataAllocator &&RHS) = default; +OnDiskDataAllocator & +OnDiskDataAllocator::operator=(OnDiskDataAllocator &&RHS) = default; +OnDiskDataAllocator::~OnDiskDataAllocator() = default; diff --git a/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp b/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp index 9403893..323b21e 100644 --- a/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp +++ b/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp @@ -427,7 +427,7 @@ TrieRawHashMapHandle::createRecord(MappedFileRegionArena &Alloc, return Record; } -Expected<OnDiskTrieRawHashMap::const_pointer> +Expected<OnDiskTrieRawHashMap::ConstOnDiskPtr> OnDiskTrieRawHashMap::recoverFromFileOffset(FileOffset Offset) const { // Check alignment. if (!isAligned(MappedFileRegionArena::getAlign(), Offset.get())) @@ -448,17 +448,17 @@ OnDiskTrieRawHashMap::recoverFromFileOffset(FileOffset Offset) const { // Looks okay... TrieRawHashMapHandle::RecordData D = Impl->Trie.getRecord(SubtrieSlotValue::getDataOffset(Offset)); - return const_pointer(D.Proxy, D.getFileOffset()); + return ConstOnDiskPtr(D.Proxy, D.getFileOffset()); } -OnDiskTrieRawHashMap::const_pointer +OnDiskTrieRawHashMap::ConstOnDiskPtr OnDiskTrieRawHashMap::find(ArrayRef<uint8_t> Hash) const { TrieRawHashMapHandle Trie = Impl->Trie; assert(Hash.size() == Trie.getNumHashBytes() && "Invalid hash"); SubtrieHandle S = Trie.getRoot(); if (!S) - return const_pointer(); + return ConstOnDiskPtr(); TrieHashIndexGenerator IndexGen = Trie.getIndexGen(S, Hash); size_t Index = IndexGen.next(); @@ -466,13 +466,13 @@ OnDiskTrieRawHashMap::find(ArrayRef<uint8_t> Hash) const { // Try to set the content. SubtrieSlotValue V = S.load(Index); if (!V) - return const_pointer(); + return ConstOnDiskPtr(); // Check for an exact match. if (V.isData()) { TrieRawHashMapHandle::RecordData D = Trie.getRecord(V); - return D.Proxy.Hash == Hash ? const_pointer(D.Proxy, D.getFileOffset()) - : const_pointer(); + return D.Proxy.Hash == Hash ? ConstOnDiskPtr(D.Proxy, D.getFileOffset()) + : ConstOnDiskPtr(); } Index = IndexGen.next(); @@ -490,7 +490,7 @@ void SubtrieHandle::reinitialize(uint32_t StartBit, uint32_t NumBits) { H->NumBits = NumBits; } -Expected<OnDiskTrieRawHashMap::pointer> +Expected<OnDiskTrieRawHashMap::OnDiskPtr> OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash, LazyInsertOnConstructCB OnConstruct, LazyInsertOnLeakCB OnLeak) { @@ -523,7 +523,8 @@ OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash, } if (S->compare_exchange_strong(Index, Existing, NewRecord->Offset)) - return pointer(NewRecord->Proxy, NewRecord->Offset.asDataFileOffset()); + return OnDiskPtr(NewRecord->Proxy, + NewRecord->Offset.asDataFileOffset()); // Race means that Existing is no longer empty; fall through... } @@ -540,8 +541,8 @@ OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash, if (NewRecord && OnLeak) OnLeak(NewRecord->Offset.asDataFileOffset(), NewRecord->Proxy, ExistingRecord.Offset.asDataFileOffset(), ExistingRecord.Proxy); - return pointer(ExistingRecord.Proxy, - ExistingRecord.Offset.asDataFileOffset()); + return OnDiskPtr(ExistingRecord.Proxy, + ExistingRecord.Offset.asDataFileOffset()); } // Sink the existing content as long as the indexes match. @@ -1135,7 +1136,7 @@ OnDiskTrieRawHashMap::create(const Twine &PathTwine, const Twine &TrieNameTwine, "OnDiskTrieRawHashMap is not supported"); } -Expected<OnDiskTrieRawHashMap::pointer> +Expected<OnDiskTrieRawHashMap::OnDiskPtr> OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash, LazyInsertOnConstructCB OnConstruct, LazyInsertOnLeakCB OnLeak) { @@ -1143,15 +1144,15 @@ OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash, "OnDiskTrieRawHashMap is not supported"); } -Expected<OnDiskTrieRawHashMap::const_pointer> +Expected<OnDiskTrieRawHashMap::ConstOnDiskPtr> OnDiskTrieRawHashMap::recoverFromFileOffset(FileOffset Offset) const { return createStringError(make_error_code(std::errc::not_supported), "OnDiskTrieRawHashMap is not supported"); } -OnDiskTrieRawHashMap::const_pointer +OnDiskTrieRawHashMap::ConstOnDiskPtr OnDiskTrieRawHashMap::find(ArrayRef<uint8_t> Hash) const { - return const_pointer(); + return ConstOnDiskPtr(); } void OnDiskTrieRawHashMap::print( diff --git a/llvm/lib/IR/Globals.cpp b/llvm/lib/IR/Globals.cpp index 1a7a5c5..c3a472b 100644 --- a/llvm/lib/IR/Globals.cpp +++ b/llvm/lib/IR/Globals.cpp @@ -419,6 +419,7 @@ findBaseObject(const Constant *C, DenseSet<const GlobalAlias *> &Aliases, case Instruction::PtrToAddr: case Instruction::PtrToInt: case Instruction::BitCast: + case Instruction::AddrSpaceCast: case Instruction::GetElementPtr: return findBaseObject(CE->getOperand(0), Aliases, Op); default: diff --git a/llvm/lib/Option/ArgList.cpp b/llvm/lib/Option/ArgList.cpp index c4188b3b..2f4e212 100644 --- a/llvm/lib/Option/ArgList.cpp +++ b/llvm/lib/Option/ArgList.cpp @@ -14,12 +14,14 @@ #include "llvm/Config/llvm-config.h" #include "llvm/Option/Arg.h" #include "llvm/Option/OptSpecifier.h" +#include "llvm/Option/OptTable.h" #include "llvm/Option/Option.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include <algorithm> #include <cassert> +#include <cstddef> #include <memory> #include <string> #include <utility> @@ -202,6 +204,42 @@ void ArgList::print(raw_ostream &O) const { LLVM_DUMP_METHOD void ArgList::dump() const { print(dbgs()); } #endif +StringRef ArgList::getSubCommand( + ArrayRef<OptTable::SubCommand> AllSubCommands, + std::function<void(ArrayRef<StringRef>)> HandleMultipleSubcommands, + std::function<void(ArrayRef<StringRef>)> HandleOtherPositionals) const { + + SmallVector<StringRef, 4> SubCommands; + SmallVector<StringRef, 4> OtherPositionals; + for (const Arg *A : *this) { + if (A->getOption().getKind() != Option::InputClass) + continue; + + size_t OldSize = SubCommands.size(); + for (const OptTable::SubCommand &CMD : AllSubCommands) { + if (StringRef(CMD.Name) == A->getValue()) + SubCommands.push_back(A->getValue()); + } + + if (SubCommands.size() == OldSize) + OtherPositionals.push_back(A->getValue()); + } + + // Invoke callbacks if necessary. + if (SubCommands.size() > 1) { + HandleMultipleSubcommands(SubCommands); + return {}; + } + if (!OtherPositionals.empty()) { + HandleOtherPositionals(OtherPositionals); + return {}; + } + + if (SubCommands.size() == 1) + return SubCommands.front(); + return {}; // No valid usage of subcommand found. +} + void InputArgList::releaseMemory() { // An InputArgList always owns its arguments. for (Arg *A : *this) diff --git a/llvm/lib/Option/OptTable.cpp b/llvm/lib/Option/OptTable.cpp index 6d10e61..14e3b0d 100644 --- a/llvm/lib/Option/OptTable.cpp +++ b/llvm/lib/Option/OptTable.cpp @@ -79,9 +79,12 @@ OptSpecifier::OptSpecifier(const Option *Opt) : ID(Opt->getID()) {} OptTable::OptTable(const StringTable &StrTable, ArrayRef<StringTable::Offset> PrefixesTable, - ArrayRef<Info> OptionInfos, bool IgnoreCase) + ArrayRef<Info> OptionInfos, bool IgnoreCase, + ArrayRef<SubCommand> SubCommands, + ArrayRef<unsigned> SubCommandIDsTable) : StrTable(&StrTable), PrefixesTable(PrefixesTable), - OptionInfos(OptionInfos), IgnoreCase(IgnoreCase) { + OptionInfos(OptionInfos), IgnoreCase(IgnoreCase), + SubCommands(SubCommands), SubCommandIDsTable(SubCommandIDsTable) { // Explicitly zero initialize the error to work around a bug in array // value-initialization on MinGW with gcc 4.3.5. @@ -715,9 +718,10 @@ static const char *getOptionHelpGroup(const OptTable &Opts, OptSpecifier Id) { void OptTable::printHelp(raw_ostream &OS, const char *Usage, const char *Title, bool ShowHidden, bool ShowAllAliases, - Visibility VisibilityMask) const { + Visibility VisibilityMask, + StringRef SubCommand) const { return internalPrintHelp( - OS, Usage, Title, ShowHidden, ShowAllAliases, + OS, Usage, Title, SubCommand, ShowHidden, ShowAllAliases, [VisibilityMask](const Info &CandidateInfo) -> bool { return (CandidateInfo.Visibility & VisibilityMask) == 0; }, @@ -730,7 +734,7 @@ void OptTable::printHelp(raw_ostream &OS, const char *Usage, const char *Title, bool ShowHidden = !(FlagsToExclude & HelpHidden); FlagsToExclude &= ~HelpHidden; return internalPrintHelp( - OS, Usage, Title, ShowHidden, ShowAllAliases, + OS, Usage, Title, /*SubCommand=*/{}, ShowHidden, ShowAllAliases, [FlagsToInclude, FlagsToExclude](const Info &CandidateInfo) { if (FlagsToInclude && !(CandidateInfo.Flags & FlagsToInclude)) return true; @@ -742,16 +746,62 @@ void OptTable::printHelp(raw_ostream &OS, const char *Usage, const char *Title, } void OptTable::internalPrintHelp( - raw_ostream &OS, const char *Usage, const char *Title, bool ShowHidden, - bool ShowAllAliases, std::function<bool(const Info &)> ExcludeOption, + raw_ostream &OS, const char *Usage, const char *Title, StringRef SubCommand, + bool ShowHidden, bool ShowAllAliases, + std::function<bool(const Info &)> ExcludeOption, Visibility VisibilityMask) const { OS << "OVERVIEW: " << Title << "\n\n"; - OS << "USAGE: " << Usage << "\n\n"; // Render help text into a map of group-name to a list of (option, help) // pairs. std::map<std::string, std::vector<OptionInfo>> GroupedOptionHelp; + auto ActiveSubCommand = + std::find_if(SubCommands.begin(), SubCommands.end(), + [&](const auto &C) { return SubCommand == C.Name; }); + if (!SubCommand.empty()) { + assert(ActiveSubCommand != SubCommands.end() && + "Not a valid registered subcommand."); + OS << ActiveSubCommand->HelpText << "\n\n"; + if (!StringRef(ActiveSubCommand->Usage).empty()) + OS << "USAGE: " << ActiveSubCommand->Usage << "\n\n"; + } else { + OS << "USAGE: " << Usage << "\n\n"; + if (SubCommands.size() > 1) { + OS << "SUBCOMMANDS:\n\n"; + for (const auto &C : SubCommands) + OS << C.Name << " - " << C.HelpText << "\n"; + OS << "\n"; + } + } + + auto DoesOptionBelongToSubcommand = [&](const Info &CandidateInfo) { + // Retrieve the SubCommandIDs registered to the given current CandidateInfo + // Option. + ArrayRef<unsigned> SubCommandIDs = + CandidateInfo.getSubCommandIDs(SubCommandIDsTable); + + // If no registered subcommands, then only global options are to be printed. + // If no valid SubCommand (empty) in commandline then print the current + // global CandidateInfo Option. + if (SubCommandIDs.empty()) + return SubCommand.empty(); + + // Handle CandidateInfo Option which has at least one registered SubCommand. + // If no valid SubCommand (empty) in commandline, this CandidateInfo option + // should not be printed. + if (SubCommand.empty()) + return false; + + // Find the ID of the valid subcommand passed in commandline (its index in + // the SubCommands table which contains all subcommands). + unsigned ActiveSubCommandID = ActiveSubCommand - &SubCommands[0]; + // Print if the ActiveSubCommandID is registered with the CandidateInfo + // Option. + return std::find(SubCommandIDs.begin(), SubCommandIDs.end(), + ActiveSubCommandID) != SubCommandIDs.end(); + }; + for (unsigned Id = 1, e = getNumOptions() + 1; Id != e; ++Id) { // FIXME: Split out option groups. if (getOptionKind(Id) == Option::GroupClass) @@ -764,6 +814,9 @@ void OptTable::internalPrintHelp( if (ExcludeOption(CandidateInfo)) continue; + if (!DoesOptionBelongToSubcommand(CandidateInfo)) + continue; + // If an alias doesn't have a help text, show a help text for the aliased // option instead. const char *HelpText = getOptionHelpText(Id, VisibilityMask); @@ -791,8 +844,11 @@ void OptTable::internalPrintHelp( GenericOptTable::GenericOptTable(const StringTable &StrTable, ArrayRef<StringTable::Offset> PrefixesTable, - ArrayRef<Info> OptionInfos, bool IgnoreCase) - : OptTable(StrTable, PrefixesTable, OptionInfos, IgnoreCase) { + ArrayRef<Info> OptionInfos, bool IgnoreCase, + ArrayRef<SubCommand> SubCommands, + ArrayRef<unsigned> SubCommandIDsTable) + : OptTable(StrTable, PrefixesTable, OptionInfos, IgnoreCase, SubCommands, + SubCommandIDsTable) { std::set<StringRef> TmpPrefixesUnion; for (auto const &Info : OptionInfos.drop_front(FirstSearchableIndex)) diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp index 7069e8d..119caea 100644 --- a/llvm/lib/Passes/PassBuilderPipelines.cpp +++ b/llvm/lib/Passes/PassBuilderPipelines.cpp @@ -1960,6 +1960,7 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level, // is fixed. MPM.addPass(WholeProgramDevirtPass(ExportSummary, nullptr)); + MPM.addPass(NoRecurseLTOInferencePass()); // Stop here at -O1. if (Level == OptimizationLevel::O1) { // The LowerTypeTestsPass needs to run to lower type metadata and the diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index f0e7d36..88550ea 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -119,6 +119,7 @@ MODULE_PASS("metarenamer", MetaRenamerPass()) MODULE_PASS("module-inline", ModuleInlinerPass()) MODULE_PASS("name-anon-globals", NameAnonGlobalPass()) MODULE_PASS("no-op-module", NoOpModulePass()) +MODULE_PASS("norecurse-lto-inference", NoRecurseLTOInferencePass()) MODULE_PASS("nsan", NumericalStabilitySanitizerPass()) MODULE_PASS("openmp-opt", OpenMPOptPass()) MODULE_PASS("openmp-opt-postlink", diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp index 4357264d..c76689f 100644 --- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp @@ -345,12 +345,6 @@ static unsigned getStackHazardSize(const MachineFunction &MF) { return MF.getSubtarget<AArch64Subtarget>().getStreamingHazardSize(); } -/// Returns true if PPRs are spilled as ZPRs. -static bool arePPRsSpilledAsZPR(const MachineFunction &MF) { - return MF.getSubtarget().getRegisterInfo()->getSpillSize( - AArch64::PPRRegClass) == 16; -} - StackOffset AArch64FrameLowering::getZPRStackSize(const MachineFunction &MF) const { const AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); @@ -1966,8 +1960,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters( StrOpc = RPI.isPaired() ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI; break; case RegPairInfo::PPR: - StrOpc = - Size == 16 ? AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO : AArch64::STR_PXI; + StrOpc = AArch64::STR_PXI; break; case RegPairInfo::VG: StrOpc = AArch64::STRXui; @@ -2178,8 +2171,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters( LdrOpc = RPI.isPaired() ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI; break; case RegPairInfo::PPR: - LdrOpc = Size == 16 ? AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO - : AArch64::LDR_PXI; + LdrOpc = AArch64::LDR_PXI; break; case RegPairInfo::VG: continue; @@ -2286,9 +2278,7 @@ static std::optional<int> getLdStFrameID(const MachineInstr &MI, // Returns true if the LDST MachineInstr \p MI is a PPR access. static bool isPPRAccess(const MachineInstr &MI) { - return MI.getOpcode() != AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO && - MI.getOpcode() != AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO && - AArch64::PPRRegClass.contains(MI.getOperand(0).getReg()); + return AArch64::PPRRegClass.contains(MI.getOperand(0).getReg()); } // Check if a Hazard slot is needed for the current function, and if so create @@ -2390,12 +2380,6 @@ void AArch64FrameLowering::determineStackHazardSlot( return; } - if (arePPRsSpilledAsZPR(MF)) { - LLVM_DEBUG(dbgs() << "SplitSVEObjects is not supported with " - "-aarch64-enable-zpr-predicate-spills"); - return; - } - // If another calling convention is explicitly set FPRs can't be promoted to // ZPR callee-saves. if (!is_contained({CallingConv::C, CallingConv::Fast, @@ -2519,14 +2503,6 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, continue; } - // Always save P4 when PPR spills are ZPR-sized and a predicate above p8 is - // spilled. If all of p0-p3 are used as return values p4 is must be free - // to reload p8-p15. - if (RegInfo->getSpillSize(AArch64::PPRRegClass) == 16 && - AArch64::PPR_p8to15RegClass.contains(Reg)) { - SavedRegs.set(AArch64::P4); - } - // MachO's compact unwind format relies on all registers being stored in // pairs. // FIXME: the usual format is actually better if unwinding isn't needed. @@ -2587,7 +2563,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, auto SpillSize = TRI->getSpillSize(*RC); bool IsZPR = AArch64::ZPRRegClass.contains(Reg); bool IsPPR = !IsZPR && AArch64::PPRRegClass.contains(Reg); - if (IsZPR || (IsPPR && arePPRsSpilledAsZPR(MF))) + if (IsZPR) ZPRCSStackSize += SpillSize; else if (IsPPR) PPRCSStackSize += SpillSize; @@ -2902,7 +2878,7 @@ static SVEStackSizes determineSVEStackSizes(MachineFunction &MF, StackTop += MFI.getObjectSize(FI); StackTop = alignTo(StackTop, Alignment); - assert(StackTop < std::numeric_limits<int64_t>::max() && + assert(StackTop < (uint64_t)std::numeric_limits<int64_t>::max() && "SVE StackTop far too large?!"); int64_t Offset = -int64_t(StackTop); @@ -2961,314 +2937,8 @@ static SVEStackSizes determineSVEStackSizes(MachineFunction &MF, return SVEStack; } -/// Attempts to scavenge a register from \p ScavengeableRegs given the used -/// registers in \p UsedRegs. -static Register tryScavengeRegister(LiveRegUnits const &UsedRegs, - BitVector const &ScavengeableRegs, - Register PreferredReg) { - if (PreferredReg != AArch64::NoRegister && UsedRegs.available(PreferredReg)) - return PreferredReg; - for (auto Reg : ScavengeableRegs.set_bits()) { - if (UsedRegs.available(Reg)) - return Reg; - } - return AArch64::NoRegister; -} - -/// Propagates frame-setup/destroy flags from \p SourceMI to all instructions in -/// \p MachineInstrs. -static void propagateFrameFlags(MachineInstr &SourceMI, - ArrayRef<MachineInstr *> MachineInstrs) { - for (MachineInstr *MI : MachineInstrs) { - if (SourceMI.getFlag(MachineInstr::FrameSetup)) - MI->setFlag(MachineInstr::FrameSetup); - if (SourceMI.getFlag(MachineInstr::FrameDestroy)) - MI->setFlag(MachineInstr::FrameDestroy); - } -} - -/// RAII helper class for scavenging or spilling a register. On construction -/// attempts to find a free register of class \p RC (given \p UsedRegs and \p -/// AllocatableRegs), if no register can be found spills \p SpillCandidate to \p -/// MaybeSpillFI to free a register. The free'd register is returned via the \p -/// FreeReg output parameter. On destruction, if there is a spill, its previous -/// value is reloaded. The spilling and scavenging is only valid at the -/// insertion point \p MBBI, this class should _not_ be used in places that -/// create or manipulate basic blocks, moving the expected insertion point. -struct ScopedScavengeOrSpill { - ScopedScavengeOrSpill(const ScopedScavengeOrSpill &) = delete; - ScopedScavengeOrSpill(ScopedScavengeOrSpill &&) = delete; - - ScopedScavengeOrSpill(MachineFunction &MF, MachineBasicBlock &MBB, - MachineBasicBlock::iterator MBBI, - Register SpillCandidate, const TargetRegisterClass &RC, - LiveRegUnits const &UsedRegs, - BitVector const &AllocatableRegs, - std::optional<int> *MaybeSpillFI, - Register PreferredReg = AArch64::NoRegister) - : MBB(MBB), MBBI(MBBI), RC(RC), TII(static_cast<const AArch64InstrInfo &>( - *MF.getSubtarget().getInstrInfo())), - TRI(*MF.getSubtarget().getRegisterInfo()) { - FreeReg = tryScavengeRegister(UsedRegs, AllocatableRegs, PreferredReg); - if (FreeReg != AArch64::NoRegister) - return; - assert(MaybeSpillFI && "Expected emergency spill slot FI information " - "(attempted to spill in prologue/epilogue?)"); - if (!MaybeSpillFI->has_value()) { - MachineFrameInfo &MFI = MF.getFrameInfo(); - *MaybeSpillFI = MFI.CreateSpillStackObject(TRI.getSpillSize(RC), - TRI.getSpillAlign(RC)); - } - FreeReg = SpillCandidate; - SpillFI = MaybeSpillFI->value(); - TII.storeRegToStackSlot(MBB, MBBI, FreeReg, false, *SpillFI, &RC, &TRI, - Register()); - } - - bool hasSpilled() const { return SpillFI.has_value(); } - - /// Returns the free register (found from scavenging or spilling a register). - Register freeRegister() const { return FreeReg; } - - Register operator*() const { return freeRegister(); } - - ~ScopedScavengeOrSpill() { - if (hasSpilled()) - TII.loadRegFromStackSlot(MBB, MBBI, FreeReg, *SpillFI, &RC, &TRI, - Register()); - } - -private: - MachineBasicBlock &MBB; - MachineBasicBlock::iterator MBBI; - const TargetRegisterClass &RC; - const AArch64InstrInfo &TII; - const TargetRegisterInfo &TRI; - Register FreeReg = AArch64::NoRegister; - std::optional<int> SpillFI; -}; - -/// Emergency stack slots for expanding SPILL_PPR_TO_ZPR_SLOT_PSEUDO and -/// FILL_PPR_FROM_ZPR_SLOT_PSEUDO. -struct EmergencyStackSlots { - std::optional<int> ZPRSpillFI; - std::optional<int> PPRSpillFI; - std::optional<int> GPRSpillFI; -}; - -/// Registers available for scavenging (ZPR, PPR3b, GPR). -struct ScavengeableRegs { - BitVector ZPRRegs; - BitVector PPR3bRegs; - BitVector GPRRegs; -}; - -static bool isInPrologueOrEpilogue(const MachineInstr &MI) { - return MI.getFlag(MachineInstr::FrameSetup) || - MI.getFlag(MachineInstr::FrameDestroy); -} - -/// Expands: -/// ``` -/// SPILL_PPR_TO_ZPR_SLOT_PSEUDO $p0, %stack.0, 0 -/// ``` -/// To: -/// ``` -/// $z0 = CPY_ZPzI_B $p0, 1, 0 -/// STR_ZXI $z0, $stack.0, 0 -/// ``` -/// While ensuring a ZPR ($z0 in this example) is free for the predicate ( -/// spilling if necessary). -static void expandSpillPPRToZPRSlotPseudo(MachineBasicBlock &MBB, - MachineInstr &MI, - const TargetRegisterInfo &TRI, - LiveRegUnits const &UsedRegs, - ScavengeableRegs const &SR, - EmergencyStackSlots &SpillSlots) { - MachineFunction &MF = *MBB.getParent(); - auto *TII = - static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo()); - - ScopedScavengeOrSpill ZPredReg( - MF, MBB, MI, AArch64::Z0, AArch64::ZPRRegClass, UsedRegs, SR.ZPRRegs, - isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.ZPRSpillFI); - - SmallVector<MachineInstr *, 2> MachineInstrs; - const DebugLoc &DL = MI.getDebugLoc(); - MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::CPY_ZPzI_B)) - .addReg(*ZPredReg, RegState::Define) - .add(MI.getOperand(0)) - .addImm(1) - .addImm(0) - .getInstr()); - MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::STR_ZXI)) - .addReg(*ZPredReg) - .add(MI.getOperand(1)) - .addImm(MI.getOperand(2).getImm()) - .setMemRefs(MI.memoperands()) - .getInstr()); - propagateFrameFlags(MI, MachineInstrs); -} - -/// Expands: -/// ``` -/// $p0 = FILL_PPR_FROM_ZPR_SLOT_PSEUDO %stack.0, 0 -/// ``` -/// To: -/// ``` -/// $z0 = LDR_ZXI %stack.0, 0 -/// $p0 = PTRUE_B 31, implicit $vg -/// $p0 = CMPNE_PPzZI_B $p0, $z0, 0, implicit-def $nzcv, implicit-def $nzcv -/// ``` -/// While ensuring a ZPR ($z0 in this example) is free for the predicate ( -/// spilling if necessary). If the status flags are in use at the point of -/// expansion they are preserved (by moving them to/from a GPR). This may cause -/// an additional spill if no GPR is free at the expansion point. -static bool expandFillPPRFromZPRSlotPseudo( - MachineBasicBlock &MBB, MachineInstr &MI, const TargetRegisterInfo &TRI, - LiveRegUnits const &UsedRegs, ScavengeableRegs const &SR, - MachineInstr *&LastPTrue, EmergencyStackSlots &SpillSlots) { - MachineFunction &MF = *MBB.getParent(); - auto *TII = - static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo()); - - ScopedScavengeOrSpill ZPredReg( - MF, MBB, MI, AArch64::Z0, AArch64::ZPRRegClass, UsedRegs, SR.ZPRRegs, - isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.ZPRSpillFI); - - ScopedScavengeOrSpill PredReg( - MF, MBB, MI, AArch64::P0, AArch64::PPR_3bRegClass, UsedRegs, SR.PPR3bRegs, - isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.PPRSpillFI, - /*PreferredReg=*/ - LastPTrue ? LastPTrue->getOperand(0).getReg() : AArch64::NoRegister); - - // Elide NZCV spills if we know it is not used. - bool IsNZCVUsed = !UsedRegs.available(AArch64::NZCV); - std::optional<ScopedScavengeOrSpill> NZCVSaveReg; - if (IsNZCVUsed) - NZCVSaveReg.emplace( - MF, MBB, MI, AArch64::X0, AArch64::GPR64RegClass, UsedRegs, SR.GPRRegs, - isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.GPRSpillFI); - SmallVector<MachineInstr *, 4> MachineInstrs; - const DebugLoc &DL = MI.getDebugLoc(); - MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::LDR_ZXI)) - .addReg(*ZPredReg, RegState::Define) - .add(MI.getOperand(1)) - .addImm(MI.getOperand(2).getImm()) - .setMemRefs(MI.memoperands()) - .getInstr()); - if (IsNZCVUsed) - MachineInstrs.push_back( - BuildMI(MBB, MI, DL, TII->get(AArch64::MRS)) - .addReg(NZCVSaveReg->freeRegister(), RegState::Define) - .addImm(AArch64SysReg::NZCV) - .addReg(AArch64::NZCV, RegState::Implicit) - .getInstr()); - - // Reuse previous ptrue if we know it has not been clobbered. - if (LastPTrue) { - assert(*PredReg == LastPTrue->getOperand(0).getReg()); - LastPTrue->moveBefore(&MI); - } else { - LastPTrue = BuildMI(MBB, MI, DL, TII->get(AArch64::PTRUE_B)) - .addReg(*PredReg, RegState::Define) - .addImm(31); - } - MachineInstrs.push_back(LastPTrue); - MachineInstrs.push_back( - BuildMI(MBB, MI, DL, TII->get(AArch64::CMPNE_PPzZI_B)) - .addReg(MI.getOperand(0).getReg(), RegState::Define) - .addReg(*PredReg) - .addReg(*ZPredReg) - .addImm(0) - .addReg(AArch64::NZCV, RegState::ImplicitDefine) - .getInstr()); - if (IsNZCVUsed) - MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::MSR)) - .addImm(AArch64SysReg::NZCV) - .addReg(NZCVSaveReg->freeRegister()) - .addReg(AArch64::NZCV, RegState::ImplicitDefine) - .getInstr()); - - propagateFrameFlags(MI, MachineInstrs); - return PredReg.hasSpilled(); -} - -/// Expands all FILL_PPR_FROM_ZPR_SLOT_PSEUDO and SPILL_PPR_TO_ZPR_SLOT_PSEUDO -/// operations within the MachineBasicBlock \p MBB. -static bool expandSMEPPRToZPRSpillPseudos(MachineBasicBlock &MBB, - const TargetRegisterInfo &TRI, - ScavengeableRegs const &SR, - EmergencyStackSlots &SpillSlots) { - LiveRegUnits UsedRegs(TRI); - UsedRegs.addLiveOuts(MBB); - bool HasPPRSpills = false; - MachineInstr *LastPTrue = nullptr; - for (MachineInstr &MI : make_early_inc_range(reverse(MBB))) { - UsedRegs.stepBackward(MI); - switch (MI.getOpcode()) { - case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO: - if (LastPTrue && - MI.definesRegister(LastPTrue->getOperand(0).getReg(), &TRI)) - LastPTrue = nullptr; - HasPPRSpills |= expandFillPPRFromZPRSlotPseudo(MBB, MI, TRI, UsedRegs, SR, - LastPTrue, SpillSlots); - MI.eraseFromParent(); - break; - case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO: - expandSpillPPRToZPRSlotPseudo(MBB, MI, TRI, UsedRegs, SR, SpillSlots); - MI.eraseFromParent(); - [[fallthrough]]; - default: - LastPTrue = nullptr; - break; - } - } - - return HasPPRSpills; -} - void AArch64FrameLowering::processFunctionBeforeFrameFinalized( MachineFunction &MF, RegScavenger *RS) const { - - AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); - const TargetSubtargetInfo &TSI = MF.getSubtarget(); - const TargetRegisterInfo &TRI = *TSI.getRegisterInfo(); - - // If predicates spills are 16-bytes we may need to expand - // SPILL_PPR_TO_ZPR_SLOT_PSEUDO/FILL_PPR_FROM_ZPR_SLOT_PSEUDO. - if (AFI->hasStackFrame() && TRI.getSpillSize(AArch64::PPRRegClass) == 16) { - auto ComputeScavengeableRegisters = [&](unsigned RegClassID) { - BitVector Regs = TRI.getAllocatableSet(MF, TRI.getRegClass(RegClassID)); - assert(Regs.count() > 0 && "Expected scavengeable registers"); - return Regs; - }; - - ScavengeableRegs SR{}; - SR.ZPRRegs = ComputeScavengeableRegisters(AArch64::ZPRRegClassID); - // Only p0-7 are possible as the second operand of cmpne (needed for fills). - SR.PPR3bRegs = ComputeScavengeableRegisters(AArch64::PPR_3bRegClassID); - SR.GPRRegs = ComputeScavengeableRegisters(AArch64::GPR64RegClassID); - - EmergencyStackSlots SpillSlots; - for (MachineBasicBlock &MBB : MF) { - // In the case we had to spill a predicate (in the range p0-p7) to reload - // a predicate (>= p8), additional spill/fill pseudos will be created. - // These need an additional expansion pass. Note: There will only be at - // most two expansion passes, as spilling/filling a predicate in the range - // p0-p7 never requires spilling another predicate. - for (int Pass = 0; Pass < 2; Pass++) { - bool HasPPRSpills = - expandSMEPPRToZPRSpillPseudos(MBB, TRI, SR, SpillSlots); - assert((Pass == 0 || !HasPPRSpills) && "Did not expect PPR spills"); - if (!HasPPRSpills) - break; - } - } - } - - MachineFrameInfo &MFI = MF.getFrameInfo(); - assert(getStackGrowthDirection() == TargetFrameLowering::StackGrowsDown && "Upwards growing stack unsupported"); @@ -3279,6 +2949,9 @@ void AArch64FrameLowering::processFunctionBeforeFrameFinalized( if (!MF.hasEHFunclets()) return; + MachineFrameInfo &MFI = MF.getFrameInfo(); + auto *AFI = MF.getInfo<AArch64FunctionInfo>(); + // Win64 C++ EH needs to allocate space for the catch objects in the fixed // object area right next to the UnwindHelp object. WinEHFuncInfo &EHInfo = *MF.getWinEHFuncInfo(); @@ -4280,18 +3953,10 @@ void AArch64FrameLowering::emitRemarks( } unsigned RegTy = StackAccess::AccessType::GPR; - if (MFI.hasScalableStackID(FrameIdx)) { - // SPILL_PPR_TO_ZPR_SLOT_PSEUDO and FILL_PPR_FROM_ZPR_SLOT_PSEUDO - // spill/fill the predicate as a data vector (so are an FPR access). - if (MI.getOpcode() != AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO && - MI.getOpcode() != AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO && - AArch64::PPRRegClass.contains(MI.getOperand(0).getReg())) { - RegTy = StackAccess::PPR; - } else - RegTy = StackAccess::FPR; - } else if (AArch64InstrInfo::isFpOrNEON(MI)) { + if (MFI.hasScalableStackID(FrameIdx)) + RegTy = isPPRAccess(MI) ? StackAccess::PPR : StackAccess::FPR; + else if (AArch64InstrInfo::isFpOrNEON(MI)) RegTy = StackAccess::FPR; - } StackAccesses[ArrIdx].AccessTypes |= RegTy; diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index 5a90da1..b8761d97 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -2579,8 +2579,6 @@ unsigned AArch64InstrInfo::getLoadStoreImmIdx(unsigned Opc) { case AArch64::STZ2Gi: case AArch64::STZGi: case AArch64::TAGPstack: - case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO: - case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO: return 2; case AArch64::LD1B_D_IMM: case AArch64::LD1B_H_IMM: @@ -4387,8 +4385,6 @@ bool AArch64InstrInfo::getMemOpInfo(unsigned Opcode, TypeSize &Scale, MinOffset = -256; MaxOffset = 254; break; - case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO: - case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO: case AArch64::LDR_ZXI: case AArch64::STR_ZXI: Scale = TypeSize::getScalable(16); @@ -5098,33 +5094,31 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg) .addImm(0) .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); + } else if (Subtarget.hasZeroCycleRegMoveGPR64() && + !Subtarget.hasZeroCycleRegMoveGPR32()) { + // Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move. + MCRegister DestRegX = TRI->getMatchingSuperReg(DestReg, AArch64::sub_32, + &AArch64::GPR64spRegClass); + assert(DestRegX.isValid() && "Destination super-reg not valid"); + MCRegister SrcRegX = + SrcReg == AArch64::WZR + ? AArch64::XZR + : TRI->getMatchingSuperReg(SrcReg, AArch64::sub_32, + &AArch64::GPR64spRegClass); + assert(SrcRegX.isValid() && "Source super-reg not valid"); + // This instruction is reading and writing X registers. This may upset + // the register scavenger and machine verifier, so we need to indicate + // that we are reading an undefined value from SrcRegX, but a proper + // value from SrcReg. + BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestRegX) + .addReg(AArch64::XZR) + .addReg(SrcRegX, RegState::Undef) + .addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc)); } else { - if (Subtarget.hasZeroCycleRegMoveGPR64() && - !Subtarget.hasZeroCycleRegMoveGPR32()) { - // Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move. - MCRegister DestRegX = TRI->getMatchingSuperReg( - DestReg, AArch64::sub_32, &AArch64::GPR64spRegClass); - assert(DestRegX.isValid() && "Destination super-reg not valid"); - MCRegister SrcRegX = - SrcReg == AArch64::WZR - ? AArch64::XZR - : TRI->getMatchingSuperReg(SrcReg, AArch64::sub_32, - &AArch64::GPR64spRegClass); - assert(SrcRegX.isValid() && "Source super-reg not valid"); - // This instruction is reading and writing X registers. This may upset - // the register scavenger and machine verifier, so we need to indicate - // that we are reading an undefined value from SrcRegX, but a proper - // value from SrcReg. - BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestRegX) - .addReg(AArch64::XZR) - .addReg(SrcRegX, RegState::Undef) - .addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc)); - } else { - // Otherwise, expand to ORR WZR. - BuildMI(MBB, I, DL, get(AArch64::ORRWrr), DestReg) - .addReg(AArch64::WZR) - .addReg(SrcReg, getKillRegState(KillSrc)); - } + // Otherwise, expand to ORR WZR. + BuildMI(MBB, I, DL, get(AArch64::ORRWrr), DestReg) + .addReg(AArch64::WZR) + .addReg(SrcReg, getKillRegState(KillSrc)); } return; } @@ -5650,11 +5644,6 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB, "Unexpected register store without SVE store instructions"); Opc = AArch64::STR_ZXI; StackID = TargetStackID::ScalableVector; - } else if (AArch64::PPRRegClass.hasSubClassEq(RC)) { - assert(Subtarget.isSVEorStreamingSVEAvailable() && - "Unexpected predicate store without SVE store instructions"); - Opc = AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO; - StackID = TargetStackID::ScalableVector; } break; case 24: @@ -5835,11 +5824,6 @@ void AArch64InstrInfo::loadRegFromStackSlot( "Unexpected register load without SVE load instructions"); Opc = AArch64::LDR_ZXI; StackID = TargetStackID::ScalableVector; - } else if (AArch64::PPRRegClass.hasSubClassEq(RC)) { - assert(Subtarget.isSVEorStreamingSVEAvailable() && - "Unexpected predicate load without SVE load instructions"); - Opc = AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO; - StackID = TargetStackID::ScalableVector; } break; case 24: diff --git a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp index aed137c..1568161 100644 --- a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp +++ b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp @@ -57,10 +57,7 @@ static bool isPartOfZPRCalleeSaves(MachineBasicBlock::iterator I) { case AArch64::ST1B_2Z_IMM: case AArch64::STR_ZXI: case AArch64::LDR_ZXI: - case AArch64::CPY_ZPzI_B: - case AArch64::CMPNE_PPzZI_B: case AArch64::PTRUE_C_B: - case AArch64::PTRUE_B: return I->getFlag(MachineInstr::FrameSetup) || I->getFlag(MachineInstr::FrameDestroy); case AArch64::SEH_SavePReg: diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td index 5d89862..ef974df 100644 --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td @@ -980,19 +980,10 @@ class ZPRRegOp <string Suffix, AsmOperandClass C, ElementSizeEnum Size, //****************************************************************************** // SVE predicate register classes. - -// Note: This hardware mode is enabled in AArch64Subtarget::getHwModeSet() -// (without the use of the table-gen'd predicates). -def SMEWithZPRPredicateSpills : HwMode<[Predicate<"false">]>; - -def PPRSpillFillRI : RegInfoByHwMode< - [DefaultMode, SMEWithZPRPredicateSpills], - [RegInfo<16,16,16>, RegInfo<16,128,128>]>; - class PPRClass<int firstreg, int lastreg, int step = 1> : RegisterClass<"AArch64", [ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ], 16, (sequence "P%u", firstreg, lastreg, step)> { - let RegInfos = PPRSpillFillRI; + let Size = 16; } def PPR : PPRClass<0, 15> { diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp index 98e0a11..12ddf47 100644 --- a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp +++ b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp @@ -86,11 +86,6 @@ static cl::alias AArch64StreamingStackHazardSize( cl::desc("alias for -aarch64-streaming-hazard-size"), cl::aliasopt(AArch64StreamingHazardSize)); -static cl::opt<bool> EnableZPRPredicateSpills( - "aarch64-enable-zpr-predicate-spills", cl::init(false), cl::Hidden, - cl::desc( - "Enables spilling/reloading SVE predicates as data vectors (ZPRs)")); - static cl::opt<unsigned> VScaleForTuningOpt("sve-vscale-for-tuning", cl::Hidden, cl::desc("Force a vscale for tuning factor for SVE")); @@ -426,20 +421,6 @@ AArch64Subtarget::AArch64Subtarget(const Triple &TT, StringRef CPU, EnableSubregLiveness = EnableSubregLivenessTracking.getValue(); } -unsigned AArch64Subtarget::getHwModeSet() const { - AArch64HwModeBits Modes = AArch64HwModeBits::DefaultMode; - - // Use a special hardware mode in streaming[-compatible] functions with - // aarch64-enable-zpr-predicate-spills. This changes the spill size (and - // alignment) for the predicate register class. - if (EnableZPRPredicateSpills.getValue() && - (isStreaming() || isStreamingCompatible())) { - Modes |= AArch64HwModeBits::SMEWithZPRPredicateSpills; - } - - return to_underlying(Modes); -} - const CallLowering *AArch64Subtarget::getCallLowering() const { return CallLoweringInfo.get(); } diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h index 671df35..8974965 100644 --- a/llvm/lib/Target/AArch64/AArch64Subtarget.h +++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h @@ -130,8 +130,6 @@ public: bool IsStreaming = false, bool IsStreamingCompatible = false, bool HasMinSize = false); - virtual unsigned getHwModeSet() const override; - // Getters for SubtargetFeatures defined in tablegen #define GET_SUBTARGETINFO_MACRO(ATTRIBUTE, DEFAULT, GETTER) \ bool GETTER() const { return ATTRIBUTE; } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 50a8754..479e345 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -5666,18 +5666,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost( VectorType *AccumVectorType = VectorType::get(AccumType, VF.divideCoefficientBy(Ratio)); // We don't yet support all kinds of legalization. - auto TA = TLI->getTypeAction(AccumVectorType->getContext(), - EVT::getEVT(AccumVectorType)); - switch (TA) { + auto TC = TLI->getTypeConversion(AccumVectorType->getContext(), + EVT::getEVT(AccumVectorType)); + switch (TC.first) { default: return Invalid; case TargetLowering::TypeLegal: case TargetLowering::TypePromoteInteger: case TargetLowering::TypeSplitVector: + // The legalised type (e.g. after splitting) must be legal too. + if (TLI->getTypeAction(AccumVectorType->getContext(), TC.second) != + TargetLowering::TypeLegal) + return Invalid; break; } - // Check what kind of type-legalisation happens. std::pair<InstructionCost, MVT> AccumLT = getTypeLegalizationCost(AccumVectorType); std::pair<InstructionCost, MVT> InputLT = diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td index be44b8f..33f35ad 100644 --- a/llvm/lib/Target/AArch64/SMEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td @@ -58,20 +58,6 @@ def FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO : let hasSideEffects = 0; } -def SPILL_PPR_TO_ZPR_SLOT_PSEUDO : - Pseudo<(outs), (ins PPRorPNRAny:$Pt, GPR64sp:$Rn, simm9:$imm9), []>, Sched<[]> -{ - let mayStore = 1; - let hasSideEffects = 0; -} - -def FILL_PPR_FROM_ZPR_SLOT_PSEUDO : - Pseudo<(outs PPRorPNRAny:$Pt), (ins GPR64sp:$Rn, simm9:$imm9), []>, Sched<[]> -{ - let mayLoad = 1; - let hasSideEffects = 0; -} - def SDTZALoadStore : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>]>; // SME ZA loads and stores def AArch64SMELdr : SDNode<"AArch64ISD::SME_ZA_LDR", SDTZALoadStore, diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td index 9446144..1a697f7 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPU.td +++ b/llvm/lib/Target/AMDGPU/AMDGPU.td @@ -1411,20 +1411,6 @@ def FeatureGloballyAddressableScratch : SubtargetFeature< "FLAT instructions can access scratch memory for any thread in any wave" >; -// FIXME: Remove after all users are migrated to attribute. -def FeatureDynamicVGPR : SubtargetFeature <"dynamic-vgpr", - "DynamicVGPR", - "true", - "Enable dynamic VGPR mode" ->; - -// FIXME: Remove after all users are migrated to attribute. -def FeatureDynamicVGPRBlockSize32 : SubtargetFeature<"dynamic-vgpr-block-size-32", - "DynamicVGPRBlockSize32", - "true", - "Use a block size of 32 for dynamic VGPR allocation (default is 16)" ->; - // Enable the use of SCRATCH_STORE/LOAD_BLOCK instructions for saving and // restoring the callee-saved registers. def FeatureUseBlockVGPROpsForCSR : SubtargetFeature<"block-vgpr-csr", @@ -1462,6 +1448,12 @@ def Feature45BitNumRecordsBufferResource : SubtargetFeature< "45-bit-num-records "The buffer resource (V#) supports 45-bit num_records" >; +def FeatureClusters : SubtargetFeature< "clusters", + "HasClusters", + "true", + "Has clusters of workgroups support" +>; + // Dummy feature used to disable assembler instructions. def FeatureDisable : SubtargetFeature<"", "FeatureDisable","true", @@ -2128,6 +2120,7 @@ def FeatureISAVersion12_50 : FeatureSet< Feature45BitNumRecordsBufferResource, FeatureSupportsXNACK, FeatureXNACK, + FeatureClusters, ]>; def FeatureISAVersion12_51 : FeatureSet< diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp index a67a7be..d0c0822 100644 --- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp +++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp @@ -1944,6 +1944,7 @@ public: void cvtVOP3Interp(MCInst &Inst, const OperandVector &Operands); void cvtVINTERP(MCInst &Inst, const OperandVector &Operands); + void cvtOpSelHelper(MCInst &Inst, unsigned OpSel); bool parseDimId(unsigned &Encoding); ParseStatus parseDim(OperandVector &Operands); @@ -9239,6 +9240,33 @@ static bool isRegOrImmWithInputMods(const MCInstrDesc &Desc, unsigned OpNum) { MCOI::OperandConstraint::TIED_TO) == -1; } +void AMDGPUAsmParser::cvtOpSelHelper(MCInst &Inst, unsigned OpSel) { + unsigned Opc = Inst.getOpcode(); + constexpr AMDGPU::OpName Ops[] = {AMDGPU::OpName::src0, AMDGPU::OpName::src1, + AMDGPU::OpName::src2}; + constexpr AMDGPU::OpName ModOps[] = {AMDGPU::OpName::src0_modifiers, + AMDGPU::OpName::src1_modifiers, + AMDGPU::OpName::src2_modifiers}; + for (int J = 0; J < 3; ++J) { + int OpIdx = AMDGPU::getNamedOperandIdx(Opc, Ops[J]); + if (OpIdx == -1) + // Some instructions, e.g. v_interp_p2_f16 in GFX9, have src0, src2, but + // no src1. So continue instead of break. + continue; + + int ModIdx = AMDGPU::getNamedOperandIdx(Opc, ModOps[J]); + uint32_t ModVal = Inst.getOperand(ModIdx).getImm(); + + if ((OpSel & (1 << J)) != 0) + ModVal |= SISrcMods::OP_SEL_0; + // op_sel[3] is encoded in src0_modifiers. + if (ModOps[J] == AMDGPU::OpName::src0_modifiers && (OpSel & (1 << 3)) != 0) + ModVal |= SISrcMods::DST_OP_SEL; + + Inst.getOperand(ModIdx).setImm(ModVal); + } +} + void AMDGPUAsmParser::cvtVOP3Interp(MCInst &Inst, const OperandVector &Operands) { OptionalImmIndexMap OptionalIdx; @@ -9275,6 +9303,16 @@ void AMDGPUAsmParser::cvtVOP3Interp(MCInst &Inst, const OperandVector &Operands) if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::omod)) addOptionalImmOperand(Inst, Operands, OptionalIdx, AMDGPUOperand::ImmTyOModSI); + + // Some v_interp instructions use op_sel[3] for dst. + if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::op_sel)) { + addOptionalImmOperand(Inst, Operands, OptionalIdx, + AMDGPUOperand::ImmTyOpSel); + int OpSelIdx = AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::op_sel); + unsigned OpSel = Inst.getOperand(OpSelIdx).getImm(); + + cvtOpSelHelper(Inst, OpSel); + } } void AMDGPUAsmParser::cvtVINTERP(MCInst &Inst, const OperandVector &Operands) @@ -9310,31 +9348,10 @@ void AMDGPUAsmParser::cvtVINTERP(MCInst &Inst, const OperandVector &Operands) if (OpSelIdx == -1) return; - const AMDGPU::OpName Ops[] = {AMDGPU::OpName::src0, AMDGPU::OpName::src1, - AMDGPU::OpName::src2}; - const AMDGPU::OpName ModOps[] = {AMDGPU::OpName::src0_modifiers, - AMDGPU::OpName::src1_modifiers, - AMDGPU::OpName::src2_modifiers}; - unsigned OpSel = Inst.getOperand(OpSelIdx).getImm(); - - for (int J = 0; J < 3; ++J) { - int OpIdx = AMDGPU::getNamedOperandIdx(Opc, Ops[J]); - if (OpIdx == -1) - break; - - int ModIdx = AMDGPU::getNamedOperandIdx(Opc, ModOps[J]); - uint32_t ModVal = Inst.getOperand(ModIdx).getImm(); - - if ((OpSel & (1 << J)) != 0) - ModVal |= SISrcMods::OP_SEL_0; - if (ModOps[J] == AMDGPU::OpName::src0_modifiers && - (OpSel & (1 << 3)) != 0) - ModVal |= SISrcMods::DST_OP_SEL; - - Inst.getOperand(ModIdx).setImm(ModVal); - } + cvtOpSelHelper(Inst, OpSel); } + void AMDGPUAsmParser::cvtScaledMFMA(MCInst &Inst, const OperandVector &Operands) { OptionalImmIndexMap OptionalIdx; diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp index 7b94ea3..f291e37 100644 --- a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp +++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp @@ -541,7 +541,7 @@ unsigned GCNSubtarget::getMaxNumSGPRs(const Function &F) const { unsigned GCNSubtarget::getBaseMaxNumVGPRs( const Function &F, std::pair<unsigned, unsigned> NumVGPRBounds) const { - const auto &[Min, Max] = NumVGPRBounds; + const auto [Min, Max] = NumVGPRBounds; // Check if maximum number of VGPRs was explicitly requested using // "amdgpu-num-vgpr" attribute. diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.h b/llvm/lib/Target/AMDGPU/GCNSubtarget.h index a54d665..c2e6078 100644 --- a/llvm/lib/Target/AMDGPU/GCNSubtarget.h +++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.h @@ -288,6 +288,8 @@ protected: bool Has45BitNumRecordsBufferResource = false; + bool HasClusters = false; + // Dummy feature to use for assembler in tablegen. bool FeatureDisable = false; @@ -1837,7 +1839,7 @@ public: } /// \returns true if the subtarget supports clusters of workgroups. - bool hasClusters() const { return GFX1250Insts; } + bool hasClusters() const { return HasClusters; } /// \returns true if the subtarget requires a wait for xcnt before atomic /// flat/global stores & rmw. diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp index d3b5718..3563caa 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp @@ -1280,6 +1280,17 @@ void AMDGPUInstPrinter::printPackedModifier(const MCInst *MI, (ModIdx != -1) ? MI->getOperand(ModIdx).getImm() : DefaultValue; } + // Some instructions, e.g. v_interp_p2_f16 in GFX9, have src0, src2, but no + // src1. + if (NumOps == 1 && AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::src2) && + !AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::src1)) { + Ops[NumOps++] = DefaultValue; // Set src1_modifiers to default. + int Mod2Idx = + AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src2_modifiers); + assert(Mod2Idx != -1); + Ops[NumOps++] = MI->getOperand(Mod2Idx).getImm(); + } + const bool HasDst = (AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::vdst) != -1) || (AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::sdst) != -1); diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp index 3c2dd42..3115579 100644 --- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp @@ -1118,12 +1118,7 @@ SIRegisterInfo::getPointerRegClass(unsigned Kind) const { const TargetRegisterClass * SIRegisterInfo::getCrossCopyRegClass(const TargetRegisterClass *RC) const { - if (isAGPRClass(RC) && !ST.hasGFX90AInsts()) - return getEquivalentVGPRClass(RC); - if (RC == &AMDGPU::SCC_CLASSRegClass) - return getWaveMaskRegClass(); - - return RC; + return RC == &AMDGPU::SCC_CLASSRegClass ? &AMDGPU::SReg_32RegClass : RC; } static unsigned getNumSubRegsForSpillOp(const MachineInstr &MI, diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp index 20fa141..f7f4d46 100644 --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp @@ -1353,11 +1353,6 @@ unsigned getVGPRAllocGranule(const MCSubtargetInfo *STI, if (DynamicVGPRBlockSize != 0) return DynamicVGPRBlockSize; - // Temporarily check the subtarget feature, until we fully switch to using - // attributes. - if (STI->getFeatureBits().test(FeatureDynamicVGPR)) - return STI->getFeatureBits().test(FeatureDynamicVGPRBlockSize32) ? 32 : 16; - bool IsWave32 = EnableWavefrontSize32 ? *EnableWavefrontSize32 : STI->getFeatureBits().test(FeatureWavefrontSize32); @@ -1412,10 +1407,7 @@ unsigned getAddressableNumVGPRs(const MCSubtargetInfo *STI, if (Features.test(FeatureGFX90AInsts)) return 512; - // Temporarily check the subtarget feature, until we fully switch to using - // attributes. - if (DynamicVGPRBlockSize != 0 || - STI->getFeatureBits().test(FeatureDynamicVGPR)) + if (DynamicVGPRBlockSize != 0) // On GFX12 we can allocate at most 8 blocks of VGPRs. return 8 * getVGPRAllocGranule(STI, DynamicVGPRBlockSize); return getAddressableNumArchVGPRs(STI); diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td index 4a2b54d..42ec8ba 100644 --- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td @@ -97,6 +97,7 @@ class VOP3Interp<string OpName, VOPProfile P, list<dag> pattern = []> : VOP3_Pseudo<OpName, P, pattern> { let AsmMatchConverter = "cvtVOP3Interp"; let mayRaiseFPException = 0; + let VOP3_OPSEL = P.HasOpSel; } def VOP3_INTERP : VOPProfile<[f32, f32, i32, untyped]> { @@ -119,16 +120,17 @@ def VOP3_INTERP_MOV : VOPProfile<[f32, i32, i32, untyped]> { let HasSrc0Mods = 0; } -class getInterp16Asm <bit HasSrc2, bit HasOMod> { +class getInterp16Asm <bit HasSrc2, bit HasOMod, bit OpSel> { string src2 = !if(HasSrc2, ", $src2_modifiers", ""); string omod = !if(HasOMod, "$omod", ""); + string opsel = !if(OpSel, "$op_sel", ""); string ret = - " $vdst, $src0_modifiers, $attr$attrchan"#src2#"$high$clamp"#omod; + " $vdst, $src0_modifiers, $attr$attrchan"#src2#"$high$clamp"#omod#opsel; } class getInterp16Ins <bit HasSrc2, bit HasOMod, - Operand Src0Mod, Operand Src2Mod> { - dag ret = !if(HasSrc2, + Operand Src0Mod, Operand Src2Mod, bit OpSel> { + dag ret1 = !if(HasSrc2, !if(HasOMod, (ins Src0Mod:$src0_modifiers, VRegSrc_32:$src0, InterpAttr:$attr, InterpAttrChan:$attrchan, @@ -143,19 +145,22 @@ class getInterp16Ins <bit HasSrc2, bit HasOMod, InterpAttr:$attr, InterpAttrChan:$attrchan, highmod:$high, Clamp0:$clamp, omod0:$omod) ); + dag ret2 = !if(OpSel, (ins op_sel0:$op_sel), (ins)); + dag ret = !con(ret1, ret2); } -class VOP3_INTERP16 <list<ValueType> ArgVT> : VOPProfile<ArgVT> { +class VOP3_INTERP16 <list<ValueType> ArgVT, bit OpSel = 0> : VOPProfile<ArgVT> { let IsSingle = 1; let HasOMod = !ne(DstVT.Value, f16.Value); let HasHigh = 1; + let HasOpSel = OpSel; let Src0Mod = FPVRegInputMods; let Src2Mod = FPVRegInputMods; let Outs64 = (outs DstRC.RegClass:$vdst); - let Ins64 = getInterp16Ins<HasSrc2, HasOMod, Src0Mod, Src2Mod>.ret; - let Asm64 = getInterp16Asm<HasSrc2, HasOMod>.ret; + let Ins64 = getInterp16Ins<HasSrc2, HasOMod, Src0Mod, Src2Mod, OpSel>.ret; + let Asm64 = getInterp16Asm<HasSrc2, HasOMod, OpSel>.ret; } //===----------------------------------------------------------------------===// @@ -480,7 +485,7 @@ let SubtargetPredicate = isGFX9Plus in { defm V_MAD_U16_gfx9 : VOP3Inst_t16 <"v_mad_u16_gfx9", VOP_I16_I16_I16_I16>; defm V_MAD_I16_gfx9 : VOP3Inst_t16 <"v_mad_i16_gfx9", VOP_I16_I16_I16_I16>; let OtherPredicates = [isNotGFX90APlus] in -def V_INTERP_P2_F16_gfx9 : VOP3Interp <"v_interp_p2_f16_gfx9", VOP3_INTERP16<[f16, f32, i32, f32]>>; +def V_INTERP_P2_F16_opsel : VOP3Interp <"v_interp_p2_f16_opsel", VOP3_INTERP16<[f16, f32, i32, f32], /*OpSel*/ 1>>; } // End SubtargetPredicate = isGFX9Plus // This predicate should only apply to the selection pattern. The @@ -2676,6 +2681,14 @@ multiclass VOP3Interp_F16_Real_gfx9<bits<10> op, string OpName, string AsmName> } } +multiclass VOP3Interp_F16_OpSel_Real_gfx9<bits<10> op, string OpName, string AsmName> { + def _gfx9 : VOP3_Real<!cast<VOP3_Pseudo>(OpName), SIEncodingFamily.GFX9>, + VOP3Interp_OpSel_gfx9 <op, !cast<VOP3_Pseudo>(OpName).Pfl> { + VOP3_Pseudo ps = !cast<VOP3_Pseudo>(OpName); + let AsmString = AsmName # ps.AsmOperands; + } +} + multiclass VOP3_Real_gfx9<bits<10> op, string AsmName> { def _gfx9 : VOP3_Real<!cast<VOP_Pseudo>(NAME#"_e64"), SIEncodingFamily.GFX9>, VOP3e_vi <op, !cast<VOP_Pseudo>(NAME#"_e64").Pfl> { @@ -2788,7 +2801,7 @@ defm V_MAD_U16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x204, "v_mad_u16">; defm V_MAD_I16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x205, "v_mad_i16">; defm V_FMA_F16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x206, "v_fma_f16">; defm V_DIV_FIXUP_F16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x207, "v_div_fixup_f16">; -defm V_INTERP_P2_F16_gfx9 : VOP3Interp_F16_Real_gfx9 <0x277, "V_INTERP_P2_F16_gfx9", "v_interp_p2_f16">; +defm V_INTERP_P2_F16_opsel : VOP3Interp_F16_OpSel_Real_gfx9 <0x277, "V_INTERP_P2_F16_opsel", "v_interp_p2_f16">; defm V_ADD_I32 : VOP3_Real_vi <0x29c>; defm V_SUB_I32 : VOP3_Real_vi <0x29d>; diff --git a/llvm/lib/Target/AMDGPU/VOPInstructions.td b/llvm/lib/Target/AMDGPU/VOPInstructions.td index 631f0f3..8325c62 100644 --- a/llvm/lib/Target/AMDGPU/VOPInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOPInstructions.td @@ -419,6 +419,13 @@ class VOP3a_ScaleSel_gfx1250<bits<10> op, VOPProfile p> : VOP3e_gfx11_gfx12<op, let Inst{14-11} = scale_sel; } +class VOP3Interp_OpSel_gfx9<bits<10> op, VOPProfile p> : VOP3Interp_vi<op, p> { + let Inst{11} = src0_modifiers{2}; + // There's no src1 + let Inst{13} = src2_modifiers{2}; + let Inst{14} = !if(p.HasDst, src0_modifiers{3}, 0); +} + class VOP3Interp_gfx10<bits<10> op, VOPProfile p> : VOP3e_gfx10<op, p> { bits<6> attr; bits<2> attrchan; diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp index f9bdc09..77913f2 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp @@ -149,6 +149,9 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O, case NVPTX::PTXCvtMode::RNA: O << ".rna"; return; + case NVPTX::PTXCvtMode::RS: + O << ".rs"; + return; } } llvm_unreachable("Invalid conversion modifier"); diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h index 77a0e03..1e0f747 100644 --- a/llvm/lib/Target/NVPTX/NVPTX.h +++ b/llvm/lib/Target/NVPTX/NVPTX.h @@ -207,6 +207,7 @@ enum CvtMode { RM, RP, RNA, + RS, BASE_MASK = 0x0F, FTZ_FLAG = 0x10, diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 8c21746..bc047a4a 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1096,9 +1096,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // Enable custom lowering for the following: // * MVT::i128 - clusterlaunchcontrol // * MVT::i32 - prmt + // * MVT::v4f32 - cvt_rs fp{4/6/8}x4 intrinsics // * MVT::Other - internal.addrspace.wrap - setOperationAction(ISD::INTRINSIC_WO_CHAIN, {MVT::i32, MVT::i128, MVT::Other}, - Custom); + setOperationAction(ISD::INTRINSIC_WO_CHAIN, + {MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Custom); } const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { @@ -1181,6 +1182,11 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT) MAKE_CASE( NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT) + MAKE_CASE(NVPTXISD::CVT_E4M3X4_F32X4_RS_SF) + MAKE_CASE(NVPTXISD::CVT_E5M2X4_F32X4_RS_SF) + MAKE_CASE(NVPTXISD::CVT_E2M3X4_F32X4_RS_SF) + MAKE_CASE(NVPTXISD::CVT_E3M2X4_F32X4_RS_SF) + MAKE_CASE(NVPTXISD::CVT_E2M1X4_F32X4_RS_SF) } return nullptr; @@ -2903,6 +2909,61 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op, {TryCancelResponse0, TryCancelResponse1}); } +static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) { + SDNode *N = Op.getNode(); + SDLoc DL(N); + SDValue F32Vec = N->getOperand(1); + SDValue RBits = N->getOperand(2); + + unsigned IntrinsicID = N->getConstantOperandVal(0); + + // Extract the 4 float elements from the vector + SmallVector<SDValue, 6> Ops; + for (unsigned i = 0; i < 4; ++i) + Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec, + DAG.getIntPtrConstant(i, DL))); + + using NVPTX::PTXCvtMode::CvtMode; + + auto [OpCode, RetTy, CvtModeFlag] = + [&]() -> std::tuple<NVPTXISD::NodeType, MVT::SimpleValueType, uint32_t> { + switch (IntrinsicID) { + case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite: + return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8, + CvtMode::RS | CvtMode::RELU_FLAG}; + case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite: + return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS}; + case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite: + return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8, + CvtMode::RS | CvtMode::RELU_FLAG}; + case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite: + return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS}; + case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite: + return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8, + CvtMode::RS | CvtMode::RELU_FLAG}; + case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite: + return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS}; + case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite: + return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8, + CvtMode::RS | CvtMode::RELU_FLAG}; + case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite: + return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS}; + case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite: + return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16, + CvtMode::RS | CvtMode::RELU_FLAG}; + case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite: + return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16, CvtMode::RS}; + default: + llvm_unreachable("unsupported/unhandled intrinsic"); + } + }(); + + Ops.push_back(RBits); + Ops.push_back(DAG.getConstant(CvtModeFlag, DL, MVT::i32)); + + return DAG.getNode(OpCode, DL, RetTy, Ops); +} + static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) { const unsigned Mode = [&]() { switch (Op->getConstantOperandVal(0)) { @@ -2972,6 +3033,17 @@ static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) { case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y: case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z: return LowerClusterLaunchControlQueryCancel(Op, DAG); + case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite: + case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite: + case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite: + case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite: + case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite: + case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite: + case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite: + case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite: + case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite: + case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite: + return lowerCvtRSIntrinsics(Op, DAG); } } diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index 769d2fe..63fa0bb 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -79,6 +79,11 @@ enum NodeType : unsigned { CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X, CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y, CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z, + CVT_E4M3X4_F32X4_RS_SF, + CVT_E5M2X4_F32X4_RS_SF, + CVT_E2M3X4_F32X4_RS_SF, + CVT_E3M2X4_F32X4_RS_SF, + CVT_E2M1X4_F32X4_RS_SF, FIRST_MEMORY_OPCODE, diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 4cacee2..6c14cf0 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -34,7 +34,8 @@ def CvtRN : PatLeaf<(i32 0x5)>; def CvtRZ : PatLeaf<(i32 0x6)>; def CvtRM : PatLeaf<(i32 0x7)>; def CvtRP : PatLeaf<(i32 0x8)>; -def CvtRNA : PatLeaf<(i32 0x9)>; +def CvtRNA : PatLeaf<(i32 0x9)>; +def CvtRS : PatLeaf<(i32 0xA)>; def CvtNONE_FTZ : PatLeaf<(i32 0x10)>; def CvtRNI_FTZ : PatLeaf<(i32 0x11)>; @@ -50,8 +51,9 @@ def CvtSAT : PatLeaf<(i32 0x20)>; def CvtSAT_FTZ : PatLeaf<(i32 0x30)>; def CvtNONE_RELU : PatLeaf<(i32 0x40)>; -def CvtRN_RELU : PatLeaf<(i32 0x45)>; -def CvtRZ_RELU : PatLeaf<(i32 0x46)>; +def CvtRN_RELU : PatLeaf<(i32 0x45)>; +def CvtRZ_RELU : PatLeaf<(i32 0x46)>; +def CvtRS_RELU : PatLeaf<(i32 0x4A)>; def CvtMode : Operand<i32> { let PrintMethod = "printCvtMode"; @@ -133,6 +135,11 @@ def hasSM100a : Predicate<"Subtarget->getSmVersion() == 100 && Subtarget->hasArc def hasSM101a : Predicate<"Subtarget->getSmVersion() == 101 && Subtarget->hasArchAccelFeatures()">; def hasSM120a : Predicate<"Subtarget->getSmVersion() == 120 && Subtarget->hasArchAccelFeatures()">; +def hasSM100aOrSM103a : + Predicate<"(Subtarget->getSmVersion() == 100 || " # + "Subtarget->getSmVersion() == 103) " # + "&& Subtarget->hasArchAccelFeatures()">; + // non-sync shfl instructions are not available on sm_70+ in PTX6.4+ def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70" "&& Subtarget->getPTXVersion() >= 64)">; @@ -593,6 +600,23 @@ let hasSideEffects = false in { defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", B32>; defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", B32>; + + multiclass CVT_FROM_FLOAT_V2_RS<string FromName, RegisterClass RC> { + def _f32_rs : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B32:$src1, B32:$src2, B32:$src3), + (ins CvtMode:$mode), + "cvt${mode:base}${mode:relu}." # FromName # ".f32">; + + def _f32_rs_sf : + BasicFlagsNVPTXInst<(outs RC:$dst), + (ins B32:$src1, B32:$src2, B32:$src3), + (ins CvtMode:$mode), + "cvt${mode:base}${mode:relu}.satfinite." # FromName # ".f32">; + } + + defm CVT_f16x2 : CVT_FROM_FLOAT_V2_RS<"f16x2", B32>; + defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_RS<"bf16x2", B32>; // FP8 conversions. multiclass CVT_TO_F8X2<string F8Name> { @@ -619,6 +643,15 @@ let hasSideEffects = false in { def CVT_f16x2_e4m3x2 : CVT_f16x2_fp8<"e4m3">; def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">; + + class CVT_TO_FP8X4<string F8Name> : + NVPTXInst<(outs B32:$dst), + (ins B32:$src1, B32:$src2, B32:$src3, B32:$src4, B32:$src5, CvtMode:$mode), + "cvt${mode:base}${mode:relu}.satfinite." # F8Name # + "x4.f32 \t$dst, {{$src1, $src2, $src3, $src4}}, $src5;">; + + def CVT_e4m3x4_f32x4_rs_sf : CVT_TO_FP8X4<"e4m3">; + def CVT_e5m2x4_f32x4_rs_sf : CVT_TO_FP8X4<"e5m2">; // Float to TF32 conversions multiclass CVT_TO_TF32<string Modifier, list<Predicate> Preds = [hasPTX<78>, hasSM<90>]> { @@ -652,6 +685,15 @@ let hasSideEffects = false in { "cvt${mode:base}${mode:relu}.f16x2." # type>; } + class CVT_TO_FP6X4<string F6Name> : + NVPTXInst<(outs B32:$dst), + (ins B32:$src1, B32:$src2, B32:$src3, B32:$src4, B32:$src5, CvtMode:$mode), + "cvt${mode:base}${mode:relu}.satfinite." # F6Name # + "x4.f32 \t$dst, {{$src1, $src2, $src3, $src4}}, $src5;">; + + def CVT_e2m3x4_f32x4_rs_sf : CVT_TO_FP6X4<"e2m3">; + def CVT_e3m2x4_f32x4_rs_sf : CVT_TO_FP6X4<"e3m2">; + // FP4 conversions. def CVT_e2m1x2_f32_sf : NVPTXInst<(outs B16:$dst), (ins B32:$src1, B32:$src2, CvtMode:$mode), @@ -668,6 +710,12 @@ let hasSideEffects = false in { "cvt.u8.u16 \t%e2m1x2_in, $src; \n\t", "cvt${mode:base}${mode:relu}.f16x2.e2m1x2 \t$dst, %e2m1x2_in; \n\t", "}}"), []>; + + def CVT_e2m1x4_f32x4_rs_sf : + NVPTXInst<(outs B16:$dst), + (ins B32:$src1, B32:$src2, B32:$src3, B32:$src4, B32:$src5, CvtMode:$mode), + "cvt${mode:base}${mode:relu}.satfinite.e2m1x4.f32 \t" # + "$dst, {{$src1, $src2, $src3, $src4}}, $src5;">; // UE8M0x2 conversions. class CVT_f32_to_ue8m0x2<string sat = ""> : diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index e91171c..a8b854f 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1782,11 +1782,32 @@ def : Pat<(int_nvvm_ff2bf16x2_rn_relu f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, C def : Pat<(int_nvvm_ff2bf16x2_rz f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRZ)>; def : Pat<(int_nvvm_ff2bf16x2_rz_relu f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRZ_RELU)>; +let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in { +def : Pat<(int_nvvm_ff2bf16x2_rs f32:$a, f32:$b, i32:$c), + (CVT_bf16x2_f32_rs $a, $b, $c, CvtRS)>; +def : Pat<(int_nvvm_ff2bf16x2_rs_relu f32:$a, f32:$b, i32:$c), + (CVT_bf16x2_f32_rs $a, $b, $c, CvtRS_RELU)>; +def : Pat<(int_nvvm_ff2bf16x2_rs_satfinite f32:$a, f32:$b, i32:$c), + (CVT_bf16x2_f32_rs_sf $a, $b, $c, CvtRS)>; +def : Pat<(int_nvvm_ff2bf16x2_rs_relu_satfinite f32:$a, f32:$b, i32:$c), + (CVT_bf16x2_f32_rs_sf $a, $b, $c, CvtRS_RELU)>; +} + def : Pat<(int_nvvm_ff2f16x2_rn f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRN)>; def : Pat<(int_nvvm_ff2f16x2_rn_relu f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRN_RELU)>; def : Pat<(int_nvvm_ff2f16x2_rz f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRZ)>; def : Pat<(int_nvvm_ff2f16x2_rz_relu f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRZ_RELU)>; +let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in { +def : Pat<(int_nvvm_ff2f16x2_rs f32:$a, f32:$b, i32:$c), + (CVT_f16x2_f32_rs $a, $b, $c, CvtRS)>; +def : Pat<(int_nvvm_ff2f16x2_rs_relu f32:$a, f32:$b, i32:$c), + (CVT_f16x2_f32_rs $a, $b, $c, CvtRS_RELU)>; +def : Pat<(int_nvvm_ff2f16x2_rs_satfinite f32:$a, f32:$b, i32:$c), + (CVT_f16x2_f32_rs_sf $a, $b, $c, CvtRS)>; +def : Pat<(int_nvvm_ff2f16x2_rs_relu_satfinite f32:$a, f32:$b, i32:$c), + (CVT_f16x2_f32_rs_sf $a, $b, $c, CvtRS_RELU)>; +} def : Pat<(int_nvvm_f2bf16_rn f32:$a), (CVT_bf16_f32 $a, CvtRN)>; def : Pat<(int_nvvm_f2bf16_rn_relu f32:$a), (CVT_bf16_f32 $a, CvtRN_RELU)>; def : Pat<(int_nvvm_f2bf16_rz f32:$a), (CVT_bf16_f32 $a, CvtRZ)>; @@ -1929,6 +1950,52 @@ let Predicates = [hasPTX<86>, hasSM<100>, hasArchAccelFeatures] in { (CVT_bf16x2_ue8m0x2 $a)>; } +def SDT_CVT_F32X4_TO_FPX4_RS_VEC : + SDTypeProfile<1, 6, [SDTCisVec<0>, SDTCisFP<1>, SDTCisFP<2>, SDTCisFP<3>, + SDTCisFP<4>, SDTCisInt<5>, SDTCisInt<6>]>; + +def SDT_CVT_F32X4_TO_FPX4_RS_INT : + SDTypeProfile<1, 6, [SDTCisInt<0>, SDTCisFP<1>, SDTCisFP<2>, SDTCisFP<3>, + SDTCisFP<4>, SDTCisInt<5>, SDTCisInt<6>]>; + +class CVT_F32X4_TO_FPX4_RS_SF_NODE<string FPName, SDTypeProfile SDT> : + SDNode<"NVPTXISD::CVT_" # FPName # "X4_F32X4_RS_SF", SDT, []>; + +multiclass CVT_F32X4_TO_FPX4_RS_SF_VEC<string FPName, VTVec RetTy> { + def : Pat<(RetTy (CVT_F32X4_TO_FPX4_RS_SF_NODE<!toupper(FPName), + SDT_CVT_F32X4_TO_FPX4_RS_VEC> + f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS)), + (!cast<NVPTXInst>("CVT_" # FPName # "x4_f32x4_rs_sf") + $f1, $f2, $f3, $f4, $rbits, CvtRS)>; + + def : Pat<(RetTy (CVT_F32X4_TO_FPX4_RS_SF_NODE<!toupper(FPName), + SDT_CVT_F32X4_TO_FPX4_RS_VEC> + f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS_RELU)), + (!cast<NVPTXInst>("CVT_" # FPName # "x4_f32x4_rs_sf") + $f1, $f2, $f3, $f4, $rbits, CvtRS_RELU)>; +} + +// RS rounding mode conversions +let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in { +// FP8x4 conversions +defm : CVT_F32X4_TO_FPX4_RS_SF_VEC<"e4m3", v4i8>; +defm : CVT_F32X4_TO_FPX4_RS_SF_VEC<"e5m2", v4i8>; + +// FP6x4 conversions +defm : CVT_F32X4_TO_FPX4_RS_SF_VEC<"e2m3", v4i8>; +defm : CVT_F32X4_TO_FPX4_RS_SF_VEC<"e3m2", v4i8>; + +// FP4x4 conversions +def : Pat<(i16 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E2M1", + SDT_CVT_F32X4_TO_FPX4_RS_INT> + f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS)), + (CVT_e2m1x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS)>; +def : Pat<(i16 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E2M1", + SDT_CVT_F32X4_TO_FPX4_RS_INT> + f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS_RELU)), + (CVT_e2m1x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS_RELU)>; +} + // // FNS // @@ -4461,6 +4528,10 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = " !eq(ptx_elt_type, "e2m1"), !ne(kind, "")) : [hasSM120a, hasPTX<87>], + !and(!or(!eq(ptx_elt_type,"e4m3"), + !eq(ptx_elt_type,"e5m2")), + !eq(geom, "m16n8k16")) : [hasSM<89>, hasPTX<87>], + !or(!eq(ptx_elt_type, "e4m3"), !eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>], @@ -4476,6 +4547,11 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = " !and(!eq(geom, "m8n8k4"), !eq(ptx_elt_type, "f64")) : [hasSM<80>, hasPTX<70>], + !and(!or(!eq(geom, "m16n8k4"), + !eq(geom, "m16n8k8"), + !eq(geom, "m16n8k16")), + !eq(ptx_elt_type, "f64")) : [hasSM<90>, hasPTX<78>], + // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16 !and(!or(!eq(geom, "m8n32k16"), !eq(geom, "m32n8k16")), @@ -4760,8 +4836,8 @@ defset list<WMMA_INSTR> WMMAs = { // MMA class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB, WMMA_REGINFO FragC, WMMA_REGINFO FragD, - string ALayout, string BLayout, int Satfinite, string b1op> - : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, FragA, FragB, FragC, FragD>.record, + string ALayout, string BLayout, int Satfinite, string b1op, string Kind> + : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record, [FragA.Ins, FragB.Ins, FragC.Ins]>, // Requires does not seem to have effect on Instruction w/o Patterns. // We set it here anyways and propagate to the Pat<> we construct below. @@ -4776,6 +4852,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB, # FragA.geom # "." # ALayout # "." # BLayout + # !if(!ne(Kind, ""), "." # Kind, "") # !if(Satfinite, ".satfinite", "") # TypeList # b1op # "\n\t\t" @@ -4792,13 +4869,15 @@ defset list<WMMA_INSTR> MMAs = { foreach satf = [0, 1] in { foreach op = NVVM_MMA_OPS.all_mma_ops in { foreach b1op = NVVM_MMA_B1OPS<op>.ret in { - if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then { - def : MMA<WMMA_REGINFO<op[0], "mma">, - WMMA_REGINFO<op[1], "mma">, - WMMA_REGINFO<op[2], "mma">, - WMMA_REGINFO<op[3], "mma">, - layout_a, layout_b, satf, b1op>; - } + foreach kind = ["", "kind::f8f6f4"] in { + if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then { + def : MMA<WMMA_REGINFO<op[0], "mma", "", kind>, + WMMA_REGINFO<op[1], "mma", "", kind>, + WMMA_REGINFO<op[2], "mma", "", kind>, + WMMA_REGINFO<op[3], "mma", "", kind>, + layout_a, layout_b, satf, b1op, kind>; + } + } // kind } // b1op } // op } // satf diff --git a/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp b/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp index 1fc475d..561a9c5 100644 --- a/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp +++ b/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp @@ -349,32 +349,30 @@ public: bool isImm() const override { return Kind == Immediate || Kind == Expression; } - bool isU1Imm() const { return Kind == Immediate && isUInt<1>(getImm()); } - bool isU2Imm() const { return Kind == Immediate && isUInt<2>(getImm()); } - bool isU3Imm() const { return Kind == Immediate && isUInt<3>(getImm()); } - bool isU4Imm() const { return Kind == Immediate && isUInt<4>(getImm()); } - bool isU5Imm() const { return Kind == Immediate && isUInt<5>(getImm()); } - bool isS5Imm() const { return Kind == Immediate && isInt<5>(getImm()); } - bool isU6Imm() const { return Kind == Immediate && isUInt<6>(getImm()); } - bool isU6ImmX2() const { return Kind == Immediate && - isUInt<6>(getImm()) && - (getImm() & 1) == 0; } - bool isU7Imm() const { return Kind == Immediate && isUInt<7>(getImm()); } - bool isU7ImmX4() const { return Kind == Immediate && - isUInt<7>(getImm()) && - (getImm() & 3) == 0; } - bool isU8Imm() const { return Kind == Immediate && isUInt<8>(getImm()); } - bool isU8ImmX8() const { return Kind == Immediate && - isUInt<8>(getImm()) && - (getImm() & 7) == 0; } - - bool isU10Imm() const { return Kind == Immediate && isUInt<10>(getImm()); } - bool isU12Imm() const { return Kind == Immediate && isUInt<12>(getImm()); } + + template <uint64_t N> bool isUImm() const { + return Kind == Immediate && isUInt<N>(getImm()); + } + template <uint64_t N> bool isSImm() const { + return Kind == Immediate && isInt<N>(getImm()); + } + bool isU6ImmX2() const { return isUImm<6>() && (getImm() & 1) == 0; } + bool isU7ImmX4() const { return isUImm<7>() && (getImm() & 3) == 0; } + bool isU8ImmX8() const { return isUImm<8>() && (getImm() & 7) == 0; } + bool isU16Imm() const { return isExtImm<16>(/*Signed*/ false, 1); } bool isS16Imm() const { return isExtImm<16>(/*Signed*/ true, 1); } bool isS16ImmX4() const { return isExtImm<16>(/*Signed*/ true, 4); } bool isS16ImmX16() const { return isExtImm<16>(/*Signed*/ true, 16); } bool isS17Imm() const { return isExtImm<17>(/*Signed*/ true, 1); } + bool isS34Imm() const { + // Once the PC-Rel ABI is finalized, evaluate whether a 34-bit + // ContextImmediate is needed. + return Kind == Expression || isSImm<34>(); + } + bool isS34ImmX16() const { + return Kind == Expression || (isSImm<34>() && (getImm() & 15) == 0); + } bool isHashImmX8() const { // The Hash Imm form is used for instructions that check or store a hash. @@ -384,16 +382,6 @@ public: (getImm() & 7) == 0); } - bool isS34ImmX16() const { - return Kind == Expression || - (Kind == Immediate && isInt<34>(getImm()) && (getImm() & 15) == 0); - } - bool isS34Imm() const { - // Once the PC-Rel ABI is finalized, evaluate whether a 34-bit - // ContextImmediate is needed. - return Kind == Expression || (Kind == Immediate && isInt<34>(getImm())); - } - bool isTLSReg() const { return Kind == TLSRegister; } bool isDirectBr() const { if (Kind == Expression) @@ -1637,7 +1625,7 @@ bool PPCAsmParser::parseInstruction(ParseInstructionInfo &Info, StringRef Name, if (Operands.size() != 5) return false; PPCOperand &EHOp = (PPCOperand &)*Operands[4]; - if (EHOp.isU1Imm() && EHOp.getImm() == 0) + if (EHOp.isUImm<1>() && EHOp.getImm() == 0) Operands.pop_back(); } @@ -1817,7 +1805,7 @@ unsigned PPCAsmParser::validateTargetOperandClass(MCParsedAsmOperand &AsmOp, } PPCOperand &Op = static_cast<PPCOperand &>(AsmOp); - if (Op.isU3Imm() && Op.getImm() == ImmVal) + if (Op.isUImm<3>() && Op.getImm() == ImmVal) return Match_Success; return Match_InvalidOperand; diff --git a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp index 48c31c9..81d8e94 100644 --- a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp +++ b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp @@ -206,45 +206,24 @@ PPCMCCodeEmitter::getVSRpEvenEncoding(const MCInst &MI, unsigned OpNo, return RegBits; } -unsigned PPCMCCodeEmitter::getImm16Encoding(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI) const { - const MCOperand &MO = MI.getOperand(OpNo); - if (MO.isReg() || MO.isImm()) return getMachineOpValue(MI, MO, Fixups, STI); - - // Add a fixup for the immediate field. - addFixup(Fixups, IsLittleEndian ? 0 : 2, MO.getExpr(), PPC::fixup_ppc_half16); - return 0; -} - -uint64_t PPCMCCodeEmitter::getImm34Encoding(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI, - MCFixupKind Fixup) const { +template <MCFixupKind Fixup> +uint64_t PPCMCCodeEmitter::getImmEncoding(const MCInst &MI, unsigned OpNo, + SmallVectorImpl<MCFixup> &Fixups, + const MCSubtargetInfo &STI) const { const MCOperand &MO = MI.getOperand(OpNo); assert(!MO.isReg() && "Not expecting a register for this operand."); if (MO.isImm()) return getMachineOpValue(MI, MO, Fixups, STI); + uint32_t Offset = 0; + if (Fixup == PPC::fixup_ppc_half16) + Offset = IsLittleEndian ? 0 : 2; + // Add a fixup for the immediate field. - addFixup(Fixups, 0, MO.getExpr(), Fixup); + addFixup(Fixups, Offset, MO.getExpr(), Fixup); return 0; } -uint64_t -PPCMCCodeEmitter::getImm34EncodingNoPCRel(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI) const { - return getImm34Encoding(MI, OpNo, Fixups, STI, PPC::fixup_ppc_imm34); -} - -uint64_t -PPCMCCodeEmitter::getImm34EncodingPCRel(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI) const { - return getImm34Encoding(MI, OpNo, Fixups, STI, PPC::fixup_ppc_pcrel34); -} - unsigned PPCMCCodeEmitter::getDispRIEncoding(const MCInst &MI, unsigned OpNo, SmallVectorImpl<MCFixup> &Fixups, const MCSubtargetInfo &STI) const { diff --git a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h index b574557..3356513 100644 --- a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h +++ b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h @@ -47,19 +47,10 @@ public: unsigned getAbsCondBrEncoding(const MCInst &MI, unsigned OpNo, SmallVectorImpl<MCFixup> &Fixups, const MCSubtargetInfo &STI) const; - unsigned getImm16Encoding(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI) const; - uint64_t getImm34Encoding(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI, - MCFixupKind Fixup) const; - uint64_t getImm34EncodingNoPCRel(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI) const; - uint64_t getImm34EncodingPCRel(const MCInst &MI, unsigned OpNo, - SmallVectorImpl<MCFixup> &Fixups, - const MCSubtargetInfo &STI) const; + template <MCFixupKind Fixup> + uint64_t getImmEncoding(const MCInst &MI, unsigned OpNo, + SmallVectorImpl<MCFixup> &Fixups, + const MCSubtargetInfo &STI) const; unsigned getDispRIEncoding(const MCInst &MI, unsigned OpNo, SmallVectorImpl<MCFixup> &Fixups, const MCSubtargetInfo &STI) const; diff --git a/llvm/lib/Target/PowerPC/PPCInstr64Bit.td b/llvm/lib/Target/PowerPC/PPCInstr64Bit.td index 60efa4c..fdca5ebc 100644 --- a/llvm/lib/Target/PowerPC/PPCInstr64Bit.td +++ b/llvm/lib/Target/PowerPC/PPCInstr64Bit.td @@ -14,30 +14,6 @@ //===----------------------------------------------------------------------===// // 64-bit operands. // -def s16imm64 : Operand<i64> { - let PrintMethod = "printS16ImmOperand"; - let EncoderMethod = "getImm16Encoding"; - let ParserMatchClass = PPCS16ImmAsmOperand; - let DecoderMethod = "decodeSImmOperand<16>"; - let OperandType = "OPERAND_IMMEDIATE"; -} -def u16imm64 : Operand<i64> { - let PrintMethod = "printU16ImmOperand"; - let EncoderMethod = "getImm16Encoding"; - let ParserMatchClass = PPCU16ImmAsmOperand; - let DecoderMethod = "decodeUImmOperand<16>"; - let OperandType = "OPERAND_IMMEDIATE"; -} -def s17imm64 : Operand<i64> { - // This operand type is used for addis/lis to allow the assembler parser - // to accept immediates in the range -65536..65535 for compatibility with - // the GNU assembler. The operand is treated as 16-bit otherwise. - let PrintMethod = "printS16ImmOperand"; - let EncoderMethod = "getImm16Encoding"; - let ParserMatchClass = PPCS17ImmAsmOperand; - let DecoderMethod = "decodeSImmOperand<16>"; - let OperandType = "OPERAND_IMMEDIATE"; -} def tocentry : Operand<iPTR> { let MIOperandInfo = (ops i64imm:$imm); } diff --git a/llvm/lib/Target/PowerPC/PPCInstrAltivec.td b/llvm/lib/Target/PowerPC/PPCInstrAltivec.td index c616db4..23d6d88 100644 --- a/llvm/lib/Target/PowerPC/PPCInstrAltivec.td +++ b/llvm/lib/Target/PowerPC/PPCInstrAltivec.td @@ -30,6 +30,11 @@ // Altivec transformation functions and pattern fragments. // +// fneg is not legal, and desugared as an xor. +def desugared_fneg : PatFrag<(ops node:$x), (v4f32 (bitconvert (xor (bitconvert $x), + (int_ppc_altivec_vslw (bitconvert (v16i8 immAllOnesV)), + (bitconvert (v16i8 immAllOnesV))))))>; + def vpkuhum_shuffle : PatFrag<(ops node:$lhs, node:$rhs), (vector_shuffle node:$lhs, node:$rhs), [{ return PPC::isVPKUHUMShuffleMask(cast<ShuffleVectorSDNode>(N), 0, *CurDAG); @@ -467,11 +472,12 @@ def VMADDFP : VAForm_1<46, (outs vrrc:$RT), (ins vrrc:$RA, vrrc:$RC, vrrc:$RB), [(set v4f32:$RT, (fma v4f32:$RA, v4f32:$RC, v4f32:$RB))]>; -// FIXME: The fma+fneg pattern won't match because fneg is not legal. +// fneg is not legal, hence we have to match on the desugared version. def VNMSUBFP: VAForm_1<47, (outs vrrc:$RT), (ins vrrc:$RA, vrrc:$RC, vrrc:$RB), "vnmsubfp $RT, $RA, $RC, $RB", IIC_VecFP, - [(set v4f32:$RT, (fneg (fma v4f32:$RA, v4f32:$RC, - (fneg v4f32:$RB))))]>; + [(set v4f32:$RT, (desugared_fneg (fma v4f32:$RA, v4f32:$RC, + (desugared_fneg v4f32:$RB))))]>; + let hasSideEffects = 1 in { def VMHADDSHS : VA1a_Int_Ty<32, "vmhaddshs", int_ppc_altivec_vmhaddshs, v8i16>; def VMHRADDSHS : VA1a_Int_Ty<33, "vmhraddshs", int_ppc_altivec_vmhraddshs, @@ -892,6 +898,13 @@ def : Pat<(mul v8i16:$vA, v8i16:$vB), (VMLADDUHM $vA, $vB, (v8i16(V_SET0H)))>; // Add def : Pat<(add (mul v8i16:$vA, v8i16:$vB), v8i16:$vC), (VMLADDUHM $vA, $vB, $vC)>; + +// Fused negated multiply-subtract +def : Pat<(v4f32 (desugared_fneg + (int_ppc_altivec_vmaddfp v4f32:$RA, v4f32:$RC, + (desugared_fneg v4f32:$RB)))), + (VNMSUBFP $RA, $RC, $RB)>; + // Saturating adds/subtracts. def : Pat<(v16i8 (saddsat v16i8:$vA, v16i8:$vB)), (v16i8 (VADDSBS $vA, $vB))>; def : Pat<(v16i8 (uaddsat v16i8:$vA, v16i8:$vB)), (v16i8 (VADDUBS $vA, $vB))>; diff --git a/llvm/lib/Target/PowerPC/PPCRegisterInfo.td b/llvm/lib/Target/PowerPC/PPCRegisterInfo.td index 6d8c122..65d0484 100644 --- a/llvm/lib/Target/PowerPC/PPCRegisterInfo.td +++ b/llvm/lib/Target/PowerPC/PPCRegisterInfo.td @@ -615,7 +615,8 @@ def spe4rc : RegisterOperand<GPRC> { } def PPCU1ImmAsmOperand : AsmOperandClass { - let Name = "U1Imm"; let PredicateMethod = "isU1Imm"; + let Name = "U1Imm"; + let PredicateMethod = "isUImm<1>"; let RenderMethod = "addImmOperands"; } def u1imm : Operand<i32> { @@ -626,7 +627,8 @@ def u1imm : Operand<i32> { } def PPCU2ImmAsmOperand : AsmOperandClass { - let Name = "U2Imm"; let PredicateMethod = "isU2Imm"; + let Name = "U2Imm"; + let PredicateMethod = "isUImm<2>"; let RenderMethod = "addImmOperands"; } def u2imm : Operand<i32> { @@ -647,7 +649,8 @@ def atimm : Operand<i32> { } def PPCU3ImmAsmOperand : AsmOperandClass { - let Name = "U3Imm"; let PredicateMethod = "isU3Imm"; + let Name = "U3Imm"; + let PredicateMethod = "isUImm<3>"; let RenderMethod = "addImmOperands"; } def u3imm : Operand<i32> { @@ -658,7 +661,8 @@ def u3imm : Operand<i32> { } def PPCU4ImmAsmOperand : AsmOperandClass { - let Name = "U4Imm"; let PredicateMethod = "isU4Imm"; + let Name = "U4Imm"; + let PredicateMethod = "isUImm<4>"; let RenderMethod = "addImmOperands"; } def u4imm : Operand<i32> { @@ -668,7 +672,8 @@ def u4imm : Operand<i32> { let OperandType = "OPERAND_IMMEDIATE"; } def PPCS5ImmAsmOperand : AsmOperandClass { - let Name = "S5Imm"; let PredicateMethod = "isS5Imm"; + let Name = "S5Imm"; + let PredicateMethod = "isSImm<5>"; let RenderMethod = "addImmOperands"; } def s5imm : Operand<i32> { @@ -678,7 +683,8 @@ def s5imm : Operand<i32> { let OperandType = "OPERAND_IMMEDIATE"; } def PPCU5ImmAsmOperand : AsmOperandClass { - let Name = "U5Imm"; let PredicateMethod = "isU5Imm"; + let Name = "U5Imm"; + let PredicateMethod = "isUImm<5>"; let RenderMethod = "addImmOperands"; } def u5imm : Operand<i32> { @@ -688,7 +694,8 @@ def u5imm : Operand<i32> { let OperandType = "OPERAND_IMMEDIATE"; } def PPCU6ImmAsmOperand : AsmOperandClass { - let Name = "U6Imm"; let PredicateMethod = "isU6Imm"; + let Name = "U6Imm"; + let PredicateMethod = "isUImm<6>"; let RenderMethod = "addImmOperands"; } def u6imm : Operand<i32> { @@ -698,7 +705,8 @@ def u6imm : Operand<i32> { let OperandType = "OPERAND_IMMEDIATE"; } def PPCU7ImmAsmOperand : AsmOperandClass { - let Name = "U7Imm"; let PredicateMethod = "isU7Imm"; + let Name = "U7Imm"; + let PredicateMethod = "isUImm<7>"; let RenderMethod = "addImmOperands"; } def u7imm : Operand<i32> { @@ -708,7 +716,8 @@ def u7imm : Operand<i32> { let OperandType = "OPERAND_IMMEDIATE"; } def PPCU8ImmAsmOperand : AsmOperandClass { - let Name = "U8Imm"; let PredicateMethod = "isU8Imm"; + let Name = "U8Imm"; + let PredicateMethod = "isUImm<8>"; let RenderMethod = "addImmOperands"; } def u8imm : Operand<i32> { @@ -718,7 +727,8 @@ def u8imm : Operand<i32> { let OperandType = "OPERAND_IMMEDIATE"; } def PPCU10ImmAsmOperand : AsmOperandClass { - let Name = "U10Imm"; let PredicateMethod = "isU10Imm"; + let Name = "U10Imm"; + let PredicateMethod = "isUImm<10>"; let RenderMethod = "addImmOperands"; } def u10imm : Operand<i32> { @@ -728,7 +738,8 @@ def u10imm : Operand<i32> { let OperandType = "OPERAND_IMMEDIATE"; } def PPCU12ImmAsmOperand : AsmOperandClass { - let Name = "U12Imm"; let PredicateMethod = "isU12Imm"; + let Name = "U12Imm"; + let PredicateMethod = "isUImm<12>"; let RenderMethod = "addImmOperands"; } def u12imm : Operand<i32> { @@ -743,7 +754,14 @@ def PPCS16ImmAsmOperand : AsmOperandClass { } def s16imm : Operand<i32> { let PrintMethod = "printS16ImmOperand"; - let EncoderMethod = "getImm16Encoding"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>"; + let ParserMatchClass = PPCS16ImmAsmOperand; + let DecoderMethod = "decodeSImmOperand<16>"; + let OperandType = "OPERAND_IMMEDIATE"; +} +def s16imm64 : Operand<i64> { + let PrintMethod = "printS16ImmOperand"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>"; let ParserMatchClass = PPCS16ImmAsmOperand; let DecoderMethod = "decodeSImmOperand<16>"; let OperandType = "OPERAND_IMMEDIATE"; @@ -754,7 +772,14 @@ def PPCU16ImmAsmOperand : AsmOperandClass { } def u16imm : Operand<i32> { let PrintMethod = "printU16ImmOperand"; - let EncoderMethod = "getImm16Encoding"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>"; + let ParserMatchClass = PPCU16ImmAsmOperand; + let DecoderMethod = "decodeUImmOperand<16>"; + let OperandType = "OPERAND_IMMEDIATE"; +} +def u16imm64 : Operand<i64> { + let PrintMethod = "printU16ImmOperand"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>"; let ParserMatchClass = PPCU16ImmAsmOperand; let DecoderMethod = "decodeUImmOperand<16>"; let OperandType = "OPERAND_IMMEDIATE"; @@ -768,7 +793,17 @@ def s17imm : Operand<i32> { // to accept immediates in the range -65536..65535 for compatibility with // the GNU assembler. The operand is treated as 16-bit otherwise. let PrintMethod = "printS16ImmOperand"; - let EncoderMethod = "getImm16Encoding"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>"; + let ParserMatchClass = PPCS17ImmAsmOperand; + let DecoderMethod = "decodeSImmOperand<16>"; + let OperandType = "OPERAND_IMMEDIATE"; +} +def s17imm64 : Operand<i64> { + // This operand type is used for addis/lis to allow the assembler parser + // to accept immediates in the range -65536..65535 for compatibility with + // the GNU assembler. The operand is treated as 16-bit otherwise. + let PrintMethod = "printS16ImmOperand"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>"; let ParserMatchClass = PPCS17ImmAsmOperand; let DecoderMethod = "decodeSImmOperand<16>"; let OperandType = "OPERAND_IMMEDIATE"; @@ -780,14 +815,14 @@ def PPCS34ImmAsmOperand : AsmOperandClass { } def s34imm : Operand<i64> { let PrintMethod = "printS34ImmOperand"; - let EncoderMethod = "getImm34EncodingNoPCRel"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_imm34>"; let ParserMatchClass = PPCS34ImmAsmOperand; let DecoderMethod = "decodeSImmOperand<34>"; let OperandType = "OPERAND_IMMEDIATE"; } def s34imm_pcrel : Operand<i64> { let PrintMethod = "printS34ImmOperand"; - let EncoderMethod = "getImm34EncodingPCRel"; + let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_pcrel34>"; let ParserMatchClass = PPCS34ImmAsmOperand; let DecoderMethod = "decodeSImmOperand<34>"; let OperandType = "OPERAND_IMMEDIATE"; diff --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp index 597dd12..9f9ae2f 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp @@ -324,6 +324,10 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { OpdsMapping[0] = GPRValueMapping; + // Atomics always use GPR destinations. Don't refine any further. + if (cast<GLoad>(MI).isAtomic()) + break; + // Use FPR64 for s64 loads on rv32. if (GPRSize == 32 && Size.getFixedValue() == 64) { assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD()); @@ -358,6 +362,10 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { OpdsMapping[0] = GPRValueMapping; + // Atomics always use GPR sources. Don't refine any further. + if (cast<GStore>(MI).isAtomic()) + break; + // Use FPR64 for s64 stores on rv32. if (GPRSize == 32 && Size.getFixedValue() == 64) { assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD()); diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td index a02de31..27cf057 100644 --- a/llvm/lib/Target/RISCV/RISCVFeatures.td +++ b/llvm/lib/Target/RISCV/RISCVFeatures.td @@ -1421,7 +1421,7 @@ def HasVendorXMIPSCMov : Predicate<"Subtarget->hasVendorXMIPSCMov()">, AssemblerPredicate<(all_of FeatureVendorXMIPSCMov), "'Xmipscmov' ('mips.ccmov' instruction)">; -def UseCCMovInsn : Predicate<"Subtarget->useCCMovInsn()">; +def UseMIPSCCMovInsn : Predicate<"Subtarget->useMIPSCCMovInsn()">; def FeatureVendorXMIPSLSP : RISCVExtension<1, 0, "MIPS optimization for hardware load-store bonding">; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index dcce2d2..b624076 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -434,7 +434,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::ABS, MVT::i32, Custom); } - if (!Subtarget.useCCMovInsn() && !Subtarget.hasVendorXTHeadCondMov()) + if (!Subtarget.useMIPSCCMovInsn() && !Subtarget.hasVendorXTHeadCondMov()) setOperationAction(ISD::SELECT, XLenVT, Custom); if (Subtarget.hasVendorXqcia() && !Subtarget.is64Bit()) { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td index 115ab38e..0b5bee1 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td @@ -175,7 +175,7 @@ def MIPS_CCMOV : RVInstR4<0b11, 0b011, OPC_CUSTOM_0, (outs GPR:$rd), Sched<[]>; } -let Predicates = [UseCCMovInsn] in { +let Predicates = [UseMIPSCCMovInsn] in { def : Pat<(select (riscv_setne (XLenVT GPR:$rs2)), (XLenVT GPR:$rs1), (XLenVT GPR:$rs3)), (MIPS_CCMOV GPR:$rs1, GPR:$rs2, GPR:$rs3)>; diff --git a/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp index c81a20b..115a96e 100644 --- a/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp +++ b/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp @@ -92,7 +92,7 @@ bool RISCVLoadStoreOpt::runOnMachineFunction(MachineFunction &Fn) { if (skipFunction(Fn.getFunction())) return false; const RISCVSubtarget &Subtarget = Fn.getSubtarget<RISCVSubtarget>(); - if (!Subtarget.useLoadStorePairs()) + if (!Subtarget.useMIPSLoadStorePairs()) return false; bool MadeChange = false; diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp index e35ffaf..715ac4c 100644 --- a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp +++ b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp @@ -65,9 +65,9 @@ static cl::opt<bool> UseMIPSLoadStorePairsOpt( cl::desc("Enable the load/store pair optimization pass"), cl::init(false), cl::Hidden); -static cl::opt<bool> UseCCMovInsn("use-riscv-ccmov", - cl::desc("Use 'mips.ccmov' instruction"), - cl::init(true), cl::Hidden); +static cl::opt<bool> UseMIPSCCMovInsn("use-riscv-mips-ccmov", + cl::desc("Use 'mips.ccmov' instruction"), + cl::init(true), cl::Hidden); void RISCVSubtarget::anchor() {} @@ -246,10 +246,10 @@ void RISCVSubtarget::overridePostRASchedPolicy( } } -bool RISCVSubtarget::useLoadStorePairs() const { +bool RISCVSubtarget::useMIPSLoadStorePairs() const { return UseMIPSLoadStorePairsOpt && HasVendorXMIPSLSP; } -bool RISCVSubtarget::useCCMovInsn() const { - return UseCCMovInsn && HasVendorXMIPSCMov; +bool RISCVSubtarget::useMIPSCCMovInsn() const { + return UseMIPSCCMovInsn && HasVendorXMIPSCMov; } diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h index 7dffa63..6acf799 100644 --- a/llvm/lib/Target/RISCV/RISCVSubtarget.h +++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h @@ -227,8 +227,8 @@ public: unsigned getXLen() const { return is64Bit() ? 64 : 32; } - bool useLoadStorePairs() const; - bool useCCMovInsn() const; + bool useMIPSLoadStorePairs() const; + bool useMIPSCCMovInsn() const; unsigned getFLen() const { if (HasStdExtD) return 64; diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 9f2e075..e16c8f0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -2811,9 +2811,7 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { GetElementPtrInst *NewGEP = simplifyZeroLengthArrayGepInst(Ref); if (NewGEP) { Ref->replaceAllUsesWith(NewGEP); - if (isInstructionTriviallyDead(Ref)) - DeadInsts.insert(Ref); - + DeadInsts.insert(Ref); Ref = NewGEP; } if (Type *GepTy = getGEPType(Ref)) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 0afec42..989950f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -307,6 +307,10 @@ private: bool selectHandleFromBinding(Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectCounterHandleFromBinding(Register &ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const; + bool selectReadImageIntrinsic(Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const; bool selectImageWriteIntrinsic(MachineInstr &I) const; @@ -314,6 +318,8 @@ private: MachineInstr &I) const; bool selectModf(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectUpdateCounter(Register &ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; bool selectFrexp(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; // Utilities @@ -3443,6 +3449,10 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, case Intrinsic::spv_resource_handlefrombinding: { return selectHandleFromBinding(ResVReg, ResType, I); } + case Intrinsic::spv_resource_counterhandlefrombinding: + return selectCounterHandleFromBinding(ResVReg, ResType, I); + case Intrinsic::spv_resource_updatecounter: + return selectUpdateCounter(ResVReg, ResType, I); case Intrinsic::spv_resource_store_typedbuffer: { return selectImageWriteIntrinsic(I); } @@ -3478,6 +3488,130 @@ bool SPIRVInstructionSelector::selectHandleFromBinding(Register &ResVReg, *cast<GIntrinsic>(&I), I); } +bool SPIRVInstructionSelector::selectCounterHandleFromBinding( + Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const { + auto &Intr = cast<GIntrinsic>(I); + assert(Intr.getIntrinsicID() == + Intrinsic::spv_resource_counterhandlefrombinding); + + // Extract information from the intrinsic call. + Register MainHandleReg = Intr.getOperand(2).getReg(); + auto *MainHandleDef = cast<GIntrinsic>(getVRegDef(*MRI, MainHandleReg)); + assert(MainHandleDef->getIntrinsicID() == + Intrinsic::spv_resource_handlefrombinding); + + uint32_t Set = getIConstVal(Intr.getOperand(4).getReg(), MRI); + uint32_t Binding = getIConstVal(Intr.getOperand(3).getReg(), MRI); + uint32_t ArraySize = getIConstVal(MainHandleDef->getOperand(4).getReg(), MRI); + Register IndexReg = MainHandleDef->getOperand(5).getReg(); + const bool IsNonUniform = false; + std::string CounterName = + getStringValueFromReg(MainHandleDef->getOperand(6).getReg(), *MRI) + + ".counter"; + + // Create the counter variable. + MachineIRBuilder MIRBuilder(I); + Register CounterVarReg = buildPointerToResource( + GR.getPointeeType(ResType), GR.getPointerStorageClass(ResType), Set, + Binding, ArraySize, IndexReg, IsNonUniform, CounterName, MIRBuilder); + + return BuildCOPY(ResVReg, CounterVarReg, I); +} + +bool SPIRVInstructionSelector::selectUpdateCounter(Register &ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + auto &Intr = cast<GIntrinsic>(I); + assert(Intr.getIntrinsicID() == Intrinsic::spv_resource_updatecounter); + + Register CounterHandleReg = Intr.getOperand(2).getReg(); + Register IncrReg = Intr.getOperand(3).getReg(); + + // The counter handle is a pointer to the counter variable (which is a struct + // containing an i32). We need to get a pointer to that i32 member to do the + // atomic operation. +#ifndef NDEBUG + SPIRVType *CounterVarType = GR.getSPIRVTypeForVReg(CounterHandleReg); + SPIRVType *CounterVarPointeeType = GR.getPointeeType(CounterVarType); + assert(CounterVarPointeeType && + CounterVarPointeeType->getOpcode() == SPIRV::OpTypeStruct && + "Counter variable must be a struct"); + assert(GR.getPointerStorageClass(CounterVarType) == + SPIRV::StorageClass::StorageBuffer && + "Counter variable must be in the storage buffer storage class"); + assert(CounterVarPointeeType->getNumOperands() == 2 && + "Counter variable must have exactly 1 member in the struct"); + const SPIRVType *MemberType = + GR.getSPIRVTypeForVReg(CounterVarPointeeType->getOperand(1).getReg()); + assert(MemberType->getOpcode() == SPIRV::OpTypeInt && + "Counter variable struct must have a single i32 member"); +#endif + + // The struct has a single i32 member. + MachineIRBuilder MIRBuilder(I); + const Type *LLVMIntType = + Type::getInt32Ty(I.getMF()->getFunction().getContext()); + + SPIRVType *IntPtrType = GR.getOrCreateSPIRVPointerType( + LLVMIntType, MIRBuilder, SPIRV::StorageClass::StorageBuffer); + + auto Zero = buildI32Constant(0, I); + if (!Zero.second) + return false; + + Register PtrToCounter = + MRI->createVirtualRegister(GR.getRegClass(IntPtrType)); + if (!BuildMI(*I.getParent(), I, I.getDebugLoc(), + TII.get(SPIRV::OpAccessChain)) + .addDef(PtrToCounter) + .addUse(GR.getSPIRVTypeID(IntPtrType)) + .addUse(CounterHandleReg) + .addUse(Zero.first) + .constrainAllUses(TII, TRI, RBI)) { + return false; + } + + // For UAV/SSBO counters, the scope is Device. The counter variable is not + // used as a flag. So the memory semantics can be None. + auto Scope = buildI32Constant(SPIRV::Scope::Device, I); + if (!Scope.second) + return false; + auto Semantics = buildI32Constant(SPIRV::MemorySemantics::None, I); + if (!Semantics.second) + return false; + + int64_t IncrVal = getIConstValSext(IncrReg, MRI); + auto Incr = buildI32Constant(static_cast<uint32_t>(IncrVal), I); + if (!Incr.second) + return false; + + Register AtomicRes = MRI->createVirtualRegister(GR.getRegClass(ResType)); + if (!BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpAtomicIAdd)) + .addDef(AtomicRes) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(PtrToCounter) + .addUse(Scope.first) + .addUse(Semantics.first) + .addUse(Incr.first) + .constrainAllUses(TII, TRI, RBI)) { + return false; + } + if (IncrVal >= 0) { + return BuildCOPY(ResVReg, AtomicRes, I); + } + + // In HLSL, IncrementCounter returns the value *before* the increment, while + // DecrementCounter returns the value *after* the decrement. Both are lowered + // to the same atomic intrinsic which returns the value *before* the + // operation. So for decrements (negative IncrVal), we must subtract the + // increment value from the result to get the post-decrement value. + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(AtomicRes) + .addUse(Incr.first) + .constrainAllUses(TII, TRI, RBI); +} bool SPIRVInstructionSelector::selectReadImageIntrinsic( Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const { diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp index 205895e..fc14a03 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp @@ -39,6 +39,10 @@ private: void collectBindingInfo(Module &M); uint32_t getAndReserveFirstUnusedBinding(uint32_t DescSet); void replaceImplicitBindingCalls(Module &M); + void replaceResourceHandleCall(Module &M, CallInst *OldCI, + uint32_t NewBinding); + void replaceCounterHandleCall(Module &M, CallInst *OldCI, + uint32_t NewBinding); void verifyUniqueOrderIdPerResource(SmallVectorImpl<CallInst *> &Calls); // A map from descriptor set to a bit vector of used binding numbers. @@ -56,64 +60,93 @@ struct BindingInfoCollector : public InstVisitor<BindingInfoCollector> { : UsedBindings(UsedBindings), ImplicitBindingCalls(ImplicitBindingCalls) { } + void addBinding(uint32_t DescSet, uint32_t Binding) { + if (UsedBindings.size() <= DescSet) { + UsedBindings.resize(DescSet + 1); + UsedBindings[DescSet].resize(64); + } + if (UsedBindings[DescSet].size() <= Binding) { + UsedBindings[DescSet].resize(2 * Binding + 1); + } + UsedBindings[DescSet].set(Binding); + } + void visitCallInst(CallInst &CI) { if (CI.getIntrinsicID() == Intrinsic::spv_resource_handlefrombinding) { const uint32_t DescSet = cast<ConstantInt>(CI.getArgOperand(0))->getZExtValue(); const uint32_t Binding = cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue(); - - if (UsedBindings.size() <= DescSet) { - UsedBindings.resize(DescSet + 1); - UsedBindings[DescSet].resize(64); - } - if (UsedBindings[DescSet].size() <= Binding) { - UsedBindings[DescSet].resize(2 * Binding + 1); - } - UsedBindings[DescSet].set(Binding); + addBinding(DescSet, Binding); } else if (CI.getIntrinsicID() == Intrinsic::spv_resource_handlefromimplicitbinding) { ImplicitBindingCalls.push_back(&CI); + } else if (CI.getIntrinsicID() == + Intrinsic::spv_resource_counterhandlefrombinding) { + const uint32_t DescSet = + cast<ConstantInt>(CI.getArgOperand(2))->getZExtValue(); + const uint32_t Binding = + cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue(); + addBinding(DescSet, Binding); + } else if (CI.getIntrinsicID() == + Intrinsic::spv_resource_counterhandlefromimplicitbinding) { + ImplicitBindingCalls.push_back(&CI); } } }; +static uint32_t getOrderId(const CallInst *CI) { + uint32_t OrderIdArgIdx = 0; + switch (CI->getIntrinsicID()) { + case Intrinsic::spv_resource_handlefromimplicitbinding: + OrderIdArgIdx = 0; + break; + case Intrinsic::spv_resource_counterhandlefromimplicitbinding: + OrderIdArgIdx = 1; + break; + default: + llvm_unreachable("CallInst is not an implicit binding intrinsic"); + } + return cast<ConstantInt>(CI->getArgOperand(OrderIdArgIdx))->getZExtValue(); +} + +static uint32_t getDescSet(const CallInst *CI) { + uint32_t DescSetArgIdx; + switch (CI->getIntrinsicID()) { + case Intrinsic::spv_resource_handlefromimplicitbinding: + case Intrinsic::spv_resource_handlefrombinding: + DescSetArgIdx = 1; + break; + case Intrinsic::spv_resource_counterhandlefromimplicitbinding: + case Intrinsic::spv_resource_counterhandlefrombinding: + DescSetArgIdx = 2; + break; + default: + llvm_unreachable("CallInst is not an implicit binding intrinsic"); + } + return cast<ConstantInt>(CI->getArgOperand(DescSetArgIdx))->getZExtValue(); +} + void SPIRVLegalizeImplicitBinding::collectBindingInfo(Module &M) { BindingInfoCollector InfoCollector(UsedBindings, ImplicitBindingCalls); InfoCollector.visit(M); // Sort the collected calls by their order ID. - std::sort( - ImplicitBindingCalls.begin(), ImplicitBindingCalls.end(), - [](const CallInst *A, const CallInst *B) { - const uint32_t OrderIdArgIdx = 0; - const uint32_t OrderA = - cast<ConstantInt>(A->getArgOperand(OrderIdArgIdx))->getZExtValue(); - const uint32_t OrderB = - cast<ConstantInt>(B->getArgOperand(OrderIdArgIdx))->getZExtValue(); - return OrderA < OrderB; - }); + std::sort(ImplicitBindingCalls.begin(), ImplicitBindingCalls.end(), + [](const CallInst *A, const CallInst *B) { + return getOrderId(A) < getOrderId(B); + }); } void SPIRVLegalizeImplicitBinding::verifyUniqueOrderIdPerResource( SmallVectorImpl<CallInst *> &Calls) { // Check that the order Id is unique per resource. for (uint32_t i = 1; i < Calls.size(); ++i) { - const uint32_t OrderIdArgIdx = 0; - const uint32_t DescSetArgIdx = 1; - const uint32_t OrderA = - cast<ConstantInt>(Calls[i - 1]->getArgOperand(OrderIdArgIdx)) - ->getZExtValue(); - const uint32_t OrderB = - cast<ConstantInt>(Calls[i]->getArgOperand(OrderIdArgIdx)) - ->getZExtValue(); + const uint32_t OrderA = getOrderId(Calls[i - 1]); + const uint32_t OrderB = getOrderId(Calls[i]); if (OrderA == OrderB) { - const uint32_t DescSetA = - cast<ConstantInt>(Calls[i - 1]->getArgOperand(DescSetArgIdx)) - ->getZExtValue(); - const uint32_t DescSetB = - cast<ConstantInt>(Calls[i]->getArgOperand(DescSetArgIdx)) - ->getZExtValue(); + const uint32_t DescSetA = getDescSet(Calls[i - 1]); + const uint32_t DescSetB = getDescSet(Calls[i]); if (DescSetA != DescSetB) { report_fatal_error("Implicit binding calls with the same order ID must " "have the same descriptor set"); @@ -144,36 +177,26 @@ void SPIRVLegalizeImplicitBinding::replaceImplicitBindingCalls(Module &M) { uint32_t lastBindingNumber = -1; for (CallInst *OldCI : ImplicitBindingCalls) { - IRBuilder<> Builder(OldCI); - const uint32_t OrderId = - cast<ConstantInt>(OldCI->getArgOperand(0))->getZExtValue(); - const uint32_t DescSet = - cast<ConstantInt>(OldCI->getArgOperand(1))->getZExtValue(); - - // Reuse an existing binding for this order ID, if one was already assigned. - // Otherwise, assign a new binding. - const uint32_t NewBinding = (lastOrderId == OrderId) - ? lastBindingNumber - : getAndReserveFirstUnusedBinding(DescSet); - lastOrderId = OrderId; - lastBindingNumber = NewBinding; - - SmallVector<Value *, 8> Args; - Args.push_back(Builder.getInt32(DescSet)); - Args.push_back(Builder.getInt32(NewBinding)); - - // Copy the remaining arguments from the old call. - for (uint32_t i = 2; i < OldCI->arg_size(); ++i) { - Args.push_back(OldCI->getArgOperand(i)); + const uint32_t OrderId = getOrderId(OldCI); + uint32_t BindingNumber; + if (OrderId == lastOrderId) { + BindingNumber = lastBindingNumber; + } else { + const uint32_t DescSet = getDescSet(OldCI); + BindingNumber = getAndReserveFirstUnusedBinding(DescSet); } - Function *NewFunc = Intrinsic::getOrInsertDeclaration( - &M, Intrinsic::spv_resource_handlefrombinding, OldCI->getType()); - CallInst *NewCI = Builder.CreateCall(NewFunc, Args); - NewCI->setCallingConv(OldCI->getCallingConv()); - - OldCI->replaceAllUsesWith(NewCI); - OldCI->eraseFromParent(); + if (OldCI->getIntrinsicID() == + Intrinsic::spv_resource_handlefromimplicitbinding) { + replaceResourceHandleCall(M, OldCI, BindingNumber); + } else { + assert(OldCI->getIntrinsicID() == + Intrinsic::spv_resource_counterhandlefromimplicitbinding && + "Unexpected implicit binding intrinsic"); + replaceCounterHandleCall(M, OldCI, BindingNumber); + } + lastOrderId = OrderId; + lastBindingNumber = BindingNumber; } } @@ -196,4 +219,49 @@ INITIALIZE_PASS(SPIRVLegalizeImplicitBinding, "legalize-spirv-implicit-binding", ModulePass *llvm::createSPIRVLegalizeImplicitBindingPass() { return new SPIRVLegalizeImplicitBinding(); -}
\ No newline at end of file +} + +void SPIRVLegalizeImplicitBinding::replaceResourceHandleCall( + Module &M, CallInst *OldCI, uint32_t NewBinding) { + IRBuilder<> Builder(OldCI); + const uint32_t DescSet = + cast<ConstantInt>(OldCI->getArgOperand(1))->getZExtValue(); + + SmallVector<Value *, 8> Args; + Args.push_back(Builder.getInt32(DescSet)); + Args.push_back(Builder.getInt32(NewBinding)); + + // Copy the remaining arguments from the old call. + for (uint32_t i = 2; i < OldCI->arg_size(); ++i) { + Args.push_back(OldCI->getArgOperand(i)); + } + + Function *NewFunc = Intrinsic::getOrInsertDeclaration( + &M, Intrinsic::spv_resource_handlefrombinding, OldCI->getType()); + CallInst *NewCI = Builder.CreateCall(NewFunc, Args); + NewCI->setCallingConv(OldCI->getCallingConv()); + + OldCI->replaceAllUsesWith(NewCI); + OldCI->eraseFromParent(); +} + +void SPIRVLegalizeImplicitBinding::replaceCounterHandleCall( + Module &M, CallInst *OldCI, uint32_t NewBinding) { + IRBuilder<> Builder(OldCI); + const uint32_t DescSet = + cast<ConstantInt>(OldCI->getArgOperand(2))->getZExtValue(); + + SmallVector<Value *, 8> Args; + Args.push_back(OldCI->getArgOperand(0)); + Args.push_back(Builder.getInt32(NewBinding)); + Args.push_back(Builder.getInt32(DescSet)); + + Type *Tys[] = {OldCI->getType(), OldCI->getArgOperand(0)->getType()}; + Function *NewFunc = Intrinsic::getOrInsertDeclaration( + &M, Intrinsic::spv_resource_counterhandlefrombinding, Tys); + CallInst *NewCI = Builder.CreateCall(NewFunc, Args); + NewCI->setCallingConv(OldCI->getCallingConv()); + + OldCI->replaceAllUsesWith(NewCI); + OldCI->eraseFromParent(); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 327c011..1d47c89 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -385,6 +385,12 @@ uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) { return MI->getOperand(1).getCImm()->getValue().getZExtValue(); } +int64_t getIConstValSext(Register ConstReg, const MachineRegisterInfo *MRI) { + const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI); + assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT); + return MI->getOperand(1).getCImm()->getSExtValue(); +} + bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) { if (const auto *GI = dyn_cast<GIntrinsic>(&MI)) return GI->is(IntrinsicID); diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index 409a0fd..5777a24 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -289,6 +289,9 @@ MachineInstr *getDefInstrMaybeConstant(Register &ConstReg, // Get constant integer value of the given ConstReg. uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI); +// Get constant integer value of the given ConstReg, sign-extended. +int64_t getIConstValSext(Register ConstReg, const MachineRegisterInfo *MRI); + // Check if MI is a SPIR-V specific intrinsic call. bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID); // Check if it's a SPIR-V specific intrinsic call. diff --git a/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp b/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp index 3090ad3..27fba34 100644 --- a/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp +++ b/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp @@ -407,6 +407,7 @@ bool X86InstructionSelector::select(MachineInstr &I) { case TargetOpcode::G_TRUNC: return selectTruncOrPtrToInt(I, MRI, MF); case TargetOpcode::G_INTTOPTR: + case TargetOpcode::G_FREEZE: return selectCopy(I, MRI); case TargetOpcode::G_ZEXT: return selectZext(I, MRI, MF); diff --git a/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp b/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp index e7709ef..11ef721 100644 --- a/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp +++ b/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp @@ -89,9 +89,29 @@ X86LegalizerInfo::X86LegalizerInfo(const X86Subtarget &STI, // 32/64-bits needs support for s64/s128 to handle cases: // s64 = EXTEND (G_IMPLICIT_DEF s32) -> s64 = G_IMPLICIT_DEF // s128 = EXTEND (G_IMPLICIT_DEF s32/s64) -> s128 = G_IMPLICIT_DEF - getActionDefinitionsBuilder(G_IMPLICIT_DEF) + getActionDefinitionsBuilder( + {G_IMPLICIT_DEF, G_PHI, G_FREEZE, G_CONSTANT_FOLD_BARRIER}) .legalFor({p0, s1, s8, s16, s32, s64}) - .legalFor(Is64Bit, {s128}); + .legalFor(UseX87, {s80}) + .legalFor(Is64Bit, {s128}) + .legalFor(HasSSE2, {v16s8, v8s16, v4s32, v2s64}) + .legalFor(HasAVX, {v32s8, v16s16, v8s32, v4s64}) + .legalFor(HasAVX512, {v64s8, v32s16, v16s32, v8s64}) + .widenScalarOrEltToNextPow2(0, /*Min=*/8) + .clampScalarOrElt(0, s8, sMaxScalar) + .moreElementsToNextPow2(0) + .clampMinNumElements(0, s8, 16) + .clampMinNumElements(0, s16, 8) + .clampMinNumElements(0, s32, 4) + .clampMinNumElements(0, s64, 2) + .clampMaxNumElements(0, s8, HasAVX512 ? 64 : (HasAVX ? 32 : 16)) + .clampMaxNumElements(0, s16, HasAVX512 ? 32 : (HasAVX ? 16 : 8)) + .clampMaxNumElements(0, s32, HasAVX512 ? 16 : (HasAVX ? 8 : 4)) + .clampMaxNumElements(0, s64, HasAVX512 ? 8 : (HasAVX ? 4 : 2)) + .clampMaxNumElements(0, p0, + Is64Bit ? s64MaxVector.getNumElements() + : s32MaxVector.getNumElements()) + .scalarizeIf(scalarOrEltWiderThan(0, 64), 0); getActionDefinitionsBuilder(G_CONSTANT) .legalFor({p0, s8, s16, s32}) @@ -289,26 +309,6 @@ X86LegalizerInfo::X86LegalizerInfo(const X86Subtarget &STI, .clampScalar(1, s16, sMaxScalar) .scalarSameSizeAs(0, 1); - // control flow - getActionDefinitionsBuilder(G_PHI) - .legalFor({s8, s16, s32, p0}) - .legalFor(UseX87, {s80}) - .legalFor(Is64Bit, {s64}) - .legalFor(HasSSE1, {v16s8, v8s16, v4s32, v2s64}) - .legalFor(HasAVX, {v32s8, v16s16, v8s32, v4s64}) - .legalFor(HasAVX512, {v64s8, v32s16, v16s32, v8s64}) - .clampMinNumElements(0, s8, 16) - .clampMinNumElements(0, s16, 8) - .clampMinNumElements(0, s32, 4) - .clampMinNumElements(0, s64, 2) - .clampMaxNumElements(0, s8, HasAVX512 ? 64 : (HasAVX ? 32 : 16)) - .clampMaxNumElements(0, s16, HasAVX512 ? 32 : (HasAVX ? 16 : 8)) - .clampMaxNumElements(0, s32, HasAVX512 ? 16 : (HasAVX ? 8 : 4)) - .clampMaxNumElements(0, s64, HasAVX512 ? 8 : (HasAVX ? 4 : 2)) - .widenScalarToNextPow2(0, /*Min=*/32) - .clampScalar(0, s8, sMaxScalar) - .scalarize(0); - getActionDefinitionsBuilder(G_BRCOND).legalFor({s1}); // pointer handling @@ -592,11 +592,6 @@ X86LegalizerInfo::X86LegalizerInfo(const X86Subtarget &STI, .minScalar(0, LLT::scalar(32)) .libcall(); - getActionDefinitionsBuilder({G_FREEZE, G_CONSTANT_FOLD_BARRIER}) - .legalFor({s8, s16, s32, s64, p0}) - .widenScalarToNextPow2(0, /*Min=*/8) - .clampScalar(0, s8, sMaxScalar); - getLegacyLegalizerInfo().computeTables(); verify(*STI.getInstrInfo()); } diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td index 564810c..83bd6ac 100644 --- a/llvm/lib/Target/X86/X86InstrAVX512.td +++ b/llvm/lib/Target/X86/X86InstrAVX512.td @@ -662,6 +662,7 @@ def VINSERTPSZrri : AVX512AIi8<0x21, MRMSrcReg, (outs VR128X:$dst), "vinsertps\t{$src3, $src2, $src1, $dst|$dst, $src1, $src2, $src3}", [(set VR128X:$dst, (X86insertps VR128X:$src1, VR128X:$src2, timm:$src3))]>, EVEX, VVVV, Sched<[SchedWriteFShuffle.XMM]>; +let mayLoad = 1 in def VINSERTPSZrmi : AVX512AIi8<0x21, MRMSrcMem, (outs VR128X:$dst), (ins VR128X:$src1, f32mem:$src2, u8imm:$src3), "vinsertps\t{$src3, $src2, $src1, $dst|$dst, $src1, $src2, $src3}", @@ -1293,6 +1294,7 @@ multiclass avx512_subvec_broadcast_rm<bits<8> opc, string OpcodeStr, SDPatternOperator OpNode, X86VectorVTInfo _Dst, X86VectorVTInfo _Src> { + let hasSideEffects = 0, mayLoad = 1 in defm rm : AVX512_maskable<opc, MRMSrcMem, _Dst, (outs _Dst.RC:$dst), (ins _Src.MemOp:$src), OpcodeStr, "$src", "$src", (_Dst.VT (OpNode addr:$src))>, @@ -1748,6 +1750,7 @@ let Constraints = "$src1 = $dst", ExeDomain = _.ExeDomain in { (_.VT (X86VPermt2 _.RC:$src1, IdxVT.RC:$src2, _.RC:$src3)), 1>, EVEX, VVVV, AVX5128IBase, Sched<[sched]>; + let hasSideEffects = 0, mayLoad = 1 in defm rm: AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins IdxVT.RC:$src2, _.MemOp:$src3), OpcodeStr, "$src3, $src2", "$src2, $src3", @@ -1759,7 +1762,7 @@ let Constraints = "$src1 = $dst", ExeDomain = _.ExeDomain in { multiclass avx512_perm_t_mb<bits<8> opc, string OpcodeStr, X86FoldableSchedWrite sched, X86VectorVTInfo _, X86VectorVTInfo IdxVT> { - let Constraints = "$src1 = $dst", ExeDomain = _.ExeDomain in + let Constraints = "$src1 = $dst", ExeDomain = _.ExeDomain, hasSideEffects = 0, mayLoad = 1 in defm rmb: AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins IdxVT.RC:$src2, _.ScalarMemOp:$src3), OpcodeStr, !strconcat("${src3}", _.BroadcastStr,", $src2"), @@ -1987,6 +1990,7 @@ multiclass avx512_cmp_scalar<X86VectorVTInfo _, SDNode OpNode, SDNode OpNodeSAE, _.FRC:$src2, timm:$cc))]>, EVEX, VVVV, VEX_LIG, Sched<[sched]>, SIMD_EXC; + let mayLoad = 1 in def rmi : AVX512Ii8<0xC2, MRMSrcMem, (outs _.KRC:$dst), (ins _.FRC:$src1, _.ScalarMemOp:$src2, u8imm:$cc), @@ -2145,6 +2149,7 @@ multiclass avx512_icmp_cc<bits<8> opc, string Suffix, PatFrag Frag, (_.VT _.RC:$src2), cond)))]>, EVEX, VVVV, Sched<[sched]>; + let mayLoad = 1 in def rmi : AVX512AIi8<opc, MRMSrcMem, (outs _.KRC:$dst), (ins _.RC:$src1, _.MemOp:$src2, u8imm:$cc), !strconcat("vpcmp", Suffix, @@ -2167,6 +2172,7 @@ multiclass avx512_icmp_cc<bits<8> opc, string Suffix, PatFrag Frag, (_.VT _.RC:$src2), cond))))]>, EVEX, VVVV, EVEX_K, Sched<[sched]>; + let mayLoad = 1 in def rmik : AVX512AIi8<opc, MRMSrcMem, (outs _.KRC:$dst), (ins _.KRCWM:$mask, _.RC:$src1, _.MemOp:$src2, u8imm:$cc), @@ -2198,6 +2204,7 @@ multiclass avx512_icmp_cc_rmb<bits<8> opc, string Suffix, PatFrag Frag, PatFrag Frag_su, X86FoldableSchedWrite sched, X86VectorVTInfo _, string Name> : avx512_icmp_cc<opc, Suffix, Frag, Frag_su, sched, _, Name> { + let mayLoad = 1 in { def rmbi : AVX512AIi8<opc, MRMSrcMem, (outs _.KRC:$dst), (ins _.RC:$src1, _.ScalarMemOp:$src2, u8imm:$cc), @@ -2221,6 +2228,7 @@ multiclass avx512_icmp_cc_rmb<bits<8> opc, string Suffix, PatFrag Frag, (_.BroadcastLdFrag addr:$src2), cond))))]>, EVEX, VVVV, EVEX_K, EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>; + } def : Pat<(_.KVT (Frag:$cc (_.BroadcastLdFrag addr:$src2), (_.VT _.RC:$src1), cond)), @@ -2305,6 +2313,7 @@ let Uses = [MXCSR], mayRaiseFPException = 1 in { (X86cmpm_su (_.VT _.RC:$src1), (_.VT _.RC:$src2), timm:$cc), 1>, Sched<[sched]>; + let mayLoad = 1 in { defm rmi : AVX512_maskable_cmp<0xC2, MRMSrcMem, _, (outs _.KRC:$dst),(ins _.RC:$src1, _.MemOp:$src2, u8imm:$cc), "vcmp"#_.Suffix, @@ -2329,6 +2338,7 @@ let Uses = [MXCSR], mayRaiseFPException = 1 in { timm:$cc)>, EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>; } + } // Patterns for selecting with loads in other operand. def : Pat<(X86any_cmpm (_.LdFrag addr:$src2), (_.VT _.RC:$src1), @@ -3771,6 +3781,7 @@ def VMOVDI2PDIZrr : AVX512BI<0x6E, MRMSrcReg, (outs VR128X:$dst), (ins GR32:$src [(set VR128X:$dst, (v4i32 (scalar_to_vector GR32:$src)))]>, EVEX, Sched<[WriteVecMoveFromGpr]>; +let mayLoad = 1 in def VMOVDI2PDIZrm : AVX512BI<0x6E, MRMSrcMem, (outs VR128X:$dst), (ins i32mem:$src), "vmovd\t{$src, $dst|$dst, $src}", [(set VR128X:$dst, @@ -3874,7 +3885,7 @@ def VMOVSS2DIZrr : AVX512BI<0x7E, MRMDestReg, (outs GR32:$dst), // Move Quadword Int to Packed Quadword Int // -let ExeDomain = SSEPackedInt in { +let ExeDomain = SSEPackedInt, mayLoad = 1, hasSideEffects = 0 in { def VMOVQI2PQIZrm : AVX512XSI<0x7E, MRMSrcMem, (outs VR128X:$dst), (ins i64mem:$src), "vmovq\t{$src, $dst|$dst, $src}", @@ -3930,13 +3941,13 @@ multiclass avx512_move_scalar<string asm, SDNode OpNode, PatFrag vzload_frag, (_.VT (OpNode _.RC:$src1, _.RC:$src2)), (_.VT _.RC:$src0))))], _.ExeDomain>, EVEX, VVVV, EVEX_K, Sched<[SchedWriteFShuffle.XMM]>; - let canFoldAsLoad = 1, isReMaterializable = 1 in { + let canFoldAsLoad = 1, isReMaterializable = 1, mayLoad = 1, hasSideEffects = 0 in { def rm : AVX512PI<0x10, MRMSrcMem, (outs _.RC:$dst), (ins _.ScalarMemOp:$src), !strconcat(asm, "\t{$src, $dst|$dst, $src}"), [(set _.RC:$dst, (_.VT (vzload_frag addr:$src)))], _.ExeDomain>, EVEX, Sched<[WriteFLoad]>; // _alt version uses FR32/FR64 register class. - let isCodeGenOnly = 1 in + let isCodeGenOnly = 1, mayLoad = 1, hasSideEffects = 0 in def rm_alt : AVX512PI<0x10, MRMSrcMem, (outs _.FRC:$dst), (ins _.ScalarMemOp:$src), !strconcat(asm, "\t{$src, $dst|$dst, $src}"), [(set _.FRC:$dst, (_.ScalarLdFrag addr:$src))], @@ -4557,6 +4568,7 @@ let Predicates = [HasAVX512] in { // AVX-512 - Non-temporals //===----------------------------------------------------------------------===// +let mayLoad = 1, hasSideEffects = 0 in { def VMOVNTDQAZrm : AVX512PI<0x2A, MRMSrcMem, (outs VR512:$dst), (ins i512mem:$src), "vmovntdqa\t{$src, $dst|$dst, $src}", [], SSEPackedInt>, Sched<[SchedWriteVecMoveLSNT.ZMM.RM]>, @@ -4575,11 +4587,12 @@ let Predicates = [HasVLX] in { [], SSEPackedInt>, Sched<[SchedWriteVecMoveLSNT.XMM.RM]>, EVEX, T8, PD, EVEX_V128, EVEX_CD8<64, CD8VF>; } +} multiclass avx512_movnt<bits<8> opc, string OpcodeStr, X86VectorVTInfo _, X86SchedWriteMoveLS Sched, PatFrag st_frag = alignednontemporalstore> { - let SchedRW = [Sched.MR], AddedComplexity = 400 in + let mayStore = 1, SchedRW = [Sched.MR], AddedComplexity = 400 in def mr : AVX512PI<opc, MRMDestMem, (outs), (ins _.MemOp:$dst, _.RC:$src), !strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"), [(st_frag (_.VT _.RC:$src), addr:$dst)], @@ -4682,6 +4695,7 @@ multiclass avx512_binop_rm<bits<8> opc, string OpcodeStr, SDNode OpNode, IsCommutable, IsCommutable>, AVX512BIBase, EVEX, VVVV, Sched<[sched]>; + let mayLoad = 1, hasSideEffects = 0 in defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.MemOp:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", @@ -4694,6 +4708,7 @@ multiclass avx512_binop_rmb<bits<8> opc, string OpcodeStr, SDNode OpNode, X86VectorVTInfo _, X86FoldableSchedWrite sched, bit IsCommutable = 0> : avx512_binop_rm<opc, OpcodeStr, OpNode, _, sched, IsCommutable> { + let mayLoad = 1, hasSideEffects = 0 in defm rmb : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.ScalarMemOp:$src2), OpcodeStr, "${src2}"#_.BroadcastStr#", $src1", @@ -4811,6 +4826,7 @@ multiclass avx512_binop_rm2<bits<8> opc, string OpcodeStr, (_Src.VT _Src.RC:$src2))), IsCommutable>, AVX512BIBase, EVEX, VVVV, Sched<[sched]>; + let mayLoad = 1, hasSideEffects = 0 in { defm rm : AVX512_maskable<opc, MRMSrcMem, _Dst, (outs _Dst.RC:$dst), (ins _Src.RC:$src1, _Src.MemOp:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", @@ -4828,6 +4844,7 @@ multiclass avx512_binop_rm2<bits<8> opc, string OpcodeStr, (_Brdct.VT (_Brdct.BroadcastLdFrag addr:$src2)))))>, AVX512BIBase, EVEX, VVVV, EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>; + } } defm VPADD : avx512_binop_rm_vl_all<0xFC, 0xFD, 0xFE, 0xD4, "vpadd", add, @@ -4893,6 +4910,7 @@ defm VPMULTISHIFTQB : avx512_binop_all<0x83, "vpmultishiftqb", SchedWriteVecALU, multiclass avx512_packs_rmb<bits<8> opc, string OpcodeStr, SDNode OpNode, X86VectorVTInfo _Src, X86VectorVTInfo _Dst, X86FoldableSchedWrite sched> { + let mayLoad = 1, hasSideEffects = 0 in defm rmb : AVX512_maskable<opc, MRMSrcMem, _Dst, (outs _Dst.RC:$dst), (ins _Src.RC:$src1, _Src.ScalarMemOp:$src2), OpcodeStr, @@ -4916,6 +4934,7 @@ multiclass avx512_packs_rm<bits<8> opc, string OpcodeStr, (_Src.VT _Src.RC:$src2))), IsCommutable, IsCommutable>, EVEX_CD8<_Src.EltSize, CD8VF>, EVEX, VVVV, Sched<[sched]>; + let mayLoad = 1, hasSideEffects = 0 in defm rm : AVX512_maskable<opc, MRMSrcMem, _Dst, (outs _Dst.RC:$dst), (ins _Src.RC:$src1, _Src.MemOp:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", @@ -5370,6 +5389,7 @@ multiclass avx512_fp_scalar<bits<8> opc, string OpcodeStr,X86VectorVTInfo _, (_.VT (VecNode _.RC:$src1, _.RC:$src2)), "_Int">, Sched<[sched]>; + let mayLoad = 1 in defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", @@ -5384,6 +5404,7 @@ multiclass avx512_fp_scalar<bits<8> opc, string OpcodeStr,X86VectorVTInfo _, Sched<[sched]> { let isCommutable = IsCommutable; } + let mayLoad = 1 in def rm : I< opc, MRMSrcMem, (outs _.FRC:$dst), (ins _.FRC:$src1, _.ScalarMemOp:$src2), OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}", @@ -5414,6 +5435,7 @@ multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _, (_.VT (VecNode _.RC:$src1, _.RC:$src2)), "_Int">, Sched<[sched]>, SIMD_EXC; + let mayLoad = 1 in defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", @@ -5430,6 +5452,7 @@ multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _, Sched<[sched]> { let isCommutable = IsCommutable; } + let mayLoad = 1 in def rm : I< opc, MRMSrcMem, (outs _.FRC:$dst), (ins _.FRC:$src1, _.ScalarMemOp:$src2), OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}", @@ -5509,6 +5532,7 @@ multiclass avx512_comutable_binop_s<bits<8> opc, string OpcodeStr, Sched<[sched]> { let isCommutable = 1; } + let mayLoad = 1 in def rm : I< opc, MRMSrcMem, (outs _.FRC:$dst), (ins _.FRC:$src1, _.ScalarMemOp:$src2), OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}", @@ -5737,6 +5761,7 @@ multiclass avx512_fp_scalef_p<bits<8> opc, string OpcodeStr, SDNode OpNode, "$src2, $src1", "$src1, $src2", (_.VT (OpNode _.RC:$src1, _.RC:$src2))>, EVEX, VVVV, Sched<[sched]>; + let mayLoad = 1 in { defm rm: AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.MemOp:$src2), OpcodeStr#_.Suffix, "$src2, $src1", "$src1, $src2", @@ -5749,6 +5774,7 @@ multiclass avx512_fp_scalef_p<bits<8> opc, string OpcodeStr, SDNode OpNode, (OpNode _.RC:$src1, (_.VT (_.BroadcastLdFrag addr:$src2)))>, EVEX, VVVV, EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>; } + } } multiclass avx512_fp_scalef_scalar<bits<8> opc, string OpcodeStr, SDNode OpNode, @@ -5759,6 +5785,7 @@ multiclass avx512_fp_scalef_scalar<bits<8> opc, string OpcodeStr, SDNode OpNode, "$src2, $src1", "$src1, $src2", (_.VT (OpNode _.RC:$src1, _.RC:$src2))>, Sched<[sched]>; + let mayLoad = 1 in defm rm: AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr#_.Suffix, "$src2, $src1", "$src1, $src2", @@ -5916,6 +5943,7 @@ multiclass avx512_shift_rmi<bits<8> opc, Format ImmFormR, Format ImmFormM, "$src2, $src1", "$src1, $src2", (_.VT (OpNode _.RC:$src1, (i8 timm:$src2)))>, Sched<[sched]>; + let mayLoad = 1 in defm mi : AVX512_maskable<opc, ImmFormM, _, (outs _.RC:$dst), (ins _.MemOp:$src1, u8imm:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", @@ -5928,7 +5956,7 @@ multiclass avx512_shift_rmi<bits<8> opc, Format ImmFormR, Format ImmFormM, multiclass avx512_shift_rmbi<bits<8> opc, Format ImmFormM, string OpcodeStr, SDNode OpNode, X86FoldableSchedWrite sched, X86VectorVTInfo _> { - let ExeDomain = _.ExeDomain in + let ExeDomain = _.ExeDomain, mayLoad = 1 in defm mbi : AVX512_maskable<opc, ImmFormM, _, (outs _.RC:$dst), (ins _.ScalarMemOp:$src1, u8imm:$src2), OpcodeStr, "$src2, ${src1}"#_.BroadcastStr, "${src1}"#_.BroadcastStr#", $src2", @@ -5946,6 +5974,7 @@ multiclass avx512_shift_rrm<bits<8> opc, string OpcodeStr, SDNode OpNode, "$src2, $src1", "$src1, $src2", (_.VT (OpNode _.RC:$src1, (SrcVT VR128X:$src2)))>, AVX512BIBase, EVEX, VVVV, Sched<[sched]>; + let mayLoad = 1 in defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, i128mem:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", @@ -6095,6 +6124,7 @@ multiclass avx512_var_shift<bits<8> opc, string OpcodeStr, SDNode OpNode, "$src2, $src1", "$src1, $src2", (_.VT (OpNode _.RC:$src1, (_.VT _.RC:$src2)))>, AVX5128IBase, EVEX, VVVV, Sched<[sched]>; + let mayLoad = 1 in defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.MemOp:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", @@ -6107,7 +6137,7 @@ multiclass avx512_var_shift<bits<8> opc, string OpcodeStr, SDNode OpNode, multiclass avx512_var_shift_mb<bits<8> opc, string OpcodeStr, SDNode OpNode, X86FoldableSchedWrite sched, X86VectorVTInfo _> { - let ExeDomain = _.ExeDomain in + let ExeDomain = _.ExeDomain, mayLoad = 1 in defm rmb : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.ScalarMemOp:$src2), OpcodeStr, "${src2}"#_.BroadcastStr#", $src1", @@ -6372,6 +6402,7 @@ multiclass avx512_permil_vec<bits<8> OpcVar, string OpcodeStr, SDNode OpNode, (_.VT (OpNode _.RC:$src1, (Ctrl.VT Ctrl.RC:$src2)))>, T8, PD, EVEX, VVVV, Sched<[sched]>; + let mayLoad = 1 in { defm rm: AVX512_maskable<OpcVar, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, Ctrl.MemOp:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", @@ -6389,6 +6420,7 @@ multiclass avx512_permil_vec<bits<8> OpcVar, string OpcodeStr, SDNode OpNode, (Ctrl.VT (Ctrl.BroadcastLdFrag addr:$src2))))>, T8, PD, EVEX, VVVV, EVEX_B, EVEX_CD8<_.EltSize, CD8VF>, Sched<[sched.Folded, sched.ReadAfterFold]>; + } } multiclass avx512_permil_vec_common<string OpcodeStr, bits<8> OpcVar, @@ -7258,6 +7290,7 @@ let ExeDomain = DstVT.ExeDomain, Uses = _Uses, (OpNode (DstVT.VT DstVT.RC:$src1), SrcRC:$src2))]>, EVEX, VVVV, Sched<[sched, ReadDefault, ReadInt2Fpu]>; + let mayLoad = 1 in def rm_Int : SI<opc, MRMSrcMem, (outs DstVT.RC:$dst), (ins DstVT.RC:$src1, x86memop:$src2), asm#"{"#mem#"}\t{$src2, $src1, $dst|$dst, $src1, $src2}", @@ -7400,6 +7433,7 @@ multiclass avx512_cvt_s_int_round<bits<8> opc, X86VectorVTInfo SrcVT, [(set DstVT.RC:$dst, (OpNodeRnd (SrcVT.VT SrcVT.RC:$src),(i32 timm:$rc)))]>, EVEX, VEX_LIG, EVEX_B, EVEX_RC, Sched<[sched]>; + let mayLoad = 1 in def rm_Int : SI<opc, MRMSrcMem, (outs DstVT.RC:$dst), (ins SrcVT.IntScalarMemOp:$src), !strconcat(asm,"\t{$src, $dst|$dst, $src}"), [(set DstVT.RC:$dst, (OpNode @@ -7451,6 +7485,7 @@ multiclass avx512_cvt_s<bits<8> opc, string asm, X86VectorVTInfo SrcVT, !strconcat(asm,"\t{$src, $dst|$dst, $src}"), [(set DstVT.RC:$dst, (OpNode SrcVT.FRC:$src))]>, EVEX, VEX_LIG, Sched<[sched]>, SIMD_EXC; + let mayLoad = 1 in def rm : AVX512<opc, MRMSrcMem, (outs DstVT.RC:$dst), (ins SrcVT.ScalarMemOp:$src), !strconcat(asm,"\t{$src, $dst|$dst, $src}"), [(set DstVT.RC:$dst, (OpNode (SrcVT.ScalarLdFrag addr:$src)))]>, @@ -7572,6 +7607,7 @@ let Predicates = [prd], ExeDomain = _SrcRC.ExeDomain in { !strconcat(asm,"\t{$src, $dst|$dst, $src}"), [(set _DstRC.RC:$dst, (OpNode _SrcRC.FRC:$src))]>, EVEX, VEX_LIG, Sched<[sched]>, SIMD_EXC; + let mayLoad = 1 in def rm : AVX512<opc, MRMSrcMem, (outs _DstRC.RC:$dst), (ins _SrcRC.ScalarMemOp:$src), !strconcat(asm,"\t{$src, $dst|$dst, $src}"), [(set _DstRC.RC:$dst, (OpNode (_SrcRC.ScalarLdFrag addr:$src)))]>, @@ -7587,6 +7623,7 @@ let Predicates = [prd], ExeDomain = _SrcRC.ExeDomain in { !strconcat(asm,"\t{{sae}, $src, $dst|$dst, $src, {sae}}"), [(set _DstRC.RC:$dst, (OpNodeSAE (_SrcRC.VT _SrcRC.RC:$src)))]>, EVEX, VEX_LIG, EVEX_B, Sched<[sched]>; + let mayLoad = 1 in def rm_Int : AVX512<opc, MRMSrcMem, (outs _DstRC.RC:$dst), (ins _SrcRC.IntScalarMemOp:$src), !strconcat(asm,"\t{$src, $dst|$dst, $src}"), @@ -7644,6 +7681,7 @@ multiclass avx512_cvt_fp_scalar<bits<8> opc, string OpcodeStr, X86VectorVTInfo _ (_.VT (OpNode (_.VT _.RC:$src1), (_Src.VT _Src.RC:$src2))), "_Int">, EVEX, VVVV, VEX_LIG, Sched<[sched]>; + let mayLoad = 1 in defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _Src.IntScalarMemOp:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", @@ -7807,6 +7845,7 @@ let Uses = [MXCSR], mayRaiseFPException = 1 in { _.ImmAllZerosV)>, EVEX, Sched<[sched]>; + let mayLoad = 1 in { defm rm : AVX512_maskable_cvt<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins MemOp:$src), (ins _.RC:$src0, MaskRC:$mask, MemOp:$src), @@ -7840,6 +7879,7 @@ let Uses = [MXCSR], mayRaiseFPException = 1 in { _.ImmAllZerosV)>, EVEX, EVEX_B, Sched<[sched.Folded]>; } + } } // Conversion with SAE - suppress all exceptions multiclass avx512_vcvt_fp_sae<bits<8> opc, string OpcodeStr, X86VectorVTInfo _, @@ -8944,6 +8984,7 @@ multiclass avx512_cvtph2ps<X86VectorVTInfo _dest, X86VectorVTInfo _src, (X86any_cvtph2ps (_src.VT _src.RC:$src)), (X86cvtph2ps (_src.VT _src.RC:$src))>, T8, PD, Sched<[sched]>; + let mayLoad = 1 in defm rm : AVX512_maskable_split<0x13, MRMSrcMem, _dest, (outs _dest.RC:$dst), (ins x86memop:$src), "vcvtph2ps", "$src", "$src", (X86any_cvtph2ps (_src.VT ld_dag)), @@ -9161,6 +9202,7 @@ multiclass avx512_fp14_s<bits<8> opc, string OpcodeStr, SDNode OpNode, "$src2, $src1", "$src1, $src2", (OpNode (_.VT _.RC:$src1), (_.VT _.RC:$src2))>, EVEX, VVVV, VEX_LIG, Sched<[sched]>; + let mayLoad = 1 in defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", @@ -9621,6 +9663,7 @@ multiclass avx512_rndscale_scalar<bits<8> opc, string OpcodeStr, (i32 timm:$src3))), "_Int">, EVEX_B, Sched<[sched]>; + let mayLoad = 1 in defm rmi : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.IntScalarMemOp:$src2, i32u8imm:$src3), OpcodeStr, @@ -9999,6 +10042,7 @@ multiclass avx512_pmovx_common<bits<8> opc, string OpcodeStr, X86FoldableSchedWr (DestInfo.VT (OpNode (SrcInfo.VT SrcInfo.RC:$src)))>, EVEX, Sched<[sched]>; + let mayLoad = 1 in defm rm : AVX512_maskable<opc, MRMSrcMem, DestInfo, (outs DestInfo.RC:$dst), (ins x86memop:$src), OpcodeStr ,"$src", "$src", (DestInfo.VT (LdFrag addr:$src))>, @@ -10601,6 +10645,7 @@ multiclass expand_by_vec_width<bits<8> opc, X86VectorVTInfo _, (null_frag)>, AVX5128IBase, Sched<[sched]>; + let mayLoad = 1 in defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.MemOp:$src1), OpcodeStr, "$src1", "$src1", (null_frag)>, @@ -10673,6 +10718,7 @@ multiclass avx512_unary_fp_packed_imm<bits<8> opc, string OpcodeStr, (OpNode (_.VT _.RC:$src1), (i32 timm:$src2)), (MaskOpNode (_.VT _.RC:$src1), (i32 timm:$src2))>, Sched<[sched]>; + let mayLoad = 1 in { defm rmi : AVX512_maskable_split<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.MemOp:$src1, i32u8imm:$src2), OpcodeStr#_.Suffix, "$src2, $src1", "$src1, $src2", @@ -10691,6 +10737,7 @@ multiclass avx512_unary_fp_packed_imm<bits<8> opc, string OpcodeStr, (i32 timm:$src2))>, EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>; } + } } //handle instruction reg_vec1 = op(reg_vec2,reg_vec3,imm),{sae} @@ -10739,6 +10786,7 @@ multiclass avx512_fp_packed_imm<bits<8> opc, string OpcodeStr, SDNode OpNode, (_.VT _.RC:$src2), (i32 timm:$src3))>, Sched<[sched]>; + let mayLoad = 1 in { defm rmi : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.MemOp:$src2, i32u8imm:$src3), OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3", @@ -10755,6 +10803,7 @@ multiclass avx512_fp_packed_imm<bits<8> opc, string OpcodeStr, SDNode OpNode, (i32 timm:$src3))>, EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>; } + } } //handle instruction reg_vec1 = op(reg_vec2,reg_vec3,imm) @@ -10770,6 +10819,7 @@ multiclass avx512_3Op_rm_imm8<bits<8> opc, string OpcodeStr, SDNode OpNode, (SrcInfo.VT SrcInfo.RC:$src2), (i8 timm:$src3)))>, Sched<[sched]>; + let mayLoad = 1 in defm rmi : AVX512_maskable<opc, MRMSrcMem, DestInfo, (outs DestInfo.RC:$dst), (ins SrcInfo.RC:$src1, SrcInfo.MemOp:$src2, u8imm:$src3), OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3", @@ -10788,7 +10838,7 @@ multiclass avx512_3Op_imm8<bits<8> opc, string OpcodeStr, SDNode OpNode, X86FoldableSchedWrite sched, X86VectorVTInfo _>: avx512_3Op_rm_imm8<opc, OpcodeStr, OpNode, sched, _, _>{ - let ExeDomain = _.ExeDomain, ImmT = Imm8 in + let ExeDomain = _.ExeDomain, ImmT = Imm8, mayLoad = 1 in defm rmbi : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.ScalarMemOp:$src2, u8imm:$src3), OpcodeStr, "$src3, ${src2}"#_.BroadcastStr#", $src1", @@ -10811,6 +10861,7 @@ multiclass avx512_fp_scalar_imm<bits<8> opc, string OpcodeStr, SDNode OpNode, (_.VT _.RC:$src2), (i32 timm:$src3))>, Sched<[sched]>; + let mayLoad = 1 in defm rmi : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.IntScalarMemOp:$src2, i32u8imm:$src3), OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3", @@ -10979,6 +11030,7 @@ multiclass avx512_shuff_packed_128_common<bits<8> opc, string OpcodeStr, (CastInfo.VT (X86Shuf128 _.RC:$src1, _.RC:$src2, (i8 timm:$src3)))))>, Sched<[sched]>; + let mayLoad = 1 in { defm rmi : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.MemOp:$src2, u8imm:$src3), OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3", @@ -11000,6 +11052,7 @@ multiclass avx512_shuff_packed_128_common<bits<8> opc, string OpcodeStr, (i8 timm:$src3)))))>, EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>; } + } } multiclass avx512_shuff_packed_128<string OpcodeStr, X86FoldableSchedWrite sched, @@ -11031,6 +11084,7 @@ multiclass avx512_valign<bits<8> opc, string OpcodeStr, OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3", (_.VT (X86VAlign _.RC:$src1, _.RC:$src2, (i8 timm:$src3)))>, Sched<[sched]>; + let mayLoad = 1 in { defm rmi : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src1, _.MemOp:$src2, u8imm:$src3), OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3", @@ -11048,6 +11102,7 @@ multiclass avx512_valign<bits<8> opc, string OpcodeStr, (i8 timm:$src3))>, EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>; } + } } multiclass avx512_valign_common<string OpcodeStr, X86SchedWriteWidths sched, @@ -11202,6 +11257,7 @@ multiclass avx512_unary_rm<bits<8> opc, string OpcodeStr, SDNode OpNode, (_.VT (OpNode (_.VT _.RC:$src1)))>, EVEX, AVX5128IBase, Sched<[sched]>; + let mayLoad = 1 in defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.MemOp:$src1), OpcodeStr, "$src1", "$src1", @@ -11214,6 +11270,7 @@ multiclass avx512_unary_rm<bits<8> opc, string OpcodeStr, SDNode OpNode, multiclass avx512_unary_rmb<bits<8> opc, string OpcodeStr, SDNode OpNode, X86FoldableSchedWrite sched, X86VectorVTInfo _> : avx512_unary_rm<opc, OpcodeStr, OpNode, sched, _> { + let mayLoad = 1 in defm rmb : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.ScalarMemOp:$src1), OpcodeStr, "${src1}"#_.BroadcastStr, @@ -11368,6 +11425,7 @@ multiclass avx512_movddup_128<bits<8> opc, string OpcodeStr, (ins _.RC:$src), OpcodeStr, "$src", "$src", (_.VT (X86VBroadcast (_.VT _.RC:$src)))>, EVEX, Sched<[sched]>; + let mayLoad = 1 in defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.ScalarMemOp:$src), OpcodeStr, "$src", "$src", (_.VT (_.BroadcastLdFrag addr:$src))>, @@ -11513,6 +11571,7 @@ defm VPEXTRQZ : avx512_extract_elt_dq<"vpextrq", v2i64x_info, GR64>, REX_W; multiclass avx512_insert_elt_m<bits<8> opc, string OpcodeStr, SDNode OpNode, X86VectorVTInfo _, PatFrag LdFrag, SDPatternOperator immoperator> { + let mayLoad = 1 in def rmi : AVX512Ii8<opc, MRMSrcMem, (outs _.RC:$dst), (ins _.RC:$src1, _.ScalarMemOp:$src2, u8imm:$src3), OpcodeStr#"\t{$src3, $src2, $src1, $dst|$dst, $src1, $src2, $src3}", @@ -11650,6 +11709,7 @@ multiclass avx512_psadbw_packed<bits<8> opc, SDNode OpNode, (OpNode (_src.VT _src.RC:$src1), (_src.VT _src.RC:$src2))))]>, Sched<[sched]>; + let mayLoad = 1 in def rm : AVX512BI<opc, MRMSrcMem, (outs _dst.RC:$dst), (ins _src.RC:$src1, _src.MemOp:$src2), !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"), @@ -11751,6 +11811,7 @@ multiclass avx512_ternlog<bits<8> opc, string OpcodeStr, SDNode OpNode, (_.VT _.RC:$src3), (i8 timm:$src4)), 1, 1>, AVX512AIi8Base, EVEX, VVVV, Sched<[sched]>; + let mayLoad = 1 in { defm rmi : AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src2, _.MemOp:$src3, u8imm:$src4), OpcodeStr, "$src4, $src3, $src2", "$src2, $src3, $src4", @@ -11770,6 +11831,7 @@ multiclass avx512_ternlog<bits<8> opc, string OpcodeStr, SDNode OpNode, (i8 timm:$src4)), 1, 0>, EVEX_B, AVX512AIi8Base, EVEX, VVVV, EVEX_CD8<_.EltSize, CD8VF>, Sched<[sched.Folded, sched.ReadAfterFold]>; + } }// Constraints = "$src1 = $dst" // Additional patterns for matching passthru operand in other positions. @@ -12016,6 +12078,7 @@ multiclass avx512_fixupimm_packed<bits<8> opc, string OpcodeStr, (_.VT _.RC:$src2), (TblVT.VT _.RC:$src3), (i32 timm:$src4))>, Sched<[sched]>; + let mayLoad = 1 in { defm rmi : AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src2, _.MemOp:$src3, i32u8imm:$src4), OpcodeStr#_.Suffix, "$src4, $src3, $src2", "$src2, $src3, $src4", @@ -12033,6 +12096,7 @@ multiclass avx512_fixupimm_packed<bits<8> opc, string OpcodeStr, (TblVT.VT (TblVT.BroadcastLdFrag addr:$src3)), (i32 timm:$src4))>, EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>; + } } // Constraints = "$src1 = $dst" } @@ -12075,6 +12139,7 @@ multiclass avx512_fixupimm_scalar<bits<8> opc, string OpcodeStr, (_src3VT.VT _src3VT.RC:$src3), (i32 timm:$src4))>, EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>; + let mayLoad = 1 in defm rmi : AVX512_maskable_3src_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst), (ins _.RC:$src2, _.ScalarMemOp:$src3, i32u8imm:$src4), OpcodeStr#_.Suffix, "$src4, $src3, $src2", "$src2, $src3, $src4", @@ -12417,6 +12482,7 @@ multiclass VNNI_rmb<bits<8> Op, string OpStr, SDNode OpNode, VTI.RC:$src2, VTI.RC:$src3)), IsCommutable, IsCommutable>, EVEX, VVVV, T8, Sched<[sched]>; + let mayLoad = 1 in { defm rm : AVX512_maskable_3src<Op, MRMSrcMem, VTI, (outs VTI.RC:$dst), (ins VTI.RC:$src2, VTI.MemOp:$src3), OpStr, "$src3, $src2", "$src2, $src3", @@ -12435,6 +12501,7 @@ multiclass VNNI_rmb<bits<8> Op, string OpStr, SDNode OpNode, T8, Sched<[sched.Folded, sched.ReadAfterFold, sched.ReadAfterFold]>; } + } } multiclass VNNI_common<bits<8> Op, string OpStr, SDNode OpNode, @@ -12508,6 +12575,7 @@ multiclass VPSHUFBITQMB_rm<X86FoldableSchedWrite sched, X86VectorVTInfo VTI> { (X86Vpshufbitqmb_su (VTI.VT VTI.RC:$src1), (VTI.VT VTI.RC:$src2))>, EVEX, VVVV, T8, PD, Sched<[sched]>; + let mayLoad = 1 in defm rm : AVX512_maskable_cmp<0x8F, MRMSrcMem, VTI, (outs VTI.KRC:$dst), (ins VTI.RC:$src1, VTI.MemOp:$src2), "vpshufbitqmb", @@ -12557,7 +12625,7 @@ multiclass GF2P8AFFINE_avx512_rmb_imm<bits<8> Op, string OpStr, SDNode OpNode, X86FoldableSchedWrite sched, X86VectorVTInfo VTI, X86VectorVTInfo BcstVTI> : avx512_3Op_rm_imm8<Op, OpStr, OpNode, sched, VTI, VTI> { - let ExeDomain = VTI.ExeDomain in + let ExeDomain = VTI.ExeDomain, mayLoad = 1 in defm rmbi : AVX512_maskable<Op, MRMSrcMem, VTI, (outs VTI.RC:$dst), (ins VTI.RC:$src1, BcstVTI.ScalarMemOp:$src2, u8imm:$src3), OpStr, "$src3, ${src2}"#BcstVTI.BroadcastStr#", $src1", @@ -12660,6 +12728,7 @@ multiclass avx512_vp2intersect_modes<X86FoldableSchedWrite sched, X86VectorVTInf _.RC:$src1, (_.VT _.RC:$src2)))]>, EVEX, VVVV, T8, XD, Sched<[sched]>; + let mayLoad = 1 in { def rm : I<0x68, MRMSrcMem, (outs _.KRPC:$dst), (ins _.RC:$src1, _.MemOp:$src2), @@ -12679,6 +12748,7 @@ multiclass avx512_vp2intersect_modes<X86FoldableSchedWrite sched, X86VectorVTInf _.RC:$src1, (_.VT (_.BroadcastLdFrag addr:$src2))))]>, EVEX, VVVV, T8, XD, EVEX_B, EVEX_CD8<_.EltSize, CD8VF>, Sched<[sched.Folded, sched.ReadAfterFold]>; + } } multiclass avx512_vp2intersect<X86SchedWriteWidths sched, AVX512VLVectorVTInfo _> { @@ -12882,6 +12952,7 @@ let Predicates = [HasFP16] in { // Move word ( r/m16) to Packed word def VMOVW2SHrr : AVX512<0x6E, MRMSrcReg, (outs VR128X:$dst), (ins GR32:$src), "vmovw\t{$src, $dst|$dst, $src}", []>, T_MAP5, PD, EVEX, Sched<[WriteVecMoveFromGpr]>; +let mayLoad = 1 in def VMOVWrm : AVX512<0x6E, MRMSrcMem, (outs VR128X:$dst), (ins i16mem:$src), "vmovw\t{$src, $dst|$dst, $src}", [(set VR128X:$dst, @@ -13607,6 +13678,7 @@ multiclass avx512_cfmbinop_sh_common<bits<8> opc, string OpcodeStr, SDNode OpNod (v4f32 (OpNode VR128X:$src1, VR128X:$src2)), IsCommutable, IsCommutable, IsCommutable, X86selects, "@earlyclobber $dst">, Sched<[WriteFMAX]>; + let mayLoad = 1 in defm rm : AVX512_maskable<opc, MRMSrcMem, f32x_info, (outs VR128X:$dst), (ins VR128X:$src1, ssmem:$src2), OpcodeStr, "$src2, $src1", "$src1, $src2", diff --git a/llvm/lib/TargetParser/TargetParser.cpp b/llvm/lib/TargetParser/TargetParser.cpp index 34b09b1..62a3c88 100644 --- a/llvm/lib/TargetParser/TargetParser.cpp +++ b/llvm/lib/TargetParser/TargetParser.cpp @@ -444,6 +444,7 @@ static void fillAMDGCNFeatureMap(StringRef GPU, const Triple &T, Features["atomic-fmin-fmax-global-f32"] = true; Features["atomic-fmin-fmax-global-f64"] = true; Features["wavefrontsize32"] = true; + Features["clusters"] = true; break; case GK_GFX1201: case GK_GFX1200: diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp index 8d9a0e7..50130da 100644 --- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp @@ -2067,6 +2067,36 @@ static void inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes, AI.run(SCCNodes, Changed); } +// Determines if the function 'F' can be marked 'norecurse'. +// It returns true if any call within 'F' could lead to a recursive +// call back to 'F', and false otherwise. +// The 'AnyFunctionsAddressIsTaken' parameter is a module-wide flag +// that is true if any function's address is taken, or if any function +// has external linkage. This is used to determine the safety of +// external/library calls. +static bool mayHaveRecursiveCallee(Function &F, + bool AnyFunctionsAddressIsTaken = true) { + for (const auto &BB : F) { + for (const auto &I : BB.instructionsWithoutDebug()) { + if (const auto *CB = dyn_cast<CallBase>(&I)) { + const Function *Callee = CB->getCalledFunction(); + if (!Callee || Callee == &F) + return true; + + if (Callee->doesNotRecurse()) + continue; + + if (!AnyFunctionsAddressIsTaken || + (Callee->isDeclaration() && + Callee->hasFnAttribute(Attribute::NoCallback))) + continue; + return true; + } + } + } + return false; +} + static void addNoRecurseAttrs(const SCCNodeSet &SCCNodes, SmallPtrSet<Function *, 8> &Changed) { // Try and identify functions that do not recurse. @@ -2078,28 +2108,14 @@ static void addNoRecurseAttrs(const SCCNodeSet &SCCNodes, Function *F = *SCCNodes.begin(); if (!F || !F->hasExactDefinition() || F->doesNotRecurse()) return; - - // If all of the calls in F are identifiable and are to norecurse functions, F - // is norecurse. This check also detects self-recursion as F is not currently - // marked norecurse, so any called from F to F will not be marked norecurse. - for (auto &BB : *F) - for (auto &I : BB.instructionsWithoutDebug()) - if (auto *CB = dyn_cast<CallBase>(&I)) { - Function *Callee = CB->getCalledFunction(); - if (!Callee || Callee == F || - (!Callee->doesNotRecurse() && - !(Callee->isDeclaration() && - Callee->hasFnAttribute(Attribute::NoCallback)))) - // Function calls a potentially recursive function. - return; - } - - // Every call was to a non-recursive function other than this function, and - // we have no indirect recursion as the SCC size is one. This function cannot - // recurse. - F->setDoesNotRecurse(); - ++NumNoRecurse; - Changed.insert(F); + if (!mayHaveRecursiveCallee(*F)) { + // Every call was to a non-recursive function other than this function, and + // we have no indirect recursion as the SCC size is one. This function + // cannot recurse. + F->setDoesNotRecurse(); + ++NumNoRecurse; + Changed.insert(F); + } } // Set the noreturn function attribute if possible. @@ -2429,3 +2445,62 @@ ReversePostOrderFunctionAttrsPass::run(Module &M, ModuleAnalysisManager &AM) { PA.preserve<LazyCallGraphAnalysis>(); return PA; } + +PreservedAnalyses NoRecurseLTOInferencePass::run(Module &M, + ModuleAnalysisManager &MAM) { + + // Check if any function in the whole program has its address taken or has + // potentially external linkage. + // We use this information when inferring norecurse attribute: If there is + // no function whose address is taken and all functions have internal + // linkage, there is no path for a callback to any user function. + bool AnyFunctionsAddressIsTaken = false; + for (Function &F : M) { + if (F.isDeclaration() || F.doesNotRecurse()) + continue; + if (!F.hasLocalLinkage() || F.hasAddressTaken()) { + AnyFunctionsAddressIsTaken = true; + break; + } + } + + // Run norecurse inference on all RefSCCs in the LazyCallGraph for this + // module. + bool Changed = false; + LazyCallGraph &CG = MAM.getResult<LazyCallGraphAnalysis>(M); + CG.buildRefSCCs(); + + for (LazyCallGraph::RefSCC &RC : CG.postorder_ref_sccs()) { + // Skip any RefSCC that is part of a call cycle. A RefSCC containing more + // than one SCC indicates a recursive relationship involving indirect calls. + if (RC.size() > 1) + continue; + + // RefSCC contains a single-SCC. SCC size > 1 indicates mutually recursive + // functions. Ex: foo1 -> foo2 -> foo3 -> foo1. + LazyCallGraph::SCC &S = *RC.begin(); + if (S.size() > 1) + continue; + + // Get the single function from this SCC. + Function &F = S.begin()->getFunction(); + if (!F.hasExactDefinition() || F.doesNotRecurse()) + continue; + + // If the analysis confirms that this function has no recursive calls + // (either direct, indirect, or through external linkages), + // we can safely apply the norecurse attribute. + if (!mayHaveRecursiveCallee(F, AnyFunctionsAddressIsTaken)) { + F.setDoesNotRecurse(); + ++NumNoRecurse; + Changed = true; + } + } + + PreservedAnalyses PA; + if (Changed) + PA.preserve<LazyCallGraphAnalysis>(); + else + PA = PreservedAnalyses::all(); + return PA; +} diff --git a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp index faeab95..cfdfd94 100644 --- a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp +++ b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp @@ -3986,6 +3986,7 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones( void ModuleCallsiteContextGraph::updateAllocationCall( CallInfo &Call, AllocationType AllocType) { std::string AllocTypeString = getAllocTypeAttributeString(AllocType); + removeAnyExistingAmbiguousAttribute(cast<CallBase>(Call.call())); auto A = llvm::Attribute::get(Call.call()->getFunction()->getContext(), "memprof", AllocTypeString); cast<CallBase>(Call.call())->addFnAttr(A); @@ -5661,9 +5662,10 @@ bool MemProfContextDisambiguation::applyImport(Module &M) { auto *MemProfMD = I.getMetadata(LLVMContext::MD_memprof); // Include allocs that were already assigned a memprof function - // attribute in the statistics. - if (CB->getAttributes().hasFnAttr("memprof")) { - assert(!MemProfMD); + // attribute in the statistics. Only do this for those that do not have + // memprof metadata, since we add an "ambiguous" memprof attribute by + // default. + if (CB->getAttributes().hasFnAttr("memprof") && !MemProfMD) { CB->getAttributes().getFnAttr("memprof").getValueAsString() == "cold" ? AllocTypeColdThinBackend++ : AllocTypeNotColdThinBackend++; @@ -5740,6 +5742,7 @@ bool MemProfContextDisambiguation::applyImport(Module &M) { // clone J-1 (J==0 is the original clone and does not have a VMaps // entry). CBClone = cast<CallBase>((*VMaps[J - 1])[CB]); + removeAnyExistingAmbiguousAttribute(CBClone); CBClone->addFnAttr(A); ORE.emit(OptimizationRemark(DEBUG_TYPE, "MemprofAttribute", CBClone) << ore::NV("AllocationCall", CBClone) << " in clone " diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 8f60e50..8c8fc69 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3356,7 +3356,10 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { impliesPoisonOrCond(FalseVal, B, /*Expected=*/false)) { // (A || B) || C --> A || (B | C) return replaceInstUsesWith( - SI, Builder.CreateLogicalOr(A, Builder.CreateOr(B, FalseVal))); + SI, Builder.CreateLogicalOr(A, Builder.CreateOr(B, FalseVal), "", + ProfcheckDisableMetadataFixes + ? nullptr + : cast<SelectInst>(CondVal))); } // (A && B) || (C && B) --> (A || C) && B @@ -3398,7 +3401,10 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { impliesPoisonOrCond(TrueVal, B, /*Expected=*/true)) { // (A && B) && C --> A && (B & C) return replaceInstUsesWith( - SI, Builder.CreateLogicalAnd(A, Builder.CreateAnd(B, TrueVal))); + SI, Builder.CreateLogicalAnd(A, Builder.CreateAnd(B, TrueVal), "", + ProfcheckDisableMetadataFixes + ? nullptr + : cast<SelectInst>(CondVal))); } // (A || B) && (C || B) --> (A && C) || B diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp index af216cd..9693ae6 100644 --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -317,24 +317,29 @@ static Value *simplifyInstruction(SCCPSolver &Solver, // Early exit if we know nothing about X. if (LRange.isFullSet()) return nullptr; - // We are allowed to refine the comparison to either true or false for out - // of range inputs. Here we refine the comparison to true, i.e. we relax - // the range check. - auto NewCR = CR->exactUnionWith(LRange.inverse()); - // TODO: Check if we can narrow the range check to an equality test. - // E.g, for X in [0, 4), X - 3 u< 2 -> X == 3 - if (!NewCR) + auto ConvertCRToICmp = + [&](const std::optional<ConstantRange> &NewCR) -> Value * { + ICmpInst::Predicate Pred; + APInt RHS; + // Check if we can represent NewCR as an icmp predicate. + if (NewCR && NewCR->getEquivalentICmp(Pred, RHS)) { + IRBuilder<NoFolder> Builder(&Inst); + Value *NewICmp = + Builder.CreateICmp(Pred, X, ConstantInt::get(X->getType(), RHS)); + InsertedValues.insert(NewICmp); + return NewICmp; + } return nullptr; - ICmpInst::Predicate Pred; - APInt RHS; - // Check if we can represent NewCR as an icmp predicate. - if (NewCR->getEquivalentICmp(Pred, RHS)) { - IRBuilder<NoFolder> Builder(&Inst); - Value *NewICmp = - Builder.CreateICmp(Pred, X, ConstantInt::get(X->getType(), RHS)); - InsertedValues.insert(NewICmp); - return NewICmp; - } + }; + // We are allowed to refine the comparison to either true or false for out + // of range inputs. + // Here we refine the comparison to false, and check if we can narrow the + // range check to a simpler test. + if (auto *V = ConvertCRToICmp(CR->exactIntersectWith(LRange))) + return V; + // Here we refine the comparison to true, i.e. we relax the range check. + if (auto *V = ConvertCRToICmp(CR->exactUnionWith(LRange.inverse()))) + return V; } } diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 48055ad..b8cfe3a 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -4895,9 +4895,8 @@ bool SimplifyCFGOpt::simplifyTerminatorOnSelect(Instruction *OldTerm, // We found both of the successors we were looking for. // Create a conditional branch sharing the condition of the select. BranchInst *NewBI = Builder.CreateCondBr(Cond, TrueBB, FalseBB); - if (TrueWeight != FalseWeight) - setBranchWeights(*NewBI, {TrueWeight, FalseWeight}, - /*IsExpected=*/false, /*ElideAllZero=*/true); + setBranchWeights(*NewBI, {TrueWeight, FalseWeight}, + /*IsExpected=*/false, /*ElideAllZero=*/true); } } else if (KeepEdge1 && (KeepEdge2 || TrueBB == FalseBB)) { // Neither of the selected blocks were successors, so this @@ -4982,9 +4981,15 @@ bool SimplifyCFGOpt::simplifyIndirectBrOnSelect(IndirectBrInst *IBI, BasicBlock *TrueBB = TBA->getBasicBlock(); BasicBlock *FalseBB = FBA->getBasicBlock(); + // The select's profile becomes the profile of the conditional branch that + // replaces the indirect branch. + SmallVector<uint32_t> SelectBranchWeights(2); + if (!ProfcheckDisableMetadataFixes) + extractBranchWeights(*SI, SelectBranchWeights); // Perform the actual simplification. - return simplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB, 0, - 0); + return simplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB, + SelectBranchWeights[0], + SelectBranchWeights[1]); } /// This is called when we find an icmp instruction @@ -5734,15 +5739,66 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) { return Changed; } -static bool casesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) { +struct ContiguousCasesResult { + ConstantInt *Min; + ConstantInt *Max; + BasicBlock *Dest; + BasicBlock *OtherDest; + SmallVectorImpl<ConstantInt *> *Cases; + SmallVectorImpl<ConstantInt *> *OtherCases; +}; + +static std::optional<ContiguousCasesResult> +findContiguousCases(Value *Condition, SmallVectorImpl<ConstantInt *> &Cases, + SmallVectorImpl<ConstantInt *> &OtherCases, + BasicBlock *Dest, BasicBlock *OtherDest) { assert(Cases.size() >= 1); array_pod_sort(Cases.begin(), Cases.end(), constantIntSortPredicate); - for (size_t I = 1, E = Cases.size(); I != E; ++I) { - if (Cases[I - 1]->getValue() != Cases[I]->getValue() + 1) - return false; + const APInt &Min = Cases.back()->getValue(); + const APInt &Max = Cases.front()->getValue(); + APInt Offset = Max - Min; + size_t ContiguousOffset = Cases.size() - 1; + if (Offset == ContiguousOffset) { + return ContiguousCasesResult{ + /*Min=*/Cases.back(), + /*Max=*/Cases.front(), + /*Dest=*/Dest, + /*OtherDest=*/OtherDest, + /*Cases=*/&Cases, + /*OtherCases=*/&OtherCases, + }; } - return true; + ConstantRange CR = computeConstantRange(Condition, /*ForSigned=*/false); + // If this is a wrapping contiguous range, that is, [Min, OtherMin] + + // [OtherMax, Max] (also [OtherMax, OtherMin]), [OtherMin+1, OtherMax-1] is a + // contiguous range for the other destination. N.B. If CR is not a full range, + // Max+1 is not equal to Min. It's not continuous in arithmetic. + if (Max == CR.getUnsignedMax() && Min == CR.getUnsignedMin()) { + assert(Cases.size() >= 2); + auto *It = + std::adjacent_find(Cases.begin(), Cases.end(), [](auto L, auto R) { + return L->getValue() != R->getValue() + 1; + }); + if (It == Cases.end()) + return std::nullopt; + auto [OtherMax, OtherMin] = std::make_pair(*It, *std::next(It)); + if ((Max - OtherMax->getValue()) + (OtherMin->getValue() - Min) == + Cases.size() - 2) { + return ContiguousCasesResult{ + /*Min=*/cast<ConstantInt>( + ConstantInt::get(OtherMin->getType(), OtherMin->getValue() + 1)), + /*Max=*/ + cast<ConstantInt>( + ConstantInt::get(OtherMax->getType(), OtherMax->getValue() - 1)), + /*Dest=*/OtherDest, + /*OtherDest=*/Dest, + /*Cases=*/&OtherCases, + /*OtherCases=*/&Cases, + }; + } + } + return std::nullopt; } static void createUnreachableSwitchDefault(SwitchInst *Switch, @@ -5779,7 +5835,6 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI, bool HasDefault = !SI->defaultDestUnreachable(); auto *BB = SI->getParent(); - // Partition the cases into two sets with different destinations. BasicBlock *DestA = HasDefault ? SI->getDefaultDest() : nullptr; BasicBlock *DestB = nullptr; @@ -5813,37 +5868,62 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI, assert(!CasesA.empty() || HasDefault); // Figure out if one of the sets of cases form a contiguous range. - SmallVectorImpl<ConstantInt *> *ContiguousCases = nullptr; - BasicBlock *ContiguousDest = nullptr; - BasicBlock *OtherDest = nullptr; - if (!CasesA.empty() && casesAreContiguous(CasesA)) { - ContiguousCases = &CasesA; - ContiguousDest = DestA; - OtherDest = DestB; - } else if (casesAreContiguous(CasesB)) { - ContiguousCases = &CasesB; - ContiguousDest = DestB; - OtherDest = DestA; - } else - return false; + std::optional<ContiguousCasesResult> ContiguousCases; + + // Only one icmp is needed when there is only one case. + if (!HasDefault && CasesA.size() == 1) + ContiguousCases = ContiguousCasesResult{ + /*Min=*/CasesA[0], + /*Max=*/CasesA[0], + /*Dest=*/DestA, + /*OtherDest=*/DestB, + /*Cases=*/&CasesA, + /*OtherCases=*/&CasesB, + }; + else if (CasesB.size() == 1) + ContiguousCases = ContiguousCasesResult{ + /*Min=*/CasesB[0], + /*Max=*/CasesB[0], + /*Dest=*/DestB, + /*OtherDest=*/DestA, + /*Cases=*/&CasesB, + /*OtherCases=*/&CasesA, + }; + // Correctness: Cases to the default destination cannot be contiguous cases. + else if (!HasDefault) + ContiguousCases = + findContiguousCases(SI->getCondition(), CasesA, CasesB, DestA, DestB); - // Start building the compare and branch. + if (!ContiguousCases) + ContiguousCases = + findContiguousCases(SI->getCondition(), CasesB, CasesA, DestB, DestA); + + if (!ContiguousCases) + return false; - Constant *Offset = ConstantExpr::getNeg(ContiguousCases->back()); - Constant *NumCases = - ConstantInt::get(Offset->getType(), ContiguousCases->size()); + auto [Min, Max, Dest, OtherDest, Cases, OtherCases] = *ContiguousCases; - Value *Sub = SI->getCondition(); - if (!Offset->isNullValue()) - Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off"); + // Start building the compare and branch. - Value *Cmp; + Constant *Offset = ConstantExpr::getNeg(Min); + Constant *NumCases = ConstantInt::get(Offset->getType(), + Max->getValue() - Min->getValue() + 1); + BranchInst *NewBI; + if (NumCases->isOneValue()) { + assert(Max->getValue() == Min->getValue()); + Value *Cmp = Builder.CreateICmpEQ(SI->getCondition(), Min); + NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest); + } // If NumCases overflowed, then all possible values jump to the successor. - if (NumCases->isNullValue() && !ContiguousCases->empty()) - Cmp = ConstantInt::getTrue(SI->getContext()); - else - Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch"); - BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest); + else if (NumCases->isNullValue() && !Cases->empty()) { + NewBI = Builder.CreateBr(Dest); + } else { + Value *Sub = SI->getCondition(); + if (!Offset->isNullValue()) + Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off"); + Value *Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch"); + NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest); + } // Update weight for the newly-created conditional branch. if (hasBranchWeightMD(*SI)) { @@ -5853,7 +5933,7 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI, uint64_t TrueWeight = 0; uint64_t FalseWeight = 0; for (size_t I = 0, E = Weights.size(); I != E; ++I) { - if (SI->getSuccessor(I) == ContiguousDest) + if (SI->getSuccessor(I) == Dest) TrueWeight += Weights[I]; else FalseWeight += Weights[I]; @@ -5868,15 +5948,15 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI, } // Prune obsolete incoming values off the successors' PHI nodes. - for (auto BBI = ContiguousDest->begin(); isa<PHINode>(BBI); ++BBI) { - unsigned PreviousEdges = ContiguousCases->size(); - if (ContiguousDest == SI->getDefaultDest()) + for (auto BBI = Dest->begin(); isa<PHINode>(BBI); ++BBI) { + unsigned PreviousEdges = Cases->size(); + if (Dest == SI->getDefaultDest()) ++PreviousEdges; for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I) cast<PHINode>(BBI)->removeIncomingValue(SI->getParent()); } for (auto BBI = OtherDest->begin(); isa<PHINode>(BBI); ++BBI) { - unsigned PreviousEdges = SI->getNumCases() - ContiguousCases->size(); + unsigned PreviousEdges = OtherCases->size(); if (OtherDest == SI->getDefaultDest()) ++PreviousEdges; for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I) @@ -7877,19 +7957,27 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { bool SimplifyCFGOpt::simplifyIndirectBr(IndirectBrInst *IBI) { BasicBlock *BB = IBI->getParent(); bool Changed = false; + SmallVector<uint32_t> BranchWeights; + const bool HasBranchWeights = !ProfcheckDisableMetadataFixes && + extractBranchWeights(*IBI, BranchWeights); + + DenseMap<const BasicBlock *, uint64_t> TargetWeight; + if (HasBranchWeights) + for (size_t I = 0, E = IBI->getNumDestinations(); I < E; ++I) + TargetWeight[IBI->getDestination(I)] += BranchWeights[I]; // Eliminate redundant destinations. SmallPtrSet<Value *, 8> Succs; SmallSetVector<BasicBlock *, 8> RemovedSuccs; - for (unsigned i = 0, e = IBI->getNumDestinations(); i != e; ++i) { - BasicBlock *Dest = IBI->getDestination(i); + for (unsigned I = 0, E = IBI->getNumDestinations(); I != E; ++I) { + BasicBlock *Dest = IBI->getDestination(I); if (!Dest->hasAddressTaken() || !Succs.insert(Dest).second) { if (!Dest->hasAddressTaken()) RemovedSuccs.insert(Dest); Dest->removePredecessor(BB); - IBI->removeDestination(i); - --i; - --e; + IBI->removeDestination(I); + --I; + --E; Changed = true; } } @@ -7915,7 +8003,12 @@ bool SimplifyCFGOpt::simplifyIndirectBr(IndirectBrInst *IBI) { eraseTerminatorAndDCECond(IBI); return true; } - + if (HasBranchWeights) { + SmallVector<uint64_t> NewBranchWeights(IBI->getNumDestinations()); + for (size_t I = 0, E = IBI->getNumDestinations(); I < E; ++I) + NewBranchWeights[I] += TargetWeight.find(IBI->getDestination(I))->second; + setFittedBranchWeights(*IBI, NewBranchWeights, /*IsExpected=*/false); + } if (SelectInst *SI = dyn_cast<SelectInst>(IBI->getAddress())) { if (simplifyIndirectBrOnSelect(IBI, SI)) return requestResimplify(); diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 56a3d6d..d393a9c 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -8201,211 +8201,6 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, } } -/// Create and return a ResumePhi for \p WideIV, unless it is truncated. If the -/// induction recipe is not canonical, creates a VPDerivedIVRecipe to compute -/// the end value of the induction. -static VPInstruction *addResumePhiRecipeForInduction( - VPWidenInductionRecipe *WideIV, VPBuilder &VectorPHBuilder, - VPBuilder &ScalarPHBuilder, VPTypeAnalysis &TypeInfo, VPValue *VectorTC) { - auto *WideIntOrFp = dyn_cast<VPWidenIntOrFpInductionRecipe>(WideIV); - // Truncated wide inductions resume from the last lane of their vector value - // in the last vector iteration which is handled elsewhere. - if (WideIntOrFp && WideIntOrFp->getTruncInst()) - return nullptr; - - VPValue *Start = WideIV->getStartValue(); - VPValue *Step = WideIV->getStepValue(); - const InductionDescriptor &ID = WideIV->getInductionDescriptor(); - VPValue *EndValue = VectorTC; - if (!WideIntOrFp || !WideIntOrFp->isCanonical()) { - EndValue = VectorPHBuilder.createDerivedIV( - ID.getKind(), dyn_cast_or_null<FPMathOperator>(ID.getInductionBinOp()), - Start, VectorTC, Step); - } - - // EndValue is derived from the vector trip count (which has the same type as - // the widest induction) and thus may be wider than the induction here. - Type *ScalarTypeOfWideIV = TypeInfo.inferScalarType(WideIV); - if (ScalarTypeOfWideIV != TypeInfo.inferScalarType(EndValue)) { - EndValue = VectorPHBuilder.createScalarCast(Instruction::Trunc, EndValue, - ScalarTypeOfWideIV, - WideIV->getDebugLoc()); - } - - auto *ResumePhiRecipe = ScalarPHBuilder.createScalarPhi( - {EndValue, Start}, WideIV->getDebugLoc(), "bc.resume.val"); - return ResumePhiRecipe; -} - -/// Create resume phis in the scalar preheader for first-order recurrences, -/// reductions and inductions, and update the VPIRInstructions wrapping the -/// original phis in the scalar header. End values for inductions are added to -/// \p IVEndValues. -static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan, - DenseMap<VPValue *, VPValue *> &IVEndValues) { - VPTypeAnalysis TypeInfo(Plan); - auto *ScalarPH = Plan.getScalarPreheader(); - auto *MiddleVPBB = cast<VPBasicBlock>(ScalarPH->getPredecessors()[0]); - VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion(); - VPBuilder VectorPHBuilder( - cast<VPBasicBlock>(VectorRegion->getSinglePredecessor())); - VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi()); - VPBuilder ScalarPHBuilder(ScalarPH); - for (VPRecipeBase &ScalarPhiR : Plan.getScalarHeader()->phis()) { - auto *ScalarPhiIRI = cast<VPIRPhi>(&ScalarPhiR); - - // TODO: Extract final value from induction recipe initially, optimize to - // pre-computed end value together in optimizeInductionExitUsers. - auto *VectorPhiR = - cast<VPHeaderPHIRecipe>(Builder.getRecipe(&ScalarPhiIRI->getIRPhi())); - if (auto *WideIVR = dyn_cast<VPWidenInductionRecipe>(VectorPhiR)) { - if (VPInstruction *ResumePhi = addResumePhiRecipeForInduction( - WideIVR, VectorPHBuilder, ScalarPHBuilder, TypeInfo, - &Plan.getVectorTripCount())) { - assert(isa<VPPhi>(ResumePhi) && "Expected a phi"); - IVEndValues[WideIVR] = ResumePhi->getOperand(0); - ScalarPhiIRI->addOperand(ResumePhi); - continue; - } - // TODO: Also handle truncated inductions here. Computing end-values - // separately should be done as VPlan-to-VPlan optimization, after - // legalizing all resume values to use the last lane from the loop. - assert(cast<VPWidenIntOrFpInductionRecipe>(VectorPhiR)->getTruncInst() && - "should only skip truncated wide inductions"); - continue; - } - - // The backedge value provides the value to resume coming out of a loop, - // which for FORs is a vector whose last element needs to be extracted. The - // start value provides the value if the loop is bypassed. - bool IsFOR = isa<VPFirstOrderRecurrencePHIRecipe>(VectorPhiR); - auto *ResumeFromVectorLoop = VectorPhiR->getBackedgeValue(); - assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() && - "Cannot handle loops with uncountable early exits"); - if (IsFOR) - ResumeFromVectorLoop = MiddleBuilder.createNaryOp( - VPInstruction::ExtractLastElement, {ResumeFromVectorLoop}, {}, - "vector.recur.extract"); - StringRef Name = IsFOR ? "scalar.recur.init" : "bc.merge.rdx"; - auto *ResumePhiR = ScalarPHBuilder.createScalarPhi( - {ResumeFromVectorLoop, VectorPhiR->getStartValue()}, {}, Name); - ScalarPhiIRI->addOperand(ResumePhiR); - } -} - -/// Handle users in the exit block for first order reductions in the original -/// exit block. The penultimate value of recurrences is fed to their LCSSA phi -/// users in the original exit block using the VPIRInstruction wrapping to the -/// LCSSA phi. -static void addExitUsersForFirstOrderRecurrences(VPlan &Plan, VFRange &Range) { - VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion(); - auto *ScalarPHVPBB = Plan.getScalarPreheader(); - auto *MiddleVPBB = Plan.getMiddleBlock(); - VPBuilder ScalarPHBuilder(ScalarPHVPBB); - VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi()); - - auto IsScalableOne = [](ElementCount VF) -> bool { - return VF == ElementCount::getScalable(1); - }; - - for (auto &HeaderPhi : VectorRegion->getEntryBasicBlock()->phis()) { - auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&HeaderPhi); - if (!FOR) - continue; - - assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() && - "Cannot handle loops with uncountable early exits"); - - // This is the second phase of vectorizing first-order recurrences, creating - // extract for users outside the loop. An overview of the transformation is - // described below. Suppose we have the following loop with some use after - // the loop of the last a[i-1], - // - // for (int i = 0; i < n; ++i) { - // t = a[i - 1]; - // b[i] = a[i] - t; - // } - // use t; - // - // There is a first-order recurrence on "a". For this loop, the shorthand - // scalar IR looks like: - // - // scalar.ph: - // s.init = a[-1] - // br scalar.body - // - // scalar.body: - // i = phi [0, scalar.ph], [i+1, scalar.body] - // s1 = phi [s.init, scalar.ph], [s2, scalar.body] - // s2 = a[i] - // b[i] = s2 - s1 - // br cond, scalar.body, exit.block - // - // exit.block: - // use = lcssa.phi [s1, scalar.body] - // - // In this example, s1 is a recurrence because it's value depends on the - // previous iteration. In the first phase of vectorization, we created a - // VPFirstOrderRecurrencePHIRecipe v1 for s1. Now we create the extracts - // for users in the scalar preheader and exit block. - // - // vector.ph: - // v_init = vector(..., ..., ..., a[-1]) - // br vector.body - // - // vector.body - // i = phi [0, vector.ph], [i+4, vector.body] - // v1 = phi [v_init, vector.ph], [v2, vector.body] - // v2 = a[i, i+1, i+2, i+3] - // b[i] = v2 - v1 - // // Next, third phase will introduce v1' = splice(v1(3), v2(0, 1, 2)) - // b[i, i+1, i+2, i+3] = v2 - v1 - // br cond, vector.body, middle.block - // - // middle.block: - // vector.recur.extract.for.phi = v2(2) - // vector.recur.extract = v2(3) - // br cond, scalar.ph, exit.block - // - // scalar.ph: - // scalar.recur.init = phi [vector.recur.extract, middle.block], - // [s.init, otherwise] - // br scalar.body - // - // scalar.body: - // i = phi [0, scalar.ph], [i+1, scalar.body] - // s1 = phi [scalar.recur.init, scalar.ph], [s2, scalar.body] - // s2 = a[i] - // b[i] = s2 - s1 - // br cond, scalar.body, exit.block - // - // exit.block: - // lo = lcssa.phi [s1, scalar.body], - // [vector.recur.extract.for.phi, middle.block] - // - // Now update VPIRInstructions modeling LCSSA phis in the exit block. - // Extract the penultimate value of the recurrence and use it as operand for - // the VPIRInstruction modeling the phi. - for (VPUser *U : FOR->users()) { - using namespace llvm::VPlanPatternMatch; - if (!match(U, m_ExtractLastElement(m_Specific(FOR)))) - continue; - // For VF vscale x 1, if vscale = 1, we are unable to extract the - // penultimate value of the recurrence. Instead we rely on the existing - // extract of the last element from the result of - // VPInstruction::FirstOrderRecurrenceSplice. - // TODO: Consider vscale_range info and UF. - if (LoopVectorizationPlanner::getDecisionAndClampRange(IsScalableOne, - Range)) - return; - VPValue *PenultimateElement = MiddleBuilder.createNaryOp( - VPInstruction::ExtractPenultimateElement, {FOR->getBackedgeValue()}, - {}, "vector.recur.extract.for.phi"); - cast<VPInstruction>(U)->replaceAllUsesWith(PenultimateElement); - } - } -} - VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( VPlanPtr Plan, VFRange &Range, LoopVersioning *LVer) { @@ -8598,9 +8393,11 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes( R->setOperand(1, WideIV->getStepValue()); } - addExitUsersForFirstOrderRecurrences(*Plan, Range); + // TODO: We can't call runPass on these transforms yet, due to verifier + // failures. + VPlanTransforms::addExitUsersForFirstOrderRecurrences(*Plan, Range); DenseMap<VPValue *, VPValue *> IVEndValues; - addScalarResumePhis(RecipeBuilder, *Plan, IVEndValues); + VPlanTransforms::addScalarResumePhis(*Plan, RecipeBuilder, IVEndValues); // --------------------------------------------------------------------------- // Transform initial VPlan: Apply previously taken decisions, in order, to @@ -8711,7 +8508,9 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlan(VFRange &Range) { DenseMap<VPValue *, VPValue *> IVEndValues; // TODO: IVEndValues are not used yet in the native path, to optimize exit // values. - addScalarResumePhis(RecipeBuilder, *Plan, IVEndValues); + // TODO: We can't call runPass on the transform yet, due to verifier + // failures. + VPlanTransforms::addScalarResumePhis(*Plan, RecipeBuilder, IVEndValues); assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid"); return Plan; diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index fedca65..91c3d42 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -10620,7 +10620,8 @@ class InstructionsCompatibilityAnalysis { /// Checks if the opcode is supported as the main opcode for copyable /// elements. static bool isSupportedOpcode(const unsigned Opcode) { - return Opcode == Instruction::Add || Opcode == Instruction::LShr; + return Opcode == Instruction::Add || Opcode == Instruction::LShr || + Opcode == Instruction::Shl; } /// Identifies the best candidate value, which represents main opcode @@ -10937,6 +10938,7 @@ public: switch (MainOpcode) { case Instruction::Add: case Instruction::LShr: + case Instruction::Shl: VectorCost = TTI.getArithmeticInstrCost(MainOpcode, VecTy, Kind); break; default: @@ -22006,6 +22008,8 @@ bool BoUpSLP::collectValuesToDemote( return all_of(E.Scalars, [&](Value *V) { if (isa<PoisonValue>(V)) return true; + if (E.isCopyableElement(V)) + return true; auto *I = cast<Instruction>(V); KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL); return AmtKnownBits.getMaxValue().ult(BitWidth); diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index ca63bf3..ebf833e 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -4198,3 +4198,202 @@ void VPlanTransforms::addBranchWeightToMiddleTerminator( MDB.createBranchWeights({1, VectorStep - 1}, /*IsExpected=*/false); MiddleTerm->addMetadata(LLVMContext::MD_prof, BranchWeights); } + +/// Create and return a ResumePhi for \p WideIV, unless it is truncated. If the +/// induction recipe is not canonical, creates a VPDerivedIVRecipe to compute +/// the end value of the induction. +static VPInstruction *addResumePhiRecipeForInduction( + VPWidenInductionRecipe *WideIV, VPBuilder &VectorPHBuilder, + VPBuilder &ScalarPHBuilder, VPTypeAnalysis &TypeInfo, VPValue *VectorTC) { + auto *WideIntOrFp = dyn_cast<VPWidenIntOrFpInductionRecipe>(WideIV); + // Truncated wide inductions resume from the last lane of their vector value + // in the last vector iteration which is handled elsewhere. + if (WideIntOrFp && WideIntOrFp->getTruncInst()) + return nullptr; + + VPValue *Start = WideIV->getStartValue(); + VPValue *Step = WideIV->getStepValue(); + const InductionDescriptor &ID = WideIV->getInductionDescriptor(); + VPValue *EndValue = VectorTC; + if (!WideIntOrFp || !WideIntOrFp->isCanonical()) { + EndValue = VectorPHBuilder.createDerivedIV( + ID.getKind(), dyn_cast_or_null<FPMathOperator>(ID.getInductionBinOp()), + Start, VectorTC, Step); + } + + // EndValue is derived from the vector trip count (which has the same type as + // the widest induction) and thus may be wider than the induction here. + Type *ScalarTypeOfWideIV = TypeInfo.inferScalarType(WideIV); + if (ScalarTypeOfWideIV != TypeInfo.inferScalarType(EndValue)) { + EndValue = VectorPHBuilder.createScalarCast(Instruction::Trunc, EndValue, + ScalarTypeOfWideIV, + WideIV->getDebugLoc()); + } + + auto *ResumePhiRecipe = ScalarPHBuilder.createScalarPhi( + {EndValue, Start}, WideIV->getDebugLoc(), "bc.resume.val"); + return ResumePhiRecipe; +} + +void VPlanTransforms::addScalarResumePhis( + VPlan &Plan, VPRecipeBuilder &Builder, + DenseMap<VPValue *, VPValue *> &IVEndValues) { + VPTypeAnalysis TypeInfo(Plan); + auto *ScalarPH = Plan.getScalarPreheader(); + auto *MiddleVPBB = cast<VPBasicBlock>(ScalarPH->getPredecessors()[0]); + VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion(); + VPBuilder VectorPHBuilder( + cast<VPBasicBlock>(VectorRegion->getSinglePredecessor())); + VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi()); + VPBuilder ScalarPHBuilder(ScalarPH); + for (VPRecipeBase &ScalarPhiR : Plan.getScalarHeader()->phis()) { + auto *ScalarPhiIRI = cast<VPIRPhi>(&ScalarPhiR); + + // TODO: Extract final value from induction recipe initially, optimize to + // pre-computed end value together in optimizeInductionExitUsers. + auto *VectorPhiR = + cast<VPHeaderPHIRecipe>(Builder.getRecipe(&ScalarPhiIRI->getIRPhi())); + if (auto *WideIVR = dyn_cast<VPWidenInductionRecipe>(VectorPhiR)) { + if (VPInstruction *ResumePhi = addResumePhiRecipeForInduction( + WideIVR, VectorPHBuilder, ScalarPHBuilder, TypeInfo, + &Plan.getVectorTripCount())) { + assert(isa<VPPhi>(ResumePhi) && "Expected a phi"); + IVEndValues[WideIVR] = ResumePhi->getOperand(0); + ScalarPhiIRI->addOperand(ResumePhi); + continue; + } + // TODO: Also handle truncated inductions here. Computing end-values + // separately should be done as VPlan-to-VPlan optimization, after + // legalizing all resume values to use the last lane from the loop. + assert(cast<VPWidenIntOrFpInductionRecipe>(VectorPhiR)->getTruncInst() && + "should only skip truncated wide inductions"); + continue; + } + + // The backedge value provides the value to resume coming out of a loop, + // which for FORs is a vector whose last element needs to be extracted. The + // start value provides the value if the loop is bypassed. + bool IsFOR = isa<VPFirstOrderRecurrencePHIRecipe>(VectorPhiR); + auto *ResumeFromVectorLoop = VectorPhiR->getBackedgeValue(); + assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() && + "Cannot handle loops with uncountable early exits"); + if (IsFOR) + ResumeFromVectorLoop = MiddleBuilder.createNaryOp( + VPInstruction::ExtractLastElement, {ResumeFromVectorLoop}, {}, + "vector.recur.extract"); + StringRef Name = IsFOR ? "scalar.recur.init" : "bc.merge.rdx"; + auto *ResumePhiR = ScalarPHBuilder.createScalarPhi( + {ResumeFromVectorLoop, VectorPhiR->getStartValue()}, {}, Name); + ScalarPhiIRI->addOperand(ResumePhiR); + } +} + +void VPlanTransforms::addExitUsersForFirstOrderRecurrences(VPlan &Plan, + VFRange &Range) { + VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion(); + auto *ScalarPHVPBB = Plan.getScalarPreheader(); + auto *MiddleVPBB = Plan.getMiddleBlock(); + VPBuilder ScalarPHBuilder(ScalarPHVPBB); + VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi()); + + auto IsScalableOne = [](ElementCount VF) -> bool { + return VF == ElementCount::getScalable(1); + }; + + for (auto &HeaderPhi : VectorRegion->getEntryBasicBlock()->phis()) { + auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&HeaderPhi); + if (!FOR) + continue; + + assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() && + "Cannot handle loops with uncountable early exits"); + + // This is the second phase of vectorizing first-order recurrences, creating + // extract for users outside the loop. An overview of the transformation is + // described below. Suppose we have the following loop with some use after + // the loop of the last a[i-1], + // + // for (int i = 0; i < n; ++i) { + // t = a[i - 1]; + // b[i] = a[i] - t; + // } + // use t; + // + // There is a first-order recurrence on "a". For this loop, the shorthand + // scalar IR looks like: + // + // scalar.ph: + // s.init = a[-1] + // br scalar.body + // + // scalar.body: + // i = phi [0, scalar.ph], [i+1, scalar.body] + // s1 = phi [s.init, scalar.ph], [s2, scalar.body] + // s2 = a[i] + // b[i] = s2 - s1 + // br cond, scalar.body, exit.block + // + // exit.block: + // use = lcssa.phi [s1, scalar.body] + // + // In this example, s1 is a recurrence because it's value depends on the + // previous iteration. In the first phase of vectorization, we created a + // VPFirstOrderRecurrencePHIRecipe v1 for s1. Now we create the extracts + // for users in the scalar preheader and exit block. + // + // vector.ph: + // v_init = vector(..., ..., ..., a[-1]) + // br vector.body + // + // vector.body + // i = phi [0, vector.ph], [i+4, vector.body] + // v1 = phi [v_init, vector.ph], [v2, vector.body] + // v2 = a[i, i+1, i+2, i+3] + // b[i] = v2 - v1 + // // Next, third phase will introduce v1' = splice(v1(3), v2(0, 1, 2)) + // b[i, i+1, i+2, i+3] = v2 - v1 + // br cond, vector.body, middle.block + // + // middle.block: + // vector.recur.extract.for.phi = v2(2) + // vector.recur.extract = v2(3) + // br cond, scalar.ph, exit.block + // + // scalar.ph: + // scalar.recur.init = phi [vector.recur.extract, middle.block], + // [s.init, otherwise] + // br scalar.body + // + // scalar.body: + // i = phi [0, scalar.ph], [i+1, scalar.body] + // s1 = phi [scalar.recur.init, scalar.ph], [s2, scalar.body] + // s2 = a[i] + // b[i] = s2 - s1 + // br cond, scalar.body, exit.block + // + // exit.block: + // lo = lcssa.phi [s1, scalar.body], + // [vector.recur.extract.for.phi, middle.block] + // + // Now update VPIRInstructions modeling LCSSA phis in the exit block. + // Extract the penultimate value of the recurrence and use it as operand for + // the VPIRInstruction modeling the phi. + for (VPUser *U : FOR->users()) { + using namespace llvm::VPlanPatternMatch; + if (!match(U, m_ExtractLastElement(m_Specific(FOR)))) + continue; + // For VF vscale x 1, if vscale = 1, we are unable to extract the + // penultimate value of the recurrence. Instead we rely on the existing + // extract of the last element from the result of + // VPInstruction::FirstOrderRecurrenceSplice. + // TODO: Consider vscale_range info and UF. + if (LoopVectorizationPlanner::getDecisionAndClampRange(IsScalableOne, + Range)) + return; + VPValue *PenultimateElement = MiddleBuilder.createNaryOp( + VPInstruction::ExtractPenultimateElement, {FOR->getBackedgeValue()}, + {}, "vector.recur.extract.for.phi"); + cast<VPInstruction>(U)->replaceAllUsesWith(PenultimateElement); + } + } +} diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h index 2f00e51..5a8a2bb 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -363,6 +363,19 @@ struct VPlanTransforms { static void addBranchWeightToMiddleTerminator(VPlan &Plan, ElementCount VF, std::optional<unsigned> VScaleForTuning); + + /// Create resume phis in the scalar preheader for first-order recurrences, + /// reductions and inductions, and update the VPIRInstructions wrapping the + /// original phis in the scalar header. End values for inductions are added to + /// \p IVEndValues. + static void addScalarResumePhis(VPlan &Plan, VPRecipeBuilder &Builder, + DenseMap<VPValue *, VPValue *> &IVEndValues); + + /// Handle users in the exit block for first order reductions in the original + /// exit block. The penultimate value of recurrences is fed to their LCSSA phi + /// users in the original exit block using the VPIRInstruction wrapping to the + /// LCSSA phi. + static void addExitUsersForFirstOrderRecurrences(VPlan &Plan, VFRange &Range); }; } // namespace llvm |