diff options
Diffstat (limited to 'llvm/lib')
52 files changed, 1610 insertions, 469 deletions
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index d52b073..b744537 100755 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -1482,6 +1482,15 @@ Constant *llvm::ConstantFoldFPInstOperands(unsigned Opcode, Constant *LHS, Constant *llvm::ConstantFoldCastOperand(unsigned Opcode, Constant *C, Type *DestTy, const DataLayout &DL) { assert(Instruction::isCast(Opcode)); + + if (auto *CE = dyn_cast<ConstantExpr>(C)) + if (CE->isCast()) + if (unsigned NewOp = CastInst::isEliminableCastPair( + Instruction::CastOps(CE->getOpcode()), + Instruction::CastOps(Opcode), CE->getOperand(0)->getType(), + C->getType(), DestTy, &DL)) + return ConstantFoldCastOperand(NewOp, CE->getOperand(0), DestTy, DL); + switch (Opcode) { default: llvm_unreachable("Missing case"); diff --git a/llvm/lib/Analysis/HeatUtils.cpp b/llvm/lib/Analysis/HeatUtils.cpp index b6a0d58..a1cc707 100644 --- a/llvm/lib/Analysis/HeatUtils.cpp +++ b/llvm/lib/Analysis/HeatUtils.cpp @@ -17,10 +17,10 @@ #include <cmath> -namespace llvm { +using namespace llvm; -static const unsigned heatSize = 100; -static const char heatPalette[heatSize][8] = { +static constexpr unsigned HeatSize = 100; +static constexpr char HeatPalette[HeatSize][8] = { "#3d50c3", "#4055c8", "#4358cb", "#465ecf", "#4961d2", "#4c66d6", "#4f69d9", "#536edd", "#5572df", "#5977e3", "#5b7ae5", "#5f7fe8", "#6282ea", "#6687ed", "#6a8bef", "#6c8ff1", "#7093f3", "#7396f5", "#779af7", "#7a9df8", "#7ea1fa", @@ -37,43 +37,37 @@ static const char heatPalette[heatSize][8] = { "#d24b40", "#d0473d", "#cc403a", "#ca3b37", "#c53334", "#c32e31", "#be242e", "#bb1b2c", "#b70d28"}; -uint64_t -getNumOfCalls(Function &callerFunction, Function &calledFunction) { - uint64_t counter = 0; - for (User *U : calledFunction.users()) { - if (auto CI = dyn_cast<CallInst>(U)) { - if (CI->getCaller() == (&callerFunction)) { - counter += 1; - } - } - } - return counter; +uint64_t llvm::getNumOfCalls(const Function &CallerFunction, + const Function &CalledFunction) { + uint64_t Counter = 0; + for (const User *U : CalledFunction.users()) + if (auto CI = dyn_cast<CallInst>(U)) + Counter += CI->getCaller() == &CallerFunction; + return Counter; } -uint64_t getMaxFreq(const Function &F, const BlockFrequencyInfo *BFI) { - uint64_t maxFreq = 0; +uint64_t llvm::getMaxFreq(const Function &F, const BlockFrequencyInfo *BFI) { + uint64_t MaxFreq = 0; for (const BasicBlock &BB : F) { - uint64_t freqVal = BFI->getBlockFreq(&BB).getFrequency(); - if (freqVal >= maxFreq) - maxFreq = freqVal; + uint64_t FreqVal = BFI->getBlockFreq(&BB).getFrequency(); + if (FreqVal >= MaxFreq) + MaxFreq = FreqVal; } - return maxFreq; + return MaxFreq; } -std::string getHeatColor(uint64_t freq, uint64_t maxFreq) { - if (freq > maxFreq) - freq = maxFreq; - double percent = (freq > 0) ? log2(double(freq)) / log2(maxFreq) : 0; - return getHeatColor(percent); +std::string llvm::getHeatColor(uint64_t Freq, uint64_t MaxFreq) { + if (Freq > MaxFreq) + Freq = MaxFreq; + double Percent = (Freq > 0) ? log2(double(Freq)) / log2(MaxFreq) : 0; + return getHeatColor(Percent); } -std::string getHeatColor(double percent) { - if (percent > 1.0) - percent = 1.0; - if (percent < 0.0) - percent = 0.0; - unsigned colorId = unsigned(round(percent * (heatSize - 1.0))); - return heatPalette[colorId]; +std::string llvm::getHeatColor(double Percent) { + if (Percent > 1.0) + Percent = 1.0; + if (Percent < 0.0) + Percent = 0.0; + unsigned ColorID = unsigned(round(Percent * (HeatSize - 1.0))); + return HeatPalette[ColorID]; } - -} // namespace llvm diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 0d978d4..d1977f0 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -5425,15 +5425,8 @@ static Value *simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, if (Src->getType() == Ty) { auto FirstOp = CI->getOpcode(); auto SecondOp = static_cast<Instruction::CastOps>(CastOpc); - Type *SrcIntPtrTy = - SrcTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(SrcTy) : nullptr; - Type *MidIntPtrTy = - MidTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(MidTy) : nullptr; - Type *DstIntPtrTy = - DstTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(DstTy) : nullptr; if (CastInst::isEliminableCastPair(FirstOp, SecondOp, SrcTy, MidTy, DstTy, - SrcIntPtrTy, MidIntPtrTy, - DstIntPtrTy) == Instruction::BitCast) + &Q.DL) == Instruction::BitCast) return Src; } } @@ -6473,7 +6466,8 @@ static Value *foldMinMaxSharedOp(Intrinsic::ID IID, Value *Op0, Value *Op1) { static Value *foldMinimumMaximumSharedOp(Intrinsic::ID IID, Value *Op0, Value *Op1) { assert((IID == Intrinsic::maxnum || IID == Intrinsic::minnum || - IID == Intrinsic::maximum || IID == Intrinsic::minimum) && + IID == Intrinsic::maximum || IID == Intrinsic::minimum || + IID == Intrinsic::maximumnum || IID == Intrinsic::minimumnum) && "Unsupported intrinsic"); auto *M0 = dyn_cast<IntrinsicInst>(Op0); @@ -6512,6 +6506,82 @@ static Value *foldMinimumMaximumSharedOp(Intrinsic::ID IID, Value *Op0, return nullptr; } +enum class MinMaxOptResult { + CannotOptimize = 0, + UseNewConstVal = 1, + UseOtherVal = 2, + // For undef/poison, we can choose to either propgate undef/poison or + // use the LHS value depending on what will allow more optimization. + UseEither = 3 +}; +// Get the optimized value for a min/max instruction with a single constant +// input (either undef or scalar constantFP). The result may indicate to +// use the non-const LHS value, use a new constant value instead (with NaNs +// quieted), or to choose either option in the case of undef/poison. +static MinMaxOptResult OptimizeConstMinMax(const Constant *RHSConst, + const Intrinsic::ID IID, + const CallBase *Call, + Constant **OutNewConstVal) { + assert(OutNewConstVal != nullptr); + + bool PropagateNaN = IID == Intrinsic::minimum || IID == Intrinsic::maximum; + bool PropagateSNaN = IID == Intrinsic::minnum || IID == Intrinsic::maxnum; + bool IsMin = IID == Intrinsic::minimum || IID == Intrinsic::minnum || + IID == Intrinsic::minimumnum; + + // min/max(x, poison) -> either x or poison + if (isa<UndefValue>(RHSConst)) { + *OutNewConstVal = const_cast<Constant *>(RHSConst); + return MinMaxOptResult::UseEither; + } + + const ConstantFP *CFP = dyn_cast<ConstantFP>(RHSConst); + if (!CFP) + return MinMaxOptResult::CannotOptimize; + APFloat CAPF = CFP->getValueAPF(); + + // minnum(x, qnan) -> x + // maxnum(x, qnan) -> x + // minnum(x, snan) -> qnan + // maxnum(x, snan) -> qnan + // minimum(X, nan) -> qnan + // maximum(X, nan) -> qnan + // minimumnum(X, nan) -> x + // maximumnum(X, nan) -> x + if (CAPF.isNaN()) { + if (PropagateNaN || (PropagateSNaN && CAPF.isSignaling())) { + *OutNewConstVal = ConstantFP::get(CFP->getType(), CAPF.makeQuiet()); + return MinMaxOptResult::UseNewConstVal; + } + return MinMaxOptResult::UseOtherVal; + } + + if (CAPF.isInfinity() || (Call && Call->hasNoInfs() && CAPF.isLargest())) { + // minnum(X, -inf) -> -inf (ignoring sNaN -> qNaN propagation) + // maxnum(X, +inf) -> +inf (ignoring sNaN -> qNaN propagation) + // minimum(X, -inf) -> -inf if nnan + // maximum(X, +inf) -> +inf if nnan + // minimumnum(X, -inf) -> -inf + // maximumnum(X, +inf) -> +inf + if (CAPF.isNegative() == IsMin && + (!PropagateNaN || (Call && Call->hasNoNaNs()))) { + *OutNewConstVal = const_cast<Constant *>(RHSConst); + return MinMaxOptResult::UseNewConstVal; + } + + // minnum(X, +inf) -> X if nnan + // maxnum(X, -inf) -> X if nnan + // minimum(X, +inf) -> X (ignoring quieting of sNaNs) + // maximum(X, -inf) -> X (ignoring quieting of sNaNs) + // minimumnum(X, +inf) -> X if nnan + // maximumnum(X, -inf) -> X if nnan + if (CAPF.isNegative() != IsMin && + (PropagateNaN || (Call && Call->hasNoNaNs()))) + return MinMaxOptResult::UseOtherVal; + } + return MinMaxOptResult::CannotOptimize; +} + Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType, Value *Op0, Value *Op1, const SimplifyQuery &Q, @@ -6780,8 +6850,17 @@ Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType, case Intrinsic::maxnum: case Intrinsic::minnum: case Intrinsic::maximum: - case Intrinsic::minimum: { - // If the arguments are the same, this is a no-op. + case Intrinsic::minimum: + case Intrinsic::maximumnum: + case Intrinsic::minimumnum: { + // In several cases here, we deviate from exact IEEE 754 semantics + // to enable optimizations (as allowed by the LLVM IR spec). + // + // For instance, we may return one of the arguments unmodified instead of + // inserting an llvm.canonicalize to transform input sNaNs into qNaNs, + // or may assume all NaN inputs are qNaNs. + + // If the arguments are the same, this is a no-op (ignoring NaN quieting) if (Op0 == Op1) return Op0; @@ -6789,40 +6868,55 @@ Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType, if (isa<Constant>(Op0)) std::swap(Op0, Op1); - // If an argument is undef, return the other argument. - if (Q.isUndefValue(Op1)) - return Op0; + if (Constant *C = dyn_cast<Constant>(Op1)) { + MinMaxOptResult OptResult = MinMaxOptResult::CannotOptimize; + Constant *NewConst = nullptr; + + if (VectorType *VTy = dyn_cast<VectorType>(C->getType())) { + ElementCount ElemCount = VTy->getElementCount(); + + if (Constant *SplatVal = C->getSplatValue()) { + // Handle splat vectors (including scalable vectors) + OptResult = OptimizeConstMinMax(SplatVal, IID, Call, &NewConst); + if (OptResult == MinMaxOptResult::UseNewConstVal) + NewConst = ConstantVector::getSplat(ElemCount, NewConst); + + } else if (ElemCount.isFixed()) { + // Storage to build up new const return value (with NaNs quieted) + SmallVector<Constant *, 16> NewC(ElemCount.getFixedValue()); + + // Check elementwise whether we can optimize to either a constant + // value or return the LHS value. We cannot mix and match LHS + + // constant elements, as this would require inserting a new + // VectorShuffle instruction, which is not allowed in simplifyBinOp. + OptResult = MinMaxOptResult::UseEither; + for (unsigned i = 0; i != ElemCount.getFixedValue(); ++i) { + auto ElemResult = OptimizeConstMinMax(C->getAggregateElement(i), + IID, Call, &NewConst); + if (ElemResult == MinMaxOptResult::CannotOptimize || + (ElemResult != OptResult && + OptResult != MinMaxOptResult::UseEither && + ElemResult != MinMaxOptResult::UseEither)) { + OptResult = MinMaxOptResult::CannotOptimize; + break; + } + NewC[i] = NewConst; + if (ElemResult != MinMaxOptResult::UseEither) + OptResult = ElemResult; + } + if (OptResult == MinMaxOptResult::UseNewConstVal) + NewConst = ConstantVector::get(NewC); + } + } else { + // Handle scalar inputs + OptResult = OptimizeConstMinMax(C, IID, Call, &NewConst); + } - bool PropagateNaN = IID == Intrinsic::minimum || IID == Intrinsic::maximum; - bool IsMin = IID == Intrinsic::minimum || IID == Intrinsic::minnum; - - // minnum(X, nan) -> X - // maxnum(X, nan) -> X - // minimum(X, nan) -> nan - // maximum(X, nan) -> nan - if (match(Op1, m_NaN())) - return PropagateNaN ? propagateNaN(cast<Constant>(Op1)) : Op0; - - // In the following folds, inf can be replaced with the largest finite - // float, if the ninf flag is set. - const APFloat *C; - if (match(Op1, m_APFloat(C)) && - (C->isInfinity() || (Call && Call->hasNoInfs() && C->isLargest()))) { - // minnum(X, -inf) -> -inf - // maxnum(X, +inf) -> +inf - // minimum(X, -inf) -> -inf if nnan - // maximum(X, +inf) -> +inf if nnan - if (C->isNegative() == IsMin && - (!PropagateNaN || (Call && Call->hasNoNaNs()))) - return ConstantFP::get(ReturnType, *C); - - // minnum(X, +inf) -> X if nnan - // maxnum(X, -inf) -> X if nnan - // minimum(X, +inf) -> X - // maximum(X, -inf) -> X - if (C->isNegative() != IsMin && - (PropagateNaN || (Call && Call->hasNoNaNs()))) - return Op0; + if (OptResult == MinMaxOptResult::UseOtherVal || + OptResult == MinMaxOptResult::UseEither) + return Op0; // Return the other arg (ignoring NaN quieting) + else if (OptResult == MinMaxOptResult::UseNewConstVal) + return NewConst; } // Min/max of the same operation with common operand: diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp index 0c4e3a2..4c2e1fe 100644 --- a/llvm/lib/Analysis/Loads.cpp +++ b/llvm/lib/Analysis/Loads.cpp @@ -37,17 +37,13 @@ static bool isDereferenceableAndAlignedPointerViaAssumption( function_ref<bool(const RetainedKnowledge &RK)> CheckSize, const DataLayout &DL, const Instruction *CtxI, AssumptionCache *AC, const DominatorTree *DT) { - // Dereferenceable information from assumptions is only valid if the value - // cannot be freed between the assumption and use. For now just use the - // information for values that cannot be freed in the function. - // TODO: More precisely check if the pointer can be freed between assumption - // and use. - if (!CtxI || Ptr->canBeFreed()) + if (!CtxI) return false; /// Look through assumes to see if both dereferencability and alignment can /// be proven by an assume if needed. RetainedKnowledge AlignRK; RetainedKnowledge DerefRK; + bool PtrCanBeFreed = Ptr->canBeFreed(); bool IsAligned = Ptr->getPointerAlignment(DL) >= Alignment; return getKnowledgeForValue( Ptr, {Attribute::Dereferenceable, Attribute::Alignment}, *AC, @@ -56,7 +52,11 @@ static bool isDereferenceableAndAlignedPointerViaAssumption( return false; if (RK.AttrKind == Attribute::Alignment) AlignRK = std::max(AlignRK, RK); - if (RK.AttrKind == Attribute::Dereferenceable) + + // Dereferenceable information from assumptions is only valid if the + // value cannot be freed between the assumption and use. + if ((!PtrCanBeFreed || willNotFreeBetween(Assume, CtxI)) && + RK.AttrKind == Attribute::Dereferenceable) DerefRK = std::max(DerefRK, RK); IsAligned |= AlignRK && AlignRK.ArgValue >= Alignment.value(); if (IsAligned && DerefRK && CheckSize(DerefRK)) @@ -390,7 +390,11 @@ bool llvm::isDereferenceableAndAlignedInLoop( } else return false; - Instruction *HeaderFirstNonPHI = &*L->getHeader()->getFirstNonPHIIt(); + Instruction *CtxI = &*L->getHeader()->getFirstNonPHIIt(); + if (BasicBlock *LoopPred = L->getLoopPredecessor()) { + if (isa<BranchInst>(LoopPred->getTerminator())) + CtxI = LoopPred->getTerminator(); + } return isDereferenceableAndAlignedPointerViaAssumption( Base, Alignment, [&SE, AccessSizeSCEV, &LoopGuards](const RetainedKnowledge &RK) { @@ -399,9 +403,9 @@ bool llvm::isDereferenceableAndAlignedInLoop( SE.applyLoopGuards(AccessSizeSCEV, *LoopGuards), SE.applyLoopGuards(SE.getSCEV(RK.IRArgValue), *LoopGuards)); }, - DL, HeaderFirstNonPHI, AC, &DT) || + DL, CtxI, AC, &DT) || isDereferenceableAndAlignedPointer(Base, Alignment, AccessSize, DL, - HeaderFirstNonPHI, AC, &DT); + CtxI, AC, &DT); } static bool suppressSpeculativeLoadForSanitizers(const Instruction &CtxI) { diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index a42c061..9655c88 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -9095,6 +9095,10 @@ Intrinsic::ID llvm::getInverseMinMaxIntrinsic(Intrinsic::ID MinMaxID) { case Intrinsic::minimum: return Intrinsic::maximum; case Intrinsic::maxnum: return Intrinsic::minnum; case Intrinsic::minnum: return Intrinsic::maxnum; + case Intrinsic::maximumnum: + return Intrinsic::minimumnum; + case Intrinsic::minimumnum: + return Intrinsic::maximumnum; default: llvm_unreachable("Unexpected intrinsic"); } } diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp index 832aa9f..aaee1f0 100644 --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -2203,6 +2203,8 @@ static Attribute::AttrKind getAttrFromCode(uint64_t Code) { return Attribute::SanitizeRealtime; case bitc::ATTR_KIND_SANITIZE_REALTIME_BLOCKING: return Attribute::SanitizeRealtimeBlocking; + case bitc::ATTR_KIND_SANITIZE_ALLOC_TOKEN: + return Attribute::SanitizeAllocToken; case bitc::ATTR_KIND_SPECULATIVE_LOAD_HARDENING: return Attribute::SpeculativeLoadHardening; case bitc::ATTR_KIND_SWIFT_ERROR: diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp index c4070e1..6d86809 100644 --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -883,6 +883,8 @@ static uint64_t getAttrKindEncoding(Attribute::AttrKind Kind) { return bitc::ATTR_KIND_STRUCT_RET; case Attribute::SanitizeAddress: return bitc::ATTR_KIND_SANITIZE_ADDRESS; + case Attribute::SanitizeAllocToken: + return bitc::ATTR_KIND_SANITIZE_ALLOC_TOKEN; case Attribute::SanitizeHWAddress: return bitc::ATTR_KIND_SANITIZE_HWADDRESS; case Attribute::SanitizeThread: diff --git a/llvm/lib/CodeGen/AsmPrinter/DebugHandlerBase.cpp b/llvm/lib/CodeGen/AsmPrinter/DebugHandlerBase.cpp index d98d180..dc38f5a 100644 --- a/llvm/lib/CodeGen/AsmPrinter/DebugHandlerBase.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/DebugHandlerBase.cpp @@ -240,6 +240,8 @@ bool DebugHandlerBase::isUnsignedDIType(const DIType *Ty) { Encoding == dwarf::DW_ATE_complex_float || Encoding == dwarf::DW_ATE_signed_fixed || Encoding == dwarf::DW_ATE_unsigned_fixed || + (Encoding >= dwarf::DW_ATE_lo_user && + Encoding <= dwarf::DW_ATE_hi_user) || (Ty->getTag() == dwarf::DW_TAG_unspecified_type && Ty->getName() == "decltype(nullptr)")) && "Unsupported encoding"); diff --git a/llvm/lib/CodeGen/InlineSpiller.cpp b/llvm/lib/CodeGen/InlineSpiller.cpp index 0c2b74c..d6e8505 100644 --- a/llvm/lib/CodeGen/InlineSpiller.cpp +++ b/llvm/lib/CodeGen/InlineSpiller.cpp @@ -671,10 +671,22 @@ bool InlineSpiller::reMaterializeFor(LiveInterval &VirtReg, MachineInstr &MI) { LiveInterval &OrigLI = LIS.getInterval(Original); VNInfo *OrigVNI = OrigLI.getVNInfoAt(UseIdx); - LiveRangeEdit::Remat RM(ParentVNI); - RM.OrigMI = LIS.getInstructionFromIndex(OrigVNI->def); + assert(OrigVNI && "corrupted sub-interval"); + MachineInstr *DefMI = LIS.getInstructionFromIndex(OrigVNI->def); + // This can happen if for two reasons: 1) This could be a phi valno, + // or 2) the remat def has already been removed from the original + // live interval; this happens if we rematted to all uses, and + // then further split one of those live ranges. + if (!DefMI) { + markValueUsed(&VirtReg, ParentVNI); + LLVM_DEBUG(dbgs() << "\tcannot remat missing def for " << UseIdx << '\t' + << MI); + return false; + } - if (!Edit->canRematerializeAt(RM, OrigVNI, UseIdx)) { + LiveRangeEdit::Remat RM(ParentVNI); + RM.OrigMI = DefMI; + if (!Edit->canRematerializeAt(RM, UseIdx)) { markValueUsed(&VirtReg, ParentVNI); LLVM_DEBUG(dbgs() << "\tcannot remat for " << UseIdx << '\t' << MI); return false; @@ -739,9 +751,6 @@ bool InlineSpiller::reMaterializeFor(LiveInterval &VirtReg, MachineInstr &MI) { /// reMaterializeAll - Try to rematerialize as many uses as possible, /// and trim the live ranges after. void InlineSpiller::reMaterializeAll() { - if (!Edit->anyRematerializable()) - return; - UsedValues.clear(); // Try to remat before all uses of snippets. diff --git a/llvm/lib/CodeGen/LiveRangeEdit.cpp b/llvm/lib/CodeGen/LiveRangeEdit.cpp index 59bc82d..5b0365d 100644 --- a/llvm/lib/CodeGen/LiveRangeEdit.cpp +++ b/llvm/lib/CodeGen/LiveRangeEdit.cpp @@ -68,41 +68,12 @@ Register LiveRangeEdit::createFrom(Register OldReg) { return VReg; } -void LiveRangeEdit::scanRemattable() { - for (VNInfo *VNI : getParent().valnos) { - if (VNI->isUnused()) - continue; - Register Original = VRM->getOriginal(getReg()); - LiveInterval &OrigLI = LIS.getInterval(Original); - VNInfo *OrigVNI = OrigLI.getVNInfoAt(VNI->def); - if (!OrigVNI) - continue; - MachineInstr *DefMI = LIS.getInstructionFromIndex(OrigVNI->def); - if (!DefMI) - continue; - if (TII.isReMaterializable(*DefMI)) - Remattable.insert(OrigVNI); - } - ScannedRemattable = true; -} - -bool LiveRangeEdit::anyRematerializable() { - if (!ScannedRemattable) - scanRemattable(); - return !Remattable.empty(); -} - -bool LiveRangeEdit::canRematerializeAt(Remat &RM, VNInfo *OrigVNI, - SlotIndex UseIdx) { - assert(ScannedRemattable && "Call anyRematerializable first"); +bool LiveRangeEdit::canRematerializeAt(Remat &RM, SlotIndex UseIdx) { + assert(RM.OrigMI && "No defining instruction for remattable value"); - // Use scanRemattable info. - if (!Remattable.count(OrigVNI)) + if (!TII.isReMaterializable(*RM.OrigMI)) return false; - // No defining instruction provided. - assert(RM.OrigMI && "No defining instruction for remattable value"); - // Verify that all used registers are available with the same values. if (!VirtRegAuxInfo::allUsesAvailableAt(RM.OrigMI, UseIdx, LIS, MRI, TII)) return false; @@ -303,6 +274,37 @@ void LiveRangeEdit::eliminateDeadDef(MachineInstr *MI, ToShrinkSet &ToShrink) { } } + // If the dest of MI is an original reg and MI is reMaterializable, + // don't delete the inst. Replace the dest with a new reg, and keep + // the inst for remat of other siblings. The inst is saved in + // LiveRangeEdit::DeadRemats and will be deleted after all the + // allocations of the func are done. Note that if we keep the + // instruction with the original operands, that handles the physreg + // operand case (described just below) as well. + // However, immediately delete instructions which have unshrunk virtual + // register uses. That may provoke RA to split an interval at the KILL + // and later result in an invalid live segment end. + if (isOrigDef && DeadRemats && !HasLiveVRegUses && + TII.isReMaterializable(*MI)) { + LiveInterval &NewLI = createEmptyIntervalFrom(Dest, false); + VNInfo::Allocator &Alloc = LIS.getVNInfoAllocator(); + VNInfo *VNI = NewLI.getNextValue(Idx, Alloc); + NewLI.addSegment(LiveInterval::Segment(Idx, Idx.getDeadSlot(), VNI)); + + if (DestSubReg) { + const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); + auto *SR = + NewLI.createSubRange(Alloc, TRI->getSubRegIndexLaneMask(DestSubReg)); + SR->addSegment(LiveInterval::Segment(Idx, Idx.getDeadSlot(), + SR->getNextValue(Idx, Alloc))); + } + + pop_back(); + DeadRemats->insert(MI); + const TargetRegisterInfo &TRI = *MRI.getTargetRegisterInfo(); + MI->substituteRegister(Dest, NewLI.reg(), 0, TRI); + assert(MI->registerDefIsDead(NewLI.reg(), &TRI)); + } // Currently, we don't support DCE of physreg live ranges. If MI reads // any unreserved physregs, don't erase the instruction, but turn it into // a KILL instead. This way, the physreg live ranges don't end up @@ -310,7 +312,7 @@ void LiveRangeEdit::eliminateDeadDef(MachineInstr *MI, ToShrinkSet &ToShrink) { // FIXME: It would be better to have something like shrinkToUses() for // physregs. That could potentially enable more DCE and it would free up // the physreg. It would not happen often, though. - if (ReadsPhysRegs) { + else if (ReadsPhysRegs) { MI->setDesc(TII.get(TargetOpcode::KILL)); // Remove all operands that aren't physregs. for (unsigned i = MI->getNumOperands(); i; --i) { @@ -322,41 +324,11 @@ void LiveRangeEdit::eliminateDeadDef(MachineInstr *MI, ToShrinkSet &ToShrink) { MI->dropMemRefs(*MI->getMF()); LLVM_DEBUG(dbgs() << "Converted physregs to:\t" << *MI); } else { - // If the dest of MI is an original reg and MI is reMaterializable, - // don't delete the inst. Replace the dest with a new reg, and keep - // the inst for remat of other siblings. The inst is saved in - // LiveRangeEdit::DeadRemats and will be deleted after all the - // allocations of the func are done. - // However, immediately delete instructions which have unshrunk virtual - // register uses. That may provoke RA to split an interval at the KILL - // and later result in an invalid live segment end. - if (isOrigDef && DeadRemats && !HasLiveVRegUses && - TII.isReMaterializable(*MI)) { - LiveInterval &NewLI = createEmptyIntervalFrom(Dest, false); - VNInfo::Allocator &Alloc = LIS.getVNInfoAllocator(); - VNInfo *VNI = NewLI.getNextValue(Idx, Alloc); - NewLI.addSegment(LiveInterval::Segment(Idx, Idx.getDeadSlot(), VNI)); - - if (DestSubReg) { - const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); - auto *SR = NewLI.createSubRange( - Alloc, TRI->getSubRegIndexLaneMask(DestSubReg)); - SR->addSegment(LiveInterval::Segment(Idx, Idx.getDeadSlot(), - SR->getNextValue(Idx, Alloc))); - } - - pop_back(); - DeadRemats->insert(MI); - const TargetRegisterInfo &TRI = *MRI.getTargetRegisterInfo(); - MI->substituteRegister(Dest, NewLI.reg(), 0, TRI); - assert(MI->registerDefIsDead(NewLI.reg(), &TRI)); - } else { - if (TheDelegate) - TheDelegate->LRE_WillEraseInstruction(MI); - LIS.RemoveMachineInstrFromMaps(*MI); - MI->eraseFromParent(); - ++NumDCEDeleted; - } + if (TheDelegate) + TheDelegate->LRE_WillEraseInstruction(MI); + LIS.RemoveMachineInstrFromMaps(*MI); + MI->eraseFromParent(); + ++NumDCEDeleted; } // Erase any virtregs that are now empty and unused. There may be <undef> diff --git a/llvm/lib/CodeGen/SplitKit.cpp b/llvm/lib/CodeGen/SplitKit.cpp index f118ee5..f9ecb2c 100644 --- a/llvm/lib/CodeGen/SplitKit.cpp +++ b/llvm/lib/CodeGen/SplitKit.cpp @@ -376,8 +376,6 @@ void SplitEditor::reset(LiveRangeEdit &LRE, ComplementSpillMode SM) { if (SpillMode) LICalc[1].reset(&VRM.getMachineFunction(), LIS.getSlotIndexes(), &MDT, &LIS.getVNInfoAllocator()); - - Edit->anyRematerializable(); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -638,7 +636,7 @@ VNInfo *SplitEditor::defFromParent(unsigned RegIdx, const VNInfo *ParentVNI, LiveRangeEdit::Remat RM(ParentVNI); RM.OrigMI = LIS.getInstructionFromIndex(OrigVNI->def); if (RM.OrigMI && TII.isAsCheapAsAMove(*RM.OrigMI) && - Edit->canRematerializeAt(RM, OrigVNI, UseIdx)) { + Edit->canRematerializeAt(RM, UseIdx)) { if (!rematWillIncreaseRestriction(RM.OrigMI, MBB, UseIdx)) { SlotIndex Def = Edit->rematerializeAt(MBB, I, Reg, RM, TRI, Late); ++NumRemats; diff --git a/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt b/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt index f159d59..0ffe3ae 100644 --- a/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt +++ b/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt @@ -24,6 +24,7 @@ add_llvm_component_library(LLVMOrcJIT EPCGenericRTDyldMemoryManager.cpp EPCIndirectionUtils.cpp ExecutionUtils.cpp + ExecutorResolutionGenerator.cpp ObjectFileInterface.cpp GetDylibInterface.cpp IndirectionUtils.cpp diff --git a/llvm/lib/ExecutionEngine/Orc/EPCDebugObjectRegistrar.cpp b/llvm/lib/ExecutionEngine/Orc/EPCDebugObjectRegistrar.cpp index 9f7d517..08bef37 100644 --- a/llvm/lib/ExecutionEngine/Orc/EPCDebugObjectRegistrar.cpp +++ b/llvm/lib/ExecutionEngine/Orc/EPCDebugObjectRegistrar.cpp @@ -42,7 +42,12 @@ Expected<std::unique_ptr<EPCDebugObjectRegistrar>> createJITLoaderGDBRegistrar( assert((*Result)[0].size() == 1 && "Unexpected number of addresses in result"); - ExecutorAddr RegisterAddr = (*Result)[0][0].getAddress(); + if (!(*Result)[0][0].has_value()) + return make_error<StringError>( + "Expected a valid address in the lookup result", + inconvertibleErrorCode()); + + ExecutorAddr RegisterAddr = (*Result)[0][0]->getAddress(); return std::make_unique<EPCDebugObjectRegistrar>(ES, RegisterAddr); } diff --git a/llvm/lib/ExecutionEngine/Orc/EPCDynamicLibrarySearchGenerator.cpp b/llvm/lib/ExecutionEngine/Orc/EPCDynamicLibrarySearchGenerator.cpp index 59d66b2..1e83c07 100644 --- a/llvm/lib/ExecutionEngine/Orc/EPCDynamicLibrarySearchGenerator.cpp +++ b/llvm/lib/ExecutionEngine/Orc/EPCDynamicLibrarySearchGenerator.cpp @@ -79,12 +79,16 @@ Error EPCDynamicLibrarySearchGenerator::tryToGenerate( assert(Result->front().size() == LookupSymbols.size() && "Result has incorrect number of elements"); + auto SymsIt = Result->front().begin(); + SymbolNameSet MissingSymbols; SymbolMap NewSymbols; - auto ResultI = Result->front().begin(); - for (auto &KV : LookupSymbols) { - if (ResultI->getAddress()) - NewSymbols[KV.first] = *ResultI; - ++ResultI; + for (auto &[Name, Flags] : LookupSymbols) { + const auto &Sym = *SymsIt++; + if (Sym && Sym->getAddress()) + NewSymbols[Name] = *Sym; + else if (LLVM_UNLIKELY(!Sym && + Flags == SymbolLookupFlags::RequiredSymbol)) + MissingSymbols.insert(Name); } LLVM_DEBUG({ @@ -96,6 +100,10 @@ Error EPCDynamicLibrarySearchGenerator::tryToGenerate( if (NewSymbols.empty()) return LS.continueLookup(Error::success()); + if (LLVM_UNLIKELY(!MissingSymbols.empty())) + return LS.continueLookup(make_error<SymbolsNotFound>( + this->EPC.getSymbolStringPool(), std::move(MissingSymbols))); + // Define resolved symbols. Error Err = addAbsolutes(JD, std::move(NewSymbols)); diff --git a/llvm/lib/ExecutionEngine/Orc/EPCGenericDylibManager.cpp b/llvm/lib/ExecutionEngine/Orc/EPCGenericDylibManager.cpp index f98b18c..1f19d17 100644 --- a/llvm/lib/ExecutionEngine/Orc/EPCGenericDylibManager.cpp +++ b/llvm/lib/ExecutionEngine/Orc/EPCGenericDylibManager.cpp @@ -66,7 +66,7 @@ EPCGenericDylibManager::CreateWithDefaultBootstrapSymbols( if (auto Err = EPC.getBootstrapSymbols( {{SAs.Instance, rt::SimpleExecutorDylibManagerInstanceName}, {SAs.Open, rt::SimpleExecutorDylibManagerOpenWrapperName}, - {SAs.Lookup, rt::SimpleExecutorDylibManagerLookupWrapperName}})) + {SAs.Resolve, rt::SimpleExecutorDylibManagerResolveWrapperName}})) return std::move(Err); return EPCGenericDylibManager(EPC, std::move(SAs)); } @@ -84,11 +84,12 @@ Expected<tpctypes::DylibHandle> EPCGenericDylibManager::open(StringRef Path, void EPCGenericDylibManager::lookupAsync(tpctypes::DylibHandle H, const SymbolLookupSet &Lookup, SymbolLookupCompleteFn Complete) { - EPC.callSPSWrapperAsync<rt::SPSSimpleExecutorDylibManagerLookupSignature>( - SAs.Lookup, + EPC.callSPSWrapperAsync<rt::SPSSimpleExecutorDylibManagerResolveSignature>( + SAs.Resolve, [Complete = std::move(Complete)]( Error SerializationErr, - Expected<std::vector<ExecutorSymbolDef>> Result) mutable { + Expected<std::vector<std::optional<ExecutorSymbolDef>>> + Result) mutable { if (SerializationErr) { cantFail(Result.takeError()); Complete(std::move(SerializationErr)); @@ -96,17 +97,18 @@ void EPCGenericDylibManager::lookupAsync(tpctypes::DylibHandle H, } Complete(std::move(Result)); }, - SAs.Instance, H, Lookup); + H, Lookup); } void EPCGenericDylibManager::lookupAsync(tpctypes::DylibHandle H, const RemoteSymbolLookupSet &Lookup, SymbolLookupCompleteFn Complete) { - EPC.callSPSWrapperAsync<rt::SPSSimpleExecutorDylibManagerLookupSignature>( - SAs.Lookup, + EPC.callSPSWrapperAsync<rt::SPSSimpleExecutorDylibManagerResolveSignature>( + SAs.Resolve, [Complete = std::move(Complete)]( Error SerializationErr, - Expected<std::vector<ExecutorSymbolDef>> Result) mutable { + Expected<std::vector<std::optional<ExecutorSymbolDef>>> + Result) mutable { if (SerializationErr) { cantFail(Result.takeError()); Complete(std::move(SerializationErr)); @@ -114,7 +116,7 @@ void EPCGenericDylibManager::lookupAsync(tpctypes::DylibHandle H, } Complete(std::move(Result)); }, - SAs.Instance, H, Lookup); + H, Lookup); } } // end namespace orc diff --git a/llvm/lib/ExecutionEngine/Orc/ExecutorResolutionGenerator.cpp b/llvm/lib/ExecutionEngine/Orc/ExecutorResolutionGenerator.cpp new file mode 100644 index 0000000..e5b0bd3 --- /dev/null +++ b/llvm/lib/ExecutionEngine/Orc/ExecutorResolutionGenerator.cpp @@ -0,0 +1,98 @@ +//===---- ExecutorProcessControl.cpp -- Executor process control APIs -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ExecutionEngine/Orc/ExecutorResolutionGenerator.h" + +#include "llvm/ExecutionEngine/Orc/DebugUtils.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h" +#include "llvm/Support/Error.h" + +#define DEBUG_TYPE "orc" + +namespace llvm { +namespace orc { + +Expected<std::unique_ptr<ExecutorResolutionGenerator>> +ExecutorResolutionGenerator::Load(ExecutionSession &ES, const char *LibraryPath, + SymbolPredicate Allow, + AbsoluteSymbolsFn AbsoluteSymbols) { + auto H = ES.getExecutorProcessControl().getDylibMgr().loadDylib(LibraryPath); + if (H) + return H.takeError(); + return std::make_unique<ExecutorResolutionGenerator>( + ES, *H, std::move(Allow), std::move(AbsoluteSymbols)); +} + +Error ExecutorResolutionGenerator::tryToGenerate( + LookupState &LS, LookupKind K, JITDylib &JD, + JITDylibLookupFlags JDLookupFlags, const SymbolLookupSet &LookupSet) { + + if (LookupSet.empty()) + return Error::success(); + + LLVM_DEBUG({ + dbgs() << "ExecutorResolutionGenerator trying to generate " << LookupSet + << "\n"; + }); + + SymbolLookupSet LookupSymbols; + for (auto &[Name, LookupFlag] : LookupSet) { + if (Allow && !Allow(Name)) + continue; + LookupSymbols.add(Name, LookupFlag); + } + + DylibManager::LookupRequest LR(H, LookupSymbols); + EPC.getDylibMgr().lookupSymbolsAsync( + LR, [this, LS = std::move(LS), JD = JITDylibSP(&JD), + LookupSymbols](auto Result) mutable { + if (Result) { + LLVM_DEBUG({ + dbgs() << "ExecutorResolutionGenerator lookup failed due to error"; + }); + return LS.continueLookup(Result.takeError()); + } + assert(Result->size() == 1 && + "Results for more than one library returned"); + assert(Result->front().size() == LookupSymbols.size() && + "Result has incorrect number of elements"); + + // const tpctypes::LookupResult &Syms = Result->front(); + // size_t SymIdx = 0; + auto Syms = Result->front().begin(); + SymbolNameSet MissingSymbols; + SymbolMap NewSyms; + for (auto &[Name, Flags] : LookupSymbols) { + const auto &Sym = *Syms++; + if (Sym && Sym->getAddress()) + NewSyms[Name] = *Sym; + else if (LLVM_UNLIKELY(!Sym && + Flags == SymbolLookupFlags::RequiredSymbol)) + MissingSymbols.insert(Name); + } + + LLVM_DEBUG({ + dbgs() << "ExecutorResolutionGenerator lookup returned " << NewSyms + << "\n"; + }); + + if (NewSyms.empty()) + return LS.continueLookup(Error::success()); + + if (LLVM_UNLIKELY(!MissingSymbols.empty())) + return LS.continueLookup(make_error<SymbolsNotFound>( + this->EPC.getSymbolStringPool(), std::move(MissingSymbols))); + + LS.continueLookup(JD->define(AbsoluteSymbols(std::move(NewSyms)))); + }); + + return Error::success(); +} + +} // end namespace orc +} // end namespace llvm diff --git a/llvm/lib/ExecutionEngine/Orc/LookupAndRecordAddrs.cpp b/llvm/lib/ExecutionEngine/Orc/LookupAndRecordAddrs.cpp index 78169a2..42d630d 100644 --- a/llvm/lib/ExecutionEngine/Orc/LookupAndRecordAddrs.cpp +++ b/llvm/lib/ExecutionEngine/Orc/LookupAndRecordAddrs.cpp @@ -72,9 +72,10 @@ Error lookupAndRecordAddrs( return make_error<StringError>("Error in lookup result elements", inconvertibleErrorCode()); - for (unsigned I = 0; I != Pairs.size(); ++I) - *Pairs[I].second = Result->front()[I].getAddress(); - + for (unsigned I = 0; I != Pairs.size(); ++I) { + if (Result->front()[I]) + *Pairs[I].second = Result->front()[I]->getAddress(); + } return Error::success(); } diff --git a/llvm/lib/ExecutionEngine/Orc/SelfExecutorProcessControl.cpp b/llvm/lib/ExecutionEngine/Orc/SelfExecutorProcessControl.cpp index 78045f1..f8a2bd3 100644 --- a/llvm/lib/ExecutionEngine/Orc/SelfExecutorProcessControl.cpp +++ b/llvm/lib/ExecutionEngine/Orc/SelfExecutorProcessControl.cpp @@ -91,22 +91,18 @@ void SelfExecutorProcessControl::lookupSymbolsAsync( for (auto &Elem : Request) { sys::DynamicLibrary Dylib(Elem.Handle.toPtr<void *>()); - R.push_back(std::vector<ExecutorSymbolDef>()); + R.push_back(tpctypes::LookupResult()); for (auto &KV : Elem.Symbols) { auto &Sym = KV.first; std::string Tmp((*Sym).data() + !!GlobalManglingPrefix, (*Sym).size() - !!GlobalManglingPrefix); void *Addr = Dylib.getAddressOfSymbol(Tmp.c_str()); - if (!Addr && KV.second == SymbolLookupFlags::RequiredSymbol) { - // FIXME: Collect all failing symbols before erroring out. - SymbolNameVector MissingSymbols; - MissingSymbols.push_back(Sym); - return Complete( - make_error<SymbolsNotFound>(SSP, std::move(MissingSymbols))); - } - // FIXME: determine accurate JITSymbolFlags. - R.back().push_back( - {ExecutorAddr::fromPtr(Addr), JITSymbolFlags::Exported}); + if (!Addr && KV.second == SymbolLookupFlags::RequiredSymbol) + R.back().emplace_back(); + else + // FIXME: determine accurate JITSymbolFlags. + R.back().emplace_back(ExecutorSymbolDef(ExecutorAddr::fromPtr(Addr), + JITSymbolFlags::Exported)); } } diff --git a/llvm/lib/ExecutionEngine/Orc/Shared/OrcRTBridge.cpp b/llvm/lib/ExecutionEngine/Orc/Shared/OrcRTBridge.cpp index 123651f..26e8f53 100644 --- a/llvm/lib/ExecutionEngine/Orc/Shared/OrcRTBridge.cpp +++ b/llvm/lib/ExecutionEngine/Orc/Shared/OrcRTBridge.cpp @@ -16,8 +16,8 @@ const char *SimpleExecutorDylibManagerInstanceName = "__llvm_orc_SimpleExecutorDylibManager_Instance"; const char *SimpleExecutorDylibManagerOpenWrapperName = "__llvm_orc_SimpleExecutorDylibManager_open_wrapper"; -const char *SimpleExecutorDylibManagerLookupWrapperName = - "__llvm_orc_SimpleExecutorDylibManager_lookup_wrapper"; +const char *SimpleExecutorDylibManagerResolveWrapperName = + "__llvm_orc_SimpleExecutorDylibManager_resolve_wrapper"; const char *SimpleExecutorMemoryManagerInstanceName = "__llvm_orc_SimpleExecutorMemoryManager_Instance"; diff --git a/llvm/lib/ExecutionEngine/Orc/TargetProcess/CMakeLists.txt b/llvm/lib/ExecutionEngine/Orc/TargetProcess/CMakeLists.txt index 9f3abac..9275586 100644 --- a/llvm/lib/ExecutionEngine/Orc/TargetProcess/CMakeLists.txt +++ b/llvm/lib/ExecutionEngine/Orc/TargetProcess/CMakeLists.txt @@ -15,6 +15,7 @@ endif() add_llvm_component_library(LLVMOrcTargetProcess ExecutorSharedMemoryMapperService.cpp DefaultHostBootstrapValues.cpp + ExecutorResolver.cpp JITLoaderGDB.cpp JITLoaderPerf.cpp JITLoaderVTune.cpp diff --git a/llvm/lib/ExecutionEngine/Orc/TargetProcess/ExecutorResolver.cpp b/llvm/lib/ExecutionEngine/Orc/TargetProcess/ExecutorResolver.cpp new file mode 100644 index 0000000..6054d86 --- /dev/null +++ b/llvm/lib/ExecutionEngine/Orc/TargetProcess/ExecutorResolver.cpp @@ -0,0 +1,47 @@ + +#include "llvm/ExecutionEngine/Orc/TargetProcess/ExecutorResolver.h" + +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/Error.h" + +namespace llvm::orc { + +void DylibSymbolResolver::resolveAsync( + const RemoteSymbolLookupSet &L, + ExecutorResolver::YieldResolveResultFn &&OnResolve) { + std::vector<std::optional<ExecutorSymbolDef>> Result; + auto DL = sys::DynamicLibrary(Handle.toPtr<void *>()); + + for (const auto &E : L) { + if (E.Name.empty()) { + if (E.Required) + OnResolve( + make_error<StringError>("Required address for empty symbol \"\"", + inconvertibleErrorCode())); + else + Result.emplace_back(); + } else { + + const char *DemangledSymName = E.Name.c_str(); +#ifdef __APPLE__ + if (E.Name.front() != '_') + OnResolve(make_error<StringError>(Twine("MachO symbol \"") + E.Name + + "\" missing leading '_'", + inconvertibleErrorCode())); + ++DemangledSymName; +#endif + + void *Addr = DL.getAddressOfSymbol(DemangledSymName); + if (!Addr && E.Required) + Result.emplace_back(); + else + // FIXME: determine accurate JITSymbolFlags. + Result.emplace_back(ExecutorSymbolDef(ExecutorAddr::fromPtr(Addr), + JITSymbolFlags::Exported)); + } + } + + OnResolve(std::move(Result)); +} + +} // end namespace llvm::orc
\ No newline at end of file diff --git a/llvm/lib/ExecutionEngine/Orc/TargetProcess/SimpleExecutorDylibManager.cpp b/llvm/lib/ExecutionEngine/Orc/TargetProcess/SimpleExecutorDylibManager.cpp index db6f201..52bb55d 100644 --- a/llvm/lib/ExecutionEngine/Orc/TargetProcess/SimpleExecutorDylibManager.cpp +++ b/llvm/lib/ExecutionEngine/Orc/TargetProcess/SimpleExecutorDylibManager.cpp @@ -10,6 +10,10 @@ #include "llvm/ExecutionEngine/Orc/Shared/OrcRTBridge.h" +#include "llvm/Support/MSVCErrorWorkarounds.h" + +#include <future> + #define DEBUG_TYPE "orc" namespace llvm { @@ -35,46 +39,9 @@ SimpleExecutorDylibManager::open(const std::string &Path, uint64_t Mode) { std::lock_guard<std::mutex> Lock(M); auto H = ExecutorAddr::fromPtr(DL.getOSSpecificHandle()); + Resolvers.push_back(std::make_unique<DylibSymbolResolver>(H)); Dylibs.insert(DL.getOSSpecificHandle()); - return H; -} - -Expected<std::vector<ExecutorSymbolDef>> -SimpleExecutorDylibManager::lookup(tpctypes::DylibHandle H, - const RemoteSymbolLookupSet &L) { - std::vector<ExecutorSymbolDef> Result; - auto DL = sys::DynamicLibrary(H.toPtr<void *>()); - - for (const auto &E : L) { - if (E.Name.empty()) { - if (E.Required) - return make_error<StringError>("Required address for empty symbol \"\"", - inconvertibleErrorCode()); - else - Result.push_back(ExecutorSymbolDef()); - } else { - - const char *DemangledSymName = E.Name.c_str(); -#ifdef __APPLE__ - if (E.Name.front() != '_') - return make_error<StringError>(Twine("MachO symbol \"") + E.Name + - "\" missing leading '_'", - inconvertibleErrorCode()); - ++DemangledSymName; -#endif - - void *Addr = DL.getAddressOfSymbol(DemangledSymName); - if (!Addr && E.Required) - return make_error<StringError>(Twine("Missing definition for ") + - DemangledSymName, - inconvertibleErrorCode()); - - // FIXME: determine accurate JITSymbolFlags. - Result.push_back({ExecutorAddr::fromPtr(Addr), JITSymbolFlags::Exported}); - } - } - - return Result; + return ExecutorAddr::fromPtr(Resolvers.back().get()); } Error SimpleExecutorDylibManager::shutdown() { @@ -94,8 +61,8 @@ void SimpleExecutorDylibManager::addBootstrapSymbols( M[rt::SimpleExecutorDylibManagerInstanceName] = ExecutorAddr::fromPtr(this); M[rt::SimpleExecutorDylibManagerOpenWrapperName] = ExecutorAddr::fromPtr(&openWrapper); - M[rt::SimpleExecutorDylibManagerLookupWrapperName] = - ExecutorAddr::fromPtr(&lookupWrapper); + M[rt::SimpleExecutorDylibManagerResolveWrapperName] = + ExecutorAddr::fromPtr(&resolveWrapper); } llvm::orc::shared::CWrapperFunctionResult @@ -109,12 +76,22 @@ SimpleExecutorDylibManager::openWrapper(const char *ArgData, size_t ArgSize) { } llvm::orc::shared::CWrapperFunctionResult -SimpleExecutorDylibManager::lookupWrapper(const char *ArgData, size_t ArgSize) { - return shared:: - WrapperFunction<rt::SPSSimpleExecutorDylibManagerLookupSignature>::handle( - ArgData, ArgSize, - shared::makeMethodWrapperHandler( - &SimpleExecutorDylibManager::lookup)) +SimpleExecutorDylibManager::resolveWrapper(const char *ArgData, + size_t ArgSize) { + using ResolveResult = ExecutorResolver::ResolveResult; + return shared::WrapperFunction< + rt::SPSSimpleExecutorDylibManagerResolveSignature>:: + handle(ArgData, ArgSize, + [](ExecutorAddr Obj, RemoteSymbolLookupSet L) -> ResolveResult { + using TmpResult = + MSVCPExpected<std::vector<std::optional<ExecutorSymbolDef>>>; + std::promise<TmpResult> P; + auto F = P.get_future(); + Obj.toPtr<ExecutorResolver *>()->resolveAsync( + std::move(L), + [&](TmpResult R) { P.set_value(std::move(R)); }); + return F.get(); + }) .release(); } diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp index 6b202ba..3842b1a 100644 --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -55,15 +55,8 @@ foldConstantCastPair( Type *MidTy = Op->getType(); Instruction::CastOps firstOp = Instruction::CastOps(Op->getOpcode()); Instruction::CastOps secondOp = Instruction::CastOps(opc); - - // Assume that pointers are never more than 64 bits wide, and only use this - // for the middle type. Otherwise we could end up folding away illegal - // bitcasts between address spaces with different sizes. - IntegerType *FakeIntPtrTy = Type::getInt64Ty(DstTy->getContext()); - - // Let CastInst::isEliminableCastPair do the heavy lifting. return CastInst::isEliminableCastPair(firstOp, secondOp, SrcTy, MidTy, DstTy, - nullptr, FakeIntPtrTy, nullptr); + /*DL=*/nullptr); } static Constant *FoldBitCast(Constant *V, Type *DestTy) { diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp index df0c85b..3f1cc1e 100644 --- a/llvm/lib/IR/Core.cpp +++ b/llvm/lib/IR/Core.cpp @@ -2403,6 +2403,14 @@ LLVMValueRef LLVMAddFunction(LLVMModuleRef M, const char *Name, GlobalValue::ExternalLinkage, Name, unwrap(M))); } +LLVMValueRef LLVMGetOrInsertFunction(LLVMModuleRef M, const char *Name, + size_t NameLen, LLVMTypeRef FunctionTy) { + return wrap(unwrap(M) + ->getOrInsertFunction(StringRef(Name, NameLen), + unwrap<FunctionType>(FunctionTy)) + .getCallee()); +} + LLVMValueRef LLVMGetNamedFunction(LLVMModuleRef M, const char *Name) { return wrap(unwrap(M)->getFunction(Name)); } diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 941e41f..88e7c44 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -2824,10 +2824,10 @@ bool CastInst::isNoopCast(const DataLayout &DL) const { /// The function returns a resultOpcode so these two casts can be replaced with: /// * %Replacement = resultOpcode %SrcTy %x to DstTy /// If no such cast is permitted, the function returns 0. -unsigned CastInst::isEliminableCastPair( - Instruction::CastOps firstOp, Instruction::CastOps secondOp, - Type *SrcTy, Type *MidTy, Type *DstTy, Type *SrcIntPtrTy, Type *MidIntPtrTy, - Type *DstIntPtrTy) { +unsigned CastInst::isEliminableCastPair(Instruction::CastOps firstOp, + Instruction::CastOps secondOp, + Type *SrcTy, Type *MidTy, Type *DstTy, + const DataLayout *DL) { // Define the 144 possibilities for these two cast instructions. The values // in this matrix determine what to do in a given situation and select the // case in the switch below. The rows correspond to firstOp, the columns @@ -2936,24 +2936,16 @@ unsigned CastInst::isEliminableCastPair( return 0; // Cannot simplify if address spaces are different! - if (SrcTy->getPointerAddressSpace() != DstTy->getPointerAddressSpace()) + if (SrcTy != DstTy) return 0; - unsigned MidSize = MidTy->getScalarSizeInBits(); - // We can still fold this without knowing the actual sizes as long we - // know that the intermediate pointer is the largest possible + // Cannot simplify if the intermediate integer size is smaller than the // pointer size. - // FIXME: Is this always true? - if (MidSize == 64) - return Instruction::BitCast; - - // ptrtoint, inttoptr -> bitcast (ptr -> ptr) if int size is >= ptr size. - if (!SrcIntPtrTy || DstIntPtrTy != SrcIntPtrTy) + unsigned MidSize = MidTy->getScalarSizeInBits(); + if (!DL || MidSize < DL->getPointerTypeSizeInBits(SrcTy)) return 0; - unsigned PtrSize = SrcIntPtrTy->getScalarSizeInBits(); - if (MidSize >= PtrSize) - return Instruction::BitCast; - return 0; + + return Instruction::BitCast; } case 8: { // ext, trunc -> bitcast, if the SrcTy and DstTy are the same @@ -2973,14 +2965,17 @@ unsigned CastInst::isEliminableCastPair( // zext, sext -> zext, because sext can't sign extend after zext return Instruction::ZExt; case 11: { - // inttoptr, ptrtoint/ptrtoaddr -> bitcast if SrcSize<=PtrSize and - // SrcSize==DstSize - if (!MidIntPtrTy) + // inttoptr, ptrtoint/ptrtoaddr -> bitcast if SrcSize<=PtrSize/AddrSize + // and SrcSize==DstSize + if (!DL) return 0; - unsigned PtrSize = MidIntPtrTy->getScalarSizeInBits(); + unsigned MidSize = secondOp == Instruction::PtrToAddr + ? DL->getAddressSizeInBits(MidTy) + : DL->getPointerTypeSizeInBits(MidTy); unsigned SrcSize = SrcTy->getScalarSizeInBits(); unsigned DstSize = DstTy->getScalarSizeInBits(); - if (SrcSize <= PtrSize && SrcSize == DstSize) + // TODO: Could also produce zext or trunc here. + if (SrcSize <= MidSize && SrcSize == DstSize) return Instruction::BitCast; return 0; } diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 6b3cd27..71a8a38 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -543,6 +543,7 @@ private: void visitAliasScopeListMetadata(const MDNode *MD); void visitAccessGroupMetadata(const MDNode *MD); void visitCapturesMetadata(Instruction &I, const MDNode *Captures); + void visitAllocTokenMetadata(Instruction &I, MDNode *MD); template <class Ty> bool isValidMetadataArray(const MDTuple &N); #define HANDLE_SPECIALIZED_MDNODE_LEAF(CLASS) void visit##CLASS(const CLASS &N); @@ -5395,6 +5396,12 @@ void Verifier::visitCapturesMetadata(Instruction &I, const MDNode *Captures) { } } +void Verifier::visitAllocTokenMetadata(Instruction &I, MDNode *MD) { + Check(isa<CallBase>(I), "!alloc_token should only exist on calls", &I); + Check(MD->getNumOperands() == 1, "!alloc_token must have 1 operand", MD); + Check(isa<MDString>(MD->getOperand(0)), "expected string", MD); +} + /// verifyInstruction - Verify that an instruction is well formed. /// void Verifier::visitInstruction(Instruction &I) { @@ -5625,6 +5632,9 @@ void Verifier::visitInstruction(Instruction &I) { if (MDNode *Captures = I.getMetadata(LLVMContext::MD_captures)) visitCapturesMetadata(I, Captures); + if (MDNode *MD = I.getMetadata(LLVMContext::MD_alloc_token)) + visitAllocTokenMetadata(I, MD); + if (MDNode *N = I.getDebugLoc().getAsMDNode()) { CheckDI(isa<DILocation>(N), "invalid !dbg metadata attachment", &I, N); visitMDNode(*N, AreDebugLocsAllowed::Yes); diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index c234623..20dcde8 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -240,6 +240,7 @@ #include "llvm/Transforms/IPO/WholeProgramDevirt.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include "llvm/Transforms/Instrumentation/AddressSanitizer.h" +#include "llvm/Transforms/Instrumentation/AllocToken.h" #include "llvm/Transforms/Instrumentation/BoundsChecking.h" #include "llvm/Transforms/Instrumentation/CGProfile.h" #include "llvm/Transforms/Instrumentation/ControlHeightReduction.h" diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index 88550ea..c5c0d64 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -125,6 +125,7 @@ MODULE_PASS("openmp-opt", OpenMPOptPass()) MODULE_PASS("openmp-opt-postlink", OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink)) MODULE_PASS("partial-inliner", PartialInlinerPass()) +MODULE_PASS("alloc-token", AllocTokenPass()) MODULE_PASS("pgo-icall-prom", PGOIndirectCallPromotion()) MODULE_PASS("pgo-instr-gen", PGOInstrumentationGen()) MODULE_PASS("pgo-instr-use", PGOInstrumentationUse()) diff --git a/llvm/lib/Support/SpecialCaseList.cpp b/llvm/lib/Support/SpecialCaseList.cpp index 8d4e043..4b03885 100644 --- a/llvm/lib/Support/SpecialCaseList.cpp +++ b/llvm/lib/Support/SpecialCaseList.cpp @@ -135,7 +135,7 @@ SpecialCaseList::addSection(StringRef SectionStr, unsigned FileNo, Sections.emplace_back(SectionStr, FileNo); auto &Section = Sections.back(); - if (auto Err = Section.SectionMatcher->insert(SectionStr, LineNo, UseGlobs)) { + if (auto Err = Section.SectionMatcher.insert(SectionStr, LineNo, UseGlobs)) { return createStringError(errc::invalid_argument, "malformed section at line " + Twine(LineNo) + ": '" + SectionStr + @@ -218,7 +218,7 @@ std::pair<unsigned, unsigned> SpecialCaseList::inSectionBlame(StringRef Section, StringRef Prefix, StringRef Query, StringRef Category) const { for (const auto &S : reverse(Sections)) { - if (S.SectionMatcher->match(Section)) { + if (S.SectionMatcher.match(Section)) { unsigned Blame = inSectionBlame(S.Entries, Prefix, Query, Category); if (Blame) return {S.FileIdx, Blame}; diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td index 1a697f7..502a8e8 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPU.td +++ b/llvm/lib/Target/AMDGPU/AMDGPU.td @@ -2592,6 +2592,9 @@ def UseFakeTrue16Insts : True16PredicateClass<"Subtarget->hasTrue16BitInsts() && // FIXME When we default to RealTrue16 instead of Fake, change the line as follows. // AssemblerPredicate<(all_of FeatureTrue16BitInsts, (not FeatureRealTrue16Insts))>; +def UseTrue16WithSramECC : True16PredicateClass<"Subtarget->useRealTrue16Insts() && " + "!Subtarget->d16PreservesUnusedBits()">; + def HasD16Writes32BitVgpr: Predicate<"Subtarget->hasD16Writes32BitVgpr()">, AssemblerPredicate<(all_of FeatureTrue16BitInsts, FeatureRealTrue16Insts, FeatureD16Writes32BitVgpr)>; def NotHasD16Writes32BitVgpr: Predicate<"!Subtarget->hasD16Writes32BitVgpr()">, diff --git a/llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp b/llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp index cb49936..65d049e 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUAttributor.cpp @@ -1504,7 +1504,6 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM, A.getOrCreateAAFor<AAAMDAttributes>(IRPosition::function(*F)); A.getOrCreateAAFor<AAUniformWorkGroupSize>(IRPosition::function(*F)); A.getOrCreateAAFor<AAAMDMaxNumWorkgroups>(IRPosition::function(*F)); - A.getOrCreateAAFor<AAAMDGPUNoAGPR>(IRPosition::function(*F)); CallingConv::ID CC = F->getCallingConv(); if (!AMDGPU::isEntryFunctionCC(CC)) { A.getOrCreateAAFor<AAAMDFlatWorkGroupSize>(IRPosition::function(*F)); @@ -1515,6 +1514,9 @@ static bool runImpl(Module &M, AnalysisGetter &AG, TargetMachine &TM, if (!F->isDeclaration() && ST.hasClusters()) A.getOrCreateAAFor<AAAMDGPUClusterDims>(IRPosition::function(*F)); + if (ST.hasGFX90AInsts()) + A.getOrCreateAAFor<AAAMDGPUNoAGPR>(IRPosition::function(*F)); + for (auto &I : instructions(F)) { Value *Ptr = nullptr; if (auto *LI = dyn_cast<LoadInst>(&I)) diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp index 280fbe2..723d07e 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp @@ -929,8 +929,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) { ThinOrFullLTOPhase Phase) { if (Level != OptimizationLevel::O0) { if (!isLTOPreLink(Phase)) { - AMDGPUAttributorOptions Opts; - MPM.addPass(AMDGPUAttributorPass(*this, Opts, Phase)); + if (getTargetTriple().isAMDGCN()) { + AMDGPUAttributorOptions Opts; + MPM.addPass(AMDGPUAttributorPass(*this, Opts, Phase)); + } } } }); @@ -1296,7 +1298,8 @@ void AMDGPUPassConfig::addIRPasses() { if (LowerCtorDtor) addPass(createAMDGPUCtorDtorLoweringLegacyPass()); - if (isPassEnabled(EnableImageIntrinsicOptimizer)) + if (TM.getTargetTriple().isAMDGCN() && + isPassEnabled(EnableImageIntrinsicOptimizer)) addPass(createAMDGPUImageIntrinsicOptimizerPass(&TM)); // This can be disabled by passing ::Disable here or on the command line diff --git a/llvm/lib/Target/AMDGPU/DSInstructions.td b/llvm/lib/Target/AMDGPU/DSInstructions.td index b2ff5a1..18582ed 100644 --- a/llvm/lib/Target/AMDGPU/DSInstructions.td +++ b/llvm/lib/Target/AMDGPU/DSInstructions.td @@ -951,6 +951,11 @@ class DSReadPat <DS_Pseudo inst, ValueType vt, PatFrag frag, int gds=0> : GCNPat (inst $ptr, Offset:$offset, (i1 gds)) >; +class DSReadPat_t16 <DS_Pseudo inst, ValueType vt, PatFrag frag, int gds=0> : GCNPat < + (vt (frag (DS1Addr1Offset i32:$ptr, i32:$offset))), + (EXTRACT_SUBREG (inst $ptr, Offset:$offset, (i1 gds)), lo16) +>; + multiclass DSReadPat_mc<DS_Pseudo inst, ValueType vt, string frag> { let OtherPredicates = [LDSRequiresM0Init] in { @@ -968,13 +973,14 @@ multiclass DSReadPat_t16<DS_Pseudo inst, ValueType vt, string frag> { def : DSReadPat<inst, vt, !cast<PatFrag>(frag#"_m0")>; } - let OtherPredicates = [NotLDSRequiresM0Init] in { - let True16Predicate = NotUseRealTrue16Insts in { - def : DSReadPat<!cast<DS_Pseudo>(!cast<string>(inst)#"_gfx9"), vt, !cast<PatFrag>(frag)>; - } - let True16Predicate = UseRealTrue16Insts in { - def : DSReadPat<!cast<DS_Pseudo>(!cast<string>(inst)#"_t16"), vt, !cast<PatFrag>(frag)>; - } + let OtherPredicates = [NotLDSRequiresM0Init], True16Predicate = NotUseRealTrue16Insts in { + def : DSReadPat<!cast<DS_Pseudo>(!cast<string>(inst)#"_gfx9"), vt, !cast<PatFrag>(frag)>; + } + let OtherPredicates = [NotLDSRequiresM0Init, D16PreservesUnusedBits], True16Predicate = UseRealTrue16Insts in { + def : DSReadPat<!cast<DS_Pseudo>(!cast<string>(inst)#"_t16"), vt, !cast<PatFrag>(frag)>; + } + let OtherPredicates = [NotLDSRequiresM0Init], True16Predicate = UseTrue16WithSramECC in { + def : DSReadPat_t16<!cast<DS_Pseudo>(!cast<string>(inst)#"_gfx9"), vt, !cast<PatFrag>(frag)>; } } diff --git a/llvm/lib/Target/AMDGPU/FLATInstructions.td b/llvm/lib/Target/AMDGPU/FLATInstructions.td index 5a22b23..e86816d 100644 --- a/llvm/lib/Target/AMDGPU/FLATInstructions.td +++ b/llvm/lib/Target/AMDGPU/FLATInstructions.td @@ -1383,6 +1383,11 @@ class FlatLoadPat_D16_t16 <FLAT_Pseudo inst, SDPatternOperator node, ValueType v (inst $vaddr, $offset, (i32 0)) >; +class FlatLoadPat_t16 <FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> : GCNPat < + (vt (node (FlatOffset i64:$vaddr, i32:$offset))), + (EXTRACT_SUBREG (inst $vaddr, $offset), lo16) +>; + class FlatSignedLoadPat_D16 <FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> : GCNPat < (node (GlobalOffset (i64 VReg_64:$vaddr), i32:$offset), vt:$in), (inst $vaddr, $offset, 0, $in) @@ -1393,6 +1398,11 @@ class FlatSignedLoadPat_D16_t16 <FLAT_Pseudo inst, SDPatternOperator node, Value (inst $vaddr, $offset, (i32 0)) >; +class FlatSignedLoadPat_t16 <FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> : GCNPat < + (vt (node (GlobalOffset (i64 VReg_64:$vaddr), i32:$offset))), + (EXTRACT_SUBREG (inst $vaddr, $offset, (i32 0)), lo16) +>; + class GlobalLoadSaddrPat_D16 <FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> : GCNPat < (vt (node (GlobalSAddr (i64 SReg_64:$saddr), (i32 VGPR_32:$voffset), i32:$offset, CPol:$cpol), vt:$in)), (inst $saddr, $voffset, $offset, $cpol, $in) @@ -1408,6 +1418,11 @@ class FlatLoadSaddrPat_D16_t16 <FLAT_Pseudo inst, SDPatternOperator node, ValueT (inst $saddr, $voffset, $offset, $cpol) >; +class FlatLoadSaddrPat_t16 <FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> : GCNPat < + (vt (node (GlobalSAddr (i64 SReg_64:$saddr), (i32 VGPR_32:$voffset), i32:$offset, CPol:$cpol))), + (EXTRACT_SUBREG (inst $saddr, $voffset, $offset, $cpol), lo16) +>; + class FlatLoadLDSSignedPat_M0 <FLAT_Pseudo inst, SDPatternOperator node> : GCNPat < (node (i64 VReg_64:$vaddr), (i32 VGPR_32:$dsaddr), (i32 timm:$offset), (i32 timm:$cpol), M0), (inst $dsaddr, $vaddr, $offset, $cpol) @@ -1443,6 +1458,11 @@ class GlobalLoadSaddrPat_D16_t16 <FLAT_Pseudo inst, SDPatternOperator node, Valu (inst $saddr, $voffset, $offset, $cpol) >; +class GlobalLoadSaddrPat_t16 <FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> : GCNPat < + (vt (node (GlobalSAddr (i64 SReg_64:$saddr), (i32 VGPR_32:$voffset), i32:$offset, CPol:$cpol))), + (EXTRACT_SUBREG (inst $saddr, $voffset, $offset, $cpol), lo16) +>; + class FlatLoadSignedPat <FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> : GCNPat < (vt (node (GlobalOffset (i64 VReg_64:$vaddr), i32:$offset))), (inst $vaddr, $offset) @@ -1625,6 +1645,11 @@ class ScratchLoadSignedPat_D16_t16 <FLAT_Pseudo inst, SDPatternOperator node, Va (inst $vaddr, $offset, 0) >; +class ScratchLoadSignedPat_t16 <FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> : GCNPat < + (vt (node (ScratchOffset (i32 VGPR_32:$vaddr), i32:$offset))), + (EXTRACT_SUBREG (inst $vaddr, $offset), lo16) +>; + class ScratchStoreSignedPat <FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> : GCNPat < (node vt:$data, (ScratchOffset (i32 VGPR_32:$vaddr), i32:$offset)), (inst getVregSrcForVT<vt>.ret:$data, $vaddr, $offset) @@ -1645,6 +1670,11 @@ class ScratchLoadSaddrPat_D16_t16 <FLAT_Pseudo inst, SDPatternOperator node, Val (inst $saddr, $offset, 0) >; +class ScratchLoadSaddrPat_t16 <FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> : GCNPat < + (vt (node (ScratchSAddr (i32 SGPR_32:$saddr), i32:$offset))), + (EXTRACT_SUBREG (inst $saddr, $offset), lo16) +>; + class ScratchStoreSaddrPat <FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> : GCNPat < (node vt:$data, (ScratchSAddr (i32 SGPR_32:$saddr), i32:$offset)), @@ -1672,6 +1702,11 @@ class ScratchLoadSVaddrPat_D16_t16 <FLAT_Pseudo inst, SDPatternOperator node, Va (inst $vaddr, $saddr, $offset, $cpol) >; +class ScratchLoadSVaddrPat_t16 <FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> : GCNPat < + (vt (node (ScratchSVAddr (i32 VGPR_32:$vaddr), (i32 SGPR_32:$saddr), i32:$offset, CPol:$cpol))), + (EXTRACT_SUBREG (inst $vaddr, $saddr, $offset, $cpol), lo16) +>; + multiclass GlobalLoadLDSPats_M0<FLAT_Pseudo inst, SDPatternOperator node> { def : FlatLoadLDSSignedPat_M0 <inst, node> { let AddedComplexity = 10; @@ -1764,6 +1799,16 @@ multiclass GlobalFLATLoadPats_D16_t16<string inst, SDPatternOperator node, Value } } +multiclass GlobalFLATLoadPats_t16<FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> { + def : FlatSignedLoadPat_t16<inst, node, vt> { + let AddedComplexity = 10; + } + + def : GlobalLoadSaddrPat_t16<!cast<FLAT_Pseudo>(!cast<string>(inst)#"_SADDR"), node, vt> { + let AddedComplexity = 11; + } +} + multiclass GlobalFLATStorePats<FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> { def : FlatStoreSignedPat <inst, node, vt> { @@ -1872,8 +1917,8 @@ multiclass ScratchFLATStorePats<FLAT_Pseudo inst, SDPatternOperator node, } } -multiclass ScratchFLATStorePats_t16<string inst, SDPatternOperator node, - ValueType vt> { +multiclass ScratchFLATStorePats_D16_t16<string inst, SDPatternOperator node, + ValueType vt> { def : ScratchStoreSignedPat <!cast<FLAT_Pseudo>(inst#"_t16"), node, vt> { let AddedComplexity = 25; } @@ -1918,6 +1963,21 @@ multiclass ScratchFLATLoadPats_D16_t16<string inst, SDPatternOperator node, Valu } } +multiclass ScratchFLATLoadPats_t16<FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> { + def : ScratchLoadSignedPat_t16 <inst, node, vt> { + let AddedComplexity = 25; + } + + def : ScratchLoadSaddrPat_t16<!cast<FLAT_Pseudo>(!cast<string>(inst)#"_SADDR"), node, vt> { + let AddedComplexity = 26; + } + + def : ScratchLoadSVaddrPat_t16<!cast<FLAT_Pseudo>(!cast<string>(inst)#"_SVS"), node, vt> { + let SubtargetPredicate = HasFlatScratchSVSMode; + let AddedComplexity = 27; + } +} + multiclass FlatLoadPats<FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> { def : FlatLoadPat <inst, node, vt> { let OtherPredicates = [HasFlatAddressSpace]; @@ -1947,6 +2007,17 @@ multiclass FlatLoadPats_D16_t16<FLAT_Pseudo inst, SDPatternOperator node, ValueT } } +multiclass FlatLoadPats_t16<FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> { + def : FlatLoadPat_t16 <inst, node, vt> { + let OtherPredicates = [HasFlatAddressSpace]; + } + + def : FlatLoadSaddrPat_t16<!cast<FLAT_Pseudo>(!cast<string>(inst)#"_SADDR"), node, vt> { + let AddedComplexity = 9; + let SubtargetPredicate = HasFlatGVSMode; + } +} + multiclass FlatStorePats<FLAT_Pseudo inst, SDPatternOperator node, ValueType vt> { def : FlatStorePat <inst, node, vt> { let OtherPredicates = [HasFlatAddressSpace]; @@ -1997,6 +2068,17 @@ let True16Predicate = NotUseRealTrue16Insts in { defm : FlatStorePats <FLAT_STORE_SHORT, atomic_store_16_flat, i16>; } +let True16Predicate = UseTrue16WithSramECC in { + defm : FlatLoadPats_t16 <FLAT_LOAD_UBYTE, extloadi8_flat, i16>; + defm : FlatLoadPats_t16 <FLAT_LOAD_UBYTE, zextloadi8_flat, i16>; + defm : FlatLoadPats_t16 <FLAT_LOAD_SBYTE, sextloadi8_flat, i16>; + defm : FlatLoadPats_t16 <FLAT_LOAD_USHORT, load_flat, i16>; + defm : FlatLoadPats_t16 <FLAT_LOAD_UBYTE, atomic_load_aext_8_flat, i16>; + defm : FlatLoadPats_t16 <FLAT_LOAD_UBYTE, atomic_load_zext_8_flat, i16>; + defm : FlatLoadPats_t16 <FLAT_LOAD_USHORT, atomic_load_nonext_16_flat, i16>; + defm : FlatLoadPats_t16 <FLAT_LOAD_SBYTE, atomic_load_sext_8_flat, i16>; +} + let OtherPredicates = [D16PreservesUnusedBits, HasFlatAddressSpace], True16Predicate = UseRealTrue16Insts in { defm : FlatLoadPats_D16_t16<FLAT_LOAD_UBYTE_D16_t16, extloadi8_flat, i16>; defm : FlatLoadPats_D16_t16<FLAT_LOAD_UBYTE_D16_t16, zextloadi8_flat, i16>; @@ -2006,11 +2088,14 @@ let OtherPredicates = [D16PreservesUnusedBits, HasFlatAddressSpace], True16Predi defm : FlatLoadPats_D16_t16<FLAT_LOAD_UBYTE_D16_t16, atomic_load_zext_8_flat, i16>; defm : FlatLoadPats_D16_t16<FLAT_LOAD_SHORT_D16_t16, atomic_load_nonext_16_flat, i16>; defm : FlatLoadPats_D16_t16<FLAT_LOAD_SBYTE_D16_t16, atomic_load_sext_8_flat, i16>; +} // End let OtherPredicates = [D16PreservesUnusedBits, HasFlatAddressSpace], True16Predicate = UseRealTrue16Insts + +let OtherPredicates = [D16PreservesUnusedBits], True16Predicate = UseRealTrue16Insts in { defm : FlatStorePats_t16 <FLAT_STORE_BYTE, truncstorei8_flat, i16>; defm : FlatStorePats_t16 <FLAT_STORE_SHORT, store_flat, i16>; defm : FlatStorePats_t16 <FLAT_STORE_BYTE, atomic_store_8_flat, i16>; defm : FlatStorePats_t16 <FLAT_STORE_SHORT, atomic_store_16_flat, i16>; -} // End let OtherPredicates = [D16PreservesUnusedBits, HasFlatAddressSpace], True16Predicate = UseRealTrue16Insts +} defm : FlatLoadPats <FLAT_LOAD_DWORD, atomic_load_nonext_32_flat, i32>; defm : FlatLoadPats <FLAT_LOAD_DWORDX2, atomic_load_nonext_64_flat, i64>; @@ -2140,6 +2225,20 @@ defm : GlobalFLATLoadPats <GLOBAL_LOAD_USHORT, atomic_load_nonext_16_global, i16 defm : GlobalFLATLoadPats <GLOBAL_LOAD_USHORT, atomic_load_zext_16_global, i16>; } +let True16Predicate = UseTrue16WithSramECC in { +defm : GlobalFLATLoadPats_t16 <GLOBAL_LOAD_UBYTE, extloadi8_global, i16>; +defm : GlobalFLATLoadPats_t16 <GLOBAL_LOAD_UBYTE, zextloadi8_global, i16>; +defm : GlobalFLATLoadPats_t16 <GLOBAL_LOAD_SBYTE, sextloadi8_global, i16>; +defm : GlobalFLATLoadPats_t16 <GLOBAL_LOAD_SSHORT, atomic_load_sext_16_global, i32>; +defm : GlobalFLATLoadPats_t16 <GLOBAL_LOAD_USHORT, atomic_load_zext_16_global, i32>; +defm : GlobalFLATLoadPats_t16 <GLOBAL_LOAD_USHORT, load_global, i16>; +defm : GlobalFLATLoadPats_t16 <GLOBAL_LOAD_UBYTE, atomic_load_aext_8_global, i16>; +defm : GlobalFLATLoadPats_t16 <GLOBAL_LOAD_UBYTE, atomic_load_zext_8_global, i16>; +defm : GlobalFLATLoadPats_t16 <GLOBAL_LOAD_SBYTE, atomic_load_sext_8_global, i16>; +defm : GlobalFLATLoadPats_t16 <GLOBAL_LOAD_USHORT, atomic_load_nonext_16_global, i16>; +defm : GlobalFLATLoadPats_t16 <GLOBAL_LOAD_USHORT, atomic_load_zext_16_global, i16>; +} + let OtherPredicates = [D16PreservesUnusedBits], True16Predicate = UseRealTrue16Insts in { defm : GlobalFLATLoadPats_D16_t16<"GLOBAL_LOAD_UBYTE_D16", extloadi8_global, i16>; defm : GlobalFLATLoadPats_D16_t16<"GLOBAL_LOAD_UBYTE_D16", zextloadi8_global, i16>; @@ -2192,6 +2291,13 @@ defm : GlobalFLATStorePats <GLOBAL_STORE_BYTE, atomic_store_8_global, i16>; defm : GlobalFLATStorePats <GLOBAL_STORE_SHORT, atomic_store_16_global, i16>; } +let OtherPredicates = [HasFlatGlobalInsts], True16Predicate = UseRealTrue16Insts in { +defm : GlobalFLATStorePats_D16_t16 <"GLOBAL_STORE_BYTE", truncstorei8_global, i16>; +defm : GlobalFLATStorePats_D16_t16 <"GLOBAL_STORE_SHORT", store_global, i16>; +defm : GlobalFLATStorePats_D16_t16 <"GLOBAL_STORE_BYTE", atomic_store_8_global, i16>; +defm : GlobalFLATStorePats_D16_t16 <"GLOBAL_STORE_SHORT", atomic_store_16_global, i16>; +} + let OtherPredicates = [HasD16LoadStore] in { defm : GlobalFLATStorePats <GLOBAL_STORE_SHORT_D16_HI, truncstorei16_hi16_global, i32>; defm : GlobalFLATStorePats <GLOBAL_STORE_BYTE_D16_HI, truncstorei8_hi16_global, i32>; @@ -2362,14 +2468,24 @@ defm : ScratchFLATStorePats <SCRATCH_STORE_SHORT, store_private, i16>; defm : ScratchFLATStorePats <SCRATCH_STORE_BYTE, truncstorei8_private, i16>; } -let True16Predicate = UseRealTrue16Insts in { +let True16Predicate = UseTrue16WithSramECC in { +defm : ScratchFLATLoadPats_t16 <SCRATCH_LOAD_UBYTE, extloadi8_private, i16>; +defm : ScratchFLATLoadPats_t16 <SCRATCH_LOAD_UBYTE, zextloadi8_private, i16>; +defm : ScratchFLATLoadPats_t16 <SCRATCH_LOAD_SBYTE, sextloadi8_private, i16>; +defm : ScratchFLATLoadPats_t16 <SCRATCH_LOAD_USHORT, load_private, i16>; +} + +let OtherPredicates = [D16PreservesUnusedBits], True16Predicate = UseRealTrue16Insts in { defm : ScratchFLATLoadPats_D16_t16<"SCRATCH_LOAD_UBYTE_D16", extloadi8_private, i16>; defm : ScratchFLATLoadPats_D16_t16<"SCRATCH_LOAD_UBYTE_D16", zextloadi8_private, i16>; defm : ScratchFLATLoadPats_D16_t16<"SCRATCH_LOAD_SBYTE_D16", sextloadi8_private, i16>; defm : ScratchFLATLoadPats_D16_t16<"SCRATCH_LOAD_SHORT_D16", load_private, i16>; -defm : ScratchFLATStorePats_t16 <"SCRATCH_STORE_SHORT", store_private, i16>; -defm : ScratchFLATStorePats_t16 <"SCRATCH_STORE_BYTE", truncstorei8_private, i16>; -} // End True16Predicate = UseRealTrue16Insts +} // End OtherPredicates = [D16PreservesUnusedBits], True16Predicate = UseRealTrue16Insts + +let True16Predicate = UseRealTrue16Insts in { +defm : ScratchFLATStorePats_D16_t16 <"SCRATCH_STORE_SHORT", store_private, i16>; +defm : ScratchFLATStorePats_D16_t16 <"SCRATCH_STORE_BYTE", truncstorei8_private, i16>; +} foreach vt = Reg32Types.types in { defm : ScratchFLATLoadPats <SCRATCH_LOAD_DWORD, load_private, vt>; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 6234714..a3a4cf2 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16498,43 +16498,60 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, SDValue X = N->getOperand(0); if (Subtarget.hasShlAdd(3)) { - for (uint64_t Divisor : {3, 5, 9}) { - if (MulAmt % Divisor != 0) - continue; - uint64_t MulAmt2 = MulAmt / Divisor; - // 3/5/9 * 2^N -> shl (shXadd X, X), N - if (isPowerOf2_64(MulAmt2)) { - SDLoc DL(N); - SDValue X = N->getOperand(0); - // Put the shift first if we can fold a zext into the - // shift forming a slli.uw. - if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) && - X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) { - SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X, - DAG.getConstant(Log2_64(MulAmt2), DL, VT)); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), - Shl); - } - // Otherwise, put rhe shl second so that it can fold with following - // instructions (e.g. sext or add). - SDValue Mul359 = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); - return DAG.getNode(ISD::SHL, DL, VT, Mul359, - DAG.getConstant(Log2_64(MulAmt2), DL, VT)); - } - - // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X) - if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) { - SDLoc DL(N); - SDValue Mul359 = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT), - Mul359); + int Shift; + if (int ShXAmount = isShifted359(MulAmt, Shift)) { + // 3/5/9 * 2^N -> shl (shXadd X, X), N + SDLoc DL(N); + SDValue X = N->getOperand(0); + // Put the shift first if we can fold a zext into the shift forming + // a slli.uw. + if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) && + X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) { + SDValue Shl = + DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT)); + return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl, + DAG.getConstant(ShXAmount, DL, VT), Shl); } + // Otherwise, put the shl second so that it can fold with following + // instructions (e.g. sext or add). + SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getConstant(ShXAmount, DL, VT), X); + return DAG.getNode(ISD::SHL, DL, VT, Mul359, + DAG.getConstant(Shift, DL, VT)); + } + + // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X) + int ShX; + int ShY; + switch (MulAmt) { + case 3 * 5: + ShY = 1; + ShX = 2; + break; + case 3 * 9: + ShY = 1; + ShX = 3; + break; + case 5 * 5: + ShX = ShY = 2; + break; + case 5 * 9: + ShY = 2; + ShX = 3; + break; + case 9 * 9: + ShX = ShY = 3; + break; + default: + ShX = ShY = 0; + break; + } + if (ShX) { + SDLoc DL(N); + SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getConstant(ShY, DL, VT), X); + return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, + DAG.getConstant(ShX, DL, VT), Mul359); } // If this is a power 2 + 2/4/8, we can use a shift followed by a single @@ -16557,18 +16574,14 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, // variants we could implement. e.g. // (2^(1,2,3) * 3,5,9 + 1) << C2 // 2^(C1>3) * 3,5,9 +/- 1 - for (uint64_t Divisor : {3, 5, 9}) { - uint64_t C = MulAmt - 1; - if (C <= Divisor) - continue; - unsigned TZ = llvm::countr_zero(C); - if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) { + if (int ShXAmount = isShifted359(MulAmt - 1, Shift)) { + assert(Shift != 0 && "MulAmt=4,6,10 handled before"); + if (Shift <= 3) { SDLoc DL(N); - SDValue Mul359 = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); + SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getConstant(ShXAmount, DL, VT), X); return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getConstant(TZ, DL, VT), X); + DAG.getConstant(Shift, DL, VT), X); } } @@ -16576,7 +16589,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) { unsigned ScaleShift = llvm::countr_zero(MulAmt - 1); if (ScaleShift >= 1 && ScaleShift < 4) { - unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2))); + unsigned ShiftAmt = llvm::countr_zero((MulAmt - 1) & (MulAmt - 2)); SDLoc DL(N); SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT)); @@ -16589,7 +16602,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, // 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x)) for (uint64_t Offset : {3, 5, 9}) { if (isPowerOf2_64(MulAmt + Offset)) { - unsigned ShAmt = Log2_64(MulAmt + Offset); + unsigned ShAmt = llvm::countr_zero(MulAmt + Offset); if (ShAmt >= VT.getSizeInBits()) continue; SDLoc DL(N); @@ -16608,21 +16621,16 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, uint64_t MulAmt2 = MulAmt / Divisor; // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples // of 25 which happen to be quite common. - for (uint64_t Divisor2 : {3, 5, 9}) { - if (MulAmt2 % Divisor2 != 0) - continue; - uint64_t MulAmt3 = MulAmt2 / Divisor2; - if (isPowerOf2_64(MulAmt3)) { - SDLoc DL(N); - SDValue Mul359A = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); - SDValue Mul359B = DAG.getNode( - RISCVISD::SHL_ADD, DL, VT, Mul359A, - DAG.getConstant(Log2_64(Divisor2 - 1), DL, VT), Mul359A); - return DAG.getNode(ISD::SHL, DL, VT, Mul359B, - DAG.getConstant(Log2_64(MulAmt3), DL, VT)); - } + if (int ShBAmount = isShifted359(MulAmt2, Shift)) { + SDLoc DL(N); + SDValue Mul359A = + DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); + SDValue Mul359B = + DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359A, + DAG.getConstant(ShBAmount, DL, VT), Mul359A); + return DAG.getNode(ISD::SHL, DL, VT, Mul359B, + DAG.getConstant(Shift, DL, VT)); } } } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index 7db4832..96e1078 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -4586,24 +4586,23 @@ void RISCVInstrInfo::mulImm(MachineFunction &MF, MachineBasicBlock &MBB, .addReg(DestReg, RegState::Kill) .addImm(ShiftAmount) .setMIFlag(Flag); - } else if (STI.hasShlAdd(3) && - ((Amount % 3 == 0 && isPowerOf2_64(Amount / 3)) || - (Amount % 5 == 0 && isPowerOf2_64(Amount / 5)) || - (Amount % 9 == 0 && isPowerOf2_64(Amount / 9)))) { + } else if (int ShXAmount, ShiftAmount; + STI.hasShlAdd(3) && + (ShXAmount = isShifted359(Amount, ShiftAmount)) != 0) { // We can use Zba SHXADD+SLLI instructions for multiply in some cases. unsigned Opc; - uint32_t ShiftAmount; - if (Amount % 9 == 0) { - Opc = RISCV::SH3ADD; - ShiftAmount = Log2_64(Amount / 9); - } else if (Amount % 5 == 0) { - Opc = RISCV::SH2ADD; - ShiftAmount = Log2_64(Amount / 5); - } else if (Amount % 3 == 0) { + switch (ShXAmount) { + case 1: Opc = RISCV::SH1ADD; - ShiftAmount = Log2_64(Amount / 3); - } else { - llvm_unreachable("implied by if-clause"); + break; + case 2: + Opc = RISCV::SH2ADD; + break; + case 3: + Opc = RISCV::SH3ADD; + break; + default: + llvm_unreachable("unexpected result of isShifted359"); } if (ShiftAmount) BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h index 42a0c4c..c5eddb9 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -25,6 +25,25 @@ namespace llvm { +// If Value is of the form C1<<C2, where C1 = 3, 5 or 9, +// returns log2(C1 - 1) and assigns Shift = C2. +// Otherwise, returns 0. +template <typename T> int isShifted359(T Value, int &Shift) { + if (Value == 0) + return 0; + Shift = llvm::countr_zero(Value); + switch (Value >> Shift) { + case 3: + return 1; + case 5: + return 2; + case 9: + return 3; + default: + return 0; + } +} + class RISCVSubtarget; static const MachineMemOperand::Flags MONontemporalBit0 = diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZa.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZa.td index 7cf6d5f..87b9f45 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZa.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZa.td @@ -9,8 +9,8 @@ // This file describes the RISC-V instructions from the standard atomic 'Za*' // extensions: // - Zawrs (v1.0) : Wait-on-Reservation-Set. -// - Zacas (v1.0-rc1) : Atomic Compare-and-Swap. -// - Zabha (v1.0-rc1) : Byte and Halfword Atomic Memory Operations. +// - Zacas (v1.0) : Atomic Compare-and-Swap. +// - Zabha (v1.0) : Byte and Halfword Atomic Memory Operations. // //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index ee25f69..7bc0b5b 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -2747,20 +2747,72 @@ bool RISCVTTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst, Intrinsic::ID IID = Inst->getIntrinsicID(); LLVMContext &C = Inst->getContext(); bool HasMask = false; + + auto getSegNum = [](const IntrinsicInst *II, unsigned PtrOperandNo, + bool IsWrite) -> int64_t { + if (auto *TarExtTy = + dyn_cast<TargetExtType>(II->getArgOperand(0)->getType())) + return TarExtTy->getIntParameter(0); + + return 1; + }; + switch (IID) { case Intrinsic::riscv_vle_mask: case Intrinsic::riscv_vse_mask: + case Intrinsic::riscv_vlseg2_mask: + case Intrinsic::riscv_vlseg3_mask: + case Intrinsic::riscv_vlseg4_mask: + case Intrinsic::riscv_vlseg5_mask: + case Intrinsic::riscv_vlseg6_mask: + case Intrinsic::riscv_vlseg7_mask: + case Intrinsic::riscv_vlseg8_mask: + case Intrinsic::riscv_vsseg2_mask: + case Intrinsic::riscv_vsseg3_mask: + case Intrinsic::riscv_vsseg4_mask: + case Intrinsic::riscv_vsseg5_mask: + case Intrinsic::riscv_vsseg6_mask: + case Intrinsic::riscv_vsseg7_mask: + case Intrinsic::riscv_vsseg8_mask: HasMask = true; [[fallthrough]]; case Intrinsic::riscv_vle: - case Intrinsic::riscv_vse: { + case Intrinsic::riscv_vse: + case Intrinsic::riscv_vlseg2: + case Intrinsic::riscv_vlseg3: + case Intrinsic::riscv_vlseg4: + case Intrinsic::riscv_vlseg5: + case Intrinsic::riscv_vlseg6: + case Intrinsic::riscv_vlseg7: + case Intrinsic::riscv_vlseg8: + case Intrinsic::riscv_vsseg2: + case Intrinsic::riscv_vsseg3: + case Intrinsic::riscv_vsseg4: + case Intrinsic::riscv_vsseg5: + case Intrinsic::riscv_vsseg6: + case Intrinsic::riscv_vsseg7: + case Intrinsic::riscv_vsseg8: { // Intrinsic interface: // riscv_vle(merge, ptr, vl) // riscv_vle_mask(merge, ptr, mask, vl, policy) // riscv_vse(val, ptr, vl) // riscv_vse_mask(val, ptr, mask, vl, policy) + // riscv_vlseg#(merge, ptr, vl, sew) + // riscv_vlseg#_mask(merge, ptr, mask, vl, policy, sew) + // riscv_vsseg#(val, ptr, vl, sew) + // riscv_vsseg#_mask(val, ptr, mask, vl, sew) bool IsWrite = Inst->getType()->isVoidTy(); Type *Ty = IsWrite ? Inst->getArgOperand(0)->getType() : Inst->getType(); + // The results of segment loads are TargetExtType. + if (auto *TarExtTy = dyn_cast<TargetExtType>(Ty)) { + unsigned SEW = + 1 << cast<ConstantInt>(Inst->getArgOperand(Inst->arg_size() - 1)) + ->getZExtValue(); + Ty = TarExtTy->getTypeParameter(0U); + Ty = ScalableVectorType::get( + IntegerType::get(C, SEW), + cast<ScalableVectorType>(Ty)->getMinNumElements() * 8 / SEW); + } const auto *RVVIInfo = RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(IID); unsigned VLIndex = RVVIInfo->VLOperand; unsigned PtrOperandNo = VLIndex - 1 - HasMask; @@ -2771,23 +2823,72 @@ bool RISCVTTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst, if (HasMask) Mask = Inst->getArgOperand(VLIndex - 1); Value *EVL = Inst->getArgOperand(VLIndex); + unsigned SegNum = getSegNum(Inst, PtrOperandNo, IsWrite); + // RVV uses contiguous elements as a segment. + if (SegNum > 1) { + unsigned ElemSize = Ty->getScalarSizeInBits(); + auto *SegTy = IntegerType::get(C, ElemSize * SegNum); + Ty = VectorType::get(SegTy, cast<VectorType>(Ty)); + } Info.InterestingOperands.emplace_back(Inst, PtrOperandNo, IsWrite, Ty, Alignment, Mask, EVL); return true; } case Intrinsic::riscv_vlse_mask: case Intrinsic::riscv_vsse_mask: + case Intrinsic::riscv_vlsseg2_mask: + case Intrinsic::riscv_vlsseg3_mask: + case Intrinsic::riscv_vlsseg4_mask: + case Intrinsic::riscv_vlsseg5_mask: + case Intrinsic::riscv_vlsseg6_mask: + case Intrinsic::riscv_vlsseg7_mask: + case Intrinsic::riscv_vlsseg8_mask: + case Intrinsic::riscv_vssseg2_mask: + case Intrinsic::riscv_vssseg3_mask: + case Intrinsic::riscv_vssseg4_mask: + case Intrinsic::riscv_vssseg5_mask: + case Intrinsic::riscv_vssseg6_mask: + case Intrinsic::riscv_vssseg7_mask: + case Intrinsic::riscv_vssseg8_mask: HasMask = true; [[fallthrough]]; case Intrinsic::riscv_vlse: - case Intrinsic::riscv_vsse: { + case Intrinsic::riscv_vsse: + case Intrinsic::riscv_vlsseg2: + case Intrinsic::riscv_vlsseg3: + case Intrinsic::riscv_vlsseg4: + case Intrinsic::riscv_vlsseg5: + case Intrinsic::riscv_vlsseg6: + case Intrinsic::riscv_vlsseg7: + case Intrinsic::riscv_vlsseg8: + case Intrinsic::riscv_vssseg2: + case Intrinsic::riscv_vssseg3: + case Intrinsic::riscv_vssseg4: + case Intrinsic::riscv_vssseg5: + case Intrinsic::riscv_vssseg6: + case Intrinsic::riscv_vssseg7: + case Intrinsic::riscv_vssseg8: { // Intrinsic interface: // riscv_vlse(merge, ptr, stride, vl) // riscv_vlse_mask(merge, ptr, stride, mask, vl, policy) // riscv_vsse(val, ptr, stride, vl) // riscv_vsse_mask(val, ptr, stride, mask, vl, policy) + // riscv_vlsseg#(merge, ptr, offset, vl, sew) + // riscv_vlsseg#_mask(merge, ptr, offset, mask, vl, policy, sew) + // riscv_vssseg#(val, ptr, offset, vl, sew) + // riscv_vssseg#_mask(val, ptr, offset, mask, vl, sew) bool IsWrite = Inst->getType()->isVoidTy(); Type *Ty = IsWrite ? Inst->getArgOperand(0)->getType() : Inst->getType(); + // The results of segment loads are TargetExtType. + if (auto *TarExtTy = dyn_cast<TargetExtType>(Ty)) { + unsigned SEW = + 1 << cast<ConstantInt>(Inst->getArgOperand(Inst->arg_size() - 1)) + ->getZExtValue(); + Ty = TarExtTy->getTypeParameter(0U); + Ty = ScalableVectorType::get( + IntegerType::get(C, SEW), + cast<ScalableVectorType>(Ty)->getMinNumElements() * 8 / SEW); + } const auto *RVVIInfo = RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(IID); unsigned VLIndex = RVVIInfo->VLOperand; unsigned PtrOperandNo = VLIndex - 2 - HasMask; @@ -2809,6 +2910,13 @@ bool RISCVTTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst, if (HasMask) Mask = Inst->getArgOperand(VLIndex - 1); Value *EVL = Inst->getArgOperand(VLIndex); + unsigned SegNum = getSegNum(Inst, PtrOperandNo, IsWrite); + // RVV uses contiguous elements as a segment. + if (SegNum > 1) { + unsigned ElemSize = Ty->getScalarSizeInBits(); + auto *SegTy = IntegerType::get(C, ElemSize * SegNum); + Ty = VectorType::get(SegTy, cast<VectorType>(Ty)); + } Info.InterestingOperands.emplace_back(Inst, PtrOperandNo, IsWrite, Ty, Alignment, Mask, EVL, Stride); return true; @@ -2817,19 +2925,89 @@ bool RISCVTTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst, case Intrinsic::riscv_vluxei_mask: case Intrinsic::riscv_vsoxei_mask: case Intrinsic::riscv_vsuxei_mask: + case Intrinsic::riscv_vloxseg2_mask: + case Intrinsic::riscv_vloxseg3_mask: + case Intrinsic::riscv_vloxseg4_mask: + case Intrinsic::riscv_vloxseg5_mask: + case Intrinsic::riscv_vloxseg6_mask: + case Intrinsic::riscv_vloxseg7_mask: + case Intrinsic::riscv_vloxseg8_mask: + case Intrinsic::riscv_vluxseg2_mask: + case Intrinsic::riscv_vluxseg3_mask: + case Intrinsic::riscv_vluxseg4_mask: + case Intrinsic::riscv_vluxseg5_mask: + case Intrinsic::riscv_vluxseg6_mask: + case Intrinsic::riscv_vluxseg7_mask: + case Intrinsic::riscv_vluxseg8_mask: + case Intrinsic::riscv_vsoxseg2_mask: + case Intrinsic::riscv_vsoxseg3_mask: + case Intrinsic::riscv_vsoxseg4_mask: + case Intrinsic::riscv_vsoxseg5_mask: + case Intrinsic::riscv_vsoxseg6_mask: + case Intrinsic::riscv_vsoxseg7_mask: + case Intrinsic::riscv_vsoxseg8_mask: + case Intrinsic::riscv_vsuxseg2_mask: + case Intrinsic::riscv_vsuxseg3_mask: + case Intrinsic::riscv_vsuxseg4_mask: + case Intrinsic::riscv_vsuxseg5_mask: + case Intrinsic::riscv_vsuxseg6_mask: + case Intrinsic::riscv_vsuxseg7_mask: + case Intrinsic::riscv_vsuxseg8_mask: HasMask = true; [[fallthrough]]; case Intrinsic::riscv_vloxei: case Intrinsic::riscv_vluxei: case Intrinsic::riscv_vsoxei: - case Intrinsic::riscv_vsuxei: { + case Intrinsic::riscv_vsuxei: + case Intrinsic::riscv_vloxseg2: + case Intrinsic::riscv_vloxseg3: + case Intrinsic::riscv_vloxseg4: + case Intrinsic::riscv_vloxseg5: + case Intrinsic::riscv_vloxseg6: + case Intrinsic::riscv_vloxseg7: + case Intrinsic::riscv_vloxseg8: + case Intrinsic::riscv_vluxseg2: + case Intrinsic::riscv_vluxseg3: + case Intrinsic::riscv_vluxseg4: + case Intrinsic::riscv_vluxseg5: + case Intrinsic::riscv_vluxseg6: + case Intrinsic::riscv_vluxseg7: + case Intrinsic::riscv_vluxseg8: + case Intrinsic::riscv_vsoxseg2: + case Intrinsic::riscv_vsoxseg3: + case Intrinsic::riscv_vsoxseg4: + case Intrinsic::riscv_vsoxseg5: + case Intrinsic::riscv_vsoxseg6: + case Intrinsic::riscv_vsoxseg7: + case Intrinsic::riscv_vsoxseg8: + case Intrinsic::riscv_vsuxseg2: + case Intrinsic::riscv_vsuxseg3: + case Intrinsic::riscv_vsuxseg4: + case Intrinsic::riscv_vsuxseg5: + case Intrinsic::riscv_vsuxseg6: + case Intrinsic::riscv_vsuxseg7: + case Intrinsic::riscv_vsuxseg8: { // Intrinsic interface (only listed ordered version): // riscv_vloxei(merge, ptr, index, vl) // riscv_vloxei_mask(merge, ptr, index, mask, vl, policy) // riscv_vsoxei(val, ptr, index, vl) // riscv_vsoxei_mask(val, ptr, index, mask, vl, policy) + // riscv_vloxseg#(merge, ptr, index, vl, sew) + // riscv_vloxseg#_mask(merge, ptr, index, mask, vl, policy, sew) + // riscv_vsoxseg#(val, ptr, index, vl, sew) + // riscv_vsoxseg#_mask(val, ptr, index, mask, vl, sew) bool IsWrite = Inst->getType()->isVoidTy(); Type *Ty = IsWrite ? Inst->getArgOperand(0)->getType() : Inst->getType(); + // The results of segment loads are TargetExtType. + if (auto *TarExtTy = dyn_cast<TargetExtType>(Ty)) { + unsigned SEW = + 1 << cast<ConstantInt>(Inst->getArgOperand(Inst->arg_size() - 1)) + ->getZExtValue(); + Ty = TarExtTy->getTypeParameter(0U); + Ty = ScalableVectorType::get( + IntegerType::get(C, SEW), + cast<ScalableVectorType>(Ty)->getMinNumElements() * 8 / SEW); + } const auto *RVVIInfo = RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(IID); unsigned VLIndex = RVVIInfo->VLOperand; unsigned PtrOperandNo = VLIndex - 2 - HasMask; @@ -2845,6 +3023,13 @@ bool RISCVTTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst, Mask = ConstantInt::getTrue(MaskType); } Value *EVL = Inst->getArgOperand(VLIndex); + unsigned SegNum = getSegNum(Inst, PtrOperandNo, IsWrite); + // RVV uses contiguous elements as a segment. + if (SegNum > 1) { + unsigned ElemSize = Ty->getScalarSizeInBits(); + auto *SegTy = IntegerType::get(C, ElemSize * SegNum); + Ty = VectorType::get(SegTy, cast<VectorType>(Ty)); + } Value *OffsetOp = Inst->getArgOperand(PtrOperandNo + 1); Info.InterestingOperands.emplace_back(Inst, PtrOperandNo, IsWrite, Ty, Align(1), Mask, EVL, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 9ca8194..56194fe 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -137,13 +137,10 @@ InstCombinerImpl::isEliminableCastPair(const CastInst *CI1, Instruction::CastOps secondOp = CI2->getOpcode(); Type *SrcIntPtrTy = SrcTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(SrcTy) : nullptr; - Type *MidIntPtrTy = - MidTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(MidTy) : nullptr; Type *DstIntPtrTy = DstTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(DstTy) : nullptr; unsigned Res = CastInst::isEliminableCastPair(firstOp, secondOp, SrcTy, MidTy, - DstTy, SrcIntPtrTy, MidIntPtrTy, - DstIntPtrTy); + DstTy, &DL); // We don't want to form an inttoptr or ptrtoint that converts to an integer // type that differs from the pointer size. diff --git a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp index cdae9a7..3704ad7 100644 --- a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -2662,7 +2662,7 @@ void ModuleAddressSanitizer::instrumentGlobals(IRBuilder<> &IRB, G->eraseFromParent(); NewGlobals[i] = NewGlobal; - Constant *ODRIndicator = ConstantPointerNull::get(PtrTy); + Constant *ODRIndicator = Constant::getNullValue(IntptrTy); GlobalValue *InstrumentedGlobal = NewGlobal; bool CanUsePrivateAliases = @@ -2677,8 +2677,7 @@ void ModuleAddressSanitizer::instrumentGlobals(IRBuilder<> &IRB, // ODR should not happen for local linkage. if (NewGlobal->hasLocalLinkage()) { - ODRIndicator = - ConstantExpr::getIntToPtr(ConstantInt::get(IntptrTy, -1), PtrTy); + ODRIndicator = ConstantInt::get(IntptrTy, -1); } else if (UseOdrIndicator) { // With local aliases, we need to provide another externally visible // symbol __odr_asan_XXX to detect ODR violation. @@ -2692,7 +2691,7 @@ void ModuleAddressSanitizer::instrumentGlobals(IRBuilder<> &IRB, ODRIndicatorSym->setVisibility(NewGlobal->getVisibility()); ODRIndicatorSym->setDLLStorageClass(NewGlobal->getDLLStorageClass()); ODRIndicatorSym->setAlignment(Align(1)); - ODRIndicator = ODRIndicatorSym; + ODRIndicator = ConstantExpr::getPtrToInt(ODRIndicatorSym, IntptrTy); } Constant *Initializer = ConstantStruct::get( @@ -2703,8 +2702,7 @@ void ModuleAddressSanitizer::instrumentGlobals(IRBuilder<> &IRB, ConstantExpr::getPointerCast(Name, IntptrTy), ConstantExpr::getPointerCast(getOrCreateModuleName(), IntptrTy), ConstantInt::get(IntptrTy, MD.IsDynInit), - Constant::getNullValue(IntptrTy), - ConstantExpr::getPointerCast(ODRIndicator, IntptrTy)); + Constant::getNullValue(IntptrTy), ODRIndicator); LLVM_DEBUG(dbgs() << "NEW GLOBAL: " << *NewGlobal << "\n"); diff --git a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp new file mode 100644 index 0000000..782d5a1 --- /dev/null +++ b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp @@ -0,0 +1,494 @@ +//===- AllocToken.cpp - Allocation token instrumentation ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements AllocToken, an instrumentation pass that +// replaces allocation calls with token-enabled versions. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Instrumentation/AllocToken.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Analysis.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/RandomNumberGenerator.h" +#include "llvm/Support/SipHash.h" +#include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <limits> +#include <memory> +#include <optional> +#include <string> +#include <utility> +#include <variant> + +using namespace llvm; + +#define DEBUG_TYPE "alloc-token" + +namespace { + +//===--- Constants --------------------------------------------------------===// + +enum class TokenMode : unsigned { + /// Incrementally increasing token ID. + Increment = 0, + + /// Simple mode that returns a statically-assigned random token ID. + Random = 1, + + /// Token ID based on allocated type hash. + TypeHash = 2, +}; + +//===--- Command-line options ---------------------------------------------===// + +cl::opt<TokenMode> + ClMode("alloc-token-mode", cl::Hidden, cl::desc("Token assignment mode"), + cl::init(TokenMode::TypeHash), + cl::values(clEnumValN(TokenMode::Increment, "increment", + "Incrementally increasing token ID"), + clEnumValN(TokenMode::Random, "random", + "Statically-assigned random token ID"), + clEnumValN(TokenMode::TypeHash, "typehash", + "Token ID based on allocated type hash"))); + +cl::opt<std::string> ClFuncPrefix("alloc-token-prefix", + cl::desc("The allocation function prefix"), + cl::Hidden, cl::init("__alloc_token_")); + +cl::opt<uint64_t> ClMaxTokens("alloc-token-max", + cl::desc("Maximum number of tokens (0 = no max)"), + cl::Hidden, cl::init(0)); + +cl::opt<bool> + ClFastABI("alloc-token-fast-abi", + cl::desc("The token ID is encoded in the function name"), + cl::Hidden, cl::init(false)); + +// Instrument libcalls only by default - compatible allocators only need to take +// care of providing standard allocation functions. With extended coverage, also +// instrument non-libcall allocation function calls with !alloc_token +// metadata. +cl::opt<bool> + ClExtended("alloc-token-extended", + cl::desc("Extend coverage to custom allocation functions"), + cl::Hidden, cl::init(false)); + +// C++ defines ::operator new (and variants) as replaceable (vs. standard +// library versions), which are nobuiltin, and are therefore not covered by +// isAllocationFn(). Cover by default, as users of AllocToken are already +// required to provide token-aware allocation functions (no defaults). +cl::opt<bool> ClCoverReplaceableNew("alloc-token-cover-replaceable-new", + cl::desc("Cover replaceable operator new"), + cl::Hidden, cl::init(true)); + +cl::opt<uint64_t> ClFallbackToken( + "alloc-token-fallback", + cl::desc("The default fallback token where none could be determined"), + cl::Hidden, cl::init(0)); + +//===--- Statistics -------------------------------------------------------===// + +STATISTIC(NumFunctionsInstrumented, "Functions instrumented"); +STATISTIC(NumAllocationsInstrumented, "Allocations instrumented"); + +//===----------------------------------------------------------------------===// + +/// Returns the !alloc_token metadata if available. +/// +/// Expected format is: !{<type-name>} +MDNode *getAllocTokenMetadata(const CallBase &CB) { + MDNode *Ret = CB.getMetadata(LLVMContext::MD_alloc_token); + if (!Ret) + return nullptr; + assert(Ret->getNumOperands() == 1 && "bad !alloc_token"); + assert(isa<MDString>(Ret->getOperand(0))); + return Ret; +} + +class ModeBase { +public: + explicit ModeBase(const IntegerType &TokenTy, uint64_t MaxTokens) + : MaxTokens(MaxTokens ? MaxTokens : TokenTy.getBitMask()) { + assert(MaxTokens <= TokenTy.getBitMask()); + } + +protected: + uint64_t boundedToken(uint64_t Val) const { + assert(MaxTokens != 0); + return Val % MaxTokens; + } + + const uint64_t MaxTokens; +}; + +/// Implementation for TokenMode::Increment. +class IncrementMode : public ModeBase { +public: + using ModeBase::ModeBase; + + uint64_t operator()(const CallBase &CB, OptimizationRemarkEmitter &) { + return boundedToken(Counter++); + } + +private: + uint64_t Counter = 0; +}; + +/// Implementation for TokenMode::Random. +class RandomMode : public ModeBase { +public: + RandomMode(const IntegerType &TokenTy, uint64_t MaxTokens, + std::unique_ptr<RandomNumberGenerator> RNG) + : ModeBase(TokenTy, MaxTokens), RNG(std::move(RNG)) {} + uint64_t operator()(const CallBase &CB, OptimizationRemarkEmitter &) { + return boundedToken((*RNG)()); + } + +private: + std::unique_ptr<RandomNumberGenerator> RNG; +}; + +/// Implementation for TokenMode::TypeHash. The implementation ensures +/// hashes are stable across different compiler invocations. Uses SipHash as the +/// hash function. +class TypeHashMode : public ModeBase { +public: + using ModeBase::ModeBase; + + uint64_t operator()(const CallBase &CB, OptimizationRemarkEmitter &ORE) { + if (MDNode *N = getAllocTokenMetadata(CB)) { + MDString *S = cast<MDString>(N->getOperand(0)); + return boundedToken(getStableSipHash(S->getString())); + } + remarkNoMetadata(CB, ORE); + return ClFallbackToken; + } + + /// Remark that there was no precise type information. + static void remarkNoMetadata(const CallBase &CB, + OptimizationRemarkEmitter &ORE) { + ORE.emit([&] { + ore::NV FuncNV("Function", CB.getParent()->getParent()); + const Function *Callee = CB.getCalledFunction(); + ore::NV CalleeNV("Callee", Callee ? Callee->getName() : "<unknown>"); + return OptimizationRemark(DEBUG_TYPE, "NoAllocToken", &CB) + << "Call to '" << CalleeNV << "' in '" << FuncNV + << "' without source-level type token"; + }); + } +}; + +// Apply opt overrides. +AllocTokenOptions transformOptionsFromCl(AllocTokenOptions Opts) { + if (!Opts.MaxTokens.has_value()) + Opts.MaxTokens = ClMaxTokens; + Opts.FastABI |= ClFastABI; + Opts.Extended |= ClExtended; + return Opts; +} + +class AllocToken { +public: + explicit AllocToken(AllocTokenOptions Opts, Module &M, + ModuleAnalysisManager &MAM) + : Options(transformOptionsFromCl(std::move(Opts))), Mod(M), + FAM(MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()), + Mode(IncrementMode(*IntPtrTy, *Options.MaxTokens)) { + switch (ClMode.getValue()) { + case TokenMode::Increment: + break; + case TokenMode::Random: + Mode.emplace<RandomMode>(*IntPtrTy, *Options.MaxTokens, + M.createRNG(DEBUG_TYPE)); + break; + case TokenMode::TypeHash: + Mode.emplace<TypeHashMode>(*IntPtrTy, *Options.MaxTokens); + break; + } + } + + bool instrumentFunction(Function &F); + +private: + /// Returns the LibFunc (or NotLibFunc) if this call should be instrumented. + std::optional<LibFunc> + shouldInstrumentCall(const CallBase &CB, const TargetLibraryInfo &TLI) const; + + /// Returns true for functions that are eligible for instrumentation. + static bool isInstrumentableLibFunc(LibFunc Func, const CallBase &CB, + const TargetLibraryInfo &TLI); + + /// Returns true for isAllocationFn() functions that we should ignore. + static bool ignoreInstrumentableLibFunc(LibFunc Func); + + /// Replace a call/invoke with a call/invoke to the allocation function + /// with token ID. + bool replaceAllocationCall(CallBase *CB, LibFunc Func, + OptimizationRemarkEmitter &ORE, + const TargetLibraryInfo &TLI); + + /// Return replacement function for a LibFunc that takes a token ID. + FunctionCallee getTokenAllocFunction(const CallBase &CB, uint64_t TokenID, + LibFunc OriginalFunc); + + /// Return the token ID from metadata in the call. + uint64_t getToken(const CallBase &CB, OptimizationRemarkEmitter &ORE) { + return std::visit([&](auto &&Mode) { return Mode(CB, ORE); }, Mode); + } + + const AllocTokenOptions Options; + Module &Mod; + IntegerType *IntPtrTy = Mod.getDataLayout().getIntPtrType(Mod.getContext()); + FunctionAnalysisManager &FAM; + // Cache for replacement functions. + DenseMap<std::pair<LibFunc, uint64_t>, FunctionCallee> TokenAllocFunctions; + // Selected mode. + std::variant<IncrementMode, RandomMode, TypeHashMode> Mode; +}; + +bool AllocToken::instrumentFunction(Function &F) { + // Do not apply any instrumentation for naked functions. + if (F.hasFnAttribute(Attribute::Naked)) + return false; + if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation)) + return false; + // Don't touch available_externally functions, their actual body is elsewhere. + if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) + return false; + // Only instrument functions that have the sanitize_alloc_token attribute. + if (!F.hasFnAttribute(Attribute::SanitizeAllocToken)) + return false; + + auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); + auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); + SmallVector<std::pair<CallBase *, LibFunc>, 4> AllocCalls; + + // Collect all allocation calls to avoid iterator invalidation. + for (Instruction &I : instructions(F)) { + auto *CB = dyn_cast<CallBase>(&I); + if (!CB) + continue; + if (std::optional<LibFunc> Func = shouldInstrumentCall(*CB, TLI)) + AllocCalls.emplace_back(CB, Func.value()); + } + + bool Modified = false; + for (auto &[CB, Func] : AllocCalls) + Modified |= replaceAllocationCall(CB, Func, ORE, TLI); + + if (Modified) + NumFunctionsInstrumented++; + return Modified; +} + +std::optional<LibFunc> +AllocToken::shouldInstrumentCall(const CallBase &CB, + const TargetLibraryInfo &TLI) const { + const Function *Callee = CB.getCalledFunction(); + if (!Callee) + return std::nullopt; + + // Ignore nobuiltin of the CallBase, so that we can cover nobuiltin libcalls + // if requested via isInstrumentableLibFunc(). Note that isAllocationFn() is + // returning false for nobuiltin calls. + LibFunc Func; + if (TLI.getLibFunc(*Callee, Func)) { + if (isInstrumentableLibFunc(Func, CB, TLI)) + return Func; + } else if (Options.Extended && getAllocTokenMetadata(CB)) { + return NotLibFunc; + } + + return std::nullopt; +} + +bool AllocToken::isInstrumentableLibFunc(LibFunc Func, const CallBase &CB, + const TargetLibraryInfo &TLI) { + if (ignoreInstrumentableLibFunc(Func)) + return false; + + if (isAllocationFn(&CB, &TLI)) + return true; + + switch (Func) { + // These libfuncs don't return normal pointers, and are therefore not handled + // by isAllocationFn(). + case LibFunc_posix_memalign: + case LibFunc_size_returning_new: + case LibFunc_size_returning_new_hot_cold: + case LibFunc_size_returning_new_aligned: + case LibFunc_size_returning_new_aligned_hot_cold: + return true; + + // See comment above ClCoverReplaceableNew. + case LibFunc_Znwj: + case LibFunc_ZnwjRKSt9nothrow_t: + case LibFunc_ZnwjSt11align_val_t: + case LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t: + case LibFunc_Znwm: + case LibFunc_Znwm12__hot_cold_t: + case LibFunc_ZnwmRKSt9nothrow_t: + case LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t: + case LibFunc_ZnwmSt11align_val_t: + case LibFunc_ZnwmSt11align_val_t12__hot_cold_t: + case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t: + case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t: + case LibFunc_Znaj: + case LibFunc_ZnajRKSt9nothrow_t: + case LibFunc_ZnajSt11align_val_t: + case LibFunc_ZnajSt11align_val_tRKSt9nothrow_t: + case LibFunc_Znam: + case LibFunc_Znam12__hot_cold_t: + case LibFunc_ZnamRKSt9nothrow_t: + case LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t: + case LibFunc_ZnamSt11align_val_t: + case LibFunc_ZnamSt11align_val_t12__hot_cold_t: + case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t: + case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t: + return ClCoverReplaceableNew; + + default: + return false; + } +} + +bool AllocToken::ignoreInstrumentableLibFunc(LibFunc Func) { + switch (Func) { + case LibFunc_strdup: + case LibFunc_dunder_strdup: + case LibFunc_strndup: + case LibFunc_dunder_strndup: + return true; + default: + return false; + } +} + +bool AllocToken::replaceAllocationCall(CallBase *CB, LibFunc Func, + OptimizationRemarkEmitter &ORE, + const TargetLibraryInfo &TLI) { + uint64_t TokenID = getToken(*CB, ORE); + + FunctionCallee TokenAlloc = getTokenAllocFunction(*CB, TokenID, Func); + if (!TokenAlloc) + return false; + NumAllocationsInstrumented++; + + if (Options.FastABI) { + assert(TokenAlloc.getFunctionType()->getNumParams() == CB->arg_size()); + CB->setCalledFunction(TokenAlloc); + return true; + } + + IRBuilder<> IRB(CB); + // Original args. + SmallVector<Value *, 4> NewArgs{CB->args()}; + // Add token ID, truncated to IntPtrTy width. + NewArgs.push_back(ConstantInt::get(IntPtrTy, TokenID)); + assert(TokenAlloc.getFunctionType()->getNumParams() == NewArgs.size()); + + // Preserve invoke vs call semantics for exception handling. + CallBase *NewCall; + if (auto *II = dyn_cast<InvokeInst>(CB)) { + NewCall = IRB.CreateInvoke(TokenAlloc, II->getNormalDest(), + II->getUnwindDest(), NewArgs); + } else { + NewCall = IRB.CreateCall(TokenAlloc, NewArgs); + cast<CallInst>(NewCall)->setTailCall(CB->isTailCall()); + } + NewCall->setCallingConv(CB->getCallingConv()); + NewCall->copyMetadata(*CB); + NewCall->setAttributes(CB->getAttributes()); + + // Replace all uses and delete the old call. + CB->replaceAllUsesWith(NewCall); + CB->eraseFromParent(); + return true; +} + +FunctionCallee AllocToken::getTokenAllocFunction(const CallBase &CB, + uint64_t TokenID, + LibFunc OriginalFunc) { + std::optional<std::pair<LibFunc, uint64_t>> Key; + if (OriginalFunc != NotLibFunc) { + Key = std::make_pair(OriginalFunc, Options.FastABI ? TokenID : 0); + auto It = TokenAllocFunctions.find(*Key); + if (It != TokenAllocFunctions.end()) + return It->second; + } + + const Function *Callee = CB.getCalledFunction(); + if (!Callee) + return FunctionCallee(); + const FunctionType *OldFTy = Callee->getFunctionType(); + if (OldFTy->isVarArg()) + return FunctionCallee(); + // Copy params, and append token ID type. + Type *RetTy = OldFTy->getReturnType(); + SmallVector<Type *, 4> NewParams{OldFTy->params()}; + std::string TokenAllocName = ClFuncPrefix; + if (Options.FastABI) + TokenAllocName += utostr(TokenID) + "_"; + else + NewParams.push_back(IntPtrTy); // token ID + TokenAllocName += Callee->getName(); + FunctionType *NewFTy = FunctionType::get(RetTy, NewParams, false); + FunctionCallee TokenAlloc = Mod.getOrInsertFunction(TokenAllocName, NewFTy); + if (Function *F = dyn_cast<Function>(TokenAlloc.getCallee())) + F->copyAttributesFrom(Callee); // preserve attrs + + if (Key.has_value()) + TokenAllocFunctions[*Key] = TokenAlloc; + return TokenAlloc; +} + +} // namespace + +AllocTokenPass::AllocTokenPass(AllocTokenOptions Opts) + : Options(std::move(Opts)) {} + +PreservedAnalyses AllocTokenPass::run(Module &M, ModuleAnalysisManager &MAM) { + AllocToken Pass(Options, M, MAM); + bool Modified = false; + + for (Function &F : M) { + if (F.empty()) + continue; // declaration + Modified |= Pass.instrumentFunction(F); + } + + return Modified ? PreservedAnalyses::none().preserveSet<CFGAnalyses>() + : PreservedAnalyses::all(); +} diff --git a/llvm/lib/Transforms/Instrumentation/CMakeLists.txt b/llvm/lib/Transforms/Instrumentation/CMakeLists.txt index 15fd421..80576c6 100644 --- a/llvm/lib/Transforms/Instrumentation/CMakeLists.txt +++ b/llvm/lib/Transforms/Instrumentation/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_component_library(LLVMInstrumentation AddressSanitizer.cpp + AllocToken.cpp BoundsChecking.cpp CGProfile.cpp ControlHeightReduction.cpp diff --git a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index 480ff4a..5ba2167 100644 --- a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -261,6 +261,11 @@ static cl::opt<bool> ClIgnorePersonalityRoutine( "list, do not create a wrapper for it."), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClAddGlobalNameSuffix( + "dfsan-add-global-name-suffix", + cl::desc("Whether to add .dfsan suffix to global names"), cl::Hidden, + cl::init(true)); + static StringRef getGlobalTypeString(const GlobalValue &G) { // Types of GlobalVariables are always pointer types. Type *GType = G.getValueType(); @@ -1256,6 +1261,9 @@ DataFlowSanitizer::WrapperKind DataFlowSanitizer::getWrapperKind(Function *F) { } void DataFlowSanitizer::addGlobalNameSuffix(GlobalValue *GV) { + if (!ClAddGlobalNameSuffix) + return; + std::string GVName = std::string(GV->getName()), Suffix = ".dfsan"; GV->setName(GVName + Suffix); @@ -1784,10 +1792,8 @@ bool DataFlowSanitizer::runImpl( } Value *DFSanFunction::getArgTLS(Type *T, unsigned ArgOffset, IRBuilder<> &IRB) { - Value *Base = IRB.CreatePointerCast(DFS.ArgTLS, DFS.IntptrTy); - if (ArgOffset) - Base = IRB.CreateAdd(Base, ConstantInt::get(DFS.IntptrTy, ArgOffset)); - return IRB.CreateIntToPtr(Base, PointerType::get(*DFS.Ctx, 0), "_dfsarg"); + return IRB.CreatePtrAdd(DFS.ArgTLS, ConstantInt::get(DFS.IntptrTy, ArgOffset), + "_dfsarg"); } Value *DFSanFunction::getRetvalTLS(Type *T, IRBuilder<> &IRB) { diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp index e9a3e98..584cdad 100644 --- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp @@ -89,6 +89,7 @@ STATISTIC(NumTransforms, "Number of transformations done"); STATISTIC(NumCloned, "Number of blocks cloned"); STATISTIC(NumPaths, "Number of individual paths threaded"); +namespace llvm { static cl::opt<bool> ClViewCfgBefore("dfa-jump-view-cfg-before", cl::desc("View the CFG before DFA Jump Threading"), @@ -120,8 +121,17 @@ static cl::opt<unsigned> cl::desc("Maximum cost accepted for the transformation"), cl::Hidden, cl::init(50)); -namespace { +extern cl::opt<bool> ProfcheckDisableMetadataFixes; + +} // namespace llvm + +static cl::opt<double> MaxClonedRate( + "dfa-max-cloned-rate", + cl::desc( + "Maximum cloned instructions rate accepted for the transformation"), + cl::Hidden, cl::init(7.5)); +namespace { class SelectInstToUnfold { SelectInst *SI; PHINode *SIUse; @@ -135,10 +145,6 @@ public: explicit operator bool() const { return SI && SIUse; } }; -void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold, - std::vector<SelectInstToUnfold> *NewSIsToUnfold, - std::vector<BasicBlock *> *NewBBs); - class DFAJumpThreading { public: DFAJumpThreading(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI, @@ -152,7 +158,8 @@ private: void unfoldSelectInstrs(DominatorTree *DT, const SmallVector<SelectInstToUnfold, 4> &SelectInsts) { - DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + // TODO: Have everything use a single lazy DTU + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); SmallVector<SelectInstToUnfold, 4> Stack(SelectInsts); while (!Stack.empty()) { @@ -167,16 +174,18 @@ private: } } + static void unfold(DomTreeUpdater *DTU, LoopInfo *LI, + SelectInstToUnfold SIToUnfold, + std::vector<SelectInstToUnfold> *NewSIsToUnfold, + std::vector<BasicBlock *> *NewBBs); + AssumptionCache *AC; DominatorTree *DT; LoopInfo *LI; TargetTransformInfo *TTI; OptimizationRemarkEmitter *ORE; }; - -} // end anonymous namespace - -namespace { +} // namespace /// Unfold the select instruction held in \p SIToUnfold by replacing it with /// control flow. @@ -185,9 +194,10 @@ namespace { /// created basic blocks into \p NewBBs. /// /// TODO: merge it with CodeGenPrepare::optimizeSelectInst() if possible. -void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold, - std::vector<SelectInstToUnfold> *NewSIsToUnfold, - std::vector<BasicBlock *> *NewBBs) { +void DFAJumpThreading::unfold(DomTreeUpdater *DTU, LoopInfo *LI, + SelectInstToUnfold SIToUnfold, + std::vector<SelectInstToUnfold> *NewSIsToUnfold, + std::vector<BasicBlock *> *NewBBs) { SelectInst *SI = SIToUnfold.getInst(); PHINode *SIUse = SIToUnfold.getUse(); assert(SI->hasOneUse()); @@ -253,7 +263,11 @@ void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold, // Insert the real conditional branch based on the original condition. StartBlockTerm->eraseFromParent(); - BranchInst::Create(EndBlock, NewBlock, SI->getCondition(), StartBlock); + auto *BI = + BranchInst::Create(EndBlock, NewBlock, SI->getCondition(), StartBlock); + if (!ProfcheckDisableMetadataFixes) + BI->setMetadata(LLVMContext::MD_prof, + SI->getMetadata(LLVMContext::MD_prof)); DTU->applyUpdates({{DominatorTree::Insert, StartBlock, EndBlock}, {DominatorTree::Insert, StartBlock, NewBlock}}); } else { @@ -288,7 +302,11 @@ void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold, // (Use) BranchInst::Create(EndBlock, NewBlockF); // Insert the real conditional branch based on the original condition. - BranchInst::Create(EndBlock, NewBlockF, SI->getCondition(), NewBlockT); + auto *BI = + BranchInst::Create(EndBlock, NewBlockF, SI->getCondition(), NewBlockT); + if (!ProfcheckDisableMetadataFixes) + BI->setMetadata(LLVMContext::MD_prof, + SI->getMetadata(LLVMContext::MD_prof)); DTU->applyUpdates({{DominatorTree::Insert, NewBlockT, NewBlockF}, {DominatorTree::Insert, NewBlockT, EndBlock}, {DominatorTree::Insert, NewBlockF, EndBlock}}); @@ -342,10 +360,12 @@ void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold, SI->eraseFromParent(); } +namespace { struct ClonedBlock { BasicBlock *BB; APInt State; ///< \p State corresponds to the next value of a switch stmnt. }; +} // namespace typedef std::deque<BasicBlock *> PathType; typedef std::vector<PathType> PathsType; @@ -375,6 +395,7 @@ inline raw_ostream &operator<<(raw_ostream &OS, const PathType &Path) { return OS; } +namespace { /// ThreadingPath is a path in the control flow of a loop that can be threaded /// by cloning necessary basic blocks and replacing conditional branches with /// unconditional ones. A threading path includes a list of basic blocks, the @@ -814,11 +835,13 @@ struct TransformDFA { : SwitchPaths(SwitchPaths), DT(DT), AC(AC), TTI(TTI), ORE(ORE), EphValues(EphValues) {} - void run() { + bool run() { if (isLegalAndProfitableToTransform()) { createAllExitPaths(); NumTransforms++; + return true; } + return false; } private: @@ -828,6 +851,7 @@ private: /// also returns false if it is illegal to clone some required block. bool isLegalAndProfitableToTransform() { CodeMetrics Metrics; + uint64_t NumClonedInst = 0; SwitchInst *Switch = SwitchPaths->getSwitchInst(); // Don't thread switch without multiple successors. @@ -837,7 +861,6 @@ private: // Note that DuplicateBlockMap is not being used as intended here. It is // just being used to ensure (BB, State) pairs are only counted once. DuplicateBlockMap DuplicateMap; - for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) { PathType PathBBs = TPath.getPath(); APInt NextState = TPath.getExitValue(); @@ -848,6 +871,7 @@ private: BasicBlock *VisitedBB = getClonedBB(BB, NextState, DuplicateMap); if (!VisitedBB) { Metrics.analyzeBasicBlock(BB, *TTI, EphValues); + NumClonedInst += BB->sizeWithoutDebug(); DuplicateMap[BB].push_back({BB, NextState}); } @@ -865,6 +889,7 @@ private: if (VisitedBB) continue; Metrics.analyzeBasicBlock(BB, *TTI, EphValues); + NumClonedInst += BB->sizeWithoutDebug(); DuplicateMap[BB].push_back({BB, NextState}); } @@ -901,6 +926,22 @@ private: } } + // Too much cloned instructions slow down later optimizations, especially + // SLPVectorizer. + // TODO: Thread the switch partially before reaching the threshold. + uint64_t NumOrigInst = 0; + for (auto *BB : DuplicateMap.keys()) + NumOrigInst += BB->sizeWithoutDebug(); + if (double(NumClonedInst) / double(NumOrigInst) > MaxClonedRate) { + LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, too much " + "instructions wll be cloned\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NotProfitable", Switch) + << "Too much instructions will be cloned."; + }); + return false; + } + InstructionCost DuplicationCost = 0; unsigned JumpTableSize = 0; @@ -951,8 +992,6 @@ private: /// Transform each threading path to effectively jump thread the DFA. void createAllExitPaths() { - DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Eager); - // Move the switch block to the end of the path, since it will be duplicated BasicBlock *SwitchBlock = SwitchPaths->getSwitchBlock(); for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) { @@ -969,15 +1008,18 @@ private: SmallPtrSet<BasicBlock *, 16> BlocksToClean; BlocksToClean.insert_range(successors(SwitchBlock)); - for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) { - createExitPath(NewDefs, TPath, DuplicateMap, BlocksToClean, &DTU); - NumPaths++; - } + { + DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Lazy); + for (const ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) { + createExitPath(NewDefs, TPath, DuplicateMap, BlocksToClean, &DTU); + NumPaths++; + } - // After all paths are cloned, now update the last successor of the cloned - // path so it skips over the switch statement - for (ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) - updateLastSuccessor(TPath, DuplicateMap, &DTU); + // After all paths are cloned, now update the last successor of the cloned + // path so it skips over the switch statement + for (const ThreadingPath &TPath : SwitchPaths->getThreadingPaths()) + updateLastSuccessor(TPath, DuplicateMap, &DTU); + } // For each instruction that was cloned and used outside, update its uses updateSSA(NewDefs); @@ -993,7 +1035,7 @@ private: /// To remember the correct destination, we have to duplicate blocks /// corresponding to each state. Also update the terminating instruction of /// the predecessors, and phis in the successor blocks. - void createExitPath(DefMap &NewDefs, ThreadingPath &Path, + void createExitPath(DefMap &NewDefs, const ThreadingPath &Path, DuplicateBlockMap &DuplicateMap, SmallPtrSet<BasicBlock *, 16> &BlocksToClean, DomTreeUpdater *DTU) { @@ -1239,7 +1281,7 @@ private: /// /// Note that this is an optional step and would have been done in later /// optimizations, but it makes the CFG significantly easier to work with. - void updateLastSuccessor(ThreadingPath &TPath, + void updateLastSuccessor(const ThreadingPath &TPath, DuplicateBlockMap &DuplicateMap, DomTreeUpdater *DTU) { APInt NextState = TPath.getExitValue(); @@ -1336,6 +1378,7 @@ private: SmallPtrSet<const Value *, 32> EphValues; std::vector<ThreadingPath> TPaths; }; +} // namespace bool DFAJumpThreading::run(Function &F) { LLVM_DEBUG(dbgs() << "\nDFA Jump threading: " << F.getName() << "\n"); @@ -1402,9 +1445,8 @@ bool DFAJumpThreading::run(Function &F) { for (AllSwitchPaths SwitchPaths : ThreadableLoops) { TransformDFA Transform(&SwitchPaths, DT, AC, TTI, ORE, EphValues); - Transform.run(); - MadeChanges = true; - LoopInfoBroken = true; + if (Transform.run()) + MadeChanges = LoopInfoBroken = true; } #ifdef EXPENSIVE_CHECKS @@ -1415,8 +1457,6 @@ bool DFAJumpThreading::run(Function &F) { return MadeChanges; } -} // end anonymous namespace - /// Integrate with the new Pass Manager PreservedAnalyses DFAJumpThreadingPass::run(Function &F, FunctionAnalysisManager &AM) { diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index bbd1ed6..5ba6f95f 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -970,6 +970,7 @@ Function *CodeExtractor::constructFunctionDeclaration( case Attribute::SanitizeMemTag: case Attribute::SanitizeRealtime: case Attribute::SanitizeRealtimeBlocking: + case Attribute::SanitizeAllocToken: case Attribute::SpeculativeLoadHardening: case Attribute::StackProtect: case Attribute::StackProtectReq: diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index 21b2652..b6ca52e 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -3031,6 +3031,13 @@ static void combineMetadata(Instruction *K, const Instruction *J, K->getContext(), MDNode::toCaptureComponents(JMD) | MDNode::toCaptureComponents(KMD))); break; + case LLVMContext::MD_alloc_token: + // Preserve !alloc_token if both K and J have it, and they are equal. + if (KMD == JMD) + K->setMetadata(Kind, JMD); + else + K->setMetadata(Kind, nullptr); + break; } } // Set !invariant.group from J if J has it. If both instructions have it diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index bf882d7..6312831 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -201,18 +201,27 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, /// unroll count is non-zero. /// /// This function performs the following: -/// - Update PHI nodes at the unrolling loop exit and epilog loop exit -/// - Create PHI nodes at the unrolling loop exit to combine -/// values that exit the unrolling loop code and jump around it. +/// - Update PHI nodes at the epilog loop exit +/// - Create PHI nodes at the unrolling loop exit and epilog preheader to +/// combine values that exit the unrolling loop code and jump around it. /// - Update PHI operands in the epilog loop by the new PHI nodes -/// - Branch around the epilog loop if extra iters (ModVal) is zero. +/// - At the unrolling loop exit, branch around the epilog loop if extra iters +// (ModVal) is zero. +/// - At the epilog preheader, add an llvm.assume call that extra iters is +/// non-zero. If the unrolling loop exit is the predecessor, the above new +/// branch guarantees that assumption. If the unrolling loop preheader is the +/// predecessor, then the required first iteration from the original loop has +/// yet to be executed, so it must be executed in the epilog loop. If we +/// later unroll the epilog loop, that llvm.assume call somehow enables +/// ScalarEvolution to compute a epilog loop maximum trip count, which enables +/// eliminating the branch at the end of the final unrolled epilog iteration. /// static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, BasicBlock *Exit, BasicBlock *PreHeader, BasicBlock *EpilogPreHeader, BasicBlock *NewPreHeader, ValueToValueMapTy &VMap, DominatorTree *DT, LoopInfo *LI, bool PreserveLCSSA, ScalarEvolution &SE, - unsigned Count) { + unsigned Count, AssumptionCache &AC) { BasicBlock *Latch = L->getLoopLatch(); assert(Latch && "Loop must have a latch"); BasicBlock *EpilogLatch = cast<BasicBlock>(VMap[Latch]); @@ -231,7 +240,7 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, // EpilogLatch // Exit (EpilogPN) - // Update PHI nodes at NewExit and Exit. + // Update PHI nodes at Exit. for (PHINode &PN : NewExit->phis()) { // PN should be used in another PHI located in Exit block as // Exit was split by SplitBlockPredecessors into Exit and NewExit @@ -246,15 +255,11 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, // epilogue edges have already been added. // // There is EpilogPreHeader incoming block instead of NewExit as - // NewExit was spilt 1 more time to get EpilogPreHeader. + // NewExit was split 1 more time to get EpilogPreHeader. assert(PN.hasOneUse() && "The phi should have 1 use"); PHINode *EpilogPN = cast<PHINode>(PN.use_begin()->getUser()); assert(EpilogPN->getParent() == Exit && "EpilogPN should be in Exit block"); - // Add incoming PreHeader from branch around the Loop - PN.addIncoming(PoisonValue::get(PN.getType()), PreHeader); - SE.forgetValue(&PN); - Value *V = PN.getIncomingValueForBlock(Latch); Instruction *I = dyn_cast<Instruction>(V); if (I && L->contains(I)) @@ -271,35 +276,52 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, NewExit); // Now PHIs should look like: // NewExit: - // PN = PHI [I, Latch], [poison, PreHeader] + // PN = PHI [I, Latch] // ... // Exit: // EpilogPN = PHI [PN, NewExit], [VMap[I], EpilogLatch] } - // Create PHI nodes at NewExit (from the unrolling loop Latch and PreHeader). - // Update corresponding PHI nodes in epilog loop. + // Create PHI nodes at NewExit (from the unrolling loop Latch) and at + // EpilogPreHeader (from PreHeader and NewExit). Update corresponding PHI + // nodes in epilog loop. for (BasicBlock *Succ : successors(Latch)) { // Skip this as we already updated phis in exit blocks. if (!L->contains(Succ)) continue; + + // Succ here appears to always be just L->getHeader(). Otherwise, how do we + // know its corresponding epilog block (from VMap) is EpilogHeader and thus + // EpilogPreHeader is the right incoming block for VPN, as set below? + // TODO: Can we thus avoid the enclosing loop over successors? + assert(Succ == L->getHeader() && + "Expect the only in-loop successor of latch to be the loop header"); + for (PHINode &PN : Succ->phis()) { - // Add new PHI nodes to the loop exit block and update epilog - // PHIs with the new PHI values. - PHINode *NewPN = PHINode::Create(PN.getType(), 2, PN.getName() + ".unr"); - NewPN->insertBefore(NewExit->getFirstNonPHIIt()); - // Adding a value to the new PHI node from the unrolling loop preheader. - NewPN->addIncoming(PN.getIncomingValueForBlock(NewPreHeader), PreHeader); - // Adding a value to the new PHI node from the unrolling loop latch. - NewPN->addIncoming(PN.getIncomingValueForBlock(Latch), Latch); + // Add new PHI nodes to the loop exit block. + PHINode *NewPN0 = PHINode::Create(PN.getType(), /*NumReservedValues=*/1, + PN.getName() + ".unr"); + NewPN0->insertBefore(NewExit->getFirstNonPHIIt()); + // Add value to the new PHI node from the unrolling loop latch. + NewPN0->addIncoming(PN.getIncomingValueForBlock(Latch), Latch); + + // Add new PHI nodes to EpilogPreHeader. + PHINode *NewPN1 = PHINode::Create(PN.getType(), /*NumReservedValues=*/2, + PN.getName() + ".epil.init"); + NewPN1->insertBefore(EpilogPreHeader->getFirstNonPHIIt()); + // Add value to the new PHI node from the unrolling loop preheader. + NewPN1->addIncoming(PN.getIncomingValueForBlock(NewPreHeader), PreHeader); + // Add value to the new PHI node from the epilog loop guard. + NewPN1->addIncoming(NewPN0, NewExit); // Update the existing PHI node operand with the value from the new PHI // node. Corresponding instruction in epilog loop should be PHI. PHINode *VPN = cast<PHINode>(VMap[&PN]); - VPN->setIncomingValueForBlock(EpilogPreHeader, NewPN); + VPN->setIncomingValueForBlock(EpilogPreHeader, NewPN1); } } + // In NewExit, branch around the epilog loop if no extra iters. Instruction *InsertPt = NewExit->getTerminator(); IRBuilder<> B(InsertPt); Value *BrLoopExit = B.CreateIsNotNull(ModVal, "lcmp.mod"); @@ -308,7 +330,7 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, SmallVector<BasicBlock*, 4> Preds(predecessors(Exit)); SplitBlockPredecessors(Exit, Preds, ".epilog-lcssa", DT, LI, nullptr, PreserveLCSSA); - // Add the branch to the exit block (around the unrolling loop) + // Add the branch to the exit block (around the epilog loop) MDNode *BranchWeights = nullptr; if (hasBranchWeightMD(*Latch->getTerminator())) { // Assume equal distribution in interval [0, Count). @@ -322,10 +344,11 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, DT->changeImmediateDominator(Exit, NewDom); } - // Split the main loop exit to maintain canonicalization guarantees. - SmallVector<BasicBlock*, 4> NewExitPreds{Latch}; - SplitBlockPredecessors(NewExit, NewExitPreds, ".loopexit", DT, LI, nullptr, - PreserveLCSSA); + // In EpilogPreHeader, assume extra iters is non-zero. + IRBuilder<> B2(EpilogPreHeader, EpilogPreHeader->getFirstNonPHIIt()); + Value *ModIsNotNull = B2.CreateIsNotNull(ModVal, "lcmp.mod"); + AssumeInst *AI = cast<AssumeInst>(B2.CreateAssumption(ModIsNotNull)); + AC.registerAssumption(AI); } /// Create a clone of the blocks in a loop and connect them together. A new @@ -795,7 +818,8 @@ bool llvm::UnrollRuntimeLoopRemainder( ConstantInt::get(BECount->getType(), Count - 1)) : B.CreateIsNotNull(ModVal, "lcmp.mod"); - BasicBlock *RemainderLoop = UseEpilogRemainder ? NewExit : PrologPreHeader; + BasicBlock *RemainderLoop = + UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader; BasicBlock *UnrollingLoop = UseEpilogRemainder ? NewPreHeader : PrologExit; // Branch to either remainder (extra iterations) loop or unrolling loop. MDNode *BranchWeights = nullptr; @@ -808,7 +832,7 @@ bool llvm::UnrollRuntimeLoopRemainder( PreHeaderBR->eraseFromParent(); if (DT) { if (UseEpilogRemainder) - DT->changeImmediateDominator(NewExit, PreHeader); + DT->changeImmediateDominator(EpilogPreHeader, PreHeader); else DT->changeImmediateDominator(PrologExit, PreHeader); } @@ -880,7 +904,8 @@ bool llvm::UnrollRuntimeLoopRemainder( // from both the original loop and the remainder code reaching the exit // blocks. While the IDom of these exit blocks were from the original loop, // now the IDom is the preheader (which decides whether the original loop or - // remainder code should run). + // remainder code should run) unless the block still has just the original + // predecessor (such as NewExit in the case of an epilog remainder). if (DT && !L->getExitingBlock()) { SmallVector<BasicBlock *, 16> ChildrenToUpdate; // NB! We have to examine the dom children of all loop blocks, not just @@ -891,7 +916,8 @@ bool llvm::UnrollRuntimeLoopRemainder( auto *DomNodeBB = DT->getNode(BB); for (auto *DomChild : DomNodeBB->children()) { auto *DomChildBB = DomChild->getBlock(); - if (!L->contains(LI->getLoopFor(DomChildBB))) + if (!L->contains(LI->getLoopFor(DomChildBB)) && + DomChildBB->getUniquePredecessor() != BB) ChildrenToUpdate.push_back(DomChildBB); } } @@ -930,7 +956,7 @@ bool llvm::UnrollRuntimeLoopRemainder( // Connect the epilog code to the original loop and update the // PHI functions. ConnectEpilog(L, ModVal, NewExit, LatchExit, PreHeader, EpilogPreHeader, - NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count); + NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC); // Update counter in loop for unrolling. // Use an incrementing IV. Pre-incr/post-incr is backedge/trip count. diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index c167dd7..fb696be 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2263,8 +2263,7 @@ public: /// debug location \p DL. VPWidenPHIRecipe(PHINode *Phi, VPValue *Start = nullptr, DebugLoc DL = DebugLoc::getUnknown(), const Twine &Name = "") - : VPSingleDefRecipe(VPDef::VPWidenPHISC, ArrayRef<VPValue *>(), Phi, DL), - Name(Name.str()) { + : VPSingleDefRecipe(VPDef::VPWidenPHISC, {}, Phi, DL), Name(Name.str()) { if (Start) addOperand(Start); } diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp index c8212af..b36298f 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp @@ -763,6 +763,8 @@ void VPlanTransforms::addMinimumVectorEpilogueIterationCheck( // Add the minimum iteration check for the epilogue vector loop. VPValue *TC = Plan.getOrAddLiveIn(TripCount); VPBuilder Builder(cast<VPBasicBlock>(Plan.getEntry())); + VPValue *VFxUF = Builder.createExpandSCEV(SE.getElementCount( + TripCount->getType(), (EpilogueVF * EpilogueUF), SCEV::FlagNUW)); VPValue *Count = Builder.createNaryOp( Instruction::Sub, {TC, Plan.getOrAddLiveIn(VectorTripCount)}, DebugLoc::getUnknown(), "n.vec.remaining"); @@ -770,9 +772,6 @@ void VPlanTransforms::addMinimumVectorEpilogueIterationCheck( // Generate code to check if the loop's trip count is less than VF * UF of // the vector epilogue loop. auto P = RequiresScalarEpilogue ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_ULT; - VPValue *VFxUF = Builder.createExpandSCEV(SE.getElementCount( - TripCount->getType(), (EpilogueVF * EpilogueUF), SCEV::FlagNUW)); - auto *CheckMinIters = Builder.createICmp( P, Count, VFxUF, DebugLoc::getUnknown(), "min.epilog.iters.check"); VPInstruction *Branch = diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index ebf833e..c8a2d84 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -3180,9 +3180,8 @@ expandVPWidenIntOrFpInduction(VPWidenIntOrFpInductionRecipe *WidenIVR, DebugLoc::getUnknown(), "induction"); // Create the widened phi of the vector IV. - auto *WidePHI = new VPWidenPHIRecipe(WidenIVR->getPHINode(), nullptr, + auto *WidePHI = new VPWidenPHIRecipe(WidenIVR->getPHINode(), Init, WidenIVR->getDebugLoc(), "vec.ind"); - WidePHI->addOperand(Init); WidePHI->insertBefore(WidenIVR); // Create the backedge value for the vector IV. @@ -3545,8 +3544,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, VPValue *A, *B; VPValue *Tmp = nullptr; // Sub reductions could have a sub between the add reduction and vec op. - if (match(VecOp, - m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Tmp)))) { + if (match(VecOp, m_Sub(m_ZeroInt(), m_VPValue(Tmp)))) { Sub = VecOp->getDefiningRecipe(); VecOp = Tmp; } diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp index 0599930..66748c5 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp @@ -71,8 +71,8 @@ bool vputils::isHeaderMask(const VPValue *V, VPlan &Plan) { m_Specific(&Plan.getVF()))) || IsWideCanonicalIV(A)); - return match(V, m_Binary<Instruction::ICmp>(m_VPValue(A), m_VPValue(B))) && - IsWideCanonicalIV(A) && B == Plan.getOrCreateBackedgeTakenCount(); + return match(V, m_ICmp(m_VPValue(A), m_VPValue(B))) && IsWideCanonicalIV(A) && + B == Plan.getOrCreateBackedgeTakenCount(); } const SCEV *vputils::getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE) { |