diff options
Diffstat (limited to 'llvm/lib')
32 files changed, 960 insertions, 152 deletions
diff --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp index 050c327..424a7fe 100644 --- a/llvm/lib/Analysis/LoopCacheAnalysis.cpp +++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp @@ -436,10 +436,9 @@ bool IndexedReference::delinearize(const LoopInfo &LI) { const SCEV *StepRec = AccessFnAR ? AccessFnAR->getStepRecurrence(SE) : nullptr; if (StepRec && SE.isKnownNegative(StepRec)) - AccessFn = SE.getAddRecExpr(AccessFnAR->getStart(), - SE.getNegativeSCEV(StepRec), - AccessFnAR->getLoop(), - AccessFnAR->getNoWrapFlags()); + AccessFn = SE.getAddRecExpr( + AccessFnAR->getStart(), SE.getNegativeSCEV(StepRec), + AccessFnAR->getLoop(), SCEV::NoWrapFlags::FlagAnyWrap); const SCEV *Div = SE.getUDivExactExpr(AccessFn, ElemSize); Subscripts.push_back(Div); Sizes.push_back(ElemSize); diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index bf62623..c47a1c1 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1001,13 +1001,25 @@ InstructionCost TargetTransformInfo::getShuffleCost( TargetTransformInfo::PartialReductionExtendKind TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) { - if (isa<SExtInst>(I)) - return PR_SignExtend; - if (isa<ZExtInst>(I)) - return PR_ZeroExtend; + if (auto *Cast = dyn_cast<CastInst>(I)) + return getPartialReductionExtendKind(Cast->getOpcode()); return PR_None; } +TargetTransformInfo::PartialReductionExtendKind +TargetTransformInfo::getPartialReductionExtendKind( + Instruction::CastOps CastOpc) { + switch (CastOpc) { + case Instruction::CastOps::ZExt: + return PR_ZeroExtend; + case Instruction::CastOps::SExt: + return PR_SignExtend; + default: + return PR_None; + } + llvm_unreachable("Unhandled cast opcode"); +} + TTI::CastContextHint TargetTransformInfo::getCastContextHint(const Instruction *I) { if (!I) diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp index fefde64f..8aa488f 100644 --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -41,6 +41,7 @@ #include "llvm/CodeGen/GCMetadataPrinter.h" #include "llvm/CodeGen/LazyMachineBlockFrequencyInfo.h" #include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/CodeGen/MachineBlockHashInfo.h" #include "llvm/CodeGen/MachineBranchProbabilityInfo.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineDominators.h" @@ -184,6 +185,8 @@ static cl::opt<bool> PrintLatency( cl::desc("Print instruction latencies as verbose asm comments"), cl::Hidden, cl::init(false)); +extern cl::opt<bool> EmitBBHash; + STATISTIC(EmittedInsts, "Number of machine instrs printed"); char AsmPrinter::ID = 0; @@ -474,6 +477,8 @@ void AsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<GCModuleInfo>(); AU.addRequired<LazyMachineBlockFrequencyInfoPass>(); AU.addRequired<MachineBranchProbabilityInfoWrapperPass>(); + if (EmitBBHash) + AU.addRequired<MachineBlockHashInfo>(); } bool AsmPrinter::doInitialization(Module &M) { @@ -1434,14 +1439,11 @@ getBBAddrMapFeature(const MachineFunction &MF, int NumMBBSectionRanges, "BB entries info is required for BBFreq and BrProb " "features"); } - return {FuncEntryCountEnabled, - BBFreqEnabled, - BrProbEnabled, + return {FuncEntryCountEnabled, BBFreqEnabled, BrProbEnabled, MF.hasBBSections() && NumMBBSectionRanges > 1, // Use static_cast to avoid breakage of tests on windows. - static_cast<bool>(BBAddrMapSkipEmitBBEntries), - HasCalls, - false}; + static_cast<bool>(BBAddrMapSkipEmitBBEntries), HasCalls, + static_cast<bool>(EmitBBHash)}; } void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) { @@ -1500,6 +1502,9 @@ void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) { PrevMBBEndSymbol = MBBSymbol; } + auto MBHI = + Features.BBHash ? &getAnalysis<MachineBlockHashInfo>() : nullptr; + if (!Features.OmitBBEntries) { OutStreamer->AddComment("BB id"); // Emit the BB ID for this basic block. @@ -1527,6 +1532,10 @@ void AsmPrinter::emitBBAddrMapSection(const MachineFunction &MF) { emitLabelDifferenceAsULEB128(MBB.getEndSymbol(), CurrentLabel); // Emit the Metadata. OutStreamer->emitULEB128IntValue(getBBAddrMapMetadata(MBB)); + // Emit the Hash. + if (MBHI) { + OutStreamer->emitInt64(MBHI->getMBBHash(MBB)); + } } PrevMBBEndSymbol = MBB.getEndSymbol(); } diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt index b6872605..4373c53 100644 --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -108,6 +108,7 @@ add_llvm_component_library(LLVMCodeGen LowerEmuTLS.cpp MachineBasicBlock.cpp MachineBlockFrequencyInfo.cpp + MachineBlockHashInfo.cpp MachineBlockPlacement.cpp MachineBranchProbabilityInfo.cpp MachineCFGPrinter.cpp diff --git a/llvm/lib/CodeGen/MachineBlockHashInfo.cpp b/llvm/lib/CodeGen/MachineBlockHashInfo.cpp new file mode 100644 index 0000000..c4d9c0f --- /dev/null +++ b/llvm/lib/CodeGen/MachineBlockHashInfo.cpp @@ -0,0 +1,115 @@ +//===- llvm/CodeGen/MachineBlockHashInfo.cpp---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Compute the hashes of basic blocks. +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/MachineBlockHashInfo.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/InitializePasses.h" +#include "llvm/Target/TargetMachine.h" + +using namespace llvm; + +uint64_t hashBlock(const MachineBasicBlock &MBB, bool HashOperands) { + uint64_t Hash = 0; + for (const MachineInstr &MI : MBB) { + if (MI.isMetaInstruction() || MI.isTerminator()) + continue; + Hash = hashing::detail::hash_16_bytes(Hash, MI.getOpcode()); + if (HashOperands) { + for (unsigned i = 0; i < MI.getNumOperands(); i++) { + Hash = + hashing::detail::hash_16_bytes(Hash, hash_value(MI.getOperand(i))); + } + } + } + return Hash; +} + +/// Fold a 64-bit integer to a 16-bit one. +uint16_t fold_64_to_16(const uint64_t Value) { + uint16_t Res = static_cast<uint16_t>(Value); + Res ^= static_cast<uint16_t>(Value >> 16); + Res ^= static_cast<uint16_t>(Value >> 32); + Res ^= static_cast<uint16_t>(Value >> 48); + return Res; +} + +INITIALIZE_PASS(MachineBlockHashInfo, "machine-block-hash", + "Machine Block Hash Analysis", true, true) + +char MachineBlockHashInfo::ID = 0; + +MachineBlockHashInfo::MachineBlockHashInfo() : MachineFunctionPass(ID) { + initializeMachineBlockHashInfoPass(*PassRegistry::getPassRegistry()); +} + +void MachineBlockHashInfo::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + MachineFunctionPass::getAnalysisUsage(AU); +} + +struct CollectHashInfo { + uint64_t Offset; + uint64_t OpcodeHash; + uint64_t InstrHash; + uint64_t NeighborHash; +}; + +bool MachineBlockHashInfo::runOnMachineFunction(MachineFunction &F) { + DenseMap<const MachineBasicBlock *, CollectHashInfo> HashInfos; + uint16_t Offset = 0; + // Initialize hash components + for (const MachineBasicBlock &MBB : F) { + // offset of the machine basic block + HashInfos[&MBB].Offset = Offset; + Offset += MBB.size(); + // Hashing opcodes + HashInfos[&MBB].OpcodeHash = hashBlock(MBB, /*HashOperands=*/false); + // Hash complete instructions + HashInfos[&MBB].InstrHash = hashBlock(MBB, /*HashOperands=*/true); + } + + // Initialize neighbor hash + for (const MachineBasicBlock &MBB : F) { + uint64_t Hash = HashInfos[&MBB].OpcodeHash; + // Append hashes of successors + for (const MachineBasicBlock *SuccMBB : MBB.successors()) { + uint64_t SuccHash = HashInfos[SuccMBB].OpcodeHash; + Hash = hashing::detail::hash_16_bytes(Hash, SuccHash); + } + // Append hashes of predecessors + for (const MachineBasicBlock *PredMBB : MBB.predecessors()) { + uint64_t PredHash = HashInfos[PredMBB].OpcodeHash; + Hash = hashing::detail::hash_16_bytes(Hash, PredHash); + } + HashInfos[&MBB].NeighborHash = Hash; + } + + // Assign hashes + for (const MachineBasicBlock &MBB : F) { + const auto &HashInfo = HashInfos[&MBB]; + BlendedBlockHash BlendedHash(fold_64_to_16(HashInfo.Offset), + fold_64_to_16(HashInfo.OpcodeHash), + fold_64_to_16(HashInfo.InstrHash), + fold_64_to_16(HashInfo.NeighborHash)); + MBBHashInfo[&MBB] = BlendedHash.combine(); + } + + return false; +} + +uint64_t MachineBlockHashInfo::getMBBHash(const MachineBasicBlock &MBB) { + return MBBHashInfo[&MBB]; +} + +MachineFunctionPass *llvm::createMachineBlockHashInfoPass() { + return new MachineBlockHashInfo(); +} diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp index b6169e6..10b7238 100644 --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -272,6 +272,12 @@ static cl::opt<bool> cl::desc("Split static data sections into hot and cold " "sections using profile information")); +cl::opt<bool> EmitBBHash( + "emit-bb-hash", + cl::desc( + "Emit the hash of basic block in the SHT_LLVM_BB_ADDR_MAP section."), + cl::init(false), cl::Optional); + /// Allow standard passes to be disabled by command line options. This supports /// simple binary flags that either suppress the pass or do nothing. /// i.e. -disable-mypass=false has no effect. @@ -1281,6 +1287,8 @@ void TargetPassConfig::addMachinePasses() { // address map (or both). if (TM->getBBSectionsType() != llvm::BasicBlockSection::None || TM->Options.BBAddrMap) { + if (EmitBBHash) + addPass(llvm::createMachineBlockHashInfoPass()); if (TM->getBBSectionsType() == llvm::BasicBlockSection::List) { addPass(llvm::createBasicBlockSectionsProfileReaderWrapperPass( TM->getBBSectionsFuncListBuf())); diff --git a/llvm/lib/ExecutionEngine/JITLink/JITLinkMemoryManager.cpp b/llvm/lib/ExecutionEngine/JITLink/JITLinkMemoryManager.cpp index 6c7e27e..fa04976 100644 --- a/llvm/lib/ExecutionEngine/JITLink/JITLinkMemoryManager.cpp +++ b/llvm/lib/ExecutionEngine/JITLink/JITLinkMemoryManager.cpp @@ -247,7 +247,7 @@ public: StandardSegments(std::move(StandardSegments)), FinalizationSegments(std::move(FinalizationSegments)) {} - ~IPInFlightAlloc() { + ~IPInFlightAlloc() override { assert(!G && "InFlight alloc neither abandoned nor finalized"); } diff --git a/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp b/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp index 75ae80f..4ceff48 100644 --- a/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp +++ b/llvm/lib/ExecutionEngine/Orc/Debugging/DebuggerSupportPlugin.cpp @@ -38,7 +38,7 @@ public: MachODebugObjectSynthesizerBase(LinkGraph &G, ExecutorAddr RegisterActionAddr) : G(G), RegisterActionAddr(RegisterActionAddr) {} - virtual ~MachODebugObjectSynthesizerBase() = default; + ~MachODebugObjectSynthesizerBase() override = default; Error preserveDebugSections() { if (G.findSectionByName(SynthDebugSectionName)) { diff --git a/llvm/lib/ExecutionEngine/Orc/LinkGraphLinkingLayer.cpp b/llvm/lib/ExecutionEngine/Orc/LinkGraphLinkingLayer.cpp index d1a6eaf..a2990ab 100644 --- a/llvm/lib/ExecutionEngine/Orc/LinkGraphLinkingLayer.cpp +++ b/llvm/lib/ExecutionEngine/Orc/LinkGraphLinkingLayer.cpp @@ -55,7 +55,7 @@ public: Plugins = Layer.Plugins; } - ~JITLinkCtx() { + ~JITLinkCtx() override { // If there is an object buffer return function then use it to // return ownership of the buffer. if (Layer.ReturnObjectBuffer && ObjBuffer) diff --git a/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp b/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp index fd805fbf..cdde733 100644 --- a/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp +++ b/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp @@ -92,7 +92,7 @@ public: Name(std::move(Name)), Ctx(Ctx), Materialize(Materialize), Discard(Discard), Destroy(Destroy) {} - ~OrcCAPIMaterializationUnit() { + ~OrcCAPIMaterializationUnit() override { if (Ctx) Destroy(Ctx); } @@ -264,7 +264,7 @@ public: LLVMOrcCAPIDefinitionGeneratorTryToGenerateFunction TryToGenerate) : Dispose(Dispose), Ctx(Ctx), TryToGenerate(TryToGenerate) {} - ~CAPIDefinitionGenerator() { + ~CAPIDefinitionGenerator() override { if (Dispose) Dispose(Ctx); } diff --git a/llvm/lib/LTO/LTO.cpp b/llvm/lib/LTO/LTO.cpp index 86780e1..9d0fa11 100644 --- a/llvm/lib/LTO/LTO.cpp +++ b/llvm/lib/LTO/LTO.cpp @@ -2224,6 +2224,7 @@ class OutOfProcessThinBackend : public CGThinBackend { ArrayRef<StringRef> DistributorArgs; SString RemoteCompiler; + ArrayRef<StringRef> RemoteCompilerPrependArgs; ArrayRef<StringRef> RemoteCompilerArgs; bool SaveTemps; @@ -2260,12 +2261,14 @@ public: bool ShouldEmitIndexFiles, bool ShouldEmitImportsFiles, StringRef LinkerOutputFile, StringRef Distributor, ArrayRef<StringRef> DistributorArgs, StringRef RemoteCompiler, + ArrayRef<StringRef> RemoteCompilerPrependArgs, ArrayRef<StringRef> RemoteCompilerArgs, bool SaveTemps) : CGThinBackend(Conf, CombinedIndex, ModuleToDefinedGVSummaries, AddStream, OnWrite, ShouldEmitIndexFiles, ShouldEmitImportsFiles, ThinLTOParallelism), LinkerOutputFile(LinkerOutputFile), DistributorPath(Distributor), DistributorArgs(DistributorArgs), RemoteCompiler(RemoteCompiler), + RemoteCompilerPrependArgs(RemoteCompilerPrependArgs), RemoteCompilerArgs(RemoteCompilerArgs), SaveTemps(SaveTemps) {} virtual void setup(unsigned ThinLTONumTasks, unsigned ThinLTOTaskOffset, @@ -2387,6 +2390,11 @@ public: JOS.attributeArray("args", [&]() { JOS.value(RemoteCompiler); + // Forward any supplied prepend options. + if (!RemoteCompilerPrependArgs.empty()) + for (auto &A : RemoteCompilerPrependArgs) + JOS.value(A); + JOS.value("-c"); JOS.value(Saver.save("--target=" + Triple.str())); @@ -2517,6 +2525,7 @@ ThinBackend lto::createOutOfProcessThinBackend( bool ShouldEmitIndexFiles, bool ShouldEmitImportsFiles, StringRef LinkerOutputFile, StringRef Distributor, ArrayRef<StringRef> DistributorArgs, StringRef RemoteCompiler, + ArrayRef<StringRef> RemoteCompilerPrependArgs, ArrayRef<StringRef> RemoteCompilerArgs, bool SaveTemps) { auto Func = [=](const Config &Conf, ModuleSummaryIndex &CombinedIndex, @@ -2526,7 +2535,7 @@ ThinBackend lto::createOutOfProcessThinBackend( Conf, CombinedIndex, Parallelism, ModuleToDefinedGVSummaries, AddStream, OnWrite, ShouldEmitIndexFiles, ShouldEmitImportsFiles, LinkerOutputFile, Distributor, DistributorArgs, RemoteCompiler, - RemoteCompilerArgs, SaveTemps); + RemoteCompilerPrependArgs, RemoteCompilerArgs, SaveTemps); }; return ThinBackend(Func, Parallelism); } diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index f788c75..92f260f 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -4005,24 +4005,20 @@ def : Pat<(i64 (zextloadi32 (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))), (SUBREG_TO_REG (i64 0), (LDRWui GPR64sp:$Rn, uimm12s4:$offset), sub_32)>; // load zero-extended i32, bitcast to f64 -def : Pat <(f64 (bitconvert (i64 (zextloadi32 (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))))), - (SUBREG_TO_REG (i64 0), (LDRSui GPR64sp:$Rn, uimm12s4:$offset), ssub)>; - +def : Pat<(f64 (bitconvert (i64 (zextloadi32 (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))))), + (SUBREG_TO_REG (i64 0), (LDRSui GPR64sp:$Rn, uimm12s4:$offset), ssub)>; // load zero-extended i16, bitcast to f64 -def : Pat <(f64 (bitconvert (i64 (zextloadi16 (am_indexed32 GPR64sp:$Rn, uimm12s2:$offset))))), - (SUBREG_TO_REG (i64 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>; - +def : Pat<(f64 (bitconvert (i64 (zextloadi16 (am_indexed16 GPR64sp:$Rn, uimm12s2:$offset))))), + (SUBREG_TO_REG (i64 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>; // load zero-extended i8, bitcast to f64 -def : Pat <(f64 (bitconvert (i64 (zextloadi8 (am_indexed32 GPR64sp:$Rn, uimm12s1:$offset))))), - (SUBREG_TO_REG (i64 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>; - +def : Pat<(f64 (bitconvert (i64 (zextloadi8 (am_indexed8 GPR64sp:$Rn, uimm12s1:$offset))))), + (SUBREG_TO_REG (i64 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>; // load zero-extended i16, bitcast to f32 -def : Pat <(f32 (bitconvert (i32 (zextloadi16 (am_indexed16 GPR64sp:$Rn, uimm12s2:$offset))))), - (SUBREG_TO_REG (i32 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>; - +def : Pat<(f32 (bitconvert (i32 (zextloadi16 (am_indexed16 GPR64sp:$Rn, uimm12s2:$offset))))), + (SUBREG_TO_REG (i32 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>; // load zero-extended i8, bitcast to f32 -def : Pat <(f32 (bitconvert (i32 (zextloadi8 (am_indexed16 GPR64sp:$Rn, uimm12s1:$offset))))), - (SUBREG_TO_REG (i32 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>; +def : Pat<(f32 (bitconvert (i32 (zextloadi8 (am_indexed8 GPR64sp:$Rn, uimm12s1:$offset))))), + (SUBREG_TO_REG (i32 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>; // Pre-fetch. def PRFMui : PrefetchUI<0b11, 0, 0b10, "prfm", diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index e3370d3..2053fc4 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1577,18 +1577,26 @@ static SVEIntrinsicInfo constructSVEIntrinsicInfo(IntrinsicInst &II) { } static bool isAllActivePredicate(Value *Pred) { - // Look through convert.from.svbool(convert.to.svbool(...) chain. Value *UncastedPred; + + // Look through predicate casts that only remove lanes. if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>( - m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>( - m_Value(UncastedPred))))) - // If the predicate has the same or less lanes than the uncasted - // predicate then we know the casting has no effect. - if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <= - cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements()) - Pred = UncastedPred; + m_Value(UncastedPred)))) { + auto *OrigPredTy = cast<ScalableVectorType>(Pred->getType()); + Pred = UncastedPred; + + if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>( + m_Value(UncastedPred)))) + // If the predicate has the same or less lanes than the uncasted predicate + // then we know the casting has no effect. + if (OrigPredTy->getMinNumElements() <= + cast<ScalableVectorType>(UncastedPred->getType()) + ->getMinNumElements()) + Pred = UncastedPred; + } + auto *C = dyn_cast<Constant>(Pred); - return (C && C->isAllOnesValue()); + return C && C->isAllOnesValue(); } // Simplify `V` by only considering the operations that affect active lanes. diff --git a/llvm/lib/Target/ARM/ARMAsmPrinter.cpp b/llvm/lib/Target/ARM/ARMAsmPrinter.cpp index 3368a50..36b9908 100644 --- a/llvm/lib/Target/ARM/ARMAsmPrinter.cpp +++ b/llvm/lib/Target/ARM/ARMAsmPrinter.cpp @@ -1471,6 +1471,435 @@ void ARMAsmPrinter::EmitUnwindingInstruction(const MachineInstr *MI) { // instructions) auto-generated. #include "ARMGenMCPseudoLowering.inc" +// Helper function to check if a register is live (used as an implicit operand) +// in the given call instruction. +static bool isRegisterLiveInCall(const MachineInstr &Call, MCRegister Reg) { + for (const MachineOperand &MO : Call.implicit_operands()) { + if (MO.isReg() && MO.getReg() == Reg && MO.isUse()) { + return true; + } + } + return false; +} + +void ARMAsmPrinter::EmitKCFI_CHECK_ARM32(Register AddrReg, int64_t Type, + const MachineInstr &Call, + int64_t PrefixNops) { + // Choose scratch register: r12 primary, r3 if target is r12. + unsigned ScratchReg = ARM::R12; + if (AddrReg == ARM::R12) { + ScratchReg = ARM::R3; + } + + // Calculate ESR for ARM mode (16-bit): 0x8000 | (scratch_reg << 5) | addr_reg + // Note: scratch_reg is always 0x1F since the EOR sequence clobbers it. + const ARMBaseRegisterInfo *TRI = static_cast<const ARMBaseRegisterInfo *>( + MF->getSubtarget().getRegisterInfo()); + unsigned AddrIndex = TRI->getEncodingValue(AddrReg); + unsigned ESR = 0x8000 | (31 << 5) | (AddrIndex & 31); + + // Check if r3 is live and needs to be spilled. + bool NeedSpillR3 = + (ScratchReg == ARM::R3) && isRegisterLiveInCall(Call, ARM::R3); + + // If we need to spill r3, push it first. + if (NeedSpillR3) { + // push {r3} + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::STMDB_UPD) + .addReg(ARM::SP) + .addReg(ARM::SP) + .addImm(ARMCC::AL) + .addReg(0) + .addReg(ARM::R3)); + } + + // Clear bit 0 of target address to handle Thumb function pointers. + // In 32-bit ARM, function pointers may have the low bit set to indicate + // Thumb state when ARM/Thumb interworking is enabled (ARMv4T and later). + // We need to clear it to avoid an alignment fault when loading. + // bic scratch, target, #1 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::BICri) + .addReg(ScratchReg) + .addReg(AddrReg) + .addImm(1) + .addImm(ARMCC::AL) + .addReg(0) + .addReg(0)); + + // ldr scratch, [scratch, #-(PrefixNops * 4 + 4)] + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::LDRi12) + .addReg(ScratchReg) + .addReg(ScratchReg) + .addImm(-(PrefixNops * 4 + 4)) + .addImm(ARMCC::AL) + .addReg(0)); + + // Each EOR instruction XORs one byte of the type, shifted to its position. + for (int i = 0; i < 4; i++) { + uint8_t byte = (Type >> (i * 8)) & 0xFF; + uint32_t imm = byte << (i * 8); + bool isLast = (i == 3); + + // Encode as ARM modified immediate. + int SOImmVal = ARM_AM::getSOImmVal(imm); + assert(SOImmVal != -1 && + "Cannot encode immediate as ARM modified immediate"); + + // eor[s] scratch, scratch, #imm (last one sets flags with CPSR) + EmitToStreamer(*OutStreamer, + MCInstBuilder(ARM::EORri) + .addReg(ScratchReg) + .addReg(ScratchReg) + .addImm(SOImmVal) + .addImm(ARMCC::AL) + .addReg(0) + .addReg(isLast ? ARM::CPSR : ARM::NoRegister)); + } + + // If we spilled r3, restore it immediately after the comparison. + // This must happen before the branch so r3 is valid on both paths. + if (NeedSpillR3) { + // pop {r3} + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::LDMIA_UPD) + .addReg(ARM::SP) + .addReg(ARM::SP) + .addImm(ARMCC::AL) + .addReg(0) + .addReg(ARM::R3)); + } + + // beq .Lpass (branch if types match, i.e., scratch is zero) + MCSymbol *Pass = OutContext.createTempSymbol(); + EmitToStreamer(*OutStreamer, + MCInstBuilder(ARM::Bcc) + .addExpr(MCSymbolRefExpr::create(Pass, OutContext)) + .addImm(ARMCC::EQ) + .addReg(ARM::CPSR)); + + // udf #ESR (trap with encoded diagnostic) + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::UDF).addImm(ESR)); + + OutStreamer->emitLabel(Pass); +} + +void ARMAsmPrinter::EmitKCFI_CHECK_Thumb2(Register AddrReg, int64_t Type, + const MachineInstr &Call, + int64_t PrefixNops) { + // Choose scratch register: r12 primary, r3 if target is r12. + unsigned ScratchReg = ARM::R12; + if (AddrReg == ARM::R12) { + ScratchReg = ARM::R3; + } + + // Calculate ESR for Thumb mode (8-bit): 0x80 | addr_reg + // Bit 7: KCFI trap indicator + // Bits 6-5: Reserved + // Bits 4-0: Address register encoding + const ARMBaseRegisterInfo *TRI = static_cast<const ARMBaseRegisterInfo *>( + MF->getSubtarget().getRegisterInfo()); + unsigned AddrIndex = TRI->getEncodingValue(AddrReg); + unsigned ESR = 0x80 | (AddrIndex & 0x1F); + + // Check if r3 is live and needs to be spilled. + bool NeedSpillR3 = + (ScratchReg == ARM::R3) && isRegisterLiveInCall(Call, ARM::R3); + + // If we need to spill r3, push it first. + if (NeedSpillR3) { + // push {r3} + EmitToStreamer( + *OutStreamer, + MCInstBuilder(ARM::tPUSH).addImm(ARMCC::AL).addReg(0).addReg(ARM::R3)); + } + + // Clear bit 0 of target address to handle Thumb function pointers. + // In 32-bit ARM, function pointers may have the low bit set to indicate + // Thumb state when ARM/Thumb interworking is enabled (ARMv4T and later). + // We need to clear it to avoid an alignment fault when loading. + // bic scratch, target, #1 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::t2BICri) + .addReg(ScratchReg) + .addReg(AddrReg) + .addImm(1) + .addImm(ARMCC::AL) + .addReg(0) + .addReg(0)); + + // ldr scratch, [scratch, #-(PrefixNops * 4 + 4)] + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::t2LDRi8) + .addReg(ScratchReg) + .addReg(ScratchReg) + .addImm(-(PrefixNops * 4 + 4)) + .addImm(ARMCC::AL) + .addReg(0)); + + // Each EOR instruction XORs one byte of the type, shifted to its position. + for (int i = 0; i < 4; i++) { + uint8_t byte = (Type >> (i * 8)) & 0xFF; + uint32_t imm = byte << (i * 8); + bool isLast = (i == 3); + + // Verify the immediate can be encoded as Thumb2 modified immediate. + assert(ARM_AM::getT2SOImmVal(imm) != -1 && + "Cannot encode immediate as Thumb2 modified immediate"); + + // eor[s] scratch, scratch, #imm (last one sets flags with CPSR) + EmitToStreamer(*OutStreamer, + MCInstBuilder(ARM::t2EORri) + .addReg(ScratchReg) + .addReg(ScratchReg) + .addImm(imm) + .addImm(ARMCC::AL) + .addReg(0) + .addReg(isLast ? ARM::CPSR : ARM::NoRegister)); + } + + // If we spilled r3, restore it immediately after the comparison. + // This must happen before the branch so r3 is valid on both paths. + if (NeedSpillR3) { + // pop {r3} + EmitToStreamer( + *OutStreamer, + MCInstBuilder(ARM::tPOP).addImm(ARMCC::AL).addReg(0).addReg(ARM::R3)); + } + + // beq .Lpass (branch if types match, i.e., scratch is zero) + MCSymbol *Pass = OutContext.createTempSymbol(); + EmitToStreamer(*OutStreamer, + MCInstBuilder(ARM::t2Bcc) + .addExpr(MCSymbolRefExpr::create(Pass, OutContext)) + .addImm(ARMCC::EQ) + .addReg(ARM::CPSR)); + + // udf #ESR (trap with encoded diagnostic) + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tUDF).addImm(ESR)); + + OutStreamer->emitLabel(Pass); +} + +void ARMAsmPrinter::EmitKCFI_CHECK_Thumb1(Register AddrReg, int64_t Type, + const MachineInstr &Call, + int64_t PrefixNops) { + // For Thumb1, use R2 unconditionally as scratch register (a low register + // required for tLDRi). R3 is used for building the type hash. + unsigned ScratchReg = ARM::R2; + unsigned TempReg = ARM::R3; + + // Check if r3 is live and needs to be spilled. + bool NeedSpillR3 = isRegisterLiveInCall(Call, ARM::R3); + + // Spill r3 if needed + if (NeedSpillR3) { + EmitToStreamer( + *OutStreamer, + MCInstBuilder(ARM::tPUSH).addImm(ARMCC::AL).addReg(0).addReg(ARM::R3)); + } + + // Check if r2 is live and needs to be spilled. + bool NeedSpillR2 = isRegisterLiveInCall(Call, ARM::R2); + + // Push R2 if it's live + if (NeedSpillR2) { + EmitToStreamer( + *OutStreamer, + MCInstBuilder(ARM::tPUSH).addImm(ARMCC::AL).addReg(0).addReg(ARM::R2)); + } + + // Clear bit 0 from target address + // TempReg (R3) is used first as helper for BIC, then later for building type + // hash. + + // movs temp, #1 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tMOVi8) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addImm(1) + .addImm(ARMCC::AL) + .addReg(0)); + + // mov scratch, target + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tMOVr) + .addReg(ScratchReg) + .addReg(AddrReg) + .addImm(ARMCC::AL)); + + // bics scratch, temp (scratch = scratch & ~temp) + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tBIC) + .addReg(ScratchReg) + .addReg(ARM::CPSR) + .addReg(ScratchReg) + .addReg(TempReg) + .addImm(ARMCC::AL) + .addReg(0)); + + // Load type hash. Thumb1 doesn't support negative offsets, so subtract. + int offset = PrefixNops * 4 + 4; + + // subs scratch, #offset + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tSUBi8) + .addReg(ScratchReg) + .addReg(ARM::CPSR) + .addReg(ScratchReg) + .addImm(offset) + .addImm(ARMCC::AL) + .addReg(0)); + + // ldr scratch, [scratch, #0] + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tLDRi) + .addReg(ScratchReg) + .addReg(ScratchReg) + .addImm(0) + .addImm(ARMCC::AL) + .addReg(0)); + + // Load expected type inline (instead of EOR sequence) + // + // This creates the 32-bit value byte-by-byte in the temp register: + // movs temp, #byte3 (high byte) + // lsls temp, temp, #8 + // adds temp, #byte2 + // lsls temp, temp, #8 + // adds temp, #byte1 + // lsls temp, temp, #8 + // adds temp, #byte0 (low byte) + + uint8_t byte0 = (Type >> 0) & 0xFF; + uint8_t byte1 = (Type >> 8) & 0xFF; + uint8_t byte2 = (Type >> 16) & 0xFF; + uint8_t byte3 = (Type >> 24) & 0xFF; + + // movs temp, #byte3 (start with high byte) + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tMOVi8) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addImm(byte3) + .addImm(ARMCC::AL) + .addReg(0)); + + // lsls temp, temp, #8 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tLSLri) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addReg(TempReg) + .addImm(8) + .addImm(ARMCC::AL) + .addReg(0)); + + // adds temp, #byte2 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tADDi8) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addReg(TempReg) + .addImm(byte2) + .addImm(ARMCC::AL) + .addReg(0)); + + // lsls temp, temp, #8 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tLSLri) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addReg(TempReg) + .addImm(8) + .addImm(ARMCC::AL) + .addReg(0)); + + // adds temp, #byte1 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tADDi8) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addReg(TempReg) + .addImm(byte1) + .addImm(ARMCC::AL) + .addReg(0)); + + // lsls temp, temp, #8 + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tLSLri) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addReg(TempReg) + .addImm(8) + .addImm(ARMCC::AL) + .addReg(0)); + + // adds temp, #byte0 (low byte) + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tADDi8) + .addReg(TempReg) + .addReg(ARM::CPSR) + .addReg(TempReg) + .addImm(byte0) + .addImm(ARMCC::AL) + .addReg(0)); + + // cmp scratch, temp + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tCMPr) + .addReg(ScratchReg) + .addReg(TempReg) + .addImm(ARMCC::AL) + .addReg(0)); + + // Restore registers if spilled (pop in reverse order of push: R2, then R3) + if (NeedSpillR2) { + // pop {r2} + EmitToStreamer( + *OutStreamer, + MCInstBuilder(ARM::tPOP).addImm(ARMCC::AL).addReg(0).addReg(ARM::R2)); + } + + // Restore r3 if spilled + if (NeedSpillR3) { + // pop {r3} + EmitToStreamer( + *OutStreamer, + MCInstBuilder(ARM::tPOP).addImm(ARMCC::AL).addReg(0).addReg(ARM::R3)); + } + + // beq .Lpass (branch if types match, i.e., scratch == temp) + MCSymbol *Pass = OutContext.createTempSymbol(); + EmitToStreamer(*OutStreamer, + MCInstBuilder(ARM::tBcc) + .addExpr(MCSymbolRefExpr::create(Pass, OutContext)) + .addImm(ARMCC::EQ) + .addReg(ARM::CPSR)); + + // bkpt #0 (trap with encoded diagnostic) + EmitToStreamer(*OutStreamer, MCInstBuilder(ARM::tBKPT).addImm(0)); + + OutStreamer->emitLabel(Pass); +} + +void ARMAsmPrinter::LowerKCFI_CHECK(const MachineInstr &MI) { + Register AddrReg = MI.getOperand(0).getReg(); + const int64_t Type = MI.getOperand(1).getImm(); + + // Get the call instruction that follows this KCFI_CHECK. + assert(std::next(MI.getIterator())->isCall() && + "KCFI_CHECK not followed by a call instruction"); + const MachineInstr &Call = *std::next(MI.getIterator()); + + // Adjust the offset for patchable-function-prefix. + int64_t PrefixNops = 0; + MI.getMF() + ->getFunction() + .getFnAttribute("patchable-function-prefix") + .getValueAsString() + .getAsInteger(10, PrefixNops); + + // Emit the appropriate instruction sequence based on the opcode variant. + switch (MI.getOpcode()) { + case ARM::KCFI_CHECK_ARM: + EmitKCFI_CHECK_ARM32(AddrReg, Type, Call, PrefixNops); + break; + case ARM::KCFI_CHECK_Thumb2: + EmitKCFI_CHECK_Thumb2(AddrReg, Type, Call, PrefixNops); + break; + case ARM::KCFI_CHECK_Thumb1: + EmitKCFI_CHECK_Thumb1(AddrReg, Type, Call, PrefixNops); + break; + default: + llvm_unreachable("Unexpected KCFI_CHECK opcode"); + } +} + void ARMAsmPrinter::emitInstruction(const MachineInstr *MI) { ARM_MC::verifyInstructionPredicates(MI->getOpcode(), getSubtargetInfo().getFeatureBits()); @@ -1504,6 +1933,11 @@ void ARMAsmPrinter::emitInstruction(const MachineInstr *MI) { switch (Opc) { case ARM::t2MOVi32imm: llvm_unreachable("Should be lowered by thumb2it pass"); case ARM::DBG_VALUE: llvm_unreachable("Should be handled by generic printing"); + case ARM::KCFI_CHECK_ARM: + case ARM::KCFI_CHECK_Thumb2: + case ARM::KCFI_CHECK_Thumb1: + LowerKCFI_CHECK(*MI); + return; case ARM::LEApcrel: case ARM::tLEApcrel: case ARM::t2LEApcrel: { diff --git a/llvm/lib/Target/ARM/ARMAsmPrinter.h b/llvm/lib/Target/ARM/ARMAsmPrinter.h index 2b067c7..9e92b5a 100644 --- a/llvm/lib/Target/ARM/ARMAsmPrinter.h +++ b/llvm/lib/Target/ARM/ARMAsmPrinter.h @@ -123,9 +123,20 @@ public: void LowerPATCHABLE_FUNCTION_EXIT(const MachineInstr &MI); void LowerPATCHABLE_TAIL_CALL(const MachineInstr &MI); + // KCFI check lowering + void LowerKCFI_CHECK(const MachineInstr &MI); + private: void EmitSled(const MachineInstr &MI, SledKind Kind); + // KCFI check emission helpers + void EmitKCFI_CHECK_ARM32(Register AddrReg, int64_t Type, + const MachineInstr &Call, int64_t PrefixNops); + void EmitKCFI_CHECK_Thumb2(Register AddrReg, int64_t Type, + const MachineInstr &Call, int64_t PrefixNops); + void EmitKCFI_CHECK_Thumb1(Register AddrReg, int64_t Type, + const MachineInstr &Call, int64_t PrefixNops); + // Helpers for emitStartOfAsmFile() and emitEndOfAsmFile() void emitAttributes(); diff --git a/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp b/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp index 0d7b6d1..fffb6373 100644 --- a/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp +++ b/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp @@ -2301,6 +2301,8 @@ bool ARMExpandPseudo::ExpandMI(MachineBasicBlock &MBB, for (unsigned i = 2, e = MBBI->getNumOperands(); i != e; ++i) NewMI->addOperand(MBBI->getOperand(i)); + NewMI->setCFIType(*MBB.getParent(), MI.getCFIType()); + // Update call info and delete the pseudo instruction TCRETURN. if (MI.isCandidateForAdditionalCallInfo()) MI.getMF()->moveAdditionalCallInfo(&MI, &*NewMI); diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index b1a668e..8122db2 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -2849,6 +2849,8 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, if (isTailCall) { MF.getFrameInfo().setHasTailCall(); SDValue Ret = DAG.getNode(ARMISD::TC_RETURN, dl, MVT::Other, Ops); + if (CLI.CFIType) + Ret.getNode()->setCFIType(CLI.CFIType->getZExtValue()); DAG.addNoMergeSiteInfo(Ret.getNode(), CLI.NoMerge); DAG.addCallSiteInfo(Ret.getNode(), std::move(CSInfo)); return Ret; @@ -2856,6 +2858,8 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // Returns a chain and a flag for retval copy to use. Chain = DAG.getNode(CallOpc, dl, {MVT::Other, MVT::Glue}, Ops); + if (CLI.CFIType) + Chain.getNode()->setCFIType(CLI.CFIType->getZExtValue()); DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge); InGlue = Chain.getValue(1); DAG.addCallSiteInfo(Chain.getNode(), std::move(CSInfo)); @@ -12008,6 +12012,71 @@ static void genTPLoopBody(MachineBasicBlock *TpLoopBody, .add(predOps(ARMCC::AL)); } +bool ARMTargetLowering::supportKCFIBundles() const { + // KCFI is supported in all ARM/Thumb modes + return true; +} + +MachineInstr * +ARMTargetLowering::EmitKCFICheck(MachineBasicBlock &MBB, + MachineBasicBlock::instr_iterator &MBBI, + const TargetInstrInfo *TII) const { + assert(MBBI->isCall() && MBBI->getCFIType() && + "Invalid call instruction for a KCFI check"); + + MachineOperand *TargetOp = nullptr; + switch (MBBI->getOpcode()) { + // ARM mode opcodes + case ARM::BLX: + case ARM::BLX_pred: + case ARM::BLX_noip: + case ARM::BLX_pred_noip: + case ARM::BX_CALL: + TargetOp = &MBBI->getOperand(0); + break; + case ARM::TCRETURNri: + case ARM::TCRETURNrinotr12: + case ARM::TAILJMPr: + case ARM::TAILJMPr4: + TargetOp = &MBBI->getOperand(0); + break; + // Thumb mode opcodes (Thumb1 and Thumb2) + // Note: Most Thumb call instructions have predicate operands before the + // target register Format: tBLXr pred, predreg, target_register, ... + case ARM::tBLXr: // Thumb1/Thumb2: BLX register (requires V5T) + case ARM::tBLXr_noip: // Thumb1/Thumb2: BLX register, no IP clobber + case ARM::tBX_CALL: // Thumb1 only: BX call (push LR, BX) + TargetOp = &MBBI->getOperand(2); + break; + // Tail call instructions don't have predicates, target is operand 0 + case ARM::tTAILJMPr: // Thumb1/Thumb2: Tail call via register + TargetOp = &MBBI->getOperand(0); + break; + default: + llvm_unreachable("Unexpected CFI call opcode"); + } + + assert(TargetOp && TargetOp->isReg() && "Invalid target operand"); + TargetOp->setIsRenamable(false); + + // Select the appropriate KCFI_CHECK variant based on the instruction set + unsigned KCFICheckOpcode; + if (Subtarget->isThumb()) { + if (Subtarget->isThumb2()) { + KCFICheckOpcode = ARM::KCFI_CHECK_Thumb2; + } else { + KCFICheckOpcode = ARM::KCFI_CHECK_Thumb1; + } + } else { + KCFICheckOpcode = ARM::KCFI_CHECK_ARM; + } + + return BuildMI(MBB, MBBI, MBBI->getDebugLoc(), TII->get(KCFICheckOpcode)) + .addReg(TargetOp->getReg()) + .addImm(MBBI->getCFIType()) + .getInstr(); +} + MachineBasicBlock * ARMTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, MachineBasicBlock *BB) const { diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h index 70aa001..8c5e0cf 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.h +++ b/llvm/lib/Target/ARM/ARMISelLowering.h @@ -447,6 +447,12 @@ class VectorType; void AdjustInstrPostInstrSelection(MachineInstr &MI, SDNode *Node) const override; + bool supportKCFIBundles() const override; + + MachineInstr *EmitKCFICheck(MachineBasicBlock &MBB, + MachineBasicBlock::instr_iterator &MBBI, + const TargetInstrInfo *TII) const override; + SDValue PerformCMOVCombine(SDNode *N, SelectionDAG &DAG) const; SDValue PerformBRCONDCombine(SDNode *N, SelectionDAG &DAG) const; SDValue PerformCMOVToBFICombine(SDNode *N, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/ARM/ARMInstrInfo.td b/llvm/lib/Target/ARM/ARMInstrInfo.td index 282ff53..53be167 100644 --- a/llvm/lib/Target/ARM/ARMInstrInfo.td +++ b/llvm/lib/Target/ARM/ARMInstrInfo.td @@ -6536,6 +6536,36 @@ def CMP_SWAP_64 : PseudoInst<(outs GPRPair:$Rd, GPRPair:$addr_temp_out), def : Pat<(atomic_fence (timm), 0), (MEMBARRIER)>; //===----------------------------------------------------------------------===// +// KCFI check pseudo-instruction. +//===----------------------------------------------------------------------===// +// KCFI_CHECK pseudo-instruction for Kernel Control-Flow Integrity. +// Expands to a sequence that verifies the function pointer's type hash. +// Different sizes for different architectures due to different expansions. + +def KCFI_CHECK_ARM + : PseudoInst<(outs), (ins GPR:$ptr, i32imm:$type), NoItinerary, []>, + Sched<[]>, + Requires<[IsARM]> { + let Size = 28; // 7 instructions (bic, ldr, 4x eor, beq, udf) +} + +def KCFI_CHECK_Thumb2 + : PseudoInst<(outs), (ins GPR:$ptr, i32imm:$type), NoItinerary, []>, + Sched<[]>, + Requires<[IsThumb2]> { + let Size = + 32; // worst-case 9 instructions (push, bic, ldr, 4x eor, pop, beq.w, udf) +} + +def KCFI_CHECK_Thumb1 + : PseudoInst<(outs), (ins GPR:$ptr, i32imm:$type), NoItinerary, []>, + Sched<[]>, + Requires<[IsThumb1Only]> { + let Size = 50; // worst-case 25 instructions (pushes, bic helper, type + // building, cmp, pops) +} + +//===----------------------------------------------------------------------===// // Instructions used for emitting unwind opcodes on Windows. //===----------------------------------------------------------------------===// let isPseudo = 1 in { diff --git a/llvm/lib/Target/ARM/ARMTargetMachine.cpp b/llvm/lib/Target/ARM/ARMTargetMachine.cpp index 86740a9..590d4c7 100644 --- a/llvm/lib/Target/ARM/ARMTargetMachine.cpp +++ b/llvm/lib/Target/ARM/ARMTargetMachine.cpp @@ -111,6 +111,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeARMTarget() { initializeMVELaneInterleavingPass(Registry); initializeARMFixCortexA57AES1742098Pass(Registry); initializeARMDAGToDAGISelLegacyPass(Registry); + initializeKCFIPass(Registry); } static std::unique_ptr<TargetLoweringObjectFile> createTLOF(const Triple &TT) { @@ -487,6 +488,9 @@ void ARMPassConfig::addPreSched2() { // proper scheduling. addPass(createARMExpandPseudoPass()); + // Emit KCFI checks for indirect calls. + addPass(createKCFIPass()); + if (getOptLevel() != CodeGenOptLevel::None) { // When optimising for size, always run the Thumb2SizeReduction pass before // IfConversion. Otherwise, check whether IT blocks are restricted @@ -517,9 +521,12 @@ void ARMPassConfig::addPreSched2() { void ARMPassConfig::addPreEmitPass() { addPass(createThumb2SizeReductionPass()); - // Constant island pass work on unbundled instructions. + // Unpack bundles for: + // - Thumb2: Constant island pass requires unbundled instructions + // - KCFI: KCFI_CHECK pseudo instructions need to be unbundled for AsmPrinter addPass(createUnpackMachineBundles([](const MachineFunction &MF) { - return MF.getSubtarget<ARMSubtarget>().isThumb2(); + return MF.getSubtarget<ARMSubtarget>().isThumb2() || + MF.getFunction().getParent()->getModuleFlag("kcfi"); })); // Don't optimize barriers or block placement at -O0. @@ -530,6 +537,7 @@ void ARMPassConfig::addPreEmitPass() { } void ARMPassConfig::addPreEmitPass2() { + // Inserts fixup instructions before unsafe AES operations. Instructions may // be inserted at the start of blocks and at within blocks so this pass has to // come before those below. diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 22cf3a7..598735f 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -4675,7 +4675,7 @@ class WMMA_INSTR<string _Intr, list<dag> _Args> // class WMMA_LOAD<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride> - : WMMA_INSTR<WMMA_NAME_LDST<"load", Frag, Layout, WithStride>.record, + : WMMA_INSTR<WMMA_NAME_LDST<"load", Frag, Layout, WithStride>.record_name, [!con((ins ADDR:$src), !if(WithStride, (ins B32:$ldm), (ins)))]>, Requires<Frag.Predicates> { @@ -4714,7 +4714,7 @@ class WMMA_LOAD<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride> // class WMMA_STORE_D<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride> - : WMMA_INSTR<WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.record, + : WMMA_INSTR<WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.record_name, [!con((ins ADDR:$dst), Frag.Ins, !if(WithStride, (ins B32:$ldm), (ins)))]>, @@ -4778,7 +4778,7 @@ class MMA_OP_PREDICATES<WMMA_REGINFO FragA, string b1op> { class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB, WMMA_REGINFO FragC, WMMA_REGINFO FragD, string ALayout, string BLayout, int Satfinite, string rnd, string b1op> - : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, b1op, FragA, FragB, FragC, FragD>.record, + : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, b1op, FragA, FragB, FragC, FragD>.record_name, [FragA.Ins, FragB.Ins, FragC.Ins]>, // Requires does not seem to have effect on Instruction w/o Patterns. // We set it here anyways and propagate to the Pat<> we construct below. @@ -4837,7 +4837,7 @@ defset list<WMMA_INSTR> WMMAs = { class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB, WMMA_REGINFO FragC, WMMA_REGINFO FragD, string ALayout, string BLayout, int Satfinite, string b1op, string Kind> - : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record, + : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record_name, [FragA.Ins, FragB.Ins, FragC.Ins]>, // Requires does not seem to have effect on Instruction w/o Patterns. // We set it here anyways and propagate to the Pat<> we construct below. @@ -4891,7 +4891,7 @@ class MMA_SP<WMMA_REGINFO FragA, WMMA_REGINFO FragB, WMMA_REGINFO FragC, WMMA_REGINFO FragD, string Metadata, string Kind, int Satfinite> : WMMA_INSTR<MMA_SP_NAME<Metadata, Kind, Satfinite, - FragA, FragB, FragC, FragD>.record, + FragA, FragB, FragC, FragD>.record_name, [FragA.Ins, FragB.Ins, FragC.Ins, (ins B32:$metadata, i32imm:$selector)]>, // Requires does not seem to have effect on Instruction w/o Patterns. @@ -4946,7 +4946,7 @@ defset list<WMMA_INSTR> MMA_SPs = { // ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16 // class LDMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space> - : WMMA_INSTR<LDMATRIX_NAME<Frag, Transposed>.record, [(ins ADDR:$src)]>, + : WMMA_INSTR<LDMATRIX_NAME<Frag, Transposed>.record_name, [(ins ADDR:$src)]>, Requires<Frag.Predicates> { // Build PatFrag that only matches particular address space. PatFrag IntrFrag = PatFrag<(ops node:$src), (Intr node:$src), @@ -4981,7 +4981,7 @@ defset list<WMMA_INSTR> LDMATRIXs = { // stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16 // class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space> - : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>, + : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record_name, [!con((ins ADDR:$dst), Frag.Ins)]>, Requires<Frag.Predicates> { // Build PatFrag that only matches particular address space. dag PFOperands = !con((ops node:$dst), @@ -5376,7 +5376,7 @@ class Tcgen05MMAInst<bit Sp, string KindStr, string ASpace, Requires<PTXPredicates> { Intrinsic Intrin = !cast<Intrinsic>( - NVVM_TCGEN05_MMA<Sp, ASpace, AShift, ScaleInputD>.record + NVVM_TCGEN05_MMA<Sp, ASpace, AShift, ScaleInputD>.record_name ); dag ScaleInpIns = !if(!eq(ScaleInputD, 1), (ins i64imm:$scale_input_d), (ins)); @@ -5618,7 +5618,7 @@ class Tcgen05MMABlockScaleInst<bit Sp, string ASpace, string KindStr, Requires<[hasTcgen05Instructions, PTXPredicate]> { Intrinsic Intrin = !cast<Intrinsic>( - NVVM_TCGEN05_MMA_BLOCKSCALE<Sp, ASpace, KindStr, ScaleVecSize>.record); + NVVM_TCGEN05_MMA_BLOCKSCALE<Sp, ASpace, KindStr, ScaleVecSize>.record_name); dag SparseMetadataIns = !if(!eq(Sp, 1), (ins B32:$spmetadata), (ins)); dag SparseMetadataIntr = !if(!eq(Sp, 1), (Intrin i32:$spmetadata), (Intrin)); @@ -5702,7 +5702,7 @@ class Tcgen05MMAWSInst<bit Sp, string ASpace, string KindStr, Requires<[hasTcgen05Instructions]> { Intrinsic Intrin = !cast<Intrinsic>( - NVVM_TCGEN05_MMA_WS<Sp, ASpace, HasZeroColMask>.record); + NVVM_TCGEN05_MMA_WS<Sp, ASpace, HasZeroColMask>.record_name); dag ZeroColMaskIns = !if(!eq(HasZeroColMask, 1), (ins B64:$zero_col_mask), (ins)); diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td index 4104abd..4c2f7f6 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXSf.td @@ -482,7 +482,7 @@ let Predicates = [HasVendorXSfvfwmaccqqq] in { defm SF_VFWMACC_4x4x4 : VPseudoSiFiveVFWMACC; } -let Predicates = [HasVendorXSfvfnrclipxfqf] in { +let Predicates = [HasVendorXSfvfnrclipxfqf], AltFmtType = IS_NOT_ALTFMT in { defm SF_VFNRCLIP_XU_F_QF : VPseudoSiFiveVFNRCLIP; defm SF_VFNRCLIP_X_F_QF : VPseudoSiFiveVFNRCLIP; } diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index d91923b..56a38bb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -1499,18 +1499,25 @@ static bool generateKernelClockInst(const SPIRV::IncomingCall *Call, Register ResultReg = Call->ReturnRegister; - // Deduce the `Scope` operand from the builtin function name. - SPIRV::Scope::Scope ScopeArg = - StringSwitch<SPIRV::Scope::Scope>(Builtin->Name) - .EndsWith("device", SPIRV::Scope::Scope::Device) - .EndsWith("work_group", SPIRV::Scope::Scope::Workgroup) - .EndsWith("sub_group", SPIRV::Scope::Scope::Subgroup); - Register ScopeReg = buildConstantIntReg32(ScopeArg, MIRBuilder, GR); - - MIRBuilder.buildInstr(SPIRV::OpReadClockKHR) - .addDef(ResultReg) - .addUse(GR->getSPIRVTypeID(Call->ReturnType)) - .addUse(ScopeReg); + if (Builtin->Name == "__spirv_ReadClockKHR") { + MIRBuilder.buildInstr(SPIRV::OpReadClockKHR) + .addDef(ResultReg) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(Call->Arguments[0]); + } else { + // Deduce the `Scope` operand from the builtin function name. + SPIRV::Scope::Scope ScopeArg = + StringSwitch<SPIRV::Scope::Scope>(Builtin->Name) + .EndsWith("device", SPIRV::Scope::Scope::Device) + .EndsWith("work_group", SPIRV::Scope::Scope::Workgroup) + .EndsWith("sub_group", SPIRV::Scope::Scope::Subgroup); + Register ScopeReg = buildConstantIntReg32(ScopeArg, MIRBuilder, GR); + + MIRBuilder.buildInstr(SPIRV::OpReadClockKHR) + .addDef(ResultReg) + .addUse(GR->getSPIRVTypeID(Call->ReturnType)) + .addUse(ScopeReg); + } return true; } diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td index 3b8764a..c259cce 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td @@ -1174,6 +1174,7 @@ defm : DemangledNativeBuiltin<"clock_read_sub_group", OpenCL_std, KernelClock, 0 defm : DemangledNativeBuiltin<"clock_read_hilo_device", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>; defm : DemangledNativeBuiltin<"clock_read_hilo_work_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>; defm : DemangledNativeBuiltin<"clock_read_hilo_sub_group", OpenCL_std, KernelClock, 0, 0, OpReadClockKHR>; +defm : DemangledNativeBuiltin<"__spirv_ReadClockKHR", OpenCL_std, KernelClock, 1, 1, OpReadClockKHR>; //===----------------------------------------------------------------------===// // Class defining an atomic instruction on floating-point numbers. diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index d49f25a..4dfc400 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -2632,6 +2632,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM, setOperationAction(Op, MVT::f32, Promote); } + setOperationPromotedToType(ISD::ATOMIC_LOAD, MVT::f16, MVT::i16); + setOperationPromotedToType(ISD::ATOMIC_LOAD, MVT::f32, MVT::i32); + setOperationPromotedToType(ISD::ATOMIC_LOAD, MVT::f64, MVT::i64); + // We have target-specific dag combine patterns for the following nodes: setTargetDAGCombine({ISD::VECTOR_SHUFFLE, ISD::SCALAR_TO_VECTOR, diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index a0f7ec6..2dd0fde 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -948,17 +948,17 @@ void llvm::updateVCallVisibilityInIndex( // linker, as we have no information on their eventual use. if (DynamicExportSymbols.count(P.first)) continue; + // With validation enabled, we want to exclude symbols visible to regular + // objects. Local symbols will be in this group due to the current + // implementation but those with VCallVisibilityTranslationUnit will have + // already been marked in clang so are unaffected. + if (VisibleToRegularObjSymbols.count(P.first)) + continue; for (auto &S : P.second.getSummaryList()) { auto *GVar = dyn_cast<GlobalVarSummary>(S.get()); if (!GVar || GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic) continue; - // With validation enabled, we want to exclude symbols visible to regular - // objects. Local symbols will be in this group due to the current - // implementation but those with VCallVisibilityTranslationUnit will have - // already been marked in clang so are unaffected. - if (VisibleToRegularObjSymbols.count(P.first)) - continue; GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit); } } @@ -1161,14 +1161,10 @@ bool DevirtIndex::tryFindVirtualCallTargets( // and therefore the same GUID. This can happen if there isn't enough // distinguishing path when compiling the source file. In that case we // conservatively return false early. + if (P.VTableVI.hasLocal() && P.VTableVI.getSummaryList().size() > 1) + return false; const GlobalVarSummary *VS = nullptr; - bool LocalFound = false; for (const auto &S : P.VTableVI.getSummaryList()) { - if (GlobalValue::isLocalLinkage(S->linkage())) { - if (LocalFound) - return false; - LocalFound = true; - } auto *CurVS = cast<GlobalVarSummary>(S->getBaseObject()); if (!CurVS->vTableFuncs().empty() || // Previously clang did not attach the necessary type metadata to @@ -1184,6 +1180,7 @@ bool DevirtIndex::tryFindVirtualCallTargets( // with public LTO visibility. if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic) return false; + break; } } // There will be no VS if all copies are available_externally having no @@ -1411,9 +1408,8 @@ bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, // If the summary list contains multiple summaries where at least one is // a local, give up, as we won't know which (possibly promoted) name to use. - for (const auto &S : TheFn.getSummaryList()) - if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1) - return false; + if (TheFn.hasLocal() && Size > 1) + return false; // Collect functions devirtualized at least for one call site for stats. if (PrintSummaryDevirt || AreStatisticsEnabled()) @@ -2591,6 +2587,11 @@ void DevirtIndex::run() { if (ExportSummary.typeIdCompatibleVtableMap().empty()) return; + // Assert that we haven't made any changes that would affect the hasLocal() + // flag on the GUID summary info. + assert(!ExportSummary.withInternalizeAndPromote() && + "Expect index-based WPD to run before internalization and promotion"); + DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID; for (const auto &P : ExportSummary.typeIdCompatibleVtableMap()) { NameByGUID[GlobalValue::getGUIDAssumingExternalLinkage(P.first)].push_back( diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp index 4acc3f2..d347ced 100644 --- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -614,6 +614,16 @@ static Decomposition decompose(Value *V, return {V, IsKnownNonNegative}; } + if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() && + canUseSExt(CI)) { + Preconditions.emplace_back( + CmpInst::ICMP_UGE, Op0, + ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1)); + if (auto Decomp = MergeResults(Op0, CI, true)) + return *Decomp; + return {V, IsKnownNonNegative}; + } + if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) { if (!isKnownNonNegative(Op0, DL)) Preconditions.emplace_back(CmpInst::ICMP_SGE, Op0, @@ -627,16 +637,6 @@ static Decomposition decompose(Value *V, return {V, IsKnownNonNegative}; } - if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() && - canUseSExt(CI)) { - Preconditions.emplace_back( - CmpInst::ICMP_UGE, Op0, - ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1)); - if (auto Decomp = MergeResults(Op0, CI, true)) - return *Decomp; - return {V, IsKnownNonNegative}; - } - // Decompose or as an add if there are no common bits between the operands. if (match(V, m_DisjointOr(m_Value(Op0), m_ConstantInt(CI)))) { if (auto Decomp = MergeResults(Op0, CI, IsSigned)) diff --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp index a83cbd17a7..f273e9d 100644 --- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp +++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -64,10 +64,10 @@ using namespace llvm; -namespace { - #define DEBUG_TYPE "mergeicmps" +namespace { + // A BCE atom "Binary Compare Expression Atom" represents an integer load // that is a constant offset from a base value, e.g. `a` or `o.c` in the example // at the top. @@ -128,11 +128,12 @@ private: unsigned Order = 1; DenseMap<const Value*, int> BaseToIndex; }; +} // namespace // If this value is a load from a constant offset w.r.t. a base address, and // there are no other users of the load or address, returns the base address and // the offset. -BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { +static BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { auto *const LoadI = dyn_cast<LoadInst>(Val); if (!LoadI) return {}; @@ -175,6 +176,7 @@ BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { return BCEAtom(GEP, LoadI, BaseId.getBaseId(Base), Offset); } +namespace { // A comparison between two BCE atoms, e.g. `a == o.a` in the example at the // top. // Note: the terminology is misleading: the comparison is symmetric, so there @@ -239,6 +241,7 @@ class BCECmpBlock { private: BCECmp Cmp; }; +} // namespace bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, AliasAnalysis &AA) const { @@ -302,9 +305,9 @@ bool BCECmpBlock::doesOtherWork() const { // Visit the given comparison. If this is a comparison between two valid // BCE atoms, returns the comparison. -std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI, - const ICmpInst::Predicate ExpectedPredicate, - BaseIdentifier &BaseId) { +static std::optional<BCECmp> +visitICmp(const ICmpInst *const CmpI, + const ICmpInst::Predicate ExpectedPredicate, BaseIdentifier &BaseId) { // The comparison can only be used once: // - For intermediate blocks, as a branch condition. // - For the final block, as an incoming value for the Phi. @@ -332,10 +335,9 @@ std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI, // Visit the given comparison block. If this is a comparison between two valid // BCE atoms, returns the comparison. -std::optional<BCECmpBlock> visitCmpBlock(Value *const Val, - BasicBlock *const Block, - const BasicBlock *const PhiBlock, - BaseIdentifier &BaseId) { +static std::optional<BCECmpBlock> +visitCmpBlock(Value *const Val, BasicBlock *const Block, + const BasicBlock *const PhiBlock, BaseIdentifier &BaseId) { if (Block->empty()) return std::nullopt; auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator()); @@ -397,6 +399,7 @@ static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons, Comparisons.push_back(std::move(Comparison)); } +namespace { // A chain of comparisons. class BCECmpChain { public: @@ -420,6 +423,7 @@ private: // The original entry block (before sorting); BasicBlock *EntryBlock_; }; +} // namespace static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) { return First.Lhs().BaseId == Second.Lhs().BaseId && @@ -742,9 +746,8 @@ bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA, return true; } -std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, - BasicBlock *const LastBlock, - int NumBlocks) { +static std::vector<BasicBlock *> +getOrderedBlocks(PHINode &Phi, BasicBlock *const LastBlock, int NumBlocks) { // Walk up from the last block to find other blocks. std::vector<BasicBlock *> Blocks(NumBlocks); assert(LastBlock && "invalid last block"); @@ -777,8 +780,8 @@ std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, return Blocks; } -bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, AliasAnalysis &AA, - DomTreeUpdater &DTU) { +static bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, + AliasAnalysis &AA, DomTreeUpdater &DTU) { LLVM_DEBUG(dbgs() << "processPhi()\n"); if (Phi.getNumIncomingValues() <= 1) { LLVM_DEBUG(dbgs() << "skip: only one incoming value in phi\n"); @@ -874,6 +877,7 @@ static bool runImpl(Function &F, const TargetLibraryInfo &TLI, return MadeChange; } +namespace { class MergeICmpsLegacyPass : public FunctionPass { public: static char ID; diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 06bea2f..a1ad2db 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2712,7 +2712,8 @@ public: static inline bool classof(const VPRecipeBase *R) { return R->getVPDefID() == VPRecipeBase::VPReductionSC || - R->getVPDefID() == VPRecipeBase::VPReductionEVLSC; + R->getVPDefID() == VPRecipeBase::VPReductionEVLSC || + R->getVPDefID() == VPRecipeBase::VPPartialReductionSC; } static inline bool classof(const VPUser *U) { @@ -2783,7 +2784,10 @@ public: Opcode(Opcode), VFScaleFactor(ScaleFactor) { [[maybe_unused]] auto *AccumulatorRecipe = getChainOp()->getDefiningRecipe(); - assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) || + // When cloning as part of a VPExpressionRecipe the chain op could have + // replaced by a temporary VPValue, so it doesn't have a defining recipe. + assert((!AccumulatorRecipe || + isa<VPReductionPHIRecipe>(AccumulatorRecipe) || isa<VPPartialReductionRecipe>(AccumulatorRecipe)) && "Unexpected operand order for partial reduction recipe"); } @@ -3093,6 +3097,11 @@ public: /// removed before codegen. void decompose(); + unsigned getVFScaleFactor() const { + auto *PR = dyn_cast<VPPartialReductionRecipe>(ExpressionRecipes.back()); + return PR ? PR->getVFScaleFactor() : 1; + } + /// Method for generating code, must not be called as this recipe is abstract. void execute(VPTransformState &State) override { llvm_unreachable("recipe must be removed before execute"); diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 1f1b42b..931a5b7 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -168,6 +168,7 @@ bool VPRecipeBase::mayHaveSideEffects() const { return cast<VPWidenIntrinsicRecipe>(this)->mayHaveSideEffects(); case VPBlendSC: case VPReductionEVLSC: + case VPPartialReductionSC: case VPReductionSC: case VPScalarIVStepsSC: case VPVectorPointerSC: @@ -300,14 +301,23 @@ InstructionCost VPPartialReductionRecipe::computeCost(ElementCount VF, VPCostContext &Ctx) const { std::optional<unsigned> Opcode; - VPValue *Op = getOperand(0); - VPRecipeBase *OpR = Op->getDefiningRecipe(); - - // If the partial reduction is predicated, a select will be operand 0 - if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(Op), m_VPValue()))) { - OpR = Op->getDefiningRecipe(); + VPValue *Op = getVecOp(); + uint64_t MulConst; + // If the partial reduction is predicated, a select will be operand 1. + // If it isn't predicated and the mul isn't operating on a constant, then it + // should have been turned into a VPExpressionRecipe. + // FIXME: Replace the entire function with this once all partial reduction + // variants are bundled into VPExpressionRecipe. + if (!match(Op, m_Select(m_VPValue(), m_VPValue(Op), m_VPValue())) && + !match(Op, m_Mul(m_VPValue(), m_ConstantInt(MulConst)))) { + auto *PhiType = Ctx.Types.inferScalarType(getChainOp()); + auto *InputType = Ctx.Types.inferScalarType(getVecOp()); + return Ctx.TTI.getPartialReductionCost(getOpcode(), InputType, InputType, + PhiType, VF, TTI::PR_None, + TTI::PR_None, {}, Ctx.CostKind); } + VPRecipeBase *OpR = Op->getDefiningRecipe(); Type *InputTypeA = nullptr, *InputTypeB = nullptr; TTI::PartialReductionExtendKind ExtAType = TTI::PR_None, ExtBType = TTI::PR_None; @@ -2856,11 +2866,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, cast<VPReductionRecipe>(ExpressionRecipes.back())->getRecurrenceKind()); switch (ExpressionType) { case ExpressionTypes::ExtendedReduction: { - return Ctx.TTI.getExtendedReductionCost( - Opcode, - cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() == - Instruction::ZExt, - RedTy, SrcVecTy, std::nullopt, Ctx.CostKind); + unsigned Opcode = RecurrenceDescriptor::getOpcode( + cast<VPReductionRecipe>(ExpressionRecipes[1])->getRecurrenceKind()); + auto *ExtR = cast<VPWidenCastRecipe>(ExpressionRecipes[0]); + return isa<VPPartialReductionRecipe>(ExpressionRecipes.back()) + ? Ctx.TTI.getPartialReductionCost( + Opcode, Ctx.Types.inferScalarType(getOperand(0)), nullptr, + RedTy, VF, + TargetTransformInfo::getPartialReductionExtendKind( + ExtR->getOpcode()), + TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind) + : Ctx.TTI.getExtendedReductionCost( + Opcode, ExtR->getOpcode() == Instruction::ZExt, RedTy, + SrcVecTy, std::nullopt, Ctx.CostKind); } case ExpressionTypes::MulAccReduction: return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy, @@ -2871,6 +2889,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, Opcode = Instruction::Sub; [[fallthrough]]; case ExpressionTypes::ExtMulAccReduction: { + if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) { + auto *Ext0R = cast<VPWidenCastRecipe>(ExpressionRecipes[0]); + auto *Ext1R = cast<VPWidenCastRecipe>(ExpressionRecipes[1]); + auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]); + return Ctx.TTI.getPartialReductionCost( + Opcode, Ctx.Types.inferScalarType(getOperand(0)), + Ctx.Types.inferScalarType(getOperand(1)), RedTy, VF, + TargetTransformInfo::getPartialReductionExtendKind( + Ext0R->getOpcode()), + TargetTransformInfo::getPartialReductionExtendKind( + Ext1R->getOpcode()), + Mul->getOpcode(), Ctx.CostKind); + } return Ctx.TTI.getMulAccReductionCost( cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() == Instruction::ZExt, @@ -2910,12 +2941,13 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, O << " = "; auto *Red = cast<VPReductionRecipe>(ExpressionRecipes.back()); unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()); + bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red); switch (ExpressionType) { case ExpressionTypes::ExtendedReduction: { getOperand(1)->printAsOperand(O, SlotTracker); - O << " +"; - O << " reduce." << Instruction::getOpcodeName(Opcode) << " ("; + O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; + O << Instruction::getOpcodeName(Opcode) << " ("; getOperand(0)->printAsOperand(O, SlotTracker); Red->printFlags(O); @@ -2931,8 +2963,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, } case ExpressionTypes::ExtNegatedMulAccReduction: { getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); - O << " + reduce." - << Instruction::getOpcodeName( + O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; + O << Instruction::getOpcodeName( RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind())) << " (sub (0, mul"; auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]); @@ -2956,9 +2988,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, case ExpressionTypes::MulAccReduction: case ExpressionTypes::ExtMulAccReduction: { getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); - O << " + "; - O << "reduce." - << Instruction::getOpcodeName( + O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; + O << Instruction::getOpcodeName( RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind())) << " ("; O << "mul"; diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index f5a3af4..3e85e6f 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -3519,18 +3519,31 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx, VPValue *VecOp = Red->getVecOp(); // Clamp the range if using extended-reduction is profitable. - auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt, - Type *SrcTy) -> bool { + auto IsExtendedRedValidAndClampRange = + [&](unsigned Opcode, Instruction::CastOps ExtOpc, Type *SrcTy) -> bool { return LoopVectorizationPlanner::getDecisionAndClampRange( [&](ElementCount VF) { auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF)); TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; - InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost( - Opcode, isZExt, RedTy, SrcVecTy, Red->getFastMathFlags(), - CostKind); + + InstructionCost ExtRedCost; InstructionCost ExtCost = cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx); InstructionCost RedCost = Red->computeCost(VF, Ctx); + + if (isa<VPPartialReductionRecipe>(Red)) { + TargetTransformInfo::PartialReductionExtendKind ExtKind = + TargetTransformInfo::getPartialReductionExtendKind(ExtOpc); + // FIXME: Move partial reduction creation, costing and clamping + // here from LoopVectorize.cpp. + ExtRedCost = Ctx.TTI.getPartialReductionCost( + Opcode, SrcTy, nullptr, RedTy, VF, ExtKind, + llvm::TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind); + } else { + ExtRedCost = Ctx.TTI.getExtendedReductionCost( + Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy, + Red->getFastMathFlags(), CostKind); + } return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost; }, Range); @@ -3541,8 +3554,7 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx, if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && IsExtendedRedValidAndClampRange( RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()), - cast<VPWidenCastRecipe>(VecOp)->getOpcode() == - Instruction::CastOps::ZExt, + cast<VPWidenCastRecipe>(VecOp)->getOpcode(), Ctx.Types.inferScalarType(A))) return new VPExpressionRecipe(cast<VPWidenCastRecipe>(VecOp), Red); @@ -3560,6 +3572,8 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx, static VPExpressionRecipe * tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, VPCostContext &Ctx, VFRange &Range) { + bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red); + unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()); if (Opcode != Instruction::Add && Opcode != Instruction::Sub) return nullptr; @@ -3568,16 +3582,41 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, // Clamp the range if using multiply-accumulate-reduction is profitable. auto IsMulAccValidAndClampRange = - [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, - VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool { + [&](VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1, + VPWidenCastRecipe *OuterExt) -> bool { return LoopVectorizationPlanner::getDecisionAndClampRange( [&](ElementCount VF) { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; Type *SrcTy = Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy; - auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF)); - InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost( - isZExt, Opcode, RedTy, SrcVecTy, CostKind); + InstructionCost MulAccCost; + + if (IsPartialReduction) { + Type *SrcTy2 = + Ext1 ? Ctx.Types.inferScalarType(Ext1->getOperand(0)) : nullptr; + // FIXME: Move partial reduction creation, costing and clamping + // here from LoopVectorize.cpp. + MulAccCost = Ctx.TTI.getPartialReductionCost( + Opcode, SrcTy, SrcTy2, RedTy, VF, + Ext0 ? TargetTransformInfo::getPartialReductionExtendKind( + Ext0->getOpcode()) + : TargetTransformInfo::PR_None, + Ext1 ? TargetTransformInfo::getPartialReductionExtendKind( + Ext1->getOpcode()) + : TargetTransformInfo::PR_None, + Mul->getOpcode(), CostKind); + } else { + // Only partial reductions support mixed extends at the moment. + if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode()) + return false; + + bool IsZExt = + !Ext0 || Ext0->getOpcode() == Instruction::CastOps::ZExt; + auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF)); + MulAccCost = Ctx.TTI.getMulAccReductionCost(IsZExt, Opcode, RedTy, + SrcVecTy, CostKind); + } + InstructionCost MulCost = Mul->computeCost(VF, Ctx); InstructionCost RedCost = Red->computeCost(VF, Ctx); InstructionCost ExtCost = 0; @@ -3611,14 +3650,10 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe()); auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe()); - // Match reduce.add(mul(ext, ext)). - if (RecipeA && RecipeB && - (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) && - match(RecipeA, m_ZExtOrSExt(m_VPValue())) && + // Match reduce.add/sub(mul(ext, ext)). + if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) && match(RecipeB, m_ZExtOrSExt(m_VPValue())) && - IsMulAccValidAndClampRange(RecipeA->getOpcode() == - Instruction::CastOps::ZExt, - Mul, RecipeA, RecipeB, nullptr)) { + IsMulAccValidAndClampRange(Mul, RecipeA, RecipeB, nullptr)) { if (Sub) return new VPExpressionRecipe(RecipeA, RecipeB, Mul, cast<VPWidenRecipe>(Sub), Red); @@ -3626,8 +3661,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, } // Match reduce.add(mul). // TODO: Add an expression type for this variant with a negated mul - if (!Sub && - IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr)) + if (!Sub && IsMulAccValidAndClampRange(Mul, nullptr, nullptr, nullptr)) return new VPExpressionRecipe(Mul, Red); } // TODO: Add an expression type for negated versions of other expression @@ -3647,9 +3681,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe()); if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) && Ext0->getOpcode() == Ext1->getOpcode() && - IsMulAccValidAndClampRange(Ext0->getOpcode() == - Instruction::CastOps::ZExt, - Mul, Ext0, Ext1, Ext)) { + IsMulAccValidAndClampRange(Mul, Ext0, Ext1, Ext) && Mul->hasOneUse()) { auto *NewExt0 = new VPWidenCastRecipe( Ext0->getOpcode(), Ext0->getOperand(0), Ext->getResultType(), *Ext0, *Ext0, Ext0->getDebugLoc()); diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp index 32e4b88..06c3d75 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp @@ -151,6 +151,8 @@ unsigned vputils::getVFScaleFactor(VPRecipeBase *R) { return RR->getVFScaleFactor(); if (auto *RR = dyn_cast<VPPartialReductionRecipe>(R)) return RR->getVFScaleFactor(); + if (auto *ER = dyn_cast<VPExpressionRecipe>(R)) + return ER->getVFScaleFactor(); assert( (!isa<VPInstruction>(R) || cast<VPInstruction>(R)->getOpcode() != VPInstruction::ReductionStartVector) && |
