diff options
Diffstat (limited to 'llvm/lib')
44 files changed, 269 insertions, 103 deletions
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp index 27114e0..033f516 100644 --- a/llvm/lib/Analysis/DXILResource.cpp +++ b/llvm/lib/Analysis/DXILResource.cpp @@ -23,7 +23,6 @@ #include "llvm/Support/DXILABI.h" #include "llvm/Support/FormatVariadic.h" #include <cstdint> -#include <optional> #define DEBUG_TYPE "dxil-resource" diff --git a/llvm/lib/Analysis/TFLiteUtils.cpp b/llvm/lib/Analysis/TFLiteUtils.cpp index 2762e22..fcef1c8 100644 --- a/llvm/lib/Analysis/TFLiteUtils.cpp +++ b/llvm/lib/Analysis/TFLiteUtils.cpp @@ -30,7 +30,6 @@ #include "tensorflow/lite/logger.h" #include <cassert> -#include <numeric> #include <optional> using namespace llvm; diff --git a/llvm/lib/Analysis/TrainingLogger.cpp b/llvm/lib/Analysis/TrainingLogger.cpp index 344ca92..39f79cf 100644 --- a/llvm/lib/Analysis/TrainingLogger.cpp +++ b/llvm/lib/Analysis/TrainingLogger.cpp @@ -23,7 +23,6 @@ #include "llvm/Support/raw_ostream.h" #include <cassert> -#include <numeric> using namespace llvm; diff --git a/llvm/lib/CodeGen/AsmPrinter/DwarfDebug.cpp b/llvm/lib/CodeGen/AsmPrinter/DwarfDebug.cpp index 567acf7..fde48a3 100644 --- a/llvm/lib/CodeGen/AsmPrinter/DwarfDebug.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/DwarfDebug.cpp @@ -1544,18 +1544,8 @@ void DwarfDebug::ensureAbstractEntityIsCreatedIfScoped(DwarfCompileUnit &CU, } static const DILocalScope *getRetainedNodeScope(const MDNode *N) { - const DIScope *S; - if (const auto *LV = dyn_cast<DILocalVariable>(N)) - S = LV->getScope(); - else if (const auto *L = dyn_cast<DILabel>(N)) - S = L->getScope(); - else if (const auto *IE = dyn_cast<DIImportedEntity>(N)) - S = IE->getScope(); - else - llvm_unreachable("Unexpected retained node!"); - // Ensure the scope is not a DILexicalBlockFile. - return cast<DILocalScope>(S)->getNonLexicalBlockFileScope(); + return DISubprogram::getRetainedNodeScope(N)->getNonLexicalBlockFileScope(); } // Collect variable information from side table maintained by MF. diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index 0309e22..b6dd174 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -1839,7 +1839,8 @@ bool CodeGenPrepare::unfoldPowerOf2Test(CmpInst *Cmp) { /// lose; some adjustment may be wanted there. /// /// Return true if any changes are made. -static bool sinkCmpExpression(CmpInst *Cmp, const TargetLowering &TLI) { +static bool sinkCmpExpression(CmpInst *Cmp, const TargetLowering &TLI, + const DataLayout &DL) { if (TLI.hasMultipleConditionRegisters(EVT::getEVT(Cmp->getType()))) return false; @@ -1847,6 +1848,18 @@ static bool sinkCmpExpression(CmpInst *Cmp, const TargetLowering &TLI) { if (TLI.useSoftFloat() && isa<FCmpInst>(Cmp)) return false; + bool UsedInPhiOrCurrentBlock = any_of(Cmp->users(), [Cmp](User *U) { + return isa<PHINode>(U) || + cast<Instruction>(U)->getParent() == Cmp->getParent(); + }); + + // Avoid sinking larger than legal integer comparisons unless its ONLY used in + // another BB. + if (UsedInPhiOrCurrentBlock && Cmp->getOperand(0)->getType()->isIntegerTy() && + Cmp->getOperand(0)->getType()->getScalarSizeInBits() > + DL.getLargestLegalIntTypeSizeInBits()) + return false; + // Only insert a cmp in each block once. DenseMap<BasicBlock *, CmpInst *> InsertedCmps; @@ -2224,7 +2237,7 @@ bool CodeGenPrepare::optimizeURem(Instruction *Rem) { } bool CodeGenPrepare::optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT) { - if (sinkCmpExpression(Cmp, *TLI)) + if (sinkCmpExpression(Cmp, *TLI, *DL)) return true; if (combineToUAddWithOverflow(Cmp, ModifiedDT)) diff --git a/llvm/lib/CodeGen/RegAllocGreedy.h b/llvm/lib/CodeGen/RegAllocGreedy.h index 7f013d1..4affa27 100644 --- a/llvm/lib/CodeGen/RegAllocGreedy.h +++ b/llvm/lib/CodeGen/RegAllocGreedy.h @@ -33,7 +33,6 @@ #include "llvm/CodeGen/SpillPlacement.h" #include "llvm/CodeGen/Spiller.h" #include "llvm/CodeGen/TargetRegisterInfo.h" -#include <algorithm> #include <cstdint> #include <memory> #include <queue> diff --git a/llvm/lib/CodeGen/TargetSchedule.cpp b/llvm/lib/CodeGen/TargetSchedule.cpp index 7ae9e0e..cd951a1 100644 --- a/llvm/lib/CodeGen/TargetSchedule.cpp +++ b/llvm/lib/CodeGen/TargetSchedule.cpp @@ -25,7 +25,6 @@ #include "llvm/Support/raw_ostream.h" #include <algorithm> #include <cassert> -#include <numeric> using namespace llvm; diff --git a/llvm/lib/DebugInfo/DWARF/DWARFUnwindTablePrinter.cpp b/llvm/lib/DebugInfo/DWARF/DWARFUnwindTablePrinter.cpp index a88f4a5..a4bdd1f 100644 --- a/llvm/lib/DebugInfo/DWARF/DWARFUnwindTablePrinter.cpp +++ b/llvm/lib/DebugInfo/DWARF/DWARFUnwindTablePrinter.cpp @@ -15,7 +15,6 @@ #include <cassert> #include <cinttypes> #include <cstdint> -#include <optional> using namespace llvm; using namespace dwarf; diff --git a/llvm/lib/Frontend/Offloading/OffloadWrapper.cpp b/llvm/lib/Frontend/Offloading/OffloadWrapper.cpp index c4aa2c7..45818de 100644 --- a/llvm/lib/Frontend/Offloading/OffloadWrapper.cpp +++ b/llvm/lib/Frontend/Offloading/OffloadWrapper.cpp @@ -29,7 +29,6 @@ #include "llvm/Transforms/Utils/ModuleUtils.h" #include <memory> -#include <string> #include <utility> using namespace llvm; diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index fff9a81..7dc32fd 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -8460,9 +8460,8 @@ OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const { Config.separator()); } -GlobalVariable * -OpenMPIRBuilder::getOrCreateInternalVariable(Type *Ty, const StringRef &Name, - unsigned AddressSpace) { +GlobalVariable *OpenMPIRBuilder::getOrCreateInternalVariable( + Type *Ty, const StringRef &Name, std::optional<unsigned> AddressSpace) { auto &Elem = *InternalVars.try_emplace(Name, nullptr).first; if (Elem.second) { assert(Elem.second->getValueType() == Ty && @@ -8472,16 +8471,18 @@ OpenMPIRBuilder::getOrCreateInternalVariable(Type *Ty, const StringRef &Name, // variable for possibly changing that to internal or private, or maybe // create different versions of the function for different OMP internal // variables. + const DataLayout &DL = M.getDataLayout(); + unsigned AddressSpaceVal = + AddressSpace ? *AddressSpace : DL.getDefaultGlobalsAddressSpace(); auto Linkage = this->M.getTargetTriple().getArch() == Triple::wasm32 ? GlobalValue::InternalLinkage : GlobalValue::CommonLinkage; auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage, Constant::getNullValue(Ty), Elem.first(), /*InsertBefore=*/nullptr, - GlobalValue::NotThreadLocal, AddressSpace); - const DataLayout &DL = M.getDataLayout(); + GlobalValue::NotThreadLocal, AddressSpaceVal); const llvm::Align TypeAlign = DL.getABITypeAlign(Ty); - const llvm::Align PtrAlign = DL.getPointerABIAlignment(AddressSpace); + const llvm::Align PtrAlign = DL.getPointerABIAlignment(AddressSpaceVal); GV->setAlignment(std::max(TypeAlign, PtrAlign)); Elem.second = GV; } diff --git a/llvm/lib/IR/DebugInfo.cpp b/llvm/lib/IR/DebugInfo.cpp index 5883606..4c5fb74 100644 --- a/llvm/lib/IR/DebugInfo.cpp +++ b/llvm/lib/IR/DebugInfo.cpp @@ -247,7 +247,7 @@ void DebugInfoFinder::processType(DIType *DT) { } } -void DebugInfoFinder::processImportedEntity(DIImportedEntity *Import) { +void DebugInfoFinder::processImportedEntity(const DIImportedEntity *Import) { auto *Entity = Import->getEntity(); if (auto *T = dyn_cast<DIType>(Entity)) processType(T); @@ -307,15 +307,13 @@ void DebugInfoFinder::processSubprogram(DISubprogram *SP) { } } - for (auto *N : SP->getRetainedNodes()) { - if (auto *Var = dyn_cast_or_null<DILocalVariable>(N)) - processVariable(Var); - else if (auto *Import = dyn_cast_or_null<DIImportedEntity>(N)) - processImportedEntity(Import); - } + SP->forEachRetainedNode( + [this](const DILocalVariable *LV) { processVariable(LV); }, + [](const DILabel *L) {}, + [this](const DIImportedEntity *IE) { processImportedEntity(IE); }); } -void DebugInfoFinder::processVariable(DILocalVariable *DV) { +void DebugInfoFinder::processVariable(const DILocalVariable *DV) { if (!NodesSeen.insert(DV).second) return; processScope(DV->getScope()); diff --git a/llvm/lib/IR/DebugInfoMetadata.cpp b/llvm/lib/IR/DebugInfoMetadata.cpp index a98e925..1a6a25e 100644 --- a/llvm/lib/IR/DebugInfoMetadata.cpp +++ b/llvm/lib/IR/DebugInfoMetadata.cpp @@ -1441,6 +1441,19 @@ bool DISubprogram::describes(const Function *F) const { assert(F && "Invalid function"); return F->getSubprogram() == this; } + +const DIScope *DISubprogram::getRawRetainedNodeScope(const MDNode *N) { + return visitRetainedNode<DIScope *>( + N, [](const DILocalVariable *LV) { return LV->getScope(); }, + [](const DILabel *L) { return L->getScope(); }, + [](const DIImportedEntity *IE) { return IE->getScope(); }, + [](const Metadata *N) { return nullptr; }); +} + +const DILocalScope *DISubprogram::getRetainedNodeScope(const MDNode *N) { + return cast<DILocalScope>(getRawRetainedNodeScope(N)); +} + DILexicalBlockBase::DILexicalBlockBase(LLVMContext &C, unsigned ID, StorageType Storage, ArrayRef<Metadata *> Ops) diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 59eb870..a4f6474 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -1559,11 +1559,27 @@ void Verifier::visitDISubprogram(const DISubprogram &N) { auto *Node = dyn_cast<MDTuple>(RawNode); CheckDI(Node, "invalid retained nodes list", &N, RawNode); for (Metadata *Op : Node->operands()) { - CheckDI(Op && (isa<DILocalVariable>(Op) || isa<DILabel>(Op) || - isa<DIImportedEntity>(Op)), + CheckDI(Op, "nullptr in retained nodes", &N, Node); + + auto True = [](const Metadata *) { return true; }; + auto False = [](const Metadata *) { return false; }; + bool IsTypeCorrect = + DISubprogram::visitRetainedNode<bool>(Op, True, True, True, False); + CheckDI(IsTypeCorrect, "invalid retained nodes, expected DILocalVariable, DILabel or " "DIImportedEntity", &N, Node, Op); + + auto *RetainedNode = cast<DINode>(Op); + auto *RetainedNodeScope = dyn_cast_or_null<DILocalScope>( + DISubprogram::getRawRetainedNodeScope(RetainedNode)); + CheckDI(RetainedNodeScope, + "invalid retained nodes, retained node is not local", &N, Node, + RetainedNode); + CheckDI( + RetainedNodeScope->getSubprogram() == &N, + "invalid retained nodes, retained node does not belong to subprogram", + &N, Node, RetainedNode, RetainedNodeScope); } } CheckDI(!hasConflictingReferenceFlags(N.getFlags()), diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index 40ceb6f..e0babc4 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -900,6 +900,11 @@ Expected<bool> parseEntryExitInstrumenterPassOptions(StringRef Params) { "EntryExitInstrumenter"); } +Expected<bool> parseDropUnnecessaryAssumesPassOptions(StringRef Params) { + return PassBuilder::parseSinglePassOption(Params, "drop-deref", + "DropUnnecessaryAssumes"); +} + Expected<bool> parseLoopExtractorPassOptions(StringRef Params) { return PassBuilder::parseSinglePassOption(Params, "single", "LoopExtractor"); } diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp index 3f41618..2fe963b 100644 --- a/llvm/lib/Passes/PassBuilderPipelines.cpp +++ b/llvm/lib/Passes/PassBuilderPipelines.cpp @@ -1298,10 +1298,18 @@ PassBuilder::buildModuleSimplificationPipeline(OptimizationLevel Level, /// TODO: Should LTO cause any differences to this set of passes? void PassBuilder::addVectorPasses(OptimizationLevel Level, - FunctionPassManager &FPM, bool IsFullLTO) { + FunctionPassManager &FPM, + ThinOrFullLTOPhase LTOPhase) { + const bool IsFullLTO = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink; + FPM.addPass(LoopVectorizePass( LoopVectorizeOptions(!PTO.LoopInterleaving, !PTO.LoopVectorization))); + // Drop dereferenceable assumes after vectorization, as they are no longer + // needed and can inhibit further optimization. + if (!isLTOPreLink(LTOPhase)) + FPM.addPass(DropUnnecessaryAssumesPass(/*DropDereferenceable=*/true)); + FPM.addPass(InferAlignmentPass()); if (IsFullLTO) { // The vectorizer may have significantly shortened a loop body; unroll @@ -1572,7 +1580,7 @@ PassBuilder::buildModuleOptimizationPipeline(OptimizationLevel Level, // from the TargetLibraryInfo. OptimizePM.addPass(InjectTLIMappings()); - addVectorPasses(Level, OptimizePM, /* IsFullLTO */ false); + addVectorPasses(Level, OptimizePM, LTOPhase); invokeVectorizerEndEPCallbacks(OptimizePM, Level); @@ -2162,7 +2170,7 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level, MainFPM.addPass(LoopDistributePass()); - addVectorPasses(Level, MainFPM, /* IsFullLTO */ true); + addVectorPasses(Level, MainFPM, ThinOrFullLTOPhase::FullLTOPostLink); invokeVectorizerEndEPCallbacks(MainFPM, Level); diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index d870f99..d8305fe5 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -432,7 +432,6 @@ FUNCTION_PASS("dot-post-dom", PostDomPrinter()) FUNCTION_PASS("dot-post-dom-only", PostDomOnlyPrinter()) FUNCTION_PASS("dse", DSEPass()) FUNCTION_PASS("dwarf-eh-prepare", DwarfEHPreparePass(*TM)) -FUNCTION_PASS("drop-unnecessary-assumes", DropUnnecessaryAssumesPass()) FUNCTION_PASS("expand-large-div-rem", ExpandLargeDivRemPass(*TM)) FUNCTION_PASS("expand-memcmp", ExpandMemCmpPass(*TM)) FUNCTION_PASS("expand-reductions", ExpandReductionsPass()) @@ -585,6 +584,10 @@ FUNCTION_PASS_WITH_PARAMS( [](bool UseMemorySSA) { return EarlyCSEPass(UseMemorySSA); }, parseEarlyCSEPassOptions, "memssa") FUNCTION_PASS_WITH_PARAMS( + "drop-unnecessary-assumes", "DropUnnecessaryAssumesPass", + [](bool DropDereferenceable) { return DropUnnecessaryAssumesPass(DropDereferenceable); }, + parseDropUnnecessaryAssumesPassOptions, "drop-deref") +FUNCTION_PASS_WITH_PARAMS( "ee-instrument", "EntryExitInstrumenterPass", [](bool PostInlining) { return EntryExitInstrumenterPass(PostInlining); }, parseEntryExitInstrumenterPassOptions, "post-inline") diff --git a/llvm/lib/Support/Mustache.cpp b/llvm/lib/Support/Mustache.cpp index 6c140be..8b95049 100644 --- a/llvm/lib/Support/Mustache.cpp +++ b/llvm/lib/Support/Mustache.cpp @@ -10,7 +10,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include <cctype> -#include <optional> #include <sstream> #define DEBUG_TYPE "mustache" diff --git a/llvm/lib/Support/Unix/Unix.h b/llvm/lib/Support/Unix/Unix.h index a1d44c6..f24d524 100644 --- a/llvm/lib/Support/Unix/Unix.h +++ b/llvm/lib/Support/Unix/Unix.h @@ -22,7 +22,6 @@ #include "llvm/Support/Chrono.h" #include "llvm/Support/Errno.h" #include "llvm/Support/ErrorHandling.h" -#include <algorithm> #include <assert.h> #include <cerrno> #include <cstdio> diff --git a/llvm/lib/Support/Windows/Program.inc b/llvm/lib/Support/Windows/Program.inc index ec785e4..5dcd2c9 100644 --- a/llvm/lib/Support/Windows/Program.inc +++ b/llvm/lib/Support/Windows/Program.inc @@ -23,7 +23,6 @@ #include <fcntl.h> #include <io.h> #include <malloc.h> -#include <numeric> #include <psapi.h> //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Support/Windows/Signals.inc b/llvm/lib/Support/Windows/Signals.inc index da68994..bacbb76 100644 --- a/llvm/lib/Support/Windows/Signals.inc +++ b/llvm/lib/Support/Windows/Signals.inc @@ -16,7 +16,6 @@ #include "llvm/Support/Path.h" #include "llvm/Support/Process.h" #include "llvm/Support/WindowsError.h" -#include <algorithm> #include <io.h> #include <signal.h> #include <stdio.h> diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 76a790dc..8457f61 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -22308,6 +22308,37 @@ static SDValue performExtBinopLoadFold(SDNode *N, SelectionDAG &DAG) { return DAG.getNode(N->getOpcode(), DL, VT, Ext0, NShift); } +// Attempt to combine the following patterns: +// SUB x, (CSET LO, (CMP a, b)) -> SBC x, 0, (CMP a, b) +// SUB (SUB x, y), (CSET LO, (CMP a, b)) -> SBC x, y, (CMP a, b) +// The CSET may be preceded by a ZEXT. +static SDValue performSubWithBorrowCombine(SDNode *N, SelectionDAG &DAG) { + if (N->getOpcode() != ISD::SUB) + return SDValue(); + + EVT VT = N->getValueType(0); + if (VT != MVT::i32 && VT != MVT::i64) + return SDValue(); + + SDValue N1 = N->getOperand(1); + if (N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) + N1 = N1.getOperand(0); + if (!N1.hasOneUse() || getCSETCondCode(N1) != AArch64CC::LO) + return SDValue(); + + SDValue Flags = N1.getOperand(3); + if (Flags.getOpcode() != AArch64ISD::SUBS) + return SDValue(); + + SDLoc DL(N); + SDValue N0 = N->getOperand(0); + if (N0->getOpcode() == ISD::SUB) + return DAG.getNode(AArch64ISD::SBC, DL, VT, N0.getOperand(0), + N0.getOperand(1), Flags); + return DAG.getNode(AArch64ISD::SBC, DL, VT, N0, DAG.getConstant(0, DL, VT), + Flags); +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { // Try to change sum of two reductions. @@ -22329,6 +22360,8 @@ static SDValue performAddSubCombine(SDNode *N, return Val; if (SDValue Val = performAddSubIntoVectorOp(N, DCI.DAG)) return Val; + if (SDValue Val = performSubWithBorrowCombine(N, DCI.DAG)) + return Val; if (SDValue Val = performExtBinopLoadFold(N, DCI.DAG)) return Val; diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 76f076a..b30e3d0 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -4444,6 +4444,11 @@ defm PRFUM : PrefetchUnscaled<0b11, 0, 0b10, "prfum", [(AArch64Prefetch timm:$Rt, (am_unscaled64 GPR64sp:$Rn, simm9:$offset))]>; +// PRFM falls back to PRFUM for negative or unaligned offsets (not a multiple +// of 8). +def : InstAlias<"prfm $Rt, [$Rn, $offset]", + (PRFUMi prfop:$Rt, GPR64sp:$Rn, simm9_offset_fb64:$offset), 0>; + //--- // (unscaled immediate, unprivileged) defm LDTRX : LoadUnprivileged<0b11, 0, 0b01, GPR64, "ldtr">; diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index e1f4386..65b6077 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -3597,6 +3597,18 @@ let Predicates = [HasSVE_or_SME] in { def : Pat<(sext (i32 (vector_extract nxv4i32:$vec, VectorIndexS:$index))), (SMOVvi32to64 (v4i32 (EXTRACT_SUBREG ZPR:$vec, zsub)), VectorIndexS:$index)>; + + // Extracts of ``unsigned'' i8 or i16 elements lead to the zero-extend being + // transformed to an AND mask. The mask is redundant since UMOV already zeroes + // the high bits of the destination register. + def : Pat<(i32 (and (vector_extract nxv16i8:$vec, VectorIndexB:$index), 0xff)), + (UMOVvi8 (v16i8 (EXTRACT_SUBREG ZPR:$vec, zsub)), VectorIndexB:$index)>; + def : Pat<(i32 (and (vector_extract nxv8i16:$vec, VectorIndexH:$index), 0xffff)), + (UMOVvi16 (v8i16 (EXTRACT_SUBREG ZPR:$vec, zsub)), VectorIndexH:$index)>; + def : Pat<(i64 (and (i64 (anyext (i32 (vector_extract nxv16i8:$vec, VectorIndexB:$index)))), (i64 0xff))), + (SUBREG_TO_REG (i64 0), (i32 (UMOVvi8 (v16i8 (EXTRACT_SUBREG ZPR:$vec, zsub)), VectorIndexB:$index)), sub_32)>; + def : Pat<(i64 (and (i64 (anyext (i32 (vector_extract nxv8i16:$vec, VectorIndexH:$index)))), (i64 0xffff))), + (SUBREG_TO_REG (i64 0), (i32 (UMOVvi16 (v8i16 (EXTRACT_SUBREG ZPR:$vec, zsub)), VectorIndexH:$index)), sub_32)>; } // End HasNEON // Extract first element from vector. diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp index 068954f..0bf2b31 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -54,7 +54,6 @@ #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h" #include <memory> #include <optional> -#include <string> using namespace llvm; diff --git a/llvm/lib/Target/AVR/AVRTargetTransformInfo.h b/llvm/lib/Target/AVR/AVRTargetTransformInfo.h index 0daeeb8..338a7c8 100644 --- a/llvm/lib/Target/AVR/AVRTargetTransformInfo.h +++ b/llvm/lib/Target/AVR/AVRTargetTransformInfo.h @@ -21,7 +21,6 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/BasicTTIImpl.h" #include "llvm/IR/Function.h" -#include <optional> namespace llvm { diff --git a/llvm/lib/Target/CSKY/CSKYConstantIslandPass.cpp b/llvm/lib/Target/CSKY/CSKYConstantIslandPass.cpp index 7885d93..a2cf0a5 100644 --- a/llvm/lib/Target/CSKY/CSKYConstantIslandPass.cpp +++ b/llvm/lib/Target/CSKY/CSKYConstantIslandPass.cpp @@ -48,7 +48,6 @@ #include "llvm/Support/Format.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include <algorithm> #include <cassert> #include <cstdint> #include <iterator> diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp index 677203d..95577dd 100644 --- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp +++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp @@ -29,7 +29,6 @@ #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include <cstdint> -#include <optional> using namespace llvm; using namespace llvm::dxil; diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h index b990b6c..ec82aa9 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.h +++ b/llvm/lib/Target/DirectX/DXILRootSignature.h @@ -21,7 +21,6 @@ #include "llvm/IR/PassManager.h" #include "llvm/MC/DXContainerRootSignature.h" #include "llvm/Pass.h" -#include <optional> namespace llvm { namespace dxil { diff --git a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.h b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.h index f2c00c7..7cbc092 100644 --- a/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.h +++ b/llvm/lib/Target/DirectX/DXILWriter/DXILBitcodeWriter.h @@ -19,7 +19,6 @@ #include "llvm/Support/Allocator.h" #include "llvm/Support/MemoryBufferRef.h" #include <memory> -#include <string> #include <vector> namespace llvm { diff --git a/llvm/lib/Target/M68k/M68kSubtarget.h b/llvm/lib/Target/M68k/M68kSubtarget.h index 16ca7d2..4f96858 100644 --- a/llvm/lib/Target/M68k/M68kSubtarget.h +++ b/llvm/lib/Target/M68k/M68kSubtarget.h @@ -27,8 +27,6 @@ #include "llvm/MC/MCInstrItineraries.h" #include "llvm/Support/Alignment.h" -#include <string> - #define GET_SUBTARGETINFO_HEADER #include "M68kGenSubtargetInfo.inc" diff --git a/llvm/lib/Target/PowerPC/PPCSubtarget.h b/llvm/lib/Target/PowerPC/PPCSubtarget.h index f275802..7d93358 100644 --- a/llvm/lib/Target/PowerPC/PPCSubtarget.h +++ b/llvm/lib/Target/PowerPC/PPCSubtarget.h @@ -23,7 +23,6 @@ #include "llvm/IR/DataLayout.h" #include "llvm/MC/MCInstrItineraries.h" #include "llvm/TargetParser/Triple.h" -#include <string> #define GET_SUBTARGETINFO_HEADER #include "PPCGenSubtargetInfo.inc" diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td index 5b72334..0b964c4 100644 --- a/llvm/lib/Target/RISCV/RISCVFeatures.td +++ b/llvm/lib/Target/RISCV/RISCVFeatures.td @@ -956,6 +956,9 @@ def FeatureStdExtSsdbltrp def FeatureStdExtSmepmp : RISCVExtension<1, 0, "Enhanced Physical Memory Protection">; +def FeatureStdExtSmpmpmt + : RISCVExperimentalExtension<0, 6, "PMP-based Memory Types Extension">; + def FeatureStdExtSmrnmi : RISCVExtension<1, 0, "Resumable Non-Maskable Interrupts">; def HasStdExtSmrnmi : Predicate<"Subtarget->hasStdExtSmrnmi()">, diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 332433b..3d8eb40 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -1683,7 +1683,8 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, !TypeSize::isKnownLE(DL.getTypeSizeInBits(Src), SrcLT.second.getSizeInBits()) || !TypeSize::isKnownLE(DL.getTypeSizeInBits(Dst), - DstLT.second.getSizeInBits())) + DstLT.second.getSizeInBits()) || + SrcLT.first > 1 || DstLT.first > 1) return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); // The split cost is handled by the base getCastInstrCost diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index 0175f2f..970b83d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -612,13 +612,10 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { // Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of // type int32 with 0 value to represent the FP Fast Math Mode. std::vector<const MachineInstr *> SPIRVFloatTypes; - const MachineInstr *ConstZero = nullptr; + const MachineInstr *ConstZeroInt32 = nullptr; for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { - // Skip if the instruction is not OpTypeFloat or OpConstant. unsigned OpCode = MI->getOpcode(); - if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantNull) - continue; // Collect the SPIRV type if it's a float. if (OpCode == SPIRV::OpTypeFloat) { @@ -629,14 +626,18 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { continue; } SPIRVFloatTypes.push_back(MI); - } else { + continue; + } + + if (OpCode == SPIRV::OpConstantNull) { // Check if the constant is int32, if not skip it. const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo(); MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg()); - if (!TypeMI || TypeMI->getOperand(1).getImm() != 32) - continue; - - ConstZero = MI; + bool IsInt32Ty = TypeMI && + TypeMI->getOpcode() == SPIRV::OpTypeInt && + TypeMI->getOperand(1).getImm() == 32; + if (IsInt32Ty) + ConstZeroInt32 = MI; } } @@ -657,9 +658,9 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { MCRegister TypeReg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg()); Inst.addOperand(MCOperand::createReg(TypeReg)); - assert(ConstZero && "There should be a constant zero."); + assert(ConstZeroInt32 && "There should be a constant zero."); MCRegister ConstReg = MAI->getRegisterAlias( - ConstZero->getMF(), ConstZero->getOperand(0).getReg()); + ConstZeroInt32->getMF(), ConstZeroInt32->getOperand(0).getReg()); Inst.addOperand(MCOperand::createReg(ConstReg)); outputMCInst(Inst); } diff --git a/llvm/lib/Target/Sparc/SparcSubtarget.h b/llvm/lib/Target/Sparc/SparcSubtarget.h index b1decca..f575f6d 100644 --- a/llvm/lib/Target/Sparc/SparcSubtarget.h +++ b/llvm/lib/Target/Sparc/SparcSubtarget.h @@ -21,7 +21,6 @@ #include "llvm/IR/DataLayout.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/TargetParser/Triple.h" -#include <string> #define GET_SUBTARGETINFO_HEADER #include "SparcGenSubtargetInfo.inc" diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp index 92a9812..70f7b88 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp @@ -119,18 +119,82 @@ InstructionCost WebAssemblyTTIImpl::getCastInstrCost( } } - // extend_low static constexpr TypeConversionCostTblEntry ConversionTbl[] = { + // extend_low {ISD::SIGN_EXTEND, MVT::v2i64, MVT::v2i32, 1}, {ISD::ZERO_EXTEND, MVT::v2i64, MVT::v2i32, 1}, {ISD::SIGN_EXTEND, MVT::v4i32, MVT::v4i16, 1}, {ISD::ZERO_EXTEND, MVT::v4i32, MVT::v4i16, 1}, {ISD::SIGN_EXTEND, MVT::v8i16, MVT::v8i8, 1}, {ISD::ZERO_EXTEND, MVT::v8i16, MVT::v8i8, 1}, + // 2 x extend_low {ISD::SIGN_EXTEND, MVT::v2i64, MVT::v2i16, 2}, {ISD::ZERO_EXTEND, MVT::v2i64, MVT::v2i16, 2}, {ISD::SIGN_EXTEND, MVT::v4i32, MVT::v4i8, 2}, {ISD::ZERO_EXTEND, MVT::v4i32, MVT::v4i8, 2}, + // extend_low, extend_high + {ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i32, 2}, + {ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i32, 2}, + {ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i16, 2}, + {ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i16, 2}, + {ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2}, + {ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2}, + // 2x extend_low, extend_high + {ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i32, 4}, + {ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i32, 4}, + {ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i16, 4}, + {ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i16, 4}, + // shuffle + {ISD::TRUNCATE, MVT::v2i16, MVT::v2i32, 2}, + {ISD::TRUNCATE, MVT::v2i8, MVT::v2i32, 4}, + {ISD::TRUNCATE, MVT::v4i16, MVT::v4i32, 2}, + {ISD::TRUNCATE, MVT::v4i8, MVT::v4i32, 4}, + // narrow, and + {ISD::TRUNCATE, MVT::v8i16, MVT::v8i32, 2}, + {ISD::TRUNCATE, MVT::v8i8, MVT::v8i16, 2}, + // narrow, 2x and + {ISD::TRUNCATE, MVT::v16i8, MVT::v16i16, 3}, + // 3x narrow, 4x and + {ISD::TRUNCATE, MVT::v8i16, MVT::v8i64, 7}, + {ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 7}, + // 7x narrow, 8x and + {ISD::TRUNCATE, MVT::v16i8, MVT::v16i64, 15}, + // convert_i32x4 + {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1}, + {ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1}, + {ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1}, + {ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1}, + // extend_low, convert + {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 2}, + {ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 2}, + {ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2}, + {ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2}, + // extend_low x 2, convert + {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8, 3}, + {ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8, 3}, + {ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8, 3}, + {ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8, 3}, + // several shuffles + {ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8, 10}, + {ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8, 10}, + {ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 10}, + {ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8, 10}, + /// trunc_sat, const, and, 3x narrow + {ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f32, 6}, + {ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f32, 6}, + {ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f32, 6}, + {ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f32, 6}, + /// trunc_sat, const, and, narrow + {ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 4}, + {ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 4}, + {ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 4}, + {ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 4}, + // 2x trunc_sat, const, 2x and, 3x narrow + {ISD::FP_TO_SINT, MVT::v8i8, MVT::v8f32, 8}, + {ISD::FP_TO_UINT, MVT::v8i8, MVT::v8f32, 8}, + // 2x trunc_sat, const, 2x and, narrow + {ISD::FP_TO_SINT, MVT::v8i16, MVT::v8f32, 6}, + {ISD::FP_TO_UINT, MVT::v8i16, MVT::v8f32, 6}, }; if (const auto *Entry = diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h index 2573066..4146c0e 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h @@ -21,7 +21,6 @@ #include "WebAssemblyTargetMachine.h" #include "llvm/CodeGen/BasicTTIImpl.h" -#include <algorithm> namespace llvm { diff --git a/llvm/lib/Target/X86/X86FlagsCopyLowering.cpp b/llvm/lib/Target/X86/X86FlagsCopyLowering.cpp index ab6e6d0..b3bf37a 100644 --- a/llvm/lib/Target/X86/X86FlagsCopyLowering.cpp +++ b/llvm/lib/Target/X86/X86FlagsCopyLowering.cpp @@ -50,7 +50,6 @@ #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include <algorithm> #include <cassert> #include <iterator> #include <utility> diff --git a/llvm/lib/Target/X86/X86TargetMachine.cpp b/llvm/lib/Target/X86/X86TargetMachine.cpp index 0c2bd7c..d4ad98a 100644 --- a/llvm/lib/Target/X86/X86TargetMachine.cpp +++ b/llvm/lib/Target/X86/X86TargetMachine.cpp @@ -50,7 +50,6 @@ #include "llvm/Transforms/CFGuard.h" #include <memory> #include <optional> -#include <string> using namespace llvm; diff --git a/llvm/lib/Transforms/Scalar/DropUnnecessaryAssumes.cpp b/llvm/lib/Transforms/Scalar/DropUnnecessaryAssumes.cpp index a577f51..4a7144f 100644 --- a/llvm/lib/Transforms/Scalar/DropUnnecessaryAssumes.cpp +++ b/llvm/lib/Transforms/Scalar/DropUnnecessaryAssumes.cpp @@ -78,11 +78,16 @@ DropUnnecessaryAssumesPass::run(Function &F, FunctionAnalysisManager &FAM) { SmallVector<OperandBundleDef> KeptBundles; unsigned NumBundles = Assume->getNumOperandBundles(); for (unsigned I = 0; I != NumBundles; ++I) { - auto IsDead = [](OperandBundleUse Bundle) { + auto IsDead = [&](OperandBundleUse Bundle) { // "ignore" operand bundles are always dead. if (Bundle.getTagName() == "ignore") return true; + // "dereferenceable" operand bundles are only dropped if requested + // (e.g., after loop vectorization has run). + if (Bundle.getTagName() == "dereferenceable") + return DropDereferenceable; + // Bundles without arguments do not affect any specific values. // Always keep them for now. if (Bundle.Inputs.empty()) diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 45b5570..566d6ea 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1232,6 +1232,30 @@ public: /// Superset of instructions that return true for isScalarWithPredication. bool isPredicatedInst(Instruction *I) const; + /// A helper function that returns how much we should divide the cost of a + /// predicated block by. Typically this is the reciprocal of the block + /// probability, i.e. if we return X we are assuming the predicated block will + /// execute once for every X iterations of the loop header so the block should + /// only contribute 1/X of its cost to the total cost calculation, but when + /// optimizing for code size it will just be 1 as code size costs don't depend + /// on execution probabilities. + /// + /// TODO: We should use actual block probability here, if available. + /// Currently, we always assume predicated blocks have a 50% chance of + /// executing, apart from blocks that are only predicated due to tail folding. + inline unsigned + getPredBlockCostDivisor(TargetTransformInfo::TargetCostKind CostKind, + BasicBlock *BB) const { + // If a block wasn't originally predicated but was predicated due to + // e.g. tail folding, don't divide the cost. Tail folded loops may still be + // predicated in the final vector loop iteration, but for most loops that + // don't have low trip counts we can expect their probability to be close to + // zero. + if (!Legal->blockNeedsPredication(BB)) + return 1; + return CostKind == TTI::TCK_CodeSize ? 1 : 2; + } + /// Return the costs for our two available strategies for lowering a /// div/rem operation which requires speculating at least one lane. /// First result is for scalarization (will be invalid for scalable @@ -2887,7 +2911,8 @@ LoopVectorizationCostModel::getDivRemSpeculationCost(Instruction *I, // Scale the cost by the probability of executing the predicated blocks. // This assumes the predicated block for each vector lane is equally // likely. - ScalarizationCost = ScalarizationCost / getPredBlockCostDivisor(CostKind); + ScalarizationCost = + ScalarizationCost / getPredBlockCostDivisor(CostKind, I->getParent()); } InstructionCost SafeDivisorCost = 0; @@ -5032,7 +5057,7 @@ InstructionCost LoopVectorizationCostModel::computePredInstDiscount( } // Scale the total scalar cost by block probability. - ScalarCost /= getPredBlockCostDivisor(CostKind); + ScalarCost /= getPredBlockCostDivisor(CostKind, I->getParent()); // Compute the discount. A non-negative discount means the vector version // of the instruction costs more, and scalarizing would be beneficial. @@ -5082,10 +5107,11 @@ InstructionCost LoopVectorizationCostModel::expectedCost(ElementCount VF) { // stores and instructions that may divide by zero) will now be // unconditionally executed. For the scalar case, we may not always execute // the predicated block, if it is an if-else block. Thus, scale the block's - // cost by the probability of executing it. blockNeedsPredication from - // Legal is used so as to not include all blocks in tail folded loops. - if (VF.isScalar() && Legal->blockNeedsPredication(BB)) - BlockCost /= getPredBlockCostDivisor(CostKind); + // cost by the probability of executing it. + // getPredBlockCostDivisor will return 1 for blocks that are only predicated + // by the header mask when folding the tail. + if (VF.isScalar()) + BlockCost /= getPredBlockCostDivisor(CostKind, BB); Cost += BlockCost; } @@ -5164,7 +5190,7 @@ LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, // conditional branches, but may not be executed for each vector lane. Scale // the cost by the probability of executing the predicated block. if (isPredicatedInst(I)) { - Cost /= getPredBlockCostDivisor(CostKind); + Cost /= getPredBlockCostDivisor(CostKind, I->getParent()); // Add the cost of an i1 extract and a branch auto *VecI1Ty = @@ -6732,6 +6758,10 @@ bool VPCostContext::skipCostComputation(Instruction *UI, bool IsVector) const { SkipCostComputation.contains(UI); } +unsigned VPCostContext::getPredBlockCostDivisor(BasicBlock *BB) const { + return CM.getPredBlockCostDivisor(CostKind, BB); +} + InstructionCost LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF, VPCostContext &CostCtx) const { diff --git a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h index 965426f..caabfa7 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h +++ b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h @@ -50,21 +50,6 @@ Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF); Value *createStepForVF(IRBuilderBase &B, Type *Ty, ElementCount VF, int64_t Step); -/// A helper function that returns how much we should divide the cost of a -/// predicated block by. Typically this is the reciprocal of the block -/// probability, i.e. if we return X we are assuming the predicated block will -/// execute once for every X iterations of the loop header so the block should -/// only contribute 1/X of its cost to the total cost calculation, but when -/// optimizing for code size it will just be 1 as code size costs don't depend -/// on execution probabilities. -/// -/// TODO: We should use actual block probability here, if available. Currently, -/// we always assume predicated blocks have a 50% chance of executing. -inline unsigned -getPredBlockCostDivisor(TargetTransformInfo::TargetCostKind CostKind) { - return CostKind == TTI::TCK_CodeSize ? 1 : 2; -} - /// A range of powers-of-2 vectorization factors with fixed start and /// adjustable end. The range includes start and excludes end, e.g.,: /// [1, 16) = {1, 2, 4, 8} @@ -367,6 +352,10 @@ struct VPCostContext { /// has already been pre-computed. bool skipCostComputation(Instruction *UI, bool IsVector) const; + /// \returns how much the cost of a predicated block should be divided by. + /// Forwards to LoopVectorizationCostModel::getPredBlockCostDivisor. + unsigned getPredBlockCostDivisor(BasicBlock *BB) const; + /// Returns the OperandInfo for \p V, if it is a live-in. TargetTransformInfo::OperandValueInfo getOperandInfo(VPValue *V) const; diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 80cd112..707886f 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -3349,7 +3349,7 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF, // Scale the cost by the probability of executing the predicated blocks. // This assumes the predicated block for each vector lane is equally // likely. - ScalarCost /= getPredBlockCostDivisor(Ctx.CostKind); + ScalarCost /= Ctx.getPredBlockCostDivisor(UI->getParent()); return ScalarCost; } case Instruction::Load: diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 634df51..b319fbc7 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -1745,17 +1745,17 @@ static bool simplifyBranchConditionForVFAndUF(VPlan &Plan, ElementCount BestVF, if (match(Term, m_BranchOnCount()) || match(Term, m_BranchOnCond(m_Not(m_ActiveLaneMask( m_VPValue(), m_VPValue(), m_VPValue()))))) { - // Try to simplify the branch condition if TC <= VF * UF when the latch - // terminator is BranchOnCount or BranchOnCond where the input is - // Not(ActiveLaneMask). - const SCEV *TripCount = - vputils::getSCEVExprForVPValue(Plan.getTripCount(), SE); - assert(!isa<SCEVCouldNotCompute>(TripCount) && + // Try to simplify the branch condition if VectorTC <= VF * UF when the + // latch terminator is BranchOnCount or BranchOnCond(Not(ActiveLaneMask)). + const SCEV *VectorTripCount = + vputils::getSCEVExprForVPValue(&Plan.getVectorTripCount(), SE); + if (isa<SCEVCouldNotCompute>(VectorTripCount)) + VectorTripCount = vputils::getSCEVExprForVPValue(Plan.getTripCount(), SE); + assert(!isa<SCEVCouldNotCompute>(VectorTripCount) && "Trip count SCEV must be computable"); ElementCount NumElements = BestVF.multiplyCoefficientBy(BestUF); - const SCEV *C = SE.getElementCount(TripCount->getType(), NumElements); - if (TripCount->isZero() || - !SE.isKnownPredicate(CmpInst::ICMP_ULE, TripCount, C)) + const SCEV *C = SE.getElementCount(VectorTripCount->getType(), NumElements); + if (!SE.isKnownPredicate(CmpInst::ICMP_ULE, VectorTripCount, C)) return false; } else if (match(Term, m_BranchOnCond(m_VPValue(Cond)))) { // For BranchOnCond, check if we can prove the condition to be true using VF |
