aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Analysis/LoopCacheAnalysis.cpp7
-rw-r--r--llvm/lib/Analysis/TargetTransformInfo.cpp20
-rw-r--r--llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp21
-rw-r--r--llvm/lib/CodeGen/CMakeLists.txt1
-rw-r--r--llvm/lib/CodeGen/MachineBlockHashInfo.cpp115
-rw-r--r--llvm/lib/CodeGen/TargetPassConfig.cpp8
-rw-r--r--llvm/lib/ExecutionEngine/JITLink/JITLinkMemoryManager.cpp2
-rw-r--r--llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp2
-rw-r--r--llvm/lib/ExecutionEngine/Orc/LinkGraphLinkingLayer.cpp2
-rw-r--r--llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp4
-rw-r--r--llvm/lib/LTO/LTO.cpp11
-rw-r--r--llvm/lib/Target/AArch64/AArch64InstrInfo.td24
-rw-r--r--llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp26
-rw-r--r--llvm/lib/Target/ARM/ARMAsmPrinter.cpp434
-rw-r--r--llvm/lib/Target/ARM/ARMAsmPrinter.h11
-rw-r--r--llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp2
-rw-r--r--llvm/lib/Target/ARM/ARMISelLowering.cpp69
-rw-r--r--llvm/lib/Target/ARM/ARMISelLowering.h6
-rw-r--r--llvm/lib/Target/ARM/ARMInstrInfo.td30
-rw-r--r--llvm/lib/Target/ARM/ARMTargetMachine.cpp12
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td20
-rw-r--r--llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td2
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp31
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVBuiltins.td1
-rw-r--r--llvm/lib/Target/X86/X86ISelLowering.cpp4
-rw-r--r--llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp31
-rw-r--r--llvm/lib/Transforms/Scalar/ConstraintElimination.cpp20
-rw-r--r--llvm/lib/Transforms/Scalar/MergeICmps.cpp34
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlan.h13
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp67
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp80
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanUtils.cpp2
32 files changed, 960 insertions, 152 deletions
diff --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp
index 050c327..424a7fe 100644
--- a/llvm/lib/Analysis/LoopCacheAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp
@@ -436,10 +436,9 @@ bool IndexedReference::delinearize(const LoopInfo &LI) {
const SCEV *StepRec = AccessFnAR ? AccessFnAR->getStepRecurrence(SE) : nullptr;
if (StepRec && SE.isKnownNegative(StepRec))
- AccessFn = SE.getAddRecExpr(AccessFnAR->getStart(),
- SE.getNegativeSCEV(StepRec),
- AccessFnAR->getLoop(),
- AccessFnAR->getNoWrapFlags());
+ AccessFn = SE.getAddRecExpr(
+ AccessFnAR->getStart(), SE.getNegativeSCEV(StepRec),
+ AccessFnAR->getLoop(), SCEV::NoWrapFlags::FlagAnyWrap);
const SCEV *Div = SE.getUDivExactExpr(AccessFn, ElemSize);
Subscripts.push_back(Div);
Sizes.push_back(ElemSize);
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index bf62623..c47a1c1 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1001,13 +1001,25 @@ InstructionCost TargetTransformInfo::getShuffleCost(
TargetTransformInfo::PartialReductionExtendKind
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
- if (isa<SExtInst>(I))
- return PR_SignExtend;
- if (isa<ZExtInst>(I))
- return PR_ZeroExtend;
+ if (auto *Cast = dyn_cast<CastInst>(I))
+ return getPartialReductionExtendKind(Cast->getOpcode());
return PR_None;
}
+TargetTransformInfo::PartialReductionExtendKind
+TargetTransformInfo::getPartialReductionExtendKind(
+ Instruction::CastOps CastOpc) {
+ switch (CastOpc) {
+ case Instruction::CastOps::ZExt:
+ return PR_ZeroExtend;
+ case Instruction::CastOps::SExt:
+ return PR_SignExtend;
+ default:
+ return PR_None;
+ }
+ llvm_unreachable("Unhandled cast opcode");
+}
+
TTI::CastContextHint
TargetTransformInfo::getCastContextHint(const Instruction *I) {
if (!I)
diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index fefde64f..8aa488f 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -41,6 +41,7 @@
#include "llvm/CodeGen/GCMetadataPrinter.h"
#include "llvm/CodeGen/LazyMachineBlockFrequencyInfo.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
+#include "llvm/CodeGen/MachineBlockHashInfo.h"
#include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
#include "llvm/CodeGen/MachineConstantPool.h"
#include "llvm/CodeGen/MachineDominators.h"
@@ -184,6 +185,8 @@ static cl::opt<bool> PrintLatency(
cl::desc("Print instruction latencies as verbose asm comments"), cl::Hidden,
cl::init(false));
+extern cl::opt<bool> EmitBBHash;
+
STATISTIC(EmittedInsts, "Number of machine instrs printed");
char AsmPrinter::ID = 0;
@@ -474,6 +477,8 @@ void AsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<GCModuleInfo>();
AU.addRequired<LazyMachineBlockFrequencyInfoPass>();
AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
+ if (EmitBBHash)
+ AU.addRequired<MachineBlockHashInfo>();
}
bool AsmPrinter::doInitialization(Module &M) {
@@ -1434,14 +1439,11 @@ getBBAddrMapFeature(const MachineFunction &MF, int NumMBBSectionRanges,
"BB entries info is required for BBFreq and BrProb "
"features");
}
- return {FuncEntryCountEnabled,
- BBFreqEnabled,
- BrProbEnabled,
+ return {FuncEntryCountEnabled, BBFreqEnabled, BrProbEnabled,
MF.hasBBSections() && NumMBBSectionRanges > 1,
// Use static_cast to avoid breakage of tests on windows.
- static_cast<bool>(BBAddrMapSkipEmitBBEntries),
- HasCalls,
- false};
+ static_cast<bool>(BBAddrMapSkipEmitBBEntries), HasCalls,
+ static_cast<bool>(EmitBBHash)};
}
void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) {
@@ -1500,6 +1502,9 @@ void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) {
PrevMBBEndSymbol = MBBSymbol;
}
+ auto MBHI =
+ Features.BBHash ? &getAnalysis<MachineBlockHashInfo>() : nullptr;
+
if (!Features.OmitBBEntries) {
OutStreamer->AddComment("BB id");
// Emit the BB ID for this basic block.
@@ -1527,6 +1532,10 @@ void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) {
emitLabelDifferenceAsULEB128(MBB.getEndSymbol(), CurrentLabel);
// Emit the Metadata.
OutStreamer->emitULEB128IntValue(getBBAddrMapMetadata(MBB));
+ // Emit the Hash.
+ if (MBHI) {
+ OutStreamer->emitInt64(MBHI->getMBBHash(MBB));
+ }
}
PrevMBBEndSymbol = MBB.getEndSymbol();
}
diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt
index b6872605..4373c53 100644
--- a/llvm/lib/CodeGen/CMakeLists.txt
+++ b/llvm/lib/CodeGen/CMakeLists.txt
@@ -108,6 +108,7 @@ add_llvm_component_library(LLVMCodeGen
LowerEmuTLS.cpp
MachineBasicBlock.cpp
MachineBlockFrequencyInfo.cpp
+ MachineBlockHashInfo.cpp
MachineBlockPlacement.cpp
MachineBranchProbabilityInfo.cpp
MachineCFGPrinter.cpp
diff --git a/llvm/lib/CodeGen/MachineBlockHashInfo.cpp b/llvm/lib/CodeGen/MachineBlockHashInfo.cpp
new file mode 100644
index 0000000..c4d9c0f
--- /dev/null
+++ b/llvm/lib/CodeGen/MachineBlockHashInfo.cpp
@@ -0,0 +1,115 @@
+//===- llvm/CodeGen/MachineBlockHashInfo.cpp---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Compute the hashes of basic blocks.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/MachineBlockHashInfo.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Target/TargetMachine.h"
+
+using namespace llvm;
+
+uint64_t hashBlock(const MachineBasicBlock &MBB, bool HashOperands) {
+ uint64_t Hash = 0;
+ for (const MachineInstr &MI : MBB) {
+ if (MI.isMetaInstruction() || MI.isTerminator())
+ continue;
+ Hash = hashing::detail::hash_16_bytes(Hash, MI.getOpcode());
+ if (HashOperands) {
+ for (unsigned i = 0; i < MI.getNumOperands(); i++) {
+ Hash =
+ hashing::detail::hash_16_bytes(Hash, hash_value(MI.getOperand(i)));
+ }
+ }
+ }
+ return Hash;
+}
+
+/// Fold a 64-bit integer to a 16-bit one.
+uint16_t fold_64_to_16(const uint64_t Value) {
+ uint16_t Res = static_cast<uint16_t>(Value);
+ Res ^= static_cast<uint16_t>(Value >> 16);
+ Res ^= static_cast<uint16_t>(Value >> 32);
+ Res ^= static_cast<uint16_t>(Value >> 48);
+ return Res;
+}
+
+INITIALIZE_PASS(MachineBlockHashInfo, "machine-block-hash",
+ "Machine Block Hash Analysis", true, true)
+
+char MachineBlockHashInfo::ID = 0;
+
+MachineBlockHashInfo::MachineBlockHashInfo() : MachineFunctionPass(ID) {
+ initializeMachineBlockHashInfoPass(*PassRegistry::getPassRegistry());
+}
+
+void MachineBlockHashInfo::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesAll();
+ MachineFunctionPass::getAnalysisUsage(AU);
+}
+
+struct CollectHashInfo {
+ uint64_t Offset;
+ uint64_t OpcodeHash;
+ uint64_t InstrHash;
+ uint64_t NeighborHash;
+};
+
+bool MachineBlockHashInfo::runOnMachineFunction(MachineFunction &F) {
+ DenseMap<const MachineBasicBlock *, CollectHashInfo> HashInfos;
+ uint16_t Offset = 0;
+ // Initialize hash components
+ for (const MachineBasicBlock &MBB : F) {
+ // offset of the machine basic block
+ HashInfos[&MBB].Offset = Offset;
+ Offset += MBB.size();
+ // Hashing opcodes
+ HashInfos[&MBB].OpcodeHash = hashBlock(MBB, /*HashOperands=*/false);
+ // Hash complete instructions
+ HashInfos[&MBB].InstrHash = hashBlock(MBB, /*HashOperands=*/true);
+ }
+
+ // Initialize neighbor hash
+ for (const MachineBasicBlock &MBB : F) {
+ uint64_t Hash = HashInfos[&MBB].OpcodeHash;
+ // Append hashes of successors
+ for (const MachineBasicBlock *SuccMBB : MBB.successors()) {
+ uint64_t SuccHash = HashInfos[SuccMBB].OpcodeHash;
+ Hash = hashing::detail::hash_16_bytes(Hash, SuccHash);
+ }
+ // Append hashes of predecessors
+ for (const MachineBasicBlock *PredMBB : MBB.predecessors()) {
+ uint64_t PredHash = HashInfos[PredMBB].OpcodeHash;
+ Hash = hashing::detail::hash_16_bytes(Hash, PredHash);
+ }
+ HashInfos[&MBB].NeighborHash = Hash;
+ }
+
+ // Assign hashes
+ for (const MachineBasicBlock &MBB : F) {
+ const auto &HashInfo = HashInfos[&MBB];
+ BlendedBlockHash BlendedHash(fold_64_to_16(HashInfo.Offset),
+ fold_64_to_16(HashInfo.OpcodeHash),
+ fold_64_to_16(HashInfo.InstrHash),
+ fold_64_to_16(HashInfo.NeighborHash));
+ MBBHashInfo[&MBB] = BlendedHash.combine();
+ }
+
+ return false;
+}
+
+uint64_t MachineBlockHashInfo::getMBBHash(const MachineBasicBlock &MBB) {
+ return MBBHashInfo[&MBB];
+}
+
+MachineFunctionPass *llvm::createMachineBlockHashInfoPass() {
+ return new MachineBlockHashInfo();
+}
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index b6169e6..10b7238 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -272,6 +272,12 @@ static cl::opt<bool>
cl::desc("Split static data sections into hot and cold "
"sections using profile information"));
+cl::opt<bool> EmitBBHash(
+ "emit-bb-hash",
+ cl::desc(
+ "Emit the hash of basic block in the SHT_LLVM_BB_ADDR_MAP section."),
+ cl::init(false), cl::Optional);
+
/// Allow standard passes to be disabled by command line options. This supports
/// simple binary flags that either suppress the pass or do nothing.
/// i.e. -disable-mypass=false has no effect.
@@ -1281,6 +1287,8 @@ void TargetPassConfig::addMachinePasses() {
// address map (or both).
if (TM->getBBSectionsType() != llvm::BasicBlockSection::None ||
TM->Options.BBAddrMap) {
+ if (EmitBBHash)
+ addPass(llvm::createMachineBlockHashInfoPass());
if (TM->getBBSectionsType() == llvm::BasicBlockSection::List) {
addPass(llvm::createBasicBlockSectionsProfileReaderWrapperPass(
TM->getBBSectionsFuncListBuf()));
diff --git a/llvm/lib/ExecutionEngine/JITLink/JITLinkMemoryManager.cpp b/llvm/lib/ExecutionEngine/JITLink/JITLinkMemoryManager.cpp
index 6c7e27e..fa04976 100644
--- a/llvm/lib/ExecutionEngine/JITLink/JITLinkMemoryManager.cpp
+++ b/llvm/lib/ExecutionEngine/JITLink/JITLinkMemoryManager.cpp
@@ -247,7 +247,7 @@ public:
StandardSegments(std::move(StandardSegments)),
FinalizationSegments(std::move(FinalizationSegments)) {}
- ~IPInFlightAlloc() {
+ ~IPInFlightAlloc() override {
assert(!G && "InFlight alloc neither abandoned nor finalized");
}
diff --git a/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp b/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp
index 75ae80f..4ceff48 100644
--- a/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp
@@ -38,7 +38,7 @@ public:
MachODebugObjectSynthesizerBase(LinkGraph &G, ExecutorAddr RegisterActionAddr)
: G(G), RegisterActionAddr(RegisterActionAddr) {}
- virtual ~MachODebugObjectSynthesizerBase() = default;
+ ~MachODebugObjectSynthesizerBase() override = default;
Error preserveDebugSections() {
if (G.findSectionByName(SynthDebugSectionName)) {
diff --git a/llvm/lib/ExecutionEngine/Orc/LinkGraphLinkingLayer.cpp b/llvm/lib/ExecutionEngine/Orc/LinkGraphLinkingLayer.cpp
index d1a6eaf..a2990ab 100644
--- a/llvm/lib/ExecutionEngine/Orc/LinkGraphLinkingLayer.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/LinkGraphLinkingLayer.cpp
@@ -55,7 +55,7 @@ public:
Plugins = Layer.Plugins;
}
- ~JITLinkCtx() {
+ ~JITLinkCtx() override {
// If there is an object buffer return function then use it to
// return ownership of the buffer.
if (Layer.ReturnObjectBuffer && ObjBuffer)
diff --git a/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp b/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp
index fd805fbf..cdde733 100644
--- a/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp
@@ -92,7 +92,7 @@ public:
Name(std::move(Name)), Ctx(Ctx), Materialize(Materialize),
Discard(Discard), Destroy(Destroy) {}
- ~OrcCAPIMaterializationUnit() {
+ ~OrcCAPIMaterializationUnit() override {
if (Ctx)
Destroy(Ctx);
}
@@ -264,7 +264,7 @@ public:
LLVMOrcCAPIDefinitionGeneratorTryToGenerateFunction TryToGenerate)
: Dispose(Dispose), Ctx(Ctx), TryToGenerate(TryToGenerate) {}
- ~CAPIDefinitionGenerator() {
+ ~CAPIDefinitionGenerator() override {
if (Dispose)
Dispose(Ctx);
}
diff --git a/llvm/lib/LTO/LTO.cpp b/llvm/lib/LTO/LTO.cpp
index 86780e1..9d0fa11 100644
--- a/llvm/lib/LTO/LTO.cpp
+++ b/llvm/lib/LTO/LTO.cpp
@@ -2224,6 +2224,7 @@ class OutOfProcessThinBackend : public CGThinBackend {
ArrayRef<StringRef> DistributorArgs;
SString RemoteCompiler;
+ ArrayRef<StringRef> RemoteCompilerPrependArgs;
ArrayRef<StringRef> RemoteCompilerArgs;
bool SaveTemps;
@@ -2260,12 +2261,14 @@ public:
bool ShouldEmitIndexFiles, bool ShouldEmitImportsFiles,
StringRef LinkerOutputFile, StringRef Distributor,
ArrayRef<StringRef> DistributorArgs, StringRef RemoteCompiler,
+ ArrayRef<StringRef> RemoteCompilerPrependArgs,
ArrayRef<StringRef> RemoteCompilerArgs, bool SaveTemps)
: CGThinBackend(Conf, CombinedIndex, ModuleToDefinedGVSummaries,
AddStream, OnWrite, ShouldEmitIndexFiles,
ShouldEmitImportsFiles, ThinLTOParallelism),
LinkerOutputFile(LinkerOutputFile), DistributorPath(Distributor),
DistributorArgs(DistributorArgs), RemoteCompiler(RemoteCompiler),
+ RemoteCompilerPrependArgs(RemoteCompilerPrependArgs),
RemoteCompilerArgs(RemoteCompilerArgs), SaveTemps(SaveTemps) {}
virtual void setup(unsigned ThinLTONumTasks, unsigned ThinLTOTaskOffset,
@@ -2387,6 +2390,11 @@ public:
JOS.attributeArray("args", [&]() {
JOS.value(RemoteCompiler);
+ // Forward any supplied prepend options.
+ if (!RemoteCompilerPrependArgs.empty())
+ for (auto &A : RemoteCompilerPrependArgs)
+ JOS.value(A);
+
JOS.value("-c");
JOS.value(Saver.save("--target=" + Triple.str()));
@@ -2517,6 +2525,7 @@ ThinBackend lto::createOutOfProcessThinBackend(
bool ShouldEmitIndexFiles, bool ShouldEmitImportsFiles,
StringRef LinkerOutputFile, StringRef Distributor,
ArrayRef<StringRef> DistributorArgs, StringRef RemoteCompiler,
+ ArrayRef<StringRef> RemoteCompilerPrependArgs,
ArrayRef<StringRef> RemoteCompilerArgs, bool SaveTemps) {
auto Func =
[=](const Config &Conf, ModuleSummaryIndex &CombinedIndex,
@@ -2526,7 +2535,7 @@ ThinBackend lto::createOutOfProcessThinBackend(
Conf, CombinedIndex, Parallelism, ModuleToDefinedGVSummaries,
AddStream, OnWrite, ShouldEmitIndexFiles, ShouldEmitImportsFiles,
LinkerOutputFile, Distributor, DistributorArgs, RemoteCompiler,
- RemoteCompilerArgs, SaveTemps);
+ RemoteCompilerPrependArgs, RemoteCompilerArgs, SaveTemps);
};
return ThinBackend(Func, Parallelism);
}
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index f788c75..92f260f 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -4005,24 +4005,20 @@ def : Pat<(i64 (zextloadi32 (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))),
(SUBREG_TO_REG (i64 0), (LDRWui GPR64sp:$Rn, uimm12s4:$offset), sub_32)>;
// load zero-extended i32, bitcast to f64
-def : Pat <(f64 (bitconvert (i64 (zextloadi32 (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))))),
- (SUBREG_TO_REG (i64 0), (LDRSui GPR64sp:$Rn, uimm12s4:$offset), ssub)>;
-
+def : Pat<(f64 (bitconvert (i64 (zextloadi32 (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))))),
+ (SUBREG_TO_REG (i64 0), (LDRSui GPR64sp:$Rn, uimm12s4:$offset), ssub)>;
// load zero-extended i16, bitcast to f64
-def : Pat <(f64 (bitconvert (i64 (zextloadi16 (am_indexed32 GPR64sp:$Rn, uimm12s2:$offset))))),
- (SUBREG_TO_REG (i64 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>;
-
+def : Pat<(f64 (bitconvert (i64 (zextloadi16 (am_indexed16 GPR64sp:$Rn, uimm12s2:$offset))))),
+ (SUBREG_TO_REG (i64 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>;
// load zero-extended i8, bitcast to f64
-def : Pat <(f64 (bitconvert (i64 (zextloadi8 (am_indexed32 GPR64sp:$Rn, uimm12s1:$offset))))),
- (SUBREG_TO_REG (i64 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>;
-
+def : Pat<(f64 (bitconvert (i64 (zextloadi8 (am_indexed8 GPR64sp:$Rn, uimm12s1:$offset))))),
+ (SUBREG_TO_REG (i64 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>;
// load zero-extended i16, bitcast to f32
-def : Pat <(f32 (bitconvert (i32 (zextloadi16 (am_indexed16 GPR64sp:$Rn, uimm12s2:$offset))))),
- (SUBREG_TO_REG (i32 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>;
-
+def : Pat<(f32 (bitconvert (i32 (zextloadi16 (am_indexed16 GPR64sp:$Rn, uimm12s2:$offset))))),
+ (SUBREG_TO_REG (i32 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>;
// load zero-extended i8, bitcast to f32
-def : Pat <(f32 (bitconvert (i32 (zextloadi8 (am_indexed16 GPR64sp:$Rn, uimm12s1:$offset))))),
- (SUBREG_TO_REG (i32 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>;
+def : Pat<(f32 (bitconvert (i32 (zextloadi8 (am_indexed8 GPR64sp:$Rn, uimm12s1:$offset))))),
+ (SUBREG_TO_REG (i32 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>;
// Pre-fetch.
def PRFMui : PrefetchUI<0b11, 0, 0b10, "prfm",
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index e3370d3..2053fc4 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -1577,18 +1577,26 @@ static SVEIntrinsicInfo constructSVEIntrinsicInfo(IntrinsicInst &II) {
}
static bool isAllActivePredicate(Value *Pred) {
- // Look through convert.from.svbool(convert.to.svbool(...) chain.
Value *UncastedPred;
+
+ // Look through predicate casts that only remove lanes.
if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>(
- m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(
- m_Value(UncastedPred)))))
- // If the predicate has the same or less lanes than the uncasted
- // predicate then we know the casting has no effect.
- if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <=
- cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements())
- Pred = UncastedPred;
+ m_Value(UncastedPred)))) {
+ auto *OrigPredTy = cast<ScalableVectorType>(Pred->getType());
+ Pred = UncastedPred;
+
+ if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(
+ m_Value(UncastedPred))))
+ // If the predicate has the same or less lanes than the uncasted predicate
+ // then we know the casting has no effect.
+ if (OrigPredTy->getMinNumElements() <=
+ cast<ScalableVectorType>(UncastedPred->getType())
+ ->getMinNumElements())
+ Pred = UncastedPred;
+ }
+
auto *C = dyn_cast<Constant>(Pred);
- return (C && C->isAllOnesValue());
+ return C && C->isAllOnesValue();
}
// Simplify `V` by only considering the operations that affect active lanes.
diff --git a/llvm/lib/Target/ARM/ARMAsmPrinter.cpp b/llvm/lib/Target/ARM/ARMAsmPrinter.cpp
index 3368a50..36b9908 100644
--- a/llvm/lib/Target/ARM/ARMAsmPrinter.cpp
+++ b/llvm/lib/Target/ARM/ARMAsmPrinter.cpp
@@ -1471,6 +1471,435 @@ void ARMAsmPrinter::EmitUnwindingInstruction(const MachineInstr *MI) {
// instructions) auto-generated.
#include "ARMGenMCPseudoLowering.inc"
+// Helper function to check if a register is live (used as an implicit operand)
+// in the given call instruction.
+static bool isRegisterLiveInCall(const MachineInstr &Call, MCRegister Reg) {
+ for (const MachineOperand &MO : Call.implicit_operands()) {
+ if (MO.isReg() && MO.getReg() == Reg && MO.isUse()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void ARMAsmPrinter::EmitKCFI_CHECK_ARM32(Register AddrReg, int64_t Type,
+ const MachineInstr &Call,
+ int64_t PrefixNops) {
+ // Choose scratch register: r12 primary, r3 if target is r12.
+ unsigned ScratchReg = ARM::R12;
+ if (AddrReg == ARM::R12) {
+ ScratchReg = ARM::R3;
+ }
+
+ // Calculate ESR for ARM mode (16-bit): 0x8000 | (scratch_reg << 5) | addr_reg
+ // Note: scratch_reg is always 0x1F since the EOR sequence clobbers it.
+ const ARMBaseRegisterInfo *TRI = static_cast<const ARMBaseRegisterInfo *>(
+ MF->getSubtarget().getRegisterInfo());
+ unsigned AddrIndex = TRI->getEncodingValue(AddrReg);
+ unsigned ESR = 0x8000 | (31 << 5) | (AddrIndex & 31);
+
+ // Check if r3 is live and needs to be spilled.
+ bool NeedSpillR3 =
+ (ScratchReg == ARM::R3) && isRegisterLiveInCall(Call, ARM::R3);
+
+ // If we need to spill r3, push it first.
+ if (NeedSpillR3) {
+ // push {r3}
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::STMDB_UPD)
+ .addReg(ARM::SP)
+ .addReg(ARM::SP)
+ .addImm(ARMCC::AL)
+ .addReg(0)
+ .addReg(ARM::R3));
+ }
+
+ // Clear bit 0 of target address to handle Thumb function pointers.
+ // In 32-bit ARM, function pointers may have the low bit set to indicate
+ // Thumb state when ARM/Thumb interworking is enabled (ARMv4T and later).
+ // We need to clear it to avoid an alignment fault when loading.
+ // bic scratch, target, #1
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::BICri)
+ .addReg(ScratchReg)
+ .addReg(AddrReg)
+ .addImm(1)
+ .addImm(ARMCC::AL)
+ .addReg(0)
+ .addReg(0));
+
+ // ldr scratch, [scratch, #-(PrefixNops * 4 + 4)]
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::LDRi12)
+ .addReg(ScratchReg)
+ .addReg(ScratchReg)
+ .addImm(-(PrefixNops * 4 + 4))
+ .addImm(ARMCC::AL)
+ .addReg(0));
+
+ // Each EOR instruction XORs one byte of the type, shifted to its position.
+ for (int i = 0; i < 4; i++) {
+ uint8_t byte = (Type >> (i * 8)) & 0xFF;
+ uint32_t imm = byte << (i * 8);
+ bool isLast = (i == 3);
+
+ // Encode as ARM modified immediate.
+ int SOImmVal = ARM_AM::getSOImmVal(imm);
+ assert(SOImmVal != -1 &&
+ "Cannot encode immediate as ARM modified immediate");
+
+ // eor[s] scratch, scratch, #imm (last one sets flags with CPSR)
+ EmitToStreamer(*OutStreamer,
+ MCInstBuilder(ARM::EORri)
+ .addReg(ScratchReg)
+ .addReg(ScratchReg)
+ .addImm(SOImmVal)
+ .addImm(ARMCC::AL)
+ .addReg(0)
+ .addReg(isLast ? ARM::CPSR : ARM::NoRegister));
+ }
+
+ // If we spilled r3, restore it immediately after the comparison.
+ // This must happen before the branch so r3 is valid on both paths.
+ if (NeedSpillR3) {
+ // pop {r3}
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::LDMIA_UPD)
+ .addReg(ARM::SP)
+ .addReg(ARM::SP)
+ .addImm(ARMCC::AL)
+ .addReg(0)
+ .addReg(ARM::R3));
+ }
+
+ // beq .Lpass (branch if types match, i.e., scratch is zero)
+ MCSymbol *Pass = OutContext.createTempSymbol();
+ EmitToStreamer(*OutStreamer,
+ MCInstBuilder(ARM::Bcc)
+ .addExpr(MCSymbolRefExpr::create(Pass, OutContext))
+ .addImm(ARMCC::EQ)
+ .addReg(ARM::CPSR));
+
+ // udf #ESR (trap with encoded diagnostic)
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::UDF).addImm(ESR));
+
+ OutStreamer->emitLabel(Pass);
+}
+
+void ARMAsmPrinter::EmitKCFI_CHECK_Thumb2(Register AddrReg, int64_t Type,
+ const MachineInstr &Call,
+ int64_t PrefixNops) {
+ // Choose scratch register: r12 primary, r3 if target is r12.
+ unsigned ScratchReg = ARM::R12;
+ if (AddrReg == ARM::R12) {
+ ScratchReg = ARM::R3;
+ }
+
+ // Calculate ESR for Thumb mode (8-bit): 0x80 | addr_reg
+ // Bit 7: KCFI trap indicator
+ // Bits 6-5: Reserved
+ // Bits 4-0: Address register encoding
+ const ARMBaseRegisterInfo *TRI = static_cast<const ARMBaseRegisterInfo *>(
+ MF->getSubtarget().getRegisterInfo());
+ unsigned AddrIndex = TRI->getEncodingValue(AddrReg);
+ unsigned ESR = 0x80 | (AddrIndex & 0x1F);
+
+ // Check if r3 is live and needs to be spilled.
+ bool NeedSpillR3 =
+ (ScratchReg == ARM::R3) && isRegisterLiveInCall(Call, ARM::R3);
+
+ // If we need to spill r3, push it first.
+ if (NeedSpillR3) {
+ // push {r3}
+ EmitToStreamer(
+ *OutStreamer,
+ MCInstBuilder(ARM::tPUSH).addImm(ARMCC::AL).addReg(0).addReg(ARM::R3));
+ }
+
+ // Clear bit 0 of target address to handle Thumb function pointers.
+ // In 32-bit ARM, function pointers may have the low bit set to indicate
+ // Thumb state when ARM/Thumb interworking is enabled (ARMv4T and later).
+ // We need to clear it to avoid an alignment fault when loading.
+ // bic scratch, target, #1
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::t2BICri)
+ .addReg(ScratchReg)
+ .addReg(AddrReg)
+ .addImm(1)
+ .addImm(ARMCC::AL)
+ .addReg(0)
+ .addReg(0));
+
+ // ldr scratch, [scratch, #-(PrefixNops * 4 + 4)]
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::t2LDRi8)
+ .addReg(ScratchReg)
+ .addReg(ScratchReg)
+ .addImm(-(PrefixNops * 4 + 4))
+ .addImm(ARMCC::AL)
+ .addReg(0));
+
+ // Each EOR instruction XORs one byte of the type, shifted to its position.
+ for (int i = 0; i < 4; i++) {
+ uint8_t byte = (Type >> (i * 8)) & 0xFF;
+ uint32_t imm = byte << (i * 8);
+ bool isLast = (i == 3);
+
+ // Verify the immediate can be encoded as Thumb2 modified immediate.
+ assert(ARM_AM::getT2SOImmVal(imm) != -1 &&
+ "Cannot encode immediate as Thumb2 modified immediate");
+
+ // eor[s] scratch, scratch, #imm (last one sets flags with CPSR)
+ EmitToStreamer(*OutStreamer,
+ MCInstBuilder(ARM::t2EORri)
+ .addReg(ScratchReg)
+ .addReg(ScratchReg)
+ .addImm(imm)
+ .addImm(ARMCC::AL)
+ .addReg(0)
+ .addReg(isLast ? ARM::CPSR : ARM::NoRegister));
+ }
+
+ // If we spilled r3, restore it immediately after the comparison.
+ // This must happen before the branch so r3 is valid on both paths.
+ if (NeedSpillR3) {
+ // pop {r3}
+ EmitToStreamer(
+ *OutStreamer,
+ MCInstBuilder(ARM::tPOP).addImm(ARMCC::AL).addReg(0).addReg(ARM::R3));
+ }
+
+ // beq .Lpass (branch if types match, i.e., scratch is zero)
+ MCSymbol *Pass = OutContext.createTempSymbol();
+ EmitToStreamer(*OutStreamer,
+ MCInstBuilder(ARM::t2Bcc)
+ .addExpr(MCSymbolRefExpr::create(Pass, OutContext))
+ .addImm(ARMCC::EQ)
+ .addReg(ARM::CPSR));
+
+ // udf #ESR (trap with encoded diagnostic)
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tUDF).addImm(ESR));
+
+ OutStreamer->emitLabel(Pass);
+}
+
+void ARMAsmPrinter::EmitKCFI_CHECK_Thumb1(Register AddrReg, int64_t Type,
+ const MachineInstr &Call,
+ int64_t PrefixNops) {
+ // For Thumb1, use R2 unconditionally as scratch register (a low register
+ // required for tLDRi). R3 is used for building the type hash.
+ unsigned ScratchReg = ARM::R2;
+ unsigned TempReg = ARM::R3;
+
+ // Check if r3 is live and needs to be spilled.
+ bool NeedSpillR3 = isRegisterLiveInCall(Call, ARM::R3);
+
+ // Spill r3 if needed
+ if (NeedSpillR3) {
+ EmitToStreamer(
+ *OutStreamer,
+ MCInstBuilder(ARM::tPUSH).addImm(ARMCC::AL).addReg(0).addReg(ARM::R3));
+ }
+
+ // Check if r2 is live and needs to be spilled.
+ bool NeedSpillR2 = isRegisterLiveInCall(Call, ARM::R2);
+
+ // Push R2 if it's live
+ if (NeedSpillR2) {
+ EmitToStreamer(
+ *OutStreamer,
+ MCInstBuilder(ARM::tPUSH).addImm(ARMCC::AL).addReg(0).addReg(ARM::R2));
+ }
+
+ // Clear bit 0 from target address
+ // TempReg (R3) is used first as helper for BIC, then later for building type
+ // hash.
+
+ // movs temp, #1
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tMOVi8)
+ .addReg(TempReg)
+ .addReg(ARM::CPSR)
+ .addImm(1)
+ .addImm(ARMCC::AL)
+ .addReg(0));
+
+ // mov scratch, target
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tMOVr)
+ .addReg(ScratchReg)
+ .addReg(AddrReg)
+ .addImm(ARMCC::AL));
+
+ // bics scratch, temp (scratch = scratch & ~temp)
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tBIC)
+ .addReg(ScratchReg)
+ .addReg(ARM::CPSR)
+ .addReg(ScratchReg)
+ .addReg(TempReg)
+ .addImm(ARMCC::AL)
+ .addReg(0));
+
+ // Load type hash. Thumb1 doesn't support negative offsets, so subtract.
+ int offset = PrefixNops * 4 + 4;
+
+ // subs scratch, #offset
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tSUBi8)
+ .addReg(ScratchReg)
+ .addReg(ARM::CPSR)
+ .addReg(ScratchReg)
+ .addImm(offset)
+ .addImm(ARMCC::AL)
+ .addReg(0));
+
+ // ldr scratch, [scratch, #0]
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tLDRi)
+ .addReg(ScratchReg)
+ .addReg(ScratchReg)
+ .addImm(0)
+ .addImm(ARMCC::AL)
+ .addReg(0));
+
+ // Load expected type inline (instead of EOR sequence)
+ //
+ // This creates the 32-bit value byte-by-byte in the temp register:
+ // movs temp, #byte3 (high byte)
+ // lsls temp, temp, #8
+ // adds temp, #byte2
+ // lsls temp, temp, #8
+ // adds temp, #byte1
+ // lsls temp, temp, #8
+ // adds temp, #byte0 (low byte)
+
+ uint8_t byte0 = (Type >> 0) & 0xFF;
+ uint8_t byte1 = (Type >> 8) & 0xFF;
+ uint8_t byte2 = (Type >> 16) & 0xFF;
+ uint8_t byte3 = (Type >> 24) & 0xFF;
+
+ // movs temp, #byte3 (start with high byte)
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tMOVi8)
+ .addReg(TempReg)
+ .addReg(ARM::CPSR)
+ .addImm(byte3)
+ .addImm(ARMCC::AL)
+ .addReg(0));
+
+ // lsls temp, temp, #8
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tLSLri)
+ .addReg(TempReg)
+ .addReg(ARM::CPSR)
+ .addReg(TempReg)
+ .addImm(8)
+ .addImm(ARMCC::AL)
+ .addReg(0));
+
+ // adds temp, #byte2
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tADDi8)
+ .addReg(TempReg)
+ .addReg(ARM::CPSR)
+ .addReg(TempReg)
+ .addImm(byte2)
+ .addImm(ARMCC::AL)
+ .addReg(0));
+
+ // lsls temp, temp, #8
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tLSLri)
+ .addReg(TempReg)
+ .addReg(ARM::CPSR)
+ .addReg(TempReg)
+ .addImm(8)
+ .addImm(ARMCC::AL)
+ .addReg(0));
+
+ // adds temp, #byte1
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tADDi8)
+ .addReg(TempReg)
+ .addReg(ARM::CPSR)
+ .addReg(TempReg)
+ .addImm(byte1)
+ .addImm(ARMCC::AL)
+ .addReg(0));
+
+ // lsls temp, temp, #8
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tLSLri)
+ .addReg(TempReg)
+ .addReg(ARM::CPSR)
+ .addReg(TempReg)
+ .addImm(8)
+ .addImm(ARMCC::AL)
+ .addReg(0));
+
+ // adds temp, #byte0 (low byte)
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tADDi8)
+ .addReg(TempReg)
+ .addReg(ARM::CPSR)
+ .addReg(TempReg)
+ .addImm(byte0)
+ .addImm(ARMCC::AL)
+ .addReg(0));
+
+ // cmp scratch, temp
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tCMPr)
+ .addReg(ScratchReg)
+ .addReg(TempReg)
+ .addImm(ARMCC::AL)
+ .addReg(0));
+
+ // Restore registers if spilled (pop in reverse order of push: R2, then R3)
+ if (NeedSpillR2) {
+ // pop {r2}
+ EmitToStreamer(
+ *OutStreamer,
+ MCInstBuilder(ARM::tPOP).addImm(ARMCC::AL).addReg(0).addReg(ARM::R2));
+ }
+
+ // Restore r3 if spilled
+ if (NeedSpillR3) {
+ // pop {r3}
+ EmitToStreamer(
+ *OutStreamer,
+ MCInstBuilder(ARM::tPOP).addImm(ARMCC::AL).addReg(0).addReg(ARM::R3));
+ }
+
+ // beq .Lpass (branch if types match, i.e., scratch == temp)
+ MCSymbol *Pass = OutContext.createTempSymbol();
+ EmitToStreamer(*OutStreamer,
+ MCInstBuilder(ARM::tBcc)
+ .addExpr(MCSymbolRefExpr::create(Pass, OutContext))
+ .addImm(ARMCC::EQ)
+ .addReg(ARM::CPSR));
+
+ // bkpt #0 (trap with encoded diagnostic)
+ EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tBKPT).addImm(0));
+
+ OutStreamer->emitLabel(Pass);
+}
+
+void ARMAsmPrinter::LowerKCFI_CHECK(const MachineInstr &MI) {
+ Register AddrReg = MI.getOperand(0).getReg();
+ const int64_t Type = MI.getOperand(1).getImm();
+
+ // Get the call instruction that follows this KCFI_CHECK.
+ assert(std::next(MI.getIterator())->isCall() &&
+ "KCFI_CHECK not followed by a call instruction");
+ const MachineInstr &Call = *std::next(MI.getIterator());
+
+ // Adjust the offset for patchable-function-prefix.
+ int64_t PrefixNops = 0;
+ MI.getMF()
+ ->getFunction()
+ .getFnAttribute("patchable-function-prefix")
+ .getValueAsString()
+ .getAsInteger(10, PrefixNops);
+
+ // Emit the appropriate instruction sequence based on the opcode variant.
+ switch (MI.getOpcode()) {
+ case ARM::KCFI_CHECK_ARM:
+ EmitKCFI_CHECK_ARM32(AddrReg, Type, Call, PrefixNops);
+ break;
+ case ARM::KCFI_CHECK_Thumb2:
+ EmitKCFI_CHECK_Thumb2(AddrReg, Type, Call, PrefixNops);
+ break;
+ case ARM::KCFI_CHECK_Thumb1:
+ EmitKCFI_CHECK_Thumb1(AddrReg, Type, Call, PrefixNops);
+ break;
+ default:
+ llvm_unreachable("Unexpected KCFI_CHECK opcode");
+ }
+}
+
void ARMAsmPrinter::emitInstruction(const MachineInstr *MI) {
ARM_MC::verifyInstructionPredicates(MI->getOpcode(),
getSubtargetInfo().getFeatureBits());
@@ -1504,6 +1933,11 @@ void ARMAsmPrinter::emitInstruction(const MachineInstr *MI) {
switch (Opc) {
case ARM::t2MOVi32imm: llvm_unreachable("Should be lowered by thumb2it pass");
case ARM::DBG_VALUE: llvm_unreachable("Should be handled by generic printing");
+ case ARM::KCFI_CHECK_ARM:
+ case ARM::KCFI_CHECK_Thumb2:
+ case ARM::KCFI_CHECK_Thumb1:
+ LowerKCFI_CHECK(*MI);
+ return;
case ARM::LEApcrel:
case ARM::tLEApcrel:
case ARM::t2LEApcrel: {
diff --git a/llvm/lib/Target/ARM/ARMAsmPrinter.h b/llvm/lib/Target/ARM/ARMAsmPrinter.h
index 2b067c7..9e92b5a 100644
--- a/llvm/lib/Target/ARM/ARMAsmPrinter.h
+++ b/llvm/lib/Target/ARM/ARMAsmPrinter.h
@@ -123,9 +123,20 @@ public:
void LowerPATCHABLE_FUNCTION_EXIT(const MachineInstr &MI);
void LowerPATCHABLE_TAIL_CALL(const MachineInstr &MI);
+ // KCFI check lowering
+ void LowerKCFI_CHECK(const MachineInstr &MI);
+
private:
void EmitSled(const MachineInstr &MI, SledKind Kind);
+ // KCFI check emission helpers
+ void EmitKCFI_CHECK_ARM32(Register AddrReg, int64_t Type,
+ const MachineInstr &Call, int64_t PrefixNops);
+ void EmitKCFI_CHECK_Thumb2(Register AddrReg, int64_t Type,
+ const MachineInstr &Call, int64_t PrefixNops);
+ void EmitKCFI_CHECK_Thumb1(Register AddrReg, int64_t Type,
+ const MachineInstr &Call, int64_t PrefixNops);
+
// Helpers for emitStartOfAsmFile() and emitEndOfAsmFile()
void emitAttributes();
diff --git a/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp b/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp
index 0d7b6d1..fffb6373 100644
--- a/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp
@@ -2301,6 +2301,8 @@ bool ARMExpandPseudo::ExpandMI(MachineBasicBlock &MBB,
for (unsigned i = 2, e = MBBI->getNumOperands(); i != e; ++i)
NewMI->addOperand(MBBI->getOperand(i));
+ NewMI->setCFIType(*MBB.getParent(), MI.getCFIType());
+
// Update call info and delete the pseudo instruction TCRETURN.
if (MI.isCandidateForAdditionalCallInfo())
MI.getMF()->moveAdditionalCallInfo(&MI, &*NewMI);
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index b1a668e..8122db2 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -2849,6 +2849,8 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
if (isTailCall) {
MF.getFrameInfo().setHasTailCall();
SDValue Ret = DAG.getNode(ARMISD::TC_RETURN, dl, MVT::Other, Ops);
+ if (CLI.CFIType)
+ Ret.getNode()->setCFIType(CLI.CFIType->getZExtValue());
DAG.addNoMergeSiteInfo(Ret.getNode(), CLI.NoMerge);
DAG.addCallSiteInfo(Ret.getNode(), std::move(CSInfo));
return Ret;
@@ -2856,6 +2858,8 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// Returns a chain and a flag for retval copy to use.
Chain = DAG.getNode(CallOpc, dl, {MVT::Other, MVT::Glue}, Ops);
+ if (CLI.CFIType)
+ Chain.getNode()->setCFIType(CLI.CFIType->getZExtValue());
DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge);
InGlue = Chain.getValue(1);
DAG.addCallSiteInfo(Chain.getNode(), std::move(CSInfo));
@@ -12008,6 +12012,71 @@ static void genTPLoopBody(MachineBasicBlock *TpLoopBody,
.add(predOps(ARMCC::AL));
}
+bool ARMTargetLowering::supportKCFIBundles() const {
+ // KCFI is supported in all ARM/Thumb modes
+ return true;
+}
+
+MachineInstr *
+ARMTargetLowering::EmitKCFICheck(MachineBasicBlock &MBB,
+ MachineBasicBlock::instr_iterator &MBBI,
+ const TargetInstrInfo *TII) const {
+ assert(MBBI->isCall() && MBBI->getCFIType() &&
+ "Invalid call instruction for a KCFI check");
+
+ MachineOperand *TargetOp = nullptr;
+ switch (MBBI->getOpcode()) {
+ // ARM mode opcodes
+ case ARM::BLX:
+ case ARM::BLX_pred:
+ case ARM::BLX_noip:
+ case ARM::BLX_pred_noip:
+ case ARM::BX_CALL:
+ TargetOp = &MBBI->getOperand(0);
+ break;
+ case ARM::TCRETURNri:
+ case ARM::TCRETURNrinotr12:
+ case ARM::TAILJMPr:
+ case ARM::TAILJMPr4:
+ TargetOp = &MBBI->getOperand(0);
+ break;
+ // Thumb mode opcodes (Thumb1 and Thumb2)
+ // Note: Most Thumb call instructions have predicate operands before the
+ // target register Format: tBLXr pred, predreg, target_register, ...
+ case ARM::tBLXr: // Thumb1/Thumb2: BLX register (requires V5T)
+ case ARM::tBLXr_noip: // Thumb1/Thumb2: BLX register, no IP clobber
+ case ARM::tBX_CALL: // Thumb1 only: BX call (push LR, BX)
+ TargetOp = &MBBI->getOperand(2);
+ break;
+ // Tail call instructions don't have predicates, target is operand 0
+ case ARM::tTAILJMPr: // Thumb1/Thumb2: Tail call via register
+ TargetOp = &MBBI->getOperand(0);
+ break;
+ default:
+ llvm_unreachable("Unexpected CFI call opcode");
+ }
+
+ assert(TargetOp && TargetOp->isReg() && "Invalid target operand");
+ TargetOp->setIsRenamable(false);
+
+ // Select the appropriate KCFI_CHECK variant based on the instruction set
+ unsigned KCFICheckOpcode;
+ if (Subtarget->isThumb()) {
+ if (Subtarget->isThumb2()) {
+ KCFICheckOpcode = ARM::KCFI_CHECK_Thumb2;
+ } else {
+ KCFICheckOpcode = ARM::KCFI_CHECK_Thumb1;
+ }
+ } else {
+ KCFICheckOpcode = ARM::KCFI_CHECK_ARM;
+ }
+
+ return BuildMI(MBB, MBBI, MBBI->getDebugLoc(), TII->get(KCFICheckOpcode))
+ .addReg(TargetOp->getReg())
+ .addImm(MBBI->getCFIType())
+ .getInstr();
+}
+
MachineBasicBlock *
ARMTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
MachineBasicBlock *BB) const {
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h
index 70aa001..8c5e0cf 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.h
+++ b/llvm/lib/Target/ARM/ARMISelLowering.h
@@ -447,6 +447,12 @@ class VectorType;
void AdjustInstrPostInstrSelection(MachineInstr &MI,
SDNode *Node) const override;
+ bool supportKCFIBundles() const override;
+
+ MachineInstr *EmitKCFICheck(MachineBasicBlock &MBB,
+ MachineBasicBlock::instr_iterator &MBBI,
+ const TargetInstrInfo *TII) const override;
+
SDValue PerformCMOVCombine(SDNode *N, SelectionDAG &DAG) const;
SDValue PerformBRCONDCombine(SDNode *N, SelectionDAG &DAG) const;
SDValue PerformCMOVToBFICombine(SDNode *N, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/ARM/ARMInstrInfo.td b/llvm/lib/Target/ARM/ARMInstrInfo.td
index 282ff53..53be167 100644
--- a/llvm/lib/Target/ARM/ARMInstrInfo.td
+++ b/llvm/lib/Target/ARM/ARMInstrInfo.td
@@ -6536,6 +6536,36 @@ def CMP_SWAP_64 : PseudoInst<(outs GPRPair:$Rd, GPRPair:$addr_temp_out),
def : Pat<(atomic_fence (timm), 0), (MEMBARRIER)>;
//===----------------------------------------------------------------------===//
+// KCFI check pseudo-instruction.
+//===----------------------------------------------------------------------===//
+// KCFI_CHECK pseudo-instruction for Kernel Control-Flow Integrity.
+// Expands to a sequence that verifies the function pointer's type hash.
+// Different sizes for different architectures due to different expansions.
+
+def KCFI_CHECK_ARM
+ : PseudoInst<(outs), (ins GPR:$ptr, i32imm:$type), NoItinerary, []>,
+ Sched<[]>,
+ Requires<[IsARM]> {
+ let Size = 28; // 7 instructions (bic, ldr, 4x eor, beq, udf)
+}
+
+def KCFI_CHECK_Thumb2
+ : PseudoInst<(outs), (ins GPR:$ptr, i32imm:$type), NoItinerary, []>,
+ Sched<[]>,
+ Requires<[IsThumb2]> {
+ let Size =
+ 32; // worst-case 9 instructions (push, bic, ldr, 4x eor, pop, beq.w, udf)
+}
+
+def KCFI_CHECK_Thumb1
+ : PseudoInst<(outs), (ins GPR:$ptr, i32imm:$type), NoItinerary, []>,
+ Sched<[]>,
+ Requires<[IsThumb1Only]> {
+ let Size = 50; // worst-case 25 instructions (pushes, bic helper, type
+ // building, cmp, pops)
+}
+
+//===----------------------------------------------------------------------===//
// Instructions used for emitting unwind opcodes on Windows.
//===----------------------------------------------------------------------===//
let isPseudo = 1 in {
diff --git a/llvm/lib/Target/ARM/ARMTargetMachine.cpp b/llvm/lib/Target/ARM/ARMTargetMachine.cpp
index 86740a9..590d4c7 100644
--- a/llvm/lib/Target/ARM/ARMTargetMachine.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetMachine.cpp
@@ -111,6 +111,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeARMTarget() {
initializeMVELaneInterleavingPass(Registry);
initializeARMFixCortexA57AES1742098Pass(Registry);
initializeARMDAGToDAGISelLegacyPass(Registry);
+ initializeKCFIPass(Registry);
}
static std::unique_ptr<TargetLoweringObjectFile> createTLOF(const Triple &TT) {
@@ -487,6 +488,9 @@ void ARMPassConfig::addPreSched2() {
// proper scheduling.
addPass(createARMExpandPseudoPass());
+ // Emit KCFI checks for indirect calls.
+ addPass(createKCFIPass());
+
if (getOptLevel() != CodeGenOptLevel::None) {
// When optimising for size, always run the Thumb2SizeReduction pass before
// IfConversion. Otherwise, check whether IT blocks are restricted
@@ -517,9 +521,12 @@ void ARMPassConfig::addPreSched2() {
void ARMPassConfig::addPreEmitPass() {
addPass(createThumb2SizeReductionPass());
- // Constant island pass work on unbundled instructions.
+ // Unpack bundles for:
+ // - Thumb2: Constant island pass requires unbundled instructions
+ // - KCFI: KCFI_CHECK pseudo instructions need to be unbundled for AsmPrinter
addPass(createUnpackMachineBundles([](const MachineFunction &MF) {
- return MF.getSubtarget<ARMSubtarget>().isThumb2();
+ return MF.getSubtarget<ARMSubtarget>().isThumb2() ||
+ MF.getFunction().getParent()->getModuleFlag("kcfi");
}));
// Don't optimize barriers or block placement at -O0.
@@ -530,6 +537,7 @@ void ARMPassConfig::addPreEmitPass() {
}
void ARMPassConfig::addPreEmitPass2() {
+
// Inserts fixup instructions before unsafe AES operations. Instructions may
// be inserted at the start of blocks and at within blocks so this pass has to
// come before those below.
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 22cf3a7..598735f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4675,7 +4675,7 @@ class WMMA_INSTR<string _Intr, list<dag> _Args>
//
class WMMA_LOAD<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride>
- : WMMA_INSTR<WMMA_NAME_LDST<"load", Frag, Layout, WithStride>.record,
+ : WMMA_INSTR<WMMA_NAME_LDST<"load", Frag, Layout, WithStride>.record_name,
[!con((ins ADDR:$src),
!if(WithStride, (ins B32:$ldm), (ins)))]>,
Requires<Frag.Predicates> {
@@ -4714,7 +4714,7 @@ class WMMA_LOAD<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride>
//
class WMMA_STORE_D<WMMA_REGINFO Frag, string Layout, string Space,
bit WithStride>
- : WMMA_INSTR<WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.record,
+ : WMMA_INSTR<WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.record_name,
[!con((ins ADDR:$dst),
Frag.Ins,
!if(WithStride, (ins B32:$ldm), (ins)))]>,
@@ -4778,7 +4778,7 @@ class MMA_OP_PREDICATES<WMMA_REGINFO FragA, string b1op> {
class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
string ALayout, string BLayout, int Satfinite, string rnd, string b1op>
- : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, b1op, FragA, FragB, FragC, FragD>.record,
+ : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, b1op, FragA, FragB, FragC, FragD>.record_name,
[FragA.Ins, FragB.Ins, FragC.Ins]>,
// Requires does not seem to have effect on Instruction w/o Patterns.
// We set it here anyways and propagate to the Pat<> we construct below.
@@ -4837,7 +4837,7 @@ defset list<WMMA_INSTR> WMMAs = {
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
string ALayout, string BLayout, int Satfinite, string b1op, string Kind>
- : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record,
+ : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record_name,
[FragA.Ins, FragB.Ins, FragC.Ins]>,
// Requires does not seem to have effect on Instruction w/o Patterns.
// We set it here anyways and propagate to the Pat<> we construct below.
@@ -4891,7 +4891,7 @@ class MMA_SP<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
string Metadata, string Kind, int Satfinite>
: WMMA_INSTR<MMA_SP_NAME<Metadata, Kind, Satfinite,
- FragA, FragB, FragC, FragD>.record,
+ FragA, FragB, FragC, FragD>.record_name,
[FragA.Ins, FragB.Ins, FragC.Ins,
(ins B32:$metadata, i32imm:$selector)]>,
// Requires does not seem to have effect on Instruction w/o Patterns.
@@ -4946,7 +4946,7 @@ defset list<WMMA_INSTR> MMA_SPs = {
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
//
class LDMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space>
- : WMMA_INSTR<LDMATRIX_NAME<Frag, Transposed>.record, [(ins ADDR:$src)]>,
+ : WMMA_INSTR<LDMATRIX_NAME<Frag, Transposed>.record_name, [(ins ADDR:$src)]>,
Requires<Frag.Predicates> {
// Build PatFrag that only matches particular address space.
PatFrag IntrFrag = PatFrag<(ops node:$src), (Intr node:$src),
@@ -4981,7 +4981,7 @@ defset list<WMMA_INSTR> LDMATRIXs = {
// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
//
class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space>
- : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>,
+ : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record_name, [!con((ins ADDR:$dst), Frag.Ins)]>,
Requires<Frag.Predicates> {
// Build PatFrag that only matches particular address space.
dag PFOperands = !con((ops node:$dst),
@@ -5376,7 +5376,7 @@ class Tcgen05MMAInst<bit Sp, string KindStr, string ASpace,
Requires<PTXPredicates> {
Intrinsic Intrin = !cast<Intrinsic>(
- NVVM_TCGEN05_MMA<Sp, ASpace, AShift, ScaleInputD>.record
+ NVVM_TCGEN05_MMA<Sp, ASpace, AShift, ScaleInputD>.record_name
);
dag ScaleInpIns = !if(!eq(ScaleInputD, 1), (ins i64imm:$scale_input_d), (ins));
@@ -5618,7 +5618,7 @@ class Tcgen05MMABlockScaleInst<bit Sp, string ASpace, string KindStr,
Requires<[hasTcgen05Instructions, PTXPredicate]> {
Intrinsic Intrin = !cast<Intrinsic>(
- NVVM_TCGEN05_MMA_BLOCKSCALE<Sp, ASpace, KindStr, ScaleVecSize>.record);
+ NVVM_TCGEN05_MMA_BLOCKSCALE<Sp, ASpace, KindStr, ScaleVecSize>.record_name);
dag SparseMetadataIns = !if(!eq(Sp, 1), (ins B32:$spmetadata), (ins));
dag SparseMetadataIntr = !if(!eq(Sp, 1), (Intrin i32:$spmetadata), (Intrin));
@@ -5702,7 +5702,7 @@ class Tcgen05MMAWSInst<bit Sp, string ASpace, string KindStr,
Requires<[hasTcgen05Instructions]> {
Intrinsic Intrin = !cast<Intrinsic>(
- NVVM_TCGEN05_MMA_WS<Sp, ASpace, HasZeroColMask>.record);
+ NVVM_TCGEN05_MMA_WS<Sp, ASpace, HasZeroColMask>.record_name);
dag ZeroColMaskIns = !if(!eq(HasZeroColMask, 1),
(ins B64:$zero_col_mask), (ins));
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td
index 4104abd..4c2f7f6 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td
@@ -482,7 +482,7 @@ let Predicates = [HasVendorXSfvfwmaccqqq] in {
defm SF_VFWMACC_4x4x4 : VPseudoSiFiveVFWMACC;
}
-let Predicates = [HasVendorXSfvfnrclipxfqf] in {
+let Predicates = [HasVendorXSfvfnrclipxfqf], AltFmtType = IS_NOT_ALTFMT in {
defm SF_VFNRCLIP_XU_F_QF : VPseudoSiFiveVFNRCLIP;
defm SF_VFNRCLIP_X_F_QF : VPseudoSiFiveVFNRCLIP;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index d91923b..56a38bb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1499,18 +1499,25 @@ static bool generateKernelClockInst(const SPIRV::IncomingCall *Call,
Register ResultReg = Call->ReturnRegister;
- // Deduce the `Scope` operand from the builtin function name.
- SPIRV::Scope::Scope ScopeArg =
- StringSwitch<SPIRV::Scope::Scope>(Builtin->Name)
- .EndsWith("device", SPIRV::Scope::Scope::Device)
- .EndsWith("work_group", SPIRV::Scope::Scope::Workgroup)
- .EndsWith("sub_group", SPIRV::Scope::Scope::Subgroup);
- Register ScopeReg = buildConstantIntReg32(ScopeArg, MIRBuilder, GR);
-
- MIRBuilder.buildInstr(SPIRV::OpReadClockKHR)
- .addDef(ResultReg)
- .addUse(GR->getSPIRVTypeID(Call->ReturnType))
- .addUse(ScopeReg);
+ if (Builtin->Name == "__spirv_ReadClockKHR") {
+ MIRBuilder.buildInstr(SPIRV::OpReadClockKHR)
+ .addDef(ResultReg)
+ .addUse(GR->getSPIRVTypeID(Call->ReturnType))
+ .addUse(Call->Arguments[0]);
+ } else {
+ // Deduce the `Scope` operand from the builtin function name.
+ SPIRV::Scope::Scope ScopeArg =
+ StringSwitch<SPIRV::Scope::Scope>(Builtin->Name)
+ .EndsWith("device", SPIRV::Scope::Scope::Device)
+ .EndsWith("work_group", SPIRV::Scope::Scope::Workgroup)
+ .EndsWith("sub_group", SPIRV::Scope::Scope::Subgroup);
+ Register ScopeReg = buildConstantIntReg32(ScopeArg, MIRBuilder, GR);
+
+ MIRBuilder.buildInstr(SPIRV::OpReadClockKHR)
+ .addDef(ResultReg)
+ .addUse(GR->getSPIRVTypeID(Call->ReturnType))
+ .addUse(ScopeReg);
+ }
return true;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 3b8764a..c259cce 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -1174,6 +1174,7 @@ defm : DemangledNativeBuiltin<"clock_read_sub_group", OpenCL_std, KernelClock, 0
defm : DemangledNativeBuiltin<"clock_read_hilo_device", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;
defm : DemangledNativeBuiltin<"clock_read_hilo_work_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;
defm : DemangledNativeBuiltin<"clock_read_hilo_sub_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>;
+defm : DemangledNativeBuiltin<"__spirv_ReadClockKHR", OpenCL_std, KernelClock, 1, 1, OpReadClockKHR>;
//===----------------------------------------------------------------------===//
// Class defining an atomic instruction on floating-point numbers.
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index d49f25a..4dfc400 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -2632,6 +2632,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(Op, MVT::f32, Promote);
}
+ setOperationPromotedToType(ISD::ATOMIC_LOAD, MVT::f16, MVT::i16);
+ setOperationPromotedToType(ISD::ATOMIC_LOAD, MVT::f32, MVT::i32);
+ setOperationPromotedToType(ISD::ATOMIC_LOAD, MVT::f64, MVT::i64);
+
// We have target-specific dag combine patterns for the following nodes:
setTargetDAGCombine({ISD::VECTOR_SHUFFLE,
ISD::SCALAR_TO_VECTOR,
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index a0f7ec6..2dd0fde 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -948,17 +948,17 @@ void llvm::updateVCallVisibilityInIndex(
// linker, as we have no information on their eventual use.
if (DynamicExportSymbols.count(P.first))
continue;
+ // With validation enabled, we want to exclude symbols visible to regular
+ // objects. Local symbols will be in this group due to the current
+ // implementation but those with VCallVisibilityTranslationUnit will have
+ // already been marked in clang so are unaffected.
+ if (VisibleToRegularObjSymbols.count(P.first))
+ continue;
for (auto &S : P.second.getSummaryList()) {
auto *GVar = dyn_cast<GlobalVarSummary>(S.get());
if (!GVar ||
GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic)
continue;
- // With validation enabled, we want to exclude symbols visible to regular
- // objects. Local symbols will be in this group due to the current
- // implementation but those with VCallVisibilityTranslationUnit will have
- // already been marked in clang so are unaffected.
- if (VisibleToRegularObjSymbols.count(P.first))
- continue;
GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit);
}
}
@@ -1161,14 +1161,10 @@ bool DevirtIndex::tryFindVirtualCallTargets(
// and therefore the same GUID. This can happen if there isn't enough
// distinguishing path when compiling the source file. In that case we
// conservatively return false early.
+ if (P.VTableVI.hasLocal() && P.VTableVI.getSummaryList().size() > 1)
+ return false;
const GlobalVarSummary *VS = nullptr;
- bool LocalFound = false;
for (const auto &S : P.VTableVI.getSummaryList()) {
- if (GlobalValue::isLocalLinkage(S->linkage())) {
- if (LocalFound)
- return false;
- LocalFound = true;
- }
auto *CurVS = cast<GlobalVarSummary>(S->getBaseObject());
if (!CurVS->vTableFuncs().empty() ||
// Previously clang did not attach the necessary type metadata to
@@ -1184,6 +1180,7 @@ bool DevirtIndex::tryFindVirtualCallTargets(
// with public LTO visibility.
if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic)
return false;
+ break;
}
}
// There will be no VS if all copies are available_externally having no
@@ -1411,9 +1408,8 @@ bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot,
// If the summary list contains multiple summaries where at least one is
// a local, give up, as we won't know which (possibly promoted) name to use.
- for (const auto &S : TheFn.getSummaryList())
- if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1)
- return false;
+ if (TheFn.hasLocal() && Size > 1)
+ return false;
// Collect functions devirtualized at least for one call site for stats.
if (PrintSummaryDevirt || AreStatisticsEnabled())
@@ -2591,6 +2587,11 @@ void DevirtIndex::run() {
if (ExportSummary.typeIdCompatibleVtableMap().empty())
return;
+ // Assert that we haven't made any changes that would affect the hasLocal()
+ // flag on the GUID summary info.
+ assert(!ExportSummary.withInternalizeAndPromote() &&
+ "Expect index-based WPD to run before internalization and promotion");
+
DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID;
for (const auto &P : ExportSummary.typeIdCompatibleVtableMap()) {
NameByGUID[GlobalValue::getGUIDAssumingExternalLinkage(P.first)].push_back(
diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
index 4acc3f2..d347ced 100644
--- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
@@ -614,6 +614,16 @@ static Decomposition decompose(Value *V,
return {V, IsKnownNonNegative};
}
+ if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() &&
+ canUseSExt(CI)) {
+ Preconditions.emplace_back(
+ CmpInst::ICMP_UGE, Op0,
+ ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1));
+ if (auto Decomp = MergeResults(Op0, CI, true))
+ return *Decomp;
+ return {V, IsKnownNonNegative};
+ }
+
if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) {
if (!isKnownNonNegative(Op0, DL))
Preconditions.emplace_back(CmpInst::ICMP_SGE, Op0,
@@ -627,16 +637,6 @@ static Decomposition decompose(Value *V,
return {V, IsKnownNonNegative};
}
- if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() &&
- canUseSExt(CI)) {
- Preconditions.emplace_back(
- CmpInst::ICMP_UGE, Op0,
- ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1));
- if (auto Decomp = MergeResults(Op0, CI, true))
- return *Decomp;
- return {V, IsKnownNonNegative};
- }
-
// Decompose or as an add if there are no common bits between the operands.
if (match(V, m_DisjointOr(m_Value(Op0), m_ConstantInt(CI)))) {
if (auto Decomp = MergeResults(Op0, CI, IsSigned))
diff --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp
index a83cbd17a7..f273e9d 100644
--- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp
+++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp
@@ -64,10 +64,10 @@
using namespace llvm;
-namespace {
-
#define DEBUG_TYPE "mergeicmps"
+namespace {
+
// A BCE atom "Binary Compare Expression Atom" represents an integer load
// that is a constant offset from a base value, e.g. `a` or `o.c` in the example
// at the top.
@@ -128,11 +128,12 @@ private:
unsigned Order = 1;
DenseMap<const Value*, int> BaseToIndex;
};
+} // namespace
// If this value is a load from a constant offset w.r.t. a base address, and
// there are no other users of the load or address, returns the base address and
// the offset.
-BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) {
+static BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) {
auto *const LoadI = dyn_cast<LoadInst>(Val);
if (!LoadI)
return {};
@@ -175,6 +176,7 @@ BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) {
return BCEAtom(GEP, LoadI, BaseId.getBaseId(Base), Offset);
}
+namespace {
// A comparison between two BCE atoms, e.g. `a == o.a` in the example at the
// top.
// Note: the terminology is misleading: the comparison is symmetric, so there
@@ -239,6 +241,7 @@ class BCECmpBlock {
private:
BCECmp Cmp;
};
+} // namespace
bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst,
AliasAnalysis &AA) const {
@@ -302,9 +305,9 @@ bool BCECmpBlock::doesOtherWork() const {
// Visit the given comparison. If this is a comparison between two valid
// BCE atoms, returns the comparison.
-std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI,
- const ICmpInst::Predicate ExpectedPredicate,
- BaseIdentifier &BaseId) {
+static std::optional<BCECmp>
+visitICmp(const ICmpInst *const CmpI,
+ const ICmpInst::Predicate ExpectedPredicate, BaseIdentifier &BaseId) {
// The comparison can only be used once:
// - For intermediate blocks, as a branch condition.
// - For the final block, as an incoming value for the Phi.
@@ -332,10 +335,9 @@ std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI,
// Visit the given comparison block. If this is a comparison between two valid
// BCE atoms, returns the comparison.
-std::optional<BCECmpBlock> visitCmpBlock(Value *const Val,
- BasicBlock *const Block,
- const BasicBlock *const PhiBlock,
- BaseIdentifier &BaseId) {
+static std::optional<BCECmpBlock>
+visitCmpBlock(Value *const Val, BasicBlock *const Block,
+ const BasicBlock *const PhiBlock, BaseIdentifier &BaseId) {
if (Block->empty())
return std::nullopt;
auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator());
@@ -397,6 +399,7 @@ static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons,
Comparisons.push_back(std::move(Comparison));
}
+namespace {
// A chain of comparisons.
class BCECmpChain {
public:
@@ -420,6 +423,7 @@ private:
// The original entry block (before sorting);
BasicBlock *EntryBlock_;
};
+} // namespace
static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) {
return First.Lhs().BaseId == Second.Lhs().BaseId &&
@@ -742,9 +746,8 @@ bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
return true;
}
-std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi,
- BasicBlock *const LastBlock,
- int NumBlocks) {
+static std::vector<BasicBlock *>
+getOrderedBlocks(PHINode &Phi, BasicBlock *const LastBlock, int NumBlocks) {
// Walk up from the last block to find other blocks.
std::vector<BasicBlock *> Blocks(NumBlocks);
assert(LastBlock && "invalid last block");
@@ -777,8 +780,8 @@ std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi,
return Blocks;
}
-bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, AliasAnalysis &AA,
- DomTreeUpdater &DTU) {
+static bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI,
+ AliasAnalysis &AA, DomTreeUpdater &DTU) {
LLVM_DEBUG(dbgs() << "processPhi()\n");
if (Phi.getNumIncomingValues() <= 1) {
LLVM_DEBUG(dbgs() << "skip: only one incoming value in phi\n");
@@ -874,6 +877,7 @@ static bool runImpl(Function &F, const TargetLibraryInfo &TLI,
return MadeChange;
}
+namespace {
class MergeICmpsLegacyPass : public FunctionPass {
public:
static char ID;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 06bea2f..a1ad2db 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2712,7 +2712,8 @@ public:
static inline bool classof(const VPRecipeBase *R) {
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
- R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
+ R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
+ R->getVPDefID() == VPRecipeBase::VPPartialReductionSC;
}
static inline bool classof(const VPUser *U) {
@@ -2783,7 +2784,10 @@ public:
Opcode(Opcode), VFScaleFactor(ScaleFactor) {
[[maybe_unused]] auto *AccumulatorRecipe =
getChainOp()->getDefiningRecipe();
- assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
+ // When cloning as part of a VPExpressionRecipe the chain op could have
+ // replaced by a temporary VPValue, so it doesn't have a defining recipe.
+ assert((!AccumulatorRecipe ||
+ isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
"Unexpected operand order for partial reduction recipe");
}
@@ -3093,6 +3097,11 @@ public:
/// removed before codegen.
void decompose();
+ unsigned getVFScaleFactor() const {
+ auto *PR = dyn_cast<VPPartialReductionRecipe>(ExpressionRecipes.back());
+ return PR ? PR->getVFScaleFactor() : 1;
+ }
+
/// Method for generating code, must not be called as this recipe is abstract.
void execute(VPTransformState &State) override {
llvm_unreachable("recipe must be removed before execute");
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 1f1b42b..931a5b7 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -168,6 +168,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
return cast<VPWidenIntrinsicRecipe>(this)->mayHaveSideEffects();
case VPBlendSC:
case VPReductionEVLSC:
+ case VPPartialReductionSC:
case VPReductionSC:
case VPScalarIVStepsSC:
case VPVectorPointerSC:
@@ -300,14 +301,23 @@ InstructionCost
VPPartialReductionRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
std::optional<unsigned> Opcode;
- VPValue *Op = getOperand(0);
- VPRecipeBase *OpR = Op->getDefiningRecipe();
-
- // If the partial reduction is predicated, a select will be operand 0
- if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(Op), m_VPValue()))) {
- OpR = Op->getDefiningRecipe();
+ VPValue *Op = getVecOp();
+ uint64_t MulConst;
+ // If the partial reduction is predicated, a select will be operand 1.
+ // If it isn't predicated and the mul isn't operating on a constant, then it
+ // should have been turned into a VPExpressionRecipe.
+ // FIXME: Replace the entire function with this once all partial reduction
+ // variants are bundled into VPExpressionRecipe.
+ if (!match(Op, m_Select(m_VPValue(), m_VPValue(Op), m_VPValue())) &&
+ !match(Op, m_Mul(m_VPValue(), m_ConstantInt(MulConst)))) {
+ auto *PhiType = Ctx.Types.inferScalarType(getChainOp());
+ auto *InputType = Ctx.Types.inferScalarType(getVecOp());
+ return Ctx.TTI.getPartialReductionCost(getOpcode(), InputType, InputType,
+ PhiType, VF, TTI::PR_None,
+ TTI::PR_None, {}, Ctx.CostKind);
}
+ VPRecipeBase *OpR = Op->getDefiningRecipe();
Type *InputTypeA = nullptr, *InputTypeB = nullptr;
TTI::PartialReductionExtendKind ExtAType = TTI::PR_None,
ExtBType = TTI::PR_None;
@@ -2856,11 +2866,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
cast<VPReductionRecipe>(ExpressionRecipes.back())->getRecurrenceKind());
switch (ExpressionType) {
case ExpressionTypes::ExtendedReduction: {
- return Ctx.TTI.getExtendedReductionCost(
- Opcode,
- cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
- Instruction::ZExt,
- RedTy, SrcVecTy, std::nullopt, Ctx.CostKind);
+ unsigned Opcode = RecurrenceDescriptor::getOpcode(
+ cast<VPReductionRecipe>(ExpressionRecipes[1])->getRecurrenceKind());
+ auto *ExtR = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
+ return isa<VPPartialReductionRecipe>(ExpressionRecipes.back())
+ ? Ctx.TTI.getPartialReductionCost(
+ Opcode, Ctx.Types.inferScalarType(getOperand(0)), nullptr,
+ RedTy, VF,
+ TargetTransformInfo::getPartialReductionExtendKind(
+ ExtR->getOpcode()),
+ TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind)
+ : Ctx.TTI.getExtendedReductionCost(
+ Opcode, ExtR->getOpcode() == Instruction::ZExt, RedTy,
+ SrcVecTy, std::nullopt, Ctx.CostKind);
}
case ExpressionTypes::MulAccReduction:
return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy,
@@ -2871,6 +2889,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
Opcode = Instruction::Sub;
[[fallthrough]];
case ExpressionTypes::ExtMulAccReduction: {
+ if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) {
+ auto *Ext0R = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
+ auto *Ext1R = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
+ auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
+ return Ctx.TTI.getPartialReductionCost(
+ Opcode, Ctx.Types.inferScalarType(getOperand(0)),
+ Ctx.Types.inferScalarType(getOperand(1)), RedTy, VF,
+ TargetTransformInfo::getPartialReductionExtendKind(
+ Ext0R->getOpcode()),
+ TargetTransformInfo::getPartialReductionExtendKind(
+ Ext1R->getOpcode()),
+ Mul->getOpcode(), Ctx.CostKind);
+ }
return Ctx.TTI.getMulAccReductionCost(
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
Instruction::ZExt,
@@ -2910,12 +2941,13 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
O << " = ";
auto *Red = cast<VPReductionRecipe>(ExpressionRecipes.back());
unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
+ bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);
switch (ExpressionType) {
case ExpressionTypes::ExtendedReduction: {
getOperand(1)->printAsOperand(O, SlotTracker);
- O << " +";
- O << " reduce." << Instruction::getOpcodeName(Opcode) << " (";
+ O << " + " << (IsPartialReduction ? "partial." : "") << "reduce.";
+ O << Instruction::getOpcodeName(Opcode) << " (";
getOperand(0)->printAsOperand(O, SlotTracker);
Red->printFlags(O);
@@ -2931,8 +2963,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
}
case ExpressionTypes::ExtNegatedMulAccReduction: {
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
- O << " + reduce."
- << Instruction::getOpcodeName(
+ O << " + " << (IsPartialReduction ? "partial." : "") << "reduce.";
+ O << Instruction::getOpcodeName(
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
<< " (sub (0, mul";
auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
@@ -2956,9 +2988,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
case ExpressionTypes::MulAccReduction:
case ExpressionTypes::ExtMulAccReduction: {
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
- O << " + ";
- O << "reduce."
- << Instruction::getOpcodeName(
+ O << " + " << (IsPartialReduction ? "partial." : "") << "reduce.";
+ O << Instruction::getOpcodeName(
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
<< " (";
O << "mul";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index f5a3af4..3e85e6f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -3519,18 +3519,31 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
VPValue *VecOp = Red->getVecOp();
// Clamp the range if using extended-reduction is profitable.
- auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
- Type *SrcTy) -> bool {
+ auto IsExtendedRedValidAndClampRange =
+ [&](unsigned Opcode, Instruction::CastOps ExtOpc, Type *SrcTy) -> bool {
return LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
- InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
- Opcode, isZExt, RedTy, SrcVecTy, Red->getFastMathFlags(),
- CostKind);
+
+ InstructionCost ExtRedCost;
InstructionCost ExtCost =
cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
InstructionCost RedCost = Red->computeCost(VF, Ctx);
+
+ if (isa<VPPartialReductionRecipe>(Red)) {
+ TargetTransformInfo::PartialReductionExtendKind ExtKind =
+ TargetTransformInfo::getPartialReductionExtendKind(ExtOpc);
+ // FIXME: Move partial reduction creation, costing and clamping
+ // here from LoopVectorize.cpp.
+ ExtRedCost = Ctx.TTI.getPartialReductionCost(
+ Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
+ llvm::TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind);
+ } else {
+ ExtRedCost = Ctx.TTI.getExtendedReductionCost(
+ Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
+ Red->getFastMathFlags(), CostKind);
+ }
return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
},
Range);
@@ -3541,8 +3554,7 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
IsExtendedRedValidAndClampRange(
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()),
- cast<VPWidenCastRecipe>(VecOp)->getOpcode() ==
- Instruction::CastOps::ZExt,
+ cast<VPWidenCastRecipe>(VecOp)->getOpcode(),
Ctx.Types.inferScalarType(A)))
return new VPExpressionRecipe(cast<VPWidenCastRecipe>(VecOp), Red);
@@ -3560,6 +3572,8 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
static VPExpressionRecipe *
tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
VPCostContext &Ctx, VFRange &Range) {
+ bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);
+
unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
return nullptr;
@@ -3568,16 +3582,41 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
// Clamp the range if using multiply-accumulate-reduction is profitable.
auto IsMulAccValidAndClampRange =
- [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
- VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
+ [&](VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1,
+ VPWidenCastRecipe *OuterExt) -> bool {
return LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
Type *SrcTy =
Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
- auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
- InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
- isZExt, Opcode, RedTy, SrcVecTy, CostKind);
+ InstructionCost MulAccCost;
+
+ if (IsPartialReduction) {
+ Type *SrcTy2 =
+ Ext1 ? Ctx.Types.inferScalarType(Ext1->getOperand(0)) : nullptr;
+ // FIXME: Move partial reduction creation, costing and clamping
+ // here from LoopVectorize.cpp.
+ MulAccCost = Ctx.TTI.getPartialReductionCost(
+ Opcode, SrcTy, SrcTy2, RedTy, VF,
+ Ext0 ? TargetTransformInfo::getPartialReductionExtendKind(
+ Ext0->getOpcode())
+ : TargetTransformInfo::PR_None,
+ Ext1 ? TargetTransformInfo::getPartialReductionExtendKind(
+ Ext1->getOpcode())
+ : TargetTransformInfo::PR_None,
+ Mul->getOpcode(), CostKind);
+ } else {
+ // Only partial reductions support mixed extends at the moment.
+ if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode())
+ return false;
+
+ bool IsZExt =
+ !Ext0 || Ext0->getOpcode() == Instruction::CastOps::ZExt;
+ auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
+ MulAccCost = Ctx.TTI.getMulAccReductionCost(IsZExt, Opcode, RedTy,
+ SrcVecTy, CostKind);
+ }
+
InstructionCost MulCost = Mul->computeCost(VF, Ctx);
InstructionCost RedCost = Red->computeCost(VF, Ctx);
InstructionCost ExtCost = 0;
@@ -3611,14 +3650,10 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
- // Match reduce.add(mul(ext, ext)).
- if (RecipeA && RecipeB &&
- (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) &&
- match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+ // Match reduce.add/sub(mul(ext, ext)).
+ if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
- IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
- Instruction::CastOps::ZExt,
- Mul, RecipeA, RecipeB, nullptr)) {
+ IsMulAccValidAndClampRange(Mul, RecipeA, RecipeB, nullptr)) {
if (Sub)
return new VPExpressionRecipe(RecipeA, RecipeB, Mul,
cast<VPWidenRecipe>(Sub), Red);
@@ -3626,8 +3661,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
}
// Match reduce.add(mul).
// TODO: Add an expression type for this variant with a negated mul
- if (!Sub &&
- IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
+ if (!Sub && IsMulAccValidAndClampRange(Mul, nullptr, nullptr, nullptr))
return new VPExpressionRecipe(Mul, Red);
}
// TODO: Add an expression type for negated versions of other expression
@@ -3647,9 +3681,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
Ext0->getOpcode() == Ext1->getOpcode() &&
- IsMulAccValidAndClampRange(Ext0->getOpcode() ==
- Instruction::CastOps::ZExt,
- Mul, Ext0, Ext1, Ext)) {
+ IsMulAccValidAndClampRange(Mul, Ext0, Ext1, Ext) && Mul->hasOneUse()) {
auto *NewExt0 = new VPWidenCastRecipe(
Ext0->getOpcode(), Ext0->getOperand(0), Ext->getResultType(), *Ext0,
*Ext0, Ext0->getDebugLoc());
diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
index 32e4b88..06c3d75 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
@@ -151,6 +151,8 @@ unsigned vputils::getVFScaleFactor(VPRecipeBase *R) {
return RR->getVFScaleFactor();
if (auto *RR = dyn_cast<VPPartialReductionRecipe>(R))
return RR->getVFScaleFactor();
+ if (auto *ER = dyn_cast<VPExpressionRecipe>(R))
+ return ER->getVFScaleFactor();
assert(
(!isa<VPInstruction>(R) || cast<VPInstruction>(R)->getOpcode() !=
VPInstruction::ReductionStartVector) &&