diff options
Diffstat (limited to 'llvm/lib')
119 files changed, 5332 insertions, 1988 deletions
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index a5ba197..e9e2e7d 100755 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -4056,8 +4056,8 @@ static Constant *ConstantFoldFixedVectorCall( switch (IntrinsicID) { case Intrinsic::masked_load: { auto *SrcPtr = Operands[0]; - auto *Mask = Operands[2]; - auto *Passthru = Operands[3]; + auto *Mask = Operands[1]; + auto *Passthru = Operands[2]; Constant *VecData = ConstantFoldLoadFromConstPtr(SrcPtr, FVTy, DL); diff --git a/llvm/lib/Analysis/DependenceAnalysis.cpp b/llvm/lib/Analysis/DependenceAnalysis.cpp index 805b682..0a8c2f8 100644 --- a/llvm/lib/Analysis/DependenceAnalysis.cpp +++ b/llvm/lib/Analysis/DependenceAnalysis.cpp @@ -128,6 +128,18 @@ static cl::opt<bool> RunSIVRoutinesOnly( "The purpose is mainly to exclude the influence of those routines " "in regression tests for SIV routines.")); +// TODO: This flag is disabled by default because it is still under development. +// Enable it or delete this flag when the feature is ready. +static cl::opt<bool> EnableMonotonicityCheck( + "da-enable-monotonicity-check", cl::init(false), cl::Hidden, + cl::desc("Check if the subscripts are monotonic. If it's not, dependence " + "is reported as unknown.")); + +static cl::opt<bool> DumpMonotonicityReport( + "da-dump-monotonicity-report", cl::init(false), cl::Hidden, + cl::desc( + "When printing analysis, dump the results of monotonicity checks.")); + //===----------------------------------------------------------------------===// // basics @@ -177,13 +189,196 @@ void DependenceAnalysisWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequiredTransitive<LoopInfoWrapperPass>(); } +namespace { + +/// The property of monotonicity of a SCEV. To define the monotonicity, assume +/// a SCEV defined within N-nested loops. Let i_k denote the iteration number +/// of the k-th loop. Then we can regard the SCEV as an N-ary function: +/// +/// F(i_1, i_2, ..., i_N) +/// +/// The domain of i_k is the closed range [0, BTC_k], where BTC_k is the +/// backedge-taken count of the k-th loop. +/// +/// A function F is said to be "monotonically increasing with respect to the +/// k-th loop" if x <= y implies the following condition: +/// +/// F(i_1, ..., i_{k-1}, x, i_{k+1}, ..., i_N) <= +/// F(i_1, ..., i_{k-1}, y, i_{k+1}, ..., i_N) +/// +/// where i_1, ..., i_{k-1}, i_{k+1}, ..., i_N, x, and y are elements of their +/// respective domains. +/// +/// Likewise F is "monotonically decreasing with respect to the k-th loop" +/// if x <= y implies +/// +/// F(i_1, ..., i_{k-1}, x, i_{k+1}, ..., i_N) >= +/// F(i_1, ..., i_{k-1}, y, i_{k+1}, ..., i_N) +/// +/// A function F that is monotonically increasing or decreasing with respect to +/// the k-th loop is simply called "monotonic with respect to k-th loop". +/// +/// A function F is said to be "multivariate monotonic" when it is monotonic +/// with respect to all of the N loops. +/// +/// Since integer comparison can be either signed or unsigned, we need to +/// distinguish monotonicity in the signed sense from that in the unsigned +/// sense. Note that the inequality "x <= y" merely indicates loop progression +/// and is not affected by the difference between signed and unsigned order. +/// +/// Currently we only consider monotonicity in a signed sense. +enum class SCEVMonotonicityType { + /// We don't know anything about the monotonicity of the SCEV. + Unknown, + + /// The SCEV is loop-invariant with respect to the outermost loop. In other + /// words, the function F corresponding to the SCEV is a constant function. + Invariant, + + /// The function F corresponding to the SCEV is multivariate monotonic in a + /// signed sense. Note that the multivariate monotonic function may also be a + /// constant function. The order employed in the definition of monotonicity + /// is not strict order. + MultivariateSignedMonotonic, +}; + +struct SCEVMonotonicity { + SCEVMonotonicity(SCEVMonotonicityType Type, + const SCEV *FailurePoint = nullptr); + + SCEVMonotonicityType getType() const { return Type; } + + const SCEV *getFailurePoint() const { return FailurePoint; } + + bool isUnknown() const { return Type == SCEVMonotonicityType::Unknown; } + + void print(raw_ostream &OS, unsigned Depth) const; + +private: + SCEVMonotonicityType Type; + + /// The subexpression that caused Unknown. Mainly for debugging purpose. + const SCEV *FailurePoint; +}; + +/// Check the monotonicity of a SCEV. Since dependence tests (SIV, MIV, etc.) +/// assume that subscript expressions are (multivariate) monotonic, we need to +/// verify this property before applying those tests. Violating this assumption +/// may cause them to produce incorrect results. +struct SCEVMonotonicityChecker + : public SCEVVisitor<SCEVMonotonicityChecker, SCEVMonotonicity> { + + SCEVMonotonicityChecker(ScalarEvolution *SE) : SE(SE) {} + + /// Check the monotonicity of \p Expr. \p Expr must be integer type. If \p + /// OutermostLoop is not null, \p Expr must be defined in \p OutermostLoop or + /// one of its nested loops. + SCEVMonotonicity checkMonotonicity(const SCEV *Expr, + const Loop *OutermostLoop); + +private: + ScalarEvolution *SE; + + /// The outermost loop that DA is analyzing. + const Loop *OutermostLoop; + + /// A helper to classify \p Expr as either Invariant or Unknown. + SCEVMonotonicity invariantOrUnknown(const SCEV *Expr); + + /// Return true if \p Expr is loop-invariant with respect to the outermost + /// loop. + bool isLoopInvariant(const SCEV *Expr) const; + + /// A helper to create an Unknown SCEVMonotonicity. + SCEVMonotonicity createUnknown(const SCEV *FailurePoint) { + return SCEVMonotonicity(SCEVMonotonicityType::Unknown, FailurePoint); + } + + SCEVMonotonicity visitAddRecExpr(const SCEVAddRecExpr *Expr); + + SCEVMonotonicity visitConstant(const SCEVConstant *) { + return SCEVMonotonicity(SCEVMonotonicityType::Invariant); + } + SCEVMonotonicity visitVScale(const SCEVVScale *) { + return SCEVMonotonicity(SCEVMonotonicityType::Invariant); + } + + // TODO: Handle more cases. + SCEVMonotonicity visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitAddExpr(const SCEVAddExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitMulExpr(const SCEVMulExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitTruncateExpr(const SCEVTruncateExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitUDivExpr(const SCEVUDivExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitSMaxExpr(const SCEVSMaxExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitUMaxExpr(const SCEVUMaxExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitSMinExpr(const SCEVSMinExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitUMinExpr(const SCEVUMinExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitUnknown(const SCEVUnknown *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + return invariantOrUnknown(Expr); + } + + friend struct SCEVVisitor<SCEVMonotonicityChecker, SCEVMonotonicity>; +}; + +} // anonymous namespace + // Used to test the dependence analyzer. // Looks through the function, noting instructions that may access memory. // Calls depends() on every possible pair and prints out the result. // Ignores all other instructions. static void dumpExampleDependence(raw_ostream &OS, DependenceInfo *DA, - ScalarEvolution &SE, bool NormalizeResults) { + ScalarEvolution &SE, LoopInfo &LI, + bool NormalizeResults) { auto *F = DA->getFunction(); + + if (DumpMonotonicityReport) { + SCEVMonotonicityChecker Checker(&SE); + OS << "Monotonicity check:\n"; + for (Instruction &Inst : instructions(F)) { + if (!isa<LoadInst>(Inst) && !isa<StoreInst>(Inst)) + continue; + Value *Ptr = getLoadStorePointerOperand(&Inst); + const Loop *L = LI.getLoopFor(Inst.getParent()); + const SCEV *PtrSCEV = SE.getSCEVAtScope(Ptr, L); + const SCEV *AccessFn = SE.removePointerBase(PtrSCEV); + SCEVMonotonicity Mon = Checker.checkMonotonicity(AccessFn, L); + OS.indent(2) << "Inst: " << Inst << "\n"; + OS.indent(4) << "Expr: " << *AccessFn << "\n"; + Mon.print(OS, 4); + } + OS << "\n"; + } + for (inst_iterator SrcI = inst_begin(F), SrcE = inst_end(F); SrcI != SrcE; ++SrcI) { if (SrcI->mayReadOrWriteMemory()) { @@ -235,7 +430,8 @@ static void dumpExampleDependence(raw_ostream &OS, DependenceInfo *DA, void DependenceAnalysisWrapperPass::print(raw_ostream &OS, const Module *) const { dumpExampleDependence( - OS, info.get(), getAnalysis<ScalarEvolutionWrapperPass>().getSE(), false); + OS, info.get(), getAnalysis<ScalarEvolutionWrapperPass>().getSE(), + getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), false); } PreservedAnalyses @@ -244,7 +440,7 @@ DependenceAnalysisPrinterPass::run(Function &F, FunctionAnalysisManager &FAM) { << "':\n"; dumpExampleDependence(OS, &FAM.getResult<DependenceAnalysis>(F), FAM.getResult<ScalarEvolutionAnalysis>(F), - NormalizeResults); + FAM.getResult<LoopAnalysis>(F), NormalizeResults); return PreservedAnalyses::all(); } @@ -671,6 +867,81 @@ bool DependenceInfo::intersectConstraints(Constraint *X, const Constraint *Y) { } //===----------------------------------------------------------------------===// +// SCEVMonotonicity + +SCEVMonotonicity::SCEVMonotonicity(SCEVMonotonicityType Type, + const SCEV *FailurePoint) + : Type(Type), FailurePoint(FailurePoint) { + assert( + ((Type == SCEVMonotonicityType::Unknown) == (FailurePoint != nullptr)) && + "FailurePoint must be provided iff Type is Unknown"); +} + +void SCEVMonotonicity::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << "Monotonicity: "; + switch (Type) { + case SCEVMonotonicityType::Unknown: + assert(FailurePoint && "FailurePoint must be provided for Unknown"); + OS << "Unknown\n"; + OS.indent(Depth) << "Reason: " << *FailurePoint << "\n"; + break; + case SCEVMonotonicityType::Invariant: + OS << "Invariant\n"; + break; + case SCEVMonotonicityType::MultivariateSignedMonotonic: + OS << "MultivariateSignedMonotonic\n"; + break; + } +} + +bool SCEVMonotonicityChecker::isLoopInvariant(const SCEV *Expr) const { + return !OutermostLoop || SE->isLoopInvariant(Expr, OutermostLoop); +} + +SCEVMonotonicity SCEVMonotonicityChecker::invariantOrUnknown(const SCEV *Expr) { + if (isLoopInvariant(Expr)) + return SCEVMonotonicity(SCEVMonotonicityType::Invariant); + return createUnknown(Expr); +} + +SCEVMonotonicity +SCEVMonotonicityChecker::checkMonotonicity(const SCEV *Expr, + const Loop *OutermostLoop) { + assert(Expr->getType()->isIntegerTy() && "Expr must be integer type"); + this->OutermostLoop = OutermostLoop; + return visit(Expr); +} + +/// We only care about an affine AddRec at the moment. For an affine AddRec, +/// the monotonicity can be inferred from its nowrap property. For example, let +/// X and Y be loop-invariant, and assume Y is non-negative. An AddRec +/// {X,+.Y}<nsw> implies: +/// +/// X <=s (X + Y) <=s ((X + Y) + Y) <=s ... +/// +/// Thus, we can conclude that the AddRec is monotonically increasing with +/// respect to the associated loop in a signed sense. The similar reasoning +/// applies when Y is non-positive, leading to a monotonically decreasing +/// AddRec. +SCEVMonotonicity +SCEVMonotonicityChecker::visitAddRecExpr(const SCEVAddRecExpr *Expr) { + if (!Expr->isAffine() || !Expr->hasNoSignedWrap()) + return createUnknown(Expr); + + const SCEV *Start = Expr->getStart(); + const SCEV *Step = Expr->getStepRecurrence(*SE); + + SCEVMonotonicity StartMon = visit(Start); + if (StartMon.isUnknown()) + return StartMon; + + if (!isLoopInvariant(Step)) + return createUnknown(Expr); + + return SCEVMonotonicity(SCEVMonotonicityType::MultivariateSignedMonotonic); +} + +//===----------------------------------------------------------------------===// // DependenceInfo methods // For debugging purposes. Dumps a dependence to OS. @@ -3488,10 +3759,19 @@ bool DependenceInfo::tryDelinearize(Instruction *Src, Instruction *Dst, // resize Pair to contain as many pairs of subscripts as the delinearization // has found, and then initialize the pairs following the delinearization. Pair.resize(Size); + SCEVMonotonicityChecker MonChecker(SE); + const Loop *OutermostLoop = SrcLoop ? SrcLoop->getOutermostLoop() : nullptr; for (int I = 0; I < Size; ++I) { Pair[I].Src = SrcSubscripts[I]; Pair[I].Dst = DstSubscripts[I]; unifySubscriptType(&Pair[I]); + + if (EnableMonotonicityCheck) { + if (MonChecker.checkMonotonicity(Pair[I].Src, OutermostLoop).isUnknown()) + return false; + if (MonChecker.checkMonotonicity(Pair[I].Dst, OutermostLoop).isUnknown()) + return false; + } } return true; @@ -3824,6 +4104,14 @@ DependenceInfo::depends(Instruction *Src, Instruction *Dst, Pair[0].Src = SrcEv; Pair[0].Dst = DstEv; + SCEVMonotonicityChecker MonChecker(SE); + const Loop *OutermostLoop = SrcLoop ? SrcLoop->getOutermostLoop() : nullptr; + if (EnableMonotonicityCheck) + if (MonChecker.checkMonotonicity(Pair[0].Src, OutermostLoop).isUnknown() || + MonChecker.checkMonotonicity(Pair[0].Dst, OutermostLoop).isUnknown()) + return std::make_unique<Dependence>(Src, Dst, + SCEVUnionPredicate(Assume, *SE)); + if (Delinearize) { if (tryDelinearize(Src, Dst, Pair)) { LLVM_DEBUG(dbgs() << " delinearized\n"); diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index e08ef60..8da51d0 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -5106,32 +5106,33 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr, return Ptr; // The following transforms are only safe if the ptrtoint cast - // doesn't truncate the pointers. - if (Indices[0]->getType()->getScalarSizeInBits() == - Q.DL.getPointerSizeInBits(AS)) { + // doesn't truncate the address of the pointers. The non-address bits + // must be the same, as the underlying objects are the same. + if (Indices[0]->getType()->getScalarSizeInBits() >= + Q.DL.getAddressSizeInBits(AS)) { auto CanSimplify = [GEPTy, &P, Ptr]() -> bool { return P->getType() == GEPTy && getUnderlyingObject(P) == getUnderlyingObject(Ptr); }; // getelementptr V, (sub P, V) -> P if P points to a type of size 1. if (TyAllocSize == 1 && - match(Indices[0], - m_Sub(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Specific(Ptr)))) && + match(Indices[0], m_Sub(m_PtrToIntOrAddr(m_Value(P)), + m_PtrToIntOrAddr(m_Specific(Ptr)))) && CanSimplify()) return P; // getelementptr V, (ashr (sub P, V), C) -> P if P points to a type of // size 1 << C. - if (match(Indices[0], m_AShr(m_Sub(m_PtrToInt(m_Value(P)), - m_PtrToInt(m_Specific(Ptr))), + if (match(Indices[0], m_AShr(m_Sub(m_PtrToIntOrAddr(m_Value(P)), + m_PtrToIntOrAddr(m_Specific(Ptr))), m_ConstantInt(C))) && TyAllocSize == 1ULL << C && CanSimplify()) return P; // getelementptr V, (sdiv (sub P, V), C) -> P if P points to a type of // size C. - if (match(Indices[0], m_SDiv(m_Sub(m_PtrToInt(m_Value(P)), - m_PtrToInt(m_Specific(Ptr))), + if (match(Indices[0], m_SDiv(m_Sub(m_PtrToIntOrAddr(m_Value(P)), + m_PtrToIntOrAddr(m_Specific(Ptr))), m_SpecificInt(TyAllocSize))) && CanSimplify()) return P; @@ -5440,9 +5441,10 @@ static Value *simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, // ptrtoint (ptradd (Ptr, X - ptrtoint(Ptr))) -> X Value *Ptr, *X; - if (CastOpc == Instruction::PtrToInt && - match(Op, m_PtrAdd(m_Value(Ptr), - m_Sub(m_Value(X), m_PtrToInt(m_Deferred(Ptr))))) && + if ((CastOpc == Instruction::PtrToInt || CastOpc == Instruction::PtrToAddr) && + match(Op, + m_PtrAdd(m_Value(Ptr), + m_Sub(m_Value(X), m_PtrToIntOrAddr(m_Deferred(Ptr))))) && X->getType() == Ty && Ty == Q.DL.getIndexType(Ptr->getType())) return X; @@ -6987,8 +6989,8 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee, switch (IID) { case Intrinsic::masked_load: case Intrinsic::masked_gather: { - Value *MaskArg = Args[2]; - Value *PassthruArg = Args[3]; + Value *MaskArg = Args[1]; + Value *PassthruArg = Args[2]; // If the mask is all zeros or undef, the "passthru" argument is the result. if (maskIsAllZeroOrUndef(MaskArg)) return PassthruArg; diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp index 0e5bc48..df75999 100644 --- a/llvm/lib/Analysis/LazyValueInfo.cpp +++ b/llvm/lib/Analysis/LazyValueInfo.cpp @@ -947,9 +947,8 @@ LazyValueInfoImpl::solveBlockValueSelect(SelectInst *SI, BasicBlock *BB) { /*UseBlockValue*/ false)); } - ValueLatticeElement Result = TrueVal; - Result.mergeIn(FalseVal); - return Result; + TrueVal.mergeIn(FalseVal); + return TrueVal; } std::optional<ConstantRange> @@ -1778,9 +1777,8 @@ ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB, assert(OptResult && "Value not available after solving"); } - ValueLatticeElement Result = *OptResult; - LLVM_DEBUG(dbgs() << " Result = " << Result << "\n"); - return Result; + LLVM_DEBUG(dbgs() << " Result = " << *OptResult << "\n"); + return *OptResult; } ValueLatticeElement LazyValueInfoImpl::getValueAt(Value *V, Instruction *CxtI) { diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp index 1d1a5560..9a5ae2a 100644 --- a/llvm/lib/Analysis/MLInlineAdvisor.cpp +++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp @@ -324,32 +324,44 @@ void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice, FAM.invalidate(*Caller, PA); } Advice.updateCachedCallerFPI(FAM); - int64_t IRSizeAfter = - getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize); - CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize); + if (Caller == Callee) { + assert(!CalleeWasDeleted); + // We double-counted CallerAndCalleeEdges - since the caller and callee + // would be the same + assert(Advice.CallerAndCalleeEdges % 2 == 0); + CurrentIRSize += getIRSize(*Caller) - Advice.CallerIRSize; + EdgeCount += getCachedFPI(*Caller).DirectCallsToDefinedFunctions - + Advice.CallerAndCalleeEdges / 2; + // The NodeCount would stay the same. + } else { + int64_t IRSizeAfter = + getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize); + CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize); + + // We can delta-update module-wide features. We know the inlining only + // changed the caller, and maybe the callee (by deleting the latter). Nodes + // are simple to update. For edges, we 'forget' the edges that the caller + // and callee used to have before inlining, and add back what they currently + // have together. + int64_t NewCallerAndCalleeEdges = + getCachedFPI(*Caller).DirectCallsToDefinedFunctions; + + // A dead function's node is not actually removed from the call graph until + // the end of the call graph walk, but the node no longer belongs to any + // valid SCC. + if (CalleeWasDeleted) { + --NodeCount; + NodesInLastSCC.erase(CG.lookup(*Callee)); + DeadFunctions.insert(Callee); + } else { + NewCallerAndCalleeEdges += + getCachedFPI(*Callee).DirectCallsToDefinedFunctions; + } + EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges); + } if (CurrentIRSize > SizeIncreaseThreshold * InitialIRSize) ForceStop = true; - // We can delta-update module-wide features. We know the inlining only changed - // the caller, and maybe the callee (by deleting the latter). - // Nodes are simple to update. - // For edges, we 'forget' the edges that the caller and callee used to have - // before inlining, and add back what they currently have together. - int64_t NewCallerAndCalleeEdges = - getCachedFPI(*Caller).DirectCallsToDefinedFunctions; - - // A dead function's node is not actually removed from the call graph until - // the end of the call graph walk, but the node no longer belongs to any valid - // SCC. - if (CalleeWasDeleted) { - --NodeCount; - NodesInLastSCC.erase(CG.lookup(*Callee)); - DeadFunctions.insert(Callee); - } else { - NewCallerAndCalleeEdges += - getCachedFPI(*Callee).DirectCallsToDefinedFunctions; - } - EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges); assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0); } diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp index dcc5117..1c5f08e 100644 --- a/llvm/lib/Analysis/MemoryLocation.cpp +++ b/llvm/lib/Analysis/MemoryLocation.cpp @@ -245,7 +245,7 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call, assert(ArgIdx == 0 && "Invalid argument index"); auto *Ty = cast<VectorType>(II->getType()); - if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(2), Ty)) + if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(1), Ty)) return MemoryLocation(Arg, DL.getTypeStoreSize(*KnownType), AATags); return MemoryLocation( @@ -255,7 +255,7 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call, assert(ArgIdx == 1 && "Invalid argument index"); auto *Ty = cast<VectorType>(II->getArgOperand(0)->getType()); - if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(3), Ty)) + if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(2), Ty)) return MemoryLocation(Arg, DL.getTypeStoreSize(*KnownType), AATags); return MemoryLocation( diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index e06b095..6f7dd79 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -15473,6 +15473,38 @@ void ScalarEvolution::LoopGuards::collectFromPHI( } } +// Return a new SCEV that modifies \p Expr to the closest number divides by +// \p Divisor and less or equal than Expr. For now, only handle constant +// Expr. +static const SCEV *getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, + const APInt &DivisorVal, + ScalarEvolution &SE) { + const APInt *ExprVal; + if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || + DivisorVal.isNonPositive()) + return Expr; + APInt Rem = ExprVal->urem(DivisorVal); + // return the SCEV: Expr - Expr % Divisor + return SE.getConstant(*ExprVal - Rem); +} + +// Return a new SCEV that modifies \p Expr to the closest number divides by +// \p Divisor and greater or equal than Expr. For now, only handle constant +// Expr. +static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr, + const APInt &DivisorVal, + ScalarEvolution &SE) { + const APInt *ExprVal; + if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || + DivisorVal.isNonPositive()) + return Expr; + APInt Rem = ExprVal->urem(DivisorVal); + if (Rem.isZero()) + return Expr; + // return the SCEV: Expr + Divisor - Expr % Divisor + return SE.getConstant(*ExprVal + DivisorVal - Rem); +} + void ScalarEvolution::LoopGuards::collectFromBlock( ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards, const BasicBlock *Block, const BasicBlock *Pred, @@ -15540,36 +15572,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock( match(LHS, m_scev_APInt(C)) && C->isNonNegative(); }; - // Return a new SCEV that modifies \p Expr to the closest number divides by - // \p Divisor and greater or equal than Expr. For now, only handle constant - // Expr. - auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr, - const APInt &DivisorVal) { - const APInt *ExprVal; - if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || - DivisorVal.isNonPositive()) - return Expr; - APInt Rem = ExprVal->urem(DivisorVal); - if (Rem.isZero()) - return Expr; - // return the SCEV: Expr + Divisor - Expr % Divisor - return SE.getConstant(*ExprVal + DivisorVal - Rem); - }; - - // Return a new SCEV that modifies \p Expr to the closest number divides by - // \p Divisor and less or equal than Expr. For now, only handle constant - // Expr. - auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr, - const APInt &DivisorVal) { - const APInt *ExprVal; - if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || - DivisorVal.isNonPositive()) - return Expr; - APInt Rem = ExprVal->urem(DivisorVal); - // return the SCEV: Expr - Expr % Divisor - return SE.getConstant(*ExprVal - Rem); - }; - // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, // recursively. This is done by aligning up/down the constant value to the // Divisor. @@ -15591,8 +15593,9 @@ void ScalarEvolution::LoopGuards::collectFromBlock( assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!"); auto *DivisibleExpr = - IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, DivisorVal) - : GetNextSCEVDividesByDivisor(MinMaxLHS, DivisorVal); + IsMin + ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE) + : getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE); SmallVector<const SCEV *> Ops = { ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr}; return SE.getMinMaxExpr(SCTy, Ops); @@ -15669,21 +15672,21 @@ void ScalarEvolution::LoopGuards::collectFromBlock( [[fallthrough]]; case CmpInst::ICMP_SLT: { RHS = SE.getMinusSCEV(RHS, One); - RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy); + RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; } case CmpInst::ICMP_UGT: case CmpInst::ICMP_SGT: RHS = SE.getAddExpr(RHS, One); - RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy); + RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; case CmpInst::ICMP_ULE: case CmpInst::ICMP_SLE: - RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy); + RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; case CmpInst::ICMP_UGE: case CmpInst::ICMP_SGE: - RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy); + RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; default: break; @@ -15737,22 +15740,29 @@ void ScalarEvolution::LoopGuards::collectFromBlock( case CmpInst::ICMP_NE: if (match(RHS, m_scev_Zero())) { const SCEV *OneAlignedUp = - GetNextSCEVDividesByDivisor(One, DividesBy); + getNextSCEVDivisibleByDivisor(One, DividesBy, SE); To = SE.getUMaxExpr(FromRewritten, OneAlignedUp); } else { + // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS), + // but creating the subtraction eagerly is expensive. Track the + // inequalities in a separate map, and materialize the rewrite lazily + // when encountering a suitable subtraction while re-writing. if (LHS->getType()->isPointerTy()) { LHS = SE.getLosslessPtrToIntExpr(LHS); RHS = SE.getLosslessPtrToIntExpr(RHS); if (isa<SCEVCouldNotCompute>(LHS) || isa<SCEVCouldNotCompute>(RHS)) break; } - auto AddSubRewrite = [&](const SCEV *A, const SCEV *B) { - const SCEV *Sub = SE.getMinusSCEV(A, B); - AddRewrite(Sub, Sub, - SE.getUMaxExpr(Sub, SE.getOne(From->getType()))); - }; - AddSubRewrite(LHS, RHS); - AddSubRewrite(RHS, LHS); + const SCEVConstant *C; + const SCEV *A, *B; + if (match(RHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(A))) && + match(LHS, m_scev_Add(m_scev_Specific(C), m_SCEV(B)))) { + RHS = A; + LHS = B; + } + if (LHS > RHS) + std::swap(LHS, RHS); + Guards.NotEqual.insert({LHS, RHS}); continue; } break; @@ -15886,13 +15896,15 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> { const DenseMap<const SCEV *, const SCEV *> ⤅ + const SmallDenseSet<std::pair<const SCEV *, const SCEV *>> ≠ SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap; public: SCEVLoopGuardRewriter(ScalarEvolution &SE, const ScalarEvolution::LoopGuards &Guards) - : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) { + : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap), + NotEqual(Guards.NotEqual) { if (Guards.PreserveNUW) FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW); if (Guards.PreserveNSW) @@ -15947,14 +15959,39 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { } const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { + // Helper to check if S is a subtraction (A - B) where A != B, and if so, + // return UMax(S, 1). + auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * { + const SCEV *LHS, *RHS; + if (MatchBinarySub(S, LHS, RHS)) { + if (LHS > RHS) + std::swap(LHS, RHS); + if (NotEqual.contains({LHS, RHS})) { + const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor( + SE.getOne(S->getType()), SE.getConstantMultiple(S), SE); + return SE.getUMaxExpr(OneAlignedUp, S); + } + } + return nullptr; + }; + + // Check if Expr itself is a subtraction pattern with guard info. + if (const SCEV *Rewritten = RewriteSubtraction(Expr)) + return Rewritten; + // Trip count expressions sometimes consist of adding 3 operands, i.e. // (Const + A + B). There may be guard info for A + B, and if so, apply // it. // TODO: Could more generally apply guards to Add sub-expressions. if (isa<SCEVConstant>(Expr->getOperand(0)) && Expr->getNumOperands() == 3) { - if (const SCEV *S = Map.lookup( - SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2)))) + const SCEV *Add = + SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2)); + if (const SCEV *Rewritten = RewriteSubtraction(Add)) + return SE.getAddExpr( + Expr->getOperand(0), Rewritten, + ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask)); + if (const SCEV *S = Map.lookup(Add)) return SE.getAddExpr(Expr->getOperand(0), S); } SmallVector<const SCEV *, 2> Operands; @@ -15989,7 +16026,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { } }; - if (RewriteMap.empty()) + if (RewriteMap.empty() && NotEqual.empty()) return Expr; SCEVLoopGuardRewriter Rewriter(SE, *this); diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 9655c88..0a72076 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -7695,6 +7695,11 @@ static bool isGuaranteedNotToBeUndefOrPoison( } if (IsWellDefined) return true; + } else if (auto *Splat = isa<ShuffleVectorInst>(Opr) ? getSplatValue(Opr) + : nullptr) { + // For splats we only need to check the value being splatted. + if (OpCheck(Splat)) + return true; } else if (all_of(Opr->operands(), OpCheck)) return true; } diff --git a/llvm/lib/CAS/CMakeLists.txt b/llvm/lib/CAS/CMakeLists.txt index bca39b6..a2f8c49 100644 --- a/llvm/lib/CAS/CMakeLists.txt +++ b/llvm/lib/CAS/CMakeLists.txt @@ -8,6 +8,8 @@ add_llvm_component_library(LLVMCAS ObjectStore.cpp OnDiskCommon.cpp OnDiskDataAllocator.cpp + OnDiskGraphDB.cpp + OnDiskKeyValueDB.cpp OnDiskTrieRawHashMap.cpp ADDITIONAL_HEADER_DIRS diff --git a/llvm/lib/CAS/OnDiskCommon.cpp b/llvm/lib/CAS/OnDiskCommon.cpp index 25aa06b..281bde9 100644 --- a/llvm/lib/CAS/OnDiskCommon.cpp +++ b/llvm/lib/CAS/OnDiskCommon.cpp @@ -7,9 +7,10 @@ //===----------------------------------------------------------------------===// #include "OnDiskCommon.h" -#include "llvm/Config/config.h" #include "llvm/Support/Error.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/Process.h" +#include <mutex> #include <thread> #if __has_include(<sys/file.h>) @@ -25,8 +26,44 @@ #include <fcntl.h> #endif +#if __has_include(<sys/mount.h>) +#include <sys/mount.h> // statfs +#endif + using namespace llvm; +static uint64_t OnDiskCASMaxMappingSize = 0; + +Expected<std::optional<uint64_t>> cas::ondisk::getOverriddenMaxMappingSize() { + static std::once_flag Flag; + Error Err = Error::success(); + std::call_once(Flag, [&Err] { + ErrorAsOutParameter EAO(&Err); + constexpr const char *EnvVar = "LLVM_CAS_MAX_MAPPING_SIZE"; + auto Value = sys::Process::GetEnv(EnvVar); + if (!Value) + return; + + uint64_t Size; + if (StringRef(*Value).getAsInteger(/*auto*/ 0, Size)) + Err = createStringError(inconvertibleErrorCode(), + "invalid value for %s: expected integer", EnvVar); + OnDiskCASMaxMappingSize = Size; + }); + + if (Err) + return std::move(Err); + + if (OnDiskCASMaxMappingSize == 0) + return std::nullopt; + + return OnDiskCASMaxMappingSize; +} + +void cas::ondisk::setMaxMappingSize(uint64_t Size) { + OnDiskCASMaxMappingSize = Size; +} + std::error_code cas::ondisk::lockFileThreadSafe(int FD, sys::fs::LockKind Kind) { #if HAVE_FLOCK @@ -125,3 +162,20 @@ Expected<size_t> cas::ondisk::preallocateFileTail(int FD, size_t CurrentSize, return NewSize; // Pretend it worked. #endif } + +bool cas::ondisk::useSmallMappingSize(const Twine &P) { + // Add exceptions to use small database file here. +#if defined(__APPLE__) && __has_include(<sys/mount.h>) + // macOS tmpfs does not support sparse tails. + SmallString<128> PathStorage; + StringRef Path = P.toNullTerminatedStringRef(PathStorage); + struct statfs StatFS; + if (statfs(Path.data(), &StatFS) != 0) + return false; + + if (strcmp(StatFS.f_fstypename, "tmpfs") == 0) + return true; +#endif + // Default to use regular database file. + return false; +} diff --git a/llvm/lib/CAS/OnDiskCommon.h b/llvm/lib/CAS/OnDiskCommon.h index 8b79ffe..ac00662 100644 --- a/llvm/lib/CAS/OnDiskCommon.h +++ b/llvm/lib/CAS/OnDiskCommon.h @@ -12,9 +12,31 @@ #include "llvm/Support/Error.h" #include "llvm/Support/FileSystem.h" #include <chrono> +#include <optional> namespace llvm::cas::ondisk { +/// The version for all the ondisk database files. It needs to be bumped when +/// compatibility breaking changes are introduced. +constexpr StringLiteral CASFormatVersion = "v1"; + +/// Retrieves an overridden maximum mapping size for CAS files, if any, +/// speicified by LLVM_CAS_MAX_MAPPING_SIZE in the environment or set by +/// `setMaxMappingSize()`. If the value from environment is unreadable, returns +/// an error. +Expected<std::optional<uint64_t>> getOverriddenMaxMappingSize(); + +/// Set MaxMappingSize for ondisk CAS. This function is not thread-safe and +/// should be set before creaing any ondisk CAS and does not affect CAS already +/// created. Set value 0 to use default size. +void setMaxMappingSize(uint64_t Size); + +/// Whether to use a small file mapping for ondisk databases created in \p Path. +/// +/// For some file system that doesn't support sparse file, use a smaller file +/// mapping to avoid consuming too much disk space on creation. +bool useSmallMappingSize(const Twine &Path); + /// Thread-safe alternative to \c sys::fs::lockFile. This does not support all /// the platforms that \c sys::fs::lockFile does, so keep it in the CAS library /// for now. diff --git a/llvm/lib/CAS/OnDiskDataAllocator.cpp b/llvm/lib/CAS/OnDiskDataAllocator.cpp index 13bbd66..9c68bc4 100644 --- a/llvm/lib/CAS/OnDiskDataAllocator.cpp +++ b/llvm/lib/CAS/OnDiskDataAllocator.cpp @@ -185,7 +185,7 @@ Expected<ArrayRef<char>> OnDiskDataAllocator::get(FileOffset Offset, return ArrayRef<char>{Impl->File.getRegion().data() + Offset.get(), Size}; } -MutableArrayRef<uint8_t> OnDiskDataAllocator::getUserHeader() { +MutableArrayRef<uint8_t> OnDiskDataAllocator::getUserHeader() const { return Impl->Store.getUserHeader(); } @@ -221,7 +221,9 @@ Expected<ArrayRef<char>> OnDiskDataAllocator::get(FileOffset Offset, "OnDiskDataAllocator is not supported"); } -MutableArrayRef<uint8_t> OnDiskDataAllocator::getUserHeader() { return {}; } +MutableArrayRef<uint8_t> OnDiskDataAllocator::getUserHeader() const { + return {}; +} size_t OnDiskDataAllocator::size() const { return 0; } size_t OnDiskDataAllocator::capacity() const { return 0; } diff --git a/llvm/lib/CAS/OnDiskGraphDB.cpp b/llvm/lib/CAS/OnDiskGraphDB.cpp new file mode 100644 index 0000000..64cbe9d --- /dev/null +++ b/llvm/lib/CAS/OnDiskGraphDB.cpp @@ -0,0 +1,1758 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// This file implements OnDiskGraphDB, an on-disk CAS nodes database, +/// independent of a particular hashing algorithm. It only needs to be +/// configured for the hash size and controls the schema of the storage. +/// +/// OnDiskGraphDB defines: +/// +/// - How the data is stored inside database, either as a standalone file, or +/// allocated inside a datapool. +/// - How references to other objects inside the same database is stored. They +/// are stored as internal references, instead of full hash value to save +/// space. +/// - How to chain databases together and import objects from upstream +/// databases. +/// +/// Here's a top-level description of the current layout: +/// +/// - db/index.<version>: a file for the "index" table, named by \a +/// IndexTableName and managed by \a TrieRawHashMap. The contents are 8B +/// that are accessed atomically, describing the object kind and where/how +/// it's stored (including an optional file offset). See \a TrieRecord for +/// more details. +/// - db/data.<version>: a file for the "data" table, named by \a +/// DataPoolTableName and managed by \a DataStore. New objects within +/// TrieRecord::MaxEmbeddedSize are inserted here as \a +/// TrieRecord::StorageKind::DataPool. +/// - db/obj.<offset>.<version>: a file storing an object outside the main +/// "data" table, named by its offset into the "index" table, with the +/// format of \a TrieRecord::StorageKind::Standalone. +/// - db/leaf.<offset>.<version>: a file storing a leaf node outside the +/// main "data" table, named by its offset into the "index" table, with +/// the format of \a TrieRecord::StorageKind::StandaloneLeaf. +/// - db/leaf+0.<offset>.<version>: a file storing a null-terminated leaf object +/// outside the main "data" table, named by its offset into the "index" table, +/// with the format of \a TrieRecord::StorageKind::StandaloneLeaf0. +// +//===----------------------------------------------------------------------===// + +#include "llvm/CAS/OnDiskGraphDB.h" +#include "OnDiskCommon.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/CAS/OnDiskDataAllocator.h" +#include "llvm/CAS/OnDiskTrieRawHashMap.h" +#include "llvm/Support/Alignment.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Errc.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/Process.h" +#include <atomic> +#include <mutex> +#include <optional> + +#define DEBUG_TYPE "on-disk-cas" + +using namespace llvm; +using namespace llvm::cas; +using namespace llvm::cas::ondisk; + +static constexpr StringLiteral IndexTableName = "llvm.cas.index"; +static constexpr StringLiteral DataPoolTableName = "llvm.cas.data"; + +static constexpr StringLiteral IndexFilePrefix = "index."; +static constexpr StringLiteral DataPoolFilePrefix = "data."; + +static constexpr StringLiteral FilePrefixObject = "obj."; +static constexpr StringLiteral FilePrefixLeaf = "leaf."; +static constexpr StringLiteral FilePrefixLeaf0 = "leaf+0."; + +static Error createCorruptObjectError(Expected<ArrayRef<uint8_t>> ID) { + if (!ID) + return ID.takeError(); + + return createStringError(llvm::errc::invalid_argument, + "corrupt object '" + toHex(*ID) + "'"); +} + +namespace { + +/// Trie record data: 8 bytes, atomic<uint64_t> +/// - 1-byte: StorageKind +/// - 7-bytes: DataStoreOffset (offset into referenced file) +class TrieRecord { +public: + enum class StorageKind : uint8_t { + /// Unknown object. + Unknown = 0, + + /// data.vX: main pool, full DataStore record. + DataPool = 1, + + /// obj.<TrieRecordOffset>.vX: standalone, with a full DataStore record. + Standalone = 10, + + /// leaf.<TrieRecordOffset>.vX: standalone, just the data. File contents + /// exactly the data content and file size matches the data size. No refs. + StandaloneLeaf = 11, + + /// leaf+0.<TrieRecordOffset>.vX: standalone, just the data plus an + /// extra null character ('\0'). File size is 1 bigger than the data size. + /// No refs. + StandaloneLeaf0 = 12, + }; + + static StringRef getStandaloneFilePrefix(StorageKind SK) { + switch (SK) { + default: + llvm_unreachable("Expected standalone storage kind"); + case TrieRecord::StorageKind::Standalone: + return FilePrefixObject; + case TrieRecord::StorageKind::StandaloneLeaf: + return FilePrefixLeaf; + case TrieRecord::StorageKind::StandaloneLeaf0: + return FilePrefixLeaf0; + } + } + + enum Limits : int64_t { + /// Saves files bigger than 64KB standalone instead of embedding them. + MaxEmbeddedSize = 64LL * 1024LL - 1, + }; + + struct Data { + StorageKind SK = StorageKind::Unknown; + FileOffset Offset; + }; + + /// Pack StorageKind and Offset from Data into 8 byte TrieRecord. + static uint64_t pack(Data D) { + assert(D.Offset.get() < (int64_t)(1ULL << 56)); + uint64_t Packed = uint64_t(D.SK) << 56 | D.Offset.get(); + assert(D.SK != StorageKind::Unknown || Packed == 0); +#ifndef NDEBUG + Data RoundTrip = unpack(Packed); + assert(D.SK == RoundTrip.SK); + assert(D.Offset.get() == RoundTrip.Offset.get()); +#endif + return Packed; + } + + // Unpack TrieRecord into Data. + static Data unpack(uint64_t Packed) { + Data D; + if (!Packed) + return D; + D.SK = (StorageKind)(Packed >> 56); + D.Offset = FileOffset(Packed & (UINT64_MAX >> 8)); + return D; + } + + TrieRecord() : Storage(0) {} + + Data load() const { return unpack(Storage); } + bool compare_exchange_strong(Data &Existing, Data New); + +private: + std::atomic<uint64_t> Storage; +}; + +/// DataStore record data: 4B + size? + refs? + data + 0 +/// - 4-bytes: Header +/// - {0,4,8}-bytes: DataSize (may be packed in Header) +/// - {0,4,8}-bytes: NumRefs (may be packed in Header) +/// - NumRefs*{4,8}-bytes: Refs[] (end-ptr is 8-byte aligned) +/// - <data> +/// - 1-byte: 0-term +struct DataRecordHandle { + /// NumRefs storage: 4B, 2B, 1B, or 0B (no refs). Or, 8B, for alignment + /// convenience to avoid computing padding later. + enum class NumRefsFlags : uint8_t { + Uses0B = 0U, + Uses1B = 1U, + Uses2B = 2U, + Uses4B = 3U, + Uses8B = 4U, + Max = Uses8B, + }; + + /// DataSize storage: 8B, 4B, 2B, or 1B. + enum class DataSizeFlags { + Uses1B = 0U, + Uses2B = 1U, + Uses4B = 2U, + Uses8B = 3U, + Max = Uses8B, + }; + + /// Kind of ref stored in Refs[]: InternalRef or InternalRef4B. + enum class RefKindFlags { + InternalRef = 0U, + InternalRef4B = 1U, + Max = InternalRef4B, + }; + + enum Counts : int { + NumRefsShift = 0, + NumRefsBits = 3, + DataSizeShift = NumRefsShift + NumRefsBits, + DataSizeBits = 2, + RefKindShift = DataSizeShift + DataSizeBits, + RefKindBits = 1, + }; + static_assert(((UINT32_MAX << NumRefsBits) & (uint32_t)NumRefsFlags::Max) == + 0, + "Not enough bits"); + static_assert(((UINT32_MAX << DataSizeBits) & (uint32_t)DataSizeFlags::Max) == + 0, + "Not enough bits"); + static_assert(((UINT32_MAX << RefKindBits) & (uint32_t)RefKindFlags::Max) == + 0, + "Not enough bits"); + + /// Layout of the DataRecordHandle and how to decode it. + struct LayoutFlags { + NumRefsFlags NumRefs; + DataSizeFlags DataSize; + RefKindFlags RefKind; + + static uint64_t pack(LayoutFlags LF) { + unsigned Packed = ((unsigned)LF.NumRefs << NumRefsShift) | + ((unsigned)LF.DataSize << DataSizeShift) | + ((unsigned)LF.RefKind << RefKindShift); +#ifndef NDEBUG + LayoutFlags RoundTrip = unpack(Packed); + assert(LF.NumRefs == RoundTrip.NumRefs); + assert(LF.DataSize == RoundTrip.DataSize); + assert(LF.RefKind == RoundTrip.RefKind); +#endif + return Packed; + } + static LayoutFlags unpack(uint64_t Storage) { + assert(Storage <= UINT8_MAX && "Expect storage to fit in a byte"); + LayoutFlags LF; + LF.NumRefs = + (NumRefsFlags)((Storage >> NumRefsShift) & ((1U << NumRefsBits) - 1)); + LF.DataSize = (DataSizeFlags)((Storage >> DataSizeShift) & + ((1U << DataSizeBits) - 1)); + LF.RefKind = + (RefKindFlags)((Storage >> RefKindShift) & ((1U << RefKindBits) - 1)); + return LF; + } + }; + + /// Header layout: + /// - 1-byte: LayoutFlags + /// - 1-byte: 1B size field + /// - {0,2}-bytes: 2B size field + struct Header { + using PackTy = uint32_t; + PackTy Packed; + + static constexpr unsigned LayoutFlagsShift = + (sizeof(PackTy) - 1) * CHAR_BIT; + }; + + struct Input { + InternalRefArrayRef Refs; + ArrayRef<char> Data; + }; + + LayoutFlags getLayoutFlags() const { + return LayoutFlags::unpack(H->Packed >> Header::LayoutFlagsShift); + } + + uint64_t getDataSize() const; + void skipDataSize(LayoutFlags LF, int64_t &RelOffset) const; + uint32_t getNumRefs() const; + void skipNumRefs(LayoutFlags LF, int64_t &RelOffset) const; + int64_t getRefsRelOffset() const; + int64_t getDataRelOffset() const; + + static uint64_t getTotalSize(uint64_t DataRelOffset, uint64_t DataSize) { + return DataRelOffset + DataSize + 1; + } + uint64_t getTotalSize() const { + return getDataRelOffset() + getDataSize() + 1; + } + + /// Describe the layout of data stored and how to decode from + /// DataRecordHandle. + struct Layout { + explicit Layout(const Input &I); + + LayoutFlags Flags; + uint64_t DataSize = 0; + uint32_t NumRefs = 0; + int64_t RefsRelOffset = 0; + int64_t DataRelOffset = 0; + uint64_t getTotalSize() const { + return DataRecordHandle::getTotalSize(DataRelOffset, DataSize); + } + }; + + InternalRefArrayRef getRefs() const { + assert(H && "Expected valid handle"); + auto *BeginByte = reinterpret_cast<const char *>(H) + getRefsRelOffset(); + size_t Size = getNumRefs(); + if (!Size) + return InternalRefArrayRef(); + if (getLayoutFlags().RefKind == RefKindFlags::InternalRef4B) + return ArrayRef(reinterpret_cast<const InternalRef4B *>(BeginByte), Size); + return ArrayRef(reinterpret_cast<const InternalRef *>(BeginByte), Size); + } + + ArrayRef<char> getData() const { + assert(H && "Expected valid handle"); + return ArrayRef(reinterpret_cast<const char *>(H) + getDataRelOffset(), + getDataSize()); + } + + static DataRecordHandle create(function_ref<char *(size_t Size)> Alloc, + const Input &I); + static Expected<DataRecordHandle> + createWithError(function_ref<Expected<char *>(size_t Size)> Alloc, + const Input &I); + static DataRecordHandle construct(char *Mem, const Input &I); + + static DataRecordHandle get(const char *Mem) { + return DataRecordHandle( + *reinterpret_cast<const DataRecordHandle::Header *>(Mem)); + } + static Expected<DataRecordHandle> + getFromDataPool(const OnDiskDataAllocator &Pool, FileOffset Offset); + + explicit operator bool() const { return H; } + const Header &getHeader() const { return *H; } + + DataRecordHandle() = default; + explicit DataRecordHandle(const Header &H) : H(&H) {} + +private: + static DataRecordHandle constructImpl(char *Mem, const Input &I, + const Layout &L); + const Header *H = nullptr; +}; + +/// Proxy for any on-disk object or raw data. +struct OnDiskContent { + std::optional<DataRecordHandle> Record; + std::optional<ArrayRef<char>> Bytes; +}; + +/// Data loaded inside the memory from standalone file. +class StandaloneDataInMemory { +public: + OnDiskContent getContent() const; + + StandaloneDataInMemory(std::unique_ptr<sys::fs::mapped_file_region> Region, + TrieRecord::StorageKind SK) + : Region(std::move(Region)), SK(SK) { +#ifndef NDEBUG + bool IsStandalone = false; + switch (SK) { + case TrieRecord::StorageKind::Standalone: + case TrieRecord::StorageKind::StandaloneLeaf: + case TrieRecord::StorageKind::StandaloneLeaf0: + IsStandalone = true; + break; + default: + break; + } + assert(IsStandalone); +#endif + } + +private: + std::unique_ptr<sys::fs::mapped_file_region> Region; + TrieRecord::StorageKind SK; +}; + +/// Container to lookup loaded standalone objects. +template <size_t NumShards> class StandaloneDataMap { + static_assert(isPowerOf2_64(NumShards), "Expected power of 2"); + +public: + uintptr_t insert(ArrayRef<uint8_t> Hash, TrieRecord::StorageKind SK, + std::unique_ptr<sys::fs::mapped_file_region> Region); + + const StandaloneDataInMemory *lookup(ArrayRef<uint8_t> Hash) const; + bool count(ArrayRef<uint8_t> Hash) const { return bool(lookup(Hash)); } + +private: + struct Shard { + /// Needs to store a std::unique_ptr for a stable address identity. + DenseMap<const uint8_t *, std::unique_ptr<StandaloneDataInMemory>> Map; + mutable std::mutex Mutex; + }; + Shard &getShard(ArrayRef<uint8_t> Hash) { + return const_cast<Shard &>( + const_cast<const StandaloneDataMap *>(this)->getShard(Hash)); + } + const Shard &getShard(ArrayRef<uint8_t> Hash) const { + static_assert(NumShards <= 256, "Expected only 8 bits of shard"); + return Shards[Hash[0] % NumShards]; + } + + Shard Shards[NumShards]; +}; + +using StandaloneDataMapTy = StandaloneDataMap<16>; + +/// A vector of internal node references. +class InternalRefVector { +public: + void push_back(InternalRef Ref) { + if (NeedsFull) + return FullRefs.push_back(Ref); + if (std::optional<InternalRef4B> Small = InternalRef4B::tryToShrink(Ref)) + return SmallRefs.push_back(*Small); + NeedsFull = true; + assert(FullRefs.empty()); + FullRefs.reserve(SmallRefs.size() + 1); + for (InternalRef4B Small : SmallRefs) + FullRefs.push_back(Small); + FullRefs.push_back(Ref); + SmallRefs.clear(); + } + + operator InternalRefArrayRef() const { + assert(SmallRefs.empty() || FullRefs.empty()); + return NeedsFull ? InternalRefArrayRef(FullRefs) + : InternalRefArrayRef(SmallRefs); + } + +private: + bool NeedsFull = false; + SmallVector<InternalRef4B> SmallRefs; + SmallVector<InternalRef> FullRefs; +}; + +} // namespace + +Expected<DataRecordHandle> DataRecordHandle::createWithError( + function_ref<Expected<char *>(size_t Size)> Alloc, const Input &I) { + Layout L(I); + if (Expected<char *> Mem = Alloc(L.getTotalSize())) + return constructImpl(*Mem, I, L); + else + return Mem.takeError(); +} + +DataRecordHandle +DataRecordHandle::create(function_ref<char *(size_t Size)> Alloc, + const Input &I) { + Layout L(I); + return constructImpl(Alloc(L.getTotalSize()), I, L); +} + +ObjectHandle ObjectHandle::fromFileOffset(FileOffset Offset) { + // Store the file offset as it is. + assert(!(Offset.get() & 0x1)); + return ObjectHandle(Offset.get()); +} + +ObjectHandle ObjectHandle::fromMemory(uintptr_t Ptr) { + // Store the pointer from memory with lowest bit set. + assert(!(Ptr & 0x1)); + return ObjectHandle(Ptr | 1); +} + +/// Proxy for an on-disk index record. +struct OnDiskGraphDB::IndexProxy { + FileOffset Offset; + ArrayRef<uint8_t> Hash; + TrieRecord &Ref; +}; + +template <size_t N> +uintptr_t StandaloneDataMap<N>::insert( + ArrayRef<uint8_t> Hash, TrieRecord::StorageKind SK, + std::unique_ptr<sys::fs::mapped_file_region> Region) { + auto &S = getShard(Hash); + std::lock_guard<std::mutex> Lock(S.Mutex); + auto &V = S.Map[Hash.data()]; + if (!V) + V = std::make_unique<StandaloneDataInMemory>(std::move(Region), SK); + return reinterpret_cast<uintptr_t>(V.get()); +} + +template <size_t N> +const StandaloneDataInMemory * +StandaloneDataMap<N>::lookup(ArrayRef<uint8_t> Hash) const { + auto &S = getShard(Hash); + std::lock_guard<std::mutex> Lock(S.Mutex); + auto I = S.Map.find(Hash.data()); + if (I == S.Map.end()) + return nullptr; + return &*I->second; +} + +namespace { + +/// Copy of \a sys::fs::TempFile that skips RemoveOnSignal, which is too +/// expensive to register/unregister at this rate. +/// +/// FIXME: Add a TempFileManager that maintains a thread-safe list of open temp +/// files and has a signal handler registerd that removes them all. +class TempFile { + bool Done = false; + TempFile(StringRef Name, int FD) : TmpName(std::string(Name)), FD(FD) {} + +public: + /// This creates a temporary file with createUniqueFile. + static Expected<TempFile> create(const Twine &Model); + TempFile(TempFile &&Other) { *this = std::move(Other); } + TempFile &operator=(TempFile &&Other) { + TmpName = std::move(Other.TmpName); + FD = Other.FD; + Other.Done = true; + Other.FD = -1; + return *this; + } + + // Name of the temporary file. + std::string TmpName; + + // The open file descriptor. + int FD = -1; + + // Keep this with the given name. + Error keep(const Twine &Name); + Error discard(); + + // This checks that keep or delete was called. + ~TempFile() { consumeError(discard()); } +}; + +class MappedTempFile { +public: + char *data() const { return Map.data(); } + size_t size() const { return Map.size(); } + + Error discard() { + assert(Map && "Map already destroyed"); + Map.unmap(); + return Temp.discard(); + } + + Error keep(const Twine &Name) { + assert(Map && "Map already destroyed"); + Map.unmap(); + return Temp.keep(Name); + } + + MappedTempFile(TempFile Temp, sys::fs::mapped_file_region Map) + : Temp(std::move(Temp)), Map(std::move(Map)) {} + +private: + TempFile Temp; + sys::fs::mapped_file_region Map; +}; +} // namespace + +Error TempFile::discard() { + Done = true; + if (FD != -1) { + sys::fs::file_t File = sys::fs::convertFDToNativeFile(FD); + if (std::error_code EC = sys::fs::closeFile(File)) + return errorCodeToError(EC); + } + FD = -1; + + // Always try to close and remove. + std::error_code RemoveEC; + if (!TmpName.empty()) { + std::error_code EC = sys::fs::remove(TmpName); + if (EC) + return errorCodeToError(EC); + } + TmpName = ""; + + return Error::success(); +} + +Error TempFile::keep(const Twine &Name) { + assert(!Done); + Done = true; + // Always try to close and rename. + std::error_code RenameEC = sys::fs::rename(TmpName, Name); + + if (!RenameEC) + TmpName = ""; + + sys::fs::file_t File = sys::fs::convertFDToNativeFile(FD); + if (std::error_code EC = sys::fs::closeFile(File)) + return errorCodeToError(EC); + FD = -1; + + return errorCodeToError(RenameEC); +} + +Expected<TempFile> TempFile::create(const Twine &Model) { + int FD; + SmallString<128> ResultPath; + if (std::error_code EC = sys::fs::createUniqueFile(Model, FD, ResultPath)) + return errorCodeToError(EC); + + TempFile Ret(ResultPath, FD); + return std::move(Ret); +} + +bool TrieRecord::compare_exchange_strong(Data &Existing, Data New) { + uint64_t ExistingPacked = pack(Existing); + uint64_t NewPacked = pack(New); + if (Storage.compare_exchange_strong(ExistingPacked, NewPacked)) + return true; + Existing = unpack(ExistingPacked); + return false; +} + +DataRecordHandle DataRecordHandle::construct(char *Mem, const Input &I) { + return constructImpl(Mem, I, Layout(I)); +} + +Expected<DataRecordHandle> +DataRecordHandle::getFromDataPool(const OnDiskDataAllocator &Pool, + FileOffset Offset) { + auto HeaderData = Pool.get(Offset, sizeof(DataRecordHandle::Header)); + if (!HeaderData) + return HeaderData.takeError(); + + auto Record = DataRecordHandle::get(HeaderData->data()); + if (Record.getTotalSize() + Offset.get() > Pool.size()) + return createStringError( + make_error_code(std::errc::illegal_byte_sequence), + "data record span passed the end of the data pool"); + + return Record; +} + +DataRecordHandle DataRecordHandle::constructImpl(char *Mem, const Input &I, + const Layout &L) { + char *Next = Mem + sizeof(Header); + + // Fill in Packed and set other data, then come back to construct the header. + Header::PackTy Packed = 0; + Packed |= LayoutFlags::pack(L.Flags) << Header::LayoutFlagsShift; + + // Construct DataSize. + switch (L.Flags.DataSize) { + case DataSizeFlags::Uses1B: + assert(I.Data.size() <= UINT8_MAX); + Packed |= (Header::PackTy)I.Data.size() + << ((sizeof(Packed) - 2) * CHAR_BIT); + break; + case DataSizeFlags::Uses2B: + assert(I.Data.size() <= UINT16_MAX); + Packed |= (Header::PackTy)I.Data.size() + << ((sizeof(Packed) - 4) * CHAR_BIT); + break; + case DataSizeFlags::Uses4B: + support::endian::write32le(Next, I.Data.size()); + Next += 4; + break; + case DataSizeFlags::Uses8B: + support::endian::write64le(Next, I.Data.size()); + Next += 8; + break; + } + + // Construct NumRefs. + // + // NOTE: May be writing NumRefs even if there are zero refs in order to fix + // alignment. + switch (L.Flags.NumRefs) { + case NumRefsFlags::Uses0B: + break; + case NumRefsFlags::Uses1B: + assert(I.Refs.size() <= UINT8_MAX); + Packed |= (Header::PackTy)I.Refs.size() + << ((sizeof(Packed) - 2) * CHAR_BIT); + break; + case NumRefsFlags::Uses2B: + assert(I.Refs.size() <= UINT16_MAX); + Packed |= (Header::PackTy)I.Refs.size() + << ((sizeof(Packed) - 4) * CHAR_BIT); + break; + case NumRefsFlags::Uses4B: + support::endian::write32le(Next, I.Refs.size()); + Next += 4; + break; + case NumRefsFlags::Uses8B: + support::endian::write64le(Next, I.Refs.size()); + Next += 8; + break; + } + + // Construct Refs[]. + if (!I.Refs.empty()) { + assert((L.Flags.RefKind == RefKindFlags::InternalRef4B) == I.Refs.is4B()); + ArrayRef<uint8_t> RefsBuffer = I.Refs.getBuffer(); + llvm::copy(RefsBuffer, Next); + Next += RefsBuffer.size(); + } + + // Construct Data and the trailing null. + assert(isAddrAligned(Align(8), Next)); + llvm::copy(I.Data, Next); + Next[I.Data.size()] = 0; + + // Construct the header itself and return. + Header *H = new (Mem) Header{Packed}; + DataRecordHandle Record(*H); + assert(Record.getData() == I.Data); + assert(Record.getNumRefs() == I.Refs.size()); + assert(Record.getRefs() == I.Refs); + assert(Record.getLayoutFlags().DataSize == L.Flags.DataSize); + assert(Record.getLayoutFlags().NumRefs == L.Flags.NumRefs); + assert(Record.getLayoutFlags().RefKind == L.Flags.RefKind); + return Record; +} + +DataRecordHandle::Layout::Layout(const Input &I) { + // Start initial relative offsets right after the Header. + uint64_t RelOffset = sizeof(Header); + + // Initialize the easy stuff. + DataSize = I.Data.size(); + NumRefs = I.Refs.size(); + + // Check refs size. + Flags.RefKind = + I.Refs.is4B() ? RefKindFlags::InternalRef4B : RefKindFlags::InternalRef; + + // Find the smallest slot available for DataSize. + bool Has1B = true; + bool Has2B = true; + if (DataSize <= UINT8_MAX && Has1B) { + Flags.DataSize = DataSizeFlags::Uses1B; + Has1B = false; + } else if (DataSize <= UINT16_MAX && Has2B) { + Flags.DataSize = DataSizeFlags::Uses2B; + Has2B = false; + } else if (DataSize <= UINT32_MAX) { + Flags.DataSize = DataSizeFlags::Uses4B; + RelOffset += 4; + } else { + Flags.DataSize = DataSizeFlags::Uses8B; + RelOffset += 8; + } + + // Find the smallest slot available for NumRefs. Never sets NumRefs8B here. + if (!NumRefs) { + Flags.NumRefs = NumRefsFlags::Uses0B; + } else if (NumRefs <= UINT8_MAX && Has1B) { + Flags.NumRefs = NumRefsFlags::Uses1B; + Has1B = false; + } else if (NumRefs <= UINT16_MAX && Has2B) { + Flags.NumRefs = NumRefsFlags::Uses2B; + Has2B = false; + } else { + Flags.NumRefs = NumRefsFlags::Uses4B; + RelOffset += 4; + } + + // Helper to "upgrade" either DataSize or NumRefs by 4B to avoid complicated + // padding rules when reading and writing. This also bumps RelOffset. + // + // The value for NumRefs is strictly limited to UINT32_MAX, but it can be + // stored as 8B. This means we can *always* find a size to grow. + // + // NOTE: Only call this once. + auto GrowSizeFieldsBy4B = [&]() { + assert(isAligned(Align(4), RelOffset)); + RelOffset += 4; + + assert(Flags.NumRefs != NumRefsFlags::Uses8B && + "Expected to be able to grow NumRefs8B"); + + // First try to grow DataSize. NumRefs will not (yet) be 8B, and if + // DataSize is upgraded to 8B it'll already be aligned. + // + // Failing that, grow NumRefs. + if (Flags.DataSize < DataSizeFlags::Uses4B) + Flags.DataSize = DataSizeFlags::Uses4B; // DataSize: Packed => 4B. + else if (Flags.DataSize < DataSizeFlags::Uses8B) + Flags.DataSize = DataSizeFlags::Uses8B; // DataSize: 4B => 8B. + else if (Flags.NumRefs < NumRefsFlags::Uses4B) + Flags.NumRefs = NumRefsFlags::Uses4B; // NumRefs: Packed => 4B. + else + Flags.NumRefs = NumRefsFlags::Uses8B; // NumRefs: 4B => 8B. + }; + + assert(isAligned(Align(4), RelOffset)); + if (Flags.RefKind == RefKindFlags::InternalRef) { + // List of 8B refs should be 8B-aligned. Grow one of the sizes to get this + // without padding. + if (!isAligned(Align(8), RelOffset)) + GrowSizeFieldsBy4B(); + + assert(isAligned(Align(8), RelOffset)); + RefsRelOffset = RelOffset; + RelOffset += 8 * NumRefs; + } else { + // The array of 4B refs doesn't need 8B alignment, but the data will need + // to be 8B-aligned. Detect this now, and, if necessary, shift everything + // by 4B by growing one of the sizes. + // If we remove the need for 8B-alignment for data there is <1% savings in + // disk storage for a clang build using MCCAS but the 8B-alignment may be + // useful in the future so keep it for now. + uint64_t RefListSize = 4 * NumRefs; + if (!isAligned(Align(8), RelOffset + RefListSize)) + GrowSizeFieldsBy4B(); + RefsRelOffset = RelOffset; + RelOffset += RefListSize; + } + + assert(isAligned(Align(8), RelOffset)); + DataRelOffset = RelOffset; +} + +uint64_t DataRecordHandle::getDataSize() const { + int64_t RelOffset = sizeof(Header); + auto *DataSizePtr = reinterpret_cast<const char *>(H) + RelOffset; + switch (getLayoutFlags().DataSize) { + case DataSizeFlags::Uses1B: + return (H->Packed >> ((sizeof(Header::PackTy) - 2) * CHAR_BIT)) & UINT8_MAX; + case DataSizeFlags::Uses2B: + return (H->Packed >> ((sizeof(Header::PackTy) - 4) * CHAR_BIT)) & + UINT16_MAX; + case DataSizeFlags::Uses4B: + return support::endian::read32le(DataSizePtr); + case DataSizeFlags::Uses8B: + return support::endian::read64le(DataSizePtr); + } + llvm_unreachable("Unknown DataSizeFlags enum"); +} + +void DataRecordHandle::skipDataSize(LayoutFlags LF, int64_t &RelOffset) const { + if (LF.DataSize >= DataSizeFlags::Uses4B) + RelOffset += 4; + if (LF.DataSize >= DataSizeFlags::Uses8B) + RelOffset += 4; +} + +uint32_t DataRecordHandle::getNumRefs() const { + LayoutFlags LF = getLayoutFlags(); + int64_t RelOffset = sizeof(Header); + skipDataSize(LF, RelOffset); + auto *NumRefsPtr = reinterpret_cast<const char *>(H) + RelOffset; + switch (LF.NumRefs) { + case NumRefsFlags::Uses0B: + return 0; + case NumRefsFlags::Uses1B: + return (H->Packed >> ((sizeof(Header::PackTy) - 2) * CHAR_BIT)) & UINT8_MAX; + case NumRefsFlags::Uses2B: + return (H->Packed >> ((sizeof(Header::PackTy) - 4) * CHAR_BIT)) & + UINT16_MAX; + case NumRefsFlags::Uses4B: + return support::endian::read32le(NumRefsPtr); + case NumRefsFlags::Uses8B: + return support::endian::read64le(NumRefsPtr); + } + llvm_unreachable("Unknown NumRefsFlags enum"); +} + +void DataRecordHandle::skipNumRefs(LayoutFlags LF, int64_t &RelOffset) const { + if (LF.NumRefs >= NumRefsFlags::Uses4B) + RelOffset += 4; + if (LF.NumRefs >= NumRefsFlags::Uses8B) + RelOffset += 4; +} + +int64_t DataRecordHandle::getRefsRelOffset() const { + LayoutFlags LF = getLayoutFlags(); + int64_t RelOffset = sizeof(Header); + skipDataSize(LF, RelOffset); + skipNumRefs(LF, RelOffset); + return RelOffset; +} + +int64_t DataRecordHandle::getDataRelOffset() const { + LayoutFlags LF = getLayoutFlags(); + int64_t RelOffset = sizeof(Header); + skipDataSize(LF, RelOffset); + skipNumRefs(LF, RelOffset); + uint32_t RefSize = LF.RefKind == RefKindFlags::InternalRef4B ? 4 : 8; + RelOffset += RefSize * getNumRefs(); + return RelOffset; +} + +Error OnDiskGraphDB::validate(bool Deep, HashingFuncT Hasher) const { + return Index.validate([&](FileOffset Offset, + OnDiskTrieRawHashMap::ConstValueProxy Record) + -> Error { + auto formatError = [&](Twine Msg) { + return createStringError( + llvm::errc::illegal_byte_sequence, + "bad record at 0x" + + utohexstr((unsigned)Offset.get(), /*LowerCase=*/true) + ": " + + Msg.str()); + }; + + if (Record.Data.size() != sizeof(TrieRecord)) + return formatError("wrong data record size"); + if (!isAligned(Align::Of<TrieRecord>(), Record.Data.size())) + return formatError("wrong data record alignment"); + + auto *R = reinterpret_cast<const TrieRecord *>(Record.Data.data()); + TrieRecord::Data D = R->load(); + std::unique_ptr<MemoryBuffer> FileBuffer; + if ((uint8_t)D.SK != (uint8_t)TrieRecord::StorageKind::Unknown && + (uint8_t)D.SK != (uint8_t)TrieRecord::StorageKind::DataPool && + (uint8_t)D.SK != (uint8_t)TrieRecord::StorageKind::Standalone && + (uint8_t)D.SK != (uint8_t)TrieRecord::StorageKind::StandaloneLeaf && + (uint8_t)D.SK != (uint8_t)TrieRecord::StorageKind::StandaloneLeaf0) + return formatError("invalid record kind value"); + + auto Ref = InternalRef::getFromOffset(Offset); + auto I = getIndexProxyFromRef(Ref); + if (!I) + return I.takeError(); + + switch (D.SK) { + case TrieRecord::StorageKind::Unknown: + // This could be an abandoned entry due to a termination before updating + // the record. It can be reused by later insertion so just skip this entry + // for now. + return Error::success(); + case TrieRecord::StorageKind::DataPool: + // Check offset is a postive value, and large enough to hold the + // header for the data record. + if (D.Offset.get() <= 0 || + (uint64_t)D.Offset.get() + sizeof(DataRecordHandle::Header) >= + DataPool.size()) + return formatError("datapool record out of bound"); + break; + case TrieRecord::StorageKind::Standalone: + case TrieRecord::StorageKind::StandaloneLeaf: + case TrieRecord::StorageKind::StandaloneLeaf0: + SmallString<256> Path; + getStandalonePath(TrieRecord::getStandaloneFilePrefix(D.SK), *I, Path); + // If need to validate the content of the file later, just load the + // buffer here. Otherwise, just check the existance of the file. + if (Deep) { + auto File = MemoryBuffer::getFile(Path, /*IsText=*/false, + /*RequiresNullTerminator=*/false); + if (!File || !*File) + return formatError("record file \'" + Path + "\' does not exist"); + + FileBuffer = std::move(*File); + } else if (!llvm::sys::fs::exists(Path)) + return formatError("record file \'" + Path + "\' does not exist"); + } + + if (!Deep) + return Error::success(); + + auto dataError = [&](Twine Msg) { + return createStringError(llvm::errc::illegal_byte_sequence, + "bad data for digest \'" + toHex(I->Hash) + + "\': " + Msg.str()); + }; + SmallVector<ArrayRef<uint8_t>> Refs; + ArrayRef<char> StoredData; + + switch (D.SK) { + case TrieRecord::StorageKind::Unknown: + llvm_unreachable("already handled"); + case TrieRecord::StorageKind::DataPool: { + auto DataRecord = DataRecordHandle::getFromDataPool(DataPool, D.Offset); + if (!DataRecord) + return dataError(toString(DataRecord.takeError())); + + for (auto InternRef : DataRecord->getRefs()) { + auto Index = getIndexProxyFromRef(InternRef); + if (!Index) + return Index.takeError(); + Refs.push_back(Index->Hash); + } + StoredData = DataRecord->getData(); + break; + } + case TrieRecord::StorageKind::Standalone: { + if (FileBuffer->getBufferSize() < sizeof(DataRecordHandle::Header)) + return dataError("data record is not big enough to read the header"); + auto DataRecord = DataRecordHandle::get(FileBuffer->getBufferStart()); + if (DataRecord.getTotalSize() < FileBuffer->getBufferSize()) + return dataError( + "data record span passed the end of the standalone file"); + for (auto InternRef : DataRecord.getRefs()) { + auto Index = getIndexProxyFromRef(InternRef); + if (!Index) + return Index.takeError(); + Refs.push_back(Index->Hash); + } + StoredData = DataRecord.getData(); + break; + } + case TrieRecord::StorageKind::StandaloneLeaf: + case TrieRecord::StorageKind::StandaloneLeaf0: { + StoredData = arrayRefFromStringRef<char>(FileBuffer->getBuffer()); + if (D.SK == TrieRecord::StorageKind::StandaloneLeaf0) { + if (!FileBuffer->getBuffer().ends_with('\0')) + return dataError("standalone file is not zero terminated"); + StoredData = StoredData.drop_back(1); + } + break; + } + } + + SmallVector<uint8_t> ComputedHash; + Hasher(Refs, StoredData, ComputedHash); + if (I->Hash != ArrayRef(ComputedHash)) + return dataError("hash mismatch, got \'" + toHex(ComputedHash) + + "\' instead"); + + return Error::success(); + }); +} + +void OnDiskGraphDB::print(raw_ostream &OS) const { + OS << "on-disk-root-path: " << RootPath << "\n"; + + struct PoolInfo { + uint64_t Offset; + }; + SmallVector<PoolInfo> Pool; + + OS << "\n"; + OS << "index:\n"; + Index.print(OS, [&](ArrayRef<char> Data) { + assert(Data.size() == sizeof(TrieRecord)); + assert(isAligned(Align::Of<TrieRecord>(), Data.size())); + auto *R = reinterpret_cast<const TrieRecord *>(Data.data()); + TrieRecord::Data D = R->load(); + OS << " SK="; + switch (D.SK) { + case TrieRecord::StorageKind::Unknown: + OS << "unknown "; + break; + case TrieRecord::StorageKind::DataPool: + OS << "datapool "; + Pool.push_back({D.Offset.get()}); + break; + case TrieRecord::StorageKind::Standalone: + OS << "standalone-data "; + break; + case TrieRecord::StorageKind::StandaloneLeaf: + OS << "standalone-leaf "; + break; + case TrieRecord::StorageKind::StandaloneLeaf0: + OS << "standalone-leaf+0"; + break; + } + OS << " Offset=" << (void *)D.Offset.get(); + }); + if (Pool.empty()) + return; + + OS << "\n"; + OS << "pool:\n"; + llvm::sort( + Pool, [](PoolInfo LHS, PoolInfo RHS) { return LHS.Offset < RHS.Offset; }); + for (PoolInfo PI : Pool) { + OS << "- addr=" << (void *)PI.Offset << " "; + auto D = DataRecordHandle::getFromDataPool(DataPool, FileOffset(PI.Offset)); + if (!D) { + OS << "error: " << toString(D.takeError()); + return; + } + + OS << "record refs=" << D->getNumRefs() << " data=" << D->getDataSize() + << " size=" << D->getTotalSize() + << " end=" << (void *)(PI.Offset + D->getTotalSize()) << "\n"; + } +} + +Expected<OnDiskGraphDB::IndexProxy> +OnDiskGraphDB::indexHash(ArrayRef<uint8_t> Hash) { + auto P = Index.insertLazy( + Hash, [](FileOffset TentativeOffset, + OnDiskTrieRawHashMap::ValueProxy TentativeValue) { + assert(TentativeValue.Data.size() == sizeof(TrieRecord)); + assert( + isAddrAligned(Align::Of<TrieRecord>(), TentativeValue.Data.data())); + new (TentativeValue.Data.data()) TrieRecord(); + }); + if (LLVM_UNLIKELY(!P)) + return P.takeError(); + + assert(*P && "Expected insertion"); + return getIndexProxyFromPointer(*P); +} + +OnDiskGraphDB::IndexProxy OnDiskGraphDB::getIndexProxyFromPointer( + OnDiskTrieRawHashMap::ConstOnDiskPtr P) const { + assert(P); + assert(P.getOffset()); + return IndexProxy{P.getOffset(), P->Hash, + *const_cast<TrieRecord *>( + reinterpret_cast<const TrieRecord *>(P->Data.data()))}; +} + +Expected<ObjectID> OnDiskGraphDB::getReference(ArrayRef<uint8_t> Hash) { + auto I = indexHash(Hash); + if (LLVM_UNLIKELY(!I)) + return I.takeError(); + return getExternalReference(*I); +} + +ObjectID OnDiskGraphDB::getExternalReference(const IndexProxy &I) { + return getExternalReference(makeInternalRef(I.Offset)); +} + +std::optional<ObjectID> +OnDiskGraphDB::getExistingReference(ArrayRef<uint8_t> Digest) { + auto tryUpstream = + [&](std::optional<IndexProxy> I) -> std::optional<ObjectID> { + if (!UpstreamDB) + return std::nullopt; + std::optional<ObjectID> UpstreamID = + UpstreamDB->getExistingReference(Digest); + if (LLVM_UNLIKELY(!UpstreamID)) + return std::nullopt; + auto Ref = expectedToOptional(indexHash(Digest)); + if (!Ref) + return std::nullopt; + if (!I) + I.emplace(*Ref); + return getExternalReference(*I); + }; + + OnDiskTrieRawHashMap::ConstOnDiskPtr P = Index.find(Digest); + if (!P) + return tryUpstream(std::nullopt); + IndexProxy I = getIndexProxyFromPointer(P); + TrieRecord::Data Obj = I.Ref.load(); + if (Obj.SK == TrieRecord::StorageKind::Unknown) + return tryUpstream(I); + return getExternalReference(makeInternalRef(I.Offset)); +} + +Expected<OnDiskGraphDB::IndexProxy> +OnDiskGraphDB::getIndexProxyFromRef(InternalRef Ref) const { + auto P = Index.recoverFromFileOffset(Ref.getFileOffset()); + if (LLVM_UNLIKELY(!P)) + return P.takeError(); + return getIndexProxyFromPointer(*P); +} + +Expected<ArrayRef<uint8_t>> OnDiskGraphDB::getDigest(InternalRef Ref) const { + auto I = getIndexProxyFromRef(Ref); + if (!I) + return I.takeError(); + return I->Hash; +} + +ArrayRef<uint8_t> OnDiskGraphDB::getDigest(const IndexProxy &I) const { + return I.Hash; +} + +static OnDiskContent getContentFromHandle(const OnDiskDataAllocator &DataPool, + ObjectHandle OH) { + // Decode ObjectHandle to locate the stored content. + uint64_t Data = OH.getOpaqueData(); + if (Data & 1) { + const auto *SDIM = + reinterpret_cast<const StandaloneDataInMemory *>(Data & (-1ULL << 1)); + return SDIM->getContent(); + } + + auto DataHandle = + cantFail(DataRecordHandle::getFromDataPool(DataPool, FileOffset(Data))); + assert(DataHandle.getData().end()[0] == 0 && "Null termination"); + return OnDiskContent{DataHandle, std::nullopt}; +} + +ArrayRef<char> OnDiskGraphDB::getObjectData(ObjectHandle Node) const { + OnDiskContent Content = getContentFromHandle(DataPool, Node); + if (Content.Bytes) + return *Content.Bytes; + assert(Content.Record && "Expected record or bytes"); + return Content.Record->getData(); +} + +InternalRefArrayRef OnDiskGraphDB::getInternalRefs(ObjectHandle Node) const { + if (std::optional<DataRecordHandle> Record = + getContentFromHandle(DataPool, Node).Record) + return Record->getRefs(); + return std::nullopt; +} + +Expected<std::optional<ObjectHandle>> +OnDiskGraphDB::load(ObjectID ExternalRef) { + InternalRef Ref = getInternalRef(ExternalRef); + auto I = getIndexProxyFromRef(Ref); + if (!I) + return I.takeError(); + TrieRecord::Data Object = I->Ref.load(); + + if (Object.SK == TrieRecord::StorageKind::Unknown) { + if (!UpstreamDB) + return std::nullopt; + return faultInFromUpstream(ExternalRef); + } + + if (Object.SK == TrieRecord::StorageKind::DataPool) + return ObjectHandle::fromFileOffset(Object.Offset); + + // Only TrieRecord::StorageKind::Standalone (and variants) need to be + // explicitly loaded. + // + // There's corruption if standalone objects have offsets, or if we get here + // for something that isn't standalone. + if (Object.Offset) + return createCorruptObjectError(getDigest(*I)); + switch (Object.SK) { + case TrieRecord::StorageKind::Unknown: + case TrieRecord::StorageKind::DataPool: + llvm_unreachable("unexpected storage kind"); + case TrieRecord::StorageKind::Standalone: + case TrieRecord::StorageKind::StandaloneLeaf0: + case TrieRecord::StorageKind::StandaloneLeaf: + break; + } + + // Load it from disk. + // + // Note: Creation logic guarantees that data that needs null-termination is + // suitably 0-padded. Requiring null-termination here would be too expensive + // for extremely large objects that happen to be page-aligned. + SmallString<256> Path; + getStandalonePath(TrieRecord::getStandaloneFilePrefix(Object.SK), *I, Path); + + auto File = sys::fs::openNativeFileForRead(Path); + if (!File) + return createFileError(Path, File.takeError()); + + auto CloseFile = make_scope_exit([&]() { sys::fs::closeFile(*File); }); + + sys::fs::file_status Status; + if (std::error_code EC = sys::fs::status(*File, Status)) + return createCorruptObjectError(getDigest(*I)); + + std::error_code EC; + auto Region = std::make_unique<sys::fs::mapped_file_region>( + *File, sys::fs::mapped_file_region::readonly, Status.getSize(), 0, EC); + if (EC) + return createCorruptObjectError(getDigest(*I)); + + return ObjectHandle::fromMemory( + static_cast<StandaloneDataMapTy *>(StandaloneData) + ->insert(I->Hash, Object.SK, std::move(Region))); +} + +Expected<bool> OnDiskGraphDB::isMaterialized(ObjectID Ref) { + auto Presence = getObjectPresence(Ref, /*CheckUpstream=*/true); + if (!Presence) + return Presence.takeError(); + + switch (*Presence) { + case ObjectPresence::Missing: + return false; + case ObjectPresence::InPrimaryDB: + return true; + case ObjectPresence::OnlyInUpstreamDB: + if (auto FaultInResult = faultInFromUpstream(Ref); !FaultInResult) + return FaultInResult.takeError(); + return true; + } + llvm_unreachable("Unknown ObjectPresence enum"); +} + +Expected<OnDiskGraphDB::ObjectPresence> +OnDiskGraphDB::getObjectPresence(ObjectID ExternalRef, + bool CheckUpstream) const { + InternalRef Ref = getInternalRef(ExternalRef); + auto I = getIndexProxyFromRef(Ref); + if (!I) + return I.takeError(); + + TrieRecord::Data Object = I->Ref.load(); + if (Object.SK != TrieRecord::StorageKind::Unknown) + return ObjectPresence::InPrimaryDB; + if (!CheckUpstream || !UpstreamDB) + return ObjectPresence::Missing; + std::optional<ObjectID> UpstreamID = + UpstreamDB->getExistingReference(getDigest(*I)); + return UpstreamID.has_value() ? ObjectPresence::OnlyInUpstreamDB + : ObjectPresence::Missing; +} + +InternalRef OnDiskGraphDB::makeInternalRef(FileOffset IndexOffset) { + return InternalRef::getFromOffset(IndexOffset); +} + +void OnDiskGraphDB::getStandalonePath(StringRef Prefix, const IndexProxy &I, + SmallVectorImpl<char> &Path) const { + Path.assign(RootPath.begin(), RootPath.end()); + sys::path::append(Path, + Prefix + Twine(I.Offset.get()) + "." + CASFormatVersion); +} + +OnDiskContent StandaloneDataInMemory::getContent() const { + bool Leaf0 = false; + bool Leaf = false; + switch (SK) { + default: + llvm_unreachable("Storage kind must be standalone"); + case TrieRecord::StorageKind::Standalone: + break; + case TrieRecord::StorageKind::StandaloneLeaf0: + Leaf = Leaf0 = true; + break; + case TrieRecord::StorageKind::StandaloneLeaf: + Leaf = true; + break; + } + + if (Leaf) { + StringRef Data(Region->data(), Region->size()); + assert(Data.drop_back(Leaf0).end()[0] == 0 && + "Standalone node data missing null termination"); + return OnDiskContent{std::nullopt, + arrayRefFromStringRef<char>(Data.drop_back(Leaf0))}; + } + + DataRecordHandle Record = DataRecordHandle::get(Region->data()); + assert(Record.getData().end()[0] == 0 && + "Standalone object record missing null termination for data"); + return OnDiskContent{Record, std::nullopt}; +} + +static Expected<MappedTempFile> createTempFile(StringRef FinalPath, + uint64_t Size) { + assert(Size && "Unexpected request for an empty temp file"); + Expected<TempFile> File = TempFile::create(FinalPath + ".%%%%%%"); + if (!File) + return File.takeError(); + + if (Error E = preallocateFileTail(File->FD, 0, Size).takeError()) + return createFileError(File->TmpName, std::move(E)); + + if (auto EC = sys::fs::resize_file_before_mapping_readwrite(File->FD, Size)) + return createFileError(File->TmpName, EC); + + std::error_code EC; + sys::fs::mapped_file_region Map(sys::fs::convertFDToNativeFile(File->FD), + sys::fs::mapped_file_region::readwrite, Size, + 0, EC); + if (EC) + return createFileError(File->TmpName, EC); + return MappedTempFile(std::move(*File), std::move(Map)); +} + +static size_t getPageSize() { + static int PageSize = sys::Process::getPageSizeEstimate(); + return PageSize; +} + +Error OnDiskGraphDB::createStandaloneLeaf(IndexProxy &I, ArrayRef<char> Data) { + assert(Data.size() > TrieRecord::MaxEmbeddedSize && + "Expected a bigger file for external content..."); + + bool Leaf0 = isAligned(Align(getPageSize()), Data.size()); + TrieRecord::StorageKind SK = Leaf0 ? TrieRecord::StorageKind::StandaloneLeaf0 + : TrieRecord::StorageKind::StandaloneLeaf; + + SmallString<256> Path; + int64_t FileSize = Data.size() + Leaf0; + getStandalonePath(TrieRecord::getStandaloneFilePrefix(SK), I, Path); + + // Write the file. Don't reuse this mapped_file_region, which is read/write. + // Let load() pull up one that's read-only. + Expected<MappedTempFile> File = createTempFile(Path, FileSize); + if (!File) + return File.takeError(); + assert(File->size() == (uint64_t)FileSize); + llvm::copy(Data, File->data()); + if (Leaf0) + File->data()[Data.size()] = 0; + assert(File->data()[Data.size()] == 0); + if (Error E = File->keep(Path)) + return E; + + // Store the object reference. + TrieRecord::Data Existing; + { + TrieRecord::Data Leaf{SK, FileOffset()}; + if (I.Ref.compare_exchange_strong(Existing, Leaf)) { + recordStandaloneSizeIncrease(FileSize); + return Error::success(); + } + } + + // If there was a race, confirm that the new value has valid storage. + if (Existing.SK == TrieRecord::StorageKind::Unknown) + return createCorruptObjectError(getDigest(I)); + + return Error::success(); +} + +Error OnDiskGraphDB::store(ObjectID ID, ArrayRef<ObjectID> Refs, + ArrayRef<char> Data) { + auto I = getIndexProxyFromRef(getInternalRef(ID)); + if (LLVM_UNLIKELY(!I)) + return I.takeError(); + + // Early return in case the node exists. + { + TrieRecord::Data Existing = I->Ref.load(); + if (Existing.SK != TrieRecord::StorageKind::Unknown) + return Error::success(); + } + + // Big leaf nodes. + if (Refs.empty() && Data.size() > TrieRecord::MaxEmbeddedSize) + return createStandaloneLeaf(*I, Data); + + // TODO: Check whether it's worth checking the index for an already existing + // object (like storeTreeImpl() does) before building up the + // InternalRefVector. + InternalRefVector InternalRefs; + for (ObjectID Ref : Refs) + InternalRefs.push_back(getInternalRef(Ref)); + + // Create the object. + + DataRecordHandle::Input Input{InternalRefs, Data}; + + // Compute the storage kind, allocate it, and create the record. + TrieRecord::StorageKind SK = TrieRecord::StorageKind::Unknown; + FileOffset PoolOffset; + SmallString<256> Path; + std::optional<MappedTempFile> File; + std::optional<uint64_t> FileSize; + auto AllocStandaloneFile = [&](size_t Size) -> Expected<char *> { + getStandalonePath(TrieRecord::getStandaloneFilePrefix( + TrieRecord::StorageKind::Standalone), + *I, Path); + if (Error E = createTempFile(Path, Size).moveInto(File)) + return std::move(E); + assert(File->size() == Size); + FileSize = Size; + SK = TrieRecord::StorageKind::Standalone; + return File->data(); + }; + auto Alloc = [&](size_t Size) -> Expected<char *> { + if (Size <= TrieRecord::MaxEmbeddedSize) { + SK = TrieRecord::StorageKind::DataPool; + auto P = DataPool.allocate(Size); + if (LLVM_UNLIKELY(!P)) { + char *NewAlloc = nullptr; + auto NewE = handleErrors( + P.takeError(), [&](std::unique_ptr<StringError> E) -> Error { + if (E->convertToErrorCode() == std::errc::not_enough_memory) + return AllocStandaloneFile(Size).moveInto(NewAlloc); + return Error(std::move(E)); + }); + if (!NewE) + return NewAlloc; + return std::move(NewE); + } + PoolOffset = P->getOffset(); + LLVM_DEBUG({ + dbgs() << "pool-alloc addr=" << (void *)PoolOffset.get() + << " size=" << Size + << " end=" << (void *)(PoolOffset.get() + Size) << "\n"; + }); + return (*P)->data(); + } + return AllocStandaloneFile(Size); + }; + + DataRecordHandle Record; + if (Error E = + DataRecordHandle::createWithError(Alloc, Input).moveInto(Record)) + return E; + assert(Record.getData().end()[0] == 0 && "Expected null-termination"); + assert(Record.getData() == Input.Data && "Expected initialization"); + assert(SK != TrieRecord::StorageKind::Unknown); + assert(bool(File) != bool(PoolOffset) && + "Expected either a mapped file or a pooled offset"); + + // Check for a race before calling MappedTempFile::keep(). + // + // Then decide what to do with the file. Better to discard than overwrite if + // another thread/process has already added this. + TrieRecord::Data Existing = I->Ref.load(); + { + TrieRecord::Data NewObject{SK, PoolOffset}; + if (File) { + if (Existing.SK == TrieRecord::StorageKind::Unknown) { + // Keep the file! + if (Error E = File->keep(Path)) + return E; + } else { + File.reset(); + } + } + + // If we didn't already see a racing/existing write, then try storing the + // new object. If that races, confirm that the new value has valid storage. + // + // TODO: Find a way to reuse the storage from the new-but-abandoned record + // handle. + if (Existing.SK == TrieRecord::StorageKind::Unknown) { + if (I->Ref.compare_exchange_strong(Existing, NewObject)) { + if (FileSize) + recordStandaloneSizeIncrease(*FileSize); + return Error::success(); + } + } + } + + if (Existing.SK == TrieRecord::StorageKind::Unknown) + return createCorruptObjectError(getDigest(*I)); + + // Load existing object. + return Error::success(); +} + +void OnDiskGraphDB::recordStandaloneSizeIncrease(size_t SizeIncrease) { + standaloneStorageSize().fetch_add(SizeIncrease, std::memory_order_relaxed); +} + +std::atomic<uint64_t> &OnDiskGraphDB::standaloneStorageSize() const { + MutableArrayRef<uint8_t> UserHeader = DataPool.getUserHeader(); + assert(UserHeader.size() == sizeof(std::atomic<uint64_t>)); + assert(isAddrAligned(Align(8), UserHeader.data())); + return *reinterpret_cast<std::atomic<uint64_t> *>(UserHeader.data()); +} + +uint64_t OnDiskGraphDB::getStandaloneStorageSize() const { + return standaloneStorageSize().load(std::memory_order_relaxed); +} + +size_t OnDiskGraphDB::getStorageSize() const { + return Index.size() + DataPool.size() + getStandaloneStorageSize(); +} + +unsigned OnDiskGraphDB::getHardStorageLimitUtilization() const { + unsigned IndexPercent = Index.size() * 100ULL / Index.capacity(); + unsigned DataPercent = DataPool.size() * 100ULL / DataPool.capacity(); + return std::max(IndexPercent, DataPercent); +} + +Expected<std::unique_ptr<OnDiskGraphDB>> OnDiskGraphDB::open( + StringRef AbsPath, StringRef HashName, unsigned HashByteSize, + std::unique_ptr<OnDiskGraphDB> UpstreamDB, FaultInPolicy Policy) { + if (std::error_code EC = sys::fs::create_directories(AbsPath)) + return createFileError(AbsPath, EC); + + constexpr uint64_t MB = 1024ull * 1024ull; + constexpr uint64_t GB = 1024ull * 1024ull * 1024ull; + + uint64_t MaxIndexSize = 12 * GB; + uint64_t MaxDataPoolSize = 24 * GB; + + if (useSmallMappingSize(AbsPath)) { + MaxIndexSize = 1 * GB; + MaxDataPoolSize = 2 * GB; + } + + auto CustomSize = getOverriddenMaxMappingSize(); + if (!CustomSize) + return CustomSize.takeError(); + if (*CustomSize) + MaxIndexSize = MaxDataPoolSize = **CustomSize; + + SmallString<256> IndexPath(AbsPath); + sys::path::append(IndexPath, IndexFilePrefix + CASFormatVersion); + std::optional<OnDiskTrieRawHashMap> Index; + if (Error E = OnDiskTrieRawHashMap::create( + IndexPath, IndexTableName + "[" + HashName + "]", + HashByteSize * CHAR_BIT, + /*DataSize=*/sizeof(TrieRecord), MaxIndexSize, + /*MinFileSize=*/MB) + .moveInto(Index)) + return std::move(E); + + uint32_t UserHeaderSize = sizeof(std::atomic<uint64_t>); + + SmallString<256> DataPoolPath(AbsPath); + sys::path::append(DataPoolPath, DataPoolFilePrefix + CASFormatVersion); + std::optional<OnDiskDataAllocator> DataPool; + StringRef PolicyName = + Policy == FaultInPolicy::SingleNode ? "single" : "full"; + if (Error E = OnDiskDataAllocator::create( + DataPoolPath, + DataPoolTableName + "[" + HashName + "]" + PolicyName, + MaxDataPoolSize, /*MinFileSize=*/MB, UserHeaderSize, + [](void *UserHeaderPtr) { + new (UserHeaderPtr) std::atomic<uint64_t>(0); + }) + .moveInto(DataPool)) + return std::move(E); + if (DataPool->getUserHeader().size() != UserHeaderSize) + return createStringError(llvm::errc::argument_out_of_domain, + "unexpected user header in '" + DataPoolPath + + "'"); + + return std::unique_ptr<OnDiskGraphDB>( + new OnDiskGraphDB(AbsPath, std::move(*Index), std::move(*DataPool), + std::move(UpstreamDB), Policy)); +} + +OnDiskGraphDB::OnDiskGraphDB(StringRef RootPath, OnDiskTrieRawHashMap Index, + OnDiskDataAllocator DataPool, + std::unique_ptr<OnDiskGraphDB> UpstreamDB, + FaultInPolicy Policy) + : Index(std::move(Index)), DataPool(std::move(DataPool)), + RootPath(RootPath.str()), UpstreamDB(std::move(UpstreamDB)), + FIPolicy(Policy) { + /// Lifetime for "big" objects not in DataPool. + /// + /// NOTE: Could use ThreadSafeTrieRawHashMap here. For now, doing something + /// simpler on the assumption there won't be much contention since most data + /// is not big. If there is contention, and we've already fixed ObjectProxy + /// object handles to be cheap enough to use consistently, the fix might be + /// to use better use of them rather than optimizing this map. + /// + /// FIXME: Figure out the right number of shards, if any. + StandaloneData = new StandaloneDataMapTy(); +} + +OnDiskGraphDB::~OnDiskGraphDB() { + delete static_cast<StandaloneDataMapTy *>(StandaloneData); +} + +Error OnDiskGraphDB::importFullTree(ObjectID PrimaryID, + ObjectHandle UpstreamNode) { + // Copies the full CAS tree from upstream. Uses depth-first copying to protect + // against the process dying during importing and leaving the database with an + // incomplete tree. Note that if the upstream has missing nodes then the tree + // will be copied with missing nodes as well, it won't be considered an error. + + struct UpstreamCursor { + ObjectHandle Node; + size_t RefsCount; + object_refs_iterator RefI; + object_refs_iterator RefE; + }; + /// Keeps track of the state of visitation for current node and all of its + /// parents. + SmallVector<UpstreamCursor, 16> CursorStack; + /// Keeps track of the currently visited nodes as they are imported into + /// primary database, from current node and its parents. When a node is + /// entered for visitation it appends its own ID, then appends referenced IDs + /// as they get imported. When a node is fully imported it removes the + /// referenced IDs from the bottom of the stack which leaves its own ID at the + /// bottom, adding to the list of referenced IDs for the parent node. + SmallVector<ObjectID, 128> PrimaryNodesStack; + + auto enqueueNode = [&](ObjectID PrimaryID, std::optional<ObjectHandle> Node) { + PrimaryNodesStack.push_back(PrimaryID); + if (!Node) + return; + auto Refs = UpstreamDB->getObjectRefs(*Node); + CursorStack.push_back({*Node, + (size_t)std::distance(Refs.begin(), Refs.end()), + Refs.begin(), Refs.end()}); + }; + + enqueueNode(PrimaryID, UpstreamNode); + + while (!CursorStack.empty()) { + UpstreamCursor &Cur = CursorStack.back(); + if (Cur.RefI == Cur.RefE) { + // Copy the node data into the primary store. + // FIXME: Use hard-link or cloning if the file-system supports it and data + // is stored into a separate file. + + // The bottom of \p PrimaryNodesStack contains the primary ID for the + // current node plus the list of imported referenced IDs. + assert(PrimaryNodesStack.size() >= Cur.RefsCount + 1); + ObjectID PrimaryID = *(PrimaryNodesStack.end() - Cur.RefsCount - 1); + auto PrimaryRefs = ArrayRef(PrimaryNodesStack) + .slice(PrimaryNodesStack.size() - Cur.RefsCount); + auto Data = UpstreamDB->getObjectData(Cur.Node); + if (Error E = store(PrimaryID, PrimaryRefs, Data)) + return E; + // Remove the current node and its IDs from the stack. + PrimaryNodesStack.truncate(PrimaryNodesStack.size() - Cur.RefsCount); + CursorStack.pop_back(); + continue; + } + + ObjectID UpstreamID = *(Cur.RefI++); + auto PrimaryID = getReference(UpstreamDB->getDigest(UpstreamID)); + if (LLVM_UNLIKELY(!PrimaryID)) + return PrimaryID.takeError(); + if (containsObject(*PrimaryID, /*CheckUpstream=*/false)) { + // This \p ObjectID already exists in the primary. Either it was imported + // via \p importFullTree or the client created it, in which case the + // client takes responsibility for how it was formed. + enqueueNode(*PrimaryID, std::nullopt); + continue; + } + Expected<std::optional<ObjectHandle>> UpstreamNode = + UpstreamDB->load(UpstreamID); + if (!UpstreamNode) + return UpstreamNode.takeError(); + enqueueNode(*PrimaryID, *UpstreamNode); + } + + assert(PrimaryNodesStack.size() == 1); + assert(PrimaryNodesStack.front() == PrimaryID); + return Error::success(); +} + +Error OnDiskGraphDB::importSingleNode(ObjectID PrimaryID, + ObjectHandle UpstreamNode) { + // Copies only a single node, it doesn't copy the referenced nodes. + + // Copy the node data into the primary store. + // FIXME: Use hard-link or cloning if the file-system supports it and data is + // stored into a separate file. + + auto Data = UpstreamDB->getObjectData(UpstreamNode); + auto UpstreamRefs = UpstreamDB->getObjectRefs(UpstreamNode); + SmallVector<ObjectID, 64> Refs; + Refs.reserve(std::distance(UpstreamRefs.begin(), UpstreamRefs.end())); + for (ObjectID UpstreamRef : UpstreamRefs) { + auto Ref = getReference(UpstreamDB->getDigest(UpstreamRef)); + if (LLVM_UNLIKELY(!Ref)) + return Ref.takeError(); + Refs.push_back(*Ref); + } + + return store(PrimaryID, Refs, Data); +} + +Expected<std::optional<ObjectHandle>> +OnDiskGraphDB::faultInFromUpstream(ObjectID PrimaryID) { + assert(UpstreamDB); + + auto UpstreamID = UpstreamDB->getReference(getDigest(PrimaryID)); + if (LLVM_UNLIKELY(!UpstreamID)) + return UpstreamID.takeError(); + + Expected<std::optional<ObjectHandle>> UpstreamNode = + UpstreamDB->load(*UpstreamID); + if (!UpstreamNode) + return UpstreamNode.takeError(); + if (!*UpstreamNode) + return std::nullopt; + + if (Error E = FIPolicy == FaultInPolicy::SingleNode + ? importSingleNode(PrimaryID, **UpstreamNode) + : importFullTree(PrimaryID, **UpstreamNode)) + return std::move(E); + return load(PrimaryID); +} diff --git a/llvm/lib/CAS/OnDiskKeyValueDB.cpp b/llvm/lib/CAS/OnDiskKeyValueDB.cpp new file mode 100644 index 0000000..2186071 --- /dev/null +++ b/llvm/lib/CAS/OnDiskKeyValueDB.cpp @@ -0,0 +1,113 @@ +//===- OnDiskKeyValueDB.cpp -------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// This file implements OnDiskKeyValueDB, an ondisk key value database. +/// +/// The KeyValue database file is named `actions.<version>` inside the CAS +/// directory. The database stores a mapping between a fixed-sized key and a +/// fixed-sized value, where the size of key and value can be configured when +/// opening the database. +/// +// +//===----------------------------------------------------------------------===// + +#include "llvm/CAS/OnDiskKeyValueDB.h" +#include "OnDiskCommon.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Alignment.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Errc.h" +#include "llvm/Support/Path.h" + +using namespace llvm; +using namespace llvm::cas; +using namespace llvm::cas::ondisk; + +static constexpr StringLiteral ActionCacheFile = "actions."; + +Expected<ArrayRef<char>> OnDiskKeyValueDB::put(ArrayRef<uint8_t> Key, + ArrayRef<char> Value) { + if (LLVM_UNLIKELY(Value.size() != ValueSize)) + return createStringError(errc::invalid_argument, + "expected value size of " + itostr(ValueSize) + + ", got: " + itostr(Value.size())); + assert(Value.size() == ValueSize); + auto ActionP = Cache.insertLazy( + Key, [&](FileOffset TentativeOffset, + OnDiskTrieRawHashMap::ValueProxy TentativeValue) { + assert(TentativeValue.Data.size() == ValueSize); + llvm::copy(Value, TentativeValue.Data.data()); + }); + if (LLVM_UNLIKELY(!ActionP)) + return ActionP.takeError(); + return (*ActionP)->Data; +} + +Expected<std::optional<ArrayRef<char>>> +OnDiskKeyValueDB::get(ArrayRef<uint8_t> Key) { + // Check the result cache. + OnDiskTrieRawHashMap::ConstOnDiskPtr ActionP = Cache.find(Key); + if (!ActionP) + return std::nullopt; + assert(isAddrAligned(Align(8), ActionP->Data.data())); + return ActionP->Data; +} + +Expected<std::unique_ptr<OnDiskKeyValueDB>> +OnDiskKeyValueDB::open(StringRef Path, StringRef HashName, unsigned KeySize, + StringRef ValueName, size_t ValueSize) { + if (std::error_code EC = sys::fs::create_directories(Path)) + return createFileError(Path, EC); + + SmallString<256> CachePath(Path); + sys::path::append(CachePath, ActionCacheFile + CASFormatVersion); + constexpr uint64_t MB = 1024ull * 1024ull; + constexpr uint64_t GB = 1024ull * 1024ull * 1024ull; + + uint64_t MaxFileSize = GB; + auto CustomSize = getOverriddenMaxMappingSize(); + if (!CustomSize) + return CustomSize.takeError(); + if (*CustomSize) + MaxFileSize = **CustomSize; + + std::optional<OnDiskTrieRawHashMap> ActionCache; + if (Error E = OnDiskTrieRawHashMap::create( + CachePath, + "llvm.actioncache[" + HashName + "->" + ValueName + "]", + KeySize * 8, + /*DataSize=*/ValueSize, MaxFileSize, /*MinFileSize=*/MB) + .moveInto(ActionCache)) + return std::move(E); + + return std::unique_ptr<OnDiskKeyValueDB>( + new OnDiskKeyValueDB(ValueSize, std::move(*ActionCache))); +} + +Error OnDiskKeyValueDB::validate(CheckValueT CheckValue) const { + return Cache.validate( + [&](FileOffset Offset, + OnDiskTrieRawHashMap::ConstValueProxy Record) -> Error { + auto formatError = [&](Twine Msg) { + return createStringError( + llvm::errc::illegal_byte_sequence, + "bad cache value at 0x" + + utohexstr((unsigned)Offset.get(), /*LowerCase=*/true) + ": " + + Msg.str()); + }; + + if (Record.Data.size() != ValueSize) + return formatError("wrong cache value size"); + if (!isAddrAligned(Align(8), Record.Data.data())) + return formatError("wrong cache value alignment"); + if (CheckValue) + return CheckValue(Offset, Record.Data); + return Error::success(); + }); +} diff --git a/llvm/lib/CodeGen/ExpandFp.cpp b/llvm/lib/CodeGen/ExpandFp.cpp index 04c7008..2b5ced3 100644 --- a/llvm/lib/CodeGen/ExpandFp.cpp +++ b/llvm/lib/CodeGen/ExpandFp.cpp @@ -993,7 +993,6 @@ static void addToWorklist(Instruction &I, static bool runImpl(Function &F, const TargetLowering &TLI, AssumptionCache *AC) { SmallVector<Instruction *, 4> Worklist; - bool Modified = false; unsigned MaxLegalFpConvertBitWidth = TLI.getMaxLargeFPConvertBitWidthSupported(); @@ -1003,50 +1002,49 @@ static bool runImpl(Function &F, const TargetLowering &TLI, if (MaxLegalFpConvertBitWidth >= llvm::IntegerType::MAX_INT_BITS) return false; - for (auto It = inst_begin(&F), End = inst_end(F); It != End;) { - Instruction &I = *It++; + auto ShouldHandleInst = [&](Instruction &I) { Type *Ty = I.getType(); // TODO: This pass doesn't handle scalable vectors. if (Ty->isScalableTy()) - continue; + return false; switch (I.getOpcode()) { case Instruction::FRem: - if (!targetSupportsFrem(TLI, Ty) && - FRemExpander::canExpandType(Ty->getScalarType())) { - addToWorklist(I, Worklist); - Modified = true; - } - break; + return !targetSupportsFrem(TLI, Ty) && + FRemExpander::canExpandType(Ty->getScalarType()); + case Instruction::FPToUI: case Instruction::FPToSI: { auto *IntTy = cast<IntegerType>(Ty->getScalarType()); - if (IntTy->getIntegerBitWidth() <= MaxLegalFpConvertBitWidth) - continue; - - addToWorklist(I, Worklist); - Modified = true; - break; + return IntTy->getIntegerBitWidth() > MaxLegalFpConvertBitWidth; } + case Instruction::UIToFP: case Instruction::SIToFP: { auto *IntTy = cast<IntegerType>(I.getOperand(0)->getType()->getScalarType()); - if (IntTy->getIntegerBitWidth() <= MaxLegalFpConvertBitWidth) - continue; - - addToWorklist(I, Worklist); - Modified = true; - break; + return IntTy->getIntegerBitWidth() > MaxLegalFpConvertBitWidth; } - default: - break; } + + return false; + }; + + bool Modified = false; + for (auto It = inst_begin(&F), End = inst_end(F); It != End;) { + Instruction &I = *It++; + if (!ShouldHandleInst(I)) + continue; + + addToWorklist(I, Worklist); + Modified = true; } while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (I->getOpcode() == Instruction::FRem) { + + switch (I->getOpcode()) { + case Instruction::FRem: { auto SQ = [&]() -> std::optional<SimplifyQuery> { if (AC) { auto Res = std::make_optional<SimplifyQuery>( @@ -1058,11 +1056,18 @@ static bool runImpl(Function &F, const TargetLowering &TLI, }(); expandFRem(cast<BinaryOperator>(*I), SQ); - } else if (I->getOpcode() == Instruction::FPToUI || - I->getOpcode() == Instruction::FPToSI) { + break; + } + + case Instruction::FPToUI: + case Instruction::FPToSI: expandFPToI(I); - } else { + break; + + case Instruction::UIToFP: + case Instruction::SIToFP: expandIToFP(I); + break; } } diff --git a/llvm/lib/CodeGen/InterleavedAccessPass.cpp b/llvm/lib/CodeGen/InterleavedAccessPass.cpp index a6a9b50..5c27a20 100644 --- a/llvm/lib/CodeGen/InterleavedAccessPass.cpp +++ b/llvm/lib/CodeGen/InterleavedAccessPass.cpp @@ -258,13 +258,11 @@ static Value *getMaskOperand(IntrinsicInst *II) { default: llvm_unreachable("Unexpected intrinsic"); case Intrinsic::vp_load: - return II->getOperand(1); case Intrinsic::masked_load: - return II->getOperand(2); + return II->getOperand(1); case Intrinsic::vp_store: - return II->getOperand(2); case Intrinsic::masked_store: - return II->getOperand(3); + return II->getOperand(2); } } diff --git a/llvm/lib/CodeGen/MIRFSDiscriminator.cpp b/llvm/lib/CodeGen/MIRFSDiscriminator.cpp index d988a2a..e37f784 100644 --- a/llvm/lib/CodeGen/MIRFSDiscriminator.cpp +++ b/llvm/lib/CodeGen/MIRFSDiscriminator.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/CodeGen/MIRFSDiscriminatorOptions.h" #include "llvm/CodeGen/Passes.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Function.h" @@ -35,13 +36,10 @@ using namespace sampleprofutil; // TODO(xur): Remove this option and related code once we make true as the // default. -namespace llvm { -cl::opt<bool> ImprovedFSDiscriminator( +cl::opt<bool> llvm::ImprovedFSDiscriminator( "improved-fs-discriminator", cl::Hidden, cl::init(false), cl::desc("New FS discriminators encoding (incompatible with the original " "encoding)")); -} // namespace llvm - char MIRAddFSDiscriminators::ID = 0; INITIALIZE_PASS(MIRAddFSDiscriminators, DEBUG_TYPE, diff --git a/llvm/lib/CodeGen/MIRSampleProfile.cpp b/llvm/lib/CodeGen/MIRSampleProfile.cpp index 9bba50e8..d44f577 100644 --- a/llvm/lib/CodeGen/MIRSampleProfile.cpp +++ b/llvm/lib/CodeGen/MIRSampleProfile.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/CodeGen/MIRFSDiscriminatorOptions.h" #include "llvm/CodeGen/MachineBlockFrequencyInfo.h" #include "llvm/CodeGen/MachineBranchProbabilityInfo.h" #include "llvm/CodeGen/MachineDominators.h" @@ -62,9 +63,6 @@ static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden, cl::init(false), cl::desc("View BFI after MIR loader")); -namespace llvm { -extern cl::opt<bool> ImprovedFSDiscriminator; -} char MIRProfileLoaderPass::ID = 0; INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE, diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index c97300d..6bf9008 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -26876,6 +26876,8 @@ static SDValue combineTruncationShuffle(ShuffleVectorSDNode *SVN, // TODO: handle more extension/truncation cases as cases arise. if (EltSizeInBits != ExtSrcSizeInBits) return SDValue(); + if (VT.getSizeInBits() != N00.getValueSizeInBits()) + return SDValue(); // We can remove *extend_vector_inreg only if the truncation happens at // the same scale as the extension. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index cb0038c..20a0efd 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4837,29 +4837,10 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I, bool IsCompressing) { SDLoc sdl = getCurSDLoc(); - auto getMaskedStoreOps = [&](Value *&Ptr, Value *&Mask, Value *&Src0, - Align &Alignment) { - // llvm.masked.store.*(Src0, Ptr, alignment, Mask) - Src0 = I.getArgOperand(0); - Ptr = I.getArgOperand(1); - Alignment = cast<ConstantInt>(I.getArgOperand(2))->getAlignValue(); - Mask = I.getArgOperand(3); - }; - auto getCompressingStoreOps = [&](Value *&Ptr, Value *&Mask, Value *&Src0, - Align &Alignment) { - // llvm.masked.compressstore.*(Src0, Ptr, Mask) - Src0 = I.getArgOperand(0); - Ptr = I.getArgOperand(1); - Mask = I.getArgOperand(2); - Alignment = I.getParamAlign(1).valueOrOne(); - }; - - Value *PtrOperand, *MaskOperand, *Src0Operand; - Align Alignment; - if (IsCompressing) - getCompressingStoreOps(PtrOperand, MaskOperand, Src0Operand, Alignment); - else - getMaskedStoreOps(PtrOperand, MaskOperand, Src0Operand, Alignment); + Value *Src0Operand = I.getArgOperand(0); + Value *PtrOperand = I.getArgOperand(1); + Value *MaskOperand = I.getArgOperand(2); + Align Alignment = I.getParamAlign(1).valueOrOne(); SDValue Ptr = getValue(PtrOperand); SDValue Src0 = getValue(Src0Operand); @@ -4964,14 +4945,12 @@ static bool getUniformBase(const Value *Ptr, SDValue &Base, SDValue &Index, void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) { SDLoc sdl = getCurSDLoc(); - // llvm.masked.scatter.*(Src0, Ptrs, alignment, Mask) + // llvm.masked.scatter.*(Src0, Ptrs, Mask) const Value *Ptr = I.getArgOperand(1); SDValue Src0 = getValue(I.getArgOperand(0)); - SDValue Mask = getValue(I.getArgOperand(3)); + SDValue Mask = getValue(I.getArgOperand(2)); EVT VT = Src0.getValueType(); - Align Alignment = cast<ConstantInt>(I.getArgOperand(2)) - ->getMaybeAlignValue() - .value_or(DAG.getEVTAlign(VT.getScalarType())); + Align Alignment = I.getParamAlign(1).valueOrOne(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); SDValue Base; @@ -5008,29 +4987,10 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) { void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) { SDLoc sdl = getCurSDLoc(); - auto getMaskedLoadOps = [&](Value *&Ptr, Value *&Mask, Value *&Src0, - Align &Alignment) { - // @llvm.masked.load.*(Ptr, alignment, Mask, Src0) - Ptr = I.getArgOperand(0); - Alignment = cast<ConstantInt>(I.getArgOperand(1))->getAlignValue(); - Mask = I.getArgOperand(2); - Src0 = I.getArgOperand(3); - }; - auto getExpandingLoadOps = [&](Value *&Ptr, Value *&Mask, Value *&Src0, - Align &Alignment) { - // @llvm.masked.expandload.*(Ptr, Mask, Src0) - Ptr = I.getArgOperand(0); - Alignment = I.getParamAlign(0).valueOrOne(); - Mask = I.getArgOperand(1); - Src0 = I.getArgOperand(2); - }; - - Value *PtrOperand, *MaskOperand, *Src0Operand; - Align Alignment; - if (IsExpanding) - getExpandingLoadOps(PtrOperand, MaskOperand, Src0Operand, Alignment); - else - getMaskedLoadOps(PtrOperand, MaskOperand, Src0Operand, Alignment); + Value *PtrOperand = I.getArgOperand(0); + Value *MaskOperand = I.getArgOperand(1); + Value *Src0Operand = I.getArgOperand(2); + Align Alignment = I.getParamAlign(0).valueOrOne(); SDValue Ptr = getValue(PtrOperand); SDValue Src0 = getValue(Src0Operand); @@ -5077,16 +5037,14 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) { void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) { SDLoc sdl = getCurSDLoc(); - // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) + // @llvm.masked.gather.*(Ptrs, Mask, Src0) const Value *Ptr = I.getArgOperand(0); - SDValue Src0 = getValue(I.getArgOperand(3)); - SDValue Mask = getValue(I.getArgOperand(2)); + SDValue Src0 = getValue(I.getArgOperand(2)); + SDValue Mask = getValue(I.getArgOperand(1)); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); - Align Alignment = cast<ConstantInt>(I.getArgOperand(1)) - ->getMaybeAlignValue() - .value_or(DAG.getEVTAlign(VT.getScalarType())); + Align Alignment = I.getParamAlign(0).valueOrOne(); const MDNode *Ranges = getRangeMetadata(I); diff --git a/llvm/lib/DWARFLinker/Classic/DWARFLinker.cpp b/llvm/lib/DWARFLinker/Classic/DWARFLinker.cpp index 8052773..8637b55 100644 --- a/llvm/lib/DWARFLinker/Classic/DWARFLinker.cpp +++ b/llvm/lib/DWARFLinker/Classic/DWARFLinker.cpp @@ -2427,11 +2427,13 @@ void DWARFLinker::DIECloner::generateLineTableForUnit(CompileUnit &Unit) { uint64_t OrigStmtSeq = StmtSeq.get(); // 1. Get the original row index from the stmt list offset. auto OrigRowIter = SeqOffToOrigRow.find(OrigStmtSeq); + const uint64_t InvalidOffset = + Unit.getOrigUnit().getFormParams().getDwarfMaxOffset(); // Check whether we have an output sequence for the StmtSeq offset. // Some sequences are discarded by the DWARFLinker if they are invalid // (empty). if (OrigRowIter == SeqOffToOrigRow.end()) { - StmtSeq.set(UINT64_MAX); + StmtSeq.set(InvalidOffset); continue; } size_t OrigRowIndex = OrigRowIter->second; @@ -2441,7 +2443,7 @@ void DWARFLinker::DIECloner::generateLineTableForUnit(CompileUnit &Unit) { if (NewRowIter == OrigRowToNewRow.end()) { // If the original row index is not found in the map, update the // stmt_sequence attribute to the 'invalid offset' magic value. - StmtSeq.set(UINT64_MAX); + StmtSeq.set(InvalidOffset); continue; } diff --git a/llvm/lib/DebugInfo/GSYM/DwarfTransformer.cpp b/llvm/lib/DebugInfo/GSYM/DwarfTransformer.cpp index fa39603..a326a01 100644 --- a/llvm/lib/DebugInfo/GSYM/DwarfTransformer.cpp +++ b/llvm/lib/DebugInfo/GSYM/DwarfTransformer.cpp @@ -320,12 +320,16 @@ static void convertFunctionLineTable(OutputAggregator &Out, CUInfo &CUI, // Attempt to retrieve DW_AT_LLVM_stmt_sequence if present. std::optional<uint64_t> StmtSeqOffset; if (auto StmtSeqAttr = Die.find(llvm::dwarf::DW_AT_LLVM_stmt_sequence)) { - // The `DW_AT_LLVM_stmt_sequence` attribute might be set to `UINT64_MAX` - // when it refers to an empty line sequence. In such cases, the DWARF linker - // will exclude the empty sequence from the final output and assign - // `UINT64_MAX` to the `DW_AT_LLVM_stmt_sequence` attribute. - uint64_t StmtSeqVal = dwarf::toSectionOffset(StmtSeqAttr, UINT64_MAX); - if (StmtSeqVal != UINT64_MAX) + // The `DW_AT_LLVM_stmt_sequence` attribute might be set to an invalid + // sentinel value when it refers to an empty line sequence. In such cases, + // the DWARF linker will exclude the empty sequence from the final output + // and assign the sentinel value to the `DW_AT_LLVM_stmt_sequence` + // attribute. The sentinel value is UINT32_MAX for DWARF32 and UINT64_MAX + // for DWARF64. + const uint64_t InvalidOffset = + Die.getDwarfUnit()->getFormParams().getDwarfMaxOffset(); + uint64_t StmtSeqVal = dwarf::toSectionOffset(StmtSeqAttr, InvalidOffset); + if (StmtSeqVal != InvalidOffset) StmtSeqOffset = StmtSeqVal; } diff --git a/llvm/lib/ExecutionEngine/Orc/Core.cpp b/llvm/lib/ExecutionEngine/Orc/Core.cpp index f47b7ec..8d413a3 100644 --- a/llvm/lib/ExecutionEngine/Orc/Core.cpp +++ b/llvm/lib/ExecutionEngine/Orc/Core.cpp @@ -1173,39 +1173,7 @@ void JITDylib::dump(raw_ostream &OS) { << " pending queries: { "; for (const auto &Q : KV.second.pendingQueries()) OS << Q.get() << " (" << Q->getRequiredState() << ") "; - OS << "}\n Defining EDU: "; - if (KV.second.DefiningEDU) { - OS << KV.second.DefiningEDU.get() << " { "; - for (auto &[Name, Flags] : KV.second.DefiningEDU->Symbols) - OS << Name << " "; - OS << "}\n"; - OS << " Dependencies:\n"; - if (!KV.second.DefiningEDU->Dependencies.empty()) { - for (auto &[DepJD, Deps] : KV.second.DefiningEDU->Dependencies) { - OS << " " << DepJD->getName() << ": [ "; - for (auto &Dep : Deps) - OS << Dep << " "; - OS << "]\n"; - } - } else - OS << " none\n"; - } else - OS << "none\n"; - OS << " Dependant EDUs:\n"; - if (!KV.second.DependantEDUs.empty()) { - for (auto &DependantEDU : KV.second.DependantEDUs) { - OS << " " << DependantEDU << ": " - << DependantEDU->JD->getName() << " { "; - for (auto &[Name, Flags] : DependantEDU->Symbols) - OS << Name << " "; - OS << "}\n"; - } - } else - OS << " none\n"; - assert((Symbols[KV.first].getState() != SymbolState::Ready || - (KV.second.pendingQueries().empty() && !KV.second.DefiningEDU && - !KV.second.DependantEDUs.empty())) && - "Stale materializing info entry"); + OS << "}\n"; } }); } @@ -1967,9 +1935,6 @@ bool ExecutionSession::verifySessionState(Twine Phase) { return runSessionLocked([&]() { bool AllOk = true; - // We'll collect these and verify them later to avoid redundant checks. - DenseSet<JITDylib::EmissionDepUnit *> EDUsToCheck; - for (auto &JD : JDs) { auto LogFailure = [&]() -> raw_fd_ostream & { @@ -2063,86 +2028,6 @@ bool ExecutionSession::verifySessionState(Twine Phase) { << " has stale or misordered queries.\n"; } } - - // If there's a DefiningEDU then check that... - // 1. The JD matches. - // 2. The symbol is in the EDU's Symbols map. - // 3. The symbol table entry is in the Emitted state. - if (MII.DefiningEDU) { - - EDUsToCheck.insert(MII.DefiningEDU.get()); - - if (MII.DefiningEDU->JD != JD.get()) { - LogFailure() << "symbol " << Sym - << " has DefiningEDU with incorrect JD" - << (llvm::is_contained(JDs, MII.DefiningEDU->JD) - ? " (JD not currently in ExecutionSession" - : "") - << "\n"; - } - - if (SymItr->second.getState() != SymbolState::Emitted) { - LogFailure() - << "symbol " << Sym - << " has DefiningEDU, but is not in Emitted state.\n"; - } - } - - // Check that JDs for any DependantEDUs are also in the session -- - // that guarantees that we'll also visit them during this loop. - for (auto &DepEDU : MII.DependantEDUs) { - if (!llvm::is_contained(JDs, DepEDU->JD)) { - LogFailure() << "symbol " << Sym << " has DependantEDU " - << (void *)DepEDU << " with JD (" << DepEDU->JD - << ") that isn't in ExecutionSession.\n"; - } - } - } - } - } - - // Check EDUs. - for (auto *EDU : EDUsToCheck) { - assert(EDU->JD->State == JITDylib::Open && "EDU->JD is not Open"); - - auto LogFailure = [&]() -> raw_fd_ostream & { - AllOk = false; - auto &Stream = errs(); - Stream << "In EDU defining " << EDU->JD->getName() << ": { "; - for (auto &[Sym, Flags] : EDU->Symbols) - Stream << Sym << " "; - Stream << "}, "; - return Stream; - }; - - if (EDU->Symbols.empty()) - LogFailure() << "no symbols defined.\n"; - else { - for (auto &[Sym, Flags] : EDU->Symbols) { - if (!Sym) - LogFailure() << "null symbol defined.\n"; - else { - if (!EDU->JD->Symbols.count(SymbolStringPtr(Sym))) { - LogFailure() << "symbol " << Sym - << " isn't present in JD's symbol table.\n"; - } - } - } - } - - for (auto &[DepJD, Symbols] : EDU->Dependencies) { - if (!llvm::is_contained(JDs, DepJD)) { - LogFailure() << "dependant symbols listed for JD that isn't in " - "ExecutionSession.\n"; - } else { - for (auto &DepSym : Symbols) { - if (!DepJD->Symbols.count(SymbolStringPtr(DepSym))) { - LogFailure() - << "dependant symbol " << DepSym - << " does not appear in symbol table for dependant JD " - << DepJD->getName() << ".\n"; - } - } } } } @@ -2917,359 +2802,64 @@ Error ExecutionSession::OL_notifyResolved(MaterializationResponsibility &MR, return MR.JD.resolve(MR, Symbols); } -template <typename HandleNewDepFn> -void ExecutionSession::propagateExtraEmitDeps( - std::deque<JITDylib::EmissionDepUnit *> Worklist, EDUInfosMap &EDUInfos, - HandleNewDepFn HandleNewDep) { - - // Iterate to a fixed-point to propagate extra-emit dependencies through the - // EDU graph. - while (!Worklist.empty()) { - auto &EDU = *Worklist.front(); - Worklist.pop_front(); - - assert(EDUInfos.count(&EDU) && "No info entry for EDU"); - auto &EDUInfo = EDUInfos[&EDU]; - - // Propagate new dependencies to users. - for (auto *UserEDU : EDUInfo.IntraEmitUsers) { - - // UserEDUInfo only present if UserEDU has its own users. - JITDylib::EmissionDepUnitInfo *UserEDUInfo = nullptr; - { - auto UserEDUInfoItr = EDUInfos.find(UserEDU); - if (UserEDUInfoItr != EDUInfos.end()) - UserEDUInfo = &UserEDUInfoItr->second; - } - - for (auto &[DepJD, Deps] : EDUInfo.NewDeps) { - auto &UserEDUDepsForJD = UserEDU->Dependencies[DepJD]; - DenseSet<NonOwningSymbolStringPtr> *UserEDUNewDepsForJD = nullptr; - for (auto Dep : Deps) { - if (UserEDUDepsForJD.insert(Dep).second) { - HandleNewDep(*UserEDU, *DepJD, Dep); - if (UserEDUInfo) { - if (!UserEDUNewDepsForJD) { - // If UserEDU has no new deps then it's not in the worklist - // yet, so add it. - if (UserEDUInfo->NewDeps.empty()) - Worklist.push_back(UserEDU); - UserEDUNewDepsForJD = &UserEDUInfo->NewDeps[DepJD]; - } - // Add (DepJD, Dep) to NewDeps. - UserEDUNewDepsForJD->insert(Dep); - } - } +WaitingOnGraph::ExternalState +ExecutionSession::IL_getSymbolState(JITDylib *JD, + NonOwningSymbolStringPtr Name) { + if (JD->State != JITDylib::Open) + return WaitingOnGraph::ExternalState::Failed; + + auto I = JD->Symbols.find_as(Name); + + // FIXME: Can we eliminate this possibility if we support query binding? + if (I == JD->Symbols.end()) + return WaitingOnGraph::ExternalState::Failed; + + if (I->second.getFlags().hasError()) + return WaitingOnGraph::ExternalState::Failed; + + if (I->second.getState() == SymbolState::Ready) + return WaitingOnGraph::ExternalState::Ready; + + return WaitingOnGraph::ExternalState::None; +} + +template <typename UpdateSymbolFn, typename UpdateQueryFn> +void ExecutionSession::IL_collectQueries( + JITDylib::AsynchronousSymbolQuerySet &Qs, + WaitingOnGraph::ContainerElementsMap &QualifiedSymbols, + UpdateSymbolFn &&UpdateSymbol, UpdateQueryFn &&UpdateQuery) { + + for (auto &[JD, Symbols] : QualifiedSymbols) { + // IL_emit and JITDylib removal are synchronized by the session lock. + // Since JITDylib removal removes any contained nodes from the + // WaitingOnGraph, we should be able to assert that all nodes in the + // WaitingOnGraph have not been removed. + assert(JD->State == JITDylib::Open && + "WaitingOnGraph includes definition in defunct JITDylib"); + for (auto &Symbol : Symbols) { + // Update symbol table. + auto I = JD->Symbols.find_as(Symbol); + assert(I != JD->Symbols.end() && + "Failed Symbol missing from JD symbol table"); + auto &Entry = I->second; + UpdateSymbol(Entry); + + // Collect queries. + auto J = JD->MaterializingInfos.find_as(Symbol); + if (J != JD->MaterializingInfos.end()) { + for (auto &Q : J->second.takeAllPendingQueries()) { + UpdateQuery(*Q, *JD, Symbol, Entry); + Qs.insert(std::move(Q)); } + JD->MaterializingInfos.erase(J); } } - - EDUInfo.NewDeps.clear(); - } -} - -// Note: This method modifies the emitted set. -ExecutionSession::EDUInfosMap ExecutionSession::simplifyDepGroups( - MaterializationResponsibility &MR, - ArrayRef<SymbolDependenceGroup> EmittedDeps) { - - auto &TargetJD = MR.getTargetJITDylib(); - - // 1. Build initial EmissionDepUnit -> EmissionDepUnitInfo and - // Symbol -> EmissionDepUnit mappings. - DenseMap<JITDylib::EmissionDepUnit *, JITDylib::EmissionDepUnitInfo> EDUInfos; - EDUInfos.reserve(EmittedDeps.size()); - DenseMap<NonOwningSymbolStringPtr, JITDylib::EmissionDepUnit *> EDUForSymbol; - for (auto &DG : EmittedDeps) { - assert(!DG.Symbols.empty() && "DepGroup does not cover any symbols"); - - // Skip empty EDUs. - if (DG.Dependencies.empty()) - continue; - - auto TmpEDU = std::make_shared<JITDylib::EmissionDepUnit>(TargetJD); - auto &EDUInfo = EDUInfos[TmpEDU.get()]; - EDUInfo.EDU = std::move(TmpEDU); - for (const auto &Symbol : DG.Symbols) { - NonOwningSymbolStringPtr NonOwningSymbol(Symbol); - assert(!EDUForSymbol.count(NonOwningSymbol) && - "Symbol should not appear in more than one SymbolDependenceGroup"); - assert(MR.getSymbols().count(Symbol) && - "Symbol in DepGroups not in the emitted set"); - auto NewlyEmittedItr = MR.getSymbols().find(Symbol); - EDUInfo.EDU->Symbols[NonOwningSymbol] = NewlyEmittedItr->second; - EDUForSymbol[NonOwningSymbol] = EDUInfo.EDU.get(); - } - } - - // 2. Build a "residual" EDU to cover all symbols that have no dependencies. - { - DenseMap<NonOwningSymbolStringPtr, JITSymbolFlags> ResidualSymbolFlags; - for (auto &[Sym, Flags] : MR.getSymbols()) { - if (!EDUForSymbol.count(NonOwningSymbolStringPtr(Sym))) - ResidualSymbolFlags[NonOwningSymbolStringPtr(Sym)] = Flags; - } - if (!ResidualSymbolFlags.empty()) { - auto ResidualEDU = std::make_shared<JITDylib::EmissionDepUnit>(TargetJD); - ResidualEDU->Symbols = std::move(ResidualSymbolFlags); - auto &ResidualEDUInfo = EDUInfos[ResidualEDU.get()]; - ResidualEDUInfo.EDU = std::move(ResidualEDU); - - // If the residual EDU is the only one then bail out early. - if (EDUInfos.size() == 1) - return EDUInfos; - - // Otherwise add the residual EDU to the EDUForSymbol map. - for (auto &[Sym, Flags] : ResidualEDUInfo.EDU->Symbols) - EDUForSymbol[Sym] = ResidualEDUInfo.EDU.get(); - } - } - -#ifndef NDEBUG - assert(EDUForSymbol.size() == MR.getSymbols().size() && - "MR symbols not fully covered by EDUs?"); - for (auto &[Sym, Flags] : MR.getSymbols()) { - assert(EDUForSymbol.count(NonOwningSymbolStringPtr(Sym)) && - "Sym in MR not covered by EDU"); - } -#endif // NDEBUG - - // 3. Use the DepGroups array to build a graph of dependencies between - // EmissionDepUnits in this finalization. We want to remove these - // intra-finalization uses, propagating dependencies on symbols outside - // this finalization. Add EDUs to the worklist. - for (auto &DG : EmittedDeps) { - - // Skip SymbolDependenceGroups with no dependencies. - if (DG.Dependencies.empty()) - continue; - - assert(EDUForSymbol.count(NonOwningSymbolStringPtr(*DG.Symbols.begin())) && - "No EDU for DG"); - auto &EDU = - *EDUForSymbol.find(NonOwningSymbolStringPtr(*DG.Symbols.begin())) - ->second; - - for (auto &[DepJD, Deps] : DG.Dependencies) { - DenseSet<NonOwningSymbolStringPtr> NewDepsForJD; - - assert(!Deps.empty() && "Dependence set for DepJD is empty"); - - if (DepJD != &TargetJD) { - // DepJD is some other JITDylib.There can't be any intra-finalization - // edges here, so just skip. - for (auto &Dep : Deps) - NewDepsForJD.insert(NonOwningSymbolStringPtr(Dep)); - } else { - // DepJD is the Target JITDylib. Check for intra-finaliztaion edges, - // skipping any and recording the intra-finalization use instead. - for (auto &Dep : Deps) { - NonOwningSymbolStringPtr NonOwningDep(Dep); - auto I = EDUForSymbol.find(NonOwningDep); - if (I == EDUForSymbol.end()) { - if (!MR.getSymbols().count(Dep)) - NewDepsForJD.insert(NonOwningDep); - continue; - } - - if (I->second != &EDU) - EDUInfos[I->second].IntraEmitUsers.insert(&EDU); - } - } - - if (!NewDepsForJD.empty()) - EDU.Dependencies[DepJD] = std::move(NewDepsForJD); - } - } - - // 4. Build the worklist. - std::deque<JITDylib::EmissionDepUnit *> Worklist; - for (auto &[EDU, EDUInfo] : EDUInfos) { - // If this EDU has extra-finalization dependencies and intra-finalization - // users then add it to the worklist. - if (!EDU->Dependencies.empty()) { - auto I = EDUInfos.find(EDU); - if (I != EDUInfos.end()) { - auto &EDUInfo = I->second; - if (!EDUInfo.IntraEmitUsers.empty()) { - EDUInfo.NewDeps = EDU->Dependencies; - Worklist.push_back(EDU); - } - } - } - } - - // 4. Propagate dependencies through the EDU graph. - propagateExtraEmitDeps( - Worklist, EDUInfos, - [](JITDylib::EmissionDepUnit &, JITDylib &, NonOwningSymbolStringPtr) {}); - - return EDUInfos; -} - -void ExecutionSession::IL_makeEDUReady( - std::shared_ptr<JITDylib::EmissionDepUnit> EDU, - JITDylib::AsynchronousSymbolQuerySet &Queries) { - - // The symbols for this EDU are ready. - auto &JD = *EDU->JD; - - for (auto &[Sym, Flags] : EDU->Symbols) { - assert(JD.Symbols.count(SymbolStringPtr(Sym)) && - "JD does not have an entry for Sym"); - auto &Entry = JD.Symbols[SymbolStringPtr(Sym)]; - - assert(((Entry.getFlags().hasMaterializationSideEffectsOnly() && - Entry.getState() == SymbolState::Materializing) || - Entry.getState() == SymbolState::Resolved || - Entry.getState() == SymbolState::Emitted) && - "Emitting from state other than Resolved"); - - Entry.setState(SymbolState::Ready); - - auto MII = JD.MaterializingInfos.find(SymbolStringPtr(Sym)); - - // Check for pending queries. - if (MII == JD.MaterializingInfos.end()) - continue; - auto &MI = MII->second; - - for (auto &Q : MI.takeQueriesMeeting(SymbolState::Ready)) { - Q->notifySymbolMetRequiredState(SymbolStringPtr(Sym), Entry.getSymbol()); - if (Q->isComplete()) - Queries.insert(Q); - Q->removeQueryDependence(JD, SymbolStringPtr(Sym)); - } - - JD.MaterializingInfos.erase(MII); - } - - JD.shrinkMaterializationInfoMemory(); -} - -void ExecutionSession::IL_makeEDUEmitted( - std::shared_ptr<JITDylib::EmissionDepUnit> EDU, - JITDylib::AsynchronousSymbolQuerySet &Queries) { - - // The symbols for this EDU are emitted, but not ready. - auto &JD = *EDU->JD; - - for (auto &[Sym, Flags] : EDU->Symbols) { - assert(JD.Symbols.count(SymbolStringPtr(Sym)) && - "JD does not have an entry for Sym"); - auto &Entry = JD.Symbols[SymbolStringPtr(Sym)]; - - assert(((Entry.getFlags().hasMaterializationSideEffectsOnly() && - Entry.getState() == SymbolState::Materializing) || - Entry.getState() == SymbolState::Resolved || - Entry.getState() == SymbolState::Emitted) && - "Emitting from state other than Resolved"); - - if (Entry.getState() == SymbolState::Emitted) { - // This was already emitted, so we can skip the rest of this loop. -#ifndef NDEBUG - for (auto &[Sym, Flags] : EDU->Symbols) { - assert(JD.Symbols.count(SymbolStringPtr(Sym)) && - "JD does not have an entry for Sym"); - auto &Entry = JD.Symbols[SymbolStringPtr(Sym)]; - assert(Entry.getState() == SymbolState::Emitted && - "Symbols for EDU in inconsistent state"); - assert(JD.MaterializingInfos.count(SymbolStringPtr(Sym)) && - "Emitted symbol has no MI"); - auto MI = JD.MaterializingInfos[SymbolStringPtr(Sym)]; - assert(MI.takeQueriesMeeting(SymbolState::Emitted).empty() && - "Already-emitted symbol has waiting-on-emitted queries"); - } -#endif // NDEBUG - break; - } - - Entry.setState(SymbolState::Emitted); - auto &MI = JD.MaterializingInfos[SymbolStringPtr(Sym)]; - MI.DefiningEDU = EDU; - - for (auto &Q : MI.takeQueriesMeeting(SymbolState::Emitted)) { - Q->notifySymbolMetRequiredState(SymbolStringPtr(Sym), Entry.getSymbol()); - if (Q->isComplete()) - Queries.insert(Q); - } } - - for (auto &[DepJD, Deps] : EDU->Dependencies) { - for (auto &Dep : Deps) - DepJD->MaterializingInfos[SymbolStringPtr(Dep)].DependantEDUs.insert( - EDU.get()); - } -} - -/// Removes the given dependence from EDU. If EDU's dependence set becomes -/// empty then this function adds an entry for it to the EDUInfos map. -/// Returns true if a new EDUInfosMap entry is added. -bool ExecutionSession::IL_removeEDUDependence(JITDylib::EmissionDepUnit &EDU, - JITDylib &DepJD, - NonOwningSymbolStringPtr DepSym, - EDUInfosMap &EDUInfos) { - assert(EDU.Dependencies.count(&DepJD) && - "JD does not appear in Dependencies of DependantEDU"); - assert(EDU.Dependencies[&DepJD].count(DepSym) && - "Symbol does not appear in Dependencies of DependantEDU"); - auto &JDDeps = EDU.Dependencies[&DepJD]; - JDDeps.erase(DepSym); - if (JDDeps.empty()) { - EDU.Dependencies.erase(&DepJD); - if (EDU.Dependencies.empty()) { - // If the dependencies set has become empty then EDU _may_ be ready - // (we won't know for sure until we've propagated the extra-emit deps). - // Create an EDUInfo for it (if it doesn't have one already) so that - // it'll be visited after propagation. - auto &DepEDUInfo = EDUInfos[&EDU]; - if (!DepEDUInfo.EDU) { - assert(EDU.JD->Symbols.count( - SymbolStringPtr(EDU.Symbols.begin()->first)) && - "Missing symbol entry for first symbol in EDU"); - auto DepEDUFirstMI = EDU.JD->MaterializingInfos.find( - SymbolStringPtr(EDU.Symbols.begin()->first)); - assert(DepEDUFirstMI != EDU.JD->MaterializingInfos.end() && - "Missing MI for first symbol in DependantEDU"); - DepEDUInfo.EDU = DepEDUFirstMI->second.DefiningEDU; - return true; - } - } - } - return false; } -Error ExecutionSession::makeJDClosedError(JITDylib::EmissionDepUnit &EDU, - JITDylib &ClosedJD) { - SymbolNameSet FailedSymbols; - for (auto &[Sym, Flags] : EDU.Symbols) - FailedSymbols.insert(SymbolStringPtr(Sym)); - SymbolDependenceMap BadDeps; - for (auto &Dep : EDU.Dependencies[&ClosedJD]) - BadDeps[&ClosedJD].insert(SymbolStringPtr(Dep)); - return make_error<UnsatisfiedSymbolDependencies>( - ClosedJD.getExecutionSession().getSymbolStringPool(), EDU.JD, - std::move(FailedSymbols), std::move(BadDeps), - ClosedJD.getName() + " is closed"); -} - -Error ExecutionSession::makeUnsatisfiedDepsError(JITDylib::EmissionDepUnit &EDU, - JITDylib &BadJD, - SymbolNameSet BadDeps) { - SymbolNameSet FailedSymbols; - for (auto &[Sym, Flags] : EDU.Symbols) - FailedSymbols.insert(SymbolStringPtr(Sym)); - SymbolDependenceMap BadDepsMap; - BadDepsMap[&BadJD] = std::move(BadDeps); - return make_error<UnsatisfiedSymbolDependencies>( - BadJD.getExecutionSession().getSymbolStringPool(), &BadJD, - std::move(FailedSymbols), std::move(BadDepsMap), - "dependencies removed or in error state"); -} - -Expected<JITDylib::AsynchronousSymbolQuerySet> +Expected<ExecutionSession::EmitQueries> ExecutionSession::IL_emit(MaterializationResponsibility &MR, - EDUInfosMap EDUInfos) { + WaitingOnGraph::SimplifyResult SR) { if (MR.RT->isDefunct()) return make_error<ResourceTrackerDefunct>(MR.RT); @@ -3279,169 +2869,50 @@ ExecutionSession::IL_emit(MaterializationResponsibility &MR, return make_error<StringError>("JITDylib " + TargetJD.getName() + " is defunct", inconvertibleErrorCode()); + #ifdef EXPENSIVE_CHECKS verifySessionState("entering ExecutionSession::IL_emit"); #endif - // Walk all EDUs: - // 1. Verifying that dependencies are available (not removed or in the error - // state. - // 2. Removing any dependencies that are already Ready. - // 3. Lifting any EDUs for Emitted symbols into the EDUInfos map. - // 4. Finding any dependant EDUs and lifting them into the EDUInfos map. - std::deque<JITDylib::EmissionDepUnit *> Worklist; - for (auto &[EDU, _] : EDUInfos) - Worklist.push_back(EDU); - - for (auto *EDU : Worklist) { - auto *EDUInfo = &EDUInfos[EDU]; - - SmallVector<JITDylib *> DepJDsToRemove; - for (auto &[DepJD, Deps] : EDU->Dependencies) { - if (DepJD->State != JITDylib::Open) - return makeJDClosedError(*EDU, *DepJD); - - SymbolNameSet BadDeps; - SmallVector<NonOwningSymbolStringPtr> DepsToRemove; - for (auto &Dep : Deps) { - auto DepEntryItr = DepJD->Symbols.find(SymbolStringPtr(Dep)); - - // If this dep has been removed or moved to the error state then add it - // to the bad deps set. We aggregate these bad deps for more - // comprehensive error messages. - if (DepEntryItr == DepJD->Symbols.end() || - DepEntryItr->second.getFlags().hasError()) { - BadDeps.insert(SymbolStringPtr(Dep)); - continue; - } - - // If this dep isn't emitted yet then just add it to the NewDeps set to - // be propagated. - auto &DepEntry = DepEntryItr->second; - if (DepEntry.getState() < SymbolState::Emitted) { - EDUInfo->NewDeps[DepJD].insert(Dep); - continue; - } - - // This dep has been emitted, so add it to the list to be removed from - // EDU. - DepsToRemove.push_back(Dep); - - // If Dep is Ready then there's nothing further to do. - if (DepEntry.getState() == SymbolState::Ready) { - assert(!DepJD->MaterializingInfos.count(SymbolStringPtr(Dep)) && - "Unexpected MaterializationInfo attached to ready symbol"); - continue; - } + auto ER = G.emit(std::move(SR), + [this](JITDylib *JD, NonOwningSymbolStringPtr Name) { + return IL_getSymbolState(JD, Name); + }); - // If we get here then Dep is Emitted. We need to look up its defining - // EDU and add this EDU to the defining EDU's list of users (this means - // creating an EDUInfos entry if the defining EDU doesn't have one - // already). - assert(DepJD->MaterializingInfos.count(SymbolStringPtr(Dep)) && - "Expected MaterializationInfo for emitted dependency"); - auto &DepMI = DepJD->MaterializingInfos[SymbolStringPtr(Dep)]; - assert(DepMI.DefiningEDU && - "Emitted symbol does not have a defining EDU"); - assert(DepMI.DependantEDUs.empty() && - "Already-emitted symbol has dependant EDUs?"); - auto &DepEDUInfo = EDUInfos[DepMI.DefiningEDU.get()]; - if (!DepEDUInfo.EDU) { - // No EDUInfo yet -- build initial entry, and reset the EDUInfo - // pointer, which we will have invalidated. - EDUInfo = &EDUInfos[EDU]; - DepEDUInfo.EDU = DepMI.DefiningEDU; - for (auto &[DepDepJD, DepDeps] : DepEDUInfo.EDU->Dependencies) { - if (DepDepJD == &TargetJD) { - for (auto &DepDep : DepDeps) - if (!MR.getSymbols().count(SymbolStringPtr(DepDep))) - DepEDUInfo.NewDeps[DepDepJD].insert(DepDep); - } else - DepEDUInfo.NewDeps[DepDepJD] = DepDeps; - } - } - DepEDUInfo.IntraEmitUsers.insert(EDU); - } - - // Some dependencies were removed or in an error state -- error out. - if (!BadDeps.empty()) - return makeUnsatisfiedDepsError(*EDU, *DepJD, std::move(BadDeps)); - - // Remove the emitted / ready deps from DepJD. - for (auto &Dep : DepsToRemove) - Deps.erase(Dep); - - // If there are no further deps in DepJD then flag it for removal too. - if (Deps.empty()) - DepJDsToRemove.push_back(DepJD); - } + EmitQueries EQ; - // Remove any JDs whose dependence sets have become empty. - for (auto &DepJD : DepJDsToRemove) { - assert(EDU->Dependencies.count(DepJD) && - "Trying to remove non-existent dep entries"); - EDU->Dependencies.erase(DepJD); - } - - // Now look for users of this EDU. - for (auto &[Sym, Flags] : EDU->Symbols) { - assert(TargetJD.Symbols.count(SymbolStringPtr(Sym)) && - "Sym not present in symbol table"); - assert((TargetJD.Symbols[SymbolStringPtr(Sym)].getState() == - SymbolState::Resolved || - TargetJD.Symbols[SymbolStringPtr(Sym)] - .getFlags() - .hasMaterializationSideEffectsOnly()) && - "Emitting symbol not in the resolved state"); - assert(!TargetJD.Symbols[SymbolStringPtr(Sym)].getFlags().hasError() && - "Symbol is already in an error state"); - - auto MII = TargetJD.MaterializingInfos.find(SymbolStringPtr(Sym)); - if (MII == TargetJD.MaterializingInfos.end() || - MII->second.DependantEDUs.empty()) - continue; - - for (auto &DependantEDU : MII->second.DependantEDUs) { - if (IL_removeEDUDependence(*DependantEDU, TargetJD, Sym, EDUInfos)) - EDUInfo = &EDUInfos[EDU]; - EDUInfo->IntraEmitUsers.insert(DependantEDU); - } - MII->second.DependantEDUs.clear(); - } - } - - Worklist.clear(); - for (auto &[EDU, EDUInfo] : EDUInfos) { - if (!EDUInfo.IntraEmitUsers.empty() && !EDU->Dependencies.empty()) { - if (EDUInfo.NewDeps.empty()) - EDUInfo.NewDeps = EDU->Dependencies; - Worklist.push_back(EDU); - } - } - - propagateExtraEmitDeps( - Worklist, EDUInfos, - [](JITDylib::EmissionDepUnit &EDU, JITDylib &JD, - NonOwningSymbolStringPtr Sym) { - JD.MaterializingInfos[SymbolStringPtr(Sym)].DependantEDUs.insert(&EDU); - }); + // Handle failed queries. + for (auto &SN : ER.Failed) + IL_collectQueries( + EQ.Failed, SN->defs(), + [](JITDylib::SymbolTableEntry &E) { + E.setFlags(E.getFlags() = JITSymbolFlags::HasError); + }, + [&](AsynchronousSymbolQuery &Q, JITDylib &JD, + NonOwningSymbolStringPtr Name, JITDylib::SymbolTableEntry &E) { + auto &FS = EQ.FailedSymsForQuery[&Q]; + if (!FS) + FS = std::make_shared<SymbolDependenceMap>(); + (*FS)[&JD].insert(SymbolStringPtr(Name)); + }); - JITDylib::AsynchronousSymbolQuerySet CompletedQueries; + for (auto &FQ : EQ.Failed) + FQ->detach(); - // Extract completed queries and lodge not-yet-ready EDUs in the - // session. - for (auto &[EDU, EDUInfo] : EDUInfos) { - if (EDU->Dependencies.empty()) - IL_makeEDUReady(std::move(EDUInfo.EDU), CompletedQueries); - else - IL_makeEDUEmitted(std::move(EDUInfo.EDU), CompletedQueries); - } + for (auto &SN : ER.Ready) + IL_collectQueries( + EQ.Updated, SN->defs(), + [](JITDylib::SymbolTableEntry &E) { E.setState(SymbolState::Ready); }, + [](AsynchronousSymbolQuery &Q, JITDylib &JD, + NonOwningSymbolStringPtr Name, JITDylib::SymbolTableEntry &E) { + Q.notifySymbolMetRequiredState(SymbolStringPtr(Name), E.getSymbol()); + }); #ifdef EXPENSIVE_CHECKS verifySessionState("exiting ExecutionSession::IL_emit"); #endif - return std::move(CompletedQueries); + return std::move(EQ); } Error ExecutionSession::OL_notifyEmitted( @@ -3471,40 +2942,127 @@ Error ExecutionSession::OL_notifyEmitted( } #endif // NDEBUG - auto EDUInfos = simplifyDepGroups(MR, DepGroups); + std::vector<std::unique_ptr<WaitingOnGraph::SuperNode>> SNs; + WaitingOnGraph::ContainerElementsMap Residual; + { + auto &JDResidual = Residual[&MR.getTargetJITDylib()]; + for (auto &[Name, Flags] : MR.getSymbols()) + JDResidual.insert(NonOwningSymbolStringPtr(Name)); + + for (auto &SDG : DepGroups) { + WaitingOnGraph::ContainerElementsMap Defs; + assert(!SDG.Symbols.empty()); + auto &JDDefs = Defs[&MR.getTargetJITDylib()]; + for (auto &Def : SDG.Symbols) { + JDDefs.insert(NonOwningSymbolStringPtr(Def)); + JDResidual.erase(NonOwningSymbolStringPtr(Def)); + } + WaitingOnGraph::ContainerElementsMap Deps; + if (!SDG.Dependencies.empty()) { + for (auto &[JD, Syms] : SDG.Dependencies) { + auto &JDDeps = Deps[JD]; + for (auto &Dep : Syms) + JDDeps.insert(NonOwningSymbolStringPtr(Dep)); + } + } + SNs.push_back(std::make_unique<WaitingOnGraph::SuperNode>( + std::move(Defs), std::move(Deps))); + } + if (!JDResidual.empty()) + SNs.push_back(std::make_unique<WaitingOnGraph::SuperNode>( + std::move(Residual), WaitingOnGraph::ContainerElementsMap())); + } + + auto SR = WaitingOnGraph::simplify(std::move(SNs)); LLVM_DEBUG({ dbgs() << " Simplified dependencies:\n"; - for (auto &[EDU, EDUInfo] : EDUInfos) { - dbgs() << " Symbols: { "; - for (auto &[Sym, Flags] : EDU->Symbols) - dbgs() << Sym << " "; - dbgs() << "}, Dependencies: { "; - for (auto &[DepJD, Deps] : EDU->Dependencies) { - dbgs() << "(" << DepJD->getName() << ", { "; - for (auto &Dep : Deps) - dbgs() << Dep << " "; - dbgs() << "}) "; + for (auto &SN : SR.superNodes()) { + + auto SortedLibs = [](WaitingOnGraph::ContainerElementsMap &C) { + std::vector<JITDylib *> JDs; + for (auto &[JD, _] : C) + JDs.push_back(JD); + llvm::sort(JDs, [](const JITDylib *LHS, const JITDylib *RHS) { + return LHS->getName() < RHS->getName(); + }); + return JDs; + }; + + auto SortedNames = [](WaitingOnGraph::ElementSet &Elems) { + std::vector<NonOwningSymbolStringPtr> Names(Elems.begin(), Elems.end()); + llvm::sort(Names, [](const NonOwningSymbolStringPtr &LHS, + const NonOwningSymbolStringPtr &RHS) { + return *LHS < *RHS; + }); + return Names; + }; + + dbgs() << " Defs: {"; + for (auto *JD : SortedLibs(SN->defs())) { + dbgs() << " (" << JD->getName() << ", ["; + for (auto &Sym : SortedNames(SN->defs()[JD])) + dbgs() << " " << Sym; + dbgs() << " ])"; + } + dbgs() << " }, Deps: {"; + for (auto *JD : SortedLibs(SN->deps())) { + dbgs() << " (" << JD->getName() << ", ["; + for (auto &Sym : SortedNames(SN->deps()[JD])) + dbgs() << " " << Sym; + dbgs() << " ])"; } - dbgs() << "}\n"; + dbgs() << " }\n"; } }); - - auto CompletedQueries = - runSessionLocked([&]() { return IL_emit(MR, EDUInfos); }); + auto EmitQueries = + runSessionLocked([&]() { return IL_emit(MR, std::move(SR)); }); // On error bail out. - if (!CompletedQueries) - return CompletedQueries.takeError(); + if (!EmitQueries) + return EmitQueries.takeError(); - MR.SymbolFlags.clear(); + // Otherwise notify failed queries, and any updated queries that have been + // completed. - // Otherwise notify all the completed queries. - for (auto &Q : *CompletedQueries) { - assert(Q->isComplete() && "Q is not complete"); - Q->handleComplete(*this); + // FIXME: Get rid of error return from notifyEmitted. + SymbolDependenceMap BadDeps; + { + for (auto &FQ : EmitQueries->Failed) { + FQ->detach(); + assert(EmitQueries->FailedSymsForQuery.count(FQ.get()) && + "Missing failed symbols for query"); + auto FailedSyms = std::move(EmitQueries->FailedSymsForQuery[FQ.get()]); + for (auto &[JD, Syms] : *FailedSyms) { + auto &BadDepsForJD = BadDeps[JD]; + for (auto &Sym : Syms) + BadDepsForJD.insert(Sym); + } + FQ->handleFailed(make_error<FailedToMaterialize>(getSymbolStringPool(), + std::move(FailedSyms))); + } + } + + for (auto &UQ : EmitQueries->Updated) + if (UQ->isComplete()) + UQ->handleComplete(*this); + + // If there are any bad dependencies then return an error. + if (!BadDeps.empty()) { + SymbolNameSet BadNames; + // Note: The name set calculated here is bogus: it includes all symbols in + // the MR, not just the ones that failed. We want to remove the error + // return path from notifyEmitted anyway, so this is just a brief + // placeholder to maintain (roughly) the current error behavior. + for (auto &[Name, Flags] : MR.getSymbols()) + BadNames.insert(Name); + MR.SymbolFlags.clear(); + return make_error<UnsatisfiedSymbolDependencies>( + getSymbolStringPool(), &MR.getTargetJITDylib(), std::move(BadNames), + std::move(BadDeps), "dependencies removed or in error state"); } + MR.SymbolFlags.clear(); return Error::success(); } @@ -3535,158 +3093,48 @@ ExecutionSession::IL_failSymbols(JITDylib &JD, #endif JITDylib::AsynchronousSymbolQuerySet FailedQueries; - auto FailedSymbolsMap = std::make_shared<SymbolDependenceMap>(); - auto ExtractFailedQueries = [&](JITDylib::MaterializingInfo &MI) { - JITDylib::AsynchronousSymbolQueryList ToDetach; - for (auto &Q : MI.pendingQueries()) { - // Add the query to the list to be failed and detach it. - FailedQueries.insert(Q); - ToDetach.push_back(Q); + auto Fail = [&](JITDylib *FailJD, NonOwningSymbolStringPtr FailSym) { + auto I = FailJD->Symbols.find_as(FailSym); + assert(I != FailJD->Symbols.end()); + I->second.setFlags(I->second.getFlags() | JITSymbolFlags::HasError); + auto J = FailJD->MaterializingInfos.find_as(FailSym); + if (J != FailJD->MaterializingInfos.end()) { + for (auto &Q : J->second.takeAllPendingQueries()) + FailedQueries.insert(std::move(Q)); + FailJD->MaterializingInfos.erase(J); } - for (auto &Q : ToDetach) - Q->detach(); - assert(!MI.hasQueriesPending() && "Queries still pending after detach"); }; - for (auto &Name : SymbolsToFail) { - (*FailedSymbolsMap)[&JD].insert(Name); - - // Look up the symbol to fail. - auto SymI = JD.Symbols.find(Name); - - // FIXME: Revisit this. We should be able to assert sequencing between - // ResourceTracker removal and symbol failure. - // - // It's possible that this symbol has already been removed, e.g. if a - // materialization failure happens concurrently with a ResourceTracker or - // JITDylib removal. In that case we can safely skip this symbol and - // continue. - if (SymI == JD.Symbols.end()) - continue; - auto &Sym = SymI->second; - - // If the symbol is already in the error state then we must have visited - // it earlier. - if (Sym.getFlags().hasError()) { - assert(!JD.MaterializingInfos.count(Name) && - "Symbol in error state still has MaterializingInfo"); - continue; - } + auto FailedSymbolsMap = std::make_shared<SymbolDependenceMap>(); - // Move the symbol into the error state. - Sym.setFlags(Sym.getFlags() | JITSymbolFlags::HasError); - - // FIXME: Come up with a sane mapping of state to - // presence-of-MaterializingInfo so that we can assert presence / absence - // here, rather than testing it. - auto MII = JD.MaterializingInfos.find(Name); - if (MII == JD.MaterializingInfos.end()) - continue; - - auto &MI = MII->second; - - // Collect queries to be failed for this MII. - ExtractFailedQueries(MI); - - if (MI.DefiningEDU) { - // If there is a DefiningEDU for this symbol then remove this - // symbol from it. - assert(MI.DependantEDUs.empty() && - "Symbol with DefiningEDU should not have DependantEDUs"); - assert(Sym.getState() >= SymbolState::Emitted && - "Symbol has EDU, should have been emitted"); - assert(MI.DefiningEDU->Symbols.count(NonOwningSymbolStringPtr(Name)) && - "Symbol does not appear in its DefiningEDU"); - MI.DefiningEDU->Symbols.erase(NonOwningSymbolStringPtr(Name)); - - // Remove this EDU from the dependants lists of its dependencies. - for (auto &[DepJD, DepSyms] : MI.DefiningEDU->Dependencies) { - for (auto DepSym : DepSyms) { - assert(DepJD->Symbols.count(SymbolStringPtr(DepSym)) && - "DepSym not in DepJD"); - assert(DepJD->MaterializingInfos.count(SymbolStringPtr(DepSym)) && - "DepSym has not MaterializingInfo"); - auto &SymMI = DepJD->MaterializingInfos[SymbolStringPtr(DepSym)]; - assert(SymMI.DependantEDUs.count(MI.DefiningEDU.get()) && - "DefiningEDU missing from DependantEDUs list of dependency"); - SymMI.DependantEDUs.erase(MI.DefiningEDU.get()); - } - } + { + auto &FailedSymsForJD = (*FailedSymbolsMap)[&JD]; + for (auto &Sym : SymbolsToFail) { + FailedSymsForJD.insert(Sym); + Fail(&JD, NonOwningSymbolStringPtr(Sym)); + } + } - MI.DefiningEDU = nullptr; - } else { - // Otherwise if there are any EDUs waiting on this symbol then move - // those symbols to the error state too, and deregister them from the - // symbols that they depend on. - // Note: We use a copy of DependantEDUs here since we'll be removing - // from the original set as we go. - for (auto &DependantEDU : MI.DependantEDUs) { - - // Remove DependantEDU from all of its users DependantEDUs lists. - for (auto &[DepJD, DepSyms] : DependantEDU->Dependencies) { - for (auto DepSym : DepSyms) { - // Skip self-reference to avoid invalidating the MI.DependantEDUs - // map. We'll clear this later. - if (DepJD == &JD && DepSym == Name) - continue; - assert(DepJD->Symbols.count(SymbolStringPtr(DepSym)) && - "DepSym not in DepJD?"); - assert(DepJD->MaterializingInfos.count(SymbolStringPtr(DepSym)) && - "DependantEDU not registered with symbol it depends on"); - auto &SymMI = DepJD->MaterializingInfos[SymbolStringPtr(DepSym)]; - assert(SymMI.DependantEDUs.count(DependantEDU) && - "DependantEDU missing from DependantEDUs list"); - SymMI.DependantEDUs.erase(DependantEDU); - } - } + WaitingOnGraph::ContainerElementsMap ToFail; + auto &JDToFail = ToFail[&JD]; + for (auto &Sym : SymbolsToFail) + JDToFail.insert(NonOwningSymbolStringPtr(Sym)); - // Move any symbols defined by DependantEDU into the error state and - // fail any queries waiting on them. - auto &DepJD = *DependantEDU->JD; - auto DepEDUSymbols = std::move(DependantEDU->Symbols); - for (auto &[DepName, Flags] : DepEDUSymbols) { - auto DepSymItr = DepJD.Symbols.find(SymbolStringPtr(DepName)); - assert(DepSymItr != DepJD.Symbols.end() && - "Symbol not present in table"); - auto &DepSym = DepSymItr->second; - - assert(DepSym.getState() >= SymbolState::Emitted && - "Symbol has EDU, should have been emitted"); - assert(!DepSym.getFlags().hasError() && - "Symbol is already in the error state?"); - DepSym.setFlags(DepSym.getFlags() | JITSymbolFlags::HasError); - (*FailedSymbolsMap)[&DepJD].insert(SymbolStringPtr(DepName)); - - // This symbol has a defining EDU so its MaterializingInfo object must - // exist. - auto DepMIItr = - DepJD.MaterializingInfos.find(SymbolStringPtr(DepName)); - assert(DepMIItr != DepJD.MaterializingInfos.end() && - "Symbol has defining EDU but not MaterializingInfo"); - auto &DepMI = DepMIItr->second; - assert(DepMI.DefiningEDU.get() == DependantEDU && - "Bad EDU dependence edge"); - assert(DepMI.DependantEDUs.empty() && - "Symbol was emitted, should not have any DependantEDUs"); - ExtractFailedQueries(DepMI); - DepJD.MaterializingInfos.erase(SymbolStringPtr(DepName)); - } + auto FailedSNs = G.fail(ToFail); - DepJD.shrinkMaterializationInfoMemory(); + for (auto &SN : FailedSNs) { + for (auto &[FailJD, Defs] : SN->defs()) { + auto &FailedSymsForFailJD = (*FailedSymbolsMap)[FailJD]; + for (auto &Def : Defs) { + FailedSymsForFailJD.insert(SymbolStringPtr(Def)); + Fail(FailJD, Def); } - - MI.DependantEDUs.clear(); } - - assert(!MI.DefiningEDU && "DefiningEDU should have been reset"); - assert(MI.DependantEDUs.empty() && - "DependantEDUs should have been removed above"); - assert(!MI.hasQueriesPending() && - "Can not delete MaterializingInfo with queries pending"); - JD.MaterializingInfos.erase(Name); } - JD.shrinkMaterializationInfoMemory(); + // Detach all failed queries. + for (auto &Q : FailedQueries) + Q->detach(); #ifdef EXPENSIVE_CHECKS verifySessionState("exiting ExecutionSession::IL_failSymbols"); @@ -3721,9 +3169,11 @@ void ExecutionSession::OL_notifyFailed(MaterializationResponsibility &MR) { return IL_failSymbols(MR.getTargetJITDylib(), SymbolsToFail); }); - for (auto &Q : FailedQueries) + for (auto &Q : FailedQueries) { + Q->detach(); Q->handleFailed( make_error<FailedToMaterialize>(getSymbolStringPool(), FailedSymbols)); + } } Error ExecutionSession::OL_replace(MaterializationResponsibility &MR, diff --git a/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp b/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp index dec1df7..893523c 100644 --- a/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp +++ b/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp @@ -448,7 +448,7 @@ Error SimpleRemoteEPC::handleHangup(SimpleRemoteEPCArgBytesVector ArgBytes) { if (const char *ErrMsg = WFR.getOutOfBandError()) return make_error<StringError>(ErrMsg, inconvertibleErrorCode()); - detail::SPSSerializableError Info; + orc::shared::detail::SPSSerializableError Info; SPSInputBuffer IB(WFR.data(), WFR.size()); if (!SPSArgList<SPSError>::deserialize(IB, Info)) return make_error<StringError>("Could not deserialize hangup info", diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp index d8374b6..10f915d 100644 --- a/llvm/lib/IR/AutoUpgrade.cpp +++ b/llvm/lib/IR/AutoUpgrade.cpp @@ -1426,6 +1426,28 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn, Intrinsic::memset, ParamTypes); return true; } + + unsigned MaskedID = + StringSwitch<unsigned>(Name) + .StartsWith("masked.load", Intrinsic::masked_load) + .StartsWith("masked.gather", Intrinsic::masked_gather) + .StartsWith("masked.store", Intrinsic::masked_store) + .StartsWith("masked.scatter", Intrinsic::masked_scatter) + .Default(0); + if (MaskedID && F->arg_size() == 4) { + rename(F); + if (MaskedID == Intrinsic::masked_load || + MaskedID == Intrinsic::masked_gather) { + NewFn = Intrinsic::getOrInsertDeclaration( + F->getParent(), MaskedID, + {F->getReturnType(), F->getArg(0)->getType()}); + return true; + } + NewFn = Intrinsic::getOrInsertDeclaration( + F->getParent(), MaskedID, + {F->getArg(0)->getType(), F->getArg(1)->getType()}); + return true; + } break; } case 'n': { @@ -5231,6 +5253,54 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) { break; } + case Intrinsic::masked_load: + case Intrinsic::masked_gather: + case Intrinsic::masked_store: + case Intrinsic::masked_scatter: { + if (CI->arg_size() != 4) { + DefaultCase(); + return; + } + + const DataLayout &DL = CI->getDataLayout(); + switch (NewFn->getIntrinsicID()) { + case Intrinsic::masked_load: + NewCall = Builder.CreateMaskedLoad( + CI->getType(), CI->getArgOperand(0), + cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue(), + CI->getArgOperand(2), CI->getArgOperand(3)); + break; + case Intrinsic::masked_gather: + NewCall = Builder.CreateMaskedGather( + CI->getType(), CI->getArgOperand(0), + DL.getValueOrABITypeAlignment( + cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue(), + CI->getType()->getScalarType()), + CI->getArgOperand(2), CI->getArgOperand(3)); + break; + case Intrinsic::masked_store: + NewCall = Builder.CreateMaskedStore( + CI->getArgOperand(0), CI->getArgOperand(1), + cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue(), + CI->getArgOperand(3)); + break; + case Intrinsic::masked_scatter: + NewCall = Builder.CreateMaskedScatter( + CI->getArgOperand(0), CI->getArgOperand(1), + DL.getValueOrABITypeAlignment( + cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue(), + CI->getArgOperand(0)->getType()->getScalarType()), + CI->getArgOperand(3)); + break; + default: + llvm_unreachable("Unexpected intrinsic ID"); + } + // Previous metadata is still valid. + NewCall->copyMetadata(*CI); + NewCall->setTailCallKind(cast<CallInst>(CI)->getTailCallKind()); + break; + } + case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: { if (CI->arg_size() != 2) { diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp index 15c0198..88dbd17 100644 --- a/llvm/lib/IR/IRBuilder.cpp +++ b/llvm/lib/IR/IRBuilder.cpp @@ -495,9 +495,11 @@ CallInst *IRBuilderBase::CreateMaskedLoad(Type *Ty, Value *Ptr, Align Alignment, if (!PassThru) PassThru = PoisonValue::get(Ty); Type *OverloadedTypes[] = { Ty, PtrTy }; - Value *Ops[] = {Ptr, getInt32(Alignment.value()), Mask, PassThru}; - return CreateMaskedIntrinsic(Intrinsic::masked_load, Ops, - OverloadedTypes, Name); + Value *Ops[] = {Ptr, Mask, PassThru}; + CallInst *CI = + CreateMaskedIntrinsic(Intrinsic::masked_load, Ops, OverloadedTypes, Name); + CI->addParamAttr(0, Attribute::getWithAlignment(CI->getContext(), Alignment)); + return CI; } /// Create a call to a Masked Store intrinsic. @@ -513,8 +515,11 @@ CallInst *IRBuilderBase::CreateMaskedStore(Value *Val, Value *Ptr, assert(DataTy->isVectorTy() && "Val should be a vector"); assert(Mask && "Mask should not be all-ones (null)"); Type *OverloadedTypes[] = { DataTy, PtrTy }; - Value *Ops[] = {Val, Ptr, getInt32(Alignment.value()), Mask}; - return CreateMaskedIntrinsic(Intrinsic::masked_store, Ops, OverloadedTypes); + Value *Ops[] = {Val, Ptr, Mask}; + CallInst *CI = + CreateMaskedIntrinsic(Intrinsic::masked_store, Ops, OverloadedTypes); + CI->addParamAttr(1, Attribute::getWithAlignment(CI->getContext(), Alignment)); + return CI; } /// Create a call to a Masked intrinsic, with given intrinsic Id, @@ -552,12 +557,14 @@ CallInst *IRBuilderBase::CreateMaskedGather(Type *Ty, Value *Ptrs, PassThru = PoisonValue::get(Ty); Type *OverloadedTypes[] = {Ty, PtrsTy}; - Value *Ops[] = {Ptrs, getInt32(Alignment.value()), Mask, PassThru}; + Value *Ops[] = {Ptrs, Mask, PassThru}; // We specify only one type when we create this intrinsic. Types of other // arguments are derived from this type. - return CreateMaskedIntrinsic(Intrinsic::masked_gather, Ops, OverloadedTypes, - Name); + CallInst *CI = CreateMaskedIntrinsic(Intrinsic::masked_gather, Ops, + OverloadedTypes, Name); + CI->addParamAttr(0, Attribute::getWithAlignment(CI->getContext(), Alignment)); + return CI; } /// Create a call to a Masked Scatter intrinsic. @@ -577,11 +584,14 @@ CallInst *IRBuilderBase::CreateMaskedScatter(Value *Data, Value *Ptrs, Mask = getAllOnesMask(NumElts); Type *OverloadedTypes[] = {DataTy, PtrsTy}; - Value *Ops[] = {Data, Ptrs, getInt32(Alignment.value()), Mask}; + Value *Ops[] = {Data, Ptrs, Mask}; // We specify only one type when we create this intrinsic. Types of other // arguments are derived from this type. - return CreateMaskedIntrinsic(Intrinsic::masked_scatter, Ops, OverloadedTypes); + CallInst *CI = + CreateMaskedIntrinsic(Intrinsic::masked_scatter, Ops, OverloadedTypes); + CI->addParamAttr(1, Attribute::getWithAlignment(CI->getContext(), Alignment)); + return CI; } /// Create a call to Masked Expand Load intrinsic diff --git a/llvm/lib/IR/Intrinsics.cpp b/llvm/lib/IR/Intrinsics.cpp index 6797a10..526800e 100644 --- a/llvm/lib/IR/Intrinsics.cpp +++ b/llvm/lib/IR/Intrinsics.cpp @@ -725,6 +725,19 @@ Function *Intrinsic::getOrInsertDeclaration(Module *M, ID id, // There can never be multiple globals with the same name of different types, // because intrinsics must be a specific type. auto *FT = getType(M->getContext(), id, Tys); + Function *F = cast<Function>( + M->getOrInsertFunction( + Tys.empty() ? getName(id) : getName(id, Tys, M, FT), FT) + .getCallee()); + if (F->getFunctionType() == FT) + return F; + + // It's possible that a declaration for this intrinsic already exists with an + // incorrect signature, if the signature has changed, but this particular + // declaration has not been auto-upgraded yet. In that case, rename the + // invalid declaration and insert a new one with the correct signature. The + // invalid declaration will get upgraded later. + F->setName(F->getName() + ".invalid"); return cast<Function>( M->getOrInsertFunction( Tys.empty() ? getName(id) : getName(id, Tys, M, FT), FT) diff --git a/llvm/lib/IR/RuntimeLibcalls.cpp b/llvm/lib/IR/RuntimeLibcalls.cpp index 7ea2e46..77af29b 100644 --- a/llvm/lib/IR/RuntimeLibcalls.cpp +++ b/llvm/lib/IR/RuntimeLibcalls.cpp @@ -21,9 +21,6 @@ using namespace RTLIB; #define GET_SET_TARGET_RUNTIME_LIBCALL_SETS #define DEFINE_GET_LOOKUP_LIBCALL_IMPL_NAME #include "llvm/IR/RuntimeLibcalls.inc" -#undef GET_INIT_RUNTIME_LIBCALL_NAMES -#undef GET_SET_TARGET_RUNTIME_LIBCALL_SETS -#undef DEFINE_GET_LOOKUP_LIBCALL_IMPL_NAME /// Set default libcall names. If a target wants to opt-out of a libcall it /// should be placed here. diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 3572852..03da154 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -6211,13 +6211,10 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) { Check(Call.getType()->isVectorTy(), "masked_load: must return a vector", Call); - ConstantInt *Alignment = cast<ConstantInt>(Call.getArgOperand(1)); - Value *Mask = Call.getArgOperand(2); - Value *PassThru = Call.getArgOperand(3); + Value *Mask = Call.getArgOperand(1); + Value *PassThru = Call.getArgOperand(2); Check(Mask->getType()->isVectorTy(), "masked_load: mask must be vector", Call); - Check(Alignment->getValue().isPowerOf2(), - "masked_load: alignment must be a power of 2", Call); Check(PassThru->getType() == Call.getType(), "masked_load: pass through and return type must match", Call); Check(cast<VectorType>(Mask->getType())->getElementCount() == @@ -6227,33 +6224,15 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) { } case Intrinsic::masked_store: { Value *Val = Call.getArgOperand(0); - ConstantInt *Alignment = cast<ConstantInt>(Call.getArgOperand(2)); - Value *Mask = Call.getArgOperand(3); + Value *Mask = Call.getArgOperand(2); Check(Mask->getType()->isVectorTy(), "masked_store: mask must be vector", Call); - Check(Alignment->getValue().isPowerOf2(), - "masked_store: alignment must be a power of 2", Call); Check(cast<VectorType>(Mask->getType())->getElementCount() == cast<VectorType>(Val->getType())->getElementCount(), "masked_store: vector mask must be same length as value", Call); break; } - case Intrinsic::masked_gather: { - const APInt &Alignment = - cast<ConstantInt>(Call.getArgOperand(1))->getValue(); - Check(Alignment.isZero() || Alignment.isPowerOf2(), - "masked_gather: alignment must be 0 or a power of 2", Call); - break; - } - case Intrinsic::masked_scatter: { - const APInt &Alignment = - cast<ConstantInt>(Call.getArgOperand(2))->getValue(); - Check(Alignment.isZero() || Alignment.isPowerOf2(), - "masked_scatter: alignment must be 0 or a power of 2", Call); - break; - } - case Intrinsic::experimental_guard: { Check(isa<CallInst>(Call), "experimental_guard cannot be invoked", Call); Check(Call.countOperandBundlesOfType(LLVMContext::OB_deopt) == 1, diff --git a/llvm/lib/MC/MCParser/MasmParser.cpp b/llvm/lib/MC/MCParser/MasmParser.cpp index 7f0ea78..d4901d9 100644 --- a/llvm/lib/MC/MCParser/MasmParser.cpp +++ b/llvm/lib/MC/MCParser/MasmParser.cpp @@ -2903,7 +2903,7 @@ bool MasmParser::parseIdentifier(StringRef &Res, if (Position == StartOfStatement && StringSwitch<bool>(Res) .CaseLower("echo", true) - .CasesLower("ifdef", "ifndef", "elseifdef", "elseifndef", true) + .CasesLower({"ifdef", "ifndef", "elseifdef", "elseifndef"}, true) .Default(false)) { ExpandNextToken = DoNotExpandMacros; } diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp index 8623c06..4787604 100644 --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -130,44 +130,46 @@ struct fltSemantics { bool hasSignBitInMSB = true; }; -static constexpr fltSemantics semIEEEhalf = {15, -14, 11, 16}; -static constexpr fltSemantics semBFloat = {127, -126, 8, 16}; -static constexpr fltSemantics semIEEEsingle = {127, -126, 24, 32}; -static constexpr fltSemantics semIEEEdouble = {1023, -1022, 53, 64}; -static constexpr fltSemantics semIEEEquad = {16383, -16382, 113, 128}; -static constexpr fltSemantics semFloat8E5M2 = {15, -14, 3, 8}; -static constexpr fltSemantics semFloat8E5M2FNUZ = { +constexpr fltSemantics APFloatBase::semIEEEhalf = {15, -14, 11, 16}; +constexpr fltSemantics APFloatBase::semBFloat = {127, -126, 8, 16}; +constexpr fltSemantics APFloatBase::semIEEEsingle = {127, -126, 24, 32}; +constexpr fltSemantics APFloatBase::semIEEEdouble = {1023, -1022, 53, 64}; +constexpr fltSemantics APFloatBase::semIEEEquad = {16383, -16382, 113, 128}; +constexpr fltSemantics APFloatBase::semFloat8E5M2 = {15, -14, 3, 8}; +constexpr fltSemantics APFloatBase::semFloat8E5M2FNUZ = { 15, -15, 3, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero}; -static constexpr fltSemantics semFloat8E4M3 = {7, -6, 4, 8}; -static constexpr fltSemantics semFloat8E4M3FN = { +constexpr fltSemantics APFloatBase::semFloat8E4M3 = {7, -6, 4, 8}; +constexpr fltSemantics APFloatBase::semFloat8E4M3FN = { 8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes}; -static constexpr fltSemantics semFloat8E4M3FNUZ = { +constexpr fltSemantics APFloatBase::semFloat8E4M3FNUZ = { 7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero}; -static constexpr fltSemantics semFloat8E4M3B11FNUZ = { +constexpr fltSemantics APFloatBase::semFloat8E4M3B11FNUZ = { 4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero}; -static constexpr fltSemantics semFloat8E3M4 = {3, -2, 5, 8}; -static constexpr fltSemantics semFloatTF32 = {127, -126, 11, 19}; -static constexpr fltSemantics semFloat8E8M0FNU = {127, - -127, - 1, - 8, - fltNonfiniteBehavior::NanOnly, - fltNanEncoding::AllOnes, - false, - false, - false}; - -static constexpr fltSemantics semFloat6E3M2FN = { +constexpr fltSemantics APFloatBase::semFloat8E3M4 = {3, -2, 5, 8}; +constexpr fltSemantics APFloatBase::semFloatTF32 = {127, -126, 11, 19}; +constexpr fltSemantics APFloatBase::semFloat8E8M0FNU = { + 127, + -127, + 1, + 8, + fltNonfiniteBehavior::NanOnly, + fltNanEncoding::AllOnes, + false, + false, + false}; + +constexpr fltSemantics APFloatBase::semFloat6E3M2FN = { 4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly}; -static constexpr fltSemantics semFloat6E2M3FN = { +constexpr fltSemantics APFloatBase::semFloat6E2M3FN = { 2, 0, 4, 6, fltNonfiniteBehavior::FiniteOnly}; -static constexpr fltSemantics semFloat4E2M1FN = { +constexpr fltSemantics APFloatBase::semFloat4E2M1FN = { 2, 0, 2, 4, fltNonfiniteBehavior::FiniteOnly}; -static constexpr fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80}; -static constexpr fltSemantics semBogus = {0, 0, 0, 0}; -static constexpr fltSemantics semPPCDoubleDouble = {-1, 0, 0, 128}; -static constexpr fltSemantics semPPCDoubleDoubleLegacy = {1023, -1022 + 53, - 53 + 53, 128}; +constexpr fltSemantics APFloatBase::semX87DoubleExtended = {16383, -16382, 64, + 80}; +constexpr fltSemantics APFloatBase::semBogus = {0, 0, 0, 0}; +constexpr fltSemantics APFloatBase::semPPCDoubleDouble = {-1, 0, 0, 128}; +constexpr fltSemantics APFloatBase::semPPCDoubleDoubleLegacy = { + 1023, -1022 + 53, 53 + 53, 128}; const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) { switch (S) { @@ -261,36 +263,6 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) { llvm_unreachable("Unknown floating semantics"); } -const fltSemantics &APFloatBase::IEEEhalf() { return semIEEEhalf; } -const fltSemantics &APFloatBase::BFloat() { return semBFloat; } -const fltSemantics &APFloatBase::IEEEsingle() { return semIEEEsingle; } -const fltSemantics &APFloatBase::IEEEdouble() { return semIEEEdouble; } -const fltSemantics &APFloatBase::IEEEquad() { return semIEEEquad; } -const fltSemantics &APFloatBase::PPCDoubleDouble() { - return semPPCDoubleDouble; -} -const fltSemantics &APFloatBase::PPCDoubleDoubleLegacy() { - return semPPCDoubleDoubleLegacy; -} -const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; } -const fltSemantics &APFloatBase::Float8E5M2FNUZ() { return semFloat8E5M2FNUZ; } -const fltSemantics &APFloatBase::Float8E4M3() { return semFloat8E4M3; } -const fltSemantics &APFloatBase::Float8E4M3FN() { return semFloat8E4M3FN; } -const fltSemantics &APFloatBase::Float8E4M3FNUZ() { return semFloat8E4M3FNUZ; } -const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() { - return semFloat8E4M3B11FNUZ; -} -const fltSemantics &APFloatBase::Float8E3M4() { return semFloat8E3M4; } -const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; } -const fltSemantics &APFloatBase::Float8E8M0FNU() { return semFloat8E8M0FNU; } -const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; } -const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; } -const fltSemantics &APFloatBase::Float4E2M1FN() { return semFloat4E2M1FN; } -const fltSemantics &APFloatBase::x87DoubleExtended() { - return semX87DoubleExtended; -} -const fltSemantics &APFloatBase::Bogus() { return semBogus; } - bool APFloatBase::isRepresentableBy(const fltSemantics &A, const fltSemantics &B) { return A.maxExponent <= B.maxExponent && A.minExponent >= B.minExponent && @@ -1029,7 +1001,7 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) { // For x87 extended precision, we want to make a NaN, not a // pseudo-NaN. Maybe we should expose the ability to make // pseudo-NaNs? - if (semantics == &semX87DoubleExtended) + if (semantics == &APFloatBase::semX87DoubleExtended) APInt::tcSetBit(significand, QNaNBit + 1); } @@ -1054,7 +1026,7 @@ IEEEFloat &IEEEFloat::operator=(IEEEFloat &&rhs) { category = rhs.category; sign = rhs.sign; - rhs.semantics = &semBogus; + rhs.semantics = &APFloatBase::semBogus; return *this; } @@ -1247,7 +1219,7 @@ IEEEFloat::IEEEFloat(const IEEEFloat &rhs) { assign(rhs); } -IEEEFloat::IEEEFloat(IEEEFloat &&rhs) : semantics(&semBogus) { +IEEEFloat::IEEEFloat(IEEEFloat &&rhs) : semantics(&APFloatBase::semBogus) { *this = std::move(rhs); } @@ -2607,8 +2579,8 @@ APFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics, shift = toSemantics.precision - fromSemantics.precision; bool X86SpecialNan = false; - if (&fromSemantics == &semX87DoubleExtended && - &toSemantics != &semX87DoubleExtended && category == fcNaN && + if (&fromSemantics == &APFloatBase::semX87DoubleExtended && + &toSemantics != &APFloatBase::semX87DoubleExtended && category == fcNaN && (!(*significandParts() & 0x8000000000000000ULL) || !(*significandParts() & 0x4000000000000000ULL))) { // x86 has some unusual NaNs which cannot be represented in any other @@ -2628,8 +2600,7 @@ APFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics, int exponentChange = omsb - fromSemantics.precision; if (exponent + exponentChange < toSemantics.minExponent) exponentChange = toSemantics.minExponent - exponent; - if (exponentChange < shift) - exponentChange = shift; + exponentChange = std::max(exponentChange, shift); if (exponentChange < 0) { shift -= exponentChange; exponent += exponentChange; @@ -2694,7 +2665,7 @@ APFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics, // For x87 extended precision, we want to make a NaN, not a special NaN if // the input wasn't special either. - if (!X86SpecialNan && semantics == &semX87DoubleExtended) + if (!X86SpecialNan && semantics == &APFloatBase::semX87DoubleExtended) APInt::tcSetBit(significandParts(), semantics->precision - 1); // Convert of sNaN creates qNaN and raises an exception (invalid op). @@ -3071,8 +3042,7 @@ IEEEFloat::roundSignificandWithExponent(const integerPart *decSigParts, if (decSig.exponent < semantics->minExponent) { excessPrecision += (semantics->minExponent - decSig.exponent); truncatedBits = excessPrecision; - if (excessPrecision > calcSemantics.precision) - excessPrecision = calcSemantics.precision; + excessPrecision = std::min(excessPrecision, calcSemantics.precision); } /* Extra half-ulp lost in reciprocal of exponent. */ powHUerr = (powStatus == opOK && calcLostFraction == lfExactlyZero) ? 0:2; @@ -3469,8 +3439,7 @@ char *IEEEFloat::convertNormalToHexString(char *dst, unsigned int hexDigits, /* Convert as much of "part" to hexdigits as we can. */ unsigned int curDigits = integerPartWidth / 4; - if (curDigits > outputDigits) - curDigits = outputDigits; + curDigits = std::min(curDigits, outputDigits); dst += partAsHex (dst, part, curDigits, hexDigitChars); outputDigits -= curDigits; } @@ -3530,7 +3499,8 @@ hash_code hash_value(const IEEEFloat &Arg) { // the actual IEEE respresentations. We compensate for that here. APInt IEEEFloat::convertF80LongDoubleAPFloatToAPInt() const { - assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended); + assert(semantics == + (const llvm::fltSemantics *)&APFloatBase::semX87DoubleExtended); assert(partCount()==2); uint64_t myexponent, mysignificand; @@ -3560,7 +3530,8 @@ APInt IEEEFloat::convertF80LongDoubleAPFloatToAPInt() const { } APInt IEEEFloat::convertPPCDoubleDoubleLegacyAPFloatToAPInt() const { - assert(semantics == (const llvm::fltSemantics *)&semPPCDoubleDoubleLegacy); + assert(semantics == + (const llvm::fltSemantics *)&APFloatBase::semPPCDoubleDoubleLegacy); assert(partCount()==2); uint64_t words[2]; @@ -3574,14 +3545,14 @@ APInt IEEEFloat::convertPPCDoubleDoubleLegacyAPFloatToAPInt() const { // Declare fltSemantics before APFloat that uses it (and // saves pointer to it) to ensure correct destruction order. fltSemantics extendedSemantics = *semantics; - extendedSemantics.minExponent = semIEEEdouble.minExponent; + extendedSemantics.minExponent = APFloatBase::semIEEEdouble.minExponent; IEEEFloat extended(*this); fs = extended.convert(extendedSemantics, rmNearestTiesToEven, &losesInfo); assert(fs == opOK && !losesInfo); (void)fs; IEEEFloat u(extended); - fs = u.convert(semIEEEdouble, rmNearestTiesToEven, &losesInfo); + fs = u.convert(APFloatBase::semIEEEdouble, rmNearestTiesToEven, &losesInfo); assert(fs == opOK || fs == opInexact); (void)fs; words[0] = *u.convertDoubleAPFloatToAPInt().getRawData(); @@ -3597,7 +3568,7 @@ APInt IEEEFloat::convertPPCDoubleDoubleLegacyAPFloatToAPInt() const { IEEEFloat v(extended); v.subtract(u, rmNearestTiesToEven); - fs = v.convert(semIEEEdouble, rmNearestTiesToEven, &losesInfo); + fs = v.convert(APFloatBase::semIEEEdouble, rmNearestTiesToEven, &losesInfo); assert(fs == opOK && !losesInfo); (void)fs; words[1] = *v.convertDoubleAPFloatToAPInt().getRawData(); @@ -3611,8 +3582,9 @@ APInt IEEEFloat::convertPPCDoubleDoubleLegacyAPFloatToAPInt() const { template <const fltSemantics &S> APInt IEEEFloat::convertIEEEFloatToAPInt() const { assert(semantics == &S); - const int bias = - (semantics == &semFloat8E8M0FNU) ? -S.minExponent : -(S.minExponent - 1); + const int bias = (semantics == &APFloatBase::semFloat8E8M0FNU) + ? -S.minExponent + : -(S.minExponent - 1); constexpr unsigned int trailing_significand_bits = S.precision - 1; constexpr int integer_bit_part = trailing_significand_bits / integerPartWidth; constexpr integerPart integer_bit = @@ -3677,87 +3649,87 @@ APInt IEEEFloat::convertIEEEFloatToAPInt() const { APInt IEEEFloat::convertQuadrupleAPFloatToAPInt() const { assert(partCount() == 2); - return convertIEEEFloatToAPInt<semIEEEquad>(); + return convertIEEEFloatToAPInt<APFloatBase::semIEEEquad>(); } APInt IEEEFloat::convertDoubleAPFloatToAPInt() const { assert(partCount()==1); - return convertIEEEFloatToAPInt<semIEEEdouble>(); + return convertIEEEFloatToAPInt<APFloatBase::semIEEEdouble>(); } APInt IEEEFloat::convertFloatAPFloatToAPInt() const { assert(partCount()==1); - return convertIEEEFloatToAPInt<semIEEEsingle>(); + return convertIEEEFloatToAPInt<APFloatBase::semIEEEsingle>(); } APInt IEEEFloat::convertBFloatAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semBFloat>(); + return convertIEEEFloatToAPInt<APFloatBase::semBFloat>(); } APInt IEEEFloat::convertHalfAPFloatToAPInt() const { assert(partCount()==1); - return convertIEEEFloatToAPInt<semIEEEhalf>(); + return convertIEEEFloatToAPInt<APFloatBase::APFloatBase::semIEEEhalf>(); } APInt IEEEFloat::convertFloat8E5M2APFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E5M2>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E5M2>(); } APInt IEEEFloat::convertFloat8E5M2FNUZAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E5M2FNUZ>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E5M2FNUZ>(); } APInt IEEEFloat::convertFloat8E4M3APFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E4M3>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E4M3>(); } APInt IEEEFloat::convertFloat8E4M3FNAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E4M3FN>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E4M3FN>(); } APInt IEEEFloat::convertFloat8E4M3FNUZAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E4M3FNUZ>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E4M3FNUZ>(); } APInt IEEEFloat::convertFloat8E4M3B11FNUZAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E4M3B11FNUZ>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E4M3B11FNUZ>(); } APInt IEEEFloat::convertFloat8E3M4APFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E3M4>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E3M4>(); } APInt IEEEFloat::convertFloatTF32APFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloatTF32>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloatTF32>(); } APInt IEEEFloat::convertFloat8E8M0FNUAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat8E8M0FNU>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat8E8M0FNU>(); } APInt IEEEFloat::convertFloat6E3M2FNAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat6E3M2FN>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat6E3M2FN>(); } APInt IEEEFloat::convertFloat6E2M3FNAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat6E2M3FN>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat6E2M3FN>(); } APInt IEEEFloat::convertFloat4E2M1FNAPFloatToAPInt() const { assert(partCount() == 1); - return convertIEEEFloatToAPInt<semFloat4E2M1FN>(); + return convertIEEEFloatToAPInt<APFloatBase::semFloat4E2M1FN>(); } // This function creates an APInt that is just a bit map of the floating @@ -3765,74 +3737,77 @@ APInt IEEEFloat::convertFloat4E2M1FNAPFloatToAPInt() const { // and treating the result as a normal integer is unlikely to be useful. APInt IEEEFloat::bitcastToAPInt() const { - if (semantics == (const llvm::fltSemantics*)&semIEEEhalf) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semIEEEhalf) return convertHalfAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semBFloat) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semBFloat) return convertBFloatAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics*)&semIEEEsingle) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semIEEEsingle) return convertFloatAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics*)&semIEEEdouble) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semIEEEdouble) return convertDoubleAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics*)&semIEEEquad) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semIEEEquad) return convertQuadrupleAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semPPCDoubleDoubleLegacy) + if (semantics == + (const llvm::fltSemantics *)&APFloatBase::semPPCDoubleDoubleLegacy) return convertPPCDoubleDoubleLegacyAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat8E5M2) return convertFloat8E5M2APFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2FNUZ) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat8E5M2FNUZ) return convertFloat8E5M2FNUZAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat8E4M3) return convertFloat8E4M3APFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FN) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat8E4M3FN) return convertFloat8E4M3FNAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FNUZ) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat8E4M3FNUZ) return convertFloat8E4M3FNUZAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ) + if (semantics == + (const llvm::fltSemantics *)&APFloatBase::semFloat8E4M3B11FNUZ) return convertFloat8E4M3B11FNUZAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E3M4) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat8E3M4) return convertFloat8E3M4APFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloatTF32) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloatTF32) return convertFloatTF32APFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat8E8M0FNU) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat8E8M0FNU) return convertFloat8E8M0FNUAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat6E3M2FN) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat6E3M2FN) return convertFloat6E3M2FNAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat6E2M3FN) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat6E2M3FN) return convertFloat6E2M3FNAPFloatToAPInt(); - if (semantics == (const llvm::fltSemantics *)&semFloat4E2M1FN) + if (semantics == (const llvm::fltSemantics *)&APFloatBase::semFloat4E2M1FN) return convertFloat4E2M1FNAPFloatToAPInt(); - assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended && + assert(semantics == + (const llvm::fltSemantics *)&APFloatBase::semX87DoubleExtended && "unknown format!"); return convertF80LongDoubleAPFloatToAPInt(); } float IEEEFloat::convertToFloat() const { - assert(semantics == (const llvm::fltSemantics*)&semIEEEsingle && + assert(semantics == (const llvm::fltSemantics *)&APFloatBase::semIEEEsingle && "Float semantics are not IEEEsingle"); APInt api = bitcastToAPInt(); return api.bitsToFloat(); } double IEEEFloat::convertToDouble() const { - assert(semantics == (const llvm::fltSemantics*)&semIEEEdouble && + assert(semantics == (const llvm::fltSemantics *)&APFloatBase::semIEEEdouble && "Float semantics are not IEEEdouble"); APInt api = bitcastToAPInt(); return api.bitsToDouble(); @@ -3840,7 +3815,7 @@ double IEEEFloat::convertToDouble() const { #ifdef HAS_IEE754_FLOAT128 float128 IEEEFloat::convertToQuad() const { - assert(semantics == (const llvm::fltSemantics *)&semIEEEquad && + assert(semantics == (const llvm::fltSemantics *)&APFloatBase::semIEEEquad && "Float semantics are not IEEEquads"); APInt api = bitcastToAPInt(); return api.bitsToQuad(); @@ -3861,7 +3836,7 @@ void IEEEFloat::initFromF80LongDoubleAPInt(const APInt &api) { uint64_t mysignificand = i1; uint8_t myintegerbit = mysignificand >> 63; - initialize(&semX87DoubleExtended); + initialize(&APFloatBase::semX87DoubleExtended); assert(partCount()==2); sign = static_cast<unsigned int>(i2>>15); @@ -3893,14 +3868,16 @@ void IEEEFloat::initFromPPCDoubleDoubleLegacyAPInt(const APInt &api) { // Get the first double and convert to our format. initFromDoubleAPInt(APInt(64, i1)); - fs = convert(semPPCDoubleDoubleLegacy, rmNearestTiesToEven, &losesInfo); + fs = convert(APFloatBase::semPPCDoubleDoubleLegacy, rmNearestTiesToEven, + &losesInfo); assert(fs == opOK && !losesInfo); (void)fs; // Unless we have a special case, add in second double. if (isFiniteNonZero()) { - IEEEFloat v(semIEEEdouble, APInt(64, i2)); - fs = v.convert(semPPCDoubleDoubleLegacy, rmNearestTiesToEven, &losesInfo); + IEEEFloat v(APFloatBase::semIEEEdouble, APInt(64, i2)); + fs = v.convert(APFloatBase::semPPCDoubleDoubleLegacy, rmNearestTiesToEven, + &losesInfo); assert(fs == opOK && !losesInfo); (void)fs; @@ -3918,7 +3895,7 @@ void IEEEFloat::initFromFloat8E8M0FNUAPInt(const APInt &api) { uint64_t val = api.getRawData()[0]; uint64_t myexponent = (val & exponent_mask); - initialize(&semFloat8E8M0FNU); + initialize(&APFloatBase::semFloat8E8M0FNU); assert(partCount() == 1); // This format has unsigned representation only @@ -4025,109 +4002,109 @@ void IEEEFloat::initFromIEEEAPInt(const APInt &api) { } void IEEEFloat::initFromQuadrupleAPInt(const APInt &api) { - initFromIEEEAPInt<semIEEEquad>(api); + initFromIEEEAPInt<APFloatBase::semIEEEquad>(api); } void IEEEFloat::initFromDoubleAPInt(const APInt &api) { - initFromIEEEAPInt<semIEEEdouble>(api); + initFromIEEEAPInt<APFloatBase::semIEEEdouble>(api); } void IEEEFloat::initFromFloatAPInt(const APInt &api) { - initFromIEEEAPInt<semIEEEsingle>(api); + initFromIEEEAPInt<APFloatBase::semIEEEsingle>(api); } void IEEEFloat::initFromBFloatAPInt(const APInt &api) { - initFromIEEEAPInt<semBFloat>(api); + initFromIEEEAPInt<APFloatBase::semBFloat>(api); } void IEEEFloat::initFromHalfAPInt(const APInt &api) { - initFromIEEEAPInt<semIEEEhalf>(api); + initFromIEEEAPInt<APFloatBase::semIEEEhalf>(api); } void IEEEFloat::initFromFloat8E5M2APInt(const APInt &api) { - initFromIEEEAPInt<semFloat8E5M2>(api); + initFromIEEEAPInt<APFloatBase::semFloat8E5M2>(api); } void IEEEFloat::initFromFloat8E5M2FNUZAPInt(const APInt &api) { - initFromIEEEAPInt<semFloat8E5M2FNUZ>(api); + initFromIEEEAPInt<APFloatBase::semFloat8E5M2FNUZ>(api); } void IEEEFloat::initFromFloat8E4M3APInt(const APInt &api) { - initFromIEEEAPInt<semFloat8E4M3>(api); + initFromIEEEAPInt<APFloatBase::semFloat8E4M3>(api); } void IEEEFloat::initFromFloat8E4M3FNAPInt(const APInt &api) { - initFromIEEEAPInt<semFloat8E4M3FN>(api); + initFromIEEEAPInt<APFloatBase::semFloat8E4M3FN>(api); } void IEEEFloat::initFromFloat8E4M3FNUZAPInt(const APInt &api) { - initFromIEEEAPInt<semFloat8E4M3FNUZ>(api); + initFromIEEEAPInt<APFloatBase::semFloat8E4M3FNUZ>(api); } void IEEEFloat::initFromFloat8E4M3B11FNUZAPInt(const APInt &api) { - initFromIEEEAPInt<semFloat8E4M3B11FNUZ>(api); + initFromIEEEAPInt<APFloatBase::semFloat8E4M3B11FNUZ>(api); } void IEEEFloat::initFromFloat8E3M4APInt(const APInt &api) { - initFromIEEEAPInt<semFloat8E3M4>(api); + initFromIEEEAPInt<APFloatBase::semFloat8E3M4>(api); } void IEEEFloat::initFromFloatTF32APInt(const APInt &api) { - initFromIEEEAPInt<semFloatTF32>(api); + initFromIEEEAPInt<APFloatBase::semFloatTF32>(api); } void IEEEFloat::initFromFloat6E3M2FNAPInt(const APInt &api) { - initFromIEEEAPInt<semFloat6E3M2FN>(api); + initFromIEEEAPInt<APFloatBase::semFloat6E3M2FN>(api); } void IEEEFloat::initFromFloat6E2M3FNAPInt(const APInt &api) { - initFromIEEEAPInt<semFloat6E2M3FN>(api); + initFromIEEEAPInt<APFloatBase::semFloat6E2M3FN>(api); } void IEEEFloat::initFromFloat4E2M1FNAPInt(const APInt &api) { - initFromIEEEAPInt<semFloat4E2M1FN>(api); + initFromIEEEAPInt<APFloatBase::semFloat4E2M1FN>(api); } /// Treat api as containing the bits of a floating point number. void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { assert(api.getBitWidth() == Sem->sizeInBits); - if (Sem == &semIEEEhalf) + if (Sem == &APFloatBase::semIEEEhalf) return initFromHalfAPInt(api); - if (Sem == &semBFloat) + if (Sem == &APFloatBase::semBFloat) return initFromBFloatAPInt(api); - if (Sem == &semIEEEsingle) + if (Sem == &APFloatBase::semIEEEsingle) return initFromFloatAPInt(api); - if (Sem == &semIEEEdouble) + if (Sem == &APFloatBase::semIEEEdouble) return initFromDoubleAPInt(api); - if (Sem == &semX87DoubleExtended) + if (Sem == &APFloatBase::semX87DoubleExtended) return initFromF80LongDoubleAPInt(api); - if (Sem == &semIEEEquad) + if (Sem == &APFloatBase::semIEEEquad) return initFromQuadrupleAPInt(api); - if (Sem == &semPPCDoubleDoubleLegacy) + if (Sem == &APFloatBase::semPPCDoubleDoubleLegacy) return initFromPPCDoubleDoubleLegacyAPInt(api); - if (Sem == &semFloat8E5M2) + if (Sem == &APFloatBase::semFloat8E5M2) return initFromFloat8E5M2APInt(api); - if (Sem == &semFloat8E5M2FNUZ) + if (Sem == &APFloatBase::semFloat8E5M2FNUZ) return initFromFloat8E5M2FNUZAPInt(api); - if (Sem == &semFloat8E4M3) + if (Sem == &APFloatBase::semFloat8E4M3) return initFromFloat8E4M3APInt(api); - if (Sem == &semFloat8E4M3FN) + if (Sem == &APFloatBase::semFloat8E4M3FN) return initFromFloat8E4M3FNAPInt(api); - if (Sem == &semFloat8E4M3FNUZ) + if (Sem == &APFloatBase::semFloat8E4M3FNUZ) return initFromFloat8E4M3FNUZAPInt(api); - if (Sem == &semFloat8E4M3B11FNUZ) + if (Sem == &APFloatBase::semFloat8E4M3B11FNUZ) return initFromFloat8E4M3B11FNUZAPInt(api); - if (Sem == &semFloat8E3M4) + if (Sem == &APFloatBase::semFloat8E3M4) return initFromFloat8E3M4APInt(api); - if (Sem == &semFloatTF32) + if (Sem == &APFloatBase::semFloatTF32) return initFromFloatTF32APInt(api); - if (Sem == &semFloat8E8M0FNU) + if (Sem == &APFloatBase::semFloat8E8M0FNU) return initFromFloat8E8M0FNUAPInt(api); - if (Sem == &semFloat6E3M2FN) + if (Sem == &APFloatBase::semFloat6E3M2FN) return initFromFloat6E3M2FNAPInt(api); - if (Sem == &semFloat6E2M3FN) + if (Sem == &APFloatBase::semFloat6E2M3FN) return initFromFloat6E2M3FNAPInt(api); - if (Sem == &semFloat4E2M1FN) + if (Sem == &APFloatBase::semFloat4E2M1FN) return initFromFloat4E2M1FNAPInt(api); llvm_unreachable("unsupported semantics"); @@ -4202,11 +4179,11 @@ IEEEFloat::IEEEFloat(const fltSemantics &Sem, const APInt &API) { } IEEEFloat::IEEEFloat(float f) { - initFromAPInt(&semIEEEsingle, APInt::floatToBits(f)); + initFromAPInt(&APFloatBase::semIEEEsingle, APInt::floatToBits(f)); } IEEEFloat::IEEEFloat(double d) { - initFromAPInt(&semIEEEdouble, APInt::doubleToBits(d)); + initFromAPInt(&APFloatBase::semIEEEdouble, APInt::doubleToBits(d)); } namespace { @@ -4815,38 +4792,40 @@ IEEEFloat frexp(const IEEEFloat &Val, int &Exp, roundingMode RM) { DoubleAPFloat::DoubleAPFloat(const fltSemantics &S) : Semantics(&S), - Floats(new APFloat[2]{APFloat(semIEEEdouble), APFloat(semIEEEdouble)}) { - assert(Semantics == &semPPCDoubleDouble); + Floats(new APFloat[2]{APFloat(APFloatBase::semIEEEdouble), + APFloat(APFloatBase::semIEEEdouble)}) { + assert(Semantics == &APFloatBase::semPPCDoubleDouble); } DoubleAPFloat::DoubleAPFloat(const fltSemantics &S, uninitializedTag) - : Semantics(&S), - Floats(new APFloat[2]{APFloat(semIEEEdouble, uninitialized), - APFloat(semIEEEdouble, uninitialized)}) { - assert(Semantics == &semPPCDoubleDouble); + : Semantics(&S), Floats(new APFloat[2]{ + APFloat(APFloatBase::semIEEEdouble, uninitialized), + APFloat(APFloatBase::semIEEEdouble, uninitialized)}) { + assert(Semantics == &APFloatBase::semPPCDoubleDouble); } DoubleAPFloat::DoubleAPFloat(const fltSemantics &S, integerPart I) - : Semantics(&S), Floats(new APFloat[2]{APFloat(semIEEEdouble, I), - APFloat(semIEEEdouble)}) { - assert(Semantics == &semPPCDoubleDouble); + : Semantics(&S), + Floats(new APFloat[2]{APFloat(APFloatBase::semIEEEdouble, I), + APFloat(APFloatBase::semIEEEdouble)}) { + assert(Semantics == &APFloatBase::semPPCDoubleDouble); } DoubleAPFloat::DoubleAPFloat(const fltSemantics &S, const APInt &I) : Semantics(&S), Floats(new APFloat[2]{ - APFloat(semIEEEdouble, APInt(64, I.getRawData()[0])), - APFloat(semIEEEdouble, APInt(64, I.getRawData()[1]))}) { - assert(Semantics == &semPPCDoubleDouble); + APFloat(APFloatBase::semIEEEdouble, APInt(64, I.getRawData()[0])), + APFloat(APFloatBase::semIEEEdouble, APInt(64, I.getRawData()[1]))}) { + assert(Semantics == &APFloatBase::semPPCDoubleDouble); } DoubleAPFloat::DoubleAPFloat(const fltSemantics &S, APFloat &&First, APFloat &&Second) : Semantics(&S), Floats(new APFloat[2]{std::move(First), std::move(Second)}) { - assert(Semantics == &semPPCDoubleDouble); - assert(&Floats[0].getSemantics() == &semIEEEdouble); - assert(&Floats[1].getSemantics() == &semIEEEdouble); + assert(Semantics == &APFloatBase::semPPCDoubleDouble); + assert(&Floats[0].getSemantics() == &APFloatBase::semIEEEdouble); + assert(&Floats[1].getSemantics() == &APFloatBase::semIEEEdouble); } DoubleAPFloat::DoubleAPFloat(const DoubleAPFloat &RHS) @@ -4854,14 +4833,14 @@ DoubleAPFloat::DoubleAPFloat(const DoubleAPFloat &RHS) Floats(RHS.Floats ? new APFloat[2]{APFloat(RHS.Floats[0]), APFloat(RHS.Floats[1])} : nullptr) { - assert(Semantics == &semPPCDoubleDouble); + assert(Semantics == &APFloatBase::semPPCDoubleDouble); } DoubleAPFloat::DoubleAPFloat(DoubleAPFloat &&RHS) : Semantics(RHS.Semantics), Floats(RHS.Floats) { - RHS.Semantics = &semBogus; + RHS.Semantics = &APFloatBase::semBogus; RHS.Floats = nullptr; - assert(Semantics == &semPPCDoubleDouble); + assert(Semantics == &APFloatBase::semPPCDoubleDouble); } DoubleAPFloat &DoubleAPFloat::operator=(const DoubleAPFloat &RHS) { @@ -5009,12 +4988,12 @@ APFloat::opStatus DoubleAPFloat::addWithSpecial(const DoubleAPFloat &LHS, APFloat A(LHS.Floats[0]), AA(LHS.Floats[1]), C(RHS.Floats[0]), CC(RHS.Floats[1]); - assert(&A.getSemantics() == &semIEEEdouble); - assert(&AA.getSemantics() == &semIEEEdouble); - assert(&C.getSemantics() == &semIEEEdouble); - assert(&CC.getSemantics() == &semIEEEdouble); - assert(&Out.Floats[0].getSemantics() == &semIEEEdouble); - assert(&Out.Floats[1].getSemantics() == &semIEEEdouble); + assert(&A.getSemantics() == &APFloatBase::semIEEEdouble); + assert(&AA.getSemantics() == &APFloatBase::semIEEEdouble); + assert(&C.getSemantics() == &APFloatBase::semIEEEdouble); + assert(&CC.getSemantics() == &APFloatBase::semIEEEdouble); + assert(&Out.Floats[0].getSemantics() == &APFloatBase::semIEEEdouble); + assert(&Out.Floats[1].getSemantics() == &APFloatBase::semIEEEdouble); return Out.addImpl(A, AA, C, CC, RM); } @@ -5119,28 +5098,32 @@ APFloat::opStatus DoubleAPFloat::multiply(const DoubleAPFloat &RHS, APFloat::opStatus DoubleAPFloat::divide(const DoubleAPFloat &RHS, APFloat::roundingMode RM) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - APFloat Tmp(semPPCDoubleDoubleLegacy, bitcastToAPInt()); - auto Ret = - Tmp.divide(APFloat(semPPCDoubleDoubleLegacy, RHS.bitcastToAPInt()), RM); - *this = DoubleAPFloat(semPPCDoubleDouble, Tmp.bitcastToAPInt()); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + APFloat Tmp(APFloatBase::semPPCDoubleDoubleLegacy, bitcastToAPInt()); + auto Ret = Tmp.divide( + APFloat(APFloatBase::semPPCDoubleDoubleLegacy, RHS.bitcastToAPInt()), RM); + *this = DoubleAPFloat(APFloatBase::semPPCDoubleDouble, Tmp.bitcastToAPInt()); return Ret; } APFloat::opStatus DoubleAPFloat::remainder(const DoubleAPFloat &RHS) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - APFloat Tmp(semPPCDoubleDoubleLegacy, bitcastToAPInt()); - auto Ret = - Tmp.remainder(APFloat(semPPCDoubleDoubleLegacy, RHS.bitcastToAPInt())); - *this = DoubleAPFloat(semPPCDoubleDouble, Tmp.bitcastToAPInt()); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + APFloat Tmp(APFloatBase::semPPCDoubleDoubleLegacy, bitcastToAPInt()); + auto Ret = Tmp.remainder( + APFloat(APFloatBase::semPPCDoubleDoubleLegacy, RHS.bitcastToAPInt())); + *this = DoubleAPFloat(APFloatBase::semPPCDoubleDouble, Tmp.bitcastToAPInt()); return Ret; } APFloat::opStatus DoubleAPFloat::mod(const DoubleAPFloat &RHS) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - APFloat Tmp(semPPCDoubleDoubleLegacy, bitcastToAPInt()); - auto Ret = Tmp.mod(APFloat(semPPCDoubleDoubleLegacy, RHS.bitcastToAPInt())); - *this = DoubleAPFloat(semPPCDoubleDouble, Tmp.bitcastToAPInt()); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + APFloat Tmp(APFloatBase::semPPCDoubleDoubleLegacy, bitcastToAPInt()); + auto Ret = Tmp.mod( + APFloat(APFloatBase::semPPCDoubleDoubleLegacy, RHS.bitcastToAPInt())); + *this = DoubleAPFloat(APFloatBase::semPPCDoubleDouble, Tmp.bitcastToAPInt()); return Ret; } @@ -5148,17 +5131,21 @@ APFloat::opStatus DoubleAPFloat::fusedMultiplyAdd(const DoubleAPFloat &Multiplicand, const DoubleAPFloat &Addend, APFloat::roundingMode RM) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - APFloat Tmp(semPPCDoubleDoubleLegacy, bitcastToAPInt()); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + APFloat Tmp(APFloatBase::semPPCDoubleDoubleLegacy, bitcastToAPInt()); auto Ret = Tmp.fusedMultiplyAdd( - APFloat(semPPCDoubleDoubleLegacy, Multiplicand.bitcastToAPInt()), - APFloat(semPPCDoubleDoubleLegacy, Addend.bitcastToAPInt()), RM); - *this = DoubleAPFloat(semPPCDoubleDouble, Tmp.bitcastToAPInt()); + APFloat(APFloatBase::semPPCDoubleDoubleLegacy, + Multiplicand.bitcastToAPInt()), + APFloat(APFloatBase::semPPCDoubleDoubleLegacy, Addend.bitcastToAPInt()), + RM); + *this = DoubleAPFloat(APFloatBase::semPPCDoubleDouble, Tmp.bitcastToAPInt()); return Ret; } APFloat::opStatus DoubleAPFloat::roundToIntegral(APFloat::roundingMode RM) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); const APFloat &Hi = getFirst(); const APFloat &Lo = getSecond(); @@ -5309,22 +5296,28 @@ void DoubleAPFloat::makeZero(bool Neg) { } void DoubleAPFloat::makeLargest(bool Neg) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - Floats[0] = APFloat(semIEEEdouble, APInt(64, 0x7fefffffffffffffull)); - Floats[1] = APFloat(semIEEEdouble, APInt(64, 0x7c8ffffffffffffeull)); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + Floats[0] = + APFloat(APFloatBase::semIEEEdouble, APInt(64, 0x7fefffffffffffffull)); + Floats[1] = + APFloat(APFloatBase::semIEEEdouble, APInt(64, 0x7c8ffffffffffffeull)); if (Neg) changeSign(); } void DoubleAPFloat::makeSmallest(bool Neg) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); Floats[0].makeSmallest(Neg); Floats[1].makeZero(/* Neg = */ false); } void DoubleAPFloat::makeSmallestNormalized(bool Neg) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - Floats[0] = APFloat(semIEEEdouble, APInt(64, 0x0360000000000000ull)); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + Floats[0] = + APFloat(APFloatBase::semIEEEdouble, APInt(64, 0x0360000000000000ull)); if (Neg) Floats[0].changeSign(); Floats[1].makeZero(/* Neg = */ false); @@ -5355,7 +5348,8 @@ hash_code hash_value(const DoubleAPFloat &Arg) { } APInt DoubleAPFloat::bitcastToAPInt() const { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); uint64_t Data[] = { Floats[0].bitcastToAPInt().getRawData()[0], Floats[1].bitcastToAPInt().getRawData()[0], @@ -5365,10 +5359,11 @@ APInt DoubleAPFloat::bitcastToAPInt() const { Expected<APFloat::opStatus> DoubleAPFloat::convertFromString(StringRef S, roundingMode RM) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - APFloat Tmp(semPPCDoubleDoubleLegacy); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + APFloat Tmp(APFloatBase::semPPCDoubleDoubleLegacy); auto Ret = Tmp.convertFromString(S, RM); - *this = DoubleAPFloat(semPPCDoubleDouble, Tmp.bitcastToAPInt()); + *this = DoubleAPFloat(APFloatBase::semPPCDoubleDouble, Tmp.bitcastToAPInt()); return Ret; } @@ -5379,7 +5374,8 @@ Expected<APFloat::opStatus> DoubleAPFloat::convertFromString(StringRef S, // nextUp must choose the smallest output > input that follows these rules. // nexDown must choose the largest output < input that follows these rules. APFloat::opStatus DoubleAPFloat::next(bool nextDown) { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); // nextDown(x) = -nextUp(-x) if (nextDown) { changeSign(); @@ -5481,7 +5477,8 @@ APFloat::opStatus DoubleAPFloat::next(bool nextDown) { APFloat::opStatus DoubleAPFloat::convertToSignExtendedInteger( MutableArrayRef<integerPart> Input, unsigned int Width, bool IsSigned, roundingMode RM, bool *IsExact) const { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); // If Hi is not finite, or Lo is zero, the value is entirely represented // by Hi. Delegate to the simpler single-APFloat conversion. @@ -5761,8 +5758,9 @@ unsigned int DoubleAPFloat::convertToHexString(char *DST, unsigned int HexDigits, bool UpperCase, roundingMode RM) const { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - return APFloat(semPPCDoubleDoubleLegacy, bitcastToAPInt()) + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + return APFloat(APFloatBase::semPPCDoubleDoubleLegacy, bitcastToAPInt()) .convertToHexString(DST, HexDigits, UpperCase, RM); } @@ -5799,7 +5797,8 @@ bool DoubleAPFloat::isLargest() const { } bool DoubleAPFloat::isInteger() const { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); return Floats[0].isInteger() && Floats[1].isInteger(); } @@ -5807,8 +5806,9 @@ void DoubleAPFloat::toString(SmallVectorImpl<char> &Str, unsigned FormatPrecision, unsigned FormatMaxPadding, bool TruncateZero) const { - assert(Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - APFloat(semPPCDoubleDoubleLegacy, bitcastToAPInt()) + assert(Semantics == &APFloatBase::semPPCDoubleDouble && + "Unexpected Semantics"); + APFloat(APFloatBase::semPPCDoubleDoubleLegacy, bitcastToAPInt()) .toString(Str, FormatPrecision, FormatMaxPadding, TruncateZero); } @@ -5840,14 +5840,17 @@ int ilogb(const DoubleAPFloat &Arg) { DoubleAPFloat scalbn(const DoubleAPFloat &Arg, int Exp, APFloat::roundingMode RM) { - assert(Arg.Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); - return DoubleAPFloat(semPPCDoubleDouble, scalbn(Arg.Floats[0], Exp, RM), + assert(Arg.Semantics == &APFloatBase::PPCDoubleDouble() && + "Unexpected Semantics"); + return DoubleAPFloat(APFloatBase::PPCDoubleDouble(), + scalbn(Arg.Floats[0], Exp, RM), scalbn(Arg.Floats[1], Exp, RM)); } DoubleAPFloat frexp(const DoubleAPFloat &Arg, int &Exp, APFloat::roundingMode RM) { - assert(Arg.Semantics == &semPPCDoubleDouble && "Unexpected Semantics"); + assert(Arg.Semantics == &APFloatBase::PPCDoubleDouble() && + "Unexpected Semantics"); // Get the unbiased exponent e of the number, where |Arg| = m * 2^e for m in // [1.0, 2.0). @@ -5943,7 +5946,8 @@ DoubleAPFloat frexp(const DoubleAPFloat &Arg, int &Exp, } APFloat First = scalbn(Hi, -Exp, RM); - return DoubleAPFloat(semPPCDoubleDouble, std::move(First), std::move(Second)); + return DoubleAPFloat(APFloatBase::PPCDoubleDouble(), std::move(First), + std::move(Second)); } } // namespace detail @@ -5955,9 +5959,8 @@ APFloat::Storage::Storage(IEEEFloat F, const fltSemantics &Semantics) { } if (usesLayout<DoubleAPFloat>(Semantics)) { const fltSemantics& S = F.getSemantics(); - new (&Double) - DoubleAPFloat(Semantics, APFloat(std::move(F), S), - APFloat(semIEEEdouble)); + new (&Double) DoubleAPFloat(Semantics, APFloat(std::move(F), S), + APFloat(APFloatBase::IEEEdouble())); return; } llvm_unreachable("Unexpected semantics"); @@ -6065,8 +6068,9 @@ APFloat::opStatus APFloat::convert(const fltSemantics &ToSemantics, return U.IEEE.convert(ToSemantics, RM, losesInfo); if (usesLayout<IEEEFloat>(getSemantics()) && usesLayout<DoubleAPFloat>(ToSemantics)) { - assert(&ToSemantics == &semPPCDoubleDouble); - auto Ret = U.IEEE.convert(semPPCDoubleDoubleLegacy, RM, losesInfo); + assert(&ToSemantics == &APFloatBase::semPPCDoubleDouble); + auto Ret = + U.IEEE.convert(APFloatBase::semPPCDoubleDoubleLegacy, RM, losesInfo); *this = APFloat(ToSemantics, U.IEEE.bitcastToAPInt()); return Ret; } @@ -6113,13 +6117,15 @@ APFloat::opStatus APFloat::convertToInteger(APSInt &result, } double APFloat::convertToDouble() const { - if (&getSemantics() == (const llvm::fltSemantics *)&semIEEEdouble) + if (&getSemantics() == + (const llvm::fltSemantics *)&APFloatBase::semIEEEdouble) return getIEEE().convertToDouble(); assert(isRepresentableBy(getSemantics(), semIEEEdouble) && "Float semantics is not representable by IEEEdouble"); APFloat Temp = *this; bool LosesInfo; - opStatus St = Temp.convert(semIEEEdouble, rmNearestTiesToEven, &LosesInfo); + opStatus St = + Temp.convert(APFloatBase::semIEEEdouble, rmNearestTiesToEven, &LosesInfo); assert(!(St & opInexact) && !LosesInfo && "Unexpected imprecision"); (void)St; return Temp.getIEEE().convertToDouble(); @@ -6127,13 +6133,14 @@ double APFloat::convertToDouble() const { #ifdef HAS_IEE754_FLOAT128 float128 APFloat::convertToQuad() const { - if (&getSemantics() == (const llvm::fltSemantics *)&semIEEEquad) + if (&getSemantics() == (const llvm::fltSemantics *)&APFloatBase::semIEEEquad) return getIEEE().convertToQuad(); assert(isRepresentableBy(getSemantics(), semIEEEquad) && "Float semantics is not representable by IEEEquad"); APFloat Temp = *this; bool LosesInfo; - opStatus St = Temp.convert(semIEEEquad, rmNearestTiesToEven, &LosesInfo); + opStatus St = + Temp.convert(APFloatBase::semIEEEquad, rmNearestTiesToEven, &LosesInfo); assert(!(St & opInexact) && !LosesInfo && "Unexpected imprecision"); (void)St; return Temp.getIEEE().convertToQuad(); @@ -6141,18 +6148,84 @@ float128 APFloat::convertToQuad() const { #endif float APFloat::convertToFloat() const { - if (&getSemantics() == (const llvm::fltSemantics *)&semIEEEsingle) + if (&getSemantics() == + (const llvm::fltSemantics *)&APFloatBase::semIEEEsingle) return getIEEE().convertToFloat(); assert(isRepresentableBy(getSemantics(), semIEEEsingle) && "Float semantics is not representable by IEEEsingle"); APFloat Temp = *this; bool LosesInfo; - opStatus St = Temp.convert(semIEEEsingle, rmNearestTiesToEven, &LosesInfo); + opStatus St = + Temp.convert(APFloatBase::semIEEEsingle, rmNearestTiesToEven, &LosesInfo); assert(!(St & opInexact) && !LosesInfo && "Unexpected imprecision"); (void)St; return Temp.getIEEE().convertToFloat(); } +APFloat::Storage::~Storage() { + if (usesLayout<IEEEFloat>(*semantics)) { + IEEE.~IEEEFloat(); + return; + } + if (usesLayout<DoubleAPFloat>(*semantics)) { + Double.~DoubleAPFloat(); + return; + } + llvm_unreachable("Unexpected semantics"); +} + +APFloat::Storage::Storage(const APFloat::Storage &RHS) { + if (usesLayout<IEEEFloat>(*RHS.semantics)) { + new (this) IEEEFloat(RHS.IEEE); + return; + } + if (usesLayout<DoubleAPFloat>(*RHS.semantics)) { + new (this) DoubleAPFloat(RHS.Double); + return; + } + llvm_unreachable("Unexpected semantics"); +} + +APFloat::Storage::Storage(APFloat::Storage &&RHS) { + if (usesLayout<IEEEFloat>(*RHS.semantics)) { + new (this) IEEEFloat(std::move(RHS.IEEE)); + return; + } + if (usesLayout<DoubleAPFloat>(*RHS.semantics)) { + new (this) DoubleAPFloat(std::move(RHS.Double)); + return; + } + llvm_unreachable("Unexpected semantics"); +} + +APFloat::Storage &APFloat::Storage::operator=(const APFloat::Storage &RHS) { + if (usesLayout<IEEEFloat>(*semantics) && + usesLayout<IEEEFloat>(*RHS.semantics)) { + IEEE = RHS.IEEE; + } else if (usesLayout<DoubleAPFloat>(*semantics) && + usesLayout<DoubleAPFloat>(*RHS.semantics)) { + Double = RHS.Double; + } else if (this != &RHS) { + this->~Storage(); + new (this) Storage(RHS); + } + return *this; +} + +APFloat::Storage &APFloat::Storage::operator=(APFloat::Storage &&RHS) { + if (usesLayout<IEEEFloat>(*semantics) && + usesLayout<IEEEFloat>(*RHS.semantics)) { + IEEE = std::move(RHS.IEEE); + } else if (usesLayout<DoubleAPFloat>(*semantics) && + usesLayout<DoubleAPFloat>(*RHS.semantics)) { + Double = std::move(RHS.Double); + } else if (this != &RHS) { + this->~Storage(); + new (this) Storage(std::move(RHS)); + } + return *this; +} + } // namespace llvm #undef APFLOAT_DISPATCH_ON_SEMANTICS diff --git a/llvm/lib/Support/Unix/Signals.inc b/llvm/lib/Support/Unix/Signals.inc index 573ad82..78d6540 100644 --- a/llvm/lib/Support/Unix/Signals.inc +++ b/llvm/lib/Support/Unix/Signals.inc @@ -868,8 +868,7 @@ void llvm::sys::PrintStackTrace(raw_ostream &OS, int Depth) { nwidth = strlen(name) - 1; } - if (nwidth > width) - width = nwidth; + width = std::max(nwidth, width); } for (int i = 0; i < depth; ++i) { diff --git a/llvm/lib/Target/AArch64/AArch64ExpandImm.cpp b/llvm/lib/Target/AArch64/AArch64ExpandImm.cpp index 9801627..e9660ac1 100644 --- a/llvm/lib/Target/AArch64/AArch64ExpandImm.cpp +++ b/llvm/lib/Target/AArch64/AArch64ExpandImm.cpp @@ -585,7 +585,7 @@ void AArch64_IMM::expandMOVImm(uint64_t Imm, unsigned BitSize, uint64_t ShiftedMask = (0xFFFFULL << Shift); uint64_t ZeroChunk = UImm & ~ShiftedMask; uint64_t OneChunk = UImm | ShiftedMask; - uint64_t RotatedImm = (UImm << 32) | (UImm >> 32); + uint64_t RotatedImm = llvm::rotl(UImm, 32); uint64_t ReplicateChunk = ZeroChunk | (RotatedImm & ShiftedMask); if (AArch64_AM::processLogicalImmediate(ZeroChunk, BitSize, Encoding) || AArch64_AM::processLogicalImmediate(OneChunk, BitSize, Encoding) || diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 662d84b..a81de5c 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -27602,6 +27602,15 @@ static SDValue performPTestFirstCombine(SDNode *N, static SDValue performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { + SDLoc DL(N); + + // If a DUP(Op0) already exists, reuse it for the scalar_to_vector. + if (DCI.isAfterLegalizeDAG()) { + if (SDNode *LN = DCI.DAG.getNodeIfExists(AArch64ISD::DUP, N->getVTList(), + N->getOperand(0))) + return SDValue(LN, 0); + } + // Let's do below transform. // // t34: v4i32 = AArch64ISD::UADDLV t2 @@ -27638,7 +27647,6 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, return SDValue(); // Let's generate new sequence with AArch64ISD::NVCAST. - SDLoc DL(N); SDValue EXTRACT_SUBVEC = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v2i32, UADDLV, DAG.getConstant(0, DL, MVT::i64)); diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index 12c600f..d5117da 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -700,7 +700,7 @@ static unsigned removeCopies(const MachineRegisterInfo &MRI, unsigned VReg) { // csel instruction. If so, return the folded opcode, and the replacement // register. static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg, - unsigned *NewVReg = nullptr) { + unsigned *NewReg = nullptr) { VReg = removeCopies(MRI, VReg); if (!Register::isVirtualRegister(VReg)) return 0; @@ -708,8 +708,37 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg, bool Is64Bit = AArch64::GPR64allRegClass.hasSubClassEq(MRI.getRegClass(VReg)); const MachineInstr *DefMI = MRI.getVRegDef(VReg); unsigned Opc = 0; - unsigned SrcOpNum = 0; + unsigned SrcReg = 0; switch (DefMI->getOpcode()) { + case AArch64::SUBREG_TO_REG: + // Check for the following way to define an 64-bit immediate: + // %0:gpr32 = MOVi32imm 1 + // %1:gpr64 = SUBREG_TO_REG 0, %0:gpr32, %subreg.sub_32 + if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 0) + return 0; + if (!DefMI->getOperand(2).isReg()) + return 0; + if (!DefMI->getOperand(3).isImm() || + DefMI->getOperand(3).getImm() != AArch64::sub_32) + return 0; + DefMI = MRI.getVRegDef(DefMI->getOperand(2).getReg()); + if (DefMI->getOpcode() != AArch64::MOVi32imm) + return 0; + if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 1) + return 0; + assert(Is64Bit); + SrcReg = AArch64::XZR; + Opc = AArch64::CSINCXr; + break; + + case AArch64::MOVi32imm: + case AArch64::MOVi64imm: + if (!DefMI->getOperand(1).isImm() || DefMI->getOperand(1).getImm() != 1) + return 0; + SrcReg = Is64Bit ? AArch64::XZR : AArch64::WZR; + Opc = Is64Bit ? AArch64::CSINCXr : AArch64::CSINCWr; + break; + case AArch64::ADDSXri: case AArch64::ADDSWri: // if NZCV is used, do not fold. @@ -724,7 +753,7 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg, if (!DefMI->getOperand(2).isImm() || DefMI->getOperand(2).getImm() != 1 || DefMI->getOperand(3).getImm() != 0) return 0; - SrcOpNum = 1; + SrcReg = DefMI->getOperand(1).getReg(); Opc = Is64Bit ? AArch64::CSINCXr : AArch64::CSINCWr; break; @@ -734,7 +763,7 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg, unsigned ZReg = removeCopies(MRI, DefMI->getOperand(1).getReg()); if (ZReg != AArch64::XZR && ZReg != AArch64::WZR) return 0; - SrcOpNum = 2; + SrcReg = DefMI->getOperand(2).getReg(); Opc = Is64Bit ? AArch64::CSINVXr : AArch64::CSINVWr; break; } @@ -753,17 +782,17 @@ static unsigned canFoldIntoCSel(const MachineRegisterInfo &MRI, unsigned VReg, unsigned ZReg = removeCopies(MRI, DefMI->getOperand(1).getReg()); if (ZReg != AArch64::XZR && ZReg != AArch64::WZR) return 0; - SrcOpNum = 2; + SrcReg = DefMI->getOperand(2).getReg(); Opc = Is64Bit ? AArch64::CSNEGXr : AArch64::CSNEGWr; break; } default: return 0; } - assert(Opc && SrcOpNum && "Missing parameters"); + assert(Opc && SrcReg && "Missing parameters"); - if (NewVReg) - *NewVReg = DefMI->getOperand(SrcOpNum).getReg(); + if (NewReg) + *NewReg = SrcReg; return Opc; } @@ -964,28 +993,34 @@ void AArch64InstrInfo::insertSelect(MachineBasicBlock &MBB, // Try folding simple instructions into the csel. if (TryFold) { - unsigned NewVReg = 0; - unsigned FoldedOpc = canFoldIntoCSel(MRI, TrueReg, &NewVReg); + unsigned NewReg = 0; + unsigned FoldedOpc = canFoldIntoCSel(MRI, TrueReg, &NewReg); if (FoldedOpc) { // The folded opcodes csinc, csinc and csneg apply the operation to // FalseReg, so we need to invert the condition. CC = AArch64CC::getInvertedCondCode(CC); TrueReg = FalseReg; } else - FoldedOpc = canFoldIntoCSel(MRI, FalseReg, &NewVReg); + FoldedOpc = canFoldIntoCSel(MRI, FalseReg, &NewReg); // Fold the operation. Leave any dead instructions for DCE to clean up. if (FoldedOpc) { - FalseReg = NewVReg; + FalseReg = NewReg; Opc = FoldedOpc; - // The extends the live range of NewVReg. - MRI.clearKillFlags(NewVReg); + // Extend the live range of NewReg. + MRI.clearKillFlags(NewReg); } } // Pull all virtual register into the appropriate class. MRI.constrainRegClass(TrueReg, RC); - MRI.constrainRegClass(FalseReg, RC); + // FalseReg might be WZR or XZR if the folded operand is a literal 1. + assert( + (FalseReg.isVirtual() || FalseReg == AArch64::WZR || + FalseReg == AArch64::XZR) && + "FalseReg was folded into a non-virtual register other than WZR or XZR"); + if (FalseReg.isVirtual()) + MRI.constrainRegClass(FalseReg, RC); // Insert the csel. BuildMI(MBB, I, DL, get(Opc), DstReg) @@ -5063,7 +5098,7 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, bool RenamableDest, bool RenamableSrc) const { if (AArch64::GPR32spRegClass.contains(DestReg) && - (AArch64::GPR32spRegClass.contains(SrcReg) || SrcReg == AArch64::WZR)) { + AArch64::GPR32spRegClass.contains(SrcReg)) { if (DestReg == AArch64::WSP || SrcReg == AArch64::WSP) { // If either operand is WSP, expand to ADD #0. if (Subtarget.hasZeroCycleRegMoveGPR64() && @@ -5088,21 +5123,14 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, .addImm(0) .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); } - } else if (SrcReg == AArch64::WZR && Subtarget.hasZeroCycleZeroingGPR32()) { - BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg) - .addImm(0) - .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); } else if (Subtarget.hasZeroCycleRegMoveGPR64() && !Subtarget.hasZeroCycleRegMoveGPR32()) { // Cyclone recognizes "ORR Xd, XZR, Xm" as a zero-cycle register move. MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32, &AArch64::GPR64spRegClass); assert(DestRegX.isValid() && "Destination super-reg not valid"); - MCRegister SrcRegX = - SrcReg == AArch64::WZR - ? AArch64::XZR - : RI.getMatchingSuperReg(SrcReg, AArch64::sub_32, - &AArch64::GPR64spRegClass); + MCRegister SrcRegX = RI.getMatchingSuperReg(SrcReg, AArch64::sub_32, + &AArch64::GPR64spRegClass); assert(SrcRegX.isValid() && "Source super-reg not valid"); // This instruction is reading and writing X registers. This may upset // the register scavenger and machine verifier, so we need to indicate @@ -5121,6 +5149,51 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, return; } + // GPR32 zeroing + if (AArch64::GPR32spRegClass.contains(DestReg) && SrcReg == AArch64::WZR) { + if (Subtarget.hasZeroCycleZeroingGPR32()) { + BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg) + .addImm(0) + .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); + } else { + BuildMI(MBB, I, DL, get(AArch64::ORRWrr), DestReg) + .addReg(AArch64::WZR) + .addReg(AArch64::WZR); + } + return; + } + + if (AArch64::GPR64spRegClass.contains(DestReg) && + AArch64::GPR64spRegClass.contains(SrcReg)) { + if (DestReg == AArch64::SP || SrcReg == AArch64::SP) { + // If either operand is SP, expand to ADD #0. + BuildMI(MBB, I, DL, get(AArch64::ADDXri), DestReg) + .addReg(SrcReg, getKillRegState(KillSrc)) + .addImm(0) + .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); + } else { + // Otherwise, expand to ORR XZR. + BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg) + .addReg(AArch64::XZR) + .addReg(SrcReg, getKillRegState(KillSrc)); + } + return; + } + + // GPR64 zeroing + if (AArch64::GPR64spRegClass.contains(DestReg) && SrcReg == AArch64::XZR) { + if (Subtarget.hasZeroCycleZeroingGPR64()) { + BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestReg) + .addImm(0) + .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); + } else { + BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg) + .addReg(AArch64::XZR) + .addReg(AArch64::XZR); + } + return; + } + // Copy a Predicate register by ORRing with itself. if (AArch64::PPRRegClass.contains(DestReg) && AArch64::PPRRegClass.contains(SrcReg)) { @@ -5205,27 +5278,6 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, return; } - if (AArch64::GPR64spRegClass.contains(DestReg) && - (AArch64::GPR64spRegClass.contains(SrcReg) || SrcReg == AArch64::XZR)) { - if (DestReg == AArch64::SP || SrcReg == AArch64::SP) { - // If either operand is SP, expand to ADD #0. - BuildMI(MBB, I, DL, get(AArch64::ADDXri), DestReg) - .addReg(SrcReg, getKillRegState(KillSrc)) - .addImm(0) - .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); - } else if (SrcReg == AArch64::XZR && Subtarget.hasZeroCycleZeroingGPR64()) { - BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestReg) - .addImm(0) - .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); - } else { - // Otherwise, expand to ORR XZR. - BuildMI(MBB, I, DL, get(AArch64::ORRXrr), DestReg) - .addReg(AArch64::XZR) - .addReg(SrcReg, getKillRegState(KillSrc)); - } - return; - } - // Copy a DDDD register quad by copying the individual sub-registers. if (AArch64::DDDDRegClass.contains(DestReg) && AArch64::DDDDRegClass.contains(SrcReg)) { diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 479e345..e3370d3 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -5722,7 +5722,7 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost( } // Add additional cost for the extends that would need to be inserted. - return Cost + 4; + return Cost + 2; } InstructionCost diff --git a/llvm/lib/Target/AMDGPU/AMDGPUAsanInstrumentation.cpp b/llvm/lib/Target/AMDGPU/AMDGPUAsanInstrumentation.cpp index 19e2a6a..93732a7 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUAsanInstrumentation.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUAsanInstrumentation.cpp @@ -244,11 +244,8 @@ void getInterestingMemoryOperands( // Masked store has an initial operand for the value. unsigned OpOffset = IsWrite ? 1 : 0; Type *Ty = IsWrite ? CI->getArgOperand(0)->getType() : CI->getType(); - MaybeAlign Alignment = Align(1); - // Otherwise no alignment guarantees. We probably got Undef. - if (auto *Op = dyn_cast<ConstantInt>(CI->getOperand(1 + OpOffset))) - Alignment = Op->getMaybeAlignValue(); - Value *Mask = CI->getOperand(2 + OpOffset); + MaybeAlign Alignment = CI->getParamAlign(OpOffset); + Value *Mask = CI->getOperand(1 + OpOffset); Interesting.emplace_back(I, OpOffset, IsWrite, Ty, Alignment, Mask); break; } diff --git a/llvm/lib/Target/AMDGPU/AMDGPUBarrierLatency.cpp b/llvm/lib/Target/AMDGPU/AMDGPUBarrierLatency.cpp new file mode 100644 index 0000000..30a1f05 --- /dev/null +++ b/llvm/lib/Target/AMDGPU/AMDGPUBarrierLatency.cpp @@ -0,0 +1,73 @@ +//===--- AMDGPUBarrierLatency.cpp - AMDGPU Barrier Latency ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file This file contains a DAG scheduling mutation to add latency to +/// barrier edges between ATOMIC_FENCE instructions and preceding +/// memory accesses potentially affected by the fence. +/// This encourages the scheduling of more instructions before +/// ATOMIC_FENCE instructions. ATOMIC_FENCE instructions may +/// introduce wait counting or indicate an impending S_BARRIER +/// wait. Having more instructions in-flight across these +/// constructs improves latency hiding. +// +//===----------------------------------------------------------------------===// + +#include "AMDGPUBarrierLatency.h" +#include "MCTargetDesc/AMDGPUMCTargetDesc.h" +#include "SIInstrInfo.h" +#include "llvm/CodeGen/ScheduleDAGInstrs.h" + +using namespace llvm; + +namespace { + +class BarrierLatency : public ScheduleDAGMutation { +public: + BarrierLatency() = default; + void apply(ScheduleDAGInstrs *DAG) override; +}; + +void BarrierLatency::apply(ScheduleDAGInstrs *DAG) { + constexpr unsigned SyntheticLatency = 2000; + for (SUnit &SU : DAG->SUnits) { + const MachineInstr *MI = SU.getInstr(); + if (MI->getOpcode() != AMDGPU::ATOMIC_FENCE) + continue; + + // Update latency on barrier edges of ATOMIC_FENCE. + // We don't consider the scope of the fence or type of instruction + // involved in the barrier edge. + for (SDep &PredDep : SU.Preds) { + if (!PredDep.isBarrier()) + continue; + SUnit *PredSU = PredDep.getSUnit(); + MachineInstr *MI = PredSU->getInstr(); + // Only consider memory loads + if (!MI->mayLoad() || MI->mayStore()) + continue; + SDep ForwardD = PredDep; + ForwardD.setSUnit(&SU); + for (SDep &SuccDep : PredSU->Succs) { + if (SuccDep == ForwardD) { + SuccDep.setLatency(SuccDep.getLatency() + SyntheticLatency); + break; + } + } + PredDep.setLatency(PredDep.getLatency() + SyntheticLatency); + PredSU->setDepthDirty(); + SU.setDepthDirty(); + } + } +} + +} // end namespace + +std::unique_ptr<ScheduleDAGMutation> +llvm::createAMDGPUBarrierLatencyDAGMutation() { + return std::make_unique<BarrierLatency>(); +} diff --git a/llvm/lib/Target/AMDGPU/AMDGPUBarrierLatency.h b/llvm/lib/Target/AMDGPU/AMDGPUBarrierLatency.h new file mode 100644 index 0000000..c23f0b9 --- /dev/null +++ b/llvm/lib/Target/AMDGPU/AMDGPUBarrierLatency.h @@ -0,0 +1,21 @@ +//===- AMDGPUBarrierLatency.h - AMDGPU Export Clustering --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_AMDGPU_AMDGPUBARRIERLATENCY_H +#define LLVM_LIB_TARGET_AMDGPU_AMDGPUBARRIERLATENCY_H + +#include "llvm/CodeGen/ScheduleDAGMutation.h" +#include <memory> + +namespace llvm { + +std::unique_ptr<ScheduleDAGMutation> createAMDGPUBarrierLatencyDAGMutation(); + +} // namespace llvm + +#endif // LLVM_LIB_TARGET_AMDGPU_AMDGPUBARRIERLATENCY_H diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp index 12915c73..97c2c9c 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp @@ -3446,10 +3446,14 @@ bool AMDGPUInstructionSelector::selectBufferLoadLds(MachineInstr &MI) const { : 0); // swz MachineMemOperand *LoadMMO = *MI.memoperands_begin(); + // Don't set the offset value here because the pointer points to the base of + // the buffer. MachinePointerInfo LoadPtrI = LoadMMO->getPointerInfo(); - LoadPtrI.Offset = MI.getOperand(6 + OpOffset).getImm(); + MachinePointerInfo StorePtrI = LoadPtrI; - StorePtrI.V = nullptr; + LoadPtrI.V = PoisonValue::get(PointerType::get(MF->getFunction().getContext(), + AMDGPUAS::BUFFER_RESOURCE)); + LoadPtrI.AddrSpace = AMDGPUAS::BUFFER_RESOURCE; StorePtrI.AddrSpace = AMDGPUAS::LOCAL_ADDRESS; auto F = LoadMMO->getFlags() & @@ -3627,13 +3631,17 @@ bool AMDGPUInstructionSelector::selectGlobalLoadLds(MachineInstr &MI) const{ if (isSGPR(Addr)) MIB.addReg(VOffset); - MIB.add(MI.getOperand(4)) // offset - .add(MI.getOperand(5)); // cpol + MIB.add(MI.getOperand(4)); // offset + + unsigned Aux = MI.getOperand(5).getImm(); + MIB.addImm(Aux & ~AMDGPU::CPol::VIRTUAL_BITS); // cpol MachineMemOperand *LoadMMO = *MI.memoperands_begin(); MachinePointerInfo LoadPtrI = LoadMMO->getPointerInfo(); LoadPtrI.Offset = MI.getOperand(4).getImm(); MachinePointerInfo StorePtrI = LoadPtrI; + LoadPtrI.V = PoisonValue::get(PointerType::get(MF->getFunction().getContext(), + AMDGPUAS::GLOBAL_ADDRESS)); LoadPtrI.AddrSpace = AMDGPUAS::GLOBAL_ADDRESS; StorePtrI.AddrSpace = AMDGPUAS::LOCAL_ADDRESS; auto F = LoadMMO->getFlags() & diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp index 4958a20..996b55f 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp @@ -17,6 +17,7 @@ #include "AMDGPUTargetMachine.h" #include "AMDGPU.h" #include "AMDGPUAliasAnalysis.h" +#include "AMDGPUBarrierLatency.h" #include "AMDGPUCtorDtorLowering.h" #include "AMDGPUExportClustering.h" #include "AMDGPUExportKernelRuntimeHandles.h" @@ -639,6 +640,7 @@ createGCNMaxOccupancyMachineScheduler(MachineSchedContext *C) { DAG->addMutation(createIGroupLPDAGMutation(AMDGPU::SchedulingPhase::Initial)); DAG->addMutation(createAMDGPUMacroFusionDAGMutation()); DAG->addMutation(createAMDGPUExportClusteringDAGMutation()); + DAG->addMutation(createAMDGPUBarrierLatencyDAGMutation()); return DAG; } @@ -659,6 +661,7 @@ createGCNMaxMemoryClauseMachineScheduler(MachineSchedContext *C) { if (ST.shouldClusterStores()) DAG->addMutation(createStoreClusterDAGMutation(DAG->TII, DAG->TRI)); DAG->addMutation(createAMDGPUExportClusteringDAGMutation()); + DAG->addMutation(createAMDGPUBarrierLatencyDAGMutation()); return DAG; } @@ -1197,6 +1200,7 @@ GCNTargetMachine::createPostMachineScheduler(MachineSchedContext *C) const { EnableVOPD) DAG->addMutation(createVOPDPairingMutation()); DAG->addMutation(createAMDGPUExportClusteringDAGMutation()); + DAG->addMutation(createAMDGPUBarrierLatencyDAGMutation()); return DAG; } //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/AMDGPU/CMakeLists.txt b/llvm/lib/Target/AMDGPU/CMakeLists.txt index 13f727b68..a1e0e52 100644 --- a/llvm/lib/Target/AMDGPU/CMakeLists.txt +++ b/llvm/lib/Target/AMDGPU/CMakeLists.txt @@ -52,6 +52,7 @@ add_llvm_target(AMDGPUCodeGen AMDGPUAsmPrinter.cpp AMDGPUAtomicOptimizer.cpp AMDGPUAttributor.cpp + AMDGPUBarrierLatency.cpp AMDGPUCallLowering.cpp AMDGPUCodeGenPrepare.cpp AMDGPUCombinerHelper.cpp diff --git a/llvm/lib/Target/AMDGPU/GCNRegPressure.h b/llvm/lib/Target/AMDGPU/GCNRegPressure.h index 979a8b0..4b22c68 100644 --- a/llvm/lib/Target/AMDGPU/GCNRegPressure.h +++ b/llvm/lib/Target/AMDGPU/GCNRegPressure.h @@ -21,6 +21,7 @@ #include "llvm/CodeGen/LiveIntervals.h" #include "llvm/CodeGen/RegisterPressure.h" #include <algorithm> +#include <array> namespace llvm { @@ -45,7 +46,7 @@ struct GCNRegPressure { return !Value[SGPR] && !Value[VGPR] && !Value[AGPR] && !Value[AVGPR]; } - void clear() { std::fill(&Value[0], &Value[ValueArraySize], 0); } + void clear() { Value.fill(0); } unsigned getNumRegs(RegKind Kind) const { assert(Kind < TOTAL_KINDS); @@ -127,9 +128,7 @@ struct GCNRegPressure { bool less(const MachineFunction &MF, const GCNRegPressure &O, unsigned MaxOccupancy = std::numeric_limits<unsigned>::max()) const; - bool operator==(const GCNRegPressure &O) const { - return std::equal(&Value[0], &Value[ValueArraySize], O.Value); - } + bool operator==(const GCNRegPressure &O) const { return Value == O.Value; } bool operator!=(const GCNRegPressure &O) const { return !(*this == O); @@ -160,7 +159,7 @@ private: /// Pressure for all register kinds (first all regular registers kinds, then /// all tuple register kinds). - unsigned Value[ValueArraySize]; + std::array<unsigned, ValueArraySize> Value; static unsigned getRegKind(const TargetRegisterClass *RC, const SIRegisterInfo *STI); diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp index f291e37..c8bbcbb 100644 --- a/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp +++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.cpp @@ -169,7 +169,6 @@ GCNSubtarget::GCNSubtarget(const Triple &TT, StringRef GPU, StringRef FS, : // clang-format off AMDGPUGenSubtargetInfo(TT, GPU, /*TuneCPU*/ GPU, FS), AMDGPUSubtarget(TT), - TargetTriple(TT), TargetID(*this), InstrItins(getInstrItineraryForCPU(GPU)), InstrInfo(initializeSubtargetDependencies(TT, GPU, FS)), diff --git a/llvm/lib/Target/AMDGPU/GCNSubtarget.h b/llvm/lib/Target/AMDGPU/GCNSubtarget.h index c2e6078..a466780 100644 --- a/llvm/lib/Target/AMDGPU/GCNSubtarget.h +++ b/llvm/lib/Target/AMDGPU/GCNSubtarget.h @@ -60,7 +60,6 @@ private: protected: // Basic subtarget description. - Triple TargetTriple; AMDGPU::IsaInfo::AMDGPUTargetID TargetID; unsigned Gen = INVALID; InstrItineraryData InstrItins; diff --git a/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp b/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp index 2aa54c9..09ef6ac 100644 --- a/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/R600ISelLowering.cpp @@ -45,6 +45,9 @@ R600TargetLowering::R600TargetLowering(const TargetMachine &TM, // Legalize loads and stores to the private address space. setOperationAction(ISD::LOAD, {MVT::i32, MVT::v2i32, MVT::v4i32}, Custom); + // 32-bit ABS is legal for AMDGPU except for R600 + setOperationAction(ISD::ABS, MVT::i32, Expand); + // EXTLOAD should be the same as ZEXTLOAD. It is legal for some address // spaces, so it is custom lowered to handle those where it isn't. for (auto Op : {ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}) diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h index ecc2824..b7a92a0 100644 --- a/llvm/lib/Target/AMDGPU/SIDefines.h +++ b/llvm/lib/Target/AMDGPU/SIDefines.h @@ -423,6 +423,9 @@ enum CPol { // Volatile (used to preserve/signal operation volatility for buffer // operations not a real instruction bit) VOLATILE = 1 << 31, + // The set of "cache policy" bits used for compiler features that + // do not correspond to handware features. + VIRTUAL_BITS = VOLATILE, }; } // namespace CPol diff --git a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp index 0189e7b..5c39f7a 100644 --- a/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIFrameLowering.cpp @@ -1034,16 +1034,13 @@ void SIFrameLowering::emitCSRSpillStores( StoreWWMRegisters(WWMCalleeSavedRegs); if (FuncInfo->isWholeWaveFunction()) { - // SI_WHOLE_WAVE_FUNC_SETUP has outlived its purpose, so we can remove - // it now. If we have already saved some WWM CSR registers, then the EXEC is - // already -1 and we don't need to do anything else. Otherwise, set EXEC to - // -1 here. + // If we have already saved some WWM CSR registers, then the EXEC is already + // -1 and we don't need to do anything else. Otherwise, set EXEC to -1 here. if (!ScratchExecCopy) buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL, /*IsProlog*/ true, /*EnableInactiveLanes*/ true); else if (WWMCalleeSavedRegs.empty()) EnableAllLanes(); - TII->getWholeWaveFunctionSetup(MF)->eraseFromParent(); } else if (ScratchExecCopy) { // FIXME: Split block and make terminator. BuildMI(MBB, MBBI, DL, TII->get(LMC.MovOpc), LMC.ExecReg) @@ -1340,6 +1337,11 @@ void SIFrameLowering::emitPrologue(MachineFunction &MF, "Needed to save BP but didn't save it anywhere"); assert((HasBP || !BPSaved) && "Saved BP but didn't need it"); + + if (FuncInfo->isWholeWaveFunction()) { + // SI_WHOLE_WAVE_FUNC_SETUP has outlived its purpose. + TII->getWholeWaveFunctionSetup(MF)->eraseFromParent(); + } } void SIFrameLowering::emitEpilogue(MachineFunction &MF, diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index a2841c11..a757421 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -1651,6 +1651,9 @@ bool SITargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, Info.memVT = EVT::getIntegerVT(CI.getContext(), Width * 8); Info.ptrVal = CI.getArgOperand(1); Info.flags |= MachineMemOperand::MOLoad | MachineMemOperand::MOStore; + auto *Aux = cast<ConstantInt>(CI.getArgOperand(CI.arg_size() - 1)); + if (Aux->getZExtValue() & AMDGPU::CPol::VOLATILE) + Info.flags |= MachineMemOperand::MOVolatile; return true; } case Intrinsic::amdgcn_ds_bvh_stack_rtn: @@ -11219,8 +11222,8 @@ SDValue SITargetLowering::LowerINTRINSIC_VOID(SDValue Op, MachinePointerInfo StorePtrI = LoadPtrI; LoadPtrI.V = PoisonValue::get( - PointerType::get(*DAG.getContext(), AMDGPUAS::GLOBAL_ADDRESS)); - LoadPtrI.AddrSpace = AMDGPUAS::GLOBAL_ADDRESS; + PointerType::get(*DAG.getContext(), AMDGPUAS::BUFFER_RESOURCE)); + LoadPtrI.AddrSpace = AMDGPUAS::BUFFER_RESOURCE; StorePtrI.AddrSpace = AMDGPUAS::LOCAL_ADDRESS; auto F = LoadMMO->getFlags() & @@ -11307,7 +11310,11 @@ SDValue SITargetLowering::LowerINTRINSIC_VOID(SDValue Op, } Ops.push_back(Op.getOperand(5)); // Offset - Ops.push_back(Op.getOperand(6)); // CPol + + unsigned Aux = Op.getConstantOperandVal(6); + Ops.push_back(DAG.getTargetConstant(Aux & ~AMDGPU::CPol::VIRTUAL_BITS, DL, + MVT::i32)); // CPol + Ops.push_back(M0Val.getValue(0)); // Chain Ops.push_back(M0Val.getValue(1)); // Glue diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp index 50447f4..2ff2d2f 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp @@ -4032,28 +4032,31 @@ static unsigned getNewFMAInst(const GCNSubtarget &ST, unsigned Opc) { } } +/// Helper struct for the implementation of 3-address conversion to communicate +/// updates made to instruction operands. +struct SIInstrInfo::ThreeAddressUpdates { + /// Other instruction whose def is no longer used by the converted + /// instruction. + MachineInstr *RemoveMIUse = nullptr; +}; + MachineInstr *SIInstrInfo::convertToThreeAddress(MachineInstr &MI, LiveVariables *LV, LiveIntervals *LIS) const { MachineBasicBlock &MBB = *MI.getParent(); - unsigned Opc = MI.getOpcode(); + ThreeAddressUpdates U; + MachineInstr *NewMI = convertToThreeAddressImpl(MI, U); - // Handle MFMA. - int NewMFMAOpc = AMDGPU::getMFMAEarlyClobberOp(Opc); - if (NewMFMAOpc != -1) { - MachineInstrBuilder MIB = - BuildMI(MBB, MI, MI.getDebugLoc(), get(NewMFMAOpc)); - for (unsigned I = 0, E = MI.getNumOperands(); I != E; ++I) - MIB.add(MI.getOperand(I)); - updateLiveVariables(LV, MI, *MIB); + if (NewMI) { + updateLiveVariables(LV, MI, *NewMI); if (LIS) { - LIS->ReplaceMachineInstrInMaps(MI, *MIB); + LIS->ReplaceMachineInstrInMaps(MI, *NewMI); // SlotIndex of defs needs to be updated when converting to early-clobber - MachineOperand &Def = MIB->getOperand(0); + MachineOperand &Def = NewMI->getOperand(0); if (Def.isEarlyClobber() && Def.isReg() && LIS->hasInterval(Def.getReg())) { - SlotIndex OldIndex = LIS->getInstructionIndex(*MIB).getRegSlot(false); - SlotIndex NewIndex = LIS->getInstructionIndex(*MIB).getRegSlot(true); + SlotIndex OldIndex = LIS->getInstructionIndex(*NewMI).getRegSlot(false); + SlotIndex NewIndex = LIS->getInstructionIndex(*NewMI).getRegSlot(true); auto &LI = LIS->getInterval(Def.getReg()); auto UpdateDefIndex = [&](LiveRange &LR) { auto *S = LR.find(OldIndex); @@ -4068,6 +4071,58 @@ MachineInstr *SIInstrInfo::convertToThreeAddress(MachineInstr &MI, UpdateDefIndex(SR); } } + } + + if (U.RemoveMIUse) { + MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); + // The only user is the instruction which will be killed. + Register DefReg = U.RemoveMIUse->getOperand(0).getReg(); + + if (MRI.hasOneNonDBGUse(DefReg)) { + // We cannot just remove the DefMI here, calling pass will crash. + U.RemoveMIUse->setDesc(get(AMDGPU::IMPLICIT_DEF)); + U.RemoveMIUse->getOperand(0).setIsDead(true); + for (unsigned I = U.RemoveMIUse->getNumOperands() - 1; I != 0; --I) + U.RemoveMIUse->removeOperand(I); + if (LV) + LV->getVarInfo(DefReg).AliveBlocks.clear(); + } + + if (LIS) { + LiveInterval &DefLI = LIS->getInterval(DefReg); + + // We cannot delete the original instruction here, so hack out the use + // in the original instruction with a dummy register so we can use + // shrinkToUses to deal with any multi-use edge cases. Other targets do + // not have the complexity of deleting a use to consider here. + Register DummyReg = MRI.cloneVirtualRegister(DefReg); + for (MachineOperand &MIOp : MI.uses()) { + if (MIOp.isReg() && MIOp.getReg() == DefReg) { + MIOp.setIsUndef(true); + MIOp.setReg(DummyReg); + } + } + + LIS->shrinkToUses(&DefLI); + } + } + + return NewMI; +} + +MachineInstr * +SIInstrInfo::convertToThreeAddressImpl(MachineInstr &MI, + ThreeAddressUpdates &U) const { + MachineBasicBlock &MBB = *MI.getParent(); + unsigned Opc = MI.getOpcode(); + + // Handle MFMA. + int NewMFMAOpc = AMDGPU::getMFMAEarlyClobberOp(Opc); + if (NewMFMAOpc != -1) { + MachineInstrBuilder MIB = + BuildMI(MBB, MI, MI.getDebugLoc(), get(NewMFMAOpc)); + for (unsigned I = 0, E = MI.getNumOperands(); I != E; ++I) + MIB.add(MI.getOperand(I)); return MIB; } @@ -4077,11 +4132,6 @@ MachineInstr *SIInstrInfo::convertToThreeAddress(MachineInstr &MI, .setMIFlags(MI.getFlags()); for (unsigned I = 0, E = MI.getNumOperands(); I != E; ++I) MIB->addOperand(MI.getOperand(I)); - - updateLiveVariables(LV, MI, *MIB); - if (LIS) - LIS->ReplaceMachineInstrInMaps(MI, *MIB); - return MIB; } @@ -4152,39 +4202,6 @@ MachineInstr *SIInstrInfo::convertToThreeAddress(MachineInstr &MI, (ST.getConstantBusLimit(Opc) > 1 || !Src0->isReg() || !RI.isSGPRReg(MBB.getParent()->getRegInfo(), Src0->getReg()))) { MachineInstr *DefMI; - const auto killDef = [&]() -> void { - MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); - // The only user is the instruction which will be killed. - Register DefReg = DefMI->getOperand(0).getReg(); - - if (MRI.hasOneNonDBGUse(DefReg)) { - // We cannot just remove the DefMI here, calling pass will crash. - DefMI->setDesc(get(AMDGPU::IMPLICIT_DEF)); - DefMI->getOperand(0).setIsDead(true); - for (unsigned I = DefMI->getNumOperands() - 1; I != 0; --I) - DefMI->removeOperand(I); - if (LV) - LV->getVarInfo(DefReg).AliveBlocks.clear(); - } - - if (LIS) { - LiveInterval &DefLI = LIS->getInterval(DefReg); - - // We cannot delete the original instruction here, so hack out the use - // in the original instruction with a dummy register so we can use - // shrinkToUses to deal with any multi-use edge cases. Other targets do - // not have the complexity of deleting a use to consider here. - Register DummyReg = MRI.cloneVirtualRegister(DefReg); - for (MachineOperand &MIOp : MI.uses()) { - if (MIOp.isReg() && MIOp.getReg() == DefReg) { - MIOp.setIsUndef(true); - MIOp.setReg(DummyReg); - } - } - - LIS->shrinkToUses(&DefLI); - } - }; int64_t Imm; if (!Src0Literal && getFoldableImm(Src2, Imm, &DefMI)) { @@ -4196,10 +4213,7 @@ MachineInstr *SIInstrInfo::convertToThreeAddress(MachineInstr &MI, .add(*Src1) .addImm(Imm) .setMIFlags(MI.getFlags()); - updateLiveVariables(LV, MI, *MIB); - if (LIS) - LIS->ReplaceMachineInstrInMaps(MI, *MIB); - killDef(); + U.RemoveMIUse = DefMI; return MIB; } } @@ -4212,11 +4226,7 @@ MachineInstr *SIInstrInfo::convertToThreeAddress(MachineInstr &MI, .addImm(Imm) .add(*Src2) .setMIFlags(MI.getFlags()); - updateLiveVariables(LV, MI, *MIB); - - if (LIS) - LIS->ReplaceMachineInstrInMaps(MI, *MIB); - killDef(); + U.RemoveMIUse = DefMI; return MIB; } } @@ -4235,12 +4245,7 @@ MachineInstr *SIInstrInfo::convertToThreeAddress(MachineInstr &MI, .addImm(Imm) .add(*Src2) .setMIFlags(MI.getFlags()); - updateLiveVariables(LV, MI, *MIB); - - if (LIS) - LIS->ReplaceMachineInstrInMaps(MI, *MIB); - if (DefMI) - killDef(); + U.RemoveMIUse = DefMI; return MIB; } } @@ -4269,9 +4274,6 @@ MachineInstr *SIInstrInfo::convertToThreeAddress(MachineInstr &MI, .setMIFlags(MI.getFlags()); if (AMDGPU::hasNamedOperand(NewOpc, AMDGPU::OpName::op_sel)) MIB.addImm(OpSel ? OpSel->getImm() : 0); - updateLiveVariables(LV, MI, *MIB); - if (LIS) - LIS->ReplaceMachineInstrInMaps(MI, *MIB); return MIB; } diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.h b/llvm/lib/Target/AMDGPU/SIInstrInfo.h index df27ec1..e1d7a07 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.h +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.h @@ -88,6 +88,8 @@ private: }; class SIInstrInfo final : public AMDGPUGenInstrInfo { + struct ThreeAddressUpdates; + private: const SIRegisterInfo RI; const GCNSubtarget &ST; @@ -190,6 +192,9 @@ private: bool resultDependsOnExec(const MachineInstr &MI) const; + MachineInstr *convertToThreeAddressImpl(MachineInstr &MI, + ThreeAddressUpdates &Updates) const; + protected: /// If the specific machine instruction is a instruction that moves/copies /// value from one register to another register return destination and source diff --git a/llvm/lib/Target/AMDGPU/SIInstructions.td b/llvm/lib/Target/AMDGPU/SIInstructions.td index 27e5ee9c..6f1feb1 100644 --- a/llvm/lib/Target/AMDGPU/SIInstructions.td +++ b/llvm/lib/Target/AMDGPU/SIInstructions.td @@ -2223,8 +2223,8 @@ def : GCNPat < def : GCNPat < (DivergentUnaryFrag<fneg> (v2f32 VReg_64:$src)), - (V_PK_ADD_F32 11 /* OP_SEL_1 | NEG_LO | HEG_HI */, VReg_64:$src, - 11 /* OP_SEL_1 | NEG_LO | HEG_HI */, (i64 0), + (V_PK_ADD_F32 !or(SRCMODS.OP_SEL_1, SRCMODS.NEG, SRCMODS.NEG_HI), VReg_64:$src, + !or(SRCMODS.OP_SEL_1, SRCMODS.NEG, SRCMODS.NEG_HI), (i64 0), 0, 0, 0, 0, 0) > { let SubtargetPredicate = HasPackedFP32Ops; @@ -3481,30 +3481,6 @@ def : GCNPat< >; } // End True16Predicate -let True16Predicate = UseRealTrue16Insts in { -def : GCNPat< - (fcanonicalize (f16 (VOP3Mods f16:$src, i32:$src_mods))), - (V_MUL_F16_t16_e64 0, (i16 CONST.FP16_ONE), $src_mods, $src, 0/*Clamp*/, /*omod*/0, /*opsel*/0) ->; - -def : GCNPat< - (fcanonicalize (f16 (fneg (VOP3Mods f16:$src, i32:$src_mods)))), - (V_MUL_F16_t16_e64 0, (i16 CONST.FP16_NEG_ONE), $src_mods, $src, 0/*Clamp*/, /*omod*/0, /*opsel*/0) ->; -} // End True16Predicate - -let True16Predicate = UseFakeTrue16Insts in { -def : GCNPat< - (fcanonicalize (f16 (VOP3Mods f16:$src, i32:$src_mods))), - (V_MUL_F16_fake16_e64 0, (i32 CONST.FP16_ONE), $src_mods, $src) ->; - -def : GCNPat< - (fcanonicalize (f16 (fneg (VOP3Mods f16:$src, i32:$src_mods)))), - (V_MUL_F16_fake16_e64 0, (i32 CONST.FP16_NEG_ONE), $src_mods, $src) ->; -} // End True16Predicate - def : GCNPat< (fcanonicalize (v2f16 (VOP3PMods v2f16:$src, i32:$src_mods))), (V_PK_MUL_F16 0, (i32 CONST.FP16_ONE), $src_mods, $src, DSTCLAMP.NONE) diff --git a/llvm/lib/Target/AMDGPU/SIMemoryLegalizer.cpp b/llvm/lib/Target/AMDGPU/SIMemoryLegalizer.cpp index 484861d..07264d9 100644 --- a/llvm/lib/Target/AMDGPU/SIMemoryLegalizer.cpp +++ b/llvm/lib/Target/AMDGPU/SIMemoryLegalizer.cpp @@ -25,6 +25,7 @@ #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/MemoryModelRelaxationAnnotations.h" #include "llvm/IR/PassManager.h" +#include "llvm/Support/AMDGPUAddrSpace.h" #include "llvm/Support/AtomicOrdering.h" #include "llvm/TargetParser/TargetParser.h" @@ -277,6 +278,12 @@ public: /// rmw operation, "std::nullopt" otherwise. std::optional<SIMemOpInfo> getAtomicCmpxchgOrRmwInfo(const MachineBasicBlock::iterator &MI) const; + + /// \returns DMA to LDS info if \p MI is as a direct-to/from-LDS load/store, + /// along with an indication of whether this is a load or store. If it is not + /// a direct-to-LDS operation, returns std::nullopt. + std::optional<SIMemOpInfo> + getLDSDMAInfo(const MachineBasicBlock::iterator &MI) const; }; class SICacheControl { @@ -360,11 +367,13 @@ public: /// between memory instructions to enforce the order they become visible as /// observed by other memory instructions executing in memory scope \p Scope. /// \p IsCrossAddrSpaceOrdering indicates if the memory ordering is between - /// address spaces. Returns true iff any instructions inserted. + /// address spaces. If \p AtomicsOnly is true, only insert waits for counters + /// that are used by atomic instructions. + /// Returns true iff any instructions inserted. virtual bool insertWait(MachineBasicBlock::iterator &MI, SIAtomicScope Scope, SIAtomicAddrSpace AddrSpace, SIMemOp Op, bool IsCrossAddrSpaceOrdering, Position Pos, - AtomicOrdering Order) const = 0; + AtomicOrdering Order, bool AtomicsOnly) const = 0; /// Inserts any necessary instructions at position \p Pos relative to /// instruction \p MI to ensure any subsequent memory instructions of this @@ -388,12 +397,6 @@ public: bool IsCrossAddrSpaceOrdering, Position Pos) const = 0; - /// Inserts any necessary instructions before the barrier start instruction - /// \p MI in order to support pairing of barriers and fences. - virtual bool insertBarrierStart(MachineBasicBlock::iterator &MI) const { - return false; - }; - /// Virtual destructor to allow derivations to be deleted. virtual ~SICacheControl() = default; }; @@ -437,7 +440,7 @@ public: bool insertWait(MachineBasicBlock::iterator &MI, SIAtomicScope Scope, SIAtomicAddrSpace AddrSpace, SIMemOp Op, bool IsCrossAddrSpaceOrdering, Position Pos, - AtomicOrdering Order) const override; + AtomicOrdering Order, bool AtomicsOnly) const override; bool insertAcquire(MachineBasicBlock::iterator &MI, SIAtomicScope Scope, @@ -484,7 +487,7 @@ public: bool insertWait(MachineBasicBlock::iterator &MI, SIAtomicScope Scope, SIAtomicAddrSpace AddrSpace, SIMemOp Op, bool IsCrossAddrSpaceOrdering, Position Pos, - AtomicOrdering Order) const override; + AtomicOrdering Order, bool AtomicsOnly) const override; bool insertAcquire(MachineBasicBlock::iterator &MI, SIAtomicScope Scope, @@ -572,14 +575,10 @@ public: bool insertWait(MachineBasicBlock::iterator &MI, SIAtomicScope Scope, SIAtomicAddrSpace AddrSpace, SIMemOp Op, bool IsCrossAddrSpaceOrdering, Position Pos, - AtomicOrdering Order) const override; - - bool insertAcquire(MachineBasicBlock::iterator &MI, - SIAtomicScope Scope, - SIAtomicAddrSpace AddrSpace, - Position Pos) const override; + AtomicOrdering Order, bool AtomicsOnly) const override; - bool insertBarrierStart(MachineBasicBlock::iterator &MI) const override; + bool insertAcquire(MachineBasicBlock::iterator &MI, SIAtomicScope Scope, + SIAtomicAddrSpace AddrSpace, Position Pos) const override; }; class SIGfx11CacheControl : public SIGfx10CacheControl { @@ -629,7 +628,7 @@ public: bool insertWait(MachineBasicBlock::iterator &MI, SIAtomicScope Scope, SIAtomicAddrSpace AddrSpace, SIMemOp Op, bool IsCrossAddrSpaceOrdering, Position Pos, - AtomicOrdering Order) const override; + AtomicOrdering Order, bool AtomicsOnly) const override; bool insertAcquire(MachineBasicBlock::iterator &MI, SIAtomicScope Scope, SIAtomicAddrSpace AddrSpace, Position Pos) const override; @@ -701,6 +700,9 @@ private: /// instructions are added/deleted or \p MI is modified, false otherwise. bool expandAtomicCmpxchgOrRmw(const SIMemOpInfo &MOI, MachineBasicBlock::iterator &MI); + /// Expands LDS DMA operation \p MI. Returns true if instructions are + /// added/deleted or \p MI is modified, false otherwise. + bool expandLDSDMA(const SIMemOpInfo &MOI, MachineBasicBlock::iterator &MI); public: SIMemoryLegalizer(const MachineModuleInfo &MMI) : MMI(MMI) {}; @@ -830,6 +832,9 @@ SIAtomicAddrSpace SIMemOpAccess::toSIAtomicAddrSpace(unsigned AS) const { return SIAtomicAddrSpace::SCRATCH; if (AS == AMDGPUAS::REGION_ADDRESS) return SIAtomicAddrSpace::GDS; + if (AS == AMDGPUAS::BUFFER_FAT_POINTER || AS == AMDGPUAS::BUFFER_RESOURCE || + AS == AMDGPUAS::BUFFER_STRIDED_POINTER) + return SIAtomicAddrSpace::GLOBAL; return SIAtomicAddrSpace::OTHER; } @@ -985,6 +990,16 @@ std::optional<SIMemOpInfo> SIMemOpAccess::getAtomicCmpxchgOrRmwInfo( return constructFromMIWithMMO(MI); } +std::optional<SIMemOpInfo> +SIMemOpAccess::getLDSDMAInfo(const MachineBasicBlock::iterator &MI) const { + assert(MI->getDesc().TSFlags & SIInstrFlags::maybeAtomic); + + if (!SIInstrInfo::isLDSDMA(*MI)) + return std::nullopt; + + return constructFromMIWithMMO(MI); +} + SICacheControl::SICacheControl(const GCNSubtarget &ST) : ST(ST) { TII = ST.getInstrInfo(); IV = getIsaVersion(ST.getCPU()); @@ -1097,7 +1112,7 @@ bool SIGfx6CacheControl::enableVolatileAndOrNonTemporal( // Only handle load and store, not atomic read-modify-write insructions. The // latter use glc to indicate if the atomic returns a result and so must not // be used for cache control. - assert(MI->mayLoad() ^ MI->mayStore()); + assert((MI->mayLoad() ^ MI->mayStore()) || SIInstrInfo::isLDSDMA(*MI)); // Only update load and store, not LLVM IR atomic read-modify-write // instructions. The latter are always marked as volatile so cannot sensibly @@ -1120,7 +1135,8 @@ bool SIGfx6CacheControl::enableVolatileAndOrNonTemporal( // observable outside the program, so no need to cause a waitcnt for LDS // address space operations. Changed |= insertWait(MI, SIAtomicScope::SYSTEM, AddrSpace, Op, false, - Position::AFTER, AtomicOrdering::Unordered); + Position::AFTER, AtomicOrdering::Unordered, + /*AtomicsOnly=*/false); return Changed; } @@ -1140,7 +1156,8 @@ bool SIGfx6CacheControl::insertWait(MachineBasicBlock::iterator &MI, SIAtomicScope Scope, SIAtomicAddrSpace AddrSpace, SIMemOp Op, bool IsCrossAddrSpaceOrdering, Position Pos, - AtomicOrdering Order) const { + AtomicOrdering Order, + bool AtomicsOnly) const { bool Changed = false; MachineBasicBlock &MBB = *MI->getParent(); @@ -1294,7 +1311,8 @@ bool SIGfx6CacheControl::insertRelease(MachineBasicBlock::iterator &MI, bool IsCrossAddrSpaceOrdering, Position Pos) const { return insertWait(MI, Scope, AddrSpace, SIMemOp::LOAD | SIMemOp::STORE, - IsCrossAddrSpaceOrdering, Pos, AtomicOrdering::Release); + IsCrossAddrSpaceOrdering, Pos, AtomicOrdering::Release, + /*AtomicsOnly=*/false); } bool SIGfx7CacheControl::insertAcquire(MachineBasicBlock::iterator &MI, @@ -1424,7 +1442,7 @@ bool SIGfx90ACacheControl::enableVolatileAndOrNonTemporal( // Only handle load and store, not atomic read-modify-write insructions. The // latter use glc to indicate if the atomic returns a result and so must not // be used for cache control. - assert(MI->mayLoad() ^ MI->mayStore()); + assert((MI->mayLoad() ^ MI->mayStore()) || SIInstrInfo::isLDSDMA(*MI)); // Only update load and store, not LLVM IR atomic read-modify-write // instructions. The latter are always marked as volatile so cannot sensibly @@ -1447,7 +1465,8 @@ bool SIGfx90ACacheControl::enableVolatileAndOrNonTemporal( // observable outside the program, so no need to cause a waitcnt for LDS // address space operations. Changed |= insertWait(MI, SIAtomicScope::SYSTEM, AddrSpace, Op, false, - Position::AFTER, AtomicOrdering::Unordered); + Position::AFTER, AtomicOrdering::Unordered, + /*AtomicsOnly=*/false); return Changed; } @@ -1467,8 +1486,8 @@ bool SIGfx90ACacheControl::insertWait(MachineBasicBlock::iterator &MI, SIAtomicScope Scope, SIAtomicAddrSpace AddrSpace, SIMemOp Op, bool IsCrossAddrSpaceOrdering, - Position Pos, - AtomicOrdering Order) const { + Position Pos, AtomicOrdering Order, + bool AtomicsOnly) const { if (ST.isTgSplitEnabled()) { // In threadgroup split mode the waves of a work-group can be executing on // different CUs. Therefore need to wait for global or GDS memory operations @@ -1488,7 +1507,8 @@ bool SIGfx90ACacheControl::insertWait(MachineBasicBlock::iterator &MI, AddrSpace &= ~SIAtomicAddrSpace::LDS; } return SIGfx7CacheControl::insertWait(MI, Scope, AddrSpace, Op, - IsCrossAddrSpaceOrdering, Pos, Order); + IsCrossAddrSpaceOrdering, Pos, Order, + AtomicsOnly); } bool SIGfx90ACacheControl::insertAcquire(MachineBasicBlock::iterator &MI, @@ -1726,7 +1746,7 @@ bool SIGfx940CacheControl::enableVolatileAndOrNonTemporal( // Only handle load and store, not atomic read-modify-write insructions. The // latter use glc to indicate if the atomic returns a result and so must not // be used for cache control. - assert(MI->mayLoad() ^ MI->mayStore()); + assert((MI->mayLoad() ^ MI->mayStore()) || SIInstrInfo::isLDSDMA(*MI)); // Only update load and store, not LLVM IR atomic read-modify-write // instructions. The latter are always marked as volatile so cannot sensibly @@ -1747,7 +1767,8 @@ bool SIGfx940CacheControl::enableVolatileAndOrNonTemporal( // observable outside the program, so no need to cause a waitcnt for LDS // address space operations. Changed |= insertWait(MI, SIAtomicScope::SYSTEM, AddrSpace, Op, false, - Position::AFTER, AtomicOrdering::Unordered); + Position::AFTER, AtomicOrdering::Unordered, + /*AtomicsOnly=*/false); return Changed; } @@ -1904,7 +1925,8 @@ bool SIGfx940CacheControl::insertRelease(MachineBasicBlock::iterator &MI, // Ensure the necessary S_WAITCNT needed by any "BUFFER_WBL2" as well as other // S_WAITCNT needed. Changed |= insertWait(MI, Scope, AddrSpace, SIMemOp::LOAD | SIMemOp::STORE, - IsCrossAddrSpaceOrdering, Pos, AtomicOrdering::Release); + IsCrossAddrSpaceOrdering, Pos, AtomicOrdering::Release, + /*AtomicsOnly=*/false); return Changed; } @@ -1959,7 +1981,7 @@ bool SIGfx10CacheControl::enableVolatileAndOrNonTemporal( // Only handle load and store, not atomic read-modify-write insructions. The // latter use glc to indicate if the atomic returns a result and so must not // be used for cache control. - assert(MI->mayLoad() ^ MI->mayStore()); + assert((MI->mayLoad() ^ MI->mayStore()) || SIInstrInfo::isLDSDMA(*MI)); // Only update load and store, not LLVM IR atomic read-modify-write // instructions. The latter are always marked as volatile so cannot sensibly @@ -1984,7 +2006,8 @@ bool SIGfx10CacheControl::enableVolatileAndOrNonTemporal( // observable outside the program, so no need to cause a waitcnt for LDS // address space operations. Changed |= insertWait(MI, SIAtomicScope::SYSTEM, AddrSpace, Op, false, - Position::AFTER, AtomicOrdering::Unordered); + Position::AFTER, AtomicOrdering::Unordered, + /*AtomicsOnly=*/false); return Changed; } @@ -2007,7 +2030,8 @@ bool SIGfx10CacheControl::insertWait(MachineBasicBlock::iterator &MI, SIAtomicScope Scope, SIAtomicAddrSpace AddrSpace, SIMemOp Op, bool IsCrossAddrSpaceOrdering, - Position Pos, AtomicOrdering Order) const { + Position Pos, AtomicOrdering Order, + bool AtomicsOnly) const { bool Changed = false; MachineBasicBlock &MBB = *MI->getParent(); @@ -2035,8 +2059,11 @@ bool SIGfx10CacheControl::insertWait(MachineBasicBlock::iterator &MI, // the WGP. Therefore need to wait for operations to complete to ensure // they are visible to waves in the other CU as the L0 is per CU. // Otherwise in CU mode and all waves of a work-group are on the same CU - // which shares the same L0. - if (!ST.isCuModeEnabled()) { + // which shares the same L0. Note that we still need to wait when + // performing a release in this mode to respect the transitivity of + // happens-before, e.g. other waves of the workgroup must be able to + // release the memory from another wave at a wider scope. + if (!ST.isCuModeEnabled() || isReleaseOrStronger(Order)) { if ((Op & SIMemOp::LOAD) != SIMemOp::NONE) VMCnt |= true; if ((Op & SIMemOp::STORE) != SIMemOp::NONE) @@ -2191,22 +2218,6 @@ bool SIGfx10CacheControl::insertAcquire(MachineBasicBlock::iterator &MI, return Changed; } -bool SIGfx10CacheControl::insertBarrierStart( - MachineBasicBlock::iterator &MI) const { - // We need to wait on vm_vsrc so barriers can pair with fences in GFX10+ CU - // mode. This is because a CU mode release fence does not emit any wait, which - // is fine when only dealing with vmem, but isn't sufficient in the presence - // of barriers which do not go through vmem. - // GFX12.5 does not require this additional wait. - if (!ST.isCuModeEnabled() || ST.hasGFX1250Insts()) - return false; - - BuildMI(*MI->getParent(), MI, MI->getDebugLoc(), - TII->get(AMDGPU::S_WAITCNT_DEPCTR)) - .addImm(AMDGPU::DepCtr::encodeFieldVmVsrc(0)); - return true; -} - bool SIGfx11CacheControl::enableLoadCacheBypass( const MachineBasicBlock::iterator &MI, SIAtomicScope Scope, SIAtomicAddrSpace AddrSpace) const { @@ -2255,7 +2266,7 @@ bool SIGfx11CacheControl::enableVolatileAndOrNonTemporal( // Only handle load and store, not atomic read-modify-write insructions. The // latter use glc to indicate if the atomic returns a result and so must not // be used for cache control. - assert(MI->mayLoad() ^ MI->mayStore()); + assert((MI->mayLoad() ^ MI->mayStore()) || SIInstrInfo::isLDSDMA(*MI)); // Only update load and store, not LLVM IR atomic read-modify-write // instructions. The latter are always marked as volatile so cannot sensibly @@ -2281,7 +2292,8 @@ bool SIGfx11CacheControl::enableVolatileAndOrNonTemporal( // observable outside the program, so no need to cause a waitcnt for LDS // address space operations. Changed |= insertWait(MI, SIAtomicScope::SYSTEM, AddrSpace, Op, false, - Position::AFTER, AtomicOrdering::Unordered); + Position::AFTER, AtomicOrdering::Unordered, + /*AtomicsOnly=*/false); return Changed; } @@ -2354,7 +2366,8 @@ bool SIGfx12CacheControl::insertWait(MachineBasicBlock::iterator &MI, SIAtomicScope Scope, SIAtomicAddrSpace AddrSpace, SIMemOp Op, bool IsCrossAddrSpaceOrdering, - Position Pos, AtomicOrdering Order) const { + Position Pos, AtomicOrdering Order, + bool AtomicsOnly) const { bool Changed = false; MachineBasicBlock &MBB = *MI->getParent(); @@ -2383,15 +2396,20 @@ bool SIGfx12CacheControl::insertWait(MachineBasicBlock::iterator &MI, // In WGP mode the waves of a work-group can be executing on either CU // of the WGP. Therefore need to wait for operations to complete to // ensure they are visible to waves in the other CU as the L0 is per CU. + // // Otherwise in CU mode and all waves of a work-group are on the same CU - // which shares the same L0. + // which shares the same L0. Note that we still need to wait when + // performing a release in this mode to respect the transitivity of + // happens-before, e.g. other waves of the workgroup must be able to + // release the memory from another wave at a wider scope. // // GFX12.5: // CU$ has two ports. To ensure operations are visible at the workgroup // level, we need to ensure all operations in this port have completed // so the other SIMDs in the WG can see them. There is no ordering // guarantee between the ports. - if (!ST.isCuModeEnabled() || ST.hasGFX1250Insts()) { + if (!ST.isCuModeEnabled() || ST.hasGFX1250Insts() || + isReleaseOrStronger(Order)) { if ((Op & SIMemOp::LOAD) != SIMemOp::NONE) LOADCnt |= true; if ((Op & SIMemOp::STORE) != SIMemOp::NONE) @@ -2444,7 +2462,7 @@ bool SIGfx12CacheControl::insertWait(MachineBasicBlock::iterator &MI, // // This also applies to fences. Fences cannot pair with an instruction // tracked with bvh/samplecnt as we don't have any atomics that do that. - if (Order != AtomicOrdering::Acquire && ST.hasImageInsts()) { + if (!AtomicsOnly && ST.hasImageInsts()) { BuildMI(MBB, MI, DL, TII->get(AMDGPU::S_WAIT_BVHCNT_soft)).addImm(0); BuildMI(MBB, MI, DL, TII->get(AMDGPU::S_WAIT_SAMPLECNT_soft)).addImm(0); } @@ -2587,7 +2605,8 @@ bool SIGfx12CacheControl::insertRelease(MachineBasicBlock::iterator &MI, // complete, whether we inserted a WB or not. If we inserted a WB (storecnt), // we of course need to wait for that as well. Changed |= insertWait(MI, Scope, AddrSpace, SIMemOp::LOAD | SIMemOp::STORE, - IsCrossAddrSpaceOrdering, Pos, AtomicOrdering::Release); + IsCrossAddrSpaceOrdering, Pos, AtomicOrdering::Release, + /*AtomicsOnly=*/false); return Changed; } @@ -2597,7 +2616,7 @@ bool SIGfx12CacheControl::enableVolatileAndOrNonTemporal( bool IsVolatile, bool IsNonTemporal, bool IsLastUse = false) const { // Only handle load and store, not atomic read-modify-write instructions. - assert(MI->mayLoad() ^ MI->mayStore()); + assert((MI->mayLoad() ^ MI->mayStore()) || SIInstrInfo::isLDSDMA(*MI)); // Only update load and store, not LLVM IR atomic read-modify-write // instructions. The latter are always marked as volatile so cannot sensibly @@ -2624,7 +2643,8 @@ bool SIGfx12CacheControl::enableVolatileAndOrNonTemporal( // observable outside the program, so no need to cause a waitcnt for LDS // address space operations. Changed |= insertWait(MI, SIAtomicScope::SYSTEM, AddrSpace, Op, false, - Position::AFTER, AtomicOrdering::Unordered); + Position::AFTER, AtomicOrdering::Unordered, + /*AtomicsOnly=*/false); } return Changed; @@ -2748,13 +2768,15 @@ bool SIMemoryLegalizer::expandLoad(const SIMemOpInfo &MOI, Changed |= CC->insertWait(MI, MOI.getScope(), MOI.getOrderingAddrSpace(), SIMemOp::LOAD | SIMemOp::STORE, MOI.getIsCrossAddressSpaceOrdering(), - Position::BEFORE, Order); + Position::BEFORE, Order, /*AtomicsOnly=*/false); if (Order == AtomicOrdering::Acquire || Order == AtomicOrdering::SequentiallyConsistent) { - Changed |= CC->insertWait( - MI, MOI.getScope(), MOI.getInstrAddrSpace(), SIMemOp::LOAD, - MOI.getIsCrossAddressSpaceOrdering(), Position::AFTER, Order); + // The wait below only needs to wait on the prior atomic. + Changed |= + CC->insertWait(MI, MOI.getScope(), MOI.getInstrAddrSpace(), + SIMemOp::LOAD, MOI.getIsCrossAddressSpaceOrdering(), + Position::AFTER, Order, /*AtomicsOnly=*/true); Changed |= CC->insertAcquire(MI, MOI.getScope(), MOI.getOrderingAddrSpace(), Position::AFTER); @@ -2830,9 +2852,11 @@ bool SIMemoryLegalizer::expandAtomicFence(const SIMemOpInfo &MOI, if (MOI.isAtomic()) { const AtomicOrdering Order = MOI.getOrdering(); if (Order == AtomicOrdering::Acquire) { - Changed |= CC->insertWait( - MI, MOI.getScope(), OrderingAddrSpace, SIMemOp::LOAD | SIMemOp::STORE, - MOI.getIsCrossAddressSpaceOrdering(), Position::BEFORE, Order); + // Acquire fences only need to wait on the previous atomic they pair with. + Changed |= CC->insertWait(MI, MOI.getScope(), OrderingAddrSpace, + SIMemOp::LOAD | SIMemOp::STORE, + MOI.getIsCrossAddressSpaceOrdering(), + Position::BEFORE, Order, /*AtomicsOnly=*/true); } if (Order == AtomicOrdering::Release || @@ -2897,10 +2921,12 @@ bool SIMemoryLegalizer::expandAtomicCmpxchgOrRmw(const SIMemOpInfo &MOI, Order == AtomicOrdering::SequentiallyConsistent || MOI.getFailureOrdering() == AtomicOrdering::Acquire || MOI.getFailureOrdering() == AtomicOrdering::SequentiallyConsistent) { - Changed |= CC->insertWait( - MI, MOI.getScope(), MOI.getInstrAddrSpace(), - isAtomicRet(*MI) ? SIMemOp::LOAD : SIMemOp::STORE, - MOI.getIsCrossAddressSpaceOrdering(), Position::AFTER, Order); + // Only wait on the previous atomic. + Changed |= + CC->insertWait(MI, MOI.getScope(), MOI.getInstrAddrSpace(), + isAtomicRet(*MI) ? SIMemOp::LOAD : SIMemOp::STORE, + MOI.getIsCrossAddressSpaceOrdering(), Position::AFTER, + Order, /*AtomicsOnly=*/true); Changed |= CC->insertAcquire(MI, MOI.getScope(), MOI.getOrderingAddrSpace(), Position::AFTER); @@ -2913,6 +2939,23 @@ bool SIMemoryLegalizer::expandAtomicCmpxchgOrRmw(const SIMemOpInfo &MOI, return Changed; } +bool SIMemoryLegalizer::expandLDSDMA(const SIMemOpInfo &MOI, + MachineBasicBlock::iterator &MI) { + assert(MI->mayLoad() && MI->mayStore()); + + // The volatility or nontemporal-ness of the operation is a + // function of the global memory, not the LDS. + SIMemOp OpKind = + SIInstrInfo::mayWriteLDSThroughDMA(*MI) ? SIMemOp::LOAD : SIMemOp::STORE; + + // Handle volatile and/or nontemporal markers on direct-to-LDS loads and + // stores. The operation is treated as a volatile/nontemporal store + // to its second argument. + return CC->enableVolatileAndOrNonTemporal( + MI, MOI.getInstrAddrSpace(), OpKind, MOI.isVolatile(), + MOI.isNonTemporal(), MOI.isLastUse()); +} + bool SIMemoryLegalizerLegacy::runOnMachineFunction(MachineFunction &MF) { const MachineModuleInfo &MMI = getAnalysis<MachineModuleInfoWrapperPass>().getMMI(); @@ -2956,22 +2999,20 @@ bool SIMemoryLegalizer::run(MachineFunction &MF) { MI = II->getIterator(); } - if (ST.getInstrInfo()->isBarrierStart(MI->getOpcode())) { - Changed |= CC->insertBarrierStart(MI); - continue; - } - if (!(MI->getDesc().TSFlags & SIInstrFlags::maybeAtomic)) continue; - if (const auto &MOI = MOA.getLoadInfo(MI)) + if (const auto &MOI = MOA.getLoadInfo(MI)) { Changed |= expandLoad(*MOI, MI); - else if (const auto &MOI = MOA.getStoreInfo(MI)) { + } else if (const auto &MOI = MOA.getStoreInfo(MI)) { Changed |= expandStore(*MOI, MI); - } else if (const auto &MOI = MOA.getAtomicFenceInfo(MI)) + } else if (const auto &MOI = MOA.getLDSDMAInfo(MI)) { + Changed |= expandLDSDMA(*MOI, MI); + } else if (const auto &MOI = MOA.getAtomicFenceInfo(MI)) { Changed |= expandAtomicFence(*MOI, MI); - else if (const auto &MOI = MOA.getAtomicCmpxchgOrRmwInfo(MI)) + } else if (const auto &MOI = MOA.getAtomicCmpxchgOrRmwInfo(MI)) { Changed |= expandAtomicCmpxchgOrRmw(*MOI, MI); + } } } diff --git a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp index 96ee69c..406f4c1 100644 --- a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp +++ b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp @@ -882,7 +882,7 @@ static bool producesFalseLanesZero(MachineInstr &MI, continue; // Skip the lr predicate reg int PIdx = llvm::findFirstVPTPredOperandIdx(MI); - if (PIdx != -1 && (int)MO.getOperandNo() == PIdx + 2) + if (PIdx != -1 && MO.getOperandNo() == PIdx + ARM::SUBOP_vpred_n_tp_reg) continue; // Check that this instruction will produce zeros in its false lanes: diff --git a/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp b/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp index ce59ae0..2cd5f02 100644 --- a/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp +++ b/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp @@ -407,9 +407,9 @@ Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { // Potentially optimising the addressing modes as we do so. auto *Ty = cast<FixedVectorType>(I->getType()); Value *Ptr = I->getArgOperand(0); - Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue(); - Value *Mask = I->getArgOperand(2); - Value *PassThru = I->getArgOperand(3); + Align Alignment = I->getParamAlign(0).valueOrOne(); + Value *Mask = I->getArgOperand(1); + Value *PassThru = I->getArgOperand(2); if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), Alignment)) @@ -458,7 +458,7 @@ Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase( if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) // Can't build an intrinsic for this return nullptr; - Value *Mask = I->getArgOperand(2); + Value *Mask = I->getArgOperand(1); if (match(Mask, m_One())) return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, {Ty, Ptr->getType()}, @@ -479,7 +479,7 @@ Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB( if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) // Can't build an intrinsic for this return nullptr; - Value *Mask = I->getArgOperand(2); + Value *Mask = I->getArgOperand(1); if (match(Mask, m_One())) return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb, {Ty, Ptr->getType()}, @@ -552,7 +552,7 @@ Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( return nullptr; Root = Extend; - Value *Mask = I->getArgOperand(2); + Value *Mask = I->getArgOperand(1); Instruction *Load = nullptr; if (!match(Mask, m_One())) Load = Builder.CreateIntrinsic( @@ -584,7 +584,7 @@ Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) { // Potentially optimising the addressing modes as we do so. Value *Input = I->getArgOperand(0); Value *Ptr = I->getArgOperand(1); - Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue(); + Align Alignment = I->getParamAlign(1).valueOrOne(); auto *Ty = cast<FixedVectorType>(Input->getType()); if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), @@ -622,7 +622,7 @@ Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase( // Can't build an intrinsic for this return nullptr; } - Value *Mask = I->getArgOperand(3); + Value *Mask = I->getArgOperand(2); // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask) LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n"); if (match(Mask, m_One())) @@ -646,7 +646,7 @@ Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB( if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) // Can't build an intrinsic for this return nullptr; - Value *Mask = I->getArgOperand(3); + Value *Mask = I->getArgOperand(2); if (match(Mask, m_One())) return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb, {Ptr->getType(), Input->getType()}, @@ -662,7 +662,7 @@ Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset( IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { using namespace PatternMatch; Value *Input = I->getArgOperand(0); - Value *Mask = I->getArgOperand(3); + Value *Mask = I->getArgOperand(2); Type *InputTy = Input->getType(); Type *MemoryTy = InputTy; diff --git a/llvm/lib/Target/ARM/MVETPAndVPTOptimisationsPass.cpp b/llvm/lib/Target/ARM/MVETPAndVPTOptimisationsPass.cpp index 5eeb4fe..413e844 100644 --- a/llvm/lib/Target/ARM/MVETPAndVPTOptimisationsPass.cpp +++ b/llvm/lib/Target/ARM/MVETPAndVPTOptimisationsPass.cpp @@ -534,7 +534,7 @@ bool MVETPAndVPTOptimisations::ConvertTailPredLoop(MachineLoop *ML, Register LR = LoopPhi->getOperand(0).getReg(); for (MachineInstr *MI : MVEInstrs) { int Idx = findFirstVPTPredOperandIdx(*MI); - MI->getOperand(Idx + 2).setReg(LR); + MI->getOperand(Idx + ARM::SUBOP_vpred_n_tp_reg).setReg(LR); } } diff --git a/llvm/lib/Target/BPF/BTFDebug.cpp b/llvm/lib/Target/BPF/BTFDebug.cpp index ba4b489..9b5fc9d 100644 --- a/llvm/lib/Target/BPF/BTFDebug.cpp +++ b/llvm/lib/Target/BPF/BTFDebug.cpp @@ -14,6 +14,7 @@ #include "BPF.h" #include "BPFCORE.h" #include "MCTargetDesc/BPFMCTargetDesc.h" +#include "llvm/BinaryFormat/Dwarf.h" #include "llvm/BinaryFormat/ELF.h" #include "llvm/CodeGen/AsmPrinter.h" #include "llvm/CodeGen/MachineModuleInfo.h" @@ -23,6 +24,7 @@ #include "llvm/MC/MCObjectFileInfo.h" #include "llvm/MC/MCSectionELF.h" #include "llvm/MC/MCStreamer.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/LineIterator.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Target/TargetLoweringObjectFile.h" @@ -301,21 +303,59 @@ void BTFTypeStruct::completeType(BTFDebug &BDebug) { BTFType.NameOff = BDebug.addString(STy->getName()); + if (STy->getTag() == dwarf::DW_TAG_variant_part) { + // Variant parts might have a discriminator, which has its own memory + // location, and variants, which share the memory location afterwards. LLVM + // DI doesn't consider discriminator as an element and instead keeps + // it as a separate reference. + // To keep BTF simple, let's represent the structure as an union with + // discriminator as the first element. + // The offsets inside variant types are already handled correctly in the + // DI. + const auto *DTy = STy->getDiscriminator(); + if (DTy) { + struct BTF::BTFMember Discriminator; + + Discriminator.NameOff = BDebug.addString(DTy->getName()); + Discriminator.Offset = DTy->getOffsetInBits(); + const auto *BaseTy = DTy->getBaseType(); + Discriminator.Type = BDebug.getTypeId(BaseTy); + + Members.push_back(Discriminator); + } + } + // Add struct/union members. const DINodeArray Elements = STy->getElements(); for (const auto *Element : Elements) { struct BTF::BTFMember BTFMember; - const auto *DDTy = cast<DIDerivedType>(Element); - BTFMember.NameOff = BDebug.addString(DDTy->getName()); - if (HasBitField) { - uint8_t BitFieldSize = DDTy->isBitField() ? DDTy->getSizeInBits() : 0; - BTFMember.Offset = BitFieldSize << 24 | DDTy->getOffsetInBits(); - } else { - BTFMember.Offset = DDTy->getOffsetInBits(); + switch (Element->getTag()) { + case dwarf::DW_TAG_member: { + const auto *DDTy = cast<DIDerivedType>(Element); + + BTFMember.NameOff = BDebug.addString(DDTy->getName()); + if (HasBitField) { + uint8_t BitFieldSize = DDTy->isBitField() ? DDTy->getSizeInBits() : 0; + BTFMember.Offset = BitFieldSize << 24 | DDTy->getOffsetInBits(); + } else { + BTFMember.Offset = DDTy->getOffsetInBits(); + } + const auto *BaseTy = tryRemoveAtomicType(DDTy->getBaseType()); + BTFMember.Type = BDebug.getTypeId(BaseTy); + break; + } + case dwarf::DW_TAG_variant_part: { + const auto *DCTy = dyn_cast<DICompositeType>(Element); + + BTFMember.NameOff = BDebug.addString(DCTy->getName()); + BTFMember.Offset = DCTy->getOffsetInBits(); + BTFMember.Type = BDebug.getTypeId(DCTy); + break; + } + default: + llvm_unreachable("Unexpected DI tag of a struct/union element"); } - const auto *BaseTy = tryRemoveAtomicType(DDTy->getBaseType()); - BTFMember.Type = BDebug.getTypeId(BaseTy); Members.push_back(BTFMember); } } @@ -672,16 +712,28 @@ void BTFDebug::visitStructType(const DICompositeType *CTy, bool IsStruct, uint32_t &TypeId) { const DINodeArray Elements = CTy->getElements(); uint32_t VLen = Elements.size(); + // Variant parts might have a discriminator. LLVM DI doesn't consider it as + // an element and instead keeps it as a separate reference. But we represent + // it as an element in BTF. + if (CTy->getTag() == dwarf::DW_TAG_variant_part) { + const auto *DTy = CTy->getDiscriminator(); + if (DTy) { + visitTypeEntry(DTy); + VLen++; + } + } if (VLen > BTF::MAX_VLEN) return; // Check whether we have any bitfield members or not bool HasBitField = false; for (const auto *Element : Elements) { - auto E = cast<DIDerivedType>(Element); - if (E->isBitField()) { - HasBitField = true; - break; + if (Element->getTag() == dwarf::DW_TAG_member) { + auto E = cast<DIDerivedType>(Element); + if (E->isBitField()) { + HasBitField = true; + break; + } } } @@ -696,9 +748,22 @@ void BTFDebug::visitStructType(const DICompositeType *CTy, bool IsStruct, // Visit all struct members. int FieldNo = 0; for (const auto *Element : Elements) { - const auto Elem = cast<DIDerivedType>(Element); - visitTypeEntry(Elem); - processDeclAnnotations(Elem->getAnnotations(), TypeId, FieldNo); + switch (Element->getTag()) { + case dwarf::DW_TAG_member: { + const auto Elem = cast<DIDerivedType>(Element); + visitTypeEntry(Elem); + processDeclAnnotations(Elem->getAnnotations(), TypeId, FieldNo); + break; + } + case dwarf::DW_TAG_variant_part: { + const auto Elem = cast<DICompositeType>(Element); + visitTypeEntry(Elem); + processDeclAnnotations(Elem->getAnnotations(), TypeId, FieldNo); + break; + } + default: + llvm_unreachable("Unexpected DI tag of a struct/union element"); + } FieldNo++; } } @@ -781,16 +846,25 @@ void BTFDebug::visitFwdDeclType(const DICompositeType *CTy, bool IsUnion, void BTFDebug::visitCompositeType(const DICompositeType *CTy, uint32_t &TypeId) { auto Tag = CTy->getTag(); - if (Tag == dwarf::DW_TAG_structure_type || Tag == dwarf::DW_TAG_union_type) { + switch (Tag) { + case dwarf::DW_TAG_structure_type: + case dwarf::DW_TAG_union_type: + case dwarf::DW_TAG_variant_part: // Handle forward declaration differently as it does not have members. if (CTy->isForwardDecl()) visitFwdDeclType(CTy, Tag == dwarf::DW_TAG_union_type, TypeId); else visitStructType(CTy, Tag == dwarf::DW_TAG_structure_type, TypeId); - } else if (Tag == dwarf::DW_TAG_array_type) + break; + case dwarf::DW_TAG_array_type: visitArrayType(CTy, TypeId); - else if (Tag == dwarf::DW_TAG_enumeration_type) + break; + case dwarf::DW_TAG_enumeration_type: visitEnumType(CTy, TypeId); + break; + default: + llvm_unreachable("Unexpected DI tag of a composite type"); + } } bool BTFDebug::IsForwardDeclCandidate(const DIType *Base) { diff --git a/llvm/lib/Target/Hexagon/CMakeLists.txt b/llvm/lib/Target/Hexagon/CMakeLists.txt index d758260..1a5f096 100644 --- a/llvm/lib/Target/Hexagon/CMakeLists.txt +++ b/llvm/lib/Target/Hexagon/CMakeLists.txt @@ -54,6 +54,7 @@ add_llvm_target(HexagonCodeGen HexagonOptAddrMode.cpp HexagonOptimizeSZextends.cpp HexagonPeephole.cpp + HexagonQFPOptimizer.cpp HexagonRDFOpt.cpp HexagonRegisterInfo.cpp HexagonSelectionDAGInfo.cpp diff --git a/llvm/lib/Target/Hexagon/Hexagon.h b/llvm/lib/Target/Hexagon/Hexagon.h index 109aba5..422ab20 100644 --- a/llvm/lib/Target/Hexagon/Hexagon.h +++ b/llvm/lib/Target/Hexagon/Hexagon.h @@ -67,6 +67,8 @@ void initializeHexagonPeepholePass(PassRegistry &); void initializeHexagonSplitConst32AndConst64Pass(PassRegistry &); void initializeHexagonVectorPrintPass(PassRegistry &); +void initializeHexagonQFPOptimizerPass(PassRegistry &); + Pass *createHexagonLoopIdiomPass(); Pass *createHexagonVectorLoopCarriedReuseLegacyPass(); @@ -112,6 +114,7 @@ FunctionPass *createHexagonVectorCombineLegacyPass(); FunctionPass *createHexagonVectorPrint(); FunctionPass *createHexagonVExtract(); FunctionPass *createHexagonExpandCondsets(); +FunctionPass *createHexagonQFPOptimizer(); } // end namespace llvm; diff --git a/llvm/lib/Target/Hexagon/HexagonGenInsert.cpp b/llvm/lib/Target/Hexagon/HexagonGenInsert.cpp index 4ddbe7a..ff876f6 100644 --- a/llvm/lib/Target/Hexagon/HexagonGenInsert.cpp +++ b/llvm/lib/Target/Hexagon/HexagonGenInsert.cpp @@ -920,6 +920,10 @@ void HexagonGenInsert::collectInBlock(MachineBasicBlock *B, // successors have been processed. RegisterSet BlockDefs, InsDefs; for (MachineInstr &MI : *B) { + // Stop if the map size is too large. + if (IFMap.size() >= MaxIFMSize) + break; + InsDefs.clear(); getInstrDefs(&MI, InsDefs); // Leave those alone. They are more transparent than "insert". @@ -942,8 +946,8 @@ void HexagonGenInsert::collectInBlock(MachineBasicBlock *B, findRecordInsertForms(VR, AVs); // Stop if the map size is too large. - if (IFMap.size() > MaxIFMSize) - return; + if (IFMap.size() >= MaxIFMSize) + break; } } diff --git a/llvm/lib/Target/Hexagon/HexagonQFPOptimizer.cpp b/llvm/lib/Target/Hexagon/HexagonQFPOptimizer.cpp new file mode 100644 index 0000000..479ac90 --- /dev/null +++ b/llvm/lib/Target/Hexagon/HexagonQFPOptimizer.cpp @@ -0,0 +1,334 @@ +//===----- HexagonQFPOptimizer.cpp - Qualcomm-FP to IEEE-FP conversions +// optimizer ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Basic infrastructure for optimizing intermediate conversion instructions +// generated while performing vector floating point operations. +// Currently run at the starting of the code generation for Hexagon, cleans +// up redundant conversion instructions and replaces the uses of conversion +// with appropriate machine operand. Liveness is preserved after this pass. +// +// @note: The redundant conversion instructions are not eliminated in this pass. +// In this pass, we are only trying to replace the uses of conversion +// instructions with its appropriate QFP instruction. We are leaving the job to +// Dead instruction Elimination pass to remove redundant conversion +// instructions. +// +// Brief overview of working of this QFP optimizer. +// This version of Hexagon QFP optimizer basically iterates over each +// instruction, checks whether if it belongs to hexagon floating point HVX +// arithmetic instruction category(Add, Sub, Mul). And then it finds the unique +// definition for the machine operands corresponding to the instruction. +// +// Example: +// MachineInstruction *MI be the HVX vadd instruction +// MI -> $v0 = V6_vadd_sf $v1, $v2 +// MachineOperand *DefMI1 = MRI->getVRegDef(MI->getOperand(1).getReg()); +// MachineOperand *DefMI2 = MRI->getVRegDef(MI->getOperand(2).getReg()); +// +// In the above example, DefMI1 and DefMI2 gives the unique definitions +// corresponding to the operands($v1 and &v2 respectively) of instruction MI. +// +// If both of the definitions are not conversion instructions(V6_vconv_sf_qf32, +// V6_vconv_hf_qf16), then it will skip optimizing the current instruction and +// iterates over next instruction. +// +// If one the definitions is conversion instruction then our pass will replace +// the arithmetic instruction with its corresponding mix variant. +// In the above example, if $v1 is conversion instruction +// DefMI1 -> $v1 = V6_vconv_sf_qf32 $v3 +// After Transformation: +// MI -> $v0 = V6_vadd_qf32_mix $v3, $v2 ($v1 is replaced with $v3) +// +// If both the definitions are conversion instructions then the instruction will +// be replaced with its qf variant +// In the above example, if $v1 and $v2 are conversion instructions +// DefMI1 -> $v1 = V6_vconv_sf_qf32 $v3 +// DefMI2 -> $v2 = V6_vconv_sf_qf32 $v4 +// After Transformation: +// MI -> $v0 = V6_vadd_qf32 $v3, $v4 ($v1 is replaced with $v3, $v2 is replaced +// with $v4) +// +// Currently, in this pass, we are not handling the case when the definitions +// are PHI inst. +// +//===----------------------------------------------------------------------===// +#include <unordered_set> +#define HEXAGON_QFP_OPTIMIZER "QFP optimizer pass" + +#include "Hexagon.h" +#include "HexagonInstrInfo.h" +#include "HexagonSubtarget.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineOperand.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <map> +#include <vector> + +#define DEBUG_TYPE "hexagon-qfp-optimizer" + +using namespace llvm; + +cl::opt<bool> + DisableQFOptimizer("disable-qfp-opt", cl::init(false), + cl::desc("Disable optimization of Qfloat operations.")); + +namespace { +const std::map<unsigned short, unsigned short> QFPInstMap{ + {Hexagon::V6_vadd_hf, Hexagon::V6_vadd_qf16_mix}, + {Hexagon::V6_vadd_qf16_mix, Hexagon::V6_vadd_qf16}, + {Hexagon::V6_vadd_sf, Hexagon::V6_vadd_qf32_mix}, + {Hexagon::V6_vadd_qf32_mix, Hexagon::V6_vadd_qf32}, + {Hexagon::V6_vsub_hf, Hexagon::V6_vsub_qf16_mix}, + {Hexagon::V6_vsub_qf16_mix, Hexagon::V6_vsub_qf16}, + {Hexagon::V6_vsub_sf, Hexagon::V6_vsub_qf32_mix}, + {Hexagon::V6_vsub_qf32_mix, Hexagon::V6_vsub_qf32}, + {Hexagon::V6_vmpy_qf16_hf, Hexagon::V6_vmpy_qf16_mix_hf}, + {Hexagon::V6_vmpy_qf16_mix_hf, Hexagon::V6_vmpy_qf16}, + {Hexagon::V6_vmpy_qf32_hf, Hexagon::V6_vmpy_qf32_mix_hf}, + {Hexagon::V6_vmpy_qf32_mix_hf, Hexagon::V6_vmpy_qf32_qf16}, + {Hexagon::V6_vmpy_qf32_sf, Hexagon::V6_vmpy_qf32}}; +} // namespace + +namespace llvm { + +FunctionPass *createHexagonQFPOptimizer(); +void initializeHexagonQFPOptimizerPass(PassRegistry &); + +} // namespace llvm + +namespace { + +struct HexagonQFPOptimizer : public MachineFunctionPass { +public: + static char ID; + + HexagonQFPOptimizer() : MachineFunctionPass(ID) {} + + bool runOnMachineFunction(MachineFunction &MF) override; + + bool optimizeQfp(MachineInstr *MI, MachineBasicBlock *MBB); + + StringRef getPassName() const override { return HEXAGON_QFP_OPTIMIZER; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + MachineFunctionPass::getAnalysisUsage(AU); + } + +private: + const HexagonSubtarget *HST = nullptr; + const HexagonInstrInfo *HII = nullptr; + const MachineRegisterInfo *MRI = nullptr; +}; + +char HexagonQFPOptimizer::ID = 0; +} // namespace + +INITIALIZE_PASS(HexagonQFPOptimizer, "hexagon-qfp-optimizer", + HEXAGON_QFP_OPTIMIZER, false, false) + +FunctionPass *llvm::createHexagonQFPOptimizer() { + return new HexagonQFPOptimizer(); +} + +bool HexagonQFPOptimizer::optimizeQfp(MachineInstr *MI, + MachineBasicBlock *MBB) { + + // Early exit: + // - if instruction is invalid or has too few operands (QFP ops need 2 sources + // + 1 dest), + // - or does not have a transformation mapping. + if (MI->getNumOperands() < 3) + return false; + auto It = QFPInstMap.find(MI->getOpcode()); + if (It == QFPInstMap.end()) + return false; + unsigned short InstTy = It->second; + + unsigned Op0F = 0; + unsigned Op1F = 0; + // Get the reaching defs of MI, DefMI1 and DefMI2 + MachineInstr *DefMI1 = nullptr; + MachineInstr *DefMI2 = nullptr; + + if (MI->getOperand(1).isReg()) + DefMI1 = MRI->getVRegDef(MI->getOperand(1).getReg()); + if (MI->getOperand(2).isReg()) + DefMI2 = MRI->getVRegDef(MI->getOperand(2).getReg()); + if (!DefMI1 || !DefMI2) + return false; + + MachineOperand &Res = MI->getOperand(0); + MachineInstr *Inst1 = nullptr; + MachineInstr *Inst2 = nullptr; + LLVM_DEBUG(dbgs() << "\n[Reaching Defs of operands]: "; DefMI1->dump(); + DefMI2->dump()); + + // Get the reaching defs of DefMI + if (DefMI1->getNumOperands() > 1 && DefMI1->getOperand(1).isReg() && + DefMI1->getOperand(1).getReg().isVirtual()) + Inst1 = MRI->getVRegDef(DefMI1->getOperand(1).getReg()); + + if (DefMI2->getNumOperands() > 1 && DefMI2->getOperand(1).isReg() && + DefMI2->getOperand(1).getReg().isVirtual()) + Inst2 = MRI->getVRegDef(DefMI2->getOperand(1).getReg()); + + unsigned Def1OP = DefMI1->getOpcode(); + unsigned Def2OP = DefMI2->getOpcode(); + + MachineInstrBuilder MIB; + // Case 1: Both reaching defs of MI are qf to sf/hf conversions + if ((Def1OP == Hexagon::V6_vconv_sf_qf32 && + Def2OP == Hexagon::V6_vconv_sf_qf32) || + (Def1OP == Hexagon::V6_vconv_hf_qf16 && + Def2OP == Hexagon::V6_vconv_hf_qf16)) { + + // If the reaching defs of DefMI are W register type, we return + if ((Inst1 && Inst1->getNumOperands() > 0 && Inst1->getOperand(0).isReg() && + MRI->getRegClass(Inst1->getOperand(0).getReg()) == + &Hexagon::HvxWRRegClass) || + (Inst2 && Inst2->getNumOperands() > 0 && Inst2->getOperand(0).isReg() && + MRI->getRegClass(Inst2->getOperand(0).getReg()) == + &Hexagon::HvxWRRegClass)) + return false; + + // Analyze the use operands of the conversion to get their KILL status + MachineOperand &Src1 = DefMI1->getOperand(1); + MachineOperand &Src2 = DefMI2->getOperand(1); + + Op0F = getKillRegState(Src1.isKill()); + Src1.setIsKill(false); + + Op1F = getKillRegState(Src2.isKill()); + Src2.setIsKill(false); + + if (MI->getOpcode() != Hexagon::V6_vmpy_qf32_sf) { + auto OuterIt = QFPInstMap.find(MI->getOpcode()); + if (OuterIt == QFPInstMap.end()) + return false; + auto InnerIt = QFPInstMap.find(OuterIt->second); + if (InnerIt == QFPInstMap.end()) + return false; + InstTy = InnerIt->second; + } + + MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), HII->get(InstTy), Res.getReg()) + .addReg(Src1.getReg(), Op0F, Src1.getSubReg()) + .addReg(Src2.getReg(), Op1F, Src2.getSubReg()); + LLVM_DEBUG(dbgs() << "\n[Inserting]: "; MIB.getInstr()->dump()); + return true; + + // Case 2: Left operand is conversion to sf/hf + } else if (((Def1OP == Hexagon::V6_vconv_sf_qf32 && + Def2OP != Hexagon::V6_vconv_sf_qf32) || + (Def1OP == Hexagon::V6_vconv_hf_qf16 && + Def2OP != Hexagon::V6_vconv_hf_qf16)) && + !DefMI2->isPHI() && + (MI->getOpcode() != Hexagon::V6_vmpy_qf32_sf)) { + + if (Inst1 && MRI->getRegClass(Inst1->getOperand(0).getReg()) == + &Hexagon::HvxWRRegClass) + return false; + + MachineOperand &Src1 = DefMI1->getOperand(1); + MachineOperand &Src2 = MI->getOperand(2); + + Op0F = getKillRegState(Src1.isKill()); + Src1.setIsKill(false); + Op1F = getKillRegState(Src2.isKill()); + MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), HII->get(InstTy), Res.getReg()) + .addReg(Src1.getReg(), Op0F, Src1.getSubReg()) + .addReg(Src2.getReg(), Op1F, Src2.getSubReg()); + LLVM_DEBUG(dbgs() << "\n[Inserting]: "; MIB.getInstr()->dump()); + return true; + + // Case 2: Left operand is conversion to sf/hf + } else if (((Def1OP != Hexagon::V6_vconv_sf_qf32 && + Def2OP == Hexagon::V6_vconv_sf_qf32) || + (Def1OP != Hexagon::V6_vconv_hf_qf16 && + Def2OP == Hexagon::V6_vconv_hf_qf16)) && + !DefMI1->isPHI() && + (MI->getOpcode() != Hexagon::V6_vmpy_qf32_sf)) { + // The second operand of original instruction is converted. + // In "mix" instructions, "qf" operand is always the first operand. + + // Caveat: vsub is not commutative w.r.t operands. + if (InstTy == Hexagon::V6_vsub_qf16_mix || + InstTy == Hexagon::V6_vsub_qf32_mix) + return false; + + if (Inst2 && MRI->getRegClass(Inst2->getOperand(0).getReg()) == + &Hexagon::HvxWRRegClass) + return false; + + MachineOperand &Src1 = MI->getOperand(1); + MachineOperand &Src2 = DefMI2->getOperand(1); + + Op1F = getKillRegState(Src2.isKill()); + Src2.setIsKill(false); + Op0F = getKillRegState(Src1.isKill()); + MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), HII->get(InstTy), Res.getReg()) + .addReg(Src2.getReg(), Op1F, + Src2.getSubReg()) // Notice the operands are flipped. + .addReg(Src1.getReg(), Op0F, Src1.getSubReg()); + LLVM_DEBUG(dbgs() << "\n[Inserting]: "; MIB.getInstr()->dump()); + return true; + } + + return false; +} + +bool HexagonQFPOptimizer::runOnMachineFunction(MachineFunction &MF) { + + bool Changed = false; + + if (DisableQFOptimizer) + return Changed; + + HST = &MF.getSubtarget<HexagonSubtarget>(); + if (!HST->useHVXV68Ops() || !HST->usePackets() || + skipFunction(MF.getFunction())) + return false; + HII = HST->getInstrInfo(); + MRI = &MF.getRegInfo(); + + MachineFunction::iterator MBBI = MF.begin(); + LLVM_DEBUG(dbgs() << "\n=== Running QFPOptimzer Pass for : " << MF.getName() + << " Optimize intermediate conversions ===\n"); + while (MBBI != MF.end()) { + MachineBasicBlock *MBB = &*MBBI; + MachineBasicBlock::iterator MII = MBBI->instr_begin(); + while (MII != MBBI->instr_end()) { + MachineInstr *MI = &*MII; + ++MII; // As MI might be removed. + + if (QFPInstMap.count(MI->getOpcode()) && + MI->getOpcode() != Hexagon::V6_vconv_sf_qf32 && + MI->getOpcode() != Hexagon::V6_vconv_hf_qf16) { + LLVM_DEBUG(dbgs() << "\n###Analyzing for removal: "; MI->dump()); + if (optimizeQfp(MI, MBB)) { + MI->eraseFromParent(); + LLVM_DEBUG(dbgs() << "\t....Removing...."); + Changed = true; + } + } + } + ++MBBI; + } + return Changed; +} diff --git a/llvm/lib/Target/Hexagon/HexagonTargetMachine.cpp b/llvm/lib/Target/Hexagon/HexagonTargetMachine.cpp index f5d8b69..d9824a31 100644 --- a/llvm/lib/Target/Hexagon/HexagonTargetMachine.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetMachine.cpp @@ -220,6 +220,7 @@ LLVMInitializeHexagonTarget() { initializeHexagonPeepholePass(PR); initializeHexagonSplitConst32AndConst64Pass(PR); initializeHexagonVectorPrintPass(PR); + initializeHexagonQFPOptimizerPass(PR); } HexagonTargetMachine::HexagonTargetMachine(const Target &T, const Triple &TT, @@ -386,6 +387,7 @@ bool HexagonPassConfig::addInstSelector() { addPass(createHexagonGenInsert()); if (EnableEarlyIf) addPass(createHexagonEarlyIfConversion()); + addPass(createHexagonQFPOptimizer()); } return false; diff --git a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp index e4c0a16..9ab5202 100644 --- a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp +++ b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp @@ -300,7 +300,6 @@ private: const_iterator end() const { return Blocks.end(); } }; - Align getAlignFromValue(const Value *V) const; std::optional<AddrInfo> getAddrInfo(Instruction &In) const; bool isHvx(const AddrInfo &AI) const; // This function is only used for assertions at the moment. @@ -612,12 +611,6 @@ auto AlignVectors::ByteSpan::values() const -> SmallVector<Value *, 8> { return Values; } -auto AlignVectors::getAlignFromValue(const Value *V) const -> Align { - const auto *C = dyn_cast<ConstantInt>(V); - assert(C && "Alignment must be a compile-time constant integer"); - return C->getAlignValue(); -} - auto AlignVectors::getAddrInfo(Instruction &In) const -> std::optional<AddrInfo> { if (auto *L = isCandidate<LoadInst>(&In)) @@ -631,11 +624,11 @@ auto AlignVectors::getAddrInfo(Instruction &In) const switch (ID) { case Intrinsic::masked_load: return AddrInfo(HVC, II, II->getArgOperand(0), II->getType(), - getAlignFromValue(II->getArgOperand(1))); + II->getParamAlign(0).valueOrOne()); case Intrinsic::masked_store: return AddrInfo(HVC, II, II->getArgOperand(1), II->getArgOperand(0)->getType(), - getAlignFromValue(II->getArgOperand(2))); + II->getParamAlign(1).valueOrOne()); } } return std::nullopt; @@ -660,9 +653,9 @@ auto AlignVectors::getMask(Value *Val) const -> Value * { if (auto *II = dyn_cast<IntrinsicInst>(Val)) { switch (II->getIntrinsicID()) { case Intrinsic::masked_load: - return II->getArgOperand(2); + return II->getArgOperand(1); case Intrinsic::masked_store: - return II->getArgOperand(3); + return II->getArgOperand(2); } } @@ -675,7 +668,7 @@ auto AlignVectors::getMask(Value *Val) const -> Value * { auto AlignVectors::getPassThrough(Value *Val) const -> Value * { if (auto *II = dyn_cast<IntrinsicInst>(Val)) { if (II->getIntrinsicID() == Intrinsic::masked_load) - return II->getArgOperand(3); + return II->getArgOperand(2); } return UndefValue::get(getPayload(Val)->getType()); } diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td index 5143d53..613dea6 100644 --- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td @@ -2025,10 +2025,10 @@ def : Pat<(v4i32(fp_to_uint v4f64:$vj)), sub_128)>; // abs -def : Pat<(abs v32i8:$xj), (XVMAX_B v32i8:$xj, (XVNEG_B v32i8:$xj))>; -def : Pat<(abs v16i16:$xj), (XVMAX_H v16i16:$xj, (XVNEG_H v16i16:$xj))>; -def : Pat<(abs v8i32:$xj), (XVMAX_W v8i32:$xj, (XVNEG_W v8i32:$xj))>; -def : Pat<(abs v4i64:$xj), (XVMAX_D v4i64:$xj, (XVNEG_D v4i64:$xj))>; +def : Pat<(abs v32i8:$xj), (XVSIGNCOV_B v32i8:$xj, v32i8:$xj)>; +def : Pat<(abs v16i16:$xj), (XVSIGNCOV_H v16i16:$xj, v16i16:$xj)>; +def : Pat<(abs v8i32:$xj), (XVSIGNCOV_W v8i32:$xj, v8i32:$xj)>; +def : Pat<(abs v4i64:$xj), (XVSIGNCOV_D v4i64:$xj, v4i64:$xj)>; // XVABSD_{B/H/W/D}[U] defm : PatXrXr<abds, "XVABSD">; diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td index 8d1dc99..4619c6b 100644 --- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td +++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td @@ -2155,10 +2155,10 @@ def : Pat<(f64 f64imm_vldi:$in), (f64 (EXTRACT_SUBREG (VLDI (to_f64imm_vldi f64imm_vldi:$in)), sub_64))>; // abs -def : Pat<(abs v16i8:$vj), (VMAX_B v16i8:$vj, (VNEG_B v16i8:$vj))>; -def : Pat<(abs v8i16:$vj), (VMAX_H v8i16:$vj, (VNEG_H v8i16:$vj))>; -def : Pat<(abs v4i32:$vj), (VMAX_W v4i32:$vj, (VNEG_W v4i32:$vj))>; -def : Pat<(abs v2i64:$vj), (VMAX_D v2i64:$vj, (VNEG_D v2i64:$vj))>; +def : Pat<(abs v16i8:$vj), (VSIGNCOV_B v16i8:$vj, v16i8:$vj)>; +def : Pat<(abs v8i16:$vj), (VSIGNCOV_H v8i16:$vj, v8i16:$vj)>; +def : Pat<(abs v4i32:$vj), (VSIGNCOV_W v4i32:$vj, v4i32:$vj)>; +def : Pat<(abs v2i64:$vj), (VSIGNCOV_D v2i64:$vj, v2i64:$vj)>; // VABSD_{B/H/W/D}[U] defm : PatVrVr<abds, "VABSD">; diff --git a/llvm/lib/Target/RISCV/MCA/RISCVCustomBehaviour.cpp b/llvm/lib/Target/RISCV/MCA/RISCVCustomBehaviour.cpp index ab93bba..b00589a 100644 --- a/llvm/lib/Target/RISCV/MCA/RISCVCustomBehaviour.cpp +++ b/llvm/lib/Target/RISCV/MCA/RISCVCustomBehaviour.cpp @@ -68,7 +68,7 @@ const llvm::StringRef RISCVSEWInstrument::DESC_NAME = "RISCV-SEW"; bool RISCVSEWInstrument::isDataValid(llvm::StringRef Data) { // Return true if not one of the valid SEW strings return StringSwitch<bool>(Data) - .Cases("E8", "E16", "E32", "E64", true) + .Cases({"E8", "E16", "E32", "E64"}, true) .Default(false); } diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp index 52dc53e..25b5af8 100644 --- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp @@ -495,18 +495,19 @@ RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr, bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II) { VectorType *DataType; Value *StoreVal = nullptr, *Ptr, *Mask, *EVL = nullptr; - MaybeAlign MA; + Align Alignment; switch (II->getIntrinsicID()) { case Intrinsic::masked_gather: DataType = cast<VectorType>(II->getType()); Ptr = II->getArgOperand(0); - MA = cast<ConstantInt>(II->getArgOperand(1))->getMaybeAlignValue(); - Mask = II->getArgOperand(2); + Alignment = II->getParamAlign(0).valueOrOne(); + Mask = II->getArgOperand(1); break; case Intrinsic::vp_gather: DataType = cast<VectorType>(II->getType()); Ptr = II->getArgOperand(0); - MA = II->getParamAlign(0).value_or( + // FIXME: Falling back to ABI alignment is incorrect. + Alignment = II->getParamAlign(0).value_or( DL->getABITypeAlign(DataType->getElementType())); Mask = II->getArgOperand(1); EVL = II->getArgOperand(2); @@ -515,14 +516,15 @@ bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II) { DataType = cast<VectorType>(II->getArgOperand(0)->getType()); StoreVal = II->getArgOperand(0); Ptr = II->getArgOperand(1); - MA = cast<ConstantInt>(II->getArgOperand(2))->getMaybeAlignValue(); - Mask = II->getArgOperand(3); + Alignment = II->getParamAlign(1).valueOrOne(); + Mask = II->getArgOperand(2); break; case Intrinsic::vp_scatter: DataType = cast<VectorType>(II->getArgOperand(0)->getType()); StoreVal = II->getArgOperand(0); Ptr = II->getArgOperand(1); - MA = II->getParamAlign(1).value_or( + // FIXME: Falling back to ABI alignment is incorrect. + Alignment = II->getParamAlign(1).value_or( DL->getABITypeAlign(DataType->getElementType())); Mask = II->getArgOperand(2); EVL = II->getArgOperand(3); @@ -533,7 +535,7 @@ bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II) { // Make sure the operation will be supported by the backend. EVT DataTypeVT = TLI->getValueType(*DL, DataType); - if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA)) + if (!TLI->isLegalStridedLoadStore(DataTypeVT, Alignment)) return false; // FIXME: Let the backend type legalize by splitting/widening? @@ -571,7 +573,7 @@ bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II) { // Merge llvm.masked.gather's passthru if (II->getIntrinsicID() == Intrinsic::masked_gather) - Call = Builder.CreateSelect(Mask, Call, II->getArgOperand(3)); + Call = Builder.CreateSelect(Mask, Call, II->getArgOperand(2)); } else Call = Builder.CreateIntrinsic( Intrinsic::experimental_vp_strided_store, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 169465e..a77d765 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -12649,10 +12649,7 @@ SDValue RISCVTargetLowering::lowerVECTOR_REVERSE(SDValue Op, Lo = DAG.getNode(ISD::VECTOR_REVERSE, DL, LoVT, Lo); Hi = DAG.getNode(ISD::VECTOR_REVERSE, DL, HiVT, Hi); // Reassemble the low and high pieces reversed. - // FIXME: This is a CONCAT_VECTORS. - SDValue Res = DAG.getInsertSubvector(DL, DAG.getUNDEF(VecVT), Hi, 0); - return DAG.getInsertSubvector(DL, Res, Lo, - LoVT.getVectorMinNumElements()); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, Hi, Lo); } // Just promote the int type to i16 which will double the LMUL. @@ -24047,18 +24044,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, } } - std::pair<Register, const TargetRegisterClass *> Res = - TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT); - - // If we picked one of the Zfinx register classes, remap it to the GPR class. - // FIXME: When Zfinx is supported in CodeGen this will need to take the - // Subtarget into account. - if (Res.second == &RISCV::GPRF16RegClass || - Res.second == &RISCV::GPRF32RegClass || - Res.second == &RISCV::GPRPairRegClass) - return std::make_pair(Res.first, &RISCV::GPRRegClass); - - return Res; + return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT); } InlineAsm::ConstraintCode diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp index 1b7cb9b..636e31c 100644 --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -699,7 +699,8 @@ public: "Can't encode VTYPE for uninitialized or unknown"); if (TWiden != 0) return RISCVVType::encodeXSfmmVType(SEW, TWiden, AltFmt); - return RISCVVType::encodeVTYPE(VLMul, SEW, TailAgnostic, MaskAgnostic); + return RISCVVType::encodeVTYPE(VLMul, SEW, TailAgnostic, MaskAgnostic, + AltFmt); } bool hasSEWLMULRatioOnly() const { return SEWLMULRatioOnly; } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index ddb53a2..12f776b 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -3775,11 +3775,13 @@ std::string RISCVInstrInfo::createMIROperandComment( #define CASE_VFMA_OPCODE_VV(OP) \ CASE_VFMA_OPCODE_LMULS_MF4(OP, VV, E16): \ + case CASE_VFMA_OPCODE_LMULS_MF4(OP##_ALT, VV, E16): \ case CASE_VFMA_OPCODE_LMULS_MF2(OP, VV, E32): \ case CASE_VFMA_OPCODE_LMULS_M1(OP, VV, E64) #define CASE_VFMA_SPLATS(OP) \ CASE_VFMA_OPCODE_LMULS_MF4(OP, VFPR16, E16): \ + case CASE_VFMA_OPCODE_LMULS_MF4(OP##_ALT, VFPR16, E16): \ case CASE_VFMA_OPCODE_LMULS_MF2(OP, VFPR32, E32): \ case CASE_VFMA_OPCODE_LMULS_M1(OP, VFPR64, E64) // clang-format on @@ -4003,11 +4005,13 @@ bool RISCVInstrInfo::findCommutedOpIndices(const MachineInstr &MI, #define CASE_VFMA_CHANGE_OPCODE_VV(OLDOP, NEWOP) \ CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(OLDOP, NEWOP, VV, E16) \ + CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(OLDOP##_ALT, NEWOP##_ALT, VV, E16) \ CASE_VFMA_CHANGE_OPCODE_LMULS_MF2(OLDOP, NEWOP, VV, E32) \ CASE_VFMA_CHANGE_OPCODE_LMULS_M1(OLDOP, NEWOP, VV, E64) #define CASE_VFMA_CHANGE_OPCODE_SPLATS(OLDOP, NEWOP) \ CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(OLDOP, NEWOP, VFPR16, E16) \ + CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(OLDOP##_ALT, NEWOP##_ALT, VFPR16, E16) \ CASE_VFMA_CHANGE_OPCODE_LMULS_MF2(OLDOP, NEWOP, VFPR32, E32) \ CASE_VFMA_CHANGE_OPCODE_LMULS_M1(OLDOP, NEWOP, VFPR64, E64) // clang-format on @@ -4469,6 +4473,20 @@ bool RISCVInstrInfo::simplifyInstruction(MachineInstr &MI) const { CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M2, E32) \ CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M4, E16) \ CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M4, E32) \ + +#define CASE_FP_WIDEOP_OPCODE_LMULS_ALT(OP) \ + CASE_FP_WIDEOP_OPCODE_COMMON(OP, MF4, E16): \ + case CASE_FP_WIDEOP_OPCODE_COMMON(OP, MF2, E16): \ + case CASE_FP_WIDEOP_OPCODE_COMMON(OP, M1, E16): \ + case CASE_FP_WIDEOP_OPCODE_COMMON(OP, M2, E16): \ + case CASE_FP_WIDEOP_OPCODE_COMMON(OP, M4, E16) + +#define CASE_FP_WIDEOP_CHANGE_OPCODE_LMULS_ALT(OP) \ + CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF4, E16) \ + CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF2, E16) \ + CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M1, E16) \ + CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M2, E16) \ + CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M4, E16) // clang-format on MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI, @@ -4478,6 +4496,8 @@ MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI, switch (MI.getOpcode()) { default: return nullptr; + case CASE_FP_WIDEOP_OPCODE_LMULS_ALT(FWADD_ALT_WV): + case CASE_FP_WIDEOP_OPCODE_LMULS_ALT(FWSUB_ALT_WV): case CASE_FP_WIDEOP_OPCODE_LMULS(FWADD_WV): case CASE_FP_WIDEOP_OPCODE_LMULS(FWSUB_WV): { assert(RISCVII::hasVecPolicyOp(MI.getDesc().TSFlags) && @@ -4494,6 +4514,8 @@ MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI, llvm_unreachable("Unexpected opcode"); CASE_FP_WIDEOP_CHANGE_OPCODE_LMULS(FWADD_WV) CASE_FP_WIDEOP_CHANGE_OPCODE_LMULS(FWSUB_WV) + CASE_FP_WIDEOP_CHANGE_OPCODE_LMULS_ALT(FWADD_ALT_WV) + CASE_FP_WIDEOP_CHANGE_OPCODE_LMULS_ALT(FWSUB_ALT_WV) } // clang-format on diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td index 66717b9..7c89686 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -1511,16 +1511,16 @@ def GIShiftMask32 : GIComplexOperandMatcher<s64, "selectShiftMask32">, GIComplexPatternEquiv<shiftMask32>; -class shiftop<SDPatternOperator operator> - : PatFrag<(ops node:$val, node:$count), - (operator node:$val, (XLenVT (shiftMaskXLen node:$count)))>; -class shiftopw<SDPatternOperator operator> - : PatFrag<(ops node:$val, node:$count), - (operator node:$val, (i64 (shiftMask32 node:$count)))>; +class PatGprShiftMaskXLen<SDPatternOperator OpNode, RVInst Inst> + : Pat<(OpNode GPR:$rs1, shiftMaskXLen:$rs2), + (Inst GPR:$rs1, shiftMaskXLen:$rs2)>; +class PatGprShiftMask32<SDPatternOperator OpNode, RVInst Inst> + : Pat<(OpNode GPR:$rs1, shiftMask32:$rs2), + (Inst GPR:$rs1, shiftMask32:$rs2)>; -def : PatGprGpr<shiftop<shl>, SLL>; -def : PatGprGpr<shiftop<srl>, SRL>; -def : PatGprGpr<shiftop<sra>, SRA>; +def : PatGprShiftMaskXLen<shl, SLL>; +def : PatGprShiftMaskXLen<srl, SRL>; +def : PatGprShiftMaskXLen<sra, SRA>; // This is a special case of the ADD instruction used to facilitate the use of a // fourth operand to emit a relocation on a symbol relating to this instruction. @@ -2203,9 +2203,9 @@ def : Pat<(sra (sext_inreg GPR:$rs1, i32), uimm5:$shamt), def : Pat<(i64 (sra (shl GPR:$rs1, (i64 32)), uimm6gt32:$shamt)), (SRAIW GPR:$rs1, (ImmSub32 uimm6gt32:$shamt))>; -def : PatGprGpr<shiftopw<riscv_sllw>, SLLW>; -def : PatGprGpr<shiftopw<riscv_srlw>, SRLW>; -def : PatGprGpr<shiftopw<riscv_sraw>, SRAW>; +def : PatGprShiftMask32<riscv_sllw, SLLW>; +def : PatGprShiftMask32<riscv_srlw, SRLW>; +def : PatGprShiftMask32<riscv_sraw, SRAW>; // Select W instructions if only the lower 32 bits of the result are used. def : PatGprGpr<binop_allwusers<add>, ADDW>; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td index 57fbaa0..62b7bcd 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td @@ -506,8 +506,8 @@ def : Pat<(XLenVT (xor GPR:$rs1, invLogicImm:$rs2)), (XNOR GPR:$rs1, invLogicImm } // Predicates = [HasStdExtZbbOrZbkb] let Predicates = [HasStdExtZbbOrZbkb] in { -def : PatGprGpr<shiftop<rotl>, ROL>; -def : PatGprGpr<shiftop<rotr>, ROR>; +def : PatGprShiftMaskXLen<rotl, ROL>; +def : PatGprShiftMaskXLen<rotr, ROR>; def : PatGprImm<rotr, RORI, uimmlog2xlen>; // There's no encoding for roli in the the 'B' extension as it can be @@ -517,29 +517,29 @@ def : Pat<(XLenVT (rotl GPR:$rs1, uimmlog2xlen:$shamt)), } // Predicates = [HasStdExtZbbOrZbkb] let Predicates = [HasStdExtZbbOrZbkb, IsRV64] in { -def : PatGprGpr<shiftopw<riscv_rolw>, ROLW>; -def : PatGprGpr<shiftopw<riscv_rorw>, RORW>; +def : PatGprShiftMask32<riscv_rolw, ROLW>; +def : PatGprShiftMask32<riscv_rorw, RORW>; def : PatGprImm<riscv_rorw, RORIW, uimm5>; def : Pat<(riscv_rolw GPR:$rs1, uimm5:$rs2), (RORIW GPR:$rs1, (ImmSubFrom32 uimm5:$rs2))>; } // Predicates = [HasStdExtZbbOrZbkb, IsRV64] let Predicates = [HasStdExtZbs] in { -def : Pat<(XLenVT (and (not (shiftop<shl> 1, (XLenVT GPR:$rs2))), GPR:$rs1)), - (BCLR GPR:$rs1, GPR:$rs2)>; -def : Pat<(XLenVT (and (rotl -2, (XLenVT GPR:$rs2)), GPR:$rs1)), - (BCLR GPR:$rs1, GPR:$rs2)>; -def : Pat<(XLenVT (or (shiftop<shl> 1, (XLenVT GPR:$rs2)), GPR:$rs1)), - (BSET GPR:$rs1, GPR:$rs2)>; -def : Pat<(XLenVT (xor (shiftop<shl> 1, (XLenVT GPR:$rs2)), GPR:$rs1)), - (BINV GPR:$rs1, GPR:$rs2)>; -def : Pat<(XLenVT (and (shiftop<srl> GPR:$rs1, (XLenVT GPR:$rs2)), 1)), - (BEXT GPR:$rs1, GPR:$rs2)>; - -def : Pat<(XLenVT (shiftop<shl> 1, (XLenVT GPR:$rs2))), - (BSET (XLenVT X0), GPR:$rs2)>; -def : Pat<(XLenVT (not (shiftop<shl> -1, (XLenVT GPR:$rs2)))), - (ADDI (XLenVT (BSET (XLenVT X0), GPR:$rs2)), -1)>; +def : Pat<(XLenVT (and (not (shl 1, shiftMaskXLen:$rs2)), GPR:$rs1)), + (BCLR GPR:$rs1, shiftMaskXLen:$rs2)>; +def : Pat<(XLenVT (and (rotl -2, shiftMaskXLen:$rs2), GPR:$rs1)), + (BCLR GPR:$rs1, shiftMaskXLen:$rs2)>; +def : Pat<(XLenVT (or (shl 1, shiftMaskXLen:$rs2), GPR:$rs1)), + (BSET GPR:$rs1, shiftMaskXLen:$rs2)>; +def : Pat<(XLenVT (xor (shl 1, shiftMaskXLen:$rs2), GPR:$rs1)), + (BINV GPR:$rs1, shiftMaskXLen:$rs2)>; +def : Pat<(XLenVT (and (srl GPR:$rs1, shiftMaskXLen:$rs2), 1)), + (BEXT GPR:$rs1, shiftMaskXLen:$rs2)>; + +def : Pat<(XLenVT (shl 1, shiftMaskXLen:$rs2)), + (BSET (XLenVT X0), shiftMaskXLen:$rs2)>; +def : Pat<(XLenVT (not (shl -1, shiftMaskXLen:$rs2))), + (ADDI (XLenVT (BSET (XLenVT X0), shiftMaskXLen:$rs2)), -1)>; def : Pat<(XLenVT (and GPR:$rs1, BCLRMask:$mask)), (BCLRI GPR:$rs1, BCLRMask:$mask)>; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td index c9c1246..f7d1a09 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td @@ -44,6 +44,336 @@ let Predicates = [HasStdExtZvfbfmin] in { let mayRaiseFPException = true, Predicates = [HasStdExtZvfbfwma] in defm PseudoVFWMACCBF16 : VPseudoVWMAC_VV_VF_BF_RM; +defset list<VTypeInfoToWide> AllWidenableIntToBF16Vectors = { + def : VTypeInfoToWide<VI8MF8, VBF16MF4>; + def : VTypeInfoToWide<VI8MF4, VBF16MF2>; + def : VTypeInfoToWide<VI8MF2, VBF16M1>; + def : VTypeInfoToWide<VI8M1, VBF16M2>; + def : VTypeInfoToWide<VI8M2, VBF16M4>; + def : VTypeInfoToWide<VI8M4, VBF16M8>; +} + +multiclass VPseudoVALU_VV_VF_RM_BF16 { + foreach m = MxListF in { + defm "" : VPseudoBinaryFV_VV_RM<m, 16/*sew*/>, + SchedBinary<"WriteVFALUV", "ReadVFALUV", "ReadVFALUV", m.MX, + 16/*sew*/, forcePassthruRead=true>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxList in { + defm "" : VPseudoBinaryV_VF_RM<m, f, f.SEW>, + SchedBinary<"WriteVFALUF", "ReadVFALUV", "ReadVFALUF", m.MX, + f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVALU_VF_RM_BF16 { + defvar f = SCALAR_F16; + foreach m = f.MxList in { + defm "" : VPseudoBinaryV_VF_RM<m, f, f.SEW>, + SchedBinary<"WriteVFALUF", "ReadVFALUV", "ReadVFALUF", m.MX, + f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVFWALU_VV_VF_RM_BF16 { + foreach m = MxListFW in { + defm "" : VPseudoBinaryW_VV_RM<m, sew=16>, + SchedBinary<"WriteVFWALUV", "ReadVFWALUV", "ReadVFWALUV", m.MX, + 16/*sew*/, forcePassthruRead=true>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxListFW in { + defm "" : VPseudoBinaryW_VF_RM<m, f, sew=f.SEW>, + SchedBinary<"WriteVFWALUF", "ReadVFWALUV", "ReadVFWALUF", m.MX, + f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVFWALU_WV_WF_RM_BF16 { + foreach m = MxListFW in { + defm "" : VPseudoBinaryW_WV_RM<m, sew=16>, + SchedBinary<"WriteVFWALUV", "ReadVFWALUV", "ReadVFWALUV", m.MX, + 16/*sew*/, forcePassthruRead=true>; + } + defvar f = SCALAR_F16; + foreach m = f.MxListFW in { + defm "" : VPseudoBinaryW_WF_RM<m, f, sew=f.SEW>, + SchedBinary<"WriteVFWALUF", "ReadVFWALUV", "ReadVFWALUF", m.MX, + f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVFMUL_VV_VF_RM_BF16 { + foreach m = MxListF in { + defm "" : VPseudoBinaryFV_VV_RM<m, 16/*sew*/>, + SchedBinary<"WriteVFMulV", "ReadVFMulV", "ReadVFMulV", m.MX, + 16/*sew*/, forcePassthruRead=true>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxList in { + defm "" : VPseudoBinaryV_VF_RM<m, f, f.SEW>, + SchedBinary<"WriteVFMulF", "ReadVFMulV", "ReadVFMulF", m.MX, + f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVWMUL_VV_VF_RM_BF16 { + foreach m = MxListFW in { + defm "" : VPseudoBinaryW_VV_RM<m, sew=16>, + SchedBinary<"WriteVFWMulV", "ReadVFWMulV", "ReadVFWMulV", m.MX, + 16/*sew*/, forcePassthruRead=true>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxListFW in { + defm "" : VPseudoBinaryW_VF_RM<m, f, sew=f.SEW>, + SchedBinary<"WriteVFWMulF", "ReadVFWMulV", "ReadVFWMulF", m.MX, + f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVMAC_VV_VF_AAXA_RM_BF16 { + foreach m = MxListF in { + defm "" : VPseudoTernaryV_VV_AAXA_RM<m, 16/*sew*/>, + SchedTernary<"WriteVFMulAddV", "ReadVFMulAddV", "ReadVFMulAddV", + "ReadVFMulAddV", m.MX, 16/*sew*/>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxList in { + defm "" : VPseudoTernaryV_VF_AAXA_RM<m, f, f.SEW>, + SchedTernary<"WriteVFMulAddF", "ReadVFMulAddV", "ReadVFMulAddF", + "ReadVFMulAddV", m.MX, f.SEW>; + } +} + +multiclass VPseudoVWMAC_VV_VF_RM_BF16 { + foreach m = MxListFW in { + defm "" : VPseudoTernaryW_VV_RM<m, sew=16>, + SchedTernary<"WriteVFWMulAddV", "ReadVFWMulAddV", + "ReadVFWMulAddV", "ReadVFWMulAddV", m.MX, 16/*sew*/>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxListFW in { + defm "" : VPseudoTernaryW_VF_RM<m, f, sew=f.SEW>, + SchedTernary<"WriteVFWMulAddF", "ReadVFWMulAddV", + "ReadVFWMulAddF", "ReadVFWMulAddV", m.MX, f.SEW>; + } +} + +multiclass VPseudoVRCP_V_BF16 { + foreach m = MxListF in { + defvar mx = m.MX; + let VLMul = m.value in { + def "_V_" # mx # "_E16" + : VPseudoUnaryNoMask<m.vrclass, m.vrclass>, + SchedUnary<"WriteVFRecpV", "ReadVFRecpV", mx, 16/*sew*/, + forcePassthruRead=true>; + def "_V_" # mx # "_E16_MASK" + : VPseudoUnaryMask<m.vrclass, m.vrclass>, + RISCVMaskedPseudo<MaskIdx = 2>, + SchedUnary<"WriteVFRecpV", "ReadVFRecpV", mx, 16/*sew*/, + forcePassthruRead=true>; + } + } +} + +multiclass VPseudoVRCP_V_RM_BF16 { + foreach m = MxListF in { + defvar mx = m.MX; + let VLMul = m.value in { + def "_V_" # mx # "_E16" + : VPseudoUnaryNoMaskRoundingMode<m.vrclass, m.vrclass>, + SchedUnary<"WriteVFRecpV", "ReadVFRecpV", mx, 16/*sew*/, + forcePassthruRead=true>; + def "_V_" # mx # "_E16_MASK" + : VPseudoUnaryMaskRoundingMode<m.vrclass, m.vrclass>, + RISCVMaskedPseudo<MaskIdx = 2>, + SchedUnary<"WriteVFRecpV", "ReadVFRecpV", mx, 16/*sew*/, + forcePassthruRead=true>; + } + } +} + +multiclass VPseudoVMAX_VV_VF_BF16 { + foreach m = MxListF in { + defm "" : VPseudoBinaryV_VV<m, sew=16>, + SchedBinary<"WriteVFMinMaxV", "ReadVFMinMaxV", "ReadVFMinMaxV", + m.MX, 16/*sew*/, forcePassthruRead=true>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxList in { + defm "" : VPseudoBinaryV_VF<m, f, f.SEW>, + SchedBinary<"WriteVFMinMaxF", "ReadVFMinMaxV", "ReadVFMinMaxF", + m.MX, f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVSGNJ_VV_VF_BF16 { + foreach m = MxListF in { + defm "" : VPseudoBinaryV_VV<m, sew=16>, + SchedBinary<"WriteVFSgnjV", "ReadVFSgnjV", "ReadVFSgnjV", m.MX, + 16/*sew*/, forcePassthruRead=true>; + } + + defvar f = SCALAR_F16; + foreach m = f.MxList in { + defm "" : VPseudoBinaryV_VF<m, f, f.SEW>, + SchedBinary<"WriteVFSgnjF", "ReadVFSgnjV", "ReadVFSgnjF", m.MX, + f.SEW, forcePassthruRead=true>; + } +} + +multiclass VPseudoVWCVTF_V_BF16 { + defvar constraint = "@earlyclobber $rd"; + foreach m = MxListW in + defm _V : VPseudoConversion<m.wvrclass, m.vrclass, m, constraint, sew=8, + TargetConstraintType=3>, + SchedUnary<"WriteVFWCvtIToFV", "ReadVFWCvtIToFV", m.MX, 8/*sew*/, + forcePassthruRead=true>; +} + +multiclass VPseudoVWCVTD_V_BF16 { + defvar constraint = "@earlyclobber $rd"; + foreach m = MxListFW in + defm _V : VPseudoConversion<m.wvrclass, m.vrclass, m, constraint, sew=16, + TargetConstraintType=3>, + SchedUnary<"WriteVFWCvtFToFV", "ReadVFWCvtFToFV", m.MX, 16/*sew*/, + forcePassthruRead=true>; +} + +multiclass VPseudoVNCVTD_W_BF16 { + defvar constraint = "@earlyclobber $rd"; + foreach m = MxListFW in + defm _W : VPseudoConversion<m.vrclass, m.wvrclass, m, constraint, sew=16, + TargetConstraintType=2>, + SchedUnary<"WriteVFNCvtFToFV", "ReadVFNCvtFToFV", m.MX, 16/*sew*/, + forcePassthruRead=true>; +} + +multiclass VPseudoVNCVTD_W_RM_BF16 { + defvar constraint = "@earlyclobber $rd"; + foreach m = MxListFW in + defm _W : VPseudoConversionRoundingMode<m.vrclass, m.wvrclass, m, + constraint, sew=16, + TargetConstraintType=2>, + SchedUnary<"WriteVFNCvtFToFV", "ReadVFNCvtFToFV", m.MX, 16/*sew*/, + forcePassthruRead=true>; +} + +let Predicates = [HasStdExtZvfbfa], AltFmtType = IS_ALTFMT in { +let mayRaiseFPException = true in { +defm PseudoVFADD_ALT : VPseudoVALU_VV_VF_RM_BF16; +defm PseudoVFSUB_ALT : VPseudoVALU_VV_VF_RM_BF16; +defm PseudoVFRSUB_ALT : VPseudoVALU_VF_RM_BF16; +} + +let mayRaiseFPException = true in { +defm PseudoVFWADD_ALT : VPseudoVFWALU_VV_VF_RM_BF16; +defm PseudoVFWSUB_ALT : VPseudoVFWALU_VV_VF_RM_BF16; +defm PseudoVFWADD_ALT : VPseudoVFWALU_WV_WF_RM_BF16; +defm PseudoVFWSUB_ALT : VPseudoVFWALU_WV_WF_RM_BF16; +} + +let mayRaiseFPException = true in +defm PseudoVFMUL_ALT : VPseudoVFMUL_VV_VF_RM_BF16; + +let mayRaiseFPException = true in +defm PseudoVFWMUL_ALT : VPseudoVWMUL_VV_VF_RM_BF16; + +let mayRaiseFPException = true in { +defm PseudoVFMACC_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +defm PseudoVFNMACC_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +defm PseudoVFMSAC_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +defm PseudoVFNMSAC_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +defm PseudoVFMADD_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +defm PseudoVFNMADD_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +defm PseudoVFMSUB_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +defm PseudoVFNMSUB_ALT : VPseudoVMAC_VV_VF_AAXA_RM_BF16; +} + +let mayRaiseFPException = true in { +defm PseudoVFWMACC_ALT : VPseudoVWMAC_VV_VF_RM_BF16; +defm PseudoVFWNMACC_ALT : VPseudoVWMAC_VV_VF_RM_BF16; +defm PseudoVFWMSAC_ALT : VPseudoVWMAC_VV_VF_RM_BF16; +defm PseudoVFWNMSAC_ALT : VPseudoVWMAC_VV_VF_RM_BF16; +} + +let mayRaiseFPException = true in +defm PseudoVFRSQRT7_ALT : VPseudoVRCP_V_BF16; + +let mayRaiseFPException = true in +defm PseudoVFREC7_ALT : VPseudoVRCP_V_RM_BF16; + +let mayRaiseFPException = true in { +defm PseudoVFMIN_ALT : VPseudoVMAX_VV_VF_BF16; +defm PseudoVFMAX_ALT : VPseudoVMAX_VV_VF_BF16; +} + +defm PseudoVFSGNJ_ALT : VPseudoVSGNJ_VV_VF_BF16; +defm PseudoVFSGNJN_ALT : VPseudoVSGNJ_VV_VF_BF16; +defm PseudoVFSGNJX_ALT : VPseudoVSGNJ_VV_VF_BF16; + +let mayRaiseFPException = true in { +defm PseudoVMFEQ_ALT : VPseudoVCMPM_VV_VF; +defm PseudoVMFNE_ALT : VPseudoVCMPM_VV_VF; +defm PseudoVMFLT_ALT : VPseudoVCMPM_VV_VF; +defm PseudoVMFLE_ALT : VPseudoVCMPM_VV_VF; +defm PseudoVMFGT_ALT : VPseudoVCMPM_VF; +defm PseudoVMFGE_ALT : VPseudoVCMPM_VF; +} + +defm PseudoVFCLASS_ALT : VPseudoVCLS_V; + +defm PseudoVFMERGE_ALT : VPseudoVMRG_FM; + +defm PseudoVFMV_V_ALT : VPseudoVMV_F; + +let mayRaiseFPException = true in { +defm PseudoVFWCVT_F_XU_ALT : VPseudoVWCVTF_V_BF16; +defm PseudoVFWCVT_F_X_ALT : VPseudoVWCVTF_V_BF16; + +defm PseudoVFWCVT_F_F_ALT : VPseudoVWCVTD_V_BF16; +} // mayRaiseFPException = true + +let mayRaiseFPException = true in { +let hasSideEffects = 0, hasPostISelHook = 1 in { +defm PseudoVFNCVT_XU_F_ALT : VPseudoVNCVTI_W_RM; +defm PseudoVFNCVT_X_F_ALT : VPseudoVNCVTI_W_RM; +} + +defm PseudoVFNCVT_RTZ_XU_F_ALT : VPseudoVNCVTI_W; +defm PseudoVFNCVT_RTZ_X_F_ALT : VPseudoVNCVTI_W; + +defm PseudoVFNCVT_F_F_ALT : VPseudoVNCVTD_W_RM_BF16; + +defm PseudoVFNCVT_ROD_F_F_ALT : VPseudoVNCVTD_W_BF16; +} // mayRaiseFPException = true + +let mayLoad = 0, mayStore = 0, hasSideEffects = 0 in { + defvar f = SCALAR_F16; + let HasSEWOp = 1, BaseInstr = VFMV_F_S in + def "PseudoVFMV_" # f.FX # "_S_ALT" : + RISCVVPseudo<(outs f.fprclass:$rd), (ins VR:$rs2, sew:$sew)>, + Sched<[WriteVMovFS, ReadVMovFS]>; + let HasVLOp = 1, HasSEWOp = 1, BaseInstr = VFMV_S_F, isReMaterializable = 1, + Constraints = "$rd = $passthru" in + def "PseudoVFMV_S_" # f.FX # "_ALT" : + RISCVVPseudo<(outs VR:$rd), + (ins VR:$passthru, f.fprclass:$rs1, AVL:$vl, sew:$sew)>, + Sched<[WriteVMovSF, ReadVMovSF_V, ReadVMovSF_F]>; +} + +defm PseudoVFSLIDE1UP_ALT : VPseudoVSLD1_VF<"@earlyclobber $rd">; +defm PseudoVFSLIDE1DOWN_ALT : VPseudoVSLD1_VF; +} // Predicates = [HasStdExtZvfbfa], AltFmtType = IS_ALTFMT + //===----------------------------------------------------------------------===// // Patterns //===----------------------------------------------------------------------===// @@ -118,3 +448,224 @@ let Predicates = [HasStdExtZvfbfwma] in { defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACCBF16", AllWidenableBF16ToFloatVectors>; } + +multiclass VPatConversionVI_VF_BF16<string intrinsic, string instruction> { + foreach fvti = AllBF16Vectors in { + defvar ivti = GetIntVTypeInfo<fvti>.Vti; + let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates, + GetVTypePredicates<ivti>.Predicates) in + defm : VPatConversion<intrinsic, instruction, "V", + ivti.Vector, fvti.Vector, ivti.Mask, fvti.Log2SEW, + fvti.LMul, ivti.RegClass, fvti.RegClass>; + } +} + +multiclass VPatConversionWF_VI_BF16<string intrinsic, string instruction, + bit isSEWAware = 0> { + foreach vtiToWti = AllWidenableIntToBF16Vectors in { + defvar vti = vtiToWti.Vti; + defvar fwti = vtiToWti.Wti; + let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates, + GetVTypePredicates<fwti>.Predicates) in + defm : VPatConversion<intrinsic, instruction, "V", + fwti.Vector, vti.Vector, fwti.Mask, vti.Log2SEW, + vti.LMul, fwti.RegClass, vti.RegClass, isSEWAware>; + } +} + +multiclass VPatConversionWF_VF_BF16<string intrinsic, string instruction, + bit isSEWAware = 0> { + foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { + defvar fvti = fvtiToFWti.Vti; + defvar fwti = fvtiToFWti.Wti; + let Predicates = !listconcat(GetVTypeMinimalPredicates<fvti>.Predicates, + GetVTypeMinimalPredicates<fwti>.Predicates) in + defm : VPatConversion<intrinsic, instruction, "V", + fwti.Vector, fvti.Vector, fwti.Mask, fvti.Log2SEW, + fvti.LMul, fwti.RegClass, fvti.RegClass, isSEWAware>; + } +} + +multiclass VPatConversionVI_WF_BF16<string intrinsic, string instruction> { + foreach vtiToWti = AllWidenableIntToBF16Vectors in { + defvar vti = vtiToWti.Vti; + defvar fwti = vtiToWti.Wti; + let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates, + GetVTypePredicates<fwti>.Predicates) in + defm : VPatConversion<intrinsic, instruction, "W", + vti.Vector, fwti.Vector, vti.Mask, vti.Log2SEW, + vti.LMul, vti.RegClass, fwti.RegClass>; + } +} + +multiclass VPatConversionVI_WF_RM_BF16<string intrinsic, string instruction> { + foreach vtiToWti = AllWidenableIntToBF16Vectors in { + defvar vti = vtiToWti.Vti; + defvar fwti = vtiToWti.Wti; + let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates, + GetVTypePredicates<fwti>.Predicates) in + defm : VPatConversionRoundingMode<intrinsic, instruction, "W", + vti.Vector, fwti.Vector, vti.Mask, vti.Log2SEW, + vti.LMul, vti.RegClass, fwti.RegClass>; + } +} + +multiclass VPatConversionVF_WF_BF16<string intrinsic, string instruction, + bit isSEWAware = 0> { + foreach fvtiToFWti = AllWidenableBF16ToFloatVectors in { + defvar fvti = fvtiToFWti.Vti; + defvar fwti = fvtiToFWti.Wti; + let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates, + GetVTypePredicates<fwti>.Predicates) in + defm : VPatConversion<intrinsic, instruction, "W", + fvti.Vector, fwti.Vector, fvti.Mask, fvti.Log2SEW, + fvti.LMul, fvti.RegClass, fwti.RegClass, isSEWAware>; + } +} + +let Predicates = [HasStdExtZvfbfa] in { +defm : VPatBinaryV_VV_VX_RM<"int_riscv_vfadd", "PseudoVFADD_ALT", + AllBF16Vectors, isSEWAware = 1>; +defm : VPatBinaryV_VV_VX_RM<"int_riscv_vfsub", "PseudoVFSUB_ALT", + AllBF16Vectors, isSEWAware = 1>; +defm : VPatBinaryV_VX_RM<"int_riscv_vfrsub", "PseudoVFRSUB_ALT", + AllBF16Vectors, isSEWAware = 1>; +defm : VPatBinaryW_VV_VX_RM<"int_riscv_vfwadd", "PseudoVFWADD_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatBinaryW_VV_VX_RM<"int_riscv_vfwsub", "PseudoVFWSUB_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatBinaryW_WV_WX_RM<"int_riscv_vfwadd_w", "PseudoVFWADD_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatBinaryW_WV_WX_RM<"int_riscv_vfwsub_w", "PseudoVFWSUB_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatBinaryV_VV_VX_RM<"int_riscv_vfmul", "PseudoVFMUL_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatBinaryW_VV_VX_RM<"int_riscv_vfwmul", "PseudoVFWMUL_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfmacc", "PseudoVFMACC_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfnmacc", "PseudoVFNMACC_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfmsac", "PseudoVFMSAC_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfnmsac", "PseudoVFNMSAC_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfmadd", "PseudoVFMADD_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfnmadd", "PseudoVFNMADD_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfmsub", "PseudoVFMSUB_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryV_VV_VX_AAXA_RM<"int_riscv_vfnmsub", "PseudoVFNMSUB_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatTernaryW_VV_VX_RM<"int_riscv_vfwmacc", "PseudoVFWMACC_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatTernaryW_VV_VX_RM<"int_riscv_vfwnmacc", "PseudoVFWNMACC_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatTernaryW_VV_VX_RM<"int_riscv_vfwmsac", "PseudoVFWMSAC_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatTernaryW_VV_VX_RM<"int_riscv_vfwnmsac", "PseudoVFWNMSAC_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatUnaryV_V<"int_riscv_vfrsqrt7", "PseudoVFRSQRT7_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatUnaryV_V_RM<"int_riscv_vfrec7", "PseudoVFREC7_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatBinaryV_VV_VX<"int_riscv_vfmin", "PseudoVFMIN_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatBinaryV_VV_VX<"int_riscv_vfmax", "PseudoVFMAX_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatBinaryV_VV_VX<"int_riscv_vfsgnj", "PseudoVFSGNJ_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatBinaryV_VV_VX<"int_riscv_vfsgnjn", "PseudoVFSGNJN_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatBinaryV_VV_VX<"int_riscv_vfsgnjx", "PseudoVFSGNJX_ALT", + AllBF16Vectors, isSEWAware=1>; +defm : VPatBinaryM_VV_VX<"int_riscv_vmfeq", "PseudoVMFEQ_ALT", AllBF16Vectors>; +defm : VPatBinaryM_VV_VX<"int_riscv_vmfle", "PseudoVMFLE_ALT", AllBF16Vectors>; +defm : VPatBinaryM_VV_VX<"int_riscv_vmflt", "PseudoVMFLT_ALT", AllBF16Vectors>; +defm : VPatBinaryM_VV_VX<"int_riscv_vmfne", "PseudoVMFNE_ALT", AllBF16Vectors>; +defm : VPatBinaryM_VX<"int_riscv_vmfgt", "PseudoVMFGT_ALT", AllBF16Vectors>; +defm : VPatBinaryM_VX<"int_riscv_vmfge", "PseudoVMFGE_ALT", AllBF16Vectors>; +defm : VPatBinarySwappedM_VV<"int_riscv_vmfgt", "PseudoVMFLT_ALT", AllBF16Vectors>; +defm : VPatBinarySwappedM_VV<"int_riscv_vmfge", "PseudoVMFLE_ALT", AllBF16Vectors>; +defm : VPatConversionVI_VF_BF16<"int_riscv_vfclass", "PseudoVFCLASS_ALT">; +foreach vti = AllBF16Vectors in { + let Predicates = GetVTypePredicates<vti>.Predicates in + defm : VPatBinaryCarryInTAIL<"int_riscv_vfmerge", "PseudoVFMERGE_ALT", + "V"#vti.ScalarSuffix#"M", + vti.Vector, + vti.Vector, vti.Scalar, vti.Mask, + vti.Log2SEW, vti.LMul, vti.RegClass, + vti.RegClass, vti.ScalarRegClass>; +} +defm : VPatConversionWF_VI_BF16<"int_riscv_vfwcvt_f_xu_v", "PseudoVFWCVT_F_XU_ALT", + isSEWAware=1>; +defm : VPatConversionWF_VI_BF16<"int_riscv_vfwcvt_f_x_v", "PseudoVFWCVT_F_X_ALT", + isSEWAware=1>; +defm : VPatConversionWF_VF_BF16<"int_riscv_vfwcvt_f_f_v", "PseudoVFWCVT_F_F_ALT", + isSEWAware=1>; +defm : VPatConversionVI_WF_RM_BF16<"int_riscv_vfncvt_xu_f_w", "PseudoVFNCVT_XU_F_ALT">; +defm : VPatConversionVI_WF_RM_BF16<"int_riscv_vfncvt_x_f_w", "PseudoVFNCVT_X_F_ALT">; +defm : VPatConversionVI_WF_BF16<"int_riscv_vfncvt_rtz_xu_f_w", "PseudoVFNCVT_RTZ_XU_F_ALT">; +defm : VPatConversionVI_WF_BF16<"int_riscv_vfncvt_rtz_x_f_w", "PseudoVFNCVT_RTZ_X_F_ALT">; +defm : VPatConversionVF_WF_RM<"int_riscv_vfncvt_f_f_w", "PseudoVFNCVT_F_F_ALT", + AllWidenableBF16ToFloatVectors, isSEWAware=1>; +defm : VPatConversionVF_WF_BF16<"int_riscv_vfncvt_rod_f_f_w", "PseudoVFNCVT_ROD_F_F_ALT", + isSEWAware=1>; +defm : VPatBinaryV_VX<"int_riscv_vfslide1up", "PseudoVFSLIDE1UP_ALT", AllBF16Vectors>; +defm : VPatBinaryV_VX<"int_riscv_vfslide1down", "PseudoVFSLIDE1DOWN_ALT", AllBF16Vectors>; + +foreach fvti = AllBF16Vectors in { + defvar ivti = GetIntVTypeInfo<fvti>.Vti; + let Predicates = GetVTypePredicates<ivti>.Predicates in { + // 13.16. Vector Floating-Point Move Instruction + // If we're splatting fpimm0, use vmv.v.x vd, x0. + def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl + fvti.Vector:$passthru, (fvti.Scalar (fpimm0)), VLOpFrag)), + (!cast<Instruction>("PseudoVMV_V_I_"#fvti.LMul.MX) + $passthru, 0, GPR:$vl, fvti.Log2SEW, TU_MU)>; + def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl + fvti.Vector:$passthru, (fvti.Scalar (SelectScalarFPAsInt (XLenVT GPR:$imm))), VLOpFrag)), + (!cast<Instruction>("PseudoVMV_V_X_"#fvti.LMul.MX) + $passthru, GPR:$imm, GPR:$vl, fvti.Log2SEW, TU_MU)>; + } + + let Predicates = GetVTypePredicates<fvti>.Predicates in { + def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl + fvti.Vector:$passthru, (fvti.Scalar fvti.ScalarRegClass:$rs2), VLOpFrag)), + (!cast<Instruction>("PseudoVFMV_V_ALT_" # fvti.ScalarSuffix # "_" # + fvti.LMul.MX) + $passthru, (fvti.Scalar fvti.ScalarRegClass:$rs2), + GPR:$vl, fvti.Log2SEW, TU_MU)>; + } +} + +foreach vti = NoGroupBF16Vectors in { + let Predicates = GetVTypePredicates<vti>.Predicates in { + def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru), + (vti.Scalar (fpimm0)), + VLOpFrag)), + (PseudoVMV_S_X $passthru, (XLenVT X0), GPR:$vl, vti.Log2SEW)>; + def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru), + (vti.Scalar (SelectScalarFPAsInt (XLenVT GPR:$imm))), + VLOpFrag)), + (PseudoVMV_S_X $passthru, GPR:$imm, GPR:$vl, vti.Log2SEW)>; + def : Pat<(vti.Vector (riscv_vfmv_s_f_vl (vti.Vector vti.RegClass:$passthru), + vti.ScalarRegClass:$rs1, + VLOpFrag)), + (!cast<Instruction>("PseudoVFMV_S_"#vti.ScalarSuffix#"_ALT") + vti.RegClass:$passthru, + (vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.Log2SEW)>; + } + + defvar vfmv_f_s_inst = !cast<Instruction>(!strconcat("PseudoVFMV_", + vti.ScalarSuffix, + "_S_ALT")); + // Only pattern-match extract-element operations where the index is 0. Any + // other index will have been custom-lowered to slide the vector correctly + // into place. + let Predicates = GetVTypePredicates<vti>.Predicates in + def : Pat<(vti.Scalar (extractelt (vti.Vector vti.RegClass:$rs2), 0)), + (vfmv_f_s_inst vti.RegClass:$rs2, vti.Log2SEW)>; +} +} // Predicates = [HasStdExtZvfbfa] diff --git a/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp b/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp index 5e10631..528bbdf 100644 --- a/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp +++ b/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp @@ -169,9 +169,9 @@ static bool getMemOperands(unsigned Factor, VectorType *VTy, Type *XLenTy, } case Intrinsic::masked_load: { Ptr = II->getOperand(0); - Alignment = cast<ConstantInt>(II->getArgOperand(1))->getAlignValue(); + Alignment = II->getParamAlign(0).valueOrOne(); - if (!isa<UndefValue>(II->getOperand(3))) + if (!isa<UndefValue>(II->getOperand(2))) return false; assert(Mask && "masked.load needs a mask!"); @@ -183,7 +183,7 @@ static bool getMemOperands(unsigned Factor, VectorType *VTy, Type *XLenTy, } case Intrinsic::masked_store: { Ptr = II->getOperand(1); - Alignment = cast<ConstantInt>(II->getArgOperand(2))->getAlignValue(); + Alignment = II->getParamAlign(1).valueOrOne(); assert(Mask && "masked.store needs a mask!"); diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h index 6acf799..334db4b 100644 --- a/llvm/lib/Target/RISCV/RISCVSubtarget.h +++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h @@ -288,9 +288,12 @@ public: bool hasVInstructionsI64() const { return HasStdExtZve64x; } bool hasVInstructionsF16Minimal() const { return HasStdExtZvfhmin; } bool hasVInstructionsF16() const { return HasStdExtZvfh; } - bool hasVInstructionsBF16Minimal() const { return HasStdExtZvfbfmin; } + bool hasVInstructionsBF16Minimal() const { + return HasStdExtZvfbfmin || HasStdExtZvfbfa; + } bool hasVInstructionsF32() const { return HasStdExtZve32f; } bool hasVInstructionsF64() const { return HasStdExtZve64d; } + bool hasVInstructionsBF16() const { return HasStdExtZvfbfa; } // F16 and F64 both require F32. bool hasVInstructionsAnyF() const { return hasVInstructionsF32(); } bool hasVInstructionsFullMultiply() const { return HasStdExtV; } diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index dbe8e18..d91923b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -507,7 +507,9 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister, static Register buildBuiltinVariableLoad( MachineIRBuilder &MIRBuilder, SPIRVType *VariableType, SPIRVGlobalRegistry *GR, SPIRV::BuiltIn::BuiltIn BuiltinValue, LLT LLType, - Register Reg = Register(0), bool isConst = true, bool hasLinkageTy = true) { + Register Reg = Register(0), bool isConst = true, + const std::optional<SPIRV::LinkageType::LinkageType> &LinkageTy = { + SPIRV::LinkageType::Import}) { Register NewRegister = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::pIDRegClass); MIRBuilder.getMRI()->setType( @@ -521,9 +523,8 @@ static Register buildBuiltinVariableLoad( // Set up the global OpVariable with the necessary builtin decorations. Register Variable = GR->buildGlobalVariable( NewRegister, PtrType, getLinkStringForBuiltIn(BuiltinValue), nullptr, - SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst, - /* HasLinkageTy */ hasLinkageTy, SPIRV::LinkageType::Import, MIRBuilder, - false); + SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst, LinkageTy, + MIRBuilder, false); // Load the value from the global variable. Register LoadedRegister = @@ -1851,7 +1852,7 @@ static bool generateWaveInst(const SPIRV::IncomingCall *Call, return buildBuiltinVariableLoad( MIRBuilder, Call->ReturnType, GR, Value, LLType, Call->ReturnRegister, - /* isConst= */ false, /* hasLinkageTy= */ false); + /* isConst= */ false, /* LinkageType= */ std::nullopt); } // We expect a builtin diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 1a7c02c..9e11c3a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -479,19 +479,9 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, .addImm(static_cast<uint32_t>(getExecutionModel(*ST, F))) .addUse(FuncVReg); addStringImm(F.getName(), MIB); - } else if (F.getLinkage() != GlobalValue::InternalLinkage && - F.getLinkage() != GlobalValue::PrivateLinkage && - F.getVisibility() != GlobalValue::HiddenVisibility) { - SPIRV::LinkageType::LinkageType LnkTy = - F.isDeclaration() - ? SPIRV::LinkageType::Import - : (F.getLinkage() == GlobalValue::LinkOnceODRLinkage && - ST->canUseExtension( - SPIRV::Extension::SPV_KHR_linkonce_odr) - ? SPIRV::LinkageType::LinkOnceODR - : SPIRV::LinkageType::Export); + } else if (const auto LnkTy = getSpirvLinkageTypeFor(*ST, F)) { buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, - {static_cast<uint32_t>(LnkTy)}, F.getName()); + {static_cast<uint32_t>(*LnkTy)}, F.getName()); } // Handle function pointers decoration diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 6fd1c7e..6181abb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -712,9 +712,9 @@ SPIRVGlobalRegistry::buildConstantSampler(Register ResReg, unsigned AddrMode, Register SPIRVGlobalRegistry::buildGlobalVariable( Register ResVReg, SPIRVType *BaseType, StringRef Name, const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage, - const MachineInstr *Init, bool IsConst, bool HasLinkageTy, - SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, - bool IsInstSelector) { + const MachineInstr *Init, bool IsConst, + const std::optional<SPIRV::LinkageType::LinkageType> &LinkageType, + MachineIRBuilder &MIRBuilder, bool IsInstSelector) { const GlobalVariable *GVar = nullptr; if (GV) { GVar = cast<const GlobalVariable>(GV); @@ -792,9 +792,9 @@ Register SPIRVGlobalRegistry::buildGlobalVariable( buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment}); } - if (HasLinkageTy) + if (LinkageType) buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, - {static_cast<uint32_t>(LinkageType)}, Name); + {static_cast<uint32_t>(*LinkageType)}, Name); SPIRV::BuiltIn::BuiltIn BuiltInId; if (getSpirvBuiltInIdByName(Name, BuiltInId)) @@ -821,8 +821,8 @@ Register SPIRVGlobalRegistry::getOrCreateGlobalVariableWithBinding( MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass); buildGlobalVariable(VarReg, VarType, Name, nullptr, - getPointerStorageClass(VarType), nullptr, false, false, - SPIRV::LinkageType::Import, MIRBuilder, false); + getPointerStorageClass(VarType), nullptr, false, + std::nullopt, MIRBuilder, false); buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::DescriptorSet, {Set}); buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::Binding, {Binding}); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index a648def..c230e62 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -548,14 +548,12 @@ public: MachineIRBuilder &MIRBuilder); Register getOrCreateUndef(MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII); - Register buildGlobalVariable(Register Reg, SPIRVType *BaseType, - StringRef Name, const GlobalValue *GV, - SPIRV::StorageClass::StorageClass Storage, - const MachineInstr *Init, bool IsConst, - bool HasLinkageTy, - SPIRV::LinkageType::LinkageType LinkageType, - MachineIRBuilder &MIRBuilder, - bool IsInstSelector); + Register buildGlobalVariable( + Register Reg, SPIRVType *BaseType, StringRef Name, const GlobalValue *GV, + SPIRV::StorageClass::StorageClass Storage, const MachineInstr *Init, + bool IsConst, + const std::optional<SPIRV::LinkageType::LinkageType> &LinkageType, + MachineIRBuilder &MIRBuilder, bool IsInstSelector); Register getOrCreateGlobalVariableWithBinding(const SPIRVType *VarType, uint32_t Set, uint32_t Binding, StringRef Name, diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index a0cff4d..5591d9f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -4350,15 +4350,8 @@ bool SPIRVInstructionSelector::selectGlobalValue( if (hasInitializer(GlobalVar) && !Init) return true; - bool HasLnkTy = !GV->hasInternalLinkage() && !GV->hasPrivateLinkage() && - !GV->hasHiddenVisibility(); - SPIRV::LinkageType::LinkageType LnkType = - GV->isDeclarationForLinker() - ? SPIRV::LinkageType::Import - : (GV->hasLinkOnceODRLinkage() && - STI.canUseExtension(SPIRV::Extension::SPV_KHR_linkonce_odr) - ? SPIRV::LinkageType::LinkOnceODR - : SPIRV::LinkageType::Export); + const std::optional<SPIRV::LinkageType::LinkageType> LnkType = + getSpirvLinkageTypeFor(STI, *GV); const unsigned AddrSpace = GV->getAddressSpace(); SPIRV::StorageClass::StorageClass StorageClass = @@ -4366,7 +4359,7 @@ bool SPIRVInstructionSelector::selectGlobalValue( SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(GVType, I, StorageClass); Register Reg = GR.buildGlobalVariable( ResVReg, ResType, GlobalIdent, GV, StorageClass, Init, - GlobalVar->isConstant(), HasLnkTy, LnkType, MIRBuilder, true); + GlobalVar->isConstant(), LnkType, MIRBuilder, true); return Reg.isValid(); } @@ -4517,8 +4510,8 @@ bool SPIRVInstructionSelector::loadVec3BuiltinInputID( // builtin variable. Register Variable = GR.buildGlobalVariable( NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr, - SPIRV::StorageClass::Input, nullptr, true, false, - SPIRV::LinkageType::Import, MIRBuilder, false); + SPIRV::StorageClass::Input, nullptr, true, std::nullopt, MIRBuilder, + false); // Create new register for loading value. MachineRegisterInfo *MRI = MIRBuilder.getMRI(); @@ -4570,8 +4563,8 @@ bool SPIRVInstructionSelector::loadBuiltinInputID( // builtin variable. Register Variable = GR.buildGlobalVariable( NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr, - SPIRV::StorageClass::Input, nullptr, true, false, - SPIRV::LinkageType::Import, MIRBuilder, false); + SPIRV::StorageClass::Input, nullptr, true, std::nullopt, MIRBuilder, + false); // Load uint value from the global variable. auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad)) diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 61a0bbe..f7cdfcb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -547,9 +547,9 @@ void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, if (MI.getOpcode() == SPIRV::OpDecorate) { // If it's got Import linkage. auto Dec = MI.getOperand(1).getImm(); - if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) { + if (Dec == SPIRV::Decoration::LinkageAttributes) { auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm(); - if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) { + if (Lnk == SPIRV::LinkageType::Import) { // Map imported function name to function ID register. const Function *ImportedFunc = F->getParent()->getFunction(getStringImm(MI, 2)); @@ -635,7 +635,7 @@ static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) { InstrTraces IS; for (auto F = M.begin(), E = M.end(); F != E; ++F) { - if ((*F).isDeclaration()) + if (F->isDeclaration()) continue; MachineFunction *MF = MMI->getMachineFunction(*F); assert(MF); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h index d8376cd..2d19f6de 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h @@ -169,9 +169,7 @@ struct ModuleAnalysisInfo { MCRegister getFuncReg(const Function *F) { assert(F && "Function is null"); - auto FuncPtrRegPair = FuncMap.find(F); - return FuncPtrRegPair == FuncMap.end() ? MCRegister() - : FuncPtrRegPair->second; + return FuncMap.lookup(F); } MCRegister getExtInstSetReg(unsigned SetNum) { return ExtInstSetMap[SetNum]; } InstrList &getMSInstrs(unsigned MSType) { return MS[MSType]; } diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 1d47c89..4e2cc88 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -1040,4 +1040,19 @@ getFirstValidInstructionInsertPoint(MachineBasicBlock &BB) { : VarPos; } +std::optional<SPIRV::LinkageType::LinkageType> +getSpirvLinkageTypeFor(const SPIRVSubtarget &ST, const GlobalValue &GV) { + if (GV.hasLocalLinkage() || GV.hasHiddenVisibility()) + return std::nullopt; + + if (GV.isDeclarationForLinker()) + return SPIRV::LinkageType::Import; + + if (GV.hasLinkOnceODRLinkage() && + ST.canUseExtension(SPIRV::Extension::SPV_KHR_linkonce_odr)) + return SPIRV::LinkageType::LinkOnceODR; + + return SPIRV::LinkageType::Export; +} + } // namespace llvm diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index 5777a24..99d9d40 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -559,5 +559,8 @@ unsigned getArrayComponentCount(const MachineRegisterInfo *MRI, const MachineInstr *ResType); MachineBasicBlock::iterator getFirstValidInstructionInsertPoint(MachineBasicBlock &BB); + +std::optional<SPIRV::LinkageType::LinkageType> +getSpirvLinkageTypeFor(const SPIRVSubtarget &ST, const GlobalValue &GV); } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td index ed54404d..7840620 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -1583,11 +1583,9 @@ def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$lhs), // MLA: v16i8 -> v4i32 def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v16i8 V128:$lhs), (v16i8 V128:$rhs))), - (ADD_I32x4 (ADD_I32x4 (DOT (extend_low_s_I16x8 $lhs), - (extend_low_s_I16x8 $rhs)), - (DOT (extend_high_s_I16x8 $lhs), - (extend_high_s_I16x8 $rhs))), - $acc)>; + (ADD_I32x4 (ADD_I32x4 (extadd_pairwise_s_I32x4 (EXTMUL_LOW_S_I16x8 $lhs, $rhs)), + (extadd_pairwise_s_I32x4 (EXTMUL_HIGH_S_I16x8 $lhs, $rhs))), + $acc)>; def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v16i8 V128:$lhs), (v16i8 V128:$rhs))), (ADD_I32x4 (ADD_I32x4 (extadd_pairwise_u_I32x4 (EXTMUL_LOW_U_I16x8 $lhs, $rhs)), diff --git a/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp b/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp index 100f1ec..53ec712 100644 --- a/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp +++ b/llvm/lib/Target/X86/GISel/X86InstructionSelector.cpp @@ -1879,28 +1879,34 @@ bool X86InstructionSelector::selectSelect(MachineInstr &I, unsigned OpCmp; LLT Ty = MRI.getType(DstReg); - switch (Ty.getSizeInBits()) { - default: - return false; - case 8: - OpCmp = X86::CMOV_GR8; - break; - case 16: - OpCmp = STI.canUseCMOV() ? X86::CMOV16rr : X86::CMOV_GR16; - break; - case 32: - OpCmp = STI.canUseCMOV() ? X86::CMOV32rr : X86::CMOV_GR32; - break; - case 64: - assert(STI.is64Bit() && STI.canUseCMOV()); - OpCmp = X86::CMOV64rr; - break; + if (Ty.getSizeInBits() == 80) { + BuildMI(*Sel.getParent(), Sel, Sel.getDebugLoc(), TII.get(X86::CMOVE_Fp80), + DstReg) + .addReg(Sel.getTrueReg()) + .addReg(Sel.getFalseReg()); + } else { + switch (Ty.getSizeInBits()) { + default: + return false; + case 8: + OpCmp = X86::CMOV_GR8; + break; + case 16: + OpCmp = STI.canUseCMOV() ? X86::CMOV16rr : X86::CMOV_GR16; + break; + case 32: + OpCmp = STI.canUseCMOV() ? X86::CMOV32rr : X86::CMOV_GR32; + break; + case 64: + assert(STI.is64Bit() && STI.canUseCMOV()); + OpCmp = X86::CMOV64rr; + break; + } + BuildMI(*Sel.getParent(), Sel, Sel.getDebugLoc(), TII.get(OpCmp), DstReg) + .addReg(Sel.getTrueReg()) + .addReg(Sel.getFalseReg()) + .addImm(X86::COND_E); } - BuildMI(*Sel.getParent(), Sel, Sel.getDebugLoc(), TII.get(OpCmp), DstReg) - .addReg(Sel.getTrueReg()) - .addReg(Sel.getFalseReg()) - .addImm(X86::COND_E); - const TargetRegisterClass *DstRC = getRegClass(Ty, DstReg, MRI); if (!RBI.constrainGenericRegister(DstReg, *DstRC, MRI)) { LLVM_DEBUG(dbgs() << "Failed to constrain CMOV\n"); diff --git a/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp b/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp index 28fa2cd..e792b1b 100644 --- a/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp +++ b/llvm/lib/Target/X86/GISel/X86LegalizerInfo.cpp @@ -575,10 +575,13 @@ X86LegalizerInfo::X86LegalizerInfo(const X86Subtarget &STI, // todo: vectors and address spaces getActionDefinitionsBuilder(G_SELECT) - .legalFor({{s8, s32}, {s16, s32}, {s32, s32}, {s64, s32}, {p0, s32}}) + .legalFor({{s16, s32}, {s32, s32}, {p0, s32}}) + .legalFor(!HasCMOV, {{s8, s32}}) + .legalFor(Is64Bit, {{s64, s32}}) + .legalFor(UseX87, {{s80, s32}}) + .clampScalar(1, s32, s32) .widenScalarToNextPow2(0, /*Min=*/8) - .clampScalar(0, HasCMOV ? s16 : s8, sMaxScalar) - .clampScalar(1, s32, s32); + .clampScalar(0, HasCMOV ? s16 : s8, sMaxScalar); // memory intrinsics getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE, G_MEMSET}).libcall(); diff --git a/llvm/lib/Target/X86/X86.td b/llvm/lib/Target/X86/X86.td index 8e08d16..a1fd366 100644 --- a/llvm/lib/Target/X86/X86.td +++ b/llvm/lib/Target/X86/X86.td @@ -1164,7 +1164,6 @@ def ProcessorFeatures { FeatureAVXNECONVERT, FeatureAVXVNNIINT8, FeatureAVXVNNIINT16, - FeatureUSERMSR, FeatureSHA512, FeatureSM3, FeatureEGPR, diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 2feee05..b5f8ee5 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -41846,7 +41846,7 @@ static SDValue combineCommutableSHUFP(SDValue N, MVT VT, const SDLoc &DL, if (!X86::mayFoldLoad(peekThroughOneUseBitcasts(N0), Subtarget) || X86::mayFoldLoad(peekThroughOneUseBitcasts(N1), Subtarget)) return SDValue(); - Imm = ((Imm & 0x0F) << 4) | ((Imm & 0xF0) >> 4); + Imm = llvm::rotl<uint8_t>(Imm, 4); return DAG.getNode(X86ISD::SHUFP, DL, VT, N1, N0, DAG.getTargetConstant(Imm, DL, MVT::i8)); }; @@ -44813,10 +44813,16 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode( } case X86ISD::PCMPGT: // icmp sgt(0, R) == ashr(R, BitWidth-1). - // iff we only need the sign bit then we can use R directly. - if (OriginalDemandedBits.isSignMask() && - ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode())) - return TLO.CombineTo(Op, Op.getOperand(1)); + if (ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode())) { + // iff we only need the signbit then we can use R directly. + if (OriginalDemandedBits.isSignMask()) + return TLO.CombineTo(Op, Op.getOperand(1)); + // otherwise we just need R's signbit for the comparison. + APInt SignMask = APInt::getSignMask(BitWidth); + if (SimplifyDemandedBits(Op.getOperand(1), SignMask, OriginalDemandedElts, + Known, TLO, Depth + 1)) + return true; + } break; case X86ISD::MOVMSK: { SDValue Src = Op.getOperand(0); diff --git a/llvm/lib/TargetParser/ARMTargetParser.cpp b/llvm/lib/TargetParser/ARMTargetParser.cpp index 08944e6..7882045 100644 --- a/llvm/lib/TargetParser/ARMTargetParser.cpp +++ b/llvm/lib/TargetParser/ARMTargetParser.cpp @@ -235,16 +235,16 @@ ARM::NeonSupportLevel ARM::getFPUNeonSupportLevel(ARM::FPUKind FPUKind) { StringRef ARM::getFPUSynonym(StringRef FPU) { return StringSwitch<StringRef>(FPU) - .Cases("fpa", "fpe2", "fpe3", "maverick", "invalid") // Unsupported + .Cases({"fpa", "fpe2", "fpe3", "maverick"}, "invalid") // Unsupported .Case("vfp2", "vfpv2") .Case("vfp3", "vfpv3") .Case("vfp4", "vfpv4") .Case("vfp3-d16", "vfpv3-d16") .Case("vfp4-d16", "vfpv4-d16") - .Cases("fp4-sp-d16", "vfpv4-sp-d16", "fpv4-sp-d16") - .Cases("fp4-dp-d16", "fpv4-dp-d16", "vfpv4-d16") + .Cases({"fp4-sp-d16", "vfpv4-sp-d16"}, "fpv4-sp-d16") + .Cases({"fp4-dp-d16", "fpv4-dp-d16"}, "vfpv4-d16") .Case("fp5-sp-d16", "fpv5-sp-d16") - .Cases("fp5-dp-d16", "fpv5-dp-d16", "fpv5-d16") + .Cases({"fp5-dp-d16", "fpv5-dp-d16"}, "fpv5-d16") // FIXME: Clang uses it, but it's bogus, since neon defaults to vfpv3. .Case("neon-vfpv3", "neon") .Default(FPU); diff --git a/llvm/lib/TargetParser/X86TargetParser.cpp b/llvm/lib/TargetParser/X86TargetParser.cpp index dd13ce3..b13c795 100644 --- a/llvm/lib/TargetParser/X86TargetParser.cpp +++ b/llvm/lib/TargetParser/X86TargetParser.cpp @@ -143,8 +143,7 @@ constexpr FeatureBitset FeaturesDiamondRapids = FeatureAVXVNNIINT8 | FeatureAVXVNNIINT16 | FeatureSHA512 | FeatureSM3 | FeatureSM4 | FeatureEGPR | FeatureZU | FeatureCCMP | FeaturePush2Pop2 | FeaturePPX | FeatureNDD | FeatureNF | FeatureMOVRS | FeatureAMX_MOVRS | - FeatureAMX_AVX512 | FeatureAMX_FP8 | FeatureAMX_TF32 | - FeatureAMX_TRANSPOSE | FeatureUSERMSR; + FeatureAMX_AVX512 | FeatureAMX_FP8 | FeatureAMX_TF32 | FeatureAMX_TRANSPOSE; // Intel Atom processors. // Bonnell has feature parity with Core2 and adds MOVBE. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index e1e24a9..dab200d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -289,12 +289,11 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { // * Narrow width by halfs excluding zero/undef lanes Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) { Value *LoadPtr = II.getArgOperand(0); - const Align Alignment = - cast<ConstantInt>(II.getArgOperand(1))->getAlignValue(); + const Align Alignment = II.getParamAlign(0).valueOrOne(); // If the mask is all ones or undefs, this is a plain vector load of the 1st // argument. - if (maskIsAllOneOrUndef(II.getArgOperand(2))) { + if (maskIsAllOneOrUndef(II.getArgOperand(1))) { LoadInst *L = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, "unmaskedload"); L->copyMetadata(II); @@ -308,7 +307,7 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) { LoadInst *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, "unmaskedload"); LI->copyMetadata(II); - return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3)); + return Builder.CreateSelect(II.getArgOperand(1), LI, II.getArgOperand(2)); } return nullptr; @@ -319,8 +318,8 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) { // * Narrow width by halfs excluding zero/undef lanes Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) { Value *StorePtr = II.getArgOperand(1); - Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); - auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); + Align Alignment = II.getParamAlign(1).valueOrOne(); + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2)); if (!ConstMask) return nullptr; @@ -356,7 +355,7 @@ Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) { // * Narrow width by halfs excluding zero/undef lanes // * Vector incrementing address -> vector masked load Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) { - auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2)); + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(1)); if (!ConstMask) return nullptr; @@ -366,8 +365,7 @@ Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) { if (ConstMask->isAllOnesValue()) if (auto *SplatPtr = getSplatValue(II.getArgOperand(0))) { auto *VecTy = cast<VectorType>(II.getType()); - const Align Alignment = - cast<ConstantInt>(II.getArgOperand(1))->getAlignValue(); + const Align Alignment = II.getParamAlign(0).valueOrOne(); LoadInst *L = Builder.CreateAlignedLoad(VecTy->getElementType(), SplatPtr, Alignment, "load.scalar"); Value *Shuf = @@ -384,7 +382,7 @@ Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) { // * Narrow store width by halfs excluding zero/undef lanes // * Vector incrementing address -> vector masked store Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { - auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2)); if (!ConstMask) return nullptr; @@ -397,8 +395,7 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) { if (maskContainsAllOneOrUndef(ConstMask)) { - Align Alignment = - cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); + Align Alignment = II.getParamAlign(1).valueOrOne(); StoreInst *S = new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment); S->copyMetadata(II); @@ -408,7 +405,7 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { // scatter(vector, splat(ptr), splat(true)) -> store extract(vector, // lastlane), ptr if (ConstMask->isAllOnesValue()) { - Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); + Align Alignment = II.getParamAlign(1).valueOrOne(); VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType()); ElementCount VF = WideLoadTy->getElementCount(); Value *RunTimeVF = Builder.CreateElementCount(Builder.getInt32Ty(), VF); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 09cb225..975498f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3757,6 +3757,10 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder, // (x < y) ? -1 : zext(x > y) // (x > y) ? 1 : sext(x != y) // (x > y) ? 1 : sext(x < y) +// (x == y) ? 0 : (x > y ? 1 : -1) +// (x == y) ? 0 : (x < y ? -1 : 1) +// Special case: x == C ? 0 : (x > C - 1 ? 1 : -1) +// Special case: x == C ? 0 : (x < C + 1 ? -1 : 1) // Into ucmp/scmp(x, y), where signedness is determined by the signedness // of the comparison in the original sequence. Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) { @@ -3849,6 +3853,44 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) { } } + // Special cases with constants: x == C ? 0 : (x > C-1 ? 1 : -1) + if (Pred == ICmpInst::ICMP_EQ && match(TV, m_Zero())) { + const APInt *C; + if (match(RHS, m_APInt(C))) { + CmpPredicate InnerPred; + Value *InnerRHS; + const APInt *InnerTV, *InnerFV; + if (match(FV, + m_Select(m_ICmp(InnerPred, m_Specific(LHS), m_Value(InnerRHS)), + m_APInt(InnerTV), m_APInt(InnerFV)))) { + + // x == C ? 0 : (x > C-1 ? 1 : -1) + if (ICmpInst::isGT(InnerPred) && InnerTV->isOne() && + InnerFV->isAllOnes()) { + IsSigned = ICmpInst::isSigned(InnerPred); + bool CanSubOne = IsSigned ? !C->isMinSignedValue() : !C->isMinValue(); + if (CanSubOne) { + APInt Cminus1 = *C - 1; + if (match(InnerRHS, m_SpecificInt(Cminus1))) + Replace = true; + } + } + + // x == C ? 0 : (x < C+1 ? -1 : 1) + if (ICmpInst::isLT(InnerPred) && InnerTV->isAllOnes() && + InnerFV->isOne()) { + IsSigned = ICmpInst::isSigned(InnerPred); + bool CanAddOne = IsSigned ? !C->isMaxSignedValue() : !C->isMaxValue(); + if (CanAddOne) { + APInt Cplus1 = *C + 1; + if (match(InnerRHS, m_SpecificInt(Cplus1))) + Replace = true; + } + } + } + } + } + Intrinsic::ID IID = IsSigned ? Intrinsic::scmp : Intrinsic::ucmp; if (Replace) return replaceInstUsesWith( @@ -4459,24 +4501,24 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Value *V = foldSelectIntoAddConstant(SI, Builder)) return replaceInstUsesWith(SI, V); - // select(mask, mload(,,mask,0), 0) -> mload(,,mask,0) + // select(mask, mload(ptr,mask,0), 0) -> mload(ptr,mask,0) // Load inst is intentionally not checked for hasOneUse() if (match(FalseVal, m_Zero()) && - (match(TrueVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(CondVal), + (match(TrueVal, m_MaskedLoad(m_Value(), m_Specific(CondVal), m_CombineOr(m_Undef(), m_Zero()))) || - match(TrueVal, m_MaskedGather(m_Value(), m_Value(), m_Specific(CondVal), + match(TrueVal, m_MaskedGather(m_Value(), m_Specific(CondVal), m_CombineOr(m_Undef(), m_Zero()))))) { auto *MaskedInst = cast<IntrinsicInst>(TrueVal); - if (isa<UndefValue>(MaskedInst->getArgOperand(3))) - MaskedInst->setArgOperand(3, FalseVal /* Zero */); + if (isa<UndefValue>(MaskedInst->getArgOperand(2))) + MaskedInst->setArgOperand(2, FalseVal /* Zero */); return replaceInstUsesWith(SI, MaskedInst); } Value *Mask; if (match(TrueVal, m_Zero()) && - (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(), m_Value(Mask), + (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(Mask), m_CombineOr(m_Undef(), m_Zero()))) || - match(FalseVal, m_MaskedGather(m_Value(), m_Value(), m_Value(Mask), + match(FalseVal, m_MaskedGather(m_Value(), m_Value(Mask), m_CombineOr(m_Undef(), m_Zero())))) && (CondVal->getType() == Mask->getType())) { // We can remove the select by ensuring the load zeros all lanes the @@ -4489,8 +4531,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (CanMergeSelectIntoLoad) { auto *MaskedInst = cast<IntrinsicInst>(FalseVal); - if (isa<UndefValue>(MaskedInst->getArgOperand(3))) - MaskedInst->setArgOperand(3, TrueVal /* Zero */); + if (isa<UndefValue>(MaskedInst->getArgOperand(2))) + MaskedInst->setArgOperand(2, TrueVal /* Zero */); return replaceInstUsesWith(SI, MaskedInst); } } @@ -4629,14 +4671,13 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } Value *MaskedLoadPtr; - const APInt *MaskedLoadAlignment; if (match(TrueVal, m_OneUse(m_MaskedLoad(m_Value(MaskedLoadPtr), - m_APInt(MaskedLoadAlignment), m_Specific(CondVal), m_Value())))) return replaceInstUsesWith( - SI, Builder.CreateMaskedLoad(TrueVal->getType(), MaskedLoadPtr, - Align(MaskedLoadAlignment->getZExtValue()), - CondVal, FalseVal)); + SI, Builder.CreateMaskedLoad( + TrueVal->getType(), MaskedLoadPtr, + cast<IntrinsicInst>(TrueVal)->getParamAlign(0).valueOrOne(), + CondVal, FalseVal)); return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index a330bb7..651e305 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -1892,7 +1892,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // segfaults which didn't exist in the original program. APInt DemandedPtrs(APInt::getAllOnes(VWidth)), DemandedPassThrough(DemandedElts); - if (auto *CMask = dyn_cast<Constant>(II->getOperand(2))) { + if (auto *CMask = dyn_cast<Constant>(II->getOperand(1))) { for (unsigned i = 0; i < VWidth; i++) { if (Constant *CElt = CMask->getAggregateElement(i)) { if (CElt->isNullValue()) @@ -1905,7 +1905,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, if (II->getIntrinsicID() == Intrinsic::masked_gather) simplifyAndSetOp(II, 0, DemandedPtrs, PoisonElts2); - simplifyAndSetOp(II, 3, DemandedPassThrough, PoisonElts3); + simplifyAndSetOp(II, 2, DemandedPassThrough, PoisonElts3); // Output elements are undefined if the element from both sources are. // TODO: can strengthen via mask as well. diff --git a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp index 2646334..cb6ca72 100644 --- a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -1494,11 +1494,8 @@ void AddressSanitizer::getInterestingMemoryOperands( if (ignoreAccess(I, BasePtr)) return; Type *Ty = IsWrite ? CI->getArgOperand(0)->getType() : CI->getType(); - MaybeAlign Alignment = Align(1); - // Otherwise no alignment guarantees. We probably got Undef. - if (auto *Op = dyn_cast<ConstantInt>(CI->getOperand(1 + OpOffset))) - Alignment = Op->getMaybeAlignValue(); - Value *Mask = CI->getOperand(2 + OpOffset); + MaybeAlign Alignment = CI->getParamAlign(0); + Value *Mask = CI->getOperand(1 + OpOffset); Interesting.emplace_back(I, OpOffset, IsWrite, Ty, Alignment, Mask); break; } diff --git a/llvm/lib/Transforms/Instrumentation/MemProfInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/MemProfInstrumentation.cpp index 3ae771a..3c0f185 100644 --- a/llvm/lib/Transforms/Instrumentation/MemProfInstrumentation.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemProfInstrumentation.cpp @@ -338,7 +338,7 @@ MemProfiler::isInterestingMemoryAccess(Instruction *I) const { } auto *BasePtr = CI->getOperand(0 + OpOffset); - Access.MaybeMask = CI->getOperand(2 + OpOffset); + Access.MaybeMask = CI->getOperand(1 + OpOffset); Access.Addr = BasePtr; } } diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index eff6f0c..b6cbecb 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -4191,10 +4191,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void handleMaskedGather(IntrinsicInst &I) { IRBuilder<> IRB(&I); Value *Ptrs = I.getArgOperand(0); - const Align Alignment( - cast<ConstantInt>(I.getArgOperand(1))->getZExtValue()); - Value *Mask = I.getArgOperand(2); - Value *PassThru = I.getArgOperand(3); + const Align Alignment = I.getParamAlign(0).valueOrOne(); + Value *Mask = I.getArgOperand(1); + Value *PassThru = I.getArgOperand(2); Type *PtrsShadowTy = getShadowTy(Ptrs); if (ClCheckAccessAddress) { @@ -4230,9 +4229,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> IRB(&I); Value *Values = I.getArgOperand(0); Value *Ptrs = I.getArgOperand(1); - const Align Alignment( - cast<ConstantInt>(I.getArgOperand(2))->getZExtValue()); - Value *Mask = I.getArgOperand(3); + const Align Alignment = I.getParamAlign(1).valueOrOne(); + Value *Mask = I.getArgOperand(2); Type *PtrsShadowTy = getShadowTy(Ptrs); if (ClCheckAccessAddress) { @@ -4262,9 +4260,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> IRB(&I); Value *V = I.getArgOperand(0); Value *Ptr = I.getArgOperand(1); - const Align Alignment( - cast<ConstantInt>(I.getArgOperand(2))->getZExtValue()); - Value *Mask = I.getArgOperand(3); + const Align Alignment = I.getParamAlign(1).valueOrOne(); + Value *Mask = I.getArgOperand(2); Value *Shadow = getShadow(V); if (ClCheckAccessAddress) { @@ -4295,10 +4292,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void handleMaskedLoad(IntrinsicInst &I) { IRBuilder<> IRB(&I); Value *Ptr = I.getArgOperand(0); - const Align Alignment( - cast<ConstantInt>(I.getArgOperand(1))->getZExtValue()); - Value *Mask = I.getArgOperand(2); - Value *PassThru = I.getArgOperand(3); + const Align Alignment = I.getParamAlign(0).valueOrOne(); + Value *Mask = I.getArgOperand(1); + Value *PassThru = I.getArgOperand(2); if (ClCheckAccessAddress) { insertCheckShadowOf(Ptr, &I); diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp index 6141b6d..4ac1321 100644 --- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -272,7 +272,7 @@ static OverwriteResult isMaskedStoreOverwrite(const Instruction *KillingI, if (KillingII->getIntrinsicID() == Intrinsic::masked_store) { // Masks. // TODO: check that KillingII's mask is a superset of the DeadII's mask. - if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3)) + if (KillingII->getArgOperand(2) != DeadII->getArgOperand(2)) return OW_Unknown; } else if (KillingII->getIntrinsicID() == Intrinsic::vp_store) { // Masks. diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index 2afa7b7..e30f306 100644 --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -1017,14 +1017,14 @@ private: }; auto MaskOp = [](const IntrinsicInst *II) { if (II->getIntrinsicID() == Intrinsic::masked_load) - return II->getOperand(2); + return II->getOperand(1); if (II->getIntrinsicID() == Intrinsic::masked_store) - return II->getOperand(3); + return II->getOperand(2); llvm_unreachable("Unexpected IntrinsicInst"); }; auto ThruOp = [](const IntrinsicInst *II) { if (II->getIntrinsicID() == Intrinsic::masked_load) - return II->getOperand(3); + return II->getOperand(2); llvm_unreachable("Unexpected IntrinsicInst"); }; diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp index 42db424..72e1131 100644 --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -2212,11 +2212,11 @@ bool GVNPass::processMaskedLoad(IntrinsicInst *I) { if (!DepInst || !Dep.isLocal() || !Dep.isDef()) return false; - Value *Mask = I->getOperand(2); - Value *Passthrough = I->getOperand(3); + Value *Mask = I->getOperand(1); + Value *Passthrough = I->getOperand(2); Value *StoreVal; - if (!match(DepInst, m_MaskedStore(m_Value(StoreVal), m_Value(), m_Value(), - m_Specific(Mask))) || + if (!match(DepInst, + m_MaskedStore(m_Value(StoreVal), m_Value(), m_Specific(Mask))) || StoreVal->getType() != I->getType()) return false; diff --git a/llvm/lib/Transforms/Scalar/GVNSink.cpp b/llvm/lib/Transforms/Scalar/GVNSink.cpp index b9534def..a06f832 100644 --- a/llvm/lib/Transforms/Scalar/GVNSink.cpp +++ b/llvm/lib/Transforms/Scalar/GVNSink.cpp @@ -430,6 +430,7 @@ public: case Instruction::FPTrunc: case Instruction::FPExt: case Instruction::PtrToInt: + case Instruction::PtrToAddr: case Instruction::IntToPtr: case Instruction::BitCast: case Instruction::AddrSpaceCast: diff --git a/llvm/lib/Transforms/Scalar/InferAlignment.cpp b/llvm/lib/Transforms/Scalar/InferAlignment.cpp index 995b803..39751c0 100644 --- a/llvm/lib/Transforms/Scalar/InferAlignment.cpp +++ b/llvm/lib/Transforms/Scalar/InferAlignment.cpp @@ -45,25 +45,20 @@ static bool tryToImproveAlign( switch (II->getIntrinsicID()) { case Intrinsic::masked_load: case Intrinsic::masked_store: { - int AlignOpIdx = II->getIntrinsicID() == Intrinsic::masked_load ? 1 : 2; - Value *PtrOp = II->getIntrinsicID() == Intrinsic::masked_load - ? II->getArgOperand(0) - : II->getArgOperand(1); + unsigned PtrOpIdx = II->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1; + Value *PtrOp = II->getArgOperand(PtrOpIdx); Type *Type = II->getIntrinsicID() == Intrinsic::masked_load ? II->getType() : II->getArgOperand(0)->getType(); - Align OldAlign = - cast<ConstantInt>(II->getArgOperand(AlignOpIdx))->getAlignValue(); + Align OldAlign = II->getParamAlign(PtrOpIdx).valueOrOne(); Align PrefAlign = DL.getPrefTypeAlign(Type); Align NewAlign = Fn(PtrOp, OldAlign, PrefAlign); - if (NewAlign <= OldAlign || - NewAlign.value() > std::numeric_limits<uint32_t>().max()) + if (NewAlign <= OldAlign) return false; - Value *V = - ConstantInt::get(Type::getInt32Ty(II->getContext()), NewAlign.value()); - II->setOperand(AlignOpIdx, V); + II->addParamAttr(PtrOpIdx, + Attribute::getWithAlignment(II->getContext(), NewAlign)); return true; } default: diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index 28ae4f0..9aaf6a5 100644 --- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -43,6 +43,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include <cassert> #include <utility> @@ -1872,6 +1873,51 @@ static void moveLCSSAPhis(BasicBlock *InnerExit, BasicBlock *InnerHeader, InnerLatch->replacePhiUsesWith(InnerLatch, OuterLatch); } +/// This deals with a corner case when a LCSSA phi node appears in a non-exit +/// block: the outer loop latch block does not need to be exit block of the +/// inner loop. Consider a loop that was in LCSSA form, but then some +/// transformation like loop-unswitch comes along and creates an empty block, +/// where BB5 in this example is the outer loop latch block: +/// +/// BB4: +/// br label %BB5 +/// BB5: +/// %old.cond.lcssa = phi i16 [ %cond, %BB4 ] +/// br outer.header +/// +/// Interchange then brings it in LCSSA form again resulting in this chain of +/// single-input phi nodes: +/// +/// BB4: +/// %new.cond.lcssa = phi i16 [ %cond, %BB3 ] +/// br label %BB5 +/// BB5: +/// %old.cond.lcssa = phi i16 [ %new.cond.lcssa, %BB4 ] +/// +/// The problem is that interchange can reoder blocks BB4 and BB5 placing the +/// use before the def if we don't check this. The solution is to simplify +/// lcssa phi nodes (remove) if they appear in non-exit blocks. +/// +static void simplifyLCSSAPhis(Loop *OuterLoop, Loop *InnerLoop) { + BasicBlock *InnerLoopExit = InnerLoop->getExitBlock(); + BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); + + // Do not modify lcssa phis where they actually belong, i.e. in exit blocks. + if (OuterLoopLatch == InnerLoopExit) + return; + + // Collect and remove phis in non-exit blocks if they have 1 input. + SmallVector<PHINode *, 8> Phis( + llvm::make_pointer_range(OuterLoopLatch->phis())); + for (PHINode *Phi : Phis) { + assert(Phi->getNumIncomingValues() == 1 && "Single input phi expected"); + LLVM_DEBUG(dbgs() << "Removing 1-input phi in non-exit block: " << *Phi + << "\n"); + Phi->replaceAllUsesWith(Phi->getIncomingValue(0)); + Phi->eraseFromParent(); + } +} + bool LoopInterchangeTransform::adjustLoopBranches() { LLVM_DEBUG(dbgs() << "adjustLoopBranches called\n"); std::vector<DominatorTree::UpdateType> DTUpdates; @@ -1882,6 +1928,9 @@ bool LoopInterchangeTransform::adjustLoopBranches() { assert(OuterLoopPreHeader != OuterLoop->getHeader() && InnerLoopPreHeader != InnerLoop->getHeader() && OuterLoopPreHeader && InnerLoopPreHeader && "Guaranteed by loop-simplify form"); + + simplifyLCSSAPhis(OuterLoop, InnerLoop); + // Ensure that both preheaders do not contain PHI nodes and have single // predecessors. This allows us to move them easily. We use // InsertPreHeaderForLoop to create an 'extra' preheader, if the existing diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp index 42d6680..146e7d1 100644 --- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp +++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp @@ -111,7 +111,7 @@ static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, } // Translate a masked load intrinsic like -// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, +// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, // <16 x i1> %mask, <16 x i32> %passthru) // to a chain of basic blocks, with loading element one-by-one if // the appropriate mask bit is set @@ -146,11 +146,10 @@ static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT) { Value *Ptr = CI->getArgOperand(0); - Value *Alignment = CI->getArgOperand(1); - Value *Mask = CI->getArgOperand(2); - Value *Src0 = CI->getArgOperand(3); + Value *Mask = CI->getArgOperand(1); + Value *Src0 = CI->getArgOperand(2); - const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); + const Align AlignVal = CI->getParamAlign(0).valueOrOne(); VectorType *VecType = cast<FixedVectorType>(CI->getType()); Type *EltTy = VecType->getElementType(); @@ -290,7 +289,7 @@ static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence, } // Translate a masked store intrinsic, like -// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, +// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, // <16 x i1> %mask) // to a chain of basic blocks, that stores element one-by-one if // the appropriate mask bit is set @@ -320,10 +319,9 @@ static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence, bool &ModifiedDT) { Value *Src = CI->getArgOperand(0); Value *Ptr = CI->getArgOperand(1); - Value *Alignment = CI->getArgOperand(2); - Value *Mask = CI->getArgOperand(3); + Value *Mask = CI->getArgOperand(2); - const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); + const Align AlignVal = CI->getParamAlign(1).valueOrOne(); auto *VecType = cast<VectorType>(Src->getType()); Type *EltTy = VecType->getElementType(); @@ -472,9 +470,8 @@ static void scalarizeMaskedGather(const DataLayout &DL, bool HasBranchDivergence, CallInst *CI, DomTreeUpdater *DTU, bool &ModifiedDT) { Value *Ptrs = CI->getArgOperand(0); - Value *Alignment = CI->getArgOperand(1); - Value *Mask = CI->getArgOperand(2); - Value *Src0 = CI->getArgOperand(3); + Value *Mask = CI->getArgOperand(1); + Value *Src0 = CI->getArgOperand(2); auto *VecType = cast<FixedVectorType>(CI->getType()); Type *EltTy = VecType->getElementType(); @@ -483,7 +480,7 @@ static void scalarizeMaskedGather(const DataLayout &DL, Instruction *InsertPt = CI; BasicBlock *IfBlock = CI->getParent(); Builder.SetInsertPoint(InsertPt); - MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); + Align AlignVal = CI->getParamAlign(0).valueOrOne(); Builder.SetCurrentDebugLocation(CI->getDebugLoc()); @@ -608,8 +605,7 @@ static void scalarizeMaskedScatter(const DataLayout &DL, DomTreeUpdater *DTU, bool &ModifiedDT) { Value *Src = CI->getArgOperand(0); Value *Ptrs = CI->getArgOperand(1); - Value *Alignment = CI->getArgOperand(2); - Value *Mask = CI->getArgOperand(3); + Value *Mask = CI->getArgOperand(2); auto *SrcFVTy = cast<FixedVectorType>(Src->getType()); @@ -623,7 +619,7 @@ static void scalarizeMaskedScatter(const DataLayout &DL, Builder.SetInsertPoint(InsertPt); Builder.SetCurrentDebugLocation(CI->getDebugLoc()); - MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); + Align AlignVal = CI->getParamAlign(1).valueOrOne(); unsigned VectorWidth = SrcFVTy->getNumElements(); // Shorten the way if the mask is a vector of constants. @@ -1125,8 +1121,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, case Intrinsic::masked_load: // Scalarize unsupported vector masked load if (TTI.isLegalMaskedLoad( - CI->getType(), - cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue(), + CI->getType(), CI->getParamAlign(0).valueOrOne(), cast<PointerType>(CI->getArgOperand(0)->getType()) ->getAddressSpace())) return false; @@ -1135,18 +1130,15 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, case Intrinsic::masked_store: if (TTI.isLegalMaskedStore( CI->getArgOperand(0)->getType(), - cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue(), + CI->getParamAlign(1).valueOrOne(), cast<PointerType>(CI->getArgOperand(1)->getType()) ->getAddressSpace())) return false; scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT); return true; case Intrinsic::masked_gather: { - MaybeAlign MA = - cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue(); + Align Alignment = CI->getParamAlign(0).valueOrOne(); Type *LoadTy = CI->getType(); - Align Alignment = DL.getValueOrABITypeAlignment(MA, - LoadTy->getScalarType()); if (TTI.isLegalMaskedGather(LoadTy, Alignment) && !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment)) return false; @@ -1154,11 +1146,8 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, return true; } case Intrinsic::masked_scatter: { - MaybeAlign MA = - cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue(); + Align Alignment = CI->getParamAlign(1).valueOrOne(); Type *StoreTy = CI->getArgOperand(0)->getType(); - Align Alignment = DL.getValueOrABITypeAlignment(MA, - StoreTy->getScalarType()); if (TTI.isLegalMaskedScatter(StoreTy, Alignment) && !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy), Alignment)) diff --git a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp index fa66a03..23e1243 100644 --- a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp +++ b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -227,6 +227,7 @@ static InstructionCost ComputeSpeculationCost(const Instruction *I, case Instruction::Call: case Instruction::BitCast: case Instruction::PtrToInt: + case Instruction::PtrToAddr: case Instruction::IntToPtr: case Instruction::AddrSpaceCast: case Instruction::FPToUI: diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp index 9693ae6..4947d03 100644 --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -20,6 +20,7 @@ #include "llvm/Analysis/ValueLatticeUtils.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/ConstantRange.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/Instructions.h" @@ -634,18 +635,10 @@ private: /// Merge \p MergeWithV into \p IV and push \p V to the worklist, if \p IV /// changes. bool mergeInValue(ValueLatticeElement &IV, Value *V, - ValueLatticeElement MergeWithV, + const ValueLatticeElement &MergeWithV, ValueLatticeElement::MergeOptions Opts = { /*MayIncludeUndef=*/false, /*CheckWiden=*/false}); - bool mergeInValue(Value *V, ValueLatticeElement MergeWithV, - ValueLatticeElement::MergeOptions Opts = { - /*MayIncludeUndef=*/false, /*CheckWiden=*/false}) { - assert(!V->getType()->isStructTy() && - "non-structs should use markConstant"); - return mergeInValue(ValueState[V], V, MergeWithV, Opts); - } - /// getValueState - Return the ValueLatticeElement object that corresponds to /// the value. This function handles the case when the value hasn't been seen /// yet by properly seeding constants etc. @@ -768,6 +761,7 @@ private: void handleCallArguments(CallBase &CB); void handleExtractOfWithOverflow(ExtractValueInst &EVI, const WithOverflowInst *WO, unsigned Idx); + bool isInstFullyOverDefined(Instruction &Inst); private: friend class InstVisitor<SCCPInstVisitor>; @@ -987,7 +981,7 @@ public: void trackValueOfArgument(Argument *A) { if (A->getType()->isStructTy()) return (void)markOverdefined(A); - mergeInValue(A, getArgAttributeVL(A)); + mergeInValue(ValueState[A], A, getArgAttributeVL(A)); } bool isStructLatticeConstant(Function *F, StructType *STy); @@ -1128,8 +1122,7 @@ bool SCCPInstVisitor::isStructLatticeConstant(Function *F, StructType *STy) { for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { const auto &It = TrackedMultipleRetVals.find(std::make_pair(F, i)); assert(It != TrackedMultipleRetVals.end()); - ValueLatticeElement LV = It->second; - if (!SCCPSolver::isConstant(LV)) + if (!SCCPSolver::isConstant(It->second)) return false; } return true; @@ -1160,7 +1153,7 @@ Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const { std::vector<Constant *> ConstVals; auto *ST = cast<StructType>(V->getType()); for (unsigned I = 0, E = ST->getNumElements(); I != E; ++I) { - ValueLatticeElement LV = LVs[I]; + const ValueLatticeElement &LV = LVs[I]; ConstVals.push_back(SCCPSolver::isConstant(LV) ? getConstant(LV, ST->getElementType(I)) : UndefValue::get(ST->getElementType(I))); @@ -1225,7 +1218,7 @@ void SCCPInstVisitor::visitInstruction(Instruction &I) { } bool SCCPInstVisitor::mergeInValue(ValueLatticeElement &IV, Value *V, - ValueLatticeElement MergeWithV, + const ValueLatticeElement &MergeWithV, ValueLatticeElement::MergeOptions Opts) { if (IV.mergeIn(MergeWithV, Opts)) { pushUsersToWorkList(V); @@ -1264,7 +1257,7 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, return; } - ValueLatticeElement BCValue = getValueState(BI->getCondition()); + const ValueLatticeElement &BCValue = getValueState(BI->getCondition()); ConstantInt *CI = getConstantInt(BCValue, BI->getCondition()->getType()); if (!CI) { // Overdefined condition variables, and branches on unfoldable constant @@ -1326,7 +1319,7 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI, // the target as executable. if (auto *IBR = dyn_cast<IndirectBrInst>(&TI)) { // Casts are folded by visitCastInst. - ValueLatticeElement IBRValue = getValueState(IBR->getAddress()); + const ValueLatticeElement &IBRValue = getValueState(IBR->getAddress()); BlockAddress *Addr = dyn_cast_or_null<BlockAddress>( getConstant(IBRValue, IBR->getAddress()->getType())); if (!Addr) { // Overdefined or unknown condition? @@ -1383,49 +1376,66 @@ bool SCCPInstVisitor::isEdgeFeasible(BasicBlock *From, BasicBlock *To) const { // 7. If a conditional branch has a value that is overdefined, make all // successors executable. void SCCPInstVisitor::visitPHINode(PHINode &PN) { - // If this PN returns a struct, just mark the result overdefined. - // TODO: We could do a lot better than this if code actually uses this. - if (PN.getType()->isStructTy()) - return (void)markOverdefined(&PN); - - if (getValueState(&PN).isOverdefined()) - return; // Quick exit - // Super-extra-high-degree PHI nodes are unlikely to ever be marked constant, // and slow us down a lot. Just mark them overdefined. if (PN.getNumIncomingValues() > 64) return (void)markOverdefined(&PN); - unsigned NumActiveIncoming = 0; + if (isInstFullyOverDefined(PN)) + return; + SmallVector<unsigned> FeasibleIncomingIndices; + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { + if (!isEdgeFeasible(PN.getIncomingBlock(i), PN.getParent())) + continue; + FeasibleIncomingIndices.push_back(i); + } // Look at all of the executable operands of the PHI node. If any of them // are overdefined, the PHI becomes overdefined as well. If they are all // constant, and they agree with each other, the PHI becomes the identical // constant. If they are constant and don't agree, the PHI is a constant // range. If there are no executable operands, the PHI remains unknown. - ValueLatticeElement PhiState = getValueState(&PN); - for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { - if (!isEdgeFeasible(PN.getIncomingBlock(i), PN.getParent())) - continue; - - ValueLatticeElement IV = getValueState(PN.getIncomingValue(i)); - PhiState.mergeIn(IV); - NumActiveIncoming++; - if (PhiState.isOverdefined()) - break; + if (StructType *STy = dyn_cast<StructType>(PN.getType())) { + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + ValueLatticeElement PhiState = getStructValueState(&PN, i); + if (PhiState.isOverdefined()) + continue; + for (unsigned j : FeasibleIncomingIndices) { + const ValueLatticeElement &IV = + getStructValueState(PN.getIncomingValue(j), i); + PhiState.mergeIn(IV); + if (PhiState.isOverdefined()) + break; + } + ValueLatticeElement &PhiStateRef = getStructValueState(&PN, i); + mergeInValue(PhiStateRef, &PN, PhiState, + ValueLatticeElement::MergeOptions().setMaxWidenSteps( + FeasibleIncomingIndices.size() + 1)); + PhiStateRef.setNumRangeExtensions( + std::max((unsigned)FeasibleIncomingIndices.size(), + PhiStateRef.getNumRangeExtensions())); + } + } else { + ValueLatticeElement PhiState = getValueState(&PN); + for (unsigned i : FeasibleIncomingIndices) { + const ValueLatticeElement &IV = getValueState(PN.getIncomingValue(i)); + PhiState.mergeIn(IV); + if (PhiState.isOverdefined()) + break; + } + // We allow up to 1 range extension per active incoming value and one + // additional extension. Note that we manually adjust the number of range + // extensions to match the number of active incoming values. This helps to + // limit multiple extensions caused by the same incoming value, if other + // incoming values are equal. + ValueLatticeElement &PhiStateRef = ValueState[&PN]; + mergeInValue(PhiStateRef, &PN, PhiState, + ValueLatticeElement::MergeOptions().setMaxWidenSteps( + FeasibleIncomingIndices.size() + 1)); + PhiStateRef.setNumRangeExtensions( + std::max((unsigned)FeasibleIncomingIndices.size(), + PhiStateRef.getNumRangeExtensions())); } - - // We allow up to 1 range extension per active incoming value and one - // additional extension. Note that we manually adjust the number of range - // extensions to match the number of active incoming values. This helps to - // limit multiple extensions caused by the same incoming value, if other - // incoming values are equal. - mergeInValue(&PN, PhiState, - ValueLatticeElement::MergeOptions().setMaxWidenSteps( - NumActiveIncoming + 1)); - ValueLatticeElement &PhiStateRef = getValueState(&PN); - PhiStateRef.setNumRangeExtensions( - std::max(NumActiveIncoming, PhiStateRef.getNumRangeExtensions())); } void SCCPInstVisitor::visitReturnInst(ReturnInst &I) { @@ -1481,7 +1491,7 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) { } } - ValueLatticeElement OpSt = getValueState(I.getOperand(0)); + const ValueLatticeElement &OpSt = getValueState(I.getOperand(0)); if (OpSt.isUnknownOrUndef()) return; @@ -1496,9 +1506,9 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) { if (I.getDestTy()->isIntOrIntVectorTy() && I.getSrcTy()->isIntOrIntVectorTy() && I.getOpcode() != Instruction::BitCast) { - auto &LV = getValueState(&I); ConstantRange OpRange = OpSt.asConstantRange(I.getSrcTy(), /*UndefAllowed=*/false); + auto &LV = getValueState(&I); Type *DestTy = I.getDestTy(); ConstantRange Res = ConstantRange::getEmpty(DestTy->getScalarSizeInBits()); @@ -1516,19 +1526,24 @@ void SCCPInstVisitor::handleExtractOfWithOverflow(ExtractValueInst &EVI, const WithOverflowInst *WO, unsigned Idx) { Value *LHS = WO->getLHS(), *RHS = WO->getRHS(); - ValueLatticeElement L = getValueState(LHS); - ValueLatticeElement R = getValueState(RHS); + Type *Ty = LHS->getType(); + addAdditionalUser(LHS, &EVI); addAdditionalUser(RHS, &EVI); - if (L.isUnknownOrUndef() || R.isUnknownOrUndef()) - return; // Wait to resolve. - Type *Ty = LHS->getType(); + const ValueLatticeElement &L = getValueState(LHS); + if (L.isUnknownOrUndef()) + return; // Wait to resolve. ConstantRange LR = L.asConstantRange(Ty, /*UndefAllowed=*/false); + + const ValueLatticeElement &R = getValueState(RHS); + if (R.isUnknownOrUndef()) + return; // Wait to resolve. + ConstantRange RR = R.asConstantRange(Ty, /*UndefAllowed=*/false); if (Idx == 0) { ConstantRange Res = LR.binaryOp(WO->getBinaryOp(), RR); - mergeInValue(&EVI, ValueLatticeElement::getRange(Res)); + mergeInValue(ValueState[&EVI], &EVI, ValueLatticeElement::getRange(Res)); } else { assert(Idx == 1 && "Index can only be 0 or 1"); ConstantRange NWRegion = ConstantRange::makeGuaranteedNoWrapRegion( @@ -1560,7 +1575,7 @@ void SCCPInstVisitor::visitExtractValueInst(ExtractValueInst &EVI) { if (auto *WO = dyn_cast<WithOverflowInst>(AggVal)) return handleExtractOfWithOverflow(EVI, WO, i); ValueLatticeElement EltVal = getStructValueState(AggVal, i); - mergeInValue(getValueState(&EVI), &EVI, EltVal); + mergeInValue(ValueState[&EVI], &EVI, EltVal); } else { // Otherwise, must be extracting from an array. return (void)markOverdefined(&EVI); @@ -1616,14 +1631,18 @@ void SCCPInstVisitor::visitSelectInst(SelectInst &I) { if (ValueState[&I].isOverdefined()) return (void)markOverdefined(&I); - ValueLatticeElement CondValue = getValueState(I.getCondition()); + const ValueLatticeElement &CondValue = getValueState(I.getCondition()); if (CondValue.isUnknownOrUndef()) return; if (ConstantInt *CondCB = getConstantInt(CondValue, I.getCondition()->getType())) { Value *OpVal = CondCB->isZero() ? I.getFalseValue() : I.getTrueValue(); - mergeInValue(&I, getValueState(OpVal)); + const ValueLatticeElement &OpValState = getValueState(OpVal); + // Safety: ValueState[&I] doesn't invalidate OpValState since it is already + // in the map. + assert(ValueState.contains(&I) && "&I is not in ValueState map."); + mergeInValue(ValueState[&I], &I, OpValState); return; } @@ -1721,7 +1740,7 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) { // being a special floating value. ValueLatticeElement NewV; NewV.markConstant(C, /*MayIncludeUndef=*/true); - return (void)mergeInValue(&I, NewV); + return (void)mergeInValue(ValueState[&I], &I, NewV); } } @@ -1741,7 +1760,7 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) { R = A.overflowingBinaryOp(BO->getOpcode(), B, OBO->getNoWrapKind()); else R = A.binaryOp(BO->getOpcode(), B); - mergeInValue(&I, ValueLatticeElement::getRange(R)); + mergeInValue(ValueState[&I], &I, ValueLatticeElement::getRange(R)); // TODO: Currently we do not exploit special values that produce something // better than overdefined with an overdefined operand for vector or floating @@ -1767,7 +1786,7 @@ void SCCPInstVisitor::visitCmpInst(CmpInst &I) { if (C) { ValueLatticeElement CV; CV.markConstant(C); - mergeInValue(&I, CV); + mergeInValue(ValueState[&I], &I, CV); return; } @@ -1802,7 +1821,7 @@ void SCCPInstVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { Operands.reserve(I.getNumOperands()); for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) { - ValueLatticeElement State = getValueState(I.getOperand(i)); + const ValueLatticeElement &State = getValueState(I.getOperand(i)); if (State.isUnknownOrUndef()) return; // Operands are not resolved yet. @@ -1881,14 +1900,13 @@ void SCCPInstVisitor::visitLoadInst(LoadInst &I) { if (ValueState[&I].isOverdefined()) return (void)markOverdefined(&I); - ValueLatticeElement PtrVal = getValueState(I.getOperand(0)); + const ValueLatticeElement &PtrVal = getValueState(I.getOperand(0)); if (PtrVal.isUnknownOrUndef()) return; // The pointer is not resolved yet! - ValueLatticeElement &IV = ValueState[&I]; - if (SCCPSolver::isConstant(PtrVal)) { Constant *Ptr = getConstant(PtrVal, I.getOperand(0)->getType()); + ValueLatticeElement &IV = ValueState[&I]; // load null is undefined. if (isa<ConstantPointerNull>(Ptr)) { @@ -1916,7 +1934,7 @@ void SCCPInstVisitor::visitLoadInst(LoadInst &I) { } // Fall back to metadata. - mergeInValue(&I, getValueFromMetadata(&I)); + mergeInValue(ValueState[&I], &I, getValueFromMetadata(&I)); } void SCCPInstVisitor::visitCallBase(CallBase &CB) { @@ -1944,7 +1962,7 @@ void SCCPInstVisitor::handleCallOverdefined(CallBase &CB) { return markOverdefined(&CB); // Can't handle struct args. if (A.get()->getType()->isMetadataTy()) continue; // Carried in CB, not allowed in Operands. - ValueLatticeElement State = getValueState(A); + const ValueLatticeElement &State = getValueState(A); if (State.isUnknownOrUndef()) return; // Operands are not resolved yet. @@ -1964,7 +1982,7 @@ void SCCPInstVisitor::handleCallOverdefined(CallBase &CB) { } // Fall back to metadata. - mergeInValue(&CB, getValueFromMetadata(&CB)); + mergeInValue(ValueState[&CB], &CB, getValueFromMetadata(&CB)); } void SCCPInstVisitor::handleCallArguments(CallBase &CB) { @@ -1992,10 +2010,11 @@ void SCCPInstVisitor::handleCallArguments(CallBase &CB) { mergeInValue(getStructValueState(&*AI, i), &*AI, CallArg, getMaxWidenStepsOpts()); } - } else - mergeInValue(&*AI, - getValueState(*CAI).intersect(getArgAttributeVL(&*AI)), - getMaxWidenStepsOpts()); + } else { + ValueLatticeElement CallArg = + getValueState(*CAI).intersect(getArgAttributeVL(&*AI)); + mergeInValue(ValueState[&*AI], &*AI, CallArg, getMaxWidenStepsOpts()); + } } } } @@ -2076,7 +2095,8 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { if (II->getIntrinsicID() == Intrinsic::vscale) { unsigned BitWidth = CB.getType()->getScalarSizeInBits(); const ConstantRange Result = getVScaleRange(II->getFunction(), BitWidth); - return (void)mergeInValue(II, ValueLatticeElement::getRange(Result)); + return (void)mergeInValue(ValueState[II], II, + ValueLatticeElement::getRange(Result)); } if (ConstantRange::isIntrinsicSupported(II->getIntrinsicID())) { @@ -2094,7 +2114,8 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { ConstantRange Result = ConstantRange::intrinsic(II->getIntrinsicID(), OpRanges); - return (void)mergeInValue(II, ValueLatticeElement::getRange(Result)); + return (void)mergeInValue(ValueState[II], II, + ValueLatticeElement::getRange(Result)); } } @@ -2121,10 +2142,25 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { return handleCallOverdefined(CB); // Not tracking this callee. // If so, propagate the return value of the callee into this call result. - mergeInValue(&CB, TFRVI->second, getMaxWidenStepsOpts()); + mergeInValue(ValueState[&CB], &CB, TFRVI->second, getMaxWidenStepsOpts()); } } +bool SCCPInstVisitor::isInstFullyOverDefined(Instruction &Inst) { + // For structure Type, we handle each member separately. + // A structure object won't be considered as overdefined when + // there is at least one member that is not overdefined. + if (StructType *STy = dyn_cast<StructType>(Inst.getType())) { + for (unsigned i = 0, e = STy->getNumElements(); i < e; ++i) { + if (!getStructValueState(&Inst, i).isOverdefined()) + return false; + } + return true; + } + + return getValueState(&Inst).isOverdefined(); +} + void SCCPInstVisitor::solve() { // Process the work lists until they are empty! while (!BBWorkList.empty() || !InstWorkList.empty()) { diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h index 7651ba1..3fed003 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -325,6 +325,8 @@ public: VPIRFlags Flags; if (Opcode == Instruction::Trunc) Flags = VPIRFlags::TruncFlagsTy(false, false); + else if (Opcode == Instruction::ZExt) + Flags = VPIRFlags::NonNegFlagsTy(false); return tryInsertInstruction( new VPWidenCastRecipe(Opcode, Op, ResultTy, Flags)); } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 280eb20..1cc9173 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1011,6 +1011,10 @@ public: /// \returns True if instruction \p I can be truncated to a smaller bitwidth /// for vectorization factor \p VF. bool canTruncateToMinimalBitwidth(Instruction *I, ElementCount VF) const { + // Truncs must truncate at most to their destination type. + if (isa_and_nonnull<TruncInst>(I) && MinBWs.contains(I) && + I->getType()->getScalarSizeInBits() < MinBWs.lookup(I)) + return false; return VF.isVector() && MinBWs.contains(I) && !isProfitableToScalarize(I, VF) && !isScalarAfterVectorization(I, VF); @@ -7192,7 +7196,8 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan( // TODO: Move to VPlan transform stage once the transition to the VPlan-based // cost model is complete for better cost estimates. VPlanTransforms::runPass(VPlanTransforms::unrollByUF, BestVPlan, BestUF); - VPlanTransforms::runPass(VPlanTransforms::materializeBuildVectors, BestVPlan); + VPlanTransforms::runPass(VPlanTransforms::materializePacksAndUnpacks, + BestVPlan); VPlanTransforms::runPass(VPlanTransforms::materializeBroadcasts, BestVPlan); VPlanTransforms::runPass(VPlanTransforms::replicateByVF, BestVPlan, BestVF); bool HasBranchWeights = diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index b62c8f1..3f18bd7 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -2242,8 +2242,49 @@ public: /// may not be necessary. bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const; bool isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy, - Align Alignment, const int64_t Diff, Value *Ptr0, - Value *PtrN, StridedPtrInfo &SPtrInfo) const; + Align Alignment, const int64_t Diff, + const size_t Sz) const; + + /// Return true if an array of scalar loads can be replaced with a strided + /// load (with constant stride). + /// + /// TODO: + /// It is possible that the load gets "widened". Suppose that originally each + /// load loads `k` bytes and `PointerOps` can be arranged as follows (`%s` is + /// constant): %b + 0 * %s + 0 %b + 0 * %s + 1 %b + 0 * %s + 2 + /// ... + /// %b + 0 * %s + (w - 1) + /// + /// %b + 1 * %s + 0 + /// %b + 1 * %s + 1 + /// %b + 1 * %s + 2 + /// ... + /// %b + 1 * %s + (w - 1) + /// ... + /// + /// %b + (n - 1) * %s + 0 + /// %b + (n - 1) * %s + 1 + /// %b + (n - 1) * %s + 2 + /// ... + /// %b + (n - 1) * %s + (w - 1) + /// + /// In this case we will generate a strided load of type `<n x (k * w)>`. + /// + /// \param PointerOps list of pointer arguments of loads. + /// \param ElemTy original scalar type of loads. + /// \param Alignment alignment of the first load. + /// \param SortedIndices is the order of PointerOps as returned by + /// `sortPtrAccesses` + /// \param Diff Pointer difference between the lowest and the highes pointer + /// in `PointerOps` as returned by `getPointersDiff`. + /// \param Ptr0 first pointer in `PointersOps`. + /// \param PtrN last pointer in `PointersOps`. + /// \param SPtrInfo If the function return `true`, it also sets all the fields + /// of `SPtrInfo` necessary to generate the strided load later. + bool analyzeConstantStrideCandidate( + const ArrayRef<Value *> PointerOps, Type *ElemTy, Align Alignment, + const SmallVectorImpl<unsigned> &SortedIndices, const int64_t Diff, + Value *Ptr0, Value *PtrN, StridedPtrInfo &SPtrInfo) const; /// Return true if an array of scalar loads can be replaced with a strided /// load (with run-time stride). @@ -5302,7 +5343,7 @@ private: unsigned &OpCnt = OrderedEntriesCount.try_emplace(TE, 0).first->getSecond(); EdgeInfo EI(TE, U.getOperandNo()); - if (!getScheduleCopyableData(EI, Op) && OpCnt < NumOps) + if (!getScheduleCopyableData(EI, Op)) continue; // Found copyable operand - continue. ++OpCnt; @@ -6849,9 +6890,8 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps, /// current graph (for masked gathers extra extractelement instructions /// might be required). bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy, - Align Alignment, const int64_t Diff, Value *Ptr0, - Value *PtrN, StridedPtrInfo &SPtrInfo) const { - const size_t Sz = PointerOps.size(); + Align Alignment, const int64_t Diff, + const size_t Sz) const { if (Diff % (Sz - 1) != 0) return false; @@ -6875,27 +6915,40 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy, return false; if (!TTI->isLegalStridedLoadStore(VecTy, Alignment)) return false; + return true; + } + return false; +} - // Iterate through all pointers and check if all distances are - // unique multiple of Dist. - SmallSet<int64_t, 4> Dists; - for (Value *Ptr : PointerOps) { - int64_t Dist = 0; - if (Ptr == PtrN) - Dist = Diff; - else if (Ptr != Ptr0) - Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE); - // If the strides are not the same or repeated, we can't - // vectorize. - if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second) - break; - } - if (Dists.size() == Sz) { - Type *StrideTy = DL->getIndexType(Ptr0->getType()); - SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride); - SPtrInfo.Ty = getWidenedType(ScalarTy, Sz); - return true; - } +bool BoUpSLP::analyzeConstantStrideCandidate( + const ArrayRef<Value *> PointerOps, Type *ScalarTy, Align Alignment, + const SmallVectorImpl<unsigned> &SortedIndices, const int64_t Diff, + Value *Ptr0, Value *PtrN, StridedPtrInfo &SPtrInfo) const { + const size_t Sz = PointerOps.size(); + if (!isStridedLoad(PointerOps, ScalarTy, Alignment, Diff, Sz)) + return false; + + int64_t Stride = Diff / static_cast<int64_t>(Sz - 1); + + // Iterate through all pointers and check if all distances are + // unique multiple of Dist. + SmallSet<int64_t, 4> Dists; + for (Value *Ptr : PointerOps) { + int64_t Dist = 0; + if (Ptr == PtrN) + Dist = Diff; + else if (Ptr != Ptr0) + Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE); + // If the strides are not the same or repeated, we can't + // vectorize. + if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second) + break; + } + if (Dists.size() == Sz) { + Type *StrideTy = DL->getIndexType(Ptr0->getType()); + SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride); + SPtrInfo.Ty = getWidenedType(ScalarTy, Sz); + return true; } return false; } @@ -6995,8 +7048,8 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads( Align Alignment = cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()]) ->getAlign(); - if (isStridedLoad(PointerOps, ScalarTy, Alignment, *Diff, Ptr0, PtrN, - SPtrInfo)) + if (analyzeConstantStrideCandidate(PointerOps, ScalarTy, Alignment, Order, + *Diff, Ptr0, PtrN, SPtrInfo)) return LoadsState::StridedVectorize; } if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) || @@ -10493,8 +10546,11 @@ static bool tryToFindDuplicates(SmallVectorImpl<Value *> &VL, PoisonValue::get(UniqueValues.front()->getType())); // Check that extended with poisons/copyable operations are still valid // for vectorization (div/rem are not allowed). - if (!S.areInstructionsWithCopyableElements() && - !getSameOpcode(PaddedUniqueValues, TLI).valid()) { + if ((!S.areInstructionsWithCopyableElements() && + !getSameOpcode(PaddedUniqueValues, TLI).valid()) || + (S.areInstructionsWithCopyableElements() && S.isMulDivLikeOp() && + (S.getMainOp()->isIntDivRem() || S.getMainOp()->isFPDivRem() || + isa<CallInst>(S.getMainOp())))) { LLVM_DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n"); ReuseShuffleIndices.clear(); return false; diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 0e0b042..fed04eb 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -407,6 +407,10 @@ public: VPBasicBlock *getParent() { return Parent; } const VPBasicBlock *getParent() const { return Parent; } + /// \return the VPRegionBlock which the recipe belongs to. + VPRegionBlock *getRegion(); + const VPRegionBlock *getRegion() const; + /// The method which generates the output IR instructions that correspond to /// this VPRecipe, thereby "executing" the VPlan. virtual void execute(VPTransformState &State) = 0; @@ -1003,6 +1007,11 @@ public: /// Creates a fixed-width vector containing all operands. The number of /// operands matches the vector element count. BuildVector, + /// Extracts all lanes from its (non-scalable) vector operand. This is an + /// abstract VPInstruction whose single defined VPValue represents VF + /// scalars extracted from a vector, to be replaced by VF ExtractElement + /// VPInstructions. + Unpack, /// Compute the final result of a AnyOf reduction with select(cmp(),x,y), /// where one of (x,y) is loop invariant, and both x and y are integer type. ComputeAnyOfResult, @@ -2711,6 +2720,15 @@ public: return R && classof(R); } + static inline bool classof(const VPValue *VPV) { + const VPRecipeBase *R = VPV->getDefiningRecipe(); + return R && classof(R); + } + + static inline bool classof(const VPSingleDefRecipe *R) { + return classof(static_cast<const VPRecipeBase *>(R)); + } + /// Generate the reduction in the loop. void execute(VPTransformState &State) override; @@ -3096,6 +3114,9 @@ public: /// Returns true if this expression contains recipes that may have side /// effects. bool mayHaveSideEffects() const; + + /// Returns true if the result of this VPExpressionRecipe is a single-scalar. + bool isSingleScalar() const; }; /// VPPredInstPHIRecipe is a recipe for generating the phi nodes needed when @@ -4075,6 +4096,14 @@ public: } }; +inline VPRegionBlock *VPRecipeBase::getRegion() { + return getParent()->getParent(); +} + +inline const VPRegionBlock *VPRecipeBase::getRegion() const { + return getParent()->getParent(); +} + /// VPlan models a candidate for vectorization, encoding various decisions take /// to produce efficient output IR, including which branches, basic-blocks and /// output IR instructions to generate, and their cost. VPlan holds a diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp index f413c63..80a2e4b 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp @@ -110,6 +110,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) { case VPInstruction::AnyOf: case VPInstruction::BuildStructVector: case VPInstruction::BuildVector: + case VPInstruction::Unpack: return SetResultTyFromOp(); case VPInstruction::ExtractLane: return inferScalarType(R->getOperand(1)); @@ -377,7 +378,7 @@ bool VPDominatorTree::properlyDominates(const VPRecipeBase *A, #ifndef NDEBUG auto GetReplicateRegion = [](VPRecipeBase *R) -> VPRegionBlock * { - auto *Region = dyn_cast_or_null<VPRegionBlock>(R->getParent()->getParent()); + VPRegionBlock *Region = R->getRegion(); if (Region && Region->isReplicator()) { assert(Region->getNumSuccessors() == 1 && Region->getNumPredecessors() == 1 && "Expected SESE region!"); diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h index d8203e2..b5b98c6 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h +++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h @@ -388,6 +388,12 @@ m_ExtractLastElement(const Op0_t &Op0) { return m_VPInstruction<VPInstruction::ExtractLastElement>(Op0); } +template <typename Op0_t, typename Op1_t> +inline VPInstruction_match<Instruction::ExtractElement, Op0_t, Op1_t> +m_ExtractElement(const Op0_t &Op0, const Op1_t &Op1) { + return m_VPInstruction<Instruction::ExtractElement>(Op0, Op1); +} + template <typename Op0_t> inline VPInstruction_match<VPInstruction::ExtractLastLanePerPart, Op0_t> m_ExtractLastLanePerPart(const Op0_t &Op0) { diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 7a98c75..1f1b42b 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -515,6 +515,7 @@ unsigned VPInstruction::getNumOperandsForOpcode(unsigned Opcode) { case VPInstruction::ExtractPenultimateElement: case VPInstruction::FirstActiveLane: case VPInstruction::Not: + case VPInstruction::Unpack: return 1; case Instruction::ICmp: case Instruction::FCmp: @@ -1246,6 +1247,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const { case VPInstruction::StepVector: case VPInstruction::ReductionStartVector: case VPInstruction::VScale: + case VPInstruction::Unpack: return false; default: return true; @@ -1290,7 +1292,8 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const { case VPInstruction::PtrAdd: return Op == getOperand(0) || vputils::onlyFirstLaneUsed(this); case VPInstruction::WidePtrAdd: - return Op == getOperand(0); + // WidePtrAdd supports scalar and vector base addresses. + return false; case VPInstruction::ComputeAnyOfResult: case VPInstruction::ComputeFindIVResult: return Op == getOperand(1); @@ -1417,6 +1420,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent, case VPInstruction::ResumeForEpilogue: O << "resume-for-epilogue"; break; + case VPInstruction::Unpack: + O << "unpack"; + break; default: O << Instruction::getOpcodeName(getOpcode()); } @@ -2352,7 +2358,7 @@ bool VPWidenIntOrFpInductionRecipe::isCanonical() const { return false; auto *StepC = dyn_cast<ConstantInt>(getStepValue()->getLiveInIRValue()); auto *StartC = dyn_cast<ConstantInt>(getStartValue()->getLiveInIRValue()); - auto *CanIV = getParent()->getParent()->getCanonicalIV(); + auto *CanIV = getRegion()->getCanonicalIV(); return StartC && StartC->isZero() && StepC && StepC->isOne() && getScalarType() == CanIV->getScalarType(); } @@ -2888,6 +2894,13 @@ bool VPExpressionRecipe::mayHaveSideEffects() const { return false; } +bool VPExpressionRecipe::isSingleScalar() const { + // Cannot use vputils::isSingleScalar(), because all external operands + // of the expression will be live-ins while bundled. + return isa<VPReductionRecipe>(ExpressionRecipes.back()) && + !isa<VPPartialReductionRecipe>(ExpressionRecipes.back()); +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, @@ -3076,7 +3089,7 @@ static void scalarizeInstruction(const Instruction *Instr, State.AC->registerAssumption(II); assert( - (RepRecipe->getParent()->getParent() || + (RepRecipe->getRegion() || !RepRecipe->getParent()->getPlan()->getVectorLoopRegion() || all_of(RepRecipe->operands(), [](VPValue *Op) { return Op->isDefinedOutsideLoopRegions(); })) && @@ -3149,7 +3162,17 @@ static bool isUsedByLoadStoreAddress(const VPUser *V) { while (!WorkList.empty()) { auto *Cur = dyn_cast<VPSingleDefRecipe>(WorkList.pop_back_val()); - if (!Cur || !Seen.insert(Cur).second || isa<VPBlendRecipe>(Cur)) + if (!Cur || !Seen.insert(Cur).second) + continue; + + auto *Blend = dyn_cast<VPBlendRecipe>(Cur); + // Skip blends that use V only through a compare by checking if any incoming + // value was already visited. + if (Blend && none_of(seq<unsigned>(0, Blend->getNumIncomingValues()), + [&](unsigned I) { + return Seen.contains( + Blend->getIncomingValue(I)->getDefiningRecipe()); + })) continue; for (VPUser *U : Cur->users()) { @@ -3170,7 +3193,13 @@ static bool isUsedByLoadStoreAddress(const VPUser *V) { } } - append_range(WorkList, cast<VPSingleDefRecipe>(Cur)->users()); + // The legacy cost model only supports scalarization loads/stores with phi + // addresses, if the phi is directly used as load/store address. Don't + // traverse further for Blends. + if (Blend) + continue; + + append_range(WorkList, Cur->users()); } return false; } @@ -3268,7 +3297,7 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF, to_vector(operands()), VF); // If the recipe is not predicated (i.e. not in a replicate region), return // the scalar cost. Otherwise handle predicated cost. - if (!getParent()->getParent()->isReplicator()) + if (!getRegion()->isReplicator()) return ScalarCost; // Account for the phi nodes that we will create. @@ -3284,7 +3313,7 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF, case Instruction::Store: { // TODO: See getMemInstScalarizationCost for how to handle replicating and // predicated cases. - const VPRegionBlock *ParentRegion = getParent()->getParent(); + const VPRegionBlock *ParentRegion = getRegion(); if (ParentRegion && ParentRegion->isReplicator()) break; diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index cae9aee8..688a013 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -106,7 +106,7 @@ bool VPlanTransforms::tryToConvertVPInstructionsToVPRecipes( return false; NewRecipe = new VPWidenIntrinsicRecipe( *CI, getVectorIntrinsicIDForCall(CI, &TLI), - {Ingredient.op_begin(), Ingredient.op_end() - 1}, CI->getType(), + drop_end(Ingredient.operands()), CI->getType(), CI->getDebugLoc()); } else if (SelectInst *SI = dyn_cast<SelectInst>(Inst)) { NewRecipe = new VPWidenSelectRecipe(*SI, Ingredient.operands()); @@ -356,8 +356,7 @@ static VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe, // Replace predicated replicate recipe with a replicate recipe without a // mask but in the replicate region. auto *RecipeWithoutMask = new VPReplicateRecipe( - PredRecipe->getUnderlyingInstr(), - make_range(PredRecipe->op_begin(), std::prev(PredRecipe->op_end())), + PredRecipe->getUnderlyingInstr(), drop_end(PredRecipe->operands()), PredRecipe->isSingleScalar(), nullptr /*Mask*/, *PredRecipe); auto *Pred = Plan.createVPBasicBlock(Twine(RegionName) + ".if", RecipeWithoutMask); @@ -939,7 +938,7 @@ static void recursivelyDeleteDeadRecipes(VPValue *V) { continue; if (!isDeadRecipe(*R)) continue; - WorkList.append(R->op_begin(), R->op_end()); + append_range(WorkList, R->operands()); R->eraseFromParent(); } } @@ -1224,6 +1223,13 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) { return; } + uint64_t Idx; + if (match(&R, m_ExtractElement(m_BuildVector(), m_ConstantInt(Idx)))) { + auto *BuildVector = cast<VPInstruction>(R.getOperand(0)); + Def->replaceAllUsesWith(BuildVector->getOperand(Idx)); + return; + } + if (auto *Phi = dyn_cast<VPPhi>(Def)) { if (Phi->getNumOperands() == 1) Phi->replaceAllUsesWith(Phi->getOperand(0)); @@ -1472,11 +1478,8 @@ static bool optimizeVectorInductionWidthForTCAndVFUF(VPlan &Plan, if (!Plan.getVectorLoopRegion()) return false; - if (!Plan.getTripCount()->isLiveIn()) - return false; - auto *TC = dyn_cast_if_present<ConstantInt>( - Plan.getTripCount()->getUnderlyingValue()); - if (!TC || !BestVF.isFixed()) + const APInt *TC; + if (!BestVF.isFixed() || !match(Plan.getTripCount(), m_APInt(TC))) return false; // Calculate the minimum power-of-2 bit width that can fit the known TC, VF @@ -1489,7 +1492,7 @@ static bool optimizeVectorInductionWidthForTCAndVFUF(VPlan &Plan, return std::max<unsigned>(PowerOf2Ceil(MaxVal.getActiveBits()), 8); }; unsigned NewBitWidth = - ComputeBitWidth(TC->getValue(), BestVF.getKnownMinValue() * BestUF); + ComputeBitWidth(*TC, BestVF.getKnownMinValue() * BestUF); LLVMContext &Ctx = Plan.getContext(); auto *NewIVTy = IntegerType::get(Ctx, NewBitWidth); @@ -1858,8 +1861,8 @@ static bool hoistPreviousBeforeFORUsers(VPFirstOrderRecurrencePHIRecipe *FOR, return nullptr; VPRegionBlock *EnclosingLoopRegion = HoistCandidate->getParent()->getEnclosingLoopRegion(); - assert((!HoistCandidate->getParent()->getParent() || - HoistCandidate->getParent()->getParent() == EnclosingLoopRegion) && + assert((!HoistCandidate->getRegion() || + HoistCandidate->getRegion() == EnclosingLoopRegion) && "CFG in VPlan should still be flat, without replicate regions"); // Hoist candidate was already visited, no need to hoist. if (!Visited.insert(HoistCandidate).second) @@ -2006,7 +2009,7 @@ struct VPCSEDenseMapInfo : public DenseMapInfo<VPSingleDefRecipe *> { .Case<VPWidenIntrinsicRecipe>([](auto *I) { return std::make_pair(true, I->getVectorIntrinsicID()); }) - .Case<VPVectorPointerRecipe>([](auto *I) { + .Case<VPVectorPointerRecipe, VPPredInstPHIRecipe>([](auto *I) { // For recipes that do not directly map to LLVM IR instructions, // assign opcodes after the last VPInstruction opcode (which is also // after the last IR Instruction opcode), based on the VPDefID. @@ -2083,6 +2086,15 @@ struct VPCSEDenseMapInfo : public DenseMapInfo<VPSingleDefRecipe *> { LFlags->getPredicate() != cast<VPRecipeWithIRFlags>(R)->getPredicate()) return false; + // Recipes in replicate regions implicitly depend on predicate. If either + // recipe is in a replicate region, only consider them equal if both have + // the same parent. + const VPRegionBlock *RegionL = L->getRegion(); + const VPRegionBlock *RegionR = R->getRegion(); + if (((RegionL && RegionL->isReplicator()) || + (RegionR && RegionR->isReplicator())) && + L->getParent() != R->getParent()) + return false; const VPlan *Plan = L->getParent()->getPlan(); VPTypeAnalysis TypeInfo(*Plan); return TypeInfo.inferScalarType(L) == TypeInfo.inferScalarType(R); @@ -2898,7 +2910,7 @@ void VPlanTransforms::replaceSymbolicStrides( // evolution. auto CanUseVersionedStride = [&Plan](VPUser &U, unsigned) { auto *R = cast<VPRecipeBase>(&U); - return R->getParent()->getParent() || + return R->getRegion() || R->getParent() == Plan.getVectorLoopRegion()->getSinglePredecessor(); }; ValueToSCEVMapTy RewriteMap; @@ -3780,7 +3792,7 @@ void VPlanTransforms::materializeBackedgeTakenCount(VPlan &Plan, BTC->replaceAllUsesWith(TCMO); } -void VPlanTransforms::materializeBuildVectors(VPlan &Plan) { +void VPlanTransforms::materializePacksAndUnpacks(VPlan &Plan) { if (Plan.hasScalarVFOnly()) return; @@ -3803,8 +3815,7 @@ void VPlanTransforms::materializeBuildVectors(VPlan &Plan) { continue; auto *DefR = cast<VPRecipeWithIRFlags>(&R); auto UsesVectorOrInsideReplicateRegion = [DefR, LoopRegion](VPUser *U) { - VPRegionBlock *ParentRegion = - cast<VPRecipeBase>(U)->getParent()->getParent(); + VPRegionBlock *ParentRegion = cast<VPRecipeBase>(U)->getRegion(); return !U->usesScalars(DefR) || ParentRegion != LoopRegion; }; if ((isa<VPReplicateRecipe>(DefR) && @@ -3829,6 +3840,49 @@ void VPlanTransforms::materializeBuildVectors(VPlan &Plan) { }); } } + + // Create explicit VPInstructions to convert vectors to scalars. The current + // implementation is conservative - it may miss some cases that may or may not + // be vector values. TODO: introduce Unpacks speculatively - remove them later + // if they are known to operate on scalar values. + for (VPBasicBlock *VPBB : VPBBsInsideLoopRegion) { + for (VPRecipeBase &R : make_early_inc_range(*VPBB)) { + if (isa<VPReplicateRecipe, VPInstruction, VPScalarIVStepsRecipe, + VPDerivedIVRecipe, VPCanonicalIVPHIRecipe>(&R)) + continue; + for (VPValue *Def : R.definedValues()) { + // Skip recipes that are single-scalar or only have their first lane + // used. + // TODO: The Defs skipped here may or may not be vector values. + // Introduce Unpacks, and remove them later, if they are guaranteed to + // produce scalar values. + if (vputils::isSingleScalar(Def) || vputils::onlyFirstLaneUsed(Def)) + continue; + + // At the moment, we create unpacks only for scalar users outside + // replicate regions. Recipes inside replicate regions still extract the + // required lanes implicitly. + // TODO: Remove once replicate regions are unrolled completely. + auto IsCandidateUnpackUser = [Def](VPUser *U) { + VPRegionBlock *ParentRegion = cast<VPRecipeBase>(U)->getRegion(); + return U->usesScalars(Def) && + (!ParentRegion || !ParentRegion->isReplicator()); + }; + if (none_of(Def->users(), IsCandidateUnpackUser)) + continue; + + auto *Unpack = new VPInstruction(VPInstruction::Unpack, {Def}); + if (R.isPhi()) + Unpack->insertBefore(*VPBB, VPBB->getFirstNonPhi()); + else + Unpack->insertAfter(&R); + Def->replaceUsesWithIf(Unpack, + [&IsCandidateUnpackUser](VPUser &U, unsigned) { + return IsCandidateUnpackUser(&U); + }); + } + } + } } void VPlanTransforms::materializeVectorTripCount(VPlan &Plan, diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h index 5a8a2bb..b28559b 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -325,9 +325,10 @@ struct VPlanTransforms { static void materializeBackedgeTakenCount(VPlan &Plan, VPBasicBlock *VectorPH); - /// Add explicit Build[Struct]Vector recipes that combine multiple scalar - /// values into single vectors. - static void materializeBuildVectors(VPlan &Plan); + /// Add explicit Build[Struct]Vector recipes to Pack multiple scalar values + /// into vectors and Unpack recipes to extract scalars from vectors as + /// needed. + static void materializePacksAndUnpacks(VPlan &Plan); /// Materialize VF and VFxUF to be computed explicitly using VPInstructions. static void materializeVFAndVFxUF(VPlan &Plan, VPBasicBlock *VectorPH, diff --git a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp index 5aeda3e..cfd1a74 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp @@ -465,10 +465,21 @@ void VPlanTransforms::unrollByUF(VPlan &Plan, unsigned UF) { /// Create a single-scalar clone of \p DefR (must be a VPReplicateRecipe or /// VPInstruction) for lane \p Lane. Use \p Def2LaneDefs to look up scalar /// definitions for operands of \DefR. -static VPRecipeWithIRFlags * +static VPValue * cloneForLane(VPlan &Plan, VPBuilder &Builder, Type *IdxTy, VPRecipeWithIRFlags *DefR, VPLane Lane, const DenseMap<VPValue *, SmallVector<VPValue *>> &Def2LaneDefs) { + VPValue *Op; + if (match(DefR, m_VPInstruction<VPInstruction::Unpack>(m_VPValue(Op)))) { + auto LaneDefs = Def2LaneDefs.find(Op); + if (LaneDefs != Def2LaneDefs.end()) + return LaneDefs->second[Lane.getKnownLane()]; + + VPValue *Idx = + Plan.getOrAddLiveIn(ConstantInt::get(IdxTy, Lane.getKnownLane())); + return Builder.createNaryOp(Instruction::ExtractElement, {Op, Idx}); + } + // Collect the operands at Lane, creating extracts as needed. SmallVector<VPValue *> NewOps; for (VPValue *Op : DefR->operands()) { @@ -480,6 +491,10 @@ cloneForLane(VPlan &Plan, VPBuilder &Builder, Type *IdxTy, continue; } if (Lane.getKind() == VPLane::Kind::ScalableLast) { + // Look through mandatory Unpack. + [[maybe_unused]] bool Matched = + match(Op, m_VPInstruction<VPInstruction::Unpack>(m_VPValue(Op))); + assert(Matched && "original op must have been Unpack"); NewOps.push_back( Builder.createNaryOp(VPInstruction::ExtractLastElement, {Op})); continue; @@ -547,7 +562,8 @@ void VPlanTransforms::replicateByVF(VPlan &Plan, ElementCount VF) { (isa<VPReplicateRecipe>(&R) && cast<VPReplicateRecipe>(&R)->isSingleScalar()) || (isa<VPInstruction>(&R) && - !cast<VPInstruction>(&R)->doesGeneratePerAllLanes())) + !cast<VPInstruction>(&R)->doesGeneratePerAllLanes() && + cast<VPInstruction>(&R)->getOpcode() != VPInstruction::Unpack)) continue; auto *DefR = cast<VPRecipeWithIRFlags>(&R); diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp index 8b1b0e5..10801c0 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp @@ -113,12 +113,12 @@ bool vputils::isUniformAcrossVFsAndUFs(VPValue *V) { return TypeSwitch<const VPRecipeBase *, bool>(R) .Case<VPDerivedIVRecipe>([](const auto *R) { return true; }) .Case<VPReplicateRecipe>([](const auto *R) { - // Loads and stores that are uniform across VF lanes are handled by - // VPReplicateRecipe.IsUniform. They are also uniform across UF parts if - // all their operands are invariant. - // TODO: Further relax the restrictions. + // Be conservative about side-effects, except for the + // known-side-effecting assumes and stores, which we know will be + // uniform. return R->isSingleScalar() && - (isa<LoadInst, StoreInst>(R->getUnderlyingValue())) && + (!R->mayHaveSideEffects() || + isa<AssumeInst, StoreInst>(R->getUnderlyingInstr())) && all_of(R->operands(), isUniformAcrossVFsAndUFs); }) .Case<VPInstruction>([](const auto *VPI) { diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.h b/llvm/lib/Transforms/Vectorize/VPlanUtils.h index cf95ac0..840a5b9 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUtils.h +++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.h @@ -64,7 +64,7 @@ inline bool isSingleScalar(const VPValue *VPV) { return true; if (auto *Rep = dyn_cast<VPReplicateRecipe>(VPV)) { - const VPRegionBlock *RegionOfR = Rep->getParent()->getParent(); + const VPRegionBlock *RegionOfR = Rep->getRegion(); // Don't consider recipes in replicate regions as uniform yet; their first // lane cannot be accessed when executing the replicate region for other // lanes. @@ -84,6 +84,12 @@ inline bool isSingleScalar(const VPValue *VPV) { return VPI->isSingleScalar() || VPI->isVectorToScalar() || (PreservesUniformity(VPI->getOpcode()) && all_of(VPI->operands(), isSingleScalar)); + if (isa<VPPartialReductionRecipe>(VPV)) + return false; + if (isa<VPReductionRecipe>(VPV)) + return true; + if (auto *Expr = dyn_cast<VPExpressionRecipe>(VPV)) + return Expr->isSingleScalar(); // VPExpandSCEVRecipes must be placed in the entry and are alway uniform. return isa<VPExpandSCEVRecipe>(VPV); diff --git a/llvm/lib/WindowsDriver/MSVCPaths.cpp b/llvm/lib/WindowsDriver/MSVCPaths.cpp index 1fc8974..09468da 100644 --- a/llvm/lib/WindowsDriver/MSVCPaths.cpp +++ b/llvm/lib/WindowsDriver/MSVCPaths.cpp @@ -259,9 +259,7 @@ static bool getSystemRegistryString(const char *keyPath, const char *valueName, #endif // _WIN32 } -namespace llvm { - -const char *archToWindowsSDKArch(Triple::ArchType Arch) { +const char *llvm::archToWindowsSDKArch(Triple::ArchType Arch) { switch (Arch) { case Triple::ArchType::x86: return "x86"; @@ -277,7 +275,7 @@ const char *archToWindowsSDKArch(Triple::ArchType Arch) { } } -const char *archToLegacyVCArch(Triple::ArchType Arch) { +const char *llvm::archToLegacyVCArch(Triple::ArchType Arch) { switch (Arch) { case Triple::ArchType::x86: // x86 is default in legacy VC toolchains. @@ -295,7 +293,7 @@ const char *archToLegacyVCArch(Triple::ArchType Arch) { } } -const char *archToDevDivInternalArch(Triple::ArchType Arch) { +const char *llvm::archToDevDivInternalArch(Triple::ArchType Arch) { switch (Arch) { case Triple::ArchType::x86: return "i386"; @@ -311,8 +309,9 @@ const char *archToDevDivInternalArch(Triple::ArchType Arch) { } } -bool appendArchToWindowsSDKLibPath(int SDKMajor, SmallString<128> LibPath, - Triple::ArchType Arch, std::string &path) { +bool llvm::appendArchToWindowsSDKLibPath(int SDKMajor, SmallString<128> LibPath, + Triple::ArchType Arch, + std::string &path) { if (SDKMajor >= 8) { sys::path::append(LibPath, archToWindowsSDKArch(Arch)); } else { @@ -336,10 +335,11 @@ bool appendArchToWindowsSDKLibPath(int SDKMajor, SmallString<128> LibPath, return true; } -std::string getSubDirectoryPath(SubDirectoryType Type, ToolsetLayout VSLayout, - const std::string &VCToolChainPath, - Triple::ArchType TargetArch, - StringRef SubdirParent) { +std::string llvm::getSubDirectoryPath(SubDirectoryType Type, + ToolsetLayout VSLayout, + const std::string &VCToolChainPath, + Triple::ArchType TargetArch, + StringRef SubdirParent) { const char *SubdirName; const char *IncludeName; switch (VSLayout) { @@ -390,19 +390,22 @@ std::string getSubDirectoryPath(SubDirectoryType Type, ToolsetLayout VSLayout, return std::string(Path); } -bool useUniversalCRT(ToolsetLayout VSLayout, const std::string &VCToolChainPath, - Triple::ArchType TargetArch, vfs::FileSystem &VFS) { +bool llvm::useUniversalCRT(ToolsetLayout VSLayout, + const std::string &VCToolChainPath, + Triple::ArchType TargetArch, vfs::FileSystem &VFS) { SmallString<128> TestPath(getSubDirectoryPath( SubDirectoryType::Include, VSLayout, VCToolChainPath, TargetArch)); sys::path::append(TestPath, "stdlib.h"); return !VFS.exists(TestPath); } -bool getWindowsSDKDir(vfs::FileSystem &VFS, std::optional<StringRef> WinSdkDir, - std::optional<StringRef> WinSdkVersion, - std::optional<StringRef> WinSysRoot, std::string &Path, - int &Major, std::string &WindowsSDKIncludeVersion, - std::string &WindowsSDKLibVersion) { +bool llvm::getWindowsSDKDir(vfs::FileSystem &VFS, + std::optional<StringRef> WinSdkDir, + std::optional<StringRef> WinSdkVersion, + std::optional<StringRef> WinSysRoot, + std::string &Path, int &Major, + std::string &WindowsSDKIncludeVersion, + std::string &WindowsSDKLibVersion) { // Trust /winsdkdir and /winsdkversion if present. if (getWindowsSDKDirViaCommandLine(VFS, WinSdkDir, WinSdkVersion, WinSysRoot, Path, Major, WindowsSDKIncludeVersion)) { @@ -460,11 +463,11 @@ bool getWindowsSDKDir(vfs::FileSystem &VFS, std::optional<StringRef> WinSdkDir, return false; } -bool getUniversalCRTSdkDir(vfs::FileSystem &VFS, - std::optional<StringRef> WinSdkDir, - std::optional<StringRef> WinSdkVersion, - std::optional<StringRef> WinSysRoot, - std::string &Path, std::string &UCRTVersion) { +bool llvm::getUniversalCRTSdkDir(vfs::FileSystem &VFS, + std::optional<StringRef> WinSdkDir, + std::optional<StringRef> WinSdkVersion, + std::optional<StringRef> WinSysRoot, + std::string &Path, std::string &UCRTVersion) { // If /winsdkdir is passed, use it as location for the UCRT too. // FIXME: Should there be a dedicated /ucrtdir to override /winsdkdir? int Major; @@ -491,11 +494,11 @@ bool getUniversalCRTSdkDir(vfs::FileSystem &VFS, return getWindows10SDKVersionFromPath(VFS, Path, UCRTVersion); } -bool findVCToolChainViaCommandLine(vfs::FileSystem &VFS, - std::optional<StringRef> VCToolsDir, - std::optional<StringRef> VCToolsVersion, - std::optional<StringRef> WinSysRoot, - std::string &Path, ToolsetLayout &VSLayout) { +bool llvm::findVCToolChainViaCommandLine( + vfs::FileSystem &VFS, std::optional<StringRef> VCToolsDir, + std::optional<StringRef> VCToolsVersion, + std::optional<StringRef> WinSysRoot, std::string &Path, + ToolsetLayout &VSLayout) { // Don't validate the input; trust the value supplied by the user. // The primary motivation is to prevent unnecessary file and registry access. if (VCToolsDir || WinSysRoot) { @@ -518,8 +521,9 @@ bool findVCToolChainViaCommandLine(vfs::FileSystem &VFS, return false; } -bool findVCToolChainViaEnvironment(vfs::FileSystem &VFS, std::string &Path, - ToolsetLayout &VSLayout) { +bool llvm::findVCToolChainViaEnvironment(vfs::FileSystem &VFS, + std::string &Path, + ToolsetLayout &VSLayout) { // These variables are typically set by vcvarsall.bat // when launching a developer command prompt. if (std::optional<std::string> VCToolsInstallDir = @@ -627,9 +631,9 @@ bool findVCToolChainViaEnvironment(vfs::FileSystem &VFS, std::string &Path, return false; } -bool findVCToolChainViaSetupConfig(vfs::FileSystem &VFS, - std::optional<StringRef> VCToolsVersion, - std::string &Path, ToolsetLayout &VSLayout) { +bool llvm::findVCToolChainViaSetupConfig( + vfs::FileSystem &VFS, std::optional<StringRef> VCToolsVersion, + std::string &Path, ToolsetLayout &VSLayout) { #if !defined(USE_MSVC_SETUP_API) return false; #else @@ -724,7 +728,8 @@ bool findVCToolChainViaSetupConfig(vfs::FileSystem &VFS, #endif } -bool findVCToolChainViaRegistry(std::string &Path, ToolsetLayout &VSLayout) { +bool llvm::findVCToolChainViaRegistry(std::string &Path, + ToolsetLayout &VSLayout) { std::string VSInstallPath; if (getSystemRegistryString(R"(SOFTWARE\Microsoft\VisualStudio\$VERSION)", "InstallDir", VSInstallPath, nullptr) || @@ -744,5 +749,3 @@ bool findVCToolChainViaRegistry(std::string &Path, ToolsetLayout &VSLayout) { } return false; } - -} // namespace llvm |