aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Analysis/MemoryProfileInfo.cpp28
-rw-r--r--llvm/lib/BinaryFormat/DXContainer.cpp64
-rw-r--r--llvm/lib/CAS/CMakeLists.txt1
-rw-r--r--llvm/lib/CAS/OnDiskDataAllocator.cpp234
-rw-r--r--llvm/lib/CAS/OnDiskTrieRawHashMap.cpp31
-rw-r--r--llvm/lib/IR/Globals.cpp1
-rw-r--r--llvm/lib/Option/ArgList.cpp38
-rw-r--r--llvm/lib/Option/OptTable.cpp76
-rw-r--r--llvm/lib/Passes/PassBuilderPipelines.cpp1
-rw-r--r--llvm/lib/Passes/PassRegistry.def1
-rw-r--r--llvm/lib/Target/AArch64/AArch64FrameLowering.cpp357
-rw-r--r--llvm/lib/Target/AArch64/AArch64InstrInfo.cpp64
-rw-r--r--llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp3
-rw-r--r--llvm/lib/Target/AArch64/AArch64RegisterInfo.td11
-rw-r--r--llvm/lib/Target/AArch64/AArch64Subtarget.cpp19
-rw-r--r--llvm/lib/Target/AArch64/AArch64Subtarget.h2
-rw-r--r--llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp11
-rw-r--r--llvm/lib/Target/AArch64/SMEInstrFormats.td14
-rw-r--r--llvm/lib/Target/AMDGPU/AMDGPU.td21
-rw-r--r--llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp63
-rw-r--r--llvm/lib/Target/AMDGPU/GCNSubtarget.cpp2
-rw-r--r--llvm/lib/Target/AMDGPU/GCNSubtarget.h4
-rw-r--r--llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp11
-rw-r--r--llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp7
-rw-r--r--llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp10
-rw-r--r--llvm/lib/Target/AMDGPU/VOP3Instructions.td31
-rw-r--r--llvm/lib/Target/AMDGPU/VOPInstructions.td7
-rw-r--r--llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp3
-rw-r--r--llvm/lib/Target/NVPTX/NVPTX.h1
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp76
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.h5
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td54
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td97
-rw-r--r--llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp54
-rw-r--r--llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp39
-rw-r--r--llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h17
-rw-r--r--llvm/lib/Target/PowerPC/PPCInstr64Bit.td24
-rw-r--r--llvm/lib/Target/PowerPC/PPCInstrAltivec.td19
-rw-r--r--llvm/lib/Target/PowerPC/PPCRegisterInfo.td67
-rw-r--r--llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp8
-rw-r--r--llvm/lib/Target/RISCV/RISCVFeatures.td2
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp2
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td2
-rw-r--r--llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp2
-rw-r--r--llvm/lib/Target/RISCV/RISCVSubtarget.cpp12
-rw-r--r--llvm/lib/Target/RISCV/RISCVSubtarget.h4
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp4
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp134
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp192
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVUtils.cpp6
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVUtils.h3
-rw-r--r--llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp1
-rw-r--r--llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp49
-rw-r--r--llvm/lib/Target/X86/X86InstrAVX512.td90
-rw-r--r--llvm/lib/TargetParser/TargetParser.cpp1
-rw-r--r--llvm/lib/Transforms/IPO/FunctionAttrs.cpp119
-rw-r--r--llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp9
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp10
-rw-r--r--llvm/lib/Transforms/Utils/SCCPSolver.cpp39
-rw-r--r--llvm/lib/Transforms/Utils/SimplifyCFG.cpp189
-rw-r--r--llvm/lib/Transforms/Vectorize/LoopVectorize.cpp215
-rw-r--r--llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp6
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp199
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.h13
64 files changed, 1806 insertions, 1073 deletions
diff --git a/llvm/lib/Analysis/MemoryProfileInfo.cpp b/llvm/lib/Analysis/MemoryProfileInfo.cpp
index 0c1f8db..92a5b6f 100644
--- a/llvm/lib/Analysis/MemoryProfileInfo.cpp
+++ b/llvm/lib/Analysis/MemoryProfileInfo.cpp
@@ -54,6 +54,10 @@ cl::opt<unsigned> MinPercentMaxColdSize(
"memprof-min-percent-max-cold-size", cl::init(100), cl::Hidden,
cl::desc("Min percent of max cold bytes for critical cold context"));
+LLVM_ABI cl::opt<bool> MemProfUseAmbiguousAttributes(
+ "memprof-ambiguous-attributes", cl::init(true), cl::Hidden,
+ cl::desc("Apply ambiguous memprof attribute to ambiguous allocations"));
+
} // end namespace llvm
bool llvm::memprof::metadataIncludesAllContextSizeInfo() {
@@ -125,6 +129,26 @@ bool llvm::memprof::hasSingleAllocType(uint8_t AllocTypes) {
return NumAllocTypes == 1;
}
+void llvm::memprof::removeAnyExistingAmbiguousAttribute(CallBase *CB) {
+ if (!CB->hasFnAttr("memprof"))
+ return;
+ assert(CB->getFnAttr("memprof").getValueAsString() == "ambiguous");
+ CB->removeFnAttr("memprof");
+}
+
+void llvm::memprof::addAmbiguousAttribute(CallBase *CB) {
+ if (!MemProfUseAmbiguousAttributes)
+ return;
+ // We may have an existing ambiguous attribute if we are reanalyzing
+ // after inlining.
+ if (CB->hasFnAttr("memprof")) {
+ assert(CB->getFnAttr("memprof").getValueAsString() == "ambiguous");
+ } else {
+ auto A = llvm::Attribute::get(CB->getContext(), "memprof", "ambiguous");
+ CB->addFnAttr(A);
+ }
+}
+
void CallStackTrie::addCallStack(
AllocationType AllocType, ArrayRef<uint64_t> StackIds,
std::vector<ContextTotalSize> ContextSizeInfo) {
@@ -470,6 +494,9 @@ void CallStackTrie::addSingleAllocTypeAttribute(CallBase *CI, AllocationType AT,
StringRef Descriptor) {
auto AllocTypeString = getAllocTypeAttributeString(AT);
auto A = llvm::Attribute::get(CI->getContext(), "memprof", AllocTypeString);
+ // After inlining we may be able to convert an existing ambiguous allocation
+ // to an unambiguous one.
+ removeAnyExistingAmbiguousAttribute(CI);
CI->addFnAttr(A);
if (MemProfReportHintedSizes) {
std::vector<ContextTotalSize> ContextSizeInfo;
@@ -529,6 +556,7 @@ bool CallStackTrie::buildAndAttachMIBMetadata(CallBase *CI) {
assert(MIBCallStack.size() == 1 &&
"Should only be left with Alloc's location in stack");
CI->setMetadata(LLVMContext::MD_memprof, MDNode::get(Ctx, MIBNodes));
+ addAmbiguousAttribute(CI);
return true;
}
// If there exists corner case that CallStackTrie has one chain to leaf
diff --git a/llvm/lib/BinaryFormat/DXContainer.cpp b/llvm/lib/BinaryFormat/DXContainer.cpp
index c06a3e3..b334f86 100644
--- a/llvm/lib/BinaryFormat/DXContainer.cpp
+++ b/llvm/lib/BinaryFormat/DXContainer.cpp
@@ -18,6 +18,70 @@
using namespace llvm;
using namespace llvm::dxbc;
+#define ROOT_PARAMETER(Val, Enum) \
+ case Val: \
+ return true;
+bool llvm::dxbc::isValidParameterType(uint32_t V) {
+ switch (V) {
+#include "llvm/BinaryFormat/DXContainerConstants.def"
+ }
+ return false;
+}
+
+bool llvm::dxbc::isValidRangeType(uint32_t V) {
+ return V <= llvm::to_underlying(dxil::ResourceClass::LastEntry);
+}
+
+#define SHADER_VISIBILITY(Val, Enum) \
+ case Val: \
+ return true;
+bool llvm::dxbc::isValidShaderVisibility(uint32_t V) {
+ switch (V) {
+#include "llvm/BinaryFormat/DXContainerConstants.def"
+ }
+ return false;
+}
+
+#define FILTER(Val, Enum) \
+ case Val: \
+ return true;
+bool llvm::dxbc::isValidSamplerFilter(uint32_t V) {
+ switch (V) {
+#include "llvm/BinaryFormat/DXContainerConstants.def"
+ }
+ return false;
+}
+
+#define TEXTURE_ADDRESS_MODE(Val, Enum) \
+ case Val: \
+ return true;
+bool llvm::dxbc::isValidAddress(uint32_t V) {
+ switch (V) {
+#include "llvm/BinaryFormat/DXContainerConstants.def"
+ }
+ return false;
+}
+
+#define COMPARISON_FUNC(Val, Enum) \
+ case Val: \
+ return true;
+bool llvm::dxbc::isValidComparisonFunc(uint32_t V) {
+ switch (V) {
+#include "llvm/BinaryFormat/DXContainerConstants.def"
+ }
+ return false;
+}
+
+#define STATIC_BORDER_COLOR(Val, Enum) \
+ case Val: \
+ return true;
+bool llvm::dxbc::isValidBorderColor(uint32_t V) {
+ switch (V) {
+#include "llvm/BinaryFormat/DXContainerConstants.def"
+ }
+ return false;
+}
+
dxbc::PartType dxbc::parsePartType(StringRef S) {
#define CONTAINER_PART(PartName) .Case(#PartName, PartType::PartName)
return StringSwitch<dxbc::PartType>(S)
diff --git a/llvm/lib/CAS/CMakeLists.txt b/llvm/lib/CAS/CMakeLists.txt
index 7ae5f7e..bca39b6 100644
--- a/llvm/lib/CAS/CMakeLists.txt
+++ b/llvm/lib/CAS/CMakeLists.txt
@@ -7,6 +7,7 @@ add_llvm_component_library(LLVMCAS
MappedFileRegionArena.cpp
ObjectStore.cpp
OnDiskCommon.cpp
+ OnDiskDataAllocator.cpp
OnDiskTrieRawHashMap.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/llvm/lib/CAS/OnDiskDataAllocator.cpp b/llvm/lib/CAS/OnDiskDataAllocator.cpp
new file mode 100644
index 0000000..13bbd66
--- /dev/null
+++ b/llvm/lib/CAS/OnDiskDataAllocator.cpp
@@ -0,0 +1,234 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file Implements OnDiskDataAllocator.
+///
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CAS/OnDiskDataAllocator.h"
+#include "DatabaseFile.h"
+#include "llvm/Config/llvm-config.h"
+
+using namespace llvm;
+using namespace llvm::cas;
+using namespace llvm::cas::ondisk;
+
+#if LLVM_ENABLE_ONDISK_CAS
+
+//===----------------------------------------------------------------------===//
+// DataAllocator data structures.
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// DataAllocator table layout:
+/// - [8-bytes: Generic table header]
+/// - 8-bytes: AllocatorOffset (reserved for implementing free lists)
+/// - 8-bytes: Size for user data header
+/// - <user data buffer>
+///
+/// Record layout:
+/// - <data>
+class DataAllocatorHandle {
+public:
+ static constexpr TableHandle::TableKind Kind =
+ TableHandle::TableKind::DataAllocator;
+
+ struct Header {
+ TableHandle::Header GenericHeader;
+ std::atomic<int64_t> AllocatorOffset;
+ const uint64_t UserHeaderSize;
+ };
+
+ operator TableHandle() const {
+ if (!H)
+ return TableHandle();
+ return TableHandle(*Region, H->GenericHeader);
+ }
+
+ Expected<MutableArrayRef<char>> allocate(MappedFileRegionArena &Alloc,
+ size_t DataSize) {
+ assert(&Alloc.getRegion() == Region);
+ auto Ptr = Alloc.allocate(DataSize);
+ if (LLVM_UNLIKELY(!Ptr))
+ return Ptr.takeError();
+ return MutableArrayRef(*Ptr, DataSize);
+ }
+
+ explicit operator bool() const { return H; }
+ const Header &getHeader() const { return *H; }
+ MappedFileRegion &getRegion() const { return *Region; }
+
+ MutableArrayRef<uint8_t> getUserHeader() {
+ return MutableArrayRef(reinterpret_cast<uint8_t *>(H + 1),
+ H->UserHeaderSize);
+ }
+
+ static Expected<DataAllocatorHandle>
+ create(MappedFileRegionArena &Alloc, StringRef Name, uint32_t UserHeaderSize);
+
+ DataAllocatorHandle() = default;
+ DataAllocatorHandle(MappedFileRegion &Region, Header &H)
+ : Region(&Region), H(&H) {}
+ DataAllocatorHandle(MappedFileRegion &Region, intptr_t HeaderOffset)
+ : DataAllocatorHandle(
+ Region, *reinterpret_cast<Header *>(Region.data() + HeaderOffset)) {
+ }
+
+private:
+ MappedFileRegion *Region = nullptr;
+ Header *H = nullptr;
+};
+
+} // end anonymous namespace
+
+struct OnDiskDataAllocator::ImplType {
+ DatabaseFile File;
+ DataAllocatorHandle Store;
+};
+
+Expected<DataAllocatorHandle>
+DataAllocatorHandle::create(MappedFileRegionArena &Alloc, StringRef Name,
+ uint32_t UserHeaderSize) {
+ // Allocate.
+ auto Offset =
+ Alloc.allocateOffset(sizeof(Header) + UserHeaderSize + Name.size() + 1);
+ if (LLVM_UNLIKELY(!Offset))
+ return Offset.takeError();
+
+ // Construct the header and the name.
+ assert(Name.size() <= UINT16_MAX && "Expected smaller table name");
+ auto *H = new (Alloc.getRegion().data() + *Offset)
+ Header{{TableHandle::TableKind::DataAllocator,
+ static_cast<uint16_t>(Name.size()),
+ static_cast<int32_t>(sizeof(Header) + UserHeaderSize)},
+ /*AllocatorOffset=*/{0},
+ /*UserHeaderSize=*/UserHeaderSize};
+ // Memset UserHeader.
+ char *UserHeader = reinterpret_cast<char *>(H + 1);
+ memset(UserHeader, 0, UserHeaderSize);
+ // Write database file name (null-terminated).
+ char *NameStorage = UserHeader + UserHeaderSize;
+ llvm::copy(Name, NameStorage);
+ NameStorage[Name.size()] = 0;
+ return DataAllocatorHandle(Alloc.getRegion(), *H);
+}
+
+Expected<OnDiskDataAllocator> OnDiskDataAllocator::create(
+ const Twine &PathTwine, const Twine &TableNameTwine, uint64_t MaxFileSize,
+ std::optional<uint64_t> NewFileInitialSize, uint32_t UserHeaderSize,
+ function_ref<void(void *)> UserHeaderInit) {
+ assert(!UserHeaderSize || UserHeaderInit);
+ SmallString<128> PathStorage;
+ StringRef Path = PathTwine.toStringRef(PathStorage);
+ SmallString<128> TableNameStorage;
+ StringRef TableName = TableNameTwine.toStringRef(TableNameStorage);
+
+ // Constructor for if the file doesn't exist.
+ auto NewDBConstructor = [&](DatabaseFile &DB) -> Error {
+ auto Store =
+ DataAllocatorHandle::create(DB.getAlloc(), TableName, UserHeaderSize);
+ if (LLVM_UNLIKELY(!Store))
+ return Store.takeError();
+
+ if (auto E = DB.addTable(*Store))
+ return E;
+
+ if (UserHeaderSize)
+ UserHeaderInit(Store->getUserHeader().data());
+ return Error::success();
+ };
+
+ // Get or create the file.
+ Expected<DatabaseFile> File =
+ DatabaseFile::create(Path, MaxFileSize, NewDBConstructor);
+ if (!File)
+ return File.takeError();
+
+ // Find the table and validate it.
+ std::optional<TableHandle> Table = File->findTable(TableName);
+ if (!Table)
+ return createTableConfigError(std::errc::argument_out_of_domain, Path,
+ TableName, "table not found");
+ if (Error E = checkTable("table kind", (size_t)DataAllocatorHandle::Kind,
+ (size_t)Table->getHeader().Kind, Path, TableName))
+ return std::move(E);
+ auto Store = Table->cast<DataAllocatorHandle>();
+ assert(Store && "Already checked the kind");
+
+ // Success.
+ OnDiskDataAllocator::ImplType Impl{DatabaseFile(std::move(*File)), Store};
+ return OnDiskDataAllocator(std::make_unique<ImplType>(std::move(Impl)));
+}
+
+Expected<OnDiskDataAllocator::OnDiskPtr>
+OnDiskDataAllocator::allocate(size_t Size) {
+ auto Data = Impl->Store.allocate(Impl->File.getAlloc(), Size);
+ if (LLVM_UNLIKELY(!Data))
+ return Data.takeError();
+
+ return OnDiskPtr(FileOffset(Data->data() - Impl->Store.getRegion().data()),
+ *Data);
+}
+
+Expected<ArrayRef<char>> OnDiskDataAllocator::get(FileOffset Offset,
+ size_t Size) const {
+ assert(Offset);
+ assert(Impl);
+ if (Offset.get() + Size >= Impl->File.getAlloc().size())
+ return createStringError(make_error_code(std::errc::protocol_error),
+ "requested size too large in allocator");
+ return ArrayRef<char>{Impl->File.getRegion().data() + Offset.get(), Size};
+}
+
+MutableArrayRef<uint8_t> OnDiskDataAllocator::getUserHeader() {
+ return Impl->Store.getUserHeader();
+}
+
+size_t OnDiskDataAllocator::size() const { return Impl->File.size(); }
+size_t OnDiskDataAllocator::capacity() const {
+ return Impl->File.getRegion().size();
+}
+
+OnDiskDataAllocator::OnDiskDataAllocator(std::unique_ptr<ImplType> Impl)
+ : Impl(std::move(Impl)) {}
+
+#else // !LLVM_ENABLE_ONDISK_CAS
+
+struct OnDiskDataAllocator::ImplType {};
+
+Expected<OnDiskDataAllocator> OnDiskDataAllocator::create(
+ const Twine &Path, const Twine &TableName, uint64_t MaxFileSize,
+ std::optional<uint64_t> NewFileInitialSize, uint32_t UserHeaderSize,
+ function_ref<void(void *)> UserHeaderInit) {
+ return createStringError(make_error_code(std::errc::not_supported),
+ "OnDiskDataAllocator is not supported");
+}
+
+Expected<OnDiskDataAllocator::OnDiskPtr>
+OnDiskDataAllocator::allocate(size_t Size) {
+ return createStringError(make_error_code(std::errc::not_supported),
+ "OnDiskDataAllocator is not supported");
+}
+
+Expected<ArrayRef<char>> OnDiskDataAllocator::get(FileOffset Offset,
+ size_t Size) const {
+ return createStringError(make_error_code(std::errc::not_supported),
+ "OnDiskDataAllocator is not supported");
+}
+
+MutableArrayRef<uint8_t> OnDiskDataAllocator::getUserHeader() { return {}; }
+
+size_t OnDiskDataAllocator::size() const { return 0; }
+size_t OnDiskDataAllocator::capacity() const { return 0; }
+
+#endif // LLVM_ENABLE_ONDISK_CAS
+
+OnDiskDataAllocator::OnDiskDataAllocator(OnDiskDataAllocator &&RHS) = default;
+OnDiskDataAllocator &
+OnDiskDataAllocator::operator=(OnDiskDataAllocator &&RHS) = default;
+OnDiskDataAllocator::~OnDiskDataAllocator() = default;
diff --git a/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp b/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp
index 9403893..323b21e 100644
--- a/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp
+++ b/llvm/lib/CAS/OnDiskTrieRawHashMap.cpp
@@ -427,7 +427,7 @@ TrieRawHashMapHandle::createRecord(MappedFileRegionArena &Alloc,
return Record;
}
-Expected<OnDiskTrieRawHashMap::const_pointer>
+Expected<OnDiskTrieRawHashMap::ConstOnDiskPtr>
OnDiskTrieRawHashMap::recoverFromFileOffset(FileOffset Offset) const {
// Check alignment.
if (!isAligned(MappedFileRegionArena::getAlign(), Offset.get()))
@@ -448,17 +448,17 @@ OnDiskTrieRawHashMap::recoverFromFileOffset(FileOffset Offset) const {
// Looks okay...
TrieRawHashMapHandle::RecordData D =
Impl->Trie.getRecord(SubtrieSlotValue::getDataOffset(Offset));
- return const_pointer(D.Proxy, D.getFileOffset());
+ return ConstOnDiskPtr(D.Proxy, D.getFileOffset());
}
-OnDiskTrieRawHashMap::const_pointer
+OnDiskTrieRawHashMap::ConstOnDiskPtr
OnDiskTrieRawHashMap::find(ArrayRef<uint8_t> Hash) const {
TrieRawHashMapHandle Trie = Impl->Trie;
assert(Hash.size() == Trie.getNumHashBytes() && "Invalid hash");
SubtrieHandle S = Trie.getRoot();
if (!S)
- return const_pointer();
+ return ConstOnDiskPtr();
TrieHashIndexGenerator IndexGen = Trie.getIndexGen(S, Hash);
size_t Index = IndexGen.next();
@@ -466,13 +466,13 @@ OnDiskTrieRawHashMap::find(ArrayRef<uint8_t> Hash) const {
// Try to set the content.
SubtrieSlotValue V = S.load(Index);
if (!V)
- return const_pointer();
+ return ConstOnDiskPtr();
// Check for an exact match.
if (V.isData()) {
TrieRawHashMapHandle::RecordData D = Trie.getRecord(V);
- return D.Proxy.Hash == Hash ? const_pointer(D.Proxy, D.getFileOffset())
- : const_pointer();
+ return D.Proxy.Hash == Hash ? ConstOnDiskPtr(D.Proxy, D.getFileOffset())
+ : ConstOnDiskPtr();
}
Index = IndexGen.next();
@@ -490,7 +490,7 @@ void SubtrieHandle::reinitialize(uint32_t StartBit, uint32_t NumBits) {
H->NumBits = NumBits;
}
-Expected<OnDiskTrieRawHashMap::pointer>
+Expected<OnDiskTrieRawHashMap::OnDiskPtr>
OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash,
LazyInsertOnConstructCB OnConstruct,
LazyInsertOnLeakCB OnLeak) {
@@ -523,7 +523,8 @@ OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash,
}
if (S->compare_exchange_strong(Index, Existing, NewRecord->Offset))
- return pointer(NewRecord->Proxy, NewRecord->Offset.asDataFileOffset());
+ return OnDiskPtr(NewRecord->Proxy,
+ NewRecord->Offset.asDataFileOffset());
// Race means that Existing is no longer empty; fall through...
}
@@ -540,8 +541,8 @@ OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash,
if (NewRecord && OnLeak)
OnLeak(NewRecord->Offset.asDataFileOffset(), NewRecord->Proxy,
ExistingRecord.Offset.asDataFileOffset(), ExistingRecord.Proxy);
- return pointer(ExistingRecord.Proxy,
- ExistingRecord.Offset.asDataFileOffset());
+ return OnDiskPtr(ExistingRecord.Proxy,
+ ExistingRecord.Offset.asDataFileOffset());
}
// Sink the existing content as long as the indexes match.
@@ -1135,7 +1136,7 @@ OnDiskTrieRawHashMap::create(const Twine &PathTwine, const Twine &TrieNameTwine,
"OnDiskTrieRawHashMap is not supported");
}
-Expected<OnDiskTrieRawHashMap::pointer>
+Expected<OnDiskTrieRawHashMap::OnDiskPtr>
OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash,
LazyInsertOnConstructCB OnConstruct,
LazyInsertOnLeakCB OnLeak) {
@@ -1143,15 +1144,15 @@ OnDiskTrieRawHashMap::insertLazy(ArrayRef<uint8_t> Hash,
"OnDiskTrieRawHashMap is not supported");
}
-Expected<OnDiskTrieRawHashMap::const_pointer>
+Expected<OnDiskTrieRawHashMap::ConstOnDiskPtr>
OnDiskTrieRawHashMap::recoverFromFileOffset(FileOffset Offset) const {
return createStringError(make_error_code(std::errc::not_supported),
"OnDiskTrieRawHashMap is not supported");
}
-OnDiskTrieRawHashMap::const_pointer
+OnDiskTrieRawHashMap::ConstOnDiskPtr
OnDiskTrieRawHashMap::find(ArrayRef<uint8_t> Hash) const {
- return const_pointer();
+ return ConstOnDiskPtr();
}
void OnDiskTrieRawHashMap::print(
diff --git a/llvm/lib/IR/Globals.cpp b/llvm/lib/IR/Globals.cpp
index 1a7a5c5..c3a472b 100644
--- a/llvm/lib/IR/Globals.cpp
+++ b/llvm/lib/IR/Globals.cpp
@@ -419,6 +419,7 @@ findBaseObject(const Constant *C, DenseSet<const GlobalAlias *> &Aliases,
case Instruction::PtrToAddr:
case Instruction::PtrToInt:
case Instruction::BitCast:
+ case Instruction::AddrSpaceCast:
case Instruction::GetElementPtr:
return findBaseObject(CE->getOperand(0), Aliases, Op);
default:
diff --git a/llvm/lib/Option/ArgList.cpp b/llvm/lib/Option/ArgList.cpp
index c4188b3b..2f4e212 100644
--- a/llvm/lib/Option/ArgList.cpp
+++ b/llvm/lib/Option/ArgList.cpp
@@ -14,12 +14,14 @@
#include "llvm/Config/llvm-config.h"
#include "llvm/Option/Arg.h"
#include "llvm/Option/OptSpecifier.h"
+#include "llvm/Option/OptTable.h"
#include "llvm/Option/Option.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
+#include <cstddef>
#include <memory>
#include <string>
#include <utility>
@@ -202,6 +204,42 @@ void ArgList::print(raw_ostream &O) const {
LLVM_DUMP_METHOD void ArgList::dump() const { print(dbgs()); }
#endif
+StringRef ArgList::getSubCommand(
+ ArrayRef<OptTable::SubCommand> AllSubCommands,
+ std::function<void(ArrayRef<StringRef>)> HandleMultipleSubcommands,
+ std::function<void(ArrayRef<StringRef>)> HandleOtherPositionals) const {
+
+ SmallVector<StringRef, 4> SubCommands;
+ SmallVector<StringRef, 4> OtherPositionals;
+ for (const Arg *A : *this) {
+ if (A->getOption().getKind() != Option::InputClass)
+ continue;
+
+ size_t OldSize = SubCommands.size();
+ for (const OptTable::SubCommand &CMD : AllSubCommands) {
+ if (StringRef(CMD.Name) == A->getValue())
+ SubCommands.push_back(A->getValue());
+ }
+
+ if (SubCommands.size() == OldSize)
+ OtherPositionals.push_back(A->getValue());
+ }
+
+ // Invoke callbacks if necessary.
+ if (SubCommands.size() > 1) {
+ HandleMultipleSubcommands(SubCommands);
+ return {};
+ }
+ if (!OtherPositionals.empty()) {
+ HandleOtherPositionals(OtherPositionals);
+ return {};
+ }
+
+ if (SubCommands.size() == 1)
+ return SubCommands.front();
+ return {}; // No valid usage of subcommand found.
+}
+
void InputArgList::releaseMemory() {
// An InputArgList always owns its arguments.
for (Arg *A : *this)
diff --git a/llvm/lib/Option/OptTable.cpp b/llvm/lib/Option/OptTable.cpp
index 6d10e61..14e3b0d 100644
--- a/llvm/lib/Option/OptTable.cpp
+++ b/llvm/lib/Option/OptTable.cpp
@@ -79,9 +79,12 @@ OptSpecifier::OptSpecifier(const Option *Opt) : ID(Opt->getID()) {}
OptTable::OptTable(const StringTable &StrTable,
ArrayRef<StringTable::Offset> PrefixesTable,
- ArrayRef<Info> OptionInfos, bool IgnoreCase)
+ ArrayRef<Info> OptionInfos, bool IgnoreCase,
+ ArrayRef<SubCommand> SubCommands,
+ ArrayRef<unsigned> SubCommandIDsTable)
: StrTable(&StrTable), PrefixesTable(PrefixesTable),
- OptionInfos(OptionInfos), IgnoreCase(IgnoreCase) {
+ OptionInfos(OptionInfos), IgnoreCase(IgnoreCase),
+ SubCommands(SubCommands), SubCommandIDsTable(SubCommandIDsTable) {
// Explicitly zero initialize the error to work around a bug in array
// value-initialization on MinGW with gcc 4.3.5.
@@ -715,9 +718,10 @@ static const char *getOptionHelpGroup(const OptTable &Opts, OptSpecifier Id) {
void OptTable::printHelp(raw_ostream &OS, const char *Usage, const char *Title,
bool ShowHidden, bool ShowAllAliases,
- Visibility VisibilityMask) const {
+ Visibility VisibilityMask,
+ StringRef SubCommand) const {
return internalPrintHelp(
- OS, Usage, Title, ShowHidden, ShowAllAliases,
+ OS, Usage, Title, SubCommand, ShowHidden, ShowAllAliases,
[VisibilityMask](const Info &CandidateInfo) -> bool {
return (CandidateInfo.Visibility & VisibilityMask) == 0;
},
@@ -730,7 +734,7 @@ void OptTable::printHelp(raw_ostream &OS, const char *Usage, const char *Title,
bool ShowHidden = !(FlagsToExclude & HelpHidden);
FlagsToExclude &= ~HelpHidden;
return internalPrintHelp(
- OS, Usage, Title, ShowHidden, ShowAllAliases,
+ OS, Usage, Title, /*SubCommand=*/{}, ShowHidden, ShowAllAliases,
[FlagsToInclude, FlagsToExclude](const Info &CandidateInfo) {
if (FlagsToInclude && !(CandidateInfo.Flags & FlagsToInclude))
return true;
@@ -742,16 +746,62 @@ void OptTable::printHelp(raw_ostream &OS, const char *Usage, const char *Title,
}
void OptTable::internalPrintHelp(
- raw_ostream &OS, const char *Usage, const char *Title, bool ShowHidden,
- bool ShowAllAliases, std::function<bool(const Info &)> ExcludeOption,
+ raw_ostream &OS, const char *Usage, const char *Title, StringRef SubCommand,
+ bool ShowHidden, bool ShowAllAliases,
+ std::function<bool(const Info &)> ExcludeOption,
Visibility VisibilityMask) const {
OS << "OVERVIEW: " << Title << "\n\n";
- OS << "USAGE: " << Usage << "\n\n";
// Render help text into a map of group-name to a list of (option, help)
// pairs.
std::map<std::string, std::vector<OptionInfo>> GroupedOptionHelp;
+ auto ActiveSubCommand =
+ std::find_if(SubCommands.begin(), SubCommands.end(),
+ [&](const auto &C) { return SubCommand == C.Name; });
+ if (!SubCommand.empty()) {
+ assert(ActiveSubCommand != SubCommands.end() &&
+ "Not a valid registered subcommand.");
+ OS << ActiveSubCommand->HelpText << "\n\n";
+ if (!StringRef(ActiveSubCommand->Usage).empty())
+ OS << "USAGE: " << ActiveSubCommand->Usage << "\n\n";
+ } else {
+ OS << "USAGE: " << Usage << "\n\n";
+ if (SubCommands.size() > 1) {
+ OS << "SUBCOMMANDS:\n\n";
+ for (const auto &C : SubCommands)
+ OS << C.Name << " - " << C.HelpText << "\n";
+ OS << "\n";
+ }
+ }
+
+ auto DoesOptionBelongToSubcommand = [&](const Info &CandidateInfo) {
+ // Retrieve the SubCommandIDs registered to the given current CandidateInfo
+ // Option.
+ ArrayRef<unsigned> SubCommandIDs =
+ CandidateInfo.getSubCommandIDs(SubCommandIDsTable);
+
+ // If no registered subcommands, then only global options are to be printed.
+ // If no valid SubCommand (empty) in commandline then print the current
+ // global CandidateInfo Option.
+ if (SubCommandIDs.empty())
+ return SubCommand.empty();
+
+ // Handle CandidateInfo Option which has at least one registered SubCommand.
+ // If no valid SubCommand (empty) in commandline, this CandidateInfo option
+ // should not be printed.
+ if (SubCommand.empty())
+ return false;
+
+ // Find the ID of the valid subcommand passed in commandline (its index in
+ // the SubCommands table which contains all subcommands).
+ unsigned ActiveSubCommandID = ActiveSubCommand - &SubCommands[0];
+ // Print if the ActiveSubCommandID is registered with the CandidateInfo
+ // Option.
+ return std::find(SubCommandIDs.begin(), SubCommandIDs.end(),
+ ActiveSubCommandID) != SubCommandIDs.end();
+ };
+
for (unsigned Id = 1, e = getNumOptions() + 1; Id != e; ++Id) {
// FIXME: Split out option groups.
if (getOptionKind(Id) == Option::GroupClass)
@@ -764,6 +814,9 @@ void OptTable::internalPrintHelp(
if (ExcludeOption(CandidateInfo))
continue;
+ if (!DoesOptionBelongToSubcommand(CandidateInfo))
+ continue;
+
// If an alias doesn't have a help text, show a help text for the aliased
// option instead.
const char *HelpText = getOptionHelpText(Id, VisibilityMask);
@@ -791,8 +844,11 @@ void OptTable::internalPrintHelp(
GenericOptTable::GenericOptTable(const StringTable &StrTable,
ArrayRef<StringTable::Offset> PrefixesTable,
- ArrayRef<Info> OptionInfos, bool IgnoreCase)
- : OptTable(StrTable, PrefixesTable, OptionInfos, IgnoreCase) {
+ ArrayRef<Info> OptionInfos, bool IgnoreCase,
+ ArrayRef<SubCommand> SubCommands,
+ ArrayRef<unsigned> SubCommandIDsTable)
+ : OptTable(StrTable, PrefixesTable, OptionInfos, IgnoreCase, SubCommands,
+ SubCommandIDsTable) {
std::set<StringRef> TmpPrefixesUnion;
for (auto const &Info : OptionInfos.drop_front(FirstSearchableIndex))
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index 7069e8d..119caea 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -1960,6 +1960,7 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
// is fixed.
MPM.addPass(WholeProgramDevirtPass(ExportSummary, nullptr));
+ MPM.addPass(NoRecurseLTOInferencePass());
// Stop here at -O1.
if (Level == OptimizationLevel::O1) {
// The LowerTypeTestsPass needs to run to lower type metadata and the
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index f0e7d36..88550ea 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -119,6 +119,7 @@ MODULE_PASS("metarenamer", MetaRenamerPass())
MODULE_PASS("module-inline", ModuleInlinerPass())
MODULE_PASS("name-anon-globals", NameAnonGlobalPass())
MODULE_PASS("no-op-module", NoOpModulePass())
+MODULE_PASS("norecurse-lto-inference", NoRecurseLTOInferencePass())
MODULE_PASS("nsan", NumericalStabilitySanitizerPass())
MODULE_PASS("openmp-opt", OpenMPOptPass())
MODULE_PASS("openmp-opt-postlink",
diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
index 4357264d..c76689f 100644
--- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
@@ -345,12 +345,6 @@ static unsigned getStackHazardSize(const MachineFunction &MF) {
return MF.getSubtarget<AArch64Subtarget>().getStreamingHazardSize();
}
-/// Returns true if PPRs are spilled as ZPRs.
-static bool arePPRsSpilledAsZPR(const MachineFunction &MF) {
- return MF.getSubtarget().getRegisterInfo()->getSpillSize(
- AArch64::PPRRegClass) == 16;
-}
-
StackOffset
AArch64FrameLowering::getZPRStackSize(const MachineFunction &MF) const {
const AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
@@ -1966,8 +1960,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
StrOpc = RPI.isPaired() ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
break;
case RegPairInfo::PPR:
- StrOpc =
- Size == 16 ? AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO : AArch64::STR_PXI;
+ StrOpc = AArch64::STR_PXI;
break;
case RegPairInfo::VG:
StrOpc = AArch64::STRXui;
@@ -2178,8 +2171,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
LdrOpc = RPI.isPaired() ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
break;
case RegPairInfo::PPR:
- LdrOpc = Size == 16 ? AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO
- : AArch64::LDR_PXI;
+ LdrOpc = AArch64::LDR_PXI;
break;
case RegPairInfo::VG:
continue;
@@ -2286,9 +2278,7 @@ static std::optional<int> getLdStFrameID(const MachineInstr &MI,
// Returns true if the LDST MachineInstr \p MI is a PPR access.
static bool isPPRAccess(const MachineInstr &MI) {
- return MI.getOpcode() != AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO &&
- MI.getOpcode() != AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO &&
- AArch64::PPRRegClass.contains(MI.getOperand(0).getReg());
+ return AArch64::PPRRegClass.contains(MI.getOperand(0).getReg());
}
// Check if a Hazard slot is needed for the current function, and if so create
@@ -2390,12 +2380,6 @@ void AArch64FrameLowering::determineStackHazardSlot(
return;
}
- if (arePPRsSpilledAsZPR(MF)) {
- LLVM_DEBUG(dbgs() << "SplitSVEObjects is not supported with "
- "-aarch64-enable-zpr-predicate-spills");
- return;
- }
-
// If another calling convention is explicitly set FPRs can't be promoted to
// ZPR callee-saves.
if (!is_contained({CallingConv::C, CallingConv::Fast,
@@ -2519,14 +2503,6 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
continue;
}
- // Always save P4 when PPR spills are ZPR-sized and a predicate above p8 is
- // spilled. If all of p0-p3 are used as return values p4 is must be free
- // to reload p8-p15.
- if (RegInfo->getSpillSize(AArch64::PPRRegClass) == 16 &&
- AArch64::PPR_p8to15RegClass.contains(Reg)) {
- SavedRegs.set(AArch64::P4);
- }
-
// MachO's compact unwind format relies on all registers being stored in
// pairs.
// FIXME: the usual format is actually better if unwinding isn't needed.
@@ -2587,7 +2563,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
auto SpillSize = TRI->getSpillSize(*RC);
bool IsZPR = AArch64::ZPRRegClass.contains(Reg);
bool IsPPR = !IsZPR && AArch64::PPRRegClass.contains(Reg);
- if (IsZPR || (IsPPR && arePPRsSpilledAsZPR(MF)))
+ if (IsZPR)
ZPRCSStackSize += SpillSize;
else if (IsPPR)
PPRCSStackSize += SpillSize;
@@ -2902,7 +2878,7 @@ static SVEStackSizes determineSVEStackSizes(MachineFunction &MF,
StackTop += MFI.getObjectSize(FI);
StackTop = alignTo(StackTop, Alignment);
- assert(StackTop < std::numeric_limits<int64_t>::max() &&
+ assert(StackTop < (uint64_t)std::numeric_limits<int64_t>::max() &&
"SVE StackTop far too large?!");
int64_t Offset = -int64_t(StackTop);
@@ -2961,314 +2937,8 @@ static SVEStackSizes determineSVEStackSizes(MachineFunction &MF,
return SVEStack;
}
-/// Attempts to scavenge a register from \p ScavengeableRegs given the used
-/// registers in \p UsedRegs.
-static Register tryScavengeRegister(LiveRegUnits const &UsedRegs,
- BitVector const &ScavengeableRegs,
- Register PreferredReg) {
- if (PreferredReg != AArch64::NoRegister && UsedRegs.available(PreferredReg))
- return PreferredReg;
- for (auto Reg : ScavengeableRegs.set_bits()) {
- if (UsedRegs.available(Reg))
- return Reg;
- }
- return AArch64::NoRegister;
-}
-
-/// Propagates frame-setup/destroy flags from \p SourceMI to all instructions in
-/// \p MachineInstrs.
-static void propagateFrameFlags(MachineInstr &SourceMI,
- ArrayRef<MachineInstr *> MachineInstrs) {
- for (MachineInstr *MI : MachineInstrs) {
- if (SourceMI.getFlag(MachineInstr::FrameSetup))
- MI->setFlag(MachineInstr::FrameSetup);
- if (SourceMI.getFlag(MachineInstr::FrameDestroy))
- MI->setFlag(MachineInstr::FrameDestroy);
- }
-}
-
-/// RAII helper class for scavenging or spilling a register. On construction
-/// attempts to find a free register of class \p RC (given \p UsedRegs and \p
-/// AllocatableRegs), if no register can be found spills \p SpillCandidate to \p
-/// MaybeSpillFI to free a register. The free'd register is returned via the \p
-/// FreeReg output parameter. On destruction, if there is a spill, its previous
-/// value is reloaded. The spilling and scavenging is only valid at the
-/// insertion point \p MBBI, this class should _not_ be used in places that
-/// create or manipulate basic blocks, moving the expected insertion point.
-struct ScopedScavengeOrSpill {
- ScopedScavengeOrSpill(const ScopedScavengeOrSpill &) = delete;
- ScopedScavengeOrSpill(ScopedScavengeOrSpill &&) = delete;
-
- ScopedScavengeOrSpill(MachineFunction &MF, MachineBasicBlock &MBB,
- MachineBasicBlock::iterator MBBI,
- Register SpillCandidate, const TargetRegisterClass &RC,
- LiveRegUnits const &UsedRegs,
- BitVector const &AllocatableRegs,
- std::optional<int> *MaybeSpillFI,
- Register PreferredReg = AArch64::NoRegister)
- : MBB(MBB), MBBI(MBBI), RC(RC), TII(static_cast<const AArch64InstrInfo &>(
- *MF.getSubtarget().getInstrInfo())),
- TRI(*MF.getSubtarget().getRegisterInfo()) {
- FreeReg = tryScavengeRegister(UsedRegs, AllocatableRegs, PreferredReg);
- if (FreeReg != AArch64::NoRegister)
- return;
- assert(MaybeSpillFI && "Expected emergency spill slot FI information "
- "(attempted to spill in prologue/epilogue?)");
- if (!MaybeSpillFI->has_value()) {
- MachineFrameInfo &MFI = MF.getFrameInfo();
- *MaybeSpillFI = MFI.CreateSpillStackObject(TRI.getSpillSize(RC),
- TRI.getSpillAlign(RC));
- }
- FreeReg = SpillCandidate;
- SpillFI = MaybeSpillFI->value();
- TII.storeRegToStackSlot(MBB, MBBI, FreeReg, false, *SpillFI, &RC, &TRI,
- Register());
- }
-
- bool hasSpilled() const { return SpillFI.has_value(); }
-
- /// Returns the free register (found from scavenging or spilling a register).
- Register freeRegister() const { return FreeReg; }
-
- Register operator*() const { return freeRegister(); }
-
- ~ScopedScavengeOrSpill() {
- if (hasSpilled())
- TII.loadRegFromStackSlot(MBB, MBBI, FreeReg, *SpillFI, &RC, &TRI,
- Register());
- }
-
-private:
- MachineBasicBlock &MBB;
- MachineBasicBlock::iterator MBBI;
- const TargetRegisterClass &RC;
- const AArch64InstrInfo &TII;
- const TargetRegisterInfo &TRI;
- Register FreeReg = AArch64::NoRegister;
- std::optional<int> SpillFI;
-};
-
-/// Emergency stack slots for expanding SPILL_PPR_TO_ZPR_SLOT_PSEUDO and
-/// FILL_PPR_FROM_ZPR_SLOT_PSEUDO.
-struct EmergencyStackSlots {
- std::optional<int> ZPRSpillFI;
- std::optional<int> PPRSpillFI;
- std::optional<int> GPRSpillFI;
-};
-
-/// Registers available for scavenging (ZPR, PPR3b, GPR).
-struct ScavengeableRegs {
- BitVector ZPRRegs;
- BitVector PPR3bRegs;
- BitVector GPRRegs;
-};
-
-static bool isInPrologueOrEpilogue(const MachineInstr &MI) {
- return MI.getFlag(MachineInstr::FrameSetup) ||
- MI.getFlag(MachineInstr::FrameDestroy);
-}
-
-/// Expands:
-/// ```
-/// SPILL_PPR_TO_ZPR_SLOT_PSEUDO $p0, %stack.0, 0
-/// ```
-/// To:
-/// ```
-/// $z0 = CPY_ZPzI_B $p0, 1, 0
-/// STR_ZXI $z0, $stack.0, 0
-/// ```
-/// While ensuring a ZPR ($z0 in this example) is free for the predicate (
-/// spilling if necessary).
-static void expandSpillPPRToZPRSlotPseudo(MachineBasicBlock &MBB,
- MachineInstr &MI,
- const TargetRegisterInfo &TRI,
- LiveRegUnits const &UsedRegs,
- ScavengeableRegs const &SR,
- EmergencyStackSlots &SpillSlots) {
- MachineFunction &MF = *MBB.getParent();
- auto *TII =
- static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
-
- ScopedScavengeOrSpill ZPredReg(
- MF, MBB, MI, AArch64::Z0, AArch64::ZPRRegClass, UsedRegs, SR.ZPRRegs,
- isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.ZPRSpillFI);
-
- SmallVector<MachineInstr *, 2> MachineInstrs;
- const DebugLoc &DL = MI.getDebugLoc();
- MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::CPY_ZPzI_B))
- .addReg(*ZPredReg, RegState::Define)
- .add(MI.getOperand(0))
- .addImm(1)
- .addImm(0)
- .getInstr());
- MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::STR_ZXI))
- .addReg(*ZPredReg)
- .add(MI.getOperand(1))
- .addImm(MI.getOperand(2).getImm())
- .setMemRefs(MI.memoperands())
- .getInstr());
- propagateFrameFlags(MI, MachineInstrs);
-}
-
-/// Expands:
-/// ```
-/// $p0 = FILL_PPR_FROM_ZPR_SLOT_PSEUDO %stack.0, 0
-/// ```
-/// To:
-/// ```
-/// $z0 = LDR_ZXI %stack.0, 0
-/// $p0 = PTRUE_B 31, implicit $vg
-/// $p0 = CMPNE_PPzZI_B $p0, $z0, 0, implicit-def $nzcv, implicit-def $nzcv
-/// ```
-/// While ensuring a ZPR ($z0 in this example) is free for the predicate (
-/// spilling if necessary). If the status flags are in use at the point of
-/// expansion they are preserved (by moving them to/from a GPR). This may cause
-/// an additional spill if no GPR is free at the expansion point.
-static bool expandFillPPRFromZPRSlotPseudo(
- MachineBasicBlock &MBB, MachineInstr &MI, const TargetRegisterInfo &TRI,
- LiveRegUnits const &UsedRegs, ScavengeableRegs const &SR,
- MachineInstr *&LastPTrue, EmergencyStackSlots &SpillSlots) {
- MachineFunction &MF = *MBB.getParent();
- auto *TII =
- static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
-
- ScopedScavengeOrSpill ZPredReg(
- MF, MBB, MI, AArch64::Z0, AArch64::ZPRRegClass, UsedRegs, SR.ZPRRegs,
- isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.ZPRSpillFI);
-
- ScopedScavengeOrSpill PredReg(
- MF, MBB, MI, AArch64::P0, AArch64::PPR_3bRegClass, UsedRegs, SR.PPR3bRegs,
- isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.PPRSpillFI,
- /*PreferredReg=*/
- LastPTrue ? LastPTrue->getOperand(0).getReg() : AArch64::NoRegister);
-
- // Elide NZCV spills if we know it is not used.
- bool IsNZCVUsed = !UsedRegs.available(AArch64::NZCV);
- std::optional<ScopedScavengeOrSpill> NZCVSaveReg;
- if (IsNZCVUsed)
- NZCVSaveReg.emplace(
- MF, MBB, MI, AArch64::X0, AArch64::GPR64RegClass, UsedRegs, SR.GPRRegs,
- isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.GPRSpillFI);
- SmallVector<MachineInstr *, 4> MachineInstrs;
- const DebugLoc &DL = MI.getDebugLoc();
- MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::LDR_ZXI))
- .addReg(*ZPredReg, RegState::Define)
- .add(MI.getOperand(1))
- .addImm(MI.getOperand(2).getImm())
- .setMemRefs(MI.memoperands())
- .getInstr());
- if (IsNZCVUsed)
- MachineInstrs.push_back(
- BuildMI(MBB, MI, DL, TII->get(AArch64::MRS))
- .addReg(NZCVSaveReg->freeRegister(), RegState::Define)
- .addImm(AArch64SysReg::NZCV)
- .addReg(AArch64::NZCV, RegState::Implicit)
- .getInstr());
-
- // Reuse previous ptrue if we know it has not been clobbered.
- if (LastPTrue) {
- assert(*PredReg == LastPTrue->getOperand(0).getReg());
- LastPTrue->moveBefore(&MI);
- } else {
- LastPTrue = BuildMI(MBB, MI, DL, TII->get(AArch64::PTRUE_B))
- .addReg(*PredReg, RegState::Define)
- .addImm(31);
- }
- MachineInstrs.push_back(LastPTrue);
- MachineInstrs.push_back(
- BuildMI(MBB, MI, DL, TII->get(AArch64::CMPNE_PPzZI_B))
- .addReg(MI.getOperand(0).getReg(), RegState::Define)
- .addReg(*PredReg)
- .addReg(*ZPredReg)
- .addImm(0)
- .addReg(AArch64::NZCV, RegState::ImplicitDefine)
- .getInstr());
- if (IsNZCVUsed)
- MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::MSR))
- .addImm(AArch64SysReg::NZCV)
- .addReg(NZCVSaveReg->freeRegister())
- .addReg(AArch64::NZCV, RegState::ImplicitDefine)
- .getInstr());
-
- propagateFrameFlags(MI, MachineInstrs);
- return PredReg.hasSpilled();
-}
-
-/// Expands all FILL_PPR_FROM_ZPR_SLOT_PSEUDO and SPILL_PPR_TO_ZPR_SLOT_PSEUDO
-/// operations within the MachineBasicBlock \p MBB.
-static bool expandSMEPPRToZPRSpillPseudos(MachineBasicBlock &MBB,
- const TargetRegisterInfo &TRI,
- ScavengeableRegs const &SR,
- EmergencyStackSlots &SpillSlots) {
- LiveRegUnits UsedRegs(TRI);
- UsedRegs.addLiveOuts(MBB);
- bool HasPPRSpills = false;
- MachineInstr *LastPTrue = nullptr;
- for (MachineInstr &MI : make_early_inc_range(reverse(MBB))) {
- UsedRegs.stepBackward(MI);
- switch (MI.getOpcode()) {
- case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO:
- if (LastPTrue &&
- MI.definesRegister(LastPTrue->getOperand(0).getReg(), &TRI))
- LastPTrue = nullptr;
- HasPPRSpills |= expandFillPPRFromZPRSlotPseudo(MBB, MI, TRI, UsedRegs, SR,
- LastPTrue, SpillSlots);
- MI.eraseFromParent();
- break;
- case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO:
- expandSpillPPRToZPRSlotPseudo(MBB, MI, TRI, UsedRegs, SR, SpillSlots);
- MI.eraseFromParent();
- [[fallthrough]];
- default:
- LastPTrue = nullptr;
- break;
- }
- }
-
- return HasPPRSpills;
-}
-
void AArch64FrameLowering::processFunctionBeforeFrameFinalized(
MachineFunction &MF, RegScavenger *RS) const {
-
- AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
- const TargetSubtargetInfo &TSI = MF.getSubtarget();
- const TargetRegisterInfo &TRI = *TSI.getRegisterInfo();
-
- // If predicates spills are 16-bytes we may need to expand
- // SPILL_PPR_TO_ZPR_SLOT_PSEUDO/FILL_PPR_FROM_ZPR_SLOT_PSEUDO.
- if (AFI->hasStackFrame() && TRI.getSpillSize(AArch64::PPRRegClass) == 16) {
- auto ComputeScavengeableRegisters = [&](unsigned RegClassID) {
- BitVector Regs = TRI.getAllocatableSet(MF, TRI.getRegClass(RegClassID));
- assert(Regs.count() > 0 && "Expected scavengeable registers");
- return Regs;
- };
-
- ScavengeableRegs SR{};
- SR.ZPRRegs = ComputeScavengeableRegisters(AArch64::ZPRRegClassID);
- // Only p0-7 are possible as the second operand of cmpne (needed for fills).
- SR.PPR3bRegs = ComputeScavengeableRegisters(AArch64::PPR_3bRegClassID);
- SR.GPRRegs = ComputeScavengeableRegisters(AArch64::GPR64RegClassID);
-
- EmergencyStackSlots SpillSlots;
- for (MachineBasicBlock &MBB : MF) {
- // In the case we had to spill a predicate (in the range p0-p7) to reload
- // a predicate (>= p8), additional spill/fill pseudos will be created.
- // These need an additional expansion pass. Note: There will only be at
- // most two expansion passes, as spilling/filling a predicate in the range
- // p0-p7 never requires spilling another predicate.
- for (int Pass = 0; Pass < 2; Pass++) {
- bool HasPPRSpills =
- expandSMEPPRToZPRSpillPseudos(MBB, TRI, SR, SpillSlots);
- assert((Pass == 0 || !HasPPRSpills) && "Did not expect PPR spills");
- if (!HasPPRSpills)
- break;
- }
- }
- }
-
- MachineFrameInfo &MFI = MF.getFrameInfo();
-
assert(getStackGrowthDirection() == TargetFrameLowering::StackGrowsDown &&
"Upwards growing stack unsupported");
@@ -3279,6 +2949,9 @@ void AArch64FrameLowering::processFunctionBeforeFrameFinalized(
if (!MF.hasEHFunclets())
return;
+ MachineFrameInfo &MFI = MF.getFrameInfo();
+ auto *AFI = MF.getInfo<AArch64FunctionInfo>();
+
// Win64 C++ EH needs to allocate space for the catch objects in the fixed
// object area right next to the UnwindHelp object.
WinEHFuncInfo &EHInfo = *MF.getWinEHFuncInfo();
@@ -4280,18 +3953,10 @@ void AArch64FrameLowering::emitRemarks(
}
unsigned RegTy = StackAccess::AccessType::GPR;
- if (MFI.hasScalableStackID(FrameIdx)) {
- // SPILL_PPR_TO_ZPR_SLOT_PSEUDO and FILL_PPR_FROM_ZPR_SLOT_PSEUDO
- // spill/fill the predicate as a data vector (so are an FPR access).
- if (MI.getOpcode() != AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO &&
- MI.getOpcode() != AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO &&
- AArch64::PPRRegClass.contains(MI.getOperand(0).getReg())) {
- RegTy = StackAccess::PPR;
- } else
- RegTy = StackAccess::FPR;
- } else if (AArch64InstrInfo::isFpOrNEON(MI)) {
+ if (MFI.hasScalableStackID(FrameIdx))
+ RegTy = isPPRAccess(MI) ? StackAccess::PPR : StackAccess::FPR;
+ else if (AArch64InstrInfo::isFpOrNEON(MI))
RegTy = StackAccess::FPR;
- }
StackAccesses[ArrIdx].AccessTypes |= RegTy;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 5a90da1..b8761d97 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -2579,8 +2579,6 @@ unsigned AArch64InstrInfo::getLoadStoreImmIdx(unsigned Opc) {
case AArch64::STZ2Gi:
case AArch64::STZGi:
case AArch64::TAGPstack:
- case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO:
- case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO:
return 2;
case AArch64::LD1B_D_IMM:
case AArch64::LD1B_H_IMM:
@@ -4387,8 +4385,6 @@ bool AArch64InstrInfo::getMemOpInfo(unsigned Opcode, TypeSize &Scale,
MinOffset = -256;
MaxOffset = 254;
break;
- case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO:
- case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO:
case AArch64::LDR_ZXI:
case AArch64::STR_ZXI:
Scale = TypeSize::getScalable(16);
@@ -5098,33 +5094,31 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg)
.addImm(0)
.addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0));
+ } else if (Subtarget.hasZeroCycleRegMoveGPR64() &&
+ !Subtarget.hasZeroCycleRegMoveGPR32()) {
+ // Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move.
+ MCRegister DestRegX = TRI->getMatchingSuperReg(DestReg, AArch64::sub_32,
+ &AArch64::GPR64spRegClass);
+ assert(DestRegX.isValid() && "Destination super-reg not valid");
+ MCRegister SrcRegX =
+ SrcReg == AArch64::WZR
+ ? AArch64::XZR
+ : TRI->getMatchingSuperReg(SrcReg, AArch64::sub_32,
+ &AArch64::GPR64spRegClass);
+ assert(SrcRegX.isValid() && "Source super-reg not valid");
+ // This instruction is reading and writing X registers. This may upset
+ // the register scavenger and machine verifier, so we need to indicate
+ // that we are reading an undefined value from SrcRegX, but a proper
+ // value from SrcReg.
+ BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestRegX)
+ .addReg(AArch64::XZR)
+ .addReg(SrcRegX, RegState::Undef)
+ .addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc));
} else {
- if (Subtarget.hasZeroCycleRegMoveGPR64() &&
- !Subtarget.hasZeroCycleRegMoveGPR32()) {
- // Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move.
- MCRegister DestRegX = TRI->getMatchingSuperReg(
- DestReg, AArch64::sub_32, &AArch64::GPR64spRegClass);
- assert(DestRegX.isValid() && "Destination super-reg not valid");
- MCRegister SrcRegX =
- SrcReg == AArch64::WZR
- ? AArch64::XZR
- : TRI->getMatchingSuperReg(SrcReg, AArch64::sub_32,
- &AArch64::GPR64spRegClass);
- assert(SrcRegX.isValid() && "Source super-reg not valid");
- // This instruction is reading and writing X registers. This may upset
- // the register scavenger and machine verifier, so we need to indicate
- // that we are reading an undefined value from SrcRegX, but a proper
- // value from SrcReg.
- BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestRegX)
- .addReg(AArch64::XZR)
- .addReg(SrcRegX, RegState::Undef)
- .addReg(SrcReg, RegState::Implicit | getKillRegState(KillSrc));
- } else {
- // Otherwise, expand to ORR WZR.
- BuildMI(MBB, I, DL, get(AArch64::ORRWrr), DestReg)
- .addReg(AArch64::WZR)
- .addReg(SrcReg, getKillRegState(KillSrc));
- }
+ // Otherwise, expand to ORR WZR.
+ BuildMI(MBB, I, DL, get(AArch64::ORRWrr), DestReg)
+ .addReg(AArch64::WZR)
+ .addReg(SrcReg, getKillRegState(KillSrc));
}
return;
}
@@ -5650,11 +5644,6 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
"Unexpected register store without SVE store instructions");
Opc = AArch64::STR_ZXI;
StackID = TargetStackID::ScalableVector;
- } else if (AArch64::PPRRegClass.hasSubClassEq(RC)) {
- assert(Subtarget.isSVEorStreamingSVEAvailable() &&
- "Unexpected predicate store without SVE store instructions");
- Opc = AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO;
- StackID = TargetStackID::ScalableVector;
}
break;
case 24:
@@ -5835,11 +5824,6 @@ void AArch64InstrInfo::loadRegFromStackSlot(
"Unexpected register load without SVE load instructions");
Opc = AArch64::LDR_ZXI;
StackID = TargetStackID::ScalableVector;
- } else if (AArch64::PPRRegClass.hasSubClassEq(RC)) {
- assert(Subtarget.isSVEorStreamingSVEAvailable() &&
- "Unexpected predicate load without SVE load instructions");
- Opc = AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO;
- StackID = TargetStackID::ScalableVector;
}
break;
case 24:
diff --git a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp
index aed137c..1568161 100644
--- a/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp
+++ b/llvm/lib/Target/AArch64/AArch64PrologueEpilogue.cpp
@@ -57,10 +57,7 @@ static bool isPartOfZPRCalleeSaves(MachineBasicBlock::iterator I) {
case AArch64::ST1B_2Z_IMM:
case AArch64::STR_ZXI:
case AArch64::LDR_ZXI:
- case AArch64::CPY_ZPzI_B:
- case AArch64::CMPNE_PPzZI_B:
case AArch64::PTRUE_C_B:
- case AArch64::PTRUE_B:
return I->getFlag(MachineInstr::FrameSetup) ||
I->getFlag(MachineInstr::FrameDestroy);
case AArch64::SEH_SavePReg:
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
index 5d89862..ef974df 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
@@ -980,19 +980,10 @@ class ZPRRegOp <string Suffix, AsmOperandClass C, ElementSizeEnum Size,
//******************************************************************************
// SVE predicate register classes.
-
-// Note: This hardware mode is enabled in AArch64Subtarget::getHwModeSet()
-// (without the use of the table-gen'd predicates).
-def SMEWithZPRPredicateSpills : HwMode<[Predicate<"false">]>;
-
-def PPRSpillFillRI : RegInfoByHwMode<
- [DefaultMode, SMEWithZPRPredicateSpills],
- [RegInfo<16,16,16>, RegInfo<16,128,128>]>;
-
class PPRClass<int firstreg, int lastreg, int step = 1> : RegisterClass<"AArch64",
[ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ], 16,
(sequence "P%u", firstreg, lastreg, step)> {
- let RegInfos = PPRSpillFillRI;
+ let Size = 16;
}
def PPR : PPRClass<0, 15> {
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
index 98e0a11..12ddf47 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
@@ -86,11 +86,6 @@ static cl::alias AArch64StreamingStackHazardSize(
cl::desc("alias for -aarch64-streaming-hazard-size"),
cl::aliasopt(AArch64StreamingHazardSize));
-static cl::opt<bool> EnableZPRPredicateSpills(
- "aarch64-enable-zpr-predicate-spills", cl::init(false), cl::Hidden,
- cl::desc(
- "Enables spilling/reloading SVE predicates as data vectors (ZPRs)"));
-
static cl::opt<unsigned>
VScaleForTuningOpt("sve-vscale-for-tuning", cl::Hidden,
cl::desc("Force a vscale for tuning factor for SVE"));
@@ -426,20 +421,6 @@ AArch64Subtarget::AArch64Subtarget(const Triple &TT, StringRef CPU,
EnableSubregLiveness = EnableSubregLivenessTracking.getValue();
}
-unsigned AArch64Subtarget::getHwModeSet() const {
- AArch64HwModeBits Modes = AArch64HwModeBits::DefaultMode;
-
- // Use a special hardware mode in streaming[-compatible] functions with
- // aarch64-enable-zpr-predicate-spills. This changes the spill size (and
- // alignment) for the predicate register class.
- if (EnableZPRPredicateSpills.getValue() &&
- (isStreaming() || isStreamingCompatible())) {
- Modes |= AArch64HwModeBits::SMEWithZPRPredicateSpills;
- }
-
- return to_underlying(Modes);
-}
-
const CallLowering *AArch64Subtarget::getCallLowering() const {
return CallLoweringInfo.get();
}
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h
index 671df35..8974965 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.h
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h
@@ -130,8 +130,6 @@ public:
bool IsStreaming = false, bool IsStreamingCompatible = false,
bool HasMinSize = false);
- virtual unsigned getHwModeSet() const override;
-
// Getters for SubtargetFeatures defined in tablegen
#define GET_SUBTARGETINFO_MACRO(ATTRIBUTE, DEFAULT, GETTER) \
bool GETTER() const { return ATTRIBUTE; }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 50a8754..479e345 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5666,18 +5666,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
VectorType *AccumVectorType =
VectorType::get(AccumType, VF.divideCoefficientBy(Ratio));
// We don't yet support all kinds of legalization.
- auto TA = TLI->getTypeAction(AccumVectorType->getContext(),
- EVT::getEVT(AccumVectorType));
- switch (TA) {
+ auto TC = TLI->getTypeConversion(AccumVectorType->getContext(),
+ EVT::getEVT(AccumVectorType));
+ switch (TC.first) {
default:
return Invalid;
case TargetLowering::TypeLegal:
case TargetLowering::TypePromoteInteger:
case TargetLowering::TypeSplitVector:
+ // The legalised type (e.g. after splitting) must be legal too.
+ if (TLI->getTypeAction(AccumVectorType->getContext(), TC.second) !=
+ TargetLowering::TypeLegal)
+ return Invalid;
break;
}
- // Check what kind of type-legalisation happens.
std::pair<InstructionCost, MVT> AccumLT =
getTypeLegalizationCost(AccumVectorType);
std::pair<InstructionCost, MVT> InputLT =
diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index be44b8f..33f35ad 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -58,20 +58,6 @@ def FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO :
let hasSideEffects = 0;
}
-def SPILL_PPR_TO_ZPR_SLOT_PSEUDO :
- Pseudo<(outs), (ins PPRorPNRAny:$Pt, GPR64sp:$Rn, simm9:$imm9), []>, Sched<[]>
-{
- let mayStore = 1;
- let hasSideEffects = 0;
-}
-
-def FILL_PPR_FROM_ZPR_SLOT_PSEUDO :
- Pseudo<(outs PPRorPNRAny:$Pt), (ins GPR64sp:$Rn, simm9:$imm9), []>, Sched<[]>
-{
- let mayLoad = 1;
- let hasSideEffects = 0;
-}
-
def SDTZALoadStore : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>]>;
// SME ZA loads and stores
def AArch64SMELdr : SDNode<"AArch64ISD::SME_ZA_LDR", SDTZALoadStore,
diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td
index 9446144..1a697f7 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPU.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPU.td
@@ -1411,20 +1411,6 @@ def FeatureGloballyAddressableScratch : SubtargetFeature<
"FLAT instructions can access scratch memory for any thread in any wave"
>;
-// FIXME: Remove after all users are migrated to attribute.
-def FeatureDynamicVGPR : SubtargetFeature <"dynamic-vgpr",
- "DynamicVGPR",
- "true",
- "Enable dynamic VGPR mode"
->;
-
-// FIXME: Remove after all users are migrated to attribute.
-def FeatureDynamicVGPRBlockSize32 : SubtargetFeature<"dynamic-vgpr-block-size-32",
- "DynamicVGPRBlockSize32",
- "true",
- "Use a block size of 32 for dynamic VGPR allocation (default is 16)"
->;
-
// Enable the use of SCRATCH_STORE/LOAD_BLOCK instructions for saving and
// restoring the callee-saved registers.
def FeatureUseBlockVGPROpsForCSR : SubtargetFeature<"block-vgpr-csr",
@@ -1462,6 +1448,12 @@ def Feature45BitNumRecordsBufferResource : SubtargetFeature< "45-bit-num-records
"The buffer resource (V#) supports 45-bit num_records"
>;
+def FeatureClusters : SubtargetFeature< "clusters",
+ "HasClusters",
+ "true",
+ "Has clusters of workgroups support"
+>;
+
// Dummy feature used to disable assembler instructions.
def FeatureDisable : SubtargetFeature<"",
"FeatureDisable","true",
@@ -2128,6 +2120,7 @@ def FeatureISAVersion12_50 : FeatureSet<
Feature45BitNumRecordsBufferResource,
FeatureSupportsXNACK,
FeatureXNACK,
+ FeatureClusters,
]>;
def FeatureISAVersion12_51 : FeatureSet<
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index a67a7be..d0c0822 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -1944,6 +1944,7 @@ public:
void cvtVOP3Interp(MCInst &Inst, const OperandVector &Operands);
void cvtVINTERP(MCInst &Inst, const OperandVector &Operands);
+ void cvtOpSelHelper(MCInst &Inst, unsigned OpSel);
bool parseDimId(unsigned &Encoding);
ParseStatus parseDim(OperandVector &Operands);
@@ -9239,6 +9240,33 @@ static bool isRegOrImmWithInputMods(const MCInstrDesc &Desc, unsigned OpNum) {
MCOI::OperandConstraint::TIED_TO) == -1;
}
+void AMDGPUAsmParser::cvtOpSelHelper(MCInst &Inst, unsigned OpSel) {
+ unsigned Opc = Inst.getOpcode();
+ constexpr AMDGPU::OpName Ops[] = {AMDGPU::OpName::src0, AMDGPU::OpName::src1,
+ AMDGPU::OpName::src2};
+ constexpr AMDGPU::OpName ModOps[] = {AMDGPU::OpName::src0_modifiers,
+ AMDGPU::OpName::src1_modifiers,
+ AMDGPU::OpName::src2_modifiers};
+ for (int J = 0; J < 3; ++J) {
+ int OpIdx = AMDGPU::getNamedOperandIdx(Opc, Ops[J]);
+ if (OpIdx == -1)
+ // Some instructions, e.g. v_interp_p2_f16 in GFX9, have src0, src2, but
+ // no src1. So continue instead of break.
+ continue;
+
+ int ModIdx = AMDGPU::getNamedOperandIdx(Opc, ModOps[J]);
+ uint32_t ModVal = Inst.getOperand(ModIdx).getImm();
+
+ if ((OpSel & (1 << J)) != 0)
+ ModVal |= SISrcMods::OP_SEL_0;
+ // op_sel[3] is encoded in src0_modifiers.
+ if (ModOps[J] == AMDGPU::OpName::src0_modifiers && (OpSel & (1 << 3)) != 0)
+ ModVal |= SISrcMods::DST_OP_SEL;
+
+ Inst.getOperand(ModIdx).setImm(ModVal);
+ }
+}
+
void AMDGPUAsmParser::cvtVOP3Interp(MCInst &Inst, const OperandVector &Operands)
{
OptionalImmIndexMap OptionalIdx;
@@ -9275,6 +9303,16 @@ void AMDGPUAsmParser::cvtVOP3Interp(MCInst &Inst, const OperandVector &Operands)
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::omod))
addOptionalImmOperand(Inst, Operands, OptionalIdx,
AMDGPUOperand::ImmTyOModSI);
+
+ // Some v_interp instructions use op_sel[3] for dst.
+ if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::op_sel)) {
+ addOptionalImmOperand(Inst, Operands, OptionalIdx,
+ AMDGPUOperand::ImmTyOpSel);
+ int OpSelIdx = AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::op_sel);
+ unsigned OpSel = Inst.getOperand(OpSelIdx).getImm();
+
+ cvtOpSelHelper(Inst, OpSel);
+ }
}
void AMDGPUAsmParser::cvtVINTERP(MCInst &Inst, const OperandVector &Operands)
@@ -9310,31 +9348,10 @@ void AMDGPUAsmParser::cvtVINTERP(MCInst &Inst, const OperandVector &Operands)
if (OpSelIdx == -1)
return;
- const AMDGPU::OpName Ops[] = {AMDGPU::OpName::src0, AMDGPU::OpName::src1,
- AMDGPU::OpName::src2};
- const AMDGPU::OpName ModOps[] = {AMDGPU::OpName::src0_modifiers,
- AMDGPU::OpName::src1_modifiers,
- AMDGPU::OpName::src2_modifiers};
-
unsigned OpSel = Inst.getOperand(OpSelIdx).getImm();
-
- for (int J = 0; J < 3; ++J) {
- int OpIdx = AMDGPU::getNamedOperandIdx(Opc, Ops[J]);
- if (OpIdx == -1)
- break;
-
- int ModIdx = AMDGPU::getNamedOperandIdx(Opc, ModOps[J]);
- uint32_t ModVal = Inst.getOperand(ModIdx).getImm();
-
- if ((OpSel & (1 << J)) != 0)
- ModVal |= SISrcMods::OP_SEL_0;
- if (ModOps[J] == AMDGPU::OpName::src0_modifiers &&
- (OpSel & (1 << 3)) != 0)
- ModVal |= SISrcMods::DST_OP_SEL;
-
- Inst.getOperand(ModIdx).setImm(ModVal);
- }
+ cvtOpSelHelper(Inst, OpSel);
}
+
void AMDGPUAsmParser::cvtScaledMFMA(MCInst &Inst,
const OperandVector &Operands) {
OptionalImmIndexMap OptionalIdx;
diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp
index 7b94ea3..f291e37 100644
--- a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp
@@ -541,7 +541,7 @@ unsigned GCNSubtarget::getMaxNumSGPRs(const Function &F) const {
unsigned GCNSubtarget::getBaseMaxNumVGPRs(
const Function &F, std::pair<unsigned, unsigned> NumVGPRBounds) const {
- const auto &[Min, Max] = NumVGPRBounds;
+ const auto [Min, Max] = NumVGPRBounds;
// Check if maximum number of VGPRs was explicitly requested using
// "amdgpu-num-vgpr" attribute.
diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.h b/llvm/lib/Target/AMDGPU/GCNSubtarget.h
index a54d665..c2e6078 100644
--- a/llvm/lib/Target/AMDGPU/GCNSubtarget.h
+++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.h
@@ -288,6 +288,8 @@ protected:
bool Has45BitNumRecordsBufferResource = false;
+ bool HasClusters = false;
+
// Dummy feature to use for assembler in tablegen.
bool FeatureDisable = false;
@@ -1837,7 +1839,7 @@ public:
}
/// \returns true if the subtarget supports clusters of workgroups.
- bool hasClusters() const { return GFX1250Insts; }
+ bool hasClusters() const { return HasClusters; }
/// \returns true if the subtarget requires a wait for xcnt before atomic
/// flat/global stores & rmw.
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
index d3b5718..3563caa 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
@@ -1280,6 +1280,17 @@ void AMDGPUInstPrinter::printPackedModifier(const MCInst *MI,
(ModIdx != -1) ? MI->getOperand(ModIdx).getImm() : DefaultValue;
}
+ // Some instructions, e.g. v_interp_p2_f16 in GFX9, have src0, src2, but no
+ // src1.
+ if (NumOps == 1 && AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::src2) &&
+ !AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::src1)) {
+ Ops[NumOps++] = DefaultValue; // Set src1_modifiers to default.
+ int Mod2Idx =
+ AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src2_modifiers);
+ assert(Mod2Idx != -1);
+ Ops[NumOps++] = MI->getOperand(Mod2Idx).getImm();
+ }
+
const bool HasDst =
(AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::vdst) != -1) ||
(AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::sdst) != -1);
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
index 3c2dd42..3115579 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
@@ -1118,12 +1118,7 @@ SIRegisterInfo::getPointerRegClass(unsigned Kind) const {
const TargetRegisterClass *
SIRegisterInfo::getCrossCopyRegClass(const TargetRegisterClass *RC) const {
- if (isAGPRClass(RC) && !ST.hasGFX90AInsts())
- return getEquivalentVGPRClass(RC);
- if (RC == &AMDGPU::SCC_CLASSRegClass)
- return getWaveMaskRegClass();
-
- return RC;
+ return RC == &AMDGPU::SCC_CLASSRegClass ? &AMDGPU::SReg_32RegClass : RC;
}
static unsigned getNumSubRegsForSpillOp(const MachineInstr &MI,
diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp
index 20fa141..f7f4d46 100644
--- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp
@@ -1353,11 +1353,6 @@ unsigned getVGPRAllocGranule(const MCSubtargetInfo *STI,
if (DynamicVGPRBlockSize != 0)
return DynamicVGPRBlockSize;
- // Temporarily check the subtarget feature, until we fully switch to using
- // attributes.
- if (STI->getFeatureBits().test(FeatureDynamicVGPR))
- return STI->getFeatureBits().test(FeatureDynamicVGPRBlockSize32) ? 32 : 16;
-
bool IsWave32 = EnableWavefrontSize32
? *EnableWavefrontSize32
: STI->getFeatureBits().test(FeatureWavefrontSize32);
@@ -1412,10 +1407,7 @@ unsigned getAddressableNumVGPRs(const MCSubtargetInfo *STI,
if (Features.test(FeatureGFX90AInsts))
return 512;
- // Temporarily check the subtarget feature, until we fully switch to using
- // attributes.
- if (DynamicVGPRBlockSize != 0 ||
- STI->getFeatureBits().test(FeatureDynamicVGPR))
+ if (DynamicVGPRBlockSize != 0)
// On GFX12 we can allocate at most 8 blocks of VGPRs.
return 8 * getVGPRAllocGranule(STI, DynamicVGPRBlockSize);
return getAddressableNumArchVGPRs(STI);
diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td
index 4a2b54d..42ec8ba 100644
--- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td
@@ -97,6 +97,7 @@ class VOP3Interp<string OpName, VOPProfile P, list<dag> pattern = []> :
VOP3_Pseudo<OpName, P, pattern> {
let AsmMatchConverter = "cvtVOP3Interp";
let mayRaiseFPException = 0;
+ let VOP3_OPSEL = P.HasOpSel;
}
def VOP3_INTERP : VOPProfile<[f32, f32, i32, untyped]> {
@@ -119,16 +120,17 @@ def VOP3_INTERP_MOV : VOPProfile<[f32, i32, i32, untyped]> {
let HasSrc0Mods = 0;
}
-class getInterp16Asm <bit HasSrc2, bit HasOMod> {
+class getInterp16Asm <bit HasSrc2, bit HasOMod, bit OpSel> {
string src2 = !if(HasSrc2, ", $src2_modifiers", "");
string omod = !if(HasOMod, "$omod", "");
+ string opsel = !if(OpSel, "$op_sel", "");
string ret =
- " $vdst, $src0_modifiers, $attr$attrchan"#src2#"$high$clamp"#omod;
+ " $vdst, $src0_modifiers, $attr$attrchan"#src2#"$high$clamp"#omod#opsel;
}
class getInterp16Ins <bit HasSrc2, bit HasOMod,
- Operand Src0Mod, Operand Src2Mod> {
- dag ret = !if(HasSrc2,
+ Operand Src0Mod, Operand Src2Mod, bit OpSel> {
+ dag ret1 = !if(HasSrc2,
!if(HasOMod,
(ins Src0Mod:$src0_modifiers, VRegSrc_32:$src0,
InterpAttr:$attr, InterpAttrChan:$attrchan,
@@ -143,19 +145,22 @@ class getInterp16Ins <bit HasSrc2, bit HasOMod,
InterpAttr:$attr, InterpAttrChan:$attrchan,
highmod:$high, Clamp0:$clamp, omod0:$omod)
);
+ dag ret2 = !if(OpSel, (ins op_sel0:$op_sel), (ins));
+ dag ret = !con(ret1, ret2);
}
-class VOP3_INTERP16 <list<ValueType> ArgVT> : VOPProfile<ArgVT> {
+class VOP3_INTERP16 <list<ValueType> ArgVT, bit OpSel = 0> : VOPProfile<ArgVT> {
let IsSingle = 1;
let HasOMod = !ne(DstVT.Value, f16.Value);
let HasHigh = 1;
+ let HasOpSel = OpSel;
let Src0Mod = FPVRegInputMods;
let Src2Mod = FPVRegInputMods;
let Outs64 = (outs DstRC.RegClass:$vdst);
- let Ins64 = getInterp16Ins<HasSrc2, HasOMod, Src0Mod, Src2Mod>.ret;
- let Asm64 = getInterp16Asm<HasSrc2, HasOMod>.ret;
+ let Ins64 = getInterp16Ins<HasSrc2, HasOMod, Src0Mod, Src2Mod, OpSel>.ret;
+ let Asm64 = getInterp16Asm<HasSrc2, HasOMod, OpSel>.ret;
}
//===----------------------------------------------------------------------===//
@@ -480,7 +485,7 @@ let SubtargetPredicate = isGFX9Plus in {
defm V_MAD_U16_gfx9 : VOP3Inst_t16 <"v_mad_u16_gfx9", VOP_I16_I16_I16_I16>;
defm V_MAD_I16_gfx9 : VOP3Inst_t16 <"v_mad_i16_gfx9", VOP_I16_I16_I16_I16>;
let OtherPredicates = [isNotGFX90APlus] in
-def V_INTERP_P2_F16_gfx9 : VOP3Interp <"v_interp_p2_f16_gfx9", VOP3_INTERP16<[f16, f32, i32, f32]>>;
+def V_INTERP_P2_F16_opsel : VOP3Interp <"v_interp_p2_f16_opsel", VOP3_INTERP16<[f16, f32, i32, f32], /*OpSel*/ 1>>;
} // End SubtargetPredicate = isGFX9Plus
// This predicate should only apply to the selection pattern. The
@@ -2676,6 +2681,14 @@ multiclass VOP3Interp_F16_Real_gfx9<bits<10> op, string OpName, string AsmName>
}
}
+multiclass VOP3Interp_F16_OpSel_Real_gfx9<bits<10> op, string OpName, string AsmName> {
+ def _gfx9 : VOP3_Real<!cast<VOP3_Pseudo>(OpName), SIEncodingFamily.GFX9>,
+ VOP3Interp_OpSel_gfx9 <op, !cast<VOP3_Pseudo>(OpName).Pfl> {
+ VOP3_Pseudo ps = !cast<VOP3_Pseudo>(OpName);
+ let AsmString = AsmName # ps.AsmOperands;
+ }
+}
+
multiclass VOP3_Real_gfx9<bits<10> op, string AsmName> {
def _gfx9 : VOP3_Real<!cast<VOP_Pseudo>(NAME#"_e64"), SIEncodingFamily.GFX9>,
VOP3e_vi <op, !cast<VOP_Pseudo>(NAME#"_e64").Pfl> {
@@ -2788,7 +2801,7 @@ defm V_MAD_U16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x204, "v_mad_u16">;
defm V_MAD_I16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x205, "v_mad_i16">;
defm V_FMA_F16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x206, "v_fma_f16">;
defm V_DIV_FIXUP_F16_gfx9 : VOP3OpSel_F16_Real_gfx9 <0x207, "v_div_fixup_f16">;
-defm V_INTERP_P2_F16_gfx9 : VOP3Interp_F16_Real_gfx9 <0x277, "V_INTERP_P2_F16_gfx9", "v_interp_p2_f16">;
+defm V_INTERP_P2_F16_opsel : VOP3Interp_F16_OpSel_Real_gfx9 <0x277, "V_INTERP_P2_F16_opsel", "v_interp_p2_f16">;
defm V_ADD_I32 : VOP3_Real_vi <0x29c>;
defm V_SUB_I32 : VOP3_Real_vi <0x29d>;
diff --git a/llvm/lib/Target/AMDGPU/VOPInstructions.td b/llvm/lib/Target/AMDGPU/VOPInstructions.td
index 631f0f3..8325c62 100644
--- a/llvm/lib/Target/AMDGPU/VOPInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOPInstructions.td
@@ -419,6 +419,13 @@ class VOP3a_ScaleSel_gfx1250<bits<10> op, VOPProfile p> : VOP3e_gfx11_gfx12<op,
let Inst{14-11} = scale_sel;
}
+class VOP3Interp_OpSel_gfx9<bits<10> op, VOPProfile p> : VOP3Interp_vi<op, p> {
+ let Inst{11} = src0_modifiers{2};
+ // There's no src1
+ let Inst{13} = src2_modifiers{2};
+ let Inst{14} = !if(p.HasDst, src0_modifiers{3}, 0);
+}
+
class VOP3Interp_gfx10<bits<10> op, VOPProfile p> : VOP3e_gfx10<op, p> {
bits<6> attr;
bits<2> attrchan;
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index f9bdc09..77913f2 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -149,6 +149,9 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
case NVPTX::PTXCvtMode::RNA:
O << ".rna";
return;
+ case NVPTX::PTXCvtMode::RS:
+ O << ".rs";
+ return;
}
}
llvm_unreachable("Invalid conversion modifier");
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 77a0e03..1e0f747 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -207,6 +207,7 @@ enum CvtMode {
RM,
RP,
RNA,
+ RS,
BASE_MASK = 0x0F,
FTZ_FLAG = 0x10,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 8c21746..bc047a4a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1096,9 +1096,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// Enable custom lowering for the following:
// * MVT::i128 - clusterlaunchcontrol
// * MVT::i32 - prmt
+ // * MVT::v4f32 - cvt_rs fp{4/6/8}x4 intrinsics
// * MVT::Other - internal.addrspace.wrap
- setOperationAction(ISD::INTRINSIC_WO_CHAIN, {MVT::i32, MVT::i128, MVT::Other},
- Custom);
+ setOperationAction(ISD::INTRINSIC_WO_CHAIN,
+ {MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Custom);
}
const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
@@ -1181,6 +1182,11 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG1_ASHIFT)
MAKE_CASE(
NVPTXISD::TCGEN05_MMA_SP_TENSOR_SCALE_D_DISABLE_OUTPUT_LANE_CG2_ASHIFT)
+ MAKE_CASE(NVPTXISD::CVT_E4M3X4_F32X4_RS_SF)
+ MAKE_CASE(NVPTXISD::CVT_E5M2X4_F32X4_RS_SF)
+ MAKE_CASE(NVPTXISD::CVT_E2M3X4_F32X4_RS_SF)
+ MAKE_CASE(NVPTXISD::CVT_E3M2X4_F32X4_RS_SF)
+ MAKE_CASE(NVPTXISD::CVT_E2M1X4_F32X4_RS_SF)
}
return nullptr;
@@ -2903,6 +2909,61 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
{TryCancelResponse0, TryCancelResponse1});
}
+static SDValue lowerCvtRSIntrinsics(SDValue Op, SelectionDAG &DAG) {
+ SDNode *N = Op.getNode();
+ SDLoc DL(N);
+ SDValue F32Vec = N->getOperand(1);
+ SDValue RBits = N->getOperand(2);
+
+ unsigned IntrinsicID = N->getConstantOperandVal(0);
+
+ // Extract the 4 float elements from the vector
+ SmallVector<SDValue, 6> Ops;
+ for (unsigned i = 0; i < 4; ++i)
+ Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, F32Vec,
+ DAG.getIntPtrConstant(i, DL)));
+
+ using NVPTX::PTXCvtMode::CvtMode;
+
+ auto [OpCode, RetTy, CvtModeFlag] =
+ [&]() -> std::tuple<NVPTXISD::NodeType, MVT::SimpleValueType, uint32_t> {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
+ return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8,
+ CvtMode::RS | CvtMode::RELU_FLAG};
+ case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
+ return {NVPTXISD::CVT_E4M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
+ case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
+ return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8,
+ CvtMode::RS | CvtMode::RELU_FLAG};
+ case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
+ return {NVPTXISD::CVT_E5M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
+ case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
+ return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8,
+ CvtMode::RS | CvtMode::RELU_FLAG};
+ case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
+ return {NVPTXISD::CVT_E2M3X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
+ case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
+ return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8,
+ CvtMode::RS | CvtMode::RELU_FLAG};
+ case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
+ return {NVPTXISD::CVT_E3M2X4_F32X4_RS_SF, MVT::v4i8, CvtMode::RS};
+ case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
+ return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16,
+ CvtMode::RS | CvtMode::RELU_FLAG};
+ case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
+ return {NVPTXISD::CVT_E2M1X4_F32X4_RS_SF, MVT::i16, CvtMode::RS};
+ default:
+ llvm_unreachable("unsupported/unhandled intrinsic");
+ }
+ }();
+
+ Ops.push_back(RBits);
+ Ops.push_back(DAG.getConstant(CvtModeFlag, DL, MVT::i32));
+
+ return DAG.getNode(OpCode, DL, RetTy, Ops);
+}
+
static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
const unsigned Mode = [&]() {
switch (Op->getConstantOperandVal(0)) {
@@ -2972,6 +3033,17 @@ static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_y:
case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_z:
return LowerClusterLaunchControlQueryCancel(Op, DAG);
+ case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite:
+ case Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite:
+ return lowerCvtRSIntrinsics(Op, DAG);
}
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 769d2fe..63fa0bb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -79,6 +79,11 @@ enum NodeType : unsigned {
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X,
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y,
CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z,
+ CVT_E4M3X4_F32X4_RS_SF,
+ CVT_E5M2X4_F32X4_RS_SF,
+ CVT_E2M3X4_F32X4_RS_SF,
+ CVT_E3M2X4_F32X4_RS_SF,
+ CVT_E2M1X4_F32X4_RS_SF,
FIRST_MEMORY_OPCODE,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 4cacee2..6c14cf0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -34,7 +34,8 @@ def CvtRN : PatLeaf<(i32 0x5)>;
def CvtRZ : PatLeaf<(i32 0x6)>;
def CvtRM : PatLeaf<(i32 0x7)>;
def CvtRP : PatLeaf<(i32 0x8)>;
-def CvtRNA : PatLeaf<(i32 0x9)>;
+def CvtRNA : PatLeaf<(i32 0x9)>;
+def CvtRS : PatLeaf<(i32 0xA)>;
def CvtNONE_FTZ : PatLeaf<(i32 0x10)>;
def CvtRNI_FTZ : PatLeaf<(i32 0x11)>;
@@ -50,8 +51,9 @@ def CvtSAT : PatLeaf<(i32 0x20)>;
def CvtSAT_FTZ : PatLeaf<(i32 0x30)>;
def CvtNONE_RELU : PatLeaf<(i32 0x40)>;
-def CvtRN_RELU : PatLeaf<(i32 0x45)>;
-def CvtRZ_RELU : PatLeaf<(i32 0x46)>;
+def CvtRN_RELU : PatLeaf<(i32 0x45)>;
+def CvtRZ_RELU : PatLeaf<(i32 0x46)>;
+def CvtRS_RELU : PatLeaf<(i32 0x4A)>;
def CvtMode : Operand<i32> {
let PrintMethod = "printCvtMode";
@@ -133,6 +135,11 @@ def hasSM100a : Predicate<"Subtarget->getSmVersion() == 100 && Subtarget->hasArc
def hasSM101a : Predicate<"Subtarget->getSmVersion() == 101 && Subtarget->hasArchAccelFeatures()">;
def hasSM120a : Predicate<"Subtarget->getSmVersion() == 120 && Subtarget->hasArchAccelFeatures()">;
+def hasSM100aOrSM103a :
+ Predicate<"(Subtarget->getSmVersion() == 100 || " #
+ "Subtarget->getSmVersion() == 103) " #
+ "&& Subtarget->hasArchAccelFeatures()">;
+
// non-sync shfl instructions are not available on sm_70+ in PTX6.4+
def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
"&& Subtarget->getPTXVersion() >= 64)">;
@@ -593,6 +600,23 @@ let hasSideEffects = false in {
defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", B32>;
defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", B32>;
+
+ multiclass CVT_FROM_FLOAT_V2_RS<string FromName, RegisterClass RC> {
+ def _f32_rs :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B32:$src1, B32:$src2, B32:$src3),
+ (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:relu}." # FromName # ".f32">;
+
+ def _f32_rs_sf :
+ BasicFlagsNVPTXInst<(outs RC:$dst),
+ (ins B32:$src1, B32:$src2, B32:$src3),
+ (ins CvtMode:$mode),
+ "cvt${mode:base}${mode:relu}.satfinite." # FromName # ".f32">;
+ }
+
+ defm CVT_f16x2 : CVT_FROM_FLOAT_V2_RS<"f16x2", B32>;
+ defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_RS<"bf16x2", B32>;
// FP8 conversions.
multiclass CVT_TO_F8X2<string F8Name> {
@@ -619,6 +643,15 @@ let hasSideEffects = false in {
def CVT_f16x2_e4m3x2 : CVT_f16x2_fp8<"e4m3">;
def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">;
+
+ class CVT_TO_FP8X4<string F8Name> :
+ NVPTXInst<(outs B32:$dst),
+ (ins B32:$src1, B32:$src2, B32:$src3, B32:$src4, B32:$src5, CvtMode:$mode),
+ "cvt${mode:base}${mode:relu}.satfinite." # F8Name #
+ "x4.f32 \t$dst, {{$src1, $src2, $src3, $src4}}, $src5;">;
+
+ def CVT_e4m3x4_f32x4_rs_sf : CVT_TO_FP8X4<"e4m3">;
+ def CVT_e5m2x4_f32x4_rs_sf : CVT_TO_FP8X4<"e5m2">;
// Float to TF32 conversions
multiclass CVT_TO_TF32<string Modifier, list<Predicate> Preds = [hasPTX<78>, hasSM<90>]> {
@@ -652,6 +685,15 @@ let hasSideEffects = false in {
"cvt${mode:base}${mode:relu}.f16x2." # type>;
}
+ class CVT_TO_FP6X4<string F6Name> :
+ NVPTXInst<(outs B32:$dst),
+ (ins B32:$src1, B32:$src2, B32:$src3, B32:$src4, B32:$src5, CvtMode:$mode),
+ "cvt${mode:base}${mode:relu}.satfinite." # F6Name #
+ "x4.f32 \t$dst, {{$src1, $src2, $src3, $src4}}, $src5;">;
+
+ def CVT_e2m3x4_f32x4_rs_sf : CVT_TO_FP6X4<"e2m3">;
+ def CVT_e3m2x4_f32x4_rs_sf : CVT_TO_FP6X4<"e3m2">;
+
// FP4 conversions.
def CVT_e2m1x2_f32_sf : NVPTXInst<(outs B16:$dst),
(ins B32:$src1, B32:$src2, CvtMode:$mode),
@@ -668,6 +710,12 @@ let hasSideEffects = false in {
"cvt.u8.u16 \t%e2m1x2_in, $src; \n\t",
"cvt${mode:base}${mode:relu}.f16x2.e2m1x2 \t$dst, %e2m1x2_in; \n\t",
"}}"), []>;
+
+ def CVT_e2m1x4_f32x4_rs_sf :
+ NVPTXInst<(outs B16:$dst),
+ (ins B32:$src1, B32:$src2, B32:$src3, B32:$src4, B32:$src5, CvtMode:$mode),
+ "cvt${mode:base}${mode:relu}.satfinite.e2m1x4.f32 \t" #
+ "$dst, {{$src1, $src2, $src3, $src4}}, $src5;">;
// UE8M0x2 conversions.
class CVT_f32_to_ue8m0x2<string sat = ""> :
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index e91171c..a8b854f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1782,11 +1782,32 @@ def : Pat<(int_nvvm_ff2bf16x2_rn_relu f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, C
def : Pat<(int_nvvm_ff2bf16x2_rz f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRZ)>;
def : Pat<(int_nvvm_ff2bf16x2_rz_relu f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRZ_RELU)>;
+let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in {
+def : Pat<(int_nvvm_ff2bf16x2_rs f32:$a, f32:$b, i32:$c),
+ (CVT_bf16x2_f32_rs $a, $b, $c, CvtRS)>;
+def : Pat<(int_nvvm_ff2bf16x2_rs_relu f32:$a, f32:$b, i32:$c),
+ (CVT_bf16x2_f32_rs $a, $b, $c, CvtRS_RELU)>;
+def : Pat<(int_nvvm_ff2bf16x2_rs_satfinite f32:$a, f32:$b, i32:$c),
+ (CVT_bf16x2_f32_rs_sf $a, $b, $c, CvtRS)>;
+def : Pat<(int_nvvm_ff2bf16x2_rs_relu_satfinite f32:$a, f32:$b, i32:$c),
+ (CVT_bf16x2_f32_rs_sf $a, $b, $c, CvtRS_RELU)>;
+}
+
def : Pat<(int_nvvm_ff2f16x2_rn f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRN)>;
def : Pat<(int_nvvm_ff2f16x2_rn_relu f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRN_RELU)>;
def : Pat<(int_nvvm_ff2f16x2_rz f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRZ)>;
def : Pat<(int_nvvm_ff2f16x2_rz_relu f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRZ_RELU)>;
+let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in {
+def : Pat<(int_nvvm_ff2f16x2_rs f32:$a, f32:$b, i32:$c),
+ (CVT_f16x2_f32_rs $a, $b, $c, CvtRS)>;
+def : Pat<(int_nvvm_ff2f16x2_rs_relu f32:$a, f32:$b, i32:$c),
+ (CVT_f16x2_f32_rs $a, $b, $c, CvtRS_RELU)>;
+def : Pat<(int_nvvm_ff2f16x2_rs_satfinite f32:$a, f32:$b, i32:$c),
+ (CVT_f16x2_f32_rs_sf $a, $b, $c, CvtRS)>;
+def : Pat<(int_nvvm_ff2f16x2_rs_relu_satfinite f32:$a, f32:$b, i32:$c),
+ (CVT_f16x2_f32_rs_sf $a, $b, $c, CvtRS_RELU)>;
+}
def : Pat<(int_nvvm_f2bf16_rn f32:$a), (CVT_bf16_f32 $a, CvtRN)>;
def : Pat<(int_nvvm_f2bf16_rn_relu f32:$a), (CVT_bf16_f32 $a, CvtRN_RELU)>;
def : Pat<(int_nvvm_f2bf16_rz f32:$a), (CVT_bf16_f32 $a, CvtRZ)>;
@@ -1929,6 +1950,52 @@ let Predicates = [hasPTX<86>, hasSM<100>, hasArchAccelFeatures] in {
(CVT_bf16x2_ue8m0x2 $a)>;
}
+def SDT_CVT_F32X4_TO_FPX4_RS_VEC :
+ SDTypeProfile<1, 6, [SDTCisVec<0>, SDTCisFP<1>, SDTCisFP<2>, SDTCisFP<3>,
+ SDTCisFP<4>, SDTCisInt<5>, SDTCisInt<6>]>;
+
+def SDT_CVT_F32X4_TO_FPX4_RS_INT :
+ SDTypeProfile<1, 6, [SDTCisInt<0>, SDTCisFP<1>, SDTCisFP<2>, SDTCisFP<3>,
+ SDTCisFP<4>, SDTCisInt<5>, SDTCisInt<6>]>;
+
+class CVT_F32X4_TO_FPX4_RS_SF_NODE<string FPName, SDTypeProfile SDT> :
+ SDNode<"NVPTXISD::CVT_" # FPName # "X4_F32X4_RS_SF", SDT, []>;
+
+multiclass CVT_F32X4_TO_FPX4_RS_SF_VEC<string FPName, VTVec RetTy> {
+ def : Pat<(RetTy (CVT_F32X4_TO_FPX4_RS_SF_NODE<!toupper(FPName),
+ SDT_CVT_F32X4_TO_FPX4_RS_VEC>
+ f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS)),
+ (!cast<NVPTXInst>("CVT_" # FPName # "x4_f32x4_rs_sf")
+ $f1, $f2, $f3, $f4, $rbits, CvtRS)>;
+
+ def : Pat<(RetTy (CVT_F32X4_TO_FPX4_RS_SF_NODE<!toupper(FPName),
+ SDT_CVT_F32X4_TO_FPX4_RS_VEC>
+ f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS_RELU)),
+ (!cast<NVPTXInst>("CVT_" # FPName # "x4_f32x4_rs_sf")
+ $f1, $f2, $f3, $f4, $rbits, CvtRS_RELU)>;
+}
+
+// RS rounding mode conversions
+let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in {
+// FP8x4 conversions
+defm : CVT_F32X4_TO_FPX4_RS_SF_VEC<"e4m3", v4i8>;
+defm : CVT_F32X4_TO_FPX4_RS_SF_VEC<"e5m2", v4i8>;
+
+// FP6x4 conversions
+defm : CVT_F32X4_TO_FPX4_RS_SF_VEC<"e2m3", v4i8>;
+defm : CVT_F32X4_TO_FPX4_RS_SF_VEC<"e3m2", v4i8>;
+
+// FP4x4 conversions
+def : Pat<(i16 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E2M1",
+ SDT_CVT_F32X4_TO_FPX4_RS_INT>
+ f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS)),
+ (CVT_e2m1x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS)>;
+def : Pat<(i16 (CVT_F32X4_TO_FPX4_RS_SF_NODE<"E2M1",
+ SDT_CVT_F32X4_TO_FPX4_RS_INT>
+ f32:$f1, f32:$f2, f32:$f3, f32:$f4, i32:$rbits, CvtRS_RELU)),
+ (CVT_e2m1x4_f32x4_rs_sf $f1, $f2, $f3, $f4, $rbits, CvtRS_RELU)>;
+}
+
//
// FNS
//
@@ -4461,6 +4528,10 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
!eq(ptx_elt_type, "e2m1"),
!ne(kind, "")) : [hasSM120a, hasPTX<87>],
+ !and(!or(!eq(ptx_elt_type,"e4m3"),
+ !eq(ptx_elt_type,"e5m2")),
+ !eq(geom, "m16n8k16")) : [hasSM<89>, hasPTX<87>],
+
!or(!eq(ptx_elt_type, "e4m3"),
!eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
@@ -4476,6 +4547,11 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
!and(!eq(geom, "m8n8k4"),
!eq(ptx_elt_type, "f64")) : [hasSM<80>, hasPTX<70>],
+ !and(!or(!eq(geom, "m16n8k4"),
+ !eq(geom, "m16n8k8"),
+ !eq(geom, "m16n8k16")),
+ !eq(ptx_elt_type, "f64")) : [hasSM<90>, hasPTX<78>],
+
// fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
!and(!or(!eq(geom, "m8n32k16"),
!eq(geom, "m32n8k16")),
@@ -4760,8 +4836,8 @@ defset list<WMMA_INSTR> WMMAs = {
// MMA
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
- string ALayout, string BLayout, int Satfinite, string b1op>
- : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, FragA, FragB, FragC, FragD>.record,
+ string ALayout, string BLayout, int Satfinite, string b1op, string Kind>
+ : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record,
[FragA.Ins, FragB.Ins, FragC.Ins]>,
// Requires does not seem to have effect on Instruction w/o Patterns.
// We set it here anyways and propagate to the Pat<> we construct below.
@@ -4776,6 +4852,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragA.geom
# "." # ALayout
# "." # BLayout
+ # !if(!ne(Kind, ""), "." # Kind, "")
# !if(Satfinite, ".satfinite", "")
# TypeList
# b1op # "\n\t\t"
@@ -4792,13 +4869,15 @@ defset list<WMMA_INSTR> MMAs = {
foreach satf = [0, 1] in {
foreach op = NVVM_MMA_OPS.all_mma_ops in {
foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
- if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
- def : MMA<WMMA_REGINFO<op[0], "mma">,
- WMMA_REGINFO<op[1], "mma">,
- WMMA_REGINFO<op[2], "mma">,
- WMMA_REGINFO<op[3], "mma">,
- layout_a, layout_b, satf, b1op>;
- }
+ foreach kind = ["", "kind::f8f6f4"] in {
+ if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
+ def : MMA<WMMA_REGINFO<op[0], "mma", "", kind>,
+ WMMA_REGINFO<op[1], "mma", "", kind>,
+ WMMA_REGINFO<op[2], "mma", "", kind>,
+ WMMA_REGINFO<op[3], "mma", "", kind>,
+ layout_a, layout_b, satf, b1op, kind>;
+ }
+ } // kind
} // b1op
} // op
} // satf
diff --git a/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp b/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp
index 1fc475d..561a9c5 100644
--- a/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp
+++ b/llvm/lib/Target/PowerPC/AsmParser/PPCAsmParser.cpp
@@ -349,32 +349,30 @@ public:
bool isImm() const override {
return Kind == Immediate || Kind == Expression;
}
- bool isU1Imm() const { return Kind == Immediate && isUInt<1>(getImm()); }
- bool isU2Imm() const { return Kind == Immediate && isUInt<2>(getImm()); }
- bool isU3Imm() const { return Kind == Immediate && isUInt<3>(getImm()); }
- bool isU4Imm() const { return Kind == Immediate && isUInt<4>(getImm()); }
- bool isU5Imm() const { return Kind == Immediate && isUInt<5>(getImm()); }
- bool isS5Imm() const { return Kind == Immediate && isInt<5>(getImm()); }
- bool isU6Imm() const { return Kind == Immediate && isUInt<6>(getImm()); }
- bool isU6ImmX2() const { return Kind == Immediate &&
- isUInt<6>(getImm()) &&
- (getImm() & 1) == 0; }
- bool isU7Imm() const { return Kind == Immediate && isUInt<7>(getImm()); }
- bool isU7ImmX4() const { return Kind == Immediate &&
- isUInt<7>(getImm()) &&
- (getImm() & 3) == 0; }
- bool isU8Imm() const { return Kind == Immediate && isUInt<8>(getImm()); }
- bool isU8ImmX8() const { return Kind == Immediate &&
- isUInt<8>(getImm()) &&
- (getImm() & 7) == 0; }
-
- bool isU10Imm() const { return Kind == Immediate && isUInt<10>(getImm()); }
- bool isU12Imm() const { return Kind == Immediate && isUInt<12>(getImm()); }
+
+ template <uint64_t N> bool isUImm() const {
+ return Kind == Immediate && isUInt<N>(getImm());
+ }
+ template <uint64_t N> bool isSImm() const {
+ return Kind == Immediate && isInt<N>(getImm());
+ }
+ bool isU6ImmX2() const { return isUImm<6>() && (getImm() & 1) == 0; }
+ bool isU7ImmX4() const { return isUImm<7>() && (getImm() & 3) == 0; }
+ bool isU8ImmX8() const { return isUImm<8>() && (getImm() & 7) == 0; }
+
bool isU16Imm() const { return isExtImm<16>(/*Signed*/ false, 1); }
bool isS16Imm() const { return isExtImm<16>(/*Signed*/ true, 1); }
bool isS16ImmX4() const { return isExtImm<16>(/*Signed*/ true, 4); }
bool isS16ImmX16() const { return isExtImm<16>(/*Signed*/ true, 16); }
bool isS17Imm() const { return isExtImm<17>(/*Signed*/ true, 1); }
+ bool isS34Imm() const {
+ // Once the PC-Rel ABI is finalized, evaluate whether a 34-bit
+ // ContextImmediate is needed.
+ return Kind == Expression || isSImm<34>();
+ }
+ bool isS34ImmX16() const {
+ return Kind == Expression || (isSImm<34>() && (getImm() & 15) == 0);
+ }
bool isHashImmX8() const {
// The Hash Imm form is used for instructions that check or store a hash.
@@ -384,16 +382,6 @@ public:
(getImm() & 7) == 0);
}
- bool isS34ImmX16() const {
- return Kind == Expression ||
- (Kind == Immediate && isInt<34>(getImm()) && (getImm() & 15) == 0);
- }
- bool isS34Imm() const {
- // Once the PC-Rel ABI is finalized, evaluate whether a 34-bit
- // ContextImmediate is needed.
- return Kind == Expression || (Kind == Immediate && isInt<34>(getImm()));
- }
-
bool isTLSReg() const { return Kind == TLSRegister; }
bool isDirectBr() const {
if (Kind == Expression)
@@ -1637,7 +1625,7 @@ bool PPCAsmParser::parseInstruction(ParseInstructionInfo &Info, StringRef Name,
if (Operands.size() != 5)
return false;
PPCOperand &EHOp = (PPCOperand &)*Operands[4];
- if (EHOp.isU1Imm() && EHOp.getImm() == 0)
+ if (EHOp.isUImm<1>() && EHOp.getImm() == 0)
Operands.pop_back();
}
@@ -1817,7 +1805,7 @@ unsigned PPCAsmParser::validateTargetOperandClass(MCParsedAsmOperand &AsmOp,
}
PPCOperand &Op = static_cast<PPCOperand &>(AsmOp);
- if (Op.isU3Imm() && Op.getImm() == ImmVal)
+ if (Op.isUImm<3>() && Op.getImm() == ImmVal)
return Match_Success;
return Match_InvalidOperand;
diff --git a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp
index 48c31c9..81d8e94 100644
--- a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp
+++ b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.cpp
@@ -206,45 +206,24 @@ PPCMCCodeEmitter::getVSRpEvenEncoding(const MCInst &MI, unsigned OpNo,
return RegBits;
}
-unsigned PPCMCCodeEmitter::getImm16Encoding(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI) const {
- const MCOperand &MO = MI.getOperand(OpNo);
- if (MO.isReg() || MO.isImm()) return getMachineOpValue(MI, MO, Fixups, STI);
-
- // Add a fixup for the immediate field.
- addFixup(Fixups, IsLittleEndian ? 0 : 2, MO.getExpr(), PPC::fixup_ppc_half16);
- return 0;
-}
-
-uint64_t PPCMCCodeEmitter::getImm34Encoding(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI,
- MCFixupKind Fixup) const {
+template <MCFixupKind Fixup>
+uint64_t PPCMCCodeEmitter::getImmEncoding(const MCInst &MI, unsigned OpNo,
+ SmallVectorImpl<MCFixup> &Fixups,
+ const MCSubtargetInfo &STI) const {
const MCOperand &MO = MI.getOperand(OpNo);
assert(!MO.isReg() && "Not expecting a register for this operand.");
if (MO.isImm())
return getMachineOpValue(MI, MO, Fixups, STI);
+ uint32_t Offset = 0;
+ if (Fixup == PPC::fixup_ppc_half16)
+ Offset = IsLittleEndian ? 0 : 2;
+
// Add a fixup for the immediate field.
- addFixup(Fixups, 0, MO.getExpr(), Fixup);
+ addFixup(Fixups, Offset, MO.getExpr(), Fixup);
return 0;
}
-uint64_t
-PPCMCCodeEmitter::getImm34EncodingNoPCRel(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI) const {
- return getImm34Encoding(MI, OpNo, Fixups, STI, PPC::fixup_ppc_imm34);
-}
-
-uint64_t
-PPCMCCodeEmitter::getImm34EncodingPCRel(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI) const {
- return getImm34Encoding(MI, OpNo, Fixups, STI, PPC::fixup_ppc_pcrel34);
-}
-
unsigned PPCMCCodeEmitter::getDispRIEncoding(const MCInst &MI, unsigned OpNo,
SmallVectorImpl<MCFixup> &Fixups,
const MCSubtargetInfo &STI) const {
diff --git a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h
index b574557..3356513 100644
--- a/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h
+++ b/llvm/lib/Target/PowerPC/MCTargetDesc/PPCMCCodeEmitter.h
@@ -47,19 +47,10 @@ public:
unsigned getAbsCondBrEncoding(const MCInst &MI, unsigned OpNo,
SmallVectorImpl<MCFixup> &Fixups,
const MCSubtargetInfo &STI) const;
- unsigned getImm16Encoding(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI) const;
- uint64_t getImm34Encoding(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI,
- MCFixupKind Fixup) const;
- uint64_t getImm34EncodingNoPCRel(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI) const;
- uint64_t getImm34EncodingPCRel(const MCInst &MI, unsigned OpNo,
- SmallVectorImpl<MCFixup> &Fixups,
- const MCSubtargetInfo &STI) const;
+ template <MCFixupKind Fixup>
+ uint64_t getImmEncoding(const MCInst &MI, unsigned OpNo,
+ SmallVectorImpl<MCFixup> &Fixups,
+ const MCSubtargetInfo &STI) const;
unsigned getDispRIEncoding(const MCInst &MI, unsigned OpNo,
SmallVectorImpl<MCFixup> &Fixups,
const MCSubtargetInfo &STI) const;
diff --git a/llvm/lib/Target/PowerPC/PPCInstr64Bit.td b/llvm/lib/Target/PowerPC/PPCInstr64Bit.td
index 60efa4c..fdca5ebc 100644
--- a/llvm/lib/Target/PowerPC/PPCInstr64Bit.td
+++ b/llvm/lib/Target/PowerPC/PPCInstr64Bit.td
@@ -14,30 +14,6 @@
//===----------------------------------------------------------------------===//
// 64-bit operands.
//
-def s16imm64 : Operand<i64> {
- let PrintMethod = "printS16ImmOperand";
- let EncoderMethod = "getImm16Encoding";
- let ParserMatchClass = PPCS16ImmAsmOperand;
- let DecoderMethod = "decodeSImmOperand<16>";
- let OperandType = "OPERAND_IMMEDIATE";
-}
-def u16imm64 : Operand<i64> {
- let PrintMethod = "printU16ImmOperand";
- let EncoderMethod = "getImm16Encoding";
- let ParserMatchClass = PPCU16ImmAsmOperand;
- let DecoderMethod = "decodeUImmOperand<16>";
- let OperandType = "OPERAND_IMMEDIATE";
-}
-def s17imm64 : Operand<i64> {
- // This operand type is used for addis/lis to allow the assembler parser
- // to accept immediates in the range -65536..65535 for compatibility with
- // the GNU assembler. The operand is treated as 16-bit otherwise.
- let PrintMethod = "printS16ImmOperand";
- let EncoderMethod = "getImm16Encoding";
- let ParserMatchClass = PPCS17ImmAsmOperand;
- let DecoderMethod = "decodeSImmOperand<16>";
- let OperandType = "OPERAND_IMMEDIATE";
-}
def tocentry : Operand<iPTR> {
let MIOperandInfo = (ops i64imm:$imm);
}
diff --git a/llvm/lib/Target/PowerPC/PPCInstrAltivec.td b/llvm/lib/Target/PowerPC/PPCInstrAltivec.td
index c616db4..23d6d88 100644
--- a/llvm/lib/Target/PowerPC/PPCInstrAltivec.td
+++ b/llvm/lib/Target/PowerPC/PPCInstrAltivec.td
@@ -30,6 +30,11 @@
// Altivec transformation functions and pattern fragments.
//
+// fneg is not legal, and desugared as an xor.
+def desugared_fneg : PatFrag<(ops node:$x), (v4f32 (bitconvert (xor (bitconvert $x),
+ (int_ppc_altivec_vslw (bitconvert (v16i8 immAllOnesV)),
+ (bitconvert (v16i8 immAllOnesV))))))>;
+
def vpkuhum_shuffle : PatFrag<(ops node:$lhs, node:$rhs),
(vector_shuffle node:$lhs, node:$rhs), [{
return PPC::isVPKUHUMShuffleMask(cast<ShuffleVectorSDNode>(N), 0, *CurDAG);
@@ -467,11 +472,12 @@ def VMADDFP : VAForm_1<46, (outs vrrc:$RT), (ins vrrc:$RA, vrrc:$RC, vrrc:$RB),
[(set v4f32:$RT,
(fma v4f32:$RA, v4f32:$RC, v4f32:$RB))]>;
-// FIXME: The fma+fneg pattern won't match because fneg is not legal.
+// fneg is not legal, hence we have to match on the desugared version.
def VNMSUBFP: VAForm_1<47, (outs vrrc:$RT), (ins vrrc:$RA, vrrc:$RC, vrrc:$RB),
"vnmsubfp $RT, $RA, $RC, $RB", IIC_VecFP,
- [(set v4f32:$RT, (fneg (fma v4f32:$RA, v4f32:$RC,
- (fneg v4f32:$RB))))]>;
+ [(set v4f32:$RT, (desugared_fneg (fma v4f32:$RA, v4f32:$RC,
+ (desugared_fneg v4f32:$RB))))]>;
+
let hasSideEffects = 1 in {
def VMHADDSHS : VA1a_Int_Ty<32, "vmhaddshs", int_ppc_altivec_vmhaddshs, v8i16>;
def VMHRADDSHS : VA1a_Int_Ty<33, "vmhraddshs", int_ppc_altivec_vmhraddshs,
@@ -892,6 +898,13 @@ def : Pat<(mul v8i16:$vA, v8i16:$vB), (VMLADDUHM $vA, $vB, (v8i16(V_SET0H)))>;
// Add
def : Pat<(add (mul v8i16:$vA, v8i16:$vB), v8i16:$vC), (VMLADDUHM $vA, $vB, $vC)>;
+
+// Fused negated multiply-subtract
+def : Pat<(v4f32 (desugared_fneg
+ (int_ppc_altivec_vmaddfp v4f32:$RA, v4f32:$RC,
+ (desugared_fneg v4f32:$RB)))),
+ (VNMSUBFP $RA, $RC, $RB)>;
+
// Saturating adds/subtracts.
def : Pat<(v16i8 (saddsat v16i8:$vA, v16i8:$vB)), (v16i8 (VADDSBS $vA, $vB))>;
def : Pat<(v16i8 (uaddsat v16i8:$vA, v16i8:$vB)), (v16i8 (VADDUBS $vA, $vB))>;
diff --git a/llvm/lib/Target/PowerPC/PPCRegisterInfo.td b/llvm/lib/Target/PowerPC/PPCRegisterInfo.td
index 6d8c122..65d0484 100644
--- a/llvm/lib/Target/PowerPC/PPCRegisterInfo.td
+++ b/llvm/lib/Target/PowerPC/PPCRegisterInfo.td
@@ -615,7 +615,8 @@ def spe4rc : RegisterOperand<GPRC> {
}
def PPCU1ImmAsmOperand : AsmOperandClass {
- let Name = "U1Imm"; let PredicateMethod = "isU1Imm";
+ let Name = "U1Imm";
+ let PredicateMethod = "isUImm<1>";
let RenderMethod = "addImmOperands";
}
def u1imm : Operand<i32> {
@@ -626,7 +627,8 @@ def u1imm : Operand<i32> {
}
def PPCU2ImmAsmOperand : AsmOperandClass {
- let Name = "U2Imm"; let PredicateMethod = "isU2Imm";
+ let Name = "U2Imm";
+ let PredicateMethod = "isUImm<2>";
let RenderMethod = "addImmOperands";
}
def u2imm : Operand<i32> {
@@ -647,7 +649,8 @@ def atimm : Operand<i32> {
}
def PPCU3ImmAsmOperand : AsmOperandClass {
- let Name = "U3Imm"; let PredicateMethod = "isU3Imm";
+ let Name = "U3Imm";
+ let PredicateMethod = "isUImm<3>";
let RenderMethod = "addImmOperands";
}
def u3imm : Operand<i32> {
@@ -658,7 +661,8 @@ def u3imm : Operand<i32> {
}
def PPCU4ImmAsmOperand : AsmOperandClass {
- let Name = "U4Imm"; let PredicateMethod = "isU4Imm";
+ let Name = "U4Imm";
+ let PredicateMethod = "isUImm<4>";
let RenderMethod = "addImmOperands";
}
def u4imm : Operand<i32> {
@@ -668,7 +672,8 @@ def u4imm : Operand<i32> {
let OperandType = "OPERAND_IMMEDIATE";
}
def PPCS5ImmAsmOperand : AsmOperandClass {
- let Name = "S5Imm"; let PredicateMethod = "isS5Imm";
+ let Name = "S5Imm";
+ let PredicateMethod = "isSImm<5>";
let RenderMethod = "addImmOperands";
}
def s5imm : Operand<i32> {
@@ -678,7 +683,8 @@ def s5imm : Operand<i32> {
let OperandType = "OPERAND_IMMEDIATE";
}
def PPCU5ImmAsmOperand : AsmOperandClass {
- let Name = "U5Imm"; let PredicateMethod = "isU5Imm";
+ let Name = "U5Imm";
+ let PredicateMethod = "isUImm<5>";
let RenderMethod = "addImmOperands";
}
def u5imm : Operand<i32> {
@@ -688,7 +694,8 @@ def u5imm : Operand<i32> {
let OperandType = "OPERAND_IMMEDIATE";
}
def PPCU6ImmAsmOperand : AsmOperandClass {
- let Name = "U6Imm"; let PredicateMethod = "isU6Imm";
+ let Name = "U6Imm";
+ let PredicateMethod = "isUImm<6>";
let RenderMethod = "addImmOperands";
}
def u6imm : Operand<i32> {
@@ -698,7 +705,8 @@ def u6imm : Operand<i32> {
let OperandType = "OPERAND_IMMEDIATE";
}
def PPCU7ImmAsmOperand : AsmOperandClass {
- let Name = "U7Imm"; let PredicateMethod = "isU7Imm";
+ let Name = "U7Imm";
+ let PredicateMethod = "isUImm<7>";
let RenderMethod = "addImmOperands";
}
def u7imm : Operand<i32> {
@@ -708,7 +716,8 @@ def u7imm : Operand<i32> {
let OperandType = "OPERAND_IMMEDIATE";
}
def PPCU8ImmAsmOperand : AsmOperandClass {
- let Name = "U8Imm"; let PredicateMethod = "isU8Imm";
+ let Name = "U8Imm";
+ let PredicateMethod = "isUImm<8>";
let RenderMethod = "addImmOperands";
}
def u8imm : Operand<i32> {
@@ -718,7 +727,8 @@ def u8imm : Operand<i32> {
let OperandType = "OPERAND_IMMEDIATE";
}
def PPCU10ImmAsmOperand : AsmOperandClass {
- let Name = "U10Imm"; let PredicateMethod = "isU10Imm";
+ let Name = "U10Imm";
+ let PredicateMethod = "isUImm<10>";
let RenderMethod = "addImmOperands";
}
def u10imm : Operand<i32> {
@@ -728,7 +738,8 @@ def u10imm : Operand<i32> {
let OperandType = "OPERAND_IMMEDIATE";
}
def PPCU12ImmAsmOperand : AsmOperandClass {
- let Name = "U12Imm"; let PredicateMethod = "isU12Imm";
+ let Name = "U12Imm";
+ let PredicateMethod = "isUImm<12>";
let RenderMethod = "addImmOperands";
}
def u12imm : Operand<i32> {
@@ -743,7 +754,14 @@ def PPCS16ImmAsmOperand : AsmOperandClass {
}
def s16imm : Operand<i32> {
let PrintMethod = "printS16ImmOperand";
- let EncoderMethod = "getImm16Encoding";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>";
+ let ParserMatchClass = PPCS16ImmAsmOperand;
+ let DecoderMethod = "decodeSImmOperand<16>";
+ let OperandType = "OPERAND_IMMEDIATE";
+}
+def s16imm64 : Operand<i64> {
+ let PrintMethod = "printS16ImmOperand";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>";
let ParserMatchClass = PPCS16ImmAsmOperand;
let DecoderMethod = "decodeSImmOperand<16>";
let OperandType = "OPERAND_IMMEDIATE";
@@ -754,7 +772,14 @@ def PPCU16ImmAsmOperand : AsmOperandClass {
}
def u16imm : Operand<i32> {
let PrintMethod = "printU16ImmOperand";
- let EncoderMethod = "getImm16Encoding";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>";
+ let ParserMatchClass = PPCU16ImmAsmOperand;
+ let DecoderMethod = "decodeUImmOperand<16>";
+ let OperandType = "OPERAND_IMMEDIATE";
+}
+def u16imm64 : Operand<i64> {
+ let PrintMethod = "printU16ImmOperand";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>";
let ParserMatchClass = PPCU16ImmAsmOperand;
let DecoderMethod = "decodeUImmOperand<16>";
let OperandType = "OPERAND_IMMEDIATE";
@@ -768,7 +793,17 @@ def s17imm : Operand<i32> {
// to accept immediates in the range -65536..65535 for compatibility with
// the GNU assembler. The operand is treated as 16-bit otherwise.
let PrintMethod = "printS16ImmOperand";
- let EncoderMethod = "getImm16Encoding";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>";
+ let ParserMatchClass = PPCS17ImmAsmOperand;
+ let DecoderMethod = "decodeSImmOperand<16>";
+ let OperandType = "OPERAND_IMMEDIATE";
+}
+def s17imm64 : Operand<i64> {
+ // This operand type is used for addis/lis to allow the assembler parser
+ // to accept immediates in the range -65536..65535 for compatibility with
+ // the GNU assembler. The operand is treated as 16-bit otherwise.
+ let PrintMethod = "printS16ImmOperand";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_half16>";
let ParserMatchClass = PPCS17ImmAsmOperand;
let DecoderMethod = "decodeSImmOperand<16>";
let OperandType = "OPERAND_IMMEDIATE";
@@ -780,14 +815,14 @@ def PPCS34ImmAsmOperand : AsmOperandClass {
}
def s34imm : Operand<i64> {
let PrintMethod = "printS34ImmOperand";
- let EncoderMethod = "getImm34EncodingNoPCRel";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_imm34>";
let ParserMatchClass = PPCS34ImmAsmOperand;
let DecoderMethod = "decodeSImmOperand<34>";
let OperandType = "OPERAND_IMMEDIATE";
}
def s34imm_pcrel : Operand<i64> {
let PrintMethod = "printS34ImmOperand";
- let EncoderMethod = "getImm34EncodingPCRel";
+ let EncoderMethod = "getImmEncoding<PPC::fixup_ppc_pcrel34>";
let ParserMatchClass = PPCS34ImmAsmOperand;
let DecoderMethod = "decodeSImmOperand<34>";
let OperandType = "OPERAND_IMMEDIATE";
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
index 597dd12..9f9ae2f 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
@@ -324,6 +324,10 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
OpdsMapping[0] = GPRValueMapping;
+ // Atomics always use GPR destinations. Don't refine any further.
+ if (cast<GLoad>(MI).isAtomic())
+ break;
+
// Use FPR64 for s64 loads on rv32.
if (GPRSize == 32 && Size.getFixedValue() == 64) {
assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
@@ -358,6 +362,10 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
OpdsMapping[0] = GPRValueMapping;
+ // Atomics always use GPR sources. Don't refine any further.
+ if (cast<GStore>(MI).isAtomic())
+ break;
+
// Use FPR64 for s64 stores on rv32.
if (GPRSize == 32 && Size.getFixedValue() == 64) {
assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td
index a02de31..27cf057 100644
--- a/llvm/lib/Target/RISCV/RISCVFeatures.td
+++ b/llvm/lib/Target/RISCV/RISCVFeatures.td
@@ -1421,7 +1421,7 @@ def HasVendorXMIPSCMov
: Predicate<"Subtarget->hasVendorXMIPSCMov()">,
AssemblerPredicate<(all_of FeatureVendorXMIPSCMov),
"'Xmipscmov' ('mips.ccmov' instruction)">;
-def UseCCMovInsn : Predicate<"Subtarget->useCCMovInsn()">;
+def UseMIPSCCMovInsn : Predicate<"Subtarget->useMIPSCCMovInsn()">;
def FeatureVendorXMIPSLSP
: RISCVExtension<1, 0, "MIPS optimization for hardware load-store bonding">;
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index dcce2d2..b624076 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -434,7 +434,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::ABS, MVT::i32, Custom);
}
- if (!Subtarget.useCCMovInsn() && !Subtarget.hasVendorXTHeadCondMov())
+ if (!Subtarget.useMIPSCCMovInsn() && !Subtarget.hasVendorXTHeadCondMov())
setOperationAction(ISD::SELECT, XLenVT, Custom);
if (Subtarget.hasVendorXqcia() && !Subtarget.is64Bit()) {
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td
index 115ab38e..0b5bee1 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXMips.td
@@ -175,7 +175,7 @@ def MIPS_CCMOV : RVInstR4<0b11, 0b011, OPC_CUSTOM_0, (outs GPR:$rd),
Sched<[]>;
}
-let Predicates = [UseCCMovInsn] in {
+let Predicates = [UseMIPSCCMovInsn] in {
def : Pat<(select (riscv_setne (XLenVT GPR:$rs2)),
(XLenVT GPR:$rs1), (XLenVT GPR:$rs3)),
(MIPS_CCMOV GPR:$rs1, GPR:$rs2, GPR:$rs3)>;
diff --git a/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp
index c81a20b..115a96e 100644
--- a/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp
+++ b/llvm/lib/Target/RISCV/RISCVLoadStoreOptimizer.cpp
@@ -92,7 +92,7 @@ bool RISCVLoadStoreOpt::runOnMachineFunction(MachineFunction &Fn) {
if (skipFunction(Fn.getFunction()))
return false;
const RISCVSubtarget &Subtarget = Fn.getSubtarget<RISCVSubtarget>();
- if (!Subtarget.useLoadStorePairs())
+ if (!Subtarget.useMIPSLoadStorePairs())
return false;
bool MadeChange = false;
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
index e35ffaf..715ac4c 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
@@ -65,9 +65,9 @@ static cl::opt<bool> UseMIPSLoadStorePairsOpt(
cl::desc("Enable the load/store pair optimization pass"), cl::init(false),
cl::Hidden);
-static cl::opt<bool> UseCCMovInsn("use-riscv-ccmov",
- cl::desc("Use 'mips.ccmov' instruction"),
- cl::init(true), cl::Hidden);
+static cl::opt<bool> UseMIPSCCMovInsn("use-riscv-mips-ccmov",
+ cl::desc("Use 'mips.ccmov' instruction"),
+ cl::init(true), cl::Hidden);
void RISCVSubtarget::anchor() {}
@@ -246,10 +246,10 @@ void RISCVSubtarget::overridePostRASchedPolicy(
}
}
-bool RISCVSubtarget::useLoadStorePairs() const {
+bool RISCVSubtarget::useMIPSLoadStorePairs() const {
return UseMIPSLoadStorePairsOpt && HasVendorXMIPSLSP;
}
-bool RISCVSubtarget::useCCMovInsn() const {
- return UseCCMovInsn && HasVendorXMIPSCMov;
+bool RISCVSubtarget::useMIPSCCMovInsn() const {
+ return UseMIPSCCMovInsn && HasVendorXMIPSCMov;
}
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h
index 7dffa63..6acf799 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.h
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h
@@ -227,8 +227,8 @@ public:
unsigned getXLen() const {
return is64Bit() ? 64 : 32;
}
- bool useLoadStorePairs() const;
- bool useCCMovInsn() const;
+ bool useMIPSLoadStorePairs() const;
+ bool useMIPSCCMovInsn() const;
unsigned getFLen() const {
if (HasStdExtD)
return 64;
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 9f2e075..e16c8f0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -2811,9 +2811,7 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
GetElementPtrInst *NewGEP = simplifyZeroLengthArrayGepInst(Ref);
if (NewGEP) {
Ref->replaceAllUsesWith(NewGEP);
- if (isInstructionTriviallyDead(Ref))
- DeadInsts.insert(Ref);
-
+ DeadInsts.insert(Ref);
Ref = NewGEP;
}
if (Type *GepTy = getGEPType(Ref))
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 0afec42..989950f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -307,6 +307,10 @@ private:
bool selectHandleFromBinding(Register &ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectCounterHandleFromBinding(Register &ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectReadImageIntrinsic(Register &ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
bool selectImageWriteIntrinsic(MachineInstr &I) const;
@@ -314,6 +318,8 @@ private:
MachineInstr &I) const;
bool selectModf(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectUpdateCounter(Register &ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
bool selectFrexp(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
// Utilities
@@ -3443,6 +3449,10 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
case Intrinsic::spv_resource_handlefrombinding: {
return selectHandleFromBinding(ResVReg, ResType, I);
}
+ case Intrinsic::spv_resource_counterhandlefrombinding:
+ return selectCounterHandleFromBinding(ResVReg, ResType, I);
+ case Intrinsic::spv_resource_updatecounter:
+ return selectUpdateCounter(ResVReg, ResType, I);
case Intrinsic::spv_resource_store_typedbuffer: {
return selectImageWriteIntrinsic(I);
}
@@ -3478,6 +3488,130 @@ bool SPIRVInstructionSelector::selectHandleFromBinding(Register &ResVReg,
*cast<GIntrinsic>(&I), I);
}
+bool SPIRVInstructionSelector::selectCounterHandleFromBinding(
+ Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
+ auto &Intr = cast<GIntrinsic>(I);
+ assert(Intr.getIntrinsicID() ==
+ Intrinsic::spv_resource_counterhandlefrombinding);
+
+ // Extract information from the intrinsic call.
+ Register MainHandleReg = Intr.getOperand(2).getReg();
+ auto *MainHandleDef = cast<GIntrinsic>(getVRegDef(*MRI, MainHandleReg));
+ assert(MainHandleDef->getIntrinsicID() ==
+ Intrinsic::spv_resource_handlefrombinding);
+
+ uint32_t Set = getIConstVal(Intr.getOperand(4).getReg(), MRI);
+ uint32_t Binding = getIConstVal(Intr.getOperand(3).getReg(), MRI);
+ uint32_t ArraySize = getIConstVal(MainHandleDef->getOperand(4).getReg(), MRI);
+ Register IndexReg = MainHandleDef->getOperand(5).getReg();
+ const bool IsNonUniform = false;
+ std::string CounterName =
+ getStringValueFromReg(MainHandleDef->getOperand(6).getReg(), *MRI) +
+ ".counter";
+
+ // Create the counter variable.
+ MachineIRBuilder MIRBuilder(I);
+ Register CounterVarReg = buildPointerToResource(
+ GR.getPointeeType(ResType), GR.getPointerStorageClass(ResType), Set,
+ Binding, ArraySize, IndexReg, IsNonUniform, CounterName, MIRBuilder);
+
+ return BuildCOPY(ResVReg, CounterVarReg, I);
+}
+
+bool SPIRVInstructionSelector::selectUpdateCounter(Register &ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+ auto &Intr = cast<GIntrinsic>(I);
+ assert(Intr.getIntrinsicID() == Intrinsic::spv_resource_updatecounter);
+
+ Register CounterHandleReg = Intr.getOperand(2).getReg();
+ Register IncrReg = Intr.getOperand(3).getReg();
+
+ // The counter handle is a pointer to the counter variable (which is a struct
+ // containing an i32). We need to get a pointer to that i32 member to do the
+ // atomic operation.
+#ifndef NDEBUG
+ SPIRVType *CounterVarType = GR.getSPIRVTypeForVReg(CounterHandleReg);
+ SPIRVType *CounterVarPointeeType = GR.getPointeeType(CounterVarType);
+ assert(CounterVarPointeeType &&
+ CounterVarPointeeType->getOpcode() == SPIRV::OpTypeStruct &&
+ "Counter variable must be a struct");
+ assert(GR.getPointerStorageClass(CounterVarType) ==
+ SPIRV::StorageClass::StorageBuffer &&
+ "Counter variable must be in the storage buffer storage class");
+ assert(CounterVarPointeeType->getNumOperands() == 2 &&
+ "Counter variable must have exactly 1 member in the struct");
+ const SPIRVType *MemberType =
+ GR.getSPIRVTypeForVReg(CounterVarPointeeType->getOperand(1).getReg());
+ assert(MemberType->getOpcode() == SPIRV::OpTypeInt &&
+ "Counter variable struct must have a single i32 member");
+#endif
+
+ // The struct has a single i32 member.
+ MachineIRBuilder MIRBuilder(I);
+ const Type *LLVMIntType =
+ Type::getInt32Ty(I.getMF()->getFunction().getContext());
+
+ SPIRVType *IntPtrType = GR.getOrCreateSPIRVPointerType(
+ LLVMIntType, MIRBuilder, SPIRV::StorageClass::StorageBuffer);
+
+ auto Zero = buildI32Constant(0, I);
+ if (!Zero.second)
+ return false;
+
+ Register PtrToCounter =
+ MRI->createVirtualRegister(GR.getRegClass(IntPtrType));
+ if (!BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(SPIRV::OpAccessChain))
+ .addDef(PtrToCounter)
+ .addUse(GR.getSPIRVTypeID(IntPtrType))
+ .addUse(CounterHandleReg)
+ .addUse(Zero.first)
+ .constrainAllUses(TII, TRI, RBI)) {
+ return false;
+ }
+
+ // For UAV/SSBO counters, the scope is Device. The counter variable is not
+ // used as a flag. So the memory semantics can be None.
+ auto Scope = buildI32Constant(SPIRV::Scope::Device, I);
+ if (!Scope.second)
+ return false;
+ auto Semantics = buildI32Constant(SPIRV::MemorySemantics::None, I);
+ if (!Semantics.second)
+ return false;
+
+ int64_t IncrVal = getIConstValSext(IncrReg, MRI);
+ auto Incr = buildI32Constant(static_cast<uint32_t>(IncrVal), I);
+ if (!Incr.second)
+ return false;
+
+ Register AtomicRes = MRI->createVirtualRegister(GR.getRegClass(ResType));
+ if (!BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpAtomicIAdd))
+ .addDef(AtomicRes)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(PtrToCounter)
+ .addUse(Scope.first)
+ .addUse(Semantics.first)
+ .addUse(Incr.first)
+ .constrainAllUses(TII, TRI, RBI)) {
+ return false;
+ }
+ if (IncrVal >= 0) {
+ return BuildCOPY(ResVReg, AtomicRes, I);
+ }
+
+ // In HLSL, IncrementCounter returns the value *before* the increment, while
+ // DecrementCounter returns the value *after* the decrement. Both are lowered
+ // to the same atomic intrinsic which returns the value *before* the
+ // operation. So for decrements (negative IncrVal), we must subtract the
+ // increment value from the result to get the post-decrement value.
+ return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(AtomicRes)
+ .addUse(Incr.first)
+ .constrainAllUses(TII, TRI, RBI);
+}
bool SPIRVInstructionSelector::selectReadImageIntrinsic(
Register &ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp
index 205895e..fc14a03 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp
@@ -39,6 +39,10 @@ private:
void collectBindingInfo(Module &M);
uint32_t getAndReserveFirstUnusedBinding(uint32_t DescSet);
void replaceImplicitBindingCalls(Module &M);
+ void replaceResourceHandleCall(Module &M, CallInst *OldCI,
+ uint32_t NewBinding);
+ void replaceCounterHandleCall(Module &M, CallInst *OldCI,
+ uint32_t NewBinding);
void verifyUniqueOrderIdPerResource(SmallVectorImpl<CallInst *> &Calls);
// A map from descriptor set to a bit vector of used binding numbers.
@@ -56,64 +60,93 @@ struct BindingInfoCollector : public InstVisitor<BindingInfoCollector> {
: UsedBindings(UsedBindings), ImplicitBindingCalls(ImplicitBindingCalls) {
}
+ void addBinding(uint32_t DescSet, uint32_t Binding) {
+ if (UsedBindings.size() <= DescSet) {
+ UsedBindings.resize(DescSet + 1);
+ UsedBindings[DescSet].resize(64);
+ }
+ if (UsedBindings[DescSet].size() <= Binding) {
+ UsedBindings[DescSet].resize(2 * Binding + 1);
+ }
+ UsedBindings[DescSet].set(Binding);
+ }
+
void visitCallInst(CallInst &CI) {
if (CI.getIntrinsicID() == Intrinsic::spv_resource_handlefrombinding) {
const uint32_t DescSet =
cast<ConstantInt>(CI.getArgOperand(0))->getZExtValue();
const uint32_t Binding =
cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue();
-
- if (UsedBindings.size() <= DescSet) {
- UsedBindings.resize(DescSet + 1);
- UsedBindings[DescSet].resize(64);
- }
- if (UsedBindings[DescSet].size() <= Binding) {
- UsedBindings[DescSet].resize(2 * Binding + 1);
- }
- UsedBindings[DescSet].set(Binding);
+ addBinding(DescSet, Binding);
} else if (CI.getIntrinsicID() ==
Intrinsic::spv_resource_handlefromimplicitbinding) {
ImplicitBindingCalls.push_back(&CI);
+ } else if (CI.getIntrinsicID() ==
+ Intrinsic::spv_resource_counterhandlefrombinding) {
+ const uint32_t DescSet =
+ cast<ConstantInt>(CI.getArgOperand(2))->getZExtValue();
+ const uint32_t Binding =
+ cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue();
+ addBinding(DescSet, Binding);
+ } else if (CI.getIntrinsicID() ==
+ Intrinsic::spv_resource_counterhandlefromimplicitbinding) {
+ ImplicitBindingCalls.push_back(&CI);
}
}
};
+static uint32_t getOrderId(const CallInst *CI) {
+ uint32_t OrderIdArgIdx = 0;
+ switch (CI->getIntrinsicID()) {
+ case Intrinsic::spv_resource_handlefromimplicitbinding:
+ OrderIdArgIdx = 0;
+ break;
+ case Intrinsic::spv_resource_counterhandlefromimplicitbinding:
+ OrderIdArgIdx = 1;
+ break;
+ default:
+ llvm_unreachable("CallInst is not an implicit binding intrinsic");
+ }
+ return cast<ConstantInt>(CI->getArgOperand(OrderIdArgIdx))->getZExtValue();
+}
+
+static uint32_t getDescSet(const CallInst *CI) {
+ uint32_t DescSetArgIdx;
+ switch (CI->getIntrinsicID()) {
+ case Intrinsic::spv_resource_handlefromimplicitbinding:
+ case Intrinsic::spv_resource_handlefrombinding:
+ DescSetArgIdx = 1;
+ break;
+ case Intrinsic::spv_resource_counterhandlefromimplicitbinding:
+ case Intrinsic::spv_resource_counterhandlefrombinding:
+ DescSetArgIdx = 2;
+ break;
+ default:
+ llvm_unreachable("CallInst is not an implicit binding intrinsic");
+ }
+ return cast<ConstantInt>(CI->getArgOperand(DescSetArgIdx))->getZExtValue();
+}
+
void SPIRVLegalizeImplicitBinding::collectBindingInfo(Module &M) {
BindingInfoCollector InfoCollector(UsedBindings, ImplicitBindingCalls);
InfoCollector.visit(M);
// Sort the collected calls by their order ID.
- std::sort(
- ImplicitBindingCalls.begin(), ImplicitBindingCalls.end(),
- [](const CallInst *A, const CallInst *B) {
- const uint32_t OrderIdArgIdx = 0;
- const uint32_t OrderA =
- cast<ConstantInt>(A->getArgOperand(OrderIdArgIdx))->getZExtValue();
- const uint32_t OrderB =
- cast<ConstantInt>(B->getArgOperand(OrderIdArgIdx))->getZExtValue();
- return OrderA < OrderB;
- });
+ std::sort(ImplicitBindingCalls.begin(), ImplicitBindingCalls.end(),
+ [](const CallInst *A, const CallInst *B) {
+ return getOrderId(A) < getOrderId(B);
+ });
}
void SPIRVLegalizeImplicitBinding::verifyUniqueOrderIdPerResource(
SmallVectorImpl<CallInst *> &Calls) {
// Check that the order Id is unique per resource.
for (uint32_t i = 1; i < Calls.size(); ++i) {
- const uint32_t OrderIdArgIdx = 0;
- const uint32_t DescSetArgIdx = 1;
- const uint32_t OrderA =
- cast<ConstantInt>(Calls[i - 1]->getArgOperand(OrderIdArgIdx))
- ->getZExtValue();
- const uint32_t OrderB =
- cast<ConstantInt>(Calls[i]->getArgOperand(OrderIdArgIdx))
- ->getZExtValue();
+ const uint32_t OrderA = getOrderId(Calls[i - 1]);
+ const uint32_t OrderB = getOrderId(Calls[i]);
if (OrderA == OrderB) {
- const uint32_t DescSetA =
- cast<ConstantInt>(Calls[i - 1]->getArgOperand(DescSetArgIdx))
- ->getZExtValue();
- const uint32_t DescSetB =
- cast<ConstantInt>(Calls[i]->getArgOperand(DescSetArgIdx))
- ->getZExtValue();
+ const uint32_t DescSetA = getDescSet(Calls[i - 1]);
+ const uint32_t DescSetB = getDescSet(Calls[i]);
if (DescSetA != DescSetB) {
report_fatal_error("Implicit binding calls with the same order ID must "
"have the same descriptor set");
@@ -144,36 +177,26 @@ void SPIRVLegalizeImplicitBinding::replaceImplicitBindingCalls(Module &M) {
uint32_t lastBindingNumber = -1;
for (CallInst *OldCI : ImplicitBindingCalls) {
- IRBuilder<> Builder(OldCI);
- const uint32_t OrderId =
- cast<ConstantInt>(OldCI->getArgOperand(0))->getZExtValue();
- const uint32_t DescSet =
- cast<ConstantInt>(OldCI->getArgOperand(1))->getZExtValue();
-
- // Reuse an existing binding for this order ID, if one was already assigned.
- // Otherwise, assign a new binding.
- const uint32_t NewBinding = (lastOrderId == OrderId)
- ? lastBindingNumber
- : getAndReserveFirstUnusedBinding(DescSet);
- lastOrderId = OrderId;
- lastBindingNumber = NewBinding;
-
- SmallVector<Value *, 8> Args;
- Args.push_back(Builder.getInt32(DescSet));
- Args.push_back(Builder.getInt32(NewBinding));
-
- // Copy the remaining arguments from the old call.
- for (uint32_t i = 2; i < OldCI->arg_size(); ++i) {
- Args.push_back(OldCI->getArgOperand(i));
+ const uint32_t OrderId = getOrderId(OldCI);
+ uint32_t BindingNumber;
+ if (OrderId == lastOrderId) {
+ BindingNumber = lastBindingNumber;
+ } else {
+ const uint32_t DescSet = getDescSet(OldCI);
+ BindingNumber = getAndReserveFirstUnusedBinding(DescSet);
}
- Function *NewFunc = Intrinsic::getOrInsertDeclaration(
- &M, Intrinsic::spv_resource_handlefrombinding, OldCI->getType());
- CallInst *NewCI = Builder.CreateCall(NewFunc, Args);
- NewCI->setCallingConv(OldCI->getCallingConv());
-
- OldCI->replaceAllUsesWith(NewCI);
- OldCI->eraseFromParent();
+ if (OldCI->getIntrinsicID() ==
+ Intrinsic::spv_resource_handlefromimplicitbinding) {
+ replaceResourceHandleCall(M, OldCI, BindingNumber);
+ } else {
+ assert(OldCI->getIntrinsicID() ==
+ Intrinsic::spv_resource_counterhandlefromimplicitbinding &&
+ "Unexpected implicit binding intrinsic");
+ replaceCounterHandleCall(M, OldCI, BindingNumber);
+ }
+ lastOrderId = OrderId;
+ lastBindingNumber = BindingNumber;
}
}
@@ -196,4 +219,49 @@ INITIALIZE_PASS(SPIRVLegalizeImplicitBinding, "legalize-spirv-implicit-binding",
ModulePass *llvm::createSPIRVLegalizeImplicitBindingPass() {
return new SPIRVLegalizeImplicitBinding();
-} \ No newline at end of file
+}
+
+void SPIRVLegalizeImplicitBinding::replaceResourceHandleCall(
+ Module &M, CallInst *OldCI, uint32_t NewBinding) {
+ IRBuilder<> Builder(OldCI);
+ const uint32_t DescSet =
+ cast<ConstantInt>(OldCI->getArgOperand(1))->getZExtValue();
+
+ SmallVector<Value *, 8> Args;
+ Args.push_back(Builder.getInt32(DescSet));
+ Args.push_back(Builder.getInt32(NewBinding));
+
+ // Copy the remaining arguments from the old call.
+ for (uint32_t i = 2; i < OldCI->arg_size(); ++i) {
+ Args.push_back(OldCI->getArgOperand(i));
+ }
+
+ Function *NewFunc = Intrinsic::getOrInsertDeclaration(
+ &M, Intrinsic::spv_resource_handlefrombinding, OldCI->getType());
+ CallInst *NewCI = Builder.CreateCall(NewFunc, Args);
+ NewCI->setCallingConv(OldCI->getCallingConv());
+
+ OldCI->replaceAllUsesWith(NewCI);
+ OldCI->eraseFromParent();
+}
+
+void SPIRVLegalizeImplicitBinding::replaceCounterHandleCall(
+ Module &M, CallInst *OldCI, uint32_t NewBinding) {
+ IRBuilder<> Builder(OldCI);
+ const uint32_t DescSet =
+ cast<ConstantInt>(OldCI->getArgOperand(2))->getZExtValue();
+
+ SmallVector<Value *, 8> Args;
+ Args.push_back(OldCI->getArgOperand(0));
+ Args.push_back(Builder.getInt32(NewBinding));
+ Args.push_back(Builder.getInt32(DescSet));
+
+ Type *Tys[] = {OldCI->getType(), OldCI->getArgOperand(0)->getType()};
+ Function *NewFunc = Intrinsic::getOrInsertDeclaration(
+ &M, Intrinsic::spv_resource_counterhandlefrombinding, Tys);
+ CallInst *NewCI = Builder.CreateCall(NewFunc, Args);
+ NewCI->setCallingConv(OldCI->getCallingConv());
+
+ OldCI->replaceAllUsesWith(NewCI);
+ OldCI->eraseFromParent();
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 327c011..1d47c89 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -385,6 +385,12 @@ uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) {
return MI->getOperand(1).getCImm()->getValue().getZExtValue();
}
+int64_t getIConstValSext(Register ConstReg, const MachineRegisterInfo *MRI) {
+ const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI);
+ assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT);
+ return MI->getOperand(1).getCImm()->getSExtValue();
+}
+
bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) {
if (const auto *GI = dyn_cast<GIntrinsic>(&MI))
return GI->is(IntrinsicID);
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 409a0fd..5777a24 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -289,6 +289,9 @@ MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
// Get constant integer value of the given ConstReg.
uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI);
+// Get constant integer value of the given ConstReg, sign-extended.
+int64_t getIConstValSext(Register ConstReg, const MachineRegisterInfo *MRI);
+
// Check if MI is a SPIR-V specific intrinsic call.
bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID);
// Check if it's a SPIR-V specific intrinsic call.
diff --git a/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp b/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp
index 3090ad3..27fba34 100644
--- a/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp
+++ b/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp
@@ -407,6 +407,7 @@ bool X86InstructionSelector::select(MachineInstr &I) {
case TargetOpcode::G_TRUNC:
return selectTruncOrPtrToInt(I, MRI, MF);
case TargetOpcode::G_INTTOPTR:
+ case TargetOpcode::G_FREEZE:
return selectCopy(I, MRI);
case TargetOpcode::G_ZEXT:
return selectZext(I, MRI, MF);
diff --git a/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp b/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp
index e7709ef..11ef721 100644
--- a/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp
+++ b/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp
@@ -89,9 +89,29 @@ X86LegalizerInfo::X86LegalizerInfo(const X86Subtarget &STI,
// 32/64-bits needs support for s64/s128 to handle cases:
// s64 = EXTEND (G_IMPLICIT_DEF s32) -> s64 = G_IMPLICIT_DEF
// s128 = EXTEND (G_IMPLICIT_DEF s32/s64) -> s128 = G_IMPLICIT_DEF
- getActionDefinitionsBuilder(G_IMPLICIT_DEF)
+ getActionDefinitionsBuilder(
+ {G_IMPLICIT_DEF, G_PHI, G_FREEZE, G_CONSTANT_FOLD_BARRIER})
.legalFor({p0, s1, s8, s16, s32, s64})
- .legalFor(Is64Bit, {s128});
+ .legalFor(UseX87, {s80})
+ .legalFor(Is64Bit, {s128})
+ .legalFor(HasSSE2, {v16s8, v8s16, v4s32, v2s64})
+ .legalFor(HasAVX, {v32s8, v16s16, v8s32, v4s64})
+ .legalFor(HasAVX512, {v64s8, v32s16, v16s32, v8s64})
+ .widenScalarOrEltToNextPow2(0, /*Min=*/8)
+ .clampScalarOrElt(0, s8, sMaxScalar)
+ .moreElementsToNextPow2(0)
+ .clampMinNumElements(0, s8, 16)
+ .clampMinNumElements(0, s16, 8)
+ .clampMinNumElements(0, s32, 4)
+ .clampMinNumElements(0, s64, 2)
+ .clampMaxNumElements(0, s8, HasAVX512 ? 64 : (HasAVX ? 32 : 16))
+ .clampMaxNumElements(0, s16, HasAVX512 ? 32 : (HasAVX ? 16 : 8))
+ .clampMaxNumElements(0, s32, HasAVX512 ? 16 : (HasAVX ? 8 : 4))
+ .clampMaxNumElements(0, s64, HasAVX512 ? 8 : (HasAVX ? 4 : 2))
+ .clampMaxNumElements(0, p0,
+ Is64Bit ? s64MaxVector.getNumElements()
+ : s32MaxVector.getNumElements())
+ .scalarizeIf(scalarOrEltWiderThan(0, 64), 0);
getActionDefinitionsBuilder(G_CONSTANT)
.legalFor({p0, s8, s16, s32})
@@ -289,26 +309,6 @@ X86LegalizerInfo::X86LegalizerInfo(const X86Subtarget &STI,
.clampScalar(1, s16, sMaxScalar)
.scalarSameSizeAs(0, 1);
- // control flow
- getActionDefinitionsBuilder(G_PHI)
- .legalFor({s8, s16, s32, p0})
- .legalFor(UseX87, {s80})
- .legalFor(Is64Bit, {s64})
- .legalFor(HasSSE1, {v16s8, v8s16, v4s32, v2s64})
- .legalFor(HasAVX, {v32s8, v16s16, v8s32, v4s64})
- .legalFor(HasAVX512, {v64s8, v32s16, v16s32, v8s64})
- .clampMinNumElements(0, s8, 16)
- .clampMinNumElements(0, s16, 8)
- .clampMinNumElements(0, s32, 4)
- .clampMinNumElements(0, s64, 2)
- .clampMaxNumElements(0, s8, HasAVX512 ? 64 : (HasAVX ? 32 : 16))
- .clampMaxNumElements(0, s16, HasAVX512 ? 32 : (HasAVX ? 16 : 8))
- .clampMaxNumElements(0, s32, HasAVX512 ? 16 : (HasAVX ? 8 : 4))
- .clampMaxNumElements(0, s64, HasAVX512 ? 8 : (HasAVX ? 4 : 2))
- .widenScalarToNextPow2(0, /*Min=*/32)
- .clampScalar(0, s8, sMaxScalar)
- .scalarize(0);
-
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1});
// pointer handling
@@ -592,11 +592,6 @@ X86LegalizerInfo::X86LegalizerInfo(const X86Subtarget &STI,
.minScalar(0, LLT::scalar(32))
.libcall();
- getActionDefinitionsBuilder({G_FREEZE, G_CONSTANT_FOLD_BARRIER})
- .legalFor({s8, s16, s32, s64, p0})
- .widenScalarToNextPow2(0, /*Min=*/8)
- .clampScalar(0, s8, sMaxScalar);
-
getLegacyLegalizerInfo().computeTables();
verify(*STI.getInstrInfo());
}
diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index 564810c..83bd6ac 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -662,6 +662,7 @@ def VINSERTPSZrri : AVX512AIi8<0x21, MRMSrcReg, (outs VR128X:$dst),
"vinsertps\t{$src3, $src2, $src1, $dst|$dst, $src1, $src2, $src3}",
[(set VR128X:$dst, (X86insertps VR128X:$src1, VR128X:$src2, timm:$src3))]>,
EVEX, VVVV, Sched<[SchedWriteFShuffle.XMM]>;
+let mayLoad = 1 in
def VINSERTPSZrmi : AVX512AIi8<0x21, MRMSrcMem, (outs VR128X:$dst),
(ins VR128X:$src1, f32mem:$src2, u8imm:$src3),
"vinsertps\t{$src3, $src2, $src1, $dst|$dst, $src1, $src2, $src3}",
@@ -1293,6 +1294,7 @@ multiclass avx512_subvec_broadcast_rm<bits<8> opc, string OpcodeStr,
SDPatternOperator OpNode,
X86VectorVTInfo _Dst,
X86VectorVTInfo _Src> {
+ let hasSideEffects = 0, mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _Dst, (outs _Dst.RC:$dst),
(ins _Src.MemOp:$src), OpcodeStr, "$src", "$src",
(_Dst.VT (OpNode addr:$src))>,
@@ -1748,6 +1750,7 @@ let Constraints = "$src1 = $dst", ExeDomain = _.ExeDomain in {
(_.VT (X86VPermt2 _.RC:$src1, IdxVT.RC:$src2, _.RC:$src3)), 1>,
EVEX, VVVV, AVX5128IBase, Sched<[sched]>;
+ let hasSideEffects = 0, mayLoad = 1 in
defm rm: AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins IdxVT.RC:$src2, _.MemOp:$src3),
OpcodeStr, "$src3, $src2", "$src2, $src3",
@@ -1759,7 +1762,7 @@ let Constraints = "$src1 = $dst", ExeDomain = _.ExeDomain in {
multiclass avx512_perm_t_mb<bits<8> opc, string OpcodeStr,
X86FoldableSchedWrite sched,
X86VectorVTInfo _, X86VectorVTInfo IdxVT> {
- let Constraints = "$src1 = $dst", ExeDomain = _.ExeDomain in
+ let Constraints = "$src1 = $dst", ExeDomain = _.ExeDomain, hasSideEffects = 0, mayLoad = 1 in
defm rmb: AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins IdxVT.RC:$src2, _.ScalarMemOp:$src3),
OpcodeStr, !strconcat("${src3}", _.BroadcastStr,", $src2"),
@@ -1987,6 +1990,7 @@ multiclass avx512_cmp_scalar<X86VectorVTInfo _, SDNode OpNode, SDNode OpNodeSAE,
_.FRC:$src2,
timm:$cc))]>,
EVEX, VVVV, VEX_LIG, Sched<[sched]>, SIMD_EXC;
+ let mayLoad = 1 in
def rmi : AVX512Ii8<0xC2, MRMSrcMem,
(outs _.KRC:$dst),
(ins _.FRC:$src1, _.ScalarMemOp:$src2, u8imm:$cc),
@@ -2145,6 +2149,7 @@ multiclass avx512_icmp_cc<bits<8> opc, string Suffix, PatFrag Frag,
(_.VT _.RC:$src2),
cond)))]>,
EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1 in
def rmi : AVX512AIi8<opc, MRMSrcMem,
(outs _.KRC:$dst), (ins _.RC:$src1, _.MemOp:$src2, u8imm:$cc),
!strconcat("vpcmp", Suffix,
@@ -2167,6 +2172,7 @@ multiclass avx512_icmp_cc<bits<8> opc, string Suffix, PatFrag Frag,
(_.VT _.RC:$src2),
cond))))]>,
EVEX, VVVV, EVEX_K, Sched<[sched]>;
+ let mayLoad = 1 in
def rmik : AVX512AIi8<opc, MRMSrcMem,
(outs _.KRC:$dst), (ins _.KRCWM:$mask, _.RC:$src1, _.MemOp:$src2,
u8imm:$cc),
@@ -2198,6 +2204,7 @@ multiclass avx512_icmp_cc_rmb<bits<8> opc, string Suffix, PatFrag Frag,
PatFrag Frag_su, X86FoldableSchedWrite sched,
X86VectorVTInfo _, string Name> :
avx512_icmp_cc<opc, Suffix, Frag, Frag_su, sched, _, Name> {
+ let mayLoad = 1 in {
def rmbi : AVX512AIi8<opc, MRMSrcMem,
(outs _.KRC:$dst), (ins _.RC:$src1, _.ScalarMemOp:$src2,
u8imm:$cc),
@@ -2221,6 +2228,7 @@ multiclass avx512_icmp_cc_rmb<bits<8> opc, string Suffix, PatFrag Frag,
(_.BroadcastLdFrag addr:$src2),
cond))))]>,
EVEX, VVVV, EVEX_K, EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>;
+ }
def : Pat<(_.KVT (Frag:$cc (_.BroadcastLdFrag addr:$src2),
(_.VT _.RC:$src1), cond)),
@@ -2305,6 +2313,7 @@ let Uses = [MXCSR], mayRaiseFPException = 1 in {
(X86cmpm_su (_.VT _.RC:$src1), (_.VT _.RC:$src2), timm:$cc),
1>, Sched<[sched]>;
+ let mayLoad = 1 in {
defm rmi : AVX512_maskable_cmp<0xC2, MRMSrcMem, _,
(outs _.KRC:$dst),(ins _.RC:$src1, _.MemOp:$src2, u8imm:$cc),
"vcmp"#_.Suffix,
@@ -2329,6 +2338,7 @@ let Uses = [MXCSR], mayRaiseFPException = 1 in {
timm:$cc)>,
EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>;
}
+ }
// Patterns for selecting with loads in other operand.
def : Pat<(X86any_cmpm (_.LdFrag addr:$src2), (_.VT _.RC:$src1),
@@ -3771,6 +3781,7 @@ def VMOVDI2PDIZrr : AVX512BI<0x6E, MRMSrcReg, (outs VR128X:$dst), (ins GR32:$src
[(set VR128X:$dst,
(v4i32 (scalar_to_vector GR32:$src)))]>,
EVEX, Sched<[WriteVecMoveFromGpr]>;
+let mayLoad = 1 in
def VMOVDI2PDIZrm : AVX512BI<0x6E, MRMSrcMem, (outs VR128X:$dst), (ins i32mem:$src),
"vmovd\t{$src, $dst|$dst, $src}",
[(set VR128X:$dst,
@@ -3874,7 +3885,7 @@ def VMOVSS2DIZrr : AVX512BI<0x7E, MRMDestReg, (outs GR32:$dst),
// Move Quadword Int to Packed Quadword Int
//
-let ExeDomain = SSEPackedInt in {
+let ExeDomain = SSEPackedInt, mayLoad = 1, hasSideEffects = 0 in {
def VMOVQI2PQIZrm : AVX512XSI<0x7E, MRMSrcMem, (outs VR128X:$dst),
(ins i64mem:$src),
"vmovq\t{$src, $dst|$dst, $src}",
@@ -3930,13 +3941,13 @@ multiclass avx512_move_scalar<string asm, SDNode OpNode, PatFrag vzload_frag,
(_.VT (OpNode _.RC:$src1, _.RC:$src2)),
(_.VT _.RC:$src0))))],
_.ExeDomain>, EVEX, VVVV, EVEX_K, Sched<[SchedWriteFShuffle.XMM]>;
- let canFoldAsLoad = 1, isReMaterializable = 1 in {
+ let canFoldAsLoad = 1, isReMaterializable = 1, mayLoad = 1, hasSideEffects = 0 in {
def rm : AVX512PI<0x10, MRMSrcMem, (outs _.RC:$dst), (ins _.ScalarMemOp:$src),
!strconcat(asm, "\t{$src, $dst|$dst, $src}"),
[(set _.RC:$dst, (_.VT (vzload_frag addr:$src)))],
_.ExeDomain>, EVEX, Sched<[WriteFLoad]>;
// _alt version uses FR32/FR64 register class.
- let isCodeGenOnly = 1 in
+ let isCodeGenOnly = 1, mayLoad = 1, hasSideEffects = 0 in
def rm_alt : AVX512PI<0x10, MRMSrcMem, (outs _.FRC:$dst), (ins _.ScalarMemOp:$src),
!strconcat(asm, "\t{$src, $dst|$dst, $src}"),
[(set _.FRC:$dst, (_.ScalarLdFrag addr:$src))],
@@ -4557,6 +4568,7 @@ let Predicates = [HasAVX512] in {
// AVX-512 - Non-temporals
//===----------------------------------------------------------------------===//
+let mayLoad = 1, hasSideEffects = 0 in {
def VMOVNTDQAZrm : AVX512PI<0x2A, MRMSrcMem, (outs VR512:$dst),
(ins i512mem:$src), "vmovntdqa\t{$src, $dst|$dst, $src}",
[], SSEPackedInt>, Sched<[SchedWriteVecMoveLSNT.ZMM.RM]>,
@@ -4575,11 +4587,12 @@ let Predicates = [HasVLX] in {
[], SSEPackedInt>, Sched<[SchedWriteVecMoveLSNT.XMM.RM]>,
EVEX, T8, PD, EVEX_V128, EVEX_CD8<64, CD8VF>;
}
+}
multiclass avx512_movnt<bits<8> opc, string OpcodeStr, X86VectorVTInfo _,
X86SchedWriteMoveLS Sched,
PatFrag st_frag = alignednontemporalstore> {
- let SchedRW = [Sched.MR], AddedComplexity = 400 in
+ let mayStore = 1, SchedRW = [Sched.MR], AddedComplexity = 400 in
def mr : AVX512PI<opc, MRMDestMem, (outs), (ins _.MemOp:$dst, _.RC:$src),
!strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"),
[(st_frag (_.VT _.RC:$src), addr:$dst)],
@@ -4682,6 +4695,7 @@ multiclass avx512_binop_rm<bits<8> opc, string OpcodeStr, SDNode OpNode,
IsCommutable, IsCommutable>, AVX512BIBase, EVEX, VVVV,
Sched<[sched]>;
+ let mayLoad = 1, hasSideEffects = 0 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.MemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -4694,6 +4708,7 @@ multiclass avx512_binop_rmb<bits<8> opc, string OpcodeStr, SDNode OpNode,
X86VectorVTInfo _, X86FoldableSchedWrite sched,
bit IsCommutable = 0> :
avx512_binop_rm<opc, OpcodeStr, OpNode, _, sched, IsCommutable> {
+ let mayLoad = 1, hasSideEffects = 0 in
defm rmb : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.ScalarMemOp:$src2), OpcodeStr,
"${src2}"#_.BroadcastStr#", $src1",
@@ -4811,6 +4826,7 @@ multiclass avx512_binop_rm2<bits<8> opc, string OpcodeStr,
(_Src.VT _Src.RC:$src2))),
IsCommutable>,
AVX512BIBase, EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1, hasSideEffects = 0 in {
defm rm : AVX512_maskable<opc, MRMSrcMem, _Dst, (outs _Dst.RC:$dst),
(ins _Src.RC:$src1, _Src.MemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -4828,6 +4844,7 @@ multiclass avx512_binop_rm2<bits<8> opc, string OpcodeStr,
(_Brdct.VT (_Brdct.BroadcastLdFrag addr:$src2)))))>,
AVX512BIBase, EVEX, VVVV, EVEX_B,
Sched<[sched.Folded, sched.ReadAfterFold]>;
+ }
}
defm VPADD : avx512_binop_rm_vl_all<0xFC, 0xFD, 0xFE, 0xD4, "vpadd", add,
@@ -4893,6 +4910,7 @@ defm VPMULTISHIFTQB : avx512_binop_all<0x83, "vpmultishiftqb", SchedWriteVecALU,
multiclass avx512_packs_rmb<bits<8> opc, string OpcodeStr, SDNode OpNode,
X86VectorVTInfo _Src, X86VectorVTInfo _Dst,
X86FoldableSchedWrite sched> {
+ let mayLoad = 1, hasSideEffects = 0 in
defm rmb : AVX512_maskable<opc, MRMSrcMem, _Dst, (outs _Dst.RC:$dst),
(ins _Src.RC:$src1, _Src.ScalarMemOp:$src2),
OpcodeStr,
@@ -4916,6 +4934,7 @@ multiclass avx512_packs_rm<bits<8> opc, string OpcodeStr,
(_Src.VT _Src.RC:$src2))),
IsCommutable, IsCommutable>,
EVEX_CD8<_Src.EltSize, CD8VF>, EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1, hasSideEffects = 0 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _Dst, (outs _Dst.RC:$dst),
(ins _Src.RC:$src1, _Src.MemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -5370,6 +5389,7 @@ multiclass avx512_fp_scalar<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
(_.VT (VecNode _.RC:$src1, _.RC:$src2)), "_Int">,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -5384,6 +5404,7 @@ multiclass avx512_fp_scalar<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
Sched<[sched]> {
let isCommutable = IsCommutable;
}
+ let mayLoad = 1 in
def rm : I< opc, MRMSrcMem, (outs _.FRC:$dst),
(ins _.FRC:$src1, _.ScalarMemOp:$src2),
OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}",
@@ -5414,6 +5435,7 @@ multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
(_.VT (VecNode _.RC:$src1, _.RC:$src2)), "_Int">,
Sched<[sched]>, SIMD_EXC;
+ let mayLoad = 1 in
defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -5430,6 +5452,7 @@ multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
Sched<[sched]> {
let isCommutable = IsCommutable;
}
+ let mayLoad = 1 in
def rm : I< opc, MRMSrcMem, (outs _.FRC:$dst),
(ins _.FRC:$src1, _.ScalarMemOp:$src2),
OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}",
@@ -5509,6 +5532,7 @@ multiclass avx512_comutable_binop_s<bits<8> opc, string OpcodeStr,
Sched<[sched]> {
let isCommutable = 1;
}
+ let mayLoad = 1 in
def rm : I< opc, MRMSrcMem, (outs _.FRC:$dst),
(ins _.FRC:$src1, _.ScalarMemOp:$src2),
OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}",
@@ -5737,6 +5761,7 @@ multiclass avx512_fp_scalef_p<bits<8> opc, string OpcodeStr, SDNode OpNode,
"$src2, $src1", "$src1, $src2",
(_.VT (OpNode _.RC:$src1, _.RC:$src2))>,
EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1 in {
defm rm: AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.MemOp:$src2), OpcodeStr#_.Suffix,
"$src2, $src1", "$src1, $src2",
@@ -5749,6 +5774,7 @@ multiclass avx512_fp_scalef_p<bits<8> opc, string OpcodeStr, SDNode OpNode,
(OpNode _.RC:$src1, (_.VT (_.BroadcastLdFrag addr:$src2)))>,
EVEX, VVVV, EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>;
}
+ }
}
multiclass avx512_fp_scalef_scalar<bits<8> opc, string OpcodeStr, SDNode OpNode,
@@ -5759,6 +5785,7 @@ multiclass avx512_fp_scalef_scalar<bits<8> opc, string OpcodeStr, SDNode OpNode,
"$src2, $src1", "$src1, $src2",
(_.VT (OpNode _.RC:$src1, _.RC:$src2))>,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rm: AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr#_.Suffix,
"$src2, $src1", "$src1, $src2",
@@ -5916,6 +5943,7 @@ multiclass avx512_shift_rmi<bits<8> opc, Format ImmFormR, Format ImmFormM,
"$src2, $src1", "$src1, $src2",
(_.VT (OpNode _.RC:$src1, (i8 timm:$src2)))>,
Sched<[sched]>;
+ let mayLoad = 1 in
defm mi : AVX512_maskable<opc, ImmFormM, _, (outs _.RC:$dst),
(ins _.MemOp:$src1, u8imm:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -5928,7 +5956,7 @@ multiclass avx512_shift_rmi<bits<8> opc, Format ImmFormR, Format ImmFormM,
multiclass avx512_shift_rmbi<bits<8> opc, Format ImmFormM,
string OpcodeStr, SDNode OpNode,
X86FoldableSchedWrite sched, X86VectorVTInfo _> {
- let ExeDomain = _.ExeDomain in
+ let ExeDomain = _.ExeDomain, mayLoad = 1 in
defm mbi : AVX512_maskable<opc, ImmFormM, _, (outs _.RC:$dst),
(ins _.ScalarMemOp:$src1, u8imm:$src2), OpcodeStr,
"$src2, ${src1}"#_.BroadcastStr, "${src1}"#_.BroadcastStr#", $src2",
@@ -5946,6 +5974,7 @@ multiclass avx512_shift_rrm<bits<8> opc, string OpcodeStr, SDNode OpNode,
"$src2, $src1", "$src1, $src2",
(_.VT (OpNode _.RC:$src1, (SrcVT VR128X:$src2)))>,
AVX512BIBase, EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, i128mem:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -6095,6 +6124,7 @@ multiclass avx512_var_shift<bits<8> opc, string OpcodeStr, SDNode OpNode,
"$src2, $src1", "$src1, $src2",
(_.VT (OpNode _.RC:$src1, (_.VT _.RC:$src2)))>,
AVX5128IBase, EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.MemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -6107,7 +6137,7 @@ multiclass avx512_var_shift<bits<8> opc, string OpcodeStr, SDNode OpNode,
multiclass avx512_var_shift_mb<bits<8> opc, string OpcodeStr, SDNode OpNode,
X86FoldableSchedWrite sched, X86VectorVTInfo _> {
- let ExeDomain = _.ExeDomain in
+ let ExeDomain = _.ExeDomain, mayLoad = 1 in
defm rmb : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.ScalarMemOp:$src2), OpcodeStr,
"${src2}"#_.BroadcastStr#", $src1",
@@ -6372,6 +6402,7 @@ multiclass avx512_permil_vec<bits<8> OpcVar, string OpcodeStr, SDNode OpNode,
(_.VT (OpNode _.RC:$src1,
(Ctrl.VT Ctrl.RC:$src2)))>,
T8, PD, EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1 in {
defm rm: AVX512_maskable<OpcVar, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, Ctrl.MemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -6389,6 +6420,7 @@ multiclass avx512_permil_vec<bits<8> OpcVar, string OpcodeStr, SDNode OpNode,
(Ctrl.VT (Ctrl.BroadcastLdFrag addr:$src2))))>,
T8, PD, EVEX, VVVV, EVEX_B, EVEX_CD8<_.EltSize, CD8VF>,
Sched<[sched.Folded, sched.ReadAfterFold]>;
+ }
}
multiclass avx512_permil_vec_common<string OpcodeStr, bits<8> OpcVar,
@@ -7258,6 +7290,7 @@ let ExeDomain = DstVT.ExeDomain, Uses = _Uses,
(OpNode (DstVT.VT DstVT.RC:$src1), SrcRC:$src2))]>,
EVEX, VVVV, Sched<[sched, ReadDefault, ReadInt2Fpu]>;
+ let mayLoad = 1 in
def rm_Int : SI<opc, MRMSrcMem, (outs DstVT.RC:$dst),
(ins DstVT.RC:$src1, x86memop:$src2),
asm#"{"#mem#"}\t{$src2, $src1, $dst|$dst, $src1, $src2}",
@@ -7400,6 +7433,7 @@ multiclass avx512_cvt_s_int_round<bits<8> opc, X86VectorVTInfo SrcVT,
[(set DstVT.RC:$dst, (OpNodeRnd (SrcVT.VT SrcVT.RC:$src),(i32 timm:$rc)))]>,
EVEX, VEX_LIG, EVEX_B, EVEX_RC,
Sched<[sched]>;
+ let mayLoad = 1 in
def rm_Int : SI<opc, MRMSrcMem, (outs DstVT.RC:$dst), (ins SrcVT.IntScalarMemOp:$src),
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
[(set DstVT.RC:$dst, (OpNode
@@ -7451,6 +7485,7 @@ multiclass avx512_cvt_s<bits<8> opc, string asm, X86VectorVTInfo SrcVT,
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
[(set DstVT.RC:$dst, (OpNode SrcVT.FRC:$src))]>,
EVEX, VEX_LIG, Sched<[sched]>, SIMD_EXC;
+ let mayLoad = 1 in
def rm : AVX512<opc, MRMSrcMem, (outs DstVT.RC:$dst), (ins SrcVT.ScalarMemOp:$src),
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
[(set DstVT.RC:$dst, (OpNode (SrcVT.ScalarLdFrag addr:$src)))]>,
@@ -7572,6 +7607,7 @@ let Predicates = [prd], ExeDomain = _SrcRC.ExeDomain in {
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
[(set _DstRC.RC:$dst, (OpNode _SrcRC.FRC:$src))]>,
EVEX, VEX_LIG, Sched<[sched]>, SIMD_EXC;
+ let mayLoad = 1 in
def rm : AVX512<opc, MRMSrcMem, (outs _DstRC.RC:$dst), (ins _SrcRC.ScalarMemOp:$src),
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
[(set _DstRC.RC:$dst, (OpNode (_SrcRC.ScalarLdFrag addr:$src)))]>,
@@ -7587,6 +7623,7 @@ let Predicates = [prd], ExeDomain = _SrcRC.ExeDomain in {
!strconcat(asm,"\t{{sae}, $src, $dst|$dst, $src, {sae}}"),
[(set _DstRC.RC:$dst, (OpNodeSAE (_SrcRC.VT _SrcRC.RC:$src)))]>,
EVEX, VEX_LIG, EVEX_B, Sched<[sched]>;
+ let mayLoad = 1 in
def rm_Int : AVX512<opc, MRMSrcMem, (outs _DstRC.RC:$dst),
(ins _SrcRC.IntScalarMemOp:$src),
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
@@ -7644,6 +7681,7 @@ multiclass avx512_cvt_fp_scalar<bits<8> opc, string OpcodeStr, X86VectorVTInfo _
(_.VT (OpNode (_.VT _.RC:$src1),
(_Src.VT _Src.RC:$src2))), "_Int">,
EVEX, VVVV, VEX_LIG, Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _Src.IntScalarMemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -7807,6 +7845,7 @@ let Uses = [MXCSR], mayRaiseFPException = 1 in {
_.ImmAllZerosV)>,
EVEX, Sched<[sched]>;
+ let mayLoad = 1 in {
defm rm : AVX512_maskable_cvt<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins MemOp:$src),
(ins _.RC:$src0, MaskRC:$mask, MemOp:$src),
@@ -7840,6 +7879,7 @@ let Uses = [MXCSR], mayRaiseFPException = 1 in {
_.ImmAllZerosV)>,
EVEX, EVEX_B, Sched<[sched.Folded]>;
}
+ }
}
// Conversion with SAE - suppress all exceptions
multiclass avx512_vcvt_fp_sae<bits<8> opc, string OpcodeStr, X86VectorVTInfo _,
@@ -8944,6 +8984,7 @@ multiclass avx512_cvtph2ps<X86VectorVTInfo _dest, X86VectorVTInfo _src,
(X86any_cvtph2ps (_src.VT _src.RC:$src)),
(X86cvtph2ps (_src.VT _src.RC:$src))>,
T8, PD, Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable_split<0x13, MRMSrcMem, _dest, (outs _dest.RC:$dst),
(ins x86memop:$src), "vcvtph2ps", "$src", "$src",
(X86any_cvtph2ps (_src.VT ld_dag)),
@@ -9161,6 +9202,7 @@ multiclass avx512_fp14_s<bits<8> opc, string OpcodeStr, SDNode OpNode,
"$src2, $src1", "$src1, $src2",
(OpNode (_.VT _.RC:$src1), (_.VT _.RC:$src2))>,
EVEX, VVVV, VEX_LIG, Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
@@ -9621,6 +9663,7 @@ multiclass avx512_rndscale_scalar<bits<8> opc, string OpcodeStr,
(i32 timm:$src3))), "_Int">, EVEX_B,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rmi : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.IntScalarMemOp:$src2, i32u8imm:$src3),
OpcodeStr,
@@ -9999,6 +10042,7 @@ multiclass avx512_pmovx_common<bits<8> opc, string OpcodeStr, X86FoldableSchedWr
(DestInfo.VT (OpNode (SrcInfo.VT SrcInfo.RC:$src)))>,
EVEX, Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, DestInfo, (outs DestInfo.RC:$dst),
(ins x86memop:$src), OpcodeStr ,"$src", "$src",
(DestInfo.VT (LdFrag addr:$src))>,
@@ -10601,6 +10645,7 @@ multiclass expand_by_vec_width<bits<8> opc, X86VectorVTInfo _,
(null_frag)>, AVX5128IBase,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.MemOp:$src1), OpcodeStr, "$src1", "$src1",
(null_frag)>,
@@ -10673,6 +10718,7 @@ multiclass avx512_unary_fp_packed_imm<bits<8> opc, string OpcodeStr,
(OpNode (_.VT _.RC:$src1), (i32 timm:$src2)),
(MaskOpNode (_.VT _.RC:$src1), (i32 timm:$src2))>,
Sched<[sched]>;
+ let mayLoad = 1 in {
defm rmi : AVX512_maskable_split<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.MemOp:$src1, i32u8imm:$src2),
OpcodeStr#_.Suffix, "$src2, $src1", "$src1, $src2",
@@ -10691,6 +10737,7 @@ multiclass avx512_unary_fp_packed_imm<bits<8> opc, string OpcodeStr,
(i32 timm:$src2))>, EVEX_B,
Sched<[sched.Folded, sched.ReadAfterFold]>;
}
+ }
}
//handle instruction reg_vec1 = op(reg_vec2,reg_vec3,imm),{sae}
@@ -10739,6 +10786,7 @@ multiclass avx512_fp_packed_imm<bits<8> opc, string OpcodeStr, SDNode OpNode,
(_.VT _.RC:$src2),
(i32 timm:$src3))>,
Sched<[sched]>;
+ let mayLoad = 1 in {
defm rmi : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.MemOp:$src2, i32u8imm:$src3),
OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3",
@@ -10755,6 +10803,7 @@ multiclass avx512_fp_packed_imm<bits<8> opc, string OpcodeStr, SDNode OpNode,
(i32 timm:$src3))>, EVEX_B,
Sched<[sched.Folded, sched.ReadAfterFold]>;
}
+ }
}
//handle instruction reg_vec1 = op(reg_vec2,reg_vec3,imm)
@@ -10770,6 +10819,7 @@ multiclass avx512_3Op_rm_imm8<bits<8> opc, string OpcodeStr, SDNode OpNode,
(SrcInfo.VT SrcInfo.RC:$src2),
(i8 timm:$src3)))>,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rmi : AVX512_maskable<opc, MRMSrcMem, DestInfo, (outs DestInfo.RC:$dst),
(ins SrcInfo.RC:$src1, SrcInfo.MemOp:$src2, u8imm:$src3),
OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3",
@@ -10788,7 +10838,7 @@ multiclass avx512_3Op_imm8<bits<8> opc, string OpcodeStr, SDNode OpNode,
X86FoldableSchedWrite sched, X86VectorVTInfo _>:
avx512_3Op_rm_imm8<opc, OpcodeStr, OpNode, sched, _, _>{
- let ExeDomain = _.ExeDomain, ImmT = Imm8 in
+ let ExeDomain = _.ExeDomain, ImmT = Imm8, mayLoad = 1 in
defm rmbi : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.ScalarMemOp:$src2, u8imm:$src3),
OpcodeStr, "$src3, ${src2}"#_.BroadcastStr#", $src1",
@@ -10811,6 +10861,7 @@ multiclass avx512_fp_scalar_imm<bits<8> opc, string OpcodeStr, SDNode OpNode,
(_.VT _.RC:$src2),
(i32 timm:$src3))>,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rmi : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.IntScalarMemOp:$src2, i32u8imm:$src3),
OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3",
@@ -10979,6 +11030,7 @@ multiclass avx512_shuff_packed_128_common<bits<8> opc, string OpcodeStr,
(CastInfo.VT (X86Shuf128 _.RC:$src1, _.RC:$src2,
(i8 timm:$src3)))))>,
Sched<[sched]>;
+ let mayLoad = 1 in {
defm rmi : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.MemOp:$src2, u8imm:$src3),
OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3",
@@ -11000,6 +11052,7 @@ multiclass avx512_shuff_packed_128_common<bits<8> opc, string OpcodeStr,
(i8 timm:$src3)))))>, EVEX_B,
Sched<[sched.Folded, sched.ReadAfterFold]>;
}
+ }
}
multiclass avx512_shuff_packed_128<string OpcodeStr, X86FoldableSchedWrite sched,
@@ -11031,6 +11084,7 @@ multiclass avx512_valign<bits<8> opc, string OpcodeStr,
OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3",
(_.VT (X86VAlign _.RC:$src1, _.RC:$src2, (i8 timm:$src3)))>,
Sched<[sched]>;
+ let mayLoad = 1 in {
defm rmi : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.MemOp:$src2, u8imm:$src3),
OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3",
@@ -11048,6 +11102,7 @@ multiclass avx512_valign<bits<8> opc, string OpcodeStr,
(i8 timm:$src3))>, EVEX_B,
Sched<[sched.Folded, sched.ReadAfterFold]>;
}
+ }
}
multiclass avx512_valign_common<string OpcodeStr, X86SchedWriteWidths sched,
@@ -11202,6 +11257,7 @@ multiclass avx512_unary_rm<bits<8> opc, string OpcodeStr, SDNode OpNode,
(_.VT (OpNode (_.VT _.RC:$src1)))>, EVEX, AVX5128IBase,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.MemOp:$src1), OpcodeStr,
"$src1", "$src1",
@@ -11214,6 +11270,7 @@ multiclass avx512_unary_rm<bits<8> opc, string OpcodeStr, SDNode OpNode,
multiclass avx512_unary_rmb<bits<8> opc, string OpcodeStr, SDNode OpNode,
X86FoldableSchedWrite sched, X86VectorVTInfo _> :
avx512_unary_rm<opc, OpcodeStr, OpNode, sched, _> {
+ let mayLoad = 1 in
defm rmb : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.ScalarMemOp:$src1), OpcodeStr,
"${src1}"#_.BroadcastStr,
@@ -11368,6 +11425,7 @@ multiclass avx512_movddup_128<bits<8> opc, string OpcodeStr,
(ins _.RC:$src), OpcodeStr, "$src", "$src",
(_.VT (X86VBroadcast (_.VT _.RC:$src)))>, EVEX,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.ScalarMemOp:$src), OpcodeStr, "$src", "$src",
(_.VT (_.BroadcastLdFrag addr:$src))>,
@@ -11513,6 +11571,7 @@ defm VPEXTRQZ : avx512_extract_elt_dq<"vpextrq", v2i64x_info, GR64>, REX_W;
multiclass avx512_insert_elt_m<bits<8> opc, string OpcodeStr, SDNode OpNode,
X86VectorVTInfo _, PatFrag LdFrag,
SDPatternOperator immoperator> {
+ let mayLoad = 1 in
def rmi : AVX512Ii8<opc, MRMSrcMem, (outs _.RC:$dst),
(ins _.RC:$src1, _.ScalarMemOp:$src2, u8imm:$src3),
OpcodeStr#"\t{$src3, $src2, $src1, $dst|$dst, $src1, $src2, $src3}",
@@ -11650,6 +11709,7 @@ multiclass avx512_psadbw_packed<bits<8> opc, SDNode OpNode,
(OpNode (_src.VT _src.RC:$src1),
(_src.VT _src.RC:$src2))))]>,
Sched<[sched]>;
+ let mayLoad = 1 in
def rm : AVX512BI<opc, MRMSrcMem,
(outs _dst.RC:$dst), (ins _src.RC:$src1, _src.MemOp:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
@@ -11751,6 +11811,7 @@ multiclass avx512_ternlog<bits<8> opc, string OpcodeStr, SDNode OpNode,
(_.VT _.RC:$src3),
(i8 timm:$src4)), 1, 1>,
AVX512AIi8Base, EVEX, VVVV, Sched<[sched]>;
+ let mayLoad = 1 in {
defm rmi : AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src2, _.MemOp:$src3, u8imm:$src4),
OpcodeStr, "$src4, $src3, $src2", "$src2, $src3, $src4",
@@ -11770,6 +11831,7 @@ multiclass avx512_ternlog<bits<8> opc, string OpcodeStr, SDNode OpNode,
(i8 timm:$src4)), 1, 0>, EVEX_B,
AVX512AIi8Base, EVEX, VVVV, EVEX_CD8<_.EltSize, CD8VF>,
Sched<[sched.Folded, sched.ReadAfterFold]>;
+ }
}// Constraints = "$src1 = $dst"
// Additional patterns for matching passthru operand in other positions.
@@ -12016,6 +12078,7 @@ multiclass avx512_fixupimm_packed<bits<8> opc, string OpcodeStr,
(_.VT _.RC:$src2),
(TblVT.VT _.RC:$src3),
(i32 timm:$src4))>, Sched<[sched]>;
+ let mayLoad = 1 in {
defm rmi : AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src2, _.MemOp:$src3, i32u8imm:$src4),
OpcodeStr#_.Suffix, "$src4, $src3, $src2", "$src2, $src3, $src4",
@@ -12033,6 +12096,7 @@ multiclass avx512_fixupimm_packed<bits<8> opc, string OpcodeStr,
(TblVT.VT (TblVT.BroadcastLdFrag addr:$src3)),
(i32 timm:$src4))>,
EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>;
+ }
} // Constraints = "$src1 = $dst"
}
@@ -12075,6 +12139,7 @@ multiclass avx512_fixupimm_scalar<bits<8> opc, string OpcodeStr,
(_src3VT.VT _src3VT.RC:$src3),
(i32 timm:$src4))>,
EVEX_B, Sched<[sched.Folded, sched.ReadAfterFold]>;
+ let mayLoad = 1 in
defm rmi : AVX512_maskable_3src_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src2, _.ScalarMemOp:$src3, i32u8imm:$src4),
OpcodeStr#_.Suffix, "$src4, $src3, $src2", "$src2, $src3, $src4",
@@ -12417,6 +12482,7 @@ multiclass VNNI_rmb<bits<8> Op, string OpStr, SDNode OpNode,
VTI.RC:$src2, VTI.RC:$src3)),
IsCommutable, IsCommutable>,
EVEX, VVVV, T8, Sched<[sched]>;
+ let mayLoad = 1 in {
defm rm : AVX512_maskable_3src<Op, MRMSrcMem, VTI, (outs VTI.RC:$dst),
(ins VTI.RC:$src2, VTI.MemOp:$src3), OpStr,
"$src3, $src2", "$src2, $src3",
@@ -12435,6 +12501,7 @@ multiclass VNNI_rmb<bits<8> Op, string OpStr, SDNode OpNode,
T8, Sched<[sched.Folded, sched.ReadAfterFold,
sched.ReadAfterFold]>;
}
+ }
}
multiclass VNNI_common<bits<8> Op, string OpStr, SDNode OpNode,
@@ -12508,6 +12575,7 @@ multiclass VPSHUFBITQMB_rm<X86FoldableSchedWrite sched, X86VectorVTInfo VTI> {
(X86Vpshufbitqmb_su (VTI.VT VTI.RC:$src1),
(VTI.VT VTI.RC:$src2))>, EVEX, VVVV, T8, PD,
Sched<[sched]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable_cmp<0x8F, MRMSrcMem, VTI, (outs VTI.KRC:$dst),
(ins VTI.RC:$src1, VTI.MemOp:$src2),
"vpshufbitqmb",
@@ -12557,7 +12625,7 @@ multiclass GF2P8AFFINE_avx512_rmb_imm<bits<8> Op, string OpStr, SDNode OpNode,
X86FoldableSchedWrite sched, X86VectorVTInfo VTI,
X86VectorVTInfo BcstVTI>
: avx512_3Op_rm_imm8<Op, OpStr, OpNode, sched, VTI, VTI> {
- let ExeDomain = VTI.ExeDomain in
+ let ExeDomain = VTI.ExeDomain, mayLoad = 1 in
defm rmbi : AVX512_maskable<Op, MRMSrcMem, VTI, (outs VTI.RC:$dst),
(ins VTI.RC:$src1, BcstVTI.ScalarMemOp:$src2, u8imm:$src3),
OpStr, "$src3, ${src2}"#BcstVTI.BroadcastStr#", $src1",
@@ -12660,6 +12728,7 @@ multiclass avx512_vp2intersect_modes<X86FoldableSchedWrite sched, X86VectorVTInf
_.RC:$src1, (_.VT _.RC:$src2)))]>,
EVEX, VVVV, T8, XD, Sched<[sched]>;
+ let mayLoad = 1 in {
def rm : I<0x68, MRMSrcMem,
(outs _.KRPC:$dst),
(ins _.RC:$src1, _.MemOp:$src2),
@@ -12679,6 +12748,7 @@ multiclass avx512_vp2intersect_modes<X86FoldableSchedWrite sched, X86VectorVTInf
_.RC:$src1, (_.VT (_.BroadcastLdFrag addr:$src2))))]>,
EVEX, VVVV, T8, XD, EVEX_B, EVEX_CD8<_.EltSize, CD8VF>,
Sched<[sched.Folded, sched.ReadAfterFold]>;
+ }
}
multiclass avx512_vp2intersect<X86SchedWriteWidths sched, AVX512VLVectorVTInfo _> {
@@ -12882,6 +12952,7 @@ let Predicates = [HasFP16] in {
// Move word ( r/m16) to Packed word
def VMOVW2SHrr : AVX512<0x6E, MRMSrcReg, (outs VR128X:$dst), (ins GR32:$src),
"vmovw\t{$src, $dst|$dst, $src}", []>, T_MAP5, PD, EVEX, Sched<[WriteVecMoveFromGpr]>;
+let mayLoad = 1 in
def VMOVWrm : AVX512<0x6E, MRMSrcMem, (outs VR128X:$dst), (ins i16mem:$src),
"vmovw\t{$src, $dst|$dst, $src}",
[(set VR128X:$dst,
@@ -13607,6 +13678,7 @@ multiclass avx512_cfmbinop_sh_common<bits<8> opc, string OpcodeStr, SDNode OpNod
(v4f32 (OpNode VR128X:$src1, VR128X:$src2)),
IsCommutable, IsCommutable, IsCommutable,
X86selects, "@earlyclobber $dst">, Sched<[WriteFMAX]>;
+ let mayLoad = 1 in
defm rm : AVX512_maskable<opc, MRMSrcMem, f32x_info, (outs VR128X:$dst),
(ins VR128X:$src1, ssmem:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
diff --git a/llvm/lib/TargetParser/TargetParser.cpp b/llvm/lib/TargetParser/TargetParser.cpp
index 34b09b1..62a3c88 100644
--- a/llvm/lib/TargetParser/TargetParser.cpp
+++ b/llvm/lib/TargetParser/TargetParser.cpp
@@ -444,6 +444,7 @@ static void fillAMDGCNFeatureMap(StringRef GPU, const Triple &T,
Features["atomic-fmin-fmax-global-f32"] = true;
Features["atomic-fmin-fmax-global-f64"] = true;
Features["wavefrontsize32"] = true;
+ Features["clusters"] = true;
break;
case GK_GFX1201:
case GK_GFX1200:
diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
index 8d9a0e7..50130da 100644
--- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp
@@ -2067,6 +2067,36 @@ static void inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes,
AI.run(SCCNodes, Changed);
}
+// Determines if the function 'F' can be marked 'norecurse'.
+// It returns true if any call within 'F' could lead to a recursive
+// call back to 'F', and false otherwise.
+// The 'AnyFunctionsAddressIsTaken' parameter is a module-wide flag
+// that is true if any function's address is taken, or if any function
+// has external linkage. This is used to determine the safety of
+// external/library calls.
+static bool mayHaveRecursiveCallee(Function &F,
+ bool AnyFunctionsAddressIsTaken = true) {
+ for (const auto &BB : F) {
+ for (const auto &I : BB.instructionsWithoutDebug()) {
+ if (const auto *CB = dyn_cast<CallBase>(&I)) {
+ const Function *Callee = CB->getCalledFunction();
+ if (!Callee || Callee == &F)
+ return true;
+
+ if (Callee->doesNotRecurse())
+ continue;
+
+ if (!AnyFunctionsAddressIsTaken ||
+ (Callee->isDeclaration() &&
+ Callee->hasFnAttribute(Attribute::NoCallback)))
+ continue;
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
static void addNoRecurseAttrs(const SCCNodeSet &SCCNodes,
SmallPtrSet<Function *, 8> &Changed) {
// Try and identify functions that do not recurse.
@@ -2078,28 +2108,14 @@ static void addNoRecurseAttrs(const SCCNodeSet &SCCNodes,
Function *F = *SCCNodes.begin();
if (!F || !F->hasExactDefinition() || F->doesNotRecurse())
return;
-
- // If all of the calls in F are identifiable and are to norecurse functions, F
- // is norecurse. This check also detects self-recursion as F is not currently
- // marked norecurse, so any called from F to F will not be marked norecurse.
- for (auto &BB : *F)
- for (auto &I : BB.instructionsWithoutDebug())
- if (auto *CB = dyn_cast<CallBase>(&I)) {
- Function *Callee = CB->getCalledFunction();
- if (!Callee || Callee == F ||
- (!Callee->doesNotRecurse() &&
- !(Callee->isDeclaration() &&
- Callee->hasFnAttribute(Attribute::NoCallback))))
- // Function calls a potentially recursive function.
- return;
- }
-
- // Every call was to a non-recursive function other than this function, and
- // we have no indirect recursion as the SCC size is one. This function cannot
- // recurse.
- F->setDoesNotRecurse();
- ++NumNoRecurse;
- Changed.insert(F);
+ if (!mayHaveRecursiveCallee(*F)) {
+ // Every call was to a non-recursive function other than this function, and
+ // we have no indirect recursion as the SCC size is one. This function
+ // cannot recurse.
+ F->setDoesNotRecurse();
+ ++NumNoRecurse;
+ Changed.insert(F);
+ }
}
// Set the noreturn function attribute if possible.
@@ -2429,3 +2445,62 @@ ReversePostOrderFunctionAttrsPass::run(Module &M, ModuleAnalysisManager &AM) {
PA.preserve<LazyCallGraphAnalysis>();
return PA;
}
+
+PreservedAnalyses NoRecurseLTOInferencePass::run(Module &M,
+ ModuleAnalysisManager &MAM) {
+
+ // Check if any function in the whole program has its address taken or has
+ // potentially external linkage.
+ // We use this information when inferring norecurse attribute: If there is
+ // no function whose address is taken and all functions have internal
+ // linkage, there is no path for a callback to any user function.
+ bool AnyFunctionsAddressIsTaken = false;
+ for (Function &F : M) {
+ if (F.isDeclaration() || F.doesNotRecurse())
+ continue;
+ if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
+ AnyFunctionsAddressIsTaken = true;
+ break;
+ }
+ }
+
+ // Run norecurse inference on all RefSCCs in the LazyCallGraph for this
+ // module.
+ bool Changed = false;
+ LazyCallGraph &CG = MAM.getResult<LazyCallGraphAnalysis>(M);
+ CG.buildRefSCCs();
+
+ for (LazyCallGraph::RefSCC &RC : CG.postorder_ref_sccs()) {
+ // Skip any RefSCC that is part of a call cycle. A RefSCC containing more
+ // than one SCC indicates a recursive relationship involving indirect calls.
+ if (RC.size() > 1)
+ continue;
+
+ // RefSCC contains a single-SCC. SCC size > 1 indicates mutually recursive
+ // functions. Ex: foo1 -> foo2 -> foo3 -> foo1.
+ LazyCallGraph::SCC &S = *RC.begin();
+ if (S.size() > 1)
+ continue;
+
+ // Get the single function from this SCC.
+ Function &F = S.begin()->getFunction();
+ if (!F.hasExactDefinition() || F.doesNotRecurse())
+ continue;
+
+ // If the analysis confirms that this function has no recursive calls
+ // (either direct, indirect, or through external linkages),
+ // we can safely apply the norecurse attribute.
+ if (!mayHaveRecursiveCallee(F, AnyFunctionsAddressIsTaken)) {
+ F.setDoesNotRecurse();
+ ++NumNoRecurse;
+ Changed = true;
+ }
+ }
+
+ PreservedAnalyses PA;
+ if (Changed)
+ PA.preserve<LazyCallGraphAnalysis>();
+ else
+ PA = PreservedAnalyses::all();
+ return PA;
+}
diff --git a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
index faeab95..cfdfd94 100644
--- a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
+++ b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
@@ -3986,6 +3986,7 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones(
void ModuleCallsiteContextGraph::updateAllocationCall(
CallInfo &Call, AllocationType AllocType) {
std::string AllocTypeString = getAllocTypeAttributeString(AllocType);
+ removeAnyExistingAmbiguousAttribute(cast<CallBase>(Call.call()));
auto A = llvm::Attribute::get(Call.call()->getFunction()->getContext(),
"memprof", AllocTypeString);
cast<CallBase>(Call.call())->addFnAttr(A);
@@ -5661,9 +5662,10 @@ bool MemProfContextDisambiguation::applyImport(Module &M) {
auto *MemProfMD = I.getMetadata(LLVMContext::MD_memprof);
// Include allocs that were already assigned a memprof function
- // attribute in the statistics.
- if (CB->getAttributes().hasFnAttr("memprof")) {
- assert(!MemProfMD);
+ // attribute in the statistics. Only do this for those that do not have
+ // memprof metadata, since we add an "ambiguous" memprof attribute by
+ // default.
+ if (CB->getAttributes().hasFnAttr("memprof") && !MemProfMD) {
CB->getAttributes().getFnAttr("memprof").getValueAsString() == "cold"
? AllocTypeColdThinBackend++
: AllocTypeNotColdThinBackend++;
@@ -5740,6 +5742,7 @@ bool MemProfContextDisambiguation::applyImport(Module &M) {
// clone J-1 (J==0 is the original clone and does not have a VMaps
// entry).
CBClone = cast<CallBase>((*VMaps[J - 1])[CB]);
+ removeAnyExistingAmbiguousAttribute(CBClone);
CBClone->addFnAttr(A);
ORE.emit(OptimizationRemark(DEBUG_TYPE, "MemprofAttribute", CBClone)
<< ore::NV("AllocationCall", CBClone) << " in clone "
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 8f60e50..8c8fc69 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3356,7 +3356,10 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
impliesPoisonOrCond(FalseVal, B, /*Expected=*/false)) {
// (A || B) || C --> A || (B | C)
return replaceInstUsesWith(
- SI, Builder.CreateLogicalOr(A, Builder.CreateOr(B, FalseVal)));
+ SI, Builder.CreateLogicalOr(A, Builder.CreateOr(B, FalseVal), "",
+ ProfcheckDisableMetadataFixes
+ ? nullptr
+ : cast<SelectInst>(CondVal)));
}
// (A && B) || (C && B) --> (A || C) && B
@@ -3398,7 +3401,10 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
impliesPoisonOrCond(TrueVal, B, /*Expected=*/true)) {
// (A && B) && C --> A && (B & C)
return replaceInstUsesWith(
- SI, Builder.CreateLogicalAnd(A, Builder.CreateAnd(B, TrueVal)));
+ SI, Builder.CreateLogicalAnd(A, Builder.CreateAnd(B, TrueVal), "",
+ ProfcheckDisableMetadataFixes
+ ? nullptr
+ : cast<SelectInst>(CondVal)));
}
// (A || B) && (C || B) --> (A && C) || B
diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
index af216cd..9693ae6 100644
--- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp
+++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
@@ -317,24 +317,29 @@ static Value *simplifyInstruction(SCCPSolver &Solver,
// Early exit if we know nothing about X.
if (LRange.isFullSet())
return nullptr;
- // We are allowed to refine the comparison to either true or false for out
- // of range inputs. Here we refine the comparison to true, i.e. we relax
- // the range check.
- auto NewCR = CR->exactUnionWith(LRange.inverse());
- // TODO: Check if we can narrow the range check to an equality test.
- // E.g, for X in [0, 4), X - 3 u< 2 -> X == 3
- if (!NewCR)
+ auto ConvertCRToICmp =
+ [&](const std::optional<ConstantRange> &NewCR) -> Value * {
+ ICmpInst::Predicate Pred;
+ APInt RHS;
+ // Check if we can represent NewCR as an icmp predicate.
+ if (NewCR && NewCR->getEquivalentICmp(Pred, RHS)) {
+ IRBuilder<NoFolder> Builder(&Inst);
+ Value *NewICmp =
+ Builder.CreateICmp(Pred, X, ConstantInt::get(X->getType(), RHS));
+ InsertedValues.insert(NewICmp);
+ return NewICmp;
+ }
return nullptr;
- ICmpInst::Predicate Pred;
- APInt RHS;
- // Check if we can represent NewCR as an icmp predicate.
- if (NewCR->getEquivalentICmp(Pred, RHS)) {
- IRBuilder<NoFolder> Builder(&Inst);
- Value *NewICmp =
- Builder.CreateICmp(Pred, X, ConstantInt::get(X->getType(), RHS));
- InsertedValues.insert(NewICmp);
- return NewICmp;
- }
+ };
+ // We are allowed to refine the comparison to either true or false for out
+ // of range inputs.
+ // Here we refine the comparison to false, and check if we can narrow the
+ // range check to a simpler test.
+ if (auto *V = ConvertCRToICmp(CR->exactIntersectWith(LRange)))
+ return V;
+ // Here we refine the comparison to true, i.e. we relax the range check.
+ if (auto *V = ConvertCRToICmp(CR->exactUnionWith(LRange.inverse())))
+ return V;
}
}
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 48055ad..b8cfe3a 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -4895,9 +4895,8 @@ bool SimplifyCFGOpt::simplifyTerminatorOnSelect(Instruction *OldTerm,
// We found both of the successors we were looking for.
// Create a conditional branch sharing the condition of the select.
BranchInst *NewBI = Builder.CreateCondBr(Cond, TrueBB, FalseBB);
- if (TrueWeight != FalseWeight)
- setBranchWeights(*NewBI, {TrueWeight, FalseWeight},
- /*IsExpected=*/false, /*ElideAllZero=*/true);
+ setBranchWeights(*NewBI, {TrueWeight, FalseWeight},
+ /*IsExpected=*/false, /*ElideAllZero=*/true);
}
} else if (KeepEdge1 && (KeepEdge2 || TrueBB == FalseBB)) {
// Neither of the selected blocks were successors, so this
@@ -4982,9 +4981,15 @@ bool SimplifyCFGOpt::simplifyIndirectBrOnSelect(IndirectBrInst *IBI,
BasicBlock *TrueBB = TBA->getBasicBlock();
BasicBlock *FalseBB = FBA->getBasicBlock();
+ // The select's profile becomes the profile of the conditional branch that
+ // replaces the indirect branch.
+ SmallVector<uint32_t> SelectBranchWeights(2);
+ if (!ProfcheckDisableMetadataFixes)
+ extractBranchWeights(*SI, SelectBranchWeights);
// Perform the actual simplification.
- return simplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB, 0,
- 0);
+ return simplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB,
+ SelectBranchWeights[0],
+ SelectBranchWeights[1]);
}
/// This is called when we find an icmp instruction
@@ -5734,15 +5739,66 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) {
return Changed;
}
-static bool casesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) {
+struct ContiguousCasesResult {
+ ConstantInt *Min;
+ ConstantInt *Max;
+ BasicBlock *Dest;
+ BasicBlock *OtherDest;
+ SmallVectorImpl<ConstantInt *> *Cases;
+ SmallVectorImpl<ConstantInt *> *OtherCases;
+};
+
+static std::optional<ContiguousCasesResult>
+findContiguousCases(Value *Condition, SmallVectorImpl<ConstantInt *> &Cases,
+ SmallVectorImpl<ConstantInt *> &OtherCases,
+ BasicBlock *Dest, BasicBlock *OtherDest) {
assert(Cases.size() >= 1);
array_pod_sort(Cases.begin(), Cases.end(), constantIntSortPredicate);
- for (size_t I = 1, E = Cases.size(); I != E; ++I) {
- if (Cases[I - 1]->getValue() != Cases[I]->getValue() + 1)
- return false;
+ const APInt &Min = Cases.back()->getValue();
+ const APInt &Max = Cases.front()->getValue();
+ APInt Offset = Max - Min;
+ size_t ContiguousOffset = Cases.size() - 1;
+ if (Offset == ContiguousOffset) {
+ return ContiguousCasesResult{
+ /*Min=*/Cases.back(),
+ /*Max=*/Cases.front(),
+ /*Dest=*/Dest,
+ /*OtherDest=*/OtherDest,
+ /*Cases=*/&Cases,
+ /*OtherCases=*/&OtherCases,
+ };
}
- return true;
+ ConstantRange CR = computeConstantRange(Condition, /*ForSigned=*/false);
+ // If this is a wrapping contiguous range, that is, [Min, OtherMin] +
+ // [OtherMax, Max] (also [OtherMax, OtherMin]), [OtherMin+1, OtherMax-1] is a
+ // contiguous range for the other destination. N.B. If CR is not a full range,
+ // Max+1 is not equal to Min. It's not continuous in arithmetic.
+ if (Max == CR.getUnsignedMax() && Min == CR.getUnsignedMin()) {
+ assert(Cases.size() >= 2);
+ auto *It =
+ std::adjacent_find(Cases.begin(), Cases.end(), [](auto L, auto R) {
+ return L->getValue() != R->getValue() + 1;
+ });
+ if (It == Cases.end())
+ return std::nullopt;
+ auto [OtherMax, OtherMin] = std::make_pair(*It, *std::next(It));
+ if ((Max - OtherMax->getValue()) + (OtherMin->getValue() - Min) ==
+ Cases.size() - 2) {
+ return ContiguousCasesResult{
+ /*Min=*/cast<ConstantInt>(
+ ConstantInt::get(OtherMin->getType(), OtherMin->getValue() + 1)),
+ /*Max=*/
+ cast<ConstantInt>(
+ ConstantInt::get(OtherMax->getType(), OtherMax->getValue() - 1)),
+ /*Dest=*/OtherDest,
+ /*OtherDest=*/Dest,
+ /*Cases=*/&OtherCases,
+ /*OtherCases=*/&Cases,
+ };
+ }
+ }
+ return std::nullopt;
}
static void createUnreachableSwitchDefault(SwitchInst *Switch,
@@ -5779,7 +5835,6 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
bool HasDefault = !SI->defaultDestUnreachable();
auto *BB = SI->getParent();
-
// Partition the cases into two sets with different destinations.
BasicBlock *DestA = HasDefault ? SI->getDefaultDest() : nullptr;
BasicBlock *DestB = nullptr;
@@ -5813,37 +5868,62 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
assert(!CasesA.empty() || HasDefault);
// Figure out if one of the sets of cases form a contiguous range.
- SmallVectorImpl<ConstantInt *> *ContiguousCases = nullptr;
- BasicBlock *ContiguousDest = nullptr;
- BasicBlock *OtherDest = nullptr;
- if (!CasesA.empty() && casesAreContiguous(CasesA)) {
- ContiguousCases = &CasesA;
- ContiguousDest = DestA;
- OtherDest = DestB;
- } else if (casesAreContiguous(CasesB)) {
- ContiguousCases = &CasesB;
- ContiguousDest = DestB;
- OtherDest = DestA;
- } else
- return false;
+ std::optional<ContiguousCasesResult> ContiguousCases;
+
+ // Only one icmp is needed when there is only one case.
+ if (!HasDefault && CasesA.size() == 1)
+ ContiguousCases = ContiguousCasesResult{
+ /*Min=*/CasesA[0],
+ /*Max=*/CasesA[0],
+ /*Dest=*/DestA,
+ /*OtherDest=*/DestB,
+ /*Cases=*/&CasesA,
+ /*OtherCases=*/&CasesB,
+ };
+ else if (CasesB.size() == 1)
+ ContiguousCases = ContiguousCasesResult{
+ /*Min=*/CasesB[0],
+ /*Max=*/CasesB[0],
+ /*Dest=*/DestB,
+ /*OtherDest=*/DestA,
+ /*Cases=*/&CasesB,
+ /*OtherCases=*/&CasesA,
+ };
+ // Correctness: Cases to the default destination cannot be contiguous cases.
+ else if (!HasDefault)
+ ContiguousCases =
+ findContiguousCases(SI->getCondition(), CasesA, CasesB, DestA, DestB);
- // Start building the compare and branch.
+ if (!ContiguousCases)
+ ContiguousCases =
+ findContiguousCases(SI->getCondition(), CasesB, CasesA, DestB, DestA);
+
+ if (!ContiguousCases)
+ return false;
- Constant *Offset = ConstantExpr::getNeg(ContiguousCases->back());
- Constant *NumCases =
- ConstantInt::get(Offset->getType(), ContiguousCases->size());
+ auto [Min, Max, Dest, OtherDest, Cases, OtherCases] = *ContiguousCases;
- Value *Sub = SI->getCondition();
- if (!Offset->isNullValue())
- Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off");
+ // Start building the compare and branch.
- Value *Cmp;
+ Constant *Offset = ConstantExpr::getNeg(Min);
+ Constant *NumCases = ConstantInt::get(Offset->getType(),
+ Max->getValue() - Min->getValue() + 1);
+ BranchInst *NewBI;
+ if (NumCases->isOneValue()) {
+ assert(Max->getValue() == Min->getValue());
+ Value *Cmp = Builder.CreateICmpEQ(SI->getCondition(), Min);
+ NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest);
+ }
// If NumCases overflowed, then all possible values jump to the successor.
- if (NumCases->isNullValue() && !ContiguousCases->empty())
- Cmp = ConstantInt::getTrue(SI->getContext());
- else
- Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
- BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);
+ else if (NumCases->isNullValue() && !Cases->empty()) {
+ NewBI = Builder.CreateBr(Dest);
+ } else {
+ Value *Sub = SI->getCondition();
+ if (!Offset->isNullValue())
+ Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off");
+ Value *Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
+ NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest);
+ }
// Update weight for the newly-created conditional branch.
if (hasBranchWeightMD(*SI)) {
@@ -5853,7 +5933,7 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
uint64_t TrueWeight = 0;
uint64_t FalseWeight = 0;
for (size_t I = 0, E = Weights.size(); I != E; ++I) {
- if (SI->getSuccessor(I) == ContiguousDest)
+ if (SI->getSuccessor(I) == Dest)
TrueWeight += Weights[I];
else
FalseWeight += Weights[I];
@@ -5868,15 +5948,15 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
}
// Prune obsolete incoming values off the successors' PHI nodes.
- for (auto BBI = ContiguousDest->begin(); isa<PHINode>(BBI); ++BBI) {
- unsigned PreviousEdges = ContiguousCases->size();
- if (ContiguousDest == SI->getDefaultDest())
+ for (auto BBI = Dest->begin(); isa<PHINode>(BBI); ++BBI) {
+ unsigned PreviousEdges = Cases->size();
+ if (Dest == SI->getDefaultDest())
++PreviousEdges;
for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I)
cast<PHINode>(BBI)->removeIncomingValue(SI->getParent());
}
for (auto BBI = OtherDest->begin(); isa<PHINode>(BBI); ++BBI) {
- unsigned PreviousEdges = SI->getNumCases() - ContiguousCases->size();
+ unsigned PreviousEdges = OtherCases->size();
if (OtherDest == SI->getDefaultDest())
++PreviousEdges;
for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I)
@@ -7877,19 +7957,27 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
bool SimplifyCFGOpt::simplifyIndirectBr(IndirectBrInst *IBI) {
BasicBlock *BB = IBI->getParent();
bool Changed = false;
+ SmallVector<uint32_t> BranchWeights;
+ const bool HasBranchWeights = !ProfcheckDisableMetadataFixes &&
+ extractBranchWeights(*IBI, BranchWeights);
+
+ DenseMap<const BasicBlock *, uint64_t> TargetWeight;
+ if (HasBranchWeights)
+ for (size_t I = 0, E = IBI->getNumDestinations(); I < E; ++I)
+ TargetWeight[IBI->getDestination(I)] += BranchWeights[I];
// Eliminate redundant destinations.
SmallPtrSet<Value *, 8> Succs;
SmallSetVector<BasicBlock *, 8> RemovedSuccs;
- for (unsigned i = 0, e = IBI->getNumDestinations(); i != e; ++i) {
- BasicBlock *Dest = IBI->getDestination(i);
+ for (unsigned I = 0, E = IBI->getNumDestinations(); I != E; ++I) {
+ BasicBlock *Dest = IBI->getDestination(I);
if (!Dest->hasAddressTaken() || !Succs.insert(Dest).second) {
if (!Dest->hasAddressTaken())
RemovedSuccs.insert(Dest);
Dest->removePredecessor(BB);
- IBI->removeDestination(i);
- --i;
- --e;
+ IBI->removeDestination(I);
+ --I;
+ --E;
Changed = true;
}
}
@@ -7915,7 +8003,12 @@ bool SimplifyCFGOpt::simplifyIndirectBr(IndirectBrInst *IBI) {
eraseTerminatorAndDCECond(IBI);
return true;
}
-
+ if (HasBranchWeights) {
+ SmallVector<uint64_t> NewBranchWeights(IBI->getNumDestinations());
+ for (size_t I = 0, E = IBI->getNumDestinations(); I < E; ++I)
+ NewBranchWeights[I] += TargetWeight.find(IBI->getDestination(I))->second;
+ setFittedBranchWeights(*IBI, NewBranchWeights, /*IsExpected=*/false);
+ }
if (SelectInst *SI = dyn_cast<SelectInst>(IBI->getAddress())) {
if (simplifyIndirectBrOnSelect(IBI, SI))
return requestResimplify();
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 56a3d6d..d393a9c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8201,211 +8201,6 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
}
}
-/// Create and return a ResumePhi for \p WideIV, unless it is truncated. If the
-/// induction recipe is not canonical, creates a VPDerivedIVRecipe to compute
-/// the end value of the induction.
-static VPInstruction *addResumePhiRecipeForInduction(
- VPWidenInductionRecipe *WideIV, VPBuilder &VectorPHBuilder,
- VPBuilder &ScalarPHBuilder, VPTypeAnalysis &TypeInfo, VPValue *VectorTC) {
- auto *WideIntOrFp = dyn_cast<VPWidenIntOrFpInductionRecipe>(WideIV);
- // Truncated wide inductions resume from the last lane of their vector value
- // in the last vector iteration which is handled elsewhere.
- if (WideIntOrFp && WideIntOrFp->getTruncInst())
- return nullptr;
-
- VPValue *Start = WideIV->getStartValue();
- VPValue *Step = WideIV->getStepValue();
- const InductionDescriptor &ID = WideIV->getInductionDescriptor();
- VPValue *EndValue = VectorTC;
- if (!WideIntOrFp || !WideIntOrFp->isCanonical()) {
- EndValue = VectorPHBuilder.createDerivedIV(
- ID.getKind(), dyn_cast_or_null<FPMathOperator>(ID.getInductionBinOp()),
- Start, VectorTC, Step);
- }
-
- // EndValue is derived from the vector trip count (which has the same type as
- // the widest induction) and thus may be wider than the induction here.
- Type *ScalarTypeOfWideIV = TypeInfo.inferScalarType(WideIV);
- if (ScalarTypeOfWideIV != TypeInfo.inferScalarType(EndValue)) {
- EndValue = VectorPHBuilder.createScalarCast(Instruction::Trunc, EndValue,
- ScalarTypeOfWideIV,
- WideIV->getDebugLoc());
- }
-
- auto *ResumePhiRecipe = ScalarPHBuilder.createScalarPhi(
- {EndValue, Start}, WideIV->getDebugLoc(), "bc.resume.val");
- return ResumePhiRecipe;
-}
-
-/// Create resume phis in the scalar preheader for first-order recurrences,
-/// reductions and inductions, and update the VPIRInstructions wrapping the
-/// original phis in the scalar header. End values for inductions are added to
-/// \p IVEndValues.
-static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan,
- DenseMap<VPValue *, VPValue *> &IVEndValues) {
- VPTypeAnalysis TypeInfo(Plan);
- auto *ScalarPH = Plan.getScalarPreheader();
- auto *MiddleVPBB = cast<VPBasicBlock>(ScalarPH->getPredecessors()[0]);
- VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion();
- VPBuilder VectorPHBuilder(
- cast<VPBasicBlock>(VectorRegion->getSinglePredecessor()));
- VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi());
- VPBuilder ScalarPHBuilder(ScalarPH);
- for (VPRecipeBase &ScalarPhiR : Plan.getScalarHeader()->phis()) {
- auto *ScalarPhiIRI = cast<VPIRPhi>(&ScalarPhiR);
-
- // TODO: Extract final value from induction recipe initially, optimize to
- // pre-computed end value together in optimizeInductionExitUsers.
- auto *VectorPhiR =
- cast<VPHeaderPHIRecipe>(Builder.getRecipe(&ScalarPhiIRI->getIRPhi()));
- if (auto *WideIVR = dyn_cast<VPWidenInductionRecipe>(VectorPhiR)) {
- if (VPInstruction *ResumePhi = addResumePhiRecipeForInduction(
- WideIVR, VectorPHBuilder, ScalarPHBuilder, TypeInfo,
- &Plan.getVectorTripCount())) {
- assert(isa<VPPhi>(ResumePhi) && "Expected a phi");
- IVEndValues[WideIVR] = ResumePhi->getOperand(0);
- ScalarPhiIRI->addOperand(ResumePhi);
- continue;
- }
- // TODO: Also handle truncated inductions here. Computing end-values
- // separately should be done as VPlan-to-VPlan optimization, after
- // legalizing all resume values to use the last lane from the loop.
- assert(cast<VPWidenIntOrFpInductionRecipe>(VectorPhiR)->getTruncInst() &&
- "should only skip truncated wide inductions");
- continue;
- }
-
- // The backedge value provides the value to resume coming out of a loop,
- // which for FORs is a vector whose last element needs to be extracted. The
- // start value provides the value if the loop is bypassed.
- bool IsFOR = isa<VPFirstOrderRecurrencePHIRecipe>(VectorPhiR);
- auto *ResumeFromVectorLoop = VectorPhiR->getBackedgeValue();
- assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() &&
- "Cannot handle loops with uncountable early exits");
- if (IsFOR)
- ResumeFromVectorLoop = MiddleBuilder.createNaryOp(
- VPInstruction::ExtractLastElement, {ResumeFromVectorLoop}, {},
- "vector.recur.extract");
- StringRef Name = IsFOR ? "scalar.recur.init" : "bc.merge.rdx";
- auto *ResumePhiR = ScalarPHBuilder.createScalarPhi(
- {ResumeFromVectorLoop, VectorPhiR->getStartValue()}, {}, Name);
- ScalarPhiIRI->addOperand(ResumePhiR);
- }
-}
-
-/// Handle users in the exit block for first order reductions in the original
-/// exit block. The penultimate value of recurrences is fed to their LCSSA phi
-/// users in the original exit block using the VPIRInstruction wrapping to the
-/// LCSSA phi.
-static void addExitUsersForFirstOrderRecurrences(VPlan &Plan, VFRange &Range) {
- VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion();
- auto *ScalarPHVPBB = Plan.getScalarPreheader();
- auto *MiddleVPBB = Plan.getMiddleBlock();
- VPBuilder ScalarPHBuilder(ScalarPHVPBB);
- VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi());
-
- auto IsScalableOne = [](ElementCount VF) -> bool {
- return VF == ElementCount::getScalable(1);
- };
-
- for (auto &HeaderPhi : VectorRegion->getEntryBasicBlock()->phis()) {
- auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&HeaderPhi);
- if (!FOR)
- continue;
-
- assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() &&
- "Cannot handle loops with uncountable early exits");
-
- // This is the second phase of vectorizing first-order recurrences, creating
- // extract for users outside the loop. An overview of the transformation is
- // described below. Suppose we have the following loop with some use after
- // the loop of the last a[i-1],
- //
- // for (int i = 0; i < n; ++i) {
- // t = a[i - 1];
- // b[i] = a[i] - t;
- // }
- // use t;
- //
- // There is a first-order recurrence on "a". For this loop, the shorthand
- // scalar IR looks like:
- //
- // scalar.ph:
- // s.init = a[-1]
- // br scalar.body
- //
- // scalar.body:
- // i = phi [0, scalar.ph], [i+1, scalar.body]
- // s1 = phi [s.init, scalar.ph], [s2, scalar.body]
- // s2 = a[i]
- // b[i] = s2 - s1
- // br cond, scalar.body, exit.block
- //
- // exit.block:
- // use = lcssa.phi [s1, scalar.body]
- //
- // In this example, s1 is a recurrence because it's value depends on the
- // previous iteration. In the first phase of vectorization, we created a
- // VPFirstOrderRecurrencePHIRecipe v1 for s1. Now we create the extracts
- // for users in the scalar preheader and exit block.
- //
- // vector.ph:
- // v_init = vector(..., ..., ..., a[-1])
- // br vector.body
- //
- // vector.body
- // i = phi [0, vector.ph], [i+4, vector.body]
- // v1 = phi [v_init, vector.ph], [v2, vector.body]
- // v2 = a[i, i+1, i+2, i+3]
- // b[i] = v2 - v1
- // // Next, third phase will introduce v1' = splice(v1(3), v2(0, 1, 2))
- // b[i, i+1, i+2, i+3] = v2 - v1
- // br cond, vector.body, middle.block
- //
- // middle.block:
- // vector.recur.extract.for.phi = v2(2)
- // vector.recur.extract = v2(3)
- // br cond, scalar.ph, exit.block
- //
- // scalar.ph:
- // scalar.recur.init = phi [vector.recur.extract, middle.block],
- // [s.init, otherwise]
- // br scalar.body
- //
- // scalar.body:
- // i = phi [0, scalar.ph], [i+1, scalar.body]
- // s1 = phi [scalar.recur.init, scalar.ph], [s2, scalar.body]
- // s2 = a[i]
- // b[i] = s2 - s1
- // br cond, scalar.body, exit.block
- //
- // exit.block:
- // lo = lcssa.phi [s1, scalar.body],
- // [vector.recur.extract.for.phi, middle.block]
- //
- // Now update VPIRInstructions modeling LCSSA phis in the exit block.
- // Extract the penultimate value of the recurrence and use it as operand for
- // the VPIRInstruction modeling the phi.
- for (VPUser *U : FOR->users()) {
- using namespace llvm::VPlanPatternMatch;
- if (!match(U, m_ExtractLastElement(m_Specific(FOR))))
- continue;
- // For VF vscale x 1, if vscale = 1, we are unable to extract the
- // penultimate value of the recurrence. Instead we rely on the existing
- // extract of the last element from the result of
- // VPInstruction::FirstOrderRecurrenceSplice.
- // TODO: Consider vscale_range info and UF.
- if (LoopVectorizationPlanner::getDecisionAndClampRange(IsScalableOne,
- Range))
- return;
- VPValue *PenultimateElement = MiddleBuilder.createNaryOp(
- VPInstruction::ExtractPenultimateElement, {FOR->getBackedgeValue()},
- {}, "vector.recur.extract.for.phi");
- cast<VPInstruction>(U)->replaceAllUsesWith(PenultimateElement);
- }
- }
-}
-
VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
VPlanPtr Plan, VFRange &Range, LoopVersioning *LVer) {
@@ -8598,9 +8393,11 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
R->setOperand(1, WideIV->getStepValue());
}
- addExitUsersForFirstOrderRecurrences(*Plan, Range);
+ // TODO: We can't call runPass on these transforms yet, due to verifier
+ // failures.
+ VPlanTransforms::addExitUsersForFirstOrderRecurrences(*Plan, Range);
DenseMap<VPValue *, VPValue *> IVEndValues;
- addScalarResumePhis(RecipeBuilder, *Plan, IVEndValues);
+ VPlanTransforms::addScalarResumePhis(*Plan, RecipeBuilder, IVEndValues);
// ---------------------------------------------------------------------------
// Transform initial VPlan: Apply previously taken decisions, in order, to
@@ -8711,7 +8508,9 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlan(VFRange &Range) {
DenseMap<VPValue *, VPValue *> IVEndValues;
// TODO: IVEndValues are not used yet in the native path, to optimize exit
// values.
- addScalarResumePhis(RecipeBuilder, *Plan, IVEndValues);
+ // TODO: We can't call runPass on the transform yet, due to verifier
+ // failures.
+ VPlanTransforms::addScalarResumePhis(*Plan, RecipeBuilder, IVEndValues);
assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid");
return Plan;
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index fedca65..91c3d42 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -10620,7 +10620,8 @@ class InstructionsCompatibilityAnalysis {
/// Checks if the opcode is supported as the main opcode for copyable
/// elements.
static bool isSupportedOpcode(const unsigned Opcode) {
- return Opcode == Instruction::Add || Opcode == Instruction::LShr;
+ return Opcode == Instruction::Add || Opcode == Instruction::LShr ||
+ Opcode == Instruction::Shl;
}
/// Identifies the best candidate value, which represents main opcode
@@ -10937,6 +10938,7 @@ public:
switch (MainOpcode) {
case Instruction::Add:
case Instruction::LShr:
+ case Instruction::Shl:
VectorCost = TTI.getArithmeticInstrCost(MainOpcode, VecTy, Kind);
break;
default:
@@ -22006,6 +22008,8 @@ bool BoUpSLP::collectValuesToDemote(
return all_of(E.Scalars, [&](Value *V) {
if (isa<PoisonValue>(V))
return true;
+ if (E.isCopyableElement(V))
+ return true;
auto *I = cast<Instruction>(V);
KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
return AmtKnownBits.getMaxValue().ult(BitWidth);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index ca63bf3..ebf833e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -4198,3 +4198,202 @@ void VPlanTransforms::addBranchWeightToMiddleTerminator(
MDB.createBranchWeights({1, VectorStep - 1}, /*IsExpected=*/false);
MiddleTerm->addMetadata(LLVMContext::MD_prof, BranchWeights);
}
+
+/// Create and return a ResumePhi for \p WideIV, unless it is truncated. If the
+/// induction recipe is not canonical, creates a VPDerivedIVRecipe to compute
+/// the end value of the induction.
+static VPInstruction *addResumePhiRecipeForInduction(
+ VPWidenInductionRecipe *WideIV, VPBuilder &VectorPHBuilder,
+ VPBuilder &ScalarPHBuilder, VPTypeAnalysis &TypeInfo, VPValue *VectorTC) {
+ auto *WideIntOrFp = dyn_cast<VPWidenIntOrFpInductionRecipe>(WideIV);
+ // Truncated wide inductions resume from the last lane of their vector value
+ // in the last vector iteration which is handled elsewhere.
+ if (WideIntOrFp && WideIntOrFp->getTruncInst())
+ return nullptr;
+
+ VPValue *Start = WideIV->getStartValue();
+ VPValue *Step = WideIV->getStepValue();
+ const InductionDescriptor &ID = WideIV->getInductionDescriptor();
+ VPValue *EndValue = VectorTC;
+ if (!WideIntOrFp || !WideIntOrFp->isCanonical()) {
+ EndValue = VectorPHBuilder.createDerivedIV(
+ ID.getKind(), dyn_cast_or_null<FPMathOperator>(ID.getInductionBinOp()),
+ Start, VectorTC, Step);
+ }
+
+ // EndValue is derived from the vector trip count (which has the same type as
+ // the widest induction) and thus may be wider than the induction here.
+ Type *ScalarTypeOfWideIV = TypeInfo.inferScalarType(WideIV);
+ if (ScalarTypeOfWideIV != TypeInfo.inferScalarType(EndValue)) {
+ EndValue = VectorPHBuilder.createScalarCast(Instruction::Trunc, EndValue,
+ ScalarTypeOfWideIV,
+ WideIV->getDebugLoc());
+ }
+
+ auto *ResumePhiRecipe = ScalarPHBuilder.createScalarPhi(
+ {EndValue, Start}, WideIV->getDebugLoc(), "bc.resume.val");
+ return ResumePhiRecipe;
+}
+
+void VPlanTransforms::addScalarResumePhis(
+ VPlan &Plan, VPRecipeBuilder &Builder,
+ DenseMap<VPValue *, VPValue *> &IVEndValues) {
+ VPTypeAnalysis TypeInfo(Plan);
+ auto *ScalarPH = Plan.getScalarPreheader();
+ auto *MiddleVPBB = cast<VPBasicBlock>(ScalarPH->getPredecessors()[0]);
+ VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion();
+ VPBuilder VectorPHBuilder(
+ cast<VPBasicBlock>(VectorRegion->getSinglePredecessor()));
+ VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi());
+ VPBuilder ScalarPHBuilder(ScalarPH);
+ for (VPRecipeBase &ScalarPhiR : Plan.getScalarHeader()->phis()) {
+ auto *ScalarPhiIRI = cast<VPIRPhi>(&ScalarPhiR);
+
+ // TODO: Extract final value from induction recipe initially, optimize to
+ // pre-computed end value together in optimizeInductionExitUsers.
+ auto *VectorPhiR =
+ cast<VPHeaderPHIRecipe>(Builder.getRecipe(&ScalarPhiIRI->getIRPhi()));
+ if (auto *WideIVR = dyn_cast<VPWidenInductionRecipe>(VectorPhiR)) {
+ if (VPInstruction *ResumePhi = addResumePhiRecipeForInduction(
+ WideIVR, VectorPHBuilder, ScalarPHBuilder, TypeInfo,
+ &Plan.getVectorTripCount())) {
+ assert(isa<VPPhi>(ResumePhi) && "Expected a phi");
+ IVEndValues[WideIVR] = ResumePhi->getOperand(0);
+ ScalarPhiIRI->addOperand(ResumePhi);
+ continue;
+ }
+ // TODO: Also handle truncated inductions here. Computing end-values
+ // separately should be done as VPlan-to-VPlan optimization, after
+ // legalizing all resume values to use the last lane from the loop.
+ assert(cast<VPWidenIntOrFpInductionRecipe>(VectorPhiR)->getTruncInst() &&
+ "should only skip truncated wide inductions");
+ continue;
+ }
+
+ // The backedge value provides the value to resume coming out of a loop,
+ // which for FORs is a vector whose last element needs to be extracted. The
+ // start value provides the value if the loop is bypassed.
+ bool IsFOR = isa<VPFirstOrderRecurrencePHIRecipe>(VectorPhiR);
+ auto *ResumeFromVectorLoop = VectorPhiR->getBackedgeValue();
+ assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() &&
+ "Cannot handle loops with uncountable early exits");
+ if (IsFOR)
+ ResumeFromVectorLoop = MiddleBuilder.createNaryOp(
+ VPInstruction::ExtractLastElement, {ResumeFromVectorLoop}, {},
+ "vector.recur.extract");
+ StringRef Name = IsFOR ? "scalar.recur.init" : "bc.merge.rdx";
+ auto *ResumePhiR = ScalarPHBuilder.createScalarPhi(
+ {ResumeFromVectorLoop, VectorPhiR->getStartValue()}, {}, Name);
+ ScalarPhiIRI->addOperand(ResumePhiR);
+ }
+}
+
+void VPlanTransforms::addExitUsersForFirstOrderRecurrences(VPlan &Plan,
+ VFRange &Range) {
+ VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion();
+ auto *ScalarPHVPBB = Plan.getScalarPreheader();
+ auto *MiddleVPBB = Plan.getMiddleBlock();
+ VPBuilder ScalarPHBuilder(ScalarPHVPBB);
+ VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->getFirstNonPhi());
+
+ auto IsScalableOne = [](ElementCount VF) -> bool {
+ return VF == ElementCount::getScalable(1);
+ };
+
+ for (auto &HeaderPhi : VectorRegion->getEntryBasicBlock()->phis()) {
+ auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&HeaderPhi);
+ if (!FOR)
+ continue;
+
+ assert(VectorRegion->getSingleSuccessor() == Plan.getMiddleBlock() &&
+ "Cannot handle loops with uncountable early exits");
+
+ // This is the second phase of vectorizing first-order recurrences, creating
+ // extract for users outside the loop. An overview of the transformation is
+ // described below. Suppose we have the following loop with some use after
+ // the loop of the last a[i-1],
+ //
+ // for (int i = 0; i < n; ++i) {
+ // t = a[i - 1];
+ // b[i] = a[i] - t;
+ // }
+ // use t;
+ //
+ // There is a first-order recurrence on "a". For this loop, the shorthand
+ // scalar IR looks like:
+ //
+ // scalar.ph:
+ // s.init = a[-1]
+ // br scalar.body
+ //
+ // scalar.body:
+ // i = phi [0, scalar.ph], [i+1, scalar.body]
+ // s1 = phi [s.init, scalar.ph], [s2, scalar.body]
+ // s2 = a[i]
+ // b[i] = s2 - s1
+ // br cond, scalar.body, exit.block
+ //
+ // exit.block:
+ // use = lcssa.phi [s1, scalar.body]
+ //
+ // In this example, s1 is a recurrence because it's value depends on the
+ // previous iteration. In the first phase of vectorization, we created a
+ // VPFirstOrderRecurrencePHIRecipe v1 for s1. Now we create the extracts
+ // for users in the scalar preheader and exit block.
+ //
+ // vector.ph:
+ // v_init = vector(..., ..., ..., a[-1])
+ // br vector.body
+ //
+ // vector.body
+ // i = phi [0, vector.ph], [i+4, vector.body]
+ // v1 = phi [v_init, vector.ph], [v2, vector.body]
+ // v2 = a[i, i+1, i+2, i+3]
+ // b[i] = v2 - v1
+ // // Next, third phase will introduce v1' = splice(v1(3), v2(0, 1, 2))
+ // b[i, i+1, i+2, i+3] = v2 - v1
+ // br cond, vector.body, middle.block
+ //
+ // middle.block:
+ // vector.recur.extract.for.phi = v2(2)
+ // vector.recur.extract = v2(3)
+ // br cond, scalar.ph, exit.block
+ //
+ // scalar.ph:
+ // scalar.recur.init = phi [vector.recur.extract, middle.block],
+ // [s.init, otherwise]
+ // br scalar.body
+ //
+ // scalar.body:
+ // i = phi [0, scalar.ph], [i+1, scalar.body]
+ // s1 = phi [scalar.recur.init, scalar.ph], [s2, scalar.body]
+ // s2 = a[i]
+ // b[i] = s2 - s1
+ // br cond, scalar.body, exit.block
+ //
+ // exit.block:
+ // lo = lcssa.phi [s1, scalar.body],
+ // [vector.recur.extract.for.phi, middle.block]
+ //
+ // Now update VPIRInstructions modeling LCSSA phis in the exit block.
+ // Extract the penultimate value of the recurrence and use it as operand for
+ // the VPIRInstruction modeling the phi.
+ for (VPUser *U : FOR->users()) {
+ using namespace llvm::VPlanPatternMatch;
+ if (!match(U, m_ExtractLastElement(m_Specific(FOR))))
+ continue;
+ // For VF vscale x 1, if vscale = 1, we are unable to extract the
+ // penultimate value of the recurrence. Instead we rely on the existing
+ // extract of the last element from the result of
+ // VPInstruction::FirstOrderRecurrenceSplice.
+ // TODO: Consider vscale_range info and UF.
+ if (LoopVectorizationPlanner::getDecisionAndClampRange(IsScalableOne,
+ Range))
+ return;
+ VPValue *PenultimateElement = MiddleBuilder.createNaryOp(
+ VPInstruction::ExtractPenultimateElement, {FOR->getBackedgeValue()},
+ {}, "vector.recur.extract.for.phi");
+ cast<VPInstruction>(U)->replaceAllUsesWith(PenultimateElement);
+ }
+ }
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index 2f00e51..5a8a2bb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -363,6 +363,19 @@ struct VPlanTransforms {
static void
addBranchWeightToMiddleTerminator(VPlan &Plan, ElementCount VF,
std::optional<unsigned> VScaleForTuning);
+
+ /// Create resume phis in the scalar preheader for first-order recurrences,
+ /// reductions and inductions, and update the VPIRInstructions wrapping the
+ /// original phis in the scalar header. End values for inductions are added to
+ /// \p IVEndValues.
+ static void addScalarResumePhis(VPlan &Plan, VPRecipeBuilder &Builder,
+ DenseMap<VPValue *, VPValue *> &IVEndValues);
+
+ /// Handle users in the exit block for first order reductions in the original
+ /// exit block. The penultimate value of recurrences is fed to their LCSSA phi
+ /// users in the original exit block using the VPIRInstruction wrapping to the
+ /// LCSSA phi.
+ static void addExitUsersForFirstOrderRecurrences(VPlan &Plan, VFRange &Range);
};
} // namespace llvm